訓練檢查點

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視原始碼 下載筆記本

「儲存 TensorFlow 模型」這個詞組通常意指以下兩種情況之一

  1. 檢查點,或
  2. SavedModel。

檢查點會擷取模型使用的所有參數 (tf.Variable 物件) 的確切值。檢查點不包含模型定義的任何運算說明,因此通常僅在可使用已儲存參數值的原始碼時才有用。

另一方面,SavedModel 格式除了參數值 (檢查點) 之外,還包含模型定義的運算序列化說明。此格式的模型與建立模型的原始碼無關。因此,這些模型適合透過 TensorFlow Serving、TensorFlow Lite、TensorFlow.js 或其他程式設計語言 (C、C++、Java、Go、Rust、C# 等 TensorFlow API) 的程式進行部署。

本指南涵蓋用於寫入和讀取檢查點的 API。

設定

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

tf.keras 訓練 API 儲存

請參閱 tf.keras 儲存和還原指南。

tf.keras.Model.save_weights 會儲存 TensorFlow 檢查點。

net.save_weights('easy_checkpoint')

寫入檢查點

TensorFlow 模型的持續狀態儲存在 tf.Variable 物件中。這些物件可以直接建構,但通常透過高階 API 建立,例如 tf.keras.layerstf.keras.Model

管理變數最簡單的方式是將變數附加至 Python 物件,然後參照這些物件。

tf.train.Checkpointtf.keras.layers.Layertf.keras.Model 的子類別會自動追蹤指派給其屬性的變數。以下範例會建構簡單的線性模型,然後寫入包含模型所有變數值的檢查點。

您可以使用 Model.save_weights 輕鬆儲存模型檢查點。

手動檢查點

設定

為了協助示範 tf.train.Checkpoint 的所有功能,請定義玩具資料集和最佳化步驟

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

建立檢查點物件

使用 tf.train.Checkpoint 物件手動建立檢查點,您要檢查點的物件會設定為物件的屬性。

tf.train.CheckpointManager 也可用於管理多個檢查點。

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

訓練模型並建立檢查點

下列訓練迴圈會建立模型和最佳化工具的執行個體,然後將它們收集到 tf.train.Checkpoint 物件中。它會在每個批次的資料上迴圈呼叫訓練步驟,並定期將檢查點寫入磁碟。

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)

還原並繼續訓練

在第一個訓練週期之後,您可以傳遞新的模型和管理員,但可以從上次停止的位置繼續訓練

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)

tf.train.CheckpointManager 物件會刪除舊的檢查點。在上方,它設定為僅保留最近的三個檢查點。

print(manager.checkpoints)  # List the three remaining checkpoints

這些路徑 (例如 './tf_ckpts/ckpt-10') 不是磁碟上的檔案。相反地,它們是 index 檔案和一或多個包含變數值的資料檔案的前置字串。這些前置字串會群組在單一 checkpoint 檔案 ('./tf_ckpts/checkpoint') 中,CheckpointManager 會在其中儲存其狀態。

ls ./tf_ckpts

載入機制

TensorFlow 會透過周遊具有命名邊緣的有向圖,從載入的物件開始,將變數與檢查點值進行比對。邊緣名稱通常來自物件中的屬性名稱,例如 self.l1 = tf.keras.layers.Dense(5) 中的 "l1"tf.train.Checkpoint 使用其關鍵字引數名稱,如 tf.train.Checkpoint(step=...) 中的 "step"

上方範例中的依附關係圖如下所示

Visualization of the dependency graph for the example training loop

最佳化工具為紅色,一般變數為藍色,而最佳化工具插槽變數為橘色。其他節點 (例如,代表 tf.train.Checkpoint) 為黑色。

插槽變數是最佳化工具狀態的一部分,但針對特定變數建立。例如,上方的 'm' 邊緣對應於動量,Adam 最佳化工具會針對每個變數追蹤動量。只有在變數和最佳化工具都將儲存時,插槽變數才會儲存在檢查點中,因此邊緣為虛線。

tf.train.Checkpoint 物件上呼叫 restore,會將要求的還原排入佇列,並在 Checkpoint 物件中找到相符路徑後立即還原變數值。例如,您可以透過網路和層重建單一路徑到偏差值,從上方定義的模型中僅載入偏差值。

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.

這些新物件的依附關係圖是您上方寫入的較大檢查點的較小子圖。它僅包含偏差值和 tf.train.Checkpoint 用於編號檢查點的儲存計數器。

Visualization of a subgraph for the bias variable

restore 會傳回狀態物件,其中具有選用判斷提示。新 Checkpoint 中建立的所有物件都已還原,因此 status.assert_existing_objects_matched 會通過。

status.assert_existing_objects_matched()

檢查點中有許多物件尚未比對,包括層的核心和最佳化工具的變數。status.assert_consumed 僅在檢查點和程式完全相符時才會通過,並且在此處會擲回例外狀況。

延遲還原

TensorFlow 中的 Layer 物件可能會將變數的建立延遲到第一次呼叫時 (當輸入形狀可用時)。例如,Dense 層核心的形狀取決於層的輸入和輸出形狀,因此作為建構函式引數所需的輸出形狀不足以自行建立變數。由於呼叫 Layer 也會讀取變數的值,因此還原必須在變數的建立及其第一次使用之間發生。

為了支援此慣用語,tf.train.Checkpoint 會延遲尚無相符變數的還原。

deferred_restore = tf.Variable(tf.zeros([1, 5]))
print(deferred_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = deferred_restore
print(deferred_restore.numpy())  # Restored

手動檢查檢查點

tf.train.load_checkpoint 會傳回 CheckpointReader,其提供對檢查點內容的較低層級存取權。它包含從每個變數的金鑰到檢查點中每個變數的形狀和 dtype 的對應。變數的金鑰是其物件路徑,如上方顯示的圖形所示。

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())

因此,如果您對 net.l1.kernel 的值感興趣,您可以使用下列程式碼取得該值

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)

它也提供 get_tensor 方法,讓您可以檢查變數的值

reader.get_tensor(key)

物件追蹤

檢查點會透過「追蹤」在其屬性中設定的任何變數或可追蹤物件,來儲存和還原 tf.Variable 物件的值。執行儲存時,變數會從所有可到達的追蹤物件以遞迴方式收集。

與直接屬性指派 (例如 self.l1 = tf.keras.layers.Dense(5)) 一樣,將清單和字典指派給屬性將會追蹤其內容。

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

您可能會注意到清單和字典的包裝函式物件。這些包裝函式是基礎資料結構的可檢查點版本。就像基於屬性的載入一樣,這些包裝函式會在變數的值新增至容器後立即還原。

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()

可追蹤物件包括 tf.train.Checkpointtf.Module 及其子類別 (例如 keras.layers.Layerkeras.Model) 以及已辨識的 Python 容器

  • dict (和 collections.OrderedDict)
  • list
  • tuple (和 collections.namedtupletyping.NamedTuple)

不支援其他容器類型,包括

  • collections.defaultdict
  • set

所有其他 Python 物件都會遭到忽略,包括

  • int
  • string
  • float

摘要

TensorFlow 物件提供簡單的自動機制,用於儲存和還原它們使用的變數值。