はじめに
ディープラーニングを実務や趣味で活用したいと考えたときに、まず触れてみたいのが「画像分類」です。
本記事では、Python × PyTorch を使って、最短ステップで画像分類を実装する流れを解説します。
✅ 対象読者
- Pythonの基礎は理解している
- 機械学習ライブラリを少し触ったことがある
- 最小限のコードで動くサンプルが欲しい
この記事を読み終えると、画像分類モデルを学習 → 推論まで一通り実装できるようになります。
環境準備
まずはPyTorchをインストールしましょう。
GPU環境がある場合はCUDA対応版を推奨します。
pip install torch torchvision
学習用データとしてPyTorch公式の CIFAR-10 を使用します。
(小規模なデータセットなので、PCでもサクッと試せます)
データセットの準備
PyTorchではtorchvision.datasets
で簡単にデータセットを取得できます。
import torch
import torchvision
import torchvision.transforms as transforms
# データ前処理(正規化+Tensor変換)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# 学習用・テスト用データセット
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
shuffle=True)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
shuffle=False)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
モデル構築(CNN)
PyTorchではnn.Module
を継承してモデルを定義します。
シンプルなCNN(畳み込みニューラルネットワーク)を構築してみましょう。
import torch.nn as nn
import torch.nn.functional as F
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.fc1 = nn.Linear(64 * 8 * 8, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 64 * 8 * 8)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
model = SimpleCNN()
学習ループ
クロスエントロピー損失とAdam最適化を利用します。
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(5): # エポック数
running_loss = 0.0
for inputs, labels in trainloader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"[Epoch {epoch+1}] loss: {running_loss/len(trainloader):.3f}")
精度評価
テストデータで精度を確認してみます。
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in testloader:
outputs = model(inputs)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f"Accuracy: {100 * correct / total:.2f}%")
推論(新しい画像でテスト)
学習済みモデルを使って任意の画像を分類してみましょう。
from PIL import Image
import torchvision.transforms as transforms
# 推論用の前処理
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# 画像を読み込み
img = Image.open("test_image.jpg")
img = transform(img).unsqueeze(0) # バッチ次元を追加
# 推論
model.eval()
with torch.no_grad():
outputs = model(img)
_, predicted = torch.max(outputs, 1)
print(f"Predicted: {classes[predicted.item()]}")
まとめ
本記事では、
- データセットの準備
- CNNモデルの構築
- 学習と評価
- 推論
を通して、PyTorchで最速で画像分類を実装する方法を紹介しました。
次のステップとしては、
- 転移学習(ResNetやVGGを利用)
- データ拡張(Data Augmentation)
- GPUでの高速学習
などに取り組むと、より実用的なモデルを作れるようになります。