Keras 量化感知訓練範例

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視原始碼 下載筆記本

總覽

歡迎使用量化感知訓練的端對端範例。

其他頁面

如需量化感知訓練簡介,以及判斷您是否應使用 (包括支援哪些項目),請參閱總覽頁面

若要快速尋找您使用案例所需的 API (除了使用 8 位元完整量化模型之外),請參閱完整指南

摘要

在本教學課程中,您將

  1. 從頭開始訓練適用於 MNIST 的 keras 模型。
  2. 套用量化感知訓練 API 微調模型、查看準確度,並匯出量化感知模型。
  3. 使用該模型為 TFLite 後端建立實際量化模型。
  4. 查看 TFLite 中準確度的持續性,以及小 4 倍的模型。若要查看行動裝置上的延遲優勢,請試用 TFLite 應用程式存放區中的 TFLite 範例

設定

 pip install -q tensorflow
 pip install -q tensorflow-model-optimization
import tempfile
import os

import tensorflow as tf

from tensorflow_model_optimization.python.core.keras.compat import keras

在未進行量化感知訓練的情況下訓練適用於 MNIST 的模型

# Load MNIST dataset
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

# Define the model architecture.
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
  train_images,
  train_labels,
  epochs=1,
  validation_split=0.1,
)
2024-03-09 12:32:07.505187: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
1688/1688 [==============================] - 21s 4ms/step - loss: 0.3232 - accuracy: 0.9089 - val_loss: 0.1358 - val_accuracy: 0.9618
<tf_keras.src.callbacks.History at 0x7fc8a42a07f0>

使用量化感知訓練複製和微調預先訓練模型

定義模型

您將對整個模型套用量化感知訓練,並在模型摘要中看到這一點。所有層現在都以「quant」為前綴。

請注意,產生的模型是量化感知模型,但未量化 (例如,權重是 float32 而不是 int8)。後面的章節將說明如何從量化感知模型建立量化模型。

完整指南中,您可以瞭解如何量化某些層以改善模型準確度。

import tensorflow_model_optimization as tfmot

quantize_model = tfmot.quantization.keras.quantize_model

# q_aware stands for for quantization aware.
q_aware_model = quantize_model(model)

# `quantize_model` requires a recompile.
q_aware_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

q_aware_model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 quantize_layer (QuantizeLa  (None, 28, 28)            3         
 yer)                                                            
                                                                 
 quant_reshape (QuantizeWra  (None, 28, 28, 1)         1         
 pperV2)                                                         
                                                                 
 quant_conv2d (QuantizeWrap  (None, 26, 26, 12)        147       
 perV2)                                                          
                                                                 
 quant_max_pooling2d (Quant  (None, 13, 13, 12)        1         
 izeWrapperV2)                                                   
                                                                 
 quant_flatten (QuantizeWra  (None, 2028)              1         
 pperV2)                                                         
                                                                 
 quant_dense (QuantizeWrapp  (None, 10)                20295     
 erV2)                                                           
                                                                 
=================================================================
Total params: 20448 (79.88 KB)
Trainable params: 20410 (79.73 KB)
Non-trainable params: 38 (152.00 Byte)
_________________________________________________________________

訓練並評估模型與基準

為了示範在僅訓練模型一個週期後進行微調,請在訓練資料的子集上使用量化感知訓練進行微調。

train_images_subset = train_images[0:1000] # out of 60000
train_labels_subset = train_labels[0:1000]

q_aware_model.fit(train_images_subset, train_labels_subset,
                  batch_size=500, epochs=1, validation_split=0.1)
2/2 [==============================] - 2s 480ms/step - loss: 0.1749 - accuracy: 0.9500 - val_loss: 0.1863 - val_accuracy: 0.9700
<tf_keras.src.callbacks.History at 0x7fc8a4006640>

在此範例中,與基準相比,在量化感知訓練後,測試準確度幾乎沒有損失。

_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

_, q_aware_model_accuracy = q_aware_model.evaluate(
   test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)
print('Quant test accuracy:', q_aware_model_accuracy)
Baseline test accuracy: 0.9550999999046326
Quant test accuracy: 0.9584000110626221

為 TFLite 後端建立量化模型

在此之後,您將擁有一個實際量化的模型,其中包含 int8 權重和 uint8 啟動。

converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

quantized_tflite_model = converter.convert()
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp3vwylslo/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp3vwylslo/assets
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:964: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway.
  warnings.warn(
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1709987557.593385   26473 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format.
W0000 00:00:1709987557.593449   26473 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.

查看從 TF 到 TFLite 的準確度持續性

定義協助程式函式,以評估測試資料集上的 TF Lite 模型。

import numpy as np

def evaluate_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print('Evaluated on {n} results so far.'.format(n=i))
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy

您評估量化模型,並看到 TensorFlow 的準確度持續存在於 TFLite 後端。

interpreter = tf.lite.Interpreter(model_content=quantized_tflite_model)
interpreter.allocate_tensors()

test_accuracy = evaluate_model(interpreter)

print('Quant TFLite test_accuracy:', test_accuracy)
print('Quant TF test accuracy:', q_aware_model_accuracy)
Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
WARNING: Attempting to use a delegate that only supports static-sized tensors with a graph that has dynamic-sized tensors (tensor#12 is a dynamic-sized tensor).
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


Quant TFLite test_accuracy: 0.9584
Quant TF test accuracy: 0.9584000110626221

查看小 4 倍的量化模型

您建立一個 float TFLite 模型,然後看到量化的 TFLite 模型小 4 倍。

# Create float TFLite model.
float_converter = tf.lite.TFLiteConverter.from_keras_model(model)
float_tflite_model = float_converter.convert()

# Measure sizes of models.
_, float_file = tempfile.mkstemp('.tflite')
_, quant_file = tempfile.mkstemp('.tflite')

with open(quant_file, 'wb') as f:
  f.write(quantized_tflite_model)

with open(float_file, 'wb') as f:
  f.write(float_tflite_model)

print("Float model in Mb:", os.path.getsize(float_file) / float(2**20))
print("Quantized model in Mb:", os.path.getsize(quant_file) / float(2**20))
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmppztp93bk/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmppztp93bk/assets
Float model in Mb: 0.08068466186523438
Quantized model in Mb: 0.0236053466796875
W0000 00:00:1709987559.464849   26473 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format.
W0000 00:00:1709987559.464884   26473 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.

結論

在本教學課程中,您瞭解如何使用 TensorFlow Model Optimization Toolkit API 建立量化感知模型,然後為 TFLite 後端建立量化模型。

您看到適用於 MNIST 模型的模型大小壓縮優勢為 4 倍,且準確度差異極小。若要查看行動裝置上的延遲優勢,請試用 TFLite 應用程式存放區中的 TFLite 範例

我們鼓勵您試用這項新功能,這對於在資源受限環境中部署尤其重要。