SegFormer是近年来Transformer架构在图像分割领域的成功应用代表。作为一款基于Vision Transformer的轻量级语义分割模型,它通过分层设计的多尺度特征提取和高效的注意力机制,在保持较低计算成本的同时实现了优异的性能表现。我在多个工业检测和医疗影像项目中采用SegFormer进行定制开发,发现其在处理小样本数据时展现出的迁移学习能力尤为突出。
训练自定义数据集的核心挑战在于数据准备与模型适配的完整链路打通。不同于使用现成的公开数据集,从原始图像标注到最终模型部署的每个环节都需要针对性处理。本文将基于我在遥感图像分割和病理切片分析中的实战经验,详细拆解从零开始训练SegFormer的全流程关键技术点。
推荐使用Python 3.8+和PyTorch 1.10+的组合,这是经过多个项目验证的稳定版本搭配。以下是关键依赖的安装命令:
bash复制pip install torch==1.10.0 torchvision==0.11.1
pip install mmcv-full==1.4.0 -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.10.0/index.html
pip install mmsegmentation==0.20.0
注意:MMCV与PyTorch版本必须严格匹配,这是导致90%环境问题的根源。建议先确定PyTorch版本后再安装对应MMCV。
SegFormer要求数据集遵循特定的目录结构。这是我为医疗影像项目设计的标准格式:
code复制custom_dataset/
├── img_dir/
│ ├── train/
│ │ ├── case_001.png
│ │ └── case_002.png
│ └── val/
│ ├── case_101.png
│ └── case_102.png
└── ann_dir/
├── train/
│ ├── case_001.png
│ └── case_002.png
└── val/
├── case_101.png
└── case_102.png
标注图像需要是单通道的PNG文件,每个像素值对应类别ID。例如在道路分割任务中:
在configs/_base_/datasets/custom.py中配置增强流水线。这是我针对小样本数据集优化的组合:
python复制train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=(512, 512), cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
dict(type='Pad', size=(512, 512), pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg'])
]
实操心得:
cat_max_ratio参数控制目标类别在裁剪区域的最小占比,对不平衡数据集特别有效。在细胞分割任务中设置为0.9可避免关键特征被裁切。
复制configs/segformer/segformer_mit-b0_512x512_160k_ade20k.py为基础模板,关键修改点包括:
python复制data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type='CustomDataset',
data_root='data/custom_dataset',
img_dir='img_dir/train',
ann_dir='ann_dir/train',
pipeline=train_pipeline),
val=dict(
type='CustomDataset',
data_root='data/custom_dataset',
img_dir='img_dir/val',
ann_dir='ann_dir/val',
pipeline=test_pipeline))
python复制model = dict(
decode_head=dict(
num_classes=3)) # 根据实际类别数调整
python复制optimizer = dict(
_delete_=True,
type='AdamW',
lr=0.00006, # 初始学习率
betas=(0.9, 0.999),
weight_decay=0.01)
单卡训练使用:
bash复制python tools/train.py configs/segformer/custom_config.py --work-dir work_dirs/exp1
多卡分布式训练(推荐):
bash复制./tools/dist_train.sh configs/segformer/custom_config.py 4 --work-dir work_dirs/exp1
注意事项:当验证集较小时(<100样本),建议设置
evaluation = dict(interval=2000)减少验证频率,避免训练中断。
使用MMSegmentation内置的日志系统:
bash复制tensorboard --logdir work_dirs/exp1
重点关注以下指标曲线:
我在实际项目中发现,当val/mIoU连续3个epoch不提升时,可以提前终止训练(Early Stopping)。
使用官方测试脚本生成详细指标:
bash复制python tools/test.py configs/segformer/custom_config.py \
work_dirs/exp1/latest.pth \
--eval mIoU aAcc mDice
对于医疗影像等专业领域,建议额外计算类别特定的Dice系数:
python复制metrics = dict(
_delete_=True,
type='DiceMetric',
iou_metrics=['mDice'],
output_dir='eval_results')
将PyTorch模型转换为ONNX格式:
bash复制python tools/pytorch2onnx.py \
configs/segformer/custom_config.py \
work_dirs/exp1/latest.pth \
--output-file model.onnx \
--shape 512 512
避坑指南:遇到
UnsupportedOperatorError时,尝试添加--dynamic-export参数启用动态尺寸导出。
使用OpenVINO进行CPU加速推理的典型流程:
python复制from openvino.runtime import Core
core = Core()
model = core.read_model("model.xml")
compiled_model = core.compile_model(model, "CPU")
# 预处理输入图像
input_tensor = preprocess(image)
# 执行推理
results = compiled_model.infer_new_request({0: input_tensor})
# 后处理输出
mask = postprocess(results[0])
| 问题现象 | 可能原因 | 解决方案 |
|---|---|---|
| CUDA out of memory | 批次过大或图像尺寸过大 | 减小samples_per_gpu或调整crop_size |
| NaN loss | 学习率过高或数据异常 | 检查数据标注,降低学习率10倍 |
| mIoU始终为0 | 类别ID设置错误 | 验证标注像素值是否与num_classes匹配 |
| 训练震荡严重 | 数据分布不均衡 | 启用class_weight或使用Focal Loss |
python复制param_groups = [
dict(lr_mult=0.1, params=backbone_params),
dict(lr_mult=1.0, params=decode_head_params)
]
python复制fp16 = dict(
loss_scale=512.,
grad_clip=dict(max_norm=35, norm_type=2))
python复制loss_decode=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0,
class_weight=[0.8, 1.2, 1.5]) # 根据类别频率调整
使用MMSeg的tools/analyze_results.py进行预测可视化:
bash复制python tools/analyze_results.py \
configs/segformer/custom_config.py \
work_dirs/exp1/latest.pth \
--show-dir results_vis
在遥感项目中,我通常会叠加原始图像与预测mask进行人工校验,特别关注边缘区域的预测一致性。