Skip to content
Published on

[深層強化学習] 14. 連続行動空間:DDPGと分布方策

Authors

概要

これまで扱ってきたほとんどの環境は離散行動空間(左/右、各方向キーなど)を持っていました。しかし、ロボット制御、自律走行、物理シミュレーションなどの現実問題では連続行動空間が必要です。関節角度、トルク、加速度などは実数値だからです。

この記事では、連続行動空間に対する方策設計、A2Cの拡張、DDPG(Deep Deterministic Policy Gradient)、そして分布方策勾配を扱います。


なぜ連続空間が必要なのか

離散化の限界

連続行動を離散化する方法(例:トルクを-1、-0.5、0、0.5、1に分割)は簡単ですが問題があります:

  • 次元の呪い:N個の関節それぞれをK段階に離散化すると行動空間がK^Nに爆発する
  • 精度不足:離散化間隔の間の細かい制御が不可能
  • 非効率性:大部分の離散行動の組み合わせは無意味

例えば6関節ロボットを10段階に離散化すると10^6 = 100万個の行動が生まれます。

連続行動空間の表現

連続方策は行動の確率分布を出力します:

  • ガウシアン方策:平均と分散を出力して正規分布から行動をサンプリング
  • ベータ方策:有限範囲の行動に適したベータ分布を使用
  • 決定的方策:1つの行動値を直接出力(DDPG)

行動空間の設計

import torch
import torch.nn as nn
import numpy as np

class ContinuousActionSpace:
    """連続行動空間の定義"""

    def __init__(self, low, high):
        self.low = np.array(low, dtype=np.float32)
        self.high = np.array(high, dtype=np.float32)
        self.shape = self.low.shape

    def clip(self, action):
        return np.clip(action, self.low, self.high)

    def sample(self):
        return np.random.uniform(self.low, self.high)

A2Cを連続行動空間に適用

既存A2Cのカテゴリカル分布をガウシアン分布に置き換えます:

class ContinuousActorCritic(nn.Module):
    """連続行動空間用Actor-Critic"""

    def __init__(self, obs_size, act_size):
        super().__init__()

        self.shared = nn.Sequential(
            nn.Linear(obs_size, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
        )

        # Actor: 平均とログ標準偏差を出力
        self.mu = nn.Sequential(
            nn.Linear(256, act_size),
            nn.Tanh(),  # 行動範囲を-1, 1に制限
        )
        # ログ標準偏差は状態非依存パラメータとして学習
        self.log_std = nn.Parameter(torch.zeros(act_size))

        # Critic
        self.value = nn.Linear(256, 1)

    def forward(self, obs):
        features = self.shared(obs)
        mu = self.mu(features)
        std = self.log_std.exp()
        value = self.value(features)
        return mu, std, value

    def get_action(self, obs):
        mu, std, value = self.forward(obs)
        dist = torch.distributions.Normal(mu, std)
        action = dist.sample()
        log_prob = dist.log_prob(action).sum(dim=-1)
        return action, log_prob, value

    def evaluate(self, obs, action):
        mu, std, value = self.forward(obs)
        dist = torch.distributions.Normal(mu, std)
        log_prob = dist.log_prob(action).sum(dim=-1)
        entropy = dist.entropy().sum(dim=-1)
        return log_prob, entropy, value

連続A2Cの学習

def train_continuous_a2c(env, model, num_steps=2048, num_updates=1000,
                         gamma=0.99, lr=3e-4, entropy_coef=0.01):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    obs = env.reset()
    for update in range(num_updates):
        # 経験収集
        observations = []
        actions = []
        rewards = []
        dones = []
        log_probs = []
        values = []

        for _ in range(num_steps):
            obs_t = torch.FloatTensor(obs).unsqueeze(0)
            action, log_prob, value = model.get_action(obs_t)

            action_np = action.detach().numpy().flatten()
            action_clipped = np.clip(action_np, env.action_space.low,
                                     env.action_space.high)

            next_obs, reward, done, _ = env.step(action_clipped)

            observations.append(obs)
            actions.append(action.detach())
            rewards.append(reward)
            dones.append(done)
            log_probs.append(log_prob)
            values.append(value.squeeze())

            obs = next_obs if not done else env.reset()

        # GAE(Generalized Advantage Estimation)の計算
        advantages, returns = compute_gae(rewards, values, dones, gamma)

        # 更新
        obs_batch = torch.FloatTensor(np.array(observations))
        act_batch = torch.cat(actions)
        adv_batch = torch.FloatTensor(advantages)
        ret_batch = torch.FloatTensor(returns)

        new_log_probs, entropy, new_values = model.evaluate(
            obs_batch, act_batch
        )

        policy_loss = -(new_log_probs * adv_batch).mean()
        value_loss = (ret_batch - new_values.squeeze()).pow(2).mean()
        entropy_loss = -entropy.mean()

        loss = policy_loss + 0.5 * value_loss + entropy_coef * entropy_loss

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

def compute_gae(rewards, values, dones, gamma, lam=0.95):
    """Generalized Advantage Estimation"""
    advantages = []
    gae = 0
    values_list = [v.detach().item() for v in values]
    values_list.append(0)  # ブートストラップ

    for t in reversed(range(len(rewards))):
        if dones[t]:
            delta = rewards[t] - values_list[t]
            gae = delta
        else:
            delta = rewards[t] + gamma * values_list[t+1] - values_list[t]
            gae = delta + gamma * lam * gae
        advantages.insert(0, gae)

    returns = [a + v for a, v in zip(advantages, values_list[:-1])]
    return advantages, returns

決定的方策勾配(DDPG)

DDPGの核心アイデア

**DDPG(Deep Deterministic Policy Gradient)**はDQNのアイデアを連続行動空間に適用したアルゴリズムです:

  • 決定的方策:確率的にサンプリングせず、状態に対して1つの行動を直接出力
  • 経験リプレイ:DQNのように過去の経験をバッファに保存して再利用
  • ターゲットネットワーク:学習安定化のためのソフト更新
  • 探索ノイズ:決定的方策にノイズを追加して探索

ネットワーク構造

class DDPGActor(nn.Module):
    """決定的方策ネットワーク"""

    def __init__(self, obs_size, act_size, act_limit):
        super().__init__()
        self.act_limit = act_limit
        self.net = nn.Sequential(
            nn.Linear(obs_size, 400),
            nn.ReLU(),
            nn.Linear(400, 300),
            nn.ReLU(),
            nn.Linear(300, act_size),
            nn.Tanh(),
        )

    def forward(self, obs):
        return self.act_limit * self.net(obs)

class DDPGCritic(nn.Module):
    """行動価値関数 Q(s, a)"""

    def __init__(self, obs_size, act_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_size + act_size, 400),
            nn.ReLU(),
            nn.Linear(400, 300),
            nn.ReLU(),
            nn.Linear(300, 1),
        )

    def forward(self, obs, action):
        x = torch.cat([obs, action], dim=-1)
        return self.net(x)

探索ノイズ:Ornstein-Uhlenbeckプロセス

DDPGは時間的相関のあるOUノイズを使用して連続的で滑らかな探索を行います:

class OrnsteinUhlenbeckNoise:
    """時間相関のある探索ノイズ"""

    def __init__(self, size, mu=0.0, theta=0.15, sigma=0.2):
        self.size = size
        self.mu = mu
        self.theta = theta
        self.sigma = sigma
        self.state = np.ones(size) * mu

    def reset(self):
        self.state = np.ones(self.size) * self.mu

    def sample(self):
        dx = (self.theta * (self.mu - self.state)
              + self.sigma * np.random.randn(self.size))
        self.state += dx
        return self.state.copy()

DDPG全体の実装

import copy
from collections import deque
import random

class ReplayBuffer:
    def __init__(self, capacity=1000000):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (np.array(states), np.array(actions),
                np.array(rewards, dtype=np.float32),
                np.array(next_states),
                np.array(dones, dtype=np.float32))

    def __len__(self):
        return len(self.buffer)

class DDPGAgent:
    def __init__(self, obs_size, act_size, act_limit,
                 gamma=0.99, tau=0.005, lr_actor=1e-4, lr_critic=1e-3):
        self.gamma = gamma
        self.tau = tau

        # メインネットワーク
        self.actor = DDPGActor(obs_size, act_size, act_limit)
        self.critic = DDPGCritic(obs_size, act_size)

        # ターゲットネットワーク(ソフト更新対象)
        self.target_actor = copy.deepcopy(self.actor)
        self.target_critic = copy.deepcopy(self.critic)

        self.actor_optimizer = torch.optim.Adam(
            self.actor.parameters(), lr=lr_actor)
        self.critic_optimizer = torch.optim.Adam(
            self.critic.parameters(), lr=lr_critic)

        self.replay_buffer = ReplayBuffer()
        self.noise = OrnsteinUhlenbeckNoise(act_size)

    def select_action(self, state, add_noise=True):
        state_t = torch.FloatTensor(state).unsqueeze(0)
        action = self.actor(state_t).detach().numpy().flatten()
        if add_noise:
            action += self.noise.sample()
        return np.clip(action, -self.actor.act_limit, self.actor.act_limit)

    def update(self, batch_size=256):
        if len(self.replay_buffer) < batch_size:
            return

        states, actions, rewards, next_states, dones = \
            self.replay_buffer.sample(batch_size)

        states_t = torch.FloatTensor(states)
        actions_t = torch.FloatTensor(actions)
        rewards_t = torch.FloatTensor(rewards).unsqueeze(1)
        next_states_t = torch.FloatTensor(next_states)
        dones_t = torch.FloatTensor(dones).unsqueeze(1)

        # === Critic更新 ===
        with torch.no_grad():
            next_actions = self.target_actor(next_states_t)
            target_q = self.target_critic(next_states_t, next_actions)
            target_value = rewards_t + self.gamma * (1 - dones_t) * target_q

        current_q = self.critic(states_t, actions_t)
        critic_loss = nn.MSELoss()(current_q, target_value)

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # === Actor更新 ===
        # 方策の目標:Q(s, mu(s))を最大化
        predicted_actions = self.actor(states_t)
        actor_loss = -self.critic(states_t, predicted_actions).mean()

        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # === ターゲットネットワークのソフト更新 ===
        self._soft_update(self.actor, self.target_actor)
        self._soft_update(self.critic, self.target_critic)

    def _soft_update(self, source, target):
        for src_param, tgt_param in zip(source.parameters(),
                                         target.parameters()):
            tgt_param.data.copy_(
                self.tau * src_param.data + (1 - self.tau) * tgt_param.data
            )

分布方策勾配

D4PG: Distributed Distributional DDPG

D4PGはDDPGに分布強化学習(Distributional RL)を組み合わせたアルゴリズムです。Q値の期待値の代わりに全体の分布を学習します。

class DistributionalCritic(nn.Module):
    """分布Q関数:リターンの分布を予測"""

    def __init__(self, obs_size, act_size,
                 num_atoms=51, v_min=-10.0, v_max=10.0):
        super().__init__()
        self.num_atoms = num_atoms
        self.v_min = v_min
        self.v_max = v_max

        self.support = torch.linspace(v_min, v_max, num_atoms)
        self.delta_z = (v_max - v_min) / (num_atoms - 1)

        self.net = nn.Sequential(
            nn.Linear(obs_size + act_size, 400),
            nn.ReLU(),
            nn.Linear(400, 300),
            nn.ReLU(),
            nn.Linear(300, num_atoms),
        )

    def forward(self, obs, action):
        x = torch.cat([obs, action], dim=-1)
        logits = self.net(x)
        probs = torch.softmax(logits, dim=-1)
        return probs

    def get_q_value(self, obs, action):
        """期待Q値の計算"""
        probs = self.forward(obs, action)
        return (probs * self.support.unsqueeze(0)).sum(dim=-1, keepdim=True)

分布TD学習

def distributional_critic_loss(critic, target_critic, target_actor,
                                states, actions, rewards, next_states,
                                dones, gamma=0.99):
    """分布クリティックの損失計算(カテゴリカルプロジェクション)"""
    num_atoms = critic.num_atoms
    v_min = critic.v_min
    v_max = critic.v_max
    delta_z = critic.delta_z
    support = critic.support

    with torch.no_grad():
        next_actions = target_actor(next_states)
        target_probs = target_critic(next_states, next_actions)

        # ベルマン更新されたサポート
        tz = rewards.unsqueeze(1) + gamma * (1 - dones.unsqueeze(1)) * \
             support.unsqueeze(0)
        tz = tz.clamp(v_min, v_max)

        # カテゴリカルプロジェクション
        b = (tz - v_min) / delta_z
        l = b.floor().long()
        u = b.ceil().long()

        projected = torch.zeros_like(target_probs)
        for i in range(num_atoms):
            projected.scatter_add_(1, l[:, i:i+1],
                                   target_probs[:, i:i+1] * (u[:, i:i+1].float() - b[:, i:i+1]))
            projected.scatter_add_(1, u[:, i:i+1],
                                   target_probs[:, i:i+1] * (b[:, i:i+1] - l[:, i:i+1].float()))

    # クロスエントロピー損失
    current_logits = critic.net(torch.cat([states, actions], dim=-1))
    loss = -(projected * torch.log_softmax(current_logits, dim=-1)).sum(dim=-1).mean()

    return loss

実験結果の比較

MuJoCoのHalfCheetah-v4環境での性能比較:

アルゴリズム100万ステップ平均報酬学習安定性サンプル効率
A2C(連続)約3000
DDPG約8000低(敏感)
D4PG約9500

DDPGはハイパーパラメータに敏感ですが、適切にチューニングすれば高い性能を示します。D4PGは分布学習のおかげでより安定的です。


要点まとめ

  • 連続行動空間はガウシアン方策(A2C)または決定的方策(DDPG)で処理する
  • DDPGはDQNの核心技法(リプレイバッファ、ターゲットネットワーク)を連続空間に適用する
  • OUノイズで時間的相関のある滑らかな探索を行う
  • 分布方策勾配(D4PG)はリターンの分布を学習して安定性を高める

次の記事では、方策更新の安定性を保証する**Trust Region手法(TRPO、PPO、ACKTR)**を扱います。