遷移檢查點儲存

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視原始碼 下載筆記本

持續儲存「最佳」模型或模型權重/參數有很多好處。這些好處包括能夠追蹤訓練進度,以及從不同的儲存狀態載入已儲存的模型。

在 TensorFlow 1 中,若要使用 tf.estimator.Estimator API 設定訓練/驗證期間的檢查點儲存,您可以在 tf.estimator.RunConfig 中指定排程,或使用 tf.estimator.CheckpointSaverHook。本指南示範如何從此工作流程遷移至 TensorFlow 2 Keras API。

在 TensorFlow 2 中,您可以使用多種方式設定 tf.keras.callbacks.ModelCheckpoint

  • 根據使用 save_best_only=True 參數監控的指標,儲存「最佳」版本,其中 monitor 可以是 'loss''val_loss''accuracy' 或'val_accuracy'` 等。
  • 以特定頻率持續儲存 (使用 save_freq 引數)。
  • 只儲存權重/參數,而非整個模型,方法是將 save_weights_only 設定為 True

如需更多詳細資訊,請參閱 tf.keras.callbacks.ModelCheckpoint API 文件,以及儲存和載入模型教學課程中的訓練期間儲存檢查點一節。若要深入瞭解檢查點格式,請參閱儲存和載入 Keras 模型指南中的TF 檢查點格式一節。此外,若要新增容錯能力,您可以針對手動檢查點使用 tf.keras.callbacks.BackupAndRestoretf.train.Checkpoint。請在容錯遷移指南中瞭解詳情。

Keras 回呼是在內建 Keras Model.fit/Model.evaluate/Model.predict API 中訓練/評估/預測期間的不同時間點呼叫的物件。請在本指南結尾的後續步驟一節中瞭解詳情。

設定

從匯入項目和一個用於示範用途的簡單資料集開始

import tensorflow.compat.v1 as tf1
import tensorflow as tf
import numpy as np
import tempfile
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

TensorFlow 1:使用 tf.estimator API 儲存檢查點

此 TensorFlow 1 範例示範如何設定 tf.estimator.RunConfig,以在使用 tf.estimator.Estimator API 進行訓練/評估期間的每個步驟儲存檢查點

feature_columns = [tf1.feature_column.numeric_column("x", shape=[28, 28])]

config = tf1.estimator.RunConfig(save_summary_steps=1,
                                 save_checkpoints_steps=1)

path = tempfile.mkdtemp()

classifier = tf1.estimator.DNNClassifier(
    feature_columns=feature_columns,
    hidden_units=[256, 32],
    optimizer=tf1.train.AdamOptimizer(0.001),
    n_classes=10,
    dropout=0.2,
    model_dir=path,
    config = config
)

train_input_fn = tf1.estimator.inputs.numpy_input_fn(
    x={"x": x_train},
    y=y_train.astype(np.int32),
    num_epochs=10,
    batch_size=50,
    shuffle=True,
)

test_input_fn = tf1.estimator.inputs.numpy_input_fn(
    x={"x": x_test},
    y=y_test.astype(np.int32),
    num_epochs=10,
    shuffle=False
)

train_spec = tf1.estimator.TrainSpec(input_fn=train_input_fn, max_steps=10)
eval_spec = tf1.estimator.EvalSpec(input_fn=test_input_fn,
                                   steps=10,
                                   throttle_secs=0)

tf1.estimator.train_and_evaluate(estimator=classifier,
                                train_spec=train_spec,
                                eval_spec=eval_spec)
%ls {classifier.model_dir}

TensorFlow 2:使用 Keras 回呼為 Model.fit 儲存檢查點

在 TensorFlow 2 中,當您使用內建 Keras Model.fit (或 Model.evaluate) 進行訓練/評估時,您可以設定 tf.keras.callbacks.ModelCheckpoint,然後將其傳遞至 Model.fit (或 Model.evaluate) 的 callbacks 參數。(請在 API 文件和使用內建方法進行訓練和評估指南中的使用回呼一節瞭解詳情。)

在以下範例中,您將使用 tf.keras.callbacks.ModelCheckpoint 回呼在暫存目錄中儲存檢查點

def create_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
  ])

model = create_model()
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'],
              steps_per_execution=10)

log_dir = tempfile.mkdtemp()

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=log_dir)

model.fit(x=x_train,
          y=y_train,
          epochs=10,
          validation_data=(x_test, y_test),
          callbacks=[model_checkpoint_callback])
%ls {model_checkpoint_callback.filepath}

後續步驟

深入瞭解檢查點

深入瞭解回呼

您可能也會發現下列與遷移相關的資源很有用