【他山之石】训练时显存优化技术——OP合并与gradient checkpoint
地址:http://bindog.github.io/
01
背景
02
反向传播是如何工作的?
f(x, y) = x * y
# gradient for x: y
# gradient for y: x
g(x) = sigmoid(x) # 1 / (1 + exp(-x))
# gradient for x: sigmoid(x) * (1 - sigmoid(x))
03
显存被谁吃掉了
在适用乘法的求导规则时,要求我们要事先保留下中间结果x和sigmoid(x),有人可能会说只保留一个x不就可以了吗?sigmoid(x)可以通过计算得出,注意框架定义的乘法及其求导规则是通用规则,乘法的左右两边完全可能是不相关的两个值,所以必须同时保留下来。 在对sigmoid函数适用求导规则时,需要存下中间结果x。
04
手动合并OP
那么有没有办法优化呢?当然是可以的,既然我们能用数学公式提前算出swish acivation的梯度,那么直接将其视为一个整体不就好了?无非就是定义一个新的函数和新的求导规则
swish(x) = x * sigmoid(x)
# gradient for x: sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x))
这样一来,计算图变成了下面这个样子:
对II型,更进一步,直接用
虽然推导过程有些复杂,但写出求导公式后,我们只需要将其封装进手写的模块中即可。原论文[4]中的实现表明,采用Inplace-abn后,显存占用最高可下降50%左右,而且由于Leaky ReLU实际效果其实与ReLU非常接近,省下来的显存可以用于提高batch_size,模型训练实际上能从中得到更大收益。
05
还能更进一步吗?
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
check_backward_validity(args)
ctx.run_function = run_function
ctx.preserve_rng_state = preserve_rng_state
if preserve_rng_state:
ctx.fwd_cpu_state = torch.get_rng_state()
# Don't eagerly initialize the cuda context by accident.
# (If the user intends that the context is initialized later, within their
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
# we have no way to anticipate this will happen before we run the function.)
ctx.had_cuda_in_fwd = False
if torch.cuda._initialized:
ctx.had_cuda_in_fwd = True
ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
ctx.save_for_backward(*args)
with torch.no_grad():
outputs = run_function(*args)
return outputs
@staticmethod
def backward(ctx, *args):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")
inputs = ctx.saved_tensors
# Stash the surrounding rng state, and mimic the state that was
# present at this time during forward. Restore the surrounding state
# when we're done.
rng_devices = []
if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
rng_devices = ctx.fwd_gpu_devices
with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
if ctx.preserve_rng_state:
torch.set_rng_state(ctx.fwd_cpu_state)
if ctx.had_cuda_in_fwd:
set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
detached_inputs = detach_variable(inputs)
with torch.enable_grad():
outputs = ctx.run_function(*detached_inputs)
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
torch.autograd.backward(outputs, args)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
for inp in detached_inputs)
return (None, None) + grads
06
CUDA版的swish activation
无优化纯Python:GPU memory=6383MB,time=223ms 合并算子(Python):GPU memory=5139MB,time=234ms 合并算子(CUDA):GPU memory=5143MB,time=188ms
[1] https://zhuanlan.zhihu.com/p/122943688
[2] https://arxiv.org/abs/1710.05941
[3] https://github.com/mapillary/inplace_abn
[4] https://arxiv.org/pdf/1712.02616.pdf
[5] https://github.com/cybertronai/gradient-checkpointing
[6] https://github.com/pytorch/pytorch/blob/176174a68ba2d36b9a5aaef0943421682ecc66d4/torch/utils/checkpoint.py#L55
[7] https://zhuanlan.zhihu.com/p/138730559
[8] https://arxiv.org/abs/1604.06174
[9] https://pytorch.org/tutorials/advanced/cpp_extension.html
[10] https://github.com/bindog/swish_optimize
[1] https://zhuanlan.zhihu.com/p/122943688
[2] https://arxiv.org/abs/1710.05941
[3] https://github.com/mapillary/inplace_abn
[4] https://arxiv.org/pdf/1712.02616.pdf
[5] https://github.com/cybertronai/gradient-checkpointing
[6] https://github.com/pytorch/pytorch/blob/176174a68ba2d36b9a5aaef0943421682ecc66d4/torch/utils/checkpoint.py#L55
[7] https://zhuanlan.zhihu.com/p/138730559
[8] https://arxiv.org/abs/1604.06174
[9] https://pytorch.org/tutorials/advanced/cpp_extension.html
[10] https://github.com/bindog/swish_optimize
本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。
直播预告
历史文章推荐
【CVPR 2020 Tutorial】如何写好论文和评审(概述)
太牛逼了!一位中国博士把整个CNN都给可视化了,每个细节看的清清楚楚!
Nature发表牛津博士建议:我希望在读博士之初时就能知道的20件事
沈向洋、华刚:读科研论文的三个层次、四个阶段与十个问题
如何看待2021年秋招算法岗灰飞烟灭?
独家解读 | ExprGAN:基于强度可控的表情编辑
独家解读 | 矩阵视角下的BP算法
独家解读 | Capsule Network深度解读
独家解读 | Fisher信息度量下的对抗攻击
论文解读 | 知识图谱最新研究综述
你的毕业论文过了吗?《如何撰写毕业论文?》
卡尔曼滤波系列——经典卡尔曼滤波推导
分享、点赞、在看,给个三连击呗!