1. 强化学习基础:行业黑话
想象你正在和一个刚训练好的语言模型聊天。你问:“今天过得怎么样?”
模型可能回:“还行。” 也可能回:“我是个 AI,没有感情。”
人类觉得前者更自然、更友好——这就是偏好反馈。强化学习(RL)在 LLM 中的核心任务,就是让模型学会生成“人类更喜欢”的回复。
为了做到这一点,我们需要一套语言来描述这个过程。下面我们以 LLM 场景为基础介绍几个 RL 的“行业黑话”。
1.1 基本概念
-
时刻 \(t\) :就是对话的第几步。比如:
- \(t=0\):用户输入 “今天过得怎么样?” → 这是初始状态 \(s_0\)
- \(t=1\):模型输出第一个词 “今” → 动作 \(a_0 = \text{“今”}\)
- \(t=2\):模型输出第二个词 “天” → 动作 \(a_1 = \text{“天”}\)
- … 直到生成完整回复,比如 “今天过得不错!”
-
在 LLM 中,状态 \(s_t\) 通常就是到第 \(t\) 步为止已生成的 token 序列(包括用户输入和模型已输出的部分)
-
动作 \(a_t\) 就是模型在第 \(t\) 步选择的下一个 token。
-
奖励 \(r_t\):这是人类(或奖励模型)对模型行为的真实反馈信号。比如:
- 如果模型最终生成了“今天过得不错!”,人类觉得回答的不错,打 5 分 → 这个分数会折算成一个最终奖励 \(r_T\)(通常只在序列结束时给,即最后一个 token)
- 中间步骤一般没有即时奖励(\(r_t = 0\) for \(t < T\))
-
价值 \(v\):奖励 \(r\) 是真实的、来自外部的信号(比如人类打分),相对应的,价值(value)是对未来奖励的估计——因为模型不能预知未来,只能靠猜。
1.2. 价值(Value):对未来奖励的“预判”
既然模型不能看到未来,它就需要一个“预判能力”:我现在处在某个状态,未来大概能拿多少分?
这就引出了两个核心函数:
1) 状态价值函数 \(V(s_t)\)
它表达的是:在当前已生成的对话上下文 \(s_t\)(比如用户刚问完 “今天过得怎么样?”,而模型还没开始回答,或已输出“今”),模型按照当前策略继续生成后续内容,平均能获得多少人类打分。
- \(\pi(a|s_t)\) 是模型在状态 \(s_t\) 下选择下一个词 \(a\) 的概率(例如在“今天过得怎么样?”之后,选“今”还是“还”);
- \(Q(s_t, a)\) 表示如果此刻选了某个具体词 \(a\),最终能拿到的预期总分;
- 把所有可能的下一个词按模型当前的偏好加权平均,就得到了该状态的整体“预期得分”——也就是 \(V(s_t)\)。
举个例子:当模型已经输出 “今天过得”,它会评估:“按我现在的风格继续回答,人类大概率会觉得自然,可能打 4 分”,于是 \(V(s_t) \approx 4\)。
2) 动作价值函数 \(Q(s_t, a)\)
它表达的是:如果我现在处于状态 \(s_t\)(比如上下文是“今天过得”),并选择动作 \(a\)(比如生成“不”),那么我能获得当前的真实奖励 \(r_t\)(通常是 0,因为回复还没结束),再加上未来所有状态价值的折扣和。
对应到 LLM 应用场景就表示:
“如果我现在在‘今天过得’后面接‘不’,形成‘今天过得不’,那接下来我大概率会说‘错!’,组成一句完整、积极的回复,最终人类可能会打 5 分。”
其中:
- \(r_t\) 是真实发生的奖励,但在 LLM 生成过程中,只有完整回复结束后才有非零值(例如人类打分 \(r_T = 5\));在中间步骤(如生成“今”“天”时),\(r_t = 0\);
- \(V(s_{t+1}), V(s_{t+2}), \dots\) 是模型自己估计的未来价值(比如生成“不”之后,预估“今天过得不错!”能拿 4.9 分);
- \(\gamma \in [0,1]\) 是折扣因子(如 0.95),表示“未来的分不如现在的分值钱”——越靠后的 token 对当前决策的影响越小。
虽然中间每一步的 \(r_t = 0\),但 \(Q(s_t, a)\) 依然非常关键:它通过 \(V(s_{t+1})\) 等未来价值,把对最终人类反馈的预判传递回当前决策。这正是 LLM 在生成每个词时具备“前瞻能力”的来源——它不是随机选词,而是基于“这样说人类会不会喜欢”的长期预期来做选择。
为什么估计的价值函数 Q 里包含真实的 \(r_t\)?
因为 RL 的目标是用真实奖励来校准价值估计。模型通过不断对比“预测的未来得分”和“实际拿到的奖励”,来修正自己的 \(V\) 和 \(Q\) 函数。
2. PPO:RLHF 的“老大哥”
PPO(Proximal Policy Optimization)是传统 RLHF(基于人类反馈的强化学习)流程中的核心算法,是 openai 在 2016年左右提出来的。原来 closeAI 的成功在那个时候就开始蓄力了。PPO的目标很直接:让语言模型生成更受人类欢迎的回复。
PPO 中的几个关键角色
模型 | 是否训练 | 输入 | 输出 | 输出维度说明 |
---|---|---|---|---|
Policy Model \(\pi_\theta\) | ✅ 是 | prompt \(x\)(token IDs,长度 \(L_x\)) | 生成回复 \(y = (a_1,\dots,a_T)\),以及每个 token 的 log-prob \(\log \pi_\theta(a_t | s_t)\) | \(y\): \([T]\) logprobs: \([T]\) |
Reference Model \(\pi_{\text{ref}}\) | ❌ 冻结 | 同上 \(x\) | 同上 log-prob \(\log \pi_{\text{ref}}(a_t | s_t)\) | \([T]\) |
Critic Model \(V_\psi\) | ✅ 是 | 状态序列 \(s_t = x \oplus y_{\le t}\)(token IDs,长度 \(L_x + t\)) | 价值估计 \(V_\psi(s_t)\) | 标量(或 \([1]\)),对每个 \(t=0,\dots,T\) 输出一个值 → 总输出 \([T+1]\) |
Reward Model \(r_\phi\) | ❌ 冻结 | \((x, y)\)(完整 prompt + response) | 标量奖励 \(R = r_\phi(x, y)\) | 标量(或 \([1]\)) |
注:\(\oplus\) 表示 token 拼接;\(T\) 是生成回复的长度(可变,但训练时通常 padding 到固定长度)。
PPO 的两阶段训练流程
PPO通过分阶段解耦“数据生成”和“策略学习”,在保证训练稳定性的同时,让模型逐步学会生成更符合人类偏好的回复。整个流程分为如下两个阶段:
阶段 1:采样与反馈(Sample + Label)
✅ 目标
用当前策略模型生成一批回复,并利用冻结的奖励模型打分,再结合当前评论家模型估计价值,最终为每个 token 动作计算出优势(Advantage) 和回报(Return),作为后续训练的监督信号。
📌 关键点:此阶段不更新任何模型参数,只是“收集数据”。Policy 和 Critic 在采样时使用的是当前最新参数,但输出会被 detach(视为常数),作为“旧策略”和“旧评论家”的快照。
🧩 参与模型与接口
模型 | 是否更新 | 输入 | 输出 | 输出维度 |
---|---|---|---|---|
Policy Model \(\pi_\theta\) | ❌(采样时不更新) | prompt \(x \in \mathbb{Z}^{L_x}\) | 生成回复 \(y \in \mathbb{Z}^T\) 及每个 token 的 log-prob \(\log \pi_\theta(a_t | s_t)\) |
\(y\): \([T]\) logprobs: \([T]\) |
Critic Model \(V_\psi\) | ❌(采样时不更新) | 状态 \(s_t = x \oplus y_{\le t} \in \mathbb{Z}^{L_x + t}\) | 价值估计 \(V_\psi(s_t) \in \mathbb{R}\) | 对 \(t=0,\dots,T\) 输出 \([T+1]\) 个标量 |
Reward Model \(r_\phi\) | ❌(始终冻结) | \((x, y)\) | 标量奖励 \(R = r_\phi(x, y)\) | \([1]\) |
注:\(L_x\) 是 prompt 长度,\(T\) 是生成回复长度(实际中常 padding 到固定 max_len)。
🔢 核心计算逻辑
-
生成轨迹:对每个 prompt \(x\),用当前策略生成完整回复 \(y = (a_1, ..., a_T)\),形成状态序列:
\[s_0 = x,\quad s_1 = x \oplus a_1,\quad \dots,\quad s_T = x \oplus y \] -
获取最终奖励:调用冻结的 Reward Model:
\[R = r_\phi(x, y) \](中间步骤无奖励,即 \(r_t = 0\) for \(t < T\))
-
计算回报(Return):
\[\hat{R}_t = \sum_{k=t}^{T} \gamma^{k-t} r_k = \gamma^{T - t} R \]因为只有最后一步有奖励。回报序列 \(\hat{R}_0, \hat{R}_1, ..., \hat{R}_T\) 构成目标值。
-
计算优势(Advantage):
\[A_t = \hat{R}_t - V_\psi(s_t), \quad t = 0, 1, ..., T-1 \]表示:在状态 \(s_t\) 下执行动作 \(a_t\),比“平均水平”好多少。
-
保存“旧”值:将当前策略的 log-prob 和评论家的 value detach,作为阶段 2 的基准(即“old policy”和“old critic”)。
💻 伪代码(阶段 1)
trajectories = []for x in prompts: # x: [L_x]# 1. 用当前策略生成回复 y 和 log-proby, logprobs = policy_model.generate_with_logprobs(x) # y: [T], logprobs: [T]# 2. 构建状态序列 s_0 ... s_Tstates = [torch.cat([x, y[:t]]) for t in range(len(y) + 1)] # len = T+1# 3. 用当前评论家估计每个状态的价值values = torch.stack([critic_model(s) for s in states]) # [T+1]# 4. 奖励模型打分(仅最终奖励)R = reward_model(x, y) # scalar# 5. 计算回报:R_t = γ^{T−t} * RT_len = len(y)returns = torch.zeros(T_len + 1)returns[T_len] = Rfor t in reversed(range(T_len)):returns[t] = gamma * returns[t + 1]# 6. 计算优势:A_t = R_t - V(s_t),仅对 t=0..T-1 有效advantages = returns[:-1] - values[:-1] # [T]# 7. 保存“旧”值(detach 阻断梯度)trajectories.append({'x': x,'y': y,'logprobs_old': logprobs.detach(), # [T]'values_old': values.detach(), # [T+1]'advantages': advantages.detach(), # [T]'returns': returns.detach() # [T+1]})
✅ 此阶段结束时,我们得到一个固定的数据集,后续训练将在此数据上多次迭代。
阶段 2:策略与评论家更新(Policy & Critic Learning)
✅ 目标
利用阶段 1 收集的固定轨迹数据,更新策略模型(Policy)和评论家模型(Critic),使得:
- 策略更倾向于选择高优势的动作;
- 评论家更准确地预测未来回报;
- 同时通过 PPO-clip 和 KL 正则防止策略突变或偏离合理语言分布。
🧩 参与模型与接口
模型 | 是否更新 | 作用 |
---|---|---|
Policy Model \(\pi_\theta\) | ✅ | 被优化的主模型 |
Critic Model \(V_\psi\) | ✅ | 被优化的价值估计器 |
Reference Model \(\pi_{\text{ref}}\) | ❌(始终冻结) | 提供 KL 正则基准(通常是 SFT 后的初始模型) |
Reward Model | ❌ | 不参与此阶段 |
🔢 核心计算逻辑
-
策略损失(PPO-Clip)
定义概率比:\[r_t(\theta) = \frac{\pi_\theta(a_t | s_t)}{\pi_{\text{old}}(a_t | s_t)} = \exp\left( \log \pi_\theta(a_t|s_t) - \log \pi_{\text{old}}(a_t|s_t) \right) \]PPO 损失为:
\[\mathcal{L}^{\text{PPO}} = \mathbb{E}_t \left[ \min\left( r_t(\theta) A_t,\ \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t \right) \right] \]- 若 \(A_t > 0\):鼓励增加动作概率,但最多增加 \((1+\epsilon)\) 倍;
- 若 \(A_t < 0\):鼓励减少概率,但最多减少到 \((1-\epsilon)\) 倍。
-
KL 散度正则(防止语言退化)
\[\mathcal{L}^{\text{KL}} = \beta \cdot \mathbb{E}_t \left[ \log \pi_\theta(a_t|s_t) - \log \pi_{\text{ref}}(a_t|s_t) \right] \]- \(\pi_{\text{ref}}\) 是冻结的 SFT 模型;
- \(\beta\) 控制正则强度(如 0.01~0.1)。
-
评论家损失(Value MSE)
\[\mathcal{L}^{\text{value}} = \mathbb{E}_t \left[ \left( V_\psi(s_t) - \hat{R}_t \right)^2 \right] \]- 目标是让评论家准确预测阶段 1 计算出的回报 \(\hat{R}_t\)。
-
总损失:
\[\mathcal{L}_{\text{total}} = -\mathcal{L}^{\text{PPO}} + \beta \cdot \text{KL} + c_1 \cdot \mathcal{L}^{\text{value}} \]
💻 伪代码(阶段 2)
for epoch in range(K): # K=2~4,对同一数据集多轮优化for traj in trajectories:x, y = traj['x'], traj['y'] # x: [L_x], y: [T]logprobs_old = traj['logprobs_old'] # [T]advantages = traj['advantages'] # [T]returns = traj['returns'] # [T+1]# --- 1. 策略损失 ---logprobs_curr = policy_model.get_logprobs(x, y) # [T]ratio = torch.exp(logprobs_curr - logprobs_old) # [T]surr1 = ratio * advantagessurr2 = torch.clamp(ratio, 1 - eps, 1 + eps) * advantagesppo_loss = -torch.mean(torch.min(surr1, surr2))# KL 正则(ref_model 冻结)with torch.no_grad():logprobs_ref = ref_model.get_logprobs(x, y) # [T]kl_loss = torch.mean(logprobs_curr - logprobs_ref)policy_loss = ppo_loss + beta * kl_loss# --- 2. 评论家损失 ---states = [torch.cat([x, y[:t]]) for t in range(len(y) + 1)]values_pred = torch.stack([critic_model(s) for s in states]) # [T+1]value_loss = F.mse_loss(values_pred, returns)# --- 3. 优化 ---total_loss = policy_loss + c1 * value_lossoptimizer.zero_grad()total_loss.backward()optimizer.step()
🎯 总结:PPO 的设计哲学
- 阶段 1 是“探索”:用当前策略生成多样回复,用外部信号(RM)和内部估计(Critic)打标签;
- 阶段 2 是“学习”:在固定数据上保守更新,通过 clip 和 KL 防止“学歪”;
- Reference Model 是安全网:确保语言依然流畅、合理;
- 整个流程可迭代:每轮 PPO 后,策略更强,下一轮采样质量更高。
这种“采样-学习”交替的模式,正是 PPO 能在 LLM 对齐中兼顾效果、稳定性和安全性的关键。
3. DPO:绕过 RL 的“聪明办法”
DPO(Direct Preference Optimization)发现:其实不需要显式训练 Reward Model + PPO,可以直接从人类偏好数据中优化策略。
DPO 的核心洞察
人类偏好数据是成对的:\((x, y_w, y_l)\),其中:
- \(x\):用户输入(prompt)
- \(y_w\):人类偏好的回复(win)
- \(y_l\):较差的回复(lose)
DPO 证明:最大化人类偏好等价于最小化下面这个损失:
这个公式到底在算什么?
- \(\pi_\theta(y|x)\):当前训练模型在 prompt \(x\) 下生成完整回复 \(y\) 的概率
→ 实际计算时,是把 \(y\) 拆成 token 序列,求 \(\prod_t \pi_\theta(y_t | x, y_{<t})\) - \(\pi_{\text{ref}}(y|x)\):参考模型(SFT 模型)生成 \(y\) 的概率
- \(\beta\):温度参数,控制优化强度(越大越激进)
通俗理解:DPO 希望模型对“好回复”的相对概率(相比参考模型)比“坏回复”更高。
DPO 伪代码
for batch in preference_data:x, y_w, y_l = batch# 计算当前模型和参考模型对两个回复的 log 概率logp_w = policy_model.log_prob(x, y_w)logp_l = policy_model.log_prob(x, y_l)ref_logp_w = ref_model.log_prob(x, y_w)ref_logp_l = ref_model.log_prob(x, y_l)# 计算 logits 差logits = beta * ((logp_w - ref_logp_w) - (logp_l - ref_logp_l))# 二分类损失:希望 logits 越大越好loss = -F.logsigmoid(logits).mean()optimizer.step(loss)
DPO 本质是一个带参考模型的对比学习(contrastive learning),完全不需要 RL 循环,所以训练快、稳定。
4. GRPO:在 PPO 和 DPO 之间找平衡
PPO vs DPO:各自的痛
方法 | 优点 | 缺点 |
---|---|---|
PPO | 支持 online learning(边生成边学),样本利用率高;可结合多种奖励(如安全性、事实性) | 需要训练 4 个模型(Policy, Critic, RM, Reference),流程复杂;RM 质量直接影响效果 |
DPO | 训练简单,只需 2 个模型(Policy + Reference);效果接近 PPO | 完全依赖离线(offline)偏好数据;容易过拟合(尤其数据少时);无法引入动态奖励 |
GRPO:群体相对优化
GRPO(Group Relative Policy Optimization)的思路是:
既然人类经常面对多个选项做判断(比如从 4 个回复中选最好的 2 个),那就直接建模这种“群体偏好”。
GRPO 的做法
- 对每个 prompt \(x\),用当前策略生成 \(K\) 个回复(比如 \(K=4\))
- 根据 Reward Model(或人类)将这些回复分成“好组”和“坏组”
- 优化目标:拉大组间差异,缩小组内差异
GRPO 损失函数(简化版)
这其实是一个带参考模型的 softmax 分类损失:希望“好回复”的归一化概率更高。
GRPO 伪代码
for x in prompts:# 1. 生成 K 个回复responses = [policy_model.generate(x) for _ in range(K)]# 2. 用 RM 打分并分组(比如 top-2 为 good)scores = [reward_model.score(x, y) for y in responses]good_mask = get_top_k_mask(scores, k=2)# 3. 计算每个回复的 log ratioratios = []for y in responses:logp = policy_model.log_prob(x, y)ref_logp = ref_model.log_prob(x, y)ratios.append(beta * (logp - ref_logp))# 4. softmax 分类损失logits = torch.stack(ratios)loss = F.cross_entropy(logits.unsqueeze(0), target=good_mask)optimizer.step(loss)
GRPO 的优势:
- 保留了 PPO 的 online 生成能力(自己造数据)
- 像 DPO 一样只优化策略模型,无需 Critic
- 对 RM 的依赖比 PPO 弱(只需排序,不要求绝对分数准确)
总结:选哪个?
方法 | 模型数量 | 是否需要 RM | 是否 RL | 适合场景 |
---|---|---|---|---|
PPO | 4(Policy, Critic, RM, Ref) | ✅ | ✅ | 高质量对齐,多目标奖励 |
DPO | 2(Policy, Ref) | ❌ | ❌ | 快速迭代,偏好数据充足 |
GRPO | 3(Policy, RM, Ref) | ✅(弱依赖) | ⚠️(类 RL) | 平衡效率与效果,支持 online 学习 |
强化学习在 LLM 中,早已不是“必须用 PPO”的时代。DPO 让对齐变得像 SFT 一样简单,GRPO 则试图把 PPO 的灵活性和 DPO 的简洁性结合起来。
技术在进化,我们的工具箱也在变丰富。选对方法,比盲目堆资源更重要。