背景BERT有多香这里就不赘述了,但最大的问题是没法上线。先问个问题,bert这么大的参数量,参数拟合是否都合理?能不能用更少的参数达到类似的效果?拿nlp的三大特征抽取器举例,cnn,rnn,transformer。如果你对细节足够了解。LSTM的门逻辑,transformer的attention都是从逻辑上,it should be work。但从数学上并没有solid的理论支持,告诉我们是真的达到拟合上限了。所以,我们可以大胆把能不能中的不能去掉——我们能用更少的参数达到类似的效果。 业界主流的几种模型压缩的方式,剪枝,权重分解,参数共享,量化,如封面标题所示。(图片转载于这篇文章(zhuanlan.zhihu.com/p/93)但从工业界和学术界的反馈来讲,针对BERT的模型压缩,稍微主流一点的还是知识蒸馏。 这篇笔记也主要是针对知识蒸馏做一个梳理。
02
开山之作参看paper——distilling the knowledge in a neural network 思路比较简单清晰,核心点在于loss function和temperature的定义。2.1 核心思路知识蒸馏的本质是让超大线下 teacher model来协助线上student model的training。 一种思路是学习超大model的参数,但由于nn可解释性弱加上复杂度高,故这篇文章直接放弃。但近期效果最好的tiny bert就是很好的利用了中间层的参数。 我们回忆一下,机器学习的很多任务,基本上都是输入扔进去,最后变成了一系列的概率值,或者logits——the input to the final softmax(不然为啥大部分任务的lost function都是softmax+cross entropy)。这些概率值本身就包含了teacher model的核心knowledge。所以,这篇paper是用这些logits来teach student(当然我猜,另外一个原因还是这种思路好实现)
2.2 lost function如何设计蒸馏流程图如下loss function定义如下
结合流程图和公式就很直观了,soft loss function是老师带着学生学习,teacher和student的预测label取交叉熵。而hard loss function是学生看着标准答案学习,student的预测label和训练数据的label取交叉熵。一个聪明的学生如果能合理的利用老师和参考答案的话(权重设置),大概率是比只有参考答案的学生学的好。以及soft loss function中,老师为了让学生能够学的更仔细,不仅告诉了标准答案,还把错误的答案也解释了下,这就是temperature的作用。(问题,会不会有的场景下,T都为1反而效果更好?)
注意:student训练完毕后,T要变成1。
2.3 temperature是什么?了解transformer的基本上一看就懂了。不过,我这边想要把背景介绍下,这篇paper的实验都是基于MNIST数据集,有这么一句——For tasks like MNIST in which the cumbersome model almost always produces the correct answer with very high confidence。个人理解,这是在暗示任务集不像真实环境那么复杂,比较简单。导致训练的teacher model的softmax最后产出的概率,往往都非常集中,除了预测值之外,其它的值都接近0。但事实上,可能分类1是概率0.6,分类2是0.2,分类3是0.1,分类2和分类3的knowledge也是我们想要学习的。一个合适的T就能保证概率值不会那么集中,相当于老师不仅教了你标准答案,还告诉答案哪里错了。
2.4 训练样本的分配论文中说,using the original training set works well。2.5 遗留问题——matching logits is a special case of distillation2.6 可优化点很明显,虽然老师标准答案和错误答案都告诉你了,但中间步骤细节还是不知道。所以,后面的优化点,基本上就是更多的把中间层的参数也利用起来。
03
distill transformer to another framework参看paper——DISTILLING TRANSFORMERS INTO SIMPLE NEURAL NETWORKS WITH UNLABELED TRANSFER DATA这篇paper把bert蒸馏成lstm,提出了hard distillation 和 soft distillation两种方案。3.1 hard distillation这个思路需要满足两个条件,有少量专家样本,有海量未标注样本。 流程图如下少量专家样本基于bert做fine-tunning后,根据我的经验,如果专家样本准确率足够高,分布足够合理,跑出来的model效果往往是非常令人满意的。我们有了一个效果不错的线下model和海量未标注数据。那么用效果不错的线下model标注海量未标注数据,就得到了准确率还不错的海量训练数据。最后用这些海量数据训练一个小model就可以了。虽然思路简单,但冷启动阶段这么干,不要太爽,五星大力推荐。3.2 soft distillation框架图如下简单介绍下三个loss function,细节建议看原paper。第一,student的预测label要么和teacher预测的label,或者真实label做交叉熵,或者两个都用,这个看你需求。第二,用student和teacher的logit做了一个square loss,logit的定义参看上文。第三,student和teacher的representations层做了一个square loss,transformer和LSTM的representations层长度不一致,做个转换就好了。
distill transformer to transformer参看paper——Patient Knowledge Distillation for BERT Model TINYBERT- DISTILLING BERT FOR NATURAL LANGUAGE UNDERSTANDING
transformer往transformer进行知识蒸馏因为框架的类似,所以中间层的参数的loss function设计相对简单,并且从逻辑上来说,it should be work。除此之外,根据2019年的这篇paper——What does bert look at? an analysis of bert’s attention。得出这么一个结论——attention weights learned by BERT can capture substantial linguistic knowledge。所以,中间层的参数的学习看起来是非常有意义的。这里特别要赞一下tiny bert,基本上把transformer中能尝试的中间层参数都设计了loss function,实验非常扎实。4.1 from m to n (m > n)我们要缩小bert的大小,第一个想法就是减少bert的层数。但这样大BERT和小BERT层数就没法一一对应了。现在有两种思路第一种,大bert每隔几层做一次映射到小bert上第二种,大bert取最后几层做映射,因为最后几层往往包含更多的knowledge。还是大家从逻辑上,觉得it should be work。从华为的tiny bert来看,不同的任务适用的方法不一样,大白话就是自己试吧,我也不知道那个好。4.2 tiny bert loss function的设计m代表层数,M代表总层数。Embd loss function是m = 0,0层,也就是输入层embeding。Hidn loss function是M>= m > 0,中间层的hiden向量,attn loss function自然就是中间层的attention向量。Pred loss function是m = M +1, 最后的输出层。
最后,各个层的loss function再做个相加,以及中间层的参数长度不一致自然要做矩阵转换。基本上能用的中间层参数都给用上了,这个代码量,调参量,工作量,必须要点个赞。4.3 two stage learning frame-work 其实就是genenra bert蒸馏后,发现效果还是不够好,那么在fine-tuned的时候再做一遍蒸馏。