nanochat: Andrej Karpathy's Full-Stack ChatGPT Clone
A comprehensive, equation-complete analysis of nanochat—the complete ChatGPT pipeline from tokenization through reinforcement learning. Deep dive into the modern GPT architecture (RoPE, RMSNorm, GQA, QK-norm, ReLU²), the Muon optimizer with Newton-Schulz orthogonalization, KV cache inference, and tool use.
Table of Contents
Repository Overview
nanochat implements the complete ChatGPT pipeline in ~8,000 lines of Python. The core files:
| File | Lines | Purpose |
|---|---|---|
nanochat/gpt.py | 311 | Modern GPT architecture: RoPE, RMSNorm, QK-norm, GQA, ReLU² |
nanochat/muon.py | 188 | Muon optimizer with Newton-Schulz orthogonalization |
nanochat/adamw.py | ~100 | Distributed AdamW for embeddings |
nanochat/tokenizer.py | 399 | BPE tokenizer (rustbpe + tiktoken) |
nanochat/engine.py | 385 | Inference engine with KV cache and tool use |
nanochat/dataloader.py | ~200 | Data loading for all training stages |
tasks/*.py | varies | Dataset implementations (GSM8K, MMLU, SmolTalk, etc.) |
GitHub: github.com/karpathy/nanochat
Architecture Differences from nanoGPT
nanochat modernizes the GPT-2 architecture used in nanoGPT:
| Feature | nanoGPT (GPT-2) | nanochat |
|---|---|---|
| Position encoding | Learned absolute | RoPE (relative) |
| Normalization | LayerNorm (learned γ, β) | RMSNorm (no params) |
| Q/K scaling | QK normalization | |
| Activation | GELU | ReLU² |
| K/V heads | Same as Q | GQA (can differ) |
| Output logits | Raw | Soft-capped (tanh) |
| Embeddings | Tied | Untied |
| Optimizer | AdamW only | AdamW + Muon |
Model Configuration
gpt.py:26-33:
@dataclass
class GPTConfig:
sequence_len: int = 1024
vocab_size: int = 50304
n_layer: int = 12
n_head: int = 6 # Number of query heads
n_kv_head: int = 6 # Number of key/value heads (GQA)
n_embd: int = 768
The speedrun "d20" model uses: n_layer=20, n_head=10, n_kv_head=10, n_embd=1280, giving 561M parameters.
The Five-Stage Training Pipeline
Stage 1: Tokenization
nanochat provides two tokenizer implementations:
- RustBPE - Custom Rust library for fast training
- tiktoken - For efficient inference
Tokenizer Training (scripts/tok_train.py)
parser.add_argument('--max_chars', type=int, default=10_000_000_000) # 10B chars
parser.add_argument('--doc_cap', type=int, default=10_000) # Max per doc
parser.add_argument('--vocab_size', type=int, default=65536) # 2^16
def text_iterator():
nchars = 0
for batch in parquets_iter_batched(split="train"):
for doc in batch:
doc_text = doc[:args.doc_cap] # Crop long documents
nchars += len(doc_text)
yield doc_text
if nchars > args.max_chars:
return
tokenizer = RustBPETokenizer.train_from_iterator(text_iter, args.vocab_size)
Training uses FineWeb-Edu data with documents capped at 10K characters to prevent single documents from dominating the vocabulary.
BPE Algorithm
- Initialize vocabulary with 256 byte tokens
- Count all adjacent token pairs in corpus
- Find most frequent pair
- Create new token :
- Replace all with in corpus
- Repeat until
Special Tokens (tokenizer.py:13-25)
SPECIAL_TOKENS = [
"<|bos|>", # Beginning of sequence
"<|user_start|>", # Start of user message
"<|user_end|>", # End of user message
"<|assistant_start|>", # Start of assistant response
"<|assistant_end|>", # End of assistant response
"<|python_start|>", # Start of Python code (tool use)
"<|python_end|>", # End of Python code
"<|output_start|>", # Start of tool output
"<|output_end|>", # End of tool output
]
Pre-tokenization Pattern (tokenizer.py:30)
SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
Note: \p{N}{1,2} uses 1-2 digits (not 3 like GPT-4) to save vocabulary space in smaller models.
Dual Tokenizer Implementation (tokenizer.py:35-80)
class HuggingFaceTokenizer:
"""For training - uses HuggingFace Tokenizers library"""
@classmethod
def train_from_iterator(cls, text_iterator, vocab_size):
tokenizer = HFTokenizer(BPE(byte_fallback=True, unk_token=None))
tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
pre_tokenizers.Split(pattern=Regex(SPLIT_PATTERN), behavior="isolated"),
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False)
])
trainer = BpeTrainer(vocab_size=vocab_size, special_tokens=SPECIAL_TOKENS)
tokenizer.train_from_iterator(text_iterator, trainer=trainer)
return cls(tokenizer)
class RustBPETokenizer:
"""For training - uses custom rustbpe for speed"""
class TiktokenTokenizer:
"""For inference - uses tiktoken for efficiency"""
Stage 2: Pretraining (scripts/base_train.py)
Training Objective:
Chinchilla Scaling Laws:
Given compute budget (in FLOPS):
Optimal allocation:
The Chinchilla-optimal ratio:
For the d20 model (561M parameters), this suggests ~11B tokens—exactly what nanochat uses.
Complete Pretraining Loop
base_train.py:214-364:
while True:
last_step = step == num_iterations
flops_so_far = num_flops_per_token * total_batch_size * step
# Evaluate val bpb (bits per byte)
if last_step or step % eval_every == 0:
model.eval()
val_loader = build_val_loader()
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
with autocast_ctx:
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
model.train()
# Termination
if last_step:
break
# Single training step
synchronize()
t0 = time.time()
for micro_step in range(grad_accum_steps):
with autocast_ctx:
loss = model(x, y)
train_loss = loss.detach()
loss = loss / grad_accum_steps
loss.backward()
x, y, dataloader_state_dict = next(train_loader)
# Gradient clipping
if grad_clip > 0.0:
grad_norm = torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
# Step optimizers with LR scheduling
lrm = get_lr_multiplier(step)
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["initial_lr"] * lrm
muon_momentum = get_muon_momentum(step)
for group in muon_optimizer.param_groups:
group["momentum"] = muon_momentum
for opt in optimizers:
opt.step()
model.zero_grad(set_to_none=True)
step += 1
Pretraining Learning Rate Schedule
base_train.py:180-189:
def get_lr_multiplier(it):
warmup_iters = round(warmup_ratio * num_iterations)
warmdown_iters = round(warmdown_ratio * num_iterations)
if it < warmup_iters:
return (it + 1) / warmup_iters
elif it <= num_iterations - warmdown_iters:
return 1.0
else:
progress = (num_iterations - it) / warmdown_iters
return progress * 1.0 + (1 - progress) * final_lr_frac
Default: 0% warmup, 20% warmdown, final LR = 0%.
Muon Momentum Warmup
base_train.py:192-195:
def get_muon_momentum(it):
frac = min(it / 300, 1)
momentum = (1 - frac) * 0.85 + frac * 0.95
return momentum
Momentum ramps from 0.85 to 0.95 over first 300 steps.
Data Loading from Parquet Files
dataloader.py:10-89: Streams data from FineWeb-Edu 100B shuffle:
def tokenizing_distributed_data_loader_with_state(B, T, split, device="cuda", resume_state_dict=None):
def document_batches():
parquet_paths = list_parquet_files()
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
while True:
for filepath in parquet_paths:
pf = pq.ParquetFile(filepath)
for rg_idx in range(ddp_rank, pf.num_row_groups, ddp_world_size):
rg = pf.read_row_group(rg_idx)
batch = rg.column('text').to_pylist()
yield batch
token_buffer = deque()
needed_tokens = B * T + 1
while True:
while len(token_buffer) < needed_tokens:
doc_batch, _ = next(batches)
token_lists = tokenizer.encode(doc_batch, prepend=bos_token)
for tokens in token_lists:
token_buffer.extend(tokens)
# Build inputs/targets
inputs = tokens[:-1].view(B, T).to(device)
targets = tokens[1:].view(B, T).to(device)
yield inputs, targets, state_dict
Bits Per Byte (BPB) Evaluation
loss_eval.py:9-65: Tokenizer-independent metric:
@torch.no_grad()
def evaluate_bpb(model, batches, steps, token_bytes):
total_nats = torch.tensor(0.0, dtype=torch.float32, device=model.get_device())
total_bytes = torch.tensor(0, dtype=torch.int64, device=model.get_device())
for _ in range(steps):
x, y = next(batch_iter)
loss2d = model(x, y, loss_reduction='none')
num_bytes2d = token_bytes[y] # Bytes per target token
total_nats += (loss2d * (num_bytes2d > 0)).sum()
total_bytes += num_bytes2d.sum()
# All-reduce across ranks
if world_size > 1:
dist.all_reduce(total_nats, op=dist.ReduceOp.SUM)
dist.all_reduce(total_bytes, op=dist.ReduceOp.SUM)
bpb = total_nats.item() / (math.log(2) * total_bytes.item())
return bpb
Stage 3: Midtraining (scripts/mid_train.py)
Introduces conversational structure, special tokens, and tool use patterns.
Midtraining Data Mixture
mid_train.py:96-111:
train_dataset = TaskMixture([
SmolTalk(split="train"), # 460K conversations
MMLU(subset="auxiliary_train", split="train"), # 100K multiple choice
GSM8K(subset="main", split="train"), # 8K math with tool use
CustomJSON(filepath=identity_conversations_filepath), # 1K identity
CustomJSON(filepath=identity_conversations_filepath), # 2 epochs of identity
SimpleSpelling(size=200000, split="train"), # 200K spelling
SpellingBee(size=80000, split="train"), # 80K letter counting
]) # Total: ~848K rows
TaskMixture Implementation
tasks/common.py:54-86:
class TaskMixture(Task):
def __init__(self, tasks, **kwargs):
self.tasks = tasks
self.lengths = [len(task) for task in self.tasks]
self.num_conversations = sum(self.lengths)
# Build shuffled index map
self.index_map = []
for task_idx, task_length in enumerate(self.lengths):
for local_idx in range(task_length):
self.index_map.append((task_idx, local_idx))
rng = random.Random(42)
rng.shuffle(self.index_map) # Deterministic shuffle
def get_example(self, index):
task_idx, local_idx = self.index_map[index]
return self.tasks[task_idx][local_idx]
Midtraining Learning Rate Schedule
mid_train.py:163-165:
def get_lr_multiplier(progress):
return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2
80% constant LR, then linear decay to 0 over final 20%.
Stage 4: Supervised Fine-Tuning (scripts/chat_sft.py)
Loss Masking: Only compute loss on assistant tokens:
where is the set of assistant token positions.
SFT Data Mixture
chat_sft.py:83-93:
train_ds = TaskMixture([
ARC(subset="ARC-Easy", split="train"), # 2.3K rows
ARC(subset="ARC-Challenge", split="train"), # 1.1K rows
GSM8K(subset="main", split="train"), # 8K rows
SmolTalk(split="train", stop=10_000), # 10K rows
CustomJSON(filepath=identity_conversations_filepath), # 1K identity
SimpleSpelling(size=300, split="train"), # 300 rows
SpellingBee(size=300, split="train"), # 300 rows
]) # Total: ~23K rows
SFT Data Collation with Loss Masking
chat_sft.py:98-118:
def sft_data_generator(dataset, batch_size):
pad_token_id = tokenizer.encode_special("<|assistant_end|>")
def collate_and_yield(batch):
nrows = len(batch)
ncols = max(len(ids) for ids, mask in batch) - 1
inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long)
targets = torch.full((nrows, ncols), -1, dtype=torch.long) # -1 = ignore
for i, (ids, mask) in enumerate(batch):
n = len(ids)
ids_tensor = torch.tensor(ids, dtype=torch.long)
inputs[i, :n-1] = ids_tensor[:-1]
row_targets = ids_tensor[1:]
mask_tensor = torch.tensor(mask[1:], dtype=torch.long)
row_targets[mask_tensor == 0] = -1 # Mask non-assistant tokens
targets[i, :n-1] = row_targets
return inputs.to(device), targets.to(device)
The mask from render_conversation() indicates which positions are assistant tokens.
SFT Training Loop
chat_sft.py:169-246:
for step in range(num_iterations):
last_step = step == num_iterations - 1
# Evaluate validation loss
if last_step or step % eval_every == 0:
model.eval()
losses = []
for _ in range(eval_steps):
val_inputs, val_targets = next(val_loader)
with torch.no_grad(), autocast_ctx:
loss = model(val_inputs, val_targets)
losses.append(loss)
val_loss = torch.stack(losses).mean()
if ddp:
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
# Gradient accumulation
num_tokens = torch.tensor(0, device=device)
for micro_step in range(grad_accum_steps):
train_inputs, train_targets = next(train_loader)
with autocast_ctx:
loss = model(train_inputs, train_targets)
loss = loss / grad_accum_steps
loss.backward()
num_tokens += (train_targets >= 0).sum()
# LR scheduler: linear decay
lrm = 1.0 - step / num_iterations
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["initial_lr"] * lrm
for opt in optimizers:
opt.step()
model.zero_grad(set_to_none=True)
Stage 5: Reinforcement Learning (scripts/chat_rl.py)
nanochat implements a simplified GRPO variant—closer to REINFORCE with baseline:
Simplifications from full GRPO:
- No KL regularization to reference model
- On-policy (no PPO ratio clipping needed)
- Token-level normalization (GAPO style)
- Advantage = (no variance normalization)
GRPO Loss
For completions per problem with rewards :
Rollout Generation
chat_rl.py:78-140:
@torch.no_grad()
def get_batch():
for example_idx in itertools.cycle(rank_indices):
conversation = train_task[example_idx]
tokens = tokenizer.render_for_completion(conversation)
prefix_length = len(tokens)
# Generate num_samples completions
generated_sequences = []
for sampling_step in range(num_samples // device_batch_size):
seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF
with autocast_ctx:
seqs, masks = engine.generate_batch(
tokens, num_samples=device_batch_size,
max_tokens=max_new_tokens,
temperature=temperature, top_k=top_k, seed=seed
)
generated_sequences.extend(seqs)
# Calculate rewards
rewards = []
for sample_tokens in generated_sequences:
generated_tokens = sample_tokens[prefix_length:]
generated_text = tokenizer.decode(generated_tokens)
reward = train_task.reward(conversation, generated_text)
rewards.append(reward)
# Compute advantages
rewards = torch.tensor(rewards, device=device)
mu = rewards.mean()
advantages = rewards - mu
yield sequences, inputs, targets, rewards, advantages
GSM8K Reward Function
tasks/gsm8k.py:110-117:
def reward(self, conversation, assistant_response):
is_correct = self.evaluate(conversation, assistant_response)
return float(is_correct) # Binary: 0.0 or 1.0
Evaluation extracts the answer after #### marker:
GSM_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
def extract_answer(completion):
match = GSM_RE.search(completion)
if match:
return match.group(1).strip().replace(",", "")
return None
RL Training Loop
chat_rl.py:241-300:
for step in range(num_steps):
# Evaluate pass@k
if step % eval_every == 0:
model.eval()
passk = torch.zeros(device_batch_size, device=device)
records = list(run_gsm8k_eval(val_task, tokenizer, engine, ...))
for k in range(1, device_batch_size + 1):
passk[k-1] = sum(any(o["is_correct"] for o in r["outcomes"][:k]) for r in records)
if ddp:
dist.all_reduce(passk, op=dist.ReduceOp.SUM)
# Forward/backward on rollouts
for example_step in range(examples_per_rank):
sequences, inputs, targets, rewards, advantages = next(batch_iterator)
model.train()
for pass_idx in range(num_passes):
b0, b1 = pass_idx * device_batch_size, (pass_idx + 1) * device_batch_size
with autocast_ctx:
logp = -model(inputs[b0:b1], targets[b0:b1], loss_reduction='none')
# Policy gradient objective
pg_obj = (logp * advantages[b0:b1].unsqueeze(-1)).sum()
num_valid = (targets[b0:b1] >= 0).sum().clamp(min=1)
pg_obj = pg_obj / (num_valid * num_passes * examples_per_rank)
loss = -pg_obj # Minimize negative of objective
loss.backward()
# Update
lrm = 1.0 - step / num_steps
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["initial_lr"] * lrm
opt.step()
model.zero_grad(set_to_none=True)
Key insight: No PPO clipping needed because we're on-policy (completions generated from current model).
Checkpointing (nanochat/checkpoint_manager.py)
nanochat implements distributed checkpointing with sharded optimizer states:
checkpoint_manager.py:24-48:
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
if rank == 0:
os.makedirs(checkpoint_dir, exist_ok=True)
# Save model parameters (only rank 0)
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
torch.save(model_data, model_path)
# Save metadata as JSON
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
with open(meta_path, "w") as f:
json.dump(meta_data, f, indent=2)
# Optimizer state is sharded - each rank saves its own
if optimizer_data is not None:
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
torch.save(optimizer_data, optimizer_path)
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
model_data = torch.load(f"model_{step:06d}.pt", map_location=device)
optimizer_data = None
if load_optimizer:
optimizer_data = torch.load(f"optim_{step:06d}_rank{rank:d}.pt", map_location=device)
with open(f"meta_{step:06d}.json") as f:
meta_data = json.load(f)
return model_data, optimizer_data, meta_data
Checkpoint structure:
checkpoint_dir/
├── model_000100.pt # Model weights (saved by rank 0 only)
├── meta_000100.json # Metadata: step, config, metrics
├── optim_000100_rank0.pt # Optimizer state shard for rank 0
├── optim_000100_rank1.pt # Optimizer state shard for rank 1
├── optim_000100_rank2.pt # ... etc
└── optim_000100_rank7.pt
Rotary Position Embeddings (RoPE)
Mathematical Foundation
RoPE encodes position through rotation in the complex plane. For each pair of dimensions, rotation by angle :
Position-Dependent Rotation
For position and dimension pair :
where base = 10,000. This creates a spectrum of frequencies.
Complete Rotation Matrix
For a vector :
Relative Position Property
The key insight: after rotating queries and keys by their positions:
The dot product depends only on relative position , not absolute positions!
Proof (2D case): Let , .
where , .
Expanding and using :
Result depends only on .
Implementation
Precomputation (gpt.py:190-204):
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
if device is None:
device = self.transformer.wte.weight.device
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
inv_freq = 1.0 / (base ** (channel_range / head_dim))
t = torch.arange(seq_len, dtype=torch.float32, device=device)
freqs = torch.outer(t, inv_freq)
cos, sin = freqs.cos(), freqs.sin()
cos, sin = cos.bfloat16(), sin.bfloat16()
cos, sin = cos[None, :, None, :], sin[None, :, None, :]
return cos, sin
Application (gpt.py:41-49):
def apply_rotary_emb(x, cos, sin):
assert x.ndim == 4
d = x.shape[3] // 2
x1, x2 = x[..., :d], x[..., d:]
y1 = x1 * cos + x2 * sin
y2 = x1 * (-sin) + x2 * cos
out = torch.cat([y1, y2], 3)
out = out.to(x.dtype)
return out
RMSNorm
Mathematical Definition
RMSNorm without learnable parameters:
Compare to LayerNorm:
RMSNorm removes:
- Mean centering ()
- Learned scale ()
- Learned shift ()
Implementation
gpt.py:36-38:
def norm(x):
return F.rms_norm(x, (x.size(-1),))
This pure functional form aids training stability and requires no parameters.
QK Normalization
Mathematical Definition
After RoPE, normalize queries and keys to unit length:
The attention scores become cosine similarity:
Benefits
- Bounded logits: No risk of softmax overflow
- Scale-invariant: Works across model sizes without tuning
- Training stability: Well-behaved gradients
Implementation
gpt.py:77:
q, k = norm(q), norm(k) # QK norm
ReLU² Activation
Mathematical Definition
Properties
| Property | ReLU | GELU | ReLU² |
|---|---|---|---|
| Sparsity | High | Low | High |
| Derivative at 0 | Undefined | ~0.5 | 0 (continuous) |
| Computation | Fast | Slow | Fast |
Implementation
gpt.py:118-122:
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
def forward(self, x):
x = self.c_fc(x)
x = F.relu(x).square() # ReLU² activation
x = self.c_proj(x)
return x
Group Query Attention (GQA)
Mathematical Definition
Standard MHA:
GQA:
Each KV group serves query heads.
Memory Savings
KV cache size without GQA:
With GQA (ratio ):
Example (d20 model with 4:1 GQA):
- Without GQA: 2 × 20 × 16 × 1024 × 64 × 2 bytes = 84MB
- With 4:1 GQA: 21MB (4× reduction)
Implementation
gpt.py:51-64:
class CausalSelfAttention(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.n_head = config.n_head
self.n_kv_head = config.n_kv_head
self.head_dim = config.n_embd // self.n_head
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
self.c_q = nn.Linear(config.n_embd, self.n_head * self.head_dim, bias=False)
self.c_k = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_v = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
gpt.py:87-91:
enable_gqa = self.n_head != self.n_kv_head
y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
Logit Soft-Capping
Mathematical Definition
where is the cap value.
Properties
| Input | Output | Effect |
|---|---|---|
| 1.0 | ~1.0 | Unchanged |
| 5.0 | ~4.9 | Slight compression |
| 15.0 | ~11.4 | Significant compression |
| 50.0 | ~15.0 | Saturated |
Benefits
- Numerical stability: Prevents softmax overflow
- Prevents overconfidence: Extreme logits → probability ≈ 1.0
- Differentiable: Unlike hard clipping
Implementation
gpt.py:266-270:
softcap = 15
logits = self.lm_head(x)
logits = logits[..., :self.config.vocab_size] # Remove padding
logits = logits.float() # FP32 for stability
logits = softcap * torch.tanh(logits / softcap)
Untied Embeddings
Unlike nanoGPT's weight tying, nanochat uses separate embedding and output matrices:
gpt.py:146-150:
self.transformer = nn.ModuleDict({
"wte": nn.Embedding(padded_vocab_size, config.n_embd),
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
})
self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False)
Why untie? Input embeddings optimize for understanding tokens; output projections optimize for predicting tokens. These roles differ enough that separate parameters improve capacity.
Weight Initialization
Standard Initialization
gpt.py:177-187:
def _init_weights(self, module):
if isinstance(module, nn.Linear):
fan_out = module.weight.size(0)
fan_in = module.weight.size(1)
std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in))
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
Zero Initialization for Output Projections
gpt.py:163-168:
def init_weights(self):
self.apply(self._init_weights)
torch.nn.init.zeros_(self.lm_head.weight)
for block in self.transformer.h:
torch.nn.init.zeros_(block.mlp.c_proj.weight)
torch.nn.init.zeros_(block.attn.c_proj.weight)
Zero-initializing output projections ensures residual connections initially pass through unchanged.
The Muon Optimizer
Algorithm Overview
Muon (MomentUm Orthogonalized by Newton-schulz) orthogonalizes gradient updates:
- Compute momentum:
- Optionally apply Nesterov:
- Orthogonalize via Newton-Schulz:
- Apply update with aspect-ratio scaling
Newton-Schulz Orthogonalization
To orthogonalize matrix , we want , making all singular values equal to 1.
Newton-Schulz iteration:
with quintic coefficients .
Five iterations approximate:
Why Orthogonalization Helps
Gradient updates often have low-rank structure—few directions receive most learning signal. This causes:
- Slow learning in underrepresented directions
- Oscillation in dominant directions
Orthogonalization redistributes signal: all singular values become 1, no direction is preferred.
Implementation
muon.py:10-36:
@torch.compile
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
assert G.ndim >= 2
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
if G.size(-2) > G.size(-1):
X = X.mT
# Ensure spectral norm is at most 1
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
# Perform the NS iterations
for _ in range(steps):
A = X @ X.mT
B = b * A + c * A @ A
X = a * X + B @ X
if G.size(-2) > G.size(-1):
X = X.mT
return X
muon.py:69-83:
@torch.no_grad()
def step(self):
for group in self.param_groups:
params = group["params"]
for p in params:
g = p.grad
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(g)
buf = state["momentum_buffer"]
buf.lerp_(g, 1 - group["momentum"])
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
p.add_(g, alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5)
Note the aspect-ratio scaling: max(1, p.size(-2) / p.size(-1))**0.5.
Distributed Muon (DistMuon)
For multi-GPU training, DistMuon handles gradient synchronization:
muon.py:86-187:
class DistMuon(torch.optim.Optimizer):
@torch.no_grad()
def step(self):
rank = dist.get_rank()
world_size = dist.get_world_size()
# Kick off reduce_scatter to average gradients
all_reduce_futures = []
for group in self.param_groups:
params = group["params"]
for base_i in range(0, len(params), world_size):
owner_idx = base_i + rank
rs_input = [p.grad for p in params[base_i:base_i + world_size]]
rs_output = params[owner_idx].grad if owner_idx < len(params) else zero_buffer
work = dist.reduce_scatter(rs_output, rs_input, op=dist.ReduceOp.AVG, async_op=True)
all_reduce_futures.append(work.get_future())
# Each rank computes update for its owned params, then all_gather
future_idx = 0
all_gather_futures = []
for group in self.param_groups:
for base_i in range(0, len(params), world_size):
owner_idx = base_i + rank
all_reduce_futures[future_idx].wait()
future_idx += 1
if owner_idx < len(params):
p = params[owner_idx]
g = p.grad # Now averaged across ranks
# Muon update
buf = self.state[p].setdefault("momentum_buffer", torch.zeros_like(g))
buf.lerp_(g, 1.0 - group["momentum"])
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
p.add_(g, alpha=-group["lr"] * scale)
# All_gather to replicate updated params
work = dist.all_gather(ag_output, ag_input, async_op=True)
all_gather_futures.append(work.get_future())
torch.futures.collect_all(all_gather_futures).wait()
Communication pattern: reduce_scatter → compute → all_gather.
Distributed AdamW (DistAdamW)
adamw.py:10-78: ZeRO-2 style sharded optimizer states:
class DistAdamW(torch.optim.Optimizer):
@torch.compile
@torch.no_grad()
def step(self):
rank = dist.get_rank()
world_size = dist.get_world_size()
# reduce_scatter gradients
for group in self.param_groups:
for p in group["params"]:
rank_size = p.shape[0] // world_size
grad_slice = torch.empty_like(p[:rank_size])
dist.reduce_scatter_tensor(grad_slice, p.grad, op=dist.ReduceOp.AVG)
# Update only this rank's slice
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
state = self.state[p]
if not state:
state['step'] = torch.tensor(0, dtype=torch.int64)
state['exp_avg'] = torch.zeros_like(p_slice)
state['exp_avg_sq'] = torch.zeros_like(p_slice)
# AdamW update on slice
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
state['step'] += 1
exp_avg.mul_(beta1).add_(grad_slice, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad_slice, grad_slice, value=1 - beta2)
# Bias correction and update
step_size = lr * (torch.sqrt(1 - beta2**t) / (1 - beta1**t))
p_slice.addcdiv_(exp_avg, exp_avg_sq.sqrt() + eps, value=-step_size)
# Decoupled weight decay
p_slice.mul_(1 - lr * wd)
# all_gather to replicate
dist.all_gather_into_tensor(p, p_slice)
Memory savings: Each rank stores optimizer states for only of parameters.
Dual Optimizer Setup
nanochat uses different optimizers for different parameter types:
AdamW for Embeddings
gpt.py:217-235:
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0):
model_dim = self.config.n_embd
# Separate parameters
matrix_params = list(self.transformer.h.parameters())
embedding_params = list(self.transformer.wte.parameters())
lm_head_params = list(self.lm_head.parameters())
# Scale LR by ∝1/√dmodel
dmodel_lr_scale = (model_dim / 768) ** -0.5
adam_groups = [
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
]
adamw_kwargs = dict(betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay)
adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
# Muon for linear layers
muon_kwargs = dict(lr=matrix_lr, momentum=0.95)
muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
return [adamw_optimizer, muon_optimizer]
Learning Rate Scaling:
For d20 model (): scale factor = .
KV Cache for Inference
Data Structure
engine.py:83-100:
class KVCache:
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers):
self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
self.kv_cache = None
self.pos = 0
def reset(self):
self.pos = 0
def get_pos(self):
return self.pos
Cache Operations
Insert and retrieve (engine.py:135-160):
def insert_kv(self, layer_idx, k, v):
if self.kv_cache is None:
self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device)
B, H, T_add, D = k.size()
t0, t1 = self.pos, self.pos + T_add
# Dynamic growth if needed
if t1 > self.kv_cache.size(4):
t_needed = (t1 + 1024 + 1023) & ~1023 # Round up to 1024
additional = torch.empty(..., dtype=k.dtype, device=k.device)
self.kv_cache = torch.cat([self.kv_cache, additional], dim=4)
# Insert k, v
self.kv_cache[layer_idx, 0, :, :, t0:t1, :] = k
self.kv_cache[layer_idx, 1, :, :, t0:t1, :] = v
# Return full cache view
key_view = self.kv_cache[layer_idx, 0, :, :, :t1, :]
value_view = self.kv_cache[layer_idx, 1, :, :, :t1, :]
# Advance pos after last layer
if layer_idx == self.kv_cache.size(0) - 1:
self.pos = t1
return key_view, value_view
Complexity Analysis
Without KV cache:
- Each of generation steps processes tokens
- Total: forward passes, attention
With KV cache:
- Each step processes only 1 token
- Total: forward passes, attention
For 1000 tokens: 1000× fewer forward passes.
Tool Use (Calculator)
State Machine
engine.py:184-191:
class RowState:
def __init__(self, current_tokens=None):
self.current_tokens = current_tokens or []
self.forced_tokens = deque()
self.in_python_block = False
self.python_expr_tokens = []
self.completed = False
Safe Evaluation
engine.py:47-80:
def use_calculator(expr):
expr = expr.replace(",", "") # Remove number commas
# Pure math expression
if all([x in "0123456789*+-/.() " for x in expr]):
if "**" in expr: # Disallow power
return None
return eval_with_timeout(expr)
# String operations (e.g., .count())
dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', ...]
if any(pattern in expr.lower() for pattern in dangerous_patterns):
return None
if '.count(' not in expr:
return None
return eval_with_timeout(expr)
Generation Loop Integration
engine.py:282-295:
if next_token == python_start:
state.in_python_block = True
state.python_expr_tokens = []
elif next_token == python_end and state.in_python_block:
state.in_python_block = False
if state.python_expr_tokens:
expr = self.tokenizer.decode(state.python_expr_tokens)
result = use_calculator(expr)
if result is not None:
result_tokens = self.tokenizer.encode(str(result))
state.forced_tokens.append(output_start)
state.forced_tokens.extend(result_tokens)
state.forced_tokens.append(output_end)
elif state.in_python_block:
state.python_expr_tokens.append(next_token)
Complete Attention Module
gpt.py:51-109:
class CausalSelfAttention(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.layer_idx = layer_idx
self.n_head = config.n_head
self.n_kv_head = config.n_kv_head
self.head_dim = config.n_embd // self.n_head
assert config.n_embd % self.n_head == 0
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
self.c_q = nn.Linear(config.n_embd, self.n_head * self.head_dim, bias=False)
self.c_k = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_v = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
def forward(self, x, cos_sin, kv_cache):
B, T, C = x.size()
# Project to Q, K, V
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
# Apply RoPE and QK norm
cos, sin = cos_sin
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
q, k = norm(q), norm(k)
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
# KV cache
if kv_cache is not None:
k, v = kv_cache.insert_kv(self.layer_idx, k, v)
# Attention with GQA
enable_gqa = self.n_head != self.n_kv_head
if kv_cache is None or q.size(2) == k.size(2):
y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
elif q.size(2) == 1:
y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
else:
# Chunked inference with custom mask
...
y = y.transpose(1, 2).contiguous().view(B, T, -1)
y = self.c_proj(y)
return y
Training Cost Tiers
| Tier | Model | Layers | Dim | Parameters | Tokens | Cost | Time (8×H100) |
|---|---|---|---|---|---|---|---|
| $100 | d20 | 20 | 1280 | 561M | 11B | ~$100 | 4h |
| $300 | d26 | 26 | 1664 | ~1B | 20B | ~$300 | 12h |
| $800 | d32 | 32 | 2048 | 1.9B | 38B | ~$800 | 33h |
Model sizing from depth: layers = d, dim = 64*d, heads = ceil(d*64/128).
Frequently Asked Questions
Related Articles
nanoGPT: Andrej Karpathy's Minimal GPT Training Framework
A comprehensive, equation-complete analysis of nanoGPT—Andrej Karpathy's influential minimal GPT implementation. Deep dive into the ~300-line model definition (model.py), training loop (train.py), Flash Attention, weight initialization, and the mathematical foundations behind every component.
GRPO: Group Relative Policy Optimization Explained
Understanding Group Relative Policy Optimization—the technique behind DeepSeek's training efficiency and a simpler alternative to PPO-based RLHF.
Context Extension: How LLMs Scale Beyond Training Length
A comprehensive deep dive into context extension techniques—how models trained on 4K tokens extrapolate to 128K+. Understand RoPE scaling, Position Interpolation, NTK-aware scaling, YaRN, and the mathematics of long-context LLMs.
vLLM in Production: The Complete Guide to High-Performance LLM Serving
A comprehensive guide to deploying vLLM in production—covering architecture internals, configuration tuning, Kubernetes deployment, monitoring, and troubleshooting.