使用 tf.function 提升效能

在 TensorFlow.org 上查看 在 Google Colab 中執行 在 GitHub 上查看原始碼 下載筆記本

在 TensorFlow 2 中,預設已開啟立即執行。使用者介面直觀且彈性十足 (執行一次性作業更加輕鬆快速),但可能會犧牲效能和部署能力。

您可以使用 tf.function 從程式碼建立圖。這是一種轉換工具,可從您的 Python 程式碼建立獨立於 Python 的資料流程圖。這將協助您建立高效能和可攜式模型,而且是使用 SavedModel 的必要條件。

本指南將協助您概念化 tf.function 在底層的運作方式,以便您有效使用。

主要重點和建議如下:

  • 在立即模式中進行偵錯,然後使用 @tf.function 裝飾。
  • 請勿依賴 Python 的副作用,例如物件突變或清單附加。
  • tf.function 最適合搭配 TensorFlow 運算元;NumPy 和 Python 呼叫會轉換為常數。

設定

import tensorflow as tf
2023-11-28 02:22:39.038158: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-28 02:22:39.038208: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-28 02:22:39.039647: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

定義輔助函式以示範您可能會遇到的錯誤類型

import traceback
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('Caught expected exception \n  {}:'.format(error_class))
    traceback.print_exc(limit=2)
  except Exception as e:
    raise e
  else:
    raise Exception('Expected {} to be raised but no error was raised!'.format(
        error_class))

基本概念

用法

您定義的 tf.function (例如透過套用 @tf.function 裝飾器) 就像核心 TensorFlow 運算一樣:您可以立即執行;您可以計算梯度等等。

@tf.function  # The decorator converts `add` into a `PolymorphicFunction`.
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

您可以在其他 tf.function 內使用 tf.function

@tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

對於具有許多小型運算的圖,tf.function 可能比立即程式碼更快。但對於具有少數昂貴運算 (例如捲積) 的圖,您可能不會看到太多加速。

import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# Warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")
Eager conv: 0.005143817999964995
Function conv: 0.005329717999984496
Note how there's not much difference in performance for convolutions

追蹤

本節說明 tf.function 在底層的運作方式,包括未來可能會變更的實作詳細資料。不過,一旦您瞭解追蹤發生的原因和時間,就能更輕鬆有效地使用 tf.function

何謂「追蹤」?

tf.function 會在 TensorFlow 圖中執行您的程式。但是,tf.Graph 無法表示您在立即 TensorFlow 程式中編寫的所有內容。例如,Python 支援多型,但 tf.Graph 要求其輸入具有指定的資料類型和維度。或者,您可能會執行輔助工作,例如讀取命令列引數、引發錯誤或使用更複雜的 Python 物件;這些都無法在 tf.Graph 中執行。

tf.function 透過將您的程式碼分成兩個階段來彌合此差距

1) 在第一個階段 (稱為「追蹤」) 中,tf.function 會建立新的 tf.Graph。Python 程式碼會正常執行,但所有 TensorFlow 運算 (例如新增兩個張量) 都會延遲:它們會由 tf.Graph 擷取,而不會執行。

2) 在第二個階段中,會執行 tf.Graph,其中包含第一個階段中延遲的所有內容。此階段比追蹤階段快得多。

根據其輸入,tf.function 在呼叫時不一定會執行第一個階段。請參閱下方的「追蹤規則」,以更瞭解它如何做出該判斷。略過第一個階段而僅執行第二個階段,正是 TensorFlow 具有高效能的原因。

tf.function 決定追蹤時,追蹤階段之後會立即接著第二個階段,因此呼叫 tf.function 會同時建立和執行 tf.Graph。稍後您將瞭解如何使用 get_concrete_function 僅執行追蹤階段。

當您將不同類型的引數傳遞到 tf.function 時,兩個階段都會執行

@tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()
Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)

請注意,如果您使用相同的引數類型重複呼叫 tf.function,TensorFlow 將略過追蹤階段並重複使用先前追蹤的圖,因為產生的圖會相同。

# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
tf.Tensor(b'bb', shape=(), dtype=string)

您可以使用 pretty_printed_concrete_signatures() 來查看所有可用的追蹤

print(double.pretty_printed_concrete_signatures())
Input Parameters:
  a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.int32, name=None)
Output Type:
  TensorSpec(shape=(), dtype=tf.int32, name=None)
Captures:
  None

Input Parameters:
  a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.float32, name=None)
Output Type:
  TensorSpec(shape=(), dtype=tf.float32, name=None)
Captures:
  None

Input Parameters:
  a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.string, name=None)
Output Type:
  TensorSpec(shape=(), dtype=tf.string, name=None)
Captures:
  None

到目前為止,您已瞭解 tf.function 在 TensorFlow 的圖追蹤邏輯之上建立快取的動態調度層。為了更具體地說明術語

  • tf.Graph 是 TensorFlow 計算的原始、語言無關、可攜式表示法。
  • 追蹤是透過 Python 程式碼產生新 tf.Graph 的程序。
  • tf.Graph 的執行個體專用於它追蹤的特定輸入類型。不同的類型需要重新追蹤。
  • 每個追蹤的 tf.Graph 都有對應的 ConcreteFunction
  • tf.function 管理 ConcreteFunction 的快取,並為您的輸入挑選正確的 ConcreteFunction
  • tf.function 包裝將追蹤的 Python 函式,並傳回 tf.types.experimental.PolymorphicFunction 物件。

追蹤規則

呼叫時,tf.function 首先使用每個引數的 tf.types.experimental.TraceType 評估每個輸入引數的類型。這用於建構 tf.types.experimental.FunctionType,描述所需 ConcreteFunction 的簽名。我們會將此 FunctionType 與現有 ConcreteFunctionFunctionType 進行比較。如果找到相符的 ConcreteFunction,則呼叫會分派給它。如果找不到相符項,則會為所需的 FunctionType 追蹤新的 ConcreteFunction

如果找到多個相符項,則會選擇最特定的簽名。比對是透過子類型化完成,很像 C++ 或 Java 中的一般函式呼叫。例如,TensorShape([1, 2])TensorShape([None, None]) 的子類型,因此使用 TensorShape([1, 2]) 呼叫 tf.function 可以分派到使用 TensorShape([None, None]) 產生的 ConcreteFunction,但如果也存在具有 TensorShape([1, None])ConcreteFunction,則會優先處理它,因為它更具體。

TraceType 是從輸入引數判斷,如下所示

  • 對於 Tensor,類型由 Tensordtypeshape 參數化;已排名形狀是未排名形狀的子類型;固定維度是未知維度的子類型
  • 對於 Variable,類型與 Tensor 類似,但也包含變數的唯一資源 ID,這是正確連接控制依附項所必需的
  • 對於 Python 基本值,類型對應於本身。例如,值 3TraceTypeLiteralTraceType<3>,而不是 int
  • 對於 Python 有序容器 (例如 listtuple 等),類型由其元素的類型參數化;例如,[1, 2] 的類型是 ListTraceType<LiteralTraceType<1>, LiteralTraceType<2>>,而 [2, 1] 的類型是 ListTraceType<LiteralTraceType<2>, LiteralTraceType<1>>,這兩者不同。
  • 對於 Python 對應 (例如 dict),類型也是從相同鍵到值類型的對應,而不是實際值。例如,{1: 2, 3: 4} 的類型是 MappingTraceType<<KeyValue<1, LiteralTraceType<2>>>, <KeyValue<3, LiteralTraceType<4>>>>。但是,與有序容器不同,{1: 2, 3: 4}{3: 4, 1: 2} 具有對等類型。
  • 對於實作 __tf_tracing_type__ 方法的 Python 物件,類型是該方法傳回的任何內容。
  • 對於任何其他 Python 物件,類型是泛型 TraceType,而比對程序為

    • 首先,它會檢查物件是否與先前追蹤中使用的物件相同 (使用 Python id()is)。請注意,即使物件已變更,這仍然會相符,因此如果您使用 Python 物件作為 tf.function 引數,最好使用不可變物件。
    • 接下來,它會檢查物件是否等於先前追蹤中使用的物件 (使用 Python ==)。

    請注意,此程序僅保留物件的 weakref,因此僅在物件在範圍內/未刪除時才有效。

控制重新追蹤

重新追蹤 (即 tf.function 建立多個追蹤時) 有助於確保 TensorFlow 為每組輸入產生正確的圖。但是,追蹤是一項昂貴的作業!如果您的 tf.function 為每次呼叫重新追蹤新的圖,您會發現程式碼執行速度比未使用 tf.function 時更慢。

若要控制追蹤行為,您可以使用下列技巧

將固定的 input_signature 傳遞至 tf.function

這會強制 tf.function 將自身限制為僅一個 tf.types.experimental.FunctionType,該類型由 input_signature 列舉的類型組成。無法分派到此 FunctionType 的呼叫將會擲回錯誤。

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
  print("Tracing with", x)
  return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# You specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(TypeError):
  next_collatz(tf.constant([[1, 2], [3, 4]]))

# You specified an int32 dtype in the input signature, so this should fail.
with assert_raises(TypeError):
  next_collatz(tf.constant([1.0, 2.0]))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception 
  <class 'TypeError'>:
Caught expected exception 
  <class 'TypeError'>:
Traceback (most recent call last):
  File "/tmpfs/tmp/ipykernel_11117/3551158538.py", line 8, in assert_raises
    yield
  File "/tmpfs/tmp/ipykernel_11117/3657259638.py", line 9, in <module>
    next_collatz(tf.constant([[1, 2], [3, 4]]))
TypeError: Binding inputs to tf.function failed due to `Can not cast TensorSpec(shape=(2, 2), dtype=tf.int32, name=None) to TensorSpec(shape=(None,), dtype=tf.int32, name=None)`. Received args: (<tf.Tensor: shape=(2, 2), dtype=int32, numpy=
array([[1, 2],
       [3, 4]], dtype=int32)>,) and kwargs: {} for signature: (x: TensorSpec(shape=(None,), dtype=tf.int32, name=None)).
Traceback (most recent call last):
  File "/tmpfs/tmp/ipykernel_11117/3551158538.py", line 8, in assert_raises
    yield
  File "/tmpfs/tmp/ipykernel_11117/3657259638.py", line 13, in <module>
    next_collatz(tf.constant([1.0, 2.0]))
TypeError: Binding inputs to tf.function failed due to `Can not cast TensorSpec(shape=(2,), dtype=tf.float32, name=None) to TensorSpec(shape=(None,), dtype=tf.int32, name=None)`. Received args: (<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1., 2.], dtype=float32)>,) and kwargs: {} for signature: (x: TensorSpec(shape=(None,), dtype=tf.int32, name=None)).

使用未知維度以獲得彈性

由於 TensorFlow 會根據張量的形狀比對張量,因此使用 None 維度作為萬用字元,將允許 tf.function 為可變大小的輸入重複使用追蹤。如果您有不同長度的序列,或每個批次的不同大小的圖片,就可能會發生可變大小的輸入。您可以查看 TransformerDeep Dream 教學課程以取得範例。

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
  print('Tracing with', x)
  return x

# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)

使用 reduce_retracing 取得自動彈性

啟用 reduce_retracing 時,tf.function 會自動識別它正在觀察的輸入類型的超類型,並選擇自動追蹤更通用的圖。這不如直接設定 input_signature 有效率,但在需要支援多種類型時很有用。

@tf.function(reduce_retracing=True)
def g(x):
  print('Tracing with', x)
  return x

# Traces once.
print(g(tf.constant([1, 2, 3])))

# Traces again, but more generalized this time.
print(g(tf.constant([1, 2, 3, 4, 5])))

# No more tracing!
print(g(tf.constant([1, 2, 3, 4, 5, 6, 7])))
print(g(tf.constant([1, 2, 3, 4, 5, 6, 7, 8, 9])))
Tracing with Tensor("x:0", shape=(3,), dtype=int32)
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)
tf.Tensor([1 2 3 4 5 6 7], shape=(7,), dtype=int32)
tf.Tensor([1 2 3 4 5 6 7 8 9], shape=(9,), dtype=int32)

傳遞張量而非 Python 常值

Python 引數通常用於控制超參數和圖建構,例如 num_layers=10training=Truenonlinearity='relu'。因此,如果 Python 引數變更,您必須重新追蹤圖才有意義。

但是,Python 引數可能未用於控制圖建構。在這些情況下,Python 值的變更可能會觸發不必要的重新追蹤。以這個訓練迴圈為例,AutoGraph 將動態展開此迴圈。儘管有多個追蹤,但產生的圖實際上是相同的,因此重新追蹤是不必要的。

def train_one_step():
  pass

@tf.function
def train(num_steps):
  print("Tracing with num_steps = ", num_steps)
  tf.print("Executing with num_steps = ", num_steps)
  for _ in tf.range(num_steps):
    train_one_step()

print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)

print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
Retracing occurs for different Python arguments.
Tracing with num_steps =  10
Executing with num_steps =  10
Tracing with num_steps =  20
Executing with num_steps =  20

Traces are reused for Tensor arguments.
Tracing with num_steps =  Tensor("num_steps:0", shape=(), dtype=int32)
Executing with num_steps =  10
Executing with num_steps =  20

如果您需要強制重新追蹤,請建立新的 tf.function。個別的 tf.function 物件保證不會共用追蹤。

def f():
  print('Tracing!')
  tf.print('Executing')

tf.function(f)()
tf.function(f)()
Tracing!
Executing
Tracing!
Executing

使用追蹤通訊協定

如果可以,您應該偏好將 Python 類型轉換為 tf.experimental.ExtensionType。此外,ExtensionTypeTraceType 是與其關聯的 tf.TypeSpec。因此,如果需要,您可以簡單地覆寫預設 tf.TypeSpec 以控制 ExtensionTypeTracing Protocol。如需詳細資訊,請參閱擴充類型指南中的「自訂 ExtensionType 的 TypeSpec」一節。

否則,若要直接控制 tf.function 應針對特定 Python 類型重新追蹤的時間,您可以自行實作 Tracing Protocol

@tf.function
def get_mixed_flavor(fruit_a, fruit_b):
  return fruit_a.flavor + fruit_b.flavor

class Fruit:
  flavor = tf.constant([0, 0])

class Apple(Fruit):
  flavor = tf.constant([1, 2])

class Mango(Fruit):
  flavor = tf.constant([3, 4])

# As described in the above rules, a generic TraceType for `Apple` and `Mango`
# is generated (and a corresponding ConcreteFunction is traced) but it fails to
# match the second function call since the first pair of Apple() and Mango()
# have gone out out of scope by then and deleted.
get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function
get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function again

# However, each subclass of the `Fruit` class has a fixed flavor, and you
# can reuse an existing traced concrete function if it was the same
# subclass. Avoiding such unnecessary tracing of concrete functions
# can have significant performance benefits.

class FruitTraceType(tf.types.experimental.TraceType):
  def __init__(self, fruit):
    self.fruit_type = type(fruit)
    self.fruit_value = fruit

  def is_subtype_of(self, other):
      # True if self subtypes `other` and `other`'s type matches FruitTraceType.
      return (type(other) is FruitTraceType and
              self.fruit_type is other.fruit_type)

  def most_specific_common_supertype(self, others):
      # `self` is the specific common supertype if all input types match it.
      return self if all(self == other for other in others) else None

  def placeholder_value(self, placeholder_context=None):
      # Use the fruit itself instead of the type for correct tracing.
      return self.fruit_value

  def __eq__(self, other):
    return type(other) is FruitTraceType and self.fruit_type == other.fruit_type

  def __hash__(self):
    return hash(self.fruit_type)

class FruitWithTraceType:

  def __tf_tracing_type__(self, context):
    return FruitTraceType(self)

class AppleWithTraceType(FruitWithTraceType):
  flavor = tf.constant([1, 2])

class MangoWithTraceType(FruitWithTraceType):
  flavor = tf.constant([3, 4])

# Now if you try calling it again:
get_mixed_flavor(AppleWithTraceType(), MangoWithTraceType()) # Traces a new concrete function
get_mixed_flavor(AppleWithTraceType(), MangoWithTraceType()) # Re-uses the traced concrete function
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([4, 6], dtype=int32)>

取得具體函式

每次追蹤函式時,都會建立新的具體函式。您可以使用 get_concrete_function 直接取得具體函式。

print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
Obtaining concrete trace
Executing traced function
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)
# You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
tf.Tensor(b'cc', shape=(), dtype=string)

列印 ConcreteFunction 會顯示其輸入引數 (具有類型) 及其輸出類型的摘要。

print(double_strings)
ConcreteFunction Input Parameters:
  a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.string, name=None)
Output Type:
  TensorSpec(shape=(), dtype=tf.string, name=None)
Captures:
  None

您也可以直接擷取具體函式的簽名。

print(double_strings.function_type)
(a: TensorSpec(shape=(), dtype=tf.string, name=None)) -> TensorSpec(shape=(), dtype=tf.string, name=None)

使用具有不相容類型的具體追蹤會擲回錯誤

with assert_raises(tf.errors.InvalidArgumentError):
  double_strings(tf.constant(1))
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>:
Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/function_type_utils.py", line 442, in bind_function_inputs
    bound_arguments = function_type.bind_with_defaults(
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/core/function/polymorphism/function_type.py", line 277, in bind_with_defaults
    with_default_args[arg_name] = constraint.cast(
TypeError: Can not cast TensorSpec(shape=(), dtype=tf.int32, name=None) to TensorSpec(shape=(), dtype=tf.string, name=None)

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1180, in _call_impl
    return self._call_with_structured_signature(args, kwargs)
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1260, in _call_with_structured_signature
    function_type_utils.canonicalize_function_inputs(
TypeError: Binding inputs to tf.function failed due to `Can not cast TensorSpec(shape=(), dtype=tf.int32, name=None) to TensorSpec(shape=(), dtype=tf.string, name=None)`. Received args: (<tf.Tensor: shape=(), dtype=int32, numpy=1>,) and kwargs: {} for signature: (a: TensorSpec(shape=(), dtype=tf.string, name=None)) -> TensorSpec(shape=(), dtype=tf.string, name=None).

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/tmpfs/tmp/ipykernel_11117/3551158538.py", line 8, in assert_raises
    yield
  File "/tmpfs/tmp/ipykernel_11117/3196284684.py", line 2, in <module>
    double_strings(tf.constant(1))
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_162 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_162]

您可能會注意到,在具體函式的輸入簽名中,Python 引數會獲得特殊處理。在 TensorFlow 2.3 之前,Python 引數只是從具體函式的簽名中移除。從 TensorFlow 2.3 開始,Python 引數會保留在簽名中,但會限制為採用追蹤期間設定的值。

@tf.function
def pow(a, b):
  return a ** b

square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
ConcreteFunction Input Parameters:
  a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=<unknown>, dtype=tf.float32, name=None)
  b (POSITIONAL_OR_KEYWORD): Literal[2]
Output Type:
  TensorSpec(shape=<unknown>, dtype=tf.float32, name=None)
Captures:
  None
assert square(tf.constant(10.0)) == 100

with assert_raises(TypeError):
  square(tf.constant(10.0), b=3)
Caught expected exception 
  <class 'TypeError'>:
Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/function_type_utils.py", line 442, in bind_function_inputs
    bound_arguments = function_type.bind_with_defaults(
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/core/function/polymorphism/function_type.py", line 277, in bind_with_defaults
    with_default_args[arg_name] = constraint.cast(
ValueError: Can not cast 3 to Literal[2]

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1180, in _call_impl
    return self._call_with_structured_signature(args, kwargs)
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1260, in _call_with_structured_signature
    function_type_utils.canonicalize_function_inputs(
TypeError: Binding inputs to tf.function failed due to `Can not cast 3 to Literal[2]`. Received args: (<tf.Tensor: shape=(), dtype=float32, numpy=10.0>,) and kwargs: {'b': 3} for signature: (a: TensorSpec(shape=<unknown>, dtype=tf.float32, name=None), b: Literal[2]) -> TensorSpec(shape=<unknown>, dtype=tf.float32, name=None).

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1183, in _call_impl
    return self._call_with_flat_signature(args, kwargs)
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1234, in _call_with_flat_signature
    raise TypeError(f"{self._flat_signature_summary()} got unexpected "
TypeError: pow(a) got unexpected keyword arguments: b.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/tmpfs/tmp/ipykernel_11117/3551158538.py", line 8, in assert_raises
    yield
  File "/tmpfs/tmp/ipykernel_11117/2310937119.py", line 4, in <module>
    square(tf.constant(10.0), b=3)
TypeError: Binding inputs to tf.function failed due to `Can not cast 3 to Literal[2]`. Received args: (<tf.Tensor: shape=(), dtype=float32, numpy=10.0>,) and kwargs: {'b': 3} for signature: (a: TensorSpec(shape=<unknown>, dtype=tf.float32, name=None), b: Literal[2]) -> TensorSpec(shape=<unknown>, dtype=tf.float32, name=None).
Fallback to flat signature also failed due to: pow(a) got unexpected keyword arguments: b.

取得圖

雖然擷取實際 tf.Graph 物件並非您通常需要執行的動作,但您可以從任何具體函式輕鬆取得。

graph = double_strings.graph
for node in graph.as_graph_def().node:
  print(f'{node.input} -> {node.name}')
[] -> a
['a', 'a'] -> add
['add'] -> Identity

實際上,tf.Graph 並非直接可呼叫。我們實際上使用 tf.types.experimental.AtomicFunction 來執行 tf.Graph 描述的計算。您可以存取描述追蹤 tf.GraphAtomicFunction,並直接呼叫它,而不是 ConcreteFunction

atomic_fn = double_strings.inference_fn
atomic_fn(tf.constant("a"))
<tf.Tensor: shape=(), dtype=string, numpy=b'aa'>

這具有在高效能情境中降低 Python 額外負荷的優點。但它僅應適用於正向推論 (不支援梯度),而擷取的張量值 (如果有的話) 則需要明確提供。

偵錯

一般而言,在立即模式中偵錯程式碼比在 tf.function 內更輕鬆。您應確保程式碼在立即模式中無錯誤執行,然後再使用 tf.function 裝飾。為了協助偵錯程序,您可以呼叫 tf.config.run_functions_eagerly(True) 以全域停用並重新啟用 tf.function

在追蹤僅在 tf.function 內出現的問題時,以下是一些秘訣

  • 一般的 Python print 呼叫僅在追蹤期間執行,協助您追蹤函式何時取得 (重新) 追蹤。
  • tf.print 呼叫每次都會執行,並可協助您追蹤執行期間的中繼值。
  • tf.debugging.enable_check_numerics 是一種追蹤 NaN 和 Inf 建立位置的簡單方法。
  • pdb (the Python 除錯器) 可以幫助您瞭解追蹤期間發生的狀況。(注意:pdb 會將您帶入 AutoGraph 轉換後的原始碼。)

AutoGraph 轉換

AutoGraph 是一個預設在 tf.function 中啟用的程式庫,可將 Python 立即執行程式碼的子集轉換為與圖相容的 TensorFlow 運算。這包括控制流程,例如 ifforwhile

TensorFlow 運算 (例如 tf.condtf.while_loop) 仍然可以運作,但如果以 Python 撰寫控制流程,通常更容易撰寫和理解。

# A simple loop

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))
[0.143583655 0.698347807 0.767881036 0.857545733 0.0981599092]
[0.142604977 0.603318036 0.645695686 0.694991 0.0978458375]
[0.141646087 0.539406419 0.568765223 0.601178706 0.0975347683]
[0.140706316 0.492538482 0.514451861 0.537887812 0.0972266495]
[0.139785022 0.456228882 0.473406643 0.491387457 0.0969214365]
[0.138881624 0.427005619 0.440947741 0.455316931 0.096619077]
[0.137995526 0.402815819 0.414429694 0.426259726 0.0963195339]
[0.137126192 0.38235572 0.392227381 0.402190775 0.0960227624]
[0.136273116 0.364751518 0.373278826 0.381821901 0.0957287177]
[0.135435775 0.349392414 0.356856316 0.364288628 0.0954373628]
[0.134613693 0.335836589 0.342441976 0.34898597 0.0951486528]
[0.133806422 0.323755413 0.329655707 0.335475951 0.094862543]
[0.133013532 0.312898636 0.318211377 0.323432535 0.0945790112]
[0.132234573 0.303071797 0.307888716 0.312607288 0.0942980051]
[0.13146916 0.294121206 0.298515141 0.302807152 0.0940194875]
[0.13071692 0.285923541 0.289953142 0.29387942 0.0937434211]
[0.12997745 0.278378516 0.282091677 0.285701483 0.0934697762]
[0.129250407 0.2714037 0.274839878 0.278173655 0.0931985155]
[0.12853545 0.264930487 0.268122554 0.271213919 0.0929296]
[0.127832219 0.258901358 0.261877 0.264754 0.092663005]
<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.12714042, 0.25326762, 0.2560503 , 0.25873667, 0.0923987 ],
      dtype=float32)>

如果您感到好奇,可以檢查 AutoGraph 產生的程式碼。

print(tf.autograph.to_code(f.python_function))
def tf__f(x):
    with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (x,)

        def set_state(vars_):
            nonlocal x
            (x,) = vars_

        def loop_body():
            nonlocal x
            ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope)
            x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope)

        def loop_test():
            return ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1
        ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})
        try:
            do_return = True
            retval_ = ag__.ld(x)
        except:
            do_return = False
            raise
        return fscope.ret(retval_, do_return)

條件式

AutoGraph 會將某些 if <condition> 陳述式轉換為等效的 tf.cond 呼叫。如果 <condition> 是張量,則會進行此替換。否則,if 陳述式會以 Python 條件式執行。

Python 條件式會在追蹤期間執行,因此只有條件式的一個分支會新增至圖中。如果沒有 AutoGraph,如果存在資料相依的控制流程,則追蹤的圖將無法採用替代分支。

tf.cond 會追蹤條件式的兩個分支並將其新增至圖中,在執行時間動態選取分支。追蹤可能會產生意想不到的副作用;請查看 AutoGraph 追蹤效果以取得更多資訊。

@tf.function
def fizzbuzz(n):
  for i in tf.range(1, n + 1):
    print('Tracing for loop')
    if i % 15 == 0:
      print('Tracing fizzbuzz branch')
      tf.print('fizzbuzz')
    elif i % 3 == 0:
      print('Tracing fizz branch')
      tf.print('fizz')
    elif i % 5 == 0:
      print('Tracing buzz branch')
      tf.print('buzz')
    else:
      print('Tracing default branch')
      tf.print(i)

fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
Tracing for loop
Tracing fizzbuzz branch
Tracing fizz branch
Tracing buzz branch
Tracing default branch
1
2
fizz
4
buzz
1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz
16
17
fizz
19
buzz

如需 AutoGraph 轉換的 if 陳述式的其他限制,請參閱參考文件

迴圈

AutoGraph 會將某些 forwhile 陳述式轉換為等效的 TensorFlow 迴圈運算,例如 tf.while_loop。如果未轉換,forwhile 迴圈會以 Python 迴圈執行。

在下列情況下會進行此替換

  • for x in y:如果 y 是張量,則轉換為 tf.while_loop。在 ytf.data.Dataset 的特殊情況下,會產生 tf.data.Dataset 運算的組合。
  • while <condition>:如果 <condition> 是張量,則轉換為 tf.while_loop

Python 迴圈會在追蹤期間執行,為迴圈的每次迭代將額外的運算新增至 tf.Graph

TensorFlow 迴圈會追蹤迴圈的主體,並在執行時間動態選取要執行的迭代次數。迴圈主體只會在產生的 tf.Graph 中出現一次。

如需 AutoGraph 轉換的 forwhile 陳述式的其他限制,請參閱參考文件

在 Python 資料上執行迴圈

常見的陷阱是在 tf.function 內對 Python/NumPy 資料執行迴圈。此迴圈會在追蹤過程中執行,為迴圈的每次迭代將模型副本新增至 tf.Graph

如果您想將整個訓練迴圈包裝在 tf.function 中,最安全的方法是將您的資料包裝為 tf.data.Dataset,以便 AutoGraph 動態展開訓練迴圈。

def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x) # Some dummy computation.
  return loss

small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<_FlatMapDataset element_spec=(TensorSpec(shape=<unknown>, dtype=tf.int32, name=None), TensorSpec(shape=<unknown>, dtype=tf.int32, name=None))>) contains 6 nodes in its graph
train(<_FlatMapDataset element_spec=(TensorSpec(shape=<unknown>, dtype=tf.int32, name=None), TensorSpec(shape=<unknown>, dtype=tf.int32, name=None))>) contains 6 nodes in its graph

將 Python/NumPy 資料包裝在 Dataset 中時,請注意 tf.data.Dataset.from_generatortf.data.Dataset.from_tensor_slices。前者會將資料保留在 Python 中,並透過 tf.py_function 擷取資料 (這可能會對效能產生影響),而後者會將資料副本捆綁為圖中的一個大型 tf.constant() 節點 (這可能會對記憶體產生影響)。

透過 TFRecordDatasetCsvDataset 等從檔案讀取資料是最有效率的資料取用方式,因為這樣 TensorFlow 本身就可以管理資料的非同步載入和預先擷取,而無需 Python 參與。若要瞭解更多資訊,請參閱 tf.data:建構 TensorFlow 輸入管線指南。

在迴圈中累積值

常見的模式是從迴圈累積中間值。通常,這是透過附加到 Python 清單或將項目新增至 Python 字典來完成。但是,由於這些是 Python 副作用,因此它們在動態展開的迴圈中無法如預期般運作。使用 tf.TensorArray 從動態展開的迴圈累積結果。

batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])

dynamic_rnn(rnn_step,
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[0.36914825, 0.64223015, 0.7850807 , 0.19980955],
        [0.4491886 , 1.4083506 , 1.0351617 , 0.20833313],
        [0.7401295 , 1.7583194 , 1.1593785 , 0.32083678]],

       [[0.09818649, 0.09965849, 0.28532243, 0.2933966 ],
        [0.56936   , 1.0815177 , 0.7327199 , 0.6250684 ],
        [1.5300128 , 1.22948   , 0.8870441 , 0.770558  ]]], dtype=float32)>

限制

tf.function 在設計上有一些限制,將 Python 函式轉換為 tf.function 時,您應該注意這些限制。

執行 Python 副作用

副作用 (例如列印、附加到清單和變更全域變數) 在 tf.function 內可能會表現異常,有時會執行兩次或完全不執行。它們只會在您第一次使用一組輸入呼叫 tf.function 時發生。之後,追蹤的 tf.Graph 會重新執行,而不會執行 Python 程式碼。

一般的經驗法則是避免在您的邏輯中依賴 Python 副作用,而僅將它們用於偵錯您的追蹤。否則,TensorFlow API (例如 tf.datatf.printtf.summarytf.Variable.assigntf.TensorArray) 是確保您的程式碼將由 TensorFlow 執行階段在每次呼叫時執行的最佳方法。

@tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)
Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2

如果您想在每次調用 tf.function 時執行 Python 程式碼,tf. py_function 是一種退出機制。tf.py_function 的缺點是它不具可攜性或特別高效能,無法與 SavedModel 一起儲存,並且在分散式 (多 GPU、TPU) 設定中無法良好運作。此外,由於 tf.py_function 必須連接到圖中,因此它會將所有輸入/輸出轉換為張量。

@tf.py_function(Tout=tf.float32)
def py_plus(x, y):
  print('Executing eagerly.')
  return x + y

@tf.function
def tf_wrapper(x, y):
  print('Tracing.')
  return py_plus(x, y)

tf.function 會在第一次追蹤

tf_wrapper(tf.constant(1.0), tf.constant(2.0)).numpy()
Tracing.
Executing eagerly.
3.0

但裡面的 tf.py_function 每次都會立即執行

tf_wrapper(tf.constant(1.0), tf.constant(2.0)).numpy()
Executing eagerly.
3.0

變更 Python 全域變數和自由變數

變更 Python 全域變數和 自由變數 算作 Python 副作用,因此它只會在追蹤期間發生。

external_list = []

@tf.function
def side_effect(x):
  print('Python side effect')
  external_list.append(x)

side_effect(1)
side_effect(1)
side_effect(1)
# The list append only happened once!
assert len(external_list) == 1
Python side effect

有時,意想不到的行為很難注意到。在下面的範例中,counter 旨在保護變數的遞增。但是,由於它是 Python 整數而不是 TensorFlow 物件,因此其值會在第一次追蹤期間擷取。當使用 tf.function 時,assign_add 將無條件記錄在基礎圖中。因此,每次呼叫 tf.function 時,v 都會遞增 1。當 Python 副作用 (範例中的 counter) 用於判斷要執行的運算 (範例中的 assign_add) 時,嘗試將其圖模式 TensorFlow 程式碼遷移到使用 tf.function 裝飾器的 Tensorflow 2 的使用者中,此問題很常見。通常,使用者只有在看到可疑的數值結果或效能明顯低於預期時 (例如,如果受保護的運算非常耗時) 才會意識到這一點。

class Model(tf.Module):
  def __init__(self):
    self.v = tf.Variable(0)
    self.counter = 0

  @tf.function
  def __call__(self):
    if self.counter == 0:
      # A python side-effect
      self.counter += 1
      self.v.assign_add(1)

    return self.v

m = Model()
for n in range(3):
  print(m().numpy()) # prints 1, 2, 3
1
2
3

實現預期行為的解決方法是使用 tf.init_scope 將運算提升到函式圖之外。這可確保變數遞增僅在追蹤期間完成一次。應該注意的是,init_scope 還有其他副作用,包括清除的控制流程和梯度帶。有時,init_scope 的使用可能會變得太複雜而難以實際管理。

class Model(tf.Module):
  def __init__(self):
    self.v = tf.Variable(0)
    self.counter = 0

  @tf.function
  def __call__(self):
    if self.counter == 0:
      # Lifts ops out of function-building graphs
      with tf.init_scope():
        self.counter += 1
        self.v.assign_add(1)

    return self.v

m = Model()
for n in range(3):
  print(m().numpy()) # prints 1, 1, 1
1
1
1

總之,作為經驗法則,您應該避免變更 Python 物件 (例如整數或容器 (例如清單)),這些物件存在於 tf.function 之外。相反,請使用引數和 TF 物件。例如,「在迴圈中累積值」一節有一個關於如何實作類似清單的運算的範例。

在某些情況下,如果狀態是 tf.Variable,您可以擷取和操作狀態。這就是使用相同 ConcreteFunction 重複呼叫來更新 Keras 模型權重的方式。

使用 Python 迭代器和產生器

許多 Python 功能 (例如產生器和迭代器) 依賴 Python 執行階段來追蹤狀態。一般來說,雖然這些建構如預期般在立即執行模式下運作,但它們是 Python 副作用的範例,因此僅在追蹤期間發生。

@tf.function
def buggy_consume_next(iterator):
  tf.print("Value:", next(iterator))

iterator = iter([1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)
Value: 1
Value: 1
Value: 1

就像 TensorFlow 具有用於清單建構的專用 tf.TensorArray 一樣,它也具有用於迭代建構的專用 tf.data.Iterator。請參閱關於 AutoGraph 轉換的一節以取得概述。此外,tf.data API 可以協助實作產生器模式

@tf.function
def good_consume_next(iterator):
  # This is ok, iterator is a tf.data.Iterator
  tf.print("Value:", next(iterator))

ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])
iterator = iter(ds)
good_consume_next(iterator)
good_consume_next(iterator)
good_consume_next(iterator)
Value: 1
Value: 2
Value: 3

tf.function 的所有輸出都必須是傳回值

除了 tf.Variable 之外,tf.function 必須傳回其所有輸出。嘗試直接從函式存取任何張量而不透過傳回值會導致「洩漏」。

例如,下面的函式透過 Python 全域 x「洩漏」張量 a

x = None

@tf.function
def leaky_function(a):
  global x
  x = a + 1  # Bad - leaks local tensor
  return a + 2

correct_a = leaky_function(tf.constant(1))

print(correct_a.numpy())  # Good - value obtained from function's returns
try:
  x.numpy()  # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
  print(expected)
3
'SymbolicTensor' object has no attribute 'numpy'

即使洩漏的值也已傳回,情況也是如此

@tf.function
def leaky_function(a):
  global x
  x = a + 1  # Bad - leaks local tensor
  return x  # Good - uses local tensor

correct_a = leaky_function(tf.constant(1))

print(correct_a.numpy())  # Good - value obtained from function's returns
try:
  x.numpy()  # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
  print(expected)

@tf.function
def captures_leaked_tensor(b):
  b += x  # Bad - `x` is leaked from `leaky_function`
  return b

with assert_raises(TypeError):
  captures_leaked_tensor(tf.constant(2))
2
'SymbolicTensor' object has no attribute 'numpy'
Caught expected exception 
  <class 'TypeError'>:
Traceback (most recent call last):
  File "/tmpfs/tmp/ipykernel_11117/3551158538.py", line 8, in assert_raises
    yield
  File "/tmpfs/tmp/ipykernel_11117/566849597.py", line 21, in <module>
    captures_leaked_tensor(tf.constant(2))
TypeError: <tf.Tensor 'add:0' shape=() dtype=int32> is out of scope and cannot be used here. Use return values, explicit Python locals or TensorFlow collections to access it.
Please see https://tensorflow.dev.org.tw/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.

<tf.Tensor 'add:0' shape=() dtype=int32> was defined here:
    File "/usr/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    File "/usr/lib/python3.9/runpy.py", line 87, in _run_code
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel_launcher.py", line 17, in <module>
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/traitlets/config/application.py", line 1077, in launch_instance
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 739, in start
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tornado/platform/asyncio.py", line 195, in start
    File "/usr/lib/python3.9/asyncio/base_events.py", line 601, in run_forever
    File "/usr/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once
    File "/usr/lib/python3.9/asyncio/events.py", line 80, in _run
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 529, in dispatch_queue
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 518, in process_one
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 424, in dispatch_shell
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 766, in execute_request
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 429, in do_execute
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3048, in run_cell
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3103, in _run_cell
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3308, in run_cell_async
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3490, in run_ast_nodes
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3550, in run_code
    File "/tmpfs/tmp/ipykernel_11117/566849597.py", line 7, in <module>
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 832, in __call__
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 888, in _call
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 695, in _initialize
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 178, in trace_function
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 283, in _maybe_define_function
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 310, in _create_concrete_function
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 1059, in func_graph_from_py_func
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 598, in wrapped_fn
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/autograph_util.py", line 41, in autograph_handler
    File "/tmpfs/tmp/ipykernel_11117/566849597.py", line 4, in leaky_function
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/math_ops.py", line 1478, in binary_op_wrapper
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py", line 1260, in op_dispatch_handler
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/math_ops.py", line 1871, in _add_dispatch
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/gen_math_ops.py", line 490, in add_v2
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/op_def_library.py", line 796, in _apply_op_helper
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 670, in _create_op_internal
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py", line 2652, in _create_op_internal
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py", line 1160, in from_node_def

The tensor <tf.Tensor 'add:0' shape=() dtype=int32> cannot be accessed from here, because it was defined in FuncGraph(name=leaky_function, id=140399462105440), which is out of scope.

通常,當您使用 Python 陳述式或資料結構時,會發生此類洩漏。除了洩漏無法存取的張量外,此類陳述式也可能是錯誤的,因為它們算作 Python 副作用,並且不保證在每次函式呼叫時執行。

洩漏本機張量的常見方式還包括變更外部 Python 集合或物件

class MyClass:

  def __init__(self):
    self.field = None

external_list = []
external_object = MyClass()

def leaky_function():
  a = tf.constant(1)
  external_list.append(a)  # Bad - leaks tensor
  external_object.field = a  # Bad - leaks tensor

不支援遞迴 tf.functions

不支援遞迴 tf.functions,並且可能會導致無限迴圈。例如,

@tf.function
def recursive_fn(n):
  if n > 0:
    return recursive_fn(n - 1)
  else:
    return 1

with assert_raises(Exception):
  recursive_fn(tf.constant(5))  # Bad - maximum recursion error.
Caught expected exception 
  <class 'Exception'>:
Traceback (most recent call last):
  File "/tmpfs/tmp/ipykernel_11117/3551158538.py", line 8, in assert_raises
    yield
  File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 9, in <module>
    recursive_fn(tf.constant(5))  # Bad - maximum recursion error.
tensorflow.python.autograph.impl.api.StagingError: in user code:

    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/usr/lib/python3.9/abc.py", line 119, in __instancecheck__
        return _abc_instancecheck(cls, instance)

    RecursionError: maximum recursion depth exceeded while calling a Python object

即使遞迴 tf.function 似乎可以運作,Python 函式也會被追蹤多次,並且可能會對效能產生影響。例如,

@tf.function
def recursive_fn(n):
  if n > 0:
    print('tracing')
    return recursive_fn(n - 1)
  else:
    return 1

recursive_fn(5)  # Warning - multiple tracings
tracing
tracing
tracing
tracing
tracing
<tf.Tensor: shape=(), dtype=int32, numpy=1>

已知問題

如果您的 tf.function 無法正確評估,則錯誤可能是由這些已知問題所造成,這些問題計劃在未來修復。

依賴 Python 全域變數和自由變數

當使用 Python 引數的新值呼叫時,tf.function 會建立新的 ConcreteFunction。但是,對於該 tf.function 的 Python 閉包、全域變數或非本機變數,它不會這樣做。如果它們的值在呼叫 tf.function 之間發生變更,tf.function 仍會使用它們在追蹤時所具有的值。這與一般 Python 函式的運作方式不同。

因此,您應該遵循函數式程式設計樣式,該樣式使用引數而不是封閉外部名稱。

@tf.function
def buggy_add():
  return 1 + foo

@tf.function
def recommended_add(foo):
  return 1 + foo

foo = 1
print("Buggy:", buggy_add())
print("Correct:", recommended_add(foo))
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo = 100
print("Buggy:", buggy_add())  # Did not change!
print("Correct:", recommended_add(foo))
Updating the value of `foo` to 100!
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(101, shape=(), dtype=int32)

更新全域值的另一種方法是將其設為 tf.Variable,並改用 Variable.assign 方法。

@tf.function
def variable_add():
  return 1 + foo

foo = tf.Variable(1)
print("Variable:", variable_add())
Variable: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo.assign(100)
print("Variable:", variable_add())
Updating the value of `foo` to 100!
Variable: tf.Tensor(101, shape=(), dtype=int32)

依賴 Python 物件

支援將自訂 Python 物件作為引數傳遞給 tf.function,但有一些限制。

為了獲得最大的功能涵蓋範圍,請考慮在將物件傳遞給 tf.function 之前,將物件轉換為擴充類型。您也可以使用 Python 基本類型和 tf.nest 相容的結構。

但是,如追蹤規則中所述,當自訂 Python 類別未提供自訂 TraceType 時,tf.function 會被迫使用基於執行個體的相等性,這表示當您傳遞具有修改屬性的相同物件時,它將不會建立新的追蹤

class SimpleModel(tf.Module):
  def __init__(self):
    # These values are *not* tf.Variables.
    self.bias = 0.
    self.weight = 2.

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

simple_model = SimpleModel()
x = tf.constant(10.)
print(evaluate(simple_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
simple_model.bias += 5.0
print(evaluate(simple_model, x))  # Didn't change :(
Adding bias!
tf.Tensor(20.0, shape=(), dtype=float32)

使用相同的 tf.function 來評估模型的修改執行個體將會出錯,因為它仍然具有與原始模型相同的基於執行個體的 TraceType

因此,建議您撰寫 tf.function 以避免依賴可變物件屬性,或為物件實作追蹤協定,以告知 tf.function 關於此類屬性。

如果無法做到這一點,一種解決方法是在每次修改物件以強制重新追蹤時,建立新的 tf.functions

def evaluate(model, x):
  return model.weight * x + model.bias

new_model = SimpleModel()
evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)
# Don't pass in `new_model`. `tf.function` already captured its state during tracing.
print(evaluate_no_bias(x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
new_model.bias += 5.0
# Create new `tf.function` and `ConcreteFunction` since you modified `new_model`.
evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)
print(evaluate_with_bias(x)) # Don't pass in `new_model`.
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

由於重新追蹤可能很耗費資源,因此您可以使用 tf.Variables 作為物件屬性,可以對其進行變更 (但不能變更,請注意!),以達到類似的效果,而無需重新追蹤。

class BetterModel:

  def __init__(self):
    self.bias = tf.Variable(0.)
    self.weight = tf.Variable(2.)

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

better_model = BetterModel()
print(evaluate(better_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
better_model.bias.assign_add(5.0)  # Note: instead of better_model.bias += 5
print(evaluate(better_model, x))  # This works!
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

建立 tf.Variables

tf.function 僅支援在第一次呼叫時建立一次的單例 tf.Variables,並在後續函式呼叫中重複使用。下面的程式碼片段會在每次函式呼叫中建立新的 tf.Variable,這會導致 ValueError 例外狀況。

範例

@tf.function
def f(x):
  v = tf.Variable(1.0)
  return v

with assert_raises(ValueError):
  f(1.0)
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmpfs/tmp/ipykernel_11117/3551158538.py", line 8, in assert_raises
    yield
  File "/tmpfs/tmp/ipykernel_11117/3018268426.py", line 7, in <module>
    f(1.0)
ValueError: in user code:

    File "/tmpfs/tmp/ipykernel_11117/3018268426.py", line 3, in f  *
        v = tf.Variable(1.0)

    ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://tensorflow.dev.org.tw/guide/function#creating_tfvariables for more information.

用於解決此限制的常見模式是以 Python None 值開始,然後在值為 None 時有條件地建立 tf.Variable

class Count(tf.Module):
  def __init__(self):
    self.count = None

  @tf.function
  def __call__(self):
    if self.count is None:
      self.count = tf.Variable(0)
    return self.count.assign_add(1)

c = Count()
print(c())
print(c())
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

與多個 Keras 最佳化工具搭配使用

當將多個 Keras 最佳化工具與 tf.function 搭配使用時,您可能會遇到 ValueError: tf.function only supports singleton tf.Variables created on the first call.。發生此錯誤的原因是最佳化工具在第一次套用梯度時,會在內部建立 tf.Variables。

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

@tf.function
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

train_step(w, x, y, opt1)
print("Calling `train_step` with different optimizer...")
with assert_raises(ValueError):
  train_step(w, x, y, opt2)
Calling `train_step` with different optimizer...
Caught expected exception 
  <class 'ValueError'>:
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1701138168.913099   11284 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
Traceback (most recent call last):
  File "/tmpfs/tmp/ipykernel_11117/3551158538.py", line 8, in assert_raises
    yield
  File "/tmpfs/tmp/ipykernel_11117/950644149.py", line 18, in <module>
    train_step(w, x, y, opt2)
ValueError: in user code:

    File "/tmpfs/tmp/ipykernel_11117/950644149.py", line 9, in train_step  *
        optimizer.apply_gradients(zip(gradients, [w]))
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/optimizer.py", line 1223, in apply_gradients  **
        return super().apply_gradients(grads_and_vars, name=name)
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/optimizer.py", line 638, in apply_gradients
        self.build(trainable_variables)
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/adam.py", line 145, in build
        self.add_variable_from_reference(
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/optimizer.py", line 1125, in add_variable_from_reference
        return super().add_variable_from_reference(
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/optimizer.py", line 513, in add_variable_from_reference
        variable = tf.Variable(

    ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://tensorflow.dev.org.tw/guide/function#creating_tfvariables for more information.

如果您需要在呼叫之間變更有狀態物件,最簡單的方法是定義 tf.Module 子類別,並建立執行個體來保存這些物件

class TrainStep(tf.Module):
  def __init__(self, optimizer):
    self.optimizer = optimizer

  @tf.function
  def __call__(self, w, x, y):
    with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
    gradients = tape.gradient(L, [w])
    self.optimizer.apply_gradients(zip(gradients, [w]))


opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

train_o1 = TrainStep(opt1)
train_o2 = TrainStep(opt2)

train_o1(w, x, y)
train_o2(w, x, y)

您也可以手動執行此操作,方法是為每個最佳化工具建立 @tf.function 包裝函式的多個執行個體

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

# Not a tf.function.
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

# Make a new tf.function and ConcreteFunction for each optimizer.
train_step_1 = tf.function(train_step)
train_step_2 = tf.function(train_step)
for i in range(10):
  if i % 2 == 0:
    train_step_1(w, x, y, opt1)
  else:
    train_step_2(w, x, y, opt2)

與多個 Keras 模型搭配使用

當將不同的模型執行個體傳遞到相同的 tf.function 時,您也可能會遇到 ValueError: tf.function only supports singleton tf.Variables created on the first call.

發生此錯誤的原因是 Keras 模型 (其未定義其輸入形狀) 和 Keras 層在第一次呼叫時會建立 tf.Variables。您可能正在嘗試在 tf.function 內初始化這些變數,而該 tf.function 已被呼叫。若要避免此錯誤,請嘗試呼叫 model.build(input_shape) 以在訓練模型之前初始化所有權重。

延伸閱讀

若要瞭解如何匯出和載入 tf.function,請參閱 SavedModel 指南。若要瞭解追蹤後執行的圖最佳化,請參閱 Grappler 指南。若要瞭解如何最佳化您的資料管線和分析您的模型,請參閱 Profiler 指南