查看原文
其他

【他山之石】模型转换:由Pytorch到TFlite

“他山之石,可以攻玉”,站在巨人的肩膀才能看得更高,走得更远。在科研的道路上,更需借助东风才能更快前行。为此,我们特别搜集整理了一些实用的代码链接,数据集,软件,编程技巧等,开辟“他山之石”专栏,助你乘风破浪,一路奋勇向前,敬请关注。

作者:知乎—澜渊

地址:https://www.zhihu.com/people/ling-huo-de-pang-zi-ya-xin


01

前言
目前,越来越多的开源代码由Pytorch写成,在模型定义、训练和可读性上的优势都远超Tensorflow。然而在面向移动端部署的时候,某些项目仍旧需要使用TFlite。这就引发了一个矛盾:新的算法效果很好,但我们却无法直接使用Pytorch来部署,必须要转成Tflite。
那么,我们就有两个选择:
方法1:根据Pytorch的代码,使用Tensorflow重写,得到TFlite;
方法2:在Pytorch上完成训练并保存模型后,利用模型转换工具ONNX,得到TFlite。
不用说,二者的难度不是一个等级的。对于简单一点的模型,方法1还勉强可以接受,而对于目标检测、实例分割等算法,没有个把月的时间,几乎是没办法完成代码转换的。即便完成,能否在Tensorflow上训练出和Pytorch相同的效果,也很难说,毕竟二者反向传播的方式都不同,这无疑对问题排查带来了极大的难度。
因此,这篇文章主要分享的是方法2,即通过ONNX来进行Pytorch到TFlite的模型转换,也就是:Pytorch—>ONNX—>Tensorflow—>TFlite

02

ONNX简介
ONNX(Open Neural Network Exchange)是一种针对机器学习所设计的开放式的文件格式,用于存储训练好的模型。它使得不同的人工智能框架(如Pytorch、MXNet)可以采用相同格式存储模型数据并交互。
目前官方支持加载ONNX模型并进行推理的深度学习框架有:Caffe2, PyTorch, MXNet,ML.NET,TensorRT 和 Microsoft CNTK,并且 TensorFlow 也非官方地支持ONNX。
https://pytorch.org/docs/stable/onnx.html

https://zhuanlan.zhihu.com/p/51387600

https://zhuanlan.zhihu.com/p/41255090


03

代码实现
Step0:环境配置(非常重要!!!)
torch==1.5.1
torchvision==0.6.1
tensorflow==tf_nightly-2.4.0.dev20200811
onnx==1.7.0
onnxruntime==1.7.0
onnx-tf==1.7.0
tensorflow-addons==0.11.2
Step1:由Pytorch得到ONNX
这里给出一个Pytorch的mobilenet_v2的模型转ONNX的例子,并且验证模型的输出是否相同。
import os.path as ospimport numpy as npimport onnximport onnxruntime as ortimport torchimport torchvision# torch --> onnx
test_arr = np.random.randn(10, 3, 224, 224).astype(np.float32)dummy_input = torch.tensor(test_arr)model = torchvision.models.mobilenet_v2(pretrained=True).eval()torch_output = model(torch.from_numpy(test_arr)) input_names = ["input"]output_names = ["output"]torch.onnx.export(model, dummy_input, "mobilenet_v2.onnx", verbose=False, input_names=input_names, output_names=output_names)
model = onnx.load("mobilenet_v2.onnx")ort_session = ort.InferenceSession('mobilenet_v2.onnx')onnx_outputs = ort_session.run(None, {'input': test_arr})print('Export ONNX!')
Step2:由ONNX转Tensorflow,得到.pb文件
from onnx_tf.backend import prepareimport onnx
TF_PATH = "tf_model" # where the representation of tensorflow model will be storedONNX_PATH = "mobilenet_v2.onnx" # path to my existing ONNX modelonnx_model = onnx.load(ONNX_PATH) # load onnx modeltf_rep = prepare(onnx_model) # creating TensorflowRep objecttf_rep.export_graph(TF_PATH)
Step3:由.pb得到TFlite
import tensorflow as tf
TF_PATH = "tf_model" TFLITE_PATH = "mobilenet_v2.tflite"converter = tf.lite.TFLiteConverter.from_saved_model(TF_PATH)converter.optimizations = [tf.lite.Optimize.DEFAULT]tf_lite_model = converter.convert()with open(TFLITE_PATH, 'wb') as f: f.write(tf_lite_model)
目前已完成tflite的推理,有空的时候会补充Pytorch模型、ONNX模型和tflite模型的推理结果,以及模型转换带来的误差。

本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。


“他山之石”历史文章


更多他山之石专栏文章,

请点击文章底部“阅读原文”查看



分享、点赞、在看,给个三连击呗!

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存