其他
【他山之石】一文读懂 PyTorch 模型保存与载入
“他山之石,可以攻玉”,站在巨人的肩膀才能看得更高,走得更远。在科研的道路上,更需借助东风才能更快前行。为此,我们特别搜集整理了一些实用的代码链接,数据集,软件,编程技巧等,开辟“他山之石”专栏,助你乘风破浪,一路奋勇向前,敬请关注。
地址:https://www.zhihu.com/people/wen-xiao-du-4
class SampleNet(nn.Module):
def __init__(self):
super(SampleNet, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 3)
self.bn1 = nn.BatchNorm2d(16)
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(16, 32, 3)
self.bn2 = nn.BatchNorm2d(32)
self.gap = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(32, 10)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.pool(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.gap(x)
x = x.view(-1, 32)
x = self.fc(x)
return x
class SampleNet2(nn.Module):
def __init__(self):
super(SampleNet2, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 3)
self.bn1 = nn.BatchNorm2d(16)
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(16, 32, 3)
self.bn2 = nn.BatchNorm2d(32)
self.conv3 = nn.Conv2d(32, 32, 3)
self.bn3 = nn.BatchNorm2d(32)
self.gap = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(32, 10)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.pool(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.conv3(x)
x = self.bn3(x)
x = self.gap(x)
x = x.view(-1, 32)
x = self.fc(x)
return
只保存模型参数
# 保存 PyTorch 模型参数
torch.save(model.state_dict(), "model.pt")
# 重新载入模型参数
# 首先定义模型
model = SampleNet()
#通过 load_state_dict 函数加载参数,torch.load() 函数中重要的一步是反序列化。
model.load_state_dict(torch.load("model.pt"))
加载部分模型参数
# SampleNet2 比 SampleNet 多一个卷积层
model = SampleNet2()
# load_params 和 model_params 分别为两个模型的参数字典
load_params = torch.load("model.pt")
model_params = model.state_dict()
# 构建一个新参数字典,为两个模型重复的部分
same_parsms = {k: v for k, v in load_params.items() if k in model_params.keys()}
# 更新模型参数字典,并载入
model_params.update(same_parsms)
model.load_state_dict(model_params)
保存加载全部模型
torch.save(model, "full_model.pt")
new_model = torch.load("full_model.pt")
跨设备加载模型
在GPU上保存,在CPU上加载
device = torch.device('cpu')
model = SampleNet()
model.load_state_dict(torch.load("model.pt", map_location=device))
在CPU上保存,在GPU上加载
device = torch.device("cuda")
model = SampleNet()
model.load_state_dict(torch.load("model.pt", map_location="cuda:0"))
model.to(device)
保存 torch.nn.DataParallel 模型
model = SampleNet()
model = torch.nn.DataParallel(model)
torch.save(model.module.state_dict(), "model")
本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。
直播预告
“他山之石”历史文章
适合PyTorch小白的官网教程:Learning PyTorch With Examples
pytorch量化备忘录
LSTM模型结构的可视化
PointNet论文复现及代码详解
SCI写作常用句型之研究结果&发现
白话生成对抗网络GAN及代码实现
pytorch的余弦退火学习率
Pytorch转ONNX-实战篇(tracing机制)
联邦学习:FedAvg 的 Pytorch 实现
PyTorch实现ShuffleNet-v2亲身实践
训练时显存优化技术——OP合并与gradient checkpoint
浅谈数据标准化与Pytorch中NLLLoss和CrossEntropyLoss损失函数的区别
在C++平台上部署PyTorch模型流程+踩坑实录
libtorch使用经验
深度学习模型转换与部署那些事(含ONNX格式详细分析)
更多他山之石专栏文章,
请点击文章底部“阅读原文”查看
分享、点赞、在看,给个三连击呗!