![]() |
![]() |
![]() |
![]() |
持續儲存「最佳」模型或模型權重/參數有很多好處。這些好處包括能夠追蹤訓練進度,以及從不同的儲存狀態載入已儲存的模型。
在 TensorFlow 1 中,若要使用 tf.estimator.Estimator
API 設定訓練/驗證期間的檢查點儲存,您可以在 tf.estimator.RunConfig
中指定排程,或使用 tf.estimator.CheckpointSaverHook
。本指南示範如何從此工作流程遷移至 TensorFlow 2 Keras API。
在 TensorFlow 2 中,您可以使用多種方式設定 tf.keras.callbacks.ModelCheckpoint
- 根據使用
save_best_only=True
參數監控的指標,儲存「最佳」版本,其中monitor
可以是'loss'
、'val_loss'
、'accuracy' 或
'val_accuracy'` 等。 - 以特定頻率持續儲存 (使用
save_freq
引數)。 - 只儲存權重/參數,而非整個模型,方法是將
save_weights_only
設定為True
。
如需更多詳細資訊,請參閱 tf.keras.callbacks.ModelCheckpoint
API 文件,以及儲存和載入模型教學課程中的訓練期間儲存檢查點一節。若要深入瞭解檢查點格式,請參閱儲存和載入 Keras 模型指南中的TF 檢查點格式一節。此外,若要新增容錯能力,您可以針對手動檢查點使用 tf.keras.callbacks.BackupAndRestore
或 tf.train.Checkpoint
。請在容錯遷移指南中瞭解詳情。
Keras 回呼是在內建 Keras Model.fit
/Model.evaluate
/Model.predict
API 中訓練/評估/預測期間的不同時間點呼叫的物件。請在本指南結尾的後續步驟一節中瞭解詳情。
設定
從匯入項目和一個用於示範用途的簡單資料集開始
import tensorflow.compat.v1 as tf1
import tensorflow as tf
import numpy as np
import tempfile
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
TensorFlow 1:使用 tf.estimator API 儲存檢查點
此 TensorFlow 1 範例示範如何設定 tf.estimator.RunConfig
,以在使用 tf.estimator.Estimator
API 進行訓練/評估期間的每個步驟儲存檢查點
feature_columns = [tf1.feature_column.numeric_column("x", shape=[28, 28])]
config = tf1.estimator.RunConfig(save_summary_steps=1,
save_checkpoints_steps=1)
path = tempfile.mkdtemp()
classifier = tf1.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[256, 32],
optimizer=tf1.train.AdamOptimizer(0.001),
n_classes=10,
dropout=0.2,
model_dir=path,
config = config
)
train_input_fn = tf1.estimator.inputs.numpy_input_fn(
x={"x": x_train},
y=y_train.astype(np.int32),
num_epochs=10,
batch_size=50,
shuffle=True,
)
test_input_fn = tf1.estimator.inputs.numpy_input_fn(
x={"x": x_test},
y=y_test.astype(np.int32),
num_epochs=10,
shuffle=False
)
train_spec = tf1.estimator.TrainSpec(input_fn=train_input_fn, max_steps=10)
eval_spec = tf1.estimator.EvalSpec(input_fn=test_input_fn,
steps=10,
throttle_secs=0)
tf1.estimator.train_and_evaluate(estimator=classifier,
train_spec=train_spec,
eval_spec=eval_spec)
%ls {classifier.model_dir}
TensorFlow 2:使用 Keras 回呼為 Model.fit 儲存檢查點
在 TensorFlow 2 中,當您使用內建 Keras Model.fit
(或 Model.evaluate
) 進行訓練/評估時,您可以設定 tf.keras.callbacks.ModelCheckpoint
,然後將其傳遞至 Model.fit
(或 Model.evaluate
) 的 callbacks
參數。(請在 API 文件和使用內建方法進行訓練和評估指南中的使用回呼一節瞭解詳情。)
在以下範例中,您將使用 tf.keras.callbacks.ModelCheckpoint
回呼在暫存目錄中儲存檢查點
def create_model():
return tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model = create_model()
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'],
steps_per_execution=10)
log_dir = tempfile.mkdtemp()
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=log_dir)
model.fit(x=x_train,
y=y_train,
epochs=10,
validation_data=(x_test, y_test),
callbacks=[model_checkpoint_callback])
%ls {model_checkpoint_callback.filepath}
後續步驟
深入瞭解檢查點
- API 文件:
tf.keras.callbacks.ModelCheckpoint
- 教學課程:儲存和載入模型 (訓練期間儲存檢查點一節)
- 指南:儲存和載入 Keras 模型 (TF 檢查點格式一節)
深入瞭解回呼
- API 文件:
tf.keras.callbacks.Callback
- 指南:編寫您自己的回呼
- 指南:使用內建方法進行訓練和評估 (使用回呼一節)
您可能也會發現下列與遷移相關的資源很有用
- 容錯遷移指南:適用於
Model.fit
的tf.keras.callbacks.BackupAndRestore
,或適用於自訂訓練迴圈的tf.train.Checkpoint
和tf.train.CheckpointManager
API - 提前停止遷移指南:
tf.keras.callbacks.EarlyStopping
是一種內建的提前停止回呼 - TensorBoard 遷移指南:TensorBoard 可追蹤和顯示指標
- LoggingTensorHook 和 StopAtStepHook 至 Keras 回呼遷移指南
- SessionRunHook 至 Keras 回呼指南