本頁說明如何在將 TensorFlow 程式碼從 TensorFlow 1 遷移至 TensorFlow 2 的同時,繼續使用 TensorFlow Hub。它補充了 TensorFlow 的一般遷移指南。
對於 TF2,TF Hub 已切換離開舊版 hub.Module
API,不再像 tf.contrib.v1.layers
那樣建構 tf.compat.v1.Graph
。取而代之的是,現在有一個 hub.KerasLayer
可與其他 Keras 層一起使用,以建構 tf.keras.Model
(通常在 TF2 的新 eager execution 環境中)及其底層的 hub.load()
方法,用於低階 TensorFlow 程式碼。
hub.Module
API 在 tensorflow_hub
函式庫中仍然可用,以便在 TF1 和 TF2 的 TF1 相容模式中使用。它只能載入 TF1 Hub 格式的模型。
hub.load()
和 hub.KerasLayer
的新 API 適用於 TensorFlow 1.15(在 eager 和 graph 模式下)以及 TensorFlow 2。這個新 API 可以載入新的 TF2 SavedModel 資產,並且在模型相容性指南中列出的限制下,也可以載入 TF1 Hub 格式的舊版模型。
一般而言,建議盡可能使用新 API。
新 API 摘要
hub.load()
是從 TensorFlow Hub(或相容服務)載入 SavedModel 的新底層函式。它包裝了 TF2 的 tf.saved_model.load()
;TensorFlow 的 SavedModel 指南描述了您可以對結果執行的操作。
m = hub.load(handle)
outputs = m(inputs)
hub.KerasLayer
類別呼叫 hub.load()
並調整結果,以便在 Keras 中與其他 Keras 層一起使用。(它甚至可以作為以其他方式使用的已載入 SavedModel 的便利包裝函式。)
model = tf.keras.Sequential([
hub.KerasLayer(handle),
...])
許多教學課程展示了這些 API 的實際運作方式。以下是一些範例
在 Estimator 訓練中使用新 API
如果您在 Estimator 中使用 TF2 SavedModel 搭配參數伺服器進行訓練(或在 TF1 Session 中使用放置在遠端裝置上的變數),則需要在 tf.Session 的 ConfigProto 中設定 experimental.share_cluster_devices_in_session
,否則您會收到類似「Assigned device '/job:ps/replica:0/task:0/device:CPU:0' does not match any device.」的錯誤訊息。
必要的選項可以這樣設定
session_config = tf.compat.v1.ConfigProto()
session_config.experimental.share_cluster_devices_in_session = True
run_config = tf.estimator.RunConfig(..., session_config=session_config)
estimator = tf.estimator.Estimator(..., config=run_config)
從 TF2.2 開始,此選項不再是實驗性的,可以省略 .experimental
部分。
載入 TF1 Hub 格式的舊版模型
有時,您的用例可能尚未提供新的 TF2 SavedModel,而您需要載入 TF1 Hub 格式的舊版模型。從 tensorflow_hub
0.7 版開始,您可以將 TF1 Hub 格式的舊版模型與 hub.KerasLayer
一起使用,如下所示
m = hub.KerasLayer(handle)
tensor_out = m(tensor_in)
此外,KerasLayer
公開了指定 tags
、signature
、output_key
和 signature_outputs_as_dict
的功能,以便更具體地使用 TF1 Hub 格式的舊版模型和舊版 SavedModel。
如需 TF1 Hub 格式相容性的更多資訊,請參閱模型相容性指南。
使用較低階的 API
舊版 TF1 Hub 格式模型可以透過 tf.saved_model.load
載入。建議使用
# DEPRECATED: TensorFlow 1
m = hub.Module(handle, tags={"foo", "bar"})
tensors_out_dict = m(dict(x1=..., x2=...), signature="sig", as_dict=True)
建議使用
# TensorFlow 2
m = hub.load(path, tags={"foo", "bar"})
tensors_out_dict = m.signatures["sig"](x1=..., x2=...)
在這些範例中,m.signatures
是一個 TensorFlow 具體函式的 dict,以簽名名稱作為鍵。呼叫此類函式會計算其所有輸出,即使未使用也一樣。(這與 TF1 圖形模式的惰性評估不同。)