FlashAttention

IO‑aware, numerically‑stable tiled attention with streaming softmax

Motivation

Naive attention computes QKᵀ and stores the full n×n score matrix, which is memory‑bound and O(n²) in space. FlashAttention reorders computations to avoid materializing QKᵀ, tiling over queries/keys into shared memory, and computing softmax in a streaming, numerically‑stable fashion. The result is large wall‑clock speedups and reduced memory traffic.

Baseline vs Flash

Attn(Q,K,V) = softmax\big( QK^T / \sqrt{d_k} + M \big) V

With masking M (e.g., causal). Naive implementations form S=QKᵀ, apply mask, compute softmax row‑wise, and multiply by V. FlashAttention tiles K/V blocks and processes each query tile against them, maintaining per‑row running max and normalization to compute softmax without ever storing S.

Streaming Softmax

Given blocks b: m = max(m, m_b),   l = l·e^{m−m'} + \sum_b e^{s_b−m'}\,,   o = o·e^{m−m'} + \sum_b e^{s_b−m'} V_b

Where s_b are block scores for a query row against a key tile; m is the running max for numerical stability, l the running normalizer, and o the running output numerator. This mirrors log‑sum‑exp in blocks.

Q tile K tile (SMEM) update m,l,o V tile O
Each Q tile streams over K/V tiles, updating running softmax stats and outputs without storing QKᵀ.

Causal Masking

For autoregressive models, FlashAttention applies causal masks at the tile level by zeroing or −∞ masking the upper‑triangular parts as tiles advance, guaranteeing correctness.

Backward Pass

Backward recomputes local softmax stats (or caches per‑tile) and accumulates dQ,dK,dV using the same tiling. Memory usage remains linear in n·d rather than n². Care is taken to maintain numerical stability by reusing running max/normalizers.

IO‑Aware Complexity

VariantGlobal Mem TrafficWorkspace
NaiveO(n²d)n×n scores
FlashAttentionO(nd√d)–O(nd·tile)no n×n

The exact bound depends on tile sizes and SRAM; the key is avoiding the n² materialization.

Pseudocode (Forward)

for q_tile in tile_rows(Q):
  m = -inf; l = 0; o = 0
  for (k_tile, v_tile) in zip(tile_cols(K), tile_cols(V)):
    s = (q_tile @ k_tile.T) / sqrt(d)
    s = apply_mask(s)
    m_new = max(m, row_max(s))
    l = l * exp(m - m_new) + row_sum(exp(s - m_new))
    o = o * exp(m - m_new) + exp(s - m_new) @ v_tile
    m = m_new
  out[q_tile_rows] = o / l

Interactive: Tile Estimator

Sequence length n
Q tile size
rows per tile
K tile size
cols per tile
Tiles processed
≈ ceil(n/tq)*ceil(n/tk)
Naive scores
entries in QKᵀ
Flash tile‑ops
relative work units
Workspace saved
no n² scores