Skip to main content
Back to Blog

Training Reasoning Models: PPO, GRPO, Reward Functions, and RLVR

A deep technical guide to training reasoning models like o1 and DeepSeek R1—covering PPO, GRPO, reward function design, RLVR, and distillation techniques.

21 min read
Share:

The Training Revolution Behind Reasoning Models

While test-time compute scaling (CoT, Tree of Thoughts, MCTS) improves what models can do at inference, training-time innovations determine what models learn to do in the first place. The breakthrough reasoning capabilities of o1, o3, and DeepSeek R1 come from novel training algorithms—particularly in how rewards are defined and optimized.

This post provides a comprehensive technical guide to training reasoning models, covering:

  • Reward function design (outcome, process, rule-based, learned)
  • PPO (Proximal Policy Optimization) and its components
  • GRPO (Group Relative Policy Optimization) and why it's revolutionary
  • RLVR (Reinforcement Learning with Verifiable Rewards)
  • DPO and other alternatives
  • How OpenAI o1/o3 and DeepSeek R1 are actually trained
  • Distillation to compress reasoning into smaller models

Understanding Reward Functions

The Role of Rewards in RL for LLMs

In reinforcement learning for language models, the reward function defines what "good" outputs look like. The model learns to maximize expected reward through gradient updates.

Code
Policy π(response | prompt) → Reward R(prompt, response) → Policy Update

The reward function is arguably the most critical design decision in training reasoning models. Different reward signals lead to fundamentally different model behaviors.

Types of Reward Functions

1. Outcome Reward Models (ORMs)

ORMs evaluate only the final answer—they don't care about reasoning steps.

Why ORMs are the simplest starting point: The appeal of ORMs is their simplicity—you need only a ground truth answer and a way to check equality. No step-level labels, no learned reward model, just binary correctness. This simplicity comes with a tradeoff: the model receives reward signal only at the end of potentially long reasoning chains, making it hard to learn which steps were good versus bad.

The credit assignment problem: Imagine a model produces a 20-step mathematical proof that arrives at the wrong answer. Which step caused the error? Was it step 3 where it made an arithmetic mistake, or step 15 where it applied the wrong formula? ORMs can't tell—they just report "wrong." The model must learn through many trials which patterns lead to correct answers, a slow and sample-inefficient process. Despite this limitation, ORMs work surprisingly well when combined with algorithms like GRPO that generate many samples per problem.

When ORMs shine: ORMs excel in domains with unambiguous correctness criteria: math problems with numeric answers, code that passes test cases, logic puzzles with definite solutions. They struggle in domains where partial credit matters or where "close" answers should receive some reward.

Python
class OutcomeRewardModel:
    """
    Evaluates correctness of final answer only.
    Binary or continuous score based on answer quality.
    """

    def __init__(self, verifier=None):
        self.verifier = verifier

    def compute_reward(
        self,
        problem: str,
        response: str,
        ground_truth: str = None
    ) -> float:
        """
        Compute outcome-based reward.

        Returns:
            1.0 if correct, 0.0 if incorrect (binary)
            or continuous score [0, 1]
        """
        # Extract final answer from response
        predicted_answer = self.extract_answer(response)

        if ground_truth is not None:
            # Exact match or semantic equivalence
            if self.normalize(predicted_answer) == self.normalize(ground_truth):
                return 1.0
            return 0.0

        # Use verifier model if no ground truth
        if self.verifier:
            return self.verifier.score(problem, predicted_answer)

        raise ValueError("Need ground_truth or verifier")

    def extract_answer(self, response: str) -> str:
        """Extract final answer from reasoning response."""
        # Look for common answer patterns
        patterns = [
            r"\\boxed{([^}]+)}",              # LaTeX boxed
            r"[Aa]nswer[:\s]+([^\n]+)",       # "Answer: X"
            r"[Tt]herefore[,:\s]+([^\n]+)",   # "Therefore, X"
            r"= ([^\n]+)$"                     # Final equation
        ]

        for pattern in patterns:
            match = re.search(pattern, response)
            if match:
                return match.group(1).strip()

        # Fallback: last line
        return response.strip().split('\n')[-1]

    def normalize(self, answer: str) -> str:
        """Normalize answer for comparison."""
        # Remove whitespace, lowercase
        answer = ' '.join(answer.lower().split())

        # Numeric normalization
        try:
            num = float(answer.replace(',', ''))
            return f"{num:.6f}"
        except:
            return answer

Advantages:

  • Simple to implement
  • No need for step-level labels
  • Works with any verifiable domain

Disadvantages:

  • Sparse signal (only at the end)
  • Can't distinguish good reasoning with wrong answer from lucky guessing
  • Credit assignment problem: which steps caused the error?

Understanding the answer extraction patterns: The extract_answer method uses a cascade of regex patterns to find the final answer in model responses. This is crucial because models format answers inconsistently—some use LaTeX \boxed{}, others write "Answer: X", and some bury the answer in prose. The pattern hierarchy matters: we check the most structured formats first (LaTeX boxed) because they're most reliable, falling back to less structured patterns only when needed. In production, you'll likely need to expand these patterns based on your model's output style.

Normalization is harder than it looks: The normalize method handles the surprisingly complex task of comparing answers. Two answers might be semantically identical but textually different: "1/2" and "0.5" and "0.500" and "50%" all represent the same value. The numeric normalization handles floating-point comparison, but real production systems need far more: unit conversion ("1 meter" vs "100 cm"), symbolic equivalence ("x^2" vs "x*x"), and domain-specific rules. This is where many RL training pipelines silently fail—answers that should match don't.

2. Process Reward Models (PRMs)

PRMs evaluate each reasoning step, providing dense reward signal.

Why dense rewards dramatically accelerate learning: The fundamental limitation of ORMs—sparse, end-of-sequence feedback—is precisely what PRMs address. By scoring each step, PRMs provide dense supervision throughout the reasoning chain. If step 5 introduces an error, the PRM flags it immediately rather than waiting until the final answer. This transforms the credit assignment problem: instead of asking "somewhere in these 20 steps, something went wrong," we can ask "was this specific step correct?"

The labeling challenge: PRMs require step-level labels, which are expensive to obtain. For a 20-step solution, you need 20 correctness judgments, not just one. OpenAI's PRM800K dataset required human mathematicians to verify 800,000 individual reasoning steps—a massive investment. This cost explains why rule-based rewards (like DeepSeek R1's approach) have become popular: they trade some precision for eliminaing labeling costs entirely.

How PRMs enable search: PRMs aren't just for training—they're crucial for test-time compute scaling. During inference, you can generate multiple solution paths and use the PRM to score each step, pruning bad branches early. This is the foundation of techniques like beam search with process rewards, where computation focuses on promising reasoning paths rather than exploring dead ends.

Python
class ProcessRewardModel(nn.Module):
    """
    Process Reward Model that scores each reasoning step.

    Architecture: Language model backbone + step-level classifier head.
    """

    def __init__(
        self,
        backbone: str = "meta-llama/Llama-3.1-8B",
        num_labels: int = 2  # correct/incorrect
    ):
        super().__init__()

        self.backbone = AutoModel.from_pretrained(backbone)
        self.hidden_size = self.backbone.config.hidden_size

        # Step classification head
        self.step_classifier = nn.Sequential(
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(self.hidden_size, num_labels)
        )

        # Optional: continuous reward head
        self.reward_head = nn.Sequential(
            nn.Linear(self.hidden_size, self.hidden_size // 2),
            nn.GELU(),
            nn.Linear(self.hidden_size // 2, 1),
            nn.Sigmoid()
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        step_end_positions: list[int]
    ) -> dict:
        """
        Forward pass computing rewards at step boundaries.

        Args:
            input_ids: Tokenized problem + solution
            attention_mask: Attention mask
            step_end_positions: Token positions marking end of each step

        Returns:
            dict with step_rewards and step_labels
        """
        # Get hidden states from backbone
        outputs = self.backbone(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        )
        hidden_states = outputs.last_hidden_state

        # Extract representations at step boundaries
        batch_size = hidden_states.size(0)
        step_representations = []

        for batch_idx in range(batch_size):
            batch_step_reps = hidden_states[batch_idx, step_end_positions[batch_idx], :]
            step_representations.append(batch_step_reps)

        step_reps = torch.stack(step_representations)  # [batch, num_steps, hidden]

        # Compute step-level predictions
        step_logits = self.step_classifier(step_reps)  # [batch, num_steps, 2]
        step_rewards = self.reward_head(step_reps).squeeze(-1)  # [batch, num_steps]

        return {
            "step_logits": step_logits,
            "step_rewards": step_rewards,
            "step_probs": F.softmax(step_logits, dim=-1)[:, :, 1]  # P(correct)
        }

    def score_solution(
        self,
        problem: str,
        solution_steps: list[str],
        tokenizer
    ) -> list[float]:
        """
        Score each step in a solution.

        Returns list of scores [0, 1] for each step.
        """
        # Build input text
        text = f"Problem: {problem}\n\nSolution:\n"
        step_positions = []

        for i, step in enumerate(solution_steps):
            text += f"Step {i+1}: {step}\n"
            # Mark position at end of step
            tokens_so_far = len(tokenizer.encode(text))
            step_positions.append(tokens_so_far - 1)

        # Tokenize
        inputs = tokenizer(
            text,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=2048
        )

        # Forward pass
        with torch.no_grad():
            outputs = self.forward(
                inputs.input_ids,
                inputs.attention_mask,
                [step_positions]
            )

        return outputs["step_rewards"][0].tolist()


def train_prm(
    model: ProcessRewardModel,
    train_dataset: Dataset,
    tokenizer,
    epochs: int = 3,
    learning_rate: float = 1e-5
):
    """
    Train Process Reward Model on step-labeled data.

    Dataset format:
    {
        "problem": str,
        "steps": list[str],
        "step_labels": list[int]  # 1=correct, 0=incorrect
    }
    """
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    model.train()

    for epoch in range(epochs):
        total_loss = 0
        correct = 0
        total = 0

        for batch in train_dataset:
            optimizer.zero_grad()

            # Prepare inputs
            inputs, step_positions = prepare_prm_inputs(
                batch["problems"],
                batch["steps"],
                tokenizer
            )

            # Forward
            outputs = model(
                inputs.input_ids.to(model.device),
                inputs.attention_mask.to(model.device),
                step_positions
            )

            # Compute loss
            labels = torch.tensor(batch["step_labels"]).to(model.device)
            loss = criterion(
                outputs["step_logits"].view(-1, 2),
                labels.view(-1)
            )

            # Backprop
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item()

            # Accuracy
            preds = outputs["step_logits"].argmax(dim=-1)
            correct += (preds == labels).sum().item()
            total += labels.numel()

        acc = correct / total
        print(f"Epoch {epoch+1}: Loss={total_loss:.4f}, Accuracy={acc:.4f}")

Understanding the architecture: The PRM uses a language model backbone to understand the reasoning context, then adds classification heads specifically for step-level scoring. The key architectural decision is extracting representations at "step boundary" positions—the tokens where one reasoning step ends. These boundary positions carry information about what the model has understood so far, making them natural points to assess correctness.

Why two output heads: The architecture includes both a classifier head (step_classifier, binary correct/incorrect) and a continuous reward head (reward_head, scalar 0-1). The classifier is useful during training with binary labels, while the continuous head is better for search applications where you want fine-grained quality scores. In practice, many implementations use only one, but having both provides flexibility.

The training loop dissected: The train_prm function implements standard supervised learning, but notice the careful handling of variable-length sequences. Each example has a different number of steps, and each step has a label. The prepare_prm_inputs function (not shown) must handle this complexity—padding sequences, tracking which positions correspond to step boundaries, and aligning labels correctly. Getting this bookkeeping right is where many PRM implementations fail.

Gradient clipping is essential: The clip_grad_norm_ call prevents gradient explosions that commonly occur when training on long sequences with many step boundaries. Without clipping, a single long solution with many incorrect steps can produce gradients that destabilize training. The value 1.0 is a reasonable default, but you may need to tune this based on your sequence lengths.

3. Rule-Based Rewards

DeepSeek R1-Zero uses remarkably simple rule-based rewards.

The surprising power of simplicity: DeepSeek's R1-Zero demonstrated that you don't need learned reward models or human-labeled data to train reasoning. Their rule-based system uses just two signals: format correctness (did the model use <think> tags?) and answer correctness (did it get the right answer?). This simplicity seemed almost naive compared to the sophisticated reward models in prior work, yet it produced state-of-the-art reasoning. The key insight: for verifiable domains, the correctness signal alone provides enough gradient for learning.

Why format rewards matter: The format reward (checking for <think> tags, substantial reasoning content) might seem cosmetic, but it serves a crucial purpose: it encourages the model to "show its work." Without format incentives, models under RL pressure optimize directly for answers, often collapsing to short, uninterpretable outputs. The format reward acts as a regularizer, ensuring the model maintains readable reasoning chains even as it optimizes for correctness.

The verification bottleneck: Rule-based rewards are only as good as your verification logic. For math, numeric comparison with tolerance handles most cases, but edge cases abound: symbolic expressions, fractions, scientific notation, mixed units. For code, you need a sandboxed execution environment that's both secure and fast enough for RL training (millions of evaluations). The code_correctness method shown is simplified—production systems use containers, time limits, and memory limits to safely execute untrusted code.

Python
class RuleBasedRewardSystem:
    """
    Rule-based reward system as used in DeepSeek R1-Zero.

    No neural network—just rules checking format and correctness.
    """

    def __init__(
        self,
        format_weight: float = 0.1,
        correctness_weight: float = 1.0
    ):
        self.format_weight = format_weight
        self.correctness_weight = correctness_weight

    def compute_reward(
        self,
        problem: str,
        response: str,
        ground_truth: str,
        problem_type: str = "math"
    ) -> dict:
        """
        Compute rule-based reward.

        Returns dict with reward breakdown.
        """
        rewards = {}

        # 1. Format reward: Check for proper thinking structure
        format_reward = self.check_format(response)
        rewards["format"] = format_reward * self.format_weight

        # 2. Correctness reward: Binary signal for answer correctness
        correctness_reward = self.check_correctness(
            response, ground_truth, problem_type
        )
        rewards["correctness"] = correctness_reward * self.correctness_weight

        # Total reward
        rewards["total"] = rewards["format"] + rewards["correctness"]

        return rewards

    def check_format(self, response: str) -> float:
        """
        Check if response follows expected format.

        R1-Zero expected: <think>...</think> structure
        """
        format_score = 0.0

        # Check for thinking tags
        has_think_start = "<think>" in response or "<thinking>" in response
        has_think_end = "</think>" in response or "</thinking>" in response

        if has_think_start and has_think_end:
            format_score += 0.5

            # Check thinking comes before answer
            think_end_pos = response.find("</think>")
            if think_end_pos == -1:
                think_end_pos = response.find("</thinking>")

            # Check there's content after thinking
            content_after = response[think_end_pos:].strip()
            if len(content_after) > 20:  # Has substantial answer
                format_score += 0.3

            # Check thinking is substantial
            think_start = response.find("<think>")
            if think_start == -1:
                think_start = response.find("<thinking>")

            thinking_content = response[think_start:think_end_pos]
            if len(thinking_content) > 100:  # Substantial thinking
                format_score += 0.2

        return format_score

    def check_correctness(
        self,
        response: str,
        ground_truth: str,
        problem_type: str
    ) -> float:
        """
        Check answer correctness with type-specific rules.
        """
        predicted = self.extract_answer(response)
        expected = self.normalize_answer(ground_truth, problem_type)
        predicted_norm = self.normalize_answer(predicted, problem_type)

        if problem_type == "math":
            return self.math_equivalence(predicted_norm, expected)
        elif problem_type == "code":
            return self.code_correctness(predicted, expected)
        elif problem_type == "mcq":
            return 1.0 if predicted_norm == expected else 0.0
        else:
            # String matching with normalization
            return 1.0 if predicted_norm == expected else 0.0

    def math_equivalence(self, pred: str, expected: str) -> float:
        """Check mathematical equivalence."""
        try:
            # Try numeric comparison
            pred_num = float(pred.replace(',', ''))
            exp_num = float(expected.replace(',', ''))

            # Allow small tolerance for floating point
            if abs(pred_num - exp_num) < 1e-6:
                return 1.0

            # Check relative error
            if exp_num != 0:
                rel_error = abs(pred_num - exp_num) / abs(exp_num)
                if rel_error < 1e-4:
                    return 1.0

            return 0.0

        except ValueError:
            # Non-numeric: exact string match
            return 1.0 if pred == expected else 0.0

    def code_correctness(self, pred_code: str, test_cases: str) -> float:
        """
        Check code correctness via test execution.

        This is a simplified version—production would sandbox execution.
        """
        # Parse test cases
        tests = self.parse_test_cases(test_cases)

        if not tests:
            return 0.0

        passed = 0
        for test_input, expected_output in tests:
            try:
                # Execute (in sandbox in production!)
                actual = self.execute_code(pred_code, test_input)
                if str(actual).strip() == str(expected_output).strip():
                    passed += 1
            except:
                pass

        return passed / len(tests)

    def extract_answer(self, response: str) -> str:
        """Extract final answer from response."""
        # Check for boxed answer (LaTeX)
        boxed_match = re.search(r"\\boxed{([^}]+)}", response)
        if boxed_match:
            return boxed_match.group(1)

        # Check for explicit answer marker
        answer_match = re.search(
            r"(?:final\s+)?answer[:\s]+(.+?)(?:\n|$)",
            response,
            re.IGNORECASE
        )
        if answer_match:
            return answer_match.group(1).strip()

        # Look for content after thinking
        if "</think>" in response:
            after_think = response.split("</think>")[-1]
            return after_think.strip()

        # Fallback: last line
        return response.strip().split('\n')[-1]

    def normalize_answer(self, answer: str, problem_type: str) -> str:
        """Normalize answer for comparison."""
        answer = answer.strip().lower()

        # Remove common prefixes
        prefixes = ["the answer is", "answer:", "therefore"]
        for prefix in prefixes:
            if answer.startswith(prefix):
                answer = answer[len(prefix):].strip()

        # Type-specific normalization
        if problem_type == "math":
            # Remove units, currency symbols
            answer = re.sub(r"[\$£€%]", "", answer)
            # Normalize fractions
            if "/" in answer:
                try:
                    parts = answer.split("/")
                    answer = str(float(parts[0]) / float(parts[1]))
                except:
                    pass

        return answer

The format scoring breakdown: The check_format method assigns partial credit for structural elements: 0.5 for having thinking tags, 0.3 for content after thinking, 0.2 for substantial thinking content. These weights are somewhat arbitrary—DeepSeek's actual weights aren't published. The key insight is that format scoring should be much smaller than correctness scoring (0.1 weight vs 1.0), so models don't game format while ignoring substance.

Math equivalence is tricky: The math_equivalence method first tries numeric comparison with tolerance (absolute and relative error), then falls back to string matching. This handles most cases but misses symbolic equivalence: "2+3" and "5" are mathematically equivalent but fail string comparison. Production systems often use symbolic math libraries (SymPy) for proper equivalence checking, though this adds latency.

Why relative error matters: Absolute tolerance (abs(pred - expected) < 1e-6) fails for large numbers—if the answer is 1 million, a difference of 0.001 should be fine, but it exceeds 1e-6. The relative error check (rel_error < 1e-4) handles this, accepting answers within 0.01% of the correct value regardless of magnitude. Using both catches cases where either alone would fail.

4. Learned Reward Models

Neural networks trained to predict human preferences.

When you can't verify, you learn: For many tasks—creative writing, open-ended questions, nuanced reasoning—there's no ground truth to verify against. You can't write a rule to check if an explanation is "good." Learned reward models fill this gap by training neural networks to predict which responses humans prefer. This is the standard RLHF approach used in ChatGPT, Claude, and most instruction-following models.

The preference modeling assumption: Learned reward models assume human preferences can be captured by a scalar value—given a response, the model outputs a single number representing "quality." This is a strong assumption. Human preferences are often inconsistent, context-dependent, and multi-dimensional (a response might be helpful but verbose). Despite these limitations, scalar reward models work surprisingly well in practice, likely because they capture the dominant quality dimension.

The Bradley-Terry model: The training loss isn't "predict the reward for this response" but rather "given two responses, predict which one humans preferred." This pairwise comparison approach (Bradley-Terry) is more robust than absolute scoring: humans are better at saying "A is better than B" than assigning absolute quality numbers. The model learns reward values such that preferred responses score higher than non-preferred ones.

Python
class LearnedRewardModel(nn.Module):
    """
    Reward model learned from human preference data.

    Standard approach in RLHF pipelines.
    """

    def __init__(self, backbone: str = "meta-llama/Llama-3.1-8B"):
        super().__init__()

        self.backbone = AutoModelForCausalLM.from_pretrained(backbone)
        self.hidden_size = self.backbone.config.hidden_size

        # Reward head: maps final hidden state to scalar
        self.reward_head = nn.Linear(self.hidden_size, 1)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor
    ) -> torch.Tensor:
        """Compute reward for input sequence."""
        outputs = self.backbone(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        )

        # Use last token's hidden state
        last_hidden = outputs.hidden_states[-1]

        # Get position of last non-padding token
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_indices = torch.arange(input_ids.size(0), device=input_ids.device)

        last_token_hidden = last_hidden[batch_indices, sequence_lengths]

        # Compute reward
        reward = self.reward_head(last_token_hidden)

        return reward.squeeze(-1)


def train_reward_model_from_preferences(
    model: LearnedRewardModel,
    preference_data: Dataset,
    tokenizer,
    epochs: int = 1,
    learning_rate: float = 1e-5
):
    """
    Train reward model from human preference data.

    Data format:
    {
        "prompt": str,
        "chosen": str,      # Preferred response
        "rejected": str     # Non-preferred response
    }

    Uses Bradley-Terry model: P(chosen > rejected) = σ(r_chosen - r_rejected)
    """
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

    model.train()

    for epoch in range(epochs):
        total_loss = 0
        total_correct = 0
        total_samples = 0

        for batch in preference_data:
            optimizer.zero_grad()

            # Tokenize chosen and rejected
            chosen_inputs = tokenizer(
                [f"{p} {c}" for p, c in zip(batch["prompts"], batch["chosen"])],
                return_tensors="pt",
                padding=True,
                truncation=True
            ).to(model.device)

            rejected_inputs = tokenizer(
                [f"{p} {r}" for p, r in zip(batch["prompts"], batch["rejected"])],
                return_tensors="pt",
                padding=True,
                truncation=True
            ).to(model.device)

            # Compute rewards
            chosen_rewards = model(**chosen_inputs)
            rejected_rewards = model(**rejected_inputs)

            # Bradley-Terry loss: -log(σ(r_chosen - r_rejected))
            loss = -F.logsigmoid(chosen_rewards - rejected_rewards).mean()

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item()

            # Accuracy: how often does model prefer chosen?
            correct = (chosen_rewards > rejected_rewards).sum().item()
            total_correct += correct
            total_samples += len(batch["prompts"])

        acc = total_correct / total_samples
        avg_loss = total_loss / len(preference_data)
        print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, Accuracy={acc:.4f}")

Why use the last token's hidden state: The model processes the entire prompt+response sequence, building up contextual representations. The last token's hidden state has "seen" everything that came before, making it a natural summary of the full response. Alternative approaches exist: averaging all token representations, using the [CLS] token if present, or attention pooling. The last-token approach is simplest and works well because modern language models are trained autoregressively, making the final position information-rich.

The logsigmoid loss explained: The loss -F.logsigmoid(chosen_rewards - rejected_rewards) implements the Bradley-Terry preference model. When chosen_rewards > rejected_rewards (correct preference), the loss is small (sigmoid approaches 1, logsigmoid approaches 0). When wrong, sigmoid approaches 0, and -logsigmoid becomes large. The margin matters: the model is rewarded for increasing the gap between chosen and rejected, not just getting the order right.

Accuracy as a sanity check: During training, we track how often the model correctly ranks chosen above rejected. Random chance would give 50% accuracy. Early in training, accuracy should climb quickly to 60-70%, then more slowly toward 80-90%. If accuracy doesn't improve, check your data (are labels correct?), learning rate (too high causes instability, too low doesn't learn), or architecture (is the model expressive enough?).

Reward Hacking and Mitigations

Models can learn to exploit reward functions in unintended ways.

The Goodhart's Law problem: "When a measure becomes a target, it ceases to be a good measure." Reward functions are proxies for what we actually want—a high-scoring response according to our reward model isn't necessarily a good response. Models under optimization pressure find these gaps. A verbose model might score higher because length correlates with helpfulness in training data, even when brevity is better. A flattering model might score higher because humans prefer agreeable responses, even when disagreement is warranted.

Common hacking patterns: Length exploitation (adding unnecessary content), sycophancy (excessive agreement), format gaming (using structures that score well regardless of content), repetition (repeating phrases that increase scores), and keyword stuffing (including words the reward model associates with quality). Each hack exploits a legitimate correlation in training data, taken to an extreme the training data didn't contain.

Defense in depth: No single mitigation prevents all hacking. Effective approaches combine multiple defenses: KL penalty to prevent drifting too far from a known-good reference policy, length normalization to remove length bias, ensemble rewards to make gaming harder (hard to simultaneously fool multiple models), and human evaluation to catch hacks that pass automated checks.

Python
class RewardHackingMitigations:
    """
    Techniques to prevent reward hacking in RL training.
    """

    @staticmethod
    def length_penalty(response: str, target_length: int = 500) -> float:
        """
        Penalize responses that are too long or too short.

        Prevents verbose padding or terse non-answers.
        """
        length = len(response.split())
        ratio = length / target_length

        if ratio < 0.5:
            return -0.1 * (0.5 - ratio)  # Too short
        elif ratio > 2.0:
            return -0.1 * (ratio - 2.0)  # Too long
        return 0.0

    @staticmethod
    def repetition_penalty(response: str) -> float:
        """
        Penalize repetitive text.

        Detects sentence-level and phrase-level repetition.
        """
        sentences = response.split('.')
        unique_sentences = set(s.strip().lower() for s in sentences if s.strip())

        if len(sentences) > 1:
            repetition_ratio = 1 - (len(unique_sentences) / len(sentences))
            if repetition_ratio > 0.3:
                return -0.2 * repetition_ratio

        # N-gram repetition
        words = response.lower().split()
        trigrams = [tuple(words[i:i+3]) for i in range(len(words)-2)]
        unique_trigrams = set(trigrams)

        if len(trigrams) > 0:
            trigram_repetition = 1 - (len(unique_trigrams) / len(trigrams))
            if trigram_repetition > 0.3:
                return -0.1 * trigram_repetition

        return 0.0

    @staticmethod
    def kl_penalty(
        policy_logprobs: torch.Tensor,
        reference_logprobs: torch.Tensor,
        beta: float = 0.1
    ) -> torch.Tensor:
        """
        KL divergence penalty to prevent drift from reference.

        This is the standard RLHF regularization term.
        """
        kl = policy_logprobs - reference_logprobs
        return -beta * kl.mean()

    @staticmethod
    def ensemble_reward(
        response: str,
        reward_models: list,
        problem: str
    ) -> float:
        """
        Use ensemble of reward models to reduce hackability.

        Hard to hack multiple different models simultaneously.
        """
        rewards = [rm.score(problem, response) for rm in reward_models]

        # Conservative: use minimum
        min_reward = min(rewards)

        # Or: use variance-penalized mean
        mean_reward = sum(rewards) / len(rewards)
        std_reward = (sum((r - mean_reward)**2 for r in rewards) / len(rewards)) ** 0.5

        # Penalize high variance (disagreement might indicate hacking)
        return mean_reward - 0.5 * std_reward

PPO: Proximal Policy Optimization

PPO Fundamentals

PPO is the dominant algorithm for RLHF. It balances exploration with stability through clipped objectives.

Why PPO dominates RLHF: Before PPO, policy gradient methods were notoriously unstable—a single bad update could collapse months of training. Vanilla policy gradients have high variance, and trust region methods (TRPO) worked but were computationally expensive with second-order optimization. PPO achieves trust region-like stability with only first-order methods by using a clever clipping mechanism. This combination of stability and simplicity made PPO the default choice for RLHF.

The actor-critic architecture: PPO uses two neural networks: the policy (actor) that generates text, and the value function (critic) that estimates expected future reward from any state. The critic enables better credit assignment—instead of asking "was this response good?" we can ask "at each token, was the remaining trajectory better than expected?" This token-level feedback accelerates learning compared to sparse episode-level rewards.

The memory challenge for LLMs: PPO's actor-critic setup means storing two full language models in memory: the policy being trained, plus the value model (typically the same size). Add a frozen reference policy for KL computation and you need ~3x model memory. This is why GRPO, which eliminates the critic, has become popular for large models—it halves the memory requirement.

Python
class PPOConfig:
    """Configuration for PPO training."""

    # Core hyperparameters
    learning_rate: float = 1e-5
    gamma: float = 0.99                 # Discount factor
    gae_lambda: float = 0.95            # GAE parameter
    clip_epsilon: float = 0.2           # PPO clipping
    value_clip_epsilon: float = 0.2     # Value function clipping

    # Training parameters
    ppo_epochs: int = 4                 # Updates per batch
    mini_batch_size: int = 64
    max_grad_norm: float = 0.5

    # KL penalty (alternative to clipping)
    use_kl_penalty: bool = False
    target_kl: float = 0.02
    kl_coef: float = 0.1

    # Entropy bonus
    entropy_coef: float = 0.01


class PPOTrainer:
    """
    Proximal Policy Optimization for language models.

    Implements the standard PPO algorithm with:
    - Clipped surrogate objective
    - Value function with optional clipping
    - GAE for advantage estimation
    - KL penalty as regularization
    """

    def __init__(
        self,
        policy_model: nn.Module,
        value_model: nn.Module,
        reward_model: nn.Module,
        reference_model: nn.Module,
        tokenizer,
        config: PPOConfig
    ):
        self.policy = policy_model
        self.value = value_model
        self.reward = reward_model
        self.reference = reference_model
        self.tokenizer = tokenizer
        self.config = config

        # Optimizers
        self.policy_optimizer = torch.optim.AdamW(
            self.policy.parameters(),
            lr=config.learning_rate
        )
        self.value_optimizer = torch.optim.AdamW(
            self.value.parameters(),
            lr=config.learning_rate
        )

    def generate_and_score(
        self,
        prompts: list[str],
        generation_config: dict = None
    ) -> dict:
        """
        Generate responses and compute rewards.

        Returns:
            dict with responses, rewards, values, logprobs
        """
        generation_config = generation_config or {
            "max_new_tokens": 512,
            "temperature": 1.0,
            "do_sample": True
        }

        # Tokenize prompts
        inputs = self.tokenizer(
            prompts,
            return_tensors="pt",
            padding=True,
            truncation=True
        ).to(self.policy.device)

        # Generate responses
        with torch.no_grad():
            outputs = self.policy.generate(
                **inputs,
                **generation_config,
                return_dict_in_generate=True,
                output_scores=True
            )

        response_ids = outputs.sequences[:, inputs.input_ids.size(1):]
        responses = self.tokenizer.batch_decode(
            response_ids,
            skip_special_tokens=True
        )

        # Compute log probabilities
        policy_logprobs = self.compute_logprobs(
            self.policy,
            inputs.input_ids,
            response_ids
        )

        reference_logprobs = self.compute_logprobs(
            self.reference,
            inputs.input_ids,
            response_ids
        )

        # Compute rewards
        rewards = []
        for prompt, response in zip(prompts, responses):
            r = self.reward.compute_reward(prompt, response)
            rewards.append(r)
        rewards = torch.tensor(rewards, device=self.policy.device)

        # Compute values
        values = self.compute_values(inputs.input_ids, response_ids)

        return {
            "prompts": prompts,
            "responses": responses,
            "response_ids": response_ids,
            "policy_logprobs": policy_logprobs,
            "reference_logprobs": reference_logprobs,
            "rewards": rewards,
            "values": values
        }

    def compute_logprobs(
        self,
        model: nn.Module,
        input_ids: torch.Tensor,
        response_ids: torch.Tensor
    ) -> torch.Tensor:
        """Compute log probabilities for responses."""
        full_ids = torch.cat([input_ids, response_ids], dim=1)

        with torch.no_grad():
            outputs = model(full_ids, output_hidden_states=False)
            logits = outputs.logits

        # Get logprobs for response tokens
        response_start = input_ids.size(1)
        response_logits = logits[:, response_start-1:-1, :]

        logprobs = F.log_softmax(response_logits, dim=-1)

        # Gather logprobs for actual tokens
        selected_logprobs = torch.gather(
            logprobs,
            dim=2,
            index=response_ids.unsqueeze(-1)
        ).squeeze(-1)

        return selected_logprobs

    def compute_values(
        self,
        input_ids: torch.Tensor,
        response_ids: torch.Tensor
    ) -> torch.Tensor:
        """Compute value estimates for each token."""
        full_ids = torch.cat([input_ids, response_ids], dim=1)

        with torch.no_grad():
            # Value model predicts value at each position
            values = self.value(full_ids)

        # Get values for response positions
        response_start = input_ids.size(1)
        return values[:, response_start:]

    def compute_advantages(
        self,
        rewards: torch.Tensor,
        values: torch.Tensor,
        policy_logprobs: torch.Tensor,
        reference_logprobs: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Compute GAE advantages and returns.

        Includes KL penalty in rewards.
        """
        batch_size, seq_len = policy_logprobs.shape

        # KL penalty at each token
        kl_penalty = self.config.kl_coef * (policy_logprobs - reference_logprobs)

        # Construct per-token rewards
        # Final reward at last token, KL penalty at each token
        token_rewards = -kl_penalty  # Penalty is subtracted
        token_rewards[:, -1] += rewards  # Add outcome reward at end

        # GAE computation
        advantages = torch.zeros_like(token_rewards)
        returns = torch.zeros_like(token_rewards)

        last_gae = 0
        last_value = 0

        for t in reversed(range(seq_len)):
            if t == seq_len - 1:
                next_value = 0  # Terminal
            else:
                next_value = values[:, t + 1]

            delta = token_rewards[:, t] + self.config.gamma * next_value - values[:, t]
            advantages[:, t] = last_gae = delta + self.config.gamma * self.config.gae_lambda * last_gae

        returns = advantages + values

        # Normalize advantages
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        return advantages, returns

    def ppo_step(self, batch: dict) -> dict:
        """
        Perform one PPO update step.

        Args:
            batch: dict from generate_and_score

        Returns:
            dict with loss metrics
        """
        # Compute advantages
        advantages, returns = self.compute_advantages(
            batch["rewards"],
            batch["values"],
            batch["policy_logprobs"],
            batch["reference_logprobs"]
        )

        # Store old logprobs for ratio computation
        old_logprobs = batch["policy_logprobs"].detach()
        old_values = batch["values"].detach()

        metrics = {
            "policy_loss": 0,
            "value_loss": 0,
            "entropy": 0,
            "kl": 0,
            "clip_fraction": 0
        }

        # Multiple PPO epochs
        for _ in range(self.config.ppo_epochs):
            # Recompute current policy logprobs
            current_logprobs = self.compute_logprobs(
                self.policy,
                batch["input_ids"],
                batch["response_ids"]
            )

            # Ratio for importance sampling
            ratio = torch.exp(current_logprobs - old_logprobs)

            # Clipped surrogate objective
            surr1 = ratio * advantages
            surr2 = torch.clamp(
                ratio,
                1 - self.config.clip_epsilon,
                1 + self.config.clip_epsilon
            ) * advantages

            policy_loss = -torch.min(surr1, surr2).mean()

            # Value loss with optional clipping
            current_values = self.compute_values(
                batch["input_ids"],
                batch["response_ids"]
            )

            if self.config.value_clip_epsilon > 0:
                value_clipped = old_values + torch.clamp(
                    current_values - old_values,
                    -self.config.value_clip_epsilon,
                    self.config.value_clip_epsilon
                )
                value_loss1 = (current_values - returns) ** 2
                value_loss2 = (value_clipped - returns) ** 2
                value_loss = 0.5 * torch.max(value_loss1, value_loss2).mean()
            else:
                value_loss = 0.5 * ((current_values - returns) ** 2).mean()

            # Entropy bonus (encourages exploration)
            # Simplified: use negative logprob as proxy
            entropy = -current_logprobs.mean()

            # Total loss
            total_loss = (
                policy_loss
                + 0.5 * value_loss
                - self.config.entropy_coef * entropy
            )

            # Update policy
            self.policy_optimizer.zero_grad()
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(
                self.policy.parameters(),
                self.config.max_grad_norm
            )
            self.policy_optimizer.step()

            # Update value (separate optimizer)
            self.value_optimizer.zero_grad()
            value_loss.backward()
            torch.nn.utils.clip_grad_norm_(
                self.value.parameters(),
                self.config.max_grad_norm
            )
            self.value_optimizer.step()

            # Track metrics
            with torch.no_grad():
                kl = (old_logprobs - current_logprobs).mean()
                clip_fraction = (
                    (torch.abs(ratio - 1) > self.config.clip_epsilon)
                    .float()
                    .mean()
                )

            metrics["policy_loss"] += policy_loss.item()
            metrics["value_loss"] += value_loss.item()
            metrics["entropy"] += entropy.item()
            metrics["kl"] += kl.item()
            metrics["clip_fraction"] += clip_fraction.item()

        # Average over epochs
        for k in metrics:
            metrics[k] /= self.config.ppo_epochs

        return metrics

    def train(
        self,
        prompts: list[str],
        num_iterations: int = 1000,
        batch_size: int = 32
    ):
        """Main training loop."""
        for iteration in range(num_iterations):
            # Sample batch of prompts
            batch_prompts = random.sample(prompts, min(batch_size, len(prompts)))

            # Generate and score
            batch = self.generate_and_score(batch_prompts)

            # PPO update
            metrics = self.ppo_step(batch)

            # Log progress
            if iteration % 10 == 0:
                print(f"Iteration {iteration}")
                print(f"  Policy Loss: {metrics['policy_loss']:.4f}")
                print(f"  Value Loss: {metrics['value_loss']:.4f}")
                print(f"  KL: {metrics['kl']:.4f}")
                print(f"  Mean Reward: {batch['rewards'].mean():.4f}")

            # Early stopping on KL
            if self.config.use_kl_penalty and metrics["kl"] > self.config.target_kl * 1.5:
                print(f"KL divergence too high ({metrics['kl']:.4f}), stopping")
                break

Understanding the PPOTrainer class: The trainer orchestrates the complex dance of PPO: generate responses, compute rewards, estimate advantages, and update both policy and value networks. The generate_and_score method handles the data collection phase—producing responses and all the information needed for policy updates (logprobs from current and reference policies, rewards, value estimates). The ppo_step method then uses this data for multiple epochs of gradient updates, the key to PPO's sample efficiency.

GAE (Generalized Advantage Estimation) explained: The compute_advantages method implements GAE, a crucial technique for variance reduction. Raw advantages (reward minus value) have high variance because rewards are noisy and values are imperfect estimates. GAE smooths this by exponentially weighting temporal differences: recent steps use more "actual" reward signal, distant steps use more value estimates. The gae_lambda parameter (0.95 typical) controls this tradeoff—higher values trust actual rewards more, lower values trust value estimates more.

The KL penalty integration: Notice how KL penalty is incorporated into token_rewards rather than as a separate loss term. This treats KL divergence as a per-token "cost" of deviating from the reference policy. The final token also receives the outcome reward. This design ensures the advantage computation naturally accounts for both reward optimization and KL regularization, letting GAE handle the credit assignment for both objectives simultaneously.

Why multiple PPO epochs work: Unlike standard RL where each data point is used once, PPO reuses collected data for multiple gradient steps (typically 4). This is possible because the clipping mechanism prevents policy updates from going too far—if the policy changes too much, the clipped objective stops pushing further. This reuse dramatically improves sample efficiency (expensive generation is amortized over multiple updates) without sacrificing stability.

The PPO Objective Function

The clipped surrogate objective is:

Code
L^CLIP(θ) = E[min(r(θ)Â, clip(r(θ), 1-ε, 1+ε)Â)]

where:
  r(θ) = π_θ(a|s) / π_θ_old(a|s)  (probability ratio)
  Â = advantage estimate
  ε = clip parameter (typically 0.1-0.2)
Python
def ppo_clipped_objective(
    policy_logprobs: torch.Tensor,
    old_logprobs: torch.Tensor,
    advantages: torch.Tensor,
    clip_epsilon: float = 0.2
) -> torch.Tensor:
    """
    Compute PPO clipped surrogate objective.

    The clipping prevents too large policy updates.
    """
    # Probability ratio
    ratio = torch.exp(policy_logprobs - old_logprobs)

    # Unclipped objective
    obj1 = ratio * advantages

    # Clipped objective
    clipped_ratio = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon)
    obj2 = clipped_ratio * advantages

    # Take minimum (pessimistic bound)
    # This is the key to PPO's stability
    return torch.min(obj1, obj2).mean()

Why PPO Works

  1. Trust region approximation: Clipping creates a soft trust region without explicit KL constraints
  2. Pessimistic bound: Taking minimum prevents overoptimistic updates
  3. Sample efficiency: Can reuse data for multiple gradient steps
  4. Stability: Less hyperparameter sensitive than vanilla policy gradient

GRPO: Group Relative Policy Optimization

The GRPO Innovation

GRPO, introduced by DeepSeek, eliminates the value model (critic) by using group statistics for advantage estimation. This halves memory requirements.

The key insight behind GRPO: PPO's value function estimates "how good is this state?"—but for language models, we don't actually need absolute values. We only need to know which responses are relatively better. GRPO exploits this by generating multiple responses to the same prompt and using their reward statistics as the baseline. If response A scores higher than the group average, it has positive advantage; if lower, negative. No learned value function required.

Why groups work as baselines: The value function's job is to answer "what's the expected reward from this state?" For a prompt, this means "across all possible completions, what's the average reward?" GRPO approximates this by sampling: generate K responses, compute their rewards, and use the mean as the baseline. With enough samples (K=8-16 typical), this Monte Carlo estimate is surprisingly accurate—often better than a learned value function that might generalize poorly to new prompts.

The memory dividend: GRPO needs only the policy and reference model—no critic. For a 70B parameter model, this saves ~140GB of GPU memory (the critic would be another 70B). This memory savings is what enabled DeepSeek to train R1 on their available hardware. The tradeoff is computational: you must generate K responses per prompt instead of 1, increasing generation cost. But generation is highly parallelizable, while memory is a hard constraint.

When GRPO beats PPO: GRPO tends to outperform PPO when: (1) prompts are diverse and value generalization is hard, (2) rewards are reliable/verifiable so Monte Carlo estimates are accurate, (3) you're memory-constrained. PPO tends to win when: (1) value functions generalize well across similar prompts, (2) you want maximum sample efficiency (PPO reuses data for multiple updates), (3) rewards are noisy and you need variance reduction from learned baselines.

Python
class GRPOConfig:
    """Configuration for GRPO training."""

    learning_rate: float = 1e-6
    group_size: int = 8              # Responses per prompt
    clip_epsilon: float = 0.2
    kl_coef: float = 0.1
    max_grad_norm: float = 1.0

    # GRPO-specific
    normalize_advantages: bool = True
    baseline: str = "mean"           # "mean" or "median"


class GRPOTrainer:
    """
    Group Relative Policy Optimization.

    Key insight: Use group statistics instead of learned value function
    to estimate advantages.

    From DeepSeek R1 paper:
    "GRPO foregoes the critic model that is typically the same size as
    the policy model, and estimates the baseline from group scores instead."
    """

    def __init__(
        self,
        policy_model: nn.Module,
        reference_model: nn.Module,
        reward_model,  # Can be neural network or rule-based
        tokenizer,
        config: GRPOConfig
    ):
        self.policy = policy_model
        self.reference = reference_model
        self.reward = reward_model
        self.tokenizer = tokenizer
        self.config = config

        self.optimizer = torch.optim.AdamW(
            self.policy.parameters(),
            lr=config.learning_rate
        )

    def generate_group(
        self,
        prompt: str,
        group_size: int = None
    ) -> dict:
        """
        Generate a group of responses for one prompt.

        Returns:
            dict with responses, rewards, logprobs for the group
        """
        group_size = group_size or self.config.group_size

        # Tokenize prompt
        inputs = self.tokenizer(
            prompt,
            return_tensors="pt",
            padding=True
        ).to(self.policy.device)

        # Generate multiple responses
        responses = []
        response_ids_list = []
        policy_logprobs_list = []
        reference_logprobs_list = []
        rewards = []

        for _ in range(group_size):
            with torch.no_grad():
                outputs = self.policy.generate(
                    **inputs,
                    max_new_tokens=1024,
                    temperature=1.0,
                    do_sample=True,
                    return_dict_in_generate=True,
                    output_scores=True
                )

            response_ids = outputs.sequences[:, inputs.input_ids.size(1):]
            response = self.tokenizer.decode(
                response_ids[0],
                skip_special_tokens=True
            )

            # Compute logprobs
            policy_lp = self.compute_sequence_logprob(
                self.policy, inputs.input_ids, response_ids
            )
            reference_lp = self.compute_sequence_logprob(
                self.reference, inputs.input_ids, response_ids
            )

            # Compute reward
            reward = self.reward.compute_reward(prompt, response)
            if isinstance(reward, dict):
                reward = reward.get("total", reward.get("correctness", 0))

            responses.append(response)
            response_ids_list.append(response_ids)
            policy_logprobs_list.append(policy_lp)
            reference_logprobs_list.append(reference_lp)
            rewards.append(reward)

        return {
            "prompt": prompt,
            "responses": responses,
            "response_ids": response_ids_list,
            "policy_logprobs": torch.stack(policy_logprobs_list),
            "reference_logprobs": torch.stack(reference_logprobs_list),
            "rewards": torch.tensor(rewards, device=self.policy.device)
        }

    def compute_sequence_logprob(
        self,
        model: nn.Module,
        input_ids: torch.Tensor,
        response_ids: torch.Tensor
    ) -> torch.Tensor:
        """Compute total log probability of response sequence."""
        full_ids = torch.cat([input_ids, response_ids], dim=1)

        with torch.no_grad():
            outputs = model(full_ids)
            logits = outputs.logits

        # Get logprobs for response tokens
        response_start = input_ids.size(1)
        response_logits = logits[:, response_start-1:-1, :]

        logprobs = F.log_softmax(response_logits, dim=-1)

        # Gather logprobs for actual tokens
        token_logprobs = torch.gather(
            logprobs,
            dim=2,
            index=response_ids.unsqueeze(-1)
        ).squeeze(-1)

        # Sum for sequence logprob
        return token_logprobs.sum(dim=-1)

    def compute_group_advantages(
        self,
        rewards: torch.Tensor,
        policy_logprobs: torch.Tensor,
        reference_logprobs: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute advantages using group statistics.

        This is the key GRPO innovation: no value network needed.
        """
        # KL penalty
        kl = policy_logprobs - reference_logprobs

        # Adjusted rewards (reward - KL penalty)
        adjusted_rewards = rewards - self.config.kl_coef * kl

        # Group baseline (mean or median)
        if self.config.baseline == "mean":
            baseline = adjusted_rewards.mean()
        else:
            baseline = adjusted_rewards.median()

        # Advantages relative to group
        advantages = adjusted_rewards - baseline

        # Optional normalization
        if self.config.normalize_advantages:
            std = advantages.std()
            if std > 1e-8:
                advantages = advantages / std

        return advantages

    def grpo_step(self, group_data: dict) -> dict:
        """
        Perform one GRPO update step.

        Args:
            group_data: dict from generate_group

        Returns:
            dict with loss metrics
        """
        # Compute group advantages
        advantages = self.compute_group_advantages(
            group_data["rewards"],
            group_data["policy_logprobs"],
            group_data["reference_logprobs"]
        )

        # Store old logprobs
        old_logprobs = group_data["policy_logprobs"].detach()

        # Compute current logprobs (requires grad)
        self.policy.train()
        current_logprobs = []

        for response_ids in group_data["response_ids"]:
            inputs = self.tokenizer(
                group_data["prompt"],
                return_tensors="pt"
            ).to(self.policy.device)

            lp = self.compute_sequence_logprob_with_grad(
                self.policy, inputs.input_ids, response_ids
            )
            current_logprobs.append(lp)

        current_logprobs = torch.stack(current_logprobs)

        # Probability ratio
        ratio = torch.exp(current_logprobs - old_logprobs)

        # Clipped objective
        surr1 = ratio * advantages
        surr2 = torch.clamp(
            ratio,
            1 - self.config.clip_epsilon,
            1 + self.config.clip_epsilon
        ) * advantages

        policy_loss = -torch.min(surr1, surr2).mean()

        # Update
        self.optimizer.zero_grad()
        policy_loss.backward()
        torch.nn.utils.clip_grad_norm_(
            self.policy.parameters(),
            self.config.max_grad_norm
        )
        self.optimizer.step()

        # Metrics
        with torch.no_grad():
            kl = (old_logprobs - current_logprobs).mean()
            clip_fraction = (
                (torch.abs(ratio - 1) > self.config.clip_epsilon)
                .float()
                .mean()
            )

        return {
            "policy_loss": policy_loss.item(),
            "mean_reward": group_data["rewards"].mean().item(),
            "max_reward": group_data["rewards"].max().item(),
            "min_reward": group_data["rewards"].min().item(),
            "kl": kl.item(),
            "clip_fraction": clip_fraction.item(),
            "advantage_std": advantages.std().item()
        }

    def compute_sequence_logprob_with_grad(
        self,
        model: nn.Module,
        input_ids: torch.Tensor,
        response_ids: torch.Tensor
    ) -> torch.Tensor:
        """Compute sequence logprob with gradients."""
        full_ids = torch.cat([input_ids, response_ids], dim=1)

        outputs = model(full_ids)
        logits = outputs.logits

        response_start = input_ids.size(1)
        response_logits = logits[:, response_start-1:-1, :]

        logprobs = F.log_softmax(response_logits, dim=-1)

        token_logprobs = torch.gather(
            logprobs,
            dim=2,
            index=response_ids.unsqueeze(-1)
        ).squeeze(-1)

        return token_logprobs.sum(dim=-1).squeeze(0)

    def train(
        self,
        prompts: list[str],
        num_iterations: int = 1000
    ):
        """Main GRPO training loop."""
        for iteration in range(num_iterations):
            # Sample prompt
            prompt = random.choice(prompts)

            # Generate group
            group_data = self.generate_group(prompt)

            # GRPO update
            metrics = self.grpo_step(group_data)

            # Log
            if iteration % 10 == 0:
                print(f"Iteration {iteration}")
                print(f"  Loss: {metrics['policy_loss']:.4f}")
                print(f"  Mean Reward: {metrics['mean_reward']:.4f}")
                print(f"  Reward Range: [{metrics['min_reward']:.2f}, {metrics['max_reward']:.2f}]")
                print(f"  KL: {metrics['kl']:.4f}")

Understanding the GRPOTrainer: The trainer follows the same pattern as PPO but with a critical simplification: no value network. The generate_group method creates multiple responses for a single prompt—this is the "group" in GRPO. Each response gets its own reward and logprob, giving us K data points from a single prompt. The compute_group_advantages method then uses these K rewards to compute relative advantages, replacing the value function entirely.

The advantage computation dissected: In compute_group_advantages, we first adjust rewards by subtracting the KL penalty (just like PPO). Then we subtract the group baseline (mean or median). The result is a relative score: positive if this response is better than the group average, negative if worse. The normalization step (dividing by standard deviation) ensures gradients are similarly scaled regardless of the reward magnitude.

Why median baselines can help: The code offers both mean and median baselines. Median is more robust to outliers—if one response in a group gets an unusually high or low reward (perhaps due to reward model quirks), it won't skew the baseline. This robustness matters when rewards are noisy. However, median can be problematic when most responses are wrong (reward 0) but a few are right (reward 1): the median would be 0, but you'd want to encourage the correct responses. Mean handles this case better.

The training loop structure: Unlike PPO's batched approach, GRPO processes one prompt at a time with its full group. This is intentional—you need all K responses to compute the group baseline. Parallelism comes from generating the K responses concurrently, not from batching across prompts. For distributed training, different workers can process different prompts simultaneously.

GRPO vs PPO Comparison

AspectPPOGRPO
Value modelRequiredNot needed
Memory2x model size1x model size
Advantage estimationGAE with learned valuesGroup statistics
Training stabilityVery stableStable with proper group size
Sample efficiencyCan reuse dataNeeds fresh samples per prompt

Why GRPO Works

The key insight is that for language model RL:

  1. We care about relative quality of responses, not absolute values
  2. Group statistics provide a good baseline without learning
  3. The policy already contains implicit value information

PPO: A=rV(s)(learned value function)A = r - V(s) \quad \text{(learned value function)}

GRPO: A=rμg(computed from group)A = r - \mu_g \quad \text{(computed from group)}

RLVR: Reinforcement Learning with Verifiable Rewards

The RLVR Approach

RLVR focuses on domains where rewards can be verified automatically (math, code, logic).

The verification advantage: In domains like mathematics and coding, we can definitively check if an answer is correct—no reward model needed, no human labelers required, no risk of reward hacking a learned proxy. A math problem either has the right numerical answer or it doesn't. Code either passes tests or it doesn't. This binary, ground-truth signal is the cleanest reward imaginable.

Why RLVR is transformative for reasoning: Prior to RLVR-style training, teaching models to reason required expensive human-labeled chains of thought. RLVR inverts this: provide only the final answer, let the model discover its own reasoning. If the model stumbles upon better reasoning strategies during exploration, they get reinforced through correct answers. This is how DeepSeek R1-Zero developed emergent reasoning—no one taught it to verify its work or try alternative approaches; it discovered these behaviors because they led to more correct answers.

The scalability unlock: RLVR scales effortlessly. Want to train on a million math problems? Generate them programmatically (synthetic data) and their answers are automatically verifiable. Want to train on code? Use existing test suites. No annotation bottleneck, no human-in-the-loop slowdown. This scalability is why reasoning model progress has accelerated—compute translates directly to capability without labeling costs.

Python
class RLVRTrainer(GRPOTrainer):
    """
    Reinforcement Learning with Verifiable Rewards.

    Extension of GRPO specifically for verifiable domains.

    Key features:
    1. Binary rewards from verification (no learned reward model)
    2. Domain-specific verifiers
    3. Optional process supervision
    """

    def __init__(
        self,
        policy_model: nn.Module,
        reference_model: nn.Module,
        verifier: 'Verifier',
        tokenizer,
        config: GRPOConfig,
        domain: str = "math"
    ):
        # Use verifier instead of learned reward model
        super().__init__(
            policy_model,
            reference_model,
            verifier,  # Verifier as reward model
            tokenizer,
            config
        )
        self.domain = domain
        self.verifier = verifier

    def compute_verifiable_reward(
        self,
        prompt: str,
        response: str,
        ground_truth: str = None
    ) -> dict:
        """
        Compute reward through verification.

        No neural network—just automated checking.
        """
        if self.domain == "math":
            return self.verify_math(prompt, response, ground_truth)
        elif self.domain == "code":
            return self.verify_code(prompt, response, ground_truth)
        elif self.domain == "logic":
            return self.verify_logic(prompt, response, ground_truth)
        else:
            raise ValueError(f"Unknown domain: {self.domain}")

    def verify_math(
        self,
        problem: str,
        response: str,
        ground_truth: str
    ) -> dict:
        """Verify mathematical answer."""
        extracted = self.extract_math_answer(response)

        # Normalize both answers
        try:
            pred_val = self.parse_math_expression(extracted)
            true_val = self.parse_math_expression(ground_truth)

            correct = abs(pred_val - true_val) < 1e-6
        except:
            # Fall back to string comparison
            correct = self.normalize_string(extracted) == self.normalize_string(ground_truth)

        # Format check
        has_reasoning = len(response) > 100 and any(
            kw in response.lower()
            for kw in ["therefore", "because", "since", "so ", "thus"]
        )

        return {
            "correctness": 1.0 if correct else 0.0,
            "format": 0.1 if has_reasoning else 0.0,
            "total": (1.0 if correct else 0.0) + (0.1 if has_reasoning else 0.0)
        }

    def verify_code(
        self,
        problem: str,
        response: str,
        test_cases: str
    ) -> dict:
        """
        Verify code through test execution.

        IMPORTANT: Run in sandbox in production!
        """
        # Extract code from response
        code = self.extract_code(response)

        if not code:
            return {"correctness": 0.0, "total": 0.0}

        # Parse test cases
        tests = self.parse_test_cases(test_cases)

        # Run tests (sandboxed)
        passed = 0
        total = len(tests)

        for test_input, expected_output in tests:
            try:
                actual = self.run_code_sandboxed(code, test_input)
                if self.outputs_match(actual, expected_output):
                    passed += 1
            except Exception as e:
                pass  # Test failed

        correctness = passed / total if total > 0 else 0.0

        return {
            "correctness": correctness,
            "tests_passed": passed,
            "tests_total": total,
            "total": correctness
        }

    def verify_logic(
        self,
        problem: str,
        response: str,
        ground_truth: str
    ) -> dict:
        """Verify logical reasoning."""
        # Extract conclusion
        conclusion = self.extract_conclusion(response)

        # Check against ground truth
        correct = self.logic_match(conclusion, ground_truth)

        # Check logical structure
        has_valid_structure = self.check_logical_structure(response)

        return {
            "correctness": 1.0 if correct else 0.0,
            "structure": 0.2 if has_valid_structure else 0.0,
            "total": (1.0 if correct else 0.0) + (0.2 if has_valid_structure else 0.0)
        }

    # Helper methods
    def extract_math_answer(self, response: str) -> str:
        """Extract math answer from response."""
        # Check for boxed answer
        boxed = re.search(r"\\boxed{([^}]+)}", response)
        if boxed:
            return boxed.group(1)

        # Check for "answer is X"
        answer_match = re.search(
            r"(?:answer|result)[:\s]+([^\n.,]+)",
            response,
            re.IGNORECASE
        )
        if answer_match:
            return answer_match.group(1).strip()

        # Last number in response
        numbers = re.findall(r"-?\d+\.?\d*", response)
        if numbers:
            return numbers[-1]

        return response.strip().split()[-1]

    def parse_math_expression(self, expr: str) -> float:
        """Parse mathematical expression to float."""
        # Clean expression
        expr = expr.strip()
        expr = re.sub(r"[,\s]", "", expr)

        # Handle fractions
        if "/" in expr:
            parts = expr.split("/")
            return float(parts[0]) / float(parts[1])

        # Handle percentages
        if "%" in expr:
            return float(expr.replace("%", "")) / 100

        return float(expr)

    def extract_code(self, response: str) -> str:
        """Extract code block from response."""
        # Look for fenced code blocks
        code_match = re.search(
            r"```(?:python|py)?\n(.*?)```",
            response,
            re.DOTALL
        )
        if code_match:
            return code_match.group(1)

        # Look for indented code
        lines = response.split('\n')
        code_lines = [l for l in lines if l.startswith('    ') or l.startswith('\t')]
        if code_lines:
            return '\n'.join(l.lstrip() for l in code_lines)

        return ""


class MathVerifier:
    """
    Specialized math verifier supporting various formats.
    """

    def __init__(self):
        self.symbolic_engine = None  # Optional: sympy for symbolic verification

    def verify(
        self,
        problem: str,
        response: str,
        ground_truth: str
    ) -> float:
        """Main verification entry point."""
        predicted = self.extract_answer(response)

        # Try exact match first
        if self.exact_match(predicted, ground_truth):
            return 1.0

        # Try numeric comparison
        if self.numeric_match(predicted, ground_truth):
            return 1.0

        # Try symbolic equivalence
        if self.symbolic_match(predicted, ground_truth):
            return 1.0

        return 0.0

    def exact_match(self, pred: str, truth: str) -> bool:
        """Exact string match after normalization."""
        return self.normalize(pred) == self.normalize(truth)

    def numeric_match(self, pred: str, truth: str, tol: float = 1e-6) -> bool:
        """Numeric comparison with tolerance."""
        try:
            pred_val = self.to_number(pred)
            truth_val = self.to_number(truth)
            return abs(pred_val - truth_val) < tol
        except:
            return False

    def symbolic_match(self, pred: str, truth: str) -> bool:
        """Symbolic equivalence using SymPy."""
        if self.symbolic_engine is None:
            return False

        try:
            from sympy import simplify, sympify
            pred_expr = sympify(pred)
            truth_expr = sympify(truth)
            return simplify(pred_expr - truth_expr) == 0
        except:
            return False

    def normalize(self, s: str) -> str:
        """Normalize answer string."""
        s = s.strip().lower()
        s = re.sub(r"\s+", " ", s)
        s = re.sub(r"[$,]", "", s)
        return s

    def to_number(self, s: str) -> float:
        """Convert string to number."""
        s = self.normalize(s)

        # Handle fractions
        if "/" in s:
            num, denom = s.split("/")
            return float(num) / float(denom)

        # Handle percentages
        if "%" in s:
            return float(s.replace("%", "")) / 100

        # Handle scientific notation
        return float(s)

    def extract_answer(self, response: str) -> str:
        """Extract final answer from response."""
        # Pattern matching for common formats
        patterns = [
            r"\\boxed{([^}]+)}",
            r"(?:final\s+)?answer[:\s]+([^\n.]+)",
            r"therefore[,:\s]+([^\n.]+)",
            r"=\s*([^\n.]+)$"
        ]

        for pattern in patterns:
            match = re.search(pattern, response, re.IGNORECASE | re.MULTILINE)
            if match:
                return match.group(1).strip()

        # Fallback: last line
        return response.strip().split('\n')[-1]

DPO and Other Alternatives

Direct Preference Optimization (DPO)

DPO eliminates the reward model entirely by optimizing preferences directly.

The theoretical insight behind DPO: Here's a surprising result from RL theory: if you have a Bradley-Terry preference model (humans prefer response A over B with probability σ(r(A) - r(B))), and you regularize with KL divergence to a reference policy, then the optimal policy has a closed-form solution. You can derive the reward function directly from the optimal policy. DPO exploits this by working backward: instead of learning a reward then optimizing against it (RLHF), optimize the policy directly against preferences. Same end result, fewer moving parts.

Why DPO is simpler: Traditional RLHF is a three-stage process: (1) train a reward model on preferences, (2) use RL to optimize against the reward, (3) manage the RL instabilities. DPO collapses this to one stage: supervised learning on preference pairs. No reward model to train, no RL hyperparameter tuning, no PPO stability concerns. This simplicity has made DPO extremely popular—many production systems now use DPO instead of PPO.

The hidden cost of simplicity: DPO assumes your preference pairs are good quality. With RLHF, the reward model can generalize beyond the training preferences—it learns a quality function applicable to novel responses. DPO has no such generalization; it only learns from the specific (chosen, rejected) pairs in your dataset. If your dataset doesn't cover the response space well, DPO may not learn the right policy. This is why some practitioners use DPO as a warm-start, followed by PPO for final polish.

Python
class DPOTrainer:
    """
    Direct Preference Optimization.

    Optimizes policy directly from preferences without explicit reward modeling.

    Key insight: The optimal policy under Bradley-Terry model has closed form,
    so we can skip reward model training entirely.

    Loss: -log σ(β * (log π(y_w|x)/π_ref(y_w|x) - log π(y_l|x)/π_ref(y_l|x)))

    where:
        y_w = winning (preferred) response
        y_l = losing (non-preferred) response
        β = temperature parameter
    """

    def __init__(
        self,
        policy_model: nn.Module,
        reference_model: nn.Module,
        tokenizer,
        beta: float = 0.1,
        learning_rate: float = 1e-6
    ):
        self.policy = policy_model
        self.reference = reference_model
        self.tokenizer = tokenizer
        self.beta = beta

        self.optimizer = torch.optim.AdamW(
            self.policy.parameters(),
            lr=learning_rate
        )

        # Freeze reference model
        for param in self.reference.parameters():
            param.requires_grad = False

    def compute_logprobs(
        self,
        model: nn.Module,
        input_ids: torch.Tensor,
        labels: torch.Tensor,
        attention_mask: torch.Tensor
    ) -> torch.Tensor:
        """Compute per-token log probabilities."""
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        logits = outputs.logits

        # Shift for next-token prediction
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()

        # Compute log probs
        log_probs = F.log_softmax(shift_logits, dim=-1)

        # Gather actual token log probs
        per_token_logps = torch.gather(
            log_probs,
            dim=2,
            index=shift_labels.unsqueeze(-1)
        ).squeeze(-1)

        return per_token_logps

    def dpo_loss(
        self,
        chosen_input_ids: torch.Tensor,
        chosen_attention_mask: torch.Tensor,
        rejected_input_ids: torch.Tensor,
        rejected_attention_mask: torch.Tensor,
        chosen_labels: torch.Tensor,
        rejected_labels: torch.Tensor
    ) -> tuple[torch.Tensor, dict]:
        """
        Compute DPO loss.

        Returns loss and metrics dict.
        """
        # Policy log probs
        policy_chosen_logps = self.compute_logprobs(
            self.policy,
            chosen_input_ids,
            chosen_labels,
            chosen_attention_mask
        ).sum(dim=-1)

        policy_rejected_logps = self.compute_logprobs(
            self.policy,
            rejected_input_ids,
            rejected_labels,
            rejected_attention_mask
        ).sum(dim=-1)

        # Reference log probs
        with torch.no_grad():
            ref_chosen_logps = self.compute_logprobs(
                self.reference,
                chosen_input_ids,
                chosen_labels,
                chosen_attention_mask
            ).sum(dim=-1)

            ref_rejected_logps = self.compute_logprobs(
                self.reference,
                rejected_input_ids,
                rejected_labels,
                rejected_attention_mask
            ).sum(dim=-1)

        # Log ratios
        chosen_log_ratio = policy_chosen_logps - ref_chosen_logps
        rejected_log_ratio = policy_rejected_logps - ref_rejected_logps

        # DPO loss
        logits = self.beta * (chosen_log_ratio - rejected_log_ratio)
        loss = -F.logsigmoid(logits).mean()

        # Metrics
        with torch.no_grad():
            chosen_rewards = self.beta * chosen_log_ratio
            rejected_rewards = self.beta * rejected_log_ratio

            accuracy = (chosen_rewards > rejected_rewards).float().mean()
            reward_margin = (chosen_rewards - rejected_rewards).mean()

        metrics = {
            "loss": loss.item(),
            "accuracy": accuracy.item(),
            "reward_margin": reward_margin.item(),
            "chosen_reward": chosen_rewards.mean().item(),
            "rejected_reward": rejected_rewards.mean().item()
        }

        return loss, metrics

    def train_step(self, batch: dict) -> dict:
        """Single training step."""
        self.policy.train()

        # Tokenize chosen and rejected
        chosen_encodings = self.tokenizer(
            batch["chosen"],
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=2048
        ).to(self.policy.device)

        rejected_encodings = self.tokenizer(
            batch["rejected"],
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=2048
        ).to(self.policy.device)

        # Compute loss
        loss, metrics = self.dpo_loss(
            chosen_input_ids=chosen_encodings.input_ids,
            chosen_attention_mask=chosen_encodings.attention_mask,
            rejected_input_ids=rejected_encodings.input_ids,
            rejected_attention_mask=rejected_encodings.attention_mask,
            chosen_labels=chosen_encodings.input_ids,
            rejected_labels=rejected_encodings.input_ids
        )

        # Update
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 1.0)
        self.optimizer.step()

        return metrics


**The DPO loss unpacked:** The core of DPO is the loss function `dpo_loss`. It computes log-ratios for chosen and rejected responses: how much more (or less) likely are they under the current policy versus the reference? The difference between these ratios, scaled by β, becomes the "implicit reward margin." The logsigmoid loss encourages the chosen response's implicit reward to exceed the rejected response's. When the model already prefers the chosen response strongly, the loss is low; when it prefers rejected, the loss is high.

**Why the reference model matters:** DPO requires a frozen reference model—typically the pre-trained or instruction-tuned starting point. The log-ratios are computed relative to this reference. Without it, the model could trivially minimize the loss by assigning extreme probabilities (100% to chosen, 0% to rejected). The reference anchors the policy, ensuring it stays close to a known-good distribution while adjusting preferences.

**The β hyperparameter:** β controls how aggressively the model exploits preference signals. High β (0.5+) makes the model strongly prefer chosen over rejected, but risks overfitting to the specific pairs. Low β (0.01-0.1) makes gentler adjustments, preserving more of the reference model's behavior. Typical values range from 0.1-0.5 depending on dataset size and quality.

class IPOTrainer(DPOTrainer):
    """
    Identity Preference Optimization.

    Variant of DPO that uses different loss formulation:
    (log_ratio_chosen - log_ratio_rejected - margin)^2

    More robust to label noise.
    """

    def __init__(self, *args, margin: float = 0.5, **kwargs):
        super().__init__(*args, **kwargs)
        self.margin = margin

    def ipo_loss(
        self,
        chosen_log_ratio: torch.Tensor,
        rejected_log_ratio: torch.Tensor
    ) -> torch.Tensor:
        """IPO loss: squared difference from margin."""
        diff = chosen_log_ratio - rejected_log_ratio
        loss = (diff - self.margin) ** 2
        return loss.mean()


class KTOTrainer:
    """
    Kahneman-Tversky Optimization.

    Uses prospect theory-inspired loss that doesn't require paired comparisons.
    Works with just desirable/undesirable labels.
    """

    def __init__(
        self,
        policy_model: nn.Module,
        reference_model: nn.Module,
        tokenizer,
        beta: float = 0.1,
        desirable_weight: float = 1.0,
        undesirable_weight: float = 1.0
    ):
        self.policy = policy_model
        self.reference = reference_model
        self.tokenizer = tokenizer
        self.beta = beta
        self.desirable_weight = desirable_weight
        self.undesirable_weight = undesirable_weight

    def kto_loss(
        self,
        policy_logps: torch.Tensor,
        reference_logps: torch.Tensor,
        is_desirable: torch.Tensor,
        kl_reference: float
    ) -> torch.Tensor:
        """
        KTO loss.

        Different treatment for desirable vs undesirable examples.
        """
        log_ratio = policy_logps - reference_logps

        # Desirable: encourage
        desirable_loss = -F.logsigmoid(self.beta * (log_ratio - kl_reference))

        # Undesirable: discourage
        undesirable_loss = -F.logsigmoid(self.beta * (kl_reference - log_ratio))

        # Combine
        loss = torch.where(
            is_desirable,
            self.desirable_weight * desirable_loss,
            self.undesirable_weight * undesirable_loss
        )

        return loss.mean()

Comparison of Training Methods

MethodReward ModelValue ModelData RequirementComplexity
PPOYesYesPrompts + preferencesHigh
GRPOYesNoPrompts + rewardsMedium
DPONoNoPreference pairsLow
RLVRVerifier onlyNoPrompts + ground truthMedium
KTONoNoSingle labelsLow

How o1 and R1 Are Trained

OpenAI o1/o3 Training

Based on public information from OpenAI:

Python
class O1TrainingPipeline:
    """
    Hypothetical O1 training pipeline based on public information.

    Key elements:
    1. Large-scale RL with chain-of-thought
    2. Hidden reasoning (not shown to users)
    3. Test-time search exploration
    """

    def __init__(self):
        self.stages = [
            "pretraining",
            "sft_on_reasoning_traces",
            "large_scale_rl",
            "safety_alignment"
        ]

    def stage_1_pretraining(self, model, data):
        """Standard GPT pretraining."""
        # Massive internet text
        # Same as GPT-4 base
        pass

    def stage_2_sft(self, model, reasoning_data):
        """
        SFT on high-quality reasoning traces.

        Data includes:
        - Math solutions with detailed steps
        - Code with explanations
        - Scientific reasoning
        """
        # Format: problem → <think>reasoning</think> answer
        pass

    def stage_3_rl(self, model, problems, reward_model):
        """
        Large-scale reinforcement learning.

        From OpenAI: "o1 learns to hone its chain of thought and refine
        the strategies it uses. It learns to recognize and correct its
        mistakes. It learns to break down tricky steps into simpler ones.
        It learns to try a different approach when the current one isn't
        working."

        Key features:
        - Outcome-based rewards
        - Likely uses PPO or similar
        - Extended reasoning encouraged
        """
        # Train to maximize answer correctness
        # Reward model evaluates final answer quality
        pass

    def stage_4_safety(self, model):
        """
        Safety alignment through RLHF.

        Includes:
        - Chain-of-thought monitoring
        - Refusal training
        - Red-teaming
        """
        pass


class O3Innovations:
    """
    O3 specific innovations.

    From OpenAI: "o3 generates a diverse set of candidate CoTs, each
    representing a distinct step-by-step reasoning pathway to solve the
    task. This process mimics a human iterating over different drafts of
    a solution before settling on the best one."
    """

    def test_time_search(self, problem):
        """
        Generate multiple reasoning paths, select best.

        o3 can be run at different compute levels:
        - Low: fewer paths explored
        - Medium: moderate exploration
        - High: extensive search
        """
        pass

    def tool_use_through_rl(self, model):
        """
        From OpenAI: "trained to use tools through reinforcement learning—
        teaching them not just how to use tools, but to reason about when
        to use them."
        """
        pass

DeepSeek R1 Training

DeepSeek published their training approach in detail:

Python
class DeepSeekR1TrainingPipeline:
    """
    DeepSeek R1 training pipeline.

    Two variants:
    1. R1-Zero: Pure RL without SFT
    2. R1: Multi-stage with cold start data

    Key innovation: GRPO with rule-based rewards
    """

    def __init__(self, base_model: str = "DeepSeek-V3"):
        self.base_model = base_model

    def train_r1_zero(
        self,
        model,
        math_problems: list[dict],
        code_problems: list[dict]
    ):
        """
        R1-Zero: Pure RL training.

        From paper: "We directly applied reinforcement learning to the
        base model without relying on supervised fine-tuning (SFT) as a
        preliminary step. This approach allows the model to explore
        chain-of-thought for solving complex problems."

        Remarkably, reasoning emerges naturally!
        """
        # Rule-based reward system
        reward_system = RuleBasedRewardSystem(
            format_weight=0.1,
            correctness_weight=1.0
        )

        # GRPO training
        trainer = GRPOTrainer(
            policy_model=model,
            reference_model=model.copy(),  # Initial model as reference
            reward_model=reward_system,
            tokenizer=self.tokenizer,
            config=GRPOConfig(
                group_size=8,
                kl_coef=0.04  # Low KL penalty
            )
        )

        # Train on math + code
        all_problems = math_problems + code_problems
        trainer.train(all_problems, num_iterations=100000)

        return model

    def train_r1(self, model):
        """
        R1: Multi-stage training for better readability.

        Addresses R1-Zero issues:
        - Language mixing
        - Poor readability
        - Inconsistent formatting
        """
        # Stage 1: Cold start with high-quality reasoning examples
        model = self.stage_1_cold_start(model)

        # Stage 2: RL with reasoning
        model = self.stage_2_reasoning_rl(model)

        # Stage 3: Rejection sampling for diverse tasks
        model = self.stage_3_rejection_sampling(model)

        # Stage 4: Final RL with all data
        model = self.stage_4_final_rl(model)

        return model

    def stage_1_cold_start(self, model):
        """
        SFT on curated reasoning examples.

        From paper: "A small amount of cold-start data is collected before
        reinforcement learning to prevent the model from learning poor
        readability styles at the beginning of RL."

        Thousands of examples (not millions).
        """
        cold_start_data = self.collect_cold_start_data()

        # Format: Include thinking process
        formatted_data = [
            f"Problem: {d['problem']}\n<think>{d['reasoning']}</think>\n{d['answer']}"
            for d in cold_start_data
        ]

        # Standard SFT
        model = supervised_finetune(model, formatted_data)
        return model

    def stage_2_reasoning_rl(self, model):
        """
        RL focused on reasoning quality.

        Verifiable domains: math, code, logic
        """
        verifier = MathVerifier()

        trainer = RLVRTrainer(
            policy_model=model,
            reference_model=model.copy(),
            verifier=verifier,
            tokenizer=self.tokenizer,
            config=GRPOConfig(group_size=16),
            domain="math"
        )

        # Train until convergence
        trainer.train(self.math_problems, num_iterations=50000)

        return model

    def stage_3_rejection_sampling(self, model):
        """
        Use model to generate training data for diverse tasks.

        For tasks without verifiable rewards (writing, open QA),
        use rejection sampling:
        1. Generate multiple responses
        2. Filter by quality (using a judge model or heuristics)
        3. Use top responses for SFT
        """
        diverse_prompts = self.load_diverse_prompts()

        accepted_data = []
        for prompt in diverse_prompts:
            # Generate multiple responses
            responses = [
                model.generate(prompt, temperature=0.8)
                for _ in range(8)
            ]

            # Score and filter
            scored = [(r, self.quality_score(prompt, r)) for r in responses]
            scored.sort(key=lambda x: x[1], reverse=True)

            # Keep top responses
            accepted_data.extend([
                {"prompt": prompt, "response": r}
                for r, score in scored[:2]  # Top 2
                if score > 0.7
            ])

        # SFT on filtered data
        model = supervised_finetune(model, accepted_data)
        return model

    def stage_4_final_rl(self, model):
        """
        Final RL combining all reward signals.

        - Verifiable rewards for math/code
        - Preference rewards for other tasks
        - Safety constraints
        """
        # Combined reward
        def combined_reward(prompt, response):
            # Detect task type
            if is_math(prompt):
                return self.math_verifier.verify(prompt, response, ground_truth)
            elif is_code(prompt):
                return self.code_verifier.verify(prompt, response, test_cases)
            else:
                return self.preference_model.score(prompt, response)

        # Final GRPO training
        trainer = GRPOTrainer(
            policy_model=model,
            reference_model=model.copy(),
            reward_model=combined_reward,
            tokenizer=self.tokenizer,
            config=GRPOConfig(group_size=8)
        )

        trainer.train(self.all_prompts, num_iterations=10000)
        return model

Emergent Capabilities

Both o1/o3 and R1 exhibit emergent reasoning behaviors not explicitly trained:

Python
class EmergentBehaviors:
    """
    Behaviors that emerge from RL training on reasoning tasks.

    From DeepSeek: "it naturally emerges with numerous powerful and
    intriguing reasoning behaviors."
    """

    @staticmethod
    def self_verification():
        """Model learns to check its own work."""
        # Emerges from: Correctness rewards
        # Model realizes checking catches errors
        example = """
        <think>
        Let me calculate 17 × 24.
        17 × 24 = 17 × (20 + 4) = 340 + 68 = 408

        Let me verify: 408 ÷ 17 = 24 ✓
        The calculation is correct.
        </think>
        """

    @staticmethod
    def backtracking():
        """Model learns to try different approaches."""
        # Emerges from: Exploring multiple paths during RL
        example = """
        <think>
        First approach: direct multiplication
        17 × 24 = ... hmm, let me try a different way.

        Alternative: break down 24 = 25 - 1
        17 × 25 - 17 × 1 = 425 - 17 = 408

        Both give 408, confirming the answer.
        </think>
        """

    @staticmethod
    def problem_decomposition():
        """Model learns to break complex problems into parts."""
        # Emerges from: Complex problems requiring multiple steps
        example = """
        <think>
        This problem has several parts:
        1. First, I need to find X
        2. Then, use X to calculate Y
        3. Finally, combine to get the answer

        Starting with part 1...
        </think>
        """

    @staticmethod
    def metacognition():
        """Model learns to reason about its own reasoning."""
        # Emerges from: Self-consistency and error correction
        example = """
        <think>
        I'm not confident about this step. Let me reconsider.

        The logic seems sound, but I should double-check the arithmetic.

        Actually, I made an error earlier. Let me correct it.
        </think>
        """

Distillation: Compressing Reasoning

Knowledge Distillation for Reasoning

Distillation transfers reasoning capabilities from large models to smaller ones.

Why distillation enables practical deployment: Training a 70B parameter reasoning model requires massive compute and produces a model too large for most deployment scenarios. Distillation offers an alternative: train the large model once, then transfer its capabilities to smaller models. A well-distilled 7B model can achieve 80-90% of the 70B model's reasoning performance at a fraction of the inference cost. This is how reasoning models become practical for real applications.

The mechanism of knowledge transfer: Unlike retraining from scratch, distillation gives the student a "cheat sheet"—the teacher's solutions. The student doesn't need to discover reasoning patterns through costly exploration; it just needs to imitate patterns the teacher already found. This is far more sample-efficient. The teacher's reasoning traces act as demonstrations, and the student learns via supervised fine-tuning on these traces.

Beyond mere imitation: Good distillation transfers not just answers but reasoning styles. When a student trains on thousands of teacher-generated proofs, it internalizes patterns: "when you see this type of problem, try this approach." The student may not understand why these patterns work, but it can replicate them. Interestingly, distilled students sometimes outperform teachers on specific task types—the compression forces the student to extract the most reliable patterns, pruning away teacher idiosyncrasies.

Python
class ReasoningDistillation:
    """
    Distill reasoning capabilities from teacher to student.

    DeepSeek released distilled models at 1.5B, 7B, 8B, 14B, 32B, 70B.
    """

    def __init__(
        self,
        teacher_model: nn.Module,
        student_model: nn.Module,
        tokenizer,
        temperature: float = 2.0
    ):
        self.teacher = teacher_model
        self.student = student_model
        self.tokenizer = tokenizer
        self.temperature = temperature

        # Freeze teacher
        for param in self.teacher.parameters():
            param.requires_grad = False

    def generate_reasoning_traces(
        self,
        problems: list[str],
        n_per_problem: int = 1
    ) -> list[dict]:
        """
        Generate reasoning traces from teacher.

        These become training data for student.
        """
        traces = []

        for problem in problems:
            for _ in range(n_per_problem):
                with torch.no_grad():
                    response = self.teacher.generate(
                        f"{problem}\n<think>",
                        max_new_tokens=2048,
                        temperature=0.7,
                        do_sample=True
                    )

                traces.append({
                    "problem": problem,
                    "reasoning": response
                })

        return traces

    def distill_on_traces(
        self,
        traces: list[dict],
        epochs: int = 3,
        learning_rate: float = 1e-5
    ):
        """
        Train student to reproduce teacher's reasoning.

        Standard SFT on teacher-generated traces.
        """
        optimizer = torch.optim.AdamW(
            self.student.parameters(),
            lr=learning_rate
        )

        self.student.train()

        for epoch in range(epochs):
            total_loss = 0

            for trace in traces:
                # Format input
                text = f"Problem: {trace['problem']}\n{trace['reasoning']}"
                inputs = self.tokenizer(
                    text,
                    return_tensors="pt",
                    truncation=True,
                    max_length=2048
                ).to(self.student.device)

                # Forward pass
                outputs = self.student(
                    input_ids=inputs.input_ids,
                    attention_mask=inputs.attention_mask,
                    labels=inputs.input_ids
                )

                loss = outputs.loss

                # Backward
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.item()

            print(f"Epoch {epoch+1}, Loss: {total_loss / len(traces):.4f}")

    def distill_with_kl(
        self,
        prompts: list[str],
        epochs: int = 3
    ):
        """
        Distillation using KL divergence on logits.

        More sophisticated than trace-based distillation.
        """
        optimizer = torch.optim.AdamW(
            self.student.parameters(),
            lr=1e-5
        )

        for epoch in range(epochs):
            total_loss = 0

            for prompt in prompts:
                inputs = self.tokenizer(
                    prompt,
                    return_tensors="pt"
                ).to(self.student.device)

                # Generate from teacher
                with torch.no_grad():
                    teacher_outputs = self.teacher.generate(
                        **inputs,
                        max_new_tokens=512,
                        return_dict_in_generate=True,
                        output_scores=True
                    )
                    teacher_logits = torch.stack(teacher_outputs.scores, dim=1)

                # Student forward on same tokens
                full_input = torch.cat([
                    inputs.input_ids,
                    teacher_outputs.sequences[:, inputs.input_ids.size(1):]
                ], dim=1)

                student_outputs = self.student(full_input)
                student_logits = student_outputs.logits[:, inputs.input_ids.size(1)-1:-1, :]

                # KL divergence loss
                teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1)
                student_log_probs = F.log_softmax(student_logits / self.temperature, dim=-1)

                kl_loss = F.kl_div(
                    student_log_probs,
                    teacher_probs,
                    reduction='batchmean'
                ) * (self.temperature ** 2)

                # Update
                optimizer.zero_grad()
                kl_loss.backward()
                optimizer.step()

                total_loss += kl_loss.item()

            print(f"Epoch {epoch+1}, KL Loss: {total_loss / len(prompts):.4f}")


**Trace-based vs KL-based distillation:** The class shows two approaches. `distill_on_traces` is simpler: generate reasoning traces from the teacher, then train the student with standard language modeling loss to reproduce them. This is just SFT on teacher-generated data. `distill_with_kl` is more sophisticated: it matches the student's probability distribution to the teacher's at each token. KL-based distillation transfers more information (not just "what tokens to generate" but "how confident to be about each option") but is computationally heavier.

**The temperature parameter in distillation:** Higher temperature (2.0+) during trace generation produces more diverse reasoning paths. The student sees varied approaches, not just the teacher's greedy decode. This diversity acts as data augmentation, helping the student generalize. During KL distillation, temperature softens the probability distributions, making them easier to match—the student doesn't need to replicate exact confidence levels, just the rough shape of the distribution.

**Why multiple traces per problem:** Generating `n_per_problem > 1` traces creates diversity even for the same problem. Different random samples lead to different reasoning paths that arrive at the same answer. This teaches the student that there are multiple valid approaches, making it more robust than a student that memorizes a single solution pattern per problem type.

**Progressive distillation rationale:** Distilling directly from 70B to 1.5B loses significant capability—the capacity gap is too large. Progressive distillation chains the process: 70B → 32B → 7B → 1.5B. Each step bridges a smaller gap, preserving more capability. The 32B student learns from 70B traces, then becomes the teacher for 7B. Each intermediate model is optimized for its capacity, extracting what's transferable and discarding what isn't. This cascade typically preserves 10-20% more capability than direct distillation at the smallest scales.

```python
class ProgressiveDistillation:
    """
    Distill in stages for better results.

    Teacher (70B) → Medium (32B) → Small (7B) → Tiny (1.5B)
    """

    def __init__(self, model_sizes: list[str]):
        self.model_sizes = model_sizes  # e.g., ["70B", "32B", "7B", "1.5B"]

    def progressive_distill(self, problems: list[str]):
        """
        Chain of distillation.

        Each model teaches the next smaller one.
        """
        current_teacher = load_model(self.model_sizes[0])

        for i, size in enumerate(self.model_sizes[1:], 1):
            print(f"Distilling {self.model_sizes[i-1]}{size}")

            student = load_model(size, initialize_fresh=True)

            distiller = ReasoningDistillation(
                teacher_model=current_teacher,
                student_model=student,
                tokenizer=self.tokenizer
            )

            # Generate traces
            traces = distiller.generate_reasoning_traces(
                problems,
                n_per_problem=3
            )

            # Train student
            distiller.distill_on_traces(traces, epochs=5)

            # Student becomes teacher for next stage
            current_teacher = student

            # Save checkpoint
            student.save_pretrained(f"distilled_{size}")

        return current_teacher

Distillation Results

From DeepSeek's distilled models:

ModelBaseMath (AIME)Code (HumanEval)Notes
R1-Distill-Qwen-1.5BQwen2.5-1.5B28.9%52.4%Tiny but capable
R1-Distill-Qwen-7BQwen2.5-7B55.5%71.3%Good balance
R1-Distill-Llama-8BLlama-3.1-8B50.4%72.6%Alternative base
R1-Distill-Qwen-14BQwen2.5-14B69.7%80.5%Strong performance
R1-Distill-Qwen-32BQwen2.5-32B72.6%85.4%Near full model
R1-Distill-Llama-70BLlama-3.3-70B79.2%86.5%Best distilled

Key observation: Even 7B distilled models achieve competitive reasoning performance!

Practical Implementation Guide

Choosing Your Training Approach

Python
def choose_training_approach(
    has_preference_data: bool,
    has_verifiable_rewards: bool,
    compute_budget: str,  # "low", "medium", "high"
    model_size: str  # "small", "medium", "large"
) -> str:
    """
    Decision tree for choosing training approach.
    """
    if compute_budget == "low":
        if has_preference_data:
            return "DPO"  # No reward model needed
        else:
            return "SFT"  # Just supervised finetuning

    if has_verifiable_rewards:
        if model_size == "large":
            return "GRPO"  # Memory efficient
        else:
            return "RLVR"  # Verifiable rewards

    if has_preference_data:
        if compute_budget == "high":
            return "PPO"  # Full RLHF
        else:
            return "DPO"  # Simpler alternative

    # Default
    return "SFT with rejection sampling"


def estimate_training_resources(
    approach: str,
    model_params: int,  # in billions
    dataset_size: int  # number of examples
) -> dict:
    """
    Estimate compute and memory requirements.
    """
    estimates = {
        "PPO": {
            "gpu_memory_gb": model_params * 8,  # Policy + Value + Reference
            "training_hours_a100": model_params * 0.1 * dataset_size / 10000,
            "complexity": "high"
        },
        "GRPO": {
            "gpu_memory_gb": model_params * 4,  # Policy + Reference only
            "training_hours_a100": model_params * 0.08 * dataset_size / 10000,
            "complexity": "medium"
        },
        "DPO": {
            "gpu_memory_gb": model_params * 4,  # Policy + Reference
            "training_hours_a100": model_params * 0.05 * dataset_size / 10000,
            "complexity": "low"
        },
        "RLVR": {
            "gpu_memory_gb": model_params * 4,
            "training_hours_a100": model_params * 0.08 * dataset_size / 10000,
            "complexity": "medium"
        },
        "SFT": {
            "gpu_memory_gb": model_params * 2,  # Just policy
            "training_hours_a100": model_params * 0.02 * dataset_size / 10000,
            "complexity": "low"
        }
    }

    return estimates.get(approach, estimates["SFT"])

Training Pipeline Template

Python
class ReasoningModelTrainingPipeline:
    """
    Complete pipeline for training a reasoning model.
    """

    def __init__(
        self,
        base_model: str,
        approach: str = "GRPO",
        output_dir: str = "./trained_model"
    ):
        self.base_model = base_model
        self.approach = approach
        self.output_dir = output_dir

    def run(
        self,
        training_data: dict,
        validation_data: dict = None
    ):
        """
        Run complete training pipeline.

        training_data format depends on approach:
        - PPO/GRPO: {"prompts": [...], "ground_truth": [...]} for verifiable
        - DPO: {"prompt": [...], "chosen": [...], "rejected": [...]}
        - SFT: {"prompt": [...], "response": [...]}
        """
        # 1. Load and prepare model
        print("Loading model...")
        model = self.load_model()

        # 2. (Optional) Cold start SFT
        if self.approach in ["GRPO", "PPO"]:
            print("Cold start SFT...")
            model = self.cold_start_sft(model, training_data)

        # 3. Main training
        print(f"Main training with {self.approach}...")
        model = self.main_training(model, training_data)

        # 4. Evaluation
        if validation_data:
            print("Evaluating...")
            metrics = self.evaluate(model, validation_data)
            print(f"Validation metrics: {metrics}")

        # 5. Save
        print(f"Saving to {self.output_dir}...")
        model.save_pretrained(self.output_dir)

        return model

    def load_model(self):
        """Load base model with appropriate settings."""
        model = AutoModelForCausalLM.from_pretrained(
            self.base_model,
            torch_dtype=torch.bfloat16,
            device_map="auto"
        )
        return model

    def cold_start_sft(self, model, data):
        """Light SFT before RL."""
        # Use ~1000-5000 high-quality examples
        sft_examples = data.get("cold_start", data["prompts"][:1000])

        trainer = SFTTrainer(
            model=model,
            train_dataset=sft_examples,
            max_steps=500
        )
        trainer.train()

        return model

    def main_training(self, model, data):
        """Main training based on approach."""
        if self.approach == "GRPO":
            return self.train_grpo(model, data)
        elif self.approach == "PPO":
            return self.train_ppo(model, data)
        elif self.approach == "DPO":
            return self.train_dpo(model, data)
        elif self.approach == "RLVR":
            return self.train_rlvr(model, data)
        else:
            raise ValueError(f"Unknown approach: {self.approach}")

    def train_grpo(self, model, data):
        """GRPO training."""
        # Create reward function
        if "ground_truth" in data:
            reward_fn = RuleBasedRewardSystem()
        else:
            reward_fn = LearnedRewardModel.from_pretrained("...")

        trainer = GRPOTrainer(
            policy_model=model,
            reference_model=model.copy(),
            reward_model=reward_fn,
            tokenizer=self.tokenizer,
            config=GRPOConfig()
        )

        trainer.train(data["prompts"], num_iterations=10000)
        return model

    def evaluate(self, model, validation_data):
        """Evaluate trained model."""
        correct = 0
        total = 0

        for problem, answer in zip(validation_data["prompts"], validation_data["ground_truth"]):
            response = model.generate(problem, max_new_tokens=1024)
            predicted = extract_answer(response)

            if predicted == answer:
                correct += 1
            total += 1

        return {
            "accuracy": correct / total,
            "total": total
        }

Conclusion

Training reasoning models involves three key innovations:

  1. Reward Functions: Rule-based verification enables scalable training without learned reward models
  2. Efficient Algorithms: GRPO eliminates the value model, halving memory requirements
  3. Emergent Capabilities: Reasoning behaviors emerge naturally from optimizing for answer correctness

Key takeaways:

  • Start simple: Rule-based rewards (RLVR) are often sufficient for verifiable domains
  • GRPO > PPO for reasoning: Same quality, half the memory
  • Distillation works: 7B models can achieve strong reasoning through distillation
  • Cold start matters: A few thousand examples before RL improves stability
  • DPO for preferences: When you have preference data but no verifiable rewards

The field is evolving rapidly—2025 will likely bring further innovations in training efficiency and capability transfer.

Frequently Asked Questions

Enrico Piovano, PhD

Co-founder & CTO at Goji AI. Former Applied Scientist at Amazon (Alexa & AGI), focused on Agentic AI and LLMs. PhD in Electrical Engineering from Imperial College London. Gold Medalist at the National Mathematical Olympiad.

Related Articles