Keras 範例中的剪枝

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視原始碼 下載筆記本

總覽

歡迎來到以量值為基礎的權重剪枝端對端範例。

其他頁面

如需剪枝簡介和判斷是否應使用剪枝 (包括支援哪些項目),請參閱總覽頁面

若要快速找到您的使用案例 (除了以 80% 稀疏性完整剪枝模型之外) 所需的 API,請參閱綜合指南

摘要

在本教學課程中,您將

  1. 從頭開始訓練用於 MNIST 的 keras 模型。
  2. 透過套用剪枝 API 來微調模型並查看準確度。
  3. 從剪枝建立小 3 倍的 TF 和 TFLite 模型。
  4. 從結合剪枝和訓練後量化建立小 10 倍的 TFLite 模型。
  5. 查看從 TF 到 TFLite 的準確度持久性。

設定

 pip install -q tensorflow-model-optimization
import tempfile
import os

import tensorflow as tf
import numpy as np

from tensorflow_model_optimization.python.core.keras.compat import keras

%load_ext tensorboard

訓練用於 MNIST 的模型,不進行剪枝

# Load MNIST dataset
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

# Normalize the input image so that each pixel value is between 0 and 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

# Define the model architecture.
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
  train_images,
  train_labels,
  epochs=4,
  validation_split=0.1,
)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11490434/11490434 [==============================] - 0s 0us/step
2024-03-09 12:16:05.087631: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Epoch 1/4
1688/1688 [==============================] - 21s 4ms/step - loss: 0.3272 - accuracy: 0.9066 - val_loss: 0.1451 - val_accuracy: 0.9602
Epoch 2/4
1688/1688 [==============================] - 7s 4ms/step - loss: 0.1396 - accuracy: 0.9602 - val_loss: 0.1017 - val_accuracy: 0.9698
Epoch 3/4
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0925 - accuracy: 0.9730 - val_loss: 0.0776 - val_accuracy: 0.9780
Epoch 4/4
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0726 - accuracy: 0.9787 - val_loss: 0.0653 - val_accuracy: 0.9822
<tf_keras.src.callbacks.History at 0x7f630863b790>

評估基準測試準確度並儲存模型以供日後使用。

_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

_, keras_file = tempfile.mkstemp('.h5')
keras.models.save_model(model, keras_file, include_optimizer=False)
print('Saved baseline model to:', keras_file)
Baseline test accuracy: 0.9790999889373779
Saved baseline model to: /tmpfs/tmp/tmpq067am7o.h5
/tmpfs/tmp/ipykernel_10506/3790298460.py:7: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native TF-Keras format, e.g. `model.save('my_model.keras')`.
  keras.models.save_model(model, keras_file, include_optimizer=False)

使用剪枝微調預先訓練的模型

定義模型

您將對整個模型套用剪枝,並在模型摘要中看到此情況。

在本範例中,您從具有 50% 稀疏性 (權重中 50% 為零) 的模型開始,並以 80% 稀疏性結束。

綜合指南中,您可以查看如何剪枝某些層以改善模型準確度。

import tensorflow_model_optimization as tfmot

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

# Compute end step to finish pruning after 2 epochs.
batch_size = 128
epochs = 2
validation_split = 0.1 # 10% of training set will be used for validation set. 

num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

# Define model for pruning.
pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                               final_sparsity=0.80,
                                                               begin_step=0,
                                                               end_step=end_step)
}

model_for_pruning = prune_low_magnitude(model, **pruning_params)

# `prune_low_magnitude` requires a recompile.
model_for_pruning.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model_for_pruning.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 prune_low_magnitude_reshap  (None, 28, 28, 1)         1         
 e (PruneLowMagnitude)                                           
                                                                 
 prune_low_magnitude_conv2d  (None, 26, 26, 12)        230       
  (PruneLowMagnitude)                                            
                                                                 
 prune_low_magnitude_max_po  (None, 13, 13, 12)        1         
 oling2d (PruneLowMagnitude                                      
 )                                                               
                                                                 
 prune_low_magnitude_flatte  (None, 2028)              1         
 n (PruneLowMagnitude)                                           
                                                                 
 prune_low_magnitude_dense   (None, 10)                40572     
 (PruneLowMagnitude)                                             
                                                                 
=================================================================
Total params: 40805 (159.41 KB)
Trainable params: 20410 (79.73 KB)
Non-trainable params: 20395 (79.69 KB)
_________________________________________________________________

訓練模型並根據基準評估模型

使用剪枝微調兩個週期。

tfmot.sparsity.keras.UpdatePruningStep在訓練期間是必要的,而tfmot.sparsity.keras.PruningSummaries提供記錄以追蹤進度和偵錯。

logdir = tempfile.mkdtemp()

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
  tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]

model_for_pruning.fit(train_images, train_labels,
                  batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                  callbacks=callbacks)
Epoch 1/2
  1/422 [..............................] - ETA: 19:24 - loss: 0.0619 - accuracy: 0.9766WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0047s vs `on_train_batch_end` time: 0.0067s). Check your callbacks.
422/422 [==============================] - 5s 6ms/step - loss: 0.1062 - accuracy: 0.9710 - val_loss: 0.1166 - val_accuracy: 0.9718
Epoch 2/2
422/422 [==============================] - 2s 5ms/step - loss: 0.1099 - accuracy: 0.9697 - val_loss: 0.0930 - val_accuracy: 0.9747
<tf_keras.src.callbacks.History at 0x7f6278426cd0>

對於此範例,與基準相比,剪枝後測試準確度的損失極小。

_, model_for_pruning_accuracy = model_for_pruning.evaluate(
   test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy) 
print('Pruned test accuracy:', model_for_pruning_accuracy)
Baseline test accuracy: 0.9790999889373779
Pruned test accuracy: 0.9686999917030334

記錄顯示了每個層稀疏性的進展。

#docs_infra: no_execute
%tensorboard --logdir={logdir}

對於非 Colab 使用者,您可以在先前的執行結果中,在 TensorBoard.dev 上查看此程式碼區塊。

從剪枝建立小 3 倍的模型

需要同時使用tfmot.sparsity.keras.strip_pruning和套用標準壓縮演算法 (例如透過 gzip) 才能看到剪枝的壓縮優勢。

  • strip_pruning是必要的,因為它會移除剪枝僅在訓練期間需要的每個 tf.Variable,否則這些變數會在推論期間增加模型大小
  • 套用標準壓縮演算法是必要的,因為序列化的權重矩陣與剪枝前的大小相同。但是,剪枝使大多數權重變為零,這是演算法可用於進一步壓縮模型的額外冗餘。

首先,為 TensorFlow 建立可壓縮模型。

model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

_, pruned_keras_file = tempfile.mkstemp('.h5')
keras.models.save_model(model_for_export, pruned_keras_file, include_optimizer=False)
print('Saved pruned Keras model to:', pruned_keras_file)
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
Saved pruned Keras model to: /tmpfs/tmp/tmphtdbakhm.h5
/tmpfs/tmp/ipykernel_10506/3267383138.py:4: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native TF-Keras format, e.g. `model.save('my_model.keras')`.
  keras.models.save_model(model_for_export, pruned_keras_file, include_optimizer=False)

然後,為 TFLite 建立可壓縮模型。

converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
pruned_tflite_model = converter.convert()

_, pruned_tflite_file = tempfile.mkstemp('.tflite')

with open(pruned_tflite_file, 'wb') as f:
  f.write(pruned_tflite_model)

print('Saved pruned TFLite model to:', pruned_tflite_file)
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp0i4b2dha/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp0i4b2dha/assets
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1709986617.944296   10506 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format.
W0000 00:00:1709986617.944339   10506 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.
Saved pruned TFLite model to: /tmpfs/tmp/tmp1i3mzxcp.tflite

定義輔助函式,以透過 gzip 實際壓縮模型並測量壓縮後的大小。

def get_gzipped_model_size(file):
  # Returns size of gzipped model, in bytes.
  import os
  import zipfile

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file)

比較並查看模型因剪枝而小 3 倍。

print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size(keras_file)))
print("Size of gzipped pruned Keras model: %.2f bytes" % (get_gzipped_model_size(pruned_keras_file)))
print("Size of gzipped pruned TFlite model: %.2f bytes" % (get_gzipped_model_size(pruned_tflite_file)))
Size of gzipped baseline Keras model: 78239.00 bytes
Size of gzipped pruned Keras model: 25908.00 bytes
Size of gzipped pruned TFlite model: 24848.00 bytes

從結合剪枝和量化建立小 10 倍的模型

您可以對剪枝模型套用訓練後量化,以獲得額外的好處。

converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_and_pruned_tflite_model = converter.convert()

_, quantized_and_pruned_tflite_file = tempfile.mkstemp('.tflite')

with open(quantized_and_pruned_tflite_file, 'wb') as f:
  f.write(quantized_and_pruned_tflite_model)

print('Saved quantized and pruned TFLite model to:', quantized_and_pruned_tflite_file)

print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size(keras_file)))
print("Size of gzipped pruned and quantized TFlite model: %.2f bytes" % (get_gzipped_model_size(quantized_and_pruned_tflite_file)))
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp35zwmyql/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp35zwmyql/assets
W0000 00:00:1709986618.942100   10506 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format.
W0000 00:00:1709986618.942130   10506 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.
Saved quantized and pruned TFLite model to: /tmpfs/tmp/tmp3v6lm0h4.tflite
Size of gzipped baseline Keras model: 78239.00 bytes
Size of gzipped pruned and quantized TFlite model: 8064.00 bytes

查看從 TF 到 TFLite 的準確度持久性

定義輔助函式,以在測試資料集上評估 TF Lite 模型。

import numpy as np

def evaluate_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on ever y image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print('Evaluated on {n} results so far.'.format(n=i))
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy

您評估剪枝和量化模型,並查看從 TensorFlow 持續到 TFLite 後端的準確度。

interpreter = tf.lite.Interpreter(model_content=quantized_and_pruned_tflite_model)
interpreter.allocate_tensors()

test_accuracy = evaluate_model(interpreter)

print('Pruned and quantized TFLite test_accuracy:', test_accuracy)
print('Pruned TF test accuracy:', model_for_pruning_accuracy)
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
WARNING: Attempting to use a delegate that only supports static-sized tensors with a graph that has dynamic-sized tensors (tensor#13 is a dynamic-sized tensor).
Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


Pruned and quantized TFLite test_accuracy: 0.9691
Pruned TF test accuracy: 0.9686999917030334

結論

在本教學課程中,您看到了如何使用 TensorFlow 模型最佳化工具組 API 為 TensorFlow 和 TFLite 建立稀疏模型。然後,您將剪枝與訓練後量化結合,以獲得額外的好處。

您為 MNIST 建立了一個小 10 倍的模型,準確度差異極小。

我們鼓勵您嘗試這項新功能,這對於在資源受限的環境中部署可能尤其重要。