基于网络的梯度下降法
今天要跟大家分享的主题是基于网络的梯度下降法(Network Gradient Descent;NGD)。内容主要包括网络梯度下降法的应用背景、具体算法介绍以及理论性质。
背景介绍
首先对经典的梯度下降法做一个回顾,假设共有个观测值,我们的目标是寻找光滑损失函数的极小值点。则经典梯度下降的更新公式为:
其中代表算法在第t步得到的解,也称为模型权重(weights),代表关于的一阶导数,为学习率(learning rate)。传统情况下,整个数据集样本量较小,且均存储在同一台机器上。此时经典梯度下降法易于实现、计算复杂度低,是现实中(例如,深度学习)所使用优化算法的最基本框架。
然而,在当今大数据分析的不断发展下,许多有用的数据分布在大量不同设备中。用户期待汇总这些数据信息训练一个全局的模型,获得更好的效果。但是,这些数据可能是非常敏感的,人们不能直接将它们上传到同一个中心服务器(Server)从而使用经典的梯度下降方法训练模型。因此,如何在满足数据隐私和安全的前提下,设计一个新的算法,让用户的数据能够被高效地共同使用,就是一个非常重要的问题。为解决该问题,一个最自然的应对方法就是向中央服务器传输梯度而非原始数据。具体步骤如下:
Step1. 每个设备利用本机的局部数据计算梯度,并将梯度信息传给中央服务器。
Step2. 中央服务器收集各设备的梯度进行汇总平均,并采用梯度下降法更新模型参数后,将更新参数返回给各设备。
Step3. 重复Step1和Step2,直到达到停止条件。
由于中央服务器能够收集到全样本梯度,因此上述方式等价于做全局梯度下降。但这样的方法在每一步更新迭代的过程中都需要各设备与中央服务器进行信息传输,因此过于依赖中央服务器,从而存在以下问题:(1)中央服务器一旦被攻击,则攻击者能够与所有设备交流,窃取信息(2)中央服务器一旦无法工作,则整个算法崩溃(3)所有设备同时向中央服务器传输信息,因此对带宽(bandwidth)的要求很高。为了进一步解决过度依赖中央服务器的问题,可以考虑通过设备间(而不再需要中央机器)的信息传输来实现梯度下降,这就是基于网络的梯度下降法。图1给出一个具体例子:4 台设备形成一个环状的网络。每台设备把模型权重传给下一台设备,并接受上一台设备的模型权重。整个过程不依靠和 Server 之间传递信息,完全由设备之间传递信息进行模型更新。
模型和算法
在介绍基于网络的梯度下降法前,首先给出一些符号设定:假设共有个独立同分布的观测值,对应损失函数为,感兴趣的参数为,全局极小值点为。整个数据集分布在台设备上。设备之间形成一个网络结构,每个设备代表网络的节点(node)。将网络对应的邻接矩阵记为,如果设备能够接收到设备传输的信息,则,否则。将进行行正则化,则得到权重矩阵 ,其中。进一步,定义整个数据集的指标集。其中 代表分布在第台设备上样本的指标集,这里简单起见,假设第台设备上的样本数满足。因此,全局的损失函数可被重写为,
在上述模型设定下,我们给出第t步迭代下,第m台设备网络梯度下降法的更新方式:
即第
理论性质分析
给出网络梯度下降法的具体算法后,我们关心算法的理论性质,包括:数值收敛性质以及统计学性质。这里算法的数值收敛性质指,算法何时存在稳定解且算法能够收敛至该稳定解。而算法的统计学性质指,当算法数值收敛时,其稳定解和全局最优解之间的差异受到什么因素影响,稳定解在什么条件下和全局极小值的渐近统计学性质相同。为阐释网络梯度下降法理论分析的核心思想,我们从最简单的线性回归模型出发。
a) 线性回归模型和最小二乘损失
考虑线性回归模型
此时,第
其中,。将
其中,,。这里
我们首先研究算法的数值收敛性质。为此不妨先假设稳定解存在,记为
称
从定理1可知,网络梯度下降法的数值收敛性质仅受到学习率
定理1保证了网络梯度下降法在
从定理2中,可知NGD估计量与OLS估计量的距离受到3个因素影响,分别为学习率
b) 一般损失函数情形
定理2的结论能够被进一步推广至一般损失函数(e.g. 负对数似然函数)。对于一般损失函数,其稳定解不再具有显示形式,因此直接研究算法在第t步迭代得到的估计量与全局最小值之间的距离。具体而言,记,和
由定理3可知,在一般损失函数下,我们能够得到与定理2类似的结论,即NGD估计量与全局最小值间的距离被控制。为使得NGD估计量与全局最小值有相同率的渐近统计学效率,我们需要(1)
相关文献研究
现在有丰富的文献研究基于网络的梯度下降法,根据对网络结构的假设、数据分布模式的假设以及理论分析类型,文献可以大致按照如下表格进行划分,感兴趣的读者可以阅读相关文章。
参考文献
[1] Yuan, K., Ling, Q., and Yin, W. (2016), “On the convergence of decentralized gradient descent,” SIAM Journal on Optimization, 26, 1835–1854.
[2] Tang, H., Lian, X., Yan, M., Zhang, C., and Liu, J. (2018), “Decentralized training over decentralized data,” in International Conference on Machine Learning, PMLR, pp. 4848–4856.
[3] Nedic, A., Olshevsky, A., and Shi, W. (2017), “Achieving geometric convergence for distributed optimization over time-varying graphs,” SIAM Journal on Optimization, 27, 2597–2633.
[4] Lian, X., Zhang, C., Zhang, H., Hsieh, C.-J., Zhang, W., and Liu, J. (2017), “Can decentralized algorithms outperform centralized algorithms? a case study for decentralized parallel stochastic gradient descent,” arXiv preprint arXiv:1705.09056.
[5] Lian, X., Zhang, W., Zhang, C., and Liu, J. (2018), “Asynchronous decentralized parallel stochastic gradient descent,” in International Conference on Machine Learning, PMLR, pp. 3043–3052.
[6] Richards, D. and Rebeschini, P. (2019), “Optimal statistical rates for decentralised non-parametric regression with linear speed-up,” arXiv preprint arXiv:1905.03135.
[7] Richards, D., Rebeschini, P., and Rosasco, L. (2020), “Decentralised learning with random features and distributed gradient descent,” in International Conference on Machine Learning, PMLR, pp. 8105–8115.gression with linear speed-up,” arXiv preprint arXiv:1905.03135.
[8] Vanhaesebrouck, P., Bellet, A., and Tommasi, M. (2017), “Decentralized collaborative learning of personalized models over networks,” in Artificial Intelligence and Statistics, PMLR, pp. 509–517.
[9] Lalitha, A., Kilinc, O. C., Javidi, T., and Koushanfar, F. (2019), “Peer-to-peer federated learning on graphs,” arXiv preprint arXiv:1901.11173.
[10] Lalitha, A., Shekhar, S., Javidi, T., and Koushanfar, F. (2018), “Fully decentralized federated learning,” in Third workshop on Bayesian Deep Learning (NeurIPS).
[11] Blot, M., Picard, D., Cord, M., and Thome, N. (2016), “Gossip training for deep learning,” arXiv preprint arXiv:1611.09726.
[12] Savazzi, S., Nicoli, M., and Rampa, V. (2020), “Federated learning with cooperating devices: A consensus approach for massive IoT networks,” IEEE Internet of Things Journal, 7, 4641–4654.
- END -