載入外部 tfrecord 與 TFDS

如果您有第三方工具產生的 tf.train.Example Proto (在 .tfrecord.riegeli 等檔案中),並想使用 TFDS API 直接載入,本頁面正適合您。

為了載入 .tfrecord 檔案,您只需要:

  • 遵循 TFDS 命名慣例。
  • 沿著 tfrecord 檔案新增中繼資料檔案 (dataset_info.jsonfeatures.json)。

限制

檔案命名慣例

TFDS 支援定義檔案名稱範本,可彈性使用不同的檔案命名架構。範本由 tfds.core.ShardedFileTemplate 表示,並支援下列變數:{DATASET}{SPLIT}{FILEFORMAT}{SHARD_INDEX}{NUM_SHARDS}{SHARD_X_OF_Y}。例如,TFDS 的預設檔案命名架構為:{DATASET}-{SPLIT}.{FILEFORMAT}-{SHARD_X_OF_Y}。以 MNIST 為例,檔案名稱看起來如下

  • mnist-test.tfrecord-00000-of-00001
  • mnist-train.tfrecord-00000-of-00001

新增中繼資料

提供功能結構

為了讓 TFDS 能夠解碼 tf.train.Example Proto,您需要提供符合規格的 tfds.features 結構。例如:

features = tfds.features.FeaturesDict({
    'image':
        tfds.features.Image(
            shape=(256, 256, 3),
            doc='Picture taken by smartphone, downscaled.'),
    'label':
        tfds.features.ClassLabel(names=['dog', 'cat']),
    'objects':
        tfds.features.Sequence({
            'camera/K': tfds.features.Tensor(shape=(3,), dtype=tf.float32),
        }),
})

對應至下列 tf.train.Example 規格

{
    'image': tf.io.FixedLenFeature(shape=(), dtype=tf.string),
    'label': tf.io.FixedLenFeature(shape=(), dtype=tf.int64),
    'objects/camera/K': tf.io.FixedLenSequenceFeature(shape=(3,), dtype=tf.int64),
}

指定功能可讓 TFDS 自動解碼圖片、影片等。如同任何其他 TFDS 資料集,功能中繼資料 (例如標籤名稱等) 將會向使用者公開 (例如 info.features['label'].names)。

如果您控制產生管道

如果您在 TFDS 外部產生資料集,但仍控制產生管道,可以使用 tfds.features.FeatureConnector.serialize_example 將資料從 dict[np.ndarray] 編碼為 tf.train.Example Proto bytes

with tf.io.TFRecordWriter('path/to/file.tfrecord') as writer:
  for ex in all_exs:
    ex_bytes = features.serialize_example(data)
    writer.write(ex_bytes)

這可確保功能與 TFDS 相容。

同樣地,也存在 feature.deserialize_example 以解碼 Proto (範例)

如果您不控制產生管道

如果您想瞭解 tfds.featurestf.train.Example 中的表示方式,可以在 Colab 中檢查

取得分割的統計資料

TFDS 需要知道每個分片中的確切範例數量。這是 len(ds) 等功能或 subplit APIsplit='train[75%:]' 所需的資訊。

  • 如果您有這項資訊,可以明確建立 tfds.core.SplitInfo 清單,並跳至下一節

    split_infos = [
        tfds.core.SplitInfo(
            name='train',
            shard_lengths=[1024, ...],  # Num of examples in shard0, shard1,...
            num_bytes=0,  # Total size of your dataset (if unknown, set to 0)
        ),
        tfds.core.SplitInfo(name='test', ...),
    ]
    
  • 如果您不知道這項資訊,可以使用 compute_split_info.py 指令碼 (或使用 tfds.folder_dataset.compute_split_info 在您自己的指令碼中) 計算。這會啟動 Beam 管道,讀取指定目錄中的所有分片,並計算資訊。

新增中繼資料檔案

若要自動將正確的中繼資料檔案沿著資料集新增,請使用 tfds.folder_dataset.write_metadata

tfds.folder_dataset.write_metadata(
    data_dir='/path/to/my/dataset/1.0.0/',
    features=features,
    # Pass the `out_dir` argument of compute_split_info (see section above)
    # You can also explicitly pass a list of `tfds.core.SplitInfo`.
    split_infos='/path/to/my/dataset/1.0.0/',
    # Pass a custom file name template or use None for the default TFDS
    # file name template.
    filename_template='{SPLIT}-{SHARD_X_OF_Y}.{FILEFORMAT}',

    # Optionally, additional DatasetInfo metadata can be provided
    # See:
    # https://tensorflow.dev.org.tw/datasets/api_docs/python/tfds/core/DatasetInfo
    description="""Multi-line description."""
    homepage='http://my-project.org',
    supervised_keys=('image', 'label'),
    citation="""BibTex citation.""",
)

一旦在您的資料集目錄上呼叫過此函式,就會新增中繼資料檔案 (dataset_info.json 等),您的資料集即可使用 TFDS 載入 (請參閱下一節)。

使用 TFDS 載入資料集

直接從資料夾

產生中繼資料後,即可使用 tfds.builder_from_directory 載入資料集,此函式會傳回具有標準 TFDS API 的 tfds.core.DatasetBuilder (例如 tfds.builder)

builder = tfds.builder_from_directory('~/path/to/my_dataset/3.0.0/')

# Metadata are available as usual
builder.info.splits['train'].num_examples

# Construct the tf.data.Dataset pipeline
ds = builder.as_dataset(split='train[75%:]')
for ex in ds:
  ...

直接從多個資料夾

也可以從多個資料夾載入資料。舉例來說,在強化學習中,多個代理程式各自產生個別的資料集,而您想將所有資料集一起載入時,就會發生這種情況。其他使用案例包括定期產生新的資料集 (例如每天一個新的資料集),而您想從某個日期範圍載入資料。

若要從多個資料夾載入資料,請使用 tfds.builder_from_directories,此函式會傳回具有標準 TFDS API 的 tfds.core.DatasetBuilder (例如 tfds.builder)

builder = tfds.builder_from_directories(builder_dirs=[
    '~/path/my_dataset/agent1/1.0.0/',
    '~/path/my_dataset/agent2/1.0.0/',
    '~/path/my_dataset/agent3/1.0.0/',
])

# Metadata are available as usual
builder.info.splits['train'].num_examples

# Construct the tf.data.Dataset pipeline
ds = builder.as_dataset(split='train[75%:]')
for ex in ds:
  ...

資料夾結構 (選用)

為了獲得與 TFDS 更佳的相容性,您可以將資料整理為 <data_dir>/<dataset_name>[/<dataset_config>]/<dataset_version>。例如:

data_dir/
    dataset0/
        1.0.0/
        1.0.1/
    dataset1/
        config0/
            2.0.0/
        config1/
            2.0.0/

這會讓您的資料集與 tfds.load / tfds.builder API 相容,只需提供 data_dir/ 即可

ds0 = tfds.load('dataset0', data_dir='data_dir/')
ds1 = tfds.load('dataset1/config0', data_dir='data_dir/')