跳到主要内容

3D 并行与序列并行

要解决的问题

万亿 token 级预训练需同时解决:样本并行(DP)、层内切分(TP)、层间流水(PP),并进一步处理超长序列导致的激活显存爆炸。3D 并行(DP+TP+PP)是 Megatron 系标配;序列并行(SP)/ Context Parallel(CP) 在序列维切分 attention 计算与激活。

核心概念

全局进程网格(示意):

World=DP×TP×PP\text{World} = \text{DP} \times \text{TP} \times \text{PP}

可选第四维 CP(Context Parallel):序列长度 SS 切为 S/CS / C per rank,Attention 需跨 rank 交换 KV(Ring Attention、DeepSpeed Ulysses 等)。

并行解决瓶颈
DP吞吐、batch size
TP单层参数/激活
PP层数深度
SP/CP序列长度 SS

激活显存(Attention,粗略):

MactShCP 后每卡 S/CM_{\text{act}} \propto S \cdot h \quad \Rightarrow \quad \text{CP 后每卡 } \propto S/C

方法/算法

配置流程:

  1. 定模型规模与目标 SS(如 32k);
  2. 选 TP 组大小(常 ≤8,同机);
  3. 若仍 OOM → 开 CP 或 activation checkpoint;
  4. PP stage 数 SppS_{\text{pp}} 使每 stage 层数均衡;
  5. 剩余 GPU 划为 DP 组数。

Ring Attention:KV 块在 CP rank 间环形传递,完成全局 attention,通信与计算重叠。

工程实践

  • Megatron-Coretensor_model_parallel_sizepipeline_model_parallel_sizecontext_parallel_size
  • DeepSpeed Ulysses:序列维 all-to-all 重排 QKV。
  • 长上下文:与 RoPE 扩展FlashAttention 联调。
  • profiling:Nsight、PyTorch profiler 看 bubble 与 comm 占比。

代表工作

局限与注意点

  • 配置组合爆炸:错误 WORLD_SIZE 整除关系导致启动失败。
  • Checkpoint 复杂:需按 TP/PP/EP 切分规则转换 HF 格式。
  • 小集群:3D 过度切分反而降效;7B/13B 常 FSDP 即可。
  • MoE EP:专家并行是第五维,与 3D 正交,文档见技术报告。

相关章节