Spring BootでTensorFlow Hello Worldしてみる

以前はMavenでTensorFlowのHello Worldを出してみました。

Maven で TensorFlowを使ってみる
eclipseでTensorFlowを使うプロジェクトの作成方法をご紹介します。前回の記事で作成した Maven プロジェクトを活用して T...

今回の記事では、Spring BootでTensorFlowのHello Worldを出してみます。

Spring Bootプロジェクト作成

まずはSpring Bootでプロジェクトを作成します。

TensorFlow含むライブラリ読み込み

Mavenでライブラリを読み込んでいきます。pom.xmlに次のように記載します。

pom.xml
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
  xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
  <modelVersion>4.0.0</modelVersion>
  <parent>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-parent</artifactId>
    <version>2.2.7.RELEASE</version>
    <relativePath /> <!-- lookup parent from repository -->
  </parent>
  <groupId>com.demo</groupId>
  <artifactId>app</artifactId>
  <version>0.0.1-SNAPSHOT</version>
  <name>HelloTensorFlow</name>
  <description>Demo project for Spring Boot</description>

  <properties>
    <java.version>11</java.version>
    <maven-jar-plugin.version>3.1.1</maven-jar-plugin.version>
  </properties>

  <dependencies>
    <dependency>
      <groupId>org.springframework.boot</groupId>
      <artifactId>spring-boot-starter-thymeleaf</artifactId>
    </dependency>
    <dependency>
      <groupId>org.springframework.boot</groupId>
      <artifactId>spring-boot-starter-web</artifactId>
    </dependency>

    <dependency>
      <groupId>org.springframework.boot</groupId>
      <artifactId>spring-boot-devtools</artifactId>
      <scope>runtime</scope>
      <optional>true</optional>
    </dependency>
    <dependency>
      <groupId>org.springframework.boot</groupId>
      <artifactId>spring-boot-starter-test</artifactId>
      <scope>test</scope>
      <exclusions>
        <exclusion>
          <groupId>org.junit.vintage</groupId>
          <artifactId>junit-vintage-engine</artifactId>
        </exclusion>
      </exclusions>
    </dependency>
    <dependency>
      <groupId>org.tensorflow</groupId>
      <artifactId>tensorflow</artifactId>
      <version>1.14.0</version>
    </dependency>
  </dependencies>

  <build>
    <plugins>
      <plugin>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-maven-plugin</artifactId>
      </plugin>
    </plugins>
  </build>

</project>

ソースコード

Spring Bootプロジェクトを作成した際に作成されるmainメソッドの中に、公式サイトのサンプルコードを記載します。

HelloTensorFlowApplication.java
package com.demo.app;

import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;

@SpringBootApplication
public class HelloTensorFlowApplication {

  public static void main(String[] args) throws Exception {
    SpringApplication.run(HelloTensorFlowApplication.class, args);

    try (Graph g = new Graph()) {
      final String value = "Hello from " + TensorFlow.version();

      // Construct the computation graph with a single operation, a constant
      // named "MyConst" with a value "value".
      try (Tensor t = Tensor.create(value.getBytes("UTF-8"))) {
        // The Java API doesn't yet include convenience functions for adding operations.
        g.opBuilder("Const", "MyConst").setAttr("dtype", t.dataType()).setAttr("value", t).build();
      }

      // Execute the "MyConst" operation in a Session.
      try (Session s = new Session(g);
          // Generally, there may be multiple output tensors,
          // all of them must be closed to prevent resource leaks.
          Tensor output = s.runner().fetch("MyConst").run().get(0)) {
        System.out.println(new String(output.bytesValue(), "UTF-8"));
      }
    }
  }

}

このとき、コンパイルのバージョンをJDK7にする必要があります。コンパイラーがJDK7でない場合は実行前にエラーが出ます。エラーを起こしている箇所を選択してeclipse解決によりコンパイラーをJDK7に設定します。

実行

作成したSpring Bootプロジェクトを実行します。

結果としてコンソール画面に「Hello from 1.14.0」が出力されれば成功です。

Java APIの注意事項

Java APIには次のような注意書きがあります。「Java APIはAPIの安定性の保証の対象外であり、Java APIは、Pythonで作成したモデルを読み込んでJavaアプリケーション内で実行する場合に特に便利です。」

ということは、機械学習はJava APIを使用して実施するのではなくPythonで実施することが推奨されている。そして、Pythonによる機械学習で得たモデルをJavaアプリケーションで活用するアーキテクチャが推奨。ということになりそうです。

そこで、学習済みモデルレジストリのようなサービスがないかと調査してみると、「TensorFlow Serving」という良さそうなサービスを見つけました。

TensorFlow Servingの概要
TensorFlow Servingは、TensorFlowの学習済みモデルを本番環境で運用するためのシステム(RPCサーバー)です。 ...

次回

上記の注意事項を踏まえて、次回はTensorFlow Servingを活用したシステム構築方法を学習してみようと思います。