はじめに
Pythonで構築・学習したAIモデルをJavaアプリケーションで再利用したいケースは増えています。特に業務システムやバッチ処理、組込系のJava環境にAIを組み込むには、「モデルの移植」が課題になります。
この記事では、ONNX(Open Neural Network Exchange)という共通フォーマットを使って、Pythonで学習したモデルをJavaから利用する方法を解説します。
使用する技術
- Python(学習用)
- ONNX(モデル変換・保存形式)
- Java + ONNX Runtime for Java
【1】Pythonで学習済みモデルをONNX形式に変換
例:scikit-learnの分類モデルをONNXに変換
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
# モデル学習
iris = load_iris()
X, y = iris.data, iris.target
model = LogisticRegression().fit(X, y)
# ONNX変換
initial_type = [('input', FloatTensorType([None, 4]))]
onnx_model = convert_sklearn(model, initial_types=initial_type)
with open("iris_model.onnx", "wb") as f:
f.write(onnx_model.SerializeToString())
ONNXファイル(例:iris_model.onnx
)が生成されます。
【2】JavaでONNXモデルを読み込む
依存ライブラリ(Maven)
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>1.17.0</version>
</dependency>
Javaコード:ONNX推論
import ai.onnxruntime.*;
import java.util.*;
public class OnnxPredictor {
public static void main(String[] args) throws OrtException {
try (OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession session = env.createSession("iris_model.onnx", new OrtSession.SessionOptions())) {
float[][] inputData = {{5.1f, 3.5f, 1.4f, 0.2f}};
OnnxTensor inputTensor = OnnxTensor.createTensor(env, inputData);
Map<String, OnnxTensor> inputMap = new HashMap<>();
inputMap.put("input", inputTensor);
OrtSession.Result result = session.run(inputMap);
float[][] output = (float[][]) result.get(0).getValue();
System.out.println("予測結果:");
for (float val : output[0]) {
System.out.printf("%.4f ", val);
}
}
}
}
【3】Spring Bootとの組み合わせ(REST API化)
Spring Bootでエンドポイントを作り、ONNX推論をAPI化することも可能です。
@RestController
public class PredictController {
@PostMapping("/predict")
public float[] predict(@RequestBody float[] input) throws OrtException {
try (OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession session = env.createSession("iris_model.onnx", new OrtSession.SessionOptions())) {
float[][] inputData = {input};
OnnxTensor tensor = OnnxTensor.createTensor(env, inputData);
Map<String, OnnxTensor> map = Map.of("input", tensor);
OrtSession.Result result = session.run(map);
float[][] output = (float[][]) result.get(0).getValue();
return output[0];
}
}
}
まとめ
Pythonで作成したAIモデルは、ONNXを使うことでJavaアプリに簡単に組み込むことが可能です。ONNX Runtimeは軽量で高速な推論が可能で、業務システムやクラウド環境でも活用できます。