查看原文
其他

巨省显存的重计算技巧在TF、Keras中的正确打开方式

苏剑林 夕小瑶的卖萌屋 2021-02-04

一只小狐狸带你解锁 炼丹术&NLP 秘籍

作者:苏剑林(来自追一科技,人称“苏神”)

前言

在前不久的文章《BERT重计算:用22.5%的训练时间节省5倍的显存开销(附代码)》中介绍了一个叫做“重计算”的技巧(附pytorch和paddlepaddle实现)。简单来说重计算就是用来省显存的方法,让平均训练速度慢一点,但batch_size可以增大好几倍,该技巧首先发布于论文《Training Deep Nets with Sublinear Memory Cost》。

最近笔者发现,重计算的技巧在tensorflow也有实现。事实上从tensorflow1.8开始,tensorflow就已经自带了该功能了,当时被列入了tf.contrib这个子库中,而从tensorflow1.15开始,它就被内置为tensorflow的主函数之一,那就tf.recompute_grad找到 tf.recompute_grad 之后,笔者就琢磨了一下它的用法,经过一番折腾,最终居然真的成功地用起来了,居然成功地让 batch_size 从48增加到了144!然而,在继续整理测试的过程中,发现这玩意居然在tensorflow 2.x是失效的...于是再折腾了两天,查找了各种资料并反复调试,最终算是成功地补充了这一缺陷。

最后是笔者自己的开源实现:

Github地址:

https://github.com/bojone/keras_recompute

该实现已经内置在bert4keras中,使用bert4keras的读者可以升级到最新版本(0.7.5+)来测试该功能。

使用

笔者的实现也命名为recompute_grad,它是一个装饰器,用于自定义Keras层的 call函数,比如

from recompute import recompute_grad
class MyLayer(Layer):@recompute_graddef call(self, inputs):return inputs * 2

对于已经存在的层,可以通过继承的方式来装饰:

from recompute import recompute_grad
class MyDense(Dense):@recompute_graddef call(self, inputs):return super(MyDense, self).call(inputs)

自定义好层之后,在代码中嵌入自定义层,然后在执行代码之前,加入环境变量RECOMPUTE=1来启用重计算。

注意:不是在总模型里插入了@recomputr_grad,就能达到省内存的目的,而是要在每个层都插入@recomputr_grad才能更好地省显存。简单来说,就是插入的@recomputr_grad越多,就省显存。具体原因请仔细理解重计算的原理。


效果

bert4keras0.7.5已经内置了重计算,直接传入环境变量RECOMPUTE=1就会启用重计算,读者可以自行尝试,大概的效果是:

1、在BERT Base版本下,batch_size可以增大为原来的3倍左右;

2、在BERT Large版本下,batch_size可以增大为原来的4倍左右;

3、平均每个样本的训练时间大约增加25%;

4、理论上,层数越多,batch_size可以增大的倍数越大。

环境

在下面的环境下测试通过:

tensorflow 1.14 + keras 2.3.1

tensorflow 1.15 + keras 2.3.1

tensorflow 2.0 + keras 2.3.1

tensorflow 2.1 + keras 2.3.1

tensorflow 2.0 + 自带tf.keras

tensorflow 2.1 + 自带tf.keras

确认不支持的环境:

tensorflow 1.x + 自带tf.keras

欢迎报告更多的测试结果。

顺便说一下,强烈建议用keras2.3.1配合tensorflow1.x/2.x来跑,强烈不建议使用tensorflow 2.x自带的tf.keras来跑





夕小瑶的卖萌屋

_

关注&星标小夕,带你解锁AI秘籍

订阅号主页下方「撩一下」有惊喜哦

    您可能也对以下帖子感兴趣

    文章有问题?点此查看未经处理的缓存