在本教學課程中,您將探索範例網頁應用程式,示範如何使用 TensorFlow.js Layers API 進行遷移學習。此範例會載入預先訓練的模型,然後在瀏覽器中重新訓練模型。
此模型已在 Python 中針對 MNIST 數字分類資料集的數字 0-4 進行預先訓練。瀏覽器中的重新訓練 (或遷移學習) 使用數字 5-9。此範例顯示,預先訓練模型的前幾層可用於在遷移學習期間從新資料中擷取特徵,進而加快新資料的訓練速度。
本教學課程的範例應用程式可線上取得,因此您無需下載任何程式碼或設定開發環境。如果您想在本機執行程式碼,請完成在本機執行範例中的選用步驟。如果您不想設定開發環境,可以直接跳到探索範例。
範例程式碼可在 GitHub 上取得。
(選用) 在本機執行範例
先決條件
若要在本機執行範例應用程式,您的開發環境中需要安裝下列項目
安裝並執行範例應用程式
- 複製或下載
tfjs-examples
存放區。 變更至
mnist-transfer-cnn
目錄cd tfjs-examples/mnist-transfer-cnn
安裝依附元件
yarn
啟動開發伺服器
yarn run watch
探索範例
開啟範例應用程式。(或者,如果您在本機執行範例,請前往瀏覽器中的 https://127.0.0.1:1234
。)
您應該會看到標題為「MNIST CNN 遷移學習」的頁面。按照指示試用應用程式。
以下是一些您可以嘗試的事項
- 試驗不同的訓練模式,並比較損失和準確度。
- 選取不同的點陣圖範例,並檢查分類機率。請注意,每個點陣圖範例中的數字都是灰階整數值,代表圖片中的像素。
- 編輯點陣圖整數值,看看變更如何影響分類機率。
探索程式碼
範例網頁應用程式會載入已針對 MNIST 資料集子集進行預先訓練的模型。預先訓練是在 Python 程式中定義:mnist_transfer_cnn.py
。Python 程式超出本教學課程的範圍,但如果您想查看模型轉換範例,則值得一看。
index.js
檔案包含示範的大部分訓練程式碼。當 index.js
在瀏覽器中執行時,設定函式 setupMnistTransferCNN
會將 MnistTransferCNNPredictor
具現化並初始化,後者會封裝重新訓練和預測常式。
初始化方法 MnistTransferCNNPredictor.init
會載入模型、載入重新訓練資料,並建立測試資料。以下是載入模型的程式碼行
this.model = await loader.loadHostedPretrainedModel(urls.model);
如果您查看 loader.loadHostedPretrainedModel
的定義,您會看到它會傳回呼叫 tf.loadLayersModel
的結果。這是用於載入由 Layer 物件組成的模型的 TensorFlow.js API。
重新訓練邏輯在 MnistTransferCNNPredictor.retrainModel
中定義。如果使用者已選取「凍結特徵層」作為訓練模式,則基本模型的前 7 層會凍結,只有最後 5 層會在新資料上訓練。如果使用者已選取「重新初始化權重」,則所有權重都會重設,且應用程式實際上會從頭開始訓練模型。
if (trainingMode === 'freeze-feature-layers') {
console.log('Freezing feature layers of the model.');
for (let i = 0; i < 7; ++i) {
this.model.layers[i].trainable = false;
}
} else if (trainingMode === 'reinitialize-weights') {
// Make a model with the same topology as before, but with re-initialized
// weight values.
const returnString = false;
this.model = await tf.models.modelFromJSON({
modelTopology: this.model.toJSON(null, returnString)
});
}
接著會編譯模型,然後使用 model.fit()
在測試資料上訓練模型
await this.model.fit(this.gte5TrainData.x, this.gte5TrainData.y, {
batchSize: batchSize,
epochs: epochs,
validationData: [this.gte5TestData.x, this.gte5TestData.y],
callbacks: [
ui.getProgressBarCallbackConfig(epochs),
tfVis.show.fitCallbacks(surfaceInfo, ['val_loss', 'val_acc'], {
zoomToFit: true,
zoomToFitAccuracy: true,
height: 200,
callbacks: ['onEpochEnd'],
}),
]
});
若要進一步瞭解 model.fit()
參數,請參閱 API 文件。
在新的資料集 (數字 5-9) 上訓練後,模型即可用於進行預測。MnistTransferCNNPredictor.predict
方法使用 model.predict()
執行此操作
// Perform prediction on the input image using the loaded model.
predict(imageText) {
tf.tidy(() => {
try {
const image = util.textToImageArray(imageText, this.imageSize);
const predictOut = this.model.predict(image);
const winner = predictOut.argMax(1);
ui.setPredictResults(predictOut.dataSync(), winner.dataSync()[0] + 5);
} catch (e) {
ui.setPredictError(e.message);
}
});
}
請注意 tf.tidy
的用法,這有助於防止記憶體洩漏。
瞭解詳情
本教學課程探索了範例應用程式,示範如何在瀏覽器中使用 TensorFlow.js 執行遷移學習。查看以下資源以進一步瞭解預先訓練模型和遷移學習。
TensorFlow.js
TensorFlow Core