Building AI Agents to Win Math Olympiads: From Training to Inference Optimization
A comprehensive guide to building AI systems that achieve gold-medal performance on IMO and other advanced math competitions—covering training techniques (GRPO, RLVR), inference-time optimization, and agentic architectures with tool use.
Table of Contents
The Gold Medal Frontier
In July 2025, an advanced version of Gemini with Deep Think officially achieved gold-medal standard at the International Mathematical Olympiad, solving 5 of 6 problems perfectly within the 4.5-hour time limit. DeepSeek-Math-V2 followed with gold-level scores on both IMO 2025 and the China Mathematical Olympiad. AI has officially crossed the threshold that was once considered decades away.
This post provides a comprehensive technical guide to building AI systems capable of competition-level mathematical reasoning. We'll cover every layer of the stack:
- Theoretical foundations: Why these techniques work for math
- Training approaches: GRPO, RLVR, curriculum learning, and data strategies
- Process Reward Models: Step-level verification for mathematical reasoning
- Inference-time optimization: Maximizing performance without training
- MCTS and tree search: Structured exploration for complex proofs
- Agentic architectures: Tool use, code execution, and formal verification
- Real problem walkthroughs: Step-by-step IMO problem solutions
- Production deployment: Cost, latency, and scaling considerations
Understanding the Mathematical Reasoning Challenge
The Landscape of Math Competitions
Before diving into techniques, let's understand what we're optimizing for:
| Competition | Format | Difficulty | Time per Problem | Proof Required |
|---|---|---|---|---|
| AMC 10/12 | 25 multiple choice | High school | 3 min avg | No |
| AIME | 15 integer answers (0-999) | Advanced high school | 10 min avg | No |
| USAMO/USAJMO | 6 proof problems | National olympiad | 45 min avg | Yes |
| IMO | 6 proof problems | International olympiad | 45 min avg | Yes |
| Putnam | 12 proof problems | Undergraduate | 25 min avg | Yes |
The jump from AIME (answer-only) to IMO (full proofs) is crucial. AIME problems can be brute-forced with compute—generate many answers, verify numerically. IMO problems require constructing valid mathematical proofs that a human jury would accept.
Current State of the Art (2025)
The landscape has evolved dramatically. Here's the current state as of late 2025:
| System | IMO Performance | AIME 2025 | Training | Open Source |
|---|---|---|---|---|
| Gemini Deep Think | 35/42 (Gold) | 98%+ | Parallel thinking + Novel RL | No |
| DeepSeekMath-V2 | 35/42 (Gold) | 96% | GRPO + Self-verification | Weights only |
| o4-mini (with Python) | Not tested | 99.5% | Unknown | No |
| o3 (with Python) | Not tested | 98.4% | Unknown | No |
| rStar-Math (7B) | N/A | 90% MATH | Self-play MCTS | Yes |
| rStar2-Agent (14B) | N/A | 80.6% | Agentic RL | Yes |
| DeepSeek-Prover-V2 | N/A | 88.9% MiniF2F | RL + Lean 4 | Yes |
| AlphaGeometry2 | 84% geometry | N/A | Neuro-symbolic | Code only |
| QwQ-32B | N/A | 50% | Pure RL | Yes |
| NuminaMath-7B (TIR) | N/A | 58% (AIMO) | SFT + Tools | Yes |
Key observations:
- Tool integration matters: o4-mini jumps from 92.7% to 99.5% with Python access
- Small models can compete: rStar-Math 7B matches o1-preview on MATH
- Self-verification is crucial: DeepSeekMath-V2's generator-verifier loop achieves gold
- Formal verification advancing: DeepSeek-Prover-V2 solves 49 Putnam problems in Lean
What Makes Math Olympiads Fundamentally Hard
Competition mathematics is qualitatively different from benchmarks like GSM8K or MATH:
GSM8K (Grade School Math):
"Sarah has 5 apples. She gives 2 to Tom. How many does she have left?"
→ Pattern: Extract numbers, identify operation, compute
→ Single reasoning step
→ 100% verifiable answer
MATH Dataset (High School Competition):
"Find all positive integers n such that n² + 2n + 2 divides n³ + 3n² + 3n + 2"
→ Pattern: Factor, find divisibility conditions
→ 2-5 reasoning steps
→ Answer verifiable (finite solution set)
IMO Problem (International Olympiad):
"Let n ≥ 3 be an integer, and consider a circle with n + 1 equally spaced
points marked on it. Label these points with the numbers 0, 1, ..., n
so that each label is used exactly once; two labelings are considered
the same if one can be obtained from the other by a rotation of the circle.
A labeling is called beautiful if, for any four labels a < b < c < d
with a + d = b + c, the chord joining the points labeled a and d does
not intersect the chord joining the points labeled b and c.
Let M be the number of beautiful labelings. Prove that M is odd."
→ Pattern: None obvious—requires creative insight
→ 10-50+ reasoning steps
→ Must construct valid proof, not just find answer
→ Multiple valid proof approaches exist
→ Partial credit for incomplete but insightful attempts
The key challenges:
- Novelty: Each IMO features entirely new problems—pattern matching fails
- Proof construction: Must justify every step, not just produce answer
- Creative insight: Often requires discovering non-obvious lemmas
- Multi-domain: Combines algebra, geometry, number theory, combinatorics
- Rigor: Proofs must handle all edge cases and be logically sound
Theoretical Foundations: Why These Techniques Work
Before implementing, let's understand the theory behind math reasoning in LLMs.
The Reasoning-as-Search Framework
Mathematical reasoning can be formalized as search over a proof tree:
┌─────────────────────────────────────────────────────────────────────────┐
│ PROOF TREE STRUCTURE │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────┐ │
│ │ Problem Statement │ │
│ └──────────┬──────────┘ │
│ │ │
│ ┌──────────────────────┼──────────────────────┐ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ Approach 1 │ │ Approach 2 │ │ Approach 3 │ │
│ │ (Algebra) │ │ (Geometry) │ │(Num Theory) │ │
│ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘ │
│ │ │ │ │
│ ┌────┴────┐ ┌────┴────┐ ┌────┴────┐ │
│ │ │ │ │ │ │ │
│ ▼ ▼ ▼ ▼ ▼ ▼ │
│ Step 1.1 Step 1.2 ... ... ... ... │
│ │ │
│ ┌──┴──┐ │
│ │ │ │
│ ▼ ▼ │
│ 1.1.1 1.1.2 │
│ │ │
│ ▼ │
│ ┌──────────┐ │
│ │ SOLUTION │ ← One successful path through the tree │
│ └──────────┘ │
│ │
│ Each node = reasoning step; each path = potential proof │
│ Search algorithms explore this tree to find valid proofs │
│ │
└─────────────────────────────────────────────────────────────────────────┘
This perspective reveals why different techniques work:
| Technique | Search Mechanism | Strength |
|---|---|---|
| Chain-of-Thought | Depth-first, single path | Fast, simple |
| Self-Consistency | Multiple DFS, vote on leaves | Covers different paths |
| Tree-of-Thought | BFS/DFS with explicit branching | Systematic exploration |
| MCTS | UCB-guided tree search | Balances exploration/exploitation |
| Verification-Refinement | Iterative deepening with backtracking | Error recovery |
Why Reinforcement Learning Works for Math
The key insight from DeepSeek R1: math problems have verifiable rewards.
In standard RLHF, we need humans to judge response quality—expensive and noisy. For math:
def math_reward(response: str, ground_truth: str) -> float:
"""
Perfect reward signal—no learned reward model needed.
This is why RLVR (RL with Verifiable Rewards) works so well
for mathematics.
"""
predicted_answer = extract_answer(response)
return 1.0 if verify_equal(predicted_answer, ground_truth) else 0.0
This creates a perfect training signal:
- No reward hacking: Can't game the reward by being confident but wrong
- Dense signal: Every problem provides clear feedback
- Scalable: Can train on millions of problems automatically
The Emergence of Mathematical Reasoning
DeepSeek R1-Zero demonstrated that pure RL without SFT can develop sophisticated reasoning behaviors:
Training progression (from R1 paper):
Early training:
"2 + 3 = 5" (direct answer)
Mid training:
"Let me calculate 2 + 3.
2 + 3 = 5" (minimal reasoning)
Late training:
"<think>
I need to find 2 + 3.
Starting with 2 and adding 3.
2 + 1 = 3
3 + 1 = 4
4 + 1 = 5
So 2 + 3 = 5.
Let me verify: 5 - 3 = 2. Correct.
</think>
The answer is 5." (emergent self-verification)
Key emergent behaviors:
- Extended reasoning: Longer chains of thought for harder problems
- Self-verification: Checking answers before submitting
- Alternative approaches: Trying different methods when stuck
- Backtracking: Recognizing and correcting errors
Part 1: Training Approaches for Math Reasoning
Training a model for mathematical reasoning is fundamentally different from standard language model fine-tuning. You need a multi-phase approach that gradually builds up reasoning capabilities: first establish the format, then optimize for correctness through reinforcement learning, and finally refine through iterative self-improvement. This section walks through each phase with complete implementations.
The Complete Training Pipeline
Building a math reasoning model requires four distinct phases, each building on the previous. Data collection provides the raw material. Cold-start SFT establishes the reasoning format (the <think>...</think> structure). Reinforcement learning optimizes for actual correctness using verifiable rewards. Finally, rejection sampling creates high-quality training data from the model's own successful solutions. Skip any phase and you'll see degraded performance—they're all necessary.
┌─────────────────────────────────────────────────────────────────────────┐
│ TRAINING PIPELINE OVERVIEW │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ Phase 1: Data Collection & Preprocessing │
│ ┌───────────────────────────────────────────────────────────────────┐ │
│ │ • Competition problems (AMC, AIME, IMO, etc.) │ │
│ │ • Textbook exercises with solutions │ │
│ │ • Synthetic problem generation │ │
│ │ • Difficulty labeling and domain classification │ │
│ └───────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ Phase 2: Cold-Start SFT (Optional but Recommended) │
│ ┌───────────────────────────────────────────────────────────────────┐ │
│ │ • Train on high-quality reasoning traces │ │
│ │ • Establish <think>...</think> format │ │
│ │ • Bootstrap from stronger model outputs │ │
│ └───────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ Phase 3: Reinforcement Learning (GRPO/PPO) │
│ ┌───────────────────────────────────────────────────────────────────┐ │
│ │ • Rule-based rewards for correctness │ │
│ │ • Format rewards for reasoning structure │ │
│ │ • KL penalty to prevent mode collapse │ │
│ │ • Curriculum: easy → hard problems │ │
│ └───────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ Phase 4: Rejection Sampling & Iterative Training │
│ ┌───────────────────────────────────────────────────────────────────┐ │
│ │ • Generate many solutions per problem │ │
│ │ • Filter for correct + high-quality reasoning │ │
│ │ • Fine-tune on filtered data │ │
│ │ • Repeat with updated model │ │
│ └───────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Data Collection and Preparation
The quality of training data fundamentally limits model capability.
Why data diversity trumps data quantity: A model trained on 100,000 similar algebra problems will be excellent at algebra but fail on geometry. A model trained on 10,000 diverse problems spanning algebra, geometry, number theory, and combinatorics develops flexible problem-solving skills. For math reasoning, you want coverage across problem types, difficulty levels, and solution approaches. The distribution of your training data directly shapes the distribution of problems your model can solve.
The difficulty distribution tradeoff: Too many easy problems, and the model never learns hard reasoning. Too many hard problems, and the model can't learn the fundamentals that hard problems build on. The 30/40/25/5 split (easy/medium/hard/very hard) is empirically effective—it provides a solid foundation while still exposing the model to olympiad-level challenges. Think of it as curriculum learning encoded in your data distribution.
Synthetic data as a force multiplier: Competition problems are scarce (a few thousand IMO problems exist in total). Synthetic generation creates unlimited training data with verified solutions. The key is quality control: generated problems must be solvable, non-trivial, and mathematically interesting. Template-based generation with symbolic verification ensures every synthetic problem has a known-correct answer.
class MathDatasetBuilder:
"""
Build comprehensive math training dataset.
Key insight: Diversity matters more than size.
1000 diverse problems > 100,000 similar problems.
"""
def __init__(self):
self.sources = {
"competition": CompetitionProblemLoader(),
"textbook": TextbookExerciseLoader(),
"synthetic": SyntheticProblemGenerator(),
"web": WebScrapedMathLoader()
}
self.difficulty_estimator = DifficultyEstimator()
self.domain_classifier = DomainClassifier()
self.deduplicator = SemanticDeduplicator()
def build_dataset(
self,
target_size: int = 500_000,
difficulty_distribution: dict = None
) -> Dataset:
"""
Build balanced training dataset.
Target distribution:
- 30% easy (AMC 10 level)
- 40% medium (AMC 12 / AIME level)
- 25% hard (USAMO level)
- 5% very hard (IMO level)
"""
difficulty_distribution = difficulty_distribution or {
"easy": 0.30,
"medium": 0.40,
"hard": 0.25,
"very_hard": 0.05
}
raw_problems = []
# Collect from all sources
for source_name, loader in self.sources.items():
problems = loader.load_all()
print(f"Loaded {len(problems)} from {source_name}")
raw_problems.extend(problems)
# Deduplicate (semantic similarity, not exact match)
unique_problems = self.deduplicator.deduplicate(
raw_problems,
similarity_threshold=0.85
)
print(f"After deduplication: {len(unique_problems)}")
# Classify each problem
classified = []
for problem in unique_problems:
classified.append({
**problem,
"difficulty": self.difficulty_estimator.estimate(problem["text"]),
"domain": self.domain_classifier.classify(problem["text"])
})
# Balance by difficulty
balanced = self.balance_by_difficulty(
classified,
difficulty_distribution,
target_size
)
return Dataset.from_list(balanced)
def balance_by_difficulty(
self,
problems: list[dict],
distribution: dict,
target_size: int
) -> list[dict]:
"""
Balance dataset to match target difficulty distribution.
"""
by_difficulty = defaultdict(list)
for p in problems:
by_difficulty[p["difficulty"]].append(p)
balanced = []
for difficulty, target_fraction in distribution.items():
target_count = int(target_size * target_fraction)
available = by_difficulty[difficulty]
if len(available) >= target_count:
balanced.extend(random.sample(available, target_count))
else:
# Oversample if not enough
balanced.extend(available)
balanced.extend(random.choices(
available,
k=target_count - len(available)
))
random.shuffle(balanced)
return balanced
class SyntheticProblemGenerator:
"""
Generate synthetic math problems for training.
Key techniques:
1. Template-based generation with variable substitution
2. Problem composition (combine simpler problems)
3. Difficulty scaling (add constraints, increase numbers)
4. LLM-based generation with verification
"""
def __init__(self, llm=None):
self.llm = llm
self.templates = self.load_templates()
self.verifier = SymbolicVerifier()
def generate_from_template(
self,
template: str,
difficulty: str = "medium"
) -> dict:
"""
Generate problem from parameterized template.
Example template:
"Find all positive integers n < {max_n} such that {expression} is a perfect square."
"""
# Sample parameters based on difficulty
params = self.sample_parameters(template, difficulty)
# Instantiate template
problem_text = template.format(**params)
# Solve symbolically to get ground truth
solution = self.verifier.solve(problem_text)
if solution is None:
return None # Skip unsolvable
return {
"text": problem_text,
"answer": solution["answer"],
"solution": solution["steps"],
"difficulty": difficulty,
"synthetic": True
}
def generate_with_llm(
self,
domain: str,
difficulty: str,
style_examples: list[str]
) -> dict:
"""
Use LLM to generate novel problems.
Critical: Must verify generated problems are valid.
"""
prompt = f"""Generate a {difficulty} {domain} problem in the style of these examples:
{chr(10).join(style_examples[:3])}
Requirements:
1. Problem must have a well-defined answer
2. Difficulty should match {difficulty} level
3. Problem should be novel (not copy of examples)
4. Include the correct answer
Output format:
PROBLEM: <problem text>
ANSWER: <correct answer>
SOLUTION: <step-by-step solution>
"""
response = self.llm.generate(prompt, temperature=0.8)
# Parse response
problem = self.parse_generated(response)
# Verify answer is correct
if not self.verify_generated_problem(problem):
return None
return problem
def verify_generated_problem(self, problem: dict) -> bool:
"""
Verify that generated problem is valid and solvable.
"""
# Try to solve with symbolic solver
symbolic_solution = self.verifier.solve(problem["text"])
if symbolic_solution:
# Check if answers match
return self.answers_match(
symbolic_solution["answer"],
problem["answer"]
)
# Fallback: Use LLM verification (less reliable)
return self.llm_verify(problem)
class DifficultyEstimator:
"""
Estimate problem difficulty from text.
Based on:
1. Linguistic complexity
2. Mathematical concepts mentioned
3. Problem length and structure
4. Comparison to known-difficulty problems
"""
def __init__(self):
# Load embedding model for similarity
self.embedder = SentenceTransformer("all-mpnet-base-v2")
# Reference problems with known difficulties
self.reference_problems = self.load_reference_problems()
self.reference_embeddings = self.embedder.encode(
[p["text"] for p in self.reference_problems]
)
def estimate(self, problem_text: str) -> str:
"""
Estimate difficulty level.
Returns: "easy", "medium", "hard", or "very_hard"
"""
# Method 1: Similarity to reference problems
embedding = self.embedder.encode(problem_text)
similarities = cosine_similarity([embedding], self.reference_embeddings)[0]
# Find most similar reference problems
top_k_idx = np.argsort(similarities)[-5:]
similar_difficulties = [
self.reference_problems[i]["difficulty"]
for i in top_k_idx
]
# Weighted vote based on similarity
similarity_estimate = self.weighted_vote(
similar_difficulties,
similarities[top_k_idx]
)
# Method 2: Feature-based estimation
features = self.extract_features(problem_text)
feature_estimate = self.feature_based_estimate(features)
# Combine estimates
return self.combine_estimates(similarity_estimate, feature_estimate)
def extract_features(self, text: str) -> dict:
"""Extract difficulty-indicating features."""
features = {}
# Length features
features["word_count"] = len(text.split())
features["sentence_count"] = len(text.split("."))
# Mathematical concept indicators
hard_concepts = [
"prove", "show that", "if and only if", "for all",
"there exists", "infinitely many", "uniquely determined"
]
features["hard_concept_count"] = sum(
1 for c in hard_concepts if c in text.lower()
)
# Domain-specific keywords
number_theory = ["prime", "divisible", "gcd", "modulo", "congruent"]
combinatorics = ["permutation", "combination", "count", "arrange"]
geometry = ["triangle", "circle", "angle", "perpendicular"]
algebra = ["polynomial", "equation", "inequality", "sequence"]
features["number_theory_score"] = sum(1 for k in number_theory if k in text.lower())
features["combinatorics_score"] = sum(1 for k in combinatorics if k in text.lower())
features["geometry_score"] = sum(1 for k in geometry if k in text.lower())
features["algebra_score"] = sum(1 for k in algebra if k in text.lower())
# Proof requirement (typically harder)
features["requires_proof"] = any(
p in text.lower() for p in ["prove", "show that", "demonstrate"]
)
return features
GRPO: Group Relative Policy Optimization (Deep Dive)
GRPO is DeepSeek's key innovation for training reasoning models. Let's understand it deeply.
Why GRPO is ideal for math training: Math problems have a unique property: answers are verifiable. Unlike creative writing where "better" is subjective, math has ground truth. This means we don't need a learned reward model—just check if the answer is correct. GRPO exploits this by generating multiple solutions per problem, scoring each as correct/incorrect, and using the group statistics as a baseline. No critic network needed, halving memory requirements.
The group baseline intuition: In standard RL, you need a value function to answer "how good is this state?" But for math, we can approximate this by asking "how hard is this problem?" If a problem has 8 solutions and 6 are correct, the "baseline" expectation is ~75% success. A correct solution gets positive advantage (better than baseline); an incorrect one gets negative advantage. The group provides a per-problem baseline automatically.
Why this works better than DPO for reasoning: DPO requires preference pairs—but for math, we don't have "chosen" and "rejected" responses, just correct and incorrect. GRPO naturally handles this: all correct solutions are relatively good, all incorrect are relatively bad. The continuous advantage (not binary) also provides richer signal than DPO's chosen/rejected dichotomy.
class GRPOTrainer:
"""
GRPO: Group Relative Policy Optimization
Key insight from DeepSeek R1 paper:
"GRPO foregoes the critic model and estimates the baseline
from group scores instead"
Why this matters for math:
1. No need to train a value network (halves memory)
2. More stable training (no value estimation errors)
3. Natural handling of sparse rewards (correct/incorrect)
Mathematical formulation:
For each prompt q, generate G responses {o_1, ..., o_G}
Advantage for response o_i:
A_i = r(q, o_i) - mean(r(q, o_j) for j in 1..G)
This is equivalent to using group mean as baseline instead
of learned value function.
"""
def __init__(
self,
policy_model,
reference_model,
tokenizer,
group_size: int = 16,
kl_coef: float = 0.04,
clip_epsilon: float = 0.2,
learning_rate: float = 1e-6,
max_grad_norm: float = 1.0
):
self.policy = policy_model
self.reference = reference_model
self.tokenizer = tokenizer
# GRPO hyperparameters
self.group_size = group_size # G in the paper
self.kl_coef = kl_coef # β in KL penalty
self.clip_epsilon = clip_epsilon # ε in clipped objective
# Optimization
self.optimizer = torch.optim.AdamW(
self.policy.parameters(),
lr=learning_rate,
weight_decay=0.01
)
self.max_grad_norm = max_grad_norm
# Reward function
self.reward_fn = self.create_math_reward_function()
def create_math_reward_function(self) -> callable:
"""
Create the rule-based reward function for math.
From DeepSeek R1 paper:
"For reasoning tasks, we adopt a rule-based reward system
that verifies whether the final answer is correct"
"""
def reward(
prompt: str,
response: str,
ground_truth: str
) -> float:
total_reward = 0.0
# 1. Correctness reward (primary signal)
predicted = self.extract_answer(response)
is_correct = self.verify_answer(predicted, ground_truth)
total_reward += 1.0 if is_correct else 0.0
# 2. Format reward (encourage thinking structure)
has_thinking = "<think>" in response and "</think>" in response
total_reward += 0.1 if has_thinking else 0.0
# 3. Length penalty (discourage padding)
# Only penalize if thinking section is too short
if has_thinking:
think_content = self.extract_thinking(response)
if len(think_content.split()) < 50:
total_reward -= 0.05
# 4. Language consistency penalty
# DeepSeek found models mix languages without this
has_mixed = self.detect_language_mixing(response)
total_reward -= 0.1 if has_mixed else 0.0
return total_reward
return reward
def train_step(
self,
problem: str,
ground_truth: str
) -> dict:
"""
Single GRPO training step.
Algorithm:
1. Generate G responses for the problem
2. Compute rewards for each response
3. Calculate group-relative advantages
4. Update policy with clipped objective
"""
# Phase 1: Generate group of responses
responses = []
policy_logprobs = []
for _ in range(self.group_size):
# Generate with policy model
response, logprob = self.generate_with_logprob(
self.policy,
problem,
max_new_tokens=4096,
temperature=1.0 # High temp for diversity
)
responses.append(response)
policy_logprobs.append(logprob)
# Phase 2: Compute rewards
rewards = torch.tensor([
self.reward_fn(problem, r, ground_truth)
for r in responses
])
# Phase 3: Compute reference model logprobs (for KL)
with torch.no_grad():
reference_logprobs = torch.stack([
self.compute_logprob(self.reference, problem, r)
for r in responses
])
policy_logprobs = torch.stack(policy_logprobs)
# Phase 4: Compute GRPO advantages
advantages = self.compute_grpo_advantages(
rewards,
policy_logprobs,
reference_logprobs
)
# Phase 5: PPO-style update with clipped objective
loss = self.compute_clipped_loss(
problem,
responses,
policy_logprobs.detach(),
advantages
)
# Phase 6: Gradient update
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(
self.policy.parameters(),
self.max_grad_norm
)
self.optimizer.step()
return {
"loss": loss.item(),
"mean_reward": rewards.mean().item(),
"max_reward": rewards.max().item(),
"std_reward": rewards.std().item(),
"correct_count": (rewards > 0.9).sum().item(),
"mean_advantage": advantages.mean().item()
}
def compute_grpo_advantages(
self,
rewards: torch.Tensor,
policy_logprobs: torch.Tensor,
reference_logprobs: torch.Tensor
) -> torch.Tensor:
"""
Compute group-relative advantages.
Formula:
A_i = (r_i - β * KL_i) - mean(r_j - β * KL_j)
where KL_i = log π(o_i|q) - log π_ref(o_i|q)
"""
# Per-response KL divergence
kl = policy_logprobs - reference_logprobs
# KL-adjusted rewards
adjusted_rewards = rewards - self.kl_coef * kl
# Group baseline (GRPO's key innovation)
baseline = adjusted_rewards.mean()
# Advantages
advantages = adjusted_rewards - baseline
# Normalize for training stability
if advantages.std() > 1e-8:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
return advantages
def compute_clipped_loss(
self,
problem: str,
responses: list[str],
old_logprobs: torch.Tensor,
advantages: torch.Tensor
) -> torch.Tensor:
"""
Compute PPO clipped objective.
L = -min(r_t * A_t, clip(r_t, 1-ε, 1+ε) * A_t)
where r_t = π(a|s) / π_old(a|s)
"""
# Recompute logprobs with gradient
new_logprobs = torch.stack([
self.compute_logprob_with_grad(self.policy, problem, r)
for r in responses
])
# Probability ratio
ratio = torch.exp(new_logprobs - old_logprobs)
# Clipped objective
surr1 = ratio * advantages
surr2 = torch.clamp(
ratio,
1 - self.clip_epsilon,
1 + self.clip_epsilon
) * advantages
# Take minimum (pessimistic bound)
loss = -torch.min(surr1, surr2).mean()
return loss
def generate_with_logprob(
self,
model,
prompt: str,
max_new_tokens: int,
temperature: float
) -> tuple[str, torch.Tensor]:
"""
Generate response and return log probability.
"""
input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
# Generate token by token to accumulate logprobs
generated_ids = []
total_logprob = 0.0
for _ in range(max_new_tokens):
with torch.no_grad():
outputs = model(
torch.cat([input_ids] + generated_ids, dim=1)
if generated_ids else input_ids
)
logits = outputs.logits[:, -1, :] / temperature
probs = F.softmax(logits, dim=-1)
# Sample next token
next_token = torch.multinomial(probs, num_samples=1)
# Accumulate log probability
total_logprob += torch.log(probs[0, next_token[0, 0]])
generated_ids.append(next_token)
# Check for EOS
if next_token[0, 0] == self.tokenizer.eos_token_id:
break
response = self.tokenizer.decode(
torch.cat(generated_ids, dim=1)[0],
skip_special_tokens=True
)
return response, total_logprob
### Dr. GRPO: Fixing GRPO's Hidden Biases
Standard GRPO has two critical biases discovered by researchers at Sea AI Lab that cause problems in practice:
┌─────────────────────────────────────────────────────────────────────────┐
│ GRPO BIAS PROBLEMS │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ PROBLEM 1: RESPONSE LENGTH BIAS │
│ ───────────────────────────────── │
│ │
│ Standard GRPO divides advantage by response length: │
│ │
│ loss = Σ(advantage_i / len(response_i)) │
│ │
│ This creates EXPLICIT length bias: │
│ • Longer responses get MORE total gradient magnitude │
│ • Model learns to be verbose even when wrong │
│ • Incorrect responses get progressively longer over training │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ PROBLEM 2: DIFFICULTY LEVEL BIAS │
│ ───────────────────────────────── │
│ │
│ GRPO normalizes by standard deviation of rewards per question: │
│ │
│ advantage_i = (r_i - mean(r)) / std(r) │
│ │
│ This causes problems: │
│ • Easy questions (all correct, low std) → overweighted │
│ • Hard questions (all wrong, low std) → overweighted │
│ • Medium questions (mixed, high std) → underweighted │
│ │
│ The model spends too much gradient on trivial/impossible problems! │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ OBSERVED FAILURE MODE: │
│ ────────────────────── │
│ │
│ Training step 0: "The answer is 5" (wrong, 10 tokens) │
│ Training step 1000: "Let me think... 5" (wrong, 50 tokens) │
│ Training step 5000: "I'll analyze...
**Dr. GRPO (GRPO Done Right)** fixes both biases with simple modifications:
```python
class DrGRPOTrainer(GRPOTrainer):
"""
Dr. GRPO: GRPO Done Right
Fixes two biases in standard GRPO:
1. Response length bias → use global constant normalization
2. Difficulty level bias → remove per-question std normalization
Paper: "Understanding R1-Zero-Like Training: A Critical Perspective"
Available in TRL library: loss_type="dr_grpo"
Results:
- AIME 2024: 43.3% (vs SimpleRL-Zero 36.0%)
- Stable output lengths (no inflation)
- Better token efficiency
"""
def __init__(
self,
policy_model,
reference_model,
tokenizer,
group_size: int = 16,
kl_coef: float = 0.04,
clip_epsilon: float = 0.2,
max_completion_length: int = 4096, # Global constant for normalization
**kwargs
):
super().__init__(
policy_model, reference_model, tokenizer,
group_size, kl_coef, clip_epsilon, **kwargs
)
self.max_completion_length = max_completion_length
def compute_dr_grpo_loss(
self,
policy_logprobs: torch.Tensor, # Shape: (batch, seq_len)
old_logprobs: torch.Tensor,
advantages: torch.Tensor, # Shape: (batch,)
response_lengths: torch.Tensor # Shape: (batch,)
) -> torch.Tensor:
"""
Compute Dr. GRPO loss with bias corrections.
Key changes from standard GRPO:
1. Normalize by GLOBAL constant (max_completion_length), not response length
2. Remove per-question std normalization from advantages
"""
# Probability ratio (same as standard GRPO)
ratio = torch.exp(policy_logprobs - old_logprobs)
# Clipped objective (same as standard GRPO)
surr1 = ratio * advantages.unsqueeze(-1)
surr2 = torch.clamp(
ratio,
1 - self.clip_epsilon,
1 + self.clip_epsilon
) * advantages.unsqueeze(-1)
# Per-token loss
token_loss = -torch.min(surr1, surr2)
# KEY CHANGE: Normalize by GLOBAL CONSTANT, not response length
# This eliminates length bias entirely
normalized_loss = token_loss.sum(dim=-1) / self.max_completion_length
return normalized_loss.mean()
def compute_dr_grpo_advantages(
self,
rewards: torch.Tensor,
policy_logprobs: torch.Tensor,
reference_logprobs: torch.Tensor
) -> torch.Tensor:
"""
Compute advantages WITHOUT per-question normalization.
Standard GRPO: A_i = (r_i - mean) / std ← BIASED
Dr. GRPO: A_i = r_i - mean ← UNBIASED
"""
# KL penalty
kl = policy_logprobs - reference_logprobs
adjusted_rewards = rewards - self.kl_coef * kl
# Group baseline (keep this)
baseline = adjusted_rewards.mean()
# Advantages WITHOUT std normalization (key change)
advantages = adjusted_rewards - baseline
# Optional: clip extreme advantages for stability
advantages = torch.clamp(advantages, -5.0, 5.0)
return advantages
# Using Dr. GRPO with HuggingFace TRL library
from trl import GRPOConfig, GRPOTrainer
config = GRPOConfig(
output_dir="dr_grpo_math",
loss_type="dr_grpo", # Enable Dr. GRPO
max_completion_length=4096, # Global normalization constant
num_generations=16, # Group size
beta=0.04, # KL coefficient
# ... other config
)
trainer = GRPOTrainer(
model=model,
ref_model=ref_model,
args=config,
train_dataset=math_dataset,
tokenizer=tokenizer,
)
trainer.train()
Empirical Results of Dr. GRPO:
| Method | AIME 2024 | MATH500 | Avg Response Tokens |
|---|---|---|---|
| SimpleRL-Zero-7B | 36.0% | 35.2% | 850 → 2400 (inflated) |
| Standard GRPO | 39.1% | 38.5% | 800 → 1900 (inflated) |
| Dr. GRPO | 43.3% | 40.9% | 800 → 850 (stable) |
The key insight: without length bias, models learn to reason efficiently rather than verbosely. Wrong answers stay short; correct answers get the reasoning they need.
Why curriculum learning accelerates math training: Imagine learning calculus before arithmetic—you'd fail. Neural networks face similar challenges. If you train on IMO problems from the start, the model rarely sees correct solutions (they're too hard) and learning signal is sparse. Starting with AMC-level problems lets the model learn basic reasoning patterns, which then transfer to harder problems. This isn't just faster—it often achieves higher final performance.
The schedule matters: Too slow a curriculum wastes compute on problems the model has mastered. Too fast and the model never solidifies fundamentals. The "linear" schedule increases difficulty steadily; "exponential" stays easy longer, then ramps up quickly. Empirically, linear works well for most cases. Monitor training metrics: if accuracy on medium problems drops when you introduce hard ones, slow down the schedule.
class CurriculumScheduler:
"""
Curriculum learning for math training.
Key insight: Start with easy problems, gradually increase difficulty.
From research: Models trained with curriculum learning converge
faster and achieve higher final performance.
"""
def __init__(
self,
dataset: Dataset,
initial_difficulty: str = "easy",
warmup_steps: int = 1000,
schedule: str = "linear"
):
self.dataset = dataset
self.current_difficulty = initial_difficulty
self.warmup_steps = warmup_steps
self.schedule = schedule
self.step = 0
# Group problems by difficulty
self.by_difficulty = {
"easy": [p for p in dataset if p["difficulty"] == "easy"],
"medium": [p for p in dataset if p["difficulty"] == "medium"],
"hard": [p for p in dataset if p["difficulty"] == "hard"],
"very_hard": [p for p in dataset if p["difficulty"] == "very_hard"]
}
def sample_problem(self) -> dict:
"""
Sample problem according to current curriculum stage.
"""
if self.schedule == "linear":
probs = self.linear_schedule()
elif self.schedule == "exponential":
probs = self.exponential_schedule()
else:
probs = self.constant_schedule()
# Sample difficulty level
difficulty = np.random.choice(
["easy", "medium", "hard", "very_hard"],
p=probs
)
# Sample problem from that difficulty
return random.choice(self.by_difficulty[difficulty])
def linear_schedule(self) -> list[float]:
"""
Linear increase in difficulty over warmup period.
Start: [0.7, 0.2, 0.08, 0.02]
End: [0.2, 0.4, 0.3, 0.1]
"""
progress = min(1.0, self.step / self.warmup_steps)
start = np.array([0.7, 0.2, 0.08, 0.02])
end = np.array([0.2, 0.4, 0.3, 0.1])
probs = start + progress * (end - start)
return probs / probs.sum()
def exponential_schedule(self) -> list[float]:
"""
Exponential difficulty increase (faster ramp-up).
"""
progress = min(1.0, self.step / self.warmup_steps)
progress = progress ** 2 # Square for faster initial progress
start = np.array([0.7, 0.2, 0.08, 0.02])
end = np.array([0.2, 0.4, 0.3, 0.1])
probs = start + progress * (end - start)
return probs / probs.sum()
def advance(self):
"""Advance curriculum step."""
self.step += 1
Cold-Start Supervised Fine-Tuning
Before jumping into reinforcement learning, you need to establish the reasoning format through supervised fine-tuning. This "cold start" phase is crucial: RL from a raw base model often fails because the model doesn't know how to structure its reasoning. By first training on examples with proper <think>...</think> formatting, you give RL a reasonable starting point. The data can come from stronger models (distillation) or human-written solutions. Even a few thousand high-quality examples can dramatically improve RL convergence.
class ColdStartSFT:
"""
Cold-start SFT before reinforcement learning.
Purpose:
1. Establish thinking format (<think>...</think>)
2. Provide reasonable starting policy for RL
3. Reduce RL training time
From DeepSeek R1:
"A small amount of cold-start data is constructed to
fine-tune the base model before RL"
"""
def __init__(
self,
base_model,
tokenizer,
learning_rate: float = 5e-6,
epochs: int = 3
):
self.model = base_model
self.tokenizer = tokenizer
self.lr = learning_rate
self.epochs = epochs
def prepare_cold_start_data(
self,
problems: list[dict],
strong_model=None
) -> list[dict]:
"""
Prepare high-quality reasoning traces for SFT.
Methods:
1. Distill from stronger model
2. Human-written exemplars
3. Filtered self-generated (if model is already decent)
"""
training_examples = []
for problem in problems:
if strong_model:
# Distill from stronger model
trace = self.generate_reasoning_trace(
strong_model,
problem["text"]
)
else:
# Use provided solution
trace = self.format_as_reasoning_trace(
problem["text"],
problem.get("solution", ""),
problem["answer"]
)
training_examples.append({
"prompt": f"Problem: {problem['text']}\n\n",
"completion": trace
})
return training_examples
def generate_reasoning_trace(
self,
model,
problem: str
) -> str:
"""
Generate high-quality reasoning trace from strong model.
"""
prompt = f"""Solve this math problem step by step.
Problem: {problem}
Show your complete reasoning process inside <think></think> tags,
then provide the final answer.
<think>
"""
response = model.generate(
prompt,
max_new_tokens=4096,
temperature=0.3 # Low temp for quality
)
# Verify answer is correct before using as training data
# (Filter out incorrect traces)
return response
def format_as_reasoning_trace(
self,
problem: str,
solution: str,
answer: str
) -> str:
"""
Format existing solution as reasoning trace.
"""
return f"""<think>
Let me solve this step by step.
{solution}
Let me verify my answer.
The answer is {answer}.
</think>
The answer is \\boxed{{{answer}}}"""
def train(self, training_data: list[dict]) -> dict:
"""
Fine-tune on reasoning traces.
"""
dataset = Dataset.from_list(training_data)
training_args = TrainingArguments(
output_dir="./cold_start_sft",
num_train_epochs=self.epochs,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=self.lr,
warmup_ratio=0.1,
logging_steps=10,
save_strategy="epoch"
)
trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=dataset,
data_collator=DataCollatorForSeq2Seq(self.tokenizer)
)
trainer.train()
return trainer.state.log_history
Part 2: Process Reward Models (PRMs)
Process Reward Models (PRMs) represent one of the most important advances in training mathematical reasoning systems. While traditional reward models give a single score for the entire solution, PRMs score each intermediate reasoning step. This seemingly simple change has profound implications for training dynamics, error detection, and inference-time search.
The Fundamental Problem with Outcome-Only Rewards
Consider a 20-step mathematical proof where step 15 contains an error. With outcome-only rewards:
- The model receives reward = 0 for the entire solution
- There's no signal about which of the 20 steps caused the failure
- The model might incorrectly penalize the good steps (1-14) and reward-hack on easier problems
This is the credit assignment problem in RL, and it's especially severe for long-form mathematical reasoning. If you can only say "this proof is wrong" without saying "step 15 is where it went wrong," learning is extremely inefficient.
How PRMs Solve This
PRMs assign a score to each step, providing dense feedback throughout the reasoning chain. This enables:
- Precise error localization: Know exactly which step introduced an error
- Step-level beam search: Prune unpromising reasoning paths early, before wasting compute on doomed approaches
- Better training signal: Instead of one sparse reward at the end, get continuous feedback
- Debugging capability: When a model fails, inspect which steps had low scores
Two Approaches to Building PRMs
Monte Carlo scoring (no separate model needed): For each step in a partial solution, sample many completions and count how often they reach the correct answer. If step 5 usually leads to correct solutions but step 6 rarely does, step 6 probably introduced an error. This is slow but requires no additional training.
Model-based scoring (trained PRM): Train a separate model to predict step correctness directly. This requires labeled data (steps marked correct/incorrect) but is much faster at inference time. The model learns to recognize patterns like "this algebraic manipulation is valid" or "this case analysis missed a case."
The best systems use both: Monte Carlo to generate training labels, then train a fast model-based PRM for deployment.
Why PRMs Matter for Math
Without PRM (Outcome Reward Only):
Step 1: Correct setup → No reward
Step 2: Correct algebraic manipulation → No reward
Step 3: Arithmetic error (5+3=9) → No reward
Step 4: Wrong conclusion from error → No reward
Final: Wrong answer → Reward = 0
Problem: Model gets no signal about where it went wrong
With PRM (Process Reward):
Step 1: Correct setup → Reward = 1.0
Step 2: Correct algebraic manipulation → Reward = 1.0
Step 3: Arithmetic error (5+3=9) → Reward = 0.0 ← Error located!
Step 4: Wrong conclusion from error → Reward = 0.0
Final: Wrong answer → Reward = 0.0
Benefit: Clear signal about which step is wrong
Complete PRM Implementation
Now let's look at a full implementation. The ProcessRewardModel class below provides both Monte Carlo scoring (slow but no training needed) and model-based scoring (fast but requires training data). In practice, you'd use Monte Carlo to bootstrap training labels, then train a fast model for deployment. The implementation includes step parsing, score computation, and the complete interface needed for integration with MCTS or beam search.
class ProcessRewardModel:
"""
Process Reward Model for mathematical reasoning.
Assigns rewards to each reasoning step, enabling:
1. Better credit assignment (know which step failed)
2. Step-level beam search (prune bad paths early)
3. More stable RL training (denser reward signal)
Based on OpenAI's "Let's Verify Step by Step" and
"Solving Math Word Problems with Process- and Outcome-Based Feedback"
"""
def __init__(
self,
base_model,
tokenizer,
scoring_method: str = "monte_carlo"
):
self.model = base_model
self.tokenizer = tokenizer
self.scoring_method = scoring_method
# Special tokens for step boundaries
self.step_token = "<step>"
self.end_step_token = "</step>"
def score_solution(
self,
problem: str,
solution: str
) -> list[dict]:
"""
Score each step in a solution.
Returns list of {step: str, score: float, is_correct: bool}
"""
# Parse solution into steps
steps = self.parse_steps(solution)
if self.scoring_method == "monte_carlo":
return self.monte_carlo_scoring(problem, steps)
elif self.scoring_method == "trained_model":
return self.model_based_scoring(problem, steps)
else:
return self.heuristic_scoring(problem, steps)
def monte_carlo_scoring(
self,
problem: str,
steps: list[str],
n_rollouts: int = 32
) -> list[dict]:
"""
Score steps by completing from each step and measuring success rate.
This is the approach from "Let's Verify Step by Step":
- For each step, complete the solution multiple times
- Step score = fraction of completions that reach correct answer
Advantage: Doesn't require labeled step-level data
Disadvantage: Expensive (requires many completions)
"""
results = []
for i, step in enumerate(steps):
# Build prefix: problem + steps 0..i
prefix = f"Problem: {problem}\n\nSolution:\n"
prefix += "\n".join(steps[:i+1])
# Complete from this point multiple times
correct_count = 0
for _ in range(n_rollouts):
completion = self.model.generate(
prefix,
max_new_tokens=2048,
temperature=0.8
)
# Check if completion reaches correct answer
full_solution = prefix + completion
if self.verify_solution(problem, full_solution):
correct_count += 1
score = correct_count / n_rollouts
results.append({
"step": step,
"step_index": i,
"score": score,
"is_correct": score > 0.5, # Majority of completions succeed
"n_rollouts": n_rollouts
})
return results
def model_based_scoring(
self,
problem: str,
steps: list[str]
) -> list[dict]:
"""
Score using a trained step classifier.
Requires labeled data: (problem, step, is_correct_step)
Training approaches:
1. Human annotation of step correctness
2. Automatic labeling using Monte Carlo estimates
3. Synthetic errors + correction labels
"""
results = []
for i, step in enumerate(steps):
# Build input for classifier
input_text = self.format_for_classifier(problem, steps[:i], step)
# Get model prediction
with torch.no_grad():
inputs = self.tokenizer(
input_text,
return_tensors="pt",
truncation=True
)
outputs = self.model(**inputs)
# Assume model outputs logits for [incorrect, correct]
probs = F.softmax(outputs.logits, dim=-1)
score = probs[0, 1].item() # P(correct)
results.append({
"step": step,
"step_index": i,
"score": score,
"is_correct": score > 0.5
})
return results
def heuristic_scoring(
self,
problem: str,
steps: list[str]
) -> list[dict]:
"""
Fast heuristic scoring without model inference.
Useful for real-time verification during generation.
Less accurate but much faster.
"""
results = []
for i, step in enumerate(steps):
score = 0.5 # Base score
# Check for mathematical validity
if self.has_valid_equations(step):
score += 0.1
# Check for logical connectives (indicates reasoning)
if self.has_logical_flow(step):
score += 0.1
# Check for consistency with previous steps
if i > 0 and self.is_consistent(steps[i-1], step):
score += 0.1
# Penalize obvious errors
if self.has_obvious_errors(step):
score -= 0.3
results.append({
"step": step,
"step_index": i,
"score": max(0, min(1, score)),
"is_correct": score > 0.5
})
return results
def parse_steps(self, solution: str) -> list[str]:
"""
Parse solution into individual reasoning steps.
"""
# Method 1: Explicit step markers
if self.step_token in solution:
steps = solution.split(self.step_token)
return [s.strip() for s in steps if s.strip()]
# Method 2: Newline-separated
lines = solution.split("\n")
steps = []
current_step = []
for line in lines:
line = line.strip()
if not line:
if current_step:
steps.append(" ".join(current_step))
current_step = []
else:
current_step.append(line)
if current_step:
steps.append(" ".join(current_step))
return steps
def has_valid_equations(self, step: str) -> bool:
"""Check if equations in step are syntactically valid."""
# Extract equations
equations = re.findall(r'[\w\d\+\-\*\/\(\)\^\s]+=[\w\d\+\-\*\/\(\)\^\s]+', step)
for eq in equations:
try:
left, right = eq.split("=")
# Try to parse both sides
sympy.sympify(left.strip())
sympy.sympify(right.strip())
except:
return False
return True
def has_logical_flow(self, step: str) -> bool:
"""Check for logical reasoning indicators."""
indicators = [
"therefore", "thus", "hence", "so", "because",
"since", "implies", "follows", "we have", "this means"
]
return any(ind in step.lower() for ind in indicators)
def is_consistent(self, prev_step: str, current_step: str) -> bool:
"""Check if current step is consistent with previous."""
# Extract variables/values from both
prev_values = self.extract_values(prev_step)
curr_values = self.extract_values(current_step)
# Check for contradictions
for var, val in curr_values.items():
if var in prev_values and prev_values[var] != val:
# Variable changed value unexpectedly
return False
return True
def has_obvious_errors(self, step: str) -> bool:
"""Detect obvious mathematical errors."""
# Check for common arithmetic errors
arithmetic_patterns = [
(r'(\d+)\s*\+\s*(\d+)\s*=\s*(\d+)', lambda a, b, c: int(a) + int(b) != int(c)),
(r'(\d+)\s*\-\s*(\d+)\s*=\s*(\d+)', lambda a, b, c: int(a) - int(b) != int(c)),
(r'(\d+)\s*\*\s*(\d+)\s*=\s*(\d+)', lambda a, b, c: int(a) * int(b) != int(c)),
]
for pattern, is_error in arithmetic_patterns:
matches = re.findall(pattern, step)
for match in matches:
try:
if is_error(*match):
return True
except:
pass
return False
class PRMGuidedGeneration:
"""
Use PRM to guide step-by-step generation.
At each step:
1. Generate candidate next steps
2. Score each candidate with PRM
3. Select best (or sample weighted by score)
"""
def __init__(
self,
generator_model,
prm: ProcessRewardModel,
beam_width: int = 4,
n_candidates: int = 8
):
self.generator = generator_model
self.prm = prm
self.beam_width = beam_width
self.n_candidates = n_candidates
def solve_with_prm_guidance(
self,
problem: str,
max_steps: int = 20
) -> dict:
"""
Generate solution with PRM-guided beam search.
"""
# Initialize beam with empty solution
beam = [{"steps": [], "score": 1.0}]
for step_idx in range(max_steps):
candidates = []
for beam_item in beam:
# Generate candidate next steps
next_steps = self.generate_next_steps(
problem,
beam_item["steps"],
n_candidates=self.n_candidates
)
for step in next_steps:
# Score with PRM
step_score = self.prm.score_single_step(
problem,
beam_item["steps"],
step
)
candidates.append({
"steps": beam_item["steps"] + [step],
"score": beam_item["score"] * step_score,
"last_step_score": step_score
})
# Keep top beam_width candidates
candidates.sort(key=lambda x: x["score"], reverse=True)
beam = candidates[:self.beam_width]
# Check if best candidate is complete
if self.is_complete(beam[0]["steps"]):
break
# Early stopping if all scores are low
if beam[0]["score"] < 0.01:
break
# Return best solution
best = beam[0]
return {
"solution": "\n".join(best["steps"]),
"score": best["score"],
"steps": len(best["steps"])
}
def generate_next_steps(
self,
problem: str,
current_steps: list[str],
n_candidates: int
) -> list[str]:
"""
Generate candidate next steps.
"""
prefix = f"Problem: {problem}\n\nSolution:\n"
if current_steps:
prefix += "\n".join(current_steps) + "\n"
candidates = []
for _ in range(n_candidates):
# Generate one step
response = self.generator.generate(
prefix + "Next step: ",
max_new_tokens=256,
temperature=0.8,
stop_sequences=["\n\n", "Next step:"]
)
step = response.split("\n")[0].strip()
if step:
candidates.append(step)
# Deduplicate
return list(set(candidates))
Part 3: Inference-Time Optimization (No Training Required)
One of the most exciting developments in mathematical reasoning is that you can dramatically improve any model's performance without any training—purely through smarter inference-time algorithms. This is called "test-time compute scaling": spending more compute at inference time to get better answers.
Why Inference-Time Optimization Matters
The basic insight: a single generation from an LLM is often suboptimal. The model might:
- Take a wrong turn early and be unable to recover
- Miss the key insight needed for a particular problem
- Make arithmetic errors that invalidate the rest of the solution
But if you generate multiple attempts and use smart selection, you can dramatically improve success rates. The question is: how do you efficiently explore the space of possible solutions?
The Landscape of Inference-Time Techniques
| Technique | Compute Cost | When to Use |
|---|---|---|
| Best-of-N | N × single generation | Quick improvement, any model |
| Beam Search | Moderate | When you have a good value function |
| MCTS | High | Hard problems, need to backtrack |
| Iterative Refinement | Moderate | When errors are detectable |
| Self-Consistency | N × single generation | Multiple valid solution paths |
Each technique represents a different tradeoff between compute cost and solution quality. For easy problems, Best-of-N with N=4 might suffice. For IMO problems, you might need MCTS with thousands of simulations.
Monte Carlo Tree Search for Math
MCTS is the most sophisticated inference-time technique. Originally developed for game-playing AI (it powered AlphaGo), MCTS has been adapted for mathematical reasoning with remarkable success.
Why MCTS works for math:
Mathematical problem-solving is naturally tree-structured. Each reasoning step is a "move" that leads to a new state. Some moves lead to dead ends; others open up promising paths. MCTS systematically explores this tree, using past experience to guide search toward promising regions.
The four phases of MCTS:
-
Selection: Starting from the root (the problem statement), traverse the tree by picking the most promising child at each node. "Promising" balances exploitation (high-scoring nodes) with exploration (under-explored nodes) using the UCB formula.
-
Expansion: When you reach a node with unexplored children, add a new child by generating the next reasoning step.
-
Simulation/Evaluation: Estimate the value of the new node. This can be done by completing the solution and checking correctness, or using a PRM to score the partial solution.
-
Backpropagation: Update the value estimates of all nodes on the path from root to the new node. This propagates information about what works back up the tree.
The UCB formula:
The first term favors nodes that have worked well. The second term favors nodes that haven't been explored much. The constant c controls the exploration-exploitation tradeoff.
class MathMCTS:
"""
Monte Carlo Tree Search adapted for mathematical reasoning.
Key adaptations for math:
1. States = partial solutions (sequence of reasoning steps)
2. Actions = possible next steps
3. Rewards = correctness of final answer
4. Heuristic value = PRM score or completion success rate
MCTS provides better exploration than beam search:
- Balances exploration vs exploitation (UCB)
- Can backtrack from dead ends
- Allocates more compute to promising branches
"""
def __init__(
self,
model,
prm: ProcessRewardModel = None,
c_puct: float = 1.414, # Exploration constant
n_simulations: int = 100,
max_depth: int = 20
):
self.model = model
self.prm = prm
self.c_puct = c_puct
self.n_simulations = n_simulations
self.max_depth = max_depth
def solve(self, problem: str, ground_truth: str = None) -> dict:
"""
Solve problem using MCTS.
Returns best solution found and search statistics.
"""
# Initialize root node
root = MCTSNode(
state=MCTSState(problem=problem, steps=[]),
parent=None
)
# Run simulations
for sim in range(self.n_simulations):
# Phase 1: Selection - traverse tree using UCB
node = self.select(root)
# Phase 2: Expansion - add new child nodes
if not node.is_terminal and not node.is_fully_expanded:
node = self.expand(node)
# Phase 3: Simulation - complete solution and evaluate
value = self.simulate(node, ground_truth)
# Phase 4: Backpropagation - update statistics
self.backpropagate(node, value)
# Return best solution
best_path = self.get_best_path(root)
return {
"solution": "\n".join(best_path),
"simulations": self.n_simulations,
"tree_size": self.count_nodes(root),
"confidence": root.value / root.visits if root.visits > 0 else 0
}
def select(self, node: 'MCTSNode') -> 'MCTSNode':
"""
Select leaf node using UCB1 formula.
UCB1 = value/visits + c * sqrt(log(parent_visits) / visits)
Balances exploitation (high value) with exploration (low visits).
"""
while node.children and not node.is_terminal:
if not node.is_fully_expanded:
return node
# UCB1 selection
node = max(
node.children,
key=lambda c: self.ucb_score(c, node.visits)
)
return node
def ucb_score(self, node: 'MCTSNode', parent_visits: int) -> float:
"""Compute UCB1 score for node selection."""
if node.visits == 0:
return float('inf') # Prioritize unexplored
exploitation = node.value / node.visits
exploration = self.c_puct * math.sqrt(
math.log(parent_visits) / node.visits
)
return exploitation + exploration
def expand(self, node: 'MCTSNode') -> 'MCTSNode':
"""
Expand node by generating new child states.
"""
# Generate possible next steps
next_steps = self.generate_next_steps(node.state)
for step in next_steps:
if step not in [c.state.steps[-1] for c in node.children if c.state.steps]:
# Create new child
new_state = MCTSState(
problem=node.state.problem,
steps=node.state.steps + [step]
)
child = MCTSNode(state=new_state, parent=node)
node.children.append(child)
# Mark terminal if solution is complete
if self.is_solution_complete(new_state):
child.is_terminal = True
# Return first unexplored child
for child in node.children:
if child.visits == 0:
return child
node.is_fully_expanded = True
return node.children[0] if node.children else node
def simulate(
self,
node: 'MCTSNode',
ground_truth: str = None
) -> float:
"""
Simulate to terminal state and return value.
Methods:
1. Complete with model, check answer (most accurate)
2. Use PRM score as heuristic (faster)
3. Lightweight rollout policy (fastest)
"""
state = node.state
# If already terminal, evaluate directly
if node.is_terminal:
return self.evaluate_solution(state, ground_truth)
# Option 1: Full completion
completion = self.complete_solution(state)
full_solution = "\n".join(state.steps) + "\n" + completion
if ground_truth:
# Check against ground truth
predicted = self.extract_answer(full_solution)
return 1.0 if self.answers_match(predicted, ground_truth) else 0.0
# Option 2: Use PRM score as proxy
if self.prm:
scores = self.prm.score_solution(state.problem, full_solution)
return sum(s["score"] for s in scores) / len(scores)
# Option 3: Heuristic evaluation
return self.heuristic_value(full_solution)
def backpropagate(self, node: 'MCTSNode', value: float):
"""
Backpropagate value up the tree.
"""
while node is not None:
node.visits += 1
node.value += value
node = node.parent
def generate_next_steps(
self,
state: 'MCTSState',
n_candidates: int = 5
) -> list[str]:
"""
Generate candidate next reasoning steps.
"""
prompt = f"Problem: {state.problem}\n\n"
if state.steps:
prompt += "Progress so far:\n" + "\n".join(state.steps) + "\n\n"
prompt += "Next step in the solution:\n"
candidates = []
for _ in range(n_candidates):
response = self.model.generate(
prompt,
max_new_tokens=256,
temperature=0.9, # High for diversity
stop_sequences=["\n\n"]
)
step = response.strip()
if step and step not in candidates:
candidates.append(step)
return candidates
def get_best_path(self, root: 'MCTSNode') -> list[str]:
"""
Extract best solution path from tree.
"""
path = []
node = root
while node.children:
# Select child with highest visit count (most explored = most confident)
node = max(node.children, key=lambda c: c.visits)
if node.state.steps:
path.append(node.state.steps[-1])
return path
@dataclass
class MCTSState:
"""State in MCTS = problem + partial solution."""
problem: str
steps: list[str]
class MCTSNode:
"""Node in MCTS tree."""
def __init__(self, state: MCTSState, parent: 'MCTSNode' = None):
self.state = state
self.parent = parent
self.children = []
self.visits = 0
self.value = 0.0
self.is_terminal = False
self.is_fully_expanded = False
Constrained MCTS (CMCTS) - 2025 Improvements
Standard MCTS explores the solution space uniformly, but mathematical reasoning has structure we can exploit. Not all reasoning steps are equally promising—a step that contradicts earlier work, makes an obvious error, or goes in circles is unlikely to lead anywhere. Constrained MCTS (CMCTS) uses a Process Reward Model to prune unpromising branches early, before wasting compute exploring them. This isn't just an optimization—it fundamentally changes what problems become tractable. By focusing search on high-potential paths, CMCTS achieves results that would require orders of magnitude more compute with unconstrained search.
CMCTS achieves 83.4% on MATH with a 7B model by constraining the search space intelligently:
┌─────────────────────────────────────────────────────────────────────────┐
│ CONSTRAINED MCTS (CMCTS) │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ KEY INNOVATIONS: │
│ ──────────────── │
│ │
│ 1. Constrained Action Space │
│ • Limit candidate steps to high-PRM-score options │
│ • Prune obviously wrong branches early │
│ │
│ 2. Process Reward Model Integration │
│ • Use PRM scores to guide expansion │
│ • Weighted UCB with PRM prior │
│ │
│ 3. Partial Order Rules │
│ • Mathematical reasoning has ordering constraints │
│ • Can't use theorem before establishing premises │
│ • Prune violations automatically │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ RESULTS: │
│ ──────── │
│ │
│ Model Standard MCTS CMCTS Improvement │
│ ───────────────────────────────────────────────────────────────────── │
│ 7B 72.1% 83.4% +11.3% │
│ 72B 78.5% 87.2% +8.7% │
│ │
│ Speculative Contrastive MCTS (SC-MCTS*): 51.9% faster │
│ │
└─────────────────────────────────────────────────────────────────────────┘
class ConstrainedMCTS(MathMCTS):
"""
Constrained MCTS for improved math reasoning.
Key changes from standard MCTS:
1. Constrained action space (prune low-PRM steps)
2. PRM-guided expansion (prefer high-quality steps)
3. Partial order constraints (mathematical dependencies)
Results: 7B model achieves 83.4% on MATH (vs 72.1% standard MCTS)
"""
def __init__(
self,
model,
prm: ProcessRewardModel,
c_puct: float = 1.414,
n_simulations: int = 100,
prm_threshold: float = 0.3, # Minimum PRM score to expand
max_candidates: int = 5 # Max candidate actions per node
):
super().__init__(model, prm, c_puct, n_simulations)
self.prm_threshold = prm_threshold
self.max_candidates = max_candidates
def expand(self, node: 'MCTSNode') -> 'MCTSNode':
"""
Constrained expansion: only add high-PRM candidates.
"""
# Generate candidate next steps
candidates = self.generate_next_steps(node.state)
# Score with PRM
scored_candidates = []
for step in candidates:
new_state = MCTSState(
problem=node.state.problem,
steps=node.state.steps + [step]
)
prm_score = self.prm.score_step(
node.state.problem,
node.state.steps,
step
)
scored_candidates.append((step, prm_score))
# Filter by threshold (CONSTRAINT 1)
valid_candidates = [
(step, score) for step, score in scored_candidates
if score >= self.prm_threshold
]
# Take top candidates (CONSTRAINT 2)
valid_candidates.sort(key=lambda x: x[1], reverse=True)
valid_candidates = valid_candidates[:self.max_candidates]
# Apply partial order constraints (CONSTRAINT 3)
valid_candidates = [
(step, score) for step, score in valid_candidates
if self.satisfies_partial_order(node.state, step)
]
# Create child nodes
for step, prm_score in valid_candidates:
new_state = MCTSState(
problem=node.state.problem,
steps=node.state.steps + [step]
)
child = MCTSNode(state=new_state, parent=node)
child.prior_score = prm_score # Store PRM prior
node.children.append(child)
# Return first unexplored
for child in node.children:
if child.visits == 0:
return child
node.is_fully_expanded = True
return node.children[0] if node.children else node
def ucb_score(self, node: 'MCTSNode', parent_visits: int) -> float:
"""
UCB with PRM prior (like AlphaZero policy prior).
"""
if node.visits == 0:
# Use PRM prior for unexplored nodes
return node.prior_score * 10 # Prioritize high-PRM unexplored
exploitation = node.value / node.visits
exploration = self.c_puct * math.sqrt(
math.log(parent_visits) / node.visits
)
# Add PRM prior bonus (decays with visits)
prior_bonus = node.prior_score / (1 + node.visits)
return exploitation + exploration + prior_bonus
def satisfies_partial_order(
self,
state: MCTSState,
new_step: str
) -> bool:
"""
Check if new step satisfies mathematical partial order.
Examples of violations:
- Using variable before defining it
- Applying theorem without establishing premises
- Referencing result from future step
"""
# Extract variables used in new step
new_vars = self.extract_variables(new_step)
# Extract variables defined so far
defined_vars = set()
for step in state.steps:
defined_vars.update(self.extract_defined_variables(step))
# Check all used variables are defined
for var in new_vars:
if var not in defined_vars and not self.is_given(var, state.problem):
return False # Using undefined variable
return True
class SpeculativeContrastiveMCTS:
"""
SC-MCTS*: Speculative Contrastive MCTS.
51.9% faster than standard MCTS by:
1. Speculative expansion (predict multiple levels)
2. Contrastive pruning (compare sibling branches)
"""
def __init__(self, model, prm, n_speculative: int = 3):
self.model = model
self.prm = prm
self.n_speculative = n_speculative
def speculative_expand(self, node: MCTSNode) -> list[MCTSNode]:
"""
Expand multiple levels speculatively.
Generate not just next step, but next N steps in one go.
If PRM scores are consistently high, accept all.
If PRM drops, backtrack to branching point.
"""
speculative_paths = []
# Generate N-step continuations
prompt = f"""Problem: {node.state.problem}
Current solution:
{chr(10).join(node.state.steps)}
Generate the next {self.n_speculative} reasoning steps:"""
continuation = self.model.generate(prompt, max_tokens=1024)
steps = self.parse_steps(continuation)
# Score each step with PRM
scores = []
cumulative_state = node.state.steps.copy()
for step in steps[:self.n_speculative]:
score = self.prm.score_step(
node.state.problem,
cumulative_state,
step
)
scores.append(score)
cumulative_state.append(step)
# Accept steps while PRM score stays high
accepted_steps = []
for step, score in zip(steps, scores):
if score < 0.3: # PRM threshold
break
accepted_steps.append(step)
return accepted_steps # May be 0 to N steps
Best-of-N with Weighted Verification
Simple majority voting treats all solutions equally—but not all verifications are equally reliable. A solution verified by code execution is more trustworthy than one that just "looks right." Weighted Best-of-N combines multiple verification signals: self-consistency (do multiple attempts agree?), PRM scores (does each step look valid?), and tool verification (does code confirm the answer?). By weighting these signals appropriately, you can extract more value from the same N samples.
class WeightedBestOfN:
"""
Best-of-N with multiple verification signals.
Combines:
1. Self-consistency (majority voting)
2. PRM scores
3. Symbolic verification
4. Code execution verification
Weights different signals based on reliability.
"""
def __init__(
self,
model,
n_samples: int = 32,
use_prm: bool = True,
use_symbolic: bool = True,
use_code: bool = True
):
self.model = model
self.n_samples = n_samples
# Initialize verifiers
if use_prm:
self.prm = ProcessRewardModel(model, model.tokenizer)
if use_symbolic:
self.symbolic = SymbolicVerifier()
if use_code:
self.code_executor = CodeVerifier()
# Verifier weights (tuned on validation set)
self.weights = {
"consistency": 1.0,
"prm": 0.8,
"symbolic": 1.5, # High weight—very reliable when applicable
"code": 1.2
}
def solve(self, problem: str) -> dict:
"""
Solve with weighted verification ensemble.
"""
# Generate N solutions
solutions = []
for i in range(self.n_samples):
temp = 0.6 + (i % 5) * 0.1 # Vary temperature
solution = self.model.generate(
f"Problem: {problem}\n\nSolve step by step:\n",
max_new_tokens=2048,
temperature=temp
)
solutions.append({
"text": solution,
"answer": self.extract_answer(solution),
"temperature": temp
})
# Score each solution with multiple verifiers
for sol in solutions:
sol["scores"] = self.compute_all_scores(problem, sol)
sol["total_score"] = self.aggregate_scores(sol["scores"])
# Group by answer and aggregate
answer_scores = self.aggregate_by_answer(solutions)
# Select best answer
best_answer = max(answer_scores, key=answer_scores.get)
# Find best solution with that answer
best_solution = max(
[s for s in solutions if self.normalize_answer(s["answer"]) == best_answer],
key=lambda s: s["total_score"]
)
return {
"answer": best_answer,
"solution": best_solution["text"],
"confidence": answer_scores[best_answer],
"scores": best_solution["scores"],
"n_samples": len(solutions),
"answer_distribution": answer_scores
}
def compute_all_scores(
self,
problem: str,
solution: dict
) -> dict:
"""
Compute scores from all verifiers.
"""
scores = {}
# 1. PRM score
if hasattr(self, 'prm'):
prm_results = self.prm.score_solution(problem, solution["text"])
scores["prm"] = sum(r["score"] for r in prm_results) / len(prm_results)
# 2. Symbolic verification
if hasattr(self, 'symbolic'):
symbolic_result = self.symbolic.verify(problem, solution["answer"])
scores["symbolic"] = 1.0 if symbolic_result["verified"] else 0.0
# 3. Code execution verification
if hasattr(self, 'code_executor'):
code_result = self.code_executor.verify(problem, solution["answer"])
scores["code"] = 1.0 if code_result["verified"] else 0.0
return scores
def aggregate_scores(self, scores: dict) -> float:
"""Weighted aggregation of verification scores."""
total = 0.0
weight_sum = 0.0
for verifier, score in scores.items():
weight = self.weights.get(verifier, 1.0)
total += weight * score
weight_sum += weight
return total / weight_sum if weight_sum > 0 else 0.0
def aggregate_by_answer(self, solutions: list[dict]) -> dict:
"""
Aggregate scores by unique answer.
"""
answer_scores = defaultdict(float)
answer_counts = defaultdict(int)
for sol in solutions:
ans = self.normalize_answer(sol["answer"])
# Add solution score
answer_scores[ans] += sol["total_score"]
# Consistency bonus (more samples = higher confidence)
answer_counts[ans] += 1
# Add consistency component
for ans in answer_scores:
consistency = answer_counts[ans] / len(solutions)
answer_scores[ans] += self.weights["consistency"] * consistency
return dict(answer_scores)
Verification-Guided Iterative Refinement
The gap between 75% accuracy (good) and 95% accuracy (gold medal) comes from error recovery. A model that generates wrong answers and stops there will plateau. A model that can detect its own errors and fix them keeps improving with more compute. Iterative refinement is the key: generate a solution, verify it, identify any errors, and refine. Repeat until verification passes or you hit a compute budget. This is how DeepSeekMath-V2 and Gemini Deep Think achieved gold at IMO 2025.
class IterativeRefinementSolver:
"""
Iterative refinement with structured verification.
Based on the insight that gold-medal performance requires
both generating solutions AND finding/fixing errors.
Pipeline:
1. Generate initial solution
2. Verify step-by-step
3. Identify first error
4. Regenerate from error point
5. Repeat until verified or max iterations
"""
def __init__(
self,
model,
verifier_model=None,
max_iterations: int = 5,
max_regenerations_per_step: int = 3
):
self.generator = model
self.verifier = verifier_model or model
self.max_iterations = max_iterations
self.max_regenerations = max_regenerations_per_step
def solve(self, problem: str) -> dict:
"""
Solve with iterative verification and refinement.
"""
history = []
# Initial generation
solution = self.generate_initial(problem)
for iteration in range(self.max_iterations):
# Verify current solution
verification = self.verify_detailed(problem, solution)
history.append({
"iteration": iteration,
"solution_preview": solution[:200] + "...",
"verification": verification
})
# Check if solution is correct
if verification["is_correct"]:
return {
"solution": solution,
"answer": self.extract_answer(solution),
"iterations": iteration + 1,
"status": "verified_correct",
"history": history
}
# Find first error and regenerate from there
if verification["first_error_step"] is not None:
solution = self.regenerate_from_error(
problem,
solution,
verification["first_error_step"],
verification["error_description"]
)
else:
# No specific error found, regenerate entirely
solution = self.generate_alternative(problem, solution)
# Max iterations reached
return {
"solution": solution,
"answer": self.extract_answer(solution),
"iterations": self.max_iterations,
"status": "max_iterations_reached",
"history": history
}
def verify_detailed(self, problem: str, solution: str) -> dict:
"""
Detailed verification identifying specific errors.
"""
verification_prompt = f"""You are a rigorous mathematical proof verifier.
Carefully check each step of this solution for errors.
Problem: {problem}
Solution to verify:
{solution}
For each step:
1. Check logical validity
2. Verify calculations
3. Ensure no cases are missed
If you find an error:
- Identify the FIRST step with an error
- Explain exactly what is wrong
- Do not try to fix it, just identify it
Response format:
STEP_BY_STEP_ANALYSIS:
[Your detailed analysis of each step]
FIRST_ERROR_STEP: [step number, or "none" if no errors]
ERROR_DESCRIPTION: [what's wrong, or "none"]
OVERALL_VERDICT: [CORRECT / INCORRECT / UNCERTAIN]
CONFIDENCE: [0.0 to 1.0]
"""
response = self.verifier.generate(
verification_prompt,
max_new_tokens=2048,
temperature=0.2 # Low temp for reliable verification
)
return self.parse_verification_response(response)
def parse_verification_response(self, response: str) -> dict:
"""Parse structured verification response."""
result = {
"is_correct": False,
"first_error_step": None,
"error_description": None,
"confidence": 0.5,
"full_analysis": response
}
# Extract fields
if "FIRST_ERROR_STEP:" in response:
match = re.search(r"FIRST_ERROR_STEP:\s*(\d+|none)", response, re.IGNORECASE)
if match:
val = match.group(1).lower()
if val != "none":
result["first_error_step"] = int(val)
if "ERROR_DESCRIPTION:" in response:
match = re.search(r"ERROR_DESCRIPTION:\s*(.+?)(?=OVERALL_VERDICT|$)", response, re.DOTALL)
if match:
desc = match.group(1).strip()
if desc.lower() != "none":
result["error_description"] = desc
if "OVERALL_VERDICT:" in response:
match = re.search(r"OVERALL_VERDICT:\s*(\w+)", response)
if match:
result["is_correct"] = match.group(1).upper() == "CORRECT"
if "CONFIDENCE:" in response:
match = re.search(r"CONFIDENCE:\s*([\d.]+)", response)
if match:
result["confidence"] = float(match.group(1))
return result
def regenerate_from_error(
self,
problem: str,
current_solution: str,
error_step: int,
error_description: str
) -> str:
"""
Regenerate solution from the point of error.
Key insight: Keep correct prefix, regenerate suffix.
"""
# Parse into steps
steps = self.parse_steps(current_solution)
# Keep steps before error
correct_prefix = "\n".join(steps[:error_step])
# Generate new completion
regenerate_prompt = f"""Problem: {problem}
Partial solution (verified correct):
{correct_prefix}
The next step had an error: {error_description}
Continue the solution correctly from here:"""
best_completion = None
best_score = -1
for _ in range(self.max_regenerations):
completion = self.generator.generate(
regenerate_prompt,
max_new_tokens=2048,
temperature=0.7
)
# Score this completion
full_solution = correct_prefix + "\n" + completion
score = self.quick_score(problem, full_solution)
if score > best_score:
best_score = score
best_completion = completion
return correct_prefix + "\n" + best_completion
def generate_alternative(
self,
problem: str,
failed_solution: str
) -> str:
"""
Generate completely different approach.
"""
prompt = f"""Problem: {problem}
A previous attempt at this problem was incorrect:
{failed_solution[:500]}...
Please try a DIFFERENT approach to solve this problem.
Consider alternative methods, different algebraic manipulations,
or viewing the problem from another angle.
Alternative solution:"""
return self.generator.generate(
prompt,
max_new_tokens=2048,
temperature=0.9 # High temp for diversity
)
Part 4: Real IMO Problem Walkthrough
Theory and implementations are important, but nothing beats seeing these techniques work on a real problem. In this section, we trace through an actual IMO problem step-by-step, showing how MCTS explores different approaches, how PRMs guide the search, and how verification-refinement catches and fixes errors. This gives you a concrete picture of what "AI solves IMO problems" actually looks like in practice.
IMO 2024 Problem 1
This problem is from the 2024 International Mathematical Olympiad. It's a Number Theory problem that requires finding all functions satisfying a divisibility condition—a classic IMO style that tests both creativity (finding the right approach) and rigor (proving the solution is complete).
Problem: Let N = {1, 2, 3, ...} be the positive integers.
A function f: N → N is called "good" if for all positive integers
m and n, the sum f(m) + f(n) divides m + n.
Find all good functions.
def solve_imo_2024_p1():
"""
Walkthrough of solving IMO 2024 Problem 1.
"""
# Step 1: Problem Analysis
problem = """
Let N = {1, 2, 3, ...} be the positive integers.
A function f: N → N is called "good" if for all positive integers
m and n, the sum f(m) + f(n) divides m + n.
Find all good functions.
"""
# System classifies as: Number Theory, Functional Equation
# Strategy: Start with special cases, find pattern, prove general
# Step 2: Initial Generation (Temperature 0.7)
initial_solution = """
<think>
Let me analyze this functional equation systematically.
First, let's try m = n:
f(m) + f(m) = 2f(m) divides m + m = 2m
So f(m) | m for all m (since 2f(m) | 2m implies f(m) | m)
Now let's try m = 1:
f(1) + f(n) divides 1 + n = n + 1
Since f(1) | 1 (from above), we have f(1) = 1.
So 1 + f(n) divides n + 1.
This means f(n) + 1 ≤ n + 1, so f(n) ≤ n.
Combined with f(n) | n, the candidates are divisors of n that are ≤ n.
Let me check if f(n) = n works:
f(m) + f(n) = m + n divides m + n ✓
Let me check if f(n) = 1 works:
f(m) + f(n) = 2 divides m + n
This requires m + n to be even for all m, n - FALSE (e.g., m=1, n=2)
What about f(n) = n for some values and different for others?
Actually, let me be more careful. We have:
- f(n) | n for all n
- f(1) = 1
Let's try m = 2, n = 1:
f(2) + f(1) = f(2) + 1 divides 3
So f(2) + 1 ∈ {1, 3}, meaning f(2) ∈ {0, 2}
Since f: N → N, f(2) ≥ 1, so f(2) = 2.
Let's try m = 3, n = 1:
f(3) + 1 divides 4, so f(3) + 1 ∈ {1, 2, 4}
f(3) ∈ {0, 1, 3}
Since f(3) | 3 and f(3) ≥ 1, f(3) ∈ {1, 3}
Case f(3) = 1:
Check m = 2, n = 3: f(2) + f(3) = 2 + 1 = 3 divides 5? NO.
So f(3) = 3.
This suggests f(n) = n for all n.
Proof that f(n) = n:
Suppose f(k) < k for some minimal k > 1.
We have f(k) | k and f(k) < k, so f(k) ≤ k/2.
Take m = k, n = 1:
f(k) + 1 divides k + 1
Since f(k) ≤ k/2, we have f(k) + 1 ≤ k/2 + 1.
For f(k) + 1 to divide k + 1, we need f(k) + 1 ≤ (k+1)/2 + 1... hmm.
Let me try a different approach.
Take m = n = k:
2f(k) | 2k, so f(k) | k ✓
Take m = k, n = k-1 (assuming k > 1):
f(k) + f(k-1) | 2k - 1
If f(k-1) = k-1 (by minimality of k), then:
f(k) + k - 1 | 2k - 1
Since f(k) | k and f(k) < k, let f(k) = k/d for some d ≥ 2 (d | k).
k/d + k - 1 | 2k - 1
k(1/d + 1) - 1 | 2k - 1
This is getting complicated. Let me try specific cases.
If k = 4 and f(4) = 2 (not 4):
f(4) + f(3) = 2 + 3 = 5 should divide 7. But 5 ∤ 7. Contradiction.
If k = 6 and f(6) = 3:
f(6) + f(5) = 3 + 5 = 8 should divide 11. But 8 ∤ 11. Contradiction.
If k = 6 and f(6) = 2:
f(6) + f(5) = 2 + 5 = 7 should divide 11. But 7 ∤ 11. Contradiction.
These contradictions suggest f(n) = n is the only solution.
Formal proof by strong induction:
Base: f(1) = 1 ✓
Inductive step: Assume f(j) = j for all j < k. Show f(k) = k.
We know f(k) | k.
Suppose f(k) = k/d for some d > 1.
Take n = k - 1 (if k > 1):
f(k) + f(k-1) = k/d + (k-1) | 2k - 1
Since gcd(k, 2k-1) = gcd(k, -1) = 1 (as 2k-1 = 2k - 1),
we need (k/d + k - 1) | 2k - 1.
k/d + k - 1 = k(1/d + 1) - 1 = k(d+1)/d - 1
For this to divide 2k - 1, and noting d ≥ 2:
k(d+1)/d - 1 ≥ k(3/2) - 1 = 3k/2 - 1
But 2k - 1 < 2(3k/2 - 1) = 3k - 2 only when... actually let's check:
We need k(d+1)/d - 1 ≤ 2k - 1
k(d+1)/d ≤ 2k
(d+1)/d ≤ 2
d + 1 ≤ 2d
1 ≤ d ✓
So the divisibility condition could hold. Need more careful analysis.
Actually, for d = 2: k/2 + k - 1 = 3k/2 - 1
This divides 2k - 1.
3k/2 - 1 | 2k - 1
3k - 2 | 4k - 2 = 2(2k-1)
If 3k - 2 | 4k - 2:
3k - 2 | 4k - 2 - (4/3)(3k-2) = 4k - 2 - 4k + 8/3 = -2 + 8/3 = 2/3
This is not an integer, so 3k - 2 ∤ 4k - 2 for k > 1.
Wait, that's not quite right. Let me redo:
If 3k - 2 | 4k - 2:
4k - 2 = q(3k - 2) for some integer q.
4k - 2 = 3qk - 2q
4k - 3qk = 2 - 2q
k(4 - 3q) = 2(1 - q)
If q = 1: k(1) = 0, impossible.
If q = 0: 4k = 2, k = 1/2, impossible.
If q = 2: k(-2) = -2, k = 1. But we assumed k > 1.
If q ≥ 3 or q < 0: LHS and RHS have different signs.
So for k > 1, 3k - 2 ∤ 4k - 2, meaning f(k) ≠ k/2.
By similar analysis for other divisors, f(k) = k.
</think>
The only good function is f(n) = n for all positive integers n.
Proof:
1. Setting m = n gives 2f(n) | 2n, so f(n) | n.
2. Setting m = 1, n = 1 gives 2f(1) | 2, so f(1) = 1.
3. By strong induction, assuming f(j) = j for j < k and f(k) = k/d
for d > 1, we derive a contradiction using m = k, n = k-1.
4. Verification: f(n) = n satisfies f(m) + f(n) = m + n | m + n. ✓
\\boxed{f(n) = n}
"""
# Step 3: Verification
verification_result = {
"step_1": {"correct": True, "note": "Valid derivation of f(n) | n"},
"step_2": {"correct": True, "note": "Correct that f(1) = 1"},
"step_3": {"correct": True, "note": "Induction structure is valid"},
"step_4": {"correct": True, "note": "Verification confirms solution"},
"overall": "CORRECT"
}
# Step 4: Symbolic verification (if applicable)
# For this problem, we verify the claimed solution works
def verify_f_equals_n():
"""Verify that f(n) = n satisfies the condition."""
# f(m) + f(n) = m + n always divides m + n
return True
# Output
return {
"problem": "IMO 2024 P1",
"answer": "f(n) = n",
"solution": initial_solution,
"verification": verification_result,
"symbolic_check": verify_f_equals_n(),
"iterations": 1, # Solved on first attempt
"confidence": 0.95
}
Trace of Error Detection and Refinement
The previous example worked on the first try. But what happens when the model makes a mistake? This trace shows the refinement loop in action: the model generates a solution with an error, the verifier catches it, and the model fixes the problem in a subsequent iteration. This is where iterative refinement earns its keep—without it, the first error would be final.
def demonstrate_refinement():
"""
Show how the system detects and fixes errors.
"""
problem = "Find all positive integers n such that n² + 2n + 2 is a perfect square."
# Iteration 1: Initial (incorrect) attempt
attempt_1 = """
<think>
We need n² + 2n + 2 = k² for some positive integer k.
Note that n² + 2n + 1 = (n+1)², so:
n² + 2n + 2 = (n+1)² + 1
We need (n+1)² + 1 = k²
k² - (n+1)² = 1
(k - n - 1)(k + n + 1) = 1
Since k, n are positive integers:
k - n - 1 = 1 and k + n + 1 = 1
From second equation: k + n + 1 = 1, so k + n = 0.
Since k, n > 0, this is impossible.
Therefore, there are no solutions.
</think>
There are no positive integers n such that n² + 2n + 2 is a perfect square.
\\boxed{\\text{No solutions}}
"""
# Verification finds error
verification_1 = {
"step_analysis": [
{"step": 1, "correct": True, "note": "Correct setup"},
{"step": 2, "correct": True, "note": "Correct factorization"},
{"step": 3, "correct": False, "note": "ERROR: (k-n-1)(k+n+1)=1 has solutions (1,1) and (-1,-1)"},
],
"first_error_step": 3,
"error_description": "The equation (k-n-1)(k+n+1)=1 can be satisfied by k-n-1=1, k+n+1=1 OR k-n-1=-1, k+n+1=-1. Only considered one case."
}
# Iteration 2: Regenerate from error
attempt_2 = """
<think>
[Previous correct work kept]
We have (k - n - 1)(k + n + 1) = 1
Case 1: k - n - 1 = 1 and k + n + 1 = 1
Adding: 2k = 2, so k = 1
Subtracting: 2n + 2 = 0, so n = -1 (not positive)
Case 2: k - n - 1 = -1 and k + n + 1 = -1
Adding: 2k = -2, so k = -1 (not positive)
Wait, I should also check if there are solutions over integers (not just factoring as 1×1).
Actually, the only factorizations of 1 in integers are 1×1 and (-1)×(-1).
But wait - let me reconsider. We have:
k² - (n+1)² = 1
This is a Pell-like equation. The difference of two squares equals 1 means:
k² = (n+1)² + 1
For consecutive squares to differ by 1, we need... actually, consecutive squares differ by 2n+1.
Hmm, k² and (n+1)² differ by 1. The only consecutive squares are 0,1 (differ by 1), 1,4 (differ by 3)...
Actually, no two perfect squares differ by exactly 1 except 0 and 1!
So k² = 1 and (n+1)² = 0, giving k = 1, n = -1 (not positive).
Or k² = 0 and (n+1)² = -1, impossible.
Therefore, there are no positive integer solutions.
Let me verify by testing small values:
n = 1: 1 + 2 + 2 = 5 (not a square)
n = 2: 4 + 4 + 2 = 10 (not a square)
n = 3: 9 + 6 + 2 = 17 (not a square)
n = 4: 16 + 8 + 2 = 26 (not a square)
n = 5: 25 + 10 + 2 = 37 (not a square)
The conclusion holds.
</think>
There are no positive integers n such that n² + 2n + 2 is a perfect square.
Proof: n² + 2n + 2 = (n+1)² + 1. For this to be a perfect square k²,
we need k² - (n+1)² = 1. The only squares differing by 1 are 0 and 1,
which gives n = -1 (not positive). ✓
\\boxed{\\text{No positive integer solutions}}
"""
# Verification passes
verification_2 = {
"step_analysis": [
{"step": 1, "correct": True},
{"step": 2, "correct": True},
{"step": 3, "correct": True, "note": "Now correctly handles all cases"},
{"step": 4, "correct": True, "note": "Good verification with examples"},
],
"overall": "CORRECT",
"confidence": 0.95
}
return {
"iterations": [
{"attempt": attempt_1, "verification": verification_1, "status": "error_found"},
{"attempt": attempt_2, "verification": verification_2, "status": "verified_correct"}
],
"final_answer": "No positive integer solutions"
}
Part 5: AlphaGeometry2 and Geometry-Specific Approaches
Geometry problems require specialized handling. AlphaGeometry2 (2025) achieves 84% solve rate on IMO geometry problems, up from 54% in the original version.
Why Geometry Is Different From Other Math Domains
IMO geometry problems present unique challenges that don't exist in algebra, number theory, or combinatorics:
The auxiliary construction problem:
Most IMO geometry proofs require introducing new elements—auxiliary points, lines, or circles that aren't in the original problem statement. These constructions are often the key insight that unlocks the entire proof. The search space is infinite: you could draw any line, any circle, through any combination of points. Finding the right construction is what separates a gold medalist from someone who can't start the problem.
For example, a typical IMO geometry problem might say "prove that points A, B, C, D are concyclic." A human solver recognizes this as needing to show a cyclic quadrilateral—perhaps by constructing the circumcircle, proving equal angles, or using power of a point. The auxiliary construction (which circumcircle? which angles?) isn't given.
The neuro-symbolic insight:
AlphaGeometry's breakthrough was recognizing that geometry requires a hybrid approach:
- Neural model: Suggests which auxiliary constructions might be useful (creative insight)
- Symbolic engine: Verifies whether those constructions lead to valid proofs (rigorous deduction)
The neural component handles the infinite search space by learning patterns from human proofs. "When proving concyclicity, try constructing the circumcircle of three points" is a pattern a neural network can learn. The symbolic engine then exhaustively explores the consequences of each construction.
Why coordinate geometry often wins:
A surprising finding: many geometry problems that seem to require clever synthetic proofs can be solved by brute-force coordinate bashing. Set up a coordinate system, assign coordinates to known points, express the conditions algebraically, and let the computer crunch the equations.
This is inelegant but reliable. AlphaGeometry2's coordinate geometry mode achieves high accuracy on problems that would require creative insight in synthetic geometry. The tradeoff is that coordinate proofs are long, unreadable by humans, and don't generalize to similar problems—but they work.
The counterintuitive vision finding:
You might expect that showing a model the geometric diagram would help. After all, humans use diagrams heavily when solving geometry problems. AlphaGeometry2's team tested multimodal approaches—feeding both the problem text and the rendered diagram to vision-language models.
The result: no significant improvement. The algebraic description contains all the information; the visual representation adds little. This suggests that human use of diagrams is more about cognitive convenience (working memory) than about extracting information that isn't in the text.
┌─────────────────────────────────────────────────────────────────────────┐
│ ALPHAGEOMETRY2 ADVANCES │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ KEY IMPROVEMENTS: │
│ ───────────────── │
│ │
│ 1. Expanded domain-specific language │
│ • More construction primitives │
│ • Richer theorem library │
│ │
│ 2. C++ symbolic engine (100x faster) │
│ • Enables deeper search │
│ • More candidate proofs per problem │
│ │
│ 3. Larger neural model (Gemini-based) │
│ • Better construction suggestions │
│ • More diverse proof strategies │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ SURPRISING FINDING: │
│ ─────────────────── │
│ │
│ "Vision models and diagram understanding DON'T significantly help" │
│ │
│ The team tested multimodal approaches (reading geometry diagrams) │
│ but found algebraic/symbolic reasoning dominates. Visual input │
│ adds little beyond what can be parsed from text descriptions. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ RESULTS: │
│ ──────── │
│ │
│ IMO Geometry Problems (30 years): 54% → 84% │
│ Combined with Gemini: 35/42 points at IMO 2025 │
│ │
└─────────────────────────────────────────────────────────────────────────┘
The GeometryAgent implementation below demonstrates the neuro-symbolic approach that powers AlphaGeometry2. The agent has two solving strategies:
Coordinate geometry (always tried first): Place points in a coordinate system, express all geometric entities as algebraic equations, and solve symbolically. This is reliable but produces inelegant proofs. For problems asking "prove X", it works but doesn't satisfy the aesthetic standards of competition mathematics.
Synthetic geometry with neural suggestions: The more sophisticated approach. The neural model proposes auxiliary constructions (draw line AB, construct circumcircle, etc.), and the symbolic engine verifies whether those constructions lead to valid proofs. The suggest_auxiliary_construction method generates candidates; the symbolic engine tests each one. When a construction enables a proof, the agent traces back to produce human-readable steps.
The key data structures are GeometricStatement (facts like "A, B, C are collinear") and GeometricTheorem (applicable theorems like "Power of a point"). The symbolic engine applies theorems to derive new statements until the goal is reached.
class GeometryAgent:
"""
Specialized agent for geometry problems.
Key insight from AlphaGeometry:
Combine neural suggestion with symbolic deduction.
Neural: Propose auxiliary constructions, lemmas
Symbolic: Verify deductions, compute coordinates
"""
def __init__(self, model, use_coordinates: bool = True):
self.model = model
self.use_coordinates = use_coordinates
self.symbolic_engine = GeometricSymbolicEngine()
def solve(self, problem: str) -> dict:
"""
Solve geometry problem using neuro-symbolic approach.
"""
# Parse geometric entities
entities = self.parse_geometry(problem)
# Strategy 1: Coordinate geometry (reliable but inelegant)
if self.use_coordinates:
coord_solution = self.solve_with_coordinates(problem, entities)
if coord_solution["success"]:
return coord_solution
# Strategy 2: Synthetic approach with neural suggestions
return self.solve_synthetic(problem, entities)
def solve_with_coordinates(
self,
problem: str,
entities: dict
) -> dict:
"""
Translate to coordinates and solve algebraically.
"""
# Step 1: Set up coordinate system
coord_system = self.establish_coordinates(entities)
# Step 2: Express all entities in coordinates
expressions = self.express_in_coordinates(entities, coord_system)
# Step 3: Translate goal to algebraic equation
goal_equation = self.translate_goal(problem, expressions)
# Step 4: Solve symbolically
solution = self.symbolic_engine.solve(goal_equation)
return {
"success": solution is not None,
"method": "coordinate_geometry",
"coordinate_system": coord_system,
"solution": solution
}
def solve_synthetic(
self,
problem: str,
entities: dict
) -> dict:
"""
Synthetic geometry with neural-guided construction.
Loop:
1. Neural model suggests construction or lemma
2. Symbolic engine verifies and derives consequences
3. Check if goal is reached
"""
constructions = []
derived_facts = set(entities.get("given_facts", []))
for iteration in range(20): # Max iterations
# Neural suggestion
suggestion = self.neural_suggest(
problem,
entities,
constructions,
derived_facts
)
if suggestion["type"] == "construction":
# Add auxiliary point/line
new_entity = self.add_construction(
suggestion["description"],
entities
)
constructions.append(new_entity)
# Derive new facts from construction
new_facts = self.symbolic_engine.derive_facts(
entities,
constructions,
derived_facts
)
derived_facts.update(new_facts)
elif suggestion["type"] == "lemma":
# Apply known theorem
result = self.apply_lemma(
suggestion["lemma_name"],
suggestion["arguments"],
entities
)
if result:
derived_facts.add(result)
elif suggestion["type"] == "conclude":
# Check if we can prove the goal
if self.goal_satisfied(problem, derived_facts):
return {
"success": True,
"method": "synthetic",
"constructions": constructions,
"proof_steps": list(derived_facts)
}
return {"success": False, "reason": "max_iterations"}
def neural_suggest(
self,
problem: str,
entities: dict,
constructions: list,
derived_facts: set
) -> dict:
"""
Get neural suggestion for next step.
"""
prompt = f"""Geometry problem: {problem}
Current constructions: {constructions}
Known facts: {list(derived_facts)[:20]}
What construction or lemma would help solve this problem?
Options:
1. Add auxiliary construction (point, line, circle)
2. Apply a theorem (similar triangles, power of point, etc.)
3. Conclude the proof
Suggest the most helpful next step:"""
response = self.model.generate(prompt, max_new_tokens=256)
return self.parse_suggestion(response)
class GeometricSymbolicEngine:
"""
Symbolic geometry engine for verification and deduction.
Uses:
1. Coordinate geometry computations
2. Automated theorem proving rules
3. Known geometry theorems
"""
def __init__(self):
self.theorems = self.load_geometry_theorems()
def derive_facts(
self,
entities: dict,
constructions: list,
existing_facts: set
) -> set:
"""
Derive new geometric facts from current configuration.
"""
new_facts = set()
# Apply all applicable theorems
for theorem in self.theorems:
applications = theorem.find_applications(
entities,
constructions,
existing_facts
)
for app in applications:
new_facts.add(app.conclusion)
return new_facts
def load_geometry_theorems(self) -> list:
"""
Load library of geometry theorems.
Examples:
- Thales' theorem
- Power of a point
- Similar triangles
- Angle bisector theorem
- etc.
"""
return [
ThalesTheorem(),
PowerOfPointTheorem(),
SimilarTrianglesTheorem(),
AngleBisectorTheorem(),
PtolemyTheorem(),
MenelausTheorem(),
CevaTheorem(),
# ... many more
]
Part 6: Multi-Agent Debate for Math
Multi-agent debate is one of the most promising techniques for improving mathematical reasoning accuracy. The core idea: instead of trusting a single model's output, have multiple "agents" with different roles critically examine each other's work.
Why Debate Works for Math
Mathematical proofs have an asymmetry: checking is easier than generating. A proof can take hours to discover, but verifying its correctness takes minutes. This asymmetry is perfect for debate—a critic agent can catch errors that the solver missed without needing to solve the problem itself.
The debate participants:
| Agent | Role | Capability |
|---|---|---|
| Solver | Generates initial solution | Strong at problem-solving |
| Critic | Finds errors, gaps, unjustified steps | Strong at verification |
| Defender | Addresses criticisms, clarifies reasoning | Strong at explanation |
| Judge | Decides if criticisms are valid | Balanced evaluation |
The Debate Protocol
Understanding how debate works in practice helps you implement it correctly. The protocol below is iterated—each round surfaces and addresses issues, progressively improving solution quality. Most problems converge within 2-3 rounds, but complex proofs may need more. The key is having clear stopping criteria: either the critic finds no issues (consensus reached) or you hit a maximum round limit.
- Solver generates an initial solution
- Critic examines the solution, looking for:
- Logical gaps (steps that don't follow)
- Calculation errors
- Missing cases
- Unjustified claims
- Defender responds to each criticism:
- Accept and fix valid criticisms
- Rebut invalid criticisms with justification
- Judge evaluates:
- Were the criticisms valid?
- Were the defenses adequate?
- Is the revised solution correct?
- Repeat until consensus or max rounds
Why Multiple Agents Beat Single-Agent Self-Correction
You might ask: why not just have one model check its own work? Research shows multi-agent debate outperforms single-agent self-correction because:
- Different "perspectives": Even with the same base model, different prompts (critic vs. defender) elicit different reasoning patterns
- Adversarial pressure: A critic is incentivized to find errors; a self-checker might be biased toward confirming its own work
- Information surfacing: The debate format forces explicit justification of each step, surfacing implicit assumptions
The implementation below shows a complete multi-agent debate system. The MathDebateSystem class orchestrates solver, critic, and judge agents through multiple rounds. Note how each agent has a specialized prompt that shapes its behavior—the critic actively looks for flaws while the defender tries to justify the reasoning. This adversarial dynamic is what makes debate effective.
class MathDebateSystem:
"""
Multi-agent debate for mathematical reasoning.
Agents:
1. Solver: Generates solutions
2. Critic: Finds errors in solutions
3. Defender: Addresses criticisms
4. Judge: Evaluates debate outcome
Research shows debate improves accuracy on hard problems.
"""
def __init__(
self,
solver_model,
critic_model=None,
judge_model=None,
max_rounds: int = 3
):
self.solver = solver_model
self.critic = critic_model or solver_model
self.judge = judge_model or solver_model
self.max_rounds = max_rounds
def solve_with_debate(self, problem: str) -> dict:
"""
Solve problem through multi-agent debate.
"""
# Initial solution
solution = self.generate_solution(problem)
debate_history = []
for round_num in range(self.max_rounds):
# Critic finds issues
criticism = self.generate_criticism(problem, solution)
debate_history.append({
"round": round_num + 1,
"solution": solution[:500] + "...",
"criticism": criticism
})
# If no valid criticism, solution is likely correct
if not criticism["has_valid_criticism"]:
break
# Defender responds
defense = self.generate_defense(
problem,
solution,
criticism
)
# Judge evaluates
judgment = self.judge_round(
problem,
solution,
criticism,
defense
)
if judgment["criticism_valid"]:
# Revise solution based on criticism
solution = self.revise_solution(
problem,
solution,
criticism,
defense
)
else:
# Criticism rejected, keep solution
break
# Final judgment
final_verdict = self.final_judgment(problem, solution)
return {
"solution": solution,
"answer": self.extract_answer(solution),
"debate_rounds": len(debate_history),
"debate_history": debate_history,
"final_verdict": final_verdict
}
def generate_criticism(
self,
problem: str,
solution: str
) -> dict:
"""
Critic agent finds potential errors.
"""
prompt = f"""You are a critical mathematics reviewer.
Your job is to find errors in mathematical proofs.
Problem: {problem}
Proposed Solution:
{solution}
Carefully analyze this solution and identify ANY errors:
1. Logical errors (non-sequiturs, unjustified steps)
2. Calculation errors
3. Missing cases
4. Incorrect theorems/facts used
5. Gaps in reasoning
If you find errors, explain them clearly.
If the solution is correct, state that explicitly.
Your critique:"""
response = self.critic.generate(
prompt,
max_new_tokens=1024,
temperature=0.3
)
has_criticism = not any(
phrase in response.lower()
for phrase in ["no errors", "solution is correct", "proof is valid"]
)
return {
"text": response,
"has_valid_criticism": has_criticism
}
def generate_defense(
self,
problem: str,
solution: str,
criticism: dict
) -> str:
"""
Defender addresses criticism.
"""
prompt = f"""You are defending a mathematical solution.
Problem: {problem}
Your Solution:
{solution}
Criticism received:
{criticism['text']}
Defend your solution. Either:
1. Explain why the criticism is invalid
2. Acknowledge the error and explain how to fix it
3. Provide additional justification for questioned steps
Your defense:"""
return self.solver.generate(
prompt,
max_new_tokens=1024,
temperature=0.4
)
def judge_round(
self,
problem: str,
solution: str,
criticism: dict,
defense: str
) -> dict:
"""
Judge evaluates debate round.
"""
prompt = f"""You are an impartial mathematical judge.
Problem: {problem}
Solution:
{solution}
Criticism:
{criticism['text']}
Defense:
{defense}
Evaluate this debate:
1. Is the criticism valid? (Does it identify a real error?)
2. Is the defense successful? (Does it address the criticism?)
3. What is your verdict?
Respond with:
CRITICISM_VALID: [yes/no]
DEFENSE_SUCCESSFUL: [yes/no]
VERDICT: [solution_correct/solution_needs_revision/unclear]
EXPLANATION: [your reasoning]"""
response = self.judge.generate(
prompt,
max_new_tokens=512,
temperature=0.2
)
# Parse response
criticism_valid = "CRITICISM_VALID: yes" in response.lower()
defense_successful = "DEFENSE_SUCCESSFUL: yes" in response.lower()
return {
"criticism_valid": criticism_valid,
"defense_successful": defense_successful,
"full_response": response
}
def revise_solution(
self,
problem: str,
solution: str,
criticism: dict,
defense: str
) -> str:
"""
Revise solution based on valid criticism.
"""
prompt = f"""Revise the solution to address the valid criticism.
Problem: {problem}
Original Solution:
{solution}
Valid Criticism:
{criticism['text']}
Defense attempt:
{defense}
Please provide a REVISED solution that fixes the identified errors
while keeping correct parts of the original.
Revised Solution:"""
return self.solver.generate(
prompt,
max_new_tokens=2048,
temperature=0.5
)
Part 7: Production Deployment
Deploying mathematical reasoning systems in production introduces challenges beyond pure accuracy: cost control, latency requirements, reliability, and observability. A system that achieves 90% accuracy but costs $10 per query isn't viable for most applications.
The Economics of AI Math Solving
Cost breakdown for a typical math query:
| Component | Tokens | Cost (GPT-4) |
|---|---|---|
| Problem input | ~200 | $0.006 |
| Reasoning chain | ~2000 | $0.060 |
| Self-verification | ~1000 | $0.030 |
| Total per attempt | ~3200 | ~$0.10 |
| With 8 attempts | ~25600 | ~$0.80 |
At scale (1000 queries/day), this becomes 24,000/month—just for the LLM costs. Add inference infrastructure, monitoring, and engineering time, and costs can easily exceed $50,000/month.
Optimization Strategies
1. Tiered solving (cheap first): Start with a fast, cheap model (GPT-4o-mini, Claude Haiku). If it solves the problem confidently, stop. Only escalate to expensive models (GPT-4, Claude Opus, o1) for hard problems. This can reduce costs by 60-80% while maintaining accuracy.
2. Difficulty estimation: Before solving, estimate problem difficulty. Easy problems don't need MCTS or multiple attempts. Hard problems justify the extra compute. A simple classifier trained on problem features can route efficiently.
3. Caching: Mathematical problems have patterns. "Solve x^2 - 4 = 0" might appear in many forms. Cache both exact matches and semantic equivalents (embedding similarity). Cache hit rates of 20-40% are common for educational platforms.
4. Early stopping: If a solution is verified correct, stop generating alternatives. If 5 attempts all fail, escalate rather than continuing with diminishing returns.
Cost and Latency Optimization
The techniques above achieve high accuracy—but at what cost? Running MCTS with 1000 simulations or Best-of-64 with PRM reranking can cost $5-50 per problem. For production systems, you need to balance accuracy against cost and latency. The implementation below shows how to build a tiered system: start cheap (GPT-4o-mini), escalate if needed, cache aggressively, and parallelize where possible. A well-designed tiered system achieves 80-90% of maximum accuracy at 10-20% of maximum cost.
class ProductionMathSolver:
"""
Production-ready math solver with cost/latency optimization.
Key considerations:
1. Tiered solving (cheap first, expensive if needed)
2. Caching for common patterns
3. Parallel execution where possible
4. Early stopping when confident
"""
def __init__(
self,
cheap_model, # e.g., Claude Haiku, GPT-4o-mini
expensive_model, # e.g., Claude Opus, o1
max_budget_dollars: float = 1.0,
target_latency_seconds: float = 30.0
):
self.cheap = cheap_model
self.expensive = expensive_model
self.max_budget = max_budget_dollars
self.target_latency = target_latency_seconds
# Pricing (per 1M tokens, approximate)
self.pricing = {
"cheap_input": 0.25,
"cheap_output": 1.25,
"expensive_input": 15.0,
"expensive_output": 75.0
}
# Cache for common problem patterns
self.cache = SolutionCache()
def solve(self, problem: str) -> dict:
"""
Solve with cost-aware strategy.
"""
start_time = time.time()
total_cost = 0.0
# Check cache first
cached = self.cache.get(problem)
if cached:
return {
**cached,
"source": "cache",
"latency": time.time() - start_time,
"cost": 0.0
}
# Tier 1: Single cheap attempt
result, cost = self.try_cheap_single(problem)
total_cost += cost
if result["confidence"] > 0.9:
self.cache.set(problem, result)
return {
**result,
"source": "cheap_single",
"latency": time.time() - start_time,
"cost": total_cost
}
# Tier 2: Cheap model with verification
if total_cost < self.max_budget * 0.3:
result, cost = self.try_cheap_with_verification(problem)
total_cost += cost
if result["confidence"] > 0.85:
self.cache.set(problem, result)
return {
**result,
"source": "cheap_verified",
"latency": time.time() - start_time,
"cost": total_cost
}
# Tier 3: Expensive model
if total_cost < self.max_budget * 0.7:
result, cost = self.try_expensive(problem)
total_cost += cost
if result["confidence"] > 0.8:
self.cache.set(problem, result)
return {
**result,
"source": "expensive",
"latency": time.time() - start_time,
"cost": total_cost
}
# Tier 4: Full pipeline (most expensive)
if total_cost < self.max_budget:
result, cost = self.try_full_pipeline(problem)
total_cost += cost
self.cache.set(problem, result)
return {
**result,
"source": "full_pipeline",
"latency": time.time() - start_time,
"cost": total_cost
}
# Budget exceeded, return best so far
return {
**result,
"source": "budget_limited",
"latency": time.time() - start_time,
"cost": total_cost
}
def try_cheap_single(self, problem: str) -> tuple[dict, float]:
"""Single attempt with cheap model."""
response = self.cheap.generate(
f"Solve: {problem}\n\nSolution:",
max_new_tokens=1024
)
# Estimate cost
cost = self.estimate_cost("cheap", len(problem), len(response))
# Quick confidence estimate
confidence = self.quick_confidence(response)
return {
"solution": response,
"answer": self.extract_answer(response),
"confidence": confidence
}, cost
def try_cheap_with_verification(self, problem: str) -> tuple[dict, float]:
"""Multiple attempts + majority voting."""
solutions = []
total_cost = 0.0
for _ in range(5): # 5 samples
response = self.cheap.generate(
f"Solve step by step: {problem}\n\n",
max_new_tokens=1024,
temperature=0.7
)
solutions.append(response)
total_cost += self.estimate_cost("cheap", len(problem), len(response))
# Majority vote
answers = [self.extract_answer(s) for s in solutions]
answer_counts = Counter(answers)
best_answer, count = answer_counts.most_common(1)[0]
confidence = count / len(solutions)
# Find best solution with that answer
best_solution = next(
s for s in solutions
if self.extract_answer(s) == best_answer
)
return {
"solution": best_solution,
"answer": best_answer,
"confidence": confidence,
"agreement": f"{count}/{len(solutions)}"
}, total_cost
def estimate_cost(
self,
model_tier: str,
input_tokens: int,
output_tokens: int
) -> float:
"""Estimate API cost."""
input_price = self.pricing[f"{model_tier}_input"]
output_price = self.pricing[f"{model_tier}_output"]
return (input_tokens * input_price + output_tokens * output_price) / 1_000_000
class SolutionCache:
"""
Cache for math solutions.
Uses semantic similarity to find cached solutions
for similar problems.
"""
def __init__(self, similarity_threshold: float = 0.95):
self.threshold = similarity_threshold
self.embedder = SentenceTransformer("all-mpnet-base-v2")
self.cache = {} # embedding -> solution
self.problem_texts = {} # embedding -> original problem
def get(self, problem: str) -> dict | None:
"""Find cached solution for similar problem."""
if not self.cache:
return None
embedding = self.embedder.encode(problem)
for cached_emb, solution in self.cache.items():
similarity = cosine_similarity([embedding], [cached_emb])[0][0]
if similarity > self.threshold:
return solution
return None
def set(self, problem: str, solution: dict):
"""Cache solution."""
embedding = tuple(self.embedder.encode(problem).tolist())
self.cache[embedding] = solution
self.problem_texts[embedding] = problem
Monitoring and Observability
Production systems fail silently without proper monitoring. A math solver might gradually degrade as problem distributions shift, or a model update might introduce subtle regressions. The monitoring system below tracks accuracy on known-answer problems, detects distribution shift, and alerts on anomalies. It also logs detailed metrics for debugging: which problem types fail most, where in the reasoning chain errors occur, and how much compute each problem consumes.
class MathSolverMonitor:
"""
Monitoring for production math solver.
Tracks:
1. Accuracy on known problems
2. Latency distribution
3. Cost per problem
4. Error patterns
"""
def __init__(self, solver: ProductionMathSolver):
self.solver = solver
self.metrics = defaultdict(list)
def solve_with_monitoring(
self,
problem: str,
ground_truth: str = None
) -> dict:
"""Solve with full monitoring."""
start_time = time.time()
try:
result = self.solver.solve(problem)
# Record metrics
self.metrics["latency"].append(result["latency"])
self.metrics["cost"].append(result["cost"])
self.metrics["source"].append(result["source"])
self.metrics["confidence"].append(result["confidence"])
if ground_truth:
is_correct = self.verify_answer(result["answer"], ground_truth)
self.metrics["correct"].append(is_correct)
if not is_correct:
self.log_error(problem, result, ground_truth)
return result
except Exception as e:
self.metrics["errors"].append({
"problem": problem[:100],
"error": str(e),
"timestamp": time.time()
})
raise
def get_summary(self) -> dict:
"""Get monitoring summary."""
return {
"total_problems": len(self.metrics["latency"]),
"accuracy": (
sum(self.metrics["correct"]) / len(self.metrics["correct"])
if self.metrics["correct"] else None
),
"avg_latency": np.mean(self.metrics["latency"]),
"p95_latency": np.percentile(self.metrics["latency"], 95),
"avg_cost": np.mean(self.metrics["cost"]),
"total_cost": sum(self.metrics["cost"]),
"source_distribution": Counter(self.metrics["source"]),
"error_count": len(self.metrics["errors"])
}
def log_error(self, problem: str, result: dict, ground_truth: str):
"""Log detailed error for analysis."""
error_record = {
"timestamp": time.time(),
"problem": problem,
"predicted": result["answer"],
"ground_truth": ground_truth,
"solution": result["solution"][:1000],
"confidence": result["confidence"],
"source": result["source"]
}
# Classify error type
error_record["error_type"] = self.classify_error(
result["solution"],
ground_truth
)
self.metrics["detailed_errors"].append(error_record)
def classify_error(self, solution: str, ground_truth: str) -> str:
"""Classify type of error for analysis."""
# Common error patterns
if "arithmetic" in solution.lower() and "=" in solution:
# Check for calculation errors
return "calculation_error"
if any(w in solution.lower() for w in ["assume", "suppose", "let"]):
return "wrong_assumption"
if "therefore" in solution.lower() or "thus" in solution.lower():
return "logical_error"
return "unknown"
Part 8: rStar-Math: Self-Play Mutual Reasoning
One of the most significant breakthroughs of 2025 is rStar-Math from Microsoft Research, demonstrating that small models can achieve frontier-level math reasoning through self-play without distillation from larger models.
┌─────────────────────────────────────────────────────────────────────────┐
│ rStar-Math ARCHITECTURE │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ KEY INSIGHT: │
│ ──────────── │
│ │
│ "Small LLMs can rival o1 without distillation from superior models" │
│ │
│ rStar-Math 7B achieves 90.0% on MATH (vs 58.8% baseline) │
│ rStar2-Agent 14B matches DeepSeek-R1 671B on AIME │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ CORE ARCHITECTURE: │
│ ────────────────── │
│ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ Self-Play Mutual Reasoning │ │
│ ├─────────────────────────────────────────────────────────────────┤ │
│ │ │ │
│ │ Problem │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌─────────────────────────────────────────────────────────┐ │ │
│ │ │ Generator (Policy SLM) │ │ │
│ │ │ Uses MCTS to explore reasoning trajectories │ │ │
│ │ │ Human-like actions: decompose, verify, backtrack │ │ │
│ │ └─────────────────────────────────────────────────────────┘ │ │
│ │ │ │ │
│ │ │ Multiple trajectories │ │
│ │ ▼ │ │
│ │ ┌─────────────────────────────────────────────────────────┐ │ │
│ │ │ Discriminator (Verifier SLM) │ │ │
│ │ │ Scores each trajectory for correctness │ │ │
│ │ │ Process Preference Model (PPM) │ │ │
│ │ └─────────────────────────────────────────────────────────┘ │ │
│ │ │ │ │
│ │ │ Mutually consistent trajectories │ │
│ │ ▼ │ │
│ │ ┌─────────────────────────────────────────────────────────┐ │ │
│ │ │ Self-Evolution Loop │ │ │
│ │ │ 4 rounds: Generator ↔ Verifier co-improve │ │ │
│ │ │ 747K problems, millions of solutions │ │ │
│ │ └─────────────────────────────────────────────────────────┘ │ │
│ │ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Three Key Innovations
rStar-Math's success comes from three innovations that work together. First, code-augmented data synthesis generates training data with verified solutions—using Python to check answers ensures no incorrect examples leak into training. Second, a Process Preference Model (PPM) learns to evaluate reasoning trajectories, enabling better search guidance than simple outcome rewards. Third, self-evolution through mutual reasoning lets the policy and PPM co-improve over multiple rounds without any external supervision.
1. Code-Augmented CoT Data Synthesis
The key to generating high-quality training data without a teacher model: generate reasoning traces that include Python code, then execute the code to verify the answer. Only traces with correct, verified answers become training data. This creates a self-supervised loop where the model's own successful reasoning becomes its training signal.
class RStarMathDataSynthesis:
"""
rStar-Math's novel data synthesis pipeline.
Key insight: Generate step-by-step verified reasoning using:
1. MCTS exploration with human-like actions
2. Code execution for verification
3. Extensive rollouts to find correct paths
"""
def __init__(
self,
policy_model,
code_executor,
n_rollouts: int = 64
):
self.policy = policy_model
self.executor = code_executor
self.n_rollouts = n_rollouts
# Human-like reasoning actions
self.actions = [
"decompose", # Break into subproblems
"analyze", # Examine given information
"compute", # Perform calculation (with code)
"verify", # Check intermediate result
"backtrack", # Abandon current approach
"conclude" # State final answer
]
def synthesize_trajectory(self, problem: str) -> dict:
"""
Generate verified reasoning trajectory using MCTS.
"""
root = MCTSNode(state=problem, action=None)
for _ in range(self.n_rollouts):
# Selection: traverse tree using UCB
node = self.select(root)
# Expansion: generate next reasoning step
if not node.is_terminal:
child = self.expand(node)
# Simulation: complete solution from this point
trajectory, final_answer = self.simulate(child)
# Verification: check with code execution
is_correct = self.verify_with_code(
problem,
trajectory,
final_answer
)
# Backpropagation: update tree with result
self.backpropagate(child, is_correct)
# Extract best trajectory
best_trajectory = self.extract_best_path(root)
return {
"problem": problem,
"trajectory": best_trajectory,
"verified": True,
"code_verified_steps": self.count_code_verified(best_trajectory)
}
def expand(self, node: MCTSNode) -> MCTSNode:
"""
Expand node with human-like reasoning action.
"""
# Select action based on current state
action = self.select_action(node.state)
# Generate reasoning step
if action == "compute":
# Generate code for computation
step = self.generate_code_step(node.state)
result = self.executor.run(step["code"])
step["code_result"] = result
else:
# Generate natural language step
step = self.generate_nl_step(node.state, action)
new_state = node.state + "\n" + step["text"]
child = MCTSNode(
state=new_state,
action=action,
parent=node
)
node.children.append(child)
return child
def verify_with_code(
self,
problem: str,
trajectory: str,
answer: str
) -> bool:
"""
Verify solution using Python code execution.
This is crucial: code provides ground-truth verification
that doesn't require a learned reward model.
"""
# Extract numerical answer
predicted = self.extract_numerical(answer)
if predicted is None:
return False
# Try to verify computationally
verification_code = f'''
# Verify the solution to: {problem[:200]}
# Claimed answer: {predicted}
# Re-derive answer from scratch
{self.generate_verification_code(problem)}
# Check if answer matches
is_correct = abs(computed_answer - {predicted}) < 1e-6
print(is_correct)
'''
try:
result = self.executor.run(verification_code)
return "True" in result
except:
# Fall back to answer matching
return self.answer_matches(predicted, problem)
2. Process Preference Model (PPM) Training
The Process Preference Model (PPM) is rStar-Math's alternative to standard Process Reward Models. The key insight: instead of training a model to score individual steps (which requires expensive step-level annotations), PPM learns to compare entire reasoning trajectories. Given two solution attempts, which one is better? This pairwise comparison is much easier to label—if one solution is correct and the other wrong, the correct one is preferred. If both are correct, prefer the one with clearer reasoning.
PPM training data comes directly from MCTS exploration: as the policy explores different reasoning paths, some lead to correct answers and others don't. These naturally form preference pairs. No human annotation required—the mathematical correctness provides the signal. Over training rounds, PPM learns to recognize not just correctness but quality of reasoning, enabling it to guide search toward elegant solutions.
class ProcessPreferenceModel:
"""
rStar-Math's novel reward model: Process Preference Model (PPM).
Unlike standard PRMs that score steps independently,
PPM learns preferences between reasoning trajectories.
Key innovation: Trained on trajectory pairs from MCTS exploration,
where correct paths are preferred over incorrect ones.
"""
def __init__(self, base_model, tokenizer):
self.model = base_model
self.tokenizer = tokenizer
# Add preference head
self.preference_head = nn.Linear(
base_model.config.hidden_size,
1
)
def train_ppm(
self,
problem: str,
trajectories: list[dict]
) -> torch.Tensor:
"""
Train PPM on trajectory preferences from MCTS.
Positive trajectories: reached correct answer
Negative trajectories: failed or incorrect
"""
# Separate by outcome
positive = [t for t in trajectories if t["correct"]]
negative = [t for t in trajectories if not t["correct"]]
if not positive or not negative:
return None
# Create preference pairs
pairs = []
for pos in positive:
for neg in negative:
pairs.append({
"problem": problem,
"chosen": pos["trajectory"],
"rejected": neg["trajectory"]
})
# Bradley-Terry preference loss
total_loss = 0
for pair in pairs:
# Score both trajectories
chosen_score = self.score_trajectory(
pair["problem"],
pair["chosen"]
)
rejected_score = self.score_trajectory(
pair["problem"],
pair["rejected"]
)
# Preference loss: chosen should score higher
loss = -F.logsigmoid(chosen_score - rejected_score)
total_loss += loss
return total_loss / len(pairs)
def score_trajectory(
self,
problem: str,
trajectory: str
) -> torch.Tensor:
"""
Score a reasoning trajectory.
Returns scalar score indicating trajectory quality.
"""
input_text = f"Problem: {problem}\n\nSolution:\n{trajectory}"
inputs = self.tokenizer(
input_text,
return_tensors="pt",
truncation=True,
max_length=4096
)
outputs = self.model(**inputs)
# Use last token's hidden state for scoring
last_hidden = outputs.last_hidden_state[:, -1, :]
score = self.preference_head(last_hidden)
return score.squeeze()
def guide_mcts(
self,
problem: str,
partial_trajectory: str,
candidates: list[str]
) -> list[float]:
"""
Use PPM to guide MCTS expansion.
Score candidate next steps to prioritize exploration.
"""
scores = []
for candidate in candidates:
full_trajectory = partial_trajectory + "\n" + candidate
score = self.score_trajectory(problem, full_trajectory)
scores.append(score.item())
return scores
3. Self-Evolution Recipe
The self-evolution loop is where rStar-Math achieves its remarkable results without any teacher model. The core idea: the policy model and PPM improve each other in alternating rounds. In each round:
- Policy generates many solution attempts using MCTS guided by the current PPM
- Filter by correctness: only keep solutions that produce correct answers (verified by code execution)
- PPM trains on the new (correct trajectory, incorrect trajectory) pairs from MCTS exploration
- Policy trains on the filtered correct solutions
This creates a virtuous cycle: better PPM → better MCTS guidance → higher quality solutions → better PPM training data → better PPM. After 4 rounds, a 7B model achieves 90% on MATH—comparable to GPT-4 and better than most 70B models. The key is that code execution provides ground truth, so the system never needs external supervision.
class RStarSelfEvolution:
"""
rStar-Math's self-evolution: policy and PPM co-improve.
4 rounds of evolution:
Round 1: Initial training on synthetic data
Round 2: Generate new solutions, filter by PPM, retrain
Round 3: Improve PPM with new data, generate better solutions
Round 4: Final refinement
Result: 7B model achieves 90% on MATH
"""
def __init__(
self,
policy_model,
ppm_model,
problem_bank, # 747K math problems
n_rounds: int = 4
):
self.policy = policy_model
self.ppm = ppm_model
self.problems = problem_bank
self.n_rounds = n_rounds
def evolve(self) -> dict:
"""
Run self-evolution loop.
"""
history = []
for round_num in range(self.n_rounds):
print(f"Evolution Round {round_num + 1}")
# Step 1: Generate solutions with current policy
solutions = self.generate_solutions(
n_per_problem=64 if round_num < 2 else 32
)
# Step 2: Filter with PPM
high_quality = self.filter_solutions(solutions)
# Step 3: Retrain policy on filtered solutions
policy_metrics = self.train_policy(high_quality)
# Step 4: Update PPM with new preference pairs
ppm_metrics = self.train_ppm(solutions)
# Step 5: Evaluate
eval_metrics = self.evaluate()
history.append({
"round": round_num + 1,
"n_solutions": len(solutions),
"n_high_quality": len(high_quality),
"policy_metrics": policy_metrics,
"ppm_metrics": ppm_metrics,
"eval": eval_metrics
})
print(f" MATH: {eval_metrics['math']:.1%}")
print(f" AIME: {eval_metrics['aime']:.1%}")
return history
def filter_solutions(
self,
solutions: list[dict],
threshold: float = 0.7
) -> list[dict]:
"""
Filter solutions using PPM scores.
Only keep high-scoring, verified-correct solutions
for policy training.
"""
filtered = []
for sol in solutions:
# Must be correct
if not sol["correct"]:
continue
# Must have high PPM score
score = self.ppm.score_trajectory(
sol["problem"],
sol["trajectory"]
)
if score > threshold:
filtered.append(sol)
return filtered
# Results of rStar-Math self-evolution
RSTAR_MATH_RESULTS = {
"qwen2.5_math_7b": {
"baseline": 58.8,
"round_1": 72.3,
"round_2": 81.5,
"round_3": 87.2,
"round_4": 90.0, # Final: 90.0% on MATH
},
"phi3_mini_3.8b": {
"baseline": 41.4,
"round_1": 58.7,
"round_2": 71.2,
"round_3": 80.1,
"round_4": 86.4, # 3.8B model achieves 86.4%!
}
}
rStar2-Agent: Thinking Smarter, Not Longer
The initial rStar-Math showed that self-play can achieve remarkable results. rStar2-Agent takes this further with a crucial insight: thinking smarter beats thinking longer. DeepSeek-R1 achieves its results with massive models and very long reasoning chains (10K+ tokens). rStar2-Agent matches this performance with a 14B model by generating more focused, efficient reasoning. The key is better training and search, not bigger models or longer outputs.
The latest evolution, rStar2-Agent, achieves DeepSeek-R1 (671B) performance with just 14B parameters:
class RStar2Agent:
"""
rStar2-Agent: Agentic RL for math reasoning.
Key insight: "Think smarter, not longer"
Achieves 80.6% on AIME24 with 14B parameters,
matching 671B DeepSeek-R1, but with shorter responses.
Trained in only 510 RL steps within one week.
"""
def __init__(
self,
base_model, # 14B pretrained
ppm,
code_executor
):
self.model = base_model
self.ppm = ppm
self.executor = code_executor
# Agentic tools
self.tools = {
"python": self.execute_python,
"verify": self.verify_step,
"search": self.search_similar_problems
}
def solve_with_agentic_rl(self, problem: str) -> dict:
"""
Solve problem using agentic reasoning.
The model learns to:
1. Decide when to use tools vs pure reasoning
2. Verify intermediate steps proactively
3. Backtrack efficiently when stuck
"""
state = {
"problem": problem,
"reasoning": [],
"tool_calls": [],
"current_approach": None
}
max_steps = 20
for step in range(max_steps):
# Model decides next action
action = self.decide_action(state)
if action["type"] == "reason":
# Pure reasoning step
reasoning = self.generate_reasoning(state)
state["reasoning"].append(reasoning)
elif action["type"] == "tool":
# Use tool (code execution, verification)
tool_name = action["tool"]
tool_input = action["input"]
result = self.tools[tool_name](tool_input)
state["tool_calls"].append({
"tool": tool_name,
"input": tool_input,
"result": result
})
elif action["type"] == "backtrack":
# Abandon current approach
state["reasoning"] = state["reasoning"][:-3]
state["current_approach"] = None
elif action["type"] == "conclude":
# Extract and return answer
answer = self.extract_answer(state)
return {
"answer": answer,
"reasoning": state["reasoning"],
"tool_calls": state["tool_calls"],
"n_steps": step + 1
}
return {"answer": None, "error": "max_steps_exceeded"}
def decide_action(self, state: dict) -> dict:
"""
Model decides what to do next.
Trained via RL to maximize correct answers
while minimizing reasoning length.
"""
prompt = f"""Problem: {state['problem']}
Current reasoning:
{chr(10).join(state['reasoning'][-5:])}
Recent tool results:
{state['tool_calls'][-2:] if state['tool_calls'] else 'None'}
What should I do next?
Options:
- reason: Continue reasoning
- tool:python: Execute Python code
- tool:verify: Verify current step
- backtrack: Try different approach
- conclude: State final answer
Decision:"""
response = self.model.generate(prompt, max_tokens=100)
return self.parse_action(response)
# rStar2-Agent Results
RSTAR2_AGENT_RESULTS = {
"model": "14B",
"training": "510 RL steps, 1 week",
"aime24": 80.6, # Average pass@1
"aime25": 69.8, # Average pass@1
"comparison": {
"deepseek_r1_671b": {"aime24": 79.8, "aime25": 70.0},
"o1_preview": {"aime24": 44.6, "aime25": None},
},
"key_insight": "Shorter responses, same accuracy"
}
Part 9: DeepSeekMath-V2: Self-Verifiable Reasoning
DeepSeekMath-V2 achieved gold-medal scores on IMO 2025 through a novel self-verification framework:
┌─────────────────────────────────────────────────────────────────────────┐
│ DeepSeekMath-V2 ARCHITECTURE │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ KEY INSIGHT: │
│ ──────────── │
│ │
│ "Incentivize the generator to identify and resolve as many issues │
│ as possible in their own proofs before finalizing them" │
│ │
│ Results: │
│ • Gold-level on IMO 2025 │
│ • Gold-level on CMO 2024 │
│ • 118/120 on Putnam 2024 │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ GENERATOR-VERIFIER TRAINING LOOP: │
│ ───────────────────────────────── │
│ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ │ │
│ │ ┌──────────────┐ ┌──────────────┐ │ │
│ │ │ Generator │───────▶│ Verifier │ │ │
│ │ │ (LLM) │ │ (LLM) │ │ │
│ │ └──────────────┘ └──────────────┘ │ │
│ │ │ │ │ │
│ │ │ Generates proof │ Identifies issues │ │
│ │ │ │ │ │
│ │ ▼ ▼ │ │
│ │ ┌──────────────┐ ┌──────────────┐ │ │
│ │ │ Self-Refine │◀───────│ Feedback │ │ │
│ │ │ Loop │ │ Signal │ │ │
│ │ └──────────────┘ └──────────────┘ │ │
│ │ │ │ │
│ │ │ Improved proof │ │
│ │ ▼ │ │
│ │ ┌──────────────┐ │ │
│ │ │ Final │ (Only when verifier finds no issues) │ │
│ │ │ Output │ │ │
│ │ └──────────────┘ │ │
│ │ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────┘
class DeepSeekMathV2:
"""
DeepSeekMath-V2: Self-Verifiable Mathematical Reasoning.
Core innovation: Train generator to self-correct before
submitting final answer, using verifier as reward signal.
Key insight: Don't just train for correct answers—train
for finding and fixing your own mistakes.
"""
def __init__(
self,
generator_model,
verifier_model,
max_refinement_rounds: int = 5
):
self.generator = generator_model
self.verifier = verifier_model
self.max_rounds = max_refinement_rounds
def solve_with_self_verification(self, problem: str) -> dict:
"""
Solve problem with iterative self-verification.
"""
# Initial solution attempt
solution = self.generator.generate(
f"Solve this problem with rigorous step-by-step derivation:\n\n{problem}"
)
refinement_history = []
for round_num in range(self.max_rounds):
# Verify current solution
verification = self.verify(problem, solution)
refinement_history.append({
"round": round_num,
"solution_preview": solution[:500],
"issues_found": verification["issues"],
"confidence": verification["confidence"]
})
# If no issues found, we're done
if not verification["issues"]:
return {
"solution": solution,
"answer": self.extract_answer(solution),
"verified": True,
"refinement_rounds": round_num,
"history": refinement_history
}
# Refine based on issues
solution = self.refine(problem, solution, verification["issues"])
# Max rounds reached
return {
"solution": solution,
"answer": self.extract_answer(solution),
"verified": False,
"refinement_rounds": self.max_rounds,
"history": refinement_history
}
def verify(self, problem: str, solution: str) -> dict:
"""
Verify solution and identify issues.
The verifier is trained to be highly critical—
it's better to flag false positives than miss errors.
"""
prompt = f"""Carefully verify this mathematical solution.
Problem: {problem}
Proposed Solution:
{solution}
Check for:
1. Logical errors (unjustified steps, non-sequiturs)
2. Calculation mistakes
3. Missing cases or edge conditions
4. Incorrect theorem applications
5. Gaps in reasoning
For each issue found, explain:
- What the issue is
- Where it occurs
- Why it's wrong
- How to fix it
If the solution is correct, state "NO ISSUES FOUND".
Verification:"""
response = self.verifier.generate(prompt, max_tokens=2048)
issues = self.parse_issues(response)
confidence = self.estimate_confidence(response)
return {
"issues": issues,
"confidence": confidence,
"raw_response": response
}
def refine(
self,
problem: str,
solution: str,
issues: list[dict]
) -> str:
"""
Refine solution based on identified issues.
"""
issues_text = "\n".join([
f"- {issue['description']}: {issue['explanation']}"
for issue in issues
])
prompt = f"""Fix the following issues in your solution.
Problem: {problem}
Your Previous Solution:
{solution}
Issues Found:
{issues_text}
Provide a corrected solution that addresses ALL issues.
Be rigorous and show your reasoning clearly.
Corrected Solution:"""
return self.generator.generate(prompt, max_tokens=4096)
class SelfVerificationTraining:
"""
Training procedure for self-verifiable reasoning.
Key insight: Train generator with verifier as reward signal,
but also train generator to internalize verification.
"""
def __init__(
self,
generator,
verifier,
problems_dataset
):
self.generator = generator
self.verifier = verifier
self.problems = problems_dataset
def train_step(self, problem: str, ground_truth: str) -> dict:
"""
Single training step with self-verification reward.
"""
# Generate initial solution
solution = self.generator.generate(problem)
# Get verification feedback
verification = self.verify_solution(solution, ground_truth)
# Compute rewards
rewards = {
# Correctness reward
"correctness": 1.0 if verification["correct"] else 0.0,
# Self-identification reward: did generator flag its own errors?
"self_awareness": self.compute_self_awareness_reward(
solution,
verification
),
# Efficiency reward: fewer refinement rounds is better
"efficiency": 1.0 / (1 + verification["refinements_needed"])
}
total_reward = (
0.6 * rewards["correctness"] +
0.3 * rewards["self_awareness"] +
0.1 * rewards["efficiency"]
)
return {
"reward": total_reward,
"rewards_breakdown": rewards,
"solution": solution,
"verification": verification
}
def compute_self_awareness_reward(
self,
solution: str,
verification: dict
) -> float:
"""
Reward generator for self-identifying issues.
If the solution includes self-doubt markers where
errors actually exist, that's good.
If it's confidently wrong, that's bad.
"""
# Check for self-doubt markers in solution
doubt_markers = [
"let me verify",
"checking this",
"wait, this might be wrong",
"hmm, let me reconsider",
"i should double-check"
]
has_self_doubt = any(
marker in solution.lower()
for marker in doubt_markers
)
if verification["correct"]:
# Correct solution: self-doubt is neutral
return 0.5
else:
# Wrong solution: self-doubt is good, confidence is bad
return 0.8 if has_self_doubt else 0.0
class ScaledVerificationCompute:
"""
Scale verification compute to handle harder problems.
Key insight: When generator starts solving problems the
verifier can't evaluate, generate new training data
for the verifier from these hard cases.
"""
def __init__(
self,
generator,
verifier,
code_executor
):
self.generator = generator
self.verifier = verifier
self.executor = code_executor
def scale_verification(
self,
hard_problems: list[str]
) -> list[dict]:
"""
Generate training data for verifier from hard problems.
When the verifier can't evaluate a proof:
1. Generate multiple solutions
2. Use code execution to determine correctness
3. Create training examples for verifier
"""
verifier_training_data = []
for problem in hard_problems:
# Generate many solutions
solutions = [
self.generator.generate(problem)
for _ in range(32)
]
# Use code execution to verify
for solution in solutions:
# Try to verify with code
code_result = self.code_verify(problem, solution)
if code_result["verified"]:
# Create training example for verifier
verifier_training_data.append({
"problem": problem,
"solution": solution,
"is_correct": code_result["correct"],
"issues": code_result.get("issues", [])
})
return verifier_training_data
def code_verify(
self,
problem: str,
solution: str
) -> dict:
"""
Verify solution using code execution.
This provides ground truth for training the verifier
on problems it previously couldn't evaluate.
"""
# Extract answer from solution
answer = self.extract_answer(solution)
# Generate verification code
verification_code = self.generate_verification_code(
problem,
answer
)
try:
result = self.executor.run(verification_code)
return {
"verified": True,
"correct": "CORRECT" in result,
"code_output": result
}
except Exception as e:
return {
"verified": False,
"error": str(e)
}
DeepSeekMath-V2: The Complete Technical Picture
DeepSeekMath-V2's gold medal at IMO 2025 came from a carefully designed verifier-first training pipeline that represents a fundamental shift in how mathematical reasoning models are trained.
Why DeepSeekMath-V2 Succeeded Where Others Failed
To understand DeepSeekMath-V2's breakthrough, we need to understand why previous approaches struggled with IMO-level mathematics.
The fundamental challenge of mathematical reasoning:
IMO problems aren't like typical AI benchmarks. They require:
-
Multi-step logical chains: Solutions often span 20-50 reasoning steps, where each step must be logically valid. A single error anywhere invalidates the entire proof.
-
Creative insight: Unlike routine exercises, IMO problems require discovering non-obvious key insights—the "aha moments" that unlock solutions. These insights often involve seeing connections between seemingly unrelated mathematical concepts.
-
Rigorous justification: In competition mathematics, "I think this is true" isn't enough. Every claim must be proven. Olympiad graders penalize hand-waving mercilessly.
-
Case coverage: Many problems require exhaustive case analysis. Missing even one case means the proof is incomplete.
Why previous approaches failed:
| Approach | Limitation | Result |
|---|---|---|
| Fine-tuning on math data | Model learns to pattern-match, not reason | Good on routine problems, fails on novel ones |
| Chain-of-thought prompting | No error correction mechanism | Errors propagate and compound |
| Reward models (standard RLHF) | Reward hacking: model learns to fool verifier | Superficially correct but logically flawed proofs |
| Formal verification (Lean) | Translation bottleneck | ~30% of problems can't be formalized cleanly |
DeepSeek's key insight:
The problem with standard RLHF is that the generator and verifier are trained together. This creates a co-evolution where the generator learns to exploit weaknesses in the verifier—producing proofs that look correct but contain subtle flaws the verifier misses.
DeepSeek inverted this: train the verifier first, completely independently, on human-annotated proofs. The verifier learns what "rigorous proof" means from human mathematicians, not from the generator's outputs. Only then is the generator trained to satisfy this pre-trained verifier.
This seemingly simple change has profound implications:
- The verifier's standards are fixed during generator training
- The generator can't "game" the verifier because the verifier was trained on real proofs, not generator outputs
- The verifier learns to spot actual mathematical errors, not just patterns correlated with errors
The 685B parameter model:
DeepSeekMath-V2 is built on DeepSeek-V3, a massive Mixture-of-Experts (MoE) model. The architecture matters:
- 685B total parameters: Comparable to GPT-4 scale
- 37B active per token: MoE routing means only a fraction of parameters are used for each prediction, keeping inference tractable
- 256 experts with top-8 routing: For each token, the 8 most relevant experts are activated
- 128K context window: Long contexts are essential for multi-step proofs
The MoE architecture is particularly well-suited for mathematics because different experts can specialize in different mathematical domains (algebra, geometry, number theory, combinatorics). When solving a geometry problem, geometry-specialized experts activate more frequently.
Math-specific pre-training:
Before the verifier-first training, the base model underwent extensive math pre-training:
- 500B tokens of mathematical content: Research papers, textbooks, competition solutions, proofs
- Extended tokenizer: Special tokens for mathematical symbols (∑, ∫, √, ∀, ∃) and LaTeX expressions
- Structured problem-solution pairs: The model learns the format of mathematical reasoning
This pre-training gives the model "mathematical common sense"—it knows that √2 is irrational, that prime numbers have exactly two divisors, and thousands of other mathematical facts that would otherwise need to be derived from scratch.
The Four-Phase Training Pipeline in Detail
DeepSeekMath-V2's training proceeds through four carefully sequenced phases. Each phase builds on the previous one, and the order is critical—reversing or parallelizing phases would undermine the entire approach.
Phase 1: Verifier Pre-Training (The Foundation)
The verifier is trained on a curated dataset of human-annotated mathematical proofs. This dataset is special:
- Source: 50,000+ competition mathematics solutions from IMO, Putnam, national olympiads
- Annotation: Each solution is annotated by expert mathematicians with:
- Overall correctness label (correct/incorrect)
- Step-by-step correctness (which steps are valid/invalid)
- Error classifications (logical gap, calculation error, missing case, unjustified claim)
- Rigor scores (1-10 for each step)
The verifier learns through multiple objectives:
- Binary classification: Given a proof, predict correct/incorrect
- Error localization: Given an incorrect proof, identify which step contains the first error
- Error typing: Classify what kind of error it is
- Rigor scoring: Rate the rigor of each step on a continuous scale
The key is that this training uses only human annotations, never generator outputs. The verifier learns human mathematicians' standards for rigor.
Phase 2: Meta-Verifier Training (Combining Signals)
A single correctness score isn't enough for training. The meta-verifier learns to combine multiple quality dimensions into a single scalar reward for GRPO:
| Dimension | Weight | What It Measures |
|---|---|---|
| Correctness | 0.40 | Is the final answer right? |
| Rigor | 0.30 | Are all steps fully justified? |
| Completeness | 0.20 | Are all cases handled? |
| Clarity | 0.10 | Is the logic chain followable? |
The weights are learned, not hand-tuned. The meta-verifier is trained to predict human quality judgments on a held-out set of proofs with human quality scores.
Why separate rigor from correctness? A proof can reach the right answer through flawed reasoning (lucky errors cancel out), or have rigorous steps but miss a case. Competition graders care about both.
Phase 3: Generator Initialization from Verifier (The Key Innovation)
This is what makes DeepSeekMath-V2 different. Instead of initializing the generator from the base model, they initialize it from the trained verifier's weights.
Why does this work? The verifier has learned to represent "what makes a good proof." These representations—encoded in the model's weights—transfer to generation. The generator starts with an internal model of proof quality before it generates a single proof.
The initialization procedure:
- Copy all weights from the trained verifier
- The verifier's classification head is discarded
- A new language modeling head is added
- Brief supervised fine-tuning (1000 steps) on proof generation to "unlock" generation capability
The generator now has:
- Deep understanding of proof structure (from pre-training)
- Internal model of what makes proofs rigorous (from verifier)
- Generation capability (from fine-tuning)
Phase 4: Iterative GRPO with Verifier Co-Evolution
Now the generator and verifier train together, but with the verifier's pre-training providing a strong foundation that resists reward hacking.
The GRPO (Group Relative Policy Optimization) procedure:
For each training iteration:
1. Sample a batch of 64 problems
2. For each problem, generate 8 candidate solutions
3. Meta-verifier scores all 512 solutions
4. Normalize scores within each problem's 8 solutions (group relative)
5. Update generator to increase probability of high-scoring solutions
6. Find "hard cases" where generator and verifier disagree strongly
7. Optionally update verifier on hard cases with human labels
Why group relative scoring?
Standard reward models give absolute scores. The problem: as the generator improves, its solutions cluster in a narrow score range, providing weak gradient signal.
GRPO normalizes within each group of 8 solutions. Even if all 8 solutions are excellent, the best one gets positive advantage, the worst gets negative. This provides consistent training signal throughout training.
The self-verification loop at inference time:
At inference, DeepSeekMath-V2 doesn't just generate once. It implements a generate-verify-refine loop:
1. Generate initial solution
2. Verifier checks for issues
3. If issues found:
a. Verifier identifies the problematic step
b. Generator regenerates from that point
c. Repeat verification
4. If no issues found:
a. Solution passes to output
5. Maximum 5 refinement rounds
This is different from standard "self-consistency" (majority voting over multiple generations). The verifier provides targeted feedback about what's wrong, allowing surgical fixes rather than complete regeneration.
Specific Techniques That Made the Difference
1. Error-Aware Training Data
Most math datasets only have correct solutions. DeepSeek collected and annotated incorrect solutions too:
- Intentionally flawed solutions with labeled errors
- Student solutions with common mistakes
- Model-generated errors from earlier training runs
This lets the verifier learn what errors look like, not just what correct proofs look like.
2. Step-Level Credit Assignment
For long proofs, which step caused success/failure? DeepSeek's training assigns credit at the step level:
- Each step gets a "contribution score" based on how it affects the final proof validity
- Steps that introduce key insights get higher credit
- Steps that introduce errors get negative credit
This is implemented via a learned "step importance" model trained on human annotations of "key steps" in solutions.
3. Diverse Generation Strategies
At inference, multiple solution attempts use different strategies:
- Forward reasoning: Start from given, derive conclusion
- Backward reasoning: Start from what we want to prove, work backward
- Case analysis: Enumerate cases systematically
- Contradiction: Assume the negation, derive contradiction
- Induction: Base case + inductive step
- Extremal principle: Consider minimum/maximum elements
- Invariants: Find quantities preserved under operations
- Construction: Build an explicit example/counter-example
Each strategy corresponds to a different prompt prefix, encouraging diverse solution attempts.
4. Verification-Guided Tree Search
For the hardest problems, DeepSeekMath-V2 uses a tree search where the verifier guides exploration:
┌─────────────────────────────────────────────────────────────────────────┐
│ VERIFICATION-GUIDED TREE SEARCH │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌───────────────┐ │
│ │ Problem │ │
│ └───────┬───────┘ │
│ │ │
│ ┌─────────────────┴─────────────────┐ │
│ │ │ │
│ ▼ ▼ │
│ ┌─────────────────┐ ┌─────────────────┐ │
│ │ Path A │ │ Path B │ │
│ │ Score: 0.8 ✓ │ │ Score: 0.3 ✗ │ │
│ └────────┬────────┘ └─────────────────┘ │
│ │ │ │
│ EXPAND PRUNE │
│ │ (too low score) │
│ ┌────────┴────────┐ │
│ │ │ │
│ ▼ ▼ │
│ ┌──────┐ ┌──────┐ │
│ │ A1 │ │ A2 │ │
│ │ 0.85 │ │ 0.72 │ │
│ └──────┘ └──────┘ │
│ │
│ The verifier scores each partial proof, enabling: │
│ • Early pruning of unpromising paths (Path B) │
│ • Focused compute on high-potential branches (Path A) │
│ • More efficient than exhaustive or random search │
│ │
└─────────────────────────────────────────────────────────────────────────┘
The verifier scores partial proofs, allowing aggressive pruning of unpromising paths. This is more efficient than exhaustive search or random sampling.
┌─────────────────────────────────────────────────────────────────────────┐
│ DEEPSEEKMATH-V2: VERIFIER-FIRST TRAINING PIPELINE │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ CORE INSIGHT: │
│ ───────────── │
│ │
│ "Train the verifier FIRST, then train the generator to satisfy the │
│ verifier. This is the opposite of typical reward model approaches." │
│ │
│ Why this works: │
│ • Verifier learns what "rigorous proof" means without generator bias │
│ • Generator then learns to produce proofs the verifier accepts │
│ • Creates virtuous cycle: better verifier → better generator → ... │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ MODEL ARCHITECTURE: │
│ ─────────────────── │
│ │
│ ┌──────────────────────────────────────────────────────────────┐ │
│ │ DeepSeekMath-V2 (685B) │ │
│ ├──────────────────────────────────────────────────────────────┤ │
│ │ │ │
│ │ Base: DeepSeek-V3 Mixture of Experts │ │
│ │ • 685B total parameters (37B active per token) │ │
│ │ • 256 experts, top-8 routing │ │
│ │ • 128K context window │ │
│ │ │ │
│ │ Math-specific adaptations: │ │
│ │ • Extended math tokenizer (special symbols, LaTeX) │ │
│ │ • Math-focused pre-training (500B math tokens) │ │
│ │ • Self-verification heads for each expert │ │
│ │ │ │
│ └──────────────────────────────────────────────────────────────┘ │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ TRAINING PHASES: │
│ ──────────────── │
│ │
│ Phase 1: Verifier Pre-Training (Independent of Generator) │
│ ───────────────────────────────────────────────────────── │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ │ │
│ │ Training data: Human-annotated proof evaluations │ │
│ │ • IMO/Putnam solutions labeled correct/incorrect │ │
│ │ • Step-by-step rigor annotations │ │
│ │ • Error type classifications │ │
│ │ │ │
│ │ Objective: Predict proof correctness + identify error locations │ │
│ │ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │
│ Phase 2: Meta-Verifier Training │
│ ──────────────────────────────── │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ │ │
│ │ The meta-verifier scores proofs on multiple dimensions: │ │
│ │ │ │
│ │ • Rigor score: Are all steps fully justified? │ │
│ │ • Completeness score: Are all cases handled? │ │
│ │ • Clarity score: Is the logic chain clear? │ │
│ │ • Novelty score: Does it use creative approaches? │ │
│ │ │ │
│ │ Combined into single "proof quality" signal for GRPO │ │
│ │ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │
│ Phase 3: Generator Training from Verifier Checkpoint │
│ ─────────────────────────────────────────────────── │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ │ │
│ │ Initialize generator from VERIFIER weights (not base model!) │ │
│ │ • Generator inherits verifier's understanding of "good proof" │ │
│ │ • Then fine-tune for generation while preserving verification │ │
│ │ │ │
│ │ This is the key innovation: generator already "knows" what │ │
│ │ makes a proof rigorous before it learns to generate │ │
│ │ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │
│ Phase 4: Iterative GRPO Refinement │
│ ────────────────────────────────── │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ │ │
│ │ for epoch in range(num_epochs): │ │
│ │ # Generator produces solutions │ │
│ │ solutions = generator.sample(problems, n=64) │ │
│ │ │ │
│ │ # Verifier scores each solution │ │
│ │ scores = verifier.score(solutions) │ │
│ │ │ │
│ │ # GRPO update: push generator toward high-scoring proofs │ │
│ │ generator = grpo_update(generator, solutions, scores) │ │
│ │ │ │
│ │ # CRITICAL: Also update verifier on hard cases │ │
│ │ hard_cases = find_disagreements(generator, verifier) │ │
│ │ verifier = update_verifier(verifier, hard_cases) │ │
│ │ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │
│ RESULTS: │
│ ──────── │
│ • Gold medal (35/42) at IMO 2025 │
│ • Gold medal at China Mathematical Olympiad 2024 │
│ • 118/120 on Putnam 2024 │
│ • 96%+ on MATH benchmark │
│ │
└─────────────────────────────────────────────────────────────────────────┘
class DeepSeekMathV2VerifierFirstTraining:
"""
DeepSeekMath-V2's actual training procedure.
The key insight: train verifier completely before touching generator.
This prevents the "reward hacking" problem where generators learn
to fool the verifier rather than produce rigorous proofs.
"""
def __init__(
self,
base_model, # DeepSeek-V3 685B MoE
proof_dataset, # Human-annotated IMO/Putnam solutions
meta_verifier_config
):
self.base = base_model
self.proofs = proof_dataset
self.meta_config = meta_verifier_config
def phase1_train_verifier(self):
"""
Phase 1: Train verifier on human-annotated proofs.
The verifier learns to identify:
- Correct vs incorrect proofs
- Types of errors (logical gap, calculation error, missing case)
- Quality dimensions (rigor, completeness, clarity)
CRITICAL: No generator involved yet. Verifier learns from
human annotations, not from generator outputs.
"""
verifier = copy.deepcopy(self.base)
# Train on human-labeled proof correctness
for proof_example in self.proofs.correctness_labels:
loss = verifier.train_step(
input=f"Proof:\n{proof_example['proof']}\n\nIs this proof correct?",
target=proof_example['label'], # "correct" or "incorrect"
)
# Train on error identification
for proof_example in self.proofs.error_annotations:
loss = verifier.train_step(
input=f"Proof:\n{proof_example['proof']}\n\nIdentify all errors:",
target=proof_example['errors'], # List of errors with locations
)
# Train on step-by-step rigor assessment
for proof_example in self.proofs.rigor_annotations:
loss = verifier.train_step(
input=f"Proof:\n{proof_example['proof']}\n\nRate rigor of each step:",
target=proof_example['step_scores'],
)
return verifier
def phase2_train_meta_verifier(self, verifier):
"""
Phase 2: Train meta-verifier that combines multiple quality signals.
The meta-verifier learns to produce a single score that captures:
- Correctness (is it right?)
- Rigor (are steps justified?)
- Completeness (all cases covered?)
- Clarity (is it understandable?)
"""
meta_verifier = MetaVerifier(verifier, self.meta_config)
# Calibrate weights for each dimension
for proof_example in self.proofs.quality_annotations:
predicted_scores = meta_verifier.score_all_dimensions(
proof_example['proof']
)
# Multi-task loss: predict each quality dimension
loss = sum([
mse_loss(predicted_scores[dim], proof_example[dim])
for dim in ['rigor', 'completeness', 'clarity', 'correctness']
])
meta_verifier.update(loss)
return meta_verifier
def phase3_initialize_generator_from_verifier(self, verifier):
"""
Phase 3: Initialize generator from verifier weights.
THIS IS THE KEY INNOVATION.
Instead of initializing generator from base model,
initialize from the TRAINED VERIFIER.
Why? The generator now starts with deep understanding of
what makes a proof rigorous, before it even learns to generate.
It's like learning to critique essays before writing them.
"""
generator = copy.deepcopy(verifier)
# Adapt verifier architecture for generation
# (Verifier was trained as classifier, generator needs to produce text)
generator = self.adapt_for_generation(generator)
# Brief generation warm-up (don't train too long, preserve verifier knowledge)
for problem in self.proofs.generation_warmup[:1000]:
loss = generator.train_step(
input=f"Problem: {problem['problem']}\n\nProvide a rigorous proof:",
target=problem['solution'],
)
return generator
def phase4_iterative_grpo(self, generator, meta_verifier, epochs=10):
"""
Phase 4: Iterative GRPO with co-evolution of generator and verifier.
Unlike standard GRPO, we also update the verifier during training.
This prevents the generator from finding reward-hacking strategies.
"""
for epoch in range(epochs):
# Sample problems
problems = self.proofs.sample_problems(batch_size=64)
for problem in problems:
# Generator produces multiple solutions
solutions = [
generator.generate(problem, temperature=0.7)
for _ in range(8)
]
# Meta-verifier scores each solution
scores = [
meta_verifier.score(problem, solution)
for solution in solutions
]
# GRPO update: increase probability of high-scoring solutions
grpo_loss = self.grpo_objective(
generator,
solutions,
scores,
reference_model=self.base
)
generator.update(grpo_loss)
# Find cases where generator and verifier disagree
disagreements = self.find_disagreements(
generator, meta_verifier, problem
)
# Update verifier on disagreements (with human labels if needed)
if disagreements:
verifier_loss = self.update_verifier_on_hard_cases(
meta_verifier, disagreements
)
meta_verifier.update(verifier_loss)
return generator, meta_verifier
def grpo_objective(
self,
generator,
solutions: list[str],
scores: list[float],
reference_model
) -> float:
"""
Group Relative Policy Optimization.
Key: Compare solutions to each other, not to absolute threshold.
This provides stable learning signal even as model improves.
"""
# Normalize scores within group
mean_score = sum(scores) / len(scores)
std_score = (sum((s - mean_score)**2 for s in scores) / len(scores)) ** 0.5
if std_score < 1e-6:
return 0.0 # All solutions have same score, no gradient
normalized_scores = [(s - mean_score) / std_score for s in scores]
# GRPO loss: push toward high-scoring solutions
loss = 0.0
for solution, norm_score in zip(solutions, normalized_scores):
# Log probability under generator
log_prob = generator.log_prob(solution)
# Log probability under reference (for KL constraint)
ref_log_prob = reference_model.log_prob(solution)
# Advantage-weighted policy gradient
loss -= norm_score * (log_prob - 0.1 * (log_prob - ref_log_prob))
return loss / len(solutions)
def find_disagreements(self, generator, verifier, problem):
"""
Find cases where generator produces solutions that
verifier scores very differently than expected.
These are valuable for improving the verifier.
"""
disagreements = []
# Generate diverse solutions
solutions = [
generator.generate(problem, temperature=1.0)
for _ in range(32)
]
for solution in solutions:
verifier_score = verifier.score(problem, solution)
generator_confidence = generator.confidence(solution)
# Disagreement: generator confident, verifier uncertain
if generator_confidence > 0.9 and verifier_score < 0.5:
disagreements.append({
'problem': problem,
'solution': solution,
'generator_confidence': generator_confidence,
'verifier_score': verifier_score,
'type': 'generator_overconfident'
})
# Disagreement: generator uncertain, verifier confident
if generator_confidence < 0.5 and verifier_score > 0.9:
disagreements.append({
'problem': problem,
'solution': solution,
'generator_confidence': generator_confidence,
'verifier_score': verifier_score,
'type': 'verifier_overconfident'
})
return disagreements
class MetaVerifier:
"""
Meta-verifier that combines multiple quality signals into
a single score for GRPO training.
This is what makes DeepSeekMath-V2 different from standard
reward models: it scores on RIGOR, not just correctness.
"""
def __init__(self, base_verifier, config):
self.verifier = base_verifier
self.config = config
# Learned weights for combining dimensions
self.dimension_weights = {
'correctness': 0.4, # Is the final answer right?
'rigor': 0.3, # Are all steps fully justified?
'completeness': 0.2, # Are all cases handled?
'clarity': 0.1 # Is the logic chain clear?
}
def score(self, problem: str, solution: str) -> float:
"""
Score a solution on multiple dimensions, combine into single value.
"""
scores = self.score_all_dimensions(solution)
combined = sum(
self.dimension_weights[dim] * scores[dim]
for dim in scores
)
return combined
def score_all_dimensions(self, solution: str) -> dict:
"""
Score solution on each quality dimension.
This is the key to training for RIGOR, not just correctness.
A proof can be correct but not rigorous (missing justifications).
A proof can be rigorous but incomplete (missing cases).
"""
# Correctness: Is the final answer right?
correctness_prompt = f"""
Solution:
{solution}
Is the final answer correct? Score 0-1:"""
correctness = self.verifier.generate(correctness_prompt)
# Rigor: Are all steps fully justified?
rigor_prompt = f"""
Solution:
{solution}
Rate the rigor of this proof. A rigorous proof:
- Justifies every step explicitly
- Cites theorems/lemmas when used
- Has no logical gaps
Score 0-1:"""
rigor = self.verifier.generate(rigor_prompt)
# Completeness: Are all cases handled?
completeness_prompt = f"""
Solution:
{solution}
Are all cases handled? Does the proof:
- Consider all possibilities?
- Handle edge cases?
- Account for boundary conditions?
Score 0-1:"""
completeness = self.verifier.generate(completeness_prompt)
# Clarity: Is the logic chain clear?
clarity_prompt = f"""
Solution:
{solution}
Is this proof clear and well-organized?
- Are steps in logical order?
- Is notation consistent?
- Can you follow the argument?
Score 0-1:"""
clarity = self.verifier.generate(clarity_prompt)
return {
'correctness': float(correctness),
'rigor': float(rigor),
'completeness': float(completeness),
'clarity': float(clarity)
}
Part 10: Formal Verification with Lean 4
Formal verification provides absolute guarantees of correctness. DeepSeek-Prover-V2 achieved 88.9% on MiniF2F-test using Lean 4.
┌─────────────────────────────────────────────────────────────────────────┐
│ FORMAL VERIFICATION LANDSCAPE │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ WHY FORMAL VERIFICATION? │
│ ───────────────────────── │
│ │
│ LLM reasoning can be wrong even when it looks right. │
│ Formal verification is ABSOLUTE: if Lean accepts, it's correct. │
│ │
│ "When an answer comes with a Lean4 proof, you don't have to trust │
│ the AI—you can check it." - Lean4 research │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ KEY SYSTEMS (2025): │
│ ─────────────────── │
│ │
│ System Performance Approach │
│ ───────────────────────────────────────────────────────────────────── │
│ DeepSeek-Prover-V2 88.9% MiniF2F RL + Subgoal decomp │
│ (671B) 49/658 PutnamBench │
│ │
│ Harmonic Aristotle Gold IMO 2025 Formally verified │
│ │
│ AlphaProof Silver IMO 2024 Lean + RL │
│ │
│ Goedel-Prover 1.64M formal stmts Statement formalizer │
│ │
│ APOLLO Framework 3-7% → 40%+ Syntax cleaning + LLM │
│ │
│ Lean Copilot Tactic suggestion LLM in Lean │
│ │
└─────────────────────────────────────────────────────────────────────────┘
DeepSeek-Prover-V2: Subgoal Decomposition
DeepSeek-Prover-V2 achieved 88.9% on MiniF2F-test through a clever decomposition strategy: instead of trying to prove theorems in one shot, it breaks complex proofs into subgoals, proves each subgoal separately, then combines them. This mirrors how human mathematicians tackle hard proofs—find intermediate lemmas that, once proven, make the main theorem tractable. The implementation uses reinforcement learning to learn good decomposition strategies.
class DeepSeekProverV2:
"""
DeepSeek-Prover-V2: Formal theorem proving in Lean 4.
Key innovations:
1. Subgoal decomposition with RL
2. Two-stage: 671B for decomposition, 7B for proof search
3. Non-CoT mode for fast generation
Results:
- 88.9% on MiniF2F-test
- 49/658 PutnamBench problems
- 6/15 AIME problems (formalized)
"""
def __init__(
self,
decomposer_model, # Large model (671B)
prover_model, # Small model (7B)
lean_env
):
self.decomposer = decomposer_model
self.prover = prover_model
self.lean = lean_env
def prove_theorem(self, theorem: str) -> dict:
"""
Prove theorem using subgoal decomposition.
1. Decomposer generates high-level proof sketch
2. Decomposer formalizes sketch into subgoals
3. Prover solves each subgoal
4. Combine into complete proof
"""
# Step 1: Generate proof sketch (CoT)
sketch = self.decomposer.generate(f"""
Theorem: {theorem}
Generate a high-level proof sketch with key steps:
""")
# Step 2: Decompose into formal subgoals
subgoals = self.decompose_to_subgoals(theorem, sketch)
# Step 3: Prove each subgoal
proofs = []
for subgoal in subgoals:
proof = self.prove_subgoal(subgoal)
if proof is None:
return {"success": False, "failed_at": subgoal}
proofs.append(proof)
# Step 4: Combine proofs
complete_proof = self.combine_proofs(theorem, subgoals, proofs)
# Step 5: Verify in Lean
verification = self.lean.verify(complete_proof)
return {
"success": verification["valid"],
"proof": complete_proof,
"subgoals": subgoals,
"subproofs": proofs
}
def decompose_to_subgoals(
self,
theorem: str,
sketch: str
) -> list[str]:
"""
Convert proof sketch to formal Lean 4 subgoals.
"""
prompt = f"""
Theorem: {theorem}
Proof sketch:
{sketch}
Convert this into a sequence of Lean 4 subgoals.
Each subgoal should be:
1. A valid Lean 4 statement
2. Simpler than the main theorem
3. Together they imply the main theorem
Subgoals (in Lean 4 syntax):
"""
response = self.decomposer.generate(prompt)
return self.parse_lean_statements(response)
def prove_subgoal(
self,
subgoal: str,
max_attempts: int = 64
) -> str | None:
"""
Prove subgoal using smaller prover model.
Uses best-first search with Lean verification.
"""
candidates = []
for _ in range(max_attempts):
# Generate proof candidate (non-CoT mode for speed)
candidate = self.prover.generate(
f"-- Prove: {subgoal}\n",
max_tokens=512,
temperature=0.8
)
# Verify in Lean
result = self.lean.verify(f"{subgoal}\n{candidate}")
if result["valid"]:
return candidate
candidates.append({
"proof": candidate,
"error": result.get("error")
})
# All attempts failed
return None
class LeanEnvironment:
"""
Interface to Lean 4 theorem prover.
"""
def __init__(self, project_path: str):
self.project_path = project_path
def verify(self, code: str) -> dict:
"""
Verify Lean 4 code.
Returns:
- valid: True if code type-checks
- errors: List of type errors if invalid
"""
# Write to temp file
temp_file = f"{self.project_path}/temp_verify.lean"
with open(temp_file, "w") as f:
f.write(code)
# Run Lean
import subprocess
result = subprocess.run(
["lake", "build", "temp_verify"],
cwd=self.project_path,
capture_output=True,
text=True
)
if result.returncode == 0:
return {"valid": True}
else:
return {
"valid": False,
"error": result.stderr
}
def get_tactics(self, goal: str) -> list[str]:
"""
Get available tactics for a goal state.
Useful for guiding proof search.
"""
# Query Lean for applicable tactics
pass
class APOLLOFramework:
"""
APOLLO: Automated LLM and Lean Collaboration.
Pipeline that transforms LLM proof sketches into
verified Lean 4 proofs.
Results: 3-7% baseline → 40%+ accuracy
Pipeline stages:
1. LLM generates proof sketch
2. Syntax cleaning (fix common Lean errors)
3. Auto-solvers (for simple subgoals)
4. LLM-driven sub-proof generation
5. Verification and iteration
"""
def __init__(
self,
llm,
lean_env,
auto_solvers
):
self.llm = llm
self.lean = lean_env
self.solvers = auto_solvers
def prove(self, theorem: str) -> dict:
"""
APOLLO proof pipeline.
"""
# Stage 1: LLM generates initial proof
initial = self.llm.generate(f"""
theorem {theorem} := by
-- Prove this theorem
""")
# Stage 2: Syntax cleaning
cleaned = self.clean_syntax(initial)
# Stage 3: Try auto-solvers on subgoals
with_autos = self.apply_auto_solvers(cleaned)
# Stage 4: LLM fills remaining gaps
complete = self.fill_gaps(theorem, with_autos)
# Stage 5: Verify
result = self.lean.verify(complete)
if result["valid"]:
return {"success": True, "proof": complete}
# Iterate if needed
for _ in range(3):
complete = self.fix_errors(complete, result["error"])
result = self.lean.verify(complete)
if result["valid"]:
return {"success": True, "proof": complete}
return {"success": False, "best_attempt": complete}
def clean_syntax(self, proof: str) -> str:
"""
Fix common Lean 4 syntax errors.
Common issues:
- Wrong indentation
- Missing imports
- Lean 3 vs Lean 4 syntax
- Incorrect tactic names
"""
# Indentation fixes
proof = self.fix_indentation(proof)
# Lean 3 → Lean 4 conversions
conversions = {
"begin": "by",
"end": "",
"simp only": "simp only",
"rw [": "rw [",
"←": "← ",
}
for old, new in conversions.items():
proof = proof.replace(old, new)
return proof
def apply_auto_solvers(self, proof: str) -> str:
"""
Apply automatic solvers to simple subgoals.
Common auto tactics:
- simp: simplification
- ring: ring arithmetic
- linarith: linear arithmetic
- omega: integer arithmetic
- norm_num: numeric normalization
"""
# Parse proof to find sorry/holes
holes = self.find_holes(proof)
for hole in holes:
# Try each auto solver
for solver in ["simp", "ring", "linarith", "omega", "norm_num"]:
test = proof.replace(hole, solver)
if self.lean.verify(test)["valid"]:
proof = test
break
return proof
Part 11: Tool-Integrated Reasoning (TIR)
Tool-Integrated Reasoning (TIR) is a paradigm where the language model interleaves natural language reasoning with calls to external tools—calculators, symbolic math engines, code interpreters. This isn't just a convenience; recent theoretical work proves that TIR strictly expands the class of problems solvable.
Why LLMs Need Tools for Math
The arithmetic bottleneck:
LLMs are trained to predict the next token, not to compute arithmetic. While they learn addition and multiplication patterns, their accuracy drops sharply:
- 2-digit × 2-digit: ~95% accuracy
- 3-digit × 3-digit: ~60% accuracy
- 4-digit × 4-digit: ~20% accuracy
For competition math, a single arithmetic error invalidates the entire solution. Tool use eliminates this bottleneck entirely.
What tools enable:
| Tool | Capability | Example Use |
|---|---|---|
| Calculator | Exact arithmetic | "What's 2^1024 mod 17?" |
| SymPy | Symbolic algebra | Factor polynomials, solve equations |
| Code interpreter | Enumeration | Check all cases for n < 100 |
| Lean/Coq | Formal verification | Prove intermediate lemmas |
| Wolfram Alpha | Knowledge + compute | Series expansions, limits |
Self-Consistency TIR (SC-TIR):
The winning approach in AIMO Progress Prize 1 (NuminaMath) combined TIR with self-consistency:
- Generate N solutions, each interleaving reasoning with code
- Execute all code blocks, replacing them with results
- Extract final answers
- Vote across solutions (majority wins)
This combines the flexibility of natural language reasoning with the precision of programmatic execution.
Theoretical Foundation
Recent work by Chen et al. (2024) proves that TIR is strictly more powerful than pure language modeling:
- There exist problems solvable in polynomial time with TIR that require exponential time without
- The separation holds even for problems with polynomial-length solutions
- Tool access isn't just faster—it enables solving fundamentally different problems
Tool-Integrated Reasoning fundamentally expands what LLMs can solve. Recent theoretical work proves this expansion is strict—not just helpful, but necessary.
┌─────────────────────────────────────────────────────────────────────────┐
│ TOOL-INTEGRATED REASONING │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ THEORETICAL BREAKTHROUGH (2025): │
│ ───────────────────────────────── │
│ │
│ "Tools enable a STRICT expansion of the model's capabilities, │
│ breaking the capability ceiling of pure-text models by unlocking │
│ problem-solving strategies that are otherwise impossible or │
│ intractably verbose." │
│ │
│ Key insight: TIR advantage is NOT confined to computation-intensive │
│ problems—it extends to those requiring significant ABSTRACT INSIGHT. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ TIR MODELS AND RESULTS: │
│ ─────────────────────── │
│ │
│ Model Approach Result │
│ ───────────────────────────────────────────────────────────────────── │
│ ToRA SFT + output shaping MATH: 68% → 84% │
│ NuminaMath (TIR) SC-TIR decoding AIMO: 29/50 │
│ START Long CoT + Python GPQA: 61%, AIME: 73% │
│ AutoTIR Autonomous tool choice Adaptive tool use │
│ ToRL RLVR + code gen Emergent tool use │
│ o4-mini + Python Tool use AIME: 92.7% → 99.5% │
│ │
└─────────────────────────────────────────────────────────────────────────┘
The following implementation suite provides four complete TIR approaches, each with different tradeoffs:
ToolIntegratedReasoning: The basic TIR implementation. The model generates natural language reasoning interleaved with <code>...</code> blocks. When a code block is detected, it's extracted and executed, with the result injected back into the context. The model continues reasoning with access to the computation result. This loop continues until the model produces a final <answer> tag.
SelfConsistencyTIR: SC-TIR combines TIR with ensemble voting. Generate N solutions (typically 32-64), each using the basic TIR loop. Execute all code blocks and verify results. Filter out solutions where code execution failed. Among remaining solutions, extract numerical answers and take the majority vote. This three-stage filtering (tool verification → execution success → consensus) is what won AIMO Progress Prize 1.
AutoTIR: Instead of telling the model when to use tools, let it learn through RL. The model receives task success reward and discovers that tool use improves outcomes. This leads to emergent, adaptive tool use—the model learns when tools help vs. when direct reasoning is faster.
TIRWithRefinement: Adds a self-critique loop. After generating a TIR solution, the model reviews its own reasoning and code, looking for errors. If issues are found, it generates a refined solution. This catches errors that even code execution might miss (e.g., solving the wrong problem correctly).
class ToolIntegratedReasoning:
"""
Tool-Integrated Reasoning (TIR) for mathematical problem solving.
Combines natural language reasoning with code execution
for precise computation and verification.
"""
def __init__(
self,
model,
code_executor,
tools: list[str] = ["python", "sympy", "wolfram"]
):
self.model = model
self.executor = code_executor
self.tools = tools
def solve_with_tir(self, problem: str) -> dict:
"""
Solve problem with tool-integrated reasoning.
The model interleaves:
1. Natural language reasoning (analysis, strategy)
2. Code execution (computation, verification)
"""
prompt = f"""Solve this problem using both reasoning and code.
Problem: {problem}
Use <think> tags for reasoning and <code> tags for Python code.
Execute code to verify calculations.
Solution:"""
# Generate with tool use
response = ""
reasoning_steps = []
code_executions = []
while True:
# Generate next chunk
chunk = self.model.generate(
prompt + response,
max_tokens=1024,
stop=["</code>", "<answer>"]
)
response += chunk
# Check for code block
if "<code>" in chunk:
# Extract and execute code
code = self.extract_code(response)
result = self.executor.run(code)
# Add result to context
response += f"</code>\n<output>{result}</output>\n"
code_executions.append({
"code": code,
"result": result
})
elif "<answer>" in response:
# Done
break
elif len(response) > 8000:
# Safety limit
break
# Extract answer
answer = self.extract_answer(response)
return {
"answer": answer,
"reasoning": response,
"code_executions": code_executions
}
class SelfConsistencyTIR:
"""
Self-Consistency with Tool-Integrated Reasoning (SC-TIR).
NuminaMath's winning approach for AIMO Progress Prize.
Key insight: Generate multiple TIR solutions,
use code verification to filter, then majority vote.
"""
def __init__(
self,
model,
code_executor,
n_samples: int = 32
):
self.model = model
self.executor = code_executor
self.n_samples = n_samples
def solve_with_sc_tir(self, problem: str) -> dict:
"""
Solve with Self-Consistency TIR.
1. Generate N solutions with TIR
2. Filter by code verification
3. Majority vote on remaining
"""
solutions = []
for i in range(self.n_samples):
# Generate solution with TIR
solution = self.generate_tir_solution(problem)
# Verify with code
verified = self.code_verify(problem, solution)
solutions.append({
"solution": solution,
"answer": self.extract_answer(solution["reasoning"]),
"verified": verified,
"code_passed": solution["code_executions"]
})
# Filter to verified solutions
verified_solutions = [s for s in solutions if s["verified"]]
if not verified_solutions:
# Fall back to all solutions
verified_solutions = solutions
# Majority vote on answers
answer_counts = Counter(
s["answer"] for s in verified_solutions
if s["answer"] is not None
)
if not answer_counts:
return {"answer": None, "confidence": 0}
best_answer = answer_counts.most_common(1)[0][0]
confidence = answer_counts[best_answer] / len(verified_solutions)
return {
"answer": best_answer,
"confidence": confidence,
"n_verified": len(verified_solutions),
"n_total": self.n_samples,
"vote_distribution": dict(answer_counts)
}
def code_verify(self, problem: str, solution: dict) -> bool:
"""
Verify solution using code execution.
"""
# Check if all code executions succeeded
for execution in solution["code_executions"]:
if "Error" in execution["result"]:
return False
# Try to verify final answer
answer = self.extract_answer(solution["reasoning"])
if answer is None:
return False
# Generate verification code
verify_code = f"""
# Verify answer to: {problem[:200]}
answer = {answer}
# Re-solve from scratch
{self.generate_solution_code(problem)}
# Check
print("VERIFIED" if abs(computed - answer) < 1e-6 else "MISMATCH")
"""
try:
result = self.executor.run(verify_code)
return "VERIFIED" in result
except:
return False
class STARTReasoning:
"""
START: Seamless Long CoT + Tool Integration.
First open-source LLM combining extended reasoning
with Python tool use.
Results:
- GPQA: 61% (PhD-level science)
- AIME: 73%
- LiveCodeBench: 47%
"""
def __init__(self, model, executor):
self.model = model
self.executor = executor
def solve(self, problem: str) -> dict:
"""
START-style reasoning with tools.
Key difference from basic TIR:
- Much longer reasoning chains (8K+ tokens)
- Strategic tool invocation (not every step)
- Self-correction within reasoning
"""
system_prompt = """You are a mathematical reasoning expert.
Think step by step, showing ALL your work.
Use <code>...</code> when computation would help.
Use <verify>...</verify> to check intermediate results.
Continue until you're confident in your answer.
Be thorough—better to reason more than miss something."""
response = self.model.generate(
f"{system_prompt}\n\nProblem: {problem}\n\nSolution:",
max_tokens=8192, # Long context for extended reasoning
temperature=0.7
)
# Process tool calls
processed = self.process_tool_calls(response)
return {
"answer": self.extract_answer(processed),
"reasoning": processed,
"tool_calls": self.count_tool_calls(response)
}
class AutoTIR:
"""
AutoTIR: Autonomous Tool-Integrated Reasoning via RL.
The model learns WHEN and WHICH tool to use
through reinforcement learning, without explicit instruction.
Key insight: Let the model discover tool use
through maximizing task success reward.
"""
def __init__(
self,
model,
tools: dict, # {"python": executor, "wolfram": api, ...}
reward_model
):
self.model = model
self.tools = tools
self.reward = reward_model
def train_autonomous_tool_use(
self,
problems: list[str],
n_epochs: int = 3
):
"""
Train model to autonomously use tools via RL.
No tool-use demonstrations needed—model learns
from reward signal alone.
"""
for epoch in range(n_epochs):
for problem in problems:
# Generate solution (model decides tool use)
solution = self.model.generate(
f"Solve: {problem}",
max_tokens=4096,
tools_available=list(self.tools.keys())
)
# Execute any tool calls
processed = self.execute_tools(solution)
# Compute reward
reward = self.reward.score(problem, processed)
# RL update (simplified)
self.rl_update(problem, solution, reward)
def execute_tools(self, solution: str) -> str:
"""
Execute any tool calls in the solution.
"""
# Find tool invocations
for tool_name, executor in self.tools.items():
pattern = f"<{tool_name}>(.*?)</{tool_name}>"
matches = re.findall(pattern, solution, re.DOTALL)
for match in matches:
result = executor(match)
solution = solution.replace(
f"<{tool_name}>{match}</{tool_name}>",
f"<{tool_name}>{match}</{tool_name}>\n<result>{result}</result>"
)
return solution
Part 12: Verification-Refinement Pipelines for IMO 2025
Multiple teams achieved gold-medal performance at IMO 2025 using carefully designed verification-refinement pipelines.
The Core Insight: Single-Shot Generation Isn't Enough
The naive approach to AI math solving is straightforward: prompt a strong model, get an answer. This works surprisingly well on easy problems—GPT-4 class models achieve ~75% on AMC-level problems with zero-shot prompting.
But for IMO-level problems, single-shot accuracy drops below 10%. Why? Because hard math problems require:
- Exploration: Trying multiple approaches before finding one that works
- Error detection: Recognizing when a reasoning path is going wrong
- Refinement: Fixing errors without starting from scratch
- Verification: Confirming the final proof is actually correct
Human mathematicians do all four constantly. They don't write a proof linearly from start to finish—they draft, check, revise, and often restart. Verification-refinement pipelines encode this human workflow into AI systems.
The Pipeline Architecture
A typical verification-refinement pipeline has four stages:
Stage 1: Diverse Generation
Generate many candidate solutions using different strategies. Diversity matters more than quantity—8 diverse approaches beats 64 slight variations. Techniques include:
- Different prompts ("use algebra", "try coordinates", "work backwards")
- Different temperatures (0.3 for focused, 0.9 for creative)
- Different models (ensemble across GPT-5, Claude, Gemini)
Stage 2: Verification
Each candidate is verified by a separate model (or the same model with a "critic" prompt). Verification looks for:
- Logical validity: Does each step follow from the previous?
- Calculation correctness: Are the computations right?
- Completeness: Are all cases handled?
- Goal achievement: Does the proof actually prove what's claimed?
Verification outputs a confidence score and a list of issues.
Stage 3: Refinement
Candidates with high potential but specific issues enter a refinement loop. The model receives:
- The original problem
- Its attempted solution
- The specific issues identified
- Instructions to fix those issues
This is far more efficient than regenerating from scratch—often 80% of a solution is correct, and only a specific step needs fixing.
Stage 4: Selection
After verification and refinement, select the best solution. Criteria include:
- Verification confidence (higher is better)
- Refinement stability (solutions that improve with refinement are promising)
- Answer agreement (if multiple solutions reach the same answer, it's likely correct)
Why This Works Better Than Pure Generation
Verification-refinement pipelines improve accuracy from 21-38% (single-shot) to 85.7% (5/6 problems at IMO 2025). The improvement comes from:
- Error recovery: Single-shot either works or fails. Refinement can recover from errors.
- Computational allocation: Spend more compute on hard problems, less on easy ones.
- Ensemble effects: Different approaches find different solutions; verification picks the best.
- Human-like iteration: Mimics how expert mathematicians actually work.
class IMO2025Pipeline:
"""
Model-agnostic verification-and-refinement pipeline.
Based on arxiv 2507.15855:
"Winning Gold at IMO 2025 with a Model-Agnostic
Verification-and-Refinement Pipeline"
Results: 21-38% baseline → 85.7% (5/6 problems)
Tested with: Gemini 2.5 Pro, GPT-5, Grok-4
"""
def __init__(
self,
model,
n_candidates: int = 32,
max_refinement_rounds: int = 5
):
self.model = model
self.n_candidates = n_candidates
self.max_rounds = max_refinement_rounds
def solve_imo_problem(self, problem: str) -> dict:
"""
Full pipeline for IMO problem.
Steps:
1. Generate diverse candidate solutions
2. Verify each candidate
3. Refine promising candidates
4. Select best verified solution
"""
# Stage 1: Diverse generation
candidates = self.generate_diverse_candidates(problem)
# Stage 2: Initial verification
verified = []
for candidate in candidates:
verification = self.verify_solution(problem, candidate)
verified.append({
"solution": candidate,
"verification": verification,
"score": verification["confidence"]
})
# Sort by verification score
verified.sort(key=lambda x: x["score"], reverse=True)
# Stage 3: Refine top candidates
refined = []
for candidate in verified[:8]: # Top 8
refined_solution = self.refine_until_verified(
problem,
candidate["solution"]
)
refined.append(refined_solution)
# Stage 4: Select best
best = max(refined, key=lambda x: x["final_score"])
return {
"solution": best["solution"],
"confidence": best["final_score"],
"refinement_rounds": best["rounds"],
"verified": best["verified"]
}
def generate_diverse_candidates(self, problem: str) -> list[str]:
"""
Generate diverse solution candidates.
Key prompts for diversity:
1. Different approaches (algebraic, geometric, combinatorial)
2. Different starting points
3. Temperature variation
"""
candidates = []
# Approach-specific prompts
approaches = [
"Solve using algebraic manipulation:",
"Solve using coordinate geometry:",
"Solve using combinatorial arguments:",
"Solve using number theory techniques:",
"Solve by considering extreme cases first:",
"Solve by working backwards from the goal:",
]
for approach in approaches:
for temp in [0.3, 0.7, 1.0]:
prompt = f"""{approach}
Problem: {problem}
Provide a complete, rigorous proof.
Solution:"""
solution = self.model.generate(
prompt,
max_tokens=4096,
temperature=temp
)
candidates.append(solution)
return candidates
def verify_solution(
self,
problem: str,
solution: str
) -> dict:
"""
Verify solution with detailed checking.
"""
prompt = f"""Carefully verify this IMO solution.
Problem: {problem}
Proposed Solution:
{solution}
Check each step for:
1. Mathematical correctness
2. Logical validity
3. Completeness of argument
4. Proper handling of all cases
Rate your confidence that this proof is correct (0-100).
If there are errors, describe them precisely.
Verification:"""
response = self.model.generate(prompt, max_tokens=2048)
# Parse confidence and issues
confidence = self.extract_confidence(response)
issues = self.extract_issues(response)
return {
"confidence": confidence,
"issues": issues,
"is_valid": confidence > 80 and len(issues) == 0
}
def refine_until_verified(
self,
problem: str,
solution: str
) -> dict:
"""
Iteratively refine solution until verified.
"""
current = solution
for round_num in range(self.max_rounds):
verification = self.verify_solution(problem, current)
if verification["is_valid"]:
return {
"solution": current,
"final_score": verification["confidence"],
"rounds": round_num,
"verified": True
}
# Refine based on issues
current = self.refine_solution(
problem,
current,
verification["issues"]
)
# Max rounds, return best attempt
final_verification = self.verify_solution(problem, current)
return {
"solution": current,
"final_score": final_verification["confidence"],
"rounds": self.max_rounds,
"verified": final_verification["is_valid"]
}
def refine_solution(
self,
problem: str,
solution: str,
issues: list[str]
) -> str:
"""
Refine solution to address specific issues.
"""
issues_text = "\n".join(f"- {issue}" for issue in issues)
prompt = f"""Fix the following issues in this IMO solution.
Problem: {problem}
Current Solution:
{solution}
Issues to Fix:
{issues_text}
Provide a corrected, complete proof that addresses ALL issues.
Corrected Solution:"""
return self.model.generate(prompt, max_tokens=4096)
class GeminiDeepThinkPipeline:
"""
Gemini Deep Think's approach for IMO 2025 gold.
Key innovations:
1. Parallel thinking: explore multiple paths simultaneously
2. Novel RL for multi-step reasoning
3. End-to-end natural language (no Lean translation)
"""
def __init__(self, model):
self.model = model
def solve_with_parallel_thinking(
self,
problem: str,
n_paths: int = 8
) -> dict:
"""
Explore multiple solution paths in parallel.
Unlike sequential refinement, parallel thinking
explores diverse approaches simultaneously,
then combines insights.
"""
# Generate parallel solution attempts
paths = []
for i in range(n_paths):
path = self.generate_solution_path(problem, seed=i)
paths.append(path)
# Cross-pollinate insights
insights = self.extract_insights(paths)
# Generate final solution using all insights
final = self.synthesize_solution(problem, insights, paths)
return {
"solution": final,
"paths_explored": len(paths),
"key_insights": insights
}
def generate_solution_path(
self,
problem: str,
seed: int
) -> dict:
"""
Generate one solution path.
Different seeds lead to different starting strategies.
"""
strategies = [
"start_with_examples",
"start_with_extreme_cases",
"start_with_algebraic_setup",
"start_with_geometric_intuition",
"start_with_invariants",
"start_with_working_backwards",
"start_with_contradiction",
"start_with_induction"
]
strategy = strategies[seed % len(strategies)]
prompt = f"""Solve this IMO problem using the {strategy} approach.
Problem: {problem}
Think deeply and show your complete reasoning.
Solution:"""
solution = self.model.generate(
prompt,
max_tokens=8192, # Extended thinking
temperature=0.8
)
return {
"strategy": strategy,
"solution": solution,
"key_steps": self.extract_key_steps(solution)
}
def extract_insights(self, paths: list[dict]) -> list[str]:
"""
Extract key insights from all solution paths.
"""
all_steps = []
for path in paths:
all_steps.extend(path["key_steps"])
# Find common themes
prompt = f"""These are key steps from multiple solution attempts:
{chr(10).join(all_steps[:30])}
What are the most important mathematical insights?
What techniques appear most promising?
Key Insights:"""
response = self.model.generate(prompt, max_tokens=1024)
return self.parse_insights(response)
def synthesize_solution(
self,
problem: str,
insights: list[str],
paths: list[dict]
) -> str:
"""
Synthesize final solution from insights and paths.
"""
insights_text = "\n".join(f"- {i}" for i in insights)
# Find most promising path
best_path = max(paths, key=lambda p: len(p["key_steps"]))
prompt = f"""Create a complete, rigorous solution.
Problem: {problem}
Key Insights to Use:
{insights_text}
Most Promising Approach:
{best_path['solution'][:2000]}
Create a clear, complete proof incorporating these insights.
Final Solution:"""
return self.model.generate(prompt, max_tokens=8192)
How Gemini Deep Think Actually Works: The Full Technical Picture
Gemini Deep Think achieved 35/42 points (gold medal) at IMO 2025 using an approach fundamentally different from previous systems. Here's the complete technical breakdown.
The Paradigm Shift: Why Google Abandoned Formal Verification
To understand Gemini Deep Think, we need to understand what Google tried before and why they changed direction.
The AlphaProof approach (IMO 2024):
At IMO 2024, Google's AlphaProof achieved silver medal using formal verification in Lean 4:
- Problem (natural language) → Translate to Lean statement
- Neural model generates candidate proof steps
- Lean type-checker verifies each step
- MCTS explores the proof space
- Valid proof → Extract answer
This worked, but had critical limitations:
-
Translation bottleneck: ~30% of IMO problems resist clean formalization. Combinatorics and number theory problems with complex constructions often require ad-hoc Lean definitions that are as hard to find as the proof itself.
-
Formalization overhead: Even for problems that can be formalized, the Lean proof is often 10x longer than the natural language proof. This inflates the search space.
-
No transfer to new domains: Lean proofs require domain-specific libraries. A geometry proof uses different tactics than an algebra proof. Each IMO has novel problem types that may not have good library support.
-
Speed: Formal verification is slow. Each step requires Lean compilation and type-checking, limiting how many paths can be explored in competition time.
Google's realization:
After IMO 2024, Google's team made a key observation: the best human mathematicians don't use formal verification. They reason in natural language, using intuition, pattern recognition, and self-correction.
What if, instead of making AI more like Lean, they made AI more like a human mathematician?
The Deep Think Architecture
Gemini Deep Think is built on Gemini 2.0, Google's most capable foundation model, with specialized training for extended mathematical reasoning.
Core capabilities:
| Capability | How It's Achieved |
|---|---|
| Extended thinking | Can generate up to 32K tokens of reasoning before answering |
| Self-verification | Checks its own work, identifies errors, and corrects them |
| Parallel exploration | Considers multiple solution approaches simultaneously |
| Strategy adaptation | Recognizes which proof techniques suit which problem types |
| Natural language rigor | Produces proofs that are rigorous despite not being formally verified |
Why "Deep Think"?
The name reflects the core innovation: thinking deeper before answering. Standard LLMs generate tokens left-to-right, each token depending only on previous tokens. Deep Think can:
- Generate a long internal reasoning chain (thousands of tokens)
- Evaluate whether it's making progress
- Backtrack and try different approaches
- Synthesize insights from failed attempts
- Only produce a final answer when confident
This is closer to human expert reasoning than typical LLM behavior.
The Training Pipeline: How Google Built Deep Think
Stage 1: Mathematical Pre-Training
Before Deep Think training, the base Gemini model underwent extensive mathematical pre-training:
- Corpus: Mathematical research papers (arXiv math), textbooks, competition solutions, proof assistant formalizations, educational content
- Scale: Estimated 100B+ tokens of mathematical content
- Special handling: LaTeX rendering, mathematical notation, structured proofs
- Verification: Solutions in training data were verified by independent systems where possible
Stage 2: Curated Solution Corpus
Google built a high-quality corpus specifically for IMO-style reasoning. This corpus was hand-curated by mathematicians, not scraped from the web:
-
Sources:
- IMO shortlist problems and official solutions (1959-2024)
- Putnam competition solutions (1938-2024)
- National olympiad solutions from 60+ countries
- Art of Problem Solving "gold standard" solutions
- Solutions from mathematical olympiad textbooks
-
Quality criteria:
- Every solution verified correct by at least 2 reviewers
- Solutions must be complete (no gaps, no "obvious" steps)
- Multiple alternative solutions per problem where available
- Step-by-step reasoning chains explicitly marked
-
Annotation:
- Key insight annotations: "This is where the crucial observation happens"
- Difficulty ratings per step
- Alternative approaches noted
- Common mistakes on similar problems
This corpus is much smaller than typical LLM training data (tens of thousands of problems vs. billions of tokens) but vastly higher quality.
Stage 3: Multi-Step Reinforcement Learning
Standard RLHF gives one reward at the end. For 30-step proofs, this provides almost no signal about which steps were good.
Google's innovation: reward at every step.
They trained a Process Reward Model (PRM) that scores intermediate reasoning steps:
Step 1: "Let x = the smallest element..." → PRM score: 0.7
Step 2: "Since x is minimal, x ≤ f(x)..." → PRM score: 0.85
Step 3: "Therefore f(x) = x + 1..." → PRM score: 0.3 (error introduced)
The generator is trained with rewards at every step, allowing credit assignment deep in the reasoning chain. Steps that introduce key insights get high reward; steps that introduce errors get low reward.
The reward function:
R(step, context) = 0.4 × correctness(step)
+ 0.3 × progress_toward_solution(step)
+ 0.2 × insight_value(step)
+ 0.1 × coherence_with_previous(step)
Each component is learned from human annotations:
- Correctness: Is this step mathematically valid?
- Progress: Does this step get us closer to the solution?
- Insight value: Does this step contain a key observation?
- Coherence: Does this step follow logically from previous steps?
Stage 4: Parallel Exploration Training
Humans solving hard problems don't try one approach and give up. They explore multiple angles, extract insights from failed attempts, and combine ideas.
Google trained Deep Think to do the same:
- Generate N parallel solution attempts (N=8 typically)
- Each attempt uses a different initial strategy
- Evaluate which approaches made progress
- Extract insights from partially successful attempts
- Synthesize a final solution combining the best ideas
The key innovation: rewarding exploration diversity. The model is rewarded not just for finding solutions, but for trying genuinely different approaches. This prevents mode collapse where the model always tries the same strategy.
Diversity metric:
The first few reasoning steps of each path are embedded, and paths with high pairwise distances (meaning they're trying different things) get bonus reward.
Why Natural Language Works for IMO
This seems counterintuitive. Formal verification gives 100% correctness guarantees. How can natural language compete?
The key insight: IMO doesn't need 100% correctness guarantee.
IMO scoring gives partial credit. A mostly-correct solution with a minor gap can still earn 5-6 points out of 7. Formal verification's all-or-nothing approach gives up these partial credit points.
More importantly, natural language allows reasoning about concepts that are hard to formalize:
| Concept | Natural Language | Lean 4 |
|---|---|---|
| "Obviously continuous" | Immediate | Define continuity, prove epsilon-delta |
| "Standard argument" | 1 sentence | Import library, match pattern |
| "By symmetry" | 1 word | Define symmetry group, prove invariance |
| "Without loss of generality" | 1 phrase | Complex case analysis |
Human mathematicians use these shortcuts constantly. They work because mathematicians share conventions about what "obvious" and "standard" mean. Deep Think has learned these conventions from training data.
Self-verification compensates for formality:
Without formal verification, how do we know the proof is correct? Deep Think uses multi-round self-verification:
- Generate complete proof
- A separate "critic" pass reads the proof and looks for errors
- If errors found, identify the problematic step
- Regenerate from that point
- Re-verify
- Repeat until clean or max iterations
At IMO 2025, Deep Think's self-verification caught ~15% of the errors that would have been in single-shot generation.
Competition-Time Procedure
At the actual IMO, Deep Think operated under competition constraints: 4.5 hours for 6 problems.
Time allocation:
| Problem Difficulty | Time Budget | Parallel Paths |
|---|---|---|
| Easy (P1, P4) | 20-30 min | 4 paths |
| Medium (P2, P5) | 40-60 min | 8 paths |
| Hard (P3, P6) | 60-90 min | 12 paths |
The competition procedure:
For each problem:
1. Quick assessment (2-3 min)
- Problem type classification
- Difficulty estimation
- Likely proof techniques
2. Parallel exploration (60% of time budget)
- Launch N parallel solution attempts
- Each uses different strategy seed
- Monitor progress, prune failing paths
- Extract insights from partial progress
3. Solution synthesis (20% of time budget)
- Combine insights from all paths
- Generate complete proof using best approach
- Self-verify
4. Refinement (20% of time budget)
- Multiple verification passes
- Fix any issues found
- Polish presentation
Early termination:
If a path achieves >0.95 confidence score before time budget expires, Deep Think terminates early and moves to verification. This optimization saved ~15% of total time at IMO 2025.
Results Breakdown: What Deep Think Solved at IMO 2025
| Problem | Type | Points | How It Was Solved |
|---|---|---|---|
| P1 | Algebra | 7/7 | Direct approach, found key inequality |
| P2 | Combinatorics | 7/7 | Case analysis with invariant |
| P3 | Number Theory | 7/7 | Modular arithmetic + descent |
| P4 | Geometry | 7/7 | Coordinate geometry |
| P5 | Algebra | 7/7 | Functional equation techniques |
| P6 | Combinatorics | 0/7 | Failed to find construction |
| Total | 35/42 | Gold medal |
P6 was the only problem Deep Think failed. It required constructing a specific combinatorial object with unusual properties—exactly the kind of creative construction that remains challenging.
Comparison with DeepSeekMath-V2
| Aspect | Gemini Deep Think | DeepSeekMath-V2 |
|---|---|---|
| Approach | End-to-end natural language | Self-verification loop |
| Formal verification | None | None |
| Model size | Gemini 2.0 scale (~1T?) | 685B (37B active) |
| Training innovation | Multi-step RL, diversity reward | Verifier-first pipeline |
| Inference innovation | Parallel exploration | Generate-verify-refine loop |
| IMO 2025 score | 35/42 (Gold) | 35/42 (Gold) |
| Strengths | Novel problems, geometry | Rigorous proofs, algebra |
| Weaknesses | Some constructions | Some geometry |
Both systems achieved gold, using fundamentally different approaches. This suggests multiple viable paths to mathematical reasoning, not a single "correct" architecture.
┌─────────────────────────────────────────────────────────────────────────┐
│ GEMINI DEEP THINK: COMPLETE ARCHITECTURE │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ KEY BREAKTHROUGH: End-to-End Natural Language │
│ ───────────────────────────────────────────────── │
│ │
│ Previous approaches (AlphaProof, DeepSeek-Prover): │
│ • Problem (NL) → Translate to Lean → Prove in Lean → Extract answer │
│ • Translation bottleneck: ~30% of problems lost in translation │
│ • Requires formal verification infrastructure │
│ │
│ Gemini Deep Think: │
│ • Problem (NL) → Reason in natural language → Answer │
│ • No formal verification dependency │
│ • Scales with general reasoning capability │
│ • Handles ALL problem types (including those hard to formalize) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ PARALLEL THINKING ARCHITECTURE: │
│ ─────────────────────────────── │
│ │
│ ┌─────────────┐ │
│ │ Problem │ │
│ └──────┬──────┘ │
│ │ │
│ ┌──────────────────────┼──────────────────────┐ │
│ ▼ ▼ ▼ │
│ ┌────────────┐ ┌────────────┐ ┌────────────┐ │
│ │ Path A: │ │ Path B: │ │ Path C: │ │
│ │ Algebraic │ │ Geometric │ │ Invariant │ │
│ │ Transform │ │ Intuition │ │ Search │ ... N │
│ └────────────┘ └────────────┘ └────────────┘ │
│ │ │ │ │
│ │ Explore │ Explore │ Explore │
│ ▼ ▼ ▼ │
│ ┌────────────┐ ┌────────────┐ ┌────────────┐ │
│ │ Dead │ │ Promising │ │ Promising │ │
│ │ End │ │ Lead │ │ Lead │ │
│ └────────────┘ └────────────┘ └────────────┘ │
│ │ │ │
│ └──────────┬───────────┘ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ Cross-Pollinate │ │
│ │ Insights │ │
│ └────────┬────────┘ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ Synthesize │ │
│ │ Final Solution │ │
│ └─────────────────┘ │
│ │
│ Unlike sequential CoT, paths explore SIMULTANEOUSLY │
│ Insights flow between paths during exploration │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Training Pipeline: Novel RL for Multi-Step Mathematical Reasoning
class GeminiDeepThinkTraining:
"""
How Gemini Deep Think was actually trained.
Key innovations:
1. Curated high-quality corpus: Unlike typical web scraping,
solutions were carefully selected for correctness and clarity
2. Multi-step RL: Reward signal at each reasoning step, not just final
3. Parallel exploration reward: Bonus for exploring diverse approaches
"""
def create_training_corpus(self):
"""
Step 1: Build high-quality solution corpus.
Google's approach differs from typical math training:
- Human mathematicians reviewed and curated solutions
- Only "gold standard" proofs included (correct + elegant)
- Multiple solution approaches per problem
- Step-by-step reasoning chains annotated
"""
corpus = {
"sources": [
"IMO shortlist solutions (1959-2024)",
"Putnam exam solutions (1938-2024)",
"National olympiad solutions (60+ countries)",
"Art of Problem Solving curated solutions",
"Mathematical olympiad textbook proofs",
"Peer-reviewed competition mathematics papers"
],
"quality_criteria": {
"correctness": "Verified by multiple reviewers",
"completeness": "All steps justified, no gaps",
"clarity": "Each step flows logically",
"elegance": "Efficient approach without redundancy"
},
"annotation": {
"step_boundaries": "Where each reasoning step starts/ends",
"key_insights": "Critical observations that unlock solution",
"alternative_approaches": "Other valid paths to solution",
"difficulty_scores": "Per-step difficulty assessment"
}
}
return corpus
def multi_step_rl_objective(self, trajectory, ground_truth):
"""
Step 2: Multi-step reinforcement learning.
Key difference from standard RLHF:
- Reward at EACH reasoning step, not just final answer
- Process reward model scores intermediate steps
- Allows credit assignment deep in reasoning chain
"""
total_reward = 0
for step_idx, step in enumerate(trajectory.steps):
# Step-level correctness reward
step_correct = self.step_verifier(
step,
context=trajectory.steps[:step_idx],
problem=trajectory.problem
)
# Progress reward: is this step making progress toward solution?
progress = self.progress_estimator(
step,
ground_truth_solution=ground_truth
)
# Insight reward: does this step contain key mathematical insight?
insight_score = self.insight_detector(step)
# Combine rewards with temporal discounting
step_reward = (
0.4 * step_correct +
0.3 * progress +
0.2 * insight_score +
0.1 * self.coherence_score(step, trajectory.steps[:step_idx])
)
total_reward += step_reward * (0.99 ** step_idx)
# Final answer correctness (still important!)
final_correct = self.verify_final_answer(trajectory, ground_truth)
total_reward += 2.0 * final_correct # Higher weight for getting it right
return total_reward
def parallel_exploration_training(self, problem, solutions):
"""
Step 3: Train for parallel exploration.
Novel RL objective: reward exploring DIVERSE approaches,
not just finding one that works.
"""
# Generate N parallel solution attempts
paths = []
for i in range(8): # 8 parallel paths
path = self.generate_path(problem, strategy_seed=i)
paths.append(path)
# Measure diversity of approaches
diversity_score = self.compute_diversity(paths)
# Find successful paths
successful_paths = [p for p in paths if p["reached_solution"]]
# Cross-pollination reward: did insights transfer between paths?
cross_pollination = self.measure_insight_transfer(paths)
# Final objective
objective = (
0.3 * diversity_score + # Reward exploring different approaches
0.4 * len(successful_paths) / len(paths) + # Success rate
0.3 * cross_pollination # Insight sharing between paths
)
return objective
def compute_diversity(self, paths):
"""
Measure how different the approaches are.
Want model to try genuinely different strategies,
not just variations of the same approach.
"""
# Encode each path's approach
approach_embeddings = [
self.approach_encoder(path["steps"][:5]) # First 5 steps define approach
for path in paths
]
# Compute pairwise distances
total_distance = 0
for i in range(len(paths)):
for j in range(i+1, len(paths)):
total_distance += cosine_distance(
approach_embeddings[i],
approach_embeddings[j]
)
# Normalize
num_pairs = len(paths) * (len(paths) - 1) / 2
return total_distance / num_pairs
class IMOCompetitionInference:
"""
How Gemini Deep Think operates during actual IMO competition.
Key constraint: 4.5 hours for 6 problems (45 min average)
"""
def __init__(self):
self.time_budget_per_problem = 45 * 60 # 45 minutes in seconds
self.n_parallel_paths = 8
def solve_problem_under_time_constraint(self, problem, deadline):
"""
Competition-mode solving with time management.
"""
start_time = time.time()
# Phase 1: Quick assessment (2-3 minutes)
# Determine problem type, likely difficulty, promising approaches
assessment = self.quick_assess(problem)
if assessment["likely_difficulty"] == "easy":
# For easier problems, go deeper on fewer paths
n_paths = 4
time_per_path = (deadline - time.time()) / n_paths
else:
# For harder problems, explore more broadly
n_paths = 12
time_per_path = (deadline - time.time()) / n_paths
# Phase 2: Parallel exploration
paths = []
for i in range(n_paths):
if time.time() > deadline - 60: # Keep 1 min for synthesis
break
path = self.explore_path_with_timeout(
problem,
strategy_seed=i,
timeout=time_per_path
)
paths.append(path)
# Early exit if high-confidence solution found
if path.get("confidence", 0) > 0.95:
break
# Phase 3: Synthesize best solution
remaining_time = deadline - time.time()
solution = self.synthesize_with_verification(
problem, paths, time_budget=remaining_time
)
return solution
def quick_assess(self, problem):
"""
Rapid problem assessment (no deep computation).
"""
# Pattern matching against known problem types
problem_type = self.classify_problem_type(problem) # geometry, algebra, etc.
# Historical difficulty of similar problems
difficulty = self.estimate_difficulty(problem)
# Identify most promising approaches for this type
approaches = self.suggest_approaches(problem_type, problem)
return {
"problem_type": problem_type,
"likely_difficulty": difficulty,
"suggested_approaches": approaches
}
Why Natural Language Works Better Than Formal Proofs for IMO
# The surprising finding: for IMO specifically, natural language
# reasoning outperformed formal verification systems
formal_vs_natural = {
"formal_proof_systems": {
"advantages": [
"100% correctness guarantee when proof compiles",
"No hand-waving or gaps possible",
"Can leverage existing proof libraries"
],
"disadvantages_for_IMO": [
"Translation bottleneck: ~30% of problems hard to formalize",
"Combinatorics/number theory: formalization often longer than proof",
"Novel problem types may lack library support",
"Creativity constrained by formal language limitations"
]
},
"natural_language_reasoning": {
"advantages_for_IMO": [
"No translation loss - directly reason about problem as stated",
"Full flexibility in approach and notation",
"Can handle any problem type including novel constructions",
"Scales with general language model capability"
],
"challenges": [
"Risk of subtle errors (mitigated by parallel verification)",
"Need extensive training on high-quality proofs",
"Harder to guarantee correctness"
]
}
}
# Result at IMO 2025:
# Gemini Deep Think (natural language): 35/42 = Gold
# AlphaProof (Lean 4): Also achieved gold, but required formal translation
# Deep Think solved problems that were difficult to formalize
Part 13: AIMO Competition: Winning Techniques
The AI Mathematical Olympiad (AIMO) Progress Prizes have become the premier benchmark for practical AI math systems. Unlike academic benchmarks, AIMO prizes real money ($1M+ total) for systems that can solve new, unseen problems under competition conditions.
What Makes AIMO Different from Academic Benchmarks
Real competition conditions:
- Problems are new: Not in any training data
- Time pressure: Limited submissions, time constraints
- Adversarial design: Problems designed to resist current techniques
- Generalization test: Can't overfit to known problem styles
Prize structure incentivizes practical systems:
- Progress Prize 1: $131,072 for first to achieve 30+ on private test
- Progress Prize 2: Increasing thresholds, larger prizes
- Grand Prize: $5M for IMO gold-medal level
Lessons from Winning Systems
AIMO competitions have produced hard-won knowledge about what actually works. Unlike academic benchmarks where you can optimize for a fixed test set, AIMO prizes require solving genuinely new problems under competition constraints. The winning systems share common patterns: tool integration for calculation reliability, fine-tuning on diverse math data, and ensemble/voting strategies for error detection. Here's what the winners did and why it worked.
NuminaMath (AIMO 1 Winner):
The winning team's key insight: combine fine-tuning with tool use. Their recipe:
- Base model: DeepSeekMath-Base 7B (already strong at math)
- Fine-tuning data: ~1M problems including synthetic, competitions, textbooks
- Tool integration: Python code interpreter for calculations
- Inference: SC-TIR with 64 samples, majority vote
This relatively simple recipe achieved 29/50 on the private test—significantly above the 20/50 threshold for the prize.
Why 7B with tools beat larger models:
- Tool use eliminates arithmetic errors (the #1 failure mode)
- Fine-tuning specializes the model for competition math
- Self-consistency catches reasoning errors
- Smaller models allow more samples within time/cost budget
The technique stack that works:
| Technique | Contribution | Cost |
|---|---|---|
| Math fine-tuning | +15-20% accuracy | Training once |
| Tool use (TIR) | +10-15% accuracy | Minimal |
| Self-consistency (64 samples) | +5-10% accuracy | 64× compute |
| Better prompting | +3-5% accuracy | Free |
Stacking all techniques turns a mediocre 7B model into a competition winner.
The AI Mathematical Olympiad (AIMO) Progress Prizes have driven practical advances in mathematical reasoning. Here's what winning teams used:
┌─────────────────────────────────────────────────────────────────────────┐
│ AIMO PROGRESS PRIZE RESULTS │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ AIMO 1 (July 2024): NuminaMath - $131,072 │
│ • Fine-tuned DeepSeekMath-Base 7B │
│ • Tool-Integrated Reasoning (TIR) │
│ • Self-Consistency TIR (SC-TIR) decoding │
│ • 29/50 problems solved │
│ │
│ AIMO 2 (2024): NVIDIA NemoSkills │
│ • Larger models with improved prompting │
│ • Enhanced code execution │
│ │
│ AIMO 3 (2025): Increased difficulty, larger prizes │
│ • Problems designed to resist current techniques │
│ • Focus on genuine mathematical reasoning │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ COMPETITION STATISTICS (AIMO 1): │
│ ───────────────────────────────── │
│ │
│ • 16,104 registrations │
│ • 1,401 participants on 1,161 teams │
│ • 1,831 submissions from 81 countries │
│ • Prize awarded by Terence Tao at IMO Bath │
│ │
└─────────────────────────────────────────────────────────────────────────┘
class NuminaMathWinningApproach:
"""
NuminaMath: Winner of AIMO Progress Prize 1.
Key components:
1. Fine-tuned DeepSeekMath-Base 7B
2. Tool-Integrated Reasoning (TIR)
3. Self-Consistency TIR (SC-TIR) decoding
4. ~1M math problems training dataset
Repository: github.com/project-numina/aimo-progress-prize
"""
def __init__(
self,
model_path: str = "AI-MO/NuminaMath-7B-TIR",
n_samples: int = 64
):
self.model = self.load_model(model_path)
self.executor = PythonCodeExecutor()
self.n_samples = n_samples
def solve(self, problem: str) -> dict:
"""
Full NuminaMath solving pipeline.
"""
# Generate multiple TIR solutions
solutions = []
for i in range(self.n_samples):
solution = self.generate_tir_solution(problem)
solutions.append(solution)
# SC-TIR: Self-Consistency with code verification
return self.sc_tir_aggregate(solutions)
def generate_tir_solution(self, problem: str) -> dict:
"""
Generate solution with Tool-Integrated Reasoning.
Format:
<think>reasoning</think>
<code>python code</code>
<output>execution result</output>
... repeat as needed ...
<answer>final answer</answer>
"""
prompt = f"""Solve this math problem step by step.
Use Python code for calculations.
Problem: {problem}
Solution:
<think>"""
response = ""
code_results = []
while True:
# Generate next chunk
chunk = self.model.generate(
prompt + response,
max_tokens=512,
stop=["</code>", "</answer>"]
)
response += chunk
# Handle code execution
if response.endswith("</code>"):
# Execute code block
code = self.extract_latest_code(response)
result = self.executor.run(code)
response += f"\n<output>\n{result}\n</output>\n<think>"
code_results.append({"code": code, "output": result})
elif "</answer>" in response or len(response) > 4000:
break
# Extract answer
answer = self.extract_answer(response)
return {
"reasoning": response,
"answer": answer,
"code_results": code_results,
"valid": answer is not None
}
def sc_tir_aggregate(self, solutions: list[dict]) -> dict:
"""
Self-Consistency TIR aggregation.
1. Filter to solutions with valid code execution
2. Filter to solutions with extracted answers
3. Majority vote
"""
# Filter valid solutions
valid = [s for s in solutions if s["valid"]]
# Filter by code execution success
code_verified = [
s for s in valid
if all("Error" not in r["output"] for r in s["code_results"])
]
# Use code-verified if available, else all valid
pool = code_verified if code_verified else valid
if not pool:
return {"answer": None, "confidence": 0}
# Majority vote
answers = [s["answer"] for s in pool]
answer_counts = Counter(answers)
best_answer, count = answer_counts.most_common(1)[0]
return {
"answer": best_answer,
"confidence": count / len(pool),
"n_agreeing": count,
"n_total": len(pool),
"n_code_verified": len(code_verified)
}
class NuminaMathTrainingData:
"""
NuminaMath's training data pipeline.
~1M math problems with solutions:
- Synthetic problems generated by LLMs
- Filtered for correctness via code execution
- Diverse difficulty levels
"""
def __init__(self, strong_model):
self.model = strong_model
self.executor = PythonCodeExecutor()
def generate_training_pair(
self,
seed_problem: str
) -> dict | None:
"""
Generate training (problem, solution) pair.
"""
# Generate solution with TIR format
solution = self.generate_solution(seed_problem)
# Verify with code execution
if not self.verify_solution(seed_problem, solution):
return None
return {
"problem": seed_problem,
"solution": solution["reasoning"],
"answer": solution["answer"]
}
def create_dataset(
self,
seed_problems: list[str],
augmentation_factor: int = 10
) -> list[dict]:
"""
Create large training dataset.
1. Start with seed problems
2. Augment by varying numbers/contexts
3. Filter by code verification
"""
dataset = []
for problem in tqdm(seed_problems):
# Original
pair = self.generate_training_pair(problem)
if pair:
dataset.append(pair)
# Augmentations
for _ in range(augmentation_factor):
augmented = self.augment_problem(problem)
pair = self.generate_training_pair(augmented)
if pair:
dataset.append(pair)
return dataset
def augment_problem(self, problem: str) -> str:
"""
Create variation of problem.
- Change numbers
- Change context
- Modify constraints
"""
prompt = f"""Create a variation of this math problem.
Change the numbers and some context, but keep the same mathematical structure.
Original: {problem}
Variation:"""
return self.model.generate(prompt, max_tokens=200)
Part 14: Reasoning Model Distillation
Distilling reasoning capabilities from large models to smaller ones enables practical deployment.
Why Distillation Matters for Mathematical Reasoning
The best mathematical reasoning systems are enormous. DeepSeek-R1 is 671B parameters. GPT-o1 and Gemini Deep Think are similarly massive. Running these models requires multiple high-end GPUs, costs dollars per query, and takes tens of seconds per response.
For most applications—homework help, competition practice, educational tools—this is impractical. Distillation offers a solution: transfer the reasoning capabilities of large models into smaller, deployable ones.
The Surprising Effectiveness of Simple Distillation
The DeepSeek team discovered something counterintuitive: simple supervised fine-tuning (SFT) on reasoning traces is remarkably effective.
The traditional view was that reasoning requires the model to "discover" chain-of-thought through reinforcement learning. You can't just show a model reasoning examples and expect it to reason—the model needs to learn the underlying process.
DeepSeek proved this wrong. When they fine-tuned small models (1.5B-70B) on 800K reasoning traces from R1-671B, the small models acquired strong reasoning capabilities. A 7B model trained this way achieves 55.5% on AIME—comparable to much larger models.
Why does this work? The key insight: reasoning patterns are learnable from examples. The large model has discovered effective reasoning strategies (backtracking, self-verification, trying alternatives). When small models see thousands of examples of these strategies, they learn to apply them. The small model doesn't need to discover reasoning from scratch—it just needs to imitate demonstrated reasoning.
The Distillation vs. RL Tradeoff
You might ask: why not just train small models with RL directly? The DeepSeek paper compared both approaches:
| Approach | 7B AIME | Training Cost | Notes |
|---|---|---|---|
| RL from scratch | ~35% | 2 weeks GPU cluster | Discovers own reasoning patterns |
| Distillation from R1 | 55.5% | 2 days, 8 GPUs | Learns teacher's patterns |
Distillation is both cheaper and more effective. The small model benefits from reasoning patterns that the large model discovered over extensive RL training. It's like a student learning from a master's solved examples rather than rediscovering mathematics from first principles.
Quality Over Quantity
The s1 paper demonstrated that very small datasets can work. Their "budget distillation" approach:
- Only ~1,000 high-quality examples
- Total training cost under $50
- Performance competitive with o1-preview on math/coding
The key is quality filtering. Each training example must have:
- Correct final answer (verified by code execution)
- Clear reasoning chain (explicit step-by-step logic)
- No unnecessary verbosity (concise but complete)
- Diverse problem coverage (different domains and difficulty levels)
A small dataset of perfect examples beats a large dataset of mediocre ones.
On-Policy vs. Off-Policy Distillation
Standard distillation is "off-policy": generate all training data upfront, then train the student. On-policy distillation is more sophisticated:
- Student generates partial solution
- Teacher corrects/completes it
- Student learns from correction
- Repeat with updated student
This addresses distribution shift: off-policy training uses the teacher's distribution, but the student needs to perform on its own distribution. On-policy distillation progressively aligns the training distribution with the student's actual behavior.
Results show 4-14% improvement from on-policy over standard distillation.
┌─────────────────────────────────────────────────────────────────────────┐
│ REASONING DISTILLATION │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ KEY INSIGHT: │
│ ──────────── │
│ │
│ "Reasoning patterns of larger models can be distilled into smaller │
│ models, resulting in better performance compared to the reasoning │
│ patterns discovered through RL on small models." │
│ - DeepSeek R1 Paper │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ DISTILLED MODELS: │
│ ───────────────── │
│ │
│ Source: DeepSeek-R1 (671B) │
│ Dataset: 800K reasoning examples │
│ │
│ Target Base Model AIME MATH │
│ ───────────────────────────────────────────────────────────────────── │
│ R1-Distill-1.5B Qwen2.5-1.5B 28.9% 83.9% │
│ R1-Distill-7B Qwen2.5-7B 55.5% 92.8% │
│ R1-Distill-8B Llama-3.1-8B 50.4% 89.1% │
│ R1-Distill-14B Qwen2.5-14B 69.7% 93.9% │
│ R1-Distill-32B Qwen2.5-32B 72.6% 94.3% │
│ R1-Distill-70B Llama-3.3-70B 79.8% 94.5% │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ s1 MODEL (Budget Distillation): │
│ ──────────────────────────────── │
│ │
│ "Researchers created an open rival to o1 for under $50" │
│ │
│ • Distilled from Gemini 2.0 Flash Thinking │
│ • Small dataset + SFT │
│ • Competitive with o1-preview on math/coding │
│ │
└─────────────────────────────────────────────────────────────────────────┘
class ReasoningDistillation:
"""
Distill reasoning capabilities from large to small models.
Key insight: SFT on reasoning traces is surprisingly effective.
No additional RL needed for strong performance.
"""
def __init__(
self,
teacher_model, # Large reasoning model
student_model, # Small base model
tokenizer
):
self.teacher = teacher_model
self.student = student_model
self.tokenizer = tokenizer
def generate_distillation_data(
self,
problems: list[str],
samples_per_problem: int = 3
) -> list[dict]:
"""
Generate distillation dataset from teacher.
For each problem:
1. Teacher generates reasoning traces
2. Filter for correct answers
3. Keep high-quality traces
"""
dataset = []
for problem in tqdm(problems):
correct_traces = []
for _ in range(samples_per_problem * 2): # Oversample
# Generate with teacher
trace = self.teacher.generate(
f"Solve step by step:\n\n{problem}\n\nSolution:",
max_tokens=4096,
temperature=0.7
)
# Check if correct
answer = self.extract_answer(trace)
ground_truth = self.get_ground_truth(problem)
if self.verify_answer(answer, ground_truth):
correct_traces.append({
"problem": problem,
"trace": trace,
"answer": answer
})
if len(correct_traces) >= samples_per_problem:
break
dataset.extend(correct_traces)
return dataset
def distill(
self,
dataset: list[dict],
epochs: int = 3,
learning_rate: float = 2e-5
) -> dict:
"""
Distill reasoning into student model via SFT.
Simple but effective: just train student to
reproduce teacher's reasoning traces.
"""
# Format for training
training_data = []
for item in dataset:
training_data.append({
"input": f"Solve step by step:\n\n{item['problem']}\n\nSolution:",
"output": item["trace"]
})
# Standard SFT
train_dataset = Dataset.from_list(training_data)
training_args = TrainingArguments(
output_dir="./distilled_model",
num_train_epochs=epochs,
per_device_train_batch_size=4,
gradient_accumulation_steps=8,
learning_rate=learning_rate,
warmup_ratio=0.1,
logging_steps=10,
save_strategy="epoch",
bf16=True
)
trainer = Trainer(
model=self.student,
args=training_args,
train_dataset=train_dataset,
tokenizer=self.tokenizer
)
trainer.train()
return trainer.state.log_history
class OnPolicyDistillation:
"""
On-Policy Distillation: More effective than offline.
Instead of using static teacher traces,
generate new traces based on student's current policy.
Results: 4-14% improvement over standard distillation.
"""
def __init__(
self,
teacher,
student,
tokenizer
):
self.teacher = teacher
self.student = student
self.tokenizer = tokenizer
def distill_on_policy(
self,
problems: list[str],
iterations: int = 5
) -> dict:
"""
On-policy distillation loop.
1. Student generates partial solution
2. Teacher completes/corrects
3. Student learns from corrected version
4. Repeat
"""
history = []
for iteration in range(iterations):
print(f"Iteration {iteration + 1}/{iterations}")
# Generate training data on-policy
training_data = []
for problem in tqdm(problems):
# Student generates partial solution
student_attempt = self.student.generate(
f"Solve:\n{problem}\n\nSolution:",
max_tokens=1024,
temperature=0.8
)
# Teacher corrects/completes
teacher_completion = self.teacher.generate(
f"""The student attempted this problem but may have errors.
Problem: {problem}
Student's attempt:
{student_attempt}
Provide the correct, complete solution:
Solution:""",
max_tokens=4096,
temperature=0.3
)
# Verify teacher's solution
if self.verify_solution(problem, teacher_completion):
training_data.append({
"input": f"Solve:\n{problem}\n\nSolution:",
"output": teacher_completion
})
# Train student on this iteration's data
metrics = self.train_student(training_data)
history.append({
"iteration": iteration,
"n_examples": len(training_data),
"metrics": metrics
})
# Evaluate
eval_results = self.evaluate()
print(f" MATH: {eval_results['math']:.1%}")
return history
class BudgetDistillation:
"""
Budget Distillation: Create reasoning model under $50.
Based on s1 paper: small dataset + careful SFT
can match expensive models.
Key findings:
- ~1000 high-quality examples sufficient
- Quality > quantity for reasoning
- Chain-of-thought format crucial
"""
def __init__(self, base_model_name: str = "Qwen/Qwen2.5-7B"):
self.model = AutoModelForCausalLM.from_pretrained(base_model_name)
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
def create_budget_dataset(
self,
source_model, # e.g., Gemini Flash Thinking
seed_problems: list[str],
n_examples: int = 1000
) -> list[dict]:
"""
Create small but high-quality distillation dataset.
Focus on:
1. Diverse problem types
2. High-quality reasoning traces
3. Verified correct answers
"""
dataset = []
# Sample diverse problems
sampled = random.sample(seed_problems, min(n_examples * 2, len(seed_problems)))
for problem in tqdm(sampled):
if len(dataset) >= n_examples:
break
# Generate reasoning trace
trace = source_model.generate(
f"Think step by step and solve:\n\n{problem}",
max_tokens=4096
)
# Quality checks
if not self.quality_check(trace):
continue
# Verify answer
if not self.verify_answer(problem, trace):
continue
dataset.append({
"problem": problem,
"trace": trace
})
return dataset
def quality_check(self, trace: str) -> bool:
"""
Check reasoning trace quality.
"""
# Minimum reasoning length
if len(trace) < 500:
return False
# Has structured thinking
thinking_indicators = [
"let me", "first", "then", "therefore",
"because", "since", "this means"
]
if sum(1 for ind in thinking_indicators if ind in trace.lower()) < 3:
return False
# Not too repetitive
sentences = trace.split(".")
unique_ratio = len(set(sentences)) / len(sentences)
if unique_ratio < 0.7:
return False
return True
def train_budget_model(
self,
dataset: list[dict],
epochs: int = 3
) -> dict:
"""
Train with minimal compute.
Cost breakdown (approximate):
- Dataset generation: ~$30 (API calls)
- Training: ~$20 (GPU time)
- Total: <$50
"""
# Estimate cost
n_tokens = sum(len(d["trace"].split()) * 1.3 for d in dataset)
api_cost = n_tokens / 1_000_000 * 0.15 # Rough estimate
print(f"Estimated API cost: ${api_cost:.2f}")
# Training
training_args = TrainingArguments(
output_dir="./budget_model",
num_train_epochs=epochs,
per_device_train_batch_size=2,
gradient_accumulation_steps=16,
learning_rate=1e-5,
warmup_ratio=0.1,
bf16=True,
# Minimal logging for cost
logging_steps=50,
save_strategy="epoch"
)
# Quick training
trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=self.prepare_dataset(dataset),
tokenizer=self.tokenizer
)
trainer.train()
return {"cost_estimate": api_cost + 20} # API + compute
Conclusion: The Path to Gold
Building AI systems that win math competitions in 2025 requires combining multiple techniques across training, inference, and system architecture:
For Training (if you have resources)
- Dr. GRPO for bias-free RL training (avoid length inflation)
- Rule-based rewards (RLVR) for verifiable domains
- Curriculum learning from easy to hard
- Cold-start SFT before RL
- Process Reward Models for step-level feedback
- Self-verification training (DeepSeekMath-V2 approach)
- Self-play evolution (rStar-Math: 7B model achieves 90% on MATH)
For Inference (no training required)
- Constrained MCTS (CMCTS) for principled search
- Tool-Integrated Reasoning (SC-TIR) with code verification
- Verification-refinement pipelines for IMO-level problems
- Parallel thinking (explore multiple approaches simultaneously)
- Self-consistency through majority voting
- Multi-agent debate for hard problems
For Agentic Systems
- Tool use (Python, SymPy, Lean 4)
- Formal verification (DeepSeek-Prover-V2 for guaranteed correctness)
- Domain-specific strategies (AlphaGeometry2 for geometry: 84%)
- Autonomous tool selection (AutoTIR learns when to use tools)
- Agentic RL (rStar2-Agent: 14B matches 671B models)
What Works in 2025
| Technique | When to Use | Expected Gain |
|---|---|---|
| Dr. GRPO | Training from scratch | +5-10% vs standard GRPO |
| rStar self-play | No teacher model available | 7B → 90% MATH |
| CMCTS | Inference on hard problems | +8-11% vs standard MCTS |
| SC-TIR | Answer-only problems | +15-25% with code verification |
| Self-verification | Full proof problems | Gold at IMO 2025 |
| Formal verification | When certainty needed | 100% if Lean accepts |
| Distillation | Limited compute | 7B achieves 55% AIME |
Key Takeaways
- Self-verification is the breakthrough: DeepSeekMath-V2 and rStar achieve gold by training models to find their own errors before submitting
- Tools are mandatory at the frontier: o4-mini jumps 92.7% → 99.5% with Python—the gap is too large to ignore
- Small models can compete: rStar-Math 7B matches o1-preview; rStar2-Agent 14B matches R1 671B
- Formal verification is maturing: DeepSeek-Prover-V2 solves 49 Putnam problems in Lean 4
- Multiple paths are essential: Parallel thinking + verification beats sequential refinement
- Distillation democratizes access: Budget distillation ($50) creates competitive reasoning models
The field advances at unprecedented speed—2025 saw multiple teams achieve IMO gold. The techniques in this guide represent the current frontier; expect rapid iteration.
Frequently Asked Questions
Related Articles
Test-Time Compute Scaling: CoT, ToT, MCTS, and Search-Based Reasoning
A comprehensive guide to inference-time scaling techniques—Chain of Thought, Tree of Thoughts, Monte Carlo Tree Search, Process Reward Models, and the HuggingFace search-and-learn framework.
Training Reasoning Models: PPO, GRPO, Reward Functions, and RLVR
A deep technical guide to training reasoning models like o1 and DeepSeek R1—covering PPO, GRPO, reward function design, RLVR, and distillation techniques.
Building Agentic AI Systems: A Complete Implementation Guide
A comprehensive guide to building AI agents—tool use, ReAct pattern, planning, memory, context management, MCP integration, and multi-agent orchestration. With full prompt examples and production patterns.