tf.data:建構 TensorFlow 輸入管線

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視原始碼 下載筆記本

tf.data API 可讓您從簡單、可重複使用的片段建構複雜的輸入管線。例如,影像模型的管線可能會彙總來自分散式檔案系統中檔案的資料、對每張影像套用隨機擾動,以及將隨機選取的影像合併成批次以進行訓練。文字模型的管線可能涉及從原始文字資料中擷取符號、使用查閱表將其轉換為嵌入識別碼,以及將不同長度的序列分批處理。tf.data API 讓處理大量資料、從不同資料格式讀取,以及執行複雜轉換成為可能。

tf.data API 引入了 tf.data.Dataset 抽象概念,代表一系列元素,其中每個元素都由一個或多個元件組成。例如,在影像管線中,元素可能是單一訓練範例,具有一對張量元件代表影像及其標籤。

有兩種不同的方式可以建立資料集

  • 資料來源會從記憶體或一或多個檔案中儲存的資料建構 Dataset

  • 資料轉換會從一或多個 tf.data.Dataset 物件建構資料集。

import tensorflow as tf
2024-01-17 02:24:38.925403: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-17 02:24:38.925446: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-17 02:24:38.927059: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
import pathlib
import os
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

np.set_printoptions(precision=4)

基本機制

若要建立輸入管線,您必須從資料來源開始。例如,若要從記憶體中的資料建構 Dataset,您可以使用 tf.data.Dataset.from_tensors()tf.data.Dataset.from_tensor_slices()。或者,如果您的輸入資料以建議的 TFRecord 格式儲存在檔案中,則可以使用 tf.data.TFRecordDataset()

取得 Dataset 物件後,您可以透過在 tf.data.Dataset 物件上串連方法呼叫來將其轉換為新的 Dataset。例如,您可以套用每個元素的轉換,例如 Dataset.map,以及多元素轉換,例如 Dataset.batch。請參閱 tf.data.Dataset 的文件,以取得完整的轉換清單。

Dataset 物件是 Python 可迭代物件。這使得可以使用 for 迴圈來取用其元素

dataset = tf.data.Dataset.from_tensor_slices([8, 3, 0, 8, 2, 1])
dataset
<_TensorSliceDataset element_spec=TensorSpec(shape=(), dtype=tf.int32, name=None)>
for elem in dataset:
  print(elem.numpy())
8
3
0
8
2
1

或透過使用 iter 明確建立 Python 迭代器,並使用 next 取用其元素

it = iter(dataset)

print(next(it).numpy())
8

或者,可以使用 reduce 轉換來取用資料集元素,這會減少所有元素以產生單一結果。以下範例說明如何使用 reduce 轉換來計算整數資料集的總和。

print(dataset.reduce(0, lambda state, value: state + value).numpy())
22

資料集結構

資料集會產生一系列元素,其中每個元素都是相同的 (巢狀) 元件結構。結構的個別元件可以是 tf.TypeSpec 可表示的任何類型,包括 tf.Tensortf.sparse.SparseTensortf.RaggedTensortf.TensorArraytf.data.Dataset

可用於表示元素 (巢狀) 結構的 Python 建構包括 tupledictNamedTupleOrderedDict。特別是,list 不是表示資料集元素結構的有效建構。這是因為早期的 tf.data 使用者強烈感覺到 list 輸入 (例如,傳遞至 tf.data.Dataset.from_tensors 時) 會自動封裝為張量,而 list 輸出 (例如,使用者定義函式的傳回值) 會強制轉換為 tuple。因此,如果您希望將 list 輸入視為結構,則需要將其轉換為 tuple,如果您希望 list 輸出是單一元件,則需要使用 tf.stack 明確封裝它。

Dataset.element_spec 屬性可讓您檢查每個元素元件的類型。此屬性會傳回 tf.TypeSpec 物件的巢狀結構,與元素的結構 (可能是單一元件、元件元組或元件的巢狀元組) 相符。例如

dataset1 = tf.data.Dataset.from_tensor_slices(tf.random.uniform([4, 10]))

dataset1.element_spec
TensorSpec(shape=(10,), dtype=tf.float32, name=None)
dataset2 = tf.data.Dataset.from_tensor_slices(
   (tf.random.uniform([4]),
    tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)))

dataset2.element_spec
(TensorSpec(shape=(), dtype=tf.float32, name=None),
 TensorSpec(shape=(100,), dtype=tf.int32, name=None))
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))

dataset3.element_spec
(TensorSpec(shape=(10,), dtype=tf.float32, name=None),
 (TensorSpec(shape=(), dtype=tf.float32, name=None),
  TensorSpec(shape=(100,), dtype=tf.int32, name=None)))
# Dataset containing a sparse tensor.
dataset4 = tf.data.Dataset.from_tensors(tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]))

dataset4.element_spec
SparseTensorSpec(TensorShape([3, 4]), tf.int32)
# Use value_type to see the type of value represented by the element spec
dataset4.element_spec.value_type
tensorflow.python.framework.sparse_tensor.SparseTensor

Dataset 轉換支援任何結構的資料集。當使用 Dataset.mapDataset.filter 轉換 (將函式套用至每個元素) 時,元素結構會決定函式的引數

dataset1 = tf.data.Dataset.from_tensor_slices(
    tf.random.uniform([4, 10], minval=1, maxval=10, dtype=tf.int32))

dataset1
<_TensorSliceDataset element_spec=TensorSpec(shape=(10,), dtype=tf.int32, name=None)>
for z in dataset1:
  print(z.numpy())
[2 1 3 2 1 3 9 6 4 5]
[8 9 2 9 1 4 7 2 4 5]
[3 9 6 8 4 6 8 4 9 5]
[4 3 7 2 8 6 4 9 7 6]
dataset2 = tf.data.Dataset.from_tensor_slices(
   (tf.random.uniform([4]),
    tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)))

dataset2
<_TensorSliceDataset element_spec=(TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(100,), dtype=tf.int32, name=None))>
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))

dataset3
<_ZipDataset element_spec=(TensorSpec(shape=(10,), dtype=tf.int32, name=None), (TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(100,), dtype=tf.int32, name=None)))>
for a, (b,c) in dataset3:
  print('shapes: {a.shape}, {b.shape}, {c.shape}'.format(a=a, b=b, c=c))
shapes: (10,), (), (100,)
shapes: (10,), (), (100,)
shapes: (10,), (), (100,)
shapes: (10,), (), (100,)

讀取輸入資料

取用 NumPy 陣列

如需更多範例,請參閱載入 NumPy 陣列教學課程。

如果您的所有輸入資料都適合放在記憶體中,則從它們建立 Dataset 的最簡單方法是將它們轉換為 tf.Tensor 物件,並使用 Dataset.from_tensor_slices

train, test = tf.keras.datasets.fashion_mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
29515/29515 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26421880/26421880 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
5148/5148 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4422102/4422102 [==============================] - 0s 0us/step
images, labels = train
images = images/255

dataset = tf.data.Dataset.from_tensor_slices((images, labels))
dataset
<_TensorSliceDataset element_spec=(TensorSpec(shape=(28, 28), dtype=tf.float64, name=None), TensorSpec(shape=(), dtype=tf.uint8, name=None))>

取用 Python 產生器

可以輕鬆作為 tf.data.Dataset 擷取的另一個常見資料來源是 Python 產生器。

def count(stop):
  i = 0
  while i<stop:
    yield i
    i += 1
for n in count(5):
  print(n)
0
1
2
3
4

Dataset.from_generator 建構函式會將 Python 產生器轉換為功能完整的 tf.data.Dataset

此建構函式採用可呼叫物件作為輸入,而非迭代器。這使其可以在產生器到達結尾時重新啟動產生器。它採用選用的 args 引數,該引數會作為可呼叫物件的引數傳遞。

output_types 引數是必要的,因為 tf.data 在內部建置 tf.Graph,而圖邊緣需要 tf.dtype

ds_counter = tf.data.Dataset.from_generator(count, args=[25], output_types=tf.int32, output_shapes = (), )
for count_batch in ds_counter.repeat().batch(10).take(10):
  print(count_batch.numpy())
[0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24  0  1  2  3  4]
[ 5  6  7  8  9 10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24]
[0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24  0  1  2  3  4]
[ 5  6  7  8  9 10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24]

output_shapes 引數不是必要的,但強烈建議使用,因為許多 TensorFlow 運算不支援具有未知等級的張量。如果特定軸的長度未知或可變,請在 output_shapes 中將其設定為 None

同樣重要的是要注意,output_shapesoutput_types 遵循與其他資料集方法相同的巢狀規則。

以下是一個範例產生器,示範了這兩個方面:它傳回陣列的元組,其中第二個陣列是具有未知長度的向量。

def gen_series():
  i = 0
  while True:
    size = np.random.randint(0, 10)
    yield i, np.random.normal(size=(size,))
    i += 1
for i, series in gen_series():
  print(i, ":", str(series))
  if i > 5:
    break
0 : [-1.1977 -1.3099]
1 : [-0.3763  1.3639  0.0703  0.7888  0.3659 -0.0422 -0.5699 -1.4458]
2 : [ 0.4438  0.2206 -0.8348  3.0743 -0.2304  0.6876]
3 : [ 0.1138  0.3484 -0.3989  0.1871 -0.9462  0.7905  0.0224  0.204  -1.2715]
4 : [ 1.0292 -1.7965 -1.1569  0.437   1.9364  0.4718 -0.5036 -0.1318]
5 : [0.6893]
6 : [-0.2385 -0.3129  0.4913  0.2546  1.4849 -1.3109 -0.3785]

第一個輸出是 int32,第二個輸出是 float32

第一個項目是純量,形狀為 (),第二個項目是未知長度的向量,形狀為 (None,)

ds_series = tf.data.Dataset.from_generator(
    gen_series,
    output_types=(tf.int32, tf.float32),
    output_shapes=((), (None,)))

ds_series
<_FlatMapDataset element_spec=(TensorSpec(shape=(), dtype=tf.int32, name=None), TensorSpec(shape=(None,), dtype=tf.float32, name=None))>

現在可以像常規 tf.data.Dataset 一樣使用它。請注意,在批次處理具有可變形狀的資料集時,您需要使用 Dataset.padded_batch

ds_series_batch = ds_series.shuffle(20).padded_batch(10)

ids, sequence_batch = next(iter(ds_series_batch))
print(ids.numpy())
print()
print(sequence_batch.numpy())
[ 4 18 19  2  8 22 13  0 25 27]

[[-0.5268  0.8465  1.8949 -0.6337 -0.9212  0.2917  0.1995 -0.2283  1.5621]
 [-0.7196  0.3447 -0.5744 -1.6807  1.9387 -0.7832  1.1232  0.5444  0.3566]
 [-1.0073  0.      0.      0.      0.      0.      0.      0.      0.    ]
 [ 1.3614 -0.0866  0.4309 -1.1438  0.066   0.3847 -0.8009  0.      0.    ]
 [-0.7528  0.      0.      0.      0.      0.      0.      0.      0.    ]
 [-0.006   0.9022  1.2462  0.0703  0.      0.      0.      0.      0.    ]
 [ 0.5811  0.      0.      0.      0.      0.      0.      0.      0.    ]
 [-0.1996  1.6923 -0.274  -0.7509 -0.6734 -1.687  -0.8438 -1.0904  0.    ]
 [ 0.3178  0.0775  1.3367  1.0921  0.1651  0.9298  0.0764 -0.4039  0.    ]
 [ 1.5668 -1.3154  0.8587 -0.7022  0.      0.      0.      0.      0.    ]]

如需更實際的範例,請嘗試將 preprocessing.image.ImageDataGenerator 包裝為 tf.data.Dataset

首先下載資料

flowers = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
228813984/228813984 [==============================] - 1s 0us/step

建立 image.ImageDataGenerator

img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, rotation_range=20)
images, labels = next(img_gen.flow_from_directory(flowers))
Found 3670 images belonging to 5 classes.
print(images.dtype, images.shape)
print(labels.dtype, labels.shape)
float32 (32, 256, 256, 3)
float32 (32, 5)
ds = tf.data.Dataset.from_generator(
    lambda: img_gen.flow_from_directory(flowers),
    output_types=(tf.float32, tf.float32),
    output_shapes=([32,256,256,3], [32,5])
)

ds.element_spec
(TensorSpec(shape=(32, 256, 256, 3), dtype=tf.float32, name=None),
 TensorSpec(shape=(32, 5), dtype=tf.float32, name=None))
for images, labels in ds.take(1):
  print('images.shape: ', images.shape)
  print('labels.shape: ', labels.shape)
Found 3670 images belonging to 5 classes.
images.shape:  (32, 256, 256, 3)
labels.shape:  (32, 5)

取用 TFRecord 資料

如需端對端範例,請參閱載入 TFRecords 教學課程。

tf.data API 支援各種檔案格式,以便您可以處理不適合放在記憶體中的大型資料集。例如,TFRecord 檔案格式是一種簡單的記錄導向二進位格式,許多 TensorFlow 應用程式使用它來處理訓練資料。tf.data.TFRecordDataset 類別可讓您串流處理一或多個 TFRecord 檔案的內容,作為輸入管線的一部分。

以下範例使用來自法國街道名稱標誌 (FSNS) 的測試檔案。

# Creates a dataset that reads all of the examples from two files.
fsns_test_file = tf.keras.utils.get_file("fsns.tfrec", "https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001")
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001
7904079/7904079 [==============================] - 0s 0us/step

TFRecordDataset 初始化程式的 filenames 引數可以是字串、字串清單或字串的 tf.Tensor。因此,如果您有兩組檔案用於訓練和驗證目的,您可以建立一個工廠方法來產生資料集,並將檔案名稱作為輸入引數

dataset = tf.data.TFRecordDataset(filenames = [fsns_test_file])
dataset
<TFRecordDatasetV2 element_spec=TensorSpec(shape=(), dtype=tf.string, name=None)>

許多 TensorFlow 專案在其 TFRecord 檔案中使用序列化的 tf.train.Example 記錄。這些記錄需要先解碼才能檢查

raw_example = next(iter(dataset))
parsed = tf.train.Example.FromString(raw_example.numpy())

parsed.features.feature['image/text']
bytes_list {
  value: "Rue Perreyon"
}

取用文字資料

如需端對端範例,請參閱載入文字教學課程。

許多資料集以一或多個文字檔案的形式散佈。tf.data.TextLineDataset 提供了一種從一或多個文字檔案中擷取行的簡單方法。給定一或多個檔案名稱,TextLineDataset 將為這些檔案的每一行產生一個字串值元素。

directory_url = 'https://storage.googleapis.com/download.tensorflow.org/data/illiad/'
file_names = ['cowper.txt', 'derby.txt', 'butler.txt']

file_paths = [
    tf.keras.utils.get_file(file_name, directory_url + file_name)
    for file_name in file_names
]
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/cowper.txt
815980/815980 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/derby.txt
809730/809730 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/butler.txt
807992/807992 [==============================] - 0s 0us/step
dataset = tf.data.TextLineDataset(file_paths)

以下是第一個檔案的前幾行

for line in dataset.take(5):
  print(line.numpy())
b"\xef\xbb\xbfAchilles sing, O Goddess! Peleus' son;"
b'His wrath pernicious, who ten thousand woes'
b"Caused to Achaia's host, sent many a soul"
b'Illustrious into Ades premature,'
b'And Heroes gave (so stood the will of Jove)'

若要在檔案之間交替行,請使用 Dataset.interleave。這使得更易於將檔案混洗在一起。以下是每個翻譯的第一行、第二行和第三行

files_ds = tf.data.Dataset.from_tensor_slices(file_paths)
lines_ds = files_ds.interleave(tf.data.TextLineDataset, cycle_length=3)

for i, line in enumerate(lines_ds.take(9)):
  if i % 3 == 0:
    print()
  print(line.numpy())
b"\xef\xbb\xbfAchilles sing, O Goddess! Peleus' son;"
b"\xef\xbb\xbfOf Peleus' son, Achilles, sing, O Muse,"
b'\xef\xbb\xbfSing, O goddess, the anger of Achilles son of Peleus, that brought'

b'His wrath pernicious, who ten thousand woes'
b'The vengeance, deep and deadly; whence to Greece'
b'countless ills upon the Achaeans. Many a brave soul did it send'

b"Caused to Achaia's host, sent many a soul"
b'Unnumbered ills arose; which many a soul'
b'hurrying down to Hades, and many a hero did it yield a prey to dogs and'

預設情況下,TextLineDataset 會產生每個檔案的每一行,這可能不是理想的,例如,如果檔案以標頭行開頭,或包含註解。可以使用 Dataset.skip()Dataset.filter 轉換來移除這些行。在這裡,您跳過第一行,然後篩選以僅尋找倖存者。

titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
titanic_lines = tf.data.TextLineDataset(titanic_file)
Downloading data from https://storage.googleapis.com/tf-datasets/titanic/train.csv
30874/30874 [==============================] - 0s 0us/step
for line in titanic_lines.take(10):
  print(line.numpy())
b'survived,sex,age,n_siblings_spouses,parch,fare,class,deck,embark_town,alone'
b'0,male,22.0,1,0,7.25,Third,unknown,Southampton,n'
b'1,female,38.0,1,0,71.2833,First,C,Cherbourg,n'
b'1,female,26.0,0,0,7.925,Third,unknown,Southampton,y'
b'1,female,35.0,1,0,53.1,First,C,Southampton,n'
b'0,male,28.0,0,0,8.4583,Third,unknown,Queenstown,y'
b'0,male,2.0,3,1,21.075,Third,unknown,Southampton,n'
b'1,female,27.0,0,2,11.1333,Third,unknown,Southampton,n'
b'1,female,14.0,1,0,30.0708,Second,unknown,Cherbourg,n'
b'1,female,4.0,1,1,16.7,Third,G,Southampton,n'
def survived(line):
  return tf.not_equal(tf.strings.substr(line, 0, 1), "0")

survivors = titanic_lines.skip(1).filter(survived)
for line in survivors.take(10):
  print(line.numpy())
b'1,female,38.0,1,0,71.2833,First,C,Cherbourg,n'
b'1,female,26.0,0,0,7.925,Third,unknown,Southampton,y'
b'1,female,35.0,1,0,53.1,First,C,Southampton,n'
b'1,female,27.0,0,2,11.1333,Third,unknown,Southampton,n'
b'1,female,14.0,1,0,30.0708,Second,unknown,Cherbourg,n'
b'1,female,4.0,1,1,16.7,Third,G,Southampton,n'
b'1,male,28.0,0,0,13.0,Second,unknown,Southampton,y'
b'1,female,28.0,0,0,7.225,Third,unknown,Cherbourg,y'
b'1,male,28.0,0,0,35.5,First,A,Southampton,y'
b'1,female,38.0,1,5,31.3875,Third,unknown,Southampton,n'

取用 CSV 資料

如需更多範例,請參閱載入 CSV 檔案載入 Pandas DataFrame 教學課程。

CSV 檔案格式是以純文字儲存表格資料的常用格式。

例如

titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
df = pd.read_csv(titanic_file)
df.head()

如果您的資料適合放在記憶體中,則相同的 Dataset.from_tensor_slices 方法適用於字典,允許輕鬆匯入此資料

titanic_slices = tf.data.Dataset.from_tensor_slices(dict(df))

for feature_batch in titanic_slices.take(1):
  for key, value in feature_batch.items():
    print("  {!r:20s}: {}".format(key, value))
'survived'          : 0
  'sex'               : b'male'
  'age'               : 22.0
  'n_siblings_spouses': 1
  'parch'             : 0
  'fare'              : 7.25
  'class'             : b'Third'
  'deck'              : b'unknown'
  'embark_town'       : b'Southampton'
  'alone'             : b'n'

更具延展性的方法是根據需要從磁碟載入。

tf.data 模組提供從符合 RFC 4180 的一或多個 CSV 檔案中擷取記錄的方法。

tf.data.experimental.make_csv_dataset 函式是讀取 CSV 檔案集的高階介面。它支援欄類型推斷和許多其他功能,例如批次處理和混洗,使使用變得簡單。

titanic_batches = tf.data.experimental.make_csv_dataset(
    titanic_file, batch_size=4,
    label_name="survived")
for feature_batch, label_batch in titanic_batches.take(1):
  print("'survived': {}".format(label_batch))
  print("features:")
  for key, value in feature_batch.items():
    print("  {!r:20s}: {}".format(key, value))
'survived': [0 1 0 0]
features:
  'sex'               : [b'male' b'male' b'male' b'female']
  'age'               : [28. 25. 30. 28.]
  'n_siblings_spouses': [1 1 0 3]
  'parch'             : [0 0 0 1]
  'fare'              : [15.5     7.775  27.75   25.4667]
  'class'             : [b'Third' b'Third' b'First' b'Third']
  'deck'              : [b'unknown' b'unknown' b'C' b'unknown']
  'embark_town'       : [b'Queenstown' b'Southampton' b'Cherbourg' b'Southampton']
  'alone'             : [b'n' b'n' b'y' b'n']

如果您只需要欄的子集,可以使用 select_columns 引數。

titanic_batches = tf.data.experimental.make_csv_dataset(
    titanic_file, batch_size=4,
    label_name="survived", select_columns=['class', 'fare', 'survived'])
for feature_batch, label_batch in titanic_batches.take(1):
  print("'survived': {}".format(label_batch))
  for key, value in feature_batch.items():
    print("  {!r:20s}: {}".format(key, value))
'survived': [1 0 1 0]
  'fare'              : [11.1333 24.15    7.925  52.    ]
  'class'             : [b'Third' b'Third' b'Third' b'First']

還有一個較低階的 experimental.CsvDataset 類別,它提供更精細的控制。它不支援欄類型推斷。相反,您必須指定每欄的類型。

titanic_types  = [tf.int32, tf.string, tf.float32, tf.int32, tf.int32, tf.float32, tf.string, tf.string, tf.string, tf.string]
dataset = tf.data.experimental.CsvDataset(titanic_file, titanic_types , header=True)

for line in dataset.take(10):
  print([item.numpy() for item in line])
[0, b'male', 22.0, 1, 0, 7.25, b'Third', b'unknown', b'Southampton', b'n']
[1, b'female', 38.0, 1, 0, 71.2833, b'First', b'C', b'Cherbourg', b'n']
[1, b'female', 26.0, 0, 0, 7.925, b'Third', b'unknown', b'Southampton', b'y']
[1, b'female', 35.0, 1, 0, 53.1, b'First', b'C', b'Southampton', b'n']
[0, b'male', 28.0, 0, 0, 8.4583, b'Third', b'unknown', b'Queenstown', b'y']
[0, b'male', 2.0, 3, 1, 21.075, b'Third', b'unknown', b'Southampton', b'n']
[1, b'female', 27.0, 0, 2, 11.1333, b'Third', b'unknown', b'Southampton', b'n']
[1, b'female', 14.0, 1, 0, 30.0708, b'Second', b'unknown', b'Cherbourg', b'n']
[1, b'female', 4.0, 1, 1, 16.7, b'Third', b'G', b'Southampton', b'n']
[0, b'male', 20.0, 0, 0, 8.05, b'Third', b'unknown', b'Southampton', b'y']

如果某些欄為空,則此低階介面允許您提供預設值而不是欄類型。

%%writefile missing.csv
1,2,3,4
,2,3,4
1,,3,4
1,2,,4
1,2,3,
,,,
Writing missing.csv
# Creates a dataset that reads all of the records from two CSV files, each with
# four float columns which may have missing values.

record_defaults = [999,999,999,999]
dataset = tf.data.experimental.CsvDataset("missing.csv", record_defaults)
dataset = dataset.map(lambda *items: tf.stack(items))
dataset
<_MapDataset element_spec=TensorSpec(shape=(4,), dtype=tf.int32, name=None)>
for line in dataset:
  print(line.numpy())
[1 2 3 4]
[999   2   3   4]
[  1 999   3   4]
[  1   2 999   4]
[  1   2   3 999]
[999 999 999 999]

預設情況下,CsvDataset 會產生檔案行的一欄,這可能不是理想的,例如,如果檔案以應忽略的標頭行開頭,或者如果輸入中不需要某些欄。可以使用 headerselect_cols 引數分別移除這些行和欄位。

# Creates a dataset that reads all of the records from two CSV files with
# headers, extracting float data from columns 2 and 4.
record_defaults = [999, 999] # Only provide defaults for the selected columns
dataset = tf.data.experimental.CsvDataset("missing.csv", record_defaults, select_cols=[1, 3])
dataset = dataset.map(lambda *items: tf.stack(items))
dataset
<_MapDataset element_spec=TensorSpec(shape=(2,), dtype=tf.int32, name=None)>
for line in dataset:
  print(line.numpy())
[2 4]
[2 4]
[999   4]
[2 4]
[  2 999]
[999 999]

取用檔案集

有許多資料集以檔案集的形式散佈,其中每個檔案都是一個範例。

flowers_root = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)
flowers_root = pathlib.Path(flowers_root)

根目錄包含每個類別的目錄

for item in flowers_root.glob("*"):
  print(item.name)
roses
sunflowers
LICENSE.txt
dandelion
tulips
daisy

每個類別目錄中的檔案都是範例

list_ds = tf.data.Dataset.list_files(str(flowers_root/'*/*'))

for f in list_ds.take(5):
  print(f.numpy())
b'/home/kbuilder/.keras/datasets/flower_photos/tulips/4520577328_a94c11e806_n.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/dandelion/22679060358_561ec823ae_m.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/daisy/6978826370_7b9aa7c7d5.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/tulips/3447650747_8299786b80_n.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/tulips/112951022_4892b1348b_n.jpg'

使用 tf.io.read_file 函式讀取資料,並從路徑中擷取標籤,傳回 (image, label) 配對

def process_path(file_path):
  label = tf.strings.split(file_path, os.sep)[-2]
  return tf.io.read_file(file_path), label

labeled_ds = list_ds.map(process_path)
for image_raw, label_text in labeled_ds.take(1):
  print(repr(image_raw.numpy()[:100]))
  print()
  print(label_text.numpy())
b"\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x01\x00H\x00H\x00\x00\xff\xe1\t\xcbXMP\x00://ns.adobe.com/xap/1.0/\x00<?xpacket begin='\xef\xbb\xbf' id='W5M0MpCehiHzreSzNTczk"

b'roses'

批次處理資料集元素

簡單批次處理

最簡單的批次處理形式是將資料集的 n 個連續元素堆疊到單一元素中。Dataset.batch() 轉換完全執行此操作,其約束與 tf.stack() 運算子相同,套用於元素的每個元件:即,對於每個元件i,所有元素都必須具有形狀完全相同的張量。

inc_dataset = tf.data.Dataset.range(100)
dec_dataset = tf.data.Dataset.range(0, -100, -1)
dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))
batched_dataset = dataset.batch(4)

for batch in batched_dataset.take(4):
  print([arr.numpy() for arr in batch])
[array([0, 1, 2, 3]), array([ 0, -1, -2, -3])]
[array([4, 5, 6, 7]), array([-4, -5, -6, -7])]
[array([ 8,  9, 10, 11]), array([ -8,  -9, -10, -11])]
[array([12, 13, 14, 15]), array([-12, -13, -14, -15])]

雖然 tf.data 嘗試傳播形狀資訊,但 Dataset.batch 的預設設定會導致未知的批次大小,因為最後一批次可能未滿。請注意形狀中的 None

batched_dataset
<_BatchDataset element_spec=(TensorSpec(shape=(None,), dtype=tf.int64, name=None), TensorSpec(shape=(None,), dtype=tf.int64, name=None))>

使用 drop_remainder 引數來忽略最後一批次,並取得完整的形狀傳播

batched_dataset = dataset.batch(7, drop_remainder=True)
batched_dataset
<_BatchDataset element_spec=(TensorSpec(shape=(7,), dtype=tf.int64, name=None), TensorSpec(shape=(7,), dtype=tf.int64, name=None))>

使用填補批次處理張量

上述方法適用於所有大小相同的張量。但是,許多模型 (包括序列模型) 使用可能具有不同大小的輸入資料 (例如,不同長度的序列)。為了處理這種情況,Dataset.padded_batch 轉換可讓您透過指定一或多個可以填補的維度來批次處理不同形狀的張量。

dataset = tf.data.Dataset.range(100)
dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x))
dataset = dataset.padded_batch(4, padded_shapes=(None,))

for batch in dataset.take(2):
  print(batch.numpy())
  print()
[[0 0 0]
 [1 0 0]
 [2 2 0]
 [3 3 3]]

[[4 4 4 4 0 0 0]
 [5 5 5 5 5 0 0]
 [6 6 6 6 6 6 0]
 [7 7 7 7 7 7 7]]

Dataset.padded_batch 轉換可讓您為每個元件的每個維度設定不同的填補,並且可以是可變長度 (以上範例中以 None 表示) 或固定長度。也可以覆寫填補值,預設值為 0。

訓練工作流程

處理多個週期

tf.data API 提供兩種主要方式來處理相同資料的多個週期。

在多個週期中迭代資料集的最簡單方法是使用 Dataset.repeat() 轉換。首先,建立 titanic 資料的資料集

titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
titanic_lines = tf.data.TextLineDataset(titanic_file)
def plot_batch_sizes(ds):
  batch_sizes = [batch.shape[0] for batch in ds]
  plt.bar(range(len(batch_sizes)), batch_sizes)
  plt.xlabel('Batch number')
  plt.ylabel('Batch size')

套用不帶引數的 Dataset.repeat() 轉換將無限期地重複輸入。

Dataset.repeat 轉換會串連其引數,而不發出一個週期結束和下一個週期開始的訊號。因此,在 Dataset.repeat 之後套用的 Dataset.batch 將產生跨越週期邊界的批次

titanic_batches = titanic_lines.repeat(3).batch(128)
plot_batch_sizes(titanic_batches)

png

如果您需要明確的週期分隔,請將 Dataset.batch 放在 repeat 之前

titanic_batches = titanic_lines.batch(128).repeat(3)

plot_batch_sizes(titanic_batches)

png

如果您想要在每個 epoch 結束時執行自訂計算 (例如,收集統計資料),那麼最簡單的方法是在每個 epoch 重新啟動資料集迭代。

epochs = 3
dataset = titanic_lines.batch(128)

for epoch in range(epochs):
  for batch in dataset:
    print(batch.shape)
  print("End of epoch: ", epoch)
(128,)
(128,)
(128,)
(128,)
(116,)
End of epoch:  0
(128,)
(128,)
(128,)
(128,)
(116,)
End of epoch:  1
(128,)
(128,)
(128,)
(128,)
(116,)
End of epoch:  2

隨機洗牌輸入資料

The Dataset.shuffle() 轉換會維護固定大小的緩衝區,並從該緩衝區中均勻隨機選擇下一個元素。

在資料集中新增索引,以便您可以查看效果

lines = tf.data.TextLineDataset(titanic_file)
counter = tf.data.experimental.Counter()

dataset = tf.data.Dataset.zip((counter, lines))
dataset = dataset.shuffle(buffer_size=100)
dataset = dataset.batch(20)
dataset
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_14491/4092668703.py:2: CounterV2 (from tensorflow.python.data.experimental.ops.counter) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Dataset.counter(...)` instead.
<_BatchDataset element_spec=(TensorSpec(shape=(None,), dtype=tf.int64, name=None), TensorSpec(shape=(None,), dtype=tf.string, name=None))>

由於 buffer_size 為 100,且批次大小為 20,因此第一個批次不包含索引超過 120 的元素。

n,line_batch = next(iter(dataset))
print(n.numpy())
[ 43  13  17  72  33  89  83 105  96  81   0  67  97  84  73  32  30  71
  64 103]

Dataset.batch 一樣,相對於 Dataset.repeat 的順序很重要。

Dataset.shuffle 在洗牌緩衝區為空之前不會發出 epoch 結束的訊號。因此,在 repeat 之前放置 shuffle 將顯示一個 epoch 的每個元素,然後再移至下一個 epoch

dataset = tf.data.Dataset.zip((counter, lines))
shuffled = dataset.shuffle(buffer_size=100).batch(10).repeat(2)

print("Here are the item ID's near the epoch boundary:\n")
for n, line_batch in shuffled.skip(60).take(5):
  print(n.numpy())
Here are the item ID's near the epoch boundary:

[625 583 586 606 504 597 453 615 575 429]
[424 456 452 605 483 566 395 556 492 365]
[570 573 611 540 545 559 388 579]
[  0  18  92  79  81  86  62 103  29  82]
[47 69 17 95  9 11 77 84 31 53]
shuffle_repeat = [n.numpy().mean() for n, line_batch in shuffled]
plt.plot(shuffle_repeat, label="shuffle().repeat()")
plt.ylabel("Mean item ID")
plt.legend()
<matplotlib.legend.Legend at 0x7f32b0355dc0>

png

但是在 shuffle 之前的 repeat 會將 epoch 邊界混合在一起

dataset = tf.data.Dataset.zip((counter, lines))
shuffled = dataset.repeat(2).shuffle(buffer_size=100).batch(10)

print("Here are the item ID's near the epoch boundary:\n")
for n, line_batch in shuffled.skip(55).take(15):
  print(n.numpy())
Here are the item ID's near the epoch boundary:

[464 377 595 328 614 504   7  20 433 623]
[419 550 615 509 499 540 557 622 570 544]
[588 618 616   4 460  38  16 617  31 591]
[440 610 585 600  36  17  35  52 592  19]
[523  59 545 624 607  51  53  26  33 318]
[510  37   6 448 612 469  32  10  39 594]
[ 41  63  13 627  67  76 386 579 412  55]
[  1  54 626  71  64  22  47 553 525  65]
[ 69   3  15 102  14 455  23  98  74  78]
[596  12  50   5  18 112 114  97  61  42]
[103  84 583  90 350 575 606  85 107 108]
[115 127  60 602 118  43  34  58  46 587]
[119  56 620  75 564 625  88 140 539  45]
[589 100 149 452 110  11  66 132 142 111]
[101 334  94 497 520 158 120  86 135  95]
repeat_shuffle = [n.numpy().mean() for n, line_batch in shuffled]

plt.plot(shuffle_repeat, label="shuffle().repeat()")
plt.plot(repeat_shuffle, label="repeat().shuffle()")
plt.ylabel("Mean item ID")
plt.legend()
<matplotlib.legend.Legend at 0x7f32b033bdf0>

png

預處理資料

The Dataset.map(f) 轉換透過將給定的函數 f 應用於輸入資料集的每個元素,來產生新的資料集。它基於 map() 函數,該函數通常應用於函數式程式語言中的列表(和其他結構)。函數 f 接受代表輸入中單個元素的 tf.Tensor 物件,並傳回將代表新資料集中單個元素的 tf.Tensor 物件。它的實作使用標準 TensorFlow 運算將一個元素轉換為另一個元素。

本節涵蓋如何使用 Dataset.map() 的常見範例。

解碼影像資料並調整其大小

當在真實世界的影像資料上訓練神經網路時,通常需要將不同大小的影像轉換為通用大小,以便可以將它們分批處理為固定大小。

重建花卉檔案名稱資料集

list_ds = tf.data.Dataset.list_files(str(flowers_root/'*/*'))

撰寫一個操作資料集元素的函數。

# Reads an image from a file, decodes it into a dense tensor, and resizes it
# to a fixed shape.
def parse_image(filename):
  parts = tf.strings.split(filename, os.sep)
  label = parts[-2]

  image = tf.io.read_file(filename)
  image = tf.io.decode_jpeg(image)
  image = tf.image.convert_image_dtype(image, tf.float32)
  image = tf.image.resize(image, [128, 128])
  return image, label

測試它是否運作。

file_path = next(iter(list_ds))
image, label = parse_image(file_path)

def show(image, label):
  plt.figure()
  plt.imshow(image)
  plt.title(label.numpy().decode('utf-8'))
  plt.axis('off')

show(image, label)

png

將其對應到資料集。

images_ds = list_ds.map(parse_image)

for image, label in images_ds.take(2):
  show(image, label)

png

png

應用任意 Python 邏輯

為了效能考量,請盡可能使用 TensorFlow 運算來預處理您的資料。但是,有時在剖析您的輸入資料時,呼叫外部 Python 程式庫會很有用。您可以在 tf.py_function 轉換中使用 Dataset.map 運算。

例如,如果您想要套用隨機旋轉,tf.image 模組只有 tf.image.rot90,對於影像增強來說不是很有用。

為了示範 tf.py_function,請嘗試改用 scipy.ndimage.rotate 函數

import scipy.ndimage as ndimage

@tf.py_function(Tout=tf.float32)
def random_rotate_image(image):
  image = ndimage.rotate(image, np.random.uniform(-30, 30), reshape=False)
  return image
image, label = next(iter(images_ds))
image = random_rotate_image(image)
show(image, label)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

png

若要將此函數與 Dataset.map 搭配使用,與 Dataset.from_generator 適用相同的注意事項,您需要在套用函數時描述傳回形狀和類型

def tf_random_rotate_image(image, label):
  im_shape = image.shape
  image = random_rotate_image(image)
  image.set_shape(im_shape)
  return image, label
rot_ds = images_ds.map(tf_random_rotate_image)

for image, label in rot_ds.take(2):
  show(image, label)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

png

png

剖析 tf.Example 協定緩衝區訊息

許多輸入管線從 TFRecord 格式擷取 tf.train.Example 協定緩衝區訊息。每個 tf.train.Example 記錄包含一個或多個「特徵」,而輸入管線通常會將這些特徵轉換為張量。

fsns_test_file = tf.keras.utils.get_file("fsns.tfrec", "https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001")
dataset = tf.data.TFRecordDataset(filenames = [fsns_test_file])
dataset
<TFRecordDatasetV2 element_spec=TensorSpec(shape=(), dtype=tf.string, name=None)>

您可以在 tf.data.Dataset 之外使用 tf.train.Example 原型,以了解資料

raw_example = next(iter(dataset))
parsed = tf.train.Example.FromString(raw_example.numpy())

feature = parsed.features.feature
raw_img = feature['image/encoded'].bytes_list.value[0]
img = tf.image.decode_png(raw_img)
plt.imshow(img)
plt.axis('off')
_ = plt.title(feature["image/text"].bytes_list.value[0])

png

raw_example = next(iter(dataset))
def tf_parse(eg):
  example = tf.io.parse_example(
      eg[tf.newaxis], {
          'image/encoded': tf.io.FixedLenFeature(shape=(), dtype=tf.string),
          'image/text': tf.io.FixedLenFeature(shape=(), dtype=tf.string)
      })
  return example['image/encoded'][0], example['image/text'][0]
img, txt = tf_parse(raw_example)
print(txt.numpy())
print(repr(img.numpy()[:20]), "...")
b'Rue Perreyon'
b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x02X' ...
decoded = dataset.map(tf_parse)
decoded
<_MapDataset element_spec=(TensorSpec(shape=(), dtype=tf.string, name=None), TensorSpec(shape=(), dtype=tf.string, name=None))>
image_batch, text_batch = next(iter(decoded.batch(10)))
image_batch.shape
TensorShape([10])

時間序列視窗化

如需端對端時間序列範例,請參閱:時間序列預測

時間序列資料通常在時間軸完整的情況下組織。

使用簡單的 Dataset.range 進行示範

range_ds = tf.data.Dataset.range(100000)

通常,基於此類資料的模型會需要連續的時間片段。

最簡單的方法是將資料分批處理

使用 batch

batches = range_ds.batch(10, drop_remainder=True)

for batch in batches.take(5):
  print(batch.numpy())
[0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24 25 26 27 28 29]
[30 31 32 33 34 35 36 37 38 39]
[40 41 42 43 44 45 46 47 48 49]

或者,為了對未來進行密集預測,您可以將特徵和標籤彼此相對移動一步

def dense_1_step(batch):
  # Shift features and labels one step relative to each other.
  return batch[:-1], batch[1:]

predict_dense_1_step = batches.map(dense_1_step)

for features, label in predict_dense_1_step.take(3):
  print(features.numpy(), " => ", label.numpy())
[0 1 2 3 4 5 6 7 8]  =>  [1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18]  =>  [11 12 13 14 15 16 17 18 19]
[20 21 22 23 24 25 26 27 28]  =>  [21 22 23 24 25 26 27 28 29]

若要預測整個視窗而不是固定偏移量,您可以將批次分割成兩部分

batches = range_ds.batch(15, drop_remainder=True)

def label_next_5_steps(batch):
  return (batch[:-5],   # Inputs: All except the last 5 steps
          batch[-5:])   # Labels: The last 5 steps

predict_5_steps = batches.map(label_next_5_steps)

for features, label in predict_5_steps.take(3):
  print(features.numpy(), " => ", label.numpy())
[0 1 2 3 4 5 6 7 8 9]  =>  [10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24]  =>  [25 26 27 28 29]
[30 31 32 33 34 35 36 37 38 39]  =>  [40 41 42 43 44]

若要允許一個批次的特徵與另一個批次的標籤之間有一些重疊,請使用 Dataset.zip

feature_length = 10
label_length = 3

features = range_ds.batch(feature_length, drop_remainder=True)
labels = range_ds.batch(feature_length).skip(1).map(lambda labels: labels[:label_length])

predicted_steps = tf.data.Dataset.zip((features, labels))

for features, label in predicted_steps.take(5):
  print(features.numpy(), " => ", label.numpy())
[0 1 2 3 4 5 6 7 8 9]  =>  [10 11 12]
[10 11 12 13 14 15 16 17 18 19]  =>  [20 21 22]
[20 21 22 23 24 25 26 27 28 29]  =>  [30 31 32]
[30 31 32 33 34 35 36 37 38 39]  =>  [40 41 42]
[40 41 42 43 44 45 46 47 48 49]  =>  [50 51 52]

使用 window

雖然使用 Dataset.batch 可以運作,但在某些情況下,您可能需要更精細的控制。Dataset.window 方法可讓您完全控制,但需要小心:它會傳回 DatasetsDataset前往「資料集結構」章節以取得詳細資訊。

window_size = 5

windows = range_ds.window(window_size, shift=1)
for sub_ds in windows.take(5):
  print(sub_ds)
<_VariantDataset element_spec=TensorSpec(shape=(), dtype=tf.int64, name=None)>
<_VariantDataset element_spec=TensorSpec(shape=(), dtype=tf.int64, name=None)>
<_VariantDataset element_spec=TensorSpec(shape=(), dtype=tf.int64, name=None)>
<_VariantDataset element_spec=TensorSpec(shape=(), dtype=tf.int64, name=None)>
<_VariantDataset element_spec=TensorSpec(shape=(), dtype=tf.int64, name=None)>

The Dataset.flat_map 方法可以取得資料集的資料集,並將其展平為單一資料集

for x in windows.flat_map(lambda x: x).take(30):
   print(x.numpy(), end=' ')
0 1 2 3 4 1 2 3 4 5 2 3 4 5 6 3 4 5 6 7 4 5 6 7 8 5 6 7 8 9

在幾乎所有情況下,您都會想要先 Dataset.batch 資料集

def sub_to_batch(sub):
  return sub.batch(window_size, drop_remainder=True)

for example in windows.flat_map(sub_to_batch).take(5):
  print(example.numpy())
[0 1 2 3 4]
[1 2 3 4 5]
[2 3 4 5 6]
[3 4 5 6 7]
[4 5 6 7 8]

現在,您可以看到 shift 引數控制每個視窗移動的距離。

將這些放在一起,您可能會撰寫此函數

def make_window_dataset(ds, window_size=5, shift=1, stride=1):
  windows = ds.window(window_size, shift=shift, stride=stride)

  def sub_to_batch(sub):
    return sub.batch(window_size, drop_remainder=True)

  windows = windows.flat_map(sub_to_batch)
  return windows
ds = make_window_dataset(range_ds, window_size=10, shift = 5, stride=3)

for example in ds.take(10):
  print(example.numpy())
[ 0  3  6  9 12 15 18 21 24 27]
[ 5  8 11 14 17 20 23 26 29 32]
[10 13 16 19 22 25 28 31 34 37]
[15 18 21 24 27 30 33 36 39 42]
[20 23 26 29 32 35 38 41 44 47]
[25 28 31 34 37 40 43 46 49 52]
[30 33 36 39 42 45 48 51 54 57]
[35 38 41 44 47 50 53 56 59 62]
[40 43 46 49 52 55 58 61 64 67]
[45 48 51 54 57 60 63 66 69 72]

然後,就像之前一樣,很容易擷取標籤

dense_labels_ds = ds.map(dense_1_step)

for inputs,labels in dense_labels_ds.take(3):
  print(inputs.numpy(), "=>", labels.numpy())
[ 0  3  6  9 12 15 18 21 24] => [ 3  6  9 12 15 18 21 24 27]
[ 5  8 11 14 17 20 23 26 29] => [ 8 11 14 17 20 23 26 29 32]
[10 13 16 19 22 25 28 31 34] => [13 16 19 22 25 28 31 34 37]

重新取樣

當處理類別非常不平衡的資料集時,您可能會想要重新取樣資料集。tf.data 提供了兩種方法來執行此操作。信用卡詐欺資料集是這類問題的一個很好的例子。

zip_path = tf.keras.utils.get_file(
    origin='https://storage.googleapis.com/download.tensorflow.org/data/creditcard.zip',
    fname='creditcard.zip',
    extract=True)

csv_path = zip_path.replace('.zip', '.csv')
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/creditcard.zip
69155632/69155632 [==============================] - 0s 0us/step
creditcard_ds = tf.data.experimental.make_csv_dataset(
    csv_path, batch_size=1024, label_name="Class",
    # Set the column types: 30 floats and an int.
    column_defaults=[float()]*30+[int()])

現在,檢查類別的分配,它是高度偏斜的

def count(counts, batch):
  features, labels = batch
  class_1 = labels == 1
  class_1 = tf.cast(class_1, tf.int32)

  class_0 = labels == 0
  class_0 = tf.cast(class_0, tf.int32)

  counts['class_0'] += tf.reduce_sum(class_0)
  counts['class_1'] += tf.reduce_sum(class_1)

  return counts
counts = creditcard_ds.take(10).reduce(
    initial_state={'class_0': 0, 'class_1': 0},
    reduce_func = count)

counts = np.array([counts['class_0'].numpy(),
                   counts['class_1'].numpy()]).astype(np.float32)

fractions = counts/counts.sum()
print(fractions)
[0.9953 0.0047]

使用不平衡資料集進行訓練的常見方法是平衡它。tf.data 包含幾種啟用此工作流程的方法

資料集取樣

重新取樣資料集的一種方法是使用 sample_from_datasets。當您為每個類別都有單獨的 tf.data.Dataset 時,這更適用。

在此,僅使用篩選器從信用卡詐欺資料中產生它們

negative_ds = (
  creditcard_ds
    .unbatch()
    .filter(lambda features, label: label==0)
    .repeat())
positive_ds = (
  creditcard_ds
    .unbatch()
    .filter(lambda features, label: label==1)
    .repeat())
for features, label in positive_ds.batch(10).take(1):
  print(label.numpy())
[1 1 1 1 1 1 1 1 1 1]

若要使用 tf.data.Dataset.sample_from_datasets,請傳遞資料集和每個資料集的權重

balanced_ds = tf.data.Dataset.sample_from_datasets(
    [negative_ds, positive_ds], [0.5, 0.5]).batch(10)

現在,資料集以 50/50 的機率產生每個類別的範例

for features, labels in balanced_ds.take(10):
  print(labels.numpy())
[0 0 0 1 0 0 1 0 0 0]
[1 1 0 0 1 1 1 0 0 1]
[1 0 0 0 1 1 0 0 0 0]
[1 0 0 0 0 1 0 1 0 0]
[1 1 1 1 0 1 0 0 1 0]
[0 1 1 0 0 0 1 1 0 1]
[0 1 1 0 1 1 0 1 1 1]
[0 0 0 1 0 0 1 1 0 0]
[0 1 0 0 1 1 1 0 1 0]
[1 1 1 1 1 1 0 0 0 0]

拒絕重新取樣

上述 Dataset.sample_from_datasets 方法的一個問題是,它需要每個類別都有單獨的 tf.data.Dataset。您可以使用 Dataset.filter 來建立這兩個資料集,但這會導致所有資料載入兩次。

The tf.data.Dataset.rejection_resample 方法可以應用於資料集以重新平衡它,同時僅載入一次。元素將被捨棄或重複以實現平衡。

The rejection_resample 方法採用 class_func 引數。此 class_func 會應用於每個資料集元素,並用於判斷範例屬於哪個類別,以達到平衡的目的。

此處的目標是平衡標籤分佈,而 creditcard_ds 的元素已經是 (features, label) 配對。因此,class_func 只需要傳回這些標籤

def class_func(features, label):
  return label

重新取樣方法處理個別範例,因此在這種情況下,您必須在套用該方法之前 unbatch 資料集。

此方法需要目標分佈,以及選擇性的初始分佈估計作為輸入。

resample_ds = (
    creditcard_ds
    .unbatch()
    .rejection_resample(class_func, target_dist=[0.5,0.5],
                        initial_dist=fractions)
    .batch(10))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/data/ops/dataset_ops.py:4963: Print (from tensorflow.python.ops.logging_ops) is deprecated and will be removed after 2018-08-20.
Instructions for updating:
Use tf.print instead of tf.Print. Note that tf.print returns a no-output operator that directly prints the output. Outside of defuns or eager mode, this operator will not be executed unless it is directly specified in session.run or used as a control dependency for other operators. This is only a concern in graph mode. Below is an example of how to ensure tf.print executes in graph mode:

The rejection_resample 方法傳回 (class, example) 配對,其中 classclass_func 的輸出。在這種情況下,example 已經是 (feature, label) 配對,因此使用 map 來捨棄標籤的額外副本

balanced_ds = resample_ds.map(lambda extra_label, features_and_label: features_and_label)

現在,資料集以 50/50 的機率產生每個類別的範例

for features, labels in balanced_ds.take(10):
  print(labels.numpy())
Proportion of examples rejected by sampler is high: [0.995312512][0.995312512 0.0046875][0 1]
Proportion of examples rejected by sampler is high: [0.995312512][0.995312512 0.0046875][0 1]
Proportion of examples rejected by sampler is high: [0.995312512][0.995312512 0.0046875][0 1]
Proportion of examples rejected by sampler is high: [0.995312512][0.995312512 0.0046875][0 1]
Proportion of examples rejected by sampler is high: [0.995312512][0.995312512 0.0046875][0 1]
Proportion of examples rejected by sampler is high: [0.995312512][0.995312512 0.0046875][0 1]
Proportion of examples rejected by sampler is high: [0.995312512][0.995312512 0.0046875][0 1]
Proportion of examples rejected by sampler is high: [0.995312512][0.995312512 0.0046875][0 1]
Proportion of examples rejected by sampler is high: [0.995312512][0.995312512 0.0046875][0 1]
Proportion of examples rejected by sampler is high: [0.995312512][0.995312512 0.0046875][0 1]
[0 1 0 1 0 0 1 1 0 1]
[1 1 0 1 0 0 1 1 0 0]
[0 0 1 0 1 0 0 0 0 1]
[1 0 0 1 0 0 1 0 1 0]
[0 1 1 0 0 1 0 1 1 0]
[1 1 0 1 1 0 1 1 0 1]
[1 1 1 0 0 0 0 0 0 1]
[0 1 0 0 0 0 0 1 1 1]
[0 1 0 1 1 0 0 1 0 1]
[0 0 1 0 1 1 1 0 0 1]

迭代器檢查點

Tensorflow 支援建立檢查點,以便在您的訓練程序重新啟動時,它可以還原最新的檢查點以恢復其大部分進度。除了檢查模型變數之外,您還可以檢查資料集迭代器的進度。如果您有一個大型資料集,並且不希望在每次重新啟動時都從頭開始資料集,這可能會很有用。但是請注意,迭代器檢查點可能很大,因為 Dataset.shuffleDataset.prefetch 等轉換需要在迭代器內緩衝元素。

若要將您的迭代器包含在檢查點中,請將迭代器傳遞至 tf.train.Checkpoint 建構函式。

range_ds = tf.data.Dataset.range(20)

iterator = iter(range_ds)
ckpt = tf.train.Checkpoint(step=tf.Variable(0), iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, '/tmp/my_ckpt', max_to_keep=3)

print([next(iterator).numpy() for _ in range(5)])

save_path = manager.save()

print([next(iterator).numpy() for _ in range(5)])

ckpt.restore(manager.latest_checkpoint)

print([next(iterator).numpy() for _ in range(5)])
[0, 1, 2, 3, 4]
[5, 6, 7, 8, 9]
[5, 6, 7, 8, 9]

tf.datatf.keras 搭配使用

The tf.keras API 簡化了建立和執行機器學習模型的許多方面。其 Model.fitModel.evaluate 以及 Model.predict API 支援資料集作為輸入。以下是快速資料集和模型設定

train, test = tf.keras.datasets.fashion_mnist.load_data()

images, labels = train
images = images/255.0
labels = labels.astype(np.int32)
fmnist_train_ds = tf.data.Dataset.from_tensor_slices((images, labels))
fmnist_train_ds = fmnist_train_ds.shuffle(5000).batch(32)

model = tf.keras.Sequential([
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(10)
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

傳遞 (feature, label) 配對的資料集是 Model.fitModel.evaluate 所需的全部內容

model.fit(fmnist_train_ds, epochs=2)
Epoch 1/2
  26/1875 [..............................] - ETA: 3s - loss: 1.7645 - accuracy: 0.3930
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1705458306.961075   14743 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
1875/1875 [==============================] - 4s 2ms/step - loss: 0.5948 - accuracy: 0.7994
Epoch 2/2
1875/1875 [==============================] - 4s 2ms/step - loss: 0.4601 - accuracy: 0.8427
<keras.src.callbacks.History at 0x7f32b0722250>

如果您傳遞無限資料集,例如透過呼叫 Dataset.repeat,您只需要同時傳遞 steps_per_epoch 引數

model.fit(fmnist_train_ds.repeat(), epochs=2, steps_per_epoch=20)
Epoch 1/2
20/20 [==============================] - 0s 2ms/step - loss: 0.4791 - accuracy: 0.8344
Epoch 2/2
20/20 [==============================] - 0s 2ms/step - loss: 0.4777 - accuracy: 0.8438
<keras.src.callbacks.History at 0x7f32b0457100>

為了評估,您可以傳遞評估步驟的數量

loss, accuracy = model.evaluate(fmnist_train_ds)
print("Loss :", loss)
print("Accuracy :", accuracy)
1875/1875 [==============================] - 3s 2ms/step - loss: 0.4396 - accuracy: 0.8490
Loss : 0.4396059811115265
Accuracy : 0.8489833474159241

對於大型資料集,請設定要評估的步驟數

loss, accuracy = model.evaluate(fmnist_train_ds.repeat(), steps=10)
print("Loss :", loss)
print("Accuracy :", accuracy)
10/10 [==============================] - 0s 2ms/step - loss: 0.4079 - accuracy: 0.8844
Loss : 0.4079427719116211
Accuracy : 0.8843749761581421

呼叫 Model.predict 時不需要標籤。

predict_ds = tf.data.Dataset.from_tensor_slices(images).batch(32)
result = model.predict(predict_ds, steps = 10)
print(result.shape)
10/10 [==============================] - 0s 1ms/step
(320, 10)

但是,如果您確實傳遞包含標籤的資料集,則會忽略標籤

result = model.predict(fmnist_train_ds, steps = 10)
print(result.shape)
10/10 [==============================] - 0s 1ms/step
(320, 10)