Skip to main content
Back to Blog

Sparse Attention Patterns: Longformer, BigBird, and Beyond

A comprehensive deep dive into sparse attention mechanisms—how Longformer, BigBird, and other architectures break the O(n²) attention barrier. Understand local attention, global tokens, random attention, and when to use each pattern.

3 min read
Share:

The Quadratic Attention Problem

Standard transformer attention has a fundamental scaling problem: every token attends to every other token, creating an O(n²) computation and memory cost. Double the sequence length, and attention costs quadruple.

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    THE O(n²) PROBLEM                                     │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  ATTENTION MATRIX SIZE:                                                  │
│  ─────────────────────                                                   │
│                                                                          │
│  Sequence Length    Attention Matrix    Memory (FP16)                   │
│  ─────────────────────────────────────────────────────                  │
│  512               512 × 512           0.5 MB                           │
│  2,048             2K × 2K             8 MB                             │
│  8,192             8K × 8K             128 MB                           │
│  32,768            32K × 32K           2 GB                             │
│  131,072           128K × 128K         32 GB                            │
│                                                                          │
│  Per attention head! Multiply by num_heads × num_layers.               │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  REAL MEMORY IMPACT (32 layers, 32 heads):                              │
│                                                                          │
│  Sequence Length    Total Attention Memory                             │
│  ─────────────────────────────────────────                              │
│  2K                 8 GB                                                │
│  8K                 128 GB                                              │
│  32K                2 TB                                                │
│                                                                          │
│  Without optimizations, 32K context is impossible!                     │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  COMPUTE SCALING:                                                        │
│                                                                          │
│  Attention FLOPs ∝ n² × d                                              │
│                                                                          │
│  At 8K context: 64× more compute than 1K                              │
│  At 128K context: 16,384× more compute than 1K                        │
│                                                                          │
│  This is unsustainable.                                                │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Solutions to Quadratic Attention

Several approaches address this:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    APPROACHES TO O(n²)                                   │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  1. FLASHATTENTION (Keep full attention, optimize memory)              │
│     • Still O(n²) compute                                              │
│     • O(n) memory via tiling                                           │
│     • Best when you need full attention                                │
│     • Covered in: attention-mechanisms-flash-attention-deep-dive       │
│                                                                          │
│  2. SPARSE ATTENTION (Attend to subset of tokens)                      │
│     • O(n) or O(n√n) compute and memory                               │
│     • Attends only to "important" positions                           │
│     • This post's focus                                                │
│                                                                          │
│  3. LINEAR ATTENTION (Approximate attention kernel)                    │
│     • O(n) compute and memory                                          │
│     • Approximates softmax with kernel tricks                          │
│     • Examples: Performers, Linear Transformers                        │
│                                                                          │
│  4. STATE SPACE MODELS (Different architecture)                        │
│     • O(n) or O(n log n) compute                                       │
│     • Recurrent-like structure                                         │
│     • Examples: Mamba, RWKV                                            │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  WHY SPARSE ATTENTION?                                                   │
│  ─────────────────────                                                   │
│                                                                          │
│  Key observation: Not all attention connections are equally useful.   │
│                                                                          │
│  In practice, attention is often:                                      │
│  • Local: Nearby tokens matter most                                   │
│  • Sparse: Few tokens get high attention weight                       │
│  • Structured: Certain positions (start, special tokens) matter       │
│                                                                          │
│  If we can identify which connections matter, we can skip the rest.   │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Part I: Attention Pattern Taxonomy

Basic Sparse Patterns

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    SPARSE ATTENTION PATTERNS                             │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  In these diagrams:                                                     │
│  ■ = attention connection (query attends to key)                       │
│  □ = no attention (masked out)                                         │
│  Rows = query positions, Columns = key positions                       │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  1. FULL ATTENTION (Standard Transformer):                              │
│  ─────────────────────────────────────────                               │
│                                                                          │
│     Keys:  0 1 2 3 4 5 6 7                                             │
│  Q  0      ■ □ □ □ □ □ □ □    (causal: attend to past only)           │
│  u  1      ■ ■ □ □ □ □ □ □                                             │
│  e  2      ■ ■ ■ □ □ □ □ □                                             │
│  r  3      ■ ■ ■ ■ □ □ □ □                                             │
│  i  4      ■ ■ ■ ■ ■ □ □ □                                             │
│  e  5      ■ ■ ■ ■ ■ ■ □ □                                             │
│  s  6      ■ ■ ■ ■ ■ ■ ■ □                                             │
│     7      ■ ■ ■ ■ ■ ■ ■ ■                                             │
│                                                                          │
│  Complexity: O(n²)                                                      │
│  Every token attends to all previous tokens.                          │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  2. LOCAL/SLIDING WINDOW ATTENTION:                                     │
│  ──────────────────────────────────                                      │
│                                                                          │
│     Keys:  0 1 2 3 4 5 6 7                                             │
│  Q  0      ■ ■ □ □ □ □ □ □    (window size = 3)                       │
│  u  1      ■ ■ ■ □ □ □ □ □                                             │
│  e  2      □ ■ ■ ■ □ □ □ □                                             │
│  r  3      □ □ ■ ■ ■ □ □ □                                             │
│  i  4      □ □ □ ■ ■ ■ □ □                                             │
│  e  5      □ □ □ □ ■ ■ ■ □                                             │
│  s  6      □ □ □ □ □ ■ ■ ■                                             │
│     7      □ □ □ □ □ □ ■ ■                                             │
│                                                                          │
│  Complexity: O(n × w) where w = window size                            │
│  Each token attends to w neighbors.                                    │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  3. DILATED/STRIDED ATTENTION:                                          │
│  ─────────────────────────────                                           │
│                                                                          │
│     Keys:  0 1 2 3 4 5 6 7                                             │
│  Q  0      ■ □ ■ □ □ □ □ □    (stride = 2)                            │
│  u  1      □ ■ □ ■ □ □ □ □                                             │
│  e  2      ■ □ ■ □ ■ □ □ □                                             │
│  r  3      □ ■ □ ■ □ ■ □ □                                             │
│  i  4      ■ □ ■ □ ■ □ ■ □                                             │
│  e  5      □ ■ □ ■ □ ■ □ ■                                             │
│  s  6      □ □ ■ □ ■ □ ■ □                                             │
│     7      □ □ □ ■ □ ■ □ ■                                             │
│                                                                          │
│  Complexity: O(n × n/stride)                                           │
│  Attends to every k-th token. Captures longer range.                  │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  4. GLOBAL ATTENTION (Selected tokens attend to all):                   │
│  ────────────────────────────────────────────────────                    │
│                                                                          │
│     Keys:  0 1 2 3 4 5 6 7                                             │
│  Q  0      ■ ■ ■ ■ ■ ■ ■ ■    ← Global token (e.g., [CLS])           │
│  u  1      ■ ■ □ □ □ □ □ □                                             │
│  e  2      ■ □ ■ □ □ □ □ □                                             │
│  r  3      ■ □ □ ■ □ □ □ □                                             │
│  i  4      ■ □ □ □ ■ □ □ □                                             │
│  e  5      ■ □ □ □ □ ■ □ □                                             │
│  s  6      ■ □ □ □ □ □ ■ □                                             │
│     7      ■ □ □ □ □ □ □ ■                                             │
│                                                                          │
│  Token 0 attends to everyone, everyone attends to token 0.            │
│  Complexity: O(n × g + n) where g = number of global tokens           │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Part II: Longformer

The Longformer Architecture

Longformer (Beltagy et al., 2020) combines local sliding window attention with global attention for specific tokens:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    LONGFORMER ATTENTION                                  │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  LONGFORMER COMBINES:                                                    │
│  ────────────────────                                                    │
│                                                                          │
│  1. Local sliding window attention (for most tokens)                   │
│  2. Global attention (for special tokens like [CLS])                   │
│  3. Optional dilated sliding window (for some layers)                  │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  ATTENTION PATTERN (16 tokens, window=4, global=[CLS]):                │
│  ───────────────────────────────────────────────────────                 │
│                                                                          │
│            Keys                                                         │
│        [CLS] 1  2  3  4  5  6  7  8  9 10 11 12 13 14 15               │
│  [CLS]   ■   ■  ■  ■  ■  ■  ■  ■  ■  ■  ■  ■  ■  ■  ■  ■              │
│    1     ■   ■  ■  ■  □  □  □  □  □  □  □  □  □  □  □  □   Q          │
│    2     ■   ■  ■  ■  ■  □  □  □  □  □  □  □  □  □  □  □   u          │
│    3     ■   □  ■  ■  ■  ■  □  □  □  □  □  □  □  □  □  □   e          │
│    4     ■   □  □  ■  ■  ■  ■  □  □  □  □  □  □  □  □  □   r          │
│    5     ■   □  □  □  ■  ■  ■  ■  □  □  □  □  □  □  □  □   i          │
│    6     ■   □  □  □  □  ■  ■  ■  ■  □  □  □  □  □  □  □   e          │
│    7     ■   □  □  □  □  □  ■  ■  ■  ■  □  □  □  □  □  □   s          │
│    8     ■   □  □  □  □  □  □  ■  ■  ■  ■  □  □  □  □  □              │
│    9     ■   □  □  □  □  □  □  □  ■  ■  ■  ■  □  □  □  □              │
│   10     ■   □  □  □  □  □  □  □  □  ■  ■  ■  ■  □  □  □              │
│   11     ■   □  □  □  □  □  □  □  □  □  ■  ■  ■  ■  □  □              │
│   12     ■   □  □  □  □  □  □  □  □  □  □  ■  ■  ■  ■  □              │
│   13     ■   □  □  □  □  □  □  □  □  □  □  □  ■  ■  ■  ■              │
│   14     ■   □  □  □  □  □  □  □  □  □  □  □  □  ■  ■  ■              │
│   15     ■   □  □  □  □  □  □  □  □  □  □  □  □  □  ■  ■              │
│                                                                          │
│  First column: Global attention to [CLS]                               │
│  First row: [CLS] attends globally                                     │
│  Diagonal bands: Local sliding window                                  │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  COMPLEXITY ANALYSIS:                                                    │
│  ────────────────────                                                    │
│                                                                          │
│  For sequence length n, window size w, g global tokens:                │
│                                                                          │
│  Local attention: n × w                                                │
│  Global attention: 2 × g × n  (global tokens + tokens attending to global)
│                                                                          │
│  Total: O(n × w + g × n) = O(n × (w + g))                             │
│                                                                          │
│  For w=512, g=1, n=4096:                                              │
│  Full attention: 4096² = 16.7M                                        │
│  Longformer: 4096 × 513 ≈ 2.1M                                        │
│  8× reduction!                                                         │
│                                                                          │
│  For w=512, g=1, n=16384:                                             │
│  Full attention: 16384² = 268M                                        │
│  Longformer: 16384 × 513 ≈ 8.4M                                       │
│  32× reduction!                                                        │
│                                                                          │
│  Savings increase with sequence length.                               │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Dilated Sliding Window

Longformer can use dilated attention in higher layers to increase receptive field:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    DILATED SLIDING WINDOW                                │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  THE IDEA:                                                               │
│  ─────────                                                               │
│                                                                          │
│  In lower layers: Dense local attention (window=512)                   │
│  In higher layers: Dilated attention (dilation=2, window=512)         │
│                                                                          │
│  This increases effective receptive field without increasing cost.    │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  DILATION EXAMPLE:                                                       │
│  ─────────────────                                                       │
│                                                                          │
│  Non-dilated (d=1, w=4):                                               │
│  Position 8 attends to: 5, 6, 7, 8, 9, 10, 11, 12                     │
│  Receptive field: 8 consecutive tokens                                │
│                                                                          │
│  Dilated (d=2, w=4):                                                   │
│  Position 8 attends to: 2, 4, 6, 8, 10, 12, 14, 16                    │
│  Receptive field: 16 positions (but only 8 tokens)                    │
│                                                                          │
│  Same cost, 2× receptive field!                                       │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  MULTI-HEAD WITH DIFFERENT DILATIONS:                                    │
│  ─────────────────────────────────────                                   │
│                                                                          │
│  Head 0: dilation = 1 (local details)                                 │
│  Head 1: dilation = 2                                                  │
│  Head 2: dilation = 4                                                  │
│  Head 3: dilation = 8 (long range)                                    │
│                                                                          │
│  Different heads capture different scales simultaneously.             │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  EFFECTIVE RECEPTIVE FIELD:                                              │
│  ──────────────────────────                                              │
│                                                                          │
│  After L layers with window w and dilation d:                         │
│  Receptive field = L × w × d                                          │
│                                                                          │
│  Example: 12 layers, w=512, d varying 1-4:                            │
│  Lower layers (d=1): captures local patterns                          │
│  Higher layers (d=4): combines patterns from wider context            │
│  Effective receptive field: ~16K+ tokens                              │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Longformer Implementation

Python
import torch
import torch.nn as nn
import torch.nn.functional as F

class LongformerAttention(nn.Module):
    """
    Longformer-style attention with sliding window + global attention.
    """

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        window_size: int = 512,
        global_tokens: int = 1,  # Usually [CLS] token
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.window_size = window_size
        self.global_tokens = global_tokens

        self.q_proj = nn.Linear(hidden_size, hidden_size)
        self.k_proj = nn.Linear(hidden_size, hidden_size)
        self.v_proj = nn.Linear(hidden_size, hidden_size)
        self.out_proj = nn.Linear(hidden_size, hidden_size)

        # Separate projections for global attention
        self.global_q_proj = nn.Linear(hidden_size, hidden_size)
        self.global_k_proj = nn.Linear(hidden_size, hidden_size)
        self.global_v_proj = nn.Linear(hidden_size, hidden_size)

    def forward(self, hidden_states, attention_mask=None):
        batch_size, seq_len, _ = hidden_states.shape

        # Compute Q, K, V for local attention
        q = self.q_proj(hidden_states)
        k = self.k_proj(hidden_states)
        v = self.v_proj(hidden_states)

        # Reshape for multi-head attention
        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim)
        k = k.view(batch_size, seq_len, self.num_heads, self.head_dim)
        v = v.view(batch_size, seq_len, self.num_heads, self.head_dim)

        # Local sliding window attention
        local_output = self._sliding_window_attention(q, k, v)

        # Global attention for first g tokens
        if self.global_tokens > 0:
            global_output = self._global_attention(
                hidden_states, q, k, v
            )
            # Combine: global tokens use global attention,
            # other tokens use local + attend to global
            output = self._combine_local_global(
                local_output, global_output, hidden_states
            )
        else:
            output = local_output

        return self.out_proj(output)

    def _sliding_window_attention(self, q, k, v):
        """
        Compute attention within sliding windows.

        For efficiency, real implementations use custom CUDA kernels.
        This is a simplified version for understanding.
        """
        batch_size, seq_len, num_heads, head_dim = q.shape
        w = self.window_size

        output = torch.zeros_like(q)

        for i in range(seq_len):
            # Window bounds
            start = max(0, i - w // 2)
            end = min(seq_len, i + w // 2 + 1)

            # Get Q for position i, K and V for window
            q_i = q[:, i:i+1, :, :]  # (batch, 1, heads, dim)
            k_window = k[:, start:end, :, :]  # (batch, window, heads, dim)
            v_window = v[:, start:end, :, :]

            # Compute attention
            scores = torch.einsum('bqhd,bkhd->bhqk', q_i, k_window)
            scores = scores / (head_dim ** 0.5)
            attn = F.softmax(scores, dim=-1)
            out_i = torch.einsum('bhqk,bkhd->bqhd', attn, v_window)

            output[:, i:i+1, :, :] = out_i

        return output.view(batch_size, seq_len, -1)

    def _global_attention(self, hidden_states, q, k, v):
        """
        Global attention for designated global tokens.
        """
        batch_size, seq_len, _ = hidden_states.shape
        g = self.global_tokens

        # Global tokens (e.g., first g tokens)
        global_hidden = hidden_states[:, :g, :]

        # Global Q, K, V with separate projections
        global_q = self.global_q_proj(global_hidden)
        global_k = self.global_k_proj(hidden_states)  # All tokens as keys
        global_v = self.global_v_proj(hidden_states)

        # Reshape
        global_q = global_q.view(batch_size, g, self.num_heads, self.head_dim)
        global_k = global_k.view(batch_size, seq_len, self.num_heads, self.head_dim)
        global_v = global_v.view(batch_size, seq_len, self.num_heads, self.head_dim)

        # Full attention for global tokens
        scores = torch.einsum('bghd,bshd->bhgs', global_q, global_k)
        scores = scores / (self.head_dim ** 0.5)
        attn = F.softmax(scores, dim=-1)
        global_output = torch.einsum('bhgs,bshd->bghd', attn, global_v)

        return global_output.view(batch_size, g, -1)

Part III: BigBird

Adding Random Attention

BigBird (Zaheer et al., 2020) adds random attention patterns to the mix:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    BIGBIRD ATTENTION                                     │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  BIGBIRD COMBINES THREE PATTERNS:                                        │
│  ─────────────────────────────────                                       │
│                                                                          │
│  1. Local sliding window attention                                     │
│  2. Global attention (select tokens)                                   │
│  3. Random attention (random token pairs)                              │
│                                                                          │
│  The random attention is key: it provides "shortcuts" through          │
│  the sequence, ensuring information can flow between distant tokens.  │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  BIGBIRD PATTERN (simplified, 12 tokens):                               │
│  ─────────────────────────────────────────                               │
│                                                                          │
│  G = Global token    L = Local window    R = Random                    │
│                                                                          │
│        Keys                                                             │
│        0  1  2  3  4  5  6  7  8  9  10 11                             │
│   0    G  G  G  G  G  G  G  G  G  G  G  G   ← Global token            │
│   1    G  L  L  L  □  □  □  R  □  □  □  □   Q                         │
│   2    G  L  L  L  L  □  □  □  □  R  □  □   u                         │
│   3    G  □  L  L  L  L  □  □  □  □  R  □   e                         │
│   4    G  □  □  L  L  L  L  □  R  □  □  □   r                         │
│   5    G  □  □  □  L  L  L  L  □  □  R  □   i                         │
│   6    G  □  □  R  □  L  L  L  L  □  □  □   e                         │
│   7    G  □  R  □  □  □  L  L  L  L  □  □   s                         │
│   8    G  □  □  □  R  □  □  L  L  L  L  □                             │
│   9    G  R  □  □  □  □  □  □  L  L  L  L                             │
│  10    G  □  □  □  □  R  □  □  □  L  L  L                             │
│  11    G  □  □  R  □  □  □  □  □  □  L  L                             │
│                                                                          │
│  G = Global (column 0): Everyone attends to global tokens             │
│  L = Local: Diagonal band for nearby tokens                           │
│  R = Random: Sparse random connections (different per token)          │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  WHY RANDOM ATTENTION?                                                   │
│  ─────────────────────                                                   │
│                                                                          │
│  Theory: Random graphs have small diameter.                            │
│                                                                          │
│  With only local attention:                                            │
│  Token 0 → Token 1000 requires ~1000/window hops                      │
│  Information "diffuses" slowly.                                        │
│                                                                          │
│  With random attention:                                                 │
│  Each token has "shortcuts" to random distant tokens.                 │
│  Like small-world networks: O(log n) hops to reach anywhere.         │
│                                                                          │
│  This is theoretically important: BigBird can simulate full           │
│  attention in O(log n) layers, preserving expressiveness.            │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  COMPLEXITY:                                                             │
│  ───────────                                                             │
│                                                                          │
│  Window size w, r random connections, g global tokens:                │
│                                                                          │
│  Per token: w (local) + r (random) + g (global) connections           │
│  Total: O(n × (w + r + g))                                            │
│                                                                          │
│  Typically w=64, r=3, g=2: each token attends to ~70 others          │
│  vs n for full attention.                                             │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

BigBird vs Longformer

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    BIGBIRD VS LONGFORMER                                 │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  SIMILARITIES:                                                           │
│  ─────────────                                                           │
│  • Both use local sliding window attention                            │
│  • Both support global attention for special tokens                   │
│  • Both achieve O(n) complexity                                       │
│  • Both support up to 4K-16K context                                  │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  DIFFERENCES:                                                            │
│  ────────────                                                            │
│                                                                          │
│  Feature          Longformer              BigBird                       │
│  ───────────────────────────────────────────────────────────────       │
│  Random attn      No                      Yes                          │
│  Dilated attn     Yes (optional)          No                           │
│  Theoretical      Heuristic               Proven expressive            │
│  Implementation   Custom kernels          Block sparse                 │
│  Primary use      Document understanding  QA, summarization            │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  WHEN TO USE WHICH:                                                      │
│  ──────────────────                                                      │
│                                                                          │
│  Longformer:                                                            │
│  • Document classification                                            │
│  • When global tokens are sufficient for long-range                   │
│  • When you need dilated attention patterns                           │
│                                                                          │
│  BigBird:                                                               │
│  • Question answering over long documents                             │
│  • When theoretical expressiveness guarantees matter                  │
│  • When random attention helps your task                              │
│                                                                          │
│  In practice, both work well for long documents.                      │
│  Choice often comes down to available implementations.                │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Part IV: Other Sparse Patterns

Block Sparse Attention

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    BLOCK SPARSE ATTENTION                                │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  THE CONCEPT:                                                            │
│  ────────────                                                            │
│                                                                          │
│  Divide sequence into blocks, attend to selected blocks.              │
│  More GPU-efficient than arbitrary sparse patterns.                   │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  BLOCK PATTERN (block size = 4):                                        │
│  ────────────────────────────────                                        │
│                                                                          │
│        Block 0   Block 1   Block 2   Block 3                           │
│        0 1 2 3   4 5 6 7   8 9 A B   C D E F                          │
│  B  0  ■ ■ ■ ■   □ □ □ □   □ □ □ □   □ □ □ □                          │
│  l  1  ■ ■ ■ ■   □ □ □ □   □ □ □ □   □ □ □ □                          │
│  o  2  ■ ■ ■ ■   □ □ □ □   □ □ □ □   □ □ □ □                          │
│  c  3  ■ ■ ■ ■   ■ ■ ■ ■   □ □ □ □   □ □ □ □  ← attends to next block│
│  k     ─────────────────────────────────────                           │
│  0  4  □ □ □ □   ■ ■ ■ ■   □ □ □ □   □ □ □ □                          │
│     5  □ □ □ □   ■ ■ ■ ■   □ □ □ □   □ □ □ □                          │
│  B  6  □ □ □ □   ■ ■ ■ ■   □ □ □ □   □ □ □ □                          │
│  l  7  □ □ □ □   ■ ■ ■ ■   ■ ■ ■ ■   □ □ □ □                          │
│  o     ─────────────────────────────────────                           │
│  c  8  □ □ □ □   □ □ □ □   ■ ■ ■ ■   □ □ □ □                          │
│  k  9  □ □ □ □   □ □ □ □   ■ ■ ■ ■   □ □ □ □                          │
│     A  □ □ □ □   □ □ □ □   ■ ■ ■ ■   □ □ □ □                          │
│  1  B  □ □ □ □   □ □ □ □   ■ ■ ■ ■   ■ ■ ■ ■                          │
│                                                                          │
│  Each block attends to itself + adjacent blocks.                      │
│  Matrix operations are on dense blocks → GPU efficient.              │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  GPU EFFICIENCY:                                                         │
│  ───────────────                                                         │
│                                                                          │
│  Arbitrary sparsity: Scattered memory access, poor GPU utilization    │
│  Block sparsity: Dense operations on blocks, near-optimal GPU use    │
│                                                                          │
│  Block sparse can achieve 70-90% of dense attention throughput       │
│  while reducing FLOPs by 10-100×.                                     │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Axial Attention

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    AXIAL ATTENTION                                       │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  THE CONCEPT:                                                            │
│  ────────────                                                            │
│                                                                          │
│  For 2D data (images, tables), attend along axes separately.          │
│  Instead of n² full attention, do 2 × n√n attention operations.      │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  EXAMPLE: 4×4 grid (16 tokens)                                         │
│  ─────────────────────────────                                           │
│                                                                          │
│  Original positions:                                                    │
│   0  1  2  3                                                           │
│   4  5  6  7                                                           │
│   8  9 10 11                                                           │
│  12 13 14 15                                                           │
│                                                                          │
│  Step 1: Row attention (each row attends within itself)               │
│  Token 5 attends to: 4, 5, 6, 7                                       │
│                                                                          │
│  Step 2: Column attention (each column attends within itself)         │
│  Token 5 attends to: 1, 5, 9, 13                                      │
│                                                                          │
│  Combined: Token 5 can reach any cell in 2 hops!                      │
│  Complexity: 2 × 16 × 4 = 128 vs 16² = 256                           │
│                                                                          │
│  For n tokens in √n × √n grid:                                        │
│  Axial: O(n√n)   vs   Full: O(n²)                                     │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  APPLICATIONS:                                                           │
│  ─────────────                                                           │
│                                                                          │
│  • Image processing (ViT variants)                                    │
│  • Tabular data                                                        │
│  • Any data with natural 2D structure                                 │
│  • Can extend to 3D (video) with 3 axial attentions                  │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Sliding Window with Sinks (Mistral's Approach)

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    SLIDING WINDOW WITH ATTENTION SINKS                   │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  MISTRAL'S APPROACH:                                                     │
│  ───────────────────                                                     │
│                                                                          │
│  Sliding window + always attend to first few tokens (attention sinks) │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  PATTERN (window=4, sinks=2):                                           │
│  ─────────────────────────────                                           │
│                                                                          │
│        Keys                                                             │
│        0  1  2  3  4  5  6  7  8  9  10 11                             │
│   0    ■  □  □  □  □  □  □  □  □  □  □  □   Q                         │
│   1    ■  ■  □  □  □  □  □  □  □  □  □  □   u                         │
│   2    ■  ■  ■  □  □  □  □  □  □  □  □  □   e                         │
│   3    ■  ■  ■  ■  □  □  □  □  □  □  □  □   r                         │
│   4    ■  ■  □  ■  ■  □  □  □  □  □  □  □   i                         │
│   5    ■  ■  □  □  ■  ■  □  □  □  □  □  □   e                         │
│   6    ■  ■  □  □  □  ■  ■  □  □  □  □  □   s                         │
│   7    ■  ■  □  □  □  □  ■  ■  □  □  □  □                             │
│   8    ■  ■  □  □  □  □  □  ■  ■  □  □  □                             │
│   9    ■  ■  □  □  □  □  □  □  ■  ■  □  □                             │
│  10    ■  ■  □  □  □  □  □  □  □  ■  ■  □                             │
│  11    ■  ■  □  □  □  □  □  □  □  □  ■  ■                             │
│                                                                          │
│  Columns 0-1 (sinks): Always attended to                              │
│  Diagonal: Local sliding window                                        │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  WHY ATTENTION SINKS?                                                    │
│  ────────────────────                                                    │
│                                                                          │
│  Observation: First tokens often receive high attention.              │
│  This seems to be where models "dump" attention when uncertain.       │
│                                                                          │
│  Without sinks:                                                         │
│  At position 1000 with window 512, position 0 is unreachable.        │
│  Model may behave strangely when it can't access these sinks.        │
│                                                                          │
│  With sinks:                                                            │
│  Position 0-1 always accessible, model behavior more stable.         │
│  Minimal additional cost (just 2-4 extra attentions per token).      │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  STREAMING INFERENCE:                                                    │
│  ────────────────────                                                    │
│                                                                          │
│  With sliding window + sinks:                                          │
│  • KV cache size fixed: sink_size + window_size                       │
│  • Can process infinite sequences!                                    │
│  • Memory: O(1) instead of O(n)                                       │
│                                                                          │
│  This enables Mistral's efficient long-context inference.            │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Part V: Implementation Considerations

Efficient Sparse Attention

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    EFFICIENT IMPLEMENTATION                              │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  CHALLENGE: Making sparse attention fast on GPUs                       │
│  ─────────────────────────────────────────────                           │
│                                                                          │
│  GPUs are optimized for:                                               │
│  • Dense matrix operations (GEMM)                                     │
│  • Coalesced memory access                                            │
│  • High arithmetic intensity                                          │
│                                                                          │
│  Naive sparse attention violates all of these!                        │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  SOLUTION 1: Block-Sparse Operations                                    │
│  ────────────────────────────────────                                    │
│                                                                          │
│  Organize sparsity into blocks. Within each block, compute dense.    │
│                                                                          │
│  Libraries:                                                             │
│  • Triton (OpenAI): Block-sparse attention kernels                    │
│  • DeepSpeed: Sparse attention support                                │
│  • xFormers: Memory-efficient attention with sparsity                 │
│                                                                          │
│  Example with Triton:                                                   │
│                                                                          │
│  from triton.ops.blocksparse import matmul                            │
│                                                                          │
│  # Define block sparsity pattern                                       │
│  layout = torch.tril(torch.ones(num_blocks, num_blocks))             │
│  layout[0, :] = 1  # Global row                                       │
│  layout[:, 0] = 1  # Global column                                    │
│                                                                          │
│  # Create block-sparse attention                                       │
│  sparse_attn = BlockSparseAttention(layout, block_size=64)           │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  SOLUTION 2: FlashAttention with Masks                                  │
│  ─────────────────────────────────────                                   │
│                                                                          │
│  FlashAttention supports attention masks.                              │
│  Some sparse patterns can be expressed as masks.                      │
│                                                                          │
│  # Sliding window mask                                                  │
│  def sliding_window_mask(seq_len, window_size):                       │
│      mask = torch.ones(seq_len, seq_len, dtype=torch.bool)           │
│      for i in range(seq_len):                                         │
│          start = max(0, i - window_size // 2)                        │
│          end = min(seq_len, i + window_size // 2 + 1)                │
│          mask[i, start:end] = False  # False = attend                │
│      return mask                                                       │
│                                                                          │
│  # Use with FlashAttention                                             │
│  output = flash_attn_func(q, k, v, attn_mask=mask)                   │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  SOLUTION 3: Reformulate as Smaller Dense Attentions                   │
│  ────────────────────────────────────────────────────                    │
│                                                                          │
│  For sliding window: Chunk sequence, each chunk is dense.             │
│                                                                          │
│  def chunked_attention(q, k, v, chunk_size):                          │
│      """Process in chunks with overlapping windows."""                │
│      outputs = []                                                      │
│      for i in range(0, seq_len, chunk_size // 2):                    │
│          chunk_q = q[:, i:i+chunk_size]                              │
│          # Include overlap for context                                │
│          chunk_k = k[:, max(0,i-chunk_size//2):i+chunk_size]        │
│          chunk_v = v[:, max(0,i-chunk_size//2):i+chunk_size]        │
│          out = standard_attention(chunk_q, chunk_k, chunk_v)         │
│          outputs.append(out)                                          │
│      return combine_chunks(outputs)                                   │
│                                                                          │
│  This uses standard optimized dense attention on smaller inputs.     │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Library Support

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    LIBRARY AND MODEL SUPPORT                             │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  HUGGING FACE TRANSFORMERS:                                              │
│  ──────────────────────────                                              │
│                                                                          │
│  Longformer:                                                            │
│  from transformers import LongformerModel                              │
│  model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
│                                                                          │
│  BigBird:                                                               │
│  from transformers import BigBirdModel                                 │
│  model = BigBirdModel.from_pretrained("google/bigbird-roberta-base")  │
│                                                                          │
│  Both support up to 4096 tokens out of the box.                       │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  XFORMERS:                                                               │
│  ─────────                                                               │
│                                                                          │
│  from xformers.components.attention import (                           │
│      ScaledDotProductAttention,                                        │
│      AttentionMask,                                                    │
│      BlockSparseAttention,                                             │
│  )                                                                       │
│                                                                          │
│  # Block sparse attention                                               │
│  attn = BlockSparseAttention(                                          │
│      layout=my_sparsity_pattern,                                       │
│      block_size=64,                                                    │
│  )                                                                       │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  MODELS USING SPARSE ATTENTION:                                          │
│  ──────────────────────────────                                          │
│                                                                          │
│  Model               Attention Type       Max Context                  │
│  ─────────────────────────────────────────────────────────────         │
│  Longformer          Sliding + Global     4,096                        │
│  BigBird             Sliding + Random     4,096                        │
│  Mistral 7B          Sliding Window       32,768 (with RoPE)          │
│  Mixtral             Sliding Window       32,768                       │
│  LED                 Longformer-style     16,384                       │
│  LongT5              Local + Transient    16,384                       │
│                                                                          │
│  Most modern long-context models combine:                             │
│  • Sparse attention (for efficiency)                                  │
│  • RoPE scaling (for extrapolation)                                   │
│  • FlashAttention (for memory efficiency)                             │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Part VI: When to Use Sparse Attention

Decision Framework

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    WHEN TO USE SPARSE ATTENTION                          │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  USE SPARSE ATTENTION WHEN:                                              │
│  ──────────────────────────                                              │
│                                                                          │
│  ✓ Sequence length > 4K tokens                                        │
│  ✓ Task doesn't need full all-to-all attention                       │
│  ✓ Most relevant context is local (nearby tokens)                    │
│  ✓ Global context can be captured by few special tokens              │
│                                                                          │
│  GOOD USE CASES:                                                         │
│  ───────────────                                                         │
│  • Document classification (Longformer with [CLS])                    │
│  • Long document QA (BigBird)                                         │
│  • Summarization of long texts                                        │
│  • Code analysis (local context usually sufficient)                  │
│  • Chat with very long history                                        │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  DON'T USE SPARSE ATTENTION WHEN:                                        │
│  ─────────────────────────────────                                       │
│                                                                          │
│  ✗ Sequence length < 4K (full attention is fine)                     │
│  ✗ Task requires precise long-range dependencies                     │
│  ✗ You have GPUs with lots of memory                                 │
│  ✗ FlashAttention solves your memory problem                         │
│                                                                          │
│  BETTER ALTERNATIVES:                                                    │
│  ────────────────────                                                    │
│  • Short sequences: Just use FlashAttention                           │
│  • Need full attention: FlashAttention (O(n) memory, still O(n²) compute)
│  • Very long but can chunk: Chunk and process separately             │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  DECISION TREE:                                                          │
│  ──────────────                                                          │
│                                                                          │
│  Is sequence length > 4K?                                              │
│    NO → Use standard attention + FlashAttention                       │
│    YES ↓                                                               │
│                                                                          │
│  Does task need full attention?                                        │
│    YES → Use FlashAttention (memory-efficient full attention)         │
│    NO ↓                                                                │
│                                                                          │
│  Is context mostly local?                                              │
│    YES → Sliding window (Mistral-style)                               │
│    NO ↓                                                                │
│                                                                          │
│  Need global context aggregation?                                      │
│    YES → Longformer (sliding + global)                                │
│    MAYBE → BigBird (sliding + global + random)                        │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Part VII: Recent Innovations (2024-2025)

DeepSeek Sparse Attention (DSA)

DeepSeek-V3 introduced a novel sparse attention pattern optimized for inference efficiency:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    DEEPSEEK SPARSE ATTENTION (DSA)                       │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  THE INSIGHT:                                                            │
│  ────────────                                                            │
│                                                                          │
│  DeepSeek observed attention patterns during inference:                 │
│  • Recent tokens (last ~4K): High attention, need full detail          │
│  • Middle tokens: Sparse attention to "landmark" positions             │
│  • Beginning tokens: Attention sinks, always accessed                  │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  DSA PATTERN:                                                            │
│  ────────────                                                            │
│                                                                          │
│  For position p with context length L:                                  │
│                                                                          │
│  Attend to:                                                              │
│  1. Sink tokens: positions 0-127 (always)                              │
│  2. Recent window: positions max(0, p-4096) to p (dense)               │
│  3. Landmarks: Every 64th token before the window                      │
│                                                                          │
│  Example at position 10000 (L=128K):                                    │
│  • Sinks: 0-127 (128 tokens)                                           │
│  • Recent: 5904-10000 (4096 tokens, dense)                             │
│  • Landmarks: 128, 192, 256, ..., 5888 (90 tokens)                     │
│  • Total: ~4,314 attentions vs 10,000 for full                         │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  IMPLEMENTATION DETAIL:                                                  │
│  ─────────────────────                                                   │
│                                                                          │
│  def get_dsa_attention_indices(position, context_len):                  │
│      indices = []                                                        │
│                                                                          │
│      # 1. Sink tokens                                                   │
│      indices.extend(range(min(128, context_len)))                       │
│                                                                          │
│      # 2. Recent window (dense)                                         │
│      window_start = max(128, position - 4096)                           │
│      indices.extend(range(window_start, position + 1))                  │
│                                                                          │
│      # 3. Landmarks (every 64th token before window)                    │
│      for i in range(128, window_start, 64):                             │
│          indices.append(i)                                               │
│                                                                          │
│      return sorted(set(indices))                                        │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  RESULTS:                                                                │
│                                                                          │
│  Context Length    Full Attention    DSA Attention    Speedup          │
│  ─────────────────────────────────────────────────────────────         │
│  32K               ~32K tokens       ~4.6K tokens     7×               │
│  128K              ~128K tokens      ~6.3K tokens     20×              │
│  256K              ~256K tokens      ~8.3K tokens     31×              │
│                                                                          │
│  Quality: <0.5% perplexity degradation on benchmarks                   │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

MInference: Dynamic Sparse Patterns

MInference (Microsoft, 2024) discovers optimal sparse patterns dynamically:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    MINFERENCE: LEARNED SPARSE PATTERNS                   │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  THE APPROACH:                                                           │
│  ──────────────                                                          │
│                                                                          │
│  Instead of fixed patterns (window, global), learn optimal sparsity    │
│  per attention head based on observed attention patterns.               │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  DISCOVERED PATTERNS (from LLM attention analysis):                     │
│  ──────────────────────────────────────────────────                      │
│                                                                          │
│  1. A-SHAPE: Attention to first tokens + recent tokens                 │
│     Common in: early layers, position-sensitive heads                  │
│     Pattern: [sink region] + [recent window]                           │
│                                                                          │
│  2. VERTICAL-SLASH: Strong attention to specific columns              │
│     Common in: middle layers, key-focused heads                        │
│     Pattern: Full attention to a few key positions                     │
│                                                                          │
│  3. BLOCK-SPARSE: Attention in dense blocks                            │
│     Common in: all layers, local-focused heads                         │
│     Pattern: Block diagonal + scattered blocks                         │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  HEAD-SPECIFIC PATTERN SELECTION:                                        │
│  ─────────────────────────────────                                       │
│                                                                          │
│  MInference profiles attention patterns during a calibration run:      │
│                                                                          │
│  for layer in model.layers:                                              │
│      for head in layer.attention.heads:                                 │
│          pattern = analyze_attention_distribution(head)                │
│          if is_a_shape(pattern):                                        │
│              head.sparse_config = AShapeSparse(sink=128, recent=4096)  │
│          elif is_vertical(pattern):                                     │
│              head.sparse_config = VerticalSparse(key_positions=[...])  │
│          else:                                                          │
│              head.sparse_config = BlockSparse(block_size=64)           │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  RESULTS (Llama-3-70B, 128K context):                                   │
│                                                                          │
│  Method              Latency    Memory    Quality                       │
│  ─────────────────────────────────────────────────────────             │
│  Full Attention      Baseline   Baseline  Baseline                     │
│  Fixed Window        0.4×       0.3×      Degraded                     │
│  MInference          0.5×       0.4×      Near-lossless               │
│                                                                          │
│  MInference achieves 2× speedup with minimal quality loss.             │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

SparseAttn: FlashAttention for Sparse Patterns

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    SPARSEATTN: EFFICIENT SPARSE KERNELS                  │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  THE CHALLENGE:                                                          │
│  ───────────────                                                         │
│                                                                          │
│  Sparse attention saves FLOPs but naive implementations are slow:      │
│  • Irregular memory access patterns                                    │
│  • Can't use optimized dense GEMM kernels                             │
│  • Overhead from index management                                      │
│                                                                          │
│  Result: Theoretical speedup ≠ actual speedup                          │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  SPARSEATTN APPROACH:                                                    │
│  ────────────────────                                                    │
│                                                                          │
│  Apply FlashAttention's tiling strategy to sparse patterns:            │
│                                                                          │
│  1. Divide sparse pattern into blocks                                  │
│  2. Identify which blocks are non-zero                                 │
│  3. Only compute attention for non-zero blocks                         │
│  4. Use FlashAttention's memory-efficient tiling within blocks         │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  SUPPORTED PATTERNS:                                                     │
│                                                                          │
│  Pattern              SparseAttn Support    Speedup vs Dense           │
│  ─────────────────────────────────────────────────────────────         │
│  Sliding Window       Yes                   2-10× (seq dependent)      │
│  Block Diagonal       Yes                   Near-linear scaling        │
│  Longformer-style     Yes                   3-8×                       │
│  BigBird-style        Partial               2-5×                       │
│  Arbitrary Sparse     Limited               Depends on structure       │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  USAGE:                                                                  │
│                                                                          │
│  from sparseattn import sparse_flash_attention                         │
│                                                                          │
│  # Define sparse pattern as block mask                                  │
│  block_mask = create_sliding_window_mask(                               │
│      seq_len=32768,                                                     │
│      block_size=256,                                                    │
│      window_blocks=16  # 4K window                                     │
│  )                                                                       │
│                                                                          │
│  output = sparse_flash_attention(                                       │
│      query, key, value,                                                 │
│      block_mask=block_mask,                                             │
│      block_size=256                                                     │
│  )                                                                       │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

NSA: Native Sparse Attention

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    NATIVE SPARSE ATTENTION (NSA)                         │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  THE CONCEPT:                                                            │
│  ────────────                                                            │
│                                                                          │
│  Train models with sparse attention from scratch, not as post-hoc.     │
│  The model learns to work with limited attention budget.               │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  HYBRID DESIGN:                                                          │
│  ──────────────                                                          │
│                                                                          │
│  NSA uses different sparsity at different layers:                       │
│                                                                          │
│  Early layers (0-8): Local window + compression tokens                 │
│  • Local details need full resolution                                  │
│  • Compress distant context into summary tokens                        │
│                                                                          │
│  Middle layers (8-24): Selective attention                             │
│  • Attend to positions based on learned selection                      │
│  • Top-k attention routing                                             │
│                                                                          │
│  Late layers (24-32): Mixed local + global summaries                   │
│  • Local for recent context                                            │
│  • Global summary tokens for distant context                           │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  COMPRESSION TOKENS:                                                     │
│  ───────────────────                                                     │
│                                                                          │
│  Instead of attending to all distant tokens, compress them:            │
│                                                                          │
│  Tokens 0-4096: Summarize into 64 "compression tokens"                 │
│  Tokens 4097-8192: Summarize into 64 compression tokens                │
│  ... and so on                                                          │
│                                                                          │
│  New tokens attend to:                                                  │
│  • Recent 4K tokens (full resolution)                                  │
│  • Compression tokens for older context (64× compression)              │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  RESULTS (pre-trained from scratch):                                    │
│                                                                          │
│  • 10× FLOPs reduction vs full attention                              │
│  • 95-98% performance retention on benchmarks                          │
│  • Native 128K context without quality degradation                     │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Updated Model Comparison

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    SPARSE ATTENTION MODELS (2025)                        │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  Model            Pattern                 Context    Key Innovation     │
│  ─────────────────────────────────────────────────────────────────────  │
│  Longformer       Window + Global         4K         Original sparse    │
│  BigBird          Window + Global + Rand  4K         Theoretical basis  │
│  Mistral          Sliding + Sinks         32K        Streaming KV       │
│  Mixtral          Sliding + Sinks         32K        MoE + sparse       │
│  Phi-3            Sliding window          128K       Efficient small    │
│  DeepSeek-V3      DSA (adaptive)          128K       Dynamic landmarks  │
│  Llama 3.1        Full + FlashAttention   128K       Scaling, not sparse│
│  Gemini 1.5       Unknown (likely sparse) 1M+        Best long context  │
│                                                                          │
│  INFERENCE-ONLY METHODS:                                                │
│  • MInference: Dynamic per-head patterns                               │
│  • SparseAttn: Efficient sparse kernels                                │
│  • vLLM sparse: Production sparse inference                            │
│                                                                          │
│  TRAINABLE SPARSE (2025):                                               │
│  • NSA: Dynamic hierarchical, 9× inference speedup                    │
│  • MoBA: Block-sparse with gating, 6.5× at 1M context                │
│  • SeerAttention: Self-distilled gating, 5.67× vs FlashAttention-2   │
│                                                                          │
│  FLASHATTENTION-3 (2025):                                               │
│  • 1.5-2× speedup on H100 vs FlashAttention-2                        │
│  • 85% GPU utilization with BF16 (up from 35%)                       │
│  • 1.3 PFLOPs/s with FP8 low-precision                               │
│  • Warp-specialization + asynchronous tensor cores                   │
│                                                                          │
│  BLASST (2025 - Drop-in Sparse):                                       │
│  • Dynamic attention pruning without pre-computation                  │
│  • Works with MHA, GQA, MQA, and MLA variants                        │
│  • Fits into existing FlashAttention kernel designs                  │
│                                                                          │
│  MORE 2025 METHODS:                                                    │
│  • PBS-Attn: Permuted block-sparse, 2.75× prefill speedup           │
│  • InfLLM-V2: Dense-sparse switchable for short-to-long adaptation  │
│  • AdaSplash: Adaptive sparse flash attention                        │
│  • LServe: Efficient long-sequence serving with unified sparse       │
│  • GNA: Generalized Neighborhood Attention at speed of light        │
│                                                                          │
│  TREND: Trainable sparse + FlashAttention-3 + dynamic pruning        │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Summary

Sparse attention patterns enable efficient processing of long sequences by attending to subsets of tokens rather than all pairs:

Code
┌─────────────────────────────────────────────────────────────────────────┐
│                    KEY TAKEAWAYS                                         │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  THE PROBLEM:                                                            │
│  • Full attention is O(n²) in memory and compute                      │
│  • 32K context → 2TB attention memory (impossible)                    │
│                                                                          │
│  CORE PATTERNS:                                                          │
│  • Local/Sliding Window: Attend to nearby tokens O(n×w)               │
│  • Global: Select tokens attend to everything O(n×g)                  │
│  • Random: Random connections for shortcuts O(n×r)                    │
│  • Block Sparse: Organize sparsity into dense blocks                  │
│                                                                          │
│  KEY MODELS:                                                             │
│  • Longformer: Local + Global + optional Dilated                      │
│  • BigBird: Local + Global + Random (theoretically complete)          │
│  • Mistral: Sliding Window + Attention Sinks                          │
│                                                                          │
│  2024-2025 INNOVATIONS:                                                 │
│  • NSA: Native Sparse Attention (ACL 2025 Best Paper)               │
│    - 9× faster inference, 6× faster training on 64K sequences       │
│    - Dynamic hierarchical: compression + selection + sliding window │
│  • MoBA: Mixture of Block Attention (Kimi, Feb 2025)                │
│    - Block sparse with gating (like MoE for attention)              │
│    - 6.5× speedup at 1M context, production-deployed at kimi.ai    │
│  • SeerAttention: Microsoft NeurIPS 2025                            │
│    - Learned sparse patterns via self-distillation                  │
│    - 90% sparsity, 5.67× speedup over FlashAttention-2             │
│  • DSA: Adaptive sparsity (sinks + window + landmarks)               │
│  • MInference: Per-head learned sparse patterns                      │
│                                                                          │
│  PRACTICAL GUIDANCE:                                                     │
│  • < 4K tokens: Just use FlashAttention                               │
│  • > 4K, local context: Sliding window                                │
│  • > 4K, need global: Longformer or BigBird                          │
│  • > 32K inference: Consider MInference or DSA                        │
│  • Modern models: Combine sparse + RoPE scaling + FlashAttention     │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Frequently Asked Questions

Enrico Piovano, PhD

Co-founder & CTO at Goji AI. Former Applied Scientist at Amazon (Alexa & AGI), focused on Agentic AI and LLMs. PhD in Electrical Engineering from Imperial College London. Gold Medalist at the National Mathematical Olympiad.

Related Articles