![]() |
![]() |
![]() |
![]() |
總覽
這個筆記本說明如何使用神經網路和 TensorFlow Compression 進行有損資料壓縮。
有損壓縮牽涉在速率 (編碼樣本所需的預期位元數) 和失真 (樣本重建中的預期誤差) 之間進行權衡取捨。
以下範例使用類似自動編碼器的模型壓縮來自 MNIST 資料集的圖片。這個方法是以論文 End-to-end Optimized Image Compression 為基礎。
如要深入瞭解已學習的資料壓縮背景資訊,請參閱 這篇論文 (目標讀者為熟悉傳統資料壓縮技術的人員) 或 這份調查 (目標讀者為機器學習受眾)。
設定
透過 pip
安裝 Tensorflow Compression。
# Installs the latest version of TFC compatible with the installed TF version.
read MAJOR MINOR <<< "$(pip show tensorflow | perl -p -0777 -e 's/.*Version: (\d+)\.(\d+).*/\1 \2/sg')"
pip install "tensorflow-compression<$MAJOR.$(($MINOR+1))"
匯入程式庫依附元件。
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_compression as tfc
import tensorflow_datasets as tfds
2024-02-02 02:21:52.947295: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-02-02 02:21:52.947345: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-02-02 02:21:52.947388: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
定義訓練器模型。
由於這個模型類似自動編碼器,而且我們需要在訓練和推論期間執行不同的功能集,因此設定方式與分類器等模型略有不同。
訓練模型包含三個部分
- 分析 (或編碼器) 轉換,從圖片轉換為潛在空間、
- 合成 (或解碼器) 轉換,從潛在空間轉換回圖片空間,以及
- 先驗和熵模型,為潛在變數的邊際機率建立模型。
首先,定義轉換
def make_analysis_transform(latent_dims):
"""Creates the analysis (encoder) transform."""
return tf.keras.Sequential([
tf.keras.layers.Conv2D(
20, 5, use_bias=True, strides=2, padding="same",
activation="leaky_relu", name="conv_1"),
tf.keras.layers.Conv2D(
50, 5, use_bias=True, strides=2, padding="same",
activation="leaky_relu", name="conv_2"),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(
500, use_bias=True, activation="leaky_relu", name="fc_1"),
tf.keras.layers.Dense(
latent_dims, use_bias=True, activation=None, name="fc_2"),
], name="analysis_transform")
def make_synthesis_transform():
"""Creates the synthesis (decoder) transform."""
return tf.keras.Sequential([
tf.keras.layers.Dense(
500, use_bias=True, activation="leaky_relu", name="fc_1"),
tf.keras.layers.Dense(
2450, use_bias=True, activation="leaky_relu", name="fc_2"),
tf.keras.layers.Reshape((7, 7, 50)),
tf.keras.layers.Conv2DTranspose(
20, 5, use_bias=True, strides=2, padding="same",
activation="leaky_relu", name="conv_1"),
tf.keras.layers.Conv2DTranspose(
1, 5, use_bias=True, strides=2, padding="same",
activation="leaky_relu", name="conv_2"),
], name="synthesis_transform")
訓練器會同時保留轉換的執行個體,以及先驗的參數。
其 call
方法設定為計算
- 速率,這是代表數字批次所需位元數的估計值,以及
- 失真,原始數字的像素與其重建像素之間的平均絕對差。
class MNISTCompressionTrainer(tf.keras.Model):
"""Model that trains a compressor/decompressor for MNIST."""
def __init__(self, latent_dims):
super().__init__()
self.analysis_transform = make_analysis_transform(latent_dims)
self.synthesis_transform = make_synthesis_transform()
self.prior_log_scales = tf.Variable(tf.zeros((latent_dims,)))
@property
def prior(self):
return tfc.NoisyLogistic(loc=0., scale=tf.exp(self.prior_log_scales))
def call(self, x, training):
"""Computes rate and distortion losses."""
# Ensure inputs are floats in the range (0, 1).
x = tf.cast(x, self.compute_dtype) / 255.
x = tf.reshape(x, (-1, 28, 28, 1))
# Compute latent space representation y, perturb it and model its entropy,
# then compute the reconstructed pixel-level representation x_hat.
y = self.analysis_transform(x)
entropy_model = tfc.ContinuousBatchedEntropyModel(
self.prior, coding_rank=1, compression=False)
y_tilde, rate = entropy_model(y, training=training)
x_tilde = self.synthesis_transform(y_tilde)
# Average number of bits per MNIST digit.
rate = tf.reduce_mean(rate)
# Mean absolute difference across pixels.
distortion = tf.reduce_mean(abs(x - x_tilde))
return dict(rate=rate, distortion=distortion)
計算速率和失真。
讓我們逐步完成這個步驟,並使用訓練集中的一張圖片。載入 MNIST 資料集以進行訓練和驗證
training_dataset, validation_dataset = tfds.load(
"mnist",
split=["train", "test"],
shuffle_files=True,
as_supervised=True,
with_info=False,
)
2024-02-02 02:21:57.270402: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://tensorflow.dev.org.tw/install/gpu for how to download and setup the required libraries for your platform. Skipping registering GPU devices...
並擷取一張圖片 \(x\)
(x, _), = validation_dataset.take(1)
plt.imshow(tf.squeeze(x))
print(f"Data type: {x.dtype}")
print(f"Shape: {x.shape}")
2024-02-02 02:21:57.719574: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Data type: <dtype: 'uint8'> Shape: (28, 28, 1)
如要取得潛在表示法 \(y\),我們需要將其轉換為 float32
、新增批次維度,然後透過分析轉換傳遞。
x = tf.cast(x, tf.float32) / 255.
x = tf.reshape(x, (-1, 28, 28, 1))
y = make_analysis_transform(10)(x)
print("y:", y)
y: tf.Tensor( [[ 0.02126924 -0.06894962 0.09302588 -0.03838679 -0.07932404 0.04434497 -0.01607926 0.03123217 -0.00262138 -0.10917155]], shape=(1, 10), dtype=float32)
潛在變數會在測試時間進行量化。為了在訓練期間以可微分的方式建立模型,我們會在間隔 \((-.5, .5)\) 中新增一致雜訊,並將結果稱為 \(\tilde y\)。這與論文 End-to-end Optimized Image Compression 中使用的術語相同。
y_tilde = y + tf.random.uniform(y.shape, -.5, .5)
print("y_tilde:", y_tilde)
y_tilde: tf.Tensor( [[ 0.3478785 0.08000787 -0.18317918 0.25482622 -0.31788376 -0.13823514 0.362665 -0.05973584 0.34111002 -0.1322978 ]], shape=(1, 10), dtype=float32)
「先驗」是一種機率密度,我們訓練這種密度來為雜訊潛在變數的邊際分佈建立模型。舉例來說,它可以是一組獨立的 邏輯分佈,每個潛在維度都有不同的規模。tfc.NoisyLogistic
說明了潛在變數具有加性雜訊的事實。隨著規模接近零,邏輯分佈會接近狄拉克 δ 函數 (尖峰),但新增的雜訊會導致「雜訊」分佈改為接近一致分佈。
prior = tfc.NoisyLogistic(loc=0., scale=tf.linspace(.01, 2., 10))
_ = tf.linspace(-6., 6., 501)[:, None]
plt.plot(_, prior.prob(_));
在訓練期間,tfc.ContinuousBatchedEntropyModel
會新增一致雜訊,並使用雜訊和先驗來計算速率的 (可微分) 上限 (編碼潛在表示法所需的平均位元數)。這個上限可以降到最低,作為損失。
entropy_model = tfc.ContinuousBatchedEntropyModel(
prior, coding_rank=1, compression=False)
y_tilde, rate = entropy_model(y, training=True)
print("rate:", rate)
print("y_tilde:", y_tilde)
rate: tf.Tensor([18.012472], shape=(1,), dtype=float32) y_tilde: tf.Tensor( [[-0.09349963 -0.02974442 -0.1147738 0.39858678 0.33601853 -0.13327162 -0.10957606 0.2511304 -0.15129623 0.25686544]], shape=(1, 10), dtype=float32)
最後,雜訊潛在變數會傳回合成轉換,以產生圖片重建 \(\tilde x\)。失真是原始圖片與重建圖片之間的誤差。顯然,在轉換未訓練的情況下,重建並非很有用。
x_tilde = make_synthesis_transform()(y_tilde)
# Mean absolute difference across pixels.
distortion = tf.reduce_mean(abs(x - x_tilde))
print("distortion:", distortion)
x_tilde = tf.saturate_cast(x_tilde[0] * 255, tf.uint8)
plt.imshow(tf.squeeze(x_tilde))
print(f"Data type: {x_tilde.dtype}")
print(f"Shape: {x_tilde.shape}")
distortion: tf.Tensor(0.17072156, shape=(), dtype=float32) Data type: <dtype: 'uint8'> Shape: (28, 28, 1)
對於每個數字批次,呼叫 MNISTCompressionTrainer
會產生速率和失真,作為該批次的平均值
(example_batch, _), = validation_dataset.batch(32).take(1)
trainer = MNISTCompressionTrainer(10)
example_output = trainer(example_batch)
print("rate: ", example_output["rate"])
print("distortion: ", example_output["distortion"])
rate: tf.Tensor(20.296253, shape=(), dtype=float32) distortion: tf.Tensor(0.14659302, shape=(), dtype=float32) 2024-02-02 02:21:58.788887: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
在下一個章節中,我們會設定模型,針對這兩個損失執行梯度下降。
訓練模型。
我們會以最佳化速率-失真拉格朗日量 (也就是速率和失真的總和) 的方式編譯訓練器,其中其中一個項目的權重由拉格朗日參數 \(\lambda\) 加權。
這個損失函數會以不同的方式影響模型的不同部分
- 訓練分析轉換以產生潛在表示法,進而達到速率和失真之間所需的權衡取捨。
- 訓練合成轉換以盡可能減少失真 (在給定潛在表示法的情況下)。
- 訓練先驗參數以盡可能減少速率 (在給定潛在表示法的情況下)。這與在最大可能性意義上將先驗擬合至潛在變數的邊際分佈相同。
def pass_through_loss(_, x):
# Since rate and distortion are unsupervised, the loss doesn't need a target.
return x
def make_mnist_compression_trainer(lmbda, latent_dims=50):
trainer = MNISTCompressionTrainer(latent_dims)
trainer.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
# Just pass through rate and distortion as losses/metrics.
loss=dict(rate=pass_through_loss, distortion=pass_through_loss),
metrics=dict(rate=pass_through_loss, distortion=pass_through_loss),
loss_weights=dict(rate=1., distortion=lmbda),
)
return trainer
接下來,訓練模型。由於我們只想壓縮圖片,因此這裡不需要人工標註,因此我們使用 map
捨棄人工標註,並改為新增速率和失真的「虛擬」目標。
def add_rd_targets(image, label):
# Training is unsupervised, so labels aren't necessary here. However, we
# need to add "dummy" targets for rate and distortion.
return image, dict(rate=0., distortion=0.)
def train_mnist_model(lmbda):
trainer = make_mnist_compression_trainer(lmbda)
trainer.fit(
training_dataset.map(add_rd_targets).batch(128).prefetch(8),
epochs=15,
validation_data=validation_dataset.map(add_rd_targets).batch(128).cache(),
validation_freq=1,
verbose=1,
)
return trainer
trainer = train_mnist_model(lmbda=2000)
Epoch 1/15 467/469 [============================>.] - ETA: 0s - loss: 219.4970 - distortion_loss: 0.0598 - rate_loss: 99.8539 - distortion_pass_through_loss: 0.0598 - rate_pass_through_loss: 99.8539 WARNING:absl:Computing quantization offsets using offset heuristic within a tf.function. Ideally, the offset heuristic should only be used to determine offsets once after training. Depending on the prior, estimating the offset might be computationally expensive. 469/469 [==============================] - 14s 23ms/step - loss: 219.3287 - distortion_loss: 0.0598 - rate_loss: 99.8235 - distortion_pass_through_loss: 0.0597 - rate_pass_through_loss: 99.8190 - val_loss: 176.6952 - val_distortion_loss: 0.0424 - val_rate_loss: 91.9455 - val_distortion_pass_through_loss: 0.0424 - val_rate_pass_through_loss: 91.9553 Epoch 2/15 469/469 [==============================] - 10s 21ms/step - loss: 166.0387 - distortion_loss: 0.0410 - rate_loss: 83.9876 - distortion_pass_through_loss: 0.0410 - rate_pass_through_loss: 83.9833 - val_loss: 156.0399 - val_distortion_loss: 0.0397 - val_rate_loss: 76.5447 - val_distortion_pass_through_loss: 0.0397 - val_rate_pass_through_loss: 76.5545 Epoch 3/15 469/469 [==============================] - 10s 21ms/step - loss: 150.7590 - distortion_loss: 0.0397 - rate_loss: 71.2698 - distortion_pass_through_loss: 0.0397 - rate_pass_through_loss: 71.2663 - val_loss: 143.7437 - val_distortion_loss: 0.0399 - val_rate_loss: 64.0326 - val_distortion_pass_through_loss: 0.0399 - val_rate_pass_through_loss: 64.0276 Epoch 4/15 469/469 [==============================] - 10s 21ms/step - loss: 142.2607 - distortion_loss: 0.0395 - rate_loss: 63.2149 - distortion_pass_through_loss: 0.0395 - rate_pass_through_loss: 63.2126 - val_loss: 136.2439 - val_distortion_loss: 0.0406 - val_rate_loss: 55.1126 - val_distortion_pass_through_loss: 0.0406 - val_rate_pass_through_loss: 55.0992 Epoch 5/15 469/469 [==============================] - 10s 21ms/step - loss: 136.6350 - distortion_loss: 0.0392 - rate_loss: 58.2967 - distortion_pass_through_loss: 0.0392 - rate_pass_through_loss: 58.2950 - val_loss: 131.8454 - val_distortion_loss: 0.0418 - val_rate_loss: 48.2437 - val_distortion_pass_through_loss: 0.0418 - val_rate_pass_through_loss: 48.2309 Epoch 6/15 469/469 [==============================] - 10s 21ms/step - loss: 132.8357 - distortion_loss: 0.0388 - rate_loss: 55.1559 - distortion_pass_through_loss: 0.0388 - rate_pass_through_loss: 55.1541 - val_loss: 126.8537 - val_distortion_loss: 0.0413 - val_rate_loss: 44.1647 - val_distortion_pass_through_loss: 0.0413 - val_rate_pass_through_loss: 44.1611 Epoch 7/15 469/469 [==============================] - 10s 21ms/step - loss: 129.6230 - distortion_loss: 0.0383 - rate_loss: 52.9344 - distortion_pass_through_loss: 0.0383 - rate_pass_through_loss: 52.9332 - val_loss: 123.1619 - val_distortion_loss: 0.0406 - val_rate_loss: 41.8766 - val_distortion_pass_through_loss: 0.0406 - val_rate_pass_through_loss: 41.8841 Epoch 8/15 469/469 [==============================] - 10s 21ms/step - loss: 127.0022 - distortion_loss: 0.0379 - rate_loss: 51.1485 - distortion_pass_through_loss: 0.0379 - rate_pass_through_loss: 51.1476 - val_loss: 119.7338 - val_distortion_loss: 0.0393 - val_rate_loss: 41.0842 - val_distortion_pass_through_loss: 0.0393 - val_rate_pass_through_loss: 41.0881 Epoch 9/15 469/469 [==============================] - 10s 21ms/step - loss: 124.5822 - distortion_loss: 0.0375 - rate_loss: 49.6230 - distortion_pass_through_loss: 0.0375 - rate_pass_through_loss: 49.6216 - val_loss: 118.7188 - val_distortion_loss: 0.0396 - val_rate_loss: 39.5464 - val_distortion_pass_through_loss: 0.0396 - val_rate_pass_through_loss: 39.5388 Epoch 10/15 469/469 [==============================] - 10s 20ms/step - loss: 122.2946 - distortion_loss: 0.0370 - rate_loss: 48.2343 - distortion_pass_through_loss: 0.0370 - rate_pass_through_loss: 48.2335 - val_loss: 116.0431 - val_distortion_loss: 0.0378 - val_rate_loss: 40.5313 - val_distortion_pass_through_loss: 0.0378 - val_rate_pass_through_loss: 40.5214 Epoch 11/15 469/469 [==============================] - 10s 20ms/step - loss: 120.2840 - distortion_loss: 0.0366 - rate_loss: 47.0336 - distortion_pass_through_loss: 0.0366 - rate_pass_through_loss: 47.0329 - val_loss: 115.0391 - val_distortion_loss: 0.0374 - val_rate_loss: 40.2651 - val_distortion_pass_through_loss: 0.0374 - val_rate_pass_through_loss: 40.2673 Epoch 12/15 469/469 [==============================] - 9s 20ms/step - loss: 118.5191 - distortion_loss: 0.0363 - rate_loss: 45.9115 - distortion_pass_through_loss: 0.0363 - rate_pass_through_loss: 45.9113 - val_loss: 113.4488 - val_distortion_loss: 0.0362 - val_rate_loss: 41.1207 - val_distortion_pass_through_loss: 0.0362 - val_rate_pass_through_loss: 41.1314 Epoch 13/15 469/469 [==============================] - 10s 20ms/step - loss: 116.7709 - distortion_loss: 0.0358 - rate_loss: 45.1241 - distortion_pass_through_loss: 0.0358 - rate_pass_through_loss: 45.1242 - val_loss: 112.9705 - val_distortion_loss: 0.0358 - val_rate_loss: 41.3719 - val_distortion_pass_through_loss: 0.0358 - val_rate_pass_through_loss: 41.3784 Epoch 14/15 469/469 [==============================] - 10s 20ms/step - loss: 115.5197 - distortion_loss: 0.0355 - rate_loss: 44.5284 - distortion_pass_through_loss: 0.0355 - rate_pass_through_loss: 44.5273 - val_loss: 111.5821 - val_distortion_loss: 0.0356 - val_rate_loss: 40.3043 - val_distortion_pass_through_loss: 0.0356 - val_rate_pass_through_loss: 40.3116 Epoch 15/15 469/469 [==============================] - 10s 21ms/step - loss: 114.3993 - distortion_loss: 0.0352 - rate_loss: 44.0437 - distortion_pass_through_loss: 0.0352 - rate_pass_through_loss: 44.0429 - val_loss: 110.9734 - val_distortion_loss: 0.0350 - val_rate_loss: 40.9841 - val_distortion_pass_through_loss: 0.0350 - val_rate_pass_through_loss: 40.9915
壓縮一些 MNIST 圖片。
為了在測試時間進行壓縮和解壓縮,我們將訓練後的模型分成兩個部分
- 編碼器端包含分析轉換和熵模型。
- 解碼器端包含合成轉換和相同的熵模型。
在測試時間,潛在變數不會有加性雜訊,但會經過量化,然後進行無損壓縮,因此我們會為其命名新名稱。我們分別將其和圖片重建稱為 \(\hat x\) 和 \(\hat y\) (遵循 End-to-end Optimized Image Compression)。
class MNISTCompressor(tf.keras.Model):
"""Compresses MNIST images to strings."""
def __init__(self, analysis_transform, entropy_model):
super().__init__()
self.analysis_transform = analysis_transform
self.entropy_model = entropy_model
def call(self, x):
# Ensure inputs are floats in the range (0, 1).
x = tf.cast(x, self.compute_dtype) / 255.
y = self.analysis_transform(x)
# Also return the exact information content of each digit.
_, bits = self.entropy_model(y, training=False)
return self.entropy_model.compress(y), bits
class MNISTDecompressor(tf.keras.Model):
"""Decompresses MNIST images from strings."""
def __init__(self, entropy_model, synthesis_transform):
super().__init__()
self.entropy_model = entropy_model
self.synthesis_transform = synthesis_transform
def call(self, string):
y_hat = self.entropy_model.decompress(string, ())
x_hat = self.synthesis_transform(y_hat)
# Scale and cast back to 8-bit integer.
return tf.saturate_cast(tf.round(x_hat * 255.), tf.uint8)
當以 compression=True
例項化時,熵模型會將已學習的先驗轉換為範圍編碼演算法的表格。當呼叫 compress()
時,系統會叫用這個演算法,將潛在空間向量轉換為位元序列。每個二進位字串的長度會趨近於潛在變數的資訊內容 (先驗下潛在變數的負對數概似)。
用於壓縮和解壓縮的熵模型必須是相同的執行個體,因為範圍編碼表格在兩側都必須完全相同。否則可能會發生解碼錯誤。
def make_mnist_codec(trainer, **kwargs):
# The entropy model must be created with `compression=True` and the same
# instance must be shared between compressor and decompressor.
entropy_model = tfc.ContinuousBatchedEntropyModel(
trainer.prior, coding_rank=1, compression=True, **kwargs)
compressor = MNISTCompressor(trainer.analysis_transform, entropy_model)
decompressor = MNISTDecompressor(entropy_model, trainer.synthesis_transform)
return compressor, decompressor
compressor, decompressor = make_mnist_codec(trainer)
從驗證資料集擷取 16 張圖片。您可以變更 skip
的引數來選取不同的子集。
(originals, _), = validation_dataset.batch(16).skip(3).take(1)
將其壓縮成字串,並追蹤每個字串的資訊內容 (以位元為單位)。
strings, entropies = compressor(originals)
print(f"String representation of first digit in hexadecimal: 0x{strings[0].numpy().hex()}")
print(f"Number of bits actually needed to represent it: {entropies[0]:0.2f}")
String representation of first digit in hexadecimal: 0x4d63d90ed1 Number of bits actually needed to represent it: 37.99
從字串將圖片解壓縮回來。
reconstructions = decompressor(strings)
顯示 16 個原始數字的每一個,以及其壓縮的二進位表示法和重建的數字。
display_digits(originals, strings, entropies, reconstructions)
請注意,編碼字串的長度與每個數字的資訊內容不同。
這是因為範圍編碼程序使用離散機率,而且會產生少量額外負荷。因此,特別是對於短字串,對應關係僅為近似值。不過,範圍編碼是漸近最佳:在極限情況下,預期位元計數會趨近於交叉熵 (預期的資訊內容),而訓練模型中的速率項是交叉熵的上限。
速率-失真權衡取捨
在上方,模型是針對特定權衡取捨 (由 lmbda=2000
指定) 進行訓練,這個權衡取捨介於用來表示每個數字的平均位元數與重建中產生的誤差之間。
當我們以不同的值重複實驗時,會發生什麼事?
我們先將 \(\lambda\) 減少為 500。
def train_and_visualize_model(lmbda):
trainer = train_mnist_model(lmbda=lmbda)
compressor, decompressor = make_mnist_codec(trainer)
strings, entropies = compressor(originals)
reconstructions = decompressor(strings)
display_digits(originals, strings, entropies, reconstructions)
train_and_visualize_model(lmbda=500)
Epoch 1/15 469/469 [==============================] - ETA: 0s - loss: 127.5305 - distortion_loss: 0.0700 - rate_loss: 92.5392 - distortion_pass_through_loss: 0.0700 - rate_pass_through_loss: 92.5329 WARNING:absl:Computing quantization offsets using offset heuristic within a tf.function. Ideally, the offset heuristic should only be used to determine offsets once after training. Depending on the prior, estimating the offset might be computationally expensive. 469/469 [==============================] - 12s 21ms/step - loss: 127.5305 - distortion_loss: 0.0700 - rate_loss: 92.5392 - distortion_pass_through_loss: 0.0700 - rate_pass_through_loss: 92.5329 - val_loss: 107.4705 - val_distortion_loss: 0.0548 - val_rate_loss: 80.0521 - val_distortion_pass_through_loss: 0.0549 - val_rate_pass_through_loss: 80.0560 Epoch 2/15 469/469 [==============================] - 10s 20ms/step - loss: 97.2982 - distortion_loss: 0.0541 - rate_loss: 70.2615 - distortion_pass_through_loss: 0.0541 - rate_pass_through_loss: 70.2563 - val_loss: 86.4681 - val_distortion_loss: 0.0599 - val_rate_loss: 56.5326 - val_distortion_pass_through_loss: 0.0599 - val_rate_pass_through_loss: 56.5317 Epoch 3/15 469/469 [==============================] - 10s 20ms/step - loss: 81.2346 - distortion_loss: 0.0562 - rate_loss: 53.1251 - distortion_pass_through_loss: 0.0562 - rate_pass_through_loss: 53.1215 - val_loss: 71.9185 - val_distortion_loss: 0.0682 - val_rate_loss: 37.8106 - val_distortion_pass_through_loss: 0.0683 - val_rate_pass_through_loss: 37.8011 Epoch 4/15 469/469 [==============================] - 10s 20ms/step - loss: 71.5505 - distortion_loss: 0.0593 - rate_loss: 41.8880 - distortion_pass_through_loss: 0.0593 - rate_pass_through_loss: 41.8856 - val_loss: 63.3954 - val_distortion_loss: 0.0762 - val_rate_loss: 25.3134 - val_distortion_pass_through_loss: 0.0762 - val_rate_pass_through_loss: 25.3113 Epoch 5/15 469/469 [==============================] - 10s 20ms/step - loss: 65.9471 - distortion_loss: 0.0622 - rate_loss: 34.8656 - distortion_pass_through_loss: 0.0622 - rate_pass_through_loss: 34.8637 - val_loss: 57.6301 - val_distortion_loss: 0.0786 - val_rate_loss: 18.3391 - val_distortion_pass_through_loss: 0.0786 - val_rate_pass_through_loss: 18.3298 Epoch 6/15 469/469 [==============================] - 9s 20ms/step - loss: 62.4032 - distortion_loss: 0.0642 - rate_loss: 30.3145 - distortion_pass_through_loss: 0.0642 - rate_pass_through_loss: 30.3131 - val_loss: 55.5859 - val_distortion_loss: 0.0848 - val_rate_loss: 13.1697 - val_distortion_pass_through_loss: 0.0848 - val_rate_pass_through_loss: 13.1684 Epoch 7/15 469/469 [==============================] - 9s 20ms/step - loss: 59.9169 - distortion_loss: 0.0656 - rate_loss: 27.1412 - distortion_pass_through_loss: 0.0655 - rate_pass_through_loss: 27.1403 - val_loss: 51.1707 - val_distortion_loss: 0.0758 - val_rate_loss: 13.2823 - val_distortion_pass_through_loss: 0.0758 - val_rate_pass_through_loss: 13.2810 Epoch 8/15 469/469 [==============================] - 9s 20ms/step - loss: 57.7210 - distortion_loss: 0.0660 - rate_loss: 24.7239 - distortion_pass_through_loss: 0.0660 - rate_pass_through_loss: 24.7234 - val_loss: 49.3867 - val_distortion_loss: 0.0736 - val_rate_loss: 12.5935 - val_distortion_pass_through_loss: 0.0736 - val_rate_pass_through_loss: 12.5888 Epoch 9/15 469/469 [==============================] - 10s 20ms/step - loss: 55.7091 - distortion_loss: 0.0658 - rate_loss: 22.7991 - distortion_pass_through_loss: 0.0658 - rate_pass_through_loss: 22.7988 - val_loss: 48.4335 - val_distortion_loss: 0.0695 - val_rate_loss: 13.6779 - val_distortion_pass_through_loss: 0.0695 - val_rate_pass_through_loss: 13.6865 Epoch 10/15 469/469 [==============================] - 10s 20ms/step - loss: 53.9030 - distortion_loss: 0.0651 - rate_loss: 21.3707 - distortion_pass_through_loss: 0.0651 - rate_pass_through_loss: 21.3701 - val_loss: 47.2973 - val_distortion_loss: 0.0689 - val_rate_loss: 12.8318 - val_distortion_pass_through_loss: 0.0690 - val_rate_pass_through_loss: 12.8296 Epoch 11/15 469/469 [==============================] - 9s 20ms/step - loss: 52.3183 - distortion_loss: 0.0641 - rate_loss: 20.2617 - distortion_pass_through_loss: 0.0641 - rate_pass_through_loss: 20.2610 - val_loss: 46.6293 - val_distortion_loss: 0.0671 - val_rate_loss: 13.0571 - val_distortion_pass_through_loss: 0.0672 - val_rate_pass_through_loss: 13.0598 Epoch 12/15 469/469 [==============================] - 10s 20ms/step - loss: 51.0015 - distortion_loss: 0.0633 - rate_loss: 19.3680 - distortion_pass_through_loss: 0.0633 - rate_pass_through_loss: 19.3676 - val_loss: 46.1679 - val_distortion_loss: 0.0642 - val_rate_loss: 14.0836 - val_distortion_pass_through_loss: 0.0642 - val_rate_pass_through_loss: 14.0873 Epoch 13/15 469/469 [==============================] - 9s 20ms/step - loss: 49.9189 - distortion_loss: 0.0627 - rate_loss: 18.5747 - distortion_pass_through_loss: 0.0627 - rate_pass_through_loss: 18.5742 - val_loss: 45.6781 - val_distortion_loss: 0.0640 - val_rate_loss: 13.6870 - val_distortion_pass_through_loss: 0.0640 - val_rate_pass_through_loss: 13.6915 Epoch 14/15 469/469 [==============================] - 9s 20ms/step - loss: 48.9510 - distortion_loss: 0.0621 - rate_loss: 17.9217 - distortion_pass_through_loss: 0.0621 - rate_pass_through_loss: 17.9215 - val_loss: 45.1058 - val_distortion_loss: 0.0614 - val_rate_loss: 14.4060 - val_distortion_pass_through_loss: 0.0614 - val_rate_pass_through_loss: 14.4218 Epoch 15/15 469/469 [==============================] - 9s 20ms/step - loss: 48.1553 - distortion_loss: 0.0615 - rate_loss: 17.4100 - distortion_pass_through_loss: 0.0615 - rate_pass_through_loss: 17.4101 - val_loss: 44.9707 - val_distortion_loss: 0.0609 - val_rate_loss: 14.5456 - val_distortion_pass_through_loss: 0.0609 - val_rate_pass_through_loss: 14.5536
我們程式碼的位元速率下降,數字的逼真度也隨之下降。不過,大多數數字仍然可辨識。
我們進一步減少 \(\lambda\)。
train_and_visualize_model(lmbda=300)
Epoch 1/15 469/469 [==============================] - ETA: 0s - loss: 114.0398 - distortion_loss: 0.0768 - rate_loss: 90.9927 - distortion_pass_through_loss: 0.0768 - rate_pass_through_loss: 90.9862 WARNING:absl:Computing quantization offsets using offset heuristic within a tf.function. Ideally, the offset heuristic should only be used to determine offsets once after training. Depending on the prior, estimating the offset might be computationally expensive. 469/469 [==============================] - 12s 21ms/step - loss: 114.0398 - distortion_loss: 0.0768 - rate_loss: 90.9927 - distortion_pass_through_loss: 0.0768 - rate_pass_through_loss: 90.9862 - val_loss: 96.7790 - val_distortion_loss: 0.0688 - val_rate_loss: 76.1469 - val_distortion_pass_through_loss: 0.0688 - val_rate_pass_through_loss: 76.1451 Epoch 2/15 469/469 [==============================] - 10s 20ms/step - loss: 85.8627 - distortion_loss: 0.0613 - rate_loss: 67.4866 - distortion_pass_through_loss: 0.0613 - rate_pass_through_loss: 67.4809 - val_loss: 74.3796 - val_distortion_loss: 0.0796 - val_rate_loss: 50.5063 - val_distortion_pass_through_loss: 0.0797 - val_rate_pass_through_loss: 50.5014 Epoch 3/15 469/469 [==============================] - 10s 21ms/step - loss: 68.7960 - distortion_loss: 0.0645 - rate_loss: 49.4494 - distortion_pass_through_loss: 0.0645 - rate_pass_through_loss: 49.4456 - val_loss: 58.5856 - val_distortion_loss: 0.0911 - val_rate_loss: 31.2701 - val_distortion_pass_through_loss: 0.0912 - val_rate_pass_through_loss: 31.2587 Epoch 4/15 469/469 [==============================] - 10s 21ms/step - loss: 58.2693 - distortion_loss: 0.0693 - rate_loss: 37.4914 - distortion_pass_through_loss: 0.0693 - rate_pass_through_loss: 37.4892 - val_loss: 48.5439 - val_distortion_loss: 0.0983 - val_rate_loss: 19.0433 - val_distortion_pass_through_loss: 0.0984 - val_rate_pass_through_loss: 19.0396 Epoch 5/15 469/469 [==============================] - 10s 21ms/step - loss: 51.9585 - distortion_loss: 0.0734 - rate_loss: 29.9320 - distortion_pass_through_loss: 0.0734 - rate_pass_through_loss: 29.9306 - val_loss: 42.2598 - val_distortion_loss: 0.0994 - val_rate_loss: 12.4350 - val_distortion_pass_through_loss: 0.0995 - val_rate_pass_through_loss: 12.4339 Epoch 6/15 469/469 [==============================] - 10s 20ms/step - loss: 48.0547 - distortion_loss: 0.0769 - rate_loss: 24.9776 - distortion_pass_through_loss: 0.0769 - rate_pass_through_loss: 24.9766 - val_loss: 38.5121 - val_distortion_loss: 0.1000 - val_rate_loss: 8.4996 - val_distortion_pass_through_loss: 0.1000 - val_rate_pass_through_loss: 8.5025 Epoch 7/15 469/469 [==============================] - 9s 20ms/step - loss: 45.3418 - distortion_loss: 0.0798 - rate_loss: 21.4144 - distortion_pass_through_loss: 0.0798 - rate_pass_through_loss: 21.4135 - val_loss: 36.2517 - val_distortion_loss: 0.0992 - val_rate_loss: 6.4969 - val_distortion_pass_through_loss: 0.0991 - val_rate_pass_through_loss: 6.5043 Epoch 8/15 469/469 [==============================] - 9s 20ms/step - loss: 43.1311 - distortion_loss: 0.0812 - rate_loss: 18.7581 - distortion_pass_through_loss: 0.0812 - rate_pass_through_loss: 18.7577 - val_loss: 34.4578 - val_distortion_loss: 0.0923 - val_rate_loss: 6.7545 - val_distortion_pass_through_loss: 0.0923 - val_rate_pass_through_loss: 6.7532 Epoch 9/15 469/469 [==============================] - 10s 20ms/step - loss: 41.1835 - distortion_loss: 0.0815 - rate_loss: 16.7441 - distortion_pass_through_loss: 0.0815 - rate_pass_through_loss: 16.7435 - val_loss: 33.3371 - val_distortion_loss: 0.0870 - val_rate_loss: 7.2369 - val_distortion_pass_through_loss: 0.0870 - val_rate_pass_through_loss: 7.2306 Epoch 10/15 469/469 [==============================] - 10s 20ms/step - loss: 39.4039 - distortion_loss: 0.0806 - rate_loss: 15.2191 - distortion_pass_through_loss: 0.0806 - rate_pass_through_loss: 15.2184 - val_loss: 32.8390 - val_distortion_loss: 0.0840 - val_rate_loss: 7.6290 - val_distortion_pass_through_loss: 0.0840 - val_rate_pass_through_loss: 7.6263 Epoch 11/15 469/469 [==============================] - 9s 20ms/step - loss: 37.9107 - distortion_loss: 0.0793 - rate_loss: 14.1125 - distortion_pass_through_loss: 0.0793 - rate_pass_through_loss: 14.1124 - val_loss: 32.6137 - val_distortion_loss: 0.0816 - val_rate_loss: 8.1349 - val_distortion_pass_through_loss: 0.0816 - val_rate_pass_through_loss: 8.1267 Epoch 12/15 469/469 [==============================] - 10s 20ms/step - loss: 36.6647 - distortion_loss: 0.0780 - rate_loss: 13.2545 - distortion_pass_through_loss: 0.0780 - rate_pass_through_loss: 13.2540 - val_loss: 32.2822 - val_distortion_loss: 0.0814 - val_rate_loss: 7.8685 - val_distortion_pass_through_loss: 0.0814 - val_rate_pass_through_loss: 7.8598 Epoch 13/15 469/469 [==============================] - 10s 20ms/step - loss: 35.6982 - distortion_loss: 0.0771 - rate_loss: 12.5623 - distortion_pass_through_loss: 0.0771 - rate_pass_through_loss: 12.5622 - val_loss: 31.9754 - val_distortion_loss: 0.0791 - val_rate_loss: 8.2589 - val_distortion_pass_through_loss: 0.0791 - val_rate_pass_through_loss: 8.2520 Epoch 14/15 469/469 [==============================] - 9s 20ms/step - loss: 34.9956 - distortion_loss: 0.0765 - rate_loss: 12.0543 - distortion_pass_through_loss: 0.0765 - rate_pass_through_loss: 12.0540 - val_loss: 31.8855 - val_distortion_loss: 0.0786 - val_rate_loss: 8.3127 - val_distortion_pass_through_loss: 0.0786 - val_rate_pass_through_loss: 8.2984 Epoch 15/15 469/469 [==============================] - 9s 20ms/step - loss: 34.4267 - distortion_loss: 0.0759 - rate_loss: 11.6673 - distortion_pass_through_loss: 0.0759 - rate_pass_through_loss: 11.6670 - val_loss: 31.7163 - val_distortion_loss: 0.0777 - val_rate_loss: 8.3933 - val_distortion_pass_through_loss: 0.0778 - val_rate_pass_through_loss: 8.3861
字串現在開始變得短得多,大約每個數字一個位元組。不過,這是要付出代價的。有更多數字變得無法辨識。
這證明這個模型與人類對誤差的感知無關,它只測量像素值方面的絕對偏差。如要達到更佳的感知圖片品質,我們需要將像素損失替換為感知損失。
將解碼器用做生成模型。
如果我們將隨機位元饋送至解碼器,這會有效地從模型學習表示數字的分佈中取樣。
首先,重新例項化壓縮器/解壓縮器,而無需進行健全性檢查,以偵測輸入字串是否未完全解碼。
compressor, decompressor = make_mnist_codec(trainer, decode_sanity_check=False)
現在,將夠長的隨機字串饋送至解壓縮器,使其可以從中解碼/取樣數字。
import os
strings = tf.constant([os.urandom(8) for _ in range(16)])
samples = decompressor(strings)
fig, axes = plt.subplots(4, 4, sharex=True, sharey=True, figsize=(5, 5))
axes = axes.ravel()
for i in range(len(axes)):
axes[i].imshow(tf.squeeze(samples[i]))
axes[i].axis("off")
plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)