安裝 TensorFlow Java

TensorFlow Java 可以在任何 JVM 上執行,以建構、訓練和部署機器學習模型。它支援在圖形或渴望模式下的 CPU 和 GPU 執行,並提供豐富的 API,以便在 JVM 環境中使用 TensorFlow。Java 和其他 JVM 語言(例如 Scala 和 Kotlin)在全球各地的大型和小型企業中經常使用,這使得 TensorFlow Java 成為大規模採用機器學習的策略性選擇。

需求條件

TensorFlow Java 在 Java 8 及更高版本上執行,並開箱即用支援以下平台

  • Ubuntu 16.04 或更高版本;64 位元,x86
  • macOS 10.12.6 (Sierra) 或更高版本;64 位元,x86
  • Windows 7 或更高版本;64 位元,x86

版本

TensorFlow Java 有自己的發布週期,獨立於 TensorFlow 執行階段。因此,其版本與其執行的 TensorFlow 執行階段版本不符。請參閱 TensorFlow Java 版本控制表,以列出所有可用版本及其與 TensorFlow 執行階段的對應關係。

構件

幾種方法可以將 TensorFlow Java 新增至您的專案。最簡單的方法是新增對 tensorflow-core-platform 構件的依附元件,其中包含 TensorFlow Java Core API 和在所有支援平台上執行所需的原生依附元件。

您也可以選擇以下其中一種擴充功能來取代純 CPU 版本

  • tensorflow-core-platform-mkl:在所有平台上支援 Intel® MKL-DNN
  • tensorflow-core-platform-gpu:在 Linux 和 Windows 平台上支援 CUDA®
  • tensorflow-core-platform-mkl-gpu:在 Linux 平台上支援 Intel® MKL-DNN 和 CUDA®。

此外,可以新增 tensorflow-framework 程式庫的個別依附元件,以受益於 JVM 上以 TensorFlow 為基礎的機器學習的豐富公用程式集。

使用 Maven 安裝

若要將 TensorFlow 包含在您的 Maven 應用程式中,請將其構件的依附元件新增至專案的 pom.xml 檔案。例如:

<dependency>
  <groupId>org.tensorflow</groupId>
  <artifactId>tensorflow-core-platform</artifactId>
  <version>0.3.3</version>
</dependency>

減少依附元件數量

務必注意,新增 tensorflow-core-platform 構件的依附元件將匯入所有支援平台的原生程式庫,這可能會大幅增加專案的大小。

如果您希望以可用平台的子集為目標,則可以使用 Maven 依附元件排除功能,從其他平台排除不必要的構件。

選擇要在應用程式中包含哪些平台的另一種方法是在 Maven 命令列或 pom.xml 中設定 JavaCPP 系統屬性。如需更多詳細資訊,請參閱 JavaCPP 文件

使用快照

來自 TensorFlow Java 來源儲存庫的最新 TensorFlow Java 開發快照可在 OSS Sonatype Nexus 儲存庫中取得。若要依附於這些構件,請務必在您的 pom.xml 中設定 OSS 快照儲存庫。

<repositories>
    <repository>
        <id>tensorflow-snapshots</id>
        <url>https://oss.sonatype.org/content/repositories/snapshots/</url>
        <snapshots>
            <enabled>true</enabled>
        </snapshots>
    </repository>
</repositories>

<dependencies>
    <dependency>
        <groupId>org.tensorflow</groupId>
        <artifactId>tensorflow-core-platform</artifactId>
        <version>0.4.0-SNAPSHOT</version>
    </dependency>
</dependencies>

使用 Gradle 安裝

若要將 TensorFlow 包含在您的 Gradle 應用程式中,請將其構件的依附元件新增至專案的 build.gradle 檔案。例如:

repositories {
    mavenCentral()
}

dependencies {
    compile group: 'org.tensorflow', name: 'tensorflow-core-platform', version: '0.3.3'
}

減少依附元件數量

使用 Gradle 從 TensorFlow Java 排除原生構件不如使用 Maven 容易。我們建議您使用 Gradle JavaCPP 外掛程式來減少此依附元件數量。

如需更多詳細資訊,請參閱 Gradle JavaCPP 文件

從來源安裝

若要從來源建置 TensorFlow Java,並可能進行自訂,請參閱以下指示

範例程式

此範例示範如何使用 TensorFlow 建置 Apache Maven 專案。首先,將 TensorFlow 依附元件新增至專案的 pom.xml 檔案

<project>
    <modelVersion>4.0.0</modelVersion>
    <groupId>org.myorg</groupId>
    <artifactId>hellotensorflow</artifactId>
    <version>1.0-SNAPSHOT</version>

    <properties>
        <exec.mainClass>HelloTensorFlow</exec.mainClass>
        <!-- Minimal version for compiling TensorFlow Java is JDK 8 -->
        <maven.compiler.source>1.8</maven.compiler.source>
        <maven.compiler.target>1.8</maven.compiler.target>
    </properties>

    <dependencies>
        <!-- Include TensorFlow (pure CPU only) for all supported platforms -->
        <dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>tensorflow-core-platform</artifactId>
            <version>0.3.3</version>
        </dependency>
    </dependencies>
</project>

建立原始碼檔案 src/main/java/HelloTensorFlow.java

import org.tensorflow.ConcreteFunction;
import org.tensorflow.Signature;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.math.Add;
import org.tensorflow.types.TInt32;

public class HelloTensorFlow {

  public static void main(String[] args) throws Exception {
    System.out.println("Hello TensorFlow " + TensorFlow.version());

    try (ConcreteFunction dbl = ConcreteFunction.create(HelloTensorFlow::dbl);
        TInt32 x = TInt32.scalarOf(10);
        Tensor dblX = dbl.call(x)) {
      System.out.println(x.getInt() + " doubled is " + ((TInt32)dblX).getInt());
    }
  }

  private static Signature dbl(Ops tf) {
    Placeholder<TInt32> x = tf.placeholder(TInt32.class);
    Add<TInt32> dblX = tf.math.add(x, x);
    return Signature.builder().input("x", x).output("dbl", dblX).build();
  }
}

編譯並執行

mvn -q compile exec:java

此命令會列印 TensorFlow 版本和一個簡單的計算。

成功!TensorFlow Java 已設定完成。