Skip to main content
Back to Blog

llama2.c: Andrej Karpathy's Pure C Inference Engine

A deep dive into llama2.c - Karpathy's ~1000 lines of pure C that runs Llama 2 inference. Covers the complete transformer implementation including RMSNorm, RoPE, SwiGLU, GQA, KV cache, BPE tokenization, and top-p sampling - all without any dependencies.

5 min read
Share:

llama2.c is Andrej Karpathy's minimalist implementation of Llama 2 inference in pure C. The entire transformer runs in ~500 lines of dependency-free C code, making it an exceptional educational resource for understanding how LLMs work at the lowest level.

Repository: github.com/karpathy/llama2.c

Repository Structure

Code
llama2.c/
├── run.c              # 973 lines - Main C inference engine
├── runq.c             # 1092 lines - Quantized (int8) inference
├── model.py           # 343 lines - PyTorch model for training
├── train.py           # 343 lines - Training loop
├── export.py          # 567 lines - Export model to binary format
├── tokenizer.py       # 78 lines - Tokenizer utilities
├── tokenizer.bin      # Pre-trained tokenizer
└── tinystories.py     # TinyStories dataset preparation

Architecture Overview

llama2.c implements the Llama 2 architecture with these key components:

ComponentImplementationFile Reference
NormalizationRMSNorm (not LayerNorm)run.c:182-195
Position EncodingRoPE (computed on-the-fly)run.c:264-279
AttentionMulti-head with GQA supportrun.c:281-319
FFNSwiGLU (w1, w2, w3)run.c:332-348
TokenizerBPE with byte fallbackrun.c:364-571
SamplingTop-p (nucleus) samplingrun.c:624-665

Model Configuration

The model hyperparameters are stored in a simple struct:

run.c:19-27:

C
typedef struct {
    int dim;        // transformer dimension
    int hidden_dim; // for ffn layers
    int n_layers;   // number of layers
    int n_heads;    // number of query heads
    int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery)
    int vocab_size; // vocabulary size, usually 256 (byte-level)
    int seq_len;    // max sequence length
} Config;

The n_kv_heads field enables Grouped Query Attention (GQA) - where multiple query heads share the same key/value heads, reducing memory and compute.

Weight Storage

All model weights are stored contiguously in memory:

run.c:29-48:

C
typedef struct {
    // token embedding table
    float* token_embedding_table;    // (vocab_size, dim)
    // weights for rmsnorms
    float* rms_att_weight; // (layer, dim) rmsnorm weights
    float* rms_ffn_weight; // (layer, dim)
    // weights for matmuls. note dim == n_heads * head_size
    float* wq; // (layer, dim, n_heads * head_size)
    float* wk; // (layer, dim, n_kv_heads * head_size)
    float* wv; // (layer, dim, n_kv_heads * head_size)
    float* wo; // (layer, n_heads * head_size, dim)
    // weights for ffn
    float* w1; // (layer, hidden_dim, dim)
    float* w2; // (layer, dim, hidden_dim)
    float* w3; // (layer, hidden_dim, dim)
    // final rmsnorm
    float* rms_final_weight; // (dim,)
    // (optional) classifier weights for the logits, on the last layer
    float* wcls;
} TransformerWeights;

Note: Llama uses three FFN weight matrices (w1, w2, w3) for SwiGLU, unlike GPT-2's two matrices.

Runtime State (Activations + KV Cache)

run.c:50-65:

C
typedef struct {
    // current wave of activations
    float *x;      // activation at current time stamp (dim,)
    float *xb;     // same, but inside a residual branch (dim,)
    float *xb2;    // an additional buffer just for convenience (dim,)
    float *hb;     // buffer for hidden dimension in the ffn (hidden_dim,)
    float *hb2;    // buffer for hidden dimension in the ffn (hidden_dim,)
    float *q;      // query (dim,)
    float *k;      // key (dim,)
    float *v;      // value (dim,)
    float *att;    // buffer for scores/attention values (n_heads, seq_len)
    float *logits; // output logits
    // kv cache
    float* key_cache;   // (layer, seq_len, dim)
    float* value_cache; // (layer, seq_len, dim)
} RunState;

The KV cache stores previously computed keys and values, enabling O(1)O(1) per-token computation instead of O(T)O(T) recomputation.

Memory-Mapped Weight Loading

Weights are loaded via mmap for efficiency - the OS handles paging from disk:

run.c:142-162:

C
void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weights,
                     int* fd, float** data, ssize_t* file_size) {
    FILE *file = fopen(checkpoint, "rb");
    if (!file) { fprintf(stderr, "Couldn't open file %s\n", checkpoint); exit(EXIT_FAILURE); }
    // read in the config header
    if (fread(config, sizeof(Config), 1, file) != 1) { exit(EXIT_FAILURE); }
    // negative vocab size is hacky way of signaling unshared weights. bit yikes.
    int shared_weights = config->vocab_size > 0 ? 1 : 0;
    config->vocab_size = abs(config->vocab_size);
    // figure out the file size
    fseek(file, 0, SEEK_END);
    *file_size = ftell(file);
    fclose(file);
    // memory map the Transformer weights into the data pointer
    *fd = open(checkpoint, O_RDONLY);
    if (*fd == -1) { fprintf(stderr, "open failed!\n"); exit(EXIT_FAILURE); }
    *data = mmap(NULL, *file_size, PROT_READ, MAP_PRIVATE, *fd, 0);
    if (*data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); exit(EXIT_FAILURE); }
    float* weights_ptr = *data + sizeof(Config)/sizeof(float);
    memory_map_weights(weights, config, weights_ptr, shared_weights);
}

Benefits of mmap:

  • No explicit memory allocation needed
  • OS pages in data on-demand
  • Multiple processes can share the same physical memory
  • Works with models larger than available RAM (with disk thrashing)

RMSNorm Implementation

Llama uses RMSNorm instead of LayerNorm - it's simpler (no mean subtraction) and empirically works just as well:

run.c:182-195:

C
void rmsnorm(float* o, float* x, float* weight, int size) {
    // calculate sum of squares
    float ss = 0.0f;
    for (int j = 0; j < size; j++) {
        ss += x[j] * x[j];
    }
    ss /= size;
    ss += 1e-5f;
    ss = 1.0f / sqrtf(ss);
    // normalize and scale
    for (int j = 0; j < size; j++) {
        o[j] = weight[j] * (ss * x[j]);
    }
}

Mathematical formulation:

RMSNorm(x)=x1di=1dxi2+ϵγ\text{RMSNorm}(x) = \frac{x}{\sqrt{\frac{1}{d}\sum_{i=1}^{d} x_i^2 + \epsilon}} \cdot \gamma

where γ\gamma is a learned scale parameter. Note there's no β\beta (bias) term and no mean subtraction.

Comparison with LayerNorm:

LayerNormRMSNorm
Mean subtractionYesNo
Learned parametersγ,β\gamma, \betaγ\gamma only
ComputeO(2d)O(2d)O(d)O(d)

Softmax with Numerical Stability

run.c:197-215:

C
void softmax(float* x, int size) {
    // find max value (for numerical stability)
    float max_val = x[0];
    for (int i = 1; i < size; i++) {
        if (x[i] > max_val) {
            max_val = x[i];
        }
    }
    // exp and sum
    float sum = 0.0f;
    for (int i = 0; i < size; i++) {
        x[i] = expf(x[i] - max_val);
        sum += x[i];
    }
    // normalize
    for (int i = 0; i < size; i++) {
        x[i] /= sum;
    }
}

The max subtraction prevents overflow: softmax(xmax(x))=softmax(x)\text{softmax}(x - \max(x)) = \text{softmax}(x).

Matrix Multiplication

The most time-consuming operation - a simple triple loop with OpenMP:

run.c:217-229:

C
void matmul(float* xout, float* x, float* w, int n, int d) {
    // W (d,n) @ x (n,) -> xout (d,)
    // by far the most amount of time is spent inside this little function
    int i;
    #pragma omp parallel for private(i)
    for (i = 0; i < d; i++) {
        float val = 0.0f;
        for (int j = 0; j < n; j++) {
            val += w[i * n + j] * x[j];
        }
        xout[i] = val;
    }
}

This computes y=Wx\mathbf{y} = W\mathbf{x} where WRd×nW \in \mathbb{R}^{d \times n} and xRn\mathbf{x} \in \mathbb{R}^n.

The #pragma omp parallel for distributes rows across CPU cores.

The Complete Forward Pass

run.c:231-362 - The heart of the inference engine:

C
float* forward(Transformer* transformer, int token, int pos) {
    // a few convenience variables
    Config* p = &transformer->config;
    TransformerWeights* w = &transformer->weights;
    RunState* s = &transformer->state;
    float *x = s->x;
    int dim = p->dim;
    int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
    int kv_mul = p->n_heads / p->n_kv_heads; // integer multiplier of the kv sharing in multiquery
    int hidden_dim = p->hidden_dim;
    int head_size = dim / p->n_heads;

    // copy the token embedding into x
    float* content_row = w->token_embedding_table + token * dim;
    memcpy(x, content_row, dim*sizeof(*x));

    // forward all the layers
    for(unsigned long long l = 0; l < p->n_layers; l++) {
        // ... layer computation ...
    }

    // final rmsnorm
    rmsnorm(x, x, w->rms_final_weight, dim);

    // classifier into logits
    matmul(s->logits, x, w->wcls, p->dim, p->vocab_size);
    return s->logits;
}

RoPE: Rotary Position Embeddings

RoPE encodes position by rotating query and key vectors:

run.c:264-279:

C
// RoPE relative positional encoding: complex-valued rotate q and k in each head
for (int i = 0; i < dim; i+=2) {
    int head_dim = i % head_size;
    float freq = 1.0f / powf(10000.0f, head_dim / (float)head_size);
    float val = pos * freq;
    float fcr = cosf(val);
    float fci = sinf(val);
    int rotn = i < kv_dim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
    for (int v = 0; v < rotn; v++) {
        float* vec = v == 0 ? s->q : s->k; // the vector to rotate (query or key)
        float v0 = vec[i];
        float v1 = vec[i+1];
        vec[i]   = v0 * fcr - v1 * fci;
        vec[i+1] = v0 * fci + v1 * fcr;
    }
}

RoPE Mathematics:

For position mm and dimension pair (i,i+1)(i, i+1):

θi=100002i/d\theta_i = 10000^{-2i/d}

(qiqi+1)=(cos(mθi)sin(mθi)sin(mθi)cos(mθi))(qiqi+1)\begin{pmatrix} q'_i \\ q'_{i+1} \end{pmatrix} = \begin{pmatrix} \cos(m\theta_i) & -\sin(m\theta_i) \\ \sin(m\theta_i) & \cos(m\theta_i) \end{pmatrix} \begin{pmatrix} q_i \\ q_{i+1} \end{pmatrix}

The rotation encodes absolute position, but attention scores depend only on relative position:

qmTkn=qTRmnkq_m^T k_n = q^T R_{m-n} k

Multi-Head Attention with GQA

run.c:281-319:

C
// multihead attention. iterate over all heads
int h;
#pragma omp parallel for private(h)
for (h = 0; h < p->n_heads; h++) {
    // get the query vector for this head
    float* q = s->q + h * head_size;
    // attention scores for this head
    float* att = s->att + h * p->seq_len;
    // iterate over all timesteps, including the current one
    for (int t = 0; t <= pos; t++) {
        // get the key vector for this head and at this timestep
        float* k = s->key_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
        // calculate the attention score as the dot product of q and k
        float score = 0.0f;
        for (int i = 0; i < head_size; i++) {
            score += q[i] * k[i];
        }
        score /= sqrtf(head_size);
        // save the score to the attention buffer
        att[t] = score;
    }

    // softmax the scores to get attention weights, from 0..pos inclusively
    softmax(att, pos + 1);

    // weighted sum of the values, store back into xb
    float* xb = s->xb + h * head_size;
    memset(xb, 0, head_size * sizeof(float));
    for (int t = 0; t <= pos; t++) {
        // get the value vector for this head and at this timestep
        float* v = s->value_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
        // get the attention weight for this timestep
        float a = att[t];
        // accumulate the weighted value into xb
        for (int i = 0; i < head_size; i++) {
            xb[i] += a * v[i];
        }
    }
}

Key insight: (h / kv_mul) maps multiple query heads to the same KV head for GQA.

For example, with n_heads=32 and n_kv_heads=8, kv_mul=4, so heads 0-3 share KV head 0, heads 4-7 share KV head 1, etc.

SwiGLU Feed-Forward Network

run.c:332-348:

C
// ffn rmsnorm
rmsnorm(s->xb, x, w->rms_ffn_weight + l*dim, dim);

// Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
// first calculate self.w1(x) and self.w3(x)
matmul(s->hb, s->xb, w->w1 + l*dim*hidden_dim, dim, hidden_dim);
matmul(s->hb2, s->xb, w->w3 + l*dim*hidden_dim, dim, hidden_dim);

// SwiGLU non-linearity
for (int i = 0; i < hidden_dim; i++) {
    float val = s->hb[i];
    // silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
    val *= (1.0f / (1.0f + expf(-val)));
    // elementwise multiply with w3(x)
    val *= s->hb2[i];
    s->hb[i] = val;
}

// final matmul to get the output of the ffn
matmul(s->xb, s->hb, w->w2 + l*dim*hidden_dim, hidden_dim, dim);

SwiGLU Formula:

FFN(x)=W2(SiLU(W1x)W3x)\text{FFN}(x) = W_2 \cdot (\text{SiLU}(W_1 x) \odot W_3 x)

where SiLU (Swish) is:

SiLU(x)=xσ(x)=x1+ex\text{SiLU}(x) = x \cdot \sigma(x) = \frac{x}{1 + e^{-x}}

The \odot is element-wise multiplication (gating).

BPE Tokenizer

The tokenizer implements Byte Pair Encoding with byte fallback:

run.c:452-571 (encode function):

C
void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) {
    // encode the string text (input) into an upper-bound preallocated tokens[] array
    if (text == NULL) { fprintf(stderr, "cannot encode NULL text\n"); exit(EXIT_FAILURE); }

    if (t->sorted_vocab == NULL) {
        // lazily malloc and sort the vocabulary
        t->sorted_vocab = malloc(t->vocab_size * sizeof(TokenIndex));
        for (int i = 0; i < t->vocab_size; i++) {
            t->sorted_vocab[i].str = t->vocab[i];
            t->sorted_vocab[i].id = i;
        }
        qsort(t->sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens);
    }

    // ... UTF-8 handling ...

    // merge the best consecutive pair each iteration, according the scores in vocab_scores
    while (1) {
        float best_score = -1e10;
        int best_id = -1;
        int best_idx = -1;

        for (int i=0; i < (*n_tokens-1); i++) {
            // check if we can merge the pair (tokens[i], tokens[i+1])
            sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]);
            int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
            if (id != -1 && t->vocab_scores[id] > best_score) {
                // this merge pair exists in vocab! record its score and position
                best_score = t->vocab_scores[id];
                best_id = id;
                best_idx = i;
            }
        }

        if (best_idx == -1) {
            break; // we couldn't find any more pairs to merge, so we're done
        }

        // merge the consecutive pair (best_idx, best_idx+1) into new token best_id
        tokens[best_idx] = best_id;
        // delete token at position best_idx+1, shift the entire sequence back 1
        for (int i = best_idx+1; i < (*n_tokens-1); i++) {
            tokens[i] = tokens[i+1];
        }
        (*n_tokens)--; // token length decreased
    }
}

The algorithm:

  1. Start with byte-level tokens
  2. Find the highest-scoring merge pair in vocabulary
  3. Merge that pair into a single token
  4. Repeat until no more merges possible

Top-p (Nucleus) Sampling

run.c:624-665:

C
int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex, float coin) {
    // top-p sampling (or "nucleus sampling") samples from the smallest set of
    // tokens that exceed probability topp. This way we never sample tokens that
    // have very low probabilities and are less likely to go "off the rails".

    int n0 = 0;
    // quicksort indices in descending order of probabilities
    // values smaller than (1 - topp) / (n - 1) cannot be part of the result
    // so for efficiency we crop these out as candidates before sorting
    const float cutoff = (1.0f - topp) / (n - 1);
    for (int i = 0; i < n; i++) {
        if (probabilities[i] >= cutoff) {
            probindex[n0].index = i;
            probindex[n0].prob = probabilities[i];
            n0++;
        }
    }
    qsort(probindex, n0, sizeof(ProbIndex), compare);

    // truncate the list where cumulative probability exceeds topp
    float cumulative_prob = 0.0f;
    int last_idx = n0 - 1;
    for (int i = 0; i < n0; i++) {
        cumulative_prob += probindex[i].prob;
        if (cumulative_prob > topp) {
            last_idx = i;
            break;
        }
    }

    // sample from the truncated list
    float r = coin * cumulative_prob;
    float cdf = 0.0f;
    for (int i = 0; i <= last_idx; i++) {
        cdf += probindex[i].prob;
        if (r < cdf) {
            return probindex[i].index;
        }
    }
    return probindex[last_idx].index;
}

Top-p algorithm:

  1. Sort tokens by probability (descending)
  2. Keep tokens until cumulative probability exceeds pp (e.g., 0.9)
  3. Sample from this "nucleus" of tokens

This prevents sampling extremely unlikely tokens while maintaining diversity.

Generation Loop

run.c:729-783:

C
void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps) {
    // encode the (string) prompt into tokens sequence
    int num_prompt_tokens = 0;
    int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * sizeof(int));
    encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens);

    // start the main loop
    long start = 0;
    int next;
    int token = prompt_tokens[0];
    int pos = 0;
    while (pos < steps) {
        // forward the transformer to get logits for the next token
        float* logits = forward(transformer, token, pos);

        // advance the state machine
        if (pos < num_prompt_tokens - 1) {
            // if we are still processing the input prompt, force the next prompt token
            next = prompt_tokens[pos + 1];
        } else {
            // otherwise sample the next token from the logits
            next = sample(sampler, logits);
        }
        pos++;

        // data-dependent terminating condition: the BOS (=1) token delimits sequences
        if (next == 1) { break; }

        // print the token as string, decode it with the Tokenizer object
        char* piece = decode(tokenizer, token, next);
        safe_printf(piece);
        fflush(stdout);
        token = next;

        // init the timer here because the first iteration can be slower
        if (start == 0) { start = time_in_ms(); }
    }

    // report achieved tok/s
    if (pos > 1) {
        long end = time_in_ms();
        fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000);
    }
}

Training with train.py

While run.c handles inference, train.py provides a complete PyTorch training pipeline:

train.py - Key hyperparameters:

Python
# model
dim = 288
n_layers = 6
n_heads = 6
n_kv_heads = 6
multiple_of = 32
dropout = 0.0

# adamw optimizer
gradient_accumulation_steps = 4
learning_rate = 5e-4
max_iters = 100000
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95
grad_clip = 1.0

# learning rate decay
decay_lr = True
warmup_iters = 1000

Distributed Training Support

train.py:

Python
ddp = int(os.environ.get("RANK", -1)) != -1  # is this a ddp run?
if ddp:
    init_process_group(backend="nccl")
    ddp_rank = int(os.environ["RANK"])
    ddp_local_rank = int(os.environ["LOCAL_RANK"])
    ddp_world_size = int(os.environ["WORLD_SIZE"])
    device = f"cuda:{ddp_local_rank}"
    torch.cuda.set_device(device)

Running DDP:

Bash
# Single GPU
python train.py --compile=False --batch_size=8

# 4 GPUs on 1 node
torchrun --standalone --nproc_per_node=4 train.py

# 8 GPUs across 2 nodes
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 \
    --master_addr=123.456.123.456 --master_port=1234 train.py

TinyStories Dataset

llama2.c uses the TinyStories dataset (small children's stories) for training small models.

tinystories.py - Custom tokenizer training:

Python
def train_vocab(vocab_size):
    """Train a custom sentencepiece tokenizer on TinyStories"""
    prefix = os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}")

    # Export text from first 10 shards for training
    tiny_file = os.path.join(DATA_CACHE_DIR, "tiny.txt")
    shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))

    with open(tiny_file, "w", encoding="utf-8") as of:
        for shard in shard_filenames[:10]:
            with open(shard, "r") as f:
                data = json.load(f)
            for example in data:
                text = example["story"].strip()
                of.write(text + "\n")

    # Train sentencepiece BPE model
    spm.SentencePieceTrainer.train(
        input=tiny_file,
        model_prefix=prefix,
        model_type="bpe",
        vocab_size=vocab_size,
        character_coverage=1.0,
        num_threads=os.cpu_count(),
        split_digits=True,           # Split digits into individual tokens
        allow_whitespace_only_pieces=True,
        byte_fallback=True,          # Handle unknown bytes
        normalization_rule_name="identity"  # No unicode normalization
    )

tinystories.py - Pre-tokenization for efficient training:

Python
def process_shard(args, vocab_size):
    """Process a single JSON shard into binary tokens"""
    shard_id, shard = args
    enc = Tokenizer(get_tokenizer_model_path(vocab_size))

    with open(shard, "r") as f:
        data = json.load(f)

    all_tokens = []
    for example in data:
        text = example["story"].strip()
        tokens = enc.encode(text, bos=True, eos=False)  # BOS but no EOS
        all_tokens.extend(tokens)

    # Save as uint16 (vocab < 65536)
    all_tokens = np.array(all_tokens, dtype=np.uint16)

    tokenized_filename = shard.replace(".json", ".bin")
    with open(tokenized_filename, "wb") as f:
        f.write(all_tokens.tobytes())

def pretokenize(vocab_size):
    """Pre-tokenize all shards in parallel"""
    shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))

    fun = partial(process_shard, vocab_size=vocab_size)
    with ProcessPoolExecutor() as executor:
        executor.map(fun, enumerate(shard_filenames))

tinystories.py - PyTorch DataLoader:

Python
class PretokDataset(torch.utils.data.IterableDataset):
    """Loads pretokenized .bin files for training"""

    def __init__(self, split, max_seq_len, vocab_size, vocab_source):
        self.split = split
        self.max_seq_len = max_seq_len
        self.vocab_size = vocab_size

    def __iter__(self):
        # Unique seed per worker + DDP rank for proper shuffling
        worker_info = torch.utils.data.get_worker_info()
        worker_id = worker_info.id if worker_info else 0
        rank = dist.get_rank() if dist.is_initialized() else 0
        seed = 42 + worker_id + 1337 * rank
        rng = random.Random(seed)

        # Load .bin shards and yield sequences
        shard_filenames = sorted(glob.glob(os.path.join(bin_dir, "*.bin")))
        while True:
            rng.shuffle(shard_filenames)  # Shuffle shards each epoch
            for shard in shard_filenames:
                m = np.memmap(shard, dtype=np.uint16, mode="r")
                # ... yield sequences of max_seq_len ...

Key design choices:

  • uint16 storage (2 bytes/token vs 4 for int32)
  • BOS token separates stories (used to calculate avg sequence length)
  • Parallel processing with ProcessPoolExecutor
  • memmap for memory-efficient loading of large shards

Python Model (for Training)

The Python implementation mirrors the C code but uses PyTorch:

model.py:94-164 (Attention):

Python
class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        assert args.n_heads % self.n_kv_heads == 0
        self.n_local_heads = args.n_heads
        self.n_local_kv_heads = self.n_kv_heads
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = args.dim // args.n_heads
        self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)

    def forward(self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor):
        bsz, seqlen, _ = x.shape

        # QKV
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

        # RoPE relative positional embeddings
        xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)

        # grouped multiquery attention: expand out keys and values
        xk = repeat_kv(xk, self.n_rep)
        xv = repeat_kv(xv, self.n_rep)

        # flash implementation
        output = torch.nn.functional.scaled_dot_product_attention(
            xq.transpose(1, 2), xk.transpose(1, 2), xv.transpose(1, 2),
            attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True
        )

        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        return self.wo(output)

Exporting Models

The export.py script converts PyTorch models to the binary format.

Quantization (Q8_0)

export.py - Group quantization to int8:

Python
def quantize_q80(w, group_size):
    """
    Symmetric quantization into int8, range [-127,127]
    Returns: (int8_values, scales, max_error)
    """
    assert w.numel() % group_size == 0
    w = w.float().reshape(-1, group_size)

    # Find max absolute value in each group
    wmax = torch.abs(w).max(dim=1).values

    # Scale factor: float = quant * scale
    scale = wmax / 127.0

    # Quantize
    quant = w / scale[:, None]
    int8val = torch.round(quant).to(torch.int8)

    # Compute max quantization error
    fp32val = (int8val.float() * scale[:, None]).view(-1)
    err = torch.abs(fp32val.reshape(-1, group_size) - w).max(dim=1).values
    maxerr = err.max().item()

    return int8val, scale, maxerr

Quantization math:

scaleg=max(wg)127\text{scale}_g = \frac{\max(|w_g|)}{127}

wint8=round(wscaleg)w_{\text{int8}} = \text{round}\left(\frac{w}{\text{scale}_g}\right)

wdequant=wint8scalegw_{\text{dequant}} = w_{\text{int8}} \cdot \text{scale}_g

Binary Export

export.py - Legacy format (v0):

Python
def legacy_export(model, filepath):
    """Export model to .bin format for run.c"""
    out_file = open(filepath, 'wb')

    # Header: 7 integers (28 bytes)
    hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0]
    p = model.params
    shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight)

    # Negative vocab_size signals unshared classifier weights
    if not shared_classifier:
        p.vocab_size = -p.vocab_size

    header = struct.pack('iiiiiii',
        p.dim, hidden_dim, p.n_layers, p.n_heads,
        p.n_kv_heads, p.vocab_size, p.max_seq_len
    )
    out_file.write(header)

    # Weights in specific order (must match run.c memory_map_weights)
    serialize_fp32(out_file, model.tok_embeddings.weight)
    for layer in model.layers:
        serialize_fp32(out_file, layer.attention_norm.weight)
    for layer in model.layers:
        serialize_fp32(out_file, layer.attention.wq.weight)
    for layer in model.layers:
        serialize_fp32(out_file, layer.attention.wk.weight)
    # ... remaining weights in order ...

Weight order (must match memory_map_weights in run.c):

  1. Token embeddings (vocab_size, dim)
  2. RMSNorm attention weights (n_layers, dim)
  3. Wq weights (n_layers, dim, n_heads * head_size)
  4. Wk weights (n_layers, dim, n_kv_heads * head_size)
  5. Wv weights (n_layers, dim, n_kv_heads * head_size)
  6. Wo weights (n_layers, n_heads * head_size, dim)
  7. RMSNorm FFN weights (n_layers, dim)
  8. W1 weights (n_layers, hidden_dim, dim)
  9. W2 weights (n_layers, dim, hidden_dim)
  10. W3 weights (n_layers, hidden_dim, dim)
  11. Final RMSNorm weight (dim,)
  12. Classifier (if not shared)

Performance Considerations

Memory Layout

All weights are stored contiguously with pointer arithmetic:

run.c:111-140:

C
void memory_map_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) {
    int head_size = p->dim / p->n_heads;
    unsigned long long n_layers = p->n_layers;
    w->token_embedding_table = ptr;
    ptr += p->vocab_size * p->dim;
    w->rms_att_weight = ptr;
    ptr += n_layers * p->dim;
    w->wq = ptr;
    ptr += n_layers * p->dim * (p->n_heads * head_size);
    // ... etc
}

KV Cache Efficiency

The KV cache stores (layer, seq_len, kv_dim) tensors:

run.c:86-88:

C
s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));

With GQA (n_kv_heads < n_heads), cache size is reduced proportionally.

Quantized Inference (runq.c)

llama2.c also provides int8 quantized inference for reduced memory and faster computation.

Quantized Tensor Structure

runq.c:33-36:

C
typedef struct {
    int8_t* q;    // quantized values
    float* s;     // scaling factors
} QuantizedTensor;

Quantized Weights

runq.c:38-60:

C
typedef struct {
    // token embedding table
    QuantizedTensor *q_tokens; // (vocab_size, dim)
    float* token_embedding_table; // same, but dequantized

    // weights for rmsnorms (kept in float - small)
    float* rms_att_weight; // (layer, dim)
    float* rms_ffn_weight; // (layer, dim)

    // weights for matmuls - QUANTIZED
    QuantizedTensor *wq; // (layer, dim, n_heads * head_size)
    QuantizedTensor *wk; // (layer, dim, n_kv_heads * head_size)
    QuantizedTensor *wv; // (layer, dim, n_kv_heads * head_size)
    QuantizedTensor *wo; // (layer, n_heads * head_size, dim)
    QuantizedTensor *w1; // (layer, hidden_dim, dim)
    QuantizedTensor *w2; // (layer, dim, hidden_dim)
    QuantizedTensor *w3; // (layer, hidden_dim, dim)

    float* rms_final_weight; // (dim,)
    QuantizedTensor *wcls;
} TransformerWeights;

Quantized Matmul

runq.c (quantized matrix multiplication):

C
void matmul(float* xout, QuantizedTensor *x, QuantizedTensor *w, int n, int d) {
    // W (d,n) @ x (n,) -> xout (d,)
    int i;
    #pragma omp parallel for private(i)
    for (i = 0; i < d; i++) {
        float val = 0.0f;
        int32_t ival = 0;
        int in = i * n;

        // do the matmul in groups of GS
        int j;
        for (j = 0; j <= n - GS; j += GS) {
            for (int k = 0; k < GS; k++) {
                ival += ((int32_t) x->q[j + k]) * ((int32_t) w->q[in + j + k]);
            }
            val += ((float) ival) * w->s[(in + j) / GS] * x->s[j / GS];
            ival = 0;
        }
        xout[i] = val;
    }
}

Group quantization: Weights are quantized in groups of GS (e.g., 64). Each group shares a single scale factor, balancing compression and accuracy.

Quantization formula:

wint8=round(wfloatscale),scale=max(w)127w_{\text{int8}} = \text{round}\left(\frac{w_{\text{float}}}{\text{scale}}\right), \quad \text{scale} = \frac{\max(|w|)}{127}

Memory savings: 4× reduction (float32 → int8).

Compiling and Running

Bash
# Compile with OpenMP for parallel matmul
gcc -O3 -o run run.c -lm -fopenmp

# Compile quantized version
gcc -O3 -o runq runq.c -lm -fopenmp

# Run inference
./run stories15M.bin -t 0.8 -p 0.9 -n 256 -i "Once upon a time"

# Run quantized inference
./runq stories15M_q80.bin -t 0.8 -p 0.9 -n 256 -i "Once upon a time"

Command-line options:

  • -t: Temperature (0 = greedy, 1 = original distribution)
  • -p: Top-p value for nucleus sampling
  • -n: Number of tokens to generate
  • -i: Input prompt

Summary

llama2.c demonstrates that a complete Llama inference engine fits in ~500 lines of C:

ComponentLinesComplexity
RMSNorm14O(d)O(d)
Softmax19O(n)O(n)
Matmul13O(nd)O(nd)
RoPE16O(d)O(d)
Attention39O(Td)O(Td)
SwiGLU FFN17O(dh)O(d \cdot h)
Forward pass132O(L(Td+dh))O(L(Td + dh))

The simplicity makes it perfect for:

  • Education: Understanding transformers at the metal
  • Embedded deployment: No Python/PyTorch dependencies
  • Prototyping: Quick experiments with custom models
  • Verification: Cross-checking PyTorch implementations

Frequently Asked Questions

Enrico Piovano, PhD

Co-founder & CTO at Goji AI. Former Applied Scientist at Amazon (Alexa & AGI), focused on Agentic AI and LLMs. PhD in Electrical Engineering from Imperial College London. Gold Medalist at the National Mathematical Olympiad.

Related Articles