Pythonで学習したAIモデルをJavaに組み込む方法【ONNX活用】

はじめに

Pythonで構築・学習したAIモデルをJavaアプリケーションで再利用したいケースは増えています。特に業務システムやバッチ処理、組込系のJava環境にAIを組み込むには、「モデルの移植」が課題になります。

この記事では、ONNX(Open Neural Network Exchange)という共通フォーマットを使って、Pythonで学習したモデルをJavaから利用する方法を解説します。


使用する技術

  • Python(学習用)
  • ONNX(モデル変換・保存形式)
  • Java + ONNX Runtime for Java

【1】Pythonで学習済みモデルをONNX形式に変換

例:scikit-learnの分類モデルをONNXに変換

from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType

# モデル学習
iris = load_iris()
X, y = iris.data, iris.target
model = LogisticRegression().fit(X, y)

# ONNX変換
initial_type = [('input', FloatTensorType([None, 4]))]
onnx_model = convert_sklearn(model, initial_types=initial_type)

with open("iris_model.onnx", "wb") as f:
    f.write(onnx_model.SerializeToString())

ONNXファイル(例:iris_model.onnx)が生成されます。


【2】JavaでONNXモデルを読み込む

依存ライブラリ(Maven)

<dependency>
  <groupId>com.microsoft.onnxruntime</groupId>
  <artifactId>onnxruntime</artifactId>
  <version>1.17.0</version>
</dependency>

Javaコード:ONNX推論

import ai.onnxruntime.*;

import java.util.*;

public class OnnxPredictor {
  public static void main(String[] args) throws OrtException {
    try (OrtEnvironment env = OrtEnvironment.getEnvironment();
         OrtSession session = env.createSession("iris_model.onnx", new OrtSession.SessionOptions())) {

      float[][] inputData = {{5.1f, 3.5f, 1.4f, 0.2f}};
      OnnxTensor inputTensor = OnnxTensor.createTensor(env, inputData);

      Map<String, OnnxTensor> inputMap = new HashMap<>();
      inputMap.put("input", inputTensor);

      OrtSession.Result result = session.run(inputMap);
      float[][] output = (float[][]) result.get(0).getValue();

      System.out.println("予測結果:");
      for (float val : output[0]) {
        System.out.printf("%.4f ", val);
      }
    }
  }
}

【3】Spring Bootとの組み合わせ(REST API化)

Spring Bootでエンドポイントを作り、ONNX推論をAPI化することも可能です。

@RestController
public class PredictController {

  @PostMapping("/predict")
  public float[] predict(@RequestBody float[] input) throws OrtException {
    try (OrtEnvironment env = OrtEnvironment.getEnvironment();
         OrtSession session = env.createSession("iris_model.onnx", new OrtSession.SessionOptions())) {

      float[][] inputData = {input};
      OnnxTensor tensor = OnnxTensor.createTensor(env, inputData);
      Map<String, OnnxTensor> map = Map.of("input", tensor);

      OrtSession.Result result = session.run(map);
      float[][] output = (float[][]) result.get(0).getValue();
      return output[0];
    }
  }
}

まとめ

Pythonで作成したAIモデルは、ONNXを使うことでJavaアプリに簡単に組み込むことが可能です。ONNX Runtimeは軽量で高速な推論が可能で、業務システムやクラウド環境でも活用できます。

タイトルとURLをコピーしました