![]() |
![]() |
![]() |
![]() |
除了成為 TensorFlow 生態系統的一部分之外,TFF 的目標是實現與其他前端和後端 ML 架構的互通性。目前,對其他 ML 架構的支援仍處於孵化階段,API 和支援的功能可能會變更 (很大程度上取決於 TFF 使用者的需求)。本教學課程說明如何將 TFF 與 JAX 作為替代 ML 前端,以及 XLA 編譯器作為替代後端搭配使用。此處顯示的範例完全基於端對端的原生 JAX/XLA 堆疊。跨架構 (例如,JAX 與 TensorFlow) 混合程式碼的可能性將在未來的教學課程中討論。
一如既往,我們歡迎您的貢獻。如果對您而言,支援 JAX/XLA 或與其他 ML 架構互通的能力很重要,請考慮協助我們將這些功能發展到與 TFF 其餘部分同等的水準。
開始之前
請參閱 TFF 文件的主體,以瞭解如何設定您的環境。根據您執行本教學課程的位置,您可能想要取消註解並執行以下部分或全部程式碼。
# !pip install --quiet --upgrade tensorflow-federated
# !pip install --quiet --upgrade nest-asyncio
# import nest_asyncio
# nest_asyncio.apply()
本教學課程也假設您已檢閱 TFF 的主要 TensorFlow 教學課程,並且熟悉核心 TFF 概念。如果您尚未完成此操作,請考慮至少檢閱其中一個教學課程。
JAX 運算
TFF 中對 JAX 的支援旨在與 TFF 與 TensorFlow 互通的方式對稱,從匯入開始
import jax
import numpy as np
import tensorflow_federated as tff
此外,就像 TensorFlow 一樣,表達任何 TFF 程式碼的基礎是在本機執行的邏輯。您可以使用 JAX 表達此邏輯,如下所示,使用 @tff.jax_computation
包裝函式。它的行為類似於您現在熟悉的 @tff.tf_computation
。讓我們從簡單的事情開始,例如,將兩個整數相加的運算
@tff.jax_computation(np.int32, np.int32)
def add_numbers(x, y):
return jax.numpy.add(x, y)
您可以像平常使用 TFF 運算一樣使用上面定義的 JAX 運算。例如,您可以檢查其類型簽章,如下所示
str(add_numbers.type_signature)
'(<x=int32,y=int32> -> int32)'
請注意,我們使用 np.int32
來定義引數的類型。TFF 不區分 Numpy 類型 (例如 np.int32
) 和 TensorFlow 類型 (例如 tf.int32
)。從 TFF 的角度來看,它們只是指稱同一事物的不同方式。
現在,請記住 TFF 不是 Python (如果這聽起來很陌生,請檢閱我們之前的一些教學課程,例如,關於自訂演算法的教學課程)。您可以將 @tff.jax_computation
包裝函式與任何可以追蹤和序列化的 JAX 程式碼搭配使用,也就是說,與您通常使用 @jax.jit
註解註解並預期編譯成 XLA 的程式碼搭配使用 (但您不需要實際使用 @jax.jit
註解將您的 JAX 程式碼嵌入 TFF 中)。
實際上,在底層,TFF 會立即將 JAX 運算編譯為 XLA。您可以透過手動從 add_numbers
擷取和列印序列化的 XLA 程式碼來自行檢查,如下所示
comp_pb = tff.framework.serialize_computation(add_numbers)
comp_pb.WhichOneof('computation')
'xla'
xla_code = jax.lib.xla_client.XlaComputation(comp_pb.xla.hlo_module.value)
print(xla_code.as_hlo_text())
HloModule xla_computation_add_numbers.7 ENTRY xla_computation_add_numbers.7 { constant.4 = pred[] constant(false) parameter.1 = (s32[], s32[]) parameter(0) get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0 get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1 add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3) ROOT tuple.6 = (s32[]) tuple(add.5) }
將 JAX 運算的表示視為 XLA 程式碼,就像以 TensorFlow 表示的運算的 tf.GraphDef
的功能等效項。它與 tf.GraphDef
一樣,可以在各種支援 XLA 的環境中移植和執行,tf.GraphDef
可以在任何 TensorFlow 執行階段上執行。
TFF 提供基於 XLA 編譯器作為後端的執行階段堆疊。您可以按如下方式啟用它
tff.backends.xla.set_local_python_execution_context()
現在,您可以執行我們上面定義的運算
add_numbers(2, 3)
5
很簡單。讓我們繼續進行更複雜的事情,例如 MNIST。
使用標準 API 進行 MNIST 訓練的範例
與往常一樣,我們先為資料批次和模型定義一堆 TFF 類型 (請記住,TFF 是一個強型別架構)。
import collections
BATCH_TYPE = collections.OrderedDict([
('pixels', tff.TensorType(np.float32, (50, 784))),
('labels', tff.TensorType(np.int32, (50,)))
])
MODEL_TYPE = collections.OrderedDict([
('weights', tff.TensorType(np.float32, (784, 10))),
('bias', tff.TensorType(np.float32, (10,)))
])
現在,讓我們在 JAX 中為模型定義一個損失函數,將模型和單一批次的資料作為參數
def loss(model, batch):
y = jax.nn.softmax(
jax.numpy.add(
jax.numpy.matmul(batch['pixels'], model['weights']), model['bias']))
targets = jax.nn.one_hot(jax.numpy.reshape(batch['labels'], -1), 10)
return -jax.numpy.mean(jax.numpy.sum(targets * jax.numpy.log(y), axis=1))
現在,一種方法是使用標準 API。以下範例說明如何使用我們的 API 建立基於剛定義的損失函數的訓練程序。
STEP_SIZE = 0.001
trainer = tff.learning.build_jax_federated_averaging_process(
BATCH_TYPE, MODEL_TYPE, loss, STEP_SIZE)
您可以像使用從 TensorFlow 中的 tf.Keras
模型建置的訓練器一樣使用上述內容。例如,以下說明如何建立用於訓練的初始模型
initial_model = trainer.initialize()
initial_model
Struct([('weights', array([[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], ..., [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)), ('bias', array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32))])
為了執行實際訓練,我們需要一些資料。讓我們建立隨機資料以保持簡單。由於資料是隨機的,我們將在訓練資料上評估,因為否則,使用隨機評估資料,很難期望模型表現良好。此外,對於這個小規模的示範,我們不會擔心隨機取樣用戶端 (我們將其作為練習留給使用者,讓他們透過遵循其他教學課程中的範本來探索這些類型的變更)
def random_batch():
pixels = np.random.uniform(
low=0.0, high=1.0, size=(50, 784)).astype(np.float32)
labels = np.random.randint(low=0, high=9, size=(50,), dtype=np.int32)
return collections.OrderedDict([('pixels', pixels), ('labels', labels)])
NUM_CLIENTS = 2
NUM_BATCHES = 10
train_data = [
[random_batch() for _ in range(NUM_BATCHES)]
for _ in range(NUM_CLIENTS)]
這樣,我們可以執行單步訓練,如下所示
trained_model = trainer.next(initial_model, train_data)
trained_model
Struct([('weights', array([[ 1.04456245e-04, -1.53498477e-05, 2.54597180e-05, ..., 5.61640409e-05, -5.32875274e-05, -4.62881755e-04], [ 7.30908650e-05, 4.67643113e-05, 2.03352147e-06, ..., 3.77510623e-05, 3.52839161e-05, -4.59865667e-04], [ 8.14835730e-05, 3.03147244e-05, -1.89143739e-05, ..., 1.12527239e-04, 4.09212225e-06, -4.59960109e-04], ..., [ 9.23552434e-05, 2.44302555e-06, -2.20817346e-05, ..., 7.61375341e-05, 1.76906979e-05, -4.43495519e-04], [ 1.17451040e-04, 2.47748958e-05, 1.04728279e-05, ..., 5.26388249e-07, 7.21131510e-05, -4.67137404e-04], [ 3.75041491e-05, 6.58061981e-05, 1.14522081e-05, ..., 2.52584141e-05, 3.55410739e-05, -4.30888613e-04]], dtype=float32)), ('bias', array([ 1.5096272e-04, 2.6502126e-05, -1.9462314e-05, 8.1269856e-05, 2.1832302e-04, 1.6636557e-04, 1.2815947e-04, 9.0642272e-05, 7.7109929e-05, -9.1987278e-04], dtype=float32))])
讓我們評估訓練步驟的結果。為了簡化,我們可以在集中式方式中評估它
import itertools
eval_data = list(itertools.chain.from_iterable(train_data))
def average_loss(model, data):
return np.mean([loss(model, batch) for batch in data])
print (average_loss(initial_model, eval_data))
print (average_loss(trained_model, eval_data))
2.3025854 2.282762
損失正在減少。太好了!現在,讓我們在多個回合中執行此操作
NUM_ROUNDS = 20
for _ in range(NUM_ROUNDS):
trained_model = trainer.next(trained_model, train_data)
print(average_loss(trained_model, eval_data))
2.2685437 2.257856 2.2495182 2.2428129 2.2372835 2.2326245 2.2286277 2.2251441 2.2220676 2.219318 2.2168345 2.2145717 2.2124937 2.2105706 2.2087805 2.2071042 2.2055268 2.2040353 2.2026198 2.2012706
如您所見,將 JAX 與 TFF 搭配使用並沒有那麼大的不同,儘管實驗性 API 在功能上還無法與 TensorFlow API 相提並論。
底層原理
如果您不喜歡使用我們的標準 API,您可以實作自己的自訂運算,就像您在 TensorFlow 的自訂演算法教學課程中看到的那樣,只是您將使用 JAX 的梯度下降機制。例如,以下說明如何定義在單個小批次上更新模型的 JAX 運算
@tff.jax_computation(MODEL_TYPE, BATCH_TYPE)
def train_on_one_batch(model, batch):
grads = jax.grad(loss)(model, batch)
return collections.OrderedDict([
(k, model[k] - STEP_SIZE * grads[k]) for k in ['weights', 'bias']
])
以下說明如何測試它是否有效
sample_batch = random_batch()
trained_model = train_on_one_batch(initial_model, sample_batch)
print(average_loss(initial_model, [sample_batch]))
print(average_loss(trained_model, [sample_batch]))
2.3025854 2.2977567
使用 JAX 的一個注意事項是它不提供等效於 tf.data.Dataset
的功能。因此,為了迭代資料集,您需要使用 TFF 的宣告式結構來進行序列運算,如下所示
@tff.federated_computation(MODEL_TYPE, tff.SequenceType(BATCH_TYPE))
def train_on_one_client(model, batches):
return tff.sequence_reduce(batches, model, train_on_one_batch)
讓我們看看它是否有效
sample_dataset = [random_batch() for _ in range(100)]
trained_model = train_on_one_client(initial_model, sample_dataset)
print(average_loss(initial_model, sample_dataset))
print(average_loss(trained_model, sample_dataset))
2.3025854 2.2284968
執行單回合訓練的運算看起來就像您在 TensorFlow 教學課程中可能看到的那樣
@tff.federated_computation(
tff.FederatedType(MODEL_TYPE, tff.SERVER),
tff.FederatedType(tff.SequenceType(BATCH_TYPE), tff.CLIENTS))
def train_one_round(model, federated_data):
locally_trained_models = tff.federated_map(
train_on_one_client,
collections.OrderedDict([
('model', tff.federated_broadcast(model)),
('batches', federated_data)]))
return tff.federated_mean(locally_trained_models)
讓我們看看它是否有效
trained_model = train_one_round(initial_model, train_data)
print(average_loss(initial_model, eval_data))
print(average_loss(trained_model, eval_data))
2.3025854 2.282762
如您所見,在 TFF 中使用 JAX,無論是透過標準 API,還是直接使用低階 TFF 結構,都類似於將 TFF 與 TensorFlow 搭配使用。請繼續關注未來的更新,如果您希望看到對跨 ML 架構的互通性提供更好的支援,請隨時向我們發送提取請求!