Tuner TFX 管線組件

Tuner 組件會調整模型的超參數。

Tuner 組件和 KerasTuner 程式庫

Tuner 組件廣泛使用 Python KerasTuner API 進行超參數調整。

組件

Tuner 接受

  • 用於訓練和評估的 tf.Examples。
  • 使用者提供的模組檔案 (或模組函式),其中定義調整邏輯,包括模型定義、超參數搜尋空間、目標等。
  • 訓練引數和評估引數的 Protobuf 定義。
  • (選用)Protobuf 調整引數定義。
  • (選用) 上游 Transform 組件產生的轉換圖表。
  • (選用) 由 SchemaGen 管線組件建立,且可由開發人員選擇性變更的資料結構描述。

透過指定的資料、模型和目標,Tuner 會調整超參數並發出最佳結果。

操作說明

Tuner 需要具有下列簽名的使用者模組函式 tuner_fn

...
from keras_tuner.engine import base_tuner

TunerFnResult = NamedTuple('TunerFnResult', [('tuner', base_tuner.BaseTuner),
                                             ('fit_kwargs', Dict[Text, Any])])

def tuner_fn(fn_args: FnArgs) -> TunerFnResult:
  """Build the tuner using the KerasTuner API.
  Args:
    fn_args: Holds args as name/value pairs.
      - working_dir: working dir for tuning.
      - train_files: List of file paths containing training tf.Example data.
      - eval_files: List of file paths containing eval tf.Example data.
      - train_steps: number of train steps.
      - eval_steps: number of eval steps.
      - schema_path: optional schema of the input data.
      - transform_graph_path: optional transform graph produced by TFT.
  Returns:
    A namedtuple contains the following:
      - tuner: A BaseTuner that will be used for tuning.
      - fit_kwargs: Args to pass to tuner's run_trial function for fitting the
                    model , e.g., the training and validation dataset. Required
                    args depend on the above tuner's implementation.
  """
  ...

在此函式中,您可以定義模型和超參數搜尋空間,並選擇調整的目標和演算法。Tuner 組件會將此模組程式碼做為輸入、調整超參數,並發出最佳結果。

Trainer 可以將 Tuner 的輸出超參數做為輸入,並在使用者模組程式碼中加以運用。管線定義如下所示

...
tuner = Tuner(
    module_file=module_file,  # Contains `tuner_fn`.
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=20),
    eval_args=trainer_pb2.EvalArgs(num_steps=5))

trainer = Trainer(
    module_file=module_file,  # Contains `run_fn`.
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    schema=schema_gen.outputs['schema'],
    # This will be passed to `run_fn`.
    hyperparameters=tuner.outputs['best_hyperparameters'],
    train_args=trainer_pb2.TrainArgs(num_steps=100),
    eval_args=trainer_pb2.EvalArgs(num_steps=5))
...

您可能不希望每次重新訓練模型時都調整超參數。一旦您使用 Tuner 判斷一組良好的超參數,您可以從管線中移除 Tuner,並使用 ImporterNode 從先前的訓練執行匯入 Tuner 構件,以饋送至 Trainer。

hparams_importer = Importer(
    # This can be Tuner's output file or manually edited file. The file contains
    # text format of hyperparameters (keras_tuner.HyperParameters.get_config())
    source_uri='path/to/best_hyperparameters.txt',
    artifact_type=HyperParameters,
).with_id('import_hparams')

trainer = Trainer(
    ...
    # An alternative is directly use the tuned hyperparameters in Trainer's user
    # module code and set hyperparameters to None here.
    hyperparameters = hparams_importer.outputs['result'])

在 Google Cloud Platform (GCP) 上調整

在 Google Cloud Platform (GCP) 上執行時,Tuner 組件可以利用兩項服務

AI Platform Vizier 做為超參數調整的後端

AI Platform Vizier 是一項受管理的服務,根據 Google Vizier 技術執行黑箱最佳化。

CloudTunerKerasTuner 的實作,會與 AI Platform Vizier 服務通訊做為研究後端。由於 CloudTuner 是 keras_tuner.Tuner 的子類別,因此可以在 tuner_fn 模組中做為直接替換項目使用,並以 TFX Tuner 組件的一部分執行。

以下程式碼片段示範如何使用 CloudTuner。請注意,CloudTuner 的設定需要 GCP 特有的項目,例如 project_idregion

...
from tensorflow_cloud import CloudTuner

...
def tuner_fn(fn_args: FnArgs) -> TunerFnResult:
  """An implementation of tuner_fn that instantiates CloudTuner."""

  ...
  tuner = CloudTuner(
      _build_model,
      hyperparameters=...,
      ...
      project_id=...,       # GCP Project ID
      region=...,           # GCP Region where Vizier service is run.
  )

  ...
  return TuneFnResult(
      tuner=tuner,
      fit_kwargs={...}
  )

在 Cloud AI Platform Training 分散式工作站叢集上平行調整

KerasTuner 架構做為 Tuner 組件的底層實作,能夠平行執行超參數搜尋。雖然標準 Tuner 組件無法平行執行多個搜尋工作站,但透過使用 Google Cloud AI Platform 擴充功能 Tuner 組件,它可以提供執行平行調整的功能,方法是使用 AI Platform Training Job 做為分散式工作站叢集管理員。TuneArgs 是提供給此組件的設定。這是標準 Tuner 組件的直接替換項目。

tuner = google_cloud_ai_platform.Tuner(
    ...   # Same kwargs as the above stock Tuner component.
    tune_args=proto.TuneArgs(num_parallel_trials=3),  # 3-worker parallel
    custom_config={
        # Configures Cloud AI Platform-specific configs . For for details, see
        # https://cloud.google.com/ai-platform/training/docs/reference/rest/v1/projects.jobs#traininginput.
        TUNING_ARGS_KEY:
            {
                'project': ...,
                'region': ...,
                # Configuration of machines for each master/worker in the flock.
                'masterConfig': ...,
                'workerConfig': ...,
                ...
            }
    })
...

擴充功能 Tuner 組件的行為和輸出與標準 Tuner 組件相同,但多個超參數搜尋會在不同的工作站機器上平行執行,因此 num_trials 會更快完成。當搜尋演算法可輕易地平行化 (例如 RandomSearch) 時,這特別有效。但是,如果搜尋演算法使用先前試驗結果的資訊 (例如 AI Platform Vizier 中實作的 Google Vizier 演算法),過度平行的搜尋會對搜尋的效力造成負面影響。

E2E 範例

GCP 範例上的 E2E CloudTuner

KerasTuner 教學課程

CloudTuner 教學課程

提案

更多詳細資訊請參閱 Tuner API 參考資料