Transformers for Recommendation Systems: From SASRec to HSTU
A comprehensive deep dive into transformer-based recommendation systems. From the fundamentals of sequential recommendation to Meta's trillion-parameter HSTU, understand how attention mechanisms revolutionized personalization.
Table of Contents
Why Transformers Changed Recommendation Systems
The success of transformers in NLP sparked a revolution in recommendation systems. Before 2018, sequential recommendation was dominated by RNNs and Markov chains. Then came SASRec, BERT4Rec, and a parade of transformer-based models that consistently outperformed their predecessors.
But why do transformers work so well for recommendations? The answer lies in what sequential recommendation actually requires:
-
Capturing user intent from behavior sequences: Users don't interact with items randomly. Each click, purchase, or watch reveals preferences that evolve over time.
-
Modeling complex dependencies: A user who bought running shoes, then a fitness tracker, then protein powder is on a fitness journey. These items relate to each other across many steps.
-
Handling variable-length histories: Some users have 10 interactions, others have 10,000. The model must work for both.
-
Real-time inference at scale: Recommendations must be computed in milliseconds for millions of users.
Transformers address all of these challenges through their attention mechanism—allowing direct connections between any two items in a user's history, regardless of how far apart they occurred.
┌─────────────────────────────────────────────────────────────────────────┐
│ EVOLUTION OF SEQUENTIAL RECOMMENDATION │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ BEFORE TRANSFORMERS (2010-2017): │
│ ───────────────────────────────── │
│ │
│ Markov Chains → item₁ → item₂ → item₃ → ? │
│ (Only last item) Limited to previous transition │
│ │
│ RNNs/GRUs/LSTMs → item₁ → h₁ → item₂ → h₂ → item₃ → h₃ → ? │
│ (Sequential) Hidden state compresses history │
│ Information degrades over distance │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ TRANSFORMER ERA (2018-NOW): │
│ ─────────────────────────── │
│ │
│ Self-Attention → item₁ ←──────────────────┐ │
│ (All-to-all) item₂ ←───────────┐ │ │
│ item₃ ←────┐ │ │ │
│ ? ────┴──────┴──────┘ │
│ │
│ Every item directly attends to every other │
│ No information bottleneck │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ WHY THIS MATTERS FOR RECOMMENDATIONS: │
│ │
│ User history: [iPhone case, AirPods, MacBook charger, ..., iPad] │
│ │
│ When predicting next item after "iPad": │
│ - Transformer sees ALL Apple products directly │
│ - RNN would have compressed early items into hidden state │
│ - Markov only sees "iPad" │
│ │
│ The "Apple ecosystem" pattern spans the entire sequence │
│ │
└─────────────────────────────────────────────────────────────────────────┘
2024-2025 State of the Art: Meta's HSTU (Hierarchical Sequential Transduction Units) demonstrated that transformers can scale to trillion parameters for recommendations, achieving 12.4%+ topline metric improvements in production. The key innovations: removing softmax normalization and using generative (next-item prediction) training objectives.
Part I: Foundations of Sequential Recommendation
The Sequential Recommendation Problem
Given a user's interaction history , predict the next item they will interact with. This seems simple, but the challenges are significant:
┌─────────────────────────────────────────────────────────────────────────┐
│ SEQUENTIAL RECOMMENDATION SETUP │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ INPUT: User interaction sequence │
│ ──────────────────────────────── │
│ │
│ User A: [item_23, item_156, item_89, item_42, item_301, ...] │
│ ↑ ↑ ↑ ↑ ↑ │
│ t=1 t=2 t=3 t=4 t=5 │
│ │
│ Each item is typically represented as: │
│ - Item ID (integer index into embedding table) │
│ - Optional: item features (category, price, etc.) │
│ │
│ OUTPUT: Probability distribution over all items │
│ ─────────────────────────────────────────────── │
│ │
│ P(next_item | history) = softmax(scores over all items) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ CHALLENGES: │
│ │
│ 1. SCALE │
│ - Millions of items, billions of users │
│ - Cannot compute full softmax over all items │
│ - Need efficient retrieval/ranking stages │
│ │
│ 2. SPARSITY │
│ - Most users interact with tiny fraction of items │
│ - Many items have few interactions (long-tail) │
│ - Cold-start for new users/items │
│ │
│ 3. DYNAMICS │
│ - User preferences change over time │
│ - New items added continuously │
│ - Seasonal and trending effects │
│ │
│ 4. LATENCY │
│ - Recommendations must be computed in <100ms │
│ - Often much stricter (10-20ms) │
│ - At massive scale (millions of QPS) │
│ │
└─────────────────────────────────────────────────────────────────────────┘
From Matrix Factorization to Deep Learning
Before diving into transformers, it's worth understanding what came before:
Matrix Factorization (2009): The Netflix Prize era. Represent users and items as latent vectors, predict ratings as dot products. Simple, interpretable, but static—doesn't capture sequential patterns.
RNN-based Models (2016-2018): GRU4Rec and variants processed sequences with recurrent networks. Better at capturing order, but suffered from vanishing gradients and sequential bottlenecks.
CNN-based Models (2018): Caser used convolutional filters to capture local patterns in sequences. Fast, but limited receptive field.
Attention-based Models (2018+): SASRec, BERT4Rec, and descendants. This is where we'll focus.
Part II: Transformer Architecture Deep Dive
Before implementing SASRec and BERT4Rec, we need to thoroughly understand the transformer architecture. This section covers every component in detail with mathematical formulations, intuitions, and implementation examples.
The Attention Mechanism: Core Intuition
Attention answers: "Which parts of the input should I focus on when producing this output?" In recommendation, this becomes: "Which past items are relevant for predicting the next item?"
┌─────────────────────────────────────────────────────────────────────────┐
│ ATTENTION MECHANISM INTUITION │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ USER HISTORY: [Running Shoes, Fitness Tracker, Protein Powder, ...] │
│ │
│ Question: "What should I recommend next?" │
│ │
│ ATTENTION PROCESS: │
│ ───────────────── │
│ │
│ Step 1: Create a QUERY from current position │
│ "I'm looking for items related to fitness..." │
│ │
│ Step 2: Compare query against all past items (KEYS) │
│ Running Shoes → High relevance (fitness) │
│ Fitness Tracker → High relevance (fitness) │
│ Phone Case → Low relevance (unrelated) │
│ Protein Powder → High relevance (fitness) │
│ │
│ Step 3: Weight VALUES by relevance scores │
│ Output = 0.35×(Running Shoes) + 0.30×(Tracker) + │
│ 0.05×(Phone Case) + 0.30×(Protein) │
│ │
│ Result: Context-aware representation for prediction │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Scaled Dot-Product Attention: The Mathematics
The attention mechanism is the heart of transformer models, and understanding it deeply is essential for building effective recommendation systems. At its core, attention computes a weighted average of values, where the weights are determined by how well queries match keys.
Why "Scaled Dot-Product"?
The name describes exactly what happens: we compute dot products between queries and keys (measuring similarity), then scale the result. The dot product is a natural similarity measure—when two vectors point in similar directions, their dot product is large. When they're orthogonal (unrelated), the dot product is zero. This is precisely what we want: items that are "similar" to what we're looking for should have high attention weights.
The fundamental attention operation:
Understanding Each Component:
- Query (Q): "What am I looking for?" In recommendations, this represents the current context or position asking "what should come next?"
- Key (K): "What do I have to offer?" Each item in the history advertises itself through its key representation.
- Value (V): "What information do I carry?" The actual content that gets aggregated based on attention weights.
- Scaling by : Critical for stable training. Without scaling, dot products grow with dimension, making softmax outputs extremely peaked (close to one-hot). This starves gradients and makes learning difficult.
The Information Flow:
Think of attention as a soft database lookup. The query asks a question, keys determine relevance (which database entries match?), and values provide the answer (what's stored in those entries?). Unlike hard lookups that return one result, soft attention returns a weighted blend of all values, proportional to relevance.
In recommendation systems specifically:
- Queries come from the position we're predicting for
- Keys and Values come from the user's interaction history
- The output tells us: "Given where the user is in their journey, what aspects of their history are most relevant?"
┌─────────────────────────────────────────────────────────────────────────┐
│ SCALED DOT-PRODUCT ATTENTION: STEP BY STEP │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ EXAMPLE: Sequence of 4 items, embedding dim = 8 │
│ │
│ INPUT: X ∈ ℝ^(4×8) │
│ ───── │
│ Item 1: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8] │
│ Item 2: [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] │
│ Item 3: [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] │
│ Item 4: [0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1] │
│ │
│ STEP 1: Linear projections to Q, K, V │
│ ────────────────────────────────────── │
│ │
│ Q = X × W_Q where W_Q ∈ ℝ^(8×8) │
│ K = X × W_K where W_K ∈ ℝ^(8×8) │
│ V = X × W_V where W_V ∈ ℝ^(8×8) │
│ │
│ Each item now has query, key, and value representations │
│ │
│ STEP 2: Compute attention scores (QK^T) │
│ ───────────────────────────────────────── │
│ │
│ Scores = Q × K^T ∈ ℝ^(4×4) │
│ │
│ │ Item1 Item2 Item3 Item4 │
│ ───────┼──────────────────────────── │
│ Item1 │ 0.82 0.76 0.71 0.65 ← How much Item1 attends to │
│ Item2 │ 0.76 0.89 0.84 0.79 each other item │
│ Item3 │ 0.71 0.84 0.95 0.90 │
│ Item4 │ 0.65 0.79 0.90 1.02 │
│ │
│ STEP 3: Scale by √d_k │
│ ──────────────────────── │
│ │
│ Why scale? Dot products grow with dimension, making softmax │
│ extremely peaked. Division by √d_k keeps gradients stable. │
│ │
│ Scaled = Scores / √8 = Scores / 2.83 │
│ │
│ STEP 4: Apply causal mask (for autoregressive) │
│ ────────────────────────────────────────────── │
│ │
│ │ Item1 Item2 Item3 Item4 │
│ ───────┼──────────────────────────── │
│ Item1 │ 0.29 -∞ -∞ -∞ ← Item1 only sees itself │
│ Item2 │ 0.27 0.31 -∞ -∞ ← Item2 sees 1,2 │
│ Item3 │ 0.25 0.30 0.34 -∞ ← Item3 sees 1,2,3 │
│ Item4 │ 0.23 0.28 0.32 0.36 ← Item4 sees all │
│ │
│ STEP 5: Softmax (row-wise normalization) │
│ ───────────────────────────────────────── │
│ │
│ │ Item1 Item2 Item3 Item4 │
│ ───────┼──────────────────────────── │
│ Item1 │ 1.00 0.00 0.00 0.00 ← Weights sum to 1 │
│ Item2 │ 0.49 0.51 0.00 0.00 │
│ Item3 │ 0.30 0.33 0.37 0.00 │
│ Item4 │ 0.22 0.24 0.26 0.28 │
│ │
│ STEP 6: Weighted sum of values │
│ ───────────────────────────────── │
│ │
│ Output[i] = Σ_j (attention_weights[i,j] × V[j]) │
│ │
│ Output[4] = 0.22×V[1] + 0.24×V[2] + 0.26×V[3] + 0.28×V[4] │
│ │
│ Each output is a weighted combination of all visible values │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Why Causal Masking Matters for Recommendations:
The causal mask (Step 4 above) is crucial for sequential recommendation. When predicting what a user will do next, we can only use information from items they've already interacted with—we can't "peek" at future interactions. This creates an autoregressive setup where each position only attends to previous positions.
This is different from bidirectional models like BERT4Rec, which allow items to attend to both past and future positions (useful during training with masked item prediction, but requires special handling at inference time).
The Softmax Temperature:
The scaling factor acts as a "temperature" for the softmax. Lower temperatures (larger divisor) make the distribution more uniform; higher temperatures make it more peaked. The default was found empirically to work well, but some applications tune this as a hyperparameter.
Implementation Considerations:
When implementing attention, several practical concerns arise:
- Numerical stability: Subtracting the max before softmax prevents overflow
- Memory efficiency: For long sequences, attention matrices can be huge (n² memory)
- Masked positions: Setting to -inf before softmax ensures zero attention weight
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class ScaledDotProductAttention(nn.Module):
"""
Scaled Dot-Product Attention with detailed implementation.
This is the core building block of all transformer models.
Understanding this deeply is essential for RecSys transformers.
"""
def __init__(self, temperature: float = None):
super().__init__()
self.temperature = temperature # If None, computed from d_k
def forward(
self,
query: torch.Tensor, # (batch, n_query, d_k)
key: torch.Tensor, # (batch, n_key, d_k)
value: torch.Tensor, # (batch, n_key, d_v)
mask: torch.Tensor = None, # (batch, n_query, n_key) or broadcastable
return_attention: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args:
query: What we're looking for
key: What we're comparing against
value: What we retrieve
mask: True = masked (will not attend), False = can attend
Returns:
output: (batch, n_query, d_v) - weighted sum of values
attention_weights: (batch, n_query, n_key) - attention distribution
"""
d_k = query.size(-1)
temperature = self.temperature or math.sqrt(d_k)
# Step 1: Compute attention scores
# (batch, n_query, d_k) @ (batch, d_k, n_key) -> (batch, n_query, n_key)
scores = torch.matmul(query, key.transpose(-2, -1))
# Step 2: Scale by temperature (sqrt(d_k))
scores = scores / temperature
# Step 3: Apply mask (set masked positions to -inf)
if mask is not None:
scores = scores.masked_fill(mask, float('-inf'))
# Step 4: Softmax to get attention weights
attention_weights = F.softmax(scores, dim=-1)
# Handle case where entire row is masked (all -inf -> nan after softmax)
attention_weights = attention_weights.masked_fill(
torch.isnan(attention_weights), 0.0
)
# Step 5: Weighted sum of values
output = torch.matmul(attention_weights, value)
if return_attention:
return output, attention_weights
return output, attention_weights
# Demonstration
def demonstrate_attention():
"""Step-by-step demonstration of attention computation."""
# Small example: 2 items, 4-dimensional embeddings
batch_size = 1
seq_len = 4
d_model = 8
# Random input (simulating item embeddings)
torch.manual_seed(42)
x = torch.randn(batch_size, seq_len, d_model)
# Linear projections (normally these are learned)
W_q = nn.Linear(d_model, d_model, bias=False)
W_k = nn.Linear(d_model, d_model, bias=False)
W_v = nn.Linear(d_model, d_model, bias=False)
Q = W_q(x) # (1, 4, 8)
K = W_k(x) # (1, 4, 8)
V = W_v(x) # (1, 4, 8)
print("Query shape:", Q.shape)
print("Key shape:", K.shape)
print("Value shape:", V.shape)
# Compute attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) # (1, 4, 4)
print("\nRaw attention scores:\n", scores[0])
# Scale
scaled_scores = scores / math.sqrt(d_model)
print("\nScaled scores:\n", scaled_scores[0])
# Create causal mask
causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
print("\nCausal mask (True = masked):\n", causal_mask)
# Apply mask
masked_scores = scaled_scores.masked_fill(causal_mask, float('-inf'))
print("\nMasked scores:\n", masked_scores[0])
# Softmax
attention_weights = F.softmax(masked_scores, dim=-1)
print("\nAttention weights (sum to 1 per row):\n", attention_weights[0])
print("Row sums:", attention_weights[0].sum(dim=-1))
# Output
output = torch.matmul(attention_weights, V)
print("\nOutput shape:", output.shape)
return output, attention_weights
# Run demonstration
output, weights = demonstrate_attention()
Multi-Head Attention: Capturing Different Relationships
A single attention head can only learn one way of relating items to each other. But user behavior is complex—items relate through categories, brands, complementarity, price ranges, and many other dimensions simultaneously. Multi-head attention solves this by running multiple attention operations in parallel, each learning different relationship patterns.
The Key Insight:
Imagine analyzing a user who bought running shoes, a fitness tracker, and protein powder. Different aspects of these items matter for different predictions:
- Category view: All are fitness-related → recommend more fitness items
- Complementary view: Shoes need socks, tracker needs charger → recommend accessories
- Price tier view: All mid-range → recommend similar price points
- Brand view: Prefers Nike → recommend more Nike products
- Temporal view: Bought in sequence suggesting workout routine → recommend recovery items
A single attention head might capture one of these perspectives, but we want them all. Multi-head attention dedicates separate "heads" to learn these different relationship types, then combines their insights.
How It Works Mathematically:
Instead of one large attention operation, we run h smaller ones in parallel. Each head operates on a lower-dimensional subspace (d_k = d_model / h), learns its own Q/K/V projections, and captures its own relationship type. The outputs are concatenated and projected back to the original dimension.
This isn't just parallelization for efficiency—it's fundamentally about representational diversity. Different heads learn different patterns, and the combination is more expressive than a single head of the same total dimension would be.
Practical Implications for Recommendations:
Research has shown that different attention heads in trained RecSys models actually do specialize:
- Some heads focus on recency (high weight on recent items)
- Some heads focus on category similarity
- Some heads learn positional patterns (periodic preferences)
- Some heads capture user-specific idiosyncrasies
This interpretability is valuable—you can inspect attention patterns to understand why the model made a recommendation.
┌─────────────────────────────────────────────────────────────────────────┐
│ MULTI-HEAD ATTENTION │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ WHY MULTIPLE HEADS? │
│ ─────────────────── │
│ │
│ Single head might learn: "Items from same category" │
│ │
│ But we want to capture multiple relationship types: │
│ • Head 1: Category similarity (shoes → shoes) │
│ • Head 2: Complementary items (shoes → socks) │
│ • Head 3: Price range similarity │
│ • Head 4: Temporal patterns (morning → evening items) │
│ • Head 5: Brand affinity │
│ • Head 6: Style similarity │
│ • Head 7: Seasonal patterns │
│ • Head 8: Cross-category bundles │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ ARCHITECTURE: │
│ │
│ Input X ∈ ℝ^(seq_len × d_model) │
│ │
│ ┌──────────────────────────────────────────────────────┐ │
│ │ │ │
│ ▼ ▼ │
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ │
│ │ Head 1 │ │ Head 2 │ │ Head 3 │ ... │ Head h │ │
│ │ d_k=64 │ │ d_k=64 │ │ d_k=64 │ │ d_k=64 │ │
│ └────┬────┘ └────┬────┘ └────┬────┘ └────┬────┘ │
│ │ │ │ │ │
│ └──────┬─────┴─────┬──────┴─────────┬──────┘ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ ┌─────────────────────────────────────┐ │
│ │ Concatenate: (seq_len, h×d_k) │ │
│ └─────────────────┬───────────────────┘ │
│ │ │
│ ▼ │
│ ┌───────────────┐ │
│ │ Linear W_O │ │
│ │ (h×d_k → d_m) │ │
│ └───────┬───────┘ │
│ │ │
│ ▼ │
│ Output ∈ ℝ^(seq_len × d_model) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ MATHEMATICS: │
│ │
│ head_i = Attention(XW_Q^i, XW_K^i, XW_V^i) │
│ │
│ MultiHead(X) = Concat(head_1, ..., head_h) × W_O │
│ │
│ Where: │
│ • W_Q^i, W_K^i ∈ ℝ^(d_model × d_k) │
│ • W_V^i ∈ ℝ^(d_model × d_v) │
│ • W_O ∈ ℝ^(h×d_v × d_model) │
│ • Typically: d_k = d_v = d_model / h │
│ │
│ PARAMETER COUNT (d_model=512, h=8): │
│ • Per head: 3 × (512 × 64) = 98,304 │
│ • All heads: 8 × 98,304 = 786,432 │
│ • Output projection: 512 × 512 = 262,144 │
│ • Total: 1,048,576 parameters │
│ │
└─────────────────────────────────────────────────────────────────────────┘
class MultiHeadAttention(nn.Module):
"""
Multi-Head Attention with complete implementation details.
Each head operates on a subspace of the input, learning different
attention patterns. Outputs are concatenated and projected.
"""
def __init__(
self,
d_model: int,
num_heads: int,
dropout: float = 0.1,
bias: bool = True,
):
super().__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # Dimension per head
self.d_v = d_model // num_heads
# Separate projection matrices for Q, K, V
# In practice, often combined into one large matrix for efficiency
self.W_q = nn.Linear(d_model, d_model, bias=bias)
self.W_k = nn.Linear(d_model, d_model, bias=bias)
self.W_v = nn.Linear(d_model, d_model, bias=bias)
# Output projection
self.W_o = nn.Linear(d_model, d_model, bias=bias)
# Dropout on attention weights
self.dropout = nn.Dropout(dropout)
# For storing attention weights (useful for visualization)
self.attention_weights = None
def forward(
self,
query: torch.Tensor, # (batch, seq_len, d_model)
key: torch.Tensor, # (batch, seq_len, d_model)
value: torch.Tensor, # (batch, seq_len, d_model)
mask: torch.Tensor = None,
return_attention: bool = False,
) -> torch.Tensor:
"""
For self-attention: query = key = value = x
For cross-attention: query from decoder, key/value from encoder
"""
batch_size = query.size(0)
seq_len_q = query.size(1)
seq_len_k = key.size(1)
# Step 1: Linear projections
Q = self.W_q(query) # (batch, seq_len_q, d_model)
K = self.W_k(key) # (batch, seq_len_k, d_model)
V = self.W_v(value) # (batch, seq_len_k, d_model)
# Step 2: Reshape for multi-head: split d_model into num_heads × d_k
# (batch, seq_len, d_model) -> (batch, seq_len, num_heads, d_k)
# -> (batch, num_heads, seq_len, d_k)
Q = Q.view(batch_size, seq_len_q, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, seq_len_k, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, seq_len_k, self.num_heads, self.d_v).transpose(1, 2)
# Step 3: Scaled dot-product attention for each head
# scores: (batch, num_heads, seq_len_q, seq_len_k)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# Step 4: Apply mask
if mask is not None:
# Expand mask for num_heads dimension
if mask.dim() == 2:
mask = mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq_q, seq_k)
elif mask.dim() == 3:
mask = mask.unsqueeze(1) # (batch, 1, seq_q, seq_k)
scores = scores.masked_fill(mask, float('-inf'))
# Step 5: Softmax and dropout
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
# Store for visualization
self.attention_weights = attention_weights.detach()
# Step 6: Apply attention to values
# (batch, num_heads, seq_len_q, seq_len_k) @ (batch, num_heads, seq_len_k, d_v)
# -> (batch, num_heads, seq_len_q, d_v)
context = torch.matmul(attention_weights, V)
# Step 7: Concatenate heads
# (batch, num_heads, seq_len_q, d_v) -> (batch, seq_len_q, num_heads, d_v)
# -> (batch, seq_len_q, d_model)
context = context.transpose(1, 2).contiguous().view(
batch_size, seq_len_q, self.d_model
)
# Step 8: Final linear projection
output = self.W_o(context)
if return_attention:
return output, attention_weights
return output
def visualize_attention_heads():
"""
Demonstrate what different attention heads might learn.
In practice, heads learn these patterns from data.
"""
# Example: 4 items in sequence
items = ["Running Shoes", "Yoga Mat", "Protein Powder", "Water Bottle"]
# Hypothetical attention patterns from different heads
patterns = {
"Head 1 (Category)": [
[1.0, 0.0, 0.0, 0.0], # Shoes → Shoes
[0.0, 1.0, 0.0, 0.0], # Yoga → Yoga
[0.2, 0.1, 0.7, 0.0], # Protein → similar to shoes (fitness)
[0.2, 0.2, 0.2, 0.4], # Bottle → everything (universal)
],
"Head 2 (Complementary)": [
[0.3, 0.3, 0.4, 0.0], # Shoes → Protein (workout stack)
[0.5, 0.2, 0.3, 0.0], # Yoga → Shoes (cross-training)
[0.4, 0.2, 0.2, 0.2], # Protein → workout gear
[0.3, 0.3, 0.2, 0.2], # Bottle → all fitness
],
"Head 3 (Recency)": [
[1.0, 0.0, 0.0, 0.0], # Strong recency bias
[0.2, 0.8, 0.0, 0.0],
[0.1, 0.2, 0.7, 0.0],
[0.1, 0.1, 0.3, 0.5],
],
"Head 4 (Position)": [
[1.0, 0.0, 0.0, 0.0], # Always attend to first item
[0.8, 0.2, 0.0, 0.0],
[0.7, 0.2, 0.1, 0.0],
[0.6, 0.2, 0.1, 0.1],
],
}
print("Attention Pattern Visualization")
print("=" * 60)
for head_name, pattern in patterns.items():
print(f"\n{head_name}:")
print("-" * 40)
for i, row in enumerate(pattern):
item = items[i]
attending_to = [f"{items[j]}({row[j]:.1f})" for j in range(i+1) if row[j] > 0.1]
print(f" {item:15} attends to: {', '.join(attending_to)}")
Positional Encodings: Teaching Position Awareness
One of the most counterintuitive aspects of attention is that it's permutation-invariant—without any additional information, attention treats sequences as unordered sets. The output for a sequence [A, B, C] would be identical to [C, B, A], just reordered. This is a problem because order carries crucial meaning.
Why Order Matters in Recommendations:
Sequential patterns are everywhere in user behavior:
- Intent evolution: Search → Browse → Compare → Purchase follows a decision journey
- Session dynamics: Morning coffee browsing differs from evening relaxation browsing
- Recency effects: Recent interactions predict immediate next actions better than older ones
- Temporal patterns: Weekly shopping routines, seasonal preferences, time-of-day effects
If the model can't distinguish position, it loses all this information. A user who just purchased running shoes and then browsed socks has very different intent than one who browsed socks and then purchased running shoes.
Positional encodings solve this by adding position information to each item embedding. There are several approaches, each with trade-offs:
Sinusoidal encodings (original Transformer) use sine and cosine functions at different frequencies. They require no learned parameters and can theoretically extrapolate to longer sequences than seen in training. However, they provide fixed patterns that may not match the specific positional importance in your domain.
Learned embeddings (common in RecSys) treat position as another vocabulary to embed. Position 1, 2, 3... each get their own learned vector. This is flexible and learns domain-specific positional patterns, but can't extrapolate beyond the maximum training length.
Relative positional encodings (Transformer-XL, T5) encode the distance between positions rather than absolute positions. This often generalizes better—"3 positions ago" is meaningful regardless of whether you're at position 10 or position 100.
Rotary Position Embeddings (RoPE) (LLaMA, modern LLMs) rotate query and key vectors based on position. This elegantly combines absolute and relative encoding—the dot product naturally depends on relative position. RoPE has become the de facto standard for modern transformers due to excellent extrapolation properties.
For recommendation systems specifically:
Learned embeddings are most common because:
- Sequences are typically bounded (50-200 items)
- Different positions can have dramatically different importance (recency effects)
- The specific positional patterns matter (position 1 might be special—the most recent item)
However, if your sequences can vary significantly in length, or if you expect to serve longer sequences than you trained on, relative or rotary encodings are worth considering.
┌─────────────────────────────────────────────────────────────────────────┐
│ POSITIONAL ENCODING METHODS │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ PROBLEM: Attention treats [A, B, C] same as [C, B, A] │
│ ─────────────────────────────────────────────────────── │
│ │
│ But order matters for recommendations! │
│ [Searched Running Shoes → Added to Cart → Purchased] │
│ is very different from │
│ [Purchased → Added to Cart → Searched Running Shoes] │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ METHOD 1: SINUSOIDAL (Original Transformer) │
│ ─────────────────────────────────────────────── │
│ │
│ PE(pos, 2i) = sin(pos / 10000^(2i/d)) │
│ PE(pos, 2i+1) = cos(pos / 10000^(2i/d)) │
│ │
│ Properties: │
│ ✓ No learned parameters │
│ ✓ Can extrapolate to longer sequences │
│ ✓ Relative positions captured: PE(pos+k) = f(PE(pos)) │
│ ✗ Fixed patterns, not task-specific │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ METHOD 2: LEARNED EMBEDDINGS (Common in RecSys) │
│ ─────────────────────────────────────────────────── │
│ │
│ PE = nn.Embedding(max_seq_len, d_model) │
│ │
│ Properties: │
│ ✓ Learns task-specific position patterns │
│ ✓ Simple to implement │
│ ✗ Cannot extrapolate beyond max_seq_len │
│ ✗ Requires more parameters │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ METHOD 3: RELATIVE POSITIONAL ENCODING (Transformer-XL, T5) │
│ ──────────────────────────────────────────────────────────── │
│ │
│ Instead of absolute positions, encode relative distance: │
│ "Item A is 3 positions before item B" │
│ │
│ Attention(Q, K) = softmax(QK^T + bias(i-j)) │
│ │
│ Properties: │
│ ✓ Better generalization to different sequence lengths │
│ ✓ Captures "distance" rather than "absolute position" │
│ ✗ More complex implementation │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ METHOD 4: ROTARY POSITION EMBEDDING (RoPE) - Modern Standard │
│ ──────────────────────────────────────────────────────────── │
│ │
│ Rotate query/key vectors based on position: │
│ q'_m = R_θ,m × q_m │
│ k'_n = R_θ,n × k_n │
│ │
│ Then: q'_m · k'_n depends only on (m-n), the relative position │
│ │
│ Properties: │
│ ✓ Combines benefits of absolute and relative encodings │
│ ✓ Decays attention with distance (like a prior) │
│ ✓ Used in modern LLMs (LLaMA, Mistral, GPT-NeoX) │
│ ✓ Excellent extrapolation properties │
│ │
└─────────────────────────────────────────────────────────────────────────┘
class SinusoidalPositionalEncoding(nn.Module):
"""
Original transformer positional encoding using sinusoidal functions.
Deterministic, no learned parameters.
"""
def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
# Create positional encoding matrix
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
# Compute div_term: 10000^(2i/d_model)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
# Apply sin to even indices, cos to odd indices
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
# Add batch dimension and register as buffer (not a parameter)
pe = pe.unsqueeze(0) # (1, max_len, d_model)
self.register_buffer('pe', pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (batch, seq_len, d_model) - input embeddings
Returns:
(batch, seq_len, d_model) - embeddings + positional encoding
"""
seq_len = x.size(1)
x = x + self.pe[:, :seq_len, :]
return self.dropout(x)
class LearnedPositionalEncoding(nn.Module):
"""
Learned positional embeddings - common in RecSys transformers.
Simple and effective for fixed-length sequences.
"""
def __init__(self, d_model: int, max_len: int = 512, dropout: float = 0.1):
super().__init__()
self.position_embedding = nn.Embedding(max_len, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (batch, seq_len, d_model)
"""
seq_len = x.size(1)
positions = torch.arange(seq_len, device=x.device)
pos_emb = self.position_embedding(positions) # (seq_len, d_model)
x = x + pos_emb
return self.dropout(x)
class RotaryPositionalEncoding(nn.Module):
"""
Rotary Position Embedding (RoPE) - modern standard.
Used in LLaMA, Mistral, and increasingly in RecSys.
Key idea: Rotate query and key vectors based on position.
The dot product then naturally captures relative position.
"""
def __init__(self, d_model: int, max_len: int = 2048, base: int = 10000):
super().__init__()
self.d_model = d_model
self.max_len = max_len
# Compute rotation frequencies
inv_freq = 1.0 / (base ** (torch.arange(0, d_model, 2).float() / d_model))
self.register_buffer('inv_freq', inv_freq)
# Precompute rotation matrices
self._build_cache(max_len)
def _build_cache(self, seq_len: int):
"""Precompute sin and cos for all positions."""
positions = torch.arange(seq_len)
freqs = torch.outer(positions, self.inv_freq) # (seq_len, d_model/2)
# Create rotation matrix components
self.register_buffer('cos_cached', freqs.cos())
self.register_buffer('sin_cached', freqs.sin())
def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to query and key.
Args:
q: (batch, num_heads, seq_len, head_dim)
k: (batch, num_heads, seq_len, head_dim)
Returns:
Rotated q and k
"""
seq_len = q.size(2)
# Get cached values
cos = self.cos_cached[:seq_len] # (seq_len, d/2)
sin = self.sin_cached[:seq_len] # (seq_len, d/2)
# Reshape for broadcasting
cos = cos.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, d/2)
sin = sin.unsqueeze(0).unsqueeze(0)
# Apply rotation
q_rot = self._rotate(q, cos, sin)
k_rot = self._rotate(k, cos, sin)
return q_rot, k_rot
def _rotate(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
"""Apply rotation to each position."""
# Split into two halves
x1, x2 = x[..., ::2], x[..., 1::2]
# Apply rotation: [x1, x2] -> [x1*cos - x2*sin, x1*sin + x2*cos]
rotated = torch.stack([
x1 * cos - x2 * sin,
x1 * sin + x2 * cos
], dim=-1).flatten(-2)
return rotated
# Comparison demonstration
def compare_positional_encodings():
"""Show differences between positional encoding methods."""
d_model = 64
seq_len = 10
batch_size = 2
# Random input
x = torch.randn(batch_size, seq_len, d_model)
# Different encodings
sinusoidal = SinusoidalPositionalEncoding(d_model)
learned = LearnedPositionalEncoding(d_model)
out_sin = sinusoidal(x)
out_learned = learned(x)
print("Positional Encoding Comparison")
print("=" * 50)
print(f"Input shape: {x.shape}")
print(f"Output shape (both): {out_sin.shape}")
# Show that learned embeddings are different each run (before training)
print(f"\nLearned PE parameter shape: {learned.position_embedding.weight.shape}")
print(f"Sinusoidal PE requires no parameters (deterministic)")
# Visualize the encoding patterns
print("\nSinusoidal encoding pattern (first 4 dims, first 5 positions):")
print(sinusoidal.pe[0, :5, :4])
Layer Normalization: Stabilizing Deep Networks
Training deep neural networks is inherently unstable. As signals pass through many layers, they can explode (grow exponentially) or vanish (shrink to zero). Normalization techniques address this by constraining activations to reasonable ranges, ensuring gradients flow properly during backpropagation.
Why Layer Norm over Batch Norm?
Batch Normalization, invented for CNNs, normalizes across the batch dimension—computing statistics over all examples in a mini-batch. This works well for images but fails for transformers:
- Variable sequence lengths: Different sequences have different lengths, making it unclear how to compute batch statistics across positions
- Batch size dependence: Batch norm's statistics depend on batch size, causing training-inference mismatch
- Sequence modeling: In autoregressive generation, we process one token at a time—there's no batch to normalize over
Layer Normalization instead normalizes across the feature dimension. For each position in each sequence, we compute the mean and variance over the d_model features, then normalize. This is completely independent of batch size and handles variable-length sequences naturally.
The Pre-Norm vs Post-Norm Debate:
The original Transformer applied layer norm after the residual connection (post-norm):
output = LayerNorm(x + Sublayer(x))
Modern transformers (GPT-2 onward) apply layer norm before the sublayer (pre-norm):
output = x + Sublayer(LayerNorm(x))
Pre-norm has significant advantages:
- Better gradient flow: The residual path is "clean"—gradients flow directly without passing through normalization
- Training stability: Pre-norm models train more stably, often without learning rate warmup
- Easier to scale: Deeper models train more reliably with pre-norm
The intuition: in post-norm, every layer modifies the residual path, accumulating changes. In pre-norm, the residual path is an "identity highway" that just accumulates additions, while normalization happens in a branch.
RMSNorm: A Simpler Alternative:
RMSNorm (Root Mean Square Normalization) simplifies layer norm by removing the mean-centering step—it only divides by the root mean square. This is computationally cheaper and works just as well in practice. LLaMA and other modern architectures use RMSNorm.
┌─────────────────────────────────────────────────────────────────────────┐
│ LAYER NORMALIZATION │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ BATCH NORM vs LAYER NORM: │
│ ───────────────────────── │
│ │
│ Input: (batch_size, seq_len, d_model) │
│ │
│ Batch Norm: Normalize across batch dimension │
│ Mean/var computed over (batch_size × seq_len) samples │
│ Problem: Varies with batch size, bad for variable-length │
│ │
│ Layer Norm: Normalize across feature dimension │
│ Mean/var computed over d_model features, per position │
│ Consistent regardless of batch size or sequence length │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ LAYER NORM FORMULA: │
│ ─────────────────── │
│ │
│ For input x ∈ ℝ^d: │
│ │
│ μ = (1/d) Σ x_i (mean across features) │
│ │
│ σ² = (1/d) Σ (x_i - μ)² (variance across features) │
│ │
│ y = γ × (x - μ) / √(σ² + ε) + β (normalize and scale) │
│ │
│ Where γ, β ∈ ℝ^d are learned parameters │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ PRE-NORM vs POST-NORM: │
│ ────────────────────── │
│ │
│ POST-NORM (Original Transformer): │
│ ───────────────────────────────── │
│ output = LayerNorm(x + Sublayer(x)) │
│ │
│ PRE-NORM (GPT-2, Most Modern): │
│ ──────────────────────────────── │
│ output = x + Sublayer(LayerNorm(x)) │
│ │
│ ┌─────────────────┐ ┌─────────────────┐ │
│ │ POST-NORM │ │ PRE-NORM │ │
│ ├─────────────────┤ ├─────────────────┤ │
│ │ x │ │ x │ │
│ │ │ │ │ │ │ │
│ │ ▼ │ │ ├─────┐ │ │
│ │ ┌─────────┐ │ │ ▼ │ │ │
│ │ │Sublayer │ │ │ ┌─────────┐│ │ │
│ │ └────┬────┘ │ │ │LayerNorm││ │ │
│ │ │ │ │ └────┬────┘│ │ │
│ │ │ │ │ ▼ │ │ │
│ │ ┌───┴───┐ │ │ ┌─────────┐│ │ │
│ │ │ Add │◄───│ │ │Sublayer ││ │ │
│ │ └───┬───┘ │ │ └────┬────┘│ │ │
│ │ ▼ │ │ │ │ │ │
│ │ ┌─────────┐ │ │ ┌───▼───┐ │ │ │
│ │ │LayerNorm│ │ │ │ Add │◄┘ │ │
│ │ └────┬────┘ │ │ └───┬───┘ │ │
│ │ ▼ │ │ ▼ │ │
│ │ output │ │ output │ │
│ └─────────────────┘ └─────────────────┘ │
│ │
│ PRE-NORM ADVANTAGES: │
│ • Better gradient flow (residual path is "clean") │
│ • More stable training for deep networks │
│ • No warm-up often needed │
│ • Default choice for modern transformers │
│ │
└─────────────────────────────────────────────────────────────────────────┘
class LayerNorm(nn.Module):
"""
Layer Normalization with detailed implementation.
Normalizes across the last dimension (features).
"""
def __init__(self, d_model: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
# Learned scale (γ) and shift (β) parameters
self.gamma = nn.Parameter(torch.ones(d_model))
self.beta = nn.Parameter(torch.zeros(d_model))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (..., d_model) - input tensor
Returns:
Normalized tensor of same shape
"""
# Compute mean and variance across last dimension
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
# Normalize
x_norm = (x - mean) / torch.sqrt(var + self.eps)
# Scale and shift
output = self.gamma * x_norm + self.beta
return output
class RMSNorm(nn.Module):
"""
Root Mean Square Layer Normalization.
Used in LLaMA and other modern architectures.
Simplifies LayerNorm by removing mean centering.
Slightly faster and works well in practice.
"""
def __init__(self, d_model: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d_model))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Compute RMS
rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
# Normalize and scale
return self.weight * (x / rms)
# Comparison
def compare_normalizations():
"""Compare LayerNorm and RMSNorm."""
d_model = 64
x = torch.randn(2, 10, d_model) # (batch, seq, features)
ln = LayerNorm(d_model)
rms = RMSNorm(d_model)
out_ln = ln(x)
out_rms = rms(x)
print("Normalization Comparison")
print("=" * 50)
print(f"Input mean: {x.mean():.4f}, std: {x.std():.4f}")
print(f"LayerNorm output mean: {out_ln.mean():.4f}, std: {out_ln.std():.4f}")
print(f"RMSNorm output mean: {out_rms.mean():.4f}, std: {out_rms.std():.4f}")
print(f"\nLayerNorm params: {sum(p.numel() for p in ln.parameters())}")
print(f"RMSNorm params: {sum(p.numel() for p in rms.parameters())}")
Feed-Forward Networks: The Memory of Transformers
The FFN in each transformer layer is surprisingly important. Recent research suggests FFN layers store factual knowledge, while attention routes information.
┌─────────────────────────────────────────────────────────────────────────┐
│ FEED-FORWARD NETWORK (FFN) │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ ARCHITECTURE: │
│ ───────────── │
│ │
│ FFN(x) = Activation(x × W₁ + b₁) × W₂ + b₂ │
│ │
│ Where: │
│ • W₁ ∈ ℝ^(d_model × d_ff) (expand) │
│ • W₂ ∈ ℝ^(d_ff × d_model) (contract) │
│ • d_ff = 4 × d_model (typically) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ VISUAL: │
│ │
│ Input │
│ │ │
│ ▼ │
│ ┌────────────────┐ │
│ │ d_model = 512 │ │
│ └───────┬────────┘ │
│ │ │
│ ▼ W₁ (expand 4x) │
│ ┌────────────────────────────────────────┐ │
│ │ d_ff = 2048 │ │
│ │ (Each neuron is a "key-value" pair) │ │
│ └───────────────────┬────────────────────┘ │
│ │ │
│ ▼ Activation (ReLU/GELU/SwiGLU) │
│ ┌────────────────────────────────────────┐ │
│ │ d_ff = 2048 │ │
│ │ (Sparse activation pattern) │ │
│ └───────────────────┬────────────────────┘ │
│ │ │
│ ▼ W₂ (contract) │
│ ┌────────────────┐ │
│ │ d_model = 512 │ │
│ └────────────────┘ │
│ │ │
│ ▼ │
│ Output │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ ACTIVATION FUNCTIONS: │
│ ───────────────────── │
│ │
│ ReLU (Original): max(0, x) │
│ Simple, fast, but "dead neurons" problem │
│ │
│ GELU (BERT, GPT-2): x × Φ(x) where Φ is Gaussian CDF │
│ Smooth, probabilistic interpretation │
│ │
│ SwiGLU (LLaMA, PaLM): Swish(xW₁) ⊙ (xV) │
│ Gated, best performance in practice │
│ But requires extra parameters (V matrix) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ FFN AS MEMORY: │
│ ────────────── │
│ │
│ Research shows FFN layers act as key-value memories: │
│ • W₁ rows are "keys" (patterns to match) │
│ • W₂ columns are "values" (outputs when key matches) │
│ • Activation creates sparse retrieval │
│ │
│ In RecSys: FFN might store "item category → related patterns" │
│ │
└─────────────────────────────────────────────────────────────────────────┘
class FeedForward(nn.Module):
"""Standard feed-forward network with ReLU/GELU."""
def __init__(
self,
d_model: int,
d_ff: int = None,
dropout: float = 0.1,
activation: str = 'gelu',
):
super().__init__()
d_ff = d_ff or 4 * d_model
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
if activation == 'relu':
self.activation = nn.ReLU()
elif activation == 'gelu':
self.activation = nn.GELU()
else:
raise ValueError(f"Unknown activation: {activation}")
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Expand -> Activate -> Contract
x = self.linear1(x)
x = self.activation(x)
x = self.dropout(x)
x = self.linear2(x)
return x
class SwiGLUFeedForward(nn.Module):
"""
SwiGLU activation: Used in LLaMA, PaLM, modern transformers.
Generally performs better than GELU but has more parameters.
SwiGLU(x) = Swish(xW) ⊙ (xV)
where Swish(x) = x × sigmoid(x)
"""
def __init__(
self,
d_model: int,
d_ff: int = None,
dropout: float = 0.1,
):
super().__init__()
# d_ff is split between gate and value projections
# So effective expansion is 2/3 of standard FFN for same param count
d_ff = d_ff or int(4 * d_model * 2/3)
self.gate_proj = nn.Linear(d_model, d_ff, bias=False)
self.up_proj = nn.Linear(d_model, d_ff, bias=False)
self.down_proj = nn.Linear(d_ff, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Gate path with Swish activation
gate = F.silu(self.gate_proj(x)) # Swish = x * sigmoid(x) = SiLU
# Up projection (no activation)
up = self.up_proj(x)
# Element-wise multiplication (gating)
hidden = gate * up
# Down projection
output = self.down_proj(self.dropout(hidden))
return output
# Parameter count comparison
def compare_ffn_variants():
"""Compare parameter counts and outputs of FFN variants."""
d_model = 512
d_ff = 2048
standard = FeedForward(d_model, d_ff, activation='gelu')
swiglu = SwiGLUFeedForward(d_model, int(d_ff * 2/3)) # Adjusted for fairness
standard_params = sum(p.numel() for p in standard.parameters())
swiglu_params = sum(p.numel() for p in swiglu.parameters())
print("FFN Variant Comparison")
print("=" * 50)
print(f"Standard (GELU) parameters: {standard_params:,}")
print(f"SwiGLU parameters: {swiglu_params:,}")
print(f"Ratio: {swiglu_params/standard_params:.2f}")
# Test forward pass
x = torch.randn(2, 10, d_model)
out_standard = standard(x)
out_swiglu = swiglu(x)
print(f"\nOutput shapes: {out_standard.shape} (both)")
Complete Transformer Block: Putting It All Together
Now we combine all components into a complete transformer block:
┌─────────────────────────────────────────────────────────────────────────┐
│ COMPLETE TRANSFORMER BLOCK │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ Input (seq_len, d_model) │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ PRE-NORM BLOCK 1 │ │
│ │ │ │
│ │ ┌─────────┐ │ │
│ │ │LayerNorm│ │ │
│ │ └────┬────┘ │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌─────────────────────────────────────────────────┐ │ │
│ │ │ Multi-Head Self-Attention │ │ │
│ │ │ │ │ │
│ │ │ • num_heads heads, each with d_k = d_model/h │ │ │
│ │ │ • Causal mask for autoregressive │ │ │
│ │ │ • Padding mask for variable length │ │ │
│ │ │ │ │ │
│ │ └───────────────────────┬───────────────────────────┘ │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌─────────┐ │ │
│ │ │ Dropout │ │ │
│ │ └────┬────┘ │ │
│ │ │ │ │
│ └──────────────────────────┼───────────────────────────────────────┘ │
│ │ │
│ Input ───────────────────┼────────────────────┐ │
│ │ │ │
│ ▼ │ │
│ ┌───────┐ │ │
│ │ Add │◄───────────────┘ │
│ └───┬───┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ PRE-NORM BLOCK 2 │ │
│ │ │ │
│ │ ┌─────────┐ │ │
│ │ │LayerNorm│ │ │
│ │ └────┬────┘ │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌─────────────────────────────────────────────────┐ │ │
│ │ │ Feed-Forward Network │ │ │
│ │ │ │ │ │
│ │ │ • Expand: d_model → d_ff (4× typical) │ │ │
│ │ │ • Activation: GELU or SwiGLU │ │ │
│ │ │ • Contract: d_ff → d_model │ │ │
│ │ │ │ │ │
│ │ └───────────────────────┬───────────────────────────┘ │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌─────────┐ │ │
│ │ │ Dropout │ │ │
│ │ └────┬────┘ │ │
│ │ │ │ │
│ └──────────────────────────┼───────────────────────────────────────┘ │
│ │ │
│ (from Add) ──────────────┼────────────────────┐ │
│ │ │ │
│ ▼ │ │
│ ┌───────┐ │ │
│ │ Add │◄───────────────┘ │
│ └───┬───┘ │
│ │ │
│ ▼ │
│ Output (seq_len, d_model) │
│ │
└─────────────────────────────────────────────────────────────────────────┘
class TransformerBlock(nn.Module):
"""
Complete transformer block with all components.
Pre-norm architecture (modern standard).
"""
def __init__(
self,
d_model: int,
num_heads: int,
d_ff: int = None,
dropout: float = 0.1,
activation: str = 'gelu',
use_swiglu: bool = False,
):
super().__init__()
d_ff = d_ff or 4 * d_model
# Layer norms (pre-norm architecture)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
# Multi-head attention
self.attention = MultiHeadAttention(
d_model=d_model,
num_heads=num_heads,
dropout=dropout,
)
# Feed-forward network
if use_swiglu:
self.ffn = SwiGLUFeedForward(d_model, d_ff, dropout)
else:
self.ffn = FeedForward(d_model, d_ff, dropout, activation)
# Dropout for residual connections
self.dropout = nn.Dropout(dropout)
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor = None,
return_attention: bool = False,
) -> torch.Tensor:
"""
Args:
x: (batch, seq_len, d_model) - input sequence
mask: (batch, seq_len, seq_len) - attention mask
"""
# Self-attention with pre-norm and residual
normed = self.norm1(x)
attn_output = self.attention(normed, normed, normed, mask)
x = x + self.dropout(attn_output)
# FFN with pre-norm and residual
normed = self.norm2(x)
ffn_output = self.ffn(normed)
x = x + self.dropout(ffn_output)
return x
class TransformerEncoder(nn.Module):
"""
Stack of transformer blocks for encoding sequences.
This is the core of SASRec, BERT4Rec, and similar models.
"""
def __init__(
self,
num_layers: int,
d_model: int,
num_heads: int,
d_ff: int = None,
dropout: float = 0.1,
activation: str = 'gelu',
):
super().__init__()
self.layers = nn.ModuleList([
TransformerBlock(
d_model=d_model,
num_heads=num_heads,
d_ff=d_ff,
dropout=dropout,
activation=activation,
)
for _ in range(num_layers)
])
# Final layer norm (important for pre-norm architecture)
self.final_norm = nn.LayerNorm(d_model)
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor = None,
) -> torch.Tensor:
"""
Args:
x: (batch, seq_len, d_model)
mask: (seq_len, seq_len) - causal mask
"""
for layer in self.layers:
x = layer(x, mask)
return self.final_norm(x)
# Complete demonstration
def build_complete_transformer():
"""Build and demonstrate a complete transformer encoder."""
# Hyperparameters
d_model = 256
num_heads = 8
num_layers = 6
d_ff = 1024
max_seq_len = 50
vocab_size = 10000 # Number of items
# Build model
class SequentialRecommender(nn.Module):
def __init__(self):
super().__init__()
# Embeddings
self.item_embedding = nn.Embedding(vocab_size + 1, d_model, padding_idx=0)
self.position_encoding = LearnedPositionalEncoding(d_model, max_seq_len)
# Transformer encoder
self.encoder = TransformerEncoder(
num_layers=num_layers,
d_model=d_model,
num_heads=num_heads,
d_ff=d_ff,
)
# Output projection (predict next item)
self.output_proj = nn.Linear(d_model, vocab_size)
def forward(self, item_ids):
# item_ids: (batch, seq_len)
batch_size, seq_len = item_ids.shape
# Embed items
x = self.item_embedding(item_ids)
x = self.position_encoding(x)
# Create causal mask
causal_mask = torch.triu(
torch.ones(seq_len, seq_len, device=item_ids.device),
diagonal=1
).bool()
# Encode
hidden = self.encoder(x, causal_mask)
# Predict next item
logits = self.output_proj(hidden)
return logits
model = SequentialRecommender()
# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Complete Transformer for Sequential Recommendation")
print("=" * 60)
print(f"d_model: {d_model}")
print(f"num_heads: {num_heads}")
print(f"num_layers: {num_layers}")
print(f"d_ff: {d_ff}")
print(f"vocab_size: {vocab_size}")
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
# Test forward pass
batch_size = 4
seq_len = 20
item_ids = torch.randint(1, vocab_size, (batch_size, seq_len))
logits = model(item_ids)
print(f"\nInput shape: {item_ids.shape}")
print(f"Output shape: {logits.shape}")
print(f"Output shape meaning: (batch={batch_size}, seq_len={seq_len}, vocab={vocab_size})")
return model
# Run
model = build_complete_transformer()
Part III: SASRec - The Foundation
Self-Attentive Sequential Recommendation
SASRec (Kang & McAuley, 2018) was the first successful application of self-attention to sequential recommendation. It's elegant and remains a strong baseline.
┌─────────────────────────────────────────────────────────────────────────┐
│ SASREC ARCHITECTURE │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ INPUT LAYER: │
│ ──────────── │
│ │
│ User sequence: [item_1, item_2, item_3, ..., item_n] │
│ ↓ ↓ ↓ ↓ │
│ Item embeddings: e_1 e_2 e_3 ... e_n │
│ + + + + │
│ Position embeds: p_1 p_2 p_3 ... p_n │
│ ↓ ↓ ↓ ↓ │
│ [e_1+p_1, e_2+p_2, e_3+p_3, ..., e_n+p_n] │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ TRANSFORMER BLOCKS (x L layers): │
│ ───────────────────────────────── │
│ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ │ │
│ │ Input → [Masked Self-Attention] → Add & Norm → [FFN] → Add & Norm│ │
│ │ (Causal mask: can │ │
│ │ only attend to past) │ │
│ │ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │
│ CAUSAL MASKING (Key design choice): │
│ ─────────────────────────────────── │
│ │
│ Attention Matrix What each position can see │
│ ┌───┬───┬───┬───┐ │
│ │ 1 │ 0 │ 0 │ 0 │ pos 1 → sees only item_1 │
│ ├───┼───┼───┼───┤ │
│ │ 1 │ 1 │ 0 │ 0 │ pos 2 → sees item_1, item_2 │
│ ├───┼───┼───┼───┤ │
│ │ 1 │ 1 │ 1 │ 0 │ pos 3 → sees item_1, item_2, item_3 │
│ ├───┼───┼───┼───┤ │
│ │ 1 │ 1 │ 1 │ 1 │ pos 4 → sees all items │
│ └───┴───┴───┴───┘ │
│ │
│ OUTPUT: Predict next item at each position │
│ ────────────────────────────────────────── │
│ │
│ pos 1 output → should predict item_2 │
│ pos 2 output → should predict item_3 │
│ pos n output → should predict item_{n+1} │
│ │
└─────────────────────────────────────────────────────────────────────────┘
SASRec Implementation
Before diving into the code, let's understand why SASRec works and the key design decisions that made it successful.
Why self-attention for sequential recommendation?
Prior to SASRec, sequential recommendation used RNNs (GRU4Rec) or CNNs (Caser). These approaches had fundamental limitations:
-
RNNs: Process sequences left-to-right, one step at a time. The hidden state must compress all history into a fixed-size vector—a bottleneck. Also, sequential processing prevents parallelization.
-
CNNs: Capture local patterns through fixed-size filters. To capture long-range dependencies, you need many stacked layers, and even then, distant items interact only indirectly.
Self-attention solves both problems:
-
Direct access to any past item: When predicting what comes after item 10, attention can directly look at item 1, item 5, or any other—no information bottleneck.
-
Parallel computation: All positions compute attention simultaneously. Training is much faster on GPUs.
-
Adaptive receptive field: The model learns which items to attend to. For a "laptop bag" purchase, it might attend strongly to a recent "laptop" purchase from 3 items ago, while ignoring unrelated items.
The key architectural decisions:
-
Causal (unidirectional) attention: Unlike BERT which sees the full sequence, SASRec only lets each position attend to past positions. This matches the recommendation task: when a user is on item 5, we can only use items 1-4 to predict item 6.
-
Learnable position embeddings: Unlike the original Transformer's sinusoidal positions, SASRec learns position embeddings. With short sequences (50-200 items), there's enough data to learn positions, and learned embeddings often work better.
-
Shared item embeddings for input and output: The same embedding matrix is used to (a) represent input items and (b) score candidate items for prediction. This parameter tying reduces overfitting and ensures consistency.
-
Dropout everywhere: Applied to embeddings, attention weights, and FFN layers. Critical for preventing overfitting on sparse user-item data.
Training objective:
SASRec uses binary cross-entropy with negative sampling. For each position in the sequence:
- Positive: The actual next item
- Negatives: Randomly sampled items the user didn't interact with
The loss encourages the model to score true next items higher than random items:
Where is the hidden state at position and is the embedding of item .
Why SASRec remains a strong baseline:
Despite being published in 2018, SASRec consistently appears in top results:
- Simple architecture with few hyperparameters
- Efficient training (hours, not days)
- Works across domains (e-commerce, streaming, news)
- Easy to extend with additional features
Many "improvements" over SASRec fail to consistently outperform it when properly tuned.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SASRec(nn.Module):
"""
Self-Attentive Sequential Recommendation
Paper: https://arxiv.org/abs/1808.09781
"""
def __init__(
self,
num_items: int,
max_seq_len: int = 50,
hidden_dim: int = 64,
num_heads: int = 2,
num_layers: int = 2,
dropout: float = 0.2,
):
super().__init__()
self.num_items = num_items
self.max_seq_len = max_seq_len
self.hidden_dim = hidden_dim
# Item embeddings (0 is padding)
self.item_embedding = nn.Embedding(
num_items + 1, # +1 for padding token
hidden_dim,
padding_idx=0
)
# Learnable position embeddings
self.position_embedding = nn.Embedding(max_seq_len, hidden_dim)
# Transformer encoder layers
self.layers = nn.ModuleList([
SASRecBlock(hidden_dim, num_heads, dropout)
for _ in range(num_layers)
])
# Layer normalization
self.layer_norm = nn.LayerNorm(hidden_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, item_seq: torch.Tensor) -> torch.Tensor:
"""
Args:
item_seq: (batch_size, seq_len) - sequence of item IDs
Returns:
(batch_size, seq_len, hidden_dim) - sequence representations
"""
batch_size, seq_len = item_seq.shape
# Create position indices
positions = torch.arange(seq_len, device=item_seq.device)
positions = positions.unsqueeze(0).expand(batch_size, -1)
# Embed items and positions
item_emb = self.item_embedding(item_seq) # (B, L, D)
pos_emb = self.position_embedding(positions) # (B, L, D)
# Combine embeddings
hidden = self.dropout(item_emb + pos_emb)
# Create causal attention mask
# True = masked (cannot attend), False = can attend
causal_mask = torch.triu(
torch.ones(seq_len, seq_len, device=item_seq.device),
diagonal=1
).bool()
# Create padding mask
padding_mask = (item_seq == 0) # (B, L)
# Apply transformer layers
for layer in self.layers:
hidden = layer(hidden, causal_mask, padding_mask)
hidden = self.layer_norm(hidden)
return hidden
def predict(self, item_seq: torch.Tensor) -> torch.Tensor:
"""
Predict scores for all items given a sequence.
Returns:
(batch_size, num_items) - scores for each item
"""
hidden = self.forward(item_seq) # (B, L, D)
# Use last position's representation
final_hidden = hidden[:, -1, :] # (B, D)
# Score all items via dot product with item embeddings
# Exclude padding embedding (index 0)
item_embeddings = self.item_embedding.weight[1:] # (num_items, D)
scores = torch.matmul(final_hidden, item_embeddings.T) # (B, num_items)
return scores
class SASRecBlock(nn.Module):
"""Single transformer block for SASRec."""
def __init__(self, hidden_dim: int, num_heads: int, dropout: float):
super().__init__()
self.attention = nn.MultiheadAttention(
hidden_dim, num_heads, dropout=dropout, batch_first=True
)
self.ffn = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * 4),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim * 4, hidden_dim),
nn.Dropout(dropout),
)
self.norm1 = nn.LayerNorm(hidden_dim)
self.norm2 = nn.LayerNorm(hidden_dim)
self.dropout = nn.Dropout(dropout)
def forward(
self,
hidden: torch.Tensor,
causal_mask: torch.Tensor,
padding_mask: torch.Tensor,
) -> torch.Tensor:
# Self-attention with residual
normed = self.norm1(hidden)
attn_out, _ = self.attention(
normed, normed, normed,
attn_mask=causal_mask,
key_padding_mask=padding_mask,
)
hidden = hidden + self.dropout(attn_out)
# FFN with residual
normed = self.norm2(hidden)
ffn_out = self.ffn(normed)
hidden = hidden + ffn_out
return hidden
# Training example
def train_sasrec():
"""Example training loop for SASRec."""
# Model setup
model = SASRec(
num_items=10000,
max_seq_len=50,
hidden_dim=64,
num_heads=2,
num_layers=2,
)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Training with "shifted sequence" objective
# Input: [item_1, item_2, ..., item_n]
# Target: [item_2, item_3, ..., item_{n+1}]
for epoch in range(100):
for batch in dataloader: # Your data loader
input_seq = batch['input_seq'] # (B, L)
target_seq = batch['target_seq'] # (B, L)
# Forward pass
hidden = model(input_seq) # (B, L, D)
# Compute scores for all items at each position
item_emb = model.item_embedding.weight[1:] # (num_items, D)
logits = torch.matmul(hidden, item_emb.T) # (B, L, num_items)
# Cross-entropy loss
loss = F.cross_entropy(
logits.view(-1, model.num_items),
target_seq.view(-1),
ignore_index=0, # Ignore padding
)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
Why Causal Masking?
SASRec uses causal (unidirectional) attention—each position can only attend to previous positions. This is crucial for two reasons:
-
Training efficiency: With causal masking, we can train on the entire sequence at once using the "shifted sequence" objective. Position predicts item , using only items .
-
Inference correctness: At inference time, we only have past interactions. The model shouldn't learn to rely on future information that won't be available.
Part IV: BERT4Rec - Bidirectional Attention
The Bidirectional Argument
BERT4Rec (Sun et al., 2019) argued that unidirectional attention limits the model's ability to learn representations. In NLP, BERT's bidirectional training dramatically outperformed GPT-style models on many tasks. Could the same apply to recommendations?
┌─────────────────────────────────────────────────────────────────────────┐
│ SASREC vs BERT4REC: ATTENTION PATTERNS │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ SASREC (Unidirectional/Causal): │
│ ───────────────────────────────── │
│ │
│ Sequence: [shoes, socks, shorts, shirt] │
│ │
│ When representing "socks": │
│ - Can see: shoes, socks │
│ - Cannot see: shorts, shirt │
│ │
│ "Socks" representation knows about previous items only │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ BERT4REC (Bidirectional): │
│ ────────────────────────── │
│ │
│ Sequence: [shoes, [MASK], shorts, shirt] │
│ │
│ When predicting [MASK] (was "socks"): │
│ - Can see: shoes, shorts, shirt │
│ - Must predict: socks │
│ │
│ Prediction uses FULL context (past AND future) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ THE TRADE-OFF: │
│ │
│ SASRec advantages: │
│ + Natural fit for sequential prediction │
│ + Efficient training (predict all positions) │
│ + Simpler inference (just use last position) │
│ │
│ BERT4Rec advantages: │
│ + Richer representations (bidirectional context) │
│ + Better modeling of item relationships │
│ + Can leverage future context during training │
│ │
│ BERT4Rec disadvantages: │
│ - Training is less efficient (only predict masked items) │
│ - Inference is different from training (no masking) │
│ - More hyperparameters (mask ratio) │
│ │
└─────────────────────────────────────────────────────────────────────────┘
BERT4Rec Implementation
BERT4Rec brought the revolutionary BERT pretraining paradigm from NLP to recommendations. Understanding its mechanics reveals both its power and its limitations.
The core idea: Masked Language Modeling for items
Instead of predicting the next item given previous items (like SASRec), BERT4Rec randomly masks items in the sequence and predicts them from context—both left AND right:
SASRec (Causal): [A, B, C, D, ?] → predict E using A,B,C,D
BERT4Rec (Masked): [A, [MASK], C, D, E] → predict B using A,C,D,E
This bidirectional context is powerful: to predict item B, the model can see what came before (A) AND after (C, D, E). In real user behavior, items often relate to both past and future actions.
Why bidirectional context might help:
Consider a user sequence: [Laptop → ??? → Laptop Bag → USB Hub → Monitor Stand]
- Unidirectional (SASRec): To predict the masked item, only sees "Laptop"
- Bidirectional (BERT4Rec): Sees "Laptop" AND "Laptop Bag, USB Hub, Monitor Stand"
The bidirectional model can infer the masked item is likely laptop-related (maybe "Laptop Case" or "Mouse") because it sees the surrounding accessory purchases.
The training/inference mismatch problem:
BERT4Rec has a fundamental issue that SASRec doesn't:
- Training: Model sees sequences with [MASK] tokens, predicts masked items
- Inference: Model sees sequences WITHOUT masks, must predict next item
This mismatch means the model is trained on a different distribution than it's tested on. Various solutions exist:
- Always append [MASK] at the end during inference
- Use the last position's hidden state (ignoring the masking objective)
- Fine-tune on next-item prediction after pretraining
The mask ratio hyperparameter:
BERT4Rec introduces a new hyperparameter: what fraction of items to mask? The original paper uses 20%, but optimal values vary by dataset:
- Too low (5%): Not enough training signal per sequence
- Too high (50%): Too much context removed, predictions become random
Bidirectional attention = No causal mask:
Unlike SASRec which uses a triangular attention mask (each position sees only past), BERT4Rec uses full attention—every position can attend to every other position (except padding). This is the source of both its power (richer context) and its limitation (can't do autoregressive generation).
class BERT4Rec(nn.Module):
"""
Bidirectional Encoder Representations from Transformers for Recommendation
Paper: https://arxiv.org/abs/1904.06690
Key difference from SASRec: Uses masked language modeling objective
instead of causal (next-item) prediction.
"""
def __init__(
self,
num_items: int,
max_seq_len: int = 50,
hidden_dim: int = 64,
num_heads: int = 2,
num_layers: int = 2,
dropout: float = 0.2,
mask_prob: float = 0.2, # Probability of masking each item
):
super().__init__()
self.num_items = num_items
self.max_seq_len = max_seq_len
self.mask_prob = mask_prob
self.mask_token = num_items + 1 # Special [MASK] token
# Item embeddings: 0=pad, 1...num_items=items, num_items+1=[MASK]
self.item_embedding = nn.Embedding(
num_items + 2, # +1 for padding, +1 for mask
hidden_dim,
padding_idx=0
)
self.position_embedding = nn.Embedding(max_seq_len, hidden_dim)
# BERT uses bidirectional attention (no causal mask)
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=num_heads,
dim_feedforward=hidden_dim * 4,
dropout=dropout,
activation='gelu',
batch_first=True,
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
self.layer_norm = nn.LayerNorm(hidden_dim)
self.dropout = nn.Dropout(dropout)
# Output projection
self.output_proj = nn.Linear(hidden_dim, num_items)
def forward(
self,
item_seq: torch.Tensor,
masked_positions: torch.Tensor = None,
) -> torch.Tensor:
"""
Args:
item_seq: (batch_size, seq_len) - may contain [MASK] tokens
masked_positions: (batch_size, num_masks) - positions to predict
Returns:
logits: (batch_size, num_masks, num_items) if masked_positions given
else (batch_size, seq_len, num_items)
"""
batch_size, seq_len = item_seq.shape
# Position indices
positions = torch.arange(seq_len, device=item_seq.device)
positions = positions.unsqueeze(0).expand(batch_size, -1)
# Embeddings
item_emb = self.item_embedding(item_seq)
pos_emb = self.position_embedding(positions)
hidden = self.dropout(item_emb + pos_emb)
# Padding mask only (no causal mask - bidirectional!)
padding_mask = (item_seq == 0)
# Transformer (bidirectional attention)
hidden = self.transformer(hidden, src_key_padding_mask=padding_mask)
hidden = self.layer_norm(hidden)
# If specific positions requested, extract those
if masked_positions is not None:
# Gather hidden states at masked positions
batch_indices = torch.arange(batch_size, device=item_seq.device)
batch_indices = batch_indices.unsqueeze(1).expand(-1, masked_positions.shape[1])
hidden = hidden[batch_indices, masked_positions] # (B, num_masks, D)
# Project to item vocabulary
logits = self.output_proj(hidden)
return logits
def mask_sequence(self, item_seq: torch.Tensor):
"""
Apply masking for training (Cloze task).
Returns:
masked_seq: sequence with some items replaced by [MASK]
labels: original items at masked positions (-100 elsewhere)
masked_positions: indices of masked positions
"""
batch_size, seq_len = item_seq.shape
# Create mask: which positions to mask
mask = torch.rand(batch_size, seq_len, device=item_seq.device) < self.mask_prob
mask = mask & (item_seq != 0) # Don't mask padding
# Create masked sequence
masked_seq = item_seq.clone()
masked_seq[mask] = self.mask_token
# Create labels (-100 for non-masked, item_id for masked)
labels = torch.full_like(item_seq, -100)
labels[mask] = item_seq[mask]
return masked_seq, labels, mask
def predict_next(self, item_seq: torch.Tensor) -> torch.Tensor:
"""
For inference: Append [MASK] and predict it.
"""
batch_size = item_seq.shape[0]
# Append [MASK] token
mask_token = torch.full((batch_size, 1), self.mask_token, device=item_seq.device)
seq_with_mask = torch.cat([item_seq, mask_token], dim=1)
# Predict at mask position (last position)
logits = self.forward(seq_with_mask)
return logits[:, -1, :] # (B, num_items)
def train_bert4rec():
"""Training loop for BERT4Rec."""
model = BERT4Rec(
num_items=10000,
max_seq_len=50,
hidden_dim=64,
num_heads=2,
num_layers=2,
mask_prob=0.2,
)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(100):
for batch in dataloader:
item_seq = batch['item_seq'] # (B, L)
# Apply random masking
masked_seq, labels, mask = model.mask_sequence(item_seq)
# Forward pass
logits = model(masked_seq) # (B, L, num_items)
# Loss only on masked positions
loss = F.cross_entropy(
logits.view(-1, model.num_items),
labels.view(-1),
ignore_index=-100,
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
The Verdict: SASRec vs BERT4Rec
Despite BERT4Rec's intuitive appeal, the evidence is mixed. A systematic replication study found:
"BERT4Rec is not consistently superior compared to SASRec in the published literature. The reason might be that BERT4Rec is better on some datasets, but SASRec is better on others."
A 2023 RecSys paper ("Turning Dross Into Gold Loss") showed that with proper tuning, SASRec can match or exceed BERT4Rec while being simpler and more efficient.
Practical recommendation: Start with SASRec. It's simpler, faster to train, and has a more natural inference setup. Try BERT4Rec if you have specific evidence it helps on your data.
Part V: Advanced Architectures
Transformers4Rec: Production-Ready Framework
NVIDIA's Transformers4Rec bridges NLP transformers and RecSys. It allows using HuggingFace transformer architectures directly for recommendation.
from transformers4rec import torch as tr
from transformers4rec.torch.ranking_metric import NDCGAt, RecallAt
# Define schema
schema = tr.Schema(
features=[
tr.Feature("item_id", tr.CATEGORICAL, num_items),
tr.Feature("category", tr.CATEGORICAL, num_categories),
tr.Feature("price", tr.CONTINUOUS),
],
targets=tr.Target("next_item", tr.CATEGORICAL, num_items),
)
# Build model with XLNet architecture
input_module = tr.TabularSequenceFeatures.from_schema(
schema,
max_sequence_length=20,
aggregation="concat",
)
transformer_config = tr.XLNetConfig.build(
d_model=64,
n_head=4,
n_layer=2,
)
body = tr.SequentialBlock(
input_module,
tr.MLPBlock([64]),
tr.TransformerBlock(transformer_config),
)
head = tr.Head(
body,
tr.NextItemPredictionTask(weight_tying=True),
inputs=input_module,
)
model = tr.Model(head)
Key features:
- Multiple transformer architectures: GPT-2, XLNet, BERT, Longformer, etc.
- Side information integration: Item features, user features, context
- Production-optimized: Integration with NVIDIA Triton for serving
HSTU: Meta's Trillion-Parameter Approach
HSTU (Hierarchical Sequential Transduction Units) represents the current state of the art for large-scale sequential recommendation. Published at ICML 2024 by Meta AI.
┌─────────────────────────────────────────────────────────────────────────┐
│ HSTU KEY INNOVATIONS │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 1. GENERATIVE FRAMING │
│ ──────────────────────── │
│ │
│ Traditional: Separate retrieval and ranking models │
│ HSTU: Single generative model that "generates" next interactions │
│ │
│ Instead of: "Score these 1000 candidates" │
│ Do: "What will user do next?" → Generates item distribution │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ 2. POINTWISE (NO SOFTMAX) ATTENTION │
│ ───────────────────────────────────── │
│ │
│ Standard Transformer: │
│ attention = softmax(QK^T / √d) × V │
│ │
│ HSTU: │
│ attention = φ(Q) × φ(K)^T × V (pointwise, no softmax) │
│ │
│ Why remove softmax? │
│ - Preserves preference intensity (softmax normalizes away) │
│ - Better for non-stationary vocabularies (new items daily) │
│ - Enables linear-time approximations │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ 3. HIERARCHICAL STRUCTURE │
│ ────────────────────────── │
│ │
│ Actions and items interleaved in single sequence: │
│ │
│ [item_1, click, item_2, purchase, item_3, view, item_4, ...] │
│ │
│ This captures: │
│ - What items user interacted with │
│ - How they interacted (click vs purchase vs view) │
│ - Temporal ordering of both │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ 4. SCALE & RESULTS │
│ ─────────────────── │
│ │
│ - Scales to TRILLION parameters │
│ - 12.4%+ improvement in production metrics at Meta │
│ - First demonstration of scaling laws in industrial RecSys │
│ - 10-1000x efficiency gains via M-FALCON inference │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Why HSTU represents a paradigm shift:
HSTU isn't just a bigger transformer—it fundamentally rethinks how transformers should work for recommendations. Here's why each innovation matters:
1. Removing softmax: The counterintuitive insight
Standard attention uses softmax to create a probability distribution over keys. But this has a problem for recommendations:
User strongly prefers: [Item A (score 100), Item B (score 99), Item C (score 1)]
After softmax:
- Item A: ~50%
- Item B: ~50%
- Item C: ~0%
The 100x preference for A over C becomes only 2x after normalization!
HSTU's pointwise attention preserves these intensity differences. If a user clicked item A 100 times and item C once, that 100:1 ratio matters.
2. Actions as first-class citizens
Most sequential models treat sequences as: [item₁, item₂, item₃, ...]
But user behavior is richer. HSTU models: [item₁, click, item₂, purchase, item₃, view, ...]
This captures that a purchase is more informative than a view. The model learns that sequences ending with "purchase" signals different intent than sequences ending with "view."
3. The scaling law discovery
HSTU demonstrated, for the first time, that scaling laws exist in recommendation—just like in LLMs. More parameters consistently improve performance:
Parameters Relative Improvement
───────────────────────────────────
100M baseline
1B +3.2%
10B +6.8%
100B +9.1%
1T +12.4%
This validates the investment in large recommendation models and suggests we're far from diminishing returns.
4. M-FALCON: Making trillion-parameter serving feasible
Training a 1T model is hard; serving it is harder. M-FALCON (Meta's Factorized Attention with Long-Context On-the-Fly Normalization) enables efficient inference by:
- Factorizing attention computation across devices
- Caching key-value pairs for recent items
- Using speculative decoding for next-item prediction
The result: 10-1000x efficiency gains, making trillion-parameter models practical for real-time serving.
# Simplified HSTU-style attention (conceptual)
class PointwiseAttention(nn.Module):
"""
HSTU-style attention without softmax normalization.
Preserves preference intensity.
"""
def __init__(self, hidden_dim: int, num_heads: int):
super().__init__()
self.num_heads = num_heads
self.head_dim = hidden_dim // num_heads
self.q_proj = nn.Linear(hidden_dim, hidden_dim)
self.k_proj = nn.Linear(hidden_dim, hidden_dim)
self.v_proj = nn.Linear(hidden_dim, hidden_dim)
self.out_proj = nn.Linear(hidden_dim, hidden_dim)
# Feature map (e.g., ELU + 1 for positivity)
self.feature_map = lambda x: F.elu(x) + 1
def forward(self, x: torch.Tensor, causal: bool = True) -> torch.Tensor:
B, L, D = x.shape
Q = self.q_proj(x).view(B, L, self.num_heads, self.head_dim)
K = self.k_proj(x).view(B, L, self.num_heads, self.head_dim)
V = self.v_proj(x).view(B, L, self.num_heads, self.head_dim)
# Apply feature map (makes attention linear-time possible)
Q = self.feature_map(Q)
K = self.feature_map(K)
if causal:
# Causal linear attention via cumulative sum
# O(L) instead of O(L^2)
KV = torch.einsum('blhd,blhe->blhde', K, V)
KV_cumsum = torch.cumsum(KV, dim=1)
output = torch.einsum('blhd,blhde->blhe', Q, KV_cumsum)
# Normalize by cumulative key sum
K_cumsum = torch.cumsum(K, dim=1)
normalizer = torch.einsum('blhd,blhd->blh', Q, K_cumsum)
output = output / (normalizer.unsqueeze(-1) + 1e-6)
else:
# Non-causal: standard (but still pointwise)
attn = torch.einsum('blhd,bmhd->blmh', Q, K)
output = torch.einsum('blmh,bmhd->blhd', attn, V)
output = output.reshape(B, L, D)
return self.out_proj(output)
LiGR: LinkedIn's Generative Ranking Architecture
LiGR (LinkedIn Generative Ranking) is a production transformer architecture that powers recommendations at LinkedIn. It introduces key innovations that address limitations of standard transformers for ranking.
┌─────────────────────────────────────────────────────────────────────────┐
│ LiGR ARCHITECTURE │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ PROBLEM WITH STANDARD TRANSFORMERS: │
│ ───────────────────────────────────── │
│ │
│ Deep residual stacking can "forget" distant layer information │
│ Standard: output = LayerNorm(x + Attention(x)) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ LiGR SOLUTION: Recurrent Gating Dynamics │
│ ───────────────────────────────────────── │
│ │
│ • Parameterized gates selectively preserve aggregated information │
│ • Improved gradient propagation and training stability │
│ • Enables deep stacking (24+ layers) without degradation │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ KEY FEATURES: │
│ │
│ 1. Historical Attention: Aggregates user behavior features │
│ 2. In-Session Attention: Models current session context │
│ 3. Set-wise Scoring: Jointly scores items for automatic diversity │
│ 4. Learned Normalization: Adapts to recommendation-specific patterns │
│ │
│ RESULTS AT LINKEDIN: │
│ • Deprecated most manual feature engineering │
│ • Outperforms systems using hundreds of features with only a few │
│ • Validated scaling laws for ranking systems │
│ • Automatic diversity via simultaneous set-wise scoring │
│ │
└─────────────────────────────────────────────────────────────────────────┘
eSASRec: Enhanced SASRec (RecSys 2025)
eSASRec (ACM RecSys 2025) systematically studies what makes transformer-based recommendation work. The key finding: combining the right components matters more than novel architectures.
The research question eSASRec answers:
Every year, new sequential recommendation models claim to beat SASRec. But do they really? Or are they just using better training tricks? eSASRec decomposes the improvements into their sources:
Claimed improvement sources in literature:
─────────────────────────────────────────────────────────────────────────
- Novel attention mechanisms (local, sparse, linear)
- Better positional encodings (relative, rotary)
- Improved architectures (gating, dense layers)
- Training objectives (contrastive, generative)
Actual improvement sources (eSASRec findings):
─────────────────────────────────────────────────────────────────────────
- Loss function: BCE → Sampled Softmax (+3-5%)
- Layer design: Vanilla → LiGR-style gating (+2-3%)
- Regularization: Proper dropout tuning (+1-2%)
- The "novel" architectures? Often 0% improvement when controlled properly
The three components that actually matter:
-
SASRec's training objective (next-item prediction with causal masking): Despite BERT4Rec's bidirectional attention, unidirectional training remains best for online serving.
-
LiGR-style transformer layers (gated residuals): The gating mechanism from LiGR helps gradient flow in deeper networks. Standard residual connections (x + attention(x)) work fine for 2-3 layers but degrade at 6+ layers.
-
Sampled softmax loss: Full softmax over millions of items is computationally prohibitive. Sampled softmax with hard negatives achieves comparable quality at fraction of the cost.
Why this matters for practitioners:
You don't need the latest fancy architecture. Take SASRec, add sampled softmax, use LiGR-style layers, tune your dropout—you'll match or beat most published "state-of-the-art" results. The recipe is:
- Architecture: Standard transformer with gated residuals
- Training: Causal masking, sampled softmax (256-1024 negatives)
- Regularization: Dropout 0.2-0.5 depending on data sparsity
- Depth: 2-4 layers (diminishing returns beyond)
# eSASRec: The winning combination
class ESASRec(nn.Module):
"""
Enhanced SASRec = SASRec objective + LiGR layers + Sampled Softmax
Paper: https://arxiv.org/abs/2508.06450
"""
def __init__(
self,
num_items: int,
hidden_dim: int = 128,
num_layers: int = 2,
num_heads: int = 4,
num_negatives: int = 256, # For sampled softmax
):
super().__init__()
# Item embeddings
self.item_embedding = nn.Embedding(num_items + 1, hidden_dim, padding_idx=0)
self.position_embedding = nn.Embedding(50, hidden_dim)
# LiGR Transformer layers (key difference from vanilla SASRec)
self.layers = nn.ModuleList([
LiGRBlock(hidden_dim, num_heads)
for _ in range(num_layers)
])
self.num_negatives = num_negatives
self.layer_norm = nn.LayerNorm(hidden_dim)
def forward(self, item_seq):
# Standard embedding + position
seq_emb = self.item_embedding(item_seq)
positions = torch.arange(item_seq.shape[1], device=item_seq.device)
seq_emb = seq_emb + self.position_embedding(positions)
# Causal mask
mask = torch.triu(torch.ones(item_seq.shape[1], item_seq.shape[1]), diagonal=1).bool()
# LiGR transformer layers
hidden = seq_emb
for layer in self.layers:
hidden = layer(hidden, mask)
return self.layer_norm(hidden)
class LiGRBlock(nn.Module):
"""
LiGR Transformer block with gated residuals.
Preserves information better across deep networks.
"""
def __init__(self, hidden_dim: int, num_heads: int):
super().__init__()
self.attention = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
self.ffn = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * 4),
nn.GELU(),
nn.Linear(hidden_dim * 4, hidden_dim),
)
# Gating mechanism (key LiGR innovation)
self.gate_attn = nn.Linear(hidden_dim * 2, hidden_dim)
self.gate_ffn = nn.Linear(hidden_dim * 2, hidden_dim)
self.norm1 = nn.LayerNorm(hidden_dim)
self.norm2 = nn.LayerNorm(hidden_dim)
def forward(self, x, mask):
# Attention with gated residual
normed = self.norm1(x)
attn_out, _ = self.attention(normed, normed, normed, attn_mask=mask)
# Gated combination (not just addition)
gate = torch.sigmoid(self.gate_attn(torch.cat([x, attn_out], dim=-1)))
x = x * gate + attn_out * (1 - gate)
# FFN with gated residual
normed = self.norm2(x)
ffn_out = self.ffn(normed)
gate = torch.sigmoid(self.gate_ffn(torch.cat([x, ffn_out], dim=-1)))
x = x * gate + ffn_out * (1 - gate)
return x
Key findings from the eSASRec study:
| Component | Options Tested | Winner |
|---|---|---|
| Training Objective | Next-item, Masked, Contrastive | Next-item (SASRec-style) |
| Transformer Architecture | Vanilla, LiGR, HSTU | LiGR |
| Loss Function | Full softmax, BCE, Sampled Softmax | Sampled Softmax |
| Negative Sampling | Uniform, Popularity, In-batch | Popularity-weighted |
Results: eSASRec achieves 23% improvement over previous SOTA (ActionPiece) and resides on the Pareto frontier alongside HSTU for accuracy-coverage trade-off. Importantly, it requires no extra features (unlike HSTU which needs timestamps).
Part VI: Beyond Attention - State Space Models
Mamba4Rec: SSMs for Sequential Recommendation
Mamba4Rec (Best Paper Award, RelKD@KDD 2024) applies Selective State Space Models to sequential recommendation, achieving transformer-level accuracy with linear complexity.
┌─────────────────────────────────────────────────────────────────────────┐
│ ATTENTION vs STATE SPACE MODELS │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ TRANSFORMER ATTENTION: │
│ ─────────────────────── │
│ │
│ Complexity: O(n²) in sequence length │
│ Memory: O(n²) for attention matrix │
│ Parallelization: Excellent (all positions computed together) │
│ │
│ Problem: Quadratic cost prohibitive for long sequences (1000+ items) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ STATE SPACE MODELS (Mamba): │
│ ──────────────────────────── │
│ │
│ Complexity: O(n) in sequence length │
│ Memory: O(n) - no attention matrix │
│ Parallelization: Good (via parallel scan) │
│ │
│ Key Innovation: Selective state transitions │
│ - Content-aware gating of what to remember │
│ - Learned dynamics for sequence modeling │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ WHY SSMs FOR RECOMMENDATIONS: │
│ │
│ • Users with 1000+ interactions are common │
│ • Linear scaling enables full history modeling │
│ • Selective mechanism learns what's relevant │
│ • Efficient inference for real-time serving │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Understanding why Mamba works for recommendations:
State Space Models (SSMs) seem like a step backward—they process sequences left-to-right like RNNs, while transformers can see everything at once. So why do they work?
The key insight: Selective state transitions
Unlike traditional RNNs that blindly compress all history into a fixed-size state, Mamba selectively decides what to remember:
Traditional RNN:
h_t = tanh(W_h * h_{t-1} + W_x * x_t)
↑ Fixed update rule, can't prioritize
Mamba (Selective SSM):
Δ_t = softplus(Linear(x_t)) ← Input-dependent discretization
B_t = Linear(x_t) ← Input-dependent input matrix
C_t = Linear(x_t) ← Input-dependent output matrix
h_t = exp(-Δ_t * A) * h_{t-1} + Δ_t * B_t * x_t
↑ How much to forget ↑ How much to incorporate
The selectivity means:
- A boring "filler" item (user just browsing): Small Δ → retain previous state, mostly ignore this item
- An important signal (user made a purchase): Large Δ → reset state, heavily weight this item
Mamba4Rec's specific innovations:
-
Bidirectional Mamba: Unlike language models that only go left-to-right, Mamba4Rec processes sequences in both directions and combines the states. This captures patterns like "users who will buy X often first browse Y."
-
Position embeddings preserved: Despite being an SSM, Mamba4Rec adds positional encodings—important because position matters differently in recommendations than in language.
-
Item embedding sharing: Same embedding matrix for input and output, like SASRec.
When to use Mamba4Rec over transformers:
- Long sequences (500+ items): Quadratic attention becomes expensive
- Memory-constrained serving: No O(n²) attention matrix
- Streaming scenarios: Can update state incrementally without recomputing
When to stick with transformers:
- Short sequences (<100 items): Quadratic cost is negligible
- Need for interpretability: Attention weights are interpretable; SSM states are not
- Existing infrastructure: Most production systems are built around transformers
# Mamba4Rec architecture (simplified)
class Mamba4Rec(nn.Module):
"""
Sequential recommendation with Selective State Space Models.
Paper: https://arxiv.org/abs/2403.03900
"""
def __init__(
self,
num_items: int,
hidden_dim: int = 64,
state_dim: int = 16,
num_layers: int = 2,
):
super().__init__()
self.item_embedding = nn.Embedding(num_items + 1, hidden_dim, padding_idx=0)
self.position_embedding = nn.Embedding(200, hidden_dim)
# Mamba layers instead of transformer
self.layers = nn.ModuleList([
MambaBlock(hidden_dim, state_dim)
for _ in range(num_layers)
])
self.layer_norm = nn.LayerNorm(hidden_dim)
def forward(self, item_seq):
B, L = item_seq.shape
# Embeddings
seq_emb = self.item_embedding(item_seq)
positions = torch.arange(L, device=item_seq.device)
hidden = seq_emb + self.position_embedding(positions)
# Mamba layers (O(L) instead of O(L²))
for layer in self.layers:
hidden = layer(hidden)
return self.layer_norm(hidden)
class MambaBlock(nn.Module):
"""
Selective State Space Model block.
Core idea: Learn what information to retain in hidden state.
"""
def __init__(self, hidden_dim: int, state_dim: int):
super().__init__()
self.hidden_dim = hidden_dim
self.state_dim = state_dim
# Input projections
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 2)
# SSM parameters (learned per-position)
self.x_proj = nn.Linear(hidden_dim, state_dim * 2) # B, C matrices
self.dt_proj = nn.Linear(hidden_dim, hidden_dim) # Δ (timestep)
# State matrix A (fixed, learned)
self.A = nn.Parameter(torch.randn(hidden_dim, state_dim))
# Output projection
self.out_proj = nn.Linear(hidden_dim, hidden_dim)
self.norm = nn.LayerNorm(hidden_dim)
def forward(self, x):
B, L, D = x.shape
residual = x
# Project input
xz = self.in_proj(x)
x_ssm, z = xz.chunk(2, dim=-1)
# Compute selective parameters
x_dbl = self.x_proj(x_ssm)
delta = F.softplus(self.dt_proj(x_ssm))
# Discretize and run SSM (simplified)
# Full implementation uses parallel scan for efficiency
A = -torch.exp(self.A)
# Selective state space computation
# This is where the "selection" happens - delta modulates
# how much of each input affects the hidden state
out = self._ssm_scan(x_ssm, A, delta)
# Gated output
out = out * F.silu(z)
out = self.out_proj(out)
return self.norm(residual + out)
def _ssm_scan(self, x, A, delta):
# Simplified selective scan
# Real implementation uses CUDA kernels for efficiency
B, L, D = x.shape
h = torch.zeros(B, D, self.state_dim, device=x.device)
outputs = []
for t in range(L):
h = h * torch.exp(A * delta[:, t:t+1, :, None]) + x[:, t, :, None]
outputs.append(h.sum(dim=-1))
return torch.stack(outputs, dim=1)
Mamba4Rec Results:
- Matches or exceeds SASRec/BERT4Rec accuracy
- 10x faster inference on sequences >500 items
- Lower memory footprint enables longer histories
Related SSM Models (2024-2025):
- EchoMamba4Rec: Bidirectional SSM with spectral filtering
- MaTrRec: Hybrid Mamba + Transformer
- SSD4Rec: Structured State Space Duality for efficiency
Part VII: Generative Retrieval with Semantic IDs
TIGER: Treating Recommendation as Generation
TIGER (Transformer Index for GEnerative Recommenders) introduces a paradigm shift: instead of scoring items, generate item identifiers token-by-token.
┌─────────────────────────────────────────────────────────────────────────┐
│ TRADITIONAL vs GENERATIVE RETRIEVAL │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ TRADITIONAL (Two-Tower): │
│ ───────────────────────── │
│ │
│ User History → User Encoder → User Embedding │
│ ↓ │
│ Dot Product → Top-K items │
│ ↑ │
│ Item Features → Item Encoder → Item Embeddings (pre-computed) │
│ │
│ Requires: ANN index, separate item encoding, retrieval infrastructure │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ GENERATIVE (TIGER): │
│ ───────────────────── │
│ │
│ User History → Transformer → Generate Semantic ID tokens │
│ │
│ [item₁, item₂, item₃] → Decoder → [tok₁, tok₂, tok₃, tok₄] │
│ ↓ │
│ Map to Item ID │
│ │
│ Benefits: │
│ • Unified model (no separate retrieval infrastructure) │
│ • Similar items share token prefixes │
│ • Beam search explores item space efficiently │
│ • Natural handling of semantic similarity │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Semantic ID Tokenization
The key innovation: represent each item as a sequence of learned discrete tokens (Semantic ID) that capture semantic meaning.
Why Semantic IDs are revolutionary:
Traditional item IDs are arbitrary: item 12345 and item 12346 have no semantic relationship. But Semantic IDs encode meaning through hierarchical structure:
Semantic ID structure (4 tokens per item):
─────────────────────────────────────────────────────────────────────────
Token 1 (coarse): Category-level (electronics, clothing, books)
Token 2: Subcategory (phones, laptops, headphones)
Token 3: Style/brand cluster (premium, budget, specific brand)
Token 4 (fine): Specific item identifier
Example:
iPhone 15 Pro: [47, 12, 89, 156]
iPhone 15: [47, 12, 89, 203] ← Shares first 3 tokens!
Samsung Galaxy: [47, 12, 45, 78] ← Shares first 2 tokens
Running shoes: [23, 67, 34, 12] ← Different category, different prefix
This structure enables:
- Efficient beam search: Prune entire categories early in generation
- Semantic generalization: Model learns "electronics" behavior applies to new electronics
- Natural diversity: Different beam paths explore different categories
How Semantic IDs are learned (RQ-VAE):
Residual Quantization Vector-Autoencoder (RQ-VAE) learns the tokenization:
-
Encode item to embedding: Use content encoder (e.g., BERT on description, CNN on image) to get dense vector
-
Quantize residually: At each level, find nearest codebook entry, compute residual (what's left), quantize residual at next level
-
Train end-to-end: Reconstruction loss ensures semantic IDs can decode back to original embedding; quantization learns meaningful clusters
The hierarchical clustering effect:
RQ-VAE naturally creates a hierarchy because each level explains increasingly fine-grained details:
- Level 1: Broad categories (high variance features)
- Level 2: Subcategories (medium variance)
- Level 3+: Specific attributes (low variance)
Similar items end up with shared prefixes, enabling efficient tree-like search during generation.
class SemanticIDTokenizer:
"""
Convert items to semantic IDs using RQ-VAE.
Items with similar content get similar token prefixes.
"""
def __init__(
self,
num_codebooks: int = 4, # Number of tokens per item
codebook_size: int = 256, # Vocabulary size per position
embedding_dim: int = 768,
):
self.num_codebooks = num_codebooks
self.codebook_size = codebook_size
# Residual quantization codebooks
self.codebooks = nn.ParameterList([
nn.Parameter(torch.randn(codebook_size, embedding_dim))
for _ in range(num_codebooks)
])
def encode(self, item_embedding: torch.Tensor) -> torch.Tensor:
"""
Convert dense embedding to sequence of discrete tokens.
Args:
item_embedding: (batch, embedding_dim) from content encoder
Returns:
semantic_id: (batch, num_codebooks) token indices
"""
tokens = []
residual = item_embedding
for codebook in self.codebooks:
# Find nearest codebook entry
distances = torch.cdist(residual, codebook)
token_idx = distances.argmin(dim=-1)
tokens.append(token_idx)
# Compute residual for next level
quantized = codebook[token_idx]
residual = residual - quantized
return torch.stack(tokens, dim=-1)
class TIGERModel(nn.Module):
"""
TIGER: Generative recommendation via semantic ID generation.
Paper: https://arxiv.org/abs/2305.05065
"""
def __init__(
self,
tokenizer: SemanticIDTokenizer,
hidden_dim: int = 256,
num_layers: int = 6,
):
super().__init__()
self.tokenizer = tokenizer
vocab_size = tokenizer.codebook_size * tokenizer.num_codebooks
# Token embeddings (shared across codebook positions)
self.token_embedding = nn.Embedding(vocab_size, hidden_dim)
self.position_embedding = nn.Embedding(200, hidden_dim)
# Transformer decoder
decoder_layer = nn.TransformerDecoderLayer(
d_model=hidden_dim,
nhead=8,
dim_feedforward=hidden_dim * 4,
batch_first=True,
)
self.decoder = nn.TransformerDecoder(decoder_layer, num_layers)
# Output heads (one per codebook position)
self.output_heads = nn.ModuleList([
nn.Linear(hidden_dim, tokenizer.codebook_size)
for _ in range(tokenizer.num_codebooks)
])
def forward(self, history_ids: torch.Tensor) -> list[torch.Tensor]:
"""
Generate next item's semantic ID given history.
Args:
history_ids: (batch, seq_len, num_codebooks) - flattened semantic IDs
Returns:
logits: List of (batch, codebook_size) for each codebook position
"""
B, L, K = history_ids.shape
# Flatten semantic IDs to sequence
# [item1_tok1, item1_tok2, ..., item2_tok1, ...]
flat_ids = history_ids.view(B, L * K)
# Embed tokens
hidden = self.token_embedding(flat_ids)
positions = torch.arange(L * K, device=hidden.device)
hidden = hidden + self.position_embedding(positions)
# Decode (causal attention)
hidden = self.decoder(hidden, hidden)
# Generate semantic ID tokens autoregressively
logits = []
for i, head in enumerate(self.output_heads):
# Use last position for each codebook
pos = -K + i if i < K else -1
logits.append(head(hidden[:, pos]))
return logits
def generate(self, history_ids: torch.Tensor, num_beams: int = 10) -> torch.Tensor:
"""
Beam search generation for next item's semantic ID.
"""
# Beam search over semantic ID space
# Similar to text generation but in item token space
pass
Recent Advances in Semantic IDs (2024-2025)
| Model | Innovation | Key Result |
|---|---|---|
| LETTER | Collaborative + diversity regularization in RQ-VAE | Better codebook utilization |
| LC-Rec | Aligns LLM semantic space with collaborative signals | Cross-domain transfer |
| TIGER++ | Whitening + EMA for codebook learning | More stable training |
| RPG | Parallel (non-autoregressive) ID generation | 4x faster inference |
| LLaDA-Rec | Discrete diffusion for semantic IDs | Better diversity |
Trade-offs with Semantic IDs:
- Search vs Recommendation: IDs tuned for search hurt recommendation, and vice versa
- Static vs Dynamic: Fixed IDs assume universal similarity; some work explores adaptive tokenization
- Length: TIGER uses 4 tokens; longer IDs (32-64) capture more semantics but slow inference
Part VIII: Training Strategies
Loss Functions for Sequential Recommendation
The choice of loss function significantly impacts performance—often more than architectural changes. Understanding these losses helps you choose the right one for your constraints.
The fundamental tradeoff:
With millions of items, you can't compute scores for everything. Each loss function handles this differently:
Loss Function Complexity per Example Quality Use Case
─────────────────────────────────────────────────────────────────────────
Full Softmax O(num_items) Best Small catalogs (<100K)
Sampled Softmax O(num_negatives) Good Large catalogs
BPR (Pairwise) O(num_negatives) Good Implicit feedback
InfoNCE O(batch_size) Good Self-supervised
Full Softmax: The gold standard
Full softmax computes the probability distribution over ALL items. It's optimal because the model explicitly learns to rank the target higher than every other item. The problem: with 10M items, this means 10M forward passes per example—prohibitive for training.
Sampled Softmax: The practical choice
Instead of all items, sample a subset of negatives. With 256-1024 negatives, you get ~95% of full softmax quality at a fraction of the cost. The key insight: most items are easy negatives (clearly not relevant), so you don't need to see them all.
Critical choice: how to sample negatives?
- Uniform: Simple, but wastes samples on trivially irrelevant items
- Popularity-weighted: Sample popular items more often (they're harder negatives)
- In-batch negatives: Use other users' positives as negatives (free, diverse)
- Hard negatives: Mine items the model currently ranks incorrectly (most informative, but expensive)
BPR (Bayesian Personalized Ranking):
BPR doesn't try to predict absolute scores—it optimizes the relative ordering. The loss says: "score the positive higher than negatives, I don't care by how much."
This works well for implicit feedback (clicks, views) where we don't have explicit ratings. A click doesn't mean "great item"—it means "better than items not clicked."
InfoNCE (Contrastive):
InfoNCE treats recommendation as a contrastive learning problem: pull together (user, positive item) pairs, push apart (user, negative item) pairs. The temperature parameter controls how hard the model focuses on difficult negatives:
- Low (0.01): Focus on very hard negatives
- High (1.0): Treat all negatives more equally
InfoNCE is popular in self-supervised pretraining because it doesn't require labels—just pairs of "similar" items (e.g., items in the same session).
Practical recommendations:
- Start with sampled softmax (256-512 negatives, popularity-weighted)
- Use in-batch negatives when batch size is large (saves computation)
- Add hard negative mining once model converges (pushes accuracy further)
- Full softmax only if you have <100K items and GPU memory to spare
# 1. Cross-Entropy (Full Softmax)
# Most accurate but expensive for large item catalogs
def full_softmax_loss(logits, targets):
"""
logits: (B, num_items) - scores for all items
targets: (B,) - target item IDs
"""
return F.cross_entropy(logits, targets)
# 2. Sampled Softmax
# Approximate full softmax with negative sampling
def sampled_softmax_loss(
hidden: torch.Tensor,
targets: torch.Tensor,
item_embeddings: torch.Tensor,
num_negatives: int = 100,
):
"""
hidden: (B, D) - sequence representations
targets: (B,) - target item IDs
item_embeddings: (num_items, D)
"""
B = hidden.shape[0]
# Positive scores
target_emb = item_embeddings[targets] # (B, D)
pos_scores = (hidden * target_emb).sum(-1) # (B,)
# Sample negatives (uniform or popularity-based)
neg_indices = torch.randint(0, item_embeddings.shape[0], (B, num_negatives))
neg_emb = item_embeddings[neg_indices] # (B, num_neg, D)
neg_scores = torch.bmm(neg_emb, hidden.unsqueeze(-1)).squeeze(-1) # (B, num_neg)
# Combine and compute log-softmax
all_scores = torch.cat([pos_scores.unsqueeze(-1), neg_scores], dim=-1)
loss = -F.log_softmax(all_scores, dim=-1)[:, 0].mean()
return loss
# 3. Binary Cross-Entropy (BPR-style)
# Pairwise: score positive higher than negatives
def bpr_loss(pos_scores, neg_scores):
"""
pos_scores: (B,) - scores for positive items
neg_scores: (B, num_neg) - scores for negative items
"""
# For each positive, compare with each negative
diff = pos_scores.unsqueeze(-1) - neg_scores # (B, num_neg)
loss = -F.logsigmoid(diff).mean()
return loss
# 4. InfoNCE / Contrastive Loss
# Popular in self-supervised settings
def infonce_loss(hidden, targets, item_embeddings, temperature=0.07):
"""Contrastive loss with in-batch negatives."""
# Positive pairs
target_emb = item_embeddings[targets] # (B, D)
# Normalize
hidden = F.normalize(hidden, dim=-1)
target_emb = F.normalize(target_emb, dim=-1)
# Similarity matrix (in-batch negatives)
sim = torch.mm(hidden, target_emb.T) / temperature # (B, B)
# Diagonal elements are positives
labels = torch.arange(hidden.shape[0], device=hidden.device)
loss = F.cross_entropy(sim, labels)
return loss
Negative Sampling Strategies
Not all negatives are equally informative. The choice of how to sample negatives can impact model quality as much as architecture choices.
The problem with uniform sampling:
With uniform sampling, most negatives are "easy"—clearly irrelevant items that the model quickly learns to score low. These easy negatives provide little training signal:
User interested in: Electronics
─────────────────────────────────────────────────────────────────────────
Uniform negatives: [Garden hose, Dog food, Baby clothes, Romance novel]
↳ Model already scores these near 0. No learning.
Better negatives: [Samsung Galaxy, Laptop stand, Wireless mouse, AirPods]
↳ Model must work harder to distinguish. More learning.
Four sampling strategies, ranked by effectiveness:
-
Uniform sampling: Baseline. Fast but wasteful. Most samples are trivially easy.
-
Popularity-based sampling: Sample popular items more often. Popular items are harder negatives because they're plausibly relevant to many users. Adds ~5-10% improvement over uniform.
-
In-batch negatives: Brilliant trick—use other users' positive items as your negatives. Zero additional computation (you already have their embeddings). Works well when batch size is large (512+). Popular in contrastive learning.
-
Hard negative mining: Explicitly find items the model currently ranks incorrectly. Most expensive (requires scoring many items) but most effective. Can add 10-20% improvement. Common approach: periodically compute top-K similar items, use those as hard negatives.
The false negative problem:
A sampled "negative" might actually be a good recommendation:
- User hasn't seen item X yet, but would love it
- You sample X as negative, penalize model for scoring it highly
- Model learns the wrong thing!
Solutions:
- Filter user history: Never sample items the user interacted with
- Debiasing: Weight loss by inverse propensity (how likely was this item to be seen?)
- Accept some noise: With enough negatives, signal overwhelms noise
Curriculum learning for negatives:
Start with easy negatives, gradually increase difficulty:
- Epoch 1-5: Uniform sampling (fast convergence on obvious patterns)
- Epoch 5-10: Popularity-weighted (harder distinction)
- Epoch 10+: Hard negative mining (fine-grained ranking)
This curriculum often converges faster than starting with hard negatives.
class NegativeSampler:
"""Different strategies for sampling negative items."""
def __init__(self, num_items: int, item_popularity: torch.Tensor = None):
self.num_items = num_items
self.item_popularity = item_popularity
def uniform(self, batch_size: int, num_negatives: int) -> torch.Tensor:
"""Random uniform sampling."""
return torch.randint(0, self.num_items, (batch_size, num_negatives))
def popularity(self, batch_size: int, num_negatives: int) -> torch.Tensor:
"""Sample proportional to item popularity (harder negatives)."""
probs = self.item_popularity / self.item_popularity.sum()
indices = torch.multinomial(probs, batch_size * num_negatives, replacement=True)
return indices.view(batch_size, num_negatives)
def in_batch(self, targets: torch.Tensor) -> torch.Tensor:
"""Use other items in batch as negatives (efficient)."""
B = targets.shape[0]
# All items in batch except self
negatives = targets.unsqueeze(0).expand(B, -1) # (B, B)
mask = ~torch.eye(B, dtype=torch.bool, device=targets.device)
return negatives[mask].view(B, B-1)
def hard_negatives(
self,
hidden: torch.Tensor,
item_embeddings: torch.Tensor,
num_negatives: int,
exclude: torch.Tensor,
) -> torch.Tensor:
"""
Sample items with high scores but wrong labels.
Expensive but very effective.
"""
# Score all items
scores = torch.mm(hidden, item_embeddings.T) # (B, num_items)
# Mask out true positives
scores.scatter_(1, exclude.unsqueeze(-1), float('-inf'))
# Sample from top-k
_, top_indices = scores.topk(num_negatives * 10, dim=-1)
# Random sample from top-k
rand_indices = torch.randint(0, num_negatives * 10, (hidden.shape[0], num_negatives))
negatives = top_indices.gather(1, rand_indices)
return negatives
Part IX: Handling Scale and Efficiency
The Item Vocabulary Problem
Real recommendation systems have millions of items. Full softmax is infeasible:
┌─────────────────────────────────────────────────────────────────────────┐
│ SCALING CHALLENGES │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ FULL SOFTMAX COST: │
│ ─────────────────── │
│ │
│ Items: 10 million │
│ Hidden dim: 256 │
│ Batch size: 1024 │
│ │
│ Embedding table: 10M × 256 = 2.56 GB │
│ Output projection: 10M × 256 = 2.56 GB │
│ Softmax computation: 1024 × 10M = 10.24 billion ops per batch │
│ │
│ This is INFEASIBLE for training and inference. │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ SOLUTIONS: │
│ ────────── │
│ │
│ 1. SAMPLED SOFTMAX │
│ - Train with ~1000 negatives instead of all items │
│ - 10,000x reduction in computation │
│ │
│ 2. TWO-TOWER RETRIEVAL │
│ - Separate user and item encoders │
│ - Use ANN (approximate nearest neighbor) for retrieval │
│ - Inference: O(log N) instead of O(N) │
│ │
│ 3. HIERARCHICAL SOFTMAX │
│ - Organize items in tree structure │
│ - O(log N) path instead of O(N) softmax │
│ │
│ 4. HASH EMBEDDINGS │
│ - Multiple items share embedding buckets │
│ - Reduces embedding table size 10-100x │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Two-Tower Architecture
The standard production architecture separates user and item encoding. This pattern is universal in industrial recommendation systems—YouTube, Pinterest, LinkedIn, and virtually every large-scale system uses some variant.
Why two towers instead of one?
A unified model that jointly encodes (user, item) pairs seems more powerful—it can capture complex user-item interactions. But it has a fatal flaw:
Unified Model (Dense Retrieval):
─────────────────────────────────────────────────────────────────────────
For each user, score all 10M items:
for item in all_items:
score = model(user_features, item_features) # Full forward pass
Cost: 10M forward passes per user per request
At 1000 QPS = 10 BILLION forward passes per second = IMPOSSIBLE
Two-Tower Model:
─────────────────────────────────────────────────────────────────────────
OFFLINE (once per day):
for item in all_items:
item_embeddings[item] = item_tower(item_features) # 10M passes total
build_ann_index(item_embeddings)
ONLINE (per request):
user_emb = user_tower(user_features) # 1 forward pass
candidates = ann_index.search(user_emb, k=100) # O(log N) lookup
Cost: 1 forward pass + index lookup per request
At 1000 QPS = 1000 forward passes per second = EASY
The key constraint: Dot product interaction only
For ANN indexes to work, the score must decompose as:
No cross-features, no nonlinear combinations. This seems limiting, but the towers can be arbitrarily complex—SASRec, BERT4Rec, any transformer. All user-item interaction is compressed into the learned embeddings.
What each tower learns:
- User tower: Learns to compress user history into a "preference vector" that captures what kinds of items this user likes
- Item tower: Learns to represent items in the same space, such that "relevant" items are close to users who would like them
The embedding space is shared—users and items live in the same vector space, making similarity computation possible.
Temperature scaling for contrastive learning:
Two-tower models are typically trained with contrastive loss:
The temperature is critical:
- High temperature (1.0): Softer distribution, model focuses on all negatives equally
- Low temperature (0.05): Sharper distribution, model focuses on hardest negatives
Most systems use for sharp, discriminative embeddings.
class TwoTowerModel(nn.Module):
"""
Two-tower architecture for large-scale retrieval.
User tower: Encodes user history
Item tower: Encodes item features
"""
def __init__(
self,
num_items: int,
hidden_dim: int = 128,
user_tower: nn.Module = None, # e.g., SASRec
):
super().__init__()
self.user_tower = user_tower or SASRec(num_items, hidden_dim=hidden_dim)
# Item tower: Could be simple embedding or full encoder
self.item_embedding = nn.Embedding(num_items, hidden_dim)
self.item_mlp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
)
def encode_user(self, item_seq: torch.Tensor) -> torch.Tensor:
"""Encode user from their interaction sequence."""
hidden = self.user_tower(item_seq)
return hidden[:, -1, :] # Last position
def encode_items(self, item_ids: torch.Tensor) -> torch.Tensor:
"""Encode items."""
emb = self.item_embedding(item_ids)
return self.item_mlp(emb)
def forward(self, item_seq: torch.Tensor, candidate_items: torch.Tensor):
"""
Score candidate items for users.
Args:
item_seq: (B, L) user sequences
candidate_items: (B, C) candidate item IDs
Returns:
scores: (B, C) scores for each candidate
"""
user_emb = self.encode_user(item_seq) # (B, D)
item_emb = self.encode_items(candidate_items) # (B, C, D)
scores = torch.bmm(item_emb, user_emb.unsqueeze(-1)).squeeze(-1)
return scores
# At inference time:
# 1. Pre-compute all item embeddings
# 2. Build ANN index (FAISS, ScaNN, etc.)
# 3. For each user, encode once, retrieve top-k via ANN
import faiss
def build_item_index(model, num_items, hidden_dim):
"""Build FAISS index for fast retrieval."""
# Encode all items
all_items = torch.arange(num_items)
item_embeddings = model.encode_items(all_items).detach().numpy()
# Build index
index = faiss.IndexFlatIP(hidden_dim) # Inner product
index.add(item_embeddings)
return index
def retrieve_candidates(model, index, item_seq, k=100):
"""Retrieve top-k candidates for a user."""
user_emb = model.encode_user(item_seq).detach().numpy()
scores, indices = index.search(user_emb, k)
return indices, scores
Efficient Attention for Long Sequences
Users can have thousands of interactions. Standard attention is :
# Linear attention approximations
class LinearAttention(nn.Module):
"""
Linear attention via kernel feature maps.
O(n) instead of O(n^2).
"""
def __init__(self, hidden_dim: int, num_heads: int):
super().__init__()
self.num_heads = num_heads
self.head_dim = hidden_dim // num_heads
self.q_proj = nn.Linear(hidden_dim, hidden_dim)
self.k_proj = nn.Linear(hidden_dim, hidden_dim)
self.v_proj = nn.Linear(hidden_dim, hidden_dim)
self.out_proj = nn.Linear(hidden_dim, hidden_dim)
def feature_map(self, x):
"""ELU-based feature map for positive attention."""
return F.elu(x) + 1
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, L, D = x.shape
H, d = self.num_heads, self.head_dim
Q = self.feature_map(self.q_proj(x)).view(B, L, H, d)
K = self.feature_map(self.k_proj(x)).view(B, L, H, d)
V = self.v_proj(x).view(B, L, H, d)
# Linear attention: (Q @ K.T) @ V = Q @ (K.T @ V)
# Compute K.T @ V first: O(L * d^2) instead of O(L^2 * d)
KV = torch.einsum('blhd,blhe->bhde', K, V) # (B, H, d, d)
# Q @ KV
output = torch.einsum('blhd,bhde->blhe', Q, KV) # (B, L, H, d)
# Normalize
K_sum = K.sum(dim=1) # (B, H, d)
normalizer = torch.einsum('blhd,bhd->blh', Q, K_sum) # (B, L, H)
output = output / (normalizer.unsqueeze(-1) + 1e-6)
output = output.reshape(B, L, D)
return self.out_proj(output)
Part X: Incorporating Side Information
Beyond Item IDs
Pure ID-based models ignore rich item metadata. Modern systems incorporate:
class SASRecWithFeatures(nn.Module):
"""SASRec with item and context features."""
def __init__(
self,
num_items: int,
num_categories: int,
num_brands: int,
hidden_dim: int = 64,
**kwargs
):
super().__init__()
# ID embeddings
self.item_embedding = nn.Embedding(num_items + 1, hidden_dim, padding_idx=0)
# Feature embeddings
self.category_embedding = nn.Embedding(num_categories, hidden_dim // 4)
self.brand_embedding = nn.Embedding(num_brands, hidden_dim // 4)
# Continuous features
self.price_proj = nn.Linear(1, hidden_dim // 4)
# Combine features
feature_dim = hidden_dim + hidden_dim // 4 * 3
self.feature_fusion = nn.Linear(feature_dim, hidden_dim)
# Rest of transformer...
self.transformer = ...
def embed_items(
self,
item_ids: torch.Tensor,
categories: torch.Tensor,
brands: torch.Tensor,
prices: torch.Tensor,
) -> torch.Tensor:
"""Combine ID and feature embeddings."""
id_emb = self.item_embedding(item_ids)
cat_emb = self.category_embedding(categories)
brand_emb = self.brand_embedding(brands)
price_emb = self.price_proj(prices.unsqueeze(-1))
combined = torch.cat([id_emb, cat_emb, brand_emb, price_emb], dim=-1)
return self.feature_fusion(combined)
Multi-Modal Features
For rich content like images and descriptions:
class MultiModalSASRec(nn.Module):
"""SASRec with text and image features."""
def __init__(self, hidden_dim: int = 256):
super().__init__()
# Pre-trained encoders (frozen or fine-tuned)
self.text_encoder = ... # e.g., sentence-transformers
self.image_encoder = ... # e.g., CLIP vision encoder
# Project to common space
self.text_proj = nn.Linear(768, hidden_dim) # Assuming BERT-base
self.image_proj = nn.Linear(512, hidden_dim) # Assuming CLIP
self.id_embedding = nn.Embedding(num_items, hidden_dim)
# Fusion
self.fusion = nn.Sequential(
nn.Linear(hidden_dim * 3, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
)
def embed_items(self, item_ids, text_features, image_features):
id_emb = self.id_embedding(item_ids)
text_emb = self.text_proj(text_features)
image_emb = self.image_proj(image_features)
combined = torch.cat([id_emb, text_emb, image_emb], dim=-1)
return self.fusion(combined)
Part XI: Evaluation and Metrics
Standard Metrics
def evaluate_recommendations(
predictions: torch.Tensor, # (num_users, k) top-k item indices
ground_truth: torch.Tensor, # (num_users,) true next items
k_values: list = [1, 5, 10, 20],
):
"""Compute standard recommendation metrics."""
results = {}
for k in k_values:
top_k = predictions[:, :k]
# Hit Rate (HR@k): Did we include the true item?
hits = (top_k == ground_truth.unsqueeze(-1)).any(dim=-1)
results[f'HR@{k}'] = hits.float().mean().item()
# NDCG@k: Normalized Discounted Cumulative Gain
# Accounts for position (higher rank = better)
positions = (top_k == ground_truth.unsqueeze(-1)).float()
ranks = torch.arange(1, k + 1, device=predictions.device).float()
dcg = (positions / torch.log2(ranks + 1)).sum(dim=-1)
idcg = 1.0 # Perfect ranking has single relevant item at position 1
results[f'NDCG@{k}'] = (dcg / idcg).mean().item()
# MRR: Mean Reciprocal Rank
match_positions = (top_k == ground_truth.unsqueeze(-1)).float().argmax(dim=-1)
has_match = (top_k == ground_truth.unsqueeze(-1)).any(dim=-1)
rr = torch.where(has_match, 1.0 / (match_positions + 1), torch.zeros_like(match_positions.float()))
results[f'MRR@{k}'] = rr.mean().item()
return results
Beyond Accuracy: Coverage and Diversity
def evaluate_coverage_diversity(
predictions: torch.Tensor, # (num_users, k)
item_popularity: torch.Tensor, # (num_items,) interaction counts
num_items: int,
):
"""Evaluate beyond accuracy metrics."""
# Coverage: What fraction of items are ever recommended?
unique_items = predictions.unique()
coverage = len(unique_items) / num_items
# Gini coefficient: How unequal is the recommendation distribution?
# Lower = more equal = better diversity
rec_counts = torch.bincount(predictions.flatten(), minlength=num_items).float()
sorted_counts = rec_counts.sort().values
n = len(sorted_counts)
index = torch.arange(1, n + 1, device=sorted_counts.device).float()
gini = (2 * (index * sorted_counts).sum() / (n * sorted_counts.sum())) - (n + 1) / n
# Popularity bias: Are we over-recommending popular items?
rec_popularity = item_popularity[predictions].float().mean()
overall_popularity = item_popularity.float().mean()
popularity_bias = rec_popularity / overall_popularity
return {
'coverage': coverage,
'gini': gini.item(),
'popularity_bias': popularity_bias.item(),
}
Part XII: Production Considerations
Serving Architecture
┌─────────────────────────────────────────────────────────────────────────┐
│ PRODUCTION SERVING ARCHITECTURE │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ REQUEST FLOW: │
│ ───────────── │
│ │
│ User Request │
│ ↓ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ RETRIEVAL STAGE │ │
│ │ (Fast, broad filtering: 10M items → 1000 candidates) │ │
│ │ │ │
│ │ • User tower encodes recent history │ │
│ │ • ANN search against pre-computed item embeddings │ │
│ │ • Latency: 5-10ms │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ ↓ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ RANKING STAGE │ │
│ │ (Accurate, fine-grained scoring: 1000 → 100) │ │
│ │ │ │
│ │ • Full transformer model scores candidates │ │
│ │ • Uses rich features, cross-attention │ │
│ │ • Latency: 10-30ms │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ ↓ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ RE-RANKING STAGE │ │
│ │ (Business logic, diversity: 100 → 20) │ │
│ │ │ │
│ │ • Apply business rules (availability, margins) │ │
│ │ • Ensure diversity (categories, price ranges) │ │
│ │ • Personalization constraints │ │
│ │ • Latency: 5ms │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ ↓ │
│ Response (20 recommendations) │
│ Total latency: 20-50ms │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Model Updates and Freshness
# Handling new items (cold start)
class ItemEmbeddingWithFallback:
"""
Handle new items that weren't in training.
"""
def __init__(self, model, content_encoder):
self.model = model
self.content_encoder = content_encoder # e.g., BERT for item text
# Cache for computed embeddings
self.embedding_cache = {}
def get_embedding(self, item_id: int, item_features: dict = None):
"""Get embedding, with fallback for new items."""
# Check if item was in training
if item_id < self.model.num_items:
return self.model.item_embedding.weight[item_id]
# New item: use content features
if item_id in self.embedding_cache:
return self.embedding_cache[item_id]
if item_features is None:
# No features: use category average or random
return self.model.item_embedding.weight.mean(dim=0)
# Encode content features
content_emb = self.content_encoder(item_features)
self.embedding_cache[item_id] = content_emb
return content_emb
Part XIII: Production Systems at Scale
Industrial Deployments (2024-2025)
The gap between academic research and production is closing. Here are the transformer-based systems powering major platforms:
┌─────────────────────────────────────────────────────────────────────────┐
│ TRANSFORMER RECSYS IN PRODUCTION (2024-2025) │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ COMPANY MODEL KEY INNOVATION │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ Meta HSTU Trillion-param, no-softmax attention │
│ 12.4%+ improvement in production │
│ │
│ LinkedIn LiGR Gated residuals, set-wise scoring │
│ Deprecated feature engineering │
│ │
│ ByteDance Monolith Real-time online training │
│ Collision-less embeddings │
│ │
│ Kuaishou OneRec End-to-end generative (no retrieval) │
│ DPO alignment, 1.6% watch-time gain │
│ │
│ Pinterest PinnerFormer Dense all-action loss │
│ Long-term engagement prediction │
│ │
│ Spotify Semantic IDs LLaMA fine-tuning for domain │
│ Unified search + recommendation │
│ │
│ Netflix UniCoRn Unified contextual ranker │
│ FM-Intent for intent prediction │
│ │
│ YouTube Transformer Watch Transformer layers for engagement │
│ Next Multi-task learning + retrieval │
│ │
│ Albatross DenseRec Sequential embeddings for cold-start │
│ Dual-path (content + behavior) │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Monolith: ByteDance's Real-Time System
Monolith powers TikTok's recommendation engine with real-time online training—the model updates continuously as users interact.
# Monolith architecture concepts (simplified)
class MonolithSystem:
"""
ByteDance's real-time recommendation system.
Key: Continuous online learning from streaming data.
"""
def __init__(self):
# Collision-less embedding table using cuckoo hashing
self.embedding_table = CollisionlessEmbeddingTable(
num_buckets=1_000_000_000, # Billion-scale
embedding_dim=128,
)
# Parameter server architecture
self.ps = ParameterServer()
# Streaming training pipeline
self.kafka_actions = KafkaConsumer('user_actions')
self.kafka_features = KafkaConsumer('features')
self.flink_joiner = FlinkStreamJoiner()
def online_training_loop(self):
"""
Continuous training from streaming data.
Model updates propagate to serving within minutes.
"""
while True:
# Join action and feature streams
training_examples = self.flink_joiner.join(
self.kafka_actions,
self.kafka_features,
window_size='1_minute'
)
# Update model
for batch in training_examples:
gradients = self.compute_gradients(batch)
self.ps.apply_gradients(gradients)
# Sync to serving (near real-time)
self.ps.sync_to_inference()
class CollisionlessEmbeddingTable:
"""
Cuckoo hashing for collision-free embeddings.
Handles billions of items without hash collisions.
"""
def __init__(self, num_buckets: int, embedding_dim: int):
# Two tables with different hash functions
self.table1 = torch.zeros(num_buckets, embedding_dim)
self.table2 = torch.zeros(num_buckets, embedding_dim)
# Frequency filtering (remove rare items)
self.access_counts = torch.zeros(num_buckets)
self.min_frequency = 5
# TTL for expiring old embeddings
self.last_access = torch.zeros(num_buckets)
self.ttl_days = 7
def lookup(self, item_ids: torch.Tensor) -> torch.Tensor:
# Try first hash
h1 = self.hash1(item_ids)
result = self.table1[h1]
# If collision, try second hash
h2 = self.hash2(item_ids)
collided = (self.access_counts[h1] == 0)
result[collided] = self.table2[h2[collided]]
return result
Monolith innovations:
- Real-time updates: Model parameters sync to serving every ~1 minute
- Collision-less embeddings: Cuckoo hashing eliminates hash collisions
- Expiring embeddings: TTL removes stale items automatically
- Frequency filtering: Ignores items with <5 interactions
OneRec: Kuaishou's Unified Generative Recommender
OneRec (February 2025) is the first end-to-end generative recommender deployed at scale, replacing the traditional retrieve-rank pipeline.
┌─────────────────────────────────────────────────────────────────────────┐
│ TRADITIONAL vs ONEREC ARCHITECTURE │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ TRADITIONAL (3-Stage): │
│ ─────────────────────── │
│ │
│ Retrieval → Pre-Ranking → Ranking → Results │
│ (10M→1K) (1K→100) (100→20) │
│ │
│ • 3 separate models to maintain │
│ • Complex infrastructure │
│ • Information loss between stages │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ ONEREC (Unified): │
│ ───────────────── │
│ │
│ User History → Encoder-Decoder → Generate Session → Results │
│ ↓ │
│ Sparse MoE │
│ ↓ │
│ DPO Alignment │
│ │
│ • Single model end-to-end │
│ • Session-wise generation (not point-by-point) │
│ • 10.6% OpEx of traditional pipeline │
│ │
└─────────────────────────────────────────────────────────────────────────┘
OneRec key components:
- Session-wise Generation: Generates a coherent session of recommendations, not individual items
- Sparse MoE: Scales to 10x FLOPs with mixture of experts
- Iterative DPO Alignment: Uses reward models to generate preference pairs for Direct Preference Optimization
Production results at Kuaishou:
- 1.6% watch-time increase (substantial at Kuaishou's scale)
- 25% of total QPS served by OneRec
- 10.6% OpEx compared to traditional pipeline
- Demonstrated scaling laws for recommendation models
PinnerFormer: Pinterest's Long-Term Engagement Model
PinnerFormer (deployed since 2021) focuses on predicting long-term user engagement, not just next-click.
class PinnerFormer:
"""
Pinterest's user representation model.
Key: Predict long-term engagement, not just next action.
"""
def __init__(self, hidden_dim: int = 256, num_layers: int = 4):
# Transformer encoder for user sequences
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(hidden_dim, nhead=8),
num_layers=num_layers
)
# Dense all-action loss (key innovation)
# Predict ALL future actions, not just next one
self.future_predictor = nn.Linear(hidden_dim, hidden_dim)
def dense_all_action_loss(
self,
user_history: torch.Tensor,
future_actions: torch.Tensor, # All actions in next N days
) -> torch.Tensor:
"""
Loss that considers ALL future engagement, not just next action.
Standard: Predict item at t+1 given items 1...t
PinnerFormer: Predict items at t+1...t+N given items 1...t
"""
# Encode history
history_repr = self.transformer(user_history)
user_embedding = history_repr[:, -1] # Last position
# Predict future engagement
future_pred = self.future_predictor(user_embedding)
# Loss against ALL future actions (not just next)
future_embeddings = self.item_encoder(future_actions)
scores = torch.mm(future_pred, future_embeddings.T)
# Each future action is a positive
loss = -torch.log_softmax(scores, dim=-1).diag().mean()
return loss
Why dense all-action loss matters:
- Next-item prediction optimizes for immediate clicks
- Dense all-action optimizes for long-term value
- Better alignment with business metrics (retention, LTV)
YouTube: Transformer-Era Watch Next
YouTube evolved from their foundational 2016 DNN system to incorporate transformer architectures for their Watch Next recommendations. The 2024-2025 system uses multi-task learning with transformer layers to jointly optimize for engagement and satisfaction.
┌─────────────────────────────────────────────────────────────────────────┐
│ YOUTUBE TRANSFORMER ARCHITECTURE (2024-2025) │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ CANDIDATE GENERATION: │
│ ──────────────────── │
│ │
│ User History → Transformer Encoder → User Embedding │
│ ↓ │
│ Watch sequence: [v₁, v₂, v₃, ..., vₙ] │
│ ↓ │
│ Multi-head self-attention (captures viewing patterns) │
│ ↓ │
│ Video-level cross-attention (video features) │
│ ↓ │
│ [USER] token aggregation │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ MULTI-TASK RANKING: │
│ ────────────────── │
│ │
│ User + Candidate → Transformer Ranking Model → Scores │
│ │
│ Objectives (jointly trained): │
│ • Click probability (immediate engagement) │
│ • Watch time (session depth) │
│ • Long-term satisfaction (surveys + implicit signals) │
│ • Creator fairness (distribution constraints) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ KEY INNOVATIONS: │
│ • Sequential modeling with transformers (vs. bag-of-features) │
│ • Multi-objective optimization with Pareto frontiers │
│ • Exploration bonuses for new/underserved content │
│ • Real-time feature updates via streaming pipelines │
│ │
└─────────────────────────────────────────────────────────────────────────┘
YouTube's transformer evolution:
- 2016-2019: Original two-tower DNN with averaged watch history
- 2019-2022: RNN-based sequential modeling for temporal patterns
- 2022-2024: Transformer encoders replace RNNs for user sequences
- 2024-2025: Full transformer ranking with multi-task heads
Key production learnings from YouTube:
- Causal masking is essential: User can only have seen previous videos
- Timestamp positional embeddings: When videos were watched matters as much as what
- Multi-objective balance: Pure engagement optimization leads to "rabbit holes"; satisfaction signals are crucial
- Scale challenges: Billions of users, millions of videos, sub-100ms latency requirements
Albatross AI: DenseRec for Cold-Start
Albatross AI, founded by ex-Amazon recommendation scientists, addresses one of RecSys's hardest problems: cold-start for new items. Their DenseRec paper (RecSys 2025) introduces a dual-path embedding architecture that combines content signals with sequential behavior patterns.
┌─────────────────────────────────────────────────────────────────────────┐
│ DENSEREC ARCHITECTURE │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ THE COLD-START PROBLEM: │
│ ─────────────────────── │
│ │
│ New Item (no interactions) │
│ ↓ │
│ Traditional RecSys: "I have no embedding for this!" │
│ ↓ │
│ DenseRec: "I can generate one from content + category behavior" │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ DUAL-PATH EMBEDDING: │
│ ──────────────────── │
│ │
│ PATH 1: Content Encoder │
│ ───────────────────── │
│ • Text (titles, descriptions) → BERT/T5 encoder │
│ • Images → Vision transformer (ViT/CLIP) │
│ • Attributes → Categorical embeddings │
│ • Combined via learned fusion │
│ │
│ PATH 2: Sequential Behavior (Category-Level) │
│ ───────────────────────────────────────────── │
│ • Items in same category/brand → aggregated behavior │
│ • Temporal patterns → transformer encoder │
│ • Transfer signals from similar items │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ FUSION: │
│ │
│ Content Path ──┐ │
│ ├─→ Gated Fusion ─→ Dense Embedding │
│ Behavior Path ─┘ │
│ │
│ Gate learns when to trust content vs. behavior signals │
│ │
└─────────────────────────────────────────────────────────────────────────┘
DenseRec innovations:
- Category-level sequential patterns: New items inherit behavior patterns from their category
- Adaptive gating: Learns to weight content vs. behavior based on item maturity
- Temporal decay: Recent category behavior weighted more heavily
- Multi-modal content fusion: Text + image + attributes combined
# DenseRec concept (simplified)
class DenseRecEmbedding:
"""
Albatross's dual-path embedding for cold-start items.
"""
def __init__(self, hidden_dim: int = 256):
# Path 1: Content encoders
self.text_encoder = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
self.image_encoder = CLIPVisionModel.from_pretrained('openai/clip-vit-base-patch32')
self.attribute_encoder = nn.Embedding(num_attributes, hidden_dim)
# Path 2: Sequential behavior encoder
self.behavior_transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(hidden_dim, nhead=4),
num_layers=2
)
# Fusion gate
self.fusion_gate = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.Sigmoid()
)
# Final projection
self.projection = nn.Linear(hidden_dim, hidden_dim)
def get_embedding(
self,
text: str,
image: Optional[torch.Tensor],
attributes: List[int],
category_id: int,
interaction_count: int
) -> torch.Tensor:
# Content path
text_emb = self.text_encoder(text)
if image is not None:
image_emb = self.image_encoder(image)
content_emb = (text_emb + image_emb) / 2
else:
content_emb = text_emb
# Behavior path (category-level)
category_sequence = self.get_category_behavior(category_id)
behavior_emb = self.behavior_transformer(category_sequence)[:, -1]
# Adaptive fusion based on item maturity
gate_input = torch.cat([content_emb, behavior_emb], dim=-1)
gate = self.fusion_gate(gate_input)
# Cold items rely more on content, warm items on behavior
combined = gate * content_emb + (1 - gate) * behavior_emb
return self.projection(combined)
def get_category_behavior(self, category_id: int) -> torch.Tensor:
"""
Aggregate sequential behavior patterns from category.
New items inherit these patterns for warm-start embeddings.
"""
# Get recent interactions for items in this category
category_items = self.category_to_items[category_id]
category_sequences = [self.item_sequences[i] for i in category_items]
# Aggregate with temporal weighting
aggregated = self.temporal_aggregate(category_sequences)
return aggregated
Why DenseRec matters for production:
- Immediate recommendations for new products: No waiting for interaction data
- Catalog turnover handling: E-commerce sites add thousands of items daily
- Seasonal/trending items: New items can surface immediately based on content similarity
- Long-tail coverage: Items with few interactions get meaningful embeddings
Albatross AI background:
- Founded by former Amazon personalization scientists
- €12.5M Series A funding (2024)
- Customers include major European retailers
- Focus on e-commerce and marketplace recommendations
Part XIV: Reinforcement Learning for Recommendation Systems
Traditional supervised learning optimizes for immediate metrics (next-click prediction). But recommendations have long-term effects—showing clickbait might get clicks but hurts user retention. Reinforcement Learning (RL) models recommendations as a sequential decision problem, optimizing for long-term user value.
Why RL for Recommendations?
┌─────────────────────────────────────────────────────────────────────────┐
│ SUPERVISED LEARNING vs REINFORCEMENT LEARNING │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ SUPERVISED LEARNING (Standard RecSys): │
│ ───────────────────────────────────── │
│ │
│ Training: Log data → Predict clicked items │
│ Objective: Maximize P(click | user, context) │
│ │
│ Problems: │
│ • Optimizes for immediate engagement, not long-term value │
│ • Ignores sequential effects of recommendations │
│ • Exploitation bias: Shows what users already like │
│ • Can't explore to discover new user preferences │
│ • Feedback loops: Popular items get more popular │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ REINFORCEMENT LEARNING: │
│ ────────────────────── │
│ │
│ Formulation: │
│ • State: User history + context │
│ • Action: Item(s) to recommend │
│ • Reward: User engagement (click, purchase, watch time, return visit) │
│ • Policy: π(action | state) - the recommendation model │
│ │
│ Objective: Maximize cumulative discounted reward │
│ max E[Σ γ^t × r_t] where γ ∈ [0,1] is discount factor │
│ │
│ Benefits: │
│ • Optimizes for long-term user value (retention, LTV) │
│ • Natural exploration-exploitation trade-off │
│ • Models sequential nature of user sessions │
│ • Can incorporate delayed rewards (purchase after browsing) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ EXAMPLE: Video Recommendation │
│ ──────────────────────────────── │
│ │
│ Supervised: Recommend video with highest click probability │
│ Result: Clickbait thumbnails, sensational content │
│ │
│ RL: Maximize (watch time + return visits + subscriptions) │
│ Result: Quality content that builds long-term engagement │
│ │
└─────────────────────────────────────────────────────────────────────────┘
The MDP Formulation
Recommendation as a Markov Decision Process (MDP):
┌─────────────────────────────────────────────────────────────────────────┐
│ MDP FOR RECOMMENDATION SYSTEMS │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ COMPONENTS: │
│ ─────────── │
│ │
│ State (s_t): │
│ • User interaction history: [item_1, item_2, ..., item_t] │
│ • User features: demographics, preferences, context │
│ • Session features: time of day, device, location │
│ • System state: already-shown items, diversity budget │
│ │
│ Action (a_t): │
│ • Single item: Select one item to recommend │
│ • Slate: Select K items for a page │
│ • Ranking: Order K candidate items │
│ │
│ Reward (r_t): │
│ • Immediate: click, add-to-cart, watch start │
│ • Delayed: purchase, subscription, return visit │
│ • Composite: weighted sum of multiple signals │
│ • Negative: skip, dislike, unsubscribe │
│ │
│ Transition (P(s_{t+1} | s_t, a_t)): │
│ • User's response to recommendation │
│ • State evolution based on user behavior │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ TRAJECTORY: │
│ │
│ s_0 → a_0 → r_0 → s_1 → a_1 → r_1 → s_2 → ... → s_T │
│ │ │ │ │ │ │ │
│ │ │ │ │ │ └── User watched 80% of video │
│ │ │ │ │ └── Recommended video B │
│ │ │ │ └── User clicked, updated history │
│ │ │ └── User clicked (+1 reward) │
│ │ └── Recommended video A │
│ └── Initial user state │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ CHALLENGES IN RECSYS MDP: │
│ │
│ 1. LARGE ACTION SPACE │
│ Millions of items = millions of actions │
│ Can't enumerate all actions like in games │
│ │
│ 2. PARTIAL OBSERVABILITY │
│ Don't see user's full intent or preferences │
│ State is an approximation from observed behavior │
│ │
│ 3. DELAYED AND SPARSE REWARDS │
│ User might purchase days after viewing │
│ Most interactions have no explicit feedback │
│ │
│ 4. NON-STATIONARITY │
│ User preferences change over time │
│ Item catalog changes daily │
│ │
│ 5. SAFETY CONSTRAINTS │
│ Can't show harmful content while exploring │
│ Business rules must be respected │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Bandit Algorithms: The Foundation
Before full RL, contextual bandits provide a simpler framework that's widely deployed:
import numpy as np
from typing import List, Dict, Any
import torch
import torch.nn as nn
class ContextualBandit:
"""
Contextual bandit for recommendations.
Simpler than full RL:
- No state transitions (each decision is independent)
- Immediate rewards only
- Still handles exploration vs exploitation
"""
def __init__(
self,
num_items: int,
context_dim: int,
exploration_rate: float = 0.1,
):
self.num_items = num_items
self.context_dim = context_dim
self.exploration_rate = exploration_rate
# Thompson Sampling: Maintain posterior for each item
# Using Bayesian linear regression
self.A = [np.eye(context_dim) for _ in range(num_items)] # Precision matrices
self.b = [np.zeros(context_dim) for _ in range(num_items)] # Mean accumulators
def select_item(self, context: np.ndarray) -> int:
"""
Select item using Thompson Sampling.
Args:
context: User/context features
Returns:
Selected item index
"""
# Sample from posterior for each item
sampled_rewards = []
for i in range(self.num_items):
# Posterior mean and covariance
A_inv = np.linalg.inv(self.A[i])
theta_mean = A_inv @ self.b[i]
theta_cov = A_inv
# Sample theta from posterior
theta_sample = np.random.multivariate_normal(theta_mean, theta_cov)
# Predicted reward
reward = context @ theta_sample
sampled_rewards.append(reward)
return np.argmax(sampled_rewards)
def update(self, context: np.ndarray, item: int, reward: float):
"""Update posterior after observing reward."""
self.A[item] += np.outer(context, context)
self.b[item] += reward * context
class LinUCB:
"""
Linear Upper Confidence Bound (LinUCB).
Classic bandit algorithm for recommendations.
Used in Yahoo! News personalization (2010).
"""
def __init__(
self,
num_items: int,
context_dim: int,
alpha: float = 1.0, # Exploration parameter
):
self.num_items = num_items
self.context_dim = context_dim
self.alpha = alpha
# Initialize parameters
self.A = [np.eye(context_dim) for _ in range(num_items)]
self.b = [np.zeros(context_dim) for _ in range(num_items)]
def select_item(self, context: np.ndarray) -> int:
"""
Select item with highest UCB.
UCB = predicted_reward + alpha * uncertainty
"""
ucb_scores = []
for i in range(self.num_items):
A_inv = np.linalg.inv(self.A[i])
theta = A_inv @ self.b[i]
# Predicted reward
pred_reward = context @ theta
# Uncertainty (confidence interval width)
uncertainty = self.alpha * np.sqrt(context @ A_inv @ context)
ucb = pred_reward + uncertainty
ucb_scores.append(ucb)
return np.argmax(ucb_scores)
def update(self, context: np.ndarray, item: int, reward: float):
"""Update parameters after observing reward."""
self.A[item] += np.outer(context, context)
self.b[item] += reward * context
class NeuralContextualBandit(nn.Module):
"""
Neural network-based contextual bandit.
Uses neural network to model reward function,
with dropout for uncertainty estimation.
"""
def __init__(
self,
context_dim: int,
num_items: int,
hidden_dim: int = 256,
dropout: float = 0.1,
):
super().__init__()
self.num_items = num_items
# Shared context encoder
self.context_encoder = nn.Sequential(
nn.Linear(context_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
)
# Item embeddings
self.item_embedding = nn.Embedding(num_items, hidden_dim)
# Reward predictor
self.reward_head = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, 1),
)
def forward(
self,
context: torch.Tensor,
item_ids: torch.Tensor = None,
) -> torch.Tensor:
"""
Predict rewards for items given context.
Args:
context: (batch, context_dim)
item_ids: (batch, num_candidates) or None for all items
"""
# Encode context
context_emb = self.context_encoder(context) # (batch, hidden)
if item_ids is None:
# Score all items
item_emb = self.item_embedding.weight # (num_items, hidden)
# Expand for batch
context_emb = context_emb.unsqueeze(1) # (batch, 1, hidden)
item_emb = item_emb.unsqueeze(0) # (1, num_items, hidden)
else:
item_emb = self.item_embedding(item_ids) # (batch, num_candidates, hidden)
context_emb = context_emb.unsqueeze(1) # (batch, 1, hidden)
# Combine and predict
combined = torch.cat([
context_emb.expand(-1, item_emb.size(1), -1),
item_emb.expand(context_emb.size(0), -1, -1),
], dim=-1)
rewards = self.reward_head(combined).squeeze(-1) # (batch, num_items)
return rewards
def select_with_uncertainty(
self,
context: torch.Tensor,
num_samples: int = 10,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Select items with uncertainty estimation via MC Dropout.
"""
self.train() # Enable dropout
# Multiple forward passes
reward_samples = []
for _ in range(num_samples):
rewards = self.forward(context)
reward_samples.append(rewards)
reward_samples = torch.stack(reward_samples) # (num_samples, batch, num_items)
# Mean and uncertainty
mean_rewards = reward_samples.mean(dim=0)
uncertainty = reward_samples.std(dim=0)
# UCB-style selection
ucb = mean_rewards + uncertainty
return ucb.argmax(dim=-1), uncertainty
Policy Gradient Methods for Recommendations
For full RL with sequential states, we use policy gradient methods:
┌─────────────────────────────────────────────────────────────────────────┐
│ POLICY GRADIENT FOR RECSYS │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ REINFORCE (Williams, 1992): │
│ ─────────────────────────── │
│ │
│ Policy: π_θ(a|s) - probability of recommending item a given state s │
│ │
│ Objective: J(θ) = E_π[Σ γ^t r_t] │
│ │
│ Gradient: ∇J(θ) = E_π[Σ ∇log π_θ(a_t|s_t) × G_t] │
│ │
│ Where G_t = Σ_{k=0}^∞ γ^k r_{t+k} (return from time t) │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ ACTOR-CRITIC: │
│ ───────────── │
│ │
│ Add a value function V(s) to reduce variance: │
│ │
│ Advantage: A(s,a) = Q(s,a) - V(s) │
│ │
│ Actor (policy): π_θ(a|s) │
│ Critic (value): V_φ(s) │
│ │
│ Actor update: ∇J(θ) = E[∇log π_θ(a|s) × A(s,a)] │
│ Critic update: Minimize (V_φ(s) - G_t)² │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ OFF-POLICY LEARNING (Crucial for RecSys): │
│ ────────────────────────────────────────── │
│ │
│ Problem: Can't interact with users in real-time during training │
│ Solution: Learn from logged data (offline RL) │
│ │
│ Importance Sampling: │
│ J(π) = E_{π_old}[π(a|s)/π_old(a|s) × r] │
│ │
│ Challenges: │
│ • High variance when π differs from π_old │
│ • Need logged propensities (probability of showing item) │
│ │
└─────────────────────────────────────────────────────────────────────────┘
class REINFORCERecommender(nn.Module):
"""
REINFORCE algorithm for sequential recommendations.
The policy network outputs a distribution over items,
and we train it to maximize expected cumulative reward.
"""
def __init__(
self,
num_items: int,
state_dim: int,
hidden_dim: int = 256,
gamma: float = 0.99,
):
super().__init__()
self.gamma = gamma
self.num_items = num_items
# State encoder (could be transformer for sequential state)
self.state_encoder = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
)
# Policy head: outputs logits over items
self.policy_head = nn.Linear(hidden_dim, num_items)
# Storage for episode
self.log_probs = []
self.rewards = []
def forward(self, state: torch.Tensor) -> torch.distributions.Categorical:
"""
Compute policy distribution over items.
"""
state_emb = self.state_encoder(state)
logits = self.policy_head(state_emb)
return torch.distributions.Categorical(logits=logits)
def select_action(self, state: torch.Tensor) -> tuple[int, torch.Tensor]:
"""
Sample action from policy and store log probability.
"""
dist = self.forward(state)
action = dist.sample()
log_prob = dist.log_prob(action)
self.log_probs.append(log_prob)
return action.item(), log_prob
def store_reward(self, reward: float):
"""Store reward for current step."""
self.rewards.append(reward)
def compute_loss(self) -> torch.Tensor:
"""
Compute REINFORCE loss at end of episode.
Loss = -Σ log π(a|s) × G_t
where G_t is the return from step t
"""
# Compute returns (discounted cumulative rewards)
returns = []
G = 0
for r in reversed(self.rewards):
G = r + self.gamma * G
returns.insert(0, G)
returns = torch.tensor(returns)
# Normalize returns (variance reduction)
returns = (returns - returns.mean()) / (returns.std() + 1e-8)
# Policy gradient loss
loss = 0
for log_prob, G in zip(self.log_probs, returns):
loss -= log_prob * G
return loss
def clear_episode(self):
"""Clear stored episode data."""
self.log_probs = []
self.rewards = []
class ActorCriticRecommender(nn.Module):
"""
Actor-Critic for recommendations.
Actor: Policy network π(a|s)
Critic: Value network V(s)
Advantage = r + γV(s') - V(s)
"""
def __init__(
self,
num_items: int,
state_dim: int,
hidden_dim: int = 256,
gamma: float = 0.99,
):
super().__init__()
self.gamma = gamma
# Shared state encoder
self.state_encoder = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
)
# Actor (policy)
self.actor = nn.Linear(hidden_dim, num_items)
# Critic (value function)
self.critic = nn.Linear(hidden_dim, 1)
def forward(self, state: torch.Tensor) -> tuple[torch.distributions.Categorical, torch.Tensor]:
"""
Compute policy and value for state.
"""
state_emb = self.state_encoder(state)
# Policy distribution
logits = self.actor(state_emb)
policy = torch.distributions.Categorical(logits=logits)
# Value estimate
value = self.critic(state_emb)
return policy, value
def compute_loss(
self,
states: torch.Tensor,
actions: torch.Tensor,
rewards: torch.Tensor,
next_states: torch.Tensor,
dones: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute actor and critic losses.
"""
# Current policy and values
policy, values = self.forward(states)
# Next state values (for TD target)
with torch.no_grad():
_, next_values = self.forward(next_states)
next_values = next_values * (1 - dones.float())
# TD target and advantage
td_target = rewards + self.gamma * next_values
advantage = td_target - values
# Actor loss (policy gradient with advantage)
log_probs = policy.log_prob(actions)
actor_loss = -(log_probs * advantage.detach()).mean()
# Critic loss (MSE on value prediction)
critic_loss = F.mse_loss(values, td_target.detach())
return actor_loss, critic_loss
Offline RL: Learning from Logged Data
In practice, we can't do online exploration—we must learn from logged interaction data:
┌─────────────────────────────────────────────────────────────────────────┐
│ OFFLINE RL FOR RECOMMENDATIONS │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ THE OFFLINE RL PROBLEM: │
│ ──────────────────────── │
│ │
│ We have: Logged data D = {(s, a, r, s')} from old policy π_old │
│ We want: Learn new policy π_new that maximizes rewards │
│ │
│ Challenge: Distribution shift │
│ • π_new might choose actions never seen in D │
│ • Q-value overestimation for unseen actions │
│ • Need to constrain π_new to stay close to π_old │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ SOLUTION 1: Conservative Q-Learning (CQL) │
│ ────────────────────────────────────────── │
│ │
│ Penalize Q-values for actions not in dataset: │
│ │
│ L_CQL = E_s[log Σ_a exp(Q(s,a))] - E_{s,a~D}[Q(s,a)] │
│ │
│ First term: Pushes down Q for ALL actions │
│ Second term: Pulls up Q for actions IN dataset │
│ Net effect: Conservatively low Q for unseen actions │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ SOLUTION 2: Batch Constrained Q-Learning (BCQ) │
│ ────────────────────────────────────────────── │
│ │
│ Only consider actions similar to those in dataset: │
│ │
│ a* = argmax_a Q(s,a) subject to π_old(a|s) > threshold │
│ │
│ Learns a generative model of π_old to filter actions │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ SOLUTION 3: Implicit Q-Learning (IQL) │
│ ────────────────────────────────────── │
│ │
│ Never query Q for actions outside dataset: │
│ │
│ Learn V(s) separately from Q(s,a) │
│ V(s) = E_a~π_old[Q(s,a)] using expectile regression │
│ │
│ Extract policy: π(a|s) ∝ exp(Q(s,a) - V(s)) │
│ Only uses in-dataset (s,a) pairs for policy extraction │
│ │
└─────────────────────────────────────────────────────────────────────────┘
class ConservativeQLearning(nn.Module):
"""
Conservative Q-Learning (CQL) for offline recommendation.
Key idea: Penalize Q-values for out-of-distribution actions
to prevent overestimation of unseen items.
"""
def __init__(
self,
num_items: int,
state_dim: int,
hidden_dim: int = 256,
cql_alpha: float = 1.0, # CQL regularization weight
gamma: float = 0.99,
):
super().__init__()
self.num_items = num_items
self.cql_alpha = cql_alpha
self.gamma = gamma
# Q-network
self.q_network = nn.Sequential(
nn.Linear(state_dim + num_items, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1),
)
# Target Q-network (for stability)
self.target_q_network = nn.Sequential(
nn.Linear(state_dim + num_items, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1),
)
# Item embeddings (for one-hot or learned)
self.item_embedding = nn.Embedding(num_items, num_items)
self.item_embedding.weight.data = torch.eye(num_items) # One-hot
def get_q_values(
self,
states: torch.Tensor,
actions: torch.Tensor,
use_target: bool = False,
) -> torch.Tensor:
"""
Compute Q(s, a) for given state-action pairs.
"""
network = self.target_q_network if use_target else self.q_network
action_emb = self.item_embedding(actions)
sa = torch.cat([states, action_emb], dim=-1)
return network(sa)
def get_all_q_values(self, states: torch.Tensor) -> torch.Tensor:
"""
Compute Q(s, a) for all actions.
"""
batch_size = states.size(0)
# Expand states for all items
states_expanded = states.unsqueeze(1).expand(-1, self.num_items, -1)
# All item embeddings
all_items = torch.arange(self.num_items, device=states.device)
item_emb = self.item_embedding(all_items)
item_emb = item_emb.unsqueeze(0).expand(batch_size, -1, -1)
# Concatenate and compute Q
sa = torch.cat([states_expanded, item_emb], dim=-1)
q_values = self.q_network(sa.view(-1, sa.size(-1)))
return q_values.view(batch_size, self.num_items)
def compute_cql_loss(
self,
states: torch.Tensor,
actions: torch.Tensor,
rewards: torch.Tensor,
next_states: torch.Tensor,
dones: torch.Tensor,
) -> torch.Tensor:
"""
Compute CQL loss.
CQL Loss = TD Loss + α × (logsumexp(Q(s,·)) - Q(s,a))
"""
batch_size = states.size(0)
# Standard TD loss
current_q = self.get_q_values(states, actions)
with torch.no_grad():
# Max Q for next state (using target network)
next_q_all = self.get_all_q_values(next_states)
next_q_max = next_q_all.max(dim=1, keepdim=True)[0]
target_q = rewards + self.gamma * next_q_max * (1 - dones.float())
td_loss = F.mse_loss(current_q, target_q)
# CQL regularization
# Push down Q for all actions, pull up Q for dataset actions
all_q = self.get_all_q_values(states) # (batch, num_items)
# logsumexp term (pushes down all Q values)
logsumexp_q = torch.logsumexp(all_q, dim=1, keepdim=True)
# Dataset Q term (pulls up Q for observed actions)
dataset_q = current_q
cql_loss = (logsumexp_q - dataset_q).mean()
# Total loss
total_loss = td_loss + self.cql_alpha * cql_loss
return total_loss, td_loss, cql_loss
class SlateRL(nn.Module):
"""
RL for slate recommendation (multiple items at once).
Challenges:
- Combinatorial action space: C(N, K) possible slates
- Item interactions within slate
- Position bias
"""
def __init__(
self,
num_items: int,
slate_size: int,
state_dim: int,
hidden_dim: int = 256,
):
super().__init__()
self.num_items = num_items
self.slate_size = slate_size
# State encoder
self.state_encoder = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
)
# Item scorer (for ranking)
self.item_encoder = nn.Embedding(num_items, hidden_dim)
self.scorer = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1),
)
# Position encoding (for modeling position bias)
self.position_encoding = nn.Embedding(slate_size, hidden_dim)
def score_items(self, state: torch.Tensor) -> torch.Tensor:
"""
Score all items for a given state.
"""
batch_size = state.size(0)
# Encode state
state_emb = self.state_encoder(state) # (batch, hidden)
# Encode all items
all_items = torch.arange(self.num_items, device=state.device)
item_emb = self.item_encoder(all_items) # (num_items, hidden)
# Score each item
state_expanded = state_emb.unsqueeze(1).expand(-1, self.num_items, -1)
item_expanded = item_emb.unsqueeze(0).expand(batch_size, -1, -1)
combined = torch.cat([state_expanded, item_expanded], dim=-1)
scores = self.scorer(combined).squeeze(-1) # (batch, num_items)
return scores
def select_slate(
self,
state: torch.Tensor,
explore: bool = True,
) -> torch.Tensor:
"""
Select slate of K items.
Uses sequential selection with position-aware scoring.
"""
batch_size = state.size(0)
# Initial item scores
scores = self.score_items(state) # (batch, num_items)
selected = []
mask = torch.zeros(batch_size, self.num_items, device=state.device)
for pos in range(self.slate_size):
# Mask already selected items
masked_scores = scores - 1e9 * mask
if explore:
# Sample from softmax
probs = F.softmax(masked_scores, dim=-1)
item = torch.multinomial(probs, 1).squeeze(-1)
else:
# Greedy selection
item = masked_scores.argmax(dim=-1)
selected.append(item)
# Update mask
mask.scatter_(1, item.unsqueeze(1), 1)
return torch.stack(selected, dim=1) # (batch, slate_size)
Production RL Systems (2024-2025)
Major platforms using RL for recommendations:
┌─────────────────────────────────────────────────────────────────────────┐
│ RL IN PRODUCTION RECOMMENDATION SYSTEMS │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ YOUTUBE (2019-present): │
│ ──────────────────────── │
│ • "Reinforcement Learning for Slate-based Recommender Systems" │
│ • SlateQ: Handles slate-level rewards (not per-item) │
│ • Optimizes for long-term watch time, not just clicks │
│ • Uses REINFORCE with baseline for variance reduction │
│ │
│ SPOTIFY (2022-present): │
│ ──────────────────────── │
│ • Contextual bandits for playlist generation │
│ • Thompson Sampling for song selection │
│ • Optimizes for listening time + skip rate │
│ │
│ NETFLIX (2020-present): │
│ ──────────────────────── │
│ • Counterfactual evaluation for offline RL │
│ • Inverse propensity scoring for unbiased learning │
│ • Multi-objective RL (engagement + diversity + freshness) │
│ │
│ ALIBABA (2018-present): │
│ ──────────────────────── │
│ • Virtual Taobao: Simulated environment for RL │
│ • Batch RL for e-commerce recommendations │
│ • Multi-agent RL for marketplace optimization │
│ │
│ KUAISHOU (2024-2025): │
│ ──────────────────────── │
│ • OneRec uses DPO (Direct Preference Optimization) │
│ • Iterative reward model training │
│ • RLHF-style alignment for video recommendations │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ KEY LESSONS FROM PRODUCTION: │
│ │
│ 1. Start with bandits, not full RL │
│ Simpler, more stable, often sufficient │
│ │
│ 2. Offline RL is essential │
│ Can't do online exploration at scale │
│ │
│ 3. Reward shaping is critical │
│ Raw metrics often don't capture business value │
│ │
│ 4. Safety constraints are non-negotiable │
│ RL without guardrails can recommend harmful content │
│ │
│ 5. Hybrid approaches work best │
│ RL for exploration, supervised for exploitation │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Related Articles
Recommendation Systems: From Collaborative Filtering to Deep Learning
A comprehensive journey through recommendation system architectures. From the Netflix Prize and matrix factorization to neural collaborative filtering and two-tower models—understand the foundations before the transformer revolution.
Generative AI for Recommendation Systems: LLMs Meet Personalization
A comprehensive guide to LLM-powered recommendation systems. From feature augmentation to conversational agents, understand how generative AI is transforming personalization.
Attention Mechanisms: From Self-Attention to FlashAttention
A comprehensive deep dive into attention mechanisms—the core innovation powering modern LLMs. From the intuition behind self-attention to the engineering of FlashAttention, understand how transformers actually work.
Transformer Architecture: A Complete Deep Dive
A comprehensive exploration of the transformer architecture—from embedding layers through attention and feed-forward networks to the output head. Understand why decoder-only models dominate, how residual connections enable deep networks, and the engineering decisions behind GPT, Llama, and modern LLMs.
Embedding Models & Strategies: Choosing and Optimizing Embeddings for AI Applications
Comprehensive guide to embedding models for RAG, search, and AI applications. Comparison of text-embedding-3, BGE, E5, Cohere Embed v4, and Voyage with guidance on fine-tuning, dimensionality, multimodal embeddings, and production optimization.
LLM-Powered Search for E-Commerce: Beyond NER and Elasticsearch
A deep dive into building intelligent e-commerce search systems that understand natural language, leverage metadata effectively, and support multi-turn conversations—moving beyond classical NER + Elasticsearch approaches.
Vector Databases: A Comprehensive Guide to Pinecone, Weaviate, Qdrant, Milvus & Chroma
Deep dive into vector database architecture, indexing algorithms, and production considerations. Comprehensive comparison of Pinecone vs Weaviate vs Qdrant vs Milvus vs Chroma with benchmarks, pricing, and use case recommendations for 2025.