Skip to main content
Back to Blog

nanochat: Andrej Karpathy's Full-Stack ChatGPT Clone

A comprehensive, equation-complete analysis of nanochat—the complete ChatGPT pipeline from tokenization through reinforcement learning. Deep dive into the modern GPT architecture (RoPE, RMSNorm, GQA, QK-norm, ReLU²), the Muon optimizer with Newton-Schulz orthogonalization, KV cache inference, and tool use.

5 min read
Share:

Repository Overview

nanochat implements the complete ChatGPT pipeline in ~8,000 lines of Python. The core files:

FileLinesPurpose
nanochat/gpt.py311Modern GPT architecture: RoPE, RMSNorm, QK-norm, GQA, ReLU²
nanochat/muon.py188Muon optimizer with Newton-Schulz orthogonalization
nanochat/adamw.py~100Distributed AdamW for embeddings
nanochat/tokenizer.py399BPE tokenizer (rustbpe + tiktoken)
nanochat/engine.py385Inference engine with KV cache and tool use
nanochat/dataloader.py~200Data loading for all training stages
tasks/*.pyvariesDataset implementations (GSM8K, MMLU, SmolTalk, etc.)

GitHub: github.com/karpathy/nanochat

Architecture Differences from nanoGPT

nanochat modernizes the GPT-2 architecture used in nanoGPT:

FeaturenanoGPT (GPT-2)nanochat
Position encodingLearned absoluteRoPE (relative)
NormalizationLayerNorm (learned γ, β)RMSNorm (no params)
Q/K scaling1/dk1/\sqrt{d_k}QK normalization
ActivationGELUReLU²
K/V headsSame as QGQA (can differ)
Output logitsRawSoft-capped (tanh)
EmbeddingsTiedUntied
OptimizerAdamW onlyAdamW + Muon

Model Configuration

gpt.py:26-33:

Python
@dataclass
class GPTConfig:
    sequence_len: int = 1024
    vocab_size: int = 50304
    n_layer: int = 12
    n_head: int = 6       # Number of query heads
    n_kv_head: int = 6    # Number of key/value heads (GQA)
    n_embd: int = 768

The speedrun "d20" model uses: n_layer=20, n_head=10, n_kv_head=10, n_embd=1280, giving 561M parameters.

The Five-Stage Training Pipeline

Stage 1: Tokenization

nanochat provides two tokenizer implementations:

  1. RustBPE - Custom Rust library for fast training
  2. tiktoken - For efficient inference

Tokenizer Training (scripts/tok_train.py)

Python
parser.add_argument('--max_chars', type=int, default=10_000_000_000)  # 10B chars
parser.add_argument('--doc_cap', type=int, default=10_000)            # Max per doc
parser.add_argument('--vocab_size', type=int, default=65536)          # 2^16

def text_iterator():
    nchars = 0
    for batch in parquets_iter_batched(split="train"):
        for doc in batch:
            doc_text = doc[:args.doc_cap]  # Crop long documents
            nchars += len(doc_text)
            yield doc_text
            if nchars > args.max_chars:
                return

tokenizer = RustBPETokenizer.train_from_iterator(text_iter, args.vocab_size)

Training uses FineWeb-Edu data with documents capped at 10K characters to prevent single documents from dominating the vocabulary.

BPE Algorithm

  1. Initialize vocabulary V\mathcal{V} with 256 byte tokens
  2. Count all adjacent token pairs in corpus
  3. Find most frequent pair (a,b)(a, b)
  4. Create new token abab: VV{ab}\mathcal{V} \leftarrow \mathcal{V} \cup \{ab\}
  5. Replace all (a,b)(a, b) with abab in corpus
  6. Repeat until V=V|\mathcal{V}| = V

Special Tokens (tokenizer.py:13-25)

Python
SPECIAL_TOKENS = [
    "<|bos|>",              # Beginning of sequence
    "<|user_start|>",       # Start of user message
    "<|user_end|>",         # End of user message
    "<|assistant_start|>",  # Start of assistant response
    "<|assistant_end|>",    # End of assistant response
    "<|python_start|>",     # Start of Python code (tool use)
    "<|python_end|>",       # End of Python code
    "<|output_start|>",     # Start of tool output
    "<|output_end|>",       # End of tool output
]

Pre-tokenization Pattern (tokenizer.py:30)

Python
SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""

Note: \p{N}{1,2} uses 1-2 digits (not 3 like GPT-4) to save vocabulary space in smaller models.

Dual Tokenizer Implementation (tokenizer.py:35-80)

Python
class HuggingFaceTokenizer:
    """For training - uses HuggingFace Tokenizers library"""
    @classmethod
    def train_from_iterator(cls, text_iterator, vocab_size):
        tokenizer = HFTokenizer(BPE(byte_fallback=True, unk_token=None))
        tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
            pre_tokenizers.Split(pattern=Regex(SPLIT_PATTERN), behavior="isolated"),
            pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False)
        ])
        trainer = BpeTrainer(vocab_size=vocab_size, special_tokens=SPECIAL_TOKENS)
        tokenizer.train_from_iterator(text_iterator, trainer=trainer)
        return cls(tokenizer)

class RustBPETokenizer:
    """For training - uses custom rustbpe for speed"""

class TiktokenTokenizer:
    """For inference - uses tiktoken for efficiency"""

Stage 2: Pretraining (scripts/base_train.py)

Training Objective:

Lpretrain=1Tt=1TlogPθ(xtx<t)\mathcal{L}_{\text{pretrain}} = -\frac{1}{T} \sum_{t=1}^{T} \log P_\theta(x_t | x_{<t})

Chinchilla Scaling Laws:

Given compute budget CC (in FLOPS):

C6NDC \approx 6 \cdot N \cdot D

Optimal allocation:

NC0.5,DC0.5N^* \propto C^{0.5}, \quad D^* \propto C^{0.5}

The Chinchilla-optimal ratio:

D20ND^* \approx 20 \cdot N^*

For the d20 model (561M parameters), this suggests ~11B tokens—exactly what nanochat uses.

Complete Pretraining Loop

base_train.py:214-364:

Python
while True:
    last_step = step == num_iterations
    flops_so_far = num_flops_per_token * total_batch_size * step

    # Evaluate val bpb (bits per byte)
    if last_step or step % eval_every == 0:
        model.eval()
        val_loader = build_val_loader()
        eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
        with autocast_ctx:
            val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
        model.train()

    # Termination
    if last_step:
        break

    # Single training step
    synchronize()
    t0 = time.time()
    for micro_step in range(grad_accum_steps):
        with autocast_ctx:
            loss = model(x, y)
        train_loss = loss.detach()
        loss = loss / grad_accum_steps
        loss.backward()
        x, y, dataloader_state_dict = next(train_loader)

    # Gradient clipping
    if grad_clip > 0.0:
        grad_norm = torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)

    # Step optimizers with LR scheduling
    lrm = get_lr_multiplier(step)
    for opt in optimizers:
        for group in opt.param_groups:
            group["lr"] = group["initial_lr"] * lrm
    muon_momentum = get_muon_momentum(step)
    for group in muon_optimizer.param_groups:
        group["momentum"] = muon_momentum
    for opt in optimizers:
        opt.step()
    model.zero_grad(set_to_none=True)

    step += 1

Pretraining Learning Rate Schedule

base_train.py:180-189:

Python
def get_lr_multiplier(it):
    warmup_iters = round(warmup_ratio * num_iterations)
    warmdown_iters = round(warmdown_ratio * num_iterations)
    if it < warmup_iters:
        return (it + 1) / warmup_iters
    elif it <= num_iterations - warmdown_iters:
        return 1.0
    else:
        progress = (num_iterations - it) / warmdown_iters
        return progress * 1.0 + (1 - progress) * final_lr_frac

Default: 0% warmup, 20% warmdown, final LR = 0%.

Muon Momentum Warmup

base_train.py:192-195:

Python
def get_muon_momentum(it):
    frac = min(it / 300, 1)
    momentum = (1 - frac) * 0.85 + frac * 0.95
    return momentum

Momentum ramps from 0.85 to 0.95 over first 300 steps.

Data Loading from Parquet Files

dataloader.py:10-89: Streams data from FineWeb-Edu 100B shuffle:

Python
def tokenizing_distributed_data_loader_with_state(B, T, split, device="cuda", resume_state_dict=None):
    def document_batches():
        parquet_paths = list_parquet_files()
        parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
        while True:
            for filepath in parquet_paths:
                pf = pq.ParquetFile(filepath)
                for rg_idx in range(ddp_rank, pf.num_row_groups, ddp_world_size):
                    rg = pf.read_row_group(rg_idx)
                    batch = rg.column('text').to_pylist()
                    yield batch

    token_buffer = deque()
    needed_tokens = B * T + 1
    while True:
        while len(token_buffer) < needed_tokens:
            doc_batch, _ = next(batches)
            token_lists = tokenizer.encode(doc_batch, prepend=bos_token)
            for tokens in token_lists:
                token_buffer.extend(tokens)
        # Build inputs/targets
        inputs = tokens[:-1].view(B, T).to(device)
        targets = tokens[1:].view(B, T).to(device)
        yield inputs, targets, state_dict

Bits Per Byte (BPB) Evaluation

loss_eval.py:9-65: Tokenizer-independent metric:

Python
@torch.no_grad()
def evaluate_bpb(model, batches, steps, token_bytes):
    total_nats = torch.tensor(0.0, dtype=torch.float32, device=model.get_device())
    total_bytes = torch.tensor(0, dtype=torch.int64, device=model.get_device())
    for _ in range(steps):
        x, y = next(batch_iter)
        loss2d = model(x, y, loss_reduction='none')
        num_bytes2d = token_bytes[y]  # Bytes per target token
        total_nats += (loss2d * (num_bytes2d > 0)).sum()
        total_bytes += num_bytes2d.sum()
    # All-reduce across ranks
    if world_size > 1:
        dist.all_reduce(total_nats, op=dist.ReduceOp.SUM)
        dist.all_reduce(total_bytes, op=dist.ReduceOp.SUM)
    bpb = total_nats.item() / (math.log(2) * total_bytes.item())
    return bpb

BPB=tLtln2tbytes(t)\text{BPB} = \frac{\sum_t \mathcal{L}_t}{\ln 2 \cdot \sum_t \text{bytes}(t)}

Stage 3: Midtraining (scripts/mid_train.py)

Introduces conversational structure, special tokens, and tool use patterns.

Midtraining Data Mixture

mid_train.py:96-111:

Python
train_dataset = TaskMixture([
    SmolTalk(split="train"),                    # 460K conversations
    MMLU(subset="auxiliary_train", split="train"),  # 100K multiple choice
    GSM8K(subset="main", split="train"),        # 8K math with tool use
    CustomJSON(filepath=identity_conversations_filepath),  # 1K identity
    CustomJSON(filepath=identity_conversations_filepath),  # 2 epochs of identity
    SimpleSpelling(size=200000, split="train"), # 200K spelling
    SpellingBee(size=80000, split="train"),     # 80K letter counting
])  # Total: ~848K rows

TaskMixture Implementation

tasks/common.py:54-86:

Python
class TaskMixture(Task):
    def __init__(self, tasks, **kwargs):
        self.tasks = tasks
        self.lengths = [len(task) for task in self.tasks]
        self.num_conversations = sum(self.lengths)
        # Build shuffled index map
        self.index_map = []
        for task_idx, task_length in enumerate(self.lengths):
            for local_idx in range(task_length):
                self.index_map.append((task_idx, local_idx))
        rng = random.Random(42)
        rng.shuffle(self.index_map)  # Deterministic shuffle

    def get_example(self, index):
        task_idx, local_idx = self.index_map[index]
        return self.tasks[task_idx][local_idx]

Midtraining Learning Rate Schedule

mid_train.py:163-165:

Python
def get_lr_multiplier(progress):
    return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2

η(t)={ηmaxif t/T<0.8ηmax1t/T0.2otherwise\eta(t) = \begin{cases} \eta_{\max} & \text{if } t/T < 0.8 \\ \eta_{\max} \cdot \frac{1 - t/T}{0.2} & \text{otherwise} \end{cases}

80% constant LR, then linear decay to 0 over final 20%.

Stage 4: Supervised Fine-Tuning (scripts/chat_sft.py)

Loss Masking: Only compute loss on assistant tokens:

LSFT=1atalogPθ(xtx<t)\mathcal{L}_{\text{SFT}} = -\frac{1}{|a|} \sum_{t \in a} \log P_\theta(x_t | x_{<t})

where aa is the set of assistant token positions.

SFT Data Mixture

chat_sft.py:83-93:

Python
train_ds = TaskMixture([
    ARC(subset="ARC-Easy", split="train"),      # 2.3K rows
    ARC(subset="ARC-Challenge", split="train"), # 1.1K rows
    GSM8K(subset="main", split="train"),        # 8K rows
    SmolTalk(split="train", stop=10_000),       # 10K rows
    CustomJSON(filepath=identity_conversations_filepath),  # 1K identity
    SimpleSpelling(size=300, split="train"),    # 300 rows
    SpellingBee(size=300, split="train"),       # 300 rows
])  # Total: ~23K rows

SFT Data Collation with Loss Masking

chat_sft.py:98-118:

Python
def sft_data_generator(dataset, batch_size):
    pad_token_id = tokenizer.encode_special("<|assistant_end|>")

    def collate_and_yield(batch):
        nrows = len(batch)
        ncols = max(len(ids) for ids, mask in batch) - 1
        inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long)
        targets = torch.full((nrows, ncols), -1, dtype=torch.long)  # -1 = ignore

        for i, (ids, mask) in enumerate(batch):
            n = len(ids)
            ids_tensor = torch.tensor(ids, dtype=torch.long)
            inputs[i, :n-1] = ids_tensor[:-1]
            row_targets = ids_tensor[1:]
            mask_tensor = torch.tensor(mask[1:], dtype=torch.long)
            row_targets[mask_tensor == 0] = -1  # Mask non-assistant tokens
            targets[i, :n-1] = row_targets
        return inputs.to(device), targets.to(device)

The mask from render_conversation() indicates which positions are assistant tokens.

SFT Training Loop

chat_sft.py:169-246:

Python
for step in range(num_iterations):
    last_step = step == num_iterations - 1

    # Evaluate validation loss
    if last_step or step % eval_every == 0:
        model.eval()
        losses = []
        for _ in range(eval_steps):
            val_inputs, val_targets = next(val_loader)
            with torch.no_grad(), autocast_ctx:
                loss = model(val_inputs, val_targets)
            losses.append(loss)
        val_loss = torch.stack(losses).mean()
        if ddp:
            dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)

    # Gradient accumulation
    num_tokens = torch.tensor(0, device=device)
    for micro_step in range(grad_accum_steps):
        train_inputs, train_targets = next(train_loader)
        with autocast_ctx:
            loss = model(train_inputs, train_targets)
        loss = loss / grad_accum_steps
        loss.backward()
        num_tokens += (train_targets >= 0).sum()

    # LR scheduler: linear decay
    lrm = 1.0 - step / num_iterations
    for opt in optimizers:
        for group in opt.param_groups:
            group["lr"] = group["initial_lr"] * lrm
    for opt in optimizers:
        opt.step()
    model.zero_grad(set_to_none=True)

Stage 5: Reinforcement Learning (scripts/chat_rl.py)

nanochat implements a simplified GRPO variant—closer to REINFORCE with baseline:

Simplifications from full GRPO:

  1. No KL regularization to reference model
  2. On-policy (no PPO ratio clipping needed)
  3. Token-level normalization (GAPO style)
  4. Advantage = rrˉr - \bar{r} (no variance normalization)

GRPO Loss

For KK completions per problem with rewards rir_i:

Ai=rirˉ,rˉ=1Kj=1KrjA_i = r_i - \bar{r}, \quad \bar{r} = \frac{1}{K}\sum_{j=1}^{K} r_j

LGRPO=1Ki=1KAitlogPθ(yi,tx,yi,<t)\mathcal{L}_{\text{GRPO}} = -\frac{1}{K} \sum_{i=1}^{K} A_i \sum_{t} \log P_\theta(y_{i,t} | x, y_{i,<t})

Rollout Generation

chat_rl.py:78-140:

Python
@torch.no_grad()
def get_batch():
    for example_idx in itertools.cycle(rank_indices):
        conversation = train_task[example_idx]
        tokens = tokenizer.render_for_completion(conversation)
        prefix_length = len(tokens)

        # Generate num_samples completions
        generated_sequences = []
        for sampling_step in range(num_samples // device_batch_size):
            seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF
            with autocast_ctx:
                seqs, masks = engine.generate_batch(
                    tokens, num_samples=device_batch_size,
                    max_tokens=max_new_tokens,
                    temperature=temperature, top_k=top_k, seed=seed
                )
            generated_sequences.extend(seqs)

        # Calculate rewards
        rewards = []
        for sample_tokens in generated_sequences:
            generated_tokens = sample_tokens[prefix_length:]
            generated_text = tokenizer.decode(generated_tokens)
            reward = train_task.reward(conversation, generated_text)
            rewards.append(reward)

        # Compute advantages
        rewards = torch.tensor(rewards, device=device)
        mu = rewards.mean()
        advantages = rewards - mu

        yield sequences, inputs, targets, rewards, advantages

GSM8K Reward Function

tasks/gsm8k.py:110-117:

Python
def reward(self, conversation, assistant_response):
    is_correct = self.evaluate(conversation, assistant_response)
    return float(is_correct)  # Binary: 0.0 or 1.0

Evaluation extracts the answer after #### marker:

Python
GSM_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
def extract_answer(completion):
    match = GSM_RE.search(completion)
    if match:
        return match.group(1).strip().replace(",", "")
    return None

RL Training Loop

chat_rl.py:241-300:

Python
for step in range(num_steps):
    # Evaluate pass@k
    if step % eval_every == 0:
        model.eval()
        passk = torch.zeros(device_batch_size, device=device)
        records = list(run_gsm8k_eval(val_task, tokenizer, engine, ...))
        for k in range(1, device_batch_size + 1):
            passk[k-1] = sum(any(o["is_correct"] for o in r["outcomes"][:k]) for r in records)
        if ddp:
            dist.all_reduce(passk, op=dist.ReduceOp.SUM)

    # Forward/backward on rollouts
    for example_step in range(examples_per_rank):
        sequences, inputs, targets, rewards, advantages = next(batch_iterator)
        model.train()

        for pass_idx in range(num_passes):
            b0, b1 = pass_idx * device_batch_size, (pass_idx + 1) * device_batch_size
            with autocast_ctx:
                logp = -model(inputs[b0:b1], targets[b0:b1], loss_reduction='none')
            # Policy gradient objective
            pg_obj = (logp * advantages[b0:b1].unsqueeze(-1)).sum()
            num_valid = (targets[b0:b1] >= 0).sum().clamp(min=1)
            pg_obj = pg_obj / (num_valid * num_passes * examples_per_rank)
            loss = -pg_obj  # Minimize negative of objective
            loss.backward()

    # Update
    lrm = 1.0 - step / num_steps
    for opt in optimizers:
        for group in opt.param_groups:
            group["lr"] = group["initial_lr"] * lrm
        opt.step()
    model.zero_grad(set_to_none=True)

Key insight: No PPO clipping needed because we're on-policy (completions generated from current model).

Checkpointing (nanochat/checkpoint_manager.py)

nanochat implements distributed checkpointing with sharded optimizer states:

checkpoint_manager.py:24-48:

Python
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
    if rank == 0:
        os.makedirs(checkpoint_dir, exist_ok=True)
        # Save model parameters (only rank 0)
        model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
        torch.save(model_data, model_path)
        # Save metadata as JSON
        meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
        with open(meta_path, "w") as f:
            json.dump(meta_data, f, indent=2)

    # Optimizer state is sharded - each rank saves its own
    if optimizer_data is not None:
        optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
        torch.save(optimizer_data, optimizer_path)

def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
    model_data = torch.load(f"model_{step:06d}.pt", map_location=device)
    optimizer_data = None
    if load_optimizer:
        optimizer_data = torch.load(f"optim_{step:06d}_rank{rank:d}.pt", map_location=device)
    with open(f"meta_{step:06d}.json") as f:
        meta_data = json.load(f)
    return model_data, optimizer_data, meta_data

Checkpoint structure:

Code
checkpoint_dir/
├── model_000100.pt          # Model weights (saved by rank 0 only)
├── meta_000100.json         # Metadata: step, config, metrics
├── optim_000100_rank0.pt    # Optimizer state shard for rank 0
├── optim_000100_rank1.pt    # Optimizer state shard for rank 1
├── optim_000100_rank2.pt    # ... etc
└── optim_000100_rank7.pt

Rotary Position Embeddings (RoPE)

Mathematical Foundation

RoPE encodes position through rotation in the complex plane. For each pair of dimensions, rotation by angle θ\theta:

rotθ(x1,x2)=(cosθsinθsinθcosθ)(x1x2)\text{rot}_\theta(x_1, x_2) = \begin{pmatrix} \cos\theta & -\sin\theta \\ \sin\theta & \cos\theta \end{pmatrix} \begin{pmatrix} x_1 \\ x_2 \end{pmatrix}

Position-Dependent Rotation

For position mm and dimension pair ii:

θm,i=mbase2i/d\theta_{m,i} = m \cdot \text{base}^{-2i/d}

where base = 10,000. This creates a spectrum of frequencies.

Complete Rotation Matrix

For a vector xRdx \in \mathbb{R}^d:

RoPE(x,m)=(x1cos(mθ1)x2sin(mθ1)x1sin(mθ1)+x2cos(mθ1)xd1cos(mθd/2)xdsin(mθd/2)xd1sin(mθd/2)+xdcos(mθd/2))\text{RoPE}(x, m) = \begin{pmatrix} x_1 \cos(m\theta_1) - x_2 \sin(m\theta_1) \\ x_1 \sin(m\theta_1) + x_2 \cos(m\theta_1) \\ \vdots \\ x_{d-1} \cos(m\theta_{d/2}) - x_d \sin(m\theta_{d/2}) \\ x_{d-1} \sin(m\theta_{d/2}) + x_d \cos(m\theta_{d/2}) \end{pmatrix}

Relative Position Property

The key insight: after rotating queries and keys by their positions:

RoPE(q,m)TRoPE(k,n)=f(q,k,mn)\text{RoPE}(q, m)^T \cdot \text{RoPE}(k, n) = f(q, k, m-n)

The dot product depends only on relative position mnm - n, not absolute positions!

Proof (2D case): Let q=(q1,q2)q = (q_1, q_2), k=(k1,k2)k = (k_1, k_2).

RoPE(q,m)RoPE(k,n)=(q1cmq2sm)(k1cnk2sn)+(q1sm+q2cm)(k1sn+k2cn)\text{RoPE}(q, m) \cdot \text{RoPE}(k, n) = (q_1c_m - q_2s_m)(k_1c_n - k_2s_n) + (q_1s_m + q_2c_m)(k_1s_n + k_2c_n)

where cm=cos(mθ)c_m = \cos(m\theta), sm=sin(mθ)s_m = \sin(m\theta).

Expanding and using cos(a)cos(b)+sin(a)sin(b)=cos(ab)\cos(a)\cos(b) + \sin(a)\sin(b) = \cos(a-b):

=(q1k1+q2k2)cos((mn)θ)+(q1k2q2k1)sin((mn)θ)= (q_1k_1 + q_2k_2)\cos((m-n)\theta) + (q_1k_2 - q_2k_1)\sin((m-n)\theta)

Result depends only on mnm - n.

Implementation

Precomputation (gpt.py:190-204):

Python
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
    if device is None:
        device = self.transformer.wte.weight.device
    channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
    inv_freq = 1.0 / (base ** (channel_range / head_dim))
    t = torch.arange(seq_len, dtype=torch.float32, device=device)
    freqs = torch.outer(t, inv_freq)
    cos, sin = freqs.cos(), freqs.sin()
    cos, sin = cos.bfloat16(), sin.bfloat16()
    cos, sin = cos[None, :, None, :], sin[None, :, None, :]
    return cos, sin

Application (gpt.py:41-49):

Python
def apply_rotary_emb(x, cos, sin):
    assert x.ndim == 4
    d = x.shape[3] // 2
    x1, x2 = x[..., :d], x[..., d:]
    y1 = x1 * cos + x2 * sin
    y2 = x1 * (-sin) + x2 * cos
    out = torch.cat([y1, y2], 3)
    out = out.to(x.dtype)
    return out

RMSNorm

Mathematical Definition

RMSNorm without learnable parameters:

RMSNorm(x)=x1di=1dxi2+ϵ\text{RMSNorm}(x) = \frac{x}{\sqrt{\frac{1}{d}\sum_{i=1}^d x_i^2 + \epsilon}}

Compare to LayerNorm:

LayerNorm(x)=γxμσ2+ϵ+β\text{LayerNorm}(x) = \gamma \odot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta

RMSNorm removes:

  1. Mean centering (μ- \mu)
  2. Learned scale (γ\gamma)
  3. Learned shift (β\beta)

Implementation

gpt.py:36-38:

Python
def norm(x):
    return F.rms_norm(x, (x.size(-1),))

This pure functional form aids training stability and requires no parameters.

QK Normalization

Mathematical Definition

After RoPE, normalize queries and keys to unit length:

q=qq2,k=kk2q' = \frac{q}{\|q\|_2}, \quad k' = \frac{k}{\|k\|_2}

The attention scores become cosine similarity:

qk=cos((q,k))[1,1]q' \cdot k' = \cos(\angle(q, k)) \in [-1, 1]

Benefits

  1. Bounded logits: No risk of softmax overflow
  2. Scale-invariant: Works across model sizes without tuning
  3. Training stability: Well-behaved gradients

Implementation

gpt.py:77:

Python
q, k = norm(q), norm(k)  # QK norm

ReLU² Activation

Mathematical Definition

ReLU2(x)=(max(0,x))2={x2if x>00otherwise\text{ReLU}^2(x) = (\max(0, x))^2 = \begin{cases} x^2 & \text{if } x > 0 \\ 0 & \text{otherwise} \end{cases}

Properties

PropertyReLUGELUReLU²
SparsityHighLowHigh
Derivative at 0Undefined~0.50 (continuous)
ComputationFastSlowFast

Implementation

gpt.py:118-122:

Python
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)

    def forward(self, x):
        x = self.c_fc(x)
        x = F.relu(x).square()  # ReLU² activation
        x = self.c_proj(x)
        return x

Group Query Attention (GQA)

Mathematical Definition

Standard MHA: nkv_heads=nheadsn_{\text{kv\_heads}} = n_{\text{heads}}

GQA: nkv_heads<nheadsn_{\text{kv\_heads}} < n_{\text{heads}}

Each KV group serves nheads/nkv_headsn_{\text{heads}} / n_{\text{kv\_heads}} query heads.

Memory Savings

KV cache size without GQA:

Memory=2×L×h×T×dk×bytes\text{Memory} = 2 \times L \times h \times T \times d_k \times \text{bytes}

With GQA (ratio r=h/gr = h/g):

MemoryGQA=Memoryr\text{Memory}_{\text{GQA}} = \frac{\text{Memory}}{r}

Example (d20 model with 4:1 GQA):

  • Without GQA: 2 × 20 × 16 × 1024 × 64 × 2 bytes = 84MB
  • With 4:1 GQA: 21MB (4× reduction)

Implementation

gpt.py:51-64:

Python
class CausalSelfAttention(nn.Module):
    def __init__(self, config, layer_idx):
        super().__init__()
        self.n_head = config.n_head
        self.n_kv_head = config.n_kv_head
        self.head_dim = config.n_embd // self.n_head
        assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
        self.c_q = nn.Linear(config.n_embd, self.n_head * self.head_dim, bias=False)
        self.c_k = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=False)
        self.c_v = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=False)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)

gpt.py:87-91:

Python
enable_gqa = self.n_head != self.n_kv_head
y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)

Logit Soft-Capping

Mathematical Definition

logits=ctanh(logits/c)\text{logits}' = c \cdot \tanh(\text{logits} / c)

where c=15c = 15 is the cap value.

Properties

InputOutputEffect
1.0~1.0Unchanged
5.0~4.9Slight compression
15.0~11.4Significant compression
50.0~15.0Saturated

Benefits

  1. Numerical stability: Prevents softmax overflow
  2. Prevents overconfidence: Extreme logits → probability ≈ 1.0
  3. Differentiable: Unlike hard clipping

Implementation

gpt.py:266-270:

Python
softcap = 15
logits = self.lm_head(x)
logits = logits[..., :self.config.vocab_size]  # Remove padding
logits = logits.float()  # FP32 for stability
logits = softcap * torch.tanh(logits / softcap)

Untied Embeddings

Unlike nanoGPT's weight tying, nanochat uses separate embedding and output matrices:

gpt.py:146-150:

Python
self.transformer = nn.ModuleDict({
    "wte": nn.Embedding(padded_vocab_size, config.n_embd),
    "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
})
self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False)

Why untie? Input embeddings optimize for understanding tokens; output projections optimize for predicting tokens. These roles differ enough that separate parameters improve capacity.

Weight Initialization

Standard Initialization

gpt.py:177-187:

Python
def _init_weights(self, module):
    if isinstance(module, nn.Linear):
        fan_out = module.weight.size(0)
        fan_in = module.weight.size(1)
        std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in))
        torch.nn.init.normal_(module.weight, mean=0.0, std=std)
    elif isinstance(module, nn.Embedding):
        torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)

Zero Initialization for Output Projections

gpt.py:163-168:

Python
def init_weights(self):
    self.apply(self._init_weights)
    torch.nn.init.zeros_(self.lm_head.weight)
    for block in self.transformer.h:
        torch.nn.init.zeros_(block.mlp.c_proj.weight)
        torch.nn.init.zeros_(block.attn.c_proj.weight)

Zero-initializing output projections ensures residual connections initially pass through unchanged.

The Muon Optimizer

Algorithm Overview

Muon (MomentUm Orthogonalized by Newton-schulz) orthogonalizes gradient updates:

  1. Compute momentum: mt=βmt1+(1β)gtm_t = \beta m_{t-1} + (1-\beta) g_t
  2. Optionally apply Nesterov: g=g+βmg' = g + \beta \cdot m
  3. Orthogonalize via Newton-Schulz: G=ortho(G)G' = \text{ortho}(G)
  4. Apply update with aspect-ratio scaling

Newton-Schulz Orthogonalization

To orthogonalize matrix GG, we want G=G(GTG)1/2G' = G(G^TG)^{-1/2}, making all singular values equal to 1.

Newton-Schulz iteration:

X0=GGFX_0 = \frac{G}{\|G\|_F} Xk+1=Xk(aI+bXkTXk+c(XkTXk)2)X_{k+1} = X_k (aI + bX_k^T X_k + c(X_k^T X_k)^2)

with quintic coefficients (a,b,c)=(3.4445,4.7750,2.0315)(a, b, c) = (3.4445, -4.7750, 2.0315).

Five iterations approximate:

ortho(G)G(GTG)1/2\text{ortho}(G) \approx G(G^TG)^{-1/2}

Why Orthogonalization Helps

Gradient updates often have low-rank structure—few directions receive most learning signal. This causes:

  • Slow learning in underrepresented directions
  • Oscillation in dominant directions

Orthogonalization redistributes signal: all singular values become 1, no direction is preferred.

Implementation

muon.py:10-36:

Python
@torch.compile
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
    assert G.ndim >= 2
    a, b, c = (3.4445, -4.7750, 2.0315)
    X = G.bfloat16()
    if G.size(-2) > G.size(-1):
        X = X.mT

    # Ensure spectral norm is at most 1
    X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
    # Perform the NS iterations
    for _ in range(steps):
        A = X @ X.mT
        B = b * A + c * A @ A
        X = a * X + B @ X

    if G.size(-2) > G.size(-1):
        X = X.mT
    return X

muon.py:69-83:

Python
@torch.no_grad()
def step(self):
    for group in self.param_groups:
        params = group["params"]
        for p in params:
            g = p.grad
            state = self.state[p]
            if "momentum_buffer" not in state:
                state["momentum_buffer"] = torch.zeros_like(g)
            buf = state["momentum_buffer"]
            buf.lerp_(g, 1 - group["momentum"])
            g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
            g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
            p.add_(g, alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5)

Note the aspect-ratio scaling: max(1, p.size(-2) / p.size(-1))**0.5.

Distributed Muon (DistMuon)

For multi-GPU training, DistMuon handles gradient synchronization:

muon.py:86-187:

Python
class DistMuon(torch.optim.Optimizer):
    @torch.no_grad()
    def step(self):
        rank = dist.get_rank()
        world_size = dist.get_world_size()

        # Kick off reduce_scatter to average gradients
        all_reduce_futures = []
        for group in self.param_groups:
            params = group["params"]
            for base_i in range(0, len(params), world_size):
                owner_idx = base_i + rank
                rs_input = [p.grad for p in params[base_i:base_i + world_size]]
                rs_output = params[owner_idx].grad if owner_idx < len(params) else zero_buffer
                work = dist.reduce_scatter(rs_output, rs_input, op=dist.ReduceOp.AVG, async_op=True)
                all_reduce_futures.append(work.get_future())

        # Each rank computes update for its owned params, then all_gather
        future_idx = 0
        all_gather_futures = []
        for group in self.param_groups:
            for base_i in range(0, len(params), world_size):
                owner_idx = base_i + rank
                all_reduce_futures[future_idx].wait()
                future_idx += 1

                if owner_idx < len(params):
                    p = params[owner_idx]
                    g = p.grad  # Now averaged across ranks
                    # Muon update
                    buf = self.state[p].setdefault("momentum_buffer", torch.zeros_like(g))
                    buf.lerp_(g, 1.0 - group["momentum"])
                    g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
                    g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
                    scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
                    p.add_(g, alpha=-group["lr"] * scale)

                # All_gather to replicate updated params
                work = dist.all_gather(ag_output, ag_input, async_op=True)
                all_gather_futures.append(work.get_future())

        torch.futures.collect_all(all_gather_futures).wait()

Communication pattern: reduce_scatter → compute → all_gather.

Distributed AdamW (DistAdamW)

adamw.py:10-78: ZeRO-2 style sharded optimizer states:

Python
class DistAdamW(torch.optim.Optimizer):
    @torch.compile
    @torch.no_grad()
    def step(self):
        rank = dist.get_rank()
        world_size = dist.get_world_size()

        # reduce_scatter gradients
        for group in self.param_groups:
            for p in group["params"]:
                rank_size = p.shape[0] // world_size
                grad_slice = torch.empty_like(p[:rank_size])
                dist.reduce_scatter_tensor(grad_slice, p.grad, op=dist.ReduceOp.AVG)

                # Update only this rank's slice
                p_slice = p[rank * rank_size:(rank + 1) * rank_size]
                state = self.state[p]
                if not state:
                    state['step'] = torch.tensor(0, dtype=torch.int64)
                    state['exp_avg'] = torch.zeros_like(p_slice)
                    state['exp_avg_sq'] = torch.zeros_like(p_slice)

                # AdamW update on slice
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                state['step'] += 1
                exp_avg.mul_(beta1).add_(grad_slice, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad_slice, grad_slice, value=1 - beta2)
                # Bias correction and update
                step_size = lr * (torch.sqrt(1 - beta2**t) / (1 - beta1**t))
                p_slice.addcdiv_(exp_avg, exp_avg_sq.sqrt() + eps, value=-step_size)
                # Decoupled weight decay
                p_slice.mul_(1 - lr * wd)

                # all_gather to replicate
                dist.all_gather_into_tensor(p, p_slice)

Memory savings: Each rank stores optimizer states for only 1/N1/N of parameters.

Dual Optimizer Setup

nanochat uses different optimizers for different parameter types:

AdamW for Embeddings

gpt.py:217-235:

Python
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0):
    model_dim = self.config.n_embd
    # Separate parameters
    matrix_params = list(self.transformer.h.parameters())
    embedding_params = list(self.transformer.wte.parameters())
    lm_head_params = list(self.lm_head.parameters())

    # Scale LR by ∝1/√dmodel
    dmodel_lr_scale = (model_dim / 768) ** -0.5
    adam_groups = [
        dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
        dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
    ]
    adamw_kwargs = dict(betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay)
    adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)

    # Muon for linear layers
    muon_kwargs = dict(lr=matrix_lr, momentum=0.95)
    muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)

    return [adamw_optimizer, muon_optimizer]

Learning Rate Scaling:

ηscaled=ηbase(d768)0.5\eta_{\text{scaled}} = \eta_{\text{base}} \cdot \left(\frac{d}{768}\right)^{-0.5}

For d20 model (d=1280d = 1280): scale factor = (1280/768)0.50.77(1280/768)^{-0.5} \approx 0.77.

KV Cache for Inference

Data Structure

engine.py:83-100:

Python
class KVCache:
    def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers):
        self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
        self.kv_cache = None
        self.pos = 0

    def reset(self):
        self.pos = 0

    def get_pos(self):
        return self.pos

Cache Operations

Insert and retrieve (engine.py:135-160):

Python
def insert_kv(self, layer_idx, k, v):
    if self.kv_cache is None:
        self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device)
    B, H, T_add, D = k.size()
    t0, t1 = self.pos, self.pos + T_add

    # Dynamic growth if needed
    if t1 > self.kv_cache.size(4):
        t_needed = (t1 + 1024 + 1023) & ~1023  # Round up to 1024
        additional = torch.empty(..., dtype=k.dtype, device=k.device)
        self.kv_cache = torch.cat([self.kv_cache, additional], dim=4)

    # Insert k, v
    self.kv_cache[layer_idx, 0, :, :, t0:t1, :] = k
    self.kv_cache[layer_idx, 1, :, :, t0:t1, :] = v

    # Return full cache view
    key_view = self.kv_cache[layer_idx, 0, :, :, :t1, :]
    value_view = self.kv_cache[layer_idx, 1, :, :, :t1, :]

    # Advance pos after last layer
    if layer_idx == self.kv_cache.size(0) - 1:
        self.pos = t1
    return key_view, value_view

Complexity Analysis

Without KV cache:

  • Each of TT generation steps processes O(t)O(t) tokens
  • Total: O(T2)O(T^2) forward passes, O(T3)O(T^3) attention

With KV cache:

  • Each step processes only 1 token
  • Total: O(T)O(T) forward passes, O(T2)O(T^2) attention

For 1000 tokens: 1000× fewer forward passes.

Tool Use (Calculator)

State Machine

engine.py:184-191:

Python
class RowState:
    def __init__(self, current_tokens=None):
        self.current_tokens = current_tokens or []
        self.forced_tokens = deque()
        self.in_python_block = False
        self.python_expr_tokens = []
        self.completed = False

Safe Evaluation

engine.py:47-80:

Python
def use_calculator(expr):
    expr = expr.replace(",", "")  # Remove number commas

    # Pure math expression
    if all([x in "0123456789*+-/.() " for x in expr]):
        if "**" in expr:  # Disallow power
            return None
        return eval_with_timeout(expr)

    # String operations (e.g., .count())
    dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', ...]
    if any(pattern in expr.lower() for pattern in dangerous_patterns):
        return None

    if '.count(' not in expr:
        return None

    return eval_with_timeout(expr)

Generation Loop Integration

engine.py:282-295:

Python
if next_token == python_start:
    state.in_python_block = True
    state.python_expr_tokens = []
elif next_token == python_end and state.in_python_block:
    state.in_python_block = False
    if state.python_expr_tokens:
        expr = self.tokenizer.decode(state.python_expr_tokens)
        result = use_calculator(expr)
        if result is not None:
            result_tokens = self.tokenizer.encode(str(result))
            state.forced_tokens.append(output_start)
            state.forced_tokens.extend(result_tokens)
            state.forced_tokens.append(output_end)
elif state.in_python_block:
    state.python_expr_tokens.append(next_token)

Complete Attention Module

gpt.py:51-109:

Python
class CausalSelfAttention(nn.Module):
    def __init__(self, config, layer_idx):
        super().__init__()
        self.layer_idx = layer_idx
        self.n_head = config.n_head
        self.n_kv_head = config.n_kv_head
        self.head_dim = config.n_embd // self.n_head
        assert config.n_embd % self.n_head == 0
        assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
        self.c_q = nn.Linear(config.n_embd, self.n_head * self.head_dim, bias=False)
        self.c_k = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=False)
        self.c_v = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=False)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)

    def forward(self, x, cos_sin, kv_cache):
        B, T, C = x.size()

        # Project to Q, K, V
        q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
        k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
        v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)

        # Apply RoPE and QK norm
        cos, sin = cos_sin
        q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
        q, k = norm(q), norm(k)
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)

        # KV cache
        if kv_cache is not None:
            k, v = kv_cache.insert_kv(self.layer_idx, k, v)

        # Attention with GQA
        enable_gqa = self.n_head != self.n_kv_head
        if kv_cache is None or q.size(2) == k.size(2):
            y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
        elif q.size(2) == 1:
            y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
        else:
            # Chunked inference with custom mask
            ...

        y = y.transpose(1, 2).contiguous().view(B, T, -1)
        y = self.c_proj(y)
        return y

Training Cost Tiers

TierModelLayersDimParametersTokensCostTime (8×H100)
$100d20201280561M11B~$1004h
$300d26261664~1B20B~$30012h
$800d323220481.9B38B~$80033h

Model sizing from depth: layers = d, dim = 64*d, heads = ceil(d*64/128).

Frequently Asked Questions

Enrico Piovano, PhD

Co-founder & CTO at Goji AI. Former Applied Scientist at Amazon (Alexa & AGI), focused on Agentic AI and LLMs. PhD in Electrical Engineering from Imperial College London. Gold Medalist at the National Mathematical Olympiad.

Related Articles