Unifying preference learning with policy optimization
Preference‑based training aligns a model to human choices using paired or grouped responses. GRPO generalizes direct preference optimization by weighting multiple candidates and optionally constraining updates via KL regularization or clipping, combining the strengths of DPO‑style losses and PPO‑style stability.
Given a prompt x and a set of K responses {yi} from a language model (policy) πθ(y|x), we want the policy to place higher probability mass on preferred responses. We denote a frozen reference policy πref (often the base SFT model) that anchors updates.
To avoid length bias, use a normalized score s(y|x) = (1/T)·log πθ(y|x) or add an explicit length penalty. We also define a scaled, reference‑relative score used by many preference losses:
Weights wi can be derived from preference labels (e.g., +1/−1), Bradley‑Terry scores, or a learned reward model. A trust‑region flavor can be added by clipping probability ratios similar to PPO.
Different preference assumptions induce different weights:
Define the GRPO objective (minimization form) with KL regularization and optional clipping surrogate:
Gradient (ignoring dependence of w on θ for a simple, effective estimator):
Detaching wi (stop‑gradient) avoids second‑order terms and is standard in practice.
Sequence log‑probabilities favor shorter outputs. Common fixes:
Clipping limits harmful updates when πθ drifts from πref. The plot shows the conservative surrogate.
# Inputs: prompts x, reference π_ref, current policy π_θ, K candidates each
for batch in dataloader:
# 1) Sample K responses per prompt (top-p, T) or read from buffer
Y = sample_candidates(model=π_θ, prompts=batch.x, K=K)
# 2) Compute sequence log-probs under π_θ and π_ref (mask padding)
logp = logprob(π_θ, batch.x, Y); logp_ref = logprob(π_ref, batch.x, Y)
# 3) Build relative scores and weights
S = β * (normalize(logp) - normalize(logp_ref)) # length-normalized
w = listwise_softmax(S) # or pairwise/heuristic weights
# 4) Loss with KL or clipping surrogate
if use_clipping:
r = exp(logp - logp_ref)
obj = mean( min(r*w, clip(r, 1-ε, 1+ε)*w) ) - λ * mean(KL(π_θ||π_ref))
loss = -obj
else:
loss = -mean(w * logp) + λ * mean(KL(π_θ||π_ref))
# 5) Update θ (detach w to avoid second-order terms)
loss.backward(); clip_grad_norm_(model.parameters(), 1.0); optimizer.step(); optimizer.zero_grad()
freeze π_ref
for minibatch of prompts x:
sample K candidates {y_i} ~ current π_θ or from buffer
compute weights w_i from preferences or scores
if using clipping:
r_i = π_θ(y_i|x)/π_ref(y_i|x)
obj = mean(min(r_i*w_i, clip(r_i,1-ε,1+ε)*w_i)) - λ*KL(π_θ||π_ref)
else:
obj = mean(w_i * log π_θ(y_i|x)) - λ*KL(π_θ||π_ref)
ascend θ on obj