快速入门 TensorFlow.js
文 / Laurence Moroney
使用 TensorFlow.js,不仅可以在浏览器中运行机器学习模型来执行推理,还可以训练它们。在本教程中,将向您展示一个基本的 “Hello World” 示例,通过该实例开启我们的全新旅程。
让我们从一个最简单的网页开始:
<html>
<head></head>
<body></body>
</html>
完成后,需要做的第一件事是添加对 TensorFlow.js 的引用,以便我们可以在浏览器环境中使用 TensorFlow API。为方便起见,可以从 CDN 上获取 JS 文件:
<html>
<head>
<!-- Load TensorFlow.js -->
<!-- Get latest version at https://github.com/tensorflow/tfjs -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.11.2"> </script>
在写这篇文章时使用的 TensorFlow.js 版本是 0.11.2。如果想获取最新版本,我们可以从 GitHub 查看。
现在我们已经成功加载了 TensorFlow.js,让我们用它做一些有趣的事情吧!
现在有一条公式为 Y = 2X-1 的直线。并提供你一组点,如(-1,-3),(0,-1),(1,1),(2,3),(3,5)和(4,7)。虽然通过公式我们可以得出给定 X 的 Y 值,我们是否可以通过机械学习模型推导出 Y 值呢?
首先,我们可以创建一个简单的神经网络来进行推理。由于只有 1 个输入值和 1 个输出值,因此它可以是单节点。在 JavaScript 中,我们可以创建一个 tf.sequential,并添加图层定义。代码示例如下:
const model = tf.sequential();
model.add(tf.layers.dense({units: 1, inputShape: [1]}));
为了完成模型定义,我们需要执行编译,并指定损失类型和优化器。我们将选择最基本的损失类型 - meanSquaredError,同时优化器使用标准的
Stochastic Gradient Descent:
model.compile({
loss: 'meanSquaredError',
optimizer: 'sgd'
});
为了训练模型,我们需要定义张量,并指定其形状:
const xs = tf.tensor2d([-1, 0, 1, 2, 3, 4], [6, 1]);
const ys = tf.tensor2d([-3, -1, 1, 3, 5, 7], [6, 1]);
为了训练模型,我们使用 fit 方法。为此,我们传递一组 X 和 Y 值,以及 epochs(循环数据)。请注意,这是异步(async/await)的,因此所有这些代码都需要在异步函数中:
await model.fit(xs, ys, {epochs: 500});
一旦准备就绪,模型就会被训练,我们就可以基于 X 值预测 Y。例如,如果我们想要找出 X = 10的 Y 值并将其写在 Web 页面上的 <div> 中,代码如下所示:
document.getElementById('output_field').innerText =
model.predict(tf.tensor2d([10], [1, 1]));
请注意,输入是包含值 10 的 1x1 的张量。
结果如下所示:
等等,你可能会问 —— 为什么不是 19?它非常接近,但它不是 19!这是因为该算法从未被赋予公式 —— 它只是根据给出的少量数据进行学习。有了更多的相关数据进行训练,ML 模型就会提供更高的准确性。
为了方便起见,完整代码如下所示:
<html>
<head>
<!-- Load TensorFlow.js -->
<!-- Get latest version at https://github.com/tensorflow/tfjs -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.11.2">
</script>
</head>
<body>
<div id="output_field"></div>
</body>
<script>
async function learnLinear(){
const model = tf.sequential();
model.add(tf.layers.dense({units: 1, inputShape: [1]}));
model.compile({
loss: 'meanSquaredError',
optimizer: 'sgd'
});
const xs = tf.tensor2d([-1, 0, 1, 2, 3, 4], [6, 1]);
const ys = tf.tensor2d([-3, -1, 1, 3, 5, 7], [6, 1]);
await model.fit(xs, ys, {epochs: 500});
document.getElementById('output_field').innerText =
model.predict(tf.tensor2d([10], [1, 1]));
}
learnLinear();
</script>
<html>
这就是在浏览器环境中使用 TensorFlow.js 创建一个非常简单的机器学习模型所需要的一切。从这里开始,我们将进入崭新的世界!
更多 AI 相关阅读:
· 使用 Cloud TPU 在 30 分钟内训练并部署实时移动物体探测器
Be a Tensorflower