TensorFlow Lite 中的簽名

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視來源 下載筆記本

TensorFlow Lite 支援將 TensorFlow 模型的輸入/輸出規格轉換為 TensorFlow Lite 模型。輸入/輸出規格稱為「簽名」。在建構 SavedModel 或建立具體函式時,可以指定簽名。

TensorFlow Lite 中的簽名提供以下功能

  • 它們透過遵循 TensorFlow 模型的簽名,指定轉換後的 TensorFlow Lite 模型的輸入和輸出。
  • 允許單一 TensorFlow Lite 模型支援多個進入點。

簽名由三個部分組成

  • 輸入:從簽名中的輸入名稱到輸入張量的輸入對應。
  • 輸出:從簽名中的輸出名稱到輸出張量的輸出對應。
  • 簽名金鑰:識別圖表進入點的名稱。

設定

import tensorflow as tf

範例模型

假設我們有兩個任務,例如編碼和解碼,作為 TensorFlow 模型

class Model(tf.Module):

  @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32)])
  def encode(self, x):
    result = tf.strings.as_string(x)
    return {
         "encoded_result": result
    }

  @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)])
  def decode(self, x):
    result = tf.strings.to_number(x)
    return {
         "decoded_result": result
    }

在簽名方面,上述 TensorFlow 模型可以總結如下

  • 簽名

    • 金鑰:encode
    • 輸入:{"x"}
    • 輸出:{"encoded_result"}
  • 簽名

    • 金鑰:decode
    • 輸入:{"x"}
    • 輸出:{"decoded_result"}

轉換具有簽名的模型

TensorFlow Lite 轉換器 API 會將上述簽名資訊帶入轉換後的 TensorFlow Lite 模型。

從 TensorFlow 2.7.0 版開始,所有轉換器 API 均提供此轉換功能。請參閱範例用法。

從 Saved Model

model = Model()

# Save the model
SAVED_MODEL_PATH = 'content/saved_models/coding'

tf.saved_model.save(
    model, SAVED_MODEL_PATH,
    signatures={
      'encode': model.encode.get_concrete_function(),
      'decode': model.decode.get_concrete_function()
    })

# Convert the saved model using TFLiteConverter
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_PATH)
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TensorFlow Lite ops.
    tf.lite.OpsSet.SELECT_TF_OPS  # enable TensorFlow ops.
]
tflite_model = converter.convert()

# Print the signatures from the converted model
interpreter = tf.lite.Interpreter(model_content=tflite_model)
signatures = interpreter.get_signature_list()
print(signatures)

從 Keras 模型

# Generate a Keras model.
keras_model = tf.keras.Sequential(
    [
        tf.keras.layers.Dense(2, input_dim=4, activation='relu', name='x'),
        tf.keras.layers.Dense(1, activation='relu', name='output'),
    ]
)

# Convert the keras model using TFLiteConverter.
# Keras model converter API uses the default signature automatically.
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
tflite_model = converter.convert()

# Print the signatures from the converted model
interpreter = tf.lite.Interpreter(model_content=tflite_model)

signatures = interpreter.get_signature_list()
print(signatures)

從具體函式

model = Model()

# Convert the concrete functions using TFLiteConverter
converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [model.encode.get_concrete_function(),
     model.decode.get_concrete_function()], model)
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TensorFlow Lite ops.
    tf.lite.OpsSet.SELECT_TF_OPS  # enable TensorFlow ops.
]
tflite_model = converter.convert()

# Print the signatures from the converted model
interpreter = tf.lite.Interpreter(model_content=tflite_model)
signatures = interpreter.get_signature_list()
print(signatures)

執行簽名

TensorFlow 推論 API 支援以簽名為基礎的執行

  • 透過簽名指定的輸入和輸出名稱存取輸入/輸出張量。
  • 分別執行圖表的每個進入點,並以簽名金鑰識別。
  • 支援 SavedModel 的初始化程序。

目前提供 Java、C++ 和 Python 語言繫結。請參閱以下章節範例。

Java

try (Interpreter interpreter = new Interpreter(file_of_tensorflowlite_model)) {
  // Run encoding signature.
  Map<String, Object> inputs = new HashMap<>();
  inputs.put("x", input);
  Map<String, Object> outputs = new HashMap<>();
  outputs.put("encoded_result", encoded_result);
  interpreter.runSignature(inputs, outputs, "encode");

  // Run decoding signature.
  Map<String, Object> inputs = new HashMap<>();
  inputs.put("x", encoded_result);
  Map<String, Object> outputs = new HashMap<>();
  outputs.put("decoded_result", decoded_result);
  interpreter.runSignature(inputs, outputs, "decode");
}

C++

SignatureRunner* encode_runner =
    interpreter->GetSignatureRunner("encode");
encode_runner->ResizeInputTensor("x", {100});
encode_runner->AllocateTensors();

TfLiteTensor* input_tensor = encode_runner->input_tensor("x");
float* input = GetTensorData<float>(input_tensor);
// Fill `input`.

encode_runner->Invoke();

const TfLiteTensor* output_tensor = encode_runner->output_tensor(
    "encoded_result");
float* output = GetTensorData<float>(output_tensor);
// Access `output`.

Python

# Load the TFLite model in TFLite Interpreter
interpreter = tf.lite.Interpreter(model_content=tflite_model)

# Print the signatures from the converted model
signatures = interpreter.get_signature_list()
print('Signature:', signatures)

# encode and decode are callable with input as arguments.
encode = interpreter.get_signature_runner('encode')
decode = interpreter.get_signature_runner('decode')

# 'encoded' and 'decoded' are dictionaries with all outputs from the inference.
input = tf.constant([1, 2, 3], dtype=tf.float32)
print('Input:', input)
encoded = encode(x=input)
print('Encoded result:', encoded)
decoded = decode(x=encoded['encoded_result'])
print('Decoded result:', decoded)

已知限制

  • 由於 TFLite 解譯器不保證執行緒安全,因此來自同一個解譯器的簽名執行器不會同時執行。
  • 尚不支援 iOS/Swift。

更新

  • 版本 2.7
    • 已實作多重簽名功能。
    • 版本 2 的所有轉換器 API 都會產生啟用簽名的 TensorFlow Lite 模型。
  • 版本 2.5
    • 簽名功能可透過 from_saved_model 轉換器 API 取得。