![]() |
![]() |
![]() |
![]() |
總覽
TensorFlow 中有 4 個型別提升選項。
- 依預設,TensorFlow 會引發錯誤,而不是針對混合型別運算提升型別。
- 執行
tf.numpy.experimental_enable_numpy_behavior()
會切換 TensorFlow 以使用 NumPy 型別提升規則。 - 本文件說明 TensorFlow 2.15 (或目前在
tf-nightly
中) 中提供的兩個新選項
pip install -q tf_nightly
設定
import numpy as np
import tensorflow as tf
import tensorflow.experimental.numpy as tnp
print("Using TensorFlow version %s" % tf.__version__)
Using TensorFlow version 2.17.0-dev20240210
啟用新的型別提升
為了在 TF-Numpy 中使用類似 JAX 的型別提升,在為 TensorFlow 啟用 NumPy 行為時,請將 'all'
或 'safe'
指定為 dtype 轉換模式。
這個新系統 (搭配 dtype_conversion_mode="all"
) 具有結合律、交換律,且可輕鬆控制最終使用的浮點數寬度 (不會自動轉換為更寬的浮點數)。它確實會帶來一些溢位和精確度損失的風險,但 dtype_conversion_mode="safe"
會強制您明確處理這些情況。這兩種模式在下一節中會更詳細地說明。
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
兩種模式:「ALL」模式與「SAFE」模式
在新的型別提升系統中,我們引入了兩種模式:ALL
模式和 SAFE
模式。SAFE
模式用於減輕「風險」提升的疑慮,這些提升可能會導致精確度損失或位元加寬。
Dtype
為了簡潔起見,我們將使用以下縮寫。
b
代表tf.bool
u8
代表tf.uint8
i16
代表tf.int16
i32
代表tf.int32
bf16
代表tf.bfloat16
f32
代表tf.float32
f64
代表tf.float64
i32*
代表 Pythonint
或弱型別i32
f32*
代表 Pythonfloat
或弱型別f32
c128*
代表 Pythoncomplex
或弱型別c128
星號 (*) 表示對應的型別為「弱」型別 - 這種 dtype 是由系統暫時推斷的,可能會延遲到其他 dtype。此概念在此處有更詳細的說明。
精確度損失運算的範例
在以下範例中,i32
+ f32
在 ALL
模式下允許,但在 SAFE
模式下不允許,因為有精確度損失的風險。
# i32 + f32 returns a f32 result in ALL mode.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = tf.constant(10, dtype = tf.int32)
b = tf.constant(5.0, dtype = tf.float32)
a + b # <tf.Tensor: shape=(), dtype=float32, numpy=15.0>
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet. <tf.Tensor: shape=(), dtype=float32, numpy=15.0>
# This promotion is not allowed in SAFE mode.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="safe")
a = tf.constant(10, dtype = tf.int32)
b = tf.constant(5.0, dtype = tf.float32)
try:
a + b
except TypeError as e:
print(f'{type(e)}: {e}') # TypeError: explicitly specify the dtype or switch to ALL mode.
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet. <class 'TypeError'>: In promotion mode PromoMode.SAFE, implicit dtype promotion between (<dtype: 'int32'>, weak=False) and (<dtype: 'float32'>, weak=False) is disallowed. You need to explicitly specify the dtype in your op, or relax your dtype promotion rules (such as from SAFE mode to ALL mode).
位元加寬運算的範例
在以下範例中,i8
+ u32
在 ALL
模式下允許,但在 SAFE
模式下不允許,因為位元加寬表示使用的位元數多於輸入中的位元數。請注意,新的型別提升語意僅允許必要的位元加寬。
# i8 + u32 returns an i64 result in ALL mode.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = tf.constant(10, dtype = tf.int8)
b = tf.constant(5, dtype = tf.uint32)
a + b
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet. <tf.Tensor: shape=(), dtype=int64, numpy=15>
# This promotion is not allowed in SAFE mode.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="safe")
a = tf.constant(10, dtype = tf.int8)
b = tf.constant(5, dtype = tf.uint32)
try:
a + b
except TypeError as e:
print(f'{type(e)}: {e}') # TypeError: explicitly specify the dtype or switch to ALL mode.
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet. <class 'TypeError'>: In promotion mode PromoMode.SAFE, implicit dtype promotion between (<dtype: 'int8'>, weak=False) and (<dtype: 'uint32'>, weak=False) is disallowed. You need to explicitly specify the dtype in your op, or relax your dtype promotion rules (such as from SAFE mode to ALL mode).
以格狀結構為基礎的系統
型別提升格狀結構
新的型別提升行為是透過以下型別提升格狀結構判定的
更具體來說,任何兩種型別之間的提升都是透過尋找兩個節點的第一個共同子節點 (包括節點本身) 來判定的。
例如,在上圖中,i8
和 i32
的第一個共同子節點是 i32
,因為當遵循箭頭方向時,兩個節點在 i32
處第一次相交。
同樣地,另一個範例是,u64
和 f16
之間的結果提升型別會是 f16
。
型別提升表
遵循格狀結構會產生以下二元提升表
新型別提升的優點
我們針對新的型別提升採用了類似 JAX 的格狀結構系統,此系統具有以下優點
格狀結構系統的優點
首先,使用格狀結構系統可確保三個非常重要的屬性
- 存在性:任何型別組合都有獨一無二的結果提升型別。
- 交換律:
a + b = b + a
- 結合律:
a + (b + c) = (a + b) = c
這三個屬性對於建構一致且可預測的型別提升語意至關重要。
類似 JAX 的格狀結構系統的優點
類似 JAX 的格狀結構系統的另一個關鍵優點是,在無號整數之外,它可以避免所有不必要的加寬提升。這表示如果沒有 64 位元輸入,您就無法獲得 64 位元結果。這對於在加速器上工作特別有利,因為它可以避免不必要的 64 位元值,這在舊型別提升中很常見。
但是,這會帶來一個權衡:混合浮點數/整數提升非常容易發生精確度損失。例如,在以下範例中,i64
+ f16
會導致將 i64
提升為 f16
。
# The first input is promoted to f16 in ALL mode.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
tf.constant(1, tf.int64) + tf.constant(3.2, tf.float16) # <tf.Tensor: shape=(), dtype=float16, numpy=4.2>
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet. <tf.Tensor: shape=(), dtype=float16, numpy=4.2>
為了減輕此疑慮,我們引入了 SAFE
模式,此模式將不允許這些「風險」提升。
WeakTensor
總覽
WeakTensor
的 dtype 是由系統暫時推斷的,可能會延遲到其他 dtype。新的型別提升中引入此概念,以防止 TF 值與沒有明確使用者指定型別的值 (例如 Python 純量常值) 之間的二元運算中發生不必要的型別提升。
例如,在以下範例中,tf.constant(1.2)
被視為「弱」型別,因為它沒有特定的 dtype。因此,tf.constant(1.2)
會延遲到 tf.constant(3.1, tf.float16)
的型別,產生 f16
輸出。
tf.constant(1.2) + tf.constant(3.1, tf.float16) # <tf.Tensor: shape=(), dtype=float16, numpy=4.3>
<tf.Tensor: shape=(), dtype=float16, numpy=4.3>
WeakTensor 建構
如果您建立張量時未指定 dtype,則會建立 WeakTensor。您可以透過檢查張量字串表示法結尾的 weak 屬性,來檢查張量是否為「弱」型別。
第一種情況:當呼叫 tf.constant
時,輸入沒有使用者指定的 dtype。
tf.constant(5) # <tf.Tensor: shape=(), dtype=int32, numpy=5, weak=True>
<tf.Tensor: shape=(), dtype=int32, numpy=5, weak=True>
tf.constant([5.0, 10.0, 3]) # <tf.Tensor: shape=(3,), dtype=float32, numpy=array([ 5., 10., 3.], dtype=float32), weak=True>
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([ 5., 10., 3.], dtype=float32), weak=True>
# A normal Tensor is created when dtype arg is specified.
tf.constant(5, tf.int32) # <tf.Tensor: shape=(), dtype=int32, numpy=5>
<tf.Tensor: shape=(), dtype=int32, numpy=5>
第二種情況:當沒有使用者指定的 dtype 的輸入傳遞至支援 WeakTensor 的 API 時。
tf.math.abs([100.0, 4.0]) # <tf.Tensor: shape=(2,), dtype=float32, numpy=array([100., 4.], dtype=float32), weak=True>
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([100., 4.], dtype=float32), weak=True>
開啟新型別提升的效果
以下是非詳盡的清單,列出開啟新型別提升所產生的變更。
- 更一致且可預測的提升結果。
- 降低位元加寬的風險。
tf.Tensor
數學 Dunder 方法使用新型別提升。tf.constant
可以傳回WeakTensor
。tf.constant
允許在傳入 dtype 與dtype
引數不同的張量輸入時進行隱含轉換。tf.Variable
就地運算 (assign、assign-add、assign-sub) 允許隱含轉換。tnp.array(1)
和tnp.array(1.0)
傳回 32 位元 WeakTensor。WeakTensor
將針對支援 WeakTensor 的一元和二元 API 建立和使用。
更一致且可預測的提升結果
使用格狀結構系統可讓新型別提升產生一致且可預測的型別提升結果。
舊型別提升
使用舊型別提升時,變更運算順序會產生不一致的結果。
# Setup
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="legacy")
a = np.array(1, dtype=np.int8)
b = tf.constant(1)
c = np.array(1, dtype=np.float16)
# (a + b) + c throws an InvalidArgumentError.
try:
tf.add(tf.add(a, b), c)
except tf.errors.InvalidArgumentError as e:
print(f'{type(e)}: {e}') # InvalidArgumentError
<class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>: cannot compute AddV2 as input #1(zero-based) was expected to be a int8 tensor but is a int32 tensor [Op:AddV2] name:
# (b + a) + c returns an i32 result.
tf.add(tf.add(b, a), c) # <tf.Tensor: shape=(), dtype=int32, numpy=3>
<tf.Tensor: shape=(), dtype=int32, numpy=3>
新型別提升
無論順序為何,新型別提升都會產生一致的結果。
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = np.array(1, dtype=np.int8)
b = tf.constant(1)
c = np.array(1, dtype=np.float16)
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
# (a + b) + c returns a f16 result.
tf.add(tf.add(a, b), c) # <tf.Tensor: shape=(), dtype=float16, numpy=3.0>
<tf.Tensor: shape=(), dtype=float16, numpy=3.0>
# (b + a) + c also returns a f16 result.
tf.add(tf.add(b, a), c) # <tf.Tensor: shape=(), dtype=float16, numpy=3.0>
<tf.Tensor: shape=(), dtype=float16, numpy=3.0>
降低位元加寬的風險
舊型別提升
舊型別提升通常會產生 64 位元結果。
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="legacy")
np.array(3.2, np.float16) + tf.constant(1, tf.int8) + tf.constant(50) # <tf.Tensor: shape=(), dtype=float64, numpy=54.19921875>
<tf.Tensor: shape=(), dtype=float64, numpy=54.19921875>
新型別提升
新型別提升會傳回位元數最少且必要的結果。
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
np.array(3.2, np.float16) + tf.constant(1, tf.int8) + tf.constant(50) # <tf.Tensor: shape=(), dtype=float16, numpy=54.2>
<tf.Tensor: shape=(), dtype=float16, numpy=54.2>
tf.Tensor 數學 Dunder 方法
所有 tf.Tensor
數學 Dunder 方法都將遵循新型別提升。
-tf.constant(5) # <tf.Tensor: shape=(), dtype=int32, numpy=-5, weak=True>
<tf.Tensor: shape=(), dtype=int32, numpy=-5, weak=True>
tf.constant(5, tf.int16) - tf.constant(1, tf.float32) # <tf.Tensor: shape=(), dtype=float32, numpy=4.0>
<tf.Tensor: shape=(), dtype=float32, numpy=4.0>
tf.Variable 就地運算
隱含轉換將在 tf.Variable
就地運算中允許。
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = tf.Variable(10, tf.int32)
a.assign_add(tf.constant(5, tf.int16)) # <tf.Variable shape=() dtype=int32, numpy=15>
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet. <tf.Variable 'UnreadVariable' shape=() dtype=int32, numpy=15>
tf.constant 隱含轉換
在舊型別提升中,tf.constant
需要輸入張量具有與 dtype 引數相同的 dtype。但是,在新型別提升中,我們會將張量隱含轉換為指定的 dtype。
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = tf.constant(10, tf.int16)
tf.constant(a, tf.float32) # <tf.Tensor: shape=(), dtype=float32, numpy=10.0>
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet. <tf.Tensor: shape=(), dtype=float32, numpy=10.0>
TF-NumPy 陣列
tnp.array
預設為 i32*
和 f32*
,適用於使用新型別提升的 Python 輸入。
tnp.array(1) # <tf.Tensor: shape=(), dtype=int32, numpy=1, weak=True>
<tf.Tensor: shape=(), dtype=int32, numpy=1, weak=True>
tnp.array(1.0) # <tf.Tensor: shape=(), dtype=int32, numpy=1, weak=True>
<tf.Tensor: shape=(), dtype=float32, numpy=1.0, weak=True>
輸入型別推斷
以下說明如何在新型別提升中推斷不同輸入的型別。
tf.Tensor
:由於tf.Tensor
具有 dtype 屬性,因此我們不會執行進一步的推斷。- NumPy 型別:這包括
np.array(1)
、np.int16(1)
和np.float
等型別。由於 NumPy 輸入也有 dtype 屬性,因此我們將 dtype 屬性視為結果推斷型別。請注意,NumPy 預設為i64
和f64
。 - Python 純量/巢狀型別:這包括
1
、[1, 2, 3]
和(1.0, 2.0)
等型別。- Python
int
推斷為i32*
。 - Python
float
推斷為f32*
。 - Python
complex
推斷為c128*
。
- Python
- 如果輸入不屬於上述任何類別,但具有 dtype 屬性,我們會將 dtype 屬性視為結果推斷型別。
延伸閱讀
新型別提升與 JAX-NumPy 的型別提升非常相似。如果您想瞭解有關新型別提升和設計選擇的更多詳細資訊,請查看以下資源。
參考資料
支援 WeakTensor 的 API
以下是支援 WeakTensor
的 API 清單。
對於一元運算,這表示如果傳入沒有使用者指定型別的輸入,則會傳回 WeakTensor
。
對於二元運算,它將遵循此處的提升表。它可能會或可能不會傳回 WeakTensor
,具體取決於兩個輸入的提升結果。
tf.bitwise.invert
tf.clip_by_value
tf.debugging.check_numerics
tf.expand_dims
tf.identity
tf.image.adjust_brightness
tf.image.adjust_gamma
tf.image.extract_patches
tf.image.random_brightness
tf.image.stateless_random_brightness
tf.linalg.diag
tf.linalg.diag_part
tf.linalg.matmul
tf.linalg.matrix_transpose
tf.linalg.tensor_diag_part
tf.linalg.trace
tf.math.abs
tf.math.acos
tf.math.acosh
tf.math.add
tf.math.angle
tf.math.asin
tf.math.asinh
tf.math.atan
tf.math.atanh
tf.math.ceil
tf.math.conj
tf.math.cos
tf.math.cosh
tf.math.digamma
tf.math.divide_no_nan
tf.math.divide
tf.math.erf
tf.math.erfc
tf.math.erfcinv
tf.math.erfinv
tf.math.exp
tf.math.expm1
tf.math.floor
tf.math.floordiv
tf.math.floormod
tf.math.imag
tf.math.lgamma
tf.math.log1p
tf.math.log_sigmoid
tf.math.log
tf.math.multiply_no_nan
tf.math.multiply
tf.math.ndtri
tf.math.negative
tf.math.pow
tf.math.real
tf.math.real
tf.math.reciprocal_no_nan
tf.math.reciprocal
tf.math.reduce_euclidean_norm
tf.math.reduce_logsumexp
tf.math.reduce_max
tf.math.reduce_mean
tf.math.reduce_min
tf.math.reduce_prod
tf.math.reduce_std
tf.math.reduce_sum
tf.math.reduce_variance
tf.math.rint
tf.math.round
tf.math.rsqrt
tf.math.scalar_mul
tf.math.sigmoid
tf.math.sign
tf.math.sin
tf.math.sinh
tf.math.softplus
tf.math.special.bessel_i0
tf.math.special.bessel_i0e
tf.math.special.bessel_i1
tf.math.special.bessel_i1e
tf.math.special.bessel_j0
tf.math.special.bessel_j1
tf.math.special.bessel_k0
tf.math.special.bessel_k0e
tf.math.special.bessel_k1
tf.math.special.bessel_k1e
tf.math.special.bessel_y0
tf.math.special.bessel_y1
tf.math.special.dawsn
tf.math.special.expint
tf.math.special.fresnel_cos
tf.math.special.fresnel_sin
tf.math.special.spence
tf.math.sqrt
tf.math.square
tf.math.subtract
tf.math.tan
tf.math.tanh
tf.nn.depth_to_space
tf.nn.elu
tf.nn.gelu
tf.nn.leaky_relu
tf.nn.log_softmax
tf.nn.relu6
tf.nn.relu
tf.nn.selu
tf.nn.softsign
tf.nn.space_to_depth
tf.nn.swish
tf.ones_like
tf.realdiv
tf.reshape
tf.squeeze
tf.stop_gradient
tf.transpose
tf.truncatediv
tf.truncatemod
tf.zeros_like
tf.experimental.numpy.abs
tf.experimental.numpy.absolute
tf.experimental.numpy.amax
tf.experimental.numpy.amin
tf.experimental.numpy.angle
tf.experimental.numpy.arange
tf.experimental.numpy.arccos
tf.experimental.numpy.arccosh
tf.experimental.numpy.arcsin
tf.experimental.numpy.arcsinh
tf.experimental.numpy.arctan
tf.experimental.numpy.arctanh
tf.experimental.numpy.around
tf.experimental.numpy.array
tf.experimental.numpy.asanyarray
tf.experimental.numpy.asarray
tf.experimental.numpy.ascontiguousarray
tf.experimental.numpy.average
tf.experimental.numpy.bitwise_not
tf.experimental.numpy.cbrt
tf.experimental.numpy.ceil
tf.experimental.numpy.conj
tf.experimental.numpy.conjugate
tf.experimental.numpy.copy
tf.experimental.numpy.cos
tf.experimental.numpy.cosh
tf.experimental.numpy.cumprod
tf.experimental.numpy.cumsum
tf.experimental.numpy.deg2rad
tf.experimental.numpy.diag
tf.experimental.numpy.diagflat
tf.experimental.numpy.diagonal
tf.experimental.numpy.diff
tf.experimental.numpy.empty_like
tf.experimental.numpy.exp2
tf.experimental.numpy.exp
tf.experimental.numpy.expand_dims
tf.experimental.numpy.expm1
tf.experimental.numpy.fabs
tf.experimental.numpy.fix
tf.experimental.numpy.flatten
tf.experimental.numpy.flip
tf.experimental.numpy.fliplr
tf.experimental.numpy.flipud
tf.experimental.numpy.floor
tf.experimental.numpy.full_like
tf.experimental.numpy.imag
tf.experimental.numpy.log10
tf.experimental.numpy.log1p
tf.experimental.numpy.log2
tf.experimental.numpy.log
tf.experimental.numpy.max
tf.experimental.numpy.mean
tf.experimental.numpy.min
tf.experimental.numpy.moveaxis
tf.experimental.numpy.nanmean
tf.experimental.numpy.negative
tf.experimental.numpy.ones_like
tf.experimental.numpy.positive
tf.experimental.numpy.prod
tf.experimental.numpy.rad2deg
tf.experimental.numpy.ravel
tf.experimental.numpy.real
tf.experimental.numpy.reciprocal
tf.experimental.numpy.repeat
tf.experimental.numpy.reshape
tf.experimental.numpy.rot90
tf.experimental.numpy.round
tf.experimental.numpy.signbit
tf.experimental.numpy.sin
tf.experimental.numpy.sinc
tf.experimental.numpy.sinh
tf.experimental.numpy.sort
tf.experimental.numpy.sqrt
tf.experimental.numpy.square
tf.experimental.numpy.squeeze
tf.experimental.numpy.std
tf.experimental.numpy.sum
tf.experimental.numpy.swapaxes
tf.experimental.numpy.tan
tf.experimental.numpy.tanh
tf.experimental.numpy.trace
tf.experimental.numpy.transpose
tf.experimental.numpy.triu
tf.experimental.numpy.vander
tf.experimental.numpy.var
tf.experimental.numpy.zeros_like