Python 函數型元件定義讓您更輕鬆地建立 TFX 自訂元件,省去定義元件規格類別、執行器類別和元件介面類別的功夫。在這種元件定義樣式中,您會編寫一個以類型提示標註的函式。類型提示會說明元件的輸入工件、輸出工件和參數。
以這種樣式編寫自訂元件非常簡單,如下列範例所示。
class MyOutput(TypedDict):
accuracy: float
@component
def MyValidationComponent(
model: InputArtifact[Model],
blessing: OutputArtifact[Model],
accuracy_threshold: Parameter[int] = 10,
) -> MyOutput:
'''My simple custom model validation component.'''
accuracy = evaluate_model(model)
if accuracy >= accuracy_threshold:
write_output_blessing(blessing)
return {
'accuracy': accuracy
}
在幕後,這會定義一個自訂元件,該元件是 BaseComponent
及其 Spec 和 Executor 類別的子類別。
如果您想定義 BaseBeamComponent
的子類別,以便搭配 TFX 管線式共用設定 (亦即編譯管線時的 beam_pipeline_args
) 使用 Beam 管線 (範例請參閱芝加哥計程車管線範例),您可以在裝飾器中設定 use_beam=True
,並在函式中新增另一個預設值為 None
的 BeamComponentParameter
,如下列範例所示
@component(use_beam=True)
def MyDataProcessor(
examples: InputArtifact[Example],
processed_examples: OutputArtifact[Example],
beam_pipeline: BeamComponentParameter[beam.Pipeline] = None,
) -> None:
'''My simple custom model validation component.'''
with beam_pipeline as p:
# data pipeline definition with beam_pipeline begins
...
# data pipeline definition with beam_pipeline ends
如果您不熟悉 TFX 管線,請進一步瞭解 TFX 管線的核心概念。
輸入、輸出和參數
在 TFX 中,輸入和輸出會以 Artifact 物件追蹤,這些物件會說明基礎資料的位置和與之相關聯的中繼資料屬性;這項資訊會儲存在 ML Metadata 中。Artifact 可以說明複雜的資料類型或簡單的資料類型,例如:int、float、位元組或 Unicode 字串。
參數是管線建構時已知元件的引數 (int、float、位元組或 Unicode 字串)。參數可用於指定引數和超參數,例如訓練迭代次數、Dropout 率,以及元件的其他設定。當在 ML Metadata 中追蹤時,參數會儲存為元件執行的屬性。
定義
如要建立自訂元件,請編寫一個函式來實作您的自訂邏輯,並使用 @component
裝飾器 (來自 tfx.dsl.component.experimental.decorators
模組) 加以裝飾。如要定義元件的輸入和輸出結構描述,請使用 tfx.dsl.component.experimental.annotations
模組中的註解,為函式的引數和傳回值加上註解
針對每個工件輸入,套用
InputArtifact[ArtifactType]
類型提示註解。將ArtifactType
替換為工件的類型,即tfx.types.Artifact
的子類別。這些輸入可以是選用引數。針對每個輸出工件,套用
OutputArtifact[ArtifactType]
類型提示註解。將ArtifactType
替換為工件的類型,即tfx.types.Artifact
的子類別。元件輸出工件應做為函式的輸入引數傳遞,如此您的元件才能將輸出寫入系統管理的儲存位置,並設定適當的工件中繼資料屬性。此引數可以是選用項目,也可以使用預設值定義此引數。針對每個參數,使用類型提示註解
Parameter[T]
。將T
替換為參數的類型。我們目前僅支援基本 Python 類型:bool
、int
、float
、str
或bytes
。針對 Beam 管線,使用類型提示註解
BeamComponentParameter[beam.Pipeline]
。將預設值設為None
。None
值將由BaseBeamExecutor
的_make_beam_pipeline()
建立的已例項化 Beam 管線取代針對每個簡單資料類型輸入 (
int
、float
、str
或bytes
),若在管線建構時未知,請使用類型提示T
。請注意,在 TFX 0.22 版本中,針對這類輸入,無法在管線建構時傳遞具體值 (請改用Parameter
註解,如上一節所述)。此引數可以是選用項目,也可以使用預設值定義此引數。如果您的元件具有簡單資料類型輸出 (int
、float
、str
或bytes
),您可以透過使用TypedDict
做為傳回類型註解,並傳回適當的 dict 物件來傳回這些輸出。
在函式主體中,輸入和輸出工件會以 tfx.types.Artifact
物件傳遞;您可以檢查其 .uri
以取得其系統管理的儲存位置,並讀取/設定任何屬性。輸入參數和簡單資料類型輸入會以指定類型的物件傳遞。簡單資料類型輸出應以字典形式傳回,其中鍵是適當的輸出名稱,而值是所需的傳回值。
完成的函數元件可能如下所示
from typing import TypedDict
import tfx.v1 as tfx
from tfx.dsl.component.experimental.decorators import component
class MyOutput(TypedDict):
loss: float
accuracy: float
@component
def MyTrainerComponent(
training_data: tfx.dsl.components.InputArtifact[tfx.types.standard_artifacts.Examples],
model: tfx.dsl.components.OutputArtifact[tfx.types.standard_artifacts.Model],
dropout_hyperparameter: float,
num_iterations: tfx.dsl.components.Parameter[int] = 10
) -> MyOutput:
'''My simple trainer component.'''
records = read_examples(training_data.uri)
model_obj = train_model(records, num_iterations, dropout_hyperparameter)
model_obj.write_to(model.uri)
return {
'loss': model_obj.loss,
'accuracy': model_obj.accuracy
}
# Example usage in a pipeline graph definition:
# ...
trainer = MyTrainerComponent(
examples=example_gen.outputs['examples'],
dropout_hyperparameter=other_component.outputs['dropout'],
num_iterations=1000)
pusher = Pusher(model=trainer.outputs['model'])
# ...
先前的範例將 MyTrainerComponent
定義為以 Python 函數為基礎的自訂元件。這個元件會使用 examples
工件做為輸入,並產生 model
工件做為輸出。元件會使用 artifact_instance.uri
在其系統管理的儲存位置讀取或寫入工件。元件會採用 num_iterations
輸入參數和 dropout_hyperparameter
簡單資料類型值,而元件會輸出 loss
和 accuracy
指標做為簡單資料類型輸出值。輸出 model
工件接著會由 Pusher
元件使用。