MLA是MHA的变体,因此先来看看MHA。

MHA(多头注意力)

MHA通过将输入向量分割成多个并行的注意力“头”,每个头独立地计算注意力权重并产生输出,然后将这些输出通过拼接和线性变换进行合并以生成最终的注意力表示。

将 Q Q Q分成了多个部分,每个部分进行注意力。比如 Q Q Q的形状 [ s e q , d i ] [seq,d_i] [seq,di​]、 K T K^T KT的形状 [ d i , s e q ] [d_i,seq] [di​,seq]、 V V V的形状 [ s e q , d i ] [seq,d_i] [seq,di​],则有 Q k T Qk^T QkT的形状为 [ s e q , s e q ] , s o f t m a x ( Q K T d k ) V [seq,seq],softmax(\frac{QK^T}{\sqrt{d_k}})V [seq,seq],softmax(dk​ ​QKT​)V的形状为 [ s e q , d i [seq,d_i [seq,di​ 也就是说每一个注意力之后的 h e a d i head_i headi​的形状都是 [ s e q , d i ] [seq,d_i] [seq,di​],这和 Q ‘ Q^{`} Q‘的形状一样,拼接起来得到的 H H H的形状和直接使用自注意力机制的形状是一样的。这里使用了一个 W O W^O WO,进行整合(合并头:将所有头的输出合并回一个大的张量)。最后一个线性层:对合并后的输出应用另一个线性变换。

其中权重矩阵 Q , K , V Q,K,V Q,K,V变化概括就是:将 Q , K , V Q,K,V Q,K,V划分成多头,并行处理。但这里的头并不是对 X X X进行多次线性变换,而是对之后的 Q , K , V Q,K,V Q,K,V划分成多个部分,每个部分进行计算,最后拼接。

h e a d i = A t t e n t i o n ( Q W i Q , K W i K , V W i V ) head_i=Attention(QW_i^Q,KW_i^K,VW_i^V) headi​=Attention(QWiQ​,KWiK​,VWiV​),每个头对 Q , K , V Q,K,V Q,K,V进行变换后进行注意力机制

M u l t i H e a d ( Q , K , V ) = C o n c a t ( h e a d 1 , h e a d 2 , … , h e a d h ) W O MultiHead(Q,K,V)=Concat(head_1,head_2,\ldots,head_h)W^O MultiHead(Q,K,V)=Concat(head1​,head2​,…,headh​)WO

MHA 能够理解输入不同部分之间的关系。然而,这种复杂性是有代价的——对内存带宽的需求很大,尤其是在解码器推理期间。主要问题的关键在于内存开销。在自回归模型中,每个解码步骤都需要加载解码器权重以及所有注意键和值。这个过程不仅计算量大,而且内存带宽也大。随着模型规模的扩大,这种开销也会增加,使得扩展变得越来越艰巨。

MLA(多头潜在注意力)

概念:

  • 多头注意力机制:Transformer 的核心模块,能够通过多个注意力头并行捕捉输入序列中的多样化特征。
  • 潜在表示学习:通过将高维输入映射到低维潜在空间,可以提取更抽象的语义特征,同时有效减少计算复杂度。

问题:

1.效率问题:传统多头注意力的计算复杂度为 O ( n 2 d ) O(n^2d) O(n2d),即随着序列长度的增长,键值(Key-Value,KV)缓存的大小也会线性增加,这给模型带来了显著的内存负担。)
2.表达能力瓶颈:难以充分捕捉复杂全局依赖。MLA 通过在潜在空间中执行注意力计算,有效降低复杂度,同时提升建模能力。

MLA 的提出:MLA 将多头注意力机制 与 潜在表示学习 相结合,解决MHA在高计算成本和KV缓存方面的局限性。

MLA的具体做法(创新点)
采用低秩联合压缩键值技术,优化了键值(KV)矩阵,显著减少了内存消耗并提高了推理效率。

如上图,在MHA、GQA中大量存在于keys values中的KV缓存——带阴影表示,到了MLA中时,只有一小部分的被压缩Compressed的Latent KV了。

并且,在推理阶段,MHA需要缓存独立的键(Key)和值(Value)矩阵,这会增加内存和计算开销。而MLA通过低秩矩阵分解技术,显著减小了存储的KV(Key-Value)的维度,从而降低了内存占用。

MLA的核心步骤:

  1. 输入映射->潜在空间

    给定输入 X ∈ R n × d X\in\mathbb{R}^{n\times d} X∈Rn×d (其中 n n n是序列长度, d d d是特征维度),通过映射函数 f f f将其投影到潜在空间:

    Z = f ( X ) ∈ R n × k , k ≪ d Z=f(X)\in\mathbb{R}^{n\times k},\quad k\ll d Z=f(X)∈Rn×k,k≪d

    f ( ⋅ ) f(\cdot) f(⋅)可为全连接层、卷积层等映射模块,潜在维度 k k k是显著降低计算复杂度的关键。

  2. 潜在空间中的多头注意力计算

    在潜在空间 Z Z Z 上进行多头注意力计算。对于第 i i i 个注意力头,其计算公式为:

    Attention i = Softmax ( Q i ⋅ K i T d k ) V i \begin{aligned}\text{Attention}_{i}&=\text{Softmax}\left(\frac{Q_{i}\cdot K^{T}_{i}}{\sqrt{d_{k}}}\right)V_{i}\end{aligned} Attentioni​​=Softmax(dk​ ​Qi​⋅KiT​​)Vi​​

    其中:

    • Q i = Z W i Q , K i = Z W i K , V i = Z W i V Q_{i}=ZW^{Q}_{i},K_{i}=ZW^{K}_{i},V_{i}=ZW^{V}_{i} Qi​=ZWiQ​,Ki​=ZWiK​,Vi​=ZWiV​ 分别为查询、键和值;
    • W i Q , W i K , W i V ∈ R k × d k W^{Q}_{i},W^{K}_{i},W^{V}_{i}\in R^{k\times d_{k}} WiQ​,WiK​,WiV​∈Rk×dk​ 是可学习的投影矩阵;
    • d k = k / h d_{k}=k/h dk​=k/h 是每个注意力头的维度 ( h h h 是头数)。

    将所有注意力头的输出拼接后再通过线性变换:

    MultiHead ( Z ) = Concat ( Attention 1 , … , Attention h ) W O \begin{aligned}\text{MultiHead}(Z)&=\text{Concat}\left(\text{Attention}_{1},\ldots,\text{Attention}_{h}\right)W^{O}\end{aligned} MultiHead(Z)​=Concat(Attention1​,…,Attentionh​)WO​

    其中 W O ∈ R h d k × k W^{O}\in R^{hd_{k}\times k} WO∈Rhdk​×k 是输出投影矩阵。

  3. 映射回原始空间

    将多头注意力结果从潜在空间映射回原始空间:

    Y = g ( MultiHead ( Z ) ) ∈ R n × d Y=g(\text{MultiHead}(Z))\in\mathbb{R}^{n\times d} Y=g(MultiHead(Z))∈Rn×d

    g ( ⋅ ) g(\cdot) g(⋅)为非线性变换,如全连接层。

参考文献:https://github.com/deepseek-ai/DeepSeek-V3/blob/main/DeepSeek_V3.pdf