查看原文
其他

​ACL 2023 | AD-KD:归因驱动的预训练语言模型知识蒸馏框架

吴思越 PaperWeekly 2023-09-02
©PaperWeekly 原创 · 作者 | 吴思越
学校 | 中山大学硕士
研究方向 | 自然语言处理,知识蒸馏


论文链接: 
https://arxiv.org/abs/2305.10010

代码链接:

https://github.com/brucewsy/AD-KD



动机

近年来,基于 Transformer 结构的预训练语言模型在自然语言处理(NLP)领域取得了极大的成功,尽管性能很强大,但是随着模型规模越来越庞大,其在资源受限的场景下部署成为了一大难题。知识蒸馏,是较为成熟的模型压缩技术之一,基本思路是用大而强的教师模型辅助训练小而弱的学生模型。

然而,现有的针对预训练语言模型的蒸馏方法存在两个问题:首先,学生模型仅仅是模仿教师模型的表面行为,却忽略了教师模型背后的推理依据;其次,用于蒸馏的特征与模型强耦合(model-specific),而相比之下我们认为靠近数据端的特征更具有可迁移性。为此我们提出了一种归因驱动的预训练语言模型知识蒸馏框架 AD-KD,使学生不仅学习到教师浅层的表面知识,也能学习到深层的推理知识,提高学生模型的泛化性。




相关工作

本文涉及两个方面的背景知识:1)知识蒸馏(Knowledge Distillation);2)归因分析(Attribution)。

2.1 知识蒸馏

在现有的知识蒸馏方法中,知识可以分成三类:
  • 基于响应(response-based)的知识 [1,2]:模型输出层的软标签
  • 基于特征(feature-bassed)的知识 [3,4]:模型中间层的隐状态
  • 基于关系(realtion-based)的知识 [5,6]:不同 token 或不同层隐状态间的相对关系

这些知识只能揭示教师模型的表面行为,无法反映内在的推理依据。

2.2 归因分析

归因分析,旨在为模型的输入特征或中间层特征赋予重要性分数,以衡量其对模型最终预测的贡献程度。常见的归因方法包括擦除法 [7]、梯度法 [8]、传播法 [9] 和扰动法 [10]。归因分析已经广泛应用于其他模型压缩技术,如剪枝和动态推理。在本工作中,我们尝试探索归因分析在知识蒸馏中的应用。




方法

3.1 归因蒸馏

为了在学生模型和教师模型之间传递归因知识,我们的归因蒸馏包含三个步骤:
  • Token Embedding Attribution:根据模型的输入和输出,采用积分梯度(Integrated Gradients)[8] 计算输入 embedding 每个维度的重要分数。这里为了更加全面地利用教师的软标签信息,我们对每个潜在的标签分别进行一次归因的过程,称为多视角归因(Multi-view Attribution)。
  • Attribution Maps Computation:对单个 token embedding 归因的结果计算 L2 范数,将 embedding 层面的细粒度重要性分数转化为 token 层面的粗粒度重要性分数,称为归因图(attribution maps)。同时,在对教师模型侧计算 L2 范数之前,我们使用 Top-k 过滤掉归因分数较低的 embedding 维度,目的是为了减少教师模型归因知识的噪声,降低学生模型的学习难度。

  • Attribution Alignment:在蒸馏初期,教师模型已经在目标任务上充分训练,而学生模型是尚未训练的,两者的归因图之间存在较大差距,如果强行对齐每个 token 的绝对归因分数会使得学生模型优化困难。考虑到 token 之间的相对归因分数信息更加重要,我们对每个视角下的归因图先做归一化(Normalization),保留相对差异的同时消除绝对量级的影响,最后再对教师模型和学生模型的归一化多视角归因图进行 MSE 对齐。
3.2 损失函数

我们最终的损失函数包含三项,分别是与硬标签的交叉熵损失,与教师软标签的 KL 散度损失,以及归因蒸馏损失。




实验分析

我们主要在 GLUE 基准数据集上验证了 AD-KD 的有效性,并做了许多的消融分析实验。

4.1 总体实验结果

相比于 SOTA 方法 CKD 和 MGSKD ,在验证集上,AD-KD 平均提升了 1.0 和 1.9 个点;在测试集上,AD-KD 平均提升了 0.9 个点。

4.2 消融实验

不同损失项的影响:在大部分数据集上,归因蒸馏损失项对学生性能的贡献程度最大。
多视角归因的作用:单视角归因优于 vanilla KD,而多视角归因优于单视角归因。

不同学生模型规模的对比:在不同的规模下,AD-KD 相比于 vanilla KD 具有一致的优势。

4.3 归因层的影响

在对比不同的归因位置(输入层、第一层、倒数第二层或均匀抽取)后,我们发现输入层的归因信息对蒸馏的提升作用最显著,因为其最具有全局性。

4.4 样例分析

我们在 QNLI 任务的两个具体样例上对比了教师模型、vanilla KD 和 AD-KD的预测结果以及归因图,结果表明 AD-KD 可以正确地学习到应该关注哪些 token,以及不应该关注哪些 token,从而输出与教师一致的答案。

此外我们还做了各种超参数的影响等实验。这里就不一一详细介绍了,感兴趣的朋友请移步原文查看。


参考文献

[1] Victor Sanh, Lysandre Debut, Julien Chaumond, and Thomas Wolf. 2019. Distilbert, a distilled version of bert: smaller, faster, cheaper and lighter. arXiv preprint arXiv:1910.01108.

[2] Iulia Turc, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. 2019. Well-read students learn better: On the importance of pre-training compact models. arXiv preprint arXiv:1908.08962.

[3] Siqi Sun, Yu Cheng, Zhe Gan, and Jingjing Liu. 2019. Patient knowledge distillation for BERT model compression. In Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP), pages 4323–4332, Hong Kong, China. Association for Computational Linguistics.

[4] Xiaoqi Jiao, Yichun Yin, Lifeng Shang, Xin Jiang, Xiao Chen, Linlin Li, Fang Wang, and Qun Liu. 2020. TinyBERT: Distilling BERT for natural language understanding. In Findings of the Association for Computational Linguistics: EMNLP 2020, pages 4163–4174, Online. Association for Computational Linguistics.

[5] Geondo Park, Gyeongman Kim, and Eunho Yang. 2021. Distilling linguistic context for language model compression. In Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing, pages 364–378, Online and Punta Cana, Dominican Republic. Association for Computational Linguistics

[6] Chang Liu, Chongyang Tao, Jiazhan Feng, and Dongyan Zhao. 2022. Multi-granularity structural knowledge distillation for language model compression. In Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pages 1001–1011, Dublin, Ireland. Association for Computational Linguistics

[7] Matthew D Zeiler and Rob Fergus. 2014. Visualizing and understanding convolutional networks. In 13th European Conference on Computer Vision, ECCV 2014, pages 818–833. Springer Verlag.
[8] Mukund Sundararajan, Ankur Taly, and Qiqi Yan. 2017. Axiomatic attribution for deep networks. In International Conference on Machine Learning, pages 3319–3328. PMLR.
[9] Avanti Shrikumar, Peyton Greenside, and Anshul Kundaje. 2017. Learning important features through propagating activation differences. In International Conference on Machine Learning, pages 3145–3153. PMLR.
[10] Karl Schulz, Leon Sixt, Federico Tombari, and Tim Landgraf. 2020. Restricting the flow: Information bottlenecks for attribution. arXiv preprint arXiv:2001.00396.


更多阅读



#投 稿 通 道#

 让你的文字被更多人看到 



如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。


总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 


PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。


📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算


📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿


△长按添加PaperWeekly小编



🔍


现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧


·
·

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存