擴充類型

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視原始碼 下載筆記本

設定

!pip install -q tf_nightly
import tensorflow as tf
import numpy as np
from typing import Tuple, List, Mapping, Union, Optional
import tempfile

擴充類型

使用者定義的類型可以讓專案更易讀、模組化、可維護。但是,大多數 TensorFlow API 對於使用者定義的 Python 類型的支援非常有限。這包括高階 API (例如 Kerastf.functiontf.SavedModel) 和低階 API (例如 tf.while_looptf.concat)。TensorFlow 擴充類型可用於建立使用者定義的物件導向類型,這些類型可以與 TensorFlow 的 API 無縫協作。若要建立擴充類型,只需定義一個以 tf.experimental.ExtensionType 作為其基底的 Python 類別,並使用類型註解來指定每個欄位的類型。

class TensorGraph(tf.experimental.ExtensionType):
  """A collection of labeled nodes connected by weighted edges."""
  edge_weights: tf.Tensor               # shape=[num_nodes, num_nodes]
  node_labels: Mapping[str, tf.Tensor]  # shape=[num_nodes]; dtype=any

class MaskedTensor(tf.experimental.ExtensionType):
  """A tensor paired with a boolean mask, indicating which values are valid."""
  values: tf.Tensor
  mask: tf.Tensor       # shape=values.shape; false for missing/invalid values.

class CSRSparseMatrix(tf.experimental.ExtensionType):
  """Compressed sparse row matrix (https://en.wikipedia.org/wiki/Sparse_matrix)."""
  values: tf.Tensor     # shape=[num_nonzero]; dtype=any
  col_index: tf.Tensor  # shape=[num_nonzero]; dtype=int64
  row_index: tf.Tensor  # shape=[num_rows+1]; dtype=int64

tf.experimental.ExtensionType 基底類別的功能與標準 Python 程式庫中的 typing.NamedTuple@dataclasses.dataclass 類似。特別是,它會根據欄位類型註解自動新增建構函式和特殊方法 (例如 __repr____eq__)。

通常,擴充類型傾向於分為以下兩類之一

  • 資料結構,它將相關值的集合組合在一起,並且可以根據這些值提供有用的運算。資料結構可能相當通用 (例如上面的 TensorGraph 範例);或者它們可以高度自訂以適用於特定模型。

  • 類似張量的類型,它們專門化或擴充「張量」的概念。此類別中的類型具有 rankshape,通常還具有 dtype;並且將它們與張量運算 (例如 tf.stacktf.addtf.matmul) 一起使用是有意義的。MaskedTensorCSRSparseMatrix 是類似張量的類型的範例。

支援的 API

以下 TensorFlow API 支援擴充類型

  • Keras:擴充類型可以用作 Keras ModelsLayers 的輸入和輸出。
  • tf.data.Dataset:擴充類型可以包含在 Datasets 中,並由資料集 Iterators 傳回。
  • TensorFlow Hub:擴充類型可以用作 tf.hub 模組的輸入和輸出。
  • SavedModel:擴充類型可以用作 SavedModel 函式的輸入和輸出。
  • tf.function:擴充類型可以用作以 @tf.function 裝飾器包裝之函式的引數和傳回值。
  • While 迴圈:擴充類型可以用作 tf.while_loop 中的迴圈變數,並且可以用作 while 迴圈主體的引數和傳回值。
  • 條件式:擴充類型可以使用 tf.condtf.case 進行條件式選取。
  • tf.py_function:擴充類型可以用作 tf.py_functionfunc 引數的引數和傳回值。
  • 張量運算:擴充類型可以擴充以支援大多數接受張量輸入的 TensorFlow 運算 (例如 tf.matmultf.gathertf.reduce_sum)。前往下方的「分派」章節以取得更多資訊。
  • 分佈策略:擴充類型可以用作每個副本的值。

如需更多詳細資訊,請參閱下方關於「支援 ExtensionType 的 TensorFlow API」的章節。

需求

欄位類型

所有欄位 (執行個體變數) 都必須宣告,並且必須為每個欄位提供類型註解。支援以下類型註解

類型 範例
Python 整數 i: int
Python 浮點數 f: float
Python 字串 s: str
Python 布林值 b: bool
Python None n: None
張量形狀 shape: tf.TensorShape
張量 dtype dtype: tf.DType
張量 張量
擴充類型 mt: MyMaskedTensor
參差不齊的張量 rt: tf.RaggedTensor
稀疏張量 st: tf.SparseTensor
索引切片 s: tf.IndexedSlices
選用張量 o: tf.experimental.Optional
類型聯集 int_or_float: typing.Union[int, float]
元組 params: typing.Tuple[int, float, tf.Tensor, int]
變長元組 lengths: typing.Tuple[int, ...]
對應 tags: typing.Mapping[str, tf.Tensor]
選用值 weight: typing.Optional[tf.Tensor]

可變性

擴充類型必須是不可變的。這確保它們可以由 TensorFlow 的圖追蹤機制正確追蹤。如果您發現自己想要變更擴充類型的值,請考慮改為定義轉換值的方法。例如,您可以定義一個 replace_mask 方法來傳回新的 MaskedTensor,而不是定義 set_mask 方法來變更 MaskedTensor

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def replace_mask(self, new_mask):
      self.values.shape.assert_is_compatible_with(new_mask.shape)
      return MaskedTensor(self.values, new_mask)

ExtensionType 新增的功能

ExtensionType 基底類別提供以下功能

  • 建構函式 (__init__)。
  • 可列印的表示方法 (__repr__)。
  • 相等和不相等運算子 (__eq__)。
  • 驗證方法 (__validate__)。
  • 強制不可變性。
  • 巢狀 TypeSpec
  • 張量 API 分派支援。

前往下方的「自訂 ExtensionType」章節以取得有關自訂此功能的更多資訊。

建構函式

ExtensionType 新增的建構函式將每個欄位作為具名引數 (依其在類別定義中列出的順序)。此建構函式將對每個參數進行類型檢查,並在必要時轉換它們。特別是,Tensor 欄位使用 tf.convert_to_tensor 轉換;Tuple 欄位轉換為 tuple;並且 Mapping 欄位轉換為不可變的 dict。

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

# Constructor takes one parameter for each field.
mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
                  mask=[[True, True, False], [True, False, True]])

# Fields are type-checked and converted to the declared types.
# For example, `mt.values` is converted to a Tensor.
print(mt.values)

如果欄位值無法轉換為其宣告的類型,則建構函式會引發 TypeError

try:
  MaskedTensor([1, 2, 3], None)
except TypeError as e:
  print(f"Got expected TypeError: {e}")

可以透過在類別層級設定欄位的預設值來指定該值

class Pencil(tf.experimental.ExtensionType):
  color: str = "black"
  has_erasor: bool = True
  length: tf.Tensor = 1.0

Pencil()
Pencil(length=0.5, color="blue")

可列印的表示

ExtensionType 新增預設的可列印表示方法 (__repr__),其中包括類別名稱和每個欄位的值

print(MaskedTensor(values=[1, 2, 3], mask=[True, True, False]))

相等運算子

ExtensionType 新增預設的相等運算子 (__eq____ne__),如果兩個值具有相同的類型且其所有欄位都相等,則認為它們相等。如果張量欄位具有相同的形狀並且對於所有元素都逐元素相等,則認為它們相等。

a = MaskedTensor([1, 2], [True, False])
b = MaskedTensor([[3, 4], [5, 6]], [[False, True], [True, True]])
print(f"a == a: {a==a}")
print(f"a == b: {a==b}")
print(f"a == a.values: {a==a.values}")

驗證方法

ExtensionType 新增 __validate__ 方法,可以覆寫該方法以對欄位執行驗證檢查。它在呼叫建構函式之後,以及在欄位經過類型檢查並轉換為其宣告的類型之後執行,因此它可以假設所有欄位都具有其宣告的類型。

以下範例更新 MaskedTensor 以驗證其欄位的 shapedtype

class MaskedTensor(tf.experimental.ExtensionType):
  """A tensor paired with a boolean mask, indicating which values are valid."""
  values: tf.Tensor
  mask: tf.Tensor
  def __validate__(self):
    self.values.shape.assert_is_compatible_with(self.mask.shape)
    assert self.mask.dtype.is_bool, 'mask.dtype must be bool'
try:
  MaskedTensor([1, 2, 3], [0, 1, 0])  # Wrong `dtype` for mask.
except AssertionError as e:
  print(f"Got expected AssertionError: {e}")
try:
  MaskedTensor([1, 2, 3], [True, False])  # shapes don't match.
except ValueError as e:
  print(f"Got expected ValueError: {e}")

強制不可變性

ExtensionType 覆寫 __setattr____delattr__ 方法以防止變更,從而確保擴充類型值是不可變的。

mt = MaskedTensor([1, 2, 3], [True, False, True])
try:
  mt.mask = [True, True, True]
except AttributeError as e:
  print(f"Got expected AttributeError: {e}")
try:
  mt.mask[0] = False
except TypeError as e:
  print(f"Got expected TypeError: {e}")
try:
  del mt.mask
except AttributeError as e:
  print(f"Got expected AttributeError: {e}")

巢狀 TypeSpec

每個 ExtensionType 類別都有對應的 TypeSpec 類別,該類別會自動建立並儲存為 <extension_type_name>.Spec

此類別擷取值的所有資訊,除了任何巢狀張量的值之外。特別是,值的 TypeSpec 是透過將任何巢狀張量、ExtensionType 或 CompositeTensor 替換為其 TypeSpec 來建立的。

class Player(tf.experimental.ExtensionType):
  name: tf.Tensor
  attributes: Mapping[str, tf.Tensor]

anne = Player("Anne", {"height": 8.3, "speed": 28.1})
anne_spec = tf.type_spec_from_value(anne)
print(anne_spec.name)  # Records `dtype` and `shape`, but not the string value.
print(anne_spec.attributes)  # Records keys and TensorSpecs for values.

TypeSpec 值可以明確建構,也可以使用 tf.type_spec_from_valueExtensionType 值建構

spec1 = Player.Spec(name=tf.TensorSpec([], tf.float32), attributes={})
spec2 = tf.type_spec_from_value(anne)

TypeSpec 由 TensorFlow 用於將值分為靜態元件動態元件

  • 靜態元件 (在圖建構時固定) 使用 tf.TypeSpec 編碼。
  • 動態元件 (每次執行圖時可能會有所不同) 編碼為 tf.Tensor 的清單。

例如,每當引數具有先前未見過的 TypeSpec 時,tf.function 就會重新追蹤其包裝的函式

@tf.function
def anonymize_player(player):
  print("<<TRACING>>")
  return Player("<anonymous>", player.attributes)
# Function gets traced (first time the function has been called):
anonymize_player(Player("Anne", {"height": 8.3, "speed": 28.1}))
# Function does NOT get traced (same TypeSpec: just tensor values changed)
anonymize_player(Player("Bart", {"height": 8.1, "speed": 25.3}))
# Function gets traced (new TypeSpec: keys for attributes changed):
anonymize_player(Player("Chuck", {"height": 11.0, "jump": 5.3}))

如需更多資訊,請參閱 tf.function 指南

自訂 ExtensionType

除了簡單地宣告欄位及其類型之外,擴充類型還可以

  • 覆寫預設的可列印表示 (__repr__)。
  • 定義方法。
  • 定義 classmethodstaticmethod
  • 定義屬性。
  • 覆寫預設建構函式 (__init__)。
  • 覆寫預設相等運算子 (__eq__)。
  • 定義運算子 (例如 __add____lt__)。
  • 宣告欄位的預設值。
  • 定義子類別。

覆寫預設的可列印表示

您可以覆寫擴充類型的此預設字串轉換運算子。以下範例更新 MaskedTensor 類別,以在 Eager 模式下列印值時產生更易讀的字串表示。

class MaskedTensor(tf.experimental.ExtensionType):
  """A tensor paired with a boolean mask, indicating which values are valid."""
  values: tf.Tensor
  mask: tf.Tensor       # shape=values.shape; false for invalid values.

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

def masked_tensor_str(values, mask):
  if isinstance(values, tf.Tensor):
    if hasattr(values, 'numpy') and hasattr(mask, 'numpy'):
      return f'<MaskedTensor {masked_tensor_str(values.numpy(), mask.numpy())}>'
    else:
      return f'MaskedTensor(values={values}, mask={mask})'
  if len(values.shape) == 1:
    items = [repr(v) if m else '_' for (v, m) in zip(values, mask)]
  else:
    items = [masked_tensor_str(v, m) for (v, m) in zip(values, mask)]
  return '[%s]' % ', '.join(items)

mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
                  mask=[[True, True, False], [True, False, True]])
print(mt)

定義方法

擴充類型可以定義方法,就像任何一般的 Python 類別一樣。例如,MaskedTensor 類型可以定義一個 with_default 方法,該方法傳回 self 的副本,其中遮罩值已替換為給定的 default 值。方法可以選擇性地使用 @tf.function 裝飾器進行註解。

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def with_default(self, default):
    return tf.where(self.mask, self.values, default)

MaskedTensor([1, 2, 3], [True, False, True]).with_default(0)

定義 classmethodstaticmethod

擴充類型可以使用 @classmethod@staticmethod 裝飾器來定義方法。例如,MaskedTensor 類型可以定義一個工廠方法,該方法會遮罩具有給定值的任何元素

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  @staticmethod
  def from_tensor_and_value_to_mask(values, value_to_mask):
    return MaskedTensor(values, values != value_to_mask)

x = tf.constant([[1, 0, 2], [3, 0, 0]])
MaskedTensor.from_tensor_and_value_to_mask(x, 0)

定義屬性

擴充類型可以使用 @property 裝飾器來定義屬性,就像任何一般的 Python 類別一樣。例如,MaskedTensor 類型可以定義一個 dtype 屬性,該屬性是值的 dtype 的簡寫

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  @property
  def dtype(self):
    return self.values.dtype

MaskedTensor([1, 2, 3], [True, False, True]).dtype

覆寫預設建構函式

您可以覆寫擴充類型的預設建構函式。自訂建構函式必須為每個宣告的欄位設定值;並且在自訂建構函式傳回之後,所有欄位都將進行類型檢查,並且值將如上所述進行轉換。

class Toy(tf.experimental.ExtensionType):
  name: str
  price: tf.Tensor
  def __init__(self, name, price, discount=0):
    self.name = name
    self.price = price * (1 - discount)

print(Toy("ball", 5.0, discount=0.2))  # On sale -- 20% off!

或者,您可以考慮保持預設建構函式不變,但新增一個或多個工廠方法。例如

class Toy(tf.experimental.ExtensionType):
  name: str
  price: tf.Tensor

  @staticmethod
  def new_toy_with_discount(name, price, discount):
    return Toy(name, price * (1 - discount))

print(Toy.new_toy_with_discount("ball", 5.0, discount=0.2))

覆寫預設相等運算子 (__eq__)

您可以覆寫擴充類型的預設 __eq__ 運算子。以下範例更新 MaskedTensor 以在比較相等性時忽略遮罩元素。

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  def __eq__(self, other):
    result = tf.math.equal(self.values, other.values)
    result = result | ~(self.mask & other.mask)
    return tf.reduce_all(result)

x = MaskedTensor([1, 2, 3, 4], [True, True, False, True])
y = MaskedTensor([5, 2, 0, 4], [False, True, False, True])
print(x == y)

使用前向參照

如果欄位的類型尚未定義,您可以使用包含類型名稱的字串來代替。在以下範例中,字串 "Node" 用於註解 children 欄位,因為 Node 類型尚未 (完全) 定義。

class Node(tf.experimental.ExtensionType):
  value: tf.Tensor
  children: Tuple["Node", ...] = ()

Node(3, [Node(5), Node(2)])

定義子類別

可以使用標準 Python 語法對擴充類型進行子類別化。擴充類型子類別可以新增新的欄位、方法和屬性;並且可以覆寫建構函式、可列印表示和相等運算子。以下範例定義了一個基本的 TensorGraph 類別,該類別使用三個 Tensor 欄位來編碼節點之間的一組邊緣。然後,它定義一個子類別,該子類別新增一個 Tensor 欄位來記錄每個節點的「特徵值」。子類別還定義了一種沿邊緣傳播特徵值的方法。

class TensorGraph(tf.experimental.ExtensionType):
  num_nodes: tf.Tensor
  edge_src: tf.Tensor   # edge_src[e] = index of src node for edge e.
  edge_dst: tf.Tensor   # edge_dst[e] = index of dst node for edge e.

class TensorGraphWithNodeFeature(TensorGraph):
  node_features: tf.Tensor  # node_features[n] = feature value for node n.

  def propagate_features(self, weight=1.0) -> 'TensorGraphWithNodeFeature':
    updates = tf.gather(self.node_features, self.edge_src) * weight
    new_node_features = tf.tensor_scatter_nd_add(
        self.node_features, tf.expand_dims(self.edge_dst, 1), updates)
    return TensorGraphWithNodeFeature(
        self.num_nodes, self.edge_src, self.edge_dst, new_node_features)

g = TensorGraphWithNodeFeature(  # Edges: 0->1, 4->3, 2->2, 2->1
    num_nodes=5, edge_src=[0, 4, 2, 2], edge_dst=[1, 3, 2, 1],
    node_features=[10.0, 0.0, 2.0, 5.0, -1.0, 0.0])

print("Original features:", g.node_features)
print("After propagating:", g.propagate_features().node_features)

定義私有欄位

擴充類型的欄位可以透過在其前面加上底線 (遵循標準 Python 慣例) 來標記為私有。這不會以任何方式影響 TensorFlow 處理欄位的方式;但只是作為對擴充類型任何使用者的訊號,表明這些欄位是私有的。

自訂 ExtensionTypeTypeSpec

每個 ExtensionType 類別都有對應的 TypeSpec 類別,該類別會自動建立並儲存為 <extension_type_name>.Spec。如需更多資訊,請參閱上面的「巢狀 TypeSpec」章節。

若要自訂 TypeSpec,只需定義您自己的巢狀類別,名為 Spec,並且 ExtensionType 將使用它作為自動建構的 TypeSpec 的基礎。您可以透過以下方式自訂 Spec 類別

  • 覆寫預設的可列印表示。
  • 覆寫預設建構函式。
  • 定義方法、classmethodstaticmethod 和屬性。

以下範例自訂 MaskedTensor.Spec 類別,使其更易於使用

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  shape = property(lambda self: self.values.shape)
  dtype = property(lambda self: self.values.dtype)

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  def with_values(self, new_values):
    return MaskedTensor(new_values, self.mask)

  class Spec:
    def __init__(self, shape, dtype=tf.float32):
      self.values = tf.TensorSpec(shape, dtype)
      self.mask = tf.TensorSpec(shape, tf.bool)

    def __repr__(self):
      return f"MaskedTensor.Spec(shape={self.shape}, dtype={self.dtype})"

    shape = property(lambda self: self.values.shape)
    dtype = property(lambda self: self.values.dtype)

張量 API 分派

擴充類型可以是「類似張量」的,因為它們專門化或擴充 tf.Tensor 類型定義的介面。類似張量的擴充類型的範例包括 RaggedTensorSparseTensorMaskedTensor。當分派裝飾器應用於類似張量的擴充類型時,可用於覆寫 TensorFlow 運算的預設行為。TensorFlow 目前定義了三個分派裝飾器

單個 API 的分派

tf.experimental.dispatch_for_api 裝飾器會覆寫指定的 TensorFlow 運算的預設行為,當使用指定的簽名呼叫它時。例如,您可以使用此裝飾器來指定 tf.stack 應如何處理 MaskedTensor

@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack(values: List[MaskedTensor], axis = 0):
  return MaskedTensor(tf.stack([v.values for v in values], axis),
                      tf.stack([v.mask for v in values], axis))

每當使用 MaskedTensor 值的清單呼叫 tf.stack 時,這都會覆寫預設實作 (因為 values 引數使用 typing.List[MaskedTensor] 進行了註解)

x = MaskedTensor([1, 2, 3], [True, True, False])
y = MaskedTensor([4, 5, 6], [False, True, True])
tf.stack([x, y])

若要允許 tf.stack 處理混合 MaskedTensorTensor 值的清單,您可以精簡 values 參數的類型註解,並適當地更新函式的主體

tf.experimental.unregister_dispatch_for(masked_stack)

def convert_to_masked_tensor(x):
  if isinstance(x, MaskedTensor):
    return x
  else:
    return MaskedTensor(x, tf.ones_like(x, tf.bool))

@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack_v2(values: List[Union[MaskedTensor, tf.Tensor]], axis = 0):
  values = [convert_to_masked_tensor(v) for v in values]
  return MaskedTensor(tf.stack([v.values for v in values], axis),
                      tf.stack([v.mask for v in values], axis))
x = MaskedTensor([1, 2, 3], [True, True, False])
y = tf.constant([4, 5, 6])
tf.stack([x, y, x])

如需可以覆寫的 API 清單,請參閱 tf.experimental.dispatch_for_api 的 API 文件。

所有一元逐元素 API 的分派

每當第一個引數 (通常命名為 x) 的值符合類型註解 x_type 時,tf.experimental.dispatch_for_unary_elementwise_apis 裝飾器會覆寫所有一元逐元素運算 (例如 tf.math.cos) 的預設行為。裝飾的函式應採用兩個引數

  • api_func:一個採用單個參數並執行逐元素運算的函式 (例如,tf.abs)。
  • x:逐元素運算的第一個引數。

以下範例更新所有一元逐元素運算以處理 MaskedTensor 類型

@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
 def masked_tensor_unary_elementwise_api_handler(api_func, x):
   return MaskedTensor(api_func(x.values), x.mask)

現在,每當在一元逐元素運算在 MaskedTensor 上呼叫時,都會使用此函式。

x = MaskedTensor([1, -2, -3], [True, False, True])
 print(tf.abs(x))
print(tf.ones_like(x, dtype=tf.float32))

二元所有逐元素 API 的分派

同樣地,tf.experimental.dispatch_for_binary_elementwise_apis 可用於更新所有二元逐元素運算以處理 MaskedTensor 類型

@tf.experimental.dispatch_for_binary_elementwise_apis(MaskedTensor, MaskedTensor)
def masked_tensor_binary_elementwise_api_handler(api_func, x, y):
  return MaskedTensor(api_func(x.values, y.values), x.mask & y.mask)
x = MaskedTensor([1, -2, -3], [True, False, True])
y = MaskedTensor([[4], [5]], [[True], [False]])
tf.math.add(x, y)

如需覆寫的逐元素 API 清單,請前往 tf.experimental.dispatch_for_unary_elementwise_apistf.experimental.dispatch_for_binary_elementwise_apis 的 API 文件。

可批次處理的 ExtensionType

如果可以使用單個執行個體來表示一批值,則 ExtensionType可批次處理的。通常,這是透過將批次維度新增至所有巢狀 Tensor 來完成的。以下 TensorFlow API 要求任何擴充類型輸入都是可批次處理的

預設情況下,BatchableExtensionType 透過批次處理任何巢狀 TensorCompositeTensorExtensionType 來建立批次處理的值。如果這不適用於您的類別,則您需要使用 tf.experimental.ExtensionTypeBatchEncoder 來覆寫此預設行為。例如,透過簡單地堆疊個別稀疏張量的 valuesindicesdense_shape 欄位來建立一批 tf.SparseTensor 值是不適當的 - 在大多數情況下,您無法堆疊這些張量,因為它們具有不相容的形狀;即使您可以,結果也不會是有效的 SparseTensor

BatchableExtensionType 範例:Network

作為範例,請考慮用於負載平衡的簡單 Network 類別,它追蹤每個節點中剩餘的工作量,以及可用於在節點之間移動工作的頻寬

class Network(tf.experimental.ExtensionType):  # This version is not batchable.
  work: tf.Tensor       # work[n] = work left to do at node n
  bandwidth: tf.Tensor  # bandwidth[n1, n2] = bandwidth from n1->n2

net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])

若要使此類型可批次處理,請將基底類型變更為 BatchableExtensionType,並調整每個欄位的形狀以包含選用的批次維度。以下範例還新增了 shape 欄位以追蹤批次形狀。雖然 tf.data.Datasettf.map_fn 並不要求此 shape 欄位,但 tf.keras確實要求。

class Network(tf.experimental.BatchableExtensionType):
  shape: tf.TensorShape  # batch shape. A single network has shape=[].
  work: tf.Tensor        # work[*shape, n] = work left to do at node n
  bandwidth: tf.Tensor   # bandwidth[*shape, n1, n2] = bandwidth from n1->n2

  def __init__(self, work, bandwidth):
    self.work = tf.convert_to_tensor(work)
    self.bandwidth = tf.convert_to_tensor(bandwidth)
    work_batch_shape = self.work.shape[:-1]
    bandwidth_batch_shape = self.bandwidth.shape[:-2]
    self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)

  def __repr__(self):
    return network_repr(self)

def network_repr(network):
  work = network.work
  bandwidth = network.bandwidth
  if hasattr(work, 'numpy'):
    work = ' '.join(str(work.numpy()).split())
  if hasattr(bandwidth, 'numpy'):
    bandwidth = ' '.join(str(bandwidth.numpy()).split())
  return (f"<Network shape={network.shape} work={work} bandwidth={bandwidth}>")
net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])
batch_of_networks = Network(
    work=tf.stack([net1.work, net2.work]),
    bandwidth=tf.stack([net1.bandwidth, net2.bandwidth]))
print(f"net1={net1}")
print(f"net2={net2}")
print(f"batch={batch_of_networks}")

然後,您可以使用 tf.data.Dataset 迭代一批網路

dataset = tf.data.Dataset.from_tensor_slices(batch_of_networks)
for i, network in enumerate(dataset):
  print(f"Batch element {i}: {network}")

您也可以使用 map_fn 將函式套用至每個批次元素

def balance_work_greedy(network):
  delta = (tf.expand_dims(network.work, -1) - tf.expand_dims(network.work, -2))
  delta /= 4
  delta = tf.maximum(tf.minimum(delta, network.bandwidth), -network.bandwidth)
  new_work = network.work + tf.reduce_sum(delta, -1)
  return Network(new_work, network.bandwidth)

tf.map_fn(balance_work_greedy, batch_of_networks)

支援 ExtensionType 的 TensorFlow API

@tf.function

tf.function 是一個裝飾器,可為 Python 函式預先計算 TensorFlow 圖,這可以大幅提升 TensorFlow 程式碼的效能。擴充類型值可以與 @tf.function 裝飾的函式透明地使用。

class Pastry(tf.experimental.ExtensionType):
  sweetness: tf.Tensor  # 2d embedding that encodes sweetness
  chewiness: tf.Tensor  # 2d embedding that encodes chewiness

@tf.function
def combine_pastry_features(x: Pastry):
  return (x.sweetness + x.chewiness) / 2

cookie = Pastry(sweetness=[1.2, 0.4], chewiness=[0.8, 0.2])
combine_pastry_features(cookie)

如果您希望明確指定 tf.functioninput_signature,則可以使用擴充類型的 TypeSpec 來執行此操作。

pastry_spec = Pastry.Spec(tf.TensorSpec([2]), tf.TensorSpec(2))

@tf.function(input_signature=[pastry_spec])
def increase_sweetness(x: Pastry, delta=1.0):
  return Pastry(x.sweetness + delta, x.chewiness)

increase_sweetness(cookie)

具體函式

具體函式封裝了由 tf.function 建置的個別追蹤圖。擴充類型可以與具體函式透明地使用。

cf = combine_pastry_features.get_concrete_function(pastry_spec)
cf(cookie)

控制流程運算

TensorFlow 的控制流程運算支援擴充類型

# Example: using tf.cond to select between two MaskedTensors. Note that the
# two MaskedTensors don't need to have the same shape.
a = MaskedTensor([1., 2, 3], [True, False, True])
b = MaskedTensor([22., 33, 108, 55], [True, True, True, False])
condition = tf.constant(True)
print(tf.cond(condition, lambda: a, lambda: b))
# Example: using tf.while_loop with MaskedTensor.
cond = lambda i, _: i < 10
def body(i, mt):
  return i + 1, mt.with_values(mt.values + 3 / 7)
print(tf.while_loop(cond, body, [0, b])[1])

Autograph 控制流程

tf.function (使用 autograph) 中的控制流程陳述式也支援擴充類型。在以下範例中,if 陳述式和 for 陳述式會自動轉換為 tf.condtf.while_loop 運算,這些運算支援擴充類型。

@tf.function
def fn(x, b):
  if b:
    x = MaskedTensor(x, tf.less(x, 0))
  else:
    x = MaskedTensor(x, tf.greater(x, 0))
  for i in tf.range(5 if b else 7):
    x = x.with_values(x.values + 1 / 2)
  return x

print(fn(tf.constant([1., -2, 3]), tf.constant(True)))
print(fn(tf.constant([1., -2, 3]), tf.constant(False)))

Keras

tf.keras 是 TensorFlow 的高階 API,用於建構和訓練深度學習模型。擴充類型可以做為 Keras 模型的輸入傳遞、在 Keras 層之間傳遞,以及由 Keras 模型傳回。Keras 目前對擴充類型有兩項要求:

  • 它們必須可批次處理 (前往上方的「可批次處理的 ExtensionType」)。
  • 它們必須具有名為 shape 的欄位或屬性。shape[0] 假設為批次維度。

以下兩個小節提供範例,說明如何將擴充類型與 Keras 搭配使用。

Keras 範例:Network

對於第一個範例,請考慮上方「可批次處理的 ExtensionType」章節中定義的 Network 類別,該類別可用於在節點之間進行負載平衡。其定義在此重複:

class Network(tf.experimental.BatchableExtensionType):
  shape: tf.TensorShape  # batch shape. A single network has shape=[].
  work: tf.Tensor        # work[*shape, n] = work left to do at node n
  bandwidth: tf.Tensor   # bandwidth[*shape, n1, n2] = bandwidth from n1->n2

  def __init__(self, work, bandwidth):
    self.work = tf.convert_to_tensor(work)
    self.bandwidth = tf.convert_to_tensor(bandwidth)
    work_batch_shape = self.work.shape[:-1]
    bandwidth_batch_shape = self.bandwidth.shape[:-2]
    self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)

  def __repr__(self):
    return network_repr(self)
single_network = Network(  # A single network with 4 nodes.
    work=[8.0, 5, 12, 2],
    bandwidth=[[0.0, 1, 2, 2], [1, 0, 0, 2], [2, 0, 0, 1], [2, 2, 1, 0]])

batch_of_networks = Network(  # Batch of 2 networks, each w/ 2 nodes.
    work=[[8.0, 5], [3, 2]],
    bandwidth=[[[0.0, 1], [1, 0]], [[0, 2], [2, 0]]])

您可以定義一個新的 Keras 層來處理 Network

class BalanceNetworkLayer(tf.keras.layers.Layer):
  """Layer that balances work between nodes in a network.

  Shifts work from more busy nodes to less busy nodes, constrained by bandwidth.
  """
  def call(self, inputs):
    # This function is defined above in the "Batchable `ExtensionType`s" section.
    return balance_work_greedy(inputs)

然後,您可以使用這些層來建立簡單的模型。若要將 ExtensionType 饋送至模型,您可以使用 tf.keras.layer.Input 層,並將 type_spec 設定為擴充類型的 TypeSpec。如果 Keras 模型將用於處理批次,則 type_spec 必須包含批次維度。

input_spec = Network.Spec(shape=None,
                          work=tf.TensorSpec(None, tf.float32),
                          bandwidth=tf.TensorSpec(None, tf.float32))
model = tf.keras.Sequential([
    tf.keras.layers.Input(type_spec=input_spec),
    BalanceNetworkLayer(),
    ])

最後,您可以將模型套用至單一網路和一批網路。

model(single_network)
model(batch_of_networks)

Keras 範例:MaskedTensor

在此範例中,擴充 MaskedTensor 以支援 Kerasshape 定義為從 values 欄位計算的屬性。Keras 要求您將此屬性新增至擴充類型及其 TypeSpecMaskedTensor 也定義了 __name__ 變數,這將是 SavedModel 序列化 (如下所述) 所需的。

class MaskedTensor(tf.experimental.BatchableExtensionType):
  # __name__ is required for serialization in SavedModel; see below for details.
  __name__ = 'extension_type_colab.MaskedTensor'

  values: tf.Tensor
  mask: tf.Tensor

  shape = property(lambda self: self.values.shape)
  dtype = property(lambda self: self.values.dtype)

  def with_default(self, default):
    return tf.where(self.mask, self.values, default)

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  class Spec:
    def __init__(self, shape, dtype=tf.float32):
      self.values = tf.TensorSpec(shape, dtype)
      self.mask = tf.TensorSpec(shape, tf.bool)

    shape = property(lambda self: self.values.shape)
    dtype = property(lambda self: self.values.dtype)

    def with_shape(self):
      return MaskedTensor.Spec(tf.TensorSpec(shape, self.values.dtype),
                               tf.TensorSpec(shape, self.mask.dtype))

接下來,dispatch 裝飾器用於覆寫數個 TensorFlow API 的預設行為。由於標準 Keras 層 (例如 Dense 層) 使用這些 API,因此覆寫這些 API 將允許我們將這些層與 MaskedTensor 搭配使用。就本範例而言,遮罩張量的 matmul 定義為將遮罩值視為零 (也就是說,不將它們包含在乘積中)。

@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
def unary_elementwise_op_handler(op, x):
 return MaskedTensor(op(x.values), x.mask)

@tf.experimental.dispatch_for_binary_elementwise_apis(
    Union[MaskedTensor, tf.Tensor],
    Union[MaskedTensor, tf.Tensor])
def binary_elementwise_op_handler(op, x, y):
  x = convert_to_masked_tensor(x)
  y = convert_to_masked_tensor(y)
  return MaskedTensor(op(x.values, y.values), x.mask & y.mask)

@tf.experimental.dispatch_for_api(tf.matmul)
def masked_matmul(a: MaskedTensor, b,
                  transpose_a=False, transpose_b=False,
                  adjoint_a=False, adjoint_b=False,
                  a_is_sparse=False, b_is_sparse=False,
                  output_type=None):
  if isinstance(a, MaskedTensor):
    a = a.with_default(0)
  if isinstance(b, MaskedTensor):
    b = b.with_default(0)
  return tf.matmul(a, b, transpose_a, transpose_b, adjoint_a,
                   adjoint_b, a_is_sparse, b_is_sparse, output_type)

然後,您可以建構一個接受 MaskedTensor 輸入的 Keras 模型,使用標準 Keras 層

input_spec = MaskedTensor.Spec([None, 2], tf.float32)

masked_tensor_model = tf.keras.Sequential([
    tf.keras.layers.Input(type_spec=input_spec),
    tf.keras.layers.Dense(16, activation="relu"),
    tf.keras.layers.Dense(1)])
masked_tensor_model.compile(loss='binary_crossentropy', optimizer='rmsprop')
a = MaskedTensor([[1., 2], [3, 4], [5, 6]],
                  [[True, False], [False, True], [True, True]])
masked_tensor_model.fit(a, tf.constant([[1], [0], [1]]), epochs=3)
print(masked_tensor_model(a))

SavedModel

SavedModel 是序列化的 TensorFlow 程式,包括權重和運算。它可以從 Keras 模型或自訂模型建置。在任一種情況下,擴充類型都可以與 SavedModel 定義的函式和方法透明地搭配使用。

SavedModel 可以儲存處理擴充類型的模型、層和函式,只要擴充類型具有 __name__ 欄位即可。此名稱用於註冊擴充類型,以便在載入模型時可以找到它。

範例:儲存 Keras 模型

可以使用 SavedModel 儲存使用擴充類型的 Keras 模型。

masked_tensor_model_path = tempfile.mkdtemp()
tf.saved_model.save(masked_tensor_model, masked_tensor_model_path)
imported_model = tf.saved_model.load(masked_tensor_model_path)
imported_model(a)

範例:儲存自訂模型

SavedModel 也可用於儲存自訂 tf.Module 子類別,其中包含處理擴充類型的函式。

class CustomModule(tf.Module):
  def __init__(self, variable_value):
    super().__init__()
    self.v = tf.Variable(variable_value)

  @tf.function
  def grow(self, x: MaskedTensor):
    """Increase values in `x` by multiplying them by `self.v`."""
    return MaskedTensor(x.values * self.v, x.mask)

module = CustomModule(100.0)

module.grow.get_concrete_function(MaskedTensor.Spec(shape=None,
                                                    dtype=tf.float32))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
imported_model.grow(MaskedTensor([1., 2, 3], [False, True, False]))

ExtensionType 無法使用時載入 SavedModel

如果您載入使用 ExtensionTypeSavedModel,但該 ExtensionType 無法使用 (也就是說,尚未匯入),則您會收到警告,且 TensorFlow 將回復為使用「匿名擴充類型」物件。此物件將具有與原始類型相同的欄位,但會缺少您為類型新增的任何進一步自訂,例如自訂方法或屬性。

ExtensionType 與 TensorFlow Serving 搭配使用

目前,TensorFlow Serving (和其他 SavedModel「簽名」字典的消費者) 要求所有輸入和輸出都必須是原始張量。如果您希望將 TensorFlow Serving 與使用擴充類型的模型搭配使用,則可以新增包裝函式方法,以從張量組成或分解擴充類型值。例如:

class CustomModuleWrapper(tf.Module):
  def __init__(self, variable_value):
    super().__init__()
    self.v = tf.Variable(variable_value)

  @tf.function
  def var_weighted_mean(self, x: MaskedTensor):
    """Mean value of unmasked values in x, weighted by self.v."""
    x = MaskedTensor(x.values * self.v, x.mask)
    return (tf.reduce_sum(x.with_default(0)) /
            tf.reduce_sum(tf.cast(x.mask, x.dtype)))

  @tf.function()
  def var_weighted_mean_wrapper(self, x_values, x_mask):
    """Raw tensor wrapper for var_weighted_mean."""
    return self.var_weighted_mean(MaskedTensor(x_values, x_mask))

module = CustomModuleWrapper([3., 2., 8., 5.])

module.var_weighted_mean_wrapper.get_concrete_function(
    tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.bool))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
x = MaskedTensor([1., 2., 3., 4.], [False, True, False, True])
imported_model.var_weighted_mean_wrapper(x.values, x.mask)

Dataset

tf.data 是一個 API,可讓您從簡單、可重複使用的片段建構複雜的輸入管線。其核心資料結構是 tf.data.Dataset,它代表一系列元素,其中每個元素都由一個或多個元件組成。

使用擴充類型建構 Dataset

可以使用 Dataset.from_tensorsDataset.from_tensor_slicesDataset.from_generator,從擴充類型值建構資料集

ds = tf.data.Dataset.from_tensors(Pastry(5, 5))
iter(ds).next()
mt = MaskedTensor(tf.reshape(range(20), [5, 4]), tf.ones([5, 4]))
ds = tf.data.Dataset.from_tensor_slices(mt)
for value in ds:
  print(value)
def value_gen():
  for i in range(2, 7):
    yield MaskedTensor(range(10), [j%i != 0 for j in range(10)])

ds = tf.data.Dataset.from_generator(
    value_gen, output_signature=MaskedTensor.Spec(shape=[10], dtype=tf.int32))
for value in ds:
  print(value)

使用擴充類型批次和取消批次處理 Dataset

具有擴充類型的資料集可以使用 Dataset.batchDataset.unbatch 進行批次處理和取消批次處理。

batched_ds = ds.batch(2)
for value in batched_ds:
  print(value)
unbatched_ds = batched_ds.unbatch()
for value in unbatched_ds:
  print(value)