使用分散式策略儲存及載入模型

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視原始碼 下載筆記本

總覽

本教學課程示範如何在訓練期間或訓練後,使用 tf.distribute.Strategy,以 SavedModel 格式儲存及載入模型。儲存及載入 Keras 模型有兩種 API:高階 (tf.keras.Model.savetf.keras.models.load_model) 和低階 (tf.saved_model.savetf.saved_model.load)。

如要進一步瞭解 SavedModel 和一般序列化,請參閱SavedModel 指南Keras 模型序列化指南。我們先從簡單範例開始。

匯入依附元件

import tensorflow_datasets as tfds

import tensorflow as tf

使用 TensorFlow Datasets 和 tf.data 載入及準備資料,並使用 tf.distribute.MirroredStrategy 建立模型

mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets = tfds.load(name='mnist', as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=[tf.metrics.SparseCategoricalAccuracy()])
    return model

使用 tf.keras.Model.fit 訓練模型

model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)

儲存及載入模型

現在您已有簡單的模型可供使用,接下來探索儲存/載入 API。有兩種 API 可供使用

Keras API

以下範例說明如何使用 Keras API 儲存及載入模型

keras_model_path = '/tmp/keras_save.keras'
model.save(keras_model_path)

還原模型,但不使用 tf.distribute.Strategy

restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)

還原模型後,您可以繼續訓練,甚至不必再次呼叫 Model.compile,因為模型在儲存前已編譯過。模型會以 Keras zip 封存格式儲存,並以 .keras 副檔名標示。詳情請參閱Keras 儲存指南

現在,還原模型並使用 tf.distribute.Strategy 訓練模型

another_strategy = tf.distribute.OneDeviceStrategy('/cpu:0')
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)

Model.fit 輸出所示,載入作業可如預期般搭配 tf.distribute.Strategy 運作。此處使用的策略不必與儲存前使用的策略相同。

tf.saved_model API

使用較低階 API 儲存模型的方式與 Keras API 類似

model = get_model()  # get a fresh model
saved_model_path = '/tmp/tf_save'
tf.saved_model.save(model, saved_model_path)

載入作業可以使用 tf.saved_model.load 完成。不過,由於這是較低階的 API (因此有更廣泛的用途),因此不會傳回 Keras 模型。而是會傳回包含可用於進行推論的函式的物件。例如

DEFAULT_FUNCTION_KEY = 'serving_default'
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

載入的物件可能包含多個函式,每個函式都與一個金鑰相關聯。"serving_default" 金鑰是使用已儲存 Keras 模型進行推論函式的預設金鑰。如要使用此函式進行推論

predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))

您也可以分散式方式載入及進行推論

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset)

  # Calling the function in a distributed manner
  for batch in dist_predict_dataset:
    result = another_strategy.run(inference_func, args=(batch,))
    print(result)
    break

呼叫還原的函式只是對已儲存模型進行正向傳遞 (tf.keras.Model.predict)。如果您想繼續訓練載入的函式,該怎麼辦?或者,如果您需要將載入的函式嵌入較大的模型中,又該怎麼辦?常見的做法是將這個載入的物件包裝在 Keras 層中,以達成此目的。幸好,TF Hub 具有 hub.KerasLayer 可用於此目的,如下所示

import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # Wrap what's loaded to a KerasLayer
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=[tf.metrics.SparseCategoricalAccuracy()])
  model.fit(train_dataset, epochs=2)

在上述範例中,Tensorflow Hub 的 hub.KerasLayer 將從 tf.saved_model.load 載入回來的結果包裝到用於建構另一個模型的 Keras 層中。這對於遷移學習非常有用。

我應該使用哪個 API?

在儲存方面,如果您使用 Keras 模型,請使用 Keras Model.save API,除非您需要低階 API 允許的其他控制項。如果您儲存的不是 Keras 模型,則較低階的 API tf.saved_model.save 是您唯一的選擇。

在載入方面,您的 API 選擇取決於您想從模型載入 API 取得什麼。如果您無法 (或不想) 取得 Keras 模型,請使用 tf.saved_model.load。否則,請使用 tf.keras.models.load_model。請注意,只有在您儲存 Keras 模型時,才能取回 Keras 模型。

您可以混用 API。您可以使用 Model.save 儲存 Keras 模型,並使用低階 API tf.saved_model.load 載入非 Keras 模型。

model = get_model()

# Saving the model using Keras `Model.save`
model.save(saved_model_path)

another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using the lower-level API
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)

從本機裝置儲存/載入

從本機 I/O 裝置儲存及載入,同時在遠端裝置上訓練時 (例如,使用 Cloud TPU 時),您必須在 tf.saved_model.SaveOptionstf.saved_model.LoadOptions 中使用選項 experimental_io_device,將 I/O 裝置設定為 localhost。例如

model = get_model()

# Saving the model to a path on localhost.
saved_model_path = '/tmp/tf_save'
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options=save_options)

# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
  loaded = tf.keras.models.load_model(saved_model_path, options=load_options)

注意事項

一種特殊情況是,您以特定方式建立 Keras 模型,然後在訓練前儲存模型。例如

class SubclassedModel(tf.keras.Model):
  """Example model defined by subclassing `tf.keras.Model`."""

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
try:
  my_model.save(saved_model_path)
except ValueError as e:
  print(f'{type(e).__name__}: ', *e.args)

SavedModel 會儲存追蹤 tf.function 時產生的 tf.types.experimental.ConcreteFunction 物件 (如要進一步瞭解,請查看圖表和 tf.function 簡介中的「函式何時追蹤?」)。如果您收到類似這樣的 ValueError,表示 Model.save 無法找到或建立追蹤的 ConcreteFunction

tf.saved_model.save(my_model, saved_model_path)
x = tf.saved_model.load(saved_model_path)
x.signatures

通常,模型的正向傳遞 (call 方法) 會在第一次呼叫模型時自動追蹤,通常是透過 Keras Model.fit 方法。如果您設定輸入形狀,例如將第一層設為 tf.keras.layers.InputLayer 或其他圖層類型,並將 input_shape 關鍵字引數傳遞給它,則 Keras SequentialFunctional API 也可以產生 ConcreteFunction

如要驗證您的模型是否有任何追蹤的 ConcreteFunction,請檢查 Model.save_spec 是否為 None

print(my_model.save_spec() is None)

我們使用 tf.keras.Model.fit 訓練模型,並注意 save_spec 已定義,且模型儲存作業可正常運作

BATCH_SIZE_PER_REPLICA = 4
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

dataset_size = 100
dataset = tf.data.Dataset.from_tensors(
    (tf.range(5, dtype=tf.float32), tf.range(5, dtype=tf.float32))
    ).repeat(dataset_size).batch(BATCH_SIZE)

my_model.compile(optimizer='adam', loss='mean_squared_error')
my_model.fit(dataset, epochs=2)

print(my_model.save_spec() is None)
my_model.save(saved_model_path)