查看原文
其他

【源头活水】多任务权重自动学习论文介绍和代码实现

“问渠那得清如许,为有源头活水来”,通过前沿领域知识的学习,从其他研究领域得到启发,对研究问题的本质有更清晰的认识和理解,是自我提高的不竭源泉。为此,我们特别精选论文阅读笔记,开辟“源头活水”专栏,帮助你广泛而深入的阅读科研文献,敬请关注。

作者:知乎—yuexiang

地址:https://zhuanlan.zhihu.com/p/367881339

本文介绍论文Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics

论文地址:https://arxiv.org/pdf/1705.07115.pdf

该论文主要介绍了一种多任务损失函数的权重怎么调。假定有3个任务,分别是semantic classification, instance regression 和depth regression。
一般会对各个损失函数加权求和作为最后的损失函数。
下面这个表展示了semantic classification 和 depth regression两个任务分别取不同的权值对应的IoU (越大越好)和Depth Error(越小越好)。本文介绍的论文的方法可以取得相对于grid search更好的结果。由于grid search需要尝试很多次,而论文的方法在一次训练中可以自动调节权重,时间成本会低很多。
上面两个任务一个是分类,一个是回归。下面是instance regression和depth regression两个任务,都是回归,也得到了同样的结论。
三个任务的实验结果,即论文中的方法可以扩展到更多的任务
假定输入为  , 模型输出为  ,如果是回归任务,定义给定  ,  是以模型输出  为均值,  为方差的正态分布。
如果是分类任务,模型输出后加上Softmax层作为每个类别的概率。
在多任务场景,假定输出  , …   多个任务独立,
计算log probability 为:
计算过程即带入正态分布的概率密度函数,化简就可以得到了。
假定两个任务都为回归任务,计算联合概率为:
损失函数计算为
计算过程就是将  代入,然后根据log两个相乘变为加法。其中  为
  也类似。
如果是分类,在经过softmax层时,除以温度  , 得到每个类的概率分布。
计算真实类别的log概率为:
计算过程也是将  代入,然后把softmax展开,根据log 两个数相除为log第一个数减去log第二个数。
假定两个任务一个是回归任务,一个是分类任务,损失函数计算为:
其中  为:
在实践中,令  为超参数,  不为0,从而保证  的稳定性。
代码:
假如有两个任务,定义两个任务的log方差分别是:
log_var_a = torch.zeros((1,), requires_grad=True)log_var_b = torch.zeros((1,), requires_grad=True)

然后将这两个参数加入优化器优化:

params = ([p for p in model.parameters()] + [log_var_a] + [log_var_b])optimizer = optim.Adam(params)
计算损失时:
def criterion(y_pred, y_true, log_vars): loss = 0 for i in range(len(y_pred)): precision = torch.exp(-log_vars[i]) diff = (y_pred[i]-y_true[i])**2. loss += torch.sum(precision * diff + log_vars[i], -1)return torch.mean(loss)
其中y_pred[i]为第i个任务的预测,y_true[i]为第i个任务的标签,log_vars[i]为第i个任务的log方差,这里是两个回归任务,计算预测和标签的差值为diff = (y_pred[i]-y_true[i])**2. 计算这个任务的权重为precision = torch.exp(-log_vars[i]),然后用权重乘以差值precision * diff并加上log方差作为正则项,这可以防止log方差过大。当方差过大时,该任务对应的权重会很小。
论文Auxiliary Tasks in Multi-task Learning对正则项进行了修改,论文地址https://arxiv.org/pdf/1805.06334.pdf
定义多任务损失为:
修改正则项从log something变为 log (1 + something),来保证正则项不为负,最终损失函数为:
代码实现
class AutomaticWeightedLoss(nn.Module): def __init__(self, num=2): super(AutomaticWeightedLoss, self).__init__() params = torch.ones(num, requires_grad=True) self.params = torch.nn.Parameter(params)
def forward(self, *x): loss_sum = 0 for i, loss in enumerate(x): loss_sum += 0.5 / (self.params[i] ** 2) * loss + torch.log(1 + self.params[i] ** 2)        return loss_sum
为什么介绍这篇文章?
最近在做多项选择自动解题(MCQA)的任务, 下面这篇AAAI 2020的论文利用多任务取得了很好的结果,论文地址:https://aaai.org/ojs/index.php/AAAI/article/view/7194/7048
对比BERTlarge WAE 和 BERTlarge,作者利用多任务取得了准确率超过baseline 2.7个点的结果。
其中正确选项的loss为交叉熵:
作者提出了一个Wrong Answer Ensemble (WAE) 的方法,错误项的标签为1,正确项的标签为0,用来教会模型哪些是错误项,用于模拟人做选择题时,排除错误项。损失为Binary Cross Entropy。
测时用一个线性回归找到最好的参数  , 用于结合正确选项的logits   和错误选项的logits   ,类似用模型融合。
其中BERTlarge    为 5.2 , BERT-base为2.2。
Happy Reading, Happy Learning!
谢谢阅读,如有错误,欢迎批评指正~
参考
https://arxiv.org/pdf/1705.07115.pdfhttps://arxiv.org/abs/1703.04977https://arxiv.org/pdf/1805.06334.pdfhttps://github.com/yaringal/multi-task-learning-example/blob/master/multi-task-learning-example-pytorch.ipynbhttps://github.com/Mikoto10032/AutomaticWeightedLoss/blob/master/AutomaticWeightedLoss.pyhttps://github.com/ranandalon/mtlhttps://aaai.org/ojs/index.php/AAAI/article/view/7194/7048

本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。


“源头活水”历史文章


更多源头活水专栏文章,

请点击文章底部“阅读原文”查看



分享、在看,给个三连击呗!

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存