其他
【他山之石】Tensorflow模型保存方式大汇总
“他山之石,可以攻玉”,站在巨人的肩膀才能看得更高,走得更远。在科研的道路上,更需借助东风才能更快前行。为此,我们特别搜集整理了一些实用的代码链接,数据集,软件,编程技巧等,开辟“他山之石”专栏,助你乘风破浪,一路奋勇向前,敬请关注。
地址:https://www.zhihu.com/people/kumonoue
GraphDef
SignatureDef
list(meta_graph.signature_def.items())
tf.saved_model
# 保存
model = tf.saved_model.save(
obj, export_dir, signatures=None, options=None
)
# 读取
model= tf.saved_model.load(
export_dir, tags=None, options=None
)
# 推理
infer = model.signatures["serving_default"]
freeze_grap
from tensorflow.python.tools.freeze_graph import freeze_graph_with_def_protos
该函数将图和权重以常量的形式保存在一张静态图中(pb)。
output_graph_def = convert_variables_to_constants(session, input_graph_def, output_names)
output_graph = 'pb_model/model.pb' # 保存地址
with tf.gfile.GFile(output_graph, 'wb') as f:
f.write(output_graph_def.SerializeToString())
tf.train.Saver()
# 保存断点
saver = tf.train.Saver()
saver.save()
# 加载断点
saver.restore()
.data文件保存了当前参数值 .index文件保存了当前参数名 .meta文件保存了当前图结构 .events文件是给可视化工具tensorboard使用。 .pbtxt文件是以字符串存储的计算图
tf.train.CheckpointManager()
# 设置断点
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
manager = tf.train.CheckpointManager(
checkpoint, directory="/tmp/model", max_to_keep=5)
# 加载最新的断点
status = checkpoint.restore(manager.latest_checkpoint)
# 保存断点
while True:
# train
manager.save()
查看静态图输入输出节点
saved_model_cli show --dir model/ --all
也可以加载静态图后,打印所有节点,逐个查看:
tensor_name_list = [tensor.name for tensor in tf.compat.v1.get_default_graph().as_graph_def().node]
for tensor_name in tensor_name_list:
print(tensor_name,'\n')
summaryWriter = tf.compat.v1.summary.FileWriter('log/', graph)
自定义SignatureDef
signature_def: {
key : "my_classification_signature"
value: {
inputs: {
key : "inputs"
value: {
name: "tf_example:0"
dtype: DT_STRING
tensor_shape: ...
}
}
outputs: {
key : "classes"
value: {
name: "index_to_string:0"
dtype: DT_STRING
tensor_shape: ...
}
}
outputs: {
key : "scores"
value: {
name: "TopKV2:0"
dtype: DT_FLOAT
tensor_shape: ...
}
}
method_name: "tensorflow/serving/classify"
}
}
#保存为pb模型
def export_model(session, m):
#只需要修改这一段,定义输入输出,其他保持默认即可
model_signature = signature_def_utils.build_signature_def(
inputs={"input": utils.build_tensor_info(m.a)},
outputs={
"output": utils.build_tensor_info(m.y)},
method_name=signature_constants.PREDICT_METHOD_NAME)
export_path = "pb_model/1"
if os.path.exists(export_path):
os.system("rm -rf "+ export_path)
print("Export the model to {}".format(export_path))
try:
legacy_init_op = tf.group(
tf.tables_initializer(), name='legacy_init_op')
builder = saved_model_builder.SavedModelBuilder(export_path)
builder.add_meta_graph_and_variables(
session, [tag_constants.SERVING],
clear_devices=True,
signature_def_map={
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
model_signature,
},
legacy_init_op=legacy_init_op)
builder.save()
except Exception as e:
print("Fail to export saved model, exception: {}".format(e))
在tf中打包多个模型和函数(Synchronized)
@tf.function
def full_model(image):
x1 = func_1(image)
x2 = func_2(image)
return [x1,x2]
full_model = full_model.get_concrete_function(tf.TensorSpec((832, 1344,3), tf.float32))
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()
layers = [op.name for op in frozen_func.graph.get_operations()]
print("-" * 50)
print("Frozen model layers: ")
for layer in layers:
print(layer)
print("-" * 50)
print("Frozen model inputs: ")
print(frozen_func.inputs)
print("Frozen model outputs: ")
print(frozen_func.outputs)
# Save frozen graph from frozen ConcreteFunction to hard drive
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
logdir="./model",
name="model.pb",
as_text=False)
本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。
“他山之石”历史文章
利用Tensorflow构建CNN图像多分类模型及图像参数、数据维度变化情况实例分析
pytorch中optimizer对loss的影响
使用PyTorch 1.6 for Android
神经网络解微分方程实例:三体问题
pytorch 实现双边滤波
编译PyTorch静态库
工业界视频理解解决方案大汇总
动手造轮子-rnn
凭什么相信你,我的CNN模型?关于CNN模型可解释性的思考
c++接口libtorch介绍& vscode+cmake实践
python从零开始构建知识图谱
一文读懂 PyTorch 模型保存与载入
适合PyTorch小白的官网教程:Learning PyTorch With Examples
pytorch量化备忘录
更多他山之石专栏文章,
请点击文章底部“阅读原文”查看
分享、点赞、在看,给个三连击呗!