![]() |
![]() |
![]() |
![]() |
在 TensorFlow 2 中,預設已開啟立即執行。使用者介面直觀且彈性十足 (執行一次性作業更加輕鬆快速),但可能會犧牲效能和部署能力。
您可以使用 tf.function
從程式碼建立圖。這是一種轉換工具,可從您的 Python 程式碼建立獨立於 Python 的資料流程圖。這將協助您建立高效能和可攜式模型,而且是使用 SavedModel
的必要條件。
本指南將協助您概念化 tf.function
在底層的運作方式,以便您有效使用。
主要重點和建議如下:
- 在立即模式中進行偵錯,然後使用
@tf.function
裝飾。 - 請勿依賴 Python 的副作用,例如物件突變或清單附加。
tf.function
最適合搭配 TensorFlow 運算元;NumPy 和 Python 呼叫會轉換為常數。
設定
import tensorflow as tf
2023-11-28 02:22:39.038158: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-11-28 02:22:39.038208: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-11-28 02:22:39.039647: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
定義輔助函式以示範您可能會遇到的錯誤類型
import traceback
import contextlib
# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
try:
yield
except error_class as e:
print('Caught expected exception \n {}:'.format(error_class))
traceback.print_exc(limit=2)
except Exception as e:
raise e
else:
raise Exception('Expected {} to be raised but no error was raised!'.format(
error_class))
基本概念
用法
您定義的 tf.function
(例如透過套用 @tf.function
裝飾器) 就像核心 TensorFlow 運算一樣:您可以立即執行;您可以計算梯度等等。
@tf.function # The decorator converts `add` into a `PolymorphicFunction`.
def add(a, b):
return a + b
add(tf.ones([2, 2]), tf.ones([2, 2])) # [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy= array([[2., 2.], [2., 2.]], dtype=float32)>
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>
您可以在其他 tf.function
內使用 tf.function
。
@tf.function
def dense_layer(x, w, b):
return add(tf.matmul(x, w), b)
dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy= array([[3., 3.], [3., 3.], [3., 3.]], dtype=float32)>
對於具有許多小型運算的圖,tf.function
可能比立即程式碼更快。但對於具有少數昂貴運算 (例如捲積) 的圖,您可能不會看到太多加速。
import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)
@tf.function
def conv_fn(image):
return conv_layer(image)
image = tf.zeros([1, 200, 200, 100])
# Warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")
Eager conv: 0.005143817999964995 Function conv: 0.005329717999984496 Note how there's not much difference in performance for convolutions
追蹤
本節說明 tf.function
在底層的運作方式,包括未來可能會變更的實作詳細資料。不過,一旦您瞭解追蹤發生的原因和時間,就能更輕鬆有效地使用 tf.function
!
何謂「追蹤」?
tf.function
會在 TensorFlow 圖中執行您的程式。但是,tf.Graph
無法表示您在立即 TensorFlow 程式中編寫的所有內容。例如,Python 支援多型,但 tf.Graph
要求其輸入具有指定的資料類型和維度。或者,您可能會執行輔助工作,例如讀取命令列引數、引發錯誤或使用更複雜的 Python 物件;這些都無法在 tf.Graph
中執行。
tf.function
透過將您的程式碼分成兩個階段來彌合此差距
1) 在第一個階段 (稱為「追蹤」) 中,tf.function
會建立新的 tf.Graph
。Python 程式碼會正常執行,但所有 TensorFlow 運算 (例如新增兩個張量) 都會延遲:它們會由 tf.Graph
擷取,而不會執行。
2) 在第二個階段中,會執行 tf.Graph
,其中包含第一個階段中延遲的所有內容。此階段比追蹤階段快得多。
根據其輸入,tf.function
在呼叫時不一定會執行第一個階段。請參閱下方的「追蹤規則」,以更瞭解它如何做出該判斷。略過第一個階段而僅執行第二個階段,正是 TensorFlow 具有高效能的原因。
當 tf.function
決定追蹤時,追蹤階段之後會立即接著第二個階段,因此呼叫 tf.function
會同時建立和執行 tf.Graph
。稍後您將瞭解如何使用 get_concrete_function
僅執行追蹤階段。
當您將不同類型的引數傳遞到 tf.function
時,兩個階段都會執行
@tf.function
def double(a):
print("Tracing with", a)
return a + a
print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()
Tracing with Tensor("a:0", shape=(), dtype=int32) tf.Tensor(2, shape=(), dtype=int32) Tracing with Tensor("a:0", shape=(), dtype=float32) tf.Tensor(2.2, shape=(), dtype=float32) Tracing with Tensor("a:0", shape=(), dtype=string) tf.Tensor(b'aa', shape=(), dtype=string)
請注意,如果您使用相同的引數類型重複呼叫 tf.function
,TensorFlow 將略過追蹤階段並重複使用先前追蹤的圖,因為產生的圖會相同。
# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
tf.Tensor(b'bb', shape=(), dtype=string)
您可以使用 pretty_printed_concrete_signatures()
來查看所有可用的追蹤
print(double.pretty_printed_concrete_signatures())
Input Parameters: a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.int32, name=None) Output Type: TensorSpec(shape=(), dtype=tf.int32, name=None) Captures: None Input Parameters: a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.float32, name=None) Output Type: TensorSpec(shape=(), dtype=tf.float32, name=None) Captures: None Input Parameters: a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.string, name=None) Output Type: TensorSpec(shape=(), dtype=tf.string, name=None) Captures: None
到目前為止,您已瞭解 tf.function
在 TensorFlow 的圖追蹤邏輯之上建立快取的動態調度層。為了更具體地說明術語
tf.Graph
是 TensorFlow 計算的原始、語言無關、可攜式表示法。- 追蹤是透過 Python 程式碼產生新
tf.Graph
的程序。 tf.Graph
的執行個體專用於它追蹤的特定輸入類型。不同的類型需要重新追蹤。- 每個追蹤的
tf.Graph
都有對應的ConcreteFunction
。 tf.function
管理ConcreteFunction
的快取,並為您的輸入挑選正確的ConcreteFunction
。tf.function
包裝將追蹤的 Python 函式,並傳回tf.types.experimental.PolymorphicFunction
物件。
追蹤規則
呼叫時,tf.function
首先使用每個引數的 tf.types.experimental.TraceType
評估每個輸入引數的類型。這用於建構 tf.types.experimental.FunctionType
,描述所需 ConcreteFunction
的簽名。我們會將此 FunctionType
與現有 ConcreteFunction
的 FunctionType
進行比較。如果找到相符的 ConcreteFunction
,則呼叫會分派給它。如果找不到相符項,則會為所需的 FunctionType
追蹤新的 ConcreteFunction
。
如果找到多個相符項,則會選擇最特定的簽名。比對是透過子類型化完成,很像 C++ 或 Java 中的一般函式呼叫。例如,TensorShape([1, 2])
是 TensorShape([None, None])
的子類型,因此使用 TensorShape([1, 2])
呼叫 tf.function 可以分派到使用 TensorShape([None, None])
產生的 ConcreteFunction
,但如果也存在具有 TensorShape([1, None])
的 ConcreteFunction
,則會優先處理它,因為它更具體。
TraceType
是從輸入引數判斷,如下所示
- 對於
Tensor
,類型由Tensor
的dtype
和shape
參數化;已排名形狀是未排名形狀的子類型;固定維度是未知維度的子類型 - 對於
Variable
,類型與Tensor
類似,但也包含變數的唯一資源 ID,這是正確連接控制依附項所必需的 - 對於 Python 基本值,類型對應於值本身。例如,值
3
的TraceType
是LiteralTraceType<3>
,而不是int
。 - 對於 Python 有序容器 (例如
list
和tuple
等),類型由其元素的類型參數化;例如,[1, 2]
的類型是ListTraceType<LiteralTraceType<1>, LiteralTraceType<2>>
,而[2, 1]
的類型是ListTraceType<LiteralTraceType<2>, LiteralTraceType<1>>
,這兩者不同。 - 對於 Python 對應 (例如
dict
),類型也是從相同鍵到值類型的對應,而不是實際值。例如,{1: 2, 3: 4}
的類型是MappingTraceType<<KeyValue<1, LiteralTraceType<2>>>, <KeyValue<3, LiteralTraceType<4>>>>
。但是,與有序容器不同,{1: 2, 3: 4}
和{3: 4, 1: 2}
具有對等類型。 - 對於實作
__tf_tracing_type__
方法的 Python 物件,類型是該方法傳回的任何內容。 對於任何其他 Python 物件,類型是泛型
TraceType
,而比對程序為- 首先,它會檢查物件是否與先前追蹤中使用的物件相同 (使用 Python
id()
或is
)。請注意,即使物件已變更,這仍然會相符,因此如果您使用 Python 物件作為tf.function
引數,最好使用不可變物件。 - 接下來,它會檢查物件是否等於先前追蹤中使用的物件 (使用 Python
==
)。
請注意,此程序僅保留物件的 weakref,因此僅在物件在範圍內/未刪除時才有效。
- 首先,它會檢查物件是否與先前追蹤中使用的物件相同 (使用 Python
控制重新追蹤
重新追蹤 (即 tf.function
建立多個追蹤時) 有助於確保 TensorFlow 為每組輸入產生正確的圖。但是,追蹤是一項昂貴的作業!如果您的 tf.function
為每次呼叫重新追蹤新的圖,您會發現程式碼執行速度比未使用 tf.function
時更慢。
若要控制追蹤行為,您可以使用下列技巧
將固定的 input_signature
傳遞至 tf.function
這會強制 tf.function
將自身限制為僅一個 tf.types.experimental.FunctionType
,該類型由 input_signature
列舉的類型組成。無法分派到此 FunctionType
的呼叫將會擲回錯誤。
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
print("Tracing with", x)
return tf.where(x % 2 == 0, x // 2, 3 * x + 1)
print(next_collatz(tf.constant([1, 2])))
# You specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(TypeError):
next_collatz(tf.constant([[1, 2], [3, 4]]))
# You specified an int32 dtype in the input signature, so this should fail.
with assert_raises(TypeError):
next_collatz(tf.constant([1.0, 2.0]))
Tracing with Tensor("x:0", shape=(None,), dtype=int32) tf.Tensor([4 1], shape=(2,), dtype=int32) Caught expected exception <class 'TypeError'>: Caught expected exception <class 'TypeError'>: Traceback (most recent call last): File "/tmpfs/tmp/ipykernel_11117/3551158538.py", line 8, in assert_raises yield File "/tmpfs/tmp/ipykernel_11117/3657259638.py", line 9, in <module> next_collatz(tf.constant([[1, 2], [3, 4]])) TypeError: Binding inputs to tf.function failed due to `Can not cast TensorSpec(shape=(2, 2), dtype=tf.int32, name=None) to TensorSpec(shape=(None,), dtype=tf.int32, name=None)`. Received args: (<tf.Tensor: shape=(2, 2), dtype=int32, numpy= array([[1, 2], [3, 4]], dtype=int32)>,) and kwargs: {} for signature: (x: TensorSpec(shape=(None,), dtype=tf.int32, name=None)). Traceback (most recent call last): File "/tmpfs/tmp/ipykernel_11117/3551158538.py", line 8, in assert_raises yield File "/tmpfs/tmp/ipykernel_11117/3657259638.py", line 13, in <module> next_collatz(tf.constant([1.0, 2.0])) TypeError: Binding inputs to tf.function failed due to `Can not cast TensorSpec(shape=(2,), dtype=tf.float32, name=None) to TensorSpec(shape=(None,), dtype=tf.int32, name=None)`. Received args: (<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1., 2.], dtype=float32)>,) and kwargs: {} for signature: (x: TensorSpec(shape=(None,), dtype=tf.int32, name=None)).
使用未知維度以獲得彈性
由於 TensorFlow 會根據張量的形狀比對張量,因此使用 None
維度作為萬用字元,將允許 tf.function
為可變大小的輸入重複使用追蹤。如果您有不同長度的序列,或每個批次的不同大小的圖片,就可能會發生可變大小的輸入。您可以查看 Transformer 和 Deep Dream 教學課程以取得範例。
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
print('Tracing with', x)
return x
# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))
Tracing with Tensor("x:0", shape=(None,), dtype=int32) tf.Tensor([1 2 3], shape=(3,), dtype=int32) tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)
使用 reduce_retracing
取得自動彈性
啟用 reduce_retracing
時,tf.function
會自動識別它正在觀察的輸入類型的超類型,並選擇自動追蹤更通用的圖。這不如直接設定 input_signature
有效率,但在需要支援多種類型時很有用。
@tf.function(reduce_retracing=True)
def g(x):
print('Tracing with', x)
return x
# Traces once.
print(g(tf.constant([1, 2, 3])))
# Traces again, but more generalized this time.
print(g(tf.constant([1, 2, 3, 4, 5])))
# No more tracing!
print(g(tf.constant([1, 2, 3, 4, 5, 6, 7])))
print(g(tf.constant([1, 2, 3, 4, 5, 6, 7, 8, 9])))
Tracing with Tensor("x:0", shape=(3,), dtype=int32) tf.Tensor([1 2 3], shape=(3,), dtype=int32) Tracing with Tensor("x:0", shape=(None,), dtype=int32) tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32) tf.Tensor([1 2 3 4 5 6 7], shape=(7,), dtype=int32) tf.Tensor([1 2 3 4 5 6 7 8 9], shape=(9,), dtype=int32)
傳遞張量而非 Python 常值
Python 引數通常用於控制超參數和圖建構,例如 num_layers=10
或 training=True
或 nonlinearity='relu'
。因此,如果 Python 引數變更,您必須重新追蹤圖才有意義。
但是,Python 引數可能未用於控制圖建構。在這些情況下,Python 值的變更可能會觸發不必要的重新追蹤。以這個訓練迴圈為例,AutoGraph 將動態展開此迴圈。儘管有多個追蹤,但產生的圖實際上是相同的,因此重新追蹤是不必要的。
def train_one_step():
pass
@tf.function
def train(num_steps):
print("Tracing with num_steps = ", num_steps)
tf.print("Executing with num_steps = ", num_steps)
for _ in tf.range(num_steps):
train_one_step()
print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)
print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
Retracing occurs for different Python arguments. Tracing with num_steps = 10 Executing with num_steps = 10 Tracing with num_steps = 20 Executing with num_steps = 20 Traces are reused for Tensor arguments. Tracing with num_steps = Tensor("num_steps:0", shape=(), dtype=int32) Executing with num_steps = 10 Executing with num_steps = 20
如果您需要強制重新追蹤,請建立新的 tf.function
。個別的 tf.function
物件保證不會共用追蹤。
def f():
print('Tracing!')
tf.print('Executing')
tf.function(f)()
tf.function(f)()
Tracing! Executing Tracing! Executing
使用追蹤通訊協定
如果可以,您應該偏好將 Python 類型轉換為 tf.experimental.ExtensionType
。此外,ExtensionType
的 TraceType
是與其關聯的 tf.TypeSpec
。因此,如果需要,您可以簡單地覆寫預設 tf.TypeSpec
以控制 ExtensionType
的 Tracing Protocol
。如需詳細資訊,請參閱擴充類型指南中的「自訂 ExtensionType 的 TypeSpec」一節。
否則,若要直接控制 tf.function
應針對特定 Python 類型重新追蹤的時間,您可以自行實作 Tracing Protocol
。
@tf.function
def get_mixed_flavor(fruit_a, fruit_b):
return fruit_a.flavor + fruit_b.flavor
class Fruit:
flavor = tf.constant([0, 0])
class Apple(Fruit):
flavor = tf.constant([1, 2])
class Mango(Fruit):
flavor = tf.constant([3, 4])
# As described in the above rules, a generic TraceType for `Apple` and `Mango`
# is generated (and a corresponding ConcreteFunction is traced) but it fails to
# match the second function call since the first pair of Apple() and Mango()
# have gone out out of scope by then and deleted.
get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function
get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function again
# However, each subclass of the `Fruit` class has a fixed flavor, and you
# can reuse an existing traced concrete function if it was the same
# subclass. Avoiding such unnecessary tracing of concrete functions
# can have significant performance benefits.
class FruitTraceType(tf.types.experimental.TraceType):
def __init__(self, fruit):
self.fruit_type = type(fruit)
self.fruit_value = fruit
def is_subtype_of(self, other):
# True if self subtypes `other` and `other`'s type matches FruitTraceType.
return (type(other) is FruitTraceType and
self.fruit_type is other.fruit_type)
def most_specific_common_supertype(self, others):
# `self` is the specific common supertype if all input types match it.
return self if all(self == other for other in others) else None
def placeholder_value(self, placeholder_context=None):
# Use the fruit itself instead of the type for correct tracing.
return self.fruit_value
def __eq__(self, other):
return type(other) is FruitTraceType and self.fruit_type == other.fruit_type
def __hash__(self):
return hash(self.fruit_type)
class FruitWithTraceType:
def __tf_tracing_type__(self, context):
return FruitTraceType(self)
class AppleWithTraceType(FruitWithTraceType):
flavor = tf.constant([1, 2])
class MangoWithTraceType(FruitWithTraceType):
flavor = tf.constant([3, 4])
# Now if you try calling it again:
get_mixed_flavor(AppleWithTraceType(), MangoWithTraceType()) # Traces a new concrete function
get_mixed_flavor(AppleWithTraceType(), MangoWithTraceType()) # Re-uses the traced concrete function
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([4, 6], dtype=int32)>
取得具體函式
每次追蹤函式時,都會建立新的具體函式。您可以使用 get_concrete_function
直接取得具體函式。
print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
Obtaining concrete trace Executing traced function tf.Tensor(b'aa', shape=(), dtype=string) tf.Tensor(b'bb', shape=(), dtype=string)
# You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
tf.Tensor(b'cc', shape=(), dtype=string)
列印 ConcreteFunction
會顯示其輸入引數 (具有類型) 及其輸出類型的摘要。
print(double_strings)
ConcreteFunction Input Parameters: a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.string, name=None) Output Type: TensorSpec(shape=(), dtype=tf.string, name=None) Captures: None
您也可以直接擷取具體函式的簽名。
print(double_strings.function_type)
(a: TensorSpec(shape=(), dtype=tf.string, name=None)) -> TensorSpec(shape=(), dtype=tf.string, name=None)
使用具有不相容類型的具體追蹤會擲回錯誤
with assert_raises(tf.errors.InvalidArgumentError):
double_strings(tf.constant(1))
Caught expected exception <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>: Traceback (most recent call last): File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/function_type_utils.py", line 442, in bind_function_inputs bound_arguments = function_type.bind_with_defaults( File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/core/function/polymorphism/function_type.py", line 277, in bind_with_defaults with_default_args[arg_name] = constraint.cast( TypeError: Can not cast TensorSpec(shape=(), dtype=tf.int32, name=None) to TensorSpec(shape=(), dtype=tf.string, name=None) The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1180, in _call_impl return self._call_with_structured_signature(args, kwargs) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1260, in _call_with_structured_signature function_type_utils.canonicalize_function_inputs( TypeError: Binding inputs to tf.function failed due to `Can not cast TensorSpec(shape=(), dtype=tf.int32, name=None) to TensorSpec(shape=(), dtype=tf.string, name=None)`. Received args: (<tf.Tensor: shape=(), dtype=int32, numpy=1>,) and kwargs: {} for signature: (a: TensorSpec(shape=(), dtype=tf.string, name=None)) -> TensorSpec(shape=(), dtype=tf.string, name=None). During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/tmpfs/tmp/ipykernel_11117/3551158538.py", line 8, in assert_raises yield File "/tmpfs/tmp/ipykernel_11117/3196284684.py", line 2, in <module> double_strings(tf.constant(1)) tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_162 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_162]
您可能會注意到,在具體函式的輸入簽名中,Python 引數會獲得特殊處理。在 TensorFlow 2.3 之前,Python 引數只是從具體函式的簽名中移除。從 TensorFlow 2.3 開始,Python 引數會保留在簽名中,但會限制為採用追蹤期間設定的值。
@tf.function
def pow(a, b):
return a ** b
square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
ConcreteFunction Input Parameters: a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=<unknown>, dtype=tf.float32, name=None) b (POSITIONAL_OR_KEYWORD): Literal[2] Output Type: TensorSpec(shape=<unknown>, dtype=tf.float32, name=None) Captures: None
assert square(tf.constant(10.0)) == 100
with assert_raises(TypeError):
square(tf.constant(10.0), b=3)
Caught expected exception <class 'TypeError'>: Traceback (most recent call last): File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/function_type_utils.py", line 442, in bind_function_inputs bound_arguments = function_type.bind_with_defaults( File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/core/function/polymorphism/function_type.py", line 277, in bind_with_defaults with_default_args[arg_name] = constraint.cast( ValueError: Can not cast 3 to Literal[2] The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1180, in _call_impl return self._call_with_structured_signature(args, kwargs) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1260, in _call_with_structured_signature function_type_utils.canonicalize_function_inputs( TypeError: Binding inputs to tf.function failed due to `Can not cast 3 to Literal[2]`. Received args: (<tf.Tensor: shape=(), dtype=float32, numpy=10.0>,) and kwargs: {'b': 3} for signature: (a: TensorSpec(shape=<unknown>, dtype=tf.float32, name=None), b: Literal[2]) -> TensorSpec(shape=<unknown>, dtype=tf.float32, name=None). During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1183, in _call_impl return self._call_with_flat_signature(args, kwargs) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1234, in _call_with_flat_signature raise TypeError(f"{self._flat_signature_summary()} got unexpected " TypeError: pow(a) got unexpected keyword arguments: b. During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/tmpfs/tmp/ipykernel_11117/3551158538.py", line 8, in assert_raises yield File "/tmpfs/tmp/ipykernel_11117/2310937119.py", line 4, in <module> square(tf.constant(10.0), b=3) TypeError: Binding inputs to tf.function failed due to `Can not cast 3 to Literal[2]`. Received args: (<tf.Tensor: shape=(), dtype=float32, numpy=10.0>,) and kwargs: {'b': 3} for signature: (a: TensorSpec(shape=<unknown>, dtype=tf.float32, name=None), b: Literal[2]) -> TensorSpec(shape=<unknown>, dtype=tf.float32, name=None). Fallback to flat signature also failed due to: pow(a) got unexpected keyword arguments: b.
取得圖
雖然擷取實際 tf.Graph
物件並非您通常需要執行的動作,但您可以從任何具體函式輕鬆取得。
graph = double_strings.graph
for node in graph.as_graph_def().node:
print(f'{node.input} -> {node.name}')
[] -> a ['a', 'a'] -> add ['add'] -> Identity
實際上,tf.Graph
並非直接可呼叫。我們實際上使用 tf.types.experimental.AtomicFunction
來執行 tf.Graph
描述的計算。您可以存取描述追蹤 tf.Graph
的 AtomicFunction
,並直接呼叫它,而不是 ConcreteFunction
atomic_fn = double_strings.inference_fn
atomic_fn(tf.constant("a"))
<tf.Tensor: shape=(), dtype=string, numpy=b'aa'>
這具有在高效能情境中降低 Python 額外負荷的優點。但它僅應適用於正向推論 (不支援梯度),而擷取的張量值 (如果有的話) 則需要明確提供。
偵錯
一般而言,在立即模式中偵錯程式碼比在 tf.function
內更輕鬆。您應確保程式碼在立即模式中無錯誤執行,然後再使用 tf.function
裝飾。為了協助偵錯程序,您可以呼叫 tf.config.run_functions_eagerly(True)
以全域停用並重新啟用 tf.function
。
在追蹤僅在 tf.function
內出現的問題時,以下是一些秘訣
- 一般的 Python
print
呼叫僅在追蹤期間執行,協助您追蹤函式何時取得 (重新) 追蹤。 tf.print
呼叫每次都會執行,並可協助您追蹤執行期間的中繼值。tf.debugging.enable_check_numerics
是一種追蹤 NaN 和 Inf 建立位置的簡單方法。pdb
(the Python 除錯器) 可以幫助您瞭解追蹤期間發生的狀況。(注意:pdb
會將您帶入 AutoGraph 轉換後的原始碼。)
AutoGraph 轉換
AutoGraph 是一個預設在 tf.function
中啟用的程式庫,可將 Python 立即執行程式碼的子集轉換為與圖相容的 TensorFlow 運算。這包括控制流程,例如 if
、for
、while
。
TensorFlow 運算 (例如 tf.cond
和 tf.while_loop
) 仍然可以運作,但如果以 Python 撰寫控制流程,通常更容易撰寫和理解。
# A simple loop
@tf.function
def f(x):
while tf.reduce_sum(x) > 1:
tf.print(x)
x = tf.tanh(x)
return x
f(tf.random.uniform([5]))
[0.143583655 0.698347807 0.767881036 0.857545733 0.0981599092] [0.142604977 0.603318036 0.645695686 0.694991 0.0978458375] [0.141646087 0.539406419 0.568765223 0.601178706 0.0975347683] [0.140706316 0.492538482 0.514451861 0.537887812 0.0972266495] [0.139785022 0.456228882 0.473406643 0.491387457 0.0969214365] [0.138881624 0.427005619 0.440947741 0.455316931 0.096619077] [0.137995526 0.402815819 0.414429694 0.426259726 0.0963195339] [0.137126192 0.38235572 0.392227381 0.402190775 0.0960227624] [0.136273116 0.364751518 0.373278826 0.381821901 0.0957287177] [0.135435775 0.349392414 0.356856316 0.364288628 0.0954373628] [0.134613693 0.335836589 0.342441976 0.34898597 0.0951486528] [0.133806422 0.323755413 0.329655707 0.335475951 0.094862543] [0.133013532 0.312898636 0.318211377 0.323432535 0.0945790112] [0.132234573 0.303071797 0.307888716 0.312607288 0.0942980051] [0.13146916 0.294121206 0.298515141 0.302807152 0.0940194875] [0.13071692 0.285923541 0.289953142 0.29387942 0.0937434211] [0.12997745 0.278378516 0.282091677 0.285701483 0.0934697762] [0.129250407 0.2714037 0.274839878 0.278173655 0.0931985155] [0.12853545 0.264930487 0.268122554 0.271213919 0.0929296] [0.127832219 0.258901358 0.261877 0.264754 0.092663005] <tf.Tensor: shape=(5,), dtype=float32, numpy= array([0.12714042, 0.25326762, 0.2560503 , 0.25873667, 0.0923987 ], dtype=float32)>
如果您感到好奇,可以檢查 AutoGraph 產生的程式碼。
print(tf.autograph.to_code(f.python_function))
def tf__f(x): with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope: do_return = False retval_ = ag__.UndefinedReturnValue() def get_state(): return (x,) def set_state(vars_): nonlocal x (x,) = vars_ def loop_body(): nonlocal x ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope) x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope) def loop_test(): return ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1 ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {}) try: do_return = True retval_ = ag__.ld(x) except: do_return = False raise return fscope.ret(retval_, do_return)
條件式
AutoGraph 會將某些 if <condition>
陳述式轉換為等效的 tf.cond
呼叫。如果 <condition>
是張量,則會進行此替換。否則,if
陳述式會以 Python 條件式執行。
Python 條件式會在追蹤期間執行,因此只有條件式的一個分支會新增至圖中。如果沒有 AutoGraph,如果存在資料相依的控制流程,則追蹤的圖將無法採用替代分支。
tf.cond
會追蹤條件式的兩個分支並將其新增至圖中,在執行時間動態選取分支。追蹤可能會產生意想不到的副作用;請查看 AutoGraph 追蹤效果以取得更多資訊。
@tf.function
def fizzbuzz(n):
for i in tf.range(1, n + 1):
print('Tracing for loop')
if i % 15 == 0:
print('Tracing fizzbuzz branch')
tf.print('fizzbuzz')
elif i % 3 == 0:
print('Tracing fizz branch')
tf.print('fizz')
elif i % 5 == 0:
print('Tracing buzz branch')
tf.print('buzz')
else:
print('Tracing default branch')
tf.print(i)
fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
Tracing for loop Tracing fizzbuzz branch Tracing fizz branch Tracing buzz branch Tracing default branch 1 2 fizz 4 buzz 1 2 fizz 4 buzz fizz 7 8 fizz buzz 11 fizz 13 14 fizzbuzz 16 17 fizz 19 buzz
如需 AutoGraph 轉換的 if 陳述式的其他限制,請參閱參考文件。
迴圈
AutoGraph 會將某些 for
和 while
陳述式轉換為等效的 TensorFlow 迴圈運算,例如 tf.while_loop
。如果未轉換,for
或 while
迴圈會以 Python 迴圈執行。
在下列情況下會進行此替換
for x in y
:如果y
是張量,則轉換為tf.while_loop
。在y
是tf.data.Dataset
的特殊情況下,會產生tf.data.Dataset
運算的組合。while <condition>
:如果<condition>
是張量,則轉換為tf.while_loop
。
Python 迴圈會在追蹤期間執行,為迴圈的每次迭代將額外的運算新增至 tf.Graph
。
TensorFlow 迴圈會追蹤迴圈的主體,並在執行時間動態選取要執行的迭代次數。迴圈主體只會在產生的 tf.Graph
中出現一次。
如需 AutoGraph 轉換的 for
和 while
陳述式的其他限制,請參閱參考文件。
在 Python 資料上執行迴圈
常見的陷阱是在 tf.function
內對 Python/NumPy 資料執行迴圈。此迴圈會在追蹤過程中執行,為迴圈的每次迭代將模型副本新增至 tf.Graph
。
如果您想將整個訓練迴圈包裝在 tf.function
中,最安全的方法是將您的資料包裝為 tf.data.Dataset
,以便 AutoGraph 動態展開訓練迴圈。
def measure_graph_size(f, *args):
g = f.get_concrete_function(*args).graph
print("{}({}) contains {} nodes in its graph".format(
f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))
@tf.function
def train(dataset):
loss = tf.constant(0)
for x, y in dataset:
loss += tf.abs(y - x) # Some dummy computation.
return loss
small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)
measure_graph_size(train, tf.data.Dataset.from_generator(
lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph train(<_FlatMapDataset element_spec=(TensorSpec(shape=<unknown>, dtype=tf.int32, name=None), TensorSpec(shape=<unknown>, dtype=tf.int32, name=None))>) contains 6 nodes in its graph train(<_FlatMapDataset element_spec=(TensorSpec(shape=<unknown>, dtype=tf.int32, name=None), TensorSpec(shape=<unknown>, dtype=tf.int32, name=None))>) contains 6 nodes in its graph
將 Python/NumPy 資料包裝在 Dataset 中時,請注意 tf.data.Dataset.from_generator
與 tf.data.Dataset.from_tensor_slices
。前者會將資料保留在 Python 中,並透過 tf.py_function
擷取資料 (這可能會對效能產生影響),而後者會將資料副本捆綁為圖中的一個大型 tf.constant()
節點 (這可能會對記憶體產生影響)。
透過 TFRecordDataset
、CsvDataset
等從檔案讀取資料是最有效率的資料取用方式,因為這樣 TensorFlow 本身就可以管理資料的非同步載入和預先擷取,而無需 Python 參與。若要瞭解更多資訊,請參閱 tf.data
:建構 TensorFlow 輸入管線指南。
在迴圈中累積值
常見的模式是從迴圈累積中間值。通常,這是透過附加到 Python 清單或將項目新增至 Python 字典來完成。但是,由於這些是 Python 副作用,因此它們在動態展開的迴圈中無法如預期般運作。使用 tf.TensorArray
從動態展開的迴圈累積結果。
batch_size = 2
seq_len = 3
feature_size = 4
def rnn_step(inp, state):
return inp + state
@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
# [batch, time, features] -> [time, batch, features]
input_data = tf.transpose(input_data, [1, 0, 2])
max_seq_len = input_data.shape[0]
states = tf.TensorArray(tf.float32, size=max_seq_len)
state = initial_state
for i in tf.range(max_seq_len):
state = rnn_step(input_data[i], state)
states = states.write(i, state)
return tf.transpose(states.stack(), [1, 0, 2])
dynamic_rnn(rnn_step,
tf.random.uniform([batch_size, seq_len, feature_size]),
tf.zeros([batch_size, feature_size]))
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy= array([[[0.36914825, 0.64223015, 0.7850807 , 0.19980955], [0.4491886 , 1.4083506 , 1.0351617 , 0.20833313], [0.7401295 , 1.7583194 , 1.1593785 , 0.32083678]], [[0.09818649, 0.09965849, 0.28532243, 0.2933966 ], [0.56936 , 1.0815177 , 0.7327199 , 0.6250684 ], [1.5300128 , 1.22948 , 0.8870441 , 0.770558 ]]], dtype=float32)>
限制
tf.function
在設計上有一些限制,將 Python 函式轉換為 tf.function
時,您應該注意這些限制。
執行 Python 副作用
副作用 (例如列印、附加到清單和變更全域變數) 在 tf.function
內可能會表現異常,有時會執行兩次或完全不執行。它們只會在您第一次使用一組輸入呼叫 tf.function
時發生。之後,追蹤的 tf.Graph
會重新執行,而不會執行 Python 程式碼。
一般的經驗法則是避免在您的邏輯中依賴 Python 副作用,而僅將它們用於偵錯您的追蹤。否則,TensorFlow API (例如 tf.data
、tf.print
、tf.summary
、tf.Variable.assign
和 tf.TensorArray
) 是確保您的程式碼將由 TensorFlow 執行階段在每次呼叫時執行的最佳方法。
@tf.function
def f(x):
print("Traced with", x)
tf.print("Executed with", x)
f(1)
f(1)
f(2)
Traced with 1 Executed with 1 Executed with 1 Traced with 2 Executed with 2
如果您想在每次調用 tf.function
時執行 Python 程式碼,tf. py_function
是一種退出機制。tf.py_function
的缺點是它不具可攜性或特別高效能,無法與 SavedModel
一起儲存,並且在分散式 (多 GPU、TPU) 設定中無法良好運作。此外,由於 tf.py_function
必須連接到圖中,因此它會將所有輸入/輸出轉換為張量。
@tf.py_function(Tout=tf.float32)
def py_plus(x, y):
print('Executing eagerly.')
return x + y
@tf.function
def tf_wrapper(x, y):
print('Tracing.')
return py_plus(x, y)
tf.function
會在第一次追蹤
tf_wrapper(tf.constant(1.0), tf.constant(2.0)).numpy()
Tracing. Executing eagerly. 3.0
但裡面的 tf.py_function
每次都會立即執行
tf_wrapper(tf.constant(1.0), tf.constant(2.0)).numpy()
Executing eagerly. 3.0
變更 Python 全域變數和自由變數
變更 Python 全域變數和 自由變數 算作 Python 副作用,因此它只會在追蹤期間發生。
external_list = []
@tf.function
def side_effect(x):
print('Python side effect')
external_list.append(x)
side_effect(1)
side_effect(1)
side_effect(1)
# The list append only happened once!
assert len(external_list) == 1
Python side effect
有時,意想不到的行為很難注意到。在下面的範例中,counter
旨在保護變數的遞增。但是,由於它是 Python 整數而不是 TensorFlow 物件,因此其值會在第一次追蹤期間擷取。當使用 tf.function
時,assign_add
將無條件記錄在基礎圖中。因此,每次呼叫 tf.function
時,v
都會遞增 1。當 Python 副作用 (範例中的 counter
) 用於判斷要執行的運算 (範例中的 assign_add
) 時,嘗試將其圖模式 TensorFlow 程式碼遷移到使用 tf.function
裝飾器的 Tensorflow 2 的使用者中,此問題很常見。通常,使用者只有在看到可疑的數值結果或效能明顯低於預期時 (例如,如果受保護的運算非常耗時) 才會意識到這一點。
class Model(tf.Module):
def __init__(self):
self.v = tf.Variable(0)
self.counter = 0
@tf.function
def __call__(self):
if self.counter == 0:
# A python side-effect
self.counter += 1
self.v.assign_add(1)
return self.v
m = Model()
for n in range(3):
print(m().numpy()) # prints 1, 2, 3
1 2 3
實現預期行為的解決方法是使用 tf.init_scope
將運算提升到函式圖之外。這可確保變數遞增僅在追蹤期間完成一次。應該注意的是,init_scope
還有其他副作用,包括清除的控制流程和梯度帶。有時,init_scope
的使用可能會變得太複雜而難以實際管理。
class Model(tf.Module):
def __init__(self):
self.v = tf.Variable(0)
self.counter = 0
@tf.function
def __call__(self):
if self.counter == 0:
# Lifts ops out of function-building graphs
with tf.init_scope():
self.counter += 1
self.v.assign_add(1)
return self.v
m = Model()
for n in range(3):
print(m().numpy()) # prints 1, 1, 1
1 1 1
總之,作為經驗法則,您應該避免變更 Python 物件 (例如整數或容器 (例如清單)),這些物件存在於 tf.function
之外。相反,請使用引數和 TF 物件。例如,「在迴圈中累積值」一節有一個關於如何實作類似清單的運算的範例。
在某些情況下,如果狀態是 tf.Variable
,您可以擷取和操作狀態。這就是使用相同 ConcreteFunction
重複呼叫來更新 Keras 模型權重的方式。
使用 Python 迭代器和產生器
許多 Python 功能 (例如產生器和迭代器) 依賴 Python 執行階段來追蹤狀態。一般來說,雖然這些建構如預期般在立即執行模式下運作,但它們是 Python 副作用的範例,因此僅在追蹤期間發生。
@tf.function
def buggy_consume_next(iterator):
tf.print("Value:", next(iterator))
iterator = iter([1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)
Value: 1 Value: 1 Value: 1
就像 TensorFlow 具有用於清單建構的專用 tf.TensorArray
一樣,它也具有用於迭代建構的專用 tf.data.Iterator
。請參閱關於 AutoGraph 轉換的一節以取得概述。此外,tf.data
API 可以協助實作產生器模式
@tf.function
def good_consume_next(iterator):
# This is ok, iterator is a tf.data.Iterator
tf.print("Value:", next(iterator))
ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])
iterator = iter(ds)
good_consume_next(iterator)
good_consume_next(iterator)
good_consume_next(iterator)
Value: 1 Value: 2 Value: 3
tf.function 的所有輸出都必須是傳回值
除了 tf.Variable
之外,tf.function 必須傳回其所有輸出。嘗試直接從函式存取任何張量而不透過傳回值會導致「洩漏」。
例如,下面的函式透過 Python 全域 x
「洩漏」張量 a
x = None
@tf.function
def leaky_function(a):
global x
x = a + 1 # Bad - leaks local tensor
return a + 2
correct_a = leaky_function(tf.constant(1))
print(correct_a.numpy()) # Good - value obtained from function's returns
try:
x.numpy() # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
print(expected)
3 'SymbolicTensor' object has no attribute 'numpy'
即使洩漏的值也已傳回,情況也是如此
@tf.function
def leaky_function(a):
global x
x = a + 1 # Bad - leaks local tensor
return x # Good - uses local tensor
correct_a = leaky_function(tf.constant(1))
print(correct_a.numpy()) # Good - value obtained from function's returns
try:
x.numpy() # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
print(expected)
@tf.function
def captures_leaked_tensor(b):
b += x # Bad - `x` is leaked from `leaky_function`
return b
with assert_raises(TypeError):
captures_leaked_tensor(tf.constant(2))
2 'SymbolicTensor' object has no attribute 'numpy' Caught expected exception <class 'TypeError'>: Traceback (most recent call last): File "/tmpfs/tmp/ipykernel_11117/3551158538.py", line 8, in assert_raises yield File "/tmpfs/tmp/ipykernel_11117/566849597.py", line 21, in <module> captures_leaked_tensor(tf.constant(2)) TypeError: <tf.Tensor 'add:0' shape=() dtype=int32> is out of scope and cannot be used here. Use return values, explicit Python locals or TensorFlow collections to access it. Please see https://tensorflow.dev.org.tw/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information. <tf.Tensor 'add:0' shape=() dtype=int32> was defined here: File "/usr/lib/python3.9/runpy.py", line 197, in _run_module_as_main File "/usr/lib/python3.9/runpy.py", line 87, in _run_code File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel_launcher.py", line 17, in <module> File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/traitlets/config/application.py", line 1077, in launch_instance File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 739, in start File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tornado/platform/asyncio.py", line 195, in start File "/usr/lib/python3.9/asyncio/base_events.py", line 601, in run_forever File "/usr/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once File "/usr/lib/python3.9/asyncio/events.py", line 80, in _run File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 529, in dispatch_queue File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 518, in process_one File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 424, in dispatch_shell File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 766, in execute_request File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 429, in do_execute File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/zmqshell.py", line 549, in run_cell File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3048, in run_cell File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3103, in _run_cell File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3308, in run_cell_async File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3490, in run_ast_nodes File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3550, in run_code File "/tmpfs/tmp/ipykernel_11117/566849597.py", line 7, in <module> File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 832, in __call__ File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 888, in _call File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 695, in _initialize File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 178, in trace_function File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 283, in _maybe_define_function File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 310, in _create_concrete_function File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 1059, in func_graph_from_py_func File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 598, in wrapped_fn File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/autograph_util.py", line 41, in autograph_handler File "/tmpfs/tmp/ipykernel_11117/566849597.py", line 4, in leaky_function File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/math_ops.py", line 1478, in binary_op_wrapper File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py", line 1260, in op_dispatch_handler File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/math_ops.py", line 1871, in _add_dispatch File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/gen_math_ops.py", line 490, in add_v2 File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/op_def_library.py", line 796, in _apply_op_helper File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 670, in _create_op_internal File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py", line 2652, in _create_op_internal File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py", line 1160, in from_node_def The tensor <tf.Tensor 'add:0' shape=() dtype=int32> cannot be accessed from here, because it was defined in FuncGraph(name=leaky_function, id=140399462105440), which is out of scope.
通常,當您使用 Python 陳述式或資料結構時,會發生此類洩漏。除了洩漏無法存取的張量外,此類陳述式也可能是錯誤的,因為它們算作 Python 副作用,並且不保證在每次函式呼叫時執行。
洩漏本機張量的常見方式還包括變更外部 Python 集合或物件
class MyClass:
def __init__(self):
self.field = None
external_list = []
external_object = MyClass()
def leaky_function():
a = tf.constant(1)
external_list.append(a) # Bad - leaks tensor
external_object.field = a # Bad - leaks tensor
不支援遞迴 tf.functions
不支援遞迴 tf.function
s,並且可能會導致無限迴圈。例如,
@tf.function
def recursive_fn(n):
if n > 0:
return recursive_fn(n - 1)
else:
return 1
with assert_raises(Exception):
recursive_fn(tf.constant(5)) # Bad - maximum recursion error.
Caught expected exception <class 'Exception'>: Traceback (most recent call last): File "/tmpfs/tmp/ipykernel_11117/3551158538.py", line 8, in assert_raises yield File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 9, in <module> recursive_fn(tf.constant(5)) # Bad - maximum recursion error. tensorflow.python.autograph.impl.api.StagingError: in user code: File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/usr/lib/python3.9/abc.py", line 119, in __instancecheck__ return _abc_instancecheck(cls, instance) RecursionError: maximum recursion depth exceeded while calling a Python object
即使遞迴 tf.function
似乎可以運作,Python 函式也會被追蹤多次,並且可能會對效能產生影響。例如,
@tf.function
def recursive_fn(n):
if n > 0:
print('tracing')
return recursive_fn(n - 1)
else:
return 1
recursive_fn(5) # Warning - multiple tracings
tracing tracing tracing tracing tracing <tf.Tensor: shape=(), dtype=int32, numpy=1>
已知問題
如果您的 tf.function
無法正確評估,則錯誤可能是由這些已知問題所造成,這些問題計劃在未來修復。
依賴 Python 全域變數和自由變數
當使用 Python 引數的新值呼叫時,tf.function
會建立新的 ConcreteFunction
。但是,對於該 tf.function
的 Python 閉包、全域變數或非本機變數,它不會這樣做。如果它們的值在呼叫 tf.function
之間發生變更,tf.function
仍會使用它們在追蹤時所具有的值。這與一般 Python 函式的運作方式不同。
因此,您應該遵循函數式程式設計樣式,該樣式使用引數而不是封閉外部名稱。
@tf.function
def buggy_add():
return 1 + foo
@tf.function
def recommended_add(foo):
return 1 + foo
foo = 1
print("Buggy:", buggy_add())
print("Correct:", recommended_add(foo))
Buggy: tf.Tensor(2, shape=(), dtype=int32) Correct: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo = 100
print("Buggy:", buggy_add()) # Did not change!
print("Correct:", recommended_add(foo))
Updating the value of `foo` to 100! Buggy: tf.Tensor(2, shape=(), dtype=int32) Correct: tf.Tensor(101, shape=(), dtype=int32)
更新全域值的另一種方法是將其設為 tf.Variable
,並改用 Variable.assign
方法。
@tf.function
def variable_add():
return 1 + foo
foo = tf.Variable(1)
print("Variable:", variable_add())
Variable: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo.assign(100)
print("Variable:", variable_add())
Updating the value of `foo` to 100! Variable: tf.Tensor(101, shape=(), dtype=int32)
依賴 Python 物件
支援將自訂 Python 物件作為引數傳遞給 tf.function
,但有一些限制。
為了獲得最大的功能涵蓋範圍,請考慮在將物件傳遞給 tf.function
之前,將物件轉換為擴充類型。您也可以使用 Python 基本類型和 tf.nest
相容的結構。
但是,如追蹤規則中所述,當自訂 Python 類別未提供自訂 TraceType
時,tf.function
會被迫使用基於執行個體的相等性,這表示當您傳遞具有修改屬性的相同物件時,它將不會建立新的追蹤。
class SimpleModel(tf.Module):
def __init__(self):
# These values are *not* tf.Variables.
self.bias = 0.
self.weight = 2.
@tf.function
def evaluate(model, x):
return model.weight * x + model.bias
simple_model = SimpleModel()
x = tf.constant(10.)
print(evaluate(simple_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
simple_model.bias += 5.0
print(evaluate(simple_model, x)) # Didn't change :(
Adding bias! tf.Tensor(20.0, shape=(), dtype=float32)
使用相同的 tf.function
來評估模型的修改執行個體將會出錯,因為它仍然具有與原始模型相同的基於執行個體的 TraceType。
因此,建議您撰寫 tf.function
以避免依賴可變物件屬性,或為物件實作追蹤協定,以告知 tf.function
關於此類屬性。
如果無法做到這一點,一種解決方法是在每次修改物件以強制重新追蹤時,建立新的 tf.function
s
def evaluate(model, x):
return model.weight * x + model.bias
new_model = SimpleModel()
evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)
# Don't pass in `new_model`. `tf.function` already captured its state during tracing.
print(evaluate_no_bias(x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
new_model.bias += 5.0
# Create new `tf.function` and `ConcreteFunction` since you modified `new_model`.
evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)
print(evaluate_with_bias(x)) # Don't pass in `new_model`.
Adding bias! tf.Tensor(25.0, shape=(), dtype=float32)
由於重新追蹤可能很耗費資源,因此您可以使用 tf.Variable
s 作為物件屬性,可以對其進行變更 (但不能變更,請注意!),以達到類似的效果,而無需重新追蹤。
class BetterModel:
def __init__(self):
self.bias = tf.Variable(0.)
self.weight = tf.Variable(2.)
@tf.function
def evaluate(model, x):
return model.weight * x + model.bias
better_model = BetterModel()
print(evaluate(better_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
better_model.bias.assign_add(5.0) # Note: instead of better_model.bias += 5
print(evaluate(better_model, x)) # This works!
Adding bias! tf.Tensor(25.0, shape=(), dtype=float32)
建立 tf.Variables
tf.function
僅支援在第一次呼叫時建立一次的單例 tf.Variable
s,並在後續函式呼叫中重複使用。下面的程式碼片段會在每次函式呼叫中建立新的 tf.Variable
,這會導致 ValueError
例外狀況。
範例
@tf.function
def f(x):
v = tf.Variable(1.0)
return v
with assert_raises(ValueError):
f(1.0)
Caught expected exception <class 'ValueError'>: Traceback (most recent call last): File "/tmpfs/tmp/ipykernel_11117/3551158538.py", line 8, in assert_raises yield File "/tmpfs/tmp/ipykernel_11117/3018268426.py", line 7, in <module> f(1.0) ValueError: in user code: File "/tmpfs/tmp/ipykernel_11117/3018268426.py", line 3, in f * v = tf.Variable(1.0) ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://tensorflow.dev.org.tw/guide/function#creating_tfvariables for more information.
用於解決此限制的常見模式是以 Python None 值開始,然後在值為 None 時有條件地建立 tf.Variable
class Count(tf.Module):
def __init__(self):
self.count = None
@tf.function
def __call__(self):
if self.count is None:
self.count = tf.Variable(0)
return self.count.assign_add(1)
c = Count()
print(c())
print(c())
tf.Tensor(1, shape=(), dtype=int32) tf.Tensor(2, shape=(), dtype=int32)
與多個 Keras 最佳化工具搭配使用
當將多個 Keras 最佳化工具與 tf.function
搭配使用時,您可能會遇到 ValueError: tf.function only supports singleton tf.Variables created on the first call.
。發生此錯誤的原因是最佳化工具在第一次套用梯度時,會在內部建立 tf.Variable
s。
opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)
@tf.function
def train_step(w, x, y, optimizer):
with tf.GradientTape() as tape:
L = tf.reduce_sum(tf.square(w*x - y))
gradients = tape.gradient(L, [w])
optimizer.apply_gradients(zip(gradients, [w]))
w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])
train_step(w, x, y, opt1)
print("Calling `train_step` with different optimizer...")
with assert_raises(ValueError):
train_step(w, x, y, opt2)
Calling `train_step` with different optimizer... Caught expected exception <class 'ValueError'>: WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1701138168.913099 11284 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. Traceback (most recent call last): File "/tmpfs/tmp/ipykernel_11117/3551158538.py", line 8, in assert_raises yield File "/tmpfs/tmp/ipykernel_11117/950644149.py", line 18, in <module> train_step(w, x, y, opt2) ValueError: in user code: File "/tmpfs/tmp/ipykernel_11117/950644149.py", line 9, in train_step * optimizer.apply_gradients(zip(gradients, [w])) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/optimizer.py", line 1223, in apply_gradients ** return super().apply_gradients(grads_and_vars, name=name) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/optimizer.py", line 638, in apply_gradients self.build(trainable_variables) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/adam.py", line 145, in build self.add_variable_from_reference( File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/optimizer.py", line 1125, in add_variable_from_reference return super().add_variable_from_reference( File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/optimizer.py", line 513, in add_variable_from_reference variable = tf.Variable( ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://tensorflow.dev.org.tw/guide/function#creating_tfvariables for more information.
如果您需要在呼叫之間變更有狀態物件,最簡單的方法是定義 tf.Module
子類別,並建立執行個體來保存這些物件
class TrainStep(tf.Module):
def __init__(self, optimizer):
self.optimizer = optimizer
@tf.function
def __call__(self, w, x, y):
with tf.GradientTape() as tape:
L = tf.reduce_sum(tf.square(w*x - y))
gradients = tape.gradient(L, [w])
self.optimizer.apply_gradients(zip(gradients, [w]))
opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)
train_o1 = TrainStep(opt1)
train_o2 = TrainStep(opt2)
train_o1(w, x, y)
train_o2(w, x, y)
您也可以手動執行此操作,方法是為每個最佳化工具建立 @tf.function
包裝函式的多個執行個體
opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)
# Not a tf.function.
def train_step(w, x, y, optimizer):
with tf.GradientTape() as tape:
L = tf.reduce_sum(tf.square(w*x - y))
gradients = tape.gradient(L, [w])
optimizer.apply_gradients(zip(gradients, [w]))
w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])
# Make a new tf.function and ConcreteFunction for each optimizer.
train_step_1 = tf.function(train_step)
train_step_2 = tf.function(train_step)
for i in range(10):
if i % 2 == 0:
train_step_1(w, x, y, opt1)
else:
train_step_2(w, x, y, opt2)
與多個 Keras 模型搭配使用
當將不同的模型執行個體傳遞到相同的 tf.function
時,您也可能會遇到 ValueError: tf.function only supports singleton tf.Variables created on the first call.
。
發生此錯誤的原因是 Keras 模型 (其未定義其輸入形狀) 和 Keras 層在第一次呼叫時會建立 tf.Variable
s。您可能正在嘗試在 tf.function
內初始化這些變數,而該 tf.function
已被呼叫。若要避免此錯誤,請嘗試呼叫 model.build(input_shape)
以在訓練模型之前初始化所有權重。
延伸閱讀
若要瞭解如何匯出和載入 tf.function
,請參閱 SavedModel 指南。若要瞭解追蹤後執行的圖最佳化,請參閱 Grappler 指南。若要瞭解如何最佳化您的資料管線和分析您的模型,請參閱 Profiler 指南。