探索 TF-Hub CORD-19 Swivel 嵌入

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視 下載筆記本 查看 TF Hub 模型

TF-Hub 的 CORD-19 Swivel 文字嵌入模組 ( https://tfhub.dev/tensorflow/cord-19/swivel-128d/3 ) 旨在支援研究人員分析與 COVID-19 相關的自然語言文字。這些嵌入是在 CORD-19 資料集 中文章的標題、作者、摘要、內文和參考文獻標題上訓練而成。

在此 Colab 中,我們將:

  • 分析嵌入空間中語意相似的字詞
  • 使用 CORD-19 嵌入在 SciCite 資料集上訓練分類器

設定

import functools
import itertools
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd

import tensorflow as tf

import tensorflow_datasets as tfds
import tensorflow_hub as hub

from tqdm import trange
2023-10-09 22:29:39.063949: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-10-09 22:29:39.064004: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-10-09 22:29:39.064046: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

分析嵌入

首先,我們先分析嵌入,方法是計算並繪製不同詞彙之間的相關矩陣。如果嵌入成功學習到擷取不同字詞的含義,則語意相似的字詞的嵌入向量應該彼此接近。讓我們看看一些與 COVID-19 相關的詞彙。

# Use the inner product between two embedding vectors as the similarity measure
def plot_correlation(labels, features):
  corr = np.inner(features, features)
  corr /= np.max(corr)
  sns.heatmap(corr, xticklabels=labels, yticklabels=labels)

# Generate embeddings for some terms
queries = [
  # Related viruses
  'coronavirus', 'SARS', 'MERS',
  # Regions
  'Italy', 'Spain', 'Europe',
  # Symptoms
  'cough', 'fever', 'throat'
]

module = hub.load('https://tfhub.dev/tensorflow/cord-19/swivel-128d/3')
embeddings = module(queries)

plot_correlation(queries, embeddings)
2023-10-09 22:29:44.568345: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:268] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

png

我們可以看到嵌入成功擷取了不同詞彙的含義。每個字詞都與其叢集中的其他字詞相似 (亦即「coronavirus」與「SARS」和「MERS」高度相關),但與其他叢集的詞彙不同 (亦即「SARS」與「Spain」之間的相似度接近 0)。

現在讓我們看看如何使用這些嵌入來解決特定任務。

SciCite:引用意圖分類

本節說明如何將嵌入用於下游任務,例如文字分類。我們將使用 TensorFlow Datasets 中的 SciCite 資料集,對學術論文中的引用意圖進行分類。給定學術論文中帶有引用的句子,判斷引用的主要意圖是作為背景資訊、方法的使用還是結果的比較。

builder = tfds.builder(name='scicite')
builder.download_and_prepare()
train_data, validation_data, test_data = builder.as_dataset(
    split=('train', 'validation', 'test'),
    as_supervised=True)

讓我們看看訓練集中的幾個標記範例

訓練引用意圖分類器

我們將使用 Keras 在 SciCite 資料集 上訓練分類器。讓我們建構一個模型,該模型使用 CORD-19 嵌入,並在頂部新增一個分類層。

超參數

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 keras_layer (KerasLayer)    (None, 128)               17301632  
                                                                 
 dense (Dense)               (None, 3)                 387       
                                                                 
=================================================================
Total params: 17302019 (132.00 MB)
Trainable params: 387 (1.51 KB)
Non-trainable params: 17301632 (132.00 MB)
_________________________________________________________________

訓練和評估模型

讓我們訓練和評估模型,以查看在 SciCite 任務上的效能

EPOCHS = 35
BATCH_SIZE = 32

history = model.fit(train_data.shuffle(10000).batch(BATCH_SIZE),
                    epochs=EPOCHS,
                    validation_data=validation_data.batch(BATCH_SIZE),
                    verbose=1)
Epoch 1/35
257/257 [==============================] - 2s 4ms/step - loss: 0.8391 - accuracy: 0.6421 - val_loss: 0.7495 - val_accuracy: 0.7020
Epoch 2/35
257/257 [==============================] - 1s 3ms/step - loss: 0.6784 - accuracy: 0.7282 - val_loss: 0.6634 - val_accuracy: 0.7380
Epoch 3/35
257/257 [==============================] - 1s 3ms/step - loss: 0.6175 - accuracy: 0.7562 - val_loss: 0.6269 - val_accuracy: 0.7478
Epoch 4/35
257/257 [==============================] - 1s 3ms/step - loss: 0.5858 - accuracy: 0.7706 - val_loss: 0.6035 - val_accuracy: 0.7533
Epoch 5/35
257/257 [==============================] - 1s 3ms/step - loss: 0.5674 - accuracy: 0.7780 - val_loss: 0.5914 - val_accuracy: 0.7576
Epoch 6/35
257/257 [==============================] - 1s 3ms/step - loss: 0.5553 - accuracy: 0.7817 - val_loss: 0.5822 - val_accuracy: 0.7653
Epoch 7/35
257/257 [==============================] - 1s 3ms/step - loss: 0.5464 - accuracy: 0.7847 - val_loss: 0.5784 - val_accuracy: 0.7609
Epoch 8/35
257/257 [==============================] - 1s 3ms/step - loss: 0.5399 - accuracy: 0.7872 - val_loss: 0.5723 - val_accuracy: 0.7707
Epoch 9/35
257/257 [==============================] - 1s 3ms/step - loss: 0.5352 - accuracy: 0.7906 - val_loss: 0.5690 - val_accuracy: 0.7707
Epoch 10/35
257/257 [==============================] - 1s 3ms/step - loss: 0.5303 - accuracy: 0.7924 - val_loss: 0.5630 - val_accuracy: 0.7806
Epoch 11/35
257/257 [==============================] - 1s 3ms/step - loss: 0.5268 - accuracy: 0.7939 - val_loss: 0.5610 - val_accuracy: 0.7773
Epoch 12/35
257/257 [==============================] - 1s 3ms/step - loss: 0.5236 - accuracy: 0.7929 - val_loss: 0.5601 - val_accuracy: 0.7762
Epoch 13/35
257/257 [==============================] - 1s 3ms/step - loss: 0.5213 - accuracy: 0.7952 - val_loss: 0.5586 - val_accuracy: 0.7773
Epoch 14/35
257/257 [==============================] - 1s 3ms/step - loss: 0.5188 - accuracy: 0.7959 - val_loss: 0.5560 - val_accuracy: 0.7751
Epoch 15/35
257/257 [==============================] - 1s 3ms/step - loss: 0.5169 - accuracy: 0.7963 - val_loss: 0.5566 - val_accuracy: 0.7817
Epoch 16/35
257/257 [==============================] - 1s 3ms/step - loss: 0.5150 - accuracy: 0.7950 - val_loss: 0.5521 - val_accuracy: 0.7795
Epoch 17/35
257/257 [==============================] - 1s 3ms/step - loss: 0.5136 - accuracy: 0.7974 - val_loss: 0.5551 - val_accuracy: 0.7795
Epoch 18/35
257/257 [==============================] - 1s 3ms/step - loss: 0.5122 - accuracy: 0.7966 - val_loss: 0.5490 - val_accuracy: 0.7795
Epoch 19/35
257/257 [==============================] - 1s 3ms/step - loss: 0.5106 - accuracy: 0.7973 - val_loss: 0.5508 - val_accuracy: 0.7849
Epoch 20/35
257/257 [==============================] - 1s 3ms/step - loss: 0.5097 - accuracy: 0.7974 - val_loss: 0.5503 - val_accuracy: 0.7806
Epoch 21/35
257/257 [==============================] - 1s 3ms/step - loss: 0.5086 - accuracy: 0.7981 - val_loss: 0.5467 - val_accuracy: 0.7817
Epoch 22/35
257/257 [==============================] - 1s 3ms/step - loss: 0.5072 - accuracy: 0.8003 - val_loss: 0.5518 - val_accuracy: 0.7838
Epoch 23/35
257/257 [==============================] - 1s 3ms/step - loss: 0.5066 - accuracy: 0.7994 - val_loss: 0.5485 - val_accuracy: 0.7871
Epoch 24/35
257/257 [==============================] - 1s 3ms/step - loss: 0.5060 - accuracy: 0.7991 - val_loss: 0.5477 - val_accuracy: 0.7849
Epoch 25/35
257/257 [==============================] - 1s 3ms/step - loss: 0.5049 - accuracy: 0.8003 - val_loss: 0.5481 - val_accuracy: 0.7849
Epoch 26/35
257/257 [==============================] - 1s 3ms/step - loss: 0.5046 - accuracy: 0.7985 - val_loss: 0.5465 - val_accuracy: 0.7871
Epoch 27/35
257/257 [==============================] - 1s 3ms/step - loss: 0.5035 - accuracy: 0.7999 - val_loss: 0.5457 - val_accuracy: 0.7828
Epoch 28/35
257/257 [==============================] - 1s 3ms/step - loss: 0.5030 - accuracy: 0.8011 - val_loss: 0.5474 - val_accuracy: 0.7838
Epoch 29/35
257/257 [==============================] - 1s 3ms/step - loss: 0.5025 - accuracy: 0.8007 - val_loss: 0.5484 - val_accuracy: 0.7871
Epoch 30/35
257/257 [==============================] - 1s 3ms/step - loss: 0.5016 - accuracy: 0.8023 - val_loss: 0.5440 - val_accuracy: 0.7904
Epoch 31/35
257/257 [==============================] - 1s 3ms/step - loss: 0.5011 - accuracy: 0.8003 - val_loss: 0.5487 - val_accuracy: 0.7849
Epoch 32/35
257/257 [==============================] - 1s 3ms/step - loss: 0.5011 - accuracy: 0.8012 - val_loss: 0.5451 - val_accuracy: 0.7882
Epoch 33/35
257/257 [==============================] - 1s 3ms/step - loss: 0.5005 - accuracy: 0.8011 - val_loss: 0.5464 - val_accuracy: 0.7882
Epoch 34/35
257/257 [==============================] - 1s 3ms/step - loss: 0.5000 - accuracy: 0.8014 - val_loss: 0.5486 - val_accuracy: 0.7871
Epoch 35/35
257/257 [==============================] - 1s 3ms/step - loss: 0.4995 - accuracy: 0.8006 - val_loss: 0.5485 - val_accuracy: 0.7871
from matplotlib import pyplot as plt
def display_training_curves(training, validation, title, subplot):
  if subplot%10==1: # set up the subplots on the first call
    plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
    plt.tight_layout()
  ax = plt.subplot(subplot)
  ax.set_facecolor('#F8F8F8')
  ax.plot(training)
  ax.plot(validation)
  ax.set_title('model '+ title)
  ax.set_ylabel(title)
  ax.set_xlabel('epoch')
  ax.legend(['train', 'valid.'])
display_training_curves(history.history['accuracy'], history.history['val_accuracy'], 'accuracy', 211)
display_training_curves(history.history['loss'], history.history['val_loss'], 'loss', 212)

png

評估模型

讓我們看看模型的效能。將傳回兩個值。損失 (代表我們錯誤的數字,值越低越好) 和準確度。

results = model.evaluate(test_data.batch(512), verbose=2)

for name, value in zip(model.metrics_names, results):
  print('%s: %.3f' % (name, value))
4/4 - 0s - loss: 0.5406 - accuracy: 0.7805 - 293ms/epoch - 73ms/step
loss: 0.541
accuracy: 0.781

我們可以看到損失快速下降,尤其是準確度迅速提高。讓我們繪製一些範例,以檢查預測如何與真實標籤相關

prediction_dataset = next(iter(test_data.batch(20)))

prediction_texts = [ex.numpy().decode('utf8') for ex in prediction_dataset[0]]
prediction_labels = [label2str(x) for x in prediction_dataset[1]]

predictions = [
    label2str(x) for x in np.argmax(model.predict(prediction_texts), axis=-1)]


pd.DataFrame({
    TEXT_FEATURE_NAME: prediction_texts,
    LABEL_NAME: prediction_labels,
    'prediction': predictions
})
1/1 [==============================] - 0s 122ms/step

我們可以看到,對於這個隨機樣本,模型在大多數情況下預測了正確的標籤,這表示它可以很好地嵌入科學句子。

下一步?

現在您已進一步瞭解 TF-Hub 的 CORD-19 Swivel 嵌入,我們鼓勵您參與 CORD-19 Kaggle 競賽,為從 COVID-19 相關學術文本中獲得科學見解做出貢獻。