![]() |
![]() |
![]() |
![]() |
「儲存 TensorFlow 模型」這個詞組通常意指以下兩種情況之一
- 檢查點,或
- SavedModel。
檢查點會擷取模型使用的所有參數 (tf.Variable
物件) 的確切值。檢查點不包含模型定義的任何運算說明,因此通常僅在可使用已儲存參數值的原始碼時才有用。
另一方面,SavedModel 格式除了參數值 (檢查點) 之外,還包含模型定義的運算序列化說明。此格式的模型與建立模型的原始碼無關。因此,這些模型適合透過 TensorFlow Serving、TensorFlow Lite、TensorFlow.js 或其他程式設計語言 (C、C++、Java、Go、Rust、C# 等 TensorFlow API) 的程式進行部署。
本指南涵蓋用於寫入和讀取檢查點的 API。
設定
import tensorflow as tf
class Net(tf.keras.Model):
"""A simple linear model."""
def __init__(self):
super(Net, self).__init__()
self.l1 = tf.keras.layers.Dense(5)
def call(self, x):
return self.l1(x)
net = Net()
從 tf.keras
訓練 API 儲存
請參閱 tf.keras
儲存和還原指南。
tf.keras.Model.save_weights
會儲存 TensorFlow 檢查點。
net.save_weights('easy_checkpoint')
寫入檢查點
TensorFlow 模型的持續狀態儲存在 tf.Variable
物件中。這些物件可以直接建構,但通常透過高階 API 建立,例如 tf.keras.layers
或 tf.keras.Model
。
管理變數最簡單的方式是將變數附加至 Python 物件,然後參照這些物件。
tf.train.Checkpoint
、tf.keras.layers.Layer
和 tf.keras.Model
的子類別會自動追蹤指派給其屬性的變數。以下範例會建構簡單的線性模型,然後寫入包含模型所有變數值的檢查點。
您可以使用 Model.save_weights
輕鬆儲存模型檢查點。
手動檢查點
設定
為了協助示範 tf.train.Checkpoint
的所有功能,請定義玩具資料集和最佳化步驟
def toy_dataset():
inputs = tf.range(10.)[:, None]
labels = inputs * 5. + tf.range(5.)[None, :]
return tf.data.Dataset.from_tensor_slices(
dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
"""Trains `net` on `example` using `optimizer`."""
with tf.GradientTape() as tape:
output = net(example['x'])
loss = tf.reduce_mean(tf.abs(output - example['y']))
variables = net.trainable_variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
return loss
建立檢查點物件
使用 tf.train.Checkpoint
物件手動建立檢查點,您要檢查點的物件會設定為物件的屬性。
tf.train.CheckpointManager
也可用於管理多個檢查點。
opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
訓練模型並建立檢查點
下列訓練迴圈會建立模型和最佳化工具的執行個體,然後將它們收集到 tf.train.Checkpoint
物件中。它會在每個批次的資料上迴圈呼叫訓練步驟,並定期將檢查點寫入磁碟。
def train_and_checkpoint(net, manager):
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
print("Restored from {}".format(manager.latest_checkpoint))
else:
print("Initializing from scratch.")
for _ in range(50):
example = next(iterator)
loss = train_step(net, example, opt)
ckpt.step.assign_add(1)
if int(ckpt.step) % 10 == 0:
save_path = manager.save()
print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
還原並繼續訓練
在第一個訓練週期之後,您可以傳遞新的模型和管理員,但可以從上次停止的位置繼續訓練
opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
train_and_checkpoint(net, manager)
tf.train.CheckpointManager
物件會刪除舊的檢查點。在上方,它設定為僅保留最近的三個檢查點。
print(manager.checkpoints) # List the three remaining checkpoints
這些路徑 (例如 './tf_ckpts/ckpt-10'
) 不是磁碟上的檔案。相反地,它們是 index
檔案和一或多個包含變數值的資料檔案的前置字串。這些前置字串會群組在單一 checkpoint
檔案 ('./tf_ckpts/checkpoint'
) 中,CheckpointManager
會在其中儲存其狀態。
ls ./tf_ckpts
載入機制
TensorFlow 會透過周遊具有命名邊緣的有向圖,從載入的物件開始,將變數與檢查點值進行比對。邊緣名稱通常來自物件中的屬性名稱,例如 self.l1 = tf.keras.layers.Dense(5)
中的 "l1"
。tf.train.Checkpoint
使用其關鍵字引數名稱,如 tf.train.Checkpoint(step=...)
中的 "step"
。
上方範例中的依附關係圖如下所示
最佳化工具為紅色,一般變數為藍色,而最佳化工具插槽變數為橘色。其他節點 (例如,代表 tf.train.Checkpoint
) 為黑色。
插槽變數是最佳化工具狀態的一部分,但針對特定變數建立。例如,上方的 'm'
邊緣對應於動量,Adam 最佳化工具會針對每個變數追蹤動量。只有在變數和最佳化工具都將儲存時,插槽變數才會儲存在檢查點中,因此邊緣為虛線。
在 tf.train.Checkpoint
物件上呼叫 restore
,會將要求的還原排入佇列,並在 Checkpoint
物件中找到相符路徑後立即還原變數值。例如,您可以透過網路和層重建單一路徑到偏差值,從上方定義的模型中僅載入偏差值。
to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy()) # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy()) # This gets the restored value.
這些新物件的依附關係圖是您上方寫入的較大檢查點的較小子圖。它僅包含偏差值和 tf.train.Checkpoint
用於編號檢查點的儲存計數器。
restore
會傳回狀態物件,其中具有選用判斷提示。新 Checkpoint
中建立的所有物件都已還原,因此 status.assert_existing_objects_matched
會通過。
status.assert_existing_objects_matched()
檢查點中有許多物件尚未比對,包括層的核心和最佳化工具的變數。status.assert_consumed
僅在檢查點和程式完全相符時才會通過,並且在此處會擲回例外狀況。
延遲還原
TensorFlow 中的 Layer
物件可能會將變數的建立延遲到第一次呼叫時 (當輸入形狀可用時)。例如,Dense
層核心的形狀取決於層的輸入和輸出形狀,因此作為建構函式引數所需的輸出形狀不足以自行建立變數。由於呼叫 Layer
也會讀取變數的值,因此還原必須在變數的建立及其第一次使用之間發生。
為了支援此慣用語,tf.train.Checkpoint
會延遲尚無相符變數的還原。
deferred_restore = tf.Variable(tf.zeros([1, 5]))
print(deferred_restore.numpy()) # Not restored; still zeros
fake_layer.kernel = deferred_restore
print(deferred_restore.numpy()) # Restored
手動檢查檢查點
tf.train.load_checkpoint
會傳回 CheckpointReader
,其提供對檢查點內容的較低層級存取權。它包含從每個變數的金鑰到檢查點中每個變數的形狀和 dtype 的對應。變數的金鑰是其物件路徑,如上方顯示的圖形所示。
reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()
sorted(shape_from_key.keys())
因此,如果您對 net.l1.kernel
的值感興趣,您可以使用下列程式碼取得該值
key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'
print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
它也提供 get_tensor
方法,讓您可以檢查變數的值
reader.get_tensor(key)
物件追蹤
檢查點會透過「追蹤」在其屬性中設定的任何變數或可追蹤物件,來儲存和還原 tf.Variable
物件的值。執行儲存時,變數會從所有可到達的追蹤物件以遞迴方式收集。
與直接屬性指派 (例如 self.l1 = tf.keras.layers.Dense(5)
) 一樣,將清單和字典指派給屬性將會追蹤其內容。
save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')
restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy() # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()
您可能會注意到清單和字典的包裝函式物件。這些包裝函式是基礎資料結構的可檢查點版本。就像基於屬性的載入一樣,這些包裝函式會在變數的值新增至容器後立即還原。
restore.listed = []
print(restore.listed) # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1) # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
可追蹤物件包括 tf.train.Checkpoint
、tf.Module
及其子類別 (例如 keras.layers.Layer
和 keras.Model
) 以及已辨識的 Python 容器
dict
(和collections.OrderedDict
)list
tuple
(和collections.namedtuple
、typing.NamedTuple
)
不支援其他容器類型,包括
collections.defaultdict
set
所有其他 Python 物件都會遭到忽略,包括
int
string
float
摘要
TensorFlow 物件提供簡單的自動機制,用於儲存和還原它們使用的變數值。