TF-NumPy 型別提升

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視原始碼 下載筆記本

總覽

TensorFlow 中有 4 個型別提升選項。

  • 依預設,TensorFlow 會引發錯誤,而不是針對混合型別運算提升型別。
  • 執行 tf.numpy.experimental_enable_numpy_behavior() 會切換 TensorFlow 以使用 NumPy 型別提升規則
  • 本文件說明 TensorFlow 2.15 (或目前在 tf-nightly 中) 中提供的兩個新選項
pip install -q tf_nightly

設定

import numpy as np
import tensorflow as tf
import tensorflow.experimental.numpy as tnp

print("Using TensorFlow version %s" % tf.__version__)
Using TensorFlow version 2.17.0-dev20240210

啟用新的型別提升

為了在 TF-Numpy 中使用類似 JAX 的型別提升,在為 TensorFlow 啟用 NumPy 行為時,請將 'all''safe' 指定為 dtype 轉換模式。

這個新系統 (搭配 dtype_conversion_mode="all") 具有結合律、交換律,且可輕鬆控制最終使用的浮點數寬度 (不會自動轉換為更寬的浮點數)。它確實會帶來一些溢位和精確度損失的風險,但 dtype_conversion_mode="safe" 會強制您明確處理這些情況。這兩種模式在下一節中會更詳細地說明。

tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.

兩種模式:「ALL」模式與「SAFE」模式

在新的型別提升系統中,我們引入了兩種模式:ALL 模式和 SAFE 模式。SAFE 模式用於減輕「風險」提升的疑慮,這些提升可能會導致精確度損失或位元加寬。

Dtype

為了簡潔起見,我們將使用以下縮寫。

星號 (*) 表示對應的型別為「弱」型別 - 這種 dtype 是由系統暫時推斷的,可能會延遲到其他 dtype。此概念在此處有更詳細的說明。

精確度損失運算的範例

在以下範例中,i32 + f32ALL 模式下允許,但在 SAFE 模式下不允許,因為有精確度損失的風險。

# i32 + f32 returns a f32 result in ALL mode.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = tf.constant(10, dtype = tf.int32)
b = tf.constant(5.0, dtype = tf.float32)
a + b  # <tf.Tensor: shape=(), dtype=float32, numpy=15.0>
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
<tf.Tensor: shape=(), dtype=float32, numpy=15.0>
# This promotion is not allowed in SAFE mode.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="safe")
a = tf.constant(10, dtype = tf.int32)
b = tf.constant(5.0, dtype = tf.float32)
try:
  a + b
except TypeError as e:
   print(f'{type(e)}: {e}')  # TypeError: explicitly specify the dtype or switch to ALL mode.
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
<class 'TypeError'>: In promotion mode PromoMode.SAFE, implicit dtype promotion between (<dtype: 'int32'>, weak=False) and (<dtype: 'float32'>, weak=False) is disallowed. You need to explicitly specify the dtype in your op, or relax your dtype promotion rules (such as from SAFE mode to ALL mode).

位元加寬運算的範例

在以下範例中,i8 + u32ALL 模式下允許,但在 SAFE 模式下不允許,因為位元加寬表示使用的位元數多於輸入中的位元數。請注意,新的型別提升語意僅允許必要的位元加寬。

# i8 + u32 returns an i64 result in ALL mode.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = tf.constant(10, dtype = tf.int8)
b = tf.constant(5, dtype = tf.uint32)
a + b
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
<tf.Tensor: shape=(), dtype=int64, numpy=15>
# This promotion is not allowed in SAFE mode.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="safe")
a = tf.constant(10, dtype = tf.int8)
b = tf.constant(5, dtype = tf.uint32)
try:
  a + b
except TypeError as e:
   print(f'{type(e)}: {e}')  # TypeError: explicitly specify the dtype or switch to ALL mode.
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
<class 'TypeError'>: In promotion mode PromoMode.SAFE, implicit dtype promotion between (<dtype: 'int8'>, weak=False) and (<dtype: 'uint32'>, weak=False) is disallowed. You need to explicitly specify the dtype in your op, or relax your dtype promotion rules (such as from SAFE mode to ALL mode).

以格狀結構為基礎的系統

型別提升格狀結構

新的型別提升行為是透過以下型別提升格狀結構判定的

Type Promotion Lattice

更具體來說,任何兩種型別之間的提升都是透過尋找兩個節點的第一個共同子節點 (包括節點本身) 來判定的。

例如,在上圖中,i8i32 的第一個共同子節點是 i32,因為當遵循箭頭方向時,兩個節點在 i32 處第一次相交。

同樣地,另一個範例是,u64f16 之間的結果提升型別會是 f16

型別提升表

遵循格狀結構會產生以下二元提升表

Type Promotion Table

新型別提升的優點

我們針對新的型別提升採用了類似 JAX 的格狀結構系統,此系統具有以下優點

格狀結構系統的優點

首先,使用格狀結構系統可確保三個非常重要的屬性

  • 存在性:任何型別組合都有獨一無二的結果提升型別。
  • 交換律:a + b = b + a
  • 結合律:a + (b + c) = (a + b) = c

這三個屬性對於建構一致且可預測的型別提升語意至關重要。

類似 JAX 的格狀結構系統的優點

類似 JAX 的格狀結構系統的另一個關鍵優點是,在無號整數之外,它可以避免所有不必要的加寬提升。這表示如果沒有 64 位元輸入,您就無法獲得 64 位元結果。這對於在加速器上工作特別有利,因為它可以避免不必要的 64 位元值,這在舊型別提升中很常見。

但是,這會帶來一個權衡:混合浮點數/整數提升非常容易發生精確度損失。例如,在以下範例中,i64 + f16 會導致將 i64 提升為 f16

# The first input is promoted to f16 in ALL mode.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
tf.constant(1, tf.int64) + tf.constant(3.2, tf.float16)  # <tf.Tensor: shape=(), dtype=float16, numpy=4.2>
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
<tf.Tensor: shape=(), dtype=float16, numpy=4.2>

為了減輕此疑慮,我們引入了 SAFE 模式,此模式將不允許這些「風險」提升。

WeakTensor

總覽

弱張量是「弱型別」的張量,類似於 JAX 中的概念

WeakTensor 的 dtype 是由系統暫時推斷的,可能會延遲到其他 dtype。新的型別提升中引入此概念,以防止 TF 值與沒有明確使用者指定型別的值 (例如 Python 純量常值) 之間的二元運算中發生不必要的型別提升。

例如,在以下範例中,tf.constant(1.2) 被視為「弱」型別,因為它沒有特定的 dtype。因此,tf.constant(1.2) 會延遲到 tf.constant(3.1, tf.float16) 的型別,產生 f16 輸出。

tf.constant(1.2) + tf.constant(3.1, tf.float16)  # <tf.Tensor: shape=(), dtype=float16, numpy=4.3>
<tf.Tensor: shape=(), dtype=float16, numpy=4.3>

WeakTensor 建構

如果您建立張量時未指定 dtype,則會建立 WeakTensor。您可以透過檢查張量字串表示法結尾的 weak 屬性,來檢查張量是否為「弱」型別。

第一種情況:當呼叫 tf.constant 時,輸入沒有使用者指定的 dtype。

tf.constant(5)  # <tf.Tensor: shape=(), dtype=int32, numpy=5, weak=True>
<tf.Tensor: shape=(), dtype=int32, numpy=5, weak=True>
tf.constant([5.0, 10.0, 3])  # <tf.Tensor: shape=(3,), dtype=float32, numpy=array([ 5., 10.,  3.], dtype=float32), weak=True>
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([ 5., 10.,  3.], dtype=float32), weak=True>
# A normal Tensor is created when dtype arg is specified.
tf.constant(5, tf.int32)  # <tf.Tensor: shape=(), dtype=int32, numpy=5>
<tf.Tensor: shape=(), dtype=int32, numpy=5>

第二種情況:當沒有使用者指定的 dtype 的輸入傳遞至支援 WeakTensor 的 API 時。

tf.math.abs([100.0, 4.0])  # <tf.Tensor: shape=(2,), dtype=float32, numpy=array([100., 4.], dtype=float32), weak=True>
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([100.,   4.], dtype=float32), weak=True>

開啟新型別提升的效果

以下是非詳盡的清單,列出開啟新型別提升所產生的變更。

  • 更一致且可預測的提升結果。
  • 降低位元加寬的風險。
  • tf.Tensor 數學 Dunder 方法使用新型別提升。
  • tf.constant 可以傳回 WeakTensor
  • tf.constant 允許在傳入 dtype 與 dtype 引數不同的張量輸入時進行隱含轉換。
  • tf.Variable 就地運算 (assign、assign-add、assign-sub) 允許隱含轉換。
  • tnp.array(1)tnp.array(1.0) 傳回 32 位元 WeakTensor。
  • WeakTensor 將針對支援 WeakTensor 的一元和二元 API 建立和使用。

更一致且可預測的提升結果

使用格狀結構系統可讓新型別提升產生一致且可預測的型別提升結果。

舊型別提升

使用舊型別提升時,變更運算順序會產生不一致的結果。

# Setup
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="legacy")
a = np.array(1, dtype=np.int8)
b = tf.constant(1)
c = np.array(1, dtype=np.float16)
# (a + b) + c throws an InvalidArgumentError.
try:
  tf.add(tf.add(a, b), c)
except tf.errors.InvalidArgumentError as e:
  print(f'{type(e)}: {e}')  # InvalidArgumentError
<class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>: cannot compute AddV2 as input #1(zero-based) was expected to be a int8 tensor but is a int32 tensor [Op:AddV2] name:
# (b + a) + c returns an i32 result.
tf.add(tf.add(b, a), c)  # <tf.Tensor: shape=(), dtype=int32, numpy=3>
<tf.Tensor: shape=(), dtype=int32, numpy=3>

新型別提升

無論順序為何,新型別提升都會產生一致的結果。

tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = np.array(1, dtype=np.int8)
b = tf.constant(1)
c = np.array(1, dtype=np.float16)
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
# (a + b) + c returns a f16 result.
tf.add(tf.add(a, b), c)  # <tf.Tensor: shape=(), dtype=float16, numpy=3.0>
<tf.Tensor: shape=(), dtype=float16, numpy=3.0>
# (b + a) + c also returns a f16 result.
tf.add(tf.add(b, a), c)  # <tf.Tensor: shape=(), dtype=float16, numpy=3.0>
<tf.Tensor: shape=(), dtype=float16, numpy=3.0>

降低位元加寬的風險

舊型別提升

舊型別提升通常會產生 64 位元結果。

tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="legacy")
np.array(3.2, np.float16) + tf.constant(1, tf.int8) + tf.constant(50)  # <tf.Tensor: shape=(), dtype=float64, numpy=54.19921875>
<tf.Tensor: shape=(), dtype=float64, numpy=54.19921875>

新型別提升

新型別提升會傳回位元數最少且必要的結果。

tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
np.array(3.2, np.float16) + tf.constant(1, tf.int8) + tf.constant(50)  # <tf.Tensor: shape=(), dtype=float16, numpy=54.2>
<tf.Tensor: shape=(), dtype=float16, numpy=54.2>

tf.Tensor 數學 Dunder 方法

所有 tf.Tensor 數學 Dunder 方法都將遵循新型別提升。

-tf.constant(5)  # <tf.Tensor: shape=(), dtype=int32, numpy=-5, weak=True>
<tf.Tensor: shape=(), dtype=int32, numpy=-5, weak=True>
tf.constant(5, tf.int16) - tf.constant(1, tf.float32)  # <tf.Tensor: shape=(), dtype=float32, numpy=4.0>
<tf.Tensor: shape=(), dtype=float32, numpy=4.0>

tf.Variable 就地運算

隱含轉換將在 tf.Variable 就地運算中允許。

tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = tf.Variable(10, tf.int32)
a.assign_add(tf.constant(5, tf.int16))  # <tf.Variable shape=() dtype=int32, numpy=15>
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
<tf.Variable 'UnreadVariable' shape=() dtype=int32, numpy=15>

tf.constant 隱含轉換

在舊型別提升中,tf.constant 需要輸入張量具有與 dtype 引數相同的 dtype。但是,在新型別提升中,我們會將張量隱含轉換為指定的 dtype。

tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = tf.constant(10, tf.int16)
tf.constant(a, tf.float32)  # <tf.Tensor: shape=(), dtype=float32, numpy=10.0>
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
<tf.Tensor: shape=(), dtype=float32, numpy=10.0>

TF-NumPy 陣列

tnp.array 預設為 i32*f32*,適用於使用新型別提升的 Python 輸入。

tnp.array(1)  # <tf.Tensor: shape=(), dtype=int32, numpy=1, weak=True>
<tf.Tensor: shape=(), dtype=int32, numpy=1, weak=True>
tnp.array(1.0)  # <tf.Tensor: shape=(), dtype=int32, numpy=1, weak=True>
<tf.Tensor: shape=(), dtype=float32, numpy=1.0, weak=True>

輸入型別推斷

以下說明如何在新型別提升中推斷不同輸入的型別。

  • tf.Tensor:由於 tf.Tensor 具有 dtype 屬性,因此我們不會執行進一步的推斷。
  • NumPy 型別:這包括 np.array(1)np.int16(1)np.float 等型別。由於 NumPy 輸入也有 dtype 屬性,因此我們將 dtype 屬性視為結果推斷型別。請注意,NumPy 預設為 i64f64
  • Python 純量/巢狀型別:這包括 1[1, 2, 3](1.0, 2.0) 等型別。
    • Python int 推斷為 i32*
    • Python float 推斷為 f32*
    • Python complex 推斷為 c128*
  • 如果輸入不屬於上述任何類別,但具有 dtype 屬性,我們會將 dtype 屬性視為結果推斷型別。

延伸閱讀

新型別提升與 JAX-NumPy 的型別提升非常相似。如果您想瞭解有關新型別提升和設計選擇的更多詳細資訊,請查看以下資源。

參考資料

支援 WeakTensor 的 API

以下是支援 WeakTensor 的 API 清單。

對於一元運算,這表示如果傳入沒有使用者指定型別的輸入,則會傳回 WeakTensor

對於二元運算,它將遵循此處的提升表。它可能會或可能不會傳回 WeakTensor,具體取決於兩個輸入的提升結果。