深入理解 Attention 機制:從原理到實作

前言

在深度學習的發展歷程中,Attention 機制(注意力機制)無疑是近十年來最具革命性的突破之一。從 2015 年首次被引入 Seq2Seq 模型,到 2017 年 Transformer 架構的誕生,再到今天的 GPT、BERT、ChatGPT,Attention 機制始終是這些模型的核心引擎。

但 Attention 到底是什麼?它解決了什麼問題?數學原理是什麼?本文將從零開始,帶你完整理解 Attention 機制的來龍去脈,並用 Python 從頭實作一個完整的 Self-Attention 模組。


一、為什麼需要 Attention?

1.1 傳統 Seq2Seq 的瓶頸

在 Attention 出現之前,處理序列資料(如機器翻譯、語音辨識)的主流方法是 Seq2Seq(Encoder-Decoder) 架構,搭配 LSTM 或 GRU 作為基礎單元。

這個架構的運作方式如下:

  1. Encoder 讀入整個輸入序列(例如一句英文),並將所有資訊壓縮成一個固定長度的向量,稱為 Context Vector(上下文向量)
  2. Decoder 接收這個 Context Vector,逐步生成輸出序列(例如對應的中文翻譯)。
"I love deep learning"
         |
      Encoder
         |
   [Context Vector]  ← 所有資訊都被壓縮在這裡
         |
      Decoder
         |
   "我愛深度學習"

這個架構的核心問題在於:所有的輸入資訊都被強制壓縮進一個固定維度的向量。當輸入序列很短時(5 個詞)還好,但當輸入是一篇 500 字的文章時,要求一個向量記住所有細節就顯得力不從心。

這個現象被稱為 Information Bottleneck(資訊瓶頸),是傳統 Seq2Seq 最大的缺陷。

1.2 人類的閱讀方式

反觀人類在理解語言時,並不是把整句話「壓縮成一個概念」再開始翻譯。以翻譯「The cat sat on the mat」為例,當我們翻譯「貓」這個字時,我們的注意力主要集中在 "cat" 上;翻譯「坐」時,注意力轉移到 "sat" 上。

這就是 Attention 的核心直覺:在生成每個輸出 token 時,動態地決定應該「關注」輸入序列中的哪些位置,而不是依賴一個固定的壓縮向量。


二、Attention 機制的數學原理

2.1 基本 Attention 公式

Attention 機制的核心計算可以用一個簡潔的公式表達:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

這個公式看起來簡短,但包含了深刻的設計思想。讓我們逐步拆解每個元素。

2.2 Query、Key、Value 的直覺

Attention 機制使用三個核心概念,可以用「資料庫查詢」的比喻來理解:

  • Query(查詢,Q):你正在「問」什麼。例如 Decoder 當前要生成的 token 想知道什麼。
  • Key(鍵,K):資料庫裡每筆資料的「索引標籤」。每個輸入 token 都有一個 Key。
  • Value(值,V):資料庫裡每筆資料的「實際內容」。每個輸入 token 都有一個 Value。

計算流程如下:

  1. 用 Query 對每個 Key 計算相似度(點積),得到一組分數。
  2. 對分數做 Softmax,得到注意力權重(加總為 1)。
  3. 用注意力權重對所有 Value 做加權平均,得到最終輸出。
Query: "我要生成哪個中文字?"
Keys:  ["I", "love", "deep", "learning"]  → 計算相似度
Scores: [0.1, 0.6, 0.2, 0.1]             → Softmax
Weights: [0.1, 0.6, 0.2, 0.1]
Output:  0.1×V_I + 0.6×V_love + 0.2×V_deep + 0.1×V_learning

2.3 為什麼要除以 dk\sqrt{d_k}

點積的結果會隨著向量維度 dkd_k 的增大而變大,導致 Softmax 的梯度非常小(梯度消失)。除以 dk\sqrt{d_k} 是一種縮放(scaling),確保點積值不會過大,讓 Softmax 能夠在合理的範圍內運作。

這種 Attention 也因此被稱為 Scaled Dot-Product Attention

2.4 Self-Attention

Self-Attention(自注意力) 是 Attention 機制的一個特殊且重要的形式:Query、Key、Value 都來自同一個序列

這使得序列中的每個 token 都可以「看到」序列中的所有其他 token,從而學習到長距離的依賴關係。

例如,在句子「The animal didn't cross the street because it was too tired」中,Self-Attention 可以讓 "it" 這個 token 學到它指的是 "animal" 而不是 "street",即使它們相距較遠。

Q、K、V 的計算方式:

Q=XWQ,K=XWK,V=XWVQ = XW^Q, \quad K = XW^K, \quad V = XW^V

其中 XX 是輸入序列的嵌入矩陣,WQ,WK,WVW^Q, W^K, W^V 是可學習的投影矩陣。

2.5 Multi-Head Attention

單一的 Attention head 每次只能關注序列中的一種模式。Multi-Head Attention 則是同時運行 hh 個獨立的 Attention head,每個 head 學習不同的關注模式,最後將所有 head 的輸出拼接起來。

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W^O

whereheadi=Attention(QWiQ,KWiK,VWiV)\text{where} \quad \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)

舉例來說,在翻譯任務中:

  • Head 1 可能學會關注句法結構(主詞和動詞的關係)
  • Head 2 可能學會關注語義相似性
  • Head 3 可能學會關注位置鄰近性

這種並行的多視角學習是 Multi-Head Attention 強大的原因。

2.6 Masked Attention

在自回歸生成任務(如語言模型)中,Decoder 在生成第 tt 個 token 時,不能看到未來的 token(t+1,t+2,t+1, t+2, \ldots),否則會造成「作弊」。

Masked Attention 通過在 Softmax 之前將未來位置的分數設為負無窮大(-\infty),使其在 Softmax 後的權重為 0,從而實現這一約束:

MaskedAttention(Q,K,V)=softmax(QKTdk+M)V\text{MaskedAttention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}} + M\right)V

其中 MM 是遮罩矩陣,對應「不應看到」的位置填入 -\infty,其餘填入 00


三、Positional Encoding

Attention 機制本身是位置不敏感的,也就是說,如果你把輸入序列的順序打亂,Attention 的輸出不會改變。這顯然不合理,因為「我愛你」和「你愛我」的意思完全不同。

Positional Encoding(位置編碼) 的作用是給每個 token 的嵌入向量加上位置資訊。Transformer 原始論文使用正弦與餘弦函數:

PE(pos,2i)=sin(pos100002i/dmodel)PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right)

PE(pos,2i+1)=cos(pos100002i/dmodel)PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right)

其中 pospos 是位置(0, 1, 2, ...),ii 是維度索引,dmodeld_{model} 是嵌入維度。

選擇三角函數的原因在於:不同頻率的正弦波可以唯一表示任意長度的位置,而且模型可以透過線性變換從絕對位置推算相對位置。


四、Python 實作

現在讓我們從頭實作一個完整的 Multi-Head Self-Attention 模組。

4.1 環境設定

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import matplotlib.pyplot as plt

4.2 Scaled Dot-Product Attention

def scaled_dot_product_attention(query, key, value, mask=None):
    """
    計算 Scaled Dot-Product Attention。
    
    Args:
        query: shape (batch, heads, seq_len_q, d_k)
        key:   shape (batch, heads, seq_len_k, d_k)
        value: shape (batch, heads, seq_len_v, d_v)
        mask:  shape (batch, 1, seq_len_q, seq_len_k), optional
    
    Returns:
        output:  shape (batch, heads, seq_len_q, d_v)
        weights: shape (batch, heads, seq_len_q, seq_len_k)
    """
    d_k = query.size(-1)  # Key 的維度
    
    # Step 1: 計算 Q 和 K 的點積,並縮放
    # scores shape: (batch, heads, seq_len_q, seq_len_k)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    
    # Step 2: 套用遮罩(如果有的話)
    if mask is not None:
        # 將 mask 為 0 的位置設為極小值,Softmax 後趨近於 0
        scores = scores.masked_fill(mask == 0, float('-inf'))
    
    # Step 3: Softmax 取得注意力權重
    # 沿最後一個維度(key 方向)做 Softmax
    weights = F.softmax(scores, dim=-1)
    
    # Step 4: 用注意力權重對 Value 做加權平均
    output = torch.matmul(weights, value)
    
    return output, weights

4.3 Multi-Head Attention 模組

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        """
        Multi-Head Attention 模組。
        
        Args:
            d_model:   模型的總維度(嵌入維度)
            num_heads: Attention Head 的數量
        """
        super(MultiHeadAttention, self).__init__()
        
        # 確保 d_model 可以被 num_heads 整除
        assert d_model % num_heads == 0, "d_model 必須能被 num_heads 整除"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # 每個 head 的維度
        
        # 定義 Q、K、V 和輸出的線性投影層
        # 注意:一次性投影整個 d_model,之後再拆分成多個 head
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        
    def split_heads(self, x, batch_size):
        """
        將最後一個維度拆分成 (num_heads, d_k),並調整維度順序。
        
        Args:
            x: shape (batch, seq_len, d_model)
        Returns:
            shape (batch, num_heads, seq_len, d_k)
        """
        x = x.view(batch_size, -1, self.num_heads, self.d_k)
        return x.permute(0, 2, 1, 3)  # (batch, heads, seq_len, d_k)
    
    def forward(self, query, key, value, mask=None):
        """
        前向傳播。
        
        Args:
            query: shape (batch, seq_len_q, d_model)
            key:   shape (batch, seq_len_k, d_model)
            value: shape (batch, seq_len_v, d_model)
            mask:  optional mask
        
        Returns:
            output: shape (batch, seq_len_q, d_model)
        """
        batch_size = query.size(0)
        
        # Step 1: 線性投影 Q、K、V
        Q = self.W_q(query)  # (batch, seq_len_q, d_model)
        K = self.W_k(key)    # (batch, seq_len_k, d_model)
        V = self.W_v(value)  # (batch, seq_len_v, d_model)
        
        # Step 2: 拆分成多個 head
        Q = self.split_heads(Q, batch_size)  # (batch, heads, seq_len_q, d_k)
        K = self.split_heads(K, batch_size)  # (batch, heads, seq_len_k, d_k)
        V = self.split_heads(V, batch_size)  # (batch, heads, seq_len_v, d_k)
        
        # Step 3: 計算每個 head 的 Attention
        attn_output, self.attention_weights = scaled_dot_product_attention(Q, K, V, mask)
        # attn_output shape: (batch, heads, seq_len_q, d_k)
        
        # Step 4: 合併所有 head 的輸出
        # 先調換維度,再 reshape 成 (batch, seq_len_q, d_model)
        attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
        attn_output = attn_output.view(batch_size, -1, self.d_model)
        
        # Step 5: 最終線性投影
        output = self.W_o(attn_output)
        
        return output

4.4 Positional Encoding

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_len=5000, dropout=0.1):
        """
        Sinusoidal Positional Encoding。
        
        Args:
            d_model:     嵌入維度
            max_seq_len: 最大序列長度
            dropout:     Dropout 比例
        """
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # 建立位置編碼矩陣
        pe = torch.zeros(max_seq_len, d_model)  # (max_seq_len, d_model)
        
        position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
        # position shape: (max_seq_len, 1)
        
        # 計算分母的除數項
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        
        # 偶數維度用 sin,奇數維度用 cos
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # 增加 batch 維度:(1, max_seq_len, d_model)
        pe = pe.unsqueeze(0)
        
        # 使用 register_buffer 讓 pe 不參與梯度更新,但隨模型移動(to device)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        """
        Args:
            x: shape (batch, seq_len, d_model)
        Returns:
            shape (batch, seq_len, d_model)
        """
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

4.5 Feed-Forward Network

Transformer 中每個 Attention 層後面都接了一個 Feed-Forward Network(FFN)

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        """
        Position-wise Feed-Forward Network。
        
        Args:
            d_model: 輸入/輸出維度
            d_ff:    隱藏層維度(通常是 d_model 的 4 倍)
            dropout: Dropout 比例
        """
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # 兩層線性變換,中間有 ReLU 激活和 Dropout
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

4.6 Transformer Encoder Layer

將 Self-Attention 和 FFN 組合成一個完整的 Encoder Layer:

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        """
        單一 Transformer Encoder 層。
        包含 Multi-Head Self-Attention + Feed-Forward Network,
        以及殘差連接(Residual Connection)和層正規化(Layer Normalization)。
        """
        super(EncoderLayer, self).__init__()
        
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        """
        Args:
            x:    shape (batch, seq_len, d_model)
            mask: optional padding mask
        """
        # Sub-layer 1: Multi-Head Self-Attention + Residual + LayerNorm
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))  # Add & Norm
        
        # Sub-layer 2: Feed-Forward + Residual + LayerNorm
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))    # Add & Norm
        
        return x

4.7 完整的 Transformer Encoder

class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, d_model=512, num_heads=8, 
                 d_ff=2048, num_layers=6, max_seq_len=512, dropout=0.1):
        """
        完整的 Transformer Encoder。
        
        Args:
            vocab_size:  詞彙表大小
            d_model:     嵌入維度
            num_heads:   Attention head 數量
            d_ff:        FFN 隱藏層維度
            num_layers:  Encoder layer 數量
            max_seq_len: 最大序列長度
            dropout:     Dropout 比例
        """
        super(TransformerEncoder, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
        
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        self.norm = nn.LayerNorm(d_model)
        self.d_model = d_model
        
    def forward(self, x, mask=None):
        """
        Args:
            x:    token indices, shape (batch, seq_len)
            mask: optional padding mask
        """
        # 嵌入 + 縮放 + 位置編碼
        x = self.embedding(x) * math.sqrt(self.d_model)
        x = self.pos_encoding(x)
        
        # 逐層通過 Encoder Layer
        for layer in self.layers:
            x = layer(x, mask)
        
        return self.norm(x)

4.8 完整測試與視覺化

def visualize_attention(attention_weights, tokens, layer=0, head=0):
    """
    視覺化 Attention 權重熱力圖。
    
    Args:
        attention_weights: shape (batch, heads, seq_len, seq_len)
        tokens:           token 列表(字串)
        layer:            要視覺化的 layer 索引
        head:             要視覺化的 head 索引
    """
    weights = attention_weights[0, head].detach().numpy()
    
    fig, ax = plt.subplots(figsize=(8, 6))
    im = ax.imshow(weights, cmap='Blues', aspect='auto')
    
    ax.set_xticks(range(len(tokens)))
    ax.set_yticks(range(len(tokens)))
    ax.set_xticklabels(tokens, rotation=45, ha='right')
    ax.set_yticklabels(tokens)
    
    ax.set_xlabel('Key (被關注的 token)')
    ax.set_ylabel('Query (正在生成的 token)')
    ax.set_title(f'Attention Weights (Head {head})')
    
    plt.colorbar(im, ax=ax)
    plt.tight_layout()
    plt.show()


def test_transformer():
    """完整測試 Transformer Encoder。"""
    print("=" * 50)
    print("Transformer Encoder 測試")
    print("=" * 50)
    
    # 超參數設定
    vocab_size = 1000
    d_model = 128      # 嵌入維度(示範用,實際通常 512 或 768)
    num_heads = 4      # Attention head 數量
    d_ff = 512         # FFN 隱藏層維度
    num_layers = 2     # Encoder layer 數量
    batch_size = 2
    seq_len = 10
    
    # 建立模型
    model = TransformerEncoder(
        vocab_size=vocab_size,
        d_model=d_model,
        num_heads=num_heads,
        d_ff=d_ff,
        num_layers=num_layers,
    )
    
    # 計算參數量
    total_params = sum(p.numel() for p in model.parameters())
    print(f"模型總參數量: {total_params:,}")
    
    # 建立假資料
    input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
    print(f"\n輸入 shape: {input_ids.shape}")
    
    # 建立 Padding Mask(假設最後 2 個 token 是 padding)
    # mask 為 0 的位置表示 padding,Attention 會忽略這些位置
    mask = torch.ones(batch_size, 1, 1, seq_len)
    mask[:, :, :, -2:] = 0  # 最後 2 個位置是 padding
    
    # 前向傳播
    model.eval()
    with torch.no_grad():
        output = model(input_ids, mask)
    
    print(f"輸出 shape: {output.shape}")
    print(f"期望 shape: ({batch_size}, {seq_len}, {d_model})")
    
    # 取得 Attention 權重(第一個 Encoder Layer)
    attn_weights = model.layers[0].self_attn.attention_weights
    print(f"\nAttention weights shape: {attn_weights.shape}")
    print(f"(batch={batch_size}, heads={num_heads}, "
          f"seq_len={seq_len}, seq_len={seq_len})")
    
    # 驗證 Attention 權重加總為 1
    weight_sum = attn_weights[0, 0].sum(dim=-1)  # 第 0 個 batch,第 0 個 head
    print(f"\nAttention 權重每行加總(應為 1):")
    print(weight_sum.numpy().round(4))
    
    return model, attn_weights, input_ids


# 執行測試
if __name__ == "__main__":
    model, attn_weights, input_ids = test_transformer()
    
    # 視覺化(使用假的 token 名稱)
    tokens = [f"tok_{i}" for i in range(10)]
    visualize_attention(attn_weights, tokens, head=0)

五、Attention 機制的變種與發展

5.1 Cross-Attention

在 Encoder-Decoder 架構中,Decoder 需要同時處理兩個輸入:

  • 已生成的輸出序列(Self-Attention)
  • Encoder 的輸出(Cross-Attention)

Cross-Attention 中,Query 來自 Decoder,而 Key 和 Value 來自 Encoder:

# Cross-Attention 的使用方式
decoder_output = cross_attention(
    query=decoder_hidden,   # Q 來自 Decoder
    key=encoder_output,     # K 來自 Encoder
    value=encoder_output    # V 來自 Encoder
)

5.2 Efficient Attention 變種

標準 Self-Attention 的計算複雜度是 O(n2)O(n^2),對長序列非常耗時。近年來出現了多種高效變種:

方法複雜度核心思想
LongformerO(n)O(n)局部視窗 + 全局 token
LinformerO(n)O(n)低秩近似 Key/Value
PerformerO(n)O(n)隨機特徵近似 Softmax
Flash AttentionO(n2)O(n^2)IO 感知計算(GPU 記憶體優化)
Sliding WindowO(nw)O(n \cdot w)固定視窗大小 ww

5.3 相對位置編碼

原始 Transformer 使用絕對位置編碼,但相對位置編碼(Relative Positional Encoding) 直接在 Attention 分數中加入相對位置資訊,在某些任務上效果更好:

score(qi,kj)=qikj+qirijdk\text{score}(q_i, k_j) = \frac{q_i \cdot k_j + q_i \cdot r_{i-j}}{\sqrt{d_k}}

其中 rijr_{i-j} 表示位置 ii 和位置 jj 之間的相對距離編碼。GPT-Neo、T5 等模型都使用了相對位置編碼的變種。


六、從 Attention 到 Transformer 到 LLM

理解了 Attention 機制,你就掌握了現代大型語言模型(LLM)的核心:

Attention (2015)

Self-Attention (2017, Transformer)

BERT (2018, Encoder-only, 雙向)
GPT (2018, Decoder-only, 單向)

GPT-2 / GPT-3 (2019/2020, 規模擴展)

ChatGPT / GPT-4 (2022/2023, RLHF 對齊)

Llama / Claude / Gemini (2023-至今)

每一步的進化,底層的 Attention 機制都沒有根本性的改變,改變的是:

  • 規模(參數量從百萬到兆)
  • 訓練方式(預訓練 + 微調 + RLHF)
  • 工程優化(Flash Attention、GQA、RoPE)

七、重點整理

  1. Attention 解決了什麼問題:傳統 Seq2Seq 的 Information Bottleneck,讓模型在生成時動態選擇關注哪些輸入。

  2. 核心公式Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V,除以 dk\sqrt{d_k} 是為了穩定梯度。

  3. Self-Attention:Q、K、V 來自同一序列,讓每個 token 都能直接與序列中所有其他 token 互動。

  4. Multi-Head Attention:同時運行多個 Attention head,各自學習不同的關注模式,最後拼接。

  5. Positional Encoding:使用三角函數為每個位置添加唯一的位置資訊,補足 Attention 缺乏位置感知的缺陷。

  6. 計算複雜度:標準 Self-Attention 是 O(n2)O(n^2),這是長序列的主要瓶頸,催生了大量 Efficient Attention 的研究。


參考資料

  • Bahdanau et al. (2015). Neural Machine Translation by Jointly Learning to Align and Translate. [原始 Attention 論文]
  • Vaswani et al. (2017). Attention Is All You Need. [Transformer 論文]
  • Devlin et al. (2018). BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding.
  • Dao et al. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.
  • The Illustrated Transformer by Jay Alammar — 最好的 Transformer 圖解教學

留言

使用 GitHub 帳號登入即可留言