在 TFX 中使用其他 ML 框架

TFX 作為平台,是框架中立的,可以與其他 ML 框架搭配使用,例如 JAX、scikit-learn。

對於模型開發人員而言,這表示他們不需要重寫以其他 ML 框架實作的模型程式碼,而是可以在 TFX 中重複使用大部分的訓練程式碼,並受益於 TFX 和 TensorFlow 生態系統其他部分提供的其他功能。

TFX 管線 SDK 和 TFX 中的大多數模組 (例如管線協調器) 沒有任何對 TensorFlow 的直接依賴性,但有些方面是面向 TensorFlow 的,例如資料格式。只要考量特定模型建構框架的需求,TFX 管線就可以用於訓練任何其他以 Python 為基礎的 ML 框架中的模型。這包括 Scikit-learn、XGBoost 和 PyTorch 等。將標準 TFX 元件與其他框架搭配使用的一些考量因素包括

  • ExampleGen 會以 TFRecord 檔案輸出 tf.train.Exampletf.train.Example 是訓練資料的通用表示法,下游元件會使用 TFXIO 在記憶體中將其讀取為 Arrow/RecordBatch,然後可以進一步轉換為 tf.datasetTensors 或其他格式。正在考慮使用 tf.train.Example/TFRecord 以外的酬載/檔案格式,但對於 TFXIO 使用者而言,這應該是黑箱作業。
  • 無論訓練使用何種框架,Transform 都可以用於產生轉換後的訓練範例,但如果模型格式不是 saved_model,使用者將無法將轉換圖表嵌入模型中。在這種情況下,模型預測需要採用轉換後的特徵而不是原始特徵,並且使用者可以在部署時,在呼叫模型預測之前執行轉換作為預先處理步驟。
  • Trainer 預設僅支援 saved_model,但使用者可以提供產生模型評估預測結果的 UDF。
  • Evaluator 預設僅支援 saved_model,但使用者可以提供產生模型評估預測結果的 UDF。

以非 Python 為基礎的框架訓練模型將需要在 Docker 容器中隔離自訂訓練元件,作為在容器化環境 (例如 Kubernetes) 中執行的管線的一部分。

JAX

JAX 是 Autograd 和 XLA 的結合,專為高效能機器學習研究而設計。Flax 是 JAX 的神經網路程式庫和生態系統,專為彈性而設計。

透過 jax2tf,我們可以將訓練過的 JAX/Flax 模型轉換為 saved_model 格式,該格式可以在 TFX 中與通用訓練和模型評估無縫搭配使用。如需詳細資訊,請查看此範例

scikit-learn

Scikit-learn 是 Python 程式設計語言的機器學習程式庫。我們在 TFX-Addons 中有一個端對端 範例,其中包含自訂訓練和評估。