论文分享|自适应块级正则化和知识蒸馏的联邦学习
联邦学习(FL)是一个分布式模型训练框架,它允许多个客户端在边缘计算场景中协作训练全局模型,同时不暴露客户端的本地数据。然而,FL通常面临数据异构性(如non-IID数据)和系统异构性(如计算和通信能力),导致模型具有较差的训练表现。为了应对上述两个挑战,我们在资源受限的边缘计算场景下提出了一个高效的FL框架-FedBR,它将块级正则化(Block-wise Regularization)和知识蒸馏(Knowledge distillation,KD)的思想引入到联邦学习算法FedAvg中。具体来说,我们首先根据深度神经网络(DNN)的层次顺序将模型划分为多个块。为了提高通信效率,服务器不再向客户端发送整个全局模型而是发送一部分连续的全局模型块。然后客户端使用知识蒸馏技术来吸收全局模型块的知识以缓解数据异质性带来的问题。我们为FedBR提供了一个理论上的收敛保证,并证明收敛界将随着服务器发送的模型块的数量增加而变大。此外,由于模型块数量的增加会带来更多的计算以及通信成本,我们设计了一种启发式算法(GMBS),它通过客户端不同的数据分布、计算和通信能力来确定其接收的全局模型块的数量。大量实验结果表明,与现有方法相比,FedBR在异构环境下可以将带宽消耗减少约31%,并提高约5.6%的平均测试精度。
本文工作已被 CCF 推荐 A 类国际期刊 IEEE/ACM Transactions on Networking 录用。
论文信息如下:
论文标题:Adaptive Block-wise Regularization and Knowledge Distillation for Enhancing Federated Learning 论文作者:Jianchun Liu, Qingmin Zeng, Hongli Xu, Yang Xu, Zhiyuan Wang, He Huang
第一作者简介:刘建春,中国科学技术大学计算机学院,特任副研究员,主要致力于联邦学习和边缘计算的研究。
研究意义
随着物联网(IoT)的持续快速增长,在移动电话、可穿戴设备和显示器等物联网设备上生成的数据量正在增加 。由于在传统的云计算中,数据需要传输到远程云服务器上,处理这些日益增长的数据将导致更高的延迟和更糟糕的用户体验。因此,一种新的计算范式——边缘计算被提出,它将数据的处理和存储放在网络边缘,以减少带宽和云服务器的处理压力。
随着边缘设备(或客户端)的计算能力变得越来越强大,分布式模型训练框架——联邦学习(FL)被应用于处理更复杂的边缘计算任务。在开创性的FL算法FedAvg[1]中,多个客户端使用其本地数据来训练模型,然后仅将其本地模型上传到参数服务器进行全局聚合。之后,服务器将聚合后的模型分发回客户端以进行进一步的本地训练。这种模型交互过程将一直持续到模型收敛。通过这种方式,FL可以减少延迟。因为与数据传输相比,模型传输消耗更少的时间。此外,由于服务器不会直接访问客户端的本地数据,客户端的隐私得到了很好的保护。
尽管有上述好处,但在EC中执行高效的联邦学习仍然面临三个主要挑战。(1)统计异构性。客户端的本地数据通常根据客户端的偏好和位置生成。例如,位于社区的监视器拍摄的图像大多是居民的图像,而位于十字路口的监视器则大多拍摄车辆的图像。因此,不同客户端之间的数据分布差异很大。来自不同客户端的数据样本通常是非独立同分布的(non-IID),并不能代表整体的数据分布(即来自所有客户端的总的数据分布)。non-IID数据的这种特性将严重损害模型训练性能并降低收敛速度。(2)系统异构性。参与模型训练的客户端具有不同的计算和通信能力,这分别与模型更新和传输的时间密切相关。通常,完成时间会随着它们的计算和通信能力的减弱而增加。最慢的客户端由于其最弱的计算或通信能力而需要最长的完成时间。在同步联邦学习架构中,每个全局轮次的训练时间总是取决于最慢的客户端,导致更长的完成时间。(3)通信限制。在训练期间,服务器和客户端之间传递模型需要大量带宽。因此,服务器上有限的通信带宽也是FL中的一个瓶颈。例如,在FL中,数百个客户端联合训练一个VGG16(大小约为500MB),每个全局轮需要消耗50GB以上的带宽,这可能会导致服务器上的网络拥塞。
为了减轻数据异构性的影响,FedAvg已经结合了一些相关技术,如数据扩充、客户端采样和正则化。数据扩充是一种采用随机变换或知识转移来增强训练数据多样性的技术。然而,服务器需要收集客户端的标签分布信息(例如,每个类中的样本数量),可能会泄露客户端的隐私。将客户端采样与FedAvg相结合,通过选择具有独立同分布(IID)数据的客户端来缓解数据异构性[5]。然而,一些客户端可能一直没有被选择,导致用于训练的数据量减少,使得模型训练的性能将降低。将正则化应用于FedAvg,通过使局部模型与全局模型更相似,有效地减轻了数据异构性的影响。与数据扩充和客户端采样相比,正则化获得了更多的全局信息,从而提高了全局模型的泛化性能。FedMLB[4]是一种结合知识蒸馏(KD)技术[6]的正则化方法,它在FedAvg的基础上将模型分块,并构造多个辅助分支。它通过知识蒸馏将每个辅助分支获得的输出和局部模型获得的输出作为正则化项放入损失函数中,以获得更多的全局信息。尽管这种方法可以有效地缓解non-IID数据问题,但辅助分支的引入大大增加了计算成本。此外,根据实验结果,由于所有客户端使用相同数量的辅助分支并接收整个全局模型,FedMLB在系统异构和通信资源有限的情况下性能较差。
综上所述,我们设计了一个高效的联邦学习框架FedBR,将块级正则化和知识蒸馏的思想集成到FedAvg中,以同时解决上述三个挑战。具体来说,我们首先根据DNN的层次顺序将模型划分为多个块,其中每个块由神经网络中的多个连续层组成。然后,我们利用服务器发送的全局模型块为客户端的本地模型构建多条路径(即从输入到输出)。为了减轻non-IID数据的影响,我们引入了一种新的正则化技术,通过KD吸收全局模型块的知识。由于不同数量的全局模型块可以构造不同数量的正则化路径,因此我们称之为块级正则化。与前几层相比,神经网络的后几层对模型训练性能更重要[7]。因此,服务器只在FedBR中分发全局模型的最后几个连续块,从而与FedAvg相比显著降低了通信成本。尽管更多数量的全局模型块允许客户端学习更多的全局信息,但这将导致更多的计算和通信成本。因此,如何动态地确定服务器为每个客户端发送的全局模型块的数量是一个关键挑战。本研究的主要贡献总结如下:
我们提出了一种高效的联邦学习框架FedBR,它通过仅分发部分全局模型来降低通信成本,并通过引入块级正则化和知识蒸馏来吸收全局模型块的知识从而缓解异构挑战。 我们为FedBR提供了理论上的收敛保证,并表明收敛边界将随着服务器发送的模型块数量的增加而降低。 我们提出了一种启发式算法(GMBS),该算法根据客户端不同的数据分布、计算和通信能力,自适应地确定客户端的全局模型块数量。 我们通过大量实验评估FedBR的性能。实验结果表明,在异构环境下,与基线相比,FedBR平均可以减少约31%的带宽消耗,并实现约5.6%的精度提高。
准备工作和问题定义
2.1 联邦学习
我们假设网络中共有个客户端,其目标是在参数服务器(PS)的协调下协同训练高效的全局模型,并且最小化平均损失函数,解决以下优化问题:
其中 是模型参数, 是客户端的损失函数,。客户端的损失函数被定义为:
其中 表示客户端 的本地数据集的一个样本,对应客户端的样本的损失。
我们假设有个客户端参与模型训练,其中。具体来说,FL的优化过程由多个通信轮次组成,每一轮通信过程被划分为以下三个阶段:
(1)模型分发:在每一轮全局轮次开始时,服务器分发当前的模型给个客户端,其中表示全局通信轮次的数量。
(2)本地更新:每个客户端在自己的本地数据集上执行轮本地迭代,其中。之后客户端将更新后的模型发送给服务器。
(3)模型聚合:在服务器收到所有的本地模型后,使用模型参数平均更新全局模型。
2.2 FedBR的训练过程
当客户端之间的数据分布是non-IID时,传统FL方法(如FedAvg)的模型训练性能将显著下降。为此,我们提出一种具有块级正则化的高效联邦学习框架(FedBR)。通过采用分块正则化和知识蒸馏技术,FedBR可以使客户端在本地更新过程中吸收更多的全局知识,从而减轻non-IID数据对训练性能的影响。此外,FedBR根据客户端不同的计算和通信能力,分配不同数量的全局模型块,充分利用网络资源。具体来说,我们将模型划分为个块,其中由服务器根据模型的层次结构确定。之后,FedBR的训练过程主要由以下四个阶段构成:
(1)模型分发:在每一轮全局轮次开始时,服务器分发不同数量的最后一部分连续全局模型块给个客户端。我们使用和来分别表示第轮次的全局模型和发送给客户端的全局模型块的数量。为防止客户端遗忘全局信息,我们每隔轮发送就发送一次整个的全局模型,其中。
(2)模型组合:如果客户端收到的是最后个全局模型块,它会把前个上一轮次的本地模型块和这后个全局模型块进行组合。如果收到的是整个的全局模型,它会直接将这个全局模型替代为组合模型。
(3)本地更新:每个客户端基于组合模型和块级正则化结构在自己的本地数据集上执行轮本地迭代。之后客户端将更新后的模型发送给服务器同时保存这个模型用于下一轮的模型组合。
(4)模型聚合:在服务器收到所有的本地模型后,使用模型参数平均更新全局模型。然后,FedBR继续进行第轮全局模型训练,直到全局模型收敛或网络资源耗尽。
2.3 块级正则化
为了更好地说明如何通过在FedBR中使用块正则化和知识蒸馏来更新客户端的本地模型,我们在图1中给出了一个示例。我们假设模型被分为六个块,服务器在某个全局轮次将最后三个模型块分发给客户端。客户端收到这三个块(即)后,将它们与客户端在第个全局轮次存储的本地模型的前三个块组合,形成组合模型。客户端使用这个组合模型覆盖局部模型,并对其进行次本地更新。在前向传播中,只包含本地模型块的路径称为本地路径,由本地模型块和全局模型块组成的路径称为混合路径。客户端的每个输入都经过条混合路径(例如)和一条本地路径(例如)。因此,每个输入将获得图1中的四个输出。然后,客户端将四个输出和真值标签放入损失函数中,其中我们采用基于本地路径的输出和混合路径的输出的知识蒸馏来获得全局模型块的知识。注意,在模型训练过程中我们对所有路径的输出都使用标准交叉熵损失函数。本地路径的输出和混合路径的输出之间的散度(即知识蒸馏函数)被用于正则化。在本地更新过程中,我们只更新本地模型块的参数,而保持全局模型块参数不变,即只更新块,并保持块不变。
2.4 问题定义
在本文中,我们考虑两种主要的资源:计算资源和通信资源。我们分别用和表示计算资源和通信资源的总预算。和分别表示进行一次本地更新所需的计算资源和传递一整个模型所需的通信资源。表示客户端在全局轮次所需的计算开销,则:
其中,表示客户端本地数据集大小,表示批次的大小,系数表示通过块级正则化结构产生的中间计算结果的复用关系。表示客户端在全局轮次平均产生的通信能开销,则:
我们的目标是最小化完成时间同时为每个客户端找到最优的,该优化问题可以形式化为:
其中保证每一个全局轮次的时间开销。第一组不等式反映了收敛条件,其中表示经过轮全局训练之后损失值的收敛阈值。第二组和第三组不等式表示计算和通信资源约束。第四组等式表示模型块的数量必须是整数。
算法设计
3.1 算法准备工作
(1)本地损失函数
我们使用CrossEntropy表示两个相同维度的变量和之间的交叉熵损失,则本地路径的交叉熵损失可以表示为:
整体的混合路径的交叉熵损失可以表示为:
其中,表示客户端接收的全局模型块的数量(即)。此外,我们使用正则化的知识蒸馏项来使本地路径能够吸收混合路径上的知识。以下等式给出了本地路径和混合路径之间的总体知识蒸馏项:
其中,KL表示两个相同维度变量和之间的KL散度,表示使用温度得到的输出。结合损失项和正则化项,客户端的损失函数可以用如下公式表示:
其中,和表示损失函数中的权重超参数。
(2)反馈变量
客户端接收的全局模型块的数量应取决于客户端的数据分布情况以及客户端的通信和计算能力。对于数据分布,我们使用客户端发送的本地模型与模型聚合后的全局模型之间的差异来模拟客户端上的局部数据分布与整体数据分布之间的偏差。本地模型和全局模型之间的差异与局部数据分布和整体数据分布之间的偏差成正比。我们使用来表示客户端的数据分布与整体数据分布之间的偏差。因此,由以下等式量化:
显然,当客户端的数据分布严重偏离整体数据分布时,客户端需要接收更多的全局模型块来学习更多的全局知识。我们使用和分别表示客户端某一个全局伦次所需的通信时间和计算时间。为了便于比较,我们采用了如下归一化的形式:
根据块级正则化的结构,客户端的计算时间和通信时间会随着接受模型块数量的增加而增加。基于上述分析,我们设计了一个反馈变量:
其中,指数项是为了增加时间开销的影响,表示客户端准确率的提升。
3.2 算法详细描述
为了为每一个客户端确定接收的模型块的数量,我们设计了一个基于贪心的模型块选择算法(GMBS)。服务器为了记住之前全局轮次的反馈信息为每个客户端保存一个变量,定义如下:
其中,的每一项表示服务器发送给客户端对应数量全局模型块后保存的反馈信息。我们需要记住之前轮次的反馈信息,以减少由某一轮引起的错误反馈信息的干扰。因此,由以下等式更新:
其中,表示影响实时反馈信息的超参数。由于全局模型块的数量是在训练开始时随机选择的,并且客户端的状态是可变的,因此我们不能总是采用与反馈信息的最高值相对应的作为客户端的接收的模型块的数量。因此,我们为反馈信息添加了一个惩罚项。最后,我们使用以下决策变量来选择全局模型块的数量:
其中,表示服务器过去选择了个全局模型块发送给客户端的频次。我们选择最大的嘴硬的作为客户端接收的全局模型块的数量。算法的详细流程如Algorithm 1所示。
性能评估
4.1 实验设置
我们在两个数据集上进行实验(即CIFAR-10、CIFAR-100),这两个数据集是FL算法中最常用的两个算法。我们采用ResNet18深度神经网络在两个数据集上进行模型训练。
我们将所提出的框架与四个基准方法进行比较:
FedAvg [1] 经典联邦学习算法。 FedProx [2] 在FedAvg的基础上添加了一个表示全局模型和本地模型差异的正则化项来使得本地训练更加稳定。 MOON [3] 将对比学习与FedAvg相结合,通过探测模型表征的相似性来纠正客户端的本地模型训练。 FedMLB [4] 引入多级杂交路径来吸收全局模型的知识以减少non-IID数据的负面影响 。
4.2 实验结果
图2比较了FedBR和其他方法的训练性能。与FedAvg,FedProx,MOON和FedMLB相比,FedBR在达到相同测试精度条件下分别节省了10.4%,35.4%,19.9%和23.4%的时间以及34.6%,47.6%,20.9%和22.3%的带宽。
图3展示了随机决策(FedBR-Ran)分发客户端模型块数量与使用GMBS算法(FedBR)决策分发客户端模型块数量的实验对比图。通过图3可以得到,在所有的测试数据集上,我们提出的决策算法都比随机决策算法有效,从而证明了GMBS算法的有效性。
本研究提出了FedBR框架,该框架将块级正则化和知识蒸馏的思想集成到FedAvg中,以处理资源约束下的边缘计算的数据异构性和系统异构性。我们设计了一种启发式算法(GMBS)来自适应地确定每个客户端接收到的全局模型块的数量并设计实验评估了FedBR的性能。实验结果表明,FedBR在降低资源消耗(如网络带宽和完成时间)的同时,有效地提高了模型的准确性。
[1] B. McMahan, E. Moore, D. Ramage, S. Hampson, and B. A. y Arcas, “Communication-efficient learning of deep networks from decentralized data,” in Artificial intelligence and statistics. PMLR, 2017, pp. 1273–1282.
[2] T. Li, A. K. Sahu, M. Zaheer, M. Sanjabi, A. Talwalkar, and V. Smith, “Federated op imization in heterogeneous networks,” Proceedings of Machine Learning and Systems, vol. 2, pp. 429–450, 2020.
[3] Q. Li, B. He, and D. Song, “Model-contrastive federated learning,” in Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2021, pp. 10 713–10 722.
[4] J. Kim, G. Kim, and B. Han, “Multi-level branched regularization for federated learning,” in International Conference on Machine Learning. PMLR, 2022, pp. 11 058–11 073.
[5] B. Luo, W. Xiao, S. Wang, J. Huang, and L. Tassiulas, “Tackling system and statisti al heterogeneity for federated learning with adaptive client sampling,” in IEEE INFOCOM 2022-IEEE Conference on Computer Communications. IEEE, 2022, pp. 1739–1748.
[6] G. Hinton, O. Vinyals, J. Dean et al., “Distilling the knowledge in a neural network,” arXiv preprint arXiv:1503.02531, vol. 2, no. 7, 2015.
[7] M. Luo, F. Chen, D. Hu, Y. Zhang, J. Liang, and J. Feng, “No fear of heterogeneity: Classifier calibration for federated learning with non-iid data,” Advances in Neural Information Processing Systems, vol. 34, pp. 5972–5984, 2021.
END
热文
2.论文分享 | BYOTEE: 用FPGA搭建属于你自己的TEE
推荐