剪枝完整指南

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視原始碼 下載筆記本

歡迎來到 Keras 權重剪枝完整指南。

本頁記錄各種使用情境,並說明如何針對各種情境使用 API。瞭解您需要哪些 API 後,請在 API 文件中尋找參數和低階詳細資訊。

  • 如要查看剪枝的優點和支援項目,請參閱總覽
  • 如需單一端對端範例,請參閱剪枝範例

涵蓋以下使用情境

  • 定義並訓練剪枝模型。
    • 序列式和函式式。
    • Keras model.fit 和自訂訓練迴圈
  • 檢查點和還原序列化剪枝模型。
  • 部署剪枝模型並查看壓縮優勢。

如需剪枝演算法的設定,請參閱 tfmot.sparsity.keras.prune_low_magnitude API 文件。

設定

為了尋找您需要的 API 並協助理解,您可以執行本節,但可略過閱讀。

! pip install -q tensorflow-model-optimization

import tensorflow as tf
import numpy as np
import tensorflow_model_optimization as tfmot
import tf_keras as keras

%load_ext tensorboard

import tempfile

input_shape = [20]
x_train = np.random.randn(1, 20).astype(np.float32)
y_train = keras.utils.to_categorical(np.random.randn(1), num_classes=20)

def setup_model():
  model = keras.Sequential([
      keras.layers.Dense(20, input_shape=input_shape),
      keras.layers.Flatten()
  ])
  return model

def setup_pretrained_weights():
  model = setup_model()

  model.compile(
      loss=keras.losses.categorical_crossentropy,
      optimizer='adam',
      metrics=['accuracy']
  )

  model.fit(x_train, y_train)

  _, pretrained_weights = tempfile.mkstemp('.tf')

  model.save_weights(pretrained_weights)

  return pretrained_weights

def get_gzipped_model_size(model):
  # Returns size of gzipped model, in bytes.
  import os
  import zipfile

  _, keras_file = tempfile.mkstemp('.h5')
  model.save(keras_file, include_optimizer=False)

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(keras_file)

  return os.path.getsize(zipped_file)

setup_model()
pretrained_weights = setup_pretrained_weights()
2024-03-09 12:22:11.550860: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

定義模型

剪枝整個模型 (序列式和函式式)

提升模型準確度的訣竅

  • 嘗試「剪枝部分層」以略過剪枝最會降低準確度的層。
  • 一般來說,使用剪枝進行微調會比從頭開始訓練更好。

如要讓整個模型透過剪枝進行訓練,請將 tfmot.sparsity.keras.prune_low_magnitude 套用至模型。

base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended.

model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)

model_for_pruning.summary()
Model: "sequential_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 prune_low_magnitude_dense_  (None, 20)                822       
 2 (PruneLowMagnitude)                                           
                                                                 
 prune_low_magnitude_flatte  (None, 20)                1         
 n_2 (PruneLowMagnitude)                                         
                                                                 
=================================================================
Total params: 823 (3.22 KB)
Trainable params: 420 (1.64 KB)
Non-trainable params: 403 (1.58 KB)
_________________________________________________________________

剪枝部分層 (序列式和函式式)

剪枝模型可能會對準確度造成負面影響。您可以選擇性地剪枝模型的層,以探索準確度、速度和模型大小之間的取捨。

提升模型準確度的訣竅

  • 一般來說,使用剪枝進行微調會比從頭開始訓練更好。
  • 嘗試剪枝後面的層,而不是前面的層。
  • 避免剪枝重要層 (例如注意力機制)。

更多資訊:

在以下範例中,僅剪枝 Dense 層。

# Create a base model
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy

# Helper function uses `prune_low_magnitude` to make only the 
# Dense layers train with pruning.
def apply_pruning_to_dense(layer):
  if isinstance(layer, keras.layers.Dense):
    return tfmot.sparsity.keras.prune_low_magnitude(layer)
  return layer

# Use `keras.models.clone_model` to apply `apply_pruning_to_dense` 
# to the layers of the model.
model_for_pruning = keras.models.clone_model(
    base_model,
    clone_function=apply_pruning_to_dense,
)

model_for_pruning.summary()
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://tensorflow.dev.org.tw/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://tensorflow.dev.org.tw/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4
Model: "sequential_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 prune_low_magnitude_dense_  (None, 20)                822       
 3 (PruneLowMagnitude)                                           
                                                                 
 flatten_3 (Flatten)         (None, 20)                0         
                                                                 
=================================================================
Total params: 822 (3.21 KB)
Trainable params: 420 (1.64 KB)
Non-trainable params: 402 (1.57 KB)
_________________________________________________________________

雖然此範例使用層的類型來決定要剪枝的項目,但剪枝特定層最簡單的方式是設定其 name 屬性,並在 clone_function 中尋找該名稱。

print(base_model.layers[0].name)
dense_3

更易讀,但模型準確度可能較低

這與使用剪枝進行微調不相容,因此準確度可能不如支援微調的上述範例。

雖然可以在定義初始模型時套用 prune_low_magnitude,但在以下範例中,之後載入權重無法運作。

函式式範例

# Use `prune_low_magnitude` to make the `Dense` layer train with pruning.
i = keras.Input(shape=(20,))
x = tfmot.sparsity.keras.prune_low_magnitude(keras.layers.Dense(10))(i)
o = keras.layers.Flatten()(x)
model_for_pruning = keras.Model(inputs=i, outputs=o)

model_for_pruning.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 20)]              0         
                                                                 
 prune_low_magnitude_dense_  (None, 10)                412       
 4 (PruneLowMagnitude)                                           
                                                                 
 flatten_4 (Flatten)         (None, 10)                0         
                                                                 
=================================================================
Total params: 412 (1.61 KB)
Trainable params: 210 (840.00 Byte)
Non-trainable params: 202 (812.00 Byte)
_________________________________________________________________

序列式範例

# Use `prune_low_magnitude` to make the `Dense` layer train with pruning.
model_for_pruning = keras.Sequential([
  tfmot.sparsity.keras.prune_low_magnitude(keras.layers.Dense(20, input_shape=input_shape)),
  keras.layers.Flatten()
])

model_for_pruning.summary()
Model: "sequential_4"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 prune_low_magnitude_dense_  (None, 20)                822       
 5 (PruneLowMagnitude)                                           
                                                                 
 flatten_5 (Flatten)         (None, 20)                0         
                                                                 
=================================================================
Total params: 822 (3.21 KB)
Trainable params: 420 (1.64 KB)
Non-trainable params: 402 (1.57 KB)
_________________________________________________________________

剪枝自訂 Keras 層或修改要剪枝的層部分

常見錯誤:剪枝偏差通常會對模型準確度造成過大的損害。

tfmot.sparsity.keras.PrunableLayer 適用於兩種使用情境

  1. 剪枝自訂 Keras 層
  2. 修改內建 Keras 層的部分進行剪枝。

例如,API 預設為僅剪枝 Dense 層的核心。以下範例也會剪枝偏差。

class MyDenseLayer(keras.layers.Dense, tfmot.sparsity.keras.PrunableLayer):

  def get_prunable_weights(self):
    # Prune bias also, though that usually harms model accuracy too much.
    return [self.kernel, self.bias]

# Use `prune_low_magnitude` to make the `MyDenseLayer` layer train with pruning.
model_for_pruning = keras.Sequential([
  tfmot.sparsity.keras.prune_low_magnitude(MyDenseLayer(20, input_shape=input_shape)),
  keras.layers.Flatten()
])

model_for_pruning.summary()
Model: "sequential_5"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 prune_low_magnitude_my_den  (None, 20)                843       
 se_layer (PruneLowMagnitud                                      
 e)                                                              
                                                                 
 flatten_6 (Flatten)         (None, 20)                0         
                                                                 
=================================================================
Total params: 843 (3.30 KB)
Trainable params: 420 (1.64 KB)
Non-trainable params: 423 (1.66 KB)
_________________________________________________________________

訓練模型

Model.fit

在訓練期間呼叫 tfmot.sparsity.keras.UpdatePruningStep 回呼。

為了協助偵錯訓練,請使用 tfmot.sparsity.keras.PruningSummaries 回呼。

# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)

log_dir = tempfile.mkdtemp()
callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep(),
    # Log sparsity and other metrics in Tensorboard.
    tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir)
]

model_for_pruning.compile(
      loss=keras.losses.categorical_crossentropy,
      optimizer='adam',
      metrics=['accuracy']
)

model_for_pruning.fit(
    x_train,
    y_train,
    callbacks=callbacks,
    epochs=2,
)

#docs_infra: no_execute
%tensorboard --logdir={log_dir}

非 Colab 使用者可以在此連結上查看先前執行此程式碼區塊在 TensorBoard.dev 的結果。

自訂訓練迴圈

在訓練期間呼叫 tfmot.sparsity.keras.UpdatePruningStep 回呼。

為了協助偵錯訓練,請使用 tfmot.sparsity.keras.PruningSummaries 回呼。

# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)

# Boilerplate
loss = keras.losses.categorical_crossentropy
optimizer = keras.optimizers.Adam()
log_dir = tempfile.mkdtemp()
unused_arg = -1
epochs = 2
batches = 1 # example is hardcoded so that the number of batches cannot change.

# Non-boilerplate.
model_for_pruning.optimizer = optimizer
step_callback = tfmot.sparsity.keras.UpdatePruningStep()
step_callback.set_model(model_for_pruning)
log_callback = tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir) # Log sparsity and other metrics in Tensorboard.
log_callback.set_model(model_for_pruning)

step_callback.on_train_begin() # run pruning callback
for _ in range(epochs):
  log_callback.on_epoch_begin(epoch=unused_arg) # run pruning callback
  for _ in range(batches):
    step_callback.on_train_batch_begin(batch=unused_arg) # run pruning callback

    with tf.GradientTape() as tape:
      logits = model_for_pruning(x_train, training=True)
      loss_value = loss(y_train, logits)
      grads = tape.gradient(loss_value, model_for_pruning.trainable_variables)
      optimizer.apply_gradients(zip(grads, model_for_pruning.trainable_variables))

  step_callback.on_epoch_end(batch=unused_arg) # run pruning callback

#docs_infra: no_execute
%tensorboard --logdir={log_dir}

非 Colab 使用者可以在此連結上查看先前執行此程式碼區塊在 TensorBoard.dev 的結果。

提升剪枝模型準確度

首先,請查看 tfmot.sparsity.keras.prune_low_magnitude API 文件,以瞭解剪枝排程是什麼,以及各種剪枝排程的數學原理。

訣竅:

  • 在模型剪枝時,請使用不高也不低的學習率。將剪枝排程視為超參數。

  • 快速測試時,嘗試在訓練開始時將模型剪枝至最終稀疏性,方法是將 begin_step 設定為 0 並搭配 tfmot.sparsity.keras.ConstantSparsity 排程。您可能會幸運獲得良好結果。

  • 請勿過於頻繁地剪枝,讓模型有時間恢復。剪枝排程提供合理的預設頻率。

  • 如需提升模型準確度的一般概念,請在「定義模型」下尋找適用於您的使用情境的訣竅。

檢查點和還原序列化

您必須在檢查點期間保留最佳化工具步驟。這表示雖然您可以使用 Keras HDF5 模型進行檢查點,但無法使用 Keras HDF5 權重。

# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)

_, keras_model_file = tempfile.mkstemp('.h5')

# Checkpoint: saving the optimizer is necessary (include_optimizer=True is the default).
model_for_pruning.save(keras_model_file, include_optimizer=True)
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://tensorflow.dev.org.tw/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://tensorflow.dev.org.tw/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tf_keras/src/engine/training.py:3098: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native TF-Keras format, e.g. `model.save('my_model.keras')`.
  saving_api.save_model(
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.

以上內容普遍適用。以下程式碼僅適用於 HDF5 模型格式 (而非 HDF5 權重和其他格式)。

# Deserialize model.
with tfmot.sparsity.keras.prune_scope():
  loaded_model = keras.models.load_model(keras_model_file)

loaded_model.summary()
WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually.
WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually.
Model: "sequential_6"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 prune_low_magnitude_dense_  (None, 20)                822       
 6 (PruneLowMagnitude)                                           
                                                                 
 prune_low_magnitude_flatte  (None, 20)                1         
 n_7 (PruneLowMagnitude)                                         
                                                                 
=================================================================
Total params: 823 (3.22 KB)
Trainable params: 420 (1.64 KB)
Non-trainable params: 403 (1.58 KB)
_________________________________________________________________

部署剪枝模型

匯出已壓縮大小的模型

常見錯誤:必須同時使用 strip_pruning 和套用標準壓縮演算法 (例如透過 gzip) 才能看到剪枝的壓縮優勢。

# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)

# Typically you train the model here.

model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

print("final model")
model_for_export.summary()

print("\n")
print("Size of gzipped pruned model without stripping: %.2f bytes" % (get_gzipped_model_size(model_for_pruning)))
print("Size of gzipped pruned model with stripping: %.2f bytes" % (get_gzipped_model_size(model_for_export)))
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://tensorflow.dev.org.tw/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://tensorflow.dev.org.tw/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4
final model
Model: "sequential_7"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense_7 (Dense)             (None, 20)                420       
                                                                 
 flatten_8 (Flatten)         (None, 20)                0         
                                                                 
=================================================================
Total params: 420 (1.64 KB)
Trainable params: 420 (1.64 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
Size of gzipped pruned model without stripping: 3455.00 bytes
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
Size of gzipped pruned model with stripping: 2939.00 bytes

硬體專屬最佳化

一旦不同的後端啟用剪枝來改善延遲時間,針對特定硬體使用區塊稀疏性即可改善延遲時間。

增加區塊大小會降低目標模型準確度可達成的峰值稀疏性。儘管如此,延遲時間仍可改善。

如需區塊稀疏性支援項目的詳細資訊,請參閱 tfmot.sparsity.keras.prune_low_magnitude API 文件。

base_model = setup_model()

# For using intrinsics on a CPU with 128-bit registers, together with 8-bit
# quantized weights, a 1x16 block size is nice because the block perfectly
# fits into the register.
pruning_params = {'block_size': [1, 16]}
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model, **pruning_params)

model_for_pruning.summary()
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://tensorflow.dev.org.tw/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://tensorflow.dev.org.tw/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4
Model: "sequential_8"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 prune_low_magnitude_dense_  (None, 20)                822       
 8 (PruneLowMagnitude)                                           
                                                                 
 prune_low_magnitude_flatte  (None, 20)                1         
 n_9 (PruneLowMagnitude)                                         
                                                                 
=================================================================
Total params: 823 (3.22 KB)
Trainable params: 420 (1.64 KB)
Non-trainable params: 403 (1.58 KB)
_________________________________________________________________