使用結構化剪枝的稀疏權重

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視原始碼 下載筆記本

從模型中結構化剪枝權重,使其在特定模式中變得稀疏,可以在適當的硬體支援下加速模型推論時間。

本教學課程將示範如何

  • 在 mnist 資料集上定義和訓練具有特定結構化稀疏性的模型
  • 將剪枝模型轉換為 tflite 格式
  • 視覺化剪枝權重的結構

如需模型最佳化剪枝技術的概略資訊,請參閱剪枝總覽。如需一般權重剪枝教學課程,請參閱使用 Keras 進行剪枝

權重的結構化剪枝

結構化剪枝會在訓練過程開始時系統性地將模型權重歸零。您可以將此剪枝技術應用於權重的規則區塊,以加速支援硬體的推論速度,例如:將模型中的權重以四個區塊分組,並在每個區塊中將其中兩個權重歸零,稱為 2x4 縮減。此技術僅適用於 TensorFlow Lite 轉換模型的權重張量的最後一個維度。例如,TensorFlow Lite 中的 Conv2D 層權重具有 [channel_out, height, width, channel_in] 結構,而 Dense 層權重具有 [channel_out, channel_in] 結構。稀疏模式會套用至最後一個維度中的權重:channel_in

與隨機稀疏性相比,結構化稀疏性由於結構限制通常具有較低的準確度,但是,它可以顯著減少支援硬體上的推論時間。

剪枝可以與其他模型壓縮技術一起應用於模型,以獲得更好的壓縮率。如需更多詳細資訊,請參閱協同最佳化技術中的量化和叢集範例。

設定

準備您的開發環境和資料。

 pip install -q tensorflow
 pip install -q tensorflow-model-optimization
 pip install -q matplotlib
import tensorflow as tf
from tensorflow import keras

import tensorflow_model_optimization as tfmot
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

MNIST 資料集下載並正規化影像資料

# Load MNIST dataset.
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 and 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

定義結構化剪枝參數

定義剪枝參數並指定結構化剪枝的類型。將剪枝參數設定為 (2, 4)。這些設定表示在四個元素的區塊中,至少有兩個具有最低量級的元素會設為零。

您不必設定 pruning_schedule 參數。預設情況下,剪枝遮罩會在第一步定義,並且在訓練期間不會更新。

pruning_params_2_by_4 = {
    'sparsity_m_by_n': (2, 4),
}

定義目標稀疏性為 50% 的隨機剪枝參數。

pruning_params_sparsity_0_5 = {
    'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(target_sparsity=0.5,
                                                              begin_step=0,
                                                              frequency=100)
}

定義模型架構並指定要剪枝的層。結構化剪枝會根據您選取的模型層套用。

在以下範例中,我們僅剪枝部分層。我們剪枝第二個 Conv2D 層和第一個 Dense 層。

請注意,第一個 Conv2D 層無法進行結構化剪枝。若要進行結構化剪枝,它應該具有多個輸入通道。相反地,我們使用隨機剪枝來剪枝第一個 Conv2D 層。

model = keras.Sequential([
    prune_low_magnitude(
        keras.layers.Conv2D(
            32, 5, padding='same', activation='relu',
            input_shape=(28, 28, 1),
            name="pruning_sparsity_0_5"),
        **pruning_params_sparsity_0_5),
    keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same'),
    prune_low_magnitude(
        keras.layers.Conv2D(
            64, 5, padding='same',
            name="structural_pruning"),
        **pruning_params_2_by_4),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same'),
    keras.layers.Flatten(),
    prune_low_magnitude(
        keras.layers.Dense(
            1024, activation='relu',
            name="structural_pruning_dense"),
        **pruning_params_2_by_4),
    keras.layers.Dropout(0.4),
    keras.layers.Dense(10)
])

model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.summary()
2024-03-09 12:19:11.497336: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 prune_low_magnitude_prunin  (None, 28, 28, 32)        1634      
 g_sparsity_0_5 (PruneLowMa                                      
 gnitude)                                                        
                                                                 
 max_pooling2d (MaxPooling2  (None, 14, 14, 32)        0         
 D)                                                              
                                                                 
 prune_low_magnitude_struct  (None, 14, 14, 64)        102466    
 ural_pruning (PruneLowMagn                                      
 itude)                                                          
                                                                 
 batch_normalization (Batch  (None, 14, 14, 64)        256       
 Normalization)                                                  
                                                                 
 re_lu (ReLU)                (None, 14, 14, 64)        0         
                                                                 
 max_pooling2d_1 (MaxPoolin  (None, 7, 7, 64)          0         
 g2D)                                                            
                                                                 
 flatten (Flatten)           (None, 3136)              0         
                                                                 
 prune_low_magnitude_struct  (None, 1024)              6423554   
 ural_pruning_dense (PruneL                                      
 owMagnitude)                                                    
                                                                 
 dropout (Dropout)           (None, 1024)              0         
                                                                 
 dense (Dense)               (None, 10)                10250     
                                                                 
=================================================================
Total params: 6538160 (24.94 MB)
Trainable params: 3274762 (12.49 MB)
Non-trainable params: 3263398 (12.45 MB)
_________________________________________________________________

訓練和評估模型。

batch_size = 128
epochs = 2

model.fit(
    train_images,
    train_labels,
    batch_size=batch_size,
    epochs=epochs,
    verbose=0,
    callbacks=tfmot.sparsity.keras.UpdatePruningStep(),
    validation_split=0.1)

_, pruned_model_accuracy = model.evaluate(test_images, test_labels, verbose=0)
print('Pruned test accuracy:', pruned_model_accuracy)
Pruned test accuracy: 0.9897000193595886

移除剪枝包裝函式,使其在您將模型轉換為 TensorFlow Lite 格式時不會包含在模型中。

model = tfmot.sparsity.keras.strip_pruning(model)

將模型轉換為 tflite 格式

import tempfile

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

_, tflite_file = tempfile.mkstemp('.tflite')
print('Saved converted pruned model to:', tflite_file)
with open(tflite_file, 'wb') as f:
  f.write(tflite_model)
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp04kvq4rj/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp04kvq4rj/assets
Saved converted pruned model to: /tmpfs/tmp/tmp218fgsbq.tflite
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1709986802.425001   13320 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format.
W0000 00:00:1709986802.425052   13320 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.

視覺化和檢查權重

現在視覺化以 2x4 稀疏性剪枝的 Dense 層中權重的結構。從 tflite 檔案中擷取權重。

# Load tflite file with the created pruned model
interpreter = tf.lite.Interpreter(model_path=tflite_file, experimental_preserve_all_tensors=True)
interpreter.allocate_tensors()

details = interpreter.get_tensor_details()

# Weights of the dense layer that has been pruned.
tensor_name = 'structural_pruning_dense/MatMul'
detail = [x for x in details if tensor_name in x["name"]]

# We need the first layer.
tensor_data = interpreter.tensor(detail[0]["index"])()

為了驗證我們選取了正確的剪枝層,請列印權重張量的形狀。

print(f"Shape of Dense layer is {tensor_data.shape}")
Shape of Dense layer is (1, 1024)

現在我們視覺化權重張量小子集的結構。權重張量的結構在最後一個維度中是稀疏的,使用 (2,4) 模式:四個元素中有兩個為零。為了使視覺化更清晰,我們將所有非零值替換為一。

import matplotlib.pyplot as plt
import numpy as np

# The value 24 is chosen for convenience.
width = height = 24

subset_values_to_display = tensor_data[0:height, 0:width]

val_ones = np.ones([height, width])
val_zeros = np.zeros([height, width])
subset_values_to_display = np.where(abs(subset_values_to_display) > 0, val_ones, val_zeros)

定義輔助函式以繪製分隔線,以便清楚地看到結構。

def plot_separation_lines(height, width):

    block_size = [1, 4]

    # Add separation lines to the figure.
    num_hlines = int((height - 1) / block_size[0])
    num_vlines = int((width - 1) / block_size[1])
    line_y_pos = [y * block_size[0] for y in range(1, num_hlines + 1)]
    line_x_pos = [x * block_size[1] for x in range(1, num_vlines + 1)]

    for y_pos in line_y_pos:
        plt.plot([-0.5, width], [y_pos - 0.5 , y_pos - 0.5], color='w')

    for x_pos in line_x_pos:
        plt.plot([x_pos - 0.5, x_pos - 0.5], [-0.5, height], color='w')

現在視覺化權重張量的子集。

plot_separation_lines(height, width)

plt.axis('off')
plt.imshow(subset_values_to_display)
plt.colorbar()
plt.title("Structural pruning for Dense layer")
plt.show()

png

視覺化 Conv2D 層的權重。結構化稀疏性套用在最後一個通道中,類似於 Dense 層。如上所述,只有第二個 Conv2D 層進行結構化剪枝。

# Get weights of the convolutional layer that has been pruned with 2 by 4 sparsity.
op_details = interpreter._get_ops_details()
op_name = 'CONV_2D'
op_detail = [x for x in op_details if op_name in x["op_name"]]
tensor_data = interpreter.tensor(op_detail[1]["inputs"][1])()
print(f"Shape of the weight tensor is {tensor_data.shape}")
Shape of the weight tensor is (64, 5, 5, 32)

Dense 層的權重類似,核心的最後一個維度具有 (2, 4) 結構。

weights_to_display = tf.reshape(tensor_data, [tf.reduce_prod(tensor_data.shape[:-1]), -1])
weights_to_display = weights_to_display[0:width, 0:height]

val_ones = np.ones([height, width])
val_zeros = np.zeros([height, width])
subset_values_to_display = np.where(abs(weights_to_display) > 1e-9, val_ones, val_zeros)

plot_separation_lines(height, width)

plt.axis('off')
plt.imshow(subset_values_to_display)
plt.colorbar()
plt.title("Structurally pruned weights for Conv2D layer")
plt.show()

png

讓我們看看這些隨機剪枝權重的外觀。我們擷取它們並顯示權重張量的子集。

# Get weights of the convolutional layer that has been pruned with random pruning.
tensor_name = 'pruning_sparsity_0_5/Conv2D'
detail = [x for x in details if tensor_name in x["name"]]
tensor_data = interpreter.tensor(detail[0]["index"])()
print(f"Shape of the weight tensor is {tensor_data.shape}")
Shape of the weight tensor is (32, 5, 5, 1)
weights_to_display = tf.reshape(tensor_data, [tensor_data.shape[0],tf.reduce_prod(tensor_data.shape[1:])])
weights_to_display = weights_to_display[0:width, 0:height]

val_ones = np.ones([height, width])
val_zeros = np.zeros([height, width])
subset_values_to_display = np.where(abs(weights_to_display) > 0, val_ones, val_zeros)

plot_separation_lines(height, width)

plt.axis('off')
plt.imshow(subset_values_to_display)
plt.colorbar()
plt.title("Unstructed pruned weights for Conv2D layer")
plt.show()

png

TensorFlow 模型最佳化工具組包含一個 python 腳本,可用於檢查給定 tflite 檔案中模型的哪些層具有結構化剪枝權重:check_sparsity_m_by_n.py。以下命令示範如何使用此工具檢查特定模型中的 2x4 稀疏性。

 python3 ./tensorflow_model_optimization/python/core/sparsity/keras/tools/check_sparsity_m_by_n.py --model_tflite=pruned_model.tflite --m_by_n=2,4
python3: can't open file '/tmpfs/src/temp/tensorflow_model_optimization/g3doc/guide/pruning/./tensorflow_model_optimization/python/core/sparsity/keras/tools/check_sparsity_m_by_n.py': [Errno 2] No such file or directory