![]() |
![]() |
![]() |
![]() |
開始之前
在開始之前,請執行以下程式碼,以確保您的環境設定正確。如果您沒有看到問候語,請參閱安裝指南以取得操作說明。
pip install --quiet --upgrade tensorflow-federated
import collections
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
在圖片分類和文字產生教學課程中,您學習瞭如何為聯邦學習 (FL) 設定模型和資料管道,並透過 TFF 的 tff.learning
API 層執行聯邦訓練。
這只是 FL 研究的冰山一角。本教學課程討論如何在不參考 tff.learning
API 的情況下實作聯邦學習演算法。在本教學課程中,您將完成以下事項
目標
- 瞭解聯邦學習演算法的一般結構。
- 探索 TFF 的聯邦核心。
- 直接使用聯邦核心實作聯邦平均。
雖然本教學課程是獨立的,但建議您先查看圖片分類和文字產生教學課程。
準備輸入資料
首先載入並預先處理 TFF 中包含的 EMNIST 資料集。如需更多詳細資訊,請參閱圖片分類教學課程。
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()
為了將資料集饋送至我們的模型,資料會扁平化,且每個範例都會轉換為 (flattened_image_vector, label)
形式的元組。
NUM_CLIENTS = 10
BATCH_SIZE = 20
def preprocess(dataset):
def batch_format_fn(element):
"""Flatten a batch of EMNIST data and return a (features, label) tuple."""
return (tf.reshape(element['pixels'], [-1, 784]),
tf.reshape(element['label'], [-1, 1]))
return dataset.batch(BATCH_SIZE).map(batch_format_fn)
現在,選取少量用戶端,並將上述預先處理套用至其資料集。
client_ids = sorted(emnist_train.client_ids)[:NUM_CLIENTS]
federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))
for x in client_ids
]
準備模型
這會使用與圖片分類教學課程中相同的模型。此模型 (透過 tf.keras
實作) 具有單一隱藏層,後接 softmax 層。
def create_keras_model():
initializer = tf.keras.initializers.GlorotNormal(seed=0)
return tf.keras.models.Sequential([
tf.keras.layers.Input(shape=(784,)),
tf.keras.layers.Dense(10, kernel_initializer=initializer),
tf.keras.layers.Softmax(),
])
為了在 TFF 中使用此模型,請將 Keras 模型包裝為 tff.learning.models.FunctionalModel
。這可讓您在 TFF 內執行模型的 正向傳遞,並擷取模型輸出。如需更多詳細資訊,另請參閱圖片分類教學課程。
keras_model = create_keras_model()
tff_model = tff.learning.models.functional_model_from_keras(
keras_model,
loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(),
input_spec=federated_train_data[0].element_spec,
metrics_constructor=collections.OrderedDict(
accuracy=tf.keras.metrics.SparseCategoricalAccuracy
),
)
雖然以上使用 tf.keras
建立 tff.learning.models.FunctionalModel
,但 TFF 支援更通用的模型。這些模型具有以下相關屬性,可擷取模型權重
trainable_variables
:對應於可訓練層的張量可迭代物件。non_trainable_variables
:對應於不可訓練層的張量可迭代物件。
在本教學課程中,僅會使用 trainable_variables
(因為模型僅具有這些!)。
建構您自己的聯邦學習演算法
雖然 tff.learning
API 可讓您建立聯邦平均的許多變體,但還有其他聯邦演算法無法完全符合此架構。例如,您可能想要新增正規化、剪輯或更複雜的演算法,例如 聯邦 GAN 訓練。您也可能對 聯邦分析感興趣。
對於這些更進階的演算法,您必須使用 TFF 撰寫自己的自訂演算法。在許多情況下,聯邦演算法具有 4 個主要元件
- 伺服器到用戶端廣播步驟。
- 本機用戶端更新步驟。
- 用戶端到伺服器上傳步驟。
- 伺服器更新步驟。
在 TFF 中,聯邦演算法通常表示為 tff.templates.IterativeProcess
(在本文中將簡稱為 IterativeProcess
)。這是一個包含 initialize
和 next
函式的類別。在此,initialize
用於初始化伺服器,而 next
將執行聯邦演算法的一個通訊回合。讓我們撰寫 FedAvg 的迭代程序骨架應有的樣子。
首先,有一個初始化函式,只需建立 tff.learning.models.FunctionalModel
,並傳回其可訓練權重。
def initialize_fn():
trainable_weights, _ = tff_model.initial_weights
return trainable_weights
此函式看起來不錯,但稍後您會看到,您需要稍微修改才能使其成為「TFF 運算」。
接下來,讓我們撰寫 next_fn
的草圖。
def next_fn(server_weights, federated_dataset):
# Broadcast the server weights to the clients.
server_weights_at_client = broadcast(server_weights)
# Each client computes their updated weights.
client_weights = client_update(federated_dataset, server_weights_at_client)
# The server averages these updates.
mean_client_weights = mean(client_weights)
# The server updates its model.
server_weights = server_update(mean_client_weights)
return server_weights
讓我們專注於分別實作這四個元件。首先,讓我們專注於可以在純 TensorFlow 中實作的部分,即用戶端和伺服器更新步驟。
TensorFlow 區塊
用戶端更新
tff.learning.models.FunctionalModel
可用於執行用戶端訓練,方式基本上與訓練 TensorFlow 模型的方式相同。特別是,可以使用 tf.GradientTape
計算資料批次的梯度,然後使用 client_optimizer
套用這些梯度。這只會涉及可訓練權重。
@tf.function
def client_update(model, dataset, initial_weights, client_optimizer):
"""Performs training (using the server model weights) on the client's dataset."""
# Initialize the client model with the current server weights and the optimizer
# state.
client_weights = initial_weights.trainable
optimizer_state = client_optimizer.initialize(
tf.nest.map_structure(tf.TensorSpec.from_tensor, client_weights)
)
# Use the client_optimizer to update the local model.
for batch in dataset:
x, y = batch
with tf.GradientTape() as tape:
tape.watch(client_weights)
# Compute a forward pass on the batch of data
outputs = model.predict_on_batch(
model_weights=(client_weights, ()), x=x, training=True
)
loss = model.loss(output=outputs, label=y)
# Compute the corresponding gradient
grads = tape.gradient(loss, client_weights)
# Apply the gradient using a client optimizer.
optimizer_state, client_weights = client_optimizer.next(
optimizer_state, weights=client_weights, gradients=grads
)
return tff.learning.models.ModelWeights(client_weights, non_trainable=())
伺服器更新
FedAvg 的伺服器更新比用戶端更新更簡單。本教學課程將實作「純粹」聯邦平均,其中伺服器模型權重會由用戶端模型權重的平均值取代。同樣地,這僅使用可訓練權重。
@tf.function
def server_update(model, mean_client_weights):
"""Updates the server model weights as the average of the client model weights."""
del model # Unused, just take the mean_client_weights.
return mean_client_weights
您可以透過簡單地傳回 mean_client_weights
來簡化程式碼片段。但是,聯邦平均的更進階實作會將 mean_client_weights
與更複雜的技術搭配使用,例如動量或適應性。
挑戰:實作 server_update
的版本,將伺服器權重更新為 model_weights 和 mean_client_weights 的中點。(注意:這種「中點」方法類似於最近關於 Lookahead Optimizer 的研究!)。
到目前為止,這僅涉及 TensorFlow 程式碼。這是經過設計的,因為 TFF 可讓您使用許多您已熟悉的 TensorFlow 程式碼。接下來,您必須指定協調邏輯,也就是指示伺服器向用戶端廣播什麼,以及用戶端上傳到伺服器的內容的邏輯。
這將需要 TFF 的聯邦核心。
聯邦核心簡介
聯邦核心 (FC) 是一組較低層級的介面,可作為 tff.learning
API 的基礎。但是,這些介面不限於學習。事實上,它們可用於分析和許多其他分散式資料的運算。
在高層次上,聯邦核心是一種開發環境,可讓簡潔表達的程式邏輯將 TensorFlow 程式碼與分散式通訊運算子 (例如分散式總和和廣播) 結合。目標是讓研究人員和從業人員能夠明確控制其系統中的分散式通訊,而無需系統實作詳細資訊 (例如指定點對點網路訊息交換)。
一個重點是 TFF 專為保護隱私而設計。因此,它允許明確控制資料的駐留位置,以防止資料在集中式伺服器位置不必要地累積。
聯邦資料
TFF 中的一個重要概念是「聯邦資料」,它指的是分散式系統中跨裝置群組 (例如用戶端資料集或伺服器模型權重) 主控的資料項目集合。跨所有裝置的整個值集合都表示為單一聯邦值。
例如,假設有用戶端裝置各自具有代表感測器溫度的浮點數。這些浮點數可以使用聯邦浮點數表示
federated_float_on_clients = tff.FederatedType(np.float32, tff.CLIENTS)
聯邦類型由其成員成分的類型 T
(例如 np.float32
) 和裝置群組 G
指定。通常,G
是 tff.CLIENTS
或 tff.SERVER
。此類聯邦類型表示為 {T}@G
,如下所示。
str(federated_float_on_clients)
'{float32}@CLIENTS'
為什麼 TFF 如此關心放置?TFF 的一個主要目標是能夠撰寫可部署在真實分散式系統上的程式碼。這表示必須推論哪些裝置子集執行哪些程式碼,以及不同資料片段的駐留位置。
TFF 專注於三件事:資料、資料的放置位置,以及資料的轉換方式。前兩者封裝在聯邦類型中,而後者封裝在聯邦運算中。
聯邦運算
TFF 是一個強型別函數式程式設計環境,其基本單位是聯邦運算。這些是接受聯邦值作為輸入,並傳回聯邦值作為輸出的邏輯片段。
例如,假設您想要平均用戶端感測器上的溫度。您可以定義以下內容 (使用我們的聯邦浮點數)
@tff.federated_computation(tff.FederatedType(np.float32, tff.CLIENTS))
def get_average_temperature(client_temperatures):
return tff.federated_mean(client_temperatures)
您可能會問,這與 TensorFlow 中的 tf.function
裝飾器有何不同?關鍵答案是 tff.federated_computation
產生的程式碼既不是 TensorFlow 也不是 Python 程式碼;它是在內部平台獨立黏合語言中分散式系統的規格。
雖然這聽起來可能很複雜,但您可以將 TFF 運算視為具有明確定義類型簽章的函式。這些類型簽章可以直接查詢。
str(get_average_temperature.type_signature)
'({float32}@CLIENTS -> float32@SERVER)'
此 tff.federated_computation
接受聯邦類型 {float32}@CLIENTS
的引數,並傳回聯邦類型 {float32}@SERVER
的值。聯邦運算也可能從伺服器到用戶端、從用戶端到用戶端或從伺服器到伺服器。聯邦運算也可以像一般函式一樣組合,只要其類型簽章匹配即可。
為了支援開發,TFF 可讓您以 Python 函式形式叫用 tff.federated_computation
。例如,您可以呼叫
get_average_temperature([68.5, 70.3, 69.8])
69.53333
非急切運算和 TensorFlow
有兩個需要注意的關鍵限制。首先,當 Python 直譯器遇到 tff.federated_computation
裝飾器時,函式會追蹤一次並序列化以供日後使用。由於聯邦學習的分散式特性,此日後使用可能會在其他地方發生,例如遠端執行環境。因此,TFF 運算從根本上來說是非急切的。此行為在某種程度上類似於 TensorFlow 中 tf.function
裝飾器的行為。
其次,聯邦運算只能包含聯邦運算子 (例如 tff.federated_mean
),它們不能包含 TensorFlow 運算。TensorFlow 程式碼必須限制在以 tff.tensorflow.computation
裝飾的區塊中。大多數普通的 TensorFlow 程式碼都可以直接裝飾,例如以下函式,它接受一個數字並將 0.5
新增至其中。
@tff.tensorflow.computation(np.float32)
def add_half(x):
return tf.add(x, 0.5)
這些也有類型簽章,但沒有放置。例如,您可以呼叫
str(add_half.type_signature)
'(float32 -> float32)'
這展示了 tff.federated_computation
和 tff.tensorflow.computation
之間的重要差異。前者具有明確的放置,而後者則沒有。
您可以使用 tff.tensorflow.computation
區塊在聯邦運算中指定放置。讓我們建立一個新增一半的函式,但僅適用於用戶端聯邦浮點數。您可以透過使用 tff.federated_map
來執行此操作,它會套用給定的 tff.tensorflow.computation
,同時保留放置。
@tff.federated_computation(tff.FederatedType(np.float32, tff.CLIENTS))
def add_half_on_clients(x):
return tff.federated_map(add_half, x)
此函式幾乎與 add_half
相同,只是它僅接受放置在 tff.CLIENTS
的值,並傳回具有相同放置的值。這可以在其類型簽章中看到
str(add_half_on_clients.type_signature)
'({float32}@CLIENTS -> {float32}@CLIENTS)'
總結
- TFF 在聯邦值上運作。
- 每個聯邦值都有一個聯邦類型,其中包含類型 (例如
np.float32
) 和放置 (例如tff.CLIENTS
)。 - 聯邦值可以使用聯邦運算轉換,聯邦運算必須以
tff.federated_computation
和聯邦類型簽章裝飾。 - TensorFlow 程式碼必須包含在具有
tff.tensorflow.computation
裝飾器的區塊中。 - 然後,這些區塊可以併入聯邦運算中。
重新探討建構您自己的聯邦學習演算法
現在您已大致瞭解聯邦核心,您可以建構自己的聯邦學習演算法。請記住,在上方,您為我們的演算法定義了 initialize_fn
和 next_fn
。next_fn
將使用您使用純 TensorFlow 程式碼定義的 client_update
和 server_update
。
但是,為了使我們的演算法成為聯邦運算,您需要 next_fn
和 initialize_fn
都是 tff.federated_computation
。
TensorFlow Federated 區塊
建立初始化運算
初始化函式將非常簡單:您將使用 model_fn
建立模型。但是,請記住,您必須使用 tff.tensorflow.computation
分隔我們的 TensorFlow 程式碼。
@tff.tensorflow.computation
def server_init():
return tff.learning.models.ModelWeights(*tff_model.initial_weights)
然後,您可以使用 tff.federated_value
將其直接傳遞至聯邦運算。
@tff.federated_computation
def initialize_fn():
return tff.federated_eval(server_init, tff.SERVER)
建立 next_fn
用戶端和伺服器更新程式碼現在可用於撰寫實際的演算法。首先,您會將 client_update
轉換為 tff.tensorflow.computation
,它接受用戶端資料集和伺服器權重,並輸出更新的用戶端權重張量。
您需要對應的類型才能正確裝飾我們的函式。幸運的是,伺服器權重的類型可以直接從我們的模型中擷取。
tf_dataset_type = tff.SequenceType(
tff.types.tensorflow_to_type(tff_model.input_spec)
)
讓我們看看資料集類型簽章。請記住,您採用了 28 x 28 圖片 (帶有整數標籤) 並將其扁平化。
str(tf_dataset_type)
'<float32[?,784],int32[?,1]>*'
您也可以使用我們上方的 server_init
函式擷取模型權重類型。
model_weights_type = server_init.type_signature.result
檢查類型簽章,您將能夠看到我們模型的架構!
str(model_weights_type)
'<trainable=<float32[784,10],float32[10]>,non_trainable=<>>'
您現在可以為用戶端更新建立我們的 tff.tensorflow.computation
。
@tff.tensorflow.computation(tf_dataset_type, model_weights_type)
def client_update_fn(tf_dataset, server_weights):
client_optimizer = tff.learning.optimizers.build_sgdm(learning_rate=0.01)
return client_update(tff_model, tf_dataset, server_weights, client_optimizer)
伺服器更新的 tff.tensorflow.computation
版本可以使用您已擷取的類型以類似的方式定義。
@tff.tensorflow.computation(model_weights_type)
def server_update_fn(mean_client_weights):
return server_update(tff_model, mean_client_weights)
最後但並非最不重要的一點是,您需要建立將所有內容整合在一起的 tff.federated_computation
。此函式將接受兩個聯邦值,一個對應於伺服器權重 (放置為 tff.SERVER
),另一個對應於用戶端資料集 (放置為 tff.CLIENTS
)。
請注意,以上定義了這兩種類型!您只需要使用 tff.FederatedType
為它們提供適當的放置即可。
federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)
還記得 FL 演算法的 4 個要素嗎?
- 伺服器到用戶端廣播步驟。
- 本機用戶端更新步驟。
- 用戶端到伺服器上傳步驟。
- 伺服器更新步驟。
現在您已建立上述內容,每個部分都可以簡潔地表示為單行 TFF 程式碼。這種簡單性就是為什麼您必須格外小心地指定聯邦類型等事項!
@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):
# Broadcast the server weights to the clients.
server_weights_at_client = tff.federated_broadcast(server_weights)
# Each client computes their updated weights.
client_weights = tff.federated_map(
client_update_fn, (federated_dataset, server_weights_at_client)
)
# The server averages these updates.
mean_client_weights = tff.federated_mean(client_weights)
# The server updates its model.
server_weights = tff.federated_map(server_update_fn, mean_client_weights)
return server_weights
您現在有一個 tff.federated_computation
,適用於演算法初始化和執行演算法的一個步驟。若要完成我們的演算法,您可以將這些傳遞至 tff.templates.IterativeProcess
。
federated_algorithm = tff.templates.IterativeProcess(
initialize_fn=initialize_fn,
next_fn=next_fn
)
讓我們看看迭代程序的 initialize
和 next
函式的類型簽章。
str(federated_algorithm.initialize.type_signature)
'( -> <trainable=<float32[784,10],float32[10]>,non_trainable=<>>@SERVER)'
這反映了 federated_algorithm.initialize
是一個無引數函式,它傳回單層模型 (具有 784 x 10 權重矩陣和 10 個偏差單位)。
str(federated_algorithm.next.type_signature)
'(<server_weights=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>@SERVER,federated_dataset={<float32[?,784],int32[?,1]>*}@CLIENTS> -> <trainable=<float32[784,10],float32[10]>,non_trainable=<>>@SERVER)'
在此,您可以看到 federated_algorithm.next
接受伺服器模型和用戶端資料,並傳回更新的伺服器模型。
評估演算法
讓我們執行幾個回合,看看損失如何變化。首先,您將使用第二個教學課程中討論的集中式方法定義評估函式。
您將首先建立集中式評估資料集,然後套用您用於訓練資料的相同預先處理。
central_emnist_test = emnist_test.create_tf_dataset_from_all_clients()
central_emnist_test = preprocess(central_emnist_test)
接下來,您將撰寫一個接受伺服器狀態的函式,並使用 Keras 在測試資料集上進行評估。如果您熟悉 tf.Keras
,這一切看起來都很熟悉,但請注意 set_weights
的使用!
def evaluate(model_weights):
keras_model = create_keras_model()
keras_model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
model_weights.assign_weights_to(keras_model)
keras_model.evaluate(central_emnist_test)
現在,讓我們初始化我們的演算法並在測試集上評估。
server_state = federated_algorithm.initialize()
evaluate(server_state)
2042/2042 [==============================] - 26s 10ms/step - loss: 2.8479 - sparse_categorical_accuracy: 0.1027
讓我們訓練幾個回合,看看是否有任何變化。
for _ in range(15):
server_state = federated_algorithm.next(server_state, federated_train_data)
evaluate(server_state)
2042/2042 [==============================] - 4s 1ms/step - loss: 2.5867 - sparse_categorical_accuracy: 0.0980
損失函數略有下降。雖然跳躍很小,但您只執行了 15 個訓練回合,並且是在用戶端的小子集上進行的。若要看到更好的結果,您可能必須執行數百次 (如果不是數千次) 回合。
修改我們的演算法
此時,讓我們停下來思考一下您已完成的工作。您已透過將純 TensorFlow 程式碼 (用於用戶端和伺服器更新) 與 TFF 聯邦核心的聯邦運算結合,直接實作了聯邦平均。
若要執行更複雜的學習,您可以簡單地變更您在上方所做的內容。特別是,透過編輯上方的純 TF 程式碼,您可以變更用戶端執行訓練的方式,或伺服器更新其模型的方式。
挑戰: 將梯度裁剪新增至 client_update
函式。
如果您想要進行更大幅度的變更,也可以讓伺服器儲存及廣播更多資料。例如,伺服器也可以儲存用戶端學習速率,並使其隨時間衰減!請注意,這會需要變更上方 tff.tensorflow.computation
呼叫中使用的類型簽章。
更困難的挑戰: 實作在用戶端上具有學習速率衰減的 Federated Averaging。
此時,您可能會開始意識到在這個架構中您可以實作的彈性有多大。如需想法 (包括上方更困難挑戰的解答),您可以查看 tff.learning.algorithms.build_weighted_fed_avg
的原始碼,或查看使用 TFF 的各種研究專案。