0
点赞
收藏
分享

微信扫一扫

基于ResNet的图像分类实战:从原理到部署

菜头粿子园 04-09 09:00 阅读 17

一、深度学习与图像分类基础

图像分类是计算机视觉的核心任务之一。传统方法依赖手工特征(如SIFT、HOG),而深度学习通过卷积神经网络(CNN)自动学习特征层次。以ResNet为例,其残差连接结构有效缓解了深层网络的梯度消失问题,成为当前主流模型。

关键公式:残差块的计算
若原始映射为 \( H(x) \),ResNet学习的是残差 \( F(x) = H(x) - x \),输出变为:
\( H(x) = F(x) + x \)

二、代码实现:PyTorch全流程

1. 环境准备

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader

# 检查GPU可用性
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

2. 数据加载与增强

使用CIFAR-10数据集,包含10类物体图像:

transform = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(),  # 数据增强
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_data = torchvision.datasets.CIFAR10(
    root="data", train=True, download=True, transform=transform
)
test_data = torchvision.datasets.CIFAR10(
    root="data", train=False, download=True, transform=transform
)

train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64)

3. 定义ResNet模型

简化版ResNet-18实现:

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        residual = self.shortcut(x)
        x = nn.ReLU()(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        x += residual  # 残差连接
        return nn.ReLU()(x)

class ResNet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(64, 64, stride=1)
        self.layer2 = self._make_layer(64, 128, stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(128, num_classes)

    def _make_layer(self, in_channels, out_channels, stride):
        return nn.Sequential(
            ResidualBlock(in_channels, out_channels, stride),
            ResidualBlock(out_channels, out_channels, stride=1)
        )

    def forward(self, x):
        x = nn.ReLU()(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        return self.fc(x)

4. 训练与评估

model = ResNet().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

def train(epochs):
    for epoch in range(epochs):
        model.train()
        for X, y in train_loader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            loss = loss_fn(pred, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # 验证集测试
        model.eval()
        test_loss, correct = 0, 0
        with torch.no_grad():
            for X, y in test_loader:
                X, y = X.to(device), y.to(device)
                pred = model(X)
                test_loss += loss_fn(pred, y).item()
                correct += (pred.argmax(1) == y).type(torch.float).sum().item()
        
        print(f"Epoch {epoch+1}: Test Accuracy = {correct/len(test_loader.dataset):.3f}")

train(epochs=10)

三、优化策略与部署建议

  1. 超参数调优
    • 学习率调度(如CosineAnnealingLR)
    • 数据增强扩展(CutMix、AutoAugment)
  2. 模型轻量化

# 使用预训练模型并微调
model = torchvision.models.resnet18(pretrained=True)
model.fc = nn.Linear(512, 10)  # 修改最后一层

  1. 部署到生产环境
    • 使用ONNX格式导出模型
    • 通过TorchScript实现跨平台推理

四、总结

本文实现了ResNet在CIFAR-10上的分类任务,代码完整覆盖数据加载、模型定义、训练评估全流程。实际应用中需根据场景调整模型深度和数据增强策略。

举报

相关推荐

0 条评论