![]() |
![]() |
![]() |
![]() |
![]() |
在SNGP 教學課程中,您已學會如何在深度殘差網路的基礎上建構 SNGP 模型,以提升模型量化不確定性的能力。在本教學課程中,您會將 SNGP 應用於自然語言理解 (NLU) 工作,方法是在深度 BERT 編碼器的基礎上建構 SNGP,以提升深度 NLU 模型偵測範圍外查詢的能力。
具體來說,您將執行下列操作:
- 建構 BERT-SNGP,這是一種 SNGP 擴增的 BERT 模型。
- 載入 CLINC 範圍外 (OOS) 意圖偵測資料集。
- 訓練 BERT-SNGP 模型。
- 評估 BERT-SNGP 模型在不確定性校正和網域外偵測方面的效能。
除了 CLINC OOS 以外,SNGP 模型也已應用於大型資料集,例如 Jigsaw 毒性偵測,以及影像資料集,例如 CIFAR-100 和 ImageNet。如要查看 SNGP 和其他不確定性方法的基準測試結果,以及具備端對端訓練/評估指令碼的高品質實作,您可以查看 Uncertainty Baselines 基準測試。
設定
pip uninstall -y tensorflow tf-text
pip install "tensorflow-text==2.11.*"
pip install -U tf-models-official==2.11.0
import matplotlib.pyplot as plt
import sklearn.metrics
import sklearn.calibration
import tensorflow_hub as hub
import tensorflow_datasets as tfds
import numpy as np
import tensorflow as tf
import official.nlp.modeling.layers as layers
import official.nlp.optimization as optimization
本教學課程需要 GPU 才能有效率地執行。請檢查 GPU 是否可用。
tf.__version__
gpus = tf.config.list_physical_devices('GPU')
gpus
assert gpus, """
No GPU(s) found! This tutorial will take many hours to run without a GPU.
You may hit this error if the installed tensorflow package is not
compatible with the CUDA and CUDNN versions."""
首先,請按照使用 BERT 進行文字分類教學課程,實作標準 BERT 分類器。我們將使用 BERT-base 編碼器,以及內建的 ClassificationHead
作為分類器。
標準 BERT 模型
建構 SNGP 模型
如要實作 BERT-SNGP 模型,您只需要將 ClassificationHead
替換為內建的 GaussianProcessClassificationHead
即可。光譜正規化已預先封裝到這個分類標頭中。如同 SNGP 教學課程中所示,將共變異數重設回呼新增至模型,讓模型在新的 epoch 開始時自動重設共變異數估算器,避免重複計算相同的資料。
class ResetCovarianceCallback(tf.keras.callbacks.Callback):
def on_epoch_begin(self, epoch, logs=None):
"""Resets covariance matrix at the beginning of the epoch."""
if epoch > 0:
self.model.classifier.reset_covariance_matrix()
class SNGPBertClassifier(BertClassifier):
def make_classification_head(self, num_classes, inner_dim, dropout_rate):
return layers.GaussianProcessClassificationHead(
num_classes=num_classes,
inner_dim=inner_dim,
dropout_rate=dropout_rate,
gp_cov_momentum=-1,
temperature=30.,
**self.classifier_kwargs)
def fit(self, *args, **kwargs):
"""Adds ResetCovarianceCallback to model callbacks."""
kwargs['callbacks'] = list(kwargs.get('callbacks', []))
kwargs['callbacks'].append(ResetCovarianceCallback())
return super().fit(*args, **kwargs)
載入 CLINC OOS 資料集
現在載入 CLINC OOS 意圖偵測資料集。這個資料集包含透過 150 個意圖類別收集的 15000 個使用者口語查詢,也包含 1000 個已知類別未涵蓋的網域外 (OOD) 句子。
(clinc_train, clinc_test, clinc_test_oos), ds_info = tfds.load(
'clinc_oos', split=['train', 'test', 'test_oos'], with_info=True, batch_size=-1)
製作訓練和測試資料。
train_examples = clinc_train['text']
train_labels = clinc_train['intent']
# Makes the in-domain (IND) evaluation data.
ind_eval_data = (clinc_test['text'], clinc_test['intent'])
建立 OOD 評估資料集。為此,請結合網域內測試資料 clinc_test
和網域外資料 clinc_test_oos
。我們也會將標籤 0 指派給網域內範例,並將標籤 1 指派給網域外範例。
test_data_size = ds_info.splits['test'].num_examples
oos_data_size = ds_info.splits['test_oos'].num_examples
# Combines the in-domain and out-of-domain test examples.
oos_texts = tf.concat([clinc_test['text'], clinc_test_oos['text']], axis=0)
oos_labels = tf.constant([0] * test_data_size + [1] * oos_data_size)
# Converts into a TF dataset.
ood_eval_dataset = tf.data.Dataset.from_tensor_slices(
{"text": oos_texts, "label": oos_labels})
訓練與評估
首先設定基本訓練組態。
TRAIN_EPOCHS = 3
TRAIN_BATCH_SIZE = 32
EVAL_BATCH_SIZE = 256
optimizer = bert_optimizer(learning_rate=1e-4)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metrics = tf.metrics.SparseCategoricalAccuracy()
fit_configs = dict(batch_size=TRAIN_BATCH_SIZE,
epochs=TRAIN_EPOCHS,
validation_batch_size=EVAL_BATCH_SIZE,
validation_data=ind_eval_data)
sngp_model = SNGPBertClassifier()
sngp_model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
sngp_model.fit(train_examples, train_labels, **fit_configs)
評估 OOD 效能
評估模型偵測不熟悉的網域外查詢的效果。為了進行嚴謹的評估,請使用先前建立的 OOD 評估資料集 ood_eval_dataset
。
將 OOD 機率計算為 \(1 - p(x)\),其中 \(p(x)=softmax(logit(x))\) 是預測機率。
sngp_probs, ood_labels = oos_predict(sngp_model, ood_eval_dataset)
ood_probs = 1 - sngp_probs
現在評估模型的不確定性分數 ood_probs
預測網域外標籤的效果。首先計算 OOD 機率與 OOD 偵測準確度的精確度-召回率曲線下面積 (AUPRC)。
precision, recall, _ = sklearn.metrics.precision_recall_curve(ood_labels, ood_probs)
auprc = sklearn.metrics.auc(recall, precision)
print(f'SNGP AUPRC: {auprc:.4f}')
這符合 Uncertainty Baselines 下 CLINC OOS 基準測試中回報的 SNGP 效能。
接下來,檢查模型在不確定性校正方面的品質,也就是模型的預測機率是否與其預測準確度一致。校正良好的模型被認為值得信賴,例如,其預測機率 \(p(x)=0.8\) 表示模型在 80% 的時間內是正確的。
prob_true, prob_pred = sklearn.calibration.calibration_curve(
ood_labels, ood_probs, n_bins=10, strategy='quantile')
plt.plot(prob_pred, prob_true)
plt.plot([0., 1.], [0., 1.], c='k', linestyle="--")
plt.xlabel('Predictive Probability')
plt.ylabel('Predictive Accuracy')
plt.title('Calibration Plots, SNGP')
plt.show()
資源與延伸閱讀
- 如要詳細瞭解從頭開始實作 SNGP 的逐步解說,請參閱 SNGP 教學課程。
- 如要瞭解 SNGP 模型 (和許多其他不確定性方法) 在各種基準測試資料集 (例如 CIFAR、ImageNet、Jigsaw 毒性偵測等) 上的實作,請參閱 Uncertainty Baselines。
- 如要深入瞭解 SNGP 方法,請查看論文 Simple and Principled Uncertainty Estimation with Deterministic Deep Learning via Distance Awareness。