使用 tff 的 ClientData。

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視原始碼 下載筆記本

以用戶端 (例如使用者) 為索引鍵的資料集概念,對於 TFF 中建模的聯邦式運算至關重要。TFF 提供 tff.simulation.datasets.ClientData 介面來抽象化此概念,而 TFF 託管的資料集 (stackoverflow、shakespeare、emnist、cifar100 和 gldv2) 全都實作此介面。

如果您使用自己的資料集進行聯邦式學習,TFF 強烈建議您實作 ClientData 介面,或使用 TFF 的其中一個輔助函式來產生代表您磁碟資料的 ClientData,例如 tff.simulation.datasets.ClientData.from_clients_and_fn

由於 TFF 的大多數端對端範例都從 ClientData 物件開始,因此使用您的自訂資料集實作 ClientData 介面,將可讓您更輕鬆地瀏覽以 TFF 撰寫的現有程式碼。此外,ClientData 建構的 tf.data.Datasets 可以直接迭代,以產生 numpy 陣列的結構,因此在移至 TFF 之前,ClientData 物件可以搭配任何以 Python 為基礎的機器學習架構使用。

如果您打算將模擬擴展到多部機器或部署模擬,可以使用幾種模式讓您更輕鬆。以下我們將逐步說明幾種使用 ClientData 和 TFF 的方法,讓您從小型迭代到大型實驗再到生產環境部署的體驗盡可能順暢。

我應該使用哪種模式將 ClientData 傳遞到 TFF 中?

我們將深入討論 TFF ClientData 的兩種用法;如果您屬於以下任一類別,您會明顯偏好其中一種用法。如果不是,您可能需要更詳細地瞭解每種用法的優缺點,才能做出更細緻的選擇。

  • 我想在本機上盡可能快速地迭代;我不需要能夠輕鬆利用 TFF 的分散式執行階段。

    • 您想要將 tf.data.Datasets 直接傳遞到 TFF 中。
    • 這可讓您以命令式方式使用 tf.data.Dataset 物件進行程式設計,並任意處理這些物件。
    • 它比以下選項提供更高的彈性;將邏輯推送至用戶端需要此邏輯可序列化。
  • 我想在 TFF 的遠端執行階段中執行我的聯邦式運算,或者我計畫很快這麼做。

    • 在這種情況下,您會想要將資料集建構和預先處理對應到用戶端。
    • 這樣做會讓您只需將 client_ids 清單直接傳遞到您的聯邦式運算。
    • 將資料集建構和預先處理推送至用戶端,可避免序列化瓶頸,並顯著提升數百到數千個用戶端的效能。

設定開放原始碼環境

匯入套件

操作 ClientData 物件

讓我們先載入並探索 TFF 的 EMNIST ClientData

client_data, _ = tff.simulation.datasets.emnist.load_data()

檢查第一個資料集可以告訴我們 ClientData 中範例的類型。

first_client_id = client_data.client_ids[0]
first_client_dataset = client_data.create_tf_dataset_for_client(
    first_client_id)
print(first_client_dataset.element_spec)
# This information is also available as a `ClientData` property:
assert client_data.element_type_structure == first_client_dataset.element_spec
OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])

請注意,資料集會產生具有 pixelslabel 索引鍵的 collections.OrderedDict 物件,其中 pixels 是形狀為 [28, 28] 的張量。假設我們希望將輸入攤平為形狀 [784]。一種可能的做法是對我們的 ClientData 物件套用預先處理函式。

def preprocess_dataset(dataset):
  """Create batches of 5 examples, and limit to 3 batches."""

  def map_fn(input):
    return collections.OrderedDict(
        x=tf.reshape(input['pixels'], shape=(-1, 784)),
        y=tf.cast(tf.reshape(input['label'], shape=(-1, 1)), tf.int64),
    )

  return dataset.batch(5).map(
      map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE).take(5)


preprocessed_client_data = client_data.preprocess(preprocess_dataset)

# Notice that we have both reshaped and renamed the elements of the ordered dict.
first_client_dataset = preprocessed_client_data.create_tf_dataset_for_client(
    first_client_id)
print(first_client_dataset.element_spec)
OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int64, name=None))])

此外,我們可能還想執行一些更複雜 (且可能具狀態) 的預先處理,例如隨機排序。

def preprocess_and_shuffle(dataset):
  """Applies `preprocess_dataset` above and shuffles the result."""
  preprocessed = preprocess_dataset(dataset)
  return preprocessed.shuffle(buffer_size=5)

preprocessed_and_shuffled = client_data.preprocess(preprocess_and_shuffle)

# The type signature will remain the same, but the batches will be shuffled.
first_client_dataset = preprocessed_and_shuffled.create_tf_dataset_for_client(
    first_client_id)
print(first_client_dataset.element_spec)
OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int64, name=None))])

tff.Computation 介接

現在我們可以對 ClientData 物件執行一些基本操作,我們已準備好將資料饋送至 tff.Computation。我們定義一個 tff.templates.IterativeProcess,其會實作聯邦平均,並探索將資料傳遞給它的不同方法。

keras_model = tf.keras.models.Sequential([
    tf.keras.layers.InputLayer(input_shape=(784,)),
    tf.keras.layers.Dense(10, kernel_initializer='zeros'),
])
tff_model = tff.learning.models.functional_model_from_keras(
    keras_model,
    loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    # Note: input spec is the _batched_ shape, and includes the
    # label tensor which will be passed to the loss function. This model is
    # therefore configured to accept data _after_ it has been preprocessed.
    input_spec=collections.OrderedDict(
        x=tf.TensorSpec(shape=[None, 784], dtype=tf.float32),
        y=tf.TensorSpec(shape=[None, 1], dtype=tf.int64),
    ),
    metrics_constructor=collections.OrderedDict(
        loss=lambda: tf.keras.metrics.SparseCategoricalCrossentropy(
            from_logits=True
        ),
        accuracy=tf.keras.metrics.SparseCategoricalAccuracy,
    ),
)

trainer = tff.learning.algorithms.build_weighted_fed_avg(
    tff_model,
    client_optimizer_fn=tff.learning.optimizers.build_sgdm(learning_rate=0.01),
)

在我們開始使用此 IterativeProcess 之前,需要先說明一下 ClientData 的語意。ClientData 物件代表可用於聯邦式訓練的整個母體,一般而言,生產環境 FL 系統的執行環境無法使用,而且是模擬專用的。ClientData 的確讓使用者能夠完全略過聯邦式運算,並透過 ClientData.create_tf_dataset_from_all_clients 像平常一樣訓練伺服器端模型。

TFF 的模擬環境讓研究人員完全掌控外迴圈。特別是,這表示用戶端可用性、用戶端退出等考量事項,必須由使用者或 Python 驅動程式指令碼處理。例如,可以透過調整 ClientData client_ids 的取樣分佈來模擬用戶端退出,讓資料較多的使用者 (以及對應的執行時間較長的本機運算) 以較低的機率被選取。

然而,在真實的聯邦式系統中,模型訓練器無法明確選取用戶端;用戶端的選取會委派給執行聯邦式運算的系統。

tf.data.Datasets 直接傳遞至 TFF

我們在 ClientDataIterativeProcess 之間介接的一個選項,是在 Python 中建構 tf.data.Datasets,並將這些資料集傳遞至 TFF。

請注意,如果我們使用預先處理的 ClientData,則我們產生的資料集會是我們上方定義的模型所預期的適當類型。

selected_client_ids = preprocessed_and_shuffled.client_ids[:10]

preprocessed_data_for_clients = [
    preprocessed_and_shuffled.create_tf_dataset_for_client(
        selected_client_ids[i]
    )
    for i in range(10)
]

state = trainer.initialize()
for _ in range(5):
  t1 = time.time()
  result = trainer.next(state, preprocessed_data_for_clients)
  state = result.state
  train_metrics = result.metrics['client_work']['train']
  t2 = time.time()
  print(f'loss {train_metrics["loss"]:.2f}, round time {t2 - t1:.2f} seconds')
loss 2.89, round time 2.35 seconds
loss 3.05, round time 2.26 seconds
loss 2.80, round time 0.63 seconds
loss 2.94, round time 3.18 seconds
loss 3.17, round time 2.44 seconds

但是,如果我們採用此途徑,我們將無法輕易地移至多機器模擬。我們在本機 TensorFlow 執行階段中建構的資料集可以擷取周圍 Python 環境的狀態,並且在嘗試參考不再可用的狀態時,會在序列化或還原序列化中失敗。例如,這可能會在 TensorFlow tensor_util.cc 中難以理解的錯誤中顯現。

Check failed: DT_VARIANT == input.dtype() (21 vs. 20)

將建構和預先處理對應到用戶端

為了避免此問題,TFF 建議使用者將資料集例項化和預先處理視為在每個用戶端本機發生的事項,並使用 TFF 的輔助程式或 federated_map,在每個用戶端明確執行此預先處理程式碼。

從概念上講,偏好這樣做的原因很明顯:在 TFF 的本機執行階段中,由於整個聯邦式協調發生在單一機器上,因此用戶端只是「意外地」可以存取全域 Python 環境。此時值得注意的是,類似的想法促成了 TFF 的跨平台、永遠可序列化、功能性哲學。

TFF 透過 ClientData 的屬性 dataset_computation,一個接受 client_id 並傳回相關聯 tf.data.Datasettff.Computation,讓這種變更變得簡單。

請注意,preprocess 僅適用於 dataset_computation;預先處理的 ClientDatadataset_computation 屬性包含我們剛才定義的整個預先處理管線

print('dataset computation without preprocessing:')
print(client_data.dataset_computation.type_signature)
print('\n')
print('dataset computation with preprocessing:')
print(preprocessed_and_shuffled.dataset_computation.type_signature)
dataset computation without preprocessing:
(str -> <label=int32,pixels=float32[28,28]>*)


dataset computation with preprocessing:
(str -> <x=float32[?,784],y=int64[?,1]>*)

我們可以在 Python 執行階段中叫用 dataset_computation 並接收急切資料集,但是當我們與迭代程序或其他運算組合以完全避免在全域急切執行階段中具體化這些資料集時,此方法的真正威力才能發揮出來。TFF 提供輔助函式 tff.simulation.compose_dataset_computation_with_iterative_process,可用於完全做到這一點。

trainer_accepting_ids = tff.simulation.compose_dataset_computation_with_iterative_process(
    preprocessed_and_shuffled.dataset_computation, trainer)

這個 tff.templates.IterativeProcesses 和上方的那個都以相同方式執行;但前者接受預先處理的用戶端資料集,而後者接受代表用戶端 ID 的字串,在其主體中處理資料集建構和預先處理,事實上,state 可以在兩者之間傳遞。

for _ in range(5):
  t1 = time.time()
  result = trainer_accepting_ids.next(state, selected_client_ids)
  state = result.state
  train_metrics = result.metrics['client_work']['train']
  t2 = time.time()
  print(f'loss {train_metrics["loss"]:.2f}, round time {t2 - t1:.2f} seconds')

擴展到大量用戶端

trainer_accepting_ids 可以立即在 TFF 的多機器執行階段中使用,並避免具體化 tf.data.Datasets 和控制器 (因此可避免序列化並將它們傳送給工作站)。

這顯著加快了分散式模擬的速度,尤其是在大量用戶端的情況下,並可實現中繼彙整,以避免類似的序列化/還原序列化額外負荷。

選用深入探討:在 TFF 中手動組合預先處理邏輯

TFF 的設計從一開始就著重於組合性;剛才由 TFF 輔助程式執行的組合類型完全在我們使用者可控制的範圍內。

selected_clients_type = tff.FederatedType(
    preprocessed_and_shuffled.dataset_computation.type_signature.parameter,
    tff.CLIENTS,
)


@tff.federated_computation(
    trainer.next.type_signature.parameter[0], selected_clients_type
)
def new_next(server_state, selected_clients):
  preprocessed_data = tff.federated_map(
      preprocessed_and_shuffled.dataset_computation, selected_clients
  )
  return trainer.next(server_state, preprocessed_data)


manual_trainer_with_preprocessing = tff.templates.IterativeProcess(
    initialize_fn=trainer.initialize, next_fn=new_next
)

我們可以很簡單地手動組合我們剛才定義的預先處理運算與訓練器的 next。事實上,這實際上就是我們使用的輔助程式在幕後執行的操作 (加上執行適當的類型檢查和操作)。我們甚至可以透過稍微不同的方式來表達相同的邏輯,方法是將 preprocess_and_shuffle 序列化為 tff.Computation,並將 federated_map 分解為一個步驟 (建構未預先處理的資料集) 和另一個步驟 (在每個用戶端執行 preprocess_and_shuffle)。

我們可以驗證,這種更手動的路徑會產生與 TFF 輔助程式具有相同類型簽章的運算 (模數參數名稱)

print(trainer_accepting_ids.next.type_signature)
print(manual_trainer_with_preprocessing.next.type_signature)
(<state=<global_model_weights=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,distributor=<>,client_work=<>,aggregator=<value_sum_process=<>,weight_sum_process=<>>,finalizer=<learning_rate=float32>>@SERVER,client_data={str}@CLIENTS> -> <state=<global_model_weights=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,distributor=<>,client_work=<>,aggregator=<value_sum_process=<>,weight_sum_process=<>>,finalizer=<learning_rate=float32>>@SERVER,metrics=<distributor=<>,client_work=<train=<loss=float32,accuracy=float32>>,aggregator=<mean_value=<>,mean_weight=<>>,finalizer=<update_non_finite=int32>>@SERVER>)
(<server_state=<global_model_weights=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,distributor=<>,client_work=<>,aggregator=<value_sum_process=<>,weight_sum_process=<>>,finalizer=<learning_rate=float32>>@SERVER,selected_clients={str}@CLIENTS> -> <state=<global_model_weights=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,distributor=<>,client_work=<>,aggregator=<value_sum_process=<>,weight_sum_process=<>>,finalizer=<learning_rate=float32>>@SERVER,metrics=<distributor=<>,client_work=<train=<loss=float32,accuracy=float32>>,aggregator=<mean_value=<>,mean_weight=<>>,finalizer=<update_non_finite=int32>>@SERVER>)