领域泛化问题综述
近年来,机器学习(machine learning)在诸多任务上都展现出很好的效果。机器学习的主要目标是:由训练数据得到可泛化的“知识”,并将由此得到的模型应用于对测试数据(testing data)的预测。需要注意的是,传统机器学习模型都基于如下假设:训练数据(training data)和测试数据是独立同分布的。但实际中该假设并不一定成立,如果训练数据与测试数据间的分布差异变大,那么由训练数据得到的模型在测试数据上的性能会相应地变差。
为了解决该问题,一种简单的思路是让训练数据尽可能多地涵盖数据:此举很可能会带来大量的资源消耗,且实践中也很难在初期就预判到所有情况。所以,大多数情况下只能依据有限的训练数据,通过一些操作或训练技巧来提高模型的泛化能力,具体方法包括领域自适应(Domain Adaptation)、元学习(Meta-learning)等。而领域泛化(Domain Generalization)则是其中一个较新的研究方向。本文主要基于2021年国际人工智能联合会议(International Joint Conference on Artificial Intelligence, IJCAI)录用的一篇综述文章“Generalizing to Unseen Domains: A Survey on Domain Generalization”进行介绍。
一、基本概念
在引入领域泛化问题的定义前,首先给出“领域”的定义:
定义一 设分别是非空的输入空间和输出空间,则由服从联合分布的总体中抽样得到的样本称为一个领域。
由此就可定义领域泛化问题:
定义二 给定个领域,,. 领域泛化问题要根据这个训练领域学习预测映射,使得在未知(unseen)的新领域上最小化损失,即求.
领域泛化和多任务学习(Multi-task Learning)、迁移学习(Transfer Learning)等机器学习中常见任务,都有一些类似和差异之处,具体见表 1。
由表 1可以发现,只有零次学习和领域泛化在学习过程中无法访问测试数据,这也是区别领域泛化和领域自适应(Domain Adaptation)的关键之一。
多任务学习:尝试同时优化不同任务上的模型,本质不属于泛化问题。 迁移学习:利用相似性,将在旧领域学习过的模型应用于新任务/领域上。 元学习:从已有任务中学习一种学习方法或元知识,加速新任务的学习。 领域自适应:通过已有领域上的数据及标签,最优化模型在新领域上的表现。 终身学习:在学习新知识、记忆旧知识的基础上,完成连续的序列化任务。 零次学习:在没有任何对应某标签数据的情况下,解决问题(比如:训练集中只有猫和狗的图片及标签,但希望由此学习得的模型可以识别斑马)。
表 1 多种机器学习策略比较
二、常见方法
很多学者已在领域泛化方面做了有价值的工作,这篇综述文章将目前的方法归结为数据操作、特征学习、学习策略三大类,如图 1所示。
1.数据操作
数据操作是指对原始训练数据进行操作以得到新数据,从而提高样本丰富度和模型的泛化能力。这类方法可以进一步细分为数据增强和数据生成两大类。
传统的数据增强方法有翻转、旋转、裁剪、添加噪声等,在机器学习中已经得到广泛运用,且展现出很好的效果。除此之外,域随机化(domain randomization)和对抗增强(adversarial augmentation)也是行之有效的方法。域增强指通过随机化参数等途径,由原数据得到能泛化到现实场景的数据。对抗增强则借鉴了对抗训练的思想,一方面以模型损失最大为目标获得增强数据,另一方面以预测损失最小为目标、将原始数据和增强数据一起用于模型的训练。
数据生成方法也有两个主要分支。一是通过VAE、GAN等比较成熟的生成模型(generative models)来生成与原始数据尽可能接近、或分布差距尽可能大的新数据。二是根据对模型输入和标签构建具有“凸”性质的运算,生成新的训练样本及其对应的标签,即Mixup方法。Mixup方法最简单的形式为
其中是第个样本标签对应的one-hot向量,权重服从某个贝塔分布。相较于其他数据生成方法,Mixup在保持良好效果的同时,有着更低的计算消耗。
2.特征学习
特征学习(representation learning)一直是机器学习的核心问题之一,也是实现领域泛化的关键所在。
领域泛化中的一种自然的特征学习思路是:从已有数据中提取与领域无关的特征。如此,对于不同领域上的任务,就可以用这些相对通用的特征来优化模型的表现。具体而言,提取通用特征的方法有:核函数、领域对抗学习、显式特征对齐等。特征解耦是领域泛化中另一种有效的特征学习思路,即将特征分解为领域共享(domain-shared)和领域特有(domain-specific)两部分。具体而言,可以通过多成分分析(multi-component analysis)和基于生成模型解耦来实现。对于后者,可以从领域、样本、标签等不同层面分析数据的生成机制,从而解耦出有价值的信息。
3.学习策略
这类方法通过引入机器学习中成熟的学习范式来提高模型的泛化能力,主要包括:
集成学习(ensemble learning):在各个领域上学习得对应的模型,然后通过对它们加权得到一个更好、更全面的模型。 元学习:将多领域数据重新划分为元训练集和元测试集,以此来模拟域迁移(domain shift)的场景,对应的参数更新公式为 其他范式:自监督方法、自我挑战方法、随机森林、系列交替训练等。
三、应用及基准数据集
领域泛化在社会生活中有着广泛应用。在计算机视觉领域,其主要用于图像分类、街景识别、视频理解等问题上。特别地,近年来其在健康医疗中的帕金森病识别、组织分割、震颤识别等方面展现出很好的效果。在自然语言处理方面,领域泛化也常被用于情感分析、语义分割、网页分类等任务。目前,领域泛化领域的基准数据集主要以图像分类数据为主,具体如表 2所示。
表 2 领域泛化常见基准数据集
以Office-Caltech数据集为例,其中包含了来自Caltech, Amazon, Webcam和DSLR四个网站的办公用品照片,不同网站照片的拍摄角度、拍摄条件都有一定区别。在检验领域泛化方法时,就可根据某一/某些网站的照片训练出10类办公用品的分类模型,并在其他网站的照片上进行验证。
四、展望
对领域泛化方法有如下展望:
连续领域泛化:如何让系统有连续泛化和适配的能力,而非局限于离线场合。 新类别的领域泛化:如何在不同领域有不同标签的场景下实现领域泛化。 可解释的领域泛化:除了特征解耦外,是否能找到更多具有可解释性的方法。 大规模预训练与领域泛化:如何用领域泛化来提高预训练模型的泛化能力。
尽管有工作在经验上说明:已有领域泛化方法的效果并没有大幅度领先传统的经验最小化方法,但该结论只是基于最简单的分类任务。在特定的场景中,如行人再识别等,领域泛化有着更出色的发挥和更广阔的发展空间。