建構您自己的聯邦學習演算法

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視原始碼 下載筆記本

開始之前

在開始之前,請執行以下程式碼,以確保您的環境設定正確。如果您沒有看到問候語,請參閱安裝指南以取得操作說明。

pip install --quiet --upgrade tensorflow-federated
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

圖片分類文字產生教學課程中,您學習瞭如何為聯邦學習 (FL) 設定模型和資料管道,並透過 TFF 的 tff.learning API 層執行聯邦訓練。

這只是 FL 研究的冰山一角。本教學課程討論如何在參考 tff.learning API 的情況下實作聯邦學習演算法。在本教學課程中,您將完成以下事項

目標

  • 瞭解聯邦學習演算法的一般結構。
  • 探索 TFF 的聯邦核心
  • 直接使用聯邦核心實作聯邦平均。

雖然本教學課程是獨立的,但建議您先查看圖片分類文字產生教學課程。

準備輸入資料

首先載入並預先處理 TFF 中包含的 EMNIST 資料集。如需更多詳細資訊,請參閱圖片分類教學課程。

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

為了將資料集饋送至我們的模型,資料會扁平化,且每個範例都會轉換為 (flattened_image_vector, label) 形式的元組。

NUM_CLIENTS = 10
BATCH_SIZE = 20

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch of EMNIST data and return a (features, label) tuple."""
    return (tf.reshape(element['pixels'], [-1, 784]),
            tf.reshape(element['label'], [-1, 1]))

  return dataset.batch(BATCH_SIZE).map(batch_format_fn)

現在,選取少量用戶端,並將上述預先處理套用至其資料集。

client_ids = sorted(emnist_train.client_ids)[:NUM_CLIENTS]
federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))
  for x in client_ids
]

準備模型

這會使用與圖片分類教學課程中相同的模型。此模型 (透過 tf.keras 實作) 具有單一隱藏層,後接 softmax 層。

def create_keras_model():
  initializer = tf.keras.initializers.GlorotNormal(seed=0)
  return tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer=initializer),
      tf.keras.layers.Softmax(),
  ])

為了在 TFF 中使用此模型,請將 Keras 模型包裝為 tff.learning.models.FunctionalModel。這可讓您在 TFF 內執行模型的 正向傳遞,並擷取模型輸出。如需更多詳細資訊,另請參閱圖片分類教學課程。

keras_model = create_keras_model()
tff_model = tff.learning.models.functional_model_from_keras(
    keras_model,
    loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(),
    input_spec=federated_train_data[0].element_spec,
    metrics_constructor=collections.OrderedDict(
        accuracy=tf.keras.metrics.SparseCategoricalAccuracy
    ),
)

雖然以上使用 tf.keras 建立 tff.learning.models.FunctionalModel,但 TFF 支援更通用的模型。這些模型具有以下相關屬性,可擷取模型權重

  • trainable_variables:對應於可訓練層的張量可迭代物件。
  • non_trainable_variables:對應於不可訓練層的張量可迭代物件。

在本教學課程中,僅會使用 trainable_variables (因為模型僅具有這些!)。

建構您自己的聯邦學習演算法

雖然 tff.learning API 可讓您建立聯邦平均的許多變體,但還有其他聯邦演算法無法完全符合此架構。例如,您可能想要新增正規化、剪輯或更複雜的演算法,例如 聯邦 GAN 訓練。您也可能對 聯邦分析感興趣。

對於這些更進階的演算法,您必須使用 TFF 撰寫自己的自訂演算法。在許多情況下,聯邦演算法具有 4 個主要元件

  1. 伺服器到用戶端廣播步驟。
  2. 本機用戶端更新步驟。
  3. 用戶端到伺服器上傳步驟。
  4. 伺服器更新步驟。

在 TFF 中,聯邦演算法通常表示為 tff.templates.IterativeProcess (在本文中將簡稱為 IterativeProcess)。這是一個包含 initializenext 函式的類別。在此,initialize 用於初始化伺服器,而 next 將執行聯邦演算法的一個通訊回合。讓我們撰寫 FedAvg 的迭代程序骨架應有的樣子。

首先,有一個初始化函式,只需建立 tff.learning.models.FunctionalModel,並傳回其可訓練權重。

def initialize_fn():
  trainable_weights, _ =  tff_model.initial_weights
  return trainable_weights

此函式看起來不錯,但稍後您會看到,您需要稍微修改才能使其成為「TFF 運算」。

接下來,讓我們撰寫 next_fn 的草圖。

def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = client_update(federated_dataset, server_weights_at_client)

  # The server averages these updates.
  mean_client_weights = mean(client_weights)

  # The server updates its model.
  server_weights = server_update(mean_client_weights)

  return server_weights

讓我們專注於分別實作這四個元件。首先,讓我們專注於可以在純 TensorFlow 中實作的部分,即用戶端和伺服器更新步驟。

TensorFlow 區塊

用戶端更新

tff.learning.models.FunctionalModel 可用於執行用戶端訓練,方式基本上與訓練 TensorFlow 模型的方式相同。特別是,可以使用 tf.GradientTape 計算資料批次的梯度,然後使用 client_optimizer 套用這些梯度。這只會涉及可訓練權重。

@tf.function
def client_update(model, dataset, initial_weights, client_optimizer):
  """Performs training (using the server model weights) on the client's dataset."""
  # Initialize the client model with the current server weights and the optimizer
  # state.
  client_weights = initial_weights.trainable
  optimizer_state = client_optimizer.initialize(
      tf.nest.map_structure(tf.TensorSpec.from_tensor, client_weights)
  )

  # Use the client_optimizer to update the local model.
  for batch in dataset:
    x, y = batch
    with tf.GradientTape() as tape:
      tape.watch(client_weights)
      # Compute a forward pass on the batch of data
      outputs = model.predict_on_batch(
          model_weights=(client_weights, ()), x=x, training=True
      )
      loss = model.loss(output=outputs, label=y)

    # Compute the corresponding gradient
    grads = tape.gradient(loss, client_weights)

    # Apply the gradient using a client optimizer.
    optimizer_state, client_weights = client_optimizer.next(
        optimizer_state, weights=client_weights, gradients=grads
    )

  return tff.learning.models.ModelWeights(client_weights, non_trainable=())

伺服器更新

FedAvg 的伺服器更新比用戶端更新更簡單。本教學課程將實作「純粹」聯邦平均,其中伺服器模型權重會由用戶端模型權重的平均值取代。同樣地,這僅使用可訓練權重。

@tf.function
def server_update(model, mean_client_weights):
  """Updates the server model weights as the average of the client model weights."""
  del model  # Unused, just take the mean_client_weights.
  return mean_client_weights

您可以透過簡單地傳回 mean_client_weights 來簡化程式碼片段。但是,聯邦平均的更進階實作會將 mean_client_weights 與更複雜的技術搭配使用,例如動量或適應性。

挑戰:實作 server_update 的版本,將伺服器權重更新為 model_weights 和 mean_client_weights 的中點。(注意:這種「中點」方法類似於最近關於 Lookahead Optimizer 的研究!)。

到目前為止,這僅涉及 TensorFlow 程式碼。這是經過設計的,因為 TFF 可讓您使用許多您已熟悉的 TensorFlow 程式碼。接下來,您必須指定協調邏輯,也就是指示伺服器向用戶端廣播什麼,以及用戶端上傳到伺服器的內容的邏輯。

這將需要 TFF 的聯邦核心

聯邦核心簡介

聯邦核心 (FC) 是一組較低層級的介面,可作為 tff.learning API 的基礎。但是,這些介面不限於學習。事實上,它們可用於分析和許多其他分散式資料的運算。

在高層次上,聯邦核心是一種開發環境,可讓簡潔表達的程式邏輯將 TensorFlow 程式碼與分散式通訊運算子 (例如分散式總和和廣播) 結合。目標是讓研究人員和從業人員能夠明確控制其系統中的分散式通訊,而無需系統實作詳細資訊 (例如指定點對點網路訊息交換)。

一個重點是 TFF 專為保護隱私而設計。因此,它允許明確控制資料的駐留位置,以防止資料在集中式伺服器位置不必要地累積。

聯邦資料

TFF 中的一個重要概念是「聯邦資料」,它指的是分散式系統中跨裝置群組 (例如用戶端資料集或伺服器模型權重) 主控的資料項目集合。跨所有裝置的整個值集合都表示為單一聯邦值

例如,假設有用戶端裝置各自具有代表感測器溫度的浮點數。這些浮點數可以使用聯邦浮點數表示

federated_float_on_clients = tff.FederatedType(np.float32, tff.CLIENTS)

聯邦類型由其成員成分的類型 T (例如 np.float32) 和裝置群組 G 指定。通常,Gtff.CLIENTStff.SERVER。此類聯邦類型表示為 {T}@G,如下所示。

str(federated_float_on_clients)
'{float32}@CLIENTS'

為什麼 TFF 如此關心放置?TFF 的一個主要目標是能夠撰寫可部署在真實分散式系統上的程式碼。這表示必須推論哪些裝置子集執行哪些程式碼,以及不同資料片段的駐留位置。

TFF 專注於三件事:資料、資料的放置位置,以及資料的轉換方式。前兩者封裝在聯邦類型中,而後者封裝在聯邦運算中。

聯邦運算

TFF 是一個強型別函數式程式設計環境,其基本單位是聯邦運算。這些是接受聯邦值作為輸入,並傳回聯邦值作為輸出的邏輯片段。

例如,假設您想要平均用戶端感測器上的溫度。您可以定義以下內容 (使用我們的聯邦浮點數)

@tff.federated_computation(tff.FederatedType(np.float32, tff.CLIENTS))
def get_average_temperature(client_temperatures):
  return tff.federated_mean(client_temperatures)

您可能會問,這與 TensorFlow 中的 tf.function 裝飾器有何不同?關鍵答案是 tff.federated_computation 產生的程式碼既不是 TensorFlow 也不是 Python 程式碼;它是在內部平台獨立黏合語言中分散式系統的規格。

雖然這聽起來可能很複雜,但您可以將 TFF 運算視為具有明確定義類型簽章的函式。這些類型簽章可以直接查詢。

str(get_average_temperature.type_signature)
'({float32}@CLIENTS -> float32@SERVER)'

tff.federated_computation 接受聯邦類型 {float32}@CLIENTS 的引數,並傳回聯邦類型 {float32}@SERVER 的值。聯邦運算也可能從伺服器到用戶端、從用戶端到用戶端或從伺服器到伺服器。聯邦運算也可以像一般函式一樣組合,只要其類型簽章匹配即可。

為了支援開發,TFF 可讓您以 Python 函式形式叫用 tff.federated_computation。例如,您可以呼叫

get_average_temperature([68.5, 70.3, 69.8])
69.53333

非急切運算和 TensorFlow

有兩個需要注意的關鍵限制。首先,當 Python 直譯器遇到 tff.federated_computation 裝飾器時,函式會追蹤一次並序列化以供日後使用。由於聯邦學習的分散式特性,此日後使用可能會在其他地方發生,例如遠端執行環境。因此,TFF 運算從根本上來說是非急切的。此行為在某種程度上類似於 TensorFlow 中 tf.function 裝飾器的行為。

其次,聯邦運算只能包含聯邦運算子 (例如 tff.federated_mean),它們不能包含 TensorFlow 運算。TensorFlow 程式碼必須限制在以 tff.tensorflow.computation 裝飾的區塊中。大多數普通的 TensorFlow 程式碼都可以直接裝飾,例如以下函式,它接受一個數字並將 0.5 新增至其中。

@tff.tensorflow.computation(np.float32)
def add_half(x):
  return tf.add(x, 0.5)

這些也有類型簽章,但沒有放置。例如,您可以呼叫

str(add_half.type_signature)
'(float32 -> float32)'

這展示了 tff.federated_computationtff.tensorflow.computation 之間的重要差異。前者具有明確的放置,而後者則沒有。

您可以使用 tff.tensorflow.computation 區塊在聯邦運算中指定放置。讓我們建立一個新增一半的函式,但僅適用於用戶端聯邦浮點數。您可以透過使用 tff.federated_map 來執行此操作,它會套用給定的 tff.tensorflow.computation,同時保留放置。

@tff.federated_computation(tff.FederatedType(np.float32, tff.CLIENTS))
def add_half_on_clients(x):
  return tff.federated_map(add_half, x)

此函式幾乎與 add_half 相同,只是它僅接受放置在 tff.CLIENTS 的值,並傳回具有相同放置的值。這可以在其類型簽章中看到

str(add_half_on_clients.type_signature)
'({float32}@CLIENTS -> {float32}@CLIENTS)'

總結

  • TFF 在聯邦值上運作。
  • 每個聯邦值都有一個聯邦類型,其中包含類型 (例如 np.float32) 和放置 (例如 tff.CLIENTS)。
  • 聯邦值可以使用聯邦運算轉換,聯邦運算必須以 tff.federated_computation 和聯邦類型簽章裝飾。
  • TensorFlow 程式碼必須包含在具有 tff.tensorflow.computation 裝飾器的區塊中。
  • 然後,這些區塊可以併入聯邦運算中。

重新探討建構您自己的聯邦學習演算法

現在您已大致瞭解聯邦核心,您可以建構自己的聯邦學習演算法。請記住,在上方,您為我們的演算法定義了 initialize_fnnext_fnnext_fn 將使用您使用純 TensorFlow 程式碼定義的 client_updateserver_update

但是,為了使我們的演算法成為聯邦運算,您需要 next_fninitialize_fn 都是 tff.federated_computation

TensorFlow Federated 區塊

建立初始化運算

初始化函式將非常簡單:您將使用 model_fn 建立模型。但是,請記住,您必須使用 tff.tensorflow.computation 分隔我們的 TensorFlow 程式碼。

@tff.tensorflow.computation
def server_init():
  return tff.learning.models.ModelWeights(*tff_model.initial_weights)

然後,您可以使用 tff.federated_value 將其直接傳遞至聯邦運算。

@tff.federated_computation
def initialize_fn():
  return tff.federated_eval(server_init, tff.SERVER)

建立 next_fn

用戶端和伺服器更新程式碼現在可用於撰寫實際的演算法。首先,您會將 client_update 轉換為 tff.tensorflow.computation,它接受用戶端資料集和伺服器權重,並輸出更新的用戶端權重張量。

您需要對應的類型才能正確裝飾我們的函式。幸運的是,伺服器權重的類型可以直接從我們的模型中擷取。

tf_dataset_type = tff.SequenceType(
    tff.types.tensorflow_to_type(tff_model.input_spec)
)

讓我們看看資料集類型簽章。請記住,您採用了 28 x 28 圖片 (帶有整數標籤) 並將其扁平化。

str(tf_dataset_type)
'<float32[?,784],int32[?,1]>*'

您也可以使用我們上方的 server_init 函式擷取模型權重類型。

model_weights_type = server_init.type_signature.result

檢查類型簽章,您將能夠看到我們模型的架構!

str(model_weights_type)
'<trainable=<float32[784,10],float32[10]>,non_trainable=<>>'

您現在可以為用戶端更新建立我們的 tff.tensorflow.computation

@tff.tensorflow.computation(tf_dataset_type, model_weights_type)
def client_update_fn(tf_dataset, server_weights):
  client_optimizer = tff.learning.optimizers.build_sgdm(learning_rate=0.01)
  return client_update(tff_model, tf_dataset, server_weights, client_optimizer)

伺服器更新的 tff.tensorflow.computation 版本可以使用您已擷取的類型以類似的方式定義。

@tff.tensorflow.computation(model_weights_type)
def server_update_fn(mean_client_weights):
  return server_update(tff_model, mean_client_weights)

最後但並非最不重要的一點是,您需要建立將所有內容整合在一起的 tff.federated_computation。此函式將接受兩個聯邦值,一個對應於伺服器權重 (放置為 tff.SERVER),另一個對應於用戶端資料集 (放置為 tff.CLIENTS)。

請注意,以上定義了這兩種類型!您只需要使用 tff.FederatedType 為它們提供適當的放置即可。

federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)

還記得 FL 演算法的 4 個要素嗎?

  1. 伺服器到用戶端廣播步驟。
  2. 本機用戶端更新步驟。
  3. 用戶端到伺服器上傳步驟。
  4. 伺服器更新步驟。

現在您已建立上述內容,每個部分都可以簡潔地表示為單行 TFF 程式碼。這種簡單性就是為什麼您必須格外小心地指定聯邦類型等事項!

@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = tff.federated_broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = tff.federated_map(
      client_update_fn, (federated_dataset, server_weights_at_client)
  )

  # The server averages these updates.
  mean_client_weights = tff.federated_mean(client_weights)

  # The server updates its model.
  server_weights = tff.federated_map(server_update_fn, mean_client_weights)

  return server_weights

您現在有一個 tff.federated_computation,適用於演算法初始化和執行演算法的一個步驟。若要完成我們的演算法,您可以將這些傳遞至 tff.templates.IterativeProcess

federated_algorithm = tff.templates.IterativeProcess(
    initialize_fn=initialize_fn,
    next_fn=next_fn
)

讓我們看看迭代程序的 initializenext 函式的類型簽章

str(federated_algorithm.initialize.type_signature)
'( -> <trainable=<float32[784,10],float32[10]>,non_trainable=<>>@SERVER)'

這反映了 federated_algorithm.initialize 是一個無引數函式,它傳回單層模型 (具有 784 x 10 權重矩陣和 10 個偏差單位)。

str(federated_algorithm.next.type_signature)
'(<server_weights=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>@SERVER,federated_dataset={<float32[?,784],int32[?,1]>*}@CLIENTS> -> <trainable=<float32[784,10],float32[10]>,non_trainable=<>>@SERVER)'

在此,您可以看到 federated_algorithm.next 接受伺服器模型和用戶端資料,並傳回更新的伺服器模型。

評估演算法

讓我們執行幾個回合,看看損失如何變化。首先,您將使用第二個教學課程中討論的集中式方法定義評估函式。

您將首先建立集中式評估資料集,然後套用您用於訓練資料的相同預先處理。

central_emnist_test = emnist_test.create_tf_dataset_from_all_clients()
central_emnist_test = preprocess(central_emnist_test)

接下來,您將撰寫一個接受伺服器狀態的函式,並使用 Keras 在測試資料集上進行評估。如果您熟悉 tf.Keras,這一切看起來都很熟悉,但請注意 set_weights 的使用!

def evaluate(model_weights):
  keras_model = create_keras_model()
  keras_model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
  )
  model_weights.assign_weights_to(keras_model)
  keras_model.evaluate(central_emnist_test)

現在,讓我們初始化我們的演算法並在測試集上評估。

server_state = federated_algorithm.initialize()
evaluate(server_state)
2042/2042 [==============================] - 26s 10ms/step - loss: 2.8479 - sparse_categorical_accuracy: 0.1027

讓我們訓練幾個回合,看看是否有任何變化。

for _ in range(15):
  server_state = federated_algorithm.next(server_state, federated_train_data)
evaluate(server_state)
2042/2042 [==============================] - 4s 1ms/step - loss: 2.5867 - sparse_categorical_accuracy: 0.0980

損失函數略有下降。雖然跳躍很小,但您只執行了 15 個訓練回合,並且是在用戶端的小子集上進行的。若要看到更好的結果,您可能必須執行數百次 (如果不是數千次) 回合。

修改我們的演算法

此時,讓我們停下來思考一下您已完成的工作。您已透過將純 TensorFlow 程式碼 (用於用戶端和伺服器更新) 與 TFF 聯邦核心的聯邦運算結合,直接實作了聯邦平均。

若要執行更複雜的學習,您可以簡單地變更您在上方所做的內容。特別是,透過編輯上方的純 TF 程式碼,您可以變更用戶端執行訓練的方式,或伺服器更新其模型的方式。

挑戰:梯度裁剪新增至 client_update 函式。

如果您想要進行更大幅度的變更,也可以讓伺服器儲存及廣播更多資料。例如,伺服器也可以儲存用戶端學習速率,並使其隨時間衰減!請注意,這會需要變更上方 tff.tensorflow.computation 呼叫中使用的類型簽章。

更困難的挑戰: 實作在用戶端上具有學習速率衰減的 Federated Averaging。

此時,您可能會開始意識到在這個架構中您可以實作的彈性有多大。如需想法 (包括上方更困難挑戰的解答),您可以查看 tff.learning.algorithms.build_weighted_fed_avg 的原始碼,或查看使用 TFF 的各種研究專案