分割與切片

所有 TFDS 資料集都公開各種資料分割 (例如 'train''test'),這些分割可以在目錄中瀏覽。除了 all 之外,任何字母字串都可以用作分割名稱 (all 是一個保留字詞,對應於所有分割的聯集,請參閱下文)。

除了「官方」資料集分割之外,TFDS 還允許選取分割的切片和各種組合。

切片 API

切片指示在 tfds.loadtfds.DatasetBuilder.as_dataset 中透過 split= kwarg 指定。

ds = tfds.load('my_dataset', split='train[:75%]')
builder = tfds.builder('my_dataset')
ds = builder.as_dataset(split='test+train[:75%]')

分割可以是

  • 純分割名稱 (字串,例如 'train''test'、...): 選取分割中的所有範例。
  • 切片:切片具有與 Python 切片表示法相同的語意。切片可以是
    • 絕對 ('train[123:450]'train[:4000]): (請參閱以下關於讀取順序注意事項)
    • 百分比 ('train[:75%]''train[25%:75%]'): 將完整資料分割成均勻的切片。如果資料無法均勻分割,則某些百分比可能包含其他範例。支援小數百分比。
    • 分片 (train[:4shard]train[4shard]): 選取所請求分片中的所有範例。(請參閱 info.splits['train'].num_shards 以取得分割的分片數量)
  • 分割的聯集 ('train+test''train[:25%]+test'): 分割將交錯在一起。
  • 完整資料集 ('all'): 'all' 是一個特殊的分割名稱,對應於所有分割的聯集 (相當於 'train+test+...')。
  • 分割清單 (['train', 'test']): 多個 tf.data.Dataset 會分別傳回
# Returns both train and test split separately
train_ds, test_ds = tfds.load('mnist', split=['train', 'test[:50%]'])

tfds.even_splits & 多主機訓練

tfds.even_splits 產生大小相同的非重疊子分割清單。

# Divide the dataset into 3 even parts, each containing 1/3 of the data
split0, split1, split2 = tfds.even_splits('train', n=3)

ds = tfds.load('my_dataset', split=split2)

這在分散式設定中訓練時特別有用,在分散式設定中,每個主機都應接收原始資料的切片。

使用 Jax,可以使用 tfds.split_for_jax_process 進一步簡化此操作

split = tfds.split_for_jax_process('train', drop_remainder=True)
ds = tfds.load('my_dataset', split=split)

tfds.split_for_jax_processtfds.even_splits 的簡單別名,tfds.split_for_jax_process 接受任何分割值作為輸入 (例如 'train[75%:]+test')

# The current `process_index` loads only `1 / process_count` of the data.
splits = tfds.even_splits('train', n=jax.process_count(), drop_remainder=True)
split = splits[jax.process_index()]

tfds.even_splitstfds.split_for_jax_process 接受任何分割值作為輸入 (例如 'train[75%:]+test')

切片和中繼資料

可以使用資料集資訊取得關於分割/子分割的其他資訊 (num_examplesfile_instructions、...)

builder = tfds.builder('my_dataset')
builder.info.splits['train'].num_examples  # 10_000
builder.info.splits['train[:75%]'].num_examples  # 7_500 (also works with slices)
builder.info.splits.keys()  # ['train', 'test']

交叉驗證

使用字串 API 進行 10 折交叉驗證的範例

vals_ds = tfds.load('mnist', split=[
    f'train[{k}%:{k+10}%]' for k in range(0, 100, 10)
])
trains_ds = tfds.load('mnist', split=[
    f'train[:{k}%]+train[{k+10}%:]' for k in range(0, 100, 10)
])

驗證資料集將各佔 10%: [0%:10%][10%:20%]、...、[90%:100%]。而訓練資料集將各佔互補的 90%: [10%:100%] (對於對應的驗證集 [0%:10%])、 `[0%:10%]

  • [20%:100%](對於驗證集[10%:20%]`)、...

tfds.core.ReadInstruction 和捨入

除了 str 之外,也可以將分割作為 tfds.core.ReadInstruction 傳遞

例如,split = 'train[50%:75%] + test' 相當於

split = (
    tfds.core.ReadInstruction(
        'train',
        from_=50,
        to=75,
        unit='%',
    )
    + tfds.core.ReadInstruction('test')
)
ds = tfds.load('my_dataset', split=split)

unit 可以是

  • abs: 絕對切片
  • %: 百分比切片
  • shard: 分片切片

tfds.ReadInstruction 也具有捨入引數。如果資料集中的範例數量無法均勻分割

  • rounding='closest' (預設值): 剩餘的範例會分配在百分比之間,因此某些百分比可能包含其他範例。
  • rounding='pct1_dropremainder': 剩餘的範例會捨棄,但這保證所有百分比都包含完全相同的範例數量 (例如: len(5%) == 5 * len(1%))。

重現性和決定性

在產生期間,對於給定的資料集版本,TFDS 保證範例會在磁碟上以決定性的方式隨機排序。因此,產生資料集兩次 (在 2 部不同的電腦中) 不會變更範例順序。

同樣地,子分割 API 將始終選取相同的範例 set,而與平台、架構等無關。這表示 set('train[:20%]') == set('train[:10%]') + set('train[10%:20%]')

但是,讀取範例的順序可能是決定性的。這取決於其他參數 (例如 shuffle_files=True 是否為 true)。