自訂聯邦演算法,第 2 部分:實作聯邦平均

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視原始碼 下載筆記本

本教學課程是分為兩部分的系列教學課程的第二部分,示範如何在 TFF 中使用聯邦核心 (FC) 實作自訂類型的聯邦演算法,聯邦核心是聯邦式學習 (FL) 層 (tff.learning) 的基礎。

我們建議您先閱讀本系列教學課程的第一部分,其中介紹了此處使用的一些重要概念和程式設計抽象概念。

本系列教學課程的第二部分使用第一部分中介紹的機制,實作聯邦式訓練和評估演算法的簡單版本。

我們建議您檢閱影像分類文字產生教學課程,以更深入淺出且更溫和的方式瞭解 TFF 聯邦式學習 API,因為這些教學課程將協助您將我們在此處描述的概念置於脈絡中。

開始之前

在開始之前,請嘗試執行下列「Hello World」範例,以確保您的環境已正確設定。如果無法運作,請參閱安裝指南以取得操作說明。

pip install --quiet --upgrade tensorflow-federated
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
@tff.federated_computation
def hello_world():
  return 'Hello, World!'

hello_world()
b'Hello, World!'

實作聯邦平均

如同用於影像分類的聯邦式學習,我們將使用 MNIST 範例,但由於本教學課程旨在說明低階概念,因此我們將略過 Keras API 和 tff.simulation,改為編寫原始模型程式碼,並從頭開始建構聯邦式資料集。

準備聯邦式資料集

為了示範,我們將模擬一個情境,其中我們有來自 10 位使用者的資料,而且每位使用者都貢獻了如何辨識不同數字的知識。這幾乎是最不符合獨立且恆等分佈 (i.i.d.) 的情境。

首先,讓我們載入標準 MNIST 資料

mnist_train, mnist_test = tf.keras.datasets.mnist.load_data()
[(x.dtype, x.shape) for x in mnist_train]
[(dtype('uint8'), (60000, 28, 28)), (dtype('uint8'), (60000,))]

資料以 Numpy 陣列的形式提供,一個包含圖片,另一個包含數字標籤,兩者的第一個維度都涵蓋個別範例。讓我們編寫一個協助函式,以與我們將聯邦序列饋送到 TFF 運算的方式相容的方式格式化資料,也就是以清單的清單形式 - 外層清單涵蓋使用者 (數字),內層清單涵蓋每個用戶端序列中的資料批次。按照慣例,我們會將每個批次建構為一對名為 xy 的張量,每個張量都有前導批次維度。同時,我們也會將每張圖片展平成為 784 個元素的向量,並將其中的像素重新調整到 0..1 範圍,這樣我們就不必用資料轉換來混淆模型邏輯。

NUM_EXAMPLES_PER_USER = 1000
BATCH_SIZE = 100


def get_data_for_digit(source, digit):
  output_sequence = []
  all_samples = [i for i, d in enumerate(source[1]) if d == digit]
  for i in range(0, min(len(all_samples), NUM_EXAMPLES_PER_USER), BATCH_SIZE):
    batch_samples = all_samples[i:i + BATCH_SIZE]
    output_sequence.append({
        'x':
            np.array([source[0][i].flatten() / 255.0 for i in batch_samples],
                     dtype=np.float32),
        'y':
            np.array([source[1][i] for i in batch_samples], dtype=np.int32)
    })
  return output_sequence


federated_train_data = [get_data_for_digit(mnist_train, d) for d in range(10)]

federated_test_data = [get_data_for_digit(mnist_test, d) for d in range(10)]

作為快速健全性檢查,讓我們查看第五個用戶端 (對應於數字 5 的用戶端) 貢獻的最後一批資料中的 Y 張量。

federated_train_data[5][-1]['y']
array([5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5], dtype=int32)

為了確保萬無一失,我們也來看看與該批次最後一個元素對應的圖片。

from matplotlib import pyplot as plt

plt.imshow(federated_train_data[5][-1]['x'][-1].reshape(28, 28), cmap='gray')
plt.grid(False)
plt.show()

png

關於結合 TensorFlow 和 TFF

在本教學課程中,為了簡潔起見,我們立即使用 tff.tensorflow.computation 修飾詞裝飾引入 TensorFlow 邏輯的函式。不過,對於更複雜的邏輯,我們不建議採用這種模式。偵錯 TensorFlow 可能已經是一項挑戰,而在 TensorFlow 完全序列化然後重新匯入後再進行偵錯,必然會遺失一些中繼資料並限制互動性,使偵錯更具挑戰性。

因此,我們強烈建議將複雜的 TF 邏輯編寫為獨立的 Python 函式 (也就是沒有 tff.tensorflow.computation 修飾詞)。這樣一來,TensorFlow 邏輯就可以在使用 TF 最佳做法和工具 (例如,熱切模式) 的情況下開發和測試,然後再為 TFF 序列化運算 (例如,透過以 Python 函式作為引數叫用 tff.tensorflow.computation)。

定義損失函式

現在我們有了資料,讓我們定義一個可以用於訓練的損失函式。首先,讓我們將輸入類型定義為 TFF 具名元組。由於資料批次的大小可能不同,因此我們將批次維度設定為 None,以表示此維度的大小未知。

BATCH_SPEC = collections.OrderedDict(
    x=tf.TensorSpec(shape=[None, 784], dtype=tf.float32),
    y=tf.TensorSpec(shape=[None], dtype=tf.int32))
BATCH_TYPE = tff.types.tensorflow_to_type(BATCH_SPEC)

str(BATCH_TYPE)
'<x=float32[?,784],y=int32[?]>'

您可能想知道為什麼我們不能只定義一個普通的 Python 類型。回想一下第 1 部分中的討論,我們在其中解釋說,雖然我們可以使用 Python 表示 TFF 運算的邏輯,但在底層,TFF 運算不是 Python。BATCH_TYPE 符號 (如上所定義) 代表抽象 TFF 類型規格。務必區分此抽象 TFF 類型與具體的 Python 表示類型,例如,可用於表示 Python 函式主體中 TFF 類型的容器 (例如 dictcollections.namedtuple)。與 Python 不同,TFF 具有單一抽象類型建構函式 tff.StructType 用於類似元組的容器,其元素可以個別命名或保持未命名。此類型也用於對運算的正式參數建模,因為 TFF 運算正式上只能宣告一個參數和一個結果 - 您很快就會看到範例。

現在,讓我們將模型參數的 TFF 類型定義為 TFF 具名元組,同樣包含權重偏差

MODEL_SPEC = collections.OrderedDict(
    weights=tf.TensorSpec(shape=[784, 10], dtype=tf.float32),
    bias=tf.TensorSpec(shape=[10], dtype=tf.float32))
MODEL_TYPE = tff.types.tensorflow_to_type(MODEL_SPEC)
print(MODEL_TYPE)
<weights=float32[784,10],bias=float32[10]>

有了這些定義,我們現在可以定義給定模型的損失,針對單一批次。請注意 @tff.tensorflow.computation 修飾詞內部的 @tf.function 修飾詞的用法。這讓我們可以使用類似 Python 的語意編寫 TF,即使我們位於 tf.Graph 環境中,而該環境是由 tff.tensorflow.computation 修飾詞所建立。

# NOTE: `forward_pass` is defined separately from `batch_loss` so that it can 
# be later called from within another tf.function. Necessary because a
# @tf.function  decorated method cannot invoke a @tff.tensorflow.computation.

@tf.function
def forward_pass(model, batch):
  predicted_y = tf.nn.softmax(
      tf.matmul(batch['x'], model['weights']) + model['bias'])
  return -tf.reduce_mean(
      tf.reduce_sum(
          tf.one_hot(batch['y'], 10) * tf.math.log(predicted_y), axis=[1]))

@tff.tensorflow.computation(MODEL_TYPE, BATCH_TYPE)
def batch_loss(model, batch):
  return forward_pass(model, batch)

如同預期,給定模型和單一資料批次,運算 batch_loss 會傳回 float32 損失。請注意 MODEL_TYPEBATCH_TYPE 是如何被合併成正式參數的 2 元組;您可以將 batch_loss 的類型識別為 (<MODEL_TYPE,BATCH_TYPE> -> float32)

str(batch_loss.type_signature)
'(<model=<weights=float32[784,10],bias=float32[10]>,batch=<x=float32[?,784],y=int32[?]>> -> float32)'

作為健全性檢查,讓我們建構一個以零填滿的初始模型,並計算我們在上面視覺化的資料批次的損失。

initial_model = collections.OrderedDict(
    weights=np.zeros([784, 10], dtype=np.float32),
    bias=np.zeros([10], dtype=np.float32))

sample_batch = federated_train_data[5][-1]

batch_loss(initial_model, sample_batch)
2.3025851

請注意,我們使用定義為 dict 的初始模型饋送 TFF 運算,即使定義它的 Python 函式的主體將模型參數用作 model['weight']model['bias']batch_loss 呼叫的引數不會簡單地傳遞到該函式的主體。

當我們叫用 batch_loss 時會發生什麼事?batch_loss 的 Python 主體已在定義它的上述儲存格中追蹤和序列化。TFF 在運算定義時間充當 batch_loss 的呼叫者,並在叫用 batch_loss 時充當叫用的目標。在這兩個角色中,TFF 都充當 TFF 抽象類型系統和 Python 表示類型之間的橋樑。在叫用時間,TFF 將接受大多數標準 Python 容器類型 (dictlisttuplecollections.namedtuple 等) 作為抽象 TFF 元組的具體表示。此外,雖然如上所述,TFF 運算正式上只接受單一參數,但如果參數的類型是元組,您可以使用熟悉的 Python 呼叫語法搭配位置和/或關鍵字引數 - 它會如預期般運作。

單一批次的梯度下降

現在,讓我們定義一個使用此損失函式執行單一步梯度下降的運算。請注意在定義此函式時,我們如何使用 batch_loss 作為子元件。您可以在另一個運算的主體內叫用使用 tff.tensorflow.computation 建構的運算,儘管通常這不是必要的 - 如上所述,由於序列化會遺失一些偵錯資訊,因此對於更複雜的運算,通常最好在沒有 tff.tensorflow.computation 修飾詞的情況下編寫和測試所有 TensorFlow。

@tff.tensorflow.computation(MODEL_TYPE, BATCH_TYPE, np.float32)
def batch_train(initial_model, batch, learning_rate):
  # Define a group of model variables and set them to `initial_model`. Must
  # be defined outside the @tf.function.
  model_vars = collections.OrderedDict([
      (name, tf.Variable(name=name, initial_value=value))
      for name, value in initial_model.items()
  ])
  optimizer = tf.keras.optimizers.SGD(learning_rate)

  @tf.function
  def _train_on_batch(model_vars, batch):
    # Perform one step of gradient descent using loss from `batch_loss`.
    with tf.GradientTape() as tape:
      loss = forward_pass(model_vars, batch)
    grads = tape.gradient(loss, model_vars)
    optimizer.apply_gradients(
        zip(tf.nest.flatten(grads), tf.nest.flatten(model_vars)))
    return model_vars

  return _train_on_batch(model_vars, batch)
str(batch_train.type_signature)
'(<initial_model=<weights=float32[784,10],bias=float32[10]>,batch=<x=float32[?,784],y=int32[?]>,learning_rate=float32> -> <weights=float32[784,10],bias=float32[10]>)'

當您在另一個此類函式的主體內叫用使用 tff.tensorflow.computation 修飾詞裝飾的 Python 函式時,內部 TFF 運算的邏輯會嵌入 (基本上是內嵌) 在外部運算的邏輯中。如上所述,如果您要編寫兩個運算,則內部函式 (在本例中為 batch_loss) 最好是常規 Python 或 tf.function,而不是 tff.tensorflow.computation。但是,在這裡我們說明在另一個 tff.tensorflow.computation 內部呼叫一個 tff.tensorflow.computation 基本上會如預期般運作。舉例來說,如果您沒有定義 batch_loss 的 Python 程式碼,而只有其序列化的 TFF 表示,則這可能是必要的。

現在,讓我們將此函式套用幾次到初始模型,看看損失是否會減少。

model = initial_model
losses = []
for _ in range(5):
  model = batch_train(model, sample_batch, 0.1)
  losses.append(batch_loss(model, sample_batch))
losses
[0.19690025, 0.13176318, 0.101132266, 0.08273812, 0.0703014]

一連串本機資料的梯度下降

現在,由於 batch_train 看起來可以運作,讓我們編寫一個類似的訓練函式 local_train,其會使用來自一個使用者的所有批次的完整序列,而不是僅使用單一批次。新的運算現在需要使用 tff.SequenceType(BATCH_TYPE) 而不是 BATCH_TYPE

LOCAL_DATA_TYPE = tff.SequenceType(BATCH_TYPE)

@tff.federated_computation(MODEL_TYPE, np.float32, LOCAL_DATA_TYPE)
def local_train(initial_model, learning_rate, all_batches):

  # Reduction function to apply to each batch.
  @tff.federated_computation((MODEL_TYPE, np.float32), BATCH_TYPE)
  def batch_fn(model_with_lr, batch):
    model, lr = model_with_lr
    return batch_train(model, batch, lr), lr

  trained_model, _ = tff.sequence_reduce(
      all_batches, (initial_model, learning_rate), batch_fn
  )
  return trained_model
str(local_train.type_signature)
'(<initial_model=<weights=float32[784,10],bias=float32[10]>,learning_rate=float32,all_batches=<x=float32[?,784],y=int32[?]>*> -> <weights=float32[784,10],bias=float32[10]>)'

此簡短程式碼區段中隱藏了許多細節,讓我們逐一檢視。

首先,雖然我們可以完全在 TensorFlow 中實作此邏輯,依靠 tf.data.Dataset.reduce 以類似於我們先前做法的方式處理序列,但這次我們選擇以膠合語言表示邏輯,作為 tff.federated_computation。我們使用了聯邦運算子 tff.sequence_reduce 來執行縮減。

運算子 tff.sequence_reduce 的使用方式與 tf.data.Dataset.reduce 類似。您可以將其視為與 tf.data.Dataset.reduce 基本上相同,但用於聯邦運算內部,您可能還記得,聯邦運算不能包含 TensorFlow 程式碼。它是一個範本運算子,具有正式參數 3 元組,其中包含 T 類型元素的序列、縮減的初始狀態 (我們將抽象地將其稱為) 的某些類型 U,以及類型為 (<U,T> -> U)縮減運算子,其透過處理單一元素來變更縮減的狀態。結果是縮減的最終狀態,在按順序處理序列中的所有元素之後。在我們的範例中,縮減的狀態是在資料前置字串上訓練的模型,而元素是資料批次。

其次,請注意,我們再次使用一個運算 (batch_train) 作為另一個運算 (local_train) 內的元件,但不是直接使用。我們無法將其用作縮減運算子,因為它採用額外的參數 - 學習率。為了解決這個問題,我們定義了一個內嵌的聯邦運算 batch_fn,其在其主體中繫結到 local_train 的參數 learning_rate。允許以此方式定義的子運算擷取其父項的正式參數,前提是子運算未在其父項的主體之外叫用。您可以將此模式視為 Python 中 functools.partial 的等效項。

當然,以這種方式擷取 learning_rate 的實際含義是,相同的學習率值會用於所有批次。

現在,讓我們在與貢獻範例批次 (數字 5) 的使用者相同的資料序列上試用新定義的本機訓練函式。

locally_trained_model = local_train(initial_model, 0.1, federated_train_data[5])

它運作了嗎?為了回答這個問題,我們需要實作評估。

本機評估

以下是一種實作本機評估的方法,方法是將所有資料批次的損失加總 (我們也可以計算平均值;我們會將其留給讀者作為練習)。

@tff.federated_computation(MODEL_TYPE, LOCAL_DATA_TYPE)
def local_eval(model, all_batches):

  @tff.tensorflow.computation((MODEL_TYPE, np.float32), BATCH_TYPE)
  def accumulate_evaluation(model_and_accumulator, batch):
    model, accumulator = model_and_accumulator
    return model, accumulator + batch_loss(model, batch)

  _, total_loss = tff.sequence_reduce(
      all_batches, (model, 0.0), accumulate_evaluation
  )
  return total_loss
str(local_eval.type_signature)
'(<model=<weights=float32[784,10],bias=float32[10]>,all_batches=<x=float32[?,784],y=int32[?]>*> -> float32)'

同樣地,此程式碼說明了一些新元素,讓我們逐一檢視。

首先,我們使用了兩個新的聯邦運算子來處理序列:tff.sequence_map,其採用對應函式 T->UT序列,並發出透過逐點套用對應函式取得的 U 序列,以及 tff.sequence_sum,其僅加總所有元素。在這裡,我們將每個資料批次對應到一個損失值,然後將產生的損失值相加,以計算總損失。

請注意,我們可以再次使用 tff.sequence_reduce,但這不是最佳選擇 - 縮減程序依定義是循序的,而對應和總和可以平行計算。當有選擇時,最好堅持使用不限制實作選擇的運算子,以便在未來將我們的 TFF 運算編譯以部署到特定環境時,可以充分利用所有可能的機會來實現更快、更可擴充、更有效率的資源執行。

其次,請注意,就像在 local_train 中一樣,我們需要的元件函式 (batch_loss) 採用的參數多於聯邦運算子 (tff.sequence_map) 預期的參數,因此我們再次定義部分,這次是透過直接將 lambda 包裝為 tff.federated_computation 來內嵌定義。使用包裝函式內嵌函式作為引數是使用 tff.tensorflow.computation 將 TensorFlow 邏輯嵌入 TFF 中的建議方式。

現在,讓我們看看我們的訓練是否有效。

print('initial_model loss =', local_eval(initial_model,
                                         federated_train_data[5]))
print('locally_trained_model loss =',
      local_eval(locally_trained_model, federated_train_data[5]))
initial_model loss = 23.025854
locally_trained_model loss = 0.43484688

確實,損失減少了。但是,如果我們在另一個使用者的資料上評估它會發生什麼事?

print('initial_model loss =', local_eval(initial_model,
                                         federated_train_data[0]))
print('locally_trained_model loss =',
      local_eval(locally_trained_model, federated_train_data[0]))
initial_model loss = 23.025854
locally_trained_model loss = 74.50075

如同預期,情況變得更糟。模型經過訓練以辨識 5,但從未見過 0。這引發了一個問題 - 從全域角度來看,本機訓練如何影響模型的品質?

聯邦式評估

這是我們旅程中的一個點,我們終於回到聯邦類型和聯邦運算 - 我們開始時的主題。以下是一對 TFF 類型定義,用於源自伺服器的模型,以及保留在用戶端上的資料。

SERVER_MODEL_TYPE = tff.FederatedType(MODEL_TYPE, tff.SERVER)
CLIENT_DATA_TYPE = tff.FederatedType(LOCAL_DATA_TYPE, tff.CLIENTS)

有了迄今為止介紹的所有定義,以 TFF 表示聯邦式評估只需一行程式碼 - 我們將模型分配給用戶端,讓每個用戶端在其本機資料部分上叫用本機評估,然後平均損失。以下是一種編寫此程式碼的方法。

@tff.federated_computation(SERVER_MODEL_TYPE, CLIENT_DATA_TYPE)
def federated_eval(model, data):
  return tff.federated_mean(
      tff.federated_map(local_eval, [tff.federated_broadcast(model),  data]))

我們已經在更簡單的情境中看過 tff.federated_meantff.federated_map 的範例,並且在直覺層面上,它們如預期般運作,但此程式碼區段中還有更多值得注意的地方,因此讓我們仔細檢視。

首先,讓我們分解讓每個用戶端在其本機資料部分上叫用本機評估部分。您可能還記得前面的章節,local_eval 的類型簽名形式為 (<MODEL_TYPE, LOCAL_DATA_TYPE> -> float32)

聯邦運算子 tff.federated_map 是一個範本,其接受作為參數的 2 元組,其中包含某些類型 T->U對應函式和類型 {T}@CLIENTS 的聯邦值 (也就是說,與對應函式的參數類型相同的成員組成部分),並傳回類型為 {U}@CLIENTS 的結果。

由於我們將 local_eval 作為對應函式饋送以在每個用戶端的基礎上套用,因此第二個引數應該是聯邦類型 {<MODEL_TYPE, LOCAL_DATA_TYPE>}@CLIENTS,也就是說,在前述章節的命名法中,它應該是聯邦元組。每個用戶端都應該將 local_eval 的完整引數集作為成員組成部分。相反地,我們饋送給它一個 2 個元素的 Python list。這裡發生了什麼事?

確實,這是 TFF 中隱含類型轉換的一個範例,類似於您在其他地方可能遇到的隱含類型轉換,例如,當您將 int 饋送到接受 float 的函式時。目前隱含轉換的使用非常稀少,但我們計劃使其在 TFF 中更加普及,作為盡可能減少樣板程式碼的一種方式。

在此案例中套用的隱含轉換是形式為 {<X,Y>}@Z 的聯邦元組與聯邦值元組 <{X}@Z,{Y}@Z> 之間的等效性。雖然形式上,這兩個是不同的類型簽名,但從程式設計人員的角度來看,Z 中的每個裝置都持有兩個資料單位 XY。這裡發生的事情與 Python 中的 zip 沒有什麼不同,而且確實,我們提供了一個運算子 tff.federated_zip,可讓您明確地執行此類轉換。tff.federated_map 遇到元組作為第二個引數時,它只會為您叫用 tff.federated_zip

鑑於上述情況,您現在應該能夠將運算式 tff.federated_broadcast(model) 識別為表示 TFF 類型 {MODEL_TYPE}@CLIENTS 的值,以及 data 作為 TFF 類型 {LOCAL_DATA_TYPE}@CLIENTS (或簡稱 CLIENT_DATA_TYPE) 的值,兩者都透過隱含的 tff.federated_zip 一起篩選,以形成 tff.federated_map 的第二個引數。

正如您所預期的,運算子 tff.federated_broadcast 只是將資料從伺服器傳輸到用戶端。

現在,讓我們看看我們的本機訓練如何影響系統中的平均損失。

print('initial_model loss =', federated_eval(initial_model,
                                             federated_train_data))
print('locally_trained_model loss =',
      federated_eval(locally_trained_model, federated_train_data))
initial_model loss = 23.025852
locally_trained_model loss = 54.43263

確實,如同預期,損失增加了。為了改善所有使用者的模型,我們需要在所有人的資料上訓練模型。

聯邦式訓練

實作聯邦式訓練最簡單的方法是本機訓練,然後平均模型。這使用與我們已討論過的相同建構區塊和模式,如下所示。

SERVER_FLOAT_TYPE = tff.FederatedType(np.float32, tff.SERVER)


@tff.federated_computation(SERVER_MODEL_TYPE, SERVER_FLOAT_TYPE,
                           CLIENT_DATA_TYPE)
def federated_train(model, learning_rate, data):
  return tff.federated_mean(
      tff.federated_map(local_train, [
          tff.federated_broadcast(model),
          tff.federated_broadcast(learning_rate), data
      ]))

請注意,在 tff.learning 提供的聯邦平均的完整功能實作中,我們傾向於平均模型差異,而不是平均模型,原因有很多,例如,能夠剪輯更新範數以進行壓縮等。

讓我們看看訓練是否有效,方法是執行幾輪訓練,並比較訓練前後的平均損失。

model = initial_model
learning_rate = 0.1
for round_num in range(5):
  model = federated_train(model, learning_rate, federated_train_data)
  learning_rate = learning_rate * 0.9
  loss = federated_eval(model, federated_train_data)
  print('round {}, loss={}'.format(round_num, loss))
round 0, loss=21.60552406311035
round 1, loss=20.365678787231445
round 2, loss=19.27480125427246
round 3, loss=18.31110954284668
round 4, loss=17.457256317138672

為了完整起見,現在我們也在測試資料上執行,以確認我們的模型能妥善泛化。

print('initial_model test loss =',
      federated_eval(initial_model, federated_test_data))
print('trained_model test loss =', federated_eval(model, federated_test_data))
initial_model test loss = 22.795593
trained_model test loss = 17.278767

本教學課程到此結束。

當然,我們的簡化範例並未反映在更實際的情境中需要執行的一些事項 - 例如,我們尚未計算損失以外的指標。我們建議您研究 tff.learning 中聯邦平均的實作,作為更完整的範例,並作為示範我們希望鼓勵的一些編碼實務的方法。