將 TensorFlow 模型匯入 TensorFlow.js

以 TensorFlow GraphDef 為基礎的模型 (通常透過 Python API 建立) 可以儲存為下列其中一種格式

  1. TensorFlow SavedModel
  2. 凍結模型
  3. TensorFlow Hub 模組

以上所有格式都可以透過 TensorFlow.js 轉換器轉換成可以直接載入 TensorFlow.js 以進行推論的格式。

(注意:TensorFlow 已淘汰工作階段套件組合格式。請將模型移轉至 SavedModel 格式。)

需求條件

轉換程序需要 Python 環境;您可能會想使用 pipenvvirtualenv 維護隔離的環境。

如要安裝轉換器,請執行下列指令

 pip install tensorflowjs

將 TensorFlow 模型匯入 TensorFlow.js 是由兩個步驟組成的程序。首先,將現有模型轉換為 TensorFlow.js Web 格式,然後將其載入 TensorFlow.js。

步驟 1:將現有 TensorFlow 模型轉換為 TensorFlow.js Web 格式

執行 pip 套件提供的轉換器指令碼

SavedModel 範例

tensorflowjs_converter \
    --input_format=tf_saved_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    --saved_model_tags=serve \
    /mobilenet/saved_model \
    /mobilenet/web_model

凍結模型範例

tensorflowjs_converter \
    --input_format=tf_frozen_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    /mobilenet/frozen_model.pb \
    /mobilenet/web_model

TensorFlow Hub 模組範例

tensorflowjs_converter \
    --input_format=tf_hub \
    'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
    /mobilenet/web_model
位置引數 說明
input_path 儲存模型目錄、工作階段套件組合目錄、凍結模型檔案或 TensorFlow Hub 模組控制代碼或路徑的完整路徑。
output_path 所有輸出成品的路徑。
選項 說明
--input_format 輸入模型的格式。SavedModel 請使用 tf_saved_model,凍結模型請使用 tf_frozen_model,工作階段套件組合請使用 tf_session_bundle,TensorFlow Hub 模組請使用 tf_hub,Keras HDF5 請使用 keras。
--output_node_names 輸出節點的名稱,以逗號分隔。
--saved_model_tags 僅適用於 SavedModel 轉換。要載入的 MetaGraphDef 標記,以逗號分隔的格式。預設值為 serve
--signature_name 僅適用於 TensorFlow Hub 模組轉換,要載入的簽名。預設值為 default。請參閱 https://tensorflow.dev.org.tw/hub/common_signatures/

使用下列指令取得詳細的說明訊息

tensorflowjs_converter --help

轉換器產生的檔案

上述轉換指令碼會產生兩種檔案

  • model.json:資料流程圖和權重資訊清單
  • group1-shard\*of\*:二進位權重檔案的集合

例如,以下是轉換 MobileNet v2 的輸出

  output_directory/model.json
  output_directory/group1-shard1of5
  ...
  output_directory/group1-shard5of5

步驟 2:在瀏覽器中載入和執行

  1. 安裝 tfjs-converter npm 套件

yarn add @tensorflow/tfjsnpm install @tensorflow/tfjs

  1. 例項化 FrozenModel 類別並執行推論。
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';

const MODEL_URL = 'model_directory/model.json';

const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));

查看 MobileNet 示範

loadGraphModel API 接受額外的 LoadOptions 參數,可用於隨要求傳送憑證或自訂標頭。如需詳細資訊,請參閱 loadGraphModel() 文件

支援的運算

目前 TensorFlow.js 僅支援一組有限的 TensorFlow 運算元。如果您的模型使用不支援的運算元,tensorflowjs_converter 指令碼將會失敗,並列印模型中不支援的運算元清單。請針對每個運算元提交 問題單,讓我們知道您需要哪些運算元支援。

僅載入權重

如果您偏好僅載入權重,可以使用下列程式碼片段

import * as tf from '@tensorflow/tfjs';

const weightManifestUrl = "https://example.org/model/weights_manifest.json";

const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
        this.weightManifest, "https://example.org/model");
// Use `weightMap` ...