使用 TensorFlow Lite Support Library 處理輸入和輸出資料

行動應用程式開發人員通常與類型物件 (例如點陣圖) 或基本類型 (例如整數) 互動。然而,執行裝置端機器學習模型的 TensorFlow Lite Interpreter API 使用 ByteBuffer 形式的張量,這可能難以偵錯和操作。TensorFlow Lite Android Support Library 旨在協助處理 TensorFlow Lite 模型的輸入和輸出,並讓 TensorFlow Lite Interpreter 更容易使用。

開始使用

匯入 Gradle 依附元件和其他設定

.tflite 模型檔案複製到將執行模型的 Android 模組的 assets 目錄。指定檔案不應壓縮,並將 TensorFlow Lite 程式庫新增至模組的 build.gradle 檔案

android {
    // Other settings

    // Specify tflite file should not be compressed for the app apk
    aaptOptions {
        noCompress "tflite"
    }

}

dependencies {
    // Other dependencies

    // Import tflite dependencies
    implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly-SNAPSHOT'
    // The GPU delegate library is optional. Depend on it as needed.
    implementation 'org.tensorflow:tensorflow-lite-gpu:0.0.0-nightly-SNAPSHOT'
    implementation 'org.tensorflow:tensorflow-lite-support:0.0.0-nightly-SNAPSHOT'
}

瀏覽 MavenCentral 上託管的 TensorFlow Lite Support Library AAR,以取得 Support Library 的不同版本。

基本圖片操作和轉換

TensorFlow Lite Support Library 具有一套基本圖片操作方法,例如裁剪和調整大小。若要使用這些方法,請建立 ImagePreprocessor 並新增必要的運算。若要將圖片轉換為 TensorFlow Lite Interpreter 所需的張量格式,請建立要用作輸入的 TensorImage

import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.image.ImageProcessor;
import org.tensorflow.lite.support.image.TensorImage;
import org.tensorflow.lite.support.image.ops.ResizeOp;

// Initialization code
// Create an ImageProcessor with all ops required. For more ops, please
// refer to the ImageProcessor Architecture section in this README.
ImageProcessor imageProcessor =
    new ImageProcessor.Builder()
        .add(new ResizeOp(224, 224, ResizeOp.ResizeMethod.BILINEAR))
        .build();

// Create a TensorImage object. This creates the tensor of the corresponding
// tensor type (uint8 in this case) that the TensorFlow Lite interpreter needs.
TensorImage tensorImage = new TensorImage(DataType.UINT8);

// Analysis code for every frame
// Preprocess the image
tensorImage.load(bitmap);
tensorImage = imageProcessor.process(tensorImage);

張量的 DataType 可以透過中繼資料擷取器程式庫以及其他模型資訊讀取。

基本音訊資料處理

TensorFlow Lite Support Library 也定義了 TensorAudio 類別,其中封裝了一些基本音訊資料處理方法。它主要與 AudioRecord 搭配使用,並在環形緩衝區中擷取音訊樣本。

import android.media.AudioRecord;
import org.tensorflow.lite.support.audio.TensorAudio;

// Create an `AudioRecord` instance.
AudioRecord record = AudioRecord(...)

// Create a `TensorAudio` object from Android AudioFormat.
TensorAudio tensorAudio = new TensorAudio(record.getFormat(), size)

// Load all audio samples available in the AudioRecord without blocking.
tensorAudio.load(record)

// Get the `TensorBuffer` for inference.
TensorBuffer buffer = tensorAudio.getTensorBuffer()

建立輸出物件並執行模型

執行模型之前,我們需要建立將儲存結果的容器物件

import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;

// Create a container for the result and specify that this is a quantized model.
// Hence, the 'DataType' is defined as UINT8 (8-bit unsigned integer)
TensorBuffer probabilityBuffer =
    TensorBuffer.createFixedSize(new int[]{1, 1001}, DataType.UINT8);

載入模型並執行推論

import java.nio.MappedByteBuffer;
import org.tensorflow.lite.InterpreterFactory;
import org.tensorflow.lite.InterpreterApi;

// Initialise the model
try{
    MappedByteBuffer tfliteModel
        = FileUtil.loadMappedFile(activity,
            "mobilenet_v1_1.0_224_quant.tflite");
    InterpreterApi tflite = new InterpreterFactory().create(
        tfliteModel, new InterpreterApi.Options());
} catch (IOException e){
    Log.e("tfliteSupport", "Error reading model", e);
}

// Running inference
if(null != tflite) {
    tflite.run(tImage.getBuffer(), probabilityBuffer.getBuffer());
}

存取結果

開發人員可以直接透過 probabilityBuffer.getFloatArray() 存取輸出。如果模型產生量化輸出,請記得轉換結果。對於 MobileNet 量化模型,開發人員需要將每個輸出值除以 255,以取得每個類別的機率,範圍從 0 (最不可能) 到 1 (最有可能)。

選用:將結果對應至標籤

開發人員也可以選擇性地將結果對應至標籤。首先,將包含標籤的文字檔複製到模組的 assets 目錄。接下來,使用以下程式碼載入標籤檔案

import org.tensorflow.lite.support.common.FileUtil;

final String ASSOCIATED_AXIS_LABELS = "labels.txt";
List<String> associatedAxisLabels = null;

try {
    associatedAxisLabels = FileUtil.loadLabels(this, ASSOCIATED_AXIS_LABELS);
} catch (IOException e) {
    Log.e("tfliteSupport", "Error reading label file", e);
}

以下程式碼片段示範如何將機率與類別標籤建立關聯

import java.util.Map;
import org.tensorflow.lite.support.common.TensorProcessor;
import org.tensorflow.lite.support.common.ops.NormalizeOp;
import org.tensorflow.lite.support.label.TensorLabel;

// Post-processor which dequantize the result
TensorProcessor probabilityProcessor =
    new TensorProcessor.Builder().add(new NormalizeOp(0, 255)).build();

if (null != associatedAxisLabels) {
    // Map of labels and their corresponding probability
    TensorLabel labels = new TensorLabel(associatedAxisLabels,
        probabilityProcessor.process(probabilityBuffer));

    // Create a map to access the result based on label
    Map<String, Float> floatMap = labels.getMapWithFloatValue();
}

目前的使用案例涵蓋範圍

目前版本的 TensorFlow Lite Support Library 涵蓋

  • 常見資料類型 (浮點數、uint8、圖片、音訊和這些物件的陣列) 作為 tflite 模型的輸入和輸出。
  • 基本圖片運算 (裁剪圖片、調整大小和旋轉)。
  • 標準化和量化
  • 檔案公用程式

未來版本將改善對文字相關應用程式的支援。

ImageProcessor 架構

ImageProcessor 的設計允許預先定義圖片操作運算,並在建構過程中進行最佳化。ImageProcessor 目前支援三種基本預先處理運算,如下列程式碼片段中的三個註解所述

import org.tensorflow.lite.support.common.ops.NormalizeOp;
import org.tensorflow.lite.support.common.ops.QuantizeOp;
import org.tensorflow.lite.support.image.ops.ResizeOp;
import org.tensorflow.lite.support.image.ops.ResizeWithCropOrPadOp;
import org.tensorflow.lite.support.image.ops.Rot90Op;

int width = bitmap.getWidth();
int height = bitmap.getHeight();

int size = height > width ? width : height;

ImageProcessor imageProcessor =
    new ImageProcessor.Builder()
        // Center crop the image to the largest square possible
        .add(new ResizeWithCropOrPadOp(size, size))
        // Resize using Bilinear or Nearest neighbour
        .add(new ResizeOp(224, 224, ResizeOp.ResizeMethod.BILINEAR));
        // Rotation counter-clockwise in 90 degree increments
        .add(new Rot90Op(rotateDegrees / 90))
        .add(new NormalizeOp(127.5, 127.5))
        .add(new QuantizeOp(128.0, 1/128.0))
        .build();

如需標準化和量化的更多詳細資訊,請參閱此處

支援程式庫的最終目標是支援所有 tf.image 轉換。這表示轉換將與 TensorFlow 相同,且實作將獨立於作業系統。

也歡迎開發人員建立自訂處理器。在這些情況下,務必與訓練流程保持一致,亦即相同的預先處理應同時適用於訓練和推論,以提高重現性。

量化

當初始化輸入或輸出物件 (例如 TensorImageTensorBuffer) 時,您需要將其類型指定為 DataType.UINT8DataType.FLOAT32

TensorImage tensorImage = new TensorImage(DataType.UINT8);
TensorBuffer probabilityBuffer =
    TensorBuffer.createFixedSize(new int[]{1, 1001}, DataType.UINT8);

TensorProcessor 可用於量化輸入張量或解量化輸出張量。例如,在處理量化輸出 TensorBuffer 時,開發人員可以使用 DequantizeOp 將結果解量化為介於 0 和 1 之間的浮點機率

import org.tensorflow.lite.support.common.TensorProcessor;

// Post-processor which dequantize the result
TensorProcessor probabilityProcessor =
    new TensorProcessor.Builder().add(new DequantizeOp(0, 1/255.0)).build();
TensorBuffer dequantizedBuffer = probabilityProcessor.process(probabilityBuffer);

張量的量化參數可以透過中繼資料擷取器程式庫讀取。