TFX 作為平台,是框架中立的,可以與其他 ML 框架搭配使用,例如 JAX、scikit-learn。
對於模型開發人員而言,這表示他們不需要重寫以其他 ML 框架實作的模型程式碼,而是可以在 TFX 中重複使用大部分的訓練程式碼,並受益於 TFX 和 TensorFlow 生態系統其他部分提供的其他功能。
TFX 管線 SDK 和 TFX 中的大多數模組 (例如管線協調器) 沒有任何對 TensorFlow 的直接依賴性,但有些方面是面向 TensorFlow 的,例如資料格式。只要考量特定模型建構框架的需求,TFX 管線就可以用於訓練任何其他以 Python 為基礎的 ML 框架中的模型。這包括 Scikit-learn、XGBoost 和 PyTorch 等。將標準 TFX 元件與其他框架搭配使用的一些考量因素包括
- ExampleGen 會以 TFRecord 檔案輸出 tf.train.Example。
tf.train.Example
是訓練資料的通用表示法,下游元件會使用 TFXIO 在記憶體中將其讀取為 Arrow/RecordBatch,然後可以進一步轉換為tf.dataset
、Tensors
或其他格式。正在考慮使用 tf.train.Example/TFRecord 以外的酬載/檔案格式,但對於 TFXIO 使用者而言,這應該是黑箱作業。 - 無論訓練使用何種框架,Transform 都可以用於產生轉換後的訓練範例,但如果模型格式不是
saved_model
,使用者將無法將轉換圖表嵌入模型中。在這種情況下,模型預測需要採用轉換後的特徵而不是原始特徵,並且使用者可以在部署時,在呼叫模型預測之前執行轉換作為預先處理步驟。 - Trainer 預設僅支援
saved_model
,但使用者可以提供產生模型評估預測結果的 UDF。 - Evaluator 預設僅支援
saved_model
,但使用者可以提供產生模型評估預測結果的 UDF。
以非 Python 為基礎的框架訓練模型將需要在 Docker 容器中隔離自訂訓練元件,作為在容器化環境 (例如 Kubernetes) 中執行的管線的一部分。
JAX
JAX 是 Autograd 和 XLA 的結合,專為高效能機器學習研究而設計。Flax 是 JAX 的神經網路程式庫和生態系統,專為彈性而設計。
透過 jax2tf,我們可以將訓練過的 JAX/Flax 模型轉換為 saved_model
格式,該格式可以在 TFX 中與通用訓練和模型評估無縫搭配使用。如需詳細資訊,請查看此範例。
scikit-learn
Scikit-learn 是 Python 程式設計語言的機器學習程式庫。我們在 TFX-Addons 中有一個端對端 範例,其中包含自訂訓練和評估。