在 TFF 中載入遠端資料


在聯邦學習的真實世界應用中,原始訓練資料通常分散在許多裝置或資料孤島中,因此需要特殊的預處理和載入才能使用。

本教學課程說明如何使用 TFF 的 DataBackendDataExecutor 介面載入儲存在這些遠端位置的範例,並使用它們來訓練模型 (使用聯邦學習)。我們將示範如何使用資料載入 API,方法是使用本機儲存的訓練資料集,並模擬範例的取樣,就像資料集已分割到不同的遠端用戶端一樣。當您將本教學課程調整為您的使用案例時,您只需將該資料集換成您自己的分散式資料即可。

如果您是聯邦學習或 TFF 的新手,請考慮閱讀圖片分類聯邦學習,以取得入門知識。

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視原始碼 下載筆記本

開始之前

開始之前,請執行以下步驟,以確保您的環境已正確設定。如需更多資訊,請參閱安裝指南。

設定開放原始碼環境

匯入套件

準備輸入資料

首先,載入 TFF 聯邦版本的 EMNIST 資料集 (來自內建儲存庫)

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

並建構預處理函式,以轉換 EMNIST 資料集中的原始範例。

NUM_EPOCHS = 5
SHUFFLE_BUFFER = 100


def preprocess(dataset):

  def map_fn(element):
    # Rename the features from `pixels` and `label`, to `x` and `y` for use with
    # Keras.
    return collections.OrderedDict(
        # Transform each `28x28` image into a `784`-element array.
        x=tf.reshape(element['pixels'], [-1, 784]),
        y=tf.reshape(element['label'], [-1, 1]))

  # Shuffle the individual examples and `repeat` over several epochs.
  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER, seed=1).map(map_fn)

讓我們驗證這是否有效

# The local dataset corresponding to a single client as tf.data.Dataset.
example_dataset = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0])

preprocessed_example_dataset = preprocess(example_dataset)
print(preprocessed_example_dataset)
<MapDataset element_spec=OrderedDict([('x', TensorSpec(shape=(1, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(1, 1), dtype=tf.int32, name=None))])>

接下來,我們將建構 DataBackend 的實作,以從 EMNIST 資料集中的用戶端載入和預處理本機範例,這對於在聯邦學習期間提取可訓練範例至關重要。

定義如何提取用戶端資料

我們需要 DataBackend 的執行個體,以指示 TFF worker 如何載入和轉換本機資料。

TFF worker 是在邊緣機器上執行的處理序,可為單一或多個邏輯用戶端執行工作。在本範例中,我們將用於訓練的 EMNIST 資料集已依邏輯用戶端分割,且所有 worker 都將在相同的本機環境中執行。因此,我們的 DataBackend 可以參考對應於任何用戶端的資料。但在非實驗性設定中,TFF worker 將分散在個別的遠端機器上,每個機器都對應到一組不同的用戶端,您需要確保 DataBackend 可以根據其本機環境正確解析資料參考。

# A `DataBackend` is a programmatic construct that resolves symbolic references,
# represented as application-specific URIs, to materialized examples that
# TFF operations can process.
class MyDataBackend(tff.framework.DataBackend):

  async def materialize(self, data, type_spec):
    # In this example, the URI contains the Id of a client.
    client_id = int(data.uri[-1])
    # The client Id is used to retrieve the corresponding local data.
    client_dataset = emnist_train.create_tf_dataset_for_client(
        emnist_train.client_ids[client_id])
    # We process the client dataset before returning so its compatible with our
    # model definitions.
    return preprocess(client_dataset)

設定執行階段環境

TFF 運算是由 ExecutionContext 叫用,為了讓 TFF 運算中定義的資料 URI 在執行階段能夠被理解,必須為 TFF worker 定義自訂環境,其中包含指向我們剛建立的 DataBackend 的指標,以便 URI 能夠正確解析。

def ex_fn(device: tf.config.LogicalDevice) -> tff.framework.DataExecutor:
  # A `DataBackend` object is wrapped by a `DataExecutor`, which queries the
  # backend when a TFF worker encounters an operation requires fetching local
  # data.
  return tff.framework.DataExecutor(
      tff.framework.EagerTFExecutor(device), data_backend=MyDataBackend())


# In a distributed setting, this needs to run in the TFF worker as a service
# connecting to some port. The top-level controller feeding TFF computations
# would then connect to this port.
factory = tff.framework.local_executor_factory(leaf_executor_fn=ex_fn)
ctx = tff.framework.SyncExecutionContext(executor_fn=factory)
tff.framework.set_default_context(ctx)

訓練模型

現在我們已準備好使用聯邦學習來訓練模型。讓我們定義 Keras 模型

def create_keras_model():
  return tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])


def model_fn():
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=preprocessed_example_dataset.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

我們可以將此 TFF 包裝的模型定義傳遞至 聯邦平均演算法,方法是叫用協助程式函式 tff.learning.algorithms.build_weighted_fed_avg,如下所示

iterative_process = tff.learning.algorithms.build_weighted_fed_avg(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

state = iterative_process.initialize()

initialize 運算會傳回聯邦平均程序的初始狀態。

若要執行一輪訓練,我們需要建構資料樣本,方法是收集 URI 參考樣本,如下所示

NUM_CLIENTS = 10

element_type = tff.types.StructWithPythonType(
    preprocessed_example_dataset.element_spec,
    container_type=collections.OrderedDict)
dataset_type = tff.types.SequenceType(element_type)

round_data_uris = [f'uri://{i}' for i in range(NUM_CLIENTS)]
round_train_data = tff.framework.CreateDataDescriptor(
    arg_uris=round_data_uris, arg_type=dataset_type)

現在我們可以執行一輪訓練

result = iterative_process.next(state, round_train_data)
state = result.state
metrics = result.metrics
print('round 1, metrics={}'.format(metrics))
round 1, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.11234568), ('loss', 11.965633), ('num_examples', 4860), ('num_batches', 4860)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])

多輪訓練

我們可以為選取用戶端和組裝輸入以擷取本機資料定義 FederatedDataSource 容器。這讓迴圈執行多輪訓練變得方便,並且可以在多個訓練工作中重複使用。

class MyFederatedDataSourceIterator(tff.program.FederatedDataSourceIterator):

  def __init__(self, client_ids: Sequence[str],
               federated_type: tff.FederatedType):
    self._client_ids = client_ids
    self._federated_type = federated_type

  @property
  def federated_type(self) -> tff.FederatedType:
    return self._federated_type

  def select(self, num_clients: Optional[int] = None) -> Any:
    client_ids_sample = random.sample(self._client_ids, num_clients)
    data_uris = [f'uri://{i}' for i in client_ids_sample]
    return tff.framework.CreateDataDescriptor(
        arg_uris=data_uris, arg_type=self._federated_type)


class MyFederatedDataSource(tff.program.FederatedDataSource):

  def __init__(self, client_ids: Sequence[str],
               federated_type: tff.FederatedType):
    self._client_ids = client_ids
    self._federated_type = federated_type
    self._capabilities = [tff.program.Capability.RANDOM_UNIFORM]

  @property
  def federated_type(self) -> tff.FederatedType:
    return self._federated_type

  @property
  def capabilities(self) -> List[tff.program.Capability]:
    return self._capabilities

  def iterator(self) -> tff.program.FederatedDataSourceIterator:
    return MyFederatedDataSourceIterator(self._client_ids, self._federated_type)


train_data_source = MyFederatedDataSource(
    client_ids=emnist_train.client_ids, federated_type=dataset_type)
train_data_iterator = train_data_source.iterator()

現在我們可以像這樣執行聯邦學習訓練迴圈

NUM_ROUNDS = 10

for round_num in range(2, NUM_ROUNDS + 1):
  round_train_data = train_data_iterator.select(NUM_CLIENTS)
  result = iterative_process.next(state, round_train_data)
  state = result.state
  metrics = result.metrics
  print('round {:2d}, metrics={}'.format(round_num, metrics))
round  2, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.12357217), ('loss', 9.161968), ('num_examples', 4815), ('num_batches', 4815)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  3, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.20563674), ('loss', 7.0862083), ('num_examples', 4790), ('num_batches', 4790)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  4, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.30241227), ('loss', 5.6945825), ('num_examples', 4560), ('num_batches', 4560)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  5, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.3867347), ('loss', 4.7210026), ('num_examples', 4900), ('num_batches', 4900)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  6, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.42311886), ('loss', 4.205554), ('num_examples', 4585), ('num_batches', 4585)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  7, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.4501548), ('loss', 4.1297464), ('num_examples', 4845), ('num_batches', 4845)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  8, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.56590474), ('loss', 2.8927681), ('num_examples', 5250), ('num_batches', 5250)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  9, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.59917355), ('loss', 2.7431731), ('num_examples', 4840), ('num_batches', 4840)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round 10, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.5717234), ('loss', 2.9738288), ('num_examples', 4845), ('num_batches', 4845)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])

結論

本教學課程到此結束。我們鼓勵您探索我們開發的其他教學課程,以瞭解 TFF 框架的許多其他功能。