在 TensorFlow.js 中編寫自訂運算元、核心和梯度

總覽

本指南概述在 TensorFlow.js 中定義自訂運算元 (op)、核心和梯度的機制。旨在概述主要概念,並提供指向程式碼的指標,以示範實際運作的概念。

本指南適用於哪些對象?

這是一份相當進階的指南,涵蓋 TensorFlow.js 的一些內部機制,對於以下幾類人可能特別有用

  • 對自訂各種數學運算行為感興趣的 TensorFlow.js 進階使用者 (例如,覆寫現有梯度實作的研究人員,或需要修補程式庫中遺失功能的使用者)
  • 建構擴充 TensorFlow.js 的程式庫的使用者 (例如,建構於 TensorFlow.js 基本元件之上的通用線性代數程式庫,或新的 TensorFlow.js 後端)。
  • 對貢獻新運算元給 tensorflow.js 感興趣的使用者,他們想要大致瞭解這些機制如何運作。

本指南不是 TensorFlow.js 的一般使用指南,因為它深入探討了內部實作機制。您不需要瞭解這些機制即可使用 TensorFlow.js

您確實需要熟悉 (或願意嘗試) 閱讀 TensorFlow.js 原始碼,才能充分利用本指南。

術語

在本指南中,一些關鍵術語對於預先描述非常有用。

運算元 (Ops) — 對一個或多個張量執行數學運算,產生一個或多個張量作為輸出。運算元是「高階」程式碼,可以使用其他運算元來定義其邏輯。

核心 — 與特定硬體/平台功能相關聯的運算元特定實作。核心是「低階」且後端特定的。某些運算元具有從運算元到核心的一對一對應關係,而其他運算元則使用多個核心。

梯度 / GradFunc運算元/核心的「反向模式」定義,用於計算該函數相對於某些輸入的導數。梯度是「高階」程式碼 (非後端特定),可以呼叫其他運算元或核心。

核心登錄檔 - 從 (核心名稱、後端名稱) 元組到核心實作的對應。

梯度登錄檔 — 從核心名稱到梯度實作的對應。

程式碼組織

運算元梯度tfjs-core中定義。

核心是後端特定的,並在其各自的後端資料夾中定義 (例如,tfjs-backend-cpu)。

自訂運算元、核心和梯度不需要在這些套件內定義。但通常會在其實作中使用類似的符號。

實作自訂運算元

將自訂運算元視為只是傳回一些張量輸出的 JavaScript 函數的一種方式,通常以張量作為輸入。

  • 某些運算元可以完全根據現有運算元定義,並且應該只匯入並直接呼叫這些函數。這是一個範例
  • 運算元的實作也可以分派到後端特定的核心。這是透過 Engine.runKernel 完成的,將在「實作自訂核心」章節中進一步說明。這是一個範例

實作自訂核心

後端特定的核心實作允許針對給定運算最佳化邏輯實作。核心由呼叫 tf.engine().runKernel() 的運算元叫用。核心實作由四件事定義

  • 核心名稱。
  • 實作核心的後端。
  • 輸入:核心函數的張量引數。
  • 屬性:核心函數的非張量引數。

這是核心實作的範例。用於實作的慣例是後端特定的,最好透過查看每個特定後端的實作和文件來理解。

一般而言,核心在低於張量的層級運作,而是直接讀取和寫入記憶體,這些記憶體最終將由 tfjs-core 包裝到張量中。

一旦實作核心,就可以使用 tfjs-core 中的 registerKernel 函數向 TensorFlow.js 註冊。您可以為您希望該核心運作的每個後端註冊核心。註冊後,可以使用 tf.engine().runKernel(...) 叫用核心,TensorFlow.js 將確保分派到目前作用中後端的實作。

實作自訂梯度

梯度通常針對給定的核心定義 (由呼叫 tf.engine().runKernel(...) 中使用的相同核心名稱識別)。這允許 tfjs-core 使用登錄檔來查閱任何核心在執行階段的梯度定義。

實作自訂梯度對於以下情況很有用

  • 新增程式庫中可能不存在的梯度定義
  • 覆寫現有的梯度定義,以自訂給定核心的梯度計算。

您可以在這裡查看梯度實作範例。

一旦您為給定的呼叫實作了梯度,就可以使用 tfjs-core 中的 registerGradient 函數向 TensorFlow.js 註冊。

實作自訂梯度的另一種方法是繞過梯度登錄檔 (因此允許以任意方式計算任意函數的梯度),是使用 tf.customGrad

這是程式庫中運算元使用 customGrad 的範例