使用 TensorFlow Lite 實現超解析度

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視原始碼 下載筆記本 查看 TF Hub 模型

總覽

從低解析度影像復原高解析度 (HR) 影像的任務通常稱為單張影像超解析度 (SISR)。

此處使用的模型是 ESRGAN (ESRGAN:增強型超解析度生成對抗網路)。我們將使用 TensorFlow Lite 在預先訓練的模型上執行推論。

TFLite 模型是從 TF Hub 上託管的此實作轉換而來。請注意,我們轉換的模型將 50x50 低解析度影像升採樣為 200x200 高解析度影像 (比例因子 = 4)。如果您想要不同的輸入大小或比例因子,則需要重新轉換或重新訓練原始模型。

設定

讓我們先安裝必要的程式庫。

pip install matplotlib tensorflow tensorflow-hub

匯入依附元件。

import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt
print(tf.__version__)

下載並轉換 ESRGAN 模型

model = hub.load("https://tfhub.dev/captain-pool/esrgan-tf2/1")
concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]

@tf.function(input_signature=[tf.TensorSpec(shape=[1, 50, 50, 3], dtype=tf.float32)])
def f(input):
  return concrete_func(input);

converter = tf.lite.TFLiteConverter.from_concrete_functions([f.get_concrete_function()], model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

# Save the TF Lite model.
with tf.io.gfile.GFile('ESRGAN.tflite', 'wb') as f:
  f.write(tflite_model)

esrgan_model_path = './ESRGAN.tflite'

下載測試影像 (昆蟲頭部)。

test_img_path = tf.keras.utils.get_file('lr.jpg', 'https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg')

使用 TensorFlow Lite 產生超解析度影像

lr = tf.io.read_file(test_img_path)
lr = tf.image.decode_jpeg(lr)
lr = tf.expand_dims(lr, axis=0)
lr = tf.cast(lr, tf.float32)

# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=esrgan_model_path)
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Run the model
interpreter.set_tensor(input_details[0]['index'], lr)
interpreter.invoke()

# Extract the output and postprocess it
output_data = interpreter.get_tensor(output_details[0]['index'])
sr = tf.squeeze(output_data, axis=0)
sr = tf.clip_by_value(sr, 0, 255)
sr = tf.round(sr)
sr = tf.cast(sr, tf.uint8)

將結果視覺化

lr = tf.cast(tf.squeeze(lr, axis=0), tf.uint8)
plt.figure(figsize = (1, 1))
plt.title('LR')
plt.imshow(lr.numpy());

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)        
plt.title(f'ESRGAN (x4)')
plt.imshow(sr.numpy());

bicubic = tf.image.resize(lr, [200, 200], tf.image.ResizeMethod.BICUBIC)
bicubic = tf.cast(bicubic, tf.uint8)
plt.subplot(1, 2, 2)   
plt.title('Bicubic')
plt.imshow(bicubic.numpy());

效能基準評測

效能基準評測數字是使用此處說明的工具產生。

模型名稱 模型大小 裝置 CPU GPU
超解析度 (ESRGAN) 4.8 Mb Pixel 3 586.8 毫秒* 128.6 毫秒
Pixel 4 385.1 毫秒* 130.3 毫秒

*使用 4 個執行緒