其他
【他山之石】PyTorch使用预训练模型进行模型加载
“他山之石,可以攻玉”,站在巨人的肩膀才能看得更高,走得更远。在科研的道路上,更需借助东风才能更快前行。为此,我们特别搜集整理了一些实用的代码链接,数据集,软件,编程技巧等,开辟“他山之石”专栏,助你乘风破浪,一路奋勇向前,敬请关注。
地址:https://www.zhihu.com/people/panda-9-6
class Par_net(nn.Module):
def __init__(self):
super(Par_net, self).__init__()
self.layer1 = nn.Linear(1,1)
self.layer2 = nn.Linear(2,2)
self.layer3 = nn.Linear(3,3)
self.para_init()
def para_init(self):
for p in self.parameters():
nn.init.constant_(p,1)
class Recover_Net(nn.Module):
def __init__(self):
super(Recover_Net,self).__init__()
self.layer1 = nn.Linear(1, 1)
self.layer2 = nn.Linear(2, 2)
self.layer3 = nn.Linear(3, 3)
class Recover_Net2(nn.Module):
def __init__(self):
super(Recover_Net2,self).__init__()
self.layer2 = nn.Linear(2, 2)
self.layer1 = nn.Linear(1, 1)
self.layer4 = nn.Linear(4, 4)
self.layer3 = nn.Linear(3, 3)
[Parameter containing:tensor([[1.]], requires_grad=True),
Parameter containing:tensor([1.], requires_grad=True),
Parameter containing:tensor([[1., 1.],
[1., 1.]], requires_grad=True),
Parameter containing:tensor([1., 1.], requires_grad=True),
Parameter containing:tensor([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]], requires_grad=True),
Parameter containing: tensor([1., 1., 1.], requires_grad=True)]
path = 'D:\Pycharm\MOT-DET\parameter_test\ori_net.pkl'
ori_net = Par_net()
state_dict = ori_net.state_dict()
torch.save(state_dict,path)
new_net = Recover_Net()
print(list(new_net.parameters()))
print('+++++++++++++++++++++++++')
checkpoint=torch.load(path)
new_net.load_state_dict(checkpoint)
print(list(new_net.parameters()))
_________________________
[Parameter containing:
tensor([[0.8347]], requires_grad=True), Parameter containing:
tensor([0.4333], requires_grad=True), Parameter containing:
tensor([[-0.2783, 0.1913],
[-0.5740, 0.2918]], requires_grad=True), Parameter containing:
tensor([-0.3380, -0.4790], requires_grad=True), Parameter containing:
tensor([[-0.5227, -0.4193, 0.0621],
[-0.3059, -0.2392, -0.3306],
[ 0.5414, 0.4600, 0.0212]], requires_grad=True), Parameter containing:
tensor([-0.4450, -0.2309, -0.3591], requires_grad=True)]
+++++++++++++++++++++++++
[Parameter containing:
tensor([[1.]], requires_grad=True), Parameter containing:
tensor([1.], requires_grad=True), Parameter containing:
tensor([[1., 1.],
[1., 1.]], requires_grad=True), Parameter containing:
tensor([1., 1.], requires_grad=True), Parameter containing:
tensor([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]], requires_grad=True), Parameter containing:
tensor([1., 1., 1.], requires_grad=True)]
那么对于网络发生改变的情况,网络模型的加载是一个什么样的状况?
print('!!!!!!!!!!!!!!!!!!!!!!!!!')
print(list(change_net.parameters()))
change_net.load_state_dict(checkpoint)
print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')
print(list(change_net.parameters()))
!!!!!!!!!!!!!!!!!!!!!!!!!
[Parameter containing:
tensor([[ 0.2289, -0.2219],
[-0.2913, -0.0575]], requires_grad=True), Parameter containing:
tensor([0.3789, 0.6743], requires_grad=True), Parameter containing:
tensor([[-0.9336]], requires_grad=True), Parameter containing:
tensor([0.2300], requires_grad=True), Parameter containing:
tensor([[ 0.4035, 0.2747, -0.3141, -0.4875],
[-0.1944, -0.4428, -0.1594, 0.3091],
[ 0.2849, 0.4464, -0.1852, 0.1225],
[-0.3696, 0.2659, 0.3679, -0.4899]], requires_grad=True), Parameter containing:
tensor([ 0.3396, -0.0440, 0.2090, -0.1895], requires_grad=True), Parameter containing:
tensor([[ 0.0609, -0.4049, 0.1330],
[-0.5245, 0.1413, -0.0111],
[ 0.2149, -0.1645, 0.2028]], requires_grad=True), Parameter containing:
tensor([-0.1993, -0.4430, -0.4672], requires_grad=True)]
Traceback (most recent call last):
File "D:/Pycharm/MOT-DET/parameter_test/paramater_net.py", line 55, in <module>
change_net.load_state_dict(checkpoint)
File "D:\Setup\python\lib\site-packages\torch\nn\modules\module.py", line 830, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for Recover_Net2:
Missing key(s) in state_dict: "layer4.weight", "layer4.bias".
Process finished with exit code 1
def load_state_dict(self, state_dict, strict=True):
r"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants. If :attr:`strict` is ``True``, then
the keys of :attr:`state_dict` must exactly match the keys returned
by this module's :meth:`~torch.nn.Module.state_dict` function.
Arguments:
state_dict (dict): a dict containing parameters and
persistent buffers.
strict (bool, optional): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
Returns:
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
* **missing_keys** is a list of str containing the missing keys
* **unexpected_keys** is a list of str containing the unexpected keys
"""
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(self)
load = None # break load->load reference cycle
if strict:
if len(unexpected_keys) > 0:
error_msgs.insert(
0, 'Unexpected key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in unexpected_keys)))
if len(missing_keys) > 0:
error_msgs.insert(
0, 'Missing key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in missing_keys)))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
self.__class__.__name__, "\n\t".join(error_msgs)))
return _IncompatibleKeys(missing_keys, unexpected_keys)
print('!!!!!!!!!!!!!!!!!!!!!!!!!')
print(list(change_net.parameters()))
# change_net.load_state_dict(checkpoint)
change_net.load_state_dict(checkpoint,strict=False)
print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')
print(list(change_net.parameters()))
!!!!!!!!!!!!!!!!!!!!!!!!!
[Parameter containing:
tensor([[-0.2190, -0.5524],
[ 0.4523, 0.0309]], requires_grad=True), Parameter containing:
tensor([0.0235, 0.1300], requires_grad=True), Parameter containing:
tensor([[0.7876]], requires_grad=True), Parameter containing:
tensor([-0.4198], requires_grad=True), Parameter containing:
tensor([[ 0.0665, -0.3781, 0.3609, -0.4710],
[-0.2582, -0.3591, -0.3254, 0.4899],
[-0.1772, 0.2239, -0.2328, -0.4277],
[ 0.3070, -0.0599, -0.4284, -0.4030]], requires_grad=True), Parameter containing:
tensor([ 0.4186, -0.2221, 0.2654, -0.4997], requires_grad=True), Parameter containing:
tensor([[ 0.2064, -0.3040, 0.3713],
[ 0.3115, 0.4053, 0.4515],
[ 0.0090, 0.5466, 0.1865]], requires_grad=True), Parameter containing:
tensor([0.2233, 0.0831, 0.5407], requires_grad=True)]
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
[Parameter containing:
tensor([[1., 1.],
[1., 1.]], requires_grad=True), Parameter containing:
tensor([1., 1.], requires_grad=True), Parameter containing:
tensor([[1.]], requires_grad=True), Parameter containing:
tensor([1.], requires_grad=True), Parameter containing:
tensor([[ 0.0665, -0.3781, 0.3609, -0.4710],
[-0.2582, -0.3591, -0.3254, 0.4899],
[-0.1772, 0.2239, -0.2328, -0.4277],
[ 0.3070, -0.0599, -0.4284, -0.4030]], requires_grad=True), Parameter containing:
tensor([ 0.4186, -0.2221, 0.2654, -0.4997], requires_grad=True), Parameter containing:
tensor([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]], requires_grad=True), Parameter containing:
tensor([1., 1., 1.], requires_grad=True)]print('!!!!!!!!!!!!!!!!!!!!!!!!!')
print(list(change_net.parameters()))
# change_net.load_state_dict(checkpoint)
change_net.load_state_dict(checkpoint,strict=False)
print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')
print(list(change_net.parameters()))
附实验完整代码
import torch.nn as nn
from torch.nn.parameter import Parameter as P
import torch
class Par_net(nn.Module):
def __init__(self):
super(Par_net, self).__init__()
self.layer1 = nn.Linear(1,1)
self.layer2 = nn.Linear(2,2)
self.layer3 = nn.Linear(3,3)
self.para_init()
def para_init(self):
for p in self.parameters():
nn.init.constant_(p,1)
class Recover_Net(nn.Module):
def __init__(self):
super(Recover_Net,self).__init__()
self.layer1 = nn.Linear(1, 1)
self.layer2 = nn.Linear(2, 2)
self.layer3 = nn.Linear(3, 3)
class Recover_Net2(nn.Module):
def __init__(self):
super(Recover_Net2,self).__init__()
self.layer2 = nn.Linear(2, 2)
self.layer1 = nn.Linear(1, 1)
self.layer4 = nn.Linear(4, 4)
self.layer3 = nn.Linear(3, 3)
path = 'D:\Pycharm\MOT-DET\parameter_test\ori_net.pkl'
ori_net = Par_net()
state_dict = ori_net.state_dict()
torch.save(state_dict,path)
new_net = Recover_Net()
change_net = Recover_Net2()
print(list(ori_net.parameters()))
print('_________________________')
print(list(new_net.parameters()))
print('+++++++++++++++++++++++++')
checkpoint=torch.load(path)
new_net.load_state_dict(checkpoint)
print(list(new_net.parameters()))
print('!!!!!!!!!!!!!!!!!!!!!!!!!')
print(list(change_net.parameters()))
# change_net.load_state_dict(checkpoint)
change_net.load_state_dict(checkpoint,strict=False)
print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')
print(list(change_net.parameters()))
“他山之石”历史文章
深度学习调参经验总结
PyTorch实现断点继续训练
Pytorch/Tensorflow-gpu训练并行加速trick(含代码)
从NumPy开始实现一个支持Auto-grad的CNN框架
pytorch_lightning 全程笔记
深度学习中的那些Trade-off
PyTorch 手把手搭建神经网络 (MNIST)
autograd源码剖析
怎样才能让你的模型更加高效运行?
来自日本程序员的纯C++深度学习库tiny-dnn
MMTracking: OpenMMLab 一体化视频目标感知平台
深度学习和机器视觉top组都在研究什么
pytorch常见的坑汇总
pytorch 中张量基本操作
更多他山之石专栏文章,
请点击文章底部“阅读原文”查看
分享、点赞、在看,给个三连击呗!