![]() |
![]() |
![]() |
![]() |
總覽
本指南假設您有一個使用 tf.compat.v1.Saver
儲存和載入檢查點的模型,並且想要遷移程式碼以使用 TF2 tf.train.Checkpoint
API,或在您的 TF2 模型中使用預先存在的檢查點。
以下是您可能會遇到的一些常見情境
情境 1
有先前訓練執行產生的現有 TF1 檢查點需要載入或轉換為 TF2。
- 若要在 TF2 中載入 TF1 檢查點,請參閱程式碼片段在 TF2 中載入 TF1 檢查點。
- 若要將檢查點轉換為 TF2,請參閱檢查點轉換。
情境 2
您正在調整模型,調整方式可能會變更變數名稱和路徑 (例如,從 get_variable
逐步遷移至明確的 tf.Variable
建立),並且希望在遷移過程中維持現有檢查點的儲存/載入。
請參閱如何在模型遷移期間維持檢查點相容性一節
情境 3
您正在將訓練程式碼和檢查點遷移至 TF2,但您的推論管線目前仍需要 TF1 檢查點 (以維持生產環境穩定性)。
選項 1
在訓練時同時儲存 TF1 和 TF2 檢查點。
選項 2
將 TF2 檢查點轉換為 TF1。
- 請參閱檢查點轉換
以下範例顯示在 TF1/TF2 中儲存和載入檢查點的所有組合,讓您在決定如何遷移模型時具有彈性。
設定
import tensorflow as tf
import tensorflow.compat.v1 as tf1
def print_checkpoint(save_path):
reader = tf.train.load_checkpoint(save_path)
shapes = reader.get_variable_to_shape_map()
dtypes = reader.get_variable_to_dtype_map()
print(f"Checkpoint at '{save_path}':")
for key in shapes:
print(f" (key='{key}', shape={shapes[key]}, dtype={dtypes[key].name}, "
f"value={reader.get_tensor(key)})")
從 TF1 到 TF2 的變更
如果您想知道 TF1 和 TF2 之間有哪些變更,以及我們所說的「名稱型」(TF1) 與「物件型」(TF2) 檢查點是什麼意思,本節將說明。
這兩種檢查點實際上是以相同的格式儲存,基本上是鍵值表。差異在於金鑰的產生方式。
名稱型檢查點中的金鑰是變數的名稱。物件型檢查點中的金鑰是指從根物件到變數的路徑 (以下範例將有助於您更瞭解其含義)。
首先,儲存一些檢查點
with tf.Graph().as_default() as g:
a = tf1.get_variable('a', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
b = tf1.get_variable('b', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
with tf1.Session() as sess:
saver = tf1.train.Saver()
sess.run(a.assign(1))
sess.run(b.assign(2))
sess.run(c.assign(3))
saver.save(sess, 'tf1-ckpt')
print_checkpoint('tf1-ckpt')
a = tf.Variable(5.0, name='a')
b = tf.Variable(6.0, name='b')
with tf.name_scope('scoped'):
c = tf.Variable(7.0, name='c')
ckpt = tf.train.Checkpoint(variables=[a, b, c])
save_path_v2 = ckpt.save('tf2-ckpt')
print_checkpoint(save_path_v2)
如果您查看 tf2-ckpt
中的金鑰,它們都指向每個變數的物件路徑。例如,變數 a
是 variables
清單中的第一個元素,因此其金鑰會變成 variables/0/...
(您可以忽略 .ATTRIBUTES/VARIABLE_VALUE 常數)。
更仔細檢查下方的 Checkpoint
物件
a = tf.Variable(0.)
b = tf.Variable(0.)
c = tf.Variable(0.)
root = ckpt = tf.train.Checkpoint(variables=[a, b, c])
print("root type =", type(root).__name__)
print("root.variables =", root.variables)
print("root.variables[0] =", root.variables[0])
試著使用下方的程式碼片段進行實驗,看看檢查點金鑰如何隨著物件結構而變更
module = tf.Module()
module.d = tf.Variable(0.)
test_ckpt = tf.train.Checkpoint(v={'a': a, 'b': b},
c=c,
module=module)
test_ckpt_path = test_ckpt.save('root-tf2-ckpt')
print_checkpoint(test_ckpt_path)
為什麼 TF2 使用這種機制?
由於 TF2 中不再有全域圖表,因此變數名稱不可靠,且在程式之間可能不一致。TF2 鼓勵物件導向建模方法,其中變數由層擁有,而層由模型擁有
variable = tf.Variable(...)
layer.variable_name = variable
model.layer_name = layer
如何在模型遷移期間維持檢查點相容性
遷移過程中的一個重要步驟是確保所有變數都初始化為正確的值,這反過來可讓您驗證運算/函式是否執行正確的計算。若要完成此步驟,您必須考量遷移各階段模型之間的檢查點相容性。基本上,本節回答的問題是:如何在變更模型的同時繼續使用相同的檢查點。
以下是三種維持檢查點相容性的方法,依彈性遞增順序排列
- 模型具有與先前相同的變數名稱。
- 模型具有不同的變數名稱,並維護指派對應表,將檢查點中的變數名稱對應至新名稱。
- 模型具有不同的變數名稱,並維護儲存所有變數的 TF2 Checkpoint 物件。
當變數名稱相符時
長標題:當變數名稱相符時,如何重複使用檢查點。
簡短解答:您可以直接使用 tf1.train.Saver
或 tf.train.Checkpoint
載入預先存在的檢查點。
如果您使用 tf.compat.v1.keras.utils.track_tf1_style_variables
,則可確保您的模型變數名稱與先前相同。您也可以手動確保變數名稱相符。
當遷移模型中的變數名稱相符時,您可以直接使用 tf.train.Checkpoint
或 tf.compat.v1.train.Saver
載入檢查點。這兩種 API 都與 eager 模式和圖表模式相容,因此您可以在遷移的任何階段使用它們。
以下範例說明如何搭配不同模型使用相同的檢查點。首先,使用 tf1.train.Saver
儲存 TF1 檢查點
with tf.Graph().as_default() as g:
a = tf1.get_variable('a', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
b = tf1.get_variable('b', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
with tf1.Session() as sess:
saver = tf1.train.Saver()
sess.run(a.assign(1))
sess.run(b.assign(2))
sess.run(c.assign(3))
save_path = saver.save(sess, 'tf1-ckpt')
print_checkpoint(save_path)
以下範例說明如何在 eager 模式中使用 tf.compat.v1.Saver
載入檢查點
a = tf.Variable(0.0, name='a')
b = tf.Variable(0.0, name='b')
with tf.name_scope('scoped'):
c = tf.Variable(0.0, name='c')
# With the removal of collections in TF2, you must pass in the list of variables
# to the Saver object:
saver = tf1.train.Saver(var_list=[a, b, c])
saver.restore(sess=None, save_path=save_path)
print(f"loaded values of [a, b, c]: [{a.numpy()}, {b.numpy()}, {c.numpy()}]")
# Saving also works in eager (sess must be None).
path = saver.save(sess=None, save_path='tf1-ckpt-saved-in-eager')
print_checkpoint(path)
下一個程式碼片段說明如何使用 TF2 API tf.train.Checkpoint
載入檢查點
a = tf.Variable(0.0, name='a')
b = tf.Variable(0.0, name='b')
with tf.name_scope('scoped'):
c = tf.Variable(0.0, name='c')
# Without the name_scope, name="scoped/c" works too:
c_2 = tf.Variable(0.0, name='scoped/c')
print("Variable names: ")
print(f" a.name = {a.name}")
print(f" b.name = {b.name}")
print(f" c.name = {c.name}")
print(f" c_2.name = {c_2.name}")
# Restore the values with tf.train.Checkpoint
ckpt = tf.train.Checkpoint(variables=[a, b, c, c_2])
ckpt.restore(save_path)
print(f"loaded values of [a, b, c, c_2]: [{a.numpy()}, {b.numpy()}, {c.numpy()}, {c_2.numpy()}]")
TF2 中的變數名稱
- 變數仍然都有您可以設定的
name
引數。 - Keras 模型也會採用
name
引數,並將其設定為變數的前置字元。 v1.name_scope
函式可用於設定變數名稱前置字元。這與tf.variable_scope
非常不同。它只會影響名稱,不會追蹤變數和重複使用。
tf.compat.v1.keras.utils.track_tf1_style_variables
裝飾器是一個 shim,可協助您維持變數名稱和 TF1 檢查點相容性,方法是保持 tf.variable_scope
和 tf.compat.v1.get_variable
的命名和重複使用語意不變。如需更多資訊,請參閱模型對應指南。
注意 1:如果您使用 shim,請使用 TF2 API 載入檢查點 (即使在使用預先訓練的 TF1 檢查點時也是如此)。
請參閱檢查點 Keras 一節。
注意 2:從 get_variable
遷移至 tf.Variable
時
如果您的 shim 裝飾層或模組包含一些變數 (或 Keras 層/模型),這些變數使用 tf.Variable
而非 tf.compat.v1.get_variable
,並以物件導向方式附加為屬性/追蹤,則它們在 TF1.x 圖表/工作階段中,與在 eager 執行期間,可能具有不同的變數命名語意。
簡而言之,當在 TF2 中執行時,名稱可能與您預期的不同。
維護指派對應表
指派對應表通常用於在 TF1 模型之間傳輸權重,如果變數名稱變更,也可以在模型遷移期間使用。
您可以搭配 tf.compat.v1.train.init_from_checkpoint
、tf.compat.v1.train.Saver
和 tf.train.load_checkpoint
使用這些對應表,將權重載入變數或範圍名稱可能已變更的模型中。
本節中的範例將使用先前儲存的檢查點
print_checkpoint('tf1-ckpt')
使用 init_from_checkpoint
載入
tf1.train.init_from_checkpoint
必須在圖表/工作階段中呼叫,因為它會將值放在變數初始設定式中,而不是建立指派運算。
您可以使用 assignment_map
引數設定變數的載入方式。摘錄自文件:
指派對應表支援下列語法
'checkpoint_scope_name/': 'scope_name/'
- 會從checkpoint_scope_name
載入目前scope_name
中的所有變數,並比對張量名稱。'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name'
- 會從checkpoint_scope_name/some_other_variable
初始化scope_name/variable_name
變數。'scope_variable_name': variable
- 會使用檢查點中的張量 'scope_variable_name' 初始化指定的tf.Variable
物件。'scope_variable_name': list(variable)
- 會使用檢查點中的張量 'scope_variable_name' 初始化分割變數清單。'/': 'scope_name/'
- 會從檢查點的根目錄 (例如,沒有範圍) 載入目前scope_name
中的所有變數。
# Restoring with tf1.train.init_from_checkpoint:
# A new model with a different scope for the variables.
with tf.Graph().as_default() as g:
with tf1.variable_scope('new_scope'):
a = tf1.get_variable('a', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
b = tf1.get_variable('b', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
with tf1.Session() as sess:
# The assignment map will remap all variables in the checkpoint to the
# new scope:
tf1.train.init_from_checkpoint(
'tf1-ckpt',
assignment_map={'/': 'new_scope/'})
# `init_from_checkpoint` adds the initializers to these variables.
# Use `sess.run` to run these initializers.
sess.run(tf1.global_variables_initializer())
print("Restored [a, b, c]: ", sess.run([a, b, c]))
使用 tf1.train.Saver
載入
與 init_from_checkpoint
不同,tf.compat.v1.train.Saver
在圖表模式和 eager 模式下皆可執行。var_list
引數可選擇性接受字典,但必須將變數名稱對應至 tf.Variable
物件。
# Restoring with tf1.train.Saver (works in both graph and eager):
# A new model with a different scope for the variables.
with tf1.variable_scope('new_scope'):
a = tf1.get_variable('a', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
b = tf1.get_variable('b', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
# Initialize the saver with a dictionary with the original variable names:
saver = tf1.train.Saver({'a': a, 'b': b, 'scoped/c': c})
saver.restore(sess=None, save_path='tf1-ckpt')
print("Restored [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])
使用 tf.train.load_checkpoint
載入
如果您需要精確控制變數值,此選項適合您。同樣地,此選項在圖表模式和 eager 模式下皆可運作。
# Restoring with tf.train.load_checkpoint (works in both graph and eager):
# A new model with a different scope for the variables.
with tf.Graph().as_default() as g:
with tf1.variable_scope('new_scope'):
a = tf1.get_variable('a', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
b = tf1.get_variable('b', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
with tf1.Session() as sess:
# It may be easier writing a loop if your model has a lot of variables.
reader = tf.train.load_checkpoint('tf1-ckpt')
sess.run(a.assign(reader.get_tensor('a')))
sess.run(b.assign(reader.get_tensor('b')))
sess.run(c.assign(reader.get_tensor('scoped/c')))
print("Restored [a, b, c]: ", sess.run([a, b, c]))
維護 TF2 Checkpoint 物件
如果在遷移期間變數和範圍名稱可能會大幅變更,請使用 tf.train.Checkpoint
和 TF2 檢查點。TF2 使用物件結構而非變數名稱 (更多詳細資訊請參閱從 TF1 到 TF2 的變更)。
簡而言之,當建立 tf.train.Checkpoint
以儲存或還原檢查點時,請確保它使用相同的排序 (針對清單) 和金鑰 (針對字典和 Checkpoint
初始設定式的關鍵字引數)。檢查點相容性的一些範例
ckpt = tf.train.Checkpoint(foo=[var_a, var_b])
# compatible with ckpt
tf.train.Checkpoint(foo=[var_a, var_b])
# not compatible with ckpt
tf.train.Checkpoint(foo=[var_b, var_a])
tf.train.Checkpoint(bar=[var_a, var_b])
以下程式碼範例說明如何使用「相同」的 tf.train.Checkpoint
載入具有不同名稱的變數。首先,儲存 TF2 檢查點
with tf.Graph().as_default() as g:
a = tf1.get_variable('a', shape=[], dtype=tf.float32,
initializer=tf1.constant_initializer(1))
b = tf1.get_variable('b', shape=[], dtype=tf.float32,
initializer=tf1.constant_initializer(2))
with tf1.variable_scope('scoped'):
c = tf1.get_variable('c', shape=[], dtype=tf.float32,
initializer=tf1.constant_initializer(3))
with tf1.Session() as sess:
sess.run(tf1.global_variables_initializer())
print("[a, b, c]: ", sess.run([a, b, c]))
# Save a TF2 checkpoint
ckpt = tf.train.Checkpoint(unscoped=[a, b], scoped=[c])
tf2_ckpt_path = ckpt.save('tf2-ckpt')
print_checkpoint(tf2_ckpt_path)
即使變數/範圍名稱變更,您也可以繼續使用 tf.train.Checkpoint
with tf.Graph().as_default() as g:
a = tf1.get_variable('a_different_name', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
b = tf1.get_variable('b_different_name', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
with tf1.variable_scope('different_scope'):
c = tf1.get_variable('c', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
with tf1.Session() as sess:
sess.run(tf1.global_variables_initializer())
print("Initialized [a, b, c]: ", sess.run([a, b, c]))
ckpt = tf.train.Checkpoint(unscoped=[a, b], scoped=[c])
# `assert_consumed` validates that all checkpoint objects are restored from
# the checkpoint. `run_restore_ops` is required when running in a TF1
# session.
ckpt.restore(tf2_ckpt_path).assert_consumed().run_restore_ops()
# Removing `assert_consumed` is fine if you want to skip the validation.
# ckpt.restore(tf2_ckpt_path).run_restore_ops()
print("Restored [a, b, c]: ", sess.run([a, b, c]))
在 eager 模式中
a = tf.Variable(0.)
b = tf.Variable(0.)
c = tf.Variable(0.)
print("Initialized [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])
# The keys "scoped" and "unscoped" are no longer relevant, but are used to
# maintain compatibility with the saved checkpoints.
ckpt = tf.train.Checkpoint(unscoped=[a, b], scoped=[c])
ckpt.restore(tf2_ckpt_path).assert_consumed().run_restore_ops()
print("Restored [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])
估算器中的 TF2 檢查點
以上章節說明如何在遷移模型時維持檢查點相容性。這些概念也適用於估算器模型,但檢查點的儲存/載入方式略有不同。當您遷移估算器模型以使用 TF2 API 時,您可能會想要在模型仍在使用估算器時,從 TF1 切換至 TF2 檢查點。本節說明如何執行此操作。
tf.estimator.Estimator
和 MonitoredSession
具有稱為 scaffold
的儲存機制,即 tf.compat.v1.train.Scaffold
物件。Scaffold
可以包含 tf1.train.Saver
或 tf.train.Checkpoint
,讓 Estimator
和 MonitoredSession
能夠儲存 TF1 或 TF2 樣式的檢查點。
# A model_fn that saves a TF1 checkpoint
def model_fn_tf1_ckpt(features, labels, mode):
# This model adds 2 to the variable `v` in every train step.
train_step = tf1.train.get_or_create_global_step()
v = tf1.get_variable('var', shape=[], dtype=tf.float32,
initializer=tf1.constant_initializer(0))
return tf.estimator.EstimatorSpec(
mode,
predictions=v,
train_op=tf.group(v.assign_add(2), train_step.assign_add(1)),
loss=tf.constant(1.),
scaffold=None
)
!rm -rf est-tf1
est = tf.estimator.Estimator(model_fn_tf1_ckpt, 'est-tf1')
def train_fn():
return tf.data.Dataset.from_tensor_slices(([1,2,3], [4,5,6]))
est.train(train_fn, steps=1)
latest_checkpoint = tf.train.latest_checkpoint('est-tf1')
print_checkpoint(latest_checkpoint)
# A model_fn that saves a TF2 checkpoint
def model_fn_tf2_ckpt(features, labels, mode):
# This model adds 2 to the variable `v` in every train step.
train_step = tf1.train.get_or_create_global_step()
v = tf1.get_variable('var', shape=[], dtype=tf.float32,
initializer=tf1.constant_initializer(0))
ckpt = tf.train.Checkpoint(var_list={'var': v}, step=train_step)
return tf.estimator.EstimatorSpec(
mode,
predictions=v,
train_op=tf.group(v.assign_add(2), train_step.assign_add(1)),
loss=tf.constant(1.),
scaffold=tf1.train.Scaffold(saver=ckpt)
)
!rm -rf est-tf2
est = tf.estimator.Estimator(model_fn_tf2_ckpt, 'est-tf2',
warm_start_from='est-tf1')
def train_fn():
return tf.data.Dataset.from_tensor_slices(([1,2,3], [4,5,6]))
est.train(train_fn, steps=1)
latest_checkpoint = tf.train.latest_checkpoint('est-tf2')
print_checkpoint(latest_checkpoint)
assert est.get_variable_value('var_list/var/.ATTRIBUTES/VARIABLE_VALUE') == 4
從 est-tf1
暖啟動後,再訓練 5 個步驟,v
的最終值應為 16
。warm_start
檢查點不會延續訓練步驟值。
檢查點 Keras
使用 Keras 建置的模型仍然使用 tf1.train.Saver
和 tf.train.Checkpoint
載入預先存在的權重。當您的模型完全遷移後,請切換為使用 model.save_weights
和 model.load_weights
,特別是當您在訓練時使用 ModelCheckpoint
回呼時。
您應該瞭解的檢查點和 Keras 相關事項
初始化與建置
Keras 模型和層在完全建立之前必須經過兩個步驟。第一個步驟是 Python 物件的初始化:layer = tf.keras.layers.Dense(x)
。第二個步驟是建置步驟,其中實際上會建立大部分的權重:layer.build(input_shape)
。您也可以透過呼叫模型或執行單一 train
、eval
或 predict
步驟 (僅限第一次) 來建置模型。
如果您發現 model.load_weights(path).assert_consumed()
引發錯誤,則可能是模型/層尚未建置。
Keras 使用 TF2 檢查點
tf.train.Checkpoint(model).write
等同於 model.save_weights
。tf.train.Checkpoint(model).read
和 model.load_weights
也是如此。請注意,Checkpoint(model) != Checkpoint(model=model)
。
TF2 檢查點適用於 Keras 的 build()
步驟
tf.train.Checkpoint.restore
具有稱為延遲還原的機制,可讓 tf.Module
和 Keras 物件在變數尚未建立時儲存變數值。這可讓已初始化的模型在載入權重後進行建置。
m = YourKerasModel()
status = m.load_weights(path)
# This call builds the model. The variables are created with the restored
# values.
m.predict(inputs)
status.assert_consumed()
由於此機制,我們強烈建議您將 TF2 檢查點載入 API 與 Keras 模型搭配使用 (即使將預先存在的 TF1 檢查點還原至模型對應 shim 時也是如此)。如需更多資訊,請參閱檢查點指南。
程式碼片段
以下程式碼片段顯示檢查點儲存 API 中的 TF1/TF2 版本相容性。
在 TF2 中儲存 TF1 檢查點
a = tf.Variable(1.0, name='a')
b = tf.Variable(2.0, name='b')
with tf.name_scope('scoped'):
c = tf.Variable(3.0, name='c')
saver = tf1.train.Saver(var_list=[a, b, c])
path = saver.save(sess=None, save_path='tf1-ckpt-saved-in-eager')
print_checkpoint(path)
在 TF2 中載入 TF1 檢查點
a = tf.Variable(0., name='a')
b = tf.Variable(0., name='b')
with tf.name_scope('scoped'):
c = tf.Variable(0., name='c')
print("Initialized [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])
saver = tf1.train.Saver(var_list=[a, b, c])
saver.restore(sess=None, save_path='tf1-ckpt-saved-in-eager')
print("Restored [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])
在 TF1 中儲存 TF2 檢查點
with tf.Graph().as_default() as g:
a = tf1.get_variable('a', shape=[], dtype=tf.float32,
initializer=tf1.constant_initializer(1))
b = tf1.get_variable('b', shape=[], dtype=tf.float32,
initializer=tf1.constant_initializer(2))
with tf1.variable_scope('scoped'):
c = tf1.get_variable('c', shape=[], dtype=tf.float32,
initializer=tf1.constant_initializer(3))
with tf1.Session() as sess:
sess.run(tf1.global_variables_initializer())
ckpt = tf.train.Checkpoint(
var_list={v.name.split(':')[0]: v for v in tf1.global_variables()})
tf2_in_tf1_path = ckpt.save('tf2-ckpt-saved-in-session')
print_checkpoint(tf2_in_tf1_path)
在 TF1 中載入 TF2 檢查點
with tf.Graph().as_default() as g:
a = tf1.get_variable('a', shape=[], dtype=tf.float32,
initializer=tf1.constant_initializer(0))
b = tf1.get_variable('b', shape=[], dtype=tf.float32,
initializer=tf1.constant_initializer(0))
with tf1.variable_scope('scoped'):
c = tf1.get_variable('c', shape=[], dtype=tf.float32,
initializer=tf1.constant_initializer(0))
with tf1.Session() as sess:
sess.run(tf1.global_variables_initializer())
print("Initialized [a, b, c]: ", sess.run([a, b, c]))
ckpt = tf.train.Checkpoint(
var_list={v.name.split(':')[0]: v for v in tf1.global_variables()})
ckpt.restore('tf2-ckpt-saved-in-session-1').run_restore_ops()
print("Restored [a, b, c]: ", sess.run([a, b, c]))
檢查點轉換
您可以透過載入並重新儲存檢查點,在 TF1 和 TF2 之間轉換檢查點。另一種方法是使用 tf.train.load_checkpoint
,如下方程式碼所示。
將 TF1 檢查點轉換為 TF2
def convert_tf1_to_tf2(checkpoint_path, output_prefix):
"""Converts a TF1 checkpoint to TF2.
To load the converted checkpoint, you must build a dictionary that maps
variable names to variable objects.
```
ckpt = tf.train.Checkpoint(vars={name: variable})
ckpt.restore(converted_ckpt_path)
```
Args:
checkpoint_path: Path to the TF1 checkpoint.
output_prefix: Path prefix to the converted checkpoint.
Returns:
Path to the converted checkpoint.
"""
vars = {}
reader = tf.train.load_checkpoint(checkpoint_path)
dtypes = reader.get_variable_to_dtype_map()
for key in dtypes.keys():
vars[key] = tf.Variable(reader.get_tensor(key))
return tf.train.Checkpoint(vars=vars).save(output_prefix)
```
Convert the checkpoint saved in the snippet `Save a TF1 checkpoint in TF2`:
請務必執行 在 TF2 中儲存 TF1 檢查點
中的程式碼片段。
print_checkpoint('tf1-ckpt-saved-in-eager') converted_path = convert_tf1_to_tf2('tf1-ckpt-saved-in-eager', 'converted-tf1-to-tf2') print("\n[Converted]") print_checkpoint(converted_path)
試著載入已轉換的檢查點。
a = tf.Variable(0.) b = tf.Variable(0.) c = tf.Variable(0.) ckpt = tf.train.Checkpoint(vars={'a': a, 'b': b, 'scoped/c': c}) ckpt.restore(converted_path).assert_consumed() print("\nRestored [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()]) ```
將 TF2 檢查點轉換為 TF1
def convert_tf2_to_tf1(checkpoint_path, output_prefix):
"""Converts a TF2 checkpoint to TF1.
The checkpoint must be saved using a
`tf.train.Checkpoint(var_list={name: variable})`
To load the converted checkpoint with `tf.compat.v1.Saver`:
```
saver = tf.compat.v1.train.Saver(var_list={name: variable})
# An alternative, if the variable names match the keys:
saver = tf.compat.v1.train.Saver(var_list=[variables])
saver.restore(sess, output_path)
```
"""
vars = {}
reader = tf.train.load_checkpoint(checkpoint_path)
dtypes = reader.get_variable_to_dtype_map()
for key in dtypes.keys():
# Get the "name" from the
if key.startswith('var_list/'):
var_name = key.split('/')[1]
# TF2 checkpoint keys use '/', so if they appear in the user-defined name,
# they are escaped to '.S'.
var_name = var_name.replace('.S', '/')
vars[var_name] = tf.Variable(reader.get_tensor(key))
return tf1.train.Saver(var_list=vars).save(sess=None, save_path=output_prefix)
```
Convert the checkpoint saved in the snippet `Save a TF2 checkpoint in TF1`:
請務必執行 在 TF1 中儲存 TF2 檢查點
中的程式碼片段。
print_checkpoint('tf2-ckpt-saved-in-session-1') converted_path = convert_tf2_to_tf1('tf2-ckpt-saved-in-session-1', 'converted-tf2-to-tf1') print("\n[Converted]") print_checkpoint(converted_path)
試著載入已轉換的檢查點。
with tf.Graph().as_default() as g: a = tf1.get_variable('a', shape=[], dtype=tf.float32, initializer=tf1.constant_initializer(0)) b = tf1.get_variable('b', shape=[], dtype=tf.float32, initializer=tf1.constant_initializer(0)) with tf1.variable_scope('scoped'): c = tf1.get_variable('c', shape=[], dtype=tf.float32, initializer=tf1.constant_initializer(0)) with tf1.Session() as sess: saver = tf1.train.Saver([a, b, c]) saver.restore(sess, converted_path) print("\nRestored [a, b, c]: ", sess.run([a, b, c])) ```
相關指南
- 驗證數值等價性和正確性
- 模型對應指南 和
tf.compat.v1.keras.utils.track_tf1_style_variables
- TF2 檢查點指南.