跳到主要内容

梯度累积与梯度裁剪

要解决的问题

目标全局 batch BglobalB_{\text{global}} 受显存限制,单步只能放 BlocalB_{\text{local}}梯度累积GaccG_{\text{acc}} 步后再更新,等价更大 batch。深层网络梯度范数偶发爆炸会导致 loss spike 或发散,梯度裁剪将更新幅度约束在稳定区间。

核心概念

累积 GaccG_{\text{acc}} 步后:

geff=1Gacci=1Gaccg(i)g_{\text{eff}} = \frac{1}{G_{\text{acc}}} \sum_{i=1}^{G_{\text{acc}}} g^{(i)}

等价全局 batch Bglobal=BlocalKGaccB_{\text{global}} = B_{\text{local}} \cdot K \cdot G_{\text{acc}}KK 为 DP 卡数)。

Global norm clipping:设所有参数梯度拼接向量 gg,阈值 τ\tau

ggmin(1,τg2+ϵ)g \leftarrow g \cdot \min\left(1, \frac{\tau}{\|g\|_2 + \epsilon}\right)
技术作用
累积显存 ↔ 有效 batch
Clip抑制尖刺梯度
LR warmup与累积步数配合

方法/算法

实现模式:

for i, batch in enumerate(loader):
loss = model(batch) / grad_accum_steps
loss.backward()
if (i + 1) % grad_accum_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
optimizer.zero_grad()

FSDP/DDP:累积步内用 model.no_sync()(DDP)避免中间 AllReduce;最后一步同步。

学习率缩放:线性规则 η=ηBglobal/Bref\eta' = \eta \cdot B_{\text{global}}/B_{\text{ref}} 仅近似;大 batch 常需 warmup 更长 或 sqrt scaling。

工程实践

  • 典型 clipmax_norm=1.0(GPT、LLaMA 类);不稳定时可降到 0.5。
  • 监控:记录 grad_norm 直方图;突刺对应数据 batch 或数值问题。
  • 与 ZeRO:裁剪在 optimizer.step 前对 unshard 梯度或分片梯度统一处理(框架内封装)。
  • 有效吞吐:累积不增算力吞吐,只换等效 batch;要更快需加 GPU。

代表工作

局限与注意点

  • 过大 GaccG_{\text{acc}}:BN 统计(若有)偏差;优化器动量滞后感增强。
  • 裁剪过狠:限制学习速度,欠拟合风险。
  • 与 AdamW:裁剪作用于梯度,权重衰减解耦仍在 optimizer
  • 混合精度:FP16 下溢梯度在裁剪前可能已为 0,见 3.6.1

延伸说明

FSDP/DDP 仅在最后累积步同步梯度,避免通信翻倍。

实践检查清单

  • grad_norm
  • τ\tau
  • warmup

小结

本节核心:grad_norm 与全链路 τ\tau 协同;上线前用检查清单做回归。

相关章节