查看原文
其他

长文总结半监督学习(Semi-Supervised Learning)

燕皖 PaperWeekly 2022-07-04


©PaperWeekly 原创 · 作者|燕皖
单位|渊亭科技

研究方向|计算机视觉、CNN


在现实生活中,无标签的数据易于获取,而有标签的数据收集起来通常很困难,标注也耗时和耗力。在这种情况下,半监督学习(Semi-Supervised Learning)更适用于现实世界中的应用,近来也已成为深度学习领域热门的新方向,该方法只需要少量有带标签的样本和大量无标签的样本,而本文主要介绍半监督学习的三个基本假设和三类方法。


Base Assumptions


在什么假设下可以应用半监督算法呢?半监督算法仅在数据的结构保持不变的假设下起作用,没有这样的假设,不可能从有限的训练集推广到无限的不可见的集合。具体地假设有:

1.1 The Smoothness Assumption

如果两个样本 x1,x2 相似,则它们的相应输出 y1,y2 也应如此。这意味着如果两个输入相同类,并且属于同一簇,则它们相应的输出需要相近,反之亦成立。

1.2 The Cluster Assumption

假设输入数据点形成簇,每个簇对应于一个输出类,那么如果点在同一个簇中,则它们可以认为属于同一类。聚类假设也可以被视为低密度分离假设,即:给定的决策边界位于低密度地区。两个假设之间的关系很容易看出。

一个高密度区域,可能会将一个簇分为两个不同的类别,从而产生属于同一聚类的不同类,这违反了聚类假设。在这种情况下,我们可以限制我们的模型在一些小扰动的未标记数据上具有一致的预测,以将其判定边界推到低密度区域。

1.3 The Manifold Assumption

(a)输入空间由多个低维流形组成,所有数据点均位于其上;
(b)位于同一流形上的数据点具有相同标签。


Consistency Regularization


深度半监督学习的一个新的研究方向是利用未标记的数据来强化训练模型,使其符合聚类假设,即学习的决策边界必须位于低密度区域。这些方法基于一个简单的概念,即如果对一个未标记的数据应用实际的扰动,则预测不应发生显著变化,因为在聚类假设下,具有不同标签的数据点在低密度区域分离。

具体来说,给定一个未标记的数据点 及其扰动的形式 ,目标是最小化两个输出之间的距离:


流行的距离测量 d 通常是均方误差(MSE),Kullback-Leiber 散度(KL)和 Jensen-Shannon 散度(JS),我们可以按以下方式计算这些度量,其中


具体到每一种算法,核心思想是没有变化的,即最小化未标记数据与其扰动输出两者之间的距离,但计算输出的形式上有很多变化。

2.1 Pi-Model (ICLR2017)


论文标题:

Temporal Ensembling for Semi-Supervised Learning


论文链接:

https://openreview.net/forum?id=BJ6oOfqge&noteId=BJ6oOfqge


代码链接:

https://github.com/smlaine2/tempens


具体来说,由于正则化技术(例如 data augment 和 dropout)通常不会改变模型输出的概率分布,Pi-Model 正是利用神经网络中这种预测函数的特性,对于任何给定的输入 x,使用不同的正则化然后预测两次,而目标是减小两次预测之间的距离, 提升模型在不同扰动下的一致性,Pi-Model 使用 MSE 做为两个概率分布之间的损失函数。


训练过程如上图所示:对每一个参与训练的样本,在训练阶段,Pi-Model 需要进行两次前向推理。此处的前向运算,包含一次随机增强变换和不做增强的前向运算。由于增强变换是随机的,同时模型采用了 Dropout,这两个因素都会造成两次前向运算结果的不同。

损失函数:由两部分构成,如下所示,其中第一项含有一个时变系数 w,用来逐步释放此项的权重,x 是未标记数据,由两次前向运算结果的均方误差(MSE)构成。第二项由交叉熵构成,x 是标记数据,y 是对应标签,仅用来评估有标签数据的误差。可见,第一项即是用来实现一致性正则。


2.2 Temporal Ensembling (ICLR2017)


论文标题:

Temporal Ensembling for Semi-Supervised Learning


论文链接:

https://openreview.net/forum?id=BJ6oOfqge&noteId=BJ6oOfqge


代码链接:

https://github.com/smlaine2/tempens


在 Pi-Model 的基础上进一步提出了 Temporal Ensembling,其整体框架与 Pi-model 类似,在获取无标签数据的处理上采用了相同的思想,唯一的不同是:

在目标函数的无监督一项中, Pi-Model 是两次前向计算结果的均方差,而在 temporal ensembling 模型中,使用时序组合模型,采用的是当前模型预测结果与历史预测结果的平均值做均方差计算。有效地保留历史了信息,消除了扰动并稳定了当前值。


如上图所示,对于一个目标 ,在每次训练迭代中,当前输出 通过 EMA(exponential moving averag,指数滑动平均)更新被累加到整体输出中 yema:


而损失函数与 Pi-Model 相似。相对于 Pi-Model,Temporal Ensembling 有两方面的好处:

  • 用空间来换取时间,总的前向推理次数减少了一半,因而减少了训练时间;

  • 通过历史预测做平均,有利于平滑单次预测中的噪声。

2.3 Mean teachers (NIPS 2017)


论文标题:

Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results


论文链接:

https://arxiv.org/abs/1703.01780


代码链接:

https://github.com/CuriousAI/mean-teacher



如上图所示,Mean Teachers 则是 Temporal Ensembling 的改进版,Temporal Ensembling 对模型的预测值进行 EMA(exponential moving averag),而 Mean Teachers 采用了对 studenet 模型权重进行 EMA,作为 teacher model  如下:


这种情况下,损失的计算是有监督和无监督损失的总和:


2.4 Unsupervised Data Augmentation



论文标题:

Unsupervised Data Augmentation for Consistency Training


论文链接:

https://arxiv.org/pdf/1904.12848v2.pdf


代码链接:

https://github.com/google-research/uda


之前的工作中对未标记的数据加入噪声增强的方式主要是采用简单的随机噪声,但是这篇文章发现对输入 x 增加的噪声 α 对模型的性能提升有着重要的影响,因此 UDA 提出对未标记的数据采取更多样化更真实的数据增强方式,并且对未标记的数据上优化相同的平滑度或一致性增强目标。训练过程如下图所示:


(1)最小化未标记数据和增强未标记数据上预测分布之间的 KL 差异:


公式中 x 是原始未标记数据的输入, 是对未标签数据进行增强(如:图像上进行 AutoAugmen,文本进行反翻译)后的数据。

(2)为了同时使用有标记的数据和未标记的数据,添加了标记数据的 Supervised Cross-entropy Loss 和上式中定义的一致性/平滑性目标 Unsupervised Consistency Loss,权重因子 λ 为我们的训练目标,最终目标的一致性损失函数定义如下:



UDA 证明了针对性的数据增强效果明显优于无针对性的数据增强,这一点和监督学习的 AutoAugment、RandAugment 的结论是一致的。

2.5 小节

一致性正则化这类方法的主要思想是:对于无标签图像,添加噪声之后模型预测也应该保持不变。除了以上的方法外,还有 VAT [1]、ICT [2] 等等方法,这些方法也都是找到一种更适合的数据增强,因为数据增强不应该是一成不变的,而是如 UDA 所述不同的任务其数据扩增应该要不一样。


Proxy-label Methods

代理标签方法是使用预测模型或它的某些变体生成一些代理标签,这些代理标签和有标记的数据混合一起,提供一些额外的训练信息,即使生成标签通常包含嘈杂,不能反映实际情况。

这类方法主要可分为分为两类:self-training(模型本身生成代理标签)和 multi-view learning(代理标签是由根据不同数据视图训练的模型生成的)。

3.1 Self-training



如上图所示,Self-training 的训练过程如下:

Step1:首先,用少量的标签数据 L 训练 Model;也就是上图的虚线以上部分;

Step2:然后,使用训练后的 Model 给未标记的数据点 x∈U 分配 Pseudo-label(伪标签);

最受欢迎的两种方式是锐化方法和 Argmax 方法。前者在保持预测值分布的同时使分布有些极端;后者仅使用对预测具有最高置信度的预测标签进行标记。如下所示:


另一方面:我们还可以对无标签数据进行过滤,如果预测结果大于预定阈值 τ,再将其添加训练中。

Setp3:通过交叉熵损失计算模型预测和伪标签的损失。

Step4:最后,使用训练好的模型为 U 的其余部分生成代理标签,一直循环,直到模型无法生成代理标签为止。

以下就是 Self-training 的伪代码:


而 Pseudo-label [5] 与 Self-traing 基本思想是一致的,但这类方法主要缺点是:模型无法纠正自己的错误。如果模型对自己预测的结果很有“自信”,但这种自信是盲目的,那么结果就是错的,这种偏差就会在训练中得到放大。

3.2 Multi-view training

Multi-view training 利用了在实际应用中非常普遍的多视图数据。多视图数据可以通过不同的测量方法(例如颜色信息和纹理)收集不同的视图图片信息,或通过创建原始数据的有限视图来实现。

在这种情况下,MVL 的目标是学习独特的预测函数 fθi 为数据点 x 的给定视图 vi(x) 建模,并共同优化所有用于提高泛化性能的功能。理想情况下,可能的观点相互补充以便所生产的模型可以相互协作以提高彼此的性能。

3.2.1 Co-training

Co-training [3] 有 m1 和 m2 两个模型,它们分别在不同的数据集上训练。每轮迭代中,如果两个模型里的一个模型,比如模型 m1 认为自己对样本 x 的分类是可信的,置信度高,分类概率大于阈值 τ ,那 m1 会为它生成伪标签,然后把它放入 m2 的训练集。

简而言之,一个模型会为另一个模型的输入提供标签。以下是它的伪代码:



3.2.2 Tri-Training

Tri-training [4] 首先对有标记示例集进行可重复取样(bootstrap sampling)以获得三个有标记训练集,然后从每个训练集产生一个分类器。

在协同训练过程中,各分类器所获得的新标记示例都由其余两个分类器协作提供,具体来说,如果两个分类器对同一个未标记示例的预测相同,则该示例就被认为具有较高的标记置信度,并在标记后被加入第三个分类器的有标记训练集。伪代码如下:



Holistic Methods

Holistic Methods 试图在一个框架中整合当前的 SSL 的主要方法,从而获得更好的性能。

4.1 MixMatch【NeurIPS 2019】

MixMatch 整合了前面提到的一些 ideas 。对于给定一批有标签的 X 和同样大小未标签的 U,先生成一批经过处理的增强标签数据 X' 和一批伪标签的 U',然后分别计算带标签数据和未标签数据的损失项。表示为:



对于 alpha,这是一个与 Mixup 操作相关的参数,建议从 0.75 开始,并根据数据集进行调整。

具体操作如下:

Setp 1:Data Augmentation

与许多 SSL 方法中的典型方法一样,我们对标记的和未标记的数据都使用数据增强。数据增强只是标准的裁剪和翻转。

Step 2:Label Guessing

对于的每个未标记的训练数据,MixMatch 使用模型的预测为样本的生成一个“guess”标签,这个“guess”标签被用于无监督损失计算。具体地,我们计算了该模型预测的分类分布在所有 K 个增量上的平均值。如下:



每个未标记的输入数据只增加两次扩增(K=2):

Step 3:Sharpening

Sharpening 是一个很重要的过程,这个思想相当于深度学习中的 relu 过程。在给定预测的平均值的基础上,应用锐化函数减小了标签分布的熵。如下:



Sharpen 函数实际上只是一个“温度调整”,建议将温度参数 T 保持为 0.5。

Step 4:MixUp

与过去使用 MixUp 工作不同,将标记的示例与未标记的示例“混合”在一起,并发现提升了性能。具体地,将有标签数据 X 和无标签数据 U 混合在一起形成一个混合数据 W。

然后有标签数据 X 和 W 中的前 #X 个进行 mixup 后,得到的数据作为有标签数据,作为 label group,记为 X',同样,无标签数据 U 和 W 中的后 #U 个进行 mixup 后,得到的数据作为无标签数据,作为 unlabel group,记为 U'。



Loss function:对于有标签的数据,使用交叉熵;“guess”标签的数据使用MSE;然后将两者加权组合,如下图所示。


4.2 FixMatch

FixMatch 是 Google Brain 提出的一种 Holistic 的半监督学习方法,与以往的Holistic Methods不同的是,FixMatch 使用交叉熵将 weakly augment 和 strong augment 的无标签数据进行比较,并取得了不错的效果。其巧妙之处是:一致性正则化使用的是交叉熵损失函数。

FixMatch 是对弱增强图像与强增强图像之间的进行一致性正则化,但是其没有使用两种图像的概率分布一致,而是使用弱增强的数据制作了伪标签,这样就自然需要使用交叉熵进行一致性正则化了。此外,FixMatch 仅使用具有高置信度的未标记数据参与训练。


增强

  1. 弱增强:用标准的翻转和平移策略。
  2. 强增强:输出严重失真的输入图像,先使用 RandAugment 或 CTAugment,再使用 CutOut 增强。

模型

FixMatch使用 Wide-Resnet 变体作为基础体系结构,记为 Wide-Resnet-28-2,其深度为 28,扩展因子为 2。因此,此模型的宽度是 ResNet 的两倍。

训练



训练过程如下:

  1. Input:准备了 batch=B 的有标签数据和 batch=μB 的无标签数据,其中 μ 是无标签数据的比例;
  2. 监督训练:对于在标注数据的监督训练,将常规的交叉熵损失 H() 用于分类任务。有标签数据的损失记为 ls,如伪代码中第 2 行所示;
  3. 生成伪标签:对无标签数据分别应用弱增强和强增强得到增强后的图像,再送给模型得到预测值,并将弱增强对应的预测值通过 argmax 获得伪标签;
  4. 一致性正则化:将强增强对应的预测值与弱增强对应的伪标签进行交叉熵损失计算,未标注数据的损失由 lu 表示,如伪代码中的第 7 行所示;式 τ 表示伪标签的阈值;
  5. 完整损失函数:最后,我们将 ls 和 lu 损失相结合,如伪代码第 8 行所示,对其进行优化以改进模型,其中,λu 是未标记数据对应损失的权重。


总结

当标注的数据较少时模型训练容易出现过拟合,一致性正则化方法通过鼓励无标签数据扰动前后的预测相同使学习的决策边界位于低密度区域,很好缓解了过拟合这一现象,代理标签法通过对未标记数据制作伪标签然后加入训练,以得到更好的决策边界,而众多方法中,混合方法表现出了良好的性能,是近来的研究热点。


参考文献

[1] Takeru M , Shin-Ichi M , Shin I , et al. Virtual Adversarial Training: A Regularization Method for Supervised and Semi-Supervised Learning[J]. IEEE Transactions on Pattern Analysis and Machine Intelligence, 2018:1-1.

[2] Verma V , Lamb A , Kannala J , et al. Interpolation Consistency Training for Semi-Supervised Learning[J]. 2019.

[3] Avrim Blum and Tom Mitchell. Combining labeled and unlabeled data with co-training. In Proceedings of the eleventh annual conference on Computational learning theory, pages 92–100, 1998.

[4] Zhi-Hua Zhou and Ming Li. Tri-training: Exploiting unlabeled data using three classififiers. IEEE Transactions on knowledge and Data Engineering, 17(11):1529–1541, 2005.

[5] Dong-Hyun Lee. Pseudo-label: The simple and effiffifficient semi-supervised learning method for deep neural networks. In Workshop on challenges in representation learning, ICML, volume 3, page 2, 2013.



更多阅读




#投 稿 通 道#

 让你的论文被更多人看到 



如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。


总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 


PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学习心得技术干货。我们的目的只有一个,让知识真正流动起来。


📝 来稿标准:

• 稿件确系个人原创作品,来稿需注明作者个人信息(姓名+学校/工作单位+学历/职位+研究方向) 

• 如果文章并非首发,请在投稿时提醒并附上所有已发布链接 

• PaperWeekly 默认每篇文章都是首发,均会添加“原创”标志


📬 投稿邮箱:

• 投稿邮箱:hr@paperweekly.site 

• 所有文章配图,请单独在附件中发送 

• 请留下即时联系方式(微信或手机),以便我们在编辑发布时和作者沟通



🔍


现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧



关于PaperWeekly


PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击「交流群」,小助手将把你带入 PaperWeekly 的交流群里。



您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存