簡介
TensorFlow Hub 除了其他資產外,也代管 TensorFlow 2 的 SavedModel。這些 SavedModel 可以使用 obj = hub.load(url)
重新載入 Python 程式 [瞭解詳情]。傳回的 obj
是 tf.saved_model.load()
的結果 (請參閱 TensorFlow 的 SavedModel 指南)。這個物件可以有任意屬性,這些屬性可以是 tf.functions、tf.Variables (從預先訓練的值初始化)、其他資源,以及遞迴地更多這類物件。
本頁說明載入的 obj
為了在 TensorFlow Python 程式中重複使用而應實作的介面。符合這個介面的 SavedModel 稱為可重複使用的 SavedModel。
重複使用意指圍繞 obj
建構更大的模型,包括微調模型的能力。微調意指進一步訓練載入的 obj
中的權重,作為周圍模型的一部分。損失函數和最佳化工具由周圍模型決定;obj
只定義從輸入到輸出啟動的對應 (「前向傳遞」),可能包括 dropout 或批次正規化等技術。
TensorFlow Hub 團隊建議在所有旨在以上述方式重複使用的 SavedModel 中實作可重複使用的 SavedModel 介面。tensorflow_hub
程式庫中的許多公用程式 (特別是 hub.KerasLayer
) 都要求 SavedModel 實作這個介面。
與 SignatureDefs 的關係
從 tf.functions 和其他 TF2 功能來看,這個介面與 SavedModel 的簽名是分開的,後者自 TF1 起便已推出,並繼續在 TF2 中用於推論 (例如將 SavedModel 部署到 TF Serving 或 TF Lite)。用於推論的簽名不夠明確,無法支援微調,而 tf.function
為重複使用的模型提供更自然且更明確的 Python API。
與模型建構程式庫的關係
可重複使用的 SavedModel 僅使用 TensorFlow 2 基本元件,與任何特定的模型建構程式庫 (例如 Keras 或 Sonnet) 無關。這有助於跨模型建構程式庫重複使用,而不會依賴原始模型建構程式碼。
將可重複使用的 SavedModel 載入或從任何給定的模型建構程式庫儲存時,都需要進行一定程度的調整。對於 Keras,hub.KerasLayer 提供載入功能,而 Keras 內建的 SavedModel 格式儲存功能已針對 TF2 重新設計,目標是提供這個介面的超集 (請參閱 2019 年 5 月的 RFC)。
與特定工作「常見的 SavedModel API」的關係
本頁上的介面定義允許任意數量和類型的輸入和輸出。TF Hub 的常見 SavedModel API 使用特定工作的使用慣例來完善這個通用介面,使模型易於互換。
介面定義
屬性
可重複使用的 SavedModel 是 TensorFlow 2 SavedModel,因此 obj = tf.saved_model.load(...)
會傳回具有以下屬性的物件
__call__
。必要。實作模型運算 (「前向傳遞」) 的 tf.function,受以下規格約束。variables
:tf.Variable 物件的清單,列出__call__
的任何可能調用所使用的所有變數,包括可訓練和不可訓練的變數。如果這個清單是空的,則可以省略。
trainable_variables
:tf.Variable 物件的清單,對於所有元素,v.trainable
都是 true。這些變數必須是variables
的子集。這些是在微調物件時要訓練的變數。SavedModel 建立者可以選擇在這裡省略一些原本可訓練的變數,以表明這些變數在微調期間不應修改。如果這個清單是空的,則可以省略,特別是如果 SavedModel 不支援微調。
regularization_losses
:tf.functions 的清單,每個函數都接受零個輸入並傳回單一純量浮點張量。對於微調,建議 SavedModel 使用者將這些函數作為額外的正規化項納入損失中 (在最簡單的情況下,無需進一步縮放)。通常,這些函數用於表示權重正規化器。(由於缺少輸入,這些 tf.functions 無法表示活動正規化器。)如果這個清單是空的,則可以省略,特別是如果 SavedModel 不支援微調或不希望規定權重正規化。
__call__
函數
還原的 SavedModel obj
具有 obj.__call__
屬性,這個屬性是還原的 tf.function,並允許像下面這樣調用 obj
。
摘要 (虛擬程式碼)
outputs = obj(inputs, trainable=..., **kwargs)
引數
引數如下。
有一個位置必要引數,其中包含一批 SavedModel 的輸入啟動。其類型是以下其中之一
- 單一輸入的單一張量,
- 用於未命名輸入的順序序列的張量清單,
- 以特定輸入名稱集為鍵的張量字典。
(這個介面的未來版本可能會允許更通用的巢狀結構。) SavedModel 建立者選擇其中一個以及張量形狀和資料類型。在有用的情況下,形狀的一些維度應該是不確定的 (特別是批次大小)。
可能有一個選用的關鍵字引數
training
,它接受 Python 布林值True
或False
。預設值為False
。如果模型支援微調,並且如果其運算在這兩者之間有所不同 (例如,在 dropout 和批次正規化中),則這種區別是透過這個引數實作的。否則,這個引數可能會不存在。不要求
__call__
接受張量值training
引數。如有必要,呼叫者有責任使用tf.cond()
在它們之間分派。SavedModel 建立者可以選擇接受更多具有特定名稱的選用
kwargs
。對於張量值引數,SavedModel 建立者定義其允許的資料類型和形狀。
tf.function
接受以 tf.TensorSpec 輸入追蹤的引數上的 Python 預設值。這類引數可用於自訂__call__
中涉及的數值超參數 (例如,dropout 率)。對於 Python 值引數,SavedModel 建立者定義其允許的值。這類引數可以用作標記,以在追蹤的函數中進行離散選擇 (但請注意追蹤的組合爆炸)。
還原的 __call__
函數必須為引數的所有允許組合提供追蹤。在 True
和 False
之間翻轉 training
不得改變引數的允許性。
結果
從調用 obj
得到的 outputs
可以是
- 單一輸出的單一張量,
- 用於未命名輸出的順序序列的張量清單,
- 以特定輸出名稱集為鍵的張量字典。
(這個介面的未來版本可能會允許更通用的巢狀結構。) 傳回類型可能會因 Python 值 kwargs 而異。這允許標記產生額外的輸出。SavedModel 建立者定義輸出資料類型和形狀及其對輸入的依賴性。
具名可調用項
可重複使用的 SavedModel 可以透過將多個模型片段放入具名子物件 (例如 obj.foo
、obj.bar
等) 中,以上述方式提供這些模型片段。每個子物件都提供一個 __call__
方法和關於變數等的支援屬性,這些屬性特定於該模型片段。對於上面的範例,將會有 obj.foo.__call__
、obj.foo.variables
等。
請注意,這個介面不涵蓋將裸 tf.function 直接新增為 tf.foo
的方法。
可重複使用的 SavedModel 的使用者只需要處理一個層級的巢狀結構 (obj.bar
但不是 obj.bar.baz
)。(這個介面的未來版本可能會允許更深的巢狀結構,並且可能會免除頂層物件本身可調用的要求。)
總結
與程序內 API 的關係
本文檔描述了一個 Python 類別的介面,該介面由 tf.function 和 tf.Variable 等基本元件組成,這些基本元件可以在透過 tf.saved_model.save()
和 tf.saved_model.load()
序列化往返後仍然存在。但是,介面已經存在於傳遞給 tf.saved_model.save()
的原始物件上。調整到這個介面可以在單個 TensorFlow 程式中實現跨模型建構 API 的模型片段交換。