Deberta

2025918

10:21

将标准的注意力得分计算分解为4个部分:

 

  • 内容-内容项 (Content-to-Content): 基于词的语义内容向量Q_cK_c计算注意力。这与原始Transformer的计算方式相同。
  • 内容-位置项 (Content-to-Position): 基于当前词的Query内容Q_c和相对位置的Key向量K_r计算注意力。

位置-内容项 (Position-to-Content): 基于当前词的相对位置的Query向量Q_r和内容K_c计算注意力。

  • 位置-位置项 (Position-to-Position): 由于本身就是相对位置,所以这一项没有必要了。

 

 

P表示相对位置的position embedding,论文(上图)中说P的大小是2k*d,2k就是最大相对距离,超出这个距离的就进行截断,其实和绝对位置的position embedding是类似的,只不过相对位置rposition embedding的需要由两个token的位置ij得到,而绝对位置ipositio embedding只需要一个位置i即可。举例来解释:

 

下表表示绝对位置position embedding 矩阵AA_i表示序列第itokenposition embedding

位置

embedding_dim0

embedding_dim1

embedding_dim2

绝对位置0(i=0

0.1

0.2

0.1

绝对位置1(i=1

..

..

..

绝对位置2(i=2

..

..

..

绝对位置3(i=3

..

..

..

绝对位置4(i=4

..

..

..

 

下表表示相对位置position embedding 矩阵PP_r表示相对距离为r=i-j的位置position embedding

对于下表这个例子来说,最大相对距离是6,即2*k=6k=3。为了使r都是正数,可以让r+k,这样r最小就是0,最大是2k

位置

embedding_dim0

embedding_dim1

embedding_dim2

相对位置-2(i-j=-2

0.1

0.2

0.1

相对位置-1(i-j=-1

..

..

..

相对位置0(i-j=-0

..

..

..

相对位置1(i-j=1

..

..

..

相对位置2(i-j=2

..

..

..

相对位置3(i-j=3

 

 

 

 

所以说,对于绝对位置编码来说,位置iposition embeddingA_i,可以直接和该tokencontent embedding相加来表示token的整体embeddingA_i+C_i;但是对于相对位置编码来说,P_r其实是P_(i-j),表示相对距离为r=i-j的位置position embedding,无法直接和token icontent embedding相加,A_i+P_(i-j)不能表示 token i的整体embedding

也就是说,如果采用相对位置编码,就只能在计算Attention score时引入相对位置信息,因为只有在计算attention score时,才会用到两个token ijAttention_i,j = …. ,才能引入相对位置P_(i-j)

 

上图中说到,P表示相对位置的position embedding,并且对于所有层是共享的,也就是说和绝对位置编码的做法不同,deberta中每一层都需要注入相对位置信息,而标准的bert中只需要在初始输入时将content embeddingposition embedding相加作为输入即可,后续层不需要显式注入位置信息。

 

 

 

相对位置编码的核心是在 attention 计算中引入 token 间的相对距离信息,除了deberta这种disentangle的方式,还有:

  • 直接将相对位置编码作为偏置项融入 content-based 的 attention;
  • ROPE:核心思想: 绝对位置信息可以通过旋转矩阵编码到Query和Key向量中。对于位置为 m 的词元,它的Query和Key向量会用一个依赖于 m 的旋转矩阵 R_θ, m 进行旋转。

计算过程:

对Q和K向量应用旋转变换:Q'_m = R_θ, m * Q_m, K'_n = R_θ, n * K_n。

用变换后的Q'和K'计算注意力分数。

数学魔法: 变换后的注意力分数天然地包含了相对位置信息。

Q'_m^T * K'_n = (R_θ, m * Q_m)^T * (R_θ, n * K_n) = Q_m^T R_θ, {m-n} K_n

最终的结果 R_θ, {m-n} 只依赖于相对位置 (m-n)!

 

 

 

 

 

 

已使用 OneNote 创建。