MTP

202596

15:50

Deepseek MTP借鉴了两个工作:1)Meta 的MTP,2)EAGLE,将这两个工作结合起来,将Meta的MTP的parallel head,换成causal head,解决LLM在training时的问题。

LLM在训练时存在的问题:
1)语言模型在训练时,和推理时有些不同,推理时是一个token一个token进行预测,将当前步生成的token拼接到序列中,然后再预测下一个token,如果当前生成token是错误的,那么会影响后面的token;但是训练时,我们强制输入正确的token,ground truth,即teacher forcing,这就导致训练和推理时存在一个mismatch不匹配。
2)模型每次只预测一个token,导致模型是近视的,只看到未来一个token,模型planning的能力较差,和人类的思考不同,人类通常是说一个字就想到下一个句子要说什么,而不是一个字一个字地往外蹦。(解决了模型planning的问题,就相当于解决了问题1),因为有了planning的能力,模型就能更好地预测一些难的transition(下图中的5->A),也就更能使预测的下一个token和ground truth一致。)

未命名图片.jpg 计算机生成了可选文字:
Model
predictions
5一,A
Groundtruth
4
3
2
1
5
4
3
A
5
4
2一,3一4
B
A
5
C
B
A
D
C
B
t
E
D
C
t
B

3)训练时只用到下一个token来计算loss,数据利用效率较低,或者说用来supervise的信息较少。

解决方案:有了这些问题,一个自然的方法就是一次预测多个token。
Meta 的MTP:用并行的多头,每个头来预测未来第i个位置的token,即:head1预测下一个token, head2预测下下一个token,head3预测下下下个token,,,,以此类推。

未命名图片.jpg 计算机生成了可选文字:
Discardedatinference()rusedtospeedupmodelupto3times)
4-token
targets
Head1
Shared
Inputs
未命名图片.jpg 计算机生成了可选文字:
Discardedatinference()rusedtospeedupmodelupto3times)
4-token
targets
Head1
Shared
Inputs

这里的head1,2,3,4和transformer块里的多头注意力multi head attention不同,这里不同的头是用来预测未来不同位置的token的,我觉得这里的头应该是分类头。每个头的训练都是之前的标准的transformer训练。
这样做存在的问题:违背了auto regressive的性质,不是基于….t-3,t-2,t-1时刻的token来生成t 时刻的 token,而是基于….t-6,t-5,t-4时刻的token来生成t 时刻的 token,中间缺失了几个token。因为t时刻其实是依赖于t-1的,所以这种方式从直观上看是有问题的。

LLM在推理时存在的问题:一个字,慢
解决方案:
1)kv cache
2)Speculative decoding:用一个小的语言模型或者是模型自身的某个module去猜测可能的输出,加速推理。主流是用模型自身的某一部分module来猜测,这种方法有2个代表作:1)medusa(多个并行头,和Meta mtp类似) 2)Eagle(causal的方式)

GPU中,data和模型的parameter存储在GPU的显存中,而真正的矩阵运算是在cache中,所以需要先将data和parameter从显存传输到cache中,才能进行计算。LLM在推理时的速度瓶颈主要是从显存到cache的传输速度,在推理时,data很小,因此传输主要是parameter,对于大模型来说,参数量大,传输慢,导致推理速度很慢。

未命名图片.jpg 计算机生成了可选文字:
々昼/\s
所以,可以用先用小模型进行推理,然后用大模型来验证小模型的推理结果,如果结果不对,就用大模型再推理。(大模型验证的速度是很快的,因为可以一次验证一整个序列,不需要一个个token的验证,就不需要进行cache和显存之间的传输。)
所以,可以用先用小模型进行推理,然后用大模型来验证小模型的推理结果,如果结果不对,就用大模型再推理。(大模型验证的速度是很快的,因为可以一次验证一整个序列,不需要一个个token的验证,就不需要进行cache和显存之间的传输。)
speculative decode就是两个部分:1)小模型quick guess 2)大模型cheap verification。

未命名图片.jpg 计算机生成了可选文字:
[STARTI2
!STARTI
[STARTI
!STARTI
ISTARTI
sbenchmarkbondn
Sbenchmarknikkei22
Sbenchmarknikkei225indexrose2276
@sbenchmarknikkei225indexrose226
.sbenchmarknik1225indexrose226
Sbenchmarknikkei225indexrose226
Sbenchmarknikkei225indexrose226
sbenchmarknikkei225indexrose226
@sbenchmarknikkei225indexrose226
69points
,0re1
69points.
or1
69points
0r1
69points
0r1
69points
or1
,to10
,9859
》to10
989
989
,9的
,in
79intekyøtate
79intatemorningtrading
,(END]
绿色是小模型生成的,蓝色是大模型修正的。

用来guess的小模型可以是Independent的,也可以是大模型其中的一个module(更推荐,因为只一个模型,系统复杂度更低。),两个代表工作,medusa 和EAGLE,都是用在推理阶段的,具体原理没搞明白,反正就是medusa是并行头,EAGLE是casual的。causal要好于并行的,Deepseek就采用了causal。

未命名图片.jpg 计算机生成了可选文字:
TargetTokens
Cross-EntropyLOSS
MainModel
《(NextTokenPCt/0司
OutputHead
TransformerBlock×L
EmbeddingLayer
Input丆bke們
•CMain
Shared
Shared
Cross-EntropyLOSS
MTPModule1
(Next?Tokened化00n丿
OutputHead
TransformerBlock
LinearProjection
concatenation
RMSNormRMSNorm
EmbeddingLayer
Cross-EntropyLOSS
MTPModule2
《e冠3Token件ed忆on丿
0utputHead
TransformerBlock
LinearProjection
concatenation
RMSNormRMSNorm
EmbeddingLayer
MTP
Shared
Shared

墨迹绘图

Deepseek的MTP训练,包含一个主模型Main model(这个就是正常的transformer部分),和两个MTP Module1和MTP Module2,MTP Module1是在预测未来第2个token的,MTP Module2是在预测未来第3个token的,这个和Meta MTP的做法一致,但是不同的是,主模块的feature被输入到MTP Module1,并和下一个token序列t2,t3,t4,t5的embedding,进行concat,即MTP Module1的输入既包含上一个模块的feature,又包含当前token的embedding信息,然后只经过一个transformer块(也就是说MTP Module是很轻量的,只包含了一个transformer块,而主模块仍然包含L个transformer块),然后经过output head输出(output head应该是分类头)。在训练时,loss是三个模块的loss之和,
在推理中,只用到主模块预测下一个token即可。由于模型在学习时,主模块出来的feature是包含有对future token有帮助的信息的,所以这种方式是能够增加模型的planning和reasoning的能力的。

墨迹绘图

 

已使用 OneNote 创建。