TFLite 撰寫工具

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視原始碼 下載筆記本

TensorFlow Lite Authoring API 提供一種維護與 TensorFlow Lite 相容的 tf.function 模型的方法。

設定

import tensorflow as tf

TensorFlow 至 TensorFlow Lite 相容性問題

如果您想在裝置上使用 TF 模型,則需要將其轉換為 TFLite 模型,才能從 TFLite 解譯器中使用。在轉換期間,您可能會因為 TFLite 內建運算元集不支援 TensorFlow 運算元而遇到相容性錯誤。

這是個惱人的問題。如何在模型撰寫時等更早的時間點偵測到它?

請注意,以下程式碼會在 converter.convert() 呼叫時失敗。

@tf.function(input_signature=[
    tf.TensorSpec(shape=[None], dtype=tf.float32)
])
def f(x):
  return tf.cosh(x)

# Evaluate the tf.function
result = f(tf.constant([0.0]))
print (f"result = {result}")
# Convert the tf.function
converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [f.get_concrete_function()], f)
try:
  fb_model = converter.convert()
except Exception as e:
  print(f"Got an exception: {e}")

簡易目標感知撰寫用法

我們推出了 Authoring API,以便在模型撰寫時偵測 TensorFlow Lite 相容性問題。

您只需要新增 @tf.lite.experimental.authoring.compatible 裝飾器來包裝您的 tf.function 模型,以檢查 TFLite 相容性。

完成後,當您評估模型時,就會自動檢查相容性。

@tf.lite.experimental.authoring.compatible
@tf.function(input_signature=[
    tf.TensorSpec(shape=[None], dtype=tf.float32)
])
def f(x):
  return tf.cosh(x)

# Evaluate the tf.function
result = f(tf.constant([0.0]))
print (f"result = {result}")

如果發現任何 TensorFlow Lite 相容性問題,系統會顯示 COMPATIBILITY WARNINGCOMPATIBILITY ERROR,並指出有問題的運算元的確切位置。在此範例中,系統會顯示 tf.function 模型中 tf.Cosh 運算元的位置。

您也可以使用 <function_name>.get_compatibility_log() 方法檢查相容性記錄。

compatibility_log = '\n'.join(f.get_compatibility_log())
print (f"compatibility_log = {compatibility_log}")

針對不相容性引發例外狀況

您可以為 @tf.lite.experimental.authoring.compatible 裝飾器提供選項。raise_exception 選項會在您嘗試評估裝飾的模型時提供您例外狀況。

@tf.lite.experimental.authoring.compatible(raise_exception=True)
@tf.function(input_signature=[
    tf.TensorSpec(shape=[None], dtype=tf.float32)
])
def f(x):
  return tf.cosh(x)

# Evaluate the tf.function
try:
  result = f(tf.constant([0.0]))
  print (f"result = {result}")
except Exception as e:
  print(f"Got an exception: {e}")

指定「選取 TF 運算元」用法

如果您已瞭解選取 TF 運算元用法,可以透過設定 converter_target_spec 將此資訊告知 Authoring API。它與您將用於 tf.lite.TFLiteConverter API 的 tf.lite.TargetSpec 物件相同。

target_spec = tf.lite.TargetSpec()
target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,
    tf.lite.OpsSet.SELECT_TF_OPS,
]
@tf.lite.experimental.authoring.compatible(converter_target_spec=target_spec, raise_exception=True)
@tf.function(input_signature=[
    tf.TensorSpec(shape=[None], dtype=tf.float32)
])
def f(x):
  return tf.cosh(x)

# Evaluate the tf.function
result = f(tf.constant([0.0]))
print (f"result = {result}")

檢查 GPU 相容性

如果您想確保模型與 TensorFlow Lite 的 GPU 委派相容,可以設定 tf.lite.TargetSpecexperimental_supported_backends

以下範例說明如何確保模型的 GPU 委派相容性。請注意,此模型有相容性問題,因為它將 2D 張量與 tf.slice 運算元和不受支援的 tf.cosh 運算元搭配使用。您會看到兩個 COMPATIBILITY WARNING,其中包含位置資訊。

target_spec = tf.lite.TargetSpec()
target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,
    tf.lite.OpsSet.SELECT_TF_OPS,
]
target_spec.experimental_supported_backends = ["GPU"]
@tf.lite.experimental.authoring.compatible(converter_target_spec=target_spec)
@tf.function(input_signature=[
    tf.TensorSpec(shape=[4, 4], dtype=tf.float32)
])
def func(x):
  y = tf.cosh(x)
  return y + tf.slice(x, [1, 1], [1, 1])

result = func(tf.ones(shape=(4,4), dtype=tf.float32))

閱讀詳情

如需更多資訊,請參閱