MLA

2025726

16:17

参考:
苏剑林. (May. 13, 2024). 《缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA 》[Blog post].   

1.MHA
MHA(Multi-Head Attention),也就是多头注意力,是开山之作《Attention is all you need》所提出的一种Attention形式,可以说它是当前主流LLM的基础工作。在数学上,多头注意力MHA等价于多个独立的单头注意力的拼接,
未命名图片.jpg 计算机生成了可选文字:
Ot
Ot
ton,<“鬱0
eIRdk
eIRdk
Wq
W
=CiVV
=CiVV
eRdxdk
eRdxdk
eRdxd
其中,﷐𝒐﷮𝑡﷯表示第t个token的attention之后的embedding,﷐﷐𝒐﷮𝑡﷯﷮(𝑠)﷯表示第t个token在第s个attention head的embedding。
简单起见,这里省略了Attention矩阵的缩放因子。实践上,常见的设置是dk=dv=d/h,对于LLAMA2-7b有d=4096,h=32,dk=dv=128,LLAMA2-70b则是d=8192,h=64,dk=dv=128

由于这里只考虑了主流的自回归LLM所用的Causal Attention,因此在token by token递归生成时,新预测出来的第t+1
个token,并不会影响到已经算好的k(s)≤t,v(s)≤t,因此这部分结果我们可以缓存下来供后续生成调用,避免不必要的重复计算,这就是所谓的KV Cache。而后面的MQA、GQA、MLA,都是围绕“如何减少KV Cache同时尽可能地保证效果”这个主题发展而来的产物。

为什么降低KV Cache的大小如此重要?
众所周知,一般情况下LLM的推理都是在GPU上进行,单张GPU的显存是有限的,一部分我们要用来存放模型的参数和前向计算的激活值,这部分依赖于模型的体量,选定模型后它就是个常数;另外一部分我们要用来存放模型的KV Cache,这部分不仅依赖于模型的体量,还依赖于模型的输入长度,也就是在推理过程中是动态增长的,当Context长度足够长时,它的大小就会占主导地位,可能超出一张卡甚至一台机(8张卡)的总显存量。

在GPU上部署模型的原则是:能一张卡部署的,就不要跨多张卡;能一台机部署的,就不要跨多台机。这是因为“卡内通信带宽 > 卡间通信带宽 > 机间通信带宽”,由于“木桶效应”,模型部署时跨的设备越多,受设备间通信带宽的的“拖累”就越大,事实上即便是单卡H100内SRAM与HBM的带宽已经达到了3TB/s,但对于Short Context来说这个速度依然还是推理的瓶颈,更不用说更慢的卡间、机间通信了。

所以,减少KV Cache的目的就是要实现在更少的设备上推理更长的Context,或者在相同的Context长度下让推理的batch size更大,从而实现更快的推理速度或者更大的吞吐总量。当然,最终目的都是为了实现更低的推理成本。

2.MQA
MQA,即“Multi-Query Attention”,是减少KV Cache的一次非常朴素的尝试,首次提出自《Fast Transformer Decoding: One Write-Head is All You Need》,这已经是2019年的论文了,这也意味着早在LLM火热之前,减少KV Cache就已经是研究人员非常关注的一个课题了。
MQA的思路很简单,直接让所有Attention Head共享同一个K、V,用公式来说,就是取消MHA所有的k,v
的上标(s):
未命名图片.jpg 计算机生成了可选文字:
O
ot
Ot
ton,
=W
=W
,Ot
eIRdk
eRdk
Et<t。№(q
Wq
WV
eRdxdk
eIR×
eRdxd
个token,并不会影响到已经算好的k(s)≤t,v(s)≤t,因此这部分结果我们可以缓存下来供后续生成调用,避免不必要的重复计算,这就是所谓的KV Cache。而后面的MQA、GQA、MLA,都是围绕“如何减少KV Cache同时尽可能地保证效果”这个主题发展而来的产物。
众所周知,一般情况下LLM的推理都是在GPU上进行,单张GPU的显存是有限的,一部分我们要用来存放模型的参数和前向计算的激活值,这部分依赖于模型的体量,选定模型后它就是个常数;另外一部分我们要用来存放模型的KV Cache,这部分不仅依赖于模型的体量,还依赖于模型的输入长度,也就是在推理过程中是动态增长的,当Context长度足够长时,它的大小就会占主导地位,可能超出一张卡甚至一台机(8张卡)的总显存量。
在GPU上部署模型的原则是:能一张卡部署的,就不要跨多张卡;能一台机部署的,就不要跨多台机。这是因为“卡内通信带宽 > 卡间通信带宽 > 机间通信带宽”,由于“木桶效应”,模型部署时跨的设备越多,受设备间通信带宽的的“拖累”就越大,事实上即便是单卡H100内SRAM与HBM的带宽已经达到了3TB/s,但对于Short Context来说这个速度依然还是推理的瓶颈,更不用说更慢的卡间、机间通信了。
MQA,即“Multi-Query Attention”,是减少KV Cache的一次非常朴素的尝试,首次提出自《Fast Transformer Decoding: One Write-Head is All You Need》,这已经是2019年的论文了,这也意味着早在LLM火热之前,减少KV Cache就已经是研究人员非常关注的一个课题了。
未命名图片.jpg 计算机生成了可选文字:
O
ot
Ot
ton,
=W
=W
,Ot
eIRdk
eRdk
Et<t。№(q
Wq
WV
eRdxdk
eIR×
eRdxd
使用MQA的模型包括PaLM、StarCoder、Gemini等。很明显,MQA直接将KV Cache减少到了原来的1/h
,这是非常可观的,单从节省显存角度看已经是天花板了。效果方面,目前看来大部分任务的损失都比较有限,且MQA的支持者相信这部分损失可以通过进一步训练来弥补回。此外,注意到MQA由于共享了K、V,将会导致Attention的参数量减少了将近一半,而为了模型总参数量的不变,通常会相应地增大FFN/GLU的规模,这也能弥补一部分效果损失。

3.GQA
然而,也有人担心MQA对KV Cache的压缩太严重,以至于会影响模型的学习效率以及最终效果。为此,一个MHA与MQA之间的过渡版本GQA(Grouped-Query Attention)应运而生,出自论文《GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints》。
GQA的思想也很朴素,它就是将所有Head分为g个组(g可以整除h),每组共享同一对K、V,用数学公式表示为
未命名图片.jpg 计算机生成了可选文字:
Ot
<老
tonqt,
qt
([sg/hl)
k
([sg/hl)
Ot
([sg/hl)
<t
e]Rdk
(Isg/hl)T([sg/hl)
([sg/hl)T
'eR×
([sg/hl)eRdk厂
([sg/hl)
eIR×
([sg/hl)eRd",厂
([sg/hl)
eIR×

总结:MHA有n_head个﷐𝑾﷮𝒌﷯和﷐𝑾﷮𝒗﷯,MQA有1个﷐𝑾﷮𝒌﷯和﷐𝑾﷮𝒗﷯,GQA有n_group个﷐𝑾﷮𝒌﷯和﷐𝑾﷮𝒗﷯。
,这是非常可观的,单从节省显存角度看已经是天花板了。效果方面,目前看来大部分任务的损失都比较有限,且MQA的支持者相信这部分损失可以通过进一步训练来弥补回。此外,注意到MQA由于共享了K、V,将会导致Attention的参数量减少了将近一半,而为了模型总参数量的不变,通常会相应地增大FFN/GLU的规模,这也能弥补一部分效果损失。
然而,也有人担心MQA对KV Cache的压缩太严重,以至于会影响模型的学习效率以及最终效果。为此,一个MHA与MQA之间的过渡版本GQA(Grouped-Query Attention)应运而生,出自论文《GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints》。
总结:MHA有n_head个﷐𝑾﷮𝒌﷯和﷐𝑾﷮𝒗﷯,MQA有1个﷐𝑾﷮𝒌﷯和﷐𝑾﷮𝒗﷯,GQA有n_group个﷐𝑾﷮𝒌﷯和﷐𝑾﷮𝒗﷯。

GQA最知名的使用者,大概是Meta开源的LLAMA2-70B,以及LLAMA3全系列,此外使用GQA的模型还有TigerBot、DeepSeek-V1、StarCoder2、Yi、ChatGLM2、ChatGLM3等,相比使用MQA的模型更多。

在llama2/3-70B中,GQA的g=8,其他用了GQA的同体量模型基本上也保持了这个设置,这并非偶然,而是同样出于推理效率的考虑。我们知道,70B这个体量的模型,如果不进行极端的量化,那么不可能部署到单卡(A100/H100 80G)上。单卡不行,那么就能单机了,一般情况下一台机可以装8张卡,刚才我们说了,Attention的每个Head实际上是独立运算然后拼接起来的,当g=8
时,正好可以每张卡负责计算一组K、V对应的Attention Head,这样可以在尽可能保证K、V多样性的同时最大程度上减少卡间通信。

4.MLA
  1)PART1
GQA虽然对MQA进行了改进,但相比于MHA,仍然存在对KV的压缩。
MLA的做法是,假设token embedding ﷐𝒙﷮𝑖﷯存在隐空间latent vector的表示﷐𝒄﷮𝑖﷯, 在缓存时,能否只缓存这个低维的﷐𝒄﷮𝒊﷯ ,而不缓存k和v,这样就能大大减少kv cache。与此同时,仍然采用MHA的做法,保留n个head的每个head具有独立的k和v(即有n_head个﷐𝑾﷮𝒌﷯和﷐𝑾﷮𝒗﷯),而不是像MQA或GQA,存在多个头共享k和v的情况。
在计算k和v时,先将﷐𝒙﷮𝑖﷯投影到潜空间(低秩空间),得到﷐𝒄﷮𝑖﷯,然后再乘以﷐𝑾﷮𝒌﷯和﷐𝑾﷮𝒗﷯,得到k和v:
未命名图片.jpg 计算机生成了可选文字:
Ot
Ot
æt•VVeIRdk
q
'(司eRdk
k
=eRd"
CiVVceRd
0
'eR×
厂(司eIRdc×
厂(司eIRdcxd
!!注意,这里的﷐﷐𝑾﷮𝑘﷯﷮(𝑠)﷯是﷐𝑑﷮𝑐﷯∗﷐𝑑﷮𝑘﷯的,因为是对隐向量﷐𝒄﷮𝑖﷯进行操作,而不是﷐𝒙﷮𝑖﷯,而MHA和MQA中的﷐﷐𝑾﷮𝑘﷯﷮(𝑠)﷯是𝑑∗﷐𝑑﷮𝑘﷯的!

进行了这样的修改后,MLA在推理阶段计算q*k时,
未命名图片.jpg 
在llama2/3-70B中,GQA的g=8,其他用了GQA的同体量模型基本上也保持了这个设置,这并非偶然,而是同样出于推理效率的考虑。我们知道,70B这个体量的模型,如果不进行极端的量化,那么不可能部署到单卡(A100/H100 80G)上。单卡不行,那么就能单机了,一般情况下一台机可以装8张卡,刚才我们说了,Attention的每个Head实际上是独立运算然后拼接起来的,当g=8
MLA的做法是,假设token embedding ﷐𝒙﷮𝑖﷯存在隐空间latent vector的表示﷐𝒄﷮𝑖﷯, 在缓存时,能否只缓存这个低维的﷐𝒄﷮𝒊﷯ ,而不缓存k和v,这样就能大大减少kv cache。与此同时,仍然采用MHA的做法,保留n个head的每个head具有独立的k和v(即有n_head个﷐𝑾﷮𝒌﷯和﷐𝑾﷮𝒗﷯),而不是像MQA或GQA,存在多个头共享k和v的情况。
未命名图片.jpg 
相比于MHA的推理阶段计算q*k时(只需要计算﷐﷐﷐𝒙﷮𝑖﷯∗𝑾﷮𝑞﷯﷮(𝑠)﷯,k是缓存的,不需要计算),MLA多了计算﷐﷐﷐𝒄﷮𝑖﷯∗𝑾﷮𝑘﷯﷮(𝑠)﷯这一步,引入了额外计算量。解决方法是通过结合律,将﷐﷐𝑾﷮𝑞﷯﷮(𝑠)﷯﷐﷐𝑾﷮𝑘﷯﷮﷐𝑠﷯𝑇﷯合并为一个新的矩阵。同理,在﷐𝒐﷮𝑡﷯后面我们还有一个投影矩阵﷐𝑾﷮𝒐﷯,于是﷐﷐𝒗﷮𝑖﷯﷮(𝑠)﷯= ﷐﷐﷐𝒄﷮𝑖﷯∗𝑾﷮𝑣﷯﷮(𝑠)﷯中的﷐﷐𝑾﷮𝑣﷯﷮(𝑠)﷯也可以吸收到后面的投影矩阵﷐𝑾﷮𝒐﷯中去(这个地方没明白,维度不一致怎么吸收?),也就是说此时KV Cache只需要存下所有的﷐𝒄﷮𝑖﷯就行,而不至于存下所有的k和v。注意到﷐𝒄﷮𝒊﷯跟(s)无关,也就是说是所有头共享的。

2)PART2
MLA有一个难以绕开的缺陷——不兼容RoPE(旋转位置编码)。
刚才我们说了,MLA的关键一步是将﷐﷐𝑾﷮𝑞﷯﷮(𝑠)﷯﷐﷐𝑾﷮𝑘﷯﷮﷐𝑠﷯𝑇﷯合并为一个新的矩阵作为Q的投影矩阵,这个矩阵是固定不变的,但如果加了RoPE的话,这一步就无法实现了。这是因为RoPE是跟位置相关的,加入RoPE后又引入了额外的计算量﷐𝓡﷮𝑡−𝑖﷯:
未命名图片.jpg 计算机生成了可选文字:
(司T
W(s)T
T
前段时间,笔者也很荣幸跟DeepSeek团队讨论过这个问题,但这个问题可以说非常本质,所以当时笔者实际上也没能提出什么有效的建议。最简单的方式是放弃RoPE,换用其他基于Attention Bias的位置编码,如﷟HYPERLINK "https://spaces.ac.cn/archives/9431#ALIBI"ALIBI,但DeepSeek的实验显示它明显不如RoPE(注意,MLA不是不能加RoPE,而是加了RoPE之后无法用恒等变换技巧来减少KV Cache),笔者也提议过换﷟HYPERLINK "https://spaces.ac.cn/archives/9431#Sandwich"Sandwich,它不像ALIBI单调衰减到负无穷,估计效果会好些,但感觉是治标不治本。还有一个折中的办法是将旋转这个操作在q,k之前进行,也就是对﷐𝒄﷮𝑖﷯进行旋转,而不是对﷐𝒒﷮𝑖﷯和﷐𝒌﷮𝑖﷯旋转,此时需要将﷐𝒒﷮𝑖﷯输入也改为﷐𝒄﷮𝑖﷯,然后RoPE加在﷐𝒄﷮𝑖﷯之后,即
未命名图片.jpg 计算机生成了可选文字:
(8)
这样'Rt就可以吸收到0中去,但这样就没有R,,,R,,=Ran-n的运篡了,此时的ROPE不再是诵讨绝
对亻立置实现相对亻立置,而单纯是在Q、K上加绝对位置,让模型自己想办法提炼相对位置信息。
相比于原始的ROPE,这种方法使得q和k失去了明面上的相对位置这个信息(虽然﷐𝒄﷮𝒕﷯和﷐𝒄﷮𝒊﷯进行了不同角度的旋转,但是投影后得到的﷐𝒒﷮𝒕﷯和﷐𝒌﷮𝒊﷯相差的角度不再是(t-i)*θ),而是让模型自己想办法提炼出相对位置信息。

最后发布的MLA,采取了一种混合的方法——每个Attention Head的Q、K新增dr个维度作为position embedding用来添加RoPE,其中K新增的维度每个Head共享:
未命名图片.jpg 计算机生成了可选文字:
Ot
Ot
,Ot
曰onqt,<“鬱0
,ahVV
eR+碍
eIRdk+dr
一C*WV(s)eRd"
0=CiWce
WeRd><dkW(s)e
厂eIRdcxdkWeIR
dxdr
W以到eRdcXdu
WceRdxd
相比于MHA的推理阶段计算q*k时(只需要计算﷐﷐﷐𝒙﷮𝑖﷯∗𝑾﷮𝑞﷯﷮(𝑠)﷯,k是缓存的,不需要计算),MLA多了计算﷐﷐﷐𝒄﷮𝑖﷯∗𝑾﷮𝑘﷯﷮(𝑠)﷯这一步,引入了额外计算量。解决方法是通过结合律,将﷐﷐𝑾﷮𝑞﷯﷮(𝑠)﷯﷐﷐𝑾﷮𝑘﷯﷮﷐𝑠﷯𝑇﷯合并为一个新的矩阵。同理,在﷐𝒐﷮𝑡﷯后面我们还有一个投影矩阵﷐𝑾﷮𝒐﷯,于是﷐﷐𝒗﷮𝑖﷯﷮(𝑠)﷯= ﷐﷐﷐𝒄﷮𝑖﷯∗𝑾﷮𝑣﷯﷮(𝑠)﷯中的﷐﷐𝑾﷮𝑣﷯﷮(𝑠)﷯也可以吸收到后面的投影矩阵﷐𝑾﷮𝒐﷯中去(这个地方没明白,维度不一致怎么吸收?),也就是说此时KV Cache只需要存下所有的﷐𝒄﷮𝑖﷯就行,而不至于存下所有的k和v。注意到﷐𝒄﷮𝒊﷯跟(s)无关,也就是说是所有头共享的。
前段时间,笔者也很荣幸跟DeepSeek团队讨论过这个问题,但这个问题可以说非常本质,所以当时笔者实际上也没能提出什么有效的建议。最简单的方式是放弃RoPE,换用其他基于Attention Bias的位置编码,如﷟HYPERLINK "https://spaces.ac.cn/archives/9431#ALIBI"ALIBI,但DeepSeek的实验显示它明显不如RoPE(注意,MLA不是不能加RoPE,而是加了RoPE之后无法用恒等变换技巧来减少KV Cache),笔者也提议过换﷟HYPERLINK "https://spaces.ac.cn/archives/9431#Sandwich"Sandwich,它不像ALIBI单调衰减到负无穷,估计效果会好些,但感觉是治标不治本。还有一个折中的办法是将旋转这个操作在q,k之前进行,也就是对﷐𝒄﷮𝑖﷯进行旋转,而不是对﷐𝒒﷮𝑖﷯和﷐𝒌﷮𝑖﷯旋转,此时需要将﷐𝒒﷮𝑖﷯输入也改为﷐𝒄﷮𝑖﷯,然后RoPE加在﷐𝒄﷮𝑖﷯之后,即
未命名图片.jpg 计算机生成了可选文字:
Ot
Ot
,Ot
曰onqt,<“鬱0
,ahVV
eR+碍
eIRdk+dr
一C*WV(s)eRd"
0=CiWce
WeRd><dkW(s)e
厂eIRdcxdkWeIR
dxdr
W以到eRdcXdu
WceRdxd

这样一来,没有RoPE的维度就可以重复“Part 1”的操作,在推理时KV Cache只需要存﷐𝒄﷮𝒊﷯,新增的带RoPE的维度就可以用来补充位置信息,并且由于所有Head共享,所以也就只有在K Cache这里增加了dr个维度(也就是将﷐𝒙﷮𝑖﷯﷐𝑾﷮k𝒓﷯﷐𝓡﷮𝑖﷯的结果缓存)。(为什么这里k的rope部分,要所有Head共享,而不是每个head独立,因为这部分是要缓存的,为了减少缓存)

3)PART3
最后有一个细节,就是MLA的最终版本,还将Q的输入也改为了低秩投影形式,这与减少KV Cache无关,主要是为了减少训练期间参数量和相应的梯度(原论文说的是激活值,个人表示不大理解)所占的显存:
在训练阶段:
未命名图片.jpg 计算机生成了可选文字:
Ot
Ot
,Ot
eonqt,<“鬱<老
eIRdk+dr
eIRdk+dr
=eRd"
c'.=W'eIRdc
CiWceIR
0
WeIR×W(s)e
WeIR×WeIR×碍
厂(司eRdcxd
W'eIRdXdc
WceRdxd
注意﷐﷐𝒌﷮𝑖﷯﷮(𝑠)﷯中的第二项,带RoPE的部分,其输入还是﷐𝒙﷮𝑖﷯而不是﷐𝒄﷮𝑖﷯。(为什么﷐﷐𝒒﷮𝒊﷯﷮(𝒔)﷯中的第二项的RoPE部分,输入是低维隐向量﷐﷐𝒄﷮𝒊﷯﷮′﷯,而﷐﷐𝒌﷮𝒊﷯﷮(𝒔)﷯中的第二项的RoPE部分,输入是原始embedding﷐𝒙﷮𝒊﷯,我的回答是因为﷐﷐𝒌﷮𝒊﷯﷮(𝒔)﷯中的第二项的RoPE部分是多头共享的,﷐𝑾﷮𝒌𝒓﷯只有一个,所以要用﷐𝒙﷮𝒊﷯作为输入进行计算,得到更多的信息(保证表达能力),而﷐﷐𝒒﷮𝒊﷯﷮(𝒔)﷯中的第二项的RoPE部分是每个头独立的,每个头的﷐𝑾﷮𝒒𝒓﷯是不同的,所以可以用低维隐向量﷐﷐𝒄﷮𝒊﷯﷮′﷯作为输入,每个头也能得到各自的信息。)

我们把带RoPE的MHA放在下面,方便大家对比,可以发现,其实在训练阶段,除了多了一步低秩投影以及只在部分维度加RoPE外,MLA与Q、K的Head Size由dk换成dk+dr的MHA基本无异。
未命名图片.jpg 计算机生成了可选文字:
Ot
Ot
ton,<“鬱0
eIRdk
eIRdk
Wq
W
=CiVV
=CiVV
eRdxdk
eRdxdk
eRdxd
墨迹绘图
这样一来,没有RoPE的维度就可以重复“Part 1”的操作,在推理时KV Cache只需要存﷐𝒄﷮𝒊﷯,新增的带RoPE的维度就可以用来补充位置信息,并且由于所有Head共享,所以也就只有在K Cache这里增加了dr个维度(也就是将﷐𝒙﷮𝑖﷯﷐𝑾﷮k𝒓﷯﷐𝓡﷮𝑖﷯的结果缓存)。(为什么这里k的rope部分,要所有Head共享,而不是每个head独立,因为这部分是要缓存的,为了减少缓存)
注意﷐﷐𝒌﷮𝑖﷯﷮(𝑠)﷯中的第二项,带RoPE的部分,其输入还是﷐𝒙﷮𝑖﷯而不是﷐𝒄﷮𝑖﷯。(为什么﷐﷐𝒒﷮𝒊﷯﷮(𝒔)﷯中的第二项的RoPE部分,输入是低维隐向量﷐﷐𝒄﷮𝒊﷯﷮′﷯,而﷐﷐𝒌﷮𝒊﷯﷮(𝒔)﷯中的第二项的RoPE部分,输入是原始embedding﷐𝒙﷮𝒊﷯,我的回答是因为﷐﷐𝒌﷮𝒊﷯﷮(𝒔)﷯中的第二项的RoPE部分是多头共享的,﷐𝑾﷮𝒌𝒓﷯只有一个,所以要用﷐𝒙﷮𝒊﷯作为输入进行计算,得到更多的信息(保证表达能力),而﷐﷐𝒒﷮𝒊﷯﷮(𝒔)﷯中的第二项的RoPE部分是每个头独立的,每个头的﷐𝑾﷮𝒒𝒓﷯是不同的,所以可以用低维隐向量﷐﷐𝒄﷮𝒊﷯﷮′﷯作为输入,每个头也能得到各自的信息。)
未命名图片.jpg 计算机生成了可选文字:
Ot
Ot
ton,<“鬱0
eIRdk
eIRdk
Wq
W
=CiVV
=CiVV
eRdxdk
eRdxdk
eRdxd
在解码阶段的MLA则改为MQA形式,
未命名图片.jpg 计算机生成了可选文字:
0Wt,
0(2)只2)
E。。№@戾
E,。№@戾
eRdc+dr
[G,w,R]e
W(s)eR×W@)eR×以司eWe
c'=CiW'eRdc
XiVVeRdc
0
WceIR
推理阶段与训练阶段的唯一区别在于,将﷐﷐𝒌﷮𝒊﷯﷮(𝒔)﷯的第一部分 ﷐𝒄﷮𝒊﷯﷐﷐𝑾﷮𝒌𝒄﷯﷮(𝒔)﷯中的﷐﷐𝑾﷮𝒌𝒄﷯﷮(𝒔)﷯挪到了﷐﷐𝒒﷮𝒊﷯﷮(𝒔)﷯的第一部分中,进行合并。由于﷐﷐𝒌﷮𝒊﷯﷮(𝒔)﷯的第二部分rope部分是多头共享的,所以﷐﷐𝒌﷮𝒊﷯﷮(𝒔)﷯整体变成了﷐𝒌﷮𝑖﷯,变成了多头共享,即MQA的形式。

总结:MLA主要是将q和k分为了两部分,token embedding query/key和position embedding query/key,将rope部分单独拎了出来。将q的 第一部分 和k的 第一部分  进行点积,将q的 第二部分 和k的 第二部分  进行点积。k的 第一部分是MHA,在推理时,通过矩阵吸收,将这部分变成了多头共享,每个头都是﷐𝒄﷮𝑖 ﷯ 即MQA,第二部分rope本身就是MQA,所以在推理时,k整体变成了MQA。
所以在推理时MLA的缓存(训练时不需要缓存),只需要缓存﷐𝒄﷮𝒊 ﷯和长度为dr的rope向量(k的第二部分)。MLA的KV Cache大小跟h无关。
未命名图片.jpg 计算机生成了可选文字:
补充说明:
1w小)w小)T合并成一个矩阵的恒等变换,理论上只有在无限精度下才成立,实际上如果我们使用单
精度尤其是BF16的话,经过变换后的精度损失往往还是挺明显的,经过多层累积后可能放大到比较可观
2、实际上我们一般不按照Wq)w小)')来计算Q,而是按照ætWq)w小)'来计算,这样虽然
是串行的,但在亻失假设下计篡最更少,并且理论精度的损失也更少,不过在文章中,我们亻乃按照
小)W'固T合并成一个矩阵来介绍。
总结:MLA主要是将q和k分为了两部分,token embedding query/key和position embedding query/key,将rope部分单独拎了出来。将q的 第一部分 和k的 第一部分  进行点积,将q的 第二部分 和k的 第二部分  进行点积。k的 第一部分是MHA,在推理时,通过矩阵吸收,将这部分变成了多头共享,每个头都是﷐𝒄﷮𝑖 ﷯ 即MQA,第二部分rope本身就是MQA,所以在推理时,k整体变成了MQA。
未命名图片.jpg 计算机生成了可选文字:
补充说明:
1w小)w小)T合并成一个矩阵的恒等变换,理论上只有在无限精度下才成立,实际上如果我们使用单
精度尤其是BF16的话,经过变换后的精度损失往往还是挺明显的,经过多层累积后可能放大到比较可观
2、实际上我们一般不按照Wq)w小)')来计算Q,而是按照ætWq)w小)'来计算,这样虽然
是串行的,但在亻失假设下计篡最更少,并且理论精度的损失也更少,不过在文章中,我们亻乃按照
小)W'固T合并成一个矩阵来介绍。


未命名图片.png 计算机生成了可选文字:
Multi-HeadLatentAttention(MLA)
[》CachedDuringInference《
Attention
OutputHidden0000,以以0000
RMSNorm
:{[qf.i;q{il)000
Multi-HeadAttention
{时叾k000
{qf.i)00O
ROPE
00“00Latent-ct
InputHidden0000
00
Latentc5V母
以0000
00
官方代码:https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py

class MLA(nn.Module):
    """
    Multi-Head Latent Attention (MLA) Layer.

    Attributes:
        dim (int): Dimensionality of the input features.
        n_heads (int): Number of attention heads.
        n_local_heads (int): Number of local attention heads for distributed systems.
        q_lora_rank (int): Rank for low-rank query projection.
        kv_lora_rank (int): Rank for low-rank key/value projection.
        qk_nope_head_dim (int): Dimensionality of non-positional query/key projections.
        qk_nope_head_dim (int): Dimensionality of non-positional query/key projections.
        qk_rope_head_dim (int): Dimensionality of rotary-positional query/key projections.
        qk_head_dim (int): Total dimensionality of query/key projections.
        v_head_dim (int): Dimensionality of value projections.
        softmax_scale (float): Scaling factor for softmax in attention computation.
    """
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.dim = args.dim
        self.n_heads = args.n_heads
        self.n_local_heads = args.n_heads // world_size
        self.q_lora_rank = args.q_lora_rank
        self.kv_lora_rank = args.kv_lora_rank
        self.qk_nope_head_dim = args.qk_nope_head_dim
        self.qk_rope_head_dim = args.qk_rope_head_dim
        self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim
        self.v_head_dim = args.v_head_dim

        if self.q_lora_rank == 0:
            self.wq = ColumnParallelLinear(self.dim, self.n_heads * self.qk_head_dim)
        else:
            self.wq_a = Linear(self.dim, self.q_lora_rank)
            self.q_norm = RMSNorm(self.q_lora_rank)
            self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim)
        self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
        self.kv_norm = RMSNorm(self.kv_lora_rank)
        self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
        self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
        self.softmax_scale = self.qk_head_dim ** -0.5
        if args.max_seq_len > args.original_seq_len:
            mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
            self.softmax_scale = self.softmax_scale * mscale * mscale

        if attn_impl == "naive":
            self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim), persistent=False)
            self.register_buffer("v_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim), persistent=False)
            self.register_buffer("v_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim), persistent=False)
        else:
            self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
            self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)

    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
        """
        Forward pass for the Multi-Head Latent Attention (MLA) Layer.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
            start_pos (int): Starting position in the sequence for caching.
            freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
            mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.

        Returns:
            torch.Tensor: Output tensor with the same shape as the input.
        """
        bsz, seqlen, _ = x.size()
        end_pos = start_pos + seqlen
        if self.q_lora_rank == 0:
            q = self.wq(x)
        else:
            q = self.wq_b(self.q_norm(self.wq_a(x)))
        q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
        q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        q_pe = apply_rotary_emb(q_pe, freqs_cis)
        kv = self.wkv_a(x)
        kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
        if attn_impl == "naive":
            q = torch.cat([q_nope, q_pe], dim=-1)
            kv = self.wkv_b(self.kv_norm(kv))
            kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
            k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
            k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
            k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
            self.k_cache[:bsz, start_pos:end_pos] = k
            self.v_cache[:bsz, start_pos:end_pos] = v
            scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale
        else:
            wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size) 
            wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
            q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
            self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
            self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
            scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
                      torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
        if mask is not None:
            scores += mask.unsqueeze(1)
        scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
        if attn_impl == "naive":
            x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
        else:
            x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
            x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
        x = self.wo(x.flatten(2))
        return x

 

已使用 OneNote 创建。