叢集保留量化感知訓練 (CQAT) Keras 範例

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視 下載筆記本

總覽

這是一個端對端範例,展示了叢集保留量化感知訓練 (CQAT) API 的用法,它是 TensorFlow 模型最佳化工具組協作最佳化管線的一部分。

其他頁面

如需管線和其他可用技術的簡介,請參閱協作最佳化總覽頁面

目錄

在本教學課程中,您將

  1. 從頭開始訓練用於 MNIST 資料集的 keras 模型。
  2. 使用叢集微調模型,並查看準確度。
  3. 套用 QAT 並觀察叢集的遺失。
  4. 套用 CQAT 並觀察先前套用的叢集已保留。
  5. 產生 TFLite 模型並觀察套用 CQAT 對模型的影響。
  6. 比較使用 CQAT 達成的模型準確度與使用訓練後量化量化的模型。

設定

您可以在本機 virtualenvcolab 中執行此 Jupyter Notebook。如需設定依附元件的詳細資訊,請參閱安裝指南

 pip install -q tensorflow-model-optimization
import tensorflow as tf
import tf_keras as keras

import numpy as np
import tempfile
import zipfile
import os

訓練用於 MNIST 的 keras 模型,不使用叢集

# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images  = test_images / 255.0

model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3),
                         activation=tf.nn.relu),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
    train_images,
    train_labels,
    validation_split=0.1,
    epochs=10
)
2024-03-09 12:45:06.324078: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Epoch 1/10
1688/1688 [==============================] - 21s 4ms/step - loss: 0.3191 - accuracy: 0.9090 - val_loss: 0.1358 - val_accuracy: 0.9645
Epoch 2/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.1291 - accuracy: 0.9635 - val_loss: 0.0912 - val_accuracy: 0.9748
Epoch 3/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0886 - accuracy: 0.9740 - val_loss: 0.0749 - val_accuracy: 0.9795
Epoch 4/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0710 - accuracy: 0.9789 - val_loss: 0.0637 - val_accuracy: 0.9818
Epoch 5/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0601 - accuracy: 0.9819 - val_loss: 0.0659 - val_accuracy: 0.9817
Epoch 6/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0532 - accuracy: 0.9838 - val_loss: 0.0630 - val_accuracy: 0.9828
Epoch 7/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0477 - accuracy: 0.9855 - val_loss: 0.0639 - val_accuracy: 0.9832
Epoch 8/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0427 - accuracy: 0.9865 - val_loss: 0.0598 - val_accuracy: 0.9850
Epoch 9/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0393 - accuracy: 0.9876 - val_loss: 0.0590 - val_accuracy: 0.9837
Epoch 10/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0353 - accuracy: 0.9891 - val_loss: 0.0610 - val_accuracy: 0.9842
<tf_keras.src.callbacks.History at 0x7f22b1b9a0d0>

評估基準模型並儲存以供日後使用

_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

_, keras_file = tempfile.mkstemp('.h5')
print('Saving model to: ', keras_file)
keras.models.save_model(model, keras_file, include_optimizer=False)
Baseline test accuracy: 0.9818000197410583
Saving model to:  /tmpfs/tmp/tmpo7dgy4fg.h5
/tmpfs/tmp/ipykernel_38069/3680774635.py:8: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native TF-Keras format, e.g. `model.save('my_model.keras')`.
  keras.models.save_model(model, keras_file, include_optimizer=False)

使用 8 個叢集叢集化和微調模型

套用 cluster_weights() API 以叢集化整個預先訓練的模型,以示範和觀察其在套用 zip 時減少模型大小的效果,同時維持準確度。如需如何最佳使用 API 以在維持目標準確度的同時達成最佳壓縮率,請參閱叢集化完整指南

定義模型並套用叢集化 API

模型需要先經過預先訓練,才能使用叢集化 API。

import tensorflow_model_optimization as tfmot

cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization

clustering_params = {
  'number_of_clusters': 8,
  'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS,
  'cluster_per_channel': True,
}

clustered_model = cluster_weights(model, **clustering_params)

# Use smaller learning rate for fine-tuning
opt = keras.optimizers.Adam(learning_rate=1e-5)

clustered_model.compile(
  loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  optimizer=opt,
  metrics=['accuracy'])

clustered_model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 cluster_reshape (ClusterWe  (None, 28, 28, 1)         0         
 ights)                                                          
                                                                 
 cluster_conv2d (ClusterWei  (None, 26, 26, 12)        324       
 ghts)                                                           
                                                                 
 cluster_max_pooling2d (Clu  (None, 13, 13, 12)        0         
 sterWeights)                                                    
                                                                 
 cluster_flatten (ClusterWe  (None, 2028)              0         
 ights)                                                          
                                                                 
 cluster_dense (ClusterWeig  (None, 10)                40578     
 hts)                                                            
                                                                 
=================================================================
Total params: 40902 (239.41 KB)
Trainable params: 20514 (80.13 KB)
Non-trainable params: 20388 (159.28 KB)
_________________________________________________________________

微調模型並評估相對於基準的準確度

使用叢集微調模型 3 個週期。

# Fine-tune model
clustered_model.fit(
  train_images,
  train_labels,
  epochs=3,
  validation_split=0.1)
Epoch 1/3
1688/1688 [==============================] - 11s 5ms/step - loss: 0.0316 - accuracy: 0.9909 - val_loss: 0.0610 - val_accuracy: 0.9837
Epoch 2/3
1688/1688 [==============================] - 8s 5ms/step - loss: 0.0297 - accuracy: 0.9916 - val_loss: 0.0603 - val_accuracy: 0.9852
Epoch 3/3
1688/1688 [==============================] - 8s 5ms/step - loss: 0.0291 - accuracy: 0.9919 - val_loss: 0.0596 - val_accuracy: 0.9850
<tf_keras.src.callbacks.History at 0x7f22946065e0>

定義輔助函式,以計算和列印模型每個核心中的叢集數量。

def print_model_weight_clusters(model):

    for layer in model.layers:
        if isinstance(layer, keras.layers.Wrapper):
            weights = layer.trainable_weights
        else:
            weights = layer.weights
        for weight in weights:
            # ignore auxiliary quantization weights
            if "quantize_layer" in weight.name:
                continue
            if "kernel" in weight.name:
                unique_count = len(np.unique(weight))
                print(
                    f"{layer.name}/{weight.name}: {unique_count} clusters "
                )

檢查模型核心是否已正確叢集化。我們需要先移除叢集化包裝函式。

stripped_clustered_model = tfmot.clustering.keras.strip_clustering(clustered_model)

print_model_weight_clusters(stripped_clustered_model)
conv2d/kernel:0: 96 clusters 
dense/kernel:0: 8 clusters

在此範例中,與基準相比,叢集化後測試準確度的損失極小。

_, clustered_model_accuracy = clustered_model.evaluate(
  test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)
print('Clustered test accuracy:', clustered_model_accuracy)
Baseline test accuracy: 0.9818000197410583
Clustered test accuracy: 0.9818999767303467

套用 QAT 和 CQAT,並檢查兩種情況下對模型叢集的影響

接下來,我們在叢集化模型上套用 QAT 和叢集保留 QAT (CQAT),並觀察到 CQAT 保留了叢集化模型中的權重叢集。請注意,我們在套用 CQAT API 之前,已使用 tfmot.clustering.keras.strip_clustering 從模型中移除叢集化包裝函式。

# QAT
qat_model = tfmot.quantization.keras.quantize_model(stripped_clustered_model)

qat_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
print('Train qat model:')
qat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)

# CQAT
quant_aware_annotate_model = tfmot.quantization.keras.quantize_annotate_model(
              stripped_clustered_model)
cqat_model = tfmot.quantization.keras.quantize_apply(
              quant_aware_annotate_model,
              tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme())

cqat_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
print('Train cqat model:')
cqat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)
Train qat model:
422/422 [==============================] - 4s 7ms/step - loss: 0.0315 - accuracy: 0.9905 - val_loss: 0.0573 - val_accuracy: 0.9855
WARNING:root:Input layer does not contain zero weights, so apply CQAT instead.
WARNING:root:Input layer does not contain zero weights, so apply CQAT instead.
Train cqat model:
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
422/422 [==============================] - 6s 8ms/step - loss: 0.0290 - accuracy: 0.9917 - val_loss: 0.0597 - val_accuracy: 0.9847
<tf_keras.src.callbacks.History at 0x7f229414f130>
print("QAT Model clusters:")
print_model_weight_clusters(qat_model)
print("CQAT Model clusters:")
print_model_weight_clusters(cqat_model)
QAT Model clusters:
quant_conv2d/conv2d/kernel:0: 108 clusters 
quant_dense/dense/kernel:0: 19910 clusters 
CQAT Model clusters:
quant_conv2d/conv2d/kernel:0: 96 clusters 
quant_dense/dense/kernel:0: 8 clusters

查看 CQAT 模型的壓縮優勢

定義輔助函式以取得壓縮模型檔案。

def get_gzipped_model_size(file):
  # It returns the size of the gzipped model in kilobytes.

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file)/1000

請注意,這是一個小型模型。將叢集化和 CQAT 套用到較大的生產模型會產生更顯著的壓縮效果。

# QAT model
converter = tf.lite.TFLiteConverter.from_keras_model(qat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
qat_tflite_model = converter.convert()
qat_model_file = 'qat_model.tflite'
# Save the model.
with open(qat_model_file, 'wb') as f:
    f.write(qat_tflite_model)

# CQAT model
converter = tf.lite.TFLiteConverter.from_keras_model(cqat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
cqat_tflite_model = converter.convert()
cqat_model_file = 'cqat_model.tflite'
# Save the model.
with open(cqat_model_file, 'wb') as f:
    f.write(cqat_tflite_model)

print("QAT model size: ", get_gzipped_model_size(qat_model_file), ' KB')
print("CQAT model size: ", get_gzipped_model_size(cqat_model_file), ' KB')
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpqnilvtco/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpqnilvtco/assets
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:964: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway.
  warnings.warn(
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1709988433.596770   38069 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format.
W0000 00:00:1709988433.596818   38069 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpqlp0tpfp/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpqlp0tpfp/assets
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:964: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway.
  warnings.warn(
QAT model size:  17.487  KB
CQAT model size:  10.64  KB
W0000 00:00:1709988436.818205   38069 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format.
W0000 00:00:1709988436.818236   38069 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.

查看從 TF 到 TFLite 的準確度持久性

定義輔助函式以評估測試資料集上的 TFLite 模型。

def eval_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print(f"Evaluated on {i} results so far.")
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy

您評估已叢集化和量化的模型,然後查看 TensorFlow 的準確度是否在 TFLite 後端持續存在。

interpreter = tf.lite.Interpreter(cqat_model_file)
interpreter.allocate_tensors()

cqat_test_accuracy = eval_model(interpreter)

print('Clustered and quantized TFLite test_accuracy:', cqat_test_accuracy)
print('Clustered TF test accuracy:', clustered_model_accuracy)
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
WARNING: Attempting to use a delegate that only supports static-sized tensors with a graph that has dynamic-sized tensors (tensor#12 is a dynamic-sized tensor).
Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


Clustered and quantized TFLite test_accuracy: 0.9822
Clustered TF test accuracy: 0.9818999767303467

套用訓練後量化並與 CQAT 模型比較

接下來,我們在叢集化模型上使用訓練後量化 (無微調),並檢查其準確度與 CQAT 模型。這示範了為何您需要使用 CQAT 來提高量化模型的準確度。差異可能不太明顯,因為 MNIST 模型非常小且過度參數化。

首先,從前 1000 張訓練影像定義校正資料集的產生器。

def mnist_representative_data_gen():
  for image in train_images[:1000]:  
    image = np.expand_dims(image, axis=0).astype(np.float32)
    yield [image]

量化模型並將準確度與先前取得的 CQAT 模型進行比較。請注意,經過微調量化的模型可達到更高的準確度。

converter = tf.lite.TFLiteConverter.from_keras_model(stripped_clustered_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = mnist_representative_data_gen
post_training_tflite_model = converter.convert()
post_training_model_file = 'post_training_model.tflite'
# Save the model.
with open(post_training_model_file, 'wb') as f:
    f.write(post_training_tflite_model)

# Compare accuracy
interpreter = tf.lite.Interpreter(post_training_model_file)
interpreter.allocate_tensors()

post_training_test_accuracy = eval_model(interpreter)

print('CQAT TFLite test_accuracy:', cqat_test_accuracy)
print('Post-training (no fine-tuning) TF test accuracy:', post_training_test_accuracy)
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpxyohbvab/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpxyohbvab/assets
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:964: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway.
  warnings.warn(
W0000 00:00:1709988438.608574   38069 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format.
W0000 00:00:1709988438.608603   38069 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.
fully_quantize: 0, inference_type: 6, input_inference_type: FLOAT32, output_inference_type: FLOAT32
Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


CQAT TFLite test_accuracy: 0.9822
Post-training (no fine-tuning) TF test accuracy: 0.9817

結論

在本教學課程中,您學習瞭如何建立模型、使用 cluster_weights() API 叢集化模型,以及套用叢集保留量化感知訓練 (CQAT),以在使用 QAT 時保留叢集。將最終的 CQAT 模型與 QAT 模型進行比較,以顯示叢集在前者中得到保留,而在後者中遺失。接下來,模型轉換為 TFLite,以顯示鏈結叢集化和 CQAT 模型最佳化技術的壓縮優勢,並評估 TFLite 模型,以確保準確度在 TFLite 後端持續存在。最後,將 CQAT 模型與使用訓練後量化 API 達成的量化叢集化模型進行比較,以示範 CQAT 在恢復正常量化造成的準確度損失方面的優勢。