在本教學課程中,您將探索一個範例 Web 應用程式,該應用程式使用 Web Worker 訓練循環神經網路 (RNN) 執行整數加法。此範例應用程式未明確定義加法運算子。而是使用範例總和訓練 RNN。
當然,這不是加總兩個整數最有效率的方法!但本教學課程示範了 Web ML 中的一項重要技術:如何在不封鎖處理 UI 邏輯的主要執行緒的情況下,執行長時間執行的運算。
本教學課程的範例應用程式可線上取得,因此您不需要下載任何程式碼或設定開發環境。如果您想在本機執行程式碼,請完成在本機執行範例中的選用步驟。如果您不想設定開發環境,可以跳至探索範例。
範例程式碼可在 GitHub 上取得。
(選用) 在本機執行範例
先決條件
若要在本機執行範例應用程式,您的開發環境中需要安裝下列項目
安裝並執行範例應用程式
- 複製或下載
tfjs-examples
存放區。 變更至
addition-rnn-webworker
目錄cd tfjs-examples/addition-rnn-webworker
安裝依附元件
yarn
啟動開發伺服器
yarn run watch
探索範例
開啟範例應用程式。(或者,如果您在本機執行範例,請前往瀏覽器中的 https://127.0.0.1:1234
。)
您應該會看到一個標題為TensorFlow.js: Addition RNN的頁面。按照指示試用此應用程式。
使用 Web 表單,您可以更新用於訓練模型的某些參數,包括下列項目
- 位數:要加總的項目的最大位數。
- 訓練大小:要產生的訓練範例數量。
- RNN 類型:SimpleRNN、GRU 或 LSTM 其中之一。
- RNN 隱藏層大小:輸出空間的維度 (必須為正整數)。
- 批次大小:每次梯度更新的樣本數。
- 訓練迭代次數:透過叫用
model.fit()
訓練模型的次數 - 測試範例數:要產生的範例字串 (例如
27+41
) 數量。
嘗試使用不同的參數訓練模型,並查看是否可以提高各種位數集合的預測準確性。另請注意模型擬合時間如何受到不同參數的影響。
探索程式碼
範例應用程式示範了您可以針對訓練 RNN 設定的某些參數。它也示範了如何使用 Web Worker 在主要執行緒之外訓練模型。Web Worker 在 Web ML 中很重要,因為它們可讓您在背景執行緒上執行運算密集型訓練工作,從而避免主要執行緒上可能影響使用者效能的問題。主要執行緒和 Worker 執行緒透過訊息事件彼此通訊。
若要進一步瞭解 Web Worker,請參閱 Web Workers API 和 Using Web Workers。
範例應用程式的主要模組是 index.js
。index.js
指令碼會建立一個 Web Worker,以執行 worker.js
模組
const worker =
new Worker(new URL('./worker.js', import.meta.url), {type: 'module'});
index.js
主要由單一函式 runAdditionRNNDemo
組成,該函式處理表單提交、處理表單資料、將表單資料傳遞至 Worker、等待 Worker 訓練模型並傳回結果,然後在頁面上顯示結果。
若要將表單資料傳送至 Worker,指令碼會在 Worker 上叫用 postMessage
worker.postMessage({
digits,
trainingSize,
rnnType,
layers,
hiddenSize,
trainIterations,
batchSize,
numTestExamples
});
Worker 會接聽此訊息,並將表單資料傳遞至準備資料並開始訓練的函式
self.addEventListener('message', async (e) => {
const { digits, trainingSize, rnnType, layers, hiddenSize, trainIterations, batchSize, numTestExamples } = e.data;
const demo = new AdditionRNNDemo(digits, trainingSize, rnnType, layers, hiddenSize);
await demo.train(trainIterations, batchSize, numTestExamples);
})
在訓練期間,Worker 可以傳送兩種不同的訊息類型,其中一種訊息類型的 isPredict
設定為 true
self.postMessage({
isPredict: true,
i, iterations, modelFitTime,
lossValues, accuracyValues,
});
和另一種訊息類型的 isPredict
設定為 false
。
self.postMessage({
isPredict: false,
isCorrect, examples
});
當 UI 執行緒 (index.js
) 處理訊息事件時,它會檢查 isPredict
旗標,以判斷從 Worker 傳回的資料形狀。如果 isPredict
為 true,則資料應代表預測,且指令碼會使用 tfjs-vis
更新頁面。如果 isPredict
為 false,則指令碼會執行程式碼區塊,假設資料代表範例。它會將資料包裝在 HTML 中,並將 HTML 插入頁面。
接下來的步驟
本教學課程提供了一個使用 Web Worker 避免長時間執行的訓練程序封鎖 UI 執行緒的範例。若要進一步瞭解在背景執行緒上執行耗時運算的優點,請參閱使用 Web Worker 在瀏覽器主要執行緒之外執行 JavaScript。
若要進一步瞭解如何訓練 TensorFlow.js 模型,請參閱訓練模型。