Policy improvement with KL‑constrained natural gradient
TRPO uses a second‑order approximation to solve this trust‑region problem and compute a step along the natural gradient direction using conjugate gradients.
Backtracking line search ensures the KL constraint holds empirically on the batch.
We do not form H explicitly. Instead, we define Fisher‑vector products (Fv) via automatic differentiation on KL and use CG to solve Hx=g. Typically 10–20 iterations suffice.
# Pseudocode for CG with Fisher-vector product
def fisher_vec(v): return grad( KL(π_θ||π_{θ+εv}), ε ) |_{ε=0}
x = 0; r = g - fisher_vec(x); p = r; rs = r.T @ r
for _ in range(K):
Ap = fisher_vec(p)
α = rs / (p.T @ Ap + 1e-8)
x = x + α p
r = r - α Ap
rs_new = r.T @ r
if sqrt(rs_new) < tol: break
p = r + (rs_new/rs) p; rs = rs_new
compute g = ∇_θ L(θ)
solve Hx = g approximately via conjugate gradients using Fisher‑vector products
scale step: step = √(2δ / (x^T H x)) * x
backtracking line search until KL(θ, θ + step) ≤ δ and L improves
apply step to θ