Skip to main content
Back to Blog

llm.c: Andrej Karpathy's GPT-2 Training in Pure C/CUDA

A comprehensive deep dive into llm.c - Karpathy's implementation of GPT-2 training in pure C and CUDA without PyTorch. Covers forward and backward passes, CUDA kernels for attention, LayerNorm, matmul, GELU, AdamW optimizer, multi-GPU training with NCCL, and ZeRO optimization.

5 min read
Share:

llm.c is Andrej Karpathy's implementation of GPT-2 training in pure C and CUDA - no PyTorch, no Python dependencies. The project demonstrates that you can train large language models with just ~1000 lines of C (CPU) or ~2000 lines of CUDA (GPU).

Repository: github.com/karpathy/llm.c

Repository Structure

Code
llm.c/
├── train_gpt2.c       # 1182 lines - CPU training (reference)
├── train_gpt2.cu      # 1904 lines - CUDA training (production)
├── train_gpt2.py      # 860 lines - Python reference
├── train_llama3.py    # 1255 lines - Llama 3 training
├── llmc/              # CUDA kernel implementations
│   ├── attention.cuh      # 11KB - Attention kernels
│   ├── layernorm.cuh      # 23KB - LayerNorm kernels
│   ├── matmul.cuh         # 14KB - Matrix multiplication
│   ├── gelu.cuh           # 3KB - GELU activation
│   ├── adamw.cuh          # 5KB - AdamW optimizer
│   ├── encoder.cuh        # 11KB - Token/position encoding
│   ├── fused_classifier.cuh # 7KB - Fused softmax + cross-entropy
│   ├── global_norm.cuh    # 4KB - Gradient norm computation
│   ├── zero.cuh           # 23KB - ZeRO optimization
│   ├── dataloader.h       # 24KB - Data loading utilities
│   └── cuda_utils.cuh     # 11KB - CUDA utilities
└── dev/               # Development utilities

GPT-2 Architecture in C

The model configuration mirrors PyTorch's GPT-2:

train_gpt2.c:526-533:

C
typedef struct {
    int max_seq_len; // max sequence length, e.g. 1024
    int vocab_size;  // vocab size, e.g. 50257
    int padded_vocab_size; // padded to e.g. %128==0, 50304
    int num_layers;  // number of layers, e.g. 12
    int num_heads;   // number of heads in attention, e.g. 12
    int channels;    // number of channels, e.g. 768
} GPT2Config;

Parameter Tensors

train_gpt2.c:536-554:

C
#define NUM_PARAMETER_TENSORS 16
typedef struct {
    float* wte;      // (V, C) token embeddings
    float* wpe;      // (maxT, C) position embeddings
    float* ln1w;     // (L, C) layernorm 1 weights
    float* ln1b;     // (L, C) layernorm 1 biases
    float* qkvw;     // (L, 3*C, C) QKV projection weights
    float* qkvb;     // (L, 3*C) QKV projection biases
    float* attprojw; // (L, C, C) attention output projection
    float* attprojb; // (L, C)
    float* ln2w;     // (L, C) layernorm 2 weights
    float* ln2b;     // (L, C)
    float* fcw;      // (L, 4*C, C) MLP first layer
    float* fcb;      // (L, 4*C)
    float* fcprojw;  // (L, C, 4*C) MLP second layer
    float* fcprojb;  // (L, C)
    float* lnfw;     // (C) final layernorm
    float* lnfb;     // (C)
} ParameterTensors;

Activation Tensors

train_gpt2.c:601-626:

C
#define NUM_ACTIVATION_TENSORS 23
typedef struct {
    float* encoded;   // (B, T, C)
    float* ln1;       // (L, B, T, C)
    float* ln1_mean;  // (L, B, T)
    float* ln1_rstd;  // (L, B, T)
    float* qkv;       // (L, B, T, 3*C)
    float* atty;      // (L, B, T, C)
    float* preatt;    // (L, B, NH, T, T)
    float* att;       // (L, B, NH, T, T)
    float* attproj;   // (L, B, T, C)
    float* residual2; // (L, B, T, C)
    float* ln2;       // (L, B, T, C)
    float* ln2_mean;  // (L, B, T)
    float* ln2_rstd;  // (L, B, T)
    float* fch;       // (L, B, T, 4*C)
    float* fch_gelu;  // (L, B, T, 4*C)
    float* fcproj;    // (L, B, T, C)
    float* residual3; // (L, B, T, C)
    float* lnf;       // (B, T, C)
    float* lnf_mean;  // (B, T)
    float* lnf_rstd;  // (B, T)
    float* logits;    // (B, T, V)
    float* probs;     // (B, T, V)
    float* losses;    // (B, T)
} ActivationTensors;

The activation struct stores all intermediate values needed for the backward pass - this is the memory cost of backpropagation.

Forward Pass Layers

Encoder (Token + Position Embeddings)

train_gpt2.c:35-58:

C
void encoder_forward(float* out,
                   int* inp, float* wte, float* wpe,
                   int B, int T, int C) {
    // out is (B,T,C). At each position (b,t), a C-dimensional vector summarizing token & position
    // inp is (B,T) of integers, holding the token ids at each (b,t) position
    // wte is (V,C) of token embeddings, short for "weight token embeddings"
    // wpe is (maxT,C) of position embeddings, short for "weight positional embedding"
    for (int b = 0; b < B; b++) {
        for (int t = 0; t < T; t++) {
            // seek to the output position in out[b,t,:]
            float* out_bt = out + b * T * C + t * C;
            // get the index of the token at inp[b, t]
            int ix = inp[b * T + t];
            // seek to the position in wte corresponding to the token
            float* wte_ix = wte + ix * C;
            // seek to the position in wpe corresponding to the position
            float* wpe_t = wpe + t * C;
            // add the two vectors and store the result in out[b,t,:]
            for (int i = 0; i < C; i++) {
                out_bt[i] = wte_ix[i] + wpe_t[i];
            }
        }
    }
}

encodedb,t=wte[inpb,t]+wpe[t]\text{encoded}_{b,t} = \text{wte}[\text{inp}_{b,t}] + \text{wpe}[t]

LayerNorm Forward

train_gpt2.c:78-118:

C
void layernorm_forward(float* out, float* mean, float* rstd,
                       float* inp, float* weight, float* bias,
                       int B, int T, int C) {
    // reference: https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
    float eps = 1e-5f;
    for (int b = 0; b < B; b++) {
        for (int t = 0; t < T; t++) {
            // seek to the input position inp[b,t,:]
            float* x = inp + b * T * C + t * C;
            // calculate the mean
            float m = 0.0f;
            for (int i = 0; i < C; i++) {
                m += x[i];
            }
            m = m/C;
            // calculate the variance (without any bias correction)
            float v = 0.0f;
            for (int i = 0; i < C; i++) {
                float xshift = x[i] - m;
                v += xshift * xshift;
            }
            v = v/C;
            // calculate the rstd (reciprocal standard deviation)
            float s = 1.0f / sqrtf(v + eps);
            // seek to the output position in out[b,t,:]
            float* out_bt = out + b * T * C + t * C;
            for (int i = 0; i < C; i++) {
                float n = (s * (x[i] - m)); // normalize
                float o = n * weight[i] + bias[i]; // scale and shift
                out_bt[i] = o; // write
            }
            // cache the mean and rstd for the backward pass later
            mean[b * T + t] = m;
            rstd[b * T + t] = s;
        }
    }
}

LayerNorm formula:

μ=1Ci=1Cxi,σ2=1Ci=1C(xiμ)2\mu = \frac{1}{C}\sum_{i=1}^{C} x_i, \quad \sigma^2 = \frac{1}{C}\sum_{i=1}^{C} (x_i - \mu)^2

LayerNorm(x)=γxμσ2+ϵ+β\text{LayerNorm}(x) = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta

We cache μ\mu and 1/σ2+ϵ1/\sqrt{\sigma^2 + \epsilon} (rstd) for the backward pass.

Matrix Multiplication

train_gpt2.c:184-229:

C
void matmul_forward(float* out,
                    const float* inp, const float* weight, const float* bias,
                    int B, int T, int C, int OC) {
    // most of the running time is spent here and in matmul_backward
    // inp is (B,T,C), weight is (OC, C), bias is (OC)
    // out will be (B,T,OC)

    // make sure the tiled loop will be correct or fallback to naive version
    const int LOOP_UNROLL = 8;
    if (B*T % LOOP_UNROLL != 0) {
        matmul_forward_naive(out, inp, weight, bias, B, T, C, OC);
        return;
    }

    // collapse the B and T loops into one and turn it into a strided loop
    // then we can tile the inner loop, and reuse the loaded weight LOOP_UNROLL many times
    #pragma omp parallel for
    for (int obt = 0; obt < B * T; obt += LOOP_UNROLL) {
        for (int o = 0; o < OC; o++) {
            // we'll keep LOOP_UNROLL many results in registers
            float result[LOOP_UNROLL];
            // initialize the bias, if it exists
            for (int ibt = 0; ibt < LOOP_UNROLL; ibt++) {
                result[ibt] = (bias != NULL) ? bias[o] : 0.0f;
            }
            // inner loops. Because we do LOOP_UNROLL steps of inner bt, we can cache
            // the value of weight[i + o * C] and reuse it.
            for (int i = 0; i < C; i++) {
                float w = weight[i + o * C];
                for (int ibt = 0; ibt < LOOP_UNROLL; ibt++) {
                    int bt = obt + ibt;
                    result[ibt] += inp[bt * C + i] * w;
                }
            }
            // write back results to main memory
            for (int ibt = 0; ibt < LOOP_UNROLL; ibt++) {
                int bt = obt + ibt;
                out[bt * OC + o] = result[ibt];
            }
        }
    }
}

Optimization: Loop unrolling by factor 8 allows reusing each weight value 8 times, improving cache efficiency.

Attention Forward

train_gpt2.c:271-345:

C
void attention_forward(float* out, float* preatt, float* att,
                       float* inp,
                       int B, int T, int C, int NH) {
    // input is (B, T, 3C) holding the query, key, value (Q, K, V) vectors
    // preatt, att are (B, NH, T, T)
    // output is (B, T, C)
    int C3 = C*3;
    int hs = C / NH; // head size
    float scale = 1.0 / sqrtf(hs);

    #pragma omp parallel for collapse(3)
    for (int b = 0; b < B; b++) {
        for (int t = 0; t < T; t++) {
            for (int h = 0; h < NH; h++) {
                float* query_t = inp + b * T * C3 + t * C3 + h * hs;
                float* preatt_bth = preatt + b*NH*T*T + h*T*T + t*T;
                float* att_bth = att + b*NH*T*T + h*T*T + t*T;

                // pass 1: calculate query dot key and maxval
                float maxval = -10000.0f;
                for (int t2 = 0; t2 <= t; t2++) {
                    float* key_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C; // +C because it's key

                    // (query_t) dot (key_t2)
                    float val = 0.0f;
                    for (int i = 0; i < hs; i++) {
                        val += query_t[i] * key_t2[i];
                    }
                    val *= scale;
                    if (val > maxval) {
                        maxval = val;
                    }
                    preatt_bth[t2] = val;
                }

                // pass 2: calculate the exp and keep track of sum
                float expsum = 0.0f;
                for (int t2 = 0; t2 <= t; t2++) {
                    float expv = expf(preatt_bth[t2] - maxval);
                    expsum += expv;
                    att_bth[t2] = expv;
                }
                float expsum_inv = expsum == 0.0f ? 0.0f : 1.0f / expsum;

                // pass 3: normalize to get the softmax
                for (int t2 = 0; t2 < T; t2++) {
                    if (t2 <= t) {
                        att_bth[t2] *= expsum_inv;
                    } else {
                        att_bth[t2] = 0.0f; // causal attention mask
                    }
                }

                // pass 4: accumulate weighted values into the output of attention
                float* out_bth = out + b * T * C + t * C + h * hs;
                for (int i = 0; i < hs; i++) { out_bth[i] = 0.0f; }
                for (int t2 = 0; t2 <= t; t2++) {
                    float* value_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C*2;
                    float att_btht2 = att_bth[t2];
                    for (int i = 0; i < hs; i++) {
                        out_bth[i] += att_btht2 * value_t2[i];
                    }
                }
            }
        }
    }
}

Four passes:

  1. Compute QKT/dQ \cdot K^T / \sqrt{d} and track max (for softmax stability)
  2. Compute exp(scoremax)\exp(\text{score} - \max) and sum
  3. Normalize to get attention weights
  4. Compute weighted sum of values

GELU Activation

train_gpt2.c:407-415:

C
#define GELU_SCALING_FACTOR sqrtf(2.0f / M_PI)
void gelu_forward(float* out, float* inp, int N) {
    // (approximate) GeLU elementwise non-linearity in the MLP block of Transformer
    for (int i = 0; i < N; i++) {
        float x = inp[i];
        float cube = 0.044715f * x * x * x;
        out[i] = 0.5f * x * (1.0f + tanhf(GELU_SCALING_FACTOR * (x + cube)));
    }
}

GELU approximation (faster than exact):

GELU(x)0.5x(1+tanh(2π(x+0.044715x3)))\text{GELU}(x) \approx 0.5x\left(1 + \tanh\left(\sqrt{\frac{2}{\pi}}\left(x + 0.044715x^3\right)\right)\right)

Softmax and Cross-Entropy

train_gpt2.c:449-521:

C
void softmax_forward(float* probs, float* logits, int B, int T, int V, int Vp) {
    // output: probs are (B,T,Vp) of the probabilities
    // input: logits is (B,T,Vp) of the unnormalized log probabilities
    #pragma omp parallel for collapse(2)
    for (int b = 0; b < B; b++) {
        for (int t = 0; t < T; t++) {
            float* logits_bt = logits + b * T * Vp + t * Vp;
            float* probs_bt = probs + b * T * Vp + t * Vp;

            float maxval = -10000.0f;
            for (int i = 0; i < V; i++) {
                if (logits_bt[i] > maxval) {
                    maxval = logits_bt[i];
                }
            }
            float sum = 0.0f;
            for (int i = 0; i < V; i++) {
                probs_bt[i] = expf(logits_bt[i] - maxval);
                sum += probs_bt[i];
            }
            for (int i = 0; i < V; i++) {
                probs_bt[i] /= sum;
            }
        }
    }
}

void crossentropy_forward(float* losses,
                          float* probs, int* targets,
                          int B, int T, int Vp) {
    for (int b = 0; b < B; b++) {
        for (int t = 0; t < T; t++) {
            float* probs_bt = probs + b * T * Vp + t * Vp;
            int ix = targets[b * T + t];
            losses[b * T + t] = -logf(probs_bt[ix]);
        }
    }
}

Backward Pass Layers

The backward pass computes gradients for all parameters. This is where llm.c really shines - implementing backprop in raw C.

LayerNorm Backward

train_gpt2.c:120-161:

C
void layernorm_backward(float* dinp, float* dweight, float* dbias,
                        float* dout, float* inp, float* weight, float* mean, float* rstd,
                        int B, int T, int C) {
    for (int b = 0; b < B; b++) {
        for (int t = 0; t < T; t++) {
            float* dout_bt = dout + b * T * C + t * C;
            float* inp_bt = inp + b * T * C + t * C;
            float* dinp_bt = dinp + b * T * C + t * C;
            float mean_bt = mean[b * T + t];
            float rstd_bt = rstd[b * T + t];

            // first: two reduce operations
            float dnorm_mean = 0.0f;
            float dnorm_norm_mean = 0.0f;
            for (int i = 0; i < C; i++) {
                float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;
                float dnorm_i = weight[i] * dout_bt[i];
                dnorm_mean += dnorm_i;
                dnorm_norm_mean += dnorm_i * norm_bti;
            }
            dnorm_mean = dnorm_mean / C;
            dnorm_norm_mean = dnorm_norm_mean / C;

            // now iterate again and accumulate all the gradients
            for (int i = 0; i < C; i++) {
                float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;
                float dnorm_i = weight[i] * dout_bt[i];
                // gradient contribution to bias
                dbias[i] += dout_bt[i];
                // gradient contribution to weight
                dweight[i] += norm_bti * dout_bt[i];
                // gradient contribution to input
                float dval = 0.0f;
                dval += dnorm_i; // term 1
                dval -= dnorm_mean; // term 2
                dval -= norm_bti * dnorm_norm_mean; // term 3
                dval *= rstd_bt; // final scale
                dinp_bt[i] += dval;
            }
        }
    }
}

The LayerNorm backward requires two passes:

  1. Compute intermediate sums (dnorm_mean, dnorm_norm_mean)
  2. Compute gradients for weight, bias, and input

Matmul Backward

train_gpt2.c:231-269:

C
void matmul_backward(float* dinp, float* dweight, float* dbias,
                     const float* dout, const float* inp, const float* weight,
                     int B, int T, int C, int OC) {
    // backward into inp first, parallelize over B,T
    #pragma omp parallel for collapse(2)
    for (int b = 0; b < B; b++) {
        for (int t = 0; t < T; t++) {
            const float* dout_bt = dout + b * T * OC + t * OC;
            float* dinp_bt = dinp + b * T * C + t * C;
            for (int o = 0; o < OC; o++) {
                const float* wrow = weight + o*C;
                float d = dout_bt[o];
                for (int i = 0; i < C; i++) {
                    dinp_bt[i] += wrow[i] * d;
                }
            }
        }
    }
    // backward into weight/bias, parallelize over output channels OC
    #pragma omp parallel for
    for (int o = 0; o < OC; o++) {
        for (int b = 0; b < B; b++) {
            for (int t = 0; t < T; t++) {
                const float* dout_bt = dout + b * T * OC + t * OC;
                const float* inp_bt = inp + b * T * C + t * C;
                float* dwrow = dweight + o*C;
                float d = dout_bt[o];
                if (dbias != NULL) { dbias[o] += d; }
                for (int i = 0; i < C; i++) {
                    dwrow[i] += inp_bt[i] * d;
                }
            }
        }
    }
}

For Y=XWT+bY = XW^T + b:

  • LX=LYW\frac{\partial L}{\partial X} = \frac{\partial L}{\partial Y} W (gradient w.r.t. input)
  • LW=(LY)TX\frac{\partial L}{\partial W} = \left(\frac{\partial L}{\partial Y}\right)^T X (gradient w.r.t. weight)
  • Lb=LY\frac{\partial L}{\partial b} = \sum \frac{\partial L}{\partial Y} (gradient w.r.t. bias)

Attention Backward

train_gpt2.c:347-405:

C
void attention_backward(float* dinp, float* dpreatt, float* datt,
                        float* dout, float* inp, float* att,
                        int B, int T, int C, int NH) {
    int C3 = C*3;
    int hs = C / NH;
    float scale = 1.f / sqrtf(hs);

    for (int b = 0; b < B; b++) {
        for (int t = 0; t < T; t++) {
            for (int h = 0; h < NH; h++) {
                float* att_bth = att + b*NH*T*T + h*T*T + t*T;
                float* datt_bth = datt + b*NH*T*T + h*T*T + t*T;
                float* dpreatt_bth = dpreatt + b*NH*T*T + h*T*T + t*T;
                float* dquery_t = dinp + b * T * C3 + t * C3 + h * hs;
                float* query_t = inp + b * T * C3 + t * C3 + h * hs;

                // backward pass 4, through the value accumulation
                float* dout_bth = dout + b * T * C + t * C + h * hs;
                for (int t2 = 0; t2 <= t; t2++) {
                    float* value_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C*2;
                    float* dvalue_t2 = dinp + b * T * C3 + t2 * C3 + h * hs + C*2;
                    for (int i = 0; i < hs; i++) {
                        datt_bth[t2] += value_t2[i] * dout_bth[i];
                        dvalue_t2[i] += att_bth[t2] * dout_bth[i];
                    }
                }

                // backward pass 2 & 3, the softmax
                for (int t2 = 0; t2 <= t; t2++) {
                    for (int t3 = 0; t3 <= t; t3++) {
                        float indicator = t2 == t3 ? 1.0f : 0.0f;
                        float local_derivative = att_bth[t2] * (indicator - att_bth[t3]);
                        dpreatt_bth[t3] += local_derivative * datt_bth[t2];
                    }
                }

                // backward pass 1, the query @ key matmul
                for (int t2 = 0; t2 <= t; t2++) {
                    float* key_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C;
                    float* dkey_t2 = dinp + b * T * C3 + t2 * C3 + h * hs + C;
                    for (int i = 0; i < hs; i++) {
                        dquery_t[i] += key_t2[i] * dpreatt_bth[t2] * scale;
                        dkey_t2[i] += query_t[i] * dpreatt_bth[t2] * scale;
                    }
                }
            }
        }
    }
}

GELU Backward

train_gpt2.c:422-434:

C
void gelu_backward(float* dinp, float* inp, float* dout, int N) {
    for (int i = 0; i < N; i++) {
        float x = inp[i];
        float cube = 0.044715f * x * x * x;
        float tanh_arg = GELU_SCALING_FACTOR * (x + cube);
        float tanh_out = tanhf(tanh_arg);
        float coshf_out = coshf(tanh_arg);
        float sech_out = 1.0f / (coshf_out * coshf_out);
        float local_grad = 0.5f * (1.0f + tanh_out) + x * 0.5f * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x);
        dinp[i] += local_grad * dout[i];
    }
}

Cross-Entropy Softmax Backward (Fused)

train_gpt2.c:502-521:

C
void crossentropy_softmax_backward(float* dlogits,
                           float* dlosses, float* probs, int* targets,
                           int B, int T, int V, int Vp) {
    // backwards through both softmax and crossentropy
    for (int b = 0; b < B; b++) {
        for (int t = 0; t < T; t++) {
            float* dlogits_bt = dlogits + b * T * Vp + t * Vp;
            float* probs_bt = probs + b * T * Vp + t * Vp;
            float dloss = dlosses[b * T + t];
            int ix = targets[b * T + t];
            for (int i = 0; i < V; i++) {
                float p = probs_bt[i];
                float indicator = i == ix ? 1.0f : 0.0f;
                dlogits_bt[i] += (p - indicator) * dloss;
            }
        }
    }
}

The elegant result: Llogitsi=pi1i=target\frac{\partial L}{\partial \text{logits}_i} = p_i - \mathbb{1}_{i=\text{target}}

AdamW Optimizer

train_gpt2.c:1007-1033:

C
void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, int t) {
    // lazily allocate the memory for m_memory and v_memory
    if (model->m_memory == NULL) {
        model->m_memory = (float*)calloc(model->num_parameters, sizeof(float));
        model->v_memory = (float*)calloc(model->num_parameters, sizeof(float));
    }

    for (size_t i = 0; i < model->num_parameters; i++) {
        float param = model->params_memory[i];
        float grad = model->grads_memory[i];

        // update the first moment (momentum)
        float m = beta1 * model->m_memory[i] + (1.0f - beta1) * grad;
        // update the second moment (RMSprop)
        float v = beta2 * model->v_memory[i] + (1.0f - beta2) * grad * grad;
        // bias-correct both moments
        float m_hat = m / (1.0f - powf(beta1, t));
        float v_hat = v / (1.0f - powf(beta2, t));

        // update
        model->m_memory[i] = m;
        model->v_memory[i] = v;
        model->params_memory[i] -= learning_rate * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * param);
    }
}

AdamW equations:

mt=β1mt1+(1β1)gtm_t = \beta_1 m_{t-1} + (1-\beta_1) g_t vt=β2vt1+(1β2)gt2v_t = \beta_2 v_{t-1} + (1-\beta_2) g_t^2 m^t=mt1β1t,v^t=vt1β2t\hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t} θt=θt1α(m^tv^t+ϵ+λθt1)\theta_t = \theta_{t-1} - \alpha \left(\frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \lambda \theta_{t-1}\right)

Data Loading Infrastructure

DataLoader for Distributed Training

llmc/dataloader.h:

C
typedef struct {
    // distributed training
    int process_rank;
    int num_processes;
    // batch and token information
    size_t B;  // batch size
    size_t T;  // sequence length
    size_t num_tokens; // total number of tokens
    size_t shard_num_samples;  // samples in current shard per process
    // shards and position
    glob_t glob_result; // all shards to iterate
    size_t current_shard_idx;
    size_t current_sample_idx;
    // file handle
    FILE* tokens_file;
    // data buffers
    uint16_t* buffer; // raw data from file
    int* inputs;      // input tokens
    int* targets;     // target tokens
    // random shuffle
    mt19937_state shuffle_rng;
    int should_shuffle;
    int* shard_indices;
    int* intra_shard_indices;
} DataLoader;

Shard-based loading:

C
int64_t dataloader_load_shard_(DataLoader *loader, int shard_index) {
    if (loader->should_shuffle) {
        shard_index = loader->shard_indices[shard_index];
    }
    const char* filename = loader->glob_result.gl_pathv[shard_index];
    loader->tokens_file = fopenCheck(filename, "rb");
    // validate header (magic number 20240520)
    int header[HEADER_SIZE];
    freadCheck(header, sizeof(int), HEADER_SIZE, loader->tokens_file);
    if (header[0] != 20240520) {
        printf("Bad magic in the data file\n");
        exit(EXIT_FAILURE);
    }
    // ... load shard metadata ...
}

Learning Rate Schedulers

llmc/schedulers.h:

C
typedef struct {
    const char* type;
    float learning_rate;
    int warmup_iterations;
    int train_num_batches;
    float final_learning_rate_frac;
} LearningRateScheduler;

// Cosine decay with warmup
float get_learning_rate_cosine(LearningRateScheduler *scheduler, int step) {
    float lr = scheduler->learning_rate;
    if (step < scheduler->warmup_iterations) {
        // Linear warmup
        lr = scheduler->learning_rate * ((float)(step + 1)) / scheduler->warmup_iterations;
    } else {
        // Cosine decay
        float decay_ratio = ((float)(step - scheduler->warmup_iterations)) /
                           (scheduler->train_num_batches - scheduler->warmup_iterations);
        float coeff = 0.5f * (1.0f + cosf(M_PI * decay_ratio));
        float min_lr = scheduler->learning_rate * scheduler->final_learning_rate_frac;
        lr = min_lr + coeff * (scheduler->learning_rate - min_lr);
    }
    return lr;
}

// Linear decay with warmup
float get_learning_rate_linear(LearningRateScheduler *scheduler, int step) {
    float lr = scheduler->learning_rate;
    if (step < scheduler->warmup_iterations) {
        lr = scheduler->learning_rate * ((float)(step + 1)) / scheduler->warmup_iterations;
    } else {
        float decay_ratio = ((float)(step - scheduler->warmup_iterations)) /
                           (scheduler->train_num_batches - scheduler->warmup_iterations);
        float min_lr = scheduler->learning_rate * scheduler->final_learning_rate_frac;
        lr = scheduler->learning_rate - decay_ratio * (scheduler->learning_rate - min_lr);
    }
    return lr;
}

Cosine schedule formula:

η(t)={ηmaxtTwarmupif t<Twarmupηmin+12(ηmaxηmin)(1+cos(πtTwarmupTTwarmup))otherwise\eta(t) = \begin{cases} \eta_{\max} \cdot \frac{t}{T_{\text{warmup}}} & \text{if } t < T_{\text{warmup}} \\ \eta_{\min} + \frac{1}{2}(\eta_{\max} - \eta_{\min})\left(1 + \cos\left(\pi \cdot \frac{t - T_{\text{warmup}}}{T - T_{\text{warmup}}}\right)\right) & \text{otherwise} \end{cases}

Complete Training Loop

train_gpt2.c:1077-1172:

C
int main() {
    // build the GPT-2 model from a checkpoint
    GPT2 model;
    gpt2_build_from_checkpoint(&model, "gpt2_124M.bin");

    // build the DataLoaders
    int B = 4;  // batch size
    int T = 64; // sequence length
    DataLoader train_loader, val_loader;
    dataloader_init(&train_loader, train_tokens, B, T, 0, 1, 1);
    dataloader_init(&val_loader, val_tokens, B, T, 0, 1, 0);

    // train
    for (int step = 0; step <= 40; step++) {
        // once in a while estimate the validation loss
        if (step % 10 == 0) {
            float val_loss = 0.0f;
            dataloader_reset(&val_loader);
            for (int i = 0; i < val_num_batches; i++) {
                dataloader_next_batch(&val_loader);
                gpt2_forward(&model, val_loader.inputs, val_loader.targets, B, T);
                val_loss += model.mean_loss;
            }
            val_loss /= val_num_batches;
            printf("val loss %f\n", val_loss);
        }

        // do a training step
        clock_gettime(CLOCK_MONOTONIC, &start);
        dataloader_next_batch(&train_loader);
        gpt2_forward(&model, train_loader.inputs, train_loader.targets, B, T);
        gpt2_zero_grad(&model);
        gpt2_backward(&model);
        gpt2_update(&model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.0f, step+1);
        clock_gettime(CLOCK_MONOTONIC, &end);
        double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9;
        printf("step %d: train loss %f (took %f ms)\n", step, model.mean_loss, time_elapsed_s * 1000);
    }
}

CUDA Implementation

The CUDA version (train_gpt2.cu) uses optimized GPU kernels.

CUDA Encoder (Token + Position Embedding)

llmc/encoder.cuh - Vectorized token + position embedding:

Code
__global__ void encoder_forward_kernel3(floatX* out,
                               const int* inp, const floatX* wte, const floatX* wpe,
                               int B, int T, int C) {
    int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size;
    int N = B * T * C;
    if (idx >= N) { return; }

    int bt = idx / C;
    int b = bt / T;
    int t = bt % T;
    int c = idx % C;

    int ix = inp[b * T + t];  // Token ID

    floatX* out_btc = out + b * T * C + t * C + c;
    const floatX* wte_ix = wte + ix * C + c;  // Token embedding
    const floatX* wpe_tc = wpe + t * C + c;   // Position embedding

    // Load 128 bits (8 BF16 values) at once
    x128 packed_out;
    x128 wte128 = load128cs(wte_ix);
    x128 wpe128 = load128cs(wpe_tc);
    for (int k = 0; k < x128::size; k++) {
        packed_out[k] = (floatX)((float)wte128[k] + (float)wpe128[k]);
    }
    store128(out_btc, packed_out);
}

Key optimizations:

  • Each thread handles x128::size (8 for BF16) elements
  • load128cs = load 128 bits, don't cache (streaming)
  • Coalesced memory access pattern

CUDA LayerNorm Kernel

llmc/layernorm.cuh (simplified):

Code
__global__ void layernorm_forward_kernel3(floatX* __restrict__ out, floatX* __restrict__ mean, floatX* __restrict__ rstd,
                                    const floatX*  __restrict__ inp, const floatX*  __restrict__ weight,
                                    const floatX* __restrict__ bias, int N, int C) {
    int lane_id = threadIdx.x % WARP_SIZE;
    int warp_id = threadIdx.x / WARP_SIZE;
    int num_warps = blockDim.x / WARP_SIZE;

    int idx = blockIdx.x * num_warps + warp_id;
    if(idx >= N) { return; }

    const floatX* x = inp + idx * C;
    floatX* o = out + idx * C;

    // thread coarsening: each thread handles multiple elements
    float sum = 0.0f;
    for (int i = lane_id; i < C; i += WARP_SIZE) {
        sum += (float)x[i];
    }
    sum = warpReduceSum(sum);
    float m = sum / C;
    if(lane_id == 0 && mean != nullptr) {
        __stcs(mean + idx, (floatX)m);
    }

    // compute variance
    sum = 0.0f;
    for (int i = lane_id; i < C; i += WARP_SIZE) {
        float xshift = (float)x[i] - m;
        sum += xshift * xshift;
    }
    sum = warpReduceSum(sum);
    float s = rsqrtf(sum / C + 1e-5f);
    if(lane_id == 0 && rstd != nullptr) {
        __stcs(rstd + idx, (floatX)s);
    }

    // normalize and write output
    for (int c = lane_id; c < C; c += WARP_SIZE) {
        float n = s * ((float)x[c] - m);
        __stcs(&o[c], (floatX)(n * (float)weight[c] + (float)bias[c]));
    }
}

Key CUDA optimizations:

  • Warp-level reductions (warpReduceSum)
  • Thread coarsening (each thread handles multiple elements)
  • Memory coalescing
  • __stcs for streaming stores (bypass L1 cache)

CUDA Attention with cuDNN

llmc/cudnn_att.cpp:

C++
void attention_forward_cudnn(floatX* out,
                             floatX* stats,
                             floatX* inp,
                             int B, int T, int NH, int C) {
    int HS = C / NH;

    cudnnAttnDescriptor_t attn_desc;
    cudnnCreateAttnDescriptor(&attn_desc);
    cudnnSetAttnDescriptor(attn_desc,
        CUDNN_ATTN_QUERYMAP_ALL_TO_ONE,
        NH, 1.0f / sqrtf(HS),
        CUDNN_DATA_FLOAT, CUDNN_DATA_FLOAT,
        CUDNN_DEFAULT_MATH,
        NULL, NULL,
        // ... more configuration
    );

    cudnnMultiHeadAttnForward(cudnn_handle,
        attn_desc,
        -1, // no lookahead (causal)
        // ... inputs and outputs
    );
}

CUDA Matmul with cuBLAS

llmc/matmul.cuh uses cuBLASLt for efficient matrix multiplication:

Code
void matmul_forward_cublaslt(floatX* out,
                             floatX* inp, floatX* weight, floatX* bias,
                             int B, int T, int C, int OC, cudaStream_t stream) {
    // Use cuBLASLt for the matrix multiplication
    // inp is (B*T, C), weight is (OC, C), out is (B*T, OC)

    const float alpha = 1.0f;
    const float beta = 0.0f;

    cublasLtMatmulDesc_t operationDesc;
    cublasLtMatmulDescCreate(&operationDesc, cublas_compute, CUDA_R_32F);

    // Set transpose operation for weight matrix
    cublasOperation_t transa = CUBLAS_OP_T;
    cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));

    cublasLtMatmul(cublaslt_handle,
                   operationDesc,
                   &alpha,
                   weight, Adesc,
                   inp, Bdesc,
                   &beta,
                   out, Cdesc,
                   out, Cdesc,
                   NULL, cublaslt_workspace, cublaslt_workspace_size, stream);
}

CUDA Matmul Backward (Bias)

llmc/matmul.cuh - Warp-level reduction for bias gradients:

Code
template<typename OutFloat, bool UseAuxBuffer>
__global__ void matmul_backward_bias_kernel9(OutFloat* dbias, const floatX* dout,
                                             int B, int T, int OC) {
    constexpr const int bdx = 4;
    constexpr const int bdy = WARP_SIZE / bdx;

    int warp_d = (int)threadIdx.x;
    int warp_c = (int)threadIdx.y;
    int block_d = (int)threadIdx.z;

    const int OC_per_warp = bdy * x128::size;  // 64 at BF16
    int global_oc = blockIdx.x * OC_per_warp + warp_c * x128::size;

    float accumulators[x128::size] = {0.0f};

    if(global_oc < OC) {
        // sum up over all bt within registers
        for (int idx = blockIdx.y * bt_per_block + local_bt; idx < B * T; idx += gridDim.y * bt_per_block) {
            x128 packed_dout = load128(dout + global_oc + idx*OC);
            for (int k = 0; k < x128::size; k++) {
                accumulators[k] += (float)packed_dout[k];
            }
        }
    }

    // Warp shuffle reduction
    for (int k = 0; k < x128::size; k++) {
        float v = accumulators[k];
        v += __shfl_down_sync(0xffffffff, v, 1, 4);
        v += __shfl_down_sync(0xffffffff, v, 2, 4);
        // Write to shared memory for block reduction
    }
}

Key optimizations:

  • x128 packed loads (128-bit = 8 × BF16)
  • Warp shuffle reductions (__shfl_down_sync)
  • Coalesced memory access patterns

CUDA AdamW Kernel

llmc/adamw.cuh:

Code
__global__ void adamw_kernel3(floatX* params_memory, float* master_params_memory, floatX* grads_memory, float* m_memory, float* v_memory, size_t num_parameters,
                              float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay,
                              float grad_scale, unsigned int seed) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx >= num_parameters) { return; }

    float grad = grad_scale * (float)grads_memory[idx];
    float m = m_memory[idx];
    float v = v_memory[idx];

    // AdamW update
    m = beta1 * m + (1.0f - beta1) * grad;
    v = beta2 * v + (1.0f - beta2) * grad * grad;
    m_memory[idx] = m;
    v_memory[idx] = v;

    float m_hat = m / beta1_correction;
    float v_hat = v / beta2_correction;

    float param = cyclic_load(master_params_memory, params_memory, idx);
    param -= learning_rate * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * param);
    cyclic_store(master_params_memory, params_memory, idx, param);
}

CUDA GELU Forward and Backward

llmc/gelu.cuh - Vectorized GELU with 128-bit packed loads:

Code
#define GELU_SCALING_FACTOR sqrtf(2.0f / M_PI)

__global__ void gelu_forward_kernel2(floatX* out, const floatX* inp) {
    int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size;

    x128 packed_out;
    x128 packed_inp = load128cs(inp + idx);  // load, don't cache
    for(int k = 0; k < packed_inp.size; ++k) {
        float xi = (float)packed_inp[k];
        float cube = 0.044715f * xi * xi * xi;
        packed_out[k] = (floatX)(0.5f * xi * (1.0f + tanhf(GELU_SCALING_FACTOR * (xi + cube))));
    }
    store128(out + idx, packed_out);  // keep in cache for next op
}

__global__ void gelu_backward_inplace_kernel(floatX* d_in_out, const floatX* inp) {
    int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size;

    x128 packed_dinp;
    x128 packed_inp = load128cs(inp + idx);
    x128 packed_dout = load128(d_in_out + idx);

    for (int k = 0; k < packed_inp.size; ++k) {
        float x = (float)packed_inp[k];
        float cube = 0.044715f * x * x * x;
        float tanh_arg = GELU_SCALING_FACTOR * (x + cube);
        float tanh_out = tanhf(tanh_arg);
        float coshf_out = coshf(tanh_arg);
        float sech_out = 1.0f / (coshf_out * coshf_out);
        float local_grad = 0.5f * (1.0f + tanh_out) +
                          x * 0.5f * sech_out * GELU_SCALING_FACTOR *
                          (1.0f + 3.0f * 0.044715f * x * x);
        packed_dinp[k] = (floatX)(local_grad * (float)packed_dout[k]);
    }
    store128(d_in_out + idx, packed_dinp);
}

Key optimizations:

  • x128 packed loads process 8 BF16 values (128 bits) per thread
  • load128cs bypasses L1 cache for streaming loads
  • In-place backward saves memory allocation

Gradient Clipping (Global Norm)

llmc/global_norm.cuh - Compute gradient norm for clipping:

Code
template<class T>
__device__ float global_norm_squared_for_range(const T* data, size_t count) {
    size_t index = blockIdx.x * blockDim.x + threadIdx.x;
    size_t grid_width = blockDim.x * gridDim.x;
    float accumulator = 0.f;

    // Grid-stride loop for all elements
    for(size_t i = index; i < count; i += grid_width) {
        accumulator += (float)data[i] * (float)data[i];
    }

    // Block-level reduction
    return blockReduce<warpReduceSum>(accumulator);
}

template<class T>
__global__ void global_norm_squared_kernel(float* out, const T* data,
                                           size_t count, ptrdiff_t stride) {
    float block_sum = global_norm_squared_for_range(data + blockIdx.y * stride, count);
    // Each block writes partial sum (no atomics for determinism)
    if(threadIdx.x == 0) {
        size_t out_index = blockIdx.y * gridDim.x + blockIdx.x;
        out[out_index] = out[out_index] + block_sum;
    }
}

__global__ void global_norm_aggregate_kernel(float* out, size_t grid_size) {
    size_t index = threadIdx.x;
    float block_sum = (index < grid_size) ? out[index] : 0.f;
    float sum = blockReduce<warpReduceSum>(block_sum);
    if(threadIdx.x == 0) {
        out[0] = sum;  // Final norm squared
    }
}

Gradient clipping usage:

C
// Compute global gradient norm
global_norm_squared(norm_buffer, grads, num_params, ...);
global_norm_aggregate_kernel<<<1, 1024>>>(norm_buffer, grid_size);
float grad_norm = sqrtf(norm_buffer[0]);

// Clip if norm exceeds threshold
if (grad_norm > max_norm) {
    float clip_coef = max_norm / grad_norm;
    // Scale all gradients by clip_coef
}

Stochastic Rounding for BF16

llmc/encoder.cuh - Deterministic stochastic rounding for BF16 training:

Code
__global__ void wte_backward_kernel(floatX* dwte,
                                    const int4* bucket_info, const int* workload_indices,
                                    const floatX* dout, const int* inp,
                                    unsigned int seed, int B, int T, int C) {
    // ... accumulate gradients in float ...

    // Stochastic rounding: FP32 -> BF16 with deterministic seed
    for (unsigned int k = 0; k < x128::size; k++) {
        // Unique seed per parameter for determinism without UB
        stochastic_rounding(accum[k] + (float)packed_in_out[k],
                           &packed_in_out[k],
                           seed + bucket * WARP_SIZE + threadIdx.x + k);
    }
    store128(dwte_ix, packed_in_out);
}

Why stochastic rounding?

BF16 has limited precision (7 bits mantissa). When accumulating small gradients:

  • Round-to-nearest: Small gradients may round to 0, causing gradient vanishing
  • Stochastic rounding: Small gradients have a probability of rounding up

P(round up)=xxϵP(\text{round up}) = \frac{x - \lfloor x \rfloor}{\epsilon}

This gives unbiased gradients in expectation while maintaining full determinism (same seed = same result).

Fused Softmax + Cross-Entropy

llmc/fused_classifier.cuh - Fuses softmax and cross-entropy for efficiency:

Code
__global__ void fused_classifier_kernel5(floatX* logits, floatX* losses, floatX* probs,
                                         const floatX* dlosses, const int* targets,
                                         int B, int T, int V, int P) {
    // Each block handles one (b,t) position
    int idx = blockIdx.x;
    int b = idx / T;
    int t = idx % T;

    const floatX* logits_bt = logits + idx * P;
    int target = targets[idx];

    // Step 1: Find max for numerical stability (warp reduction)
    float maxval = -INFINITY;
    for (int i = threadIdx.x; i < V; i += blockDim.x) {
        maxval = fmaxf(maxval, (float)logits_bt[i]);
    }
    maxval = warpReduceMax(maxval);

    // Step 2: Compute exp sum
    float sumval = 0.0f;
    for (int i = threadIdx.x; i < V; i += blockDim.x) {
        sumval += expf((float)logits_bt[i] - maxval);
    }
    sumval = warpReduceSum(sumval);

    // Step 3: Compute loss = -log(softmax[target])
    float prob_target = expf((float)logits_bt[target] - maxval) / sumval;
    if (threadIdx.x == 0) {
        losses[idx] = -logf(prob_target);
    }
}

Benefits of fusion:

  • Single pass over logits (better memory bandwidth)
  • No intermediate storage for probabilities
  • Reduced kernel launch overhead

Multi-GPU Training with NCCL

llmc/zero.cuh implements ZeRO-style distributed training:

Code
void multi_gpu_async_reduce_gradient(
    floatX* local_grads,
    floatX* all_grads,
    MultiGpuConfig* config) {

    // All-reduce gradients across GPUs
    ncclAllReduce(
        local_grads,
        all_grads,
        num_parameters,
        ncclFloatX,
        ncclSum,
        config->nccl_comm,
        config->stream
    );
}

ZeRO Optimization: Partitions optimizer states across GPUs to reduce memory:

  • ZeRO-1: Partition optimizer states
  • ZeRO-2: Partition gradients + optimizer states
  • ZeRO-3: Partition parameters + gradients + optimizer states

Building and Running

Bash
# CPU version
make train_gpt2
./train_gpt2

# CUDA version (requires CUDA toolkit)
make train_gpt2cu
./train_gpt2cu

# With cuDNN (faster attention)
make train_gpt2cu USE_CUDNN=1
./train_gpt2cu

Performance Comparison

ImplementationHardwareTokens/secNotes
train_gpt2.c8-core CPU~1KReference, OpenMP
train_gpt2.cuRTX 4090~300KCUDA, cuBLAS
train_gpt2.cu + cuDNNRTX 4090~400KFlash Attention
PyTorchRTX 4090~250KFor comparison

Summary

llm.c demonstrates that LLM training doesn't require Python or deep learning frameworks:

ComponentC LinesCUDA Lines
LayerNorm82200+
Attention133300+
Matmul98(cuBLAS)
GELU2750+
AdamW2750+
Full training11821904

Key insights:

  • The backward pass is ~2x the complexity of forward
  • Most time is spent in matmul (use cuBLAS/BLAS)
  • Attention backward is O(T2)O(T^2) - the bottleneck for long sequences
  • Multi-GPU training requires careful gradient synchronization

Frequently Asked Questions

Enrico Piovano, PhD

Co-founder & CTO at Goji AI. Former Applied Scientist at Amazon (Alexa & AGI), focused on Agentic AI and LLMs. PhD in Electrical Engineering from Imperial College London. Gold Medalist at the National Mathematical Olympiad.

Related Articles

llama2.c: Andrej Karpathy's Pure C Inference Engine

A deep dive into llama2.c - Karpathy's ~1000 lines of pure C that runs Llama 2 inference. Covers the complete transformer implementation including RMSNorm, RoPE, SwiGLU, GQA, KV cache, BPE tokenization, and top-p sampling - all without any dependencies.

5 min read
LLMsTraining

nanoGPT: Andrej Karpathy's Minimal GPT Training Framework

A comprehensive, equation-complete analysis of nanoGPT—Andrej Karpathy's influential minimal GPT implementation. Deep dive into the ~300-line model definition (model.py), training loop (train.py), Flash Attention, weight initialization, and the mathematical foundations behind every component.

3 min read
LLMsTraining

nanochat: Andrej Karpathy's Full-Stack ChatGPT Clone

A comprehensive, equation-complete analysis of nanochat—the complete ChatGPT pipeline from tokenization through reinforcement learning. Deep dive into the modern GPT architecture (RoPE, RMSNorm, GQA, QK-norm, ReLU²), the Muon optimizer with Newton-Schulz orthogonalization, KV cache inference, and tool use.

5 min read
LLMsML Engineering

Distributed Training: How to Train 70B+ Parameter Models

A comprehensive deep dive into distributed training—how to train models that don't fit on a single GPU. Understand data parallelism, tensor parallelism, pipeline parallelism, ZeRO optimization, and the engineering behind training frontier LLMs.

3 min read
LLMsML Engineering

Attention Mechanisms: From Self-Attention to FlashAttention

A comprehensive deep dive into attention mechanisms—the core innovation powering modern LLMs. From the intuition behind self-attention to the engineering of FlashAttention, understand how transformers actually work.

7 min read