常見實作陷阱

本頁說明實作新資料集時常見的實作陷阱。

應避免使用舊版 SplitGenerator

舊版 tfds.core.SplitGenerator API 已淘汰。

def _split_generator(...):
  return [
      tfds.core.SplitGenerator(name='train', gen_kwargs={'path': train_path}),
      tfds.core.SplitGenerator(name='test', gen_kwargs={'path': test_path}),
  ]

應替換為

def _split_generator(...):
  return {
      'train': self._generate_examples(path=train_path),
      'test': self._generate_examples(path=test_path),
  }

理由:新版 API 不那麼冗長,而且更加明確。舊版 API 將在未來版本中移除。

新資料集應包含在資料夾中

tensorflow_datasets/ 存放庫中新增資料集時,請務必遵循資料集即資料夾結構 (所有檢查總和、虛擬資料、實作程式碼皆包含在資料夾中)。

  • 舊資料集 (不良):<category>/<ds_name>.py
  • 新資料集 (良好):<category>/<ds_name>/<ds_name>.py

使用 TFDS CLI (tfds new,或適用於 Google 員工的 gtfds new) 來產生範本。

理由:舊結構需要檢查總和與虛擬資料的絕對路徑,而且將資料集檔案分散在許多位置。這使得在 TFDS 存放庫外部實作資料集更加困難。為了保持一致性,現在應全面使用新結構。

描述清單應格式化為 Markdown

DatasetInfo.description str 的格式為 Markdown。Markdown 清單的第一個項目前需要空行

_DESCRIPTION = """
Some text.
                      # << Empty line here !!!
1. Item 1
2. Item 1
3. Item 1
                      # << Empty line here !!!
Some other text.
"""

理由:格式錯誤的描述會在我們的目錄文件中產生視覺瑕疵。如果沒有空行,上述文字會呈現為

Some text. 1. Item 1 2. Item 1 3. Item 1 Some other text

忘記 ClassLabel 名稱

使用 tfds.features.ClassLabel 時,請嘗試使用 names=names_file= 提供人類可讀的標籤字串 (而非 num_classes=10)。

features = {
    'label': tfds.features.ClassLabel(names=['dog', 'cat', ...]),
}

理由:人類可讀的標籤用於許多地方

忘記圖片形狀

使用 tfds.features.Imagetfds.features.Video 時,如果圖片具有靜態形狀,則應明確指定

features = {
    'image': tfds.features.Image(shape=(256, 256, 3)),
}

理由:這允許靜態形狀推論 (例如 ds.element_spec['image'].shape),這是批次處理的必要條件 (批次處理未知形狀的圖片需要先調整大小)。

偏好更具體的類型,而非 tfds.features.Tensor

如果可以,偏好更具體的類型 tfds.features.ClassLabeltfds.features.BBoxFeatures 等,而非泛型 tfds.features.Tensor

理由:除了語意上更正確之外,特定功能還能為使用者提供額外的中繼資料,並可供工具偵測。

全域空間中的延遲匯入

不應從全域空間呼叫延遲匯入。例如,以下做法是錯誤的

tfds.lazy_imports.apache_beam # << Error: Import beam in the global scope

def f() -> beam.Map:
  ...

理由:在全域範圍中使用延遲匯入會為所有 tfds 使用者匯入模組,失去延遲匯入的目的。

動態計算訓練/測試分割

如果資料集未提供官方分割,TFDS 也不應提供。應避免以下做法

_TRAIN_TEST_RATIO = 0.7

def _split_generator():
  ids = list(range(num_examples))
  np.random.RandomState(seed).shuffle(ids)

  # Split train/test
  train_ids = ids[_TRAIN_TEST_RATIO * num_examples:]
  test_ids = ids[:_TRAIN_TEST_RATIO * num_examples]
  return {
      'train': self._generate_examples(train_ids),
      'test': self._generate_examples(test_ids),
  }

理由:TFDS 嘗試提供盡可能接近原始資料的資料集。子分割 API 應改為使用,讓使用者動態建立他們想要的子分割

ds_train, ds_test = tfds.load(..., split=['train[:80%]', 'train[80%:]'])

Python 樣式指南

偏好使用 pathlib API

相較於 tf.io.gfile API,建議使用 pathlib API。所有 dl_manager 方法都會傳回與 GCS、S3 等相容的 pathlib 類物件…

path = dl_manager.download_and_extract('http://some-website/my_data.zip')

json_path = path / 'data/file.json'

json.loads(json_path.read_text())

理由:pathlib API 是現代化的物件導向檔案 API,可移除重複程式碼。使用 .read_text() / .read_bytes() 也可確保檔案已正確關閉。

如果方法未使用 self,則應為函式

如果類別方法未使用 self,則應為簡單函式 (在類別外部定義)。

理由:這向讀者明確表示函式沒有副作用,也沒有隱藏的輸入/輸出

x = f(y)  # Clear inputs/outputs

x = self.f(y)  # Does f depend on additional hidden variables ? Is it stateful ?

Python 中的延遲匯入

我們會延遲匯入 TensorFlow 等大型模組。延遲匯入會將模組的實際匯入延後到第一次使用模組時。因此,不需要這個大型模組的使用者永遠不會匯入它。我們使用 etils.epy.lazy_imports

from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
# After this statement, TensorFlow is not imported yet

...

features = tfds.features.Image(dtype=tf.uint8)
# After using it (`tf.uint8`), TensorFlow is now imported

在底層,LazyModule 類別的作用如同工廠,只會在存取屬性時 (__getattr__) 實際匯入模組。

您也可以透過內容管理員方便地使用它

from etils import epy

with epy.lazy_imports(error_callback=..., success_callback=...):
  import some_big_module