![]() |
![]() |
![]() |
![]() |
在本教學課程中,我們將使用 EMNIST 資料集示範如何啟用有損壓縮演算法,以減少使用 tff.learning
API 的聯邦平均演算法中的通訊成本。如需聯邦平均演算法的更多詳細資訊,請參閱論文 Communication-Efficient Learning of Deep Networks from Decentralized Data。
開始之前
開始之前,請執行以下步驟,以確保您的環境設定正確。如果您沒有看到歡迎訊息,請參閱安裝指南以取得操作說明。
pip install --quiet --upgrade tensorflow-federated
pip install --quiet --upgrade tensorflow-model-optimization
%load_ext tensorboard
import functools
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
驗證 TFF 是否正常運作。
@tff.federated_computation
def hello_world():
return 'Hello, World!'
hello_world()
b'Hello, World!'
準備輸入資料
在本節中,我們將載入並預先處理 TFF 隨附的 EMNIST 資料集。如需 EMNIST 資料集的更多詳細資訊,請查看用於影像分類的聯邦學習教學課程。
# This value only applies to EMNIST dataset, consider choosing appropriate
# values if switching to other datasets.
MAX_CLIENT_DATASET_SIZE = 418
CLIENT_EPOCHS_PER_ROUND = 1
CLIENT_BATCH_SIZE = 20
TEST_BATCH_SIZE = 500
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(
only_digits=True)
def reshape_emnist_element(element):
return (tf.expand_dims(element['pixels'], axis=-1), element['label'])
def preprocess_train_dataset(dataset):
"""Preprocessing function for the EMNIST training dataset."""
return (dataset
# Shuffle according to the largest client dataset
.shuffle(buffer_size=MAX_CLIENT_DATASET_SIZE)
# Repeat to do multiple local epochs
.repeat(CLIENT_EPOCHS_PER_ROUND)
# Batch to a fixed client batch size
.batch(CLIENT_BATCH_SIZE, drop_remainder=False)
# Preprocessing step
.map(reshape_emnist_element))
emnist_train = emnist_train.preprocess(preprocess_train_dataset)
定義模型
在這裡,我們定義以原始 FedAvg CNN 為基礎的 keras 模型,然後將 keras 模型包裝在 tff.learning.models.VariableModel 的執行個體中,以便 TFF 可以取用。
請注意,我們需要一個**函式**來產生模型,而不是直接產生模型。此外,此函式**不能**只擷取預先建構的模型,而必須在呼叫模型的環境中建立模型。原因在於 TFF 的設計目的是前往裝置,而且需要控制資源的建構時間,以便擷取和封裝資源。
def create_original_fedavg_cnn_model(only_digits=True):
"""The CNN model used in https://arxiv.org/abs/1602.05629."""
data_format = 'channels_last'
max_pool = functools.partial(
tf.keras.layers.MaxPooling2D,
pool_size=(2, 2),
padding='same',
data_format=data_format)
conv2d = functools.partial(
tf.keras.layers.Conv2D,
kernel_size=5,
padding='same',
data_format=data_format,
activation=tf.nn.relu)
model = tf.keras.models.Sequential([
tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),
conv2d(filters=32),
max_pool(),
conv2d(filters=64),
max_pool(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(512, activation=tf.nn.relu),
tf.keras.layers.Dense(10 if only_digits else 62),
tf.keras.layers.Softmax(),
])
return model
# Gets the type information of the input data. TFF is a strongly typed
# functional programming framework, and needs type information about inputs to
# the model.
input_spec = emnist_train.create_tf_dataset_for_client(
emnist_train.client_ids[0]).element_spec
def tff_model_fn():
keras_model = create_original_fedavg_cnn_model()
return tff.learning.models.from_keras_model(
keras_model=keras_model,
input_spec=input_spec,
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
訓練模型並輸出訓練指標
現在我們準備好建構聯邦平均演算法,並在 EMNIST 資料集上訓練定義的模型。
首先,我們需要使用 tff.learning.algorithms.build_weighted_fed_avg API 建構聯邦平均演算法。
federated_averaging = tff.learning.algorithms.build_weighted_fed_avg(
model_fn=tff_model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))
現在讓我們執行聯邦平均演算法。從 TFF 的角度來看,聯邦學習演算法的執行方式如下
- 初始化演算法並取得初始伺服器狀態。伺服器狀態包含執行演算法的必要資訊。請注意,由於 TFF 具有函式功能,因此此狀態同時包含演算法使用的任何最佳化工具狀態 (例如動量項),以及模型參數本身,這些參數將做為引數傳遞,並從 TFF 運算傳回做為結果。
- 逐輪執行演算法。在每一輪中,都會傳回新的伺服器狀態,做為每個用戶端在其資料上訓練模型的結果。通常在一輪中
- 伺服器將模型廣播給所有參與的用戶端。
- 每個用戶端都會根據模型及其本身的資料執行工作。
- 伺服器彙總所有模型,以產生包含新模型的伺服器狀態。
如需更多詳細資訊,請參閱自訂聯邦演算法第 2 部分:實作聯邦平均教學課程。
訓練指標會寫入 Tensorboard 目錄,以便在訓練後顯示。
def train(federated_averaging_process, num_rounds, num_clients_per_round, summary_writer):
"""Trains the federated averaging process and output metrics."""
# Initialize the Federated Averaging algorithm to get the initial server state.
state = federated_averaging_process.initialize()
with summary_writer.as_default():
for round_num in range(num_rounds):
# Sample the clients parcitipated in this round.
sampled_clients = np.random.choice(
emnist_train.client_ids,
size=num_clients_per_round,
replace=False)
# Create a list of `tf.Dataset` instances from the data of sampled clients.
sampled_train_data = [
emnist_train.create_tf_dataset_for_client(client)
for client in sampled_clients
]
# Round one round of the algorithm based on the server state and client data
# and output the new state and metrics.
result = federated_averaging_process.next(state, sampled_train_data)
state = result.state
train_metrics = result.metrics['client_work']['train']
# Add metrics to Tensorboard.
for name, value in train_metrics.items():
tf.summary.scalar(name, value, step=round_num)
summary_writer.flush()
# Clean the log directory to avoid conflicts.
try:
tf.io.gfile.rmtree('/tmp/logs/scalars')
except tf.errors.OpError as e:
pass # Path doesn't exist
# Set up the log directory and writer for Tensorboard.
logdir = "/tmp/logs/scalars/original/"
summary_writer = tf.summary.create_file_writer(logdir)
train(federated_averaging_process=federated_averaging, num_rounds=10,
num_clients_per_round=10, summary_writer=summary_writer)
round 0, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.092454836), ('loss', 2.310193), ('num_examples', 941), ('num_batches', 51)]), broadcasted_bits=507.62Mibit, aggregated_bits=507.62Mibit round 1, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.10029791), ('loss', 2.3102622), ('num_examples', 1007), ('num_batches', 55)]), broadcasted_bits=1015.24Mibit, aggregated_bits=1015.25Mibit round 2, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.10710711), ('loss', 2.3048222), ('num_examples', 999), ('num_batches', 54)]), broadcasted_bits=1.49Gibit, aggregated_bits=1.49Gibit round 3, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.1061061), ('loss', 2.3066027), ('num_examples', 999), ('num_batches', 55)]), broadcasted_bits=1.98Gibit, aggregated_bits=1.98Gibit round 4, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.1287594), ('loss', 2.2999024), ('num_examples', 1064), ('num_batches', 58)]), broadcasted_bits=2.48Gibit, aggregated_bits=2.48Gibit round 5, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.13529412), ('loss', 2.2994456), ('num_examples', 1020), ('num_batches', 55)]), broadcasted_bits=2.97Gibit, aggregated_bits=2.97Gibit round 6, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.124045804), ('loss', 2.2947247), ('num_examples', 1048), ('num_batches', 57)]), broadcasted_bits=3.47Gibit, aggregated_bits=3.47Gibit round 7, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.14217557), ('loss', 2.290349), ('num_examples', 1048), ('num_batches', 57)]), broadcasted_bits=3.97Gibit, aggregated_bits=3.97Gibit round 8, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.14641434), ('loss', 2.290953), ('num_examples', 1004), ('num_batches', 56)]), broadcasted_bits=4.46Gibit, aggregated_bits=4.46Gibit round 9, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.1695238), ('loss', 2.2859888), ('num_examples', 1050), ('num_batches', 57)]), broadcasted_bits=4.96Gibit, aggregated_bits=4.96Gibit
啟動 TensorBoard 並指定上述根記錄目錄,以顯示訓練指標。資料可能需要幾秒才能載入。除了 Loss 和 Accuracy 之外,我們也會輸出廣播和彙總資料的數量。廣播資料是指伺服器推送至每個用戶端的張量,而彙總資料是指每個用戶端傳回伺服器的張量。
%tensorboard --logdir /tmp/logs/scalars/ --port=0
建構自訂彙總函式
現在讓我們實作函式,以便在彙總資料上使用有損壓縮演算法。我們將使用 TFF 的 API 來建立 tff.aggregators.AggregationFactory。雖然研究人員可能經常想要實作自己的演算法 (可透過 tff.aggregators API 完成),但我們將使用內建方法來執行此作業,特別是 tff.learning.compression_aggregator。
務必注意,此彙總器不會一次將壓縮套用至整個模型。而是僅將壓縮套用至模型中夠大的變數。一般而言,偏差等小型變數對不精確性更敏感,而且由於相對較小,因此潛在的通訊節省量也相對較小。
compression_aggregator = tff.learning.compression_aggregator()
isinstance(compression_aggregator, tff.aggregators.WeightedAggregationFactory)
True
在上方,您可以看到壓縮彙總器是加權彙總工廠,這表示它涉及加權彙總 (相較於用於差分隱私的彙總器,後者通常未加權)。
此彙總工廠可透過其 `model_aggregator` 引數直接插入 FedAvg。
federated_averaging_with_compression = tff.learning.algorithms.build_weighted_fed_avg(
tff_model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
model_aggregator=compression_aggregator)
再次訓練模型
現在讓我們執行新的聯邦平均演算法。
logdir_for_compression = "/tmp/logs/scalars/compression/"
summary_writer_for_compression = tf.summary.create_file_writer(
logdir_for_compression)
train(federated_averaging_process=federated_averaging_with_compression,
num_rounds=10,
num_clients_per_round=10,
summary_writer=summary_writer_for_compression)
round 0, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.087804876), ('loss', 2.3126457), ('num_examples', 1025), ('num_batches', 55)]), broadcasted_bits=507.62Mibit, aggregated_bits=146.47Mibit round 1, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.073267326), ('loss', 2.3111901), ('num_examples', 1010), ('num_batches', 56)]), broadcasted_bits=1015.24Mibit, aggregated_bits=292.93Mibit round 2, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.08925144), ('loss', 2.3071017), ('num_examples', 1042), ('num_batches', 57)]), broadcasted_bits=1.49Gibit, aggregated_bits=439.40Mibit round 3, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.07985144), ('loss', 2.3061485), ('num_examples', 1077), ('num_batches', 59)]), broadcasted_bits=1.98Gibit, aggregated_bits=585.86Mibit round 4, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.11947791), ('loss', 2.302166), ('num_examples', 996), ('num_batches', 55)]), broadcasted_bits=2.48Gibit, aggregated_bits=732.33Mibit round 5, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.12195122), ('loss', 2.2997446), ('num_examples', 984), ('num_batches', 54)]), broadcasted_bits=2.97Gibit, aggregated_bits=878.79Mibit round 6, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.10429448), ('loss', 2.2997215), ('num_examples', 978), ('num_batches', 55)]), broadcasted_bits=3.47Gibit, aggregated_bits=1.00Gibit round 7, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.16857143), ('loss', 2.2961135), ('num_examples', 1050), ('num_batches', 56)]), broadcasted_bits=3.97Gibit, aggregated_bits=1.14Gibit round 8, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.1399177), ('loss', 2.2942808), ('num_examples', 972), ('num_batches', 54)]), broadcasted_bits=4.46Gibit, aggregated_bits=1.29Gibit round 9, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.14202899), ('loss', 2.2972558), ('num_examples', 1035), ('num_batches', 57)]), broadcasted_bits=4.96Gibit, aggregated_bits=1.43Gibit
再次啟動 TensorBoard,以比較兩次執行之間的訓練指標。
如您在 Tensorboard 中所見,在 `aggregated_bits` 圖表中,`orginial` 和 `compression` 曲線之間有顯著的減少,而在 `loss` 和 `sparse_categorical_accuracy` 圖表中,這兩條曲線非常相似。
總之,我們實作了一種壓縮演算法,與原始聯邦平均演算法相比,它可以達到類似的效能,同時顯著降低通訊成本。
%tensorboard --logdir /tmp/logs/scalars/ --port=0
練習
若要實作自訂壓縮演算法並將其套用至訓練迴圈,您可以
- 將新的壓縮演算法實作為 tff.aggregators.MeanFactory 的子類別。
- 使用壓縮演算法執行訓練,以查看其是否優於上述演算法。
潛在有價值的開放研究問題包括:非均勻量化、霍夫曼編碼等無損壓縮,以及根據先前訓練輪次的資訊調整壓縮的機制。
建議閱讀資料