深度学习算法(第18期)----用RNN也能玩分类
上期我们一起学习了RNN是怎么处理变化长度的输入输出的,
深度学习算法(第17期)----RNN如何处理变化长度的输入和输出?
我们知道之前学过CNN在处理分类问题上的强大能力,今天我们看下前几期介绍的RNN是如何玩分类的。
MNIST数据集,我们都已经很熟悉了,是一个手写数字的数据集,之前我们用它来实战CNN分类器和机器学习的方法(在公众号中回复“MNIST”,即可免费下载)。今天我们就用RNN来对MNIST数据集进行一个预测。
这个时候,我们需要将每一张数据图像当成一个28x28的序列信号(图像的大小为28x28pixels)。对于整个网络框架,我们使用一个150个循环神经元外加一个有10个神经元的全连接层(每个类对应一个),最后接一个softmax层。如下:
from tensorflow.contrib.layers import fully_connected
n_steps = 28
n_inputs = 28
n_neurons = 150
n_outputs = 10
learning_rate = 0.001
X = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
y = tf.placeholder(tf.int32, [None])
basic_cell = tf.contrib.rnn.BasicRNNCell(num_units=n_neurons)
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, dtype=tf.float32)
logits = fully_connected(states, n_outputs, activation_fn=None)
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=y, logits=logits)
loss = tf.reduce_mean(xentropy)
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
training_op = optimizer.minimize(loss)
correct = tf.nn.in_top_k(logits, y, 1)
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
init = tf.global_variables_initializer()
接下来,我们加载数据集,并对数据集进行reshape,如下:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/")
X_test = mnist.test.images.reshape((-1, n_steps, n_inputs))
y_test = mnist.test.labels
现在,我们将对上面的RNN进行training,在执行阶段跟之前的dnn也是非常类似的,如下:
n_epochs = 100
batch_size = 150
with tf.Session() as sess:
init.run()
for epoch in range(n_epochs):
for iteration in range(mnist.train.num_examples // batch_size):
X_batch, y_batch = mnist.train.next_batch(batch_size)
X_batch = X_batch.reshape((-1, n_steps, n_inputs))
sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
acc_train = accuracy.eval(feed_dict={X: X_batch, y: y_batch})
acc_test = accuracy.eval(feed_dict={X: X_test, y: y_test})
print(epoch, "Train accuracy:", acc_train, "Test accuracy:", acc_test)
输出的结果如下:
0 Train accuracy: 0.713333 Test accuracy: 0.7299
1 Train accuracy: 0.766667 Test accuracy: 0.7977
...
98 Train accuracy: 0.986667 Test accuracy: 0.9777
99 Train accuracy: 0.986667 Test accuracy: 0.9809
最终得到了98%的准确率,还挺不错的,如果我们调整下超参数或者RNN权重初始化的方式,训练的更久一些,或者加一些正则化的方法,结果应该还会更好。学习了RNN的分类玩法,下一期我们将实战下RNN在时序信号上的预测能力。
今天我们主要从我们熟悉的MNIST数据集出发,来更深层次的学习了下RNN在分类方面的知识,希望有些收获,欢迎留言或进社区共同交流,喜欢的话,就点个“在看”吧,您也可以置顶公众号,第一时间接收最新内容。