Skip to main content
Back to Blog

Test-Time Compute Scaling: CoT, ToT, MCTS, and Search-Based Reasoning

A comprehensive guide to inference-time scaling techniques—Chain of Thought, Tree of Thoughts, Monte Carlo Tree Search, Process Reward Models, and the HuggingFace search-and-learn framework.

11 min read
Share:

The Test-Time Compute Revolution

For years, scaling AI meant one thing: bigger models trained on more data. But 2024-2025 introduced a paradigm shift—test-time compute scaling—where models improve by "thinking longer" at inference time rather than training longer.

The insight is profound: a small model that thinks carefully can outperform a large model that responds immediately. UC Berkeley research demonstrated that "with the right inference-time scaling approach, a 1B parameter model can outperform a 405B Llama 3 model that lacks inference-time scaling."

This post provides a comprehensive technical guide to test-time compute techniques, from basic Chain-of-Thought prompting to sophisticated tree search algorithms with Process Reward Models.

Understanding Test-Time Compute

What Is Test-Time Compute?

Test-time compute (also called inference-time scaling) refers to using additional computation during model inference to improve output quality. Instead of generating a response immediately, the model:

The paradigm shift: For a decade, AI progress followed a simple recipe: train bigger models on more data. GPT-2 → GPT-3 → GPT-4 represented scaling training compute. But this approach has diminishing returns (scaling laws flatten) and massive fixed costs (billions of dollars per training run). Test-time compute offers a different lever: instead of building a smarter brain, give the brain more time to think. A smaller model that "thinks" for 10 seconds can outperform a larger model that responds instantly.

Why this is fundamentally different: Training compute is a fixed investment—you pay once and get a model. Test-time compute scales with usage—you pay per query, more for harder problems. This enables adaptive intelligence: easy questions get quick, cheap answers; hard questions get expensive, thorough reasoning. You couldn't achieve this by "training more" because training is uniform across all future queries.

  1. Generates multiple candidate solutions
  2. Explores different reasoning paths
  3. Evaluates and selects the best output
  4. Iteratively refines its answer

The key trade-off: more compute at inference = better quality, but higher latency and cost.

Why It Matters

ApproachCostQuality Scaling
Larger modelHigher training + inferenceDiminishing returns (scaling laws plateau)
More training dataHigher data collection + trainingData availability limits
Test-time computeVariable per queryScales with problem difficulty

Test-time compute is adaptive—you can allocate more thinking to hard problems and less to easy ones. This is fundamentally more efficient than uniformly scaling model size.

The Compute-Optimal Frontier

Research shows optimal allocation depends on problem difficulty:

Code
Easy problems: Minimal test-time compute (fast response)
Medium problems: Moderate compute (some exploration)
Hard problems: Maximum compute (extensive search)

The challenge is determining difficulty before solving—leading to adaptive algorithms that start with less compute and scale up if needed.

Chain-of-Thought (CoT) Prompting

The Foundation

Chain-of-Thought prompting is the simplest test-time compute technique. Instead of answering directly, the model generates intermediate reasoning steps.

Without CoT:

Code
Q: Roger has 5 tennis balls. He buys 2 more cans of 3 tennis balls each.
   How many tennis balls does he have now?
A: 11

With CoT:

Code
Q: Roger has 5 tennis balls. He buys 2 more cans of 3 tennis balls each.
   How many tennis balls does he have now?
A: Roger starts with 5 balls. 2 cans of 3 balls each means 2 × 3 = 6 new balls.
   5 + 6 = 11 balls total.
   The answer is 11.

Zero-Shot CoT

The famous "Let's think step by step" prompt enables CoT without examples:

Python
def zero_shot_cot(question: str, model) -> str:
    """Apply zero-shot chain-of-thought prompting."""
    prompt = f"""{question}

Let's think step by step."""

    response = model.generate(prompt)
    return response

Research showed this simple addition improves accuracy on math problems from ~17% to ~78% for certain models.

Few-Shot CoT

Providing examples of reasoning chains helps the model learn the expected format:

Python
def few_shot_cot(question: str, model, examples: list[dict]) -> str:
    """Apply few-shot chain-of-thought prompting."""
    prompt = ""
    for ex in examples:
        prompt += f"Q: {ex['question']}\n"
        prompt += f"A: {ex['reasoning']} The answer is {ex['answer']}.\n\n"

    prompt += f"Q: {question}\nA:"

    response = model.generate(prompt)
    return response

# Example usage
examples = [
    {
        "question": "There are 15 trees. Workers plant 6 more. How many trees?",
        "reasoning": "We start with 15 trees. Workers plant 6 more trees. 15 + 6 = 21.",
        "answer": "21"
    },
    {
        "question": "If there are 3 cars and each has 4 wheels, how many wheels total?",
        "reasoning": "Each car has 4 wheels. With 3 cars: 3 × 4 = 12 wheels.",
        "answer": "12"
    }
]

Self-Ask CoT

A variant where the model asks and answers sub-questions:

Python
def self_ask_cot(question: str, model) -> str:
    """Self-Ask chain-of-thought prompting."""
    prompt = f"""Question: {question}

Are follow up questions needed here: Yes

Follow up: [First sub-question the model should identify]
Intermediate answer: [Answer to follow-up]
Follow up: [Second sub-question if needed]
Intermediate answer: [Answer]
So the final answer is:"""

    response = model.generate(prompt)
    return response

CoT Limitations

  1. Linearity: CoT explores one path; if it goes wrong early, everything fails
  2. No backtracking: Can't revise earlier steps
  3. Hallucinated reasoning: Steps may be plausible-sounding but wrong
  4. Length sensitivity: Very long chains can degrade (overthinking)

Research found "an optimal reasoning token budget of approximately 4K, beyond which performance may degrade due to overthinking."

Self-Consistency

The Technique

Self-consistency generates multiple CoT paths and takes a majority vote on the final answer:

Python
import collections

def self_consistency(question: str, model, n_samples: int = 10, temperature: float = 0.7) -> str:
    """Self-consistency decoding with majority voting."""

    prompt = f"{question}\nLet's think step by step."

    # Generate multiple reasoning chains
    responses = []
    for _ in range(n_samples):
        response = model.generate(prompt, temperature=temperature)
        responses.append(response)

    # Extract final answers
    answers = [extract_final_answer(r) for r in responses]

    # Majority vote
    answer_counts = collections.Counter(answers)
    best_answer = answer_counts.most_common(1)[0][0]

    return best_answer

def extract_final_answer(response: str) -> str:
    """Extract the final answer from a CoT response."""
    # Look for common patterns
    patterns = [
        r"The answer is[:\s]*(.+?)[\.\n]",
        r"Therefore[,:\s]*(.+?)[\.\n]",
        r"So[,:\s]*(.+?)[\.\n]",
        r"= (\d+)",
    ]
    for pattern in patterns:
        match = re.search(pattern, response, re.IGNORECASE)
        if match:
            return match.group(1).strip()
    return response.split('\n')[-1].strip()

The self-consistency algorithm:

  1. Diverse generation: Use temperature sampling (typically 0.7-0.9) to create variety. Each response takes a potentially different reasoning path.

  2. Answer extraction: Parse each response to find the final answer. The patterns handle common formats like "The answer is 42" or "= 42".

  3. Majority voting: The most frequent answer wins. This works because correct reasoning paths tend to converge on the same answer, while errors are more random.

Why Self-Consistency Works

Different reasoning paths may make different mistakes, but the correct answer is more likely to appear in multiple paths. The diversity of reasoning (from temperature sampling) helps cover different solution strategies.

Problem TypeWithout Self-ConsistencyWith Self-Consistency (n=40)
GSM8K56.5%74.4%
SVAMP68.9%86.6%
AQuA35.8%48.3%

Weighted Self-Consistency

Weight votes by confidence scores:

Python
def weighted_self_consistency(
    question: str,
    model,
    n_samples: int = 10
) -> str:
    """Self-consistency with confidence weighting."""

    answers_with_confidence = []

    for _ in range(n_samples):
        response, logprobs = model.generate_with_logprobs(
            f"{question}\nLet's think step by step.",
            temperature=0.7
        )

        answer = extract_final_answer(response)

        # Use average log probability as confidence
        confidence = sum(logprobs) / len(logprobs) if logprobs else 0
        answers_with_confidence.append((answer, confidence))

    # Weighted aggregation
    answer_weights = collections.defaultdict(float)
    for answer, conf in answers_with_confidence:
        answer_weights[answer] += math.exp(conf)  # Convert log prob to prob

    best_answer = max(answer_weights, key=answer_weights.get)
    return best_answer

Tree of Thoughts (ToT)

Beyond Linear Reasoning

Tree of Thoughts extends CoT by exploring multiple reasoning branches simultaneously, with the ability to backtrack and try alternative paths.

Code
         [Problem]
            │
    ┌───────┼───────┐
    │       │       │
[Thought 1][Thought 2][Thought 3]
    │       │       │
    ▼       ▼       ▼
  [Dead   [Branch  [Branch
   End]    A]       B]
            │       │
           ...     ...

Core Algorithm

Python
from dataclasses import dataclass
from typing import Optional
import heapq

@dataclass
class ThoughtNode:
    """A node in the thought tree."""
    state: str  # Current reasoning state
    parent: Optional['ThoughtNode']
    children: list['ThoughtNode']
    value: float  # Evaluation score
    depth: int

    def __lt__(self, other):
        return self.value > other.value  # Higher value = higher priority


class TreeOfThoughts:
    def __init__(
        self,
        model,
        evaluator,
        max_depth: int = 5,
        branching_factor: int = 3,
        beam_width: int = 5
    ):
        self.model = model
        self.evaluator = evaluator
        self.max_depth = max_depth
        self.branching_factor = branching_factor
        self.beam_width = beam_width

    def generate_thoughts(self, state: str, problem: str) -> list[str]:
        """Generate candidate next thoughts from current state."""
        prompt = f"""Problem: {problem}

Current reasoning:
{state}

Generate {self.branching_factor} distinct next steps for solving this problem.
Each step should explore a different approach or aspect.
Format: List each thought on a new line starting with "Thought:"
"""
        response = self.model.generate(prompt)

        # Parse thoughts
        thoughts = []
        for line in response.split('\n'):
            if line.strip().startswith('Thought:'):
                thoughts.append(line.replace('Thought:', '').strip())

        return thoughts[:self.branching_factor]

    def evaluate_state(self, state: str, problem: str) -> float:
        """Evaluate how promising a reasoning state is."""
        prompt = f"""Problem: {problem}

Reasoning so far:
{state}

Evaluate this reasoning on a scale of 1-10:
- Is it making progress toward the solution?
- Is the logic sound?
- Are there errors or dead ends?

Score (1-10):"""

        response = self.evaluator.generate(prompt)

        # Extract score
        try:
            score = float(re.search(r'(\d+)', response).group(1))
            return min(10, max(1, score)) / 10
        except:
            return 0.5

    def solve_bfs(self, problem: str) -> str:
        """Breadth-first search through thought tree."""
        root = ThoughtNode(
            state="",
            parent=None,
            children=[],
            value=1.0,
            depth=0
        )

        current_level = [root]

        for depth in range(self.max_depth):
            next_level = []

            for node in current_level:
                # Generate thoughts
                thoughts = self.generate_thoughts(node.state, problem)

                for thought in thoughts:
                    new_state = f"{node.state}\n{thought}" if node.state else thought
                    value = self.evaluate_state(new_state, problem)

                    child = ThoughtNode(
                        state=new_state,
                        parent=node,
                        children=[],
                        value=value,
                        depth=depth + 1
                    )
                    node.children.append(child)
                    next_level.append(child)

            # Beam search: keep top-k
            next_level.sort(key=lambda x: x.value, reverse=True)
            current_level = next_level[:self.beam_width]

            # Check for solution
            for node in current_level:
                if self.is_solution(node.state, problem):
                    return self.extract_answer(node.state, problem)

        # Return best found
        best_node = max(current_level, key=lambda x: x.value)
        return self.extract_answer(best_node.state, problem)

    def solve_dfs(self, problem: str, max_iterations: int = 100) -> str:
        """Depth-first search with backtracking."""
        stack = [ThoughtNode(state="", parent=None, children=[], value=1.0, depth=0)]
        best_solution = None
        best_value = 0

        iterations = 0
        while stack and iterations < max_iterations:
            iterations += 1
            node = stack.pop()

            if node.depth >= self.max_depth:
                if node.value > best_value:
                    best_value = node.value
                    best_solution = node.state
                continue

            thoughts = self.generate_thoughts(node.state, problem)

            for thought in thoughts:
                new_state = f"{node.state}\n{thought}" if node.state else thought
                value = self.evaluate_state(new_state, problem)

                # Pruning: skip unpromising branches
                if value < 0.3:
                    continue

                child = ThoughtNode(
                    state=new_state,
                    parent=node,
                    children=[],
                    value=value,
                    depth=node.depth + 1
                )
                stack.append(child)

        return self.extract_answer(best_solution or "", problem)

    def is_solution(self, state: str, problem: str) -> bool:
        """Check if state contains a complete solution."""
        prompt = f"""Problem: {problem}

Reasoning:
{state}

Does this reasoning reach a final answer? (yes/no)"""
        response = self.model.generate(prompt).strip().lower()
        return 'yes' in response

    def extract_answer(self, state: str, problem: str) -> str:
        """Extract final answer from reasoning state."""
        prompt = f"""Problem: {problem}

Reasoning:
{state}

Based on this reasoning, what is the final answer?"""
        return self.model.generate(prompt)

BFS vs DFS for Tree of Thoughts:

  • BFS (solve_bfs): Explores all nodes at each depth before going deeper. Uses beam search to keep only the top-k most promising branches. Best when you want to thoroughly explore early reasoning steps before committing.

  • DFS (solve_dfs): Goes deep first, then backtracks. More memory-efficient. Good when solutions require many steps but branching factor is low.

Key design decisions:

  1. Evaluation as a separate model call: The evaluate_state method asks the LLM to score reasoning quality. This is expensive but necessary—we can't assume any reasoning is good just because it's coherent.

  2. Pruning threshold (0.3 in DFS): Low-scoring branches are abandoned. This prevents wasting compute on dead ends but risks missing unconventional solutions.

  3. Solution detection: The is_solution method asks the LLM if reasoning is complete. This avoids hardcoding answer formats.

  4. Beam width controls breadth-accuracy trade-off: Higher beam width explores more but costs more compute.

ToT for Different Problem Types

Game of 24 (combine 4 numbers to get 24):

Python
def tot_game_of_24(numbers: list[int], tot: TreeOfThoughts) -> str:
    problem = f"Use the numbers {numbers} with +, -, *, / to get 24."

    # Custom thought generator for this domain
    def generate_game_thoughts(state, remaining):
        thoughts = []
        for i, a in enumerate(remaining):
            for j, b in enumerate(remaining):
                if i != j:
                    for op in ['+', '-', '*', '/']:
                        if op == '/' and b == 0:
                            continue
                        result = eval(f"{a}{op}{b}")
                        new_remaining = [n for k, n in enumerate(remaining)
                                        if k != i and k != j] + [result]
                        thoughts.append(f"{a} {op} {b} = {result}")
        return thoughts

    return tot.solve_bfs(problem)

Creative Writing:

Python
def tot_creative_writing(topic: str, tot: TreeOfThoughts) -> str:
    problem = f"Write a coherent passage about: {topic}"

    # Thoughts are paragraph plans
    # Evaluation considers coherence, creativity, relevance
    return tot.solve_bfs(problem)

ToT vs CoT Comparison

AspectChain-of-ThoughtTree of Thoughts
StructureLinearBranching tree
BacktrackingNoYes
ExplorationSingle pathMultiple paths
Compute costO(n)O(b^d) where b=branching, d=depth
Best forSimple reasoningComplex multi-step problems

Forest of Thoughts (FoT)

Forest of Thoughts runs multiple ToT searches in parallel and aggregates results:

Python
from concurrent.futures import ThreadPoolExecutor, as_completed

class ForestOfThoughts:
    def __init__(
        self,
        model,
        evaluator,
        n_trees: int = 5,
        tree_config: dict = None
    ):
        self.model = model
        self.evaluator = evaluator
        self.n_trees = n_trees
        self.tree_config = tree_config or {
            'max_depth': 5,
            'branching_factor': 3,
            'beam_width': 3
        }

    def solve(self, problem: str) -> str:
        """Run multiple ToT searches and aggregate."""

        trees = [
            TreeOfThoughts(
                self.model,
                self.evaluator,
                **self.tree_config
            )
            for _ in range(self.n_trees)
        ]

        # Run in parallel
        results = []
        with ThreadPoolExecutor(max_workers=self.n_trees) as executor:
            futures = {
                executor.submit(tree.solve_bfs, problem): i
                for i, tree in enumerate(trees)
            }

            for future in as_completed(futures):
                try:
                    result = future.result()
                    results.append(result)
                except Exception as e:
                    print(f"Tree search failed: {e}")

        # Aggregate results
        return self.aggregate_results(results, problem)

    def aggregate_results(self, results: list[str], problem: str) -> str:
        """Aggregate multiple tree search results."""
        if not results:
            return "Unable to solve."

        if len(results) == 1:
            return results[0]

        # For answers that can be compared directly (numbers, etc)
        answer_counts = collections.Counter(
            self.normalize_answer(r) for r in results
        )

        # If majority agreement, return that
        most_common, count = answer_counts.most_common(1)[0]
        if count > len(results) / 2:
            return most_common

        # Otherwise, use model to synthesize
        prompt = f"""Problem: {problem}

Multiple reasoning approaches produced these answers:
{chr(10).join(f"- {r}" for r in results)}

Analyze these answers and determine the most likely correct answer.
If they disagree, explain which reasoning is most sound."""

        return self.model.generate(prompt)

    def normalize_answer(self, answer: str) -> str:
        """Normalize answer for comparison."""
        # Remove whitespace, convert to lowercase
        normalized = ' '.join(answer.lower().split())

        # Try to extract numeric answer
        numbers = re.findall(r'-?\d+\.?\d*', normalized)
        if numbers:
            return numbers[-1]  # Usually final answer is last

        return normalized

Diverse Tree Initialization

For better coverage, initialize trees differently:

Python
class DiverseForestOfThoughts(ForestOfThoughts):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.initial_prompts = [
            "Let's solve this systematically.",
            "Let's try a creative approach.",
            "Let's start with what we know for certain.",
            "Let's work backwards from the goal.",
            "Let's break this into smaller parts."
        ]

    def create_diverse_trees(self):
        """Create trees with different initial strategies."""
        trees = []
        for i, init_prompt in enumerate(self.initial_prompts[:self.n_trees]):
            tree = TreeOfThoughts(
                self.model,
                self.evaluator,
                **self.tree_config
            )
            tree.initial_prompt = init_prompt
            trees.append(tree)
        return trees

Monte Carlo Tree Search (MCTS)

MCTS for Language Models

MCTS balances exploration and exploitation using UCB (Upper Confidence Bound):

Python
import math
import random

@dataclass
class MCTSNode:
    """Node for Monte Carlo Tree Search."""
    state: str
    parent: Optional['MCTSNode']
    children: dict[str, 'MCTSNode']  # action -> child
    visits: int = 0
    value: float = 0.0
    untried_actions: list[str] = None

    def ucb_score(self, exploration_weight: float = 1.414) -> float:
        """Upper Confidence Bound score."""
        if self.visits == 0:
            return float('inf')

        exploitation = self.value / self.visits
        exploration = exploration_weight * math.sqrt(
            math.log(self.parent.visits) / self.visits
        )
        return exploitation + exploration

    def best_child(self, exploration_weight: float = 1.414) -> 'MCTSNode':
        """Select child with highest UCB score."""
        return max(
            self.children.values(),
            key=lambda c: c.ucb_score(exploration_weight)
        )

    def is_fully_expanded(self) -> bool:
        return len(self.untried_actions) == 0

    def is_terminal(self) -> bool:
        return self.is_fully_expanded() and len(self.children) == 0


class MCTSReasoner:
    def __init__(
        self,
        model,
        reward_model,
        num_simulations: int = 100,
        max_depth: int = 10,
        exploration_weight: float = 1.414
    ):
        self.model = model
        self.reward_model = reward_model
        self.num_simulations = num_simulations
        self.max_depth = max_depth
        self.exploration_weight = exploration_weight

    def solve(self, problem: str) -> str:
        """Solve using MCTS."""
        # Initialize root
        root = MCTSNode(
            state="",
            parent=None,
            children={},
            untried_actions=self.get_actions(problem, "")
        )

        for _ in range(self.num_simulations):
            # Selection
            node = self.select(root)

            # Expansion
            if not node.is_terminal() and node.visits > 0:
                node = self.expand(node, problem)

            # Simulation
            reward = self.simulate(node, problem)

            # Backpropagation
            self.backpropagate(node, reward)

        # Return best path
        return self.extract_solution(root, problem)

    def select(self, node: MCTSNode) -> MCTSNode:
        """Select most promising node using UCB."""
        while node.children and node.is_fully_expanded():
            node = node.best_child(self.exploration_weight)
        return node

    def expand(self, node: MCTSNode, problem: str) -> MCTSNode:
        """Expand node with one untried action."""
        if not node.untried_actions:
            return node

        action = node.untried_actions.pop()
        new_state = f"{node.state}\n{action}" if node.state else action

        child = MCTSNode(
            state=new_state,
            parent=node,
            children={},
            untried_actions=self.get_actions(problem, new_state)
        )
        node.children[action] = child
        return child

    def simulate(self, node: MCTSNode, problem: str) -> float:
        """Simulate from node to terminal state."""
        state = node.state
        depth = state.count('\n') if state else 0

        # Random rollout
        while depth < self.max_depth:
            actions = self.get_actions(problem, state)
            if not actions:
                break

            action = random.choice(actions)
            state = f"{state}\n{action}" if state else action
            depth += 1

        # Evaluate final state
        return self.evaluate(state, problem)

    def backpropagate(self, node: MCTSNode, reward: float):
        """Backpropagate reward up the tree."""
        while node:
            node.visits += 1
            node.value += reward
            node = node.parent

    def get_actions(self, problem: str, state: str) -> list[str]:
        """Get possible next reasoning steps."""
        prompt = f"""Problem: {problem}

Current reasoning:
{state if state else "(Starting fresh)"}

Generate 5 possible next reasoning steps. Each should be a single logical step.
Format each on a new line starting with "-"."""

        response = self.model.generate(prompt)
        actions = [
            line.lstrip('- ').strip()
            for line in response.split('\n')
            if line.strip().startswith('-')
        ]
        return actions[:5]

    def evaluate(self, state: str, problem: str) -> float:
        """Evaluate state using reward model."""
        prompt = f"""Problem: {problem}

Reasoning:
{state}

Rate this reasoning from 0 to 1:
- 1.0: Correct and complete solution
- 0.7-0.9: On the right track, mostly correct
- 0.4-0.6: Partially correct, some errors
- 0.1-0.3: Mostly wrong
- 0.0: Completely wrong or irrelevant"""

        response = self.reward_model.generate(prompt)
        try:
            score = float(re.search(r'(\d*\.?\d+)', response).group(1))
            return min(1.0, max(0.0, score))
        except:
            return 0.5

The four phases of MCTS:

  1. Selection: Start at root, use UCB scores to navigate down to a promising node. UCB balances exploitation (high-value nodes) and exploration (less-visited nodes).

  2. Expansion: When you reach a node with unexplored actions, pick one and create a child node. This grows the tree incrementally.

  3. Simulation (Rollout): From the new node, play out randomly to a terminal state. This estimates the node's value without fully exploring its subtree.

  4. Backpropagation: Update visit counts and values for all nodes on the path from root to the expanded node.

The UCB formula: UCB=VN+clnNparentN\text{UCB} = \frac{V}{N} + c \cdot \sqrt{\frac{\ln N_{\text{parent}}}{N}}

  • First term: Exploitation (average reward)
  • Second term: Exploration bonus (decreases as node is visited more)
  • c=1.414 (sqrt(2)) is a common exploration weight

Why follow most-visited path for final answer? Visit counts reflect confidence better than raw values. A node visited 100 times with average value 0.7 is more reliable than one visited 5 times with value 0.9.

Python
    def extract_solution(self, root: MCTSNode, problem: str) -> str:
        """Extract best solution path."""
        node = root
        path = []

        while node.children:
            # Follow most visited path
            best_child = max(
                node.children.values(),
                key=lambda c: c.visits
            )
            path.append(best_child.state.split('\n')[-1])
            node = best_child

        solution = '\n'.join(path)

        # Generate final answer
        prompt = f"""Problem: {problem}

Reasoning:
{solution}

Based on this reasoning, what is the final answer?"""

        return self.model.generate(prompt)

MCTS Variants for Reasoning

AlphaGo-style with neural evaluation:

Python
class NeuralMCTS(MCTSReasoner):
    def __init__(self, policy_model, value_model, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.policy_model = policy_model  # P(action|state)
        self.value_model = value_model    # V(state)

    def get_actions(self, problem: str, state: str) -> list[str]:
        """Use policy model to get likely actions."""
        # Policy model outputs action probabilities
        actions, probs = self.policy_model.predict(problem, state)

        # Sort by probability
        sorted_actions = [a for _, a in sorted(
            zip(probs, actions), reverse=True
        )]
        return sorted_actions

    def evaluate(self, state: str, problem: str) -> float:
        """Use value model for evaluation."""
        return self.value_model.predict(problem, state)

Best-of-N Sampling

Simple but Effective

Best-of-N generates N complete solutions and selects the best using a verifier:

Python
class BestOfN:
    def __init__(
        self,
        model,
        verifier,
        n: int = 16,
        temperature: float = 0.8
    ):
        self.model = model
        self.verifier = verifier
        self.n = n
        self.temperature = temperature

    def solve(self, problem: str) -> tuple[str, float]:
        """Generate N solutions, return best."""
        solutions = []

        prompt = f"{problem}\nLet's solve this step by step:"

        for _ in range(self.n):
            solution = self.model.generate(
                prompt,
                temperature=self.temperature
            )
            score = self.verifier.score(problem, solution)
            solutions.append((solution, score))

        # Return best
        best_solution, best_score = max(solutions, key=lambda x: x[1])
        return best_solution, best_score

    def solve_with_early_stopping(
        self,
        problem: str,
        threshold: float = 0.95
    ) -> tuple[str, float]:
        """Stop early if we find a high-confidence solution."""
        best_solution = None
        best_score = 0

        prompt = f"{problem}\nLet's solve this step by step:"

        for i in range(self.n):
            solution = self.model.generate(
                prompt,
                temperature=self.temperature
            )
            score = self.verifier.score(problem, solution)

            if score > best_score:
                best_solution = solution
                best_score = score

            # Early stopping
            if score >= threshold:
                print(f"Early stopping at iteration {i+1}")
                return best_solution, best_score

        return best_solution, best_score


class WeightedBestOfN(BestOfN):
    def solve_weighted(self, problem: str) -> str:
        """Weighted sampling based on verifier scores."""
        solutions = []

        prompt = f"{problem}\nLet's solve this step by step:"

        # First pass: generate all
        for _ in range(self.n):
            solution = self.model.generate(
                prompt,
                temperature=self.temperature
            )
            solutions.append(solution)

        # Score all
        scores = [
            self.verifier.score(problem, sol)
            for sol in solutions
        ]

        # Weighted selection (softmax)
        probs = self.softmax(scores)
        selected_idx = random.choices(range(len(solutions)), weights=probs)[0]

        return solutions[selected_idx]

    def softmax(self, scores: list[float], temperature: float = 1.0) -> list[float]:
        exp_scores = [math.exp(s / temperature) for s in scores]
        total = sum(exp_scores)
        return [e / total for e in exp_scores]

Compute-Performance Trade-off

NRelative ComputeTypical Improvement
11xBaseline
44x+5-10% accuracy
1616x+10-20% accuracy
6464x+15-25% accuracy
256256x+18-30% accuracy

Diminishing returns: doubling N gives progressively smaller gains.

Traditional beam search at the token level:

Python
class TokenBeamSearch:
    def __init__(
        self,
        model,
        beam_width: int = 5,
        max_length: int = 512
    ):
        self.model = model
        self.beam_width = beam_width
        self.max_length = max_length

    def search(self, prompt: str) -> str:
        """Standard beam search over tokens."""
        # Initialize beams with prompt
        beams = [(prompt, 0.0)]  # (sequence, log_prob)

        for _ in range(self.max_length):
            all_candidates = []

            for seq, score in beams:
                if seq.endswith('<eos>'):
                    all_candidates.append((seq, score))
                    continue

                # Get next token probabilities
                logits = self.model.get_next_token_logits(seq)
                log_probs = F.log_softmax(logits, dim=-1)

                # Get top-k tokens
                top_log_probs, top_indices = torch.topk(
                    log_probs, self.beam_width
                )

                for log_prob, token_idx in zip(top_log_probs, top_indices):
                    token = self.model.decode(token_idx)
                    new_seq = seq + token
                    new_score = score + log_prob.item()
                    all_candidates.append((new_seq, new_score))

            # Keep top beams
            beams = sorted(
                all_candidates,
                key=lambda x: x[1],
                reverse=True
            )[:self.beam_width]

            # Check if all beams finished
            if all(seq.endswith('<eos>') for seq, _ in beams):
                break

        # Return best
        return beams[0][0]

Beam search over reasoning steps (more practical for LLMs):

Python
class StepBeamSearch:
    def __init__(
        self,
        model,
        verifier,
        beam_width: int = 5,
        max_steps: int = 10
    ):
        self.model = model
        self.verifier = verifier
        self.beam_width = beam_width
        self.max_steps = max_steps

    def search(self, problem: str) -> str:
        """Beam search over reasoning steps."""
        # Initialize beams
        beams = [("", 0.0)]  # (reasoning_so_far, cumulative_score)

        for step in range(self.max_steps):
            all_candidates = []

            for reasoning, score in beams:
                # Generate next step candidates
                next_steps = self.generate_next_steps(problem, reasoning)

                for next_step in next_steps:
                    new_reasoning = f"{reasoning}\n{next_step}" if reasoning else next_step

                    # Score with verifier
                    step_score = self.verifier.score_step(
                        problem, new_reasoning, step
                    )
                    new_score = score + step_score

                    all_candidates.append((new_reasoning, new_score))

            # Keep top beams
            beams = sorted(
                all_candidates,
                key=lambda x: x[1],
                reverse=True
            )[:self.beam_width]

            # Check for completion
            if self.is_complete(problem, beams[0][0]):
                break

        return self.extract_answer(problem, beams[0][0])

    def generate_next_steps(
        self,
        problem: str,
        reasoning: str,
        n_candidates: int = 10
    ) -> list[str]:
        """Generate candidate next reasoning steps."""
        prompt = f"""Problem: {problem}

Reasoning so far:
{reasoning if reasoning else "(Starting)"}

Generate the next reasoning step. Output only the single next step."""

        steps = []
        for _ in range(n_candidates):
            step = self.model.generate(prompt, temperature=0.8)
            steps.append(step.strip())

        # Deduplicate
        return list(set(steps))

    def is_complete(self, problem: str, reasoning: str) -> bool:
        """Check if reasoning is complete."""
        prompt = f"""Problem: {problem}
Reasoning: {reasoning}
Is this reasoning complete with a final answer? (yes/no)"""
        response = self.model.generate(prompt)
        return 'yes' in response.lower()

    def extract_answer(self, problem: str, reasoning: str) -> str:
        """Extract final answer."""
        prompt = f"""Problem: {problem}
Reasoning: {reasoning}
What is the final answer?"""
        return self.model.generate(prompt)

Process Reward Models (PRMs)

What Are PRMs?

Process Reward Models evaluate each reasoning step, not just the final answer. This enables step-by-step guidance during generation.

Outcome Reward Model (ORM): Score(final_answer) → 0 or 1 Process Reward Model (PRM): Score(step_1), Score(step_2), ..., Score(step_n)

PRM Architecture

Python
class ProcessRewardModel(nn.Module):
    """Process Reward Model for step-by-step evaluation."""

    def __init__(
        self,
        base_model: str = "meta-llama/Llama-3.1-8B",
        hidden_size: int = 4096
    ):
        super().__init__()

        self.backbone = AutoModel.from_pretrained(base_model)
        self.reward_head = nn.Linear(hidden_size, 1)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        step_positions: list[int]
    ) -> torch.Tensor:
        """
        Compute reward for each reasoning step.

        step_positions: indices marking end of each step
        """
        outputs = self.backbone(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        hidden_states = outputs.last_hidden_state

        # Get hidden state at each step position
        step_hidden = hidden_states[:, step_positions, :]

        # Compute rewards
        rewards = self.reward_head(step_hidden).squeeze(-1)
        return torch.sigmoid(rewards)

    def score_solution(
        self,
        problem: str,
        solution_steps: list[str]
    ) -> list[float]:
        """Score each step in a solution."""
        # Build input with step markers
        text = f"Problem: {problem}\n"
        step_positions = []

        for i, step in enumerate(solution_steps):
            text += f"Step {i+1}: {step}\n"
            step_positions.append(len(self.tokenizer.encode(text)) - 1)

        inputs = self.tokenizer(
            text,
            return_tensors="pt",
            padding=True
        )

        with torch.no_grad():
            rewards = self.forward(
                inputs.input_ids,
                inputs.attention_mask,
                step_positions
            )

        return rewards[0].tolist()


def train_prm(
    model: ProcessRewardModel,
    train_data: list[dict],
    epochs: int = 3
):
    """
    Train PRM on step-labeled data.

    train_data format:
    [{
        "problem": str,
        "steps": [str, ...],
        "labels": [0 or 1, ...]  # 1 = correct step, 0 = error
    }]
    """
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    criterion = nn.BCELoss()

    for epoch in range(epochs):
        total_loss = 0

        for example in train_data:
            optimizer.zero_grad()

            # Get step positions
            text, step_positions = model.prepare_input(
                example["problem"],
                example["steps"]
            )

            inputs = model.tokenizer(text, return_tensors="pt")

            # Forward
            predicted = model(
                inputs.input_ids,
                inputs.attention_mask,
                step_positions
            )

            labels = torch.tensor(example["labels"], dtype=torch.float32)

            loss = criterion(predicted, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_data):.4f}")

How PRMs work:

  1. Architecture: A language model backbone (like Llama) processes the problem + partial solution. A linear head on top predicts a reward score for each step position.

  2. Step positions: The model needs to know where each step ends. The step_positions list marks token indices at step boundaries.

  3. Output interpretation: Sigmoid activation produces scores in [0, 1]. Higher = more likely to lead to correct answer.

Training data format:

Python
{
    "problem": "What is 2 + 2?",
    "steps": ["First, I identify this as addition", "2 + 2 = 4", "The answer is 4"],
    "labels": [1, 1, 1]  # All steps are correct
}

Labels are typically generated by:

  • Human annotation (expensive but high quality)
  • Completion-based labeling (Math-Shepherd approach)
  • Model self-evaluation (cheaper but noisier)

PRM-Guided Generation

Use PRM to guide generation at each step:

Python
class PRMGuidedGenerator:
    def __init__(
        self,
        generator_model,
        prm: ProcessRewardModel,
        num_candidates: int = 8,
        threshold: float = 0.5
    ):
        self.generator = generator_model
        self.prm = prm
        self.num_candidates = num_candidates
        self.threshold = threshold

    def generate_solution(self, problem: str, max_steps: int = 15) -> str:
        """Generate solution with PRM guidance."""
        steps = []

        for i in range(max_steps):
            # Generate candidates for next step
            candidates = self.generate_step_candidates(problem, steps)

            if not candidates:
                break

            # Score each candidate
            scored_candidates = []
            for candidate in candidates:
                test_steps = steps + [candidate]
                scores = self.prm.score_solution(problem, test_steps)
                step_score = scores[-1]  # Score of new step
                scored_candidates.append((candidate, step_score))

            # Select best candidate above threshold
            valid_candidates = [
                (c, s) for c, s in scored_candidates if s >= self.threshold
            ]

            if not valid_candidates:
                # All candidates below threshold - might need backtracking
                # For now, take best anyway
                best_candidate = max(scored_candidates, key=lambda x: x[1])[0]
            else:
                best_candidate = max(valid_candidates, key=lambda x: x[1])[0]

            steps.append(best_candidate)

            # Check if complete
            if self.is_final_answer(best_candidate):
                break

        return '\n'.join(steps)

    def generate_step_candidates(
        self,
        problem: str,
        previous_steps: list[str]
    ) -> list[str]:
        """Generate candidate next steps."""
        prompt = f"""Problem: {problem}

Previous steps:
{chr(10).join(f"{i+1}. {s}" for i, s in enumerate(previous_steps)) if previous_steps else "(Starting)"}

Generate the next reasoning step:"""

        candidates = []
        for _ in range(self.num_candidates):
            response = self.generator.generate(prompt, temperature=0.8)
            candidates.append(response.strip())

        return list(set(candidates))  # Deduplicate

    def is_final_answer(self, step: str) -> bool:
        """Check if step contains final answer."""
        indicators = [
            "the answer is",
            "therefore,",
            "thus,",
            "in conclusion",
            "final answer:",
            "= "
        ]
        step_lower = step.lower()
        return any(ind in step_lower for ind in indicators)

PRM-guided generation flow:

  1. Generate candidates: For each step, generate N candidate next steps (with temperature for diversity).

  2. Score candidates: Evaluate each candidate by appending it to the current solution and running the PRM.

  3. Threshold filtering: Only consider candidates scoring above the threshold (0.5 default). This prevents committing to likely-wrong steps.

  4. Best selection: Among valid candidates, pick the highest-scoring one.

  5. Fallback: If all candidates score below threshold, either backtrack or accept the best anyway (current implementation takes best).

Why threshold filtering? Without it, the generator might take a step with 0.2 score just because it's the "best" among bad options. The threshold forces the model to only proceed with reasonably confident steps.

Math-Shepherd PRM

The Math-Shepherd approach uses automated labeling:

Python
def generate_prm_training_data(
    problems: list[str],
    solutions: list[list[str]],
    ground_truth: list[str]
) -> list[dict]:
    """
    Generate PRM training data using completion-based labeling.

    For each step, complete to final answer multiple times.
    If completions reach correct answer >50% of time, step is correct.
    """
    training_data = []

    for problem, steps, answer in zip(problems, solutions, ground_truth):
        step_labels = []

        for i, step in enumerate(steps):
            partial_solution = steps[:i+1]

            # Complete from this step multiple times
            completions = complete_from_step(problem, partial_solution, n=10)

            # Check how many reach correct answer
            correct_count = sum(
                1 for c in completions
                if extract_answer(c) == answer
            )

            # Label: 1 if majority correct, 0 otherwise
            label = 1 if correct_count >= 5 else 0
            step_labels.append(label)

        training_data.append({
            "problem": problem,
            "steps": steps,
            "labels": step_labels
        })

    return training_data

Diverse Verifier Tree Search (DVTS)

DVTS Algorithm

DVTS from HuggingFace's search-and-learn framework combines tree search with diverse exploration:

Python
class DVTS:
    """
    Diverse Verifier Tree Search.

    Key innovations:
    1. Diversity-promoting selection
    2. Process reward model guidance
    3. Adaptive beam allocation
    """

    def __init__(
        self,
        model,
        prm: ProcessRewardModel,
        n_beams: int = 4,
        n_candidates: int = 8,
        diversity_penalty: float = 0.1
    ):
        self.model = model
        self.prm = prm
        self.n_beams = n_beams
        self.n_candidates = n_candidates
        self.diversity_penalty = diversity_penalty

    def search(self, problem: str, max_steps: int = 15) -> str:
        """Run DVTS search."""
        # Initialize beams
        beams = [
            {
                "steps": [],
                "score": 0.0,
                "embedding": None
            }
            for _ in range(self.n_beams)
        ]

        for step_idx in range(max_steps):
            all_candidates = []

            for beam in beams:
                # Generate candidates
                candidates = self.generate_candidates(
                    problem, beam["steps"]
                )

                for candidate in candidates:
                    new_steps = beam["steps"] + [candidate]

                    # Score with PRM
                    prm_scores = self.prm.score_solution(problem, new_steps)
                    step_score = prm_scores[-1]

                    # Compute embedding for diversity
                    embedding = self.get_embedding(new_steps)

                    all_candidates.append({
                        "steps": new_steps,
                        "score": step_score,
                        "embedding": embedding
                    })

            # Select diverse top-k
            beams = self.diverse_select(all_candidates, self.n_beams)

            # Check for completion
            if all(self.is_complete(b["steps"]) for b in beams):
                break

        # Return best beam
        best_beam = max(beams, key=lambda b: b["score"])
        return self.format_solution(problem, best_beam["steps"])

    def diverse_select(
        self,
        candidates: list[dict],
        k: int
    ) -> list[dict]:
        """Select top-k candidates while promoting diversity."""
        if len(candidates) <= k:
            return candidates

        # Sort by score
        sorted_candidates = sorted(
            candidates,
            key=lambda c: c["score"],
            reverse=True
        )

        selected = [sorted_candidates[0]]

        for _ in range(k - 1):
            best_candidate = None
            best_score = float('-inf')

            for candidate in sorted_candidates:
                if candidate in selected:
                    continue

                # Compute diversity-adjusted score
                min_similarity = min(
                    self.cosine_similarity(
                        candidate["embedding"],
                        s["embedding"]
                    )
                    for s in selected
                )

                diversity_bonus = (1 - min_similarity) * self.diversity_penalty
                adjusted_score = candidate["score"] + diversity_bonus

                if adjusted_score > best_score:
                    best_score = adjusted_score
                    best_candidate = candidate

            if best_candidate:
                selected.append(best_candidate)

        return selected

    def generate_candidates(
        self,
        problem: str,
        previous_steps: list[str]
    ) -> list[str]:
        """Generate diverse candidate next steps."""
        prompt = f"""Problem: {problem}

Solution so far:
{self.format_steps(previous_steps)}

Generate the next step:"""

        candidates = []

        # Use different temperatures for diversity
        temperatures = [0.5, 0.7, 0.9, 1.1]

        for temp in temperatures:
            for _ in range(self.n_candidates // len(temperatures)):
                response = self.model.generate(prompt, temperature=temp)
                candidates.append(response.strip())

        return list(set(candidates))

    def get_embedding(self, steps: list[str]) -> np.ndarray:
        """Get embedding of solution state."""
        text = '\n'.join(steps)
        return self.model.encode(text)

    def cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float:
        """Compute cosine similarity between embeddings."""
        return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

    def format_steps(self, steps: list[str]) -> str:
        if not steps:
            return "(Starting)"
        return '\n'.join(f"Step {i+1}: {s}" for i, s in enumerate(steps))

    def is_complete(self, steps: list[str]) -> bool:
        if not steps:
            return False
        last_step = steps[-1].lower()
        return any(x in last_step for x in ['answer is', 'therefore', '= '])

    def format_solution(self, problem: str, steps: list[str]) -> str:
        return f"Problem: {problem}\n\nSolution:\n{self.format_steps(steps)}"

DVTS key innovations:

  1. Diversity penalty: Standard beam search tends to keep similar solutions. DVTS adds a bonus for solutions that are different from already-selected ones, computed via embedding similarity.

  2. Embedding-based similarity: Each partial solution is embedded, and cosine similarity measures how alike two solutions are. This captures semantic diversity, not just surface-level text differences.

  3. Multi-temperature sampling: Candidates are generated at different temperatures (0.5, 0.7, 0.9, 1.1). Low temperatures produce safe, likely steps; high temperatures explore creative alternatives.

The diversity selection algorithm:

  1. Start by selecting the highest-scoring candidate
  2. For each remaining slot:
    • For each unselected candidate, compute: adjusted_score = raw_score + (1 - min_similarity_to_selected) * penalty
    • Select the candidate with highest adjusted score
  3. This creates a set that's both high-quality (score) and diverse (dissimilar)

Why diversity matters: Multiple similar high-scoring paths might all be wrong in the same way. Diversity ensures at least one path explores a different approach, increasing the chance of finding the correct answer.


HuggingFace Search-and-Learn Framework

Installation and Setup

Bash
pip install sal
# or
git clone https://github.com/huggingface/search-and-learn
cd search-and-learn
pip install -e .

Basic Usage

Python
from sal.config import Config
from sal.search import beam_search, best_of_n, dvts
from sal.models.reward_models import load_prm

# Configuration
config = Config()
config.n = 32  # Number of candidates/beams
config.temperature = 0.8
config.search_batch_size = 8

# Load models
llm = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
prm = load_prm("RLHFlow/Llama3.1-8B-PRM-Deepseek-Data")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")

# Prepare input
problem = "What is 17 * 24 + 156?"
input_batch = tokenizer(
    [f"Solve: {problem}"],
    return_tensors="pt",
    padding=True
)
Python
# Best-of-N: Generate N solutions, pick best
result = best_of_n(
    x=input_batch,
    config=config,
    llm=llm,
    prm=prm
)

print(f"Best solution: {result.best_solution}")
print(f"Score: {result.best_score}")
print(f"Time: {result.elapsed_time:.2f}s")

Beam Search

Python
# Beam search with PRM guidance
config.beam_width = 4
config.max_steps = 20

result = beam_search(
    x=input_batch,
    config=config,
    llm=llm,
    prm=prm
)

print(f"Solution: {result.solution}")
print(f"Step scores: {result.step_scores}")

DVTS

Python
# Diverse Verifier Tree Search
config.n_beams = 4
config.diversity_penalty = 0.1

result = dvts(
    x=input_batch,
    config=config,
    llm=llm,
    prm=prm
)

print(f"Solution: {result.solution}")
print(f"Explored paths: {result.n_explored}")

Performance Comparison

From HuggingFace benchmarks on MATH500:

MethodAccuracyTime (s)Notes
Greedy65.2%1.2Baseline
Best-of-N (N=32)73.4%3.5Simple, effective
Beam Search (k=4)75.8%10.1More compute
DVTS78.2%8.5Best accuracy/time trade-off

Custom Search Configuration

Python
from sal.config import SearchConfig

# Fine-tuned configuration
search_config = SearchConfig(
    # Generation parameters
    temperature=0.8,
    top_p=0.95,
    max_new_tokens=1024,

    # Search parameters
    n_candidates=32,
    beam_width=4,
    max_steps=20,

    # PRM parameters
    prm_batch_size=16,
    min_step_score=0.3,  # Prune steps below this

    # Diversity
    diversity_penalty=0.1,
    n_diverse_beams=4,

    # Early stopping
    early_stop_threshold=0.95,
    patience=3
)

result = dvts(input_batch, search_config, llm, prm)

Self-Reflection and Iterative Refinement

Self-Critique Loop

Python
class SelfReflectionSolver:
    def __init__(
        self,
        model,
        max_iterations: int = 3,
        improvement_threshold: float = 0.1
    ):
        self.model = model
        self.max_iterations = max_iterations
        self.improvement_threshold = improvement_threshold

    def solve(self, problem: str) -> str:
        """Solve with self-reflection loop."""
        # Initial solution
        solution = self.generate_solution(problem)

        for i in range(self.max_iterations):
            # Critique
            critique = self.generate_critique(problem, solution)

            # Check if improvements needed
            if "no issues" in critique.lower() or "correct" in critique.lower():
                break

            # Refine
            refined = self.refine_solution(problem, solution, critique)

            # Check improvement
            if self.is_better(problem, refined, solution):
                solution = refined
            else:
                break

        return solution

    def generate_solution(self, problem: str) -> str:
        prompt = f"""Problem: {problem}

Solve this step by step:"""
        return self.model.generate(prompt)

    def generate_critique(self, problem: str, solution: str) -> str:
        prompt = f"""Problem: {problem}

Proposed solution:
{solution}

Carefully check this solution for errors:
1. Are all calculations correct?
2. Is the logic sound?
3. Does it answer the question asked?
4. Are there any missing steps?

Critique:"""
        return self.model.generate(prompt)

    def refine_solution(
        self,
        problem: str,
        solution: str,
        critique: str
    ) -> str:
        prompt = f"""Problem: {problem}

Previous solution:
{solution}

Issues identified:
{critique}

Provide a corrected solution that addresses these issues:"""
        return self.model.generate(prompt)

    def is_better(
        self,
        problem: str,
        new_solution: str,
        old_solution: str
    ) -> bool:
        """Compare solutions using the model."""
        prompt = f"""Problem: {problem}

Solution A:
{old_solution}

Solution B:
{new_solution}

Which solution is better? Answer with just "A" or "B":"""
        response = self.model.generate(prompt).strip()
        return "B" in response.upper()

Multi-Agent Debate

Python
class DebateSolver:
    def __init__(
        self,
        model,
        n_debaters: int = 3,
        n_rounds: int = 2
    ):
        self.model = model
        self.n_debaters = n_debaters
        self.n_rounds = n_rounds

    def solve(self, problem: str) -> str:
        """Solve through multi-agent debate."""
        # Initial solutions from each debater
        solutions = [
            self.generate_solution(problem, i)
            for i in range(self.n_debaters)
        ]

        # Debate rounds
        for round_idx in range(self.n_rounds):
            new_solutions = []

            for i, solution in enumerate(solutions):
                # Get other solutions
                other_solutions = [s for j, s in enumerate(solutions) if j != i]

                # Debate and update
                updated = self.debate_round(
                    problem, solution, other_solutions, i
                )
                new_solutions.append(updated)

            solutions = new_solutions

        # Final consensus
        return self.reach_consensus(problem, solutions)

    def generate_solution(self, problem: str, debater_id: int) -> str:
        personas = [
            "systematic and methodical",
            "creative and intuitive",
            "skeptical and rigorous"
        ]
        persona = personas[debater_id % len(personas)]

        prompt = f"""You are a {persona} problem solver.

Problem: {problem}

Solve this step by step:"""
        return self.model.generate(prompt)

    def debate_round(
        self,
        problem: str,
        my_solution: str,
        other_solutions: list[str],
        debater_id: int
    ) -> str:
        others_text = '\n\n'.join(
            f"Other solver's answer:\n{s}" for s in other_solutions
        )

        prompt = f"""Problem: {problem}

Your previous answer:
{my_solution}

{others_text}

Consider the other solutions. If they have valid points, incorporate them.
If you see errors in their reasoning, explain why your approach is better.
Provide your updated solution:"""

        return self.model.generate(prompt)

    def reach_consensus(self, problem: str, solutions: list[str]) -> str:
        solutions_text = '\n\n---\n\n'.join(
            f"Solution {i+1}:\n{s}" for i, s in enumerate(solutions)
        )

        prompt = f"""Problem: {problem}

After debate, here are the final positions:

{solutions_text}

Synthesize these into the best final answer, taking the strongest elements from each:"""

        return self.model.generate(prompt)

Multi-agent debate dynamics:

  1. Diverse personas: Each debater has a different style (systematic, creative, skeptical). This prevents groupthink and ensures different approaches are explored.

  2. Exposure to alternatives: In each round, debaters see other solutions. Strong arguments from others can change a debater's approach; weak arguments are critiqued.

  3. Iterative refinement: Multiple rounds allow positions to evolve. Initial disagreements often converge as debaters adopt valid points from each other.

  4. Consensus synthesis: The final step doesn't just pick a winner—it synthesizes the best elements from all positions.

When to use debate vs. other methods:

  • Debate works best for open-ended problems where there's no single "correct" answer format (essays, analysis, recommendations)
  • Majority voting (self-consistency) works better for problems with verifiable answers (math, factual questions)
  • PRM-guided search is superior when you have a trained reward model for your domain

Budget Forcing

Controlling Test-Time Compute

Budget forcing allows explicit control over reasoning length:

Python
class BudgetForcedGenerator:
    def __init__(
        self,
        model,
        tokenizer
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.think_start = "<think>"
        self.think_end = "</think>"

    def generate_with_budget(
        self,
        problem: str,
        min_thinking_tokens: int = 100,
        max_thinking_tokens: int = 2000
    ) -> str:
        """Generate with controlled thinking budget."""

        prompt = f"{problem}\n{self.think_start}"
        generated = ""
        thinking_tokens = 0

        while thinking_tokens < max_thinking_tokens:
            # Generate next chunk
            next_tokens = self.model.generate(
                prompt + generated,
                max_new_tokens=50,
                temperature=0.7
            )

            generated += next_tokens
            thinking_tokens = len(self.tokenizer.encode(generated))

            # Check for premature ending
            if self.think_end in next_tokens:
                if thinking_tokens < min_thinking_tokens:
                    # Remove end token and continue
                    generated = generated.replace(self.think_end, "")
                    generated += "\nWait, let me think more carefully.\n"
                else:
                    # Accept the ending
                    break

        # Force ending if over budget
        if self.think_end not in generated:
            generated += f"\n{self.think_end}"

        # Generate final answer
        final_prompt = prompt + generated + "\nFinal answer:"
        answer = self.model.generate(final_prompt, max_new_tokens=100)

        return f"{self.think_start}{generated}{answer}"

    def extend_thinking(self, partial_response: str) -> str:
        """Extend thinking by suppressing end token."""
        if self.think_end in partial_response:
            # Remove end token
            partial_response = partial_response.replace(self.think_end, "")

        # Add continuation prompt
        partial_response += "\n\nWait, I should reconsider this."

        return partial_response

    def truncate_thinking(self, partial_response: str, max_tokens: int) -> str:
        """Truncate thinking to budget."""
        tokens = self.tokenizer.encode(partial_response)

        if len(tokens) <= max_tokens:
            return partial_response

        # Truncate and add ending
        truncated_tokens = tokens[:max_tokens]
        truncated = self.tokenizer.decode(truncated_tokens)

        # Clean ending
        truncated = truncated.rsplit('\n', 1)[0]  # Remove partial line
        truncated += f"\n{self.think_end}"

        return truncated

Adaptive Budget Allocation

Python
class AdaptiveBudget:
    def __init__(
        self,
        model,
        difficulty_estimator,
        budget_map: dict = None
    ):
        self.model = model
        self.difficulty_estimator = difficulty_estimator
        self.budget_map = budget_map or {
            "easy": (50, 200),
            "medium": (200, 1000),
            "hard": (1000, 4000)
        }

    def estimate_difficulty(self, problem: str) -> str:
        """Estimate problem difficulty."""
        prompt = f"""Rate the difficulty of this problem as easy, medium, or hard:

{problem}

Consider:
- Number of steps required
- Mathematical complexity
- Domain knowledge needed

Difficulty:"""

        response = self.difficulty_estimator.generate(prompt).strip().lower()

        if "easy" in response:
            return "easy"
        elif "hard" in response:
            return "hard"
        else:
            return "medium"

    def solve_with_adaptive_budget(self, problem: str) -> str:
        """Solve with difficulty-appropriate budget."""
        difficulty = self.estimate_difficulty(problem)
        min_tokens, max_tokens = self.budget_map[difficulty]

        generator = BudgetForcedGenerator(self.model, self.model.tokenizer)

        return generator.generate_with_budget(
            problem,
            min_thinking_tokens=min_tokens,
            max_thinking_tokens=max_tokens
        )

Putting It All Together

Unified Test-Time Scaling System

Python
class TestTimeScaler:
    """Unified system for test-time compute scaling."""

    def __init__(
        self,
        model,
        prm: ProcessRewardModel = None,
        default_method: str = "adaptive"
    ):
        self.model = model
        self.prm = prm
        self.default_method = default_method

        # Initialize methods
        self.methods = {
            "cot": lambda p: self.chain_of_thought(p),
            "self_consistency": lambda p: self.self_consistency(p),
            "best_of_n": lambda p: self.best_of_n(p),
            "beam_search": lambda p: self.beam_search(p),
            "tree_of_thoughts": lambda p: self.tree_of_thoughts(p),
            "mcts": lambda p: self.mcts(p),
            "dvts": lambda p: self.dvts(p),
            "self_reflection": lambda p: self.self_reflection(p),
            "adaptive": lambda p: self.adaptive_solve(p)
        }

    def solve(
        self,
        problem: str,
        method: str = None,
        compute_budget: str = "medium"
    ) -> dict:
        """
        Solve problem with specified method and budget.

        Returns dict with solution, method used, and metrics.
        """
        method = method or self.default_method

        if method not in self.methods:
            raise ValueError(f"Unknown method: {method}")

        start_time = time.time()
        solution = self.methods[method](problem)
        elapsed = time.time() - start_time

        return {
            "solution": solution,
            "method": method,
            "time": elapsed,
            "compute_budget": compute_budget
        }

    def adaptive_solve(self, problem: str) -> str:
        """Automatically select method based on problem."""
        # Estimate difficulty
        difficulty = self.estimate_difficulty(problem)

        if difficulty == "easy":
            return self.chain_of_thought(problem)
        elif difficulty == "medium":
            return self.best_of_n(problem)
        else:
            return self.dvts(problem)

    def estimate_difficulty(self, problem: str) -> str:
        """Quick difficulty estimation."""
        # Simple heuristics
        word_count = len(problem.split())
        has_multiple_parts = any(x in problem.lower() for x in ['and', 'then', 'also'])
        is_math = any(x in problem for x in ['+', '-', '*', '/', '='])

        if word_count < 20 and not has_multiple_parts:
            return "easy"
        elif word_count > 50 or (has_multiple_parts and is_math):
            return "hard"
        else:
            return "medium"

    # Individual method implementations
    def chain_of_thought(self, problem: str) -> str:
        prompt = f"{problem}\n\nLet's think step by step:"
        return self.model.generate(prompt)

    def self_consistency(self, problem: str, n: int = 10) -> str:
        responses = []
        for _ in range(n):
            resp = self.model.generate(
                f"{problem}\n\nLet's think step by step:",
                temperature=0.7
            )
            responses.append(resp)

        # Majority vote on final answer
        answers = [self.extract_answer(r) for r in responses]
        return collections.Counter(answers).most_common(1)[0][0]

    def best_of_n(self, problem: str, n: int = 16) -> str:
        if not self.prm:
            return self.self_consistency(problem, n)

        best_solution = None
        best_score = float('-inf')

        for _ in range(n):
            solution = self.model.generate(
                f"{problem}\n\nSolve step by step:",
                temperature=0.8
            )
            steps = self.parse_steps(solution)
            scores = self.prm.score_solution(problem, steps)
            avg_score = sum(scores) / len(scores) if scores else 0

            if avg_score > best_score:
                best_score = avg_score
                best_solution = solution

        return best_solution

    def beam_search(self, problem: str) -> str:
        searcher = StepBeamSearch(
            self.model,
            self.prm or self.model,  # Use model as verifier if no PRM
            beam_width=4,
            max_steps=15
        )
        return searcher.search(problem)

    def tree_of_thoughts(self, problem: str) -> str:
        tot = TreeOfThoughts(
            self.model,
            self.prm or self.model,
            max_depth=5,
            branching_factor=3,
            beam_width=3
        )
        return tot.solve_bfs(problem)

    def mcts(self, problem: str) -> str:
        searcher = MCTSReasoner(
            self.model,
            self.prm or self.model,
            num_simulations=50,
            max_depth=10
        )
        return searcher.solve(problem)

    def dvts(self, problem: str) -> str:
        searcher = DVTS(
            self.model,
            self.prm,
            n_beams=4,
            n_candidates=8
        )
        return searcher.search(problem)

    def self_reflection(self, problem: str) -> str:
        solver = SelfReflectionSolver(self.model, max_iterations=3)
        return solver.solve(problem)

    def extract_answer(self, response: str) -> str:
        """Extract final answer from response."""
        patterns = [
            r"answer is[:\s]*([^\n.]+)",
            r"= ([^\n.]+)$",
            r"therefore[,:\s]*([^\n.]+)"
        ]
        for pattern in patterns:
            match = re.search(pattern, response, re.IGNORECASE)
            if match:
                return match.group(1).strip()
        return response.split('\n')[-1].strip()

    def parse_steps(self, solution: str) -> list[str]:
        """Parse solution into steps."""
        lines = solution.strip().split('\n')
        steps = [l.strip() for l in lines if l.strip() and not l.startswith('#')]
        return steps

Comparison and Selection Guide

When to Use Each Method

MethodBest ForComputeLatency
CoTSimple reasoning, quick responsesLowLow
Self-ConsistencyMath problems, verifiable answersMediumMedium
Best-of-NWhen you have a good verifierMediumMedium
Beam SearchMulti-step problems, PRM availableHighHigh
Tree of ThoughtsComplex planning, backtracking neededVery HighVery High
MCTSGame-like problems, clear rewardsVery HighVery High
DVTSComplex math, need diversityHighHigh
Self-ReflectionOpen-ended problems, no verifierMediumMedium

Decision Flowchart

Code
START
  │
  ├─ Is answer easily verifiable? (math, code)
  │    ├─ YES → Do you have a PRM?
  │    │         ├─ YES → DVTS or Beam Search
  │    │         └─ NO  → Best-of-N with self-verification
  │    │
  │    └─ NO → Is backtracking helpful?
  │              ├─ YES → Tree of Thoughts or MCTS
  │              └─ NO  → Self-Reflection
  │
  ├─ Is latency critical?
  │    ├─ YES → CoT (single pass)
  │    └─ NO  → Self-Consistency
  │
  └─ Is the problem very complex?
       ├─ YES → DVTS with high beam count
       └─ NO  → Best-of-N (N=8-16)

Conclusion

Test-time compute scaling represents a fundamental shift in how we improve AI capabilities. Instead of relying solely on larger models, we can achieve better results by having models "think longer" at inference time.

Key takeaways:

  1. Start simple: Chain-of-Thought and Self-Consistency provide significant improvements with minimal complexity
  2. Add verifiers: Process Reward Models enable sophisticated search methods
  3. Match method to problem: Use simpler methods for easier problems, complex search for hard ones
  4. Budget wisely: More compute helps, but with diminishing returns
  5. Combine techniques: The best systems often combine multiple approaches

The HuggingFace search-and-learn framework provides production-ready implementations of these techniques, making it accessible to implement sophisticated test-time scaling in your own applications.

Frequently Asked Questions

Enrico Piovano, PhD

Co-founder & CTO at Goji AI. Former Applied Scientist at Amazon (Alexa & AGI), focused on Agentic AI and LLMs. PhD in Electrical Engineering from Imperial College London. Gold Medalist at the National Mathematical Olympiad.

Related Articles