注意力机制图解

学习来源

类型 名称 作者/机构
论文 Attention Is All You Need Vaswani et al., Google Brain
图解教程 The Illustrated Transformer Jay Alammar
在线课程 Generative AI with LLMs DeepLearning.AI
技术博客 通俗理解Transformer架构(2026最新配图版) CSDN社区

核心收获

注意力机制是Transformer的灵魂,它让模型能够"看清"输入序列中每个位置与其他所有位置的关联,实现真正的并行计算和长距离依赖捕捉。

  1. QKV三角:Query(我想找什么)、Key(我有什么)、Value(我真正提供什么)——理解这三个向量的关系是掌握注意力机制的关键
  2. 缩放点积注意力:通过 softmax(QK^T / √d_k)V 的数学公式,实现动态加权求和,√d_k的缩放因子防止梯度消失
  3. 多头注意力:将注意力拆分到多个子空间并行计算,不同头捕捉不同类型的依赖关系(语法、语义、位置等)
  4. 位置编码:Transformer本身不感知位置,需要通过正弦/余弦函数注入位置信息
  5. 残差连接+层归一化:确保深层网络的稳定训练,防止梯度消失

正文内容:注意力机制原理详解

一、从RNN的局限到注意力机制的诞生

在2017年之前,序列处理任务主要由RNN(循环神经网络)和LSTM/GRU主导。这些模型的工作方式类似于一条"单行道"——必须按顺序处理序列中的每个元素,前一个token的输出必须等待才能处理下一个token。

这种设计带来了两个根本性问题:

  • 并行计算受限:无法充分利用GPU的并行能力,训练速度慢
  • 长距离依赖捕捉困难:信息在传递过程中会逐渐稀释或丢失

2014年,注意力机制首次被应用于机器翻译任务。核心思想很简单:模型在处理每个词时,应该动态地"关注"输入序列中最相关的部分,而不是均匀地处理所有信息。这就像人类阅读时的注意力分配——当被问"谁提出了相对论"时,我们会自动聚焦于"爱因斯坦"这个词。

核心洞察

Transformer的最大创新是:彻底摒弃循环结构,全部使用注意力机制 + 前馈神经网络。这使得并行计算成为可能,训练速度提升3-5倍,同时能更好地捕捉任意距离的依赖关系。

二、Transformer整体架构图解

Transformer采用经典的Encoder-Decoder(编码器-解码器)架构。以下是原论文《Attention Is All You Need》中的经典架构图:

架构概览
  • 左侧Encoder:接收源序列,提取语义特征,堆叠6层
  • 右侧Decoder:基于Encoder输出和已生成的部分,顺序生成目标序列
  • 核心组件:多头注意力、前馈网络、残差连接、层归一化、位置编码

三、Scaled Dot-Product Attention(缩放点积注意力)

这是Transformer最核心的计算单元。让我们一步步拆解其工作原理:

Attention(Q, K, V) = softmax(QK^T / √d_k) × V

3.1 QKV三角:三个向量的本质含义

理解Query(Q)、Key(K)、Value(V)这三个向量是掌握注意力机制的关键:

向量 含义 通俗比喻
Query(查询) "当前我想找什么" 你在图书馆找书时的需求描述
Key(键) "每个元素拥有什么特征" 每本书的标签和分类
Value(值) "每个元素的实际内容" 书的实际内容

3.2 计算流程分步图解

1 生成QKV向量

输入序列经过可学习的权重矩阵 W_Q、W_K、W_V 进行线性变换,生成Query、Key、Value三个矩阵。

# 线性投影生成Q、K、V import torch import torch.nn as nn # 假设嵌入维度 d_model=512,注意力维度 d_k=64 d_model = 512 d_k = 64 # 可学习的权重矩阵 W_Q = nn.Linear(d_model, d_k, bias=False) W_K = nn.Linear(d_model, d_k, bias=False) W_V = nn.Linear(d_model, d_k, bias=False) # 输入 x: (batch_size, seq_len, d_model) # 输出 Q, K, V: (batch_size, seq_len, d_k) Q = W_Q(x) K = W_K(x) V = W_V(x)
2 计算注意力分数(点积)

计算Query与所有Key的点积,得到相似度矩阵。点积越大,表示相关性越强。

# 计算点积注意力分数 # Q: (batch, seq_len, d_k) @ K^T: (batch, d_k, seq_len) # → scores: (batch, seq_len, seq_len) attention_scores = torch.matmul(Q, K.transpose(-2, -1)) # 示例:对于句子 "The animal didn't cross the street because it was tired" # "it" 与 "animal" 的点积会远高于其他词
3 缩放处理(防止梯度消失)

将点积结果除以 √d_k 进行缩放。当d_k较大时,点积的方差会很大,导致softmax输出过于极端(接近one-hot),梯度接近于零。

import math # 缩放因子:√d_k scaled_scores = attention_scores / math.sqrt(d_k) # d_k=64 → √64=8 # 缩放后softmax梯度更加稳定
4 Softmax归一化

对缩放后的分数应用Softmax,将相似度转换为概率分布(所有权重为正且和为1)。

# Softmax归一化 # 在最后一个维度(seq_len)上执行softmax attention_weights = torch.nn.functional.softmax(scaled_scores, dim=-1) # 输出形状: (batch, seq_len, seq_len) # 每一行表示当前位置对所有位置的注意力权重
5 加权求和

用注意力权重对Value进行加权求和,得到最终的注意力输出向量。

# 加权求和得到上下文向量 # attention_weights: (batch, seq_len, seq_len) @ V: (batch, seq_len, d_k) # → output: (batch, seq_len, d_k) output = torch.matmul(attention_weights, V) # output的每一行 = 所有value向量的加权和 # 包含了整个序列的上下文信息

3.3 经典例子:代词消解

注意力机制最经典的例子是代词消解。考虑这句话:

"The animal didn't cross the street because it was too tired."

这里的"it"应该指代"animal"。注意力机制会自动给"animal"分配极高的权重(如0.9),而其他词的权重很低(如0.05)。

而如果是另一句话:

"The animal didn't cross the street because it was too wide."

这里的"it"应该指代"street"。注意力权重会自动调整,给"street"分配更高权重。

四、Multi-Head Attention(多头注意力)

单头注意力只能在一个子空间建模关系,而多头注意力让模型能够在多个不同子空间中并行建模。这就像让多个不同科室的医生会诊——每个医生关注不同维度的特征。

4.1 为什么需要多头?

类型 视角 表达力 适用场景
单头注意力 单一 一般 简单依赖关系
多头注意力 多视角 复杂语义、语法、位置关系

在实际Transformer中,不同的注意力头可以学习到不同类型的依赖模式:

  • 有的头关注局部信息(相邻词的关系)
  • 有的头关注句法结构(主谓一致、修饰关系)
  • 有的头关注长距离语义(跨句子的指代关系)

4.2 多头注意力的数学原理

MultiHead(Q, K, V) = Concat(head₁, head₂, ..., headₕ) × W^O
其中 headᵢ = Attention(QW_i^Q, KW_i^K, VW_i^V)
# 多头注意力PyTorch实现 class MultiHeadAttention(nn.Module): def __init__(self, d_model=512, num_heads=8): super().__init__() assert d_model % num_heads == 0 self.d_k = d_model // num_heads # 每头维度(如512/8=64) self.num_heads = num_heads # 线性投影层 self.W_Q = nn.Linear(d_model, d_model) self.W_K = nn.Linear(d_model, d_model) self.W_V = nn.Linear(d_model, d_model) self.W_O = nn.Linear(d_model, d_model) def forward(self, Q, K, V, mask=None): batch_size = Q.size(0) # 1. 线性投影 Q = self.W_Q(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) K = self.W_K(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) V = self.W_V(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) # 2. Scaled Dot-Product Attention scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) attn_weights = torch.softmax(scores, dim=-1) context = torch.matmul(attn_weights, V) # 3. 拼接多头并投影 context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k) return self.W_O(context)

4.3 多头的计算效率

虽然头数增加了,但每个头的维度会等比例缩小,总计算量与单头全维度注意力相近:

  • d_model = 512,h = 8 → 每头维度 d_k = 64
  • 单头全维度:O(512²)
  • 8头独立计算:O(8 × 64²) = O(32768) ≈ O(512²)

五、Transformer中的三种注意力

Transformer架构中包含三种不同类型的注意力机制:

位置 注意力类型 Q来源 K/V来源 作用
Encoder层 多头自注意力 Encoder输入 Encoder输入 让输入序列的每个词看到其他所有词
Decoder第一层 掩码多头自注意力 Decoder输入 Decoder输入 防止看到未来位置(自回归生成)
Decoder第二层 交叉注意力 Decoder Encoder输出 让Decoder关注Encoder的全局信息
交叉注意力的本质

交叉注意力 = Decoder拿着"当前想找什么(Q)",去Encoder的信息库里"检索(匹配K)",然后把最相关的内容(V)加权取出来。这是机器翻译的核心机制。

六、工程实现:核心算子链

从工程视角看,Transformer的核心算子链如下:

# Transformer 核心算子链 Embedding → Add Positional Encoding # 注入位置信息 → Linear(Q/K/V) # 线性投影生成QKV → Reshape/Transpose # 张量变形准备多头计算 → MatMul(Q, K^T) # 计算注意力分数 → Scale # 除以√d_k缩放 → Mask (Decoder用) # 掩码未来位置 → Softmax # 归一化为概率 → MatMul(Attn, V) # 加权求和 → Concat # 拼接多头 → Linear # 最终投影 → Add & Norm # 残差连接+层归一化 → FFN # 前馈神经网络 → Add & Norm # 残差连接+层归一化

七、残差连接与层归一化

残差连接(Residual Connection)和层归一化(Layer Normalization)是Transformer训练稳定性的关键:

7.1 残差连接

将输入直接加到输出上:output = x + f(x)。这确保了即使注意力层什么都没学到,模型至少可以学习恒等函数,不会"遗忘"原始信息。

7.2 层归一化

对每个样本的向量进行归一化(零均值、单位方差),而不是像Batch Norm那样跨批次归一化。这对变长序列和小批次训练更友好。

# Encoder层完整实现 class EncoderLayer(nn.Module): def __init__(self, embed_size, num_heads, forward_expansion=4, dropout=0): super().__init__() self.attention = MultiHeadAttention(embed_size, num_heads) self.norm1 = nn.LayerNorm(embed_size) self.norm2 = nn.LayerNorm(embed_size) self.feed_forward = nn.Sequential( nn.Linear(embed_size, forward_expansion * embed_size), nn.ReLU(), nn.Linear(forward_expansion * embed_size, embed_size), ) self.dropout = nn.Dropout(dropout) def forward(self, x, mask): # 自注意力 + 残差 + 归一化 attention = self.attention(x, x, x, mask) x = self.dropout(self.norm1(attention + x)) # 前馈网络 + 残差 + 归一化 forward = self.feed_forward(x) out = self.dropout(self.norm2(forward + x)) return out

八、注意力机制 vs RNN/LSTM

特性 RNN/LSTM 注意力机制(Transformer)
计算方式 串行(必须等前一步) 并行(所有位置同时计算)
长距离依赖 梯度消失/爆炸问题 直接建模任意距离
可解释性 隐状态难以解释 注意力权重可直接可视化
训练速度 慢(无法利用GPU并行) 快(完全并行)
内存复杂度 O(n) O(n²)(n为序列长度)

相关链接

💭 思考与实践

深度思考

1. 为什么缩放因子是 √d_k?

点积的方差与d_k成正比。当d_k很大时,点积结果可能非常大,导致softmax进入饱和区(梯度接近0)。除以√d_k可以将方差归一化到1附近,确保softmax的梯度在合理范围内。

2. 多头注意力的头数如何选择?

Transformer原论文使用8头(Base)和16头(Large)。研究表明,头数太多会导致每个头的维度太小,学不到有效特征;头数太少则无法捕捉足够多样的模式。一般建议:d_model / num_heads ≥ 64。

3. 注意力机制的O(n²)复杂度问题

对于超长序列(如10K+ tokens),标准注意力的内存和计算量会爆炸。解决方案包括:Sparse Attention、Linear Attention、Flash Attention等。

实践建议

  1. 可视化注意力权重:使用BertViz或tensor2tensor查看BERT/Transformer的注意力分布,直观理解模型在关注什么
  2. 实现一个Mini Transformer:用PyTorch实现简化版的Transformer,加深对每个组件的理解
  3. 对比不同头数的效果:在相同数据集上训练1头、4头、8头的模型,观察性能差异
  4. 分析真实案例:用注意力机制分析代码补全、对话生成等场景,理解其在实际应用中的表现

下一步学习

明天(周二)我们将深入学习自注意力机制与多头注意力的进阶内容,包括:

  • 自注意力与普通注意力的区别
  • 掩码多头注意力的实现原理
  • 多头注意力的可视化分析
← 返回课程列表