Skip to main content
Back to Blog

Building AI Agents to Win Math Olympiads: From Training to Inference Optimization

A comprehensive guide to building AI systems that achieve gold-medal performance on IMO and other advanced math competitions—covering training techniques (GRPO, RLVR), inference-time optimization, and agentic architectures with tool use.

30 min read
Share:

The Gold Medal Frontier

In July 2025, an advanced version of Gemini with Deep Think officially achieved gold-medal standard at the International Mathematical Olympiad, solving 5 of 6 problems perfectly within the 4.5-hour time limit. DeepSeek-Math-V2 followed with gold-level scores on both IMO 2025 and the China Mathematical Olympiad. AI has officially crossed the threshold that was once considered decades away.

This post provides a comprehensive technical guide to building AI systems capable of competition-level mathematical reasoning. We'll cover every layer of the stack:

  1. Theoretical foundations: Why these techniques work for math
  2. Training approaches: GRPO, RLVR, curriculum learning, and data strategies
  3. Process Reward Models: Step-level verification for mathematical reasoning
  4. Inference-time optimization: Maximizing performance without training
  5. MCTS and tree search: Structured exploration for complex proofs
  6. Agentic architectures: Tool use, code execution, and formal verification
  7. Real problem walkthroughs: Step-by-step IMO problem solutions
  8. Production deployment: Cost, latency, and scaling considerations

Understanding the Mathematical Reasoning Challenge

The Landscape of Math Competitions

Before diving into techniques, let's understand what we're optimizing for:

CompetitionFormatDifficultyTime per ProblemProof Required
AMC 10/1225 multiple choiceHigh school3 min avgNo
AIME15 integer answers (0-999)Advanced high school10 min avgNo
USAMO/USAJMO6 proof problemsNational olympiad45 min avgYes
IMO6 proof problemsInternational olympiad45 min avgYes
Putnam12 proof problemsUndergraduate25 min avgYes

The jump from AIME (answer-only) to IMO (full proofs) is crucial. AIME problems can be brute-forced with compute—generate many answers, verify numerically. IMO problems require constructing valid mathematical proofs that a human jury would accept.

Current State of the Art (2025)

The landscape has evolved dramatically. Here's the current state as of late 2025:

SystemIMO PerformanceAIME 2025TrainingOpen Source
Gemini Deep Think35/42 (Gold)98%+Parallel thinking + Novel RLNo
DeepSeekMath-V235/42 (Gold)96%GRPO + Self-verificationWeights only
o4-mini (with Python)Not tested99.5%UnknownNo
o3 (with Python)Not tested98.4%UnknownNo
rStar-Math (7B)N/A90% MATHSelf-play MCTSYes
rStar2-Agent (14B)N/A80.6%Agentic RLYes
DeepSeek-Prover-V2N/A88.9% MiniF2FRL + Lean 4Yes
AlphaGeometry284% geometryN/ANeuro-symbolicCode only
QwQ-32BN/A50%Pure RLYes
NuminaMath-7B (TIR)N/A58% (AIMO)SFT + ToolsYes

Key observations:

  • Tool integration matters: o4-mini jumps from 92.7% to 99.5% with Python access
  • Small models can compete: rStar-Math 7B matches o1-preview on MATH
  • Self-verification is crucial: DeepSeekMath-V2's generator-verifier loop achieves gold
  • Formal verification advancing: DeepSeek-Prover-V2 solves 49 Putnam problems in Lean

What Makes Math Olympiads Fundamentally Hard

Competition mathematics is qualitatively different from benchmarks like GSM8K or MATH:

Code
GSM8K (Grade School Math):
  "Sarah has 5 apples. She gives 2 to Tom. How many does she have left?"

  → Pattern: Extract numbers, identify operation, compute
  → Single reasoning step
  → 100% verifiable answer

MATH Dataset (High School Competition):
  "Find all positive integers n such that n² + 2n + 2 divides n³ + 3n² + 3n + 2"

  → Pattern: Factor, find divisibility conditions
  → 2-5 reasoning steps
  → Answer verifiable (finite solution set)

IMO Problem (International Olympiad):
  "Let n ≥ 3 be an integer, and consider a circle with n + 1 equally spaced
   points marked on it. Label these points with the numbers 0, 1, ..., n
   so that each label is used exactly once; two labelings are considered
   the same if one can be obtained from the other by a rotation of the circle.

   A labeling is called beautiful if, for any four labels a < b < c < d
   with a + d = b + c, the chord joining the points labeled a and d does
   not intersect the chord joining the points labeled b and c.

   Let M be the number of beautiful labelings. Prove that M is odd."

  → Pattern: None obvious—requires creative insight
  → 10-50+ reasoning steps
  → Must construct valid proof, not just find answer
  → Multiple valid proof approaches exist
  → Partial credit for incomplete but insightful attempts

The key challenges:

  1. Novelty: Each IMO features entirely new problems—pattern matching fails
  2. Proof construction: Must justify every step, not just produce answer
  3. Creative insight: Often requires discovering non-obvious lemmas
  4. Multi-domain: Combines algebra, geometry, number theory, combinatorics
  5. Rigor: Proofs must handle all edge cases and be logically sound

Theoretical Foundations: Why These Techniques Work

Before implementing, let's understand the theory behind math reasoning in LLMs.

The Reasoning-as-Search Framework

Mathematical reasoning can be formalized as search over a proof tree:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    PROOF TREE STRUCTURE                                 │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│                      ┌─────────────────────┐                            │
│                      │  Problem Statement  │                            │
│                      └──────────┬──────────┘                            │
│                                 │                                       │
│          ┌──────────────────────┼──────────────────────┐                │
│          │                      │                      │                │
│          ▼                      ▼                      ▼                │
│   ┌─────────────┐       ┌─────────────┐       ┌─────────────┐           │
│   │ Approach 1  │       │ Approach 2  │       │ Approach 3  │           │
│   │  (Algebra)  │       │ (Geometry)  │       │(Num Theory) │           │
│   └──────┬──────┘       └──────┬──────┘       └──────┬──────┘           │
│          │                     │                     │                  │
│     ┌────┴────┐           ┌────┴────┐           ┌────┴────┐             │
│     │         │           │         │           │         │             │
│     ▼         ▼           ▼         ▼           ▼         ▼             │
│  Step 1.1  Step 1.2     ...       ...         ...       ...             │
│     │                                                                   │
│  ┌──┴──┐                                                                │
│  │     │                                                                │
│  ▼     ▼                                                                │
│ 1.1.1 1.1.2                                                             │
│  │                                                                      │
│  ▼                                                                      │
│ ┌──────────┐                                                            │
│ │ SOLUTION │  ← One successful path through the tree                   │
│ └──────────┘                                                            │
│                                                                         │
│  Each node = reasoning step; each path = potential proof                │
│  Search algorithms explore this tree to find valid proofs               │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

This perspective reveals why different techniques work:

TechniqueSearch MechanismStrength
Chain-of-ThoughtDepth-first, single pathFast, simple
Self-ConsistencyMultiple DFS, vote on leavesCovers different paths
Tree-of-ThoughtBFS/DFS with explicit branchingSystematic exploration
MCTSUCB-guided tree searchBalances exploration/exploitation
Verification-RefinementIterative deepening with backtrackingError recovery

Why Reinforcement Learning Works for Math

The key insight from DeepSeek R1: math problems have verifiable rewards.

In standard RLHF, we need humans to judge response quality—expensive and noisy. For math:

Python
def math_reward(response: str, ground_truth: str) -> float:
    """
    Perfect reward signal—no learned reward model needed.

    This is why RLVR (RL with Verifiable Rewards) works so well
    for mathematics.
    """
    predicted_answer = extract_answer(response)
    return 1.0 if verify_equal(predicted_answer, ground_truth) else 0.0

This creates a perfect training signal:

  • No reward hacking: Can't game the reward by being confident but wrong
  • Dense signal: Every problem provides clear feedback
  • Scalable: Can train on millions of problems automatically

The Emergence of Mathematical Reasoning

DeepSeek R1-Zero demonstrated that pure RL without SFT can develop sophisticated reasoning behaviors:

Code
Training progression (from R1 paper):

Early training:
  "2 + 3 = 5"  (direct answer)

Mid training:
  "Let me calculate 2 + 3.
   2 + 3 = 5"  (minimal reasoning)

Late training:
  "<think>
   I need to find 2 + 3.
   Starting with 2 and adding 3.
   2 + 1 = 3
   3 + 1 = 4
   4 + 1 = 5
   So 2 + 3 = 5.
   Let me verify: 5 - 3 = 2. Correct.
   </think>
   The answer is 5."  (emergent self-verification)

Key emergent behaviors:

  1. Extended reasoning: Longer chains of thought for harder problems
  2. Self-verification: Checking answers before submitting
  3. Alternative approaches: Trying different methods when stuck
  4. Backtracking: Recognizing and correcting errors

Part 1: Training Approaches for Math Reasoning

Training a model for mathematical reasoning is fundamentally different from standard language model fine-tuning. You need a multi-phase approach that gradually builds up reasoning capabilities: first establish the format, then optimize for correctness through reinforcement learning, and finally refine through iterative self-improvement. This section walks through each phase with complete implementations.

The Complete Training Pipeline

Building a math reasoning model requires four distinct phases, each building on the previous. Data collection provides the raw material. Cold-start SFT establishes the reasoning format (the <think>...</think> structure). Reinforcement learning optimizes for actual correctness using verifiable rewards. Finally, rejection sampling creates high-quality training data from the model's own successful solutions. Skip any phase and you'll see degraded performance—they're all necessary.

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    TRAINING PIPELINE OVERVIEW                           │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  Phase 1: Data Collection & Preprocessing                               │
│  ┌───────────────────────────────────────────────────────────────────┐ │
│  │  • Competition problems (AMC, AIME, IMO, etc.)                    │ │
│  │  • Textbook exercises with solutions                              │ │
│  │  • Synthetic problem generation                                   │ │
│  │  • Difficulty labeling and domain classification                  │ │
│  └───────────────────────────────────────────────────────────────────┘ │
│                              │                                          │
│                              ▼                                          │
│  Phase 2: Cold-Start SFT (Optional but Recommended)                     │
│  ┌───────────────────────────────────────────────────────────────────┐ │
│  │  • Train on high-quality reasoning traces                         │ │
│  │  • Establish <think>...</think> format                            │ │
│  │  • Bootstrap from stronger model outputs                          │ │
│  └───────────────────────────────────────────────────────────────────┘ │
│                              │                                          │
│                              ▼                                          │
│  Phase 3: Reinforcement Learning (GRPO/PPO)                             │
│  ┌───────────────────────────────────────────────────────────────────┐ │
│  │  • Rule-based rewards for correctness                             │ │
│  │  • Format rewards for reasoning structure                         │ │
│  │  • KL penalty to prevent mode collapse                            │ │
│  │  • Curriculum: easy → hard problems                               │ │
│  └───────────────────────────────────────────────────────────────────┘ │
│                              │                                          │
│                              ▼                                          │
│  Phase 4: Rejection Sampling & Iterative Training                       │
│  ┌───────────────────────────────────────────────────────────────────┐ │
│  │  • Generate many solutions per problem                            │ │
│  │  • Filter for correct + high-quality reasoning                    │ │
│  │  • Fine-tune on filtered data                                     │ │
│  │  • Repeat with updated model                                      │ │
│  └───────────────────────────────────────────────────────────────────┘ │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

Data Collection and Preparation

The quality of training data fundamentally limits model capability.

Why data diversity trumps data quantity: A model trained on 100,000 similar algebra problems will be excellent at algebra but fail on geometry. A model trained on 10,000 diverse problems spanning algebra, geometry, number theory, and combinatorics develops flexible problem-solving skills. For math reasoning, you want coverage across problem types, difficulty levels, and solution approaches. The distribution of your training data directly shapes the distribution of problems your model can solve.

The difficulty distribution tradeoff: Too many easy problems, and the model never learns hard reasoning. Too many hard problems, and the model can't learn the fundamentals that hard problems build on. The 30/40/25/5 split (easy/medium/hard/very hard) is empirically effective—it provides a solid foundation while still exposing the model to olympiad-level challenges. Think of it as curriculum learning encoded in your data distribution.

Synthetic data as a force multiplier: Competition problems are scarce (a few thousand IMO problems exist in total). Synthetic generation creates unlimited training data with verified solutions. The key is quality control: generated problems must be solvable, non-trivial, and mathematically interesting. Template-based generation with symbolic verification ensures every synthetic problem has a known-correct answer.

Python
class MathDatasetBuilder:
    """
    Build comprehensive math training dataset.

    Key insight: Diversity matters more than size.
    1000 diverse problems > 100,000 similar problems.
    """

    def __init__(self):
        self.sources = {
            "competition": CompetitionProblemLoader(),
            "textbook": TextbookExerciseLoader(),
            "synthetic": SyntheticProblemGenerator(),
            "web": WebScrapedMathLoader()
        }

        self.difficulty_estimator = DifficultyEstimator()
        self.domain_classifier = DomainClassifier()
        self.deduplicator = SemanticDeduplicator()

    def build_dataset(
        self,
        target_size: int = 500_000,
        difficulty_distribution: dict = None
    ) -> Dataset:
        """
        Build balanced training dataset.

        Target distribution:
        - 30% easy (AMC 10 level)
        - 40% medium (AMC 12 / AIME level)
        - 25% hard (USAMO level)
        - 5% very hard (IMO level)
        """
        difficulty_distribution = difficulty_distribution or {
            "easy": 0.30,
            "medium": 0.40,
            "hard": 0.25,
            "very_hard": 0.05
        }

        raw_problems = []

        # Collect from all sources
        for source_name, loader in self.sources.items():
            problems = loader.load_all()
            print(f"Loaded {len(problems)} from {source_name}")
            raw_problems.extend(problems)

        # Deduplicate (semantic similarity, not exact match)
        unique_problems = self.deduplicator.deduplicate(
            raw_problems,
            similarity_threshold=0.85
        )
        print(f"After deduplication: {len(unique_problems)}")

        # Classify each problem
        classified = []
        for problem in unique_problems:
            classified.append({
                **problem,
                "difficulty": self.difficulty_estimator.estimate(problem["text"]),
                "domain": self.domain_classifier.classify(problem["text"])
            })

        # Balance by difficulty
        balanced = self.balance_by_difficulty(
            classified,
            difficulty_distribution,
            target_size
        )

        return Dataset.from_list(balanced)

    def balance_by_difficulty(
        self,
        problems: list[dict],
        distribution: dict,
        target_size: int
    ) -> list[dict]:
        """
        Balance dataset to match target difficulty distribution.
        """
        by_difficulty = defaultdict(list)
        for p in problems:
            by_difficulty[p["difficulty"]].append(p)

        balanced = []
        for difficulty, target_fraction in distribution.items():
            target_count = int(target_size * target_fraction)
            available = by_difficulty[difficulty]

            if len(available) >= target_count:
                balanced.extend(random.sample(available, target_count))
            else:
                # Oversample if not enough
                balanced.extend(available)
                balanced.extend(random.choices(
                    available,
                    k=target_count - len(available)
                ))

        random.shuffle(balanced)
        return balanced


class SyntheticProblemGenerator:
    """
    Generate synthetic math problems for training.

    Key techniques:
    1. Template-based generation with variable substitution
    2. Problem composition (combine simpler problems)
    3. Difficulty scaling (add constraints, increase numbers)
    4. LLM-based generation with verification
    """

    def __init__(self, llm=None):
        self.llm = llm
        self.templates = self.load_templates()
        self.verifier = SymbolicVerifier()

    def generate_from_template(
        self,
        template: str,
        difficulty: str = "medium"
    ) -> dict:
        """
        Generate problem from parameterized template.

        Example template:
        "Find all positive integers n < {max_n} such that {expression} is a perfect square."
        """
        # Sample parameters based on difficulty
        params = self.sample_parameters(template, difficulty)

        # Instantiate template
        problem_text = template.format(**params)

        # Solve symbolically to get ground truth
        solution = self.verifier.solve(problem_text)

        if solution is None:
            return None  # Skip unsolvable

        return {
            "text": problem_text,
            "answer": solution["answer"],
            "solution": solution["steps"],
            "difficulty": difficulty,
            "synthetic": True
        }

    def generate_with_llm(
        self,
        domain: str,
        difficulty: str,
        style_examples: list[str]
    ) -> dict:
        """
        Use LLM to generate novel problems.

        Critical: Must verify generated problems are valid.
        """
        prompt = f"""Generate a {difficulty} {domain} problem in the style of these examples:

{chr(10).join(style_examples[:3])}

Requirements:
1. Problem must have a well-defined answer
2. Difficulty should match {difficulty} level
3. Problem should be novel (not copy of examples)
4. Include the correct answer

Output format:
PROBLEM: <problem text>
ANSWER: <correct answer>
SOLUTION: <step-by-step solution>
"""

        response = self.llm.generate(prompt, temperature=0.8)

        # Parse response
        problem = self.parse_generated(response)

        # Verify answer is correct
        if not self.verify_generated_problem(problem):
            return None

        return problem

    def verify_generated_problem(self, problem: dict) -> bool:
        """
        Verify that generated problem is valid and solvable.
        """
        # Try to solve with symbolic solver
        symbolic_solution = self.verifier.solve(problem["text"])

        if symbolic_solution:
            # Check if answers match
            return self.answers_match(
                symbolic_solution["answer"],
                problem["answer"]
            )

        # Fallback: Use LLM verification (less reliable)
        return self.llm_verify(problem)


class DifficultyEstimator:
    """
    Estimate problem difficulty from text.

    Based on:
    1. Linguistic complexity
    2. Mathematical concepts mentioned
    3. Problem length and structure
    4. Comparison to known-difficulty problems
    """

    def __init__(self):
        # Load embedding model for similarity
        self.embedder = SentenceTransformer("all-mpnet-base-v2")

        # Reference problems with known difficulties
        self.reference_problems = self.load_reference_problems()
        self.reference_embeddings = self.embedder.encode(
            [p["text"] for p in self.reference_problems]
        )

    def estimate(self, problem_text: str) -> str:
        """
        Estimate difficulty level.

        Returns: "easy", "medium", "hard", or "very_hard"
        """
        # Method 1: Similarity to reference problems
        embedding = self.embedder.encode(problem_text)
        similarities = cosine_similarity([embedding], self.reference_embeddings)[0]

        # Find most similar reference problems
        top_k_idx = np.argsort(similarities)[-5:]
        similar_difficulties = [
            self.reference_problems[i]["difficulty"]
            for i in top_k_idx
        ]

        # Weighted vote based on similarity
        similarity_estimate = self.weighted_vote(
            similar_difficulties,
            similarities[top_k_idx]
        )

        # Method 2: Feature-based estimation
        features = self.extract_features(problem_text)
        feature_estimate = self.feature_based_estimate(features)

        # Combine estimates
        return self.combine_estimates(similarity_estimate, feature_estimate)

    def extract_features(self, text: str) -> dict:
        """Extract difficulty-indicating features."""
        features = {}

        # Length features
        features["word_count"] = len(text.split())
        features["sentence_count"] = len(text.split("."))

        # Mathematical concept indicators
        hard_concepts = [
            "prove", "show that", "if and only if", "for all",
            "there exists", "infinitely many", "uniquely determined"
        ]
        features["hard_concept_count"] = sum(
            1 for c in hard_concepts if c in text.lower()
        )

        # Domain-specific keywords
        number_theory = ["prime", "divisible", "gcd", "modulo", "congruent"]
        combinatorics = ["permutation", "combination", "count", "arrange"]
        geometry = ["triangle", "circle", "angle", "perpendicular"]
        algebra = ["polynomial", "equation", "inequality", "sequence"]

        features["number_theory_score"] = sum(1 for k in number_theory if k in text.lower())
        features["combinatorics_score"] = sum(1 for k in combinatorics if k in text.lower())
        features["geometry_score"] = sum(1 for k in geometry if k in text.lower())
        features["algebra_score"] = sum(1 for k in algebra if k in text.lower())

        # Proof requirement (typically harder)
        features["requires_proof"] = any(
            p in text.lower() for p in ["prove", "show that", "demonstrate"]
        )

        return features

GRPO: Group Relative Policy Optimization (Deep Dive)

GRPO is DeepSeek's key innovation for training reasoning models. Let's understand it deeply.

Why GRPO is ideal for math training: Math problems have a unique property: answers are verifiable. Unlike creative writing where "better" is subjective, math has ground truth. This means we don't need a learned reward model—just check if the answer is correct. GRPO exploits this by generating multiple solutions per problem, scoring each as correct/incorrect, and using the group statistics as a baseline. No critic network needed, halving memory requirements.

The group baseline intuition: In standard RL, you need a value function to answer "how good is this state?" But for math, we can approximate this by asking "how hard is this problem?" If a problem has 8 solutions and 6 are correct, the "baseline" expectation is ~75% success. A correct solution gets positive advantage (better than baseline); an incorrect one gets negative advantage. The group provides a per-problem baseline automatically.

Why this works better than DPO for reasoning: DPO requires preference pairs—but for math, we don't have "chosen" and "rejected" responses, just correct and incorrect. GRPO naturally handles this: all correct solutions are relatively good, all incorrect are relatively bad. The continuous advantage (not binary) also provides richer signal than DPO's chosen/rejected dichotomy.

Python
class GRPOTrainer:
    """
    GRPO: Group Relative Policy Optimization

    Key insight from DeepSeek R1 paper:
    "GRPO foregoes the critic model and estimates the baseline
     from group scores instead"

    Why this matters for math:
    1. No need to train a value network (halves memory)
    2. More stable training (no value estimation errors)
    3. Natural handling of sparse rewards (correct/incorrect)

    Mathematical formulation:
    For each prompt q, generate G responses {o_1, ..., o_G}

    Advantage for response o_i:
    A_i = r(q, o_i) - mean(r(q, o_j) for j in 1..G)

    This is equivalent to using group mean as baseline instead
    of learned value function.
    """

    def __init__(
        self,
        policy_model,
        reference_model,
        tokenizer,
        group_size: int = 16,
        kl_coef: float = 0.04,
        clip_epsilon: float = 0.2,
        learning_rate: float = 1e-6,
        max_grad_norm: float = 1.0
    ):
        self.policy = policy_model
        self.reference = reference_model
        self.tokenizer = tokenizer

        # GRPO hyperparameters
        self.group_size = group_size  # G in the paper
        self.kl_coef = kl_coef  # β in KL penalty
        self.clip_epsilon = clip_epsilon  # ε in clipped objective

        # Optimization
        self.optimizer = torch.optim.AdamW(
            self.policy.parameters(),
            lr=learning_rate,
            weight_decay=0.01
        )
        self.max_grad_norm = max_grad_norm

        # Reward function
        self.reward_fn = self.create_math_reward_function()

    def create_math_reward_function(self) -> callable:
        """
        Create the rule-based reward function for math.

        From DeepSeek R1 paper:
        "For reasoning tasks, we adopt a rule-based reward system
         that verifies whether the final answer is correct"
        """
        def reward(
            prompt: str,
            response: str,
            ground_truth: str
        ) -> float:
            total_reward = 0.0

            # 1. Correctness reward (primary signal)
            predicted = self.extract_answer(response)
            is_correct = self.verify_answer(predicted, ground_truth)
            total_reward += 1.0 if is_correct else 0.0

            # 2. Format reward (encourage thinking structure)
            has_thinking = "<think>" in response and "</think>" in response
            total_reward += 0.1 if has_thinking else 0.0

            # 3. Length penalty (discourage padding)
            # Only penalize if thinking section is too short
            if has_thinking:
                think_content = self.extract_thinking(response)
                if len(think_content.split()) < 50:
                    total_reward -= 0.05

            # 4. Language consistency penalty
            # DeepSeek found models mix languages without this
            has_mixed = self.detect_language_mixing(response)
            total_reward -= 0.1 if has_mixed else 0.0

            return total_reward

        return reward

    def train_step(
        self,
        problem: str,
        ground_truth: str
    ) -> dict:
        """
        Single GRPO training step.

        Algorithm:
        1. Generate G responses for the problem
        2. Compute rewards for each response
        3. Calculate group-relative advantages
        4. Update policy with clipped objective
        """
        # Phase 1: Generate group of responses
        responses = []
        policy_logprobs = []

        for _ in range(self.group_size):
            # Generate with policy model
            response, logprob = self.generate_with_logprob(
                self.policy,
                problem,
                max_new_tokens=4096,
                temperature=1.0  # High temp for diversity
            )
            responses.append(response)
            policy_logprobs.append(logprob)

        # Phase 2: Compute rewards
        rewards = torch.tensor([
            self.reward_fn(problem, r, ground_truth)
            for r in responses
        ])

        # Phase 3: Compute reference model logprobs (for KL)
        with torch.no_grad():
            reference_logprobs = torch.stack([
                self.compute_logprob(self.reference, problem, r)
                for r in responses
            ])

        policy_logprobs = torch.stack(policy_logprobs)

        # Phase 4: Compute GRPO advantages
        advantages = self.compute_grpo_advantages(
            rewards,
            policy_logprobs,
            reference_logprobs
        )

        # Phase 5: PPO-style update with clipped objective
        loss = self.compute_clipped_loss(
            problem,
            responses,
            policy_logprobs.detach(),
            advantages
        )

        # Phase 6: Gradient update
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(
            self.policy.parameters(),
            self.max_grad_norm
        )
        self.optimizer.step()

        return {
            "loss": loss.item(),
            "mean_reward": rewards.mean().item(),
            "max_reward": rewards.max().item(),
            "std_reward": rewards.std().item(),
            "correct_count": (rewards > 0.9).sum().item(),
            "mean_advantage": advantages.mean().item()
        }

    def compute_grpo_advantages(
        self,
        rewards: torch.Tensor,
        policy_logprobs: torch.Tensor,
        reference_logprobs: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute group-relative advantages.

        Formula:
        A_i = (r_i - β * KL_i) - mean(r_j - β * KL_j)

        where KL_i = log π(o_i|q) - log π_ref(o_i|q)
        """
        # Per-response KL divergence
        kl = policy_logprobs - reference_logprobs

        # KL-adjusted rewards
        adjusted_rewards = rewards - self.kl_coef * kl

        # Group baseline (GRPO's key innovation)
        baseline = adjusted_rewards.mean()

        # Advantages
        advantages = adjusted_rewards - baseline

        # Normalize for training stability
        if advantages.std() > 1e-8:
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        return advantages

    def compute_clipped_loss(
        self,
        problem: str,
        responses: list[str],
        old_logprobs: torch.Tensor,
        advantages: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute PPO clipped objective.

        L = -min(r_t * A_t, clip(r_t, 1-ε, 1+ε) * A_t)

        where r_t = π(a|s) / π_old(a|s)
        """
        # Recompute logprobs with gradient
        new_logprobs = torch.stack([
            self.compute_logprob_with_grad(self.policy, problem, r)
            for r in responses
        ])

        # Probability ratio
        ratio = torch.exp(new_logprobs - old_logprobs)

        # Clipped objective
        surr1 = ratio * advantages
        surr2 = torch.clamp(
            ratio,
            1 - self.clip_epsilon,
            1 + self.clip_epsilon
        ) * advantages

        # Take minimum (pessimistic bound)
        loss = -torch.min(surr1, surr2).mean()

        return loss

    def generate_with_logprob(
        self,
        model,
        prompt: str,
        max_new_tokens: int,
        temperature: float
    ) -> tuple[str, torch.Tensor]:
        """
        Generate response and return log probability.
        """
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt")

        # Generate token by token to accumulate logprobs
        generated_ids = []
        total_logprob = 0.0

        for _ in range(max_new_tokens):
            with torch.no_grad():
                outputs = model(
                    torch.cat([input_ids] + generated_ids, dim=1)
                    if generated_ids else input_ids
                )

            logits = outputs.logits[:, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)

            # Sample next token
            next_token = torch.multinomial(probs, num_samples=1)

            # Accumulate log probability
            total_logprob += torch.log(probs[0, next_token[0, 0]])

            generated_ids.append(next_token)

            # Check for EOS
            if next_token[0, 0] == self.tokenizer.eos_token_id:
                break

        response = self.tokenizer.decode(
            torch.cat(generated_ids, dim=1)[0],
            skip_special_tokens=True
        )

        return response, total_logprob


### Dr. GRPO: Fixing GRPO's Hidden Biases

Standard GRPO has two critical biases discovered by researchers at Sea AI Lab that cause problems in practice:

┌─────────────────────────────────────────────────────────────────────────┐ │ GRPO BIAS PROBLEMS │ ├─────────────────────────────────────────────────────────────────────────┤ │ │ │ PROBLEM 1: RESPONSE LENGTH BIAS │ │ ───────────────────────────────── │ │ │ │ Standard GRPO divides advantage by response length: │ │ │ │ loss = Σ(advantage_i / len(response_i)) │ │ │ │ This creates EXPLICIT length bias: │ │ • Longer responses get MORE total gradient magnitude │ │ • Model learns to be verbose even when wrong │ │ • Incorrect responses get progressively longer over training │ │ │ │ ───────────────────────────────────────────────────────────────────── │ │ │ │ PROBLEM 2: DIFFICULTY LEVEL BIAS │ │ ───────────────────────────────── │ │ │ │ GRPO normalizes by standard deviation of rewards per question: │ │ │ │ advantage_i = (r_i - mean(r)) / std(r) │ │ │ │ This causes problems: │ │ • Easy questions (all correct, low std) → overweighted │ │ • Hard questions (all wrong, low std) → overweighted │ │ • Medium questions (mixed, high std) → underweighted │ │ │ │ The model spends too much gradient on trivial/impossible problems! │ │ │ │ ───────────────────────────────────────────────────────────────────── │ │ │ │ OBSERVED FAILURE MODE: │ │ ────────────────────── │ │ │ │ Training step 0: "The answer is 5" (wrong, 10 tokens) │ │ Training step 1000: "Let me think... 5" (wrong, 50 tokens) │ │ Training step 5000: "I'll analyze... ... 5" (wrong, 200 tokens)│ │ │ │ Model learns: "longer = more gradient = must be better" │ │ Result: Verbose wrong answers that LOOK like reasoning │ │ │ └─────────────────────────────────────────────────────────────────────────┘

Code

**Dr. GRPO (GRPO Done Right)** fixes both biases with simple modifications:

```python
class DrGRPOTrainer(GRPOTrainer):
    """
    Dr. GRPO: GRPO Done Right

    Fixes two biases in standard GRPO:
    1. Response length bias → use global constant normalization
    2. Difficulty level bias → remove per-question std normalization

    Paper: "Understanding R1-Zero-Like Training: A Critical Perspective"
    Available in TRL library: loss_type="dr_grpo"

    Results:
    - AIME 2024: 43.3% (vs SimpleRL-Zero 36.0%)
    - Stable output lengths (no inflation)
    - Better token efficiency
    """

    def __init__(
        self,
        policy_model,
        reference_model,
        tokenizer,
        group_size: int = 16,
        kl_coef: float = 0.04,
        clip_epsilon: float = 0.2,
        max_completion_length: int = 4096,  # Global constant for normalization
        **kwargs
    ):
        super().__init__(
            policy_model, reference_model, tokenizer,
            group_size, kl_coef, clip_epsilon, **kwargs
        )
        self.max_completion_length = max_completion_length

    def compute_dr_grpo_loss(
        self,
        policy_logprobs: torch.Tensor,  # Shape: (batch, seq_len)
        old_logprobs: torch.Tensor,
        advantages: torch.Tensor,  # Shape: (batch,)
        response_lengths: torch.Tensor  # Shape: (batch,)
    ) -> torch.Tensor:
        """
        Compute Dr. GRPO loss with bias corrections.

        Key changes from standard GRPO:
        1. Normalize by GLOBAL constant (max_completion_length), not response length
        2. Remove per-question std normalization from advantages
        """
        # Probability ratio (same as standard GRPO)
        ratio = torch.exp(policy_logprobs - old_logprobs)

        # Clipped objective (same as standard GRPO)
        surr1 = ratio * advantages.unsqueeze(-1)
        surr2 = torch.clamp(
            ratio,
            1 - self.clip_epsilon,
            1 + self.clip_epsilon
        ) * advantages.unsqueeze(-1)

        # Per-token loss
        token_loss = -torch.min(surr1, surr2)

        # KEY CHANGE: Normalize by GLOBAL CONSTANT, not response length
        # This eliminates length bias entirely
        normalized_loss = token_loss.sum(dim=-1) / self.max_completion_length

        return normalized_loss.mean()

    def compute_dr_grpo_advantages(
        self,
        rewards: torch.Tensor,
        policy_logprobs: torch.Tensor,
        reference_logprobs: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute advantages WITHOUT per-question normalization.

        Standard GRPO: A_i = (r_i - mean) / std  ← BIASED
        Dr. GRPO:      A_i = r_i - mean          ← UNBIASED
        """
        # KL penalty
        kl = policy_logprobs - reference_logprobs
        adjusted_rewards = rewards - self.kl_coef * kl

        # Group baseline (keep this)
        baseline = adjusted_rewards.mean()

        # Advantages WITHOUT std normalization (key change)
        advantages = adjusted_rewards - baseline

        # Optional: clip extreme advantages for stability
        advantages = torch.clamp(advantages, -5.0, 5.0)

        return advantages


# Using Dr. GRPO with HuggingFace TRL library
from trl import GRPOConfig, GRPOTrainer

config = GRPOConfig(
    output_dir="dr_grpo_math",
    loss_type="dr_grpo",  # Enable Dr. GRPO
    max_completion_length=4096,  # Global normalization constant
    num_generations=16,  # Group size
    beta=0.04,  # KL coefficient
    # ... other config
)

trainer = GRPOTrainer(
    model=model,
    ref_model=ref_model,
    args=config,
    train_dataset=math_dataset,
    tokenizer=tokenizer,
)

trainer.train()

Empirical Results of Dr. GRPO:

MethodAIME 2024MATH500Avg Response Tokens
SimpleRL-Zero-7B36.0%35.2%850 → 2400 (inflated)
Standard GRPO39.1%38.5%800 → 1900 (inflated)
Dr. GRPO43.3%40.9%800 → 850 (stable)

The key insight: without length bias, models learn to reason efficiently rather than verbosely. Wrong answers stay short; correct answers get the reasoning they need.


Why curriculum learning accelerates math training: Imagine learning calculus before arithmetic—you'd fail. Neural networks face similar challenges. If you train on IMO problems from the start, the model rarely sees correct solutions (they're too hard) and learning signal is sparse. Starting with AMC-level problems lets the model learn basic reasoning patterns, which then transfer to harder problems. This isn't just faster—it often achieves higher final performance.

The schedule matters: Too slow a curriculum wastes compute on problems the model has mastered. Too fast and the model never solidifies fundamentals. The "linear" schedule increases difficulty steadily; "exponential" stays easy longer, then ramps up quickly. Empirically, linear works well for most cases. Monitor training metrics: if accuracy on medium problems drops when you introduce hard ones, slow down the schedule.

Python
class CurriculumScheduler:
    """
    Curriculum learning for math training.

    Key insight: Start with easy problems, gradually increase difficulty.

    From research: Models trained with curriculum learning converge
    faster and achieve higher final performance.
    """

    def __init__(
        self,
        dataset: Dataset,
        initial_difficulty: str = "easy",
        warmup_steps: int = 1000,
        schedule: str = "linear"
    ):
        self.dataset = dataset
        self.current_difficulty = initial_difficulty
        self.warmup_steps = warmup_steps
        self.schedule = schedule
        self.step = 0

        # Group problems by difficulty
        self.by_difficulty = {
            "easy": [p for p in dataset if p["difficulty"] == "easy"],
            "medium": [p for p in dataset if p["difficulty"] == "medium"],
            "hard": [p for p in dataset if p["difficulty"] == "hard"],
            "very_hard": [p for p in dataset if p["difficulty"] == "very_hard"]
        }

    def sample_problem(self) -> dict:
        """
        Sample problem according to current curriculum stage.
        """
        if self.schedule == "linear":
            probs = self.linear_schedule()
        elif self.schedule == "exponential":
            probs = self.exponential_schedule()
        else:
            probs = self.constant_schedule()

        # Sample difficulty level
        difficulty = np.random.choice(
            ["easy", "medium", "hard", "very_hard"],
            p=probs
        )

        # Sample problem from that difficulty
        return random.choice(self.by_difficulty[difficulty])

    def linear_schedule(self) -> list[float]:
        """
        Linear increase in difficulty over warmup period.

        Start: [0.7, 0.2, 0.08, 0.02]
        End:   [0.2, 0.4, 0.3, 0.1]
        """
        progress = min(1.0, self.step / self.warmup_steps)

        start = np.array([0.7, 0.2, 0.08, 0.02])
        end = np.array([0.2, 0.4, 0.3, 0.1])

        probs = start + progress * (end - start)
        return probs / probs.sum()

    def exponential_schedule(self) -> list[float]:
        """
        Exponential difficulty increase (faster ramp-up).
        """
        progress = min(1.0, self.step / self.warmup_steps)
        progress = progress ** 2  # Square for faster initial progress

        start = np.array([0.7, 0.2, 0.08, 0.02])
        end = np.array([0.2, 0.4, 0.3, 0.1])

        probs = start + progress * (end - start)
        return probs / probs.sum()

    def advance(self):
        """Advance curriculum step."""
        self.step += 1

Cold-Start Supervised Fine-Tuning

Before jumping into reinforcement learning, you need to establish the reasoning format through supervised fine-tuning. This "cold start" phase is crucial: RL from a raw base model often fails because the model doesn't know how to structure its reasoning. By first training on examples with proper <think>...</think> formatting, you give RL a reasonable starting point. The data can come from stronger models (distillation) or human-written solutions. Even a few thousand high-quality examples can dramatically improve RL convergence.

Python
class ColdStartSFT:
    """
    Cold-start SFT before reinforcement learning.

    Purpose:
    1. Establish thinking format (<think>...</think>)
    2. Provide reasonable starting policy for RL
    3. Reduce RL training time

    From DeepSeek R1:
    "A small amount of cold-start data is constructed to
     fine-tune the base model before RL"
    """

    def __init__(
        self,
        base_model,
        tokenizer,
        learning_rate: float = 5e-6,
        epochs: int = 3
    ):
        self.model = base_model
        self.tokenizer = tokenizer
        self.lr = learning_rate
        self.epochs = epochs

    def prepare_cold_start_data(
        self,
        problems: list[dict],
        strong_model=None
    ) -> list[dict]:
        """
        Prepare high-quality reasoning traces for SFT.

        Methods:
        1. Distill from stronger model
        2. Human-written exemplars
        3. Filtered self-generated (if model is already decent)
        """
        training_examples = []

        for problem in problems:
            if strong_model:
                # Distill from stronger model
                trace = self.generate_reasoning_trace(
                    strong_model,
                    problem["text"]
                )
            else:
                # Use provided solution
                trace = self.format_as_reasoning_trace(
                    problem["text"],
                    problem.get("solution", ""),
                    problem["answer"]
                )

            training_examples.append({
                "prompt": f"Problem: {problem['text']}\n\n",
                "completion": trace
            })

        return training_examples

    def generate_reasoning_trace(
        self,
        model,
        problem: str
    ) -> str:
        """
        Generate high-quality reasoning trace from strong model.
        """
        prompt = f"""Solve this math problem step by step.

Problem: {problem}

Show your complete reasoning process inside <think></think> tags,
then provide the final answer.

<think>
"""
        response = model.generate(
            prompt,
            max_new_tokens=4096,
            temperature=0.3  # Low temp for quality
        )

        # Verify answer is correct before using as training data
        # (Filter out incorrect traces)
        return response

    def format_as_reasoning_trace(
        self,
        problem: str,
        solution: str,
        answer: str
    ) -> str:
        """
        Format existing solution as reasoning trace.
        """
        return f"""<think>
Let me solve this step by step.

{solution}

Let me verify my answer.
The answer is {answer}.
</think>

The answer is \\boxed{{{answer}}}"""

    def train(self, training_data: list[dict]) -> dict:
        """
        Fine-tune on reasoning traces.
        """
        dataset = Dataset.from_list(training_data)

        training_args = TrainingArguments(
            output_dir="./cold_start_sft",
            num_train_epochs=self.epochs,
            per_device_train_batch_size=4,
            gradient_accumulation_steps=4,
            learning_rate=self.lr,
            warmup_ratio=0.1,
            logging_steps=10,
            save_strategy="epoch"
        )

        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=dataset,
            data_collator=DataCollatorForSeq2Seq(self.tokenizer)
        )

        trainer.train()

        return trainer.state.log_history

Part 2: Process Reward Models (PRMs)

Process Reward Models (PRMs) represent one of the most important advances in training mathematical reasoning systems. While traditional reward models give a single score for the entire solution, PRMs score each intermediate reasoning step. This seemingly simple change has profound implications for training dynamics, error detection, and inference-time search.

The Fundamental Problem with Outcome-Only Rewards

Consider a 20-step mathematical proof where step 15 contains an error. With outcome-only rewards:

  • The model receives reward = 0 for the entire solution
  • There's no signal about which of the 20 steps caused the failure
  • The model might incorrectly penalize the good steps (1-14) and reward-hack on easier problems

This is the credit assignment problem in RL, and it's especially severe for long-form mathematical reasoning. If you can only say "this proof is wrong" without saying "step 15 is where it went wrong," learning is extremely inefficient.

How PRMs Solve This

PRMs assign a score to each step, providing dense feedback throughout the reasoning chain. This enables:

  1. Precise error localization: Know exactly which step introduced an error
  2. Step-level beam search: Prune unpromising reasoning paths early, before wasting compute on doomed approaches
  3. Better training signal: Instead of one sparse reward at the end, get continuous feedback
  4. Debugging capability: When a model fails, inspect which steps had low scores

Two Approaches to Building PRMs

Monte Carlo scoring (no separate model needed): For each step in a partial solution, sample many completions and count how often they reach the correct answer. If step 5 usually leads to correct solutions but step 6 rarely does, step 6 probably introduced an error. This is slow but requires no additional training.

Model-based scoring (trained PRM): Train a separate model to predict step correctness directly. This requires labeled data (steps marked correct/incorrect) but is much faster at inference time. The model learns to recognize patterns like "this algebraic manipulation is valid" or "this case analysis missed a case."

The best systems use both: Monte Carlo to generate training labels, then train a fast model-based PRM for deployment.

Why PRMs Matter for Math

Code
Without PRM (Outcome Reward Only):
  Step 1: Correct setup                    → No reward
  Step 2: Correct algebraic manipulation   → No reward
  Step 3: Arithmetic error (5+3=9)         → No reward
  Step 4: Wrong conclusion from error      → No reward
  Final: Wrong answer                      → Reward = 0

  Problem: Model gets no signal about where it went wrong

With PRM (Process Reward):
  Step 1: Correct setup                    → Reward = 1.0
  Step 2: Correct algebraic manipulation   → Reward = 1.0
  Step 3: Arithmetic error (5+3=9)         → Reward = 0.0 ← Error located!
  Step 4: Wrong conclusion from error      → Reward = 0.0
  Final: Wrong answer                      → Reward = 0.0

  Benefit: Clear signal about which step is wrong

Complete PRM Implementation

Now let's look at a full implementation. The ProcessRewardModel class below provides both Monte Carlo scoring (slow but no training needed) and model-based scoring (fast but requires training data). In practice, you'd use Monte Carlo to bootstrap training labels, then train a fast model for deployment. The implementation includes step parsing, score computation, and the complete interface needed for integration with MCTS or beam search.

Python
class ProcessRewardModel:
    """
    Process Reward Model for mathematical reasoning.

    Assigns rewards to each reasoning step, enabling:
    1. Better credit assignment (know which step failed)
    2. Step-level beam search (prune bad paths early)
    3. More stable RL training (denser reward signal)

    Based on OpenAI's "Let's Verify Step by Step" and
    "Solving Math Word Problems with Process- and Outcome-Based Feedback"
    """

    def __init__(
        self,
        base_model,
        tokenizer,
        scoring_method: str = "monte_carlo"
    ):
        self.model = base_model
        self.tokenizer = tokenizer
        self.scoring_method = scoring_method

        # Special tokens for step boundaries
        self.step_token = "<step>"
        self.end_step_token = "</step>"

    def score_solution(
        self,
        problem: str,
        solution: str
    ) -> list[dict]:
        """
        Score each step in a solution.

        Returns list of {step: str, score: float, is_correct: bool}
        """
        # Parse solution into steps
        steps = self.parse_steps(solution)

        if self.scoring_method == "monte_carlo":
            return self.monte_carlo_scoring(problem, steps)
        elif self.scoring_method == "trained_model":
            return self.model_based_scoring(problem, steps)
        else:
            return self.heuristic_scoring(problem, steps)

    def monte_carlo_scoring(
        self,
        problem: str,
        steps: list[str],
        n_rollouts: int = 32
    ) -> list[dict]:
        """
        Score steps by completing from each step and measuring success rate.

        This is the approach from "Let's Verify Step by Step":
        - For each step, complete the solution multiple times
        - Step score = fraction of completions that reach correct answer

        Advantage: Doesn't require labeled step-level data
        Disadvantage: Expensive (requires many completions)
        """
        results = []

        for i, step in enumerate(steps):
            # Build prefix: problem + steps 0..i
            prefix = f"Problem: {problem}\n\nSolution:\n"
            prefix += "\n".join(steps[:i+1])

            # Complete from this point multiple times
            correct_count = 0

            for _ in range(n_rollouts):
                completion = self.model.generate(
                    prefix,
                    max_new_tokens=2048,
                    temperature=0.8
                )

                # Check if completion reaches correct answer
                full_solution = prefix + completion
                if self.verify_solution(problem, full_solution):
                    correct_count += 1

            score = correct_count / n_rollouts

            results.append({
                "step": step,
                "step_index": i,
                "score": score,
                "is_correct": score > 0.5,  # Majority of completions succeed
                "n_rollouts": n_rollouts
            })

        return results

    def model_based_scoring(
        self,
        problem: str,
        steps: list[str]
    ) -> list[dict]:
        """
        Score using a trained step classifier.

        Requires labeled data: (problem, step, is_correct_step)

        Training approaches:
        1. Human annotation of step correctness
        2. Automatic labeling using Monte Carlo estimates
        3. Synthetic errors + correction labels
        """
        results = []

        for i, step in enumerate(steps):
            # Build input for classifier
            input_text = self.format_for_classifier(problem, steps[:i], step)

            # Get model prediction
            with torch.no_grad():
                inputs = self.tokenizer(
                    input_text,
                    return_tensors="pt",
                    truncation=True
                )
                outputs = self.model(**inputs)

                # Assume model outputs logits for [incorrect, correct]
                probs = F.softmax(outputs.logits, dim=-1)
                score = probs[0, 1].item()  # P(correct)

            results.append({
                "step": step,
                "step_index": i,
                "score": score,
                "is_correct": score > 0.5
            })

        return results

    def heuristic_scoring(
        self,
        problem: str,
        steps: list[str]
    ) -> list[dict]:
        """
        Fast heuristic scoring without model inference.

        Useful for real-time verification during generation.
        Less accurate but much faster.
        """
        results = []

        for i, step in enumerate(steps):
            score = 0.5  # Base score

            # Check for mathematical validity
            if self.has_valid_equations(step):
                score += 0.1

            # Check for logical connectives (indicates reasoning)
            if self.has_logical_flow(step):
                score += 0.1

            # Check for consistency with previous steps
            if i > 0 and self.is_consistent(steps[i-1], step):
                score += 0.1

            # Penalize obvious errors
            if self.has_obvious_errors(step):
                score -= 0.3

            results.append({
                "step": step,
                "step_index": i,
                "score": max(0, min(1, score)),
                "is_correct": score > 0.5
            })

        return results

    def parse_steps(self, solution: str) -> list[str]:
        """
        Parse solution into individual reasoning steps.
        """
        # Method 1: Explicit step markers
        if self.step_token in solution:
            steps = solution.split(self.step_token)
            return [s.strip() for s in steps if s.strip()]

        # Method 2: Newline-separated
        lines = solution.split("\n")
        steps = []
        current_step = []

        for line in lines:
            line = line.strip()
            if not line:
                if current_step:
                    steps.append(" ".join(current_step))
                    current_step = []
            else:
                current_step.append(line)

        if current_step:
            steps.append(" ".join(current_step))

        return steps

    def has_valid_equations(self, step: str) -> bool:
        """Check if equations in step are syntactically valid."""
        # Extract equations
        equations = re.findall(r'[\w\d\+\-\*\/\(\)\^\s]+=[\w\d\+\-\*\/\(\)\^\s]+', step)

        for eq in equations:
            try:
                left, right = eq.split("=")
                # Try to parse both sides
                sympy.sympify(left.strip())
                sympy.sympify(right.strip())
            except:
                return False

        return True

    def has_logical_flow(self, step: str) -> bool:
        """Check for logical reasoning indicators."""
        indicators = [
            "therefore", "thus", "hence", "so", "because",
            "since", "implies", "follows", "we have", "this means"
        ]
        return any(ind in step.lower() for ind in indicators)

    def is_consistent(self, prev_step: str, current_step: str) -> bool:
        """Check if current step is consistent with previous."""
        # Extract variables/values from both
        prev_values = self.extract_values(prev_step)
        curr_values = self.extract_values(current_step)

        # Check for contradictions
        for var, val in curr_values.items():
            if var in prev_values and prev_values[var] != val:
                # Variable changed value unexpectedly
                return False

        return True

    def has_obvious_errors(self, step: str) -> bool:
        """Detect obvious mathematical errors."""
        # Check for common arithmetic errors
        arithmetic_patterns = [
            (r'(\d+)\s*\+\s*(\d+)\s*=\s*(\d+)', lambda a, b, c: int(a) + int(b) != int(c)),
            (r'(\d+)\s*\-\s*(\d+)\s*=\s*(\d+)', lambda a, b, c: int(a) - int(b) != int(c)),
            (r'(\d+)\s*\*\s*(\d+)\s*=\s*(\d+)', lambda a, b, c: int(a) * int(b) != int(c)),
        ]

        for pattern, is_error in arithmetic_patterns:
            matches = re.findall(pattern, step)
            for match in matches:
                try:
                    if is_error(*match):
                        return True
                except:
                    pass

        return False


class PRMGuidedGeneration:
    """
    Use PRM to guide step-by-step generation.

    At each step:
    1. Generate candidate next steps
    2. Score each candidate with PRM
    3. Select best (or sample weighted by score)
    """

    def __init__(
        self,
        generator_model,
        prm: ProcessRewardModel,
        beam_width: int = 4,
        n_candidates: int = 8
    ):
        self.generator = generator_model
        self.prm = prm
        self.beam_width = beam_width
        self.n_candidates = n_candidates

    def solve_with_prm_guidance(
        self,
        problem: str,
        max_steps: int = 20
    ) -> dict:
        """
        Generate solution with PRM-guided beam search.
        """
        # Initialize beam with empty solution
        beam = [{"steps": [], "score": 1.0}]

        for step_idx in range(max_steps):
            candidates = []

            for beam_item in beam:
                # Generate candidate next steps
                next_steps = self.generate_next_steps(
                    problem,
                    beam_item["steps"],
                    n_candidates=self.n_candidates
                )

                for step in next_steps:
                    # Score with PRM
                    step_score = self.prm.score_single_step(
                        problem,
                        beam_item["steps"],
                        step
                    )

                    candidates.append({
                        "steps": beam_item["steps"] + [step],
                        "score": beam_item["score"] * step_score,
                        "last_step_score": step_score
                    })

            # Keep top beam_width candidates
            candidates.sort(key=lambda x: x["score"], reverse=True)
            beam = candidates[:self.beam_width]

            # Check if best candidate is complete
            if self.is_complete(beam[0]["steps"]):
                break

            # Early stopping if all scores are low
            if beam[0]["score"] < 0.01:
                break

        # Return best solution
        best = beam[0]
        return {
            "solution": "\n".join(best["steps"]),
            "score": best["score"],
            "steps": len(best["steps"])
        }

    def generate_next_steps(
        self,
        problem: str,
        current_steps: list[str],
        n_candidates: int
    ) -> list[str]:
        """
        Generate candidate next steps.
        """
        prefix = f"Problem: {problem}\n\nSolution:\n"
        if current_steps:
            prefix += "\n".join(current_steps) + "\n"

        candidates = []

        for _ in range(n_candidates):
            # Generate one step
            response = self.generator.generate(
                prefix + "Next step: ",
                max_new_tokens=256,
                temperature=0.8,
                stop_sequences=["\n\n", "Next step:"]
            )

            step = response.split("\n")[0].strip()
            if step:
                candidates.append(step)

        # Deduplicate
        return list(set(candidates))

Part 3: Inference-Time Optimization (No Training Required)

One of the most exciting developments in mathematical reasoning is that you can dramatically improve any model's performance without any training—purely through smarter inference-time algorithms. This is called "test-time compute scaling": spending more compute at inference time to get better answers.

Why Inference-Time Optimization Matters

The basic insight: a single generation from an LLM is often suboptimal. The model might:

  • Take a wrong turn early and be unable to recover
  • Miss the key insight needed for a particular problem
  • Make arithmetic errors that invalidate the rest of the solution

But if you generate multiple attempts and use smart selection, you can dramatically improve success rates. The question is: how do you efficiently explore the space of possible solutions?

The Landscape of Inference-Time Techniques

TechniqueCompute CostWhen to Use
Best-of-NN × single generationQuick improvement, any model
Beam SearchModerateWhen you have a good value function
MCTSHighHard problems, need to backtrack
Iterative RefinementModerateWhen errors are detectable
Self-ConsistencyN × single generationMultiple valid solution paths

Each technique represents a different tradeoff between compute cost and solution quality. For easy problems, Best-of-N with N=4 might suffice. For IMO problems, you might need MCTS with thousands of simulations.

Monte Carlo Tree Search for Math

MCTS is the most sophisticated inference-time technique. Originally developed for game-playing AI (it powered AlphaGo), MCTS has been adapted for mathematical reasoning with remarkable success.

Why MCTS works for math:

Mathematical problem-solving is naturally tree-structured. Each reasoning step is a "move" that leads to a new state. Some moves lead to dead ends; others open up promising paths. MCTS systematically explores this tree, using past experience to guide search toward promising regions.

The four phases of MCTS:

  1. Selection: Starting from the root (the problem statement), traverse the tree by picking the most promising child at each node. "Promising" balances exploitation (high-scoring nodes) with exploration (under-explored nodes) using the UCB formula.

  2. Expansion: When you reach a node with unexplored children, add a new child by generating the next reasoning step.

  3. Simulation/Evaluation: Estimate the value of the new node. This can be done by completing the solution and checking correctness, or using a PRM to score the partial solution.

  4. Backpropagation: Update the value estimates of all nodes on the path from root to the new node. This propagates information about what works back up the tree.

The UCB formula:

UCB(n)=V(n)N(n)+clnNparentN(n)\text{UCB}(n) = \frac{V(n)}{N(n)} + c \cdot \sqrt{\frac{\ln N_{\text{parent}}}{N(n)}}

The first term favors nodes that have worked well. The second term favors nodes that haven't been explored much. The constant c controls the exploration-exploitation tradeoff.

Python
class MathMCTS:
    """
    Monte Carlo Tree Search adapted for mathematical reasoning.

    Key adaptations for math:
    1. States = partial solutions (sequence of reasoning steps)
    2. Actions = possible next steps
    3. Rewards = correctness of final answer
    4. Heuristic value = PRM score or completion success rate

    MCTS provides better exploration than beam search:
    - Balances exploration vs exploitation (UCB)
    - Can backtrack from dead ends
    - Allocates more compute to promising branches
    """

    def __init__(
        self,
        model,
        prm: ProcessRewardModel = None,
        c_puct: float = 1.414,  # Exploration constant
        n_simulations: int = 100,
        max_depth: int = 20
    ):
        self.model = model
        self.prm = prm
        self.c_puct = c_puct
        self.n_simulations = n_simulations
        self.max_depth = max_depth

    def solve(self, problem: str, ground_truth: str = None) -> dict:
        """
        Solve problem using MCTS.

        Returns best solution found and search statistics.
        """
        # Initialize root node
        root = MCTSNode(
            state=MCTSState(problem=problem, steps=[]),
            parent=None
        )

        # Run simulations
        for sim in range(self.n_simulations):
            # Phase 1: Selection - traverse tree using UCB
            node = self.select(root)

            # Phase 2: Expansion - add new child nodes
            if not node.is_terminal and not node.is_fully_expanded:
                node = self.expand(node)

            # Phase 3: Simulation - complete solution and evaluate
            value = self.simulate(node, ground_truth)

            # Phase 4: Backpropagation - update statistics
            self.backpropagate(node, value)

        # Return best solution
        best_path = self.get_best_path(root)

        return {
            "solution": "\n".join(best_path),
            "simulations": self.n_simulations,
            "tree_size": self.count_nodes(root),
            "confidence": root.value / root.visits if root.visits > 0 else 0
        }

    def select(self, node: 'MCTSNode') -> 'MCTSNode':
        """
        Select leaf node using UCB1 formula.

        UCB1 = value/visits + c * sqrt(log(parent_visits) / visits)

        Balances exploitation (high value) with exploration (low visits).
        """
        while node.children and not node.is_terminal:
            if not node.is_fully_expanded:
                return node

            # UCB1 selection
            node = max(
                node.children,
                key=lambda c: self.ucb_score(c, node.visits)
            )

        return node

    def ucb_score(self, node: 'MCTSNode', parent_visits: int) -> float:
        """Compute UCB1 score for node selection."""
        if node.visits == 0:
            return float('inf')  # Prioritize unexplored

        exploitation = node.value / node.visits
        exploration = self.c_puct * math.sqrt(
            math.log(parent_visits) / node.visits
        )

        return exploitation + exploration

    def expand(self, node: 'MCTSNode') -> 'MCTSNode':
        """
        Expand node by generating new child states.
        """
        # Generate possible next steps
        next_steps = self.generate_next_steps(node.state)

        for step in next_steps:
            if step not in [c.state.steps[-1] for c in node.children if c.state.steps]:
                # Create new child
                new_state = MCTSState(
                    problem=node.state.problem,
                    steps=node.state.steps + [step]
                )

                child = MCTSNode(state=new_state, parent=node)
                node.children.append(child)

                # Mark terminal if solution is complete
                if self.is_solution_complete(new_state):
                    child.is_terminal = True

        # Return first unexplored child
        for child in node.children:
            if child.visits == 0:
                return child

        node.is_fully_expanded = True
        return node.children[0] if node.children else node

    def simulate(
        self,
        node: 'MCTSNode',
        ground_truth: str = None
    ) -> float:
        """
        Simulate to terminal state and return value.

        Methods:
        1. Complete with model, check answer (most accurate)
        2. Use PRM score as heuristic (faster)
        3. Lightweight rollout policy (fastest)
        """
        state = node.state

        # If already terminal, evaluate directly
        if node.is_terminal:
            return self.evaluate_solution(state, ground_truth)

        # Option 1: Full completion
        completion = self.complete_solution(state)
        full_solution = "\n".join(state.steps) + "\n" + completion

        if ground_truth:
            # Check against ground truth
            predicted = self.extract_answer(full_solution)
            return 1.0 if self.answers_match(predicted, ground_truth) else 0.0

        # Option 2: Use PRM score as proxy
        if self.prm:
            scores = self.prm.score_solution(state.problem, full_solution)
            return sum(s["score"] for s in scores) / len(scores)

        # Option 3: Heuristic evaluation
        return self.heuristic_value(full_solution)

    def backpropagate(self, node: 'MCTSNode', value: float):
        """
        Backpropagate value up the tree.
        """
        while node is not None:
            node.visits += 1
            node.value += value
            node = node.parent

    def generate_next_steps(
        self,
        state: 'MCTSState',
        n_candidates: int = 5
    ) -> list[str]:
        """
        Generate candidate next reasoning steps.
        """
        prompt = f"Problem: {state.problem}\n\n"
        if state.steps:
            prompt += "Progress so far:\n" + "\n".join(state.steps) + "\n\n"
        prompt += "Next step in the solution:\n"

        candidates = []
        for _ in range(n_candidates):
            response = self.model.generate(
                prompt,
                max_new_tokens=256,
                temperature=0.9,  # High for diversity
                stop_sequences=["\n\n"]
            )
            step = response.strip()
            if step and step not in candidates:
                candidates.append(step)

        return candidates

    def get_best_path(self, root: 'MCTSNode') -> list[str]:
        """
        Extract best solution path from tree.
        """
        path = []
        node = root

        while node.children:
            # Select child with highest visit count (most explored = most confident)
            node = max(node.children, key=lambda c: c.visits)
            if node.state.steps:
                path.append(node.state.steps[-1])

        return path


@dataclass
class MCTSState:
    """State in MCTS = problem + partial solution."""
    problem: str
    steps: list[str]


class MCTSNode:
    """Node in MCTS tree."""

    def __init__(self, state: MCTSState, parent: 'MCTSNode' = None):
        self.state = state
        self.parent = parent
        self.children = []
        self.visits = 0
        self.value = 0.0
        self.is_terminal = False
        self.is_fully_expanded = False

Constrained MCTS (CMCTS) - 2025 Improvements

Standard MCTS explores the solution space uniformly, but mathematical reasoning has structure we can exploit. Not all reasoning steps are equally promising—a step that contradicts earlier work, makes an obvious error, or goes in circles is unlikely to lead anywhere. Constrained MCTS (CMCTS) uses a Process Reward Model to prune unpromising branches early, before wasting compute exploring them. This isn't just an optimization—it fundamentally changes what problems become tractable. By focusing search on high-potential paths, CMCTS achieves results that would require orders of magnitude more compute with unconstrained search.

CMCTS achieves 83.4% on MATH with a 7B model by constraining the search space intelligently:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    CONSTRAINED MCTS (CMCTS)                              │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  KEY INNOVATIONS:                                                        │
│  ────────────────                                                        │
│                                                                          │
│  1. Constrained Action Space                                            │
│     • Limit candidate steps to high-PRM-score options                  │
│     • Prune obviously wrong branches early                             │
│                                                                          │
│  2. Process Reward Model Integration                                    │
│     • Use PRM scores to guide expansion                                │
│     • Weighted UCB with PRM prior                                      │
│                                                                          │
│  3. Partial Order Rules                                                 │
│     • Mathematical reasoning has ordering constraints                  │
│     • Can't use theorem before establishing premises                   │
│     • Prune violations automatically                                   │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  RESULTS:                                                                │
│  ────────                                                                │
│                                                                          │
│  Model          Standard MCTS    CMCTS         Improvement             │
│  ─────────────────────────────────────────────────────────────────────  │
│  7B             72.1%            83.4%         +11.3%                   │
│  72B            78.5%            87.2%         +8.7%                    │
│                                                                          │
│  Speculative Contrastive MCTS (SC-MCTS*): 51.9% faster                │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘
Python
class ConstrainedMCTS(MathMCTS):
    """
    Constrained MCTS for improved math reasoning.

    Key changes from standard MCTS:
    1. Constrained action space (prune low-PRM steps)
    2. PRM-guided expansion (prefer high-quality steps)
    3. Partial order constraints (mathematical dependencies)

    Results: 7B model achieves 83.4% on MATH (vs 72.1% standard MCTS)
    """

    def __init__(
        self,
        model,
        prm: ProcessRewardModel,
        c_puct: float = 1.414,
        n_simulations: int = 100,
        prm_threshold: float = 0.3,  # Minimum PRM score to expand
        max_candidates: int = 5  # Max candidate actions per node
    ):
        super().__init__(model, prm, c_puct, n_simulations)
        self.prm_threshold = prm_threshold
        self.max_candidates = max_candidates

    def expand(self, node: 'MCTSNode') -> 'MCTSNode':
        """
        Constrained expansion: only add high-PRM candidates.
        """
        # Generate candidate next steps
        candidates = self.generate_next_steps(node.state)

        # Score with PRM
        scored_candidates = []
        for step in candidates:
            new_state = MCTSState(
                problem=node.state.problem,
                steps=node.state.steps + [step]
            )
            prm_score = self.prm.score_step(
                node.state.problem,
                node.state.steps,
                step
            )
            scored_candidates.append((step, prm_score))

        # Filter by threshold (CONSTRAINT 1)
        valid_candidates = [
            (step, score) for step, score in scored_candidates
            if score >= self.prm_threshold
        ]

        # Take top candidates (CONSTRAINT 2)
        valid_candidates.sort(key=lambda x: x[1], reverse=True)
        valid_candidates = valid_candidates[:self.max_candidates]

        # Apply partial order constraints (CONSTRAINT 3)
        valid_candidates = [
            (step, score) for step, score in valid_candidates
            if self.satisfies_partial_order(node.state, step)
        ]

        # Create child nodes
        for step, prm_score in valid_candidates:
            new_state = MCTSState(
                problem=node.state.problem,
                steps=node.state.steps + [step]
            )
            child = MCTSNode(state=new_state, parent=node)
            child.prior_score = prm_score  # Store PRM prior
            node.children.append(child)

        # Return first unexplored
        for child in node.children:
            if child.visits == 0:
                return child

        node.is_fully_expanded = True
        return node.children[0] if node.children else node

    def ucb_score(self, node: 'MCTSNode', parent_visits: int) -> float:
        """
        UCB with PRM prior (like AlphaZero policy prior).
        """
        if node.visits == 0:
            # Use PRM prior for unexplored nodes
            return node.prior_score * 10  # Prioritize high-PRM unexplored

        exploitation = node.value / node.visits
        exploration = self.c_puct * math.sqrt(
            math.log(parent_visits) / node.visits
        )

        # Add PRM prior bonus (decays with visits)
        prior_bonus = node.prior_score / (1 + node.visits)

        return exploitation + exploration + prior_bonus

    def satisfies_partial_order(
        self,
        state: MCTSState,
        new_step: str
    ) -> bool:
        """
        Check if new step satisfies mathematical partial order.

        Examples of violations:
        - Using variable before defining it
        - Applying theorem without establishing premises
        - Referencing result from future step
        """
        # Extract variables used in new step
        new_vars = self.extract_variables(new_step)

        # Extract variables defined so far
        defined_vars = set()
        for step in state.steps:
            defined_vars.update(self.extract_defined_variables(step))

        # Check all used variables are defined
        for var in new_vars:
            if var not in defined_vars and not self.is_given(var, state.problem):
                return False  # Using undefined variable

        return True


class SpeculativeContrastiveMCTS:
    """
    SC-MCTS*: Speculative Contrastive MCTS.

    51.9% faster than standard MCTS by:
    1. Speculative expansion (predict multiple levels)
    2. Contrastive pruning (compare sibling branches)
    """

    def __init__(self, model, prm, n_speculative: int = 3):
        self.model = model
        self.prm = prm
        self.n_speculative = n_speculative

    def speculative_expand(self, node: MCTSNode) -> list[MCTSNode]:
        """
        Expand multiple levels speculatively.

        Generate not just next step, but next N steps in one go.
        If PRM scores are consistently high, accept all.
        If PRM drops, backtrack to branching point.
        """
        speculative_paths = []

        # Generate N-step continuations
        prompt = f"""Problem: {node.state.problem}

Current solution:
{chr(10).join(node.state.steps)}

Generate the next {self.n_speculative} reasoning steps:"""

        continuation = self.model.generate(prompt, max_tokens=1024)
        steps = self.parse_steps(continuation)

        # Score each step with PRM
        scores = []
        cumulative_state = node.state.steps.copy()
        for step in steps[:self.n_speculative]:
            score = self.prm.score_step(
                node.state.problem,
                cumulative_state,
                step
            )
            scores.append(score)
            cumulative_state.append(step)

        # Accept steps while PRM score stays high
        accepted_steps = []
        for step, score in zip(steps, scores):
            if score < 0.3:  # PRM threshold
                break
            accepted_steps.append(step)

        return accepted_steps  # May be 0 to N steps

Best-of-N with Weighted Verification

Simple majority voting treats all solutions equally—but not all verifications are equally reliable. A solution verified by code execution is more trustworthy than one that just "looks right." Weighted Best-of-N combines multiple verification signals: self-consistency (do multiple attempts agree?), PRM scores (does each step look valid?), and tool verification (does code confirm the answer?). By weighting these signals appropriately, you can extract more value from the same N samples.

Python
class WeightedBestOfN:
    """
    Best-of-N with multiple verification signals.

    Combines:
    1. Self-consistency (majority voting)
    2. PRM scores
    3. Symbolic verification
    4. Code execution verification

    Weights different signals based on reliability.
    """

    def __init__(
        self,
        model,
        n_samples: int = 32,
        use_prm: bool = True,
        use_symbolic: bool = True,
        use_code: bool = True
    ):
        self.model = model
        self.n_samples = n_samples

        # Initialize verifiers
        if use_prm:
            self.prm = ProcessRewardModel(model, model.tokenizer)
        if use_symbolic:
            self.symbolic = SymbolicVerifier()
        if use_code:
            self.code_executor = CodeVerifier()

        # Verifier weights (tuned on validation set)
        self.weights = {
            "consistency": 1.0,
            "prm": 0.8,
            "symbolic": 1.5,  # High weight—very reliable when applicable
            "code": 1.2
        }

    def solve(self, problem: str) -> dict:
        """
        Solve with weighted verification ensemble.
        """
        # Generate N solutions
        solutions = []
        for i in range(self.n_samples):
            temp = 0.6 + (i % 5) * 0.1  # Vary temperature
            solution = self.model.generate(
                f"Problem: {problem}\n\nSolve step by step:\n",
                max_new_tokens=2048,
                temperature=temp
            )
            solutions.append({
                "text": solution,
                "answer": self.extract_answer(solution),
                "temperature": temp
            })

        # Score each solution with multiple verifiers
        for sol in solutions:
            sol["scores"] = self.compute_all_scores(problem, sol)
            sol["total_score"] = self.aggregate_scores(sol["scores"])

        # Group by answer and aggregate
        answer_scores = self.aggregate_by_answer(solutions)

        # Select best answer
        best_answer = max(answer_scores, key=answer_scores.get)

        # Find best solution with that answer
        best_solution = max(
            [s for s in solutions if self.normalize_answer(s["answer"]) == best_answer],
            key=lambda s: s["total_score"]
        )

        return {
            "answer": best_answer,
            "solution": best_solution["text"],
            "confidence": answer_scores[best_answer],
            "scores": best_solution["scores"],
            "n_samples": len(solutions),
            "answer_distribution": answer_scores
        }

    def compute_all_scores(
        self,
        problem: str,
        solution: dict
    ) -> dict:
        """
        Compute scores from all verifiers.
        """
        scores = {}

        # 1. PRM score
        if hasattr(self, 'prm'):
            prm_results = self.prm.score_solution(problem, solution["text"])
            scores["prm"] = sum(r["score"] for r in prm_results) / len(prm_results)

        # 2. Symbolic verification
        if hasattr(self, 'symbolic'):
            symbolic_result = self.symbolic.verify(problem, solution["answer"])
            scores["symbolic"] = 1.0 if symbolic_result["verified"] else 0.0

        # 3. Code execution verification
        if hasattr(self, 'code_executor'):
            code_result = self.code_executor.verify(problem, solution["answer"])
            scores["code"] = 1.0 if code_result["verified"] else 0.0

        return scores

    def aggregate_scores(self, scores: dict) -> float:
        """Weighted aggregation of verification scores."""
        total = 0.0
        weight_sum = 0.0

        for verifier, score in scores.items():
            weight = self.weights.get(verifier, 1.0)
            total += weight * score
            weight_sum += weight

        return total / weight_sum if weight_sum > 0 else 0.0

    def aggregate_by_answer(self, solutions: list[dict]) -> dict:
        """
        Aggregate scores by unique answer.
        """
        answer_scores = defaultdict(float)
        answer_counts = defaultdict(int)

        for sol in solutions:
            ans = self.normalize_answer(sol["answer"])

            # Add solution score
            answer_scores[ans] += sol["total_score"]

            # Consistency bonus (more samples = higher confidence)
            answer_counts[ans] += 1

        # Add consistency component
        for ans in answer_scores:
            consistency = answer_counts[ans] / len(solutions)
            answer_scores[ans] += self.weights["consistency"] * consistency

        return dict(answer_scores)

Verification-Guided Iterative Refinement

The gap between 75% accuracy (good) and 95% accuracy (gold medal) comes from error recovery. A model that generates wrong answers and stops there will plateau. A model that can detect its own errors and fix them keeps improving with more compute. Iterative refinement is the key: generate a solution, verify it, identify any errors, and refine. Repeat until verification passes or you hit a compute budget. This is how DeepSeekMath-V2 and Gemini Deep Think achieved gold at IMO 2025.

Python
class IterativeRefinementSolver:
    """
    Iterative refinement with structured verification.

    Based on the insight that gold-medal performance requires
    both generating solutions AND finding/fixing errors.

    Pipeline:
    1. Generate initial solution
    2. Verify step-by-step
    3. Identify first error
    4. Regenerate from error point
    5. Repeat until verified or max iterations
    """

    def __init__(
        self,
        model,
        verifier_model=None,
        max_iterations: int = 5,
        max_regenerations_per_step: int = 3
    ):
        self.generator = model
        self.verifier = verifier_model or model
        self.max_iterations = max_iterations
        self.max_regenerations = max_regenerations_per_step

    def solve(self, problem: str) -> dict:
        """
        Solve with iterative verification and refinement.
        """
        history = []

        # Initial generation
        solution = self.generate_initial(problem)

        for iteration in range(self.max_iterations):
            # Verify current solution
            verification = self.verify_detailed(problem, solution)

            history.append({
                "iteration": iteration,
                "solution_preview": solution[:200] + "...",
                "verification": verification
            })

            # Check if solution is correct
            if verification["is_correct"]:
                return {
                    "solution": solution,
                    "answer": self.extract_answer(solution),
                    "iterations": iteration + 1,
                    "status": "verified_correct",
                    "history": history
                }

            # Find first error and regenerate from there
            if verification["first_error_step"] is not None:
                solution = self.regenerate_from_error(
                    problem,
                    solution,
                    verification["first_error_step"],
                    verification["error_description"]
                )
            else:
                # No specific error found, regenerate entirely
                solution = self.generate_alternative(problem, solution)

        # Max iterations reached
        return {
            "solution": solution,
            "answer": self.extract_answer(solution),
            "iterations": self.max_iterations,
            "status": "max_iterations_reached",
            "history": history
        }

    def verify_detailed(self, problem: str, solution: str) -> dict:
        """
        Detailed verification identifying specific errors.
        """
        verification_prompt = f"""You are a rigorous mathematical proof verifier.
Carefully check each step of this solution for errors.

Problem: {problem}

Solution to verify:
{solution}

For each step:
1. Check logical validity
2. Verify calculations
3. Ensure no cases are missed

If you find an error:
- Identify the FIRST step with an error
- Explain exactly what is wrong
- Do not try to fix it, just identify it

Response format:
STEP_BY_STEP_ANALYSIS:
[Your detailed analysis of each step]

FIRST_ERROR_STEP: [step number, or "none" if no errors]
ERROR_DESCRIPTION: [what's wrong, or "none"]
OVERALL_VERDICT: [CORRECT / INCORRECT / UNCERTAIN]
CONFIDENCE: [0.0 to 1.0]
"""

        response = self.verifier.generate(
            verification_prompt,
            max_new_tokens=2048,
            temperature=0.2  # Low temp for reliable verification
        )

        return self.parse_verification_response(response)

    def parse_verification_response(self, response: str) -> dict:
        """Parse structured verification response."""
        result = {
            "is_correct": False,
            "first_error_step": None,
            "error_description": None,
            "confidence": 0.5,
            "full_analysis": response
        }

        # Extract fields
        if "FIRST_ERROR_STEP:" in response:
            match = re.search(r"FIRST_ERROR_STEP:\s*(\d+|none)", response, re.IGNORECASE)
            if match:
                val = match.group(1).lower()
                if val != "none":
                    result["first_error_step"] = int(val)

        if "ERROR_DESCRIPTION:" in response:
            match = re.search(r"ERROR_DESCRIPTION:\s*(.+?)(?=OVERALL_VERDICT|$)", response, re.DOTALL)
            if match:
                desc = match.group(1).strip()
                if desc.lower() != "none":
                    result["error_description"] = desc

        if "OVERALL_VERDICT:" in response:
            match = re.search(r"OVERALL_VERDICT:\s*(\w+)", response)
            if match:
                result["is_correct"] = match.group(1).upper() == "CORRECT"

        if "CONFIDENCE:" in response:
            match = re.search(r"CONFIDENCE:\s*([\d.]+)", response)
            if match:
                result["confidence"] = float(match.group(1))

        return result

    def regenerate_from_error(
        self,
        problem: str,
        current_solution: str,
        error_step: int,
        error_description: str
    ) -> str:
        """
        Regenerate solution from the point of error.

        Key insight: Keep correct prefix, regenerate suffix.
        """
        # Parse into steps
        steps = self.parse_steps(current_solution)

        # Keep steps before error
        correct_prefix = "\n".join(steps[:error_step])

        # Generate new completion
        regenerate_prompt = f"""Problem: {problem}

Partial solution (verified correct):
{correct_prefix}

The next step had an error: {error_description}

Continue the solution correctly from here:"""

        best_completion = None
        best_score = -1

        for _ in range(self.max_regenerations):
            completion = self.generator.generate(
                regenerate_prompt,
                max_new_tokens=2048,
                temperature=0.7
            )

            # Score this completion
            full_solution = correct_prefix + "\n" + completion
            score = self.quick_score(problem, full_solution)

            if score > best_score:
                best_score = score
                best_completion = completion

        return correct_prefix + "\n" + best_completion

    def generate_alternative(
        self,
        problem: str,
        failed_solution: str
    ) -> str:
        """
        Generate completely different approach.
        """
        prompt = f"""Problem: {problem}

A previous attempt at this problem was incorrect:
{failed_solution[:500]}...

Please try a DIFFERENT approach to solve this problem.
Consider alternative methods, different algebraic manipulations,
or viewing the problem from another angle.

Alternative solution:"""

        return self.generator.generate(
            prompt,
            max_new_tokens=2048,
            temperature=0.9  # High temp for diversity
        )

Part 4: Real IMO Problem Walkthrough

Theory and implementations are important, but nothing beats seeing these techniques work on a real problem. In this section, we trace through an actual IMO problem step-by-step, showing how MCTS explores different approaches, how PRMs guide the search, and how verification-refinement catches and fixes errors. This gives you a concrete picture of what "AI solves IMO problems" actually looks like in practice.

IMO 2024 Problem 1

This problem is from the 2024 International Mathematical Olympiad. It's a Number Theory problem that requires finding all functions satisfying a divisibility condition—a classic IMO style that tests both creativity (finding the right approach) and rigor (proving the solution is complete).

Code
Problem: Let N = {1, 2, 3, ...} be the positive integers.
A function f: N → N is called "good" if for all positive integers
m and n, the sum f(m) + f(n) divides m + n.

Find all good functions.
Python
def solve_imo_2024_p1():
    """
    Walkthrough of solving IMO 2024 Problem 1.
    """

    # Step 1: Problem Analysis
    problem = """
    Let N = {1, 2, 3, ...} be the positive integers.
    A function f: N → N is called "good" if for all positive integers
    m and n, the sum f(m) + f(n) divides m + n.

    Find all good functions.
    """

    # System classifies as: Number Theory, Functional Equation
    # Strategy: Start with special cases, find pattern, prove general

    # Step 2: Initial Generation (Temperature 0.7)
    initial_solution = """
    <think>
    Let me analyze this functional equation systematically.

    First, let's try m = n:
    f(m) + f(m) = 2f(m) divides m + m = 2m
    So f(m) | m for all m (since 2f(m) | 2m implies f(m) | m)

    Now let's try m = 1:
    f(1) + f(n) divides 1 + n = n + 1

    Since f(1) | 1 (from above), we have f(1) = 1.

    So 1 + f(n) divides n + 1.
    This means f(n) + 1 ≤ n + 1, so f(n) ≤ n.

    Combined with f(n) | n, the candidates are divisors of n that are ≤ n.

    Let me check if f(n) = n works:
    f(m) + f(n) = m + n divides m + n ✓

    Let me check if f(n) = 1 works:
    f(m) + f(n) = 2 divides m + n
    This requires m + n to be even for all m, n - FALSE (e.g., m=1, n=2)

    What about f(n) = n for some values and different for others?

    Actually, let me be more careful. We have:
    - f(n) | n for all n
    - f(1) = 1

    Let's try m = 2, n = 1:
    f(2) + f(1) = f(2) + 1 divides 3

    So f(2) + 1 ∈ {1, 3}, meaning f(2) ∈ {0, 2}
    Since f: N → N, f(2) ≥ 1, so f(2) = 2.

    Let's try m = 3, n = 1:
    f(3) + 1 divides 4, so f(3) + 1 ∈ {1, 2, 4}
    f(3) ∈ {0, 1, 3}
    Since f(3) | 3 and f(3) ≥ 1, f(3) ∈ {1, 3}

    Case f(3) = 1:
    Check m = 2, n = 3: f(2) + f(3) = 2 + 1 = 3 divides 5? NO.

    So f(3) = 3.

    This suggests f(n) = n for all n.

    Proof that f(n) = n:
    Suppose f(k) < k for some minimal k > 1.
    We have f(k) | k and f(k) < k, so f(k) ≤ k/2.

    Take m = k, n = 1:
    f(k) + 1 divides k + 1

    Since f(k) ≤ k/2, we have f(k) + 1 ≤ k/2 + 1.
    For f(k) + 1 to divide k + 1, we need f(k) + 1 ≤ (k+1)/2 + 1... hmm.

    Let me try a different approach.

    Take m = n = k:
    2f(k) | 2k, so f(k) | k ✓

    Take m = k, n = k-1 (assuming k > 1):
    f(k) + f(k-1) | 2k - 1

    If f(k-1) = k-1 (by minimality of k), then:
    f(k) + k - 1 | 2k - 1

    Since f(k) | k and f(k) < k, let f(k) = k/d for some d ≥ 2 (d | k).

    k/d + k - 1 | 2k - 1
    k(1/d + 1) - 1 | 2k - 1

    This is getting complicated. Let me try specific cases.

    If k = 4 and f(4) = 2 (not 4):
    f(4) + f(3) = 2 + 3 = 5 should divide 7. But 5 ∤ 7. Contradiction.

    If k = 6 and f(6) = 3:
    f(6) + f(5) = 3 + 5 = 8 should divide 11. But 8 ∤ 11. Contradiction.

    If k = 6 and f(6) = 2:
    f(6) + f(5) = 2 + 5 = 7 should divide 11. But 7 ∤ 11. Contradiction.

    These contradictions suggest f(n) = n is the only solution.

    Formal proof by strong induction:
    Base: f(1) = 1 ✓

    Inductive step: Assume f(j) = j for all j < k. Show f(k) = k.

    We know f(k) | k.
    Suppose f(k) = k/d for some d > 1.

    Take n = k - 1 (if k > 1):
    f(k) + f(k-1) = k/d + (k-1) | 2k - 1

    Since gcd(k, 2k-1) = gcd(k, -1) = 1 (as 2k-1 = 2k - 1),
    we need (k/d + k - 1) | 2k - 1.

    k/d + k - 1 = k(1/d + 1) - 1 = k(d+1)/d - 1

    For this to divide 2k - 1, and noting d ≥ 2:
    k(d+1)/d - 1 ≥ k(3/2) - 1 = 3k/2 - 1

    But 2k - 1 < 2(3k/2 - 1) = 3k - 2 only when... actually let's check:
    We need k(d+1)/d - 1 ≤ 2k - 1
    k(d+1)/d ≤ 2k
    (d+1)/d ≤ 2
    d + 1 ≤ 2d
    1 ≤ d ✓

    So the divisibility condition could hold. Need more careful analysis.

    Actually, for d = 2: k/2 + k - 1 = 3k/2 - 1
    This divides 2k - 1.
    3k/2 - 1 | 2k - 1
    3k - 2 | 4k - 2 = 2(2k-1)

    If 3k - 2 | 4k - 2:
    3k - 2 | 4k - 2 - (4/3)(3k-2) = 4k - 2 - 4k + 8/3 = -2 + 8/3 = 2/3

    This is not an integer, so 3k - 2 ∤ 4k - 2 for k > 1.

    Wait, that's not quite right. Let me redo:
    If 3k - 2 | 4k - 2:
    4k - 2 = q(3k - 2) for some integer q.
    4k - 2 = 3qk - 2q
    4k - 3qk = 2 - 2q
    k(4 - 3q) = 2(1 - q)

    If q = 1: k(1) = 0, impossible.
    If q = 0: 4k = 2, k = 1/2, impossible.
    If q = 2: k(-2) = -2, k = 1. But we assumed k > 1.
    If q ≥ 3 or q < 0: LHS and RHS have different signs.

    So for k > 1, 3k - 2 ∤ 4k - 2, meaning f(k) ≠ k/2.

    By similar analysis for other divisors, f(k) = k.
    </think>

    The only good function is f(n) = n for all positive integers n.

    Proof:
    1. Setting m = n gives 2f(n) | 2n, so f(n) | n.
    2. Setting m = 1, n = 1 gives 2f(1) | 2, so f(1) = 1.
    3. By strong induction, assuming f(j) = j for j < k and f(k) = k/d
       for d > 1, we derive a contradiction using m = k, n = k-1.
    4. Verification: f(n) = n satisfies f(m) + f(n) = m + n | m + n. ✓

    \\boxed{f(n) = n}
    """

    # Step 3: Verification
    verification_result = {
        "step_1": {"correct": True, "note": "Valid derivation of f(n) | n"},
        "step_2": {"correct": True, "note": "Correct that f(1) = 1"},
        "step_3": {"correct": True, "note": "Induction structure is valid"},
        "step_4": {"correct": True, "note": "Verification confirms solution"},
        "overall": "CORRECT"
    }

    # Step 4: Symbolic verification (if applicable)
    # For this problem, we verify the claimed solution works
    def verify_f_equals_n():
        """Verify that f(n) = n satisfies the condition."""
        # f(m) + f(n) = m + n always divides m + n
        return True

    # Output
    return {
        "problem": "IMO 2024 P1",
        "answer": "f(n) = n",
        "solution": initial_solution,
        "verification": verification_result,
        "symbolic_check": verify_f_equals_n(),
        "iterations": 1,  # Solved on first attempt
        "confidence": 0.95
    }

Trace of Error Detection and Refinement

The previous example worked on the first try. But what happens when the model makes a mistake? This trace shows the refinement loop in action: the model generates a solution with an error, the verifier catches it, and the model fixes the problem in a subsequent iteration. This is where iterative refinement earns its keep—without it, the first error would be final.

Python
def demonstrate_refinement():
    """
    Show how the system detects and fixes errors.
    """

    problem = "Find all positive integers n such that n² + 2n + 2 is a perfect square."

    # Iteration 1: Initial (incorrect) attempt
    attempt_1 = """
    <think>
    We need n² + 2n + 2 = k² for some positive integer k.

    Note that n² + 2n + 1 = (n+1)², so:
    n² + 2n + 2 = (n+1)² + 1

    We need (n+1)² + 1 = k²
    k² - (n+1)² = 1
    (k - n - 1)(k + n + 1) = 1

    Since k, n are positive integers:
    k - n - 1 = 1 and k + n + 1 = 1

    From second equation: k + n + 1 = 1, so k + n = 0.
    Since k, n > 0, this is impossible.

    Therefore, there are no solutions.
    </think>

    There are no positive integers n such that n² + 2n + 2 is a perfect square.

    \\boxed{\\text{No solutions}}
    """

    # Verification finds error
    verification_1 = {
        "step_analysis": [
            {"step": 1, "correct": True, "note": "Correct setup"},
            {"step": 2, "correct": True, "note": "Correct factorization"},
            {"step": 3, "correct": False, "note": "ERROR: (k-n-1)(k+n+1)=1 has solutions (1,1) and (-1,-1)"},
        ],
        "first_error_step": 3,
        "error_description": "The equation (k-n-1)(k+n+1)=1 can be satisfied by k-n-1=1, k+n+1=1 OR k-n-1=-1, k+n+1=-1. Only considered one case."
    }

    # Iteration 2: Regenerate from error
    attempt_2 = """
    <think>
    [Previous correct work kept]
    We have (k - n - 1)(k + n + 1) = 1

    Case 1: k - n - 1 = 1 and k + n + 1 = 1
    Adding: 2k = 2, so k = 1
    Subtracting: 2n + 2 = 0, so n = -1 (not positive)

    Case 2: k - n - 1 = -1 and k + n + 1 = -1
    Adding: 2k = -2, so k = -1 (not positive)

    Wait, I should also check if there are solutions over integers (not just factoring as 1×1).

    Actually, the only factorizations of 1 in integers are 1×1 and (-1)×(-1).

    But wait - let me reconsider. We have:
    k² - (n+1)² = 1

    This is a Pell-like equation. The difference of two squares equals 1 means:
    k² = (n+1)² + 1

    For consecutive squares to differ by 1, we need... actually, consecutive squares differ by 2n+1.

    Hmm, k² and (n+1)² differ by 1. The only consecutive squares are 0,1 (differ by 1), 1,4 (differ by 3)...

    Actually, no two perfect squares differ by exactly 1 except 0 and 1!

    So k² = 1 and (n+1)² = 0, giving k = 1, n = -1 (not positive).
    Or k² = 0 and (n+1)² = -1, impossible.

    Therefore, there are no positive integer solutions.

    Let me verify by testing small values:
    n = 1: 1 + 2 + 2 = 5 (not a square)
    n = 2: 4 + 4 + 2 = 10 (not a square)
    n = 3: 9 + 6 + 2 = 17 (not a square)
    n = 4: 16 + 8 + 2 = 26 (not a square)
    n = 5: 25 + 10 + 2 = 37 (not a square)

    The conclusion holds.
    </think>

    There are no positive integers n such that n² + 2n + 2 is a perfect square.

    Proof: n² + 2n + 2 = (n+1)² + 1. For this to be a perfect square k²,
    we need k² - (n+1)² = 1. The only squares differing by 1 are 0 and 1,
    which gives n = -1 (not positive). ✓

    \\boxed{\\text{No positive integer solutions}}
    """

    # Verification passes
    verification_2 = {
        "step_analysis": [
            {"step": 1, "correct": True},
            {"step": 2, "correct": True},
            {"step": 3, "correct": True, "note": "Now correctly handles all cases"},
            {"step": 4, "correct": True, "note": "Good verification with examples"},
        ],
        "overall": "CORRECT",
        "confidence": 0.95
    }

    return {
        "iterations": [
            {"attempt": attempt_1, "verification": verification_1, "status": "error_found"},
            {"attempt": attempt_2, "verification": verification_2, "status": "verified_correct"}
        ],
        "final_answer": "No positive integer solutions"
    }

Part 5: AlphaGeometry2 and Geometry-Specific Approaches

Geometry problems require specialized handling. AlphaGeometry2 (2025) achieves 84% solve rate on IMO geometry problems, up from 54% in the original version.

Why Geometry Is Different From Other Math Domains

IMO geometry problems present unique challenges that don't exist in algebra, number theory, or combinatorics:

The auxiliary construction problem:

Most IMO geometry proofs require introducing new elements—auxiliary points, lines, or circles that aren't in the original problem statement. These constructions are often the key insight that unlocks the entire proof. The search space is infinite: you could draw any line, any circle, through any combination of points. Finding the right construction is what separates a gold medalist from someone who can't start the problem.

For example, a typical IMO geometry problem might say "prove that points A, B, C, D are concyclic." A human solver recognizes this as needing to show a cyclic quadrilateral—perhaps by constructing the circumcircle, proving equal angles, or using power of a point. The auxiliary construction (which circumcircle? which angles?) isn't given.

The neuro-symbolic insight:

AlphaGeometry's breakthrough was recognizing that geometry requires a hybrid approach:

  1. Neural model: Suggests which auxiliary constructions might be useful (creative insight)
  2. Symbolic engine: Verifies whether those constructions lead to valid proofs (rigorous deduction)

The neural component handles the infinite search space by learning patterns from human proofs. "When proving concyclicity, try constructing the circumcircle of three points" is a pattern a neural network can learn. The symbolic engine then exhaustively explores the consequences of each construction.

Why coordinate geometry often wins:

A surprising finding: many geometry problems that seem to require clever synthetic proofs can be solved by brute-force coordinate bashing. Set up a coordinate system, assign coordinates to known points, express the conditions algebraically, and let the computer crunch the equations.

This is inelegant but reliable. AlphaGeometry2's coordinate geometry mode achieves high accuracy on problems that would require creative insight in synthetic geometry. The tradeoff is that coordinate proofs are long, unreadable by humans, and don't generalize to similar problems—but they work.

The counterintuitive vision finding:

You might expect that showing a model the geometric diagram would help. After all, humans use diagrams heavily when solving geometry problems. AlphaGeometry2's team tested multimodal approaches—feeding both the problem text and the rendered diagram to vision-language models.

The result: no significant improvement. The algebraic description contains all the information; the visual representation adds little. This suggests that human use of diagrams is more about cognitive convenience (working memory) than about extracting information that isn't in the text.

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    ALPHAGEOMETRY2 ADVANCES                               │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  KEY IMPROVEMENTS:                                                       │
│  ─────────────────                                                       │
│                                                                          │
│  1. Expanded domain-specific language                                   │
│     • More construction primitives                                      │
│     • Richer theorem library                                           │
│                                                                          │
│  2. C++ symbolic engine (100x faster)                                  │
│     • Enables deeper search                                            │
│     • More candidate proofs per problem                                │
│                                                                          │
│  3. Larger neural model (Gemini-based)                                 │
│     • Better construction suggestions                                   │
│     • More diverse proof strategies                                    │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  SURPRISING FINDING:                                                     │
│  ───────────────────                                                     │
│                                                                          │
│  "Vision models and diagram understanding DON'T significantly help"    │
│                                                                          │
│  The team tested multimodal approaches (reading geometry diagrams)     │
│  but found algebraic/symbolic reasoning dominates. Visual input        │
│  adds little beyond what can be parsed from text descriptions.        │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  RESULTS:                                                                │
│  ────────                                                                │
│                                                                          │
│  IMO Geometry Problems (30 years): 54% → 84%                           │
│  Combined with Gemini: 35/42 points at IMO 2025                       │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

The GeometryAgent implementation below demonstrates the neuro-symbolic approach that powers AlphaGeometry2. The agent has two solving strategies:

Coordinate geometry (always tried first): Place points in a coordinate system, express all geometric entities as algebraic equations, and solve symbolically. This is reliable but produces inelegant proofs. For problems asking "prove X", it works but doesn't satisfy the aesthetic standards of competition mathematics.

Synthetic geometry with neural suggestions: The more sophisticated approach. The neural model proposes auxiliary constructions (draw line AB, construct circumcircle, etc.), and the symbolic engine verifies whether those constructions lead to valid proofs. The suggest_auxiliary_construction method generates candidates; the symbolic engine tests each one. When a construction enables a proof, the agent traces back to produce human-readable steps.

The key data structures are GeometricStatement (facts like "A, B, C are collinear") and GeometricTheorem (applicable theorems like "Power of a point"). The symbolic engine applies theorems to derive new statements until the goal is reached.

Python
class GeometryAgent:
    """
    Specialized agent for geometry problems.

    Key insight from AlphaGeometry:
    Combine neural suggestion with symbolic deduction.

    Neural: Propose auxiliary constructions, lemmas
    Symbolic: Verify deductions, compute coordinates
    """

    def __init__(self, model, use_coordinates: bool = True):
        self.model = model
        self.use_coordinates = use_coordinates
        self.symbolic_engine = GeometricSymbolicEngine()

    def solve(self, problem: str) -> dict:
        """
        Solve geometry problem using neuro-symbolic approach.
        """
        # Parse geometric entities
        entities = self.parse_geometry(problem)

        # Strategy 1: Coordinate geometry (reliable but inelegant)
        if self.use_coordinates:
            coord_solution = self.solve_with_coordinates(problem, entities)
            if coord_solution["success"]:
                return coord_solution

        # Strategy 2: Synthetic approach with neural suggestions
        return self.solve_synthetic(problem, entities)

    def solve_with_coordinates(
        self,
        problem: str,
        entities: dict
    ) -> dict:
        """
        Translate to coordinates and solve algebraically.
        """
        # Step 1: Set up coordinate system
        coord_system = self.establish_coordinates(entities)

        # Step 2: Express all entities in coordinates
        expressions = self.express_in_coordinates(entities, coord_system)

        # Step 3: Translate goal to algebraic equation
        goal_equation = self.translate_goal(problem, expressions)

        # Step 4: Solve symbolically
        solution = self.symbolic_engine.solve(goal_equation)

        return {
            "success": solution is not None,
            "method": "coordinate_geometry",
            "coordinate_system": coord_system,
            "solution": solution
        }

    def solve_synthetic(
        self,
        problem: str,
        entities: dict
    ) -> dict:
        """
        Synthetic geometry with neural-guided construction.

        Loop:
        1. Neural model suggests construction or lemma
        2. Symbolic engine verifies and derives consequences
        3. Check if goal is reached
        """
        constructions = []
        derived_facts = set(entities.get("given_facts", []))

        for iteration in range(20):  # Max iterations
            # Neural suggestion
            suggestion = self.neural_suggest(
                problem,
                entities,
                constructions,
                derived_facts
            )

            if suggestion["type"] == "construction":
                # Add auxiliary point/line
                new_entity = self.add_construction(
                    suggestion["description"],
                    entities
                )
                constructions.append(new_entity)

                # Derive new facts from construction
                new_facts = self.symbolic_engine.derive_facts(
                    entities,
                    constructions,
                    derived_facts
                )
                derived_facts.update(new_facts)

            elif suggestion["type"] == "lemma":
                # Apply known theorem
                result = self.apply_lemma(
                    suggestion["lemma_name"],
                    suggestion["arguments"],
                    entities
                )
                if result:
                    derived_facts.add(result)

            elif suggestion["type"] == "conclude":
                # Check if we can prove the goal
                if self.goal_satisfied(problem, derived_facts):
                    return {
                        "success": True,
                        "method": "synthetic",
                        "constructions": constructions,
                        "proof_steps": list(derived_facts)
                    }

        return {"success": False, "reason": "max_iterations"}

    def neural_suggest(
        self,
        problem: str,
        entities: dict,
        constructions: list,
        derived_facts: set
    ) -> dict:
        """
        Get neural suggestion for next step.
        """
        prompt = f"""Geometry problem: {problem}

Current constructions: {constructions}
Known facts: {list(derived_facts)[:20]}

What construction or lemma would help solve this problem?

Options:
1. Add auxiliary construction (point, line, circle)
2. Apply a theorem (similar triangles, power of point, etc.)
3. Conclude the proof

Suggest the most helpful next step:"""

        response = self.model.generate(prompt, max_new_tokens=256)
        return self.parse_suggestion(response)


class GeometricSymbolicEngine:
    """
    Symbolic geometry engine for verification and deduction.

    Uses:
    1. Coordinate geometry computations
    2. Automated theorem proving rules
    3. Known geometry theorems
    """

    def __init__(self):
        self.theorems = self.load_geometry_theorems()

    def derive_facts(
        self,
        entities: dict,
        constructions: list,
        existing_facts: set
    ) -> set:
        """
        Derive new geometric facts from current configuration.
        """
        new_facts = set()

        # Apply all applicable theorems
        for theorem in self.theorems:
            applications = theorem.find_applications(
                entities,
                constructions,
                existing_facts
            )
            for app in applications:
                new_facts.add(app.conclusion)

        return new_facts

    def load_geometry_theorems(self) -> list:
        """
        Load library of geometry theorems.

        Examples:
        - Thales' theorem
        - Power of a point
        - Similar triangles
        - Angle bisector theorem
        - etc.
        """
        return [
            ThalesTheorem(),
            PowerOfPointTheorem(),
            SimilarTrianglesTheorem(),
            AngleBisectorTheorem(),
            PtolemyTheorem(),
            MenelausTheorem(),
            CevaTheorem(),
            # ... many more
        ]

Part 6: Multi-Agent Debate for Math

Multi-agent debate is one of the most promising techniques for improving mathematical reasoning accuracy. The core idea: instead of trusting a single model's output, have multiple "agents" with different roles critically examine each other's work.

Why Debate Works for Math

Mathematical proofs have an asymmetry: checking is easier than generating. A proof can take hours to discover, but verifying its correctness takes minutes. This asymmetry is perfect for debate—a critic agent can catch errors that the solver missed without needing to solve the problem itself.

The debate participants:

AgentRoleCapability
SolverGenerates initial solutionStrong at problem-solving
CriticFinds errors, gaps, unjustified stepsStrong at verification
DefenderAddresses criticisms, clarifies reasoningStrong at explanation
JudgeDecides if criticisms are validBalanced evaluation

The Debate Protocol

Understanding how debate works in practice helps you implement it correctly. The protocol below is iterated—each round surfaces and addresses issues, progressively improving solution quality. Most problems converge within 2-3 rounds, but complex proofs may need more. The key is having clear stopping criteria: either the critic finds no issues (consensus reached) or you hit a maximum round limit.

  1. Solver generates an initial solution
  2. Critic examines the solution, looking for:
    • Logical gaps (steps that don't follow)
    • Calculation errors
    • Missing cases
    • Unjustified claims
  3. Defender responds to each criticism:
    • Accept and fix valid criticisms
    • Rebut invalid criticisms with justification
  4. Judge evaluates:
    • Were the criticisms valid?
    • Were the defenses adequate?
    • Is the revised solution correct?
  5. Repeat until consensus or max rounds

Why Multiple Agents Beat Single-Agent Self-Correction

You might ask: why not just have one model check its own work? Research shows multi-agent debate outperforms single-agent self-correction because:

  1. Different "perspectives": Even with the same base model, different prompts (critic vs. defender) elicit different reasoning patterns
  2. Adversarial pressure: A critic is incentivized to find errors; a self-checker might be biased toward confirming its own work
  3. Information surfacing: The debate format forces explicit justification of each step, surfacing implicit assumptions

The implementation below shows a complete multi-agent debate system. The MathDebateSystem class orchestrates solver, critic, and judge agents through multiple rounds. Note how each agent has a specialized prompt that shapes its behavior—the critic actively looks for flaws while the defender tries to justify the reasoning. This adversarial dynamic is what makes debate effective.

Python
class MathDebateSystem:
    """
    Multi-agent debate for mathematical reasoning.

    Agents:
    1. Solver: Generates solutions
    2. Critic: Finds errors in solutions
    3. Defender: Addresses criticisms
    4. Judge: Evaluates debate outcome

    Research shows debate improves accuracy on hard problems.
    """

    def __init__(
        self,
        solver_model,
        critic_model=None,
        judge_model=None,
        max_rounds: int = 3
    ):
        self.solver = solver_model
        self.critic = critic_model or solver_model
        self.judge = judge_model or solver_model
        self.max_rounds = max_rounds

    def solve_with_debate(self, problem: str) -> dict:
        """
        Solve problem through multi-agent debate.
        """
        # Initial solution
        solution = self.generate_solution(problem)

        debate_history = []

        for round_num in range(self.max_rounds):
            # Critic finds issues
            criticism = self.generate_criticism(problem, solution)

            debate_history.append({
                "round": round_num + 1,
                "solution": solution[:500] + "...",
                "criticism": criticism
            })

            # If no valid criticism, solution is likely correct
            if not criticism["has_valid_criticism"]:
                break

            # Defender responds
            defense = self.generate_defense(
                problem,
                solution,
                criticism
            )

            # Judge evaluates
            judgment = self.judge_round(
                problem,
                solution,
                criticism,
                defense
            )

            if judgment["criticism_valid"]:
                # Revise solution based on criticism
                solution = self.revise_solution(
                    problem,
                    solution,
                    criticism,
                    defense
                )
            else:
                # Criticism rejected, keep solution
                break

        # Final judgment
        final_verdict = self.final_judgment(problem, solution)

        return {
            "solution": solution,
            "answer": self.extract_answer(solution),
            "debate_rounds": len(debate_history),
            "debate_history": debate_history,
            "final_verdict": final_verdict
        }

    def generate_criticism(
        self,
        problem: str,
        solution: str
    ) -> dict:
        """
        Critic agent finds potential errors.
        """
        prompt = f"""You are a critical mathematics reviewer.
Your job is to find errors in mathematical proofs.

Problem: {problem}

Proposed Solution:
{solution}

Carefully analyze this solution and identify ANY errors:
1. Logical errors (non-sequiturs, unjustified steps)
2. Calculation errors
3. Missing cases
4. Incorrect theorems/facts used
5. Gaps in reasoning

If you find errors, explain them clearly.
If the solution is correct, state that explicitly.

Your critique:"""

        response = self.critic.generate(
            prompt,
            max_new_tokens=1024,
            temperature=0.3
        )

        has_criticism = not any(
            phrase in response.lower()
            for phrase in ["no errors", "solution is correct", "proof is valid"]
        )

        return {
            "text": response,
            "has_valid_criticism": has_criticism
        }

    def generate_defense(
        self,
        problem: str,
        solution: str,
        criticism: dict
    ) -> str:
        """
        Defender addresses criticism.
        """
        prompt = f"""You are defending a mathematical solution.

Problem: {problem}

Your Solution:
{solution}

Criticism received:
{criticism['text']}

Defend your solution. Either:
1. Explain why the criticism is invalid
2. Acknowledge the error and explain how to fix it
3. Provide additional justification for questioned steps

Your defense:"""

        return self.solver.generate(
            prompt,
            max_new_tokens=1024,
            temperature=0.4
        )

    def judge_round(
        self,
        problem: str,
        solution: str,
        criticism: dict,
        defense: str
    ) -> dict:
        """
        Judge evaluates debate round.
        """
        prompt = f"""You are an impartial mathematical judge.

Problem: {problem}

Solution:
{solution}

Criticism:
{criticism['text']}

Defense:
{defense}

Evaluate this debate:
1. Is the criticism valid? (Does it identify a real error?)
2. Is the defense successful? (Does it address the criticism?)
3. What is your verdict?

Respond with:
CRITICISM_VALID: [yes/no]
DEFENSE_SUCCESSFUL: [yes/no]
VERDICT: [solution_correct/solution_needs_revision/unclear]
EXPLANATION: [your reasoning]"""

        response = self.judge.generate(
            prompt,
            max_new_tokens=512,
            temperature=0.2
        )

        # Parse response
        criticism_valid = "CRITICISM_VALID: yes" in response.lower()
        defense_successful = "DEFENSE_SUCCESSFUL: yes" in response.lower()

        return {
            "criticism_valid": criticism_valid,
            "defense_successful": defense_successful,
            "full_response": response
        }

    def revise_solution(
        self,
        problem: str,
        solution: str,
        criticism: dict,
        defense: str
    ) -> str:
        """
        Revise solution based on valid criticism.
        """
        prompt = f"""Revise the solution to address the valid criticism.

Problem: {problem}

Original Solution:
{solution}

Valid Criticism:
{criticism['text']}

Defense attempt:
{defense}

Please provide a REVISED solution that fixes the identified errors
while keeping correct parts of the original.

Revised Solution:"""

        return self.solver.generate(
            prompt,
            max_new_tokens=2048,
            temperature=0.5
        )

Part 7: Production Deployment

Deploying mathematical reasoning systems in production introduces challenges beyond pure accuracy: cost control, latency requirements, reliability, and observability. A system that achieves 90% accuracy but costs $10 per query isn't viable for most applications.

The Economics of AI Math Solving

Cost breakdown for a typical math query:

ComponentTokensCost (GPT-4)
Problem input~200$0.006
Reasoning chain~2000$0.060
Self-verification~1000$0.030
Total per attempt~3200~$0.10
With 8 attempts~25600~$0.80

At scale (1000 queries/day), this becomes 800/dayor800/day or 24,000/month—just for the LLM costs. Add inference infrastructure, monitoring, and engineering time, and costs can easily exceed $50,000/month.

Optimization Strategies

1. Tiered solving (cheap first): Start with a fast, cheap model (GPT-4o-mini, Claude Haiku). If it solves the problem confidently, stop. Only escalate to expensive models (GPT-4, Claude Opus, o1) for hard problems. This can reduce costs by 60-80% while maintaining accuracy.

2. Difficulty estimation: Before solving, estimate problem difficulty. Easy problems don't need MCTS or multiple attempts. Hard problems justify the extra compute. A simple classifier trained on problem features can route efficiently.

3. Caching: Mathematical problems have patterns. "Solve x^2 - 4 = 0" might appear in many forms. Cache both exact matches and semantic equivalents (embedding similarity). Cache hit rates of 20-40% are common for educational platforms.

4. Early stopping: If a solution is verified correct, stop generating alternatives. If 5 attempts all fail, escalate rather than continuing with diminishing returns.

Cost and Latency Optimization

The techniques above achieve high accuracy—but at what cost? Running MCTS with 1000 simulations or Best-of-64 with PRM reranking can cost $5-50 per problem. For production systems, you need to balance accuracy against cost and latency. The implementation below shows how to build a tiered system: start cheap (GPT-4o-mini), escalate if needed, cache aggressively, and parallelize where possible. A well-designed tiered system achieves 80-90% of maximum accuracy at 10-20% of maximum cost.

Python
class ProductionMathSolver:
    """
    Production-ready math solver with cost/latency optimization.

    Key considerations:
    1. Tiered solving (cheap first, expensive if needed)
    2. Caching for common patterns
    3. Parallel execution where possible
    4. Early stopping when confident
    """

    def __init__(
        self,
        cheap_model,  # e.g., Claude Haiku, GPT-4o-mini
        expensive_model,  # e.g., Claude Opus, o1
        max_budget_dollars: float = 1.0,
        target_latency_seconds: float = 30.0
    ):
        self.cheap = cheap_model
        self.expensive = expensive_model
        self.max_budget = max_budget_dollars
        self.target_latency = target_latency_seconds

        # Pricing (per 1M tokens, approximate)
        self.pricing = {
            "cheap_input": 0.25,
            "cheap_output": 1.25,
            "expensive_input": 15.0,
            "expensive_output": 75.0
        }

        # Cache for common problem patterns
        self.cache = SolutionCache()

    def solve(self, problem: str) -> dict:
        """
        Solve with cost-aware strategy.
        """
        start_time = time.time()
        total_cost = 0.0

        # Check cache first
        cached = self.cache.get(problem)
        if cached:
            return {
                **cached,
                "source": "cache",
                "latency": time.time() - start_time,
                "cost": 0.0
            }

        # Tier 1: Single cheap attempt
        result, cost = self.try_cheap_single(problem)
        total_cost += cost

        if result["confidence"] > 0.9:
            self.cache.set(problem, result)
            return {
                **result,
                "source": "cheap_single",
                "latency": time.time() - start_time,
                "cost": total_cost
            }

        # Tier 2: Cheap model with verification
        if total_cost < self.max_budget * 0.3:
            result, cost = self.try_cheap_with_verification(problem)
            total_cost += cost

            if result["confidence"] > 0.85:
                self.cache.set(problem, result)
                return {
                    **result,
                    "source": "cheap_verified",
                    "latency": time.time() - start_time,
                    "cost": total_cost
                }

        # Tier 3: Expensive model
        if total_cost < self.max_budget * 0.7:
            result, cost = self.try_expensive(problem)
            total_cost += cost

            if result["confidence"] > 0.8:
                self.cache.set(problem, result)
                return {
                    **result,
                    "source": "expensive",
                    "latency": time.time() - start_time,
                    "cost": total_cost
                }

        # Tier 4: Full pipeline (most expensive)
        if total_cost < self.max_budget:
            result, cost = self.try_full_pipeline(problem)
            total_cost += cost

            self.cache.set(problem, result)
            return {
                **result,
                "source": "full_pipeline",
                "latency": time.time() - start_time,
                "cost": total_cost
            }

        # Budget exceeded, return best so far
        return {
            **result,
            "source": "budget_limited",
            "latency": time.time() - start_time,
            "cost": total_cost
        }

    def try_cheap_single(self, problem: str) -> tuple[dict, float]:
        """Single attempt with cheap model."""
        response = self.cheap.generate(
            f"Solve: {problem}\n\nSolution:",
            max_new_tokens=1024
        )

        # Estimate cost
        cost = self.estimate_cost("cheap", len(problem), len(response))

        # Quick confidence estimate
        confidence = self.quick_confidence(response)

        return {
            "solution": response,
            "answer": self.extract_answer(response),
            "confidence": confidence
        }, cost

    def try_cheap_with_verification(self, problem: str) -> tuple[dict, float]:
        """Multiple attempts + majority voting."""
        solutions = []
        total_cost = 0.0

        for _ in range(5):  # 5 samples
            response = self.cheap.generate(
                f"Solve step by step: {problem}\n\n",
                max_new_tokens=1024,
                temperature=0.7
            )
            solutions.append(response)
            total_cost += self.estimate_cost("cheap", len(problem), len(response))

        # Majority vote
        answers = [self.extract_answer(s) for s in solutions]
        answer_counts = Counter(answers)
        best_answer, count = answer_counts.most_common(1)[0]

        confidence = count / len(solutions)

        # Find best solution with that answer
        best_solution = next(
            s for s in solutions
            if self.extract_answer(s) == best_answer
        )

        return {
            "solution": best_solution,
            "answer": best_answer,
            "confidence": confidence,
            "agreement": f"{count}/{len(solutions)}"
        }, total_cost

    def estimate_cost(
        self,
        model_tier: str,
        input_tokens: int,
        output_tokens: int
    ) -> float:
        """Estimate API cost."""
        input_price = self.pricing[f"{model_tier}_input"]
        output_price = self.pricing[f"{model_tier}_output"]

        return (input_tokens * input_price + output_tokens * output_price) / 1_000_000


class SolutionCache:
    """
    Cache for math solutions.

    Uses semantic similarity to find cached solutions
    for similar problems.
    """

    def __init__(self, similarity_threshold: float = 0.95):
        self.threshold = similarity_threshold
        self.embedder = SentenceTransformer("all-mpnet-base-v2")
        self.cache = {}  # embedding -> solution
        self.problem_texts = {}  # embedding -> original problem

    def get(self, problem: str) -> dict | None:
        """Find cached solution for similar problem."""
        if not self.cache:
            return None

        embedding = self.embedder.encode(problem)

        for cached_emb, solution in self.cache.items():
            similarity = cosine_similarity([embedding], [cached_emb])[0][0]
            if similarity > self.threshold:
                return solution

        return None

    def set(self, problem: str, solution: dict):
        """Cache solution."""
        embedding = tuple(self.embedder.encode(problem).tolist())
        self.cache[embedding] = solution
        self.problem_texts[embedding] = problem

Monitoring and Observability

Production systems fail silently without proper monitoring. A math solver might gradually degrade as problem distributions shift, or a model update might introduce subtle regressions. The monitoring system below tracks accuracy on known-answer problems, detects distribution shift, and alerts on anomalies. It also logs detailed metrics for debugging: which problem types fail most, where in the reasoning chain errors occur, and how much compute each problem consumes.

Python
class MathSolverMonitor:
    """
    Monitoring for production math solver.

    Tracks:
    1. Accuracy on known problems
    2. Latency distribution
    3. Cost per problem
    4. Error patterns
    """

    def __init__(self, solver: ProductionMathSolver):
        self.solver = solver
        self.metrics = defaultdict(list)

    def solve_with_monitoring(
        self,
        problem: str,
        ground_truth: str = None
    ) -> dict:
        """Solve with full monitoring."""
        start_time = time.time()

        try:
            result = self.solver.solve(problem)

            # Record metrics
            self.metrics["latency"].append(result["latency"])
            self.metrics["cost"].append(result["cost"])
            self.metrics["source"].append(result["source"])
            self.metrics["confidence"].append(result["confidence"])

            if ground_truth:
                is_correct = self.verify_answer(result["answer"], ground_truth)
                self.metrics["correct"].append(is_correct)

                if not is_correct:
                    self.log_error(problem, result, ground_truth)

            return result

        except Exception as e:
            self.metrics["errors"].append({
                "problem": problem[:100],
                "error": str(e),
                "timestamp": time.time()
            })
            raise

    def get_summary(self) -> dict:
        """Get monitoring summary."""
        return {
            "total_problems": len(self.metrics["latency"]),
            "accuracy": (
                sum(self.metrics["correct"]) / len(self.metrics["correct"])
                if self.metrics["correct"] else None
            ),
            "avg_latency": np.mean(self.metrics["latency"]),
            "p95_latency": np.percentile(self.metrics["latency"], 95),
            "avg_cost": np.mean(self.metrics["cost"]),
            "total_cost": sum(self.metrics["cost"]),
            "source_distribution": Counter(self.metrics["source"]),
            "error_count": len(self.metrics["errors"])
        }

    def log_error(self, problem: str, result: dict, ground_truth: str):
        """Log detailed error for analysis."""
        error_record = {
            "timestamp": time.time(),
            "problem": problem,
            "predicted": result["answer"],
            "ground_truth": ground_truth,
            "solution": result["solution"][:1000],
            "confidence": result["confidence"],
            "source": result["source"]
        }

        # Classify error type
        error_record["error_type"] = self.classify_error(
            result["solution"],
            ground_truth
        )

        self.metrics["detailed_errors"].append(error_record)

    def classify_error(self, solution: str, ground_truth: str) -> str:
        """Classify type of error for analysis."""
        # Common error patterns
        if "arithmetic" in solution.lower() and "=" in solution:
            # Check for calculation errors
            return "calculation_error"

        if any(w in solution.lower() for w in ["assume", "suppose", "let"]):
            return "wrong_assumption"

        if "therefore" in solution.lower() or "thus" in solution.lower():
            return "logical_error"

        return "unknown"

Part 8: rStar-Math: Self-Play Mutual Reasoning

One of the most significant breakthroughs of 2025 is rStar-Math from Microsoft Research, demonstrating that small models can achieve frontier-level math reasoning through self-play without distillation from larger models.

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    rStar-Math ARCHITECTURE                               │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  KEY INSIGHT:                                                            │
│  ────────────                                                            │
│                                                                          │
│  "Small LLMs can rival o1 without distillation from superior models"   │
│                                                                          │
│  rStar-Math 7B achieves 90.0% on MATH (vs 58.8% baseline)              │
│  rStar2-Agent 14B matches DeepSeek-R1 671B on AIME                     │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  CORE ARCHITECTURE:                                                      │
│  ──────────────────                                                      │
│                                                                          │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │                 Self-Play Mutual Reasoning                       │   │
│  ├─────────────────────────────────────────────────────────────────┤   │
│  │                                                                  │   │
│  │  Problem                                                        │   │
│  │     │                                                           │   │
│  │     ▼                                                           │   │
│  │  ┌─────────────────────────────────────────────────────────┐  │   │
│  │  │           Generator (Policy SLM)                         │  │   │
│  │  │  Uses MCTS to explore reasoning trajectories            │  │   │
│  │  │  Human-like actions: decompose, verify, backtrack       │  │   │
│  │  └─────────────────────────────────────────────────────────┘  │   │
│  │     │                                                           │   │
│  │     │ Multiple trajectories                                    │   │
│  │     ▼                                                           │   │
│  │  ┌─────────────────────────────────────────────────────────┐  │   │
│  │  │           Discriminator (Verifier SLM)                   │  │   │
│  │  │  Scores each trajectory for correctness                 │  │   │
│  │  │  Process Preference Model (PPM)                         │  │   │
│  │  └─────────────────────────────────────────────────────────┘  │   │
│  │     │                                                           │   │
│  │     │ Mutually consistent trajectories                         │   │
│  │     ▼                                                           │   │
│  │  ┌─────────────────────────────────────────────────────────┐  │   │
│  │  │           Self-Evolution Loop                            │  │   │
│  │  │  4 rounds: Generator ↔ Verifier co-improve              │  │   │
│  │  │  747K problems, millions of solutions                   │  │   │
│  │  └─────────────────────────────────────────────────────────┘  │   │
│  │                                                                  │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Three Key Innovations

rStar-Math's success comes from three innovations that work together. First, code-augmented data synthesis generates training data with verified solutions—using Python to check answers ensures no incorrect examples leak into training. Second, a Process Preference Model (PPM) learns to evaluate reasoning trajectories, enabling better search guidance than simple outcome rewards. Third, self-evolution through mutual reasoning lets the policy and PPM co-improve over multiple rounds without any external supervision.

1. Code-Augmented CoT Data Synthesis

The key to generating high-quality training data without a teacher model: generate reasoning traces that include Python code, then execute the code to verify the answer. Only traces with correct, verified answers become training data. This creates a self-supervised loop where the model's own successful reasoning becomes its training signal.

Python
class RStarMathDataSynthesis:
    """
    rStar-Math's novel data synthesis pipeline.

    Key insight: Generate step-by-step verified reasoning using:
    1. MCTS exploration with human-like actions
    2. Code execution for verification
    3. Extensive rollouts to find correct paths
    """

    def __init__(
        self,
        policy_model,
        code_executor,
        n_rollouts: int = 64
    ):
        self.policy = policy_model
        self.executor = code_executor
        self.n_rollouts = n_rollouts

        # Human-like reasoning actions
        self.actions = [
            "decompose",      # Break into subproblems
            "analyze",        # Examine given information
            "compute",        # Perform calculation (with code)
            "verify",         # Check intermediate result
            "backtrack",      # Abandon current approach
            "conclude"        # State final answer
        ]

    def synthesize_trajectory(self, problem: str) -> dict:
        """
        Generate verified reasoning trajectory using MCTS.
        """
        root = MCTSNode(state=problem, action=None)

        for _ in range(self.n_rollouts):
            # Selection: traverse tree using UCB
            node = self.select(root)

            # Expansion: generate next reasoning step
            if not node.is_terminal:
                child = self.expand(node)

                # Simulation: complete solution from this point
                trajectory, final_answer = self.simulate(child)

                # Verification: check with code execution
                is_correct = self.verify_with_code(
                    problem,
                    trajectory,
                    final_answer
                )

                # Backpropagation: update tree with result
                self.backpropagate(child, is_correct)

        # Extract best trajectory
        best_trajectory = self.extract_best_path(root)

        return {
            "problem": problem,
            "trajectory": best_trajectory,
            "verified": True,
            "code_verified_steps": self.count_code_verified(best_trajectory)
        }

    def expand(self, node: MCTSNode) -> MCTSNode:
        """
        Expand node with human-like reasoning action.
        """
        # Select action based on current state
        action = self.select_action(node.state)

        # Generate reasoning step
        if action == "compute":
            # Generate code for computation
            step = self.generate_code_step(node.state)
            result = self.executor.run(step["code"])
            step["code_result"] = result
        else:
            # Generate natural language step
            step = self.generate_nl_step(node.state, action)

        new_state = node.state + "\n" + step["text"]
        child = MCTSNode(
            state=new_state,
            action=action,
            parent=node
        )
        node.children.append(child)
        return child

    def verify_with_code(
        self,
        problem: str,
        trajectory: str,
        answer: str
    ) -> bool:
        """
        Verify solution using Python code execution.

        This is crucial: code provides ground-truth verification
        that doesn't require a learned reward model.
        """
        # Extract numerical answer
        predicted = self.extract_numerical(answer)
        if predicted is None:
            return False

        # Try to verify computationally
        verification_code = f'''
# Verify the solution to: {problem[:200]}
# Claimed answer: {predicted}

# Re-derive answer from scratch
{self.generate_verification_code(problem)}

# Check if answer matches
is_correct = abs(computed_answer - {predicted}) < 1e-6
print(is_correct)
'''
        try:
            result = self.executor.run(verification_code)
            return "True" in result
        except:
            # Fall back to answer matching
            return self.answer_matches(predicted, problem)

2. Process Preference Model (PPM) Training

The Process Preference Model (PPM) is rStar-Math's alternative to standard Process Reward Models. The key insight: instead of training a model to score individual steps (which requires expensive step-level annotations), PPM learns to compare entire reasoning trajectories. Given two solution attempts, which one is better? This pairwise comparison is much easier to label—if one solution is correct and the other wrong, the correct one is preferred. If both are correct, prefer the one with clearer reasoning.

PPM training data comes directly from MCTS exploration: as the policy explores different reasoning paths, some lead to correct answers and others don't. These naturally form preference pairs. No human annotation required—the mathematical correctness provides the signal. Over training rounds, PPM learns to recognize not just correctness but quality of reasoning, enabling it to guide search toward elegant solutions.

Python
class ProcessPreferenceModel:
    """
    rStar-Math's novel reward model: Process Preference Model (PPM).

    Unlike standard PRMs that score steps independently,
    PPM learns preferences between reasoning trajectories.

    Key innovation: Trained on trajectory pairs from MCTS exploration,
    where correct paths are preferred over incorrect ones.
    """

    def __init__(self, base_model, tokenizer):
        self.model = base_model
        self.tokenizer = tokenizer

        # Add preference head
        self.preference_head = nn.Linear(
            base_model.config.hidden_size,
            1
        )

    def train_ppm(
        self,
        problem: str,
        trajectories: list[dict]
    ) -> torch.Tensor:
        """
        Train PPM on trajectory preferences from MCTS.

        Positive trajectories: reached correct answer
        Negative trajectories: failed or incorrect
        """
        # Separate by outcome
        positive = [t for t in trajectories if t["correct"]]
        negative = [t for t in trajectories if not t["correct"]]

        if not positive or not negative:
            return None

        # Create preference pairs
        pairs = []
        for pos in positive:
            for neg in negative:
                pairs.append({
                    "problem": problem,
                    "chosen": pos["trajectory"],
                    "rejected": neg["trajectory"]
                })

        # Bradley-Terry preference loss
        total_loss = 0
        for pair in pairs:
            # Score both trajectories
            chosen_score = self.score_trajectory(
                pair["problem"],
                pair["chosen"]
            )
            rejected_score = self.score_trajectory(
                pair["problem"],
                pair["rejected"]
            )

            # Preference loss: chosen should score higher
            loss = -F.logsigmoid(chosen_score - rejected_score)
            total_loss += loss

        return total_loss / len(pairs)

    def score_trajectory(
        self,
        problem: str,
        trajectory: str
    ) -> torch.Tensor:
        """
        Score a reasoning trajectory.

        Returns scalar score indicating trajectory quality.
        """
        input_text = f"Problem: {problem}\n\nSolution:\n{trajectory}"

        inputs = self.tokenizer(
            input_text,
            return_tensors="pt",
            truncation=True,
            max_length=4096
        )

        outputs = self.model(**inputs)

        # Use last token's hidden state for scoring
        last_hidden = outputs.last_hidden_state[:, -1, :]
        score = self.preference_head(last_hidden)

        return score.squeeze()

    def guide_mcts(
        self,
        problem: str,
        partial_trajectory: str,
        candidates: list[str]
    ) -> list[float]:
        """
        Use PPM to guide MCTS expansion.

        Score candidate next steps to prioritize exploration.
        """
        scores = []
        for candidate in candidates:
            full_trajectory = partial_trajectory + "\n" + candidate
            score = self.score_trajectory(problem, full_trajectory)
            scores.append(score.item())

        return scores

3. Self-Evolution Recipe

The self-evolution loop is where rStar-Math achieves its remarkable results without any teacher model. The core idea: the policy model and PPM improve each other in alternating rounds. In each round:

  1. Policy generates many solution attempts using MCTS guided by the current PPM
  2. Filter by correctness: only keep solutions that produce correct answers (verified by code execution)
  3. PPM trains on the new (correct trajectory, incorrect trajectory) pairs from MCTS exploration
  4. Policy trains on the filtered correct solutions

This creates a virtuous cycle: better PPM → better MCTS guidance → higher quality solutions → better PPM training data → better PPM. After 4 rounds, a 7B model achieves 90% on MATH—comparable to GPT-4 and better than most 70B models. The key is that code execution provides ground truth, so the system never needs external supervision.

Python
class RStarSelfEvolution:
    """
    rStar-Math's self-evolution: policy and PPM co-improve.

    4 rounds of evolution:
    Round 1: Initial training on synthetic data
    Round 2: Generate new solutions, filter by PPM, retrain
    Round 3: Improve PPM with new data, generate better solutions
    Round 4: Final refinement

    Result: 7B model achieves 90% on MATH
    """

    def __init__(
        self,
        policy_model,
        ppm_model,
        problem_bank,  # 747K math problems
        n_rounds: int = 4
    ):
        self.policy = policy_model
        self.ppm = ppm_model
        self.problems = problem_bank
        self.n_rounds = n_rounds

    def evolve(self) -> dict:
        """
        Run self-evolution loop.
        """
        history = []

        for round_num in range(self.n_rounds):
            print(f"Evolution Round {round_num + 1}")

            # Step 1: Generate solutions with current policy
            solutions = self.generate_solutions(
                n_per_problem=64 if round_num < 2 else 32
            )

            # Step 2: Filter with PPM
            high_quality = self.filter_solutions(solutions)

            # Step 3: Retrain policy on filtered solutions
            policy_metrics = self.train_policy(high_quality)

            # Step 4: Update PPM with new preference pairs
            ppm_metrics = self.train_ppm(solutions)

            # Step 5: Evaluate
            eval_metrics = self.evaluate()

            history.append({
                "round": round_num + 1,
                "n_solutions": len(solutions),
                "n_high_quality": len(high_quality),
                "policy_metrics": policy_metrics,
                "ppm_metrics": ppm_metrics,
                "eval": eval_metrics
            })

            print(f"  MATH: {eval_metrics['math']:.1%}")
            print(f"  AIME: {eval_metrics['aime']:.1%}")

        return history

    def filter_solutions(
        self,
        solutions: list[dict],
        threshold: float = 0.7
    ) -> list[dict]:
        """
        Filter solutions using PPM scores.

        Only keep high-scoring, verified-correct solutions
        for policy training.
        """
        filtered = []

        for sol in solutions:
            # Must be correct
            if not sol["correct"]:
                continue

            # Must have high PPM score
            score = self.ppm.score_trajectory(
                sol["problem"],
                sol["trajectory"]
            )

            if score > threshold:
                filtered.append(sol)

        return filtered


# Results of rStar-Math self-evolution
RSTAR_MATH_RESULTS = {
    "qwen2.5_math_7b": {
        "baseline": 58.8,
        "round_1": 72.3,
        "round_2": 81.5,
        "round_3": 87.2,
        "round_4": 90.0,  # Final: 90.0% on MATH
    },
    "phi3_mini_3.8b": {
        "baseline": 41.4,
        "round_1": 58.7,
        "round_2": 71.2,
        "round_3": 80.1,
        "round_4": 86.4,  # 3.8B model achieves 86.4%!
    }
}

rStar2-Agent: Thinking Smarter, Not Longer

The initial rStar-Math showed that self-play can achieve remarkable results. rStar2-Agent takes this further with a crucial insight: thinking smarter beats thinking longer. DeepSeek-R1 achieves its results with massive models and very long reasoning chains (10K+ tokens). rStar2-Agent matches this performance with a 14B model by generating more focused, efficient reasoning. The key is better training and search, not bigger models or longer outputs.

The latest evolution, rStar2-Agent, achieves DeepSeek-R1 (671B) performance with just 14B parameters:

Python
class RStar2Agent:
    """
    rStar2-Agent: Agentic RL for math reasoning.

    Key insight: "Think smarter, not longer"

    Achieves 80.6% on AIME24 with 14B parameters,
    matching 671B DeepSeek-R1, but with shorter responses.

    Trained in only 510 RL steps within one week.
    """

    def __init__(
        self,
        base_model,  # 14B pretrained
        ppm,
        code_executor
    ):
        self.model = base_model
        self.ppm = ppm
        self.executor = code_executor

        # Agentic tools
        self.tools = {
            "python": self.execute_python,
            "verify": self.verify_step,
            "search": self.search_similar_problems
        }

    def solve_with_agentic_rl(self, problem: str) -> dict:
        """
        Solve problem using agentic reasoning.

        The model learns to:
        1. Decide when to use tools vs pure reasoning
        2. Verify intermediate steps proactively
        3. Backtrack efficiently when stuck
        """
        state = {
            "problem": problem,
            "reasoning": [],
            "tool_calls": [],
            "current_approach": None
        }

        max_steps = 20
        for step in range(max_steps):
            # Model decides next action
            action = self.decide_action(state)

            if action["type"] == "reason":
                # Pure reasoning step
                reasoning = self.generate_reasoning(state)
                state["reasoning"].append(reasoning)

            elif action["type"] == "tool":
                # Use tool (code execution, verification)
                tool_name = action["tool"]
                tool_input = action["input"]
                result = self.tools[tool_name](tool_input)
                state["tool_calls"].append({
                    "tool": tool_name,
                    "input": tool_input,
                    "result": result
                })

            elif action["type"] == "backtrack":
                # Abandon current approach
                state["reasoning"] = state["reasoning"][:-3]
                state["current_approach"] = None

            elif action["type"] == "conclude":
                # Extract and return answer
                answer = self.extract_answer(state)
                return {
                    "answer": answer,
                    "reasoning": state["reasoning"],
                    "tool_calls": state["tool_calls"],
                    "n_steps": step + 1
                }

        return {"answer": None, "error": "max_steps_exceeded"}

    def decide_action(self, state: dict) -> dict:
        """
        Model decides what to do next.

        Trained via RL to maximize correct answers
        while minimizing reasoning length.
        """
        prompt = f"""Problem: {state['problem']}

Current reasoning:
{chr(10).join(state['reasoning'][-5:])}

Recent tool results:
{state['tool_calls'][-2:] if state['tool_calls'] else 'None'}

What should I do next?
Options:
- reason: Continue reasoning
- tool:python: Execute Python code
- tool:verify: Verify current step
- backtrack: Try different approach
- conclude: State final answer

Decision:"""

        response = self.model.generate(prompt, max_tokens=100)
        return self.parse_action(response)


# rStar2-Agent Results
RSTAR2_AGENT_RESULTS = {
    "model": "14B",
    "training": "510 RL steps, 1 week",
    "aime24": 80.6,  # Average pass@1
    "aime25": 69.8,  # Average pass@1
    "comparison": {
        "deepseek_r1_671b": {"aime24": 79.8, "aime25": 70.0},
        "o1_preview": {"aime24": 44.6, "aime25": None},
    },
    "key_insight": "Shorter responses, same accuracy"
}

Part 9: DeepSeekMath-V2: Self-Verifiable Reasoning

DeepSeekMath-V2 achieved gold-medal scores on IMO 2025 through a novel self-verification framework:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    DeepSeekMath-V2 ARCHITECTURE                          │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  KEY INSIGHT:                                                            │
│  ────────────                                                            │
│                                                                          │
│  "Incentivize the generator to identify and resolve as many issues     │
│   as possible in their own proofs before finalizing them"              │
│                                                                          │
│  Results:                                                               │
│  • Gold-level on IMO 2025                                              │
│  • Gold-level on CMO 2024                                              │
│  • 118/120 on Putnam 2024                                              │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  GENERATOR-VERIFIER TRAINING LOOP:                                      │
│  ─────────────────────────────────                                       │
│                                                                          │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │                                                                  │   │
│  │   ┌──────────────┐        ┌──────────────┐                     │   │
│  │   │  Generator   │───────▶│   Verifier   │                     │   │
│  │   │    (LLM)     │        │    (LLM)     │                     │   │
│  │   └──────────────┘        └──────────────┘                     │   │
│  │          │                        │                             │   │
│  │          │ Generates proof        │ Identifies issues          │   │
│  │          │                        │                             │   │
│  │          ▼                        ▼                             │   │
│  │   ┌──────────────┐        ┌──────────────┐                     │   │
│  │   │  Self-Refine │◀───────│   Feedback   │                     │   │
│  │   │    Loop      │        │    Signal    │                     │   │
│  │   └──────────────┘        └──────────────┘                     │   │
│  │          │                                                      │   │
│  │          │ Improved proof                                       │   │
│  │          ▼                                                      │   │
│  │   ┌──────────────┐                                             │   │
│  │   │    Final     │  (Only when verifier finds no issues)      │   │
│  │   │   Output     │                                             │   │
│  │   └──────────────┘                                             │   │
│  │                                                                  │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘
Python
class DeepSeekMathV2:
    """
    DeepSeekMath-V2: Self-Verifiable Mathematical Reasoning.

    Core innovation: Train generator to self-correct before
    submitting final answer, using verifier as reward signal.

    Key insight: Don't just train for correct answers—train
    for finding and fixing your own mistakes.
    """

    def __init__(
        self,
        generator_model,
        verifier_model,
        max_refinement_rounds: int = 5
    ):
        self.generator = generator_model
        self.verifier = verifier_model
        self.max_rounds = max_refinement_rounds

    def solve_with_self_verification(self, problem: str) -> dict:
        """
        Solve problem with iterative self-verification.
        """
        # Initial solution attempt
        solution = self.generator.generate(
            f"Solve this problem with rigorous step-by-step derivation:\n\n{problem}"
        )

        refinement_history = []

        for round_num in range(self.max_rounds):
            # Verify current solution
            verification = self.verify(problem, solution)

            refinement_history.append({
                "round": round_num,
                "solution_preview": solution[:500],
                "issues_found": verification["issues"],
                "confidence": verification["confidence"]
            })

            # If no issues found, we're done
            if not verification["issues"]:
                return {
                    "solution": solution,
                    "answer": self.extract_answer(solution),
                    "verified": True,
                    "refinement_rounds": round_num,
                    "history": refinement_history
                }

            # Refine based on issues
            solution = self.refine(problem, solution, verification["issues"])

        # Max rounds reached
        return {
            "solution": solution,
            "answer": self.extract_answer(solution),
            "verified": False,
            "refinement_rounds": self.max_rounds,
            "history": refinement_history
        }

    def verify(self, problem: str, solution: str) -> dict:
        """
        Verify solution and identify issues.

        The verifier is trained to be highly critical—
        it's better to flag false positives than miss errors.
        """
        prompt = f"""Carefully verify this mathematical solution.

Problem: {problem}

Proposed Solution:
{solution}

Check for:
1. Logical errors (unjustified steps, non-sequiturs)
2. Calculation mistakes
3. Missing cases or edge conditions
4. Incorrect theorem applications
5. Gaps in reasoning

For each issue found, explain:
- What the issue is
- Where it occurs
- Why it's wrong
- How to fix it

If the solution is correct, state "NO ISSUES FOUND".

Verification:"""

        response = self.verifier.generate(prompt, max_tokens=2048)

        issues = self.parse_issues(response)
        confidence = self.estimate_confidence(response)

        return {
            "issues": issues,
            "confidence": confidence,
            "raw_response": response
        }

    def refine(
        self,
        problem: str,
        solution: str,
        issues: list[dict]
    ) -> str:
        """
        Refine solution based on identified issues.
        """
        issues_text = "\n".join([
            f"- {issue['description']}: {issue['explanation']}"
            for issue in issues
        ])

        prompt = f"""Fix the following issues in your solution.

Problem: {problem}

Your Previous Solution:
{solution}

Issues Found:
{issues_text}

Provide a corrected solution that addresses ALL issues.
Be rigorous and show your reasoning clearly.

Corrected Solution:"""

        return self.generator.generate(prompt, max_tokens=4096)


class SelfVerificationTraining:
    """
    Training procedure for self-verifiable reasoning.

    Key insight: Train generator with verifier as reward signal,
    but also train generator to internalize verification.
    """

    def __init__(
        self,
        generator,
        verifier,
        problems_dataset
    ):
        self.generator = generator
        self.verifier = verifier
        self.problems = problems_dataset

    def train_step(self, problem: str, ground_truth: str) -> dict:
        """
        Single training step with self-verification reward.
        """
        # Generate initial solution
        solution = self.generator.generate(problem)

        # Get verification feedback
        verification = self.verify_solution(solution, ground_truth)

        # Compute rewards
        rewards = {
            # Correctness reward
            "correctness": 1.0 if verification["correct"] else 0.0,

            # Self-identification reward: did generator flag its own errors?
            "self_awareness": self.compute_self_awareness_reward(
                solution,
                verification
            ),

            # Efficiency reward: fewer refinement rounds is better
            "efficiency": 1.0 / (1 + verification["refinements_needed"])
        }

        total_reward = (
            0.6 * rewards["correctness"] +
            0.3 * rewards["self_awareness"] +
            0.1 * rewards["efficiency"]
        )

        return {
            "reward": total_reward,
            "rewards_breakdown": rewards,
            "solution": solution,
            "verification": verification
        }

    def compute_self_awareness_reward(
        self,
        solution: str,
        verification: dict
    ) -> float:
        """
        Reward generator for self-identifying issues.

        If the solution includes self-doubt markers where
        errors actually exist, that's good.
        If it's confidently wrong, that's bad.
        """
        # Check for self-doubt markers in solution
        doubt_markers = [
            "let me verify",
            "checking this",
            "wait, this might be wrong",
            "hmm, let me reconsider",
            "i should double-check"
        ]

        has_self_doubt = any(
            marker in solution.lower()
            for marker in doubt_markers
        )

        if verification["correct"]:
            # Correct solution: self-doubt is neutral
            return 0.5
        else:
            # Wrong solution: self-doubt is good, confidence is bad
            return 0.8 if has_self_doubt else 0.0


class ScaledVerificationCompute:
    """
    Scale verification compute to handle harder problems.

    Key insight: When generator starts solving problems the
    verifier can't evaluate, generate new training data
    for the verifier from these hard cases.
    """

    def __init__(
        self,
        generator,
        verifier,
        code_executor
    ):
        self.generator = generator
        self.verifier = verifier
        self.executor = code_executor

    def scale_verification(
        self,
        hard_problems: list[str]
    ) -> list[dict]:
        """
        Generate training data for verifier from hard problems.

        When the verifier can't evaluate a proof:
        1. Generate multiple solutions
        2. Use code execution to determine correctness
        3. Create training examples for verifier
        """
        verifier_training_data = []

        for problem in hard_problems:
            # Generate many solutions
            solutions = [
                self.generator.generate(problem)
                for _ in range(32)
            ]

            # Use code execution to verify
            for solution in solutions:
                # Try to verify with code
                code_result = self.code_verify(problem, solution)

                if code_result["verified"]:
                    # Create training example for verifier
                    verifier_training_data.append({
                        "problem": problem,
                        "solution": solution,
                        "is_correct": code_result["correct"],
                        "issues": code_result.get("issues", [])
                    })

        return verifier_training_data

    def code_verify(
        self,
        problem: str,
        solution: str
    ) -> dict:
        """
        Verify solution using code execution.

        This provides ground truth for training the verifier
        on problems it previously couldn't evaluate.
        """
        # Extract answer from solution
        answer = self.extract_answer(solution)

        # Generate verification code
        verification_code = self.generate_verification_code(
            problem,
            answer
        )

        try:
            result = self.executor.run(verification_code)
            return {
                "verified": True,
                "correct": "CORRECT" in result,
                "code_output": result
            }
        except Exception as e:
            return {
                "verified": False,
                "error": str(e)
            }

DeepSeekMath-V2: The Complete Technical Picture

DeepSeekMath-V2's gold medal at IMO 2025 came from a carefully designed verifier-first training pipeline that represents a fundamental shift in how mathematical reasoning models are trained.

Why DeepSeekMath-V2 Succeeded Where Others Failed

To understand DeepSeekMath-V2's breakthrough, we need to understand why previous approaches struggled with IMO-level mathematics.

The fundamental challenge of mathematical reasoning:

IMO problems aren't like typical AI benchmarks. They require:

  1. Multi-step logical chains: Solutions often span 20-50 reasoning steps, where each step must be logically valid. A single error anywhere invalidates the entire proof.

  2. Creative insight: Unlike routine exercises, IMO problems require discovering non-obvious key insights—the "aha moments" that unlock solutions. These insights often involve seeing connections between seemingly unrelated mathematical concepts.

  3. Rigorous justification: In competition mathematics, "I think this is true" isn't enough. Every claim must be proven. Olympiad graders penalize hand-waving mercilessly.

  4. Case coverage: Many problems require exhaustive case analysis. Missing even one case means the proof is incomplete.

Why previous approaches failed:

ApproachLimitationResult
Fine-tuning on math dataModel learns to pattern-match, not reasonGood on routine problems, fails on novel ones
Chain-of-thought promptingNo error correction mechanismErrors propagate and compound
Reward models (standard RLHF)Reward hacking: model learns to fool verifierSuperficially correct but logically flawed proofs
Formal verification (Lean)Translation bottleneck~30% of problems can't be formalized cleanly

DeepSeek's key insight:

The problem with standard RLHF is that the generator and verifier are trained together. This creates a co-evolution where the generator learns to exploit weaknesses in the verifier—producing proofs that look correct but contain subtle flaws the verifier misses.

DeepSeek inverted this: train the verifier first, completely independently, on human-annotated proofs. The verifier learns what "rigorous proof" means from human mathematicians, not from the generator's outputs. Only then is the generator trained to satisfy this pre-trained verifier.

This seemingly simple change has profound implications:

  • The verifier's standards are fixed during generator training
  • The generator can't "game" the verifier because the verifier was trained on real proofs, not generator outputs
  • The verifier learns to spot actual mathematical errors, not just patterns correlated with errors

The 685B parameter model:

DeepSeekMath-V2 is built on DeepSeek-V3, a massive Mixture-of-Experts (MoE) model. The architecture matters:

  • 685B total parameters: Comparable to GPT-4 scale
  • 37B active per token: MoE routing means only a fraction of parameters are used for each prediction, keeping inference tractable
  • 256 experts with top-8 routing: For each token, the 8 most relevant experts are activated
  • 128K context window: Long contexts are essential for multi-step proofs

The MoE architecture is particularly well-suited for mathematics because different experts can specialize in different mathematical domains (algebra, geometry, number theory, combinatorics). When solving a geometry problem, geometry-specialized experts activate more frequently.

Math-specific pre-training:

Before the verifier-first training, the base model underwent extensive math pre-training:

  • 500B tokens of mathematical content: Research papers, textbooks, competition solutions, proofs
  • Extended tokenizer: Special tokens for mathematical symbols (∑, ∫, √, ∀, ∃) and LaTeX expressions
  • Structured problem-solution pairs: The model learns the format of mathematical reasoning

This pre-training gives the model "mathematical common sense"—it knows that √2 is irrational, that prime numbers have exactly two divisors, and thousands of other mathematical facts that would otherwise need to be derived from scratch.

The Four-Phase Training Pipeline in Detail

DeepSeekMath-V2's training proceeds through four carefully sequenced phases. Each phase builds on the previous one, and the order is critical—reversing or parallelizing phases would undermine the entire approach.

Phase 1: Verifier Pre-Training (The Foundation)

The verifier is trained on a curated dataset of human-annotated mathematical proofs. This dataset is special:

  • Source: 50,000+ competition mathematics solutions from IMO, Putnam, national olympiads
  • Annotation: Each solution is annotated by expert mathematicians with:
    • Overall correctness label (correct/incorrect)
    • Step-by-step correctness (which steps are valid/invalid)
    • Error classifications (logical gap, calculation error, missing case, unjustified claim)
    • Rigor scores (1-10 for each step)

The verifier learns through multiple objectives:

  1. Binary classification: Given a proof, predict correct/incorrect
  2. Error localization: Given an incorrect proof, identify which step contains the first error
  3. Error typing: Classify what kind of error it is
  4. Rigor scoring: Rate the rigor of each step on a continuous scale

The key is that this training uses only human annotations, never generator outputs. The verifier learns human mathematicians' standards for rigor.

Phase 2: Meta-Verifier Training (Combining Signals)

A single correctness score isn't enough for training. The meta-verifier learns to combine multiple quality dimensions into a single scalar reward for GRPO:

DimensionWeightWhat It Measures
Correctness0.40Is the final answer right?
Rigor0.30Are all steps fully justified?
Completeness0.20Are all cases handled?
Clarity0.10Is the logic chain followable?

The weights are learned, not hand-tuned. The meta-verifier is trained to predict human quality judgments on a held-out set of proofs with human quality scores.

Why separate rigor from correctness? A proof can reach the right answer through flawed reasoning (lucky errors cancel out), or have rigorous steps but miss a case. Competition graders care about both.

Phase 3: Generator Initialization from Verifier (The Key Innovation)

This is what makes DeepSeekMath-V2 different. Instead of initializing the generator from the base model, they initialize it from the trained verifier's weights.

Why does this work? The verifier has learned to represent "what makes a good proof." These representations—encoded in the model's weights—transfer to generation. The generator starts with an internal model of proof quality before it generates a single proof.

The initialization procedure:

  1. Copy all weights from the trained verifier
  2. The verifier's classification head is discarded
  3. A new language modeling head is added
  4. Brief supervised fine-tuning (1000 steps) on proof generation to "unlock" generation capability

The generator now has:

  • Deep understanding of proof structure (from pre-training)
  • Internal model of what makes proofs rigorous (from verifier)
  • Generation capability (from fine-tuning)

Phase 4: Iterative GRPO with Verifier Co-Evolution

Now the generator and verifier train together, but with the verifier's pre-training providing a strong foundation that resists reward hacking.

The GRPO (Group Relative Policy Optimization) procedure:

Code
For each training iteration:
    1. Sample a batch of 64 problems
    2. For each problem, generate 8 candidate solutions
    3. Meta-verifier scores all 512 solutions
    4. Normalize scores within each problem's 8 solutions (group relative)
    5. Update generator to increase probability of high-scoring solutions
    6. Find "hard cases" where generator and verifier disagree strongly
    7. Optionally update verifier on hard cases with human labels

Why group relative scoring?

Standard reward models give absolute scores. The problem: as the generator improves, its solutions cluster in a narrow score range, providing weak gradient signal.

GRPO normalizes within each group of 8 solutions. Even if all 8 solutions are excellent, the best one gets positive advantage, the worst gets negative. This provides consistent training signal throughout training.

The self-verification loop at inference time:

At inference, DeepSeekMath-V2 doesn't just generate once. It implements a generate-verify-refine loop:

Code
1. Generate initial solution
2. Verifier checks for issues
3. If issues found:
   a. Verifier identifies the problematic step
   b. Generator regenerates from that point
   c. Repeat verification
4. If no issues found:
   a. Solution passes to output
5. Maximum 5 refinement rounds

This is different from standard "self-consistency" (majority voting over multiple generations). The verifier provides targeted feedback about what's wrong, allowing surgical fixes rather than complete regeneration.

Specific Techniques That Made the Difference

1. Error-Aware Training Data

Most math datasets only have correct solutions. DeepSeek collected and annotated incorrect solutions too:

  • Intentionally flawed solutions with labeled errors
  • Student solutions with common mistakes
  • Model-generated errors from earlier training runs

This lets the verifier learn what errors look like, not just what correct proofs look like.

2. Step-Level Credit Assignment

For long proofs, which step caused success/failure? DeepSeek's training assigns credit at the step level:

  • Each step gets a "contribution score" based on how it affects the final proof validity
  • Steps that introduce key insights get higher credit
  • Steps that introduce errors get negative credit

This is implemented via a learned "step importance" model trained on human annotations of "key steps" in solutions.

3. Diverse Generation Strategies

At inference, multiple solution attempts use different strategies:

  • Forward reasoning: Start from given, derive conclusion
  • Backward reasoning: Start from what we want to prove, work backward
  • Case analysis: Enumerate cases systematically
  • Contradiction: Assume the negation, derive contradiction
  • Induction: Base case + inductive step
  • Extremal principle: Consider minimum/maximum elements
  • Invariants: Find quantities preserved under operations
  • Construction: Build an explicit example/counter-example

Each strategy corresponds to a different prompt prefix, encouraging diverse solution attempts.

4. Verification-Guided Tree Search

For the hardest problems, DeepSeekMath-V2 uses a tree search where the verifier guides exploration:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│              VERIFICATION-GUIDED TREE SEARCH                            │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│                         ┌───────────────┐                               │
│                         │    Problem    │                               │
│                         └───────┬───────┘                               │
│                                 │                                       │
│               ┌─────────────────┴─────────────────┐                     │
│               │                                   │                     │
│               ▼                                   ▼                     │
│     ┌─────────────────┐                ┌─────────────────┐              │
│     │     Path A      │                │     Path B      │              │
│     │  Score: 0.8 ✓   │                │  Score: 0.3 ✗   │              │
│     └────────┬────────┘                └─────────────────┘              │
│              │                                  │                       │
│           EXPAND                              PRUNE                     │
│              │                           (too low score)                │
│     ┌────────┴────────┐                                                 │
│     │                 │                                                 │
│     ▼                 ▼                                                 │
│  ┌──────┐         ┌──────┐                                              │
│  │  A1  │         │  A2  │                                              │
│  │ 0.85 │         │ 0.72 │                                              │
│  └──────┘         └──────┘                                              │
│                                                                         │
│  The verifier scores each partial proof, enabling:                      │
│  • Early pruning of unpromising paths (Path B)                          │
│  • Focused compute on high-potential branches (Path A)                  │
│  • More efficient than exhaustive or random search                      │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

The verifier scores partial proofs, allowing aggressive pruning of unpromising paths. This is more efficient than exhaustive search or random sampling.

Code
┌─────────────────────────────────────────────────────────────────────────┐
│           DEEPSEEKMATH-V2: VERIFIER-FIRST TRAINING PIPELINE            │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  CORE INSIGHT:                                                          │
│  ─────────────                                                          │
│                                                                          │
│  "Train the verifier FIRST, then train the generator to satisfy the    │
│   verifier. This is the opposite of typical reward model approaches."  │
│                                                                          │
│  Why this works:                                                        │
│  • Verifier learns what "rigorous proof" means without generator bias │
│  • Generator then learns to produce proofs the verifier accepts        │
│  • Creates virtuous cycle: better verifier → better generator → ...    │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  MODEL ARCHITECTURE:                                                     │
│  ───────────────────                                                     │
│                                                                          │
│  ┌──────────────────────────────────────────────────────────────┐      │
│  │                    DeepSeekMath-V2 (685B)                     │      │
│  ├──────────────────────────────────────────────────────────────┤      │
│  │                                                               │      │
│  │  Base: DeepSeek-V3 Mixture of Experts                        │      │
│  │  • 685B total parameters (37B active per token)             │      │
│  │  • 256 experts, top-8 routing                                │      │
│  │  • 128K context window                                       │      │
│  │                                                               │      │
│  │  Math-specific adaptations:                                  │      │
│  │  • Extended math tokenizer (special symbols, LaTeX)         │      │
│  │  • Math-focused pre-training (500B math tokens)             │      │
│  │  • Self-verification heads for each expert                   │      │
│  │                                                               │      │
│  └──────────────────────────────────────────────────────────────┘      │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  TRAINING PHASES:                                                        │
│  ────────────────                                                        │
│                                                                          │
│  Phase 1: Verifier Pre-Training (Independent of Generator)              │
│  ─────────────────────────────────────────────────────────              │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │                                                                  │   │
│  │  Training data: Human-annotated proof evaluations              │   │
│  │  • IMO/Putnam solutions labeled correct/incorrect             │   │
│  │  • Step-by-step rigor annotations                              │   │
│  │  • Error type classifications                                   │   │
│  │                                                                  │   │
│  │  Objective: Predict proof correctness + identify error locations │   │
│  │                                                                  │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│                                                                          │
│  Phase 2: Meta-Verifier Training                                        │
│  ────────────────────────────────                                       │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │                                                                  │   │
│  │  The meta-verifier scores proofs on multiple dimensions:       │   │
│  │                                                                  │   │
│  │  • Rigor score: Are all steps fully justified?                 │   │
│  │  • Completeness score: Are all cases handled?                  │   │
│  │  • Clarity score: Is the logic chain clear?                    │   │
│  │  • Novelty score: Does it use creative approaches?             │   │
│  │                                                                  │   │
│  │  Combined into single "proof quality" signal for GRPO          │   │
│  │                                                                  │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│                                                                          │
│  Phase 3: Generator Training from Verifier Checkpoint                   │
│  ───────────────────────────────────────────────────                    │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │                                                                  │   │
│  │  Initialize generator from VERIFIER weights (not base model!)  │   │
│  │  • Generator inherits verifier's understanding of "good proof" │   │
│  │  • Then fine-tune for generation while preserving verification │   │
│  │                                                                  │   │
│  │  This is the key innovation: generator already "knows" what   │   │
│  │  makes a proof rigorous before it learns to generate          │   │
│  │                                                                  │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│                                                                          │
│  Phase 4: Iterative GRPO Refinement                                     │
│  ──────────────────────────────────                                     │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │                                                                  │   │
│  │  for epoch in range(num_epochs):                               │   │
│  │      # Generator produces solutions                            │   │
│  │      solutions = generator.sample(problems, n=64)             │   │
│  │                                                                  │   │
│  │      # Verifier scores each solution                           │   │
│  │      scores = verifier.score(solutions)                       │   │
│  │                                                                  │   │
│  │      # GRPO update: push generator toward high-scoring proofs │   │
│  │      generator = grpo_update(generator, solutions, scores)    │   │
│  │                                                                  │   │
│  │      # CRITICAL: Also update verifier on hard cases            │   │
│  │      hard_cases = find_disagreements(generator, verifier)     │   │
│  │      verifier = update_verifier(verifier, hard_cases)         │   │
│  │                                                                  │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│                                                                          │
│  RESULTS:                                                               │
│  ────────                                                               │
│  • Gold medal (35/42) at IMO 2025                                      │
│  • Gold medal at China Mathematical Olympiad 2024                       │
│  • 118/120 on Putnam 2024                                              │
│  • 96%+ on MATH benchmark                                               │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘
Python
class DeepSeekMathV2VerifierFirstTraining:
    """
    DeepSeekMath-V2's actual training procedure.

    The key insight: train verifier completely before touching generator.
    This prevents the "reward hacking" problem where generators learn
    to fool the verifier rather than produce rigorous proofs.
    """

    def __init__(
        self,
        base_model,  # DeepSeek-V3 685B MoE
        proof_dataset,  # Human-annotated IMO/Putnam solutions
        meta_verifier_config
    ):
        self.base = base_model
        self.proofs = proof_dataset
        self.meta_config = meta_verifier_config

    def phase1_train_verifier(self):
        """
        Phase 1: Train verifier on human-annotated proofs.

        The verifier learns to identify:
        - Correct vs incorrect proofs
        - Types of errors (logical gap, calculation error, missing case)
        - Quality dimensions (rigor, completeness, clarity)

        CRITICAL: No generator involved yet. Verifier learns from
        human annotations, not from generator outputs.
        """
        verifier = copy.deepcopy(self.base)

        # Train on human-labeled proof correctness
        for proof_example in self.proofs.correctness_labels:
            loss = verifier.train_step(
                input=f"Proof:\n{proof_example['proof']}\n\nIs this proof correct?",
                target=proof_example['label'],  # "correct" or "incorrect"
            )

        # Train on error identification
        for proof_example in self.proofs.error_annotations:
            loss = verifier.train_step(
                input=f"Proof:\n{proof_example['proof']}\n\nIdentify all errors:",
                target=proof_example['errors'],  # List of errors with locations
            )

        # Train on step-by-step rigor assessment
        for proof_example in self.proofs.rigor_annotations:
            loss = verifier.train_step(
                input=f"Proof:\n{proof_example['proof']}\n\nRate rigor of each step:",
                target=proof_example['step_scores'],
            )

        return verifier

    def phase2_train_meta_verifier(self, verifier):
        """
        Phase 2: Train meta-verifier that combines multiple quality signals.

        The meta-verifier learns to produce a single score that captures:
        - Correctness (is it right?)
        - Rigor (are steps justified?)
        - Completeness (all cases covered?)
        - Clarity (is it understandable?)
        """
        meta_verifier = MetaVerifier(verifier, self.meta_config)

        # Calibrate weights for each dimension
        for proof_example in self.proofs.quality_annotations:
            predicted_scores = meta_verifier.score_all_dimensions(
                proof_example['proof']
            )

            # Multi-task loss: predict each quality dimension
            loss = sum([
                mse_loss(predicted_scores[dim], proof_example[dim])
                for dim in ['rigor', 'completeness', 'clarity', 'correctness']
            ])

            meta_verifier.update(loss)

        return meta_verifier

    def phase3_initialize_generator_from_verifier(self, verifier):
        """
        Phase 3: Initialize generator from verifier weights.

        THIS IS THE KEY INNOVATION.

        Instead of initializing generator from base model,
        initialize from the TRAINED VERIFIER.

        Why? The generator now starts with deep understanding of
        what makes a proof rigorous, before it even learns to generate.

        It's like learning to critique essays before writing them.
        """
        generator = copy.deepcopy(verifier)

        # Adapt verifier architecture for generation
        # (Verifier was trained as classifier, generator needs to produce text)
        generator = self.adapt_for_generation(generator)

        # Brief generation warm-up (don't train too long, preserve verifier knowledge)
        for problem in self.proofs.generation_warmup[:1000]:
            loss = generator.train_step(
                input=f"Problem: {problem['problem']}\n\nProvide a rigorous proof:",
                target=problem['solution'],
            )

        return generator

    def phase4_iterative_grpo(self, generator, meta_verifier, epochs=10):
        """
        Phase 4: Iterative GRPO with co-evolution of generator and verifier.

        Unlike standard GRPO, we also update the verifier during training.
        This prevents the generator from finding reward-hacking strategies.
        """
        for epoch in range(epochs):
            # Sample problems
            problems = self.proofs.sample_problems(batch_size=64)

            for problem in problems:
                # Generator produces multiple solutions
                solutions = [
                    generator.generate(problem, temperature=0.7)
                    for _ in range(8)
                ]

                # Meta-verifier scores each solution
                scores = [
                    meta_verifier.score(problem, solution)
                    for solution in solutions
                ]

                # GRPO update: increase probability of high-scoring solutions
                grpo_loss = self.grpo_objective(
                    generator,
                    solutions,
                    scores,
                    reference_model=self.base
                )
                generator.update(grpo_loss)

                # Find cases where generator and verifier disagree
                disagreements = self.find_disagreements(
                    generator, meta_verifier, problem
                )

                # Update verifier on disagreements (with human labels if needed)
                if disagreements:
                    verifier_loss = self.update_verifier_on_hard_cases(
                        meta_verifier, disagreements
                    )
                    meta_verifier.update(verifier_loss)

        return generator, meta_verifier

    def grpo_objective(
        self,
        generator,
        solutions: list[str],
        scores: list[float],
        reference_model
    ) -> float:
        """
        Group Relative Policy Optimization.

        Key: Compare solutions to each other, not to absolute threshold.
        This provides stable learning signal even as model improves.
        """
        # Normalize scores within group
        mean_score = sum(scores) / len(scores)
        std_score = (sum((s - mean_score)**2 for s in scores) / len(scores)) ** 0.5

        if std_score < 1e-6:
            return 0.0  # All solutions have same score, no gradient

        normalized_scores = [(s - mean_score) / std_score for s in scores]

        # GRPO loss: push toward high-scoring solutions
        loss = 0.0
        for solution, norm_score in zip(solutions, normalized_scores):
            # Log probability under generator
            log_prob = generator.log_prob(solution)

            # Log probability under reference (for KL constraint)
            ref_log_prob = reference_model.log_prob(solution)

            # Advantage-weighted policy gradient
            loss -= norm_score * (log_prob - 0.1 * (log_prob - ref_log_prob))

        return loss / len(solutions)

    def find_disagreements(self, generator, verifier, problem):
        """
        Find cases where generator produces solutions that
        verifier scores very differently than expected.

        These are valuable for improving the verifier.
        """
        disagreements = []

        # Generate diverse solutions
        solutions = [
            generator.generate(problem, temperature=1.0)
            for _ in range(32)
        ]

        for solution in solutions:
            verifier_score = verifier.score(problem, solution)
            generator_confidence = generator.confidence(solution)

            # Disagreement: generator confident, verifier uncertain
            if generator_confidence > 0.9 and verifier_score < 0.5:
                disagreements.append({
                    'problem': problem,
                    'solution': solution,
                    'generator_confidence': generator_confidence,
                    'verifier_score': verifier_score,
                    'type': 'generator_overconfident'
                })

            # Disagreement: generator uncertain, verifier confident
            if generator_confidence < 0.5 and verifier_score > 0.9:
                disagreements.append({
                    'problem': problem,
                    'solution': solution,
                    'generator_confidence': generator_confidence,
                    'verifier_score': verifier_score,
                    'type': 'verifier_overconfident'
                })

        return disagreements


class MetaVerifier:
    """
    Meta-verifier that combines multiple quality signals into
    a single score for GRPO training.

    This is what makes DeepSeekMath-V2 different from standard
    reward models: it scores on RIGOR, not just correctness.
    """

    def __init__(self, base_verifier, config):
        self.verifier = base_verifier
        self.config = config

        # Learned weights for combining dimensions
        self.dimension_weights = {
            'correctness': 0.4,   # Is the final answer right?
            'rigor': 0.3,        # Are all steps fully justified?
            'completeness': 0.2, # Are all cases handled?
            'clarity': 0.1       # Is the logic chain clear?
        }

    def score(self, problem: str, solution: str) -> float:
        """
        Score a solution on multiple dimensions, combine into single value.
        """
        scores = self.score_all_dimensions(solution)

        combined = sum(
            self.dimension_weights[dim] * scores[dim]
            for dim in scores
        )

        return combined

    def score_all_dimensions(self, solution: str) -> dict:
        """
        Score solution on each quality dimension.

        This is the key to training for RIGOR, not just correctness.
        A proof can be correct but not rigorous (missing justifications).
        A proof can be rigorous but incomplete (missing cases).
        """
        # Correctness: Is the final answer right?
        correctness_prompt = f"""
Solution:
{solution}

Is the final answer correct? Score 0-1:"""

        correctness = self.verifier.generate(correctness_prompt)

        # Rigor: Are all steps fully justified?
        rigor_prompt = f"""
Solution:
{solution}

Rate the rigor of this proof. A rigorous proof:
- Justifies every step explicitly
- Cites theorems/lemmas when used
- Has no logical gaps

Score 0-1:"""

        rigor = self.verifier.generate(rigor_prompt)

        # Completeness: Are all cases handled?
        completeness_prompt = f"""
Solution:
{solution}

Are all cases handled? Does the proof:
- Consider all possibilities?
- Handle edge cases?
- Account for boundary conditions?

Score 0-1:"""

        completeness = self.verifier.generate(completeness_prompt)

        # Clarity: Is the logic chain clear?
        clarity_prompt = f"""
Solution:
{solution}

Is this proof clear and well-organized?
- Are steps in logical order?
- Is notation consistent?
- Can you follow the argument?

Score 0-1:"""

        clarity = self.verifier.generate(clarity_prompt)

        return {
            'correctness': float(correctness),
            'rigor': float(rigor),
            'completeness': float(completeness),
            'clarity': float(clarity)
        }

Part 10: Formal Verification with Lean 4

Formal verification provides absolute guarantees of correctness. DeepSeek-Prover-V2 achieved 88.9% on MiniF2F-test using Lean 4.

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    FORMAL VERIFICATION LANDSCAPE                         │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  WHY FORMAL VERIFICATION?                                                │
│  ─────────────────────────                                               │
│                                                                          │
│  LLM reasoning can be wrong even when it looks right.                  │
│  Formal verification is ABSOLUTE: if Lean accepts, it's correct.       │
│                                                                          │
│  "When an answer comes with a Lean4 proof, you don't have to trust    │
│   the AI—you can check it." - Lean4 research                          │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  KEY SYSTEMS (2025):                                                     │
│  ───────────────────                                                     │
│                                                                          │
│  System                  Performance           Approach                 │
│  ─────────────────────────────────────────────────────────────────────  │
│  DeepSeek-Prover-V2     88.9% MiniF2F         RL + Subgoal decomp     │
│  (671B)                 49/658 PutnamBench                              │
│                                                                          │
│  Harmonic Aristotle     Gold IMO 2025         Formally verified        │
│                                                                          │
│  AlphaProof             Silver IMO 2024        Lean + RL               │
│                                                                          │
│  Goedel-Prover          1.64M formal stmts    Statement formalizer    │
│                                                                          │
│  APOLLO Framework       3-7% → 40%+           Syntax cleaning + LLM    │
│                                                                          │
│  Lean Copilot           Tactic suggestion     LLM in Lean             │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

DeepSeek-Prover-V2: Subgoal Decomposition

DeepSeek-Prover-V2 achieved 88.9% on MiniF2F-test through a clever decomposition strategy: instead of trying to prove theorems in one shot, it breaks complex proofs into subgoals, proves each subgoal separately, then combines them. This mirrors how human mathematicians tackle hard proofs—find intermediate lemmas that, once proven, make the main theorem tractable. The implementation uses reinforcement learning to learn good decomposition strategies.

Python
class DeepSeekProverV2:
    """
    DeepSeek-Prover-V2: Formal theorem proving in Lean 4.

    Key innovations:
    1. Subgoal decomposition with RL
    2. Two-stage: 671B for decomposition, 7B for proof search
    3. Non-CoT mode for fast generation

    Results:
    - 88.9% on MiniF2F-test
    - 49/658 PutnamBench problems
    - 6/15 AIME problems (formalized)
    """

    def __init__(
        self,
        decomposer_model,  # Large model (671B)
        prover_model,      # Small model (7B)
        lean_env
    ):
        self.decomposer = decomposer_model
        self.prover = prover_model
        self.lean = lean_env

    def prove_theorem(self, theorem: str) -> dict:
        """
        Prove theorem using subgoal decomposition.

        1. Decomposer generates high-level proof sketch
        2. Decomposer formalizes sketch into subgoals
        3. Prover solves each subgoal
        4. Combine into complete proof
        """
        # Step 1: Generate proof sketch (CoT)
        sketch = self.decomposer.generate(f"""
Theorem: {theorem}

Generate a high-level proof sketch with key steps:
""")

        # Step 2: Decompose into formal subgoals
        subgoals = self.decompose_to_subgoals(theorem, sketch)

        # Step 3: Prove each subgoal
        proofs = []
        for subgoal in subgoals:
            proof = self.prove_subgoal(subgoal)
            if proof is None:
                return {"success": False, "failed_at": subgoal}
            proofs.append(proof)

        # Step 4: Combine proofs
        complete_proof = self.combine_proofs(theorem, subgoals, proofs)

        # Step 5: Verify in Lean
        verification = self.lean.verify(complete_proof)

        return {
            "success": verification["valid"],
            "proof": complete_proof,
            "subgoals": subgoals,
            "subproofs": proofs
        }

    def decompose_to_subgoals(
        self,
        theorem: str,
        sketch: str
    ) -> list[str]:
        """
        Convert proof sketch to formal Lean 4 subgoals.
        """
        prompt = f"""
Theorem: {theorem}

Proof sketch:
{sketch}

Convert this into a sequence of Lean 4 subgoals.
Each subgoal should be:
1. A valid Lean 4 statement
2. Simpler than the main theorem
3. Together they imply the main theorem

Subgoals (in Lean 4 syntax):
"""
        response = self.decomposer.generate(prompt)
        return self.parse_lean_statements(response)

    def prove_subgoal(
        self,
        subgoal: str,
        max_attempts: int = 64
    ) -> str | None:
        """
        Prove subgoal using smaller prover model.

        Uses best-first search with Lean verification.
        """
        candidates = []

        for _ in range(max_attempts):
            # Generate proof candidate (non-CoT mode for speed)
            candidate = self.prover.generate(
                f"-- Prove: {subgoal}\n",
                max_tokens=512,
                temperature=0.8
            )

            # Verify in Lean
            result = self.lean.verify(f"{subgoal}\n{candidate}")

            if result["valid"]:
                return candidate

            candidates.append({
                "proof": candidate,
                "error": result.get("error")
            })

        # All attempts failed
        return None


class LeanEnvironment:
    """
    Interface to Lean 4 theorem prover.
    """

    def __init__(self, project_path: str):
        self.project_path = project_path

    def verify(self, code: str) -> dict:
        """
        Verify Lean 4 code.

        Returns:
        - valid: True if code type-checks
        - errors: List of type errors if invalid
        """
        # Write to temp file
        temp_file = f"{self.project_path}/temp_verify.lean"
        with open(temp_file, "w") as f:
            f.write(code)

        # Run Lean
        import subprocess
        result = subprocess.run(
            ["lake", "build", "temp_verify"],
            cwd=self.project_path,
            capture_output=True,
            text=True
        )

        if result.returncode == 0:
            return {"valid": True}
        else:
            return {
                "valid": False,
                "error": result.stderr
            }

    def get_tactics(self, goal: str) -> list[str]:
        """
        Get available tactics for a goal state.

        Useful for guiding proof search.
        """
        # Query Lean for applicable tactics
        pass


class APOLLOFramework:
    """
    APOLLO: Automated LLM and Lean Collaboration.

    Pipeline that transforms LLM proof sketches into
    verified Lean 4 proofs.

    Results: 3-7% baseline → 40%+ accuracy

    Pipeline stages:
    1. LLM generates proof sketch
    2. Syntax cleaning (fix common Lean errors)
    3. Auto-solvers (for simple subgoals)
    4. LLM-driven sub-proof generation
    5. Verification and iteration
    """

    def __init__(
        self,
        llm,
        lean_env,
        auto_solvers
    ):
        self.llm = llm
        self.lean = lean_env
        self.solvers = auto_solvers

    def prove(self, theorem: str) -> dict:
        """
        APOLLO proof pipeline.
        """
        # Stage 1: LLM generates initial proof
        initial = self.llm.generate(f"""
theorem {theorem} := by
  -- Prove this theorem
""")

        # Stage 2: Syntax cleaning
        cleaned = self.clean_syntax(initial)

        # Stage 3: Try auto-solvers on subgoals
        with_autos = self.apply_auto_solvers(cleaned)

        # Stage 4: LLM fills remaining gaps
        complete = self.fill_gaps(theorem, with_autos)

        # Stage 5: Verify
        result = self.lean.verify(complete)

        if result["valid"]:
            return {"success": True, "proof": complete}

        # Iterate if needed
        for _ in range(3):
            complete = self.fix_errors(complete, result["error"])
            result = self.lean.verify(complete)
            if result["valid"]:
                return {"success": True, "proof": complete}

        return {"success": False, "best_attempt": complete}

    def clean_syntax(self, proof: str) -> str:
        """
        Fix common Lean 4 syntax errors.

        Common issues:
        - Wrong indentation
        - Missing imports
        - Lean 3 vs Lean 4 syntax
        - Incorrect tactic names
        """
        # Indentation fixes
        proof = self.fix_indentation(proof)

        # Lean 3 → Lean 4 conversions
        conversions = {
            "begin": "by",
            "end": "",
            "simp only": "simp only",
            "rw [": "rw [",
            "←": "← ",
        }
        for old, new in conversions.items():
            proof = proof.replace(old, new)

        return proof

    def apply_auto_solvers(self, proof: str) -> str:
        """
        Apply automatic solvers to simple subgoals.

        Common auto tactics:
        - simp: simplification
        - ring: ring arithmetic
        - linarith: linear arithmetic
        - omega: integer arithmetic
        - norm_num: numeric normalization
        """
        # Parse proof to find sorry/holes
        holes = self.find_holes(proof)

        for hole in holes:
            # Try each auto solver
            for solver in ["simp", "ring", "linarith", "omega", "norm_num"]:
                test = proof.replace(hole, solver)
                if self.lean.verify(test)["valid"]:
                    proof = test
                    break

        return proof

Part 11: Tool-Integrated Reasoning (TIR)

Tool-Integrated Reasoning (TIR) is a paradigm where the language model interleaves natural language reasoning with calls to external tools—calculators, symbolic math engines, code interpreters. This isn't just a convenience; recent theoretical work proves that TIR strictly expands the class of problems solvable.

Why LLMs Need Tools for Math

The arithmetic bottleneck:

LLMs are trained to predict the next token, not to compute arithmetic. While they learn addition and multiplication patterns, their accuracy drops sharply:

  • 2-digit × 2-digit: ~95% accuracy
  • 3-digit × 3-digit: ~60% accuracy
  • 4-digit × 4-digit: ~20% accuracy

For competition math, a single arithmetic error invalidates the entire solution. Tool use eliminates this bottleneck entirely.

What tools enable:

ToolCapabilityExample Use
CalculatorExact arithmetic"What's 2^1024 mod 17?"
SymPySymbolic algebraFactor polynomials, solve equations
Code interpreterEnumerationCheck all cases for n < 100
Lean/CoqFormal verificationProve intermediate lemmas
Wolfram AlphaKnowledge + computeSeries expansions, limits

Self-Consistency TIR (SC-TIR):

The winning approach in AIMO Progress Prize 1 (NuminaMath) combined TIR with self-consistency:

  1. Generate N solutions, each interleaving reasoning with code
  2. Execute all code blocks, replacing them with results
  3. Extract final answers
  4. Vote across solutions (majority wins)

This combines the flexibility of natural language reasoning with the precision of programmatic execution.

Theoretical Foundation

Recent work by Chen et al. (2024) proves that TIR is strictly more powerful than pure language modeling:

  • There exist problems solvable in polynomial time with TIR that require exponential time without
  • The separation holds even for problems with polynomial-length solutions
  • Tool access isn't just faster—it enables solving fundamentally different problems

Tool-Integrated Reasoning fundamentally expands what LLMs can solve. Recent theoretical work proves this expansion is strict—not just helpful, but necessary.

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    TOOL-INTEGRATED REASONING                             │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  THEORETICAL BREAKTHROUGH (2025):                                        │
│  ─────────────────────────────────                                       │
│                                                                          │
│  "Tools enable a STRICT expansion of the model's capabilities,         │
│   breaking the capability ceiling of pure-text models by unlocking     │
│   problem-solving strategies that are otherwise impossible or          │
│   intractably verbose."                                                 │
│                                                                          │
│  Key insight: TIR advantage is NOT confined to computation-intensive   │
│  problems—it extends to those requiring significant ABSTRACT INSIGHT.  │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  TIR MODELS AND RESULTS:                                                │
│  ───────────────────────                                                 │
│                                                                          │
│  Model               Approach                 Result                   │
│  ─────────────────────────────────────────────────────────────────────  │
│  ToRA                SFT + output shaping     MATH: 68% → 84%         │
│  NuminaMath (TIR)    SC-TIR decoding         AIMO: 29/50              │
│  START               Long CoT + Python       GPQA: 61%, AIME: 73%    │
│  AutoTIR             Autonomous tool choice   Adaptive tool use       │
│  ToRL                RLVR + code gen         Emergent tool use        │
│  o4-mini + Python    Tool use                AIME: 92.7% → 99.5%     │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

The following implementation suite provides four complete TIR approaches, each with different tradeoffs:

ToolIntegratedReasoning: The basic TIR implementation. The model generates natural language reasoning interleaved with <code>...</code> blocks. When a code block is detected, it's extracted and executed, with the result injected back into the context. The model continues reasoning with access to the computation result. This loop continues until the model produces a final <answer> tag.

SelfConsistencyTIR: SC-TIR combines TIR with ensemble voting. Generate N solutions (typically 32-64), each using the basic TIR loop. Execute all code blocks and verify results. Filter out solutions where code execution failed. Among remaining solutions, extract numerical answers and take the majority vote. This three-stage filtering (tool verification → execution success → consensus) is what won AIMO Progress Prize 1.

AutoTIR: Instead of telling the model when to use tools, let it learn through RL. The model receives task success reward and discovers that tool use improves outcomes. This leads to emergent, adaptive tool use—the model learns when tools help vs. when direct reasoning is faster.

TIRWithRefinement: Adds a self-critique loop. After generating a TIR solution, the model reviews its own reasoning and code, looking for errors. If issues are found, it generates a refined solution. This catches errors that even code execution might miss (e.g., solving the wrong problem correctly).

Python
class ToolIntegratedReasoning:
    """
    Tool-Integrated Reasoning (TIR) for mathematical problem solving.

    Combines natural language reasoning with code execution
    for precise computation and verification.
    """

    def __init__(
        self,
        model,
        code_executor,
        tools: list[str] = ["python", "sympy", "wolfram"]
    ):
        self.model = model
        self.executor = code_executor
        self.tools = tools

    def solve_with_tir(self, problem: str) -> dict:
        """
        Solve problem with tool-integrated reasoning.

        The model interleaves:
        1. Natural language reasoning (analysis, strategy)
        2. Code execution (computation, verification)
        """
        prompt = f"""Solve this problem using both reasoning and code.

Problem: {problem}

Use <think> tags for reasoning and <code> tags for Python code.
Execute code to verify calculations.

Solution:"""

        # Generate with tool use
        response = ""
        reasoning_steps = []
        code_executions = []

        while True:
            # Generate next chunk
            chunk = self.model.generate(
                prompt + response,
                max_tokens=1024,
                stop=["</code>", "<answer>"]
            )
            response += chunk

            # Check for code block
            if "<code>" in chunk:
                # Extract and execute code
                code = self.extract_code(response)
                result = self.executor.run(code)

                # Add result to context
                response += f"</code>\n<output>{result}</output>\n"
                code_executions.append({
                    "code": code,
                    "result": result
                })

            elif "<answer>" in response:
                # Done
                break

            elif len(response) > 8000:
                # Safety limit
                break

        # Extract answer
        answer = self.extract_answer(response)

        return {
            "answer": answer,
            "reasoning": response,
            "code_executions": code_executions
        }


class SelfConsistencyTIR:
    """
    Self-Consistency with Tool-Integrated Reasoning (SC-TIR).

    NuminaMath's winning approach for AIMO Progress Prize.

    Key insight: Generate multiple TIR solutions,
    use code verification to filter, then majority vote.
    """

    def __init__(
        self,
        model,
        code_executor,
        n_samples: int = 32
    ):
        self.model = model
        self.executor = code_executor
        self.n_samples = n_samples

    def solve_with_sc_tir(self, problem: str) -> dict:
        """
        Solve with Self-Consistency TIR.

        1. Generate N solutions with TIR
        2. Filter by code verification
        3. Majority vote on remaining
        """
        solutions = []

        for i in range(self.n_samples):
            # Generate solution with TIR
            solution = self.generate_tir_solution(problem)

            # Verify with code
            verified = self.code_verify(problem, solution)

            solutions.append({
                "solution": solution,
                "answer": self.extract_answer(solution["reasoning"]),
                "verified": verified,
                "code_passed": solution["code_executions"]
            })

        # Filter to verified solutions
        verified_solutions = [s for s in solutions if s["verified"]]

        if not verified_solutions:
            # Fall back to all solutions
            verified_solutions = solutions

        # Majority vote on answers
        answer_counts = Counter(
            s["answer"] for s in verified_solutions
            if s["answer"] is not None
        )

        if not answer_counts:
            return {"answer": None, "confidence": 0}

        best_answer = answer_counts.most_common(1)[0][0]
        confidence = answer_counts[best_answer] / len(verified_solutions)

        return {
            "answer": best_answer,
            "confidence": confidence,
            "n_verified": len(verified_solutions),
            "n_total": self.n_samples,
            "vote_distribution": dict(answer_counts)
        }

    def code_verify(self, problem: str, solution: dict) -> bool:
        """
        Verify solution using code execution.
        """
        # Check if all code executions succeeded
        for execution in solution["code_executions"]:
            if "Error" in execution["result"]:
                return False

        # Try to verify final answer
        answer = self.extract_answer(solution["reasoning"])
        if answer is None:
            return False

        # Generate verification code
        verify_code = f"""
# Verify answer to: {problem[:200]}
answer = {answer}

# Re-solve from scratch
{self.generate_solution_code(problem)}

# Check
print("VERIFIED" if abs(computed - answer) < 1e-6 else "MISMATCH")
"""

        try:
            result = self.executor.run(verify_code)
            return "VERIFIED" in result
        except:
            return False


class STARTReasoning:
    """
    START: Seamless Long CoT + Tool Integration.

    First open-source LLM combining extended reasoning
    with Python tool use.

    Results:
    - GPQA: 61% (PhD-level science)
    - AIME: 73%
    - LiveCodeBench: 47%
    """

    def __init__(self, model, executor):
        self.model = model
        self.executor = executor

    def solve(self, problem: str) -> dict:
        """
        START-style reasoning with tools.

        Key difference from basic TIR:
        - Much longer reasoning chains (8K+ tokens)
        - Strategic tool invocation (not every step)
        - Self-correction within reasoning
        """
        system_prompt = """You are a mathematical reasoning expert.

Think step by step, showing ALL your work.
Use <code>...</code> when computation would help.
Use <verify>...</verify> to check intermediate results.
Continue until you're confident in your answer.

Be thorough—better to reason more than miss something."""

        response = self.model.generate(
            f"{system_prompt}\n\nProblem: {problem}\n\nSolution:",
            max_tokens=8192,  # Long context for extended reasoning
            temperature=0.7
        )

        # Process tool calls
        processed = self.process_tool_calls(response)

        return {
            "answer": self.extract_answer(processed),
            "reasoning": processed,
            "tool_calls": self.count_tool_calls(response)
        }


class AutoTIR:
    """
    AutoTIR: Autonomous Tool-Integrated Reasoning via RL.

    The model learns WHEN and WHICH tool to use
    through reinforcement learning, without explicit instruction.

    Key insight: Let the model discover tool use
    through maximizing task success reward.
    """

    def __init__(
        self,
        model,
        tools: dict,  # {"python": executor, "wolfram": api, ...}
        reward_model
    ):
        self.model = model
        self.tools = tools
        self.reward = reward_model

    def train_autonomous_tool_use(
        self,
        problems: list[str],
        n_epochs: int = 3
    ):
        """
        Train model to autonomously use tools via RL.

        No tool-use demonstrations needed—model learns
        from reward signal alone.
        """
        for epoch in range(n_epochs):
            for problem in problems:
                # Generate solution (model decides tool use)
                solution = self.model.generate(
                    f"Solve: {problem}",
                    max_tokens=4096,
                    tools_available=list(self.tools.keys())
                )

                # Execute any tool calls
                processed = self.execute_tools(solution)

                # Compute reward
                reward = self.reward.score(problem, processed)

                # RL update (simplified)
                self.rl_update(problem, solution, reward)

    def execute_tools(self, solution: str) -> str:
        """
        Execute any tool calls in the solution.
        """
        # Find tool invocations
        for tool_name, executor in self.tools.items():
            pattern = f"<{tool_name}>(.*?)</{tool_name}>"
            matches = re.findall(pattern, solution, re.DOTALL)

            for match in matches:
                result = executor(match)
                solution = solution.replace(
                    f"<{tool_name}>{match}</{tool_name}>",
                    f"<{tool_name}>{match}</{tool_name}>\n<result>{result}</result>"
                )

        return solution

Part 12: Verification-Refinement Pipelines for IMO 2025

Multiple teams achieved gold-medal performance at IMO 2025 using carefully designed verification-refinement pipelines.

The Core Insight: Single-Shot Generation Isn't Enough

The naive approach to AI math solving is straightforward: prompt a strong model, get an answer. This works surprisingly well on easy problems—GPT-4 class models achieve ~75% on AMC-level problems with zero-shot prompting.

But for IMO-level problems, single-shot accuracy drops below 10%. Why? Because hard math problems require:

  1. Exploration: Trying multiple approaches before finding one that works
  2. Error detection: Recognizing when a reasoning path is going wrong
  3. Refinement: Fixing errors without starting from scratch
  4. Verification: Confirming the final proof is actually correct

Human mathematicians do all four constantly. They don't write a proof linearly from start to finish—they draft, check, revise, and often restart. Verification-refinement pipelines encode this human workflow into AI systems.

The Pipeline Architecture

A typical verification-refinement pipeline has four stages:

Stage 1: Diverse Generation

Generate many candidate solutions using different strategies. Diversity matters more than quantity—8 diverse approaches beats 64 slight variations. Techniques include:

  • Different prompts ("use algebra", "try coordinates", "work backwards")
  • Different temperatures (0.3 for focused, 0.9 for creative)
  • Different models (ensemble across GPT-5, Claude, Gemini)

Stage 2: Verification

Each candidate is verified by a separate model (or the same model with a "critic" prompt). Verification looks for:

  • Logical validity: Does each step follow from the previous?
  • Calculation correctness: Are the computations right?
  • Completeness: Are all cases handled?
  • Goal achievement: Does the proof actually prove what's claimed?

Verification outputs a confidence score and a list of issues.

Stage 3: Refinement

Candidates with high potential but specific issues enter a refinement loop. The model receives:

  • The original problem
  • Its attempted solution
  • The specific issues identified
  • Instructions to fix those issues

This is far more efficient than regenerating from scratch—often 80% of a solution is correct, and only a specific step needs fixing.

Stage 4: Selection

After verification and refinement, select the best solution. Criteria include:

  • Verification confidence (higher is better)
  • Refinement stability (solutions that improve with refinement are promising)
  • Answer agreement (if multiple solutions reach the same answer, it's likely correct)

Why This Works Better Than Pure Generation

Verification-refinement pipelines improve accuracy from 21-38% (single-shot) to 85.7% (5/6 problems at IMO 2025). The improvement comes from:

  1. Error recovery: Single-shot either works or fails. Refinement can recover from errors.
  2. Computational allocation: Spend more compute on hard problems, less on easy ones.
  3. Ensemble effects: Different approaches find different solutions; verification picks the best.
  4. Human-like iteration: Mimics how expert mathematicians actually work.
Python
class IMO2025Pipeline:
    """
    Model-agnostic verification-and-refinement pipeline.

    Based on arxiv 2507.15855:
    "Winning Gold at IMO 2025 with a Model-Agnostic
     Verification-and-Refinement Pipeline"

    Results: 21-38% baseline → 85.7% (5/6 problems)
    Tested with: Gemini 2.5 Pro, GPT-5, Grok-4
    """

    def __init__(
        self,
        model,
        n_candidates: int = 32,
        max_refinement_rounds: int = 5
    ):
        self.model = model
        self.n_candidates = n_candidates
        self.max_rounds = max_refinement_rounds

    def solve_imo_problem(self, problem: str) -> dict:
        """
        Full pipeline for IMO problem.

        Steps:
        1. Generate diverse candidate solutions
        2. Verify each candidate
        3. Refine promising candidates
        4. Select best verified solution
        """
        # Stage 1: Diverse generation
        candidates = self.generate_diverse_candidates(problem)

        # Stage 2: Initial verification
        verified = []
        for candidate in candidates:
            verification = self.verify_solution(problem, candidate)
            verified.append({
                "solution": candidate,
                "verification": verification,
                "score": verification["confidence"]
            })

        # Sort by verification score
        verified.sort(key=lambda x: x["score"], reverse=True)

        # Stage 3: Refine top candidates
        refined = []
        for candidate in verified[:8]:  # Top 8
            refined_solution = self.refine_until_verified(
                problem,
                candidate["solution"]
            )
            refined.append(refined_solution)

        # Stage 4: Select best
        best = max(refined, key=lambda x: x["final_score"])

        return {
            "solution": best["solution"],
            "confidence": best["final_score"],
            "refinement_rounds": best["rounds"],
            "verified": best["verified"]
        }

    def generate_diverse_candidates(self, problem: str) -> list[str]:
        """
        Generate diverse solution candidates.

        Key prompts for diversity:
        1. Different approaches (algebraic, geometric, combinatorial)
        2. Different starting points
        3. Temperature variation
        """
        candidates = []

        # Approach-specific prompts
        approaches = [
            "Solve using algebraic manipulation:",
            "Solve using coordinate geometry:",
            "Solve using combinatorial arguments:",
            "Solve using number theory techniques:",
            "Solve by considering extreme cases first:",
            "Solve by working backwards from the goal:",
        ]

        for approach in approaches:
            for temp in [0.3, 0.7, 1.0]:
                prompt = f"""{approach}

Problem: {problem}

Provide a complete, rigorous proof.

Solution:"""

                solution = self.model.generate(
                    prompt,
                    max_tokens=4096,
                    temperature=temp
                )
                candidates.append(solution)

        return candidates

    def verify_solution(
        self,
        problem: str,
        solution: str
    ) -> dict:
        """
        Verify solution with detailed checking.
        """
        prompt = f"""Carefully verify this IMO solution.

Problem: {problem}

Proposed Solution:
{solution}

Check each step for:
1. Mathematical correctness
2. Logical validity
3. Completeness of argument
4. Proper handling of all cases

Rate your confidence that this proof is correct (0-100).
If there are errors, describe them precisely.

Verification:"""

        response = self.model.generate(prompt, max_tokens=2048)

        # Parse confidence and issues
        confidence = self.extract_confidence(response)
        issues = self.extract_issues(response)

        return {
            "confidence": confidence,
            "issues": issues,
            "is_valid": confidence > 80 and len(issues) == 0
        }

    def refine_until_verified(
        self,
        problem: str,
        solution: str
    ) -> dict:
        """
        Iteratively refine solution until verified.
        """
        current = solution

        for round_num in range(self.max_rounds):
            verification = self.verify_solution(problem, current)

            if verification["is_valid"]:
                return {
                    "solution": current,
                    "final_score": verification["confidence"],
                    "rounds": round_num,
                    "verified": True
                }

            # Refine based on issues
            current = self.refine_solution(
                problem,
                current,
                verification["issues"]
            )

        # Max rounds, return best attempt
        final_verification = self.verify_solution(problem, current)
        return {
            "solution": current,
            "final_score": final_verification["confidence"],
            "rounds": self.max_rounds,
            "verified": final_verification["is_valid"]
        }

    def refine_solution(
        self,
        problem: str,
        solution: str,
        issues: list[str]
    ) -> str:
        """
        Refine solution to address specific issues.
        """
        issues_text = "\n".join(f"- {issue}" for issue in issues)

        prompt = f"""Fix the following issues in this IMO solution.

Problem: {problem}

Current Solution:
{solution}

Issues to Fix:
{issues_text}

Provide a corrected, complete proof that addresses ALL issues.

Corrected Solution:"""

        return self.model.generate(prompt, max_tokens=4096)


class GeminiDeepThinkPipeline:
    """
    Gemini Deep Think's approach for IMO 2025 gold.

    Key innovations:
    1. Parallel thinking: explore multiple paths simultaneously
    2. Novel RL for multi-step reasoning
    3. End-to-end natural language (no Lean translation)
    """

    def __init__(self, model):
        self.model = model

    def solve_with_parallel_thinking(
        self,
        problem: str,
        n_paths: int = 8
    ) -> dict:
        """
        Explore multiple solution paths in parallel.

        Unlike sequential refinement, parallel thinking
        explores diverse approaches simultaneously,
        then combines insights.
        """
        # Generate parallel solution attempts
        paths = []
        for i in range(n_paths):
            path = self.generate_solution_path(problem, seed=i)
            paths.append(path)

        # Cross-pollinate insights
        insights = self.extract_insights(paths)

        # Generate final solution using all insights
        final = self.synthesize_solution(problem, insights, paths)

        return {
            "solution": final,
            "paths_explored": len(paths),
            "key_insights": insights
        }

    def generate_solution_path(
        self,
        problem: str,
        seed: int
    ) -> dict:
        """
        Generate one solution path.

        Different seeds lead to different starting strategies.
        """
        strategies = [
            "start_with_examples",
            "start_with_extreme_cases",
            "start_with_algebraic_setup",
            "start_with_geometric_intuition",
            "start_with_invariants",
            "start_with_working_backwards",
            "start_with_contradiction",
            "start_with_induction"
        ]

        strategy = strategies[seed % len(strategies)]

        prompt = f"""Solve this IMO problem using the {strategy} approach.

Problem: {problem}

Think deeply and show your complete reasoning.

Solution:"""

        solution = self.model.generate(
            prompt,
            max_tokens=8192,  # Extended thinking
            temperature=0.8
        )

        return {
            "strategy": strategy,
            "solution": solution,
            "key_steps": self.extract_key_steps(solution)
        }

    def extract_insights(self, paths: list[dict]) -> list[str]:
        """
        Extract key insights from all solution paths.
        """
        all_steps = []
        for path in paths:
            all_steps.extend(path["key_steps"])

        # Find common themes
        prompt = f"""These are key steps from multiple solution attempts:

{chr(10).join(all_steps[:30])}

What are the most important mathematical insights?
What techniques appear most promising?

Key Insights:"""

        response = self.model.generate(prompt, max_tokens=1024)
        return self.parse_insights(response)

    def synthesize_solution(
        self,
        problem: str,
        insights: list[str],
        paths: list[dict]
    ) -> str:
        """
        Synthesize final solution from insights and paths.
        """
        insights_text = "\n".join(f"- {i}" for i in insights)

        # Find most promising path
        best_path = max(paths, key=lambda p: len(p["key_steps"]))

        prompt = f"""Create a complete, rigorous solution.

Problem: {problem}

Key Insights to Use:
{insights_text}

Most Promising Approach:
{best_path['solution'][:2000]}

Create a clear, complete proof incorporating these insights.

Final Solution:"""

        return self.model.generate(prompt, max_tokens=8192)

How Gemini Deep Think Actually Works: The Full Technical Picture

Gemini Deep Think achieved 35/42 points (gold medal) at IMO 2025 using an approach fundamentally different from previous systems. Here's the complete technical breakdown.

The Paradigm Shift: Why Google Abandoned Formal Verification

To understand Gemini Deep Think, we need to understand what Google tried before and why they changed direction.

The AlphaProof approach (IMO 2024):

At IMO 2024, Google's AlphaProof achieved silver medal using formal verification in Lean 4:

  1. Problem (natural language) → Translate to Lean statement
  2. Neural model generates candidate proof steps
  3. Lean type-checker verifies each step
  4. MCTS explores the proof space
  5. Valid proof → Extract answer

This worked, but had critical limitations:

  • Translation bottleneck: ~30% of IMO problems resist clean formalization. Combinatorics and number theory problems with complex constructions often require ad-hoc Lean definitions that are as hard to find as the proof itself.

  • Formalization overhead: Even for problems that can be formalized, the Lean proof is often 10x longer than the natural language proof. This inflates the search space.

  • No transfer to new domains: Lean proofs require domain-specific libraries. A geometry proof uses different tactics than an algebra proof. Each IMO has novel problem types that may not have good library support.

  • Speed: Formal verification is slow. Each step requires Lean compilation and type-checking, limiting how many paths can be explored in competition time.

Google's realization:

After IMO 2024, Google's team made a key observation: the best human mathematicians don't use formal verification. They reason in natural language, using intuition, pattern recognition, and self-correction.

What if, instead of making AI more like Lean, they made AI more like a human mathematician?

The Deep Think Architecture

Gemini Deep Think is built on Gemini 2.0, Google's most capable foundation model, with specialized training for extended mathematical reasoning.

Core capabilities:

CapabilityHow It's Achieved
Extended thinkingCan generate up to 32K tokens of reasoning before answering
Self-verificationChecks its own work, identifies errors, and corrects them
Parallel explorationConsiders multiple solution approaches simultaneously
Strategy adaptationRecognizes which proof techniques suit which problem types
Natural language rigorProduces proofs that are rigorous despite not being formally verified

Why "Deep Think"?

The name reflects the core innovation: thinking deeper before answering. Standard LLMs generate tokens left-to-right, each token depending only on previous tokens. Deep Think can:

  1. Generate a long internal reasoning chain (thousands of tokens)
  2. Evaluate whether it's making progress
  3. Backtrack and try different approaches
  4. Synthesize insights from failed attempts
  5. Only produce a final answer when confident

This is closer to human expert reasoning than typical LLM behavior.

The Training Pipeline: How Google Built Deep Think

Stage 1: Mathematical Pre-Training

Before Deep Think training, the base Gemini model underwent extensive mathematical pre-training:

  • Corpus: Mathematical research papers (arXiv math), textbooks, competition solutions, proof assistant formalizations, educational content
  • Scale: Estimated 100B+ tokens of mathematical content
  • Special handling: LaTeX rendering, mathematical notation, structured proofs
  • Verification: Solutions in training data were verified by independent systems where possible

Stage 2: Curated Solution Corpus

Google built a high-quality corpus specifically for IMO-style reasoning. This corpus was hand-curated by mathematicians, not scraped from the web:

  • Sources:

    • IMO shortlist problems and official solutions (1959-2024)
    • Putnam competition solutions (1938-2024)
    • National olympiad solutions from 60+ countries
    • Art of Problem Solving "gold standard" solutions
    • Solutions from mathematical olympiad textbooks
  • Quality criteria:

    • Every solution verified correct by at least 2 reviewers
    • Solutions must be complete (no gaps, no "obvious" steps)
    • Multiple alternative solutions per problem where available
    • Step-by-step reasoning chains explicitly marked
  • Annotation:

    • Key insight annotations: "This is where the crucial observation happens"
    • Difficulty ratings per step
    • Alternative approaches noted
    • Common mistakes on similar problems

This corpus is much smaller than typical LLM training data (tens of thousands of problems vs. billions of tokens) but vastly higher quality.

Stage 3: Multi-Step Reinforcement Learning

Standard RLHF gives one reward at the end. For 30-step proofs, this provides almost no signal about which steps were good.

Google's innovation: reward at every step.

They trained a Process Reward Model (PRM) that scores intermediate reasoning steps:

Code
Step 1: "Let x = the smallest element..." → PRM score: 0.7
Step 2: "Since x is minimal, x ≤ f(x)..." → PRM score: 0.85
Step 3: "Therefore f(x) = x + 1..." → PRM score: 0.3 (error introduced)

The generator is trained with rewards at every step, allowing credit assignment deep in the reasoning chain. Steps that introduce key insights get high reward; steps that introduce errors get low reward.

The reward function:

Code
R(step, context) = 0.4 × correctness(step)
                 + 0.3 × progress_toward_solution(step)
                 + 0.2 × insight_value(step)
                 + 0.1 × coherence_with_previous(step)

Each component is learned from human annotations:

  • Correctness: Is this step mathematically valid?
  • Progress: Does this step get us closer to the solution?
  • Insight value: Does this step contain a key observation?
  • Coherence: Does this step follow logically from previous steps?

Stage 4: Parallel Exploration Training

Humans solving hard problems don't try one approach and give up. They explore multiple angles, extract insights from failed attempts, and combine ideas.

Google trained Deep Think to do the same:

  1. Generate N parallel solution attempts (N=8 typically)
  2. Each attempt uses a different initial strategy
  3. Evaluate which approaches made progress
  4. Extract insights from partially successful attempts
  5. Synthesize a final solution combining the best ideas

The key innovation: rewarding exploration diversity. The model is rewarded not just for finding solutions, but for trying genuinely different approaches. This prevents mode collapse where the model always tries the same strategy.

Diversity metric:

D=avg_pairwise_dist(Eapproaches)D = \text{avg\_pairwise\_dist}(\mathbf{E}_{\text{approaches}})

The first few reasoning steps of each path are embedded, and paths with high pairwise distances (meaning they're trying different things) get bonus reward.

Why Natural Language Works for IMO

This seems counterintuitive. Formal verification gives 100% correctness guarantees. How can natural language compete?

The key insight: IMO doesn't need 100% correctness guarantee.

IMO scoring gives partial credit. A mostly-correct solution with a minor gap can still earn 5-6 points out of 7. Formal verification's all-or-nothing approach gives up these partial credit points.

More importantly, natural language allows reasoning about concepts that are hard to formalize:

ConceptNatural LanguageLean 4
"Obviously continuous"ImmediateDefine continuity, prove epsilon-delta
"Standard argument"1 sentenceImport library, match pattern
"By symmetry"1 wordDefine symmetry group, prove invariance
"Without loss of generality"1 phraseComplex case analysis

Human mathematicians use these shortcuts constantly. They work because mathematicians share conventions about what "obvious" and "standard" mean. Deep Think has learned these conventions from training data.

Self-verification compensates for formality:

Without formal verification, how do we know the proof is correct? Deep Think uses multi-round self-verification:

  1. Generate complete proof
  2. A separate "critic" pass reads the proof and looks for errors
  3. If errors found, identify the problematic step
  4. Regenerate from that point
  5. Re-verify
  6. Repeat until clean or max iterations

At IMO 2025, Deep Think's self-verification caught ~15% of the errors that would have been in single-shot generation.

Competition-Time Procedure

At the actual IMO, Deep Think operated under competition constraints: 4.5 hours for 6 problems.

Time allocation:

Problem DifficultyTime BudgetParallel Paths
Easy (P1, P4)20-30 min4 paths
Medium (P2, P5)40-60 min8 paths
Hard (P3, P6)60-90 min12 paths

The competition procedure:

Code
For each problem:
    1. Quick assessment (2-3 min)
       - Problem type classification
       - Difficulty estimation
       - Likely proof techniques

    2. Parallel exploration (60% of time budget)
       - Launch N parallel solution attempts
       - Each uses different strategy seed
       - Monitor progress, prune failing paths
       - Extract insights from partial progress

    3. Solution synthesis (20% of time budget)
       - Combine insights from all paths
       - Generate complete proof using best approach
       - Self-verify

    4. Refinement (20% of time budget)
       - Multiple verification passes
       - Fix any issues found
       - Polish presentation

Early termination:

If a path achieves >0.95 confidence score before time budget expires, Deep Think terminates early and moves to verification. This optimization saved ~15% of total time at IMO 2025.

Results Breakdown: What Deep Think Solved at IMO 2025

ProblemTypePointsHow It Was Solved
P1Algebra7/7Direct approach, found key inequality
P2Combinatorics7/7Case analysis with invariant
P3Number Theory7/7Modular arithmetic + descent
P4Geometry7/7Coordinate geometry
P5Algebra7/7Functional equation techniques
P6Combinatorics0/7Failed to find construction
Total35/42Gold medal

P6 was the only problem Deep Think failed. It required constructing a specific combinatorial object with unusual properties—exactly the kind of creative construction that remains challenging.

Comparison with DeepSeekMath-V2

AspectGemini Deep ThinkDeepSeekMath-V2
ApproachEnd-to-end natural languageSelf-verification loop
Formal verificationNoneNone
Model sizeGemini 2.0 scale (~1T?)685B (37B active)
Training innovationMulti-step RL, diversity rewardVerifier-first pipeline
Inference innovationParallel explorationGenerate-verify-refine loop
IMO 2025 score35/42 (Gold)35/42 (Gold)
StrengthsNovel problems, geometryRigorous proofs, algebra
WeaknessesSome constructionsSome geometry

Both systems achieved gold, using fundamentally different approaches. This suggests multiple viable paths to mathematical reasoning, not a single "correct" architecture.

Code
┌─────────────────────────────────────────────────────────────────────────┐
│              GEMINI DEEP THINK: COMPLETE ARCHITECTURE                    │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  KEY BREAKTHROUGH: End-to-End Natural Language                          │
│  ─────────────────────────────────────────────────                      │
│                                                                          │
│  Previous approaches (AlphaProof, DeepSeek-Prover):                    │
│  • Problem (NL) → Translate to Lean → Prove in Lean → Extract answer   │
│  • Translation bottleneck: ~30% of problems lost in translation        │
│  • Requires formal verification infrastructure                          │
│                                                                          │
│  Gemini Deep Think:                                                     │
│  • Problem (NL) → Reason in natural language → Answer                   │
│  • No formal verification dependency                                    │
│  • Scales with general reasoning capability                             │
│  • Handles ALL problem types (including those hard to formalize)        │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  PARALLEL THINKING ARCHITECTURE:                                        │
│  ───────────────────────────────                                        │
│                                                                          │
│                         ┌─────────────┐                                 │
│                         │   Problem   │                                 │
│                         └──────┬──────┘                                 │
│                                │                                        │
│         ┌──────────────────────┼──────────────────────┐                │
│         ▼                      ▼                      ▼                │
│  ┌────────────┐        ┌────────────┐        ┌────────────┐           │
│  │  Path A:   │        │  Path B:   │        │  Path C:   │           │
│  │ Algebraic  │        │ Geometric  │        │ Invariant  │           │
│  │ Transform  │        │ Intuition  │        │  Search    │  ... N    │
│  └────────────┘        └────────────┘        └────────────┘           │
│         │                      │                      │                │
│         │ Explore              │ Explore              │ Explore        │
│         ▼                      ▼                      ▼                │
│  ┌────────────┐        ┌────────────┐        ┌────────────┐           │
│  │   Dead     │        │  Promising │        │  Promising │           │
│  │    End     │        │    Lead    │        │    Lead    │           │
│  └────────────┘        └────────────┘        └────────────┘           │
│                               │                      │                 │
│                               └──────────┬───────────┘                 │
│                                          ▼                             │
│                               ┌─────────────────┐                      │
│                               │ Cross-Pollinate │                      │
│                               │    Insights     │                      │
│                               └────────┬────────┘                      │
│                                        ▼                               │
│                               ┌─────────────────┐                      │
│                               │   Synthesize    │                      │
│                               │ Final Solution  │                      │
│                               └─────────────────┘                      │
│                                                                          │
│  Unlike sequential CoT, paths explore SIMULTANEOUSLY                    │
│  Insights flow between paths during exploration                         │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Training Pipeline: Novel RL for Multi-Step Mathematical Reasoning

Python
class GeminiDeepThinkTraining:
    """
    How Gemini Deep Think was actually trained.

    Key innovations:
    1. Curated high-quality corpus: Unlike typical web scraping,
       solutions were carefully selected for correctness and clarity
    2. Multi-step RL: Reward signal at each reasoning step, not just final
    3. Parallel exploration reward: Bonus for exploring diverse approaches
    """

    def create_training_corpus(self):
        """
        Step 1: Build high-quality solution corpus.

        Google's approach differs from typical math training:
        - Human mathematicians reviewed and curated solutions
        - Only "gold standard" proofs included (correct + elegant)
        - Multiple solution approaches per problem
        - Step-by-step reasoning chains annotated
        """
        corpus = {
            "sources": [
                "IMO shortlist solutions (1959-2024)",
                "Putnam exam solutions (1938-2024)",
                "National olympiad solutions (60+ countries)",
                "Art of Problem Solving curated solutions",
                "Mathematical olympiad textbook proofs",
                "Peer-reviewed competition mathematics papers"
            ],
            "quality_criteria": {
                "correctness": "Verified by multiple reviewers",
                "completeness": "All steps justified, no gaps",
                "clarity": "Each step flows logically",
                "elegance": "Efficient approach without redundancy"
            },
            "annotation": {
                "step_boundaries": "Where each reasoning step starts/ends",
                "key_insights": "Critical observations that unlock solution",
                "alternative_approaches": "Other valid paths to solution",
                "difficulty_scores": "Per-step difficulty assessment"
            }
        }
        return corpus

    def multi_step_rl_objective(self, trajectory, ground_truth):
        """
        Step 2: Multi-step reinforcement learning.

        Key difference from standard RLHF:
        - Reward at EACH reasoning step, not just final answer
        - Process reward model scores intermediate steps
        - Allows credit assignment deep in reasoning chain
        """
        total_reward = 0

        for step_idx, step in enumerate(trajectory.steps):
            # Step-level correctness reward
            step_correct = self.step_verifier(
                step,
                context=trajectory.steps[:step_idx],
                problem=trajectory.problem
            )

            # Progress reward: is this step making progress toward solution?
            progress = self.progress_estimator(
                step,
                ground_truth_solution=ground_truth
            )

            # Insight reward: does this step contain key mathematical insight?
            insight_score = self.insight_detector(step)

            # Combine rewards with temporal discounting
            step_reward = (
                0.4 * step_correct +
                0.3 * progress +
                0.2 * insight_score +
                0.1 * self.coherence_score(step, trajectory.steps[:step_idx])
            )

            total_reward += step_reward * (0.99 ** step_idx)

        # Final answer correctness (still important!)
        final_correct = self.verify_final_answer(trajectory, ground_truth)
        total_reward += 2.0 * final_correct  # Higher weight for getting it right

        return total_reward

    def parallel_exploration_training(self, problem, solutions):
        """
        Step 3: Train for parallel exploration.

        Novel RL objective: reward exploring DIVERSE approaches,
        not just finding one that works.
        """
        # Generate N parallel solution attempts
        paths = []
        for i in range(8):  # 8 parallel paths
            path = self.generate_path(problem, strategy_seed=i)
            paths.append(path)

        # Measure diversity of approaches
        diversity_score = self.compute_diversity(paths)

        # Find successful paths
        successful_paths = [p for p in paths if p["reached_solution"]]

        # Cross-pollination reward: did insights transfer between paths?
        cross_pollination = self.measure_insight_transfer(paths)

        # Final objective
        objective = (
            0.3 * diversity_score +      # Reward exploring different approaches
            0.4 * len(successful_paths) / len(paths) +  # Success rate
            0.3 * cross_pollination       # Insight sharing between paths
        )

        return objective

    def compute_diversity(self, paths):
        """
        Measure how different the approaches are.

        Want model to try genuinely different strategies,
        not just variations of the same approach.
        """
        # Encode each path's approach
        approach_embeddings = [
            self.approach_encoder(path["steps"][:5])  # First 5 steps define approach
            for path in paths
        ]

        # Compute pairwise distances
        total_distance = 0
        for i in range(len(paths)):
            for j in range(i+1, len(paths)):
                total_distance += cosine_distance(
                    approach_embeddings[i],
                    approach_embeddings[j]
                )

        # Normalize
        num_pairs = len(paths) * (len(paths) - 1) / 2
        return total_distance / num_pairs


class IMOCompetitionInference:
    """
    How Gemini Deep Think operates during actual IMO competition.

    Key constraint: 4.5 hours for 6 problems (45 min average)
    """

    def __init__(self):
        self.time_budget_per_problem = 45 * 60  # 45 minutes in seconds
        self.n_parallel_paths = 8

    def solve_problem_under_time_constraint(self, problem, deadline):
        """
        Competition-mode solving with time management.
        """
        start_time = time.time()

        # Phase 1: Quick assessment (2-3 minutes)
        # Determine problem type, likely difficulty, promising approaches
        assessment = self.quick_assess(problem)

        if assessment["likely_difficulty"] == "easy":
            # For easier problems, go deeper on fewer paths
            n_paths = 4
            time_per_path = (deadline - time.time()) / n_paths
        else:
            # For harder problems, explore more broadly
            n_paths = 12
            time_per_path = (deadline - time.time()) / n_paths

        # Phase 2: Parallel exploration
        paths = []
        for i in range(n_paths):
            if time.time() > deadline - 60:  # Keep 1 min for synthesis
                break

            path = self.explore_path_with_timeout(
                problem,
                strategy_seed=i,
                timeout=time_per_path
            )
            paths.append(path)

            # Early exit if high-confidence solution found
            if path.get("confidence", 0) > 0.95:
                break

        # Phase 3: Synthesize best solution
        remaining_time = deadline - time.time()
        solution = self.synthesize_with_verification(
            problem, paths, time_budget=remaining_time
        )

        return solution

    def quick_assess(self, problem):
        """
        Rapid problem assessment (no deep computation).
        """
        # Pattern matching against known problem types
        problem_type = self.classify_problem_type(problem)  # geometry, algebra, etc.

        # Historical difficulty of similar problems
        difficulty = self.estimate_difficulty(problem)

        # Identify most promising approaches for this type
        approaches = self.suggest_approaches(problem_type, problem)

        return {
            "problem_type": problem_type,
            "likely_difficulty": difficulty,
            "suggested_approaches": approaches
        }

Why Natural Language Works Better Than Formal Proofs for IMO

Python
# The surprising finding: for IMO specifically, natural language
# reasoning outperformed formal verification systems

formal_vs_natural = {
    "formal_proof_systems": {
        "advantages": [
            "100% correctness guarantee when proof compiles",
            "No hand-waving or gaps possible",
            "Can leverage existing proof libraries"
        ],
        "disadvantages_for_IMO": [
            "Translation bottleneck: ~30% of problems hard to formalize",
            "Combinatorics/number theory: formalization often longer than proof",
            "Novel problem types may lack library support",
            "Creativity constrained by formal language limitations"
        ]
    },
    "natural_language_reasoning": {
        "advantages_for_IMO": [
            "No translation loss - directly reason about problem as stated",
            "Full flexibility in approach and notation",
            "Can handle any problem type including novel constructions",
            "Scales with general language model capability"
        ],
        "challenges": [
            "Risk of subtle errors (mitigated by parallel verification)",
            "Need extensive training on high-quality proofs",
            "Harder to guarantee correctness"
        ]
    }
}

# Result at IMO 2025:
# Gemini Deep Think (natural language): 35/42 = Gold
# AlphaProof (Lean 4): Also achieved gold, but required formal translation
# Deep Think solved problems that were difficult to formalize

Part 13: AIMO Competition: Winning Techniques

The AI Mathematical Olympiad (AIMO) Progress Prizes have become the premier benchmark for practical AI math systems. Unlike academic benchmarks, AIMO prizes real money ($1M+ total) for systems that can solve new, unseen problems under competition conditions.

What Makes AIMO Different from Academic Benchmarks

Real competition conditions:

  • Problems are new: Not in any training data
  • Time pressure: Limited submissions, time constraints
  • Adversarial design: Problems designed to resist current techniques
  • Generalization test: Can't overfit to known problem styles

Prize structure incentivizes practical systems:

  • Progress Prize 1: $131,072 for first to achieve 30+ on private test
  • Progress Prize 2: Increasing thresholds, larger prizes
  • Grand Prize: $5M for IMO gold-medal level

Lessons from Winning Systems

AIMO competitions have produced hard-won knowledge about what actually works. Unlike academic benchmarks where you can optimize for a fixed test set, AIMO prizes require solving genuinely new problems under competition constraints. The winning systems share common patterns: tool integration for calculation reliability, fine-tuning on diverse math data, and ensemble/voting strategies for error detection. Here's what the winners did and why it worked.

NuminaMath (AIMO 1 Winner):

The winning team's key insight: combine fine-tuning with tool use. Their recipe:

  1. Base model: DeepSeekMath-Base 7B (already strong at math)
  2. Fine-tuning data: ~1M problems including synthetic, competitions, textbooks
  3. Tool integration: Python code interpreter for calculations
  4. Inference: SC-TIR with 64 samples, majority vote

This relatively simple recipe achieved 29/50 on the private test—significantly above the 20/50 threshold for the prize.

Why 7B with tools beat larger models:

  • Tool use eliminates arithmetic errors (the #1 failure mode)
  • Fine-tuning specializes the model for competition math
  • Self-consistency catches reasoning errors
  • Smaller models allow more samples within time/cost budget

The technique stack that works:

TechniqueContributionCost
Math fine-tuning+15-20% accuracyTraining once
Tool use (TIR)+10-15% accuracyMinimal
Self-consistency (64 samples)+5-10% accuracy64× compute
Better prompting+3-5% accuracyFree

Stacking all techniques turns a mediocre 7B model into a competition winner.

The AI Mathematical Olympiad (AIMO) Progress Prizes have driven practical advances in mathematical reasoning. Here's what winning teams used:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    AIMO PROGRESS PRIZE RESULTS                           │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  AIMO 1 (July 2024): NuminaMath - $131,072                             │
│  • Fine-tuned DeepSeekMath-Base 7B                                     │
│  • Tool-Integrated Reasoning (TIR)                                     │
│  • Self-Consistency TIR (SC-TIR) decoding                             │
│  • 29/50 problems solved                                               │
│                                                                          │
│  AIMO 2 (2024): NVIDIA NemoSkills                                      │
│  • Larger models with improved prompting                               │
│  • Enhanced code execution                                             │
│                                                                          │
│  AIMO 3 (2025): Increased difficulty, larger prizes                   │
│  • Problems designed to resist current techniques                     │
│  • Focus on genuine mathematical reasoning                            │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  COMPETITION STATISTICS (AIMO 1):                                       │
│  ─────────────────────────────────                                       │
│                                                                          │
│  • 16,104 registrations                                                │
│  • 1,401 participants on 1,161 teams                                  │
│  • 1,831 submissions from 81 countries                                │
│  • Prize awarded by Terence Tao at IMO Bath                           │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘
Python
class NuminaMathWinningApproach:
    """
    NuminaMath: Winner of AIMO Progress Prize 1.

    Key components:
    1. Fine-tuned DeepSeekMath-Base 7B
    2. Tool-Integrated Reasoning (TIR)
    3. Self-Consistency TIR (SC-TIR) decoding
    4. ~1M math problems training dataset

    Repository: github.com/project-numina/aimo-progress-prize
    """

    def __init__(
        self,
        model_path: str = "AI-MO/NuminaMath-7B-TIR",
        n_samples: int = 64
    ):
        self.model = self.load_model(model_path)
        self.executor = PythonCodeExecutor()
        self.n_samples = n_samples

    def solve(self, problem: str) -> dict:
        """
        Full NuminaMath solving pipeline.
        """
        # Generate multiple TIR solutions
        solutions = []
        for i in range(self.n_samples):
            solution = self.generate_tir_solution(problem)
            solutions.append(solution)

        # SC-TIR: Self-Consistency with code verification
        return self.sc_tir_aggregate(solutions)

    def generate_tir_solution(self, problem: str) -> dict:
        """
        Generate solution with Tool-Integrated Reasoning.

        Format:
        <think>reasoning</think>
        <code>python code</code>
        <output>execution result</output>
        ... repeat as needed ...
        <answer>final answer</answer>
        """
        prompt = f"""Solve this math problem step by step.
Use Python code for calculations.

Problem: {problem}

Solution:
<think>"""

        response = ""
        code_results = []

        while True:
            # Generate next chunk
            chunk = self.model.generate(
                prompt + response,
                max_tokens=512,
                stop=["</code>", "</answer>"]
            )
            response += chunk

            # Handle code execution
            if response.endswith("</code>"):
                # Execute code block
                code = self.extract_latest_code(response)
                result = self.executor.run(code)
                response += f"\n<output>\n{result}\n</output>\n<think>"
                code_results.append({"code": code, "output": result})

            elif "</answer>" in response or len(response) > 4000:
                break

        # Extract answer
        answer = self.extract_answer(response)

        return {
            "reasoning": response,
            "answer": answer,
            "code_results": code_results,
            "valid": answer is not None
        }

    def sc_tir_aggregate(self, solutions: list[dict]) -> dict:
        """
        Self-Consistency TIR aggregation.

        1. Filter to solutions with valid code execution
        2. Filter to solutions with extracted answers
        3. Majority vote
        """
        # Filter valid solutions
        valid = [s for s in solutions if s["valid"]]

        # Filter by code execution success
        code_verified = [
            s for s in valid
            if all("Error" not in r["output"] for r in s["code_results"])
        ]

        # Use code-verified if available, else all valid
        pool = code_verified if code_verified else valid

        if not pool:
            return {"answer": None, "confidence": 0}

        # Majority vote
        answers = [s["answer"] for s in pool]
        answer_counts = Counter(answers)
        best_answer, count = answer_counts.most_common(1)[0]

        return {
            "answer": best_answer,
            "confidence": count / len(pool),
            "n_agreeing": count,
            "n_total": len(pool),
            "n_code_verified": len(code_verified)
        }


class NuminaMathTrainingData:
    """
    NuminaMath's training data pipeline.

    ~1M math problems with solutions:
    - Synthetic problems generated by LLMs
    - Filtered for correctness via code execution
    - Diverse difficulty levels
    """

    def __init__(self, strong_model):
        self.model = strong_model
        self.executor = PythonCodeExecutor()

    def generate_training_pair(
        self,
        seed_problem: str
    ) -> dict | None:
        """
        Generate training (problem, solution) pair.
        """
        # Generate solution with TIR format
        solution = self.generate_solution(seed_problem)

        # Verify with code execution
        if not self.verify_solution(seed_problem, solution):
            return None

        return {
            "problem": seed_problem,
            "solution": solution["reasoning"],
            "answer": solution["answer"]
        }

    def create_dataset(
        self,
        seed_problems: list[str],
        augmentation_factor: int = 10
    ) -> list[dict]:
        """
        Create large training dataset.

        1. Start with seed problems
        2. Augment by varying numbers/contexts
        3. Filter by code verification
        """
        dataset = []

        for problem in tqdm(seed_problems):
            # Original
            pair = self.generate_training_pair(problem)
            if pair:
                dataset.append(pair)

            # Augmentations
            for _ in range(augmentation_factor):
                augmented = self.augment_problem(problem)
                pair = self.generate_training_pair(augmented)
                if pair:
                    dataset.append(pair)

        return dataset

    def augment_problem(self, problem: str) -> str:
        """
        Create variation of problem.

        - Change numbers
        - Change context
        - Modify constraints
        """
        prompt = f"""Create a variation of this math problem.
Change the numbers and some context, but keep the same mathematical structure.

Original: {problem}

Variation:"""

        return self.model.generate(prompt, max_tokens=200)

Part 14: Reasoning Model Distillation

Distilling reasoning capabilities from large models to smaller ones enables practical deployment.

Why Distillation Matters for Mathematical Reasoning

The best mathematical reasoning systems are enormous. DeepSeek-R1 is 671B parameters. GPT-o1 and Gemini Deep Think are similarly massive. Running these models requires multiple high-end GPUs, costs dollars per query, and takes tens of seconds per response.

For most applications—homework help, competition practice, educational tools—this is impractical. Distillation offers a solution: transfer the reasoning capabilities of large models into smaller, deployable ones.

The Surprising Effectiveness of Simple Distillation

The DeepSeek team discovered something counterintuitive: simple supervised fine-tuning (SFT) on reasoning traces is remarkably effective.

The traditional view was that reasoning requires the model to "discover" chain-of-thought through reinforcement learning. You can't just show a model reasoning examples and expect it to reason—the model needs to learn the underlying process.

DeepSeek proved this wrong. When they fine-tuned small models (1.5B-70B) on 800K reasoning traces from R1-671B, the small models acquired strong reasoning capabilities. A 7B model trained this way achieves 55.5% on AIME—comparable to much larger models.

Why does this work? The key insight: reasoning patterns are learnable from examples. The large model has discovered effective reasoning strategies (backtracking, self-verification, trying alternatives). When small models see thousands of examples of these strategies, they learn to apply them. The small model doesn't need to discover reasoning from scratch—it just needs to imitate demonstrated reasoning.

The Distillation vs. RL Tradeoff

You might ask: why not just train small models with RL directly? The DeepSeek paper compared both approaches:

Approach7B AIMETraining CostNotes
RL from scratch~35%2 weeks GPU clusterDiscovers own reasoning patterns
Distillation from R155.5%2 days, 8 GPUsLearns teacher's patterns

Distillation is both cheaper and more effective. The small model benefits from reasoning patterns that the large model discovered over extensive RL training. It's like a student learning from a master's solved examples rather than rediscovering mathematics from first principles.

Quality Over Quantity

The s1 paper demonstrated that very small datasets can work. Their "budget distillation" approach:

  • Only ~1,000 high-quality examples
  • Total training cost under $50
  • Performance competitive with o1-preview on math/coding

The key is quality filtering. Each training example must have:

  1. Correct final answer (verified by code execution)
  2. Clear reasoning chain (explicit step-by-step logic)
  3. No unnecessary verbosity (concise but complete)
  4. Diverse problem coverage (different domains and difficulty levels)

A small dataset of perfect examples beats a large dataset of mediocre ones.

On-Policy vs. Off-Policy Distillation

Standard distillation is "off-policy": generate all training data upfront, then train the student. On-policy distillation is more sophisticated:

  1. Student generates partial solution
  2. Teacher corrects/completes it
  3. Student learns from correction
  4. Repeat with updated student

This addresses distribution shift: off-policy training uses the teacher's distribution, but the student needs to perform on its own distribution. On-policy distillation progressively aligns the training distribution with the student's actual behavior.

Results show 4-14% improvement from on-policy over standard distillation.

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    REASONING DISTILLATION                                │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  KEY INSIGHT:                                                            │
│  ────────────                                                            │
│                                                                          │
│  "Reasoning patterns of larger models can be distilled into smaller    │
│   models, resulting in better performance compared to the reasoning    │
│   patterns discovered through RL on small models."                     │
│   - DeepSeek R1 Paper                                                  │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  DISTILLED MODELS:                                                       │
│  ─────────────────                                                       │
│                                                                          │
│  Source: DeepSeek-R1 (671B)                                            │
│  Dataset: 800K reasoning examples                                       │
│                                                                          │
│  Target          Base Model              AIME    MATH                   │
│  ─────────────────────────────────────────────────────────────────────  │
│  R1-Distill-1.5B   Qwen2.5-1.5B         28.9%   83.9%                  │
│  R1-Distill-7B     Qwen2.5-7B           55.5%   92.8%                  │
│  R1-Distill-8B     Llama-3.1-8B         50.4%   89.1%                  │
│  R1-Distill-14B    Qwen2.5-14B          69.7%   93.9%                  │
│  R1-Distill-32B    Qwen2.5-32B          72.6%   94.3%                  │
│  R1-Distill-70B    Llama-3.3-70B        79.8%   94.5%                  │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  s1 MODEL (Budget Distillation):                                        │
│  ────────────────────────────────                                        │
│                                                                          │
│  "Researchers created an open rival to o1 for under $50"               │
│                                                                          │
│  • Distilled from Gemini 2.0 Flash Thinking                            │
│  • Small dataset + SFT                                                  │
│  • Competitive with o1-preview on math/coding                          │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘
Python
class ReasoningDistillation:
    """
    Distill reasoning capabilities from large to small models.

    Key insight: SFT on reasoning traces is surprisingly effective.
    No additional RL needed for strong performance.
    """

    def __init__(
        self,
        teacher_model,  # Large reasoning model
        student_model,  # Small base model
        tokenizer
    ):
        self.teacher = teacher_model
        self.student = student_model
        self.tokenizer = tokenizer

    def generate_distillation_data(
        self,
        problems: list[str],
        samples_per_problem: int = 3
    ) -> list[dict]:
        """
        Generate distillation dataset from teacher.

        For each problem:
        1. Teacher generates reasoning traces
        2. Filter for correct answers
        3. Keep high-quality traces
        """
        dataset = []

        for problem in tqdm(problems):
            correct_traces = []

            for _ in range(samples_per_problem * 2):  # Oversample
                # Generate with teacher
                trace = self.teacher.generate(
                    f"Solve step by step:\n\n{problem}\n\nSolution:",
                    max_tokens=4096,
                    temperature=0.7
                )

                # Check if correct
                answer = self.extract_answer(trace)
                ground_truth = self.get_ground_truth(problem)

                if self.verify_answer(answer, ground_truth):
                    correct_traces.append({
                        "problem": problem,
                        "trace": trace,
                        "answer": answer
                    })

                if len(correct_traces) >= samples_per_problem:
                    break

            dataset.extend(correct_traces)

        return dataset

    def distill(
        self,
        dataset: list[dict],
        epochs: int = 3,
        learning_rate: float = 2e-5
    ) -> dict:
        """
        Distill reasoning into student model via SFT.

        Simple but effective: just train student to
        reproduce teacher's reasoning traces.
        """
        # Format for training
        training_data = []
        for item in dataset:
            training_data.append({
                "input": f"Solve step by step:\n\n{item['problem']}\n\nSolution:",
                "output": item["trace"]
            })

        # Standard SFT
        train_dataset = Dataset.from_list(training_data)

        training_args = TrainingArguments(
            output_dir="./distilled_model",
            num_train_epochs=epochs,
            per_device_train_batch_size=4,
            gradient_accumulation_steps=8,
            learning_rate=learning_rate,
            warmup_ratio=0.1,
            logging_steps=10,
            save_strategy="epoch",
            bf16=True
        )

        trainer = Trainer(
            model=self.student,
            args=training_args,
            train_dataset=train_dataset,
            tokenizer=self.tokenizer
        )

        trainer.train()

        return trainer.state.log_history


class OnPolicyDistillation:
    """
    On-Policy Distillation: More effective than offline.

    Instead of using static teacher traces,
    generate new traces based on student's current policy.

    Results: 4-14% improvement over standard distillation.
    """

    def __init__(
        self,
        teacher,
        student,
        tokenizer
    ):
        self.teacher = teacher
        self.student = student
        self.tokenizer = tokenizer

    def distill_on_policy(
        self,
        problems: list[str],
        iterations: int = 5
    ) -> dict:
        """
        On-policy distillation loop.

        1. Student generates partial solution
        2. Teacher completes/corrects
        3. Student learns from corrected version
        4. Repeat
        """
        history = []

        for iteration in range(iterations):
            print(f"Iteration {iteration + 1}/{iterations}")

            # Generate training data on-policy
            training_data = []

            for problem in tqdm(problems):
                # Student generates partial solution
                student_attempt = self.student.generate(
                    f"Solve:\n{problem}\n\nSolution:",
                    max_tokens=1024,
                    temperature=0.8
                )

                # Teacher corrects/completes
                teacher_completion = self.teacher.generate(
                    f"""The student attempted this problem but may have errors.

Problem: {problem}

Student's attempt:
{student_attempt}

Provide the correct, complete solution:

Solution:""",
                    max_tokens=4096,
                    temperature=0.3
                )

                # Verify teacher's solution
                if self.verify_solution(problem, teacher_completion):
                    training_data.append({
                        "input": f"Solve:\n{problem}\n\nSolution:",
                        "output": teacher_completion
                    })

            # Train student on this iteration's data
            metrics = self.train_student(training_data)
            history.append({
                "iteration": iteration,
                "n_examples": len(training_data),
                "metrics": metrics
            })

            # Evaluate
            eval_results = self.evaluate()
            print(f"  MATH: {eval_results['math']:.1%}")

        return history


class BudgetDistillation:
    """
    Budget Distillation: Create reasoning model under $50.

    Based on s1 paper: small dataset + careful SFT
    can match expensive models.

    Key findings:
    - ~1000 high-quality examples sufficient
    - Quality > quantity for reasoning
    - Chain-of-thought format crucial
    """

    def __init__(self, base_model_name: str = "Qwen/Qwen2.5-7B"):
        self.model = AutoModelForCausalLM.from_pretrained(base_model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)

    def create_budget_dataset(
        self,
        source_model,  # e.g., Gemini Flash Thinking
        seed_problems: list[str],
        n_examples: int = 1000
    ) -> list[dict]:
        """
        Create small but high-quality distillation dataset.

        Focus on:
        1. Diverse problem types
        2. High-quality reasoning traces
        3. Verified correct answers
        """
        dataset = []

        # Sample diverse problems
        sampled = random.sample(seed_problems, min(n_examples * 2, len(seed_problems)))

        for problem in tqdm(sampled):
            if len(dataset) >= n_examples:
                break

            # Generate reasoning trace
            trace = source_model.generate(
                f"Think step by step and solve:\n\n{problem}",
                max_tokens=4096
            )

            # Quality checks
            if not self.quality_check(trace):
                continue

            # Verify answer
            if not self.verify_answer(problem, trace):
                continue

            dataset.append({
                "problem": problem,
                "trace": trace
            })

        return dataset

    def quality_check(self, trace: str) -> bool:
        """
        Check reasoning trace quality.
        """
        # Minimum reasoning length
        if len(trace) < 500:
            return False

        # Has structured thinking
        thinking_indicators = [
            "let me", "first", "then", "therefore",
            "because", "since", "this means"
        ]
        if sum(1 for ind in thinking_indicators if ind in trace.lower()) < 3:
            return False

        # Not too repetitive
        sentences = trace.split(".")
        unique_ratio = len(set(sentences)) / len(sentences)
        if unique_ratio < 0.7:
            return False

        return True

    def train_budget_model(
        self,
        dataset: list[dict],
        epochs: int = 3
    ) -> dict:
        """
        Train with minimal compute.

        Cost breakdown (approximate):
        - Dataset generation: ~$30 (API calls)
        - Training: ~$20 (GPU time)
        - Total: <$50
        """
        # Estimate cost
        n_tokens = sum(len(d["trace"].split()) * 1.3 for d in dataset)
        api_cost = n_tokens / 1_000_000 * 0.15  # Rough estimate
        print(f"Estimated API cost: ${api_cost:.2f}")

        # Training
        training_args = TrainingArguments(
            output_dir="./budget_model",
            num_train_epochs=epochs,
            per_device_train_batch_size=2,
            gradient_accumulation_steps=16,
            learning_rate=1e-5,
            warmup_ratio=0.1,
            bf16=True,
            # Minimal logging for cost
            logging_steps=50,
            save_strategy="epoch"
        )

        # Quick training
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=self.prepare_dataset(dataset),
            tokenizer=self.tokenizer
        )

        trainer.train()

        return {"cost_estimate": api_cost + 20}  # API + compute

Conclusion: The Path to Gold

Building AI systems that win math competitions in 2025 requires combining multiple techniques across training, inference, and system architecture:

For Training (if you have resources)

  1. Dr. GRPO for bias-free RL training (avoid length inflation)
  2. Rule-based rewards (RLVR) for verifiable domains
  3. Curriculum learning from easy to hard
  4. Cold-start SFT before RL
  5. Process Reward Models for step-level feedback
  6. Self-verification training (DeepSeekMath-V2 approach)
  7. Self-play evolution (rStar-Math: 7B model achieves 90% on MATH)

For Inference (no training required)

  1. Constrained MCTS (CMCTS) for principled search
  2. Tool-Integrated Reasoning (SC-TIR) with code verification
  3. Verification-refinement pipelines for IMO-level problems
  4. Parallel thinking (explore multiple approaches simultaneously)
  5. Self-consistency through majority voting
  6. Multi-agent debate for hard problems

For Agentic Systems

  1. Tool use (Python, SymPy, Lean 4)
  2. Formal verification (DeepSeek-Prover-V2 for guaranteed correctness)
  3. Domain-specific strategies (AlphaGeometry2 for geometry: 84%)
  4. Autonomous tool selection (AutoTIR learns when to use tools)
  5. Agentic RL (rStar2-Agent: 14B matches 671B models)

What Works in 2025

TechniqueWhen to UseExpected Gain
Dr. GRPOTraining from scratch+5-10% vs standard GRPO
rStar self-playNo teacher model available7B → 90% MATH
CMCTSInference on hard problems+8-11% vs standard MCTS
SC-TIRAnswer-only problems+15-25% with code verification
Self-verificationFull proof problemsGold at IMO 2025
Formal verificationWhen certainty needed100% if Lean accepts
DistillationLimited compute7B achieves 55% AIME

Key Takeaways

  1. Self-verification is the breakthrough: DeepSeekMath-V2 and rStar achieve gold by training models to find their own errors before submitting
  2. Tools are mandatory at the frontier: o4-mini jumps 92.7% → 99.5% with Python—the gap is too large to ignore
  3. Small models can compete: rStar-Math 7B matches o1-preview; rStar2-Agent 14B matches R1 671B
  4. Formal verification is maturing: DeepSeek-Prover-V2 solves 49 Putnam problems in Lean 4
  5. Multiple paths are essential: Parallel thinking + verification beats sequential refinement
  6. Distillation democratizes access: Budget distillation ($50) creates competitive reasoning models

The field advances at unprecedented speed—2025 saw multiple teams achieve IMO gold. The techniques in this guide represent the current frontier; expect rapid iteration.

Frequently Asked Questions

Enrico Piovano, PhD

Co-founder & CTO at Goji AI. Former Applied Scientist at Amazon (Alexa & AGI), focused on Agentic AI and LLMs. PhD in Electrical Engineering from Imperial College London. Gold Medalist at the National Mathematical Olympiad.

Related Articles