Skip to content

Split View: [심층 강화학습] 17. 모델 기반 강화학습: Imagination-Augmented Agent

|

[심층 강화학습] 17. 모델 기반 강화학습: Imagination-Augmented Agent

개요

지금까지 다룬 알고리즘들은 모두 모델-프리(model-free) 방법이었다. 환경의 전이 동역학을 몰라도 시행착오만으로 학습했다. 하지만 인간은 행동하기 전에 "이렇게 하면 어떻게 될까?"를 상상한다. 모델 기반(model-based) 강화학습은 환경 모델을 학습하여 이런 상상(롤아웃)을 가능하게 한다.

이 글에서는 모델-프리와 모델-기반의 차이, 환경 모델의 불완전성 문제, 그리고 I2A(Imagination-Augmented Agent)를 다룬다.


모델-프리 vs 모델-기반

핵심 차이

항목모델-프리모델-기반
환경 지식없음전이 모델 학습/보유
샘플 효율성낮음높음
계획불가가능 (롤아웃)
모델 오류 영향없음있음 (편향 유발)
대표 알고리즘DQN, PPO, A3CDyna, MBPO, I2A

모델-기반 RL의 일반 흐름

  1. 실제 환경과 상호작용하여 데이터 수집
  2. 수집된 데이터로 환경 모델(전이 함수 + 보상 함수) 학습
  3. 학습된 모델을 사용하여 추가 경험 생성(상상/롤아웃)
  4. 실제 + 상상 경험을 합쳐 정책 개선
import torch
import torch.nn as nn
import numpy as np

class EnvironmentModel(nn.Module):
    """학습 가능한 환경 모델: 상태 전이와 보상 예측"""

    def __init__(self, obs_size, act_size, hidden_size=256):
        super().__init__()
        self.transition = nn.Sequential(
            nn.Linear(obs_size + act_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, obs_size),
        )
        self.reward_pred = nn.Sequential(
            nn.Linear(obs_size + act_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
        )

    def forward(self, state, action):
        """다음 상태와 보상 예측"""
        if action.dim() == 1:
            action = action.unsqueeze(-1)
        sa = torch.cat([state, action.float()], dim=-1)
        next_state = state + self.transition(sa)  # 잔차 학습
        reward = self.reward_pred(sa)
        return next_state, reward.squeeze(-1)

모델 불완전성 문제

왜 모델이 완벽하지 않은가

학습된 환경 모델은 필연적으로 오류를 가진다:

  1. 데이터 한계: 방문하지 않은 상태-행동 쌍에 대한 예측이 부정확
  2. 복합 오류(Compounding Error): 롤아웃이 길어질수록 오류가 누적
  3. 분포 이동: 모델이 학습한 데이터 분포와 정책이 방문하는 분포가 다름
def demonstrate_compounding_error():
    """복합 오류 시연: 1단계 오류가 어떻게 누적되는지"""
    true_state = np.array([0.0, 0.0])
    predicted_state = np.array([0.0, 0.0])
    errors = []

    for step in range(50):
        # 실제 전이 (예시)
        true_state = true_state + np.array([0.1, 0.05])

        # 모델 전이 (약간의 오류 포함)
        model_error = np.random.normal(0, 0.01, 2)
        predicted_state = predicted_state + np.array([0.1, 0.05]) + model_error

        error = np.linalg.norm(true_state - predicted_state)
        errors.append(error)

    # 오류가 sqrt(t)에 비례하여 증가
    print(f"10스텝 오류: {errors[9]:.4f}")
    print(f"50스텝 오류: {errors[49]:.4f}")

대응 전략

  • 짧은 롤아웃: 모델 예측을 짧게 유지하여 복합 오류 제한
  • 앙상블 모델: 여러 모델의 불확실성을 활용
  • 모델과 실제 경험 혼합: Dyna 스타일로 실제 데이터를 계속 사용

Imagination-Augmented Agent (I2A)

I2A는 DeepMind이 2017년에 제안한 모델 기반 RL 아키텍처로, 환경 모델의 불완전성을 고려하면서도 상상을 활용하는 방법이다.

I2A의 핵심 구성 요소

  1. 환경 모델 (Environment Model): 상태 전이를 예측
  2. 롤아웃 정책 (Rollout Policy): 상상 속에서 행동을 선택
  3. 롤아웃 인코더 (Rollout Encoder): 상상 궤적을 고정 길이 벡터로 압축
  4. 모델-프리 경로: 환경 모델 없이 직접 상태에서 특징 추출

환경 모델

Atari 같은 비주얼 환경에서는 ConvLSTM 기반 환경 모델을 사용한다:

class AtariEnvironmentModel(nn.Module):
    """Atari 환경 모델: 다음 프레임과 보상 예측"""

    def __init__(self, num_actions, channels=1):
        super().__init__()
        self.num_actions = num_actions

        # 행동을 이미지와 같은 크기의 채널로 확장
        self.action_embed = nn.Embedding(num_actions, 84 * 84)

        # 인코더
        self.encoder = nn.Sequential(
            nn.Conv2d(channels + 1, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.ReLU(),
        )

        # 디코더 (다음 프레임 생성)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, channels, 4, stride=2, padding=1),
        )

        # 보상 예측
        self.reward_head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
        )

    def forward(self, state, action):
        """현재 상태와 행동으로 다음 상태/보상 예측"""
        batch_size = state.shape[0]

        # 행동을 이미지 형태로 변환
        action_plane = self.action_embed(action).view(
            batch_size, 1, 84, 84
        )

        # 상태 + 행동 결합
        x = torch.cat([state, action_plane], dim=1)

        # 인코딩
        encoded = self.encoder(x)

        # 다음 프레임 예측
        next_frame = self.decoder(encoded)

        # 보상 예측
        reward = self.reward_head(encoded)

        return next_frame, reward.squeeze(-1)

롤아웃 정책

상상 속에서 행동을 선택하는 정책. 보통 사전 학습된 간단한 정책이나 현재 정책의 사본을 사용한다:

class RolloutPolicy(nn.Module):
    """상상 롤아웃에 사용되는 정책"""

    def __init__(self, obs_channels, num_actions):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(obs_channels, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
        )
        self._feature_size = self._get_conv_size(obs_channels)
        self.fc = nn.Sequential(
            nn.Linear(self._feature_size, 256),
            nn.ReLU(),
            nn.Linear(256, num_actions),
        )

    def _get_conv_size(self, channels):
        with torch.no_grad():
            x = torch.zeros(1, channels, 84, 84)
            return self.conv(x).view(1, -1).shape[1]

    def forward(self, obs):
        features = self.conv(obs).view(obs.size(0), -1)
        return self.fc(features)

    def get_action(self, obs):
        logits = self.forward(obs)
        return torch.distributions.Categorical(
            logits=logits
        ).sample()

롤아웃 인코더

상상 궤적(상태-보상 시퀀스)을 고정 길이 벡터로 요약한다:

class RolloutEncoder(nn.Module):
    """상상 궤적을 고정 길이 벡터로 인코딩"""

    def __init__(self, obs_channels, hidden_size=256):
        super().__init__()

        # 각 프레임의 특징 추출
        self.frame_encoder = nn.Sequential(
            nn.Conv2d(obs_channels, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(4),
            nn.Flatten(),
        )
        self._frame_feature_size = 64 * 4 * 4

        # 시퀀스 인코딩 (LSTM)
        self.lstm = nn.LSTM(
            self._frame_feature_size + 1,  # 프레임 특징 + 보상
            hidden_size,
            batch_first=True,
        )

        self.output_size = hidden_size

    def forward(self, imagined_frames, imagined_rewards):
        """
        imagined_frames: (batch, rollout_len, C, H, W)
        imagined_rewards: (batch, rollout_len)
        """
        batch, rollout_len = imagined_frames.shape[:2]

        # 각 프레임 인코딩
        frames_flat = imagined_frames.view(-1, *imagined_frames.shape[2:])
        frame_features = self.frame_encoder(frames_flat)
        frame_features = frame_features.view(batch, rollout_len, -1)

        # 보상 결합
        rewards = imagined_rewards.unsqueeze(-1)
        lstm_input = torch.cat([frame_features, rewards], dim=-1)

        # LSTM으로 시퀀스 인코딩
        _, (hidden, _) = self.lstm(lstm_input)
        return hidden.squeeze(0)  # (batch, hidden_size)

I2A 전체 아키텍처

class I2A(nn.Module):
    """Imagination-Augmented Agent"""

    def __init__(self, obs_channels, num_actions, rollout_len=5,
                 hidden_size=256):
        super().__init__()
        self.num_actions = num_actions
        self.rollout_len = rollout_len

        # 환경 모델 (사전 학습)
        self.env_model = AtariEnvironmentModel(num_actions, obs_channels)

        # 롤아웃 정책 (사전 학습)
        self.rollout_policy = RolloutPolicy(obs_channels, num_actions)

        # 각 행동에 대한 롤아웃 인코더
        self.rollout_encoder = RolloutEncoder(obs_channels, hidden_size)

        # 모델-프리 경로
        self.model_free_path = nn.Sequential(
            nn.Conv2d(obs_channels, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.Flatten(),
        )
        self._mf_size = self._get_mf_size(obs_channels)

        # 결합 후 정책/가치 출력
        combined_size = self._mf_size + hidden_size * num_actions
        self.policy_head = nn.Sequential(
            nn.Linear(combined_size, 512),
            nn.ReLU(),
            nn.Linear(512, num_actions),
        )
        self.value_head = nn.Sequential(
            nn.Linear(combined_size, 512),
            nn.ReLU(),
            nn.Linear(512, 1),
        )

    def _get_mf_size(self, channels):
        with torch.no_grad():
            x = torch.zeros(1, channels, 84, 84)
            return self.model_free_path(x).shape[1]

    def imagine(self, state):
        """각 행동에 대해 상상 롤아웃 수행"""
        batch_size = state.shape[0]
        all_encodings = []

        for action_idx in range(self.num_actions):
            imagined_frames = []
            imagined_rewards = []
            current = state

            for step in range(self.rollout_len):
                # 첫 스텝은 지정된 행동, 이후는 롤아웃 정책
                if step == 0:
                    action = torch.full(
                        (batch_size,), action_idx, dtype=torch.long
                    )
                else:
                    action = self.rollout_policy.get_action(current)

                # 환경 모델로 다음 상태 예측
                with torch.no_grad():
                    next_frame, reward = self.env_model(current, action)

                imagined_frames.append(next_frame)
                imagined_rewards.append(reward)
                current = next_frame

            # 롤아웃 인코딩
            frames_stack = torch.stack(imagined_frames, dim=1)
            rewards_stack = torch.stack(imagined_rewards, dim=1)
            encoding = self.rollout_encoder(frames_stack, rewards_stack)
            all_encodings.append(encoding)

        # 모든 행동의 상상 결과 결합
        return torch.cat(all_encodings, dim=-1)

    def forward(self, state):
        # 모델-프리 특징
        mf_features = self.model_free_path(state)

        # 상상 기반 특징
        imagination_features = self.imagine(state)

        # 결합
        combined = torch.cat([mf_features, imagination_features], dim=-1)

        policy_logits = self.policy_head(combined)
        value = self.value_head(combined)

        return policy_logits, value

I2A 학습 절차

3단계 학습

def train_i2a(env, config):
    """I2A 3단계 학습"""

    # === 1단계: 환경 모델 학습 ===
    print("1단계: 환경 모델 학습")
    env_model = AtariEnvironmentModel(
        config['num_actions'], config['obs_channels']
    )
    env_model_optimizer = torch.optim.Adam(
        env_model.parameters(), lr=1e-3
    )

    # 무작위 정책으로 데이터 수집 후 학습
    transitions = collect_random_transitions(env, num_steps=50000)
    for epoch in range(config['env_model_epochs']):
        loss = train_env_model_epoch(env_model, transitions,
                                      env_model_optimizer)
        if epoch % 10 == 0:
            print(f"  Epoch {epoch}: EnvModel Loss={loss:.4f}")

    # === 2단계: 롤아웃 정책 학습 (간단한 A2C) ===
    print("2단계: 롤아웃 정책 학습")
    rollout_policy = RolloutPolicy(
        config['obs_channels'], config['num_actions']
    )
    train_a2c_policy(env, rollout_policy, num_steps=100000)

    # === 3단계: I2A 전체 학습 ===
    print("3단계: I2A 학습")
    i2a = I2A(config['obs_channels'], config['num_actions'])
    i2a.env_model = env_model  # 사전 학습된 환경 모델
    i2a.rollout_policy = rollout_policy  # 사전 학습된 롤아웃 정책

    # 환경 모델과 롤아웃 정책은 고정
    for param in i2a.env_model.parameters():
        param.requires_grad = False
    for param in i2a.rollout_policy.parameters():
        param.requires_grad = False

    # A2C로 나머지 파라미터 학습
    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, i2a.parameters()),
        lr=7e-4
    )
    train_a2c_with_model(env, i2a, optimizer,
                         num_steps=config['total_steps'])

    return i2a

def train_env_model_epoch(env_model, transitions, optimizer):
    """환경 모델 한 에폭 학습"""
    total_loss = 0
    for state, action, reward, next_state in transitions:
        pred_next, pred_reward = env_model(state, action)

        frame_loss = nn.MSELoss()(pred_next, next_state)
        reward_loss = nn.MSELoss()(pred_reward, reward)
        loss = frame_loss + reward_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(transitions)

Atari Breakout 실험 결과

환경 모델 품질

학습된 환경 모델이 Breakout 게임의 프레임을 얼마나 잘 예측하는지가 I2A 성능의 핵심이다:

  • 1스텝 예측: 매우 정확 (MSE 낮음)
  • 5스텝 예측: 공의 위치에 약간의 흐림 발생
  • 15스텝 예측: 상당한 불확실성, 하지만 전반적인 게임 상황은 파악 가능

I2A vs 기준선 성능

방법Breakout 점수
A2C (모델-프리)약 400
I2A (롤아웃 5스텝)약 500
I2A + 완벽한 환경 모델약 600

I2A는 불완전한 환경 모델에서도 성능 향상을 보여준다. 핵심은 롤아웃 인코더가 모델의 불확실성을 자연스럽게 처리하도록 학습된다는 점이다.


실전 고려사항

환경 모델 학습의 어려움

  • 비주얼 환경: 픽셀 레벨 예측은 어렵고 비용이 크다. 잠재 공간(latent space)에서 예측하는 방법이 효율적이다.
  • 확률적 환경: 결정적 모델은 다중 모드(multimodal) 전이를 평균화한다. VAE나 GAN 기반 모델이 필요할 수 있다.

최신 연구 방향

  • World Models: VAE로 잠재 공간을 학습하고 RNN으로 전이를 예측
  • Dreamer: 잠재 공간에서 상상하고 계획하는 통합 프레임워크
  • MuZero: 환경 모델을 명시적으로 학습하지 않고, 가치 예측에 필요한 잠재 동역학만 학습

핵심 요약

  • 모델 기반 RL은 환경 모델을 학습하여 상상(롤아웃)으로 샘플 효율성을 높인다
  • 모델의 복합 오류가 주요 도전 과제이며, 짧은 롤아웃과 앙상블로 대응한다
  • I2A는 여러 행동에 대한 상상 결과를 인코딩하여 정책 결정에 활용한다
  • 모델-프리 경로를 병행하여 모델 오류에 대한 안전장치를 제공한다

다음 글에서는 모델 기반 RL의 가장 인상적인 사례인 AlphaGo Zero를 살펴보겠다.

[Deep RL] 17. Model-Based Reinforcement Learning: Imagination-Augmented Agent

Overview

All algorithms covered so far were model-free methods. They learned through trial and error without knowing the environment's transition dynamics. But humans imagine "what would happen if I do this?" before acting. Model-based reinforcement learning enables such imagination (rollouts) by learning an environment model.

This post covers the differences between model-free and model-based approaches, the problem of model imperfection, and I2A (Imagination-Augmented Agent).


Model-Free vs Model-Based

Key Differences

ItemModel-FreeModel-Based
Environment knowledgeNoneLearn/possess transition model
Sample efficiencyLowHigh
PlanningNot possiblePossible (rollouts)
Model error impactNonePresent (causes bias)
Representative algorithmsDQN, PPO, A3CDyna, MBPO, I2A

General Flow of Model-Based RL

  1. Collect data by interacting with the real environment
  2. Learn the environment model (transition function + reward function) from collected data
  3. Generate additional experiences using the learned model (imagination/rollouts)
  4. Improve policy using both real and imagined experiences
import torch
import torch.nn as nn
import numpy as np

class EnvironmentModel(nn.Module):
    """Learnable environment model: predicts state transitions and rewards"""

    def __init__(self, obs_size, act_size, hidden_size=256):
        super().__init__()
        self.transition = nn.Sequential(
            nn.Linear(obs_size + act_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, obs_size),
        )
        self.reward_pred = nn.Sequential(
            nn.Linear(obs_size + act_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
        )

    def forward(self, state, action):
        """Predict next state and reward"""
        if action.dim() == 1:
            action = action.unsqueeze(-1)
        sa = torch.cat([state, action.float()], dim=-1)
        next_state = state + self.transition(sa)  # Residual learning
        reward = self.reward_pred(sa)
        return next_state, reward.squeeze(-1)

Model Imperfection Problem

Why Models Are Not Perfect

Learned environment models inevitably have errors:

  1. Data limitations: Predictions are inaccurate for unvisited state-action pairs
  2. Compounding error: Errors accumulate as rollouts get longer
  3. Distribution shift: The data distribution the model was trained on differs from what the policy visits
def demonstrate_compounding_error():
    """Demonstrate compounding error: how 1-step errors accumulate"""
    true_state = np.array([0.0, 0.0])
    predicted_state = np.array([0.0, 0.0])
    errors = []

    for step in range(50):
        # True transition (example)
        true_state = true_state + np.array([0.1, 0.05])

        # Model transition (with slight error)
        model_error = np.random.normal(0, 0.01, 2)
        predicted_state = predicted_state + np.array([0.1, 0.05]) + model_error

        error = np.linalg.norm(true_state - predicted_state)
        errors.append(error)

    # Error grows proportionally to sqrt(t)
    print(f"Error at 10 steps: {errors[9]:.4f}")
    print(f"Error at 50 steps: {errors[49]:.4f}")

Mitigation Strategies

  • Short rollouts: Keep model predictions short to limit compounding error
  • Ensemble models: Leverage uncertainty from multiple models
  • Mix model and real experiences: Continue using real data in Dyna style

Imagination-Augmented Agent (I2A)

I2A is a model-based RL architecture proposed by DeepMind in 2017 that leverages imagination while accounting for model imperfection.

Core Components of I2A

  1. Environment Model: Predicts state transitions
  2. Rollout Policy: Selects actions during imagination
  3. Rollout Encoder: Compresses imagined trajectories into fixed-length vectors
  4. Model-Free Path: Extracts features directly from state without the environment model

Environment Model

For visual environments like Atari, a ConvLSTM-based environment model is used:

class AtariEnvironmentModel(nn.Module):
    """Atari environment model: predicts next frame and reward"""

    def __init__(self, num_actions, channels=1):
        super().__init__()
        self.num_actions = num_actions

        # Expand action to image-sized channel
        self.action_embed = nn.Embedding(num_actions, 84 * 84)

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(channels + 1, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.ReLU(),
        )

        # Decoder (next frame generation)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, channels, 4, stride=2, padding=1),
        )

        # Reward prediction
        self.reward_head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
        )

    def forward(self, state, action):
        """Predict next state/reward from current state and action"""
        batch_size = state.shape[0]

        # Convert action to image format
        action_plane = self.action_embed(action).view(
            batch_size, 1, 84, 84
        )

        # Combine state + action
        x = torch.cat([state, action_plane], dim=1)

        # Encoding
        encoded = self.encoder(x)

        # Next frame prediction
        next_frame = self.decoder(encoded)

        # Reward prediction
        reward = self.reward_head(encoded)

        return next_frame, reward.squeeze(-1)

Rollout Policy

A policy that selects actions during imagination. Usually a pretrained simple policy or a copy of the current policy:

class RolloutPolicy(nn.Module):
    """Policy used for imagination rollouts"""

    def __init__(self, obs_channels, num_actions):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(obs_channels, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
        )
        self._feature_size = self._get_conv_size(obs_channels)
        self.fc = nn.Sequential(
            nn.Linear(self._feature_size, 256),
            nn.ReLU(),
            nn.Linear(256, num_actions),
        )

    def _get_conv_size(self, channels):
        with torch.no_grad():
            x = torch.zeros(1, channels, 84, 84)
            return self.conv(x).view(1, -1).shape[1]

    def forward(self, obs):
        features = self.conv(obs).view(obs.size(0), -1)
        return self.fc(features)

    def get_action(self, obs):
        logits = self.forward(obs)
        return torch.distributions.Categorical(
            logits=logits
        ).sample()

Rollout Encoder

Summarizes imagined trajectories (state-reward sequences) into fixed-length vectors:

class RolloutEncoder(nn.Module):
    """Encode imagined trajectories into fixed-length vectors"""

    def __init__(self, obs_channels, hidden_size=256):
        super().__init__()

        # Feature extraction for each frame
        self.frame_encoder = nn.Sequential(
            nn.Conv2d(obs_channels, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(4),
            nn.Flatten(),
        )
        self._frame_feature_size = 64 * 4 * 4

        # Sequence encoding (LSTM)
        self.lstm = nn.LSTM(
            self._frame_feature_size + 1,  # Frame features + reward
            hidden_size,
            batch_first=True,
        )

        self.output_size = hidden_size

    def forward(self, imagined_frames, imagined_rewards):
        """
        imagined_frames: (batch, rollout_len, C, H, W)
        imagined_rewards: (batch, rollout_len)
        """
        batch, rollout_len = imagined_frames.shape[:2]

        # Encode each frame
        frames_flat = imagined_frames.view(-1, *imagined_frames.shape[2:])
        frame_features = self.frame_encoder(frames_flat)
        frame_features = frame_features.view(batch, rollout_len, -1)

        # Combine with rewards
        rewards = imagined_rewards.unsqueeze(-1)
        lstm_input = torch.cat([frame_features, rewards], dim=-1)

        # LSTM sequence encoding
        _, (hidden, _) = self.lstm(lstm_input)
        return hidden.squeeze(0)  # (batch, hidden_size)

Full I2A Architecture

class I2A(nn.Module):
    """Imagination-Augmented Agent"""

    def __init__(self, obs_channels, num_actions, rollout_len=5,
                 hidden_size=256):
        super().__init__()
        self.num_actions = num_actions
        self.rollout_len = rollout_len

        # Environment model (pretrained)
        self.env_model = AtariEnvironmentModel(num_actions, obs_channels)

        # Rollout policy (pretrained)
        self.rollout_policy = RolloutPolicy(obs_channels, num_actions)

        # Rollout encoder for each action
        self.rollout_encoder = RolloutEncoder(obs_channels, hidden_size)

        # Model-free path
        self.model_free_path = nn.Sequential(
            nn.Conv2d(obs_channels, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.Flatten(),
        )
        self._mf_size = self._get_mf_size(obs_channels)

        # Policy/value output after combination
        combined_size = self._mf_size + hidden_size * num_actions
        self.policy_head = nn.Sequential(
            nn.Linear(combined_size, 512),
            nn.ReLU(),
            nn.Linear(512, num_actions),
        )
        self.value_head = nn.Sequential(
            nn.Linear(combined_size, 512),
            nn.ReLU(),
            nn.Linear(512, 1),
        )

    def _get_mf_size(self, channels):
        with torch.no_grad():
            x = torch.zeros(1, channels, 84, 84)
            return self.model_free_path(x).shape[1]

    def imagine(self, state):
        """Perform imagination rollout for each action"""
        batch_size = state.shape[0]
        all_encodings = []

        for action_idx in range(self.num_actions):
            imagined_frames = []
            imagined_rewards = []
            current = state

            for step in range(self.rollout_len):
                # First step uses specified action, subsequent steps use rollout policy
                if step == 0:
                    action = torch.full(
                        (batch_size,), action_idx, dtype=torch.long
                    )
                else:
                    action = self.rollout_policy.get_action(current)

                # Predict next state with environment model
                with torch.no_grad():
                    next_frame, reward = self.env_model(current, action)

                imagined_frames.append(next_frame)
                imagined_rewards.append(reward)
                current = next_frame

            # Rollout encoding
            frames_stack = torch.stack(imagined_frames, dim=1)
            rewards_stack = torch.stack(imagined_rewards, dim=1)
            encoding = self.rollout_encoder(frames_stack, rewards_stack)
            all_encodings.append(encoding)

        # Combine imagination results from all actions
        return torch.cat(all_encodings, dim=-1)

    def forward(self, state):
        # Model-free features
        mf_features = self.model_free_path(state)

        # Imagination-based features
        imagination_features = self.imagine(state)

        # Combine
        combined = torch.cat([mf_features, imagination_features], dim=-1)

        policy_logits = self.policy_head(combined)
        value = self.value_head(combined)

        return policy_logits, value

I2A Training Procedure

Three-Stage Training

def train_i2a(env, config):
    """I2A three-stage training"""

    # === Stage 1: Environment model training ===
    print("Stage 1: Training environment model")
    env_model = AtariEnvironmentModel(
        config['num_actions'], config['obs_channels']
    )
    env_model_optimizer = torch.optim.Adam(
        env_model.parameters(), lr=1e-3
    )

    # Collect data with random policy then train
    transitions = collect_random_transitions(env, num_steps=50000)
    for epoch in range(config['env_model_epochs']):
        loss = train_env_model_epoch(env_model, transitions,
                                      env_model_optimizer)
        if epoch % 10 == 0:
            print(f"  Epoch {epoch}: EnvModel Loss={loss:.4f}")

    # === Stage 2: Rollout policy training (simple A2C) ===
    print("Stage 2: Training rollout policy")
    rollout_policy = RolloutPolicy(
        config['obs_channels'], config['num_actions']
    )
    train_a2c_policy(env, rollout_policy, num_steps=100000)

    # === Stage 3: Full I2A training ===
    print("Stage 3: Training I2A")
    i2a = I2A(config['obs_channels'], config['num_actions'])
    i2a.env_model = env_model  # Pretrained environment model
    i2a.rollout_policy = rollout_policy  # Pretrained rollout policy

    # Freeze environment model and rollout policy
    for param in i2a.env_model.parameters():
        param.requires_grad = False
    for param in i2a.rollout_policy.parameters():
        param.requires_grad = False

    # Train remaining parameters with A2C
    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, i2a.parameters()),
        lr=7e-4
    )
    train_a2c_with_model(env, i2a, optimizer,
                         num_steps=config['total_steps'])

    return i2a

def train_env_model_epoch(env_model, transitions, optimizer):
    """Train environment model for one epoch"""
    total_loss = 0
    for state, action, reward, next_state in transitions:
        pred_next, pred_reward = env_model(state, action)

        frame_loss = nn.MSELoss()(pred_next, next_state)
        reward_loss = nn.MSELoss()(pred_reward, reward)
        loss = frame_loss + reward_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(transitions)

Atari Breakout Experimental Results

Environment Model Quality

How well the learned environment model predicts Breakout game frames is key to I2A performance:

  • 1-step prediction: Very accurate (low MSE)
  • 5-step prediction: Slight blurring in ball position
  • 15-step prediction: Significant uncertainty, but overall game situation remains understandable

I2A vs Baseline Performance

MethodBreakout Score
A2C (model-free)~400
I2A (5-step rollout)~500
I2A + perfect environment model~600

I2A shows performance improvement even with an imperfect environment model. The key is that the rollout encoder naturally learns to handle model uncertainty.


Practical Considerations

Difficulties in Environment Model Training

  • Visual environments: Pixel-level prediction is difficult and expensive. Predicting in latent space is more efficient.
  • Stochastic environments: Deterministic models average out multimodal transitions. VAE or GAN-based models may be needed.

Recent Research Directions

  • World Models: Learn latent space with VAE and predict transitions with RNN
  • Dreamer: Unified framework for imagining and planning in latent space
  • MuZero: Learns only the latent dynamics needed for value prediction without explicitly learning an environment model

Key Takeaways

  • Model-based RL improves sample efficiency by learning environment models for imagination (rollouts)
  • Compounding error of the model is the main challenge, mitigated by short rollouts and ensembles
  • I2A encodes imagination results for multiple actions to inform policy decisions
  • A model-free path provides a safety net against model errors

In the next post, we will examine AlphaGo Zero, the most impressive example of model-based RL.