深入理解 Attention 機制:從原理到實作
前言
在深度學習的發展歷程中,Attention 機制(注意力機制)無疑是近十年來最具革命性的突破之一。從 2015 年首次被引入 Seq2Seq 模型,到 2017 年 Transformer 架構的誕生,再到今天的 GPT、BERT、ChatGPT,Attention 機制始終是這些模型的核心引擎。
但 Attention 到底是什麼?它解決了什麼問題?數學原理是什麼?本文將從零開始,帶你完整理解 Attention 機制的來龍去脈,並用 Python 從頭實作一個完整的 Self-Attention 模組。
一、為什麼需要 Attention?
1.1 傳統 Seq2Seq 的瓶頸
在 Attention 出現之前,處理序列資料(如機器翻譯、語音辨識)的主流方法是 Seq2Seq(Encoder-Decoder) 架構,搭配 LSTM 或 GRU 作為基礎單元。
這個架構的運作方式如下:
- Encoder 讀入整個輸入序列(例如一句英文),並將所有資訊壓縮成一個固定長度的向量,稱為 Context Vector(上下文向量)。
- Decoder 接收這個 Context Vector,逐步生成輸出序列(例如對應的中文翻譯)。
"I love deep learning"
|
Encoder
|
[Context Vector] ← 所有資訊都被壓縮在這裡
|
Decoder
|
"我愛深度學習"
這個架構的核心問題在於:所有的輸入資訊都被強制壓縮進一個固定維度的向量。當輸入序列很短時(5 個詞)還好,但當輸入是一篇 500 字的文章時,要求一個向量記住所有細節就顯得力不從心。
這個現象被稱為 Information Bottleneck(資訊瓶頸),是傳統 Seq2Seq 最大的缺陷。
1.2 人類的閱讀方式
反觀人類在理解語言時,並不是把整句話「壓縮成一個概念」再開始翻譯。以翻譯「The cat sat on the mat」為例,當我們翻譯「貓」這個字時,我們的注意力主要集中在 "cat" 上;翻譯「坐」時,注意力轉移到 "sat" 上。
這就是 Attention 的核心直覺:在生成每個輸出 token 時,動態地決定應該「關注」輸入序列中的哪些位置,而不是依賴一個固定的壓縮向量。
二、Attention 機制的數學原理
2.1 基本 Attention 公式
Attention 機制的核心計算可以用一個簡潔的公式表達:
這個公式看起來簡短,但包含了深刻的設計思想。讓我們逐步拆解每個元素。
2.2 Query、Key、Value 的直覺
Attention 機制使用三個核心概念,可以用「資料庫查詢」的比喻來理解:
- Query(查詢,Q):你正在「問」什麼。例如 Decoder 當前要生成的 token 想知道什麼。
- Key(鍵,K):資料庫裡每筆資料的「索引標籤」。每個輸入 token 都有一個 Key。
- Value(值,V):資料庫裡每筆資料的「實際內容」。每個輸入 token 都有一個 Value。
計算流程如下:
- 用 Query 對每個 Key 計算相似度(點積),得到一組分數。
- 對分數做 Softmax,得到注意力權重(加總為 1)。
- 用注意力權重對所有 Value 做加權平均,得到最終輸出。
Query: "我要生成哪個中文字?"
Keys: ["I", "love", "deep", "learning"] → 計算相似度
Scores: [0.1, 0.6, 0.2, 0.1] → Softmax
Weights: [0.1, 0.6, 0.2, 0.1]
Output: 0.1×V_I + 0.6×V_love + 0.2×V_deep + 0.1×V_learning
2.3 為什麼要除以 ?
點積的結果會隨著向量維度 的增大而變大,導致 Softmax 的梯度非常小(梯度消失)。除以 是一種縮放(scaling),確保點積值不會過大,讓 Softmax 能夠在合理的範圍內運作。
這種 Attention 也因此被稱為 Scaled Dot-Product Attention。
2.4 Self-Attention
Self-Attention(自注意力) 是 Attention 機制的一個特殊且重要的形式:Query、Key、Value 都來自同一個序列。
這使得序列中的每個 token 都可以「看到」序列中的所有其他 token,從而學習到長距離的依賴關係。
例如,在句子「The animal didn't cross the street because it was too tired」中,Self-Attention 可以讓 "it" 這個 token 學到它指的是 "animal" 而不是 "street",即使它們相距較遠。
Q、K、V 的計算方式:
其中 是輸入序列的嵌入矩陣, 是可學習的投影矩陣。
2.5 Multi-Head Attention
單一的 Attention head 每次只能關注序列中的一種模式。Multi-Head Attention 則是同時運行 個獨立的 Attention head,每個 head 學習不同的關注模式,最後將所有 head 的輸出拼接起來。
舉例來說,在翻譯任務中:
- Head 1 可能學會關注句法結構(主詞和動詞的關係)
- Head 2 可能學會關注語義相似性
- Head 3 可能學會關注位置鄰近性
這種並行的多視角學習是 Multi-Head Attention 強大的原因。
2.6 Masked Attention
在自回歸生成任務(如語言模型)中,Decoder 在生成第 個 token 時,不能看到未來的 token(),否則會造成「作弊」。
Masked Attention 通過在 Softmax 之前將未來位置的分數設為負無窮大(),使其在 Softmax 後的權重為 0,從而實現這一約束:
其中 是遮罩矩陣,對應「不應看到」的位置填入 ,其餘填入 。
三、Positional Encoding
Attention 機制本身是位置不敏感的,也就是說,如果你把輸入序列的順序打亂,Attention 的輸出不會改變。這顯然不合理,因為「我愛你」和「你愛我」的意思完全不同。
Positional Encoding(位置編碼) 的作用是給每個 token 的嵌入向量加上位置資訊。Transformer 原始論文使用正弦與餘弦函數:
其中 是位置(0, 1, 2, ...), 是維度索引, 是嵌入維度。
選擇三角函數的原因在於:不同頻率的正弦波可以唯一表示任意長度的位置,而且模型可以透過線性變換從絕對位置推算相對位置。
四、Python 實作
現在讓我們從頭實作一個完整的 Multi-Head Self-Attention 模組。
4.1 環境設定
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import matplotlib.pyplot as plt
4.2 Scaled Dot-Product Attention
def scaled_dot_product_attention(query, key, value, mask=None):
"""
計算 Scaled Dot-Product Attention。
Args:
query: shape (batch, heads, seq_len_q, d_k)
key: shape (batch, heads, seq_len_k, d_k)
value: shape (batch, heads, seq_len_v, d_v)
mask: shape (batch, 1, seq_len_q, seq_len_k), optional
Returns:
output: shape (batch, heads, seq_len_q, d_v)
weights: shape (batch, heads, seq_len_q, seq_len_k)
"""
d_k = query.size(-1) # Key 的維度
# Step 1: 計算 Q 和 K 的點積,並縮放
# scores shape: (batch, heads, seq_len_q, seq_len_k)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
# Step 2: 套用遮罩(如果有的話)
if mask is not None:
# 將 mask 為 0 的位置設為極小值,Softmax 後趨近於 0
scores = scores.masked_fill(mask == 0, float('-inf'))
# Step 3: Softmax 取得注意力權重
# 沿最後一個維度(key 方向)做 Softmax
weights = F.softmax(scores, dim=-1)
# Step 4: 用注意力權重對 Value 做加權平均
output = torch.matmul(weights, value)
return output, weights
4.3 Multi-Head Attention 模組
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
"""
Multi-Head Attention 模組。
Args:
d_model: 模型的總維度(嵌入維度)
num_heads: Attention Head 的數量
"""
super(MultiHeadAttention, self).__init__()
# 確保 d_model 可以被 num_heads 整除
assert d_model % num_heads == 0, "d_model 必須能被 num_heads 整除"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # 每個 head 的維度
# 定義 Q、K、V 和輸出的線性投影層
# 注意:一次性投影整個 d_model,之後再拆分成多個 head
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
def split_heads(self, x, batch_size):
"""
將最後一個維度拆分成 (num_heads, d_k),並調整維度順序。
Args:
x: shape (batch, seq_len, d_model)
Returns:
shape (batch, num_heads, seq_len, d_k)
"""
x = x.view(batch_size, -1, self.num_heads, self.d_k)
return x.permute(0, 2, 1, 3) # (batch, heads, seq_len, d_k)
def forward(self, query, key, value, mask=None):
"""
前向傳播。
Args:
query: shape (batch, seq_len_q, d_model)
key: shape (batch, seq_len_k, d_model)
value: shape (batch, seq_len_v, d_model)
mask: optional mask
Returns:
output: shape (batch, seq_len_q, d_model)
"""
batch_size = query.size(0)
# Step 1: 線性投影 Q、K、V
Q = self.W_q(query) # (batch, seq_len_q, d_model)
K = self.W_k(key) # (batch, seq_len_k, d_model)
V = self.W_v(value) # (batch, seq_len_v, d_model)
# Step 2: 拆分成多個 head
Q = self.split_heads(Q, batch_size) # (batch, heads, seq_len_q, d_k)
K = self.split_heads(K, batch_size) # (batch, heads, seq_len_k, d_k)
V = self.split_heads(V, batch_size) # (batch, heads, seq_len_v, d_k)
# Step 3: 計算每個 head 的 Attention
attn_output, self.attention_weights = scaled_dot_product_attention(Q, K, V, mask)
# attn_output shape: (batch, heads, seq_len_q, d_k)
# Step 4: 合併所有 head 的輸出
# 先調換維度,再 reshape 成 (batch, seq_len_q, d_model)
attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
attn_output = attn_output.view(batch_size, -1, self.d_model)
# Step 5: 最終線性投影
output = self.W_o(attn_output)
return output
4.4 Positional Encoding
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_seq_len=5000, dropout=0.1):
"""
Sinusoidal Positional Encoding。
Args:
d_model: 嵌入維度
max_seq_len: 最大序列長度
dropout: Dropout 比例
"""
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
# 建立位置編碼矩陣
pe = torch.zeros(max_seq_len, d_model) # (max_seq_len, d_model)
position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
# position shape: (max_seq_len, 1)
# 計算分母的除數項
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
# 偶數維度用 sin,奇數維度用 cos
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
# 增加 batch 維度:(1, max_seq_len, d_model)
pe = pe.unsqueeze(0)
# 使用 register_buffer 讓 pe 不參與梯度更新,但隨模型移動(to device)
self.register_buffer('pe', pe)
def forward(self, x):
"""
Args:
x: shape (batch, seq_len, d_model)
Returns:
shape (batch, seq_len, d_model)
"""
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)
4.5 Feed-Forward Network
Transformer 中每個 Attention 層後面都接了一個 Feed-Forward Network(FFN):
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
"""
Position-wise Feed-Forward Network。
Args:
d_model: 輸入/輸出維度
d_ff: 隱藏層維度(通常是 d_model 的 4 倍)
dropout: Dropout 比例
"""
super(FeedForward, self).__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# 兩層線性變換,中間有 ReLU 激活和 Dropout
return self.linear2(self.dropout(F.relu(self.linear1(x))))
4.6 Transformer Encoder Layer
將 Self-Attention 和 FFN 組合成一個完整的 Encoder Layer:
class EncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
"""
單一 Transformer Encoder 層。
包含 Multi-Head Self-Attention + Feed-Forward Network,
以及殘差連接(Residual Connection)和層正規化(Layer Normalization)。
"""
super(EncoderLayer, self).__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.feed_forward = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
"""
Args:
x: shape (batch, seq_len, d_model)
mask: optional padding mask
"""
# Sub-layer 1: Multi-Head Self-Attention + Residual + LayerNorm
attn_output = self.self_attn(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_output)) # Add & Norm
# Sub-layer 2: Feed-Forward + Residual + LayerNorm
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_output)) # Add & Norm
return x
4.7 完整的 Transformer Encoder
class TransformerEncoder(nn.Module):
def __init__(self, vocab_size, d_model=512, num_heads=8,
d_ff=2048, num_layers=6, max_seq_len=512, dropout=0.1):
"""
完整的 Transformer Encoder。
Args:
vocab_size: 詞彙表大小
d_model: 嵌入維度
num_heads: Attention head 數量
d_ff: FFN 隱藏層維度
num_layers: Encoder layer 數量
max_seq_len: 最大序列長度
dropout: Dropout 比例
"""
super(TransformerEncoder, self).__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
self.layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(d_model)
self.d_model = d_model
def forward(self, x, mask=None):
"""
Args:
x: token indices, shape (batch, seq_len)
mask: optional padding mask
"""
# 嵌入 + 縮放 + 位置編碼
x = self.embedding(x) * math.sqrt(self.d_model)
x = self.pos_encoding(x)
# 逐層通過 Encoder Layer
for layer in self.layers:
x = layer(x, mask)
return self.norm(x)
4.8 完整測試與視覺化
def visualize_attention(attention_weights, tokens, layer=0, head=0):
"""
視覺化 Attention 權重熱力圖。
Args:
attention_weights: shape (batch, heads, seq_len, seq_len)
tokens: token 列表(字串)
layer: 要視覺化的 layer 索引
head: 要視覺化的 head 索引
"""
weights = attention_weights[0, head].detach().numpy()
fig, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(weights, cmap='Blues', aspect='auto')
ax.set_xticks(range(len(tokens)))
ax.set_yticks(range(len(tokens)))
ax.set_xticklabels(tokens, rotation=45, ha='right')
ax.set_yticklabels(tokens)
ax.set_xlabel('Key (被關注的 token)')
ax.set_ylabel('Query (正在生成的 token)')
ax.set_title(f'Attention Weights (Head {head})')
plt.colorbar(im, ax=ax)
plt.tight_layout()
plt.show()
def test_transformer():
"""完整測試 Transformer Encoder。"""
print("=" * 50)
print("Transformer Encoder 測試")
print("=" * 50)
# 超參數設定
vocab_size = 1000
d_model = 128 # 嵌入維度(示範用,實際通常 512 或 768)
num_heads = 4 # Attention head 數量
d_ff = 512 # FFN 隱藏層維度
num_layers = 2 # Encoder layer 數量
batch_size = 2
seq_len = 10
# 建立模型
model = TransformerEncoder(
vocab_size=vocab_size,
d_model=d_model,
num_heads=num_heads,
d_ff=d_ff,
num_layers=num_layers,
)
# 計算參數量
total_params = sum(p.numel() for p in model.parameters())
print(f"模型總參數量: {total_params:,}")
# 建立假資料
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
print(f"\n輸入 shape: {input_ids.shape}")
# 建立 Padding Mask(假設最後 2 個 token 是 padding)
# mask 為 0 的位置表示 padding,Attention 會忽略這些位置
mask = torch.ones(batch_size, 1, 1, seq_len)
mask[:, :, :, -2:] = 0 # 最後 2 個位置是 padding
# 前向傳播
model.eval()
with torch.no_grad():
output = model(input_ids, mask)
print(f"輸出 shape: {output.shape}")
print(f"期望 shape: ({batch_size}, {seq_len}, {d_model})")
# 取得 Attention 權重(第一個 Encoder Layer)
attn_weights = model.layers[0].self_attn.attention_weights
print(f"\nAttention weights shape: {attn_weights.shape}")
print(f"(batch={batch_size}, heads={num_heads}, "
f"seq_len={seq_len}, seq_len={seq_len})")
# 驗證 Attention 權重加總為 1
weight_sum = attn_weights[0, 0].sum(dim=-1) # 第 0 個 batch,第 0 個 head
print(f"\nAttention 權重每行加總(應為 1):")
print(weight_sum.numpy().round(4))
return model, attn_weights, input_ids
# 執行測試
if __name__ == "__main__":
model, attn_weights, input_ids = test_transformer()
# 視覺化(使用假的 token 名稱)
tokens = [f"tok_{i}" for i in range(10)]
visualize_attention(attn_weights, tokens, head=0)
五、Attention 機制的變種與發展
5.1 Cross-Attention
在 Encoder-Decoder 架構中,Decoder 需要同時處理兩個輸入:
- 已生成的輸出序列(Self-Attention)
- Encoder 的輸出(Cross-Attention)
Cross-Attention 中,Query 來自 Decoder,而 Key 和 Value 來自 Encoder:
# Cross-Attention 的使用方式
decoder_output = cross_attention(
query=decoder_hidden, # Q 來自 Decoder
key=encoder_output, # K 來自 Encoder
value=encoder_output # V 來自 Encoder
)
5.2 Efficient Attention 變種
標準 Self-Attention 的計算複雜度是 ,對長序列非常耗時。近年來出現了多種高效變種:
| 方法 | 複雜度 | 核心思想 |
|---|---|---|
| Longformer | 局部視窗 + 全局 token | |
| Linformer | 低秩近似 Key/Value | |
| Performer | 隨機特徵近似 Softmax | |
| Flash Attention | IO 感知計算(GPU 記憶體優化) | |
| Sliding Window | 固定視窗大小 |
5.3 相對位置編碼
原始 Transformer 使用絕對位置編碼,但相對位置編碼(Relative Positional Encoding) 直接在 Attention 分數中加入相對位置資訊,在某些任務上效果更好:
其中 表示位置 和位置 之間的相對距離編碼。GPT-Neo、T5 等模型都使用了相對位置編碼的變種。
六、從 Attention 到 Transformer 到 LLM
理解了 Attention 機制,你就掌握了現代大型語言模型(LLM)的核心:
Attention (2015)
↓
Self-Attention (2017, Transformer)
↓
BERT (2018, Encoder-only, 雙向)
GPT (2018, Decoder-only, 單向)
↓
GPT-2 / GPT-3 (2019/2020, 規模擴展)
↓
ChatGPT / GPT-4 (2022/2023, RLHF 對齊)
↓
Llama / Claude / Gemini (2023-至今)
每一步的進化,底層的 Attention 機制都沒有根本性的改變,改變的是:
- 規模(參數量從百萬到兆)
- 訓練方式(預訓練 + 微調 + RLHF)
- 工程優化(Flash Attention、GQA、RoPE)
七、重點整理
-
Attention 解決了什麼問題:傳統 Seq2Seq 的 Information Bottleneck,讓模型在生成時動態選擇關注哪些輸入。
-
核心公式:,除以 是為了穩定梯度。
-
Self-Attention:Q、K、V 來自同一序列,讓每個 token 都能直接與序列中所有其他 token 互動。
-
Multi-Head Attention:同時運行多個 Attention head,各自學習不同的關注模式,最後拼接。
-
Positional Encoding:使用三角函數為每個位置添加唯一的位置資訊,補足 Attention 缺乏位置感知的缺陷。
-
計算複雜度:標準 Self-Attention 是 ,這是長序列的主要瓶頸,催生了大量 Efficient Attention 的研究。
參考資料
- Bahdanau et al. (2015). Neural Machine Translation by Jointly Learning to Align and Translate. [原始 Attention 論文]
- Vaswani et al. (2017). Attention Is All You Need. [Transformer 論文]
- Devlin et al. (2018). BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding.
- Dao et al. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.
- The Illustrated Transformer by Jay Alammar — 最好的 Transformer 圖解教學
留言
使用 GitHub 帳號登入即可留言