JAX 上的 TensorFlow Probability

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視原始碼 下載筆記本

TensorFlow Probability (TFP) 是一個用於機率推理和統計分析的程式庫,現在也適用於 JAX!對於不熟悉的使用者,JAX 是一個以可組合函式轉換為基礎的加速數值運算程式庫。

JAX 上的 TFP 支援常規 TFP 的許多最實用功能,同時保留許多 TFP 使用者現在已習慣的抽象概念和 API。

設定

JAX 上的 TFP 依賴 TensorFlow;讓我們從這個 Colab 完全解除安裝 TensorFlow。

pip uninstall tensorflow -y -q

我們可以透過 TFP 的最新每夜建置版本在 JAX 上安裝 TFP。

pip install -Uq tfp-nightly[jax] > /dev/null

讓我們匯入一些實用的 Python 程式庫。

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn import datasets
sns.set(style='white')
/usr/local/lib/python3.6/dist-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.
  import pandas.util.testing as tm

我們也匯入一些基本的 JAX 功能。

import jax.numpy as jnp
from jax import grad
from jax import jit
from jax import random
from jax import value_and_grad
from jax import vmap

在 JAX 上匯入 TFP

若要在 JAX 上使用 TFP,只需匯入 jax「基底」,然後像平常一樣使用 tfp

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
tfb = tfp.bijectors
tfpk = tfp.math.psd_kernels

示範:貝氏邏輯迴歸

為了示範我們可以使用 JAX 後端做什麼,我們將實作應用於經典鳶尾花資料集的貝氏邏輯迴歸。

首先,讓我們匯入鳶尾花資料集並擷取一些中繼資料。

iris = datasets.load_iris()
features, labels = iris['data'], iris['target']

num_features = features.shape[-1]
num_classes = len(iris.target_names)

我們可以使用 tfd.JointDistributionCoroutine 定義模型。我們將在權重和偏差項上放置標準常態先驗,然後撰寫一個 target_log_prob 函式,將取樣標籤釘選到資料。

Root = tfd.JointDistributionCoroutine.Root
def model():
  w = yield Root(tfd.Sample(tfd.Normal(0., 1.),
                            sample_shape=(num_features, num_classes)))
  b = yield Root(
      tfd.Sample(tfd.Normal(0., 1.), sample_shape=(num_classes,)))
  logits = jnp.dot(features, w) + b
  yield tfd.Independent(tfd.Categorical(logits=logits),
                        reinterpreted_batch_ndims=1)


dist = tfd.JointDistributionCoroutine(model)
def target_log_prob(*params):
  return dist.log_prob(params + (labels,))

我們從 dist 取樣以產生 MCMC 的初始狀態。然後,我們可以定義一個函式,該函式接受隨機金鑰和初始狀態,並從 No-U-Turn-Sampler (NUTS) 產生 500 個樣本。請注意,我們可以使用 JAX 轉換 (例如 jit) 來編譯使用 XLA 的 NUTS 取樣器。

init_key, sample_key = random.split(random.PRNGKey(0))
init_params = tuple(dist.sample(seed=init_key)[:-1])

@jit
def run_chain(key, state):
  kernel = tfp.mcmc.NoUTurnSampler(target_log_prob, 1e-3)
  return tfp.mcmc.sample_chain(500,
      current_state=state,
      kernel=kernel,
      trace_fn=lambda _, results: results.target_log_prob,
      num_burnin_steps=500,
      seed=key)

states, log_probs = run_chain(sample_key, init_params)
plt.figure()
plt.plot(log_probs)
plt.ylabel('Target Log Prob')
plt.xlabel('Iterations of NUTS')
plt.show()

png

讓我們使用我們的樣本執行貝氏模型平均 (BMA),方法是平均每組權重的預測機率。

首先,讓我們撰寫一個函式,針對給定的一組參數,產生每個類別的機率。我們可以使用 dist.sample_distributions 來取得模型中的最終分佈。

def classifier_probs(params):
  dists, _ = dist.sample_distributions(seed=random.PRNGKey(0),
                                       value=params + (None,))
  return dists[-1].distribution.probs_parameter()

我們可以對樣本集 vmap(classifier_probs),以取得每個樣本的預測類別機率。然後,我們計算每個樣本的平均準確度,以及貝氏模型平均的準確度。

all_probs = jit(vmap(classifier_probs))(states)
print('Average accuracy:', jnp.mean(all_probs.argmax(axis=-1) == labels))
print('BMA accuracy:', jnp.mean(all_probs.mean(axis=0).argmax(axis=-1) == labels))
Average accuracy: 0.96952
BMA accuracy: 0.97999996

看起來 BMA 將我們的錯誤率降低了將近三分之一!

基本原理

JAX 上的 TFP 具有與 TF 相同的 API,其中 JAX 上的 TFP 接受 JAX 對應物,而不是接受 tf.Tensor 等 TF 物件。例如,先前使用 tf.Tensor 作為輸入的任何位置,API 現在都預期為 JAX DeviceArray。TFP 方法會傳回 DeviceArray,而不是傳回 tf.Tensor。JAX 上的 TFP 也適用於 JAX 物件的巢狀結構,例如 DeviceArray 的清單或字典。

分佈

JAX 支援大多數 TFP 的分佈,其語意與 TF 對應物非常相似。它們也註冊為 JAX Pytrees,因此它們可以是 JAX 轉換函式的輸入和輸出。

基本分佈

分佈的 log_prob 方法運作方式相同。

dist = tfd.Normal(0., 1.)
print(dist.log_prob(0.))
-0.9189385

從分佈取樣需要明確傳入 PRNGKey (或整數清單) 作為 seed 關鍵字引數。未能明確傳入種子將會擲回錯誤。

tfd.Normal(0., 1.).sample(seed=random.PRNGKey(0))
DeviceArray(-0.20584226, dtype=float32)

分佈的形狀語意在 JAX 中保持不變,其中分佈將各自具有 event_shapebatch_shape,而繪製許多樣本將會新增額外的 sample_shape 維度。

例如,具有向量參數的 tfd.MultivariateNormalDiag 將具有向量事件形狀和空批次形狀。

dist = tfd.MultivariateNormalDiag(
    loc=jnp.zeros(5),
    scale_diag=jnp.ones(5)
)
print('Event shape:', dist.event_shape)
print('Batch shape:', dist.batch_shape)
Event shape: (5,)
Batch shape: ()

另一方面,以向量參數化的 tfd.Normal 將具有純量事件形狀和向量批次形狀。

dist = tfd.Normal(
    loc=jnp.ones(5),
    scale=jnp.ones(5),
)
print('Event shape:', dist.event_shape)
print('Batch shape:', dist.batch_shape)
Event shape: ()
Batch shape: (5,)

取得樣本 log_prob 的語意在 JAX 中也以相同方式運作。

dist =  tfd.Normal(jnp.zeros(5), jnp.ones(5))
s = dist.sample(sample_shape=(10, 2), seed=random.PRNGKey(0))
print(dist.log_prob(s).shape)

dist =  tfd.Independent(tfd.Normal(jnp.zeros(5), jnp.ones(5)), 1)
s = dist.sample(sample_shape=(10, 2), seed=random.PRNGKey(0))
print(dist.log_prob(s).shape)
(10, 2, 5)
(10, 2)

由於 JAX DeviceArray 與 NumPy 和 Matplotlib 等程式庫相容,因此我們可以將樣本直接饋送到繪圖函式中。

sns.distplot(tfd.Normal(0., 1.).sample(1000, seed=random.PRNGKey(0)))
plt.show()

png

Distribution 方法與 JAX 轉換相容。

sns.distplot(jit(vmap(lambda key: tfd.Normal(0., 1.).sample(seed=key)))(
    random.split(random.PRNGKey(0), 2000)))
plt.show()

png

x = jnp.linspace(-5., 5., 100)
plt.plot(x, jit(vmap(grad(tfd.Normal(0., 1.).prob)))(x))
plt.show()

png

由於 TFP 分佈已註冊為 JAX pytree 節點,因此我們可以撰寫以分佈作為輸入或輸出的函式,並使用 jit 轉換它們,但它們尚不支援作為 vmap 函式的引數。

@jit
def random_distribution(key):
  loc_key, scale_key = random.split(key)
  loc, log_scale = random.normal(loc_key), random.normal(scale_key)
  return tfd.Normal(loc, jnp.exp(log_scale))
random_dist = random_distribution(random.PRNGKey(0))
print(random_dist.mean(), random_dist.variance())
0.14389051 0.081832744

轉換分佈

轉換分佈,即樣本透過 Bijector 傳遞的分佈,也可以立即使用 (bijector 也適用!請參閱下方)。

dist = tfd.TransformedDistribution(
    tfd.Normal(0., 1.),
    tfb.Sigmoid()
)
sns.distplot(dist.sample(1000, seed=random.PRNGKey(0)))
plt.show()

png

聯合分佈

TFP 提供 JointDistribution,以啟用將元件分佈組合成多個隨機變數的單一分佈。目前,TFP 提供三種核心變體 (JointDistributionSequentialJointDistributionNamedJointDistributionCoroutine),所有這些變體都在 JAX 中受到支援。AutoBatched 變體也全部受到支援。

dist = tfd.JointDistributionSequential([
  tfd.Normal(0., 1.),
  lambda x: tfd.Normal(x, 1e-1)
])
plt.scatter(*dist.sample(1000, seed=random.PRNGKey(0)), alpha=0.5)
plt.show()

png

joint = tfd.JointDistributionNamed(dict(
    e=             tfd.Exponential(rate=1.),
    n=             tfd.Normal(loc=0., scale=2.),
    m=lambda n, e: tfd.Normal(loc=n, scale=e),
    x=lambda    m: tfd.Sample(tfd.Bernoulli(logits=m), 12),
))
joint.sample(seed=random.PRNGKey(0))
{'e': DeviceArray(3.376818, dtype=float32),
 'm': DeviceArray(2.5449684, dtype=float32),
 'n': DeviceArray(-0.6027825, dtype=float32),
 'x': DeviceArray([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int32)}
Root = tfd.JointDistributionCoroutine.Root
def model():
  e = yield Root(tfd.Exponential(rate=1.))
  n = yield Root(tfd.Normal(loc=0, scale=2.))
  m = yield tfd.Normal(loc=n, scale=e)
  x = yield tfd.Sample(tfd.Bernoulli(logits=m), 12)

joint = tfd.JointDistributionCoroutine(model)

joint.sample(seed=random.PRNGKey(0))
StructTuple(var0=DeviceArray(0.17315261, dtype=float32), var1=DeviceArray(-3.290489, dtype=float32), var2=DeviceArray(-3.1949058, dtype=float32), var3=DeviceArray([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32))

其他分佈

高斯過程在 JAX 模式下也適用!

k1, k2, k3 = random.split(random.PRNGKey(0), 3)
observation_noise_variance = 0.01
f = lambda x: jnp.sin(10*x[..., 0]) * jnp.exp(-x[..., 0]**2)
observation_index_points = random.uniform(
    k1, [50], minval=-1.,maxval= 1.)[..., jnp.newaxis]
observations = f(observation_index_points) + tfd.Normal(
    loc=0., scale=jnp.sqrt(observation_noise_variance)).sample(seed=k2)

index_points = jnp.linspace(-1., 1., 100)[..., jnp.newaxis]

kernel = tfpk.ExponentiatedQuadratic(length_scale=0.1)

gprm = tfd.GaussianProcessRegressionModel(
    kernel=kernel,
    index_points=index_points,
    observation_index_points=observation_index_points,
    observations=observations,
    observation_noise_variance=observation_noise_variance)

samples = gprm.sample(10, seed=k3)
for i in range(10):
  plt.plot(index_points, samples[i], alpha=0.5)
plt.plot(observation_index_points, observations, marker='o', linestyle='')
plt.show()

png

隱藏式馬可夫模型也受到支援。

initial_distribution = tfd.Categorical(probs=[0.8, 0.2])
transition_distribution = tfd.Categorical(probs=[[0.7, 0.3],
                                                 [0.2, 0.8]])

observation_distribution = tfd.Normal(loc=[0., 15.], scale=[5., 10.])

model = tfd.HiddenMarkovModel(
    initial_distribution=initial_distribution,
    transition_distribution=transition_distribution,
    observation_distribution=observation_distribution,
    num_steps=7)

print(model.mean())
print(model.log_prob(jnp.zeros(7)))
print(model.sample(seed=random.PRNGKey(0)))
[3.       6.       7.5      8.249999 8.625001 8.812501 8.90625 ]
/usr/local/lib/python3.6/dist-packages/tensorflow_probability/substrates/jax/distributions/hidden_markov_model.py:483: UserWarning: HiddenMarkovModel.log_prob in TFP versions < 0.12.0 had a bug in which the transition model was applied prior to the initial step. This bug has been fixed. You may observe a slight change in behavior.
  'HiddenMarkovModel.log_prob in TFP versions < 0.12.0 had a bug '
-19.855635
[ 1.3641367  0.505798   1.3626463  3.6541772  2.272286  15.10309
 22.794212 ]

由於嚴格依賴 TensorFlow 或 XLA 不相容性,因此目前不支援少數分佈 (例如 PixelCNN)。

Bijector

JAX 今天支援大多數 TFP 的 bijector!

tfb.Exp().inverse(1.)
DeviceArray(0., dtype=float32)
bij = tfb.Shift(1.)(tfb.Scale(3.))
print(bij.forward(jnp.ones(5)))
print(bij.inverse(jnp.ones(5)))
[4. 4. 4. 4. 4.]
[0. 0. 0. 0. 0.]
b = tfb.FillScaleTriL(diag_bijector=tfb.Exp(), diag_shift=None)
print(b.forward(x=[0., 0., 0.]))
print(b.inverse(y=[[1., 0], [.5, 2]]))
[[1. 0.]
 [0. 1.]]
[0.6931472 0.5       0.       ]
b = tfb.Chain([tfb.Exp(), tfb.Softplus()])
# or:
# b = tfb.Exp()(tfb.Softplus())
print(b.forward(-jnp.ones(5)))
[1.3678794 1.3678794 1.3678794 1.3678794 1.3678794]

Bijector 與 JAX 轉換相容,例如 jitgradvmap

jit(vmap(tfb.Exp().inverse))(jnp.arange(4.))
DeviceArray([     -inf, 0.       , 0.6931472, 1.0986123], dtype=float32)
x = jnp.linspace(0., 1., 100)
plt.plot(x, jit(grad(lambda x: vmap(tfb.Sigmoid().inverse)(x).sum()))(x))
plt.show()

png

某些 bijector (例如 RealNVPFFJORD) 尚不受支援。

MCMC

我們也已將 tfp.mcmc 移植到 JAX,因此我們可以在 JAX 中執行 Hamiltonian Monte Carlo (HMC) 和 No-U-Turn-Sampler (NUTS) 等演算法。

target_log_prob = tfd.MultivariateNormalDiag(jnp.zeros(2), jnp.ones(2)).log_prob

與 TF 上的 TFP 不同,我們需要使用 seed 關鍵字引數將 PRNGKey 傳遞到 sample_chain 中。

def run_chain(key, state):
  kernel = tfp.mcmc.NoUTurnSampler(target_log_prob, 1e-1)
  return tfp.mcmc.sample_chain(1000,
      current_state=state,
      kernel=kernel,
      trace_fn=lambda _, results: results.target_log_prob,
      seed=key)
states, log_probs = jit(run_chain)(random.PRNGKey(0), jnp.zeros(2))
plt.figure()
plt.scatter(*states.T, alpha=0.5)
plt.figure()
plt.plot(log_probs)
plt.show()

png

png

若要執行多個鏈,我們可以將一批狀態傳遞到 sample_chain 中,或使用 vmap (雖然我們尚未探索這兩種方法之間的效能差異)。

states, log_probs = jit(run_chain)(random.PRNGKey(0), jnp.zeros([10, 2]))
plt.figure()
for i in range(10):
  plt.scatter(*states[:, i].T, alpha=0.5)
plt.figure()
for i in range(10):
  plt.plot(log_probs[:, i], alpha=0.5)
plt.show()

png

png

最佳化工具

JAX 上的 TFP 支援一些重要的最佳化工具,例如 BFGS 和 L-BFGS。讓我們設定一個簡單的縮放二次損失函數。

minimum = jnp.array([1.0, 1.0])  # The center of the quadratic bowl.
scales = jnp.array([2.0, 3.0])  # The scales along the two axes.

# The objective function and the gradient.
def quadratic_loss(x):
  return jnp.sum(scales * jnp.square(x - minimum))

start = jnp.array([0.6, 0.8])  # Starting point for the search.

BFGS 可以找到此損失的最小值。

optim_results = tfp.optimizer.bfgs_minimize(
    value_and_grad(quadratic_loss), initial_position=start, tolerance=1e-8)

# Check that the search converged
assert(optim_results.converged)
# Check that the argmin is close to the actual value.
np.testing.assert_allclose(optim_results.position, minimum)
# Print out the total number of function evaluations it took. Should be 5.
print("Function evaluations: %d" % optim_results.num_objective_evaluations)
Function evaluations: 5

L-BFGS 也可以。

optim_results = tfp.optimizer.lbfgs_minimize(
    value_and_grad(quadratic_loss), initial_position=start, tolerance=1e-8)

# Check that the search converged
assert(optim_results.converged)
# Check that the argmin is close to the actual value.
np.testing.assert_allclose(optim_results.position, minimum)
# Print out the total number of function evaluations it took. Should be 5.
print("Function evaluations: %d" % optim_results.num_objective_evaluations)
Function evaluations: 5

若要 vmap L-BFGS,讓我們設定一個函式,該函式針對單一起點最佳化損失。

def optimize_single(start):
  return tfp.optimizer.lbfgs_minimize(
      value_and_grad(quadratic_loss), initial_position=start, tolerance=1e-8)

all_results = jit(vmap(optimize_single))(
    random.normal(random.PRNGKey(0), (10, 2)))
assert all(all_results.converged)
for i in range(10):
  np.testing.assert_allclose(optim_results.position[i], minimum)
print("Function evaluations: %s" % all_results.num_objective_evaluations)
Function evaluations: [6 6 9 6 6 8 6 8 5 9]

注意事項

TF 和 JAX 之間存在一些基本差異,某些 TFP 行為在兩個基底之間會有所不同,而且並非所有功能都受到支援。例如:

  • JAX 上的 TFP 不支援任何類似 tf.Variable 的項目,因為 JAX 中不存在類似項目。這也表示 tfp.util.TransformedVariable 等公用程式也不受支援。
  • 由於 tfp.layers 依賴 Keras 和 tf.Variable,因此後端尚不支援 tfp.layers
  • 由於 tfp.math.minimize 依賴 tf.Variable,因此 TFP on JAX 無法運作 tfp.math.minimize
  • 使用 JAX 上的 TFP 時,張量形狀一律為具體的整數值,而且絕不會像 TF 上的 TFP 那樣未知/動態。
  • 虛擬隨機性在 TF 和 JAX 中的處理方式不同 (請參閱附錄)。
  • tfp.experimental 中的程式庫不保證存在於 JAX 基底中。
  • TF 和 JAX 之間的 Dtype 升級規則不同。JAX 上的 TFP 會嘗試在內部遵循 TF 的 dtype 語意,以保持一致性。
  • Bijector 尚未註冊為 JAX pytree。

若要查看 JAX 上的 TFP 支援的完整清單,請參閱 API 文件

結論

我們已將許多 TFP 的功能移植到 JAX,並很高興看到大家將建構什麼。某些功能尚不受支援;如果您發現我們遺漏了對您重要的項目 (或您發現錯誤!),請與我們聯絡 -- 您可以傳送電子郵件至 tfprobability@tensorflow.org,或在 我們的 Github 存放區上提出問題。

附錄:JAX 中的虛擬隨機性

JAX 的虛擬隨機數字產生 (PRNG) 模型是無狀態的。與有狀態模型不同,沒有在每次隨機繪製後演變的可變全域狀態。在 JAX 的模型中,我們從 PRNG 金鑰開始,其作用類似於一對 32 位元整數。我們可以使用 jax.random.PRNGKey 建構這些金鑰。

key = random.PRNGKey(0)  # Creates a key with value [0, 0]
print(key)
[0 0]

JAX 中的隨機函式會使用金鑰來確定性產生隨機變數,表示它們不應再次使用。例如,我們可以使用 key 取樣常態分佈值,但我們不應在其他地方再次使用 key。此外,將相同的值傳遞到 random.normal 將會產生相同的值。

print(random.normal(key))
-0.20584226

那麼,我們如何從單一金鑰中繪製多個樣本?答案是金鑰分割。基本概念是我們可以將 PRNGKey 分割成多個,而每個新金鑰都可以視為獨立的隨機性來源。

key1, key2 = random.split(key, num=2)
print(key1, key2)
[4146024105  967050713] [2718843009 1272950319]

金鑰分割是確定性的,但具有混沌性,因此現在每個新金鑰都可以用來繪製不同的隨機樣本。

print(random.normal(key1), random.normal(key2))
0.14389051 -1.2515389

如需 JAX 確定性金鑰分割模型的詳細資訊,請參閱本指南