張量切片簡介

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視原始碼 下載筆記本

在處理物件偵測和 NLP 等 ML 應用程式時,有時需要處理張量的子區段 (切片)。例如,如果您的模型架構包含路由,其中一個層可能會控制哪個訓練範例路由到下一層。在這種情況下,您可以使用張量切片運算子來分割張量,然後以正確的順序將它們放回一起。

在 NLP 應用程式中,您可以使用張量切片在訓練期間執行字詞遮罩。例如,您可以從句子清單中產生訓練資料,方法是在每個句子中選擇要遮罩的字詞索引,取出字詞作為標籤,然後將選取的字詞替換為遮罩符號。

在本指南中,您將學習如何使用 TensorFlow API 來

  • 從張量中擷取切片
  • 在張量的特定索引處插入資料

本指南假設您已熟悉張量索引。在開始使用本指南之前,請先閱讀 張量TensorFlow NumPy 指南的索引章節。

設定

import tensorflow as tf
import numpy as np

擷取張量切片

使用 tf.slice 執行類似 NumPy 的張量切片。

t1 = tf.constant([0, 1, 2, 3, 4, 5, 6, 7])

print(tf.slice(t1,
               begin=[1],
               size=[3]))

或者,您可以使用更 Pythonic 的語法。請注意,張量切片在開始-停止範圍內均勻間隔。

print(t1[1:4])

print(t1[-3:])

對於二維張量,您可以使用類似

t2 = tf.constant([[0, 1, 2, 3, 4],
                  [5, 6, 7, 8, 9],
                  [10, 11, 12, 13, 14],
                  [15, 16, 17, 18, 19]])

print(t2[:-1, 1:3])

您也可以在高維度張量上使用 tf.slice

t3 = tf.constant([[[1, 3, 5, 7],
                   [9, 11, 13, 15]],
                  [[17, 19, 21, 23],
                   [25, 27, 29, 31]]
                  ])

print(tf.slice(t3,
               begin=[1, 1, 0],
               size=[1, 1, 2]))

您也可以使用 tf.strided_slice,透過「跨步」張量維度來擷取張量切片。

使用 tf.gather 從張量的單一軸擷取特定索引。

print(tf.gather(t1,
                indices=[0, 3, 6]))

# This is similar to doing

t1[::3]

tf.gather 不需要索引均勻間隔。

alphabet = tf.constant(list('abcdefghijklmnopqrstuvwxyz'))

print(tf.gather(alphabet,
                indices=[2, 0, 19, 18]))

若要從張量的多個軸擷取切片,請使用 tf.gather_nd。當您想要收集矩陣的元素而不僅僅是其列或欄時,這非常有用。

t4 = tf.constant([[0, 5],
                  [1, 6],
                  [2, 7],
                  [3, 8],
                  [4, 9]])

print(tf.gather_nd(t4,
                   indices=[[2], [3], [0]]))

t5 = np.reshape(np.arange(18), [2, 3, 3])

print(tf.gather_nd(t5,
                   indices=[[0, 0, 0], [1, 2, 1]]))
# Return a list of two matrices

print(tf.gather_nd(t5,
                   indices=[[[0, 0], [0, 2]], [[1, 0], [1, 2]]]))
# Return one matrix

print(tf.gather_nd(t5,
                   indices=[[0, 0], [0, 2], [1, 0], [1, 2]]))

將資料插入張量

使用 tf.scatter_nd 在張量的特定切片/索引處插入資料。請注意,您在其中插入值的張量已初始化為零。

t6 = tf.constant([10])
indices = tf.constant([[1], [3], [5], [7], [9]])
data = tf.constant([2, 4, 6, 8, 10])

print(tf.scatter_nd(indices=indices,
                    updates=data,
                    shape=t6))

類似 tf.scatter_nd 等需要零初始化的張量的方法,與稀疏張量初始化工具類似。您可以使用 tf.gather_ndtf.scatter_nd 來模擬稀疏張量運算子的行為。

考慮一個範例,您可以在其中結合使用這兩種方法來建構稀疏張量。

# Gather values from one tensor by specifying indices

new_indices = tf.constant([[0, 2], [2, 1], [3, 3]])
t7 = tf.gather_nd(t2, indices=new_indices)

# Add these values into a new tensor

t8 = tf.scatter_nd(indices=new_indices, updates=t7, shape=tf.constant([4, 5]))

print(t8)

這類似於

t9 = tf.SparseTensor(indices=[[0, 2], [2, 1], [3, 3]],
                     values=[2, 11, 18],
                     dense_shape=[4, 5])

print(t9)
# Convert the sparse tensor into a dense tensor

t10 = tf.sparse.to_dense(t9)

print(t10)

若要將資料插入具有預先存在值的張量中,請使用 tf.tensor_scatter_nd_add

t11 = tf.constant([[2, 7, 0],
                   [9, 0, 1],
                   [0, 3, 8]])

# Convert the tensor into a magic square by inserting numbers at appropriate indices

t12 = tf.tensor_scatter_nd_add(t11,
                               indices=[[0, 2], [1, 1], [2, 0]],
                               updates=[6, 5, 4])

print(t12)

同樣地,使用 tf.tensor_scatter_nd_sub 從具有預先存在值的張量中減去值。

# Convert the tensor into an identity matrix

t13 = tf.tensor_scatter_nd_sub(t11,
                               indices=[[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [2, 1], [2, 2]],
                               updates=[1, 7, 9, -1, 1, 3, 7])

print(t13)

使用 tf.tensor_scatter_nd_min 將元素方式最小值從一個張量複製到另一個張量。

t14 = tf.constant([[-2, -7, 0],
                   [-9, 0, 1],
                   [0, -3, -8]])

t15 = tf.tensor_scatter_nd_min(t14,
                               indices=[[0, 2], [1, 1], [2, 0]],
                               updates=[-6, -5, -4])

print(t15)

同樣地,使用 tf.tensor_scatter_nd_max 將元素方式最大值從一個張量複製到另一個張量。

t16 = tf.tensor_scatter_nd_max(t14,
                               indices=[[0, 2], [1, 1], [2, 0]],
                               updates=[6, 5, 4])

print(t16)

延伸閱讀與資源

在本指南中,您已學習如何使用 TensorFlow 提供的張量切片運算子,以便更精細地控制張量中的元素。