跳到主要内容

Ring Attention 与序列并行

总览见 稀疏注意力总览。Ring 解决 「单卡装不下、但语义仍需全连接 attention」;与 DSANSA减少参与计算的 token 数 正交。

要解决的问题

当序列长度 LL 达到 32K、128K 甚至更长时,即便使用 Flash Attention单层激活与 KV 仍可能超出单卡 HBM。典型瓶颈:

  • 训练:保存 L×LL\times L 中间量或分块累加时的 activation;batch 与序列长度无法同时拉大。
  • 推理:KV Cache 随 LL 线性增长;多卡推理若不做切分,单卡仍要承载完整 KV。

Ring Attention 与 Sequence Parallel(序列并行) 不把 attention 改成稀疏掩码,而是把 序列维上的 K/V(及中间结果)切到多卡,通过环形通信拼出全局 attention 结果。

Ring Attention 核心机制

Ring Attention

直觉

PP 张 GPU,将序列均分为 PP 段,每卡 initially 只持有本段 Query 及对应 K/V 块。要计算全局 attention,每卡的 Query 需要「见过」所有位置的 K/V

Ring 协议:在 P1P-1 轮通信中,每轮将本卡的 K/V block 发给环上下一张卡;各卡用收到的 K/V 块与本地 Q 做 attention 分块累加(配合 online softmax),PP 轮后每卡得到其负责 Query 位置上的完整 attention 输出。

伪代码(概念级)

# 每卡 p 持有 Q_p, K_p, V_p(序列的一段)
for step in 0 .. P-1:
K_recv, V_recv = ring_receive_from_prev_gpu()
partial_attn = flash_attn(Q_local, K_recv, V_recv) # 分块累加
accumulate_online_softmax(partial_attn)
ring_send_to_next_gpu(K_local, V_local) # 轮转本段 KV

复杂度与通信

维度单卡稠密Ring(PP 卡)
每卡峰值 KV/激活O(L)O(L)O(L/P)O(L/P)
全局总 FLOPsO(L2)O(L^2)O(L2)O(L^2)(未减少总计算)
通信量/步与 KV block 大小成正比,约 O(P)O(P)

结论:Ring 是 显存与并行度 优化,不是算法渐近阶优化。

Sequence Parallel

序列并行(SP) 将 LayerNorm、Dropout、部分 FFN 输入等 沿序列维切分,使 activation 不集中在单卡。与 Ring Attention 组合时:

  • SP 降低 非 attention 算子 的激活峰值;
  • Ring 降低 attention 内 KV 与分块计算 的峰值。

性能指标

实践中(如 Megatron-LM、DeepSpeed-Ulysses 生态的多种组合),Sequence Parallel + Ring Attention 可将单卡峰值显存显著压低,使 更长 LL 进入训练

与其它并行方式对比

方式切分维度主要缓解是否保持全连接 attention
Tensor Parallel隐藏维 / 头维参数与单层矩阵
Pipeline Parallel层间激活
Ring / Sequence Parallel序列KV、序列 activation
Ulysses / Context Parallel头或序列+通信重排长上下文 attention 通信
滑动窗口 / DSA算法掩码FLOPs 与访存否(子集 token)
Ulysses 与 Ring 的选型(经验判断)
  • Ring:实现路径成熟,适合「KV 放不下」的训练集群。
  • Ulysses 类:通过 all-to-all 重排 Q/K,有时在特定 PPHH 下通信更优;需框架支持(如 DeepSpeed-Ulysses)。

二者常与 Flash Attention 内核叠加,与 token 稀疏 也可叠加(稀疏掩码在分块内核内实现)。

工程落地

  • 训练:长上下文预训练/SFT 常用多卡 Ring + SP;需关注 NCCL 带宽 与 ring 步数带来的 wall-clock 开销。
  • 推理:多卡推理若不做 tensor parallel,Ring 类方案较少见;推理侧更常先做 KV 压缩(GQA/MLA)token 稀疏(DSA)
  • 与 Flash:分块 attention 累加天然适配 FlashAttention 的 tiling;工业实现多在分布式框架内封装。

局限与风险

  1. 总 FLOPs 不降:集群总算力需求仍随 L2L^2 增长。
  2. 通信瓶颈PP 过大或网络慢时,ring 步数拖累扩展效率。
  3. 实现复杂度:online softmax 跨块合并、梯度同步需与框架深度集成。
  4. 不等同稀疏:不能替代 滑动窗口DSA 降低单卡推理成本。

参考链接