TensorFlow Java 可以在任何 JVM 上執行,以建構、訓練和部署機器學習模型。它支援在圖形或渴望模式下的 CPU 和 GPU 執行,並提供豐富的 API,以便在 JVM 環境中使用 TensorFlow。Java 和其他 JVM 語言(例如 Scala 和 Kotlin)在全球各地的大型和小型企業中經常使用,這使得 TensorFlow Java 成為大規模採用機器學習的策略性選擇。
需求條件
TensorFlow Java 在 Java 8 及更高版本上執行,並開箱即用支援以下平台
- Ubuntu 16.04 或更高版本;64 位元,x86
- macOS 10.12.6 (Sierra) 或更高版本;64 位元,x86
- Windows 7 或更高版本;64 位元,x86
版本
TensorFlow Java 有自己的發布週期,獨立於 TensorFlow 執行階段。因此,其版本與其執行的 TensorFlow 執行階段版本不符。請參閱 TensorFlow Java 版本控制表,以列出所有可用版本及其與 TensorFlow 執行階段的對應關係。
構件
有幾種方法可以將 TensorFlow Java 新增至您的專案。最簡單的方法是新增對 tensorflow-core-platform
構件的依附元件,其中包含 TensorFlow Java Core API 和在所有支援平台上執行所需的原生依附元件。
您也可以選擇以下其中一種擴充功能來取代純 CPU 版本
tensorflow-core-platform-mkl
:在所有平台上支援 Intel® MKL-DNNtensorflow-core-platform-gpu
:在 Linux 和 Windows 平台上支援 CUDA®tensorflow-core-platform-mkl-gpu
:在 Linux 平台上支援 Intel® MKL-DNN 和 CUDA®。
此外,可以新增 tensorflow-framework
程式庫的個別依附元件,以受益於 JVM 上以 TensorFlow 為基礎的機器學習的豐富公用程式集。
使用 Maven 安裝
若要將 TensorFlow 包含在您的 Maven 應用程式中,請將其構件的依附元件新增至專案的 pom.xml
檔案。例如:
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-core-platform</artifactId>
<version>0.3.3</version>
</dependency>
減少依附元件數量
務必注意,新增 tensorflow-core-platform
構件的依附元件將匯入所有支援平台的原生程式庫,這可能會大幅增加專案的大小。
如果您希望以可用平台的子集為目標,則可以使用 Maven 依附元件排除功能,從其他平台排除不必要的構件。
選擇要在應用程式中包含哪些平台的另一種方法是在 Maven 命令列或 pom.xml
中設定 JavaCPP 系統屬性。如需更多詳細資訊,請參閱 JavaCPP 文件。
使用快照
來自 TensorFlow Java 來源儲存庫的最新 TensorFlow Java 開發快照可在 OSS Sonatype Nexus 儲存庫中取得。若要依附於這些構件,請務必在您的 pom.xml
中設定 OSS 快照儲存庫。
<repositories>
<repository>
<id>tensorflow-snapshots</id>
<url>https://oss.sonatype.org/content/repositories/snapshots/</url>
<snapshots>
<enabled>true</enabled>
</snapshots>
</repository>
</repositories>
<dependencies>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-core-platform</artifactId>
<version>0.4.0-SNAPSHOT</version>
</dependency>
</dependencies>
使用 Gradle 安裝
若要將 TensorFlow 包含在您的 Gradle 應用程式中,請將其構件的依附元件新增至專案的 build.gradle
檔案。例如:
repositories {
mavenCentral()
}
dependencies {
compile group: 'org.tensorflow', name: 'tensorflow-core-platform', version: '0.3.3'
}
減少依附元件數量
使用 Gradle 從 TensorFlow Java 排除原生構件不如使用 Maven 容易。我們建議您使用 Gradle JavaCPP 外掛程式來減少此依附元件數量。
如需更多詳細資訊,請參閱 Gradle JavaCPP 文件。
從來源安裝
若要從來源建置 TensorFlow Java,並可能進行自訂,請參閱以下指示。
範例程式
此範例示範如何使用 TensorFlow 建置 Apache Maven 專案。首先,將 TensorFlow 依附元件新增至專案的 pom.xml
檔案
<project>
<modelVersion>4.0.0</modelVersion>
<groupId>org.myorg</groupId>
<artifactId>hellotensorflow</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<exec.mainClass>HelloTensorFlow</exec.mainClass>
<!-- Minimal version for compiling TensorFlow Java is JDK 8 -->
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
</properties>
<dependencies>
<!-- Include TensorFlow (pure CPU only) for all supported platforms -->
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-core-platform</artifactId>
<version>0.3.3</version>
</dependency>
</dependencies>
</project>
建立原始碼檔案 src/main/java/HelloTensorFlow.java
import org.tensorflow.ConcreteFunction;
import org.tensorflow.Signature;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.math.Add;
import org.tensorflow.types.TInt32;
public class HelloTensorFlow {
public static void main(String[] args) throws Exception {
System.out.println("Hello TensorFlow " + TensorFlow.version());
try (ConcreteFunction dbl = ConcreteFunction.create(HelloTensorFlow::dbl);
TInt32 x = TInt32.scalarOf(10);
Tensor dblX = dbl.call(x)) {
System.out.println(x.getInt() + " doubled is " + ((TInt32)dblX).getInt());
}
}
private static Signature dbl(Ops tf) {
Placeholder<TInt32> x = tf.placeholder(TInt32.class);
Add<TInt32> dblX = tf.math.add(x, x);
return Signature.builder().input("x", x).output("dbl", dblX).build();
}
}
編譯並執行
mvn -q compile exec:java
此命令會列印 TensorFlow 版本和一個簡單的計算。
成功!TensorFlow Java 已設定完成。