QLoRA Internals

NF4 quantization, double quant, dequant math, and paged optimizers

Why QLoRA

QLoRA enables parameter‑efficient fine‑tuning on a 4‑bit quantized base model while training small LoRA adapters in higher precision. The base weights are frozen and stored with ultra‑low memory, while adapters capture task‑specific updates. Coupled with paged optimizers, this allows single‑GPU fine‑tuning of multi‑billion‑parameter models with minimal quality loss.

NF4 Quantization

NormalFloat4 (NF4) is a 4‑bit, non‑uniform quantizer whose 16 codebook values approximate the cumulative distribution of standardized weights (roughly Gaussian). Rather than linear steps, NF4 places more levels near 0 to preserve small weights.

x ∈ ℝ → \hat{x} = s·c[q] + b,   q ∈ {0,…,15}

Here c[q] is a fixed 16‑entry codebook on standardized space, s is a per‑group or per‑channel scale, and b is an optional bias/offset (often 0). Quantization selects the nearest code:

q(x) = argminq | (x − b)/s − c[q] |

Grouping choices (e.g., 64 weights per group) trade off fidelity and metadata overhead. Per‑channel scales give higher fidelity but slightly larger metadata.

Double Quantization

To reduce overhead further, QLoRA “double‑quantizes” scales: instead of storing s as FP16 per group, it stores quantized scales ŝ using an 8‑bit quantizer with its own codebook and a small set of shared scales. This can halve the metadata memory for scales with negligible loss.

s ≈ s0 · d[k],   k ∈ {0,…,255}

At runtime, dequant uses the small codebook d and reconstructs s before reconstructing x̂. In practice, libraries handle this transparently.

Dequantization Math

During forward passes, matrix multiplications operate on dequantized values on‑the‑fly (often in fused kernels):

\hat{W} = s ⊙ gather(c, Q) + b,   y = \hat{W}x

Where Q stores 4‑bit codes, gather(c,Q) looks up codebook entries, s broadcasts per group/channel, and ⊙ is elementwise scaling. Fused dequant‑GEMM minimizes memory traffic.

Paged Optimizers

Since base weights are frozen, only adapter parameters need optimizer state. Paged Adam(W) further reduces GPU memory by keeping optimizer state on CPU or paged buffers and pulling it to GPU on demand. This is crucial for long sequences and large batch sizes.

Memory Accounting

Component Precision Approx. Memory Notes
Base weights (frozen) NF4 (4‑bit) + scales ≈ 0.5 bytes/param + 0.05–0.15 bytes/param Codebook indices + per‑group scales
LoRA A,B FP16/BF16 2·d·r per d×d layer (×2 bytes) Trainable; often r∈[4,16]
Optimizer state (adapters) FP32 (CPU paged) ≈ 2× adapter params (m,v) Paged to CPU to save VRAM
KV cache (inference) FP16/BF16 O(L·n_layers·d) Dominates long‑context inference

For a 7B model, 4‑bit weights are ~3.5 GB plus ~0.4–1.0 GB for scales/metadata depending on grouping, leaving room for adapters and context on a 24 GB GPU.

Interactive: Memory Calculator

Model params (B)
Billions
Base bytes/param
NF4 ~0.5–0.6
Scales bytes/param
Double‑quant ~0.05–0.15
Targeted fraction
Adapters cover this fraction
Rank r
LoRA rank
Hidden size d
Approximate
KV cache (GB)
At target context
Base
Scales/Meta
Adapters
KV cache
Total (GPU)
optimizer

Training Pipeline (Diagram)

Input tokens Dequant + GEMM (NF4) LoRA BA (FP16) Output activations Paged AdamW (CPU)
Fused dequant‑GEMM runs on the frozen 4‑bit base; LoRA adapters add residual updates; paged optimizers keep adapter states off‑GPU.

NF4 Quant/Dequant Sketch

# Codebook c of length 16 (sorted). Group size G (e.g., 64)
def quantize_nf4(x, group_size=64):
    # x: 1D tensor of weights (FP32)
    groups = x.view(-1, group_size)
    scales = groups.std(dim=1, keepdim=True).clamp_min(1e-8)
    x_hat = groups / scales  # standardize
    # nearest code index per element
    q = (x_hat.unsqueeze(-1) - c.view(1,1,-1)).abs().argmin(dim=-1).to(torch.uint8)
    return q, scales.squeeze(1)  # store q (4-bit) + per-group scale

def dequantize_nf4(q, scales):
    return (c[q] * scales.unsqueeze(1))

# LoRA on quantized base
class QLoRALinear(nn.Module):
    def __init__(self, in_f, out_f, r=8, alpha=16):
        super().__init__()
        self.q_codes, self.scales = ...  # NF4 storage (frozen)
        self.A = nn.Parameter(torch.zeros(r, in_f))
        self.B = nn.Parameter(torch.zeros(out_f, r))
        nn.init.kaiming_uniform_(self.A, a=5**0.5); nn.init.zeros_(self.B)
        self.scaling = alpha / r
    def forward(self, x):
        W = dequantize_nf4(self.q_codes, self.scales)  # fused in practice
        y = x @ W.T
        y = y + self.scaling * (x @ self.A.T @ self.B.T)
        return y

Real implementations fuse dequant and matmul in custom kernels and keep q_codes/scales in compact layouts.

Tips & Pitfalls

Quantization Error & Fidelity

Quantization introduces approximation error. For NF4, the non‑uniform codebook c concentrates bins near 0, reducing MSE for weight distributions that are approximately zero‑mean Gaussian.

\operatorname{MSE}(s) = \mathbb{E}\big[ (x - s·c[q(x/s)])^2 \big],\; q(z)=\arg\min_j |z - c[j]|\,,\; x \sim \mathcal{N}(0,\sigma^2)

Best scales s minimize MSE within a group. Empirically, per‑channel scales are most faithful but slightly larger metadata than per‑group. Outlier channels benefit from per‑channel scaling.

NF4 Codebook Density (schematic) Codebook levels (schematic, concentrated near 0)
NF4 places more quantization levels near zero, reducing MSE for typical weight distributions.

Grouping Strategies & Metadata

Scales can be stored per‑tensor, per‑row, per‑channel, or per‑group of G weights (e.g., G=64). Smaller groups → better fidelity but more scale metadata.

Weights stream G=64 G=64 G=64 Scales (per group) Double‑quant codes
Groups of G weights share a scale; scales themselves can be quantized (double quantization).

Fused Dequant‑GEMM (Kernel View)

Efficient inference/training fuses dequantization into matmul kernels to reduce memory traffic. Tiles of codes are loaded, dequantized into registers/shared memory, then multiplied with activation tiles.

Ytile += (S ⊙ C[Qtile]) · Xtile,   streaming softmax/accumulation for attention blocks
NF4 codes (global mem) Dequant tile (SMEM) GEMM with X tile Tiles stream left→right within kernel
Kernel loads NF4 code tiles, dequantizes into fast memory, and multiplies with activation tiles.
# Pseudocode of fused kernel loop
for tile in weight_codes.tiles():
    q = global_load(tile)
    w_tile = dequantize(codebook, q, scales_for(tile))  # shared/registers
    y_tile += matmul(w_tile, x_tile)

LoRA Placement in Transformers

Adapters are commonly applied to attention projections (Q and V) and optionally to MLPs. With QLoRA, the base QKV/O weights are quantized; adapters are FP16/BF16.

W_q (NF4) LoRA: B_q A_q W_k (NF4) W_v (NF4) LoRA: B_v A_v W_o (NF4) (optional LoRA)
Quantized base projections with FP16 adapters on Q and V (common), K optional, O sometimes.

Worked Memory Example (7B)

Assume a 7B model (7×109 params), NF4 with per‑group scales (~0.6 bytes/param effective), LoRA rank r=8 on attention and MLP layers covering ~70% of params with square d×d blocks.

Base ≈ 7e9 × 0.6 B ≈ 4.2 GB,   Adapters per d×d ≈ 2dr; overall ≈ 0.7·(2dr) aggregated

Typical adapter memory: 100–300 MB in FP16 depending on which layers are targeted. Optimizer states (paged) live on CPU (~2× adapter size). This easily fits on a 24 GB GPU with ample KV cache for long contexts.

Base (≈4.2 GB) Scales/Meta Adapters KV cache headroom
Rough memory split on a 24 GB GPU: base weights, metadata, adapters, and KV cache.

Training Stability & Scaling

Validation Checklist