查看原文
其他

ICLR 2022 | 基于Transformer的跨域方法——CDTrans

达摩院AI Earth PaperWeekly 2022-10-14


©作者 | 达摩院AI Earth

本文解读我们 ICLR 2022 上发表的论文《CDTrans: Cross-domain Transformer for Unsupervised Domain Adaptation》。这篇文章提出一种基于 Transformer 的跨域方法:CDTrans。它使用 Transformer 中的 CrossAttention 机制来实现 SourceDomain 和 TargetDomain 特征对齐。具体来说,在传统方法给 TargetDomain 打伪标签的过程中难免存在噪声。由于噪声的存在,需要对齐的 Source 和 Target 的图片对可能不属于同一类,强行对齐会对训练产生很大的负面影响。
该方法经过实验发现 Transformer 中的 CrossAttention 可以有效避免噪声给对齐造成的影响。CrossAttention 让模型更多的关注 Source 和 Target 图片对中相似的信息。换句话说,即使图片对不属于同一类,被拉近的也只会是两者相似的部分。因此,CDTrans 具有一定的抗噪能力。最终实验也表明 CDTrans 的效果大幅领先 SOTA 方法。



论文标题:
CDTrans: Cross-domain Transformer for Unsupervised Domain Adaptation

论文链接:

https://arxiv.org/abs/2109.06165

代码链接:

https://github.com/CDTrans/CDTrans




前言


大多数现有的 UDA 方法都集中在学习域特征表示上,希望能够学习到一个跟类别种类相关的而跟域无关的特征。目前的研究无论是从域层面(粗粒度)还是类别层面(细粒度)上的特征对齐操作,都是使用基于卷积神经网络(CNN)的框架。大体上主流的解决思路有两种,分别是基于分布度量一致性约束的方法和基于对抗学习的方法。具有代表性的技术分别是 MMD [1] 和 DANN [2] 。


 左右图分别是MMD和DANN的网络结构图


在最近的一些研究进展中,基于类别层面的 UDA 的方法中一个主流思路是在 target 数据上得到伪标签,用伪标签训练模型。但是一个比较重要的问题是,这些伪标签通常存在一定的噪音,不可避免地会影响 UDA 的性能。 

随着 Transformer 在各种任务中的成功,特别是 MulT [3] 和 CrossViT [4] 等基于 transformer 的工作分别在多模态和多尺度上取得成功,证明了 Cross Attention 可以处理不同形式的内容,可以用来对齐不同尺度或者不同模态的数据。所以我们希望借助 Transformer 的 Cross Attention 机制来处理 UDA 任务里面的不同域的特征。CrossViT 模型的输入是同一张图片的不同尺度下的图片 patch,MulT 模型输入的是同一种含义下不同模态的数据,他们两者的数据都具有含义一致性,即数据在不同的数据表现形式(多尺度或者多模态)下,表达的含义是一致的。

 左右图分别是MulT和CrossViT的Cross Attention机制


我们把 Source 域和 Target 域的图片看作不同的数据表现形式,拉近两个域的分布的过程就是追求含义一致性的过程。所以使用 Transformer 来解决跨域(Domain Adaption, DA)的问题。另一个使用 Cross Attention 的原因是,我们发现 Cross Attention 有一定的抗噪能力,可以大幅度弱化伪标签中的噪声对 UDA 性能的影响。



方法介绍

2.1 Cross Attention及其鲁棒性


我们工作的核心思想是使用 Transformer 的 Cross Attention 机制来拉近 source 域和 target 域的图片的分布距离。据我们所知,这应该是较早使用纯 Transformer 在 UDA 上进行尝试的工作。 

具体来说,在利用 Transformer 的 Cross Attention 来做两个域分布对齐时,它的输入需要是一个样本对。类似于多模态里面的图文对,这里我们的输入是由一个 source 图片和一个 target 图片组成的样本对。正常来说,两张图片应该是属于同一个类别,但是来自于不同的 domain(一个 source,一个 target)。

由于在 UDA 任务中,target 是没有标签的。因此我们只能借鉴伪标签[5]的思路,来生成潜在的可能属于同一个 ID 的样本对。但是,伪标签生成的样本对中不可避免的会存在噪声。这时,我们惊喜的发现 Cross Attention 对样本对中的噪声有着很强的鲁棒性。 

我们分析这主要是因为 Attention 机制所决定的,Attention 的 weight 更多的会关注两张图片相似的部分,而忽略其不相似的部分。如果 Source 域图片和 Target 域图片不属于同一个类别的话,比如下图 1.a“小轿车 vs 卡车”的例子,Attention 的 weight 主要集中于两个图片中相似部分的对齐(比如轮胎),而对其他部位的对齐会给很小的 weight。

换句话说,Cross Attention 没有在使劲拉近对齐小轿车和卡车,而更多的是在努力对齐两个图片中的轮胎。一方面,Cross Attention 避免了强行拉近小轿车和卡车,减弱了噪声样本对 UDA 训练的影响;另一方面,拉近不同域的轮胎,在一定程度上可能帮助到目标域轮胎的识别。


图 1.b 是 Cross-Attention 在不同噪声比例的情况下的结果。从图 1.b 中我们可以看出使用 Cross-Attention (红线) 的表现接近只用正确样本的结果(蓝线),而不使用 Cross-Attention (绿线) 的表现受到噪声影响较大。因此,进一步表明 Cross-Attention 对噪声具有良好的鲁棒性,可以从含有噪声数据中学习到有用的信息。
2.2 共享参数的三分支网络结构



基于 Cross Attention,我们设计了共享参数的三分支网络结构,如上图所示。左侧的 Source 分支(绿色)和右侧的 Target 分支(蓝色)使用 Self Attention 来学习各自数据的特征信息,而中间的 Source-Target 分支(橙色)通过使用 Source 的 Quey 和 Target 的 Key、Value 来学习他们相同的信息。 

Source 分支通过 Source 的 label 保持模型在 Source 数据集的表现,同时为 Cross Attention 提供合适的 Query 信息。Target 分支通过伪标签进行监督学习,让模型对 Target 进行合理的学习,同时为 Cross Attention 提供合适的 Key 和 Value 信息。Source-Target 分支用来使用对齐两个 domain 的特征分布。 

注意,这里我们并不是直接用伪标签对 Source-Target 分支进行训练,而是使用蒸馏技术,让 Target 分支的输出去学习 Source-Target分 支的输出。公式如下:

之所以使用蒸馏技术,是因为我们相信 Cross Attention 的对齐能力和抗噪能力。如果输入的两张图片是相同类别,则中间的 Source-Target 分支可以用于学习他们共同的特征。

相比于 Target 分支,Source-Target 分支的特征实现了两个域的对齐。如果输入的两张图片是不同类别(即噪声),这时 Target 分支的label完全是错误的,会影响训练。但是中间的 Source-Target 分支使用了 Cross Attention,是有抗噪能力的。因此,我们相信,用中间的 Source-Target 分支去指导 Target 分支可以取得更好的效果。

2.3 Source-Tareget域样本匹配策略

最后,我们介绍下我们的如何借鉴伪标签的思路,来生成我们的样本对的。为了产生准确稳定的 Source-Target 样本对,我们设计了一种双向中心匹配算法,该算法是寻找合适的样本对信息输入到三分支参数共享的的 CDTrans。算法公式如下所示:


这里两个集合分别是从 Source 域样本去寻找 Target 域中距离最近的样本和 Target 域样本去寻找 Source 中距离最近的样本。最终的集合则是两者的并集。这样的好处是确保 Source 样本和 Target 样本尽可能参与到 Source-Target 样本对中,提高样本利用率。 

同时我们发现,来自目标域的数据将其经过“源域数据训练的模型”时它会输出一个分类预测结果,这个结果可以用来进一步过滤我们生成的样本对集合 P,提高样本对精度。

具体来说,对 P 中的每一个样本对,目标域图片经过“源域数据训练的模型”的分类结果如果和源域图片的标签不一致时,我们认为可能这个样本对是一个噪声,将它删掉。如果一致,则保留。这里值得一提的是,我们发现 SHOT [5] 方法中采用自监督方式得到的分类结果相比于原始模型输出的分类结果要更准确,因此,文章中我们采用了 SHOT 的方式来生成分类结果。




实验结果


3.1 和SOTA比较


我们在四个数据集上做了实验,分别是 Office31,Office-Home,Visda-2017 和 DomainNet。


为了公平的跟基于 CNN 方法做比较,在 Office31,Office-Home 和 DomainNet 上,原有的方法基于 ResNet50 的数据集,我们提供了 DeiT-Small 和 DeiT-Base 两种结构的结果。DeiT-Small 整体参数量跟 ResNet50 差不多。在 VisDA-2017 上,原有的方法基于 Resnet-101,我们直接使用和其參数量大致相近的 DeiT-Base 模型作为对比。需要注意的是 TVT 方法使用的是在 ImageNet-21k 上预训练的模型初始化,而我们使用的 DeiT 的预训练模型和 ResNet 一致,都是在 Image-1k 上进行的预训练。

从结果上看,我们的方法在四个数据集上均取得了非常不错的结果,我们的效果远超之前 SOTA 方法,相比于之前最好的方法,分别提高了 5.5%/8.3%/3.3%/9.8%。

3.2 消融实验


表 5 展示了 CDTrans 中各个部分起到的作用。RPLL 和 MRKLD+LRENT 作为其他的伪标签技术引入作为对比。因为 UDA 任务拉进 Source Domain 和 Target Domain 的样本分布特征,那就应该要保证在训练阶段让模型尽可能利用更多的 Source Domain 和 Target Domain 的样本。

可以看到单纯使用 One-way-source 或者 One-way-target 策略时,Target Domain 或者 Source Domain 的样本利用率并不高,这会限制模型精度的提高。但是简单把 One-way-source 和 One-way-target 策略加起来的 Two-way 策略,虽然 source 和 target 域的样本利用率高了,但是精度只提升了一点点。

这主要的原因就是 source-target 样本匹配成对的精度不够高。当添加 Ca 策略之后,Tw+Ca 的样本对匹配精度提高,最终模型在 Target Domain 的精度接近 GroundTruth 上的表现。可以看出 Tw+Ca 方法要比单纯的 RPLL 和 MRKLD+LRENT 伪标签技术要好很多。 

表 6 展示了 CDTrans 中不同损失函数的作用。CDTrans 主要包含三部分的损失函数,Source Domain 分支的带有 source label 的交叉熵损失, Target Domain 分枝的伪标签的交叉熵损失,中间 Fusion 分支的蒸馏损失。单纯使用 Target Domain 分支的损失,模型精度可以实现不错的精度表现,因为这更像是单纯的对 Target Domain 的样本进行学习,Source Domain 的样本仅仅经过一次模型,没有监督信息。

当同时使用 Source-Target 样本对中的 Source 和 Target 分支的损失函数的时候,精度又有一点提升,说明 Source Domain 的监督学习对 Target Domain 也有帮助。当 Fusion 分支使用交叉熵损失加入到训练中的时候,模型相对获得 1% 的提升效果。这证明了 Fusion Branch 的作用。当 Fusion 分支使用蒸馏损失的时候,模型可以获得获得 1.7% 的提升效果,证明了蒸馏损失比交叉熵损失更适合做融合操作,更有利于拉进带有噪声的 Source-Target 样本对的分布关系。

3.3 可视化结果
以下是样本对的可视化结果。每一列的结果分别是:Source 原图,Source Self-Attention,Target原图,Target Self-Attention, Source-Target Cross-Attention。

1. 正确的Souce-Target样本匹配


2. 错误的Source-Tareget样本匹配



从可视化的图中可以看出,Source-Target 正确匹配的样本的 Cross Attention 相关性得到了加强,相同特征的区域得到更多的注意力,而 Source-Target 错误的匹配样本,Cross Attention 朝着有相似特征的区域关注,注意力相比于 Target 的 Self-Attention 可以更好的关注与 Source 相似的区域,而更少的关注 Target 自身独特区域。

例如“truck 和 car”的样本对中,Source-Target 的 Cross-Attention 关注于车顶和车轮上部位置,这是卡车和汽车中都有的共同点。在“plane 和 plant”图中,Plant 与 Plane 的相似度比较低,Cross Attention 关注到了背景部分。这样的好处是在 Source-Target 样本对中,Target 的伪标签同 Source 样本一致,Cross Attention 关注的背景使得这个 Target 样本避免把 Plant 的特征学习到 Plane 类里面去,减少了模型从噪声样本中学习到类别特征。




总结


CDTrans 是一种首先把 Cross Attention 机制引入到 UDA 场景的 Transformer 方法。这个方法最大的特点就是使用 Cross Attention 把 Source Domain 和 Target Domain 的信息融合起来,拉进跨域样本的分布距离。注意力机制可以是得模型更加聚焦于 Source Domain 和 Target Domain 的相似的特征表示,使得模型可以获得更好的跨域精度表现。
通过共享参数的三分支结构,即可以实现 Source Domain、Target Domain 单独学习其特征表示,也可以实现 Source Domain 和 Target Domain 的相同特征表示的学习。在训练的时候输入 Source-Target 样本对进行三分支模型训练,测试阶段仅仅需要进行单分支特征提取即可。这样既保证了模型高效训练,又可以实现模型同时对 Source Domain 和 Target Domain 数据提取能力,而不仅仅是学习 Target Domain 而遗忘 Source Domain 的提取特征能力。最后希望 Cross Attention 机制可以在 UDA 场景中发挥更多的作用。


参考文献

[1] Deep Domain Confusion: Maximizing for Domain Invariance

[2] Domain-Adversarial Training of Neural Networks

[3] Tsai, Yao-Hung Hubert, et al. "Multimodal transformer for unaligned multimodal language sequences." Proceedings of the conference. Association for Computational Linguistics. Meeting. Vol. 2019. NIH Public Access, 2019.

[4] CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification

[5] Do We Really Need to Access the Source Data? Source Hypothesis Transfer for Unsupervised Domain Adaptation


更多阅读



#投 稿 通 道#

 让你的文字被更多人看到 



如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。


总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 


PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。


📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算


📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿


△长按添加PaperWeekly小编




🔍


现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧


·

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存