ICLR 2022 | 基于Transformer的跨域方法——CDTrans
©作者 | 达摩院AI Earth
论文链接:
代码链接:
前言
大多数现有的 UDA 方法都集中在学习域特征表示上,希望能够学习到一个跟类别种类相关的而跟域无关的特征。目前的研究无论是从域层面(粗粒度)还是类别层面(细粒度)上的特征对齐操作,都是使用基于卷积神经网络(CNN)的框架。大体上主流的解决思路有两种,分别是基于分布度量一致性约束的方法和基于对抗学习的方法。具有代表性的技术分别是 MMD [1] 和 DANN [2] 。
▲ 左右图分别是MMD和DANN的网络结构图
在最近的一些研究进展中,基于类别层面的 UDA 的方法中一个主流思路是在 target 数据上得到伪标签,用伪标签训练模型。但是一个比较重要的问题是,这些伪标签通常存在一定的噪音,不可避免地会影响 UDA 的性能。
▲ 左右图分别是MulT和CrossViT的Cross Attention机制
我们工作的核心思想是使用 Transformer 的 Cross Attention 机制来拉近 source 域和 target 域的图片的分布距离。据我们所知,这应该是较早使用纯 Transformer 在 UDA 上进行尝试的工作。
具体来说,在利用 Transformer 的 Cross Attention 来做两个域分布对齐时,它的输入需要是一个样本对。类似于多模态里面的图文对,这里我们的输入是由一个 source 图片和一个 target 图片组成的样本对。正常来说,两张图片应该是属于同一个类别,但是来自于不同的 domain(一个 source,一个 target)。
由于在 UDA 任务中,target 是没有标签的。因此我们只能借鉴伪标签[5]的思路,来生成潜在的可能属于同一个 ID 的样本对。但是,伪标签生成的样本对中不可避免的会存在噪声。这时,我们惊喜的发现 Cross Attention 对样本对中的噪声有着很强的鲁棒性。
我们分析这主要是因为 Attention 机制所决定的,Attention 的 weight 更多的会关注两张图片相似的部分,而忽略其不相似的部分。如果 Source 域图片和 Target 域图片不属于同一个类别的话,比如下图 1.a“小轿车 vs 卡车”的例子,Attention 的 weight 主要集中于两个图片中相似部分的对齐(比如轮胎),而对其他部位的对齐会给很小的 weight。
换句话说,Cross Attention 没有在使劲拉近对齐小轿车和卡车,而更多的是在努力对齐两个图片中的轮胎。一方面,Cross Attention 避免了强行拉近小轿车和卡车,减弱了噪声样本对 UDA 训练的影响;另一方面,拉近不同域的轮胎,在一定程度上可能帮助到目标域轮胎的识别。
Source 分支通过 Source 的 label 保持模型在 Source 数据集的表现,同时为 Cross Attention 提供合适的 Query 信息。Target 分支通过伪标签进行监督学习,让模型对 Target 进行合理的学习,同时为 Cross Attention 提供合适的 Key 和 Value 信息。Source-Target 分支用来使用对齐两个 domain 的特征分布。
注意,这里我们并不是直接用伪标签对 Source-Target 分支进行训练,而是使用蒸馏技术,让 Target 分支的输出去学习 Source-Target分 支的输出。公式如下:
之所以使用蒸馏技术,是因为我们相信 Cross Attention 的对齐能力和抗噪能力。如果输入的两张图片是相同类别,则中间的 Source-Target 分支可以用于学习他们共同的特征。
相比于 Target 分支,Source-Target 分支的特征实现了两个域的对齐。如果输入的两张图片是不同类别(即噪声),这时 Target 分支的label完全是错误的,会影响训练。但是中间的 Source-Target 分支使用了 Cross Attention,是有抗噪能力的。因此,我们相信,用中间的 Source-Target 分支去指导 Target 分支可以取得更好的效果。
2.3 Source-Tareget域样本匹配策略
最后,我们介绍下我们的如何借鉴伪标签的思路,来生成我们的样本对的。为了产生准确稳定的 Source-Target 样本对,我们设计了一种双向中心匹配算法,该算法是寻找合适的样本对信息输入到三分支参数共享的的 CDTrans。算法公式如下所示:
这里两个集合分别是从 Source 域样本去寻找 Target 域中距离最近的样本和 Target 域样本去寻找 Source 中距离最近的样本。最终的集合则是两者的并集。这样的好处是确保 Source 样本和 Target 样本尽可能参与到 Source-Target 样本对中,提高样本利用率。
同时我们发现,来自目标域的数据将其经过“源域数据训练的模型”时它会输出一个分类预测结果,这个结果可以用来进一步过滤我们生成的样本对集合 P,提高样本对精度。
具体来说,对 P 中的每一个样本对,目标域图片经过“源域数据训练的模型”的分类结果如果和源域图片的标签不一致时,我们认为可能这个样本对是一个噪声,将它删掉。如果一致,则保留。这里值得一提的是,我们发现 SHOT [5] 方法中采用自监督方式得到的分类结果相比于原始模型输出的分类结果要更准确,因此,文章中我们采用了 SHOT 的方式来生成分类结果。
实验结果
3.1 和SOTA比较
我们在四个数据集上做了实验,分别是 Office31,Office-Home,Visda-2017 和 DomainNet。
为了公平的跟基于 CNN 方法做比较,在 Office31,Office-Home 和 DomainNet 上,原有的方法基于 ResNet50 的数据集,我们提供了 DeiT-Small 和 DeiT-Base 两种结构的结果。DeiT-Small 整体参数量跟 ResNet50 差不多。在 VisDA-2017 上,原有的方法基于 Resnet-101,我们直接使用和其參数量大致相近的 DeiT-Base 模型作为对比。需要注意的是 TVT 方法使用的是在 ImageNet-21k 上预训练的模型初始化,而我们使用的 DeiT 的预训练模型和 ResNet 一致,都是在 Image-1k 上进行的预训练。
从结果上看,我们的方法在四个数据集上均取得了非常不错的结果,我们的效果远超之前 SOTA 方法,相比于之前最好的方法,分别提高了 5.5%/8.3%/3.3%/9.8%。
可以看到单纯使用 One-way-source 或者 One-way-target 策略时,Target Domain 或者 Source Domain 的样本利用率并不高,这会限制模型精度的提高。但是简单把 One-way-source 和 One-way-target 策略加起来的 Two-way 策略,虽然 source 和 target 域的样本利用率高了,但是精度只提升了一点点。
这主要的原因就是 source-target 样本匹配成对的精度不够高。当添加 Ca 策略之后,Tw+Ca 的样本对匹配精度提高,最终模型在 Target Domain 的精度接近 GroundTruth 上的表现。可以看出 Tw+Ca 方法要比单纯的 RPLL 和 MRKLD+LRENT 伪标签技术要好很多。
表 6 展示了 CDTrans 中不同损失函数的作用。CDTrans 主要包含三部分的损失函数,Source Domain 分支的带有 source label 的交叉熵损失, Target Domain 分枝的伪标签的交叉熵损失,中间 Fusion 分支的蒸馏损失。单纯使用 Target Domain 分支的损失,模型精度可以实现不错的精度表现,因为这更像是单纯的对 Target Domain 的样本进行学习,Source Domain 的样本仅仅经过一次模型,没有监督信息。
当同时使用 Source-Target 样本对中的 Source 和 Target 分支的损失函数的时候,精度又有一点提升,说明 Source Domain 的监督学习对 Target Domain 也有帮助。当 Fusion 分支使用交叉熵损失加入到训练中的时候,模型相对获得 1% 的提升效果。这证明了 Fusion Branch 的作用。当 Fusion 分支使用蒸馏损失的时候,模型可以获得获得 1.7% 的提升效果,证明了蒸馏损失比交叉熵损失更适合做融合操作,更有利于拉进带有噪声的 Source-Target 样本对的分布关系。
1. 正确的Souce-Target样本匹配
2. 错误的Source-Tareget样本匹配
例如“truck 和 car”的样本对中,Source-Target 的 Cross-Attention 关注于车顶和车轮上部位置,这是卡车和汽车中都有的共同点。在“plane 和 plant”图中,Plant 与 Plane 的相似度比较低,Cross Attention 关注到了背景部分。这样的好处是在 Source-Target 样本对中,Target 的伪标签同 Source 样本一致,Cross Attention 关注的背景使得这个 Target 样本避免把 Plant 的特征学习到 Plane 类里面去,减少了模型从噪声样本中学习到类别特征。
总结
参考文献
[1] Deep Domain Confusion: Maximizing for Domain Invariance
[2] Domain-Adversarial Training of Neural Networks
[3] Tsai, Yao-Hung Hubert, et al. "Multimodal transformer for unaligned multimodal language sequences." Proceedings of the conference. Association for Computational Linguistics. Meeting. Vol. 2019. NIH Public Access, 2019.
[4] CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification
[5] Do We Really Need to Access the Source Data? Source Hypothesis Transfer for Unsupervised Domain Adaptation
更多阅读
#投 稿 通 道#
让你的文字被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析、科研心得或竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。
📝 稿件基本要求:
• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注
• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题
• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算
📬 投稿通道:
• 投稿邮箱:hr@paperweekly.site
• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者
• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿
△长按添加PaperWeekly小编
🔍
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」订阅我们的专栏吧