NeurIPS 2022 | 知识蒸馏中如何让“大教师网络”也教得好?
本文介绍一篇发表在机器学习顶会 NeurIPS 2022 (CCF-A 类会议)的论文《Asymmetric Temperature Scaling Makes Larger Networks Teach Well Again》。该工作的研究内容为知识蒸馏(Knowledge Distillation),是与华为诺亚联合实验室共同研究产出的一篇工作。本文分为以下几个部分对该工作进行介绍:
文章链接 代码链接 研究背景 提出的方法 实验效果 投稿历程
论文题目:
Asymmetric Temperature Scaling Makes Larger Networks Teach Well Again
NeurIPS 2022
https://arxiv.org/abs/2210.04427
https://github.com/lxcnju/ATS-LargeKD
https://gitee.com/lxcnju/ats-mindspore
http://www.lamda.nju.edu.cn/lixc/
研究背景
知识蒸馏(Knowledge Distillation)可以将大(强)模型的能力传递给轻量(弱)模型,其基本形式如下:
其基本步骤为:1)在训练集上训练一个大教师网络,或者拿现有的当作教师网络;2)使用图示的损失去指导学生网络进行训练。损失包括两部分:正常分类的损失和知识蒸馏损失。前者是 hard-label,后者是 soft-label。引入后者的目的是因为学生直接学习 hard-label 太困难了,因此期望学生能够模仿教师的 soft 输出,从而把握类别之间的相似度,从而更好地学习。
▲ 文章的研究内容
提出的方法
每个样本正确类别的 logit,记作 每个样本错误类别 logits 间的方差,记作
如果 target logit 非常大,那么无论用什么温度系数对教师的输出进行 softmax,最后得到的 都为 one-hot 形式; 如果 wrong logits 之间差异很小,就假设都一样,那么无论用什么温度系数对教师的输出进行 softmax,最后得到的 在错误类别之间都无法提供差异化信息。
Inherent Variance:错误类别 logits 经过 softmax 之后得到的类别概率分布的方差; Derived Average:所有类别 logits 经过 softmax 之后得到的错误类别概率的平均值; Derived Variance:所有类别 logits 经过 softmax 之后得到的错误类别概率的方差。
大教师网络给的正确类别 logit 的值很大,导致 DA 变小; 大教室网络给的错误类别 logit 的差异很小,导致 IV 变小。
大教师网络给的正确类别 logit 的值很大,用较大的 增大 DA; 大教室网络给的错误类别 logit 的差异很小,用较小的 增大 IV。
实验结果
投稿历程
发现大神级网络和小神经网络输出的结果具有一些差异,兴奋值 ++; 发现可以将知识蒸馏的损失分解为三部分,特别是 class discriminability 的定义很有意思,兴奋值 ++; 发现可以用公式解释大教师神经网络的 DV 很小,兴奋值 ++; 发现可以提出一个非常简单的 ATS 来使得大教师教地更好,兴奋值 ++。
更多阅读
#投 稿 通 道#
让你的文字被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析、科研心得或竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。
📝 稿件基本要求:
• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注
• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题
• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算
📬 投稿通道:
• 投稿邮箱:hr@paperweekly.site
• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者
• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿
△长按添加PaperWeekly小编
🔍
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」