查看原文
其他

【综述专栏】BERT知识蒸馏综述

在科学研究中,从方法论上来讲,都应“先见森林,再见树木”。当前,人工智能学术研究方兴未艾,技术迅猛发展,可谓万木争荣,日新月异。对于AI从业者来说,在广袤的知识森林中,系统梳理脉络,才能更好地把握趋势。为此,我们精选国内外优秀的综述文章,开辟“综述专栏”,敬请关注。

来源:知乎—中二青年
地址:https://zhuanlan.zhihu.com/p/106810758


01

背景
BERT有多香这里就不赘述了,但最大的问题是没法上线。
先问个问题,bert这么大的参数量,参数拟合是否都合理?能不能用更少的参数达到类似的效果?
拿nlp的三大特征抽取器举例,cnn,rnn,transformer。如果你对细节足够了解。LSTM的门逻辑,transformer的attention都是从逻辑上,it should be work。但从数学上并没有solid的理论支持,告诉我们是真的达到拟合上限了。
所以,我们可以大胆把能不能中的不能去掉——我们能用更少的参数达到类似的效果。
业界主流的几种模型压缩的方式,剪枝,权重分解,参数共享,量化,如封面标题所示。
(图片转载于这篇文章(zhuanlan.zhihu.com/p/93)
但从工业界和学术界的反馈来讲,针对BERT的模型压缩,稍微主流一点的还是知识蒸馏。
这篇笔记也主要是针对知识蒸馏做一个梳理。

02

开山之作
参看paper——distilling the knowledge in a neural network
思路比较简单清晰,核心点在于loss function和temperature的定义。
2.1 核心思路
知识蒸馏的本质是让超大线下 teacher model来协助线上student model的training。
一种思路是学习超大model的参数,但由于nn可解释性弱加上复杂度高,故这篇文章直接放弃。但近期效果最好的tiny bert就是很好的利用了中间层的参数。
我们回忆一下,机器学习的很多任务,基本上都是输入扔进去,最后变成了一系列的概率值,或者logits——the input to the final softmax(不然为啥大部分任务的lost function都是softmax+cross entropy)。这些概率值本身就包含了teacher model的核心knowledge。
所以,这篇paper是用这些logits来teach student(当然我猜,另外一个原因还是这种思路好实现)

2.2 lost function如何设计
蒸馏流程图如下
loss function定义如下

结合流程图和公式就很直观了,soft loss function是老师带着学生学习,teacher和student的预测label取交叉熵。而hard loss function是学生看着标准答案学习,student的预测label和训练数据的label取交叉熵。
一个聪明的学生如果能合理的利用老师和参考答案的话(权重设置),大概率是比只有参考答案的学生学的好。
以及soft loss function中,老师为了让学生能够学的更仔细,不仅告诉了标准答案,还把错误的答案也解释了下,这就是temperature的作用。(问题,会不会有的场景下,T都为1反而效果更好?)

注意:student训练完毕后,T要变成1。

2.3 temperature是什么?
了解transformer的基本上一看就懂了。
不过,我这边想要把背景介绍下,这篇paper的实验都是基于MNIST数据集,有这么一句——For tasks like MNIST in which the cumbersome model almost always produces the correct answer with very high confidence。
个人理解,这是在暗示任务集不像真实环境那么复杂,比较简单。导致训练的teacher model的softmax最后产出的概率,往往都非常集中,除了预测值之外,其它的值都接近0。但事实上,可能分类1是概率0.6,分类2是0.2,分类3是0.1,分类2和分类3的knowledge也是我们想要学习的。
一个合适的T就能保证概率值不会那么集中,相当于老师不仅教了你标准答案,还告诉答案哪里错了。

2.4 训练样本的分配
论文中说,using the original training set works well。
2.5 遗留问题——matching logits is a special case of distillation
2.6 可优化点
很明显,虽然老师标准答案和错误答案都告诉你了,但中间步骤细节还是不知道。所以,后面的优化点,基本上就是更多的把中间层的参数也利用起来。

03

distill transformer to another framework
参看paper——DISTILLING TRANSFORMERS INTO SIMPLE NEURAL NETWORKS WITH UNLABELED TRANSFER DATA
这篇paper把bert蒸馏成lstm,提出了hard distillation 和 soft distillation两种方案。
3.1 hard distillation
这个思路需要满足两个条件,有少量专家样本,有海量未标注样本。
流程图如下
少量专家样本基于bert做fine-tunning后,根据我的经验,如果专家样本准确率足够高,分布足够合理,跑出来的model效果往往是非常令人满意的。
我们有了一个效果不错的线下model和海量未标注数据。那么用效果不错的线下model标注海量未标注数据,就得到了准确率还不错的海量训练数据。
最后用这些海量数据训练一个小model就可以了。
虽然思路简单,但冷启动阶段这么干,不要太爽,五星大力推荐。
3.2 soft distillation
框架图如下
简单介绍下三个loss function,细节建议看原paper。
第一,student的预测label要么和teacher预测的label,或者真实label做交叉熵,或者两个都用,这个看你需求。
第二,用student和teacher的logit做了一个square loss,logit的定义参看上文。
第三,student和teacher的representations层做了一个square loss,transformer和LSTM的representations层长度不一致,做个转换就好了。

3.3 可优化点
transformer和LSTM中间层差别太大,针对中间层的参数设计loss function是非常困难的。
故最终,这篇paper只加了logit层和representations层。
所以,transformer到transformer的蒸馏会利用更多的中间层信息。

04

distill transformer to transformer
参看paper——
Patient Knowledge Distillation for BERT Model
TINYBERT- DISTILLING BERT FOR NATURAL LANGUAGE UNDERSTANDING

transformer往transformer进行知识蒸馏因为框架的类似,所以中间层的参数的loss function设计相对简单,并且从逻辑上来说,it should be work。
除此之外,根据2019年的这篇paper——What does bert look at? an analysis of bert’s attention。得出这么一个结论——attention weights learned by BERT can capture substantial linguistic knowledge。所以,中间层的参数的学习看起来是非常有意义的。
这里特别要赞一下tiny bert,基本上把transformer中能尝试的中间层参数都设计了loss function,实验非常扎实。
4.1 from m to n (m > n)
我们要缩小bert的大小,第一个想法就是减少bert的层数。但这样大BERT和小BERT层数就没法一一对应了。
现在有两种思路
第一种,大bert每隔几层做一次映射到小bert上
第二种,大bert取最后几层做映射,因为最后几层往往包含更多的knowledge。
还是大家从逻辑上,觉得it should be work。
从华为的tiny bert来看,不同的任务适用的方法不一样,大白话就是自己试吧,我也不知道那个好。
4.2 tiny bert loss function的设计
m代表层数,M代表总层数。
Embd loss function是m = 0,0层,也就是输入层embeding。
Hidn loss function是M>= m > 0,中间层的hiden向量,attn loss function自然就是中间层的attention向量。
Pred loss function是m = M +1, 最后的输出层。

最后,各个层的loss function再做个相加,以及中间层的参数长度不一致自然要做矩阵转换。
基本上能用的中间层参数都给用上了,这个代码量,调参量,工作量,必须要点个赞。
4.3 two stage learning frame-work
其实就是genenra bert蒸馏后,发现效果还是不够好,那么在fine-tuned的时候再做一遍蒸馏。

05

总结
近期bert的知识蒸馏的,大致分成两种。
第一种,从transformer到非transformer框架的知识蒸馏
这种由于中间层参数的不可比性,导致从teacher model可学习的知识比较受限。但比较自由,可以把知识蒸馏到一个非常小的model,但效果肯定会差一些。
第二种,从transformer到transformer框架的知识蒸馏
由于中间层参数可利用,所以知识蒸馏的效果会好很多,甚至能够接近原始bert的效果。但transformer即使只有三层,参数量其实也不少,另外蒸馏过程的计算也无法忽视。
所以最后用那种,还是要根据线上需求来取舍。

本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。


“综述专栏”历史文章


更多综述专栏文章,

请点击文章底部“阅读原文”查看



分享、点赞、在看,给个三连击呗!

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存