论文笔记:Model-Contrastive Federated Learning (MOON) 联邦学习撞上对比学习
前言
本文作者是新加坡国立大学的Qinbin Li(博士生,导师 何炳胜),Bingsheng He(何炳胜教授,导师 宋晓东)以及加州大学伯克利分校的Dawn Song(宋晓东教授,论文总引用量7万+)。
论文一作个人主页: https://qinbinli.com 论文链接:https://arxiv.org/pdf/2103.16257.pdf 代码:https://github.com/QinbinLi/MOON 会议:CVPR 2021
CVPR作为计算机视觉领域的顶级会议(CCF-A),目前有4篇联邦学习相关的论文
Multi-Institutional Collaborations for Improving Deep Learning-Based Magnetic Resonance Image Reconstruction Using Federated Learning
Model-Contrastive Federated Learning
FedDG: Federated Domain Generalization on Medical Image Segmentation via Episodic Learning in Continuous Frequency Space
Soteria: Provable Defense Against Privacy Leakage in Federated Learning From Representation Perspective
Motivation
联邦学习的关键挑战是客户端之间数据的异质性(Non-IID),尽管已有很多方法(例如FedProx,SCAFFOLD)来解决这个问题,但是他们在图像数据集上的效果欠佳(见实验Table1)。 传统的对比学习是data-level的,本文改进了FedAvg的本地模型训练阶段,提出了model-level的联邦对比学习(Model-Contrastive Federated Learning)
作者从NT-Xent loss中获得灵感,提出了model-contrastive loss。model-contrastive loss可以从两方面影响本地模型 1. 本地模型能够学到接近于全局模型的representation 2. 本地模型可以学到比上一轮本地模型更好的representation
背景知识
联邦学习训练过程
本文主要针对客户端本地训练阶段进行了改进(说白了就是加了个loss)。
对比学习的基本想法是同类相聚,异类相离。从不同的图像获得的表征应该相互远离,从相同的图像获得的表征应该彼此靠近。
常用NT-Xent loss(the normalized temperature-scaled cross entropy loss)
SimCLR伪代码:
Preliminary Experiment
本文基于这样一个直观的想法来解决Non-IID问题:
the model trained on the whole dataset is able to extract a better feature representation than the model trained on a skewed subset.
2a:用所有数据集放在一起训练一个CNN模型。
2b:将所有数据集以Non-IID的方式划分10个客户端,各自训练CNN模型,最后随机选择一个客户端的模型。
2c:在10个客户端上使用FedAvg算法训练得到一个global model(10个本地模型加权平均)
2d:在10个客户端上使用FedAvg算法训练,然后随机选择一个客户端的local model。(2d学习到的蓝色的类别表征明显比2c差)
通过T-SNE可视化表征向量,证实了如下观点:全局模型应该要比本地模型的性能好(全局模型能学到一个更好的表征),因此在non-iid的场景下,我们应该控制这种drift以及处理好由全局模型和本地模型学到的表征。
方法:MOON
Since there is always drift in local training and the global model learns a better representation than the local model, MOON aims to decrease the distance between the representation learned by the local model and the representation learned by the global model, and increase the distance between the representation learned by the local model and the representation learned by the previous local model.
(上一轮本地训练好的发往server的模型得到的表征)固定 (这轮开始时发送到本地的全局模型得到的表征)固定 (这轮正在被更新的本地模型得到的表征)不断被更新
MOON的优化目标(loss)如下:
The network has three components: a base encoder, a projection head, and an output layer.
SimCLR和MOON
SimCLR是想让同一张图片(数据层面)的不同view的表征zi和zj最大程度地相近 MOON是想让全局模型和本地模型的参数(模型层面)对应的表征zglob和zlocal最大程度地相近。
作者还提到,理想情况下(IID),全局模型和本地模型训练得到的表征应该是一样好的,那么lcon是一个常数,此时会得到FedAvg一样的效果。在这种意义上,MOON比FedAvg更具鲁棒性(能处理Non-IID的情况)
实验
数据集
作者通过实验展示了在数据集Non-IID的情况下FedProx,SCAFFOLD这些方法应用到图片数据集的效果会大打折扣,甚至和FedAvg一样差。
SOLO表示每个客户端只利用自己本地数据训练模型
总结
一句话总结:作者在联邦学习本地模型训练的时候加了个model-contrastive loss,使得在Non-IID的图片数据集上训练的联邦学习模型效果很好。
往期推荐