遷移模型檢查點

View on TensorFlow.org Run in Google Colab View on GitHub Download notebook

總覽

本指南假設您有一個使用 tf.compat.v1.Saver 儲存和載入檢查點的模型,並且想要遷移程式碼以使用 TF2 tf.train.Checkpoint API,或在您的 TF2 模型中使用預先存在的檢查點。

以下是您可能會遇到的一些常見情境

情境 1

有先前訓練執行產生的現有 TF1 檢查點需要載入或轉換為 TF2。

情境 2

您正在調整模型,調整方式可能會變更變數名稱和路徑 (例如,從 get_variable 逐步遷移至明確的 tf.Variable 建立),並且希望在遷移過程中維持現有檢查點的儲存/載入。

請參閱如何在模型遷移期間維持檢查點相容性一節

情境 3

您正在將訓練程式碼和檢查點遷移至 TF2,但您的推論管線目前仍需要 TF1 檢查點 (以維持生產環境穩定性)。

選項 1

在訓練時同時儲存 TF1 和 TF2 檢查點。

選項 2

將 TF2 檢查點轉換為 TF1。


以下範例顯示在 TF1/TF2 中儲存和載入檢查點的所有組合,讓您在決定如何遷移模型時具有彈性。

設定

import tensorflow as tf
import tensorflow.compat.v1 as tf1

def print_checkpoint(save_path):
  reader = tf.train.load_checkpoint(save_path)
  shapes = reader.get_variable_to_shape_map()
  dtypes = reader.get_variable_to_dtype_map()
  print(f"Checkpoint at '{save_path}':")
  for key in shapes:
    print(f"  (key='{key}', shape={shapes[key]}, dtype={dtypes[key].name}, "
          f"value={reader.get_tensor(key)})")

從 TF1 到 TF2 的變更

如果您想知道 TF1 和 TF2 之間有哪些變更,以及我們所說的「名稱型」(TF1) 與「物件型」(TF2) 檢查點是什麼意思,本節將說明。

這兩種檢查點實際上是以相同的格式儲存,基本上是鍵值表。差異在於金鑰的產生方式。

名稱型檢查點中的金鑰是變數的名稱。物件型檢查點中的金鑰是指從根物件到變數的路徑 (以下範例將有助於您更瞭解其含義)。

首先,儲存一些檢查點

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  with tf1.Session() as sess:
    saver = tf1.train.Saver()
    sess.run(a.assign(1))
    sess.run(b.assign(2))
    sess.run(c.assign(3))
    saver.save(sess, 'tf1-ckpt')

print_checkpoint('tf1-ckpt')
a = tf.Variable(5.0, name='a')
b = tf.Variable(6.0, name='b')
with tf.name_scope('scoped'):
  c = tf.Variable(7.0, name='c')

ckpt = tf.train.Checkpoint(variables=[a, b, c])
save_path_v2 = ckpt.save('tf2-ckpt')
print_checkpoint(save_path_v2)

如果您查看 tf2-ckpt 中的金鑰,它們都指向每個變數的物件路徑。例如,變數 avariables 清單中的第一個元素,因此其金鑰會變成 variables/0/... (您可以忽略 .ATTRIBUTES/VARIABLE_VALUE 常數)。

更仔細檢查下方的 Checkpoint 物件

a = tf.Variable(0.)
b = tf.Variable(0.)
c = tf.Variable(0.)
root = ckpt = tf.train.Checkpoint(variables=[a, b, c])
print("root type =", type(root).__name__)
print("root.variables =", root.variables)
print("root.variables[0] =", root.variables[0])

試著使用下方的程式碼片段進行實驗,看看檢查點金鑰如何隨著物件結構而變更

module = tf.Module()
module.d = tf.Variable(0.)
test_ckpt = tf.train.Checkpoint(v={'a': a, 'b': b}, 
                                c=c,
                                module=module)
test_ckpt_path = test_ckpt.save('root-tf2-ckpt')
print_checkpoint(test_ckpt_path)

為什麼 TF2 使用這種機制?

由於 TF2 中不再有全域圖表,因此變數名稱不可靠,且在程式之間可能不一致。TF2 鼓勵物件導向建模方法,其中變數由層擁有,而層由模型擁有

variable = tf.Variable(...)
layer.variable_name = variable
model.layer_name = layer

如何在模型遷移期間維持檢查點相容性

遷移過程中的一個重要步驟是確保所有變數都初始化為正確的值,這反過來可讓您驗證運算/函式是否執行正確的計算。若要完成此步驟,您必須考量遷移各階段模型之間的檢查點相容性。基本上,本節回答的問題是:如何在變更模型的同時繼續使用相同的檢查點

以下是三種維持檢查點相容性的方法,依彈性遞增順序排列

  1. 模型具有與先前相同的變數名稱
  2. 模型具有不同的變數名稱,並維護指派對應表,將檢查點中的變數名稱對應至新名稱。
  3. 模型具有不同的變數名稱,並維護儲存所有變數的 TF2 Checkpoint 物件

當變數名稱相符時

長標題:當變數名稱相符時,如何重複使用檢查點。

簡短解答:您可以直接使用 tf1.train.Savertf.train.Checkpoint 載入預先存在的檢查點。


如果您使用 tf.compat.v1.keras.utils.track_tf1_style_variables,則可確保您的模型變數名稱與先前相同。您也可以手動確保變數名稱相符。

當遷移模型中的變數名稱相符時,您可以直接使用 tf.train.Checkpointtf.compat.v1.train.Saver 載入檢查點。這兩種 API 都與 eager 模式和圖表模式相容,因此您可以在遷移的任何階段使用它們。

以下範例說明如何搭配不同模型使用相同的檢查點。首先,使用 tf1.train.Saver 儲存 TF1 檢查點

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  with tf1.Session() as sess:
    saver = tf1.train.Saver()
    sess.run(a.assign(1))
    sess.run(b.assign(2))
    sess.run(c.assign(3))
    save_path = saver.save(sess, 'tf1-ckpt')
print_checkpoint(save_path)

以下範例說明如何在 eager 模式中使用 tf.compat.v1.Saver 載入檢查點

a = tf.Variable(0.0, name='a')
b = tf.Variable(0.0, name='b')
with tf.name_scope('scoped'):
  c = tf.Variable(0.0, name='c')

# With the removal of collections in TF2, you must pass in the list of variables
# to the Saver object:
saver = tf1.train.Saver(var_list=[a, b, c])
saver.restore(sess=None, save_path=save_path)
print(f"loaded values of [a, b, c]:  [{a.numpy()}, {b.numpy()}, {c.numpy()}]")

# Saving also works in eager (sess must be None).
path = saver.save(sess=None, save_path='tf1-ckpt-saved-in-eager')
print_checkpoint(path)

下一個程式碼片段說明如何使用 TF2 API tf.train.Checkpoint 載入檢查點

a = tf.Variable(0.0, name='a')
b = tf.Variable(0.0, name='b')
with tf.name_scope('scoped'):
  c = tf.Variable(0.0, name='c')

# Without the name_scope, name="scoped/c" works too:
c_2 = tf.Variable(0.0, name='scoped/c')

print("Variable names: ")
print(f"  a.name = {a.name}")
print(f"  b.name = {b.name}")
print(f"  c.name = {c.name}")
print(f"  c_2.name = {c_2.name}")

# Restore the values with tf.train.Checkpoint
ckpt = tf.train.Checkpoint(variables=[a, b, c, c_2])
ckpt.restore(save_path)
print(f"loaded values of [a, b, c, c_2]:  [{a.numpy()}, {b.numpy()}, {c.numpy()}, {c_2.numpy()}]")

TF2 中的變數名稱

  • 變數仍然都有您可以設定的 name 引數。
  • Keras 模型也會採用 name 引數,並將其設定為變數的前置字元。
  • v1.name_scope 函式可用於設定變數名稱前置字元。這與 tf.variable_scope 非常不同。它只會影響名稱,不會追蹤變數和重複使用。

tf.compat.v1.keras.utils.track_tf1_style_variables 裝飾器是一個 shim,可協助您維持變數名稱和 TF1 檢查點相容性,方法是保持 tf.variable_scopetf.compat.v1.get_variable 的命名和重複使用語意不變。如需更多資訊,請參閱模型對應指南

注意 1:如果您使用 shim,請使用 TF2 API 載入檢查點 (即使在使用預先訓練的 TF1 檢查點時也是如此)。

請參閱檢查點 Keras 一節。

注意 2:從 get_variable 遷移至 tf.Variable

如果您的 shim 裝飾層或模組包含一些變數 (或 Keras 層/模型),這些變數使用 tf.Variable 而非 tf.compat.v1.get_variable,並以物件導向方式附加為屬性/追蹤,則它們在 TF1.x 圖表/工作階段中,與在 eager 執行期間,可能具有不同的變數命名語意。

簡而言之,當在 TF2 中執行時,名稱可能與您預期的不同

維護指派對應表

指派對應表通常用於在 TF1 模型之間傳輸權重,如果變數名稱變更,也可以在模型遷移期間使用。

您可以搭配 tf.compat.v1.train.init_from_checkpointtf.compat.v1.train.Savertf.train.load_checkpoint 使用這些對應表,將權重載入變數或範圍名稱可能已變更的模型中。

本節中的範例將使用先前儲存的檢查點

print_checkpoint('tf1-ckpt')

使用 init_from_checkpoint 載入

tf1.train.init_from_checkpoint 必須在圖表/工作階段中呼叫,因為它會將值放在變數初始設定式中,而不是建立指派運算。

您可以使用 assignment_map 引數設定變數的載入方式。摘錄自文件:

指派對應表支援下列語法

  • 'checkpoint_scope_name/': 'scope_name/' - 會從 checkpoint_scope_name 載入目前 scope_name 中的所有變數,並比對張量名稱。
  • 'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name' - 會從 checkpoint_scope_name/some_other_variable 初始化 scope_name/variable_name 變數。
  • 'scope_variable_name': variable - 會使用檢查點中的張量 'scope_variable_name' 初始化指定的 tf.Variable 物件。
  • 'scope_variable_name': list(variable) - 會使用檢查點中的張量 'scope_variable_name' 初始化分割變數清單。
  • '/': 'scope_name/' - 會從檢查點的根目錄 (例如,沒有範圍) 載入目前 scope_name 中的所有變數。
# Restoring with tf1.train.init_from_checkpoint:

# A new model with a different scope for the variables.
with tf.Graph().as_default() as g:
  with tf1.variable_scope('new_scope'):
    a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
    b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
    c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
  with tf1.Session() as sess:
    # The assignment map will remap all variables in the checkpoint to the
    # new scope:
    tf1.train.init_from_checkpoint(
        'tf1-ckpt',
        assignment_map={'/': 'new_scope/'})
    # `init_from_checkpoint` adds the initializers to these variables.
    # Use `sess.run` to run these initializers.
    sess.run(tf1.global_variables_initializer())

    print("Restored [a, b, c]: ", sess.run([a, b, c]))

使用 tf1.train.Saver 載入

init_from_checkpoint 不同,tf.compat.v1.train.Saver 在圖表模式和 eager 模式下皆可執行。var_list 引數可選擇性接受字典,但必須將變數名稱對應至 tf.Variable 物件。

# Restoring with tf1.train.Saver (works in both graph and eager):

# A new model with a different scope for the variables.
with tf1.variable_scope('new_scope'):
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                      initializer=tf1.zeros_initializer())
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                      initializer=tf1.zeros_initializer())
  c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
# Initialize the saver with a dictionary with the original variable names:
saver = tf1.train.Saver({'a': a, 'b': b, 'scoped/c': c})
saver.restore(sess=None, save_path='tf1-ckpt')
print("Restored [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])

使用 tf.train.load_checkpoint 載入

如果您需要精確控制變數值,此選項適合您。同樣地,此選項在圖表模式和 eager 模式下皆可運作。

# Restoring with tf.train.load_checkpoint (works in both graph and eager):

# A new model with a different scope for the variables.
with tf.Graph().as_default() as g:
  with tf1.variable_scope('new_scope'):
    a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
    b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
    c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
  with tf1.Session() as sess:
    # It may be easier writing a loop if your model has a lot of variables.
    reader = tf.train.load_checkpoint('tf1-ckpt')
    sess.run(a.assign(reader.get_tensor('a')))
    sess.run(b.assign(reader.get_tensor('b')))
    sess.run(c.assign(reader.get_tensor('scoped/c')))
    print("Restored [a, b, c]: ", sess.run([a, b, c]))

維護 TF2 Checkpoint 物件

如果在遷移期間變數和範圍名稱可能會大幅變更,請使用 tf.train.Checkpoint 和 TF2 檢查點。TF2 使用物件結構而非變數名稱 (更多詳細資訊請參閱從 TF1 到 TF2 的變更)。

簡而言之,當建立 tf.train.Checkpoint 以儲存或還原檢查點時,請確保它使用相同的排序 (針對清單) 和金鑰 (針對字典和 Checkpoint 初始設定式的關鍵字引數)。檢查點相容性的一些範例

ckpt = tf.train.Checkpoint(foo=[var_a, var_b])

# compatible with ckpt
tf.train.Checkpoint(foo=[var_a, var_b])

# not compatible with ckpt
tf.train.Checkpoint(foo=[var_b, var_a])
tf.train.Checkpoint(bar=[var_a, var_b])

以下程式碼範例說明如何使用「相同」的 tf.train.Checkpoint 載入具有不同名稱的變數。首先,儲存 TF2 檢查點

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(1))
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(2))
  with tf1.variable_scope('scoped'):
    c = tf1.get_variable('c', shape=[], dtype=tf.float32, 
                        initializer=tf1.constant_initializer(3))
  with tf1.Session() as sess:
    sess.run(tf1.global_variables_initializer())
    print("[a, b, c]: ", sess.run([a, b, c]))

    # Save a TF2 checkpoint
    ckpt = tf.train.Checkpoint(unscoped=[a, b], scoped=[c])
    tf2_ckpt_path = ckpt.save('tf2-ckpt')
    print_checkpoint(tf2_ckpt_path)

即使變數/範圍名稱變更,您也可以繼續使用 tf.train.Checkpoint

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a_different_name', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  b = tf1.get_variable('b_different_name', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  with tf1.variable_scope('different_scope'):
    c = tf1.get_variable('c', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
  with tf1.Session() as sess:
    sess.run(tf1.global_variables_initializer())
    print("Initialized [a, b, c]: ", sess.run([a, b, c]))

    ckpt = tf.train.Checkpoint(unscoped=[a, b], scoped=[c])
    # `assert_consumed` validates that all checkpoint objects are restored from
    # the checkpoint. `run_restore_ops` is required when running in a TF1
    # session.
    ckpt.restore(tf2_ckpt_path).assert_consumed().run_restore_ops()

    # Removing `assert_consumed` is fine if you want to skip the validation.
    # ckpt.restore(tf2_ckpt_path).run_restore_ops()

    print("Restored [a, b, c]: ", sess.run([a, b, c]))

在 eager 模式中

a = tf.Variable(0.)
b = tf.Variable(0.)
c = tf.Variable(0.)
print("Initialized [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])

# The keys "scoped" and "unscoped" are no longer relevant, but are used to
# maintain compatibility with the saved checkpoints.
ckpt = tf.train.Checkpoint(unscoped=[a, b], scoped=[c])

ckpt.restore(tf2_ckpt_path).assert_consumed().run_restore_ops()
print("Restored [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])

估算器中的 TF2 檢查點

以上章節說明如何在遷移模型時維持檢查點相容性。這些概念也適用於估算器模型,但檢查點的儲存/載入方式略有不同。當您遷移估算器模型以使用 TF2 API 時,您可能會想要在模型仍在使用估算器時,從 TF1 切換至 TF2 檢查點。本節說明如何執行此操作。

tf.estimator.EstimatorMonitoredSession 具有稱為 scaffold 的儲存機制,即 tf.compat.v1.train.Scaffold 物件。Scaffold 可以包含 tf1.train.Savertf.train.Checkpoint,讓 EstimatorMonitoredSession 能夠儲存 TF1 或 TF2 樣式的檢查點。

# A model_fn that saves a TF1 checkpoint
def model_fn_tf1_ckpt(features, labels, mode):
  # This model adds 2 to the variable `v` in every train step.
  train_step = tf1.train.get_or_create_global_step()
  v = tf1.get_variable('var', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(0))
  return tf.estimator.EstimatorSpec(
      mode,
      predictions=v,
      train_op=tf.group(v.assign_add(2), train_step.assign_add(1)),
      loss=tf.constant(1.),
      scaffold=None
  )

!rm -rf est-tf1
est = tf.estimator.Estimator(model_fn_tf1_ckpt, 'est-tf1')

def train_fn():
  return tf.data.Dataset.from_tensor_slices(([1,2,3], [4,5,6]))
est.train(train_fn, steps=1)

latest_checkpoint = tf.train.latest_checkpoint('est-tf1')
print_checkpoint(latest_checkpoint)
# A model_fn that saves a TF2 checkpoint
def model_fn_tf2_ckpt(features, labels, mode):
  # This model adds 2 to the variable `v` in every train step.
  train_step = tf1.train.get_or_create_global_step()
  v = tf1.get_variable('var', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(0))
  ckpt = tf.train.Checkpoint(var_list={'var': v}, step=train_step)
  return tf.estimator.EstimatorSpec(
      mode,
      predictions=v,
      train_op=tf.group(v.assign_add(2), train_step.assign_add(1)),
      loss=tf.constant(1.),
      scaffold=tf1.train.Scaffold(saver=ckpt)
  )

!rm -rf est-tf2
est = tf.estimator.Estimator(model_fn_tf2_ckpt, 'est-tf2',
                             warm_start_from='est-tf1')

def train_fn():
  return tf.data.Dataset.from_tensor_slices(([1,2,3], [4,5,6]))
est.train(train_fn, steps=1)

latest_checkpoint = tf.train.latest_checkpoint('est-tf2')
print_checkpoint(latest_checkpoint)  

assert est.get_variable_value('var_list/var/.ATTRIBUTES/VARIABLE_VALUE') == 4

est-tf1 暖啟動後,再訓練 5 個步驟,v 的最終值應為 16warm_start 檢查點不會延續訓練步驟值。

檢查點 Keras

使用 Keras 建置的模型仍然使用 tf1.train.Savertf.train.Checkpoint 載入預先存在的權重。當您的模型完全遷移後,請切換為使用 model.save_weightsmodel.load_weights,特別是當您在訓練時使用 ModelCheckpoint 回呼時。

您應該瞭解的檢查點和 Keras 相關事項

初始化與建置

Keras 模型和層在完全建立之前必須經過兩個步驟。第一個步驟是 Python 物件的初始化layer = tf.keras.layers.Dense(x)。第二個步驟是建置步驟,其中實際上會建立大部分的權重:layer.build(input_shape)。您也可以透過呼叫模型或執行單一 trainevalpredict 步驟 (僅限第一次) 來建置模型。

如果您發現 model.load_weights(path).assert_consumed() 引發錯誤,則可能是模型/層尚未建置。

Keras 使用 TF2 檢查點

tf.train.Checkpoint(model).write 等同於 model.save_weightstf.train.Checkpoint(model).readmodel.load_weights 也是如此。請注意,Checkpoint(model) != Checkpoint(model=model)

TF2 檢查點適用於 Keras 的 build() 步驟

tf.train.Checkpoint.restore 具有稱為延遲還原的機制,可讓 tf.Module 和 Keras 物件在變數尚未建立時儲存變數值。這可讓已初始化的模型在載入權重後進行建置

m = YourKerasModel()
status = m.load_weights(path)

# This call builds the model. The variables are created with the restored
# values.
m.predict(inputs)

status.assert_consumed()

由於此機制,我們強烈建議您將 TF2 檢查點載入 API 與 Keras 模型搭配使用 (即使將預先存在的 TF1 檢查點還原至模型對應 shim 時也是如此)。如需更多資訊,請參閱檢查點指南

程式碼片段

以下程式碼片段顯示檢查點儲存 API 中的 TF1/TF2 版本相容性。

在 TF2 中儲存 TF1 檢查點

a = tf.Variable(1.0, name='a')
b = tf.Variable(2.0, name='b')
with tf.name_scope('scoped'):
  c = tf.Variable(3.0, name='c')

saver = tf1.train.Saver(var_list=[a, b, c])
path = saver.save(sess=None, save_path='tf1-ckpt-saved-in-eager')
print_checkpoint(path)

在 TF2 中載入 TF1 檢查點

a = tf.Variable(0., name='a')
b = tf.Variable(0., name='b')
with tf.name_scope('scoped'):
  c = tf.Variable(0., name='c')
print("Initialized [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])
saver = tf1.train.Saver(var_list=[a, b, c])
saver.restore(sess=None, save_path='tf1-ckpt-saved-in-eager')
print("Restored [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])

在 TF1 中儲存 TF2 檢查點

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(1))
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(2))
  with tf1.variable_scope('scoped'):
    c = tf1.get_variable('c', shape=[], dtype=tf.float32, 
                        initializer=tf1.constant_initializer(3))
  with tf1.Session() as sess:
    sess.run(tf1.global_variables_initializer())
    ckpt = tf.train.Checkpoint(
        var_list={v.name.split(':')[0]: v for v in tf1.global_variables()})
    tf2_in_tf1_path = ckpt.save('tf2-ckpt-saved-in-session')
    print_checkpoint(tf2_in_tf1_path)

在 TF1 中載入 TF2 檢查點

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(0))
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(0))
  with tf1.variable_scope('scoped'):
    c = tf1.get_variable('c', shape=[], dtype=tf.float32, 
                        initializer=tf1.constant_initializer(0))
  with tf1.Session() as sess:
    sess.run(tf1.global_variables_initializer())
    print("Initialized [a, b, c]: ", sess.run([a, b, c]))
    ckpt = tf.train.Checkpoint(
        var_list={v.name.split(':')[0]: v for v in tf1.global_variables()})
    ckpt.restore('tf2-ckpt-saved-in-session-1').run_restore_ops()
    print("Restored [a, b, c]: ", sess.run([a, b, c]))

檢查點轉換

您可以透過載入並重新儲存檢查點,在 TF1 和 TF2 之間轉換檢查點。另一種方法是使用 tf.train.load_checkpoint,如下方程式碼所示。

將 TF1 檢查點轉換為 TF2

def convert_tf1_to_tf2(checkpoint_path, output_prefix):
  """Converts a TF1 checkpoint to TF2.

  To load the converted checkpoint, you must build a dictionary that maps
  variable names to variable objects.
  ```
  ckpt = tf.train.Checkpoint(vars={name: variable})  
  ckpt.restore(converted_ckpt_path)

    ```

    Args:
      checkpoint_path: Path to the TF1 checkpoint.
      output_prefix: Path prefix to the converted checkpoint.

    Returns:
      Path to the converted checkpoint.
    """
    vars = {}
    reader = tf.train.load_checkpoint(checkpoint_path)
    dtypes = reader.get_variable_to_dtype_map()
    for key in dtypes.keys():
      vars[key] = tf.Variable(reader.get_tensor(key))
    return tf.train.Checkpoint(vars=vars).save(output_prefix)
  ```

Convert the checkpoint saved in the snippet `Save a TF1 checkpoint in TF2`:


請務必執行 在 TF2 中儲存 TF1 檢查點 中的程式碼片段。

print_checkpoint('tf1-ckpt-saved-in-eager') converted_path = convert_tf1_to_tf2('tf1-ckpt-saved-in-eager', 'converted-tf1-to-tf2') print("\n[Converted]") print_checkpoint(converted_path)

試著載入已轉換的檢查點。

a = tf.Variable(0.) b = tf.Variable(0.) c = tf.Variable(0.) ckpt = tf.train.Checkpoint(vars={'a': a, 'b': b, 'scoped/c': c}) ckpt.restore(converted_path).assert_consumed() print("\nRestored [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()]) ```

將 TF2 檢查點轉換為 TF1

def convert_tf2_to_tf1(checkpoint_path, output_prefix):
  """Converts a TF2 checkpoint to TF1.

  The checkpoint must be saved using a 
  `tf.train.Checkpoint(var_list={name: variable})`

  To load the converted checkpoint with `tf.compat.v1.Saver`:
  ```
  saver = tf.compat.v1.train.Saver(var_list={name: variable}) 

  # An alternative, if the variable names match the keys:
  saver = tf.compat.v1.train.Saver(var_list=[variables]) 
  saver.restore(sess, output_path)

    ```
    """
    vars = {}
    reader = tf.train.load_checkpoint(checkpoint_path)
    dtypes = reader.get_variable_to_dtype_map()
    for key in dtypes.keys():
      # Get the "name" from the 
      if key.startswith('var_list/'):
        var_name = key.split('/')[1]
        # TF2 checkpoint keys use '/', so if they appear in the user-defined name,
        # they are escaped to '.S'.
        var_name = var_name.replace('.S', '/')
        vars[var_name] = tf.Variable(reader.get_tensor(key))

    return tf1.train.Saver(var_list=vars).save(sess=None, save_path=output_prefix)
  ```

Convert the checkpoint saved in the snippet `Save a TF2 checkpoint in TF1`:


請務必執行 在 TF1 中儲存 TF2 檢查點 中的程式碼片段。

print_checkpoint('tf2-ckpt-saved-in-session-1') converted_path = convert_tf2_to_tf1('tf2-ckpt-saved-in-session-1', 'converted-tf2-to-tf1') print("\n[Converted]") print_checkpoint(converted_path)

試著載入已轉換的檢查點。

with tf.Graph().as_default() as g: a = tf1.get_variable('a', shape=[], dtype=tf.float32, initializer=tf1.constant_initializer(0)) b = tf1.get_variable('b', shape=[], dtype=tf.float32, initializer=tf1.constant_initializer(0)) with tf1.variable_scope('scoped'): c = tf1.get_variable('c', shape=[], dtype=tf.float32, initializer=tf1.constant_initializer(0)) with tf1.Session() as sess: saver = tf1.train.Saver([a, b, c]) saver.restore(sess, converted_path) print("\nRestored [a, b, c]: ", sess.run([a, b, c])) ```