使用遷移學習分類花朵

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視 下載筆記本 查看 TF Hub 模型

您是否曾經看過美麗的花朵,並想知道這是什麼花?嗯,您不是第一個這麼想的人,所以讓我們建構一種從照片辨識花朵種類的方法!

對於分類圖片,一種稱為「卷積神經網路」的特定類型「深度神經網路」已被證明特別強大。然而,現代卷積神經網路具有數百萬個參數。從頭開始訓練它們需要大量標記的訓練資料和大量運算能力 (數百個 GPU 小時或更多)。我們只有大約三千張標記的照片,並且希望花費更少的時間,因此我們需要更聰明的方法。

我們將使用一種稱為「遷移學習」的技術,其中我們採用預先訓練的網路 (在約一百萬張通用圖片上訓練),使用它來擷取特徵,並在頂部訓練一個新層,以用於我們自己的花朵圖片分類任務。

設定

import collections
import io
import math
import os
import random
from six.moves import urllib

from IPython.display import clear_output, Image, display, HTML

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

import tensorflow_hub as hub

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import sklearn.metrics as sk_metrics
import time

花朵資料集

花朵資料集包含花朵圖片,並具有 5 個可能的類別標籤。

在訓練機器學習模型時,我們會將資料分成訓練和測試資料集。我們將在訓練資料上訓練模型,然後評估模型在從未見過的資料 (測試集) 上的效能。

讓我們下載訓練和測試範例 (可能需要一些時間),並將它們分成訓練和測試集。

執行以下兩個儲存格

FLOWERS_DIR = './flower_photos'
TRAIN_FRACTION = 0.8
RANDOM_SEED = 2018


def download_images():
  """If the images aren't already downloaded, save them to FLOWERS_DIR."""
  if not os.path.exists(FLOWERS_DIR):
    DOWNLOAD_URL = 'http://download.tensorflow.org/example_images/flower_photos.tgz'
    print('Downloading flower images from %s...' % DOWNLOAD_URL)
    urllib.request.urlretrieve(DOWNLOAD_URL, 'flower_photos.tgz')
    !tar xfz flower_photos.tgz
  print('Flower photos are located in %s' % FLOWERS_DIR)


def make_train_and_test_sets():
  """Split the data into train and test sets and get the label classes."""
  train_examples, test_examples = [], []
  shuffler = random.Random(RANDOM_SEED)
  is_root = True
  for (dirname, subdirs, filenames) in tf.gfile.Walk(FLOWERS_DIR):
    # The root directory gives us the classes
    if is_root:
      subdirs = sorted(subdirs)
      classes = collections.OrderedDict(enumerate(subdirs))
      label_to_class = dict([(x, i) for i, x in enumerate(subdirs)])
      is_root = False
    # The sub directories give us the image files for training.
    else:
      filenames.sort()
      shuffler.shuffle(filenames)
      full_filenames = [os.path.join(dirname, f) for f in filenames]
      label = dirname.split('/')[-1]
      label_class = label_to_class[label]
      # An example is the image file and it's label class.
      examples = list(zip(full_filenames, [label_class] * len(filenames)))
      num_train = int(len(filenames) * TRAIN_FRACTION)
      train_examples.extend(examples[:num_train])
      test_examples.extend(examples[num_train:])

  shuffler.shuffle(train_examples)
  shuffler.shuffle(test_examples)
  return train_examples, test_examples, classes
# Download the images and split the images into train and test sets.
download_images()
TRAIN_EXAMPLES, TEST_EXAMPLES, CLASSES = make_train_and_test_sets()
NUM_CLASSES = len(CLASSES)

print('\nThe dataset has %d label classes: %s' % (NUM_CLASSES, CLASSES.values()))
print('There are %d training images' % len(TRAIN_EXAMPLES))
print('there are %d test images' % len(TEST_EXAMPLES))

探索資料

花朵資料集包含標記的花朵圖片範例。每個範例都包含一張 JPEG 花朵圖片和類別標籤:花朵的種類。讓我們一起顯示一些圖片及其標籤。

顯示一些標記的圖片

建構模型

我們將載入 TF-Hub 圖片特徵向量模組,在其上堆疊線性分類器,並新增訓練和評估運算。以下儲存格會建構描述模型及其訓練的 TF 圖形,但不會執行訓練 (將在下一步執行)。

LEARNING_RATE = 0.01

tf.reset_default_graph()

# Load a pre-trained TF-Hub module for extracting features from images. We've
# chosen this particular module for speed, but many other choices are available.
image_module = hub.Module('https://tfhub.dev/google/imagenet/mobilenet_v2_035_128/feature_vector/2')

# Preprocessing images into tensors with size expected by the image module.
encoded_images = tf.placeholder(tf.string, shape=[None])
image_size = hub.get_expected_image_size(image_module)


def decode_and_resize_image(encoded):
  decoded = tf.image.decode_jpeg(encoded, channels=3)
  decoded = tf.image.convert_image_dtype(decoded, tf.float32)
  return tf.image.resize_images(decoded, image_size)


batch_images = tf.map_fn(decode_and_resize_image, encoded_images, dtype=tf.float32)

# The image module can be applied as a function to extract feature vectors for a
# batch of images.
features = image_module(batch_images)


def create_model(features):
  """Build a model for classification from extracted features."""
  # Currently, the model is just a single linear layer. You can try to add
  # another layer, but be careful... two linear layers (when activation=None)
  # are equivalent to a single linear layer. You can create a nonlinear layer
  # like this:
  # layer = tf.layers.dense(inputs=..., units=..., activation=tf.nn.relu)
  layer = tf.layers.dense(inputs=features, units=NUM_CLASSES, activation=None)
  return layer


# For each class (kind of flower), the model outputs some real number as a score
# how much the input resembles this class. This vector of numbers is often
# called the "logits".
logits = create_model(features)
labels = tf.placeholder(tf.float32, [None, NUM_CLASSES])

# Mathematically, a good way to measure how much the predicted probabilities
# diverge from the truth is the "cross-entropy" between the two probability
# distributions. For numerical stability, this is best done directly from the
# logits, not the probabilities extracted from them.
cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=labels)
cross_entropy_mean = tf.reduce_mean(cross_entropy)

# Let's add an optimizer so we can train the network.
optimizer = tf.train.GradientDescentOptimizer(learning_rate=LEARNING_RATE)
train_op = optimizer.minimize(loss=cross_entropy_mean)

# The "softmax" function transforms the logits vector into a vector of
# probabilities: non-negative numbers that sum up to one, and the i-th number
# says how likely the input comes from class i.
probabilities = tf.nn.softmax(logits)

# We choose the highest one as the predicted class.
prediction = tf.argmax(probabilities, 1)
correct_prediction = tf.equal(prediction, tf.argmax(labels, 1))

# The accuracy will allow us to eval on our test set. 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

訓練網路

現在我們的模型已建構完成,讓我們訓練它,並看看它在測試集上的效能。

# How long will we train the network (number of batches).
NUM_TRAIN_STEPS = 100
# How many training examples we use in each step.
TRAIN_BATCH_SIZE = 10
# How often to evaluate the model performance.
EVAL_EVERY = 10

def get_batch(batch_size=None, test=False):
  """Get a random batch of examples."""
  examples = TEST_EXAMPLES if test else TRAIN_EXAMPLES
  batch_examples = random.sample(examples, batch_size) if batch_size else examples
  return batch_examples

def get_images_and_labels(batch_examples):
  images = [get_encoded_image(e) for e in batch_examples]
  one_hot_labels = [get_label_one_hot(e) for e in batch_examples]
  return images, one_hot_labels

def get_label_one_hot(example):
  """Get the one hot encoding vector for the example."""
  one_hot_vector = np.zeros(NUM_CLASSES)
  np.put(one_hot_vector, get_label(example), 1)
  return one_hot_vector

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  for i in range(NUM_TRAIN_STEPS):
    # Get a random batch of training examples.
    train_batch = get_batch(batch_size=TRAIN_BATCH_SIZE)
    batch_images, batch_labels = get_images_and_labels(train_batch)
    # Run the train_op to train the model.
    train_loss, _, train_accuracy = sess.run(
        [cross_entropy_mean, train_op, accuracy],
        feed_dict={encoded_images: batch_images, labels: batch_labels})
    is_final_step = (i == (NUM_TRAIN_STEPS - 1))
    if i % EVAL_EVERY == 0 or is_final_step:
      # Get a batch of test examples.
      test_batch = get_batch(batch_size=None, test=True)
      batch_images, batch_labels = get_images_and_labels(test_batch)
      # Evaluate how well our model performs on the test set.
      test_loss, test_accuracy, test_prediction, correct_predicate = sess.run(
        [cross_entropy_mean, accuracy, prediction, correct_prediction],
        feed_dict={encoded_images: batch_images, labels: batch_labels})
      print('Test accuracy at step %s: %.2f%%' % (i, (test_accuracy * 100)))
def show_confusion_matrix(test_labels, predictions):
  """Compute confusion matrix and normalize."""
  confusion = sk_metrics.confusion_matrix(
    np.argmax(test_labels, axis=1), predictions)
  confusion_normalized = confusion.astype("float") / confusion.sum(axis=1)
  axis_labels = list(CLASSES.values())
  ax = sns.heatmap(
      confusion_normalized, xticklabels=axis_labels, yticklabels=axis_labels,
      cmap='Blues', annot=True, fmt='.2f', square=True)
  plt.title("Confusion matrix")
  plt.ylabel("True label")
  plt.xlabel("Predicted label")

show_confusion_matrix(batch_labels, test_prediction)

不正確的預測

讓我們仔細看看模型判斷錯誤的測試範例。

  • 我們的測試集中是否有任何標記錯誤的範例?
  • 測試集中是否有任何不良資料 (實際上不是花朵圖片的圖片)?
  • 是否有您可以理解模型為何犯錯的圖片?
incorrect = [
    (example, CLASSES[prediction])
    for example, prediction, is_correct in zip(test_batch, test_prediction, correct_predicate)
    if not is_correct
]
display_images(
  [(get_image(example), "prediction: {0}\nlabel:{1}".format(incorrect_prediction, get_class(example)))
   for (example, incorrect_prediction) in incorrect[:20]])

練習:改進模型!

我們已經訓練了一個基準模型,現在讓我們嘗試改進它以獲得更高的準確性。(請記住,當您進行變更時,需要重新執行儲存格。)

練習 1:嘗試不同的圖片模型。

使用 TF-Hub,嘗試幾個不同的圖片模型很簡單。只需將 "https://tfhub.dev/google/imagenet/mobilenet_v2_050_128/feature_vector/2" 控制代碼替換為 hub.Module() 呼叫中不同模組的控制代碼,然後重新執行所有程式碼。您可以在 tfhub.dev 上查看所有可用的圖片模組。

一個不錯的選擇可能是其他 MobileNet V2 模組之一。許多模組 (包括 MobileNet 模組) 都在 ImageNet 資料集上進行了訓練,該資料集包含超過 100 萬張圖片和 1000 個類別。選擇網路架構可在速度和分類準確性之間提供權衡:MobileNet 或 NASNet Mobile 等模型速度快且體積小,而 Inception 和 ResNet 等更傳統的架構則專為準確性而設計。

對於較大的 Inception V3 架構,您也可以探索在更接近您自己任務的網域上進行預先訓練的好處:它也可以作為 在 iNaturalist 植物和動物資料集上訓練的模組使用。

練習 2:新增隱藏層。

在擷取的圖片特徵和線性分類器之間堆疊一個隱藏層 (在上面的 create_model() 函式中)。若要建立具有非線性隱藏層 (例如 100 個節點),請使用 tf.layers.dense,並將單位設定為 100,並將啟動設定為 tf.nn.relu。變更隱藏層的大小是否會影響測試準確性?新增第二個隱藏層是否會提高準確性?

練習 3:變更超參數。

增加「訓練步驟數」是否能提高最終準確性?您可以「變更學習率」以使模型更快收斂嗎?訓練「批次大小」是否會影響模型的效能?

練習 4:嘗試不同的最佳化工具。

將基本 GradientDescentOptimizer 替換為更複雜的最佳化工具,例如 AdagradOptimizer。這對您的模型訓練有影響嗎?如果您想進一步瞭解不同最佳化演算法的優點,請查看這篇文章

想瞭解更多資訊?

如果您對本教學課程的更進階版本感興趣,請查看 TensorFlow 圖片重新訓練教學課程,其中會逐步引導您完成使用 TensorBoard 可視化訓練、透過扭曲圖片等進階技術來擴充資料集,以及替換花朵資料集以在您自己的資料集上學習圖片分類器。

您可以在 tensorflow.org 上瞭解更多關於 TensorFlow 的資訊,並在 tensorflow.org/hub 上查看 TF-Hub API 文件。在 tfhub.dev 上尋找可用的 TensorFlow Hub 模組,包括更多圖片特徵向量模組和文字嵌入模組。

另請查看 機器學習速成課程,這是 Google 快速實用的機器學習入門課程。