在目标检测和图像分割任务中,类别不平衡和难易样本不平衡是长期困扰研究者的两大难题。以目标检测为例,一张图片中可能包含几十个物体,但背景区域(负样本)往往占据绝大多数像素。这种极端不平衡会导致模型训练时被大量简单负样本主导,难以有效学习关键的正样本特征。
传统交叉熵损失函数对所有样本"一视同仁",无法应对这种不平衡场景。2017年何恺明团队在RetinaNet论文中提出的Focal Loss,通过动态调整样本权重,巧妙解决了这一问题。其核心创新在于:
这种双重调节机制使模型能够:
实际测试表明,在COCO数据集上使用Focal Loss的RetinaNet,其AP指标比当时主流方法提高了3-5个百分点,尤其在小物体检测上提升显著。
二分类交叉熵(BCE)的数学表达式为:
code复制BCE = -[y·ln(p) + (1-y)·ln(1-p)]
其中y∈{0,1}是真实标签,p∈(0,1)是预测概率。
多分类交叉熵(CE)的一般形式为:
code复制CE = -Σ(y_i·ln(p_i))
交叉熵的本质是惩罚预测概率与真实标签的偏离程度。但它存在两个明显缺陷:
为解决类别不平衡问题,常见做法是引入α平衡因子:
code复制CE_α = -α·ln(p) 其中α∈[0,1]
设置原则:
但这种方法仅解决了数量不平衡,未考虑样本难易程度的差异。
Focal Loss的核心创新是引入调制因子(1-p)^γ:
code复制FL = -(1-p)^γ·ln(p)
γ>0时,该因子会产生三种关键效果:
实验表明,γ=2时能在多数任务取得最佳平衡。
结合α和γ的Focal Loss最终形式:
code复制FL = -α(1-p)^γ·ln(p)
参数说明:
当固定α=0.25时:
实验发现:
虽然直觉上α应与类别频率成反比,但实际发现:
建议初始值:
python复制class BinaryFocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2):
super().__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, preds, targets):
eps = 1e-7 # 数值稳定性
loss_pos = -self.alpha * (1-preds)**self.gamma * torch.log(preds+eps) * targets
loss_neg = -(1-self.alpha) * preds**self.gamma * torch.log(1-preds+eps) * (1-targets)
return (loss_pos + loss_neg).mean()
关键点说明:
python复制class FocalLoss(nn.Module):
def __init__(self, alpha=None, gamma=2):
super().__init__()
self.alpha = alpha # 应为各类别权重Tensor
self.gamma = gamma
def forward(self, preds, targets):
log_probs = F.log_softmax(preds, dim=1)
probs = torch.exp(log_probs)
# 获取目标类别对应的概率
batch_probs = probs.gather(1, targets.view(-1,1)).squeeze()
batch_log_probs = log_probs.gather(1, targets.view(-1,1)).squeeze()
if self.alpha is not None:
batch_alpha = self.alpha.gather(0, targets)
loss = -batch_alpha * (1-batch_probs)**self.gamma * batch_log_probs
else:
loss = -(1-batch_probs)**self.gamma * batch_log_probs
return loss.mean()
注意事项:
初始设置:
调整策略:
典型组合:
Focal Loss的梯度有两个重要特性:
对易样本(p>0.6):
对难样本(p<0.4):
这种特性带来两个优势:
推荐使用场景:
不推荐场景:
可能原因:
γ过大导致梯度爆炸
α设置不合理
症状:多数类准确率骤降
处理方法:
常见表现:出现NaN值
预防措施:
python复制preds = torch.clamp(preds, 1e-7, 1-1e-7)
改进点:动态调整α使其与类别频率的平方根成反比
code复制α_t = (1 - β)/(1 - β^n_t)
其中n_t是类别t的样本数,β∈[0,1)是超参数
让γ根据类别动态调整:
code复制γ_t = γ_base + λ·log(f_t)
f_t是类别频率,λ是调节系数
将Focal Loss与梯度均衡结合:
实际效果:比原始FL更稳定
监控指标:
与其他技术的配合:
部署注意事项:
在实际项目中,我通常会先使用标准交叉熵训练几个epoch作为baseline,然后逐步引入Focal Loss的参数调节。一个实用的技巧是监控每个batch中难易样本的loss比例,理想状态下难样本的loss应占总loss的40%-60%。如果发现比例异常,就需要及时调整γ参数。