使用 Web Worker 訓練模型

在本教學課程中,您將探索一個範例 Web 應用程式,該應用程式使用 Web Worker 訓練循環神經網路 (RNN) 執行整數加法。此範例應用程式未明確定義加法運算子。而是使用範例總和訓練 RNN。

當然,這不是加總兩個整數最有效率的方法!但本教學課程示範了 Web ML 中的一項重要技術:如何在不封鎖處理 UI 邏輯的主要執行緒的情況下,執行長時間執行的運算。

本教學課程的範例應用程式可線上取得,因此您不需要下載任何程式碼或設定開發環境。如果您想在本機執行程式碼,請完成在本機執行範例中的選用步驟。如果您不想設定開發環境,可以跳至探索範例

範例程式碼可在 GitHub 上取得。

(選用) 在本機執行範例

先決條件

若要在本機執行範例應用程式,您的開發環境中需要安裝下列項目

安裝並執行範例應用程式

  1. 複製或下載 tfjs-examples 存放區。
  2. 變更至 addition-rnn-webworker 目錄

    cd tfjs-examples/addition-rnn-webworker
    
  3. 安裝依附元件

    yarn
    
  4. 啟動開發伺服器

    yarn run watch
    

探索範例

開啟範例應用程式。(或者,如果您在本機執行範例,請前往瀏覽器中的 https://127.0.0.1:1234。)

您應該會看到一個標題為TensorFlow.js: Addition RNN的頁面。按照指示試用此應用程式。

使用 Web 表單,您可以更新用於訓練模型的某些參數,包括下列項目

  • 位數:要加總的項目的最大位數。
  • 訓練大小:要產生的訓練範例數量。
  • RNN 類型SimpleRNNGRULSTM 其中之一。
  • RNN 隱藏層大小:輸出空間的維度 (必須為正整數)。
  • 批次大小:每次梯度更新的樣本數。
  • 訓練迭代次數:透過叫用 model.fit() 訓練模型的次數
  • 測試範例數:要產生的範例字串 (例如 27+41) 數量。

嘗試使用不同的參數訓練模型,並查看是否可以提高各種位數集合的預測準確性。另請注意模型擬合時間如何受到不同參數的影響。

探索程式碼

範例應用程式示範了您可以針對訓練 RNN 設定的某些參數。它也示範了如何使用 Web Worker 在主要執行緒之外訓練模型。Web Worker 在 Web ML 中很重要,因為它們可讓您在背景執行緒上執行運算密集型訓練工作,從而避免主要執行緒上可能影響使用者效能的問題。主要執行緒和 Worker 執行緒透過訊息事件彼此通訊。

若要進一步瞭解 Web Worker,請參閱 Web Workers APIUsing Web Workers

範例應用程式的主要模組是 index.jsindex.js 指令碼會建立一個 Web Worker,以執行 worker.js 模組

const worker =
    new Worker(new URL('./worker.js', import.meta.url), {type: 'module'});

index.js 主要由單一函式 runAdditionRNNDemo 組成,該函式處理表單提交、處理表單資料、將表單資料傳遞至 Worker、等待 Worker 訓練模型並傳回結果,然後在頁面上顯示結果。

若要將表單資料傳送至 Worker,指令碼會在 Worker 上叫用 postMessage

worker.postMessage({
  digits,
  trainingSize,
  rnnType,
  layers,
  hiddenSize,
  trainIterations,
  batchSize,
  numTestExamples
});

Worker 會接聽此訊息,並將表單資料傳遞至準備資料並開始訓練的函式

self.addEventListener('message', async (e) => {
  const { digits, trainingSize, rnnType, layers, hiddenSize, trainIterations, batchSize, numTestExamples } = e.data;
  const demo = new AdditionRNNDemo(digits, trainingSize, rnnType, layers, hiddenSize);
  await demo.train(trainIterations, batchSize, numTestExamples);
})

在訓練期間,Worker 可以傳送兩種不同的訊息類型,其中一種訊息類型的 isPredict 設定為 true

self.postMessage({
  isPredict: true,
  i, iterations, modelFitTime,
  lossValues, accuracyValues,
});

另一種訊息類型的 isPredict 設定為 false

self.postMessage({
  isPredict: false,
  isCorrect, examples
});

當 UI 執行緒 (index.js) 處理訊息事件時,它會檢查 isPredict 旗標,以判斷從 Worker 傳回的資料形狀。如果 isPredict 為 true,則資料應代表預測,且指令碼會使用 tfjs-vis 更新頁面。如果 isPredict 為 false,則指令碼會執行程式碼區塊,假設資料代表範例。它會將資料包裝在 HTML 中,並將 HTML 插入頁面。

接下來的步驟

本教學課程提供了一個使用 Web Worker 避免長時間執行的訓練程序封鎖 UI 執行緒的範例。若要進一步瞭解在背景執行緒上執行耗時運算的優點,請參閱使用 Web Worker 在瀏覽器主要執行緒之外執行 JavaScript

若要進一步瞭解如何訓練 TensorFlow.js 模型,請參閱訓練模型