稀疏性保留集群 Keras 範例

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視 下載筆記本

總覽

這是一個端對端範例,展示了 稀疏性保留集群 API 的用法,該 API 是 TensorFlow 模型最佳化工具組協作最佳化管線的一部分。

其他頁面

如需管線和其他可用技術的簡介,請參閱協作最佳化總覽頁面

目錄

在本教學課程中,您將

  1. 從頭開始訓練用於 MNIST 資料集的 keras 模型。
  2. 使用稀疏性微調模型,並查看準確度,並觀察模型已成功剪枝。
  3. 將權重集群應用於剪枝後的模型,並觀察稀疏性的損失。
  4. 將稀疏性保留集群應用於剪枝後的模型,並觀察先前應用的稀疏性已得到保留。
  5. 產生 TFLite 模型,並檢查準確度是否已在剪枝集群模型中得到保留。
  6. 比較不同模型的大小,以觀察應用稀疏性,然後應用稀疏性保留集群協作最佳化技術的壓縮優勢。

設定

您可以在本機 virtualenvcolab 中執行此 Jupyter Notebook。如需設定依賴項的詳細資訊,請參閱安裝指南

 pip install -q tensorflow-model-optimization
import tensorflow as tf
import tf_keras as keras

import numpy as np
import tempfile
import zipfile
import os

訓練用於 MNIST 的 keras 模型以進行剪枝和集群

# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images  = test_images / 255.0

model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3),
                         activation=tf.nn.relu),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
    train_images,
    train_labels,
    validation_split=0.1,
    epochs=10
)
2024-03-09 12:54:09.347032: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Epoch 1/10
1688/1688 [==============================] - 21s 4ms/step - loss: 0.3119 - accuracy: 0.9120 - val_loss: 0.1272 - val_accuracy: 0.9640
Epoch 2/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.1224 - accuracy: 0.9655 - val_loss: 0.0870 - val_accuracy: 0.9770
Epoch 3/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0908 - accuracy: 0.9740 - val_loss: 0.0740 - val_accuracy: 0.9800
Epoch 4/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0759 - accuracy: 0.9775 - val_loss: 0.0639 - val_accuracy: 0.9830
Epoch 5/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0659 - accuracy: 0.9810 - val_loss: 0.0653 - val_accuracy: 0.9832
Epoch 6/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0593 - accuracy: 0.9822 - val_loss: 0.0675 - val_accuracy: 0.9805
Epoch 7/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0538 - accuracy: 0.9839 - val_loss: 0.0615 - val_accuracy: 0.9825
Epoch 8/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0489 - accuracy: 0.9848 - val_loss: 0.0619 - val_accuracy: 0.9832
Epoch 9/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0457 - accuracy: 0.9862 - val_loss: 0.0639 - val_accuracy: 0.9838
Epoch 10/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0422 - accuracy: 0.9870 - val_loss: 0.0593 - val_accuracy: 0.9835
<tf_keras.src.callbacks.History at 0x7f0f2dac7b80>

評估基準模型並儲存以供日後使用

_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

_, keras_file = tempfile.mkstemp('.h5')
print('Saving model to: ', keras_file)
keras.models.save_model(model, keras_file, include_optimizer=False)
Baseline test accuracy: 0.98089998960495
Saving model to:  /tmpfs/tmp/tmp98l4xiax.h5
/tmpfs/tmp/ipykernel_44770/3680774635.py:8: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native TF-Keras format, e.g. `model.save('my_model.keras')`.
  keras.models.save_model(model, keras_file, include_optimizer=False)

剪枝並微調模型至 50% 稀疏性

套用 prune_low_magnitude() API 以剪枝整個預先訓練的模型,以達成下一步要集群的模型。如需如何最佳使用 API 以在維持目標準確度的同時達成最佳壓縮率,請參閱剪枝綜合指南

定義模型並套用稀疏性 API

請注意,使用了預先訓練的模型。

import tensorflow_model_optimization as tfmot

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(0.5, begin_step=0, frequency=100)
  }

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep()
]

pruned_model = prune_low_magnitude(model, **pruning_params)

# Use smaller learning rate for fine-tuning
opt = keras.optimizers.Adam(learning_rate=1e-5)

pruned_model.compile(
  loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  optimizer=opt,
  metrics=['accuracy'])

pruned_model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 prune_low_magnitude_reshap  (None, 28, 28, 1)         1         
 e (PruneLowMagnitude)                                           
                                                                 
 prune_low_magnitude_conv2d  (None, 26, 26, 12)        230       
  (PruneLowMagnitude)                                            
                                                                 
 prune_low_magnitude_max_po  (None, 13, 13, 12)        1         
 oling2d (PruneLowMagnitude                                      
 )                                                               
                                                                 
 prune_low_magnitude_flatte  (None, 2028)              1         
 n (PruneLowMagnitude)                                           
                                                                 
 prune_low_magnitude_dense   (None, 10)                40572     
 (PruneLowMagnitude)                                             
                                                                 
=================================================================
Total params: 40805 (159.41 KB)
Trainable params: 20410 (79.73 KB)
Non-trainable params: 20395 (79.69 KB)
_________________________________________________________________

微調模型、檢查稀疏性,並評估相對於基準的準確度

使用剪枝微調模型 3 個週期。

# Fine-tune model
pruned_model.fit(
  train_images,
  train_labels,
  epochs=3,
  validation_split=0.1,
  callbacks=callbacks)
Epoch 1/3
1688/1688 [==============================] - 10s 4ms/step - loss: 0.0805 - accuracy: 0.9729 - val_loss: 0.0834 - val_accuracy: 0.9753
Epoch 2/3
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0636 - accuracy: 0.9791 - val_loss: 0.0735 - val_accuracy: 0.9798
Epoch 3/3
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0557 - accuracy: 0.9825 - val_loss: 0.0688 - val_accuracy: 0.9813
<tf_keras.src.callbacks.History at 0x7f0f2d9bc220>

定義輔助函式以計算和列印模型的稀疏性。

def print_model_weights_sparsity(model):

    for layer in model.layers:
        if isinstance(layer, keras.layers.Wrapper):
            weights = layer.trainable_weights
        else:
            weights = layer.weights
        for weight in weights:
            if "kernel" not in weight.name or "centroid" in weight.name:
                continue
            weight_size = weight.numpy().size
            zero_num = np.count_nonzero(weight == 0)
            print(
                f"{weight.name}: {zero_num/weight_size:.2%} sparsity ",
                f"({zero_num}/{weight_size})",
            )

檢查模型核心是否已正確剪枝。我們需要先剝離剪枝包裝函式。我們也建立模型的深層副本,以在下一步中使用。

stripped_pruned_model = tfmot.sparsity.keras.strip_pruning(pruned_model)

print_model_weights_sparsity(stripped_pruned_model)

stripped_pruned_model_copy = keras.models.clone_model(stripped_pruned_model)
stripped_pruned_model_copy.set_weights(stripped_pruned_model.get_weights())
conv2d/kernel:0: 50.00% sparsity  (54/108)
dense/kernel:0: 50.00% sparsity  (10140/20280)

套用集群和稀疏性保留集群,並檢查其在兩種情況下對模型稀疏性的影響

接下來,我們將集群和稀疏性保留集群都應用於剪枝後的模型,並觀察到後者在您的剪枝模型上保留了稀疏性。請注意,我們在套用集群 API 之前,已使用 tfmot.sparsity.keras.strip_pruning 從剪枝模型中剝離了剪枝包裝函式。

# Clustering
cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization

clustering_params = {
  'number_of_clusters': 8,
  'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS
}

clustered_model = cluster_weights(stripped_pruned_model, **clustering_params)

clustered_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

print('Train clustering model:')
clustered_model.fit(train_images, train_labels,epochs=3, validation_split=0.1)


stripped_pruned_model.save("stripped_pruned_model_clustered.h5")

# Sparsity preserving clustering
from tensorflow_model_optimization.python.core.clustering.keras.experimental import (
    cluster,
)

cluster_weights = cluster.cluster_weights

clustering_params = {
  'number_of_clusters': 8,
  'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS,
  'preserve_sparsity': True
}

sparsity_clustered_model = cluster_weights(stripped_pruned_model_copy, **clustering_params)

sparsity_clustered_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

print('Train sparsity preserving clustering model:')
sparsity_clustered_model.fit(train_images, train_labels,epochs=3, validation_split=0.1)
Train clustering model:
Epoch 1/3
1688/1688 [==============================] - 9s 5ms/step - loss: 0.0477 - accuracy: 0.9846 - val_loss: 0.0659 - val_accuracy: 0.9813
Epoch 2/3
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0464 - accuracy: 0.9851 - val_loss: 0.0611 - val_accuracy: 0.9825
Epoch 3/3
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0452 - accuracy: 0.9855 - val_loss: 0.0728 - val_accuracy: 0.9797
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tf_keras/src/engine/training.py:3098: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native TF-Keras format, e.g. `model.save('my_model.keras')`.
  saving_api.save_model(
Train sparsity preserving clustering model:
Epoch 1/3
1688/1688 [==============================] - 9s 5ms/step - loss: 0.0471 - accuracy: 0.9853 - val_loss: 0.0669 - val_accuracy: 0.9823
Epoch 2/3
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0446 - accuracy: 0.9863 - val_loss: 0.0661 - val_accuracy: 0.9817
Epoch 3/3
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0435 - accuracy: 0.9857 - val_loss: 0.0714 - val_accuracy: 0.9813
<tf_keras.src.callbacks.History at 0x7f0e4415cc40>

檢查兩種模型的稀疏性。

print("Clustered Model sparsity:\n")
print_model_weights_sparsity(clustered_model)
print("\nSparsity preserved clustered Model sparsity:\n")
print_model_weights_sparsity(sparsity_clustered_model)
Clustered Model sparsity:

conv2d/kernel:0: 0.00% sparsity  (0/108)
dense/kernel:0: 0.34% sparsity  (69/20280)

Sparsity preserved clustered Model sparsity:

conv2d/kernel:0: 50.00% sparsity  (54/108)
dense/kernel:0: 50.00% sparsity  (10140/20280)

從集群建立小 1.6 倍的模型

定義輔助函式以取得壓縮模型檔案。

def get_gzipped_model_size(file):
  # It returns the size of the gzipped model in kilobytes.

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file)/1000
# Clustered model
clustered_model_file = 'clustered_model.h5'

# Save the model.
clustered_model.save(clustered_model_file)

#Sparsity Preserve Clustered model
sparsity_clustered_model_file = 'sparsity_clustered_model.h5'

# Save the model.
sparsity_clustered_model.save(sparsity_clustered_model_file)

print("Clustered Model size: ", get_gzipped_model_size(clustered_model_file), ' KB')
print("Sparsity preserved clustered Model size: ", get_gzipped_model_size(sparsity_clustered_model_file), ' KB')
Clustered Model size:  247.191  KB
Sparsity preserved clustered Model size:  155.272  KB

從結合稀疏性保留權重集群和訓練後量化建立 TFLite 模型

剝離集群包裝函式並轉換為 TFLite。

stripped_sparsity_clustered_model = tfmot.clustering.keras.strip_clustering(sparsity_clustered_model)

converter = tf.lite.TFLiteConverter.from_keras_model(stripped_sparsity_clustered_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
sparsity_clustered_quant_model = converter.convert()

_, pruned_and_clustered_tflite_file = tempfile.mkstemp('.tflite')

with open(pruned_and_clustered_tflite_file, 'wb') as f:
  f.write(sparsity_clustered_quant_model)

print("Sparsity preserved clustered Model size: ", get_gzipped_model_size(sparsity_clustered_model_file), ' KB')
print("Sparsity preserved clustered and quantized TFLite model size:",
       get_gzipped_model_size(pruned_and_clustered_tflite_file), ' KB')
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpzf7jus7v/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpzf7jus7v/assets
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1709989008.133294   44770 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format.
W0000 00:00:1709989008.133348   44770 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.
Sparsity preserved clustered Model size:  155.272  KB
Sparsity preserved clustered and quantized TFLite model size: 8.183  KB

查看從 TF 到 TFLite 的準確度持久性

def eval_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print(f"Evaluated on {i} results so far.")
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy

您評估了已剪枝、集群和量化的模型,然後看到 TensorFlow 的準確度在 TFLite 後端持續存在。

# Keras model evaluation
stripped_sparsity_clustered_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
_, sparsity_clustered_keras_accuracy = stripped_sparsity_clustered_model.evaluate(
    test_images, test_labels, verbose=0)

# TFLite model evaluation
interpreter = tf.lite.Interpreter(pruned_and_clustered_tflite_file)
interpreter.allocate_tensors()

sparsity_clustered_tflite_accuracy = eval_model(interpreter)

print('Pruned, clustered and quantized Keras model accuracy:', sparsity_clustered_keras_accuracy)
print('Pruned, clustered and quantized TFLite model accuracy:', sparsity_clustered_tflite_accuracy)
Evaluated on 0 results so far.
Evaluated on 1000 results so far.
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
WARNING: Attempting to use a delegate that only supports static-sized tensors with a graph that has dynamic-sized tensors (tensor#13 is a dynamic-sized tensor).
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


Pruned, clustered and quantized Keras model accuracy: 0.9782999753952026
Pruned, clustered and quantized TFLite model accuracy: 0.9784

結論

在本教學課程中,您學習瞭如何建立模型、使用 prune_low_magnitude() API 剪枝模型,以及應用稀疏性保留集群以在集群權重時保留稀疏性。稀疏性保留集群模型與集群模型進行了比較,以顯示稀疏性在前一種模型中得到保留,而在後一種模型中則丟失。接下來,剪枝集群模型被轉換為 TFLite,以顯示鏈結剪枝和稀疏性保留集群模型最佳化技術的壓縮優勢,最後,評估了 TFLite 模型,以確保準確度在 TFLite 後端持續存在。