醒刻 Logo
← 返回文章列表

PyTorch 实战 CIFAR-10:手写 ResNet 并把准确率做到 94% 左右

这篇文章用一个完整的 CIFAR-10 分类项目,带你从数据增强、手写 ResNet、训练循环到保存权重,走完整个 PyTorch 实战流程。

2026年3月31日 · 1577 · 更新于 2026年3月31日

这是从《深度学习知识框架》里拆出来的实战篇,目标很明确:用一个完整、能跑、能理解的 CIFAR-10 项目,把前面的知识真正串起来。

如果你已经看过总纲,或者看完了训练工程那篇,这篇就很适合拿来练手。

目录

  1. 数据加载与增强
  2. 手写 ResNet
  3. 损失函数与优化器
  4. 训练与验证循环
  5. 推理与保存模型
  6. 预期结果与检查清单

1. 数据加载与增强

先把数据管道搭起来。

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# 训练集:归一化 + 增强(只对训练集做)
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2470, 0.2435, 0.2616]
    ),
])

# 验证集:只做确定性预处理
val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2470, 0.2435, 0.2616]
    ),
])

train_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=train_transform)
val_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=val_transform)

train_loader = DataLoader(
    train_dataset, batch_size=128,
    shuffle=True, num_workers=4, pin_memory=True
)
val_loader = DataLoader(
    val_dataset, batch_size=256,
    shuffle=False, num_workers=4, pin_memory=True
)

classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck']

为什么常见写法是 padding=4RandomCrop(32)

因为这样相当于先把图稍微“扩边”,再随机裁回原尺寸,能有效增加位置变化,是 CIFAR-10 很常用的一种增强组合。


2. 手写 ResNet

这里不直接调用现成的 torchvision.models.resnet18(),而是手写一个适合 CIFAR-10 的简化版,方便真正理解残差连接在干什么。

class ResidualBlock(nn.Module):
    """ResNet 的基本单元:两层卷积 + 一条跳跃连接"""

    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()

        self.main = nn.Sequential(
            nn.Conv2d(in_channels, out_channels,
                      kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels,
                      kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
        )

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels),
            )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.main(x)
        out = out + self.shortcut(x)
        out = self.relu(out)
        return out


class ResNet18_CIFAR(nn.Module):
    """适配 CIFAR-10 的 ResNet-18"""

    def __init__(self, num_classes=10):
        super().__init__()

        self.stem = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )

        self.layer1 = self._make_layer(64, 64, stride=1)
        self.layer2 = self._make_layer(64, 128, stride=2)
        self.layer3 = self._make_layer(128, 256, stride=2)
        self.layer4 = self._make_layer(256, 512, stride=2)

        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

        self._init_weights()

    def _make_layer(self, in_ch, out_ch, stride):
        return nn.Sequential(
            ResidualBlock(in_ch, out_ch, stride),
            ResidualBlock(out_ch, out_ch, stride=1)
        )

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.head(x)
        return x


model = ResNet18_CIFAR(num_classes=10)
print(sum(p.numel() for p in model.parameters()))

bias=False 的原因:卷积层后面已经接了 BatchNorm,额外的 bias 往往没有必要。


3. 损失函数与优化器

对于 CIFAR-10 这类分类任务,比较稳的一套起步配置是:

  • CrossEntropyLoss
  • SGD + momentum
  • weight_decay
  • 余弦退火学习率调度
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

optimizer = torch.optim.SGD(
    model.parameters(),
    lr=0.1,
    momentum=0.9,
    weight_decay=5e-4,
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=200, eta_min=1e-4
)

这里的 label_smoothing=0.1 可以理解为:

不让模型把某个类别的概率压得过于极端,能稍微改善泛化。


4. 训练与验证循环

这是整个项目最核心的一部分。

def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss, correct, total = 0, 0, 0

    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)

    return total_loss / total, 100. * correct / total


@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss, correct, total = 0, 0, 0

    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)

        total_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)

    return total_loss / total, 100. * correct / total


EPOCHS = 200
best_val_acc = 0
history = {
    'train_loss': [], 'val_loss': [],
    'train_acc': [], 'val_acc': []
}

for epoch in range(EPOCHS):
    train_loss, train_acc = train_one_epoch(
        model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = evaluate(
        model, val_loader, criterion, device)

    scheduler.step()

    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_acc'].append(train_acc)
    history['val_acc'].append(val_acc)

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), 'best_model.pth')

    if (epoch + 1) % 10 == 0:
        lr = optimizer.param_groups[0]['lr']
        print(
            f"Epoch {epoch+1:3d}/{EPOCHS} | "
            f"lr={lr:.5f} | "
            f"Train Loss={train_loss:.3f} Acc={train_acc:.1f}% | "
            f"Val Loss={val_loss:.3f} Acc={val_acc:.1f}%"
        )

print(f"最佳验证准确率: {best_val_acc:.2f}%")

如果你训练时发现:

  • 训练集很好、验证集很差 → 多半是过拟合
  • 训练和验证都不行 → 优先排查学习率、数据管道和模型容量

5. 推理与保存模型

训练结束后,别只保存一份参数,最好连优化器状态和历史记录一起保存,方便断点续训。

import torch.nn.functional as F

model.load_state_dict(torch.load('best_model.pth'))
model.eval()


def predict(image_tensor, model, device):
    with torch.no_grad():
        x = image_tensor.unsqueeze(0).to(device)
        logits = model(x)
        probs = F.softmax(logits, dim=1)
        conf, pred = probs.max(1)
    return classes[pred.item()], conf.item()


torch.save({
    'epoch': EPOCHS,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'best_val_acc': best_val_acc,
    'history': history,
}, 'checkpoint_full.pth')

6. 预期结果与检查清单

6.1 一个比较合理的结果预期

阶段验证准确率说明
刚把代码跑通~85%说明整体流程已经基本正确
加上增强和调度器~92%常见可达水平
label smoothing 等优化~94%这篇代码的目标区间
再加 CutMix / Mixup 等~96%+更偏竞赛级玩法

6.2 跑完后一定要检查这三件事

  1. 训练损失是否稳定下降
  2. 拿 100 个样本能否快速过拟合
  3. 训练和验证之间的差距是否合理

如果这三件事都对,通常说明你的模型、数据管道和训练循环都已经比较健康了。


最适合新手的深度学习实战,不是上来就训大模型,而是先把这样一个小而完整的项目真正跑明白。只要你把 CIFAR-10 + ResNet 吃透,后面很多视觉任务都会顺很多。

评论区