第 9 章:预训练 Pretraining¶
1. 本章要解决的问题¶
第 8 章里,我们已经解决了一个关键入口问题:
原始文本怎样变成模型可以读取的 token 序列。
但到那里为止,我们只是把“数据喂进去”的前半段打通了,还没有真正回答另一个更核心的问题:
一个 GPT 类语言模型到底是怎么被训练出来的?
很多人在学习大模型时,会对“预训练”这个词有一种熟悉但模糊的感觉。
大家都知道:
- GPT、LLaMA、Qwen 这些模型都经历过预训练
- 预训练通常要吃海量文本
- 预训练成本很高
但如果继续追问:
- 它具体在学什么?
- loss 是怎么定义的?
- 一条文本样本进入模型后,训练循环里发生了什么?
- 为什么大家总在谈数据、算力、参数规模?
很多理解就会开始变得零碎。
所以这一章的任务,是把“预训练”从一个笼统概念,落到一个你可以真正讲清楚的训练闭环里。
从全书结构上看,这一章有三个作用:
- 它承接第 8 章,把 tokenizer 和数据处理真正接到模型训练上
- 它把第 7 章的 Mini-GPT 从“能定义、能 forward”推进到“能系统训练”
- 它也为第 10 章的 scaling law 做准备,因为只有先理解训练过程,才知道为什么模型会越做越大
如果第 8 章回答的是:
文本怎样进入模型。
那么第 9 章回答的就是:
模型怎样通过海量文本把“下一个 token 该是什么”这件事学出来。
2. 你学完后应该会什么¶
- 能解释预训练在 LLM 生命周期中的位置
- 能说清 causal language modeling 的训练目标
- 能理解
shift logits和shift labels为什么存在 - 能画出一个最小预训练流程的数据流和训练流
- 能理解 warmup、学习率衰减、mixed precision、checkpoint 的作用
- 能把“预训练为什么贵”这件事和数据、算力、参数规模联系起来
- 能把这一章内容自然衔接到第 10 章的 scaling law
3. 预训练到底是什么¶
先给一个尽量不绕的定义。
所谓预训练,指的是:
在大规模通用语料上,用一个相对通用的训练目标,先让模型学会语言中的统计规律和知识结构。
这里面有三个关键词。
3.1 大规模通用语料¶
预训练不是拿几千条标注数据做监督学习,而是使用规模大得多的文本集合,例如:
- 网页文本
- 书籍
- 代码
- 论文
- 问答数据
- 多语言语料
这些数据通常不需要人工逐条标注“正确答案”,因为语言模型的训练目标可以从文本本身自动构造出来。
3.2 相对通用的训练目标¶
对于 GPT 类模型,这个目标通常就是:
给定前文,预测下一个 token。
也就是说,训练信号并不是来自人工标签,而是来自文本序列本身。
一句话写出来后,前面的 token 就是输入,后面的 token 就能成为监督信号。
3.3 先学通用能力,再做后续适配¶
预训练之所以叫“预”训练,是因为它通常不是终点。
后面模型往往还会经历:
- 指令微调(SFT)
- 偏好对齐(RLHF / DPO)
- 领域微调
- 部署阶段的推理优化
所以预训练更像是在打地基。
它决定了模型有没有足够强的:
- 语言建模能力
- 基础世界知识
- 模式归纳能力
- 上下文理解能力
后面的 SFT 和对齐,更像是在这个基础上做“行为层”的塑形。
4. GPT 预训练到底在学什么¶
如果把 GPT 的预训练目标说得最朴素一点,其实就是一句话:
不断玩“猜下一个 token”的游戏。
例如一段文本是:
tokenize 之后,模型看到的可能是:
训练时,并不是把整句话作为一个分类问题来处理,而是把它拆成很多个局部预测任务:
- 看见
深度,预测下一个 token 是什么 - 看见
深度 学习,预测下一个 token 是什么 - 看见
深度 学习 改变,预测下一个 token 是什么
于是一个长度为 T 的序列,天然就能提供大约 T-1 个监督信号。
这也是语言模型特别适合吃海量无标注文本的原因:
监督信号是序列自己长出来的。
5. causal language modeling:训练目标到底怎么定义¶
GPT 预训练最经典的目标,就是 causal language modeling,简称 CLM。
5.1 什么叫 causal¶
这里的 causal,不是因果推断里的那个“因果图”意思,而是:
当前位置只能看见自己左边的上下文,不能偷看未来。
也就是第 7 章讲过的 decoder-only + causal mask 设定。
假设序列是:
那么模型在位置 t 上,只能基于:
来预测:
这就决定了它的训练目标可以写成:
这条公式的直觉其实不复杂。
它表达的是:
整段文本出现的概率,可以分解成一步一步“前文条件下,下一个 token 出现的概率”的连乘。
5.2 为什么是 next-token prediction¶
因为只要模型真的擅长预测下一个 token,它就会被迫学到很多东西:
- 语法
- 词法搭配
- 长距离依赖
- 常识模式
- 文档结构
- 一部分代码和数学模式
比如要预测一句话后面是“苹果”还是“微软”,模型必须利用前文语境; 要预测代码里下一个符号是什么,它又必须学会括号、缩进、关键字等结构规律。
所以“预测下一个 token”听起来简单,但它其实是一个非常强的自监督任务。
5.3 loss 一般怎么写¶
训练时,我们会让模型在每个位置输出对整个词表的 logits:
其中:
B是 batch sizeT是 sequence lengthV是 vocabulary size
然后用 cross entropy 去比较“模型预测的下一个 token 分布”和“真实下一个 token id”之间的差距。
6. 为什么训练时要做 shift¶
这一点是很多人第一次写 GPT 训练代码时最容易懵的地方。
模型输入是一整段 token:
但训练目标不是让模型在位置 x1 预测 x1 自己,而是:
- 在位置
x1之后预测x2 - 在位置
x2之后预测x3 - 在位置
x3之后预测x4
所以我们通常会写出:
logits = model(input_ids) # [B, T, V]
shift_logits = logits[:, :-1, :] # 位置 0..T-2 的预测
shift_labels = input_ids[:, 1:] # 真值是 1..T-1
loss = F.cross_entropy(
shift_logits.reshape(-1, vocab_size),
shift_labels.reshape(-1)
)
这里的核心思想是:
- 第
t个位置输出的 logits,对应的是“下一个 token”的预测 - 所以 labels 要整体左移一位
也可以反过来说:
输入序列和监督标签本来来自同一串 token,只是监督目标比输入晚一个位置。
6.1 一个更直观的例子¶
假设输入是:
那么对齐关系其实是:
最后一个位置通常没有“下一个 token”可以预测,所以在最小实现里,经常直接丢掉最后一个位置的 logits。
6.2 padding 时还要注意什么¶
如果 batch 内做了 padding,那么 padding 位置不应该参与 loss。
这时通常会:
- 使用
attention_mask控制模型不要关注 pad 的无效部分 - 使用
ignore_index让 pad 对应的 label 不参与 cross entropy
例如:
loss = F.cross_entropy(
shift_logits.reshape(-1, vocab_size),
shift_labels.reshape(-1),
ignore_index=pad_token_id
)
这也说明第 8 章讲的数据处理,不只是“把文本切开”而已,它会直接影响第 9 章 loss 是怎么正确计算的。
7. 一个最小预训练流程长什么样¶
把概念先压平之后,一个最小 GPT 预训练流程其实可以概括为下面这条链路:
原始文本
-> tokenizer
-> token ids
-> 切分 / packing 成固定长度样本
-> batch 化
-> 输入模型 forward
-> 得到 logits
-> shift 后计算 cross entropy loss
-> backward
-> optimizer step
-> 重复很多很多次
这条链路里,前四步偏数据,后五步偏训练。
真正让很多初学者感觉“大模型训练很复杂”的原因,不是单个步骤难,而是:
每一步都和前后步骤强耦合。
7.1 数据侧在做什么¶
数据侧至少要处理这些事情:
- 文本清洗
- tokenizer 编码
- 长文切块
- padding 或 packing
- 构造 batch
如果这些环节做得不好,哪怕模型结构是对的,训练效率和最终效果也会被明显拖累。
7.2 模型侧在做什么¶
模型侧相对更“纯粹”:
- 把 token ids 变成 embedding
- 经过多层 decoder blocks
- 输出每个位置对词表的 logits
7.3 优化侧在做什么¶
优化侧负责回答:
- loss 怎样反向传播
- 参数怎样更新
- 学习率怎样变化
- 训练中断了怎么续训
现实里很多训练成败,往往不只是模型结构对不对,更取决于优化策略稳不稳。
8. 从零实现时,一个最小训练循环怎么写¶
如果你沿着第 7 章的 Mini-GPT 继续往前走,那么最小训练循环大致会长这样:
model.train()
for step, batch in enumerate(train_loader):
input_ids = batch["input_ids"].to(device)
logits = model(input_ids)
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = input_ids[:, 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
这段代码虽然短,但已经包含了预训练最关键的骨架:
- 读取一批 token 序列
- 前向计算得到 logits
- 通过 shift 构造 next-token prediction loss
- 反向传播
- 更新参数
从学习路径看,能真正把这几步自己写出来、跑通、看懂,比一上来套复杂训练框架更重要。
因为这会让你对后续所有训练工具的本质都不陌生:
Trainer、DeepSpeed、FSDP、Megatron,本质上都是在这条主干上做规模化和工程化。
9. 预训练里真正贵的,不只是模型 forward¶
很多人第一次接触预训练,会把注意力几乎全部放在“模型结构”上。
但一旦真的开始训练,就会很快发现:
预训练是一个系统工程,不是一个单纯的 forward 函数。
成本主要来自四个方面。
9.1 数据规模¶
想让模型学到足够丰富的语言模式,就需要足够多的 token。
而 token 数一旦上去,意味着:
- 训练步数更多
- 数据加载压力更大
- 清洗、去重、过滤成本更高
9.2 模型规模¶
参数越多,模型容量越大,但也意味着:
- 显存更吃紧
- 通信开销更大
- 每一步训练更贵
9.3 上下文长度¶
序列越长,尤其在标准 self-attention 下,计算和显存压力会上升得很快。
这也是为什么“支持更长上下文”从来不是一个免费升级。
9.4 训练稳定性¶
即使数据和模型都准备好了,如果优化策略不稳定,训练仍然可能出现:
- loss 不下降
- 梯度爆炸
- nan
- 训练后期震荡
所以工程上会有很多看起来“不像模型原理”的细节,实际上非常关键。
10. 训练工程里几个你必须认识的关键词¶
这一节先不追求工业级展开,而是先把最常见的关键词建立起直觉。
10.1 warmup¶
训练刚开始时,如果学习率直接给太大,模型参数可能会更新得非常剧烈,导致不稳定。
所以常见做法是:
先用较小学习率起步,再在前若干 step 里逐渐升到目标学习率。
这就是 warmup。
它的作用可以粗略理解成“别让模型一上来就猛踩油门”。
10.2 learning rate decay¶
训练到后期,学习率往往不会一直保持不变,而会逐步下降。
常见策略包括:
- linear decay
- cosine decay
直觉上,这有点像先大步走,再小步微调。
前期更强调快速学到粗结构,后期更强调稳定收敛。
10.3 gradient accumulation¶
如果显存不够大,无法一次放下理想 batch size,可以把多个 mini-batch 的梯度累积起来,再统一更新一次参数。
这样做的效果是:
用更小的显存,近似实现更大的有效 batch size。
10.4 mixed precision¶
预训练里常用 FP16 或 BF16 来降低显存占用、提升吞吐。
这就是 mixed precision 训练常见的背景。
它不是“为了炫技”,而是大规模训练里非常现实的效率优化。
10.5 checkpointing¶
预训练通常跑得很久,不可能假设过程永远不会中断。
所以需要定期保存:
- 模型参数
- optimizer 状态
- scheduler 状态
- 当前 step
这样中断后才能从上次位置继续。
这里说的 checkpoint,既包括“保存训练进度”,也常常包括“保存阶段性模型权重用于评估和对比”。
11. 数据质量为什么会直接影响预训练效果¶
很多初学者容易有一个误区:
“只要数据量够大,模型总会学到东西。”
这句话只说对了一半。
因为预训练数据不是越多越好,而是:
规模、质量和分布都很重要。
11.1 噪声数据会浪费训练预算¶
如果语料里充满:
- 重复内容
- 乱码
- 模板垃圾文本
- 极低信息密度文本
那么模型会把大量计算预算花在学习这些低价值模式上。
11.2 数据分布会影响模型偏好¶
模型吃进去的语料分布,最后会明显影响它的输出风格和能力边界。
例如:
- 代码语料多,代码能力通常更强
- 学术文本多,正式写作和术语理解往往更好
- 多语言数据丰富,跨语言泛化可能更强
这也是为什么“预训练数据配方”在工业界通常非常重要,但也很少完全公开。
11.3 数据泄漏和污染要有意识¶
如果评测集内容提前出现在预训练语料里,那么后面的评估就会失真。
这件事在 LLM 时代尤其敏感,因为互联网规模语料很难做到完全可控。
所以我们在理解预训练时,也要开始建立一个意识:
训练数据不是背景板,而是模型能力塑造的一部分。
12. sequence packing 为什么重要¶
第 8 章已经讲过切块和 padding,这里再从训练效率角度补一句。
如果大量样本长度参差不齐,而我们又简单地把它们 pad 到同样长度,就会浪费很多算力在无效位置上。
sequence packing 的核心目标就是:
尽量减少 padding 带来的空耗,让每个 batch 里真正参与训练的 token 更多。
对于大规模预训练,这种效率差异不是小修小补,而会直接放大成显著的训练成本差异。
所以“数据管线优化”并不是附属工作,它本身就是预训练的一部分。
13. 一个更完整的预训练视角:不仅是会写 loop¶
从学习角度看,到这里你至少应该开始把预训练理解成三层东西,而不是只盯着那几行 PyTorch 代码。
13.1 第一层:目标函数¶
也就是:
- 模型输入是什么
- 监督信号是什么
- loss 怎么计算
13.2 第二层:训练系统¶
也就是:
- 数据如何高效喂入
- 参数如何稳定更新
- 显存和吞吐如何平衡
13.3 第三层:规模决策¶
也就是:
- 给定算力预算,该配多大模型
- 该训练多少 token
- 是数据不够,还是模型太小,还是算力不够
而这第三层,就是第 10 章 scaling law 要重点讨论的问题。
14. 从 Mini-GPT 到真正的大模型训练,中间还差什么¶
这一章我们刻意把重点放在“最小闭环”上,是为了先建立正确主干。
但你也要知道,真实工业训练和我们现在的最小版本之间,还有很长一段工程距离。
通常还会涉及:
- 更复杂的数据清洗和去重
- 更大的分布式训练系统
- 更精细的监控与评估
- 更稳定的并行训练策略
- 更系统的实验记录与复现
不过这并不意味着前面的最小实现价值不大。
恰恰相反。
因为只要主干理解扎实,你后面学到的大部分复杂工具,都会有明确落点:
它们不是在替代原理,而是在把原理扩展到更大规模。
15. 从求职和项目表达角度,这一章能怎么讲¶
如果你把第 7、8、9 章连起来,其实已经能形成一个很不错的 from-scratch 项目叙事:
15.1 你做了什么¶
- 手写了一个最小 GPT
- 实现了 tokenizer / 数据处理流程
- 跑通了 next-token prediction 的预训练闭环
15.2 你真正理解了什么¶
- GPT 为什么使用 causal mask
- 训练时为什么要 shift logits 和 labels
- 数据管线为什么直接影响训练效率
- 学习率调度和 mixed precision 为什么常见
15.3 你还能继续扩展什么¶
- 增加验证集 perplexity 评估
- 增加采样生成观察训练效果
- 接入 Hugging Face tokenizer 和 datasets
- 增加 checkpoint resume 能力
- 对比不同 sequence length 和 batch size 的训练现象
这样的项目,已经不只是“我看懂了几篇博客”,而是:
我能把语言模型训练最关键的骨架亲手搭出来。
16. 本章小结¶
这一章里,我们把“预训练”从一个抽象大词,拆回到了一个可以落地的训练闭环。
最重要的主线其实只有三句:
- GPT 预训练的核心任务是 causal language modeling
- 它本质上是在做 next-token prediction
- 一个最小预训练系统,就是把数据、模型、loss 和优化过程真正连起来
与此同时,我们也开始看到:
- 训练效果不只取决于模型结构
- 数据质量和数据管线会直接影响结果
- 学习率、精度、batch size、checkpoint 等工程细节并不是“实现杂质”,而是预训练本身的一部分
下一章我们会继续追问一个自然问题:
如果预训练就是不断堆数据、堆参数、堆算力,那模型为什么会随着规模增长而持续变强?
这就会进入第 10 章要讨论的主题:
scaling law 与大模型为什么变大。