Training Reasoning Models: PPO, GRPO, Reward Functions, and RLVR
A deep technical guide to training reasoning models like o1 and DeepSeek R1—covering PPO, GRPO, reward function design, RLVR, and distillation techniques.
Table of Contents
The Training Revolution Behind Reasoning Models
While test-time compute scaling (CoT, Tree of Thoughts, MCTS) improves what models can do at inference, training-time innovations determine what models learn to do in the first place. The breakthrough reasoning capabilities of o1, o3, and DeepSeek R1 come from novel training algorithms—particularly in how rewards are defined and optimized.
This post provides a comprehensive technical guide to training reasoning models, covering:
- Reward function design (outcome, process, rule-based, learned)
- PPO (Proximal Policy Optimization) and its components
- GRPO (Group Relative Policy Optimization) and why it's revolutionary
- RLVR (Reinforcement Learning with Verifiable Rewards)
- DPO and other alternatives
- How OpenAI o1/o3 and DeepSeek R1 are actually trained
- Distillation to compress reasoning into smaller models
Understanding Reward Functions
The Role of Rewards in RL for LLMs
In reinforcement learning for language models, the reward function defines what "good" outputs look like. The model learns to maximize expected reward through gradient updates.
Policy π(response | prompt) → Reward R(prompt, response) → Policy Update
The reward function is arguably the most critical design decision in training reasoning models. Different reward signals lead to fundamentally different model behaviors.
Types of Reward Functions
1. Outcome Reward Models (ORMs)
ORMs evaluate only the final answer—they don't care about reasoning steps.
Why ORMs are the simplest starting point: The appeal of ORMs is their simplicity—you need only a ground truth answer and a way to check equality. No step-level labels, no learned reward model, just binary correctness. This simplicity comes with a tradeoff: the model receives reward signal only at the end of potentially long reasoning chains, making it hard to learn which steps were good versus bad.
The credit assignment problem: Imagine a model produces a 20-step mathematical proof that arrives at the wrong answer. Which step caused the error? Was it step 3 where it made an arithmetic mistake, or step 15 where it applied the wrong formula? ORMs can't tell—they just report "wrong." The model must learn through many trials which patterns lead to correct answers, a slow and sample-inefficient process. Despite this limitation, ORMs work surprisingly well when combined with algorithms like GRPO that generate many samples per problem.
When ORMs shine: ORMs excel in domains with unambiguous correctness criteria: math problems with numeric answers, code that passes test cases, logic puzzles with definite solutions. They struggle in domains where partial credit matters or where "close" answers should receive some reward.
class OutcomeRewardModel:
"""
Evaluates correctness of final answer only.
Binary or continuous score based on answer quality.
"""
def __init__(self, verifier=None):
self.verifier = verifier
def compute_reward(
self,
problem: str,
response: str,
ground_truth: str = None
) -> float:
"""
Compute outcome-based reward.
Returns:
1.0 if correct, 0.0 if incorrect (binary)
or continuous score [0, 1]
"""
# Extract final answer from response
predicted_answer = self.extract_answer(response)
if ground_truth is not None:
# Exact match or semantic equivalence
if self.normalize(predicted_answer) == self.normalize(ground_truth):
return 1.0
return 0.0
# Use verifier model if no ground truth
if self.verifier:
return self.verifier.score(problem, predicted_answer)
raise ValueError("Need ground_truth or verifier")
def extract_answer(self, response: str) -> str:
"""Extract final answer from reasoning response."""
# Look for common answer patterns
patterns = [
r"\\boxed{([^}]+)}", # LaTeX boxed
r"[Aa]nswer[:\s]+([^\n]+)", # "Answer: X"
r"[Tt]herefore[,:\s]+([^\n]+)", # "Therefore, X"
r"= ([^\n]+)$" # Final equation
]
for pattern in patterns:
match = re.search(pattern, response)
if match:
return match.group(1).strip()
# Fallback: last line
return response.strip().split('\n')[-1]
def normalize(self, answer: str) -> str:
"""Normalize answer for comparison."""
# Remove whitespace, lowercase
answer = ' '.join(answer.lower().split())
# Numeric normalization
try:
num = float(answer.replace(',', ''))
return f"{num:.6f}"
except:
return answer
Advantages:
- Simple to implement
- No need for step-level labels
- Works with any verifiable domain
Disadvantages:
- Sparse signal (only at the end)
- Can't distinguish good reasoning with wrong answer from lucky guessing
- Credit assignment problem: which steps caused the error?
Understanding the answer extraction patterns: The extract_answer method uses a cascade of regex patterns to find the final answer in model responses. This is crucial because models format answers inconsistently—some use LaTeX \boxed{}, others write "Answer: X", and some bury the answer in prose. The pattern hierarchy matters: we check the most structured formats first (LaTeX boxed) because they're most reliable, falling back to less structured patterns only when needed. In production, you'll likely need to expand these patterns based on your model's output style.
Normalization is harder than it looks: The normalize method handles the surprisingly complex task of comparing answers. Two answers might be semantically identical but textually different: "1/2" and "0.5" and "0.500" and "50%" all represent the same value. The numeric normalization handles floating-point comparison, but real production systems need far more: unit conversion ("1 meter" vs "100 cm"), symbolic equivalence ("x^2" vs "x*x"), and domain-specific rules. This is where many RL training pipelines silently fail—answers that should match don't.
2. Process Reward Models (PRMs)
PRMs evaluate each reasoning step, providing dense reward signal.
Why dense rewards dramatically accelerate learning: The fundamental limitation of ORMs—sparse, end-of-sequence feedback—is precisely what PRMs address. By scoring each step, PRMs provide dense supervision throughout the reasoning chain. If step 5 introduces an error, the PRM flags it immediately rather than waiting until the final answer. This transforms the credit assignment problem: instead of asking "somewhere in these 20 steps, something went wrong," we can ask "was this specific step correct?"
The labeling challenge: PRMs require step-level labels, which are expensive to obtain. For a 20-step solution, you need 20 correctness judgments, not just one. OpenAI's PRM800K dataset required human mathematicians to verify 800,000 individual reasoning steps—a massive investment. This cost explains why rule-based rewards (like DeepSeek R1's approach) have become popular: they trade some precision for eliminaing labeling costs entirely.
How PRMs enable search: PRMs aren't just for training—they're crucial for test-time compute scaling. During inference, you can generate multiple solution paths and use the PRM to score each step, pruning bad branches early. This is the foundation of techniques like beam search with process rewards, where computation focuses on promising reasoning paths rather than exploring dead ends.
class ProcessRewardModel(nn.Module):
"""
Process Reward Model that scores each reasoning step.
Architecture: Language model backbone + step-level classifier head.
"""
def __init__(
self,
backbone: str = "meta-llama/Llama-3.1-8B",
num_labels: int = 2 # correct/incorrect
):
super().__init__()
self.backbone = AutoModel.from_pretrained(backbone)
self.hidden_size = self.backbone.config.hidden_size
# Step classification head
self.step_classifier = nn.Sequential(
nn.Linear(self.hidden_size, self.hidden_size),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(self.hidden_size, num_labels)
)
# Optional: continuous reward head
self.reward_head = nn.Sequential(
nn.Linear(self.hidden_size, self.hidden_size // 2),
nn.GELU(),
nn.Linear(self.hidden_size // 2, 1),
nn.Sigmoid()
)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
step_end_positions: list[int]
) -> dict:
"""
Forward pass computing rewards at step boundaries.
Args:
input_ids: Tokenized problem + solution
attention_mask: Attention mask
step_end_positions: Token positions marking end of each step
Returns:
dict with step_rewards and step_labels
"""
# Get hidden states from backbone
outputs = self.backbone(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True
)
hidden_states = outputs.last_hidden_state
# Extract representations at step boundaries
batch_size = hidden_states.size(0)
step_representations = []
for batch_idx in range(batch_size):
batch_step_reps = hidden_states[batch_idx, step_end_positions[batch_idx], :]
step_representations.append(batch_step_reps)
step_reps = torch.stack(step_representations) # [batch, num_steps, hidden]
# Compute step-level predictions
step_logits = self.step_classifier(step_reps) # [batch, num_steps, 2]
step_rewards = self.reward_head(step_reps).squeeze(-1) # [batch, num_steps]
return {
"step_logits": step_logits,
"step_rewards": step_rewards,
"step_probs": F.softmax(step_logits, dim=-1)[:, :, 1] # P(correct)
}
def score_solution(
self,
problem: str,
solution_steps: list[str],
tokenizer
) -> list[float]:
"""
Score each step in a solution.
Returns list of scores [0, 1] for each step.
"""
# Build input text
text = f"Problem: {problem}\n\nSolution:\n"
step_positions = []
for i, step in enumerate(solution_steps):
text += f"Step {i+1}: {step}\n"
# Mark position at end of step
tokens_so_far = len(tokenizer.encode(text))
step_positions.append(tokens_so_far - 1)
# Tokenize
inputs = tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048
)
# Forward pass
with torch.no_grad():
outputs = self.forward(
inputs.input_ids,
inputs.attention_mask,
[step_positions]
)
return outputs["step_rewards"][0].tolist()
def train_prm(
model: ProcessRewardModel,
train_dataset: Dataset,
tokenizer,
epochs: int = 3,
learning_rate: float = 1e-5
):
"""
Train Process Reward Model on step-labeled data.
Dataset format:
{
"problem": str,
"steps": list[str],
"step_labels": list[int] # 1=correct, 0=incorrect
}
"""
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
model.train()
for epoch in range(epochs):
total_loss = 0
correct = 0
total = 0
for batch in train_dataset:
optimizer.zero_grad()
# Prepare inputs
inputs, step_positions = prepare_prm_inputs(
batch["problems"],
batch["steps"],
tokenizer
)
# Forward
outputs = model(
inputs.input_ids.to(model.device),
inputs.attention_mask.to(model.device),
step_positions
)
# Compute loss
labels = torch.tensor(batch["step_labels"]).to(model.device)
loss = criterion(
outputs["step_logits"].view(-1, 2),
labels.view(-1)
)
# Backprop
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
# Accuracy
preds = outputs["step_logits"].argmax(dim=-1)
correct += (preds == labels).sum().item()
total += labels.numel()
acc = correct / total
print(f"Epoch {epoch+1}: Loss={total_loss:.4f}, Accuracy={acc:.4f}")
Understanding the architecture: The PRM uses a language model backbone to understand the reasoning context, then adds classification heads specifically for step-level scoring. The key architectural decision is extracting representations at "step boundary" positions—the tokens where one reasoning step ends. These boundary positions carry information about what the model has understood so far, making them natural points to assess correctness.
Why two output heads: The architecture includes both a classifier head (step_classifier, binary correct/incorrect) and a continuous reward head (reward_head, scalar 0-1). The classifier is useful during training with binary labels, while the continuous head is better for search applications where you want fine-grained quality scores. In practice, many implementations use only one, but having both provides flexibility.
The training loop dissected: The train_prm function implements standard supervised learning, but notice the careful handling of variable-length sequences. Each example has a different number of steps, and each step has a label. The prepare_prm_inputs function (not shown) must handle this complexity—padding sequences, tracking which positions correspond to step boundaries, and aligning labels correctly. Getting this bookkeeping right is where many PRM implementations fail.
Gradient clipping is essential: The clip_grad_norm_ call prevents gradient explosions that commonly occur when training on long sequences with many step boundaries. Without clipping, a single long solution with many incorrect steps can produce gradients that destabilize training. The value 1.0 is a reasonable default, but you may need to tune this based on your sequence lengths.
3. Rule-Based Rewards
DeepSeek R1-Zero uses remarkably simple rule-based rewards.
The surprising power of simplicity: DeepSeek's R1-Zero demonstrated that you don't need learned reward models or human-labeled data to train reasoning. Their rule-based system uses just two signals: format correctness (did the model use <think> tags?) and answer correctness (did it get the right answer?). This simplicity seemed almost naive compared to the sophisticated reward models in prior work, yet it produced state-of-the-art reasoning. The key insight: for verifiable domains, the correctness signal alone provides enough gradient for learning.
Why format rewards matter: The format reward (checking for <think> tags, substantial reasoning content) might seem cosmetic, but it serves a crucial purpose: it encourages the model to "show its work." Without format incentives, models under RL pressure optimize directly for answers, often collapsing to short, uninterpretable outputs. The format reward acts as a regularizer, ensuring the model maintains readable reasoning chains even as it optimizes for correctness.
The verification bottleneck: Rule-based rewards are only as good as your verification logic. For math, numeric comparison with tolerance handles most cases, but edge cases abound: symbolic expressions, fractions, scientific notation, mixed units. For code, you need a sandboxed execution environment that's both secure and fast enough for RL training (millions of evaluations). The code_correctness method shown is simplified—production systems use containers, time limits, and memory limits to safely execute untrusted code.
class RuleBasedRewardSystem:
"""
Rule-based reward system as used in DeepSeek R1-Zero.
No neural network—just rules checking format and correctness.
"""
def __init__(
self,
format_weight: float = 0.1,
correctness_weight: float = 1.0
):
self.format_weight = format_weight
self.correctness_weight = correctness_weight
def compute_reward(
self,
problem: str,
response: str,
ground_truth: str,
problem_type: str = "math"
) -> dict:
"""
Compute rule-based reward.
Returns dict with reward breakdown.
"""
rewards = {}
# 1. Format reward: Check for proper thinking structure
format_reward = self.check_format(response)
rewards["format"] = format_reward * self.format_weight
# 2. Correctness reward: Binary signal for answer correctness
correctness_reward = self.check_correctness(
response, ground_truth, problem_type
)
rewards["correctness"] = correctness_reward * self.correctness_weight
# Total reward
rewards["total"] = rewards["format"] + rewards["correctness"]
return rewards
def check_format(self, response: str) -> float:
"""
Check if response follows expected format.
R1-Zero expected: <think>...</think> structure
"""
format_score = 0.0
# Check for thinking tags
has_think_start = "<think>" in response or "<thinking>" in response
has_think_end = "</think>" in response or "</thinking>" in response
if has_think_start and has_think_end:
format_score += 0.5
# Check thinking comes before answer
think_end_pos = response.find("</think>")
if think_end_pos == -1:
think_end_pos = response.find("</thinking>")
# Check there's content after thinking
content_after = response[think_end_pos:].strip()
if len(content_after) > 20: # Has substantial answer
format_score += 0.3
# Check thinking is substantial
think_start = response.find("<think>")
if think_start == -1:
think_start = response.find("<thinking>")
thinking_content = response[think_start:think_end_pos]
if len(thinking_content) > 100: # Substantial thinking
format_score += 0.2
return format_score
def check_correctness(
self,
response: str,
ground_truth: str,
problem_type: str
) -> float:
"""
Check answer correctness with type-specific rules.
"""
predicted = self.extract_answer(response)
expected = self.normalize_answer(ground_truth, problem_type)
predicted_norm = self.normalize_answer(predicted, problem_type)
if problem_type == "math":
return self.math_equivalence(predicted_norm, expected)
elif problem_type == "code":
return self.code_correctness(predicted, expected)
elif problem_type == "mcq":
return 1.0 if predicted_norm == expected else 0.0
else:
# String matching with normalization
return 1.0 if predicted_norm == expected else 0.0
def math_equivalence(self, pred: str, expected: str) -> float:
"""Check mathematical equivalence."""
try:
# Try numeric comparison
pred_num = float(pred.replace(',', ''))
exp_num = float(expected.replace(',', ''))
# Allow small tolerance for floating point
if abs(pred_num - exp_num) < 1e-6:
return 1.0
# Check relative error
if exp_num != 0:
rel_error = abs(pred_num - exp_num) / abs(exp_num)
if rel_error < 1e-4:
return 1.0
return 0.0
except ValueError:
# Non-numeric: exact string match
return 1.0 if pred == expected else 0.0
def code_correctness(self, pred_code: str, test_cases: str) -> float:
"""
Check code correctness via test execution.
This is a simplified version—production would sandbox execution.
"""
# Parse test cases
tests = self.parse_test_cases(test_cases)
if not tests:
return 0.0
passed = 0
for test_input, expected_output in tests:
try:
# Execute (in sandbox in production!)
actual = self.execute_code(pred_code, test_input)
if str(actual).strip() == str(expected_output).strip():
passed += 1
except:
pass
return passed / len(tests)
def extract_answer(self, response: str) -> str:
"""Extract final answer from response."""
# Check for boxed answer (LaTeX)
boxed_match = re.search(r"\\boxed{([^}]+)}", response)
if boxed_match:
return boxed_match.group(1)
# Check for explicit answer marker
answer_match = re.search(
r"(?:final\s+)?answer[:\s]+(.+?)(?:\n|$)",
response,
re.IGNORECASE
)
if answer_match:
return answer_match.group(1).strip()
# Look for content after thinking
if "</think>" in response:
after_think = response.split("</think>")[-1]
return after_think.strip()
# Fallback: last line
return response.strip().split('\n')[-1]
def normalize_answer(self, answer: str, problem_type: str) -> str:
"""Normalize answer for comparison."""
answer = answer.strip().lower()
# Remove common prefixes
prefixes = ["the answer is", "answer:", "therefore"]
for prefix in prefixes:
if answer.startswith(prefix):
answer = answer[len(prefix):].strip()
# Type-specific normalization
if problem_type == "math":
# Remove units, currency symbols
answer = re.sub(r"[\$£€%]", "", answer)
# Normalize fractions
if "/" in answer:
try:
parts = answer.split("/")
answer = str(float(parts[0]) / float(parts[1]))
except:
pass
return answer
The format scoring breakdown: The check_format method assigns partial credit for structural elements: 0.5 for having thinking tags, 0.3 for content after thinking, 0.2 for substantial thinking content. These weights are somewhat arbitrary—DeepSeek's actual weights aren't published. The key insight is that format scoring should be much smaller than correctness scoring (0.1 weight vs 1.0), so models don't game format while ignoring substance.
Math equivalence is tricky: The math_equivalence method first tries numeric comparison with tolerance (absolute and relative error), then falls back to string matching. This handles most cases but misses symbolic equivalence: "2+3" and "5" are mathematically equivalent but fail string comparison. Production systems often use symbolic math libraries (SymPy) for proper equivalence checking, though this adds latency.
Why relative error matters: Absolute tolerance (abs(pred - expected) < 1e-6) fails for large numbers—if the answer is 1 million, a difference of 0.001 should be fine, but it exceeds 1e-6. The relative error check (rel_error < 1e-4) handles this, accepting answers within 0.01% of the correct value regardless of magnitude. Using both catches cases where either alone would fail.
4. Learned Reward Models
Neural networks trained to predict human preferences.
When you can't verify, you learn: For many tasks—creative writing, open-ended questions, nuanced reasoning—there's no ground truth to verify against. You can't write a rule to check if an explanation is "good." Learned reward models fill this gap by training neural networks to predict which responses humans prefer. This is the standard RLHF approach used in ChatGPT, Claude, and most instruction-following models.
The preference modeling assumption: Learned reward models assume human preferences can be captured by a scalar value—given a response, the model outputs a single number representing "quality." This is a strong assumption. Human preferences are often inconsistent, context-dependent, and multi-dimensional (a response might be helpful but verbose). Despite these limitations, scalar reward models work surprisingly well in practice, likely because they capture the dominant quality dimension.
The Bradley-Terry model: The training loss isn't "predict the reward for this response" but rather "given two responses, predict which one humans preferred." This pairwise comparison approach (Bradley-Terry) is more robust than absolute scoring: humans are better at saying "A is better than B" than assigning absolute quality numbers. The model learns reward values such that preferred responses score higher than non-preferred ones.
class LearnedRewardModel(nn.Module):
"""
Reward model learned from human preference data.
Standard approach in RLHF pipelines.
"""
def __init__(self, backbone: str = "meta-llama/Llama-3.1-8B"):
super().__init__()
self.backbone = AutoModelForCausalLM.from_pretrained(backbone)
self.hidden_size = self.backbone.config.hidden_size
# Reward head: maps final hidden state to scalar
self.reward_head = nn.Linear(self.hidden_size, 1)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor
) -> torch.Tensor:
"""Compute reward for input sequence."""
outputs = self.backbone(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True
)
# Use last token's hidden state
last_hidden = outputs.hidden_states[-1]
# Get position of last non-padding token
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_indices = torch.arange(input_ids.size(0), device=input_ids.device)
last_token_hidden = last_hidden[batch_indices, sequence_lengths]
# Compute reward
reward = self.reward_head(last_token_hidden)
return reward.squeeze(-1)
def train_reward_model_from_preferences(
model: LearnedRewardModel,
preference_data: Dataset,
tokenizer,
epochs: int = 1,
learning_rate: float = 1e-5
):
"""
Train reward model from human preference data.
Data format:
{
"prompt": str,
"chosen": str, # Preferred response
"rejected": str # Non-preferred response
}
Uses Bradley-Terry model: P(chosen > rejected) = σ(r_chosen - r_rejected)
"""
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
model.train()
for epoch in range(epochs):
total_loss = 0
total_correct = 0
total_samples = 0
for batch in preference_data:
optimizer.zero_grad()
# Tokenize chosen and rejected
chosen_inputs = tokenizer(
[f"{p} {c}" for p, c in zip(batch["prompts"], batch["chosen"])],
return_tensors="pt",
padding=True,
truncation=True
).to(model.device)
rejected_inputs = tokenizer(
[f"{p} {r}" for p, r in zip(batch["prompts"], batch["rejected"])],
return_tensors="pt",
padding=True,
truncation=True
).to(model.device)
# Compute rewards
chosen_rewards = model(**chosen_inputs)
rejected_rewards = model(**rejected_inputs)
# Bradley-Terry loss: -log(σ(r_chosen - r_rejected))
loss = -F.logsigmoid(chosen_rewards - rejected_rewards).mean()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
# Accuracy: how often does model prefer chosen?
correct = (chosen_rewards > rejected_rewards).sum().item()
total_correct += correct
total_samples += len(batch["prompts"])
acc = total_correct / total_samples
avg_loss = total_loss / len(preference_data)
print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, Accuracy={acc:.4f}")
Why use the last token's hidden state: The model processes the entire prompt+response sequence, building up contextual representations. The last token's hidden state has "seen" everything that came before, making it a natural summary of the full response. Alternative approaches exist: averaging all token representations, using the [CLS] token if present, or attention pooling. The last-token approach is simplest and works well because modern language models are trained autoregressively, making the final position information-rich.
The logsigmoid loss explained: The loss -F.logsigmoid(chosen_rewards - rejected_rewards) implements the Bradley-Terry preference model. When chosen_rewards > rejected_rewards (correct preference), the loss is small (sigmoid approaches 1, logsigmoid approaches 0). When wrong, sigmoid approaches 0, and -logsigmoid becomes large. The margin matters: the model is rewarded for increasing the gap between chosen and rejected, not just getting the order right.
Accuracy as a sanity check: During training, we track how often the model correctly ranks chosen above rejected. Random chance would give 50% accuracy. Early in training, accuracy should climb quickly to 60-70%, then more slowly toward 80-90%. If accuracy doesn't improve, check your data (are labels correct?), learning rate (too high causes instability, too low doesn't learn), or architecture (is the model expressive enough?).
Reward Hacking and Mitigations
Models can learn to exploit reward functions in unintended ways.
The Goodhart's Law problem: "When a measure becomes a target, it ceases to be a good measure." Reward functions are proxies for what we actually want—a high-scoring response according to our reward model isn't necessarily a good response. Models under optimization pressure find these gaps. A verbose model might score higher because length correlates with helpfulness in training data, even when brevity is better. A flattering model might score higher because humans prefer agreeable responses, even when disagreement is warranted.
Common hacking patterns: Length exploitation (adding unnecessary content), sycophancy (excessive agreement), format gaming (using structures that score well regardless of content), repetition (repeating phrases that increase scores), and keyword stuffing (including words the reward model associates with quality). Each hack exploits a legitimate correlation in training data, taken to an extreme the training data didn't contain.
Defense in depth: No single mitigation prevents all hacking. Effective approaches combine multiple defenses: KL penalty to prevent drifting too far from a known-good reference policy, length normalization to remove length bias, ensemble rewards to make gaming harder (hard to simultaneously fool multiple models), and human evaluation to catch hacks that pass automated checks.
class RewardHackingMitigations:
"""
Techniques to prevent reward hacking in RL training.
"""
@staticmethod
def length_penalty(response: str, target_length: int = 500) -> float:
"""
Penalize responses that are too long or too short.
Prevents verbose padding or terse non-answers.
"""
length = len(response.split())
ratio = length / target_length
if ratio < 0.5:
return -0.1 * (0.5 - ratio) # Too short
elif ratio > 2.0:
return -0.1 * (ratio - 2.0) # Too long
return 0.0
@staticmethod
def repetition_penalty(response: str) -> float:
"""
Penalize repetitive text.
Detects sentence-level and phrase-level repetition.
"""
sentences = response.split('.')
unique_sentences = set(s.strip().lower() for s in sentences if s.strip())
if len(sentences) > 1:
repetition_ratio = 1 - (len(unique_sentences) / len(sentences))
if repetition_ratio > 0.3:
return -0.2 * repetition_ratio
# N-gram repetition
words = response.lower().split()
trigrams = [tuple(words[i:i+3]) for i in range(len(words)-2)]
unique_trigrams = set(trigrams)
if len(trigrams) > 0:
trigram_repetition = 1 - (len(unique_trigrams) / len(trigrams))
if trigram_repetition > 0.3:
return -0.1 * trigram_repetition
return 0.0
@staticmethod
def kl_penalty(
policy_logprobs: torch.Tensor,
reference_logprobs: torch.Tensor,
beta: float = 0.1
) -> torch.Tensor:
"""
KL divergence penalty to prevent drift from reference.
This is the standard RLHF regularization term.
"""
kl = policy_logprobs - reference_logprobs
return -beta * kl.mean()
@staticmethod
def ensemble_reward(
response: str,
reward_models: list,
problem: str
) -> float:
"""
Use ensemble of reward models to reduce hackability.
Hard to hack multiple different models simultaneously.
"""
rewards = [rm.score(problem, response) for rm in reward_models]
# Conservative: use minimum
min_reward = min(rewards)
# Or: use variance-penalized mean
mean_reward = sum(rewards) / len(rewards)
std_reward = (sum((r - mean_reward)**2 for r in rewards) / len(rewards)) ** 0.5
# Penalize high variance (disagreement might indicate hacking)
return mean_reward - 0.5 * std_reward
PPO: Proximal Policy Optimization
PPO Fundamentals
PPO is the dominant algorithm for RLHF. It balances exploration with stability through clipped objectives.
Why PPO dominates RLHF: Before PPO, policy gradient methods were notoriously unstable—a single bad update could collapse months of training. Vanilla policy gradients have high variance, and trust region methods (TRPO) worked but were computationally expensive with second-order optimization. PPO achieves trust region-like stability with only first-order methods by using a clever clipping mechanism. This combination of stability and simplicity made PPO the default choice for RLHF.
The actor-critic architecture: PPO uses two neural networks: the policy (actor) that generates text, and the value function (critic) that estimates expected future reward from any state. The critic enables better credit assignment—instead of asking "was this response good?" we can ask "at each token, was the remaining trajectory better than expected?" This token-level feedback accelerates learning compared to sparse episode-level rewards.
The memory challenge for LLMs: PPO's actor-critic setup means storing two full language models in memory: the policy being trained, plus the value model (typically the same size). Add a frozen reference policy for KL computation and you need ~3x model memory. This is why GRPO, which eliminates the critic, has become popular for large models—it halves the memory requirement.
class PPOConfig:
"""Configuration for PPO training."""
# Core hyperparameters
learning_rate: float = 1e-5
gamma: float = 0.99 # Discount factor
gae_lambda: float = 0.95 # GAE parameter
clip_epsilon: float = 0.2 # PPO clipping
value_clip_epsilon: float = 0.2 # Value function clipping
# Training parameters
ppo_epochs: int = 4 # Updates per batch
mini_batch_size: int = 64
max_grad_norm: float = 0.5
# KL penalty (alternative to clipping)
use_kl_penalty: bool = False
target_kl: float = 0.02
kl_coef: float = 0.1
# Entropy bonus
entropy_coef: float = 0.01
class PPOTrainer:
"""
Proximal Policy Optimization for language models.
Implements the standard PPO algorithm with:
- Clipped surrogate objective
- Value function with optional clipping
- GAE for advantage estimation
- KL penalty as regularization
"""
def __init__(
self,
policy_model: nn.Module,
value_model: nn.Module,
reward_model: nn.Module,
reference_model: nn.Module,
tokenizer,
config: PPOConfig
):
self.policy = policy_model
self.value = value_model
self.reward = reward_model
self.reference = reference_model
self.tokenizer = tokenizer
self.config = config
# Optimizers
self.policy_optimizer = torch.optim.AdamW(
self.policy.parameters(),
lr=config.learning_rate
)
self.value_optimizer = torch.optim.AdamW(
self.value.parameters(),
lr=config.learning_rate
)
def generate_and_score(
self,
prompts: list[str],
generation_config: dict = None
) -> dict:
"""
Generate responses and compute rewards.
Returns:
dict with responses, rewards, values, logprobs
"""
generation_config = generation_config or {
"max_new_tokens": 512,
"temperature": 1.0,
"do_sample": True
}
# Tokenize prompts
inputs = self.tokenizer(
prompts,
return_tensors="pt",
padding=True,
truncation=True
).to(self.policy.device)
# Generate responses
with torch.no_grad():
outputs = self.policy.generate(
**inputs,
**generation_config,
return_dict_in_generate=True,
output_scores=True
)
response_ids = outputs.sequences[:, inputs.input_ids.size(1):]
responses = self.tokenizer.batch_decode(
response_ids,
skip_special_tokens=True
)
# Compute log probabilities
policy_logprobs = self.compute_logprobs(
self.policy,
inputs.input_ids,
response_ids
)
reference_logprobs = self.compute_logprobs(
self.reference,
inputs.input_ids,
response_ids
)
# Compute rewards
rewards = []
for prompt, response in zip(prompts, responses):
r = self.reward.compute_reward(prompt, response)
rewards.append(r)
rewards = torch.tensor(rewards, device=self.policy.device)
# Compute values
values = self.compute_values(inputs.input_ids, response_ids)
return {
"prompts": prompts,
"responses": responses,
"response_ids": response_ids,
"policy_logprobs": policy_logprobs,
"reference_logprobs": reference_logprobs,
"rewards": rewards,
"values": values
}
def compute_logprobs(
self,
model: nn.Module,
input_ids: torch.Tensor,
response_ids: torch.Tensor
) -> torch.Tensor:
"""Compute log probabilities for responses."""
full_ids = torch.cat([input_ids, response_ids], dim=1)
with torch.no_grad():
outputs = model(full_ids, output_hidden_states=False)
logits = outputs.logits
# Get logprobs for response tokens
response_start = input_ids.size(1)
response_logits = logits[:, response_start-1:-1, :]
logprobs = F.log_softmax(response_logits, dim=-1)
# Gather logprobs for actual tokens
selected_logprobs = torch.gather(
logprobs,
dim=2,
index=response_ids.unsqueeze(-1)
).squeeze(-1)
return selected_logprobs
def compute_values(
self,
input_ids: torch.Tensor,
response_ids: torch.Tensor
) -> torch.Tensor:
"""Compute value estimates for each token."""
full_ids = torch.cat([input_ids, response_ids], dim=1)
with torch.no_grad():
# Value model predicts value at each position
values = self.value(full_ids)
# Get values for response positions
response_start = input_ids.size(1)
return values[:, response_start:]
def compute_advantages(
self,
rewards: torch.Tensor,
values: torch.Tensor,
policy_logprobs: torch.Tensor,
reference_logprobs: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute GAE advantages and returns.
Includes KL penalty in rewards.
"""
batch_size, seq_len = policy_logprobs.shape
# KL penalty at each token
kl_penalty = self.config.kl_coef * (policy_logprobs - reference_logprobs)
# Construct per-token rewards
# Final reward at last token, KL penalty at each token
token_rewards = -kl_penalty # Penalty is subtracted
token_rewards[:, -1] += rewards # Add outcome reward at end
# GAE computation
advantages = torch.zeros_like(token_rewards)
returns = torch.zeros_like(token_rewards)
last_gae = 0
last_value = 0
for t in reversed(range(seq_len)):
if t == seq_len - 1:
next_value = 0 # Terminal
else:
next_value = values[:, t + 1]
delta = token_rewards[:, t] + self.config.gamma * next_value - values[:, t]
advantages[:, t] = last_gae = delta + self.config.gamma * self.config.gae_lambda * last_gae
returns = advantages + values
# Normalize advantages
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
return advantages, returns
def ppo_step(self, batch: dict) -> dict:
"""
Perform one PPO update step.
Args:
batch: dict from generate_and_score
Returns:
dict with loss metrics
"""
# Compute advantages
advantages, returns = self.compute_advantages(
batch["rewards"],
batch["values"],
batch["policy_logprobs"],
batch["reference_logprobs"]
)
# Store old logprobs for ratio computation
old_logprobs = batch["policy_logprobs"].detach()
old_values = batch["values"].detach()
metrics = {
"policy_loss": 0,
"value_loss": 0,
"entropy": 0,
"kl": 0,
"clip_fraction": 0
}
# Multiple PPO epochs
for _ in range(self.config.ppo_epochs):
# Recompute current policy logprobs
current_logprobs = self.compute_logprobs(
self.policy,
batch["input_ids"],
batch["response_ids"]
)
# Ratio for importance sampling
ratio = torch.exp(current_logprobs - old_logprobs)
# Clipped surrogate objective
surr1 = ratio * advantages
surr2 = torch.clamp(
ratio,
1 - self.config.clip_epsilon,
1 + self.config.clip_epsilon
) * advantages
policy_loss = -torch.min(surr1, surr2).mean()
# Value loss with optional clipping
current_values = self.compute_values(
batch["input_ids"],
batch["response_ids"]
)
if self.config.value_clip_epsilon > 0:
value_clipped = old_values + torch.clamp(
current_values - old_values,
-self.config.value_clip_epsilon,
self.config.value_clip_epsilon
)
value_loss1 = (current_values - returns) ** 2
value_loss2 = (value_clipped - returns) ** 2
value_loss = 0.5 * torch.max(value_loss1, value_loss2).mean()
else:
value_loss = 0.5 * ((current_values - returns) ** 2).mean()
# Entropy bonus (encourages exploration)
# Simplified: use negative logprob as proxy
entropy = -current_logprobs.mean()
# Total loss
total_loss = (
policy_loss
+ 0.5 * value_loss
- self.config.entropy_coef * entropy
)
# Update policy
self.policy_optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(
self.policy.parameters(),
self.config.max_grad_norm
)
self.policy_optimizer.step()
# Update value (separate optimizer)
self.value_optimizer.zero_grad()
value_loss.backward()
torch.nn.utils.clip_grad_norm_(
self.value.parameters(),
self.config.max_grad_norm
)
self.value_optimizer.step()
# Track metrics
with torch.no_grad():
kl = (old_logprobs - current_logprobs).mean()
clip_fraction = (
(torch.abs(ratio - 1) > self.config.clip_epsilon)
.float()
.mean()
)
metrics["policy_loss"] += policy_loss.item()
metrics["value_loss"] += value_loss.item()
metrics["entropy"] += entropy.item()
metrics["kl"] += kl.item()
metrics["clip_fraction"] += clip_fraction.item()
# Average over epochs
for k in metrics:
metrics[k] /= self.config.ppo_epochs
return metrics
def train(
self,
prompts: list[str],
num_iterations: int = 1000,
batch_size: int = 32
):
"""Main training loop."""
for iteration in range(num_iterations):
# Sample batch of prompts
batch_prompts = random.sample(prompts, min(batch_size, len(prompts)))
# Generate and score
batch = self.generate_and_score(batch_prompts)
# PPO update
metrics = self.ppo_step(batch)
# Log progress
if iteration % 10 == 0:
print(f"Iteration {iteration}")
print(f" Policy Loss: {metrics['policy_loss']:.4f}")
print(f" Value Loss: {metrics['value_loss']:.4f}")
print(f" KL: {metrics['kl']:.4f}")
print(f" Mean Reward: {batch['rewards'].mean():.4f}")
# Early stopping on KL
if self.config.use_kl_penalty and metrics["kl"] > self.config.target_kl * 1.5:
print(f"KL divergence too high ({metrics['kl']:.4f}), stopping")
break
Understanding the PPOTrainer class: The trainer orchestrates the complex dance of PPO: generate responses, compute rewards, estimate advantages, and update both policy and value networks. The generate_and_score method handles the data collection phase—producing responses and all the information needed for policy updates (logprobs from current and reference policies, rewards, value estimates). The ppo_step method then uses this data for multiple epochs of gradient updates, the key to PPO's sample efficiency.
GAE (Generalized Advantage Estimation) explained: The compute_advantages method implements GAE, a crucial technique for variance reduction. Raw advantages (reward minus value) have high variance because rewards are noisy and values are imperfect estimates. GAE smooths this by exponentially weighting temporal differences: recent steps use more "actual" reward signal, distant steps use more value estimates. The gae_lambda parameter (0.95 typical) controls this tradeoff—higher values trust actual rewards more, lower values trust value estimates more.
The KL penalty integration: Notice how KL penalty is incorporated into token_rewards rather than as a separate loss term. This treats KL divergence as a per-token "cost" of deviating from the reference policy. The final token also receives the outcome reward. This design ensures the advantage computation naturally accounts for both reward optimization and KL regularization, letting GAE handle the credit assignment for both objectives simultaneously.
Why multiple PPO epochs work: Unlike standard RL where each data point is used once, PPO reuses collected data for multiple gradient steps (typically 4). This is possible because the clipping mechanism prevents policy updates from going too far—if the policy changes too much, the clipped objective stops pushing further. This reuse dramatically improves sample efficiency (expensive generation is amortized over multiple updates) without sacrificing stability.
The PPO Objective Function
The clipped surrogate objective is:
L^CLIP(θ) = E[min(r(θ)Â, clip(r(θ), 1-ε, 1+ε)Â)]
where:
r(θ) = π_θ(a|s) / π_θ_old(a|s) (probability ratio)
 = advantage estimate
ε = clip parameter (typically 0.1-0.2)
def ppo_clipped_objective(
policy_logprobs: torch.Tensor,
old_logprobs: torch.Tensor,
advantages: torch.Tensor,
clip_epsilon: float = 0.2
) -> torch.Tensor:
"""
Compute PPO clipped surrogate objective.
The clipping prevents too large policy updates.
"""
# Probability ratio
ratio = torch.exp(policy_logprobs - old_logprobs)
# Unclipped objective
obj1 = ratio * advantages
# Clipped objective
clipped_ratio = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon)
obj2 = clipped_ratio * advantages
# Take minimum (pessimistic bound)
# This is the key to PPO's stability
return torch.min(obj1, obj2).mean()
Why PPO Works
- Trust region approximation: Clipping creates a soft trust region without explicit KL constraints
- Pessimistic bound: Taking minimum prevents overoptimistic updates
- Sample efficiency: Can reuse data for multiple gradient steps
- Stability: Less hyperparameter sensitive than vanilla policy gradient
GRPO: Group Relative Policy Optimization
The GRPO Innovation
GRPO, introduced by DeepSeek, eliminates the value model (critic) by using group statistics for advantage estimation. This halves memory requirements.
The key insight behind GRPO: PPO's value function estimates "how good is this state?"—but for language models, we don't actually need absolute values. We only need to know which responses are relatively better. GRPO exploits this by generating multiple responses to the same prompt and using their reward statistics as the baseline. If response A scores higher than the group average, it has positive advantage; if lower, negative. No learned value function required.
Why groups work as baselines: The value function's job is to answer "what's the expected reward from this state?" For a prompt, this means "across all possible completions, what's the average reward?" GRPO approximates this by sampling: generate K responses, compute their rewards, and use the mean as the baseline. With enough samples (K=8-16 typical), this Monte Carlo estimate is surprisingly accurate—often better than a learned value function that might generalize poorly to new prompts.
The memory dividend: GRPO needs only the policy and reference model—no critic. For a 70B parameter model, this saves ~140GB of GPU memory (the critic would be another 70B). This memory savings is what enabled DeepSeek to train R1 on their available hardware. The tradeoff is computational: you must generate K responses per prompt instead of 1, increasing generation cost. But generation is highly parallelizable, while memory is a hard constraint.
When GRPO beats PPO: GRPO tends to outperform PPO when: (1) prompts are diverse and value generalization is hard, (2) rewards are reliable/verifiable so Monte Carlo estimates are accurate, (3) you're memory-constrained. PPO tends to win when: (1) value functions generalize well across similar prompts, (2) you want maximum sample efficiency (PPO reuses data for multiple updates), (3) rewards are noisy and you need variance reduction from learned baselines.
class GRPOConfig:
"""Configuration for GRPO training."""
learning_rate: float = 1e-6
group_size: int = 8 # Responses per prompt
clip_epsilon: float = 0.2
kl_coef: float = 0.1
max_grad_norm: float = 1.0
# GRPO-specific
normalize_advantages: bool = True
baseline: str = "mean" # "mean" or "median"
class GRPOTrainer:
"""
Group Relative Policy Optimization.
Key insight: Use group statistics instead of learned value function
to estimate advantages.
From DeepSeek R1 paper:
"GRPO foregoes the critic model that is typically the same size as
the policy model, and estimates the baseline from group scores instead."
"""
def __init__(
self,
policy_model: nn.Module,
reference_model: nn.Module,
reward_model, # Can be neural network or rule-based
tokenizer,
config: GRPOConfig
):
self.policy = policy_model
self.reference = reference_model
self.reward = reward_model
self.tokenizer = tokenizer
self.config = config
self.optimizer = torch.optim.AdamW(
self.policy.parameters(),
lr=config.learning_rate
)
def generate_group(
self,
prompt: str,
group_size: int = None
) -> dict:
"""
Generate a group of responses for one prompt.
Returns:
dict with responses, rewards, logprobs for the group
"""
group_size = group_size or self.config.group_size
# Tokenize prompt
inputs = self.tokenizer(
prompt,
return_tensors="pt",
padding=True
).to(self.policy.device)
# Generate multiple responses
responses = []
response_ids_list = []
policy_logprobs_list = []
reference_logprobs_list = []
rewards = []
for _ in range(group_size):
with torch.no_grad():
outputs = self.policy.generate(
**inputs,
max_new_tokens=1024,
temperature=1.0,
do_sample=True,
return_dict_in_generate=True,
output_scores=True
)
response_ids = outputs.sequences[:, inputs.input_ids.size(1):]
response = self.tokenizer.decode(
response_ids[0],
skip_special_tokens=True
)
# Compute logprobs
policy_lp = self.compute_sequence_logprob(
self.policy, inputs.input_ids, response_ids
)
reference_lp = self.compute_sequence_logprob(
self.reference, inputs.input_ids, response_ids
)
# Compute reward
reward = self.reward.compute_reward(prompt, response)
if isinstance(reward, dict):
reward = reward.get("total", reward.get("correctness", 0))
responses.append(response)
response_ids_list.append(response_ids)
policy_logprobs_list.append(policy_lp)
reference_logprobs_list.append(reference_lp)
rewards.append(reward)
return {
"prompt": prompt,
"responses": responses,
"response_ids": response_ids_list,
"policy_logprobs": torch.stack(policy_logprobs_list),
"reference_logprobs": torch.stack(reference_logprobs_list),
"rewards": torch.tensor(rewards, device=self.policy.device)
}
def compute_sequence_logprob(
self,
model: nn.Module,
input_ids: torch.Tensor,
response_ids: torch.Tensor
) -> torch.Tensor:
"""Compute total log probability of response sequence."""
full_ids = torch.cat([input_ids, response_ids], dim=1)
with torch.no_grad():
outputs = model(full_ids)
logits = outputs.logits
# Get logprobs for response tokens
response_start = input_ids.size(1)
response_logits = logits[:, response_start-1:-1, :]
logprobs = F.log_softmax(response_logits, dim=-1)
# Gather logprobs for actual tokens
token_logprobs = torch.gather(
logprobs,
dim=2,
index=response_ids.unsqueeze(-1)
).squeeze(-1)
# Sum for sequence logprob
return token_logprobs.sum(dim=-1)
def compute_group_advantages(
self,
rewards: torch.Tensor,
policy_logprobs: torch.Tensor,
reference_logprobs: torch.Tensor
) -> torch.Tensor:
"""
Compute advantages using group statistics.
This is the key GRPO innovation: no value network needed.
"""
# KL penalty
kl = policy_logprobs - reference_logprobs
# Adjusted rewards (reward - KL penalty)
adjusted_rewards = rewards - self.config.kl_coef * kl
# Group baseline (mean or median)
if self.config.baseline == "mean":
baseline = adjusted_rewards.mean()
else:
baseline = adjusted_rewards.median()
# Advantages relative to group
advantages = adjusted_rewards - baseline
# Optional normalization
if self.config.normalize_advantages:
std = advantages.std()
if std > 1e-8:
advantages = advantages / std
return advantages
def grpo_step(self, group_data: dict) -> dict:
"""
Perform one GRPO update step.
Args:
group_data: dict from generate_group
Returns:
dict with loss metrics
"""
# Compute group advantages
advantages = self.compute_group_advantages(
group_data["rewards"],
group_data["policy_logprobs"],
group_data["reference_logprobs"]
)
# Store old logprobs
old_logprobs = group_data["policy_logprobs"].detach()
# Compute current logprobs (requires grad)
self.policy.train()
current_logprobs = []
for response_ids in group_data["response_ids"]:
inputs = self.tokenizer(
group_data["prompt"],
return_tensors="pt"
).to(self.policy.device)
lp = self.compute_sequence_logprob_with_grad(
self.policy, inputs.input_ids, response_ids
)
current_logprobs.append(lp)
current_logprobs = torch.stack(current_logprobs)
# Probability ratio
ratio = torch.exp(current_logprobs - old_logprobs)
# Clipped objective
surr1 = ratio * advantages
surr2 = torch.clamp(
ratio,
1 - self.config.clip_epsilon,
1 + self.config.clip_epsilon
) * advantages
policy_loss = -torch.min(surr1, surr2).mean()
# Update
self.optimizer.zero_grad()
policy_loss.backward()
torch.nn.utils.clip_grad_norm_(
self.policy.parameters(),
self.config.max_grad_norm
)
self.optimizer.step()
# Metrics
with torch.no_grad():
kl = (old_logprobs - current_logprobs).mean()
clip_fraction = (
(torch.abs(ratio - 1) > self.config.clip_epsilon)
.float()
.mean()
)
return {
"policy_loss": policy_loss.item(),
"mean_reward": group_data["rewards"].mean().item(),
"max_reward": group_data["rewards"].max().item(),
"min_reward": group_data["rewards"].min().item(),
"kl": kl.item(),
"clip_fraction": clip_fraction.item(),
"advantage_std": advantages.std().item()
}
def compute_sequence_logprob_with_grad(
self,
model: nn.Module,
input_ids: torch.Tensor,
response_ids: torch.Tensor
) -> torch.Tensor:
"""Compute sequence logprob with gradients."""
full_ids = torch.cat([input_ids, response_ids], dim=1)
outputs = model(full_ids)
logits = outputs.logits
response_start = input_ids.size(1)
response_logits = logits[:, response_start-1:-1, :]
logprobs = F.log_softmax(response_logits, dim=-1)
token_logprobs = torch.gather(
logprobs,
dim=2,
index=response_ids.unsqueeze(-1)
).squeeze(-1)
return token_logprobs.sum(dim=-1).squeeze(0)
def train(
self,
prompts: list[str],
num_iterations: int = 1000
):
"""Main GRPO training loop."""
for iteration in range(num_iterations):
# Sample prompt
prompt = random.choice(prompts)
# Generate group
group_data = self.generate_group(prompt)
# GRPO update
metrics = self.grpo_step(group_data)
# Log
if iteration % 10 == 0:
print(f"Iteration {iteration}")
print(f" Loss: {metrics['policy_loss']:.4f}")
print(f" Mean Reward: {metrics['mean_reward']:.4f}")
print(f" Reward Range: [{metrics['min_reward']:.2f}, {metrics['max_reward']:.2f}]")
print(f" KL: {metrics['kl']:.4f}")
Understanding the GRPOTrainer: The trainer follows the same pattern as PPO but with a critical simplification: no value network. The generate_group method creates multiple responses for a single prompt—this is the "group" in GRPO. Each response gets its own reward and logprob, giving us K data points from a single prompt. The compute_group_advantages method then uses these K rewards to compute relative advantages, replacing the value function entirely.
The advantage computation dissected: In compute_group_advantages, we first adjust rewards by subtracting the KL penalty (just like PPO). Then we subtract the group baseline (mean or median). The result is a relative score: positive if this response is better than the group average, negative if worse. The normalization step (dividing by standard deviation) ensures gradients are similarly scaled regardless of the reward magnitude.
Why median baselines can help: The code offers both mean and median baselines. Median is more robust to outliers—if one response in a group gets an unusually high or low reward (perhaps due to reward model quirks), it won't skew the baseline. This robustness matters when rewards are noisy. However, median can be problematic when most responses are wrong (reward 0) but a few are right (reward 1): the median would be 0, but you'd want to encourage the correct responses. Mean handles this case better.
The training loop structure: Unlike PPO's batched approach, GRPO processes one prompt at a time with its full group. This is intentional—you need all K responses to compute the group baseline. Parallelism comes from generating the K responses concurrently, not from batching across prompts. For distributed training, different workers can process different prompts simultaneously.
GRPO vs PPO Comparison
| Aspect | PPO | GRPO |
|---|---|---|
| Value model | Required | Not needed |
| Memory | 2x model size | 1x model size |
| Advantage estimation | GAE with learned values | Group statistics |
| Training stability | Very stable | Stable with proper group size |
| Sample efficiency | Can reuse data | Needs fresh samples per prompt |
Why GRPO Works
The key insight is that for language model RL:
- We care about relative quality of responses, not absolute values
- Group statistics provide a good baseline without learning
- The policy already contains implicit value information
PPO:
GRPO:
RLVR: Reinforcement Learning with Verifiable Rewards
The RLVR Approach
RLVR focuses on domains where rewards can be verified automatically (math, code, logic).
The verification advantage: In domains like mathematics and coding, we can definitively check if an answer is correct—no reward model needed, no human labelers required, no risk of reward hacking a learned proxy. A math problem either has the right numerical answer or it doesn't. Code either passes tests or it doesn't. This binary, ground-truth signal is the cleanest reward imaginable.
Why RLVR is transformative for reasoning: Prior to RLVR-style training, teaching models to reason required expensive human-labeled chains of thought. RLVR inverts this: provide only the final answer, let the model discover its own reasoning. If the model stumbles upon better reasoning strategies during exploration, they get reinforced through correct answers. This is how DeepSeek R1-Zero developed emergent reasoning—no one taught it to verify its work or try alternative approaches; it discovered these behaviors because they led to more correct answers.
The scalability unlock: RLVR scales effortlessly. Want to train on a million math problems? Generate them programmatically (synthetic data) and their answers are automatically verifiable. Want to train on code? Use existing test suites. No annotation bottleneck, no human-in-the-loop slowdown. This scalability is why reasoning model progress has accelerated—compute translates directly to capability without labeling costs.
class RLVRTrainer(GRPOTrainer):
"""
Reinforcement Learning with Verifiable Rewards.
Extension of GRPO specifically for verifiable domains.
Key features:
1. Binary rewards from verification (no learned reward model)
2. Domain-specific verifiers
3. Optional process supervision
"""
def __init__(
self,
policy_model: nn.Module,
reference_model: nn.Module,
verifier: 'Verifier',
tokenizer,
config: GRPOConfig,
domain: str = "math"
):
# Use verifier instead of learned reward model
super().__init__(
policy_model,
reference_model,
verifier, # Verifier as reward model
tokenizer,
config
)
self.domain = domain
self.verifier = verifier
def compute_verifiable_reward(
self,
prompt: str,
response: str,
ground_truth: str = None
) -> dict:
"""
Compute reward through verification.
No neural network—just automated checking.
"""
if self.domain == "math":
return self.verify_math(prompt, response, ground_truth)
elif self.domain == "code":
return self.verify_code(prompt, response, ground_truth)
elif self.domain == "logic":
return self.verify_logic(prompt, response, ground_truth)
else:
raise ValueError(f"Unknown domain: {self.domain}")
def verify_math(
self,
problem: str,
response: str,
ground_truth: str
) -> dict:
"""Verify mathematical answer."""
extracted = self.extract_math_answer(response)
# Normalize both answers
try:
pred_val = self.parse_math_expression(extracted)
true_val = self.parse_math_expression(ground_truth)
correct = abs(pred_val - true_val) < 1e-6
except:
# Fall back to string comparison
correct = self.normalize_string(extracted) == self.normalize_string(ground_truth)
# Format check
has_reasoning = len(response) > 100 and any(
kw in response.lower()
for kw in ["therefore", "because", "since", "so ", "thus"]
)
return {
"correctness": 1.0 if correct else 0.0,
"format": 0.1 if has_reasoning else 0.0,
"total": (1.0 if correct else 0.0) + (0.1 if has_reasoning else 0.0)
}
def verify_code(
self,
problem: str,
response: str,
test_cases: str
) -> dict:
"""
Verify code through test execution.
IMPORTANT: Run in sandbox in production!
"""
# Extract code from response
code = self.extract_code(response)
if not code:
return {"correctness": 0.0, "total": 0.0}
# Parse test cases
tests = self.parse_test_cases(test_cases)
# Run tests (sandboxed)
passed = 0
total = len(tests)
for test_input, expected_output in tests:
try:
actual = self.run_code_sandboxed(code, test_input)
if self.outputs_match(actual, expected_output):
passed += 1
except Exception as e:
pass # Test failed
correctness = passed / total if total > 0 else 0.0
return {
"correctness": correctness,
"tests_passed": passed,
"tests_total": total,
"total": correctness
}
def verify_logic(
self,
problem: str,
response: str,
ground_truth: str
) -> dict:
"""Verify logical reasoning."""
# Extract conclusion
conclusion = self.extract_conclusion(response)
# Check against ground truth
correct = self.logic_match(conclusion, ground_truth)
# Check logical structure
has_valid_structure = self.check_logical_structure(response)
return {
"correctness": 1.0 if correct else 0.0,
"structure": 0.2 if has_valid_structure else 0.0,
"total": (1.0 if correct else 0.0) + (0.2 if has_valid_structure else 0.0)
}
# Helper methods
def extract_math_answer(self, response: str) -> str:
"""Extract math answer from response."""
# Check for boxed answer
boxed = re.search(r"\\boxed{([^}]+)}", response)
if boxed:
return boxed.group(1)
# Check for "answer is X"
answer_match = re.search(
r"(?:answer|result)[:\s]+([^\n.,]+)",
response,
re.IGNORECASE
)
if answer_match:
return answer_match.group(1).strip()
# Last number in response
numbers = re.findall(r"-?\d+\.?\d*", response)
if numbers:
return numbers[-1]
return response.strip().split()[-1]
def parse_math_expression(self, expr: str) -> float:
"""Parse mathematical expression to float."""
# Clean expression
expr = expr.strip()
expr = re.sub(r"[,\s]", "", expr)
# Handle fractions
if "/" in expr:
parts = expr.split("/")
return float(parts[0]) / float(parts[1])
# Handle percentages
if "%" in expr:
return float(expr.replace("%", "")) / 100
return float(expr)
def extract_code(self, response: str) -> str:
"""Extract code block from response."""
# Look for fenced code blocks
code_match = re.search(
r"```(?:python|py)?\n(.*?)```",
response,
re.DOTALL
)
if code_match:
return code_match.group(1)
# Look for indented code
lines = response.split('\n')
code_lines = [l for l in lines if l.startswith(' ') or l.startswith('\t')]
if code_lines:
return '\n'.join(l.lstrip() for l in code_lines)
return ""
class MathVerifier:
"""
Specialized math verifier supporting various formats.
"""
def __init__(self):
self.symbolic_engine = None # Optional: sympy for symbolic verification
def verify(
self,
problem: str,
response: str,
ground_truth: str
) -> float:
"""Main verification entry point."""
predicted = self.extract_answer(response)
# Try exact match first
if self.exact_match(predicted, ground_truth):
return 1.0
# Try numeric comparison
if self.numeric_match(predicted, ground_truth):
return 1.0
# Try symbolic equivalence
if self.symbolic_match(predicted, ground_truth):
return 1.0
return 0.0
def exact_match(self, pred: str, truth: str) -> bool:
"""Exact string match after normalization."""
return self.normalize(pred) == self.normalize(truth)
def numeric_match(self, pred: str, truth: str, tol: float = 1e-6) -> bool:
"""Numeric comparison with tolerance."""
try:
pred_val = self.to_number(pred)
truth_val = self.to_number(truth)
return abs(pred_val - truth_val) < tol
except:
return False
def symbolic_match(self, pred: str, truth: str) -> bool:
"""Symbolic equivalence using SymPy."""
if self.symbolic_engine is None:
return False
try:
from sympy import simplify, sympify
pred_expr = sympify(pred)
truth_expr = sympify(truth)
return simplify(pred_expr - truth_expr) == 0
except:
return False
def normalize(self, s: str) -> str:
"""Normalize answer string."""
s = s.strip().lower()
s = re.sub(r"\s+", " ", s)
s = re.sub(r"[$,]", "", s)
return s
def to_number(self, s: str) -> float:
"""Convert string to number."""
s = self.normalize(s)
# Handle fractions
if "/" in s:
num, denom = s.split("/")
return float(num) / float(denom)
# Handle percentages
if "%" in s:
return float(s.replace("%", "")) / 100
# Handle scientific notation
return float(s)
def extract_answer(self, response: str) -> str:
"""Extract final answer from response."""
# Pattern matching for common formats
patterns = [
r"\\boxed{([^}]+)}",
r"(?:final\s+)?answer[:\s]+([^\n.]+)",
r"therefore[,:\s]+([^\n.]+)",
r"=\s*([^\n.]+)$"
]
for pattern in patterns:
match = re.search(pattern, response, re.IGNORECASE | re.MULTILINE)
if match:
return match.group(1).strip()
# Fallback: last line
return response.strip().split('\n')[-1]
DPO and Other Alternatives
Direct Preference Optimization (DPO)
DPO eliminates the reward model entirely by optimizing preferences directly.
The theoretical insight behind DPO: Here's a surprising result from RL theory: if you have a Bradley-Terry preference model (humans prefer response A over B with probability σ(r(A) - r(B))), and you regularize with KL divergence to a reference policy, then the optimal policy has a closed-form solution. You can derive the reward function directly from the optimal policy. DPO exploits this by working backward: instead of learning a reward then optimizing against it (RLHF), optimize the policy directly against preferences. Same end result, fewer moving parts.
Why DPO is simpler: Traditional RLHF is a three-stage process: (1) train a reward model on preferences, (2) use RL to optimize against the reward, (3) manage the RL instabilities. DPO collapses this to one stage: supervised learning on preference pairs. No reward model to train, no RL hyperparameter tuning, no PPO stability concerns. This simplicity has made DPO extremely popular—many production systems now use DPO instead of PPO.
The hidden cost of simplicity: DPO assumes your preference pairs are good quality. With RLHF, the reward model can generalize beyond the training preferences—it learns a quality function applicable to novel responses. DPO has no such generalization; it only learns from the specific (chosen, rejected) pairs in your dataset. If your dataset doesn't cover the response space well, DPO may not learn the right policy. This is why some practitioners use DPO as a warm-start, followed by PPO for final polish.
class DPOTrainer:
"""
Direct Preference Optimization.
Optimizes policy directly from preferences without explicit reward modeling.
Key insight: The optimal policy under Bradley-Terry model has closed form,
so we can skip reward model training entirely.
Loss: -log σ(β * (log π(y_w|x)/π_ref(y_w|x) - log π(y_l|x)/π_ref(y_l|x)))
where:
y_w = winning (preferred) response
y_l = losing (non-preferred) response
β = temperature parameter
"""
def __init__(
self,
policy_model: nn.Module,
reference_model: nn.Module,
tokenizer,
beta: float = 0.1,
learning_rate: float = 1e-6
):
self.policy = policy_model
self.reference = reference_model
self.tokenizer = tokenizer
self.beta = beta
self.optimizer = torch.optim.AdamW(
self.policy.parameters(),
lr=learning_rate
)
# Freeze reference model
for param in self.reference.parameters():
param.requires_grad = False
def compute_logprobs(
self,
model: nn.Module,
input_ids: torch.Tensor,
labels: torch.Tensor,
attention_mask: torch.Tensor
) -> torch.Tensor:
"""Compute per-token log probabilities."""
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask
)
logits = outputs.logits
# Shift for next-token prediction
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Compute log probs
log_probs = F.log_softmax(shift_logits, dim=-1)
# Gather actual token log probs
per_token_logps = torch.gather(
log_probs,
dim=2,
index=shift_labels.unsqueeze(-1)
).squeeze(-1)
return per_token_logps
def dpo_loss(
self,
chosen_input_ids: torch.Tensor,
chosen_attention_mask: torch.Tensor,
rejected_input_ids: torch.Tensor,
rejected_attention_mask: torch.Tensor,
chosen_labels: torch.Tensor,
rejected_labels: torch.Tensor
) -> tuple[torch.Tensor, dict]:
"""
Compute DPO loss.
Returns loss and metrics dict.
"""
# Policy log probs
policy_chosen_logps = self.compute_logprobs(
self.policy,
chosen_input_ids,
chosen_labels,
chosen_attention_mask
).sum(dim=-1)
policy_rejected_logps = self.compute_logprobs(
self.policy,
rejected_input_ids,
rejected_labels,
rejected_attention_mask
).sum(dim=-1)
# Reference log probs
with torch.no_grad():
ref_chosen_logps = self.compute_logprobs(
self.reference,
chosen_input_ids,
chosen_labels,
chosen_attention_mask
).sum(dim=-1)
ref_rejected_logps = self.compute_logprobs(
self.reference,
rejected_input_ids,
rejected_labels,
rejected_attention_mask
).sum(dim=-1)
# Log ratios
chosen_log_ratio = policy_chosen_logps - ref_chosen_logps
rejected_log_ratio = policy_rejected_logps - ref_rejected_logps
# DPO loss
logits = self.beta * (chosen_log_ratio - rejected_log_ratio)
loss = -F.logsigmoid(logits).mean()
# Metrics
with torch.no_grad():
chosen_rewards = self.beta * chosen_log_ratio
rejected_rewards = self.beta * rejected_log_ratio
accuracy = (chosen_rewards > rejected_rewards).float().mean()
reward_margin = (chosen_rewards - rejected_rewards).mean()
metrics = {
"loss": loss.item(),
"accuracy": accuracy.item(),
"reward_margin": reward_margin.item(),
"chosen_reward": chosen_rewards.mean().item(),
"rejected_reward": rejected_rewards.mean().item()
}
return loss, metrics
def train_step(self, batch: dict) -> dict:
"""Single training step."""
self.policy.train()
# Tokenize chosen and rejected
chosen_encodings = self.tokenizer(
batch["chosen"],
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048
).to(self.policy.device)
rejected_encodings = self.tokenizer(
batch["rejected"],
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048
).to(self.policy.device)
# Compute loss
loss, metrics = self.dpo_loss(
chosen_input_ids=chosen_encodings.input_ids,
chosen_attention_mask=chosen_encodings.attention_mask,
rejected_input_ids=rejected_encodings.input_ids,
rejected_attention_mask=rejected_encodings.attention_mask,
chosen_labels=chosen_encodings.input_ids,
rejected_labels=rejected_encodings.input_ids
)
# Update
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 1.0)
self.optimizer.step()
return metrics
**The DPO loss unpacked:** The core of DPO is the loss function `dpo_loss`. It computes log-ratios for chosen and rejected responses: how much more (or less) likely are they under the current policy versus the reference? The difference between these ratios, scaled by β, becomes the "implicit reward margin." The logsigmoid loss encourages the chosen response's implicit reward to exceed the rejected response's. When the model already prefers the chosen response strongly, the loss is low; when it prefers rejected, the loss is high.
**Why the reference model matters:** DPO requires a frozen reference model—typically the pre-trained or instruction-tuned starting point. The log-ratios are computed relative to this reference. Without it, the model could trivially minimize the loss by assigning extreme probabilities (100% to chosen, 0% to rejected). The reference anchors the policy, ensuring it stays close to a known-good distribution while adjusting preferences.
**The β hyperparameter:** β controls how aggressively the model exploits preference signals. High β (0.5+) makes the model strongly prefer chosen over rejected, but risks overfitting to the specific pairs. Low β (0.01-0.1) makes gentler adjustments, preserving more of the reference model's behavior. Typical values range from 0.1-0.5 depending on dataset size and quality.
class IPOTrainer(DPOTrainer):
"""
Identity Preference Optimization.
Variant of DPO that uses different loss formulation:
(log_ratio_chosen - log_ratio_rejected - margin)^2
More robust to label noise.
"""
def __init__(self, *args, margin: float = 0.5, **kwargs):
super().__init__(*args, **kwargs)
self.margin = margin
def ipo_loss(
self,
chosen_log_ratio: torch.Tensor,
rejected_log_ratio: torch.Tensor
) -> torch.Tensor:
"""IPO loss: squared difference from margin."""
diff = chosen_log_ratio - rejected_log_ratio
loss = (diff - self.margin) ** 2
return loss.mean()
class KTOTrainer:
"""
Kahneman-Tversky Optimization.
Uses prospect theory-inspired loss that doesn't require paired comparisons.
Works with just desirable/undesirable labels.
"""
def __init__(
self,
policy_model: nn.Module,
reference_model: nn.Module,
tokenizer,
beta: float = 0.1,
desirable_weight: float = 1.0,
undesirable_weight: float = 1.0
):
self.policy = policy_model
self.reference = reference_model
self.tokenizer = tokenizer
self.beta = beta
self.desirable_weight = desirable_weight
self.undesirable_weight = undesirable_weight
def kto_loss(
self,
policy_logps: torch.Tensor,
reference_logps: torch.Tensor,
is_desirable: torch.Tensor,
kl_reference: float
) -> torch.Tensor:
"""
KTO loss.
Different treatment for desirable vs undesirable examples.
"""
log_ratio = policy_logps - reference_logps
# Desirable: encourage
desirable_loss = -F.logsigmoid(self.beta * (log_ratio - kl_reference))
# Undesirable: discourage
undesirable_loss = -F.logsigmoid(self.beta * (kl_reference - log_ratio))
# Combine
loss = torch.where(
is_desirable,
self.desirable_weight * desirable_loss,
self.undesirable_weight * undesirable_loss
)
return loss.mean()
Comparison of Training Methods
| Method | Reward Model | Value Model | Data Requirement | Complexity |
|---|---|---|---|---|
| PPO | Yes | Yes | Prompts + preferences | High |
| GRPO | Yes | No | Prompts + rewards | Medium |
| DPO | No | No | Preference pairs | Low |
| RLVR | Verifier only | No | Prompts + ground truth | Medium |
| KTO | No | No | Single labels | Low |
How o1 and R1 Are Trained
OpenAI o1/o3 Training
Based on public information from OpenAI:
class O1TrainingPipeline:
"""
Hypothetical O1 training pipeline based on public information.
Key elements:
1. Large-scale RL with chain-of-thought
2. Hidden reasoning (not shown to users)
3. Test-time search exploration
"""
def __init__(self):
self.stages = [
"pretraining",
"sft_on_reasoning_traces",
"large_scale_rl",
"safety_alignment"
]
def stage_1_pretraining(self, model, data):
"""Standard GPT pretraining."""
# Massive internet text
# Same as GPT-4 base
pass
def stage_2_sft(self, model, reasoning_data):
"""
SFT on high-quality reasoning traces.
Data includes:
- Math solutions with detailed steps
- Code with explanations
- Scientific reasoning
"""
# Format: problem → <think>reasoning</think> answer
pass
def stage_3_rl(self, model, problems, reward_model):
"""
Large-scale reinforcement learning.
From OpenAI: "o1 learns to hone its chain of thought and refine
the strategies it uses. It learns to recognize and correct its
mistakes. It learns to break down tricky steps into simpler ones.
It learns to try a different approach when the current one isn't
working."
Key features:
- Outcome-based rewards
- Likely uses PPO or similar
- Extended reasoning encouraged
"""
# Train to maximize answer correctness
# Reward model evaluates final answer quality
pass
def stage_4_safety(self, model):
"""
Safety alignment through RLHF.
Includes:
- Chain-of-thought monitoring
- Refusal training
- Red-teaming
"""
pass
class O3Innovations:
"""
O3 specific innovations.
From OpenAI: "o3 generates a diverse set of candidate CoTs, each
representing a distinct step-by-step reasoning pathway to solve the
task. This process mimics a human iterating over different drafts of
a solution before settling on the best one."
"""
def test_time_search(self, problem):
"""
Generate multiple reasoning paths, select best.
o3 can be run at different compute levels:
- Low: fewer paths explored
- Medium: moderate exploration
- High: extensive search
"""
pass
def tool_use_through_rl(self, model):
"""
From OpenAI: "trained to use tools through reinforcement learning—
teaching them not just how to use tools, but to reason about when
to use them."
"""
pass
DeepSeek R1 Training
DeepSeek published their training approach in detail:
class DeepSeekR1TrainingPipeline:
"""
DeepSeek R1 training pipeline.
Two variants:
1. R1-Zero: Pure RL without SFT
2. R1: Multi-stage with cold start data
Key innovation: GRPO with rule-based rewards
"""
def __init__(self, base_model: str = "DeepSeek-V3"):
self.base_model = base_model
def train_r1_zero(
self,
model,
math_problems: list[dict],
code_problems: list[dict]
):
"""
R1-Zero: Pure RL training.
From paper: "We directly applied reinforcement learning to the
base model without relying on supervised fine-tuning (SFT) as a
preliminary step. This approach allows the model to explore
chain-of-thought for solving complex problems."
Remarkably, reasoning emerges naturally!
"""
# Rule-based reward system
reward_system = RuleBasedRewardSystem(
format_weight=0.1,
correctness_weight=1.0
)
# GRPO training
trainer = GRPOTrainer(
policy_model=model,
reference_model=model.copy(), # Initial model as reference
reward_model=reward_system,
tokenizer=self.tokenizer,
config=GRPOConfig(
group_size=8,
kl_coef=0.04 # Low KL penalty
)
)
# Train on math + code
all_problems = math_problems + code_problems
trainer.train(all_problems, num_iterations=100000)
return model
def train_r1(self, model):
"""
R1: Multi-stage training for better readability.
Addresses R1-Zero issues:
- Language mixing
- Poor readability
- Inconsistent formatting
"""
# Stage 1: Cold start with high-quality reasoning examples
model = self.stage_1_cold_start(model)
# Stage 2: RL with reasoning
model = self.stage_2_reasoning_rl(model)
# Stage 3: Rejection sampling for diverse tasks
model = self.stage_3_rejection_sampling(model)
# Stage 4: Final RL with all data
model = self.stage_4_final_rl(model)
return model
def stage_1_cold_start(self, model):
"""
SFT on curated reasoning examples.
From paper: "A small amount of cold-start data is collected before
reinforcement learning to prevent the model from learning poor
readability styles at the beginning of RL."
Thousands of examples (not millions).
"""
cold_start_data = self.collect_cold_start_data()
# Format: Include thinking process
formatted_data = [
f"Problem: {d['problem']}\n<think>{d['reasoning']}</think>\n{d['answer']}"
for d in cold_start_data
]
# Standard SFT
model = supervised_finetune(model, formatted_data)
return model
def stage_2_reasoning_rl(self, model):
"""
RL focused on reasoning quality.
Verifiable domains: math, code, logic
"""
verifier = MathVerifier()
trainer = RLVRTrainer(
policy_model=model,
reference_model=model.copy(),
verifier=verifier,
tokenizer=self.tokenizer,
config=GRPOConfig(group_size=16),
domain="math"
)
# Train until convergence
trainer.train(self.math_problems, num_iterations=50000)
return model
def stage_3_rejection_sampling(self, model):
"""
Use model to generate training data for diverse tasks.
For tasks without verifiable rewards (writing, open QA),
use rejection sampling:
1. Generate multiple responses
2. Filter by quality (using a judge model or heuristics)
3. Use top responses for SFT
"""
diverse_prompts = self.load_diverse_prompts()
accepted_data = []
for prompt in diverse_prompts:
# Generate multiple responses
responses = [
model.generate(prompt, temperature=0.8)
for _ in range(8)
]
# Score and filter
scored = [(r, self.quality_score(prompt, r)) for r in responses]
scored.sort(key=lambda x: x[1], reverse=True)
# Keep top responses
accepted_data.extend([
{"prompt": prompt, "response": r}
for r, score in scored[:2] # Top 2
if score > 0.7
])
# SFT on filtered data
model = supervised_finetune(model, accepted_data)
return model
def stage_4_final_rl(self, model):
"""
Final RL combining all reward signals.
- Verifiable rewards for math/code
- Preference rewards for other tasks
- Safety constraints
"""
# Combined reward
def combined_reward(prompt, response):
# Detect task type
if is_math(prompt):
return self.math_verifier.verify(prompt, response, ground_truth)
elif is_code(prompt):
return self.code_verifier.verify(prompt, response, test_cases)
else:
return self.preference_model.score(prompt, response)
# Final GRPO training
trainer = GRPOTrainer(
policy_model=model,
reference_model=model.copy(),
reward_model=combined_reward,
tokenizer=self.tokenizer,
config=GRPOConfig(group_size=8)
)
trainer.train(self.all_prompts, num_iterations=10000)
return model
Emergent Capabilities
Both o1/o3 and R1 exhibit emergent reasoning behaviors not explicitly trained:
class EmergentBehaviors:
"""
Behaviors that emerge from RL training on reasoning tasks.
From DeepSeek: "it naturally emerges with numerous powerful and
intriguing reasoning behaviors."
"""
@staticmethod
def self_verification():
"""Model learns to check its own work."""
# Emerges from: Correctness rewards
# Model realizes checking catches errors
example = """
<think>
Let me calculate 17 × 24.
17 × 24 = 17 × (20 + 4) = 340 + 68 = 408
Let me verify: 408 ÷ 17 = 24 ✓
The calculation is correct.
</think>
"""
@staticmethod
def backtracking():
"""Model learns to try different approaches."""
# Emerges from: Exploring multiple paths during RL
example = """
<think>
First approach: direct multiplication
17 × 24 = ... hmm, let me try a different way.
Alternative: break down 24 = 25 - 1
17 × 25 - 17 × 1 = 425 - 17 = 408
Both give 408, confirming the answer.
</think>
"""
@staticmethod
def problem_decomposition():
"""Model learns to break complex problems into parts."""
# Emerges from: Complex problems requiring multiple steps
example = """
<think>
This problem has several parts:
1. First, I need to find X
2. Then, use X to calculate Y
3. Finally, combine to get the answer
Starting with part 1...
</think>
"""
@staticmethod
def metacognition():
"""Model learns to reason about its own reasoning."""
# Emerges from: Self-consistency and error correction
example = """
<think>
I'm not confident about this step. Let me reconsider.
The logic seems sound, but I should double-check the arithmetic.
Actually, I made an error earlier. Let me correct it.
</think>
"""
Distillation: Compressing Reasoning
Knowledge Distillation for Reasoning
Distillation transfers reasoning capabilities from large models to smaller ones.
Why distillation enables practical deployment: Training a 70B parameter reasoning model requires massive compute and produces a model too large for most deployment scenarios. Distillation offers an alternative: train the large model once, then transfer its capabilities to smaller models. A well-distilled 7B model can achieve 80-90% of the 70B model's reasoning performance at a fraction of the inference cost. This is how reasoning models become practical for real applications.
The mechanism of knowledge transfer: Unlike retraining from scratch, distillation gives the student a "cheat sheet"—the teacher's solutions. The student doesn't need to discover reasoning patterns through costly exploration; it just needs to imitate patterns the teacher already found. This is far more sample-efficient. The teacher's reasoning traces act as demonstrations, and the student learns via supervised fine-tuning on these traces.
Beyond mere imitation: Good distillation transfers not just answers but reasoning styles. When a student trains on thousands of teacher-generated proofs, it internalizes patterns: "when you see this type of problem, try this approach." The student may not understand why these patterns work, but it can replicate them. Interestingly, distilled students sometimes outperform teachers on specific task types—the compression forces the student to extract the most reliable patterns, pruning away teacher idiosyncrasies.
class ReasoningDistillation:
"""
Distill reasoning capabilities from teacher to student.
DeepSeek released distilled models at 1.5B, 7B, 8B, 14B, 32B, 70B.
"""
def __init__(
self,
teacher_model: nn.Module,
student_model: nn.Module,
tokenizer,
temperature: float = 2.0
):
self.teacher = teacher_model
self.student = student_model
self.tokenizer = tokenizer
self.temperature = temperature
# Freeze teacher
for param in self.teacher.parameters():
param.requires_grad = False
def generate_reasoning_traces(
self,
problems: list[str],
n_per_problem: int = 1
) -> list[dict]:
"""
Generate reasoning traces from teacher.
These become training data for student.
"""
traces = []
for problem in problems:
for _ in range(n_per_problem):
with torch.no_grad():
response = self.teacher.generate(
f"{problem}\n<think>",
max_new_tokens=2048,
temperature=0.7,
do_sample=True
)
traces.append({
"problem": problem,
"reasoning": response
})
return traces
def distill_on_traces(
self,
traces: list[dict],
epochs: int = 3,
learning_rate: float = 1e-5
):
"""
Train student to reproduce teacher's reasoning.
Standard SFT on teacher-generated traces.
"""
optimizer = torch.optim.AdamW(
self.student.parameters(),
lr=learning_rate
)
self.student.train()
for epoch in range(epochs):
total_loss = 0
for trace in traces:
# Format input
text = f"Problem: {trace['problem']}\n{trace['reasoning']}"
inputs = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=2048
).to(self.student.device)
# Forward pass
outputs = self.student(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
labels=inputs.input_ids
)
loss = outputs.loss
# Backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss / len(traces):.4f}")
def distill_with_kl(
self,
prompts: list[str],
epochs: int = 3
):
"""
Distillation using KL divergence on logits.
More sophisticated than trace-based distillation.
"""
optimizer = torch.optim.AdamW(
self.student.parameters(),
lr=1e-5
)
for epoch in range(epochs):
total_loss = 0
for prompt in prompts:
inputs = self.tokenizer(
prompt,
return_tensors="pt"
).to(self.student.device)
# Generate from teacher
with torch.no_grad():
teacher_outputs = self.teacher.generate(
**inputs,
max_new_tokens=512,
return_dict_in_generate=True,
output_scores=True
)
teacher_logits = torch.stack(teacher_outputs.scores, dim=1)
# Student forward on same tokens
full_input = torch.cat([
inputs.input_ids,
teacher_outputs.sequences[:, inputs.input_ids.size(1):]
], dim=1)
student_outputs = self.student(full_input)
student_logits = student_outputs.logits[:, inputs.input_ids.size(1)-1:-1, :]
# KL divergence loss
teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1)
student_log_probs = F.log_softmax(student_logits / self.temperature, dim=-1)
kl_loss = F.kl_div(
student_log_probs,
teacher_probs,
reduction='batchmean'
) * (self.temperature ** 2)
# Update
optimizer.zero_grad()
kl_loss.backward()
optimizer.step()
total_loss += kl_loss.item()
print(f"Epoch {epoch+1}, KL Loss: {total_loss / len(prompts):.4f}")
**Trace-based vs KL-based distillation:** The class shows two approaches. `distill_on_traces` is simpler: generate reasoning traces from the teacher, then train the student with standard language modeling loss to reproduce them. This is just SFT on teacher-generated data. `distill_with_kl` is more sophisticated: it matches the student's probability distribution to the teacher's at each token. KL-based distillation transfers more information (not just "what tokens to generate" but "how confident to be about each option") but is computationally heavier.
**The temperature parameter in distillation:** Higher temperature (2.0+) during trace generation produces more diverse reasoning paths. The student sees varied approaches, not just the teacher's greedy decode. This diversity acts as data augmentation, helping the student generalize. During KL distillation, temperature softens the probability distributions, making them easier to match—the student doesn't need to replicate exact confidence levels, just the rough shape of the distribution.
**Why multiple traces per problem:** Generating `n_per_problem > 1` traces creates diversity even for the same problem. Different random samples lead to different reasoning paths that arrive at the same answer. This teaches the student that there are multiple valid approaches, making it more robust than a student that memorizes a single solution pattern per problem type.
**Progressive distillation rationale:** Distilling directly from 70B to 1.5B loses significant capability—the capacity gap is too large. Progressive distillation chains the process: 70B → 32B → 7B → 1.5B. Each step bridges a smaller gap, preserving more capability. The 32B student learns from 70B traces, then becomes the teacher for 7B. Each intermediate model is optimized for its capacity, extracting what's transferable and discarding what isn't. This cascade typically preserves 10-20% more capability than direct distillation at the smallest scales.
```python
class ProgressiveDistillation:
"""
Distill in stages for better results.
Teacher (70B) → Medium (32B) → Small (7B) → Tiny (1.5B)
"""
def __init__(self, model_sizes: list[str]):
self.model_sizes = model_sizes # e.g., ["70B", "32B", "7B", "1.5B"]
def progressive_distill(self, problems: list[str]):
"""
Chain of distillation.
Each model teaches the next smaller one.
"""
current_teacher = load_model(self.model_sizes[0])
for i, size in enumerate(self.model_sizes[1:], 1):
print(f"Distilling {self.model_sizes[i-1]} → {size}")
student = load_model(size, initialize_fresh=True)
distiller = ReasoningDistillation(
teacher_model=current_teacher,
student_model=student,
tokenizer=self.tokenizer
)
# Generate traces
traces = distiller.generate_reasoning_traces(
problems,
n_per_problem=3
)
# Train student
distiller.distill_on_traces(traces, epochs=5)
# Student becomes teacher for next stage
current_teacher = student
# Save checkpoint
student.save_pretrained(f"distilled_{size}")
return current_teacher
Distillation Results
From DeepSeek's distilled models:
| Model | Base | Math (AIME) | Code (HumanEval) | Notes |
|---|---|---|---|---|
| R1-Distill-Qwen-1.5B | Qwen2.5-1.5B | 28.9% | 52.4% | Tiny but capable |
| R1-Distill-Qwen-7B | Qwen2.5-7B | 55.5% | 71.3% | Good balance |
| R1-Distill-Llama-8B | Llama-3.1-8B | 50.4% | 72.6% | Alternative base |
| R1-Distill-Qwen-14B | Qwen2.5-14B | 69.7% | 80.5% | Strong performance |
| R1-Distill-Qwen-32B | Qwen2.5-32B | 72.6% | 85.4% | Near full model |
| R1-Distill-Llama-70B | Llama-3.3-70B | 79.2% | 86.5% | Best distilled |
Key observation: Even 7B distilled models achieve competitive reasoning performance!
Practical Implementation Guide
Choosing Your Training Approach
def choose_training_approach(
has_preference_data: bool,
has_verifiable_rewards: bool,
compute_budget: str, # "low", "medium", "high"
model_size: str # "small", "medium", "large"
) -> str:
"""
Decision tree for choosing training approach.
"""
if compute_budget == "low":
if has_preference_data:
return "DPO" # No reward model needed
else:
return "SFT" # Just supervised finetuning
if has_verifiable_rewards:
if model_size == "large":
return "GRPO" # Memory efficient
else:
return "RLVR" # Verifiable rewards
if has_preference_data:
if compute_budget == "high":
return "PPO" # Full RLHF
else:
return "DPO" # Simpler alternative
# Default
return "SFT with rejection sampling"
def estimate_training_resources(
approach: str,
model_params: int, # in billions
dataset_size: int # number of examples
) -> dict:
"""
Estimate compute and memory requirements.
"""
estimates = {
"PPO": {
"gpu_memory_gb": model_params * 8, # Policy + Value + Reference
"training_hours_a100": model_params * 0.1 * dataset_size / 10000,
"complexity": "high"
},
"GRPO": {
"gpu_memory_gb": model_params * 4, # Policy + Reference only
"training_hours_a100": model_params * 0.08 * dataset_size / 10000,
"complexity": "medium"
},
"DPO": {
"gpu_memory_gb": model_params * 4, # Policy + Reference
"training_hours_a100": model_params * 0.05 * dataset_size / 10000,
"complexity": "low"
},
"RLVR": {
"gpu_memory_gb": model_params * 4,
"training_hours_a100": model_params * 0.08 * dataset_size / 10000,
"complexity": "medium"
},
"SFT": {
"gpu_memory_gb": model_params * 2, # Just policy
"training_hours_a100": model_params * 0.02 * dataset_size / 10000,
"complexity": "low"
}
}
return estimates.get(approach, estimates["SFT"])
Training Pipeline Template
class ReasoningModelTrainingPipeline:
"""
Complete pipeline for training a reasoning model.
"""
def __init__(
self,
base_model: str,
approach: str = "GRPO",
output_dir: str = "./trained_model"
):
self.base_model = base_model
self.approach = approach
self.output_dir = output_dir
def run(
self,
training_data: dict,
validation_data: dict = None
):
"""
Run complete training pipeline.
training_data format depends on approach:
- PPO/GRPO: {"prompts": [...], "ground_truth": [...]} for verifiable
- DPO: {"prompt": [...], "chosen": [...], "rejected": [...]}
- SFT: {"prompt": [...], "response": [...]}
"""
# 1. Load and prepare model
print("Loading model...")
model = self.load_model()
# 2. (Optional) Cold start SFT
if self.approach in ["GRPO", "PPO"]:
print("Cold start SFT...")
model = self.cold_start_sft(model, training_data)
# 3. Main training
print(f"Main training with {self.approach}...")
model = self.main_training(model, training_data)
# 4. Evaluation
if validation_data:
print("Evaluating...")
metrics = self.evaluate(model, validation_data)
print(f"Validation metrics: {metrics}")
# 5. Save
print(f"Saving to {self.output_dir}...")
model.save_pretrained(self.output_dir)
return model
def load_model(self):
"""Load base model with appropriate settings."""
model = AutoModelForCausalLM.from_pretrained(
self.base_model,
torch_dtype=torch.bfloat16,
device_map="auto"
)
return model
def cold_start_sft(self, model, data):
"""Light SFT before RL."""
# Use ~1000-5000 high-quality examples
sft_examples = data.get("cold_start", data["prompts"][:1000])
trainer = SFTTrainer(
model=model,
train_dataset=sft_examples,
max_steps=500
)
trainer.train()
return model
def main_training(self, model, data):
"""Main training based on approach."""
if self.approach == "GRPO":
return self.train_grpo(model, data)
elif self.approach == "PPO":
return self.train_ppo(model, data)
elif self.approach == "DPO":
return self.train_dpo(model, data)
elif self.approach == "RLVR":
return self.train_rlvr(model, data)
else:
raise ValueError(f"Unknown approach: {self.approach}")
def train_grpo(self, model, data):
"""GRPO training."""
# Create reward function
if "ground_truth" in data:
reward_fn = RuleBasedRewardSystem()
else:
reward_fn = LearnedRewardModel.from_pretrained("...")
trainer = GRPOTrainer(
policy_model=model,
reference_model=model.copy(),
reward_model=reward_fn,
tokenizer=self.tokenizer,
config=GRPOConfig()
)
trainer.train(data["prompts"], num_iterations=10000)
return model
def evaluate(self, model, validation_data):
"""Evaluate trained model."""
correct = 0
total = 0
for problem, answer in zip(validation_data["prompts"], validation_data["ground_truth"]):
response = model.generate(problem, max_new_tokens=1024)
predicted = extract_answer(response)
if predicted == answer:
correct += 1
total += 1
return {
"accuracy": correct / total,
"total": total
}
Conclusion
Training reasoning models involves three key innovations:
- Reward Functions: Rule-based verification enables scalable training without learned reward models
- Efficient Algorithms: GRPO eliminates the value model, halving memory requirements
- Emergent Capabilities: Reasoning behaviors emerge naturally from optimizing for answer correctness
Key takeaways:
- Start simple: Rule-based rewards (RLVR) are often sufficient for verifiable domains
- GRPO > PPO for reasoning: Same quality, half the memory
- Distillation works: 7B models can achieve strong reasoning through distillation
- Cold start matters: A few thousand examples before RL improves stability
- DPO for preferences: When you have preference data but no verifiable rewards
The field is evolving rapidly—2025 will likely bring further innovations in training efficiency and capability transfer.
Frequently Asked Questions
Related Articles
Reasoning Models: A Brief Framework
Understanding o1, o3, DeepSeek R1, and the shift from pre-training scaling to inference-time and training-time scaling—the defining trend of 2025.
RLVR: Reinforcement Learning with Verifiable Rewards
Understanding Reinforcement Learning with Verifiable Rewards (RLVR)—the technique behind DeepSeek R1's reasoning capabilities, process reward models, and when to use verifiable vs human feedback.
RL Algorithms for LLM Training: PPO, GRPO, GSPO, and Beyond
A comprehensive guide to reinforcement learning algorithms for LLM alignment—PPO, GRPO, GSPO, REINFORCE++, DPO, and their variants. Understanding the tradeoffs that power modern AI assistants.