將不同資料傳送給特定用戶端與 tff.federated_select

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視原始碼 下載筆記本

本教學課程示範如何在 TFF 中實作自訂聯邦演算法,這些演算法需要將不同資料傳送給不同用戶端。您可能已熟悉 tff.federated_broadcast,其會將單一伺服器放置的值傳送給所有用戶端。本教學課程著重於將伺服器型值 (server-based value) 的不同部分傳送給不同用戶端的案例。這對於在不同用戶端之間劃分模型各部分,以避免將整個模型傳送給任何單一用戶端可能很有用。

讓我們開始匯入 tensorflowtensorflow_federated

pip install --quiet --upgrade tensorflow-federated
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

根據用戶端資料傳送不同值

考慮我們有一些伺服器放置的清單,我們想要根據一些用戶端放置的資料,將一些元素傳送給每個用戶端的情況。例如,伺服器上的字串清單,以及用戶端上的逗號分隔索引清單以下載。我們可以實作如下

list_of_strings_type = tff.TensorType(np.str_, [None])
# We only ever send exactly two values to each client. The number of keys per
# client must be a fixed number across all clients.
number_of_keys_per_client = 2
keys_type = tff.TensorType(np.int32, [number_of_keys_per_client])
get_size = tff.tensorflow.computation(lambda x: tf.size(x))
select_fn = tff.tensorflow.computation(lambda val, index: tf.gather(val, index))
client_data_type = np.str_

# A function from our client data to the indices of the values we'd like to
# select from the server.
@tff.tensorflow.computation(client_data_type)
def keys_for_client(client_string):
  # We assume our client data is a single string consisting of exactly three
  # comma-separated integers indicating which values to grab from the server.
  split = tf.strings.split([client_string], sep=',')[0]
  return tf.strings.to_number([split[0], split[1]], tf.int32)

@tff.tensorflow.computation(tff.SequenceType(np.str_))
def concatenate(values):
  def reduce_fn(acc, item):
    return tf.cond(tf.math.equal(acc, ''),
                   lambda: item,
                   lambda: tf.strings.join([acc, item], ','))
  return values.reduce('', reduce_fn)

@tff.federated_computation(tff.FederatedType(list_of_strings_type, tff.SERVER), tff.FederatedType(client_data_type, tff.CLIENTS))
def broadcast_based_on_client_data(list_of_strings_at_server, client_data):
  keys_at_clients = tff.federated_map(keys_for_client, client_data)
  max_key = tff.federated_map(get_size, list_of_strings_at_server)
  values_at_clients = tff.federated_select(keys_at_clients, max_key, list_of_strings_at_server, select_fn)
  value_at_clients = tff.federated_map(concatenate, values_at_clients)
  return value_at_clients

然後,我們可以透過提供伺服器放置的字串清單以及每個用戶端的字串資料來模擬我們的運算

client_data = ['0,1', '1,2', '2,0']
broadcast_based_on_client_data(['a', 'b', 'c'], client_data)
[<tf.Tensor: shape=(), dtype=string, numpy=b'a,b'>,
 <tf.Tensor: shape=(), dtype=string, numpy=b'b,c'>,
 <tf.Tensor: shape=(), dtype=string, numpy=b'c,a'>]

將隨機元素傳送給每個用戶端

或者,將伺服器資料的隨機部分傳送給每個用戶端可能很有用。我們可以透過先在每個用戶端上產生隨機金鑰,然後遵循與上述使用的選擇程序類似的程序來實作該功能

@tff.tensorflow.computation(np.int32)
def get_random_key(max_key):
  return tf.random.uniform(shape=[1], minval=0, maxval=max_key, dtype=tf.int32)

list_of_strings_type = tff.TensorType(np.str_, [None])
get_size = tff.tensorflow.computation(lambda x: tf.size(x))
select_fn = tff.tensorflow.computation(lambda val, index: tf.gather(val, index))

@tff.tensorflow.computation(tff.SequenceType(np.str_))
def get_last_element(sequence):
  return sequence.reduce('', lambda _initial_state, val: val)

@tff.federated_computation(tff.FederatedType(list_of_strings_type, tff.SERVER))
def broadcast_random_element(list_of_strings_at_server):
  max_key_at_server = tff.federated_map(get_size, list_of_strings_at_server)
  max_key_at_clients = tff.federated_broadcast(max_key_at_server)
  key_at_clients = tff.federated_map(get_random_key, max_key_at_clients)
  random_string_sequence_at_clients = tff.federated_select(
      key_at_clients, max_key_at_server, list_of_strings_at_server, select_fn)
  # Even though we only passed in a single key, `federated_select` returns a
  # sequence for each client. We only care about the last (and only) element.
  random_string_at_clients = tff.federated_map(get_last_element, random_string_sequence_at_clients)
  return random_string_at_clients

由於我們的 broadcast_random_element 函式不接收任何用戶端放置的資料,因此我們必須使用預設用戶端數量來設定 TFF 模擬執行階段

tff.backends.native.set_sync_local_cpp_execution_context(default_num_clients=3)

然後我們可以模擬選擇。我們可以變更上方的 default_num_clients 和下方的字串清單以產生不同的結果,或者只需重新執行運算即可產生不同的隨機輸出。

broadcast_random_element(tf.convert_to_tensor(['foo', 'bar', 'baz']))