查看原文
其他

【源头活水】[Meta-Learning]对Reptile的深度解析



“问渠那得清如许,为有源头活水来”,通过前沿领域知识的学习,从其他研究领域得到启发,对研究问题的本质有更清晰的认识和理解,是自我提高的不竭源泉。为此,我们特别精选论文阅读笔记,开辟“源头活水”专栏,帮助你广泛而深入的阅读科研文献,敬请关注。

来源:知乎—周威

地址:https://zhuanlan.zhihu.com/p/239929601

大家发现了,meta-learning中取名都挺有意思,一会儿哺乳动物(MAML),一会儿又是爬行动物(Reptile),这个领域不会都动物学家转过来的吧?
有关论文和代码链接,下面给出
Paper:
https://arxiv.org/abs/1803.02999
Code:
https://github.com/gabrielhuang/reptile-pytorch

01

Reptile论文中的一些创新点
作者在论文的introduction部分的最后,总结了该论文的一些贡献,具体如下:
也就是作者提出了一个叫做Reptile的Meta-Learning,该方法和First-Order MAML很像,但是实现起来更简单。
毕竟我们在之前的学习MAML的过程中,是要将训练数据集分为support set和query set的。support set和query set的作用各不相同。
而作者的Reptile并不需要将数据集分为support set和query set的。
注意一下,这里作者提到的fast weight 和slow weight是什么意思呢?
我们在MAML中提到,我们的目标是需要学习到一个好的初始化权重。
学习这个初始化权重需要两个更新(或者称为两个梯度下降):
第一个更新/梯度下降(叫做inner loop update)是使用一个个的task(也就是support set)不断的从初始化的权重    开始进行梯度下降,下降到    ,这里的i是task的索引,这里的    就是fast weight ;
第二个更新/梯度下降(叫做outer loop update)是使用一个个(索引为i)的query set去计算总的损失函数   。然后求解总损失函数对初始化权重的导数作为梯度即可实现第二个更新,即    ,这里的    就是slow weight,    是学习率。
Reptile中并没有进行support set和query set的划分,而是采用了一种更简单的方式进行slow weight更新,具体的算法流程如下。


02

Reptile的算法流程
这里我们引入Reptile的算法流程,如下
就这?如此简单?
对的,就是如此简单,没有support set和query set的纠缠,作者的意思非常明确。
给了一堆tasks,不断地从每个task中(采用子集k个)进行inner loop update,更新fast weight。
那么当该task的k个子集全部训练结束后,网络初始化的参数    就更新为    ,其中
  
简单来说,上面的更新其实就是fast weight的更新,即    就是fast weight,这个和MAML中fast weight的更新类似。就是在同一个任务中随机采样的k个子集(每个子集包含N-way K-shot个数据)进行训练来更新参数。
在Reptile中,更新slow weight的方向并不是像MAML中由总损失Loss对初始化参数的导数决定的。因为Reptile中并不需要query set,所以无法计算总损失。
作者直接使用    的方向,也就是    来决定slow weight更新的方向,也就是上面流程图中标黄的部分,简单粗暴,简直amazing呀!
这里不妨使用图绘进行更清晰的表述,如下图所示
这么一看,Reptile要比MAML简单很多,即便是简化后的First-Order MAML,也不如Reptile简单呀!
值得注意的是,上图中的Task1(1)、Task1(2)、Task1(3)、Task1(4)都是同一个任务(比如猫狗分类数据集)中随机采样的4个子集(每个子集N-way K-shot个数据)。
这里我们结合代码进行验证,代码如下:
# Main loopfor meta_iteration in tqdm.trange(args.start_meta_iteration, args.meta_iterations):
# Update learning rate meta_lr = args.meta_lr * (1. - meta_iteration/float(args.meta_iterations)) set_learning_rate(meta_optimizer, meta_lr)
# Clone model net = meta_net.clone() optimizer = get_optimizer(net, state) # load state of base optimizer?
# Sample base task from Meta-Train train = meta_train.get_random_task(args.classes, args.train_shots or args.shots) train_iter = make_infinite(DataLoader(train, args.batch, shuffle=True))
# Update fast net # do the first batch update steps # the grads of para are from grad to grad-p1-p2-...-pn loss = do_learning(net, optimizer, train_iter, args.iterations) state = optimizer.state_dict() # save optimizer state
# Update slow net # update the meta_net's grad parameter to meta_net.param - clone_net.param # correspond to p1+p2+p3+...+pn meta_net.point_grad_to(net)    meta_optimizer.step()
核心代码非常简单、简短。值得注意的是,这里进行了一个模型的clone,意在进行fast weight的更新,同时保留原模型的初始化参数。
同样地,如果有n个任务(比如猫狗分类、香蕉苹果分类、男女分类等),那么更新的公式为

这里的i是第i个任务的索引,一共有n个任务。上面图例中只展示了在1个任务中的slow weight更新。
至此,Reptile就讲解完毕,很简单的一个模型。但是具体背后的数学原理并不是很简单,即作者为什么选择   为slow weight的更新方向,我认为还是有必要精读下原论文的数学推导的,这里我就不细说了。

03

总结
最近看的东西比较多,也比较杂,而且还有很多idea进行验证与成果化,更新的速度比较慢,大家见谅!
看reptile原文的时候,我觉得这东西全是数学推导(看着就头大),等摸清楚了后,会发现Reptile是要比MAML简单的。具体的一些关于Reptile的应用,我也还在摸索,等有成果了会和大家进行额外分享的!

本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。


“源头活水”历史文章


更多源头活水专栏文章,

请点击文章底部“阅读原文”查看



分享、在看,给个三连击呗!

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存