![]() |
![]() |
![]() |
![]() |
TFRecord 格式是一種用於儲存二進位記錄序列的簡單格式。
Protocol Buffer 是一個跨平台、跨語言的程式庫,可有效率地序列化結構化資料。
Protocol 訊息是由 .proto
檔案定義,這些檔案通常是瞭解訊息類型最簡單的方式。
tf.train.Example
訊息 (或 protobuf) 是一種彈性的訊息類型,代表 {"字串": 值}
對應。它專為搭配 TensorFlow 使用而設計,並廣泛用於較高層級的 API (例如 TFX)。
這個筆記本示範如何建立、剖析及使用 tf.train.Example
訊息,然後將 tf.train.Example
訊息序列化、寫入及讀取到 .tfrecord
檔案,以及從 .tfrecord
檔案讀取 tf.train.Example
訊息。
設定
import tensorflow as tf
import numpy as np
import IPython.display as display
2024-07-13 05:37:29.355021: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-07-13 05:37:29.381246: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-07-13 05:37:29.381281: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
tf.train.Example
tf.train.Example 的資料類型
從根本上來說,tf.train.Example
是一種 {"字串": tf.train.Feature}
對應。
tf.train.Feature
訊息類型可以接受下列三種類型之一 (請參閱 .proto
檔案以供參考)。大多數其他通用類型都可以強制轉換為其中一種
tf.train.BytesList
(可以強制轉換下列類型)字串
位元組
tf.train.FloatList
(可以強制轉換下列類型)float
(float32
)double
(float64
)
tf.train.Int64List
(可以強制轉換下列類型)布林值
列舉
int32
uint32
int64
uint64
為了將標準 TensorFlow 類型轉換為與 tf.train.Example
相容的 tf.train.Feature
,您可以使用下方的捷徑函式。請注意,每個函式都會接收純量輸入值,並傳回包含上述三種清單類型之一的 tf.train.Feature
# The following functions can be used to convert a value to a type compatible
# with tf.train.Example.
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
以下是一些說明這些函式運作方式的範例。請注意,輸入類型各不相同,但輸出類型已標準化。如果函式的輸入類型與上述任何一種可強制轉換的類型不符,函式就會擲回例外狀況 (例如,_int64_feature(1.0)
會傳回錯誤,因為 1.0
是浮點數,因此應該改用 _float_feature
函式)
print(_bytes_feature(b'test_string'))
print(_bytes_feature(u'test_bytes'.encode('utf-8')))
print(_float_feature(np.exp(1)))
print(_int64_feature(True))
print(_int64_feature(1))
bytes_list { value: "test_string" } bytes_list { value: "test_bytes" } float_list { value: 2.7182817459106445 } int64_list { value: 1 } int64_list { value: 1 }
所有 Proto 訊息都可以使用 .SerializeToString
方法序列化為二進位字串
feature = _float_feature(np.exp(1))
feature.SerializeToString()
b'\x12\x06\n\x04T\xf8-@'
建立 tf.train.Example 訊息
假設您想要從現有資料建立 tf.train.Example
訊息。實際上,資料集可能來自任何地方,但從單一觀察值建立 tf.train.Example
訊息的程序會是相同的
在每個觀察值中,每個值都需要轉換為包含 3 種相容類型之一的
tf.train.Feature
,方法是使用上述函式之一。您可以從特徵名稱字串建立對應 (字典) 到 #1 中產生的編碼特徵值。
步驟 2 中產生的對應會轉換為
Features
訊息。
在這個筆記本中,您將使用 NumPy 建立資料集。
這個資料集將有 4 個特徵
- 布林值特徵,False 或 True 的機率相等
- 從 [0, 5] 均勻隨機選擇的整數特徵
- 從字串表格產生,並使用整數特徵作為索引的字串特徵
- 來自標準常態分佈的浮點數特徵
考量一個範例,其中包含來自上述每個分佈的 10,000 個獨立且完全相同的分佈觀察值
# The number of observations in the dataset.
n_observations = int(1e4)
# Boolean feature, encoded as False or True.
feature0 = np.random.choice([False, True], n_observations)
# Integer feature, random from 0 to 4.
feature1 = np.random.randint(0, 5, n_observations)
# String feature.
strings = np.array([b'cat', b'dog', b'chicken', b'horse', b'goat'])
feature2 = strings[feature1]
# Float feature, from a standard normal distribution.
feature3 = np.random.randn(n_observations)
這些特徵都可以使用 _bytes_feature
、_float_feature
、_int64_feature
其中之一強制轉換為與 tf.train.Example
相容的類型。然後,您可以從這些編碼特徵建立 tf.train.Example
訊息
@tf.py_function(Tout=tf.string)
def serialize_example(feature0, feature1, feature2, feature3):
"""
Creates a tf.train.Example message ready to be written to a file.
"""
# Create a dictionary mapping the feature name to the tf.train.Example-compatible
# data type.
feature = {
'feature0': _int64_feature(feature0),
'feature1': _int64_feature(feature1),
'feature2': _bytes_feature(feature2),
'feature3': _float_feature(feature3),
}
# Create a Features message using tf.train.Example.
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
return example_proto.SerializeToString()
例如,假設您有一個來自資料集的單一觀察值 [False, 4, bytes('goat'), 0.9876]
。您可以使用 serialize_example()
為這個觀察值建立及列印 tf.train.Example
訊息。每個單一觀察值都會依照上述內容寫入為 Features
訊息。請注意,tf.train.Example
訊息只是 Features
訊息的包裝函式
# This is an example observation from the dataset.
example_observation = [False, 4, b'goat', 0.9876]
serialized_example = serialize_example(*example_observation)
serialized_example
<tf.Tensor: shape=(), dtype=string, numpy=b'\nR\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04[\xd3|?'>
若要解碼訊息,請使用 tf.train.Example.FromString
方法。
example_proto = tf.train.Example.FromString(serialized_example.numpy())
example_proto
features { feature { key: "feature0" value { int64_list { value: 0 } } } feature { key: "feature1" value { int64_list { value: 4 } } } feature { key: "feature2" value { bytes_list { value: "goat" } } } feature { key: "feature3" value { float_list { value: 0.9876000285148621 } } } }
TFRecord 格式詳細資訊
TFRecord 檔案包含記錄序列。檔案只能循序讀取。
每筆記錄都包含一個位元組字串 (用於資料酬載)、資料長度以及 CRC-32C (使用 Castagnoli 多項式的 32 位元 CRC) 雜湊,以進行完整性檢查。
每筆記錄都以以下格式儲存
uint64 length
uint32 masked_crc32_of_length
byte data[length]
uint32 masked_crc32_of_data
記錄會串連在一起以產生檔案。CRC 在此處說明,而 CRC 的遮罩為
masked_crc = ((crc >> 15) | (crc << 17)) + 0xa282ead8ul
讀取和寫入 TFRecord 檔案
tf.io
模組也包含用於讀取和寫入 TFRecord 檔案的純 Python 函式。
寫入 TFRecord 檔案
接下來,將 10,000 個觀察值寫入檔案 test.tfrecord
。每個觀察值都會轉換為 tf.train.Example
訊息,然後寫入檔案。接著,您可以確認已建立檔案 test.tfrecord
filename = 'test.tfrecord'
# Write the `tf.train.Example` observations to the file.
with tf.io.TFRecordWriter(filename) as writer:
for i in range(n_observations):
example = serialize_example(feature0[i], feature1[i], feature2[i], feature3[i])
writer.write(example.numpy())
du -sh {filename}
984K test.tfrecord
在 Python 中讀取 TFRecord 檔案
這些序列化張量可以使用 tf.train.Example.ParseFromString
輕鬆剖析
filenames = [filename]
raw_dataset = tf.data.TFRecordDataset(filenames)
raw_dataset
<TFRecordDatasetV2 element_spec=TensorSpec(shape=(), dtype=tf.string, name=None)>
for raw_record in raw_dataset.take(1):
example = tf.train.Example()
example.ParseFromString(raw_record.numpy())
print(example)
features { feature { key: "feature0" value { int64_list { value: 1 } } } feature { key: "feature1" value { int64_list { value: 1 } } } feature { key: "feature2" value { bytes_list { value: "dog" } } } feature { key: "feature3" value { float_list { value: 1.7843105792999268 } } } } 2024-07-13 05:37:48.209959: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
這會傳回 tf.train.Example
Proto,雖然不容易直接使用,但基本上是下列項目的表示法
Dict[str,
Union[List[float],
List[int],
List[str]]]
以下程式碼會手動將 Example 轉換為 NumPy 陣列的字典,而未使用 TensorFlow Ops。如需詳細資訊,請參閱 PROTO 檔案。
result = {}
# example.features.feature is the dictionary
for key, feature in example.features.feature.items():
# The values are the Feature objects which contain a `kind` which contains:
# one of three fields: bytes_list, float_list, int64_list
kind = feature.WhichOneof('kind')
result[key] = np.array(getattr(feature, kind).value)
result
{'feature3': array([1.78431058]), 'feature2': array([b'dog'], dtype='|S3'), 'feature1': array([1]), 'feature0': array([1])}
使用 tf.data 讀取 TFRecord 檔案
您也可以使用 tf.data.TFRecordDataset
類別讀取 TFRecord 檔案。
如需使用 tf.data
取用 TFRecord 檔案的詳細資訊,請參閱「tf.data:建構 TensorFlow 輸入管線」指南。
使用 TFRecordDataset
對於標準化輸入資料和最佳化效能非常實用。
filenames = [filename]
raw_dataset = tf.data.TFRecordDataset(filenames)
raw_dataset
<TFRecordDatasetV2 element_spec=TensorSpec(shape=(), dtype=tf.string, name=None)>
此時,資料集包含序列化的 tf.train.Example
訊息。在反覆運算時,會將這些訊息傳回為純量字串張量。
使用 .take
方法僅顯示前 10 筆記錄。
for raw_record in raw_dataset.take(10):
print(repr(raw_record))
<tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04Jd\xe4?\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x01\n\x13\n\x08feature2\x12\x07\n\x05\n\x03dog\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nU\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xa0\xdap\xbd\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x02\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x17\n\x08feature2\x12\x0b\n\t\n\x07chicken'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nR\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04S\x92\x9f=\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xc5\xe2\xf9>\n\x13\n\x08feature2\x12\x07\n\x05\n\x03dog\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nU\n\x17\n\x08feature2\x12\x0b\n\t\n\x07chicken\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x02\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\x9a\x81\xc1\xbf\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nS\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x03\n\x15\n\x08feature2\x12\t\n\x07\n\x05horse\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\x9c?\x16\xbf'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xfe\xc6\xb3\xbd\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x01\n\x13\n\x08feature2\x12\x07\n\x05\n\x03dog\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\x0c\xaa<>'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04KE6\xbf\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x01\n\x13\n\x08feature2\x12\x07\n\x05\n\x03dog\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\x03\x1a\x14\xc0'> 2024-07-13 05:37:48.263700: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
這些張量可以使用下方的函式剖析。請注意,feature_description
在此處是必要的,因為 tf.data.Dataset
使用圖形執行,而且需要這個描述才能建構其形狀和類型簽名
# Create a description of the features.
feature_description = {
'feature0': tf.io.FixedLenFeature([], tf.int64, default_value=0),
'feature1': tf.io.FixedLenFeature([], tf.int64, default_value=0),
'feature2': tf.io.FixedLenFeature([], tf.string, default_value=''),
'feature3': tf.io.FixedLenFeature([], tf.float32, default_value=0.0),
}
def _parse_function(example_proto):
# Parse the input `tf.train.Example` proto using the dictionary above.
return tf.io.parse_single_example(example_proto, feature_description)
或者,使用 tf.parse_example
一次剖析整個批次。使用 tf.data.Dataset.map
方法將此函式套用至資料集中的每個項目
parsed_dataset = raw_dataset.map(_parse_function)
parsed_dataset
<_MapDataset element_spec={'feature0': TensorSpec(shape=(), dtype=tf.int64, name=None), 'feature1': TensorSpec(shape=(), dtype=tf.int64, name=None), 'feature2': TensorSpec(shape=(), dtype=tf.string, name=None), 'feature3': TensorSpec(shape=(), dtype=tf.float32, name=None)}>
使用立即執行來顯示資料集中的觀察值。這個資料集中有 10,000 個觀察值,但您只會顯示前 10 個。資料會顯示為特徵字典。每個項目都是 tf.Tensor
,而這個張量的 numpy
元素會顯示特徵的值
for parsed_record in parsed_dataset.take(10):
print(repr(parsed_record))
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'dog'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=1.7843106>} {'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=2>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'chicken'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=-0.058802247>} {'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=4>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'goat'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=0.07791581>} {'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'dog'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=0.48805824>} {'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=2>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'chicken'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=-1.5117676>} {'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=3>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'horse'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=-0.5869081>} {'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'cat'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=-0.08778189>} {'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'dog'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=0.18424243>} {'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'cat'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=-0.7119948>} {'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'dog'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=-2.3140876>} 2024-07-13 05:37:48.363183: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
在這裡,tf.parse_example
函式會將 tf.train.Example
欄位解壓縮到標準張量中。
逐步解說:讀取和寫入圖片資料
這是一個端對端範例,說明如何使用 TFRecord 讀取和寫入圖片資料。使用圖片作為輸入資料,您會將資料寫入為 TFRecord 檔案,然後將檔案讀取回來並顯示圖片。
舉例來說,如果您想在同一個輸入資料集上使用多個模型,這會很有用。圖片資料可以預先處理為 TFRecord 格式,而不是以原始格式儲存,這樣就可以在所有後續處理和建模中使用。
首先,讓我們下載這張雪地裡的貓咪圖片和這張紐約市威廉斯堡大橋在建工程的照片。
擷取圖片
cat_in_snow = tf.keras.utils.get_file(
'320px-Felis_catus-cat_on_snow.jpg',
'https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg')
williamsburg_bridge = tf.keras.utils.get_file(
'194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg',
'https://storage.googleapis.com/download.tensorflow.org/example_images/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg')
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg 17858/17858 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg 15477/15477 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
display.display(display.Image(filename=cat_in_snow))
display.display(display.HTML('Image cc-by: <a "href=https://commons.wikimedia.org/wiki/File:Felis_catus-cat_on_snow.jpg">Von.grzanka</a>'))
display.display(display.Image(filename=williamsburg_bridge))
display.display(display.HTML('<a "href=https://commons.wikimedia.org/wiki/File:New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg">From Wikimedia</a>'))
寫入 TFRecord 檔案
和之前一樣,將特徵編碼為與 tf.train.Example
相容的類型。這會儲存原始圖片字串特徵,以及高度、寬度、深度和任意 label
特徵。後者用於在您寫入檔案時區分貓咪圖片和大橋圖片。貓咪圖片使用 0
,大橋圖片使用 1
image_labels = {
cat_in_snow : 0,
williamsburg_bridge : 1,
}
# This is an example, just using the cat image.
image_string = open(cat_in_snow, 'rb').read()
label = image_labels[cat_in_snow]
# Create a dictionary with features that may be relevant.
def image_example(image_string, label):
image_shape = tf.io.decode_jpeg(image_string).shape
feature = {
'height': _int64_feature(image_shape[0]),
'width': _int64_feature(image_shape[1]),
'depth': _int64_feature(image_shape[2]),
'label': _int64_feature(label),
'image_raw': _bytes_feature(image_string),
}
return tf.train.Example(features=tf.train.Features(feature=feature))
for line in str(image_example(image_string, label)).split('\n')[:15]:
print(line)
print('...')
features { feature { key: "depth" value { int64_list { value: 3 } } } feature { key: "height" value { int64_list { value: 213 } ...
請注意,所有特徵現在都儲存在 tf.train.Example
訊息中。接下來,將上述程式碼函數化,並將範例訊息寫入名為 images.tfrecords
的檔案
# Write the raw image files to `images.tfrecords`.
# First, process the two images into `tf.train.Example` messages.
# Then, write to a `.tfrecords` file.
record_file = 'images.tfrecords'
with tf.io.TFRecordWriter(record_file) as writer:
for filename, label in image_labels.items():
image_string = open(filename, 'rb').read()
tf_example = image_example(image_string, label)
writer.write(tf_example.SerializeToString())
du -sh {record_file}
36K images.tfrecords
讀取 TFRecord 檔案
您現在有了檔案—images.tfrecords
—而且現在可以反覆運算其中的記錄,以讀回您寫入的內容。由於在這個範例中,您只會重製圖片,因此您唯一需要的特徵是原始圖片字串。使用上述 getter 擷取它,即 example.features.feature['image_raw'].bytes_list.value[0]
。您也可以使用標籤來判斷哪個記錄是貓咪,哪個是大橋
raw_image_dataset = tf.data.TFRecordDataset('images.tfrecords')
# Create a dictionary describing the features.
image_feature_description = {
'height': tf.io.FixedLenFeature([], tf.int64),
'width': tf.io.FixedLenFeature([], tf.int64),
'depth': tf.io.FixedLenFeature([], tf.int64),
'label': tf.io.FixedLenFeature([], tf.int64),
'image_raw': tf.io.FixedLenFeature([], tf.string),
}
def _parse_image_function(example_proto):
# Parse the input tf.train.Example proto using the dictionary above.
return tf.io.parse_single_example(example_proto, image_feature_description)
parsed_image_dataset = raw_image_dataset.map(_parse_image_function)
parsed_image_dataset
<_MapDataset element_spec={'depth': TensorSpec(shape=(), dtype=tf.int64, name=None), 'height': TensorSpec(shape=(), dtype=tf.int64, name=None), 'image_raw': TensorSpec(shape=(), dtype=tf.string, name=None), 'label': TensorSpec(shape=(), dtype=tf.int64, name=None), 'width': TensorSpec(shape=(), dtype=tf.int64, name=None)}>
從 TFRecord 檔案復原圖片
for image_features in parsed_image_dataset:
image_raw = image_features['image_raw'].numpy()
display.display(display.Image(data=image_raw))
2024-07-13 05:37:48.876637: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence