JAX 在 TFF 中的實驗性支援

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視 下載筆記本

除了成為 TensorFlow 生態系統的一部分之外,TFF 的目標是實現與其他前端和後端 ML 架構的互通性。目前,對其他 ML 架構的支援仍處於孵化階段,API 和支援的功能可能會變更 (很大程度上取決於 TFF 使用者的需求)。本教學課程說明如何將 TFF 與 JAX 作為替代 ML 前端,以及 XLA 編譯器作為替代後端搭配使用。此處顯示的範例完全基於端對端的原生 JAX/XLA 堆疊。跨架構 (例如,JAX 與 TensorFlow) 混合程式碼的可能性將在未來的教學課程中討論。

一如既往,我們歡迎您的貢獻。如果對您而言,支援 JAX/XLA 或與其他 ML 架構互通的能力很重要,請考慮協助我們將這些功能發展到與 TFF 其餘部分同等的水準。

開始之前

請參閱 TFF 文件的主體,以瞭解如何設定您的環境。根據您執行本教學課程的位置,您可能想要取消註解並執行以下部分或全部程式碼。

# !pip install --quiet --upgrade tensorflow-federated
# !pip install --quiet --upgrade nest-asyncio
# import nest_asyncio
# nest_asyncio.apply()

本教學課程也假設您已檢閱 TFF 的主要 TensorFlow 教學課程,並且熟悉核心 TFF 概念。如果您尚未完成此操作,請考慮至少檢閱其中一個教學課程。

JAX 運算

TFF 中對 JAX 的支援旨在與 TFF 與 TensorFlow 互通的方式對稱,從匯入開始

import jax
import numpy as np
import tensorflow_federated as tff

此外,就像 TensorFlow 一樣,表達任何 TFF 程式碼的基礎是在本機執行的邏輯。您可以使用 JAX 表達此邏輯,如下所示,使用 @tff.jax_computation 包裝函式。它的行為類似於您現在熟悉的 @tff.tf_computation。讓我們從簡單的事情開始,例如,將兩個整數相加的運算

@tff.jax_computation(np.int32, np.int32)
def add_numbers(x, y):
  return jax.numpy.add(x, y)

您可以像平常使用 TFF 運算一樣使用上面定義的 JAX 運算。例如,您可以檢查其類型簽章,如下所示

str(add_numbers.type_signature)
'(<x=int32,y=int32> -> int32)'

請注意,我們使用 np.int32 來定義引數的類型。TFF 不區分 Numpy 類型 (例如 np.int32) 和 TensorFlow 類型 (例如 tf.int32)。從 TFF 的角度來看,它們只是指稱同一事物的不同方式。

現在,請記住 TFF 不是 Python (如果這聽起來很陌生,請檢閱我們之前的一些教學課程,例如,關於自訂演算法的教學課程)。您可以將 @tff.jax_computation 包裝函式與任何可以追蹤和序列化的 JAX 程式碼搭配使用,也就是說,與您通常使用 @jax.jit 註解註解並預期編譯成 XLA 的程式碼搭配使用 (但您不需要實際使用 @jax.jit 註解將您的 JAX 程式碼嵌入 TFF 中)。

實際上,在底層,TFF 會立即將 JAX 運算編譯為 XLA。您可以透過手動從 add_numbers 擷取和列印序列化的 XLA 程式碼來自行檢查,如下所示

comp_pb = tff.framework.serialize_computation(add_numbers)
comp_pb.WhichOneof('computation')
'xla'
xla_code = jax.lib.xla_client.XlaComputation(comp_pb.xla.hlo_module.value)
print(xla_code.as_hlo_text())
HloModule xla_computation_add_numbers.7

ENTRY xla_computation_add_numbers.7 {
  constant.4 = pred[] constant(false)
  parameter.1 = (s32[], s32[]) parameter(0)
  get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0
  get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1
  add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)
  ROOT tuple.6 = (s32[]) tuple(add.5)
}

將 JAX 運算的表示視為 XLA 程式碼,就像以 TensorFlow 表示的運算的 tf.GraphDef 的功能等效項。它與 tf.GraphDef 一樣,可以在各種支援 XLA 的環境中移植和執行,tf.GraphDef 可以在任何 TensorFlow 執行階段上執行。

TFF 提供基於 XLA 編譯器作為後端的執行階段堆疊。您可以按如下方式啟用它

tff.backends.xla.set_local_python_execution_context()

現在,您可以執行我們上面定義的運算

add_numbers(2, 3)
5

很簡單。讓我們繼續進行更複雜的事情,例如 MNIST。

使用標準 API 進行 MNIST 訓練的範例

與往常一樣,我們先為資料批次和模型定義一堆 TFF 類型 (請記住,TFF 是一個強型別架構)。

import collections

BATCH_TYPE = collections.OrderedDict([
    ('pixels', tff.TensorType(np.float32, (50, 784))),
    ('labels', tff.TensorType(np.int32, (50,)))
])

MODEL_TYPE = collections.OrderedDict([
    ('weights', tff.TensorType(np.float32, (784, 10))),
    ('bias', tff.TensorType(np.float32, (10,)))
])

現在,讓我們在 JAX 中為模型定義一個損失函數,將模型和單一批次的資料作為參數

def loss(model, batch):
  y = jax.nn.softmax(
      jax.numpy.add(
          jax.numpy.matmul(batch['pixels'], model['weights']), model['bias']))
  targets = jax.nn.one_hot(jax.numpy.reshape(batch['labels'], -1), 10)
  return -jax.numpy.mean(jax.numpy.sum(targets * jax.numpy.log(y), axis=1))

現在,一種方法是使用標準 API。以下範例說明如何使用我們的 API 建立基於剛定義的損失函數的訓練程序。

STEP_SIZE = 0.001

trainer = tff.learning.build_jax_federated_averaging_process(
    BATCH_TYPE, MODEL_TYPE, loss, STEP_SIZE)

您可以像使用從 TensorFlow 中的 tf.Keras 模型建置的訓練器一樣使用上述內容。例如,以下說明如何建立用於訓練的初始模型

initial_model = trainer.initialize()
initial_model
Struct([('weights', array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)), ('bias', array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32))])

為了執行實際訓練,我們需要一些資料。讓我們建立隨機資料以保持簡單。由於資料是隨機的,我們將在訓練資料上評估,因為否則,使用隨機評估資料,很難期望模型表現良好。此外,對於這個小規模的示範,我們不會擔心隨機取樣用戶端 (我們將其作為練習留給使用者,讓他們透過遵循其他教學課程中的範本來探索這些類型的變更)

def random_batch():
  pixels = np.random.uniform(
      low=0.0, high=1.0, size=(50, 784)).astype(np.float32)
  labels = np.random.randint(low=0, high=9, size=(50,), dtype=np.int32)
  return collections.OrderedDict([('pixels', pixels), ('labels', labels)])

NUM_CLIENTS = 2
NUM_BATCHES = 10

train_data = [
    [random_batch() for _ in range(NUM_BATCHES)]
    for _ in range(NUM_CLIENTS)]

這樣,我們可以執行單步訓練,如下所示

trained_model = trainer.next(initial_model, train_data)
trained_model
Struct([('weights', array([[ 1.04456245e-04, -1.53498477e-05,  2.54597180e-05, ...,
         5.61640409e-05, -5.32875274e-05, -4.62881755e-04],
       [ 7.30908650e-05,  4.67643113e-05,  2.03352147e-06, ...,
         3.77510623e-05,  3.52839161e-05, -4.59865667e-04],
       [ 8.14835730e-05,  3.03147244e-05, -1.89143739e-05, ...,
         1.12527239e-04,  4.09212225e-06, -4.59960109e-04],
       ...,
       [ 9.23552434e-05,  2.44302555e-06, -2.20817346e-05, ...,
         7.61375341e-05,  1.76906979e-05, -4.43495519e-04],
       [ 1.17451040e-04,  2.47748958e-05,  1.04728279e-05, ...,
         5.26388249e-07,  7.21131510e-05, -4.67137404e-04],
       [ 3.75041491e-05,  6.58061981e-05,  1.14522081e-05, ...,
         2.52584141e-05,  3.55410739e-05, -4.30888613e-04]], dtype=float32)), ('bias', array([ 1.5096272e-04,  2.6502126e-05, -1.9462314e-05,  8.1269856e-05,
        2.1832302e-04,  1.6636557e-04,  1.2815947e-04,  9.0642272e-05,
        7.7109929e-05, -9.1987278e-04], dtype=float32))])

讓我們評估訓練步驟的結果。為了簡化,我們可以在集中式方式中評估它

import itertools
eval_data = list(itertools.chain.from_iterable(train_data))

def average_loss(model, data):
  return np.mean([loss(model, batch) for batch in data])

print (average_loss(initial_model, eval_data))
print (average_loss(trained_model, eval_data))
2.3025854
2.282762

損失正在減少。太好了!現在,讓我們在多個回合中執行此操作

NUM_ROUNDS = 20
for _ in range(NUM_ROUNDS):
  trained_model = trainer.next(trained_model, train_data)
  print(average_loss(trained_model, eval_data))
2.2685437
2.257856
2.2495182
2.2428129
2.2372835
2.2326245
2.2286277
2.2251441
2.2220676
2.219318
2.2168345
2.2145717
2.2124937
2.2105706
2.2087805
2.2071042
2.2055268
2.2040353
2.2026198
2.2012706

如您所見,將 JAX 與 TFF 搭配使用並沒有那麼大的不同,儘管實驗性 API 在功能上還無法與 TensorFlow API 相提並論。

底層原理

如果您不喜歡使用我們的標準 API,您可以實作自己的自訂運算,就像您在 TensorFlow 的自訂演算法教學課程中看到的那樣,只是您將使用 JAX 的梯度下降機制。例如,以下說明如何定義在單個小批次上更新模型的 JAX 運算

@tff.jax_computation(MODEL_TYPE, BATCH_TYPE)
def train_on_one_batch(model, batch):
  grads = jax.grad(loss)(model, batch)
  return collections.OrderedDict([
      (k, model[k] - STEP_SIZE * grads[k]) for k in ['weights', 'bias']
  ])

以下說明如何測試它是否有效

sample_batch = random_batch()
trained_model = train_on_one_batch(initial_model, sample_batch)
print(average_loss(initial_model, [sample_batch]))
print(average_loss(trained_model, [sample_batch]))
2.3025854
2.2977567

使用 JAX 的一個注意事項是它不提供等效於 tf.data.Dataset 的功能。因此,為了迭代資料集,您需要使用 TFF 的宣告式結構來進行序列運算,如下所示

@tff.federated_computation(MODEL_TYPE, tff.SequenceType(BATCH_TYPE))
def train_on_one_client(model, batches):
  return tff.sequence_reduce(batches, model, train_on_one_batch)

讓我們看看它是否有效

sample_dataset = [random_batch() for _ in range(100)]
trained_model = train_on_one_client(initial_model, sample_dataset)
print(average_loss(initial_model, sample_dataset))
print(average_loss(trained_model, sample_dataset))
2.3025854
2.2284968

執行單回合訓練的運算看起來就像您在 TensorFlow 教學課程中可能看到的那樣

@tff.federated_computation(
    tff.FederatedType(MODEL_TYPE, tff.SERVER),
    tff.FederatedType(tff.SequenceType(BATCH_TYPE), tff.CLIENTS))
def train_one_round(model, federated_data):
  locally_trained_models = tff.federated_map(
      train_on_one_client,
      collections.OrderedDict([
          ('model', tff.federated_broadcast(model)),
          ('batches', federated_data)]))
  return tff.federated_mean(locally_trained_models)

讓我們看看它是否有效

trained_model = train_one_round(initial_model, train_data)
print(average_loss(initial_model, eval_data))
print(average_loss(trained_model, eval_data))
2.3025854
2.282762

如您所見,在 TFF 中使用 JAX,無論是透過標準 API,還是直接使用低階 TFF 結構,都類似於將 TFF 與 TensorFlow 搭配使用。請繼續關注未來的更新,如果您希望看到對跨 ML 架構的互通性提供更好的支援,請隨時向我們發送提取請求!