適用於 Jax 和 PyTorch 的 TFDS

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視原始碼 下載筆記本

TFDS 向來與架構無關。例如,您可以輕鬆載入 NumPy 格式的資料集,以便在 Jax 和 PyTorch 中使用。

TensorFlow 及其資料載入解決方案 (tf.data) 依設計在我們的 API 中是一級公民。

我們擴充了 TFDS 以支援無 TensorFlow 僅限 NumPy 的資料載入。這對於在 Jax 和 PyTorch 等 ML 架構中使用可能很方便。事實上,對於後者使用者而言,TensorFlow 可能會

  • 保留 GPU/TPU 記憶體;
  • 增加 CI/CD 中的建置時間;
  • 花費時間在執行階段匯入。

TensorFlow 不再是讀取資料集的依附元件。

ML 管線需要資料載入器來載入範例、解碼範例並將其呈現給模型。資料載入器使用「來源/取樣器/載入器」範例

 TFDS dataset       ┌────────────────┐
   on disk          │                │
        ┌──────────►│      Data      │
|..|... │     |     │     source     ├─┐
├──┼────┴─────┤     │                │ │
│12│image12   │     └────────────────┘ │    ┌────────────────┐
├──┼──────────┤                        │    │                │
│13│image13   │                        ├───►│      Data      ├───► ML pipeline
├──┼──────────┤                        │    │     loader     │
│14│image14   │     ┌────────────────┐ │    │                │
├──┼──────────┤     │                │ │    └────────────────┘
|..|...       |     │     Index      ├─┘
                    │    sampler     │
                    │                │
                    └────────────────┘
  • 資料來源負責即時存取和解碼 TFDS 資料集中的範例。
  • 索引取樣器負責判斷處理記錄的順序。這對於在讀取任何記錄之前實作全域轉換 (例如,全域隨機排序、分片、針對多個週期重複) 很重要。
  • 資料載入器透過運用資料來源和索引取樣器來協調載入。它允許效能最佳化 (例如,預先擷取、多重處理或多執行緒)。

重點摘要

tfds.data_source 是一個 API,用於建立資料來源

  1. 適用於純 Python 管線中的快速原型設計;
  2. 以大規模管理資料密集的 ML 管線。

設定

讓我們安裝並匯入所需的依附元件

!pip install array_record
!pip install tfds-nightly

import os
os.environ.pop('TFDS_DATA_DIR', None)

import tensorflow_datasets as tfds

資料來源

資料來源基本上是 Python 序列。因此,它們需要實作下列協定

class RandomAccessDataSource(Protocol):
  """Interface for datasources where storage supports efficient random access."""

  def __len__(self) -> int:
    """Number of records in the dataset."""

  def __getitem__(self, record_key: int) -> Sequence[Any]:
    """Retrieves records for the given record_keys."""

基礎檔案格式需要支援有效率的隨機存取。目前,TFDS 依賴 array_record

array_record 是一種衍生自 Riegeli 的新檔案格式,達成了 IO 效率的新境界。特別是,ArrayRecord 支援依記錄索引進行平行讀取、寫入和隨機存取。ArrayRecord 建構於 Riegeli 之上,並支援相同的壓縮演算法。

fashion_mnist 是電腦視覺的常見資料集。若要使用 TFDS 擷取以 ArrayRecord 為基礎的資料來源,只需使用

ds = tfds.data_source('fashion_mnist')
2023-10-03 09:33:57.614443: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-10-03 09:33:57.614491: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-10-03 09:33:57.614528: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Downloading and preparing dataset 29.45 MiB (download: 29.45 MiB, generated: 36.42 MiB, total: 65.87 MiB) to /home/kbuilder/tensorflow_datasets/fashion_mnist/3.0.1...
2023-10-03 09:34:02.776237: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:268] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Dataset fashion_mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/fashion_mnist/3.0.1. Subsequent calls will reuse this data.

tfds.data_source 是一個方便的包裝函式。它相當於

builder = tfds.builder('fashion_mnist', file_format='array_record')
builder.download_and_prepare()
ds = builder.as_data_source()

這會輸出資料來源的字典

{
  'train': DataSource(name=fashion_mnist, split='train', decoders=None),
  'test': DataSource(name=fashion_mnist, split='test', decoders=None),
}

一旦 download_and_prepare 執行,而且您產生記錄檔,我們就不再需要 TensorFlow。一切都會在 Python/NumPy 中發生!

讓我們透過解除安裝 TensorFlow 並在另一個子程序中重新載入資料來源來檢查這一點

pip uninstall -y tensorflow
%%writefile no_tensorflow.py
import os
os.environ.pop('TFDS_DATA_DIR', None)

import tensorflow_datasets as tfds

try:
  import tensorflow as tf
except ImportError:
  print('No TensorFlow found...')

ds = tfds.data_source('fashion_mnist')
print('...but the data source could still be loaded...')
ds['train'][0]
print('...and the records can be decoded.')
Writing no_tensorflow.py
python no_tensorflow.py
No TensorFlow found...
...but the data source could still be loaded...
WARNING:absl:OpenCV is not installed. We recommend using OpenCV because it is faster according to our benchmarks. Defaulting to PIL to decode images...
...and the records can be decoded.

在未來的版本中,我們也將使資料集準備工作無需 TensorFlow。

資料來源具有長度

len(ds['train'])
60000

存取資料集的第一個元素

%%timeit
ds['train'][0]
WARNING:absl:OpenCV is not installed. We recommend using OpenCV because it is faster according to our benchmarks. Defaulting to PIL to decode images...
553 µs ± 4.26 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

...就像存取任何其他元素一樣便宜。這是隨機存取的定義

%%timeit
ds['train'][1000]
551 µs ± 4.76 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

功能現在使用 NumPy DType (而不是 TensorFlow DType)。您可以使用下列項目檢查功能

features = tfds.builder('fashion_mnist').info.features

您可以在我們的文件中找到關於功能的更多資訊。在這裡,我們尤其可以擷取影像的形狀和類別數量

shape = features['image'].shape
num_classes = features['label'].num_classes

在純 Python 中使用

您可以透過在 Python 中反覆運算資料來源來取用資料來源

for example in ds['train']:
  print(example)
  break
{'image': array([[[  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [ 18],
        [ 77],
        [227],
        [227],
        [208],
        [210],
        [225],
        [216],
        [ 85],
        [ 32],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [ 61],
        [100],
        [ 97],
        [ 80],
        [ 57],
        [117],
        [227],
        [238],
        [115],
        [ 49],
        [ 78],
        [106],
        [108],
        [ 71],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [ 81],
        [105],
        [ 80],
        [ 69],
        [ 72],
        [ 64],
        [ 44],
        [ 21],
        [ 13],
        [ 44],
        [ 69],
        [ 75],
        [ 75],
        [ 80],
        [114],
        [ 80],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [ 26],
        [ 92],
        [ 69],
        [ 68],
        [ 75],
        [ 75],
        [ 71],
        [ 74],
        [ 83],
        [ 75],
        [ 77],
        [ 78],
        [ 74],
        [ 74],
        [ 83],
        [ 77],
        [108],
        [ 34],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [ 55],
        [ 92],
        [ 69],
        [ 74],
        [ 74],
        [ 71],
        [ 71],
        [ 77],
        [ 69],
        [ 66],
        [ 75],
        [ 74],
        [ 77],
        [ 80],
        [ 80],
        [ 78],
        [ 94],
        [ 63],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [ 63],
        [ 95],
        [ 66],
        [ 68],
        [ 72],
        [ 72],
        [ 69],
        [ 72],
        [ 74],
        [ 74],
        [ 74],
        [ 75],
        [ 75],
        [ 77],
        [ 80],
        [ 77],
        [106],
        [ 61],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [ 80],
        [108],
        [ 71],
        [ 69],
        [ 72],
        [ 71],
        [ 69],
        [ 72],
        [ 75],
        [ 75],
        [ 72],
        [ 72],
        [ 75],
        [ 78],
        [ 72],
        [ 85],
        [128],
        [ 64],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [ 88],
        [120],
        [ 75],
        [ 74],
        [ 77],
        [ 75],
        [ 72],
        [ 77],
        [ 74],
        [ 74],
        [ 77],
        [ 78],
        [ 83],
        [ 83],
        [ 66],
        [111],
        [123],
        [ 78],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [ 85],
        [134],
        [ 74],
        [ 85],
        [ 69],
        [ 75],
        [ 75],
        [ 74],
        [ 75],
        [ 74],
        [ 75],
        [ 75],
        [ 81],
        [ 75],
        [ 61],
        [151],
        [115],
        [ 91],
        [ 12],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  0],
        [ 10],
        [ 85],
        [153],
        [ 83],
        [ 80],
        [ 68],
        [ 77],
        [ 75],
        [ 74],
        [ 75],
        [ 74],
        [ 75],
        [ 77],
        [ 80],
        [ 68],
        [ 61],
        [162],
        [122],
        [ 78],
        [  6],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  0],
        [ 30],
        [ 75],
        [154],
        [ 85],
        [ 80],
        [ 71],
        [ 80],
        [ 72],
        [ 77],
        [ 75],
        [ 75],
        [ 77],
        [ 78],
        [ 77],
        [ 75],
        [ 49],
        [191],
        [132],
        [ 72],
        [ 15],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  0],
        [ 58],
        [ 66],
        [174],
        [115],
        [ 66],
        [ 77],
        [ 80],
        [ 72],
        [ 78],
        [ 75],
        [ 77],
        [ 78],
        [ 78],
        [ 77],
        [ 66],
        [ 49],
        [222],
        [131],
        [ 77],
        [ 37],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  0],
        [ 69],
        [ 55],
        [179],
        [139],
        [ 55],
        [ 92],
        [ 74],
        [ 74],
        [ 78],
        [ 74],
        [ 78],
        [ 77],
        [ 75],
        [ 80],
        [ 64],
        [ 55],
        [242],
        [111],
        [ 95],
        [ 44],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  0],
        [ 74],
        [ 57],
        [159],
        [180],
        [ 55],
        [ 92],
        [ 64],
        [ 72],
        [ 74],
        [ 74],
        [ 77],
        [ 75],
        [ 77],
        [ 78],
        [ 55],
        [ 66],
        [255],
        [ 97],
        [108],
        [ 49],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  0],
        [ 74],
        [ 66],
        [145],
        [153],
        [ 72],
        [ 83],
        [ 58],
        [ 78],
        [ 77],
        [ 75],
        [ 75],
        [ 75],
        [ 72],
        [ 80],
        [ 30],
        [132],
        [255],
        [ 37],
        [122],
        [ 60],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  0],
        [ 80],
        [ 69],
        [142],
        [180],
        [142],
        [ 57],
        [ 64],
        [ 78],
        [ 74],
        [ 75],
        [ 75],
        [ 75],
        [ 72],
        [ 85],
        [ 21],
        [185],
        [227],
        [ 37],
        [143],
        [ 63],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  0],
        [ 83],
        [ 71],
        [136],
        [194],
        [126],
        [ 46],
        [ 69],
        [ 75],
        [ 72],
        [ 75],
        [ 75],
        [ 75],
        [ 74],
        [ 78],
        [ 38],
        [139],
        [185],
        [ 60],
        [151],
        [ 58],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  4],
        [ 81],
        [ 74],
        [145],
        [177],
        [ 78],
        [ 49],
        [ 74],
        [ 77],
        [ 75],
        [ 75],
        [ 75],
        [ 75],
        [ 74],
        [ 72],
        [ 63],
        [ 80],
        [156],
        [117],
        [153],
        [ 55],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [ 10],
        [ 80],
        [ 72],
        [157],
        [163],
        [ 61],
        [ 55],
        [ 75],
        [ 77],
        [ 75],
        [ 77],
        [ 75],
        [ 75],
        [ 75],
        [ 77],
        [ 71],
        [ 60],
        [ 98],
        [156],
        [132],
        [ 58],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [ 13],
        [ 77],
        [ 74],
        [157],
        [143],
        [ 43],
        [ 61],
        [ 72],
        [ 75],
        [ 77],
        [ 75],
        [ 74],
        [ 77],
        [ 77],
        [ 75],
        [ 71],
        [ 58],
        [ 80],
        [157],
        [120],
        [ 66],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [ 18],
        [ 81],
        [ 74],
        [156],
        [114],
        [ 35],
        [ 72],
        [ 71],
        [ 75],
        [ 78],
        [ 72],
        [ 66],
        [ 80],
        [ 78],
        [ 77],
        [ 75],
        [ 64],
        [ 63],
        [165],
        [119],
        [ 68],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [ 23],
        [ 85],
        [ 81],
        [177],
        [ 57],
        [ 52],
        [ 77],
        [ 71],
        [ 78],
        [ 80],
        [ 72],
        [ 75],
        [ 74],
        [ 77],
        [ 77],
        [ 75],
        [ 64],
        [ 37],
        [173],
        [ 95],
        [ 72],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [ 26],
        [ 81],
        [ 86],
        [160],
        [ 20],
        [ 75],
        [ 77],
        [ 77],
        [ 80],
        [ 78],
        [ 80],
        [ 89],
        [ 78],
        [ 81],
        [ 83],
        [ 80],
        [ 74],
        [ 20],
        [177],
        [ 77],
        [ 74],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [ 49],
        [ 77],
        [ 91],
        [200],
        [  0],
        [ 83],
        [ 95],
        [ 86],
        [ 88],
        [ 88],
        [ 89],
        [ 88],
        [ 89],
        [ 88],
        [ 83],
        [ 89],
        [ 86],
        [  0],
        [191],
        [ 78],
        [ 80],
        [ 24],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [ 54],
        [ 71],
        [108],
        [165],
        [  0],
        [ 24],
        [ 57],
        [ 52],
        [ 57],
        [ 60],
        [ 60],
        [ 60],
        [ 63],
        [ 63],
        [ 77],
        [ 89],
        [ 52],
        [  0],
        [211],
        [ 97],
        [ 77],
        [ 61],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [ 68],
        [ 91],
        [117],
        [137],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [ 18],
        [216],
        [ 94],
        [ 97],
        [ 57],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [ 54],
        [115],
        [105],
        [185],
        [  0],
        [  0],
        [  1],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [153],
        [ 78],
        [106],
        [ 37],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [ 18],
        [ 61],
        [ 41],
        [103],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [106],
        [ 47],
        [ 69],
        [ 23],
        [  0],
        [  0],
        [  0]]], dtype=uint8), 'label': 2}

如果您檢查元素,您也會注意到所有功能都已使用 NumPy 解碼。在幕後,我們預設使用 OpenCV,因為它速度很快。如果您沒有安裝 OpenCV,我們會預設為 Pillow 以提供輕量且快速的影像解碼。

{
  'image': array([[[0], [0], ..., [0]],
                  [[0], [0], ..., [0]]], dtype=uint8),
  'label': 2,
}

與 PyTorch 搭配使用

PyTorch 使用來源/取樣器/載入器範例。在 Torch 中,「資料來源」稱為「資料集」。torch.utils.data 包含您需要知道的所有詳細資訊,以在 Torch 中建置有效率的輸入管線。

TFDS 資料來源可以用作一般的地圖樣式資料集

首先,我們安裝並匯入 Torch

!pip install torch

from tqdm import tqdm
import torch

我們已經定義了用於訓練和測試的資料來源 (分別為 ds['train']ds['test'])。我們現在可以定義取樣器和載入器

batch_size = 128
train_sampler = torch.utils.data.RandomSampler(ds['train'], num_samples=5_000)
train_loader = torch.utils.data.DataLoader(
    ds['train'],
    sampler=train_sampler,
    batch_size=batch_size,
)
test_loader = torch.utils.data.DataLoader(
    ds['test'],
    sampler=None,
    batch_size=batch_size,
)

使用 PyTorch,我們在第一個範例上訓練和評估簡單的邏輯迴歸

class LinearClassifier(torch.nn.Module):
  def __init__(self, shape, num_classes):
    super(LinearClassifier, self).__init__()
    height, width, channels = shape
    self.classifier = torch.nn.Linear(height * width * channels, num_classes)

  def forward(self, image):
    image = image.view(image.size()[0], -1).to(torch.float32)
    return self.classifier(image)


model = LinearClassifier(shape, num_classes)
optimizer = torch.optim.Adam(model.parameters())
loss_function = torch.nn.CrossEntropyLoss()

print('Training...')
model.train()
for example in tqdm(train_loader):
  image, label = example['image'], example['label']
  prediction = model(image)
  loss = loss_function(prediction, label)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

print('Testing...')
model.eval()
num_examples = 0
true_positives = 0
for example in tqdm(test_loader):
  image, label = example['image'], example['label']
  prediction = model(image)
  num_examples += image.shape[0]
  predicted_label = prediction.argmax(dim=1)
  true_positives += (predicted_label == label).sum().item()
print(f'\nAccuracy: {true_positives/num_examples * 100:.2f}%')
Training...
100%|██████████| 40/40 [00:01<00:00, 28.90it/s]
Testing...
100%|██████████| 79/79 [00:02<00:00, 34.75it/s]
Accuracy: 62.07%

即將推出:與 JAX 搭配使用

我們正與 Grain 密切合作。Grain 是適用於 Python 的開放原始碼、快速且具決定性的資料載入器。敬請期待!

閱讀更多資訊

如需更多資訊,請參閱 tfds.data_source API 文件。