可重複使用的 SavedModel

簡介

TensorFlow Hub 除了其他資產外,也代管 TensorFlow 2 的 SavedModel。這些 SavedModel 可以使用 obj = hub.load(url) 重新載入 Python 程式 [瞭解詳情]。傳回的 objtf.saved_model.load() 的結果 (請參閱 TensorFlow 的 SavedModel 指南)。這個物件可以有任意屬性,這些屬性可以是 tf.functions、tf.Variables (從預先訓練的值初始化)、其他資源,以及遞迴地更多這類物件。

本頁說明載入的 obj 為了在 TensorFlow Python 程式中重複使用而應實作的介面。符合這個介面的 SavedModel 稱為可重複使用的 SavedModel

重複使用意指圍繞 obj 建構更大的模型,包括微調模型的能力。微調意指進一步訓練載入的 obj 中的權重,作為周圍模型的一部分。損失函數和最佳化工具由周圍模型決定;obj 只定義從輸入到輸出啟動的對應 (「前向傳遞」),可能包括 dropout 或批次正規化等技術。

TensorFlow Hub 團隊建議在所有旨在以上述方式重複使用的 SavedModel 中實作可重複使用的 SavedModel 介面tensorflow_hub 程式庫中的許多公用程式 (特別是 hub.KerasLayer) 都要求 SavedModel 實作這個介面。

與 SignatureDefs 的關係

從 tf.functions 和其他 TF2 功能來看,這個介面與 SavedModel 的簽名是分開的,後者自 TF1 起便已推出,並繼續在 TF2 中用於推論 (例如將 SavedModel 部署到 TF Serving 或 TF Lite)。用於推論的簽名不夠明確,無法支援微調,而 tf.function 為重複使用的模型提供更自然且更明確的 Python API

與模型建構程式庫的關係

可重複使用的 SavedModel 僅使用 TensorFlow 2 基本元件,與任何特定的模型建構程式庫 (例如 Keras 或 Sonnet) 無關。這有助於跨模型建構程式庫重複使用,而不會依賴原始模型建構程式碼。

將可重複使用的 SavedModel 載入或從任何給定的模型建構程式庫儲存時,都需要進行一定程度的調整。對於 Keras,hub.KerasLayer 提供載入功能,而 Keras 內建的 SavedModel 格式儲存功能已針對 TF2 重新設計,目標是提供這個介面的超集 (請參閱 2019 年 5 月的 RFC)。

與特定工作「常見的 SavedModel API」的關係

本頁上的介面定義允許任意數量和類型的輸入和輸出。TF Hub 的常見 SavedModel API 使用特定工作的使用慣例來完善這個通用介面,使模型易於互換。

介面定義

屬性

可重複使用的 SavedModel 是 TensorFlow 2 SavedModel,因此 obj = tf.saved_model.load(...) 會傳回具有以下屬性的物件

  • __call__。必要。實作模型運算 (「前向傳遞」) 的 tf.function,受以下規格約束。

  • variables:tf.Variable 物件的清單,列出 __call__ 的任何可能調用所使用的所有變數,包括可訓練和不可訓練的變數。

    如果這個清單是空的,則可以省略。

  • trainable_variables:tf.Variable 物件的清單,對於所有元素,v.trainable 都是 true。這些變數必須是 variables 的子集。這些是在微調物件時要訓練的變數。SavedModel 建立者可以選擇在這裡省略一些原本可訓練的變數,以表明這些變數在微調期間不應修改。

    如果這個清單是空的,則可以省略,特別是如果 SavedModel 不支援微調。

  • regularization_losses:tf.functions 的清單,每個函數都接受零個輸入並傳回單一純量浮點張量。對於微調,建議 SavedModel 使用者將這些函數作為額外的正規化項納入損失中 (在最簡單的情況下,無需進一步縮放)。通常,這些函數用於表示權重正規化器。(由於缺少輸入,這些 tf.functions 無法表示活動正規化器。)

    如果這個清單是空的,則可以省略,特別是如果 SavedModel 不支援微調或不希望規定權重正規化。

__call__ 函數

還原的 SavedModel obj 具有 obj.__call__ 屬性,這個屬性是還原的 tf.function,並允許像下面這樣調用 obj

摘要 (虛擬程式碼)

outputs = obj(inputs, trainable=..., **kwargs)

引數

引數如下。

  • 有一個位置必要引數,其中包含一批 SavedModel 的輸入啟動。其類型是以下其中之一

    • 單一輸入的單一張量,
    • 用於未命名輸入的順序序列的張量清單,
    • 以特定輸入名稱集為鍵的張量字典。

    (這個介面的未來版本可能會允許更通用的巢狀結構。) SavedModel 建立者選擇其中一個以及張量形狀和資料類型。在有用的情況下,形狀的一些維度應該是不確定的 (特別是批次大小)。

  • 可能有一個選用的關鍵字引數 training,它接受 Python 布林值 TrueFalse。預設值為 False。如果模型支援微調,並且如果其運算在這兩者之間有所不同 (例如,在 dropout 和批次正規化中),則這種區別是透過這個引數實作的。否則,這個引數可能會不存在。

    不要求 __call__ 接受張量值 training 引數。如有必要,呼叫者有責任使用 tf.cond() 在它們之間分派。

  • SavedModel 建立者可以選擇接受更多具有特定名稱的選用 kwargs

    • 對於張量值引數,SavedModel 建立者定義其允許的資料類型和形狀。tf.function 接受以 tf.TensorSpec 輸入追蹤的引數上的 Python 預設值。這類引數可用於自訂 __call__ 中涉及的數值超參數 (例如,dropout 率)。

    • 對於 Python 值引數,SavedModel 建立者定義其允許的值。這類引數可以用作標記,以在追蹤的函數中進行離散選擇 (但請注意追蹤的組合爆炸)。

還原的 __call__ 函數必須為引數的所有允許組合提供追蹤。在 TrueFalse 之間翻轉 training 不得改變引數的允許性。

結果

從調用 obj 得到的 outputs 可以是

  • 單一輸出的單一張量,
  • 用於未命名輸出的順序序列的張量清單,
  • 以特定輸出名稱集為鍵的張量字典。

(這個介面的未來版本可能會允許更通用的巢狀結構。) 傳回類型可能會因 Python 值 kwargs 而異。這允許標記產生額外的輸出。SavedModel 建立者定義輸出資料類型和形狀及其對輸入的依賴性。

具名可調用項

可重複使用的 SavedModel 可以透過將多個模型片段放入具名子物件 (例如 obj.fooobj.bar 等) 中,以上述方式提供這些模型片段。每個子物件都提供一個 __call__ 方法和關於變數等的支援屬性,這些屬性特定於該模型片段。對於上面的範例,將會有 obj.foo.__call__obj.foo.variables 等。

請注意,這個介面涵蓋將裸 tf.function 直接新增為 tf.foo 的方法。

可重複使用的 SavedModel 的使用者只需要處理一個層級的巢狀結構 (obj.bar 但不是 obj.bar.baz)。(這個介面的未來版本可能會允許更深的巢狀結構,並且可能會免除頂層物件本身可調用的要求。)

總結

與程序內 API 的關係

本文檔描述了一個 Python 類別的介面,該介面由 tf.function 和 tf.Variable 等基本元件組成,這些基本元件可以在透過 tf.saved_model.save()tf.saved_model.load() 序列化往返後仍然存在。但是,介面已經存在於傳遞給 tf.saved_model.save() 的原始物件上。調整到這個介面可以在單個 TensorFlow 程式中實現跨模型建構 API 的模型片段交換。