TensorFlow.js Layers API 給 Keras 使用者的指南

TensorFlow.js 的 Layers API 以 Keras 為模型,並且考量到 JavaScript 和 Python 之間的差異,我們盡力使 Layers API 與 Keras 盡可能相似。這讓有使用 Python 開發 Keras 模型經驗的使用者,可以更輕鬆地遷移到 JavaScript 中的 TensorFlow.js Layers。例如,以下 Keras 程式碼可以轉換為 JavaScript

# Python:
import keras
import numpy as np

# Build and compile model.
model = keras.Sequential()
model.add(keras.layers.Dense(units=1, input_shape=[1]))
model.compile(optimizer='sgd', loss='mean_squared_error')

# Generate some synthetic data for training.
xs = np.array([[1], [2], [3], [4]])
ys = np.array([[1], [3], [5], [7]])

# Train model with fit().
model.fit(xs, ys, epochs=1000)

# Run inference with predict().
print(model.predict(np.array([[5]])))
// JavaScript:
import * as tf from '@tensorflow/tfjs';

// Build and compile model.
const model = tf.sequential();
model.add(tf.layers.dense({units: 1, inputShape: [1]}));
model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});

// Generate some synthetic data for training.
const xs = tf.tensor2d([[1], [2], [3], [4]], [4, 1]);
const ys = tf.tensor2d([[1], [3], [5], [7]], [4, 1]);

// Train model with fit().
await model.fit(xs, ys, {epochs: 1000});

// Run inference with predict().
model.predict(tf.tensor2d([[5]], [1, 1])).print();

然而,我們想在此文件中指出並說明一些差異。一旦您了解這些差異及其背後的理由,您的 Python 到 JavaScript 遷移(或反向遷移)應該會是相對順暢的體驗。

建構函式採用 JavaScript 物件作為組態

比較上方範例中的以下 Python 和 JavaScript 行:它們都建立一個 Dense 層。

# Python:
keras.layers.Dense(units=1, inputShape=[1])
// JavaScript:
tf.layers.dense({units: 1, inputShape: [1]});

JavaScript 函式沒有與 Python 函式中關鍵字引數等效的功能。我們希望避免在 JavaScript 中將建構函式選項實作為位置引數,這對於具有大量關鍵字引數的建構函式(例如,LSTM)來說,尤其難以記住和使用。這就是我們使用 JavaScript 組態物件的原因。此類物件提供與 Python 關鍵字引數相同的位置不變性和靈活性。

Model 類別的某些方法,例如 Model.compile(),也採用 JavaScript 組態物件作為輸入。但是,請記住 Model.fit()Model.evaluate()Model.predict() 略有不同。由於這些方法採用必要的 x (特徵) 和 y (標籤或目標) 資料作為輸入;xy 是位置引數,與隨後的組態物件分開,後者扮演關鍵字引數的角色。例如

// JavaScript:
await model.fit(xs, ys, {epochs: 1000});

Model.fit() 是非同步的

Model.fit() 是使用者在 TensorFlow.js 中執行模型訓練的主要方法。此方法通常會長時間執行,持續數秒或數分鐘。因此,我們利用 JavaScript 語言的 async 功能,以便此函式可以以不會在瀏覽器中執行時封鎖主 UI 執行緒的方式使用。這與 JavaScript 中其他可能長時間執行的函式類似,例如非同步 async fetch。請注意,async 是 Python 中不存在的建構。雖然 Keras 中的 fit() 方法會傳回 History 物件,但 JavaScript 中 fit() 方法的對應方法會傳回 History 的 Promise,可以 await(如上方範例所示)或與 then() 方法一起使用。

TensorFlow.js 沒有 NumPy

Python Keras 使用者經常使用 NumPy 執行基本數值和陣列運算,例如在上方範例中產生 2D 張量。

# Python:
xs = np.array([[1], [2], [3], [4]])

在 TensorFlow.js 中,這類基本數值運算由套件本身完成。例如

// JavaScript:
const xs = tf.tensor2d([[1], [2], [3], [4]], [4, 1]);

tf.* 命名空間也提供許多其他用於陣列和線性代數運算的函式,例如矩陣乘法。如需更多資訊,請參閱 TensorFlow.js Core 文件

使用工廠方法,而非建構函式

Python 中的這行程式碼(來自上方範例)是建構函式呼叫

# Python:
model = keras.Sequential()

如果嚴格翻譯成 JavaScript,等效的建構函式呼叫看起來會像這樣

// JavaScript:
const model = new tf.Sequential();  // !!! DON'T DO THIS !!!

然而,我們決定不使用「new」建構函式,因為 1) 「new」關鍵字會使程式碼更加臃腫,而且 2) 「new」建構函式被視為 JavaScript 的「不良部分」:一個潛在的陷阱,正如 JavaScript: the Good Parts 中所論證的那樣。若要在 TensorFlow.js 中建立模型和層,您可以呼叫工廠方法,這些方法具有 lowerCamelCase 名稱,例如

// JavaScript:
const model = tf.sequential();

const layer = tf.layers.batchNormalization({axis: 1});

選項字串值為 lowerCamelCase,而非 snake_case

在 JavaScript 中,符號名稱更常使用 camel case(例如,請參閱 Google JavaScript Style Guide),相較之下,Python 中常見 snake case(例如,在 Keras 中)。因此,我們決定對選項的字串值使用 lowerCamelCase,包括以下選項

  • DataFormat,例如,channelsFirst 而非 channels_first
  • Initializer,例如,glorotNormal 而非 glorot_normal
  • Loss 和 metrics,例如,meanSquaredError 而非 mean_squared_errorcategoricalCrossentropy 而非 categorical_crossentropy

例如,如上方範例所示

// JavaScript:
model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});

關於模型序列化和還原序列化,請放心。TensorFlow.js 的內部機制確保 JSON 物件中的 snake case 能夠正確處理,例如,從 Python Keras 載入預先訓練模型時。

使用 apply() 執行 Layer 物件,而非將其作為函式呼叫

在 Keras 中,Layer 物件已定義 __call__ 方法。因此,使用者可以透過將物件作為函式呼叫來調用層的邏輯,例如,

# Python:
my_input = keras.Input(shape=[2, 4])
flatten = keras.layers.Flatten()

print(flatten(my_input).shape)

此 Python 語法糖在 TensorFlow.js 中實作為 apply() 方法

// JavaScript:
const myInput = tf.input({shape: [2, 4]});
const flatten = tf.layers.flatten();

console.log(flatten.apply(myInput).shape);

Layer.apply() 支援在具體張量上進行命令式 (Eager) 評估

目前,在 Keras 中,call 方法只能在 (Python) TensorFlow 的 tf.Tensor 物件上運作(假設為 TensorFlow 後端),這些物件是符號式且不包含實際數值。這就是前一節範例中顯示的內容。但是,在 TensorFlow.js 中,層的 apply() 方法可以在符號式和命令式模式下運作。如果使用 SymbolicTensor(tf.Tensor 的近似類比)調用 apply(),則傳回值將為 SymbolicTensor。這通常發生在模型建置期間。但是,如果使用實際的具體 Tensor 值調用 apply(),它將傳回一個具體的 Tensor。例如

// JavaScript:
const flatten = tf.layers.flatten();

flatten.apply(tf.ones([2, 3, 4])).print();

此功能讓人聯想到 (Python) TensorFlow 的 Eager Execution。除了為組合動態神經網路開啟大門之外,它還在模型開發期間提供了更高的互動性和可除錯性。

Optimizer 位於 train.* 下,而非 optimizers.* 下

在 Keras 中,Optimizer 物件的建構函式位於 keras.optimizers.* 命名空間下。在 TensorFlow.js Layers 中,Optimizer 的工廠方法位於 tf.train.* 命名空間下。例如

# Python:
my_sgd = keras.optimizers.sgd(lr=0.2)
// JavaScript:
const mySGD = tf.train.sgd({lr: 0.2});

loadLayersModel() 從 URL 載入,而非 HDF5 檔案

在 Keras 中,模型通常 儲存 為 HDF5 (.h5) 檔案,稍後可以使用 keras.models.load_model() 方法載入。該方法採用 .h5 檔案的路徑。TensorFlow.js 中 load_model() 的對應方法是 tf.loadLayersModel()。由於 HDF5 不是瀏覽器友善的檔案格式,因此 tf.loadLayersModel() 採用 TensorFlow.js 特定的格式。tf.loadLayersModel() 採用 model.json 檔案作為其輸入引數。model.json 可以使用 tensorflowjs pip 套件從 Keras HDF5 檔案轉換而來。

// JavaScript:
const model = await tf.loadLayersModel('https://foo.bar/model.json');

另請注意,tf.loadLayersModel() 會傳回 Promisetf.Model

一般而言,在 TensorFlow.js 中儲存和載入 tf.Model 是分別使用 tf.Model.savetf.loadLayersModel 方法完成的。我們設計這些 API 的目的是使其與 Keras 的 save 和 load_model API 相似。但是,瀏覽器環境與 Keras 等主要深度學習框架運作的後端環境截然不同,尤其是在持久保存和傳輸資料的路徑陣列中。因此,TensorFlow.js 和 Keras 中的 save/load API 之間存在一些有趣的差異。如需更多詳細資訊,請參閱我們的儲存和載入 tf.Model 教學課程。

使用 fitDataset() 使用 tf.data.Dataset 物件訓練模型

在 Python TensorFlow 的 tf.keras 中,可以使用 Dataset 物件訓練模型。模型的 fit() 方法直接接受此類物件。TensorFlow.js 模型也可以使用 Dataset 物件的 JavaScript 等效項進行訓練(請參閱 TensorFlow.js 中 tf.data API 的文件)。但是,與 Python 不同的是,基於 Dataset 的訓練是透過專用方法 fitDataset 完成的。fit() 方法僅適用於基於張量的模型訓練。

Layer 和 Model 物件的記憶體管理

TensorFlow.js 在瀏覽器中的 WebGL 上執行,其中 Layer 和 Model 物件的權重由 WebGL 紋理支援。但是,WebGL 沒有內建的垃圾收集支援。Layer 和 Model 物件在其推論和訓練呼叫期間,會在內部為使用者管理張量記憶體。但它們也允許使用者處置它們,以釋放它們佔用的 WebGL 記憶體。這在單一頁面載入中建立和釋放許多模型實例的情況下非常有用。若要處置 Layer 或 Model 物件,請使用 dispose() 方法。