查看原文
其他

让学生网络相互学习,为什么深度相互学习优于传统蒸馏模型?| 论文精读

CSDN App AI科技大本营 2019-10-30

作者 | Ying Zhang,Tao Xiang等
译者 | 李杰
出品 | AI科技大本营(ID:rgznai100)

蒸馏模型是一种将知识从教师网络(teacher)传递到学生网络(student)的有效且广泛使用的技术。通常来说,蒸馏模型是从功能强大的大型网络或集成网络转移到结构简单,运行快速的小型网络。本文决定打破这种预先定义好的“强弱关系”,提出了一种深度相互学习策略(deep mutual learning, DML)。

在此策略中,一组学生网络在整个训练过程中相互学习、相互指导,而不是静态的预先定义好教师和学生之间的单向转换通路。作者通过在CIFAR-100和Market-1501数据集上的实验,表明DML网络在分类和任务重识别任务中的有效性。更重要的是,DML的成功揭示了没有强大的教师网络是可行的,相互学习的对象是由一个个简单的学生网络组成的集合。

简介


深度神经网络已经广泛应用到计算机视觉的各个任务中,并获得了很好的性能表现,但是,这种SOTA通常是依靠深度堆叠网络层数,增加网络宽度实现,这种结构设计会产生大量的参数,一方面,会拖慢运行速度和执行效率,另一方面,需要很大的存储空间进行存储。这两方面也限制了很多网络在实际应用中落地。

因此,如何在保证效果的情况下设计更小,更快速的网络,就成了我们关注的重点。基于这种思想,涌现了很多好的工作,蒸馏模型(model distillation)就是其中的代表,为了更好地学习小型网络,蒸馏方法从一个强大的(更深或更宽)教师网络开始,然后训练一个更小的学生网络来模仿教师网络。下图是蒸馏模型的一个结构表示:

蒸馏模型

符号表示:

  • Big Model:复杂强大的教师网络

  • Small Model:轻巧简单的学生网络

  • soft targets:输入x经过教师网络后得到的softmax层输出

  • hard targets:输入数据对应的label标签

  • softmax公式表示:



其中,qi是第i类的概率, Zi和Zj分别表示softmax层输出,T是温度系数,控制着输出概率的软化(soft)程度,T越大,不同类别输出概率在不改变相对大小关系的情况下,差值会越小,也就是更加soft。

定义好基本概念后,实现步骤可以表示为:

1.设置一个较大的 T,输入x训练一个教师网络,经过softmax层后生成soft targets。
2.使用步骤1得到的soft targets来训练学生网络。
3.最终模型的目标函数由soft targets和学生网络的输出数据的交叉熵,hard targets 和学生网络的输出数据的交叉熵两部分共同组成。

这些训练步骤,能够保证学生网络和教师网络的结果尽可能一致,也就代表学生网络学到了教师网络的知识;能够保证学生网络的结果和实际类别标签尽可能一致,也就代表学生网络的能力很强。

在本文中,作者剥离了教师,学生网络的概念,提出了一个与蒸馏模型不同但又相关的概念——相互学习(mutual learning)。通过上文的介绍我们知道,蒸馏模型从一个强大的、预先培训过的教师网络开始,然后将知识传递给一个小的、未经训练的学生网络,这种传递方式是一条单向的通路。与之相反,在相互学习中,从一组未经训练的学生网络开始,它们同时学习,共同解决任务。

在训练过程中,每个学生网络的损失函数由两部分组成:(1)传统的监督学习损失(2)模仿损失,使每个学生的预测类别与其他学生的类别预测概率保持一致。实验证明,在这种基于同伴教学(peer-teaching)的训练中,每个学生网络的学习效果都比在传统的监督学习场景中单独学习要好得多。此外,虽然传统的蒸馏模型需要一个比预期学生网络更大、更强大的教师网络,但事实证明,在许多情况下,几个大型网络的相互学习也比独立学习提高了性能。

你可能会有这样的疑问:为什么这种相互学习的策略会比蒸馏模型更有用?如果整个训练过程都是从一个小的且未经预训练的网络开始,那么网络中额外的知识从哪里产生?为什么它会收敛的好,而不是被群体思维所束缚,造成“瞎子带领瞎子"(theblind lead the blind)的局面?

针对这些疑问,作者给出了相应的解释:每个学生网络主要受传统的监督学习损失的指导,这意味着他们的表现通常会提高,而且也限制了他们作为一个群体任意地进行群体思维的能力。有了监督学习,所有的网络很快就可以为每个训练实例预测相同的标签,这些标签大多是正确且相同的。

但是由于每个网络从不同的初始条件开始,它们对下一个最有可能的类的概率的估计是不同的,而正是这些secondary信息,为蒸馏和相互学习提供了额外的知识。在相互学习网络中,每个学生网络有效地汇集了他们对下一个最有可能的类别的集体估计,根据每个训练实例找出并匹配其他最有可能的类会增加每个学生网络的后验熵,这有助于得到一个更健壮和泛化能力更强的网络。

综上所述,相互学习通过利用一组小的未经训练的网络协作进行训练,可以简洁而有效的提高网络的泛化能力。实验结果表明,与经过预训练的静态大型网络相比,同伴相互学习可以获得更好的性能。此外,作者认为相互学习还有以下几点优势:

1.网络效果随队列中网络的数量增加而增加;
2.相互学习适用于各种网络架构,以及由不同大小的混合网络组成的异构群组;
3.与独立训练相比,即使是在队列中相互训练的大型网络也能提高性能;
4.虽然作者的重点是获得一个单一有效的网络,但整个队列也可以整合为一个高效的集成模型。
深度相互学习
*为表述方便,本文以两个网络为例进行说明。

  • DML通用表示


如下图所示,本文提出的DML网络,在队列中有两个网络θ1,θ2。给定来自 M个类别的 N个样本,表示为:

其对应的标签集合为:

那么θ1网络中某个样本 xi属于类别 m的概率可以表示为:

其中,是θ1网络中经过softmax层后输出的预测概率。

对于多目标分类任务而言,θ1网络的目标函数可以用交叉熵表示:

其中, 相当于一个指示函数,如下式所示,如果标签值和预测值相同,置为1,否则置为0:

传统的监督损失训练能够帮助网络预测实例的正确标签,为了进一步提升网络θ1的泛化能力,DML引入了同伴网络θ2,θ2同样会产生一个预测概率p2,在这里引入KL散度的概念,相信了解过GAN网络的小伙伴对KL应该不会陌生,KL 散度是一种衡量两个概率分布的匹配程度的指标,两个分布差异越大,KL散度越大。作者采用KL散度,衡量这两个网络的预测p1和p2是否匹配。
p1和p2的KL散度距离计算公式为:

综上,对于θ1网络来说,此时总的损失函数就由两部分构成:自身监督损失函数,来自θ2网络的匹配损失函数:

同理,θ2可以表示为:

  • 算法优化


DML在每次训练迭代中,都计算两个模型的预测,并根据另一个模型的预测更新两个网络的参数。θ1和θ2网络一直在迭代直至收敛,整个训练优化细节如下图所示:
输入:
训练集 X,标签集 Y,学习率 γ1,γ2

初始化:
θ1,θ2不同初始化条件

步骤:
从训练集 X中随机抽样 x
1.根据上文中的概率计算公式p,分别计算两个网络的在当前batch的预测p1和p2,得到θ1的总损失函数 Lθ1
2.利用随机梯度下降,更新θ1参数:

3. 根据上文中的概率计算公式p,分别计算两个网络的在当前batch的预测p1和p2,得到θ2的总损失函数Lθ2
4. 利用随机梯度下降,更新θ2参数:


重复以上步骤直至网络收敛

  • 学生网络的扩展


前几节我们用两个网络θ1和θ2说明了DML的结构,算法。其实DML不仅在两个网络中有效,还可以扩展到多个网络中去。假定我们要训练一个有K(K>2)个学生网络的互相学习网络,那么对于其中的某个网络θk而言,总的损失函数可以表示为:

该公式说明,每个学生网络都能够从另外的K-1个网络中学到知识,换而言之,对于一个学生网络,另外的K-1个网络都能作为该网络的教师网络。K=2就是该扩展网络的特例。注意,在上式中,对于其他网络的KL散度和,前面添加了权重系数1/(K-1),这是为了确保整个训练过程主要以监督学习的真正标签为指导。

对于两个以上的网络,除了DML训练策略外,在K个网络的训练中,对于一个学生网络,我们还可以将所有其他的k-1个网络集成作为一个单独的教师网络来提供综合平均的学习知识。这种思想与蒸馏模型类似,但是在参数更新上,在每个mini-batch上进行更新。基于这种思想,一个学生网络θk的损失函数可以表示为:

实验
主要是两方面的实验,利用CIFAR-100和Market-1501两个数据集分别进行目标分类和人物重识别任务测试。

  • Results on CIFAR-100


在CIFAR-100上进行top-1指标测试。首先对只有两个网络的DML进行测试,。采用不同的网络结构,结果如下表所示,可以看到,相比独立的分类网络,基于任何组合方式的,添加DML策略的网络,表现都有所提升;体量较小的网络(如ResNet-32),从DML中提升更多;在大网络(如WRN-28-10)中添加DML策略,也会使得性能得到提升,与传统的蒸馏模型相比,可以看到一个大型的预培训的教师网络并非必要条件。



  • Results on Market-1501


在Market-1501上进行mAP和rank-1指标测试。每个MobileNet在一个双网络队列中训练,并计算队列中两个网络的平均性能。如下表所示,与单独学习相比,DML显著的提升了MobileNet的性能,我们还可以看到,使用两个MobileNet训练的DML方法的性能显著优于先前最主流的方法。


  • Comparison with Distillation


本文提出的DML模型与蒸馏模型密切相关,因此作者对比了这两个模型的效果。如下表所示,设置了三组网络,分别是:独立网络net1,net2;蒸馏模型net1为教师网络,net2为学生网络;DML模型,net1和net2相互学习。从实验结果分析,意料之中,传统的蒸馏方法从一个强大的预训练的教师网络指导学生网络的确实提升了性能。但结果同样表明,预训练的强大教师网络不是必要条件,与蒸馏模型相比,在DML中一起训练的两个网络也获得了明显的提升。


  • DML的有效性


上述实验部分证明了DML的有效性,我们再从理论上讨论一下DML为什么能够提升以及通过哪些方法进行提升。

1)更鲁棒的最小值

与传统的优化方法相比,DML不是帮助我们找到一个更好的或者更深层次的训练损失最小值,而是帮助我们找到一个更广泛或者更可靠的最小值,它能更好地概括测试数据,更加健壮。作者利用Market-1501数据集和MobileNet主干网络做了一个小实验来证明DML能够找到更鲁棒的最小值。

作者比较了DML模型和独立模型在添加高斯噪声前后训练的损失变化。从图(a)可以看出两个模型的极小值的深度是相同的,但是在加入高斯噪声后,独立模型的训练损失增加较多,而DML模型的训练损失较少。这表明DML模型找到了一个更广泛,健壮的最小值,进而提供更好的泛化性能。


2)怎样找到更好的最小值?

那么DML是怎样找到这个广泛健壮的最小值的呢?DML会要求每个网络匹配其同伴网络的概率估计,如果给定网络预测为零,而其对等网络预测为非零,则该网络将受到严重惩罚。总体上,DML是指,当每个网络独立地将一个关注点放在一个小的次概率集合上时,DML中的所有网络都倾向于聚合它们对次级概率的预测。也就是说所有的网络把重心放在次概率上,并且把更多重心放在更明显的次概率上。因此,DML是通过对“合理的”次概率预测的相互概率匹配来寻找更宽泛的最小值。

结论
本文提出了一种简单且普适的方法DML来提高深度神经网络的性能,方法是将几个网络一起训练,相互蒸馏。用这种方法,可以获得紧凑的网络。实验证明,DML相比传统的蒸馏模型更好更健壮。此外,DML也能提高大型网络的性能,并且以这种方式训练的网络队列可以作为一个集成来进一步提高性能。

论文链接:
https://arxiv.org/abs/1706.00384
代码链接
YingZhangDUT/Deep-Mutual-Learninggithub.com

(*本文为 AI科技大本营编译文章,请微信联系 1092722531



精彩推荐


2019 中国大数据技术大会(BDTC)再度来袭!豪华主席阵容及百位技术专家齐聚,15 场精选专题技术和行业论坛,超强干货+技术剖析+行业实践立体解读,深入解析热门技术在行业中的实践落地。


即日起,限量 5 折票开售,数量有限,扫码购买,先到先得!



推荐阅读

你点的每个“在看”,我都认真当成了AI

    您可能也对以下帖子感兴趣

    文章有问题?点此查看未经处理的缓存