TF Hub 的通用 SavedModel API

簡介

TensorFlow Hub 託管各種工作適用的模型。建議相同工作適用的模型實作通用 API,以便模型消費者輕鬆交換模型,而無需修改使用模型的程式碼,即使模型來自不同的發布者也一樣。

目標是讓相同工作適用的不同模型交換作業,簡化為切換字串值超參數。如此一來,模型消費者就能輕鬆找到最適合他們問題的模型。

這個目錄收集 TF2 SavedModel 格式模型通用 API 的規格。(它取代了現已淘汰的 TF1 Hub 格式的通用簽名。)

可重複使用的 SavedModel:通用基礎

可重複使用的 SavedModel API 定義將 SavedModel 載回 Python 程式中,並將其重複用作較大型 TensorFlow 模型一部分的一般慣例。

基本用法

obj = hub.load("path/to/model")  # That's tf.saved_model.load() after download.
outputs = obj(inputs, training=False)  # Invokes the tf.function obj.__call__.

對於 Keras 使用者,hub.KerasLayer 類別依賴此 API 將可重複使用的 SavedModel 包裝為 Keras 層 (讓 Keras 使用者無需了解其細節),輸入和輸出則根據下方列出的工作專屬 API。

工作專屬 API

這些透過特定 ML 工作和資料類型的慣例,精進可重複使用的 SavedModel API。