遷移學習與微調

作者: fchollet

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視原始碼 在 keras.io 上檢視

設定

import numpy as np
import tensorflow as tf
from tensorflow import keras
2023-10-03 11:11:08.160283: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-10-03 11:11:08.160349: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-10-03 11:11:08.160404: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

簡介

遷移學習包含取用在一個問題上學到的特徵,並在新的類似問題上加以利用。例如,從已學會識別浣熊的模型取得的特徵,可能對啟動旨在識別貉的模型很有用。

遷移學習通常用於資料集資料太少,無法從頭開始訓練完整模型的任務。

在深度學習的背景下,遷移學習最常見的形式是以下工作流程

  1. 從先前訓練的模型中取得層。
  2. 凍結它們,以避免在未來的訓練回合中破壞它們包含的任何資訊。
  3. 在凍結層的頂部新增一些新的可訓練層。它們將學習將舊特徵轉化為新資料集的預測。
  4. 在您的資料集上訓練新層。

最後一個可選步驟是微調,它包含解凍您在上面取得的整個模型(或部分模型),並以非常低的學習率在新資料上重新訓練它。這可以透過逐步調整預訓練特徵以適應新資料,來潛在地實現有意義的改進。

首先,我們將詳細介紹 Keras 可訓練 API,它是大多數遷移學習和微調工作流程的基礎。

接著,我們將示範典型的工作流程,方法是取用在 ImageNet 資料集上預訓練的模型,並在 Kaggle「貓 vs 狗」分類資料集上重新訓練它。

這改編自《Deep Learning with Python》和 2016 年的部落格文章「building powerful image classification models using very little data」。

凍結層:瞭解 trainable 屬性

層和模型有三個權重屬性

  • weights 是層的所有權重變數的清單。
  • trainable_weights 是那些旨在更新(透過梯度下降)以在訓練期間最小化損失的清單。
  • non_trainable_weights 是那些不打算訓練的清單。通常它們在正向傳遞期間由模型更新。

範例:Dense 層有 2 個可訓練權重(核心和偏差)

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 2
non_trainable_weights: 0
2023-10-03 11:11:10.677246: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://tensorflow.dev.org.tw/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...

一般來說,所有權重都是可訓練權重。唯一具有非可訓練權重的內建層是 BatchNormalization 層。它使用非可訓練權重來追蹤其輸入在訓練期間的平均值和變異數。若要學習如何在您自己的自訂層中使用非可訓練權重,請參閱從頭開始編寫新層的指南

範例:BatchNormalization 層有 2 個可訓練權重和 2 個非可訓練權重

layer = keras.layers.BatchNormalization()
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 4
trainable_weights: 2
non_trainable_weights: 2

層和模型也具有布林屬性 trainable。其值可以變更。將 layer.trainable 設定為 False 會將層的所有權重從可訓練移至非可訓練。這稱為「凍結」層:凍結層的狀態在訓練期間不會更新(無論是使用 fit() 進行訓練,還是使用任何依賴 trainable_weights 來套用梯度更新的自訂迴圈進行訓練)。

範例:將 trainable 設定為 False

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights
layer.trainable = False  # Freeze the layer

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 0
non_trainable_weights: 2

當可訓練權重變成非可訓練時,其值在訓練期間不再更新。

# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])

# Freeze the first layer
layer1.trainable = False

# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()

# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))

# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(
    initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
    initial_layer1_weights_values[1], final_layer1_weights_values[1]
)
1/1 [==============================] - 0s 324ms/step - loss: 0.0178

請勿將 layer.trainable 屬性與 layer.__call__() 中的引數 training 混淆(它控制層是否應以推論模式或訓練模式執行其正向傳遞)。如需更多資訊,請參閱 Keras 常見問題

遞迴設定 trainable 屬性

如果您在模型或任何具有子層的層上設定 trainable = False,則所有子層也會變成非可訓練。

範例:

inner_model = keras.Sequential(
    [
        keras.Input(shape=(3,)),
        keras.layers.Dense(3, activation="relu"),
        keras.layers.Dense(3, activation="relu"),
    ]
)

model = keras.Sequential(
    [
        keras.Input(shape=(3,)),
        inner_model,
        keras.layers.Dense(3, activation="sigmoid"),
    ]
)

model.trainable = False  # Freeze the outer model

assert inner_model.trainable == False  # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False  # `trainable` is propagated recursively

典型的遷移學習工作流程

這引導我們瞭解如何在 Keras 中實作典型的遷移學習工作流程

  1. 實例化基礎模型並將預訓練權重載入其中。
  2. 透過設定 trainable = False 來凍結基礎模型中的所有層。
  3. 在基礎模型中一個(或多個)層的輸出之上建立新模型。
  4. 在您的新資料集上訓練您的新模型。

請注意,另一種更輕量的工作流程也可能是:

  1. 實例化基礎模型並將預訓練權重載入其中。
  2. 透過它執行您的新資料集,並記錄基礎模型中一個(或多個)層的輸出。這稱為特徵擷取
  3. 將該輸出用作新且較小模型的輸入資料。

第二個工作流程的主要優點是您只需在您的資料上執行基礎模型一次,而不是每個訓練週期執行一次。因此它更快且更便宜。

不過,第二個工作流程的問題在於,它不允許您在訓練期間動態修改新模型的輸入資料,這在進行資料擴增時是必需的,例如。遷移學習通常用於新資料集資料太少,無法從頭開始訓練完整規模模型的任務,並且在這種情況下,資料擴增非常重要。因此,在接下來的內容中,我們將專注於第一個工作流程。

以下是第一個工作流程在 Keras 中的樣子

首先,使用預訓練權重實例化基礎模型。

base_model = keras.applications.Xception(
    weights='imagenet',  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False)  # Do not include the ImageNet classifier at the top.

接著,凍結基礎模型。

base_model.trainable = False

在頂部建立新模型。

inputs = keras.Input(shape=(150, 150, 3))
# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
x = keras.layers.GlobalAveragePooling2D()(x)
# A Dense classifier with a single unit (binary classification)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

在新資料上訓練模型。

model.compile(optimizer=keras.optimizers.Adam(),
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])
model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)

微調

一旦您的模型在新資料上收斂,您可以嘗試解凍基礎模型的全部或部分,並以非常低的學習率重新端對端訓練整個模型。

這是一個可選的最後步驟,可以潛在地為您帶來漸進式的改進。它也可能潛在地導致快速過度擬合 — 請記住這一點。

務必僅在凍結層的模型已訓練至收斂後再執行此步驟。如果您將隨機初始化的可訓練層與保留預訓練特徵的可訓練層混合,則隨機初始化的層將在訓練期間導致非常大的梯度更新,這將破壞您的預訓練特徵。

在此階段使用非常低的學習率也至關重要,因為您正在訓練比第一輪訓練更大的模型,且資料集通常非常小。因此,如果您套用大的權重更新,您將面臨快速過度擬合的風險。在這裡,您只想以漸進式的方式重新調整預訓練權重。

以下是如何實作整個基礎模型的微調

# Unfreeze the base model
base_model.trainable = True

# It's important to recompile your model after you make any changes
# to the `trainable` attribute of any inner layer, so that your changes
# are take into account
model.compile(optimizer=keras.optimizers.Adam(1e-5),  # Very low learning rate
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])

# Train end-to-end. Be careful to stop before you overfit!
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)

關於 compile()trainable 的重要注意事項

在模型上呼叫 compile() 旨在「凍結」該模型的行為。這表示模型編譯時的 trainable 屬性值應在該模型的整個生命週期中保留,直到再次呼叫 compile 為止。因此,如果您變更任何 trainable 值,請務必再次在您的模型上呼叫 compile(),以使您的變更生效。

關於 BatchNormalization 層的重要注意事項

許多影像模型都包含 BatchNormalization 層。該層在所有可想像的方面都是特例。以下是一些需要記住的事項:

  • BatchNormalization 包含 2 個在訓練期間更新的非可訓練權重。這些是追蹤輸入的平均值和變異數的變數。
  • 當您設定 bn_layer.trainable = False 時,BatchNormalization 層將在推論模式下執行,並且不會更新其平均值和變異數統計資訊。一般來說,其他層並非如此,因為權重可訓練性與推論/訓練模式是兩個正交概念。但這兩者在 BatchNormalization 層的情況下是相關聯的。
  • 當您解凍包含 BatchNormalization 層的模型以進行微調時,您應該透過在呼叫基礎模型時傳遞 training=False,使 BatchNormalization 層保持在推論模式。否則,套用至非可訓練權重的更新將突然破壞模型已學到的內容。

您將在本指南結尾的端對端範例中看到此模式的運作方式。

使用自訂訓練迴圈進行遷移學習和微調

如果您使用的是自己的低階訓練迴圈而不是 fit(),則工作流程基本上保持不變。在套用梯度更新時,您應注意僅考慮清單 model.trainable_weights

# Create base model
base_model = keras.applications.Xception(
    weights='imagenet',
    input_shape=(150, 150, 3),
    include_top=False)
# Freeze base model
base_model.trainable = False

# Create new model on top.
inputs = keras.Input(shape=(150, 150, 3))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam()

# Iterate over the batches of a dataset.
for inputs, targets in new_dataset:
    # Open a GradientTape.
    with tf.GradientTape() as tape:
        # Forward pass.
        predictions = model(inputs)
        # Compute the loss value for this batch.
        loss_value = loss_fn(targets, predictions)

    # Get gradients of loss wrt the *trainable* weights.
    gradients = tape.gradient(loss_value, model.trainable_weights)
    # Update the weights of the model.
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))

微調也是如此。

端對端範例:在貓 vs. 狗資料集上微調影像分類模型

為了鞏固這些概念,讓我們引導您完成一個具體的端對端遷移學習和微調範例。我們將載入在 ImageNet 上預訓練的 Xception 模型,並在 Kaggle「貓 vs. 狗」分類資料集上使用它。

取得資料

首先,讓我們使用 TFDS 擷取貓 vs. 狗資料集。如果您有自己的資料集,您可能會想要使用公用程式 keras.utils.image_dataset_from_directory,從磁碟上歸檔到類別特定資料夾的一組影像中產生類似的標籤資料集物件。

當使用非常小的資料集時,遷移學習最有用。為了保持我們的資料集小巧,我們將使用原始訓練資料 (25,000 張影像) 的 40% 進行訓練,10% 進行驗證,以及 10% 進行測試。

import tensorflow_datasets as tfds

tfds.disable_progress_bar()

train_ds, validation_ds, test_ds = tfds.load(
    "cats_vs_dogs",
    # Reserve 10% for validation and 10% for test
    split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
    as_supervised=True,  # Include labels
)

print("Number of training samples: %d" % tf.data.experimental.cardinality(train_ds))
print(
    "Number of validation samples: %d" % tf.data.experimental.cardinality(validation_ds)
)
print("Number of test samples: %d" % tf.data.experimental.cardinality(test_ds))
Number of training samples: 9305
Number of validation samples: 2326
Number of test samples: 2326

這些是訓練資料集中的前 9 張影像 — 如您所見,它們的大小都不同。

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image)
    plt.title(int(label))
    plt.axis("off")

png

我們也可以看到標籤 1 是「狗」,標籤 0 是「貓」。

標準化資料

我們的原始影像具有各種大小。此外,每個像素由介於 0 到 255 之間的 3 個整數值(RGB 層級值)組成。這不太適合饋送神經網路。我們需要做 2 件事

  • 標準化為固定的影像大小。我們選擇 150x150。
  • 將像素值正規化在 -1 和 1 之間。我們將使用 Normalization 層作為模型本身的一部分來執行此操作。

一般來說,開發以原始資料作為輸入的模型是一種好的做法,而不是開發以已預處理資料作為輸入的模型。原因是,如果您的模型預期預處理資料,則每次您匯出模型以在其他地方(在 Web 瀏覽器、行動應用程式中)使用它時,您都需要重新實作完全相同的預處理管線。這很快就會變得非常棘手。因此,我們應該在接觸模型之前執行盡可能少的預處理量。

在這裡,我們將在資料管線中執行影像大小調整(因為深度神經網路只能處理連續的資料批次),並且我們將在建立模型時,將輸入值縮放作為模型的一部分來執行。

讓我們將影像大小調整為 150x150

size = (150, 150)

train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))
validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))
test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))

此外,讓我們對資料進行批次處理,並使用快取和預先擷取來最佳化載入速度。

batch_size = 32

train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)
test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)

使用隨機資料擴增

當您沒有大型影像資料集時,透過對訓練影像套用隨機但逼真的轉換(例如隨機水平翻轉或小的隨機旋轉)來人工引入樣本多樣性是一種好的做法。這有助於讓模型接觸到訓練資料的不同方面,同時減緩過度擬合。

from tensorflow import keras
from tensorflow.keras import layers

data_augmentation = keras.Sequential(
    [
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.1),
    ]
)

讓我們視覺化第一個批次的第一張影像在各種隨機轉換後的外觀

import numpy as np

for images, labels in train_ds.take(1):
    plt.figure(figsize=(10, 10))
    first_image = images[0]
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        augmented_image = data_augmentation(
            tf.expand_dims(first_image, 0), training=True
        )
        plt.imshow(augmented_image[0].numpy().astype("int32"))
        plt.title(int(labels[0]))
        plt.axis("off")
2023-10-03 11:11:16.151536: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

png

建構模型

現在讓我們建構一個遵循我們稍早解釋的藍圖的模型。

請注意:

  • 我們新增 Rescaling 層以將輸入值(最初在 [0, 255] 範圍內)縮放至 [-1, 1] 範圍。
  • 我們在分類層之前新增 Dropout 層,以進行正規化。
  • 我們確保在呼叫基礎模型時傳遞 training=False,使其在推論模式下執行,以便即使在我們解凍基礎模型以進行微調後,batchnorm 統計資訊也不會更新。
base_model = keras.applications.Xception(
    weights="imagenet",  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False,
)  # Do not include the ImageNet classifier at the top.

# Freeze the base_model
base_model.trainable = False

# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
x = data_augmentation(inputs)  # Apply random data augmentation

# Pre-trained Xception weights requires that input be scaled
# from (0, 255) to a range of (-1., +1.), the rescaling layer
# outputs: `(inputs * scale) + offset`
scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
x = scale_layer(x)

# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)  # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

model.summary()
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5
83683744/83683744 [==============================] - 0s 0us/step
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_5 (InputLayer)        [(None, 150, 150, 3)]     0         
                                                                 
 sequential_3 (Sequential)   (None, 150, 150, 3)       0         
                                                                 
 rescaling (Rescaling)       (None, 150, 150, 3)       0         
                                                                 
 xception (Functional)       (None, 5, 5, 2048)        20861480  
                                                                 
 global_average_pooling2d (  (None, 2048)              0         
 GlobalAveragePooling2D)                                         
                                                                 
 dropout (Dropout)           (None, 2048)              0         
                                                                 
 dense_7 (Dense)             (None, 1)                 2049      
                                                                 
=================================================================
Total params: 20863529 (79.59 MB)
Trainable params: 2049 (8.00 KB)
Non-trainable params: 20861480 (79.58 MB)
_________________________________________________________________

訓練頂層

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 20
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Epoch 1/20
142/291 [=============>................] - ETA: 35s - loss: 0.2211 - binary_accuracy: 0.8963
Corrupt JPEG data: 65 extraneous bytes before marker 0xd9
258/291 [=========================>....] - ETA: 7s - loss: 0.1831 - binary_accuracy: 0.9172
Corrupt JPEG data: 239 extraneous bytes before marker 0xd9
272/291 [===========================>..] - ETA: 4s - loss: 0.1797 - binary_accuracy: 0.9190
Corrupt JPEG data: 1153 extraneous bytes before marker 0xd9
275/291 [===========================>..] - ETA: 3s - loss: 0.1789 - binary_accuracy: 0.9194
Corrupt JPEG data: 228 extraneous bytes before marker 0xd9
291/291 [==============================] - ETA: 0s - loss: 0.1771 - binary_accuracy: 0.9205
Corrupt JPEG data: 2226 extraneous bytes before marker 0xd9
291/291 [==============================] - 89s 295ms/step - loss: 0.1771 - binary_accuracy: 0.9205 - val_loss: 0.0835 - val_binary_accuracy: 0.9652
Epoch 2/20
291/291 [==============================] - 85s 293ms/step - loss: 0.1197 - binary_accuracy: 0.9493 - val_loss: 0.0846 - val_binary_accuracy: 0.9699
Epoch 3/20
291/291 [==============================] - 83s 286ms/step - loss: 0.1135 - binary_accuracy: 0.9531 - val_loss: 0.0751 - val_binary_accuracy: 0.9708
Epoch 4/20
291/291 [==============================] - 83s 285ms/step - loss: 0.1037 - binary_accuracy: 0.9558 - val_loss: 0.0704 - val_binary_accuracy: 0.9712
Epoch 5/20
291/291 [==============================] - 83s 285ms/step - loss: 0.1024 - binary_accuracy: 0.9582 - val_loss: 0.0718 - val_binary_accuracy: 0.9733
Epoch 6/20
291/291 [==============================] - 83s 284ms/step - loss: 0.1006 - binary_accuracy: 0.9595 - val_loss: 0.0749 - val_binary_accuracy: 0.9699
Epoch 7/20
291/291 [==============================] - 83s 287ms/step - loss: 0.0961 - binary_accuracy: 0.9580 - val_loss: 0.0720 - val_binary_accuracy: 0.9699
Epoch 8/20
291/291 [==============================] - 83s 285ms/step - loss: 0.0952 - binary_accuracy: 0.9598 - val_loss: 0.0737 - val_binary_accuracy: 0.9712
Epoch 9/20
291/291 [==============================] - 83s 286ms/step - loss: 0.0984 - binary_accuracy: 0.9614 - val_loss: 0.0729 - val_binary_accuracy: 0.9708
Epoch 10/20
291/291 [==============================] - 83s 285ms/step - loss: 0.1007 - binary_accuracy: 0.9581 - val_loss: 0.0811 - val_binary_accuracy: 0.9686
Epoch 11/20
291/291 [==============================] - 83s 285ms/step - loss: 0.0960 - binary_accuracy: 0.9611 - val_loss: 0.0813 - val_binary_accuracy: 0.9703
Epoch 12/20
291/291 [==============================] - 83s 287ms/step - loss: 0.0950 - binary_accuracy: 0.9606 - val_loss: 0.0745 - val_binary_accuracy: 0.9703
Epoch 13/20
291/291 [==============================] - 84s 289ms/step - loss: 0.0970 - binary_accuracy: 0.9602 - val_loss: 0.0756 - val_binary_accuracy: 0.9703
Epoch 14/20
291/291 [==============================] - 83s 285ms/step - loss: 0.0915 - binary_accuracy: 0.9632 - val_loss: 0.0754 - val_binary_accuracy: 0.9690
Epoch 15/20
291/291 [==============================] - 84s 290ms/step - loss: 0.0938 - binary_accuracy: 0.9628 - val_loss: 0.0786 - val_binary_accuracy: 0.9682
Epoch 16/20
291/291 [==============================] - 82s 283ms/step - loss: 0.0958 - binary_accuracy: 0.9609 - val_loss: 0.0784 - val_binary_accuracy: 0.9682
Epoch 17/20
291/291 [==============================] - 83s 284ms/step - loss: 0.0907 - binary_accuracy: 0.9616 - val_loss: 0.0720 - val_binary_accuracy: 0.9721
Epoch 18/20
291/291 [==============================] - 83s 287ms/step - loss: 0.0946 - binary_accuracy: 0.9621 - val_loss: 0.0751 - val_binary_accuracy: 0.9716
Epoch 19/20
291/291 [==============================] - 83s 286ms/step - loss: 0.1004 - binary_accuracy: 0.9597 - val_loss: 0.0726 - val_binary_accuracy: 0.9703
Epoch 20/20
291/291 [==============================] - 82s 283ms/step - loss: 0.0891 - binary_accuracy: 0.9635 - val_loss: 0.0736 - val_binary_accuracy: 0.9712
<keras.src.callbacks.History at 0x7f701c5093a0>

對整個模型進行一輪微調

最後,讓我們解凍基礎模型,並以低學習率端對端訓練整個模型。

重要的是,儘管基礎模型變得可訓練,但它仍然在推論模式下執行,因為我們在建立模型時呼叫它時傳遞了 training=False。這表示內部的批次正規化層不會更新其批次統計資訊。如果它們更新了,它們將對模型至今學到的表示造成嚴重破壞。

# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True
model.summary()

model.compile(
    optimizer=keras.optimizers.Adam(1e-5),  # Low learning rate
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 10
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_5 (InputLayer)        [(None, 150, 150, 3)]     0         
                                                                 
 sequential_3 (Sequential)   (None, 150, 150, 3)       0         
                                                                 
 rescaling (Rescaling)       (None, 150, 150, 3)       0         
                                                                 
 xception (Functional)       (None, 5, 5, 2048)        20861480  
                                                                 
 global_average_pooling2d (  (None, 2048)              0         
 GlobalAveragePooling2D)                                         
                                                                 
 dropout (Dropout)           (None, 2048)              0         
                                                                 
 dense_7 (Dense)             (None, 1)                 2049      
                                                                 
=================================================================
Total params: 20863529 (79.59 MB)
Trainable params: 20809001 (79.38 MB)
Non-trainable params: 54528 (213.00 KB)
_________________________________________________________________
Epoch 1/10
291/291 [==============================] - 321s 1s/step - loss: 0.0751 - binary_accuracy: 0.9698 - val_loss: 0.0537 - val_binary_accuracy: 0.9781
Epoch 2/10
291/291 [==============================] - 302s 1s/step - loss: 0.0566 - binary_accuracy: 0.9787 - val_loss: 0.0490 - val_binary_accuracy: 0.9802
Epoch 3/10
291/291 [==============================] - 296s 1s/step - loss: 0.0455 - binary_accuracy: 0.9810 - val_loss: 0.0477 - val_binary_accuracy: 0.9794
Epoch 4/10
291/291 [==============================] - 289s 994ms/step - loss: 0.0351 - binary_accuracy: 0.9859 - val_loss: 0.0457 - val_binary_accuracy: 0.9789
Epoch 5/10
291/291 [==============================] - 289s 993ms/step - loss: 0.0268 - binary_accuracy: 0.9907 - val_loss: 0.0522 - val_binary_accuracy: 0.9794
Epoch 6/10
291/291 [==============================] - 288s 990ms/step - loss: 0.0258 - binary_accuracy: 0.9900 - val_loss: 0.0529 - val_binary_accuracy: 0.9789
Epoch 7/10
291/291 [==============================] - 286s 982ms/step - loss: 0.0209 - binary_accuracy: 0.9918 - val_loss: 0.0518 - val_binary_accuracy: 0.9776
Epoch 8/10
291/291 [==============================] - 289s 994ms/step - loss: 0.0185 - binary_accuracy: 0.9936 - val_loss: 0.0467 - val_binary_accuracy: 0.9832
Epoch 9/10
291/291 [==============================] - 289s 994ms/step - loss: 0.0150 - binary_accuracy: 0.9953 - val_loss: 0.0509 - val_binary_accuracy: 0.9802
Epoch 10/10
291/291 [==============================] - 292s 1s/step - loss: 0.0148 - binary_accuracy: 0.9952 - val_loss: 0.0501 - val_binary_accuracy: 0.9832
<keras.src.callbacks.History at 0x7f701c7e4f10>

在 10 個週期後,微調在此處為我們帶來了不錯的改進。