自訂 Python 函數元件

Python 函數型元件定義讓您更輕鬆地建立 TFX 自訂元件,省去定義元件規格類別、執行器類別和元件介面類別的功夫。在這種元件定義樣式中,您會編寫一個以類型提示標註的函式。類型提示會說明元件的輸入工件、輸出工件和參數。

以這種樣式編寫自訂元件非常簡單,如下列範例所示。

class MyOutput(TypedDict):
  accuracy: float

@component
def MyValidationComponent(
    model: InputArtifact[Model],
    blessing: OutputArtifact[Model],
    accuracy_threshold: Parameter[int] = 10,
) -> MyOutput:
  '''My simple custom model validation component.'''

  accuracy = evaluate_model(model)
  if accuracy >= accuracy_threshold:
    write_output_blessing(blessing)

  return {
    'accuracy': accuracy
  }

在幕後,這會定義一個自訂元件,該元件是 BaseComponent 及其 Spec 和 Executor 類別的子類別。

如果您想定義 BaseBeamComponent 的子類別,以便搭配 TFX 管線式共用設定 (亦即編譯管線時的 beam_pipeline_args) 使用 Beam 管線 (範例請參閱芝加哥計程車管線範例),您可以在裝飾器中設定 use_beam=True,並在函式中新增另一個預設值為 NoneBeamComponentParameter,如下列範例所示

@component(use_beam=True)
def MyDataProcessor(
    examples: InputArtifact[Example],
    processed_examples: OutputArtifact[Example],
    beam_pipeline: BeamComponentParameter[beam.Pipeline] = None,
    ) -> None:
  '''My simple custom model validation component.'''

  with beam_pipeline as p:
    # data pipeline definition with beam_pipeline begins
    ...
    # data pipeline definition with beam_pipeline ends

如果您不熟悉 TFX 管線,請進一步瞭解 TFX 管線的核心概念

輸入、輸出和參數

在 TFX 中,輸入和輸出會以 Artifact 物件追蹤,這些物件會說明基礎資料的位置和與之相關聯的中繼資料屬性;這項資訊會儲存在 ML Metadata 中。Artifact 可以說明複雜的資料類型或簡單的資料類型,例如:int、float、位元組或 Unicode 字串。

參數是管線建構時已知元件的引數 (int、float、位元組或 Unicode 字串)。參數可用於指定引數和超參數,例如訓練迭代次數、Dropout 率,以及元件的其他設定。當在 ML Metadata 中追蹤時,參數會儲存為元件執行的屬性。

定義

如要建立自訂元件,請編寫一個函式來實作您的自訂邏輯,並使用 @component 裝飾器 (來自 tfx.dsl.component.experimental.decorators 模組) 加以裝飾。如要定義元件的輸入和輸出結構描述,請使用 tfx.dsl.component.experimental.annotations 模組中的註解,為函式的引數和傳回值加上註解

  • 針對每個工件輸入,套用 InputArtifact[ArtifactType] 類型提示註解。將 ArtifactType 替換為工件的類型,即 tfx.types.Artifact 的子類別。這些輸入可以是選用引數。

  • 針對每個輸出工件,套用 OutputArtifact[ArtifactType] 類型提示註解。將 ArtifactType 替換為工件的類型,即 tfx.types.Artifact 的子類別。元件輸出工件應做為函式的輸入引數傳遞,如此您的元件才能將輸出寫入系統管理的儲存位置,並設定適當的工件中繼資料屬性。此引數可以是選用項目,也可以使用預設值定義此引數。

  • 針對每個參數,使用類型提示註解 Parameter[T]。將 T 替換為參數的類型。我們目前僅支援基本 Python 類型:boolintfloatstrbytes

  • 針對 Beam 管線,使用類型提示註解 BeamComponentParameter[beam.Pipeline]。將預設值設為 NoneNone 值將由 BaseBeamExecutor_make_beam_pipeline() 建立的已例項化 Beam 管線取代

  • 針對每個簡單資料類型輸入 (intfloatstrbytes),若在管線建構時未知,請使用類型提示 T。請注意,在 TFX 0.22 版本中,針對這類輸入,無法在管線建構時傳遞具體值 (請改用 Parameter 註解,如上一節所述)。此引數可以是選用項目,也可以使用預設值定義此引數。如果您的元件具有簡單資料類型輸出 (intfloatstrbytes),您可以透過使用 TypedDict 做為傳回類型註解,並傳回適當的 dict 物件來傳回這些輸出。

在函式主體中,輸入和輸出工件會以 tfx.types.Artifact 物件傳遞;您可以檢查其 .uri 以取得其系統管理的儲存位置,並讀取/設定任何屬性。輸入參數和簡單資料類型輸入會以指定類型的物件傳遞。簡單資料類型輸出應以字典形式傳回,其中鍵是適當的輸出名稱,而值是所需的傳回值。

完成的函數元件可能如下所示

from typing import TypedDict
import tfx.v1 as tfx
from tfx.dsl.component.experimental.decorators import component

class MyOutput(TypedDict):
  loss: float
  accuracy: float

@component
def MyTrainerComponent(
    training_data: tfx.dsl.components.InputArtifact[tfx.types.standard_artifacts.Examples],
    model: tfx.dsl.components.OutputArtifact[tfx.types.standard_artifacts.Model],
    dropout_hyperparameter: float,
    num_iterations: tfx.dsl.components.Parameter[int] = 10
) -> MyOutput:
  '''My simple trainer component.'''

  records = read_examples(training_data.uri)
  model_obj = train_model(records, num_iterations, dropout_hyperparameter)
  model_obj.write_to(model.uri)

  return {
    'loss': model_obj.loss,
    'accuracy': model_obj.accuracy
  }

# Example usage in a pipeline graph definition:
# ...
trainer = MyTrainerComponent(
    examples=example_gen.outputs['examples'],
    dropout_hyperparameter=other_component.outputs['dropout'],
    num_iterations=1000)
pusher = Pusher(model=trainer.outputs['model'])
# ...

先前的範例將 MyTrainerComponent 定義為以 Python 函數為基礎的自訂元件。這個元件會使用 examples 工件做為輸入,並產生 model 工件做為輸出。元件會使用 artifact_instance.uri 在其系統管理的儲存位置讀取或寫入工件。元件會採用 num_iterations 輸入參數和 dropout_hyperparameter 簡單資料類型值,而元件會輸出 lossaccuracy 指標做為簡單資料類型輸出值。輸出 model 工件接著會由 Pusher 元件使用。