TensorFlow Federated:分散式資料上的機器學習

import collections
import tensorflow as tf
import tensorflow_federated as tff

# Load simulation data.
source, _ = tff.simulation.datasets.emnist.load_data()
def client_data(n):
  return source.create_tf_dataset_for_client(source.client_ids[n]).map(
      lambda e: (tf.reshape(e['pixels'], [-1]), e['label'])
  ).repeat(10).batch(20)

# Pick a subset of client devices to participate in training.
train_data = [client_data(n) for n in range(3)]

# Wrap a Keras model for use with TFF.
keras_model = tf.keras.models.Sequential([
  tf.keras.layers.Dense(
    10, tf.nn.softmax, input_shape=(784,), kernel_initializer='zeros')
])
tff_model = tff.learning.models.functional_model_from_keras(
      keras_model,
      loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(),
      input_spec=train_data[0].element_spec,
      metrics_constructor=collections.OrderedDict(
        accuracy=tf.keras.metrics.SparseCategoricalAccuracy))

# Simulate a few rounds of training with the selected client devices.
trainer = tff.learning.algorithms.build_weighted_fed_avg(
  tff_model,
  client_optimizer_fn=tff.learning.optimizers.build_sgdm(learning_rate=0.1))
state = trainer.initialize()
for _ in range(5):
  result = trainer.next(state, train_data)
  state = result.state
  metrics = result.metrics
  print(metrics['client_work']['train']['accuracy'])
  • TensorFlow Federated (TFF) 是一個適用於分散式資料的機器學習和其他運算的開放原始碼架構。TFF 的開發旨在促進對聯邦學習 (FL) 進行開放研究與實驗,聯邦學習是一種機器學習方法,其中共用全域模型會在許多參與用戶端之間進行訓練,而這些用戶端會將其訓練資料保留在本地端。例如,FL 已用於訓練行動鍵盤的預測模型,而無需將敏感的輸入資料上傳至伺服器。

    TFF 讓開發人員能夠在其模型和資料上模擬內含的聯邦學習演算法,以及試驗新穎的演算法。研究人員會找到許多研究類型的起點和完整範例。TFF 提供的建構區塊也可用於實作非學習運算,例如聯邦分析。TFF 的介面分為兩個主要層級

  • 此層提供一組高階介面,讓開發人員能夠將聯邦訓練和評估的內含實作套用至現有的 TensorFlow 模型。
  • 系統核心是一組低階介面,用於透過在強型別函數式程式設計環境中結合 TensorFlow 與分散式通訊運算子,來簡潔地表達新穎的聯邦演算法。此層也作為我們建構聯邦學習的基礎。
  • TFF 讓開發人員能夠宣告式地表達聯邦運算,以便將其部署到各種執行階段環境。TFF 隨附適用於實驗的高效能多機器模擬執行階段。請造訪教學課程並親自試用!

    如有問題和支援需求,請在 StackOverflow 上使用 tensorflow-federated 標籤找到我們。