![]() |
![]() |
![]() |
![]() |
模型進度可以在訓練期間和訓練後儲存。這表示模型可以從中斷處繼續執行,並避免長時間的訓練。儲存也表示您可以分享模型,而其他人可以重現您的工作。在發布研究模型和技術時,大多數機器學習從業人員都會分享
- 用於建立模型的程式碼,以及
- 模型的已訓練權重或參數
分享這些資料有助於其他人瞭解模型運作方式,並自行使用新資料進行嘗試。
選項
儲存 TensorFlow 模型的方式會因您使用的 API 而異。本指南使用 tf.keras,這是一種用於在 TensorFlow 中建構和訓練模型的高階 API。本教學課程中使用的新高階 .keras
格式建議用於儲存 Keras 物件,因為它提供穩健、有效率的名稱型儲存機制,通常比低階或舊版格式更易於偵錯。如需更進階的儲存或序列化工作流程 (特別是涉及自訂物件的工作流程),請參閱儲存與載入 Keras 模型指南。如需其他方法,請參閱使用 SavedModel 格式指南。
設定
安裝與匯入
安裝並匯入 TensorFlow 和依附元件
pip install pyyaml h5py # Required to save models in HDF5 format
import os
import tensorflow as tf
from tensorflow import keras
print(tf.version.VERSION)
取得範例資料集
為了示範如何儲存和載入權重,您將使用 MNIST 資料集。為了加快執行速度,請使用前 1000 個範例
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_labels = train_labels[:1000]
test_labels = test_labels[:1000]
train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0
定義模型
從建構簡單的循序模型開始
# Define a simple sequential model
def create_model():
model = tf.keras.Sequential([
keras.layers.Dense(512, activation='relu', input_shape=(784,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(10)
])
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
return model
# Create a basic model instance
model = create_model()
# Display the model's architecture
model.summary()
在訓練期間儲存檢查點
您可以使用已訓練的模型,而不必重新訓練,或者在訓練程序中斷時從中斷處繼續訓練。tf.keras.callbacks.ModelCheckpoint
回呼可讓您在訓練期間和訓練結束時持續儲存模型。
檢查點回呼用法
建立 tf.keras.callbacks.ModelCheckpoint
回呼,僅在訓練期間儲存權重
checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
save_weights_only=True,
verbose=1)
# Train the model with the new callback
model.fit(train_images,
train_labels,
epochs=10,
validation_data=(test_images, test_labels),
callbacks=[cp_callback]) # Pass callback to training
# This may generate warnings related to saving the state of the optimizer.
# These warnings (and similar warnings throughout this notebook)
# are in place to discourage outdated usage, and can be ignored.
這會建立單一 TensorFlow 檢查點檔案集合,這些檔案會在每個 epoch 結束時更新
os.listdir(checkpoint_dir)
只要兩個模型共用相同的架構,您就可以在它們之間共用權重。因此,從僅限權重的檢查點還原模型時,請建立一個與原始模型架構相同的模型,然後設定其權重。
現在重建一個全新的未訓練模型,並在測試集上評估它。未訓練的模型將隨機執行 (準確度約 10%)
# Create a basic model instance
model = create_model()
# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Untrained model, accuracy: {:5.2f}%".format(100 * acc))
然後從檢查點載入權重並重新評估
# Loads the weights
model.load_weights(checkpoint_path)
# Re-evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
檢查點回呼選項
此回呼提供多個選項,可為檢查點提供獨特的名稱,並調整檢查點頻率。
訓練新模型,並每五個 epoch 儲存一次具有獨特名稱的檢查點
# Include the epoch in the file name (uses `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
batch_size = 32
# Calculate the number of batches per epoch
import math
n_batches = len(train_images) / batch_size
n_batches = math.ceil(n_batches) # round up the number of batches to the nearest whole integer
# Create a callback that saves the model's weights every 5 epochs
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
verbose=1,
save_weights_only=True,
save_freq=5*n_batches)
# Create a new model instance
model = create_model()
# Save the weights using the `checkpoint_path` format
model.save_weights(checkpoint_path.format(epoch=0))
# Train the model with the new callback
model.fit(train_images,
train_labels,
epochs=50,
batch_size=batch_size,
callbacks=[cp_callback],
validation_data=(test_images, test_labels),
verbose=0)
現在,檢閱產生的檢查點並選擇最新的檢查點
os.listdir(checkpoint_dir)
latest = tf.train.latest_checkpoint(checkpoint_dir)
latest
若要測試,請重設模型並載入最新的檢查點
# Create a new model instance
model = create_model()
# Load the previously saved weights
model.load_weights(latest)
# Re-evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
這些檔案是什麼?
上述程式碼將權重儲存到 檢查點格式檔案的集合中,這些檔案僅包含二進位格式的已訓練權重。檢查點包含
- 一個或多個包含模型權重的分片。
- 索引檔案,指出哪些權重儲存在哪個分片中。
如果您在單一機器上訓練模型,則會有一個分片,其後置字元為:.data-00000-of-00001
手動儲存權重
若要手動儲存權重,請使用 tf.keras.Model.save_weights
。根據預設,tf.keras
(尤其是 Model.save_weights
方法) 會使用 TensorFlow Checkpoint 格式,並加上 .ckpt
副檔名。若要以 HDF5 格式儲存並加上 .h5
副檔名,請參閱儲存與載入模型指南。
# Save the weights
model.save_weights('./checkpoints/my_checkpoint')
# Create a new model instance
model = create_model()
# Restore the weights
model.load_weights('./checkpoints/my_checkpoint')
# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
儲存整個模型
呼叫 tf.keras.Model.save
,將模型的架構、權重和訓練設定儲存到單一 model.keras
zip 封存檔中。
整個模型可以儲存為三種不同的檔案格式 (新的 .keras
格式和兩種舊版格式:SavedModel
和 HDF5
)。將模型儲存為 path/to/model.keras
會自動以最新格式儲存。
您可以透過以下方式切換至 SavedModel 格式:
- 將
save_format='tf'
傳遞至save()
- 傳遞不含副檔名的檔案名稱
您可以透過以下方式切換至 H5 格式:
- 將
save_format='h5'
傳遞至save()
- 傳遞以
.h5
結尾的檔案名稱
儲存功能完善的模型非常有用,您可以將其載入 TensorFlow.js (Saved Model、HDF5),然後在網頁瀏覽器中訓練和執行,或使用 TensorFlow Lite 將其轉換為在行動裝置上執行 (Saved Model、HDF5)
*自訂物件 (例如,子類別化的模型或層) 在儲存和載入時需要特別注意。請參閱下方的儲存自訂物件章節。
新的高階 .keras
格式
新的 Keras v3 儲存格式以 .keras
副檔名標示,是一種更簡單、有效率的格式,可實作名稱型儲存,確保從 Python 的角度來看,您載入的內容與您儲存的內容完全相同。這讓偵錯變得更加容易,而且是 Keras 的建議格式。
以下章節說明如何在 .keras
格式中儲存和還原模型。
# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)
# Save the entire model as a `.keras` zip archive.
model.save('my_model.keras')
從 .keras
zip 封存檔重新載入全新的 Keras 模型
new_model = tf.keras.models.load_model('my_model.keras')
# Show the model architecture
new_model.summary()
嘗試使用已載入的模型執行評估和預測
# Evaluate the restored model
loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))
print(new_model.predict(test_images).shape)
SavedModel 格式
SavedModel 格式是另一種序列化模型的方式。以這種格式儲存的模型可以使用 tf.keras.models.load_model
還原,並且與 TensorFlow Serving 相容。SavedModel 指南詳細說明如何 serve/inspect
SavedModel。以下章節說明儲存和還原模型的步驟。
# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)
# Save the entire model as a SavedModel.
!mkdir -p saved_model
model.save('saved_model/my_model')
SavedModel 格式是一個目錄,其中包含 protobuf 二進位檔和 TensorFlow 檢查點。檢查已儲存的模型目錄
# my_model directory
ls saved_model
# Contains an assets folder, saved_model.pb, and variables folder.
ls saved_model/my_model
從已儲存的模型重新載入全新的 Keras 模型
new_model = tf.keras.models.load_model('saved_model/my_model')
# Check its architecture
new_model.summary()
還原的模型會使用與原始模型相同的引數進行編譯。嘗試使用已載入的模型執行評估和預測
# Evaluate the restored model
loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))
print(new_model.predict(test_images).shape)
HDF5 格式
Keras 提供基本舊版高階儲存格式,使用 HDF5 標準。
# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)
# Save the entire model to a HDF5 file.
# The '.h5' extension indicates that the model should be saved to HDF5.
model.save('my_model.h5')
現在,從該檔案重新建立模型
# Recreate the exact same model, including its weights and the optimizer
new_model = tf.keras.models.load_model('my_model.h5')
# Show the model architecture
new_model.summary()
檢查其準確度
loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))
Keras 透過檢查模型的架構來儲存模型。此技術會儲存所有內容
- 權重值
- 模型的架構
- 模型的訓練設定 (您傳遞至
.compile()
方法的內容) - 最佳化工具及其狀態 (如果有的話) (這可讓您從中斷處重新開始訓練)
Keras 無法儲存 v1.x
最佳化工具 (來自 tf.compat.v1.train
),因為它們與檢查點不相容。對於 v1.x 最佳化工具,您需要在載入後重新編譯模型,這會導致最佳化工具的狀態遺失。
儲存自訂物件
如果您使用 SavedModel 格式,則可以略過本節。高階 .keras
/HDF5 格式與低階 SavedModel 格式之間的關鍵差異在於,.keras
/HDF5 格式使用物件組態來儲存模型架構,而 SavedModel 則儲存執行圖。因此,SavedModel 能夠儲存自訂物件,例如子類別化的模型和自訂層,而無需原始程式碼。但是,偵錯低階 SavedModel 可能會因此變得更加困難,而且我們建議改用高階 .keras
格式,因為它具有以名稱為基礎的 Keras 原生特性。
若要將自訂物件儲存到 .keras
和 HDF5,您必須執行以下操作
- 在您的物件中定義
get_config
方法,並選擇性地定義from_config
類別方法。get_config(self)
會傳回 JSON 可序列化字典,其中包含重建物件所需的參數。from_config(cls, config)
會使用從get_config
傳回的組態來建立新物件。根據預設,此函式會將組態用作初始化 kwargs (return cls(**config)
)。
- 透過以下三種方式之一將自訂物件傳遞至模型
- 使用
@tf.keras.utils.register_keras_serializable
裝飾器註冊自訂物件。(建議) - 在載入模型時,將物件直接傳遞至
custom_objects
引數。引數必須是字典,將字串類別名稱對應至 Python 類別。例如,tf.keras.models.load_model(path, custom_objects={'CustomLayer': CustomLayer})
- 使用
tf.keras.utils.custom_object_scope
,其中物件包含在custom_objects
字典引數中,並在範圍內放置tf.keras.models.load_model(path)
呼叫。
- 使用
如需自訂物件和 get_config
的範例,請參閱從頭開始撰寫層和模型教學課程。
# MIT License
#
# Copyright (c) 2017 François Chollet
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.