![]() |
![]() |
![]() |
![]() |
TFDS 一直以來都與框架無關。例如,您可以輕鬆地以 NumPy 格式載入資料集,以便在 Jax 和 PyTorch 中使用。
TensorFlow 及其資料載入解決方案 (tf.data
) 在我們的 API 中從設計上來說就是一等公民。
我們擴展了 TFDS 以支援無 TensorFlow 的純 NumPy 資料載入。這對於在 Jax 和 PyTorch 等機器學習框架中使用可能很方便。事實上,對於後者使用者而言,TensorFlow 可能會
- 保留 GPU/TPU 記憶體;
- 增加 CI/CD 的建置時間;
- 在執行階段花費時間匯入。
讀取資料集不再需要 TensorFlow 作為依賴項目。
機器學習管線需要資料載入器來載入範例、解碼範例並將其呈現給模型。資料載入器使用「來源/取樣器/載入器」範例
TFDS dataset ┌────────────────┐
on disk │ │
┌──────────►│ Data │
|..|... │ | │ source ├─┐
├──┼────┴─────┤ │ │ │
│12│image12 │ └────────────────┘ │ ┌────────────────┐
├──┼──────────┤ │ │ │
│13│image13 │ ├───►│ Data ├───► ML pipeline
├──┼──────────┤ │ │ loader │
│14│image14 │ ┌────────────────┐ │ │ │
├──┼──────────┤ │ │ │ └────────────────┘
|..|... | │ Index ├─┘
│ sampler │
│ │
└────────────────┘
- 資料來源負責即時存取和解碼 TFDS 資料集中的範例。
- 索引取樣器負責決定記錄的處理順序。這對於在讀取任何記錄之前實作全域轉換(例如,全域隨機排序、分片、針對多個 epoch 重複)非常重要。
- 資料載入器透過利用資料來源和索引取樣器來協調載入。它允許效能最佳化(例如,預先擷取、多處理或多執行緒)。
重點摘要
tfds.data_source
是一個用於建立資料來源的 API
- 用於在純 Python 管線中快速建立原型;
- 用於大規模管理資料密集型機器學習管線。
設定
讓我們安裝並匯入所需的依賴項目
!pip install array_record
!pip install grain-nightly
!pip install jax jaxlib
!pip install tfds-nightly
import os
os.environ.pop('TFDS_DATA_DIR', None)
import tensorflow_datasets as tfds
資料來源
資料來源基本上是 Python 序列。因此,它們需要實作以下協定
from typing import SupportsIndex
class RandomAccessDataSource(Protocol):
"""Interface for datasources where storage supports efficient random access."""
def __len__(self) -> int:
"""Number of records in the dataset."""
def __getitem__(self, key: SupportsIndex) -> Any:
"""Retrieves the record for the given key."""
底層檔案格式需要支援有效率的隨機存取。目前,TFDS 依賴 array_record
。
array_record
是一種衍生自 Riegeli 的新檔案格式,在 IO 效率方面達到新的境界。特別是,ArrayRecord 支援依記錄索引進行平行讀取、寫入和隨機存取。ArrayRecord 建構於 Riegeli 之上,並支援相同的壓縮演算法。
fashion_mnist
是電腦視覺的常見資料集。若要使用 TFDS 擷取以 ArrayRecord 為基礎的資料來源,只需使用
ds = tfds.data_source('fashion_mnist')
Downloading and preparing dataset 29.45 MiB (download: 29.45 MiB, generated: 36.42 MiB, total: 65.87 MiB) to /home/kbuilder/tensorflow_datasets/fashion_mnist/3.0.1... 2024-04-26 11:20:57.419076: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected Dataset fashion_mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/fashion_mnist/3.0.1. Subsequent calls will reuse this data.
tfds.data_source
是一個方便的包裝函式。它等同於
builder = tfds.builder('fashion_mnist', file_format='array_record')
builder.download_and_prepare()
ds = builder.as_data_source()
這會輸出資料來源的字典
{
'train': DataSource(name=fashion_mnist, split='train', decoders=None),
'test': DataSource(name=fashion_mnist, split='test', decoders=None),
}
一旦 download_and_prepare
執行完畢,且您產生了記錄檔,我們就不再需要 TensorFlow 了。所有操作都會在 Python/NumPy 中進行!
讓我們透過解除安裝 TensorFlow 並在另一個子程序中重新載入資料來源來檢查這一點
pip uninstall -y tensorflow
/usr/lib/python3.9/pty.py:85: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock. pid, fd = os.forkpty()
%%writefile no_tensorflow.py
import os
os.environ.pop('TFDS_DATA_DIR', None)
import tensorflow_datasets as tfds
try:
import tensorflow as tf
except ImportError:
print('No TensorFlow found...')
ds = tfds.data_source('fashion_mnist')
print('...but the data source could still be loaded...')
ds['train'][0]
print('...and the records can be decoded.')
Writing no_tensorflow.py
python no_tensorflow.py
No TensorFlow found... ...but the data source could still be loaded... WARNING:absl:OpenCV is not installed. We recommend using OpenCV because it is faster according to our benchmarks. Defaulting to PIL to decode images... ...and the records can be decoded.
在未來版本中,我們也將使資料集準備工作無需 TensorFlow。
資料來源具有長度
len(ds['train'])
60000
存取資料集的第一個元素
%%timeit
ds['train'][0]
WARNING:absl:OpenCV is not installed. We recommend using OpenCV because it is faster according to our benchmarks. Defaulting to PIL to decode images... 584 µs ± 2.11 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
…與存取任何其他元素一樣簡單。這是隨機存取的定義
%%timeit
ds['train'][1000]
581 µs ± 2.33 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
功能現在使用 NumPy DTypes(而非 TensorFlow DTypes)。您可以使用以下程式碼檢查功能
features = tfds.builder('fashion_mnist').info.features
您可以在我們的文件中找到關於功能的更多資訊。在這裡,我們尤其可以擷取影像的形狀和類別數量
shape = features['image'].shape
num_classes = features['label'].num_classes
在純 Python 中使用
您可以在 Python 中透過反覆運算資料來源來使用它們
for example in ds['train']:
print(example)
break
{'image': array([[[ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 18], [ 77], [227], [227], [208], [210], [225], [216], [ 85], [ 32], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 61], [100], [ 97], [ 80], [ 57], [117], [227], [238], [115], [ 49], [ 78], [106], [108], [ 71], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 81], [105], [ 80], [ 69], [ 72], [ 64], [ 44], [ 21], [ 13], [ 44], [ 69], [ 75], [ 75], [ 80], [114], [ 80], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 0], [ 0], [ 26], [ 92], [ 69], [ 68], [ 75], [ 75], [ 71], [ 74], [ 83], [ 75], [ 77], [ 78], [ 74], [ 74], [ 83], [ 77], [108], [ 34], [ 0], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 0], [ 0], [ 55], [ 92], [ 69], [ 74], [ 74], [ 71], [ 71], [ 77], [ 69], [ 66], [ 75], [ 74], [ 77], [ 80], [ 80], [ 78], [ 94], [ 63], [ 0], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 0], [ 0], [ 63], [ 95], [ 66], [ 68], [ 72], [ 72], [ 69], [ 72], [ 74], [ 74], [ 74], [ 75], [ 75], [ 77], [ 80], [ 77], [106], [ 61], [ 0], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 0], [ 0], [ 80], [108], [ 71], [ 69], [ 72], [ 71], [ 69], [ 72], [ 75], [ 75], [ 72], [ 72], [ 75], [ 78], [ 72], [ 85], [128], [ 64], [ 0], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 0], [ 0], [ 88], [120], [ 75], [ 74], [ 77], [ 75], [ 72], [ 77], [ 74], [ 74], [ 77], [ 78], [ 83], [ 83], [ 66], [111], [123], [ 78], [ 0], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 0], [ 0], [ 85], [134], [ 74], [ 85], [ 69], [ 75], [ 75], [ 74], [ 75], [ 74], [ 75], [ 75], [ 81], [ 75], [ 61], [151], [115], [ 91], [ 12], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 0], [ 10], [ 85], [153], [ 83], [ 80], [ 68], [ 77], [ 75], [ 74], [ 75], [ 74], [ 75], [ 77], [ 80], [ 68], [ 61], [162], [122], [ 78], [ 6], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 0], [ 30], [ 75], [154], [ 85], [ 80], [ 71], [ 80], [ 72], [ 77], [ 75], [ 75], [ 77], [ 78], [ 77], [ 75], [ 49], [191], [132], [ 72], [ 15], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 0], [ 58], [ 66], [174], [115], [ 66], [ 77], [ 80], [ 72], [ 78], [ 75], [ 77], [ 78], [ 78], [ 77], [ 66], [ 49], [222], [131], [ 77], [ 37], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 0], [ 69], [ 55], [179], [139], [ 55], [ 92], [ 74], [ 74], [ 78], [ 74], [ 78], [ 77], [ 75], [ 80], [ 64], [ 55], [242], [111], [ 95], [ 44], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 0], [ 74], [ 57], [159], [180], [ 55], [ 92], [ 64], [ 72], [ 74], [ 74], [ 77], [ 75], [ 77], [ 78], [ 55], [ 66], [255], [ 97], [108], [ 49], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 0], [ 74], [ 66], [145], [153], [ 72], [ 83], [ 58], [ 78], [ 77], [ 75], [ 75], [ 75], [ 72], [ 80], [ 30], [132], [255], [ 37], [122], [ 60], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 0], [ 80], [ 69], [142], [180], [142], [ 57], [ 64], [ 78], [ 74], [ 75], [ 75], [ 75], [ 72], [ 85], [ 21], [185], [227], [ 37], [143], [ 63], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 0], [ 83], [ 71], [136], [194], [126], [ 46], [ 69], [ 75], [ 72], [ 75], [ 75], [ 75], [ 74], [ 78], [ 38], [139], [185], [ 60], [151], [ 58], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 4], [ 81], [ 74], [145], [177], [ 78], [ 49], [ 74], [ 77], [ 75], [ 75], [ 75], [ 75], [ 74], [ 72], [ 63], [ 80], [156], [117], [153], [ 55], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 10], [ 80], [ 72], [157], [163], [ 61], [ 55], [ 75], [ 77], [ 75], [ 77], [ 75], [ 75], [ 75], [ 77], [ 71], [ 60], [ 98], [156], [132], [ 58], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 13], [ 77], [ 74], [157], [143], [ 43], [ 61], [ 72], [ 75], [ 77], [ 75], [ 74], [ 77], [ 77], [ 75], [ 71], [ 58], [ 80], [157], [120], [ 66], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 18], [ 81], [ 74], [156], [114], [ 35], [ 72], [ 71], [ 75], [ 78], [ 72], [ 66], [ 80], [ 78], [ 77], [ 75], [ 64], [ 63], [165], [119], [ 68], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 23], [ 85], [ 81], [177], [ 57], [ 52], [ 77], [ 71], [ 78], [ 80], [ 72], [ 75], [ 74], [ 77], [ 77], [ 75], [ 64], [ 37], [173], [ 95], [ 72], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 26], [ 81], [ 86], [160], [ 20], [ 75], [ 77], [ 77], [ 80], [ 78], [ 80], [ 89], [ 78], [ 81], [ 83], [ 80], [ 74], [ 20], [177], [ 77], [ 74], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 49], [ 77], [ 91], [200], [ 0], [ 83], [ 95], [ 86], [ 88], [ 88], [ 89], [ 88], [ 89], [ 88], [ 83], [ 89], [ 86], [ 0], [191], [ 78], [ 80], [ 24], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 54], [ 71], [108], [165], [ 0], [ 24], [ 57], [ 52], [ 57], [ 60], [ 60], [ 60], [ 63], [ 63], [ 77], [ 89], [ 52], [ 0], [211], [ 97], [ 77], [ 61], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 68], [ 91], [117], [137], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 18], [216], [ 94], [ 97], [ 57], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 54], [115], [105], [185], [ 0], [ 0], [ 1], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [153], [ 78], [106], [ 37], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 18], [ 61], [ 41], [103], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [106], [ 47], [ 69], [ 23], [ 0], [ 0], [ 0]]], dtype=uint8), 'label': 2}
如果您檢查元素,您也會注意到所有功能都已使用 NumPy 解碼。在幕後,我們預設使用 OpenCV,因為它速度很快。如果您沒有安裝 OpenCV,我們會預設使用 Pillow 來提供輕量且快速的影像解碼。
{
'image': array([[[0], [0], ..., [0]],
[[0], [0], ..., [0]]], dtype=uint8),
'label': 2,
}
搭配 PyTorch 使用
PyTorch 使用來源/取樣器/載入器範例。在 Torch 中,「資料來源」稱為「資料集」。torch.utils.data
包含您需要知道的所有詳細資訊,以便在 Torch 中建構有效率的輸入管線。
TFDS 資料來源可以用作一般的對應樣式資料集。
首先,我們安裝並匯入 Torch
!pip install torch
from tqdm import tqdm
import torch
我們已經定義了用於訓練和測試的資料來源(分別為 ds['train']
和 ds['test']
)。我們現在可以定義取樣器和載入器
batch_size = 128
train_sampler = torch.utils.data.RandomSampler(ds['train'], num_samples=5_000)
train_loader = torch.utils.data.DataLoader(
ds['train'],
sampler=train_sampler,
batch_size=batch_size,
)
test_loader = torch.utils.data.DataLoader(
ds['test'],
sampler=None,
batch_size=batch_size,
)
使用 PyTorch,我們在前幾個範例上訓練和評估簡單的邏輯迴歸
class LinearClassifier(torch.nn.Module):
def __init__(self, shape, num_classes):
super(LinearClassifier, self).__init__()
height, width, channels = shape
self.classifier = torch.nn.Linear(height * width * channels, num_classes)
def forward(self, image):
image = image.view(image.size()[0], -1).to(torch.float32)
return self.classifier(image)
model = LinearClassifier(shape, num_classes)
optimizer = torch.optim.Adam(model.parameters())
loss_function = torch.nn.CrossEntropyLoss()
print('Training...')
model.train()
for example in tqdm(train_loader):
image, label = example['image'], example['label']
prediction = model(image)
loss = loss_function(prediction, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Testing...')
model.eval()
num_examples = 0
true_positives = 0
for example in tqdm(test_loader):
image, label = example['image'], example['label']
prediction = model(image)
num_examples += image.shape[0]
predicted_label = prediction.argmax(dim=1)
true_positives += (predicted_label == label).sum().item()
print(f'\nAccuracy: {true_positives/num_examples * 100:.2f}%')
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/torch/cuda/__init__.py:619: UserWarning: Can't initialize NVML warnings.warn("Can't initialize NVML") Training... 100%|██████████| 40/40 [00:01<00:00, 31.22it/s] Testing... 100%|██████████| 79/79 [00:02<00:00, 32.66it/s] Accuracy: 65.63%
搭配 JAX 使用
Grain 是一個用於讀取資料以訓練和評估 JAX 模型的程式庫。它是開放原始碼、快速且具決定性的。Grain 使用來源/取樣器/載入器範例,因此我們可以重複使用 tfds.data_source
import grain.python as pygrain
import numpy as np
data_source = tfds.data_source("fashion_mnist", split="train")
# To shuffle the data, use a sampler:
sampler = pygrain.IndexSampler(
num_records=5,
num_epochs=1,
shard_options=pygrain.NoSharding(),
shuffle=True,
seed=0,
)
轉換定義為類別,可以是 BatchTransform
、FilterTransform
或 MapTransform
class ImageToText(pygrain.MapTransform):
"""Maps an image to text."""
LABEL_TO_TEXT = {
0: "zero",
1: "one",
2: "two",
3: "three",
4: "four",
5: "five",
6: "six",
7: "seven",
8: "height",
9: "nine",
}
def map(self, element: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
label = element["label"]
text = self.LABEL_TO_TEXT[label]
element["text"] = text
return element
# You can chain transformations in a list:
operations = [ImageToText()]
最後,資料載入器負責協調載入。您可以透過多處理進行擴展,以同時享受 Python 的彈性和資料載入器的效能
loader = pygrain.DataLoader(
data_source=data_source,
operations=operations,
sampler=sampler,
worker_count=0, # Scale to multiple workers in multiprocessing
)
for element in loader:
print(element["text"])
two one one height four
閱讀更多
如需更多資訊,請參閱 tfds.data_source
API 文件。