查看原文
其他

无需转换!直接在 Node.js 中执行 TensorFlow SavedModel

Google TensorFlow 2021-08-05

文 / Kangyi Zhang,Sandeep Gupta 和 Brijesh Krishnaswami

TensorFlow.js 是一个开源代码库,开发者可以通过 JavaScript 语言定义、训练和运行机器学习模型。这让大多数的 JavaScript 开发者也能参与构建和部署机器学习模型,并由此产生了很多新的机器学习用例。如 TensorFlow.js 可以在所有主流浏览器中运行,服务端有 Node.js,还有最近 微信小程序插件React Native 开始实现在混合移动应用中的机器学习相关操作,开发者无需离开 JS 生态。现在,我们很高兴为 Node.js 开发者提供一种新方法,可以无需进行模型转换,轻松高效地部署预训练的 TensorFlow SavedModel


TensorFlow.js 的主要优势之一是 JavaScript 开发者可以轻松地部署预训练的 TensorFlow 模型进行推理。TensorFlow.js 提供了转换工具 tfjs-converter ,可将 TensorFlow SavedModel、TFHub 模型或 Keras 模型转换为 JavaScript 兼容格式。但是,转换工具需要 JavaScript 开发者安装 TensorFlow 的 Python 工具包并学习如何使用它。此外,转换工具不支持全部的 TensorFlow 算子(支持的算子参见此文),因此,如果模型包含不支持的算子,则无法使用此工具。



在 Node.js 中执行原生模型

我们很高兴宣布现在可以在 Node.js 中执行原生 TensorFlow SavedModel。现在,您可以把预训练的 TensorFlow 模型存为 SavedModel 格式,并通过 @tensorflow/tfjs-nodetfjs-node-gpu 包将模型加载到 Node.js 进行推理,且无需使用转换工具 tfjs-converter。


TensorFlow SavedModel 通常含有一个或几个命名函数,称为 SignatureDef。预训练的TensorFlow SavedModel 可以通过一行代码在 JavaScript 中加载模型的 SignatureDef,随后该模型便可用于推理。

const model = await tf.node.loadSavedModel(path, [tag], signatureKey);
const output = model.predict(input);


也可以将多个输入以数组或图的形式提供给模型:

const model1 = await tf.node.loadSavedModel(path1, [tag], signatureKey);
const outputArray = model1.predict([inputTensor1, inputTensor2]);


const model2 = await tf.node.loadSavedModel(path2, [tag], signatureKey);
const outputMap = model2.predict({input1: inputTensor1, input2:inputTensor2});


如需查看 TensorFlow SavedModel 的详细信息,查找模型标签和签名信息(又称为 MetaGraph),可以通过一个 JavaScript helper API 对其进行解析,类似于 TensorFlow SavedModel 客户端工具

const modelInfo = await tf.node.getMetaGraphsFromSavedModel(path);


此项新功能可在 1.3.2 或更高版本的 @tensorflow/tfjs-node 包中使用,同时支持 CPU 和 GPU。它支持在 TensorFlow Python 1.x 和 2.0 版本中训练和导出的 TensorFlow SavedModel。由此带来的好处除了无需进行任何转换,原生执行 TensorFlow SavedModel 意味着您可以在模型中使用 TensorFlow.js 尚未支持的算子。这要通过将 SavedModel 作为 TensorFlow 会话加载到 C++ 中进行绑定予以实现。


除了可用性上的优点,性能上的表现同样有亮点。在下图的性能基准测试中(使用 MobileNetV2 模型,横轴为推理用时),可以看到直接在 Node.js 中执行 SavedModel,CPU 和 GPU 的推理用时均有所降低。


您可以到 @tensorflow/tfjs-examples 仓库查看我们的示例 。欢迎加入我们的 讨论组 并分享您的反馈!



如果您想详细了解 本文提及 的相关内容,请参阅以下文档。这些文档深入探讨了这篇文章中提及的许多主题:

  • TensorFlow.js (阅读原文直接跳转)
    http://tensorflow.google.cn/js

  • 微信小程序插件
    https://github.com/tensorflow/tfjs-wechat

  • React Native 
    https://github.com/tensorflow/tfjs/tree/master/tfjs-react-native

  • SavedModel
    https://tensorflow.google.cn/guide/saved_model

  • tfjs-converter
    https://github.com/tensorflow/tfjs/tree/master/tfjs-converter

  • 此文
    https://js.tensorflow.org/api/latest/#Operations

  • @tensorflow/tfjs-node 
    https://www.npmjs.com/package/@tensorflow/tfjs-node

  • tfjs-node-gpu
    https://www.npmjs.com/package/@tensorflow/tfjs-node-gpu

  • SignatureDef
    https://tensorflow.google.cn/guide/saved_model#identifying_a_signature_to_export

  • 客户端工具
    https://tensorflow.google.cn/guide/saved_model#show_command

  • 示例
    https://github.com/tensorflow/tfjs-examples/tree/master/firebase-object-detection-node

  • 讨论组
    https://groups.google.com/a/tensorflow.org/forum/#!forum/tfjs



推荐阅读:



    您可能也对以下帖子感兴趣

    文章有问题?点此查看未经处理的缓存