使用Tensorflow Object Detection API实现对象检测
一:预训练模型介绍
Tensorflow Object Detection API自从发布以来,其提供预训练模型也是不断更新发布,功能越来越强大,对常见的物体几乎都可以做到实时准确的检测,对应用场景相对简单的视频分析与对象检测提供了极大的方便与更多的技术方案选择。tensorflow object detection提供的预训练模型都是基于以下三个数据集训练生成,它们是:
COCO数据集
Kitti数据集
Open Images数据集
每个预训练模型都是以tar文件形式存在,其中包括以下几个部分:
图协议graph.pbtxt
检查点(checkpoint)文件(odel.ckpt.data-00000-of-00001, model.ckpt.index, model.ckpt.meta)
冻结图协议包含作为常量的权重数据
一个config的配置文件
基于COCO数据集训练的模型名称、运行速度、mAP指标及输出列表如下:
二:使用模型实现对象检测
这里我们使用ssd_mobilenet模型,基于COCO数据集训练生成的,支持90个分类物体对象检测,首先需要读取模型文件,代码如下
tar_file = tarfile.open(MODEL_FILE)
for file in tar_file.getmembers():
file_name = os.path.basename(file.name)
if 'frozen_inference_graph.pb' in file_name:
tar_file.extract(file, os.getcwd())
然后加载模型完成计算图构建
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
最后通过session来执行计算图并输入适当的参数即可
# image_np == [1, None, None, 3]
image_np_expanded = np.expand_dims(image_np, axis=0)
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
scores = detection_graph.get_tensor_by_name('detection_scores:0')
classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')
# Actual detection.
(boxes, scores, classes, num_detections) = sess.run(
[boxes, scores, classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
最终检测效果 - 检测人与书
检测我的苹果电脑与喝水玻璃杯
更多相关阅读
Windows系统如何安装Tensorflow Object Detection API
在ubuntu上配置tensorflow 1.7+CUDA踩过的坑
知不足者好学,
耻下问者自满!
关注【OpenCV学堂】
长按或者扫码二维码即可关注
Tensorflow+OpenCV深度图像
+QQ 573300093