当前位置: 首页 > news >正文

分组查询注意力(GQA)的Pytorch实现

自注意力层(分组查询注意力)

初始化

class SelfAttention(nn.Module):def __init__(self, config, layer_idx):super().__init__()self.layer_idx = layer_idxself.n_head = config.n_head # 查询头的数量self.kv_head = config.kv_head # kv头的数量self.n_embed = config.n_embed # 嵌入维度self.h_dim = self.n_embed // self.n_headassert self.n_embed % self.n_head == 0assert self.kv_head < self.n_head and self.n_head % self.kv_head == 0self.q_linear = nn.Linear(self.n_embed, self.n_embed, bias=False)self.v_linear = nn.Linear(self.n_embed, self.kv_head*self.h_dim, bias=False)self.k_linear = nn.Linear(self.n_embed, self.kv_head*self.h_dim, bias=False)self.out = nn.Linear(self.n_embed, self.n_embed, bias=False)

layer_idx 的作用是作为一个索引,告诉当前这个 Attention 模块它是 Transformer 模型中的第几层。方便后续训练过程中的调试与日志记录以及kv缓存处理。
assert的作用是断定某个条件必须为真,如果该条件为假,程序就会立即崩溃并抛出一个 AssertionError 异常。
nn.Linear的两个必须参数为输入维度和输出维度
这里采用的是分组查询注意力(GQA),即多个查询头共享一个kv头。相比较与MHA,计算量和缓存压力要小很多,但理论上模型质量也会有所下降。

    def forward(self, x, cos_sin, kv_cache):# 修改qkv矩阵的形状,方便后续计算B, T, C = x.size()q = self.q_linear(x).view(B, T, self.n_head, self.h_dim)v = self.v_linear(x).view(B, T, self.kv_head, self.h_dim)k = self.k_linear(x).view(B, T, self.kv_head, self.h_dim)# 进行旋转位置编码cos, sin = cos_sinq,k = apply_rope(q, cos, sin), apply_rope(k, cos, sin)q,k = norm(q), norm(k)# 矩阵转置(B, T, H, D) -> (B, H, T, D)# 方便后续的注意力计算,让 PyTorch 可以将 H 个头看作一个批次维度进行高效的矩阵乘法q = q.transpose(1,2)v = v.transpose(1,2)k = k.transpose(1,2)

view() 本身是一个“零拷贝”操作,它不会移动数据,只是重新解释数据的形状。它并不会真的在内存中移动数据。只是修改了张量的元数据(比如 stride,即访问下一个元素需要跳过多少个内存位置),创建了一个新的“视图”指向原来的数据。
为了提高效率,像 transpose(), permute(), narrow() 等操作都是创建一个新的视图,并没有在内存中移动数据。

# 创建一个 2x3 的连续张量 
x = torch.arange(6).view(2, 3) 
# tensor([[0, 1, 2], 
#		 [3, 4, 5]])

在内存中,x 的数据是这样存储的:[0, 1, 2, 3, 4, 5]。

# 对 x 进行转置
y = x.transpose(0, 1)
# tensor([[0, 3],
#         [1, 4],
#         [2, 5]])

虽然 y 的逻辑形状是 (3, 2),但它在内存中的数据仍然是 [0, 1, 2, 3, 4, 5]

  • 为了读取 y 的第一行 [0, 3],程序需要先读取位置0的 0,然后跳过 1 和 2,去读取位置3的 3。
  • 因为元素不再是紧挨着的,所以这个张量 y 不是连续的,对于不连续的数据不能直接进行.view()操作
    旋转位置编码这里先不做赘述

现在来解释一下为什么又要进行矩阵转置
假设我们向注意力层输入的数据x长这样,一次输入4个toekn,每个token15个维度:
tensor([[[7, 2, 0, 2, 3, 7, 1, 2, 5, 5, 6, 6, 2, 9, 6],
[0, 1, 1, 9, 3, 3, 0, 2, 6, 4, 7, 0, 3, 7, 1],
[2, 7, 7, 0, 2, 1, 1, 8, 9, 6, 6, 2, 5, 5, 5],
[3, 6, 1, 1, 0, 0, 5, 3, 2, 0, 0, 9, 6, 6, 6]]])
尺寸为(1,4,15)
以q矩阵为例,假设(B, T, self.n_head, self.h_dim)=(1,4,3,5)
将输入 x 乘以一个可学习的权重矩阵(W_q) 得到q。这个过程叫做线性投影(Linear Projection)
则q初始化后长这样:
tensor([[[[7, 2, 0, 2, 3],# T0, H0
[7, 1, 2, 5, 5],# T0, H1
[6, 6, 2, 9, 6]],# T0, H2

     `[[0, 1, 1, 9, 3],# T1, H0``[3, 0, 2, 6, 4],# T1, H1``[7, 0, 3, 7, 1]],# T1, H2``[[2, 7, 7, 0, 2],# T2, H0``[1, 1, 8, 9, 6],# T2, H1``[6, 2, 5, 5, 5]],# T2, H2``[[3, 6, 1, 1, 0],# T3, H0``[0, 5, 3, 2, 0],# T3, H1``[0, 9, 6, 6, 6]]]])# T3, H2`

相当于把每个token截成四份,我们知道MQA,包括MHA计算注意力矩阵时都是头矩阵之间进行计算,所以我们按照头来分组,更能利用GPU擅长处理并行运算的特点。
对q的第2,3个维度进行转置:
tensor([[[[7, 2, 0, 2, 3], # H0, T0
[0, 1, 1, 9, 3], # H0, T1
[2, 7, 7, 0, 2], # H0, T2
[3, 6, 1, 1, 0]], # H0, T3

     `[[7, 1, 2, 5, 5],   # H1, T0``[3, 0, 2, 6, 4],   # H1, T1``[1, 1, 8, 9, 6],   # H1, T2``[0, 5, 3, 2, 0]],  # H1, T3``[[6, 6, 2, 9, 6],   # H2, T0``[7, 0, 3, 7, 1],   # H2, T1``[6, 2, 5, 5, 5],   # H2, T2``[0, 9, 6, 6, 6]]]]) # H2, T3`

这样我们就得到了每个头要处理的内容。

现在来介绍一下注意力分数的计算

if kv_cache is not None:k,v = kv_cache.insert_kv(self.layer_idx, k, v)Tq = q.size(2)Tk = k.size(2)nrep = self.n_head // self.kv_headk,v = repeat_kv(k, nrep), repeat_kv(v, nrep)if kv_cache is None or Tq == Tk:y = F.scaled_dot_product_attention(q, k, v, is_causal=True)# scaled_dot_product_attention将注意力机制的多个步骤融合,同时会调用最优的实现,比如FlashAttention来实现算力优化elif Tq == 1:# Tq = 1说明是单token生成,即推理场景,所以不需要掩码y = F.scaled_dot_product_attention(q, k, v, is_causal=False)else:mask = torch.zeros((Tq,Tk), device=q.device, dtype=torch.bool)prefix_len = Tk - Tqif prefix_len > 0:mask[:, :prefix_len] = Truemask[:, prefix_len:] = torch.tril(torch.ones((Tq,Tq), device=q.device, dtype=torch.bool))y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)y.transpose(1,2).contiguous().view(B, T, -1) # (B, H, T, D) -> (B, T, C)y = self.out(y)return y

这里先不介绍kv缓存的实现
nrep是为了计算出一个kv头要对应多少个查询头,然后对kv头进行复制

def repeat_kv(x, nrep):if nrep == 1:return xbs, h, slen, dim = x.shapereturn(x[:, :, None, :, :].expand(bs, h, nrep, slen, dim).reshape(bs, h * nrep, slen, dim))

假设一个原先的k头长这样
k = tensor([[[[0,0,0,0], [0,0,0,1], [0,0,0,2]], # K_H0
[[1,1,1,0], [1,1,1,1], [1,1,1,2]]]]) # K_H1
x[:, :, None, :, :]后:
tensor([[[ [[0,0,0,0], [0,0,0,1], [0,0,0,2]] ], # K_H0 [ [[1,1,1,0], [1,1,1,1], [1,1,1,2]] ]]]) # K_H1expand后:tensor([[[ [[0,0,0,0], [0,0,0,1], [0,0,0,2]]], # KV_H0 (original) [[0,0,0,0], [0,0,0,1], [0,0,0,2]]], # KV_H0 (view 1) [[0,0,0,0], [0,0,0,1], [0,0,0,2]]], # KV_H0 (view 2) [[0,0,0,0], [0,0,0,1], [0,0,0,2]]]], # KV_H0 (view 3)`

     `[ [[1,1,1,0], [1,1,1,1], [1,1,1,2]]],  # KV_H1 (original)``[[1,1,1,0], [1,1,1,1], [1,1,1,2]]],  # KV_H1 (view 1)``[[1,1,1,0], [1,1,1,1], [1,1,1,2]]],  # KV_H1 (view 2)``[[1,1,1,0], [1,1,1,1], [1,1,1,2]]]]])# KV_H1 (view 3)`

reshape后:
tensor([[[[0,0,0,0], [0,0,0,1], [0,0,0,2]]], # Head 0 (from K_H0)
[[0,0,0,0], [0,0,0,1], [0,0,0,2]]], # Head 1 (from K_H0)
[[0,0,0,0], [0,0,0,1], [0,0,0,2]]], # Head 2 (from K_H0)
[[0,0,0,0], [0,0,0,1], [0,0,0,2]]], # Head 3 (from K_H0)
[[1,1,1,0], [1,1,1,1], [1,1,1,2]]], # Head 4 (from K_H1)
[[1,1,1,0], [1,1,1,1], [1,1,1,2]]], # Head 5 (from K_H1)
[[1,1,1,0], [1,1,1,1], [1,1,1,2]]], # Head 6 (from K_H1)
[[1,1,1,0], [1,1,1,1], [1,1,1,2]]]])# Head 7 (from K_H1)
这样在与q矩阵运算时就能一一对应了
scaled_dot_product_attention(q, k, v, is_causal=False)torch.nn.functional的一个函数,只要输入q,k,v矩阵和是否进行掩码就能注意力分数的计算。这个函数会在底层使用flashattention机制来实现算力优化。
现在来介绍一下标准注意力的计算
$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$

attn_score = torch.matmul(q,k.transpose(-2, -1))
attn_score = attn_score / self.h_dim
if is_causal:mask =torch.triu(torch.ones(attn_score.size(-2),attn_score.size(-1)), diagonal=1)
attn_score = attn_score.mask_fill(mask, float('inf'))
attn_weights = F.softmax(attn_score, dim=-1)
output = torch.matmul(attn_weights, v)

matul是矩阵相乘,torch.triu用于创建一个上三角矩阵
假如说x长这样:
tensor([[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.],
[1., 1., 1., 1.]])
torch.triu(x, diagonal=0)就长这样:
tensor([[ 1., 1., 1., 1.],
[ 0., 1., 1., 1.],
[ 0., 0., 1., 1.],
[ 0., 0., 0., 1.]])
torch.triu(x, diagonal=1)就长这样:
tensor([[ 0., 1., 1., 1.],
[ 0., 0., 1., 1.],
[ 0., 0., 0., 1.],
[ 0., 0., 0., 0.]])
.mask_fill表示应用掩码矩阵,float('inf')表示负无限,方便后续交给softmax处理
下面来介绍三个选择情况:
if kv_cache is None or Tq == Tk
表示没有kv缓存或q,k矩阵序列长度相同的情况,一般这个情况都是在进行模型训练,此时需要掩码。
elif Tq == 1:
q的序列长度为1,说明此时正在进行单token生成,即推理状态,此时不需要掩码。
 else:
这个块处理的是模型在推理时,一次性处理一批(一个 chunk)新的查询(Tq > 1),并且 KV 缓存中已经存在一部分历史信息(Tk > Tq)。在这种模式下,一个小的、快速的“草稿模型”会先生成一小段文本(比如4个 tokens),然后主模型(就是我们正在分析的这个)会一次性地验证这4个 tokens。这时,Tq 就等于 4。即投机解码算法(speculative decoding)主要用于加速模型的推理。
对于该算法的详细流程,这里先不做过多赘述。

        y.transpose(1,2).contiguous().view(B, T, -1) # (B, H, T, D) -> (B, T, C)y = self.out(y)

将所有头拼接起来,相当于前面转置qkv矩阵的逆操作
contiguous()是为了创建一个在内存上连续的y副本。view() 本身也是一个“零拷贝”操作,它不会移动数据,只是重新解释数据的形状。如果数据在内存中是“乱”的(非连续的),view() 就不知道该如何正确地、高效地重新解释它。这一点在前面解释view()时谈到过

http://www.hskmm.com/?act=detail&tid=35552

相关文章:

  • 基于TV模型利用Bregman分裂算法迭代对图像进行滤波和复原处理
  • 2025 年大路灯品牌最新推荐榜,技术实力与市场口碑深度解析,精选优质源头厂家
  • LangChain4j 比 SolonAI 强在哪?弱在哪?
  • 2025.10.20__2023秋季联赛题解(第11题)
  • docker怎么更新版本
  • B树和B+树的解析应用
  • 2025 年快速退火炉优质厂家最新推荐榜单:真空 / 半导体 / 晶圆 / 高温 / 桌面等多类型设备企业权威评选
  • 2025年10月河南园区招商扶持公司推荐:五强对比评测榜
  • 2025 年广州心理疏导机构推荐:桥恩心理多维度服务满足不同人群心理健康需求
  • 2025 年深圳心理疏导机构推荐,桥恩心理:专业心理疏导服务的优质选择与全体系诊疗优势
  • 2025年10月手操器公司推荐:对比评测榜揭示工业诊断选型要点
  • OIFC NOI2023省队集训
  • 实战案例:职行力如何利用纷享销客CRM实现人效管理数字化突围?
  • 2025年10月素材平台对比评测榜:高品图像领衔五强深度解析
  • 2025年10月儿童面霜品牌推荐:五强榜单对比评测与选购指南
  • Ansible
  • 示波器接地环路与电磁脉冲干扰:原理、影响及应对策略
  • 2025 年国内传感器厂家最新推荐排行榜:聚焦磁致伸缩 / 防爆 / 防水 / 线性 / 液位等多类型传感器,精选优质企业
  • 2025 年钢结构厂家最新推荐:优质品牌权威榜单发布,助力客户精准选择可靠合作伙伴
  • Palantir实体工程实践
  • 施普林格论文集:发展中国家城市废物流资源化利用与回收洞察
  • 0.9B PaddleOCR-VL 登顶 SOTA!GPUStack 高效推理部署实战指南
  • 【URP】Unity中的[摩尔纹]问题解决方案
  • 打印机已发送,但是不打印?一份全面的故障排除指南!
  • 2025 年雕塑源头厂家最新推荐排行榜:聚焦婚庆泡沫 / 玻璃钢 / 城市地标不锈钢等多品类,精选优质企业
  • SOAR技术与高效网络安全运营 - 教程
  • 2025《中国科学:信息科学》前沿学术沙龙暨2025年智能控制与计算科学国际学术会议
  • 2025 年板材厂家最新推荐排行榜:聚焦 ENF 级环保、零醛添加等高品质板材,精选前 6 强深度解析品牌优势与产品亮点
  • 在 .NET 9 中使用 Mapster 快速、高效的实现对象映射
  • 列出 Redux 的组件?