搭配 TensorFlow Lite Model Maker 的文字搜尋器

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視原始碼 下載筆記本 查看 TF Hub 模型

在這個 Colab 筆記本中,您可以學習如何使用 TensorFlow Lite Model Maker 程式庫建立 TFLite Searcher 模型。您可以使用文字搜尋器模型,為您的應用程式建構語意搜尋或智慧回覆功能。這類模型可讓您擷取文字查詢,並在文字資料集 (例如網頁資料庫) 中搜尋最相關的條目。模型會傳回資料集中距離評分最小的條目清單,包括您指定的中繼資料,例如網址、網頁標題或其他文字條目 ID。建構完成後,您可以使用 Task Library Searcher API 將其部署到裝置 (例如 Android),只需幾行程式碼即可執行推論。

本教學課程利用 CNN/DailyMail 資料集作為範例,建立 TFLite Searcher 模型。您可以嘗試使用您自己的資料集,並採用相容的輸入逗號分隔值 (CSV) 格式。

使用可擴充最近鄰演算法的文字搜尋

本教學課程使用公開提供的 CNN/DailyMail 非匿名摘要資料集,該資料集是從 GitHub 存放區產生。這個資料集包含超過 30 萬篇新聞文章,使其成為建構搜尋器模型的好資料集,並在模型推論期間針對文字查詢傳回各種相關新聞。

本範例中的文字搜尋器模型使用 ScaNN (可擴充最近鄰演算法) 索引檔案,該檔案可以從預先定義的資料庫中搜尋類似項目。ScaNN 在大規模有效率的向量相似度搜尋方面,實現了最先進的效能。

本教學課程中使用這個資料集中的重點新聞和網址來建立模型

  1. 重點新聞是用於產生嵌入功能向量的文字,然後用於搜尋。
  2. 網址是在搜尋相關重點新聞後向使用者顯示的傳回結果。

本教學課程將這些資料儲存到 CSV 檔案中,然後使用 CSV 檔案來建構模型。以下是資料集中的幾個範例。

重點新聞 網址
夏威夷航空再次在準時起飛率方面名列第一。《航空公司品質排名報告》調查了美國 14 家最大的航空公司。ExpressJet
和美國航空的準時起飛率最差。維珍美國航空的行李處理表現最佳;西南航空的客訴率最低。
http://www.cnn.com/2013/04/08/travel/airline-quality-report
歐洲足球管理機構公布競標主辦 2020 年決賽的國家/地區名單。第 60 屆決賽將由 13 個國家/地區主辦
。32 個國家/地區正在考慮競標主辦 2020 年賽事。歐洲足總將於 9 月 25 日公布主辦城市。
http://edition.cnn.com:80/2013/09/20/sport/football/football-euro-2020-bid-countries/index.html?
從前的章魚獵人 Dylan Mayer 現在也簽署了一份由 5,000 名潛水員連署的請願書,禁止在海冠公園獵捕章魚。華盛頓的決定
魚類及野生動物部門可能需要數個月才能做出決定。
http://www.dailymail.co.uk:80/news/article-2238423/Dylan-Mayer-Washington-considers-ban-Octopus-hunting-diver-caught-ate-Puget-Sound.html?
在宇宙大爆炸後 4.2 億年觀測到星系。由 NASA 的哈伯太空望遠鏡、史匹哲太空望遠鏡和其中一個自然界的望遠鏡發現
在太空中的自然「變焦鏡頭」。
http://www.dailymail.co.uk/sciencetech/article-2233883/The-furthest-object-seen-Record-breaking-image-shows-galaxy-13-3-BILLION-light-years-Earth.html

設定

首先,安裝必要的套件,包括來自 GitHub 存放區的 Model Maker 套件。

sudo apt -y install libportaudio2
pip install -q tflite-model-maker
pip install gdown

匯入必要的套件。

from tflite_model_maker import searcher

準備資料集

本教學課程使用來自 GitHub 存放區的 CNN/Daily Mail 摘要資料集。

首先,下載 cnn 和 dailymail 的文字和網址,然後解壓縮。如果從 Google 雲端硬碟下載失敗,請稍候幾分鐘再試一次,或手動下載,然後將其上傳到 Colab。

gdown https://drive.google.com/uc?id=0BwmD_VLjROrfTHk4NFg2SndKcjQ
gdown https://drive.google.com/uc?id=0BwmD_VLjROrfM1BxdkxVaTY2bWs

wget -O all_train.txt https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt
tar xzf cnn_stories.tgz
tar xzf dailymail_stories.tgz

然後,將資料儲存到 CSV 檔案中,該檔案可以載入 tflite_model_maker 程式庫。程式碼是以在 tensorflow_datasets 中載入此資料的邏輯為基礎。我們無法直接使用 tensorflow_dataset,因為它不包含本 Colab 中使用的網址。

由於將資料處理成整個資料集的嵌入功能向量需要很長時間。預設只選取 CNN 和 Daily Mail 資料集的前 5% 的報導以供示範用途。您可以調整比例,或嘗試使用預先建構的 TFLite 模型,其中包含 CNN 和 Daily Mail 資料集 50% 的報導,以便進行搜尋。

將重點新聞和網址儲存到 CSV 檔案

建構文字搜尋器模型

透過載入資料集、使用資料建立模型並匯出 TFLite 模型,來建立文字搜尋器模型。

步驟 1:載入資料集

Model Maker 採用 CSV 格式的文字資料集和每個文字字串的對應中繼資料 (例如本範例中的網址)。它會使用使用者指定的嵌入器模型,將文字字串嵌入到功能向量中。

在本示範中,我們使用 Universal Sentence Encoder 建構搜尋器模型,這是一種最先進的句子嵌入模型,已從 colab 重新訓練。此模型已針對裝置端推論效能進行最佳化,並且僅需 6 毫秒即可嵌入查詢字串 (在 Pixel 6 上測得)。或者,您可以使用 這個量化版本,它較小,但每次嵌入需要 38 毫秒。

wget -O universal_sentence_encoder.tflite https://storage.googleapis.com/download.tensorflow.org/models/tflite_support/searcher/text_to_image_blogpost/text_embedder.tflite

建立 searcher.TextDataLoader 執行個體,並使用 data_loader.load_from_csv 方法載入資料集。此步驟需要約 10 分鐘,因為它會逐一產生每個文字的嵌入功能向量。您可以嘗試上傳自己的 CSV 檔案並載入它,以建構自訂模型。

在 CSV 檔案中指定文字欄和中繼資料欄的名稱。

  • 文字用於產生嵌入功能向量。
  • 中繼資料是在您搜尋特定文字時要顯示的內容。

以下是上面產生的 CNN-DailyMail CSV 檔案的前 4 行。

重點新聞 網址
敘利亞官員:歐巴馬爬到樹頂,不知道如何下來。歐巴馬致函眾議院和參議院領導人。歐巴馬
將尋求國會批准對敘利亞採取軍事行動。聯合國發言人表示,目的是確定是否使用了化學武器,而不是由誰使用。
http://www.cnn.com/2013/08/31/world/meast/syria-civil-war/
Usain Bolt 贏得世界錦標賽第三面金牌。帶領牙買加隊贏得 4x100 公尺接力賽。Bolt 在錦標賽中獲得第八面金牌。牙買加隊在女子
4x100 公尺接力賽中再添一金。
http://edition.cnn.com/2013/08/18/sport/athletics-bolt-jamaica-gold
該機構堪薩斯市辦公室的員工是數百名「虛擬」員工之一。該員工去年往返美國本土的旅費
超過 24,000 美元。遠距辦公計畫與所有 GSA 實務一樣,都在審查中。
http://www.cnn.com:80/2012/08/23/politics/gsa-hawaii-teleworking
最新消息:一位加拿大醫生表示,她曾在 2010 年參與檢查 Harry Burkhart 的團隊。最新消息:診斷結果:「自閉症、嚴重焦慮症、創傷後壓力
症候群和憂鬱症」官員表示,Burkhart 也涉嫌德國縱火案調查。檢察官認為,這位德國國民在洛杉磯縱火多起。
在洛杉磯。
http://edition.cnn.com:80/2012/01/05/justice/california-arson/index.html?
data_loader = searcher.TextDataLoader.create("universal_sentence_encoder.tflite", l2_normalize=True)
data_loader.load_from_csv("cnn_dailymail.csv", text_column="highlights", metadata_column="urls")

對於圖片使用案例,您可以建立 searcher.ImageDataLoader 執行個體,然後使用 data_loader.load_from_folder 從資料夾載入圖片。searcher.ImageDataLoader 執行個體需要由 TFLite 嵌入器模型建立,因為它將用於將查詢編碼為功能向量,並與 TFLite Searcher 模型一起匯出。例如

data_loader = searcher.ImageDataLoader.create("mobilenet_v2_035_96_embedder_with_metadata.tflite")
data_loader.load_from_folder("food/")

步驟 2:建立搜尋器模型

  • 設定 ScaNN 選項。如需更多詳細資訊,請參閱 API 文件
  • 從資料和 ScaNN 選項建立搜尋器模型。您可以參閱深入探討,以進一步瞭解 ScaNN 演算法。
scann_options = searcher.ScaNNOptions(
      distance_measure="dot_product",
      tree=searcher.Tree(num_leaves=140, num_leaves_to_search=4),
      score_ah=searcher.ScoreAH(dimensions_per_block=1, anisotropic_quantization_threshold=0.2))
model = searcher.Searcher.create_from_data(data_loader, scann_options)

在上述範例中,我們定義了下列選項

  • distance_measure:我們使用「dot_product」來測量兩個嵌入向量之間的距離。請注意,我們實際上計算點積值,以保留「越小越接近」的概念。

  • tree:資料集被劃分為 140 個分割區 (大約是資料大小的平方根),並且在檢索期間搜尋其中 4 個分割區,大約佔資料集的 3%。

  • score_ah:我們使用相同的維度將浮點嵌入量化為 int8 值,以節省空間。

步驟 3:匯出 TFLite 模型

然後您可以匯出 TFLite Searcher 模型。

model.export(
      export_filename="searcher.tflite",
      userinfo="",
      export_format=searcher.ExportFormat.TFLITE)

在您的查詢中測試 TFLite 模型

您可以使用自訂查詢文字測試匯出的 TFLite 模型。若要使用搜尋器模型查詢文字,請初始化模型並使用文字詞組執行搜尋,如下所示

from tflite_support.task import text

# Initializes a TextSearcher object.
searcher = text.TextSearcher.create_from_file("searcher.tflite")

# Searches the input query.
results = searcher.search("The Airline Quality Rankings Report looks at the 14 largest U.S. airlines.")
print(results)

如需有關如何將模型整合到各種平台的詳細資訊,請參閱Task Library 文件

閱讀更多資訊

如需更多資訊,請參閱