在工业质检、遥感影像分析等领域,传统的水平框检测往往难以精确框选倾斜或旋转的目标。YOLO26 OBB(Oriented Bounding Box)旋转目标检测技术应运而生,它能够通过带角度的矩形框更精准地定位目标。本文将手把手带你完成从数据标注到模型训练的全流程实战。
提示:OBB检测相比传统水平框检测,在航空影像、工业零件缺陷检测等场景中能提升10%-30%的检测精度。
推荐使用LabelImg-Rotated或CVAT进行旋转框标注。以LabelImg-Rotated为例:
bash复制git clone https://github.com/cgvict/labelImg-Rotated
cd labelImg-Rotated
pip install -r requirements.txt
python labelImg.py
json复制{
"version": "4.5.6",
"flags": {},
"shapes": [
{
"label": "crack",
"points": [[x1,y1], [x2,y2], [x3,y3], [x4,y4]],
"shape_type": "polygon"
}
]
}
YOLO26 OBB需要特定的txt格式,每行包含:
code复制class_id x_center y_center width height angle_radians
转换脚本核心逻辑:
python复制import math
import json
def convert(json_path, txt_path):
with open(json_path) as f:
data = json.load(f)
with open(txt_path, 'w') as f:
for shape in data['shapes']:
points = np.array(shape['points'])
rect = cv2.minAreaRect(points)
(x,y), (w,h), angle = rect
# 归一化处理
x_center, y_center = x/img_width, y/img_height
w_norm, h_norm = w/img_width, h/img_height
# 角度转为弧度并标准化
angle_rad = math.radians(angle)
if w < h: # 确保长边对应width
angle_rad += math.pi/2
w_norm, h_norm = h_norm, w_norm
line = f"{class_id} {x_center} {y_center} {w_norm} {h_norm} {angle_rad}\n"
f.write(line)
注意:角度处理是转换的关键,YOLO26要求角度范围在[-π/2, π/2]之间,且width始终对应长边。
推荐使用Python3.8+PyTorch1.12+CUDA11.3环境:
bash复制conda create -n yolo26 python=3.8
conda activate yolo26
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
git clone https://github.com/xxx/YOLO26
cd YOLO26
pip install -r requirements.txt
创建dataset.yaml文件:
yaml复制path: /path/to/dataset
train: images/train
val: images/val
test: images/test
names:
0: crack
1: scratch
2: corrosion
目录结构示例:
code复制dataset/
├── images/
│ ├── train/
│ ├── val/
│ └── test/
└── labels/
├── train/
├── val/
└── test/
关键训练参数解析:
bash复制python train.py \
--data dataset.yaml \
--cfg models/yolov5s-obb.yaml \
--weights '' \
--batch-size 16 \
--epochs 100 \
--img-size 640 \
--device 0 \
--hyp data/hyps/hyp.obb.yaml \
--rect \ # 矩形训练提升速度
--noautoanchor \ # 关闭自动anchor
--rotate 0.5 \ # 数据增强旋转概率
--degrees 180 # 最大旋转角度
实操技巧:初始训练建议使用小模型(yolov5s-obb)快速验证流程,正式训练再切换到大模型。
现象:预测框角度在0°和90°之间跳动
解决方法:
python复制class OBBLoss:
def __init__(self):
self.angle_weight = 0.2 # 默认0.1可适当提高
优化方案:
yaml复制head:
[[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # 增加P2分支
[-1, 3, C3, [256, False]],
...]
应对策略:
bash复制python train.py --batch-size 8 --accumulate 2 # 等效bs=16
bash复制python train.py --half
bash复制python train.py --img-size 512
基础测试命令:
bash复制python detect.py \
--weights runs/train/exp/weights/best.pt \
--source test_images/ \
--conf 0.25 \
--iou 0.45 \
--device 0 \
--save-txt \
--save-conf \
--hide-labels
高级功能:
--obb-metrics计算旋转框专属指标:code复制Oriented IoU: 0.78
Angle Error: 5.2°
使用剪枝压缩模型:
python复制from torch_pruner import prune
model = torch.load('best.pt')
prune.global_unstructured(
model,
pruning_method=prune.L1Unstructured,
amount=0.3 # 剪枝30%
)
torch.save(model, 'pruned.pt')
转换命令示例:
bash复制python export.py \
--weights best.pt \
--include engine \
--device 0 \
--half \
--dynamic \
--simplify
部署时注意:
cpp复制void decodeOBB(float* output, float* boxes) {
// x,y,w,h,angle格式转换
float angle = atan2(output[4], output[5]);
// 其他解码逻辑...
}
在hyp.obb.yaml中配置专用增强:
yaml复制rotate: 0.5 # 旋转增强概率
degrees: 180 # 旋转角度范围
perspective: 0.001 # 透视变换
mixup: 0.1 # 图像混合
copy_paste: 0.2 # 目标粘贴增强
在models/common.py中添加SE注意力模块:
python复制class SEBlock(nn.Module):
def __init__(self, c):
super().__init__()
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(c, c//16),
nn.ReLU(),
nn.Linear(c//16, c),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avgpool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y
改进的OBB损失计算:
python复制class OBBLoss(nn.Module):
def __init__(self):
super().__init__()
self.kld_loss = KullbackLeiblerDivergence()
def forward(self, pred, target):
# 位置损失
xy_loss = F.mse_loss(pred[:,:2], target[:,:2])
# 尺寸损失
wh_loss = F.smooth_l1_loss(pred[:,2:4], target[:,2:4])
# 角度损失(使用KLD)
angle_loss = self.kld_loss(pred[:,4:6], target[:,4:6])
return xy_loss + wh_loss + angle_loss
在实际项目中,我发现三个关键经验: