PyTorch与caffe中SGD算法实现的一点小区别
加入极市专业CV交流群,与6000+来自腾讯,华为,百度,北大,清华,中科院等名企名校视觉开发者互动交流!更有机会与李开复老师等大牛群内互动!
同时提供每月大咖直播分享、真实项目需求对接、干货资讯汇总,行业技术交流。关注 极市平台 公众号 ,回复 加群,立刻申请入群~
作者:朱见深
来源:https://zhuanlan.zhihu.com/p/43016574
本文已经作者授权,未经许可不得二次转载
PS: 之前我的理解有一点偏差,经过刘昊淼和王赟 Maigo的提醒现在已经更正了。知乎的这个编辑器打公式太麻烦了,更新后的内容请看原文链接
刘昊淼 知乎主页:https://www.zhihu.com/people/liu-hao-miao-82/activities
王赟 Maigo知乎主页:https://www.zhihu.com/people/maigo/activities
原文链接:http://kaizhao.net/blog/posts/momentum-caffe-pytorch/
最近在复现之前自己之前的一个paper的时候发现PyTorch与caffe在实现SGD优化算法时有一处不太引人注意的区别,导致原本复制caffe中的超参数在PyTorch中无法复现性能。
这个区别与momentum有关。简单地说,[1]和caffe的实现中,momentum项只用乘以一个系数
假设目标函数是
(1)式中
我们先看caffe关于这部分的实现(代码在 github.com/BVLC/caffe/b)
github.com/BVLC/caffe/b 链接:https://github.com/BVLC/caffe/blob/99bd99795dcdf0b1d3086a8d67ab1782a8a08383/src/caffe/solvers/sgd_solver.cpp#L232-L234
template <typename D
type>
void SGDSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
const vector<float>& net_params_lr = this->net_->params_lr();
Dtype momentum = this->param_.momentum();
Dtype local_rate = rate * net_params_lr[param_id];
// Compute the update to history, then copy it to the parameter diff.
switch (Caffe::mode()) {
case Caffe::CPU: {
caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
net_params[param_id]->cpu_diff(), momentum,
history_[param_id]->mutable_cpu_data());
caffe_copy(net_params[param_id]->count(),
history_[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
break;
}
case Caffe::GPU: {
#ifndef CPU_ONLY
sgd_update_gpu(net_params[param_id]->count(),
net_params[param_id]->mutable_gpu_diff(),
history_[param_id]->mutable_gpu_data(),
momentum, local_rate);
#else
NO_GPU;
#endif
break;
}
default:
LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
}
}
函数ComputeUpdateValue主要用于计算最后参数的更新值 ,也就是(2)式中的
caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
net_params[param_id]->cpu_diff(), momentum,
history_[param_id]->mutable_cpu_data());
这里axpby就是
我们再来看看PyTorch相关部分的代码(代码链接github.com/pytorch/pyto):
github.com/pytorch/pyto 链接:https://github.com/pytorch/pytorch/blob/9679fc5fcd36248ffe67f70d5c135d7af8ba0e2b/torch/optim/sgd.py#L88-L105
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']
for p in group['params']:
if p.grad is None:
continue
d_p = p.grad.data
if weight_decay != 0:
d_p.add_(weight_decay, p.data)
if momentum != 0:
param_state = self.state[p]
if 'momentum_buffer' not in param_state:
buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
buf.mul_(momentum).add_(d_p)
else:
buf = param_state['momentum_buffer']
buf.mul_(momentum).add_(1 - dampening, d_p)
if nesterov:
d_p = d_p.add(momentum, buf)
else:
d_p = buf
p.data.add_(-group['lr'], d_p)
这里d_p是参数的导数,可以看到PyTorch的实现和(1)(2)式不太一样,是按照下面的规则更新参数的:
为了方便对比我们把(1)(2)也搬过来:
(1)(2)是caffe的实现,和[1]一致;(3)(4)是PyTorch的实现。可以看出来,相对于caffe的实现,PyTorch真正的momentum系数相当于caffe的momentum再乘以学习率
因此使用PyTorch的时候,当学习率非常小(比如像我这样使用类似FCN结构的网络,学习率<1e-6),那么实际上的有效momentum是非常小的。
我不知道PyTorch是基于什么样的考虑要这样设计,文档中倒是有说这个区别,但是并没有解释 (文档链接torch.optim - PyTorch master documentation)
torch.optim - PyTorch master documentation 链接:https://pytorch.org/docs/stable/optim.html?highlight=sgd#torch.optim.SGD
[1] Sutskever, Ilya, et al. "On the importance of initialization and momentum in deep learning."International conference on machine learning. 2013.
-End-
*延伸阅读
CV细分方向交流群
添加极市小助手微信(ID : cv-mart),备注:研究方向-姓名-学校/公司-城市(如:目标检测-小极-北大-深圳),即可申请加入目标检测、目标跟踪、人脸、工业检测、医学影像、三维&SLAM、图像分割等极市技术交流群(已经添加小助手的好友直接私信),更有每月大咖直播分享、真实项目需求对接、干货资讯汇总,行业技术交流,一起来让思想之光照的更远吧~
△长按添加极市小助手
△长按关注极市平台
觉得有用麻烦给个在看啦~