Trainer TFX 管道元件

Trainer TFX 管道元件會訓練 TensorFlow 模型。

Trainer 和 TensorFlow

Trainer 廣泛使用 Python TensorFlow API 來訓練模型。

元件

Trainer 接受

  • 用於訓練和評估的 tf.Examples。
  • 使用者提供的模組檔案,定義 Trainer 邏輯。
  • Protobuf 訓練引數和評估引數的定義。
  • (選用) SchemaGen 管道元件建立的資料結構描述,且開發人員可選擇性地變更。
  • (選用) 上游 Transform 元件產生的轉換圖。
  • (選用) 用於暖啟動等案例的預先訓練模型。
  • (選用) 超參數,將傳遞至使用者模組函式。如需與 Tuner 整合的詳細資訊,請參閱此處

Trainer 傳送:至少一個用於推論/服務的模型 (通常採用 SavedModelFormat),以及另一個用於評估的模型 (通常是 EvalSavedModel),可為選用項目。

我們透過模型重寫程式庫,為 TFLite 等替代模型格式提供支援。如需如何轉換 Estimator 和 Keras 模型的範例,請參閱模型重寫程式庫的連結。

通用 Trainer

通用 Trainer 讓開發人員能將任何 TensorFlow 模型 API 與 Trainer 元件搭配使用。除了 TensorFlow Estimator,開發人員也可以使用 Keras 模型或自訂訓練迴圈。如需詳細資訊,請參閱通用 Trainer 的 RFC

設定 Trainer 元件

通用 Trainer 的一般管道 DSL 程式碼看起來會像這樣

from tfx.components import Trainer

...

trainer = Trainer(
    module_file=module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Trainer 會叫用訓練模組,該模組在 module_file 參數中指定。如果 GenericExecutorcustom_executor_spec 中指定,則模組檔案中需要 run_fn,而不是 trainer_fntrainer_fn 負責建立模型。除此之外,run_fn 也需要處理訓練部分,並將訓練後的模型輸出至 FnArgs 指定的所需位置

from tfx.components.trainer.fn_args_utils import FnArgs

def run_fn(fn_args: FnArgs) -> None:
  """Build the TF model and train it."""
  model = _build_keras_model()
  model.fit(...)
  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir, ...)

以下是包含 run_fn範例模組檔案

請注意,如果管道中未使用 Transform 元件,則 Trainer 會直接從 ExampleGen 取得範例

trainer = Trainer(
    module_file=module_file,
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Trainer API 參考資料中提供更多詳細資訊。