在2025年NIPS会议上提出的S-GRPO(Serial-Group Decaying-Reward Policy Optimization)方法,针对当前大语言模型推理过程中普遍存在的"过度思考"现象提出了创新性解决方案。所谓"过度思考",指的是模型在生成思维链(Chain-of-Thought)时产生大量冗余推理步骤,这不仅增加了计算开销,有时甚至会因为错误累积而降低最终答案的准确性。
这种现象的根源在于传统强化学习训练范式。现有的结果奖励机制(Outcome Reward)只关注最终答案是否正确,而对中间推理过程缺乏有效调控。就像学生在解题时,老师只看最终答案给分,而不关心解题步骤是否简洁高效,这自然会导致"啰嗦式推理"的产生。
S-GRPO的核心创新在于将强化学习的调控粒度从结果级别细化到推理过程级别。通过三个关键技术——串行组生成、衰减奖励策略和优势计算更新,实现了对思维链生成过程的精细控制。实验证明,这种方法可以在保持甚至提升模型准确率的同时,显著减少推理序列长度(40.4%-61.1%的缩减),这对于降低推理成本、提升响应速度具有重要意义。
与传统GRPO(Group Relative Policy Optimization)的并行多路径采样不同,S-GRPO采用单路径串行生成策略。具体实现过程如下:
这种设计有两大优势:
注意:截断点的选择需要根据任务复杂度动态调整。简单问题可能只需要2-3个退出点,而复杂问题可能需要5个以上。
S-GRPO的核心创新之一是设计了时间衰减的奖励分配机制:
Rₜ = Rₘₐₓ × γ^(t-1)
其中:
这种设计实现了两个关键目标:
为了平衡"早期退出"和"完整推理"两种能力,S-GRPO采用独特的双阶段训练:
完整思维滚动(Full Rollout):
早期退出滚动(Early-exit Rollout):
这种混合训练策略既保留了模型原有的深度推理能力,又新增了智能退出的灵活性,类似于人类解题时在"快速判断"和"深入分析"之间的平衡。
论文中测试了多种主流推理模型,包括:
关键超参数设置:
| 参数 | 推荐值 | 说明 |
|---|---|---|
| 学习率 | 5e-6 | 使用余弦退火调度 |
| γ衰减系数 | 0.7 | 任务简单时可增大至0.8 |
| 批大小 | 32 | 根据显存调整 |
| KL散度系数 | 0.05 | 控制策略更新幅度 |
python复制from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-72B")
python复制def serial_group_sampling(prompt, max_steps=5):
full_output = model.generate(prompt, max_length=1024)
exit_points = []
for t in range(1, max_steps+1):
truncated = early_stopping(full_output, t)
exit_points.append(truncated)
return exit_points
python复制def decaying_reward(exit_points, gamma=0.7):
rewards = []
max_reward = calculate_reward(exit_points[-1]) # 完整路径奖励
for i, point in enumerate(exit_points[:-1]):
r = max_reward * (gamma ** i)
rewards.append(r if is_correct(point) else 0)
rewards.append(max_reward) # 完整路径
return rewards
衰减系数γ的调整:
退出点数量选择:
混合训练比例:
在五个主流数据集上的表现:
| 数据集 | 序列缩减率 | 准确率变化 |
|---|---|---|
| GSM8K | 52.3% | +2.1% |
| AIME 2024 | 61.1% | +3.92% |
| MATH | 47.8% | +1.3% |
| TheoremQA | 40.4% | +0.72% |
| ARC-Challenge | 55.6% | +1.8% |
问题:早期退出过多导致复杂问题准确率下降
问题:奖励稀疏导致训练不稳定
问题:模型忽略中间推理直接猜答案
延迟与准确率的权衡:
计算资源节省:
与传统方法的兼容性:
在实际应用中,我们发现S-GRPO特别适合需要实时响应的场景,如教育领域的自动解题系统、客服机器人等。通过合理设置退出策略,可以在保持90%以上准确率的同时,将响应时间控制在传统方法的60%以内。