- Authors

- Name
- Youngju Kim
- @fjvbn20031
Overview
All algorithms covered so far were model-free methods. They learned through trial and error without knowing the environment's transition dynamics. But humans imagine "what would happen if I do this?" before acting. Model-based reinforcement learning enables such imagination (rollouts) by learning an environment model.
This post covers the differences between model-free and model-based approaches, the problem of model imperfection, and I2A (Imagination-Augmented Agent).
Model-Free vs Model-Based
Key Differences
| Item | Model-Free | Model-Based |
|---|---|---|
| Environment knowledge | None | Learn/possess transition model |
| Sample efficiency | Low | High |
| Planning | Not possible | Possible (rollouts) |
| Model error impact | None | Present (causes bias) |
| Representative algorithms | DQN, PPO, A3C | Dyna, MBPO, I2A |
General Flow of Model-Based RL
- Collect data by interacting with the real environment
- Learn the environment model (transition function + reward function) from collected data
- Generate additional experiences using the learned model (imagination/rollouts)
- Improve policy using both real and imagined experiences
import torch
import torch.nn as nn
import numpy as np
class EnvironmentModel(nn.Module):
"""Learnable environment model: predicts state transitions and rewards"""
def __init__(self, obs_size, act_size, hidden_size=256):
super().__init__()
self.transition = nn.Sequential(
nn.Linear(obs_size + act_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, obs_size),
)
self.reward_pred = nn.Sequential(
nn.Linear(obs_size + act_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, 1),
)
def forward(self, state, action):
"""Predict next state and reward"""
if action.dim() == 1:
action = action.unsqueeze(-1)
sa = torch.cat([state, action.float()], dim=-1)
next_state = state + self.transition(sa) # Residual learning
reward = self.reward_pred(sa)
return next_state, reward.squeeze(-1)
Model Imperfection Problem
Why Models Are Not Perfect
Learned environment models inevitably have errors:
- Data limitations: Predictions are inaccurate for unvisited state-action pairs
- Compounding error: Errors accumulate as rollouts get longer
- Distribution shift: The data distribution the model was trained on differs from what the policy visits
def demonstrate_compounding_error():
"""Demonstrate compounding error: how 1-step errors accumulate"""
true_state = np.array([0.0, 0.0])
predicted_state = np.array([0.0, 0.0])
errors = []
for step in range(50):
# True transition (example)
true_state = true_state + np.array([0.1, 0.05])
# Model transition (with slight error)
model_error = np.random.normal(0, 0.01, 2)
predicted_state = predicted_state + np.array([0.1, 0.05]) + model_error
error = np.linalg.norm(true_state - predicted_state)
errors.append(error)
# Error grows proportionally to sqrt(t)
print(f"Error at 10 steps: {errors[9]:.4f}")
print(f"Error at 50 steps: {errors[49]:.4f}")
Mitigation Strategies
- Short rollouts: Keep model predictions short to limit compounding error
- Ensemble models: Leverage uncertainty from multiple models
- Mix model and real experiences: Continue using real data in Dyna style
Imagination-Augmented Agent (I2A)
I2A is a model-based RL architecture proposed by DeepMind in 2017 that leverages imagination while accounting for model imperfection.
Core Components of I2A
- Environment Model: Predicts state transitions
- Rollout Policy: Selects actions during imagination
- Rollout Encoder: Compresses imagined trajectories into fixed-length vectors
- Model-Free Path: Extracts features directly from state without the environment model
Environment Model
For visual environments like Atari, a ConvLSTM-based environment model is used:
class AtariEnvironmentModel(nn.Module):
"""Atari environment model: predicts next frame and reward"""
def __init__(self, num_actions, channels=1):
super().__init__()
self.num_actions = num_actions
# Expand action to image-sized channel
self.action_embed = nn.Embedding(num_actions, 84 * 84)
# Encoder
self.encoder = nn.Sequential(
nn.Conv2d(channels + 1, 64, 4, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, 4, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(128, 128, 3, stride=1, padding=1),
nn.ReLU(),
)
# Decoder (next frame generation)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(64, channels, 4, stride=2, padding=1),
)
# Reward prediction
self.reward_head = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 1),
)
def forward(self, state, action):
"""Predict next state/reward from current state and action"""
batch_size = state.shape[0]
# Convert action to image format
action_plane = self.action_embed(action).view(
batch_size, 1, 84, 84
)
# Combine state + action
x = torch.cat([state, action_plane], dim=1)
# Encoding
encoded = self.encoder(x)
# Next frame prediction
next_frame = self.decoder(encoded)
# Reward prediction
reward = self.reward_head(encoded)
return next_frame, reward.squeeze(-1)
Rollout Policy
A policy that selects actions during imagination. Usually a pretrained simple policy or a copy of the current policy:
class RolloutPolicy(nn.Module):
"""Policy used for imagination rollouts"""
def __init__(self, obs_channels, num_actions):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(obs_channels, 32, 8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2),
nn.ReLU(),
)
self._feature_size = self._get_conv_size(obs_channels)
self.fc = nn.Sequential(
nn.Linear(self._feature_size, 256),
nn.ReLU(),
nn.Linear(256, num_actions),
)
def _get_conv_size(self, channels):
with torch.no_grad():
x = torch.zeros(1, channels, 84, 84)
return self.conv(x).view(1, -1).shape[1]
def forward(self, obs):
features = self.conv(obs).view(obs.size(0), -1)
return self.fc(features)
def get_action(self, obs):
logits = self.forward(obs)
return torch.distributions.Categorical(
logits=logits
).sample()
Rollout Encoder
Summarizes imagined trajectories (state-reward sequences) into fixed-length vectors:
class RolloutEncoder(nn.Module):
"""Encode imagined trajectories into fixed-length vectors"""
def __init__(self, obs_channels, hidden_size=256):
super().__init__()
# Feature extraction for each frame
self.frame_encoder = nn.Sequential(
nn.Conv2d(obs_channels, 32, 8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2),
nn.ReLU(),
nn.AdaptiveAvgPool2d(4),
nn.Flatten(),
)
self._frame_feature_size = 64 * 4 * 4
# Sequence encoding (LSTM)
self.lstm = nn.LSTM(
self._frame_feature_size + 1, # Frame features + reward
hidden_size,
batch_first=True,
)
self.output_size = hidden_size
def forward(self, imagined_frames, imagined_rewards):
"""
imagined_frames: (batch, rollout_len, C, H, W)
imagined_rewards: (batch, rollout_len)
"""
batch, rollout_len = imagined_frames.shape[:2]
# Encode each frame
frames_flat = imagined_frames.view(-1, *imagined_frames.shape[2:])
frame_features = self.frame_encoder(frames_flat)
frame_features = frame_features.view(batch, rollout_len, -1)
# Combine with rewards
rewards = imagined_rewards.unsqueeze(-1)
lstm_input = torch.cat([frame_features, rewards], dim=-1)
# LSTM sequence encoding
_, (hidden, _) = self.lstm(lstm_input)
return hidden.squeeze(0) # (batch, hidden_size)
Full I2A Architecture
class I2A(nn.Module):
"""Imagination-Augmented Agent"""
def __init__(self, obs_channels, num_actions, rollout_len=5,
hidden_size=256):
super().__init__()
self.num_actions = num_actions
self.rollout_len = rollout_len
# Environment model (pretrained)
self.env_model = AtariEnvironmentModel(num_actions, obs_channels)
# Rollout policy (pretrained)
self.rollout_policy = RolloutPolicy(obs_channels, num_actions)
# Rollout encoder for each action
self.rollout_encoder = RolloutEncoder(obs_channels, hidden_size)
# Model-free path
self.model_free_path = nn.Sequential(
nn.Conv2d(obs_channels, 32, 8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2),
nn.ReLU(),
nn.Flatten(),
)
self._mf_size = self._get_mf_size(obs_channels)
# Policy/value output after combination
combined_size = self._mf_size + hidden_size * num_actions
self.policy_head = nn.Sequential(
nn.Linear(combined_size, 512),
nn.ReLU(),
nn.Linear(512, num_actions),
)
self.value_head = nn.Sequential(
nn.Linear(combined_size, 512),
nn.ReLU(),
nn.Linear(512, 1),
)
def _get_mf_size(self, channels):
with torch.no_grad():
x = torch.zeros(1, channels, 84, 84)
return self.model_free_path(x).shape[1]
def imagine(self, state):
"""Perform imagination rollout for each action"""
batch_size = state.shape[0]
all_encodings = []
for action_idx in range(self.num_actions):
imagined_frames = []
imagined_rewards = []
current = state
for step in range(self.rollout_len):
# First step uses specified action, subsequent steps use rollout policy
if step == 0:
action = torch.full(
(batch_size,), action_idx, dtype=torch.long
)
else:
action = self.rollout_policy.get_action(current)
# Predict next state with environment model
with torch.no_grad():
next_frame, reward = self.env_model(current, action)
imagined_frames.append(next_frame)
imagined_rewards.append(reward)
current = next_frame
# Rollout encoding
frames_stack = torch.stack(imagined_frames, dim=1)
rewards_stack = torch.stack(imagined_rewards, dim=1)
encoding = self.rollout_encoder(frames_stack, rewards_stack)
all_encodings.append(encoding)
# Combine imagination results from all actions
return torch.cat(all_encodings, dim=-1)
def forward(self, state):
# Model-free features
mf_features = self.model_free_path(state)
# Imagination-based features
imagination_features = self.imagine(state)
# Combine
combined = torch.cat([mf_features, imagination_features], dim=-1)
policy_logits = self.policy_head(combined)
value = self.value_head(combined)
return policy_logits, value
I2A Training Procedure
Three-Stage Training
def train_i2a(env, config):
"""I2A three-stage training"""
# === Stage 1: Environment model training ===
print("Stage 1: Training environment model")
env_model = AtariEnvironmentModel(
config['num_actions'], config['obs_channels']
)
env_model_optimizer = torch.optim.Adam(
env_model.parameters(), lr=1e-3
)
# Collect data with random policy then train
transitions = collect_random_transitions(env, num_steps=50000)
for epoch in range(config['env_model_epochs']):
loss = train_env_model_epoch(env_model, transitions,
env_model_optimizer)
if epoch % 10 == 0:
print(f" Epoch {epoch}: EnvModel Loss={loss:.4f}")
# === Stage 2: Rollout policy training (simple A2C) ===
print("Stage 2: Training rollout policy")
rollout_policy = RolloutPolicy(
config['obs_channels'], config['num_actions']
)
train_a2c_policy(env, rollout_policy, num_steps=100000)
# === Stage 3: Full I2A training ===
print("Stage 3: Training I2A")
i2a = I2A(config['obs_channels'], config['num_actions'])
i2a.env_model = env_model # Pretrained environment model
i2a.rollout_policy = rollout_policy # Pretrained rollout policy
# Freeze environment model and rollout policy
for param in i2a.env_model.parameters():
param.requires_grad = False
for param in i2a.rollout_policy.parameters():
param.requires_grad = False
# Train remaining parameters with A2C
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, i2a.parameters()),
lr=7e-4
)
train_a2c_with_model(env, i2a, optimizer,
num_steps=config['total_steps'])
return i2a
def train_env_model_epoch(env_model, transitions, optimizer):
"""Train environment model for one epoch"""
total_loss = 0
for state, action, reward, next_state in transitions:
pred_next, pred_reward = env_model(state, action)
frame_loss = nn.MSELoss()(pred_next, next_state)
reward_loss = nn.MSELoss()(pred_reward, reward)
loss = frame_loss + reward_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(transitions)
Atari Breakout Experimental Results
Environment Model Quality
How well the learned environment model predicts Breakout game frames is key to I2A performance:
- 1-step prediction: Very accurate (low MSE)
- 5-step prediction: Slight blurring in ball position
- 15-step prediction: Significant uncertainty, but overall game situation remains understandable
I2A vs Baseline Performance
| Method | Breakout Score |
|---|---|
| A2C (model-free) | ~400 |
| I2A (5-step rollout) | ~500 |
| I2A + perfect environment model | ~600 |
I2A shows performance improvement even with an imperfect environment model. The key is that the rollout encoder naturally learns to handle model uncertainty.
Practical Considerations
Difficulties in Environment Model Training
- Visual environments: Pixel-level prediction is difficult and expensive. Predicting in latent space is more efficient.
- Stochastic environments: Deterministic models average out multimodal transitions. VAE or GAN-based models may be needed.
Recent Research Directions
- World Models: Learn latent space with VAE and predict transitions with RNN
- Dreamer: Unified framework for imagining and planning in latent space
- MuZero: Learns only the latent dynamics needed for value prediction without explicitly learning an environment model
Key Takeaways
- Model-based RL improves sample efficiency by learning environment models for imagination (rollouts)
- Compounding error of the model is the main challenge, mitigated by short rollouts and ensembles
- I2A encodes imagination results for multiple actions to inform policy decisions
- A model-free path provides a safety net against model errors
In the next post, we will examine AlphaGo Zero, the most impressive example of model-based RL.