查看原文
其他

基于梯度下降法的——线性回归拟合

柯广 大数据技术派 2022-10-15


阅读本文需要的知识储备:

  • 高等数学

  • 运筹学

  • Python基础


引出梯度下降


对于,线性回归问题,上一篇我们用的是最小二乘法,很多人听到这个,或许会说:天杀的最小二乘法,因为很多人对它太敏感了。是的,从小到大,天天最小二乘法,能不能来点新花样。这里就用数学算法——梯度下降,来解决,寻优问题。


当然了,我们的目标函数还是:

在开始之前,我还是上大家熟知常见的图片。

梯度下山图片(来源:百度图片)


找了好久,我选了这张图片,因为我觉得这张图片很形象:天气骤变,一个人需要快速下山回家,但是他迷路了,不知道怎么回家,他知道自己家位于这座山海拔最低处。环顾四周,怎么样最快下山回家呢。他个子一定(假设1.8m大个子吧),每次迈步子为平时走路最大步长了,哈哈!(假设保持不变),要往哪个方向走才能使得:每迈出一步,自己下降的高度最大呢?只要每步有效下降高度最大,我们完全有理由相信,他会最快下山回家。

所以:他会告诉自己,我每次要找一个最好的下山方向(有点像“贪心”)。

其实,这个图还反映了另外一个问题,对于有多个极值点的情况,不同的初始出发点,梯度下降可能会陷入局部极小值点。就像一句古诗:不识庐山真面目,只缘身在此山中!这时候,就需要多点随机下山解决。当然了,解决线性回归问题的梯度下降是基于误差平方和,只有二次项,不存在多峰问题。


梯度下降的理论基础


我们都现在都知道这个人的任务是什么了:每次要找一个最好的下山方向。数学微分学告诉我们:其实这里的方向就是我们平时所说的:方向导数,它可以衡量函数值沿着某个方向变化的快慢,只要选择了好的方向(导数),就能快速达到(最大/最小值)。

(1)、梯度的定义

这里还是摆一个公式吧,否则看着不符合我的风格。方向导数定义就不扯远了,那是个关于极限的定义。这里给出三元函数梯度定义公式:



显然,让沿着与梯度方向,夹角为0或者180°时函数值增减最快。

其实,每个多元函数在任一点会有一个梯度。函数在某一点沿着梯度方向,函数值是变化最快的。这里就不过多证明了。

(2)、步长的求法

其实,我们可以设定一个指定步长。但是,这个指定步长到底设为多大合适。众所周知,过大会导致越过极小值点;过小在数据量大时会导致迭代次数过多。所以我们需要一套理论可以来科学得计算步长。保证在步长变换过程中,尽管有时可能会走回头路,但总体趋势是向驻点逼近。

这里简单说一下,假设在图中一点沿着梯度方向存在二阶偏导数,就可以泰勒展开到平方项,进而对这个关于步长的函数求导数,导函数零点就是此时最佳步长。详细可以参见运筹学推导。我尽量少写公式,多说明,哈哈。

用到的公式主要是步长lambda公式如下:


说明:下三角f表示梯度,海赛矩阵,X1,X2这里表示各个变量(这里是两个),对于连续函数,偏导数不分先后,所以不要算两遍,后面写程序会用到!这样每走一步,都会重新设置步长,与定步长相比,是不是更加智能了?

下降停止标志:梯度趋于0,或者小于给定的eps。


有了这些理论基础后,编程实现就容易多了,下面就编程实现了。


线性关系呢。最著名的当数最小二乘法了,很多人都知道。


梯度下降的Python实现


这里用的与上一片一样的数据。

(1)、用到的函数:

不同点的梯度函数,海赛矩阵函数,迭代主函数

这里用到的比如点乘函数,在第一篇《基于最小二乘法的——线性回归拟合(一)》里面有我是放在一个脚本里面的,所以这里没有写两次,你们可以把两个脚本放在一起是没有问题的。

程序代码:

1#-----------------梯度下降法----------------
2#返回梯度向量
3def dif(alpha,beta,x,y):
4   dif_alpha = -2*sum(err(alpha,beta,x,y))
5   dif_beta = -2*dot(err(alpha,beta,x,y),x)
6   return(dif_alpha,dif_beta)
7#返回海赛矩阵
8def hesse(x):
9   return([[2*len(x),2*sum(x)],[2*sum(x),2*dot(x,x)]])
10#计算lambda
11def lam(x1,x2):
12   s1 = dot(x1,[x2[0][0],x2[1][0]])
13   s2 = dot(x1,[x2[0][1],x2[1][1]])
14   return(dot(x1,x1)/dot([s1,s2],x1))
15#导入数学、随机数模块
16import math
17import random
18def grad(x,y):
19   #设置最大计算次数
20   n_max = 100
21   k = 0
22   error_ = 0.001
23   alpha,beta = random.random(),random.random()
24   #计算梯度向量
25   d_f = dif(alpha,beta,x,y)
26   while(math.sqrt(dot(d_f,d_f))>error_ and k<n_max):
27      h = hesse(x)
28      lamb = lam(d_f,h)
29      alpha,beta = [alpha-lamb*d_f[0],beta-lamb*d_f[1]]
30      d_f = dif(alpha,beta,x,y)
31      k+=1
32   else:
33      return(alpha,beta,k,math.sqrt(dot(d_f,d_f)))
34alpha,beta,k,error = grad(x,y)
35print('\n*------------梯度下降-----------*')
36print('系数为:',alpha,beta)
37print('梯度下降拟合次数为:',k)
38print('梯度为:',error)
39print('误差为:',error_total(alpha,beta,x,y))
40R_square = r_square(alpha,beta,x,y)
41print('R方:',R_square)
42if(R_square>0.95):
43   print('在0.05置信水平下,该线性拟合不错!')
44else:
45   print('在0.05置信水平下,该线性拟合效果不佳!')
46#可视化
47plt.figure(2)
48plt.scatter(x,y,marker = '*',color = 'b')
49plt.xlabel('x label')
50plt.ylabel('y label')
51plt.title('Linear Fit')
52plt.plot(x,[alpha+beta*x_i for x_i in x],color = 'r')
53plt.show()
54
55print('\n#-------------多个初始点下山---------------#')
56for i in range(10):
57   alpha,beta,k,error = grad(x,y)
58   R_square = r_square(alpha,beta,x,y)
59  print('系数为:',alpha,beta,'\n误差为:',error_total(alpha,beta,x,y),'\nR方:',R_square)
60   if(R_square>0.95):
61      print('在0.05置信水平下,该线性拟合不错!')
62   else:
63      print('在0.05置信水平下,该线性拟合效果不佳!')
64   print('*********************************************')


(2)、结果

*------------梯度下降-----------*
系数为:2.1672851935 2.5282506525292012
梯度下降拟合次数为:5
梯度为:1.2745428915606112e-05
误差为:9.898083702910405
R方:0.9558599578256541
在0.05置信水平下,该线性拟合不错!

拟合图如下


1#-------------多个初始点下山---------------#
2系数为:2.167285891989479 2.528250598680116
3误差为:9.898083702904094
4R方:0.9558599578256822
5在0.05置信水平下,该线性拟合不错!
6*********************************************
7系数为:2.167282336941068 2.5282508727544775
8误差为:9.898083702990858
9R方:0.9558599578252953
10在0.05置信水平下,该线性拟合不错!
11*********************************************
12系数为:2.167285928067579 2.5282505958987773
13误差为:9.898083702903905
14R方:0.9558599578256831
15在0.05置信水平下,该线性拟合不错!
16*********************************************
17系数为:2.1672811054772247 2.528250967694748
18误差为:9.898083703052635
19R方:0.9558599578250199
20在0.05置信水平下,该线性拟合不错!
21*********************************************
22系数为:2.1672836911979947 2.528250768347593
23误差为:9.898083702941747
24R方:0.9558599578255144
25在0.05置信水平下,该线性拟合不错!
26*********************************************
27系数为:2.1672838440861364 2.5282507565614916
28误差为:9.898083702937456
29R方:0.9558599578255335
30在0.05置信水平下,该线性拟合不错!
31*********************************************
32系数为:2.1672853294236947 2.5282506420502253
33误差为:9.898083702908751
34R方:0.9558599578256615
35在0.05置信水平下,该线性拟合不错!
36*********************************************
37系数为:2.1672857750441694 2.5282506076959184
38误差为:9.898083702904778
39R方:0.9558599578256792
40在0.05置信水平下,该线性拟合不错!
41*********************************************
42系数为:2.16728609101821 2.5282505833364226
43误差为:9.89808370290327
44R方:0.9558599578256859
45在0.05置信水平下,该线性拟合不错!
46*********************************************
47系数为:2.1672842715049874 2.528250723609833
48误差为:9.898083702926757
49R方:0.9558599578255812
50在0.05置信水平下,该线性拟合不错!
51*********************************************


当然了,这里多个初始点随机梯度下降不需要,以后对于多元多峰函数这是有必要的


结果分析


1*----------梯度下降----------*
2系数为:2.1672851935 2.5282506525292012
3梯度下降拟合次数为:5
4梯度为:1.2745428915606112e-05
5误差为:9.898083702910405
6R方:0.9558599578256541
7在0.05置信水平下,该线性拟合不错!


可以对比最小二乘法与梯度下降误差,我们猜测肯定是梯度下降误差大一些,因为最小二乘法基于函数极值点求法肯定是全局最优的,梯度下降由于随机原因与步长可能是靠近最优,哈哈!在有多个极值点的情况下可能是局部最优解。


1*----------最小二乘法-------*
2
3系数为:2.6786542252575067 2.538861110659364
4
5误差为:6.8591175428159215
6
7R方:0.9696451619135048
8
9在0.05置信水平下,该线性拟合不错!
10
11*------------梯度下降-----------*
12
13系数为:2.1672851935 2.5282506525292012
14
15梯度下降拟合次数为:5
16
17梯度为:1.2745428915606112e-05
18
19误差为:9.898083702910405
20
21R方:0.9558599578256541
22
23在0.05置信水平下,该线性拟合不错!


可以看出,梯度为:1.2745428915606112e-05,已经接近0了,跟据实际精度会有不同。显然,梯度下降这里不存在局部极值点问题,只能是步长迈过去了,但这个点一定是靠近最优解的,误差非常小。


欢迎留言,觉得不错,记得【点赞】分享哦!点击【阅读原文】访问我的个人博客!



猜你可能喜欢

R语言(绘图入门)

多元线性回归、逐步回归、逻辑回归的总结

K-Means算法、非负矩阵分解(NMF)与图像压缩

Python系列之——好用的Python开发工具


您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存