以 TensorFlow GraphDef 為基礎的模型 (通常透過 Python API 建立) 可以儲存為下列其中一種格式
- TensorFlow SavedModel
- 凍結模型
- TensorFlow Hub 模組
以上所有格式都可以透過 TensorFlow.js 轉換器轉換成可以直接載入 TensorFlow.js 以進行推論的格式。
(注意:TensorFlow 已淘汰工作階段套件組合格式。請將模型移轉至 SavedModel 格式。)
需求條件
轉換程序需要 Python 環境;您可能會想使用 pipenv 或 virtualenv 維護隔離的環境。
如要安裝轉換器,請執行下列指令
pip install tensorflowjs
將 TensorFlow 模型匯入 TensorFlow.js 是由兩個步驟組成的程序。首先,將現有模型轉換為 TensorFlow.js Web 格式,然後將其載入 TensorFlow.js。
步驟 1:將現有 TensorFlow 模型轉換為 TensorFlow.js Web 格式
執行 pip 套件提供的轉換器指令碼
SavedModel 範例
tensorflowjs_converter \
--input_format=tf_saved_model \
--output_node_names='MobilenetV1/Predictions/Reshape_1' \
--saved_model_tags=serve \
/mobilenet/saved_model \
/mobilenet/web_model
凍結模型範例
tensorflowjs_converter \
--input_format=tf_frozen_model \
--output_node_names='MobilenetV1/Predictions/Reshape_1' \
/mobilenet/frozen_model.pb \
/mobilenet/web_model
TensorFlow Hub 模組範例
tensorflowjs_converter \
--input_format=tf_hub \
'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
/mobilenet/web_model
位置引數 | 說明 |
---|---|
input_path |
儲存模型目錄、工作階段套件組合目錄、凍結模型檔案或 TensorFlow Hub 模組控制代碼或路徑的完整路徑。 |
output_path |
所有輸出成品的路徑。 |
選項 | 說明 |
---|---|
--input_format |
輸入模型的格式。SavedModel 請使用 tf_saved_model,凍結模型請使用 tf_frozen_model,工作階段套件組合請使用 tf_session_bundle,TensorFlow Hub 模組請使用 tf_hub,Keras HDF5 請使用 keras。 |
--output_node_names |
輸出節點的名稱,以逗號分隔。 |
--saved_model_tags |
僅適用於 SavedModel 轉換。要載入的 MetaGraphDef 標記,以逗號分隔的格式。預設值為 serve 。 |
--signature_name |
僅適用於 TensorFlow Hub 模組轉換,要載入的簽名。預設值為 default 。請參閱 https://tensorflow.dev.org.tw/hub/common_signatures/ |
使用下列指令取得詳細的說明訊息
tensorflowjs_converter --help
轉換器產生的檔案
上述轉換指令碼會產生兩種檔案
model.json
:資料流程圖和權重資訊清單group1-shard\*of\*
:二進位權重檔案的集合
例如,以下是轉換 MobileNet v2 的輸出
output_directory/model.json
output_directory/group1-shard1of5
...
output_directory/group1-shard5of5
步驟 2:在瀏覽器中載入和執行
- 安裝 tfjs-converter npm 套件
yarn add @tensorflow/tfjs
或 npm install @tensorflow/tfjs
- 例項化 FrozenModel 類別並執行推論。
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';
const MODEL_URL = 'model_directory/model.json';
const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));
查看 MobileNet 示範。
loadGraphModel
API 接受額外的 LoadOptions
參數,可用於隨要求傳送憑證或自訂標頭。如需詳細資訊,請參閱 loadGraphModel() 文件。
支援的運算
目前 TensorFlow.js 僅支援一組有限的 TensorFlow 運算元。如果您的模型使用不支援的運算元,tensorflowjs_converter
指令碼將會失敗,並列印模型中不支援的運算元清單。請針對每個運算元提交 問題單,讓我們知道您需要哪些運算元支援。
僅載入權重
如果您偏好僅載入權重,可以使用下列程式碼片段
import * as tf from '@tensorflow/tfjs';
const weightManifestUrl = "https://example.org/model/weights_manifest.json";
const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
this.weightManifest, "https://example.org/model");
// Use `weightMap` ...