基於 BERT-SNGP 的不確定性感知深度語言學習

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視 下載筆記本 查看 TF Hub 模型

SNGP 教學課程中,您已學會如何在深度殘差網路的基礎上建構 SNGP 模型,以提升模型量化不確定性的能力。在本教學課程中,您會將 SNGP 應用於自然語言理解 (NLU) 工作,方法是在深度 BERT 編碼器的基礎上建構 SNGP,以提升深度 NLU 模型偵測範圍外查詢的能力。

具體來說,您將執行下列操作:

  • 建構 BERT-SNGP,這是一種 SNGP 擴增的 BERT 模型。
  • 載入 CLINC 範圍外 (OOS) 意圖偵測資料集。
  • 訓練 BERT-SNGP 模型。
  • 評估 BERT-SNGP 模型在不確定性校正和網域外偵測方面的效能。

除了 CLINC OOS 以外,SNGP 模型也已應用於大型資料集,例如 Jigsaw 毒性偵測,以及影像資料集,例如 CIFAR-100ImageNet。如要查看 SNGP 和其他不確定性方法的基準測試結果,以及具備端對端訓練/評估指令碼的高品質實作,您可以查看 Uncertainty Baselines 基準測試。

設定

pip uninstall -y tensorflow tf-text
pip install "tensorflow-text==2.11.*"
pip install -U tf-models-official==2.11.0
import matplotlib.pyplot as plt

import sklearn.metrics
import sklearn.calibration

import tensorflow_hub as hub
import tensorflow_datasets as tfds

import numpy as np
import tensorflow as tf

import official.nlp.modeling.layers as layers
import official.nlp.optimization as optimization

本教學課程需要 GPU 才能有效率地執行。請檢查 GPU 是否可用。

tf.__version__
gpus = tf.config.list_physical_devices('GPU')
gpus
assert gpus, """
  No GPU(s) found! This tutorial will take many hours to run without a GPU.

  You may hit this error if the installed tensorflow package is not
  compatible with the CUDA and CUDNN versions."""

首先,請按照使用 BERT 進行文字分類教學課程,實作標準 BERT 分類器。我們將使用 BERT-base 編碼器,以及內建的 ClassificationHead 作為分類器。

標準 BERT 模型

建構 SNGP 模型

如要實作 BERT-SNGP 模型,您只需要將 ClassificationHead 替換為內建的 GaussianProcessClassificationHead 即可。光譜正規化已預先封裝到這個分類標頭中。如同 SNGP 教學課程中所示,將共變異數重設回呼新增至模型,讓模型在新的 epoch 開始時自動重設共變異數估算器,避免重複計算相同的資料。

class ResetCovarianceCallback(tf.keras.callbacks.Callback):

  def on_epoch_begin(self, epoch, logs=None):
    """Resets covariance matrix at the beginning of the epoch."""
    if epoch > 0:
      self.model.classifier.reset_covariance_matrix()
class SNGPBertClassifier(BertClassifier):

  def make_classification_head(self, num_classes, inner_dim, dropout_rate):
    return layers.GaussianProcessClassificationHead(
        num_classes=num_classes, 
        inner_dim=inner_dim,
        dropout_rate=dropout_rate,
        gp_cov_momentum=-1,
        temperature=30.,
        **self.classifier_kwargs)

  def fit(self, *args, **kwargs):
    """Adds ResetCovarianceCallback to model callbacks."""
    kwargs['callbacks'] = list(kwargs.get('callbacks', []))
    kwargs['callbacks'].append(ResetCovarianceCallback())

    return super().fit(*args, **kwargs)

載入 CLINC OOS 資料集

現在載入 CLINC OOS 意圖偵測資料集。這個資料集包含透過 150 個意圖類別收集的 15000 個使用者口語查詢,也包含 1000 個已知類別未涵蓋的網域外 (OOD) 句子。

(clinc_train, clinc_test, clinc_test_oos), ds_info = tfds.load(
    'clinc_oos', split=['train', 'test', 'test_oos'], with_info=True, batch_size=-1)

製作訓練和測試資料。

train_examples = clinc_train['text']
train_labels = clinc_train['intent']

# Makes the in-domain (IND) evaluation data.
ind_eval_data = (clinc_test['text'], clinc_test['intent'])

建立 OOD 評估資料集。為此,請結合網域內測試資料 clinc_test 和網域外資料 clinc_test_oos。我們也會將標籤 0 指派給網域內範例,並將標籤 1 指派給網域外範例。

test_data_size = ds_info.splits['test'].num_examples
oos_data_size = ds_info.splits['test_oos'].num_examples

# Combines the in-domain and out-of-domain test examples.
oos_texts = tf.concat([clinc_test['text'], clinc_test_oos['text']], axis=0)
oos_labels = tf.constant([0] * test_data_size + [1] * oos_data_size)

# Converts into a TF dataset.
ood_eval_dataset = tf.data.Dataset.from_tensor_slices(
    {"text": oos_texts, "label": oos_labels})

訓練與評估

首先設定基本訓練組態。

TRAIN_EPOCHS = 3
TRAIN_BATCH_SIZE = 32
EVAL_BATCH_SIZE = 256

optimizer = bert_optimizer(learning_rate=1e-4)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metrics = tf.metrics.SparseCategoricalAccuracy()
fit_configs = dict(batch_size=TRAIN_BATCH_SIZE,
                   epochs=TRAIN_EPOCHS,
                   validation_batch_size=EVAL_BATCH_SIZE, 
                   validation_data=ind_eval_data)
sngp_model = SNGPBertClassifier()
sngp_model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
sngp_model.fit(train_examples, train_labels, **fit_configs)

評估 OOD 效能

評估模型偵測不熟悉的網域外查詢的效果。為了進行嚴謹的評估,請使用先前建立的 OOD 評估資料集 ood_eval_dataset

將 OOD 機率計算為 \(1 - p(x)\),其中 \(p(x)=softmax(logit(x))\) 是預測機率。

sngp_probs, ood_labels = oos_predict(sngp_model, ood_eval_dataset)
ood_probs = 1 - sngp_probs

現在評估模型的不確定性分數 ood_probs 預測網域外標籤的效果。首先計算 OOD 機率與 OOD 偵測準確度的精確度-召回率曲線下面積 (AUPRC)。

precision, recall, _ = sklearn.metrics.precision_recall_curve(ood_labels, ood_probs)
auprc = sklearn.metrics.auc(recall, precision)
print(f'SNGP AUPRC: {auprc:.4f}')

這符合 Uncertainty Baselines 下 CLINC OOS 基準測試中回報的 SNGP 效能。

接下來,檢查模型在不確定性校正方面的品質,也就是模型的預測機率是否與其預測準確度一致。校正良好的模型被認為值得信賴,例如,其預測機率 \(p(x)=0.8\) 表示模型在 80% 的時間內是正確的。

prob_true, prob_pred = sklearn.calibration.calibration_curve(
    ood_labels, ood_probs, n_bins=10, strategy='quantile')
plt.plot(prob_pred, prob_true)

plt.plot([0., 1.], [0., 1.], c='k', linestyle="--")
plt.xlabel('Predictive Probability')
plt.ylabel('Predictive Accuracy')
plt.title('Calibration Plots, SNGP')

plt.show()

資源與延伸閱讀