適用於聯邦學習研究的 TFF:模型與更新壓縮

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視原始碼 下載筆記本

在本教學課程中,我們將使用 EMNIST 資料集示範如何啟用有損壓縮演算法,以減少使用 tff.learning API 的聯邦平均演算法中的通訊成本。如需聯邦平均演算法的更多詳細資訊,請參閱論文 Communication-Efficient Learning of Deep Networks from Decentralized Data

開始之前

開始之前,請執行以下步驟,以確保您的環境設定正確。如果您沒有看到歡迎訊息,請參閱安裝指南以取得操作說明。

pip install --quiet --upgrade tensorflow-federated
pip install --quiet --upgrade tensorflow-model-optimization
%load_ext tensorboard

import functools

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

驗證 TFF 是否正常運作。

@tff.federated_computation
def hello_world():
  return 'Hello, World!'

hello_world()
b'Hello, World!'

準備輸入資料

在本節中,我們將載入並預先處理 TFF 隨附的 EMNIST 資料集。如需 EMNIST 資料集的更多詳細資訊,請查看用於影像分類的聯邦學習教學課程。

# This value only applies to EMNIST dataset, consider choosing appropriate
# values if switching to other datasets.
MAX_CLIENT_DATASET_SIZE = 418

CLIENT_EPOCHS_PER_ROUND = 1
CLIENT_BATCH_SIZE = 20
TEST_BATCH_SIZE = 500

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(
    only_digits=True)

def reshape_emnist_element(element):
  return (tf.expand_dims(element['pixels'], axis=-1), element['label'])

def preprocess_train_dataset(dataset):
  """Preprocessing function for the EMNIST training dataset."""
  return (dataset
          # Shuffle according to the largest client dataset
          .shuffle(buffer_size=MAX_CLIENT_DATASET_SIZE)
          # Repeat to do multiple local epochs
          .repeat(CLIENT_EPOCHS_PER_ROUND)
          # Batch to a fixed client batch size
          .batch(CLIENT_BATCH_SIZE, drop_remainder=False)
          # Preprocessing step
          .map(reshape_emnist_element))

emnist_train = emnist_train.preprocess(preprocess_train_dataset)

定義模型

在這裡,我們定義以原始 FedAvg CNN 為基礎的 keras 模型,然後將 keras 模型包裝在 tff.learning.models.VariableModel 的執行個體中,以便 TFF 可以取用。

請注意,我們需要一個**函式**來產生模型,而不是直接產生模型。此外,此函式**不能**只擷取預先建構的模型,而必須在呼叫模型的環境中建立模型。原因在於 TFF 的設計目的是前往裝置,而且需要控制資源的建構時間,以便擷取和封裝資源。

def create_original_fedavg_cnn_model(only_digits=True):
  """The CNN model used in https://arxiv.org/abs/1602.05629."""
  data_format = 'channels_last'

  max_pool = functools.partial(
      tf.keras.layers.MaxPooling2D,
      pool_size=(2, 2),
      padding='same',
      data_format=data_format)
  conv2d = functools.partial(
      tf.keras.layers.Conv2D,
      kernel_size=5,
      padding='same',
      data_format=data_format,
      activation=tf.nn.relu)

  model = tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),
      conv2d(filters=32),
      max_pool(),
      conv2d(filters=64),
      max_pool(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(512, activation=tf.nn.relu),
      tf.keras.layers.Dense(10 if only_digits else 62),
      tf.keras.layers.Softmax(),
  ])

  return model

# Gets the type information of the input data. TFF is a strongly typed
# functional programming framework, and needs type information about inputs to 
# the model.
input_spec = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0]).element_spec

def tff_model_fn():
  keras_model = create_original_fedavg_cnn_model()
  return tff.learning.models.from_keras_model(
      keras_model=keras_model,
      input_spec=input_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

訓練模型並輸出訓練指標

現在我們準備好建構聯邦平均演算法,並在 EMNIST 資料集上訓練定義的模型。

首先,我們需要使用 tff.learning.algorithms.build_weighted_fed_avg API 建構聯邦平均演算法。

federated_averaging = tff.learning.algorithms.build_weighted_fed_avg(
    model_fn=tff_model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

現在讓我們執行聯邦平均演算法。從 TFF 的角度來看,聯邦學習演算法的執行方式如下

  1. 初始化演算法並取得初始伺服器狀態。伺服器狀態包含執行演算法的必要資訊。請注意,由於 TFF 具有函式功能,因此此狀態同時包含演算法使用的任何最佳化工具狀態 (例如動量項),以及模型參數本身,這些參數將做為引數傳遞,並從 TFF 運算傳回做為結果。
  2. 逐輪執行演算法。在每一輪中,都會傳回新的伺服器狀態,做為每個用戶端在其資料上訓練模型的結果。通常在一輪中
    1. 伺服器將模型廣播給所有參與的用戶端。
    2. 每個用戶端都會根據模型及其本身的資料執行工作。
    3. 伺服器彙總所有模型,以產生包含新模型的伺服器狀態。

如需更多詳細資訊,請參閱自訂聯邦演算法第 2 部分:實作聯邦平均教學課程。

訓練指標會寫入 Tensorboard 目錄,以便在訓練後顯示。

def train(federated_averaging_process, num_rounds, num_clients_per_round, summary_writer):
  """Trains the federated averaging process and output metrics."""

  # Initialize the Federated Averaging algorithm to get the initial server state.
  state = federated_averaging_process.initialize()

  with summary_writer.as_default():
    for round_num in range(num_rounds):
      # Sample the clients parcitipated in this round.
      sampled_clients = np.random.choice(
          emnist_train.client_ids,
          size=num_clients_per_round,
          replace=False)
      # Create a list of `tf.Dataset` instances from the data of sampled clients.
      sampled_train_data = [
          emnist_train.create_tf_dataset_for_client(client)
          for client in sampled_clients
      ]
      # Round one round of the algorithm based on the server state and client data
      # and output the new state and metrics.
      result = federated_averaging_process.next(state, sampled_train_data)
      state = result.state
      train_metrics = result.metrics['client_work']['train']

      # Add metrics to Tensorboard.
      for name, value in train_metrics.items():
          tf.summary.scalar(name, value, step=round_num)
      summary_writer.flush()
# Clean the log directory to avoid conflicts.
try:
  tf.io.gfile.rmtree('/tmp/logs/scalars')
except tf.errors.OpError as e:
  pass  # Path doesn't exist

# Set up the log directory and writer for Tensorboard.
logdir = "/tmp/logs/scalars/original/"
summary_writer = tf.summary.create_file_writer(logdir)

train(federated_averaging_process=federated_averaging, num_rounds=10,
      num_clients_per_round=10, summary_writer=summary_writer)
round  0, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.092454836), ('loss', 2.310193), ('num_examples', 941), ('num_batches', 51)]), broadcasted_bits=507.62Mibit, aggregated_bits=507.62Mibit
round  1, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.10029791), ('loss', 2.3102622), ('num_examples', 1007), ('num_batches', 55)]), broadcasted_bits=1015.24Mibit, aggregated_bits=1015.25Mibit
round  2, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.10710711), ('loss', 2.3048222), ('num_examples', 999), ('num_batches', 54)]), broadcasted_bits=1.49Gibit, aggregated_bits=1.49Gibit
round  3, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.1061061), ('loss', 2.3066027), ('num_examples', 999), ('num_batches', 55)]), broadcasted_bits=1.98Gibit, aggregated_bits=1.98Gibit
round  4, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.1287594), ('loss', 2.2999024), ('num_examples', 1064), ('num_batches', 58)]), broadcasted_bits=2.48Gibit, aggregated_bits=2.48Gibit
round  5, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.13529412), ('loss', 2.2994456), ('num_examples', 1020), ('num_batches', 55)]), broadcasted_bits=2.97Gibit, aggregated_bits=2.97Gibit
round  6, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.124045804), ('loss', 2.2947247), ('num_examples', 1048), ('num_batches', 57)]), broadcasted_bits=3.47Gibit, aggregated_bits=3.47Gibit
round  7, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.14217557), ('loss', 2.290349), ('num_examples', 1048), ('num_batches', 57)]), broadcasted_bits=3.97Gibit, aggregated_bits=3.97Gibit
round  8, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.14641434), ('loss', 2.290953), ('num_examples', 1004), ('num_batches', 56)]), broadcasted_bits=4.46Gibit, aggregated_bits=4.46Gibit
round  9, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.1695238), ('loss', 2.2859888), ('num_examples', 1050), ('num_batches', 57)]), broadcasted_bits=4.96Gibit, aggregated_bits=4.96Gibit

啟動 TensorBoard 並指定上述根記錄目錄,以顯示訓練指標。資料可能需要幾秒才能載入。除了 Loss 和 Accuracy 之外,我們也會輸出廣播和彙總資料的數量。廣播資料是指伺服器推送至每個用戶端的張量,而彙總資料是指每個用戶端傳回伺服器的張量。

%tensorboard --logdir /tmp/logs/scalars/ --port=0

建構自訂彙總函式

現在讓我們實作函式,以便在彙總資料上使用有損壓縮演算法。我們將使用 TFF 的 API 來建立 tff.aggregators.AggregationFactory。雖然研究人員可能經常想要實作自己的演算法 (可透過 tff.aggregators API 完成),但我們將使用內建方法來執行此作業,特別是 tff.learning.compression_aggregator

務必注意,此彙總器不會一次將壓縮套用至整個模型。而是僅將壓縮套用至模型中夠大的變數。一般而言,偏差等小型變數對不精確性更敏感,而且由於相對較小,因此潛在的通訊節省量也相對較小。

compression_aggregator = tff.learning.compression_aggregator()
isinstance(compression_aggregator, tff.aggregators.WeightedAggregationFactory)
True

在上方,您可以看到壓縮彙總器是加權彙總工廠,這表示它涉及加權彙總 (相較於用於差分隱私的彙總器,後者通常未加權)。

此彙總工廠可透過其 `model_aggregator` 引數直接插入 FedAvg。

federated_averaging_with_compression = tff.learning.algorithms.build_weighted_fed_avg(
    tff_model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
    model_aggregator=compression_aggregator)

再次訓練模型

現在讓我們執行新的聯邦平均演算法。

logdir_for_compression = "/tmp/logs/scalars/compression/"
summary_writer_for_compression = tf.summary.create_file_writer(
    logdir_for_compression)

train(federated_averaging_process=federated_averaging_with_compression, 
      num_rounds=10,
      num_clients_per_round=10,
      summary_writer=summary_writer_for_compression)
round  0, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.087804876), ('loss', 2.3126457), ('num_examples', 1025), ('num_batches', 55)]), broadcasted_bits=507.62Mibit, aggregated_bits=146.47Mibit
round  1, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.073267326), ('loss', 2.3111901), ('num_examples', 1010), ('num_batches', 56)]), broadcasted_bits=1015.24Mibit, aggregated_bits=292.93Mibit
round  2, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.08925144), ('loss', 2.3071017), ('num_examples', 1042), ('num_batches', 57)]), broadcasted_bits=1.49Gibit, aggregated_bits=439.40Mibit
round  3, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.07985144), ('loss', 2.3061485), ('num_examples', 1077), ('num_batches', 59)]), broadcasted_bits=1.98Gibit, aggregated_bits=585.86Mibit
round  4, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.11947791), ('loss', 2.302166), ('num_examples', 996), ('num_batches', 55)]), broadcasted_bits=2.48Gibit, aggregated_bits=732.33Mibit
round  5, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.12195122), ('loss', 2.2997446), ('num_examples', 984), ('num_batches', 54)]), broadcasted_bits=2.97Gibit, aggregated_bits=878.79Mibit
round  6, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.10429448), ('loss', 2.2997215), ('num_examples', 978), ('num_batches', 55)]), broadcasted_bits=3.47Gibit, aggregated_bits=1.00Gibit
round  7, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.16857143), ('loss', 2.2961135), ('num_examples', 1050), ('num_batches', 56)]), broadcasted_bits=3.97Gibit, aggregated_bits=1.14Gibit
round  8, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.1399177), ('loss', 2.2942808), ('num_examples', 972), ('num_batches', 54)]), broadcasted_bits=4.46Gibit, aggregated_bits=1.29Gibit
round  9, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.14202899), ('loss', 2.2972558), ('num_examples', 1035), ('num_batches', 57)]), broadcasted_bits=4.96Gibit, aggregated_bits=1.43Gibit

再次啟動 TensorBoard,以比較兩次執行之間的訓練指標。

如您在 Tensorboard 中所見,在 `aggregated_bits` 圖表中,`orginial` 和 `compression` 曲線之間有顯著的減少,而在 `loss` 和 `sparse_categorical_accuracy` 圖表中,這兩條曲線非常相似。

總之,我們實作了一種壓縮演算法,與原始聯邦平均演算法相比,它可以達到類似的效能,同時顯著降低通訊成本。

%tensorboard --logdir /tmp/logs/scalars/ --port=0

練習

若要實作自訂壓縮演算法並將其套用至訓練迴圈,您可以

  1. 將新的壓縮演算法實作為 tff.aggregators.MeanFactory 的子類別。
  2. 使用壓縮演算法執行訓練,以查看其是否優於上述演算法。

潛在有價值的開放研究問題包括:非均勻量化、霍夫曼編碼等無損壓縮,以及根據先前訓練輪次的資訊調整壓縮的機制。

建議閱讀資料