使用 TensorFlow Lite Model Maker 進行音訊領域的遷移學習

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視原始碼 下載筆記本 查看 TF Hub 模型

在本 Colab 筆記本中,您將學習如何使用 TensorFlow Lite Model Maker 訓練自訂音訊分類模型。

Model Maker 程式庫使用遷移學習來簡化使用自訂資料集訓練 TensorFlow Lite 模型的流程。使用您自己的自訂資料集重新訓練 TensorFlow Lite 模型,可減少所需的訓練資料和時間。

它是自訂音訊模型並部署在 Android 上的 Codelab 的一部分。

您將使用自訂鳥類資料集,並匯出可在手機上使用的 TFLite 模型、可用於瀏覽器中推論的 TensorFlow.JS 模型,以及可用於服務的 SavedModel 版本。

安裝依附元件

sudo apt -y install libportaudio2
pip install tflite-model-maker

匯入 TensorFlow、Model Maker 和其他程式庫

在所需的依附元件中,您將使用 TensorFlow 和 Model Maker。除此之外,其他依附元件用於音訊處理、播放和視覺化。

import tensorflow as tf
import tflite_model_maker as mm
from tflite_model_maker import audio_classifier
import os

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import itertools
import glob
import random

from IPython.display import Audio, Image
from scipy.io import wavfile

print(f"TensorFlow Version: {tf.__version__}")
print(f"Model Maker Version: {mm.__version__}")

鳥類資料集

鳥類資料集是 5 種鳥類歌曲的教育收集

  • 白胸林鶯
  • 家麻雀
  • 紅交嘴雀
  • 栗冠蟻鷯
  • 艾氏旋木雀

原始音訊來自 Xeno-canto,這是一個致力於分享來自世界各地鳥類聲音的網站。

讓我們從下載資料開始。

birds_dataset_folder = tf.keras.utils.get_file('birds_dataset.zip',
                                                'https://storage.googleapis.com/laurencemoroney-blog.appspot.com/birds_dataset.zip',
                                                cache_dir='./',
                                                cache_subdir='dataset',
                                                extract=True)

探索資料

音訊已在 train 和 test 資料夾中分割。在每個分割資料夾中,每個鳥類都有一個資料夾,使用它們的 bird_code 作為名稱。

音訊都是單聲道且取樣率為 16kHz。

如需每個檔案的詳細資訊,您可以閱讀 metadata.csv 檔案。它包含所有檔案作者、授權和更多資訊。在本教學課程中,您無需自行閱讀。

# @title [Run this] Util functions and data structures.

data_dir = './dataset/small_birds_dataset'

bird_code_to_name = {
  'wbwwre1': 'White-breasted Wood-Wren',
  'houspa': 'House Sparrow',
  'redcro': 'Red Crossbill',  
  'chcant2': 'Chestnut-crowned Antpitta',
  'azaspi1': "Azara's Spinetail",   
}

birds_images = {
  'wbwwre1': 'https://upload.wikimedia.org/wikipedia/commons/thumb/2/22/Henicorhina_leucosticta_%28Cucarachero_pechiblanco%29_-_Juvenil_%2814037225664%29.jpg/640px-Henicorhina_leucosticta_%28Cucarachero_pechiblanco%29_-_Juvenil_%2814037225664%29.jpg', #   Alejandro Bayer Tamayo from Armenia, Colombia 
  'houspa': 'https://upload.wikimedia.org/wikipedia/commons/thumb/5/52/House_Sparrow%2C_England_-_May_09.jpg/571px-House_Sparrow%2C_England_-_May_09.jpg', #    Diliff
  'redcro': 'https://upload.wikimedia.org/wikipedia/commons/thumb/4/49/Red_Crossbills_%28Male%29.jpg/640px-Red_Crossbills_%28Male%29.jpg', #  Elaine R. Wilson, www.naturespicsonline.com
  'chcant2': 'https://upload.wikimedia.org/wikipedia/commons/thumb/6/67/Chestnut-crowned_antpitta_%2846933264335%29.jpg/640px-Chestnut-crowned_antpitta_%2846933264335%29.jpg', #   Mike's Birds from Riverside, CA, US
  'azaspi1': 'https://upload.wikimedia.org/wikipedia/commons/thumb/b/b2/Synallaxis_azarae_76608368.jpg/640px-Synallaxis_azarae_76608368.jpg', # https://www.inaturalist.org/photos/76608368
}

test_files = os.path.abspath(os.path.join(data_dir, 'test/*/*.wav'))

def get_random_audio_file():
  test_list = glob.glob(test_files)
  random_audio_path = random.choice(test_list)
  return random_audio_path


def show_bird_data(audio_path):
  sample_rate, audio_data = wavfile.read(audio_path, 'rb')

  bird_code = audio_path.split('/')[-2]
  print(f'Bird name: {bird_code_to_name[bird_code]}')
  print(f'Bird code: {bird_code}')
  display(Image(birds_images[bird_code]))

  plttitle = f'{bird_code_to_name[bird_code]} ({bird_code})'
  plt.title(plttitle)
  plt.plot(audio_data)
  display(Audio(audio_data, rate=sample_rate))

print('functions and data structures created')

播放一些音訊

為了更瞭解資料,讓我們聽一些來自測試分割的隨機音訊檔案。

random_audio = get_random_audio_file()
show_bird_data(random_audio)

訓練模型

使用 Model Maker 進行音訊處理時,您必須從模型規格開始。這是您的新模型將從中提取資訊以瞭解新類別的基礎模型。它也會影響資料集如何轉換以符合模型規格參數,例如:取樣率、頻道數。

YAMNet 是一個音訊事件分類器,在 AudioSet 資料集上訓練,以預測 AudioSet 本體論中的音訊事件。

它的輸入預期為 16kHz 和 1 個頻道。

您無需自行重新取樣。Model Maker 會為您處理。

  • frame_length 是決定每個訓練範例的長度。在本例中,為 EXPECTED_WAVEFORM_LENGTH * 3 秒

  • frame_steps 是決定訓練範例之間的間隔距離。在本例中,第 i 個範例將在第 (i-1) 個範例之後的 EXPECTED_WAVEFORM_LENGTH * 6 秒開始。

設定這些值的原因是為了解決真實世界資料集中的一些限制。

例如,在鳥類資料集中,鳥類不會一直唱歌。牠們唱歌、休息,然後再次唱歌,中間夾雜著噪音。擁有較長的影格有助於捕捉歌聲,但將其設定得太長會減少訓練的範例數量。

spec = audio_classifier.YamNetSpec(
    keep_yamnet_and_custom_heads=True,
    frame_step=3 * audio_classifier.YamNetSpec.EXPECTED_WAVEFORM_LENGTH,
    frame_length=6 * audio_classifier.YamNetSpec.EXPECTED_WAVEFORM_LENGTH)

載入資料

Model Maker 具有 API,可從資料夾載入資料,並使其成為模型規格預期的格式。

訓練和測試分割基於資料夾。驗證資料集將建立為訓練分割的 20%。

train_data = audio_classifier.DataLoader.from_folder(
    spec, os.path.join(data_dir, 'train'), cache=True)
train_data, validation_data = train_data.split(0.8)
test_data = audio_classifier.DataLoader.from_folder(
    spec, os.path.join(data_dir, 'test'), cache=True)

訓練模型

audio_classifier 具有 create 方法,用於建立模型並開始訓練。

您可以自訂許多參數,如需更多資訊,您可以閱讀文件中的更多詳細資訊。

在第一次嘗試中,您將使用所有預設設定,並訓練 100 個 epoch。

batch_size = 128
epochs = 100

print('Training the model')
model = audio_classifier.create(
    train_data,
    spec,
    validation_data,
    batch_size=batch_size,
    epochs=epochs)

準確度看起來不錯,但重要的是在測試資料上執行評估步驟,並驗證您的模型在未見過的資料上取得了良好的結果。

print('Evaluating the model')
model.evaluate(test_data)

瞭解您的模型

在訓練分類器時,查看混淆矩陣很有用。混淆矩陣可讓您詳細瞭解分類器在測試資料上的效能。

Model Maker 已經為您建立混淆矩陣。

def show_confusion_matrix(confusion, test_labels):
  """Compute confusion matrix and normalize."""
  confusion_normalized = confusion.astype("float") / confusion.sum(axis=1)
  axis_labels = test_labels
  ax = sns.heatmap(
      confusion_normalized, xticklabels=axis_labels, yticklabels=axis_labels,
      cmap='Blues', annot=True, fmt='.2f', square=True)
  plt.title("Confusion matrix")
  plt.ylabel("True label")
  plt.xlabel("Predicted label")

confusion_matrix = model.confusion_matrix(test_data)
show_confusion_matrix(confusion_matrix.numpy(), test_data.index_to_label)

測試模型 [選用]

您可以嘗試在來自測試資料集的範例音訊上測試模型,以查看結果。

首先,您取得服務模型。

serving_model = model.create_serving_model()

print(f'Model\'s input shape and type: {serving_model.inputs}')
print(f'Model\'s output shape and type: {serving_model.outputs}')

回到您稍早載入的隨機音訊

# if you want to try another file just uncoment the line below
random_audio = get_random_audio_file()
show_bird_data(random_audio)

建立的模型具有固定的輸入視窗。

對於給定的音訊檔案,您必須將其分割成預期大小的資料視窗。最後一個視窗可能需要用零填充。

sample_rate, audio_data = wavfile.read(random_audio, 'rb')

audio_data = np.array(audio_data) / tf.int16.max
input_size = serving_model.input_shape[1]

splitted_audio_data = tf.signal.frame(audio_data, input_size, input_size, pad_end=True, pad_value=0)

print(f'Test audio path: {random_audio}')
print(f'Original size of the audio data: {len(audio_data)}')
print(f'Number of windows for inference: {len(splitted_audio_data)}')

您將迴圈遍歷所有分割的音訊,並將模型應用於每個音訊。

您剛訓練的模型有 2 個輸出:原始 YAMNet 的輸出和您剛訓練的輸出。這很重要,因為真實世界環境比鳥叫聲更複雜。您可以使用 YAMNet 的輸出過濾掉不相關的音訊,例如,在鳥類用例中,如果 YAMNet 沒有將鳥類或動物分類,這可能表示您的模型的輸出可能具有不相關的分類。

下面列印了兩個輸出,以便更容易理解它們之間的關係。您的模型犯的大部分錯誤都發生在 YAMNet 的預測與您的領域 (例如:鳥類) 無關時。

print(random_audio)

results = []
print('Result of the window ith:  your model class -> score,  (spec class -> score)')
for i, data in enumerate(splitted_audio_data):
  yamnet_output, inference = serving_model(data)
  results.append(inference[0].numpy())
  result_index = tf.argmax(inference[0])
  spec_result_index = tf.argmax(yamnet_output[0])
  t = spec._yamnet_labels()[spec_result_index]
  result_str = f'Result of the window {i}: ' \
  f'\t{test_data.index_to_label[result_index]} -> {inference[0][result_index].numpy():.3f}, ' \
  f'\t({spec._yamnet_labels()[spec_result_index]} -> {yamnet_output[0][spec_result_index]:.3f})'
  print(result_str)


results_np = np.array(results)
mean_results = results_np.mean(axis=0)
result_index = mean_results.argmax()
print(f'Mean result: {test_data.index_to_label[result_index]} -> {mean_results[result_index]}')

匯出模型

最後一步是匯出您的模型,以便在嵌入式裝置或瀏覽器上使用。

export 方法會為您匯出兩種格式。

models_path = './birds_models'
print(f'Exporing the TFLite model to {models_path}')

model.export(models_path, tflite_filename='my_birds_model.tflite')

您也可以匯出 SavedModel 版本,用於服務或在 Python 環境中使用。

model.export(models_path, export_format=[mm.ExportFormat.SAVED_MODEL, mm.ExportFormat.LABEL])

後續步驟

您做到了。

現在,您可以使用 TFLite AudioClassifier Task API 將您的新模型部署在行動裝置上。

您也可以嘗試使用具有不同類別的您自己的資料進行相同的流程,以下是 Model Maker for Audio Classification 的文件。

也可以從端對端參考應用程式中學習:AndroidiOS