作者: fchollet
![]() |
![]() |
![]() |
![]() |
簡介
當您進行監督式學習時,可以使用 fit()
,一切都會順利運作。
當您需要從頭開始編寫自己的訓練迴圈時,可以使用 GradientTape
並掌控每個小細節。
但如果您需要自訂訓練演算法,但仍然想要受益於 fit()
的便利功能 (例如回呼、內建分散式支援或步驟融合),該怎麼辦?
Keras 的核心原則是逐步揭露複雜性。您應該始終能夠逐步進入較低層級的工作流程。如果高階功能與您的使用案例不完全相符,您不應一籌莫展。您應該能夠在更精細地掌控細節的同時,保留相當程度的高階便利性。
當您需要自訂 fit()
的功能時,您應該覆寫 Model
類別的訓練步驟函式。這是 fit()
針對每個批次的資料呼叫的函式。然後,您就可以像平常一樣呼叫 fit()
,而且它將會執行您自己的學習演算法。
請注意,無論您建構的是 Sequential
模型、Functional API 模型還是子類別化模型,此模式都不會妨礙您使用 Functional API 建構模型。
讓我們看看它是如何運作的。
設定
需要 TensorFlow 2.8 或更高版本。
import tensorflow as tf
from tensorflow import keras
第一個簡單範例
讓我們從一個簡單的範例開始
- 我們建立一個新的類別,將
keras.Model
子類別化。 - 我們只覆寫方法
train_step(self, data)
。 - 我們會傳回字典,將指標名稱 (包括損失) 對應到其目前值。
輸入引數 data
是傳遞至 fit 作為訓練資料的內容
- 如果您透過呼叫
fit(x, y, ...)
傳遞 Numpy 陣列,則data
會是元組(x, y)
- 如果您透過呼叫
fit(dataset, ...)
傳遞tf.data.Dataset
,則data
會是dataset
在每個批次中產生的內容。
在 train_step
方法的主體中,我們實作常規訓練更新,這與您已熟悉的方式類似。重要的是,我們透過 self.compute_loss()
計算損失,這會包裝傳遞至 compile()
的損失函式。
同樣地,我們針對 self.metrics
中的指標呼叫 metric.update_state(y, y_pred)
,以更新在 compile()
中傳遞的指標狀態,並在最後從 self.metrics
查詢結果,以擷取其目前值。
class CustomModel(keras.Model):
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
x, y = data
with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
# Compute the loss value
# (the loss function is configured in `compile()`)
loss = self.compute_loss(y=y, y_pred=y_pred)
# Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update metrics (includes the metric that tracks the loss)
for metric in self.metrics:
if metric.name == "loss":
metric.update_state(loss)
else:
metric.update_state(y, y_pred)
# Return a dict mapping metric names to current value
return {m.name: m.result() for m in self.metrics}
讓我們試用看看
import numpy as np
# Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])
# Just use `fit` as usual
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=3)
Epoch 1/3 32/32 [==============================] - 3s 2ms/step - loss: 1.6446 Epoch 2/3 32/32 [==============================] - 0s 2ms/step - loss: 0.7554 Epoch 3/3 32/32 [==============================] - 0s 2ms/step - loss: 0.3924 <keras.src.callbacks.History at 0x7fef5c11ba30>
進入更低層級
當然,您可以直接略過在 compile()
中傳遞損失函式,而是在 train_step
中手動執行所有操作。指標也是如此。
以下是一個更低層級的範例,僅使用 compile()
來設定最佳化工具
- 首先,我們建立
Metric
執行個體來追蹤我們的損失和 MAE 分數 (在__init__()
中)。 - 我們實作自訂
train_step()
,更新這些指標的狀態 (透過對其呼叫update_state()
),然後查詢它們 (透過result()
) 以傳回其目前的平均值,以顯示在進度列中,並傳遞至任何回呼。 - 請注意,我們需要在每個週期之間對指標呼叫
reset_states()
!否則,呼叫result()
會傳回自訓練開始以來的平均值,但我們通常使用每個週期的平均值。值得慶幸的是,架構可以為我們做到這一點:只需在模型的metrics
屬性中列出您要重設的任何指標即可。模型會在每個fit()
週期的開始或在呼叫evaluate()
的開始時,對此處列出的任何物件呼叫reset_states()
。
class CustomModel(keras.Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.loss_tracker = keras.metrics.Mean(name="loss")
self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae")
def train_step(self, data):
x, y = data
with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
# Compute our own loss
loss = keras.losses.mean_squared_error(y, y_pred)
# Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Compute our own metrics
self.loss_tracker.update_state(loss)
self.mae_metric.update_state(y, y_pred)
return {"loss": self.loss_tracker.result(), "mae": self.mae_metric.result()}
@property
def metrics(self):
# We list our `Metric` objects here so that `reset_states()` can be
# called automatically at the start of each epoch
# or at the start of `evaluate()`.
# If you don't implement this property, you have to call
# `reset_states()` yourself at the time of your choosing.
return [self.loss_tracker, self.mae_metric]
# Construct an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
# We don't pass a loss or metrics here.
model.compile(optimizer="adam")
# Just use `fit` as usual -- you can use callbacks, etc.
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=5)
Epoch 1/5 32/32 [==============================] - 0s 2ms/step - loss: 0.3240 - mae: 0.4583 Epoch 2/5 32/32 [==============================] - 0s 2ms/step - loss: 0.2416 - mae: 0.3984 Epoch 3/5 32/32 [==============================] - 0s 2ms/step - loss: 0.2340 - mae: 0.3919 Epoch 4/5 32/32 [==============================] - 0s 2ms/step - loss: 0.2274 - mae: 0.3870 Epoch 5/5 32/32 [==============================] - 0s 2ms/step - loss: 0.2197 - mae: 0.3808 <keras.src.callbacks.History at 0x7fef3c130b20>
支援 sample_weight
和 class_weight
您可能已經注意到,我們的第一個基本範例沒有提及任何樣本權重。如果您想要支援 fit()
引數 sample_weight
和 class_weight
,您只需執行下列操作
- 從
data
引數解壓縮sample_weight
- 將其傳遞至
compute_loss
和update_state
(當然,如果您不依賴compile()
進行損失和指標計算,您也可以手動套用) - 就這樣。
class CustomModel(keras.Model):
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
if len(data) == 3:
x, y, sample_weight = data
else:
sample_weight = None
x, y = data
with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
# Compute the loss value.
# The loss function is configured in `compile()`.
loss = self.compute_loss(
y=y,
y_pred=y_pred,
sample_weight=sample_weight,
)
# Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update the metrics.
# Metrics are configured in `compile()`.
for metric in self.metrics:
if metric.name == "loss":
metric.update_state(loss)
else:
metric.update_state(y, y_pred, sample_weight=sample_weight)
# Return a dict mapping metric names to current value.
# Note that it will include the loss (tracked in self.metrics).
return {m.name: m.result() for m in self.metrics}
# Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])
# You can now use sample_weight argument
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
sw = np.random.random((1000, 1))
model.fit(x, y, sample_weight=sw, epochs=3)
Epoch 1/3 32/32 [==============================] - 0s 2ms/step - loss: 0.1298 Epoch 2/3 32/32 [==============================] - 0s 2ms/step - loss: 0.1179 Epoch 3/3 32/32 [==============================] - 0s 2ms/step - loss: 0.1121 <keras.src.callbacks.History at 0x7fef3c168100>
提供您自己的評估步驟
如果您想要對 model.evaluate()
的呼叫執行相同的操作,該怎麼辦?那麼您將以完全相同的方式覆寫 test_step
。以下是它的外觀
class CustomModel(keras.Model):
def test_step(self, data):
# Unpack the data
x, y = data
# Compute predictions
y_pred = self(x, training=False)
# Updates the metrics tracking the loss
self.compute_loss(y=y, y_pred=y_pred)
# Update the metrics.
for metric in self.metrics:
if metric.name != "loss":
metric.update_state(y, y_pred)
# Return a dict mapping metric names to current value.
# Note that it will include the loss (tracked in self.metrics).
return {m.name: m.result() for m in self.metrics}
# Construct an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(loss="mse", metrics=["mae"])
# Evaluate with our custom test_step
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.evaluate(x, y)
32/32 [==============================] - 0s 1ms/step - loss: 0.9028 0.9028095006942749
總結:端對端 GAN 範例
讓我們逐步瞭解一個端對端範例,其中運用了您剛學到的一切。
讓我們考慮
- 一個產生器網路,旨在產生 28x28x1 圖片。
- 一個鑑別器網路,旨在將 28x28x1 圖片分類為兩個類別 (「假」和「真」)。
- 每個網路各一個最佳化工具。
- 一個用於訓練鑑別器的損失函式。
from tensorflow.keras import layers
# Create the discriminator
discriminator = keras.Sequential(
[
keras.Input(shape=(28, 28, 1)),
layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.GlobalMaxPooling2D(),
layers.Dense(1),
],
name="discriminator",
)
# Create the generator
latent_dim = 128
generator = keras.Sequential(
[
keras.Input(shape=(latent_dim,)),
# We want to generate 128 coefficients to reshape into a 7x7x128 map
layers.Dense(7 * 7 * 128),
layers.LeakyReLU(alpha=0.2),
layers.Reshape((7, 7, 128)),
layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
],
name="generator",
)
以下是一個功能完整的 GAN 類別,覆寫了 compile()
以使用其自己的簽章,並在 train_step
中以 17 行程式碼實作整個 GAN 演算法
class GAN(keras.Model):
def __init__(self, discriminator, generator, latent_dim):
super().__init__()
self.discriminator = discriminator
self.generator = generator
self.latent_dim = latent_dim
self.d_loss_tracker = keras.metrics.Mean(name="d_loss")
self.g_loss_tracker = keras.metrics.Mean(name="g_loss")
def compile(self, d_optimizer, g_optimizer, loss_fn):
super().compile()
self.d_optimizer = d_optimizer
self.g_optimizer = g_optimizer
self.loss_fn = loss_fn
def train_step(self, real_images):
if isinstance(real_images, tuple):
real_images = real_images[0]
# Sample random points in the latent space
batch_size = tf.shape(real_images)[0]
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
# Decode them to fake images
generated_images = self.generator(random_latent_vectors)
# Combine them with real images
combined_images = tf.concat([generated_images, real_images], axis=0)
# Assemble labels discriminating real from fake images
labels = tf.concat(
[tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
)
# Add random noise to the labels - important trick!
labels += 0.05 * tf.random.uniform(tf.shape(labels))
# Train the discriminator
with tf.GradientTape() as tape:
predictions = self.discriminator(combined_images)
d_loss = self.loss_fn(labels, predictions)
grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
self.d_optimizer.apply_gradients(
zip(grads, self.discriminator.trainable_weights)
)
# Sample random points in the latent space
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
# Assemble labels that say "all real images"
misleading_labels = tf.zeros((batch_size, 1))
# Train the generator (note that we should *not* update the weights
# of the discriminator)!
with tf.GradientTape() as tape:
predictions = self.discriminator(self.generator(random_latent_vectors))
g_loss = self.loss_fn(misleading_labels, predictions)
grads = tape.gradient(g_loss, self.generator.trainable_weights)
self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
# Update metrics and return their value.
self.d_loss_tracker.update_state(d_loss)
self.g_loss_tracker.update_state(g_loss)
return {
"d_loss": self.d_loss_tracker.result(),
"g_loss": self.g_loss_tracker.result(),
}
讓我們試駕看看
# Prepare the dataset. We use both the training & test MNIST digits.
batch_size = 64
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
all_digits = np.concatenate([x_train, x_test])
all_digits = all_digits.astype("float32") / 255.0
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
dataset = tf.data.Dataset.from_tensor_slices(all_digits)
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)
gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)
gan.compile(
d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
)
# To limit the execution time, we only train on 100 batches. You can train on
# the entire dataset. You will need about 20 epochs to get nice results.
gan.fit(dataset.take(100), epochs=1)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 11490434/11490434 [==============================] - 0s 0us/step 100/100 [==============================] - 8s 15ms/step - d_loss: 0.4372 - g_loss: 0.8775 <keras.src.callbacks.History at 0x7feee42ff190>
深度學習背後的概念很簡單,那麼為什麼它們的實作應該很痛苦呢?