自注意力层(分组查询注意力)
初始化
class SelfAttention(nn.Module):def __init__(self, config, layer_idx):super().__init__()self.layer_idx = layer_idxself.n_head = config.n_head # 查询头的数量self.kv_head = config.kv_head # kv头的数量self.n_embed = config.n_embed # 嵌入维度self.h_dim = self.n_embed // self.n_headassert self.n_embed % self.n_head == 0assert self.kv_head < self.n_head and self.n_head % self.kv_head == 0self.q_linear = nn.Linear(self.n_embed, self.n_embed, bias=False)self.v_linear = nn.Linear(self.n_embed, self.kv_head*self.h_dim, bias=False)self.k_linear = nn.Linear(self.n_embed, self.kv_head*self.h_dim, bias=False)self.out = nn.Linear(self.n_embed, self.n_embed, bias=False)
layer_idx
的作用是作为一个索引,告诉当前这个 Attention 模块它是 Transformer 模型中的第几层。方便后续训练过程中的调试与日志记录以及kv缓存处理。
assert
的作用是断定某个条件必须为真,如果该条件为假,程序就会立即崩溃并抛出一个 AssertionError
异常。
nn.Linear
的两个必须参数为输入维度和输出维度
这里采用的是分组查询注意力(GQA),即多个查询头共享一个kv头。相比较与MHA,计算量和缓存压力要小很多,但理论上模型质量也会有所下降。
def forward(self, x, cos_sin, kv_cache):# 修改qkv矩阵的形状,方便后续计算B, T, C = x.size()q = self.q_linear(x).view(B, T, self.n_head, self.h_dim)v = self.v_linear(x).view(B, T, self.kv_head, self.h_dim)k = self.k_linear(x).view(B, T, self.kv_head, self.h_dim)# 进行旋转位置编码cos, sin = cos_sinq,k = apply_rope(q, cos, sin), apply_rope(k, cos, sin)q,k = norm(q), norm(k)# 矩阵转置(B, T, H, D) -> (B, H, T, D)# 方便后续的注意力计算,让 PyTorch 可以将 H 个头看作一个批次维度进行高效的矩阵乘法q = q.transpose(1,2)v = v.transpose(1,2)k = k.transpose(1,2)
view()
本身是一个“零拷贝”操作,它不会移动数据,只是重新解释数据的形状。它并不会真的在内存中移动数据。只是修改了张量的元数据(比如 stride,即访问下一个元素需要跳过多少个内存位置),创建了一个新的“视图”指向原来的数据。
为了提高效率,像 transpose(), permute(), narrow() 等操作都是创建一个新的视图,并没有在内存中移动数据。
# 创建一个 2x3 的连续张量
x = torch.arange(6).view(2, 3)
# tensor([[0, 1, 2],
# [3, 4, 5]])
在内存中,x 的数据是这样存储的:[0, 1, 2, 3, 4, 5]。
# 对 x 进行转置
y = x.transpose(0, 1)
# tensor([[0, 3],
# [1, 4],
# [2, 5]])
虽然 y 的逻辑形状是 (3, 2),但它在内存中的数据仍然是 [0, 1, 2, 3, 4, 5]
- 为了读取 y 的第一行 [0, 3],程序需要先读取位置0的 0,然后跳过 1 和 2,去读取位置3的 3。
- 因为元素不再是紧挨着的,所以这个张量 y 不是连续的,对于不连续的数据不能直接进行
.view()
操作
旋转位置编码这里先不做赘述
现在来解释一下为什么又要进行矩阵转置:
假设我们向注意力层输入的数据x长这样,一次输入4个toekn,每个token15个维度:
tensor([[[7, 2, 0, 2, 3, 7, 1, 2, 5, 5, 6, 6, 2, 9, 6],
[0, 1, 1, 9, 3, 3, 0, 2, 6, 4, 7, 0, 3, 7, 1],
[2, 7, 7, 0, 2, 1, 1, 8, 9, 6, 6, 2, 5, 5, 5],
[3, 6, 1, 1, 0, 0, 5, 3, 2, 0, 0, 9, 6, 6, 6]]])
尺寸为(1,4,15)
以q矩阵为例,假设(B, T, self.n_head, self.h_dim)=(1,4,3,5)
将输入 x 乘以一个可学习的权重矩阵(W_q) 得到q。这个过程叫做线性投影(Linear Projection)。
则q初始化后长这样:
tensor([[[[7, 2, 0, 2, 3],# T0, H0
[7, 1, 2, 5, 5],# T0, H1
[6, 6, 2, 9, 6]],# T0, H2
`[[0, 1, 1, 9, 3],# T1, H0``[3, 0, 2, 6, 4],# T1, H1``[7, 0, 3, 7, 1]],# T1, H2``[[2, 7, 7, 0, 2],# T2, H0``[1, 1, 8, 9, 6],# T2, H1``[6, 2, 5, 5, 5]],# T2, H2``[[3, 6, 1, 1, 0],# T3, H0``[0, 5, 3, 2, 0],# T3, H1``[0, 9, 6, 6, 6]]]])# T3, H2`
相当于把每个token截成四份,我们知道MQA,包括MHA计算注意力矩阵时都是头矩阵之间进行计算,所以我们按照头来分组,更能利用GPU擅长处理并行运算的特点。
对q的第2,3个维度进行转置:
tensor([[[[7, 2, 0, 2, 3], # H0, T0
[0, 1, 1, 9, 3], # H0, T1
[2, 7, 7, 0, 2], # H0, T2
[3, 6, 1, 1, 0]], # H0, T3
`[[7, 1, 2, 5, 5], # H1, T0``[3, 0, 2, 6, 4], # H1, T1``[1, 1, 8, 9, 6], # H1, T2``[0, 5, 3, 2, 0]], # H1, T3``[[6, 6, 2, 9, 6], # H2, T0``[7, 0, 3, 7, 1], # H2, T1``[6, 2, 5, 5, 5], # H2, T2``[0, 9, 6, 6, 6]]]]) # H2, T3`
这样我们就得到了每个头要处理的内容。
现在来介绍一下注意力分数的计算
if kv_cache is not None:k,v = kv_cache.insert_kv(self.layer_idx, k, v)Tq = q.size(2)Tk = k.size(2)nrep = self.n_head // self.kv_headk,v = repeat_kv(k, nrep), repeat_kv(v, nrep)if kv_cache is None or Tq == Tk:y = F.scaled_dot_product_attention(q, k, v, is_causal=True)# scaled_dot_product_attention将注意力机制的多个步骤融合,同时会调用最优的实现,比如FlashAttention来实现算力优化elif Tq == 1:# Tq = 1说明是单token生成,即推理场景,所以不需要掩码y = F.scaled_dot_product_attention(q, k, v, is_causal=False)else:mask = torch.zeros((Tq,Tk), device=q.device, dtype=torch.bool)prefix_len = Tk - Tqif prefix_len > 0:mask[:, :prefix_len] = Truemask[:, prefix_len:] = torch.tril(torch.ones((Tq,Tq), device=q.device, dtype=torch.bool))y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)y.transpose(1,2).contiguous().view(B, T, -1) # (B, H, T, D) -> (B, T, C)y = self.out(y)return y
这里先不介绍kv缓存的实现
nrep是为了计算出一个kv头要对应多少个查询头,然后对kv头进行复制
def repeat_kv(x, nrep):if nrep == 1:return xbs, h, slen, dim = x.shapereturn(x[:, :, None, :, :].expand(bs, h, nrep, slen, dim).reshape(bs, h * nrep, slen, dim))
假设一个原先的k头长这样
k = tensor([[[[0,0,0,0], [0,0,0,1], [0,0,0,2]], # K_H0
[[1,1,1,0], [1,1,1,1], [1,1,1,2]]]]) # K_H1
在x[:, :, None, :, :]
后:
tensor([[[ [[0,0,0,0], [0,0,0,1], [0,0,0,2]] ], # K_H0
[ [[1,1,1,0], [1,1,1,1], [1,1,1,2]] ]]]) # K_H1expand后:
tensor([[[ [[0,0,0,0], [0,0,0,1], [0,0,0,2]]], # KV_H0 (original)
[[0,0,0,0], [0,0,0,1], [0,0,0,2]]], # KV_H0 (view 1)
[[0,0,0,0], [0,0,0,1], [0,0,0,2]]], # KV_H0 (view 2)
[[0,0,0,0], [0,0,0,1], [0,0,0,2]]]], # KV_H0 (view 3)`
`[ [[1,1,1,0], [1,1,1,1], [1,1,1,2]]], # KV_H1 (original)``[[1,1,1,0], [1,1,1,1], [1,1,1,2]]], # KV_H1 (view 1)``[[1,1,1,0], [1,1,1,1], [1,1,1,2]]], # KV_H1 (view 2)``[[1,1,1,0], [1,1,1,1], [1,1,1,2]]]]])# KV_H1 (view 3)`
reshape后:
tensor([[[[0,0,0,0], [0,0,0,1], [0,0,0,2]]], # Head 0 (from K_H0)
[[0,0,0,0], [0,0,0,1], [0,0,0,2]]], # Head 1 (from K_H0)
[[0,0,0,0], [0,0,0,1], [0,0,0,2]]], # Head 2 (from K_H0)
[[0,0,0,0], [0,0,0,1], [0,0,0,2]]], # Head 3 (from K_H0)
[[1,1,1,0], [1,1,1,1], [1,1,1,2]]], # Head 4 (from K_H1)
[[1,1,1,0], [1,1,1,1], [1,1,1,2]]], # Head 5 (from K_H1)
[[1,1,1,0], [1,1,1,1], [1,1,1,2]]], # Head 6 (from K_H1)
[[1,1,1,0], [1,1,1,1], [1,1,1,2]]]])# Head 7 (from K_H1)
这样在与q矩阵运算时就能一一对应了
scaled_dot_product_attention(q, k, v, is_causal=False)
是torch.nn.functional
的一个函数,只要输入q,k,v矩阵和是否进行掩码就能注意力分数的计算。这个函数会在底层使用flashattention机制来实现算力优化。
现在来介绍一下标准注意力的计算
$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$
attn_score = torch.matmul(q,k.transpose(-2, -1))
attn_score = attn_score / self.h_dim
if is_causal:mask =torch.triu(torch.ones(attn_score.size(-2),attn_score.size(-1)), diagonal=1)
attn_score = attn_score.mask_fill(mask, float('inf'))
attn_weights = F.softmax(attn_score, dim=-1)
output = torch.matmul(attn_weights, v)
matul是矩阵相乘,torch.triu用于创建一个上三角矩阵
假如说x长这样:
tensor([[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.],
[1., 1., 1., 1.]])
torch.triu(x, diagonal=0)就长这样:
tensor([[ 1., 1., 1., 1.],
[ 0., 1., 1., 1.],
[ 0., 0., 1., 1.],
[ 0., 0., 0., 1.]])
torch.triu(x, diagonal=1)就长这样:
tensor([[ 0., 1., 1., 1.],
[ 0., 0., 1., 1.],
[ 0., 0., 0., 1.],
[ 0., 0., 0., 0.]])
.mask_fill表示应用掩码矩阵,float('inf')表示负无限,方便后续交给softmax处理
下面来介绍三个选择情况:
if kv_cache is None or Tq == Tk
表示没有kv缓存或q,k矩阵序列长度相同的情况,一般这个情况都是在进行模型训练,此时需要掩码。
elif Tq == 1:
q的序列长度为1,说明此时正在进行单token生成,即推理状态,此时不需要掩码。
else:
这个块处理的是模型在推理时,一次性处理一批(一个 chunk)新的查询(Tq > 1),并且 KV 缓存中已经存在一部分历史信息(Tk > Tq)。在这种模式下,一个小的、快速的“草稿模型”会先生成一小段文本(比如4个 tokens),然后主模型(就是我们正在分析的这个)会一次性地验证这4个 tokens。这时,Tq 就等于 4。即投机解码算法(speculative decoding)主要用于加速模型的推理。
对于该算法的详细流程,这里先不做过多赘述。
y.transpose(1,2).contiguous().view(B, T, -1) # (B, H, T, D) -> (B, T, C)y = self.out(y)
将所有头拼接起来,相当于前面转置qkv矩阵的逆操作
contiguous()是为了创建一个在内存上连续的y副本。view() 本身也是一个“零拷贝”操作,它不会移动数据,只是重新解释数据的形状。如果数据在内存中是“乱”的(非连续的),view() 就不知道该如何正确地、高效地重新解释它。这一点在前面解释view()时谈到过