查看原文
其他

【他山之石】PyTorch使用预训练模型进行模型加载

“他山之石,可以攻玉”,站在巨人的肩膀才能看得更高,走得更远。在科研的道路上,更需借助东风才能更快前行。为此,我们特别搜集整理了一些实用的代码链接,数据集,软件,编程技巧等,开辟“他山之石”专栏,助你乘风破浪,一路奋勇向前,敬请关注。

作者:知乎—HUST小菜鸡

地址:https://www.zhihu.com/people/panda-9-6

通过模型以及模型参数的保存,实现模型的加载,我之前已经在一篇文章写过了,在最近的学习任务中,以及和知友的讨论中,遇到了这样一种情况,如果网络模型发生改动,那么使用预训练模型进行加载的情况下,网络模型是如何加载的,对此进行了一些简单的实验。来验证自己的一些想法。
关于模型的保存以及恢复参照之前的这篇文章 PyTorch实现断点继续训练
参照PyTorch中文文档(https://pytorch-cn.readthedocs.io/zh/latest/),我们对torch.save和torch.load以及load_state_dict有了新的认识,
对于保存的对象参数,通过模型的state_dict中的key对应和model.state_dict的key进行匹配,以此来返回加载一致的模型参数。注意这里的重点是state_dict中的key对应和model.state_dict中的key是一致的,即利用网络参数的字典进行匹配,这是可以正确恢复的关键,如果直接使用序列化的方法,那么对于网络层的命名是系统自己默认的,那么在恢复的过程中,如果插入了特定的层以后,网络的默认命名参数都会发生改变,这样后续的网络都无法进行加载,因而我们更推荐在网络中各个层级命名的时候,都采用自己的命名,而不要采用默认的命名。
为了进行相关的实验验证,定义了三个网络,第一个网络用于初始化参数为固定的值进行保存,第二个网络用于验证对于完全一致的网络的模型加载,第三个网络用于验证在网络中添加删除某些层级以后模型的加载。
为了直观的可视化结果,网络设置只简单的使用了全连接层,且为了直观的可视网络的对应层次的参数,对应的layer1,2,3,4的全连接单元数分别为1,2,3,4

class Par_net(nn.Module): def __init__(self): super(Par_net, self).__init__() self.layer1 = nn.Linear(1,1) self.layer2 = nn.Linear(2,2) self.layer3 = nn.Linear(3,3) self.para_init()
def para_init(self): for p in self.parameters(): nn.init.constant_(p,1)
class Recover_Net(nn.Module): def __init__(self): super(Recover_Net,self).__init__() self.layer1 = nn.Linear(1, 1) self.layer2 = nn.Linear(2, 2) self.layer3 = nn.Linear(3, 3)

class Recover_Net2(nn.Module): def __init__(self): super(Recover_Net2,self).__init__() self.layer2 = nn.Linear(2, 2) self.layer1 = nn.Linear(1, 1) self.layer4 = nn.Linear(4, 4)
self.layer3 = nn.Linear(3, 3)
由于采用了自定义的一个模型初始化参数,因此Par_net在定义以后,会初始化加载所有的网络参数均为1
[Parameter containing:tensor([[1.]], requires_grad=True), Parameter containing:tensor([1.], requires_grad=True),
Parameter containing:tensor([[1., 1.], [1., 1.]], requires_grad=True), Parameter containing:tensor([1., 1.], requires_grad=True),
Parameter containing:tensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]], requires_grad=True), Parameter containing: tensor([1., 1., 1.], requires_grad=True)]
对模型参数进行保存,进而加载到新的网络中,输出网络在加载前后参数的变化情况:
path = 'D:\Pycharm\MOT-DET\parameter_test\ori_net.pkl'ori_net = Par_net()
state_dict = ori_net.state_dict()torch.save(state_dict,path)
new_net = Recover_Net()

print(list(new_net.parameters()))print('+++++++++++++++++++++++++')
checkpoint=torch.load(path)new_net.load_state_dict(checkpoint)print(list(new_net.parameters()))
通过以下实验结果可以看出网络参数被初始化:
_________________________[Parameter containing:tensor([[0.8347]], requires_grad=True), Parameter containing:tensor([0.4333], requires_grad=True), Parameter containing:tensor([[-0.2783, 0.1913], [-0.5740, 0.2918]], requires_grad=True), Parameter containing:tensor([-0.3380, -0.4790], requires_grad=True), Parameter containing:tensor([[-0.5227, -0.4193, 0.0621], [-0.3059, -0.2392, -0.3306], [ 0.5414, 0.4600, 0.0212]], requires_grad=True), Parameter containing:tensor([-0.4450, -0.2309, -0.3591], requires_grad=True)]+++++++++++++++++++++++++[Parameter containing:tensor([[1.]], requires_grad=True), Parameter containing:tensor([1.], requires_grad=True), Parameter containing:tensor([[1., 1.], [1., 1.]], requires_grad=True), Parameter containing:tensor([1., 1.], requires_grad=True), Parameter containing:tensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]], requires_grad=True), Parameter containing:tensor([1., 1., 1.], requires_grad=True)]

那么对于网络发生改变的情况,网络模型的加载是一个什么样的状况?

为此,我们一方面打乱网络的层数,后面也可以自行删除等进行实验
print('!!!!!!!!!!!!!!!!!!!!!!!!!')print(list(change_net.parameters()))change_net.load_state_dict(checkpoint)print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')print(list(change_net.parameters()))
随机初始化的网络参数很直观,但是加载的过程中会出现error,导致不能正确加载
!!!!!!!!!!!!!!!!!!!!!!!!![Parameter containing:tensor([[ 0.2289, -0.2219], [-0.2913, -0.0575]], requires_grad=True), Parameter containing:tensor([0.3789, 0.6743], requires_grad=True), Parameter containing:tensor([[-0.9336]], requires_grad=True), Parameter containing:tensor([0.2300], requires_grad=True), Parameter containing:tensor([[ 0.4035, 0.2747, -0.3141, -0.4875], [-0.1944, -0.4428, -0.1594, 0.3091], [ 0.2849, 0.4464, -0.1852, 0.1225], [-0.3696, 0.2659, 0.3679, -0.4899]], requires_grad=True), Parameter containing:tensor([ 0.3396, -0.0440, 0.2090, -0.1895], requires_grad=True), Parameter containing:tensor([[ 0.0609, -0.4049, 0.1330], [-0.5245, 0.1413, -0.0111], [ 0.2149, -0.1645, 0.2028]], requires_grad=True), Parameter containing:tensor([-0.1993, -0.4430, -0.4672], requires_grad=True)]Traceback (most recent call last): File "D:/Pycharm/MOT-DET/parameter_test/paramater_net.py", line 55, in <module> change_net.load_state_dict(checkpoint) File "D:\Setup\python\lib\site-packages\torch\nn\modules\module.py", line 830, in load_state_dict self.__class__.__name__, "\n\t".join(error_msgs)))RuntimeError: Error(s) in loading state_dict for Recover_Net2: Missing key(s) in state_dict: "layer4.weight", "layer4.bias".
Process finished with exit code 1
报错原因是模型中没有layer4这个key
def load_state_dict(self, state_dict, strict=True): r"""Copies parameters and buffers from :attr:`state_dict` into this module and its descendants. If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function.
Arguments: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True``
Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing the missing keys * **unexpected_keys** is a list of str containing the unexpected keys """ missing_keys = [] unexpected_keys = [] error_msgs = []
# copy state_dict so _load_from_state_dict can modify it metadata = getattr(state_dict, '_metadata', None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata
def load(module, prefix=''): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) module._load_from_state_dict( state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + '.')
load(self) load = None # break load->load reference cycle
if strict: if len(unexpected_keys) > 0: error_msgs.insert( 0, 'Unexpected key(s) in state_dict: {}. '.format( ', '.join('"{}"'.format(k) for k in unexpected_keys))) if len(missing_keys) > 0: error_msgs.insert( 0, 'Missing key(s) in state_dict: {}. '.format( ', '.join('"{}"'.format(k) for k in missing_keys)))
if len(error_msgs) > 0: raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( self.__class__.__name__, "\n\t".join(error_msgs))) return _IncompatibleKeys(missing_keys, unexpected_keys)
去load_state_dict()函数中去仔细看发现,有一个strict参数,该参数决定网络在恢复过程中是严格恢复,还是非严格恢复,默认是严格恢复,如果严格恢复,则会严格匹配所有的字典,这是导致error的原因,但是我们在日常使用的过程中,只会出现warnning,不会出现error,就是这里采用的不是strict
whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True``
修改代码后重新加载(将strict赋值为False)
print('!!!!!!!!!!!!!!!!!!!!!!!!!')print(list(change_net.parameters()))# change_net.load_state_dict(checkpoint)change_net.load_state_dict(checkpoint,strict=False)print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')print(list(change_net.parameters()))
可以看到具有相同网络层名称的网络被初始化,不具有的网络层的参数不会被初始化
!!!!!!!!!!!!!!!!!!!!!!!!![Parameter containing:tensor([[-0.2190, -0.5524], [ 0.4523, 0.0309]], requires_grad=True), Parameter containing:tensor([0.0235, 0.1300], requires_grad=True), Parameter containing:tensor([[0.7876]], requires_grad=True), Parameter containing:tensor([-0.4198], requires_grad=True), Parameter containing:tensor([[ 0.0665, -0.3781, 0.3609, -0.4710], [-0.2582, -0.3591, -0.3254, 0.4899], [-0.1772, 0.2239, -0.2328, -0.4277], [ 0.3070, -0.0599, -0.4284, -0.4030]], requires_grad=True), Parameter containing:tensor([ 0.4186, -0.2221, 0.2654, -0.4997], requires_grad=True), Parameter containing:tensor([[ 0.2064, -0.3040, 0.3713], [ 0.3115, 0.4053, 0.4515], [ 0.0090, 0.5466, 0.1865]], requires_grad=True), Parameter containing:tensor([0.2233, 0.0831, 0.5407], requires_grad=True)]$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$[Parameter containing:tensor([[1., 1.], [1., 1.]], requires_grad=True), Parameter containing:tensor([1., 1.], requires_grad=True), Parameter containing:tensor([[1.]], requires_grad=True), Parameter containing:tensor([1.], requires_grad=True), Parameter containing:tensor([[ 0.0665, -0.3781, 0.3609, -0.4710], [-0.2582, -0.3591, -0.3254, 0.4899], [-0.1772, 0.2239, -0.2328, -0.4277], [ 0.3070, -0.0599, -0.4284, -0.4030]], requires_grad=True), Parameter containing:tensor([ 0.4186, -0.2221, 0.2654, -0.4997], requires_grad=True), Parameter containing:tensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]], requires_grad=True), Parameter containing:tensor([1., 1., 1.], requires_grad=True)]print('!!!!!!!!!!!!!!!!!!!!!!!!!')print(list(change_net.parameters()))# change_net.load_state_dict(checkpoint)change_net.load_state_dict(checkpoint,strict=False)print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')print(list(change_net.parameters()))
以上部分的内容可以看作是对对自己之前那篇文章的一个补充,一方面也是自己的想法的一个验证,如有问题还希望大家多多指正

附实验完整代码

import torch.nn as nnfrom torch.nn.parameter import Parameter as Pimport torch
class Par_net(nn.Module): def __init__(self): super(Par_net, self).__init__() self.layer1 = nn.Linear(1,1) self.layer2 = nn.Linear(2,2) self.layer3 = nn.Linear(3,3) self.para_init()
def para_init(self): for p in self.parameters(): nn.init.constant_(p,1)
class Recover_Net(nn.Module): def __init__(self): super(Recover_Net,self).__init__() self.layer1 = nn.Linear(1, 1) self.layer2 = nn.Linear(2, 2) self.layer3 = nn.Linear(3, 3)

class Recover_Net2(nn.Module): def __init__(self): super(Recover_Net2,self).__init__() self.layer2 = nn.Linear(2, 2) self.layer1 = nn.Linear(1, 1) self.layer4 = nn.Linear(4, 4)
self.layer3 = nn.Linear(3, 3)

path = 'D:\Pycharm\MOT-DET\parameter_test\ori_net.pkl'ori_net = Par_net()
state_dict = ori_net.state_dict()torch.save(state_dict,path)
new_net = Recover_Net()change_net = Recover_Net2()
print(list(ori_net.parameters()))print('_________________________')print(list(new_net.parameters()))print('+++++++++++++++++++++++++')
checkpoint=torch.load(path)new_net.load_state_dict(checkpoint)print(list(new_net.parameters()))
print('!!!!!!!!!!!!!!!!!!!!!!!!!')print(list(change_net.parameters()))# change_net.load_state_dict(checkpoint)change_net.load_state_dict(checkpoint,strict=False)print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')print(list(change_net.parameters()))


本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。


“他山之石”历史文章


更多他山之石专栏文章,

请点击文章底部“阅读原文”查看



分享、点赞、在看,给个三连击呗!

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存