查看原文
其他

【强基固本】神经网络结构下理解Logistic Regression &TF框架下构造Logistic实现Mnist分类

“强基固本,行稳致远”,科学研究离不开理论基础,人工智能学科更是需要数学、物理、神经科学等基础学科提供有力支撑,为了紧扣时代脉搏,我们推出“强基固本”专栏,讲解AI领域的基础知识,为你的科研学习提供助力,夯实理论基础,提升原始创新能力,敬请关注。

作者:知乎—猫头嘤

地址:https://www.zhihu.com/people/guo-qi-97-88

因为我出身统计专业,自以为对老伙计Logistic Regression知根知底,但几天的深度学习课程让我对它有了重新的认识。

01

神经网络角度去看Logistic Regression
Logistics Regression 翻译为逻辑斯蒂回归,这个翻译最早出现在统计学领域,他名为回归,却是用来处理分类问题,那是不是说它翻译出错了,应该叫逻辑斯蒂分类器?并不,逻辑回归中的‘回归’一词正是他的本质。
在统计学中解释逻辑回归时,一般用概率去解释他,套上sigmoid函数之后会被映射到0-1之间的数值(概率),在二分类问题中,label只有0/1,回归出来的连续数值,由极大似然估计给他临门一脚!0.7>0.5会被认为是1,0.25<0.5会认为归类到0。
但从神经网络结构的角度来看,Logistic Regression = Single-layer Feed-farward neural network with sigmoid activation。即 单层结构的-前馈神经网络-激活函数为sigmoid,结构入下:
将这个猫头视为灰度值图片(虽然图片是彩色RGB,但是你可以认为他是灰度的),并拉伸为12287个灰度单位的向量,一共有i张图片,那么每组个向量的X标记为X^{i} 用于标识来自不同图片的X,一组有12287个X,也就是灰度值,X下标用于标识同一组不同位置的向量元素
如果推导过Logistic,这些公式一定不会陌生,Logistic用微积分和矩阵推都行,后者会简单一些。
  • (1)式是一个标准的线性回归方程,有权重  和偏置  ,如果你不理解  和  意味着什么,我将用于训练的数据形式全部展开,你会顿然开悟:
 的  为一共有多少个特征输入,在猫猫头图片中,拉长后的向量长度为12287,也就是12287个X作为输入。
此外 
 代表的是输入多少组向量(图片),每一组向量有它对应的  ,后者就是  ,在二分类问题中直接决定、预测样本属于哪一类

  • (2)-(3)式子是对标准线性回归方程做映射后,求其单个图片训练并做预测后,需要一个它的Loss_Function(损失函数),这里使用的是对分类问题用的比较多的交换熵,也可以用MSE 作为模型的损失函数。
  • (4)式子就是将总共m个方程的Loss加和,作为整个模型的Loss,接下来就是求偏导,然后对每一个权重和偏置参数作梯度下降\牛顿\拟牛顿 进行迭代更新,使得整个Loss减小,停止迭代后确定权重  和偏置 

02

Mnist数据集实战,TensorFlow构造逻辑斯蒂
对于多分类的问题,这里引入经典的Mnist数据集,这个时候,如果分类的结果有10个,那么不可能再用0-1进行分类,整个策略大体不变,但是结构有稍稍的改变:
#数据的导入 来自 TF2 自带数据集import tensorflow as tfmnist = tf.keras.datasets.mnist(x_train_image,y_train_label),(x_test_image,y_test_label)= mnist.load_data()

先把数据导入,TF自带Mnist,28x28灰度值图片。

Mnist可视化
但是导入数据后y_train_label,是属于0-9的数字,这种结构模型是没法识别的,需要作独热映射(One-hot-Encoding)
#这种形式无法被线性模型识别y_train_label
array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)
import numpy as np#数据标签的处理,label需要被手动One_hot_encodingdef encode_one_hot(labels): num = labels.shape[0] res = np.zeros((num,10)) for i in range(num): res[i,labels[i]] = 1 # labels[i]表示0,1,2,3,4,5,6,7,8,9,则对应的列是1。 return res
y_train_label=encode_one_hot(y_train_label)
y_test_label=encode_one_hot(y_test_label)
现在再看看映射后的样子
#成功映射y_train_label[1]
array([1., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
接着把数据标准化了,至于为什么要标准化,因为不标准化会让计算复杂,训练时间就需要更久,一般为了节省时间都标准化
#接着到数据本身的标准化,减少我可怜cpu的负载。x_train_image = x_train_image / 255x_test_image = x_test_image / 255
这边就比较关键了,这个结构是与普通前馈神经网络根本性不同的关键,虽然逻辑斯蒂本身就是单层神经网络 :P
我这里用的是tf.keras.layers.Flatten,功能只是单纯的把向量拉长为28x28=784的向量,并没有产生任何的参数,把数据单纯的递给了下一层,而下一层的激活函数softmax用于分类,决定最后属于哪一类。
softmax的性质是所有输出的值相加是1,用于分类问题是十分方便的,因为只需要将10个神经元中概率最大的那一个作为预测的结果即可。比如预测结果为3的神经元概率为0.6,为最高,其他的都是0.1 0.2,那么我就将3作为我对这张手写图片的预测结果。
#初始化 权重和偏置
def Logistic(): Logistic=tf.keras.Sequential() Logistic.add(tf.keras.layers.Flatten(input_shape=(28,28))) Logistic.add(tf.keras.layers.Dense(10,activation=tf.nn.softmax)) #输出10个神经元,softmax用于分类 return Logistic
model=Logistic()

model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
flatten (Flatten) (None, 784) 0
_________________________________________________________________
dense (Dense) (None, 10) 7850
=================================================================
Total params: 7,850
Trainable params: 7,850
Non-trainable params: 0
_________________________________________________________________
损失函数为交换熵,优化器是SGD随机梯度下降,学习率0.1,模型评价方法是accuracy,用于评价模型优劣,如果愿意也可以用ROC曲线面积。
#设置优化器、损失函数model.compile( optimizer = tf.keras.optimizers.SGD(lr = 0.1), #优化器 loss='binary_crossentropy', #损失函数,交叉熵 metrics=['accuracy'] #准确率)
训练并评价模型,这边分了60epochs 一个epoch用1000个数据作训练。准确率大概在90%左右,和逻辑斯蒂分类的准确率差不多(本环节正是用神经网络构造逻辑斯蒂)
model.fit(x_train_image,y_train_label,epochs=60,batch_size=1000)
test_loss, test_acc = model.evaluate(x_test_image, y_test_label)print('Test Acc:',test_acc)
313/313 [==============================] - 0s 955us/step - loss: 0.0619 - accuracy: 0.9007
Test Acc: 0.9006999731063843
有哪些地方符号不对希望指正,相互进步 。
谢谢阅读~

本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。


“强基固本”历史文章


更多强基固本专栏文章,

请点击文章底部“阅读原文”查看


分享、点赞、在看,给个三连击呗!

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存