![]() |
![]() |
![]() |
![]() |
設定
!pip install -q tf_nightly
import tensorflow as tf
import numpy as np
from typing import Tuple, List, Mapping, Union, Optional
import tempfile
擴充類型
使用者定義的類型可以讓專案更易讀、模組化、可維護。但是,大多數 TensorFlow API 對於使用者定義的 Python 類型的支援非常有限。這包括高階 API (例如 Keras、tf.function
、tf.SavedModel
) 和低階 API (例如 tf.while_loop
和 tf.concat
)。TensorFlow 擴充類型可用於建立使用者定義的物件導向類型,這些類型可以與 TensorFlow 的 API 無縫協作。若要建立擴充類型,只需定義一個以 tf.experimental.ExtensionType
作為其基底的 Python 類別,並使用類型註解來指定每個欄位的類型。
class TensorGraph(tf.experimental.ExtensionType):
"""A collection of labeled nodes connected by weighted edges."""
edge_weights: tf.Tensor # shape=[num_nodes, num_nodes]
node_labels: Mapping[str, tf.Tensor] # shape=[num_nodes]; dtype=any
class MaskedTensor(tf.experimental.ExtensionType):
"""A tensor paired with a boolean mask, indicating which values are valid."""
values: tf.Tensor
mask: tf.Tensor # shape=values.shape; false for missing/invalid values.
class CSRSparseMatrix(tf.experimental.ExtensionType):
"""Compressed sparse row matrix (https://en.wikipedia.org/wiki/Sparse_matrix)."""
values: tf.Tensor # shape=[num_nonzero]; dtype=any
col_index: tf.Tensor # shape=[num_nonzero]; dtype=int64
row_index: tf.Tensor # shape=[num_rows+1]; dtype=int64
tf.experimental.ExtensionType
基底類別的功能與標準 Python 程式庫中的 typing.NamedTuple
和 @dataclasses.dataclass
類似。特別是,它會根據欄位類型註解自動新增建構函式和特殊方法 (例如 __repr__
和 __eq__
)。
通常,擴充類型傾向於分為以下兩類之一
資料結構,它將相關值的集合組合在一起,並且可以根據這些值提供有用的運算。資料結構可能相當通用 (例如上面的
TensorGraph
範例);或者它們可以高度自訂以適用於特定模型。類似張量的類型,它們專門化或擴充「張量」的概念。此類別中的類型具有
rank
、shape
,通常還具有dtype
;並且將它們與張量運算 (例如tf.stack
、tf.add
或tf.matmul
) 一起使用是有意義的。MaskedTensor
和CSRSparseMatrix
是類似張量的類型的範例。
支援的 API
以下 TensorFlow API 支援擴充類型
- Keras:擴充類型可以用作 Keras
Models
和Layers
的輸入和輸出。 tf.data.Dataset
:擴充類型可以包含在Datasets
中,並由資料集Iterators
傳回。- TensorFlow Hub:擴充類型可以用作
tf.hub
模組的輸入和輸出。 - SavedModel:擴充類型可以用作
SavedModel
函式的輸入和輸出。 tf.function
:擴充類型可以用作以@tf.function
裝飾器包裝之函式的引數和傳回值。- While 迴圈:擴充類型可以用作
tf.while_loop
中的迴圈變數,並且可以用作 while 迴圈主體的引數和傳回值。 - 條件式:擴充類型可以使用
tf.cond
和tf.case
進行條件式選取。 tf.py_function
:擴充類型可以用作tf.py_function
的func
引數的引數和傳回值。- 張量運算:擴充類型可以擴充以支援大多數接受張量輸入的 TensorFlow 運算 (例如
tf.matmul
、tf.gather
和tf.reduce_sum
)。前往下方的「分派」章節以取得更多資訊。 - 分佈策略:擴充類型可以用作每個副本的值。
如需更多詳細資訊,請參閱下方關於「支援 ExtensionType 的 TensorFlow API」的章節。
需求
欄位類型
所有欄位 (執行個體變數) 都必須宣告,並且必須為每個欄位提供類型註解。支援以下類型註解
類型 | 範例 |
---|---|
Python 整數 | i: int |
Python 浮點數 | f: float |
Python 字串 | s: str |
Python 布林值 | b: bool |
Python None |
n: None |
張量形狀 | shape: tf.TensorShape |
張量 dtype |
dtype: tf.DType |
張量 | 張量 |
擴充類型 | mt: MyMaskedTensor |
參差不齊的張量 | rt: tf.RaggedTensor |
稀疏張量 | st: tf.SparseTensor |
索引切片 | s: tf.IndexedSlices |
選用張量 | o: tf.experimental.Optional |
類型聯集 | int_or_float: typing.Union[int, float] |
元組 | params: typing.Tuple[int, float, tf.Tensor, int] |
變長元組 | lengths: typing.Tuple[int, ...] |
對應 | tags: typing.Mapping[str, tf.Tensor] |
選用值 | weight: typing.Optional[tf.Tensor] |
可變性
擴充類型必須是不可變的。這確保它們可以由 TensorFlow 的圖追蹤機制正確追蹤。如果您發現自己想要變更擴充類型的值,請考慮改為定義轉換值的方法。例如,您可以定義一個 replace_mask
方法來傳回新的 MaskedTensor
,而不是定義 set_mask
方法來變更 MaskedTensor
。
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
def replace_mask(self, new_mask):
self.values.shape.assert_is_compatible_with(new_mask.shape)
return MaskedTensor(self.values, new_mask)
ExtensionType
新增的功能
ExtensionType
基底類別提供以下功能
- 建構函式 (
__init__
)。 - 可列印的表示方法 (
__repr__
)。 - 相等和不相等運算子 (
__eq__
)。 - 驗證方法 (
__validate__
)。 - 強制不可變性。
- 巢狀
TypeSpec
。 - 張量 API 分派支援。
前往下方的「自訂 ExtensionType
」章節以取得有關自訂此功能的更多資訊。
建構函式
ExtensionType
新增的建構函式將每個欄位作為具名引數 (依其在類別定義中列出的順序)。此建構函式將對每個參數進行類型檢查,並在必要時轉換它們。特別是,Tensor
欄位使用 tf.convert_to_tensor
轉換;Tuple
欄位轉換為 tuple
;並且 Mapping
欄位轉換為不可變的 dict。
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
# Constructor takes one parameter for each field.
mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
mask=[[True, True, False], [True, False, True]])
# Fields are type-checked and converted to the declared types.
# For example, `mt.values` is converted to a Tensor.
print(mt.values)
如果欄位值無法轉換為其宣告的類型,則建構函式會引發 TypeError
try:
MaskedTensor([1, 2, 3], None)
except TypeError as e:
print(f"Got expected TypeError: {e}")
可以透過在類別層級設定欄位的預設值來指定該值
class Pencil(tf.experimental.ExtensionType):
color: str = "black"
has_erasor: bool = True
length: tf.Tensor = 1.0
Pencil()
Pencil(length=0.5, color="blue")
可列印的表示
ExtensionType
新增預設的可列印表示方法 (__repr__
),其中包括類別名稱和每個欄位的值
print(MaskedTensor(values=[1, 2, 3], mask=[True, True, False]))
相等運算子
ExtensionType
新增預設的相等運算子 (__eq__
和 __ne__
),如果兩個值具有相同的類型且其所有欄位都相等,則認為它們相等。如果張量欄位具有相同的形狀並且對於所有元素都逐元素相等,則認為它們相等。
a = MaskedTensor([1, 2], [True, False])
b = MaskedTensor([[3, 4], [5, 6]], [[False, True], [True, True]])
print(f"a == a: {a==a}")
print(f"a == b: {a==b}")
print(f"a == a.values: {a==a.values}")
驗證方法
ExtensionType
新增 __validate__
方法,可以覆寫該方法以對欄位執行驗證檢查。它在呼叫建構函式之後,以及在欄位經過類型檢查並轉換為其宣告的類型之後執行,因此它可以假設所有欄位都具有其宣告的類型。
以下範例更新 MaskedTensor
以驗證其欄位的 shape
和 dtype
class MaskedTensor(tf.experimental.ExtensionType):
"""A tensor paired with a boolean mask, indicating which values are valid."""
values: tf.Tensor
mask: tf.Tensor
def __validate__(self):
self.values.shape.assert_is_compatible_with(self.mask.shape)
assert self.mask.dtype.is_bool, 'mask.dtype must be bool'
try:
MaskedTensor([1, 2, 3], [0, 1, 0]) # Wrong `dtype` for mask.
except AssertionError as e:
print(f"Got expected AssertionError: {e}")
try:
MaskedTensor([1, 2, 3], [True, False]) # shapes don't match.
except ValueError as e:
print(f"Got expected ValueError: {e}")
強制不可變性
ExtensionType
覆寫 __setattr__
和 __delattr__
方法以防止變更,從而確保擴充類型值是不可變的。
mt = MaskedTensor([1, 2, 3], [True, False, True])
try:
mt.mask = [True, True, True]
except AttributeError as e:
print(f"Got expected AttributeError: {e}")
try:
mt.mask[0] = False
except TypeError as e:
print(f"Got expected TypeError: {e}")
try:
del mt.mask
except AttributeError as e:
print(f"Got expected AttributeError: {e}")
巢狀 TypeSpec
每個 ExtensionType
類別都有對應的 TypeSpec
類別,該類別會自動建立並儲存為 <extension_type_name>.Spec
。
此類別擷取值的所有資訊,除了任何巢狀張量的值之外。特別是,值的 TypeSpec
是透過將任何巢狀張量、ExtensionType 或 CompositeTensor 替換為其 TypeSpec
來建立的。
class Player(tf.experimental.ExtensionType):
name: tf.Tensor
attributes: Mapping[str, tf.Tensor]
anne = Player("Anne", {"height": 8.3, "speed": 28.1})
anne_spec = tf.type_spec_from_value(anne)
print(anne_spec.name) # Records `dtype` and `shape`, but not the string value.
print(anne_spec.attributes) # Records keys and TensorSpecs for values.
TypeSpec
值可以明確建構,也可以使用 tf.type_spec_from_value
從 ExtensionType
值建構
spec1 = Player.Spec(name=tf.TensorSpec([], tf.float32), attributes={})
spec2 = tf.type_spec_from_value(anne)
TypeSpec
由 TensorFlow 用於將值分為靜態元件和動態元件
- 靜態元件 (在圖建構時固定) 使用
tf.TypeSpec
編碼。 - 動態元件 (每次執行圖時可能會有所不同) 編碼為
tf.Tensor
的清單。
例如,每當引數具有先前未見過的 TypeSpec
時,tf.function
就會重新追蹤其包裝的函式
@tf.function
def anonymize_player(player):
print("<<TRACING>>")
return Player("<anonymous>", player.attributes)
# Function gets traced (first time the function has been called):
anonymize_player(Player("Anne", {"height": 8.3, "speed": 28.1}))
# Function does NOT get traced (same TypeSpec: just tensor values changed)
anonymize_player(Player("Bart", {"height": 8.1, "speed": 25.3}))
# Function gets traced (new TypeSpec: keys for attributes changed):
anonymize_player(Player("Chuck", {"height": 11.0, "jump": 5.3}))
如需更多資訊,請參閱 tf.function 指南。
自訂 ExtensionType
除了簡單地宣告欄位及其類型之外,擴充類型還可以
- 覆寫預設的可列印表示 (
__repr__
)。 - 定義方法。
- 定義
classmethod
和staticmethod
。 - 定義屬性。
- 覆寫預設建構函式 (
__init__
)。 - 覆寫預設相等運算子 (
__eq__
)。 - 定義運算子 (例如
__add__
和__lt__
)。 - 宣告欄位的預設值。
- 定義子類別。
覆寫預設的可列印表示
您可以覆寫擴充類型的此預設字串轉換運算子。以下範例更新 MaskedTensor
類別,以在 Eager 模式下列印值時產生更易讀的字串表示。
class MaskedTensor(tf.experimental.ExtensionType):
"""A tensor paired with a boolean mask, indicating which values are valid."""
values: tf.Tensor
mask: tf.Tensor # shape=values.shape; false for invalid values.
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
def masked_tensor_str(values, mask):
if isinstance(values, tf.Tensor):
if hasattr(values, 'numpy') and hasattr(mask, 'numpy'):
return f'<MaskedTensor {masked_tensor_str(values.numpy(), mask.numpy())}>'
else:
return f'MaskedTensor(values={values}, mask={mask})'
if len(values.shape) == 1:
items = [repr(v) if m else '_' for (v, m) in zip(values, mask)]
else:
items = [masked_tensor_str(v, m) for (v, m) in zip(values, mask)]
return '[%s]' % ', '.join(items)
mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
mask=[[True, True, False], [True, False, True]])
print(mt)
定義方法
擴充類型可以定義方法,就像任何一般的 Python 類別一樣。例如,MaskedTensor
類型可以定義一個 with_default
方法,該方法傳回 self
的副本,其中遮罩值已替換為給定的 default
值。方法可以選擇性地使用 @tf.function
裝飾器進行註解。
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
def with_default(self, default):
return tf.where(self.mask, self.values, default)
MaskedTensor([1, 2, 3], [True, False, True]).with_default(0)
定義 classmethod
和 staticmethod
擴充類型可以使用 @classmethod
和 @staticmethod
裝飾器來定義方法。例如,MaskedTensor
類型可以定義一個工廠方法,該方法會遮罩具有給定值的任何元素
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
@staticmethod
def from_tensor_and_value_to_mask(values, value_to_mask):
return MaskedTensor(values, values != value_to_mask)
x = tf.constant([[1, 0, 2], [3, 0, 0]])
MaskedTensor.from_tensor_and_value_to_mask(x, 0)
定義屬性
擴充類型可以使用 @property
裝飾器來定義屬性,就像任何一般的 Python 類別一樣。例如,MaskedTensor
類型可以定義一個 dtype
屬性,該屬性是值的 dtype
的簡寫
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
@property
def dtype(self):
return self.values.dtype
MaskedTensor([1, 2, 3], [True, False, True]).dtype
覆寫預設建構函式
您可以覆寫擴充類型的預設建構函式。自訂建構函式必須為每個宣告的欄位設定值;並且在自訂建構函式傳回之後,所有欄位都將進行類型檢查,並且值將如上所述進行轉換。
class Toy(tf.experimental.ExtensionType):
name: str
price: tf.Tensor
def __init__(self, name, price, discount=0):
self.name = name
self.price = price * (1 - discount)
print(Toy("ball", 5.0, discount=0.2)) # On sale -- 20% off!
或者,您可以考慮保持預設建構函式不變,但新增一個或多個工廠方法。例如
class Toy(tf.experimental.ExtensionType):
name: str
price: tf.Tensor
@staticmethod
def new_toy_with_discount(name, price, discount):
return Toy(name, price * (1 - discount))
print(Toy.new_toy_with_discount("ball", 5.0, discount=0.2))
覆寫預設相等運算子 (__eq__
)
您可以覆寫擴充類型的預設 __eq__
運算子。以下範例更新 MaskedTensor
以在比較相等性時忽略遮罩元素。
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
def __eq__(self, other):
result = tf.math.equal(self.values, other.values)
result = result | ~(self.mask & other.mask)
return tf.reduce_all(result)
x = MaskedTensor([1, 2, 3, 4], [True, True, False, True])
y = MaskedTensor([5, 2, 0, 4], [False, True, False, True])
print(x == y)
使用前向參照
如果欄位的類型尚未定義,您可以使用包含類型名稱的字串來代替。在以下範例中,字串 "Node"
用於註解 children
欄位,因為 Node
類型尚未 (完全) 定義。
class Node(tf.experimental.ExtensionType):
value: tf.Tensor
children: Tuple["Node", ...] = ()
Node(3, [Node(5), Node(2)])
定義子類別
可以使用標準 Python 語法對擴充類型進行子類別化。擴充類型子類別可以新增新的欄位、方法和屬性;並且可以覆寫建構函式、可列印表示和相等運算子。以下範例定義了一個基本的 TensorGraph
類別,該類別使用三個 Tensor
欄位來編碼節點之間的一組邊緣。然後,它定義一個子類別,該子類別新增一個 Tensor
欄位來記錄每個節點的「特徵值」。子類別還定義了一種沿邊緣傳播特徵值的方法。
class TensorGraph(tf.experimental.ExtensionType):
num_nodes: tf.Tensor
edge_src: tf.Tensor # edge_src[e] = index of src node for edge e.
edge_dst: tf.Tensor # edge_dst[e] = index of dst node for edge e.
class TensorGraphWithNodeFeature(TensorGraph):
node_features: tf.Tensor # node_features[n] = feature value for node n.
def propagate_features(self, weight=1.0) -> 'TensorGraphWithNodeFeature':
updates = tf.gather(self.node_features, self.edge_src) * weight
new_node_features = tf.tensor_scatter_nd_add(
self.node_features, tf.expand_dims(self.edge_dst, 1), updates)
return TensorGraphWithNodeFeature(
self.num_nodes, self.edge_src, self.edge_dst, new_node_features)
g = TensorGraphWithNodeFeature( # Edges: 0->1, 4->3, 2->2, 2->1
num_nodes=5, edge_src=[0, 4, 2, 2], edge_dst=[1, 3, 2, 1],
node_features=[10.0, 0.0, 2.0, 5.0, -1.0, 0.0])
print("Original features:", g.node_features)
print("After propagating:", g.propagate_features().node_features)
定義私有欄位
擴充類型的欄位可以透過在其前面加上底線 (遵循標準 Python 慣例) 來標記為私有。這不會以任何方式影響 TensorFlow 處理欄位的方式;但只是作為對擴充類型任何使用者的訊號,表明這些欄位是私有的。
自訂 ExtensionType
的 TypeSpec
每個 ExtensionType
類別都有對應的 TypeSpec
類別,該類別會自動建立並儲存為 <extension_type_name>.Spec
。如需更多資訊,請參閱上面的「巢狀 TypeSpec」章節。
若要自訂 TypeSpec
,只需定義您自己的巢狀類別,名為 Spec
,並且 ExtensionType
將使用它作為自動建構的 TypeSpec
的基礎。您可以透過以下方式自訂 Spec
類別
- 覆寫預設的可列印表示。
- 覆寫預設建構函式。
- 定義方法、
classmethod
、staticmethod
和屬性。
以下範例自訂 MaskedTensor.Spec
類別,使其更易於使用
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
shape = property(lambda self: self.values.shape)
dtype = property(lambda self: self.values.dtype)
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
def with_values(self, new_values):
return MaskedTensor(new_values, self.mask)
class Spec:
def __init__(self, shape, dtype=tf.float32):
self.values = tf.TensorSpec(shape, dtype)
self.mask = tf.TensorSpec(shape, tf.bool)
def __repr__(self):
return f"MaskedTensor.Spec(shape={self.shape}, dtype={self.dtype})"
shape = property(lambda self: self.values.shape)
dtype = property(lambda self: self.values.dtype)
張量 API 分派
擴充類型可以是「類似張量」的,因為它們專門化或擴充 tf.Tensor
類型定義的介面。類似張量的擴充類型的範例包括 RaggedTensor
、SparseTensor
和 MaskedTensor
。當分派裝飾器應用於類似張量的擴充類型時,可用於覆寫 TensorFlow 運算的預設行為。TensorFlow 目前定義了三個分派裝飾器
@tf.experimental.dispatch_for_api(tf_api)
@tf.experimental.dispatch_for_unary_elementwise_apis(x_type)
@tf.experimental.dispatch_for_binary_elementwise_apis(x_type, y_type)
單個 API 的分派
tf.experimental.dispatch_for_api
裝飾器會覆寫指定的 TensorFlow 運算的預設行為,當使用指定的簽名呼叫它時。例如,您可以使用此裝飾器來指定 tf.stack
應如何處理 MaskedTensor
值
@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack(values: List[MaskedTensor], axis = 0):
return MaskedTensor(tf.stack([v.values for v in values], axis),
tf.stack([v.mask for v in values], axis))
每當使用 MaskedTensor
值的清單呼叫 tf.stack
時,這都會覆寫預設實作 (因為 values
引數使用 typing.List[MaskedTensor]
進行了註解)
x = MaskedTensor([1, 2, 3], [True, True, False])
y = MaskedTensor([4, 5, 6], [False, True, True])
tf.stack([x, y])
若要允許 tf.stack
處理混合 MaskedTensor
和 Tensor
值的清單,您可以精簡 values
參數的類型註解,並適當地更新函式的主體
tf.experimental.unregister_dispatch_for(masked_stack)
def convert_to_masked_tensor(x):
if isinstance(x, MaskedTensor):
return x
else:
return MaskedTensor(x, tf.ones_like(x, tf.bool))
@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack_v2(values: List[Union[MaskedTensor, tf.Tensor]], axis = 0):
values = [convert_to_masked_tensor(v) for v in values]
return MaskedTensor(tf.stack([v.values for v in values], axis),
tf.stack([v.mask for v in values], axis))
x = MaskedTensor([1, 2, 3], [True, True, False])
y = tf.constant([4, 5, 6])
tf.stack([x, y, x])
如需可以覆寫的 API 清單,請參閱 tf.experimental.dispatch_for_api
的 API 文件。
所有一元逐元素 API 的分派
每當第一個引數 (通常命名為 x
) 的值符合類型註解 x_type
時,tf.experimental.dispatch_for_unary_elementwise_apis
裝飾器會覆寫所有一元逐元素運算 (例如 tf.math.cos
) 的預設行為。裝飾的函式應採用兩個引數
api_func
:一個採用單個參數並執行逐元素運算的函式 (例如,tf.abs
)。x
:逐元素運算的第一個引數。
以下範例更新所有一元逐元素運算以處理 MaskedTensor
類型
@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
def masked_tensor_unary_elementwise_api_handler(api_func, x):
return MaskedTensor(api_func(x.values), x.mask)
現在,每當在一元逐元素運算在 MaskedTensor
上呼叫時,都會使用此函式。
x = MaskedTensor([1, -2, -3], [True, False, True])
print(tf.abs(x))
print(tf.ones_like(x, dtype=tf.float32))
二元所有逐元素 API 的分派
同樣地,tf.experimental.dispatch_for_binary_elementwise_apis
可用於更新所有二元逐元素運算以處理 MaskedTensor
類型
@tf.experimental.dispatch_for_binary_elementwise_apis(MaskedTensor, MaskedTensor)
def masked_tensor_binary_elementwise_api_handler(api_func, x, y):
return MaskedTensor(api_func(x.values, y.values), x.mask & y.mask)
x = MaskedTensor([1, -2, -3], [True, False, True])
y = MaskedTensor([[4], [5]], [[True], [False]])
tf.math.add(x, y)
如需覆寫的逐元素 API 清單,請前往 tf.experimental.dispatch_for_unary_elementwise_apis
和 tf.experimental.dispatch_for_binary_elementwise_apis
的 API 文件。
可批次處理的 ExtensionType
如果可以使用單個執行個體來表示一批值,則 ExtensionType
是可批次處理的。通常,這是透過將批次維度新增至所有巢狀 Tensor
來完成的。以下 TensorFlow API 要求任何擴充類型輸入都是可批次處理的
tf.data.Dataset
(batch
、unbatch
、from_tensor_slices
)tf.keras
(fit
、evaluate
、predict
)tf.map_fn
預設情況下,BatchableExtensionType
透過批次處理任何巢狀 Tensor
、CompositeTensor
和 ExtensionType
來建立批次處理的值。如果這不適用於您的類別,則您需要使用 tf.experimental.ExtensionTypeBatchEncoder
來覆寫此預設行為。例如,透過簡單地堆疊個別稀疏張量的 values
、indices
和 dense_shape
欄位來建立一批 tf.SparseTensor
值是不適當的 - 在大多數情況下,您無法堆疊這些張量,因為它們具有不相容的形狀;即使您可以,結果也不會是有效的 SparseTensor
。
BatchableExtensionType
範例:Network
作為範例,請考慮用於負載平衡的簡單 Network
類別,它追蹤每個節點中剩餘的工作量,以及可用於在節點之間移動工作的頻寬
class Network(tf.experimental.ExtensionType): # This version is not batchable.
work: tf.Tensor # work[n] = work left to do at node n
bandwidth: tf.Tensor # bandwidth[n1, n2] = bandwidth from n1->n2
net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])
若要使此類型可批次處理,請將基底類型變更為 BatchableExtensionType
,並調整每個欄位的形狀以包含選用的批次維度。以下範例還新增了 shape
欄位以追蹤批次形狀。雖然 tf.data.Dataset
或 tf.map_fn
並不要求此 shape
欄位,但 tf.keras
確實要求。
class Network(tf.experimental.BatchableExtensionType):
shape: tf.TensorShape # batch shape. A single network has shape=[].
work: tf.Tensor # work[*shape, n] = work left to do at node n
bandwidth: tf.Tensor # bandwidth[*shape, n1, n2] = bandwidth from n1->n2
def __init__(self, work, bandwidth):
self.work = tf.convert_to_tensor(work)
self.bandwidth = tf.convert_to_tensor(bandwidth)
work_batch_shape = self.work.shape[:-1]
bandwidth_batch_shape = self.bandwidth.shape[:-2]
self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)
def __repr__(self):
return network_repr(self)
def network_repr(network):
work = network.work
bandwidth = network.bandwidth
if hasattr(work, 'numpy'):
work = ' '.join(str(work.numpy()).split())
if hasattr(bandwidth, 'numpy'):
bandwidth = ' '.join(str(bandwidth.numpy()).split())
return (f"<Network shape={network.shape} work={work} bandwidth={bandwidth}>")
net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])
batch_of_networks = Network(
work=tf.stack([net1.work, net2.work]),
bandwidth=tf.stack([net1.bandwidth, net2.bandwidth]))
print(f"net1={net1}")
print(f"net2={net2}")
print(f"batch={batch_of_networks}")
然後,您可以使用 tf.data.Dataset
迭代一批網路
dataset = tf.data.Dataset.from_tensor_slices(batch_of_networks)
for i, network in enumerate(dataset):
print(f"Batch element {i}: {network}")
您也可以使用 map_fn
將函式套用至每個批次元素
def balance_work_greedy(network):
delta = (tf.expand_dims(network.work, -1) - tf.expand_dims(network.work, -2))
delta /= 4
delta = tf.maximum(tf.minimum(delta, network.bandwidth), -network.bandwidth)
new_work = network.work + tf.reduce_sum(delta, -1)
return Network(new_work, network.bandwidth)
tf.map_fn(balance_work_greedy, batch_of_networks)
支援 ExtensionType
的 TensorFlow API
@tf.function
tf.function
是一個裝飾器,可為 Python 函式預先計算 TensorFlow 圖,這可以大幅提升 TensorFlow 程式碼的效能。擴充類型值可以與 @tf.function
裝飾的函式透明地使用。
class Pastry(tf.experimental.ExtensionType):
sweetness: tf.Tensor # 2d embedding that encodes sweetness
chewiness: tf.Tensor # 2d embedding that encodes chewiness
@tf.function
def combine_pastry_features(x: Pastry):
return (x.sweetness + x.chewiness) / 2
cookie = Pastry(sweetness=[1.2, 0.4], chewiness=[0.8, 0.2])
combine_pastry_features(cookie)
如果您希望明確指定 tf.function
的 input_signature
,則可以使用擴充類型的 TypeSpec
來執行此操作。
pastry_spec = Pastry.Spec(tf.TensorSpec([2]), tf.TensorSpec(2))
@tf.function(input_signature=[pastry_spec])
def increase_sweetness(x: Pastry, delta=1.0):
return Pastry(x.sweetness + delta, x.chewiness)
increase_sweetness(cookie)
具體函式
具體函式封裝了由 tf.function
建置的個別追蹤圖。擴充類型可以與具體函式透明地使用。
cf = combine_pastry_features.get_concrete_function(pastry_spec)
cf(cookie)
控制流程運算
TensorFlow 的控制流程運算支援擴充類型
# Example: using tf.cond to select between two MaskedTensors. Note that the
# two MaskedTensors don't need to have the same shape.
a = MaskedTensor([1., 2, 3], [True, False, True])
b = MaskedTensor([22., 33, 108, 55], [True, True, True, False])
condition = tf.constant(True)
print(tf.cond(condition, lambda: a, lambda: b))
# Example: using tf.while_loop with MaskedTensor.
cond = lambda i, _: i < 10
def body(i, mt):
return i + 1, mt.with_values(mt.values + 3 / 7)
print(tf.while_loop(cond, body, [0, b])[1])
Autograph 控制流程
在 tf.function
(使用 autograph) 中的控制流程陳述式也支援擴充類型。在以下範例中,if
陳述式和 for
陳述式會自動轉換為 tf.cond
和 tf.while_loop
運算,這些運算支援擴充類型。
@tf.function
def fn(x, b):
if b:
x = MaskedTensor(x, tf.less(x, 0))
else:
x = MaskedTensor(x, tf.greater(x, 0))
for i in tf.range(5 if b else 7):
x = x.with_values(x.values + 1 / 2)
return x
print(fn(tf.constant([1., -2, 3]), tf.constant(True)))
print(fn(tf.constant([1., -2, 3]), tf.constant(False)))
Keras
tf.keras 是 TensorFlow 的高階 API,用於建構和訓練深度學習模型。擴充類型可以做為 Keras 模型的輸入傳遞、在 Keras 層之間傳遞,以及由 Keras 模型傳回。Keras 目前對擴充類型有兩項要求:
- 它們必須可批次處理 (前往上方的「可批次處理的
ExtensionType
」)。 - 它們必須具有名為
shape
的欄位或屬性。shape[0]
假設為批次維度。
以下兩個小節提供範例,說明如何將擴充類型與 Keras 搭配使用。
Keras 範例:Network
對於第一個範例,請考慮上方「可批次處理的 ExtensionType
」章節中定義的 Network
類別,該類別可用於在節點之間進行負載平衡。其定義在此重複:
class Network(tf.experimental.BatchableExtensionType):
shape: tf.TensorShape # batch shape. A single network has shape=[].
work: tf.Tensor # work[*shape, n] = work left to do at node n
bandwidth: tf.Tensor # bandwidth[*shape, n1, n2] = bandwidth from n1->n2
def __init__(self, work, bandwidth):
self.work = tf.convert_to_tensor(work)
self.bandwidth = tf.convert_to_tensor(bandwidth)
work_batch_shape = self.work.shape[:-1]
bandwidth_batch_shape = self.bandwidth.shape[:-2]
self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)
def __repr__(self):
return network_repr(self)
single_network = Network( # A single network with 4 nodes.
work=[8.0, 5, 12, 2],
bandwidth=[[0.0, 1, 2, 2], [1, 0, 0, 2], [2, 0, 0, 1], [2, 2, 1, 0]])
batch_of_networks = Network( # Batch of 2 networks, each w/ 2 nodes.
work=[[8.0, 5], [3, 2]],
bandwidth=[[[0.0, 1], [1, 0]], [[0, 2], [2, 0]]])
您可以定義一個新的 Keras 層來處理 Network
。
class BalanceNetworkLayer(tf.keras.layers.Layer):
"""Layer that balances work between nodes in a network.
Shifts work from more busy nodes to less busy nodes, constrained by bandwidth.
"""
def call(self, inputs):
# This function is defined above in the "Batchable `ExtensionType`s" section.
return balance_work_greedy(inputs)
然後,您可以使用這些層來建立簡單的模型。若要將 ExtensionType
饋送至模型,您可以使用 tf.keras.layer.Input
層,並將 type_spec
設定為擴充類型的 TypeSpec
。如果 Keras 模型將用於處理批次,則 type_spec
必須包含批次維度。
input_spec = Network.Spec(shape=None,
work=tf.TensorSpec(None, tf.float32),
bandwidth=tf.TensorSpec(None, tf.float32))
model = tf.keras.Sequential([
tf.keras.layers.Input(type_spec=input_spec),
BalanceNetworkLayer(),
])
最後,您可以將模型套用至單一網路和一批網路。
model(single_network)
model(batch_of_networks)
Keras 範例:MaskedTensor
在此範例中,擴充 MaskedTensor
以支援 Keras
。shape
定義為從 values
欄位計算的屬性。Keras 要求您將此屬性新增至擴充類型及其 TypeSpec
。MaskedTensor
也定義了 __name__
變數,這將是 SavedModel
序列化 (如下所述) 所需的。
class MaskedTensor(tf.experimental.BatchableExtensionType):
# __name__ is required for serialization in SavedModel; see below for details.
__name__ = 'extension_type_colab.MaskedTensor'
values: tf.Tensor
mask: tf.Tensor
shape = property(lambda self: self.values.shape)
dtype = property(lambda self: self.values.dtype)
def with_default(self, default):
return tf.where(self.mask, self.values, default)
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
class Spec:
def __init__(self, shape, dtype=tf.float32):
self.values = tf.TensorSpec(shape, dtype)
self.mask = tf.TensorSpec(shape, tf.bool)
shape = property(lambda self: self.values.shape)
dtype = property(lambda self: self.values.dtype)
def with_shape(self):
return MaskedTensor.Spec(tf.TensorSpec(shape, self.values.dtype),
tf.TensorSpec(shape, self.mask.dtype))
接下來,dispatch 裝飾器用於覆寫數個 TensorFlow API 的預設行為。由於標準 Keras 層 (例如 Dense
層) 使用這些 API,因此覆寫這些 API 將允許我們將這些層與 MaskedTensor
搭配使用。就本範例而言,遮罩張量的 matmul
定義為將遮罩值視為零 (也就是說,不將它們包含在乘積中)。
@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
def unary_elementwise_op_handler(op, x):
return MaskedTensor(op(x.values), x.mask)
@tf.experimental.dispatch_for_binary_elementwise_apis(
Union[MaskedTensor, tf.Tensor],
Union[MaskedTensor, tf.Tensor])
def binary_elementwise_op_handler(op, x, y):
x = convert_to_masked_tensor(x)
y = convert_to_masked_tensor(y)
return MaskedTensor(op(x.values, y.values), x.mask & y.mask)
@tf.experimental.dispatch_for_api(tf.matmul)
def masked_matmul(a: MaskedTensor, b,
transpose_a=False, transpose_b=False,
adjoint_a=False, adjoint_b=False,
a_is_sparse=False, b_is_sparse=False,
output_type=None):
if isinstance(a, MaskedTensor):
a = a.with_default(0)
if isinstance(b, MaskedTensor):
b = b.with_default(0)
return tf.matmul(a, b, transpose_a, transpose_b, adjoint_a,
adjoint_b, a_is_sparse, b_is_sparse, output_type)
然後,您可以建構一個接受 MaskedTensor
輸入的 Keras 模型,使用標準 Keras 層
input_spec = MaskedTensor.Spec([None, 2], tf.float32)
masked_tensor_model = tf.keras.Sequential([
tf.keras.layers.Input(type_spec=input_spec),
tf.keras.layers.Dense(16, activation="relu"),
tf.keras.layers.Dense(1)])
masked_tensor_model.compile(loss='binary_crossentropy', optimizer='rmsprop')
a = MaskedTensor([[1., 2], [3, 4], [5, 6]],
[[True, False], [False, True], [True, True]])
masked_tensor_model.fit(a, tf.constant([[1], [0], [1]]), epochs=3)
print(masked_tensor_model(a))
SavedModel
SavedModel 是序列化的 TensorFlow 程式,包括權重和運算。它可以從 Keras 模型或自訂模型建置。在任一種情況下,擴充類型都可以與 SavedModel 定義的函式和方法透明地搭配使用。
SavedModel 可以儲存處理擴充類型的模型、層和函式,只要擴充類型具有 __name__
欄位即可。此名稱用於註冊擴充類型,以便在載入模型時可以找到它。
範例:儲存 Keras 模型
可以使用 SavedModel
儲存使用擴充類型的 Keras 模型。
masked_tensor_model_path = tempfile.mkdtemp()
tf.saved_model.save(masked_tensor_model, masked_tensor_model_path)
imported_model = tf.saved_model.load(masked_tensor_model_path)
imported_model(a)
範例:儲存自訂模型
SavedModel 也可用於儲存自訂 tf.Module
子類別,其中包含處理擴充類型的函式。
class CustomModule(tf.Module):
def __init__(self, variable_value):
super().__init__()
self.v = tf.Variable(variable_value)
@tf.function
def grow(self, x: MaskedTensor):
"""Increase values in `x` by multiplying them by `self.v`."""
return MaskedTensor(x.values * self.v, x.mask)
module = CustomModule(100.0)
module.grow.get_concrete_function(MaskedTensor.Spec(shape=None,
dtype=tf.float32))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
imported_model.grow(MaskedTensor([1., 2, 3], [False, True, False]))
當 ExtensionType
無法使用時載入 SavedModel
如果您載入使用 ExtensionType
的 SavedModel
,但該 ExtensionType
無法使用 (也就是說,尚未匯入),則您會收到警告,且 TensorFlow 將回復為使用「匿名擴充類型」物件。此物件將具有與原始類型相同的欄位,但會缺少您為類型新增的任何進一步自訂,例如自訂方法或屬性。
將 ExtensionType
與 TensorFlow Serving 搭配使用
目前,TensorFlow Serving (和其他 SavedModel「簽名」字典的消費者) 要求所有輸入和輸出都必須是原始張量。如果您希望將 TensorFlow Serving 與使用擴充類型的模型搭配使用,則可以新增包裝函式方法,以從張量組成或分解擴充類型值。例如:
class CustomModuleWrapper(tf.Module):
def __init__(self, variable_value):
super().__init__()
self.v = tf.Variable(variable_value)
@tf.function
def var_weighted_mean(self, x: MaskedTensor):
"""Mean value of unmasked values in x, weighted by self.v."""
x = MaskedTensor(x.values * self.v, x.mask)
return (tf.reduce_sum(x.with_default(0)) /
tf.reduce_sum(tf.cast(x.mask, x.dtype)))
@tf.function()
def var_weighted_mean_wrapper(self, x_values, x_mask):
"""Raw tensor wrapper for var_weighted_mean."""
return self.var_weighted_mean(MaskedTensor(x_values, x_mask))
module = CustomModuleWrapper([3., 2., 8., 5.])
module.var_weighted_mean_wrapper.get_concrete_function(
tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.bool))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
x = MaskedTensor([1., 2., 3., 4.], [False, True, False, True])
imported_model.var_weighted_mean_wrapper(x.values, x.mask)
Dataset
tf.data
是一個 API,可讓您從簡單、可重複使用的片段建構複雜的輸入管線。其核心資料結構是 tf.data.Dataset
,它代表一系列元素,其中每個元素都由一個或多個元件組成。
使用擴充類型建構 Dataset
可以使用 Dataset.from_tensors
、Dataset.from_tensor_slices
或 Dataset.from_generator
,從擴充類型值建構資料集
ds = tf.data.Dataset.from_tensors(Pastry(5, 5))
iter(ds).next()
mt = MaskedTensor(tf.reshape(range(20), [5, 4]), tf.ones([5, 4]))
ds = tf.data.Dataset.from_tensor_slices(mt)
for value in ds:
print(value)
def value_gen():
for i in range(2, 7):
yield MaskedTensor(range(10), [j%i != 0 for j in range(10)])
ds = tf.data.Dataset.from_generator(
value_gen, output_signature=MaskedTensor.Spec(shape=[10], dtype=tf.int32))
for value in ds:
print(value)
使用擴充類型批次和取消批次處理 Dataset
具有擴充類型的資料集可以使用 Dataset.batch
和 Dataset.unbatch
進行批次處理和取消批次處理。
batched_ds = ds.batch(2)
for value in batched_ds:
print(value)
unbatched_ds = batched_ds.unbatch()
for value in unbatched_ds:
print(value)