- Authors

- Name
- Youngju Kim
- @fjvbn20031
概要
方策勾配法の核心問題の1つは**学習率(step size)**の選択です。大きすぎると方策が急に悪化して回復できなくなり、小さすぎると学習が遅くなりすぎます。Trust Region手法は方策更新の大きさを制限してこの問題を解決します。
この記事では、Roboschool環境でA2Cベースラインを設定した後、PPO、TRPO、ACKTRの原理と実装を見ていきます。
Roboschool環境
Roboschool(現在はPyBullet)はMuJoCoのオープンソース代替として、さまざまなロボット制御環境を提供します:
- HalfCheetah:2足チーターロボットの前進歩行
- Hopper:1足ジャンプロボットのバランスと移動
- Walker2D:2足歩行ロボット
- Humanoid:人体型ロボットの歩行
import gymnasium as gym
def make_env(env_name='HalfCheetah-v4'):
env = gym.make(env_name)
obs_size = env.observation_space.shape[0]
act_size = env.action_space.shape[0]
act_limit = env.action_space.high[0]
return env, obs_size, act_size, act_limit
A2Cベースライン
比較のための連続行動空間A2Cベースライン:
import torch
import torch.nn as nn
import numpy as np
class A2CBaseline(nn.Module):
def __init__(self, obs_size, act_size):
super().__init__()
self.shared = nn.Sequential(
nn.Linear(obs_size, 64),
nn.Tanh(),
nn.Linear(64, 64),
nn.Tanh(),
)
self.mu = nn.Linear(64, act_size)
self.log_std = nn.Parameter(torch.zeros(act_size))
self.value = nn.Linear(64, 1)
def forward(self, obs):
features = self.shared(obs)
mu = self.mu(features)
std = self.log_std.exp()
value = self.value(features)
return mu, std, value
A2Cはシンプルですが学習率に敏感です。大きな更新1回で方策が崩壊する可能性があります。
Proximal Policy Optimization(PPO)
PPOは2017年にOpenAIのSchulmanらが提案したアルゴリズムで、TRPOの複雑な制約最適化を単純なクリッピングで近似します。実装が簡単でありながら性能が優れているため、最も広く使用されているアルゴリズムです。
クリッピング目的関数
PPOの核心は方策比率(policy ratio)をクリッピングして大きすぎる更新を防ぐことです:
方策比率 r(t) = 新しい方策の確率 / 以前の方策の確率
クリッピングされた目的関数はr(t)を(1-epsilon, 1+epsilon)の範囲に制限します。一般的にepsilon = 0.2を使用します。
class PPOAgent:
def __init__(self, obs_size, act_size, clip_epsilon=0.2,
lr=3e-4, gamma=0.99, lam=0.95):
self.clip_epsilon = clip_epsilon
self.gamma = gamma
self.lam = lam
self.model = A2CBaseline(obs_size, act_size)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
def compute_ppo_loss(self, obs, actions, old_log_probs,
advantages, returns):
"""PPOクリッピング目的関数"""
mu, std, values = self.model(obs)
dist = torch.distributions.Normal(mu, std)
# 現在の方策のログ確率
new_log_probs = dist.log_prob(actions).sum(dim=-1)
entropy = dist.entropy().sum(dim=-1).mean()
# 方策比率
ratio = torch.exp(new_log_probs - old_log_probs)
# クリッピングされた目的関数
surr1 = ratio * advantages
surr2 = torch.clamp(ratio,
1.0 - self.clip_epsilon,
1.0 + self.clip_epsilon) * advantages
policy_loss = -torch.min(surr1, surr2).mean()
# 価値損失
value_loss = (returns - values.squeeze()).pow(2).mean()
# 全体損失
loss = policy_loss + 0.5 * value_loss - 0.01 * entropy
return loss, policy_loss.item(), value_loss.item(), entropy.item()
PPO全体の学習ループ
def collect_trajectories(env, model, num_steps=2048):
"""環境から経験を収集"""
obs_list, act_list, rew_list = [], [], []
done_list, logprob_list, val_list = [], [], []
obs = env.reset()[0]
for _ in range(num_steps):
obs_t = torch.FloatTensor(obs).unsqueeze(0)
mu, std, value = model(obs_t)
dist = torch.distributions.Normal(mu, std)
action = dist.sample()
log_prob = dist.log_prob(action).sum(dim=-1)
action_np = action.detach().numpy().flatten()
next_obs, reward, terminated, truncated, _ = env.step(action_np)
done = terminated or truncated
obs_list.append(obs)
act_list.append(action.detach().squeeze(0))
rew_list.append(reward)
done_list.append(done)
logprob_list.append(log_prob.detach())
val_list.append(value.squeeze().detach())
obs = next_obs if not done else env.reset()[0]
return (obs_list, act_list, rew_list,
done_list, logprob_list, val_list)
def compute_gae(rewards, values, dones, gamma=0.99, lam=0.95):
"""Generalized Advantage Estimation"""
advantages = []
gae = 0
next_value = 0
for t in reversed(range(len(rewards))):
if dones[t]:
delta = rewards[t] - values[t]
gae = delta
else:
next_val = values[t+1] if t+1 < len(values) else next_value
delta = rewards[t] + gamma * next_val - values[t]
gae = delta + gamma * lam * gae
advantages.insert(0, gae)
advantages = torch.FloatTensor(advantages)
returns = advantages + torch.FloatTensor([v.item() for v in values])
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
return advantages, returns
def train_ppo(env_name='HalfCheetah-v4', total_timesteps=1000000,
num_steps=2048, num_epochs=10, batch_size=64):
"""PPOメイン学習関数"""
env, obs_size, act_size, _ = make_env(env_name)
agent = PPOAgent(obs_size, act_size)
num_updates = total_timesteps // num_steps
for update in range(num_updates):
# 1. 経験収集
data = collect_trajectories(env, agent.model, num_steps)
obs_list, act_list, rew_list, done_list, logprob_list, val_list = data
# 2. GAE計算
advantages, returns = compute_gae(rew_list, val_list, done_list)
# テンソル変換
obs_t = torch.FloatTensor(np.array(obs_list))
act_t = torch.stack(act_list)
old_logprobs = torch.cat(logprob_list)
# 3. ミニバッチPPO更新(複数エポック)
dataset_size = len(obs_list)
for epoch in range(num_epochs):
indices = np.random.permutation(dataset_size)
for start in range(0, dataset_size, batch_size):
end = start + batch_size
idx = indices[start:end]
loss, pl, vl, ent = agent.compute_ppo_loss(
obs_t[idx], act_t[idx], old_logprobs[idx],
advantages[idx], returns[idx]
)
agent.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(
agent.model.parameters(), 0.5
)
agent.optimizer.step()
if update % 10 == 0:
avg_reward = np.sum(rew_list) / max(sum(done_list), 1)
print(f"Update {update}: AvgReward={avg_reward:.1f}, "
f"PolicyLoss={pl:.4f}, Entropy={ent:.4f}")
return agent
Trust Region Policy Optimization(TRPO)
TRPOはPPOより先に提案されたアルゴリズムで、方策更新をKLダイバージェンス(Kullback-Leibler divergence)制約の下で行います。
TRPOの数学的背景
TRPOは以下の最適化問題を解く必要があります:
目的:サロゲート目的関数の最大化(方策比率 * アドバンテージの期待値) 制約:以前の方策と新しい方策の間のKLダイバージェンスがdelta以下
これを効率的に解くために**共役勾配(conjugate gradient)**法を使用します。
def conjugate_gradient(Avp_fn, b, num_steps=10, residual_tol=1e-10):
"""共役勾配アルゴリズム
Avp_fn: ヘシアン-ベクトル積を計算する関数
b: 右辺ベクトル(方策勾配)
"""
x = torch.zeros_like(b)
r = b.clone()
p = b.clone()
rdotr = torch.dot(r, r)
for _ in range(num_steps):
Avp = Avp_fn(p)
alpha = rdotr / (torch.dot(p, Avp) + 1e-8)
x += alpha * p
r -= alpha * Avp
new_rdotr = torch.dot(r, r)
if new_rdotr < residual_tol:
break
beta = new_rdotr / rdotr
p = r + beta * p
rdotr = new_rdotr
return x
def hessian_vector_product(model, obs, old_dist, vector, damping=0.1):
"""フィッシャー情報行列とベクトルの積を計算"""
mu, std, _ = model(obs)
new_dist = torch.distributions.Normal(mu, std)
kl = torch.distributions.kl_divergence(old_dist, new_dist).sum(dim=-1).mean()
kl_grad = torch.autograd.grad(kl, model.parameters(), create_graph=True)
kl_grad_flat = torch.cat([g.view(-1) for g in kl_grad])
kl_grad_vector = torch.dot(kl_grad_flat, vector)
hvp = torch.autograd.grad(kl_grad_vector, model.parameters())
hvp_flat = torch.cat([g.contiguous().view(-1) for g in hvp])
return hvp_flat + damping * vector
def trpo_step(model, obs, actions, advantages, old_log_probs,
max_kl=0.01):
"""TRPO更新ステップ"""
# 1. 方策勾配の計算
mu, std, _ = model(obs)
dist = torch.distributions.Normal(mu, std)
log_probs = dist.log_prob(actions).sum(dim=-1)
ratio = torch.exp(log_probs - old_log_probs)
surrogate = (ratio * advantages).mean()
policy_grad = torch.autograd.grad(surrogate, model.parameters())
policy_grad_flat = torch.cat([g.view(-1) for g in policy_grad])
# 2. 共役勾配でステップ方向を計算
old_dist = torch.distributions.Normal(mu.detach(), std.detach())
Avp_fn = lambda v: hessian_vector_product(model, obs, old_dist, v)
step_dir = conjugate_gradient(Avp_fn, policy_grad_flat)
# 3. ステップサイズの決定(ラインサーチ)
shs = 0.5 * torch.dot(step_dir, Avp_fn(step_dir))
max_step = torch.sqrt(max_kl / (shs + 1e-8))
full_step = max_step * step_dir
# 4. ラインサーチで適切なステップを見つける
old_params = torch.cat([p.data.view(-1) for p in model.parameters()])
expected_improve = torch.dot(policy_grad_flat, full_step)
for fraction in [1.0, 0.5, 0.25, 0.125]:
new_params = old_params + fraction * full_step
_set_flat_params(model, new_params)
# KLダイバージェンスのチェック
new_mu, new_std, _ = model(obs)
new_dist = torch.distributions.Normal(new_mu, new_std)
kl = torch.distributions.kl_divergence(old_dist, new_dist)
kl = kl.sum(dim=-1).mean()
if kl < max_kl:
return True # 成功
# 失敗時は元のパラメータに復元
_set_flat_params(model, old_params)
return False
def _set_flat_params(model, flat_params):
offset = 0
for param in model.parameters():
size = param.numel()
param.data.copy_(flat_params[offset:offset+size].view(param.shape))
offset += size
ACKTR: A2C using Kronecker-Factored Trust Region
ACKTRは**Kronecker-factored approximate curvature(K-FAC)**を使用して自然勾配(natural gradient)を効率的に近似します。
核心アイデア
- 通常の勾配:パラメータ空間での最急降下方向
- 自然勾配:分布空間(確率分布が変化する程度)での最急降下方向
- K-FAC:フィッシャー情報行列の逆行列を効率的に近似
class KFACOptimizer:
"""K-FACオプティマイザの概念的実装"""
def __init__(self, model, lr=0.25, damping=1e-3, update_freq=10):
self.model = model
self.lr = lr
self.damping = damping
self.update_freq = update_freq
self.steps = 0
# 各レイヤーのフィッシャー情報因子
self.fisher_factors = {}
def step(self, closure=None):
"""K-FAC更新ステップ"""
self.steps += 1
# フィッシャー因子の更新(定期的に)
if self.steps % self.update_freq == 0:
self._update_fisher_factors()
# 自然勾配の計算と適用
for name, param in self.model.named_parameters():
if param.grad is not None:
# フィッシャー逆行列 * 勾配 = 自然勾配
natural_grad = self._compute_natural_gradient(
name, param.grad
)
param.data -= self.lr * natural_grad
def _update_fisher_factors(self):
"""クロネッカー因子の更新"""
# A = E[a * a^T](入力活性化の外積)
# G = E[g * g^T](出力勾配の外積)
# フィッシャー近似: F ~ A (x) G(クロネッカー積)
pass
def _compute_natural_gradient(self, name, grad):
"""自然勾配 = F^-1 * grad"""
# K-FAC: (A (x) G)^-1 = A^-1 (x) G^-1
# 大きな行列の逆行列の代わりに小さな行列2つの逆行列のみ計算
return grad # 簡略化した返却
アルゴリズムの比較
HalfCheetah-v4の性能比較
| アルゴリズム | 100万ステップ報酬 | 実装難易度 | ハイパーパラメータ感度 |
|---|---|---|---|
| A2C | 約3000 | 簡単 | 高い |
| PPO | 約6000 | 簡単 | 低い |
| TRPO | 約5500 | 難しい | 非常に低い |
| ACKTR | 約7000 | 非常に難しい | 低い |
選択ガイド
- PPO:ほとんどの状況で最初の選択肢。実装が簡単で性能が良い
- TRPO:理論的保証が必要な研究目的
- ACKTR:サンプル効率が重要な場合。ただし実装の複雑度が高い
PPO実践的なヒント
- 学習率スケジューリング:線形減少が一般的
def linear_schedule(initial_lr, current_step, total_steps):
return initial_lr * (1.0 - current_step / total_steps)
-
アドバンテージの正規化:ミニバッチ内で平均0、分散1に正規化
-
価値関数のクリッピング:方策と同様に価値関数の更新もクリッピング
-
勾配クリッピング:max_grad_norm = 0.5が一般的
-
並列環境:ベクトル化環境でサンプル収集速度を向上
def make_vec_env(env_name, num_envs=8):
"""ベクトル化環境の作成"""
def make_single():
return gym.make(env_name)
envs = gym.vector.SyncVectorEnv(
[make_single for _ in range(num_envs)]
)
return envs
要点まとめ
- Trust Region手法は方策更新の大きさを制限して学習の安定性を保証する
- PPOはクリッピングされた目的関数で簡単でありながら効果的なTrust Region近似を提供する
- TRPOはKLダイバージェンス制約と共役勾配で理論的に正確なTrust Regionを実装する
- ACKTRはK-FACを使用した自然勾配で高いサンプル効率を達成する
- 実践ではPPOが最も広く使用されている
次の記事では、勾配なしで方策を最適化する**Black-Box最適化(進化戦略、遺伝的アルゴリズム)**を見ていきます。