查看原文
其他

【他山之石】Tensorflow模型保存方式大汇总

“他山之石,可以攻玉”,站在巨人的肩膀才能看得更高,走得更远。在科研的道路上,更需借助东风才能更快前行。为此,我们特别搜集整理了一些实用的代码链接,数据集,软件,编程技巧等,开辟“他山之石”专栏,助你乘风破浪,一路奋勇向前,敬请关注。

作者:知乎—卡西法

地址:https://www.zhihu.com/people/kumonoue


GraphDef

GraphDef是Tensorflow中序列化的图结构。在tensorflow中,计算图被保存为Protobuf格式(pb)。pb可以只保存图的结构,也可以保存结构加权重。


SignatureDef

定义图结构输入输出的节点名称和属性,一般存储于.index文件中。
查看方法:
list(meta_graph.signature_def.items())


tf.saved_model

将动态图保存成权重(./variables)、计算图(keras_metadata.pb)、权重和计算图(saved_model.pb)三种文件。
# 保存model = tf.saved_model.save( obj, export_dir, signatures=None, options=None)# 读取model= tf.saved_model.load( export_dir, tags=None, options=None)# 推理infer = model.signatures["serving_default"]


freeze_grap

from tensorflow.python.tools.freeze_graph import freeze_graph_with_def_protos

该函数将图和权重以常量的形式保存在一张静态图中(pb)。

其中的核心代码是:
output_graph_def = convert_variables_to_constants(session, input_graph_def, output_names)output_graph = 'pb_model/model.pb' # 保存地址with tf.gfile.GFile(output_graph, 'wb') as f: f.write(output_graph_def.SerializeToString())
参考:
https://github.com/tensorflow/tensorflow/blob/f5b9c2225584c79539ff6746b3417e8505443a4b/tensorflow/python/tools/freeze_graph.py

tf.train.Saver()

详细可参考:
https://zhuanlan.zhihu.com/p/64099452
# 保存断点saver = tf.train.Saver()saver.save()# 加载断点saver.restore()
  • .data文件保存了当前参数值
  • .index文件保存了当前参数名
  • .meta文件保存了当前图结构
  • .events文件是给可视化工具tensorboard使用。
  • .pbtxt文件是以字符串存储的计算图


tf.train.CheckpointManager()

CheckpointManager是一个管理断点的工具,是Saver更高级的API,类似于tensorflow.keras.callbacks中的Checkpoint类。CheckpointManager可以设置自动存点间隔步数、最大断点数、自动存点间隔时间等参数
其中,最新的断点文件名以字符串形式储存在checkpoint文件中。
可参考:
https://tensorflow.google.cn/api_docs/python/tf/train/CheckpointManager?hl=en
# 设置断点checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)manager = tf.train.CheckpointManager( checkpoint, directory="/tmp/model", max_to_keep=5)# 加载最新的断点status = checkpoint.restore(manager.latest_checkpoint)# 保存断点while True: # train manager.save()


查看静态图输入输出节点

可以使用Tensorflow自带工具saved_model_cli,输入的模型需要使用tf.saved_model.save或者tf.keras.models.Model实例的save属性保存的模型结构。其中需要的文件有.data(模型权重)、.index(模型的SignatureDef)和.pb(MetaGraph)。
saved_model_cli show --dir model/ --all

也可以加载静态图后,打印所有节点,逐个查看:

tensor_name_list = [tensor.name for tensor in tf.compat.v1.get_default_graph().as_graph_def().node]for tensor_name in tensor_name_list: print(tensor_name,'\n')
可以将静态图保存为summary,使用TensorBoard可视化查看:
summaryWriter = tf.compat.v1.summary.FileWriter('log/', graph)


自定义SignatureDef

上文说到SignatureDef是输入输出到静态图节点的映射,一般表示为字典的形式,下面是官方给的分类模型书写范例:
signature_def: { key : "my_classification_signature" value: { inputs: { key : "inputs" value: { name: "tf_example:0" dtype: DT_STRING tensor_shape: ... } } outputs: { key : "classes" value: { name: "index_to_string:0" dtype: DT_STRING tensor_shape: ... } } outputs: { key : "scores" value: { name: "TopKV2:0" dtype: DT_FLOAT tensor_shape: ... } } method_name: "tensorflow/serving/classify" }}
修改静态图的SignatureDef:
#保存为pb模型def export_model(session, m): #只需要修改这一段,定义输入输出,其他保持默认即可 model_signature = signature_def_utils.build_signature_def( inputs={"input": utils.build_tensor_info(m.a)}, outputs={ "output": utils.build_tensor_info(m.y)},
method_name=signature_constants.PREDICT_METHOD_NAME)
export_path = "pb_model/1" if os.path.exists(export_path): os.system("rm -rf "+ export_path) print("Export the model to {}".format(export_path))
try: legacy_init_op = tf.group( tf.tables_initializer(), name='legacy_init_op') builder = saved_model_builder.SavedModelBuilder(export_path) builder.add_meta_graph_and_variables( session, [tag_constants.SERVING], clear_devices=True, signature_def_map={ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: model_signature, }, legacy_init_op=legacy_init_op)
builder.save() except Exception as e: print("Fail to export saved model, exception: {}".format(e))
关于SignatureDef的编写可参考:
https://tensorflow.google.cn/tfx/serving/signature_defs?hl=en


在tf中打包多个模型和函数(Synchronized)

可以用tf.function直接对一些函数和模型的操作进行封装,其中的计算会转换为tf中的图计算。需要注意的是,用这种方法进行封装,执行的时候是Synchronized的。
如果需要实现Async,请使用OpenVINO、TensorRT、OpenGL或CUDA等进行部署。
@tf.functiondef full_model(image): x1 = func_1(image) x2 = func_2(image) return [x1,x2]
full_model = full_model.get_concrete_function(tf.TensorSpec((832, 1344,3), tf.float32))
frozen_func = convert_variables_to_constants_v2(full_model)frozen_func.graph.as_graph_def()
layers = [op.name for op in frozen_func.graph.get_operations()]print("-" * 50)print("Frozen model layers: ")for layer in layers: print(layer)
print("-" * 50)print("Frozen model inputs: ")print(frozen_func.inputs)print("Frozen model outputs: ")print(frozen_func.outputs)
# Save frozen graph from frozen ConcreteFunction to hard drivetf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir="./model", name="model.pb", as_text=False)

本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。



“他山之石”历史文章


更多他山之石专栏文章,

请点击文章底部“阅读原文”查看



分享、点赞、在看,给个三连击呗!

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存