將 SessionRunHook 遷移至 Keras 回呼

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視原始碼 下載筆記本

在 TensorFlow 1 中,若要自訂訓練行為,您會搭配 tf.estimator.SessionRunHook 使用 tf.estimator.Estimator。本指南示範如何從 SessionRunHook 遷移至 TensorFlow 2 的自訂回呼,並搭配 tf.keras.callbacks.Callback API,此 API 可與 Keras Model.fit 搭配使用以進行訓練 (以及 Model.evaluateModel.predict)。您將學習如何執行這項操作,方法是實作 SessionRunHookCallback 工作,以測量訓練期間每秒的範例數。

回呼範例包括檢查點儲存 (tf.keras.callbacks.ModelCheckpoint) 和 TensorBoard 摘要寫入。Keras 回呼是在內建 Keras Model.fit/Model.evaluate/Model.predict API 中訓練/評估/預測期間不同時間點呼叫的物件。您可以在 tf.keras.callbacks.Callback API 文件,以及撰寫您自己的回呼使用內建方法進行訓練和評估 (「使用回呼」章節) 指南中,進一步瞭解回呼。

設定

從匯入項目和簡單資料集開始,以用於示範目的

import tensorflow as tf
import tensorflow.compat.v1 as tf1

import time
from datetime import datetime
from absl import flags
features = [[1., 1.5], [2., 2.5], [3., 3.5]]
labels = [[0.3], [0.5], [0.7]]
eval_features = [[4., 4.5], [5., 5.5], [6., 6.5]]
eval_labels = [[0.8], [0.9], [1.]]

TensorFlow 1:使用 tf.estimator API 建立自訂 SessionRunHook

下列 TensorFlow 1 範例示範如何設定自訂 SessionRunHook,以測量訓練期間每秒的範例數。建立 Hook (LoggerHook) 後,將其傳遞至 tf.estimator.Estimator.trainhooks 參數。

def _input_fn():
  return tf1.data.Dataset.from_tensor_slices(
      (features, labels)).batch(1).repeat(100)

def _model_fn(features, labels, mode):
  logits = tf1.layers.Dense(1)(features)
  loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)
  optimizer = tf1.train.AdagradOptimizer(0.05)
  train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
  return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
class LoggerHook(tf1.train.SessionRunHook):
  """Logs loss and runtime."""

  def begin(self):
    self._step = -1
    self._start_time = time.time()
    self.log_frequency = 10

  def before_run(self, run_context):
    self._step += 1

  def after_run(self, run_context, run_values):
    if self._step % self.log_frequency == 0:
      current_time = time.time()
      duration = current_time - self._start_time
      self._start_time = current_time
      examples_per_sec = self.log_frequency / duration
      print('Time:', datetime.now(), ', Step #:', self._step,
            ', Examples per second:', examples_per_sec)

estimator = tf1.estimator.Estimator(model_fn=_model_fn)

# Begin training.
estimator.train(_input_fn, hooks=[LoggerHook()])

TensorFlow 2:為 Model.fit 建立自訂 Keras 回呼

在 TensorFlow 2 中,當您使用內建 Keras Model.fit (或 Model.evaluate) 進行訓練/評估時,您可以設定自訂 tf.keras.callbacks.Callback,然後將其傳遞至 Model.fit (或 Model.evaluate) 的 callbacks 參數。(在撰寫您自己的回呼指南中瞭解詳情。)

在以下範例中,您將撰寫自訂 tf.keras.callbacks.Callback,以記錄各種指標,其將測量每秒的範例數,這應與先前的 SessionRunHook 範例中的指標相當。

class CustomCallback(tf.keras.callbacks.Callback):

    def on_train_begin(self, logs = None):
      self._step = -1
      self._start_time = time.time()
      self.log_frequency = 10

    def on_train_batch_begin(self, batch, logs = None):
      self._step += 1

    def on_train_batch_end(self, batch, logs = None):
      if self._step % self.log_frequency == 0:
        current_time = time.time()
        duration = current_time - self._start_time
        self._start_time = current_time
        examples_per_sec = self.log_frequency / duration
        print('Time:', datetime.now(), ', Step #:', self._step,
              ', Examples per second:', examples_per_sec)

callback = CustomCallback()

dataset = tf.data.Dataset.from_tensor_slices(
    (features, labels)).batch(1).repeat(100)

model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)

model.compile(optimizer, "mse")

# Begin training.
result = model.fit(dataset, callbacks=[callback], verbose = 0)
# Provide the results of training metrics.
result.history

後續步驟

進一步瞭解回呼,請參閱:

您也可能會發現下列與遷移相關的資源很有用: