【强基固本】神经网络15分钟入门!——反向传播到底是怎么传播的?
The following article is from 括号的城堡 Author 括号先森
“强基固本,行稳致远”,科学研究离不开理论基础,人工智能学科更是需要数学、物理、神经科学等基础学科提供有力支撑,为了紧扣时代脉搏,我们推出“强基固本”专栏,讲解AI领域的基础知识,为你的科研学习提供助力,夯实理论基础,提升原始创新能力,敬请关注。
来源:括号的城堡
上一篇神经网络15分钟快速入门!足够通俗易懂了吧文章中对两层神经网络进行了描述,从中我们知道神经网络的过程就是正向传播得到Loss值,再把Loss值反向传播,并对神经网络的参数进行更新。其中反向传播正是神经网络的要点所在。
本篇将对反向传播的内容进行讲解,力求通俗,毕竟只有15分钟时间~
01
在讲反向传播之前先讲一下链式法则。
假设一个场景,一辆汽车20万元,要收10%的购置税,如果要买2辆,则正向传播的过程可以画成:
汽车单价20万,最终需要支付44万,我现在想知道汽车单价每波动1万,对最终支付价格的影响是多少。参看下图:我们从右向左依次求导,得到的值分别为
①44/44=1
②44/40=1.1
③40/20=2
那么最终价格相对于汽车单价的导数就是①×②×③=2.2
这就是链式法则。我们只需要知道每个节点导数值,然后求乘积就可以了。
链式法则的一种定义*是:
如果某个函数由复合函数表示,则该复合函数的导数可以用构成复合函数的各个函数的导数的乘积表示。
所以我们只需要关注每个节点的导数值即可。
02
下边介绍几种典型节点的反向传播算法。
2.1 加法节点
如下图:该节点可以写作z=x+y
很容易知道,z对x求导等于1,对y求导也等于1,所以在加法节点反向传递时,输入的值会原封不动地流入下一个节点。
比如:
2.2 乘法节点
如下图,该节点可以写作z=x*y
同样很容易知道,z对x求导等于y,对y求导等于x,所以在加法节点反向传递时,输入的值交叉相乘然后流入下一个节点。
比如:
2.3 仿射变换
所谓仿射变换就是这个式子,如果觉得眼生就去看上一篇文章。
画成图的话就是:
这是神经网络里的一个重要形式单元。这个图片看起来虽然复杂,但其实和乘法节点是类似的,我们对X求导,结果就是W1;对W1求导,结果就是X,到这里和乘法节点是一样的;对b1求导,结果为1,原封不动地流入即可。不过需要注意的一点是,这里的相乘是向量之间的乘法。
2.4 ReLU层
激活层我们就以ReLU为例。回忆一下,ReLU层的形式是这样的:
因为当x>0时,y=x,求导为1,也就是原封不动传递。
当x<=0时,y=0,求导为0,也就是传递值为0。
2.5 Softmax-with-Loss
Softmax-with-Loss指的就是Softmax和交叉熵损失的合称。这是我们之前提到的神经网络的最后一个环节。这部分的反向传播推导过程比较复杂,这里直接上结论吧(对推导过程感兴趣的话可以看文末参考文献*的附录A):
其中
从前面的层输入的是(a1, a2, a3),softmax层输出(y1, y2, y3)。此外,教师标签是(t1, t2, t3),Cross Entropy Error层输出损失L。
所谓教师标签,就是表示是否分类正确的标签,比如正确分类应该是第一行的结果时,(t1, t2, t3)就是(1,0,0)。
从上图可以看出,Softmax-with-Loss的反向传播的结果为(y1 − t1, y2 − t2, y3 − t3)。
03
参数的更新对象其实就是W和b,具体的在2.3中对其更新方法进行了描述,简单来说,dW就是输入值乘以X,db就等于输入值。这里用dW和db表示反向传播到W和b节点时的计算结果。
那现在该怎样更新W和b呢?
直接用W=W-dW;b=b-db么?
可以,但不太好。
其一,需要引入正则化惩罚项。这是为了避免最后求出的W过于集中所设置的项,比如[1/3,1/3,1/3]和[1,0,0],这两个结果明显前一个结果更为分散,也是我们更想要的。为了衡量分散度,我们用1/2W^2来表示。对该式求导,结果就是W。设正则化惩罚项的系数值为reg,那么修正后的dW可以写为:
其二,是步子迈的有点大。直接反向传播回来的量值可能会比较大,在寻找最优解的过程中可能会直接将最优解越过去,所以在这里设置一个参数:学习率。这个数通常很小,比如设学习率为0.0001这样。我们将学习率用epsilon表示,那么最终更新后的W和b写为:
至此,一次反向传播的流程就走完了。
04
链式法则是反向传播的基本传递方式,它大大简化了反向传播计算的复杂程度。在本例中可能还不太明显,在有些非常复杂的网络中,它的好处会更加显而易见。
另外反向传播的各个节点的算法也是比较重要的内容,本文介绍了常用的节点的反向传播计算结果,实际应用中可能会有更多的形式。不过不用担心,google一下,你就知道。
参数更新是反向传播的目的,结合例子来看可能会更容易理解。下一篇文章会不使用任何框架,纯手写一个我们之前提到的神经网络,并实现象限分类的问题。