NF4 quantization, double quant, dequant math, and paged optimizers
QLoRA enables parameter‑efficient fine‑tuning on a 4‑bit quantized base model while training small LoRA adapters in higher precision. The base weights are frozen and stored with ultra‑low memory, while adapters capture task‑specific updates. Coupled with paged optimizers, this allows single‑GPU fine‑tuning of multi‑billion‑parameter models with minimal quality loss.
NormalFloat4 (NF4) is a 4‑bit, non‑uniform quantizer whose 16 codebook values approximate the cumulative distribution of standardized weights (roughly Gaussian). Rather than linear steps, NF4 places more levels near 0 to preserve small weights.
Here c[q] is a fixed 16‑entry codebook on standardized space, s is a per‑group or per‑channel scale, and b is an optional bias/offset (often 0). Quantization selects the nearest code:
Grouping choices (e.g., 64 weights per group) trade off fidelity and metadata overhead. Per‑channel scales give higher fidelity but slightly larger metadata.
To reduce overhead further, QLoRA “double‑quantizes” scales: instead of storing s as FP16 per group, it stores quantized scales ŝ using an 8‑bit quantizer with its own codebook and a small set of shared scales. This can halve the metadata memory for scales with negligible loss.
At runtime, dequant uses the small codebook d and reconstructs s before reconstructing x̂. In practice, libraries handle this transparently.
During forward passes, matrix multiplications operate on dequantized values on‑the‑fly (often in fused kernels):
Where Q stores 4‑bit codes, gather(c,Q) looks up codebook entries, s broadcasts per group/channel, and ⊙ is elementwise scaling. Fused dequant‑GEMM minimizes memory traffic.
Since base weights are frozen, only adapter parameters need optimizer state. Paged Adam(W) further reduces GPU memory by keeping optimizer state on CPU or paged buffers and pulling it to GPU on demand. This is crucial for long sequences and large batch sizes.
| Component | Precision | Approx. Memory | Notes |
|---|---|---|---|
| Base weights (frozen) | NF4 (4‑bit) + scales | ≈ 0.5 bytes/param + 0.05–0.15 bytes/param | Codebook indices + per‑group scales |
| LoRA A,B | FP16/BF16 | 2·d·r per d×d layer (×2 bytes) | Trainable; often r∈[4,16] |
| Optimizer state (adapters) | FP32 (CPU paged) | ≈ 2× adapter params (m,v) | Paged to CPU to save VRAM |
| KV cache (inference) | FP16/BF16 | O(L·n_layers·d) | Dominates long‑context inference |
For a 7B model, 4‑bit weights are ~3.5 GB plus ~0.4–1.0 GB for scales/metadata depending on grouping, leaving room for adapters and context on a 24 GB GPU.
# Codebook c of length 16 (sorted). Group size G (e.g., 64)
def quantize_nf4(x, group_size=64):
# x: 1D tensor of weights (FP32)
groups = x.view(-1, group_size)
scales = groups.std(dim=1, keepdim=True).clamp_min(1e-8)
x_hat = groups / scales # standardize
# nearest code index per element
q = (x_hat.unsqueeze(-1) - c.view(1,1,-1)).abs().argmin(dim=-1).to(torch.uint8)
return q, scales.squeeze(1) # store q (4-bit) + per-group scale
def dequantize_nf4(q, scales):
return (c[q] * scales.unsqueeze(1))
# LoRA on quantized base
class QLoRALinear(nn.Module):
def __init__(self, in_f, out_f, r=8, alpha=16):
super().__init__()
self.q_codes, self.scales = ... # NF4 storage (frozen)
self.A = nn.Parameter(torch.zeros(r, in_f))
self.B = nn.Parameter(torch.zeros(out_f, r))
nn.init.kaiming_uniform_(self.A, a=5**0.5); nn.init.zeros_(self.B)
self.scaling = alpha / r
def forward(self, x):
W = dequantize_nf4(self.q_codes, self.scales) # fused in practice
y = x @ W.T
y = y + self.scaling * (x @ self.A.T @ self.B.T)
return y
Real implementations fuse dequant and matmul in custom kernels and keep q_codes/scales in compact layouts.
Quantization introduces approximation error. For NF4, the non‑uniform codebook c concentrates bins near 0, reducing MSE for weight distributions that are approximately zero‑mean Gaussian.
Best scales s minimize MSE within a group. Empirically, per‑channel scales are most faithful but slightly larger metadata than per‑group. Outlier channels benefit from per‑channel scaling.
Scales can be stored per‑tensor, per‑row, per‑channel, or per‑group of G weights (e.g., G=64). Smaller groups → better fidelity but more scale metadata.
Efficient inference/training fuses dequantization into matmul kernels to reduce memory traffic. Tiles of codes are loaded, dequantized into registers/shared memory, then multiplied with activation tiles.
# Pseudocode of fused kernel loop
for tile in weight_codes.tiles():
q = global_load(tile)
w_tile = dequantize(codebook, q, scales_for(tile)) # shared/registers
y_tile += matmul(w_tile, x_tile)
Adapters are commonly applied to attention projections (Q and V) and optionally to MLPs. With QLoRA, the base QKV/O weights are quantized; adapters are FP16/BF16.
Assume a 7B model (7×109 params), NF4 with per‑group scales (~0.6 bytes/param effective), LoRA rank r=8 on attention and MLP layers covering ~70% of params with square d×d blocks.
Typical adapter memory: 100–300 MB in FP16 depending on which layers are targeted. Optimizer states (paged) live on CPU (~2× adapter size). This easily fits on a 24 GB GPU with ample KV cache for long contexts.