效能秘訣

本文件提供 TensorFlow Datasets (TFDS) 專屬的效能秘訣。請注意,TFDS 以 tf.data.Dataset 物件的形式提供資料集,因此 tf.data 指南中的建議仍然適用。

基準評測資料集

使用 tfds.benchmark(ds) 基準評測任何 tf.data.Dataset 物件。

務必指出 batch_size= 以標準化結果 (例如 100 iter/秒 -> 3200 ex/秒)。這適用於任何可迭代物件 (例如 tfds.benchmark(tfds.as_numpy(ds)))。

ds = tfds.load('mnist', split='train').batch(32).prefetch()
# Display some benchmark statistics
tfds.benchmark(ds, batch_size=32)
# Second iteration is much faster, due to auto-caching
tfds.benchmark(ds, batch_size=32)

小型資料集 (小於 1 GB)

所有 TFDS 資料集都以 TFRecord 格式將資料儲存在磁碟上。對於小型資料集 (例如 MNIST、CIFAR-10/-100),從 .tfrecord 讀取可能會增加相當大的額外負擔。

由於這些資料集適合放入記憶體,因此可以透過快取或預先載入資料集來大幅提升效能。請注意,TFDS 會自動快取小型資料集 (以下章節將詳細說明)。

快取資料集

以下範例示範在正規化圖片後明確快取資料集的資料管線。

def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label


ds, ds_info = tfds.load(
    'mnist',
    split='train',
    as_supervised=True,  # returns `(img, label)` instead of dict(image=, ...)
    with_info=True,
)
# Applying normalization before `ds.cache()` to re-use it.
# Note: Random transformations (e.g. images augmentations) should be applied
# after both `ds.cache()` (to avoid caching randomness) and `ds.batch()` (for
# vectorization [1]).
ds = ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.cache()
# For true randomness, we set the shuffle buffer to the full dataset size.
ds = ds.shuffle(ds_info.splits['train'].num_examples)
# Batch after shuffling to get unique batches at each epoch.
ds = ds.batch(128)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

在疊代此資料集時,由於快取,第二次疊代會比第一次快得多。

自動快取

根據預設,TFDS 會自動快取 (使用 ds.cache()) 符合以下限制的資料集

  • 已定義資料集總大小 (所有分割),且小於 250 MiB
  • 停用 shuffle_files,或僅讀取單一分片

您可以將 try_autocaching=False 傳遞至 tfds.ReadConfig (位於 tfds.load 中) 來選擇停用自動快取。請參閱資料集目錄文件,瞭解特定資料集是否會使用自動快取。

將完整資料載入為單一張量

如果您的資料集適合放入記憶體,您也可以將完整資料集載入為單一張量或 NumPy 陣列。您可以設定 batch_size=-1 以將所有範例批次處理成單一 tf.Tensor,即可達成此目的。然後使用 tfds.as_numpy,將格式從 tf.Tensor 轉換為 np.array

(img_train, label_train), (img_test, label_test) = tfds.as_numpy(tfds.load(
    'mnist',
    split=['train', 'test'],
    batch_size=-1,
    as_supervised=True,
))

大型資料集

大型資料集已分片 (分割成多個檔案),且通常不適合放入記憶體,因此不應快取。

隨機排序和訓練

在訓練期間,務必妥善隨機排序資料,因為隨機排序不佳的資料可能會導致訓練準確度降低。

除了使用 ds.shuffle 隨機排序記錄之外,您也應設定 shuffle_files=True,以便針對已分片成多個檔案的較大型資料集獲得良好的隨機排序行為。否則,週期會以相同的順序讀取分片,因此資料不會真正隨機化。

ds = tfds.load('imagenet2012', split='train', shuffle_files=True)

此外,當 shuffle_files=True 時,TFDS 會停用 options.deterministic,這可能會稍微提升效能。若要取得決定性隨機排序,您可以透過 tfds.ReadConfig 選擇停用此功能:方法是設定 read_config.shuffle_seed 或覆寫 read_config.options.deterministic

跨工作站自動分片資料 (TF)

在多個工作站上進行訓練時,您可以使用 tfds.ReadConfiginput_context 引數,如此一來,每個工作站都會讀取資料的子集。

input_context = tf.distribute.InputContext(
    input_pipeline_id=1,  # Worker id
    num_input_pipelines=4,  # Total number of workers
)
read_config = tfds.ReadConfig(
    input_context=input_context,
)
ds = tfds.load('dataset', split='train', read_config=read_config)

這與子分割 API 相輔相成。首先,會套用子分割 API:train[:50%] 會轉換成要讀取的檔案清單。接著,系統會對這些檔案套用 ds.shard() 運算。例如,當搭配 num_input_pipelines=2 使用 train[:50%] 時,2 個工作站的每一個都會讀取 1/4 的資料。

shuffle_files=True 時,檔案會在一個工作站內隨機排序,但不會跨工作站隨機排序。每個工作站會在週期之間讀取相同的檔案子集。

跨工作站自動分片資料 (Jax)

使用 Jax 時,您可以使用 tfds.split_for_jax_processtfds.even_splits API,將資料分散到多個工作站。請參閱分割 API 指南

split = tfds.split_for_jax_process('train', drop_remainder=True)
ds = tfds.load('my_dataset', split=split)

tfds.split_for_jax_process 是下列項目的簡單別名

# The current `process_index` loads only `1 / process_count` of the data.
splits = tfds.even_splits('train', n=jax.process_count(), drop_remainder=True)
split = splits[jax.process_index()]

更快的圖片解碼

根據預設,TFDS 會自動解碼圖片。但是,在某些情況下,如果略過圖片解碼 (使用 tfds.decode.SkipDecoding) 並手動套用 tf.io.decode_image 運算,效能可能會更高

這兩個範例的程式碼可在解碼指南中找到。

略過未使用的功能

如果您只使用部分功能,可以完全略過某些功能。如果您的資料集有許多未使用的功能,不解碼這些功能可以大幅提升效能。請參閱 https://tensorflow.dev.org.tw/datasets/decode#only_decode_a_sub-set_of_the_features

tf.data 用盡我的所有 RAM!

如果您受限於 RAM,或者如果您在使用 tf.data 時平行載入許多資料集,以下是一些可能有幫助的選項

覆寫緩衝區大小

builder.as_dataset(
  read_config=tfds.ReadConfig(
    ...
    override_buffer_size=1024,  # Save quite a bit of RAM.
  ),
  ...
)

這會覆寫傳遞至 TFRecordDataset (或同等項目) 的 buffer_sizehttps://tensorflow.dev.org.tw/api_docs/python/tf/data/TFRecordDataset#args

使用 tf.data.Dataset.with_options 停止神奇行為

https://tensorflow.dev.org.tw/api_docs/python/tf/data/Dataset#with_options

options = tf.data.Options()

# Stop magic stuff that eats up RAM:
options.autotune.enabled = False
options.experimental_distribute.auto_shard_policy = (
  tf.data.experimental.AutoShardPolicy.OFF)
options.experimental_optimization.inject_prefetch = False

data = data.with_options(options)