![]() |
![]() |
![]() |
![]() |
總覽
本指南提供使用 TensorFlow 2 (TF2) 撰寫程式碼的最佳做法清單,是為最近從 TensorFlow 1 (TF1) 轉換過來的使用者所撰寫。如要進一步瞭解如何將 TF1 程式碼遷移至 TF2,請參閱指南的遷移章節。
設定
匯入 TensorFlow 和本指南範例的其他依附元件。
import tensorflow as tf
import tensorflow_datasets as tfds
2023-10-04 01:22:53.526066: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-10-04 01:22:53.526110: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-10-04 01:22:53.526158: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
慣用 TensorFlow 2 的建議做法
將您的程式碼重構為較小的模組
良好的做法是將您的程式碼重構為較小的函式,並在需要時呼叫這些函式。為了獲得最佳效能,您應盡量裝飾您可以在 tf.function
中進行的最大運算區塊 (請注意,由 tf.function
呼叫的巢狀 Python 函式不需要有自己的個別裝飾,除非您想要針對 tf.function
使用不同的 jit_compile
設定)。視您的使用案例而定,這可能是多個訓練步驟,甚至是您的整個訓練迴圈。對於推論使用案例,這可能只是單一模型正向傳遞。
調整部分 tf.keras.optimizer
的預設學習率
部分 Keras 最佳化工具在 TF2 中的學習率不同。如果您發現模型的收斂行為發生變化,請檢查預設學習率。
optimizers.SGD
、optimizers.Adam
或 optimizers.RMSprop
沒有任何變更。
以下預設學習率已變更
optimizers.Adagrad
從0.01
變更為0.001
optimizers.Adadelta
從1.0
變更為0.001
optimizers.Adamax
從0.002
變更為0.001
optimizers.Nadam
從0.002
變更為0.001
使用 tf.Module
和 Keras 層來管理變數
tf.Module
和 tf.keras.layers.Layer
提供方便的 variables
和 trainable_variables
屬性,這些屬性會以遞迴方式收集所有依附變數。這讓您可以輕鬆地在本機管理正在使用的變數。
Keras 層/模型繼承自 tf.train.Checkpointable
,並與 @tf.function
整合,這使得可以直接從 Keras 物件檢查點或匯出 SavedModel。您不一定要使用 Keras 的 Model.fit
API 來利用這些整合。
請閱讀 Keras 指南中關於 遷移學習和微調 的章節,以瞭解如何使用 Keras 收集相關變數的子集。
合併 tf.data.Dataset
和 tf.function
TensorFlow Datasets 套件 (tfds
) 包含用於將預先定義的資料集載入為 tf.data.Dataset
物件的公用程式。在此範例中,您可以使用 tfds
載入 MNIST 資料集
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
2023-10-04 01:22:57.406511: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://tensorflow.dev.org.tw/install/gpu for how to download and setup the required libraries for your platform. Skipping registering GPU devices...
然後準備要用於訓練的資料
- 重新調整每張圖片的比例。
- 隨機排序範例的順序。
- 收集圖片和標籤的批次。
BUFFER_SIZE = 10 # Use a much larger value for real code
BATCH_SIZE = 64
NUM_EPOCHS = 5
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
為了讓範例簡短,請修剪資料集,使其僅傳回 5 個批次
train_data = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
test_data = mnist_test.map(scale).batch(BATCH_SIZE)
STEPS_PER_EPOCH = 5
train_data = train_data.take(STEPS_PER_EPOCH)
test_data = test_data.take(STEPS_PER_EPOCH)
image_batch, label_batch = next(iter(train_data))
2023-10-04 01:22:58.048011: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
使用一般 Python 迭代來迭代記憶體中容納的訓練資料。否則,tf.data.Dataset
是從磁碟串流訓練資料的最佳方式。資料集是 可迭代物件 (而非迭代器),並且在 eager 執行中就像其他 Python 可迭代物件一樣運作。您可以透過將程式碼包裝在 tf.function
中來充分利用資料集非同步預先擷取/串流功能,這會使用 AutoGraph 將 Python 迭代取代為等效的圖形運算。
@tf.function
def train(model, dataset, optimizer):
for x, y in dataset:
with tf.GradientTape() as tape:
# training=True is only needed if there are layers with different
# behavior during training versus inference (e.g. Dropout).
prediction = model(x, training=True)
loss = loss_fn(prediction, y)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
如果您使用 Keras Model.fit
API,您就不必擔心資料集迭代。
model.compile(optimizer=optimizer, loss=loss_fn)
model.fit(dataset)
使用 Keras 訓練迴圈
如果您不需要對訓練過程進行低階控制,建議使用 Keras 的內建 fit
、evaluate
和 predict
方法。這些方法提供統一的介面來訓練模型,無論實作 (循序、函式或子類別化) 為何。
這些方法的優點包括
- 它們接受 Numpy 陣列、Python 產生器和
tf.data.Datasets
。 - 它們會自動套用正規化和啟動損失。
- 它們支援
tf.distribute
,其中訓練程式碼會保持不變,無論硬體組態為何。 - 它們支援任意可呼叫物件作為損失和指標。
- 它們支援
tf.keras.callbacks.TensorBoard
等回呼,以及自訂回呼。 - 它們效能良好,會自動使用 TensorFlow 圖形。
以下是使用 Dataset
訓練模型的範例。如需瞭解其運作方式的詳細資訊,請查看教學課程。
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(0.02),
input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(10)
])
# Model is the full model w/o custom layers
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model.evaluate(test_data)
print("Loss {}, Accuracy {}".format(loss, acc))
Epoch 1/5 5/5 [==============================] - 2s 44ms/step - loss: 1.6644 - accuracy: 0.4906 Epoch 2/5 2023-10-04 01:22:59.569439: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 0s 9ms/step - loss: 0.5173 - accuracy: 0.9062 Epoch 3/5 2023-10-04 01:23:00.062308: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 0s 9ms/step - loss: 0.3418 - accuracy: 0.9469 Epoch 4/5 2023-10-04 01:23:00.384057: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 0s 8ms/step - loss: 0.2707 - accuracy: 0.9781 Epoch 5/5 2023-10-04 01:23:00.766486: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 0s 8ms/step - loss: 0.2195 - accuracy: 0.9812 2023-10-04 01:23:01.120149: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 0s 4ms/step - loss: 1.6036 - accuracy: 0.6250 Loss 1.6036441326141357, Accuracy 0.625 2023-10-04 01:23:01.572685: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
自訂訓練並撰寫您自己的迴圈
如果 Keras 模型適用於您,但您需要對訓練步驟或外部訓練迴圈有更高的彈性和控制權,您可以實作您自己的訓練步驟,甚至是整個訓練迴圈。請參閱 Keras 指南中關於自訂 fit
的內容以瞭解詳情。
您也可以將許多事物實作為 tf.keras.callbacks.Callback
。
此方法具有許多先前提及的優點,但可讓您控制訓練步驟,甚至是外部迴圈。
標準訓練迴圈有三個步驟
- 迭代 Python 產生器或
tf.data.Dataset
以取得範例批次。 - 使用
tf.GradientTape
收集梯度。 - 使用
tf.keras.optimizers
之一,將權重更新套用至模型的變數。
請記住
- 一律在子類別化層和模型的
call
方法中包含training
引數。 - 請務必使用正確設定的
training
引數呼叫模型。 - 視使用情況而定,模型變數可能在模型於一批資料上執行之前不存在。
- 您需要手動處理模型的正規化損失等項目。
不需要執行變數初始設定式或新增手動控制依附元件。tf.function
會為您處理自動控制依附元件和建立時的變數初始化。
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(0.02),
input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(10)
])
optimizer = tf.keras.optimizers.Adam(0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
@tf.function
def train_step(inputs, labels):
with tf.GradientTape() as tape:
predictions = model(inputs, training=True)
regularization_loss=tf.math.add_n(model.losses)
pred_loss=loss_fn(labels, predictions)
total_loss=pred_loss + regularization_loss
gradients = tape.gradient(total_loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
for epoch in range(NUM_EPOCHS):
for inputs, labels in train_data:
train_step(inputs, labels)
print("Finished epoch", epoch)
2023-10-04 01:23:02.652222: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 0 2023-10-04 01:23:02.957452: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 1 2023-10-04 01:23:03.632425: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 2 2023-10-04 01:23:03.877866: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 3 Finished epoch 4 2023-10-04 01:23:04.197488: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
利用具有 Python 控制流程的 tf.function
tf.function
提供一種方式,可將資料依附控制流程轉換為圖形模式等效項目,例如 tf.cond
和 tf.while_loop
。
資料依附控制流程常見的一個位置是在序列模型中。tf.keras.layers.RNN
包裝 RNN 儲存格,讓您可以靜態或動態地展開遞迴。例如,您可以重新實作動態展開,如下所示。
class DynamicRNN(tf.keras.Model):
def __init__(self, rnn_cell):
super(DynamicRNN, self).__init__(self)
self.cell = rnn_cell
@tf.function(input_signature=[tf.TensorSpec(dtype=tf.float32, shape=[None, None, 3])])
def call(self, input_data):
# [batch, time, features] -> [time, batch, features]
input_data = tf.transpose(input_data, [1, 0, 2])
timesteps = tf.shape(input_data)[0]
batch_size = tf.shape(input_data)[1]
outputs = tf.TensorArray(tf.float32, timesteps)
state = self.cell.get_initial_state(batch_size = batch_size, dtype=tf.float32)
for i in tf.range(timesteps):
output, state = self.cell(input_data[i], state)
outputs = outputs.write(i, output)
return tf.transpose(outputs.stack(), [1, 0, 2]), state
lstm_cell = tf.keras.layers.LSTMCell(units = 13)
my_rnn = DynamicRNN(lstm_cell)
outputs, state = my_rnn(tf.random.normal(shape=[10,20,3]))
print(outputs.shape)
(10, 20, 13)
請閱讀 tf.function
指南以取得更多資訊。
新樣式指標和損失
指標和損失都是在 eager 和 tf.function
中運作的物件。
損失物件是可呼叫物件,並預期 (y_true
、y_pred
) 作為引數
cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
cce([[1, 0]], [[-1.0,3.0]]).numpy()
4.01815
使用指標來收集和顯示資料
您可以使用 tf.metrics
彙總資料,並使用 tf.summary
記錄摘要,並使用內容管理員將其重新導向至寫入器。摘要會直接發出至寫入器,這表示您必須在呼叫點提供 step
值。
summary_writer = tf.summary.create_file_writer('/tmp/summaries')
with summary_writer.as_default():
tf.summary.scalar('loss', 0.1, step=42)
在將 tf.metrics
用於彙總資料之後,再將其記錄為摘要。指標是有狀態的;它們會累積值,並在您呼叫 result
方法時傳回累積結果 (例如 Mean.result
)。使用 Model.reset_states
清除累積值。
def train(model, optimizer, dataset, log_freq=10):
avg_loss = tf.keras.metrics.Mean(name='loss', dtype=tf.float32)
for images, labels in dataset:
loss = train_step(model, optimizer, images, labels)
avg_loss.update_state(loss)
if tf.equal(optimizer.iterations % log_freq, 0):
tf.summary.scalar('loss', avg_loss.result(), step=optimizer.iterations)
avg_loss.reset_states()
def test(model, test_x, test_y, step_num):
# training=False is only needed if there are layers with different
# behavior during training versus inference (e.g. Dropout).
loss = loss_fn(model(test_x, training=False), test_y)
tf.summary.scalar('loss', loss, step=step_num)
train_summary_writer = tf.summary.create_file_writer('/tmp/summaries/train')
test_summary_writer = tf.summary.create_file_writer('/tmp/summaries/test')
with train_summary_writer.as_default():
train(model, optimizer, dataset)
with test_summary_writer.as_default():
test(model, test_x, test_y, optimizer.iterations)
將 TensorBoard 指向摘要記錄目錄,即可視覺化產生的摘要
tensorboard --logdir /tmp/summaries
使用 tf.summary
API 寫入摘要資料,以便在 TensorBoard 中視覺化。如需更多資訊,請閱讀 tf.summary
指南。
# Create the metrics
loss_metric = tf.keras.metrics.Mean(name='train_loss')
accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
@tf.function
def train_step(inputs, labels):
with tf.GradientTape() as tape:
predictions = model(inputs, training=True)
regularization_loss=tf.math.add_n(model.losses)
pred_loss=loss_fn(labels, predictions)
total_loss=pred_loss + regularization_loss
gradients = tape.gradient(total_loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
# Update the metrics
loss_metric.update_state(total_loss)
accuracy_metric.update_state(labels, predictions)
for epoch in range(NUM_EPOCHS):
# Reset the metrics
loss_metric.reset_states()
accuracy_metric.reset_states()
for inputs, labels in train_data:
train_step(inputs, labels)
# Get the metric results
mean_loss=loss_metric.result()
mean_accuracy = accuracy_metric.result()
print('Epoch: ', epoch)
print(' loss: {:.3f}'.format(mean_loss))
print(' accuracy: {:.3f}'.format(mean_accuracy))
2023-10-04 01:23:05.220607: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 0 loss: 0.176 accuracy: 0.994 2023-10-04 01:23:05.554495: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 1 loss: 0.153 accuracy: 0.991 2023-10-04 01:23:06.043597: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 2 loss: 0.134 accuracy: 0.994 2023-10-04 01:23:06.297768: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 3 loss: 0.108 accuracy: 1.000 Epoch: 4 loss: 0.095 accuracy: 1.000 2023-10-04 01:23:06.678292: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Keras 指標名稱
Keras 模型在處理指標名稱時具有一致性。當您在指標清單中傳遞字串時,這個確切字串會用作指標的 name
。這些名稱在 model.fit
傳回的歷程記錄物件中以及傳遞至 keras.callbacks
的記錄中可見。會設定為您在指標清單中傳遞的字串。
model.compile(
optimizer = tf.keras.optimizers.Adam(0.001),
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics = ['acc', 'accuracy', tf.keras.metrics.SparseCategoricalAccuracy(name="my_accuracy")])
history = model.fit(train_data)
5/5 [==============================] - 1s 9ms/step - loss: 0.1077 - acc: 0.9937 - accuracy: 0.9937 - my_accuracy: 0.9937 2023-10-04 01:23:07.849601: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
history.history.keys()
dict_keys(['loss', 'acc', 'accuracy', 'my_accuracy'])
偵錯
使用 eager 執行逐步執行您的程式碼,以檢查形狀、資料類型和值。部分 API (例如 tf.function
、tf.keras
等) 旨在為了效能和可攜性而使用圖形執行。偵錯時,請使用 tf.config.run_functions_eagerly(True)
,以便在此程式碼內使用 eager 執行。
例如
@tf.function
def f(x):
if x > 0:
import pdb
pdb.set_trace()
x = x + 1
return x
tf.config.run_functions_eagerly(True)
f(tf.constant(1))
>>> f()
-> x = x + 1
(Pdb) l
6 @tf.function
7 def f(x):
8 if x > 0:
9 import pdb
10 pdb.set_trace()
11 -> x = x + 1
12 return x
13
14 tf.config.run_functions_eagerly(True)
15 f(tf.constant(1))
[EOF]
這也適用於 Keras 模型和支援 eager 執行的其他 API 內
class CustomModel(tf.keras.models.Model):
@tf.function
def call(self, input_data):
if tf.reduce_mean(input_data) > 0:
return input_data
else:
import pdb
pdb.set_trace()
return input_data // 2
tf.config.run_functions_eagerly(True)
model = CustomModel()
model(tf.constant([-2, -4]))
>>> call()
-> return input_data // 2
(Pdb) l
10 if tf.reduce_mean(input_data) > 0:
11 return input_data
12 else:
13 import pdb
14 pdb.set_trace()
15 -> return input_data // 2
16
17
18 tf.config.run_functions_eagerly(True)
19 model = CustomModel()
20 model(tf.constant([-2, -4]))
注意事項
tf.keras.Model
方法 (例如fit
、evaluate
和predict
) 會以 圖形搭配底層的tf.function
執行。使用
tf.keras.Model.compile
時,請設定run_eagerly = True
以停用包裝在tf.function
中的Model
邏輯。使用
tf.data.experimental.enable_debug_mode
以啟用tf.data
的偵錯模式。請閱讀 API 文件以取得更多詳細資訊。
請勿將 tf.Tensors
保留在您的物件中
這些張量物件可能會在 tf.function
或 eager 內容中建立,而且這些張量的行為有所不同。一律僅將 tf.Tensor
用於中繼值。
如要追蹤狀態,請使用 tf.Variable
,因為它們一律可從這兩種內容中使用。請閱讀 tf.Variable
指南以瞭解詳情。