查看原文
其他

如何通过Meta Learning实现域泛化(Domain Generalization)?

丘明姗 PaperWeekly 2022-07-06


©作者 | 丘明姗

单位 | 华南理工大学

研究方向 | 领域泛化


域泛化(Domain Generalization)中有很多工作是用 meta learning 做的。Meta learning 在 few shot 中很常用,它的目的也是提升模型的泛化性,所以我们来看看 DG 中采用 meta learning 的工作。



Revisit Meta Learning


Meta learning 的motivation 就是让模型学会学习。一个学会了如何学习的模型,自然就有好的泛化性。

以 few shot learning 背景为例,我们只有少量的样本来训练一个任务。直接用少量的样本训练模型显然会过拟合,那怎么办?Meta learning 给出的策略就是采用公用大型数据集和已有的少样本共同训练模型。它将数据集分成两类,大型数据集的样本称为 support sets,少样本称为 query sets;将训练分成两个阶段,一次学习称为一个 epoch(整个数据集),首先在 support sets 上训练并更新一次梯度,接着用 query sets 基于 support sets 更新的模型再求一次梯度,本轮 epoch 的梯度更新与 query sets 上梯度更新方向一致。


可以这么理解,support sets 的作用就是让模型有一个好的初始化,接着再用 query sets 对模型进行 fine-tune,使模型真正适用于任务场景。显然,大型数据集和拥有的少样本数据来自不同 domain,存在 distribution shift,大型数据集训练的模型在任务上只能得到次优的效果。而通过一次次 query sets 的"fine-tune",模型就能很好地适应任务场景。 

这么一看,是不是跟 DG 要做的很像?所以,DG 也这么干了。但是 DG 的场景会更困难一些,因为 DG 在训练时根本不知道目标域数据,就没法用目标域数据作为 query sets。因此 DG 退而求其次的策略是将源域数据划分成 support sets 和 query sets(DG 的论文里一般称为 meta-training sets 和 meta-testing sets),核心依然是模拟 distribution shifts,训练出对 distribution shift robust 的模型,就认为模型拥有了泛化到目标域的能力。



Meta Learning与Domain Alignment对比


Domain Alignment 专注于特征的学习,学到 domain agnostic 的特征。因此它会通过 loss 或者是 domain 判别器等其他各种手段对提取的特征施加约束,认为成功实现分布对齐的模型就是泛化性好的模型。它只是简单通过不同源域的训练数据来模拟 distribution shift。 

Meta learning 主要是对输入数据的设计,强调数据的 distribution shift,并通过两次梯度更新使模型 robust,认为学到 distribution shift 的模型就是泛化性好的模型。但没有对数据作显式对齐。 

其实,meta learning 可以看做是一个训练 trick,它可以和所有 DG 方法结合使用。因为 meta learning 对模型结构,loss 都没有任何要求(也称为 model agnostic),只需要对训练数据和训练过程做简单的调整,就可以套在任何模型上了。因此,要是你发现自己的 DG 模型效果不够满意,可以考虑叠加这个 buff(感觉我在教坏人-_-



DG中的Meta Learning


下面就来看几篇 DG 中的论文,了解它们是怎么使用这个 trick 的。


3.1 Meta Learning实现DG


本文给出的方法很简单,但是它对 meta learning 的 insight 做了很好的解释。



论文标题:

Learning to Generalize: Meta-Learning for Domain Generalization

论文链接:

https://arxiv.org/abs/1710.03463


训练时共有 个源域,每次训练采用一个源域作为 meta-testing set,另外的源域作为 meta-training set,得到目标函数:


有意思的点是作者对上述目标函数做 Taylor 展开,得到了以下的形式:


这揭示了目标函数一是要最小化在 meta-training set 和 meta-testing set 上的误差(上式第一第二项),二是使 meta-training set 和 meta-testing set 的优化方向最大程度地相似(上式第三项)。显然,如果目标函数是 ,模型很可能偷懒,找一个容易使该式最小化的源域的梯度方向进行优化,从而过拟合这个源域。而 meta leanring 的目标函数函数加上了这个正则化约束,就促使模型考虑所有源域的梯度方向。因此作者还给出下面两种改进的 meta learning 目标函数,可以替代上式的点积计算相似度。



第一种改进是将点积替换成余弦相似度。第二种是退化为用 meta-training set 的方向优化 meat-testing set,这种方式关键是需要模型有好的初始化。
3.2 解决DG中的Batch Normalization问题


论文标题:

MetaNorm: Learning to Normalize Few-Shot Batches Across Domains

论文链接:

https://openreview.net/forum?id=9z_dNsC4B5t


这篇文章的 motivation 是解决 DG 中的 BN 问题,它也用了 meta leanring 的 trick。 
我们都知道,当网络层数很多时,每一层参数的更新会导致上一层输入数据分布变化,也就是发生 iternal covariate shift,这样很容易导致梯度消失或梯度爆炸。BN 可以调整数据接近独立同分布,使训练更稳定。BN 用训练数据计算均值和方差来实现正则化,这在训练和测试数据是独立同分布时显然没问题,但 DG 的训练和测试数据不同分布,这么做就行不通了。 
这时要怎么办?思路依然很简单,就是分布对齐。这篇文章首先用下面的公式推断 domain-specific 统计量。


接着最小化所有 domain-specific 统计量的 KL 散度。



数据依然是分成 meta-training set 和 meta-testing set 两部分。在 meta-training 阶段的目标是最小化交叉熵损失和 KL 散度,meta-testing 阶段不再最小化 KL 散度,只是最小化正则化数据的交叉熵。
3.3 语义空间对齐


论文标题:

Domain Generalization via Model-Agnostic Learning of Semantic Features

论文链接:https://proceedings.neurips.cc/paper/2019/hash/2974788b53f73e7950e8aa49f3a306db-Abstract.html
以往 DG 中都是实现特征空间对齐,目的是 domain invariant。本文还进行了另一种对齐:语义空间对齐,目的是保持多个源域在语义空间上 class 之间的关系。因为 DG 场景没有任何有标签的目标域数据可以提供语义空间的信息,为了提升预测的准确率,一种思路就是不妨也将源域的语义空间信息也迁移到目标域上。下面这篇 DA 的研究也提到了语义空间对齐的好处。


论文标题:

Simultaneous Deep Transfer Across Domains and Tasks

论文链接:

https://openaccess.thecvf.com/content_iccv_2015/html/Tzeng_Simultaneous_Deep_Transfer_ICCV_2015_paper.html


Recall that in this setting, we have access to target labeled data for only half of our categories. We use soft label information from the source domain to provide information about the held-out categories which lack labeled target examples.

一个好的特征空间自然是不同 domain 的数据尽量混在一起难以区分,不同 class 的数据尽量形成良好的聚簇。作者就此分别对语义空间和特征空间采用了不同的操作。 
首先是语义空间。对于每个 domain,计算特征空间中属于同一 class 的样本的均值,作为这个 class 的 'concept',并通过 softmax 得到这个 class 的软标签。



接着聚合同一个 domain 的所有软标签向量,得到软标签混淆矩阵。我们希望训练过程中不同 domain 的 inter-class 关系能够被保持,因此操作还是进行 domain 的对齐,也就是最小化不同 domain 混淆矩阵的对称 KL 散度。


接着是特征空间对齐。同样是借鉴对比损失的思想,计算下面的 triplet loss,使 positive sample 与 anchor 的距离小于 negative sample 与 anchor 的距离。


本文的训练数据同样被分为 meta-training set 和 meta-testing set 来模拟 distribution shift。



总结
Meta learning 就是通过对已有的数据作简单的划分模拟 distribution shift,使模型学得更 robust。它是一种训练的思路,可以和任何 DG 的模型结构结合来增强泛化性。 
但是 meta learning 同样存在一些缺陷。一是虽然可能训练得到的模型对 distribution shift 不那么敏感,但仍不能避免模型对源域数据过拟合。二是模型每一层更新都要求两次梯度,计算效率自然会慢。


独家定制「炼丹贴纸」

限量 200 份!

扫码回复「贴纸」 

立即免费参与领取

👇👇👇




🔍


现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧



·

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存