使用預先訓練的模型

在本教學課程中,您將探索範例網頁應用程式,示範如何使用 TensorFlow.js Layers API 進行遷移學習。此範例會載入預先訓練的模型,然後在瀏覽器中重新訓練模型。

此模型已在 Python 中針對 MNIST 數字分類資料集的數字 0-4 進行預先訓練。瀏覽器中的重新訓練 (或遷移學習) 使用數字 5-9。此範例顯示,預先訓練模型的前幾層可用於在遷移學習期間從新資料中擷取特徵,進而加快新資料的訓練速度。

本教學課程的範例應用程式可線上取得,因此您無需下載任何程式碼或設定開發環境。如果您想在本機執行程式碼,請完成在本機執行範例中的選用步驟。如果您不想設定開發環境,可以直接跳到探索範例

範例程式碼可在 GitHub 上取得。

(選用) 在本機執行範例

先決條件

若要在本機執行範例應用程式,您的開發環境中需要安裝下列項目

安裝並執行範例應用程式

  1. 複製或下載 tfjs-examples 存放區。
  2. 變更至 mnist-transfer-cnn 目錄

    cd tfjs-examples/mnist-transfer-cnn
    
  3. 安裝依附元件

    yarn
    
  4. 啟動開發伺服器

    yarn run watch
    

探索範例

開啟範例應用程式。(或者,如果您在本機執行範例,請前往瀏覽器中的 https://127.0.0.1:1234。)

您應該會看到標題為「MNIST CNN 遷移學習」的頁面。按照指示試用應用程式。

以下是一些您可以嘗試的事項

  • 試驗不同的訓練模式,並比較損失和準確度。
  • 選取不同的點陣圖範例,並檢查分類機率。請注意,每個點陣圖範例中的數字都是灰階整數值,代表圖片中的像素。
  • 編輯點陣圖整數值,看看變更如何影響分類機率。

探索程式碼

範例網頁應用程式會載入已針對 MNIST 資料集子集進行預先訓練的模型。預先訓練是在 Python 程式中定義:mnist_transfer_cnn.py。Python 程式超出本教學課程的範圍,但如果您想查看模型轉換範例,則值得一看。

index.js 檔案包含示範的大部分訓練程式碼。當 index.js 在瀏覽器中執行時,設定函式 setupMnistTransferCNN 會將 MnistTransferCNNPredictor 具現化並初始化,後者會封裝重新訓練和預測常式。

初始化方法 MnistTransferCNNPredictor.init 會載入模型、載入重新訓練資料,並建立測試資料。以下是載入模型的程式碼行

this.model = await loader.loadHostedPretrainedModel(urls.model);

如果您查看 loader.loadHostedPretrainedModel 的定義,您會看到它會傳回呼叫 tf.loadLayersModel 的結果。這是用於載入由 Layer 物件組成的模型的 TensorFlow.js API。

重新訓練邏輯在 MnistTransferCNNPredictor.retrainModel 中定義。如果使用者已選取「凍結特徵層」作為訓練模式,則基本模型的前 7 層會凍結,只有最後 5 層會在新資料上訓練。如果使用者已選取「重新初始化權重」,則所有權重都會重設,且應用程式實際上會從頭開始訓練模型。

if (trainingMode === 'freeze-feature-layers') {
  console.log('Freezing feature layers of the model.');
  for (let i = 0; i < 7; ++i) {
    this.model.layers[i].trainable = false;
  }
} else if (trainingMode === 'reinitialize-weights') {
  // Make a model with the same topology as before, but with re-initialized
  // weight values.
  const returnString = false;
  this.model = await tf.models.modelFromJSON({
    modelTopology: this.model.toJSON(null, returnString)
  });
}

接著會編譯模型,然後使用 model.fit() 在測試資料上訓練模型

await this.model.fit(this.gte5TrainData.x, this.gte5TrainData.y, {
  batchSize: batchSize,
  epochs: epochs,
  validationData: [this.gte5TestData.x, this.gte5TestData.y],
  callbacks: [
    ui.getProgressBarCallbackConfig(epochs),
    tfVis.show.fitCallbacks(surfaceInfo, ['val_loss', 'val_acc'], {
      zoomToFit: true,
      zoomToFitAccuracy: true,
      height: 200,
      callbacks: ['onEpochEnd'],
    }),
  ]
});

若要進一步瞭解 model.fit() 參數,請參閱 API 文件

在新的資料集 (數字 5-9) 上訓練後,模型即可用於進行預測。MnistTransferCNNPredictor.predict 方法使用 model.predict() 執行此操作

// Perform prediction on the input image using the loaded model.
predict(imageText) {
  tf.tidy(() => {
    try {
      const image = util.textToImageArray(imageText, this.imageSize);
      const predictOut = this.model.predict(image);
      const winner = predictOut.argMax(1);

      ui.setPredictResults(predictOut.dataSync(), winner.dataSync()[0] + 5);
    } catch (e) {
      ui.setPredictError(e.message);
    }
  });
}

請注意 tf.tidy 的用法,這有助於防止記憶體洩漏。

瞭解詳情

本教學課程探索了範例應用程式,示範如何在瀏覽器中使用 TensorFlow.js 執行遷移學習。查看以下資源以進一步瞭解預先訓練模型和遷移學習。

TensorFlow.js

TensorFlow Core