Appearance
从零实现一个迷你 Transformer:Self-Attention 怎样读上下文
这篇笔记用 PyTorch 手写一个非常小的 decoder-only Transformer,不依赖 Hugging Face 或其他高层库。
你会看到:
- 文本如何变成 token 和 embedding。
- Self-Attention 如何用
Q、K、V从上下文里取信息。 - 为什么语言模型需要 causal mask,避免偷看未来。
- 如何把 Multi-Head Attention、Feed Forward、Residual、LayerNorm 组合成 Transformer block。
- 如何训练一个很小的字符级语言模型,并让它继续生成文本。
直观比喻:Self-Attention 像每个 token 都拿着一个问题 Q,去和其他 token 的关键词 K 匹配,最后按匹配程度读取它们携带的信息 V。
0. 环境准备
python
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['Songti SC']
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
import torch
from torch import nn
import torch.nn.functional as F
torch.manual_seed(7)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("torch version:", torch.__version__)
print("device:", device)1. 一个最小文本任务
为了专注理解 Transformer,我们不用大数据集,而是构造一小段有规律的文本。
任务是:给定前面一段字符,预测下一个字符。比如看到 graph neural,模型应该更倾向预测后面的字符。
这里使用字符级 tokenization:每个不同字符都是一个 token。真实大模型通常使用 BPE、SentencePiece 等子词 tokenization,但核心训练目标是一样的。
python
text = (
"graph neural networks pass messages. "
"attention lets tokens read context. "
"transformers use self attention. "
) * 80
chars = sorted(set(text))
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for ch, i in stoi.items()}
def encode(s):
return [stoi[ch] for ch in s]
def decode(ids):
return "".join(itos[int(i)] for i in ids)
data = torch.tensor(encode(text), dtype=torch.long)
vocab_size = len(chars)
print("text length:", len(data))
print("vocab size:", vocab_size)
print("vocab:", "".join(chars))
print("encoded example:", encode("attention"))示例输出:
text length: 8480
vocab size: 22
vocab: .acdefghiklmnoprstuwx
encoded example: [2, 18, 18, 5, 13, 18, 9, 14, 13]2. 训练样本:输入序列和目标序列
语言模型不是只预测最后一个字符,而是对序列中每个位置都预测"下一个字符"。
如果输入是:
text
a t t e n目标就是:
text
t t e n t注意:第 1 个位置只能看见第 1 个输入 token,第 2 个位置只能看见前 2 个输入 token,以此类推。这个限制由 causal mask 保证。
python
block_size = 32
batch_size = 54
def get_batch():
starts = torch.randint(0, len(data) - block_size - 1, (batch_size,))
x = torch.stack([data[i : i + block_size] for i in starts])
y = torch.stack([data[i + 1 : i + block_size + 1] for i in starts])
return x.to(device), y.to(device)
x, y = get_batch()
print("x shape:", tuple(x.shape))
print("y shape:", tuple(y.shape))
print("x[0]:", decode(x[0].cpu()))
print("y[0]:", decode(y[0].cpu()))示例输出
x shape: (54, 32)
y shape: (54, 32)
x[0]: on. graph neural networks pass m
y[0]: n. graph neural networks pass me3. 从矩阵乘法理解 Self-Attention
对一个长度为 T 的序列,每个 token 先变成一个向量,记作 x,形状是 (B, T, C):
B:batch size。T:序列长度。C:每个 token 向量的维度。
Self-Attention 会从 x 线性变换出三组向量:
Q/ Query:当前位置想找什么信息。K/ Key:每个位置提供什么索引。V/ Value:每个位置真正携带的内容。
核心计算是:
python
class CausalSelfAttentionHead(nn.Module):
def __init__(self, n_embd, head_size, block_size):
super().__init__()
self.key = nn.Linear(n_embd, head_size, bias=False)
self.query = nn.Linear(n_embd, head_size, bias=False)
self.value = nn.Linear(n_embd, head_size, bias=False)
self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
def forward(self, x):
B, T, C = x.shape
k = self.key(x)
q = self.query(x)
v = self.value(x)
scores = q @ k.transpose(-2, -1) / torch.sqrt(torch.tensor(k.shape[-1]))
scores = scores.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
weights = F.softmax(scores, dim=-1)
out = weights @ v
return out, weights
demo_embedding = nn.Embedding(vocab_size, 16).to(device)
demo_head = CausalSelfAttentionHead(n_embd=16, head_size=8, block_size=block_size).to(
device
)
demo_x = demo_embedding(x[:1])
demo_out, demo_weights = demo_head(demo_x)
print("input embedding:", tuple(demo_x.shape))
print("attention output:", tuple(demo_out.shape))
print("attention weights:", tuple(demo_weights.shape))
print("first 6x6 causal weights:")
print(demo_weights[0, :6, :6].detach().cpu().round(decimals=2))示例输出:
input embedding: (1, 32, 16)
attention output: (1, 32, 8)
attention weights: (1, 32, 32)
first 6x6 causal weights:
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.5200, 0.4800, 0.0000, 0.0000, 0.0000, 0.0000],
[0.3600, 0.2700, 0.3700, 0.0000, 0.0000, 0.0000],
[0.2100, 0.1900, 0.4200, 0.1800, 0.0000, 0.0000],
[0.3100, 0.1700, 0.2500, 0.1200, 0.1500, 0.0000],
[0.1400, 0.1800, 0.1100, 0.1900, 0.1000, 0.2700]])其中 score[b, i, j] 表示第 i 个 token 对第 j 个 token 的关注程度。
4. Multi-Head Attention
单个 attention head 只用一种方式读上下文。Multi-Head Attention 会并行使用多个 head:
- 有的 head 可能关注前一个字符。
- 有的 head 可能关注词内结构。
- 有的 head 可能关注更远处的模式。
实现上就是把多个 head 的输出拼接起来,再过一个线性层。
python
class MultiHeadAttention(nn.Module):
def __init__(self, n_embd, num_heads, block_size):
super().__init__()
assert n_embd % num_heads == 0
head_size = n_embd // num_heads
self.heads = nn.ModuleList(
[
CausalSelfAttentionHead(n_embd, head_size, block_size)
for _ in range(num_heads)
]
)
self.proj = nn.Linear(n_embd, n_embd)
def forward(self, x, return_weights=False):
head_outputs = []
head_weights = []
for head in self.heads:
out, weights = head(x)
head_outputs.append(out)
head_weights.append(weights)
out = torch.cat(head_outputs, dim=-1)
out = self.proj(out)
if return_weights:
return out, torch.stack(head_weights, dim=1)
return out
class FeedForward(nn.Module):
def __init__(self, n_embd):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.ReLU(),
nn.Linear(4 * n_embd, n_embd),
)
def forward(self, x):
return self.net(x)5. Transformer Block
一个标准 Transformer block 通常包含两部分:
- Multi-Head Self-Attention:让 token 读取上下文。
- Feed Forward:对每个位置独立做非线性变换。
还要加上两个稳定训练的关键结构:
- Residual connection:
x = x + sublayer(x),保留原信息,让梯度更容易传播。 - LayerNorm:把每个 token 的向量归一化,让训练更稳定。
这里使用现代常见的 pre-norm 写法:先 LayerNorm,再进入子层。
python
class TransformerBlock(nn.Module):
def __init__(self, n_embd, num_heads, block_size):
super().__init__()
self.ln1 = nn.LayerNorm(n_embd)
self.attn = MultiHeadAttention(n_embd, num_heads, block_size)
self.ln2 = nn.LayerNorm(n_embd)
self.ffwd = FeedForward(n_embd)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.ffwd(self.ln2(x))
return x6. 组装一个迷你语言模型
完整模型流程:
token_embedding_table:把 token id 变成向量。position_embedding_table:告诉模型每个 token 在第几个位置。没有位置信息时,attention 本身不知道顺序。- 多个
TransformerBlock:反复读上下文并更新表示。 lm_head:把每个位置的向量变成对词表中每个 token 的 logits。
训练时直接把 logits 交给 F.cross_entropy,不要提前手动 softmax。
python
class MiniTransformerLanguageModel(nn.Module):
def __init__(self, vocab_size, n_embd, block_size, num_heads, num_layers):
super().__init__()
self.block_size = block_size
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
self.position_embedding_table = nn.Embedding(block_size, n_embd)
self.blocks = nn.Sequential(
*[
TransformerBlock(n_embd, num_heads, block_size)
for _ in range(num_layers)
]
)
self.ln_f = nn.LayerNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size)
def forward(self, idx, targets=None):
B, T = idx.shape
token_emb = self.token_embedding_table(idx)
pos = torch.arange(T, device=idx.device)
pos_emb = self.position_embedding_table(pos)
x = token_emb + pos_emb
x = self.blocks(x)
x = self.ln_f(x)
logits = self.lm_head(x)
loss = None
if targets is not None:
B, T, C = logits.shape
loss = F.cross_entropy(logits.view(B * T, C), targets.view(B * T))
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens):
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.block_size :]
logits, _ = self(idx_cond)
logits = logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, idx_next], dim=1)
return idx7. 训练
这里的数据非常小,模型也很小,所以训练几十到几百步就能看到 loss 下降。
python
model = MiniTransformerLanguageModel(
vocab_size=vocab_size,
n_embd=64,
block_size=block_size,
num_heads=4,
num_layers=2,
).to(device)
num_params = sum(p.numel() for p in model.parameters())示例输出:
parameters: 104598如果你在 CPU 上运行,下面也应该很快完成。真实语言模型的区别主要是:数据更大、模型更深、训练更久、工程细节更多。
python
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3)
for step in range(201):
xb, yb = get_batch()
logits, loss = model(xb, yb)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
if step % 50 == 0:
print(f"step {step:3d} | loss {loss.item():.4f}")示例输出:
step 0 | loss 3.2232
step 50 | loss 0.5614
step 100 | loss 0.1098
step 150 | loss 0.0841
step 200 | loss 0.07808. 生成文本
生成时从一个起始 token 开始:
- 用当前上下文预测下一个 token 的概率分布。
- 从概率分布里采样一个 token。
- 把新 token 接到上下文后面。
- 重复以上步骤。
模型只看过很小、重复的文本,所以生成结果不会像真实大模型一样丰富,但你应该能看到它学会了一些字符组合和短语模式。
python
model.eval()
start = torch.tensor([encode("at")], dtype=torch.long, device=device)
generated = model.generate(start, max_new_tokens=160)[0].cpu()
print(decode(generated))示例输出:
attention lets tokens read context. transformers use self attention. graph neural networks pass messages. attention lets tokens read context. transformers use sel9. 看一眼 Attention 权重
下面取第一层第一个 head 的 attention 权重。矩阵中第 i 行表示第 i 个位置在读哪些历史位置。
因为用了 causal mask,右上角未来位置的权重应该是0。
python
with torch.no_grad():
sample = torch.tensor(
[encode("attention lets tokens")], dtype=torch.long, device=device
)
token_emb = model.token_embedding_table(sample)
pos = torch.arange(sample.shape[1], device=device)
x_emb = token_emb + model.position_embedding_table(pos)
x_norm = model.blocks[0].ln1(x_emb)
_, weights = model.blocks[0].attn(x_norm, return_weights=True)
print("sample:", decode(sample[0].cpu()))
print("weights shape: (batch, heads, T, T) =", tuple(weights.shape))
print(weights[0, 0, :8, :8].cpu().round(decimals=2))示例输出:
sample: attention lets tokens
weights shape: (batch, heads, T, T) = (1, 4, 21, 21)
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.6800, 0.3200, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.7600, 0.1300, 0.1100, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.6300, 0.1700, 0.1800, 0.0200, 0.0000, 0.0000, 0.0000, 0.0000],
[0.5000, 0.0800, 0.1500, 0.0700, 0.2100, 0.0000, 0.0000, 0.0000],
[0.3200, 0.1000, 0.1100, 0.0300, 0.2400, 0.2000, 0.0000, 0.0000],
[0.0400, 0.2800, 0.0800, 0.0100, 0.2400, 0.2400, 0.1100, 0.0000],
[0.0100, 0.0100, 0.0100, 0.0000, 0.0000, 0.0300, 0.9400, 0.0000]])用热力图显示完整 attention 矩阵。
python
tokens = list(decode(sample[0].cpu()))
attention_matrix = weights[0, 0].detach().cpu()
fig, ax = plt.subplots(figsize=(5, 5))
im = ax.imshow(attention_matrix, cmap="Blues", vmin=0, vmax=attention_matrix.max())
ax.set_title("Layer 1, Head 1 attention weights")
ax.set_xlabel("Key / 被读取的位置")
ax.set_ylabel("Query / 当前预测的位置")
ax.set_xticks(range(len(tokens)))
ax.set_yticks(range(len(tokens)))
ax.set_xticklabels(tokens)
ax.set_yticklabels(tokens)
fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
plt.tight_layout()
plt.show()
10. 小结
到这里,我们已经从零搭出了一个迷你 Transformer:
- Token embedding 把离散 token 变成连续向量。
- Position embedding 给模型提供顺序信息。
- Causal Self-Attention 让当前位置读取过去上下文,但不能偷看未来。
- Multi-Head Attention 让模型用多种方式读取上下文。
- Feed Forward、Residual connection、LayerNorm 让每个 block 更有表达能力且更容易训练。
- 语言模型训练时预测每个位置的下一个 token,使用 logits + cross entropy。