Skip to content

从零实现一个迷你 Transformer:Self-Attention 怎样读上下文

这篇笔记用 PyTorch 手写一个非常小的 decoder-only Transformer,不依赖 Hugging Face 或其他高层库。

你会看到:

  1. 文本如何变成 token 和 embedding。
  2. Self-Attention 如何用 QKV 从上下文里取信息。
  3. 为什么语言模型需要 causal mask,避免偷看未来。
  4. 如何把 Multi-Head Attention、Feed Forward、Residual、LayerNorm 组合成 Transformer block。
  5. 如何训练一个很小的字符级语言模型,并让它继续生成文本。

直观比喻:Self-Attention 像每个 token 都拿着一个问题 Q,去和其他 token 的关键词 K 匹配,最后按匹配程度读取它们携带的信息 V

0. 环境准备

python
import matplotlib.pyplot as plt

plt.rcParams['font.sans-serif'] = ['Songti SC']
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题

import torch
from torch import nn
import torch.nn.functional as F

torch.manual_seed(7)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("torch version:", torch.__version__)
print("device:", device)

1. 一个最小文本任务

为了专注理解 Transformer,我们不用大数据集,而是构造一小段有规律的文本。

任务是:给定前面一段字符,预测下一个字符。比如看到 graph neural,模型应该更倾向预测后面的字符。

这里使用字符级 tokenization:每个不同字符都是一个 token。真实大模型通常使用 BPE、SentencePiece 等子词 tokenization,但核心训练目标是一样的。

python
text = (
    "graph neural networks pass messages. "
    "attention lets tokens read context. "
    "transformers use self attention. "
) * 80

chars = sorted(set(text))
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for ch, i in stoi.items()}


def encode(s):
    return [stoi[ch] for ch in s]


def decode(ids):
    return "".join(itos[int(i)] for i in ids)


data = torch.tensor(encode(text), dtype=torch.long)
vocab_size = len(chars)

print("text length:", len(data))
print("vocab size:", vocab_size)
print("vocab:", "".join(chars))
print("encoded example:", encode("attention"))

示例输出:

text length: 8480
vocab size: 22
vocab: .acdefghiklmnoprstuwx
encoded example: [2, 18, 18, 5, 13, 18, 9, 14, 13]

2. 训练样本:输入序列和目标序列

语言模型不是只预测最后一个字符,而是对序列中每个位置都预测"下一个字符"。

如果输入是:

text
a t t e n

目标就是:

text
t t e n t

注意:第 1 个位置只能看见第 1 个输入 token,第 2 个位置只能看见前 2 个输入 token,以此类推。这个限制由 causal mask 保证。

python
block_size = 32
batch_size = 54


def get_batch():
    starts = torch.randint(0, len(data) - block_size - 1, (batch_size,))
    x = torch.stack([data[i : i + block_size] for i in starts])
    y = torch.stack([data[i + 1 : i + block_size + 1] for i in starts])
    return x.to(device), y.to(device)


x, y = get_batch()
print("x shape:", tuple(x.shape))
print("y shape:", tuple(y.shape))
print("x[0]:", decode(x[0].cpu()))
print("y[0]:", decode(y[0].cpu()))

示例输出

x shape: (54, 32)
y shape: (54, 32)
x[0]: on. graph neural networks pass m
y[0]: n. graph neural networks pass me

3. 从矩阵乘法理解 Self-Attention

对一个长度为 T 的序列,每个 token 先变成一个向量,记作 x,形状是 (B, T, C)

  • B:batch size。
  • T:序列长度。
  • C:每个 token 向量的维度。

Self-Attention 会从 x 线性变换出三组向量:

  • Q / Query:当前位置想找什么信息。
  • K / Key:每个位置提供什么索引。
  • V / Value:每个位置真正携带的内容。

核心计算是:

python
class CausalSelfAttentionHead(nn.Module):
    def __init__(self, n_embd, head_size, block_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)

        scores = q @ k.transpose(-2, -1) / torch.sqrt(torch.tensor(k.shape[-1]))
        scores = scores.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
        weights = F.softmax(scores, dim=-1)
        out = weights @ v
        return out, weights


demo_embedding = nn.Embedding(vocab_size, 16).to(device)
demo_head = CausalSelfAttentionHead(n_embd=16, head_size=8, block_size=block_size).to(
    device
)
demo_x = demo_embedding(x[:1])
demo_out, demo_weights = demo_head(demo_x)

print("input embedding:", tuple(demo_x.shape))
print("attention output:", tuple(demo_out.shape))
print("attention weights:", tuple(demo_weights.shape))
print("first 6x6 causal weights:")
print(demo_weights[0, :6, :6].detach().cpu().round(decimals=2))

示例输出:

input embedding: (1, 32, 16)
attention output: (1, 32, 8)
attention weights: (1, 32, 32)
first 6x6 causal weights:
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5200, 0.4800, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3600, 0.2700, 0.3700, 0.0000, 0.0000, 0.0000],
        [0.2100, 0.1900, 0.4200, 0.1800, 0.0000, 0.0000],
        [0.3100, 0.1700, 0.2500, 0.1200, 0.1500, 0.0000],
        [0.1400, 0.1800, 0.1100, 0.1900, 0.1000, 0.2700]])

其中 score[b, i, j] 表示第 i 个 token 对第 j 个 token 的关注程度。

4. Multi-Head Attention

单个 attention head 只用一种方式读上下文。Multi-Head Attention 会并行使用多个 head:

  • 有的 head 可能关注前一个字符。
  • 有的 head 可能关注词内结构。
  • 有的 head 可能关注更远处的模式。

实现上就是把多个 head 的输出拼接起来,再过一个线性层。

python
class MultiHeadAttention(nn.Module):
    def __init__(self, n_embd, num_heads, block_size):
        super().__init__()
        assert n_embd % num_heads == 0
        head_size = n_embd // num_heads
        self.heads = nn.ModuleList(
            [
                CausalSelfAttentionHead(n_embd, head_size, block_size)
                for _ in range(num_heads)
            ]
        )
        self.proj = nn.Linear(n_embd, n_embd)

    def forward(self, x, return_weights=False):
        head_outputs = []
        head_weights = []
        for head in self.heads:
            out, weights = head(x)
            head_outputs.append(out)
            head_weights.append(weights)
        out = torch.cat(head_outputs, dim=-1)
        out = self.proj(out)
        if return_weights:
            return out, torch.stack(head_weights, dim=1)
        return out


class FeedForward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
        )

    def forward(self, x):
        return self.net(x)

5. Transformer Block

一个标准 Transformer block 通常包含两部分:

  1. Multi-Head Self-Attention:让 token 读取上下文。
  2. Feed Forward:对每个位置独立做非线性变换。

还要加上两个稳定训练的关键结构:

  • Residual connection:x = x + sublayer(x),保留原信息,让梯度更容易传播。
  • LayerNorm:把每个 token 的向量归一化,让训练更稳定。

这里使用现代常见的 pre-norm 写法:先 LayerNorm,再进入子层。

python
class TransformerBlock(nn.Module):
    def __init__(self, n_embd, num_heads, block_size):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.attn = MultiHeadAttention(n_embd, num_heads, block_size)
        self.ln2 = nn.LayerNorm(n_embd)
        self.ffwd = FeedForward(n_embd)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

6. 组装一个迷你语言模型

完整模型流程:

  1. token_embedding_table:把 token id 变成向量。
  2. position_embedding_table:告诉模型每个 token 在第几个位置。没有位置信息时,attention 本身不知道顺序。
  3. 多个 TransformerBlock:反复读上下文并更新表示。
  4. lm_head:把每个位置的向量变成对词表中每个 token 的 logits。

训练时直接把 logits 交给 F.cross_entropy,不要提前手动 softmax

python
class MiniTransformerLanguageModel(nn.Module):
    def __init__(self, vocab_size, n_embd, block_size, num_heads, num_layers):
        super().__init__()
        self.block_size = block_size
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(
            *[
                TransformerBlock(n_embd, num_heads, block_size)
                for _ in range(num_layers)
            ]
        )
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        token_emb = self.token_embedding_table(idx)
        pos = torch.arange(T, device=idx.device)
        pos_emb = self.position_embedding_table(pos)
        x = token_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)

        loss = None
        if targets is not None:
            B, T, C = logits.shape
            loss = F.cross_entropy(logits.view(B * T, C), targets.view(B * T))
        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size :]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, idx_next], dim=1)
        return idx

7. 训练

这里的数据非常小,模型也很小,所以训练几十到几百步就能看到 loss 下降。

python
model = MiniTransformerLanguageModel(
    vocab_size=vocab_size,
    n_embd=64,
    block_size=block_size,
    num_heads=4,
    num_layers=2,
).to(device)

num_params = sum(p.numel() for p in model.parameters())

示例输出:

parameters: 104598

如果你在 CPU 上运行,下面也应该很快完成。真实语言模型的区别主要是:数据更大、模型更深、训练更久、工程细节更多。

python
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3)

for step in range(201):
xb, yb = get_batch()
logits, loss = model(xb, yb)

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if step % 50 == 0:
        print(f"step {step:3d} | loss {loss.item():.4f}")

示例输出:

step 0 | loss 3.2232
step 50 | loss 0.5614
step 100 | loss 0.1098
step 150 | loss 0.0841
step 200 | loss 0.0780

8. 生成文本

生成时从一个起始 token 开始:

  1. 用当前上下文预测下一个 token 的概率分布。
  2. 从概率分布里采样一个 token。
  3. 把新 token 接到上下文后面。
  4. 重复以上步骤。

模型只看过很小、重复的文本,所以生成结果不会像真实大模型一样丰富,但你应该能看到它学会了一些字符组合和短语模式。

python
model.eval()
start = torch.tensor([encode("at")], dtype=torch.long, device=device)
generated = model.generate(start, max_new_tokens=160)[0].cpu()
print(decode(generated))

示例输出:

attention lets tokens read context. transformers use self attention. graph neural networks pass messages. attention lets tokens read context. transformers use sel

9. 看一眼 Attention 权重

下面取第一层第一个 head 的 attention 权重。矩阵中第 i 行表示第 i 个位置在读哪些历史位置。

因为用了 causal mask,右上角未来位置的权重应该是0。

python
with torch.no_grad():
    sample = torch.tensor(
        [encode("attention lets tokens")], dtype=torch.long, device=device
    )
    token_emb = model.token_embedding_table(sample)
    pos = torch.arange(sample.shape[1], device=device)
    x_emb = token_emb + model.position_embedding_table(pos)
    x_norm = model.blocks[0].ln1(x_emb)
    _, weights = model.blocks[0].attn(x_norm, return_weights=True)

print("sample:", decode(sample[0].cpu()))
print("weights shape: (batch, heads, T, T) =", tuple(weights.shape))
print(weights[0, 0, :8, :8].cpu().round(decimals=2))

示例输出:

sample: attention lets tokens
weights shape: (batch, heads, T, T) = (1, 4, 21, 21)
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6800, 0.3200, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.7600, 0.1300, 0.1100, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6300, 0.1700, 0.1800, 0.0200, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.0800, 0.1500, 0.0700, 0.2100, 0.0000, 0.0000, 0.0000],
        [0.3200, 0.1000, 0.1100, 0.0300, 0.2400, 0.2000, 0.0000, 0.0000],
        [0.0400, 0.2800, 0.0800, 0.0100, 0.2400, 0.2400, 0.1100, 0.0000],
        [0.0100, 0.0100, 0.0100, 0.0000, 0.0000, 0.0300, 0.9400, 0.0000]])

用热力图显示完整 attention 矩阵。

python
tokens = list(decode(sample[0].cpu()))
attention_matrix = weights[0, 0].detach().cpu()

fig, ax = plt.subplots(figsize=(5, 5))
im = ax.imshow(attention_matrix, cmap="Blues", vmin=0, vmax=attention_matrix.max())
ax.set_title("Layer 1, Head 1 attention weights")
ax.set_xlabel("Key / 被读取的位置")
ax.set_ylabel("Query / 当前预测的位置")
ax.set_xticks(range(len(tokens)))
ax.set_yticks(range(len(tokens)))
ax.set_xticklabels(tokens)
ax.set_yticklabels(tokens)
fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
plt.tight_layout()
plt.show()

10. 小结

到这里,我们已经从零搭出了一个迷你 Transformer:

  • Token embedding 把离散 token 变成连续向量。
  • Position embedding 给模型提供顺序信息。
  • Causal Self-Attention 让当前位置读取过去上下文,但不能偷看未来。
  • Multi-Head Attention 让模型用多种方式读取上下文。
  • Feed Forward、Residual connection、LayerNorm 让每个 block 更有表达能力且更容易训练。
  • 语言模型训练时预测每个位置的下一个 token,使用 logits + cross entropy。