聯邦學習

總覽

本文件介紹了有助於執行聯邦學習工作的介面,例如使用 TensorFlow 中現有機 learning 模型進行聯邦訓練或評估。在設計這些介面時,我們的首要目標是讓使用者能夠在不必瞭解其底層運作方式的情況下,就能試驗聯邦學習,並針對各種現有模型和資料評估已實作的聯邦學習演算法。我們鼓勵您為平台貢獻。TFF 在設計時已將擴充性和組合性納入考量,我們歡迎各界踴躍投稿;我們很期待看到您的成果!

這個層級提供的介面包含以下三個主要部分

  • 模型。類別和輔助函式可讓您包裝現有模型,以便搭配 TFF 使用。包裝模型可以簡單到只須呼叫單一包裝函式 (例如,tff.learning.models.from_keras_model),或是定義 tff.learning.models.VariableModel 介面的子類別,以獲得完整的自訂彈性。

  • 聯邦運算建構器。輔助函式會使用您的現有模型,建構用於訓練或評估的聯邦運算。

  • 資料集。您可以下載並在 Python 中存取的資料集,用於模擬聯邦學習情境。雖然聯邦學習的設計用途是搭配無法在集中位置輕鬆下載的去中心化資料,但在研究和開發階段,使用可下載並在本機操作的資料進行初步實驗通常會很方便,對於剛接觸這種方法的開發人員而言尤其如此。

這些介面主要在 tff.learning 命名空間中定義,研究資料集和其他模擬相關功能除外,這些功能已歸類在 tff.simulation 中。這個層級是使用 Federated Core (FC) 提供的較低層級介面實作而成,而 Federated Core (FC) 也提供執行階段環境。

在繼續之前,我們建議您先查看關於圖片分類文字產生的教學課程,因為這些課程使用具體範例介紹了此處描述的大部分概念。如果您有興趣進一步瞭解 TFF 的運作方式,建議您瀏覽自訂演算法教學課程,瞭解我們用來表達聯邦運算邏輯的較低層級介面,並研究 tff.learning 介面的現有實作方式。

模型

架構假設

序列化

TFF 的目標是支援各種分散式學習情境,在這些情境中,您編寫的機器學習模型程式碼可能會在大量功能各異的異質用戶端上執行。雖然在某些應用程式中,這些用戶端可能是功能強大的資料庫伺服器,但我們的平台想要支援的許多重要用途都涉及資源有限的行動裝置和嵌入式裝置。我們無法假設這些裝置能夠託管 Python 執行階段;此時我們唯一可以假設的是,它們能夠託管本機 TensorFlow 執行階段。因此,我們在 TFF 中做出的一項基本架構假設是,您的模型程式碼必須可序列化為 TensorFlow 圖表。

您可以 (而且應該) 仍然遵循最新的最佳做法 (例如使用 Eager 模式) 開發 TF 程式碼。但是,最終程式碼必須可序列化 (例如,可以包裝為適用於 Eager 模式程式碼的 tf.function)。這可確保執行階段所需的任何 Python 狀態或控制流程都可以序列化 (可能需要 Autograph 的協助)。

目前,TensorFlow 尚未完全支援序列化和還原序列化 Eager 模式 TensorFlow。因此,TFF 中的序列化目前遵循 TF 1.0 模式,其中所有程式碼都必須在 TFF 控制的 tf.Graph 內建構。這表示 TFF 目前無法取用已建構的模型;相反地,模型定義邏輯會封裝在不含引數的函式中,該函式會傳回 tff.learning.models.VariableModel。然後 TFF 會呼叫此函式,以確保模型的所有組件都已序列化。此外,由於 TFF 是強型別環境,因此需要一點額外的中繼資料,例如模型輸入類型的規格。

彙總

我們強烈建議大多數使用者使用 Keras 建構模型,請參閱下方的「Keras 轉換器」一節。這些包裝函式會自動處理模型更新的彙總,以及為模型定義的任何指標。但是,瞭解一般 tff.learning.models.VariableModel 的彙總處理方式可能仍然很有用。

聯邦學習中始終至少有兩個彙總層級:本機裝置端彙總,以及跨裝置 (或聯邦式) 彙總

  • **本機彙總**。此層級的彙總是指跨個別用戶端擁有的多個範例批次的彙總。它適用於模型參數 (變數),這些參數會隨著模型在本機訓練而持續循序演進,也適用於您計算的統計資料 (例如平均損失、準確度和其他指標),您的模型會在疊代處理每個個別用戶端的本機資料串流時再次在本機更新這些統計資料。

    在此層級執行彙總是模型程式碼的責任,並且是使用標準 TensorFlow 建構完成的。

    一般處理結構如下:

    • 模型首先建構 tf.Variable 以保留彙總值,例如批次數量或已處理範例的數量、每批次或每範例損失的總和等等。

    • TFF 會在您的 Model 上多次叫用 forward_pass 方法,依序處理後續的用戶端資料批次,這可讓您更新保存各種彙總值的變數,作為副作用。

    • 最後,TFF 會在您的 Model 上叫用 report_local_unfinalized_metrics 方法,讓您的模型能夠將其收集的所有摘要統計資料編譯成一組精簡的指標,以供用戶端匯出。在此情況下,您的模型程式碼可能會將損失總和除以已處理範例的數量,以匯出平均損失等等。

  • **聯邦式彙總**。此層級的彙總是指跨系統中多個用戶端 (裝置) 的彙總。同樣地,它適用於模型參數 (變數),這些參數會在用戶端之間取平均值,也適用於您的模型因本機彙總而匯出的指標。

    在此層級執行彙總是 TFF 的責任。但是,身為模型建立者,您可以控制此程序 (詳情請見下文)。

    一般處理結構如下:

    • 初始模型,以及訓練所需的任何參數,都會由伺服器分配給將參與一輪訓練或評估的用戶端子集。

    • 在每個用戶端上,獨立且平行地,您的模型程式碼會在本機資料批次串流上重複叫用,以產生一組新的模型參數 (在訓練時) 和一組新的本機指標,如上所述 (這是本機彙總)。

    • TFF 會執行分散式彙總協定,以累積和彙總整個系統中的模型參數和本機匯出指標。此邏輯是以宣告方式使用 TFF 自己的聯邦運算語言 (而非 TensorFlow) 表達。如需彙總 API 的詳細資訊,請參閱自訂演算法教學課程。

抽象介面

這個基本的「建構函式 + 中繼資料」介面是由 tff.learning.models.VariableModel 介面表示,如下所示:

  • 建構函式、forward_passreport_local_unfinalized_metrics 方法應分別建構模型變數、正向傳遞和您想要報告的統計資料。這些方法建構的 TensorFlow 必須可序列化,如上所述。

  • input_spec 屬性,以及傳回可訓練、不可訓練和本機變數子集的 3 個屬性,代表中繼資料。TFF 會使用此資訊來判斷如何將模型的部分連線至聯邦最佳化演算法,以及定義內部類型簽章,以協助驗證已建構系統的正確性 (以便您的模型無法在與模型設計用途不符的資料上具現化)。

此外,抽象介面 tff.learning.models.VariableModel 會公開 metric_finalizers 屬性,該屬性會接收指標的未完成值 (由 report_local_unfinalized_metrics() 傳回) 並傳回已完成的指標值。metric_finalizersreport_local_unfinalized_metrics() 方法將一起用於建構跨用戶端指標彙總器,以在定義聯邦訓練程序或評估運算時使用。例如,簡單的 tff.learning.metrics.sum_then_finalize 彙總器會先將用戶端的未完成指標值加總,然後在伺服器端呼叫完成器函式。

您可以在我們的圖片分類教學課程的第二部分,以及我們在 model_examples.py 中用於測試的範例模型中,找到如何定義您自己的自訂 tff.learning.models.VariableModel 的範例。

Keras 轉換器

TFF 需要的幾乎所有資訊都可以透過呼叫 tf.keras 介面取得,因此如果您有 Keras 模型,可以仰賴 tff.learning.models.from_keras_model 建構 tff.learning.models.VariableModel

請注意,TFF 仍然希望您提供建構函式 - 不含引數的模型函式,例如下列範例:

def model_fn():
  keras_model = ...
  return tff.learning.models.from_keras_model(keras_model, sample_batch, loss=...)

除了模型本身之外,您還提供一個範例資料批次,TFF 會使用該批次來判斷模型輸入的類型和形狀。這可確保 TFF 可以針對實際存在於用戶端裝置上的資料正確地具現化模型 (因為我們假設在您建構要序列化的 TensorFlow 時,此資料通常無法使用)。

我們的圖片分類文字產生教學課程說明了 Keras 包裝函式的使用方式。

聯邦運算建構器

tff.learning 套件為執行學習相關工作的 tff.Computation 提供多個建構器;我們預期這類運算的集合在未來會擴充。

架構假設

執行

執行聯邦運算有兩個不同的階段。

  • 編譯:TFF 首先將聯邦學習演算法編譯成整個分散式運算的抽象序列化表示法。這是 TensorFlow 序列化發生的時間,但可能會發生其他轉換以支援更有效率的執行。我們將編譯器發出的序列化表示法稱為聯邦運算

  • 執行 TFF 提供執行這些運算的方法。目前,僅透過本機模擬 (例如,在筆記本中使用模擬的去中心化資料) 支援執行。

由 TFF 的 Federated Learning API 產生的聯邦運算 (例如使用聯邦模型平均的訓練演算法,或聯邦評估) 包含許多元素,最值得注意的是:

  • 模型程式碼的序列化形式,以及 Federated Learning 框架建構的其他 TensorFlow 程式碼,用於驅動模型的訓練/評估迴圈 (例如建構最佳化工具、套用模型更新、疊代處理 tf.data.Dataset 和計算指標,以及在伺服器上套用彙總更新,僅舉幾例)。

  • 用戶端和伺服器之間通訊的宣告式規格 (通常是跨用戶端裝置的各種彙總形式,以及從伺服器廣播到所有用戶端),以及這種分散式通訊如何與用戶端本機或伺服器本機 TensorFlow 程式碼的執行交錯。

以這種序列化形式表示的聯邦運算,是以獨立於平台的內部語言 (與 Python 不同) 表達,但若要使用 Federated Learning API,您不需要擔心此表示法的詳細資訊。運算在您的 Python 程式碼中表示為 tff.Computation 類型的物件,在大多數情況下,您可以將其視為不透明的 Python 可呼叫項。

在教學課程中,您將叫用這些聯邦運算,就好像它們是要在本機執行的常規 Python 函式一樣。但是,TFF 的設計目的是以與執行環境的大多數方面無關的方式來表達聯邦運算,以便它們可以部署到 (例如) 執行 Android 的裝置群組,或資料中心的叢集。同樣地,這項設計的主要後果是對序列化的強烈假設。特別是,當您叫用以下描述的 build_... 方法之一時,運算會完全序列化。

狀態建模

TFF 是函數式程式設計環境,但聯邦學習中許多感興趣的程序都是有狀態的。例如,涉及多輪聯邦模型平均的訓練迴圈是我們可以歸類為有狀態程序的範例。在此程序中,從一輪到另一輪演進的狀態包括正在訓練的模型參數集,以及可能與最佳化工具相關聯的其他狀態 (例如,動量向量)。

由於 TFF 是函數式的,因此有狀態程序在 TFF 中被建模為運算,這些運算接受目前狀態作為輸入,然後提供更新後的狀態作為輸出。為了完整定義有狀態程序,還需要指定初始狀態的來源 (否則我們無法啟動程序)。這在輔助類別 tff.templates.IterativeProcess 的定義中擷取,其中 initializenext 這兩個屬性分別對應於初始化和疊代。

可用的建構器

目前,TFF 提供各種建構器函式,用於產生聯邦訓練和評估的聯邦運算。兩個值得注意的範例包括:

資料集

架構假設

用戶端選取

在典型的聯邦學習情境中,我們有龐大的用戶端裝置群體,可能多達數億台,但其中只有一小部分可能在任何給定時刻處於活動狀態且可用於訓練 (例如,這可能僅限於已插入電源、未使用計量網路且處於閒置狀態的用戶端)。一般而言,可用於參與訓練或評估的用戶端集合不在開發人員的控制範圍內。此外,由於協調數百萬個用戶端是不切實際的,因此典型的訓練或評估輪次只會包含一小部分可用用戶端,這些用戶端可能會隨機取樣。

這項設計的主要後果是,聯邦運算在設計上是以忽略確切參與者集合的方式表達的;所有處理都表示為對匿名用戶端抽象群組的彙總運算,而該群組可能會因訓練輪次而異。因此,運算與具體參與者 (以及他們饋送到運算中的具體資料) 的實際繫結是在運算本身之外建模的。

為了模擬聯邦學習程式碼的實際部署,您通常會編寫如下所示的訓練迴圈:

trainer = tff.learning.algorithms.build_weighted_fed_avg(...)
state = trainer.initialize()
federated_training_data = ...

def sample(federate_data):
  return ...

while True:
  data_for_this_round = sample(federated_training_data)
  result = trainer.next(state, data_for_this_round)
  state = result.state

為了方便起見,在模擬中使用 TFF 時,聯邦資料會以 Python 列表形式接受,每個參與的用戶端裝置一個元素,以表示該裝置的本機 tf.data.Dataset

抽象介面

為了標準化處理模擬的聯邦資料集,TFF 提供抽象介面 tff.simulation.datasets.ClientData,該介面可讓使用者列舉用戶端集合,並建構包含特定用戶端資料的 tf.data.Dataset。這些 tf.data.Dataset 可以直接作為輸入饋送到在 Eager 模式下產生的聯邦運算。

應注意的是,存取用戶端身分的能力是資料集僅為模擬用途提供的功能,在模擬中,可能需要能夠訓練來自特定用戶端子集的資料 (例如,模擬不同類型用戶端的日間可用性)。編譯後的運算和底層執行階段不涉及任何用戶端身分的概念。一旦從特定用戶端子集選取的資料作為輸入 (例如,在呼叫 tff.templates.IterativeProcess.next 時),用戶端身分就不會再出現在其中。

可用的資料集

我們已將 tff.simulation.datasets 命名空間專用於實作 tff.simulation.datasets.ClientData 介面的資料集,以便用於模擬,並使用資料集為其植入種子,以支援圖片分類文字產生教學課程。我們希望鼓勵您將自己的資料集貢獻到平台。