Trainer TFX 管道元件會訓練 TensorFlow 模型。
Trainer 和 TensorFlow
Trainer 廣泛使用 Python TensorFlow API 來訓練模型。
元件
Trainer 採用
- 用於訓練和評估的 tf.Examples。
- 使用者提供的模組檔案,定義 Trainer 邏輯。
- Protobuf 訓練引數和評估引數的定義。
- (選用) SchemaGen 管道元件建立的資料結構描述,且可由開發人員選擇性地修改。
- (選用) 上游 Transform 元件產生的轉換圖表。
- (選用) 用於暖啟動等情境的預先訓練模型。
- (選用) 超參數,將傳遞至使用者模組函式。如需與 Tuner 整合的詳細資訊,請參閱這裡。
Trainer 發出:至少一個用於推論/服務的模型 (通常為 SavedModelFormat),以及另一個選用於評估的模型 (通常為 EvalSavedModel)。
我們透過模型重寫程式庫,提供替代模型格式 (例如 TFLite) 的支援。請參閱模型重寫程式庫的連結,以取得如何轉換 Estimator 和 Keras 模型的範例。
通用 Trainer
通用 Trainer 讓開發人員能將任何 TensorFlow 模型 API 與 Trainer 元件搭配使用。除了 TensorFlow Estimator 之外,開發人員還可以使用 Keras 模型或自訂訓練迴圈。如需詳細資訊,請參閱通用 Trainer 的 RFC。
設定 Trainer 元件
通用 Trainer 的典型管道 DSL 程式碼如下所示
from tfx.components import Trainer
...
trainer = Trainer(
module_file=module_file,
examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
Trainer 會叫用訓練模組,該模組在 module_file
參數中指定。如果 custom_executor_spec
中指定了 GenericExecutor
,則模組檔案中需要 run_fn
,而不是 trainer_fn
。trainer_fn
負責建立模型。除了這一點,run_fn
也需要處理訓練部分,並將訓練後的模型輸出到 FnArgs 指定的所需位置
from tfx.components.trainer.fn_args_utils import FnArgs
def run_fn(fn_args: FnArgs) -> None:
"""Build the TF model and train it."""
model = _build_keras_model()
model.fit(...)
# Save model to fn_args.serving_model_dir.
model.save(fn_args.serving_model_dir, ...)
以下是包含 run_fn
的範例模組檔案。
請注意,如果管道中未使用 Transform 元件,則 Trainer 會直接從 ExampleGen 取得範例
trainer = Trainer(
module_file=module_file,
examples=example_gen.outputs['examples'],
schema=infer_schema.outputs['schema'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
如需更多詳細資訊,請參閱Trainer API 參考資料。