Skip to content

Split View: [심층 강화학습] 15. Trust Region 방법: TRPO, PPO, ACKTR

|

[심층 강화학습] 15. Trust Region 방법: TRPO, PPO, ACKTR

개요

정책 그래디언트 방법의 핵심 문제 중 하나는 학습률(step size) 선택이다. 너무 크면 정책이 갑자기 나빠져 복구할 수 없고, 너무 작으면 학습이 지나치게 느리다. Trust Region 방법은 정책 업데이트의 크기를 제한하여 이 문제를 해결한다.

이 글에서는 Roboschool 환경에서 A2C 기준선을 설정한 뒤, PPO, TRPO, ACKTR의 원리와 구현을 살펴본다.


Roboschool 환경

Roboschool(현재는 PyBullet)은 MuJoCo의 오픈소스 대안으로, 다양한 로봇 제어 환경을 제공한다:

  • HalfCheetah: 2족 치타 로봇의 전진 보행
  • Hopper: 1족 점프 로봇의 균형과 이동
  • Walker2D: 2족 보행 로봇
  • Humanoid: 인체 형태 로봇의 보행
import gymnasium as gym

def make_env(env_name='HalfCheetah-v4'):
    env = gym.make(env_name)
    obs_size = env.observation_space.shape[0]
    act_size = env.action_space.shape[0]
    act_limit = env.action_space.high[0]
    return env, obs_size, act_size, act_limit

A2C 기준선

비교를 위한 연속 행동 공간 A2C 기준선:

import torch
import torch.nn as nn
import numpy as np

class A2CBaseline(nn.Module):
    def __init__(self, obs_size, act_size):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(obs_size, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
        )
        self.mu = nn.Linear(64, act_size)
        self.log_std = nn.Parameter(torch.zeros(act_size))
        self.value = nn.Linear(64, 1)

    def forward(self, obs):
        features = self.shared(obs)
        mu = self.mu(features)
        std = self.log_std.exp()
        value = self.value(features)
        return mu, std, value

A2C는 단순하지만 학습률에 민감하다. 큰 업데이트 한 번이 정책을 무너뜨릴 수 있다.


Proximal Policy Optimization (PPO)

PPO는 2017년 OpenAI의 Schulman 등이 제안한 알고리즘으로, TRPO의 복잡한 제약 최적화를 단순한 클리핑으로 근사한다. 구현이 간단하면서도 성능이 우수하여 가장 널리 사용되는 알고리즘이다.

클리핑 목적 함수

PPO의 핵심은 정책 비율(policy ratio)을 클리핑하여 너무 큰 업데이트를 방지하는 것이다:

정책 비율 r(t) = 새 정책의 확률 / 이전 정책의 확률

클리핑된 목적함수는 r(t)를 (1-epsilon, 1+epsilon) 범위로 제한한다. 일반적으로 epsilon = 0.2를 사용한다.

class PPOAgent:
    def __init__(self, obs_size, act_size, clip_epsilon=0.2,
                 lr=3e-4, gamma=0.99, lam=0.95):
        self.clip_epsilon = clip_epsilon
        self.gamma = gamma
        self.lam = lam

        self.model = A2CBaseline(obs_size, act_size)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)

    def compute_ppo_loss(self, obs, actions, old_log_probs,
                         advantages, returns):
        """PPO 클리핑 목적함수"""
        mu, std, values = self.model(obs)
        dist = torch.distributions.Normal(mu, std)

        # 현재 정책의 로그 확률
        new_log_probs = dist.log_prob(actions).sum(dim=-1)
        entropy = dist.entropy().sum(dim=-1).mean()

        # 정책 비율
        ratio = torch.exp(new_log_probs - old_log_probs)

        # 클리핑된 목적함수
        surr1 = ratio * advantages
        surr2 = torch.clamp(ratio,
                            1.0 - self.clip_epsilon,
                            1.0 + self.clip_epsilon) * advantages
        policy_loss = -torch.min(surr1, surr2).mean()

        # 가치 손실
        value_loss = (returns - values.squeeze()).pow(2).mean()

        # 전체 손실
        loss = policy_loss + 0.5 * value_loss - 0.01 * entropy
        return loss, policy_loss.item(), value_loss.item(), entropy.item()

PPO 전체 학습 루프

def collect_trajectories(env, model, num_steps=2048):
    """환경에서 경험 수집"""
    obs_list, act_list, rew_list = [], [], []
    done_list, logprob_list, val_list = [], [], []

    obs = env.reset()[0]
    for _ in range(num_steps):
        obs_t = torch.FloatTensor(obs).unsqueeze(0)
        mu, std, value = model(obs_t)
        dist = torch.distributions.Normal(mu, std)
        action = dist.sample()
        log_prob = dist.log_prob(action).sum(dim=-1)

        action_np = action.detach().numpy().flatten()
        next_obs, reward, terminated, truncated, _ = env.step(action_np)
        done = terminated or truncated

        obs_list.append(obs)
        act_list.append(action.detach().squeeze(0))
        rew_list.append(reward)
        done_list.append(done)
        logprob_list.append(log_prob.detach())
        val_list.append(value.squeeze().detach())

        obs = next_obs if not done else env.reset()[0]

    return (obs_list, act_list, rew_list,
            done_list, logprob_list, val_list)

def compute_gae(rewards, values, dones, gamma=0.99, lam=0.95):
    """Generalized Advantage Estimation"""
    advantages = []
    gae = 0
    next_value = 0

    for t in reversed(range(len(rewards))):
        if dones[t]:
            delta = rewards[t] - values[t]
            gae = delta
        else:
            next_val = values[t+1] if t+1 < len(values) else next_value
            delta = rewards[t] + gamma * next_val - values[t]
            gae = delta + gamma * lam * gae
        advantages.insert(0, gae)

    advantages = torch.FloatTensor(advantages)
    returns = advantages + torch.FloatTensor([v.item() for v in values])
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
    return advantages, returns

def train_ppo(env_name='HalfCheetah-v4', total_timesteps=1000000,
              num_steps=2048, num_epochs=10, batch_size=64):
    """PPO 메인 학습 함수"""
    env, obs_size, act_size, _ = make_env(env_name)
    agent = PPOAgent(obs_size, act_size)

    num_updates = total_timesteps // num_steps

    for update in range(num_updates):
        # 1. 경험 수집
        data = collect_trajectories(env, agent.model, num_steps)
        obs_list, act_list, rew_list, done_list, logprob_list, val_list = data

        # 2. GAE 계산
        advantages, returns = compute_gae(rew_list, val_list, done_list)

        # 텐서 변환
        obs_t = torch.FloatTensor(np.array(obs_list))
        act_t = torch.stack(act_list)
        old_logprobs = torch.cat(logprob_list)

        # 3. 미니배치 PPO 업데이트 (여러 에폭)
        dataset_size = len(obs_list)
        for epoch in range(num_epochs):
            indices = np.random.permutation(dataset_size)
            for start in range(0, dataset_size, batch_size):
                end = start + batch_size
                idx = indices[start:end]

                loss, pl, vl, ent = agent.compute_ppo_loss(
                    obs_t[idx], act_t[idx], old_logprobs[idx],
                    advantages[idx], returns[idx]
                )

                agent.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    agent.model.parameters(), 0.5
                )
                agent.optimizer.step()

        if update % 10 == 0:
            avg_reward = np.sum(rew_list) / max(sum(done_list), 1)
            print(f"Update {update}: AvgReward={avg_reward:.1f}, "
                  f"PolicyLoss={pl:.4f}, Entropy={ent:.4f}")

    return agent

Trust Region Policy Optimization (TRPO)

TRPO는 PPO보다 먼저 제안된 알고리즘으로, 정책 업데이트를 KL 발산(Kullback-Leibler divergence) 제약 하에서 수행한다.

TRPO의 수학적 배경

TRPO는 다음 최적화 문제를 풀어야 한다:

목적: 서로게이트 목적함수 최대화 (정책 비율 * 어드밴티지의 기대값) 제약: 이전 정책과 새 정책 사이의 KL 발산이 delta 이하

이를 효율적으로 풀기 위해 컨주게이트 그래디언트(conjugate gradient) 방법을 사용한다.

def conjugate_gradient(Avp_fn, b, num_steps=10, residual_tol=1e-10):
    """컨주게이트 그래디언트 알고리즘
    Avp_fn: 헤시안-벡터 곱을 계산하는 함수
    b: 우변 벡터 (정책 그래디언트)
    """
    x = torch.zeros_like(b)
    r = b.clone()
    p = b.clone()
    rdotr = torch.dot(r, r)

    for _ in range(num_steps):
        Avp = Avp_fn(p)
        alpha = rdotr / (torch.dot(p, Avp) + 1e-8)
        x += alpha * p
        r -= alpha * Avp
        new_rdotr = torch.dot(r, r)
        if new_rdotr < residual_tol:
            break
        beta = new_rdotr / rdotr
        p = r + beta * p
        rdotr = new_rdotr

    return x

def hessian_vector_product(model, obs, old_dist, vector, damping=0.1):
    """Fisher Information Matrix와 벡터의 곱 계산"""
    mu, std, _ = model(obs)
    new_dist = torch.distributions.Normal(mu, std)
    kl = torch.distributions.kl_divergence(old_dist, new_dist).sum(dim=-1).mean()

    kl_grad = torch.autograd.grad(kl, model.parameters(), create_graph=True)
    kl_grad_flat = torch.cat([g.view(-1) for g in kl_grad])

    kl_grad_vector = torch.dot(kl_grad_flat, vector)
    hvp = torch.autograd.grad(kl_grad_vector, model.parameters())
    hvp_flat = torch.cat([g.contiguous().view(-1) for g in hvp])

    return hvp_flat + damping * vector

def trpo_step(model, obs, actions, advantages, old_log_probs,
              max_kl=0.01):
    """TRPO 업데이트 스텝"""
    # 1. 정책 그래디언트 계산
    mu, std, _ = model(obs)
    dist = torch.distributions.Normal(mu, std)
    log_probs = dist.log_prob(actions).sum(dim=-1)

    ratio = torch.exp(log_probs - old_log_probs)
    surrogate = (ratio * advantages).mean()

    policy_grad = torch.autograd.grad(surrogate, model.parameters())
    policy_grad_flat = torch.cat([g.view(-1) for g in policy_grad])

    # 2. 컨주게이트 그래디언트로 스텝 방향 계산
    old_dist = torch.distributions.Normal(mu.detach(), std.detach())
    Avp_fn = lambda v: hessian_vector_product(model, obs, old_dist, v)
    step_dir = conjugate_gradient(Avp_fn, policy_grad_flat)

    # 3. 스텝 크기 결정 (라인 서치)
    shs = 0.5 * torch.dot(step_dir, Avp_fn(step_dir))
    max_step = torch.sqrt(max_kl / (shs + 1e-8))
    full_step = max_step * step_dir

    # 4. 라인 서치로 적절한 스텝 찾기
    old_params = torch.cat([p.data.view(-1) for p in model.parameters()])
    expected_improve = torch.dot(policy_grad_flat, full_step)

    for fraction in [1.0, 0.5, 0.25, 0.125]:
        new_params = old_params + fraction * full_step
        _set_flat_params(model, new_params)

        # KL 발산 체크
        new_mu, new_std, _ = model(obs)
        new_dist = torch.distributions.Normal(new_mu, new_std)
        kl = torch.distributions.kl_divergence(old_dist, new_dist)
        kl = kl.sum(dim=-1).mean()

        if kl < max_kl:
            return True  # 성공

    # 실패 시 원래 파라미터로 복원
    _set_flat_params(model, old_params)
    return False

def _set_flat_params(model, flat_params):
    offset = 0
    for param in model.parameters():
        size = param.numel()
        param.data.copy_(flat_params[offset:offset+size].view(param.shape))
        offset += size

ACKTR: A2C using Kronecker-Factored Trust Region

ACKTR은 Kronecker-factored approximate curvature(K-FAC) 를 사용하여 자연 그래디언트(natural gradient)를 효율적으로 근사한다.

핵심 아이디어

  • 일반 그래디언트: 파라미터 공간에서의 최급강하 방향
  • 자연 그래디언트: 분포 공간(확률 분포가 변하는 정도)에서의 최급강하 방향
  • K-FAC: Fisher 정보 행렬의 역행렬을 효율적으로 근사
class KFACOptimizer:
    """K-FAC 옵티마이저의 개념적 구현"""

    def __init__(self, model, lr=0.25, damping=1e-3, update_freq=10):
        self.model = model
        self.lr = lr
        self.damping = damping
        self.update_freq = update_freq
        self.steps = 0

        # 각 레이어의 Fisher 정보 인수
        self.fisher_factors = {}

    def step(self, closure=None):
        """K-FAC 업데이트 스텝"""
        self.steps += 1

        # Fisher 인수 업데이트 (주기적으로)
        if self.steps % self.update_freq == 0:
            self._update_fisher_factors()

        # 자연 그래디언트 계산 및 적용
        for name, param in self.model.named_parameters():
            if param.grad is not None:
                # Fisher 역행렬 * 그래디언트 = 자연 그래디언트
                natural_grad = self._compute_natural_gradient(
                    name, param.grad
                )
                param.data -= self.lr * natural_grad

    def _update_fisher_factors(self):
        """Kronecker 인수 업데이트"""
        # A = E[a * a^T] (입력 활성화의 외적)
        # G = E[g * g^T] (출력 그래디언트의 외적)
        # Fisher 근사: F ≈ A (x) G (크로네커 곱)
        pass

    def _compute_natural_gradient(self, name, grad):
        """자연 그래디언트 = F^-1 * grad"""
        # K-FAC: (A (x) G)^-1 = A^-1 (x) G^-1
        # 큰 행렬의 역행렬 대신 작은 행렬 두 개의 역행렬만 계산
        return grad  # 간략화된 반환

알고리즘 비교

HalfCheetah-v4 성능 비교

알고리즘100만 스텝 보상구현 난이도하이퍼파라미터 민감도
A2C약 3000쉬움높음
PPO약 6000쉬움낮음
TRPO약 5500어려움매우 낮음
ACKTR약 7000매우 어려움낮음

선택 가이드

  • PPO: 대부분의 상황에서 첫 번째 선택. 구현이 간단하고 성능이 좋다
  • TRPO: 이론적 보장이 필요한 연구 목적
  • ACKTR: 샘플 효율성이 중요한 경우. 하지만 구현 복잡도가 높다

PPO 실전 팁

  1. 학습률 스케줄링: 선형 감소가 일반적이다
def linear_schedule(initial_lr, current_step, total_steps):
    return initial_lr * (1.0 - current_step / total_steps)
  1. 어드밴티지 정규화: 미니배치 내에서 평균 0, 분산 1로 정규화

  2. 가치 함수 클리핑: 정책과 마찬가지로 가치 함수 업데이트도 클리핑

  3. 그래디언트 클리핑: max_grad_norm = 0.5가 일반적

  4. 병렬 환경: 벡터화된 환경으로 샘플 수집 속도 향상

def make_vec_env(env_name, num_envs=8):
    """벡터화된 환경 생성"""
    def make_single():
        return gym.make(env_name)
    envs = gym.vector.SyncVectorEnv(
        [make_single for _ in range(num_envs)]
    )
    return envs

핵심 요약

  • Trust Region 방법은 정책 업데이트의 크기를 제한하여 학습 안정성을 보장한다
  • PPO는 클리핑된 목적함수로 간단하면서도 효과적인 Trust Region 근사를 제공한다
  • TRPO는 KL 발산 제약과 컨주게이트 그래디언트로 이론적으로 정확한 Trust Region을 구현한다
  • ACKTR은 K-FAC을 사용한 자연 그래디언트로 높은 샘플 효율성을 달성한다
  • 실전에서는 PPO가 가장 널리 사용된다

다음 글에서는 그래디언트 없이 정책을 최적화하는 Black-Box 최적화(진화 전략, 유전 알고리즘) 를 살펴보겠다.

[Deep RL] 15. Trust Region Methods: TRPO, PPO, ACKTR

Overview

One of the key problems in policy gradient methods is step size selection. If too large, the policy can suddenly deteriorate beyond recovery; if too small, learning becomes excessively slow. Trust Region methods solve this problem by limiting the size of policy updates.

This post sets an A2C baseline on the Roboschool environment, then examines the principles and implementations of PPO, TRPO, and ACKTR.


Roboschool Environment

Roboschool (now PyBullet) is an open-source alternative to MuJoCo that provides various robot control environments:

  • HalfCheetah: Forward locomotion of a bipedal cheetah robot
  • Hopper: Balance and movement of a single-legged jumping robot
  • Walker2D: Bipedal walking robot
  • Humanoid: Walking of a humanoid robot
import gymnasium as gym

def make_env(env_name='HalfCheetah-v4'):
    env = gym.make(env_name)
    obs_size = env.observation_space.shape[0]
    act_size = env.action_space.shape[0]
    act_limit = env.action_space.high[0]
    return env, obs_size, act_size, act_limit

A2C Baseline

A continuous action space A2C baseline for comparison:

import torch
import torch.nn as nn
import numpy as np

class A2CBaseline(nn.Module):
    def __init__(self, obs_size, act_size):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(obs_size, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
        )
        self.mu = nn.Linear(64, act_size)
        self.log_std = nn.Parameter(torch.zeros(act_size))
        self.value = nn.Linear(64, 1)

    def forward(self, obs):
        features = self.shared(obs)
        mu = self.mu(features)
        std = self.log_std.exp()
        value = self.value(features)
        return mu, std, value

A2C is simple but sensitive to the learning rate. A single large update can collapse the policy.


Proximal Policy Optimization (PPO)

PPO was proposed by Schulman et al. at OpenAI in 2017. It approximates TRPO's complex constrained optimization with simple clipping. It is the most widely used algorithm due to its simple implementation and excellent performance.

Clipping Objective

The core of PPO is clipping the policy ratio to prevent excessively large updates:

Policy ratio r(t) = probability under new policy / probability under old policy

The clipped objective restricts r(t) to the range (1-epsilon, 1+epsilon). Typically epsilon = 0.2 is used.

class PPOAgent:
    def __init__(self, obs_size, act_size, clip_epsilon=0.2,
                 lr=3e-4, gamma=0.99, lam=0.95):
        self.clip_epsilon = clip_epsilon
        self.gamma = gamma
        self.lam = lam

        self.model = A2CBaseline(obs_size, act_size)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)

    def compute_ppo_loss(self, obs, actions, old_log_probs,
                         advantages, returns):
        """PPO clipping objective"""
        mu, std, values = self.model(obs)
        dist = torch.distributions.Normal(mu, std)

        # Log probability under current policy
        new_log_probs = dist.log_prob(actions).sum(dim=-1)
        entropy = dist.entropy().sum(dim=-1).mean()

        # Policy ratio
        ratio = torch.exp(new_log_probs - old_log_probs)

        # Clipped objective
        surr1 = ratio * advantages
        surr2 = torch.clamp(ratio,
                            1.0 - self.clip_epsilon,
                            1.0 + self.clip_epsilon) * advantages
        policy_loss = -torch.min(surr1, surr2).mean()

        # Value loss
        value_loss = (returns - values.squeeze()).pow(2).mean()

        # Total loss
        loss = policy_loss + 0.5 * value_loss - 0.01 * entropy
        return loss, policy_loss.item(), value_loss.item(), entropy.item()

Full PPO Training Loop

def collect_trajectories(env, model, num_steps=2048):
    """Collect experiences from the environment"""
    obs_list, act_list, rew_list = [], [], []
    done_list, logprob_list, val_list = [], [], []

    obs = env.reset()[0]
    for _ in range(num_steps):
        obs_t = torch.FloatTensor(obs).unsqueeze(0)
        mu, std, value = model(obs_t)
        dist = torch.distributions.Normal(mu, std)
        action = dist.sample()
        log_prob = dist.log_prob(action).sum(dim=-1)

        action_np = action.detach().numpy().flatten()
        next_obs, reward, terminated, truncated, _ = env.step(action_np)
        done = terminated or truncated

        obs_list.append(obs)
        act_list.append(action.detach().squeeze(0))
        rew_list.append(reward)
        done_list.append(done)
        logprob_list.append(log_prob.detach())
        val_list.append(value.squeeze().detach())

        obs = next_obs if not done else env.reset()[0]

    return (obs_list, act_list, rew_list,
            done_list, logprob_list, val_list)

def compute_gae(rewards, values, dones, gamma=0.99, lam=0.95):
    """Generalized Advantage Estimation"""
    advantages = []
    gae = 0
    next_value = 0

    for t in reversed(range(len(rewards))):
        if dones[t]:
            delta = rewards[t] - values[t]
            gae = delta
        else:
            next_val = values[t+1] if t+1 < len(values) else next_value
            delta = rewards[t] + gamma * next_val - values[t]
            gae = delta + gamma * lam * gae
        advantages.insert(0, gae)

    advantages = torch.FloatTensor(advantages)
    returns = advantages + torch.FloatTensor([v.item() for v in values])
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
    return advantages, returns

def train_ppo(env_name='HalfCheetah-v4', total_timesteps=1000000,
              num_steps=2048, num_epochs=10, batch_size=64):
    """PPO main training function"""
    env, obs_size, act_size, _ = make_env(env_name)
    agent = PPOAgent(obs_size, act_size)

    num_updates = total_timesteps // num_steps

    for update in range(num_updates):
        # 1. Collect experiences
        data = collect_trajectories(env, agent.model, num_steps)
        obs_list, act_list, rew_list, done_list, logprob_list, val_list = data

        # 2. Compute GAE
        advantages, returns = compute_gae(rew_list, val_list, done_list)

        # Convert to tensors
        obs_t = torch.FloatTensor(np.array(obs_list))
        act_t = torch.stack(act_list)
        old_logprobs = torch.cat(logprob_list)

        # 3. Mini-batch PPO update (multiple epochs)
        dataset_size = len(obs_list)
        for epoch in range(num_epochs):
            indices = np.random.permutation(dataset_size)
            for start in range(0, dataset_size, batch_size):
                end = start + batch_size
                idx = indices[start:end]

                loss, pl, vl, ent = agent.compute_ppo_loss(
                    obs_t[idx], act_t[idx], old_logprobs[idx],
                    advantages[idx], returns[idx]
                )

                agent.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    agent.model.parameters(), 0.5
                )
                agent.optimizer.step()

        if update % 10 == 0:
            avg_reward = np.sum(rew_list) / max(sum(done_list), 1)
            print(f"Update {update}: AvgReward={avg_reward:.1f}, "
                  f"PolicyLoss={pl:.4f}, Entropy={ent:.4f}")

    return agent

Trust Region Policy Optimization (TRPO)

TRPO was proposed before PPO and performs policy updates under a KL divergence constraint.

Mathematical Background of TRPO

TRPO solves the following optimization problem:

Objective: Maximize the surrogate objective (expected policy ratio * advantage) Constraint: KL divergence between old and new policy must be below delta

It uses conjugate gradient methods to solve this efficiently.

def conjugate_gradient(Avp_fn, b, num_steps=10, residual_tol=1e-10):
    """Conjugate gradient algorithm
    Avp_fn: function that computes Hessian-vector product
    b: right-hand side vector (policy gradient)
    """
    x = torch.zeros_like(b)
    r = b.clone()
    p = b.clone()
    rdotr = torch.dot(r, r)

    for _ in range(num_steps):
        Avp = Avp_fn(p)
        alpha = rdotr / (torch.dot(p, Avp) + 1e-8)
        x += alpha * p
        r -= alpha * Avp
        new_rdotr = torch.dot(r, r)
        if new_rdotr < residual_tol:
            break
        beta = new_rdotr / rdotr
        p = r + beta * p
        rdotr = new_rdotr

    return x

def hessian_vector_product(model, obs, old_dist, vector, damping=0.1):
    """Compute product of Fisher Information Matrix and a vector"""
    mu, std, _ = model(obs)
    new_dist = torch.distributions.Normal(mu, std)
    kl = torch.distributions.kl_divergence(old_dist, new_dist).sum(dim=-1).mean()

    kl_grad = torch.autograd.grad(kl, model.parameters(), create_graph=True)
    kl_grad_flat = torch.cat([g.view(-1) for g in kl_grad])

    kl_grad_vector = torch.dot(kl_grad_flat, vector)
    hvp = torch.autograd.grad(kl_grad_vector, model.parameters())
    hvp_flat = torch.cat([g.contiguous().view(-1) for g in hvp])

    return hvp_flat + damping * vector

def trpo_step(model, obs, actions, advantages, old_log_probs,
              max_kl=0.01):
    """TRPO update step"""
    # 1. Compute policy gradient
    mu, std, _ = model(obs)
    dist = torch.distributions.Normal(mu, std)
    log_probs = dist.log_prob(actions).sum(dim=-1)

    ratio = torch.exp(log_probs - old_log_probs)
    surrogate = (ratio * advantages).mean()

    policy_grad = torch.autograd.grad(surrogate, model.parameters())
    policy_grad_flat = torch.cat([g.view(-1) for g in policy_grad])

    # 2. Compute step direction via conjugate gradient
    old_dist = torch.distributions.Normal(mu.detach(), std.detach())
    Avp_fn = lambda v: hessian_vector_product(model, obs, old_dist, v)
    step_dir = conjugate_gradient(Avp_fn, policy_grad_flat)

    # 3. Determine step size (line search)
    shs = 0.5 * torch.dot(step_dir, Avp_fn(step_dir))
    max_step = torch.sqrt(max_kl / (shs + 1e-8))
    full_step = max_step * step_dir

    # 4. Line search for appropriate step
    old_params = torch.cat([p.data.view(-1) for p in model.parameters()])
    expected_improve = torch.dot(policy_grad_flat, full_step)

    for fraction in [1.0, 0.5, 0.25, 0.125]:
        new_params = old_params + fraction * full_step
        _set_flat_params(model, new_params)

        # KL divergence check
        new_mu, new_std, _ = model(obs)
        new_dist = torch.distributions.Normal(new_mu, new_std)
        kl = torch.distributions.kl_divergence(old_dist, new_dist)
        kl = kl.sum(dim=-1).mean()

        if kl < max_kl:
            return True  # Success

    # Restore original parameters on failure
    _set_flat_params(model, old_params)
    return False

def _set_flat_params(model, flat_params):
    offset = 0
    for param in model.parameters():
        size = param.numel()
        param.data.copy_(flat_params[offset:offset+size].view(param.shape))
        offset += size

ACKTR: A2C using Kronecker-Factored Trust Region

ACKTR uses Kronecker-factored approximate curvature (K-FAC) to efficiently approximate the natural gradient.

Core Idea

  • Regular gradient: steepest descent direction in parameter space
  • Natural gradient: steepest descent direction in distribution space (how much the probability distribution changes)
  • K-FAC: efficiently approximates the inverse of the Fisher information matrix
class KFACOptimizer:
    """Conceptual implementation of K-FAC optimizer"""

    def __init__(self, model, lr=0.25, damping=1e-3, update_freq=10):
        self.model = model
        self.lr = lr
        self.damping = damping
        self.update_freq = update_freq
        self.steps = 0

        # Fisher information factors for each layer
        self.fisher_factors = {}

    def step(self, closure=None):
        """K-FAC update step"""
        self.steps += 1

        # Update Fisher factors (periodically)
        if self.steps % self.update_freq == 0:
            self._update_fisher_factors()

        # Compute and apply natural gradient
        for name, param in self.model.named_parameters():
            if param.grad is not None:
                # Fisher inverse * gradient = natural gradient
                natural_grad = self._compute_natural_gradient(
                    name, param.grad
                )
                param.data -= self.lr * natural_grad

    def _update_fisher_factors(self):
        """Update Kronecker factors"""
        # A = E[a * a^T] (outer product of input activations)
        # G = E[g * g^T] (outer product of output gradients)
        # Fisher approximation: F ~ A (x) G (Kronecker product)
        pass

    def _compute_natural_gradient(self, name, grad):
        """Natural gradient = F^-1 * grad"""
        # K-FAC: (A (x) G)^-1 = A^-1 (x) G^-1
        # Instead of inverting a large matrix, only invert two small matrices
        return grad  # Simplified return

Algorithm Comparison

HalfCheetah-v4 Performance Comparison

AlgorithmReward at 1M stepsImplementation difficultyHyperparameter sensitivity
A2C~3000EasyHigh
PPO~6000EasyLow
TRPO~5500DifficultVery low
ACKTR~7000Very difficultLow

Selection Guide

  • PPO: First choice in most situations. Simple implementation and good performance
  • TRPO: For research purposes requiring theoretical guarantees
  • ACKTR: When sample efficiency is important. However, implementation complexity is high

PPO Practical Tips

  1. Learning rate scheduling: Linear decay is common
def linear_schedule(initial_lr, current_step, total_steps):
    return initial_lr * (1.0 - current_step / total_steps)
  1. Advantage normalization: Normalize to mean 0, variance 1 within mini-batches

  2. Value function clipping: Clip value function updates like policy

  3. Gradient clipping: max_grad_norm = 0.5 is typical

  4. Parallel environments: Use vectorized environments to speed up sample collection

def make_vec_env(env_name, num_envs=8):
    """Create vectorized environments"""
    def make_single():
        return gym.make(env_name)
    envs = gym.vector.SyncVectorEnv(
        [make_single for _ in range(num_envs)]
    )
    return envs

Key Takeaways

  • Trust Region methods guarantee learning stability by limiting the size of policy updates
  • PPO provides a simple yet effective Trust Region approximation through clipped objectives
  • TRPO implements a theoretically precise Trust Region with KL divergence constraints and conjugate gradients
  • ACKTR achieves high sample efficiency with natural gradients using K-FAC
  • In practice, PPO is the most widely used

In the next post, we will explore Black-Box optimization (evolutionary strategies, genetic algorithms) that optimize policies without gradients.