Skip to content
Published on

[Deep RL] 15. Trust Region Methods: TRPO, PPO, ACKTR

Authors

Overview

One of the key problems in policy gradient methods is step size selection. If too large, the policy can suddenly deteriorate beyond recovery; if too small, learning becomes excessively slow. Trust Region methods solve this problem by limiting the size of policy updates.

This post sets an A2C baseline on the Roboschool environment, then examines the principles and implementations of PPO, TRPO, and ACKTR.


Roboschool Environment

Roboschool (now PyBullet) is an open-source alternative to MuJoCo that provides various robot control environments:

  • HalfCheetah: Forward locomotion of a bipedal cheetah robot
  • Hopper: Balance and movement of a single-legged jumping robot
  • Walker2D: Bipedal walking robot
  • Humanoid: Walking of a humanoid robot
import gymnasium as gym

def make_env(env_name='HalfCheetah-v4'):
    env = gym.make(env_name)
    obs_size = env.observation_space.shape[0]
    act_size = env.action_space.shape[0]
    act_limit = env.action_space.high[0]
    return env, obs_size, act_size, act_limit

A2C Baseline

A continuous action space A2C baseline for comparison:

import torch
import torch.nn as nn
import numpy as np

class A2CBaseline(nn.Module):
    def __init__(self, obs_size, act_size):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(obs_size, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
        )
        self.mu = nn.Linear(64, act_size)
        self.log_std = nn.Parameter(torch.zeros(act_size))
        self.value = nn.Linear(64, 1)

    def forward(self, obs):
        features = self.shared(obs)
        mu = self.mu(features)
        std = self.log_std.exp()
        value = self.value(features)
        return mu, std, value

A2C is simple but sensitive to the learning rate. A single large update can collapse the policy.


Proximal Policy Optimization (PPO)

PPO was proposed by Schulman et al. at OpenAI in 2017. It approximates TRPO's complex constrained optimization with simple clipping. It is the most widely used algorithm due to its simple implementation and excellent performance.

Clipping Objective

The core of PPO is clipping the policy ratio to prevent excessively large updates:

Policy ratio r(t) = probability under new policy / probability under old policy

The clipped objective restricts r(t) to the range (1-epsilon, 1+epsilon). Typically epsilon = 0.2 is used.

class PPOAgent:
    def __init__(self, obs_size, act_size, clip_epsilon=0.2,
                 lr=3e-4, gamma=0.99, lam=0.95):
        self.clip_epsilon = clip_epsilon
        self.gamma = gamma
        self.lam = lam

        self.model = A2CBaseline(obs_size, act_size)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)

    def compute_ppo_loss(self, obs, actions, old_log_probs,
                         advantages, returns):
        """PPO clipping objective"""
        mu, std, values = self.model(obs)
        dist = torch.distributions.Normal(mu, std)

        # Log probability under current policy
        new_log_probs = dist.log_prob(actions).sum(dim=-1)
        entropy = dist.entropy().sum(dim=-1).mean()

        # Policy ratio
        ratio = torch.exp(new_log_probs - old_log_probs)

        # Clipped objective
        surr1 = ratio * advantages
        surr2 = torch.clamp(ratio,
                            1.0 - self.clip_epsilon,
                            1.0 + self.clip_epsilon) * advantages
        policy_loss = -torch.min(surr1, surr2).mean()

        # Value loss
        value_loss = (returns - values.squeeze()).pow(2).mean()

        # Total loss
        loss = policy_loss + 0.5 * value_loss - 0.01 * entropy
        return loss, policy_loss.item(), value_loss.item(), entropy.item()

Full PPO Training Loop

def collect_trajectories(env, model, num_steps=2048):
    """Collect experiences from the environment"""
    obs_list, act_list, rew_list = [], [], []
    done_list, logprob_list, val_list = [], [], []

    obs = env.reset()[0]
    for _ in range(num_steps):
        obs_t = torch.FloatTensor(obs).unsqueeze(0)
        mu, std, value = model(obs_t)
        dist = torch.distributions.Normal(mu, std)
        action = dist.sample()
        log_prob = dist.log_prob(action).sum(dim=-1)

        action_np = action.detach().numpy().flatten()
        next_obs, reward, terminated, truncated, _ = env.step(action_np)
        done = terminated or truncated

        obs_list.append(obs)
        act_list.append(action.detach().squeeze(0))
        rew_list.append(reward)
        done_list.append(done)
        logprob_list.append(log_prob.detach())
        val_list.append(value.squeeze().detach())

        obs = next_obs if not done else env.reset()[0]

    return (obs_list, act_list, rew_list,
            done_list, logprob_list, val_list)

def compute_gae(rewards, values, dones, gamma=0.99, lam=0.95):
    """Generalized Advantage Estimation"""
    advantages = []
    gae = 0
    next_value = 0

    for t in reversed(range(len(rewards))):
        if dones[t]:
            delta = rewards[t] - values[t]
            gae = delta
        else:
            next_val = values[t+1] if t+1 < len(values) else next_value
            delta = rewards[t] + gamma * next_val - values[t]
            gae = delta + gamma * lam * gae
        advantages.insert(0, gae)

    advantages = torch.FloatTensor(advantages)
    returns = advantages + torch.FloatTensor([v.item() for v in values])
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
    return advantages, returns

def train_ppo(env_name='HalfCheetah-v4', total_timesteps=1000000,
              num_steps=2048, num_epochs=10, batch_size=64):
    """PPO main training function"""
    env, obs_size, act_size, _ = make_env(env_name)
    agent = PPOAgent(obs_size, act_size)

    num_updates = total_timesteps // num_steps

    for update in range(num_updates):
        # 1. Collect experiences
        data = collect_trajectories(env, agent.model, num_steps)
        obs_list, act_list, rew_list, done_list, logprob_list, val_list = data

        # 2. Compute GAE
        advantages, returns = compute_gae(rew_list, val_list, done_list)

        # Convert to tensors
        obs_t = torch.FloatTensor(np.array(obs_list))
        act_t = torch.stack(act_list)
        old_logprobs = torch.cat(logprob_list)

        # 3. Mini-batch PPO update (multiple epochs)
        dataset_size = len(obs_list)
        for epoch in range(num_epochs):
            indices = np.random.permutation(dataset_size)
            for start in range(0, dataset_size, batch_size):
                end = start + batch_size
                idx = indices[start:end]

                loss, pl, vl, ent = agent.compute_ppo_loss(
                    obs_t[idx], act_t[idx], old_logprobs[idx],
                    advantages[idx], returns[idx]
                )

                agent.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    agent.model.parameters(), 0.5
                )
                agent.optimizer.step()

        if update % 10 == 0:
            avg_reward = np.sum(rew_list) / max(sum(done_list), 1)
            print(f"Update {update}: AvgReward={avg_reward:.1f}, "
                  f"PolicyLoss={pl:.4f}, Entropy={ent:.4f}")

    return agent

Trust Region Policy Optimization (TRPO)

TRPO was proposed before PPO and performs policy updates under a KL divergence constraint.

Mathematical Background of TRPO

TRPO solves the following optimization problem:

Objective: Maximize the surrogate objective (expected policy ratio * advantage) Constraint: KL divergence between old and new policy must be below delta

It uses conjugate gradient methods to solve this efficiently.

def conjugate_gradient(Avp_fn, b, num_steps=10, residual_tol=1e-10):
    """Conjugate gradient algorithm
    Avp_fn: function that computes Hessian-vector product
    b: right-hand side vector (policy gradient)
    """
    x = torch.zeros_like(b)
    r = b.clone()
    p = b.clone()
    rdotr = torch.dot(r, r)

    for _ in range(num_steps):
        Avp = Avp_fn(p)
        alpha = rdotr / (torch.dot(p, Avp) + 1e-8)
        x += alpha * p
        r -= alpha * Avp
        new_rdotr = torch.dot(r, r)
        if new_rdotr < residual_tol:
            break
        beta = new_rdotr / rdotr
        p = r + beta * p
        rdotr = new_rdotr

    return x

def hessian_vector_product(model, obs, old_dist, vector, damping=0.1):
    """Compute product of Fisher Information Matrix and a vector"""
    mu, std, _ = model(obs)
    new_dist = torch.distributions.Normal(mu, std)
    kl = torch.distributions.kl_divergence(old_dist, new_dist).sum(dim=-1).mean()

    kl_grad = torch.autograd.grad(kl, model.parameters(), create_graph=True)
    kl_grad_flat = torch.cat([g.view(-1) for g in kl_grad])

    kl_grad_vector = torch.dot(kl_grad_flat, vector)
    hvp = torch.autograd.grad(kl_grad_vector, model.parameters())
    hvp_flat = torch.cat([g.contiguous().view(-1) for g in hvp])

    return hvp_flat + damping * vector

def trpo_step(model, obs, actions, advantages, old_log_probs,
              max_kl=0.01):
    """TRPO update step"""
    # 1. Compute policy gradient
    mu, std, _ = model(obs)
    dist = torch.distributions.Normal(mu, std)
    log_probs = dist.log_prob(actions).sum(dim=-1)

    ratio = torch.exp(log_probs - old_log_probs)
    surrogate = (ratio * advantages).mean()

    policy_grad = torch.autograd.grad(surrogate, model.parameters())
    policy_grad_flat = torch.cat([g.view(-1) for g in policy_grad])

    # 2. Compute step direction via conjugate gradient
    old_dist = torch.distributions.Normal(mu.detach(), std.detach())
    Avp_fn = lambda v: hessian_vector_product(model, obs, old_dist, v)
    step_dir = conjugate_gradient(Avp_fn, policy_grad_flat)

    # 3. Determine step size (line search)
    shs = 0.5 * torch.dot(step_dir, Avp_fn(step_dir))
    max_step = torch.sqrt(max_kl / (shs + 1e-8))
    full_step = max_step * step_dir

    # 4. Line search for appropriate step
    old_params = torch.cat([p.data.view(-1) for p in model.parameters()])
    expected_improve = torch.dot(policy_grad_flat, full_step)

    for fraction in [1.0, 0.5, 0.25, 0.125]:
        new_params = old_params + fraction * full_step
        _set_flat_params(model, new_params)

        # KL divergence check
        new_mu, new_std, _ = model(obs)
        new_dist = torch.distributions.Normal(new_mu, new_std)
        kl = torch.distributions.kl_divergence(old_dist, new_dist)
        kl = kl.sum(dim=-1).mean()

        if kl < max_kl:
            return True  # Success

    # Restore original parameters on failure
    _set_flat_params(model, old_params)
    return False

def _set_flat_params(model, flat_params):
    offset = 0
    for param in model.parameters():
        size = param.numel()
        param.data.copy_(flat_params[offset:offset+size].view(param.shape))
        offset += size

ACKTR: A2C using Kronecker-Factored Trust Region

ACKTR uses Kronecker-factored approximate curvature (K-FAC) to efficiently approximate the natural gradient.

Core Idea

  • Regular gradient: steepest descent direction in parameter space
  • Natural gradient: steepest descent direction in distribution space (how much the probability distribution changes)
  • K-FAC: efficiently approximates the inverse of the Fisher information matrix
class KFACOptimizer:
    """Conceptual implementation of K-FAC optimizer"""

    def __init__(self, model, lr=0.25, damping=1e-3, update_freq=10):
        self.model = model
        self.lr = lr
        self.damping = damping
        self.update_freq = update_freq
        self.steps = 0

        # Fisher information factors for each layer
        self.fisher_factors = {}

    def step(self, closure=None):
        """K-FAC update step"""
        self.steps += 1

        # Update Fisher factors (periodically)
        if self.steps % self.update_freq == 0:
            self._update_fisher_factors()

        # Compute and apply natural gradient
        for name, param in self.model.named_parameters():
            if param.grad is not None:
                # Fisher inverse * gradient = natural gradient
                natural_grad = self._compute_natural_gradient(
                    name, param.grad
                )
                param.data -= self.lr * natural_grad

    def _update_fisher_factors(self):
        """Update Kronecker factors"""
        # A = E[a * a^T] (outer product of input activations)
        # G = E[g * g^T] (outer product of output gradients)
        # Fisher approximation: F ~ A (x) G (Kronecker product)
        pass

    def _compute_natural_gradient(self, name, grad):
        """Natural gradient = F^-1 * grad"""
        # K-FAC: (A (x) G)^-1 = A^-1 (x) G^-1
        # Instead of inverting a large matrix, only invert two small matrices
        return grad  # Simplified return

Algorithm Comparison

HalfCheetah-v4 Performance Comparison

AlgorithmReward at 1M stepsImplementation difficultyHyperparameter sensitivity
A2C~3000EasyHigh
PPO~6000EasyLow
TRPO~5500DifficultVery low
ACKTR~7000Very difficultLow

Selection Guide

  • PPO: First choice in most situations. Simple implementation and good performance
  • TRPO: For research purposes requiring theoretical guarantees
  • ACKTR: When sample efficiency is important. However, implementation complexity is high

PPO Practical Tips

  1. Learning rate scheduling: Linear decay is common
def linear_schedule(initial_lr, current_step, total_steps):
    return initial_lr * (1.0 - current_step / total_steps)
  1. Advantage normalization: Normalize to mean 0, variance 1 within mini-batches

  2. Value function clipping: Clip value function updates like policy

  3. Gradient clipping: max_grad_norm = 0.5 is typical

  4. Parallel environments: Use vectorized environments to speed up sample collection

def make_vec_env(env_name, num_envs=8):
    """Create vectorized environments"""
    def make_single():
        return gym.make(env_name)
    envs = gym.vector.SyncVectorEnv(
        [make_single for _ in range(num_envs)]
    )
    return envs

Key Takeaways

  • Trust Region methods guarantee learning stability by limiting the size of policy updates
  • PPO provides a simple yet effective Trust Region approximation through clipped objectives
  • TRPO implements a theoretically precise Trust Region with KL divergence constraints and conjugate gradients
  • ACKTR achieves high sample efficiency with natural gradients using K-FAC
  • In practice, PPO is the most widely used

In the next post, we will explore Black-Box optimization (evolutionary strategies, genetic algorithms) that optimize policies without gradients.