查看原文
其他

论文推荐|[ICCV 2019]一种无需原始训练数据的Teacher-Student模型压缩方法

梁凯焕 CSIG文档图像分析与识别专委会 2022-07-11


本文简要介绍ICCV 2019论文“Data-Free Learning of Student Networks”的主要工作。该论文主要解决的问题是,在无法获取原始数据集的情况下,如何进行网络压缩。


一、研究背景

大部分深度神经网络(CNN)往往需要消耗巨大的计算资源以及存储,为了将模型部署到移动端等性能受限设备上,通常需要对网络进行加速压缩。现有的一些加速压缩算法,例如知识蒸馏[1]等, 能够在具有训练数据的情况下取得有效的结果。然而,在现实应用中,训练数据集由于隐私、传输等原因往往无法得到。因此,作者提出了一种无需原始训练数据的模型压缩方法。


二、原理简述


Fig.1. Overall architecture
 

Fig. 1是论文提出的整体结构。通过给定的待压缩网络(教师网络),作者训练了一个生成器来生成与原始训练集具有相似分布的数据。然后利用生成数据,基于知识蒸馏算法对学生网络进行训练,从而实现了无数据情况下的模型压缩。

那么如何在没有数据的情况下,通过给定的教师网络,训练一个可靠的生成器呢?作者提出了以下三个loss来指导生成器的学习。

(1)在图像分类任务中,对于真实的数据,网络的输出往往接近一个One-hot向量。其中,在分类类别上的输出接近1,其它类别的输出接近0。因此,如果生成器生成的图片接近真实数据,那么它在教师网络上的输出应该同样接近于一个One-hot向量。于是,作者提出了One-hotloss:

其中yT是生成图片通过教师网络的输出,t是伪标签,由于生成的图片不具备标签,作者将yT中的最大值设置为伪标签。Hcross表示交叉熵函数。

(2)另外,在神经网络中,输入真实数据往往比输入随机噪声在Feature Map上有更大的响应值。因此,作者提出了Activation Loss对生成数据进行约束:

其中fT表示生成数据通过教师网络提取得到的特征,||·||1表示l1范数。

(3)此外,为了让网络更好地训练,训练数据往往需要类别平衡。因此,为了让生成的数据同样类别平衡,作者引入了信息熵Loss来衡量类别平衡程度:

其中,Hinfo表示信息熵,yT表示每张图片的输出。如果信息熵越大,说明输入的一组图片中,每个类别的数目越平均,从而保证了生成的图片类别平均。

最后,通过结合上述三个Loss函数,可以得到生成器训练所使用的Loss:

通过优化以上Loss,可以训练得到一个生成器,然后通过生成器生成的样本进行知识蒸馏。在知识蒸馏中,待压缩网络(教师网络)通常准确率较高但参数冗余,学生网络是一个轻量化设计、随机初始化的网络。通过用教师网络的输出来指导学生网络的输出,能够提高学生网络的准确率,达到模型压缩的目的。这个过程可以用以下公式进行表示:

其中,ys、yt分别表示学生网络、教师网络的输出,Hcross表示交叉熵函数。

算法1表示文章方法的流程。首先,通过优化上文所述的Loss,得到一个与原始数据集具有相似分布的生成器。第二,通过生成器生成的图片,利用知识蒸馏的方法将教师网络的输出迁移到学生网络上。学生网络具有更少的参数,从而实现了无需数据的压缩方法。


  
三、主要实验结果及可视化效果
 
TABLE 1. Classification result on the MNIST dataset.


TABLE 2.  Effectiveness of different components of the proposed data-free learning method.


TABLE 3. Classification result on the CIFAR dataset.


TABLE 4.  Classification result on the CelebA dataset.


TABLE 5.  Classification results on various datasets.    


Fig.2. Visualization of averaged image in each category(from 0 to 9) on the MNIST dataset.
 

Fig.3. Visualizationof filters in the first convolutional layer learned on the MNIST dataset. The topline shows filters trained using the original training dataset, and the bottom line shows filters obtained using samples generated by the proposed method.
 

由TABLE 1、TABLE 3、TABLE 4来看,文章所提方案在MNIST、CIFAR、CelebA上取得的结果均比其它无数据压缩方案好,而且接近于采用数据的知识蒸馏方法。TABLE 2验证了所提出的三个loss对结果的影响,可以看到,只用单个Loss的话提升并不明显,需要组合多个loss才能得到好的结果。Fig.2中第二行是生成数据的可视化结果,可以看到生成图片与真实图片相比仍有不小的差距。Fig.3是卷积核的可视化结果,可以看出学生网络的卷积核跟教师网络的卷积核具有相似性。

 
四、总结及讨论
  1. 文章利用待压缩网络,根据真实图片在网络中的输出规律设计了三个Loss来训练生成器,使得生成图片与真实数据具有相似分布,从而达到了Data-free进行模型压缩的目的。在MNIST、CIFAR、CelebA等数据集上都验证了其方法的有效性。

  2. 通过可视化结果可以看出生成的图片与真实图片仍有不小的差距。此外,需要在更大的数据集上进一步验证其性能。


五、相关资源
  • 论文地址:https://arxiv.org/pdf/1904.01186.pdf

  • 代码地址:https://github.com/huawei-noah/DAFL


参考文献
[1] Geoffrey Hinton, Oriol Vinyals, and Jeff Dean.Distilling the knowledge in a neural network. arXiv:1503.02531, 2015.


原文作者:Hanting Chen, Yunhe Wang, Chang Xu , Zhaohui Yang, Chuanjian Liu, Boxin Shi,Chunjing Xu, Chao Xu, Qi Tian


撰稿:梁凯焕

编排:高  学

审校:殷  飞

发布:金连文




免责声明:1)本文仅代表撰稿者观点,个人理解及总结不一定准确及全面,论文完整思想及论点应以原论文为准。(2)本文观点不代表本公众号立场。 


往期精彩内容回顾



征稿启事:本公众号将不定期介绍一些文档图像分析与识别领域为主的论文、数据集、代码等成果,欢迎自荐或推荐相关领域最新论文/代码/数据集等成果给本公众号审阅编排后发布 (投稿邮箱:xuegao@scut.edu.cn)。




(扫描识别如上二维码加关注)

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存