![]() |
![]() |
![]() |
![]() |
開始之前
開始之前,請執行以下步驟以確保您的環境設定正確。如果您沒有看到歡迎訊息,請參閱安裝指南以取得相關指示。
pip install --quiet --upgrade tensorflow-federated
import collections
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
組合學習演算法
建構您自己的聯邦學習演算法教學課程使用 TFF 的聯邦核心直接實作聯邦平均 (FedAvg) 演算法的版本。
在本教學課程中,您將使用 TFF API 中的聯邦學習元件,以模組化方式建構聯邦學習演算法,而無需從頭開始重新實作所有內容。
在本教學課程中,您將實作 FedAvg 的變體,該變體透過本機訓練採用梯度裁剪。
學習演算法建構區塊
在高階層次上,許多學習演算法可以分成 4 個獨立的元件,稱為建構區塊。這些區塊如下
- 分配器 (即伺服器到用戶端通訊)
- 用戶端工作 (即本機用戶端運算)
- 彙整器 (即用戶端到伺服器通訊)
- 定案器 (即使用彙整的用戶端輸出進行伺服器運算)
雖然建構您自己的聯邦學習演算法教學課程從頭開始實作所有這些建構區塊,但通常是不必要的。您可以重複使用類似演算法中的建構區塊。
在本例中,若要實作具有梯度裁剪的 FedAvg,您只需要修改用戶端工作建構區塊即可。其餘區塊可以與「原始」FedAvg 中使用的區塊相同。
實作用戶端工作
首先,讓我們編寫執行具有梯度裁剪的本機模型訓練的 TF 邏輯。為了簡化,梯度裁剪的範數最多為 1。
TF 邏輯
@tf.function
def client_update(
model: tff.learning.models.FunctionalModel,
dataset: tf.data.Dataset,
initial_weights: tff.learning.models.ModelWeights,
client_optimizer: tff.learning.optimizers.Optimizer,
):
"""Performs training (using the initial server model weights) on the client's dataset."""
# Keep track of the number of examples.
num_examples = 0.0
# Use the client_optimizer to update the local model.
trainable_weights, non_trainable_weights = (
initial_weights.trainable,
initial_weights.non_trainable,
)
optimizer_state = client_optimizer.initialize(
tf.nest.map_structure(lambda x: tf.TensorSpec, trainable_weights)
)
for batch in dataset:
x, y = batch
with tf.GradientTape() as tape:
tape.watch(trainable_weights)
logits = model.predict_on_batch(
model_weights=(trainable_weights, non_trainable_weights),
x=x,
training=True,
)
num_examples += tf.cast(tf.shape(y)[0], tf.float32)
loss = model.loss(output=logits, label=y)
# Compute the corresponding gradient
grads = tape.gradient(loss, trainable_weights)
# Compute the gradient norm and clip
gradient_norm = tf.linalg.global_norm(grads)
if gradient_norm > 1:
grads = tf.nest.map_structure(lambda x: x / gradient_norm, grads)
# Apply the gradient using a client optimizer.
optimizer_state, trainable_weights = client_optimizer.next(
optimizer_state, trainable_weights, grads
)
# Compute the difference between the initial weights and the client weights
client_update = tf.nest.map_structure(
tf.subtract, trainable_weights, initial_weights[0]
)
return tff.learning.templates.ClientResult(
update=client_update, update_weight=num_examples
)
關於上述程式碼,有幾個重點。首先,它會追蹤已檢視範例的數量,因為這將構成用戶端更新的權重 (在計算跨用戶端的平均值時)。
其次,它使用 tff.learning.templates.ClientResult
來封裝輸出。此傳回類型用於標準化 tff.learning
中的用戶端工作建構區塊。
建立 ClientWorkProcess
雖然上述 TF 邏輯將執行具有裁剪的本機訓練,但仍需要包裝在 TFF 程式碼中,才能建立必要的建構區塊。
具體而言,4 個建構區塊表示為 tff.templates.MeasuredProcess
。這表示所有 4 個區塊都具有 initialize
和 next
函式,用於例項化和執行運算。
這允許每個建構區塊視需要追蹤自己的狀態 (儲存在伺服器上),以執行其運算。雖然在本教學課程中不會使用它,但它可用於追蹤已發生的迭代次數,或追蹤最佳化工具狀態等。
用戶端工作 TF 邏輯通常應包裝為 tff.learning.templates.ClientWorkProcess
,它編纂了用戶端本機訓練的預期輸入和輸出類型。它可以透過模型和最佳化工具進行參數化,如下所示。
def build_gradient_clipping_client_work(
model: tff.learning.models.FunctionalModel,
optimizer: tff.learning.optimizers.Optimizer,
) -> tff.learning.templates.ClientWorkProcess:
"""Creates a client work process that uses gradient clipping."""
data_type = tff.SequenceType(tff.types.tensorflow_to_type(model.input_spec))
model_weights_type = tff.types.to_type(
tf.nest.map_structure(
lambda arr: tff.types.TensorType(shape=arr.shape, dtype=arr.dtype),
tff.learning.models.ModelWeights(*model.initial_weights),
)
)
@tff.federated_computation
def initialize_fn():
return tff.federated_value((), tff.SERVER)
@tff.tensorflow.computation(model_weights_type, data_type)
def client_update_computation(model_weights, dataset):
return client_update(model, dataset, model_weights, optimizer)
@tff.federated_computation(
initialize_fn.type_signature.result,
tff.FederatedType(model_weights_type, tff.CLIENTS),
tff.FederatedType(data_type, tff.CLIENTS),
)
def next_fn(state, model_weights, client_dataset):
client_result = tff.federated_map(
client_update_computation, (model_weights, client_dataset)
)
# Return empty measurements, though a more complete algorithm might
# measure something here.
measurements = tff.federated_value((), tff.SERVER)
return tff.templates.MeasuredProcessOutput(
state, client_result, measurements
)
return tff.learning.templates.ClientWorkProcess(initialize_fn, next_fn)
組合學習演算法
讓我們將上述用戶端工作放入功能完整的演算法中。首先,讓我們設定資料和模型。
準備輸入資料
載入並預先處理 TFF 中包含的 EMNIST 資料集。如需更多詳細資訊,請參閱影像分類教學課程。
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()
為了將資料集饋送到我們的模型中,資料會被扁平化並轉換為 (flattened_image_vector, label)
形式的元組。
讓我們選取少量用戶端,並將上述預先處理套用至其資料集。
NUM_CLIENTS = 10
BATCH_SIZE = 20
def preprocess(dataset):
def batch_format_fn(element):
"""Flatten a batch of EMNIST data and return a (features, label) tuple."""
return (tf.reshape(element['pixels'], [-1, 784]),
tf.reshape(element['label'], [-1, 1]))
return dataset.batch(BATCH_SIZE).map(batch_format_fn)
client_ids = sorted(emnist_train.client_ids)[:NUM_CLIENTS]
federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))
for x in client_ids
]
準備模型
這使用與影像分類教學課程中相同的模型。此模型 (透過 tf.keras
實作) 具有單一隱藏層,後接 softmax 層。為了在 TFF 中使用此模型,Keras 模型會包裝為 tff.learning.models.FunctionalModel
。這允許我們執行模型的 前向傳遞 aggregator_factory = tff.aggregators.MeanFactory() aggregator = aggregator_factory.create( model_weights_type.trainable, tff.TensorType(np.float32) ) finalizer = tff.learning.templates.build_apply_optimizer_finalizer( server_optimizer, model_weights_type )
initializer = tf.keras.initializers.GlorotNormal(seed=0)
keras_model = tf.keras.models.Sequential([
tf.keras.layers.Input(shape=(784,)),
tf.keras.layers.Dense(10, kernel_initializer=initializer),
tf.keras.layers.Softmax(),
])
tff_model = tff.learning.models.functional_model_from_keras(
keras_model,
loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(),
input_spec=federated_train_data[0].element_spec,
metrics_constructor=collections.OrderedDict(
accuracy=tf.keras.metrics.SparseCategoricalAccuracy
),
)
準備最佳化工具
如同 tff.learning.algorithms.build_weighted_fed_avg
中一樣,這裡有兩個最佳化工具:用戶端最佳化工具和伺服器最佳化工具。為了簡化,最佳化工具將是具有不同學習率的 SGD。
client_optimizer = tff.learning.optimizers.build_sgdm(learning_rate=0.01)
server_optimizer = tff.learning.optimizers.build_sgdm(learning_rate=1.0)
定義建構區塊
既然已設定用戶端工作建構區塊、資料、模型和最佳化工具,剩下的就是為分配器、彙整器和定案器建立建構區塊。這可以透過借用 TFF 中可用的某些預設值以及 FedAvg 使用的預設值來完成。
@tff.tensorflow.computation
def initial_model_weights_fn():
return tff.learning.models.ModelWeights(*tff_model.initial_weights)
model_weights_type = initial_model_weights_fn.type_signature.result
distributor = tff.learning.templates.build_broadcast_process(model_weights_type)
client_work = build_gradient_clipping_client_work(tff_model, client_optimizer)
# TFF aggregators use a factory pattern, which create an aggregator
# based on the output type of the client work. This also uses a float (the number
# of examples) to govern the weight in the average being computed.)
aggregator_factory = tff.aggregators.MeanFactory()
aggregator = aggregator_factory.create(
model_weights_type.trainable, tff.TensorType(np.float32)
)
finalizer = tff.learning.templates.build_apply_optimizer_finalizer(
server_optimizer, model_weights_type
)
組合建構區塊
最後,您可以使用 TFF 中的內建組合器來將建構區塊組合在一起。這是一個相對簡單的組合器,它採用上述 4 個建構區塊,並將它們的類型連接在一起。
fed_avg_with_clipping = tff.learning.templates.compose_learning_process(
initial_model_weights_fn,
distributor,
client_work,
aggregator,
finalizer
)
執行演算法
現在演算法已完成,讓我們執行它。首先,初始化演算法。此演算法的狀態為每個建構區塊都有一個元件,以及一個用於全域模型權重的元件。
state = fed_avg_with_clipping.initialize()
state.client_work
()
如預期的那樣,用戶端工作具有空狀態 (請記住上面的用戶端工作程式碼!)。但是,其他建構區塊可能具有非空狀態。例如,定案器會追蹤已發生的迭代次數。由於 next
尚未執行,因此其狀態為 0
。
state.finalizer
OrderedDict([('learning_rate', 1.0)])
現在執行訓練回合。
learning_process_output = fed_avg_with_clipping.next(state, federated_train_data)
此輸出 (tff.learning.templates.LearningProcessOutput
) 同時具有 .state
和 .metrics
輸出。讓我們看看兩者。
learning_process_output.state.finalizer
OrderedDict([('learning_rate', 1.0)])
顯然,定案器狀態已遞增一,因為已執行一輪 .next
。
learning_process_output.metrics
OrderedDict([('distributor', ()), ('client_work', ()), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
雖然指標為空,但對於更複雜和實用的演算法,它們通常會充滿有用的資訊。
結論
透過使用上述建構區塊/組合器架構,您可以建立全新的學習演算法,而無需從頭開始重新執行所有操作。但是,這僅僅是起點。此架構使將演算法表示為 FedAvg 的簡單修改變得更加容易。如需更多演算法,請參閱 tff.learning.algorithms
,其中包含諸如 FedProx 和 具有用戶端學習率排程的 FedAvg 等演算法。這些 API 甚至可以協助實作全新的演算法,例如 聯邦 k 平均值叢集。