在计算机视觉领域,姿态估计(Pose Estimation)一直是个热门且实用的研究方向。相比传统的OpenPose等方案,YOLOv8 Pose模型将目标检测和关键点检测集成到一个端到端的网络中,实现了速度和精度的完美平衡。最近在实际项目中需要训练一个自定义的YOLOv8 Pose模型来识别特定场景下的人体姿态,这里记录下完整的训练流程和踩坑经验。
提示:YOLOv8 Pose是Ultralytics公司在YOLOv8目标检测基础上扩展的姿态估计模型,支持17个COCO格式的关键点检测,单模型同时输出检测框和关键点坐标。
训练姿态估计模型对硬件要求较高,建议配置:
实测在RTX 3090上,输入640x640图像:
推荐使用conda创建虚拟环境:
bash复制conda create -n yolov8_pose python=3.8
conda activate yolov8_pose
pip install ultralytics albumentations lap
注意:务必安装lap库,这是关键点匹配的依赖,官方文档经常漏提这点。
使用LabelImg或CVAT标注工具时:
关键点标注注意事项:
数据集目录结构建议:
code复制dataset/
├── images/
│ ├── train/
│ └── val/
└── labels/
├── train/
└── val/
创建dataset.yaml:
yaml复制path: /path/to/dataset
train: images/train
val: images/val
# 关键点配置
kpt_shape: [17, 3] # 17个关键点,每个点(x,y,visibility)
flip_idx: [1,0,3,2,5,4,7,6,9,8,11,10,13,12,15,14,16] # 水平翻转时配对的关键点索引
基础训练(从预训练模型微调):
bash复制yolo train pose \
data=dataset.yaml \
model=yolov8n-pose.pt \
epochs=100 \
imgsz=640 \
batch=32 \
device=0 \
workers=8 \
optimizer=Adam \
lr0=0.001 \
cos_lr=True
关键参数解析:
cos_lr: 使用余弦退火学习率调度fliplr=0.5: 默认启用水平翻转增强mask_ratio=4: 关键点可见性采样比例bash复制tensorboard --logdir runs/pose/train
重点关注指标:
metrics/precision(B): 边界框精度metrics/kpt_precision: 关键点精度metrics/kpt_recall: 关键点召回率python复制from ultralytics import YOLO
model = YOLO('runs/pose/train/weights/best.pt')
model.val(conf=0.25, iou=0.6)
在dataset.yaml中添加增强配置:
yaml复制augment:
degrees: 10.0
translate: 0.1
scale: 0.5
shear: 2.0
perspective: 0.0005
flipud: 0.0
fliplr: 0.5
mosaic: 1.0
mixup: 0.0
重要:姿态估计任务中mixup增强通常效果不佳,建议设为0
修改关键点损失权重(源码修改):
python复制# ultralytics/models/utils/loss.py
class PoseLoss(DetLoss):
def __init__(self, model):
super().__init__(model)
self.kpt_loss_gain = 0.1 # 原始值为0.1,可适当增大
self.kpt_obj_loss = True # 是否使用关键点可见性预测
python复制model = YOLO('best.pt')
model.prune(importance_threshold=0.1)
bash复制yolo export model=best.pt format=onnx int8
python复制from ultralytics import YOLO
import cv2
model = YOLO('best.pose.pt')
cap = cv2.VideoCapture(0)
while cap.isOpened():
ret, frame = cap.read()
results = model(frame, conf=0.5)
annotated = results[0].plot()
cv2.imshow('YOLOv8 Pose', annotated)
if cv2.waitKey(1) == ord('q'):
break
bash复制yolo export model=best.pt format=engine device=0
症状:关键点位置偏离实际关节
解决方案:
kpt_loss_gain权重(见4.2节)fliplr增强概率错误信息:CUDA out of memory
处理方法:
batch_size(最低可到8)imgsz=320较小输入尺寸bash复制yolo train ... batch=16 accumulate=2
症状:视频中关键点位置跳动
优化方案:
torch.jit.trace优化推理速度tracker="bytetrack.yaml"python复制results = model.track(frame, persist=True, tracker="bytetrack.yaml")
训练自定义YOLOv8 Pose模型最关键的三个经验: