使用 TensorFlow Hub 從 TF1 遷移至 TF2

本頁說明如何在將 TensorFlow 程式碼從 TensorFlow 1 遷移至 TensorFlow 2 的同時,繼續使用 TensorFlow Hub。它補充了 TensorFlow 的一般遷移指南

對於 TF2,TF Hub 已切換離開舊版 hub.Module API,不再像 tf.contrib.v1.layers 那樣建構 tf.compat.v1.Graph。取而代之的是,現在有一個 hub.KerasLayer 可與其他 Keras 層一起使用,以建構 tf.keras.Model(通常在 TF2 的新 eager execution 環境中)及其底層的 hub.load() 方法,用於低階 TensorFlow 程式碼。

hub.Module API 在 tensorflow_hub 函式庫中仍然可用,以便在 TF1 和 TF2 的 TF1 相容模式中使用。它只能載入 TF1 Hub 格式的模型。

hub.load()hub.KerasLayer 的新 API 適用於 TensorFlow 1.15(在 eager 和 graph 模式下)以及 TensorFlow 2。這個新 API 可以載入新的 TF2 SavedModel 資產,並且在模型相容性指南中列出的限制下,也可以載入 TF1 Hub 格式的舊版模型。

一般而言,建議盡可能使用新 API。

新 API 摘要

hub.load() 是從 TensorFlow Hub(或相容服務)載入 SavedModel 的新底層函式。它包裝了 TF2 的 tf.saved_model.load();TensorFlow 的 SavedModel 指南描述了您可以對結果執行的操作。

m = hub.load(handle)
outputs = m(inputs)

hub.KerasLayer 類別呼叫 hub.load() 並調整結果,以便在 Keras 中與其他 Keras 層一起使用。(它甚至可以作為以其他方式使用的已載入 SavedModel 的便利包裝函式。)

model = tf.keras.Sequential([
    hub.KerasLayer(handle),
    ...])

許多教學課程展示了這些 API 的實際運作方式。以下是一些範例

在 Estimator 訓練中使用新 API

如果您在 Estimator 中使用 TF2 SavedModel 搭配參數伺服器進行訓練(或在 TF1 Session 中使用放置在遠端裝置上的變數),則需要在 tf.Session 的 ConfigProto 中設定 experimental.share_cluster_devices_in_session,否則您會收到類似「Assigned device '/job:ps/replica:0/task:0/device:CPU:0' does not match any device.」的錯誤訊息。

必要的選項可以這樣設定

session_config = tf.compat.v1.ConfigProto()
session_config.experimental.share_cluster_devices_in_session = True
run_config = tf.estimator.RunConfig(..., session_config=session_config)
estimator = tf.estimator.Estimator(..., config=run_config)

從 TF2.2 開始,此選項不再是實驗性的,可以省略 .experimental 部分。

載入 TF1 Hub 格式的舊版模型

有時,您的用例可能尚未提供新的 TF2 SavedModel,而您需要載入 TF1 Hub 格式的舊版模型。從 tensorflow_hub 0.7 版開始,您可以將 TF1 Hub 格式的舊版模型與 hub.KerasLayer 一起使用,如下所示

m = hub.KerasLayer(handle)
tensor_out = m(tensor_in)

此外,KerasLayer 公開了指定 tagssignatureoutput_keysignature_outputs_as_dict 的功能,以便更具體地使用 TF1 Hub 格式的舊版模型和舊版 SavedModel。

如需 TF1 Hub 格式相容性的更多資訊,請參閱模型相容性指南

使用較低階的 API

舊版 TF1 Hub 格式模型可以透過 tf.saved_model.load 載入。建議使用

# DEPRECATED: TensorFlow 1
m = hub.Module(handle, tags={"foo", "bar"})
tensors_out_dict = m(dict(x1=..., x2=...), signature="sig", as_dict=True)

建議使用

# TensorFlow 2
m = hub.load(path, tags={"foo", "bar"})
tensors_out_dict = m.signatures["sig"](x1=..., x2=...)

在這些範例中,m.signatures 是一個 TensorFlow 具體函式的 dict,以簽名名稱作為鍵。呼叫此類函式會計算其所有輸出,即使未使用也一樣。(這與 TF1 圖形模式的惰性評估不同。)