llama2.c: Andrej Karpathy's Pure C Inference Engine
A deep dive into llama2.c - Karpathy's ~1000 lines of pure C that runs Llama 2 inference. Covers the complete transformer implementation including RMSNorm, RoPE, SwiGLU, GQA, KV cache, BPE tokenization, and top-p sampling - all without any dependencies.
Table of Contents
llama2.c is Andrej Karpathy's minimalist implementation of Llama 2 inference in pure C. The entire transformer runs in ~500 lines of dependency-free C code, making it an exceptional educational resource for understanding how LLMs work at the lowest level.
Repository: github.com/karpathy/llama2.c
Repository Structure
llama2.c/
├── run.c # 973 lines - Main C inference engine
├── runq.c # 1092 lines - Quantized (int8) inference
├── model.py # 343 lines - PyTorch model for training
├── train.py # 343 lines - Training loop
├── export.py # 567 lines - Export model to binary format
├── tokenizer.py # 78 lines - Tokenizer utilities
├── tokenizer.bin # Pre-trained tokenizer
└── tinystories.py # TinyStories dataset preparation
Architecture Overview
llama2.c implements the Llama 2 architecture with these key components:
| Component | Implementation | File Reference |
|---|---|---|
| Normalization | RMSNorm (not LayerNorm) | run.c:182-195 |
| Position Encoding | RoPE (computed on-the-fly) | run.c:264-279 |
| Attention | Multi-head with GQA support | run.c:281-319 |
| FFN | SwiGLU (w1, w2, w3) | run.c:332-348 |
| Tokenizer | BPE with byte fallback | run.c:364-571 |
| Sampling | Top-p (nucleus) sampling | run.c:624-665 |
Model Configuration
The model hyperparameters are stored in a simple struct:
run.c:19-27:
typedef struct {
int dim; // transformer dimension
int hidden_dim; // for ffn layers
int n_layers; // number of layers
int n_heads; // number of query heads
int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery)
int vocab_size; // vocabulary size, usually 256 (byte-level)
int seq_len; // max sequence length
} Config;
The n_kv_heads field enables Grouped Query Attention (GQA) - where multiple query heads share the same key/value heads, reducing memory and compute.
Weight Storage
All model weights are stored contiguously in memory:
run.c:29-48:
typedef struct {
// token embedding table
float* token_embedding_table; // (vocab_size, dim)
// weights for rmsnorms
float* rms_att_weight; // (layer, dim) rmsnorm weights
float* rms_ffn_weight; // (layer, dim)
// weights for matmuls. note dim == n_heads * head_size
float* wq; // (layer, dim, n_heads * head_size)
float* wk; // (layer, dim, n_kv_heads * head_size)
float* wv; // (layer, dim, n_kv_heads * head_size)
float* wo; // (layer, n_heads * head_size, dim)
// weights for ffn
float* w1; // (layer, hidden_dim, dim)
float* w2; // (layer, dim, hidden_dim)
float* w3; // (layer, hidden_dim, dim)
// final rmsnorm
float* rms_final_weight; // (dim,)
// (optional) classifier weights for the logits, on the last layer
float* wcls;
} TransformerWeights;
Note: Llama uses three FFN weight matrices (w1, w2, w3) for SwiGLU, unlike GPT-2's two matrices.
Runtime State (Activations + KV Cache)
run.c:50-65:
typedef struct {
// current wave of activations
float *x; // activation at current time stamp (dim,)
float *xb; // same, but inside a residual branch (dim,)
float *xb2; // an additional buffer just for convenience (dim,)
float *hb; // buffer for hidden dimension in the ffn (hidden_dim,)
float *hb2; // buffer for hidden dimension in the ffn (hidden_dim,)
float *q; // query (dim,)
float *k; // key (dim,)
float *v; // value (dim,)
float *att; // buffer for scores/attention values (n_heads, seq_len)
float *logits; // output logits
// kv cache
float* key_cache; // (layer, seq_len, dim)
float* value_cache; // (layer, seq_len, dim)
} RunState;
The KV cache stores previously computed keys and values, enabling per-token computation instead of recomputation.
Memory-Mapped Weight Loading
Weights are loaded via mmap for efficiency - the OS handles paging from disk:
run.c:142-162:
void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weights,
int* fd, float** data, ssize_t* file_size) {
FILE *file = fopen(checkpoint, "rb");
if (!file) { fprintf(stderr, "Couldn't open file %s\n", checkpoint); exit(EXIT_FAILURE); }
// read in the config header
if (fread(config, sizeof(Config), 1, file) != 1) { exit(EXIT_FAILURE); }
// negative vocab size is hacky way of signaling unshared weights. bit yikes.
int shared_weights = config->vocab_size > 0 ? 1 : 0;
config->vocab_size = abs(config->vocab_size);
// figure out the file size
fseek(file, 0, SEEK_END);
*file_size = ftell(file);
fclose(file);
// memory map the Transformer weights into the data pointer
*fd = open(checkpoint, O_RDONLY);
if (*fd == -1) { fprintf(stderr, "open failed!\n"); exit(EXIT_FAILURE); }
*data = mmap(NULL, *file_size, PROT_READ, MAP_PRIVATE, *fd, 0);
if (*data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); exit(EXIT_FAILURE); }
float* weights_ptr = *data + sizeof(Config)/sizeof(float);
memory_map_weights(weights, config, weights_ptr, shared_weights);
}
Benefits of mmap:
- No explicit memory allocation needed
- OS pages in data on-demand
- Multiple processes can share the same physical memory
- Works with models larger than available RAM (with disk thrashing)
RMSNorm Implementation
Llama uses RMSNorm instead of LayerNorm - it's simpler (no mean subtraction) and empirically works just as well:
run.c:182-195:
void rmsnorm(float* o, float* x, float* weight, int size) {
// calculate sum of squares
float ss = 0.0f;
for (int j = 0; j < size; j++) {
ss += x[j] * x[j];
}
ss /= size;
ss += 1e-5f;
ss = 1.0f / sqrtf(ss);
// normalize and scale
for (int j = 0; j < size; j++) {
o[j] = weight[j] * (ss * x[j]);
}
}
Mathematical formulation:
where is a learned scale parameter. Note there's no (bias) term and no mean subtraction.
Comparison with LayerNorm:
| LayerNorm | RMSNorm | |
|---|---|---|
| Mean subtraction | Yes | No |
| Learned parameters | only | |
| Compute |
Softmax with Numerical Stability
run.c:197-215:
void softmax(float* x, int size) {
// find max value (for numerical stability)
float max_val = x[0];
for (int i = 1; i < size; i++) {
if (x[i] > max_val) {
max_val = x[i];
}
}
// exp and sum
float sum = 0.0f;
for (int i = 0; i < size; i++) {
x[i] = expf(x[i] - max_val);
sum += x[i];
}
// normalize
for (int i = 0; i < size; i++) {
x[i] /= sum;
}
}
The max subtraction prevents overflow: .
Matrix Multiplication
The most time-consuming operation - a simple triple loop with OpenMP:
run.c:217-229:
void matmul(float* xout, float* x, float* w, int n, int d) {
// W (d,n) @ x (n,) -> xout (d,)
// by far the most amount of time is spent inside this little function
int i;
#pragma omp parallel for private(i)
for (i = 0; i < d; i++) {
float val = 0.0f;
for (int j = 0; j < n; j++) {
val += w[i * n + j] * x[j];
}
xout[i] = val;
}
}
This computes where and .
The #pragma omp parallel for distributes rows across CPU cores.
The Complete Forward Pass
run.c:231-362 - The heart of the inference engine:
float* forward(Transformer* transformer, int token, int pos) {
// a few convenience variables
Config* p = &transformer->config;
TransformerWeights* w = &transformer->weights;
RunState* s = &transformer->state;
float *x = s->x;
int dim = p->dim;
int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
int kv_mul = p->n_heads / p->n_kv_heads; // integer multiplier of the kv sharing in multiquery
int hidden_dim = p->hidden_dim;
int head_size = dim / p->n_heads;
// copy the token embedding into x
float* content_row = w->token_embedding_table + token * dim;
memcpy(x, content_row, dim*sizeof(*x));
// forward all the layers
for(unsigned long long l = 0; l < p->n_layers; l++) {
// ... layer computation ...
}
// final rmsnorm
rmsnorm(x, x, w->rms_final_weight, dim);
// classifier into logits
matmul(s->logits, x, w->wcls, p->dim, p->vocab_size);
return s->logits;
}
RoPE: Rotary Position Embeddings
RoPE encodes position by rotating query and key vectors:
run.c:264-279:
// RoPE relative positional encoding: complex-valued rotate q and k in each head
for (int i = 0; i < dim; i+=2) {
int head_dim = i % head_size;
float freq = 1.0f / powf(10000.0f, head_dim / (float)head_size);
float val = pos * freq;
float fcr = cosf(val);
float fci = sinf(val);
int rotn = i < kv_dim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
for (int v = 0; v < rotn; v++) {
float* vec = v == 0 ? s->q : s->k; // the vector to rotate (query or key)
float v0 = vec[i];
float v1 = vec[i+1];
vec[i] = v0 * fcr - v1 * fci;
vec[i+1] = v0 * fci + v1 * fcr;
}
}
RoPE Mathematics:
For position and dimension pair :
The rotation encodes absolute position, but attention scores depend only on relative position:
Multi-Head Attention with GQA
run.c:281-319:
// multihead attention. iterate over all heads
int h;
#pragma omp parallel for private(h)
for (h = 0; h < p->n_heads; h++) {
// get the query vector for this head
float* q = s->q + h * head_size;
// attention scores for this head
float* att = s->att + h * p->seq_len;
// iterate over all timesteps, including the current one
for (int t = 0; t <= pos; t++) {
// get the key vector for this head and at this timestep
float* k = s->key_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
// calculate the attention score as the dot product of q and k
float score = 0.0f;
for (int i = 0; i < head_size; i++) {
score += q[i] * k[i];
}
score /= sqrtf(head_size);
// save the score to the attention buffer
att[t] = score;
}
// softmax the scores to get attention weights, from 0..pos inclusively
softmax(att, pos + 1);
// weighted sum of the values, store back into xb
float* xb = s->xb + h * head_size;
memset(xb, 0, head_size * sizeof(float));
for (int t = 0; t <= pos; t++) {
// get the value vector for this head and at this timestep
float* v = s->value_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
// get the attention weight for this timestep
float a = att[t];
// accumulate the weighted value into xb
for (int i = 0; i < head_size; i++) {
xb[i] += a * v[i];
}
}
}
Key insight: (h / kv_mul) maps multiple query heads to the same KV head for GQA.
For example, with n_heads=32 and n_kv_heads=8, kv_mul=4, so heads 0-3 share KV head 0, heads 4-7 share KV head 1, etc.
SwiGLU Feed-Forward Network
run.c:332-348:
// ffn rmsnorm
rmsnorm(s->xb, x, w->rms_ffn_weight + l*dim, dim);
// Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
// first calculate self.w1(x) and self.w3(x)
matmul(s->hb, s->xb, w->w1 + l*dim*hidden_dim, dim, hidden_dim);
matmul(s->hb2, s->xb, w->w3 + l*dim*hidden_dim, dim, hidden_dim);
// SwiGLU non-linearity
for (int i = 0; i < hidden_dim; i++) {
float val = s->hb[i];
// silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
val *= (1.0f / (1.0f + expf(-val)));
// elementwise multiply with w3(x)
val *= s->hb2[i];
s->hb[i] = val;
}
// final matmul to get the output of the ffn
matmul(s->xb, s->hb, w->w2 + l*dim*hidden_dim, hidden_dim, dim);
SwiGLU Formula:
where SiLU (Swish) is:
The is element-wise multiplication (gating).
BPE Tokenizer
The tokenizer implements Byte Pair Encoding with byte fallback:
run.c:452-571 (encode function):
void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) {
// encode the string text (input) into an upper-bound preallocated tokens[] array
if (text == NULL) { fprintf(stderr, "cannot encode NULL text\n"); exit(EXIT_FAILURE); }
if (t->sorted_vocab == NULL) {
// lazily malloc and sort the vocabulary
t->sorted_vocab = malloc(t->vocab_size * sizeof(TokenIndex));
for (int i = 0; i < t->vocab_size; i++) {
t->sorted_vocab[i].str = t->vocab[i];
t->sorted_vocab[i].id = i;
}
qsort(t->sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens);
}
// ... UTF-8 handling ...
// merge the best consecutive pair each iteration, according the scores in vocab_scores
while (1) {
float best_score = -1e10;
int best_id = -1;
int best_idx = -1;
for (int i=0; i < (*n_tokens-1); i++) {
// check if we can merge the pair (tokens[i], tokens[i+1])
sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]);
int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
if (id != -1 && t->vocab_scores[id] > best_score) {
// this merge pair exists in vocab! record its score and position
best_score = t->vocab_scores[id];
best_id = id;
best_idx = i;
}
}
if (best_idx == -1) {
break; // we couldn't find any more pairs to merge, so we're done
}
// merge the consecutive pair (best_idx, best_idx+1) into new token best_id
tokens[best_idx] = best_id;
// delete token at position best_idx+1, shift the entire sequence back 1
for (int i = best_idx+1; i < (*n_tokens-1); i++) {
tokens[i] = tokens[i+1];
}
(*n_tokens)--; // token length decreased
}
}
The algorithm:
- Start with byte-level tokens
- Find the highest-scoring merge pair in vocabulary
- Merge that pair into a single token
- Repeat until no more merges possible
Top-p (Nucleus) Sampling
run.c:624-665:
int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex, float coin) {
// top-p sampling (or "nucleus sampling") samples from the smallest set of
// tokens that exceed probability topp. This way we never sample tokens that
// have very low probabilities and are less likely to go "off the rails".
int n0 = 0;
// quicksort indices in descending order of probabilities
// values smaller than (1 - topp) / (n - 1) cannot be part of the result
// so for efficiency we crop these out as candidates before sorting
const float cutoff = (1.0f - topp) / (n - 1);
for (int i = 0; i < n; i++) {
if (probabilities[i] >= cutoff) {
probindex[n0].index = i;
probindex[n0].prob = probabilities[i];
n0++;
}
}
qsort(probindex, n0, sizeof(ProbIndex), compare);
// truncate the list where cumulative probability exceeds topp
float cumulative_prob = 0.0f;
int last_idx = n0 - 1;
for (int i = 0; i < n0; i++) {
cumulative_prob += probindex[i].prob;
if (cumulative_prob > topp) {
last_idx = i;
break;
}
}
// sample from the truncated list
float r = coin * cumulative_prob;
float cdf = 0.0f;
for (int i = 0; i <= last_idx; i++) {
cdf += probindex[i].prob;
if (r < cdf) {
return probindex[i].index;
}
}
return probindex[last_idx].index;
}
Top-p algorithm:
- Sort tokens by probability (descending)
- Keep tokens until cumulative probability exceeds (e.g., 0.9)
- Sample from this "nucleus" of tokens
This prevents sampling extremely unlikely tokens while maintaining diversity.
Generation Loop
run.c:729-783:
void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps) {
// encode the (string) prompt into tokens sequence
int num_prompt_tokens = 0;
int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * sizeof(int));
encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
// start the main loop
long start = 0;
int next;
int token = prompt_tokens[0];
int pos = 0;
while (pos < steps) {
// forward the transformer to get logits for the next token
float* logits = forward(transformer, token, pos);
// advance the state machine
if (pos < num_prompt_tokens - 1) {
// if we are still processing the input prompt, force the next prompt token
next = prompt_tokens[pos + 1];
} else {
// otherwise sample the next token from the logits
next = sample(sampler, logits);
}
pos++;
// data-dependent terminating condition: the BOS (=1) token delimits sequences
if (next == 1) { break; }
// print the token as string, decode it with the Tokenizer object
char* piece = decode(tokenizer, token, next);
safe_printf(piece);
fflush(stdout);
token = next;
// init the timer here because the first iteration can be slower
if (start == 0) { start = time_in_ms(); }
}
// report achieved tok/s
if (pos > 1) {
long end = time_in_ms();
fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000);
}
}
Training with train.py
While run.c handles inference, train.py provides a complete PyTorch training pipeline:
train.py - Key hyperparameters:
# model
dim = 288
n_layers = 6
n_heads = 6
n_kv_heads = 6
multiple_of = 32
dropout = 0.0
# adamw optimizer
gradient_accumulation_steps = 4
learning_rate = 5e-4
max_iters = 100000
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95
grad_clip = 1.0
# learning rate decay
decay_lr = True
warmup_iters = 1000
Distributed Training Support
train.py:
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
if ddp:
init_process_group(backend="nccl")
ddp_rank = int(os.environ["RANK"])
ddp_local_rank = int(os.environ["LOCAL_RANK"])
ddp_world_size = int(os.environ["WORLD_SIZE"])
device = f"cuda:{ddp_local_rank}"
torch.cuda.set_device(device)
Running DDP:
# Single GPU
python train.py --compile=False --batch_size=8
# 4 GPUs on 1 node
torchrun --standalone --nproc_per_node=4 train.py
# 8 GPUs across 2 nodes
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 \
--master_addr=123.456.123.456 --master_port=1234 train.py
TinyStories Dataset
llama2.c uses the TinyStories dataset (small children's stories) for training small models.
tinystories.py - Custom tokenizer training:
def train_vocab(vocab_size):
"""Train a custom sentencepiece tokenizer on TinyStories"""
prefix = os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}")
# Export text from first 10 shards for training
tiny_file = os.path.join(DATA_CACHE_DIR, "tiny.txt")
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
with open(tiny_file, "w", encoding="utf-8") as of:
for shard in shard_filenames[:10]:
with open(shard, "r") as f:
data = json.load(f)
for example in data:
text = example["story"].strip()
of.write(text + "\n")
# Train sentencepiece BPE model
spm.SentencePieceTrainer.train(
input=tiny_file,
model_prefix=prefix,
model_type="bpe",
vocab_size=vocab_size,
character_coverage=1.0,
num_threads=os.cpu_count(),
split_digits=True, # Split digits into individual tokens
allow_whitespace_only_pieces=True,
byte_fallback=True, # Handle unknown bytes
normalization_rule_name="identity" # No unicode normalization
)
tinystories.py - Pre-tokenization for efficient training:
def process_shard(args, vocab_size):
"""Process a single JSON shard into binary tokens"""
shard_id, shard = args
enc = Tokenizer(get_tokenizer_model_path(vocab_size))
with open(shard, "r") as f:
data = json.load(f)
all_tokens = []
for example in data:
text = example["story"].strip()
tokens = enc.encode(text, bos=True, eos=False) # BOS but no EOS
all_tokens.extend(tokens)
# Save as uint16 (vocab < 65536)
all_tokens = np.array(all_tokens, dtype=np.uint16)
tokenized_filename = shard.replace(".json", ".bin")
with open(tokenized_filename, "wb") as f:
f.write(all_tokens.tobytes())
def pretokenize(vocab_size):
"""Pre-tokenize all shards in parallel"""
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
fun = partial(process_shard, vocab_size=vocab_size)
with ProcessPoolExecutor() as executor:
executor.map(fun, enumerate(shard_filenames))
tinystories.py - PyTorch DataLoader:
class PretokDataset(torch.utils.data.IterableDataset):
"""Loads pretokenized .bin files for training"""
def __init__(self, split, max_seq_len, vocab_size, vocab_source):
self.split = split
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
def __iter__(self):
# Unique seed per worker + DDP rank for proper shuffling
worker_info = torch.utils.data.get_worker_info()
worker_id = worker_info.id if worker_info else 0
rank = dist.get_rank() if dist.is_initialized() else 0
seed = 42 + worker_id + 1337 * rank
rng = random.Random(seed)
# Load .bin shards and yield sequences
shard_filenames = sorted(glob.glob(os.path.join(bin_dir, "*.bin")))
while True:
rng.shuffle(shard_filenames) # Shuffle shards each epoch
for shard in shard_filenames:
m = np.memmap(shard, dtype=np.uint16, mode="r")
# ... yield sequences of max_seq_len ...
Key design choices:
uint16storage (2 bytes/token vs 4 for int32)- BOS token separates stories (used to calculate avg sequence length)
- Parallel processing with
ProcessPoolExecutor memmapfor memory-efficient loading of large shards
Python Model (for Training)
The Python implementation mirrors the C code but uses PyTorch:
model.py:94-164 (Attention):
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
assert args.n_heads % self.n_kv_heads == 0
self.n_local_heads = args.n_heads
self.n_local_kv_heads = self.n_kv_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
def forward(self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor):
bsz, seqlen, _ = x.shape
# QKV
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
# RoPE relative positional embeddings
xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
# grouped multiquery attention: expand out keys and values
xk = repeat_kv(xk, self.n_rep)
xv = repeat_kv(xv, self.n_rep)
# flash implementation
output = torch.nn.functional.scaled_dot_product_attention(
xq.transpose(1, 2), xk.transpose(1, 2), xv.transpose(1, 2),
attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True
)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)
Exporting Models
The export.py script converts PyTorch models to the binary format.
Quantization (Q8_0)
export.py - Group quantization to int8:
def quantize_q80(w, group_size):
"""
Symmetric quantization into int8, range [-127,127]
Returns: (int8_values, scales, max_error)
"""
assert w.numel() % group_size == 0
w = w.float().reshape(-1, group_size)
# Find max absolute value in each group
wmax = torch.abs(w).max(dim=1).values
# Scale factor: float = quant * scale
scale = wmax / 127.0
# Quantize
quant = w / scale[:, None]
int8val = torch.round(quant).to(torch.int8)
# Compute max quantization error
fp32val = (int8val.float() * scale[:, None]).view(-1)
err = torch.abs(fp32val.reshape(-1, group_size) - w).max(dim=1).values
maxerr = err.max().item()
return int8val, scale, maxerr
Quantization math:
Binary Export
export.py - Legacy format (v0):
def legacy_export(model, filepath):
"""Export model to .bin format for run.c"""
out_file = open(filepath, 'wb')
# Header: 7 integers (28 bytes)
hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0]
p = model.params
shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight)
# Negative vocab_size signals unshared classifier weights
if not shared_classifier:
p.vocab_size = -p.vocab_size
header = struct.pack('iiiiiii',
p.dim, hidden_dim, p.n_layers, p.n_heads,
p.n_kv_heads, p.vocab_size, p.max_seq_len
)
out_file.write(header)
# Weights in specific order (must match run.c memory_map_weights)
serialize_fp32(out_file, model.tok_embeddings.weight)
for layer in model.layers:
serialize_fp32(out_file, layer.attention_norm.weight)
for layer in model.layers:
serialize_fp32(out_file, layer.attention.wq.weight)
for layer in model.layers:
serialize_fp32(out_file, layer.attention.wk.weight)
# ... remaining weights in order ...
Weight order (must match memory_map_weights in run.c):
- Token embeddings
(vocab_size, dim) - RMSNorm attention weights
(n_layers, dim) - Wq weights
(n_layers, dim, n_heads * head_size) - Wk weights
(n_layers, dim, n_kv_heads * head_size) - Wv weights
(n_layers, dim, n_kv_heads * head_size) - Wo weights
(n_layers, n_heads * head_size, dim) - RMSNorm FFN weights
(n_layers, dim) - W1 weights
(n_layers, hidden_dim, dim) - W2 weights
(n_layers, dim, hidden_dim) - W3 weights
(n_layers, hidden_dim, dim) - Final RMSNorm weight
(dim,) - Classifier (if not shared)
Performance Considerations
Memory Layout
All weights are stored contiguously with pointer arithmetic:
run.c:111-140:
void memory_map_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) {
int head_size = p->dim / p->n_heads;
unsigned long long n_layers = p->n_layers;
w->token_embedding_table = ptr;
ptr += p->vocab_size * p->dim;
w->rms_att_weight = ptr;
ptr += n_layers * p->dim;
w->wq = ptr;
ptr += n_layers * p->dim * (p->n_heads * head_size);
// ... etc
}
KV Cache Efficiency
The KV cache stores (layer, seq_len, kv_dim) tensors:
run.c:86-88:
s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
With GQA (n_kv_heads < n_heads), cache size is reduced proportionally.
Quantized Inference (runq.c)
llama2.c also provides int8 quantized inference for reduced memory and faster computation.
Quantized Tensor Structure
runq.c:33-36:
typedef struct {
int8_t* q; // quantized values
float* s; // scaling factors
} QuantizedTensor;
Quantized Weights
runq.c:38-60:
typedef struct {
// token embedding table
QuantizedTensor *q_tokens; // (vocab_size, dim)
float* token_embedding_table; // same, but dequantized
// weights for rmsnorms (kept in float - small)
float* rms_att_weight; // (layer, dim)
float* rms_ffn_weight; // (layer, dim)
// weights for matmuls - QUANTIZED
QuantizedTensor *wq; // (layer, dim, n_heads * head_size)
QuantizedTensor *wk; // (layer, dim, n_kv_heads * head_size)
QuantizedTensor *wv; // (layer, dim, n_kv_heads * head_size)
QuantizedTensor *wo; // (layer, n_heads * head_size, dim)
QuantizedTensor *w1; // (layer, hidden_dim, dim)
QuantizedTensor *w2; // (layer, dim, hidden_dim)
QuantizedTensor *w3; // (layer, hidden_dim, dim)
float* rms_final_weight; // (dim,)
QuantizedTensor *wcls;
} TransformerWeights;
Quantized Matmul
runq.c (quantized matrix multiplication):
void matmul(float* xout, QuantizedTensor *x, QuantizedTensor *w, int n, int d) {
// W (d,n) @ x (n,) -> xout (d,)
int i;
#pragma omp parallel for private(i)
for (i = 0; i < d; i++) {
float val = 0.0f;
int32_t ival = 0;
int in = i * n;
// do the matmul in groups of GS
int j;
for (j = 0; j <= n - GS; j += GS) {
for (int k = 0; k < GS; k++) {
ival += ((int32_t) x->q[j + k]) * ((int32_t) w->q[in + j + k]);
}
val += ((float) ival) * w->s[(in + j) / GS] * x->s[j / GS];
ival = 0;
}
xout[i] = val;
}
}
Group quantization: Weights are quantized in groups of GS (e.g., 64). Each group shares a single scale factor, balancing compression and accuracy.
Quantization formula:
Memory savings: 4× reduction (float32 → int8).
Compiling and Running
# Compile with OpenMP for parallel matmul
gcc -O3 -o run run.c -lm -fopenmp
# Compile quantized version
gcc -O3 -o runq runq.c -lm -fopenmp
# Run inference
./run stories15M.bin -t 0.8 -p 0.9 -n 256 -i "Once upon a time"
# Run quantized inference
./runq stories15M_q80.bin -t 0.8 -p 0.9 -n 256 -i "Once upon a time"
Command-line options:
-t: Temperature (0 = greedy, 1 = original distribution)-p: Top-p value for nucleus sampling-n: Number of tokens to generate-i: Input prompt
Summary
llama2.c demonstrates that a complete Llama inference engine fits in ~500 lines of C:
| Component | Lines | Complexity |
|---|---|---|
| RMSNorm | 14 | |
| Softmax | 19 | |
| Matmul | 13 | |
| RoPE | 16 | |
| Attention | 39 | |
| SwiGLU FFN | 17 | |
| Forward pass | 132 |
The simplicity makes it perfect for:
- Education: Understanding transformers at the metal
- Embedded deployment: No Python/PyTorch dependencies
- Prototyping: Quick experiments with custom models
- Verification: Cross-checking PyTorch implementations
Frequently Asked Questions
Related Articles
nanoGPT: Andrej Karpathy's Minimal GPT Training Framework
A comprehensive, equation-complete analysis of nanoGPT—Andrej Karpathy's influential minimal GPT implementation. Deep dive into the ~300-line model definition (model.py), training loop (train.py), Flash Attention, weight initialization, and the mathematical foundations behind every component.
nanochat: Andrej Karpathy's Full-Stack ChatGPT Clone
A comprehensive, equation-complete analysis of nanochat—the complete ChatGPT pipeline from tokenization through reinforcement learning. Deep dive into the modern GPT architecture (RoPE, RMSNorm, GQA, QK-norm, ReLU²), the Muon optimizer with Newton-Schulz orthogonalization, KV cache inference, and tool use.
Transformer Architecture: A Complete Deep Dive
A comprehensive exploration of the transformer architecture—from embedding layers through attention and feed-forward networks to the output head. Understand why decoder-only models dominate, how residual connections enable deep networks, and the engineering decisions behind GPT, Llama, and modern LLMs.
Context Extension: How LLMs Scale Beyond Training Length
A comprehensive deep dive into context extension techniques—how models trained on 4K tokens extrapolate to 128K+. Understand RoPE scaling, Position Interpolation, NTK-aware scaling, YaRN, and the mathematics of long-context LLMs.