前言
算法、算力、数据是深度学习的三架马车。深度学习是数据驱动式方法,目前的从业基本者都有一个共识就是:数据是非常重要的且不可或缺的。在实际环境中对数据标注又是一个耗时和昂贵的过程。但是受束于资源的限制,可能你有很多的图片,但是只有一部分可以进行人工标注。例如工业频繁更换型号的场景,花费更多时间标注意味着上线运行时间的 delay,会严重影响效率和产能。在这样的情况下,如何利用大量未标注的图像以及部分已标注的图像来提高模型的性能呢?答案是 semi-supervised 半监督学习。半监督学习(SSL)这个领域近年来得到飞速的发展,方法也有很多,但很多都是使用较为复杂的方法,标注降低了,但是训练复杂度等其他方面的代价上来了。本文提出 FixMatch,是一种对现有 SSL 方法进行显著简化的算法。FixMatch 使用模型的预测生成伪标签进行无标签数据的训练。本文贡献:利用一致性正则化( Consistency regularization)和伪标签(pseudo-labeling)技术进行无监督训练。SOTA 精度,其中 CIFAR-10 有 250 个标注,准确率为 94.93%。甚至仅使用10张带有标注的图在 CIFAR-10 上达到 78% 精度。
论文标题:
FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence
论文链接:
https://arxiv.org/abs/2001.0768
官方代码:
https://github.com/google-research/fixmatch
核心思想
如上图所示,训练过程包括两个部分,有监督训练和无监督训练。有 label 的数据,执行有监督训练,和普通分类任务训练没有区别。没有 label 的数据,经过首先经过弱增强获取伪标签。然后利用该伪标签去监督强增强的输出值,只有大于一定阈值条件才执行伪标签的生成。无监督的训练过程包含两种思想在里面,即一致性正则化和伪标签训练。一致性正则化是当前半监督 SOTA 工作中一个重要的组件,其建立在一个基本假设:相同图片经过不同扰动(增强)经过网络会输出相同预测结果,因此对这二者进行 loss 计算便可以对网络进行监督训练,又被称为自监督训练。loss 计算如下:
伪标签是利用模型本身为未标记数据获取人工标签的思想。通常是使用“hard”标签,也就是 argmax 获取的 onehot 标签,仅保留最大类概率超过阈值的标签。计算 loss 的时如下:
其中 , 为阈值。我们假设 argmax 一个概率分布产生一个有效的 onehot 概率分布。这种机制为什么 work?无监督训练过程实际上是一个孪生网络,可以提取到图片的有用特征。弱增强不至于图像失真,再加上输出伪标签阈值的设置,极大程度上降低了引入错误标签噪声的可能性。而仅仅使用弱增强可能会导致训练过拟合,无法提取到本质的特征,所以使用强增强。强增强带来图片的严重失真,但是依然是保留足够可以辨认类别的特征。有监督和无监督混合训练,逐步提高模型的表达能力。
算法流程图
输入的数据包括两个部分,有标注的数据和没有标注的数据,另外需要设定置信度阈值、采样比例、loss 权重等超参。
- 对有标注的部分执行有监督训练,使用传统的 CE loss。
- 对无标注的数据,利用获取的伪标签进行训练,同样利用 CE loss。
loss计算
loss 包括两部分,标注有监督分类任务 loss 和无监督为标签训练 loss 分别如下。其中 表示弱增强,一般为 flip、平移; 表示强增强,一般为色彩变换、对比度增强、旋转等,下面会细说。
数据增强
上文提到了该方法应用到了两种数据增强,分别是 weak Augmentation 和 strong Augmentation。weak Augmentation 为标准的 flip-and-shift 增强策略,50% 的概率进行 flip 和 12.5% 的概率进行 shift,包括水平和竖直方向。对于 strong Augmentation,论文主要应用 RandAugment 和 CTAugment 两种策略,都是为提高模型表现而提出的增强策略。首先进行严重失真的增强,然后再应用 CutOut 增强。增强函数来自 PIL 库。以前的 SSL 使用的是 AutoAugment,这个工具训练了一个强化学习算法来寻找最佳准确率的增强方法。本文 FixMatch 使用了 AutoAugment 的两个变体之一,采用的是随机采样策略,减少网络对增强之间耦合程度的依赖。对于 RandAugment:
上表为 RandAugment 涉及到的所有的增强和参数区间。对于 CTAugment 同理,主要有 Autocontrast 、Brightness、Color、Contrast、Cutout、Equalize、Invert、Identity、Posterize、Rescale、Rotate、Sharpness、Shear_x、Shear_y、Smooth、Solarize、Translate_x、Translate_y,不再叙述。
实验效果
作者分别在 CIFAR 和 SVHM 等数据集上进行了训练测试,模型表现超过之前的网络。具体如下:
对于极端缺少标注的场景,仅仅使用每个类别 1 张共 10 张标注的图片就可以达到 78% 的最大 accuracy,当然这种做法和挑选的样本质量有关,作者也做了相关实验论证。不过也证明本文的方法的确 work。
另外,作者做了很多实验消融实验,包括一些训练超参调节的,不再叙述,有兴趣的可以阅读原文。
结语
本文介绍了一篇半监督领域经典论文,其做法简单有效,使用图像增强技术进行伪标签学习和一致性正则化训练,在 CIFAR 等多个数据集上仅仅利用少量的标注图片就可以达到一个不错的效果,这对于获取标注困难的场景非常有意义。例如在工业应用领域,可能会有海量数据,但是现实限制可能无法都进行人工标注,因此可以尝试利用半监督训练的方法,非常值得借鉴。
#投 稿 通 道#
让你的论文被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学习心得或技术干货。我们的目的只有一个,让知识真正流动起来。
📝 来稿标准:
• 稿件确系个人原创作品,来稿需注明作者个人信息(姓名+学校/工作单位+学历/职位+研究方向)
• 如果文章并非首发,请在投稿时提醒并附上所有已发布链接
• PaperWeekly 默认每篇文章都是首发,均会添加“原创”标志
📬 投稿邮箱:
• 投稿邮箱:hr@paperweekly.site
• 所有文章配图,请单独在附件中发送
• 请留下即时联系方式(微信或手机),以便我们在编辑发布时和作者沟通
🔍
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」订阅我们的专栏吧
关于PaperWeekly
PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击「交流群」,小助手将把你带入 PaperWeekly 的交流群里。