回呼函式是自訂 Keras 模型在訓練、評估或推論期間行為的強大工具。範例包括 tf.keras.callbacks.TensorBoard,用於透過 TensorBoard 將訓練進度和結果視覺化,或 tf.keras.callbacks.ModelCheckpoint,用於在訓練期間定期儲存您的模型。

在本指南中,您將瞭解 Keras 回呼函式是什麼、它可以做什麼,以及如何建構您自己的回呼函式。我們提供一些簡單回呼函式應用程式的示範,以協助您入門。


import tensorflow as tf
from tensorflow import keras

Keras 回呼函式總覽

所有回呼函式都是 keras.callbacks.Callback 類別的子類別,並覆寫一組在訓練、測試和預測的各個階段呼叫的方法。回呼函式有助於取得模型在訓練期間的內部狀態和統計資料的檢視。

您可以將回呼函式清單 (作為關鍵字引數 callbacks) 傳遞至下列模型方法



on_(train|test|predict)_begin(self, logs=None)

fit/evaluate/predict 開始時呼叫。

on_(train|test|predict)_end(self, logs=None)

fit/evaluate/predict 結束時呼叫。


on_(train|test|predict)_batch_begin(self, batch, logs=None)


on_(train|test|predict)_batch_end(self, batch, logs=None)

在訓練/測試/預測批次結束時呼叫。在此方法中,logs 是一個包含指標結果的字典。

Epoch 層級方法 (僅限訓練)

on_epoch_begin(self, epoch, logs=None)

在訓練期間的 epoch 開始時呼叫。

on_epoch_end(self, epoch, logs=None)

在訓練期間的 epoch 結束時呼叫。


讓我們看看一個具體範例。首先,讓我們匯入 tensorflow 並定義一個簡單的 Sequential Keras 模型

# Define the Keras model to add callbacks to
def get_model():
    model = keras.Sequential()
    model.add(keras.layers.Dense(1, input_dim=784))
    return model

然後,從 Keras datasets API 載入 MNIST 資料以進行訓練和測試

# Load example MNIST data and pre-process it
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype("float32") / 255.0
x_test = x_test.reshape(-1, 784).astype("float32") / 255.0

# Limit the data to 1000 samples
x_train = x_train[:1000]
y_train = y_train[:1000]
x_test = x_test[:1000]
y_test = y_test[:1000]


  • 何時 fit/evaluate/predict 開始與結束
  • 何時每個 epoch 開始與結束
  • 何時每個訓練批次開始與結束
  • 何時每個評估 (測試) 批次開始與結束
  • 何時每個推論 (預測) 批次開始與結束
class CustomCallback(keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        keys = list(logs.keys())
        print("Starting training; got log keys: {}".format(keys))

    def on_train_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop training; got log keys: {}".format(keys))

    def on_epoch_begin(self, epoch, logs=None):
        keys = list(logs.keys())
        print("Start epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_epoch_end(self, epoch, logs=None):
        keys = list(logs.keys())
        print("End epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_test_begin(self, logs=None):
        keys = list(logs.keys())
        print("Start testing; got log keys: {}".format(keys))

    def on_test_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop testing; got log keys: {}".format(keys))

    def on_predict_begin(self, logs=None):
        keys = list(logs.keys())
        print("Start predicting; got log keys: {}".format(keys))

    def on_predict_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop predicting; got log keys: {}".format(keys))

    def on_train_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: start of batch {}; got log keys: {}".format(batch, keys))

    def on_train_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: end of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys))

    def on_predict_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Predicting: start of batch {}; got log keys: {}".format(batch, keys))

    def on_predict_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Predicting: end of batch {}; got log keys: {}".format(batch, keys))


model = get_model()

res = model.evaluate(
    x_test, y_test, batch_size=128, verbose=0, callbacks=[CustomCallback()]

res = model.predict(x_test, batch_size=128, callbacks=[CustomCallback()])
logs 字典的用法

logs 字典包含損失值,以及批次或 epoch 結束時的所有指標。範例包括損失和平均絕對誤差。

class LossAndErrorPrintingCallback(keras.callbacks.Callback):
    def on_train_batch_end(self, batch, logs=None):
            "Up to batch {}, the average loss is {:7.2f}.".format(batch, logs["loss"])

    def on_test_batch_end(self, batch, logs=None):
            "Up to batch {}, the average loss is {:7.2f}.".format(batch, logs["loss"])

    def on_epoch_end(self, epoch, logs=None):
            "The average loss for epoch {} is {:7.2f} "
            "and mean absolute error is {:7.2f}.".format(
                epoch, logs["loss"], logs["mean_absolute_error"]

model = get_model()

res = model.evaluate(
self.model 屬性的用法


以下是您可以在回呼函式中使用 self.model 執行的一些操作

  • 設定 self.model.stop_training = True 以立即中斷訓練。
  • 變更最佳化器的超參數 (以 self.model.optimizer 形式提供),例如 self.model.optimizer.learning_rate
  • 定期儲存模型。
  • 在每個 epoch 結束時記錄幾個測試範例上 model.predict() 的輸出,以在訓練期間用作健全性檢查。
  • 在每個 epoch 結束時擷取中繼特徵的視覺化,以監控模型隨時間學習的內容。
  • 等等。


Keras 回呼函式應用程式範例


第一個範例示範如何建立一個 Callback,當達到損失最小值時,透過設定屬性 self.model.stop_training (布林值) 來停止訓練。您可以選擇性地提供引數 patience,以指定在達到局部最小值後,我們應等待多少個 epoch 才停止。

tf.keras.callbacks.EarlyStopping 提供更完整和通用的實作。

import numpy as np

class EarlyStoppingAtMinLoss(keras.callbacks.Callback):
    """Stop training when the loss is at its min, i.e. the loss stops decreasing.

        patience: Number of epochs to wait after min has been hit. After this
        number of no improvement, training stops.

    def __init__(self, patience=0):
        self.patience = patience
        # best_weights to store the weights at which the minimum loss occurs.
        self.best_weights = None

    def on_train_begin(self, logs=None):
        # The number of epoch it has waited when loss is no longer minimum.
        self.wait = 0
        # The epoch the training stops at.
        self.stopped_epoch = 0
        # Initialize the best as infinity.
        self.best = np.Inf

    def on_epoch_end(self, epoch, logs=None):
        current = logs.get("loss")
        if np.less(current, self.best):
            self.best = current
            self.wait = 0
            # Record the best weights if current results is better (less).
            self.best_weights = self.model.get_weights()
            self.wait += 1
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
                self.model.stop_training = True
                print("Restoring model weights from the end of the best epoch.")

    def on_train_end(self, logs=None):
        if self.stopped_epoch > 0:
            print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))

model = get_model()
    callbacks=[LossAndErrorPrintingCallback(), EarlyStoppingAtMinLoss()],
在此範例中,我們示範如何使用自訂 Callback 在訓練過程中動態變更最佳化器的學習率。

請參閱 callbacks.LearningRateScheduler 以取得更通用的實作。

class CustomLearningRateScheduler(keras.callbacks.Callback):
    """Learning rate scheduler which sets the learning rate according to schedule.

        schedule: a function that takes an epoch index
            (integer, indexed from 0) and current learning rate
            as inputs and returns a new learning rate as output (float).

    def __init__(self, schedule):
        self.schedule = schedule

    def on_epoch_begin(self, epoch, logs=None):
        if not hasattr(self.model.optimizer, "lr"):
            raise ValueError('Optimizer must have a "lr" attribute.')
        # Get the current learning rate from model's optimizer.
        lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate))
        # Call schedule function to get the scheduled learning rate.
        scheduled_lr = self.schedule(epoch, lr)
        # Set the value back to the optimizer before this epoch starts
        tf.keras.backend.set_value(self.model.optimizer.lr, scheduled_lr)
        print("\nEpoch %05d: Learning rate is %6.4f." % (epoch, scheduled_lr))

    # (epoch to start, learning rate) tuples
    (3, 0.05),
    (6, 0.01),
    (9, 0.005),
    (12, 0.001),

def lr_schedule(epoch, lr):
    """Helper function to retrieve the scheduled learning rate based on epoch."""
    if epoch < LR_SCHEDULE[0][0] or epoch > LR_SCHEDULE[-1][0]:
        return lr
    for i in range(len(LR_SCHEDULE)):
        if epoch == LR_SCHEDULE[i][0]:
            return LR_SCHEDULE[i][1]
    return lr

model = get_model()
內建 Keras 回呼函式

務必查看現有的 Keras 回呼函式,方法是閱讀 API 文件。應用程式包括記錄至 CSV、儲存模型、在 TensorBoard 中將指標視覺化等等!