查看原文
其他

PyTorch与caffe中SGD算法实现的一点小区别

朱见深 极市平台 2021-09-20

加入极市专业CV交流群,与6000+来自腾讯,华为,百度,北大,清华,中科院等名企名校视觉开发者互动交流!更有机会与李开复老师等大牛群内互动!

同时提供每月大咖直播分享、真实项目需求对接、干货资讯汇总,行业技术交流。关注 极市平台 公众号 ,回复 加群,立刻申请入群~


作者:朱见深

来源:https://zhuanlan.zhihu.com/p/43016574

本文已经作者授权,未经许可不得二次转载


PS: 之前我的理解有一点偏差,经过刘昊淼王赟 Maigo的提醒现在已经更正了。知乎的这个编辑器打公式太麻烦了,更新后的内容请看原文链接
  • 刘昊淼 知乎主页:https://www.zhihu.com/people/liu-hao-miao-82/activities

  • 王赟 Maigo知乎主页:https://www.zhihu.com/people/maigo/activities

  • 原文链接:http://kaizhao.net/blog/posts/momentum-caffe-pytorch/



最近在复现之前自己之前的一个paper的时候发现PyTorch与caffe在实现SGD优化算法时有一处不太引人注意的区别,导致原本复制caffe中的超参数在PyTorch中无法复现性能。



这个区别与momentum有关。简单地说,[1]和caffe的实现中,momentum项只用乘以一个系数然后就直接用来更新参数。而PyTorch的实现在此基础上又额外乘了一个学习率,导致实际的有效momentum变小,特别是在学习率很小的情况下。


假设目标函数是,目标函数的导数是,那么SGD根据以下公式更新参数:

 (1) 

(2)


(1)式中表示目标函数的导数,表示momentum的系数(在[1]中被称为velocity), 表示学习率。



我们先看caffe关于这部分的实现(代码在 github.com/BVLC/caffe/b

  • github.com/BVLC/caffe/b 链接:https://github.com/BVLC/caffe/blob/99bd99795dcdf0b1d3086a8d67ab1782a8a08383/src/caffe/solvers/sgd_solver.cpp#L232-L234


template <typename Dtype>
void SGDSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
const vector<float>& net_params_lr = this->net_->params_lr();
Dtype momentum = this->param_.momentum();
Dtype local_rate = rate * net_params_lr[param_id];
// Compute the update to history, then copy it to the parameter diff.
switch (Caffe::mode()) {
case Caffe::CPU: {
caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
net_params[param_id]->cpu_diff(), momentum,
history_[param_id]->mutable_cpu_data());
caffe_copy(net_params[param_id]->count(),
history_[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
break;
}
case Caffe::GPU: {
#ifndef CPU_ONLY
sgd_update_gpu(net_params[param_id]->count(),
net_params[param_id]->mutable_gpu_diff(),
history_[param_id]->mutable_gpu_data(),
momentum, local_rate);
#else
NO_GPU;
#endif
break;
}
default:
LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
}
}


函数ComputeUpdateValue主要用于计算最后参数的更新值 ,也就是(2)式中的。我们重点关注一下部分代码:


caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
net_params[param_id]->cpu_diff(), momentum,
history_[param_id]->mutable_cpu_data());


这里axpby就是,对应着local_rate就是学习率(之所以有local是因为caffe可以逐层设置学习率系数)。net_params[param_id]->cpu_diff()就是参数的导数,也就是(1)式中的。history_[param_id]->mutable_cpu_data()也就是历史累计的momentum,对应的是



我们再来看看PyTorch相关部分的代码(代码链接github.com/pytorch/pyto):

  • github.com/pytorch/pyto 链接:https://github.com/pytorch/pytorch/blob/9679fc5fcd36248ffe67f70d5c135d7af8ba0e2b/torch/optim/sgd.py#L88-L105


def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']

for p in group['params']:
if p.grad is None:
continue
d_p = p.grad.data
if weight_decay != 0:
d_p.add_(weight_decay, p.data)
if momentum != 0:
param_state = self.state[p]
if 'momentum_buffer' not in param_state:
buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
buf.mul_(momentum).add_(d_p)
else:
buf = param_state['momentum_buffer']
buf.mul_(momentum).add_(1 - dampening, d_p)
if nesterov:
d_p = d_p.add(momentum, buf)
else:
d_p = buf

p.data.add_(-group['lr'], d_p)


这里d_p是参数的导数,可以看到PyTorch的实现和(1)(2)式不太一样,是按照下面的规则更新参数的: 


(3) 

(4)


为了方便对比我们把(1)(2)也搬过来: 


(1) 

(2)


(1)(2)是caffe的实现,和[1]一致;(3)(4)是PyTorch的实现。可以看出来,相对于caffe的实现,PyTorch真正的momentum系数相当于caffe的momentum再乘以学习率


因此使用PyTorch的时候,当学习率非常小(比如像我这样使用类似FCN结构的网络,学习率<1e-6),那么实际上的有效momentum是非常小的。




我不知道PyTorch是基于什么样的考虑要这样设计,文档中倒是有说这个区别,但是并没有解释 (文档链接torch.optim - PyTorch master documentation

  • torch.optim - PyTorch master documentation 链接:https://pytorch.org/docs/stable/optim.html?highlight=sgd#torch.optim.SGD



[1] Sutskever, Ilya, et al. "On the importance of initialization and momentum in deep learning."International conference on machine learning. 2013.



-End-



*延伸阅读





CV细分方向交流群


添加极市小助手微信(ID : cv-mart),备注:研究方向-姓名-学校/公司-城市(如:目标检测-小极-北大-深圳),即可申请加入目标检测、目标跟踪、人脸、工业检测、医学影像、三维&SLAM、图像分割等极市技术交流群(已经添加小助手的好友直接私信),更有每月大咖直播分享、真实项目需求对接、干货资讯汇总,行业技术交流,一起来让思想之光照的更远吧~



△长按添加极市小助手


△长按关注极市平台


觉得有用麻烦给个在看啦~  

: . Video Mini Program Like ,轻点两下取消赞 Wow ,轻点两下取消在看

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存