Effective Tensorflow 2

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視 下載筆記本

總覽

本指南提供使用 TensorFlow 2 (TF2) 撰寫程式碼的最佳做法清單,是為最近從 TensorFlow 1 (TF1) 轉換過來的使用者所撰寫。如要進一步瞭解如何將 TF1 程式碼遷移至 TF2,請參閱指南的遷移章節

設定

匯入 TensorFlow 和本指南範例的其他依附元件。

import tensorflow as tf
import tensorflow_datasets as tfds
2023-10-04 01:22:53.526066: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-10-04 01:22:53.526110: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-10-04 01:22:53.526158: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

慣用 TensorFlow 2 的建議做法

將您的程式碼重構為較小的模組

良好的做法是將您的程式碼重構為較小的函式,並在需要時呼叫這些函式。為了獲得最佳效能,您應盡量裝飾您可以在 tf.function 中進行的最大運算區塊 (請注意,由 tf.function 呼叫的巢狀 Python 函式不需要有自己的個別裝飾,除非您想要針對 tf.function 使用不同的 jit_compile 設定)。視您的使用案例而定,這可能是多個訓練步驟,甚至是您的整個訓練迴圈。對於推論使用案例,這可能只是單一模型正向傳遞。

調整部分 tf.keras.optimizer 的預設學習率

部分 Keras 最佳化工具在 TF2 中的學習率不同。如果您發現模型的收斂行為發生變化,請檢查預設學習率。

optimizers.SGDoptimizers.Adamoptimizers.RMSprop 沒有任何變更。

以下預設學習率已變更

使用 tf.Module 和 Keras 層來管理變數

tf.Moduletf.keras.layers.Layer 提供方便的 variablestrainable_variables 屬性,這些屬性會以遞迴方式收集所有依附變數。這讓您可以輕鬆地在本機管理正在使用的變數。

Keras 層/模型繼承自 tf.train.Checkpointable,並與 @tf.function 整合,這使得可以直接從 Keras 物件檢查點或匯出 SavedModel。您不一定要使用 Keras 的 Model.fit API 來利用這些整合。

請閱讀 Keras 指南中關於 遷移學習和微調 的章節,以瞭解如何使用 Keras 收集相關變數的子集。

合併 tf.data.Datasettf.function

TensorFlow Datasets 套件 (tfds) 包含用於將預先定義的資料集載入為 tf.data.Dataset 物件的公用程式。在此範例中,您可以使用 tfds 載入 MNIST 資料集

datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
2023-10-04 01:22:57.406511: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://tensorflow.dev.org.tw/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...

然後準備要用於訓練的資料

  • 重新調整每張圖片的比例。
  • 隨機排序範例的順序。
  • 收集圖片和標籤的批次。
BUFFER_SIZE = 10 # Use a much larger value for real code
BATCH_SIZE = 64
NUM_EPOCHS = 5


def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255

  return image, label

為了讓範例簡短,請修剪資料集,使其僅傳回 5 個批次

train_data = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
test_data = mnist_test.map(scale).batch(BATCH_SIZE)

STEPS_PER_EPOCH = 5

train_data = train_data.take(STEPS_PER_EPOCH)
test_data = test_data.take(STEPS_PER_EPOCH)
image_batch, label_batch = next(iter(train_data))
2023-10-04 01:22:58.048011: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

使用一般 Python 迭代來迭代記憶體中容納的訓練資料。否則,tf.data.Dataset 是從磁碟串流訓練資料的最佳方式。資料集是 可迭代物件 (而非迭代器),並且在 eager 執行中就像其他 Python 可迭代物件一樣運作。您可以透過將程式碼包裝在 tf.function 中來充分利用資料集非同步預先擷取/串流功能,這會使用 AutoGraph 將 Python 迭代取代為等效的圖形運算。

@tf.function
def train(model, dataset, optimizer):
  for x, y in dataset:
    with tf.GradientTape() as tape:
      # training=True is only needed if there are layers with different
      # behavior during training versus inference (e.g. Dropout).
      prediction = model(x, training=True)
      loss = loss_fn(prediction, y)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

如果您使用 Keras Model.fit API,您就不必擔心資料集迭代。

model.compile(optimizer=optimizer, loss=loss_fn)
model.fit(dataset)

使用 Keras 訓練迴圈

如果您不需要對訓練過程進行低階控制,建議使用 Keras 的內建 fitevaluatepredict 方法。這些方法提供統一的介面來訓練模型,無論實作 (循序、函式或子類別化) 為何。

這些方法的優點包括

  • 它們接受 Numpy 陣列、Python 產生器和 tf.data.Datasets
  • 它們會自動套用正規化和啟動損失。
  • 它們支援 tf.distribute,其中訓練程式碼會保持不變,無論硬體組態為何
  • 它們支援任意可呼叫物件作為損失和指標。
  • 它們支援 tf.keras.callbacks.TensorBoard 等回呼,以及自訂回呼。
  • 它們效能良好,會自動使用 TensorFlow 圖形。

以下是使用 Dataset 訓練模型的範例。如需瞭解其運作方式的詳細資訊,請查看教學課程

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
])

# Model is the full model w/o custom layers
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model.evaluate(test_data)

print("Loss {}, Accuracy {}".format(loss, acc))
Epoch 1/5
5/5 [==============================] - 2s 44ms/step - loss: 1.6644 - accuracy: 0.4906
Epoch 2/5
2023-10-04 01:22:59.569439: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 9ms/step - loss: 0.5173 - accuracy: 0.9062
Epoch 3/5
2023-10-04 01:23:00.062308: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 9ms/step - loss: 0.3418 - accuracy: 0.9469
Epoch 4/5
2023-10-04 01:23:00.384057: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 8ms/step - loss: 0.2707 - accuracy: 0.9781
Epoch 5/5
2023-10-04 01:23:00.766486: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 8ms/step - loss: 0.2195 - accuracy: 0.9812
2023-10-04 01:23:01.120149: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 4ms/step - loss: 1.6036 - accuracy: 0.6250
Loss 1.6036441326141357, Accuracy 0.625
2023-10-04 01:23:01.572685: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

自訂訓練並撰寫您自己的迴圈

如果 Keras 模型適用於您,但您需要對訓練步驟或外部訓練迴圈有更高的彈性和控制權,您可以實作您自己的訓練步驟,甚至是整個訓練迴圈。請參閱 Keras 指南中關於自訂 fit 的內容以瞭解詳情。

您也可以將許多事物實作為 tf.keras.callbacks.Callback

此方法具有許多先前提及的優點,但可讓您控制訓練步驟,甚至是外部迴圈。

標準訓練迴圈有三個步驟

  1. 迭代 Python 產生器或 tf.data.Dataset 以取得範例批次。
  2. 使用 tf.GradientTape 收集梯度。
  3. 使用 tf.keras.optimizers 之一,將權重更新套用至模型的變數。

請記住

  • 一律在子類別化層和模型的 call 方法中包含 training 引數。
  • 請務必使用正確設定的 training 引數呼叫模型。
  • 視使用情況而定,模型變數可能在模型於一批資料上執行之前不存在。
  • 您需要手動處理模型的正規化損失等項目。

不需要執行變數初始設定式或新增手動控制依附元件。tf.function 會為您處理自動控制依附元件和建立時的變數初始化。

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
])

optimizer = tf.keras.optimizers.Adam(0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

@tf.function
def train_step(inputs, labels):
  with tf.GradientTape() as tape:
    predictions = model(inputs, training=True)
    regularization_loss=tf.math.add_n(model.losses)
    pred_loss=loss_fn(labels, predictions)
    total_loss=pred_loss + regularization_loss

  gradients = tape.gradient(total_loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

for epoch in range(NUM_EPOCHS):
  for inputs, labels in train_data:
    train_step(inputs, labels)
  print("Finished epoch", epoch)
2023-10-04 01:23:02.652222: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 0
2023-10-04 01:23:02.957452: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 1
2023-10-04 01:23:03.632425: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 2
2023-10-04 01:23:03.877866: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 3
Finished epoch 4
2023-10-04 01:23:04.197488: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

利用具有 Python 控制流程的 tf.function

tf.function 提供一種方式,可將資料依附控制流程轉換為圖形模式等效項目,例如 tf.condtf.while_loop

資料依附控制流程常見的一個位置是在序列模型中。tf.keras.layers.RNN 包裝 RNN 儲存格,讓您可以靜態或動態地展開遞迴。例如,您可以重新實作動態展開,如下所示。

class DynamicRNN(tf.keras.Model):

  def __init__(self, rnn_cell):
    super(DynamicRNN, self).__init__(self)
    self.cell = rnn_cell

  @tf.function(input_signature=[tf.TensorSpec(dtype=tf.float32, shape=[None, None, 3])])
  def call(self, input_data):

    # [batch, time, features] -> [time, batch, features]
    input_data = tf.transpose(input_data, [1, 0, 2])
    timesteps =  tf.shape(input_data)[0]
    batch_size = tf.shape(input_data)[1]
    outputs = tf.TensorArray(tf.float32, timesteps)
    state = self.cell.get_initial_state(batch_size = batch_size, dtype=tf.float32)
    for i in tf.range(timesteps):
      output, state = self.cell(input_data[i], state)
      outputs = outputs.write(i, output)
    return tf.transpose(outputs.stack(), [1, 0, 2]), state
lstm_cell = tf.keras.layers.LSTMCell(units = 13)

my_rnn = DynamicRNN(lstm_cell)
outputs, state = my_rnn(tf.random.normal(shape=[10,20,3]))
print(outputs.shape)
(10, 20, 13)

請閱讀 tf.function 指南以取得更多資訊。

新樣式指標和損失

指標和損失都是在 eager 和 tf.function 中運作的物件。

損失物件是可呼叫物件,並預期 (y_truey_pred) 作為引數

cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
cce([[1, 0]], [[-1.0,3.0]]).numpy()
4.01815

使用指標來收集和顯示資料

您可以使用 tf.metrics 彙總資料,並使用 tf.summary 記錄摘要,並使用內容管理員將其重新導向至寫入器。摘要會直接發出至寫入器,這表示您必須在呼叫點提供 step 值。

summary_writer = tf.summary.create_file_writer('/tmp/summaries')
with summary_writer.as_default():
  tf.summary.scalar('loss', 0.1, step=42)

在將 tf.metrics 用於彙總資料之後,再將其記錄為摘要。指標是有狀態的;它們會累積值,並在您呼叫 result 方法時傳回累積結果 (例如 Mean.result)。使用 Model.reset_states 清除累積值。

def train(model, optimizer, dataset, log_freq=10):
  avg_loss = tf.keras.metrics.Mean(name='loss', dtype=tf.float32)
  for images, labels in dataset:
    loss = train_step(model, optimizer, images, labels)
    avg_loss.update_state(loss)
    if tf.equal(optimizer.iterations % log_freq, 0):
      tf.summary.scalar('loss', avg_loss.result(), step=optimizer.iterations)
      avg_loss.reset_states()

def test(model, test_x, test_y, step_num):
  # training=False is only needed if there are layers with different
  # behavior during training versus inference (e.g. Dropout).
  loss = loss_fn(model(test_x, training=False), test_y)
  tf.summary.scalar('loss', loss, step=step_num)

train_summary_writer = tf.summary.create_file_writer('/tmp/summaries/train')
test_summary_writer = tf.summary.create_file_writer('/tmp/summaries/test')

with train_summary_writer.as_default():
  train(model, optimizer, dataset)

with test_summary_writer.as_default():
  test(model, test_x, test_y, optimizer.iterations)

將 TensorBoard 指向摘要記錄目錄,即可視覺化產生的摘要

tensorboard --logdir /tmp/summaries

使用 tf.summary API 寫入摘要資料,以便在 TensorBoard 中視覺化。如需更多資訊,請閱讀 tf.summary 指南

# Create the metrics
loss_metric = tf.keras.metrics.Mean(name='train_loss')
accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

@tf.function
def train_step(inputs, labels):
  with tf.GradientTape() as tape:
    predictions = model(inputs, training=True)
    regularization_loss=tf.math.add_n(model.losses)
    pred_loss=loss_fn(labels, predictions)
    total_loss=pred_loss + regularization_loss

  gradients = tape.gradient(total_loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))
  # Update the metrics
  loss_metric.update_state(total_loss)
  accuracy_metric.update_state(labels, predictions)


for epoch in range(NUM_EPOCHS):
  # Reset the metrics
  loss_metric.reset_states()
  accuracy_metric.reset_states()

  for inputs, labels in train_data:
    train_step(inputs, labels)
  # Get the metric results
  mean_loss=loss_metric.result()
  mean_accuracy = accuracy_metric.result()

  print('Epoch: ', epoch)
  print('  loss:     {:.3f}'.format(mean_loss))
  print('  accuracy: {:.3f}'.format(mean_accuracy))
2023-10-04 01:23:05.220607: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  0
  loss:     0.176
  accuracy: 0.994
2023-10-04 01:23:05.554495: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  1
  loss:     0.153
  accuracy: 0.991
2023-10-04 01:23:06.043597: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  2
  loss:     0.134
  accuracy: 0.994
2023-10-04 01:23:06.297768: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  3
  loss:     0.108
  accuracy: 1.000
Epoch:  4
  loss:     0.095
  accuracy: 1.000
2023-10-04 01:23:06.678292: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Keras 指標名稱

Keras 模型在處理指標名稱時具有一致性。當您在指標清單中傳遞字串時,這個確切字串會用作指標的 name。這些名稱在 model.fit 傳回的歷程記錄物件中以及傳遞至 keras.callbacks 的記錄中可見。會設定為您在指標清單中傳遞的字串。

model.compile(
    optimizer = tf.keras.optimizers.Adam(0.001),
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics = ['acc', 'accuracy', tf.keras.metrics.SparseCategoricalAccuracy(name="my_accuracy")])
history = model.fit(train_data)
5/5 [==============================] - 1s 9ms/step - loss: 0.1077 - acc: 0.9937 - accuracy: 0.9937 - my_accuracy: 0.9937
2023-10-04 01:23:07.849601: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
history.history.keys()
dict_keys(['loss', 'acc', 'accuracy', 'my_accuracy'])

偵錯

使用 eager 執行逐步執行您的程式碼,以檢查形狀、資料類型和值。部分 API (例如 tf.functiontf.keras 等) 旨在為了效能和可攜性而使用圖形執行。偵錯時,請使用 tf.config.run_functions_eagerly(True),以便在此程式碼內使用 eager 執行。

例如

@tf.function
def f(x):
  if x > 0:
    import pdb
    pdb.set_trace()
    x = x + 1
  return x

tf.config.run_functions_eagerly(True)
f(tf.constant(1))
>>> f()
-> x = x + 1
(Pdb) l
  6     @tf.function
  7     def f(x):
  8       if x > 0:
  9         import pdb
 10         pdb.set_trace()
 11  ->     x = x + 1
 12       return x
 13
 14     tf.config.run_functions_eagerly(True)
 15     f(tf.constant(1))
[EOF]

這也適用於 Keras 模型和支援 eager 執行的其他 API 內

class CustomModel(tf.keras.models.Model):

  @tf.function
  def call(self, input_data):
    if tf.reduce_mean(input_data) > 0:
      return input_data
    else:
      import pdb
      pdb.set_trace()
      return input_data // 2


tf.config.run_functions_eagerly(True)
model = CustomModel()
model(tf.constant([-2, -4]))
>>> call()
-> return input_data // 2
(Pdb) l
 10         if tf.reduce_mean(input_data) > 0:
 11           return input_data
 12         else:
 13           import pdb
 14           pdb.set_trace()
 15  ->       return input_data // 2
 16
 17
 18     tf.config.run_functions_eagerly(True)
 19     model = CustomModel()
 20     model(tf.constant([-2, -4]))

注意事項

請勿將 tf.Tensors 保留在您的物件中

這些張量物件可能會在 tf.function 或 eager 內容中建立,而且這些張量的行為有所不同。一律僅將 tf.Tensor 用於中繼值。

如要追蹤狀態,請使用 tf.Variable,因為它們一律可從這兩種內容中使用。請閱讀 tf.Variable 指南以瞭解詳情。

資源和延伸閱讀

  • 請閱讀 TF2 指南教學課程,以進一步瞭解如何使用 TF2。

  • 如果您先前使用過 TF1.x,強烈建議您將程式碼遷移至 TF2。請閱讀遷移指南以瞭解詳情。