稀疏性與叢集保留量化感知訓練 (PCQAT) Keras 範例

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視 下載筆記本

總覽

這是一個端對端範例,展示了稀疏性與叢集保留量化感知訓練 (PCQAT) API 的用法,它是 TensorFlow 模型最佳化工具組協同最佳化管線的一部分。

其他頁面

如需管線和其他可用技術的簡介,請參閱協同最佳化總覽頁面

目錄

在本教學課程中,您將:

  1. 從頭開始訓練 MNIST 資料集的 keras 模型。
  2. 使用剪枝微調模型並查看準確度,並觀察模型是否已成功剪枝。
  3. 在剪枝模型上套用稀疏性保留叢集,並觀察先前套用的稀疏性是否已保留。
  4. 套用 QAT 並觀察稀疏性和叢集的損失。
  5. 套用 PCQAT 並觀察先前套用的稀疏性和叢集是否都已保留。
  6. 產生 TFLite 模型並觀察在其上套用 PCQAT 的效果。
  7. 比較不同模型的大小,以觀察在套用稀疏性後,接著套用稀疏性保留叢集和 PCQAT 協同最佳化技術的壓縮優勢。
  8. 將完全最佳化模型的準確度與未最佳化基準模型準確度進行比較。

設定

您可以在本機 virtualenvcolab 中執行此 Jupyter Notebook。如需設定依賴關係的詳細資訊,請參閱安裝指南

 pip install -q tensorflow-model-optimization
import tensorflow as tf
import tf_keras as keras

import numpy as np
import tempfile
import zipfile
import os

訓練要剪枝和叢集的 MNIST keras 模型

# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images  = test_images / 255.0

model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3),
                         activation=tf.nn.relu),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

opt = keras.optimizers.Adam(learning_rate=1e-3)

# Train the digit classification model
model.compile(optimizer=opt,
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
    train_images,
    train_labels,
    validation_split=0.1,
    epochs=10
)
2024-03-09 12:49:28.954689: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Epoch 1/10
1688/1688 [==============================] - 21s 4ms/step - loss: 0.3037 - accuracy: 0.9146 - val_loss: 0.1153 - val_accuracy: 0.9682
Epoch 2/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.1133 - accuracy: 0.9680 - val_loss: 0.0895 - val_accuracy: 0.9762
Epoch 3/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0792 - accuracy: 0.9768 - val_loss: 0.0652 - val_accuracy: 0.9825
Epoch 4/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0662 - accuracy: 0.9803 - val_loss: 0.0633 - val_accuracy: 0.9823
Epoch 5/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0570 - accuracy: 0.9833 - val_loss: 0.0649 - val_accuracy: 0.9825
Epoch 6/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0498 - accuracy: 0.9853 - val_loss: 0.0571 - val_accuracy: 0.9842
Epoch 7/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0448 - accuracy: 0.9867 - val_loss: 0.0586 - val_accuracy: 0.9840
Epoch 8/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0405 - accuracy: 0.9873 - val_loss: 0.0586 - val_accuracy: 0.9848
Epoch 9/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0370 - accuracy: 0.9885 - val_loss: 0.0624 - val_accuracy: 0.9828
Epoch 10/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0332 - accuracy: 0.9902 - val_loss: 0.0554 - val_accuracy: 0.9848
<tf_keras.src.callbacks.History at 0x7f615076beb0>

評估基準模型並儲存以供日後使用

_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

_, keras_file = tempfile.mkstemp('.h5')
print('Saving model to: ', keras_file)
keras.models.save_model(model, keras_file, include_optimizer=False)
Baseline test accuracy: 0.9835000038146973
Saving model to:  /tmpfs/tmp/tmpf70eijr3.h5
/tmpfs/tmp/ipykernel_41361/3680774635.py:8: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native TF-Keras format, e.g. `model.save('my_model.keras')`.
  keras.models.save_model(model, keras_file, include_optimizer=False)

剪枝並微調模型至 50% 稀疏性

套用 prune_low_magnitude() API 以達成將在下一步驟中叢集的剪枝模型。如需剪枝 API 的詳細資訊,請參閱剪枝完整指南

定義模型並套用稀疏性 API

請注意,使用的是預先訓練的模型。

import tensorflow_model_optimization as tfmot

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(0.5, begin_step=0, frequency=100)
  }

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep()
]

pruned_model = prune_low_magnitude(model, **pruning_params)

# Use smaller learning rate for fine-tuning
opt = keras.optimizers.Adam(learning_rate=1e-5)

pruned_model.compile(
  loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  optimizer=opt,
  metrics=['accuracy'])

微調模型、檢查稀疏性,並評估相對於基準的準確度

使用剪枝微調模型 3 個 epoch。

# Fine-tune model
pruned_model.fit(
  train_images,
  train_labels,
  epochs=3,
  validation_split=0.1,
  callbacks=callbacks)
Epoch 1/3
1688/1688 [==============================] - 10s 4ms/step - loss: 0.0851 - accuracy: 0.9707 - val_loss: 0.0801 - val_accuracy: 0.9768
Epoch 2/3
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0591 - accuracy: 0.9801 - val_loss: 0.0672 - val_accuracy: 0.9808
Epoch 3/3
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0493 - accuracy: 0.9852 - val_loss: 0.0626 - val_accuracy: 0.9837
<tf_keras.src.callbacks.History at 0x7f60c8593ee0>

定義輔助函數以計算和列印模型的稀疏性和叢集。

def print_model_weights_sparsity(model):
    for layer in model.layers:
        if isinstance(layer, keras.layers.Wrapper):
            weights = layer.trainable_weights
        else:
            weights = layer.weights
        for weight in weights:
            if "kernel" not in weight.name or "centroid" in weight.name:
                continue
            weight_size = weight.numpy().size
            zero_num = np.count_nonzero(weight == 0)
            print(
                f"{weight.name}: {zero_num/weight_size:.2%} sparsity ",
                f"({zero_num}/{weight_size})",
            )

def print_model_weight_clusters(model):
    for layer in model.layers:
        if isinstance(layer, keras.layers.Wrapper):
            weights = layer.trainable_weights
        else:
            weights = layer.weights
        for weight in weights:
            # ignore auxiliary quantization weights
            if "quantize_layer" in weight.name:
                continue
            if "kernel" in weight.name:
                unique_count = len(np.unique(weight))
                print(
                    f"{layer.name}/{weight.name}: {unique_count} clusters "
                )

讓我們先剝離剪枝包裝器,然後檢查模型核心是否已正確剪枝。

stripped_pruned_model = tfmot.sparsity.keras.strip_pruning(pruned_model)

print_model_weights_sparsity(stripped_pruned_model)
conv2d/kernel:0: 50.00% sparsity  (54/108)
dense/kernel:0: 50.00% sparsity  (10140/20280)

套用稀疏性保留叢集,並檢查其在兩種情況下對模型稀疏性的影響

接下來,在剪枝模型上套用稀疏性保留叢集,並觀察叢集數量,並檢查稀疏性是否已保留。

import tensorflow_model_optimization as tfmot
from tensorflow_model_optimization.python.core.clustering.keras.experimental import (
    cluster,
)

cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization

cluster_weights = cluster.cluster_weights

clustering_params = {
  'number_of_clusters': 8,
  'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS,
  'preserve_sparsity': True
}

sparsity_clustered_model = cluster_weights(stripped_pruned_model, **clustering_params)

sparsity_clustered_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

print('Train sparsity preserving clustering model:')
sparsity_clustered_model.fit(train_images, train_labels,epochs=3, validation_split=0.1)
Train sparsity preserving clustering model:
Epoch 1/3
1688/1688 [==============================] - 9s 5ms/step - loss: 0.0422 - accuracy: 0.9869 - val_loss: 0.0712 - val_accuracy: 0.9818
Epoch 2/3
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0398 - accuracy: 0.9878 - val_loss: 0.0627 - val_accuracy: 0.9848
Epoch 3/3
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0403 - accuracy: 0.9865 - val_loss: 0.0597 - val_accuracy: 0.9830
<tf_keras.src.callbacks.History at 0x7f6080153790>

先剝離叢集包裝器,然後檢查模型是否已正確剪枝和叢集。

stripped_clustered_model = tfmot.clustering.keras.strip_clustering(sparsity_clustered_model)

print("Model sparsity:\n")
print_model_weights_sparsity(stripped_clustered_model)

print("\nModel clusters:\n")
print_model_weight_clusters(stripped_clustered_model)
Model sparsity:

kernel:0: 50.93% sparsity  (55/108)
kernel:0: 58.12% sparsity  (11787/20280)

Model clusters:

conv2d/kernel:0: 8 clusters 
dense/kernel:0: 8 clusters

套用 QAT 和 PCQAT,並檢查對模型叢集和稀疏性的影響

接下來,在稀疏叢集模型上套用 QAT 和 PCQAT,並觀察 PCQAT 是否保留了模型中的權重稀疏性和叢集。請注意,剝離的模型會傳遞至 QAT 和 PCQAT API。

# QAT
qat_model = tfmot.quantization.keras.quantize_model(stripped_clustered_model)

qat_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
print('Train qat model:')
qat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)

# PCQAT
quant_aware_annotate_model = tfmot.quantization.keras.quantize_annotate_model(
              stripped_clustered_model)
pcqat_model = tfmot.quantization.keras.quantize_apply(
              quant_aware_annotate_model,
              tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme(preserve_sparsity=True))

pcqat_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
print('Train pcqat model:')
pcqat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)
Train qat model:
422/422 [==============================] - 4s 7ms/step - loss: 0.0298 - accuracy: 0.9911 - val_loss: 0.0587 - val_accuracy: 0.9853
Train pcqat model:
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
422/422 [==============================] - 5s 7ms/step - loss: 0.0315 - accuracy: 0.9904 - val_loss: 0.0563 - val_accuracy: 0.9842
<tf_keras.src.callbacks.History at 0x7f6050606e80>
print("QAT Model clusters:")
print_model_weight_clusters(qat_model)
print("\nQAT Model sparsity:")
print_model_weights_sparsity(qat_model)
print("\nPCQAT Model clusters:")
print_model_weight_clusters(pcqat_model)
print("\nPCQAT Model sparsity:")
print_model_weights_sparsity(pcqat_model)
QAT Model clusters:
quant_conv2d/conv2d/kernel:0: 100 clusters 
quant_dense/dense/kernel:0: 18251 clusters 

QAT Model sparsity:
conv2d/kernel:0: 8.33% sparsity  (9/108)
dense/kernel:0: 7.52% sparsity  (1525/20280)

PCQAT Model clusters:
quant_conv2d/conv2d/kernel:0: 8 clusters 
quant_dense/dense/kernel:0: 8 clusters 

PCQAT Model sparsity:
conv2d/kernel:0: 50.93% sparsity  (55/108)
dense/kernel:0: 58.16% sparsity  (11794/20280)

查看 PCQAT 模型的壓縮優勢

定義輔助函數以取得已壓縮的模型檔案。

def get_gzipped_model_size(file):
  # It returns the size of the gzipped model in kilobytes.

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file)/1000

觀察到將稀疏性、叢集和 PCQAT 套用於模型可產生顯著的壓縮優勢。

# QAT model
converter = tf.lite.TFLiteConverter.from_keras_model(qat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
qat_tflite_model = converter.convert()
qat_model_file = 'qat_model.tflite'
# Save the model.
with open(qat_model_file, 'wb') as f:
    f.write(qat_tflite_model)

# PCQAT model
converter = tf.lite.TFLiteConverter.from_keras_model(pcqat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
pcqat_tflite_model = converter.convert()
pcqat_model_file = 'pcqat_model.tflite'
# Save the model.
with open(pcqat_model_file, 'wb') as f:
    f.write(pcqat_tflite_model)

print("QAT model size: ", get_gzipped_model_size(qat_model_file), ' KB')
print("PCQAT model size: ", get_gzipped_model_size(pcqat_model_file), ' KB')
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpbd29dk98/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpbd29dk98/assets
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:964: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway.
  warnings.warn(
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1709988717.237025   41361 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format.
W0000 00:00:1709988717.237075   41361 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpy4q5o_1n/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpy4q5o_1n/assets
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:964: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway.
  warnings.warn(
W0000 00:00:1709988720.060897   41361 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format.
W0000 00:00:1709988720.060927   41361 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.
QAT model size:  13.958  KB
PCQAT model size:  7.876  KB

查看從 TF 到 TFLite 的準確度持久性

定義輔助函數以在測試資料集上評估 TFLite 模型。

def eval_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print(f"Evaluated on {i} results so far.")
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy

評估已剪枝、叢集和量化的模型,然後查看 TensorFlow 的準確度是否在 TFLite 後端持續存在。

interpreter = tf.lite.Interpreter(pcqat_model_file)
interpreter.allocate_tensors()

pcqat_test_accuracy = eval_model(interpreter)

print('Pruned, clustered and quantized TFLite test_accuracy:', pcqat_test_accuracy)
print('Baseline TF test accuracy:', baseline_model_accuracy)
Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
WARNING: Attempting to use a delegate that only supports static-sized tensors with a graph that has dynamic-sized tensors (tensor#12 is a dynamic-sized tensor).
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


Pruned, clustered and quantized TFLite test_accuracy: 0.9806
Baseline TF test accuracy: 0.9835000038146973

結論

在本教學課程中,您學習了如何建立模型、使用 prune_low_magnitude() API 進行剪枝,以及使用 cluster_weights() API 套用稀疏性保留叢集,以在叢集權重時保留稀疏性。

接下來,套用了稀疏性與叢集保留量化感知訓練 (PCQAT),以便在使用 QAT 時保留模型稀疏性和叢集。最終將 PCQAT 模型與 QAT 模型進行比較,以顯示前者保留了稀疏性和叢集,而後者則失去了稀疏性和叢集。

接下來,模型被轉換為 TFLite,以顯示鏈結稀疏性、叢集和 PCQAT 模型最佳化技術的壓縮優勢,並評估 TFLite 模型,以確保準確度在 TFLite 後端持續存在。

最後,將 PCQAT TFLite 模型的準確度與預先最佳化基準模型準確度進行比較,以顯示協同最佳化技術設法在實現壓縮優勢的同時,與原始模型相比,保持了相似的準確度。