Skip to content
Published on

[深層強化学習] 17. モデルベース強化学習:Imagination-Augmented Agent

Authors

概要

これまで扱ったアルゴリズムはすべてモデルフリー(model-free)手法でした。環境の遷移ダイナミクスを知らなくても試行錯誤だけで学習しました。しかし人間は行動する前に「こうしたらどうなるだろう?」と想像します。モデルベース(model-based)強化学習は環境モデルを学習してこのような想像(ロールアウト)を可能にします。

この記事では、モデルフリーとモデルベースの違い、環境モデルの不完全性問題、そしてI2A(Imagination-Augmented Agent)を扱います。


モデルフリー vs モデルベース

核心的な違い

項目モデルフリーモデルベース
環境の知識なし遷移モデルを学習/保有
サンプル効率低い高い
計画不可可能(ロールアウト)
モデル誤差の影響なしあり(バイアスを引き起こす)
代表アルゴリズムDQN, PPO, A3CDyna, MBPO, I2A

モデルベースRLの一般的な流れ

  1. 実際の環境と相互作用してデータを収集
  2. 収集データで環境モデル(遷移関数 + 報酬関数)を学習
  3. 学習済みモデルで追加経験を生成(想像/ロールアウト)
  4. 実際 + 想像の経験を合わせて方策を改善
import torch
import torch.nn as nn
import numpy as np

class EnvironmentModel(nn.Module):
    """学習可能な環境モデル:状態遷移と報酬を予測"""

    def __init__(self, obs_size, act_size, hidden_size=256):
        super().__init__()
        self.transition = nn.Sequential(
            nn.Linear(obs_size + act_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, obs_size),
        )
        self.reward_pred = nn.Sequential(
            nn.Linear(obs_size + act_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
        )

    def forward(self, state, action):
        """次の状態と報酬を予測"""
        if action.dim() == 1:
            action = action.unsqueeze(-1)
        sa = torch.cat([state, action.float()], dim=-1)
        next_state = state + self.transition(sa)  # 残差学習
        reward = self.reward_pred(sa)
        return next_state, reward.squeeze(-1)

モデルの不完全性問題

なぜモデルが完全でないのか

学習された環境モデルは必然的に誤差を持ちます:

  1. データの限界:訪問していない状態-行動ペアに対する予測が不正確
  2. 複合誤差(Compounding Error):ロールアウトが長くなるほど誤差が蓄積
  3. 分布シフト:モデルが学習したデータ分布と方策が訪問する分布が異なる
def demonstrate_compounding_error():
    """複合誤差のデモ:1ステップ誤差がどう蓄積されるか"""
    true_state = np.array([0.0, 0.0])
    predicted_state = np.array([0.0, 0.0])
    errors = []

    for step in range(50):
        # 実際の遷移(例)
        true_state = true_state + np.array([0.1, 0.05])

        # モデル遷移(わずかな誤差を含む)
        model_error = np.random.normal(0, 0.01, 2)
        predicted_state = predicted_state + np.array([0.1, 0.05]) + model_error

        error = np.linalg.norm(true_state - predicted_state)
        errors.append(error)

    # 誤差がsqrt(t)に比例して増加
    print(f"10ステップ誤差: {errors[9]:.4f}")
    print(f"50ステップ誤差: {errors[49]:.4f}")

対応戦略

  • 短いロールアウト:モデル予測を短く保って複合誤差を制限
  • アンサンブルモデル:複数モデルの不確実性を活用
  • モデルと実際の経験を混合:Dynaスタイルで実データを引き続き使用

Imagination-Augmented Agent(I2A)

I2Aは2017年にDeepMindが提案したモデルベースRLアーキテクチャで、環境モデルの不完全性を考慮しながらも想像を活用する方法です。

I2Aの核心構成要素

  1. 環境モデル(Environment Model):状態遷移を予測
  2. ロールアウト方策(Rollout Policy):想像の中で行動を選択
  3. ロールアウトエンコーダ(Rollout Encoder):想像軌道を固定長ベクトルに圧縮
  4. モデルフリー経路:環境モデルなしで直接状態から特徴を抽出

環境モデル

Atariのようなビジュアル環境ではConvLSTMベースの環境モデルを使用します:

class AtariEnvironmentModel(nn.Module):
    """Atari環境モデル:次のフレームと報酬を予測"""

    def __init__(self, num_actions, channels=1):
        super().__init__()
        self.num_actions = num_actions

        # 行動を画像と同じサイズのチャネルに拡張
        self.action_embed = nn.Embedding(num_actions, 84 * 84)

        # エンコーダ
        self.encoder = nn.Sequential(
            nn.Conv2d(channels + 1, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.ReLU(),
        )

        # デコーダ(次のフレーム生成)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, channels, 4, stride=2, padding=1),
        )

        # 報酬予測
        self.reward_head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
        )

    def forward(self, state, action):
        """現在の状態と行動から次の状態/報酬を予測"""
        batch_size = state.shape[0]
        action_plane = self.action_embed(action).view(batch_size, 1, 84, 84)
        x = torch.cat([state, action_plane], dim=1)
        encoded = self.encoder(x)
        next_frame = self.decoder(encoded)
        reward = self.reward_head(encoded)
        return next_frame, reward.squeeze(-1)

ロールアウト方策

想像の中で行動を選択する方策。通常、事前学習された簡単な方策や現在の方策のコピーを使用します:

class RolloutPolicy(nn.Module):
    """想像ロールアウトに使用される方策"""

    def __init__(self, obs_channels, num_actions):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(obs_channels, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
        )
        self._feature_size = self._get_conv_size(obs_channels)
        self.fc = nn.Sequential(
            nn.Linear(self._feature_size, 256),
            nn.ReLU(),
            nn.Linear(256, num_actions),
        )

    def _get_conv_size(self, channels):
        with torch.no_grad():
            x = torch.zeros(1, channels, 84, 84)
            return self.conv(x).view(1, -1).shape[1]

    def forward(self, obs):
        features = self.conv(obs).view(obs.size(0), -1)
        return self.fc(features)

    def get_action(self, obs):
        logits = self.forward(obs)
        return torch.distributions.Categorical(logits=logits).sample()

ロールアウトエンコーダ

想像軌道(状態-報酬シーケンス)を固定長ベクトルに要約します:

class RolloutEncoder(nn.Module):
    """想像軌道を固定長ベクトルにエンコード"""

    def __init__(self, obs_channels, hidden_size=256):
        super().__init__()
        self.frame_encoder = nn.Sequential(
            nn.Conv2d(obs_channels, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(4),
            nn.Flatten(),
        )
        self._frame_feature_size = 64 * 4 * 4
        self.lstm = nn.LSTM(
            self._frame_feature_size + 1,
            hidden_size,
            batch_first=True,
        )
        self.output_size = hidden_size

    def forward(self, imagined_frames, imagined_rewards):
        batch, rollout_len = imagined_frames.shape[:2]
        frames_flat = imagined_frames.view(-1, *imagined_frames.shape[2:])
        frame_features = self.frame_encoder(frames_flat)
        frame_features = frame_features.view(batch, rollout_len, -1)
        rewards = imagined_rewards.unsqueeze(-1)
        lstm_input = torch.cat([frame_features, rewards], dim=-1)
        _, (hidden, _) = self.lstm(lstm_input)
        return hidden.squeeze(0)

I2A全体アーキテクチャ

class I2A(nn.Module):
    """Imagination-Augmented Agent"""

    def __init__(self, obs_channels, num_actions, rollout_len=5,
                 hidden_size=256):
        super().__init__()
        self.num_actions = num_actions
        self.rollout_len = rollout_len
        self.env_model = AtariEnvironmentModel(num_actions, obs_channels)
        self.rollout_policy = RolloutPolicy(obs_channels, num_actions)
        self.rollout_encoder = RolloutEncoder(obs_channels, hidden_size)
        self.model_free_path = nn.Sequential(
            nn.Conv2d(obs_channels, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.Flatten(),
        )
        self._mf_size = self._get_mf_size(obs_channels)
        combined_size = self._mf_size + hidden_size * num_actions
        self.policy_head = nn.Sequential(
            nn.Linear(combined_size, 512), nn.ReLU(),
            nn.Linear(512, num_actions),
        )
        self.value_head = nn.Sequential(
            nn.Linear(combined_size, 512), nn.ReLU(),
            nn.Linear(512, 1),
        )

    def _get_mf_size(self, channels):
        with torch.no_grad():
            x = torch.zeros(1, channels, 84, 84)
            return self.model_free_path(x).shape[1]

    def imagine(self, state):
        batch_size = state.shape[0]
        all_encodings = []
        for action_idx in range(self.num_actions):
            imagined_frames, imagined_rewards = [], []
            current = state
            for step in range(self.rollout_len):
                if step == 0:
                    action = torch.full((batch_size,), action_idx, dtype=torch.long)
                else:
                    action = self.rollout_policy.get_action(current)
                with torch.no_grad():
                    next_frame, reward = self.env_model(current, action)
                imagined_frames.append(next_frame)
                imagined_rewards.append(reward)
                current = next_frame
            frames_stack = torch.stack(imagined_frames, dim=1)
            rewards_stack = torch.stack(imagined_rewards, dim=1)
            encoding = self.rollout_encoder(frames_stack, rewards_stack)
            all_encodings.append(encoding)
        return torch.cat(all_encodings, dim=-1)

    def forward(self, state):
        mf_features = self.model_free_path(state)
        imagination_features = self.imagine(state)
        combined = torch.cat([mf_features, imagination_features], dim=-1)
        policy_logits = self.policy_head(combined)
        value = self.value_head(combined)
        return policy_logits, value

I2A学習手順

3段階学習

def train_i2a(env, config):
    """I2A 3段階学習"""

    # === 第1段階:環境モデル学習 ===
    print("第1段階:環境モデル学習")
    env_model = AtariEnvironmentModel(config['num_actions'], config['obs_channels'])
    env_model_optimizer = torch.optim.Adam(env_model.parameters(), lr=1e-3)
    transitions = collect_random_transitions(env, num_steps=50000)
    for epoch in range(config['env_model_epochs']):
        loss = train_env_model_epoch(env_model, transitions, env_model_optimizer)
        if epoch % 10 == 0:
            print(f"  Epoch {epoch}: EnvModel Loss={loss:.4f}")

    # === 第2段階:ロールアウト方策学習(簡単なA2C)===
    print("第2段階:ロールアウト方策学習")
    rollout_policy = RolloutPolicy(config['obs_channels'], config['num_actions'])
    train_a2c_policy(env, rollout_policy, num_steps=100000)

    # === 第3段階:I2A全体学習 ===
    print("第3段階:I2A学習")
    i2a = I2A(config['obs_channels'], config['num_actions'])
    i2a.env_model = env_model
    i2a.rollout_policy = rollout_policy
    for param in i2a.env_model.parameters():
        param.requires_grad = False
    for param in i2a.rollout_policy.parameters():
        param.requires_grad = False
    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, i2a.parameters()), lr=7e-4
    )
    train_a2c_with_model(env, i2a, optimizer, num_steps=config['total_steps'])
    return i2a

def train_env_model_epoch(env_model, transitions, optimizer):
    """環境モデルの1エポック学習"""
    total_loss = 0
    for state, action, reward, next_state in transitions:
        pred_next, pred_reward = env_model(state, action)
        frame_loss = nn.MSELoss()(pred_next, next_state)
        reward_loss = nn.MSELoss()(pred_reward, reward)
        loss = frame_loss + reward_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(transitions)

Atari Breakout実験結果

環境モデルの品質

学習された環境モデルがBreakoutゲームのフレームをどれだけ正確に予測できるかがI2A性能の鍵です:

  • 1ステップ予測:非常に正確(低MSE)
  • 5ステップ予測:ボール位置にわずかなぼやけが発生
  • 15ステップ予測:かなりの不確実性、しかし全体的なゲーム状況は把握可能

I2A vs ベースラインの性能

方法Breakoutスコア
A2C(モデルフリー)約400
I2A(5ステップロールアウト)約500
I2A + 完全な環境モデル約600

I2Aは不完全な環境モデルでも性能向上を示します。核心はロールアウトエンコーダがモデルの不確実性を自然に処理するよう学習することです。


実践的な考慮事項

環境モデル学習の困難さ

  • ビジュアル環境:ピクセルレベルの予測は難しくコストがかかる。潜在空間(latent space)で予測する方法が効率的
  • 確率的環境:決定的モデルはマルチモーダル遷移を平均化する。VAEやGANベースのモデルが必要な場合がある

最新の研究方向

  • World Models:VAEで潜在空間を学習しRNNで遷移を予測
  • Dreamer:潜在空間で想像し計画する統合フレームワーク
  • MuZero:環境モデルを明示的に学習せず、価値予測に必要な潜在ダイナミクスのみ学習

要点まとめ

  • モデルベースRLは環境モデルを学習して想像(ロールアウト)でサンプル効率を高める
  • モデルの複合誤差が主な課題であり、短いロールアウトとアンサンブルで対応する
  • I2Aは複数の行動に対する想像結果をエンコードして方策決定に活用する
  • モデルフリー経路を併用してモデル誤差に対する安全装置を提供する

次の記事では、モデルベースRLの最も印象的な事例であるAlphaGo Zeroを見ていきます。