影像分類聯邦學習

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視原始碼 下載筆記本

在本教學中,我們使用經典的 MNIST 訓練範例來介紹 TFF 的聯邦學習 (FL) API 層,`tff.learning` - 一組更高等級的介面,可用於執行常見的聯邦學習任務類型,例如針對使用者提供的以 TensorFlow 實作的模型進行聯邦訓練。

本教學和聯邦學習 API 主要適用於想要將自己的 TensorFlow 模型插入 TFF 的使用者,將後者主要視為黑盒子。如需更深入瞭解 TFF 以及如何實作您自己的聯邦學習演算法,請參閱 FC Core API 的教學 - 自訂聯邦演算法第 1 部分第 2 部分

如需更多關於 `tff.learning` 的資訊,請繼續閱讀文字生成聯邦學習教學,除了涵蓋循環模型外,該教學也示範了如何載入預先訓練的序列化 Keras 模型,以便使用聯邦學習進行精煉,並結合使用 Keras 進行評估。

開始之前

開始之前,請執行以下程式碼,以確保您的環境設定正確。如果您沒有看到歡迎訊息,請參閱安裝指南以取得操作說明。

pip install --quiet --upgrade tensorflow-federated
%load_ext tensorboard
Fetching TensorBoard MPM version 'live'... done.
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

np.random.seed(0)

tff.federated_computation(lambda: 'Hello, World!')()
b'Hello, World!'

準備輸入資料

讓我們先從資料開始。聯邦學習需要聯邦資料集,即來自多個使用者的資料集合。聯邦資料通常是非 i.i.d.,這帶來了一系列獨特的挑戰。

為了方便實驗,我們在 TFF 儲存庫中加入了幾個資料集,包括 MNIST 的聯邦版本,其中包含使用 Leaf 重新處理的原始 NIST 資料集版本,以便資料按數字的原始書寫者進行鍵控。由於每位書寫者都有獨特的風格,因此這個資料集展現了聯邦資料集預期的非 i.i.d. 行為。

以下是如何載入它。

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

`load_data()` 傳回的資料集是 `tff.simulation.ClientData` 的執行個體,這是一個介面,可讓您列舉使用者集合、建構代表特定使用者資料的 `tf.data.Dataset`,以及查詢個別元素的結構。以下說明如何使用此介面來探索資料集的內容。請記住,雖然此介面可讓您逐一查看用戶端 ID,但這只是模擬資料的功能。您很快就會看到,聯邦學習架構不會使用用戶端身分 - 它們的唯一目的是讓您選取資料子集以進行模擬。

len(emnist_train.client_ids)
3383
emnist_train.element_type_structure
OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])
example_dataset = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0])

example_element = next(iter(example_dataset))

example_element['label'].numpy()
1
from matplotlib import pyplot as plt
plt.imshow(example_element['pixels'].numpy(), cmap='gray', aspect='equal')
plt.grid(False)
_ = plt.show()

探索聯邦資料中的異質性

聯邦資料通常是非 i.i.d.,使用者通常根據使用模式有不同的資料分佈。有些用戶端可能在裝置上的訓練範例較少,在本地端面臨資料匱乏,而有些用戶端則有足夠多的訓練範例。讓我們使用我們可用的 EMNIST 資料來探索聯邦系統典型的資料異質性概念。請務必注意,我們只能對用戶端資料進行這種深入分析,因為這是模擬環境,所有資料都可在本地端取得。在真實的生產聯邦環境中,您將無法檢查單一用戶端的資料。

首先,讓我們取樣一個用戶端的資料,以瞭解一個模擬裝置上的範例。由於我們使用的資料集已按獨特的書寫者進行鍵控,因此一個用戶端的資料代表一個人的筆跡,針對 0 到 9 的數字樣本,模擬一個使用者的獨特「使用模式」。

## Example MNIST digits for one client
figure = plt.figure(figsize=(20, 4))
j = 0

for example in example_dataset.take(40):
  plt.subplot(4, 10, j+1)
  plt.imshow(example['pixels'].numpy(), cmap='gray', aspect='equal')
  plt.axis('off')
  j += 1

現在讓我們視覺化每個用戶端針對每個 MNIST 數字標籤的範例數量。在聯邦環境中,每個用戶端的範例數量可能會因使用者行為而異。

# Number of examples per layer for a sample of clients
f = plt.figure(figsize=(12, 7))
f.suptitle('Label Counts for a Sample of Clients')
for i in range(6):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    # Append counts individually per label to make plots
    # more colorful instead of one color per plot.
    label = example['label'].numpy()
    plot_data[label].append(label)
  plt.subplot(2, 3, i+1)
  plt.title('Client {}'.format(i))
  for j in range(10):
    plt.hist(
        plot_data[j],
        density=False,
        bins=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

現在讓我們視覺化每個用戶端針對每個 MNIST 標籤的平均影像。此程式碼將產生一個標籤的所有使用者範例的每個像素值的平均值。我們會看到,一個用戶端針對一個數字的平均影像看起來會與另一個用戶端針對相同數字的平均影像不同,這是因為每個人的筆跡風格都不同。我們可以思考每個本地訓練回合如何在每個用戶端上將模型朝不同方向推進,因為我們正在從該使用者自己的獨特資料中學習。在本教學稍後部分,我們將看到如何從所有用戶端取得模型的每次更新,並將它們匯總到我們的新全域模型中,該模型已從我們每個用戶端自己的獨特資料中學習。

# Each client has different mean images, meaning each client will be nudging
# the model in their own directions locally.

for i in range(5):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    plot_data[example['label'].numpy()].append(example['pixels'].numpy())
  f = plt.figure(i, figsize=(12, 5))
  f.suptitle("Client #{}'s Mean Image Per Label".format(i))
  for j in range(10):
    mean_img = np.mean(plot_data[j], 0)
    plt.subplot(2, 5, j+1)
    plt.imshow(mean_img.reshape((28, 28)))
    plt.axis('off')

使用者資料可能會有雜訊且標籤不可靠。例如,查看上面用戶端 #2 的資料,我們可以發現對於標籤 2,可能有一些標籤錯誤的範例,導致平均影像雜訊較多。

預先處理輸入資料

由於資料已經是 `tf.data.Dataset`,因此可以使用 Dataset 轉換來完成預先處理。在這裡,我們將 28x28 影像展平為 784 元素陣列、隨機排列個別範例、將它們組織成批次,並將特徵從 `pixels` 和 `label` 重新命名為 `x` 和 `y`,以便與 Keras 搭配使用。我們還在資料集上加入了 `repeat`,以執行多個週期。

NUM_CLIENTS = 10
NUM_EPOCHS = 5
BATCH_SIZE = 20
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER = 10

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch `pixels` and return the features as an `OrderedDict`."""
    return collections.OrderedDict(
        x=tf.reshape(element['pixels'], [-1, 784]),
        y=tf.reshape(element['label'], [-1, 1]))

  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER, seed=1).batch(
      BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)

讓我們驗證這是否有效。

preprocessed_example_dataset = preprocess(example_dataset)

sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                     next(iter(preprocessed_example_dataset)))

sample_batch
OrderedDict([('x', array([[1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       ...,
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.]], dtype=float32)), ('y', array([[2],
       [1],
       [5],
       [7],
       [1],
       [7],
       [7],
       [1],
       [4],
       [7],
       [4],
       [2],
       [2],
       [5],
       [4],
       [1],
       [1],
       [0],
       [0],
       [9]], dtype=int32))])

我們幾乎已準備好所有建構區塊,可以建構聯邦資料集。

在模擬中將聯邦資料饋送至 TFF 的其中一種方法是直接作為 Python 清單,清單的每個元素都包含個別使用者的資料,無論是清單還是 `tf.data.Dataset`。由於我們已經有一個提供後者的介面,因此讓我們使用它。

這是一個簡單的輔助函式,可從給定的使用者集合建構資料集清單,作為訓練或評估回合的輸入。

def make_federated_data(client_data, client_ids):
  return [
      preprocess(client_data.create_tf_dataset_for_client(x))
      for x in client_ids
  ]

現在,我們如何選擇用戶端?

在典型的聯邦訓練情境中,我們處理的可能是非常龐大的使用者裝置群體,其中只有一小部分可能在給定的時間點可用於訓練。例如,當用戶端裝置是手機時,只有在插入電源、關閉計量網路且處於閒置狀態時才會參與訓練,情況就是如此。

當然,我們處於模擬環境中,所有資料都可在本地端取得。因此,通常在執行模擬時,我們會簡單地取樣要參與每個訓練回合的用戶端隨機子集,通常每個回合都不同。

也就是說,正如您可以透過研究關於聯邦平均演算法的論文所發現的那樣,在每個回合中隨機取樣用戶端子集的系統中實現收斂可能需要一段時間,並且在本互動式教學中執行數百個回合是不切實際的。

我們將改為對用戶端集合取樣一次,並在各回合中重複使用相同的集合,以加速收斂(有意過度擬合這些少數使用者的資料)。我們將修改本教學以模擬隨機取樣作為讀者的練習 - 這相當容易做到(一旦您這樣做,請記住讓模型收斂可能需要一段時間)。

sample_clients = emnist_train.client_ids[0:NUM_CLIENTS]

federated_train_data = make_federated_data(emnist_train, sample_clients)

print(f'Number of client datasets: {len(federated_train_data)}')
print(f'First dataset: {federated_train_data[0]}')
Number of client datasets: 10
First dataset: <_PrefetchDataset element_spec=OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int32, name=None))])>

使用 Keras 建立模型

如果您使用 Keras,您可能已經有程式碼可以建構 Keras 模型。以下是一個簡單模型的範例,足以滿足我們的需求。

def create_keras_model():
  return tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])

為了在 TFF 中使用任何模型,需要將其包裝在 `tff.learning.models.VariableModel` 介面的執行個體中,該介面公開了用於標記模型前向傳遞、中繼資料屬性等的方法,類似於 Keras,但也引入了其他元素,例如控制計算聯邦指標的程序的方法。我們先不用擔心這個;如果您有一個像我們剛才在上面定義的 Keras 模型,您可以透過呼叫 `tff.learning.models.from_keras_model`,並將模型和範例資料批次作為引數傳遞,讓 TFF 為您包裝它,如下所示。

def model_fn():
  # We _must_ create a new model here, and _not_ capture it from an external
  # scope. TFF will call this within different graph contexts.
  keras_model = create_keras_model()
  return tff.learning.models.from_keras_model(
      keras_model,
      input_spec=preprocessed_example_dataset.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

在聯邦資料上訓練模型

現在我們有一個模型包裝為 `tff.learning.models.VariableModel` 以便與 TFF 搭配使用,我們可以讓 TFF 透過呼叫輔助函式 `tff.learning.algorithms.build_weighted_fed_avg` 來建構聯邦平均演算法,如下所示。

請記住,引數需要是建構函式(例如上面的 `model_fn`),而不是已建構的執行個體,以便您的模型建構可以在 TFF 控制的環境中發生(如果您對此原因感到好奇,我們鼓勵您閱讀關於自訂演算法的後續教學)。

關於以下聯邦平均演算法的一個重要注意事項是,有 **2** 個最佳化器:_client_optimizer_ 和 _server_optimizer_。_client_optimizer_ 僅用於計算每個用戶端上的本地模型更新。_server_optimizer_ 將平均更新套用至伺服器上的全域模型。特別是,這表示所使用的最佳化器和學習率的選擇可能需要與您用於在標準 i.i.d. 資料集上訓練模型的最佳化器和學習率不同。我們建議從常規 SGD 開始,學習率可能比平常小。我們使用的學習率尚未經過仔細調整,請隨意實驗。

training_process = tff.learning.algorithms.build_weighted_fed_avg(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

剛剛發生了什麼事?TFF 建構了一對聯邦運算,並將它們封裝到 `tff.templates.IterativeProcess` 中,其中這些運算以一對屬性 `initialize` 和 `next` 的形式提供。

簡而言之,聯邦運算是 TFF 內部語言的程式,可以表達各種聯邦演算法(您可以在自訂演算法教學中找到更多相關資訊)。在本例中,產生並封裝到 `iterative_process` 中的兩個運算實作了聯邦平均。

TFF 的目標是以可以在真實聯邦學習設定中執行的方式定義運算,但目前僅實作了本地執行模擬執行階段。若要在模擬器中執行運算,您只需像 Python 函式一樣叫用它即可。此預設解譯環境並非針對高效能而設計,但它足以用於本教學;我們預期在未來的版本中提供更高效能的模擬執行階段,以促進更大規模的研究。

讓我們從 `initialize` 運算開始。與所有聯邦運算一樣,您可以將其視為函式。運算不接受任何引數,並傳回一個結果 - 伺服器上聯邦平均程序狀態的表示。雖然我們不想深入探討 TFF 的細節,但查看此狀態的外觀可能會有所啟發。您可以按如下所示視覺化它。

print(training_process.initialize.type_signature.formatted_representation())
( -> <
  global_model_weights=<
    trainable=<
      float32[784,10],
      float32[10]
    >,
    non_trainable=<>
  >,
  distributor=<>,
  client_work=<>,
  aggregator=<
    value_sum_process=<>,
    weight_sum_process=<>
  >,
  finalizer=<
    int64,
    float32[784,10],
    float32[10]
  >
>@SERVER)

雖然上面的類型簽章起初可能看起來有點神秘,但您可以辨識出伺服器狀態由 `global_model_weights`(將分發到所有裝置的 MNIST 初始模型參數)、一些空參數(例如 `distributor`,它控制伺服器到用戶端的通訊)和 `finalizer` 元件組成。最後一個控制伺服器在回合結束時用來更新其模型的邏輯,並包含一個整數,表示已發生多少回合的 FedAvg。

讓我們叫用 `initialize` 運算來建構伺服器狀態。

train_state = training_process.initialize()

這對聯邦運算中的第二個 `next` 代表單一回合的聯邦平均,其中包括將伺服器狀態(包括模型參數)推送至用戶端、在裝置上針對其本地資料進行訓練、收集和平均模型更新,以及在伺服器上產生新的更新模型。

從概念上講,您可以將 `next` 視為具有如下所示的功能類型簽章。

SERVER_STATE, FEDERATED_DATA -> SERVER_STATE, TRAINING_METRICS

特別是,應該將 `next()` 視為不是在伺服器上執行的函式,而是整個分散式運算的宣告式功能表示 - 某些輸入由伺服器 (`SERVER_STATE`) 提供,但每個參與裝置都會貢獻其自己的本地資料集。

讓我們執行單一回合的訓練並視覺化結果。我們可以針對使用者範例使用我們已在上面產生的聯邦資料。

result = training_process.next(train_state, federated_train_data)
train_state = result.state
train_metrics = result.metrics
print('round  1, metrics={}'.format(train_metrics))
round  1, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.12345679), ('loss', 3.1193733), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])

讓我們再執行幾個回合。如先前所述,通常在此時,您會從每個回合新隨機選取的使用者範例中挑選模擬資料的子集,以便模擬使用者持續來來去去的真實部署,但在本互動式筆記本中,為了示範起見,我們只會重複使用相同的使用者,以便系統快速收斂。

NUM_ROUNDS = 11
for round_num in range(2, NUM_ROUNDS):
  result = training_process.next(train_state, federated_train_data)
  train_state = result.state
  train_metrics = result.metrics
  print('round {:2d}, metrics={}'.format(round_num, train_metrics))
round  2, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.14012346), ('loss', 2.9851403), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
round  3, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.1590535), ('loss', 2.8617127), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
round  4, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.17860082), ('loss', 2.7401376), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
round  5, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.20102881), ('loss', 2.6186547), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
round  6, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.22345679), ('loss', 2.5006158), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
round  7, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.24794239), ('loss', 2.3858356), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
round  8, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.27160493), ('loss', 2.2757034), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
round  9, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.2958848), ('loss', 2.17098), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
round 10, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.3251029), ('loss', 2.072707), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])

訓練損失在每個聯邦訓練回合後都在減少,表示模型正在收斂。但是,這些訓練指標有一些重要的注意事項,請參閱本教學稍後的「評估」章節。

在 TensorBoard 中顯示模型指標

接下來,讓我們使用 TensorBoard 視覺化來自這些聯邦運算的指標。

讓我們從建立目錄和對應的摘要寫入器開始,以將指標寫入其中。

logdir = "/tmp/logs/scalars/training/"
try:
  tf.io.gfile.rmtree(logdir)  # delete any previous results
except tf.errors.NotFoundError as e:
  pass # Ignore if the directory didn't previously exist.
summary_writer = tf.summary.create_file_writer(logdir)
train_state = training_process.initialize()

使用相同的摘要寫入器繪製相關的純量指標。

with summary_writer.as_default():
  for round_num in range(1, NUM_ROUNDS):
    result = training_process.next(train_state, federated_train_data)
    train_state = result.state
    train_metrics = result.metrics
    for name, value in train_metrics['client_work']['train'].items():
      tf.summary.scalar(name, value, step=round_num)

使用上面指定的根記錄目錄啟動 TensorBoard。資料載入可能需要幾秒鐘。

!ls {logdir}
%tensorboard --logdir {logdir} --port=0
# Uncomment and run this cell to clean your directory of old output for
# future graphs from this directory. We don't run it by default so that if 
# you do a "Runtime > Run all" you don't lose your results.

# !rm -R /tmp/logs/scalars/*

為了以相同方式檢視評估指標,您可以建立一個單獨的評估資料夾,例如 "logs/scalars/eval",以寫入 TensorBoard。

自訂模型實作

Keras 是 TensorFlow 建議的高階模型 API,我們鼓勵在 TFF 中盡可能使用 Keras 模型(透過 `tff.learning.models.from_keras_model`)。

然而,`tff.learning` 提供了一個較低階的模型介面 `tff.learning.models.VariableModel`,它公開了使用模型進行聯邦學習所需的最低限度功能。直接實作此介面(可能仍然使用 `tf.keras.layers` 等建構區塊)允許最大程度的自訂,而無需修改聯邦學習演算法的內部結構。

因此,讓我們從頭開始重新做一遍。

定義模型變數、前向傳遞和指標

第一步是識別我們要使用的 TensorFlow 變數。為了使以下程式碼更易於閱讀,讓我們定義一個資料結構來表示整個集合。這將包括我們要訓練的變數,例如 `weights` 和 `bias`,以及將在訓練期間更新的各種累積統計資料和計數器的變數,例如 `loss_sum`、`accuracy_sum` 和 `num_examples`。

MnistVariables = collections.namedtuple(
    'MnistVariables', 'weights bias num_examples loss_sum accuracy_sum')

這是一個建立變數的方法。為了簡單起見,我們將所有統計資料表示為 `tf.float32`,因為這將消除稍後階段進行類型轉換的需要。將變數初始化器包裝為 lambda 是資源變數強加的要求。

def create_mnist_variables():
  return MnistVariables(
      weights=tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(784, 10)),
          name='weights',
          trainable=True),
      bias=tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(10)),
          name='bias',
          trainable=True),
      num_examples=tf.Variable(0.0, name='num_examples', trainable=False),
      loss_sum=tf.Variable(0.0, name='loss_sum', trainable=False),
      accuracy_sum=tf.Variable(0.0, name='accuracy_sum', trainable=False))

有了模型參數和累積統計資料的變數,我們現在可以定義前向傳遞方法,該方法計算損失、發出預測,並更新單一批次輸入資料的累積統計資料,如下所示。

def predict_on_batch(variables, x):
  return tf.nn.softmax(tf.matmul(x, variables.weights) + variables.bias)

def mnist_forward_pass(variables, batch):
  y = predict_on_batch(variables, batch['x'])
  predictions = tf.cast(tf.argmax(y, 1), tf.int32)

  flat_labels = tf.reshape(batch['y'], [-1])
  loss = -tf.reduce_mean(
      tf.reduce_sum(tf.one_hot(flat_labels, 10) * tf.math.log(y), axis=[1]))
  accuracy = tf.reduce_mean(
      tf.cast(tf.equal(predictions, flat_labels), tf.float32))

  num_examples = tf.cast(tf.size(batch['y']), tf.float32)

  variables.num_examples.assign_add(num_examples)
  variables.loss_sum.assign_add(loss * num_examples)
  variables.accuracy_sum.assign_add(accuracy * num_examples)

  return loss, predictions

接下來,我們定義兩個與本地指標相關的函式,再次使用 TensorFlow。

第一個函式 `get_local_unfinalized_metrics` 傳回未最終化的指標值(除了自動處理的模型更新外),這些指標值有資格在聯邦學習或評估過程中匯總到伺服器。

def get_local_unfinalized_metrics(variables):
  return collections.OrderedDict(
      num_examples=[variables.num_examples],
      loss=[variables.loss_sum, variables.num_examples],
      accuracy=[variables.accuracy_sum, variables.num_examples])

第二個函式 `get_metric_finalizers` 傳回 `tf.function` 的 `OrderedDict`,其索引鍵(即指標名稱)與 `get_local_unfinalized_metrics` 相同。每個 `tf.function` 都會接收指標的未最終化值,並計算最終化的指標。

def get_metric_finalizers():
  return collections.OrderedDict(
      num_examples=tf.function(func=lambda x: x[0]),
      loss=tf.function(func=lambda x: x[0] / x[1]),
      accuracy=tf.function(func=lambda x: x[0] / x[1]))

`get_local_unfinalized_metrics` 傳回的本地未最終化指標如何在用戶端之間匯總,由定義聯邦學習或評估程序時的 `metrics_aggregator` 參數指定。例如,在 `tff.learning.algorithms.build_weighted_fed_avg` API(在下一節中顯示)中,`metrics_aggregator` 的預設值為 `tff.learning.metrics.sum_then_finalize`,它首先對來自 `CLIENTS` 的未最終化指標求和,然後在 `SERVER` 處套用指標最終化器。

建構 `tff.learning.models.VariableModel` 的執行個體

有了以上所有內容,我們就準備好建構一個模型表示,以便與 TFF 搭配使用,類似於當您讓 TFF 擷取 Keras 模型時為您產生的模型表示。

import collections
from collections.abc import Callable

class MnistModel(tff.learning.models.VariableModel):

  def __init__(self):
    self._variables = create_mnist_variables()

  @property
  def trainable_variables(self):
    return [self._variables.weights, self._variables.bias]

  @property
  def non_trainable_variables(self):
    return []

  @property
  def local_variables(self):
    return [
        self._variables.num_examples, self._variables.loss_sum,
        self._variables.accuracy_sum
    ]

  @property
  def input_spec(self):
    return collections.OrderedDict(
        x=tf.TensorSpec([None, 784], tf.float32),
        y=tf.TensorSpec([None, 1], tf.int32))

  @tf.function
  def predict_on_batch(self, x, training=True):
    del training
    return predict_on_batch(self._variables, x)

  @tf.function
  def forward_pass(self, batch, training=True):
    del training
    loss, predictions = mnist_forward_pass(self._variables, batch)
    num_exmaples = tf.shape(batch['x'])[0]
    return tff.learning.models.BatchOutput(
        loss=loss, predictions=predictions, num_examples=num_exmaples)

  @tf.function
  def report_local_unfinalized_metrics(
      self) -> collections.OrderedDict[str, list[tf.Tensor]]:
    """Creates an `OrderedDict` of metric names to unfinalized values."""
    return get_local_unfinalized_metrics(self._variables)

  def metric_finalizers(
      self) -> collections.OrderedDict[str, Callable[[list[tf.Tensor]], tf.Tensor]]:
    """Creates an `OrderedDict` of metric names to finalizers."""
    return get_metric_finalizers()

  @tf.function
  def reset_metrics(self):
    """Resets metrics variables to initial value."""
    for var in self.local_variables:
      var.assign(tf.zeros_like(var))

如您所見,`tff.learning.models.VariableModel` 定義的抽象方法和屬性對應於前面章節中介紹變數並定義損失和統計資料的程式碼片段。

以下是一些值得強調的重點:

  • 您的模型將使用的所有狀態都必須擷取為 TensorFlow 變數,因為 TFF 在執行階段不使用 Python(請記住,您的程式碼應該編寫成可以部署到行動裝置;請參閱自訂演算法教學,以取得關於原因的更深入評論)。
  • 您的模型應該描述它接受哪種形式的資料(`input_spec`),因為一般來說,TFF 是一個強型別環境,並且想要判斷所有元件的類型簽章。宣告模型輸入的格式是其中一個重要部分。
  • 雖然技術上不是必需的,但我們建議將所有 TensorFlow 邏輯(前向傳遞、指標計算等)包裝為 `tf.function`,因為這有助於確保 TensorFlow 可以序列化,並消除了對明確控制相依性的需求。

以上內容足以用於評估和聯邦 SGD 等演算法。但是,對於聯邦平均,我們需要指定模型應如何在每個批次上進行本地訓練。我們將在建構聯邦平均演算法時指定本地最佳化器。

使用新模型模擬聯邦訓練

有了以上所有內容,其餘程序看起來就像我們已經看到的 - 只需將模型建構函式替換為我們新模型類別的建構函式,並使用您在迭代程序中建立的兩個聯邦運算來循環執行訓練回合。

training_process = tff.learning.algorithms.build_weighted_fed_avg(
    MnistModel,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02))
train_state = training_process.initialize()
result = training_process.next(train_state, federated_train_data)
train_state = result.state
metrics = result.metrics
print('round  1, metrics={}'.format(metrics))
round  1, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('num_examples', 4860.0), ('loss', 3.119374), ('accuracy', 0.12345679)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
for round_num in range(2, 11):
  result = training_process.next(train_state, federated_train_data)
  train_state = result.state
  metrics = result.metrics
  print('round {:2d}, metrics={}'.format(round_num, metrics))
round  2, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.98514), ('accuracy', 0.14012346)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
round  3, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.8617127), ('accuracy', 0.1590535)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
round  4, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.740137), ('accuracy', 0.17860082)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
round  5, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.6186547), ('accuracy', 0.20102881)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
round  6, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.5006158), ('accuracy', 0.22345679)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
round  7, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.3858361), ('accuracy', 0.24794239)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
round  8, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.275704), ('accuracy', 0.27160493)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
round  9, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.1709805), ('accuracy', 0.2958848)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
round 10, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.0727067), ('accuracy', 0.3251029)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])

若要在 TensorBoard 中查看這些指標,請參閱上面「在 TensorBoard 中顯示模型指標」中列出的步驟。

評估

到目前為止,我們所有的實驗都只呈現了聯邦式訓練指標 - 在所有客戶端上訓練的所有批次資料的平均指標。這引發了對過度擬合的常見擔憂,特別是因為為了簡單起見,我們在每一輪都使用了相同的客戶端集合,但在聯邦平均演算法的訓練指標中,還有一個額外的過度擬合概念。如果我們想像每個客戶端都只有一批資料,並且我們在該批次上訓練多次迭代(epochs),這樣更容易理解。在這種情況下,本地模型將快速精確地擬合到該批次,因此我們平均的本地準確度指標將接近 1.0。因此,這些訓練指標可以被視為訓練正在進行的信號,但僅此而已。

若要對聯邦式資料執行評估,您可以建構另一個聯邦式計算,專門用於此目的,方法是使用 tff.learning.build_federated_evaluation 函數,並傳入您的模型建構函式作為參數。請注意,與我們使用 MnistTrainableModel 的聯邦平均不同,只需傳入 MnistModel 即可。評估不會執行梯度下降,也不需要建構優化器。

對於實驗和研究,當集中式測試資料集可用時,《用於文字生成的聯邦式學習》示範了另一個評估選項:從聯邦式學習中取得訓練好的權重,將它們應用於標準的 Keras 模型,然後簡單地在集中式資料集上呼叫 tf.keras.models.Model.evaluate()

evaluation_process = tff.learning.algorithms.build_fed_eval(MnistModel)

您可以如下檢查評估函數的抽象類型簽章。

print(evaluation_process.next.type_signature.formatted_representation())
(<
  state=<
    global_model_weights=<
      trainable=<
        float32[784,10],
        float32[10]
      >,
      non_trainable=<>
    >,
    distributor=<>,
    client_work=<
      <>,
      <
        num_examples=<
          float32
        >,
        loss=<
          float32,
          float32
        >,
        accuracy=<
          float32,
          float32
        >
      >
    >,
    aggregator=<
      value_sum_process=<>,
      weight_sum_process=<>
    >,
    finalizer=<>
  >@SERVER,
  client_data={<
    x=float32[?,784],
    y=int32[?,1]
  >*}@CLIENTS
> -> <
  state=<
    global_model_weights=<
      trainable=<
        float32[784,10],
        float32[10]
      >,
      non_trainable=<>
    >,
    distributor=<>,
    client_work=<
      <>,
      <
        num_examples=<
          float32
        >,
        loss=<
          float32,
          float32
        >,
        accuracy=<
          float32,
          float32
        >
      >
    >,
    aggregator=<
      value_sum_process=<>,
      weight_sum_process=<>
    >,
    finalizer=<>
  >@SERVER,
  metrics=<
    distributor=<>,
    client_work=<
      eval=<
        current_round_metrics=<
          num_examples=float32,
          loss=float32,
          accuracy=float32
        >,
        total_rounds_metrics=<
          num_examples=float32,
          loss=float32,
          accuracy=float32
        >
      >
    >,
    aggregator=<
      mean_value=<>,
      mean_weight=<>
    >,
    finalizer=<>
  >@SERVER
>)

請注意,評估過程是一個 tff.lenaring.templates.LearningProcess 物件。該物件有一個 initialize 方法,它將建立狀態,但這將首先包含一個未經訓練的模型。使用 set_model_weights 方法,必須插入來自訓練狀態的權重以進行評估。

evaluation_state = evaluation_process.initialize()
model_weights = training_process.get_model_weights(train_state)
evaluation_state = evaluation_process.set_model_weights(evaluation_state, model_weights)

現在,評估狀態包含要評估的模型權重,我們可以透過在過程上呼叫 next 方法,就像在訓練中一樣,使用評估資料集來計算評估指標。

這將再次傳回一個 tff.learning.templates.LearingProcessOutput 實例。

evaluation_output = evaluation_process.next(evaluation_state, federated_train_data)

以下是我們得到的結果。請注意,這些數字看起來比上面最後一輪訓練報告的結果略好。依照慣例,迭代訓練過程報告的訓練指標通常反映模型在訓練輪開始時的效能,因此評估指標將始終領先一步。

str(evaluation_output.metrics)
"OrderedDict([('distributor', ()), ('client_work', OrderedDict([('eval', OrderedDict([('current_round_metrics', OrderedDict([('num_examples', 4860.0), ('loss', 1.6654209), ('accuracy', 0.3621399)])), ('total_rounds_metrics', OrderedDict([('num_examples', 4860.0), ('loss', 1.6654209), ('accuracy', 0.3621399)]))]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])"

現在,讓我們編譯一個聯邦式資料的測試樣本,並在測試資料上重新執行評估。資料將來自真實使用者的相同樣本,但來自不同的保留資料集。

federated_test_data = make_federated_data(emnist_test, sample_clients)

len(federated_test_data), federated_test_data[0]
(10,
 <_PrefetchDataset element_spec=OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int32, name=None))])>)
evaluation_output = evaluation_process.next(evaluation_state, federated_test_data)
str(evaluation_output.metrics)
"OrderedDict([('distributor', ()), ('client_work', OrderedDict([('eval', OrderedDict([('current_round_metrics', OrderedDict([('num_examples', 580.0), ('loss', 1.7750846), ('accuracy', 0.33620688)])), ('total_rounds_metrics', OrderedDict([('num_examples', 580.0), ('loss', 1.7750846), ('accuracy', 0.33620688)]))]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])"

本教學課程到此結束。我們鼓勵您嘗試參數(例如,批次大小、使用者數量、epochs、學習率等),修改上面的程式碼以模擬在每一輪中對使用者的隨機樣本進行訓練,並探索我們開發的其他教學課程。