查看原文
其他

ICLR 2023 | HomoDistil:蒸馏和剪枝在知识传递上的有机结合

An. PaperWeekly 2023-04-24

©Paperweekly 原创 · 作者 | An.

单位 | 中科院自动化所

研究方向 | 计算机视觉、模型压缩



论文标题:
HomoDistil: Homotopic Task-Agnostic Distillation of Pre-trained Transformers

论文链接:

https://arxiv.org/pdf/2302.09632.pdf



动机&背景

随着预训练大模型规模的不断增加,任务特定蒸馏(下游微调蒸馏)的成本越来越高,任务无关蒸馏变得越来越重要。然而,由于教师模型的模型容量和表示能力远超学生模型,因此学生很难在大量开放域训练数据上模仿教师的预测。本文提出了同源蒸馏(Homotopic Distillation, HomoDistil)来缓解这一问题,该方法充分利用了蒸馏和剪枝的优势,将两者有机结合在了一起

具体来说,本文用教师模型初始化学生模型,以缓解两者在蒸馏过程中的容量和能力差异,并通过基于蒸馏损失的重要性得分的迭代剪枝,来逐步将学生模型修剪至最终想要的目标结构。在整个蒸馏+剪枝的过程中,教师和学生一直保持着较小的预测差异,这有助于知识更有效的传递。其核心动机如图 1 所示。


▲ 图1. HomoDistil 动机说明(用剪枝给知识蒸馏做初始化,并迭代式地获得最终的学生模型结构)



HomoDistil:同源任务无关蒸馏

如图 2 所示,本文所提出的 HomoDistil 先用教师模型初始化学生,并以类似 TinyBERT [1] 的蒸馏损失函数作为修剪的目标函数,在每次迭代中,根据重要性得分从学生中删除最不重要的神经元并用蒸馏损失指导学生的训练。在整个训练过程中不断重复此过程,直至学生达到目标规模。该方法可从「蒸馏损失函数」和「迭代剪枝细节」两部分进行介绍。


▲ 图2. HomoDistil 方法的示意图,矩形的宽度表示层的宽度,颜色的深度反映训练的充分性。

2.1 蒸馏损失函数

本文采用了与 TinyBERT [1] 的通用蒸馏阶段类似的蒸馏损失函数进行任务无关的蒸馏整体损失函数可以分为三部分:a)任务损失:设 是学生模型在开放域数据上预训练的任务损失(例如 BERT 的掩码语言建模损失 );b)概率蒸馏损失:即 Hinton [2] 经典 KD 论文中的 KL 散度损失;c)Transformer 蒸馏损失:具体包括教师和学生的中间层及嵌入层的隐层表示的差异损失,以及中间层的注意力得分差异损失。
设教师和学生的第 层的隐层表示为 ,中间层隐层表示的蒸馏损失可定义为:
其中 是均分误差 随机初始化的可学习线性投影。类似地,嵌入层隐层表示的蒸馏损失可定义为:

式(2)中的 , 以及 与式(1)中的含义类似。最后,注意力蒸馏损失可定义为:
式(3)的 表示第 层注意力得分矩阵的平均值,Transformer 蒸馏损失旨在从教师的中间层中捕获丰富的语义和句法只是,以提高学生的泛化性能。最终的损失函数为:

2.2 迭代剪枝细节

下面将对迭代剪枝的几个关键要素的细节进行介绍。

初始化本文迭代剪枝的对象是学生模型,且初始状态为预训练的教师模型,即
权重更新本文以 SGD 为优化器,以式(5)的 为目标函数,在每轮迭代中对模型权重进行更新,即:
修剪准则本文采用 [3] 提出的敏感度作为重要性得分 中第 个参数的敏感度 定义为其梯度和权重的乘积大小,即:
剪枝粒度对于学生模型中的任意权重矩阵 ,本文将其对应的重要性得分记为 ,并以列作为剪枝粒度,每次迭代时最小的修剪粒度为权重的一整列,单个列的重要性得分 为:
式(7)中的重要性得分是基于 计算的,直观上讲,该算法会优先修剪那些删除后对任务损失、预测差异以及蒸馏知识传递影响最小的权重列
迭代修剪策略在每轮迭代中,我们将根据下式获得掩码矩阵。
其中, 常采用立方递减函数来调整迭代过程中的稀疏性, 是最终稀疏度, 为总训练迭代次数,这样可以保证稀疏性缓慢增加,列逐步被修剪,从而防止学生预测性能的突然下降,具体公式如下:
稀疏模式与过往剪枝方法常用的全局稀疏模式不同,本文采用的是针对单个权重的局部稀疏性,即修剪后的模型在所有权重矩阵内都满足一定的稀疏性要求。一方面,局部稀疏性对硬件和软件更优化,能够实现更大的推理加速;另一方面,局部稀疏性有助于更好地保持和教师模型相近的模型框架,这有助于蒸馏知识的传递。

2.3 与过往剪枝+蒸馏方法的对比

从蒸馏的角度表 1 展示了蒸馏视角下 HomoDistil 和现有“剪枝+蒸馏”方法的区别。

▲ 表1. 蒸馏视角下 HomoDistil 和其他方法的对比
从剪枝的角度表 2 展示了剪枝视角下 HomoDistil 和现有“剪枝+蒸馏”方法的区别。

▲ 表2. 剪枝视角下 HomoDistil 和其他方法的对比




实验

GLUE 数据集:如表 3 所示,HomoDistil 在 6/8个任务上取得了最优的性能,并在 MNLI、SST-2 和 CoLA 上取得了显著提升。对于 10~20M 参数量的学生,增幅更为显著。


▲ 表3. GLUE 验证集上微调后的蒸馏模型性能对比(取 5 个随机种子结果的中位数)

SQuAD 数据集:表 4 的结果充分证明了 HomoDistil 方法的有效性,所有的 HomoBERT 学生都比最佳基线(MiniLM)高出 3 个点以上。


▲ 表4. SQuAD v1.1/2.0 验证集上微调后的蒸馏模型性能对比(取 5 个随机种子结果的中位数)


消融 1-损失函数:表 5 显示,使用蒸馏损失训练的学生性能始终优于没有蒸馏损失的模型,这说明教师的知识对于恢复由修剪导致的性能下降至关重要。


▲ 表5. GLUE 数据集上有无蒸馏损失的 HomoBERT 性能对比


消融 2 - 重要性得分:表 6 说明敏感度和 PLATON [4] 的重要性得分指标优于基线。


▲ 表6. GLUE 数据集上不同重要性得分指标下 HomoBERT 的性能对比



分析

预测差异:图 3 证明了本文动机的合理性,实验表明,由完整教师模型初始化的学生在整个迭代修剪过程中都比随机初始化的模型具有更小的预测差异。


▲ 图3. HomoBERT 在迭代修剪参数下蒸馏过程中的预测差异


通用知识传递:从图 4 可以看出,使用教师模型初始化的学生,在经过修剪后,在下游任务上具有明显更好的泛化性能,这说明 HomoDistil 这一范式确实传递了通用的任务无关的知识。


▲ 图4. GLUE 下不同迭代修剪参数的 HomoBERT-small 的微调准确率

参考文献

[1] Xiaoqi Jiao, Yichun Yin, Lifeng Shang, Xin Jiang, Xiao Chen, Linlin Li, Fang Wang, and Qun Liu. Tinybert: Distilling bert for natural language understanding. arXiv preprint arXiv:1909.10351, 2019.
[2] Geoffrey Hinton, Oriol Vinyals, and Jeff Dean. Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531, 2015.
[3] Molchanov, Pavlo, et al. “Importance Estimation for Neural Network Pruning.” 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 2020.

[4] Qingru Zhang, Simiao Zuo, Chen Liang, Alexander Bukharin, Pengcheng He, Weizhu Chen, and Tuo Zhao. Platon: Pruning large transformer models with upper confidence bound of weight importance. In International Conference on Machine Learning, pp. 26809–26823. PMLR, 2022.



更多阅读



#投 稿 通 道#

 让你的文字被更多人看到 



如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。


总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 


PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。


📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算


📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿


△长按添加PaperWeekly小编



🔍


现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧


·
·

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存