2000字范文,分享全网优秀范文,学习好帮手!
2000字范文 > 加载dict_Pytorch模型resume training 加载模型基础上继续训练

加载dict_Pytorch模型resume training 加载模型基础上继续训练

时间:2021-04-05 06:23:48

相关推荐

加载dict_Pytorch模型resume training 加载模型基础上继续训练

Step1:首先查看源码train.py中如何保存模型的:

checkpoint_dict = {'epoch': epoch, 'model_state_dict': model.state_dict(), 'optim_state_dict': optimizer.state_dict(), 'criterion_state_dict': train_criterion.state_dict()}torch.save(checkpoint_dict, filename)

并查看源码中如何加载参数(参数名):

#print('----------model loading-------------------') model.train()#print('----------model loaded true-------------------')#print('----------train_criterion loading-------------------') loss_tmp = train_criterion(output, target_var)#print('----------train_criterion loaded-------------------') #print('----------optimizer loading-------------------') optimizer.step()#print('----------optimizer loaded-------------------')

step2:然后根据保存的模型的写法,自己写加载模型:

start_epoch = -1if opt.resume:print('-----------------------------')path_checkpoint = opt.checkbreakpoint checkpoint = torch.load(path_checkpoint) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optim_state_dict']) train_criterion.load_state_dict(checkpoint['criterion_state_dict'])start_epoch = checkpoint['epoch'] print("start_epoch:",start_epoch)print('-----------------------------')

其中“resume”和“checkbreakpoint”都是在parse参数时候定义,直接写入命令行:

self.parser.add_argument('--checkbreakpoint',type=str, default ='epoch_005.pth.tar')self.parser.add_argument('--resume', type=bool, default=False)

Step3:最后调整训练轮数

原来是输入了一个epoch值,训练迭代为range(0,epoch)

# 原来的for epoch in range(opt.epochs):

现在需要改成,输入一个开头值,加上新训的轮数,训练迭代为range(start_epoch, start_epoch+新训多少轮)

如果没有断点续传,start_epoch默认为-1,(start_epoch==-1?0:start_epoch) 为0,则轮数为(-1+1=0,0+epochs)如果有断点续传,start_epoch被更新为断点,start_epoch==-1?0:startepoch)为start_epoch,则轮数为(start_epochs+1, start_epoch+新增轮数)

# 新的start_epoch = -1if opt.resume:print('-----------------------------')path_checkpoint = opt.checkbreakpoint checkpoint = torch.load(path_checkpoint) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optim_state_dict']) train_criterion.load_state_dict(checkpoint['criterion_state_dict'])start_epoch = checkpoint['epoch'] print("start_epoch:",start_epoch)print('-----------------------------')new_start = 0 if start_epoch==-1 else start_epochfor epoch in range(start_epoch + 1, new_start+opt.epochs):

例如,断点start_epoch是300,opt.epochs=200, 就从(300+1,300+200),新训200轮。

Step4:直接训练,注意新补充参数,resume和checkbreakpoint

python train.py --resume True --checkbreakpoint ../epoch_240.pth.tar --lr 1e-5 --logdir ./1e-5 --gpus 2 --epochs 200

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。