![]() |
![]() |
![]() |
![]() |
本教學課程微調 TensorFlow 模型花園套件 (tensorflow-models
) 中的殘差網路 (ResNet),以分類 CIFAR 資料集中的圖片。
模型花園包含一系列最先進的視覺模型,這些模型是使用 TensorFlow 的高階 API 實作。這些實作示範了模型化的最佳做法,讓使用者能充分利用 TensorFlow 進行研究和產品開發。
本教學課程使用 ResNet 模型,這是一種最先進的圖片分類器。本教學課程使用 ResNet-18 模型,這是一種具有 18 層的卷積神經網路。
本教學課程示範如何
- 使用 TensorFlow Models 套件中的模型。
- 微調預先建構的 ResNet 以進行圖片分類。
- 匯出已調整的 ResNet 模型。
設定
安裝並匯入必要的模組。
pip install -U -q "tf-models-official"
匯入 TensorFlow、TensorFlow Datasets 和一些輔助程式庫。
import pprint
import tempfile
from IPython import display
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
2023-10-17 11:52:54.005237: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-10-17 11:52:54.005294: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-10-17 11:52:54.005338: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
tensorflow_models
套件包含 ResNet 視覺模型,而 official.vision.serving
模型包含儲存和匯出已調整模型的功能。
import tensorflow_models as tfm
# These are not in the tfm public API for v2.9. They will be available in v2.10
from official.vision.serving import export_saved_model_lib
import official.core.train_lib
為 Cifar-10 資料集設定 ResNet-18 模型
CIFAR10 資料集包含 60,000 張彩色圖片,這些圖片分為 10 個互斥的類別,每個類別各有 6,000 張圖片。
在模型花園中,定義模型的參數集合稱為設定。模型花園可以透過 工廠,根據一組已知的參數建立設定。
使用 tfm.vision.configs.image_classification.image_classification_imagenet
定義的 resnet_imagenet
工廠設定。此設定已設定為訓練 ResNet,使其在 ImageNet 上收斂。
exp_config = tfm.core.exp_factory.get_exp_config('resnet_imagenet')
tfds_name = 'cifar10'
ds,ds_info = tfds.load(
tfds_name,
with_info=True)
ds_info
2023-10-17 11:52:59.285390: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://tensorflow.dev.org.tw/install/gpu for how to download and setup the required libraries for your platform. Skipping registering GPU devices... tfds.core.DatasetInfo( name='cifar10', full_name='cifar10/3.0.2', description=""" The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images. """, homepage='https://www.cs.toronto.edu/~kriz/cifar.html', data_dir='gs://tensorflow-datasets/datasets/cifar10/3.0.2', file_format=tfrecord, download_size=162.17 MiB, dataset_size=132.40 MiB, features=FeaturesDict({ 'id': Text(shape=(), dtype=string), 'image': Image(shape=(32, 32, 3), dtype=uint8), 'label': ClassLabel(shape=(), dtype=int64, num_classes=10), }), supervised_keys=('image', 'label'), disable_shuffling=False, splits={ 'test': <SplitInfo num_examples=10000, num_shards=1>, 'train': <SplitInfo num_examples=50000, num_shards=1>, }, citation="""@TECHREPORT{Krizhevsky09learningmultiple, author = {Alex Krizhevsky}, title = {Learning multiple layers of features from tiny images}, institution = {}, year = {2009} }""", )
調整模型與資料集設定,使其適用於 Cifar-10 (cifar10
)。
# Configure model
exp_config.task.model.num_classes = 10
exp_config.task.model.input_size = list(ds_info.features["image"].shape)
exp_config.task.model.backbone.resnet.model_id = 18
# Configure training and testing data
batch_size = 128
exp_config.task.train_data.input_path = ''
exp_config.task.train_data.tfds_name = tfds_name
exp_config.task.train_data.tfds_split = 'train'
exp_config.task.train_data.global_batch_size = batch_size
exp_config.task.validation_data.input_path = ''
exp_config.task.validation_data.tfds_name = tfds_name
exp_config.task.validation_data.tfds_split = 'test'
exp_config.task.validation_data.global_batch_size = batch_size
調整訓練器設定。
logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]
if 'GPU' in ''.join(logical_device_names):
print('This may be broken in Colab.')
device = 'GPU'
elif 'TPU' in ''.join(logical_device_names):
print('This may be broken in Colab.')
device = 'TPU'
else:
print('Running on CPU is slow, so only train for a few steps.')
device = 'CPU'
if device=='CPU':
train_steps = 20
exp_config.trainer.steps_per_loop = 5
else:
train_steps=5000
exp_config.trainer.steps_per_loop = 100
exp_config.trainer.summary_interval = 100
exp_config.trainer.checkpoint_interval = train_steps
exp_config.trainer.validation_interval = 1000
exp_config.trainer.validation_steps = ds_info.splits['test'].num_examples // batch_size
exp_config.trainer.train_steps = train_steps
exp_config.trainer.optimizer_config.learning_rate.type = 'cosine'
exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = train_steps
exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 0.1
exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = 100
Running on CPU is slow, so only train for a few steps.
列印修改後的設定。
pprint.pprint(exp_config.as_dict())
display.Javascript("google.colab.output.setIframeHeight('300px');")
{'runtime': {'all_reduce_alg': None, 'batchnorm_spatial_persistent': False, 'dataset_num_private_threads': None, 'default_shard_dim': -1, 'distribution_strategy': 'mirrored', 'enable_xla': True, 'gpu_thread_mode': None, 'loss_scale': None, 'mixed_precision_dtype': None, 'num_cores_per_replica': 1, 'num_gpus': 0, 'num_packs': 1, 'per_gpu_thread_count': 0, 'run_eagerly': False, 'task_index': -1, 'tpu': None, 'tpu_enable_xla_dynamic_padder': None, 'use_tpu_mp_strategy': False, 'worker_hosts': None}, 'task': {'allow_image_summary': False, 'differential_privacy_config': None, 'eval_input_partition_dims': [], 'evaluation': {'precision_and_recall_thresholds': None, 'report_per_class_precision_and_recall': False, 'top_k': 5}, 'freeze_backbone': False, 'init_checkpoint': None, 'init_checkpoint_modules': 'all', 'losses': {'l2_weight_decay': 0.0001, 'label_smoothing': 0.0, 'loss_weight': 1.0, 'one_hot': True, 'soft_labels': False, 'use_binary_cross_entropy': False}, 'model': {'add_head_batch_norm': False, 'backbone': {'resnet': {'bn_trainable': True, 'depth_multiplier': 1.0, 'model_id': 18, 'replace_stem_max_pool': False, 'resnetd_shortcut': False, 'scale_stem': True, 'se_ratio': 0.0, 'stem_type': 'v0', 'stochastic_depth_drop_rate': 0.0}, 'type': 'resnet'}, 'dropout_rate': 0.0, 'input_size': [32, 32, 3], 'kernel_initializer': 'random_uniform', 'norm_activation': {'activation': 'relu', 'norm_epsilon': 1e-05, 'norm_momentum': 0.9, 'use_sync_bn': False}, 'num_classes': 10, 'output_softmax': False}, 'model_output_keys': [], 'name': None, 'train_data': {'apply_tf_data_service_before_batching': False, 'aug_crop': True, 'aug_policy': None, 'aug_rand_hflip': True, 'aug_type': None, 'autotune_algorithm': None, 'block_length': 1, 'cache': False, 'center_crop_fraction': 0.875, 'color_jitter': 0.0, 'crop_area_range': (0.08, 1.0), 'cycle_length': 10, 'decode_jpeg_only': True, 'decoder': {'simple_decoder': {'attribute_names': [], 'mask_binarize_threshold': None, 'regenerate_source_id': False}, 'type': 'simple_decoder'}, 'deterministic': None, 'drop_remainder': True, 'dtype': 'float32', 'enable_shared_tf_data_service_between_parallel_trainers': False, 'enable_tf_data_service': False, 'file_type': 'tfrecord', 'global_batch_size': 128, 'image_field_key': 'image/encoded', 'input_path': '', 'is_multilabel': False, 'is_training': True, 'label_field_key': 'image/class/label', 'mixup_and_cutmix': None, 'prefetch_buffer_size': None, 'randaug_magnitude': 10, 'random_erasing': None, 'repeated_augment': None, 'seed': None, 'sharding': True, 'shuffle_buffer_size': 10000, 'tf_data_service_address': None, 'tf_data_service_job_name': None, 'tf_resize_method': 'bilinear', 'tfds_as_supervised': False, 'tfds_data_dir': '', 'tfds_name': 'cifar10', 'tfds_skip_decoding_feature': '', 'tfds_split': 'train', 'three_augment': False, 'trainer_id': None, 'weights': None}, 'train_input_partition_dims': [], 'validation_data': {'apply_tf_data_service_before_batching': False, 'aug_crop': True, 'aug_policy': None, 'aug_rand_hflip': True, 'aug_type': None, 'autotune_algorithm': None, 'block_length': 1, 'cache': False, 'center_crop_fraction': 0.875, 'color_jitter': 0.0, 'crop_area_range': (0.08, 1.0), 'cycle_length': 10, 'decode_jpeg_only': True, 'decoder': {'simple_decoder': {'attribute_names': [], 'mask_binarize_threshold': None, 'regenerate_source_id': False}, 'type': 'simple_decoder'}, 'deterministic': None, 'drop_remainder': True, 'dtype': 'float32', 'enable_shared_tf_data_service_between_parallel_trainers': False, 'enable_tf_data_service': False, 'file_type': 'tfrecord', 'global_batch_size': 128, 'image_field_key': 'image/encoded', 'input_path': '', 'is_multilabel': False, 'is_training': False, 'label_field_key': 'image/class/label', 'mixup_and_cutmix': None, 'prefetch_buffer_size': None, 'randaug_magnitude': 10, 'random_erasing': None, 'repeated_augment': None, 'seed': None, 'sharding': True, 'shuffle_buffer_size': 10000, 'tf_data_service_address': None, 'tf_data_service_job_name': None, 'tf_resize_method': 'bilinear', 'tfds_as_supervised': False, 'tfds_data_dir': '', 'tfds_name': 'cifar10', 'tfds_skip_decoding_feature': '', 'tfds_split': 'test', 'three_augment': False, 'trainer_id': None, 'weights': None} }, 'trainer': {'allow_tpu_summary': False, 'best_checkpoint_eval_metric': '', 'best_checkpoint_export_subdir': '', 'best_checkpoint_metric_comp': 'higher', 'checkpoint_interval': 20, 'continuous_eval_timeout': 3600, 'eval_tf_function': True, 'eval_tf_while_loop': False, 'loss_upper_bound': 1000000.0, 'max_to_keep': 5, 'optimizer_config': {'ema': None, 'learning_rate': {'cosine': {'alpha': 0.0, 'decay_steps': 20, 'initial_learning_rate': 0.1, 'name': 'CosineDecay', 'offset': 0}, 'type': 'cosine'}, 'optimizer': {'sgd': {'clipnorm': None, 'clipvalue': None, 'decay': 0.0, 'global_clipnorm': None, 'momentum': 0.9, 'name': 'SGD', 'nesterov': False}, 'type': 'sgd'}, 'warmup': {'linear': {'name': 'linear', 'warmup_learning_rate': 0, 'warmup_steps': 100}, 'type': 'linear'} }, 'preemption_on_demand_checkpoint': True, 'recovery_begin_steps': 0, 'recovery_max_trials': 0, 'steps_per_loop': 5, 'summary_interval': 100, 'train_steps': 20, 'train_tf_function': True, 'train_tf_while_loop': True, 'validation_interval': 1000, 'validation_steps': 78, 'validation_summary_subdir': 'validation'} } <IPython.core.display.Javascript object>
設定分散策略。
logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]
if exp_config.runtime.mixed_precision_dtype == tf.float16:
tf.keras.mixed_precision.set_global_policy('mixed_float16')
if 'GPU' in ''.join(logical_device_names):
distribution_strategy = tf.distribute.MirroredStrategy()
elif 'TPU' in ''.join(logical_device_names):
tf.tpu.experimental.initialize_tpu_system()
tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='/device:TPU_SYSTEM:0')
distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
print('Warning: this will be really slow.')
distribution_strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])
Warning: this will be really slow.
從 config_definitions.TaskConfig
建立 Task
物件 (tfm.core.base_task.Task
)。
Task
物件具有建構資料集、建構模型,以及執行訓練和評估所需的所有方法。這些方法由 tfm.core.train_lib.run_experiment
驅動。
with distribution_strategy.scope():
model_dir = tempfile.mkdtemp()
task = tfm.core.task_factory.get_task(exp_config.task, logging_dir=model_dir)
# tf.keras.utils.plot_model(task.build_model(), show_shapes=True)
for images, labels in task.build_inputs(exp_config.task.train_data).take(1):
print()
print(f'images.shape: {str(images.shape):16} images.dtype: {images.dtype!r}')
print(f'labels.shape: {str(labels.shape):16} labels.dtype: {labels.dtype!r}')
images.shape: (128, 32, 32, 3) images.dtype: tf.float32 labels.shape: (128,) labels.dtype: tf.int32 2023-10-17 11:53:02.248801: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
視覺化訓練資料
資料載入器使用 preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)
套用 z-score 正規化,因此資料集傳回的圖片無法直接由標準工具顯示。視覺化程式碼需要將資料重新縮放至 [0,1] 範圍。
plt.hist(images.numpy().flatten());
使用 ds_info
(這是 tfds.core.DatasetInfo
的執行個體) 查閱每個類別 ID 的文字說明。
label_info = ds_info.features['label']
label_info.int2str(1)
'automobile'
視覺化一批資料。
def show_batch(images, labels, predictions=None):
plt.figure(figsize=(10, 10))
min = images.numpy().min()
max = images.numpy().max()
delta = max - min
for i in range(12):
plt.subplot(6, 6, i + 1)
plt.imshow((images[i]-min) / delta)
if predictions is None:
plt.title(label_info.int2str(labels[i]))
else:
if labels[i] == predictions[i]:
color = 'g'
else:
color = 'r'
plt.title(label_info.int2str(predictions[i]), color=color)
plt.axis("off")
plt.figure(figsize=(10, 10))
for images, labels in task.build_inputs(exp_config.task.train_data).take(1):
show_batch(images, labels)
2023-10-17 11:53:04.198417: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
視覺化測試資料
視覺化一批來自驗證資料集的圖片。
plt.figure(figsize=(10, 10));
for images, labels in task.build_inputs(exp_config.task.validation_data).take(1):
show_batch(images, labels)
2023-10-17 11:53:07.007846: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
訓練與評估
model, eval_logs = tfm.core.train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode='train_and_eval',
params=exp_config,
model_dir=model_dir,
run_post_eval=True)
restoring or initializing model... INFO:tensorflow:Customized initialization is done through the passed `init_fn`. INFO:tensorflow:Customized initialization is done through the passed `init_fn`. train | step: 0 | training until step 20... 2023-10-17 11:53:09.849007: W tensorflow/core/framework/dataset.cc:959] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations. train | step: 5 | steps/sec: 0.5 | output: {'accuracy': 0.103125, 'learning_rate': 0.0, 'top_5_accuracy': 0.4828125, 'training_loss': 2.7998607} saved checkpoint to /tmpfs/tmp/tmpu0ate1h5/ckpt-5. train | step: 10 | steps/sec: 0.8 | output: {'accuracy': 0.0828125, 'learning_rate': 0.0, 'top_5_accuracy': 0.4984375, 'training_loss': 2.8205295} train | step: 15 | steps/sec: 0.8 | output: {'accuracy': 0.0921875, 'learning_rate': 0.0, 'top_5_accuracy': 0.503125, 'training_loss': 2.8169343} train | step: 20 | steps/sec: 0.8 | output: {'accuracy': 0.1015625, 'learning_rate': 0.0, 'top_5_accuracy': 0.45, 'training_loss': 2.8760865} eval | step: 20 | running 78 steps of evaluation... eval | step: 20 | steps/sec: 24.4 | eval time: 3.2 sec | output: {'accuracy': 0.09485176, 'steps_per_second': 24.40085348913806, 'top_5_accuracy': 0.49589342, 'validation_loss': 2.5864375} saved checkpoint to /tmpfs/tmp/tmpu0ate1h5/ckpt-20. 2023-10-17 11:53:43.844533: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://tensorflow.dev.org.tw/install/gpu for how to download and setup the required libraries for your platform. Skipping registering GPU devices... WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/nn_ops.py:5253: tensor_shape_from_node_def_name (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version. Instructions for updating: This API was designed for TensorFlow v1. See https://tensorflow.dev.org.tw/guide/migrate for instructions on how to migrate your code to TensorFlow v2. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/nn_ops.py:5253: tensor_shape_from_node_def_name (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version. Instructions for updating: This API was designed for TensorFlow v1. See https://tensorflow.dev.org.tw/guide/migrate for instructions on how to migrate your code to TensorFlow v2. eval | step: 20 | running 78 steps of evaluation... 2023-10-17 11:53:45.627213: W tensorflow/core/framework/dataset.cc:959] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations. eval | step: 20 | steps/sec: 40.1 | eval time: 1.9 sec | output: {'accuracy': 0.09485176, 'steps_per_second': 40.14298727815298, 'top_5_accuracy': 0.49589342, 'validation_loss': 2.5864375}
# tf.keras.utils.plot_model(model, show_shapes=True)
列印 accuracy、top_5_accuracy 和 validation_loss 評估指標。
for key, value in eval_logs.items():
if isinstance(value, tf.Tensor):
value = value.numpy()
print(f'{key:20}: {value:.3f}')
accuracy : 0.095 top_5_accuracy : 0.496 validation_loss : 2.586 steps_per_second : 40.143
將一批經過處理的訓練資料傳送到模型中執行,並檢視結果
for images, labels in task.build_inputs(exp_config.task.train_data).take(1):
predictions = model.predict(images)
predictions = tf.argmax(predictions, axis=-1)
show_batch(images, labels, tf.cast(predictions, tf.int32))
if device=='CPU':
plt.suptitle('The model was only trained for a few steps, it is not expected to do well.')
2023-10-17 11:53:49.840600: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 4/4 [==============================] - 1s 13ms/step 2023-10-17 11:53:50.778301: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
匯出 SavedModel
keras.Model
物件由 train_lib.run_experiment
傳回,預期資料會由資料集載入器使用 preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)
中的相同平均值和變異數統計資料進行正規化。此匯出函式會處理這些詳細資訊,因此您可以傳遞 tf.uint8
圖片並取得正確的結果。
# Saving and exporting the trained model
export_saved_model_lib.export_inference_graph(
input_type='image_tensor',
batch_size=1,
input_image_size=[32, 32],
params=exp_config,
checkpoint_path=tf.train.latest_checkpoint(model_dir),
export_dir='./export/')
INFO:tensorflow:Assets written to: ./export/assets INFO:tensorflow:Assets written to: ./export/assets
測試匯出的模型。
# Importing SavedModel
imported = tf.saved_model.load('./export/')
model_fn = imported.signatures['serving_default']
視覺化預測結果。
plt.figure(figsize=(10, 10))
for data in tfds.load('cifar10', split='test').batch(12).take(1):
predictions = []
for image in data['image']:
index = tf.argmax(model_fn(image[tf.newaxis, ...])['logits'], axis=1)[0]
predictions.append(index)
show_batch(data['image'], data['label'], predictions)
if device=='CPU':
plt.suptitle('The model was only trained for a few steps, it is not expected to do better than random.')
2023-10-17 11:54:01.438509: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.