![]() |
![]() |
![]() |
![]() |
簡介
TensorFlow Probability (TFP) 提供許多 JointDistribution
抽象化,讓機率推論更輕鬆,使用者可以近乎數學的形式輕鬆表達機率圖形模型;抽象化產生從模型取樣和評估模型樣本對數機率的方法。在本教學課程中,我們會檢閱「自動批次」變體,這些變體是在原始 JointDistribution
抽象化之後開發的。相對於原始的非自動批次抽象化,自動批次版本更簡單易用且更符合人體工學,讓許多模型能夠以更少的重複性程式碼來表達。在此 Colab 中,我們會詳細探討一個簡單的模型 (可能很乏味),闡明自動批次處理解決的問題,並 (希望) 在此過程中教導讀者更多關於 TFP 形狀概念的知識。
在自動批次處理推出之前,JointDistribution
有幾種不同的變體,對應於表達機率模型的不同語法樣式:JointDistributionSequential
、JointDistributionNamed
和 JointDistributionCoroutine
。自動批次處理以混合類別的形式存在,因此我們現在擁有所有這些的 AutoBatched
變體。在本教學課程中,我們會探討 JointDistributionSequential
和 JointDistributionSequentialAutoBatched
之間的差異;不過,我們在此處所做的一切都適用於其他變體,基本上沒有任何變更。
依附元件與先決條件
匯入和設定
先決條件:貝氏迴歸問題
我們將考量非常簡單的貝氏迴歸情境
\[ \begin{align*} m & \sim \text{Normal}(0, 1) \\ b & \sim \text{Normal}(0, 1) \\ Y & \sim \text{Normal}(mX + b, 1) \end{align*} \]
在此模型中,m
和 b
是從標準常態分佈中抽取的,而觀察值 Y
是從常態分佈中抽取的,其平均值取決於隨機變數 m
和 b
,以及一些 (非隨機、已知) 共變數 X
。(為了簡化,在此範例中,我們假設所有隨機變數的尺度都是已知的。)
若要在本模型中執行推論,我們需要知道共變數 X
和觀察值 Y
,但為了本教學課程的目的,我們只需要 X
,因此我們定義一個簡單的虛擬 X
X = np.arange(7)
X
array([0, 1, 2, 3, 4, 5, 6])
需求
在機率推論中,我們通常想要執行兩個基本運算
sample
:從模型中繪製樣本。log_prob
:計算模型樣本的對數機率。
TFP 的 JointDistribution
抽象化 (以及許多其他機率程式設計方法) 的主要貢獻是讓使用者可以一次編寫模型,並能存取 sample
和 log_prob
計算。
注意到我們的資料集中有 7 個點 (X.shape = (7,)
),我們現在可以說明優秀 JointDistribution
的需求
sample()
應產生Tensors
清單,其形狀為[(), (), (7,)
],分別對應於純量斜率、純量偏差和向量觀察值。log_prob(sample())
應產生純量:特定斜率、偏差和觀察值的對數機率。sample([5, 3])
應產生Tensors
清單,其形狀為[(5, 3), (5, 3), (5, 3, 7)]
,代表模型樣本的(5, 3)
-批次。log_prob(sample([5, 3]))
應產生形狀為 (5, 3) 的Tensor
。
我們現在將檢視一系列 JointDistribution
模型,看看如何達成上述需求,並希望在此過程中多瞭解一些關於 TFP 形狀的知識。
劇透警告:滿足上述需求且無需額外重複性程式碼的方法是自動批次處理。
首次嘗試;JointDistributionSequential
jds = tfd.JointDistributionSequential([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
這或多或少是將模型直接翻譯成程式碼。斜率 m
和偏差 b
很簡單。Y
是使用 lambda
函數定義的:一般模式是,JointDistributionSequential
(JDS) 中的 \(k\) 個引數的 lambda
函數使用模型中先前的 \(k\) 個分佈。請注意「反向」順序。
我們將呼叫 sample_distributions
,它會同時傳回樣本和用於產生樣本的基礎「子分佈」。(我們可以只呼叫 sample
來產生樣本;在本教學課程稍後部分,擁有分佈也會很方便。) 我們產生的樣本很好
dists, sample = jds.sample_distributions()
sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>, <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>, <tf.Tensor: shape=(7,), dtype=float32, numpy= array([ 0.18573815, -1.79962 , -1.8106272 , -3.5971394 , -6.6625295 , -7.308844 , -9.832693 ], dtype=float32)>]
但 log_prob
產生的結果具有不想要的形狀
jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy= array([-4.4777603, -4.6775575, -4.7430477, -4.647725 , -4.5746684, -4.4368567, -4.480562 ], dtype=float32)>
而且多次取樣無效
try:
jds.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]
讓我們嘗試瞭解問題出在哪裡。
簡要回顧:批次形狀和事件形狀
在 TFP 中,一般 (非 JointDistribution
) 機率分佈具有事件形狀和批次形狀,瞭解兩者之間的差異對於有效使用 TFP 至關重要
- 事件形狀描述從分佈中單次繪製的形狀;繪製可能在維度之間具有依賴性。對於純量分佈,事件形狀為 []。對於 5 維 MultivariateNormal,事件形狀為 [5]。
- 批次形狀描述獨立、非同分佈的繪製,又稱為「分佈批次」。在單一 Python 物件中表示分佈批次是 TFP 實現大規模效率的關鍵方法之一。
就我們的目的而言,要記住的關鍵事實是,如果我們在分佈的單一範例上呼叫 log_prob
,則結果的形狀將始終與批次形狀 (也就是說,最右側維度) 相符。
如需更深入探討形狀,請參閱「瞭解 TensorFlow 分佈形狀」教學課程。
為什麼 log_prob(sample())
未產生純量?
讓我們運用我們對批次形狀和事件形狀的知識,來探索 log_prob(sample())
發生了什麼事。這是我們的樣本,再次
sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>, <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>, <tf.Tensor: shape=(7,), dtype=float32, numpy= array([ 0.18573815, -1.79962 , -1.8106272 , -3.5971394 , -6.6625295 , -7.308844 , -9.832693 ], dtype=float32)>]
以下是我們的分佈
dists
[<tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>, <tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>, <tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>]
對數機率是透過加總子分佈在零件 (相符) 元素的對數機率來計算的
log_prob_parts = [dist.log_prob(s) for (dist, s) in zip(dists, sample)]
log_prob_parts
[<tf.Tensor: shape=(), dtype=float32, numpy=-2.3113134>, <tf.Tensor: shape=(), dtype=float32, numpy=-1.1357536>, <tf.Tensor: shape=(7,), dtype=float32, numpy= array([-1.0306933, -1.2304904, -1.2959809, -1.200658 , -1.1276014, -0.9897899, -1.0334952], dtype=float32)>]
sum(log_prob_parts) - jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0.], dtype=float32)>
因此,一種解釋層次是,對數機率計算傳回 7-Tensor,因為 log_prob_parts
的第三個子元件是 7-Tensor。但為什麼呢?
嗯,我們看到 dists
的最後一個元素 (對應於數學公式中 Y
的分佈) 的 batch_shape
為 [7]
。換句話說,Y
的分佈是一批 7 個獨立常態分佈 (具有不同的平均值,在此情況下,尺度相同)。
我們現在瞭解問題出在哪裡:在 JDS 中,Y
的分佈具有 batch_shape=[7]
,JDS 的樣本代表 m
和 b
的純量,以及 7 個獨立常態分佈的「批次」。而 log_prob
計算 7 個獨立的對數機率,每個對數機率都代表在某些 X[i]
繪製 m
和 b
以及單一觀察值 Y[i]
的對數機率。
使用 Independent
修正 log_prob(sample())
回想一下,dists[2]
具有 event_shape=[]
和 batch_shape=[7]
dists[2]
<tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>
透過使用 TFP 的 Independent
元分佈 (將批次維度轉換為事件維度),我們可以將其轉換為具有 event_shape=[7]
和 batch_shape=[]
的分佈 (我們將其重新命名為 y_dist_i
,因為它是 Y
的分佈,而 _i
代表我們的 Independent
包裝)
y_dist_i = tfd.Independent(dists[2], reinterpreted_batch_ndims=1)
y_dist_i
<tfp.distributions.Independent 'IndependentJointDistributionSequential_sample_distributions_Normal' batch_shape=[] event_shape=[7] dtype=float32>
現在,7 維向量的 log_prob
是純量
y_dist_i.log_prob(sample[2])
<tf.Tensor: shape=(), dtype=float32, numpy=-7.9087086>
在底層,Independent
會加總批次
y_dist_i.log_prob(sample[2]) - tf.reduce_sum(dists[2].log_prob(sample[2]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>
實際上,我們可以使用它來建構新的 jds_i
(i
再次代表 Independent
),其中 log_prob
傳回純量
jds_i = tfd.JointDistributionSequential([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Independent( # Y
tfd.Normal(loc=m*X + b, scale=1.),
reinterpreted_batch_ndims=1)
])
jds_i.log_prob(sample)
<tf.Tensor: shape=(), dtype=float32, numpy=-11.355776>
幾點注意事項
jds_i.log_prob(s)
與tf.reduce_sum(jds.log_prob(s))
不同。前者產生聯合分佈的「正確」對數機率。後者加總了 7-Tensor,每個元素都是m
、b
和Y
對數機率的單一元素的對數機率之總和,因此它會過度計算m
和b
。(log_prob(m) + log_prob(b) + log_prob(Y)
傳回結果,而不是擲回例外狀況,因為 TFP 遵循 TF 和 NumPy 的廣播規則;將純量新增至向量會產生向量大小的結果。)- 在這個特定情況下,我們可以解決問題,並使用
MultivariateNormalDiag
而非Independent(Normal(...))
達成相同的結果。MultivariateNormalDiag
是向量值分佈 (也就是說,它已經具有向量事件形狀)。實際上,MultivariateNormalDiag
可以 (但並非如此) 實作為Independent
和Normal
的組合。值得記住的是,給定向量V
,來自n1 = Normal(loc=V)
和n2 = MultivariateNormalDiag(loc=V)
的樣本是無法區分的;這些分佈之間的差異在於n1.log_prob(n1.sample())
是向量,而n2.log_prob(n2.sample())
是純量。
多個樣本?
繪製多個樣本仍然無效
try:
jds_i.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]
讓我們思考一下原因。當我們呼叫 jds_i.sample([5, 3])
時,我們首先會繪製 m
和 b
的樣本,每個樣本的形狀為 (5, 3)
。接下來,我們將嘗試透過以下方式建構 Normal
分佈
tfd.Normal(loc=m*X + b, scale=1.)
但是如果 m
的形狀為 (5, 3)
,而 X
的形狀為 7
,我們就無法將它們相乘,而這確實是我們遇到的錯誤
m = tfd.Normal(0., 1.).sample([5, 3])
try:
m * X
except tf.errors.InvalidArgumentError as e:
print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]
為了解決這個問題,讓我們思考一下 Y
分佈必須具備哪些屬性。如果我們呼叫了 jds_i.sample([5, 3])
,那麼我們知道 m
和 b
的形狀都將為 (5, 3)
。在 Y
分佈上呼叫 sample
應該產生什麼形狀?顯而易見的答案是 (5, 3, 7)
:對於每個批次點,我們都希望樣本的大小與 X
相同。我們可以透過使用 TensorFlow 的廣播功能 (新增額外維度) 來達成此目的
m[..., tf.newaxis].shape
TensorShape([5, 3, 1])
(m[..., tf.newaxis] * X).shape
TensorShape([5, 3, 7])
在 m
和 b
中新增軸,我們可以定義新的 JDS,以支援多個樣本
jds_ia = tfd.JointDistributionSequential([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Independent( # Y
tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
reinterpreted_batch_ndims=1)
])
shaped_sample = jds_ia.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[-1.1133379 , 0.16390413, -0.24177533], [-1.1312429 , -0.6224666 , -1.8182136 ], [-0.31343174, -0.32932565, 0.5164407 ], [-0.0119963 , -0.9079621 , 2.3655841 ], [-0.26293617, 0.8229698 , 0.31098196]], dtype=float32)>, <tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[-0.02876974, 1.0872147 , 1.0138507 ], [ 0.27367726, -1.331534 , -0.09084719], [ 1.3349475 , -0.68765205, 1.680652 ], [ 0.75436825, 1.3050154 , -0.9415123 ], [-1.2502679 , -0.25730947, 0.74611956]], dtype=float32)>, <tf.Tensor: shape=(5, 3, 7), dtype=float32, numpy= array([[[-1.8258233e+00, -3.0641669e-01, -2.7595463e+00, -1.6952467e+00, -4.8197951e+00, -5.2986512e+00, -6.6931367e+00], [ 3.6438566e-01, 1.0067395e+00, 1.4542470e+00, 8.1155670e-01, 1.8868095e+00, 2.3877139e+00, 1.0195159e+00], [-8.3624744e-01, 1.2518480e+00, 1.0943471e+00, 1.3052304e+00, -4.5756745e-01, -1.0668410e-01, -7.0669651e-02]], [[-3.1788960e-01, 9.2615485e-03, -3.0963073e+00, -2.2846246e+00, -3.2269263e+00, -6.0213070e+00, -7.4806519e+00], [-3.9149747e+00, -3.5155020e+00, -1.5669601e+00, -5.0759468e+00, -4.5065498e+00, -5.6719379e+00, -4.8012795e+00], [ 1.3053948e-01, -8.0493152e-01, -4.7845001e+00, -4.9721808e+00, -7.1365709e+00, -9.6198196e+00, -9.7951422e+00]], [[ 2.0621397e+00, 3.4639853e-01, 7.0252883e-01, -1.4311566e+00, 3.3790007e+00, 1.1619035e+00, -8.9105040e-01], [-7.8956139e-01, -8.5023916e-01, -9.7148323e-01, -2.6229355e+00, -2.7150445e+00, -2.4633870e+00, -2.1841538e+00], [ 7.7627432e-01, 2.2401071e+00, 3.7601702e+00, 2.4245868e+00, 4.0690269e+00, 4.0605016e+00, 5.1753912e+00]], [[ 1.4275590e+00, 3.3346462e+00, 1.5374103e+00, -2.2849756e-01, 9.1219616e-01, -3.1220305e-01, -3.2643962e-01], [-3.1910419e-02, -3.8848895e-01, 9.9946201e-02, -2.3619974e+00, -1.8507402e+00, -3.6830821e+00, -5.4907336e+00], [-7.1941972e-02, 2.1602919e+00, 4.9575748e+00, 4.2317696e+00, 9.3528280e+00, 1.0526063e+01, 1.5262107e+01]], [[-2.3257759e+00, -2.5343289e+00, -3.5342445e+00, -4.0423255e+00, -3.2361765e+00, -3.3434000e+00, -2.6849220e+00], [ 1.5006512e-02, -1.9866472e-01, 7.6781356e-01, 1.6228745e+00, 1.4191239e+00, 2.6655579e+00, 4.4663467e+00], [ 2.6599693e+00, 1.2663836e+00, 1.7162113e+00, 1.4839669e+00, 2.0559487e+00, 2.5976877e+00, 2.5977583e+00]]], dtype=float32)>]
jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[-12.483114 , -10.139662 , -11.514159 ], [-11.656767 , -17.201958 , -12.132455 ], [-17.838818 , -9.474525 , -11.24898 ], [-13.95219 , -12.490049 , -17.123957 ], [-14.487818 , -11.3755455, -10.576363 ]], dtype=float32)>
作為額外檢查,我們將驗證單一批次點的對數機率是否與我們之前的對數機率相符
(jds_ia.log_prob(shaped_sample)[3, 1] -
jds_i.log_prob([shaped_sample[0][3, 1],
shaped_sample[1][3, 1],
shaped_sample[2][3, 1, :]]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>
自動批次處理大獲全勝
太棒了!我們現在有了 JointDistribution 的版本,可以處理我們所有的需求:由於使用了 tfd.Independent
,log_prob
傳回純量,而且在我們透過新增額外軸修正廣播後,多個樣本現在也能運作。
如果我告訴您有更簡單、更好的方法呢?確實有,它稱為 JointDistributionSequentialAutoBatched
(JDSAB)
jds_ab = tfd.JointDistributionSequentialAutoBatched([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
jds_ab.log_prob(jds.sample())
<tf.Tensor: shape=(), dtype=float32, numpy=-12.954952>
shaped_sample = jds_ab.sample([5, 3])
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[-12.191533 , -10.43885 , -16.371655 ], [-13.292994 , -11.97949 , -16.788685 ], [-15.987699 , -13.435732 , -10.6029 ], [-10.184758 , -11.969714 , -14.275676 ], [-12.740775 , -11.5654125, -12.990162 ]], dtype=float32)>
jds_ab.log_prob(shaped_sample) - jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], dtype=float32)>
這是如何運作的?雖然您可以嘗試閱讀程式碼以深入瞭解,但我們將簡要概述,這對於大多數使用案例來說已足夠
- 回想一下,我們的第一個問題是
Y
的分佈具有batch_shape=[7]
和event_shape=[]
,而我們使用Independent
將批次維度轉換為事件維度。JDSAB 忽略元件分佈的批次形狀;相反地,它將批次形狀視為模型的整體屬性,假設為[]
(除非透過設定batch_ndims > 0
另行指定)。效果相當於使用 tfd.Independent 將元件分佈的所有批次維度轉換為事件維度,就像我們在上面手動執行的那樣。 - 我們的第二個問題是需要調整
m
和b
的形狀,以便在建立多個樣本時,它們可以與X
適當地廣播。使用 JDSAB,您可以編寫模型以產生單一樣本,而我們使用 TensorFlow 的 vectorized_map「提升」整個模型以產生多個樣本。(此功能類似於 JAX 的 vmap。)
為了更詳細地探討批次形狀問題,我們可以比較原始「不良」聯合分佈 jds
、批次修正分佈 jds_i
和 jds_ia
以及自動批次處理 jds_ab
的批次形狀
jds.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([7])]
jds_i.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ia.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ab.batch_shape
TensorShape([])
我們看到原始 jds
具有具有不同批次形狀的子分佈。jds_i
和 jds_ia
透過建立具有相同 (空) 批次形狀的子分佈來修正此問題。jds_ab
只有單一 (空) 批次形狀。
值得注意的是,JointDistributionSequentialAutoBatched
免費提供了一些額外的通用性。假設我們使共變數 X
(以及隱含的觀察值 Y
) 成為二維
X = np.arange(14).reshape((2, 7))
X
array([[ 0, 1, 2, 3, 4, 5, 6], [ 7, 8, 9, 10, 11, 12, 13]])
我們的 JointDistributionSequentialAutoBatched
無需變更即可運作 (我們需要重新定義模型,因為 X
的形狀會由 jds_ab.log_prob
進行快取)
jds_ab = tfd.JointDistributionSequentialAutoBatched([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
shaped_sample = jds_ab.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[ 0.1813647 , -0.85994506, 0.27593774], [-0.73323774, 1.1153806 , 0.8841938 ], [ 0.5127983 , -0.29271227, 0.63733214], [ 0.2362284 , -0.919168 , 1.6648189 ], [ 0.26317367, 0.73077047, 2.5395133 ]], dtype=float32)>, <tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[ 0.09636458, 2.0138032 , -0.5054413 ], [ 0.63941646, -1.0785882 , -0.6442188 ], [ 1.2310615 , -0.3293852 , 0.77637213], [ 1.2115169 , -0.98906034, -0.07816773], [-1.1318136 , 0.510014 , 1.036522 ]], dtype=float32)>, <tf.Tensor: shape=(5, 3, 2, 7), dtype=float32, numpy= array([[[[-1.9685398e+00, -1.6832136e+00, -6.9127172e-01, 8.5992378e-01, -5.3123581e-01, 3.1584005e+00, 2.9044402e+00], [-2.5645006e-01, 3.1554163e-01, 3.1186538e+00, 1.4272424e+00, 1.2843871e+00, 1.2266440e+00, 1.2798605e+00]], [[ 1.5973477e+00, -5.3631151e-01, 6.8143606e-03, -1.4910895e+00, -2.1568544e+00, -2.0513713e+00, -3.1663666e+00], [-4.9448099e+00, -2.8385928e+00, -6.9027486e+00, -5.6543546e+00, -7.2378774e+00, -8.1577444e+00, -9.3582869e+00]], [[-2.1233239e+00, 5.8853775e-02, 1.2024102e+00, 1.6622503e+00, -1.9197327e-01, 1.8647723e+00, 6.4322817e-01], [ 3.7549341e-01, 1.5853541e+00, 2.4594500e+00, 2.1952972e+00, 1.7517658e+00, 2.9666045e+00, 2.5468128e+00]]], [[[ 8.9906776e-01, 6.7375046e-01, 7.3354661e-01, -9.9894643e-01, -3.4606690e+00, -3.4810467e+00, -4.4315586e+00], [-3.0670738e+00, -6.3628020e+00, -6.2538433e+00, -6.8091092e+00, -7.7134805e+00, -8.6319380e+00, -8.6904278e+00]], [[-2.2462025e+00, -3.3060855e-01, 1.8974400e-01, 3.1422038e+00, 4.1483402e+00, 3.5642972e+00, 4.8709240e+00], [ 4.7880130e+00, 5.8790064e+00, 9.6695948e+00, 7.8112822e+00, 1.2022618e+01, 1.2411858e+01, 1.4323385e+01]], [[-1.0189297e+00, -7.8115642e-01, 1.6466728e+00, 8.2378983e-01, 3.0765080e+00, 3.0170646e+00, 5.1899948e+00], [ 6.5285158e+00, 7.8038850e+00, 6.4155884e+00, 9.0899811e+00, 1.0040427e+01, 9.1404457e+00, 1.0411951e+01]]], [[[ 4.5557004e-01, 1.4905317e+00, 1.4904103e+00, 2.9777462e+00, 2.8620450e+00, 3.4745665e+00, 3.8295493e+00], [ 3.9977460e+00, 5.7173767e+00, 7.8421035e+00, 6.3180594e+00, 6.0838981e+00, 8.2257290e+00, 9.6548376e+00]], [[-7.0750320e-01, -3.5972297e-01, 4.3136525e-01, -2.3301599e+00, -5.0374687e-01, -2.8338656e+00, -3.4453444e+00], [-3.1258626e+00, -3.4687450e+00, -1.2045374e+00, -4.0196013e+00, -5.8831010e+00, -4.2965469e+00, -4.1388311e+00]], [[ 2.1969774e+00, 2.4614549e+00, 2.2314475e+00, 1.8392437e+00, 2.8367062e+00, 4.8600502e+00, 4.2273531e+00], [ 6.1879644e+00, 5.1792760e+00, 6.1141996e+00, 5.6517797e+00, 8.9979610e+00, 7.5938139e+00, 9.7918644e+00]]], [[[ 1.5249090e+00, 1.1388919e+00, 8.6903995e-01, 3.0762129e+00, 1.5128503e+00, 3.5204377e+00, 2.4760864e+00], [ 3.4166217e+00, 3.5930209e+00, 3.1694956e+00, 4.5797420e+00, 4.5271711e+00, 2.8774328e+00, 4.7288942e+00]], [[-2.3095846e+00, -2.0595703e+00, -3.0093951e+00, -3.8594103e+00, -4.9681158e+00, -6.4256043e+00, -5.5345035e+00], [-6.4306297e+00, -7.0924540e+00, -8.4075985e+00, -1.0417805e+01, -1.1727266e+01, -1.1196255e+01, -1.1333830e+01]], [[-7.0419472e-01, 1.4568675e+00, 3.7946482e+00, 4.8489718e+00, 6.6498446e+00, 9.0224218e+00, 1.1153137e+01], [ 1.0060651e+01, 1.1998097e+01, 1.5326431e+01, 1.7957514e+01, 1.8323889e+01, 2.0160881e+01, 2.1269085e+01]]], [[[-2.2360647e-01, -1.3632748e+00, -7.2704530e-01, 2.3558271e-01, -1.0381399e+00, 1.9387857e+00, -3.3694571e-01], [ 1.6015106e-01, 1.5284677e+00, -4.8567140e-01, -1.7770648e-01, 2.1919653e+00, 1.3015286e+00, 1.3877077e+00]], [[ 1.3688663e+00, 2.6602898e+00, 6.6657305e-01, 4.6554832e+00, 5.7781887e+00, 4.9115267e+00, 4.8446012e+00], [ 5.1983776e+00, 6.2297459e+00, 6.3848300e+00, 8.4291229e+00, 7.1309576e+00, 1.0395646e+01, 8.5736713e+00]], [[ 1.2675294e+00, 5.2844582e+00, 5.1331611e+00, 8.9993315e+00, 1.0794343e+01, 1.4039831e+01, 1.5731170e+01], [ 1.9084715e+01, 2.2191265e+01, 2.3481146e+01, 2.5803375e+01, 2.8632090e+01, 3.0234968e+01, 3.1886738e+01]]]], dtype=float32)>]
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[-28.90071 , -23.052422, -19.851362], [-19.775568, -25.894997, -20.302256], [-21.10754 , -23.667885, -20.973007], [-19.249458, -20.87892 , -20.573763], [-22.351208, -25.457762, -24.648403]], dtype=float32)>
另一方面,我們精心製作的 JointDistributionSequential
不再運作
jds_ia = tfd.JointDistributionSequential([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Independent( # Y
tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
reinterpreted_batch_ndims=1)
])
try:
jds_ia.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
print(e)
Incompatible shapes: [5,3,1] vs. [2,7] [Op:Mul]
為了修正此問題,我們必須新增第二個 tf.newaxis
到 m
和 b
,以符合形狀,並在呼叫 Independent
時將 reinterpreted_batch_ndims
增加到 2。在這種情況下,讓自動批次處理機制處理形狀問題更簡短、更輕鬆且更符合人體工學。
再次強調,我們注意到雖然本筆記本探討了 JointDistributionSequentialAutoBatched
,但 JointDistribution
的其他變體也具有等效的 AutoBatched
。(對於 JointDistributionCoroutine
的使用者,JointDistributionCoroutineAutoBatched
的額外優點是您不再需要指定 Root
節點;如果您從未使用過 JointDistributionCoroutine
,您可以安全地忽略此陳述。)
結論
在本筆記本中,我們介紹了 JointDistributionSequentialAutoBatched
,並詳細探討了一個簡單的範例。希望您學到了一些關於 TFP 形狀和自動批次處理的知識!