![]() |
![]() |
![]() |
![]() |
![]() |
這個筆記本示範如何針對來自 TFDS 的資料集或您自己的作物病害偵測資料集,微調 TensorFlow Hub 的 CropNet 模型。
您將會
- 載入 TFDS 木薯資料集或您自己的資料
- 使用未知(負面)範例豐富資料,以獲得更穩健的模型
- 將圖片增強套用至資料
- 從 TF Hub 載入和微調 CropNet 模型
- 匯出 TFLite 模型,即可直接透過 Task Library、MLKit 或 TFLite 部署在您的應用程式中
匯入和依賴項
在開始之前,您需要安裝一些必要的依賴項,例如 Model Maker 和最新版本的 TensorFlow Datasets。
sudo apt install -q libportaudio2
## image_classifier library requires numpy <= 1.23.5
pip install "numpy<=1.23.5"
pip install --use-deprecated=legacy-resolver tflite-model-maker-nightly
pip install -U tensorflow-datasets
## scann library requires tensorflow < 2.9.0
pip install "tensorflow<2.9.0"
pip install "tensorflow-datasets~=4.8.0" # protobuf>=3.12.2
pip install tensorflow-metadata~=1.10.0 # protobuf>=3.13
## tensorflowjs requires packaging < 20.10
pip install "packaging<20.10"
import matplotlib.pyplot as plt
import os
import seaborn as sns
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow_examples.lite.model_maker.core.export_format import ExportFormat
from tensorflow_examples.lite.model_maker.core.task import image_preprocessing
from tflite_model_maker import image_classifier
from tflite_model_maker import ImageClassifierDataLoader
from tflite_model_maker.image_classifier import ModelSpec
2023-11-07 13:39:32.174301: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_addons/utils/tfa_eol_msg.py:23: UserWarning: TensorFlow Addons (TFA) has ended development and introduction of new features. TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024. Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). For more information see: https://github.com/tensorflow/addons/issues/2807 warnings.warn( /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_addons/utils/ensure_tf_install.py:53: UserWarning: Tensorflow Addons supports using Python ops for all Tensorflow versions above or equal to 2.12.0 and strictly below 2.15.0 (nightly versions are not supported). The versions of TensorFlow you are currently using is 2.8.4 and is not supported. Some things might work, some things might not. If you were to encounter a bug, do not file an issue. If you want to make sure you're using a tested and supported configuration, either change the TensorFlow version or the TensorFlow Addons's version. You can find the compatibility matrix in TensorFlow Addon's readme: https://github.com/tensorflow/addons warnings.warn(
載入 TFDS 資料集以進行微調
讓我們使用 TFDS 中公開可用的 木薯葉病資料集。
tfds_name = 'cassava'
(ds_train, ds_validation, ds_test), ds_info = tfds.load(
name=tfds_name,
split=['train', 'validation', 'test'],
with_info=True,
as_supervised=True)
TFLITE_NAME_PREFIX = tfds_name
2023-11-07 13:39:36.293577: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
或者,載入您自己的資料進行微調
除了使用 TFDS 資料集,您也可以使用自己的資料進行訓練。此程式碼片段示範如何載入您自己的自訂資料集。請參閱此連結以了解資料的支援結構。此處提供一個使用公開可用的木薯葉病資料集的範例。
# data_root_dir = tf.keras.utils.get_file(
# 'cassavaleafdata.zip',
# 'https://storage.googleapis.com/emcassavadata/cassavaleafdata.zip',
# extract=True)
# data_root_dir = os.path.splitext(data_root_dir)[0] # Remove the .zip extension
# builder = tfds.ImageFolder(data_root_dir)
# ds_info = builder.info
# ds_train = builder.as_dataset(split='train', as_supervised=True)
# ds_validation = builder.as_dataset(split='validation', as_supervised=True)
# ds_test = builder.as_dataset(split='test', as_supervised=True)
可視化訓練分割的樣本
讓我們看看資料集中的一些範例,包括圖片樣本及其標籤的類別 ID 和類別名稱。
_ = tfds.show_examples(ds_train, ds_info)
從 TFDS 資料集新增圖片以用作未知範例
將額外的未知(負面)範例新增至訓練資料集,並為它們分配新的未知類別標籤號碼。目標是建立一個模型,使其在實際應用(例如在田野中)時,能夠在看到意料之外的事物時預測「未知」。
您可以在下方看到將用於採樣額外未知影像的資料集清單。它包含 3 個完全不同的資料集,以增加多樣性。其中一個是豆類葉病資料集,以便模型可以接觸到木薯以外的患病植物。
UNKNOWN_TFDS_DATASETS = [{
'tfds_name': 'imagenet_v2/matched-frequency',
'train_split': 'test[:80%]',
'test_split': 'test[80%:]',
'num_examples_ratio_to_normal': 1.0,
}, {
'tfds_name': 'oxford_flowers102',
'train_split': 'train',
'test_split': 'test',
'num_examples_ratio_to_normal': 1.0,
}, {
'tfds_name': 'beans',
'train_split': 'train',
'test_split': 'test',
'num_examples_ratio_to_normal': 1.0,
}]
UNKNOWN 資料集也從 TFDS 載入。
# Load unknown datasets.
weights = [
spec['num_examples_ratio_to_normal'] for spec in UNKNOWN_TFDS_DATASETS
]
num_unknown_train_examples = sum(
int(w * ds_train.cardinality().numpy()) for w in weights)
ds_unknown_train = tf.data.Dataset.sample_from_datasets([
tfds.load(
name=spec['tfds_name'], split=spec['train_split'],
as_supervised=True).repeat(-1) for spec in UNKNOWN_TFDS_DATASETS
], weights).take(num_unknown_train_examples)
ds_unknown_train = ds_unknown_train.apply(
tf.data.experimental.assert_cardinality(num_unknown_train_examples))
ds_unknown_tests = [
tfds.load(
name=spec['tfds_name'], split=spec['test_split'], as_supervised=True)
for spec in UNKNOWN_TFDS_DATASETS
]
ds_unknown_test = ds_unknown_tests[0]
for ds in ds_unknown_tests[1:]:
ds_unknown_test = ds_unknown_test.concatenate(ds)
# All examples from the unknown datasets will get a new class label number.
num_normal_classes = len(ds_info.features['label'].names)
unknown_label_value = tf.convert_to_tensor(num_normal_classes, tf.int64)
ds_unknown_train = ds_unknown_train.map(lambda image, _:
(image, unknown_label_value))
ds_unknown_test = ds_unknown_test.map(lambda image, _:
(image, unknown_label_value))
# Merge the normal train dataset with the unknown train dataset.
weights = [
ds_train.cardinality().numpy(),
ds_unknown_train.cardinality().numpy()
]
ds_train_with_unknown = tf.data.Dataset.sample_from_datasets(
[ds_train, ds_unknown_train], [float(w) for w in weights])
ds_train_with_unknown = ds_train_with_unknown.apply(
tf.data.experimental.assert_cardinality(sum(weights)))
print((f"Added {ds_unknown_train.cardinality().numpy()} negative examples."
f"Training dataset has now {ds_train_with_unknown.cardinality().numpy()}"
' examples in total.'))
Added 16968 negative examples.Training dataset has now 22624 examples in total.
套用增強
為了使所有圖片更多樣化,您將套用一些增強,例如變更:
- 亮度
- 對比度
- 飽和度
- 色相
- 裁剪
這些類型的增強有助於使模型更能適應圖片輸入的變化。
def random_crop_and_random_augmentations_fn(image):
# preprocess_for_train does random crop and resize internally.
image = image_preprocessing.preprocess_for_train(image)
image = tf.image.random_brightness(image, 0.2)
image = tf.image.random_contrast(image, 0.5, 2.0)
image = tf.image.random_saturation(image, 0.75, 1.25)
image = tf.image.random_hue(image, 0.1)
return image
def random_crop_fn(image):
# preprocess_for_train does random crop and resize internally.
image = image_preprocessing.preprocess_for_train(image)
return image
def resize_and_center_crop_fn(image):
image = tf.image.resize(image, (256, 256))
image = image[16:240, 16:240]
return image
no_augment_fn = lambda image: image
train_augment_fn = lambda image, label: (
random_crop_and_random_augmentations_fn(image), label)
eval_augment_fn = lambda image, label: (resize_and_center_crop_fn(image), label)
為了套用增強,它使用了 Dataset 類別的 map
方法。
ds_train_with_unknown = ds_train_with_unknown.map(train_augment_fn)
ds_validation = ds_validation.map(eval_augment_fn)
ds_test = ds_test.map(eval_augment_fn)
ds_unknown_test = ds_unknown_test.map(eval_augment_fn)
INFO:tensorflow:Use default resize_bicubic. INFO:tensorflow:Use default resize_bicubic. INFO:tensorflow:Use customized resize method bilinear INFO:tensorflow:Use customized resize method bilinear
將資料包裝成 Model Maker 友善的格式
若要將這些資料集與 Model Maker 搭配使用,它們需要採用 ImageClassifierDataLoader 類別。
label_names = ds_info.features['label'].names + ['UNKNOWN']
train_data = ImageClassifierDataLoader(ds_train_with_unknown,
ds_train_with_unknown.cardinality(),
label_names)
validation_data = ImageClassifierDataLoader(ds_validation,
ds_validation.cardinality(),
label_names)
test_data = ImageClassifierDataLoader(ds_test, ds_test.cardinality(),
label_names)
unknown_test_data = ImageClassifierDataLoader(ds_unknown_test,
ds_unknown_test.cardinality(),
label_names)
執行訓練
TensorFlow Hub 有多個模型可用於遷移學習。
您可以在這裡選擇一個,也可以繼續嘗試其他模型,以獲得更好的結果。
如果您想要嘗試更多模型,可以從這個集合中新增它們。
選擇基礎模型
為了微調模型,您將使用 Model Maker。這使得整體解決方案更簡單,因為在模型訓練完成後,它也會將其轉換為 TFLite。
Model Maker 使此轉換成為最佳轉換,並包含所有必要資訊,以便稍後輕鬆地在裝置端部署模型。
模型規格是您告訴 Model Maker 您想要使用的基礎模型的方式。
image_model_spec = ModelSpec(uri=model_handle)
此處的一個重要細節是設定 train_whole_model
,這將使基礎模型在訓練期間進行微調。這會使過程變慢,但最終模型具有更高的準確性。設定 shuffle
將確保模型以隨機洗牌的順序查看資料,這是模型學習的最佳實務做法。
model = image_classifier.create(
train_data,
model_spec=image_model_spec,
batch_size=128,
learning_rate=0.03,
epochs=5,
shuffle=True,
train_whole_model=True,
validation_data=validation_data)
INFO:tensorflow:Retraining the models... INFO:tensorflow:Retraining the models... Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= hub_keras_layer_v1v2 (HubKe (None, 1280) 4226432 rasLayerV1V2) dropout (Dropout) (None, 1280) 0 dense (Dense) (None, 6) 7686 ================================================================= Total params: 4,234,118 Trainable params: 4,209,718 Non-trainable params: 24,400 _________________________________________________________________ None Epoch 1/5 176/176 [==============================] - 485s 3s/step - loss: 0.8830 - accuracy: 0.9190 - val_loss: 1.1238 - val_accuracy: 0.8068 Epoch 2/5 176/176 [==============================] - 463s 3s/step - loss: 0.7892 - accuracy: 0.9545 - val_loss: 1.0590 - val_accuracy: 0.8290 Epoch 3/5 176/176 [==============================] - 464s 3s/step - loss: 0.7744 - accuracy: 0.9577 - val_loss: 1.0222 - val_accuracy: 0.8438 Epoch 4/5 176/176 [==============================] - 463s 3s/step - loss: 0.7617 - accuracy: 0.9633 - val_loss: 1.0435 - val_accuracy: 0.8407 Epoch 5/5 176/176 [==============================] - 461s 3s/step - loss: 0.7571 - accuracy: 0.9653 - val_loss: 0.9859 - val_accuracy: 0.8655
在測試分割上評估模型
model.evaluate(test_data)
59/59 [==============================] - 7s 101ms/step - loss: 0.9668 - accuracy: 0.8637 [0.9668245911598206, 0.863660454750061]
為了更深入了解微調後的模型,最好分析混淆矩陣。這將顯示一個類別被預測為另一個類別的頻率。
def predict_class_label_number(dataset):
"""Runs inference and returns predictions as class label numbers."""
rev_label_names = {l: i for i, l in enumerate(label_names)}
return [
rev_label_names[o[0][0]]
for o in model.predict_top_k(dataset, batch_size=128)
]
def show_confusion_matrix(cm, labels):
plt.figure(figsize=(10, 8))
sns.heatmap(cm, xticklabels=labels, yticklabels=labels,
annot=True, fmt='g')
plt.xlabel('Prediction')
plt.ylabel('Label')
plt.show()
confusion_mtx = tf.math.confusion_matrix(
list(ds_test.map(lambda x, y: y)),
predict_class_label_number(test_data),
num_classes=len(label_names))
show_confusion_matrix(confusion_mtx, label_names)
在未知測試資料上評估模型
在此評估中,我們預期模型的準確度幾乎為 1。模型測試的所有圖片都與正常資料集無關,因此我們預期模型會預測「未知」類別標籤。
model.evaluate(unknown_test_data)
259/259 [==============================] - 30s 111ms/step - loss: 0.6760 - accuracy: 0.9999 [0.6760221719741821, 0.9998791813850403]
列印混淆矩陣。
unknown_confusion_mtx = tf.math.confusion_matrix(
list(ds_unknown_test.map(lambda x, y: y)),
predict_class_label_number(unknown_test_data),
num_classes=len(label_names))
show_confusion_matrix(unknown_confusion_mtx, label_names)
將模型匯出為 TFLite 和 SavedModel
現在我們可以將訓練後的模型匯出為 TFLite 和 SavedModel 格式,以便在裝置端部署並在 TensorFlow 中用於推論。
tflite_filename = f'{TFLITE_NAME_PREFIX}_model_{model_name}.tflite'
model.export(export_dir='.', tflite_filename=tflite_filename)
2023-11-07 14:20:20.089818: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them. INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp99qci6gx/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp99qci6gx/assets /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:746: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway. warnings.warn("Statistics for quantized inputs were expected, but not " 2023-11-07 14:20:30.245779: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:357] Ignored output_format. 2023-11-07 14:20:30.245840: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:360] Ignored drop_control_dependency. fully_quantize: 0, inference_type: 6, input_inference_type: 3, output_inference_type: 3 INFO:tensorflow:Label file is inside the TFLite model with metadata. INFO:tensorflow:Label file is inside the TFLite model with metadata. INFO:tensorflow:Saving labels in /tmpfs/tmp/tmp8co343h3/labels.txt INFO:tensorflow:Saving labels in /tmpfs/tmp/tmp8co343h3/labels.txt INFO:tensorflow:TensorFlow Lite model exported successfully: ./cassava_model_mobilenet_v3_large_100_224.tflite INFO:tensorflow:TensorFlow Lite model exported successfully: ./cassava_model_mobilenet_v3_large_100_224.tflite
# Export saved model version.
model.export(export_dir='.', export_format=ExportFormat.SAVED_MODEL)
INFO:tensorflow:Assets written to: ./saved_model/assets INFO:tensorflow:Assets written to: ./saved_model/assets
後續步驟
您剛訓練的模型可以用於行動裝置,甚至可以部署在田野中!
若要下載模型,請按一下 colab 左側「檔案」選單的資料夾圖示,然後選擇下載選項。
此處使用的相同技術可以應用於其他可能更適合您的用例的植物病害任務,或任何其他類型的圖片分類任務。如果您想繼續並在 Android 應用程式上部署,可以繼續閱讀此 Android 快速入門指南。