查看原文
其他

【他山之石】一文读懂 PyTorch 模型保存与载入

“他山之石,可以攻玉”,站在巨人的肩膀才能看得更高,走得更远。在科研的道路上,更需借助东风才能更快前行。为此,我们特别搜集整理了一些实用的代码链接,数据集,软件,编程技巧等,开辟“他山之石”专栏,助你乘风破浪,一路奋勇向前,敬请关注。

作者:知乎—涤生

地址:https://www.zhihu.com/people/wen-xiao-du-4


当我们设计好一个模型,投喂大量数据,并且经过艰苦训练,终于得到了一个表现不错的模型,得到这个模型以后,就要保存下来,为后面的部署或者测试实验做准备。那么该如何正确高效的保存好训练的模型并再次重新载入呢?
首先来定义一个简单的模型
class SampleNet(nn.Module): def __init__(self): super(SampleNet, self).__init__() self.conv1 = nn.Conv2d(3, 16, 3) self.bn1 = nn.BatchNorm2d(16) self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(16, 32, 3) self.bn2 = nn.BatchNorm2d(32)
self.gap = nn.AdaptiveAvgPool2d(1) self.fc = nn.Linear(32, 10)
def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.pool(x) x = self.conv2(x) x = self.bn2(x) x = self.relu(x) x = self.gap(x) x = x.view(-1, 32) x = self.fc(x)
return x
class SampleNet2(nn.Module): def __init__(self): super(SampleNet2, self).__init__() self.conv1 = nn.Conv2d(3, 16, 3) self.bn1 = nn.BatchNorm2d(16) self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(16, 32, 3) self.bn2 = nn.BatchNorm2d(32)
self.conv3 = nn.Conv2d(32, 32, 3) self.bn3 = nn.BatchNorm2d(32)
self.gap = nn.AdaptiveAvgPool2d(1) self.fc = nn.Linear(32, 10)
def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.pool(x) x = self.conv2(x) x = self.bn2(x) x = self.relu(x) x = self.conv3(x) x = self.bn3(x) x = self.gap(x) x = x.view(-1, 32) x = self.fc(x)
return


只保存模型参数

首先,PyTorch 官方推荐的第一种方式是只保存模型的参数。
对于一个卷积网络模型来说,模型的卷积层、BN层是有经过训练得到的参数的,只需要把对应每一层的参数存储起来,就可以再次加载模型。
而模型的参数存储在一个字典中,通过 `model.state_dict()` 即可得到。
# 保存 PyTorch 模型参数torch.save(model.state_dict(), "model.pt")
# 重新载入模型参数# 首先定义模型model = SampleNet()#通过 load_state_dict 函数加载参数,torch.load() 函数中重要的一步是反序列化。model.load_state_dict(torch.load("model.pt"))
利用这种方式,可以更加灵活的利用训练好的模型,比如我需要再次训练的时候,改变来模型,但是只加载部分模型参数。

加载部分模型参数

# SampleNet2 比 SampleNet 多一个卷积层model = SampleNet2()# load_params 和 model_params 分别为两个模型的参数字典load_params = torch.load("model.pt")model_params = model.state_dict()# 构建一个新参数字典,为两个模型重复的部分same_parsms = {k: v for k, v in load_params.items() if k in model_params.keys()}# 更新模型参数字典,并载入model_params.update(same_parsms)model.load_state_dict(model_params)

保存加载全部模型

torch.save(model, "full_model.pt")new_model = torch.load("full_model.pt")

跨设备加载模型

在GPU上保存,在CPU上加载

一般模型训练都是在GPU设备,保存后能在GPU设备上加载运行,而若想在CPU设备上加载,只需在load 函数中加一个map_location参数即可。
device = torch.device('cpu')model = SampleNet()model.load_state_dict(torch.load("model.pt", map_location=device))

在CPU上保存,在GPU上加载

反过来,在也只需要改变map_location参数,但要注意将模型也对应到相同设备。
device = torch.device("cuda")model = SampleNet()model.load_state_dict(torch.load("model.pt", map_location="cuda:0"))model.to(device)

保存 torch.nn.DataParallel 模型

如果是要保存在单机多GPU上训练的模型,则需要特别注意一下。
model = SampleNet()model = torch.nn.DataParallel(model)torch.save(model.module.state_dict(), "model")
以上就是 PyTorch 在保存模型时的一些方式和技巧,希望能够帮助到你~

本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。


直播预告



“他山之石”历史文章


更多他山之石专栏文章,

请点击文章底部“阅读原文”查看



分享、点赞、在看,给个三连击呗!

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存