刚刚和小伙伴参加完 kaggle 的 Global Wheat Detection 比赛获得了 Private Leaderboard 第七的名次,首先,在这次比赛中我们发现在 Public Leaderboard 所得到成绩和 Private Leaderboard 所得到的成绩有很大的差异,其次,我们还发现了一些除魔改模型之外对涨点有效的方法。 这是我们成绩排名截图。下面就具体看看这两种方法。
Data argument
在训练神经网络时,我们常常会遇到的一个只有小几百数据,然而,神经网络模型都需要至少成千上万的图片数据。因此,为了获得更多的数据,我们只要对现有的数据集进行微小的改变。 比如翻转(flips)、平移(translations)、旋转(rotations)等等。而我们要介绍的是 MixMatch,可以看做是半监督学习下的 mixup 扩增。
论文标题: MixMatch: A Holistic Approach to Semi-Supervised Learning 论文链接: https://arxiv.org/pdf/1905.02249.pdf 代码链接: https://github.com/google-research/mixmatch 对于许多半监督学习方法,往往都是增加了一个损失项,这个损失项是在未标记的数据上计算的,以促进模型更好地泛化到训练集之外的数据中。一般地,这个损失项可分为三类: 熵最小化——它鼓励模型对未标记的数据输出有信心的预测;
一致性正则化——当模型的输入受到扰动时,它鼓励模型产生相同的输出分布;
泛型正则化——这有助于模型很好地泛化,避免对训练数据的过度拟合。
MixMatch 整合了前面提到的一些 ideas 。对于给定一个已经标签的 batch X 和同样大小未标签的 batch U,先生成一批经过 Mixup 处理的增强标签数据 X' 和一批伪标签的 U',然后分别计算带标签数据和未标签数据的损失项。具体地流程如下:
MixMatch 就是将无监督和有监督的数据分开进行 mixup 增强,然后无监督的 loss 使用的是 MSE。在比赛中,我们发现如果有监督和无监督一起进行 mixup,性能会下降,而分开进行 mixup 增强,则会进一步提升。
尽管 SSL 取得显著进展,但 SSL 方法主要应用于图像分类,今天介绍一种用于目标检测的 SSL,称为 STAC。
论文标题: A Simple Semi-Supervised Learning Framework for Object Detection
论文链接: https://arxiv.org/pdf/2005.04757.pdf
代码链接: https://github.com/google-research/ssl_detection/
这篇文章利用了 Self-training和 Augmentation driven Consistency regularization,所以称为 STAC。具体训练步骤如下:
现在就看 SSL 的另一个关键点——未标记数据的无监督的损失函数:
其中,ls 是有监督的损失函数,lu 是无监督的损失函数,A 是应用于未标记图像的强数据增强,p 和 s 是类别,t 和 q 是边框坐标。 将 data augmentations 应用于半监督学习的方法在很早就有文献提出,其背后的思想是 Consistency Regularization,即使对未标记的示例进行了增强,分类器也应该输出相同的类分布。 具体地,一致性正则化强制未标记的样本 x 应该与增强后的样本 Augment(x) 保持一致,其中 Augment 是一个随机数据增强函数,例如:随机空间平移或添加噪声。而本文实验发现 λu ∈ [1, 2] 的时候效果最好。说明了半监督和有监督的重要性是不一样的。
Global Wheat Detection
https://www.kaggle.com/c/global-wheat-detection/overview/code-requirements
比赛背景: 主要是准确估计算出不同品种的小麦头的密度和大小,从而帮助农民评估自己的农作物 比赛要求: 检测并框出图片中的小麦头,评估方式是 MAP,MAP,主要是权衡 precision 和 recall 的一个指标。截止时间 8 月 4 号,提交要求不能联网并且 CPU Notebook <= 9 hours run-time,GPU Notebook <= 6 hours run-time 数据集: 训练集为 3434 张小麦图片,在 Public Leaderboard 上计算成绩的测试集占总的测试集的 62%,而在最终计算 Private Leaderboard 成绩的测试集为占中的测试集的 38%。 在训练阶段通过对图像进行增强来数据扩增训练出多个模型,然后在测试集上进行半监督学习。最后,在检测时利用 TTA(Test time augmentation)增加检测的准确性,并利用 wbf 融合多个模型的结果。 3.2.1 训练数据扩增
由于在训练的数据量较小(容易过拟合),而并且测试集的分布比较分散,对模型泛化能力要求比较高,因此采取对图像增强的方式对训练集进行扩增,采取的图像扩增的方式有图像的缩放、随机水平翻转和垂直翻转、多个图像的拼接、色彩空间 hsv 增强,通过这个方式训练集扩增了 5 倍以此缓解训练数据量小的问题。 3.2.2 半监督训练
伪标签对成绩的提升有很大的帮助,最初在 Public Leaderboard 上没加入伪标签技术成绩:0.7522 , 加入伪标签技术后成绩为:0.7720 ,增加了 0.0198,排名提升了一百多名效果可以说是相当的明显了。 具体地,我们对图像的增强策略包括 Vertical Flip,HorizontalFlip,Rotate90,180,270,Multi-Scale 0.83 and 1.2 ,cutout,mixup,然后利用在训练集训练好的模型对未标记的测试集图片进行伪标签制作。 最开始,我们也仅仅是增加这些 argument,能够达到 0.7720,进一步使用 MixMatch 和 STAC 的方法后,分别能够达到 0.7734 和 0.7751。 在检测的过程中使用了 TTA(Test time augmentation),对原始图像进行旋转(90°,180°,270°)、垂直水平翻折、图像缩放(放大 1.2 倍,缩小 0.87 倍),然后对 TTA 后的图像进行检测,最终将所得到的 box 进行 nms。 采用 TTA(测试时增强),可以对一幅小麦图像做多种变换,创造出多个不同版本,对多个版本数据进行计算最后得到平均输出作为最终结果,提高了结果的稳定性和精准度。 在这次比赛中最终提交的两个方案中,方案一也就是上面使用的方案取得了 Private Leaderboard 第七的成绩,方案二:增加了根据验证集计算成绩自动选择最好的阈值,对于伪标签的训练 epoch 增加到 15,而减少了半监督训练中的 Argument(只剩下了旋转)。 方案一在 Public Leaderboard 表现一般的方案成绩为 0.7721 排在 55 名,但是却在 Private Leaderboard 排在了第七名。方案二在 Public Leaderboard 上成绩还不错的方案 0.7751 在排名在 23,但是在 Private Leaderboard上37% 的测试集我的成绩却为 0.6954 排在了 300 多名。
写在最后
由于本次比赛的数据集较小,很容易导致过拟合的现象。比赛结束的时候发现 Public leaderboard 成绩还不错,但是当 Private Leaderboard 出来后排名一落千丈,相比较而言,数据量大了的比赛绝大部分人排名都没有变化,少数有 1~2 名的浮动在。 在这次比赛里的方案二由于 Public Leaderboard 上测试集占 62%,测试集样本较多,因此增加伪标签的训练使得它在 Public Leaderboard 上的成绩增加很多,但是方案二发生了过拟合使得在 Private Leaderboard 上的成绩下降就很明显。 因此,深度学习网络训练到什么时候停止?在关注训练集数量、质量以及分布等等因素的同时,更应该测试集(实际场景)的情况。否则常常会出现悲惨结局。另外,除了魔改模型,数据增强和半监督都是跳出魔改模型的好方法,能够使得模型获得更多的泛化能力。
#投 稿 通 道 #
让你的论文被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读 ,也可以是学习心得 或技术干货 。我们的目的只有一个,让知识真正流动起来。
📝 来稿标准:
• 稿件确系个人原创作品 ,来稿需注明作者个人信息(姓名+学校/工作单位+学历/职位+研究方向)
• 如果文章并非首发,请在投稿时提醒并附上所有已发布链接
• PaperWeekly 默认每篇文章都是首发,均会添加“原创”标志
📬 投稿邮箱:
• 投稿邮箱: hr@paperweekly.site
• 所有文章配图,请单独在附件中发送
• 请留下即时联系方式(微信或手机),以便我们在编辑发布时和作者沟通
🔍
现在,在「知乎」 也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」 订阅我们的专栏吧
关于PaperWeekly
PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击「交流群」 ,小助手将把你带入 PaperWeekly 的交流群里。