查看原文
其他

教程 | Tensorflow keras 极简神经网络构建与使用

gloomyfish OpenCV学堂 2020-02-04

点击上方蓝字关注我们

微信公众号:OpenCV学堂
星标或者置顶【OpenCV学堂】

Tensorflow keras极简神经网络构建教程

Keras介绍

Keras (κέρας) 在希腊语中意为号角,它来自古希腊和拉丁文学中的一个文学形象。发布于2015年,是一套高级API框架,其默认的backend是tensorflow,但是可以支持CNTK、Theano、MXNet作为backend运行。其特点是语法简单,容易上手,提供了大量的实验数据接口与预训练网络接口,最初是谷歌的一位工程师开发的,非常适合快速开发。Tensorflow虽然是非常流行的深度学习框架,但是tensorflow开发需要了解计算图与自动微分相关技术,对于完全没有任何深度学习基础的人不是一个很好的选择,而keras完全是为零基础的人准备,它简化了tensorflow中计算图、会话等基本概念,通过Sequential与功能API两个组件实现网络搭建,通过简单的添加一些层就可以快速搭建神经网络模型。

Mnist数据集准备

我们以mnist数据集为例,构建一个神经网络实现手写数字的训练与测试,首先我们需要认识一下mnist数据集,mnist数据集有6万张手写图像,1万张测试图像。Keras通过datase来下载与使用mnist数据集,下载与读取的代码如下:

mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) =mnist.load_data()

通过下面的代码可以显示手写数字图像:

print(train_labels[0])
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([   ])
    plt.grid(False)
    plt.imshow(train_images[i], cmap=plt.cm.gray)
    plt.xlabel(str(train_labels[i]))
plt.show()

对数据re-scale到0~1.0之间,对标签进行了one-hot编码,代码如下:

# re-scale to 0~1.0之间
train_images = train_images / 255.0
test_images = test_images / 255.0
train_labels = one_hot(train_labels)
test_labels = one_hot(test_labels)

其中one-hot编码函数如下:

def one_hot(labels):
    onehot_labels = np.zeros(shape=[len(labels), 10])
    for i in range(len(labels)):
        index = labels[i]
        onehot_labels[i][index] = 1
    return onehot_labels

建立模型

构建神经网络

  • 输入层为28x28=784个输入节点

  • 隐藏层120个节点

  • 输出层10个节点

首先需要定义模型:

model = keras.Sequential()

然后按顺序添加模型各层

model.add(keras.layers.Flatten(input_shape=(2828)))
model.add(keras.layers.Dense(units=120, activation=tf.nn.relu))
model.add(keras.layers.Dense(units=10, activation=tf.nn.softmax))

编译模型
模型还需要再进行几项设置才可以开始训练。这些设置会添加到模型的编译步骤:

损失函数
衡量模型在训练期间的准确率。我们希望尽可能缩小该函数,以“引导”模型朝着正确的方向优化。
优化器
根据模型看到的数据及其损失函数更新模型的方式。
指标
用于监控训练和测试步骤。以下示例使用准确率,即图像被正确分类的比例

model.compile(optimizer=tf.train.AdamOptimizer(), 
loss="categorical_crossentropy", metrics=['accuracy'])

训练模型
训练神经网络模型需要执行以下步骤:
将训练数据馈送到模型中,在本示例中为 train_images 和 train_labels 数组。
模型学习将图像与标签相关联。我们要求模型对测试集进行预测,在本示例中为 test_images 数组。我们会验证预测结果是否与 test_labels 数组中的标签一致。
要开始训练,请调用 model.fit 方法,使模型与训练数据“拟合”:

model.fit(x=train_images, y=train_labels, epochs=5)

评估模型
模型在测试集数据上运行:

test_loss, test_acc = model.evaluate(x=test_images, y=test_labels)
print("Test Accuracy %.2f"% test_acc)

使用模型进行预测

# 开始预测
cnt = 0
predictions = model.predict(test_images)
for i in range(len(test_images)):
    target = np.argmax(predictions[i])
    label = np.argmax(test_labels[i])
    if target == label:
        cnt += 1
print("correct prediction of total : %.2f"%(cnt/len(test_images)))

卷积神经网络

mnist数据转换为四维

train_images = np.expand_dims(train_images, axis=3)
test_images = np.expand_dims(test_images, axis=3)

创建模型并构建CNN各层

model = keras.Sequential()
model.add(keras.layers.Conv2D(filters=32, kernel_size=5, strides=(11),
                              padding='same', activation=tf.nn.relu, input_shape=(28281)))
model.add(keras.layers.MaxPool2D(pool_size=(22), strides=(22), padding='valid'))
model.add(keras.layers.Conv2D(filters=64, kernel_size=3, strides=(11),
                              padding='same', activation=tf.nn.relu))
model.add(keras.layers.MaxPool2D(pool_size=(22), strides=(22), padding='valid'))
model.add(keras.layers.Dropout(0.25))
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(units=128, activation=tf.nn.relu))
model.add(keras.layers.Dropout(0.5))
model.add(keras.layers.Dense(units=10, activation=tf.nn.softmax))

编译与训练模型

# 训练模型
model.compile(optimizer=tf.train.AdamOptimizer(), loss="categorical_crossentropy", metrics=['accuracy'])
model.fit(x=train_images, y=train_labels, epochs=10)

欢迎扫码加入【OpenCV研习社】

- 学习OpenCV+tensorflow开发技术
- 与更多伙伴相互交流、一起学习进步
- 每周一到每周五分享知识点学习(音频+文字+源码)
- 系统化学习知识点,从易到难、由浅入深
- 直接向老师提问、每天答疑辅导


推荐阅读

OpenCV学堂-原创精华文章

《tensorflow零基础入门视频教程》

OpenCV研习社介绍与加入指南

MTCNN实时人脸检测网络详解与代码演示

详解对象检测网络性能评价指标mAP计算

卷积神经网络是如何实现不变性特征提取的

深度学习中常用的图像数据增强方法-纯干货

基于OpenCV与tensorflow实现实时手势识别

tensorflow风格迁移网络训练与使用

使用tensorflow layers相关API快速构建卷积神经网络

基于OpenCV Python实现二维码检测与识别

OpenCV+Tensorflow实现实时人脸识别演示

OpenCV标准霍夫直线检测详解


关注【OpenCV学堂】

长按或者扫码即可关注

参考:

https://www.tensorflow.org/guide/keras

    您可能也对以下帖子感兴趣

    文章有问题?点此查看未经处理的缓存