tensorflow中实现神经网络训练手写数字数据集mnist
tensorflow中实现神经网络训练手写数字数据集mnist
一:网络结构
基于tensorflow实现一个简单的三层神经网络,并使用它训练mnist数据集,神经网络三层分别为:
输入层:
像素数据输入28x28=784 个输入节点
隐藏层:
30个神经元节点
输出层:
10个神经元节点,对应 0 ~ 9 十个数字
图示结构如下:
网络结构的代码实现:
hidden_nodes = 30
x = tf.placeholder(shape=[None, 784], dtype=tf.float32)
y = tf.placeholder(shape=[None, 10], dtype=tf.float32)
w1 = tf.Variable(tf.truncated_normal(shape=[784, hidden_nodes]), dtype=tf.float32)
b1 = tf.Variable(tf.truncated_normal(shape=[1, hidden_nodes]), dtype=tf.float32)
w2 = tf.Variable(tf.truncated_normal(shape=[hidden_nodes, 10]), dtype=tf.float32)
b2 = tf.Variable(tf.truncated_normal(shape=[1, 10]), dtype=tf.float32)
# layer hidden
nn_1 = tf.add(tf.matmul(x, w1), b1)
h1 = tf.nn.sigmoid(nn_1)
# layer output
nn_2 = tf.add(tf.matmul(h1, w2), b2)
out = tf.nn.sigmoid(nn_2)
# loss function
error = tf.square(tf.subtract(y, out))
loss = tf.reduce_sum(error)
# back prop
step = tf.train.GradientDescentOptimizer(0.05).minimize(loss)
init = tf.global_variables_initializer()
二:数据读取与训练
读取mnist数据集
from tensorflow.examples.tutorials.mnist import inputdata
mnist = inputdata.readdatasets("MNISTdata/", onehot=True)
如果不行,就下载下来,放到本地即可
执行训练的代码如下
# accurate model
acc_mat = tf.equal(tf.argmax(out, 1), tf.argmax(y, 1))
acc = tf.reduce_sum(tf.cast(acc_mat, tf.float32))
with tf.Session() as sess:
sess.run(init)
for i in range(20000):
batch_xs, batch_ys = mnist.train.next_batch(10)
sess.run(step, feed_dict={x: batch_xs, y: batch_ys})
if i % 1000 == 0:
x_input = mnist.test.images[:1000]
y_input = mnist.test.labels[:1000]
curr_acc = sess.run(acc, feed_dict={x: x_input, y: y_input})
print("current acc : ", curr_acc)
训练结果:
测试集上对1000张手写数字图像测试正确识别921张,准确率高达92.1%。说明传统的人工神经网络表现还是不错的,这个还是在没有优化的情况下,通过修改批量数大小,修改学习率,添加隐藏层节点数与dropout正则化,可以更进一步提高识别率。
上次送书活动,感谢大家踊跃发言,留言,然图书只有三本,留言前三名
- 门德尔松
- 王健行
- 水亦心
截图为证:
请在微信公众号上,发送【本人微信号】,有效期至2018-07-14日24:00截至。过期作废!其它人可以到【京东】购买本人图书,本人一定做好答疑服务,再次感谢大家的支持与赞扬!
知不足者好学
耻下问者自满
关注【OpenCV学堂】
长按或者扫码二维码即可关注
更多相关阅读