![]() |
![]() |
![]() |
![]() |
容錯能力是指定期儲存可追蹤物件 (例如參數和模型) 狀態的機制。這可讓您在訓練期間發生程式/機器故障時還原這些狀態。
本指南首先示範如何在 TensorFlow 1 中透過指定搭配 tf.estimator.RunConfig
的指標儲存功能,將容錯能力新增至搭配 tf.estimator.Estimator
的訓練。然後,您將學習如何以兩種方式在 Tensorflow 2 中實作訓練的容錯能力
- 如果您使用 Keras
Model.fit
API,則可以將tf.keras.callbacks.BackupAndRestore
回呼傳遞給它。 - 如果您使用自訂訓練迴圈 (搭配
tf.GradientTape
),則可以使用tf.train.Checkpoint
和tf.train.CheckpointManager
API 任意儲存檢查點。
這兩種方法都會在 檢查點檔案中備份及還原訓練狀態。
設定
安裝 tf-nightly
,因為在 TensorFlow 2.10 中導入了在特定步驟使用 tf.keras.callbacks.BackupAndRestore
中的 save_freq
引數儲存檢查點的頻率
pip install tf-nightly
import tensorflow.compat.v1 as tf1
import tensorflow as tf
import numpy as np
import tempfile
import time
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
TensorFlow 1:使用 tf.estimator.RunConfig
儲存檢查點
在 TensorFlow 1 中,您可以設定 tf.estimator
,以透過設定 tf.estimator.RunConfig
來在每個步驟儲存檢查點。
在本範例中,首先編寫一個掛鉤,以便在第五個檢查點人為擲回錯誤
class InterruptHook(tf1.train.SessionRunHook):
# A hook for artificially interrupting training.
def begin(self):
self._step = -1
def before_run(self, run_context):
self._step += 1
def after_run(self, run_context, run_values):
if self._step == 5:
raise RuntimeError('Interruption')
接下來,設定 tf.estimator.Estimator
以儲存每個檢查點,並使用 MNIST 資料集
feature_columns = [tf1.feature_column.numeric_column("x", shape=[28, 28])]
config = tf1.estimator.RunConfig(save_summary_steps=1,
save_checkpoints_steps=1)
path = tempfile.mkdtemp()
classifier = tf1.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[256, 32],
optimizer=tf1.train.AdamOptimizer(0.001),
n_classes=10,
dropout=0.2,
model_dir=path,
config = config
)
train_input_fn = tf1.estimator.inputs.numpy_input_fn(
x={"x": x_train},
y=y_train.astype(np.int32),
num_epochs=10,
batch_size=50,
shuffle=True,
)
開始訓練模型。您稍早定義的掛鉤將會引發人為例外狀況。
try:
classifier.train(input_fn=train_input_fn,
hooks=[InterruptHook()],
max_steps=10)
except Exception as e:
print(f'{type(e).__name__}:{e}')
使用最後儲存的檢查點重建 tf.estimator.Estimator
並繼續訓練
classifier = tf1.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[256, 32],
optimizer=tf1.train.AdamOptimizer(0.001),
n_classes=10,
dropout=0.2,
model_dir=path,
config = config
)
classifier.train(input_fn=train_input_fn,
max_steps = 10)
TensorFlow 2:使用回呼和 Model.fit
備份和還原
在 TensorFlow 2 中,如果您使用 Keras Model.fit
API 進行訓練,則可以提供 tf.keras.callbacks.BackupAndRestore
回呼以新增容錯能力功能。
為了協助示範此功能,首先從定義 Keras Callback
類別開始,以便在第四個週期檢查點人為擲回錯誤
class InterruptAtEpoch(tf.keras.callbacks.Callback):
# A callback for artificially interrupting training.
def __init__(self, interrupting_epoch=3):
self.interrupting_epoch = interrupting_epoch
def on_epoch_end(self, epoch, log=None):
if epoch == self.interrupting_epoch:
raise RuntimeError('Interruption')
然後,定義並例項化簡單的 Keras 模型、定義損失函數、呼叫 Model.compile
,並設定 tf.keras.callbacks.BackupAndRestore
回呼,以便在週期界限的暫時目錄中儲存檢查點
def create_model():
return tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10)
])
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model = create_model()
model.compile(optimizer='adam',
loss=loss,
metrics=['accuracy'])
log_dir = tempfile.mkdtemp()
backup_restore_callback = tf.keras.callbacks.BackupAndRestore(
backup_dir = log_dir)
使用 Model.fit
開始訓練模型。在訓練期間,會感謝上方例項化的 tf.keras.callbacks.BackupAndRestore
而儲存檢查點,而 InterruptAtEpoch
類別將會引發人為例外狀況,以便在第四個週期後模擬故障。
try:
model.fit(x=x_train,
y=y_train,
epochs=10,
steps_per_epoch=100,
validation_data=(x_test, y_test),
callbacks=[backup_restore_callback, InterruptAtEpoch()])
except Exception as e:
print(f'{type(e).__name__}:{e}')
接下來,例項化 Keras 模型、呼叫 Model.compile
,並從先前儲存的檢查點使用 Model.fit
繼續訓練模型
model = create_model()
model.compile(optimizer='adam',
loss=loss,
metrics=['accuracy'],
steps_per_execution=10)
model.fit(x=x_train,
y=y_train,
epochs=10,
steps_per_epoch=100,
validation_data=(x_test, y_test),
callbacks=[backup_restore_callback])
定義另一個 Callback
類別,以便在第 140 個步驟人為擲回錯誤
class InterruptAtStep(tf.keras.callbacks.Callback):
# A callback for artificially interrupting training.
def __init__(self, interrupting_step=140):
self.total_step_count = 0
self.interrupting_step = interrupting_step
def on_batch_begin(self, batch, logs=None):
self.total_step_count += 1
def on_batch_end(self, batch, logs=None):
if self.total_step_count == self.interrupting_step:
print("\nInterrupting at step count", self.total_step_count)
raise RuntimeError('Interruption')
為了確保每 30 個步驟儲存檢查點,請在 BackupAndRestore
回呼中將 save_freq
設定為 30
。InterruptAtStep
將會引發人為例外狀況,以便在週期 1 和步驟 40 (總步驟計數 140) 時模擬故障。檢查點將會在週期 1 和步驟 20 時最後儲存。
log_dir_2 = tempfile.mkdtemp()
backup_restore_callback = tf.keras.callbacks.BackupAndRestore(
backup_dir = log_dir_2, save_freq=30
)
model = create_model()
model.compile(optimizer='adam',
loss=loss,
metrics=['accuracy'])
try:
model.fit(x=x_train,
y=y_train,
epochs=10,
steps_per_epoch=100,
validation_data=(x_test, y_test),
callbacks=[backup_restore_callback, InterruptAtStep()])
except Exception as e:
print(f'{type(e).__name__}:{e}')
接下來,例項化 Keras 模型、呼叫 Model.compile
,並從先前儲存的檢查點使用 Model.fit
繼續訓練模型。請注意,訓練從週期 2 和步驟 21 開始。
model = create_model()
model.compile(optimizer='adam',
loss=loss,
metrics=['accuracy'],
steps_per_execution=10)
model.fit(x=x_train,
y=y_train,
epochs=10,
steps_per_epoch=100,
validation_data=(x_test, y_test),
callbacks=[backup_restore_callback])
TensorFlow 2:使用自訂訓練迴圈編寫手動檢查點
如果您在 TensorFlow 2 中使用自訂訓練迴圈,則可以使用 tf.train.Checkpoint
和 tf.train.CheckpointManager
API 實作容錯能力機制。
本範例示範如何
- 使用
tf.train.Checkpoint
物件手動建立檢查點,其中您要儲存的可追蹤物件會設定為屬性。 - 使用
tf.train.CheckpointManager
管理多個檢查點。
首先定義並例項化 Keras 模型、最佳化工具和損失函數。然後,建立 Checkpoint
,以便管理具有可追蹤狀態的兩個物件 (模型和最佳化工具),以及 CheckpointManager
,以便在暫時目錄中記錄和保留數個檢查點。
model = create_model()
optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
log_dir = tempfile.mkdtemp()
epochs = 5
steps_per_epoch = 5
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint, log_dir, max_to_keep=2)
現在,實作自訂訓練迴圈,以便在每個新週期開始時,於第一個週期後載入最後一個檢查點
for epoch in range(epochs):
if epoch > 0:
tf.train.load_checkpoint(save_path)
print(f"\nStart of epoch {epoch}")
for step in range(steps_per_epoch):
with tf.GradientTape() as tape:
logits = model(x_train, training=True)
loss_value = loss_fn(y_train, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
save_path = checkpoint_manager.save()
print(f"Checkpoint saved to {save_path}")
print(f"Training loss at step {step}: {loss_value}")
後續步驟
若要進一步瞭解 TensorFlow 2 中的容錯能力和檢查點,請參閱下列文件
tf.keras.callbacks.BackupAndRestore
回呼 API 文件。tf.train.Checkpoint
和tf.train.CheckpointManager
API 文件。- 訓練檢查點指南,包括「編寫檢查點」章節。
您也可以找到下列與 分散式訓練相關的資料,可能對您有所幫助
- 搭配 Keras 進行多工作站訓練教學課程中的「容錯能力」章節。
- 參數伺服器訓練教學課程中的「處理工作失敗」章節。