其他
Keras教程 | 基于迁移学习实现花卉图像分类
星标或者置顶【OpenCV学堂】
干货教程第一时间送达!
Application模块
Keras中的Application模块中有一系列基于ImageNet的预训练好的图像分类模型,这些模型如下:
Xception
VGG16
VGG19
ResNet50
InceptionV3
InceptionResNetV2
MobileNet
DenseNet
NASNet
MobileNetV2
加载与使用这些预训练模型可以实现一些简单的分类,ImageNet支持1000个分类。
加载VGG16模型
model = keras.applications.VGG16(weights='imagenet')
加载ResNet50模型
model= keras.applications.ResNet50(weights='imagenet')
加载VGG16模型但是不包括输出层
input_tensor = keras.Input(shape=(64, 64, 3))
vgg_model = keras.applications.VGG16(weights='imagenet', include_top=False, input_tensor=input_tensor)
vgg_model.summary()
显示模型加载以后的结构
图像分类预测
基于预训练ResNet50模型实现对图像分类预测,代码实现如下
def image_classification_demo():
model= keras.applications.ResNet50(weights='imagenet')
# load the image
src = cv.imread("D:/images/space_shuttle.jpg")
img = cv.resize(src, (224, 224))
img = np.expand_dims(img, 0)
proba = model.predict(img)
result = tf.keras.applications.resnet50.decode_predictions(proba)
print(result)
cv.putText(src, result[0][0][1],(50, 50), cv.FONT_HERSHEY_PLAIN, 2.0, (0, 0, 255), 2, 8)
cv.imshow("input", src)
cv.waitKey(0)
cv.destroyAllWindows()
基于VGG16迁移学习
数据集下载
http://download.tensorflow.org/example_images/flower_photos.tgz
5种花卉类型,接近4000张图像,分为训练集与测试集。
通过Keras的ImageDataGenerator加载数据集,代码如下
num_classes = 5
train_datagen = keras.preprocessing.image.ImageDataGenerator(
rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
train_generator = train_datagen.flow_from_directory(
'D:/images/train_data/flower_photos',
target_size=(64, 64),
batch_size=4,
shuffle=True,
class_mode='categorical')
print(train_generator.classes)
print(train_generator.class_indices)
test_datagen = keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
validation_generator = test_datagen.flow_from_directory(
'D:/images/train_data/test_img',
target_size=(64, 64),
batch_size=4,
class_mode='categorical')
构建迁移学习网络
使用VGG6的前面两个权重block,依赖block2_pool的输出,输入张量(64x64x3)
# 构建网络的层
x = layer_dict['block2_pool'].output
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Flatten()(x)
x = keras.layers.Dense(4096, activation='relu')(x)
x = keras.layers.Dropout(0.25)(x)
x = keras.layers.Dense(num_classes, activation=tf.nn.softmax)(x)
custom_model = keras.models.Model(inputs=vgg_model.input, outputs=x)
custom_model.summary()
# 是否fine-tuning整个网络或者几层
for layer in custom_model.layers[:7]:
layer.trainable = True
训练与保存模型
# 编译与训练
custom_model.compile(loss='categorical_crossentropy',
optimizer=tf.train.AdamOptimizer(0.0001),
metrics=['accuracy'])
custom_model.fit_generator(train_generator, epochs=10, validation_data=validation_generator)
# 保存整个模型
custom_model.save("D:/my_train/my_transfer_vgg.h5")
使用模型测试花卉种类预测
代码实现如下
def flowers_demo():
# 加载与使用
flower_dict = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
new_model = keras.models.load_model("D:/my_train/my_transfer_vgg.h5")
new_model.summary()
root_dir = "D:/images/train_data/test_img/tulips/"
for file in os.listdir(root_dir):
src = cv.imread(os.path.join(root_dir, file))
img = cv.resize(src, (64, 64))
img = np.expand_dims(img, 0)
result = new_model.predict(img)
index = np.argmax(result)
print(result, index)
cv.putText(src, flower_dict[index],(50, 50), cv.FONT_HERSHEY_PLAIN, 2.0, (0, 0, 255), 2, 8)
cv.imshow("input", src)
cv.waitKey(0)
cv.destroyAllWindows()
欢迎扫码加入【OpenCV研习社】
- 学习OpenCV+tensorflow开发技术
- 与更多伙伴相互交流、一起学习进步
- 每周一到每周五分享知识点学习(音频+文字+源码)
- 系统化学习知识点,从易到难、由浅入深
- 直接向老师提问、每天答疑辅导
推荐阅读
使用tensorflow layers相关API快速构建卷积神经网络
教程 | Tensorflow keras 极简神经网络构建与使用
关注【OpenCV学堂】
长按或者扫码即可关注