Skip to main content
Back to Blog

Transformers for Recommendation Systems: From SASRec to HSTU

A comprehensive deep dive into transformer-based recommendation systems. From the fundamentals of sequential recommendation to Meta's trillion-parameter HSTU, understand how attention mechanisms revolutionized personalization.

30 min read
Share:

Why Transformers Changed Recommendation Systems

The success of transformers in NLP sparked a revolution in recommendation systems. Before 2018, sequential recommendation was dominated by RNNs and Markov chains. Then came SASRec, BERT4Rec, and a parade of transformer-based models that consistently outperformed their predecessors.

But why do transformers work so well for recommendations? The answer lies in what sequential recommendation actually requires:

  1. Capturing user intent from behavior sequences: Users don't interact with items randomly. Each click, purchase, or watch reveals preferences that evolve over time.

  2. Modeling complex dependencies: A user who bought running shoes, then a fitness tracker, then protein powder is on a fitness journey. These items relate to each other across many steps.

  3. Handling variable-length histories: Some users have 10 interactions, others have 10,000. The model must work for both.

  4. Real-time inference at scale: Recommendations must be computed in milliseconds for millions of users.

Transformers address all of these challenges through their attention mechanism—allowing direct connections between any two items in a user's history, regardless of how far apart they occurred.

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                EVOLUTION OF SEQUENTIAL RECOMMENDATION                    │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  BEFORE TRANSFORMERS (2010-2017):                                       │
│  ─────────────────────────────────                                       │
│                                                                          │
│  Markov Chains     →  item₁ → item₂ → item₃ → ?                        │
│  (Only last item)     Limited to previous transition                    │
│                                                                          │
│  RNNs/GRUs/LSTMs   →  item₁ → h₁ → item₂ → h₂ → item₃ → h₃ → ?       │
│  (Sequential)         Hidden state compresses history                   │
│                       Information degrades over distance                │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  TRANSFORMER ERA (2018-NOW):                                             │
│  ───────────────────────────                                             │
│                                                                          │
│  Self-Attention    →  item₁ ←──────────────────┐                       │
│  (All-to-all)         item₂ ←───────────┐      │                       │
│                       item₃ ←────┐      │      │                       │
│                         ?    ────┴──────┴──────┘                       │
│                                                                          │
│                       Every item directly attends to every other        │
│                       No information bottleneck                         │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  WHY THIS MATTERS FOR RECOMMENDATIONS:                                   │
│                                                                          │
│  User history: [iPhone case, AirPods, MacBook charger, ..., iPad]      │
│                                                                          │
│  When predicting next item after "iPad":                                │
│  - Transformer sees ALL Apple products directly                         │
│  - RNN would have compressed early items into hidden state              │
│  - Markov only sees "iPad"                                              │
│                                                                          │
│  The "Apple ecosystem" pattern spans the entire sequence                │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

2024-2025 State of the Art: Meta's HSTU (Hierarchical Sequential Transduction Units) demonstrated that transformers can scale to trillion parameters for recommendations, achieving 12.4%+ topline metric improvements in production. The key innovations: removing softmax normalization and using generative (next-item prediction) training objectives.


Part I: Foundations of Sequential Recommendation

The Sequential Recommendation Problem

Given a user's interaction history S=(s1,s2,...,st)S = (s_1, s_2, ..., s_t), predict the next item st+1s_{t+1} they will interact with. This seems simple, but the challenges are significant:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                SEQUENTIAL RECOMMENDATION SETUP                           │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  INPUT: User interaction sequence                                        │
│  ────────────────────────────────                                        │
│                                                                          │
│  User A: [item_23, item_156, item_89, item_42, item_301, ...]          │
│           ↑         ↑          ↑        ↑         ↑                     │
│          t=1       t=2        t=3      t=4       t=5                    │
│                                                                          │
│  Each item is typically represented as:                                  │
│  - Item ID (integer index into embedding table)                         │
│  - Optional: item features (category, price, etc.)                      │
│                                                                          │
│  OUTPUT: Probability distribution over all items                         │
│  ───────────────────────────────────────────────                         │
│                                                                          │
│  P(next_item | history) = softmax(scores over all items)                │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  CHALLENGES:                                                             │
│                                                                          │
│  1. SCALE                                                                │
│     - Millions of items, billions of users                              │
│     - Cannot compute full softmax over all items                        │
│     - Need efficient retrieval/ranking stages                           │
│                                                                          │
│  2. SPARSITY                                                             │
│     - Most users interact with tiny fraction of items                   │
│     - Many items have few interactions (long-tail)                      │
│     - Cold-start for new users/items                                    │
│                                                                          │
│  3. DYNAMICS                                                             │
│     - User preferences change over time                                 │
│     - New items added continuously                                      │
│     - Seasonal and trending effects                                     │
│                                                                          │
│  4. LATENCY                                                              │
│     - Recommendations must be computed in <100ms                        │
│     - Often much stricter (10-20ms)                                     │
│     - At massive scale (millions of QPS)                                │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

From Matrix Factorization to Deep Learning

Before diving into transformers, it's worth understanding what came before:

Matrix Factorization (2009): The Netflix Prize era. Represent users and items as latent vectors, predict ratings as dot products. Simple, interpretable, but static—doesn't capture sequential patterns.

RNN-based Models (2016-2018): GRU4Rec and variants processed sequences with recurrent networks. Better at capturing order, but suffered from vanishing gradients and sequential bottlenecks.

CNN-based Models (2018): Caser used convolutional filters to capture local patterns in sequences. Fast, but limited receptive field.

Attention-based Models (2018+): SASRec, BERT4Rec, and descendants. This is where we'll focus.


Part II: Transformer Architecture Deep Dive

Before implementing SASRec and BERT4Rec, we need to thoroughly understand the transformer architecture. This section covers every component in detail with mathematical formulations, intuitions, and implementation examples.

The Attention Mechanism: Core Intuition

Attention answers: "Which parts of the input should I focus on when producing this output?" In recommendation, this becomes: "Which past items are relevant for predicting the next item?"

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    ATTENTION MECHANISM INTUITION                         │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  USER HISTORY: [Running Shoes, Fitness Tracker, Protein Powder, ...]    │
│                                                                          │
│  Question: "What should I recommend next?"                              │
│                                                                          │
│  ATTENTION PROCESS:                                                      │
│  ─────────────────                                                       │
│                                                                          │
│  Step 1: Create a QUERY from current position                           │
│          "I'm looking for items related to fitness..."                  │
│                                                                          │
│  Step 2: Compare query against all past items (KEYS)                    │
│          Running Shoes   → High relevance (fitness)                     │
│          Fitness Tracker → High relevance (fitness)                     │
│          Phone Case      → Low relevance (unrelated)                    │
│          Protein Powder  → High relevance (fitness)                     │
│                                                                          │
│  Step 3: Weight VALUES by relevance scores                              │
│          Output = 0.35×(Running Shoes) + 0.30×(Tracker) +              │
│                   0.05×(Phone Case) + 0.30×(Protein)                   │
│                                                                          │
│  Result: Context-aware representation for prediction                    │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Scaled Dot-Product Attention: The Mathematics

The attention mechanism is the heart of transformer models, and understanding it deeply is essential for building effective recommendation systems. At its core, attention computes a weighted average of values, where the weights are determined by how well queries match keys.

Why "Scaled Dot-Product"?

The name describes exactly what happens: we compute dot products between queries and keys (measuring similarity), then scale the result. The dot product is a natural similarity measure—when two vectors point in similar directions, their dot product is large. When they're orthogonal (unrelated), the dot product is zero. This is precisely what we want: items that are "similar" to what we're looking for should have high attention weights.

The fundamental attention operation:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

Understanding Each Component:

  • Query (Q): "What am I looking for?" In recommendations, this represents the current context or position asking "what should come next?"
  • Key (K): "What do I have to offer?" Each item in the history advertises itself through its key representation.
  • Value (V): "What information do I carry?" The actual content that gets aggregated based on attention weights.
  • Scaling by dk\sqrt{d_k}: Critical for stable training. Without scaling, dot products grow with dimension, making softmax outputs extremely peaked (close to one-hot). This starves gradients and makes learning difficult.

The Information Flow:

Think of attention as a soft database lookup. The query asks a question, keys determine relevance (which database entries match?), and values provide the answer (what's stored in those entries?). Unlike hard lookups that return one result, soft attention returns a weighted blend of all values, proportional to relevance.

In recommendation systems specifically:

  • Queries come from the position we're predicting for
  • Keys and Values come from the user's interaction history
  • The output tells us: "Given where the user is in their journey, what aspects of their history are most relevant?"
Code
┌─────────────────────────────────────────────────────────────────────────┐
│              SCALED DOT-PRODUCT ATTENTION: STEP BY STEP                  │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  EXAMPLE: Sequence of 4 items, embedding dim = 8                        │
│                                                                          │
│  INPUT: X ∈ ℝ^(4×8)                                                     │
│  ─────                                                                   │
│  Item 1: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]                      │
│  Item 2: [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]                      │
│  Item 3: [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]                      │
│  Item 4: [0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1]                      │
│                                                                          │
│  STEP 1: Linear projections to Q, K, V                                  │
│  ──────────────────────────────────────                                  │
│                                                                          │
│  Q = X × W_Q    where W_Q ∈ ℝ^(8×8)                                    │
│  K = X × W_K    where W_K ∈ ℝ^(8×8)                                    │
│  V = X × W_V    where W_V ∈ ℝ^(8×8)                                    │
│                                                                          │
│  Each item now has query, key, and value representations                │
│                                                                          │
│  STEP 2: Compute attention scores (QK^T)                                │
│  ─────────────────────────────────────────                               │
│                                                                          │
│  Scores = Q × K^T ∈ ℝ^(4×4)                                            │
│                                                                          │
│           │ Item1  Item2  Item3  Item4                                  │
│    ───────┼────────────────────────────                                 │
│    Item1  │  0.82   0.76   0.71   0.65   ← How much Item1 attends to   │
│    Item2  │  0.76   0.89   0.84   0.79      each other item            │
│    Item3  │  0.71   0.84   0.95   0.90                                  │
│    Item4  │  0.65   0.79   0.90   1.02                                  │
│                                                                          │
│  STEP 3: Scale by √d_k                                                  │
│  ────────────────────────                                                │
│                                                                          │
│  Why scale? Dot products grow with dimension, making softmax            │
│  extremely peaked. Division by √d_k keeps gradients stable.             │
│                                                                          │
│  Scaled = Scores / √8 = Scores / 2.83                                  │
│                                                                          │
│  STEP 4: Apply causal mask (for autoregressive)                         │
│  ──────────────────────────────────────────────                          │
│                                                                          │
│           │ Item1  Item2  Item3  Item4                                  │
│    ───────┼────────────────────────────                                 │
│    Item1  │  0.29   -∞     -∞     -∞     ← Item1 only sees itself      │
│    Item2  │  0.27   0.31   -∞     -∞     ← Item2 sees 1,2              │
│    Item3  │  0.25   0.30   0.34   -∞     ← Item3 sees 1,2,3            │
│    Item4  │  0.23   0.28   0.32   0.36   ← Item4 sees all              │
│                                                                          │
│  STEP 5: Softmax (row-wise normalization)                               │
│  ─────────────────────────────────────────                               │
│                                                                          │
│           │ Item1  Item2  Item3  Item4                                  │
│    ───────┼────────────────────────────                                 │
│    Item1  │  1.00   0.00   0.00   0.00   ← Weights sum to 1            │
│    Item2  │  0.49   0.51   0.00   0.00                                  │
│    Item3  │  0.30   0.33   0.37   0.00                                  │
│    Item4  │  0.22   0.24   0.26   0.28                                  │
│                                                                          │
│  STEP 6: Weighted sum of values                                         │
│  ─────────────────────────────────                                       │
│                                                                          │
│  Output[i] = Σ_j (attention_weights[i,j] × V[j])                       │
│                                                                          │
│  Output[4] = 0.22×V[1] + 0.24×V[2] + 0.26×V[3] + 0.28×V[4]            │
│                                                                          │
│  Each output is a weighted combination of all visible values            │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Why Causal Masking Matters for Recommendations:

The causal mask (Step 4 above) is crucial for sequential recommendation. When predicting what a user will do next, we can only use information from items they've already interacted with—we can't "peek" at future interactions. This creates an autoregressive setup where each position only attends to previous positions.

This is different from bidirectional models like BERT4Rec, which allow items to attend to both past and future positions (useful during training with masked item prediction, but requires special handling at inference time).

The Softmax Temperature:

The scaling factor dk\sqrt{d_k} acts as a "temperature" for the softmax. Lower temperatures (larger divisor) make the distribution more uniform; higher temperatures make it more peaked. The default dk\sqrt{d_k} was found empirically to work well, but some applications tune this as a hyperparameter.

Implementation Considerations:

When implementing attention, several practical concerns arise:

  • Numerical stability: Subtracting the max before softmax prevents overflow
  • Memory efficiency: For long sequences, attention matrices can be huge (n² memory)
  • Masked positions: Setting to -inf before softmax ensures zero attention weight
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class ScaledDotProductAttention(nn.Module):
    """
    Scaled Dot-Product Attention with detailed implementation.

    This is the core building block of all transformer models.
    Understanding this deeply is essential for RecSys transformers.
    """

    def __init__(self, temperature: float = None):
        super().__init__()
        self.temperature = temperature  # If None, computed from d_k

    def forward(
        self,
        query: torch.Tensor,      # (batch, n_query, d_k)
        key: torch.Tensor,        # (batch, n_key, d_k)
        value: torch.Tensor,      # (batch, n_key, d_v)
        mask: torch.Tensor = None,  # (batch, n_query, n_key) or broadcastable
        return_attention: bool = False,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            query: What we're looking for
            key: What we're comparing against
            value: What we retrieve
            mask: True = masked (will not attend), False = can attend

        Returns:
            output: (batch, n_query, d_v) - weighted sum of values
            attention_weights: (batch, n_query, n_key) - attention distribution
        """
        d_k = query.size(-1)
        temperature = self.temperature or math.sqrt(d_k)

        # Step 1: Compute attention scores
        # (batch, n_query, d_k) @ (batch, d_k, n_key) -> (batch, n_query, n_key)
        scores = torch.matmul(query, key.transpose(-2, -1))

        # Step 2: Scale by temperature (sqrt(d_k))
        scores = scores / temperature

        # Step 3: Apply mask (set masked positions to -inf)
        if mask is not None:
            scores = scores.masked_fill(mask, float('-inf'))

        # Step 4: Softmax to get attention weights
        attention_weights = F.softmax(scores, dim=-1)

        # Handle case where entire row is masked (all -inf -> nan after softmax)
        attention_weights = attention_weights.masked_fill(
            torch.isnan(attention_weights), 0.0
        )

        # Step 5: Weighted sum of values
        output = torch.matmul(attention_weights, value)

        if return_attention:
            return output, attention_weights
        return output, attention_weights


# Demonstration
def demonstrate_attention():
    """Step-by-step demonstration of attention computation."""

    # Small example: 2 items, 4-dimensional embeddings
    batch_size = 1
    seq_len = 4
    d_model = 8

    # Random input (simulating item embeddings)
    torch.manual_seed(42)
    x = torch.randn(batch_size, seq_len, d_model)

    # Linear projections (normally these are learned)
    W_q = nn.Linear(d_model, d_model, bias=False)
    W_k = nn.Linear(d_model, d_model, bias=False)
    W_v = nn.Linear(d_model, d_model, bias=False)

    Q = W_q(x)  # (1, 4, 8)
    K = W_k(x)  # (1, 4, 8)
    V = W_v(x)  # (1, 4, 8)

    print("Query shape:", Q.shape)
    print("Key shape:", K.shape)
    print("Value shape:", V.shape)

    # Compute attention scores
    scores = torch.matmul(Q, K.transpose(-2, -1))  # (1, 4, 4)
    print("\nRaw attention scores:\n", scores[0])

    # Scale
    scaled_scores = scores / math.sqrt(d_model)
    print("\nScaled scores:\n", scaled_scores[0])

    # Create causal mask
    causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
    print("\nCausal mask (True = masked):\n", causal_mask)

    # Apply mask
    masked_scores = scaled_scores.masked_fill(causal_mask, float('-inf'))
    print("\nMasked scores:\n", masked_scores[0])

    # Softmax
    attention_weights = F.softmax(masked_scores, dim=-1)
    print("\nAttention weights (sum to 1 per row):\n", attention_weights[0])
    print("Row sums:", attention_weights[0].sum(dim=-1))

    # Output
    output = torch.matmul(attention_weights, V)
    print("\nOutput shape:", output.shape)

    return output, attention_weights


# Run demonstration
output, weights = demonstrate_attention()

Multi-Head Attention: Capturing Different Relationships

A single attention head can only learn one way of relating items to each other. But user behavior is complex—items relate through categories, brands, complementarity, price ranges, and many other dimensions simultaneously. Multi-head attention solves this by running multiple attention operations in parallel, each learning different relationship patterns.

The Key Insight:

Imagine analyzing a user who bought running shoes, a fitness tracker, and protein powder. Different aspects of these items matter for different predictions:

  • Category view: All are fitness-related → recommend more fitness items
  • Complementary view: Shoes need socks, tracker needs charger → recommend accessories
  • Price tier view: All mid-range → recommend similar price points
  • Brand view: Prefers Nike → recommend more Nike products
  • Temporal view: Bought in sequence suggesting workout routine → recommend recovery items

A single attention head might capture one of these perspectives, but we want them all. Multi-head attention dedicates separate "heads" to learn these different relationship types, then combines their insights.

How It Works Mathematically:

Instead of one large attention operation, we run h smaller ones in parallel. Each head operates on a lower-dimensional subspace (d_k = d_model / h), learns its own Q/K/V projections, and captures its own relationship type. The outputs are concatenated and projected back to the original dimension.

This isn't just parallelization for efficiency—it's fundamentally about representational diversity. Different heads learn different patterns, and the combination is more expressive than a single head of the same total dimension would be.

Practical Implications for Recommendations:

Research has shown that different attention heads in trained RecSys models actually do specialize:

  • Some heads focus on recency (high weight on recent items)
  • Some heads focus on category similarity
  • Some heads learn positional patterns (periodic preferences)
  • Some heads capture user-specific idiosyncrasies

This interpretability is valuable—you can inspect attention patterns to understand why the model made a recommendation.

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                     MULTI-HEAD ATTENTION                                 │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  WHY MULTIPLE HEADS?                                                     │
│  ───────────────────                                                     │
│                                                                          │
│  Single head might learn: "Items from same category"                    │
│                                                                          │
│  But we want to capture multiple relationship types:                    │
│  • Head 1: Category similarity (shoes → shoes)                          │
│  • Head 2: Complementary items (shoes → socks)                          │
│  • Head 3: Price range similarity                                       │
│  • Head 4: Temporal patterns (morning → evening items)                  │
│  • Head 5: Brand affinity                                               │
│  • Head 6: Style similarity                                             │
│  • Head 7: Seasonal patterns                                            │
│  • Head 8: Cross-category bundles                                       │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  ARCHITECTURE:                                                           │
│                                                                          │
│  Input X ∈ ℝ^(seq_len × d_model)                                       │
│                                                                          │
│         ┌──────────────────────────────────────────────────────┐       │
│         │                                                        │       │
│         ▼                                                        ▼       │
│    ┌─────────┐  ┌─────────┐  ┌─────────┐      ┌─────────┐            │
│    │ Head 1  │  │ Head 2  │  │ Head 3  │ ...  │ Head h  │            │
│    │ d_k=64  │  │ d_k=64  │  │ d_k=64  │      │ d_k=64  │            │
│    └────┬────┘  └────┬────┘  └────┬────┘      └────┬────┘            │
│         │            │            │                 │                  │
│         └──────┬─────┴─────┬──────┴─────────┬──────┘                  │
│                │           │                │                          │
│                ▼           ▼                ▼                          │
│         ┌─────────────────────────────────────┐                       │
│         │     Concatenate: (seq_len, h×d_k)   │                       │
│         └─────────────────┬───────────────────┘                       │
│                           │                                            │
│                           ▼                                            │
│                   ┌───────────────┐                                   │
│                   │ Linear W_O    │                                   │
│                   │ (h×d_k → d_m) │                                   │
│                   └───────┬───────┘                                   │
│                           │                                            │
│                           ▼                                            │
│                  Output ∈ ℝ^(seq_len × d_model)                       │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  MATHEMATICS:                                                            │
│                                                                          │
│  head_i = Attention(XW_Q^i, XW_K^i, XW_V^i)                            │
│                                                                          │
│  MultiHead(X) = Concat(head_1, ..., head_h) × W_O                      │
│                                                                          │
│  Where:                                                                  │
│  • W_Q^i, W_K^i ∈ ℝ^(d_model × d_k)                                   │
│  • W_V^i ∈ ℝ^(d_model × d_v)                                          │
│  • W_O ∈ ℝ^(h×d_v × d_model)                                          │
│  • Typically: d_k = d_v = d_model / h                                  │
│                                                                          │
│  PARAMETER COUNT (d_model=512, h=8):                                    │
│  • Per head: 3 × (512 × 64) = 98,304                                   │
│  • All heads: 8 × 98,304 = 786,432                                     │
│  • Output projection: 512 × 512 = 262,144                              │
│  • Total: 1,048,576 parameters                                         │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘
Python
class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention with complete implementation details.

    Each head operates on a subspace of the input, learning different
    attention patterns. Outputs are concatenated and projected.
    """

    def __init__(
        self,
        d_model: int,
        num_heads: int,
        dropout: float = 0.1,
        bias: bool = True,
    ):
        super().__init__()

        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # Dimension per head
        self.d_v = d_model // num_heads

        # Separate projection matrices for Q, K, V
        # In practice, often combined into one large matrix for efficiency
        self.W_q = nn.Linear(d_model, d_model, bias=bias)
        self.W_k = nn.Linear(d_model, d_model, bias=bias)
        self.W_v = nn.Linear(d_model, d_model, bias=bias)

        # Output projection
        self.W_o = nn.Linear(d_model, d_model, bias=bias)

        # Dropout on attention weights
        self.dropout = nn.Dropout(dropout)

        # For storing attention weights (useful for visualization)
        self.attention_weights = None

    def forward(
        self,
        query: torch.Tensor,      # (batch, seq_len, d_model)
        key: torch.Tensor,        # (batch, seq_len, d_model)
        value: torch.Tensor,      # (batch, seq_len, d_model)
        mask: torch.Tensor = None,
        return_attention: bool = False,
    ) -> torch.Tensor:
        """
        For self-attention: query = key = value = x
        For cross-attention: query from decoder, key/value from encoder
        """
        batch_size = query.size(0)
        seq_len_q = query.size(1)
        seq_len_k = key.size(1)

        # Step 1: Linear projections
        Q = self.W_q(query)  # (batch, seq_len_q, d_model)
        K = self.W_k(key)    # (batch, seq_len_k, d_model)
        V = self.W_v(value)  # (batch, seq_len_k, d_model)

        # Step 2: Reshape for multi-head: split d_model into num_heads × d_k
        # (batch, seq_len, d_model) -> (batch, seq_len, num_heads, d_k)
        # -> (batch, num_heads, seq_len, d_k)
        Q = Q.view(batch_size, seq_len_q, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len_k, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len_k, self.num_heads, self.d_v).transpose(1, 2)

        # Step 3: Scaled dot-product attention for each head
        # scores: (batch, num_heads, seq_len_q, seq_len_k)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        # Step 4: Apply mask
        if mask is not None:
            # Expand mask for num_heads dimension
            if mask.dim() == 2:
                mask = mask.unsqueeze(0).unsqueeze(0)  # (1, 1, seq_q, seq_k)
            elif mask.dim() == 3:
                mask = mask.unsqueeze(1)  # (batch, 1, seq_q, seq_k)
            scores = scores.masked_fill(mask, float('-inf'))

        # Step 5: Softmax and dropout
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        # Store for visualization
        self.attention_weights = attention_weights.detach()

        # Step 6: Apply attention to values
        # (batch, num_heads, seq_len_q, seq_len_k) @ (batch, num_heads, seq_len_k, d_v)
        # -> (batch, num_heads, seq_len_q, d_v)
        context = torch.matmul(attention_weights, V)

        # Step 7: Concatenate heads
        # (batch, num_heads, seq_len_q, d_v) -> (batch, seq_len_q, num_heads, d_v)
        # -> (batch, seq_len_q, d_model)
        context = context.transpose(1, 2).contiguous().view(
            batch_size, seq_len_q, self.d_model
        )

        # Step 8: Final linear projection
        output = self.W_o(context)

        if return_attention:
            return output, attention_weights
        return output


def visualize_attention_heads():
    """
    Demonstrate what different attention heads might learn.
    In practice, heads learn these patterns from data.
    """

    # Example: 4 items in sequence
    items = ["Running Shoes", "Yoga Mat", "Protein Powder", "Water Bottle"]

    # Hypothetical attention patterns from different heads
    patterns = {
        "Head 1 (Category)": [
            [1.0, 0.0, 0.0, 0.0],  # Shoes → Shoes
            [0.0, 1.0, 0.0, 0.0],  # Yoga → Yoga
            [0.2, 0.1, 0.7, 0.0],  # Protein → similar to shoes (fitness)
            [0.2, 0.2, 0.2, 0.4],  # Bottle → everything (universal)
        ],
        "Head 2 (Complementary)": [
            [0.3, 0.3, 0.4, 0.0],  # Shoes → Protein (workout stack)
            [0.5, 0.2, 0.3, 0.0],  # Yoga → Shoes (cross-training)
            [0.4, 0.2, 0.2, 0.2],  # Protein → workout gear
            [0.3, 0.3, 0.2, 0.2],  # Bottle → all fitness
        ],
        "Head 3 (Recency)": [
            [1.0, 0.0, 0.0, 0.0],  # Strong recency bias
            [0.2, 0.8, 0.0, 0.0],
            [0.1, 0.2, 0.7, 0.0],
            [0.1, 0.1, 0.3, 0.5],
        ],
        "Head 4 (Position)": [
            [1.0, 0.0, 0.0, 0.0],  # Always attend to first item
            [0.8, 0.2, 0.0, 0.0],
            [0.7, 0.2, 0.1, 0.0],
            [0.6, 0.2, 0.1, 0.1],
        ],
    }

    print("Attention Pattern Visualization")
    print("=" * 60)
    for head_name, pattern in patterns.items():
        print(f"\n{head_name}:")
        print("-" * 40)
        for i, row in enumerate(pattern):
            item = items[i]
            attending_to = [f"{items[j]}({row[j]:.1f})" for j in range(i+1) if row[j] > 0.1]
            print(f"  {item:15} attends to: {', '.join(attending_to)}")

Positional Encodings: Teaching Position Awareness

One of the most counterintuitive aspects of attention is that it's permutation-invariant—without any additional information, attention treats sequences as unordered sets. The output for a sequence [A, B, C] would be identical to [C, B, A], just reordered. This is a problem because order carries crucial meaning.

Why Order Matters in Recommendations:

Sequential patterns are everywhere in user behavior:

  • Intent evolution: Search → Browse → Compare → Purchase follows a decision journey
  • Session dynamics: Morning coffee browsing differs from evening relaxation browsing
  • Recency effects: Recent interactions predict immediate next actions better than older ones
  • Temporal patterns: Weekly shopping routines, seasonal preferences, time-of-day effects

If the model can't distinguish position, it loses all this information. A user who just purchased running shoes and then browsed socks has very different intent than one who browsed socks and then purchased running shoes.

Positional encodings solve this by adding position information to each item embedding. There are several approaches, each with trade-offs:

Sinusoidal encodings (original Transformer) use sine and cosine functions at different frequencies. They require no learned parameters and can theoretically extrapolate to longer sequences than seen in training. However, they provide fixed patterns that may not match the specific positional importance in your domain.

Learned embeddings (common in RecSys) treat position as another vocabulary to embed. Position 1, 2, 3... each get their own learned vector. This is flexible and learns domain-specific positional patterns, but can't extrapolate beyond the maximum training length.

Relative positional encodings (Transformer-XL, T5) encode the distance between positions rather than absolute positions. This often generalizes better—"3 positions ago" is meaningful regardless of whether you're at position 10 or position 100.

Rotary Position Embeddings (RoPE) (LLaMA, modern LLMs) rotate query and key vectors based on position. This elegantly combines absolute and relative encoding—the dot product naturally depends on relative position. RoPE has become the de facto standard for modern transformers due to excellent extrapolation properties.

For recommendation systems specifically:

Learned embeddings are most common because:

  • Sequences are typically bounded (50-200 items)
  • Different positions can have dramatically different importance (recency effects)
  • The specific positional patterns matter (position 1 might be special—the most recent item)

However, if your sequences can vary significantly in length, or if you expect to serve longer sequences than you trained on, relative or rotary encodings are worth considering.

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    POSITIONAL ENCODING METHODS                           │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  PROBLEM: Attention treats [A, B, C] same as [C, B, A]                  │
│  ───────────────────────────────────────────────────────                 │
│                                                                          │
│  But order matters for recommendations!                                  │
│  [Searched Running Shoes → Added to Cart → Purchased]                   │
│  is very different from                                                  │
│  [Purchased → Added to Cart → Searched Running Shoes]                   │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  METHOD 1: SINUSOIDAL (Original Transformer)                            │
│  ───────────────────────────────────────────────                         │
│                                                                          │
│  PE(pos, 2i)   = sin(pos / 10000^(2i/d))                               │
│  PE(pos, 2i+1) = cos(pos / 10000^(2i/d))                               │
│                                                                          │
│  Properties:                                                             │
│  ✓ No learned parameters                                                │
│  ✓ Can extrapolate to longer sequences                                  │
│  ✓ Relative positions captured: PE(pos+k) = f(PE(pos))                 │
│  ✗ Fixed patterns, not task-specific                                   │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  METHOD 2: LEARNED EMBEDDINGS (Common in RecSys)                        │
│  ───────────────────────────────────────────────────                     │
│                                                                          │
│  PE = nn.Embedding(max_seq_len, d_model)                               │
│                                                                          │
│  Properties:                                                             │
│  ✓ Learns task-specific position patterns                              │
│  ✓ Simple to implement                                                  │
│  ✗ Cannot extrapolate beyond max_seq_len                               │
│  ✗ Requires more parameters                                            │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  METHOD 3: RELATIVE POSITIONAL ENCODING (Transformer-XL, T5)           │
│  ────────────────────────────────────────────────────────────            │
│                                                                          │
│  Instead of absolute positions, encode relative distance:               │
│  "Item A is 3 positions before item B"                                  │
│                                                                          │
│  Attention(Q, K) = softmax(QK^T + bias(i-j))                           │
│                                                                          │
│  Properties:                                                             │
│  ✓ Better generalization to different sequence lengths                 │
│  ✓ Captures "distance" rather than "absolute position"                 │
│  ✗ More complex implementation                                         │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  METHOD 4: ROTARY POSITION EMBEDDING (RoPE) - Modern Standard          │
│  ────────────────────────────────────────────────────────────            │
│                                                                          │
│  Rotate query/key vectors based on position:                            │
│  q'_m = R_θ,m × q_m                                                    │
│  k'_n = R_θ,n × k_n                                                    │
│                                                                          │
│  Then: q'_m · k'_n depends only on (m-n), the relative position        │
│                                                                          │
│  Properties:                                                             │
│  ✓ Combines benefits of absolute and relative encodings                │
│  ✓ Decays attention with distance (like a prior)                       │
│  ✓ Used in modern LLMs (LLaMA, Mistral, GPT-NeoX)                      │
│  ✓ Excellent extrapolation properties                                  │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘
Python
class SinusoidalPositionalEncoding(nn.Module):
    """
    Original transformer positional encoding using sinusoidal functions.
    Deterministic, no learned parameters.
    """

    def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

        # Create positional encoding matrix
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)

        # Compute div_term: 10000^(2i/d_model)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )

        # Apply sin to even indices, cos to odd indices
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # Add batch dimension and register as buffer (not a parameter)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (batch, seq_len, d_model) - input embeddings
        Returns:
            (batch, seq_len, d_model) - embeddings + positional encoding
        """
        seq_len = x.size(1)
        x = x + self.pe[:, :seq_len, :]
        return self.dropout(x)


class LearnedPositionalEncoding(nn.Module):
    """
    Learned positional embeddings - common in RecSys transformers.
    Simple and effective for fixed-length sequences.
    """

    def __init__(self, d_model: int, max_len: int = 512, dropout: float = 0.1):
        super().__init__()
        self.position_embedding = nn.Embedding(max_len, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (batch, seq_len, d_model)
        """
        seq_len = x.size(1)
        positions = torch.arange(seq_len, device=x.device)
        pos_emb = self.position_embedding(positions)  # (seq_len, d_model)
        x = x + pos_emb
        return self.dropout(x)


class RotaryPositionalEncoding(nn.Module):
    """
    Rotary Position Embedding (RoPE) - modern standard.
    Used in LLaMA, Mistral, and increasingly in RecSys.

    Key idea: Rotate query and key vectors based on position.
    The dot product then naturally captures relative position.
    """

    def __init__(self, d_model: int, max_len: int = 2048, base: int = 10000):
        super().__init__()
        self.d_model = d_model
        self.max_len = max_len

        # Compute rotation frequencies
        inv_freq = 1.0 / (base ** (torch.arange(0, d_model, 2).float() / d_model))
        self.register_buffer('inv_freq', inv_freq)

        # Precompute rotation matrices
        self._build_cache(max_len)

    def _build_cache(self, seq_len: int):
        """Precompute sin and cos for all positions."""
        positions = torch.arange(seq_len)
        freqs = torch.outer(positions, self.inv_freq)  # (seq_len, d_model/2)

        # Create rotation matrix components
        self.register_buffer('cos_cached', freqs.cos())
        self.register_buffer('sin_cached', freqs.sin())

    def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Apply rotary embeddings to query and key.

        Args:
            q: (batch, num_heads, seq_len, head_dim)
            k: (batch, num_heads, seq_len, head_dim)

        Returns:
            Rotated q and k
        """
        seq_len = q.size(2)

        # Get cached values
        cos = self.cos_cached[:seq_len]  # (seq_len, d/2)
        sin = self.sin_cached[:seq_len]  # (seq_len, d/2)

        # Reshape for broadcasting
        cos = cos.unsqueeze(0).unsqueeze(0)  # (1, 1, seq_len, d/2)
        sin = sin.unsqueeze(0).unsqueeze(0)

        # Apply rotation
        q_rot = self._rotate(q, cos, sin)
        k_rot = self._rotate(k, cos, sin)

        return q_rot, k_rot

    def _rotate(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
        """Apply rotation to each position."""
        # Split into two halves
        x1, x2 = x[..., ::2], x[..., 1::2]

        # Apply rotation: [x1, x2] -> [x1*cos - x2*sin, x1*sin + x2*cos]
        rotated = torch.stack([
            x1 * cos - x2 * sin,
            x1 * sin + x2 * cos
        ], dim=-1).flatten(-2)

        return rotated


# Comparison demonstration
def compare_positional_encodings():
    """Show differences between positional encoding methods."""

    d_model = 64
    seq_len = 10
    batch_size = 2

    # Random input
    x = torch.randn(batch_size, seq_len, d_model)

    # Different encodings
    sinusoidal = SinusoidalPositionalEncoding(d_model)
    learned = LearnedPositionalEncoding(d_model)

    out_sin = sinusoidal(x)
    out_learned = learned(x)

    print("Positional Encoding Comparison")
    print("=" * 50)
    print(f"Input shape: {x.shape}")
    print(f"Output shape (both): {out_sin.shape}")

    # Show that learned embeddings are different each run (before training)
    print(f"\nLearned PE parameter shape: {learned.position_embedding.weight.shape}")
    print(f"Sinusoidal PE requires no parameters (deterministic)")

    # Visualize the encoding patterns
    print("\nSinusoidal encoding pattern (first 4 dims, first 5 positions):")
    print(sinusoidal.pe[0, :5, :4])

Layer Normalization: Stabilizing Deep Networks

Training deep neural networks is inherently unstable. As signals pass through many layers, they can explode (grow exponentially) or vanish (shrink to zero). Normalization techniques address this by constraining activations to reasonable ranges, ensuring gradients flow properly during backpropagation.

Why Layer Norm over Batch Norm?

Batch Normalization, invented for CNNs, normalizes across the batch dimension—computing statistics over all examples in a mini-batch. This works well for images but fails for transformers:

  1. Variable sequence lengths: Different sequences have different lengths, making it unclear how to compute batch statistics across positions
  2. Batch size dependence: Batch norm's statistics depend on batch size, causing training-inference mismatch
  3. Sequence modeling: In autoregressive generation, we process one token at a time—there's no batch to normalize over

Layer Normalization instead normalizes across the feature dimension. For each position in each sequence, we compute the mean and variance over the d_model features, then normalize. This is completely independent of batch size and handles variable-length sequences naturally.

The Pre-Norm vs Post-Norm Debate:

The original Transformer applied layer norm after the residual connection (post-norm):

Code
output = LayerNorm(x + Sublayer(x))

Modern transformers (GPT-2 onward) apply layer norm before the sublayer (pre-norm):

Code
output = x + Sublayer(LayerNorm(x))

Pre-norm has significant advantages:

  • Better gradient flow: The residual path is "clean"—gradients flow directly without passing through normalization
  • Training stability: Pre-norm models train more stably, often without learning rate warmup
  • Easier to scale: Deeper models train more reliably with pre-norm

The intuition: in post-norm, every layer modifies the residual path, accumulating changes. In pre-norm, the residual path is an "identity highway" that just accumulates additions, while normalization happens in a branch.

RMSNorm: A Simpler Alternative:

RMSNorm (Root Mean Square Normalization) simplifies layer norm by removing the mean-centering step—it only divides by the root mean square. This is computationally cheaper and works just as well in practice. LLaMA and other modern architectures use RMSNorm.

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                      LAYER NORMALIZATION                                 │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  BATCH NORM vs LAYER NORM:                                              │
│  ─────────────────────────                                               │
│                                                                          │
│  Input: (batch_size, seq_len, d_model)                                  │
│                                                                          │
│  Batch Norm: Normalize across batch dimension                           │
│              Mean/var computed over (batch_size × seq_len) samples      │
│              Problem: Varies with batch size, bad for variable-length   │
│                                                                          │
│  Layer Norm: Normalize across feature dimension                         │
│              Mean/var computed over d_model features, per position      │
│              Consistent regardless of batch size or sequence length     │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  LAYER NORM FORMULA:                                                     │
│  ───────────────────                                                     │
│                                                                          │
│  For input x ∈ ℝ^d:                                                     │
│                                                                          │
│  μ = (1/d) Σ x_i                          (mean across features)        │
│                                                                          │
│  σ² = (1/d) Σ (x_i - μ)²                  (variance across features)   │
│                                                                          │
│  y = γ × (x - μ) / √(σ² + ε) + β         (normalize and scale)         │
│                                                                          │
│  Where γ, β ∈ ℝ^d are learned parameters                               │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  PRE-NORM vs POST-NORM:                                                  │
│  ──────────────────────                                                  │
│                                                                          │
│  POST-NORM (Original Transformer):                                      │
│  ─────────────────────────────────                                       │
│  output = LayerNorm(x + Sublayer(x))                                   │
│                                                                          │
│  PRE-NORM (GPT-2, Most Modern):                                         │
│  ────────────────────────────────                                        │
│  output = x + Sublayer(LayerNorm(x))                                   │
│                                                                          │
│  ┌─────────────────┐     ┌─────────────────┐                           │
│  │   POST-NORM     │     │    PRE-NORM     │                           │
│  ├─────────────────┤     ├─────────────────┤                           │
│  │       x         │     │       x         │                           │
│  │       │         │     │       │         │                           │
│  │       ▼         │     │       ├─────┐   │                           │
│  │  ┌─────────┐   │     │       ▼     │   │                           │
│  │  │Sublayer │   │     │  ┌─────────┐│   │                           │
│  │  └────┬────┘   │     │  │LayerNorm││   │                           │
│  │       │         │     │  └────┬────┘│   │                           │
│  │       │         │     │       ▼     │   │                           │
│  │   ┌───┴───┐    │     │  ┌─────────┐│   │                           │
│  │   │  Add  │◄───│     │  │Sublayer ││   │                           │
│  │   └───┬───┘    │     │  └────┬────┘│   │                           │
│  │       ▼         │     │       │     │   │                           │
│  │  ┌─────────┐   │     │   ┌───▼───┐ │   │                           │
│  │  │LayerNorm│   │     │   │  Add  │◄┘   │                           │
│  │  └────┬────┘   │     │   └───┬───┘     │                           │
│  │       ▼         │     │       ▼         │                           │
│  │    output      │     │    output      │                           │
│  └─────────────────┘     └─────────────────┘                           │
│                                                                          │
│  PRE-NORM ADVANTAGES:                                                    │
│  • Better gradient flow (residual path is "clean")                     │
│  • More stable training for deep networks                              │
│  • No warm-up often needed                                             │
│  • Default choice for modern transformers                              │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘
Python
class LayerNorm(nn.Module):
    """
    Layer Normalization with detailed implementation.
    Normalizes across the last dimension (features).
    """

    def __init__(self, d_model: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps

        # Learned scale (γ) and shift (β) parameters
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (..., d_model) - input tensor
        Returns:
            Normalized tensor of same shape
        """
        # Compute mean and variance across last dimension
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)

        # Normalize
        x_norm = (x - mean) / torch.sqrt(var + self.eps)

        # Scale and shift
        output = self.gamma * x_norm + self.beta

        return output


class RMSNorm(nn.Module):
    """
    Root Mean Square Layer Normalization.
    Used in LLaMA and other modern architectures.

    Simplifies LayerNorm by removing mean centering.
    Slightly faster and works well in practice.
    """

    def __init__(self, d_model: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Compute RMS
        rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)

        # Normalize and scale
        return self.weight * (x / rms)


# Comparison
def compare_normalizations():
    """Compare LayerNorm and RMSNorm."""

    d_model = 64
    x = torch.randn(2, 10, d_model)  # (batch, seq, features)

    ln = LayerNorm(d_model)
    rms = RMSNorm(d_model)

    out_ln = ln(x)
    out_rms = rms(x)

    print("Normalization Comparison")
    print("=" * 50)
    print(f"Input mean: {x.mean():.4f}, std: {x.std():.4f}")
    print(f"LayerNorm output mean: {out_ln.mean():.4f}, std: {out_ln.std():.4f}")
    print(f"RMSNorm output mean: {out_rms.mean():.4f}, std: {out_rms.std():.4f}")

    print(f"\nLayerNorm params: {sum(p.numel() for p in ln.parameters())}")
    print(f"RMSNorm params: {sum(p.numel() for p in rms.parameters())}")

Feed-Forward Networks: The Memory of Transformers

The FFN in each transformer layer is surprisingly important. Recent research suggests FFN layers store factual knowledge, while attention routes information.

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    FEED-FORWARD NETWORK (FFN)                            │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  ARCHITECTURE:                                                           │
│  ─────────────                                                           │
│                                                                          │
│  FFN(x) = Activation(x × W₁ + b₁) × W₂ + b₂                            │
│                                                                          │
│  Where:                                                                  │
│  • W₁ ∈ ℝ^(d_model × d_ff)     (expand)                                │
│  • W₂ ∈ ℝ^(d_ff × d_model)     (contract)                              │
│  • d_ff = 4 × d_model (typically)                                       │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  VISUAL:                                                                 │
│                                                                          │
│  Input                                                                   │
│    │                                                                     │
│    ▼                                                                     │
│  ┌────────────────┐                                                     │
│  │ d_model = 512  │                                                     │
│  └───────┬────────┘                                                     │
│          │                                                               │
│          ▼ W₁ (expand 4x)                                               │
│  ┌────────────────────────────────────────┐                             │
│  │           d_ff = 2048                   │                             │
│  │   (Each neuron is a "key-value" pair)  │                             │
│  └───────────────────┬────────────────────┘                             │
│                      │                                                   │
│                      ▼ Activation (ReLU/GELU/SwiGLU)                    │
│  ┌────────────────────────────────────────┐                             │
│  │           d_ff = 2048                   │                             │
│  │   (Sparse activation pattern)          │                             │
│  └───────────────────┬────────────────────┘                             │
│                      │                                                   │
│                      ▼ W₂ (contract)                                    │
│  ┌────────────────┐                                                     │
│  │ d_model = 512  │                                                     │
│  └────────────────┘                                                     │
│    │                                                                     │
│    ▼                                                                     │
│  Output                                                                  │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  ACTIVATION FUNCTIONS:                                                   │
│  ─────────────────────                                                   │
│                                                                          │
│  ReLU (Original):      max(0, x)                                        │
│                        Simple, fast, but "dead neurons" problem         │
│                                                                          │
│  GELU (BERT, GPT-2):   x × Φ(x)   where Φ is Gaussian CDF              │
│                        Smooth, probabilistic interpretation             │
│                                                                          │
│  SwiGLU (LLaMA, PaLM): Swish(xW₁) ⊙ (xV)                               │
│                        Gated, best performance in practice              │
│                        But requires extra parameters (V matrix)         │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  FFN AS MEMORY:                                                          │
│  ──────────────                                                          │
│                                                                          │
│  Research shows FFN layers act as key-value memories:                   │
│  • W₁ rows are "keys" (patterns to match)                              │
│  • W₂ columns are "values" (outputs when key matches)                  │
│  • Activation creates sparse retrieval                                  │
│                                                                          │
│  In RecSys: FFN might store "item category → related patterns"         │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘
Python
class FeedForward(nn.Module):
    """Standard feed-forward network with ReLU/GELU."""

    def __init__(
        self,
        d_model: int,
        d_ff: int = None,
        dropout: float = 0.1,
        activation: str = 'gelu',
    ):
        super().__init__()
        d_ff = d_ff or 4 * d_model

        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'gelu':
            self.activation = nn.GELU()
        else:
            raise ValueError(f"Unknown activation: {activation}")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Expand -> Activate -> Contract
        x = self.linear1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x


class SwiGLUFeedForward(nn.Module):
    """
    SwiGLU activation: Used in LLaMA, PaLM, modern transformers.
    Generally performs better than GELU but has more parameters.

    SwiGLU(x) = Swish(xW) ⊙ (xV)
    where Swish(x) = x × sigmoid(x)
    """

    def __init__(
        self,
        d_model: int,
        d_ff: int = None,
        dropout: float = 0.1,
    ):
        super().__init__()
        # d_ff is split between gate and value projections
        # So effective expansion is 2/3 of standard FFN for same param count
        d_ff = d_ff or int(4 * d_model * 2/3)

        self.gate_proj = nn.Linear(d_model, d_ff, bias=False)
        self.up_proj = nn.Linear(d_model, d_ff, bias=False)
        self.down_proj = nn.Linear(d_ff, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Gate path with Swish activation
        gate = F.silu(self.gate_proj(x))  # Swish = x * sigmoid(x) = SiLU

        # Up projection (no activation)
        up = self.up_proj(x)

        # Element-wise multiplication (gating)
        hidden = gate * up

        # Down projection
        output = self.down_proj(self.dropout(hidden))

        return output


# Parameter count comparison
def compare_ffn_variants():
    """Compare parameter counts and outputs of FFN variants."""

    d_model = 512
    d_ff = 2048

    standard = FeedForward(d_model, d_ff, activation='gelu')
    swiglu = SwiGLUFeedForward(d_model, int(d_ff * 2/3))  # Adjusted for fairness

    standard_params = sum(p.numel() for p in standard.parameters())
    swiglu_params = sum(p.numel() for p in swiglu.parameters())

    print("FFN Variant Comparison")
    print("=" * 50)
    print(f"Standard (GELU) parameters: {standard_params:,}")
    print(f"SwiGLU parameters: {swiglu_params:,}")
    print(f"Ratio: {swiglu_params/standard_params:.2f}")

    # Test forward pass
    x = torch.randn(2, 10, d_model)
    out_standard = standard(x)
    out_swiglu = swiglu(x)

    print(f"\nOutput shapes: {out_standard.shape} (both)")

Complete Transformer Block: Putting It All Together

Now we combine all components into a complete transformer block:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    COMPLETE TRANSFORMER BLOCK                            │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│                         Input (seq_len, d_model)                        │
│                                │                                         │
│                                ▼                                         │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │                      PRE-NORM BLOCK 1                            │   │
│  │                                                                   │   │
│  │    ┌─────────┐                                                   │   │
│  │    │LayerNorm│                                                   │   │
│  │    └────┬────┘                                                   │   │
│  │         │                                                         │   │
│  │         ▼                                                         │   │
│  │  ┌─────────────────────────────────────────────────┐            │   │
│  │  │           Multi-Head Self-Attention              │            │   │
│  │  │                                                   │            │   │
│  │  │  • num_heads heads, each with d_k = d_model/h   │            │   │
│  │  │  • Causal mask for autoregressive               │            │   │
│  │  │  • Padding mask for variable length             │            │   │
│  │  │                                                   │            │   │
│  │  └───────────────────────┬───────────────────────────┘            │   │
│  │                          │                                         │   │
│  │                          ▼                                         │   │
│  │                     ┌─────────┐                                   │   │
│  │                     │ Dropout │                                   │   │
│  │                     └────┬────┘                                   │   │
│  │                          │                                         │   │
│  └──────────────────────────┼───────────────────────────────────────┘   │
│                             │                                            │
│    Input ───────────────────┼────────────────────┐                      │
│                             │                     │                      │
│                             ▼                     │                      │
│                         ┌───────┐                │                      │
│                         │  Add  │◄───────────────┘                      │
│                         └───┬───┘                                        │
│                             │                                            │
│                             ▼                                            │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │                      PRE-NORM BLOCK 2                            │   │
│  │                                                                   │   │
│  │    ┌─────────┐                                                   │   │
│  │    │LayerNorm│                                                   │   │
│  │    └────┬────┘                                                   │   │
│  │         │                                                         │   │
│  │         ▼                                                         │   │
│  │  ┌─────────────────────────────────────────────────┐            │   │
│  │  │            Feed-Forward Network                  │            │   │
│  │  │                                                   │            │   │
│  │  │  • Expand: d_model → d_ff (4× typical)          │            │   │
│  │  │  • Activation: GELU or SwiGLU                    │            │   │
│  │  │  • Contract: d_ff → d_model                      │            │   │
│  │  │                                                   │            │   │
│  │  └───────────────────────┬───────────────────────────┘            │   │
│  │                          │                                         │   │
│  │                          ▼                                         │   │
│  │                     ┌─────────┐                                   │   │
│  │                     │ Dropout │                                   │   │
│  │                     └────┬────┘                                   │   │
│  │                          │                                         │   │
│  └──────────────────────────┼───────────────────────────────────────┘   │
│                             │                                            │
│    (from Add) ──────────────┼────────────────────┐                      │
│                             │                     │                      │
│                             ▼                     │                      │
│                         ┌───────┐                │                      │
│                         │  Add  │◄───────────────┘                      │
│                         └───┬───┘                                        │
│                             │                                            │
│                             ▼                                            │
│                    Output (seq_len, d_model)                            │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘
Python
class TransformerBlock(nn.Module):
    """
    Complete transformer block with all components.
    Pre-norm architecture (modern standard).
    """

    def __init__(
        self,
        d_model: int,
        num_heads: int,
        d_ff: int = None,
        dropout: float = 0.1,
        activation: str = 'gelu',
        use_swiglu: bool = False,
    ):
        super().__init__()
        d_ff = d_ff or 4 * d_model

        # Layer norms (pre-norm architecture)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        # Multi-head attention
        self.attention = MultiHeadAttention(
            d_model=d_model,
            num_heads=num_heads,
            dropout=dropout,
        )

        # Feed-forward network
        if use_swiglu:
            self.ffn = SwiGLUFeedForward(d_model, d_ff, dropout)
        else:
            self.ffn = FeedForward(d_model, d_ff, dropout, activation)

        # Dropout for residual connections
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        x: torch.Tensor,
        mask: torch.Tensor = None,
        return_attention: bool = False,
    ) -> torch.Tensor:
        """
        Args:
            x: (batch, seq_len, d_model) - input sequence
            mask: (batch, seq_len, seq_len) - attention mask
        """
        # Self-attention with pre-norm and residual
        normed = self.norm1(x)
        attn_output = self.attention(normed, normed, normed, mask)
        x = x + self.dropout(attn_output)

        # FFN with pre-norm and residual
        normed = self.norm2(x)
        ffn_output = self.ffn(normed)
        x = x + self.dropout(ffn_output)

        return x


class TransformerEncoder(nn.Module):
    """
    Stack of transformer blocks for encoding sequences.
    This is the core of SASRec, BERT4Rec, and similar models.
    """

    def __init__(
        self,
        num_layers: int,
        d_model: int,
        num_heads: int,
        d_ff: int = None,
        dropout: float = 0.1,
        activation: str = 'gelu',
    ):
        super().__init__()

        self.layers = nn.ModuleList([
            TransformerBlock(
                d_model=d_model,
                num_heads=num_heads,
                d_ff=d_ff,
                dropout=dropout,
                activation=activation,
            )
            for _ in range(num_layers)
        ])

        # Final layer norm (important for pre-norm architecture)
        self.final_norm = nn.LayerNorm(d_model)

    def forward(
        self,
        x: torch.Tensor,
        mask: torch.Tensor = None,
    ) -> torch.Tensor:
        """
        Args:
            x: (batch, seq_len, d_model)
            mask: (seq_len, seq_len) - causal mask
        """
        for layer in self.layers:
            x = layer(x, mask)

        return self.final_norm(x)


# Complete demonstration
def build_complete_transformer():
    """Build and demonstrate a complete transformer encoder."""

    # Hyperparameters
    d_model = 256
    num_heads = 8
    num_layers = 6
    d_ff = 1024
    max_seq_len = 50
    vocab_size = 10000  # Number of items

    # Build model
    class SequentialRecommender(nn.Module):
        def __init__(self):
            super().__init__()

            # Embeddings
            self.item_embedding = nn.Embedding(vocab_size + 1, d_model, padding_idx=0)
            self.position_encoding = LearnedPositionalEncoding(d_model, max_seq_len)

            # Transformer encoder
            self.encoder = TransformerEncoder(
                num_layers=num_layers,
                d_model=d_model,
                num_heads=num_heads,
                d_ff=d_ff,
            )

            # Output projection (predict next item)
            self.output_proj = nn.Linear(d_model, vocab_size)

        def forward(self, item_ids):
            # item_ids: (batch, seq_len)
            batch_size, seq_len = item_ids.shape

            # Embed items
            x = self.item_embedding(item_ids)
            x = self.position_encoding(x)

            # Create causal mask
            causal_mask = torch.triu(
                torch.ones(seq_len, seq_len, device=item_ids.device),
                diagonal=1
            ).bool()

            # Encode
            hidden = self.encoder(x, causal_mask)

            # Predict next item
            logits = self.output_proj(hidden)

            return logits

    model = SequentialRecommender()

    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    print("Complete Transformer for Sequential Recommendation")
    print("=" * 60)
    print(f"d_model: {d_model}")
    print(f"num_heads: {num_heads}")
    print(f"num_layers: {num_layers}")
    print(f"d_ff: {d_ff}")
    print(f"vocab_size: {vocab_size}")
    print(f"\nTotal parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")

    # Test forward pass
    batch_size = 4
    seq_len = 20
    item_ids = torch.randint(1, vocab_size, (batch_size, seq_len))

    logits = model(item_ids)
    print(f"\nInput shape: {item_ids.shape}")
    print(f"Output shape: {logits.shape}")
    print(f"Output shape meaning: (batch={batch_size}, seq_len={seq_len}, vocab={vocab_size})")

    return model


# Run
model = build_complete_transformer()

Part III: SASRec - The Foundation

Self-Attentive Sequential Recommendation

SASRec (Kang & McAuley, 2018) was the first successful application of self-attention to sequential recommendation. It's elegant and remains a strong baseline.

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                         SASREC ARCHITECTURE                              │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  INPUT LAYER:                                                            │
│  ────────────                                                            │
│                                                                          │
│  User sequence: [item_1, item_2, item_3, ..., item_n]                   │
│                     ↓        ↓        ↓            ↓                     │
│  Item embeddings:  e_1      e_2      e_3    ...   e_n                   │
│                     +        +        +            +                     │
│  Position embeds:  p_1      p_2      p_3    ...   p_n                   │
│                     ↓        ↓        ↓            ↓                     │
│                   [e_1+p_1, e_2+p_2, e_3+p_3, ..., e_n+p_n]             │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  TRANSFORMER BLOCKS (x L layers):                                        │
│  ─────────────────────────────────                                       │
│                                                                          │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │                                                                   │   │
│  │  Input → [Masked Self-Attention] → Add & Norm → [FFN] → Add & Norm│   │
│  │           (Causal mask: can                                       │   │
│  │            only attend to past)                                   │   │
│  │                                                                   │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│                                                                          │
│  CAUSAL MASKING (Key design choice):                                     │
│  ───────────────────────────────────                                     │
│                                                                          │
│       Attention Matrix          What each position can see              │
│       ┌───┬───┬───┬───┐                                                 │
│       │ 1 │ 0 │ 0 │ 0 │  pos 1 → sees only item_1                      │
│       ├───┼───┼───┼───┤                                                 │
│       │ 1 │ 1 │ 0 │ 0 │  pos 2 → sees item_1, item_2                   │
│       ├───┼───┼───┼───┤                                                 │
│       │ 1 │ 1 │ 1 │ 0 │  pos 3 → sees item_1, item_2, item_3           │
│       ├───┼───┼───┼───┤                                                 │
│       │ 1 │ 1 │ 1 │ 1 │  pos 4 → sees all items                        │
│       └───┴───┴───┴───┘                                                 │
│                                                                          │
│  OUTPUT: Predict next item at each position                              │
│  ──────────────────────────────────────────                              │
│                                                                          │
│  pos 1 output → should predict item_2                                   │
│  pos 2 output → should predict item_3                                   │
│  pos n output → should predict item_{n+1}                               │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

SASRec Implementation

Before diving into the code, let's understand why SASRec works and the key design decisions that made it successful.

Why self-attention for sequential recommendation?

Prior to SASRec, sequential recommendation used RNNs (GRU4Rec) or CNNs (Caser). These approaches had fundamental limitations:

  • RNNs: Process sequences left-to-right, one step at a time. The hidden state must compress all history into a fixed-size vector—a bottleneck. Also, sequential processing prevents parallelization.

  • CNNs: Capture local patterns through fixed-size filters. To capture long-range dependencies, you need many stacked layers, and even then, distant items interact only indirectly.

Self-attention solves both problems:

  1. Direct access to any past item: When predicting what comes after item 10, attention can directly look at item 1, item 5, or any other—no information bottleneck.

  2. Parallel computation: All positions compute attention simultaneously. Training is much faster on GPUs.

  3. Adaptive receptive field: The model learns which items to attend to. For a "laptop bag" purchase, it might attend strongly to a recent "laptop" purchase from 3 items ago, while ignoring unrelated items.

The key architectural decisions:

  1. Causal (unidirectional) attention: Unlike BERT which sees the full sequence, SASRec only lets each position attend to past positions. This matches the recommendation task: when a user is on item 5, we can only use items 1-4 to predict item 6.

  2. Learnable position embeddings: Unlike the original Transformer's sinusoidal positions, SASRec learns position embeddings. With short sequences (50-200 items), there's enough data to learn positions, and learned embeddings often work better.

  3. Shared item embeddings for input and output: The same embedding matrix is used to (a) represent input items and (b) score candidate items for prediction. This parameter tying reduces overfitting and ensures consistency.

  4. Dropout everywhere: Applied to embeddings, attention weights, and FFN layers. Critical for preventing overfitting on sparse user-item data.

Training objective:

SASRec uses binary cross-entropy with negative sampling. For each position tt in the sequence:

  • Positive: The actual next item st+1s_{t+1}
  • Negatives: Randomly sampled items the user didn't interact with

The loss encourages the model to score true next items higher than random items:

L=t=1n1[log(σ(htest+1))+jneglog(1σ(htej))]\mathcal{L} = -\sum_{t=1}^{n-1} \left[ \log(\sigma(\mathbf{h}_t \cdot \mathbf{e}_{s_{t+1}})) + \sum_{j \in \text{neg}} \log(1 - \sigma(\mathbf{h}_t \cdot \mathbf{e}_j)) \right]

Where ht\mathbf{h}_t is the hidden state at position tt and ei\mathbf{e}_i is the embedding of item ii.

Why SASRec remains a strong baseline:

Despite being published in 2018, SASRec consistently appears in top results:

  • Simple architecture with few hyperparameters
  • Efficient training (hours, not days)
  • Works across domains (e-commerce, streaming, news)
  • Easy to extend with additional features

Many "improvements" over SASRec fail to consistently outperform it when properly tuned.

Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SASRec(nn.Module):
    """
    Self-Attentive Sequential Recommendation
    Paper: https://arxiv.org/abs/1808.09781
    """

    def __init__(
        self,
        num_items: int,
        max_seq_len: int = 50,
        hidden_dim: int = 64,
        num_heads: int = 2,
        num_layers: int = 2,
        dropout: float = 0.2,
    ):
        super().__init__()
        self.num_items = num_items
        self.max_seq_len = max_seq_len
        self.hidden_dim = hidden_dim

        # Item embeddings (0 is padding)
        self.item_embedding = nn.Embedding(
            num_items + 1,  # +1 for padding token
            hidden_dim,
            padding_idx=0
        )

        # Learnable position embeddings
        self.position_embedding = nn.Embedding(max_seq_len, hidden_dim)

        # Transformer encoder layers
        self.layers = nn.ModuleList([
            SASRecBlock(hidden_dim, num_heads, dropout)
            for _ in range(num_layers)
        ])

        # Layer normalization
        self.layer_norm = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, item_seq: torch.Tensor) -> torch.Tensor:
        """
        Args:
            item_seq: (batch_size, seq_len) - sequence of item IDs

        Returns:
            (batch_size, seq_len, hidden_dim) - sequence representations
        """
        batch_size, seq_len = item_seq.shape

        # Create position indices
        positions = torch.arange(seq_len, device=item_seq.device)
        positions = positions.unsqueeze(0).expand(batch_size, -1)

        # Embed items and positions
        item_emb = self.item_embedding(item_seq)  # (B, L, D)
        pos_emb = self.position_embedding(positions)  # (B, L, D)

        # Combine embeddings
        hidden = self.dropout(item_emb + pos_emb)

        # Create causal attention mask
        # True = masked (cannot attend), False = can attend
        causal_mask = torch.triu(
            torch.ones(seq_len, seq_len, device=item_seq.device),
            diagonal=1
        ).bool()

        # Create padding mask
        padding_mask = (item_seq == 0)  # (B, L)

        # Apply transformer layers
        for layer in self.layers:
            hidden = layer(hidden, causal_mask, padding_mask)

        hidden = self.layer_norm(hidden)
        return hidden

    def predict(self, item_seq: torch.Tensor) -> torch.Tensor:
        """
        Predict scores for all items given a sequence.

        Returns:
            (batch_size, num_items) - scores for each item
        """
        hidden = self.forward(item_seq)  # (B, L, D)

        # Use last position's representation
        final_hidden = hidden[:, -1, :]  # (B, D)

        # Score all items via dot product with item embeddings
        # Exclude padding embedding (index 0)
        item_embeddings = self.item_embedding.weight[1:]  # (num_items, D)
        scores = torch.matmul(final_hidden, item_embeddings.T)  # (B, num_items)

        return scores


class SASRecBlock(nn.Module):
    """Single transformer block for SASRec."""

    def __init__(self, hidden_dim: int, num_heads: int, dropout: float):
        super().__init__()

        self.attention = nn.MultiheadAttention(
            hidden_dim, num_heads, dropout=dropout, batch_first=True
        )
        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 4, hidden_dim),
            nn.Dropout(dropout),
        )
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        hidden: torch.Tensor,
        causal_mask: torch.Tensor,
        padding_mask: torch.Tensor,
    ) -> torch.Tensor:
        # Self-attention with residual
        normed = self.norm1(hidden)
        attn_out, _ = self.attention(
            normed, normed, normed,
            attn_mask=causal_mask,
            key_padding_mask=padding_mask,
        )
        hidden = hidden + self.dropout(attn_out)

        # FFN with residual
        normed = self.norm2(hidden)
        ffn_out = self.ffn(normed)
        hidden = hidden + ffn_out

        return hidden


# Training example
def train_sasrec():
    """Example training loop for SASRec."""

    # Model setup
    model = SASRec(
        num_items=10000,
        max_seq_len=50,
        hidden_dim=64,
        num_heads=2,
        num_layers=2,
    )

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Training with "shifted sequence" objective
    # Input: [item_1, item_2, ..., item_n]
    # Target: [item_2, item_3, ..., item_{n+1}]

    for epoch in range(100):
        for batch in dataloader:  # Your data loader
            input_seq = batch['input_seq']    # (B, L)
            target_seq = batch['target_seq']  # (B, L)

            # Forward pass
            hidden = model(input_seq)  # (B, L, D)

            # Compute scores for all items at each position
            item_emb = model.item_embedding.weight[1:]  # (num_items, D)
            logits = torch.matmul(hidden, item_emb.T)  # (B, L, num_items)

            # Cross-entropy loss
            loss = F.cross_entropy(
                logits.view(-1, model.num_items),
                target_seq.view(-1),
                ignore_index=0,  # Ignore padding
            )

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

Why Causal Masking?

SASRec uses causal (unidirectional) attention—each position can only attend to previous positions. This is crucial for two reasons:

  1. Training efficiency: With causal masking, we can train on the entire sequence at once using the "shifted sequence" objective. Position ii predicts item i+1i+1, using only items 1...i1...i.

  2. Inference correctness: At inference time, we only have past interactions. The model shouldn't learn to rely on future information that won't be available.


Part IV: BERT4Rec - Bidirectional Attention

The Bidirectional Argument

BERT4Rec (Sun et al., 2019) argued that unidirectional attention limits the model's ability to learn representations. In NLP, BERT's bidirectional training dramatically outperformed GPT-style models on many tasks. Could the same apply to recommendations?

Code
┌─────────────────────────────────────────────────────────────────────────┐
│               SASREC vs BERT4REC: ATTENTION PATTERNS                     │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  SASREC (Unidirectional/Causal):                                        │
│  ─────────────────────────────────                                       │
│                                                                          │
│  Sequence: [shoes, socks, shorts, shirt]                                │
│                                                                          │
│  When representing "socks":                                              │
│  - Can see: shoes, socks                                                │
│  - Cannot see: shorts, shirt                                            │
│                                                                          │
│  "Socks" representation knows about previous items only                  │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  BERT4REC (Bidirectional):                                               │
│  ──────────────────────────                                              │
│                                                                          │
│  Sequence: [shoes, [MASK], shorts, shirt]                               │
│                                                                          │
│  When predicting [MASK] (was "socks"):                                  │
│  - Can see: shoes, shorts, shirt                                        │
│  - Must predict: socks                                                  │
│                                                                          │
│  Prediction uses FULL context (past AND future)                         │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  THE TRADE-OFF:                                                          │
│                                                                          │
│  SASRec advantages:                                                      │
│  + Natural fit for sequential prediction                                │
│  + Efficient training (predict all positions)                           │
│  + Simpler inference (just use last position)                           │
│                                                                          │
│  BERT4Rec advantages:                                                    │
│  + Richer representations (bidirectional context)                       │
│  + Better modeling of item relationships                                │
│  + Can leverage future context during training                          │
│                                                                          │
│  BERT4Rec disadvantages:                                                 │
│  - Training is less efficient (only predict masked items)               │
│  - Inference is different from training (no masking)                    │
│  - More hyperparameters (mask ratio)                                    │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

BERT4Rec Implementation

BERT4Rec brought the revolutionary BERT pretraining paradigm from NLP to recommendations. Understanding its mechanics reveals both its power and its limitations.

The core idea: Masked Language Modeling for items

Instead of predicting the next item given previous items (like SASRec), BERT4Rec randomly masks items in the sequence and predicts them from context—both left AND right:

Code
SASRec (Causal):     [A, B, C, D, ?] → predict E using A,B,C,D
BERT4Rec (Masked):   [A, [MASK], C, D, E] → predict B using A,C,D,E

This bidirectional context is powerful: to predict item B, the model can see what came before (A) AND after (C, D, E). In real user behavior, items often relate to both past and future actions.

Why bidirectional context might help:

Consider a user sequence: [Laptop → ??? → Laptop Bag → USB Hub → Monitor Stand]

  • Unidirectional (SASRec): To predict the masked item, only sees "Laptop"
  • Bidirectional (BERT4Rec): Sees "Laptop" AND "Laptop Bag, USB Hub, Monitor Stand"

The bidirectional model can infer the masked item is likely laptop-related (maybe "Laptop Case" or "Mouse") because it sees the surrounding accessory purchases.

The training/inference mismatch problem:

BERT4Rec has a fundamental issue that SASRec doesn't:

  • Training: Model sees sequences with [MASK] tokens, predicts masked items
  • Inference: Model sees sequences WITHOUT masks, must predict next item

This mismatch means the model is trained on a different distribution than it's tested on. Various solutions exist:

  1. Always append [MASK] at the end during inference
  2. Use the last position's hidden state (ignoring the masking objective)
  3. Fine-tune on next-item prediction after pretraining

The mask ratio hyperparameter:

BERT4Rec introduces a new hyperparameter: what fraction of items to mask? The original paper uses 20%, but optimal values vary by dataset:

  • Too low (5%): Not enough training signal per sequence
  • Too high (50%): Too much context removed, predictions become random

Bidirectional attention = No causal mask:

Unlike SASRec which uses a triangular attention mask (each position sees only past), BERT4Rec uses full attention—every position can attend to every other position (except padding). This is the source of both its power (richer context) and its limitation (can't do autoregressive generation).

Python
class BERT4Rec(nn.Module):
    """
    Bidirectional Encoder Representations from Transformers for Recommendation
    Paper: https://arxiv.org/abs/1904.06690

    Key difference from SASRec: Uses masked language modeling objective
    instead of causal (next-item) prediction.
    """

    def __init__(
        self,
        num_items: int,
        max_seq_len: int = 50,
        hidden_dim: int = 64,
        num_heads: int = 2,
        num_layers: int = 2,
        dropout: float = 0.2,
        mask_prob: float = 0.2,  # Probability of masking each item
    ):
        super().__init__()
        self.num_items = num_items
        self.max_seq_len = max_seq_len
        self.mask_prob = mask_prob
        self.mask_token = num_items + 1  # Special [MASK] token

        # Item embeddings: 0=pad, 1...num_items=items, num_items+1=[MASK]
        self.item_embedding = nn.Embedding(
            num_items + 2,  # +1 for padding, +1 for mask
            hidden_dim,
            padding_idx=0
        )

        self.position_embedding = nn.Embedding(max_seq_len, hidden_dim)

        # BERT uses bidirectional attention (no causal mask)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * 4,
            dropout=dropout,
            activation='gelu',
            batch_first=True,
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)

        self.layer_norm = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(dropout)

        # Output projection
        self.output_proj = nn.Linear(hidden_dim, num_items)

    def forward(
        self,
        item_seq: torch.Tensor,
        masked_positions: torch.Tensor = None,
    ) -> torch.Tensor:
        """
        Args:
            item_seq: (batch_size, seq_len) - may contain [MASK] tokens
            masked_positions: (batch_size, num_masks) - positions to predict

        Returns:
            logits: (batch_size, num_masks, num_items) if masked_positions given
                    else (batch_size, seq_len, num_items)
        """
        batch_size, seq_len = item_seq.shape

        # Position indices
        positions = torch.arange(seq_len, device=item_seq.device)
        positions = positions.unsqueeze(0).expand(batch_size, -1)

        # Embeddings
        item_emb = self.item_embedding(item_seq)
        pos_emb = self.position_embedding(positions)
        hidden = self.dropout(item_emb + pos_emb)

        # Padding mask only (no causal mask - bidirectional!)
        padding_mask = (item_seq == 0)

        # Transformer (bidirectional attention)
        hidden = self.transformer(hidden, src_key_padding_mask=padding_mask)
        hidden = self.layer_norm(hidden)

        # If specific positions requested, extract those
        if masked_positions is not None:
            # Gather hidden states at masked positions
            batch_indices = torch.arange(batch_size, device=item_seq.device)
            batch_indices = batch_indices.unsqueeze(1).expand(-1, masked_positions.shape[1])
            hidden = hidden[batch_indices, masked_positions]  # (B, num_masks, D)

        # Project to item vocabulary
        logits = self.output_proj(hidden)
        return logits

    def mask_sequence(self, item_seq: torch.Tensor):
        """
        Apply masking for training (Cloze task).

        Returns:
            masked_seq: sequence with some items replaced by [MASK]
            labels: original items at masked positions (-100 elsewhere)
            masked_positions: indices of masked positions
        """
        batch_size, seq_len = item_seq.shape

        # Create mask: which positions to mask
        mask = torch.rand(batch_size, seq_len, device=item_seq.device) < self.mask_prob
        mask = mask & (item_seq != 0)  # Don't mask padding

        # Create masked sequence
        masked_seq = item_seq.clone()
        masked_seq[mask] = self.mask_token

        # Create labels (-100 for non-masked, item_id for masked)
        labels = torch.full_like(item_seq, -100)
        labels[mask] = item_seq[mask]

        return masked_seq, labels, mask

    def predict_next(self, item_seq: torch.Tensor) -> torch.Tensor:
        """
        For inference: Append [MASK] and predict it.
        """
        batch_size = item_seq.shape[0]

        # Append [MASK] token
        mask_token = torch.full((batch_size, 1), self.mask_token, device=item_seq.device)
        seq_with_mask = torch.cat([item_seq, mask_token], dim=1)

        # Predict at mask position (last position)
        logits = self.forward(seq_with_mask)
        return logits[:, -1, :]  # (B, num_items)


def train_bert4rec():
    """Training loop for BERT4Rec."""

    model = BERT4Rec(
        num_items=10000,
        max_seq_len=50,
        hidden_dim=64,
        num_heads=2,
        num_layers=2,
        mask_prob=0.2,
    )

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(100):
        for batch in dataloader:
            item_seq = batch['item_seq']  # (B, L)

            # Apply random masking
            masked_seq, labels, mask = model.mask_sequence(item_seq)

            # Forward pass
            logits = model(masked_seq)  # (B, L, num_items)

            # Loss only on masked positions
            loss = F.cross_entropy(
                logits.view(-1, model.num_items),
                labels.view(-1),
                ignore_index=-100,
            )

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

The Verdict: SASRec vs BERT4Rec

Despite BERT4Rec's intuitive appeal, the evidence is mixed. A systematic replication study found:

"BERT4Rec is not consistently superior compared to SASRec in the published literature. The reason might be that BERT4Rec is better on some datasets, but SASRec is better on others."

A 2023 RecSys paper ("Turning Dross Into Gold Loss") showed that with proper tuning, SASRec can match or exceed BERT4Rec while being simpler and more efficient.

Practical recommendation: Start with SASRec. It's simpler, faster to train, and has a more natural inference setup. Try BERT4Rec if you have specific evidence it helps on your data.


Part V: Advanced Architectures

Transformers4Rec: Production-Ready Framework

NVIDIA's Transformers4Rec bridges NLP transformers and RecSys. It allows using HuggingFace transformer architectures directly for recommendation.

Python
from transformers4rec import torch as tr
from transformers4rec.torch.ranking_metric import NDCGAt, RecallAt

# Define schema
schema = tr.Schema(
    features=[
        tr.Feature("item_id", tr.CATEGORICAL, num_items),
        tr.Feature("category", tr.CATEGORICAL, num_categories),
        tr.Feature("price", tr.CONTINUOUS),
    ],
    targets=tr.Target("next_item", tr.CATEGORICAL, num_items),
)

# Build model with XLNet architecture
input_module = tr.TabularSequenceFeatures.from_schema(
    schema,
    max_sequence_length=20,
    aggregation="concat",
)

transformer_config = tr.XLNetConfig.build(
    d_model=64,
    n_head=4,
    n_layer=2,
)

body = tr.SequentialBlock(
    input_module,
    tr.MLPBlock([64]),
    tr.TransformerBlock(transformer_config),
)

head = tr.Head(
    body,
    tr.NextItemPredictionTask(weight_tying=True),
    inputs=input_module,
)

model = tr.Model(head)

Key features:

  • Multiple transformer architectures: GPT-2, XLNet, BERT, Longformer, etc.
  • Side information integration: Item features, user features, context
  • Production-optimized: Integration with NVIDIA Triton for serving

HSTU: Meta's Trillion-Parameter Approach

HSTU (Hierarchical Sequential Transduction Units) represents the current state of the art for large-scale sequential recommendation. Published at ICML 2024 by Meta AI.

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    HSTU KEY INNOVATIONS                                   │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  1. GENERATIVE FRAMING                                                   │
│  ────────────────────────                                                │
│                                                                          │
│  Traditional: Separate retrieval and ranking models                     │
│  HSTU: Single generative model that "generates" next interactions       │
│                                                                          │
│  Instead of: "Score these 1000 candidates"                              │
│  Do: "What will user do next?" → Generates item distribution            │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  2. POINTWISE (NO SOFTMAX) ATTENTION                                    │
│  ─────────────────────────────────────                                   │
│                                                                          │
│  Standard Transformer:                                                   │
│  attention = softmax(QK^T / √d) × V                                     │
│                                                                          │
│  HSTU:                                                                   │
│  attention = φ(Q) × φ(K)^T × V   (pointwise, no softmax)               │
│                                                                          │
│  Why remove softmax?                                                     │
│  - Preserves preference intensity (softmax normalizes away)             │
│  - Better for non-stationary vocabularies (new items daily)             │
│  - Enables linear-time approximations                                   │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  3. HIERARCHICAL STRUCTURE                                               │
│  ──────────────────────────                                              │
│                                                                          │
│  Actions and items interleaved in single sequence:                      │
│                                                                          │
│  [item_1, click, item_2, purchase, item_3, view, item_4, ...]          │
│                                                                          │
│  This captures:                                                          │
│  - What items user interacted with                                      │
│  - How they interacted (click vs purchase vs view)                      │
│  - Temporal ordering of both                                            │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  4. SCALE & RESULTS                                                      │
│  ───────────────────                                                     │
│                                                                          │
│  - Scales to TRILLION parameters                                        │
│  - 12.4%+ improvement in production metrics at Meta                     │
│  - First demonstration of scaling laws in industrial RecSys             │
│  - 10-1000x efficiency gains via M-FALCON inference                     │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Why HSTU represents a paradigm shift:

HSTU isn't just a bigger transformer—it fundamentally rethinks how transformers should work for recommendations. Here's why each innovation matters:

1. Removing softmax: The counterintuitive insight

Standard attention uses softmax to create a probability distribution over keys. But this has a problem for recommendations:

Code
User strongly prefers: [Item A (score 100), Item B (score 99), Item C (score 1)]

After softmax:
- Item A: ~50%
- Item B: ~50%
- Item C: ~0%

The 100x preference for A over C becomes only 2x after normalization!

HSTU's pointwise attention preserves these intensity differences. If a user clicked item A 100 times and item C once, that 100:1 ratio matters.

2. Actions as first-class citizens

Most sequential models treat sequences as: [item₁, item₂, item₃, ...]

But user behavior is richer. HSTU models: [item₁, click, item₂, purchase, item₃, view, ...]

This captures that a purchase is more informative than a view. The model learns that sequences ending with "purchase" signals different intent than sequences ending with "view."

3. The scaling law discovery

HSTU demonstrated, for the first time, that scaling laws exist in recommendation—just like in LLMs. More parameters consistently improve performance:

Code
Parameters    Relative Improvement
───────────────────────────────────
100M          baseline
1B            +3.2%
10B           +6.8%
100B          +9.1%
1T            +12.4%

This validates the investment in large recommendation models and suggests we're far from diminishing returns.

4. M-FALCON: Making trillion-parameter serving feasible

Training a 1T model is hard; serving it is harder. M-FALCON (Meta's Factorized Attention with Long-Context On-the-Fly Normalization) enables efficient inference by:

  • Factorizing attention computation across devices
  • Caching key-value pairs for recent items
  • Using speculative decoding for next-item prediction

The result: 10-1000x efficiency gains, making trillion-parameter models practical for real-time serving.

Python
# Simplified HSTU-style attention (conceptual)
class PointwiseAttention(nn.Module):
    """
    HSTU-style attention without softmax normalization.
    Preserves preference intensity.
    """

    def __init__(self, hidden_dim: int, num_heads: int):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads

        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)

        # Feature map (e.g., ELU + 1 for positivity)
        self.feature_map = lambda x: F.elu(x) + 1

    def forward(self, x: torch.Tensor, causal: bool = True) -> torch.Tensor:
        B, L, D = x.shape

        Q = self.q_proj(x).view(B, L, self.num_heads, self.head_dim)
        K = self.k_proj(x).view(B, L, self.num_heads, self.head_dim)
        V = self.v_proj(x).view(B, L, self.num_heads, self.head_dim)

        # Apply feature map (makes attention linear-time possible)
        Q = self.feature_map(Q)
        K = self.feature_map(K)

        if causal:
            # Causal linear attention via cumulative sum
            # O(L) instead of O(L^2)
            KV = torch.einsum('blhd,blhe->blhde', K, V)
            KV_cumsum = torch.cumsum(KV, dim=1)

            output = torch.einsum('blhd,blhde->blhe', Q, KV_cumsum)

            # Normalize by cumulative key sum
            K_cumsum = torch.cumsum(K, dim=1)
            normalizer = torch.einsum('blhd,blhd->blh', Q, K_cumsum)
            output = output / (normalizer.unsqueeze(-1) + 1e-6)
        else:
            # Non-causal: standard (but still pointwise)
            attn = torch.einsum('blhd,bmhd->blmh', Q, K)
            output = torch.einsum('blmh,bmhd->blhd', attn, V)

        output = output.reshape(B, L, D)
        return self.out_proj(output)

LiGR: LinkedIn's Generative Ranking Architecture

LiGR (LinkedIn Generative Ranking) is a production transformer architecture that powers recommendations at LinkedIn. It introduces key innovations that address limitations of standard transformers for ranking.

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                      LiGR ARCHITECTURE                                   │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  PROBLEM WITH STANDARD TRANSFORMERS:                                    │
│  ─────────────────────────────────────                                   │
│                                                                          │
│  Deep residual stacking can "forget" distant layer information          │
│  Standard: output = LayerNorm(x + Attention(x))                         │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  LiGR SOLUTION: Recurrent Gating Dynamics                               │
│  ─────────────────────────────────────────                               │
│                                                                          │
│  • Parameterized gates selectively preserve aggregated information      │
│  • Improved gradient propagation and training stability                 │
│  • Enables deep stacking (24+ layers) without degradation               │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  KEY FEATURES:                                                           │
│                                                                          │
│  1. Historical Attention: Aggregates user behavior features             │
│  2. In-Session Attention: Models current session context                │
│  3. Set-wise Scoring: Jointly scores items for automatic diversity      │
│  4. Learned Normalization: Adapts to recommendation-specific patterns   │
│                                                                          │
│  RESULTS AT LINKEDIN:                                                    │
│  • Deprecated most manual feature engineering                           │
│  • Outperforms systems using hundreds of features with only a few       │
│  • Validated scaling laws for ranking systems                           │
│  • Automatic diversity via simultaneous set-wise scoring                │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

eSASRec: Enhanced SASRec (RecSys 2025)

eSASRec (ACM RecSys 2025) systematically studies what makes transformer-based recommendation work. The key finding: combining the right components matters more than novel architectures.

The research question eSASRec answers:

Every year, new sequential recommendation models claim to beat SASRec. But do they really? Or are they just using better training tricks? eSASRec decomposes the improvements into their sources:

Code
Claimed improvement sources in literature:
─────────────────────────────────────────────────────────────────────────
- Novel attention mechanisms (local, sparse, linear)
- Better positional encodings (relative, rotary)
- Improved architectures (gating, dense layers)
- Training objectives (contrastive, generative)

Actual improvement sources (eSASRec findings):
─────────────────────────────────────────────────────────────────────────
- Loss function: BCE → Sampled Softmax (+3-5%)
- Layer design: Vanilla → LiGR-style gating (+2-3%)
- Regularization: Proper dropout tuning (+1-2%)
- The "novel" architectures? Often 0% improvement when controlled properly

The three components that actually matter:

  1. SASRec's training objective (next-item prediction with causal masking): Despite BERT4Rec's bidirectional attention, unidirectional training remains best for online serving.

  2. LiGR-style transformer layers (gated residuals): The gating mechanism from LiGR helps gradient flow in deeper networks. Standard residual connections (x + attention(x)) work fine for 2-3 layers but degrade at 6+ layers.

  3. Sampled softmax loss: Full softmax over millions of items is computationally prohibitive. Sampled softmax with hard negatives achieves comparable quality at fraction of the cost.

Why this matters for practitioners:

You don't need the latest fancy architecture. Take SASRec, add sampled softmax, use LiGR-style layers, tune your dropout—you'll match or beat most published "state-of-the-art" results. The recipe is:

  • Architecture: Standard transformer with gated residuals
  • Training: Causal masking, sampled softmax (256-1024 negatives)
  • Regularization: Dropout 0.2-0.5 depending on data sparsity
  • Depth: 2-4 layers (diminishing returns beyond)
Python
# eSASRec: The winning combination
class ESASRec(nn.Module):
    """
    Enhanced SASRec = SASRec objective + LiGR layers + Sampled Softmax
    Paper: https://arxiv.org/abs/2508.06450
    """

    def __init__(
        self,
        num_items: int,
        hidden_dim: int = 128,
        num_layers: int = 2,
        num_heads: int = 4,
        num_negatives: int = 256,  # For sampled softmax
    ):
        super().__init__()

        # Item embeddings
        self.item_embedding = nn.Embedding(num_items + 1, hidden_dim, padding_idx=0)
        self.position_embedding = nn.Embedding(50, hidden_dim)

        # LiGR Transformer layers (key difference from vanilla SASRec)
        self.layers = nn.ModuleList([
            LiGRBlock(hidden_dim, num_heads)
            for _ in range(num_layers)
        ])

        self.num_negatives = num_negatives
        self.layer_norm = nn.LayerNorm(hidden_dim)

    def forward(self, item_seq):
        # Standard embedding + position
        seq_emb = self.item_embedding(item_seq)
        positions = torch.arange(item_seq.shape[1], device=item_seq.device)
        seq_emb = seq_emb + self.position_embedding(positions)

        # Causal mask
        mask = torch.triu(torch.ones(item_seq.shape[1], item_seq.shape[1]), diagonal=1).bool()

        # LiGR transformer layers
        hidden = seq_emb
        for layer in self.layers:
            hidden = layer(hidden, mask)

        return self.layer_norm(hidden)


class LiGRBlock(nn.Module):
    """
    LiGR Transformer block with gated residuals.
    Preserves information better across deep networks.
    """

    def __init__(self, hidden_dim: int, num_heads: int):
        super().__init__()

        self.attention = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.GELU(),
            nn.Linear(hidden_dim * 4, hidden_dim),
        )

        # Gating mechanism (key LiGR innovation)
        self.gate_attn = nn.Linear(hidden_dim * 2, hidden_dim)
        self.gate_ffn = nn.Linear(hidden_dim * 2, hidden_dim)

        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)

    def forward(self, x, mask):
        # Attention with gated residual
        normed = self.norm1(x)
        attn_out, _ = self.attention(normed, normed, normed, attn_mask=mask)

        # Gated combination (not just addition)
        gate = torch.sigmoid(self.gate_attn(torch.cat([x, attn_out], dim=-1)))
        x = x * gate + attn_out * (1 - gate)

        # FFN with gated residual
        normed = self.norm2(x)
        ffn_out = self.ffn(normed)

        gate = torch.sigmoid(self.gate_ffn(torch.cat([x, ffn_out], dim=-1)))
        x = x * gate + ffn_out * (1 - gate)

        return x

Key findings from the eSASRec study:

ComponentOptions TestedWinner
Training ObjectiveNext-item, Masked, ContrastiveNext-item (SASRec-style)
Transformer ArchitectureVanilla, LiGR, HSTULiGR
Loss FunctionFull softmax, BCE, Sampled SoftmaxSampled Softmax
Negative SamplingUniform, Popularity, In-batchPopularity-weighted

Results: eSASRec achieves 23% improvement over previous SOTA (ActionPiece) and resides on the Pareto frontier alongside HSTU for accuracy-coverage trade-off. Importantly, it requires no extra features (unlike HSTU which needs timestamps).


Part VI: Beyond Attention - State Space Models

Mamba4Rec: SSMs for Sequential Recommendation

Mamba4Rec (Best Paper Award, RelKD@KDD 2024) applies Selective State Space Models to sequential recommendation, achieving transformer-level accuracy with linear complexity.

Code
┌─────────────────────────────────────────────────────────────────────────┐
│               ATTENTION vs STATE SPACE MODELS                            │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  TRANSFORMER ATTENTION:                                                  │
│  ───────────────────────                                                 │
│                                                                          │
│  Complexity: O(n²) in sequence length                                   │
│  Memory: O(n²) for attention matrix                                     │
│  Parallelization: Excellent (all positions computed together)           │
│                                                                          │
│  Problem: Quadratic cost prohibitive for long sequences (1000+ items)   │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  STATE SPACE MODELS (Mamba):                                            │
│  ────────────────────────────                                            │
│                                                                          │
│  Complexity: O(n) in sequence length                                    │
│  Memory: O(n) - no attention matrix                                     │
│  Parallelization: Good (via parallel scan)                              │
│                                                                          │
│  Key Innovation: Selective state transitions                            │
│  - Content-aware gating of what to remember                             │
│  - Learned dynamics for sequence modeling                               │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  WHY SSMs FOR RECOMMENDATIONS:                                           │
│                                                                          │
│  • Users with 1000+ interactions are common                             │
│  • Linear scaling enables full history modeling                         │
│  • Selective mechanism learns what's relevant                           │
│  • Efficient inference for real-time serving                            │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Understanding why Mamba works for recommendations:

State Space Models (SSMs) seem like a step backward—they process sequences left-to-right like RNNs, while transformers can see everything at once. So why do they work?

The key insight: Selective state transitions

Unlike traditional RNNs that blindly compress all history into a fixed-size state, Mamba selectively decides what to remember:

Code
Traditional RNN:
  h_t = tanh(W_h * h_{t-1} + W_x * x_t)
  ↑ Fixed update rule, can't prioritize

Mamba (Selective SSM):
  Δ_t = softplus(Linear(x_t))           ← Input-dependent discretization
  B_t = Linear(x_t)                      ← Input-dependent input matrix
  C_t = Linear(x_t)                      ← Input-dependent output matrix
  h_t = exp(-Δ_t * A) * h_{t-1} + Δ_t * B_t * x_t
        ↑ How much to forget       ↑ How much to incorporate

The selectivity means:

  • A boring "filler" item (user just browsing): Small Δ → retain previous state, mostly ignore this item
  • An important signal (user made a purchase): Large Δ → reset state, heavily weight this item

Mamba4Rec's specific innovations:

  1. Bidirectional Mamba: Unlike language models that only go left-to-right, Mamba4Rec processes sequences in both directions and combines the states. This captures patterns like "users who will buy X often first browse Y."

  2. Position embeddings preserved: Despite being an SSM, Mamba4Rec adds positional encodings—important because position matters differently in recommendations than in language.

  3. Item embedding sharing: Same embedding matrix for input and output, like SASRec.

When to use Mamba4Rec over transformers:

  • Long sequences (500+ items): Quadratic attention becomes expensive
  • Memory-constrained serving: No O(n²) attention matrix
  • Streaming scenarios: Can update state incrementally without recomputing

When to stick with transformers:

  • Short sequences (<100 items): Quadratic cost is negligible
  • Need for interpretability: Attention weights are interpretable; SSM states are not
  • Existing infrastructure: Most production systems are built around transformers
Python
# Mamba4Rec architecture (simplified)
class Mamba4Rec(nn.Module):
    """
    Sequential recommendation with Selective State Space Models.
    Paper: https://arxiv.org/abs/2403.03900
    """

    def __init__(
        self,
        num_items: int,
        hidden_dim: int = 64,
        state_dim: int = 16,
        num_layers: int = 2,
    ):
        super().__init__()

        self.item_embedding = nn.Embedding(num_items + 1, hidden_dim, padding_idx=0)
        self.position_embedding = nn.Embedding(200, hidden_dim)

        # Mamba layers instead of transformer
        self.layers = nn.ModuleList([
            MambaBlock(hidden_dim, state_dim)
            for _ in range(num_layers)
        ])

        self.layer_norm = nn.LayerNorm(hidden_dim)

    def forward(self, item_seq):
        B, L = item_seq.shape

        # Embeddings
        seq_emb = self.item_embedding(item_seq)
        positions = torch.arange(L, device=item_seq.device)
        hidden = seq_emb + self.position_embedding(positions)

        # Mamba layers (O(L) instead of O(L²))
        for layer in self.layers:
            hidden = layer(hidden)

        return self.layer_norm(hidden)


class MambaBlock(nn.Module):
    """
    Selective State Space Model block.
    Core idea: Learn what information to retain in hidden state.
    """

    def __init__(self, hidden_dim: int, state_dim: int):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.state_dim = state_dim

        # Input projections
        self.in_proj = nn.Linear(hidden_dim, hidden_dim * 2)

        # SSM parameters (learned per-position)
        self.x_proj = nn.Linear(hidden_dim, state_dim * 2)  # B, C matrices
        self.dt_proj = nn.Linear(hidden_dim, hidden_dim)    # Δ (timestep)

        # State matrix A (fixed, learned)
        self.A = nn.Parameter(torch.randn(hidden_dim, state_dim))

        # Output projection
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)
        self.norm = nn.LayerNorm(hidden_dim)

    def forward(self, x):
        B, L, D = x.shape
        residual = x

        # Project input
        xz = self.in_proj(x)
        x_ssm, z = xz.chunk(2, dim=-1)

        # Compute selective parameters
        x_dbl = self.x_proj(x_ssm)
        delta = F.softplus(self.dt_proj(x_ssm))

        # Discretize and run SSM (simplified)
        # Full implementation uses parallel scan for efficiency
        A = -torch.exp(self.A)

        # Selective state space computation
        # This is where the "selection" happens - delta modulates
        # how much of each input affects the hidden state
        out = self._ssm_scan(x_ssm, A, delta)

        # Gated output
        out = out * F.silu(z)
        out = self.out_proj(out)

        return self.norm(residual + out)

    def _ssm_scan(self, x, A, delta):
        # Simplified selective scan
        # Real implementation uses CUDA kernels for efficiency
        B, L, D = x.shape
        h = torch.zeros(B, D, self.state_dim, device=x.device)
        outputs = []

        for t in range(L):
            h = h * torch.exp(A * delta[:, t:t+1, :, None]) + x[:, t, :, None]
            outputs.append(h.sum(dim=-1))

        return torch.stack(outputs, dim=1)

Mamba4Rec Results:

  • Matches or exceeds SASRec/BERT4Rec accuracy
  • 10x faster inference on sequences >500 items
  • Lower memory footprint enables longer histories

Related SSM Models (2024-2025):

  • EchoMamba4Rec: Bidirectional SSM with spectral filtering
  • MaTrRec: Hybrid Mamba + Transformer
  • SSD4Rec: Structured State Space Duality for efficiency

Part VII: Generative Retrieval with Semantic IDs

TIGER: Treating Recommendation as Generation

TIGER (Transformer Index for GEnerative Recommenders) introduces a paradigm shift: instead of scoring items, generate item identifiers token-by-token.

Code
┌─────────────────────────────────────────────────────────────────────────┐
│              TRADITIONAL vs GENERATIVE RETRIEVAL                         │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  TRADITIONAL (Two-Tower):                                               │
│  ─────────────────────────                                               │
│                                                                          │
│  User History → User Encoder → User Embedding                           │
│                                      ↓                                   │
│                                 Dot Product → Top-K items               │
│                                      ↑                                   │
│  Item Features → Item Encoder → Item Embeddings (pre-computed)          │
│                                                                          │
│  Requires: ANN index, separate item encoding, retrieval infrastructure  │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  GENERATIVE (TIGER):                                                    │
│  ─────────────────────                                                   │
│                                                                          │
│  User History → Transformer → Generate Semantic ID tokens               │
│                                                                          │
│  [item₁, item₂, item₃] → Decoder → [tok₁, tok₂, tok₃, tok₄]           │
│                                      ↓                                   │
│                               Map to Item ID                            │
│                                                                          │
│  Benefits:                                                               │
│  • Unified model (no separate retrieval infrastructure)                 │
│  • Similar items share token prefixes                                   │
│  • Beam search explores item space efficiently                          │
│  • Natural handling of semantic similarity                              │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Semantic ID Tokenization

The key innovation: represent each item as a sequence of learned discrete tokens (Semantic ID) that capture semantic meaning.

Why Semantic IDs are revolutionary:

Traditional item IDs are arbitrary: item 12345 and item 12346 have no semantic relationship. But Semantic IDs encode meaning through hierarchical structure:

Code
Semantic ID structure (4 tokens per item):
─────────────────────────────────────────────────────────────────────────
Token 1 (coarse): Category-level (electronics, clothing, books)
Token 2: Subcategory (phones, laptops, headphones)
Token 3: Style/brand cluster (premium, budget, specific brand)
Token 4 (fine): Specific item identifier

Example:
  iPhone 15 Pro:    [47, 12, 89, 156]
  iPhone 15:        [47, 12, 89, 203]   ← Shares first 3 tokens!
  Samsung Galaxy:   [47, 12, 45, 78]    ← Shares first 2 tokens
  Running shoes:    [23, 67, 34, 12]    ← Different category, different prefix

This structure enables:

  • Efficient beam search: Prune entire categories early in generation
  • Semantic generalization: Model learns "electronics" behavior applies to new electronics
  • Natural diversity: Different beam paths explore different categories

How Semantic IDs are learned (RQ-VAE):

Residual Quantization Vector-Autoencoder (RQ-VAE) learns the tokenization:

  1. Encode item to embedding: Use content encoder (e.g., BERT on description, CNN on image) to get dense vector

  2. Quantize residually: At each level, find nearest codebook entry, compute residual (what's left), quantize residual at next level

  3. Train end-to-end: Reconstruction loss ensures semantic IDs can decode back to original embedding; quantization learns meaningful clusters

The hierarchical clustering effect:

RQ-VAE naturally creates a hierarchy because each level explains increasingly fine-grained details:

  • Level 1: Broad categories (high variance features)
  • Level 2: Subcategories (medium variance)
  • Level 3+: Specific attributes (low variance)

Similar items end up with shared prefixes, enabling efficient tree-like search during generation.

Python
class SemanticIDTokenizer:
    """
    Convert items to semantic IDs using RQ-VAE.
    Items with similar content get similar token prefixes.
    """

    def __init__(
        self,
        num_codebooks: int = 4,  # Number of tokens per item
        codebook_size: int = 256,  # Vocabulary size per position
        embedding_dim: int = 768,
    ):
        self.num_codebooks = num_codebooks
        self.codebook_size = codebook_size

        # Residual quantization codebooks
        self.codebooks = nn.ParameterList([
            nn.Parameter(torch.randn(codebook_size, embedding_dim))
            for _ in range(num_codebooks)
        ])

    def encode(self, item_embedding: torch.Tensor) -> torch.Tensor:
        """
        Convert dense embedding to sequence of discrete tokens.

        Args:
            item_embedding: (batch, embedding_dim) from content encoder

        Returns:
            semantic_id: (batch, num_codebooks) token indices
        """
        tokens = []
        residual = item_embedding

        for codebook in self.codebooks:
            # Find nearest codebook entry
            distances = torch.cdist(residual, codebook)
            token_idx = distances.argmin(dim=-1)
            tokens.append(token_idx)

            # Compute residual for next level
            quantized = codebook[token_idx]
            residual = residual - quantized

        return torch.stack(tokens, dim=-1)


class TIGERModel(nn.Module):
    """
    TIGER: Generative recommendation via semantic ID generation.
    Paper: https://arxiv.org/abs/2305.05065
    """

    def __init__(
        self,
        tokenizer: SemanticIDTokenizer,
        hidden_dim: int = 256,
        num_layers: int = 6,
    ):
        super().__init__()

        self.tokenizer = tokenizer
        vocab_size = tokenizer.codebook_size * tokenizer.num_codebooks

        # Token embeddings (shared across codebook positions)
        self.token_embedding = nn.Embedding(vocab_size, hidden_dim)
        self.position_embedding = nn.Embedding(200, hidden_dim)

        # Transformer decoder
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=hidden_dim,
            nhead=8,
            dim_feedforward=hidden_dim * 4,
            batch_first=True,
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers)

        # Output heads (one per codebook position)
        self.output_heads = nn.ModuleList([
            nn.Linear(hidden_dim, tokenizer.codebook_size)
            for _ in range(tokenizer.num_codebooks)
        ])

    def forward(self, history_ids: torch.Tensor) -> list[torch.Tensor]:
        """
        Generate next item's semantic ID given history.

        Args:
            history_ids: (batch, seq_len, num_codebooks) - flattened semantic IDs

        Returns:
            logits: List of (batch, codebook_size) for each codebook position
        """
        B, L, K = history_ids.shape

        # Flatten semantic IDs to sequence
        # [item1_tok1, item1_tok2, ..., item2_tok1, ...]
        flat_ids = history_ids.view(B, L * K)

        # Embed tokens
        hidden = self.token_embedding(flat_ids)
        positions = torch.arange(L * K, device=hidden.device)
        hidden = hidden + self.position_embedding(positions)

        # Decode (causal attention)
        hidden = self.decoder(hidden, hidden)

        # Generate semantic ID tokens autoregressively
        logits = []
        for i, head in enumerate(self.output_heads):
            # Use last position for each codebook
            pos = -K + i if i < K else -1
            logits.append(head(hidden[:, pos]))

        return logits

    def generate(self, history_ids: torch.Tensor, num_beams: int = 10) -> torch.Tensor:
        """
        Beam search generation for next item's semantic ID.
        """
        # Beam search over semantic ID space
        # Similar to text generation but in item token space
        pass

Recent Advances in Semantic IDs (2024-2025)

ModelInnovationKey Result
LETTERCollaborative + diversity regularization in RQ-VAEBetter codebook utilization
LC-RecAligns LLM semantic space with collaborative signalsCross-domain transfer
TIGER++Whitening + EMA for codebook learningMore stable training
RPGParallel (non-autoregressive) ID generation4x faster inference
LLaDA-RecDiscrete diffusion for semantic IDsBetter diversity

Trade-offs with Semantic IDs:

  • Search vs Recommendation: IDs tuned for search hurt recommendation, and vice versa
  • Static vs Dynamic: Fixed IDs assume universal similarity; some work explores adaptive tokenization
  • Length: TIGER uses 4 tokens; longer IDs (32-64) capture more semantics but slow inference

Part VIII: Training Strategies

Loss Functions for Sequential Recommendation

The choice of loss function significantly impacts performance—often more than architectural changes. Understanding these losses helps you choose the right one for your constraints.

The fundamental tradeoff:

With millions of items, you can't compute scores for everything. Each loss function handles this differently:

Code
Loss Function        Complexity per Example    Quality      Use Case
─────────────────────────────────────────────────────────────────────────
Full Softmax         O(num_items)              Best         Small catalogs (<100K)
Sampled Softmax      O(num_negatives)          Good         Large catalogs
BPR (Pairwise)       O(num_negatives)          Good         Implicit feedback
InfoNCE              O(batch_size)             Good         Self-supervised

Full Softmax: The gold standard

Full softmax computes the probability distribution over ALL items. It's optimal because the model explicitly learns to rank the target higher than every other item. The problem: with 10M items, this means 10M forward passes per example—prohibitive for training.

Sampled Softmax: The practical choice

Instead of all items, sample a subset of negatives. With 256-1024 negatives, you get ~95% of full softmax quality at a fraction of the cost. The key insight: most items are easy negatives (clearly not relevant), so you don't need to see them all.

Critical choice: how to sample negatives?

  • Uniform: Simple, but wastes samples on trivially irrelevant items
  • Popularity-weighted: Sample popular items more often (they're harder negatives)
  • In-batch negatives: Use other users' positives as negatives (free, diverse)
  • Hard negatives: Mine items the model currently ranks incorrectly (most informative, but expensive)

BPR (Bayesian Personalized Ranking):

BPR doesn't try to predict absolute scores—it optimizes the relative ordering. The loss says: "score the positive higher than negatives, I don't care by how much."

LBPR=logσ(spossneg)\mathcal{L}_{BPR} = -\sum \log \sigma(s_{pos} - s_{neg})

This works well for implicit feedback (clicks, views) where we don't have explicit ratings. A click doesn't mean "great item"—it means "better than items not clicked."

InfoNCE (Contrastive):

InfoNCE treats recommendation as a contrastive learning problem: pull together (user, positive item) pairs, push apart (user, negative item) pairs. The temperature parameter τ\tau controls how hard the model focuses on difficult negatives:

  • Low τ\tau (0.01): Focus on very hard negatives
  • High τ\tau (1.0): Treat all negatives more equally

InfoNCE is popular in self-supervised pretraining because it doesn't require labels—just pairs of "similar" items (e.g., items in the same session).

Practical recommendations:

  1. Start with sampled softmax (256-512 negatives, popularity-weighted)
  2. Use in-batch negatives when batch size is large (saves computation)
  3. Add hard negative mining once model converges (pushes accuracy further)
  4. Full softmax only if you have <100K items and GPU memory to spare
Python
# 1. Cross-Entropy (Full Softmax)
# Most accurate but expensive for large item catalogs
def full_softmax_loss(logits, targets):
    """
    logits: (B, num_items) - scores for all items
    targets: (B,) - target item IDs
    """
    return F.cross_entropy(logits, targets)


# 2. Sampled Softmax
# Approximate full softmax with negative sampling
def sampled_softmax_loss(
    hidden: torch.Tensor,
    targets: torch.Tensor,
    item_embeddings: torch.Tensor,
    num_negatives: int = 100,
):
    """
    hidden: (B, D) - sequence representations
    targets: (B,) - target item IDs
    item_embeddings: (num_items, D)
    """
    B = hidden.shape[0]

    # Positive scores
    target_emb = item_embeddings[targets]  # (B, D)
    pos_scores = (hidden * target_emb).sum(-1)  # (B,)

    # Sample negatives (uniform or popularity-based)
    neg_indices = torch.randint(0, item_embeddings.shape[0], (B, num_negatives))
    neg_emb = item_embeddings[neg_indices]  # (B, num_neg, D)
    neg_scores = torch.bmm(neg_emb, hidden.unsqueeze(-1)).squeeze(-1)  # (B, num_neg)

    # Combine and compute log-softmax
    all_scores = torch.cat([pos_scores.unsqueeze(-1), neg_scores], dim=-1)
    loss = -F.log_softmax(all_scores, dim=-1)[:, 0].mean()

    return loss


# 3. Binary Cross-Entropy (BPR-style)
# Pairwise: score positive higher than negatives
def bpr_loss(pos_scores, neg_scores):
    """
    pos_scores: (B,) - scores for positive items
    neg_scores: (B, num_neg) - scores for negative items
    """
    # For each positive, compare with each negative
    diff = pos_scores.unsqueeze(-1) - neg_scores  # (B, num_neg)
    loss = -F.logsigmoid(diff).mean()
    return loss


# 4. InfoNCE / Contrastive Loss
# Popular in self-supervised settings
def infonce_loss(hidden, targets, item_embeddings, temperature=0.07):
    """Contrastive loss with in-batch negatives."""
    # Positive pairs
    target_emb = item_embeddings[targets]  # (B, D)

    # Normalize
    hidden = F.normalize(hidden, dim=-1)
    target_emb = F.normalize(target_emb, dim=-1)

    # Similarity matrix (in-batch negatives)
    sim = torch.mm(hidden, target_emb.T) / temperature  # (B, B)

    # Diagonal elements are positives
    labels = torch.arange(hidden.shape[0], device=hidden.device)
    loss = F.cross_entropy(sim, labels)

    return loss

Negative Sampling Strategies

Not all negatives are equally informative. The choice of how to sample negatives can impact model quality as much as architecture choices.

The problem with uniform sampling:

With uniform sampling, most negatives are "easy"—clearly irrelevant items that the model quickly learns to score low. These easy negatives provide little training signal:

Code
User interested in: Electronics
─────────────────────────────────────────────────────────────────────────
Uniform negatives:    [Garden hose, Dog food, Baby clothes, Romance novel]
                      ↳ Model already scores these near 0. No learning.

Better negatives:     [Samsung Galaxy, Laptop stand, Wireless mouse, AirPods]
                      ↳ Model must work harder to distinguish. More learning.

Four sampling strategies, ranked by effectiveness:

  1. Uniform sampling: Baseline. Fast but wasteful. Most samples are trivially easy.

  2. Popularity-based sampling: Sample popular items more often. Popular items are harder negatives because they're plausibly relevant to many users. Adds ~5-10% improvement over uniform.

  3. In-batch negatives: Brilliant trick—use other users' positive items as your negatives. Zero additional computation (you already have their embeddings). Works well when batch size is large (512+). Popular in contrastive learning.

  4. Hard negative mining: Explicitly find items the model currently ranks incorrectly. Most expensive (requires scoring many items) but most effective. Can add 10-20% improvement. Common approach: periodically compute top-K similar items, use those as hard negatives.

The false negative problem:

A sampled "negative" might actually be a good recommendation:

  • User hasn't seen item X yet, but would love it
  • You sample X as negative, penalize model for scoring it highly
  • Model learns the wrong thing!

Solutions:

  • Filter user history: Never sample items the user interacted with
  • Debiasing: Weight loss by inverse propensity (how likely was this item to be seen?)
  • Accept some noise: With enough negatives, signal overwhelms noise

Curriculum learning for negatives:

Start with easy negatives, gradually increase difficulty:

  • Epoch 1-5: Uniform sampling (fast convergence on obvious patterns)
  • Epoch 5-10: Popularity-weighted (harder distinction)
  • Epoch 10+: Hard negative mining (fine-grained ranking)

This curriculum often converges faster than starting with hard negatives.

Python
class NegativeSampler:
    """Different strategies for sampling negative items."""

    def __init__(self, num_items: int, item_popularity: torch.Tensor = None):
        self.num_items = num_items
        self.item_popularity = item_popularity

    def uniform(self, batch_size: int, num_negatives: int) -> torch.Tensor:
        """Random uniform sampling."""
        return torch.randint(0, self.num_items, (batch_size, num_negatives))

    def popularity(self, batch_size: int, num_negatives: int) -> torch.Tensor:
        """Sample proportional to item popularity (harder negatives)."""
        probs = self.item_popularity / self.item_popularity.sum()
        indices = torch.multinomial(probs, batch_size * num_negatives, replacement=True)
        return indices.view(batch_size, num_negatives)

    def in_batch(self, targets: torch.Tensor) -> torch.Tensor:
        """Use other items in batch as negatives (efficient)."""
        B = targets.shape[0]
        # All items in batch except self
        negatives = targets.unsqueeze(0).expand(B, -1)  # (B, B)
        mask = ~torch.eye(B, dtype=torch.bool, device=targets.device)
        return negatives[mask].view(B, B-1)

    def hard_negatives(
        self,
        hidden: torch.Tensor,
        item_embeddings: torch.Tensor,
        num_negatives: int,
        exclude: torch.Tensor,
    ) -> torch.Tensor:
        """
        Sample items with high scores but wrong labels.
        Expensive but very effective.
        """
        # Score all items
        scores = torch.mm(hidden, item_embeddings.T)  # (B, num_items)

        # Mask out true positives
        scores.scatter_(1, exclude.unsqueeze(-1), float('-inf'))

        # Sample from top-k
        _, top_indices = scores.topk(num_negatives * 10, dim=-1)

        # Random sample from top-k
        rand_indices = torch.randint(0, num_negatives * 10, (hidden.shape[0], num_negatives))
        negatives = top_indices.gather(1, rand_indices)

        return negatives

Part IX: Handling Scale and Efficiency

The Item Vocabulary Problem

Real recommendation systems have millions of items. Full softmax is infeasible:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    SCALING CHALLENGES                                     │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  FULL SOFTMAX COST:                                                      │
│  ───────────────────                                                     │
│                                                                          │
│  Items: 10 million                                                       │
│  Hidden dim: 256                                                         │
│  Batch size: 1024                                                        │
│                                                                          │
│  Embedding table: 10M × 256 = 2.56 GB                                   │
│  Output projection: 10M × 256 = 2.56 GB                                 │
│  Softmax computation: 1024 × 10M = 10.24 billion ops per batch         │
│                                                                          │
│  This is INFEASIBLE for training and inference.                         │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  SOLUTIONS:                                                              │
│  ──────────                                                              │
│                                                                          │
│  1. SAMPLED SOFTMAX                                                      │
│     - Train with ~1000 negatives instead of all items                   │
│     - 10,000x reduction in computation                                  │
│                                                                          │
│  2. TWO-TOWER RETRIEVAL                                                  │
│     - Separate user and item encoders                                   │
│     - Use ANN (approximate nearest neighbor) for retrieval              │
│     - Inference: O(log N) instead of O(N)                               │
│                                                                          │
│  3. HIERARCHICAL SOFTMAX                                                 │
│     - Organize items in tree structure                                  │
│     - O(log N) path instead of O(N) softmax                            │
│                                                                          │
│  4. HASH EMBEDDINGS                                                      │
│     - Multiple items share embedding buckets                            │
│     - Reduces embedding table size 10-100x                              │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Two-Tower Architecture

The standard production architecture separates user and item encoding. This pattern is universal in industrial recommendation systems—YouTube, Pinterest, LinkedIn, and virtually every large-scale system uses some variant.

Why two towers instead of one?

A unified model that jointly encodes (user, item) pairs seems more powerful—it can capture complex user-item interactions. But it has a fatal flaw:

Code
Unified Model (Dense Retrieval):
─────────────────────────────────────────────────────────────────────────
For each user, score all 10M items:
  for item in all_items:
    score = model(user_features, item_features)  # Full forward pass

Cost: 10M forward passes per user per request
At 1000 QPS = 10 BILLION forward passes per second = IMPOSSIBLE

Two-Tower Model:
─────────────────────────────────────────────────────────────────────────
OFFLINE (once per day):
  for item in all_items:
    item_embeddings[item] = item_tower(item_features)  # 10M passes total
  build_ann_index(item_embeddings)

ONLINE (per request):
  user_emb = user_tower(user_features)  # 1 forward pass
  candidates = ann_index.search(user_emb, k=100)  # O(log N) lookup

Cost: 1 forward pass + index lookup per request
At 1000 QPS = 1000 forward passes per second = EASY

The key constraint: Dot product interaction only

For ANN indexes to work, the score must decompose as:

score(u,i)=ui\text{score}(u, i) = \mathbf{u} \cdot \mathbf{i}

No cross-features, no nonlinear combinations. This seems limiting, but the towers can be arbitrarily complex—SASRec, BERT4Rec, any transformer. All user-item interaction is compressed into the learned embeddings.

What each tower learns:

  • User tower: Learns to compress user history into a "preference vector" that captures what kinds of items this user likes
  • Item tower: Learns to represent items in the same space, such that "relevant" items are close to users who would like them

The embedding space is shared—users and items live in the same vector space, making similarity computation possible.

Temperature scaling for contrastive learning:

Two-tower models are typically trained with contrastive loss:

L=logexp(uipos/τ)jexp(uij/τ)\mathcal{L} = -\log \frac{\exp(\mathbf{u} \cdot \mathbf{i}_{pos} / \tau)}{\sum_{j} \exp(\mathbf{u} \cdot \mathbf{i}_j / \tau)}

The temperature τ\tau is critical:

  • High temperature (1.0): Softer distribution, model focuses on all negatives equally
  • Low temperature (0.05): Sharper distribution, model focuses on hardest negatives

Most systems use τ[0.05,0.2]\tau \in [0.05, 0.2] for sharp, discriminative embeddings.

Python
class TwoTowerModel(nn.Module):
    """
    Two-tower architecture for large-scale retrieval.
    User tower: Encodes user history
    Item tower: Encodes item features
    """

    def __init__(
        self,
        num_items: int,
        hidden_dim: int = 128,
        user_tower: nn.Module = None,  # e.g., SASRec
    ):
        super().__init__()
        self.user_tower = user_tower or SASRec(num_items, hidden_dim=hidden_dim)

        # Item tower: Could be simple embedding or full encoder
        self.item_embedding = nn.Embedding(num_items, hidden_dim)
        self.item_mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )

    def encode_user(self, item_seq: torch.Tensor) -> torch.Tensor:
        """Encode user from their interaction sequence."""
        hidden = self.user_tower(item_seq)
        return hidden[:, -1, :]  # Last position

    def encode_items(self, item_ids: torch.Tensor) -> torch.Tensor:
        """Encode items."""
        emb = self.item_embedding(item_ids)
        return self.item_mlp(emb)

    def forward(self, item_seq: torch.Tensor, candidate_items: torch.Tensor):
        """
        Score candidate items for users.

        Args:
            item_seq: (B, L) user sequences
            candidate_items: (B, C) candidate item IDs

        Returns:
            scores: (B, C) scores for each candidate
        """
        user_emb = self.encode_user(item_seq)  # (B, D)
        item_emb = self.encode_items(candidate_items)  # (B, C, D)

        scores = torch.bmm(item_emb, user_emb.unsqueeze(-1)).squeeze(-1)
        return scores


# At inference time:
# 1. Pre-compute all item embeddings
# 2. Build ANN index (FAISS, ScaNN, etc.)
# 3. For each user, encode once, retrieve top-k via ANN

import faiss

def build_item_index(model, num_items, hidden_dim):
    """Build FAISS index for fast retrieval."""
    # Encode all items
    all_items = torch.arange(num_items)
    item_embeddings = model.encode_items(all_items).detach().numpy()

    # Build index
    index = faiss.IndexFlatIP(hidden_dim)  # Inner product
    index.add(item_embeddings)

    return index

def retrieve_candidates(model, index, item_seq, k=100):
    """Retrieve top-k candidates for a user."""
    user_emb = model.encode_user(item_seq).detach().numpy()
    scores, indices = index.search(user_emb, k)
    return indices, scores

Efficient Attention for Long Sequences

Users can have thousands of interactions. Standard attention is O(n2)O(n^2):

Python
# Linear attention approximations

class LinearAttention(nn.Module):
    """
    Linear attention via kernel feature maps.
    O(n) instead of O(n^2).
    """

    def __init__(self, hidden_dim: int, num_heads: int):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads

        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)

    def feature_map(self, x):
        """ELU-based feature map for positive attention."""
        return F.elu(x) + 1

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, L, D = x.shape
        H, d = self.num_heads, self.head_dim

        Q = self.feature_map(self.q_proj(x)).view(B, L, H, d)
        K = self.feature_map(self.k_proj(x)).view(B, L, H, d)
        V = self.v_proj(x).view(B, L, H, d)

        # Linear attention: (Q @ K.T) @ V = Q @ (K.T @ V)
        # Compute K.T @ V first: O(L * d^2) instead of O(L^2 * d)
        KV = torch.einsum('blhd,blhe->bhde', K, V)  # (B, H, d, d)

        # Q @ KV
        output = torch.einsum('blhd,bhde->blhe', Q, KV)  # (B, L, H, d)

        # Normalize
        K_sum = K.sum(dim=1)  # (B, H, d)
        normalizer = torch.einsum('blhd,bhd->blh', Q, K_sum)  # (B, L, H)
        output = output / (normalizer.unsqueeze(-1) + 1e-6)

        output = output.reshape(B, L, D)
        return self.out_proj(output)

Part X: Incorporating Side Information

Beyond Item IDs

Pure ID-based models ignore rich item metadata. Modern systems incorporate:

Python
class SASRecWithFeatures(nn.Module):
    """SASRec with item and context features."""

    def __init__(
        self,
        num_items: int,
        num_categories: int,
        num_brands: int,
        hidden_dim: int = 64,
        **kwargs
    ):
        super().__init__()

        # ID embeddings
        self.item_embedding = nn.Embedding(num_items + 1, hidden_dim, padding_idx=0)

        # Feature embeddings
        self.category_embedding = nn.Embedding(num_categories, hidden_dim // 4)
        self.brand_embedding = nn.Embedding(num_brands, hidden_dim // 4)

        # Continuous features
        self.price_proj = nn.Linear(1, hidden_dim // 4)

        # Combine features
        feature_dim = hidden_dim + hidden_dim // 4 * 3
        self.feature_fusion = nn.Linear(feature_dim, hidden_dim)

        # Rest of transformer...
        self.transformer = ...

    def embed_items(
        self,
        item_ids: torch.Tensor,
        categories: torch.Tensor,
        brands: torch.Tensor,
        prices: torch.Tensor,
    ) -> torch.Tensor:
        """Combine ID and feature embeddings."""
        id_emb = self.item_embedding(item_ids)
        cat_emb = self.category_embedding(categories)
        brand_emb = self.brand_embedding(brands)
        price_emb = self.price_proj(prices.unsqueeze(-1))

        combined = torch.cat([id_emb, cat_emb, brand_emb, price_emb], dim=-1)
        return self.feature_fusion(combined)

Multi-Modal Features

For rich content like images and descriptions:

Python
class MultiModalSASRec(nn.Module):
    """SASRec with text and image features."""

    def __init__(self, hidden_dim: int = 256):
        super().__init__()

        # Pre-trained encoders (frozen or fine-tuned)
        self.text_encoder = ...  # e.g., sentence-transformers
        self.image_encoder = ...  # e.g., CLIP vision encoder

        # Project to common space
        self.text_proj = nn.Linear(768, hidden_dim)  # Assuming BERT-base
        self.image_proj = nn.Linear(512, hidden_dim)  # Assuming CLIP
        self.id_embedding = nn.Embedding(num_items, hidden_dim)

        # Fusion
        self.fusion = nn.Sequential(
            nn.Linear(hidden_dim * 3, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
        )

    def embed_items(self, item_ids, text_features, image_features):
        id_emb = self.id_embedding(item_ids)
        text_emb = self.text_proj(text_features)
        image_emb = self.image_proj(image_features)

        combined = torch.cat([id_emb, text_emb, image_emb], dim=-1)
        return self.fusion(combined)

Part XI: Evaluation and Metrics

Standard Metrics

Python
def evaluate_recommendations(
    predictions: torch.Tensor,  # (num_users, k) top-k item indices
    ground_truth: torch.Tensor,  # (num_users,) true next items
    k_values: list = [1, 5, 10, 20],
):
    """Compute standard recommendation metrics."""

    results = {}

    for k in k_values:
        top_k = predictions[:, :k]

        # Hit Rate (HR@k): Did we include the true item?
        hits = (top_k == ground_truth.unsqueeze(-1)).any(dim=-1)
        results[f'HR@{k}'] = hits.float().mean().item()

        # NDCG@k: Normalized Discounted Cumulative Gain
        # Accounts for position (higher rank = better)
        positions = (top_k == ground_truth.unsqueeze(-1)).float()
        ranks = torch.arange(1, k + 1, device=predictions.device).float()
        dcg = (positions / torch.log2(ranks + 1)).sum(dim=-1)
        idcg = 1.0  # Perfect ranking has single relevant item at position 1
        results[f'NDCG@{k}'] = (dcg / idcg).mean().item()

        # MRR: Mean Reciprocal Rank
        match_positions = (top_k == ground_truth.unsqueeze(-1)).float().argmax(dim=-1)
        has_match = (top_k == ground_truth.unsqueeze(-1)).any(dim=-1)
        rr = torch.where(has_match, 1.0 / (match_positions + 1), torch.zeros_like(match_positions.float()))
        results[f'MRR@{k}'] = rr.mean().item()

    return results

Beyond Accuracy: Coverage and Diversity

Python
def evaluate_coverage_diversity(
    predictions: torch.Tensor,  # (num_users, k)
    item_popularity: torch.Tensor,  # (num_items,) interaction counts
    num_items: int,
):
    """Evaluate beyond accuracy metrics."""

    # Coverage: What fraction of items are ever recommended?
    unique_items = predictions.unique()
    coverage = len(unique_items) / num_items

    # Gini coefficient: How unequal is the recommendation distribution?
    # Lower = more equal = better diversity
    rec_counts = torch.bincount(predictions.flatten(), minlength=num_items).float()
    sorted_counts = rec_counts.sort().values
    n = len(sorted_counts)
    index = torch.arange(1, n + 1, device=sorted_counts.device).float()
    gini = (2 * (index * sorted_counts).sum() / (n * sorted_counts.sum())) - (n + 1) / n

    # Popularity bias: Are we over-recommending popular items?
    rec_popularity = item_popularity[predictions].float().mean()
    overall_popularity = item_popularity.float().mean()
    popularity_bias = rec_popularity / overall_popularity

    return {
        'coverage': coverage,
        'gini': gini.item(),
        'popularity_bias': popularity_bias.item(),
    }

Part XII: Production Considerations

Serving Architecture

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                PRODUCTION SERVING ARCHITECTURE                           │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  REQUEST FLOW:                                                           │
│  ─────────────                                                           │
│                                                                          │
│  User Request                                                            │
│       ↓                                                                  │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │                     RETRIEVAL STAGE                              │   │
│  │  (Fast, broad filtering: 10M items → 1000 candidates)           │   │
│  │                                                                   │   │
│  │  • User tower encodes recent history                             │   │
│  │  • ANN search against pre-computed item embeddings               │   │
│  │  • Latency: 5-10ms                                               │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│       ↓                                                                  │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │                      RANKING STAGE                               │   │
│  │  (Accurate, fine-grained scoring: 1000 → 100)                   │   │
│  │                                                                   │   │
│  │  • Full transformer model scores candidates                      │   │
│  │  • Uses rich features, cross-attention                           │   │
│  │  • Latency: 10-30ms                                              │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│       ↓                                                                  │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │                     RE-RANKING STAGE                             │   │
│  │  (Business logic, diversity: 100 → 20)                          │   │
│  │                                                                   │   │
│  │  • Apply business rules (availability, margins)                  │   │
│  │  • Ensure diversity (categories, price ranges)                   │   │
│  │  • Personalization constraints                                   │   │
│  │  • Latency: 5ms                                                  │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│       ↓                                                                  │
│  Response (20 recommendations)                                          │
│  Total latency: 20-50ms                                                 │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Model Updates and Freshness

Python
# Handling new items (cold start)
class ItemEmbeddingWithFallback:
    """
    Handle new items that weren't in training.
    """

    def __init__(self, model, content_encoder):
        self.model = model
        self.content_encoder = content_encoder  # e.g., BERT for item text

        # Cache for computed embeddings
        self.embedding_cache = {}

    def get_embedding(self, item_id: int, item_features: dict = None):
        """Get embedding, with fallback for new items."""

        # Check if item was in training
        if item_id < self.model.num_items:
            return self.model.item_embedding.weight[item_id]

        # New item: use content features
        if item_id in self.embedding_cache:
            return self.embedding_cache[item_id]

        if item_features is None:
            # No features: use category average or random
            return self.model.item_embedding.weight.mean(dim=0)

        # Encode content features
        content_emb = self.content_encoder(item_features)
        self.embedding_cache[item_id] = content_emb

        return content_emb

Part XIII: Production Systems at Scale

Industrial Deployments (2024-2025)

The gap between academic research and production is closing. Here are the transformer-based systems powering major platforms:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│           TRANSFORMER RECSYS IN PRODUCTION (2024-2025)                   │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  COMPANY        MODEL              KEY INNOVATION                       │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  Meta           HSTU               Trillion-param, no-softmax attention │
│                                    12.4%+ improvement in production     │
│                                                                          │
│  LinkedIn       LiGR               Gated residuals, set-wise scoring    │
│                                    Deprecated feature engineering       │
│                                                                          │
│  ByteDance      Monolith           Real-time online training            │
│                                    Collision-less embeddings            │
│                                                                          │
│  Kuaishou       OneRec             End-to-end generative (no retrieval) │
│                                    DPO alignment, 1.6% watch-time gain  │
│                                                                          │
│  Pinterest      PinnerFormer       Dense all-action loss                │
│                                    Long-term engagement prediction      │
│                                                                          │
│  Spotify        Semantic IDs       LLaMA fine-tuning for domain         │
│                                    Unified search + recommendation      │
│                                                                          │
│  Netflix        UniCoRn            Unified contextual ranker            │
│                                    FM-Intent for intent prediction      │
│                                                                          │
│  YouTube        Transformer Watch  Transformer layers for engagement    │
│                 Next               Multi-task learning + retrieval      │
│                                                                          │
│  Albatross      DenseRec           Sequential embeddings for cold-start │
│                                    Dual-path (content + behavior)       │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Monolith: ByteDance's Real-Time System

Monolith powers TikTok's recommendation engine with real-time online training—the model updates continuously as users interact.

Python
# Monolith architecture concepts (simplified)
class MonolithSystem:
    """
    ByteDance's real-time recommendation system.
    Key: Continuous online learning from streaming data.
    """

    def __init__(self):
        # Collision-less embedding table using cuckoo hashing
        self.embedding_table = CollisionlessEmbeddingTable(
            num_buckets=1_000_000_000,  # Billion-scale
            embedding_dim=128,
        )

        # Parameter server architecture
        self.ps = ParameterServer()

        # Streaming training pipeline
        self.kafka_actions = KafkaConsumer('user_actions')
        self.kafka_features = KafkaConsumer('features')
        self.flink_joiner = FlinkStreamJoiner()

    def online_training_loop(self):
        """
        Continuous training from streaming data.
        Model updates propagate to serving within minutes.
        """
        while True:
            # Join action and feature streams
            training_examples = self.flink_joiner.join(
                self.kafka_actions,
                self.kafka_features,
                window_size='1_minute'
            )

            # Update model
            for batch in training_examples:
                gradients = self.compute_gradients(batch)
                self.ps.apply_gradients(gradients)

            # Sync to serving (near real-time)
            self.ps.sync_to_inference()


class CollisionlessEmbeddingTable:
    """
    Cuckoo hashing for collision-free embeddings.
    Handles billions of items without hash collisions.
    """

    def __init__(self, num_buckets: int, embedding_dim: int):
        # Two tables with different hash functions
        self.table1 = torch.zeros(num_buckets, embedding_dim)
        self.table2 = torch.zeros(num_buckets, embedding_dim)

        # Frequency filtering (remove rare items)
        self.access_counts = torch.zeros(num_buckets)
        self.min_frequency = 5

        # TTL for expiring old embeddings
        self.last_access = torch.zeros(num_buckets)
        self.ttl_days = 7

    def lookup(self, item_ids: torch.Tensor) -> torch.Tensor:
        # Try first hash
        h1 = self.hash1(item_ids)
        result = self.table1[h1]

        # If collision, try second hash
        h2 = self.hash2(item_ids)
        collided = (self.access_counts[h1] == 0)
        result[collided] = self.table2[h2[collided]]

        return result

Monolith innovations:

  • Real-time updates: Model parameters sync to serving every ~1 minute
  • Collision-less embeddings: Cuckoo hashing eliminates hash collisions
  • Expiring embeddings: TTL removes stale items automatically
  • Frequency filtering: Ignores items with <5 interactions

OneRec: Kuaishou's Unified Generative Recommender

OneRec (February 2025) is the first end-to-end generative recommender deployed at scale, replacing the traditional retrieve-rank pipeline.

Code
┌─────────────────────────────────────────────────────────────────────────┐
│              TRADITIONAL vs ONEREC ARCHITECTURE                          │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  TRADITIONAL (3-Stage):                                                  │
│  ───────────────────────                                                 │
│                                                                          │
│  Retrieval → Pre-Ranking → Ranking → Results                           │
│  (10M→1K)    (1K→100)     (100→20)                                     │
│                                                                          │
│  • 3 separate models to maintain                                        │
│  • Complex infrastructure                                               │
│  • Information loss between stages                                      │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  ONEREC (Unified):                                                       │
│  ─────────────────                                                       │
│                                                                          │
│  User History → Encoder-Decoder → Generate Session → Results            │
│                      ↓                                                   │
│                 Sparse MoE                                              │
│                      ↓                                                   │
│                 DPO Alignment                                           │
│                                                                          │
│  • Single model end-to-end                                              │
│  • Session-wise generation (not point-by-point)                         │
│  • 10.6% OpEx of traditional pipeline                                   │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

OneRec key components:

  1. Session-wise Generation: Generates a coherent session of recommendations, not individual items
  2. Sparse MoE: Scales to 10x FLOPs with mixture of experts
  3. Iterative DPO Alignment: Uses reward models to generate preference pairs for Direct Preference Optimization

Production results at Kuaishou:

  • 1.6% watch-time increase (substantial at Kuaishou's scale)
  • 25% of total QPS served by OneRec
  • 10.6% OpEx compared to traditional pipeline
  • Demonstrated scaling laws for recommendation models

PinnerFormer: Pinterest's Long-Term Engagement Model

PinnerFormer (deployed since 2021) focuses on predicting long-term user engagement, not just next-click.

Python
class PinnerFormer:
    """
    Pinterest's user representation model.
    Key: Predict long-term engagement, not just next action.
    """

    def __init__(self, hidden_dim: int = 256, num_layers: int = 4):
        # Transformer encoder for user sequences
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(hidden_dim, nhead=8),
            num_layers=num_layers
        )

        # Dense all-action loss (key innovation)
        # Predict ALL future actions, not just next one
        self.future_predictor = nn.Linear(hidden_dim, hidden_dim)

    def dense_all_action_loss(
        self,
        user_history: torch.Tensor,
        future_actions: torch.Tensor,  # All actions in next N days
    ) -> torch.Tensor:
        """
        Loss that considers ALL future engagement, not just next action.

        Standard: Predict item at t+1 given items 1...t
        PinnerFormer: Predict items at t+1...t+N given items 1...t
        """
        # Encode history
        history_repr = self.transformer(user_history)
        user_embedding = history_repr[:, -1]  # Last position

        # Predict future engagement
        future_pred = self.future_predictor(user_embedding)

        # Loss against ALL future actions (not just next)
        future_embeddings = self.item_encoder(future_actions)
        scores = torch.mm(future_pred, future_embeddings.T)

        # Each future action is a positive
        loss = -torch.log_softmax(scores, dim=-1).diag().mean()

        return loss

Why dense all-action loss matters:

  • Next-item prediction optimizes for immediate clicks
  • Dense all-action optimizes for long-term value
  • Better alignment with business metrics (retention, LTV)

YouTube: Transformer-Era Watch Next

YouTube evolved from their foundational 2016 DNN system to incorporate transformer architectures for their Watch Next recommendations. The 2024-2025 system uses multi-task learning with transformer layers to jointly optimize for engagement and satisfaction.

Code
┌─────────────────────────────────────────────────────────────────────────┐
│            YOUTUBE TRANSFORMER ARCHITECTURE (2024-2025)                  │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  CANDIDATE GENERATION:                                                  │
│  ────────────────────                                                   │
│                                                                          │
│  User History → Transformer Encoder → User Embedding                    │
│       ↓                                                                  │
│  Watch sequence: [v₁, v₂, v₃, ..., vₙ]                                 │
│                    ↓                                                    │
│  Multi-head self-attention (captures viewing patterns)                 │
│                    ↓                                                    │
│  Video-level cross-attention (video features)                          │
│                    ↓                                                    │
│  [USER] token aggregation                                              │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  MULTI-TASK RANKING:                                                    │
│  ──────────────────                                                     │
│                                                                          │
│  User + Candidate → Transformer Ranking Model → Scores                 │
│                                                                          │
│  Objectives (jointly trained):                                          │
│  • Click probability (immediate engagement)                            │
│  • Watch time (session depth)                                          │
│  • Long-term satisfaction (surveys + implicit signals)                 │
│  • Creator fairness (distribution constraints)                         │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  KEY INNOVATIONS:                                                        │
│  • Sequential modeling with transformers (vs. bag-of-features)         │
│  • Multi-objective optimization with Pareto frontiers                  │
│  • Exploration bonuses for new/underserved content                     │
│  • Real-time feature updates via streaming pipelines                   │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

YouTube's transformer evolution:

  1. 2016-2019: Original two-tower DNN with averaged watch history
  2. 2019-2022: RNN-based sequential modeling for temporal patterns
  3. 2022-2024: Transformer encoders replace RNNs for user sequences
  4. 2024-2025: Full transformer ranking with multi-task heads

Key production learnings from YouTube:

  • Causal masking is essential: User can only have seen previous videos
  • Timestamp positional embeddings: When videos were watched matters as much as what
  • Multi-objective balance: Pure engagement optimization leads to "rabbit holes"; satisfaction signals are crucial
  • Scale challenges: Billions of users, millions of videos, sub-100ms latency requirements

Albatross AI: DenseRec for Cold-Start

Albatross AI, founded by ex-Amazon recommendation scientists, addresses one of RecSys's hardest problems: cold-start for new items. Their DenseRec paper (RecSys 2025) introduces a dual-path embedding architecture that combines content signals with sequential behavior patterns.

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    DENSEREC ARCHITECTURE                                 │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  THE COLD-START PROBLEM:                                                 │
│  ───────────────────────                                                 │
│                                                                          │
│  New Item (no interactions)                                             │
│       ↓                                                                  │
│  Traditional RecSys: "I have no embedding for this!"                   │
│       ↓                                                                  │
│  DenseRec: "I can generate one from content + category behavior"       │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  DUAL-PATH EMBEDDING:                                                    │
│  ────────────────────                                                    │
│                                                                          │
│  PATH 1: Content Encoder                                                │
│  ─────────────────────                                                  │
│  • Text (titles, descriptions) → BERT/T5 encoder                       │
│  • Images → Vision transformer (ViT/CLIP)                              │
│  • Attributes → Categorical embeddings                                 │
│  • Combined via learned fusion                                          │
│                                                                          │
│  PATH 2: Sequential Behavior (Category-Level)                          │
│  ─────────────────────────────────────────────                          │
│  • Items in same category/brand → aggregated behavior                  │
│  • Temporal patterns → transformer encoder                             │
│  • Transfer signals from similar items                                  │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  FUSION:                                                                 │
│                                                                          │
│  Content Path ──┐                                                       │
│                 ├─→ Gated Fusion ─→ Dense Embedding                    │
│  Behavior Path ─┘                                                       │
│                                                                          │
│  Gate learns when to trust content vs. behavior signals                │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

DenseRec innovations:

  1. Category-level sequential patterns: New items inherit behavior patterns from their category
  2. Adaptive gating: Learns to weight content vs. behavior based on item maturity
  3. Temporal decay: Recent category behavior weighted more heavily
  4. Multi-modal content fusion: Text + image + attributes combined
Python
# DenseRec concept (simplified)
class DenseRecEmbedding:
    """
    Albatross's dual-path embedding for cold-start items.
    """

    def __init__(self, hidden_dim: int = 256):
        # Path 1: Content encoders
        self.text_encoder = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
        self.image_encoder = CLIPVisionModel.from_pretrained('openai/clip-vit-base-patch32')
        self.attribute_encoder = nn.Embedding(num_attributes, hidden_dim)

        # Path 2: Sequential behavior encoder
        self.behavior_transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(hidden_dim, nhead=4),
            num_layers=2
        )

        # Fusion gate
        self.fusion_gate = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.Sigmoid()
        )

        # Final projection
        self.projection = nn.Linear(hidden_dim, hidden_dim)

    def get_embedding(
        self,
        text: str,
        image: Optional[torch.Tensor],
        attributes: List[int],
        category_id: int,
        interaction_count: int
    ) -> torch.Tensor:
        # Content path
        text_emb = self.text_encoder(text)
        if image is not None:
            image_emb = self.image_encoder(image)
            content_emb = (text_emb + image_emb) / 2
        else:
            content_emb = text_emb

        # Behavior path (category-level)
        category_sequence = self.get_category_behavior(category_id)
        behavior_emb = self.behavior_transformer(category_sequence)[:, -1]

        # Adaptive fusion based on item maturity
        gate_input = torch.cat([content_emb, behavior_emb], dim=-1)
        gate = self.fusion_gate(gate_input)

        # Cold items rely more on content, warm items on behavior
        combined = gate * content_emb + (1 - gate) * behavior_emb

        return self.projection(combined)

    def get_category_behavior(self, category_id: int) -> torch.Tensor:
        """
        Aggregate sequential behavior patterns from category.
        New items inherit these patterns for warm-start embeddings.
        """
        # Get recent interactions for items in this category
        category_items = self.category_to_items[category_id]
        category_sequences = [self.item_sequences[i] for i in category_items]

        # Aggregate with temporal weighting
        aggregated = self.temporal_aggregate(category_sequences)
        return aggregated

Why DenseRec matters for production:

  • Immediate recommendations for new products: No waiting for interaction data
  • Catalog turnover handling: E-commerce sites add thousands of items daily
  • Seasonal/trending items: New items can surface immediately based on content similarity
  • Long-tail coverage: Items with few interactions get meaningful embeddings

Albatross AI background:

  • Founded by former Amazon personalization scientists
  • €12.5M Series A funding (2024)
  • Customers include major European retailers
  • Focus on e-commerce and marketplace recommendations

Part XIV: Reinforcement Learning for Recommendation Systems

Traditional supervised learning optimizes for immediate metrics (next-click prediction). But recommendations have long-term effects—showing clickbait might get clicks but hurts user retention. Reinforcement Learning (RL) models recommendations as a sequential decision problem, optimizing for long-term user value.

Why RL for Recommendations?

Code
┌─────────────────────────────────────────────────────────────────────────┐
│            SUPERVISED LEARNING vs REINFORCEMENT LEARNING                 │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  SUPERVISED LEARNING (Standard RecSys):                                  │
│  ─────────────────────────────────────                                   │
│                                                                          │
│  Training: Log data → Predict clicked items                             │
│  Objective: Maximize P(click | user, context)                           │
│                                                                          │
│  Problems:                                                               │
│  • Optimizes for immediate engagement, not long-term value              │
│  • Ignores sequential effects of recommendations                        │
│  • Exploitation bias: Shows what users already like                     │
│  • Can't explore to discover new user preferences                       │
│  • Feedback loops: Popular items get more popular                       │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  REINFORCEMENT LEARNING:                                                 │
│  ──────────────────────                                                  │
│                                                                          │
│  Formulation:                                                            │
│  • State: User history + context                                        │
│  • Action: Item(s) to recommend                                         │
│  • Reward: User engagement (click, purchase, watch time, return visit)  │
│  • Policy: π(action | state) - the recommendation model                 │
│                                                                          │
│  Objective: Maximize cumulative discounted reward                       │
│  max E[Σ γ^t × r_t]  where γ ∈ [0,1] is discount factor                │
│                                                                          │
│  Benefits:                                                               │
│  • Optimizes for long-term user value (retention, LTV)                 │
│  • Natural exploration-exploitation trade-off                           │
│  • Models sequential nature of user sessions                            │
│  • Can incorporate delayed rewards (purchase after browsing)            │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  EXAMPLE: Video Recommendation                                           │
│  ────────────────────────────────                                        │
│                                                                          │
│  Supervised: Recommend video with highest click probability             │
│  Result: Clickbait thumbnails, sensational content                      │
│                                                                          │
│  RL: Maximize (watch time + return visits + subscriptions)              │
│  Result: Quality content that builds long-term engagement               │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

The MDP Formulation

Recommendation as a Markov Decision Process (MDP):

Code
┌─────────────────────────────────────────────────────────────────────────┐
│              MDP FOR RECOMMENDATION SYSTEMS                              │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  COMPONENTS:                                                             │
│  ───────────                                                             │
│                                                                          │
│  State (s_t):                                                            │
│  • User interaction history: [item_1, item_2, ..., item_t]              │
│  • User features: demographics, preferences, context                    │
│  • Session features: time of day, device, location                      │
│  • System state: already-shown items, diversity budget                  │
│                                                                          │
│  Action (a_t):                                                           │
│  • Single item: Select one item to recommend                            │
│  • Slate: Select K items for a page                                     │
│  • Ranking: Order K candidate items                                     │
│                                                                          │
│  Reward (r_t):                                                           │
│  • Immediate: click, add-to-cart, watch start                           │
│  • Delayed: purchase, subscription, return visit                        │
│  • Composite: weighted sum of multiple signals                          │
│  • Negative: skip, dislike, unsubscribe                                 │
│                                                                          │
│  Transition (P(s_{t+1} | s_t, a_t)):                                    │
│  • User's response to recommendation                                    │
│  • State evolution based on user behavior                               │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  TRAJECTORY:                                                             │
│                                                                          │
│  s_0 → a_0 → r_0 → s_1 → a_1 → r_1 → s_2 → ... → s_T                  │
│   │      │     │     │      │     │                                     │
│   │      │     │     │      │     └── User watched 80% of video        │
│   │      │     │     │      └── Recommended video B                     │
│   │      │     │     └── User clicked, updated history                  │
│   │      │     └── User clicked (+1 reward)                             │
│   │      └── Recommended video A                                         │
│   └── Initial user state                                                 │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  CHALLENGES IN RECSYS MDP:                                               │
│                                                                          │
│  1. LARGE ACTION SPACE                                                   │
│     Millions of items = millions of actions                             │
│     Can't enumerate all actions like in games                           │
│                                                                          │
│  2. PARTIAL OBSERVABILITY                                                │
│     Don't see user's full intent or preferences                         │
│     State is an approximation from observed behavior                    │
│                                                                          │
│  3. DELAYED AND SPARSE REWARDS                                           │
│     User might purchase days after viewing                              │
│     Most interactions have no explicit feedback                         │
│                                                                          │
│  4. NON-STATIONARITY                                                     │
│     User preferences change over time                                   │
│     Item catalog changes daily                                          │
│                                                                          │
│  5. SAFETY CONSTRAINTS                                                   │
│     Can't show harmful content while exploring                          │
│     Business rules must be respected                                    │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Bandit Algorithms: The Foundation

Before full RL, contextual bandits provide a simpler framework that's widely deployed:

Python
import numpy as np
from typing import List, Dict, Any
import torch
import torch.nn as nn

class ContextualBandit:
    """
    Contextual bandit for recommendations.

    Simpler than full RL:
    - No state transitions (each decision is independent)
    - Immediate rewards only
    - Still handles exploration vs exploitation
    """

    def __init__(
        self,
        num_items: int,
        context_dim: int,
        exploration_rate: float = 0.1,
    ):
        self.num_items = num_items
        self.context_dim = context_dim
        self.exploration_rate = exploration_rate

        # Thompson Sampling: Maintain posterior for each item
        # Using Bayesian linear regression
        self.A = [np.eye(context_dim) for _ in range(num_items)]  # Precision matrices
        self.b = [np.zeros(context_dim) for _ in range(num_items)]  # Mean accumulators

    def select_item(self, context: np.ndarray) -> int:
        """
        Select item using Thompson Sampling.

        Args:
            context: User/context features

        Returns:
            Selected item index
        """
        # Sample from posterior for each item
        sampled_rewards = []

        for i in range(self.num_items):
            # Posterior mean and covariance
            A_inv = np.linalg.inv(self.A[i])
            theta_mean = A_inv @ self.b[i]
            theta_cov = A_inv

            # Sample theta from posterior
            theta_sample = np.random.multivariate_normal(theta_mean, theta_cov)

            # Predicted reward
            reward = context @ theta_sample
            sampled_rewards.append(reward)

        return np.argmax(sampled_rewards)

    def update(self, context: np.ndarray, item: int, reward: float):
        """Update posterior after observing reward."""
        self.A[item] += np.outer(context, context)
        self.b[item] += reward * context


class LinUCB:
    """
    Linear Upper Confidence Bound (LinUCB).

    Classic bandit algorithm for recommendations.
    Used in Yahoo! News personalization (2010).
    """

    def __init__(
        self,
        num_items: int,
        context_dim: int,
        alpha: float = 1.0,  # Exploration parameter
    ):
        self.num_items = num_items
        self.context_dim = context_dim
        self.alpha = alpha

        # Initialize parameters
        self.A = [np.eye(context_dim) for _ in range(num_items)]
        self.b = [np.zeros(context_dim) for _ in range(num_items)]

    def select_item(self, context: np.ndarray) -> int:
        """
        Select item with highest UCB.

        UCB = predicted_reward + alpha * uncertainty
        """
        ucb_scores = []

        for i in range(self.num_items):
            A_inv = np.linalg.inv(self.A[i])
            theta = A_inv @ self.b[i]

            # Predicted reward
            pred_reward = context @ theta

            # Uncertainty (confidence interval width)
            uncertainty = self.alpha * np.sqrt(context @ A_inv @ context)

            ucb = pred_reward + uncertainty
            ucb_scores.append(ucb)

        return np.argmax(ucb_scores)

    def update(self, context: np.ndarray, item: int, reward: float):
        """Update parameters after observing reward."""
        self.A[item] += np.outer(context, context)
        self.b[item] += reward * context


class NeuralContextualBandit(nn.Module):
    """
    Neural network-based contextual bandit.

    Uses neural network to model reward function,
    with dropout for uncertainty estimation.
    """

    def __init__(
        self,
        context_dim: int,
        num_items: int,
        hidden_dim: int = 256,
        dropout: float = 0.1,
    ):
        super().__init__()

        self.num_items = num_items

        # Shared context encoder
        self.context_encoder = nn.Sequential(
            nn.Linear(context_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
        )

        # Item embeddings
        self.item_embedding = nn.Embedding(num_items, hidden_dim)

        # Reward predictor
        self.reward_head = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1),
        )

    def forward(
        self,
        context: torch.Tensor,
        item_ids: torch.Tensor = None,
    ) -> torch.Tensor:
        """
        Predict rewards for items given context.

        Args:
            context: (batch, context_dim)
            item_ids: (batch, num_candidates) or None for all items
        """
        # Encode context
        context_emb = self.context_encoder(context)  # (batch, hidden)

        if item_ids is None:
            # Score all items
            item_emb = self.item_embedding.weight  # (num_items, hidden)
            # Expand for batch
            context_emb = context_emb.unsqueeze(1)  # (batch, 1, hidden)
            item_emb = item_emb.unsqueeze(0)  # (1, num_items, hidden)
        else:
            item_emb = self.item_embedding(item_ids)  # (batch, num_candidates, hidden)
            context_emb = context_emb.unsqueeze(1)  # (batch, 1, hidden)

        # Combine and predict
        combined = torch.cat([
            context_emb.expand(-1, item_emb.size(1), -1),
            item_emb.expand(context_emb.size(0), -1, -1),
        ], dim=-1)

        rewards = self.reward_head(combined).squeeze(-1)  # (batch, num_items)
        return rewards

    def select_with_uncertainty(
        self,
        context: torch.Tensor,
        num_samples: int = 10,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Select items with uncertainty estimation via MC Dropout.
        """
        self.train()  # Enable dropout

        # Multiple forward passes
        reward_samples = []
        for _ in range(num_samples):
            rewards = self.forward(context)
            reward_samples.append(rewards)

        reward_samples = torch.stack(reward_samples)  # (num_samples, batch, num_items)

        # Mean and uncertainty
        mean_rewards = reward_samples.mean(dim=0)
        uncertainty = reward_samples.std(dim=0)

        # UCB-style selection
        ucb = mean_rewards + uncertainty

        return ucb.argmax(dim=-1), uncertainty

Policy Gradient Methods for Recommendations

For full RL with sequential states, we use policy gradient methods:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    POLICY GRADIENT FOR RECSYS                            │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  REINFORCE (Williams, 1992):                                            │
│  ───────────────────────────                                             │
│                                                                          │
│  Policy: π_θ(a|s) - probability of recommending item a given state s   │
│                                                                          │
│  Objective: J(θ) = E_π[Σ γ^t r_t]                                      │
│                                                                          │
│  Gradient: ∇J(θ) = E_π[Σ ∇log π_θ(a_t|s_t) × G_t]                     │
│                                                                          │
│  Where G_t = Σ_{k=0}^∞ γ^k r_{t+k}  (return from time t)              │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  ACTOR-CRITIC:                                                           │
│  ─────────────                                                           │
│                                                                          │
│  Add a value function V(s) to reduce variance:                          │
│                                                                          │
│  Advantage: A(s,a) = Q(s,a) - V(s)                                     │
│                                                                          │
│  Actor (policy): π_θ(a|s)                                              │
│  Critic (value): V_φ(s)                                                 │
│                                                                          │
│  Actor update: ∇J(θ) = E[∇log π_θ(a|s) × A(s,a)]                      │
│  Critic update: Minimize (V_φ(s) - G_t)²                               │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  OFF-POLICY LEARNING (Crucial for RecSys):                              │
│  ──────────────────────────────────────────                              │
│                                                                          │
│  Problem: Can't interact with users in real-time during training       │
│  Solution: Learn from logged data (offline RL)                          │
│                                                                          │
│  Importance Sampling:                                                    │
│  J(π) = E_{π_old}[π(a|s)/π_old(a|s) × r]                              │
│                                                                          │
│  Challenges:                                                             │
│  • High variance when π differs from π_old                              │
│  • Need logged propensities (probability of showing item)               │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘
Python
class REINFORCERecommender(nn.Module):
    """
    REINFORCE algorithm for sequential recommendations.

    The policy network outputs a distribution over items,
    and we train it to maximize expected cumulative reward.
    """

    def __init__(
        self,
        num_items: int,
        state_dim: int,
        hidden_dim: int = 256,
        gamma: float = 0.99,
    ):
        super().__init__()

        self.gamma = gamma
        self.num_items = num_items

        # State encoder (could be transformer for sequential state)
        self.state_encoder = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
        )

        # Policy head: outputs logits over items
        self.policy_head = nn.Linear(hidden_dim, num_items)

        # Storage for episode
        self.log_probs = []
        self.rewards = []

    def forward(self, state: torch.Tensor) -> torch.distributions.Categorical:
        """
        Compute policy distribution over items.
        """
        state_emb = self.state_encoder(state)
        logits = self.policy_head(state_emb)
        return torch.distributions.Categorical(logits=logits)

    def select_action(self, state: torch.Tensor) -> tuple[int, torch.Tensor]:
        """
        Sample action from policy and store log probability.
        """
        dist = self.forward(state)
        action = dist.sample()
        log_prob = dist.log_prob(action)

        self.log_probs.append(log_prob)
        return action.item(), log_prob

    def store_reward(self, reward: float):
        """Store reward for current step."""
        self.rewards.append(reward)

    def compute_loss(self) -> torch.Tensor:
        """
        Compute REINFORCE loss at end of episode.

        Loss = -Σ log π(a|s) × G_t
        where G_t is the return from step t
        """
        # Compute returns (discounted cumulative rewards)
        returns = []
        G = 0
        for r in reversed(self.rewards):
            G = r + self.gamma * G
            returns.insert(0, G)

        returns = torch.tensor(returns)

        # Normalize returns (variance reduction)
        returns = (returns - returns.mean()) / (returns.std() + 1e-8)

        # Policy gradient loss
        loss = 0
        for log_prob, G in zip(self.log_probs, returns):
            loss -= log_prob * G

        return loss

    def clear_episode(self):
        """Clear stored episode data."""
        self.log_probs = []
        self.rewards = []


class ActorCriticRecommender(nn.Module):
    """
    Actor-Critic for recommendations.

    Actor: Policy network π(a|s)
    Critic: Value network V(s)

    Advantage = r + γV(s') - V(s)
    """

    def __init__(
        self,
        num_items: int,
        state_dim: int,
        hidden_dim: int = 256,
        gamma: float = 0.99,
    ):
        super().__init__()

        self.gamma = gamma

        # Shared state encoder
        self.state_encoder = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
        )

        # Actor (policy)
        self.actor = nn.Linear(hidden_dim, num_items)

        # Critic (value function)
        self.critic = nn.Linear(hidden_dim, 1)

    def forward(self, state: torch.Tensor) -> tuple[torch.distributions.Categorical, torch.Tensor]:
        """
        Compute policy and value for state.
        """
        state_emb = self.state_encoder(state)

        # Policy distribution
        logits = self.actor(state_emb)
        policy = torch.distributions.Categorical(logits=logits)

        # Value estimate
        value = self.critic(state_emb)

        return policy, value

    def compute_loss(
        self,
        states: torch.Tensor,
        actions: torch.Tensor,
        rewards: torch.Tensor,
        next_states: torch.Tensor,
        dones: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Compute actor and critic losses.
        """
        # Current policy and values
        policy, values = self.forward(states)

        # Next state values (for TD target)
        with torch.no_grad():
            _, next_values = self.forward(next_states)
            next_values = next_values * (1 - dones.float())

        # TD target and advantage
        td_target = rewards + self.gamma * next_values
        advantage = td_target - values

        # Actor loss (policy gradient with advantage)
        log_probs = policy.log_prob(actions)
        actor_loss = -(log_probs * advantage.detach()).mean()

        # Critic loss (MSE on value prediction)
        critic_loss = F.mse_loss(values, td_target.detach())

        return actor_loss, critic_loss

Offline RL: Learning from Logged Data

In practice, we can't do online exploration—we must learn from logged interaction data:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                      OFFLINE RL FOR RECOMMENDATIONS                      │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  THE OFFLINE RL PROBLEM:                                                 │
│  ────────────────────────                                                │
│                                                                          │
│  We have: Logged data D = {(s, a, r, s')} from old policy π_old        │
│  We want: Learn new policy π_new that maximizes rewards                 │
│                                                                          │
│  Challenge: Distribution shift                                           │
│  • π_new might choose actions never seen in D                           │
│  • Q-value overestimation for unseen actions                            │
│  • Need to constrain π_new to stay close to π_old                       │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  SOLUTION 1: Conservative Q-Learning (CQL)                              │
│  ──────────────────────────────────────────                              │
│                                                                          │
│  Penalize Q-values for actions not in dataset:                          │
│                                                                          │
│  L_CQL = E_s[log Σ_a exp(Q(s,a))] - E_{s,a~D}[Q(s,a)]                  │
│                                                                          │
│  First term: Pushes down Q for ALL actions                              │
│  Second term: Pulls up Q for actions IN dataset                         │
│  Net effect: Conservatively low Q for unseen actions                    │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  SOLUTION 2: Batch Constrained Q-Learning (BCQ)                         │
│  ──────────────────────────────────────────────                          │
│                                                                          │
│  Only consider actions similar to those in dataset:                     │
│                                                                          │
│  a* = argmax_a Q(s,a)  subject to  π_old(a|s) > threshold              │
│                                                                          │
│  Learns a generative model of π_old to filter actions                  │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  SOLUTION 3: Implicit Q-Learning (IQL)                                  │
│  ──────────────────────────────────────                                  │
│                                                                          │
│  Never query Q for actions outside dataset:                             │
│                                                                          │
│  Learn V(s) separately from Q(s,a)                                     │
│  V(s) = E_a~π_old[Q(s,a)] using expectile regression                   │
│                                                                          │
│  Extract policy: π(a|s) ∝ exp(Q(s,a) - V(s))                           │
│  Only uses in-dataset (s,a) pairs for policy extraction                │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘
Python
class ConservativeQLearning(nn.Module):
    """
    Conservative Q-Learning (CQL) for offline recommendation.

    Key idea: Penalize Q-values for out-of-distribution actions
    to prevent overestimation of unseen items.
    """

    def __init__(
        self,
        num_items: int,
        state_dim: int,
        hidden_dim: int = 256,
        cql_alpha: float = 1.0,  # CQL regularization weight
        gamma: float = 0.99,
    ):
        super().__init__()

        self.num_items = num_items
        self.cql_alpha = cql_alpha
        self.gamma = gamma

        # Q-network
        self.q_network = nn.Sequential(
            nn.Linear(state_dim + num_items, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
        )

        # Target Q-network (for stability)
        self.target_q_network = nn.Sequential(
            nn.Linear(state_dim + num_items, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
        )

        # Item embeddings (for one-hot or learned)
        self.item_embedding = nn.Embedding(num_items, num_items)
        self.item_embedding.weight.data = torch.eye(num_items)  # One-hot

    def get_q_values(
        self,
        states: torch.Tensor,
        actions: torch.Tensor,
        use_target: bool = False,
    ) -> torch.Tensor:
        """
        Compute Q(s, a) for given state-action pairs.
        """
        network = self.target_q_network if use_target else self.q_network

        action_emb = self.item_embedding(actions)
        sa = torch.cat([states, action_emb], dim=-1)
        return network(sa)

    def get_all_q_values(self, states: torch.Tensor) -> torch.Tensor:
        """
        Compute Q(s, a) for all actions.
        """
        batch_size = states.size(0)

        # Expand states for all items
        states_expanded = states.unsqueeze(1).expand(-1, self.num_items, -1)

        # All item embeddings
        all_items = torch.arange(self.num_items, device=states.device)
        item_emb = self.item_embedding(all_items)
        item_emb = item_emb.unsqueeze(0).expand(batch_size, -1, -1)

        # Concatenate and compute Q
        sa = torch.cat([states_expanded, item_emb], dim=-1)
        q_values = self.q_network(sa.view(-1, sa.size(-1)))

        return q_values.view(batch_size, self.num_items)

    def compute_cql_loss(
        self,
        states: torch.Tensor,
        actions: torch.Tensor,
        rewards: torch.Tensor,
        next_states: torch.Tensor,
        dones: torch.Tensor,
    ) -> torch.Tensor:
        """
        Compute CQL loss.

        CQL Loss = TD Loss + α × (logsumexp(Q(s,·)) - Q(s,a))
        """
        batch_size = states.size(0)

        # Standard TD loss
        current_q = self.get_q_values(states, actions)

        with torch.no_grad():
            # Max Q for next state (using target network)
            next_q_all = self.get_all_q_values(next_states)
            next_q_max = next_q_all.max(dim=1, keepdim=True)[0]
            target_q = rewards + self.gamma * next_q_max * (1 - dones.float())

        td_loss = F.mse_loss(current_q, target_q)

        # CQL regularization
        # Push down Q for all actions, pull up Q for dataset actions
        all_q = self.get_all_q_values(states)  # (batch, num_items)

        # logsumexp term (pushes down all Q values)
        logsumexp_q = torch.logsumexp(all_q, dim=1, keepdim=True)

        # Dataset Q term (pulls up Q for observed actions)
        dataset_q = current_q

        cql_loss = (logsumexp_q - dataset_q).mean()

        # Total loss
        total_loss = td_loss + self.cql_alpha * cql_loss

        return total_loss, td_loss, cql_loss


class SlateRL(nn.Module):
    """
    RL for slate recommendation (multiple items at once).

    Challenges:
    - Combinatorial action space: C(N, K) possible slates
    - Item interactions within slate
    - Position bias
    """

    def __init__(
        self,
        num_items: int,
        slate_size: int,
        state_dim: int,
        hidden_dim: int = 256,
    ):
        super().__init__()

        self.num_items = num_items
        self.slate_size = slate_size

        # State encoder
        self.state_encoder = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )

        # Item scorer (for ranking)
        self.item_encoder = nn.Embedding(num_items, hidden_dim)
        self.scorer = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
        )

        # Position encoding (for modeling position bias)
        self.position_encoding = nn.Embedding(slate_size, hidden_dim)

    def score_items(self, state: torch.Tensor) -> torch.Tensor:
        """
        Score all items for a given state.
        """
        batch_size = state.size(0)

        # Encode state
        state_emb = self.state_encoder(state)  # (batch, hidden)

        # Encode all items
        all_items = torch.arange(self.num_items, device=state.device)
        item_emb = self.item_encoder(all_items)  # (num_items, hidden)

        # Score each item
        state_expanded = state_emb.unsqueeze(1).expand(-1, self.num_items, -1)
        item_expanded = item_emb.unsqueeze(0).expand(batch_size, -1, -1)

        combined = torch.cat([state_expanded, item_expanded], dim=-1)
        scores = self.scorer(combined).squeeze(-1)  # (batch, num_items)

        return scores

    def select_slate(
        self,
        state: torch.Tensor,
        explore: bool = True,
    ) -> torch.Tensor:
        """
        Select slate of K items.

        Uses sequential selection with position-aware scoring.
        """
        batch_size = state.size(0)

        # Initial item scores
        scores = self.score_items(state)  # (batch, num_items)

        selected = []
        mask = torch.zeros(batch_size, self.num_items, device=state.device)

        for pos in range(self.slate_size):
            # Mask already selected items
            masked_scores = scores - 1e9 * mask

            if explore:
                # Sample from softmax
                probs = F.softmax(masked_scores, dim=-1)
                item = torch.multinomial(probs, 1).squeeze(-1)
            else:
                # Greedy selection
                item = masked_scores.argmax(dim=-1)

            selected.append(item)

            # Update mask
            mask.scatter_(1, item.unsqueeze(1), 1)

        return torch.stack(selected, dim=1)  # (batch, slate_size)

Production RL Systems (2024-2025)

Major platforms using RL for recommendations:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│              RL IN PRODUCTION RECOMMENDATION SYSTEMS                     │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  YOUTUBE (2019-present):                                                 │
│  ────────────────────────                                                │
│  • "Reinforcement Learning for Slate-based Recommender Systems"        │
│  • SlateQ: Handles slate-level rewards (not per-item)                  │
│  • Optimizes for long-term watch time, not just clicks                 │
│  • Uses REINFORCE with baseline for variance reduction                 │
│                                                                          │
│  SPOTIFY (2022-present):                                                 │
│  ────────────────────────                                                │
│  • Contextual bandits for playlist generation                          │
│  • Thompson Sampling for song selection                                │
│  • Optimizes for listening time + skip rate                            │
│                                                                          │
│  NETFLIX (2020-present):                                                 │
│  ────────────────────────                                                │
│  • Counterfactual evaluation for offline RL                            │
│  • Inverse propensity scoring for unbiased learning                    │
│  • Multi-objective RL (engagement + diversity + freshness)             │
│                                                                          │
│  ALIBABA (2018-present):                                                 │
│  ────────────────────────                                                │
│  • Virtual Taobao: Simulated environment for RL                        │
│  • Batch RL for e-commerce recommendations                             │
│  • Multi-agent RL for marketplace optimization                         │
│                                                                          │
│  KUAISHOU (2024-2025):                                                  │
│  ────────────────────────                                                │
│  • OneRec uses DPO (Direct Preference Optimization)                    │
│  • Iterative reward model training                                     │
│  • RLHF-style alignment for video recommendations                      │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  KEY LESSONS FROM PRODUCTION:                                            │
│                                                                          │
│  1. Start with bandits, not full RL                                    │
│     Simpler, more stable, often sufficient                              │
│                                                                          │
│  2. Offline RL is essential                                             │
│     Can't do online exploration at scale                                │
│                                                                          │
│  3. Reward shaping is critical                                          │
│     Raw metrics often don't capture business value                      │
│                                                                          │
│  4. Safety constraints are non-negotiable                               │
│     RL without guardrails can recommend harmful content                 │
│                                                                          │
│  5. Hybrid approaches work best                                          │
│     RL for exploration, supervised for exploitation                     │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Enrico Piovano, PhD

Co-founder & CTO at Goji AI. Former Applied Scientist at Amazon (Alexa & AGI), focused on Agentic AI and LLMs. PhD in Electrical Engineering from Imperial College London. Gold Medalist at the National Mathematical Olympiad.

Related Articles

RecSysPersonalization

Recommendation Systems: From Collaborative Filtering to Deep Learning

A comprehensive journey through recommendation system architectures. From the Netflix Prize and matrix factorization to neural collaborative filtering and two-tower models—understand the foundations before the transformer revolution.

30 min read
RecSysPersonalization

Generative AI for Recommendation Systems: LLMs Meet Personalization

A comprehensive guide to LLM-powered recommendation systems. From feature augmentation to conversational agents, understand how generative AI is transforming personalization.

9 min read
LLMsML Engineering

Attention Mechanisms: From Self-Attention to FlashAttention

A comprehensive deep dive into attention mechanisms—the core innovation powering modern LLMs. From the intuition behind self-attention to the engineering of FlashAttention, understand how transformers actually work.

7 min read
EducationLLMs

Transformer Architecture: A Complete Deep Dive

A comprehensive exploration of the transformer architecture—from embedding layers through attention and feed-forward networks to the output head. Understand why decoder-only models dominate, how residual connections enable deep networks, and the engineering decisions behind GPT, Llama, and modern LLMs.

30 min read

Embedding Models & Strategies: Choosing and Optimizing Embeddings for AI Applications

Comprehensive guide to embedding models for RAG, search, and AI applications. Comparison of text-embedding-3, BGE, E5, Cohere Embed v4, and Voyage with guidance on fine-tuning, dimensionality, multimodal embeddings, and production optimization.

15 min read
EducationLLMs

LLM-Powered Search for E-Commerce: Beyond NER and Elasticsearch

A deep dive into building intelligent e-commerce search systems that understand natural language, leverage metadata effectively, and support multi-turn conversations—moving beyond classical NER + Elasticsearch approaches.

30 min read

Vector Databases: A Comprehensive Guide to Pinecone, Weaviate, Qdrant, Milvus & Chroma

Deep dive into vector database architecture, indexing algorithms, and production considerations. Comprehensive comparison of Pinecone vs Weaviate vs Qdrant vs Milvus vs Chroma with benchmarks, pricing, and use case recommendations for 2025.

13 min read