儲存與載入模型

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視原始碼 下載筆記本

模型進度可以在訓練期間和訓練後儲存。這表示模型可以從中斷處繼續執行,並避免長時間的訓練。儲存也表示您可以分享模型,而其他人可以重現您的工作。在發布研究模型和技術時,大多數機器學習從業人員都會分享

  • 用於建立模型的程式碼,以及
  • 模型的已訓練權重或參數

分享這些資料有助於其他人瞭解模型運作方式,並自行使用新資料進行嘗試。

選項

儲存 TensorFlow 模型的方式會因您使用的 API 而異。本指南使用 tf.keras,這是一種用於在 TensorFlow 中建構和訓練模型的高階 API。本教學課程中使用的新高階 .keras 格式建議用於儲存 Keras 物件,因為它提供穩健、有效率的名稱型儲存機制,通常比低階或舊版格式更易於偵錯。如需更進階的儲存或序列化工作流程 (特別是涉及自訂物件的工作流程),請參閱儲存與載入 Keras 模型指南。如需其他方法,請參閱使用 SavedModel 格式指南

設定

安裝與匯入

安裝並匯入 TensorFlow 和依附元件

pip install pyyaml h5py  # Required to save models in HDF5 format
import os

import tensorflow as tf
from tensorflow import keras

print(tf.version.VERSION)

取得範例資料集

為了示範如何儲存和載入權重,您將使用 MNIST 資料集。為了加快執行速度,請使用前 1000 個範例

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

train_labels = train_labels[:1000]
test_labels = test_labels[:1000]

train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0

定義模型

從建構簡單的循序模型開始

# Define a simple sequential model
def create_model():
  model = tf.keras.Sequential([
    keras.layers.Dense(512, activation='relu', input_shape=(784,)),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10)
  ])

  model.compile(optimizer='adam',
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

  return model

# Create a basic model instance
model = create_model()

# Display the model's architecture
model.summary()

在訓練期間儲存檢查點

您可以使用已訓練的模型,而不必重新訓練,或者在訓練程序中斷時從中斷處繼續訓練。tf.keras.callbacks.ModelCheckpoint 回呼可讓您在訓練期間和訓練結束時持續儲存模型。

檢查點回呼用法

建立 tf.keras.callbacks.ModelCheckpoint 回呼,僅在訓練期間儲存權重

checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)

# Train the model with the new callback
model.fit(train_images, 
          train_labels,  
          epochs=10,
          validation_data=(test_images, test_labels),
          callbacks=[cp_callback])  # Pass callback to training

# This may generate warnings related to saving the state of the optimizer.
# These warnings (and similar warnings throughout this notebook)
# are in place to discourage outdated usage, and can be ignored.

這會建立單一 TensorFlow 檢查點檔案集合,這些檔案會在每個 epoch 結束時更新

os.listdir(checkpoint_dir)

只要兩個模型共用相同的架構,您就可以在它們之間共用權重。因此,從僅限權重的檢查點還原模型時,請建立一個與原始模型架構相同的模型,然後設定其權重。

現在重建一個全新的未訓練模型,並在測試集上評估它。未訓練的模型將隨機執行 (準確度約 10%)

# Create a basic model instance
model = create_model()

# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Untrained model, accuracy: {:5.2f}%".format(100 * acc))

然後從檢查點載入權重並重新評估

# Loads the weights
model.load_weights(checkpoint_path)

# Re-evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))

檢查點回呼選項

此回呼提供多個選項,可為檢查點提供獨特的名稱,並調整檢查點頻率。

訓練新模型,並每五個 epoch 儲存一次具有獨特名稱的檢查點

# Include the epoch in the file name (uses `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

batch_size = 32

# Calculate the number of batches per epoch
import math
n_batches = len(train_images) / batch_size
n_batches = math.ceil(n_batches)    # round up the number of batches to the nearest whole integer

# Create a callback that saves the model's weights every 5 epochs
cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path, 
    verbose=1, 
    save_weights_only=True,
    save_freq=5*n_batches)

# Create a new model instance
model = create_model()

# Save the weights using the `checkpoint_path` format
model.save_weights(checkpoint_path.format(epoch=0))

# Train the model with the new callback
model.fit(train_images, 
          train_labels,
          epochs=50, 
          batch_size=batch_size, 
          callbacks=[cp_callback],
          validation_data=(test_images, test_labels),
          verbose=0)

現在,檢閱產生的檢查點並選擇最新的檢查點

os.listdir(checkpoint_dir)
latest = tf.train.latest_checkpoint(checkpoint_dir)
latest

若要測試,請重設模型並載入最新的檢查點

# Create a new model instance
model = create_model()

# Load the previously saved weights
model.load_weights(latest)

# Re-evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))

這些檔案是什麼?

上述程式碼將權重儲存到 檢查點格式檔案的集合中,這些檔案僅包含二進位格式的已訓練權重。檢查點包含

  • 一個或多個包含模型權重的分片。
  • 索引檔案,指出哪些權重儲存在哪個分片中。

如果您在單一機器上訓練模型,則會有一個分片,其後置字元為:.data-00000-of-00001

手動儲存權重

若要手動儲存權重,請使用 tf.keras.Model.save_weights。根據預設,tf.keras (尤其是 Model.save_weights 方法) 會使用 TensorFlow Checkpoint 格式,並加上 .ckpt 副檔名。若要以 HDF5 格式儲存並加上 .h5 副檔名,請參閱儲存與載入模型指南。

# Save the weights
model.save_weights('./checkpoints/my_checkpoint')

# Create a new model instance
model = create_model()

# Restore the weights
model.load_weights('./checkpoints/my_checkpoint')

# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))

儲存整個模型

呼叫 tf.keras.Model.save,將模型的架構、權重和訓練設定儲存到單一 model.keras zip 封存檔中。

整個模型可以儲存為三種不同的檔案格式 (新的 .keras 格式和兩種舊版格式:SavedModelHDF5)。將模型儲存為 path/to/model.keras 會自動以最新格式儲存。

您可以透過以下方式切換至 SavedModel 格式:

  • save_format='tf' 傳遞至 save()
  • 傳遞不含副檔名的檔案名稱

您可以透過以下方式切換至 H5 格式:

  • save_format='h5' 傳遞至 save()
  • 傳遞以 .h5 結尾的檔案名稱

儲存功能完善的模型非常有用,您可以將其載入 TensorFlow.js (Saved ModelHDF5),然後在網頁瀏覽器中訓練和執行,或使用 TensorFlow Lite 將其轉換為在行動裝置上執行 (Saved ModelHDF5)

*自訂物件 (例如,子類別化的模型或層) 在儲存和載入時需要特別注意。請參閱下方的儲存自訂物件章節。

新的高階 .keras 格式

新的 Keras v3 儲存格式以 .keras 副檔名標示,是一種更簡單、有效率的格式,可實作名稱型儲存,確保從 Python 的角度來看,您載入的內容與您儲存的內容完全相同。這讓偵錯變得更加容易,而且是 Keras 的建議格式。

以下章節說明如何在 .keras 格式中儲存和還原模型。

# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)

# Save the entire model as a `.keras` zip archive.
model.save('my_model.keras')

.keras zip 封存檔重新載入全新的 Keras 模型

new_model = tf.keras.models.load_model('my_model.keras')

# Show the model architecture
new_model.summary()

嘗試使用已載入的模型執行評估和預測

# Evaluate the restored model
loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))

print(new_model.predict(test_images).shape)

SavedModel 格式

SavedModel 格式是另一種序列化模型的方式。以這種格式儲存的模型可以使用 tf.keras.models.load_model 還原,並且與 TensorFlow Serving 相容。SavedModel 指南詳細說明如何 serve/inspect SavedModel。以下章節說明儲存和還原模型的步驟。

# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)

# Save the entire model as a SavedModel.
!mkdir -p saved_model
model.save('saved_model/my_model')

SavedModel 格式是一個目錄,其中包含 protobuf 二進位檔和 TensorFlow 檢查點。檢查已儲存的模型目錄

# my_model directory
ls saved_model

# Contains an assets folder, saved_model.pb, and variables folder.
ls saved_model/my_model

從已儲存的模型重新載入全新的 Keras 模型

new_model = tf.keras.models.load_model('saved_model/my_model')

# Check its architecture
new_model.summary()

還原的模型會使用與原始模型相同的引數進行編譯。嘗試使用已載入的模型執行評估和預測

# Evaluate the restored model
loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))

print(new_model.predict(test_images).shape)

HDF5 格式

Keras 提供基本舊版高階儲存格式,使用 HDF5 標準。

# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)

# Save the entire model to a HDF5 file.
# The '.h5' extension indicates that the model should be saved to HDF5.
model.save('my_model.h5')

現在,從該檔案重新建立模型

# Recreate the exact same model, including its weights and the optimizer
new_model = tf.keras.models.load_model('my_model.h5')

# Show the model architecture
new_model.summary()

檢查其準確度

loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))

Keras 透過檢查模型的架構來儲存模型。此技術會儲存所有內容

  • 權重值
  • 模型的架構
  • 模型的訓練設定 (您傳遞至 .compile() 方法的內容)
  • 最佳化工具及其狀態 (如果有的話) (這可讓您從中斷處重新開始訓練)

Keras 無法儲存 v1.x 最佳化工具 (來自 tf.compat.v1.train),因為它們與檢查點不相容。對於 v1.x 最佳化工具,您需要在載入後重新編譯模型,這會導致最佳化工具的狀態遺失。

儲存自訂物件

如果您使用 SavedModel 格式,則可以略過本節。高階 .keras/HDF5 格式與低階 SavedModel 格式之間的關鍵差異在於,.keras/HDF5 格式使用物件組態來儲存模型架構,而 SavedModel 則儲存執行圖。因此,SavedModel 能夠儲存自訂物件,例如子類別化的模型和自訂層,而無需原始程式碼。但是,偵錯低階 SavedModel 可能會因此變得更加困難,而且我們建議改用高階 .keras 格式,因為它具有以名稱為基礎的 Keras 原生特性。

若要將自訂物件儲存到 .keras 和 HDF5,您必須執行以下操作

  1. 在您的物件中定義 get_config 方法,並選擇性地定義 from_config 類別方法。
    • get_config(self) 會傳回 JSON 可序列化字典,其中包含重建物件所需的參數。
    • from_config(cls, config) 會使用從 get_config 傳回的組態來建立新物件。根據預設,此函式會將組態用作初始化 kwargs (return cls(**config))。
  2. 透過以下三種方式之一將自訂物件傳遞至模型

如需自訂物件和 get_config 的範例,請參閱從頭開始撰寫層和模型教學課程。

# MIT License
#
# Copyright (c) 2017 François Chollet
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.