剪枝保留量化感知訓練 (PQAT) Keras 範例

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視 下載筆記本

總覽

這是一個端對端範例,展示了**剪枝保留量化感知訓練 (PQAT)** API 的用法,這是 TensorFlow Model Optimization Toolkit 協作最佳化管線的一部分。

其他頁面

如需管線和其他可用技術的簡介,請參閱協作最佳化總覽頁面

目錄

在本教學課程中,您將

  1. 從頭開始訓練用於 MNIST 資料集的 keras 模型。
  2. 使用稀疏性 API 微調模型以進行剪枝,並查看準確度。
  3. 套用 QAT 並觀察稀疏性的損失。
  4. 套用 PQAT 並觀察到先前套用的稀疏性已保留。
  5. 產生 TFLite 模型並觀察套用 PQAT 對其產生的影響。
  6. 比較使用 PQAT 模型達成的準確度與使用訓練後量化量化的模型。

設定

您可以在本機 virtualenvcolab 中執行此 Jupyter Notebook。如需設定依附元件的詳細資訊,請參閱安裝指南

 pip install -q tensorflow-model-optimization
import tensorflow as tf
import tf_keras as keras

import numpy as np
import tempfile
import zipfile
import os

訓練用於 MNIST 的 keras 模型,不進行剪枝

# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images  = test_images / 255.0

model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3),
                         activation=tf.nn.relu),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
    train_images,
    train_labels,
    validation_split=0.1,
    epochs=10
)
2024-03-09 12:40:49.225662: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Epoch 1/10
1688/1688 [==============================] - 21s 4ms/step - loss: 0.3056 - accuracy: 0.9125 - val_loss: 0.1308 - val_accuracy: 0.9640
Epoch 2/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.1348 - accuracy: 0.9614 - val_loss: 0.0882 - val_accuracy: 0.9760
Epoch 3/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0951 - accuracy: 0.9730 - val_loss: 0.0719 - val_accuracy: 0.9797
Epoch 4/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0761 - accuracy: 0.9778 - val_loss: 0.0694 - val_accuracy: 0.9798
Epoch 5/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0648 - accuracy: 0.9808 - val_loss: 0.0599 - val_accuracy: 0.9838
Epoch 6/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0564 - accuracy: 0.9831 - val_loss: 0.0601 - val_accuracy: 0.9837
Epoch 7/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0496 - accuracy: 0.9852 - val_loss: 0.0578 - val_accuracy: 0.9848
Epoch 8/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0445 - accuracy: 0.9864 - val_loss: 0.0556 - val_accuracy: 0.9847
Epoch 9/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0408 - accuracy: 0.9874 - val_loss: 0.0539 - val_accuracy: 0.9853
Epoch 10/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0382 - accuracy: 0.9881 - val_loss: 0.0585 - val_accuracy: 0.9848
<tf_keras.src.callbacks.History at 0x7f06cb26d670>

評估基準模型並儲存以供日後使用

_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

_, keras_file = tempfile.mkstemp('.h5')
print('Saving model to: ', keras_file)
keras.models.save_model(model, keras_file, include_optimizer=False)
Baseline test accuracy: 0.9812999963760376
Saving model to:  /tmpfs/tmp/tmpgyooj7vn.h5
/tmpfs/tmp/ipykernel_34779/3680774635.py:8: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native TF-Keras format, e.g. `model.save('my_model.keras')`.
  keras.models.save_model(model, keras_file, include_optimizer=False)

剪枝並微調模型至 50% 稀疏性

套用 prune_low_magnitude() API 來剪枝整個預先訓練的模型,以示範和觀察其在套用 zip 時減少模型大小,同時維持準確度的有效性。如需如何最佳使用 API 以在維持目標準確度的同時達成最佳壓縮率,請參閱剪枝綜合指南

定義模型並套用稀疏性 API

模型需要先經過預先訓練,才能使用稀疏性 API。

import tensorflow_model_optimization as tfmot

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(0.5, begin_step=0, frequency=100)
  }

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep()
]

pruned_model = prune_low_magnitude(model, **pruning_params)

# Use smaller learning rate for fine-tuning
opt = keras.optimizers.Adam(learning_rate=1e-5)

pruned_model.compile(
  loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  optimizer=opt,
  metrics=['accuracy'])

pruned_model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 prune_low_magnitude_reshap  (None, 28, 28, 1)         1         
 e (PruneLowMagnitude)                                           
                                                                 
 prune_low_magnitude_conv2d  (None, 26, 26, 12)        230       
  (PruneLowMagnitude)                                            
                                                                 
 prune_low_magnitude_max_po  (None, 13, 13, 12)        1         
 oling2d (PruneLowMagnitude                                      
 )                                                               
                                                                 
 prune_low_magnitude_flatte  (None, 2028)              1         
 n (PruneLowMagnitude)                                           
                                                                 
 prune_low_magnitude_dense   (None, 10)                40572     
 (PruneLowMagnitude)                                             
                                                                 
=================================================================
Total params: 40805 (159.41 KB)
Trainable params: 20410 (79.73 KB)
Non-trainable params: 20395 (79.69 KB)
_________________________________________________________________

微調模型並評估準確度與基準的比較

使用剪枝微調模型 3 個 epoch。

# Fine-tune model
pruned_model.fit(
  train_images,
  train_labels,
  epochs=3,
  validation_split=0.1,
  callbacks=callbacks)
Epoch 1/3
1688/1688 [==============================] - 10s 4ms/step - loss: 0.0852 - accuracy: 0.9716 - val_loss: 0.0814 - val_accuracy: 0.9742
Epoch 2/3
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0641 - accuracy: 0.9800 - val_loss: 0.0721 - val_accuracy: 0.9763
Epoch 3/3
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0559 - accuracy: 0.9829 - val_loss: 0.0682 - val_accuracy: 0.9788
<tf_keras.src.callbacks.History at 0x7f06b00f8eb0>

定義輔助函式來計算和列印模型的稀疏性。

def print_model_weights_sparsity(model):

    for layer in model.layers:
        if isinstance(layer, keras.layers.Wrapper):
            weights = layer.trainable_weights
        else:
            weights = layer.weights
        for weight in weights:
            # ignore auxiliary quantization weights
            if "quantize_layer" in weight.name:
                continue
            weight_size = weight.numpy().size
            zero_num = np.count_nonzero(weight == 0)
            print(
                f"{weight.name}: {zero_num/weight_size:.2%} sparsity ",
                f"({zero_num}/{weight_size})",
            )

檢查模型是否已正確剪枝。我們需要先移除剪枝包裝函式。

stripped_pruned_model = tfmot.sparsity.keras.strip_pruning(pruned_model)

print_model_weights_sparsity(stripped_pruned_model)
conv2d/kernel:0: 50.00% sparsity  (54/108)
conv2d/bias:0: 0.00% sparsity  (0/12)
dense/kernel:0: 50.00% sparsity  (10140/20280)
dense/bias:0: 0.00% sparsity  (0/10)

在此範例中,與基準相比,剪枝後測試準確度的損失極小。

_, pruned_model_accuracy = pruned_model.evaluate(
  test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)
print('Pruned test accuracy:', pruned_model_accuracy)
Baseline test accuracy: 0.9812999963760376
Pruned test accuracy: 0.9769999980926514

套用 QAT 和 PQAT,並檢查這兩種情況對模型稀疏性的影響

接下來,我們在剪枝模型上同時套用 QAT 和剪枝保留 QAT (PQAT),並觀察到 PQAT 保留了剪枝模型上的稀疏性。請注意,我們在使用 PQAT API 之前,已使用 tfmot.sparsity.keras.strip_pruning 從剪枝模型中移除剪枝包裝函式。

# QAT
qat_model = tfmot.quantization.keras.quantize_model(stripped_pruned_model)

qat_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
print('Train qat model:')
qat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)

# PQAT
quant_aware_annotate_model = tfmot.quantization.keras.quantize_annotate_model(
              stripped_pruned_model)
pqat_model = tfmot.quantization.keras.quantize_apply(
              quant_aware_annotate_model,
              tfmot.experimental.combine.Default8BitPrunePreserveQuantizeScheme())

pqat_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
print('Train pqat model:')
pqat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)
Train qat model:
422/422 [==============================] - 4s 7ms/step - loss: 0.0384 - accuracy: 0.9893 - val_loss: 0.0539 - val_accuracy: 0.9847
Train pqat model:
422/422 [==============================] - 4s 7ms/step - loss: 0.0395 - accuracy: 0.9890 - val_loss: 0.0543 - val_accuracy: 0.9850
<tf_keras.src.callbacks.History at 0x7f06cb25efa0>
print("QAT Model sparsity:")
print_model_weights_sparsity(qat_model)
print("PQAT Model sparsity:")
print_model_weights_sparsity(pqat_model)
QAT Model sparsity:
conv2d/kernel:0: 15.74% sparsity  (17/108)
conv2d/bias:0: 0.00% sparsity  (0/12)
dense/kernel:0: 11.48% sparsity  (2328/20280)
dense/bias:0: 0.00% sparsity  (0/10)
PQAT Model sparsity:
conv2d/kernel:0: 50.00% sparsity  (54/108)
conv2d/bias:0: 0.00% sparsity  (0/12)
dense/kernel:0: 50.00% sparsity  (10140/20280)
dense/bias:0: 0.00% sparsity  (0/10)

查看 PQAT 模型的壓縮優勢

定義輔助函式以取得已壓縮的模型檔案。

def get_gzipped_model_size(file):
  # It returns the size of the gzipped model in kilobytes.

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file)/1000

由於這是一個小型模型,因此這兩個模型之間的差異不是很明顯。將剪枝和 PQAT 套用至較大的生產模型會產生更顯著的壓縮。

# QAT model
converter = tf.lite.TFLiteConverter.from_keras_model(qat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
qat_tflite_model = converter.convert()
qat_model_file = 'qat_model.tflite'
# Save the model.
with open(qat_model_file, 'wb') as f:
    f.write(qat_tflite_model)

# PQAT model
converter = tf.lite.TFLiteConverter.from_keras_model(pqat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
pqat_tflite_model = converter.convert()
pqat_model_file = 'pqat_model.tflite'
# Save the model.
with open(pqat_model_file, 'wb') as f:
    f.write(pqat_tflite_model)

print("QAT model size: ", get_gzipped_model_size(qat_model_file), ' KB')
print("PQAT model size: ", get_gzipped_model_size(pqat_model_file), ' KB')
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpe5bjb60h/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpe5bjb60h/assets
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:964: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway.
  warnings.warn(
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1709988173.208136   34779 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format.
W0000 00:00:1709988173.208187   34779 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpkxxlbf0o/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpkxxlbf0o/assets
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:964: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway.
  warnings.warn(
W0000 00:00:1709988175.442658   34779 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format.
W0000 00:00:1709988175.442688   34779 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.
QAT model size:  16.923  KB
PQAT model size:  14.491  KB

查看從 TF 到 TFLite 的準確度持續性

定義輔助函式以評估測試資料集上的 TFLite 模型。

def eval_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print(f"Evaluated on {i} results so far.")
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy

您評估已剪枝和量化的模型,然後查看 TensorFlow 中的準確度是否在 TFLite 後端持續存在。

interpreter = tf.lite.Interpreter(pqat_model_file)
interpreter.allocate_tensors()

pqat_test_accuracy = eval_model(interpreter)

print('Pruned and quantized TFLite test_accuracy:', pqat_test_accuracy)
print('Pruned TF test accuracy:', pruned_model_accuracy)
Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
WARNING: Attempting to use a delegate that only supports static-sized tensors with a graph that has dynamic-sized tensors (tensor#12 is a dynamic-sized tensor).
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


Pruned and quantized TFLite test_accuracy: 0.9821
Pruned TF test accuracy: 0.9769999980926514

套用訓練後量化並與 PQAT 模型比較

接下來,我們在剪枝模型上使用一般訓練後量化 (無微調),並檢查其準確度與 PQAT 模型的比較。這示範了為何您需要使用 PQAT 來改善量化模型的準確度。

首先,從前 1000 張訓練影像定義校正資料集的產生器。

def mnist_representative_data_gen():
  for image in train_images[:1000]:  
    image = np.expand_dims(image, axis=0).astype(np.float32)
    yield [image]

量化模型並將準確度與先前取得的 PQAT 模型進行比較。請注意,使用微調量化的模型可達到更高的準確度。

converter = tf.lite.TFLiteConverter.from_keras_model(stripped_pruned_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = mnist_representative_data_gen
post_training_tflite_model = converter.convert()
post_training_model_file = 'post_training_model.tflite'
# Save the model.
with open(post_training_model_file, 'wb') as f:
    f.write(post_training_tflite_model)

# Compare accuracy
interpreter = tf.lite.Interpreter(post_training_model_file)
interpreter.allocate_tensors()

post_training_test_accuracy = eval_model(interpreter)

print('PQAT TFLite test_accuracy:', pqat_test_accuracy)
print('Post-training (no fine-tuning) TF test accuracy:', post_training_test_accuracy)
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp4xwi7ko3/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp4xwi7ko3/assets
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:964: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway.
  warnings.warn(
W0000 00:00:1709988177.152521   34779 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format.
W0000 00:00:1709988177.152549   34779 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.
fully_quantize: 0, inference_type: 6, input_inference_type: FLOAT32, output_inference_type: FLOAT32
Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


PQAT TFLite test_accuracy: 0.9821
Post-training (no fine-tuning) TF test accuracy: 0.9764

結論

在本教學課程中,您學習瞭如何建立模型、使用稀疏性 API 進行剪枝,以及套用稀疏性保留量化感知訓練 (PQAT) 以在使用 QAT 時保留稀疏性。最後,將 PQAT 模型與 QAT 模型進行比較,以顯示前者保留了稀疏性,而後者則失去了稀疏性。接下來,模型會轉換為 TFLite,以顯示鏈結剪枝和 PQAT 模型最佳化技術的壓縮優勢,並評估 TFLite 模型,以確保準確度在 TFLite 後端持續存在。最後,將 PQAT 模型與使用訓練後量化 API 達成的量化剪枝模型進行比較,以示範 PQAT 在恢復正常量化造成的準確度損失方面的優勢。