其他
【他山之石】Pytorch:eval()的用法比较
“他山之石,可以攻玉”,站在巨人的肩膀才能看得更高,走得更远。在科研的道路上,更需借助东风才能更快前行。为此,我们特别搜集整理了一些实用的代码链接,数据集,软件,编程技巧等,开辟“他山之石”专栏,助你乘风破浪,一路奋勇向前,敬请关注。
01
1.1 model.train()
1.2 model.eval()
1.3 分析原因
# 定义一个网络
class Net(nn.Module):
def __init__(self, l1=120, l2=84):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, l1)
self.fc2 = nn.Linear(l1, l2)
self.fc3 = nn.Linear(l2, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# 实例化这个网络
Model = Net()
# 训练模式使用.train()
Model.train(mode=True)
# 测试模型使用.eval()
Model.eval()
def train(model, optimizer, epoch, train_loader, validation_loader):
model.train() # ???????????? 错误的位置
for batch_idx, (data, target) in experiment.batch_loop(iterable=train_loader):
# model.train() # 正确的位置,保证每一个batch都能进入model.train()的模式
data, target = Variable(data), Variable(target)
# Inference
output = model(data)
loss_t = F.nll_loss(output, target)
# The iconic grad-back-step trio
optimizer.zero_grad()
loss_t.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
train_loss = loss_t.item()
train_accuracy = get_correct_count(output, target) * 100.0 / len(target)
experiment.add_metric(LOSS_METRIC, train_loss)
experiment.add_metric(ACC_METRIC, train_accuracy)
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx, len(train_loader),
100. * batch_idx / len(train_loader), train_loss))
with experiment.validation():
val_loss, val_accuracy = test(model, validation_loader) # ????????????
experiment.add_metric(LOSS_METRIC, val_loss)
experiment.add_metric(ACC_METRIC, val_accuracy)
def test(model, test_loader):
model.eval()
# ...
02
在train模式下,dropout网络层会按照设定的参数p设置保留激活单元的概率(保留概率=p); BN层会继续计算数据的mean和var等参数并更新。
在eval模式下,dropout层会让所有的激活单元都通过,而BN层会停止计算和更新mean和var,直接使用在训练阶段已经学出的mean和var值。
本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。
“他山之石”历史文章
ONNX模型文件->可执行文件 C Runtime通路 具体实现方法
Pytorch mixed precision 概述(混合精度)
Weights & Biases (兼容多种深度学习框架的可视化工具WB中文简介)
GCN实现及其中的归一化
Pytorch Lightning 完全攻略
Tensorflow之TFRecord的原理和使用心得
从零开始实现一个卷积神经网络
斯坦福大规模网络数据集
超轻量的YOLO-Nano
MMAction2: 新一代视频理解工具箱
TensorFlow神经网络实现二分类的正确姿势
人类早期驯服野生机器学习模型的珍贵资料
不会强化学习,只会numpy,能解决多难的RL问题?
技术总结《OpenAI Gym》
ROC和CMC曲线的理解(FAR, FRR的理解)
更多他山之石专栏文章,
请点击文章底部“阅读原文”查看
分享、点赞、在看,给个三连击呗!