查看原文
其他

NeurIPS 2022 | 知识蒸馏中如何让“大教师网络”也教得好?

李新春 PaperWeekly 2023-02-10

©作者 | 李新春
单位 | 南京大学
研究方向 | 知识蒸馏


本文介绍一篇发表在机器学习顶会 NeurIPS 2022 (CCF-A 类会议)的论文《Asymmetric Temperature Scaling Makes Larger Networks Teach Well Again》。该工作的研究内容为知识蒸馏(Knowledge Distillation),是与华为诺亚联合实验室共同研究产出的一篇工作。本文分为以下几个部分对该工作进行介绍:


  • 文章链接
  • 代码链接
  • 研究背景
  • 提出的方法
  • 实验效果
  • 投稿历程


论文题目:

Asymmetric Temperature Scaling Makes Larger Networks Teach Well Again

文章来源:

NeurIPS 2022

论文链接:

https://arxiv.org/abs/2210.04427

代码链接:

https://github.com/lxcnju/ATS-LargeKD

https://gitee.com/lxcnju/ats-mindspore

作者主页:

http://www.lamda.nju.edu.cn/lixc/





研究背景


知识蒸馏(Knowledge Distillation)可以将大(强)模型的能力传递给轻量(弱)模型,其基本形式如下:


▲ 知识蒸馏示意图


其基本步骤为:1)在训练集上训练一个大教师网络,或者拿现有的当作教师网络;2)使用图示的损失去指导学生网络进行训练。损失包括两部分:正常分类的损失和知识蒸馏损失。前者是 hard-label,后者是 soft-label。引入后者的目的是因为学生直接学习 hard-label 太困难了,因此期望学生能够模仿教师的 soft 输出,从而把握类别之间的相似度,从而更好地学习。


值得注意的是:知识蒸馏损失里面的温度系数 Temperature 很重要!如果 很小,那么教师的输出结果像 hard-label,导致和正常分类损失相比没有什么额外的信息;如果 很大,那么教师的输出结果像 uniform-label,类别之间的差异性就没有了,仅仅起到了一个 label smoothing 的作用。

▲ 知识蒸馏中温度系数的作用

普遍的认知是越好的教师教学生教地越好。然而实际上,2019 年有学者 [Jang Hyun Cho, 2019] 指出:大神经网络不一定教地好!

引用下面的一个示意图(来自 [Seyed Iman Mirzadeh, 2020]),随着 teacher size 逐渐变大,教师的准确率越来越高(红色的 teacher accuracy),但是其教的学生的准确率先变高再变低(蓝色的 student accuracy)。现有的工作都将这个奇怪的现象归因于”大教师网络“和”小学生网络“之间的容量差异(capacity gap),但是没有形象地指出这种差异为何出现。


▲ 大神经网络不一定教得好

因此,本文的研究内容就是:为什么大神经网络不一定教地好,有没有什么简单的办法让大神经网络教地好?


▲ 文章的研究内容




提出的方法


本文最直接的猜测起源于下面的式子:

▲ 大教师网络和小教师网络在教同一个学生网络的区别

也就是说,在遍历所有可能温度系数的情况下,相比较于大教师网络,小教师网络更容易给出质量更好的指导信息,即

首先,文章通过一些观察实验发现:大教师网络更容易给出置信度较高的预测,包括两个方面。其一,正确类别的 logit 可能更大;其二,错误类别 logits 之间差异更小。本文称神经网络最后一层给出的类别预测得分称为 logits。

▲ 大神经网络和小神经网络的输出 logits 的分布

具体地,在 CIFAR-100 和 CIFAR-10 上训练 ResNet14/44/110 和 WRN28-1/4/8,统计神经网络输出的 logits 的如下指标:

  • 每个样本正确类别的 logit,记作  
  • 每个样本错误类别 logits 间的方差,记作  

可以看出,在 CIFAR-100 上,ResNet110 很明显给出了更大的 ,在 CIFAR-10 上 WideResNet28-8 给出了更小的 。举例而言,给定五个类别,第一个为正确类别,小教师网络和大教师网络给出的 logits 大概如下:

▲ 大教师和小教师给出的logits示例

这就是最基本的现象,也是整个工作的启发点:大神经网络更为置信,给出的 target logit 更大,或者 wrong logits 差异更小!

那么我们不妨设想两个极端:

  • 如果 target logit 非常大,那么无论用什么温度系数对教师的输出进行 softmax,最后得到的 都为 one-hot 形式;
  • 如果 wrong logits 之间差异很小,就假设都一样,那么无论用什么温度系数对教师的输出进行 softmax,最后得到的 在错误类别之间都无法提供差异化信息。

也就是说,大教师网络的高置信度导致:无论在什么样子的温度系数下,其给出的指导信息(即:)都很难具有足够有效的信息!这里足够有效的信息如何定义呢?本文将其定义为,错误类别之间的概率值的方差!

▲ 根据现象得到的猜测

因此本文的猜测为:大神经网络不能教地好的原因是无论使用怎样的温度系数,都难以使得错误类别概率“错落有致”。

为了从理论上去推导验证,本文将知识蒸馏分为三个部分:

▲ 知识蒸馏分解

分别包括:1)Correct Guidance,类似于 hard-label 的 one-hot 标签;2)Smooth Regularization,错误类别的平均概率值,类似于 label smoothing;3)Class Discriminability,错误类别之间的差异,可以用方差来度量,错误类别差异越大,教师提供的指导信息越多!

接下来是理论分析,先定义一些符号和公式:

▲ 一些基本的符号

理论分析:


事实上:随着 不断增大,得到的 的熵越来越大,即越来越均匀。本研究证明了:随着 不断增大,得到的 元素之间的方差也会越来越小。


在正确类别 logit 最大情况下:随着 不断增大,错误类别概率的均值 会逐渐增大。

最重要的等式为:


其中 DA、IV、DV 分别解释如下:

  • Inherent Variance:错误类别 logits 经过 softmax 之后得到的类别概率分布的方差;
  • Derived Average:所有类别 logits 经过 softmax 之后得到的错误类别概率的平均值;
  • Derived Variance:所有类别 logits 经过 softmax 之后得到的错误类别概率的方差。


针对某一个样本的计算示意图如下(SF 代表 Softmax):

▲ DA、IV、DV关系示意图

利用该公式解释为什么大教师网络教不好:


翻译为中文为:

  • 大教师网络给的正确类别 logit 的值很大,导致 DA 变小;
  • 大教室网络给的错误类别 logit 的差异很小,导致 IV 变小。


最终都会导致 DV 变小,即:大教师网络的 DV 很小,传统温度缩放下很难让错误类别的概率“错落有致”。

提出的方法为 Asymmetric Temperature Scaling(ATS),针对正确/错误类别施加较大/较小的温度系数:

  • 大教师网络给的正确类别 logit 的值很大,用较大的 增大 DA;
  • 大教室网络给的错误类别 logit 的差异很小,用较小的 增大 IV。

结论:ATS 可以使得大教师网络的 DV 变大让错误类别的概率“错落有致”。




实验结果


实验设置和结果就不详细介绍了,有兴趣的可以看文章。下面就简单贴一下结果:





投稿历程


到此,本文的基本方法都介绍完了,是一个非常简单的改进。研究设计的过程中也充满了乐趣,主要包括三个过程:

  1. 发现大神级网络和小神经网络输出的结果具有一些差异,兴奋值 ++;
  2. 发现可以将知识蒸馏的损失分解为三部分,特别是 class discriminability 的定义很有意思,兴奋值 ++;
  3. 发现可以用公式解释大教师神经网络的 DV 很小,兴奋值 ++;
  4. 发现可以提出一个非常简单的 ATS 来使得大教师教地更好,兴奋值 ++。


该工作完成于 2021.1 月份左右,在新年前几天完成的,满怀期待投稿了 ICML 2022。很不幸的是被拒了,个人感觉是在边缘,因为审稿人给的意见都没有特别严重的,主要是一些行文思路和概念没有解释清楚。

于是完善了之后转投了 NeurIPS,得分为 2 (Strong Reject),5(Borderline Accept),6 (Weak Accept)。看到审稿意见本想放弃,但仔细一看给 2 分的貌似只是针对我们公式符号的不合理性进行了攻击,感觉还是有希望的。于是修改了符号,提交了 rebuttal revision。审稿人然后就将分数改为 6。最终得分为 666。


更多阅读



#投 稿 通 道#

 让你的文字被更多人看到 



如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。


总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 


PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。


📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算


📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿


△长按添加PaperWeekly小编



🔍


现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧


·

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存