Test-Time Compute Scaling: CoT, ToT, MCTS, and Search-Based Reasoning
A comprehensive guide to inference-time scaling techniques—Chain of Thought, Tree of Thoughts, Monte Carlo Tree Search, Process Reward Models, and the HuggingFace search-and-learn framework.
Table of Contents
The Test-Time Compute Revolution
For years, scaling AI meant one thing: bigger models trained on more data. But 2024-2025 introduced a paradigm shift—test-time compute scaling—where models improve by "thinking longer" at inference time rather than training longer.
The insight is profound: a small model that thinks carefully can outperform a large model that responds immediately. UC Berkeley research demonstrated that "with the right inference-time scaling approach, a 1B parameter model can outperform a 405B Llama 3 model that lacks inference-time scaling."
This post provides a comprehensive technical guide to test-time compute techniques, from basic Chain-of-Thought prompting to sophisticated tree search algorithms with Process Reward Models.
Understanding Test-Time Compute
What Is Test-Time Compute?
Test-time compute (also called inference-time scaling) refers to using additional computation during model inference to improve output quality. Instead of generating a response immediately, the model:
The paradigm shift: For a decade, AI progress followed a simple recipe: train bigger models on more data. GPT-2 → GPT-3 → GPT-4 represented scaling training compute. But this approach has diminishing returns (scaling laws flatten) and massive fixed costs (billions of dollars per training run). Test-time compute offers a different lever: instead of building a smarter brain, give the brain more time to think. A smaller model that "thinks" for 10 seconds can outperform a larger model that responds instantly.
Why this is fundamentally different: Training compute is a fixed investment—you pay once and get a model. Test-time compute scales with usage—you pay per query, more for harder problems. This enables adaptive intelligence: easy questions get quick, cheap answers; hard questions get expensive, thorough reasoning. You couldn't achieve this by "training more" because training is uniform across all future queries.
- Generates multiple candidate solutions
- Explores different reasoning paths
- Evaluates and selects the best output
- Iteratively refines its answer
The key trade-off: more compute at inference = better quality, but higher latency and cost.
Why It Matters
| Approach | Cost | Quality Scaling |
|---|---|---|
| Larger model | Higher training + inference | Diminishing returns (scaling laws plateau) |
| More training data | Higher data collection + training | Data availability limits |
| Test-time compute | Variable per query | Scales with problem difficulty |
Test-time compute is adaptive—you can allocate more thinking to hard problems and less to easy ones. This is fundamentally more efficient than uniformly scaling model size.
The Compute-Optimal Frontier
Research shows optimal allocation depends on problem difficulty:
Easy problems: Minimal test-time compute (fast response)
Medium problems: Moderate compute (some exploration)
Hard problems: Maximum compute (extensive search)
The challenge is determining difficulty before solving—leading to adaptive algorithms that start with less compute and scale up if needed.
Chain-of-Thought (CoT) Prompting
The Foundation
Chain-of-Thought prompting is the simplest test-time compute technique. Instead of answering directly, the model generates intermediate reasoning steps.
Without CoT:
Q: Roger has 5 tennis balls. He buys 2 more cans of 3 tennis balls each.
How many tennis balls does he have now?
A: 11
With CoT:
Q: Roger has 5 tennis balls. He buys 2 more cans of 3 tennis balls each.
How many tennis balls does he have now?
A: Roger starts with 5 balls. 2 cans of 3 balls each means 2 × 3 = 6 new balls.
5 + 6 = 11 balls total.
The answer is 11.
Zero-Shot CoT
The famous "Let's think step by step" prompt enables CoT without examples:
def zero_shot_cot(question: str, model) -> str:
"""Apply zero-shot chain-of-thought prompting."""
prompt = f"""{question}
Let's think step by step."""
response = model.generate(prompt)
return response
Research showed this simple addition improves accuracy on math problems from ~17% to ~78% for certain models.
Few-Shot CoT
Providing examples of reasoning chains helps the model learn the expected format:
def few_shot_cot(question: str, model, examples: list[dict]) -> str:
"""Apply few-shot chain-of-thought prompting."""
prompt = ""
for ex in examples:
prompt += f"Q: {ex['question']}\n"
prompt += f"A: {ex['reasoning']} The answer is {ex['answer']}.\n\n"
prompt += f"Q: {question}\nA:"
response = model.generate(prompt)
return response
# Example usage
examples = [
{
"question": "There are 15 trees. Workers plant 6 more. How many trees?",
"reasoning": "We start with 15 trees. Workers plant 6 more trees. 15 + 6 = 21.",
"answer": "21"
},
{
"question": "If there are 3 cars and each has 4 wheels, how many wheels total?",
"reasoning": "Each car has 4 wheels. With 3 cars: 3 × 4 = 12 wheels.",
"answer": "12"
}
]
Self-Ask CoT
A variant where the model asks and answers sub-questions:
def self_ask_cot(question: str, model) -> str:
"""Self-Ask chain-of-thought prompting."""
prompt = f"""Question: {question}
Are follow up questions needed here: Yes
Follow up: [First sub-question the model should identify]
Intermediate answer: [Answer to follow-up]
Follow up: [Second sub-question if needed]
Intermediate answer: [Answer]
So the final answer is:"""
response = model.generate(prompt)
return response
CoT Limitations
- Linearity: CoT explores one path; if it goes wrong early, everything fails
- No backtracking: Can't revise earlier steps
- Hallucinated reasoning: Steps may be plausible-sounding but wrong
- Length sensitivity: Very long chains can degrade (overthinking)
Research found "an optimal reasoning token budget of approximately 4K, beyond which performance may degrade due to overthinking."
Self-Consistency
The Technique
Self-consistency generates multiple CoT paths and takes a majority vote on the final answer:
import collections
def self_consistency(question: str, model, n_samples: int = 10, temperature: float = 0.7) -> str:
"""Self-consistency decoding with majority voting."""
prompt = f"{question}\nLet's think step by step."
# Generate multiple reasoning chains
responses = []
for _ in range(n_samples):
response = model.generate(prompt, temperature=temperature)
responses.append(response)
# Extract final answers
answers = [extract_final_answer(r) for r in responses]
# Majority vote
answer_counts = collections.Counter(answers)
best_answer = answer_counts.most_common(1)[0][0]
return best_answer
def extract_final_answer(response: str) -> str:
"""Extract the final answer from a CoT response."""
# Look for common patterns
patterns = [
r"The answer is[:\s]*(.+?)[\.\n]",
r"Therefore[,:\s]*(.+?)[\.\n]",
r"So[,:\s]*(.+?)[\.\n]",
r"= (\d+)",
]
for pattern in patterns:
match = re.search(pattern, response, re.IGNORECASE)
if match:
return match.group(1).strip()
return response.split('\n')[-1].strip()
The self-consistency algorithm:
-
Diverse generation: Use temperature sampling (typically 0.7-0.9) to create variety. Each response takes a potentially different reasoning path.
-
Answer extraction: Parse each response to find the final answer. The patterns handle common formats like "The answer is 42" or "= 42".
-
Majority voting: The most frequent answer wins. This works because correct reasoning paths tend to converge on the same answer, while errors are more random.
Why Self-Consistency Works
Different reasoning paths may make different mistakes, but the correct answer is more likely to appear in multiple paths. The diversity of reasoning (from temperature sampling) helps cover different solution strategies.
| Problem Type | Without Self-Consistency | With Self-Consistency (n=40) |
|---|---|---|
| GSM8K | 56.5% | 74.4% |
| SVAMP | 68.9% | 86.6% |
| AQuA | 35.8% | 48.3% |
Weighted Self-Consistency
Weight votes by confidence scores:
def weighted_self_consistency(
question: str,
model,
n_samples: int = 10
) -> str:
"""Self-consistency with confidence weighting."""
answers_with_confidence = []
for _ in range(n_samples):
response, logprobs = model.generate_with_logprobs(
f"{question}\nLet's think step by step.",
temperature=0.7
)
answer = extract_final_answer(response)
# Use average log probability as confidence
confidence = sum(logprobs) / len(logprobs) if logprobs else 0
answers_with_confidence.append((answer, confidence))
# Weighted aggregation
answer_weights = collections.defaultdict(float)
for answer, conf in answers_with_confidence:
answer_weights[answer] += math.exp(conf) # Convert log prob to prob
best_answer = max(answer_weights, key=answer_weights.get)
return best_answer
Tree of Thoughts (ToT)
Beyond Linear Reasoning
Tree of Thoughts extends CoT by exploring multiple reasoning branches simultaneously, with the ability to backtrack and try alternative paths.
[Problem]
│
┌───────┼───────┐
│ │ │
[Thought 1][Thought 2][Thought 3]
│ │ │
▼ ▼ ▼
[Dead [Branch [Branch
End] A] B]
│ │
... ...
Core Algorithm
from dataclasses import dataclass
from typing import Optional
import heapq
@dataclass
class ThoughtNode:
"""A node in the thought tree."""
state: str # Current reasoning state
parent: Optional['ThoughtNode']
children: list['ThoughtNode']
value: float # Evaluation score
depth: int
def __lt__(self, other):
return self.value > other.value # Higher value = higher priority
class TreeOfThoughts:
def __init__(
self,
model,
evaluator,
max_depth: int = 5,
branching_factor: int = 3,
beam_width: int = 5
):
self.model = model
self.evaluator = evaluator
self.max_depth = max_depth
self.branching_factor = branching_factor
self.beam_width = beam_width
def generate_thoughts(self, state: str, problem: str) -> list[str]:
"""Generate candidate next thoughts from current state."""
prompt = f"""Problem: {problem}
Current reasoning:
{state}
Generate {self.branching_factor} distinct next steps for solving this problem.
Each step should explore a different approach or aspect.
Format: List each thought on a new line starting with "Thought:"
"""
response = self.model.generate(prompt)
# Parse thoughts
thoughts = []
for line in response.split('\n'):
if line.strip().startswith('Thought:'):
thoughts.append(line.replace('Thought:', '').strip())
return thoughts[:self.branching_factor]
def evaluate_state(self, state: str, problem: str) -> float:
"""Evaluate how promising a reasoning state is."""
prompt = f"""Problem: {problem}
Reasoning so far:
{state}
Evaluate this reasoning on a scale of 1-10:
- Is it making progress toward the solution?
- Is the logic sound?
- Are there errors or dead ends?
Score (1-10):"""
response = self.evaluator.generate(prompt)
# Extract score
try:
score = float(re.search(r'(\d+)', response).group(1))
return min(10, max(1, score)) / 10
except:
return 0.5
def solve_bfs(self, problem: str) -> str:
"""Breadth-first search through thought tree."""
root = ThoughtNode(
state="",
parent=None,
children=[],
value=1.0,
depth=0
)
current_level = [root]
for depth in range(self.max_depth):
next_level = []
for node in current_level:
# Generate thoughts
thoughts = self.generate_thoughts(node.state, problem)
for thought in thoughts:
new_state = f"{node.state}\n{thought}" if node.state else thought
value = self.evaluate_state(new_state, problem)
child = ThoughtNode(
state=new_state,
parent=node,
children=[],
value=value,
depth=depth + 1
)
node.children.append(child)
next_level.append(child)
# Beam search: keep top-k
next_level.sort(key=lambda x: x.value, reverse=True)
current_level = next_level[:self.beam_width]
# Check for solution
for node in current_level:
if self.is_solution(node.state, problem):
return self.extract_answer(node.state, problem)
# Return best found
best_node = max(current_level, key=lambda x: x.value)
return self.extract_answer(best_node.state, problem)
def solve_dfs(self, problem: str, max_iterations: int = 100) -> str:
"""Depth-first search with backtracking."""
stack = [ThoughtNode(state="", parent=None, children=[], value=1.0, depth=0)]
best_solution = None
best_value = 0
iterations = 0
while stack and iterations < max_iterations:
iterations += 1
node = stack.pop()
if node.depth >= self.max_depth:
if node.value > best_value:
best_value = node.value
best_solution = node.state
continue
thoughts = self.generate_thoughts(node.state, problem)
for thought in thoughts:
new_state = f"{node.state}\n{thought}" if node.state else thought
value = self.evaluate_state(new_state, problem)
# Pruning: skip unpromising branches
if value < 0.3:
continue
child = ThoughtNode(
state=new_state,
parent=node,
children=[],
value=value,
depth=node.depth + 1
)
stack.append(child)
return self.extract_answer(best_solution or "", problem)
def is_solution(self, state: str, problem: str) -> bool:
"""Check if state contains a complete solution."""
prompt = f"""Problem: {problem}
Reasoning:
{state}
Does this reasoning reach a final answer? (yes/no)"""
response = self.model.generate(prompt).strip().lower()
return 'yes' in response
def extract_answer(self, state: str, problem: str) -> str:
"""Extract final answer from reasoning state."""
prompt = f"""Problem: {problem}
Reasoning:
{state}
Based on this reasoning, what is the final answer?"""
return self.model.generate(prompt)
BFS vs DFS for Tree of Thoughts:
-
BFS (
solve_bfs): Explores all nodes at each depth before going deeper. Uses beam search to keep only the top-k most promising branches. Best when you want to thoroughly explore early reasoning steps before committing. -
DFS (
solve_dfs): Goes deep first, then backtracks. More memory-efficient. Good when solutions require many steps but branching factor is low.
Key design decisions:
-
Evaluation as a separate model call: The
evaluate_statemethod asks the LLM to score reasoning quality. This is expensive but necessary—we can't assume any reasoning is good just because it's coherent. -
Pruning threshold (0.3 in DFS): Low-scoring branches are abandoned. This prevents wasting compute on dead ends but risks missing unconventional solutions.
-
Solution detection: The
is_solutionmethod asks the LLM if reasoning is complete. This avoids hardcoding answer formats. -
Beam width controls breadth-accuracy trade-off: Higher beam width explores more but costs more compute.
ToT for Different Problem Types
Game of 24 (combine 4 numbers to get 24):
def tot_game_of_24(numbers: list[int], tot: TreeOfThoughts) -> str:
problem = f"Use the numbers {numbers} with +, -, *, / to get 24."
# Custom thought generator for this domain
def generate_game_thoughts(state, remaining):
thoughts = []
for i, a in enumerate(remaining):
for j, b in enumerate(remaining):
if i != j:
for op in ['+', '-', '*', '/']:
if op == '/' and b == 0:
continue
result = eval(f"{a}{op}{b}")
new_remaining = [n for k, n in enumerate(remaining)
if k != i and k != j] + [result]
thoughts.append(f"{a} {op} {b} = {result}")
return thoughts
return tot.solve_bfs(problem)
Creative Writing:
def tot_creative_writing(topic: str, tot: TreeOfThoughts) -> str:
problem = f"Write a coherent passage about: {topic}"
# Thoughts are paragraph plans
# Evaluation considers coherence, creativity, relevance
return tot.solve_bfs(problem)
ToT vs CoT Comparison
| Aspect | Chain-of-Thought | Tree of Thoughts |
|---|---|---|
| Structure | Linear | Branching tree |
| Backtracking | No | Yes |
| Exploration | Single path | Multiple paths |
| Compute cost | O(n) | O(b^d) where b=branching, d=depth |
| Best for | Simple reasoning | Complex multi-step problems |
Forest of Thoughts (FoT)
Ensemble Tree Search
Forest of Thoughts runs multiple ToT searches in parallel and aggregates results:
from concurrent.futures import ThreadPoolExecutor, as_completed
class ForestOfThoughts:
def __init__(
self,
model,
evaluator,
n_trees: int = 5,
tree_config: dict = None
):
self.model = model
self.evaluator = evaluator
self.n_trees = n_trees
self.tree_config = tree_config or {
'max_depth': 5,
'branching_factor': 3,
'beam_width': 3
}
def solve(self, problem: str) -> str:
"""Run multiple ToT searches and aggregate."""
trees = [
TreeOfThoughts(
self.model,
self.evaluator,
**self.tree_config
)
for _ in range(self.n_trees)
]
# Run in parallel
results = []
with ThreadPoolExecutor(max_workers=self.n_trees) as executor:
futures = {
executor.submit(tree.solve_bfs, problem): i
for i, tree in enumerate(trees)
}
for future in as_completed(futures):
try:
result = future.result()
results.append(result)
except Exception as e:
print(f"Tree search failed: {e}")
# Aggregate results
return self.aggregate_results(results, problem)
def aggregate_results(self, results: list[str], problem: str) -> str:
"""Aggregate multiple tree search results."""
if not results:
return "Unable to solve."
if len(results) == 1:
return results[0]
# For answers that can be compared directly (numbers, etc)
answer_counts = collections.Counter(
self.normalize_answer(r) for r in results
)
# If majority agreement, return that
most_common, count = answer_counts.most_common(1)[0]
if count > len(results) / 2:
return most_common
# Otherwise, use model to synthesize
prompt = f"""Problem: {problem}
Multiple reasoning approaches produced these answers:
{chr(10).join(f"- {r}" for r in results)}
Analyze these answers and determine the most likely correct answer.
If they disagree, explain which reasoning is most sound."""
return self.model.generate(prompt)
def normalize_answer(self, answer: str) -> str:
"""Normalize answer for comparison."""
# Remove whitespace, convert to lowercase
normalized = ' '.join(answer.lower().split())
# Try to extract numeric answer
numbers = re.findall(r'-?\d+\.?\d*', normalized)
if numbers:
return numbers[-1] # Usually final answer is last
return normalized
Diverse Tree Initialization
For better coverage, initialize trees differently:
class DiverseForestOfThoughts(ForestOfThoughts):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.initial_prompts = [
"Let's solve this systematically.",
"Let's try a creative approach.",
"Let's start with what we know for certain.",
"Let's work backwards from the goal.",
"Let's break this into smaller parts."
]
def create_diverse_trees(self):
"""Create trees with different initial strategies."""
trees = []
for i, init_prompt in enumerate(self.initial_prompts[:self.n_trees]):
tree = TreeOfThoughts(
self.model,
self.evaluator,
**self.tree_config
)
tree.initial_prompt = init_prompt
trees.append(tree)
return trees
Monte Carlo Tree Search (MCTS)
MCTS for Language Models
MCTS balances exploration and exploitation using UCB (Upper Confidence Bound):
import math
import random
@dataclass
class MCTSNode:
"""Node for Monte Carlo Tree Search."""
state: str
parent: Optional['MCTSNode']
children: dict[str, 'MCTSNode'] # action -> child
visits: int = 0
value: float = 0.0
untried_actions: list[str] = None
def ucb_score(self, exploration_weight: float = 1.414) -> float:
"""Upper Confidence Bound score."""
if self.visits == 0:
return float('inf')
exploitation = self.value / self.visits
exploration = exploration_weight * math.sqrt(
math.log(self.parent.visits) / self.visits
)
return exploitation + exploration
def best_child(self, exploration_weight: float = 1.414) -> 'MCTSNode':
"""Select child with highest UCB score."""
return max(
self.children.values(),
key=lambda c: c.ucb_score(exploration_weight)
)
def is_fully_expanded(self) -> bool:
return len(self.untried_actions) == 0
def is_terminal(self) -> bool:
return self.is_fully_expanded() and len(self.children) == 0
class MCTSReasoner:
def __init__(
self,
model,
reward_model,
num_simulations: int = 100,
max_depth: int = 10,
exploration_weight: float = 1.414
):
self.model = model
self.reward_model = reward_model
self.num_simulations = num_simulations
self.max_depth = max_depth
self.exploration_weight = exploration_weight
def solve(self, problem: str) -> str:
"""Solve using MCTS."""
# Initialize root
root = MCTSNode(
state="",
parent=None,
children={},
untried_actions=self.get_actions(problem, "")
)
for _ in range(self.num_simulations):
# Selection
node = self.select(root)
# Expansion
if not node.is_terminal() and node.visits > 0:
node = self.expand(node, problem)
# Simulation
reward = self.simulate(node, problem)
# Backpropagation
self.backpropagate(node, reward)
# Return best path
return self.extract_solution(root, problem)
def select(self, node: MCTSNode) -> MCTSNode:
"""Select most promising node using UCB."""
while node.children and node.is_fully_expanded():
node = node.best_child(self.exploration_weight)
return node
def expand(self, node: MCTSNode, problem: str) -> MCTSNode:
"""Expand node with one untried action."""
if not node.untried_actions:
return node
action = node.untried_actions.pop()
new_state = f"{node.state}\n{action}" if node.state else action
child = MCTSNode(
state=new_state,
parent=node,
children={},
untried_actions=self.get_actions(problem, new_state)
)
node.children[action] = child
return child
def simulate(self, node: MCTSNode, problem: str) -> float:
"""Simulate from node to terminal state."""
state = node.state
depth = state.count('\n') if state else 0
# Random rollout
while depth < self.max_depth:
actions = self.get_actions(problem, state)
if not actions:
break
action = random.choice(actions)
state = f"{state}\n{action}" if state else action
depth += 1
# Evaluate final state
return self.evaluate(state, problem)
def backpropagate(self, node: MCTSNode, reward: float):
"""Backpropagate reward up the tree."""
while node:
node.visits += 1
node.value += reward
node = node.parent
def get_actions(self, problem: str, state: str) -> list[str]:
"""Get possible next reasoning steps."""
prompt = f"""Problem: {problem}
Current reasoning:
{state if state else "(Starting fresh)"}
Generate 5 possible next reasoning steps. Each should be a single logical step.
Format each on a new line starting with "-"."""
response = self.model.generate(prompt)
actions = [
line.lstrip('- ').strip()
for line in response.split('\n')
if line.strip().startswith('-')
]
return actions[:5]
def evaluate(self, state: str, problem: str) -> float:
"""Evaluate state using reward model."""
prompt = f"""Problem: {problem}
Reasoning:
{state}
Rate this reasoning from 0 to 1:
- 1.0: Correct and complete solution
- 0.7-0.9: On the right track, mostly correct
- 0.4-0.6: Partially correct, some errors
- 0.1-0.3: Mostly wrong
- 0.0: Completely wrong or irrelevant"""
response = self.reward_model.generate(prompt)
try:
score = float(re.search(r'(\d*\.?\d+)', response).group(1))
return min(1.0, max(0.0, score))
except:
return 0.5
The four phases of MCTS:
-
Selection: Start at root, use UCB scores to navigate down to a promising node. UCB balances exploitation (high-value nodes) and exploration (less-visited nodes).
-
Expansion: When you reach a node with unexplored actions, pick one and create a child node. This grows the tree incrementally.
-
Simulation (Rollout): From the new node, play out randomly to a terminal state. This estimates the node's value without fully exploring its subtree.
-
Backpropagation: Update visit counts and values for all nodes on the path from root to the expanded node.
The UCB formula:
- First term: Exploitation (average reward)
- Second term: Exploration bonus (decreases as node is visited more)
c=1.414(sqrt(2)) is a common exploration weight
Why follow most-visited path for final answer? Visit counts reflect confidence better than raw values. A node visited 100 times with average value 0.7 is more reliable than one visited 5 times with value 0.9.
def extract_solution(self, root: MCTSNode, problem: str) -> str:
"""Extract best solution path."""
node = root
path = []
while node.children:
# Follow most visited path
best_child = max(
node.children.values(),
key=lambda c: c.visits
)
path.append(best_child.state.split('\n')[-1])
node = best_child
solution = '\n'.join(path)
# Generate final answer
prompt = f"""Problem: {problem}
Reasoning:
{solution}
Based on this reasoning, what is the final answer?"""
return self.model.generate(prompt)
MCTS Variants for Reasoning
AlphaGo-style with neural evaluation:
class NeuralMCTS(MCTSReasoner):
def __init__(self, policy_model, value_model, *args, **kwargs):
super().__init__(*args, **kwargs)
self.policy_model = policy_model # P(action|state)
self.value_model = value_model # V(state)
def get_actions(self, problem: str, state: str) -> list[str]:
"""Use policy model to get likely actions."""
# Policy model outputs action probabilities
actions, probs = self.policy_model.predict(problem, state)
# Sort by probability
sorted_actions = [a for _, a in sorted(
zip(probs, actions), reverse=True
)]
return sorted_actions
def evaluate(self, state: str, problem: str) -> float:
"""Use value model for evaluation."""
return self.value_model.predict(problem, state)
Best-of-N Sampling
Simple but Effective
Best-of-N generates N complete solutions and selects the best using a verifier:
class BestOfN:
def __init__(
self,
model,
verifier,
n: int = 16,
temperature: float = 0.8
):
self.model = model
self.verifier = verifier
self.n = n
self.temperature = temperature
def solve(self, problem: str) -> tuple[str, float]:
"""Generate N solutions, return best."""
solutions = []
prompt = f"{problem}\nLet's solve this step by step:"
for _ in range(self.n):
solution = self.model.generate(
prompt,
temperature=self.temperature
)
score = self.verifier.score(problem, solution)
solutions.append((solution, score))
# Return best
best_solution, best_score = max(solutions, key=lambda x: x[1])
return best_solution, best_score
def solve_with_early_stopping(
self,
problem: str,
threshold: float = 0.95
) -> tuple[str, float]:
"""Stop early if we find a high-confidence solution."""
best_solution = None
best_score = 0
prompt = f"{problem}\nLet's solve this step by step:"
for i in range(self.n):
solution = self.model.generate(
prompt,
temperature=self.temperature
)
score = self.verifier.score(problem, solution)
if score > best_score:
best_solution = solution
best_score = score
# Early stopping
if score >= threshold:
print(f"Early stopping at iteration {i+1}")
return best_solution, best_score
return best_solution, best_score
class WeightedBestOfN(BestOfN):
def solve_weighted(self, problem: str) -> str:
"""Weighted sampling based on verifier scores."""
solutions = []
prompt = f"{problem}\nLet's solve this step by step:"
# First pass: generate all
for _ in range(self.n):
solution = self.model.generate(
prompt,
temperature=self.temperature
)
solutions.append(solution)
# Score all
scores = [
self.verifier.score(problem, sol)
for sol in solutions
]
# Weighted selection (softmax)
probs = self.softmax(scores)
selected_idx = random.choices(range(len(solutions)), weights=probs)[0]
return solutions[selected_idx]
def softmax(self, scores: list[float], temperature: float = 1.0) -> list[float]:
exp_scores = [math.exp(s / temperature) for s in scores]
total = sum(exp_scores)
return [e / total for e in exp_scores]
Compute-Performance Trade-off
| N | Relative Compute | Typical Improvement |
|---|---|---|
| 1 | 1x | Baseline |
| 4 | 4x | +5-10% accuracy |
| 16 | 16x | +10-20% accuracy |
| 64 | 64x | +15-25% accuracy |
| 256 | 256x | +18-30% accuracy |
Diminishing returns: doubling N gives progressively smaller gains.
Beam Search
Token-Level Beam Search
Traditional beam search at the token level:
class TokenBeamSearch:
def __init__(
self,
model,
beam_width: int = 5,
max_length: int = 512
):
self.model = model
self.beam_width = beam_width
self.max_length = max_length
def search(self, prompt: str) -> str:
"""Standard beam search over tokens."""
# Initialize beams with prompt
beams = [(prompt, 0.0)] # (sequence, log_prob)
for _ in range(self.max_length):
all_candidates = []
for seq, score in beams:
if seq.endswith('<eos>'):
all_candidates.append((seq, score))
continue
# Get next token probabilities
logits = self.model.get_next_token_logits(seq)
log_probs = F.log_softmax(logits, dim=-1)
# Get top-k tokens
top_log_probs, top_indices = torch.topk(
log_probs, self.beam_width
)
for log_prob, token_idx in zip(top_log_probs, top_indices):
token = self.model.decode(token_idx)
new_seq = seq + token
new_score = score + log_prob.item()
all_candidates.append((new_seq, new_score))
# Keep top beams
beams = sorted(
all_candidates,
key=lambda x: x[1],
reverse=True
)[:self.beam_width]
# Check if all beams finished
if all(seq.endswith('<eos>') for seq, _ in beams):
break
# Return best
return beams[0][0]
Step-Level Beam Search
Beam search over reasoning steps (more practical for LLMs):
class StepBeamSearch:
def __init__(
self,
model,
verifier,
beam_width: int = 5,
max_steps: int = 10
):
self.model = model
self.verifier = verifier
self.beam_width = beam_width
self.max_steps = max_steps
def search(self, problem: str) -> str:
"""Beam search over reasoning steps."""
# Initialize beams
beams = [("", 0.0)] # (reasoning_so_far, cumulative_score)
for step in range(self.max_steps):
all_candidates = []
for reasoning, score in beams:
# Generate next step candidates
next_steps = self.generate_next_steps(problem, reasoning)
for next_step in next_steps:
new_reasoning = f"{reasoning}\n{next_step}" if reasoning else next_step
# Score with verifier
step_score = self.verifier.score_step(
problem, new_reasoning, step
)
new_score = score + step_score
all_candidates.append((new_reasoning, new_score))
# Keep top beams
beams = sorted(
all_candidates,
key=lambda x: x[1],
reverse=True
)[:self.beam_width]
# Check for completion
if self.is_complete(problem, beams[0][0]):
break
return self.extract_answer(problem, beams[0][0])
def generate_next_steps(
self,
problem: str,
reasoning: str,
n_candidates: int = 10
) -> list[str]:
"""Generate candidate next reasoning steps."""
prompt = f"""Problem: {problem}
Reasoning so far:
{reasoning if reasoning else "(Starting)"}
Generate the next reasoning step. Output only the single next step."""
steps = []
for _ in range(n_candidates):
step = self.model.generate(prompt, temperature=0.8)
steps.append(step.strip())
# Deduplicate
return list(set(steps))
def is_complete(self, problem: str, reasoning: str) -> bool:
"""Check if reasoning is complete."""
prompt = f"""Problem: {problem}
Reasoning: {reasoning}
Is this reasoning complete with a final answer? (yes/no)"""
response = self.model.generate(prompt)
return 'yes' in response.lower()
def extract_answer(self, problem: str, reasoning: str) -> str:
"""Extract final answer."""
prompt = f"""Problem: {problem}
Reasoning: {reasoning}
What is the final answer?"""
return self.model.generate(prompt)
Process Reward Models (PRMs)
What Are PRMs?
Process Reward Models evaluate each reasoning step, not just the final answer. This enables step-by-step guidance during generation.
Outcome Reward Model (ORM): Score(final_answer) → 0 or 1 Process Reward Model (PRM): Score(step_1), Score(step_2), ..., Score(step_n)
PRM Architecture
class ProcessRewardModel(nn.Module):
"""Process Reward Model for step-by-step evaluation."""
def __init__(
self,
base_model: str = "meta-llama/Llama-3.1-8B",
hidden_size: int = 4096
):
super().__init__()
self.backbone = AutoModel.from_pretrained(base_model)
self.reward_head = nn.Linear(hidden_size, 1)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
step_positions: list[int]
) -> torch.Tensor:
"""
Compute reward for each reasoning step.
step_positions: indices marking end of each step
"""
outputs = self.backbone(
input_ids=input_ids,
attention_mask=attention_mask
)
hidden_states = outputs.last_hidden_state
# Get hidden state at each step position
step_hidden = hidden_states[:, step_positions, :]
# Compute rewards
rewards = self.reward_head(step_hidden).squeeze(-1)
return torch.sigmoid(rewards)
def score_solution(
self,
problem: str,
solution_steps: list[str]
) -> list[float]:
"""Score each step in a solution."""
# Build input with step markers
text = f"Problem: {problem}\n"
step_positions = []
for i, step in enumerate(solution_steps):
text += f"Step {i+1}: {step}\n"
step_positions.append(len(self.tokenizer.encode(text)) - 1)
inputs = self.tokenizer(
text,
return_tensors="pt",
padding=True
)
with torch.no_grad():
rewards = self.forward(
inputs.input_ids,
inputs.attention_mask,
step_positions
)
return rewards[0].tolist()
def train_prm(
model: ProcessRewardModel,
train_data: list[dict],
epochs: int = 3
):
"""
Train PRM on step-labeled data.
train_data format:
[{
"problem": str,
"steps": [str, ...],
"labels": [0 or 1, ...] # 1 = correct step, 0 = error
}]
"""
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
criterion = nn.BCELoss()
for epoch in range(epochs):
total_loss = 0
for example in train_data:
optimizer.zero_grad()
# Get step positions
text, step_positions = model.prepare_input(
example["problem"],
example["steps"]
)
inputs = model.tokenizer(text, return_tensors="pt")
# Forward
predicted = model(
inputs.input_ids,
inputs.attention_mask,
step_positions
)
labels = torch.tensor(example["labels"], dtype=torch.float32)
loss = criterion(predicted, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_data):.4f}")
How PRMs work:
-
Architecture: A language model backbone (like Llama) processes the problem + partial solution. A linear head on top predicts a reward score for each step position.
-
Step positions: The model needs to know where each step ends. The
step_positionslist marks token indices at step boundaries. -
Output interpretation: Sigmoid activation produces scores in [0, 1]. Higher = more likely to lead to correct answer.
Training data format:
{
"problem": "What is 2 + 2?",
"steps": ["First, I identify this as addition", "2 + 2 = 4", "The answer is 4"],
"labels": [1, 1, 1] # All steps are correct
}
Labels are typically generated by:
- Human annotation (expensive but high quality)
- Completion-based labeling (Math-Shepherd approach)
- Model self-evaluation (cheaper but noisier)
PRM-Guided Generation
Use PRM to guide generation at each step:
class PRMGuidedGenerator:
def __init__(
self,
generator_model,
prm: ProcessRewardModel,
num_candidates: int = 8,
threshold: float = 0.5
):
self.generator = generator_model
self.prm = prm
self.num_candidates = num_candidates
self.threshold = threshold
def generate_solution(self, problem: str, max_steps: int = 15) -> str:
"""Generate solution with PRM guidance."""
steps = []
for i in range(max_steps):
# Generate candidates for next step
candidates = self.generate_step_candidates(problem, steps)
if not candidates:
break
# Score each candidate
scored_candidates = []
for candidate in candidates:
test_steps = steps + [candidate]
scores = self.prm.score_solution(problem, test_steps)
step_score = scores[-1] # Score of new step
scored_candidates.append((candidate, step_score))
# Select best candidate above threshold
valid_candidates = [
(c, s) for c, s in scored_candidates if s >= self.threshold
]
if not valid_candidates:
# All candidates below threshold - might need backtracking
# For now, take best anyway
best_candidate = max(scored_candidates, key=lambda x: x[1])[0]
else:
best_candidate = max(valid_candidates, key=lambda x: x[1])[0]
steps.append(best_candidate)
# Check if complete
if self.is_final_answer(best_candidate):
break
return '\n'.join(steps)
def generate_step_candidates(
self,
problem: str,
previous_steps: list[str]
) -> list[str]:
"""Generate candidate next steps."""
prompt = f"""Problem: {problem}
Previous steps:
{chr(10).join(f"{i+1}. {s}" for i, s in enumerate(previous_steps)) if previous_steps else "(Starting)"}
Generate the next reasoning step:"""
candidates = []
for _ in range(self.num_candidates):
response = self.generator.generate(prompt, temperature=0.8)
candidates.append(response.strip())
return list(set(candidates)) # Deduplicate
def is_final_answer(self, step: str) -> bool:
"""Check if step contains final answer."""
indicators = [
"the answer is",
"therefore,",
"thus,",
"in conclusion",
"final answer:",
"= "
]
step_lower = step.lower()
return any(ind in step_lower for ind in indicators)
PRM-guided generation flow:
-
Generate candidates: For each step, generate N candidate next steps (with temperature for diversity).
-
Score candidates: Evaluate each candidate by appending it to the current solution and running the PRM.
-
Threshold filtering: Only consider candidates scoring above the threshold (0.5 default). This prevents committing to likely-wrong steps.
-
Best selection: Among valid candidates, pick the highest-scoring one.
-
Fallback: If all candidates score below threshold, either backtrack or accept the best anyway (current implementation takes best).
Why threshold filtering? Without it, the generator might take a step with 0.2 score just because it's the "best" among bad options. The threshold forces the model to only proceed with reasonably confident steps.
Math-Shepherd PRM
The Math-Shepherd approach uses automated labeling:
def generate_prm_training_data(
problems: list[str],
solutions: list[list[str]],
ground_truth: list[str]
) -> list[dict]:
"""
Generate PRM training data using completion-based labeling.
For each step, complete to final answer multiple times.
If completions reach correct answer >50% of time, step is correct.
"""
training_data = []
for problem, steps, answer in zip(problems, solutions, ground_truth):
step_labels = []
for i, step in enumerate(steps):
partial_solution = steps[:i+1]
# Complete from this step multiple times
completions = complete_from_step(problem, partial_solution, n=10)
# Check how many reach correct answer
correct_count = sum(
1 for c in completions
if extract_answer(c) == answer
)
# Label: 1 if majority correct, 0 otherwise
label = 1 if correct_count >= 5 else 0
step_labels.append(label)
training_data.append({
"problem": problem,
"steps": steps,
"labels": step_labels
})
return training_data
Diverse Verifier Tree Search (DVTS)
DVTS Algorithm
DVTS from HuggingFace's search-and-learn framework combines tree search with diverse exploration:
class DVTS:
"""
Diverse Verifier Tree Search.
Key innovations:
1. Diversity-promoting selection
2. Process reward model guidance
3. Adaptive beam allocation
"""
def __init__(
self,
model,
prm: ProcessRewardModel,
n_beams: int = 4,
n_candidates: int = 8,
diversity_penalty: float = 0.1
):
self.model = model
self.prm = prm
self.n_beams = n_beams
self.n_candidates = n_candidates
self.diversity_penalty = diversity_penalty
def search(self, problem: str, max_steps: int = 15) -> str:
"""Run DVTS search."""
# Initialize beams
beams = [
{
"steps": [],
"score": 0.0,
"embedding": None
}
for _ in range(self.n_beams)
]
for step_idx in range(max_steps):
all_candidates = []
for beam in beams:
# Generate candidates
candidates = self.generate_candidates(
problem, beam["steps"]
)
for candidate in candidates:
new_steps = beam["steps"] + [candidate]
# Score with PRM
prm_scores = self.prm.score_solution(problem, new_steps)
step_score = prm_scores[-1]
# Compute embedding for diversity
embedding = self.get_embedding(new_steps)
all_candidates.append({
"steps": new_steps,
"score": step_score,
"embedding": embedding
})
# Select diverse top-k
beams = self.diverse_select(all_candidates, self.n_beams)
# Check for completion
if all(self.is_complete(b["steps"]) for b in beams):
break
# Return best beam
best_beam = max(beams, key=lambda b: b["score"])
return self.format_solution(problem, best_beam["steps"])
def diverse_select(
self,
candidates: list[dict],
k: int
) -> list[dict]:
"""Select top-k candidates while promoting diversity."""
if len(candidates) <= k:
return candidates
# Sort by score
sorted_candidates = sorted(
candidates,
key=lambda c: c["score"],
reverse=True
)
selected = [sorted_candidates[0]]
for _ in range(k - 1):
best_candidate = None
best_score = float('-inf')
for candidate in sorted_candidates:
if candidate in selected:
continue
# Compute diversity-adjusted score
min_similarity = min(
self.cosine_similarity(
candidate["embedding"],
s["embedding"]
)
for s in selected
)
diversity_bonus = (1 - min_similarity) * self.diversity_penalty
adjusted_score = candidate["score"] + diversity_bonus
if adjusted_score > best_score:
best_score = adjusted_score
best_candidate = candidate
if best_candidate:
selected.append(best_candidate)
return selected
def generate_candidates(
self,
problem: str,
previous_steps: list[str]
) -> list[str]:
"""Generate diverse candidate next steps."""
prompt = f"""Problem: {problem}
Solution so far:
{self.format_steps(previous_steps)}
Generate the next step:"""
candidates = []
# Use different temperatures for diversity
temperatures = [0.5, 0.7, 0.9, 1.1]
for temp in temperatures:
for _ in range(self.n_candidates // len(temperatures)):
response = self.model.generate(prompt, temperature=temp)
candidates.append(response.strip())
return list(set(candidates))
def get_embedding(self, steps: list[str]) -> np.ndarray:
"""Get embedding of solution state."""
text = '\n'.join(steps)
return self.model.encode(text)
def cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float:
"""Compute cosine similarity between embeddings."""
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
def format_steps(self, steps: list[str]) -> str:
if not steps:
return "(Starting)"
return '\n'.join(f"Step {i+1}: {s}" for i, s in enumerate(steps))
def is_complete(self, steps: list[str]) -> bool:
if not steps:
return False
last_step = steps[-1].lower()
return any(x in last_step for x in ['answer is', 'therefore', '= '])
def format_solution(self, problem: str, steps: list[str]) -> str:
return f"Problem: {problem}\n\nSolution:\n{self.format_steps(steps)}"
DVTS key innovations:
-
Diversity penalty: Standard beam search tends to keep similar solutions. DVTS adds a bonus for solutions that are different from already-selected ones, computed via embedding similarity.
-
Embedding-based similarity: Each partial solution is embedded, and cosine similarity measures how alike two solutions are. This captures semantic diversity, not just surface-level text differences.
-
Multi-temperature sampling: Candidates are generated at different temperatures (0.5, 0.7, 0.9, 1.1). Low temperatures produce safe, likely steps; high temperatures explore creative alternatives.
The diversity selection algorithm:
- Start by selecting the highest-scoring candidate
- For each remaining slot:
- For each unselected candidate, compute:
adjusted_score = raw_score + (1 - min_similarity_to_selected) * penalty - Select the candidate with highest adjusted score
- For each unselected candidate, compute:
- This creates a set that's both high-quality (score) and diverse (dissimilar)
Why diversity matters: Multiple similar high-scoring paths might all be wrong in the same way. Diversity ensures at least one path explores a different approach, increasing the chance of finding the correct answer.
HuggingFace Search-and-Learn Framework
Installation and Setup
pip install sal
# or
git clone https://github.com/huggingface/search-and-learn
cd search-and-learn
pip install -e .
Basic Usage
from sal.config import Config
from sal.search import beam_search, best_of_n, dvts
from sal.models.reward_models import load_prm
# Configuration
config = Config()
config.n = 32 # Number of candidates/beams
config.temperature = 0.8
config.search_batch_size = 8
# Load models
llm = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
prm = load_prm("RLHFlow/Llama3.1-8B-PRM-Deepseek-Data")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
# Prepare input
problem = "What is 17 * 24 + 156?"
input_batch = tokenizer(
[f"Solve: {problem}"],
return_tensors="pt",
padding=True
)
Best-of-N Search
# Best-of-N: Generate N solutions, pick best
result = best_of_n(
x=input_batch,
config=config,
llm=llm,
prm=prm
)
print(f"Best solution: {result.best_solution}")
print(f"Score: {result.best_score}")
print(f"Time: {result.elapsed_time:.2f}s")
Beam Search
# Beam search with PRM guidance
config.beam_width = 4
config.max_steps = 20
result = beam_search(
x=input_batch,
config=config,
llm=llm,
prm=prm
)
print(f"Solution: {result.solution}")
print(f"Step scores: {result.step_scores}")
DVTS
# Diverse Verifier Tree Search
config.n_beams = 4
config.diversity_penalty = 0.1
result = dvts(
x=input_batch,
config=config,
llm=llm,
prm=prm
)
print(f"Solution: {result.solution}")
print(f"Explored paths: {result.n_explored}")
Performance Comparison
From HuggingFace benchmarks on MATH500:
| Method | Accuracy | Time (s) | Notes |
|---|---|---|---|
| Greedy | 65.2% | 1.2 | Baseline |
| Best-of-N (N=32) | 73.4% | 3.5 | Simple, effective |
| Beam Search (k=4) | 75.8% | 10.1 | More compute |
| DVTS | 78.2% | 8.5 | Best accuracy/time trade-off |
Custom Search Configuration
from sal.config import SearchConfig
# Fine-tuned configuration
search_config = SearchConfig(
# Generation parameters
temperature=0.8,
top_p=0.95,
max_new_tokens=1024,
# Search parameters
n_candidates=32,
beam_width=4,
max_steps=20,
# PRM parameters
prm_batch_size=16,
min_step_score=0.3, # Prune steps below this
# Diversity
diversity_penalty=0.1,
n_diverse_beams=4,
# Early stopping
early_stop_threshold=0.95,
patience=3
)
result = dvts(input_batch, search_config, llm, prm)
Self-Reflection and Iterative Refinement
Self-Critique Loop
class SelfReflectionSolver:
def __init__(
self,
model,
max_iterations: int = 3,
improvement_threshold: float = 0.1
):
self.model = model
self.max_iterations = max_iterations
self.improvement_threshold = improvement_threshold
def solve(self, problem: str) -> str:
"""Solve with self-reflection loop."""
# Initial solution
solution = self.generate_solution(problem)
for i in range(self.max_iterations):
# Critique
critique = self.generate_critique(problem, solution)
# Check if improvements needed
if "no issues" in critique.lower() or "correct" in critique.lower():
break
# Refine
refined = self.refine_solution(problem, solution, critique)
# Check improvement
if self.is_better(problem, refined, solution):
solution = refined
else:
break
return solution
def generate_solution(self, problem: str) -> str:
prompt = f"""Problem: {problem}
Solve this step by step:"""
return self.model.generate(prompt)
def generate_critique(self, problem: str, solution: str) -> str:
prompt = f"""Problem: {problem}
Proposed solution:
{solution}
Carefully check this solution for errors:
1. Are all calculations correct?
2. Is the logic sound?
3. Does it answer the question asked?
4. Are there any missing steps?
Critique:"""
return self.model.generate(prompt)
def refine_solution(
self,
problem: str,
solution: str,
critique: str
) -> str:
prompt = f"""Problem: {problem}
Previous solution:
{solution}
Issues identified:
{critique}
Provide a corrected solution that addresses these issues:"""
return self.model.generate(prompt)
def is_better(
self,
problem: str,
new_solution: str,
old_solution: str
) -> bool:
"""Compare solutions using the model."""
prompt = f"""Problem: {problem}
Solution A:
{old_solution}
Solution B:
{new_solution}
Which solution is better? Answer with just "A" or "B":"""
response = self.model.generate(prompt).strip()
return "B" in response.upper()
Multi-Agent Debate
class DebateSolver:
def __init__(
self,
model,
n_debaters: int = 3,
n_rounds: int = 2
):
self.model = model
self.n_debaters = n_debaters
self.n_rounds = n_rounds
def solve(self, problem: str) -> str:
"""Solve through multi-agent debate."""
# Initial solutions from each debater
solutions = [
self.generate_solution(problem, i)
for i in range(self.n_debaters)
]
# Debate rounds
for round_idx in range(self.n_rounds):
new_solutions = []
for i, solution in enumerate(solutions):
# Get other solutions
other_solutions = [s for j, s in enumerate(solutions) if j != i]
# Debate and update
updated = self.debate_round(
problem, solution, other_solutions, i
)
new_solutions.append(updated)
solutions = new_solutions
# Final consensus
return self.reach_consensus(problem, solutions)
def generate_solution(self, problem: str, debater_id: int) -> str:
personas = [
"systematic and methodical",
"creative and intuitive",
"skeptical and rigorous"
]
persona = personas[debater_id % len(personas)]
prompt = f"""You are a {persona} problem solver.
Problem: {problem}
Solve this step by step:"""
return self.model.generate(prompt)
def debate_round(
self,
problem: str,
my_solution: str,
other_solutions: list[str],
debater_id: int
) -> str:
others_text = '\n\n'.join(
f"Other solver's answer:\n{s}" for s in other_solutions
)
prompt = f"""Problem: {problem}
Your previous answer:
{my_solution}
{others_text}
Consider the other solutions. If they have valid points, incorporate them.
If you see errors in their reasoning, explain why your approach is better.
Provide your updated solution:"""
return self.model.generate(prompt)
def reach_consensus(self, problem: str, solutions: list[str]) -> str:
solutions_text = '\n\n---\n\n'.join(
f"Solution {i+1}:\n{s}" for i, s in enumerate(solutions)
)
prompt = f"""Problem: {problem}
After debate, here are the final positions:
{solutions_text}
Synthesize these into the best final answer, taking the strongest elements from each:"""
return self.model.generate(prompt)
Multi-agent debate dynamics:
-
Diverse personas: Each debater has a different style (systematic, creative, skeptical). This prevents groupthink and ensures different approaches are explored.
-
Exposure to alternatives: In each round, debaters see other solutions. Strong arguments from others can change a debater's approach; weak arguments are critiqued.
-
Iterative refinement: Multiple rounds allow positions to evolve. Initial disagreements often converge as debaters adopt valid points from each other.
-
Consensus synthesis: The final step doesn't just pick a winner—it synthesizes the best elements from all positions.
When to use debate vs. other methods:
- Debate works best for open-ended problems where there's no single "correct" answer format (essays, analysis, recommendations)
- Majority voting (self-consistency) works better for problems with verifiable answers (math, factual questions)
- PRM-guided search is superior when you have a trained reward model for your domain
Budget Forcing
Controlling Test-Time Compute
Budget forcing allows explicit control over reasoning length:
class BudgetForcedGenerator:
def __init__(
self,
model,
tokenizer
):
self.model = model
self.tokenizer = tokenizer
self.think_start = "<think>"
self.think_end = "</think>"
def generate_with_budget(
self,
problem: str,
min_thinking_tokens: int = 100,
max_thinking_tokens: int = 2000
) -> str:
"""Generate with controlled thinking budget."""
prompt = f"{problem}\n{self.think_start}"
generated = ""
thinking_tokens = 0
while thinking_tokens < max_thinking_tokens:
# Generate next chunk
next_tokens = self.model.generate(
prompt + generated,
max_new_tokens=50,
temperature=0.7
)
generated += next_tokens
thinking_tokens = len(self.tokenizer.encode(generated))
# Check for premature ending
if self.think_end in next_tokens:
if thinking_tokens < min_thinking_tokens:
# Remove end token and continue
generated = generated.replace(self.think_end, "")
generated += "\nWait, let me think more carefully.\n"
else:
# Accept the ending
break
# Force ending if over budget
if self.think_end not in generated:
generated += f"\n{self.think_end}"
# Generate final answer
final_prompt = prompt + generated + "\nFinal answer:"
answer = self.model.generate(final_prompt, max_new_tokens=100)
return f"{self.think_start}{generated}{answer}"
def extend_thinking(self, partial_response: str) -> str:
"""Extend thinking by suppressing end token."""
if self.think_end in partial_response:
# Remove end token
partial_response = partial_response.replace(self.think_end, "")
# Add continuation prompt
partial_response += "\n\nWait, I should reconsider this."
return partial_response
def truncate_thinking(self, partial_response: str, max_tokens: int) -> str:
"""Truncate thinking to budget."""
tokens = self.tokenizer.encode(partial_response)
if len(tokens) <= max_tokens:
return partial_response
# Truncate and add ending
truncated_tokens = tokens[:max_tokens]
truncated = self.tokenizer.decode(truncated_tokens)
# Clean ending
truncated = truncated.rsplit('\n', 1)[0] # Remove partial line
truncated += f"\n{self.think_end}"
return truncated
Adaptive Budget Allocation
class AdaptiveBudget:
def __init__(
self,
model,
difficulty_estimator,
budget_map: dict = None
):
self.model = model
self.difficulty_estimator = difficulty_estimator
self.budget_map = budget_map or {
"easy": (50, 200),
"medium": (200, 1000),
"hard": (1000, 4000)
}
def estimate_difficulty(self, problem: str) -> str:
"""Estimate problem difficulty."""
prompt = f"""Rate the difficulty of this problem as easy, medium, or hard:
{problem}
Consider:
- Number of steps required
- Mathematical complexity
- Domain knowledge needed
Difficulty:"""
response = self.difficulty_estimator.generate(prompt).strip().lower()
if "easy" in response:
return "easy"
elif "hard" in response:
return "hard"
else:
return "medium"
def solve_with_adaptive_budget(self, problem: str) -> str:
"""Solve with difficulty-appropriate budget."""
difficulty = self.estimate_difficulty(problem)
min_tokens, max_tokens = self.budget_map[difficulty]
generator = BudgetForcedGenerator(self.model, self.model.tokenizer)
return generator.generate_with_budget(
problem,
min_thinking_tokens=min_tokens,
max_thinking_tokens=max_tokens
)
Putting It All Together
Unified Test-Time Scaling System
class TestTimeScaler:
"""Unified system for test-time compute scaling."""
def __init__(
self,
model,
prm: ProcessRewardModel = None,
default_method: str = "adaptive"
):
self.model = model
self.prm = prm
self.default_method = default_method
# Initialize methods
self.methods = {
"cot": lambda p: self.chain_of_thought(p),
"self_consistency": lambda p: self.self_consistency(p),
"best_of_n": lambda p: self.best_of_n(p),
"beam_search": lambda p: self.beam_search(p),
"tree_of_thoughts": lambda p: self.tree_of_thoughts(p),
"mcts": lambda p: self.mcts(p),
"dvts": lambda p: self.dvts(p),
"self_reflection": lambda p: self.self_reflection(p),
"adaptive": lambda p: self.adaptive_solve(p)
}
def solve(
self,
problem: str,
method: str = None,
compute_budget: str = "medium"
) -> dict:
"""
Solve problem with specified method and budget.
Returns dict with solution, method used, and metrics.
"""
method = method or self.default_method
if method not in self.methods:
raise ValueError(f"Unknown method: {method}")
start_time = time.time()
solution = self.methods[method](problem)
elapsed = time.time() - start_time
return {
"solution": solution,
"method": method,
"time": elapsed,
"compute_budget": compute_budget
}
def adaptive_solve(self, problem: str) -> str:
"""Automatically select method based on problem."""
# Estimate difficulty
difficulty = self.estimate_difficulty(problem)
if difficulty == "easy":
return self.chain_of_thought(problem)
elif difficulty == "medium":
return self.best_of_n(problem)
else:
return self.dvts(problem)
def estimate_difficulty(self, problem: str) -> str:
"""Quick difficulty estimation."""
# Simple heuristics
word_count = len(problem.split())
has_multiple_parts = any(x in problem.lower() for x in ['and', 'then', 'also'])
is_math = any(x in problem for x in ['+', '-', '*', '/', '='])
if word_count < 20 and not has_multiple_parts:
return "easy"
elif word_count > 50 or (has_multiple_parts and is_math):
return "hard"
else:
return "medium"
# Individual method implementations
def chain_of_thought(self, problem: str) -> str:
prompt = f"{problem}\n\nLet's think step by step:"
return self.model.generate(prompt)
def self_consistency(self, problem: str, n: int = 10) -> str:
responses = []
for _ in range(n):
resp = self.model.generate(
f"{problem}\n\nLet's think step by step:",
temperature=0.7
)
responses.append(resp)
# Majority vote on final answer
answers = [self.extract_answer(r) for r in responses]
return collections.Counter(answers).most_common(1)[0][0]
def best_of_n(self, problem: str, n: int = 16) -> str:
if not self.prm:
return self.self_consistency(problem, n)
best_solution = None
best_score = float('-inf')
for _ in range(n):
solution = self.model.generate(
f"{problem}\n\nSolve step by step:",
temperature=0.8
)
steps = self.parse_steps(solution)
scores = self.prm.score_solution(problem, steps)
avg_score = sum(scores) / len(scores) if scores else 0
if avg_score > best_score:
best_score = avg_score
best_solution = solution
return best_solution
def beam_search(self, problem: str) -> str:
searcher = StepBeamSearch(
self.model,
self.prm or self.model, # Use model as verifier if no PRM
beam_width=4,
max_steps=15
)
return searcher.search(problem)
def tree_of_thoughts(self, problem: str) -> str:
tot = TreeOfThoughts(
self.model,
self.prm or self.model,
max_depth=5,
branching_factor=3,
beam_width=3
)
return tot.solve_bfs(problem)
def mcts(self, problem: str) -> str:
searcher = MCTSReasoner(
self.model,
self.prm or self.model,
num_simulations=50,
max_depth=10
)
return searcher.solve(problem)
def dvts(self, problem: str) -> str:
searcher = DVTS(
self.model,
self.prm,
n_beams=4,
n_candidates=8
)
return searcher.search(problem)
def self_reflection(self, problem: str) -> str:
solver = SelfReflectionSolver(self.model, max_iterations=3)
return solver.solve(problem)
def extract_answer(self, response: str) -> str:
"""Extract final answer from response."""
patterns = [
r"answer is[:\s]*([^\n.]+)",
r"= ([^\n.]+)$",
r"therefore[,:\s]*([^\n.]+)"
]
for pattern in patterns:
match = re.search(pattern, response, re.IGNORECASE)
if match:
return match.group(1).strip()
return response.split('\n')[-1].strip()
def parse_steps(self, solution: str) -> list[str]:
"""Parse solution into steps."""
lines = solution.strip().split('\n')
steps = [l.strip() for l in lines if l.strip() and not l.startswith('#')]
return steps
Comparison and Selection Guide
When to Use Each Method
| Method | Best For | Compute | Latency |
|---|---|---|---|
| CoT | Simple reasoning, quick responses | Low | Low |
| Self-Consistency | Math problems, verifiable answers | Medium | Medium |
| Best-of-N | When you have a good verifier | Medium | Medium |
| Beam Search | Multi-step problems, PRM available | High | High |
| Tree of Thoughts | Complex planning, backtracking needed | Very High | Very High |
| MCTS | Game-like problems, clear rewards | Very High | Very High |
| DVTS | Complex math, need diversity | High | High |
| Self-Reflection | Open-ended problems, no verifier | Medium | Medium |
Decision Flowchart
START
│
├─ Is answer easily verifiable? (math, code)
│ ├─ YES → Do you have a PRM?
│ │ ├─ YES → DVTS or Beam Search
│ │ └─ NO → Best-of-N with self-verification
│ │
│ └─ NO → Is backtracking helpful?
│ ├─ YES → Tree of Thoughts or MCTS
│ └─ NO → Self-Reflection
│
├─ Is latency critical?
│ ├─ YES → CoT (single pass)
│ └─ NO → Self-Consistency
│
└─ Is the problem very complex?
├─ YES → DVTS with high beam count
└─ NO → Best-of-N (N=8-16)
Conclusion
Test-time compute scaling represents a fundamental shift in how we improve AI capabilities. Instead of relying solely on larger models, we can achieve better results by having models "think longer" at inference time.
Key takeaways:
- Start simple: Chain-of-Thought and Self-Consistency provide significant improvements with minimal complexity
- Add verifiers: Process Reward Models enable sophisticated search methods
- Match method to problem: Use simpler methods for easier problems, complex search for hard ones
- Budget wisely: More compute helps, but with diminishing returns
- Combine techniques: The best systems often combine multiple approaches
The HuggingFace search-and-learn framework provides production-ready implementations of these techniques, making it accessible to implement sophisticated test-time scaling in your own applications.
Frequently Asked Questions
Related Articles
Reasoning Models: A Brief Framework
Understanding o1, o3, DeepSeek R1, and the shift from pre-training scaling to inference-time and training-time scaling—the defining trend of 2025.
Training Reasoning Models: PPO, GRPO, Reward Functions, and RLVR
A deep technical guide to training reasoning models like o1 and DeepSeek R1—covering PPO, GRPO, reward function design, RLVR, and distillation techniques.
GRPO: Group Relative Policy Optimization Explained
Understanding Group Relative Policy Optimization—the technique behind DeepSeek's training efficiency and a simpler alternative to PPO-based RLHF.