KV Cache 原理
要解决的问题
自回归 Decode 每步只新增一个 token,若每步对整段历史重算注意力,复杂度为 且浪费已算过的 Key/Value。KV Cache 缓存各层历史 ,使每步仅对最新 token 做投影并 attend 到缓存,是长上下文推理的默认标配。
核心概念
对第 层、第 个头,时刻 的注意力:
Cache 内容:每层存储已生成位置的 (或 GQA 下更少的 KV 头)。
显存估算(单请求、FP16/BF16,忽略碎片):
其中 =层数,=序列长度,=KV 头数(GQA 时 ),=每头维度,因子 2 为 K 与 V。
| 变量 | Llama-3-8B 量级( 示意) | 说明 |
|---|---|---|
| 32 | 层数 | |
| 8(GQA) | 减少 KV 头 | |
| 128 | hidden/head | |
| KV ≈ 32×8192×8×128×2×2B ≈ 1GB | 随 线性增 |
方法 / 与 Prefill-Decode 配合
- Prefill:对 prompt 一次前向,填充全部层的 KV。
- Decode:每步输入 shape
[batch, 1],past_key_values传入 transformer;输出 logits 仅最后一位置。 - GQA/MQA:Query 头数多、KV 头数少,Cache 按 KV 头存储(见 2.3 多头注意力)。
与 5.2.2 PagedAttention 结合:逻辑上仍是 KV Cache,物理上用非连续页管理显存。
工程实践
- 带宽瓶颈:Decode 常 memory-bound(读 KV 量 ),优化方向为 FP8 KV、量化(5.3)、FlashAttention(5.2.3)。
- 多请求:每会话独立 KV;连续批处理下 batch 维合并(5.6.2)。
- 可观测:监控
gpu_cache_usage_perc、OOM 时优先减max_model_len或启用 paging。
代表工作
- Papernot 等早期 RNN cache;Transformer 时代见 Hugging Face
use_cache=True - Dao et al., FlashAttention(IO 感知,仍兼容 KV 语义)
- vLLM PagedAttention 论文与实现
实践检查清单
- 固定评测/推理配置(温度、max_tokens、parser 版本)便于回归
- 记录硬件:GPU 型号、驱动、框架 commit
- 对比基线:未优化前 TTFT/TPOT 或 Acc
- 文档化失败案例:OOM、解析失败率、拒答率
- 交叉阅读本章「相关章节」避免孤立优化