Skip to content
Published on

[深層強化学習] 15. Trust Region手法:TRPO、PPO、ACKTR

Authors

概要

方策勾配法の核心問題の1つは**学習率(step size)**の選択です。大きすぎると方策が急に悪化して回復できなくなり、小さすぎると学習が遅くなりすぎます。Trust Region手法は方策更新の大きさを制限してこの問題を解決します。

この記事では、Roboschool環境でA2Cベースラインを設定した後、PPO、TRPO、ACKTRの原理と実装を見ていきます。


Roboschool環境

Roboschool(現在はPyBullet)はMuJoCoのオープンソース代替として、さまざまなロボット制御環境を提供します:

  • HalfCheetah:2足チーターロボットの前進歩行
  • Hopper:1足ジャンプロボットのバランスと移動
  • Walker2D:2足歩行ロボット
  • Humanoid:人体型ロボットの歩行
import gymnasium as gym

def make_env(env_name='HalfCheetah-v4'):
    env = gym.make(env_name)
    obs_size = env.observation_space.shape[0]
    act_size = env.action_space.shape[0]
    act_limit = env.action_space.high[0]
    return env, obs_size, act_size, act_limit

A2Cベースライン

比較のための連続行動空間A2Cベースライン:

import torch
import torch.nn as nn
import numpy as np

class A2CBaseline(nn.Module):
    def __init__(self, obs_size, act_size):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(obs_size, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
        )
        self.mu = nn.Linear(64, act_size)
        self.log_std = nn.Parameter(torch.zeros(act_size))
        self.value = nn.Linear(64, 1)

    def forward(self, obs):
        features = self.shared(obs)
        mu = self.mu(features)
        std = self.log_std.exp()
        value = self.value(features)
        return mu, std, value

A2Cはシンプルですが学習率に敏感です。大きな更新1回で方策が崩壊する可能性があります。


Proximal Policy Optimization(PPO)

PPOは2017年にOpenAIのSchulmanらが提案したアルゴリズムで、TRPOの複雑な制約最適化を単純なクリッピングで近似します。実装が簡単でありながら性能が優れているため、最も広く使用されているアルゴリズムです。

クリッピング目的関数

PPOの核心は方策比率(policy ratio)をクリッピングして大きすぎる更新を防ぐことです:

方策比率 r(t) = 新しい方策の確率 / 以前の方策の確率

クリッピングされた目的関数はr(t)を(1-epsilon, 1+epsilon)の範囲に制限します。一般的にepsilon = 0.2を使用します。

class PPOAgent:
    def __init__(self, obs_size, act_size, clip_epsilon=0.2,
                 lr=3e-4, gamma=0.99, lam=0.95):
        self.clip_epsilon = clip_epsilon
        self.gamma = gamma
        self.lam = lam

        self.model = A2CBaseline(obs_size, act_size)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)

    def compute_ppo_loss(self, obs, actions, old_log_probs,
                         advantages, returns):
        """PPOクリッピング目的関数"""
        mu, std, values = self.model(obs)
        dist = torch.distributions.Normal(mu, std)

        # 現在の方策のログ確率
        new_log_probs = dist.log_prob(actions).sum(dim=-1)
        entropy = dist.entropy().sum(dim=-1).mean()

        # 方策比率
        ratio = torch.exp(new_log_probs - old_log_probs)

        # クリッピングされた目的関数
        surr1 = ratio * advantages
        surr2 = torch.clamp(ratio,
                            1.0 - self.clip_epsilon,
                            1.0 + self.clip_epsilon) * advantages
        policy_loss = -torch.min(surr1, surr2).mean()

        # 価値損失
        value_loss = (returns - values.squeeze()).pow(2).mean()

        # 全体損失
        loss = policy_loss + 0.5 * value_loss - 0.01 * entropy
        return loss, policy_loss.item(), value_loss.item(), entropy.item()

PPO全体の学習ループ

def collect_trajectories(env, model, num_steps=2048):
    """環境から経験を収集"""
    obs_list, act_list, rew_list = [], [], []
    done_list, logprob_list, val_list = [], [], []

    obs = env.reset()[0]
    for _ in range(num_steps):
        obs_t = torch.FloatTensor(obs).unsqueeze(0)
        mu, std, value = model(obs_t)
        dist = torch.distributions.Normal(mu, std)
        action = dist.sample()
        log_prob = dist.log_prob(action).sum(dim=-1)

        action_np = action.detach().numpy().flatten()
        next_obs, reward, terminated, truncated, _ = env.step(action_np)
        done = terminated or truncated

        obs_list.append(obs)
        act_list.append(action.detach().squeeze(0))
        rew_list.append(reward)
        done_list.append(done)
        logprob_list.append(log_prob.detach())
        val_list.append(value.squeeze().detach())

        obs = next_obs if not done else env.reset()[0]

    return (obs_list, act_list, rew_list,
            done_list, logprob_list, val_list)

def compute_gae(rewards, values, dones, gamma=0.99, lam=0.95):
    """Generalized Advantage Estimation"""
    advantages = []
    gae = 0
    next_value = 0

    for t in reversed(range(len(rewards))):
        if dones[t]:
            delta = rewards[t] - values[t]
            gae = delta
        else:
            next_val = values[t+1] if t+1 < len(values) else next_value
            delta = rewards[t] + gamma * next_val - values[t]
            gae = delta + gamma * lam * gae
        advantages.insert(0, gae)

    advantages = torch.FloatTensor(advantages)
    returns = advantages + torch.FloatTensor([v.item() for v in values])
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
    return advantages, returns

def train_ppo(env_name='HalfCheetah-v4', total_timesteps=1000000,
              num_steps=2048, num_epochs=10, batch_size=64):
    """PPOメイン学習関数"""
    env, obs_size, act_size, _ = make_env(env_name)
    agent = PPOAgent(obs_size, act_size)

    num_updates = total_timesteps // num_steps

    for update in range(num_updates):
        # 1. 経験収集
        data = collect_trajectories(env, agent.model, num_steps)
        obs_list, act_list, rew_list, done_list, logprob_list, val_list = data

        # 2. GAE計算
        advantages, returns = compute_gae(rew_list, val_list, done_list)

        # テンソル変換
        obs_t = torch.FloatTensor(np.array(obs_list))
        act_t = torch.stack(act_list)
        old_logprobs = torch.cat(logprob_list)

        # 3. ミニバッチPPO更新(複数エポック)
        dataset_size = len(obs_list)
        for epoch in range(num_epochs):
            indices = np.random.permutation(dataset_size)
            for start in range(0, dataset_size, batch_size):
                end = start + batch_size
                idx = indices[start:end]

                loss, pl, vl, ent = agent.compute_ppo_loss(
                    obs_t[idx], act_t[idx], old_logprobs[idx],
                    advantages[idx], returns[idx]
                )

                agent.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    agent.model.parameters(), 0.5
                )
                agent.optimizer.step()

        if update % 10 == 0:
            avg_reward = np.sum(rew_list) / max(sum(done_list), 1)
            print(f"Update {update}: AvgReward={avg_reward:.1f}, "
                  f"PolicyLoss={pl:.4f}, Entropy={ent:.4f}")

    return agent

Trust Region Policy Optimization(TRPO)

TRPOはPPOより先に提案されたアルゴリズムで、方策更新をKLダイバージェンス(Kullback-Leibler divergence)制約の下で行います。

TRPOの数学的背景

TRPOは以下の最適化問題を解く必要があります:

目的:サロゲート目的関数の最大化(方策比率 * アドバンテージの期待値) 制約:以前の方策と新しい方策の間のKLダイバージェンスがdelta以下

これを効率的に解くために**共役勾配(conjugate gradient)**法を使用します。

def conjugate_gradient(Avp_fn, b, num_steps=10, residual_tol=1e-10):
    """共役勾配アルゴリズム
    Avp_fn: ヘシアン-ベクトル積を計算する関数
    b: 右辺ベクトル(方策勾配)
    """
    x = torch.zeros_like(b)
    r = b.clone()
    p = b.clone()
    rdotr = torch.dot(r, r)

    for _ in range(num_steps):
        Avp = Avp_fn(p)
        alpha = rdotr / (torch.dot(p, Avp) + 1e-8)
        x += alpha * p
        r -= alpha * Avp
        new_rdotr = torch.dot(r, r)
        if new_rdotr < residual_tol:
            break
        beta = new_rdotr / rdotr
        p = r + beta * p
        rdotr = new_rdotr

    return x

def hessian_vector_product(model, obs, old_dist, vector, damping=0.1):
    """フィッシャー情報行列とベクトルの積を計算"""
    mu, std, _ = model(obs)
    new_dist = torch.distributions.Normal(mu, std)
    kl = torch.distributions.kl_divergence(old_dist, new_dist).sum(dim=-1).mean()

    kl_grad = torch.autograd.grad(kl, model.parameters(), create_graph=True)
    kl_grad_flat = torch.cat([g.view(-1) for g in kl_grad])

    kl_grad_vector = torch.dot(kl_grad_flat, vector)
    hvp = torch.autograd.grad(kl_grad_vector, model.parameters())
    hvp_flat = torch.cat([g.contiguous().view(-1) for g in hvp])

    return hvp_flat + damping * vector

def trpo_step(model, obs, actions, advantages, old_log_probs,
              max_kl=0.01):
    """TRPO更新ステップ"""
    # 1. 方策勾配の計算
    mu, std, _ = model(obs)
    dist = torch.distributions.Normal(mu, std)
    log_probs = dist.log_prob(actions).sum(dim=-1)

    ratio = torch.exp(log_probs - old_log_probs)
    surrogate = (ratio * advantages).mean()

    policy_grad = torch.autograd.grad(surrogate, model.parameters())
    policy_grad_flat = torch.cat([g.view(-1) for g in policy_grad])

    # 2. 共役勾配でステップ方向を計算
    old_dist = torch.distributions.Normal(mu.detach(), std.detach())
    Avp_fn = lambda v: hessian_vector_product(model, obs, old_dist, v)
    step_dir = conjugate_gradient(Avp_fn, policy_grad_flat)

    # 3. ステップサイズの決定(ラインサーチ)
    shs = 0.5 * torch.dot(step_dir, Avp_fn(step_dir))
    max_step = torch.sqrt(max_kl / (shs + 1e-8))
    full_step = max_step * step_dir

    # 4. ラインサーチで適切なステップを見つける
    old_params = torch.cat([p.data.view(-1) for p in model.parameters()])
    expected_improve = torch.dot(policy_grad_flat, full_step)

    for fraction in [1.0, 0.5, 0.25, 0.125]:
        new_params = old_params + fraction * full_step
        _set_flat_params(model, new_params)

        # KLダイバージェンスのチェック
        new_mu, new_std, _ = model(obs)
        new_dist = torch.distributions.Normal(new_mu, new_std)
        kl = torch.distributions.kl_divergence(old_dist, new_dist)
        kl = kl.sum(dim=-1).mean()

        if kl < max_kl:
            return True  # 成功

    # 失敗時は元のパラメータに復元
    _set_flat_params(model, old_params)
    return False

def _set_flat_params(model, flat_params):
    offset = 0
    for param in model.parameters():
        size = param.numel()
        param.data.copy_(flat_params[offset:offset+size].view(param.shape))
        offset += size

ACKTR: A2C using Kronecker-Factored Trust Region

ACKTRは**Kronecker-factored approximate curvature(K-FAC)**を使用して自然勾配(natural gradient)を効率的に近似します。

核心アイデア

  • 通常の勾配:パラメータ空間での最急降下方向
  • 自然勾配:分布空間(確率分布が変化する程度)での最急降下方向
  • K-FAC:フィッシャー情報行列の逆行列を効率的に近似
class KFACOptimizer:
    """K-FACオプティマイザの概念的実装"""

    def __init__(self, model, lr=0.25, damping=1e-3, update_freq=10):
        self.model = model
        self.lr = lr
        self.damping = damping
        self.update_freq = update_freq
        self.steps = 0

        # 各レイヤーのフィッシャー情報因子
        self.fisher_factors = {}

    def step(self, closure=None):
        """K-FAC更新ステップ"""
        self.steps += 1

        # フィッシャー因子の更新(定期的に)
        if self.steps % self.update_freq == 0:
            self._update_fisher_factors()

        # 自然勾配の計算と適用
        for name, param in self.model.named_parameters():
            if param.grad is not None:
                # フィッシャー逆行列 * 勾配 = 自然勾配
                natural_grad = self._compute_natural_gradient(
                    name, param.grad
                )
                param.data -= self.lr * natural_grad

    def _update_fisher_factors(self):
        """クロネッカー因子の更新"""
        # A = E[a * a^T](入力活性化の外積)
        # G = E[g * g^T](出力勾配の外積)
        # フィッシャー近似: F ~ A (x) G(クロネッカー積)
        pass

    def _compute_natural_gradient(self, name, grad):
        """自然勾配 = F^-1 * grad"""
        # K-FAC: (A (x) G)^-1 = A^-1 (x) G^-1
        # 大きな行列の逆行列の代わりに小さな行列2つの逆行列のみ計算
        return grad  # 簡略化した返却

アルゴリズムの比較

HalfCheetah-v4の性能比較

アルゴリズム100万ステップ報酬実装難易度ハイパーパラメータ感度
A2C約3000簡単高い
PPO約6000簡単低い
TRPO約5500難しい非常に低い
ACKTR約7000非常に難しい低い

選択ガイド

  • PPO:ほとんどの状況で最初の選択肢。実装が簡単で性能が良い
  • TRPO:理論的保証が必要な研究目的
  • ACKTR:サンプル効率が重要な場合。ただし実装の複雑度が高い

PPO実践的なヒント

  1. 学習率スケジューリング:線形減少が一般的
def linear_schedule(initial_lr, current_step, total_steps):
    return initial_lr * (1.0 - current_step / total_steps)
  1. アドバンテージの正規化:ミニバッチ内で平均0、分散1に正規化

  2. 価値関数のクリッピング:方策と同様に価値関数の更新もクリッピング

  3. 勾配クリッピング:max_grad_norm = 0.5が一般的

  4. 並列環境:ベクトル化環境でサンプル収集速度を向上

def make_vec_env(env_name, num_envs=8):
    """ベクトル化環境の作成"""
    def make_single():
        return gym.make(env_name)
    envs = gym.vector.SyncVectorEnv(
        [make_single for _ in range(num_envs)]
    )
    return envs

要点まとめ

  • Trust Region手法は方策更新の大きさを制限して学習の安定性を保証する
  • PPOはクリッピングされた目的関数で簡単でありながら効果的なTrust Region近似を提供する
  • TRPOはKLダイバージェンス制約と共役勾配で理論的に正確なTrust Regionを実装する
  • ACKTRはK-FACを使用した自然勾配で高いサンプル効率を達成する
  • 実践ではPPOが最も広く使用されている

次の記事では、勾配なしで方策を最適化する**Black-Box最適化(進化戦略、遺伝的アルゴリズム)**を見ていきます。