通用句子編碼器

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視 下載筆記本 查看 TF Hub 模型

這個筆記本說明如何存取通用句子編碼器,並將其用於句子相似度和句子分類任務。

通用句子編碼器讓取得句子層級嵌入變得像查詢個別字詞的嵌入一樣容易。句子嵌入可以輕鬆用於計算句子層級語義相似度,並在下游分類任務中使用較少的監督式訓練資料來提升效能。

設定

這個章節設定存取 TF Hub 上通用句子編碼器的環境,並提供將編碼器應用於字詞、句子和段落的範例。

%%capture
!pip3 install seaborn

關於安裝 TensorFlow 的更詳細資訊,請參閱 https://tensorflow.dev.org.tw/install/

載入通用句子編碼器的 TF Hub 模組

2024-03-10 12:03:32.159319: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
module https://tfhub.dev/google/universal-sentence-encoder/4 loaded

為每則訊息計算表示法,顯示支援的各種長度。

Message: Elephant
Embedding size: 512
Embedding: [0.008344484493136406, 0.0004808559315279126, 0.06595249474048615, ...]

Message: I am a sentence for which I would like to get its embedding.
Embedding size: 512
Embedding: [0.050808604806661606, -0.016524329781532288, 0.01573779620230198, ...]

Message: Universal Sentence Encoder embeddings also support short paragraphs. There is no hard limit on how long the paragraph is. Roughly, the longer the more 'diluted' the embedding will be.
Embedding size: 512
Embedding: [-0.028332693502306938, -0.0558621808886528, -0.012941482476890087, ...]

語義文本相似度任務範例

通用句子編碼器產生的嵌入已大致正規化。兩個句子的語義相似度可以輕鬆計算為編碼的內積。

def plot_similarity(labels, features, rotation):
  corr = np.inner(features, features)
  sns.set(font_scale=1.2)
  g = sns.heatmap(
      corr,
      xticklabels=labels,
      yticklabels=labels,
      vmin=0,
      vmax=1,
      cmap="YlOrRd")
  g.set_xticklabels(labels, rotation=rotation)
  g.set_title("Semantic Textual Similarity")

def run_and_plot(messages_):
  message_embeddings_ = embed(messages_)
  plot_similarity(messages_, message_embeddings_, 90)

相似度視覺化

我們在這裡以熱圖顯示相似度。最終圖表是一個 9x9 矩陣,其中每個條目 [i, j] 的顏色都是根據句子 ij 的編碼內積而定。

messages = [
    # Smartphones
    "I like my phone",
    "My phone is not good.",
    "Your cellphone looks great.",

    # Weather
    "Will it snow tomorrow?",
    "Recently a lot of hurricanes have hit the US",
    "Global warming is real",

    # Food and health
    "An apple a day, keeps the doctors away",
    "Eating strawberries is healthy",
    "Is paleo better than keto?",

    # Asking about age
    "How old are you?",
    "what is your age?",
]

run_and_plot(messages)

png

評估:STS (語義文本相似度) 基準

STS 基準提供內在評估,評估使用句子嵌入計算出的相似度分數與人類判斷的一致程度。此基準要求系統傳回各種句子配對的相似度分數。Pearson 相關性然後使用 Pearson 相關性來評估機器相似度分數相對於人類判斷的品質。

下載資料

import pandas
import scipy
import math
import csv

sts_dataset = tf.keras.utils.get_file(
    fname="Stsbenchmark.tar.gz",
    origin="http://ixa2.si.ehu.es/stswiki/images/4/48/Stsbenchmark.tar.gz",
    extract=True)
sts_dev = pandas.read_table(
    os.path.join(os.path.dirname(sts_dataset), "stsbenchmark", "sts-dev.csv"),
    skip_blank_lines=True,
    usecols=[4, 5, 6],
    names=["sim", "sent_1", "sent_2"])
sts_test = pandas.read_table(
    os.path.join(
        os.path.dirname(sts_dataset), "stsbenchmark", "sts-test.csv"),
    quoting=csv.QUOTE_NONE,
    skip_blank_lines=True,
    usecols=[4, 5, 6],
    names=["sim", "sent_1", "sent_2"])
# cleanup some NaN values in sts_dev
sts_dev = sts_dev[[isinstance(s, str) for s in sts_dev['sent_2']]]
Downloading data from http://ixa2.si.ehu.es/stswiki/images/4/48/Stsbenchmark.tar.gz
409630/409630 ━━━━━━━━━━━━━━━━━━━━ 1s 2us/step

評估句子嵌入

sts_data = sts_dev

def run_sts_benchmark(batch):
  sts_encode1 = tf.nn.l2_normalize(embed(tf.constant(batch['sent_1'].tolist())), axis=1)
  sts_encode2 = tf.nn.l2_normalize(embed(tf.constant(batch['sent_2'].tolist())), axis=1)
  cosine_similarities = tf.reduce_sum(tf.multiply(sts_encode1, sts_encode2), axis=1)
  clip_cosine_similarities = tf.clip_by_value(cosine_similarities, -1.0, 1.0)
  scores = 1.0 - tf.acos(clip_cosine_similarities) / math.pi
  """Returns the similarity scores"""
  return scores

dev_scores = sts_data['sim'].tolist()
scores = []
for batch in np.array_split(sts_data, 10):
  scores.extend(run_sts_benchmark(batch))

pearson_correlation = scipy.stats.pearsonr(scores, dev_scores)
print('Pearson correlation coefficient = {0}\np-value = {1}'.format(
    pearson_correlation[0], pearson_correlation[1]))
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/numpy/core/fromnumeric.py:59: FutureWarning: 'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.
  return bound(*args, **kwds)
Pearson correlation coefficient = 0.8036396940028219
p-value = 0.0