【源头活水】Mutual Mean-Teaching:为无监督学习提供更鲁棒的伪标签
“问渠那得清如许,为有源头活水来”,通过前沿领域知识的学习,从其他研究领域得到启发,对研究问题的本质有更清晰的认识和理解,是自我提高的不竭源泉。为此,我们特别精选论文阅读笔记,开辟“源头活水”专栏,帮助你广泛而深入的阅读科研文献,敬请关注。
本文介绍一篇我们发表于ICLR-2020的论文《Mutual Mean-Teaching: Pseudo Label Refinery for Unsupervised Domain Adaptation on Person Re-identification》[1],其旨在解决更实际的开放集无监督领域自适应问题,所谓开放集指预先无法获知目标域所含的类别。这项工作在多个行人重识别任务上验证其有效性,精度显著地超过最先进技术13%-18%,大幅度逼近有监督学习性能。这也是ICLR收录的第一篇行人重识别任务相关的论文,代码和模型均已公开。
论文链接:https://openreview.net/forum?id=rJlnOhVYPS
代码链接:https://github.com/yxgeee/MMT
视频介绍:
01
任务
行人重识别(Person ReID)旨在跨相机下检索出特定行人的图像,被广泛应用于监控场景。如今许多带有人工标注的大规模数据集推动了这项任务的快速发展,也为这项任务带来了精度上质的提升。然而,在实际应用中,即使是用大规模数据集训练好的模型,若直接部署于一个新的监控系统,显著的领域差异通常会导致明显的精度下降。在每个监控系统上都重新进行数据采集和人工标注由于太过费时费力,也很难实现。所以无监督领域自适应(Unsupervised Domain Adaptation)的任务被提出以解决上述问题,让在有标注的源域(Source Domain)上训练好的模型适应于无标注的目标域(Target Domain),以获得在目标域上检索精度的提升。值得注意的是,有别于一般的无监督领域自适应问题(目标域与源域共享类别),行人重识别的任务中目标域的类别数无法预知,且通常与源域没有重复,这里称之为开放集(Open-set)的无监督领域自适应任务,该任务更为实际,也更具挑战性。
动机
无监督领域自适应在行人重识别上的现有技术方案主要分为基于聚类的伪标签法、领域转换法、基于图像或特征相似度的伪标签法,其中基于聚类的伪标签法被证实较为有效,且保持目前最先进的精度 [2,3],所以该论文主要围绕该类方法进行展开。基于聚类的伪标签法,顾名思义,(i)首先用聚类算法(K-Means, DBSCAN等)对无标签的目标域图像特征进行聚类,从而生成伪标签,(ii)再用该伪标签监督网络在目标域上的学习。以上两步循环直至收敛,如下图所示:
尽管该类方法可以一定程度上随着模型的优化改善伪标签质量,但是模型的训练往往被无法避免的伪标签噪声所干扰,并且在初始伪标签噪声较大的情况下,模型有较大的崩溃风险。所谓伪标签噪声主要来自于源域预训练的网络在目标域上有限的表现力、未知的目标域类别数、聚类算法本身的局限性等等。所以如何处理伪标签噪声对网络最终的性能产生了至关重要的影响,但现有方案并没有有效地解决它。
02
概述
为了有效地解决基于聚类的算法中的伪标签噪声的问题,该文提出利用"同步平均教学"框架进行伪标签优化,核心思想是利用更为鲁棒的"软"标签对伪标签进行在线优化。在这里,"硬"标签指代置信度为100%的标签,如常用的one-hot标签[0,1,0,0],而"软"标签指代置信度<100%的标签,如[0.1,0.6,0.2,0.1]。
如上图所示,A1与A2为同一类,外貌相似的B实际为另一类,由于姿态多样性,聚类算法产生的伪标签错误地将A1与B分为一类,而将A1与A2分为不同类,使用错误的伪标签进行训练会造成误差的不断放大。该文指出,网络由于具备学习和捕获数据分布的能力,所以网络的输出本身就可以作为一种有效的监督。然而,利用网络的输出来训练自己是不可取的,会无法避免地造成误差的放大。所以该文提出同步训练对称的网络,在协同训练下达到相互监督的效果,从而避免对网络自身的输出误差形成过拟合。在实际操作中,该文利用"平均模型"进行监督,提供更为可信和稳定的"软"标签,将在下文进行描述。总的来说,该文
提出"相互平均教学"(Mutual Mean-Teaching)框架为无监督领域自适应的任务提供更为可信的、鲁棒的伪标签;
针对三元组(Triplet)设计合理的伪标签以及匹配的损失函数,以支持协同训练的框架。
相互平均教学(MMT)
如上图所示,该文提出的"相互平均教学"框架利用离线优化的"硬"伪标签与在线优化的"软"伪标签进行联合训练。"硬"伪标签由聚类生成,在每个训练epoch前进行单独更新;"软"伪标签由协同训练的网络生成,随着网络的更新被在线优化。直观地来说,该框架利用同行网络(Peer Networks)的输出来减轻伪标签中的噪声,并利用该输出的互补性来优化彼此。而为了增强该互补性,主要采取以下措施:
对两个网络Net 1和Net 2使用不同的初始化参数;
随机产生不同干扰,例如,对输入两个网络的图像采用不同的随机增强方式,如随机裁剪、随机翻转、随机擦除等,对两个网络的输出特征采用随机dropout;
训练Net 1和Net 2时采用不同的"软"监督,i.e. "软"标签来自对方网络的"平均模型";
采用网络的"平均模型"Mean-Net 1/2而不是当前的网络本身Net 1/2进行相互监督。
此处,"平均模型"的参数
这里, 指第 个iteration,
在行人重识别任务中,通常使用分类损失与三元损失进行联合训练以达到较好的精度。其中分类损失作用于分类器的预测值,而三元损失直接作用于图像特征。为了方便展示,下文中,我们使用
"软"分类损失
利用"硬"伪标签进行监督时,分类损失可以用一般的多分类交叉熵损失函数
上式中,
上式中 和
"软"三元损失
传统的三元(anchor, positive, negative)损失函数表示为:
上式中
这里softmax-triplet的取值范围为 ,可以用来替换传统的三元损失,当使用"硬"伪标签进行监督时,可以看作二分类问题,使用二元交叉熵损失函数
这里的" "指的是每个样本与其负样本的欧氏距离应该远远大于与正样本的欧氏距离。但由于伪标签存在噪声,并不能完全正确地区分正负样本,所以该文提出需要软化对三元组的监督(使用"平均模型"输出的特征距离比
该损失函数旨在让Net 1输出的softmax-triplet逼近Mean-Net 2的softmax-triplet预测值,让Net 2输出的softmax-triplet逼近Mean-Net 1的softmax-triplet预测值。通过该损失函数的设计,该文有效地解决了传统三元损失函数无法支持"软"标签训练的局限性。"软"三元损失函数可以有效提升无监督领域自适应在行人重识别任务中的精度,实验详情参见原论文消融学习的对比实验。
算法流程
该文提出的"相互平均教学"框架利用"硬"/"软"分类损失和"硬"/"软"三元损失联合训练,在每个训练iteration中,主要由三步组成:
1. 通过"平均模型"计算分类预测和三元组特征的"软"伪标签;
2. 通过损失函数的反向传播更新Net 1和Net 2的参数;
3. 通过参数加权平均法更新Mean-Net 1和Mean-Net 2的参数。
03
该文在四个行人重识别任务上进行了验证,精度均比现有最先进的方法 [2,3] 提升十个点以上,媲美有监督学习的性能。论文中使用K-Means聚类进行实验,在每个行人重识别任务中都对不同的伪类别数(表格中表示为"MMT-伪类别数")进行了验证。发现无需设定特定的数目,均可获得最先进的结果。另外,开源的代码中包含了基于DBSCAN的实验脚本,可以进一步提升性能,感兴趣的同学可以尝试。论文中的消融研究有效证明了"相互平均教学"框架的设计有效性和可解释性,在这里就不细细展开了。
MMT+ (VisDA-2020)
我们在ECCV 2020 Workshop的Visual Domain Adaptation Challenge中进一步优化了MMT,获得第二名,方案解读参见:
https://zhuanlan.zhihu.com/p/265758275
04
本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。
“源头活水”历史文章
Few-shot object detection论文整理(CVPR2021)
图神经网络也可以很快——Cluster-GCN
基于3D卷积神经网络的人体行为识别(3D CNN)
CVPR2021 | Variational Relational Point Completion Network
一种极简的深度子领域自适应方法DSAN
以因果为先验的解耦表示 | 生成模型——CausalVAE及其扩展
使用具有外部记忆的神经网络模型对上下文和结构化知识进行对话
IPT CVPR 2021 | 底层视觉预训练Transformer | 华为开源代码解读
CVPR 2021 | LapStyle - 基于拉普拉斯金字塔的高质量风格化方法
ICLR2021 | 通过干预的无监督解耦表示
实时目标检测算法
Shuffle Transformer 高效快速的基础模型
BeBold:一种新的强化学习探索准则
这篇CVPR文章真是妙蛙种子到了妙妙屋
更多源头活水专栏文章,
请点击文章底部“阅读原文”查看
分享、在看,给个三连击呗!