资讯 | Deepseek-V2多头潜在注意力(Multi-head Latent Attention)原理及PyTorch实现

GS Lab 图科学实验室Graph Science Lab 2025年01月23日 22:48 广东

探索 DeepSeekV2 中的 GPU 利用率瓶颈和多头潜在注意力实现。

在本文中,我们将探讨两个关键主题。首先,我们将讨论和了解 Transformer 模型(也称为大型语言模型 (LLM))在训练和推理过程中遇到的瓶颈问题。

然后,我们将深入研究 LLM 架构中有关 KV 缓存的特定瓶颈问题,以及 DeepSeek 的创新方法多头潜在注意力如何解决这个问题。

http://arxiv.org/pdf/2405.04434v5

DeepSeek-V2 是一种强大的开源混合专家 ( MoE ) 语言模型,其特点是通过创新的 Transformer 架构进行经济的训练和高效的推理。它包含236B 个总参数,其中每个 token 激活21B ,并支持128K 个token的上下文长度。

DeepSeek-V2 性能表现位居开源模型前列,成为最强开源 MoE 语言模型。在 MMLU 上,DeepSeek-V2 仅用少量激活参数就取得了顶级性能。与 DeepSeek 67B 相比,DeepSeek-V2 性能大幅提升,节省了 42.5% 的训练成本,减少了93.3% 的KV 缓存,最大生成吞吐量提升至 5.76 倍。

GPU 处理中的瓶颈问题

近年来,图形处理单元 (GPU) 的研究和投资激增。事实上,当我写这篇文章时,我偶然看到了NVIDIA已成为第二大最有价值的上市公司的消息,其估值高达3 万亿美元。

但为什么会这样呢?答案在于 AI 模型需要运行大量的数学运算,而 GPU 有助于加快运行速度。它们在执行计算方面已经变得非常高效,因此,它们的需求量很大,这并不奇怪。

GPU 的速度已经变得太快了。它们能够以惊人的速度执行计算,以每秒浮点运算次数 (FLOP) 来衡量。

但存在一个问题——计算速度远远超过了内存带宽(GB/s),即 GPU 中不同内存区域之间传输数据的速度。这种不匹配造成了瓶颈,拖慢了整个过程。

为了帮助说明瓶颈问题,让我们看一下图 1 和图 2。

图片 1

图1显示了 NVIDIA A100 GPU 的规格。如您所见,这款强大的 GPU 可以在 FP32 模式下执行令人印象深刻的 19.5 TFLOP(万亿次浮点运算)。但是,其 GPU 内存带宽限制在 2 TB/s 左右。虽然这两者是两码事,但还是有联系的。

这突出了一个关键点,瓶颈并不总是在于我们可以执行多少操作,而是在于我们在 GPU 的不同部分之间传输数据的速度有多快。如果数据传输的数量随着内存的增加而增加,延迟也会增加。*我们计算中涉及的张量的大小和数量在其中起着重要作用。

例如,对同一个张量多次计算同一个操作可能比对大小相同的不同张量计算同一个操作更快。这是因为 GPU 需要移动张量,这会降低速度。因此,我们的目标不应该只是优化我们执行的操作数量(KV 缓存、MQA、GQA),还应该尽量减少我们需要进行的内存访问和传输。

图 2 将有助于澄清这些概念,并让您更好地理解这些操作的工作原理。图 2 说明了 GPU 中 Attention 操作的发生方式。

图 2:GPU 中的注意力操作

现在我们已经探讨了瓶颈问题,很明显它可以显著增加延迟。为了解决这个问题并减少推理过程中的延迟,研究人员提出了各种方法,包括闪存注意、多查询注意、分组查询注意、KV 缓存方法、滚动缓冲区 KV 缓存等。即使是最近流行的方法 MAMBA 也间接解决了这个问题。

仔细查看图 1,您会发现 GPU 内存为 80 GB HBM2e(高带宽第二代增强型)。然而,在训练和推理大型语言模型(近年来,尤其是随着LLM 和多模态模型的兴起,这些模型的参数呈指数级增长 )时,这种 HBM 内存很快就会成为限制因素,造成流程瓶颈并增加延迟。

那么,DeepSeek 解决了什么瓶颈问题?

众所周知,KV 缓存推理是一种有助于减少注意力机制(vanilla transformer — 注意力就是你所需要的全部论文架构)中计算负载的解决方案。它通过在 Key 和 value 中缓存 token 来生成下一个 token。然而,在处理长序列的实际场景中,KV 缓存会变得非常大且占用大量内存。这限制了最大批处理大小和序列长度,从而造成了瓶颈。

为了解决这一瓶颈并减少延迟,DeepSeek 的研究人员提出了一种名为多头潜在注意力的新方法。这种新方法旨在缓解瓶颈问题并加快推理过程。

DeepSeek-V2 采用了两种创新架构

用于前馈网络 (FFN) 的DeepSeekMoE架构。

多头潜在注意力(MLA)用于注意力机制。

DeepSeekV2 Transformer Block由DeepSeekMoE+MLA组成

DeepSeekMoE

在标准 MoE 架构中,每个 token 被分配一个(或两个)专家,并且每个 MoE 层都有多个专家,所有专家在结构上都与标准 FFN 相同。这种设置体现了两个问题:一个 token 的指定专家将打算在其参数中汇集截然不同的知识,这些知识很难同时利用;其次,分配给不同专家的 token 可能需要共同的知识,从而导致多个专家聚集在一起获取各自参数中的共享知识。

为了解决这两个问题,DeepSeekMoE 引入了两种策略来提高专家的专业化程度:

  • 细粒度专家细分:为了更有针对性地获取每个专家的知识,通过分割 FFN 中间隐藏维度,将所有专家细分为更细的粒度。因此

  • 共享专家隔离:隔离某些专家作为始终处于激活状态的共享专家,旨在捕获不同情境下的共同知识,并通过将共同知识压缩到这些共享专家中,减少其他路由专家之间的冗余。

让我们在 DeepSeekMoE 中为第 t个 token制定专家任务。如果u_t是此 token 的 FFN 输入,则输出h`_ t 将为:

其中 𝑁𝑠 和 𝑁𝑟 分别是共享专家和路由专家的数量;同样,FFN(𝑠)_𝑖 和 FFN(𝑟)_𝑖 分别表示第 𝑖 位共享专家和第 𝑖 位路由专家。

然后对于路由专家,g_i,t是第 i 个此类专家的门值, s_i,t是 token 到专家的亲和力分数,Topk(. , Kr)由Kr个最高亲和力分数组成 ,其中Kr是路由专家的活跃数量

Gate Implementation

class MoEGate(torch.nn.Module):
    def __init__(self, num_experts_per_tok: int, n_routed_experts: int, routed_scaling_factor: int, topk_method: str, n_group: int, topk_group: int, hidden_size: int):
        super().__init__()
        self.top_k = num_experts_per_tok
        self.n_routed_experts = n_routed_experts
        self.routed_scaling_factor = routed_scaling_factor
        self.topk_method = topk_method
        self.n_group = n_group
        self.topk_group = topk_group
        self.weight = torch.nn.Parameter(torch.empty((self.n_routed_experts, hidden_size)))
        torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
    def forward(self, x: torch.Tensor):
        batch, seq_len, h = x.shape
        hidden_states = x.view(-1, h)
        logits = torch.nn.functional.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32), None)
        scores = logits.softmax(dim=-1, dtype=torch.float32)
        if self.topk_method == "greedy":
            topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
        elif self.topk_method == "group_limited_greedy":
            group_scores = (scores.view(batch * seq_len, self.n_group, -1).max(dim=-1).values)
            group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]  # [n, top_k_group]
            group_mask = torch.zeros_like(group_scores)  # [n, n_group]
            group_mask.scatter_(1, group_idx, 1)  # [n, n_group]
            score_mask = (
                group_mask.unsqueeze(-1)
                .expand(
                    batch * seq_len, self.n_group, self.n_routed_experts // self.n_group
                )
                .reshape(batch * seq_len, -1)
            )  # [n, e]
            tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0)  # [n, e]
            topk_weight, topk_idx = torch.topk(
                tmp_scores, k=self.top_k, dim=-1, sorted=False
            )
        return topk_idx, topk_weight

MoE Implementation

class MoE(torch.nn.Module):
    def __init__(self, dim: int, routed_scaling_factor: int, topk_method: str, n_group: int, topk_group: int, hidden_dim: int | None = None, n_routed_experts: int = 12, num_experts_per_tok: int = 4, n_shared_experts: int = 2, mlp: str = "swiglu"):
        super().__init__()
        self.experts_per_rank = n_routed_experts
        self.num_experts_per_tok = num_experts_per_tok
        self.n_shared_experts = n_shared_experts
        mlp_block = SwiGLU
        self.experts = torch.nn.ModuleList([mlp_block(dim, hidden_dim) for i in range(n_routed_experts)])
        self.gate = MoEGate(num_experts_per_tok, n_routed_experts, routed_scaling_factor, topk_method, n_group, topk_group, dim)
        self.shared_experts = mlp_block(dim, hidden_dim * n_shared_experts)
        
    def forward(self, x: torch.Tensor):
        identity = x
        orig_shape = x.shape
        topk_idx, topk_weight = self.gate(x)
        x = x.view(-1, x.shape[-1])
        flat_topk_idx = topk_idx.view(-1)
        x = x.repeat_interleave(self.num_experts_per_tok, dim=0)
        y = torch.empty_like(x)
        y = y.type(x.dtype)
        for i, expert in enumerate(self.experts):
            y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=x.dtype)
        y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
        
        y = y.view(*orig_shape)
        output = y + self.shared_experts(identity)
        return output

多头潜在注意力 (MLA)

假设您熟悉大型语言模型中的自注意力和多头注意力。

简单来说,研究人员在多头潜在注意力方面取得了突破,通过降低空间复杂度或内存使用量来降低时间复杂度,最终降低延迟。

我想简单介绍一下什么是潜在向量。潜在向量表示影响或促成模型观测数据或结果的潜在或不可观察因素。

在某些模型中,例如 Mistral,模型的维度设置为 4096。这意味着所有向量、层输出和嵌入维度都将与模型维度大小相同。

然而,DeepSeek 的研究人员采用了不同的方法。他们通过降低模型维度来高效地存储 KV 缓存。

例如,他们将向量维度从 4096 缩小到 1024。这允许 KV 缓存存储在 1024 维度中,而其他层仍使用原始模型维度。让我们从数学和视觉上深入了解它的工作原理。

图 3:DeepSeekv2[1] 中的多头潜在注意力

这篇文章的灵感来自 DeepSeek V2论文,论文中描述了一种减少 KV 缓存大小的过程。这是使用 Key 和 Value 的投影矩阵(projection matrices)的低秩近似值(low rank approximation)来执行的。论文没有详细说明如何实现这一点——具体来说,损失函数应该是什么,以及如何实际训练它们。这让我开始理解和构建一个实现,用于训练一个网络来生成 key 和 value 权重矩阵的低秩近似值。在这篇文章的其余部分,将提供我对如何实现这一点的解释,并提供一些背后的高级数学推理。已经使用一个小型且非常简单的 LLM 实现了这一点——只是为了看看如何实现这一点。

多头潜在注意力 (MLA) 的性能优于标准多头注意力 (MHA),同时显著降低了 KV 缓存增强推理效率。MLA 不会像多查询注意力 (MQA) 和分组查询注意力 (GQA) 那样减少 KV 头,而是将键和值联合压缩为潜在向量。

在深入研究细节之前,我们必须要明白低秩矩阵近似是作为训练过程的一部分实现的——这一点很重要,因为第一个想法是将低秩矩阵近似计算为训练后的工作——尽管计算效率低下,但这是第一个想法。但是,重读了几次论文——很明显,他们正在将键和值的 W 矩阵投影到较低维空间。这意味着,从正常的内部隐藏大小到较低的维度,或者如论文中所说是压缩维度。接下来,必须弄清楚如何训练它们——这促使设计了可以捕获重建的损失函数,以及交叉熵损失。

低秩投影(Low-Rank Projection)

图 4:键值向下和向上投影。

""" I want to Give a Latent vector illustration with model dimension of 4096 
    and Latent dimension of 1024 """
# dim = 4096 , latent_dim = 1024
self.c^KV_t = nn.Linear(dim, latent_dim, bias=False) # Equation 1
self.k^C_t = nn.Linear(latent_dim, dim, bias=False) # Equation 2
self.v^C_t = nn.Linear(latent_dim, dim, bias=False) # Equation 3

MLA(Multi-Head latenttention)背后的核心思想是对key和value进行低秩联合压缩,以减少KV缓存。

注意:(1)在此示例中,为简单起见,我使用了 1024 的潜在向量大小。(2)您在图像(插图图像)中看到的数字只是随机示例,并不代表模型中的实际数字。(3)图像和方程中显示的权重矩阵是可学习的参数,这意味着它们在训练的反向传播过程中会得到更新。

c^KV_t是键和值的压缩潜在向量(公式 1)。这是原始键和值向量的低维表示。(正如我之前提到的,这个维度将像 1024(d_c — 压缩维度),潜在向量与模型维度相比较低。它在推理期间被缓存。

在图 4 中的公式 1 中,下投影权重矩阵W^DKV将输入向量 (h_t) 从模型维度 (4096) 投影到潜在维度 (1024)。

理解线性层(Linear Layer)的重要说明:权重矩阵(当前层神经元数量、前一层或输入层神经元数量)。我们还用输入大小和输出大小(nn.linear(输入向量大小(1024),输出向量大小(4096)))初始化线性层。例如,如果我们在 MLP/FFN 中有 1024(输入向量大小)个输入单元,在隐藏层中有 4096 个神经元,则权重矩阵大小将为W(4096,1024)。此矩阵(转置)与输入向量相乘,得出一个新向量。然后将偏差向量添加到线性变换后的向量中。这是向量的线性变换。向量从一个维度(1024)变换到其他维度(4096)或相同维度。

图 5:图 4 中公式 1 的直观表示。

图 5说明图 4 中的公式 1。我取了 9 个 token,模型维度为 4096。

当我们使用模型进行预测(推理)时,我们不需要存储原始的高维键和值向量。相反,我们可以存储一个压缩的潜在键和值向量,称为c^KV_t。这更有效率,因为 c^KV_t 只有1024 个元素(每个向量),而原始元素有 4096 个(模型维度)。这种方法减少了我们所需的内存量,也加快了进程。

图 4 中所示的方程式(方程式 2 和 3)用于向上投影压缩的键和值向量表示(c^KV_t)。这意味着我们正在获取压缩的潜在向量并将其投影回模型维度。

得到的键和值向量 k^C_t 和 v^C_t 现在位于模型维度 (h_d * n_h)中。我们使用线性层权重矩阵 W^UK 和 W^UV 实现这种向上投影。

但是,在推理过程中,向上投影矩阵 W^UK 和 W^UV 可以被吸收(矩阵乘法结合律)到查询(W^Q)和输出(W^O)投影矩阵中,从而无需明确计算键和值。

这种方法不同于传统方法,传统方法是分别投影键和值向量,然后将它们与头部的权重矩阵(例如 W^Q、W^K 和 W^V)相乘。相反,我们将 k^C_t 和 v^C_t 直接投影到头部的键和值矩阵中。我将用下图来说明这一点。

图 6:键和值的向上投影。

从图 6 中,您可以了解到,使用两个权重矩阵 (W^UK) 和 (W^UV) 将低维 kv 表示 (c^KV_t) 投影到 32 个值和 32 个键向量中,每个向量的大小为(9, 128)。W^UK 和 W^UV 的大小为 ( d_h X n_ h X d_c (潜在向量维度))。这些查询和值用于多头注意力。

d_h → 头部尺寸,n_h → 头部数量

注意:我再次投影到模型维度。但我们可以根据需要向上投影到任何维度(d_h X n_ h)。但这会影响注意力层输出投影。所以我们可能需要调整输出权重矩阵以投影到模型维度。

现在我们需要研究查询压缩

图 7:查询向下和向上投影。

在图7中,公式4表示将查询压缩到潜在向量空间(从4096到1024)。公式5是上投影公式(从1024到d_h X n_h)。

等式 4 和 5 与您在图 5 和 6 中看到的图示相同。

到目前为止,我们已经看到了查询、键和值向量的向下投影和向上投影。但我们必须看到一个重要的主题,即如何将旋转位置嵌入集成到这些向量中。

解耦旋转位置嵌入(Decoupled Rotary Positional Embedding)

RoPE(旋转位置嵌入)通过在多维空间中旋转向量,将位置信息引入键和查询。RoPE 既是相对的,又是绝对的。

简单说明一下 RoPE 如何旋转向量。

图 8:二维空间中的矢量旋转。

图 8 给出了向量在二维空间中如何旋转的方程。m 是序列中标记/向量的位置。theta 是角度(对于所有标记/向量都相同)。旋转矩阵 (cos m.theta, sin m.theta) 通常记为r。

图片 9:二维空间中的 RoPE

在图 9 中,我以我们享受音乐为例。在这个序列中,“我们”是第一个标记,因此位置 m 变为1,而对于享受和音乐,位置变为2和3。Theta对于所有向量都是相同的。在大多数架构(python)中,序列位置从 0 开始,为了说明向量旋转如何在二维中发生,我从 1 开始获取位置。

我们仅使用 RoPE 方法旋转向量。我们知道向量有大小和方向。在 RoPE 中,我们仅通过 m * theta 的角度改变方向,向量的大小保持不变。它为键和值向量提供绝对(该特定标记的位置)和相对(相对于序列中其他标记的位置)位置信息。这增强了注意力输出。

我们只旋转查询和键,而不是值向量。因此,旋转位置嵌入有助于增强输入之间的注意力机制。在 vanilla Transformer 架构中存在一个问题,它将位置信息(正弦函数)添加到字嵌入中,从而改变了所有向量的大小,即使相同的词在输入中也有不同的向量(不同的大小)。

但在 RoPE 中,同样的单词 Magnitude 在旋转后没有变化。旋转提供了信息,因此向量之间的注意力增强了注意力(语义和句法关系)输出。但 Value 向量没有经过任何旋转(真正的词嵌入),因此注意力层输出具有关于 token 序列的丰富信息。

但是,当 RoPE 对键和值都具有位置敏感性时,就会出现问题。如果我们在方程 1 和方程 2 中应用 RoPE(图 4),它将与位置敏感的权重矩阵相结合。因此,W^UK 无法再被吸收到 W^Q 中,因为乘法不交换。这需要在推理过程中重新计算所有前缀标记的键,这是低效的。

为了解决这个问题,提出了一种解耦的 RoPE 策略。该策略使用额外的多头查询和共享密钥来将 RoPE 信息与压缩密钥和查询分开携带。然后,这些解耦的 RoPE 查询和密钥与向上投影的查询和密钥连接起来。

图 10:应用于键和查询以及连接的解耦 RoPE

图 10:W^QR 和 W^KR 是权重矩阵,用于产生解耦查询和密钥,如公式 6 和 7 所示。RoPE(.)表示应用 RoPE 旋转矩阵的操作,如您可以与图 8 进行比较。公式 8 和 9 表示解耦的 RoPE 应用查询和密钥向量与向上投影的查询和密钥向量的连接。

注意:RoPE 不会改变向量的大小或维度。权重矩阵 W^QR 和 W^KR 会改变与 q^C_t 和 k^C_t 兼容的输入向量。

与潜在向量相比,q^R_t 和 k^R_t 的大小将更小。查询(n_h RoPE 向量)的大小为 (d^R_h * n_h),键的大小为 d^R_h。q^R_t 是根据压缩查询向量 c^Q_t 计算得出的。但 k^R_t 是根据输入向量 h_t 本身计算得出的。

连接保留了查询和键中的位置信息。查询和键向量应位于同一维度,才能发生注意操作,如图 11 所示。

图 11:注意力机制。

图11:u_t只不过是一个带有模型维度的注意力层输出。W_O权重矩阵注意力层输出投影矩阵。

将键值缓存与其他方法进行比较

为了证明多头潜在注意力机制效果更好,DeepSeek 的研究人员对不同注意力机制中每个 token 的 KV 缓存进行了比较(如图 12 所示)。我们可以理解,MLA 只需要很少量的内存用于 KV 缓存,但却能实现比 MHA(多头注意力机制)更强的性能。

图 12:每个 token 的 KV 缓存比较。

图 12:n_h 表示注意力头的数量,d_h 是注意力头的维度,l 表示 LLM 架构中的解码器层数,n_g 是分组查询注意力中的组数。

MLA 完整计算过程(总结):

  • 公式 4(图 7)→输入向量(h_t)被投影到潜在维度(查询的压缩版本)。

  • 公式 5(图 7)→然后将潜在向量投影到多个查询(多个头)中。

  • 公式 6(图片 10)→为了捕获查询输入向量的位置信息,研究人员使用了创建位置向量的解耦旋转位置嵌入。

  • 公式 8(图 10)→位置向量与查询连接在一起

  • 公式 1(图 4)→输入向量(h_t)被投影到潜在维度(键和值的压缩版本)[推理期间缓存]。

  • 公式 2(图片 4)→使用线性层权重矩阵将键和值的潜在向量投影到压缩键中。

  • 公式 7(图片 10)→对于键的位置信息,RoPE 已应用于输入向量 (h_t) 以创建位置向量[推理期间缓存]。

  • 公式 9(图 10)→位置向量被连接成压缩密钥向量。

  • 公式 3(图片 4)→使用线性层权重投影矩阵将键和值的潜在向量投影到压缩值向量中。

  • 公式 10(图片 11)→注意力机制在所有头脑中发生。

  • 公式 11(图片 11)→注意力头被连接起来。然后使用名为 W_O 的权重矩阵对该连接的输出进行线性投影。

  • 最终输出随后被输入到 MoE(细粒度专家和共享专家隔离)层。

压缩质量损失(Compression quality Loss)——定义重建损失,由以下表达式给出:

L_comp = ||K — W^PRO_k^T * c_k||²_F + ||V — W^PRO_v^T * c_v||²_F

K、V 是原始键和值矩阵,c_k, c_v 是压缩表示,||·||²_F 表示 Frobenius 范数,W^PRO_k、W^PRO_v 是投影矩阵。

正交性(Orthogonality)——这使我们能够确保降维能够创建或保留给定矩阵的独立特征。

  • 作用于投影矩阵 (W^PRO)

  • 确保投影中的不同维度捕捉不同的方面

  • 保留输入之间的关系

  • 转型空间的作品

L_ortho = ||W^PRO_k * W^PRO_k^T — I||²_F + ||W^PRO_v * W^PRO_v^T — I||²_F

稀疏性(Sparsity)——这是确保我们保留给定矩阵的独立维度的另一种形式。

  • 作用于压缩表示(c_k,c_v)

  • 使实际压缩值变得稀疏

  • 降低计算复杂度

  • 在表征空间中工作

L_sparse = ||c_k||₁ + ||c_v||₁

详细说明正交正则化和稀疏正则化(orthogonal and sparsity regularization)的组合——稀疏性有助于将矩阵中的许多条目推向接近零(这既是压缩,也是确保内存优化)。正交性有助于确定独立特征的存在并在变换空间中找到它们。两者在学习找到将输入映射到低维空间的潜在空间方面相互补充。

注意力一致性(Attention Consistency)

注意一致性可以保持原始注意模式和压缩注意模式之间的关系或一致性。

L_attn = ||A_orig — A_comp||²_F

综合所有这些,我们可以得到最终的损失目标

L_total = λ_task * L_task + λ_comp * L_comp + λ_attn * L_attn + λ_ortho * L_ortho + λ_sparse * L_sparse

这里L_task是交叉熵损失项。

MLA Implementation

class MLA(torch.nn.Module):
    def __init__(self, model_args: DeepseekConfig):
        super().__init__()
        d_model = model_args.d_model
        self.num_heads = model_args.num_heads
        self.head_dim = model_args.d_model // model_args.num_heads
        self.attn_dropout = torch.nn.Dropout(model_args.dropout)
        self.res_dropout = torch.nn.Dropout(model_args.dropout)
        self.flash_attn = hasattr(torch.nn.functional, "scaled_dot_product_attention")
        
        self.q_lora_rank = model_args.q_lora_rank
        self.qk_rope_head_dim = model_args.qk_rope_head_dim
        self.kv_lora_rank = model_args.kv_lora_rank
        self.v_head_dim = model_args.v_head_dim
        self.qk_nope_head_dim = model_args.qk_nope_head_dim
        self.q_head_dim = model_args.qk_nope_head_dim + model_args.qk_rope_head_dim
        self.q_a_proj = torch.nn.Linear(d_model, model_args.q_lora_rank, bias=False)
        self.q_a_layernorm = RMSNorm(model_args.q_lora_rank)
        self.q_b_proj = torch.nn.Linear(model_args.q_lora_rank, self.num_heads * self.q_head_dim, bias=False)
        self.kv_a_proj_with_mqa = torch.nn.Linear(d_model,model_args.kv_lora_rank + model_args.qk_rope_head_dim,bias=False,)
        self.kv_a_layernorm = RMSNorm(model_args.kv_lora_rank)
        self.kv_b_proj = torch.nn.Linear(model_args.kv_lora_rank,self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + 
            self.v_head_dim),bias=False,)
        self.o_proj = torch.nn.Linear(self.num_heads * self.v_head_dim,d_model, bias=False,)

    def forward(self, x: torch.Tensor, mask: torch.Tensor, freqs_cis) -> torch.Tensor:
        batch, seq_len, d_model = x.shape
        q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x)))
        q = q.view(batch, seq_len, self.num_heads, self.q_head_dim).transpose(1, 2)
        q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        compressed_kv = self.kv_a_proj_with_mqa(x)
        compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        k_pe = k_pe.view(batch, seq_len, 1, self.qk_rope_head_dim).transpose(1, 2)
        kv = (self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
            .view(batch, seq_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
            .transpose(1, 2))
        k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
        q_pe, k_pe = apply_rope(q_pe, k_pe, freqs_cis)
        k_pe = k_pe.transpose(2, 1)
        q_pe = q_pe.transpose(2, 1)
        query_states = k_pe.new_empty(batch, self.num_heads, seq_len, self.q_head_dim)
        query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
        query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
        key_states = k_pe.new_empty(batch, self.num_heads, seq_len, self.q_head_dim)
        key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
        key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
        attn_mtx = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
        attn_mtx = attn_mtx + mask[:, :, :seq_len, :seq_len]
        attn_mtx = torch.nn.functional.softmax(attn_mtx.float(), dim=-1).type_as(key_states)
        attn_mtx = self.attn_dropout(attn_mtx)
        output = torch.matmul(attn_mtx, value_states)  # (batch, n_head, seq_len, head_dim)
        output = output.transpose(1, 2).contiguous().view(batch, seq_len, self.num_heads * self.v_head_dim)
        output = self.o_proj(output)
        output = self.res_dropout(output)
        return output

参考

  1. DeepSeek-AI, Aixin Liu, Bei Feng, Bin Wang, Bingxuan Wang, Bo Liu, Chenggang Zhao, Chengqi Dengr, DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model (2024), Research paper (arxiv)

  2. Flash Attention Conceptual Guide, Huggingface.co

  3. Exploring the GPU Architecture, vmware.com

  4. Jianlin Su, Yu Lu, Shengfeng Pan, Ahmed Murtadha, Bo Wen, Yunfeng Liu, RoFormer: Enhanced Transformer with Rotary Position Embedding (2021), Research Paper (arxiv)

  5. https://huggingface.co/abideen