大模型入门:从 MHA 到 GQA,一次讲清 KV Cache 为什么能省显存
大模型入门从 MHA 到 GQA一次讲清 KV Cache 为什么能省显存摘要上一篇讲 MHA 时我们已经知道 KV Cache 会缓存每一层历史 token 的 K/V。继续往下看问题就变成了为什么很多大模型的 Query Head 数量和 KV Head 数量不一样本文从 KV Cache 的显存公式开始拆清 MHA、MQA、GQA 的张量形状、显存差异、手写实现和 PyTorch 接口用法。一、推理显存经常卡在 KV Cache很多人第一次跑本地大模型会以为显存主要被模型参数吃掉。这当然没错。一个 7B 模型即使用 FP16也要十几 GB 级别的参数显存。但进入真实推理后你会发现另一个东西也会涨得很快prompt 越长KV Cache 越大 batch 越大KV Cache 越大 上下文窗口越长KV Cache 越大 并发请求越多KV Cache 越难管理模型参数是加载时就基本固定的KV Cache 是生成过程中随着请求、长度和 batch 增长的。这也是为什么服务端推理框架会认真做 KV Cache 管理vLLM 的 PagedAttention、Hugging Face 的 DynamicCache/StaticCache/QuantizedCache本质上都在处理同一类问题怎么让历史 K/V 既能被快速读取又不要把显存撑爆。而 GQA 正好站在这个问题中间。一句话理解GQA 让多个 Query Head 共享较少的 Key/Value Head从而减少 KV Cache 的存储和读取压力。1. 先回忆KV Cache 到底缓存了什么Decoder-only 大模型推理时一般分成两个阶段阶段输入主要动作Prefill完整 prompt一次性计算 prompt 的每层 K/V并写入 cacheDecode当前新 token只算新 token 的 Q/K/V用新 Q 查询历史 K/VHugging Face 的缓存文档也强调自回归生成是一个 token 一个 token 往后预测KV Cache 会保存过去 token 在注意力层里的 K/V后续 token 可以复用它们避免重复计算。上一篇文章里我们用的 MHA 张量形状是q.shape[batch,num_heads,seq_len,head_dim]k.shape[batch,num_heads,seq_len,head_dim]v.shape[batch,num_heads,seq_len,head_dim]每一层要缓存的是历史 token 的k和vpast_k.shape[batch,num_heads,past_len,head_dim]past_v.shape[batch,num_heads,past_len,head_dim]注意这里缓存的是每一层的 K/V。一个 32 层模型就有 32 份这样的缓存。所以 KV Cache 的显存可以粗略估算为KV Cache bytes batch_size * seq_len * num_layers * 2 * num_kv_heads * head_dim * bytes_per_element这里的2表示 K 和 V 两份。公式里最容易被忽略的是num_kv_heads。MHA 里num_kv_heads num_query_headsGQA 里num_kv_heads num_query_heads这就是 GQA 能省显存的入口。2. 用一组数字算清楚假设有一个简化配置batch_size1seq_len8192num_layers32num_query_heads32head_dim128dtypefp16# 2 bytes如果是传统 MHAnum_kv_heads32KV Cache 大约是1 * 8192 * 32 * 2 * 32 * 128 * 2 bytes 4 GiB如果换成 GQA假设num_kv_heads8KV Cache 大约是1 * 8192 * 32 * 2 * 8 * 128 * 2 bytes 1 GiB同样的 Query Head 数量同样的上下文长度只是把 KV Head 从 32 降到 8缓存就变成原来的四分之一。如果是 MQAnum_kv_heads1KV Cache 会进一步降到128 MiB这只是一个教学估算真实框架还会受到 allocator、block size、padding、并发调度、量化和 kernel 实现影响。但作为面试和工程理解这个公式足够抓住核心。3. MHA、MQA、GQA 的区别可以先用一张表记住结构Query HeadKV Head直觉MHA多个和 Query 一样多每个 Q head 独享一组 K/VMQA多个1 个所有 Q head 共享同一组 K/VGQA多个介于 1 和 Query Head 之间一组 Q head 共享一组 K/V假设num_query_heads32num_kv_heads8group_sizenum_query_heads//num_kv_heads# 4那么 GQA 的意思是Q heads: 0 1 2 3 | 4 5 6 7 | ... | 28 29 30 31 KV head: 0 | 1 | ... | 7每 4 个 Query Head 共享 1 个 KV Head。它不像 MQA 那样把所有 Query Head 都压到同一个 KV Head 上也不像 MHA 那样每个 Query Head 都保留独立 K/V。GQA 原论文的动机也在这里MQA 可以显著提升 decoder 推理速度但可能带来质量下降GQA 使用介于 1 和 Query Head 数之间的 KV Head 数量在效果和推理效率之间做折中。4. 张量形状怎么变MHA 的投影通常是q_proj:hidden_dim-num_q_heads*head_dim k_proj:hidden_dim-num_q_heads*head_dim v_proj:hidden_dim-num_q_heads*head_dimGQA 的投影变成q_proj:hidden_dim-num_q_heads*head_dim k_proj:hidden_dim-num_kv_heads*head_dim v_proj:hidden_dim-num_kv_heads*head_dim也就是说Q 还是很多头K/V 变少了。假设batch2seq_len5num_q_heads32num_kv_heads8head_dim128那么q.shape[2,32,5,128]k.shape[2,8,5,128]v.shape[2,8,5,128]但 attention 计算时q k.transpose(-2, -1)要求 head 维度能对齐。一个教学版做法是把 K/V 按组展开k_expanded.shape[2,32,5,128]v_expanded.shape[2,32,5,128]PyTorch 的scaled_dot_product_attention(enable_gqaTrue)文档里也展示了类似逻辑启用 GQA 时会按 Query Head 和 KV Head 的比例对 key/value 做repeat_interleave。但要注意真实高性能实现不一定真的物理复制 K/V。服务端推理更关心 cache 布局、访存和 kernel 的实现方式。5. 手写一个最小 GQA下面这份代码只保留核心逻辑适合面试讲法Q Head 数可以大于 KV Head 数KV Head 必须能整除 Query HeadK/V 先按较少 head 存储计算 attention 前按组展开cache 里只缓存较少的 KV Head。importmathimporttorchfromtorchimportnndefrepeat_kv(x:torch.Tensor,n_rep:int)-torch.Tensor:# x: [B, H_kv, T, D]ifn_rep1:returnx batch,num_kv_heads,seq_len,head_dimx.shape xx[:,:,None,:,:]xx.expand(batch,num_kv_heads,n_rep,seq_len,head_dim)returnx.reshape(batch,num_kv_heads*n_rep,seq_len,head_dim)classGroupedQueryAttention(nn.Module):def__init__(self,hidden_dim:int,num_q_heads:int,num_kv_heads:int,dropout:float0.0,):super().__init__()asserthidden_dim%num_q_heads0assertnum_q_heads%num_kv_heads0self.hidden_dimhidden_dim self.num_q_headsnum_q_heads self.num_kv_headsnum_kv_heads self.head_dimhidden_dim//num_q_heads self.num_groupsnum_q_heads//num_kv_heads self.q_projnn.Linear(hidden_dim,num_q_heads*self.head_dim)self.k_projnn.Linear(hidden_dim,num_kv_heads*self.head_dim)self.v_projnn.Linear(hidden_dim,num_kv_heads*self.head_dim)self.o_projnn.Linear(num_q_heads*self.head_dim,hidden_dim)self.dropoutnn.Dropout(dropout)def_split_heads(self,x:torch.Tensor,num_heads:int)-torch.Tensor:batch,seq_len,_x.shape xx.view(batch,seq_len,num_heads,self.head_dim)returnx.transpose(1,2)# [B, H, T, D]def_merge_heads(self,x:torch.Tensor)-torch.Tensor:batch,heads,seq_len,head_dimx.shape xx.transpose(1,2).contiguous()returnx.view(batch,seq_len,heads*head_dim)defforward(self,x:torch.Tensor,attn_mask:torch.Tensor|NoneNone,past_key_value:tuple[torch.Tensor,torch.Tensor]|NoneNone,use_cache:boolFalse,):qself._split_heads(self.q_proj(x),self.num_q_heads)kself._split_heads(self.k_proj(x),self.num_kv_heads)vself._split_heads(self.v_proj(x),self.num_kv_heads)ifpast_key_valueisnotNone:past_k,past_vpast_key_value ktorch.cat([past_k,k],dim2)vtorch.cat([past_v,v],dim2)present_key_value(k,v)ifuse_cacheelseNonek_for_attnrepeat_kv(k,self.num_groups)v_for_attnrepeat_kv(v,self.num_groups)scoresq k_for_attn.transpose(-2,-1)scoresscores/math.sqrt(self.head_dim)ifattn_maskisnotNone:scoresscores.masked_fill(attn_mask,float(-inf))weightstorch.softmax(scores,dim-1)weightsself.dropout(weights)outweights v_for_attn outself._merge_heads(out)outself.o_proj(out)returnout,weights,present_key_value测试一下形状xtorch.randn(2,5,4096)gqaGroupedQueryAttention(hidden_dim4096,num_q_heads32,num_kv_heads8,)out,weights,cachegqa(x,use_cacheTrue)print(out.shape)# [2, 5, 4096]print(weights.shape)# [2, 32, 5, 5]print(cache[0].shape)# [2, 8, 5, 128]print(cache[1].shape)# [2, 8, 5, 128]关键点在最后两行。注意力权重仍然是 32 个 Query Headweights.shape[2,32,5,5]但缓存里只有 8 个 KV Headcache[0].shape[2,8,5,128]cache[1].shape[2,8,5,128]这就是 GQA 在 KV Cache 上省显存的直接体现。6. 用 PyTorch 接口怎么写PyTorch 的torch.nn.functional.scaled_dot_product_attention已经有enable_gqa参数。一个最小示例importtorchimporttorch.nn.functionalasF querytorch.randn(2,32,5,128,devicecuda,dtypetorch.float16)keytorch.randn(2,8,5,128,devicecuda,dtypetorch.float16)valuetorch.randn(2,8,5,128,devicecuda,dtypetorch.float16)outF.scaled_dot_product_attention(query,key,value,is_causalTrue,enable_gqaTrue,)print(out.shape)# [2, 32, 5, 128]官方文档里有两个约束很重要number_of_heads_query % number_of_heads_key_value 0 number_of_heads_key number_of_heads_value也就是说Query Head 数必须能被 KV Head 数整除Key Head 数和 Value Head 数必须相同enable_gqa目前仍是实验特性后端支持和张量类型有限制。还有一个容易踩坑的点PyTorch 这个函数里的布尔attn_mask语义和一些 MHA 接口的 padding mask 语义相反。scaled_dot_product_attention里True表示参与 attention迁移代码时要小心。7. 为什么 GQA 主要影响推理如果只做一次完整 forward而且不使用 KV CacheGQA 对峰值显存的影响没有 KV Cache 场景那么直观。真正的收益集中在自回归 decode每一步都要读历史 K/V 历史越长读得越多 并发越高cache 越多 KV Head 越少cache 越小Hugging Face 的优化文档也提到减少 KV 向量数量只有在使用 KV Cache 的自回归解码场景里才特别有意义因为 decode 阶段会反复读取历史 K/V内存带宽很容易成为瓶颈。所以可以这样理解场景GQA 价值训练全序列并行不是主要优化目标Prefill可以减少写入 cache 的 K/V 体积Decode最关键减少每步读取的历史 K/V长上下文服务价值更明显高并发服务价值更明显这也是为什么讲 GQA 时不能只画 attention 公式。要把它放回推理服务的 KV Cache 场景里看。8. 和 vLLM、PagedAttention 有什么关系GQA 解决的是每个 token、每一层、每个请求要存多少 KV Head。PagedAttention 解决的是这些 KV Cache 在显存里怎么分配、分页、复用和读取。二者不是同一层优化但会一起影响推理效率。vLLM 的 PagedAttention 文档里提到key/value cache 会被拆成 block每个 block 存固定数量 token 的 cache。这样做的目标是用更适合服务端调度的方式管理 KV Cache而不是把每个请求都当成一大段连续显存。可以把它们放到同一张图里GQA减少每个 token 的 KV 体积 PagedAttention管理很多 token 的 KV 存放方式 Quantized Cache降低每个元素的字节数 Offloaded Cache把部分 cache 放到 CPU如果只看单次模型结构GQA 像是 attention 结构变化。如果从推理系统看GQA 是 KV Cache 成本控制的一环。9. 常见坑坑 1只改num_kv_heads忘了改投影层输出维度GQA 里 Q/K/V 的 projection 输出维度不一样q_proj-num_q_heads*head_dim k_proj-num_kv_heads*head_dim v_proj-num_kv_heads*head_dim如果还把 K/V 投影到num_q_heads * head_dimcache 就没有省下来。坑 2num_q_heads不能整除num_kv_headsGQA 要按组共享 K/V所以通常要求num_q_heads%num_kv_heads0否则每组 Query Head 没法均匀映射到 KV Head。坑 3把 repeat 后的 K/V 当成 cache 存教学代码为了看懂会在 attention 前做repeat_kv。但 cache 里应该保留较少的 KV Headcache_k.shape[B,H_kv,T,D]如果把展开后的 K/V 存进去cache_k.shape[B,H_q,T,D]显存又回到 MHA 级别了。坑 4只算 cache 容量不看内存带宽KV Cache 不只是占显存。Decode 每一步都要读取历史 K/V所以内存带宽也会成为瓶颈。GQA 的价值不只是少存也包括少读。坑 5把 GQA 当成无损替换GQA 是效果和效率的折中。GQA 原论文的结论是GQA 相比 MQA 更能保留 MHA 的质量同时接近 MQA 的速度收益。但具体效果仍然取决于模型、训练方式、上采样策略和任务。工程上不要把结构变化理解成“免费优化”。它通常是在模型设计或训练阶段就确定好的。10. 面试怎么讲如果面试官问“GQA 和 MHA 有什么区别”可以这样回答MHA 里 Query、Key、Value 的 head 数通常一样每个 Query Head 都有独立的 K/V Head。GQA 保留较多 Query Head但减少 Key/Value Head让一组 Query Head 共享一组 K/V。这样 attention 仍然有多个 Query 子空间但 KV Cache 只需要存较少的 K/V Head。如果继续问“为什么能省显存”可以接KV Cache 每层都会存历史 token 的 K/V大小和num_kv_heads成正比。MHA 里num_kv_heads num_q_headsGQA 里num_kv_heads更小所以 cache 的 K/V 张量更小。比如 32 个 Query Head、8 个 KV Head 时KV Cache 大约是 MHA 的四分之一。如果问“GQA、MQA 怎么区分”可以答MQA 是所有 Query Head 共享一个 KV Head省得最多但表达能力可能受影响。GQA 是折中方案多个 Query Head 分组共享多个 KV Head通常在效率和效果之间更平衡。如果问“代码里最容易错在哪里”可以答第一Q/K/V 投影维度不同第二Query Head 数要能整除 KV Head 数第三cache 里存的是未展开的 K/V不要把 repeat 后的 K/V 存进 cache第四使用 PyTorchenable_gqaTrue时要注意 mask 语义和后端限制。11. 一张速记表问题关键回答GQA 改了什么Query Head 多KV Head 少为什么能省显存KV Cache 大小和num_kv_heads成正比MHA 的 KV Head 数通常等于 Query Head 数MQA 的 KV Head 数1 个GQA 的 KV Head 数介于 1 和 Query Head 数之间代码核心约束num_q_heads % num_kv_heads 0cache 里存什么未展开的 K/V形状是[B, H_kv, T, D]attention 前做什么把 K/V 按组映射到 Query Head最适合讲的场景长上下文、自回归 decode、高并发推理PyTorch 接口scaled_dot_product_attention(..., enable_gqaTrue)总结GQA 可以用三句话记住MHA 每个 Query Head 通常都有自己的 K/VKV Cache 按 Query Head 数增长。GQA 让一组 Query Head 共享较少的 K/V HeadKV Cache 按 KV Head 数增长。它的主要价值出现在自回归推理尤其是长上下文和高并发服务里。所以学 GQA 不要只记住一个缩写。真正要记住的是这条线MHA 张量形状 - KV Cache 显存公式 - KV Head 数量 - Decode 访存压力 - GQA这条线讲清楚了GQA、MQA、KV Cache、长上下文推理优化就能串起来。参考资料Joshua Ainslie et al.GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpointshttps://arxiv.org/abs/2305.13245PyTorchtorch.nn.functional.scaled_dot_product_attentionhttps://docs.pytorch.org/docs/main/generated/torch.nn.functional.scaled_dot_product_attention.htmlHugging Face TransformersCachinghttps://huggingface.co/docs/transformers/main/cache_explanationHugging Face TransformersKV cache strategieshttps://huggingface.co/docs/transformers/main/kv_cacheHugging Face TransformersOptimizing LLMs for Speed and Memoryhttps://huggingface.co/docs/transformers/v4.35.2/llm_tutorial_optimizationvLLMPaged Attentionhttps://docs.vllm.ai/en/latest/design/paged_attention/