查看原文
其他

【源头活水】Meta Transfer Learning for Few Shot Learning

“问渠那得清如许,为有源头活水来”,通过前沿领域知识的学习,从其他研究领域得到启发,对研究问题的本质有更清晰的认识和理解,是自我提高的不竭源泉。为此,我们特别精选论文阅读笔记,开辟“源头活水”专栏,帮助你广泛而深入的阅读科研文献,敬请关注。

作者:知乎—风中的大虾

地址:https://www.zhihu.com/people/zhang-jin-yang-0

这是CVPR19的一篇文章,主要的工作是基于meta-learning的few shot learning。
link :https://arxiv.org/abs/1812.02391
code :https://github.com/yaoyao-liu/meta-transfer-learning
这篇文章对MAML提出了一些改进,主要的框架基本还是基于MAML的,关于MAML的介绍,可以参考
https://zhuanlan.zhihu.com/p/57864886
https://zhuanlan.zhihu.com/p/72920138

接下来我们来看这篇文章
Meta-learning,尤其是MAML已经被当做few shot任务的一种基本框架,它的基本思想是在已有的数据上(meta train data)构建一系列meta task,训练一个base-learner,使得其在遇到新任务,或者说只有少部分标签的未知任务时,能够快速适应。这属于更好地构建一个初始化网络模型参数的方式。
当然现在meta learning基本仍处于一个open的方向,上文知乎中关于MAML的介绍,将其分为三个主要方式,learning good weight initializations,meta-models that generate the parameters of other models 和learning transferable optimizers。
MAML和本文的方法应当属于第一种,NAS(网络结构搜索)应当属于第二种。
由于few shot的实验设置,基于meta learning的方法为了避免过拟合,都使用了比较浅的神经网络,比如MAML中仅使用4层卷积模块,这在一定程度,限制了模型的性能。
本文在meta learning的任务上使用了一个较深的网络(residual-net 12),不同于直接在预训练的模型上进行全局参数的fine tuning,本文在预训练模型的参数固定的前提下,对预训练模型的每层参数重新学习一个scale和shift,在保证预训练模型不损失general的特性的前提下,重新训练了参数,减小了模型参数,使得预训练模型可以transfer到新任务上,同时,比fine tuning的方式有更少的参数,可以避免过拟合。
同时,作者在训练过程中使用了 hard task meta-batch,使得模型能够学习错分的任务,进一步提高了模型性能。
接下来介绍一些背景,同时自己也复习一下few shot learning领域的一些经典方法。
Few shot learning致力于从极少部分的带标签样本进行学习,人类可以通过极少的样本结合自己的经验知识,学习到新的概念,但是这对机器还尚有困难。
Data argumentation可以在一定程度上提高few shot任务的性能,但是这并不能从本质上解决问题,构建多任务也可以提高模型性能,然而这个对各个任务的先验知识要求比较高。
Meta learning通过构建meta task用来训练一个base learner,使其在新任务上能够快速适应,如MAML,作者指出,MAML的方式,需要大量的meta task(240k),效率太低,而且只work在比较浅层的神经网络(4 conv)。
解决这两个问题,也是文章的主要贡献点。
Few shot learning可以大致划分成三个方向:metric learning method,memory network method和gradient based method。
通过在meta task中mining hard的样本,可以提高模型的鲁棒性和泛化能力,这是本文中使用的方式。
在基于few shot的设置中,MAML并非像fine tuning那样直接学习一个在support set上的最优参数,而是现在support set上学习一个模型参数,而后基于该参数使用query set在此更新参数,最终的参数是在第二次更新的参数基础上进行更新的,感觉这里可以看成git 中的branch,设置一些支线分别进行各个task,完成后最后在主线merge,主线的参数更新为最终支线所有基于query set梯度之和。
文章主要思路如下:
在large scale的数据集上训练一个模型,后续固定这个模型的参数,在其基础上重新学习一个scale和shift,下图
绿色是需要学习的参数,黄色是固定的参数,通过large scale预训练的模型参数为黄色表示的参数,如果正常进行fine tuning,对于预训练模型的所有参数都需要更新,如上图a所示,本文为每一组卷积和和偏置重新设置了可学习的参数,其中卷积和的参数(scale)用1来初始化,偏置(shift)用0来初始化,每一个channel仅有一个参数,那么对于卷积核,可以将需要学习的参数规模变为原来的(1/9)(仅对3*3的卷积核),最终的输出通过与预训练模型的参数乘加得到。
预训练模型的流程如下:
其中,  表示模型中卷积部分的参数,  表示全连接部分,在后续步骤中,  固定不变,仅对其每一层卷积学习一个scale和shift,  是需要更新的参数,因为在large scale 的数据集上的训练和few shot的实验设置会不一样(类别数不同),所以预训练的模型在使用时仅保留了卷积部分,全连接的部分重新随机初始化。
公式1,2表示的是在large scale数据集上得到预训练模型的更新公式和损失函数(梯度下降和交叉熵)。
公式3是使用meta train data更新  和  的更新公式,对应based learner的更新(第一次梯度更新,在support set)。
公式4,5是使用meta test data更新  和  的更新公式,对应meta learner的更新(对应第二次梯度更新,在query set)
公式6是通过scale和shift的到pre train mode基础上,重新学习得到的输出。
文中hard task meta batch通过meta train的query set的准确率来统计,通过对query set进行测试,可以得到分类准确率最低的m类,在构建task时,通过对包含在这个m类中的样本重新采样组成 hard task,继而组成hard task meta batch。
文章算法流程图如下:
输入包括数据集合三个不同阶段的学习率,分别对应为在large scale数据上得到预训练模型的学习率,base learner的学习率,mete learner的学习率。
1-5表示的是使用large scale数据集得到预训练模型的过程。
8-19对应随机选择一个task meta batch,进行训练,然后得到hard class-m。然后通过head class-m重新采样得到若干hard task,进行训练。
20 清空hard class-m。训练一下个meta batch。
算法2给了如果训练一个task的细节,给定一个task,
将其分为support set和query set
通过support set进行训练,更新base learner的参数,(迭代多次)
然后通过query set更新 base learner的参数(即使用测试集进行参数更新)(只进行一次)。
待每个batch中所有task都完成后,将meat learner的参数更新至主线模型。
最后统计query set的错误率,选出错误率最高的m类。
如果没有MAML的基础,这篇文章读起来会很拗口,下面展示两个数据集上的结果。
总结一下,文章就是在MAML的基础上,使用了一个较深的预训练模型,同时为了保证效果,为固定的预训练模型的每层卷积核偏移设置了一个可学习的scale和shift,然后在训练时,通过构建hard task meta batch来提高模型的泛化能力和鲁棒性。

本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。


“源头活水”历史文章


更多源头活水专栏文章,

请点击文章底部“阅读原文”查看



分享、点赞、在看,给个三连击呗!

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存