這個筆記本示範了生成對抗網路,該網路使用自監督式和半監督式學習技術,以少至 2.5% 的標記資料在 ImageNet 上進行訓練。產生器和鑑別器模型均可在 TF Hub 上取得。
如需模型和訓練程序的更多資訊,請參閱我們的部落格文章和論文 [1]。用於訓練這些模型的程式碼可在 GitHub 上取得。
- (選用) 在下方的第二個程式碼儲存格中選取模型。
- 按一下「執行階段」>「全部執行」以依序執行每個儲存格。
- 之後,當您使用滑桿和下拉式選單修改設定時,互動式視覺化效果應會自動更新。
[1] Mario Lucic*、Michael Tschannen*、Marvin Ritter*、Xiaohua Zhai、Olivier Bachem、Sylvain Gelly,《High-Fidelity Image Generation With Fewer Labels》,ICML 2019。
# @title Imports and utility functions
import os
import IPython
from IPython.display import display
import numpy as np
import PIL.Image
import pandas as pd
import six
import tensorflow.compat.v1 as tf
import tensorflow_hub as hub
def imgrid(imarray, cols=8, pad=1):
pad = int(pad)
assert pad >= 0
cols = int(cols)
assert cols >= 1
N, H, W, C = imarray.shape
rows = int(np.ceil(N / float(cols)))
batch_pad = rows * cols - N
assert batch_pad >= 0
post_pad = [batch_pad, pad, pad, 0]
pad_arg = [[0, p] for p in post_pad]
imarray = np.pad(imarray, pad_arg, 'constant')
H += pad
W += pad
grid = (imarray
.reshape(rows, cols, H, W, C)
.transpose(0, 2, 1, 3, 4)
.reshape(rows*H, cols*W, C))
return grid[:-pad, :-pad]
def imshow(a, format='png', jpeg_fallback=True):
a = np.asarray(a, dtype=np.uint8)
if six.PY3:
str_file = six.BytesIO()
str_file = six.StringIO()
PIL.Image.fromarray(a).save(str_file, format)
png_data = str_file.getvalue()
disp = display(IPython.display.Image(png_data))
except IOError:
if jpeg_fallback and format != 'jpeg':
print ('Warning: image was too large to display in format "{}"; '
'trying jpeg instead.').format(format)
return imshow(a, format='jpeg')
return disp
class Generator(object):
def __init__(self, module_spec):
self._module_spec = module_spec
self._sess = None
self._graph = tf.Graph()
def z_dim(self):
return self._z.shape[-1].value
def conditional(self):
return self._labels is not None
def _load_model(self):
with self._graph.as_default():
self._generator = hub.Module(self._module_spec, name="gen_module",
tags={"gen", "bsNone"})
input_info = self._generator.get_input_info_dict()
inputs = {k: tf.placeholder(v.dtype, v.get_shape().as_list(), k)
for k, v in self._generator.get_input_info_dict().items()}
self._samples = self._generator(inputs=inputs, as_dict=True)["generated"]
print("Inputs:", inputs)
print("Outputs:", self._samples)
self._z = inputs["z"]
self._labels = inputs.get("labels", None)
def _init_session(self):
if self._sess is None:
self._sess = tf.Session(graph=self._graph)
def get_noise(self, num_samples, seed=None):
if np.isscalar(seed):
return np.random.normal(size=[num_samples, self.z_dim])
z = np.empty(shape=(len(seed), self.z_dim), dtype=np.float32)
for i, s in enumerate(seed):
z[i] = np.random.normal(size=[self.z_dim])
return z
def get_samples(self, z, labels=None):
with self._graph.as_default():
feed_dict = {self._z: z}
if self.conditional:
assert labels is not None
assert labels.shape[0] == z.shape[0]
feed_dict[self._labels] = labels
samples = self._sess.run(self._samples, feed_dict=feed_dict)
return np.uint8(np.clip(256 * samples, 0, 255))
class Discriminator(object):
def __init__(self, module_spec):
self._module_spec = module_spec
self._sess = None
self._graph = tf.Graph()
def conditional(self):
return "labels" in self._inputs
def image_shape(self):
return self._inputs["images"].shape.as_list()[1:]
def _load_model(self):
with self._graph.as_default():
self._discriminator = hub.Module(self._module_spec, name="disc_module",
tags={"disc", "bsNone"})
input_info = self._discriminator.get_input_info_dict()
self._inputs = {k: tf.placeholder(v.dtype, v.get_shape().as_list(), k)
for k, v in input_info.items()}
self._outputs = self._discriminator(inputs=self._inputs, as_dict=True)
print("Inputs:", self._inputs)
print("Outputs:", self._outputs)
def _init_session(self):
if self._sess is None:
self._sess = tf.Session(graph=self._graph)
def predict(self, images, labels=None):
with self._graph.as_default():
feed_dict = {self._inputs["images"]: images}
if "labels" in self._inputs:
assert labels is not None
assert labels.shape[0] == images.shape[0]
feed_dict[self._inputs["labels"]] = labels
return self._sess.run(self._outputs, feed_dict=feed_dict)
2023-11-07 12:22:55.109710: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-11-07 12:22:55.109762: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-11-07 12:22:55.111269: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/compat/v2_compat.py:108: disable_resource_variables (from tensorflow.python.ops.variable_scope) is deprecated and will be removed in a future version. Instructions for updating: non-resource variables are not supported in the long term
# @title Select a model { run: "auto" }
model_name = "S3GAN 128x128 20% labels (FID 6.9, IS 98.1)" # @param ["S3GAN 256x256 10% labels (FID 8.8, IS 130.7)", "S3GAN 128x128 2.5% labels (FID 12.6, IS 48.7)", "S3GAN 128x128 5% labels (FID 8.4, IS 74.0)", "S3GAN 128x128 10% labels (FID 7.6, IS 90.3)", "S3GAN 128x128 20% labels (FID 6.9, IS 98.1)"]
models = {
"S3GAN 256x256 10% labels": "https://tfhub.dev/google/compare_gan/s3gan_10_256x256/1",
"S3GAN 128x128 2.5% labels": "https://tfhub.dev/google/compare_gan/s3gan_2_5_128x128/1",
"S3GAN 128x128 5% labels": "https://tfhub.dev/google/compare_gan/s3gan_5_128x128/1",
"S3GAN 128x128 10% labels": "https://tfhub.dev/google/compare_gan/s3gan_10_128x128/1",
"S3GAN 128x128 20% labels": "https://tfhub.dev/google/compare_gan/s3gan_20_128x128/1",
module_spec = models[model_name.split(" (")[0]]
print("Module spec:", module_spec)
print("Loading model...")
sampler = Generator(module_spec)
print("Model loaded.")
Module spec: https://tfhub.dev/google/compare_gan/s3gan_20_128x128/1 Loading model... 2023-11-07 12:23:17.360038: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Saver not created because there are no variables in the graph to restore Inputs: {'labels': <tf.Tensor 'labels:0' shape=(?,) dtype=int32>, 'z': <tf.Tensor 'z:0' shape=(?, 120) dtype=float32>} Outputs: Tensor("gen_module_apply_default/generator_1/truediv:0", shape=(?, 128, 128, 3), dtype=float32) Model loaded.
disc = Discriminator(module_spec)
batch_size = 4
num_classes = 1000
images = np.random.random(size=[batch_size] + disc.image_shape)
labels = np.random.randint(0, num_classes, size=(batch_size))
disc.predict(images, labels=labels)
INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Saver not created because there are no variables in the graph to restore Inputs: {'images': <tf.Tensor 'images:0' shape=(?, 128, 128, 3) dtype=float32>, 'labels': <tf.Tensor 'labels:0' shape=(?,) dtype=int32>} Outputs: {'prediction': <tf.Tensor 'disc_module_apply_default/discriminator/Sigmoid:0' shape=(?, 1) dtype=float32>} {'prediction': array([[0.82321596], [0.89030766], [0.8621535 ], [0.88563365]], dtype=float32)}