作者: fchollet
![]() |
![]() |
![]() |
![]() |
簡介
一般來說,有兩種方式可在多個裝置之間分散運算
資料平行處理:單一模型在多個裝置或多部機器上複製。它們各自處理不同的資料批次,然後合併結果。這種設定方式有很多種變體,差異在於不同模型複本合併結果的方式、它們是否在每個批次都保持同步,或是它們是否耦合程度較低等等。
模型平行處理:單一模型的不同部分在不同裝置上執行,共同處理單一批次的資料。這種方式最適合具有自然平行架構的模型,例如具有多個分支的模型。
本指南著重於資料平行處理,特別是同步資料平行處理,其中模型的不同複本在處理完每個批次後保持同步。同步性讓模型收斂行為與您在單一裝置訓練中看到的行為相同。
具體來說,本指南會教您如何在下列兩種設定中使用 tf.distribute
API,在多個 GPU 上訓練 Keras 模型,只需對程式碼進行最少的變更
- 在單一機器上安裝的多個 GPU (通常為 2 到 8 個) 上 (單一主機、多裝置訓練)。這是研究人員和小規模產業工作流程最常見的設定。
- 在多部機器的叢集上,每部機器都託管一個或多個 GPU (多工作站分散式訓練)。這適用於大規模產業工作流程,例如使用 20-100 個 GPU 訓練數千萬張影像的高解析度影像分類模型。
設定
import tensorflow as tf
import keras
單一主機、多裝置同步訓練
在此設定中,您有一部機器,上面有多個 GPU (通常為 2 到 8 個)。每個裝置都會執行模型的副本 (稱為複本)。為簡化起見,在以下內容中,我們假設處理的是 8 個 GPU,這不失一般性。
運作方式
在訓練的每個步驟中
- 目前的資料批次 (稱為全域批次) 會分割成 8 個不同的子批次 (稱為本機批次)。例如,如果全域批次有 512 個樣本,則 8 個本機批次中的每一個都會有 64 個樣本。
- 8 個複本中的每一個都會獨立處理本機批次:它們會執行正向傳遞,然後執行反向傳遞,輸出權重相對於模型在本機批次上的損失的梯度。
- 源自本機梯度的權重更新會在 8 個複本之間有效合併。由於這是在每個步驟結束時完成的,因此複本始終保持同步。
實際上,同步更新模型複本權重的程序是在每個個別權重變數的層級處理的。這是透過鏡像變數物件完成的。
如何使用
若要使用 Keras 模型執行單一主機、多裝置同步訓練,您可以使用 tf.distribute.MirroredStrategy
API。以下說明其運作方式
- 例項化
MirroredStrategy
,您可以選擇性地設定要使用的特定裝置 (預設情況下,策略會使用所有可用的 GPU)。 - 使用策略物件開啟範圍,並在此範圍內,建立所有您需要的包含變數的 Keras 物件。通常,這表示在分配範圍內建立與編譯模型。
- 像平常一樣透過
fit()
訓練模型。
重要的是,我們建議您使用 tf.data.Dataset
物件,在多裝置或分散式工作流程中載入資料。
示意圖如下所示
# Create a MirroredStrategy.
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
# Open a strategy scope.
with strategy.scope():
# Everything that creates variables should be under the strategy scope.
# In general this is only model construction & `compile()`.
model = Model(...)
model.compile(...)
# Train the model on all available devices.
model.fit(train_dataset, validation_data=val_dataset, ...)
# Test the model on all available devices.
model.evaluate(test_dataset)
以下是一個簡單的端對端可執行範例
def get_compiled_model():
# Make a simple 2-layer densely-connected neural network.
inputs = keras.Input(shape=(784,))
x = keras.layers.Dense(256, activation="relu")(inputs)
x = keras.layers.Dense(256, activation="relu")(x)
outputs = keras.layers.Dense(10)(x)
model = keras.Model(inputs, outputs)
model.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
return model
def get_dataset():
batch_size = 32
num_val_samples = 10000
# Return the MNIST dataset in the form of a `tf.data.Dataset`.
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Preprocess the data (these are Numpy arrays)
x_train = x_train.reshape(-1, 784).astype("float32") / 255
x_test = x_test.reshape(-1, 784).astype("float32") / 255
y_train = y_train.astype("float32")
y_test = y_test.astype("float32")
# Reserve num_val_samples samples for validation
x_val = x_train[-num_val_samples:]
y_val = y_train[-num_val_samples:]
x_train = x_train[:-num_val_samples]
y_train = y_train[:-num_val_samples]
return (
tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size),
tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(batch_size),
tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size),
)
# Create a MirroredStrategy.
strategy = tf.distribute.MirroredStrategy()
print("Number of devices: {}".format(strategy.num_replicas_in_sync))
# Open a strategy scope.
with strategy.scope():
# Everything that creates variables should be under the strategy scope.
# In general this is only model construction & `compile()`.
model = get_compiled_model()
# Train the model on all available devices.
train_dataset, val_dataset, test_dataset = get_dataset()
model.fit(train_dataset, epochs=2, validation_data=val_dataset)
# Test the model on all available devices.
model.evaluate(test_dataset)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3') Number of devices: 4 2023-07-19 11:35:32.379801: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:786] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2" op: "TensorSliceDataset" input: "Placeholder/_0" input: "Placeholder/_1" attr { key: "Toutput_types" value { list { type: DT_FLOAT type: DT_FLOAT } } } attr { key: "_cardinality" value { i: 50000 } } attr { key: "is_files" value { b: false } } attr { key: "metadata" value { s: "\n\024TensorSliceDataset:0" } } attr { key: "output_shapes" value { list { shape { dim { size: 784 } } shape { } } } } attr { key: "replicate_on_split" value { b: false } } experimental_type { type_id: TFT_PRODUCT args { type_id: TFT_DATASET args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } } } } Epoch 1/2 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). 1556/1563 [============================>.] - ETA: 0s - loss: 0.2236 - sparse_categorical_accuracy: 0.9328INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). 2023-07-19 11:35:46.769935: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:786] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2" op: "TensorSliceDataset" input: "Placeholder/_0" input: "Placeholder/_1" attr { key: "Toutput_types" value { list { type: DT_FLOAT type: DT_FLOAT } } } attr { key: "_cardinality" value { i: 10000 } } attr { key: "is_files" value { b: false } } attr { key: "metadata" value { s: "\n\024TensorSliceDataset:2" } } attr { key: "output_shapes" value { list { shape { dim { size: 784 } } shape { } } } } attr { key: "replicate_on_split" value { b: false } } experimental_type { type_id: TFT_PRODUCT args { type_id: TFT_DATASET args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } } } } 1563/1563 [==============================] - 16s 7ms/step - loss: 0.2238 - sparse_categorical_accuracy: 0.9328 - val_loss: 0.1347 - val_sparse_categorical_accuracy: 0.9592 Epoch 2/2 1563/1563 [==============================] - 11s 7ms/step - loss: 0.0940 - sparse_categorical_accuracy: 0.9717 - val_loss: 0.0984 - val_sparse_categorical_accuracy: 0.9684 2023-07-19 11:35:59.993148: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:786] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2" op: "TensorSliceDataset" input: "Placeholder/_0" input: "Placeholder/_1" attr { key: "Toutput_types" value { list { type: DT_FLOAT type: DT_FLOAT } } } attr { key: "_cardinality" value { i: 10000 } } attr { key: "is_files" value { b: false } } attr { key: "metadata" value { s: "\n\024TensorSliceDataset:4" } } attr { key: "output_shapes" value { list { shape { dim { size: 784 } } shape { } } } } attr { key: "replicate_on_split" value { b: false } } experimental_type { type_id: TFT_PRODUCT args { type_id: TFT_DATASET args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } } } } 313/313 [==============================] - 2s 4ms/step - loss: 0.1057 - sparse_categorical_accuracy: 0.9676 [0.10571097582578659, 0.9675999879837036]
使用回呼來確保容錯能力
使用分散式訓練時,您應始終確保您有從失敗中復原的策略 (容錯能力)。處理此問題最簡單的方法是將 ModelCheckpoint
回呼傳遞至 fit()
,以定期儲存模型 (例如,每 100 個批次或每個週期)。然後,您可以從儲存的模型重新啟動訓練。
以下是一個簡單的範例
import os
from tensorflow import keras
# Prepare a directory to store all the checkpoints.
checkpoint_dir = "./ckpt"
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
def make_or_restore_model():
# Either restore the latest model, or create a fresh one
# if there is no checkpoint available.
checkpoints = [checkpoint_dir + "/" + name for name in os.listdir(checkpoint_dir)]
if checkpoints:
latest_checkpoint = max(checkpoints, key=os.path.getctime)
print("Restoring from", latest_checkpoint)
return keras.models.load_model(latest_checkpoint)
print("Creating a new model")
return get_compiled_model()
def run_training(epochs=1):
# Create a MirroredStrategy.
strategy = tf.distribute.MirroredStrategy()
# Open a strategy scope and create/restore the model
with strategy.scope():
model = make_or_restore_model()
callbacks = [
# This callback saves a SavedModel every epoch
# We include the current epoch in the folder name.
keras.callbacks.ModelCheckpoint(
filepath=checkpoint_dir + "/ckpt-{epoch}", save_freq="epoch"
)
]
model.fit(
train_dataset,
epochs=epochs,
callbacks=callbacks,
validation_data=val_dataset,
verbose=2,
)
# Running the first time creates the model
run_training(epochs=1)
# Calling the same function again will resume from where we left off
run_training(epochs=1)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3') Creating a new model 2023-07-19 11:36:01.811216: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:786] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2" op: "TensorSliceDataset" input: "Placeholder/_0" input: "Placeholder/_1" attr { key: "Toutput_types" value { list { type: DT_FLOAT type: DT_FLOAT } } } attr { key: "_cardinality" value { i: 50000 } } attr { key: "is_files" value { b: false } } attr { key: "metadata" value { s: "\n\024TensorSliceDataset:0" } } attr { key: "output_shapes" value { list { shape { dim { size: 784 } } shape { } } } } attr { key: "replicate_on_split" value { b: false } } experimental_type { type_id: TFT_PRODUCT args { type_id: TFT_DATASET args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } } } } INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 2023-07-19 11:36:13.671835: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:786] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2" op: "TensorSliceDataset" input: "Placeholder/_0" input: "Placeholder/_1" attr { key: "Toutput_types" value { list { type: DT_FLOAT type: DT_FLOAT } } } attr { key: "_cardinality" value { i: 10000 } } attr { key: "is_files" value { b: false } } attr { key: "metadata" value { s: "\n\024TensorSliceDataset:2" } } attr { key: "output_shapes" value { list { shape { dim { size: 784 } } shape { } } } } attr { key: "replicate_on_split" value { b: false } } experimental_type { type_id: TFT_PRODUCT args { type_id: TFT_DATASET args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } } } } INFO:tensorflow:Assets written to: ./ckpt/ckpt-1/assets INFO:tensorflow:Assets written to: ./ckpt/ckpt-1/assets 1563/1563 - 14s - loss: 0.2268 - sparse_categorical_accuracy: 0.9322 - val_loss: 0.1148 - val_sparse_categorical_accuracy: 0.9656 - 14s/epoch - 9ms/step INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3') INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3') Restoring from ./ckpt/ckpt-1 2023-07-19 11:36:16.521031: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:786] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2" op: "TensorSliceDataset" input: "Placeholder/_0" input: "Placeholder/_1" attr { key: "Toutput_types" value { list { type: DT_FLOAT type: DT_FLOAT } } } attr { key: "_cardinality" value { i: 50000 } } attr { key: "is_files" value { b: false } } attr { key: "metadata" value { s: "\n\024TensorSliceDataset:0" } } attr { key: "output_shapes" value { list { shape { dim { size: 784 } } shape { } } } } attr { key: "replicate_on_split" value { b: false } } experimental_type { type_id: TFT_PRODUCT args { type_id: TFT_DATASET args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } } } } INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 2023-07-19 11:36:28.440092: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:786] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2" op: "TensorSliceDataset" input: "Placeholder/_0" input: "Placeholder/_1" attr { key: "Toutput_types" value { list { type: DT_FLOAT type: DT_FLOAT } } } attr { key: "_cardinality" value { i: 10000 } } attr { key: "is_files" value { b: false } } attr { key: "metadata" value { s: "\n\024TensorSliceDataset:2" } } attr { key: "output_shapes" value { list { shape { dim { size: 784 } } shape { } } } } attr { key: "replicate_on_split" value { b: false } } experimental_type { type_id: TFT_PRODUCT args { type_id: TFT_DATASET args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } } } } INFO:tensorflow:Assets written to: ./ckpt/ckpt-1/assets INFO:tensorflow:Assets written to: ./ckpt/ckpt-1/assets 1563/1563 - 13s - loss: 0.0974 - sparse_categorical_accuracy: 0.9703 - val_loss: 0.0960 - val_sparse_categorical_accuracy: 0.9724 - 13s/epoch - 9ms/step
tf.data
效能秘訣
執行分散式訓練時,載入資料的效率通常會變得至關重要。以下是一些秘訣,可確保您的 tf.data
管線盡可能快速地執行。
關於資料集批次的注意事項
建立資料集時,請確保它以全域批次大小進行批次處理。例如,如果您的 8 個 GPU 中的每一個都能够執行 64 個樣本的批次,您可以呼叫使用 512 的全域批次大小。
呼叫 dataset.cache()
如果您在資料集上呼叫 .cache()
,則在資料上執行第一次迭代後,其資料將被快取。每個後續迭代都將使用快取資料。快取可以位於記憶體中 (預設) 或您指定的本機檔案中。
在以下情況下,這可以提高效能
- 您的資料預期不會在迭代之間變更
- 您正在從遠端分散式檔案系統讀取資料
- 您正在從本機磁碟讀取資料,但您的資料可以放入記憶體中,而且您的工作流程受到顯著的 IO 限制 (例如,讀取與解碼影像檔案)。
呼叫 dataset.prefetch(buffer_size)
您幾乎應始終在建立資料集後呼叫 .prefetch(buffer_size)
。這表示您的資料管線將與您的模型非同步執行,新的樣本將在後端進行預先處理並儲存在緩衝區中,而目前的批次樣本則用於訓練模型。在目前的批次結束時,下一個批次將預先擷取到 GPU 記憶體中。
多工作站分散式同步訓練
運作方式
在此設定中,您有多部機器 (稱為工作站),每部機器都有一個或多個 GPU。與單一主機訓練的情況非常相似,每個可用的 GPU 都會執行一個模型複本,並且每個複本的變數值在每個批次後都保持同步。
重要的是,目前的實作假設所有工作站都具有相同數量的 GPU (同質叢集)。
如何使用
- 設定叢集 (我們在下方提供指標)。
- 在每個工作站上設定適當的
TF_CONFIG
環境變數。這會告知工作站其角色以及如何與其同層級工作站通訊。 - 在每個工作站上,在
MultiWorkerMirroredStrategy
物件的範圍內執行您的模型建構與編譯程式碼,類似於我們對單一主機訓練所做的方式。 - 在指定的評估器機器上執行評估程式碼。
設定叢集
首先,設定叢集 (機器集合)。應個別設定每部機器,使其能够執行您的模型 (通常,每部機器都會執行相同的 Docker 映像) 並能够存取您的資料來源 (例如,GCS)。
叢集管理不在本指南的範圍內。以下文件可協助您開始使用。您也可以查看 Kubeflow。
設定 TF_CONFIG
環境變數
雖然在每個工作站上執行的程式碼幾乎與單一主機工作流程中使用的程式碼相同 (除了使用不同的 tf.distribute
策略物件),但單一主機工作流程與多工作站工作流程之間的一個顯著差異是,您需要在叢集中執行的每部機器上設定 TF_CONFIG
環境變數。
TF_CONFIG
環境變數是一個 JSON 字串,用於指定
- 叢集組態,以及構成叢集的機器的位址和連接埠清單
- 工作站的「工作」,這是此特定機器必須在叢集中扮演的角色。
TF_CONFIG 的一個範例是
os.environ['TF_CONFIG'] = json.dumps({
'cluster': {
'worker': ["localhost:12345", "localhost:23456"]
},
'task': {'type': 'worker', 'index': 0}
})
在多工作站同步訓練設定中,機器的有效角色 (工作類型) 為「工作站」和「評估器」。
例如,如果您有 8 部機器,每部機器有 4 個 GPU,則您可以有 7 個工作站和一個評估器。
- 工作站訓練模型,每個工作站處理全域批次的子批次。
- 其中一個工作站 (工作站 0) 將充當「主要工作站」,這是一種特殊類型的工作站,負責儲存記錄和檢查點以供日後重複使用 (通常儲存到雲端儲存位置)。
- 評估器執行一個持續迴圈,該迴圈載入主要工作站儲存的最新檢查點,對其執行評估 (與其他工作站非同步),並寫入評估記錄 (例如,TensorBoard 記錄)。
在每個工作站上執行程式碼
您會在每個工作站 (包括主要工作站) 上執行訓練程式碼,並在評估器上執行評估程式碼。
訓練程式碼基本上與您在單一主機設定中使用的程式碼相同,除了使用 MultiWorkerMirroredStrategy
而不是 MirroredStrategy
。
每個工作站都會執行相同的程式碼 (減去以下注意事項中說明的差異),包括相同的回呼。
評估器只會使用 MirroredStrategy
(因為它在單一機器上執行,不需要與其他機器通訊) 並呼叫 model.evaluate()
。它會載入主要工作站儲存到雲端儲存位置的最新檢查點,並將評估記錄儲存到與主要工作站記錄相同的位置。
範例:在多工作站設定中執行的程式碼
在主要工作站 (工作站 0) 上
# Set TF_CONFIG
os.environ['TF_CONFIG'] = json.dumps({
'cluster': {
'worker': ["localhost:12345", "localhost:23456"]
},
'task': {'type': 'worker', 'index': 0}
})
# Open a strategy scope and create/restore the model.
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
with strategy.scope():
model = make_or_restore_model()
callbacks = [
# This callback saves a SavedModel every 100 batches
keras.callbacks.ModelCheckpoint(filepath='path/to/cloud/location/ckpt',
save_freq=100),
keras.callbacks.TensorBoard('path/to/cloud/location/tb/')
]
model.fit(train_dataset,
callbacks=callbacks,
...)
在其他工作站上
# Set TF_CONFIG
worker_index = 1 # For instance
os.environ['TF_CONFIG'] = json.dumps({
'cluster': {
'worker': ["localhost:12345", "localhost:23456"]
},
'task': {'type': 'worker', 'index': worker_index}
})
# Open a strategy scope and create/restore the model.
# You can restore from the checkpoint saved by the chief.
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
with strategy.scope():
model = make_or_restore_model()
callbacks = [
keras.callbacks.ModelCheckpoint(filepath='local/path/ckpt', save_freq=100),
keras.callbacks.TensorBoard('local/path/tb/')
]
model.fit(train_dataset,
callbacks=callbacks,
...)
在評估器上
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
model = make_or_restore_model() # Restore from the checkpoint saved by the chief.
results = model.evaluate(val_dataset)
# Then, log the results on a shared location, write TensorBoard logs, etc