查看原文
其他

【源头活水】基于对抗的迁移学习方法: DANN域对抗网络

“问渠那得清如许,为有源头活水来”,通过前沿领域知识的学习,从其他研究领域得到启发,对研究问题的本质有更清晰的认识和理解,是自我提高的不竭源泉。为此,我们特别精选论文阅读笔记,开辟“源头活水”专栏,帮助你广泛而深入的阅读科研文献,敬请关注。

来源:知乎—NaNNN
地址:https://zhuanlan.zhihu.com/p/73947456
论文解读:Domain-Adversarial Training of Neural Networks(DANN)
论文地址:https://www.jmlr.org/papers/volume17/15-239/15-239.pdf

论文在线阅读:https://bbs.sffai.com/d/262-domain-adversarial-training-of-neural-networks

(扫码阅读)


本次介绍的论文是对抗迁移学习领域中一篇很经典的论文,论文作者Yaroslav Ganin [1] 等人首次将对抗的思想引入迁移学习领域当中。

01

背景简介
在传统的机器学习中,我们经常需要大量带标签的数据进行训练, 并且需要保证训练集和测试集中的数据分布相似。在一些问题中,如果训练集和测试集的数据具有不同的分布,训练后的分类器在测试集上就没有好的表现。打个比方,在情感分析中,我们可能拥有‘电影’的大量带标签用户评价,然而我们却希望可以对‘书籍’下的用户评价进行分类。这种情况下该怎么办呢?
域适应(Domain Adaption)是迁移学习中一个重要的分支,目的是把具有不同分布的源域(Source Domain) 和目标域 (Target Domain) 中的数据,映射到同一个特征空间,寻找某一种度量准则,使其在这个空间上的“距离”尽可能近。然后,我们在源域 (带标签) 上训练好的分类器,就可以直接用于目标域数据的分类。

图1. 域适用举例 [2]

如上图所示,图1.a为源域样本分布(带标签),图1.b为目标域样本分布,它们具有共同的特征空间和标签空间。在传统的机器学习场景中,训练集和测试集具有相同的分布,我们可以用训练集训练好的分类器,直接用于测试集分类。但在域适应问题中,源域和目标域通常具有不同的分布,这就意味着我们无法将源域训练好的分类器,直接用于目标域样本的分类。因此,在域适应问题中,我们尝试对两个域中的数据做一个映射,使得属于同一类(标签)的样本聚在一起。此时,我们就可以利用带标签的源域数据,训练分类器供目标域样本使用。
然后,我们再简单回顾一下GAN(生成对抗网络)[3]。生成对抗网络由GoodFellow等人在2014年提出,其结构如图2所示。
图2. 生成对抗网络结构[3]
生成对抗网络包含一个生成器(Generator)和一个判别器(Discriminator)。生成器用来生成假图片,判别器则用来区分,输入的图片是真图片还是假图片。生成器希望生成的图片可以骗过判别器(以假乱真),而判别器则希望提高辨别能力防止被骗。两者互相博弈,直到系统达到一个稳定状态(纳什平衡)。
那么,GAN思想如何用到Domain Adapatation中呢?在域适应问题中, 存在一个源域和目标域。和生成对抗网络相比,域适应问题免去了生成样本的过程,直接将目标域中的数据看作生成的样本。因此,生成器的目的发生了变化,不再是生成样本,而是扮演了一个特征提取(feature extractor)的功能:如何从源域和目标域中提取特征,使得判别器无法区分提取的特征是来自源域,还是目标?
这篇论文首次将对抗学习的思想,引入到迁移学习中。一般来说,传统的域适应问题一般会选用固定的特征 (fixed feature represenatations),但是本文提出的对抗迁移网络则关注于如何在不同域之间选择可供迁移的特征(transferable features)。也就是说,一个好的可迁移特征,应该满足两个条件:
1. Domain-invariance - 面对这些特征,你无法区分它们是来自目标域还是源域。
2. Discriminativeness - 利用这些特征,你可以很好的完成分类任务。
因此,域适应问题的网络损失由两部分构成:训练损失(标签预测器损失)和域判别损失

02

域对抗迁移网络 (DANN)
接下来我们讲一下域对抗迁移网络(DANN),如图3所示:
图3. DANN对抗迁移网络结构 [1]
DANN结构主要包含3个部分:
特征提取器 (feature extractor) - 图示绿色部分,用来将数据映射到特定的特征空间,使标签预测器能够分辨出来自源域数据的类别的同时,域判别器无法区分数据来自哪个域。
标签预测器 (label predictor) - 图示蓝色部分,对来自源域的数据进行分类,尽可能分出正确的标签。
域判别器(domain classifier)- 图示红色部分,对特征空间的数据进行分类,尽可能分出数据来自哪个域。
其中,特征提取器和标签分类器构成了一个前馈神经网络。然后,在特征提取器后面,我们加上一个域判别器,中间通过一个梯度反转层 (gradient reversal layer, GRL) 连接。在训练的过程中,对来自源域的带标签数据,网络不断最小化标签预测器的损失 (loss)。对来自源域和目标域的全部数据,网络不断最小化域判别器的损失。
2.1 标签预测器的损失
对于特征提取器 (以单隐层为例),sigmoid作为激活函数,其输出为:

对于标签预测器,softmax作为激活函数,其输出为:

当给定数据点  ,负对数似然 (negative log-probabality) 作为损失函数,其标签预测器的损失为:

因此在源域上,我们的训练优化目标就是:

其中,表示第    个样本的标签预测损失,是一个可选的正则化器(Regularizer), 是人为设置的正则化参数,目的是用来防止神经网络过拟合。
标签预测器的损失表达函数现在讲完了,和其他普通神经网络差不多。然而,DANN网络的核心主要在接下来要讲的部分:跨域正则器(Domain Regularizer)。
2.2 域判别器损失
对于域判别器,sigmoid作为激活函数,其输出为

然后,我们定义域判别器  的损失为 (负对数似然作为损失函数):

  表示第    个样本的二元标签,用来表示这个样本属于源域还是目标域。此时,域判别器的目标函数为

2.3 总损失
对抗迁移网络的总损失由两部分构成:网络的训练损失(标签预测器损失)和域判别损失。
在这里,我们可以给出DANN的总目标函数为 

其中,我们通过最小化目标函数来更新标签预测器的参数,最大化目标函数来更新域判别器的参数。

好啦,今天我们简单介绍了DANN的基本思想。由于本人目前也在学习过程中,如有不足,欢迎指正 :)
参考文献
[1] Ganin, Yaroslav, et al. "Domain-adversarial training of neural networks."The Journal of Machine Learning Research17.1 (2016): 2096-2030.
[2] Yanwen Zhang, medium.com/deep-learnin
[3] Goodfellow, Ian, et al. "Generative adversarial nets."Advances in neural information processing systems. 2014.

本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。


“源头活水”历史文章


更多源头活水专栏文章,

请点击文章底部“阅读原文”查看



分享、在看,给个三连击呗!

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存