Skip to content
Published on

[Deep RL] 17. Model-Based Reinforcement Learning: Imagination-Augmented Agent

Authors

Overview

All algorithms covered so far were model-free methods. They learned through trial and error without knowing the environment's transition dynamics. But humans imagine "what would happen if I do this?" before acting. Model-based reinforcement learning enables such imagination (rollouts) by learning an environment model.

This post covers the differences between model-free and model-based approaches, the problem of model imperfection, and I2A (Imagination-Augmented Agent).


Model-Free vs Model-Based

Key Differences

ItemModel-FreeModel-Based
Environment knowledgeNoneLearn/possess transition model
Sample efficiencyLowHigh
PlanningNot possiblePossible (rollouts)
Model error impactNonePresent (causes bias)
Representative algorithmsDQN, PPO, A3CDyna, MBPO, I2A

General Flow of Model-Based RL

  1. Collect data by interacting with the real environment
  2. Learn the environment model (transition function + reward function) from collected data
  3. Generate additional experiences using the learned model (imagination/rollouts)
  4. Improve policy using both real and imagined experiences
import torch
import torch.nn as nn
import numpy as np

class EnvironmentModel(nn.Module):
    """Learnable environment model: predicts state transitions and rewards"""

    def __init__(self, obs_size, act_size, hidden_size=256):
        super().__init__()
        self.transition = nn.Sequential(
            nn.Linear(obs_size + act_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, obs_size),
        )
        self.reward_pred = nn.Sequential(
            nn.Linear(obs_size + act_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
        )

    def forward(self, state, action):
        """Predict next state and reward"""
        if action.dim() == 1:
            action = action.unsqueeze(-1)
        sa = torch.cat([state, action.float()], dim=-1)
        next_state = state + self.transition(sa)  # Residual learning
        reward = self.reward_pred(sa)
        return next_state, reward.squeeze(-1)

Model Imperfection Problem

Why Models Are Not Perfect

Learned environment models inevitably have errors:

  1. Data limitations: Predictions are inaccurate for unvisited state-action pairs
  2. Compounding error: Errors accumulate as rollouts get longer
  3. Distribution shift: The data distribution the model was trained on differs from what the policy visits
def demonstrate_compounding_error():
    """Demonstrate compounding error: how 1-step errors accumulate"""
    true_state = np.array([0.0, 0.0])
    predicted_state = np.array([0.0, 0.0])
    errors = []

    for step in range(50):
        # True transition (example)
        true_state = true_state + np.array([0.1, 0.05])

        # Model transition (with slight error)
        model_error = np.random.normal(0, 0.01, 2)
        predicted_state = predicted_state + np.array([0.1, 0.05]) + model_error

        error = np.linalg.norm(true_state - predicted_state)
        errors.append(error)

    # Error grows proportionally to sqrt(t)
    print(f"Error at 10 steps: {errors[9]:.4f}")
    print(f"Error at 50 steps: {errors[49]:.4f}")

Mitigation Strategies

  • Short rollouts: Keep model predictions short to limit compounding error
  • Ensemble models: Leverage uncertainty from multiple models
  • Mix model and real experiences: Continue using real data in Dyna style

Imagination-Augmented Agent (I2A)

I2A is a model-based RL architecture proposed by DeepMind in 2017 that leverages imagination while accounting for model imperfection.

Core Components of I2A

  1. Environment Model: Predicts state transitions
  2. Rollout Policy: Selects actions during imagination
  3. Rollout Encoder: Compresses imagined trajectories into fixed-length vectors
  4. Model-Free Path: Extracts features directly from state without the environment model

Environment Model

For visual environments like Atari, a ConvLSTM-based environment model is used:

class AtariEnvironmentModel(nn.Module):
    """Atari environment model: predicts next frame and reward"""

    def __init__(self, num_actions, channels=1):
        super().__init__()
        self.num_actions = num_actions

        # Expand action to image-sized channel
        self.action_embed = nn.Embedding(num_actions, 84 * 84)

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(channels + 1, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.ReLU(),
        )

        # Decoder (next frame generation)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, channels, 4, stride=2, padding=1),
        )

        # Reward prediction
        self.reward_head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
        )

    def forward(self, state, action):
        """Predict next state/reward from current state and action"""
        batch_size = state.shape[0]

        # Convert action to image format
        action_plane = self.action_embed(action).view(
            batch_size, 1, 84, 84
        )

        # Combine state + action
        x = torch.cat([state, action_plane], dim=1)

        # Encoding
        encoded = self.encoder(x)

        # Next frame prediction
        next_frame = self.decoder(encoded)

        # Reward prediction
        reward = self.reward_head(encoded)

        return next_frame, reward.squeeze(-1)

Rollout Policy

A policy that selects actions during imagination. Usually a pretrained simple policy or a copy of the current policy:

class RolloutPolicy(nn.Module):
    """Policy used for imagination rollouts"""

    def __init__(self, obs_channels, num_actions):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(obs_channels, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
        )
        self._feature_size = self._get_conv_size(obs_channels)
        self.fc = nn.Sequential(
            nn.Linear(self._feature_size, 256),
            nn.ReLU(),
            nn.Linear(256, num_actions),
        )

    def _get_conv_size(self, channels):
        with torch.no_grad():
            x = torch.zeros(1, channels, 84, 84)
            return self.conv(x).view(1, -1).shape[1]

    def forward(self, obs):
        features = self.conv(obs).view(obs.size(0), -1)
        return self.fc(features)

    def get_action(self, obs):
        logits = self.forward(obs)
        return torch.distributions.Categorical(
            logits=logits
        ).sample()

Rollout Encoder

Summarizes imagined trajectories (state-reward sequences) into fixed-length vectors:

class RolloutEncoder(nn.Module):
    """Encode imagined trajectories into fixed-length vectors"""

    def __init__(self, obs_channels, hidden_size=256):
        super().__init__()

        # Feature extraction for each frame
        self.frame_encoder = nn.Sequential(
            nn.Conv2d(obs_channels, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(4),
            nn.Flatten(),
        )
        self._frame_feature_size = 64 * 4 * 4

        # Sequence encoding (LSTM)
        self.lstm = nn.LSTM(
            self._frame_feature_size + 1,  # Frame features + reward
            hidden_size,
            batch_first=True,
        )

        self.output_size = hidden_size

    def forward(self, imagined_frames, imagined_rewards):
        """
        imagined_frames: (batch, rollout_len, C, H, W)
        imagined_rewards: (batch, rollout_len)
        """
        batch, rollout_len = imagined_frames.shape[:2]

        # Encode each frame
        frames_flat = imagined_frames.view(-1, *imagined_frames.shape[2:])
        frame_features = self.frame_encoder(frames_flat)
        frame_features = frame_features.view(batch, rollout_len, -1)

        # Combine with rewards
        rewards = imagined_rewards.unsqueeze(-1)
        lstm_input = torch.cat([frame_features, rewards], dim=-1)

        # LSTM sequence encoding
        _, (hidden, _) = self.lstm(lstm_input)
        return hidden.squeeze(0)  # (batch, hidden_size)

Full I2A Architecture

class I2A(nn.Module):
    """Imagination-Augmented Agent"""

    def __init__(self, obs_channels, num_actions, rollout_len=5,
                 hidden_size=256):
        super().__init__()
        self.num_actions = num_actions
        self.rollout_len = rollout_len

        # Environment model (pretrained)
        self.env_model = AtariEnvironmentModel(num_actions, obs_channels)

        # Rollout policy (pretrained)
        self.rollout_policy = RolloutPolicy(obs_channels, num_actions)

        # Rollout encoder for each action
        self.rollout_encoder = RolloutEncoder(obs_channels, hidden_size)

        # Model-free path
        self.model_free_path = nn.Sequential(
            nn.Conv2d(obs_channels, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.Flatten(),
        )
        self._mf_size = self._get_mf_size(obs_channels)

        # Policy/value output after combination
        combined_size = self._mf_size + hidden_size * num_actions
        self.policy_head = nn.Sequential(
            nn.Linear(combined_size, 512),
            nn.ReLU(),
            nn.Linear(512, num_actions),
        )
        self.value_head = nn.Sequential(
            nn.Linear(combined_size, 512),
            nn.ReLU(),
            nn.Linear(512, 1),
        )

    def _get_mf_size(self, channels):
        with torch.no_grad():
            x = torch.zeros(1, channels, 84, 84)
            return self.model_free_path(x).shape[1]

    def imagine(self, state):
        """Perform imagination rollout for each action"""
        batch_size = state.shape[0]
        all_encodings = []

        for action_idx in range(self.num_actions):
            imagined_frames = []
            imagined_rewards = []
            current = state

            for step in range(self.rollout_len):
                # First step uses specified action, subsequent steps use rollout policy
                if step == 0:
                    action = torch.full(
                        (batch_size,), action_idx, dtype=torch.long
                    )
                else:
                    action = self.rollout_policy.get_action(current)

                # Predict next state with environment model
                with torch.no_grad():
                    next_frame, reward = self.env_model(current, action)

                imagined_frames.append(next_frame)
                imagined_rewards.append(reward)
                current = next_frame

            # Rollout encoding
            frames_stack = torch.stack(imagined_frames, dim=1)
            rewards_stack = torch.stack(imagined_rewards, dim=1)
            encoding = self.rollout_encoder(frames_stack, rewards_stack)
            all_encodings.append(encoding)

        # Combine imagination results from all actions
        return torch.cat(all_encodings, dim=-1)

    def forward(self, state):
        # Model-free features
        mf_features = self.model_free_path(state)

        # Imagination-based features
        imagination_features = self.imagine(state)

        # Combine
        combined = torch.cat([mf_features, imagination_features], dim=-1)

        policy_logits = self.policy_head(combined)
        value = self.value_head(combined)

        return policy_logits, value

I2A Training Procedure

Three-Stage Training

def train_i2a(env, config):
    """I2A three-stage training"""

    # === Stage 1: Environment model training ===
    print("Stage 1: Training environment model")
    env_model = AtariEnvironmentModel(
        config['num_actions'], config['obs_channels']
    )
    env_model_optimizer = torch.optim.Adam(
        env_model.parameters(), lr=1e-3
    )

    # Collect data with random policy then train
    transitions = collect_random_transitions(env, num_steps=50000)
    for epoch in range(config['env_model_epochs']):
        loss = train_env_model_epoch(env_model, transitions,
                                      env_model_optimizer)
        if epoch % 10 == 0:
            print(f"  Epoch {epoch}: EnvModel Loss={loss:.4f}")

    # === Stage 2: Rollout policy training (simple A2C) ===
    print("Stage 2: Training rollout policy")
    rollout_policy = RolloutPolicy(
        config['obs_channels'], config['num_actions']
    )
    train_a2c_policy(env, rollout_policy, num_steps=100000)

    # === Stage 3: Full I2A training ===
    print("Stage 3: Training I2A")
    i2a = I2A(config['obs_channels'], config['num_actions'])
    i2a.env_model = env_model  # Pretrained environment model
    i2a.rollout_policy = rollout_policy  # Pretrained rollout policy

    # Freeze environment model and rollout policy
    for param in i2a.env_model.parameters():
        param.requires_grad = False
    for param in i2a.rollout_policy.parameters():
        param.requires_grad = False

    # Train remaining parameters with A2C
    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, i2a.parameters()),
        lr=7e-4
    )
    train_a2c_with_model(env, i2a, optimizer,
                         num_steps=config['total_steps'])

    return i2a

def train_env_model_epoch(env_model, transitions, optimizer):
    """Train environment model for one epoch"""
    total_loss = 0
    for state, action, reward, next_state in transitions:
        pred_next, pred_reward = env_model(state, action)

        frame_loss = nn.MSELoss()(pred_next, next_state)
        reward_loss = nn.MSELoss()(pred_reward, reward)
        loss = frame_loss + reward_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(transitions)

Atari Breakout Experimental Results

Environment Model Quality

How well the learned environment model predicts Breakout game frames is key to I2A performance:

  • 1-step prediction: Very accurate (low MSE)
  • 5-step prediction: Slight blurring in ball position
  • 15-step prediction: Significant uncertainty, but overall game situation remains understandable

I2A vs Baseline Performance

MethodBreakout Score
A2C (model-free)~400
I2A (5-step rollout)~500
I2A + perfect environment model~600

I2A shows performance improvement even with an imperfect environment model. The key is that the rollout encoder naturally learns to handle model uncertainty.


Practical Considerations

Difficulties in Environment Model Training

  • Visual environments: Pixel-level prediction is difficult and expensive. Predicting in latent space is more efficient.
  • Stochastic environments: Deterministic models average out multimodal transitions. VAE or GAN-based models may be needed.

Recent Research Directions

  • World Models: Learn latent space with VAE and predict transitions with RNN
  • Dreamer: Unified framework for imagining and planning in latent space
  • MuZero: Learns only the latent dynamics needed for value prediction without explicitly learning an environment model

Key Takeaways

  • Model-based RL improves sample efficiency by learning environment models for imagination (rollouts)
  • Compounding error of the model is the main challenge, mitigated by short rollouts and ensembles
  • I2A encodes imagination results for multiple actions to inform policy decisions
  • A model-free path provides a safety net against model errors

In the next post, we will examine AlphaGo Zero, the most impressive example of model-based RL.