在计算机视觉领域,图像语义分割一直是最具挑战性的任务之一。简单来说,它就像给照片中的每个像素"贴标签"——让计算机不仅能认出图中有什么物体,还要精确勾勒出它们的轮廓。这项技术在自动驾驶、医疗影像分析等领域有着广泛的应用前景。
记得我第一次接触语义分割是在2016年,当时FCN(全卷积网络)刚刚问世。那时的模型虽然能完成基本的分割任务,但边缘总是毛毛糙糙的,小物体经常被漏掉。经过这些年的发展,U-Net、DeepLab等架构不断刷新着性能记录。但直到今天,如何在复杂场景下实现精确分割,仍然是业界的研究热点。
我们的模型以经典的U-Net为基础框架,但做了几个关键改进:
骨干网络升级:将原来的简单编码器替换为ResNet50。这个选择背后有深思熟虑:
混合注意力模块(HAM):
python复制class HybridAttentionModule(nn.Module):
def __init__(self, in_channels):
super().__init__()
# 通道注意力分支
self.channel_att = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, in_channels//4, 1),
nn.ReLU(),
nn.Conv2d(in_channels//4, in_channels, 1),
nn.Sigmoid()
)
# 空间注意力分支
self.spatial_att = nn.Sequential(
nn.Conv2d(2, 1, 3, padding=1),
nn.Sigmoid()
)
def forward(self, x):
# 通道注意力
channel_att = self.channel_att(x)
# 空间注意力
avg_pool = torch.mean(x, dim=1, keepdim=True)
max_pool = torch.max(x, dim=1, keepdim=True)[0]
spatial_att = self.spatial_att(torch.cat([avg_pool, max_pool], dim=1))
# 融合
return x * channel_att * spatial_att
这个模块的创新点在于:
我们采用了改进版的金字塔池化模块(PPM):
python复制class PyramidPoolingModule(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.pool_sizes = [1, 2, 3, 6]
self.convs = nn.ModuleList([
nn.Sequential(
nn.AdaptiveAvgPool2d(size),
nn.Conv2d(in_channels, out_channels, 1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
) for size in self.pool_sizes
])
def forward(self, x):
x_size = x.size()[2:]
features = [x]
for conv in self.convs:
feat = conv(x)
feat = F.interpolate(feat, x_size, mode='bilinear', align_corners=True)
features.append(feat)
return torch.cat(features, dim=1)
关键改进包括:
在PASCAL VOC 2012数据集上,我们采用了强化的数据增强策略:
python复制train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.2),
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
transforms.RandomAffine(degrees=15, translate=(0.1,0.1), scale=(0.8,1.2)),
transforms.RandomResizedCrop(256, scale=(0.5, 1.0)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
几个实用技巧:
我们采用联合损失函数:
python复制class DiceLoss(nn.Module):
def __init__(self, smooth=1e-6):
super().__init__()
self.smooth = smooth
def forward(self, pred, target):
pred = F.softmax(pred, dim=1)
target = F.one_hot(target, num_classes=21).permute(0,3,1,2).float()
intersection = (pred * target).sum(dim=(2,3))
union = pred.sum(dim=(2,3)) + target.sum(dim=(2,3))
dice = (2.*intersection + self.smooth)/(union + self.smooth)
return 1. - dice.mean()
criterion = lambda pred, target: 0.5*F.cross_entropy(pred, target) + 0.5*DiceLoss()(pred, target)
这样设计的原因是:
我们采用余弦退火配合热重启:
python复制scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer,
T_0=20, # 初始周期长度
T_mult=2, # 周期倍增因子
eta_min=1e-6
)
这种策略的优势:
使用Apex库的混合精度训练:
python复制from apex import amp
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
实测效果:
在PASCAL VOC 2012测试集上的表现:
| 模型 | mIoU(%) | 参数量(M) | FPS |
|---|---|---|---|
| FCN-8s | 62.2 | 134.5 | 12.3 |
| U-Net | 68.4 | 31.0 | 23.1 |
| 我们的模型 | 73.8 | 48.7 | 18.6 |
关键发现:
标签处理陷阱:
python复制label[label == 255] = 20 # 假设20是背景类
显存优化技巧:
python复制for i, (x,y) in enumerate(train_loader):
pred = model(x)
loss = criterion(pred, y)/4 # 假设累积4步
loss.backward()
if (i+1)%4 == 0:
optimizer.step()
optimizer.zero_grad()
调试建议:
这个项目最让我意外的是注意力机制对细小物体的提升效果。记得有张测试图片中有只远处的小鸟,普通U-Net完全漏检,而我们的模型却能准确标出。这也让我意识到,在计算机视觉领域,有时候模仿人类的"注意力"机制,确实能带来质的飞跃。