Node.js 中的 TensorFlow.js

本指南說明適用於 Node.js 的 TensorFlow.js 套件和 API。

如要瞭解如何在 Node.js 中安裝 TensorFlow.js,請參閱設定教學課程。如需安裝與支援的其他資訊,請參閱 Node.js 版 TensorFlow.js 存放區

TensorFlow CPU

TensorFlow CPU 套件可匯入如下

import * as tf from '@tensorflow/tfjs-node'

從這個套件匯入 TensorFlow.js 時,您會取得由 TensorFlow C 二進位檔加速,並在 CPU 上執行的模組。CPU 上的 TensorFlow 會使用硬體加速來最佳化線性代數運算。

這個套件適用於支援 TensorFlow 的 Linux、Windows 和 macOS 平台。

TensorFlow GPU

TensorFlow GPU 套件可匯入如下

import * as tf from '@tensorflow/tfjs-node-gpu'

和 CPU 套件一樣,這個模組由 TensorFlow C 二進位檔加速。但 GPU 套件會在 GPU 上使用 CUDA 執行張量運算,因此僅適用於 Linux。這個繫結的速度至少比其他繫結選項快一個數量級。

純 JavaScript 版 TensorFlow

另有一個 TensorFlow.js 版本可在 CPU 上執行純 JavaScript。這個版本可匯入如下

import * as tf from '@tensorflow/tfjs'

這個套件與您在瀏覽器中使用的套件相同。在這個套件中,運算會在 CPU 上以原生 JavaScript 執行。這個套件比其他套件小得多,因為它不需要 TensorFlow 二進位檔,但速度也慢得多。

由於這個套件不仰賴 TensorFlow,因此可在更多支援 Node.js 的裝置中使用。它不限於支援 TensorFlow 的 Linux、Windows 和 macOS 平台。

生產環境考量

Node.js 繫結為 TensorFlow.js 提供後端,以同步方式實作運算。這表示,例如,當您呼叫 tf.matMul(a, b) 等運算時,系統會封鎖主要執行緒,直到運算完成為止。

因此,這些繫結非常適合指令碼和離線工作。如果您想在網路伺服器等生產環境應用程式中使用 Node.js 繫結,則應設定工作佇列或設定工作人員執行緒,讓 TensorFlow.js 程式碼不會封鎖主要執行緒。

API

當您使用上述任何選項將套件匯入為 tf 時,所有一般的 TensorFlow.js 符號都會出現在匯入的模組中。

tf.browser

tf.browser.* 命名空間中的 API 無法在 Node.js 中使用,因為這些 API 仰賴瀏覽器專用的 API。如需 tf.browser API 的清單,請參閱「瀏覽器」。

tf.node

這兩個 Node.js 套件也提供命名空間 tf.node,其中包含 Node.js 專用的 API (例如 TensorBoard)。

以下範例說明如何在 Node.js 中將摘要匯出至 TensorBoard

const model = tf.sequential();
model.add(tf.layers.dense({ units: 1, inputShape: [200] }));
model.compile({
  loss: 'meanSquaredError',
  optimizer: 'sgd',
  metrics: ['MAE']
});

// Generate some random fake data for demo purposes.
const xs = tf.randomUniform([10000, 200]);
const ys = tf.randomUniform([10000, 1]);
const valXs = tf.randomUniform([1000, 200]);
const valYs = tf.randomUniform([1000, 1]);

// Start model training process.
async function train() {
  await model.fit(xs, ys, {
    epochs: 100,
    validationData: [valXs, valYs],
    // Add the tensorBoard callback here.
    callbacks: tf.node.tensorBoard('/tmp/fit_logs_1')
  });
}
train();