歡迎來到 Keras 權重剪枝完整指南。
本頁記錄了各種使用情境,並說明如何針對每種情境使用 API。一旦您知道您需要的 API,請在 API 文件中找到參數和底層細節。
- 定義並訓練剪枝模型。
- Sequential 和 Functional。
- Keras model.fit 和自訂訓練迴圈
- 檢查點並還原序列化剪枝模型。
- 部署剪枝模型並查看壓縮優勢。
如需剪枝演算法的設定,請參閱 tfmot.sparsity.keras.prune_low_magnitude
API 文件。
為了找到您需要的 API 並理解目的,您可以執行但不閱讀本節。
! pip install -q tensorflow-model-optimization
import tensorflow as tf
import numpy as np
import tensorflow_model_optimization as tfmot
import tf_keras as keras
%load_ext tensorboard
import tempfile
input_shape = [20]
x_train = np.random.randn(1, 20).astype(np.float32)
y_train = keras.utils.to_categorical(np.random.randn(1), num_classes=20)
def setup_model():
model = keras.Sequential([
keras.layers.Dense(20, input_shape=input_shape),
return model
def setup_pretrained_weights():
model = setup_model()
model.fit(x_train, y_train)
_, pretrained_weights = tempfile.mkstemp('.tf')
return pretrained_weights
def get_gzipped_model_size(model):
# Returns size of gzipped model, in bytes.
import os
import zipfile
_, keras_file = tempfile.mkstemp('.h5')
model.save(keras_file, include_optimizer=False)
_, zipped_file = tempfile.mkstemp('.zip')
with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
return os.path.getsize(zipped_file)
pretrained_weights = setup_pretrained_weights()
剪枝整個模型 (Sequential 和 Functional)
- 嘗試「剪枝某些層」以跳過剪枝最降低準確度的層。
- 一般來說,使用剪枝進行微調比從頭開始訓練更好。
若要讓整個模型在訓練時進行剪枝,請將 tfmot.sparsity.keras.prune_low_magnitude
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended.
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)
Model: "sequential_2" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= prune_low_magnitude_dense_ (None, 20) 822 2 (PruneLowMagnitude) prune_low_magnitude_flatte (None, 20) 1 n_2 (PruneLowMagnitude) ================================================================= Total params: 823 (3.22 KB) Trainable params: 420 (1.64 KB) Non-trainable params: 403 (1.58 KB) _________________________________________________________________
剪枝某些層 (Sequential 和 Functional)
- 一般來說,使用剪枝進行微調比從頭開始訓練更好。
- 嘗試剪枝後面的層而不是前面的層。
- 避免剪枝關鍵層(例如注意力機制)。
API 文件提供了如何針對每層變更剪枝設定的詳細資訊。
在以下範例中,僅剪枝 Dense
# Create a base model
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
# Helper function uses `prune_low_magnitude` to make only the
# Dense layers train with pruning.
def apply_pruning_to_dense(layer):
if isinstance(layer, keras.layers.Dense):
return tfmot.sparsity.keras.prune_low_magnitude(layer)
return layer
# Use `keras.models.clone_model` to apply `apply_pruning_to_dense`
# to the layers of the model.
model_for_pruning = keras.models.clone_model(
Model: "sequential_3" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= prune_low_magnitude_dense_ (None, 20) 822 3 (PruneLowMagnitude) flatten_3 (Flatten) (None, 20) 0 ================================================================= Total params: 822 (3.21 KB) Trainable params: 420 (1.64 KB) Non-trainable params: 402 (1.57 KB) _________________________________________________________________
雖然此範例使用層的類型來決定要剪枝的內容,但剪枝特定層最簡單的方法是設定其 name
屬性,並在 clone_function
雖然可以在定義初始模型時套用 prune_low_magnitude
Functional 範例
# Use `prune_low_magnitude` to make the `Dense` layer train with pruning.
i = keras.Input(shape=(20,))
x = tfmot.sparsity.keras.prune_low_magnitude(keras.layers.Dense(10))(i)
o = keras.layers.Flatten()(x)
model_for_pruning = keras.Model(inputs=i, outputs=o)
Model: "model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) [(None, 20)] 0 prune_low_magnitude_dense_ (None, 10) 412 4 (PruneLowMagnitude) flatten_4 (Flatten) (None, 10) 0 ================================================================= Total params: 412 (1.61 KB) Trainable params: 210 (840.00 Byte) Non-trainable params: 202 (812.00 Byte) _________________________________________________________________
Sequential 範例
# Use `prune_low_magnitude` to make the `Dense` layer train with pruning.
model_for_pruning = keras.Sequential([
tfmot.sparsity.keras.prune_low_magnitude(keras.layers.Dense(20, input_shape=input_shape)),
Model: "sequential_4" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= prune_low_magnitude_dense_ (None, 20) 822 5 (PruneLowMagnitude) flatten_5 (Flatten) (None, 20) 0 ================================================================= Total params: 822 (3.21 KB) Trainable params: 420 (1.64 KB) Non-trainable params: 402 (1.57 KB) _________________________________________________________________
剪枝自訂 Keras 層或修改層的部分以進行剪枝
- 剪枝自訂 Keras 層
- 修改內建 Keras 層的部分以進行剪枝。
例如,API 預設僅剪枝 Dense
class MyDenseLayer(keras.layers.Dense, tfmot.sparsity.keras.PrunableLayer):
def get_prunable_weights(self):
# Prune bias also, though that usually harms model accuracy too much.
return [self.kernel, self.bias]
# Use `prune_low_magnitude` to make the `MyDenseLayer` layer train with pruning.
model_for_pruning = keras.Sequential([
tfmot.sparsity.keras.prune_low_magnitude(MyDenseLayer(20, input_shape=input_shape)),
Model: "sequential_5" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= prune_low_magnitude_my_den (None, 20) 843 se_layer (PruneLowMagnitud e) flatten_6 (Flatten) (None, 20) 0 ================================================================= Total params: 843 (3.30 KB) Trainable params: 420 (1.64 KB) Non-trainable params: 423 (1.66 KB) _________________________________________________________________
在訓練期間呼叫 tfmot.sparsity.keras.UpdatePruningStep
為了幫助偵錯訓練,請使用 tfmot.sparsity.keras.PruningSummaries
# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)
log_dir = tempfile.mkdtemp()
callbacks = [
# Log sparsity and other metrics in Tensorboard.
#docs_infra: no_execute
%tensorboard --logdir={log_dir}
對於非 Colab 使用者,您可以在 TensorBoard.dev 上查看先前執行此程式碼區塊的結果。
在訓練期間呼叫 tfmot.sparsity.keras.UpdatePruningStep
為了幫助偵錯訓練,請使用 tfmot.sparsity.keras.PruningSummaries
# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)
# Boilerplate
loss = keras.losses.categorical_crossentropy
optimizer = keras.optimizers.Adam()
log_dir = tempfile.mkdtemp()
unused_arg = -1
epochs = 2
batches = 1 # example is hardcoded so that the number of batches cannot change.
# Non-boilerplate.
model_for_pruning.optimizer = optimizer
step_callback = tfmot.sparsity.keras.UpdatePruningStep()
log_callback = tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir) # Log sparsity and other metrics in Tensorboard.
step_callback.on_train_begin() # run pruning callback
for _ in range(epochs):
log_callback.on_epoch_begin(epoch=unused_arg) # run pruning callback
for _ in range(batches):
step_callback.on_train_batch_begin(batch=unused_arg) # run pruning callback
with tf.GradientTape() as tape:
logits = model_for_pruning(x_train, training=True)
loss_value = loss(y_train, logits)
grads = tape.gradient(loss_value, model_for_pruning.trainable_variables)
optimizer.apply_gradients(zip(grads, model_for_pruning.trainable_variables))
step_callback.on_epoch_end(batch=unused_arg) # run pruning callback
#docs_infra: no_execute
%tensorboard --logdir={log_dir}
對於非 Colab 使用者,您可以在 TensorBoard.dev 上查看先前執行此程式碼區塊的結果。
首先,請查看 tfmot.sparsity.keras.prune_low_magnitude
API 文件以瞭解剪枝排程是什麼以及每種類型剪枝排程的數學原理。
設定為 0。您可能會幸運地獲得良好的結果。不要太頻繁地剪枝,讓模型有時間恢復。剪枝排程提供了合理的預設頻率。
您必須在檢查點期間保留最佳化器步驟。這表示雖然您可以使用 Keras HDF5 模型進行檢查點,但您不能使用 Keras HDF5 權重。
# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)
_, keras_model_file = tempfile.mkstemp('.h5')
# Checkpoint: saving the optimizer is necessary (include_optimizer=True is the default).
model_for_pruning.save(keras_model_file, include_optimizer=True)
# Checkpoint: saving the optimizer is necessary (include_optimizer=True is the default).
model_for_pruning.save(keras_model_file, include_optimizer=True)
以上內容普遍適用。以下程式碼僅適用於 HDF5 模型格式(而非 HDF5 權重和其他格式)。
# Deserialize model.
with tfmot.sparsity.keras.prune_scope():
loaded_model = keras.models.load_model(keras_model_file)
WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually. WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually. Model: "sequential_6" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= prune_low_magnitude_dense_ (None, 20) 822 6 (PruneLowMagnitude) prune_low_magnitude_flatte (None, 20) 1 n_7 (PruneLowMagnitude) ================================================================= Total params: 823 (3.22 KB) Trainable params: 420 (1.64 KB) Non-trainable params: 403 (1.58 KB) _________________________________________________________________
常見錯誤:strip_pruning 和套用標準壓縮演算法(例如透過 gzip)都是看到剪枝壓縮優勢所必需的。
# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)
# Typically you train the model here.
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
print("final model")
print("Size of gzipped pruned model without stripping: %.2f bytes" % (get_gzipped_model_size(model_for_pruning)))
print("Size of gzipped pruned model with stripping: %.2f bytes" % (get_gzipped_model_size(model_for_export)))
final model
Model: "sequential_7" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_7 (Dense) (None, 20) 420 flatten_8 (Flatten) (None, 20) 0 ================================================================= Total params: 420 (1.64 KB) Trainable params: 420 (1.64 KB) Non-trainable params: 0 (0.00 Byte) _________________________________________________________________ WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model. WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model. Size of gzipped pruned model without stripping: 3455.00 bytes WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model. WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model. Size of gzipped pruned model with stripping: 2939.00 bytes
一旦不同的後端 啟用剪枝以改善延遲,使用區塊稀疏性可以改善某些硬體的延遲。
如需區塊稀疏性支援哪些功能的詳細資訊,請參閱 tfmot.sparsity.keras.prune_low_magnitude
API 文件。
base_model = setup_model()
# For using intrinsics on a CPU with 128-bit registers, together with 8-bit
# quantized weights, a 1x16 block size is nice because the block perfectly
# fits into the register.
pruning_params = {'block_size': [1, 16]}
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model, **pruning_params)
Model: "sequential_8" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= prune_low_magnitude_dense_ (None, 20) 822 8 (PruneLowMagnitude) prune_low_magnitude_flatte (None, 20) 1 n_9 (PruneLowMagnitude) ================================================================= Total params: 823 (3.22 KB) Trainable params: 420 (1.64 KB) Non-trainable params: 403 (1.58 KB) _________________________________________________________________