IO‑aware, numerically‑stable tiled attention with streaming softmax
Naive attention computes QKᵀ and stores the full n×n score matrix, which is memory‑bound and O(n²) in space. FlashAttention reorders computations to avoid materializing QKᵀ, tiling over queries/keys into shared memory, and computing softmax in a streaming, numerically‑stable fashion. The result is large wall‑clock speedups and reduced memory traffic.
With masking M (e.g., causal). Naive implementations form S=QKᵀ, apply mask, compute softmax row‑wise, and multiply by V. FlashAttention tiles K/V blocks and processes each query tile against them, maintaining per‑row running max and normalization to compute softmax without ever storing S.
Where s_b are block scores for a query row against a key tile; m is the running max for numerical stability, l the running normalizer, and o the running output numerator. This mirrors log‑sum‑exp in blocks.
For autoregressive models, FlashAttention applies causal masks at the tile level by zeroing or −∞ masking the upper‑triangular parts as tiles advance, guaranteeing correctness.
Backward recomputes local softmax stats (or caches per‑tile) and accumulates dQ,dK,dV using the same tiling. Memory usage remains linear in n·d rather than n². Care is taken to maintain numerical stability by reusing running max/normalizers.
| Variant | Global Mem Traffic | Workspace |
|---|---|---|
| Naive | O(n²d) | n×n scores |
| FlashAttention | O(nd√d)–O(nd·tile) | no n×n |
The exact bound depends on tile sizes and SRAM; the key is avoiding the n² materialization.
for q_tile in tile_rows(Q):
m = -inf; l = 0; o = 0
for (k_tile, v_tile) in zip(tile_cols(K), tile_cols(V)):
s = (q_tile @ k_tile.T) / sqrt(d)
s = apply_mask(s)
m_new = max(m, row_max(s))
l = l * exp(m - m_new) + row_sum(exp(s - m_new))
o = o * exp(m - m_new) + exp(s - m_new) @ v_tile
m = m_new
out[q_tile_rows] = o / l