Split View: [심층 강화학습] 17. 모델 기반 강화학습: Imagination-Augmented Agent
[심층 강화학습] 17. 모델 기반 강화학습: Imagination-Augmented Agent
개요
지금까지 다룬 알고리즘들은 모두 모델-프리(model-free) 방법이었다. 환경의 전이 동역학을 몰라도 시행착오만으로 학습했다. 하지만 인간은 행동하기 전에 "이렇게 하면 어떻게 될까?"를 상상한다. 모델 기반(model-based) 강화학습은 환경 모델을 학습하여 이런 상상(롤아웃)을 가능하게 한다.
이 글에서는 모델-프리와 모델-기반의 차이, 환경 모델의 불완전성 문제, 그리고 I2A(Imagination-Augmented Agent)를 다룬다.
모델-프리 vs 모델-기반
핵심 차이
| 항목 | 모델-프리 | 모델-기반 |
|---|---|---|
| 환경 지식 | 없음 | 전이 모델 학습/보유 |
| 샘플 효율성 | 낮음 | 높음 |
| 계획 | 불가 | 가능 (롤아웃) |
| 모델 오류 영향 | 없음 | 있음 (편향 유발) |
| 대표 알고리즘 | DQN, PPO, A3C | Dyna, MBPO, I2A |
모델-기반 RL의 일반 흐름
- 실제 환경과 상호작용하여 데이터 수집
- 수집된 데이터로 환경 모델(전이 함수 + 보상 함수) 학습
- 학습된 모델을 사용하여 추가 경험 생성(상상/롤아웃)
- 실제 + 상상 경험을 합쳐 정책 개선
import torch
import torch.nn as nn
import numpy as np
class EnvironmentModel(nn.Module):
"""학습 가능한 환경 모델: 상태 전이와 보상 예측"""
def __init__(self, obs_size, act_size, hidden_size=256):
super().__init__()
self.transition = nn.Sequential(
nn.Linear(obs_size + act_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, obs_size),
)
self.reward_pred = nn.Sequential(
nn.Linear(obs_size + act_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, 1),
)
def forward(self, state, action):
"""다음 상태와 보상 예측"""
if action.dim() == 1:
action = action.unsqueeze(-1)
sa = torch.cat([state, action.float()], dim=-1)
next_state = state + self.transition(sa) # 잔차 학습
reward = self.reward_pred(sa)
return next_state, reward.squeeze(-1)
모델 불완전성 문제
왜 모델이 완벽하지 않은가
학습된 환경 모델은 필연적으로 오류를 가진다:
- 데이터 한계: 방문하지 않은 상태-행동 쌍에 대한 예측이 부정확
- 복합 오류(Compounding Error): 롤아웃이 길어질수록 오류가 누적
- 분포 이동: 모델이 학습한 데이터 분포와 정책이 방문하는 분포가 다름
def demonstrate_compounding_error():
"""복합 오류 시연: 1단계 오류가 어떻게 누적되는지"""
true_state = np.array([0.0, 0.0])
predicted_state = np.array([0.0, 0.0])
errors = []
for step in range(50):
# 실제 전이 (예시)
true_state = true_state + np.array([0.1, 0.05])
# 모델 전이 (약간의 오류 포함)
model_error = np.random.normal(0, 0.01, 2)
predicted_state = predicted_state + np.array([0.1, 0.05]) + model_error
error = np.linalg.norm(true_state - predicted_state)
errors.append(error)
# 오류가 sqrt(t)에 비례하여 증가
print(f"10스텝 오류: {errors[9]:.4f}")
print(f"50스텝 오류: {errors[49]:.4f}")
대응 전략
- 짧은 롤아웃: 모델 예측을 짧게 유지하여 복합 오류 제한
- 앙상블 모델: 여러 모델의 불확실성을 활용
- 모델과 실제 경험 혼합: Dyna 스타일로 실제 데이터를 계속 사용
Imagination-Augmented Agent (I2A)
I2A는 DeepMind이 2017년에 제안한 모델 기반 RL 아키텍처로, 환경 모델의 불완전성을 고려하면서도 상상을 활용하는 방법이다.
I2A의 핵심 구성 요소
- 환경 모델 (Environment Model): 상태 전이를 예측
- 롤아웃 정책 (Rollout Policy): 상상 속에서 행동을 선택
- 롤아웃 인코더 (Rollout Encoder): 상상 궤적을 고정 길이 벡터로 압축
- 모델-프리 경로: 환경 모델 없이 직접 상태에서 특징 추출
환경 모델
Atari 같은 비주얼 환경에서는 ConvLSTM 기반 환경 모델을 사용한다:
class AtariEnvironmentModel(nn.Module):
"""Atari 환경 모델: 다음 프레임과 보상 예측"""
def __init__(self, num_actions, channels=1):
super().__init__()
self.num_actions = num_actions
# 행동을 이미지와 같은 크기의 채널로 확장
self.action_embed = nn.Embedding(num_actions, 84 * 84)
# 인코더
self.encoder = nn.Sequential(
nn.Conv2d(channels + 1, 64, 4, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, 4, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(128, 128, 3, stride=1, padding=1),
nn.ReLU(),
)
# 디코더 (다음 프레임 생성)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(64, channels, 4, stride=2, padding=1),
)
# 보상 예측
self.reward_head = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 1),
)
def forward(self, state, action):
"""현재 상태와 행동으로 다음 상태/보상 예측"""
batch_size = state.shape[0]
# 행동을 이미지 형태로 변환
action_plane = self.action_embed(action).view(
batch_size, 1, 84, 84
)
# 상태 + 행동 결합
x = torch.cat([state, action_plane], dim=1)
# 인코딩
encoded = self.encoder(x)
# 다음 프레임 예측
next_frame = self.decoder(encoded)
# 보상 예측
reward = self.reward_head(encoded)
return next_frame, reward.squeeze(-1)
롤아웃 정책
상상 속에서 행동을 선택하는 정책. 보통 사전 학습된 간단한 정책이나 현재 정책의 사본을 사용한다:
class RolloutPolicy(nn.Module):
"""상상 롤아웃에 사용되는 정책"""
def __init__(self, obs_channels, num_actions):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(obs_channels, 32, 8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2),
nn.ReLU(),
)
self._feature_size = self._get_conv_size(obs_channels)
self.fc = nn.Sequential(
nn.Linear(self._feature_size, 256),
nn.ReLU(),
nn.Linear(256, num_actions),
)
def _get_conv_size(self, channels):
with torch.no_grad():
x = torch.zeros(1, channels, 84, 84)
return self.conv(x).view(1, -1).shape[1]
def forward(self, obs):
features = self.conv(obs).view(obs.size(0), -1)
return self.fc(features)
def get_action(self, obs):
logits = self.forward(obs)
return torch.distributions.Categorical(
logits=logits
).sample()
롤아웃 인코더
상상 궤적(상태-보상 시퀀스)을 고정 길이 벡터로 요약한다:
class RolloutEncoder(nn.Module):
"""상상 궤적을 고정 길이 벡터로 인코딩"""
def __init__(self, obs_channels, hidden_size=256):
super().__init__()
# 각 프레임의 특징 추출
self.frame_encoder = nn.Sequential(
nn.Conv2d(obs_channels, 32, 8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2),
nn.ReLU(),
nn.AdaptiveAvgPool2d(4),
nn.Flatten(),
)
self._frame_feature_size = 64 * 4 * 4
# 시퀀스 인코딩 (LSTM)
self.lstm = nn.LSTM(
self._frame_feature_size + 1, # 프레임 특징 + 보상
hidden_size,
batch_first=True,
)
self.output_size = hidden_size
def forward(self, imagined_frames, imagined_rewards):
"""
imagined_frames: (batch, rollout_len, C, H, W)
imagined_rewards: (batch, rollout_len)
"""
batch, rollout_len = imagined_frames.shape[:2]
# 각 프레임 인코딩
frames_flat = imagined_frames.view(-1, *imagined_frames.shape[2:])
frame_features = self.frame_encoder(frames_flat)
frame_features = frame_features.view(batch, rollout_len, -1)
# 보상 결합
rewards = imagined_rewards.unsqueeze(-1)
lstm_input = torch.cat([frame_features, rewards], dim=-1)
# LSTM으로 시퀀스 인코딩
_, (hidden, _) = self.lstm(lstm_input)
return hidden.squeeze(0) # (batch, hidden_size)
I2A 전체 아키텍처
class I2A(nn.Module):
"""Imagination-Augmented Agent"""
def __init__(self, obs_channels, num_actions, rollout_len=5,
hidden_size=256):
super().__init__()
self.num_actions = num_actions
self.rollout_len = rollout_len
# 환경 모델 (사전 학습)
self.env_model = AtariEnvironmentModel(num_actions, obs_channels)
# 롤아웃 정책 (사전 학습)
self.rollout_policy = RolloutPolicy(obs_channels, num_actions)
# 각 행동에 대한 롤아웃 인코더
self.rollout_encoder = RolloutEncoder(obs_channels, hidden_size)
# 모델-프리 경로
self.model_free_path = nn.Sequential(
nn.Conv2d(obs_channels, 32, 8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2),
nn.ReLU(),
nn.Flatten(),
)
self._mf_size = self._get_mf_size(obs_channels)
# 결합 후 정책/가치 출력
combined_size = self._mf_size + hidden_size * num_actions
self.policy_head = nn.Sequential(
nn.Linear(combined_size, 512),
nn.ReLU(),
nn.Linear(512, num_actions),
)
self.value_head = nn.Sequential(
nn.Linear(combined_size, 512),
nn.ReLU(),
nn.Linear(512, 1),
)
def _get_mf_size(self, channels):
with torch.no_grad():
x = torch.zeros(1, channels, 84, 84)
return self.model_free_path(x).shape[1]
def imagine(self, state):
"""각 행동에 대해 상상 롤아웃 수행"""
batch_size = state.shape[0]
all_encodings = []
for action_idx in range(self.num_actions):
imagined_frames = []
imagined_rewards = []
current = state
for step in range(self.rollout_len):
# 첫 스텝은 지정된 행동, 이후는 롤아웃 정책
if step == 0:
action = torch.full(
(batch_size,), action_idx, dtype=torch.long
)
else:
action = self.rollout_policy.get_action(current)
# 환경 모델로 다음 상태 예측
with torch.no_grad():
next_frame, reward = self.env_model(current, action)
imagined_frames.append(next_frame)
imagined_rewards.append(reward)
current = next_frame
# 롤아웃 인코딩
frames_stack = torch.stack(imagined_frames, dim=1)
rewards_stack = torch.stack(imagined_rewards, dim=1)
encoding = self.rollout_encoder(frames_stack, rewards_stack)
all_encodings.append(encoding)
# 모든 행동의 상상 결과 결합
return torch.cat(all_encodings, dim=-1)
def forward(self, state):
# 모델-프리 특징
mf_features = self.model_free_path(state)
# 상상 기반 특징
imagination_features = self.imagine(state)
# 결합
combined = torch.cat([mf_features, imagination_features], dim=-1)
policy_logits = self.policy_head(combined)
value = self.value_head(combined)
return policy_logits, value
I2A 학습 절차
3단계 학습
def train_i2a(env, config):
"""I2A 3단계 학습"""
# === 1단계: 환경 모델 학습 ===
print("1단계: 환경 모델 학습")
env_model = AtariEnvironmentModel(
config['num_actions'], config['obs_channels']
)
env_model_optimizer = torch.optim.Adam(
env_model.parameters(), lr=1e-3
)
# 무작위 정책으로 데이터 수집 후 학습
transitions = collect_random_transitions(env, num_steps=50000)
for epoch in range(config['env_model_epochs']):
loss = train_env_model_epoch(env_model, transitions,
env_model_optimizer)
if epoch % 10 == 0:
print(f" Epoch {epoch}: EnvModel Loss={loss:.4f}")
# === 2단계: 롤아웃 정책 학습 (간단한 A2C) ===
print("2단계: 롤아웃 정책 학습")
rollout_policy = RolloutPolicy(
config['obs_channels'], config['num_actions']
)
train_a2c_policy(env, rollout_policy, num_steps=100000)
# === 3단계: I2A 전체 학습 ===
print("3단계: I2A 학습")
i2a = I2A(config['obs_channels'], config['num_actions'])
i2a.env_model = env_model # 사전 학습된 환경 모델
i2a.rollout_policy = rollout_policy # 사전 학습된 롤아웃 정책
# 환경 모델과 롤아웃 정책은 고정
for param in i2a.env_model.parameters():
param.requires_grad = False
for param in i2a.rollout_policy.parameters():
param.requires_grad = False
# A2C로 나머지 파라미터 학습
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, i2a.parameters()),
lr=7e-4
)
train_a2c_with_model(env, i2a, optimizer,
num_steps=config['total_steps'])
return i2a
def train_env_model_epoch(env_model, transitions, optimizer):
"""환경 모델 한 에폭 학습"""
total_loss = 0
for state, action, reward, next_state in transitions:
pred_next, pred_reward = env_model(state, action)
frame_loss = nn.MSELoss()(pred_next, next_state)
reward_loss = nn.MSELoss()(pred_reward, reward)
loss = frame_loss + reward_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(transitions)
Atari Breakout 실험 결과
환경 모델 품질
학습된 환경 모델이 Breakout 게임의 프레임을 얼마나 잘 예측하는지가 I2A 성능의 핵심이다:
- 1스텝 예측: 매우 정확 (MSE 낮음)
- 5스텝 예측: 공의 위치에 약간의 흐림 발생
- 15스텝 예측: 상당한 불확실성, 하지만 전반적인 게임 상황은 파악 가능
I2A vs 기준선 성능
| 방법 | Breakout 점수 |
|---|---|
| A2C (모델-프리) | 약 400 |
| I2A (롤아웃 5스텝) | 약 500 |
| I2A + 완벽한 환경 모델 | 약 600 |
I2A는 불완전한 환경 모델에서도 성능 향상을 보여준다. 핵심은 롤아웃 인코더가 모델의 불확실성을 자연스럽게 처리하도록 학습된다는 점이다.
실전 고려사항
환경 모델 학습의 어려움
- 비주얼 환경: 픽셀 레벨 예측은 어렵고 비용이 크다. 잠재 공간(latent space)에서 예측하는 방법이 효율적이다.
- 확률적 환경: 결정적 모델은 다중 모드(multimodal) 전이를 평균화한다. VAE나 GAN 기반 모델이 필요할 수 있다.
최신 연구 방향
- World Models: VAE로 잠재 공간을 학습하고 RNN으로 전이를 예측
- Dreamer: 잠재 공간에서 상상하고 계획하는 통합 프레임워크
- MuZero: 환경 모델을 명시적으로 학습하지 않고, 가치 예측에 필요한 잠재 동역학만 학습
핵심 요약
- 모델 기반 RL은 환경 모델을 학습하여 상상(롤아웃)으로 샘플 효율성을 높인다
- 모델의 복합 오류가 주요 도전 과제이며, 짧은 롤아웃과 앙상블로 대응한다
- I2A는 여러 행동에 대한 상상 결과를 인코딩하여 정책 결정에 활용한다
- 모델-프리 경로를 병행하여 모델 오류에 대한 안전장치를 제공한다
다음 글에서는 모델 기반 RL의 가장 인상적인 사례인 AlphaGo Zero를 살펴보겠다.
[Deep RL] 17. Model-Based Reinforcement Learning: Imagination-Augmented Agent
Overview
All algorithms covered so far were model-free methods. They learned through trial and error without knowing the environment's transition dynamics. But humans imagine "what would happen if I do this?" before acting. Model-based reinforcement learning enables such imagination (rollouts) by learning an environment model.
This post covers the differences between model-free and model-based approaches, the problem of model imperfection, and I2A (Imagination-Augmented Agent).
Model-Free vs Model-Based
Key Differences
| Item | Model-Free | Model-Based |
|---|---|---|
| Environment knowledge | None | Learn/possess transition model |
| Sample efficiency | Low | High |
| Planning | Not possible | Possible (rollouts) |
| Model error impact | None | Present (causes bias) |
| Representative algorithms | DQN, PPO, A3C | Dyna, MBPO, I2A |
General Flow of Model-Based RL
- Collect data by interacting with the real environment
- Learn the environment model (transition function + reward function) from collected data
- Generate additional experiences using the learned model (imagination/rollouts)
- Improve policy using both real and imagined experiences
import torch
import torch.nn as nn
import numpy as np
class EnvironmentModel(nn.Module):
"""Learnable environment model: predicts state transitions and rewards"""
def __init__(self, obs_size, act_size, hidden_size=256):
super().__init__()
self.transition = nn.Sequential(
nn.Linear(obs_size + act_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, obs_size),
)
self.reward_pred = nn.Sequential(
nn.Linear(obs_size + act_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, 1),
)
def forward(self, state, action):
"""Predict next state and reward"""
if action.dim() == 1:
action = action.unsqueeze(-1)
sa = torch.cat([state, action.float()], dim=-1)
next_state = state + self.transition(sa) # Residual learning
reward = self.reward_pred(sa)
return next_state, reward.squeeze(-1)
Model Imperfection Problem
Why Models Are Not Perfect
Learned environment models inevitably have errors:
- Data limitations: Predictions are inaccurate for unvisited state-action pairs
- Compounding error: Errors accumulate as rollouts get longer
- Distribution shift: The data distribution the model was trained on differs from what the policy visits
def demonstrate_compounding_error():
"""Demonstrate compounding error: how 1-step errors accumulate"""
true_state = np.array([0.0, 0.0])
predicted_state = np.array([0.0, 0.0])
errors = []
for step in range(50):
# True transition (example)
true_state = true_state + np.array([0.1, 0.05])
# Model transition (with slight error)
model_error = np.random.normal(0, 0.01, 2)
predicted_state = predicted_state + np.array([0.1, 0.05]) + model_error
error = np.linalg.norm(true_state - predicted_state)
errors.append(error)
# Error grows proportionally to sqrt(t)
print(f"Error at 10 steps: {errors[9]:.4f}")
print(f"Error at 50 steps: {errors[49]:.4f}")
Mitigation Strategies
- Short rollouts: Keep model predictions short to limit compounding error
- Ensemble models: Leverage uncertainty from multiple models
- Mix model and real experiences: Continue using real data in Dyna style
Imagination-Augmented Agent (I2A)
I2A is a model-based RL architecture proposed by DeepMind in 2017 that leverages imagination while accounting for model imperfection.
Core Components of I2A
- Environment Model: Predicts state transitions
- Rollout Policy: Selects actions during imagination
- Rollout Encoder: Compresses imagined trajectories into fixed-length vectors
- Model-Free Path: Extracts features directly from state without the environment model
Environment Model
For visual environments like Atari, a ConvLSTM-based environment model is used:
class AtariEnvironmentModel(nn.Module):
"""Atari environment model: predicts next frame and reward"""
def __init__(self, num_actions, channels=1):
super().__init__()
self.num_actions = num_actions
# Expand action to image-sized channel
self.action_embed = nn.Embedding(num_actions, 84 * 84)
# Encoder
self.encoder = nn.Sequential(
nn.Conv2d(channels + 1, 64, 4, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, 4, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(128, 128, 3, stride=1, padding=1),
nn.ReLU(),
)
# Decoder (next frame generation)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(64, channels, 4, stride=2, padding=1),
)
# Reward prediction
self.reward_head = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 1),
)
def forward(self, state, action):
"""Predict next state/reward from current state and action"""
batch_size = state.shape[0]
# Convert action to image format
action_plane = self.action_embed(action).view(
batch_size, 1, 84, 84
)
# Combine state + action
x = torch.cat([state, action_plane], dim=1)
# Encoding
encoded = self.encoder(x)
# Next frame prediction
next_frame = self.decoder(encoded)
# Reward prediction
reward = self.reward_head(encoded)
return next_frame, reward.squeeze(-1)
Rollout Policy
A policy that selects actions during imagination. Usually a pretrained simple policy or a copy of the current policy:
class RolloutPolicy(nn.Module):
"""Policy used for imagination rollouts"""
def __init__(self, obs_channels, num_actions):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(obs_channels, 32, 8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2),
nn.ReLU(),
)
self._feature_size = self._get_conv_size(obs_channels)
self.fc = nn.Sequential(
nn.Linear(self._feature_size, 256),
nn.ReLU(),
nn.Linear(256, num_actions),
)
def _get_conv_size(self, channels):
with torch.no_grad():
x = torch.zeros(1, channels, 84, 84)
return self.conv(x).view(1, -1).shape[1]
def forward(self, obs):
features = self.conv(obs).view(obs.size(0), -1)
return self.fc(features)
def get_action(self, obs):
logits = self.forward(obs)
return torch.distributions.Categorical(
logits=logits
).sample()
Rollout Encoder
Summarizes imagined trajectories (state-reward sequences) into fixed-length vectors:
class RolloutEncoder(nn.Module):
"""Encode imagined trajectories into fixed-length vectors"""
def __init__(self, obs_channels, hidden_size=256):
super().__init__()
# Feature extraction for each frame
self.frame_encoder = nn.Sequential(
nn.Conv2d(obs_channels, 32, 8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2),
nn.ReLU(),
nn.AdaptiveAvgPool2d(4),
nn.Flatten(),
)
self._frame_feature_size = 64 * 4 * 4
# Sequence encoding (LSTM)
self.lstm = nn.LSTM(
self._frame_feature_size + 1, # Frame features + reward
hidden_size,
batch_first=True,
)
self.output_size = hidden_size
def forward(self, imagined_frames, imagined_rewards):
"""
imagined_frames: (batch, rollout_len, C, H, W)
imagined_rewards: (batch, rollout_len)
"""
batch, rollout_len = imagined_frames.shape[:2]
# Encode each frame
frames_flat = imagined_frames.view(-1, *imagined_frames.shape[2:])
frame_features = self.frame_encoder(frames_flat)
frame_features = frame_features.view(batch, rollout_len, -1)
# Combine with rewards
rewards = imagined_rewards.unsqueeze(-1)
lstm_input = torch.cat([frame_features, rewards], dim=-1)
# LSTM sequence encoding
_, (hidden, _) = self.lstm(lstm_input)
return hidden.squeeze(0) # (batch, hidden_size)
Full I2A Architecture
class I2A(nn.Module):
"""Imagination-Augmented Agent"""
def __init__(self, obs_channels, num_actions, rollout_len=5,
hidden_size=256):
super().__init__()
self.num_actions = num_actions
self.rollout_len = rollout_len
# Environment model (pretrained)
self.env_model = AtariEnvironmentModel(num_actions, obs_channels)
# Rollout policy (pretrained)
self.rollout_policy = RolloutPolicy(obs_channels, num_actions)
# Rollout encoder for each action
self.rollout_encoder = RolloutEncoder(obs_channels, hidden_size)
# Model-free path
self.model_free_path = nn.Sequential(
nn.Conv2d(obs_channels, 32, 8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2),
nn.ReLU(),
nn.Flatten(),
)
self._mf_size = self._get_mf_size(obs_channels)
# Policy/value output after combination
combined_size = self._mf_size + hidden_size * num_actions
self.policy_head = nn.Sequential(
nn.Linear(combined_size, 512),
nn.ReLU(),
nn.Linear(512, num_actions),
)
self.value_head = nn.Sequential(
nn.Linear(combined_size, 512),
nn.ReLU(),
nn.Linear(512, 1),
)
def _get_mf_size(self, channels):
with torch.no_grad():
x = torch.zeros(1, channels, 84, 84)
return self.model_free_path(x).shape[1]
def imagine(self, state):
"""Perform imagination rollout for each action"""
batch_size = state.shape[0]
all_encodings = []
for action_idx in range(self.num_actions):
imagined_frames = []
imagined_rewards = []
current = state
for step in range(self.rollout_len):
# First step uses specified action, subsequent steps use rollout policy
if step == 0:
action = torch.full(
(batch_size,), action_idx, dtype=torch.long
)
else:
action = self.rollout_policy.get_action(current)
# Predict next state with environment model
with torch.no_grad():
next_frame, reward = self.env_model(current, action)
imagined_frames.append(next_frame)
imagined_rewards.append(reward)
current = next_frame
# Rollout encoding
frames_stack = torch.stack(imagined_frames, dim=1)
rewards_stack = torch.stack(imagined_rewards, dim=1)
encoding = self.rollout_encoder(frames_stack, rewards_stack)
all_encodings.append(encoding)
# Combine imagination results from all actions
return torch.cat(all_encodings, dim=-1)
def forward(self, state):
# Model-free features
mf_features = self.model_free_path(state)
# Imagination-based features
imagination_features = self.imagine(state)
# Combine
combined = torch.cat([mf_features, imagination_features], dim=-1)
policy_logits = self.policy_head(combined)
value = self.value_head(combined)
return policy_logits, value
I2A Training Procedure
Three-Stage Training
def train_i2a(env, config):
"""I2A three-stage training"""
# === Stage 1: Environment model training ===
print("Stage 1: Training environment model")
env_model = AtariEnvironmentModel(
config['num_actions'], config['obs_channels']
)
env_model_optimizer = torch.optim.Adam(
env_model.parameters(), lr=1e-3
)
# Collect data with random policy then train
transitions = collect_random_transitions(env, num_steps=50000)
for epoch in range(config['env_model_epochs']):
loss = train_env_model_epoch(env_model, transitions,
env_model_optimizer)
if epoch % 10 == 0:
print(f" Epoch {epoch}: EnvModel Loss={loss:.4f}")
# === Stage 2: Rollout policy training (simple A2C) ===
print("Stage 2: Training rollout policy")
rollout_policy = RolloutPolicy(
config['obs_channels'], config['num_actions']
)
train_a2c_policy(env, rollout_policy, num_steps=100000)
# === Stage 3: Full I2A training ===
print("Stage 3: Training I2A")
i2a = I2A(config['obs_channels'], config['num_actions'])
i2a.env_model = env_model # Pretrained environment model
i2a.rollout_policy = rollout_policy # Pretrained rollout policy
# Freeze environment model and rollout policy
for param in i2a.env_model.parameters():
param.requires_grad = False
for param in i2a.rollout_policy.parameters():
param.requires_grad = False
# Train remaining parameters with A2C
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, i2a.parameters()),
lr=7e-4
)
train_a2c_with_model(env, i2a, optimizer,
num_steps=config['total_steps'])
return i2a
def train_env_model_epoch(env_model, transitions, optimizer):
"""Train environment model for one epoch"""
total_loss = 0
for state, action, reward, next_state in transitions:
pred_next, pred_reward = env_model(state, action)
frame_loss = nn.MSELoss()(pred_next, next_state)
reward_loss = nn.MSELoss()(pred_reward, reward)
loss = frame_loss + reward_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(transitions)
Atari Breakout Experimental Results
Environment Model Quality
How well the learned environment model predicts Breakout game frames is key to I2A performance:
- 1-step prediction: Very accurate (low MSE)
- 5-step prediction: Slight blurring in ball position
- 15-step prediction: Significant uncertainty, but overall game situation remains understandable
I2A vs Baseline Performance
| Method | Breakout Score |
|---|---|
| A2C (model-free) | ~400 |
| I2A (5-step rollout) | ~500 |
| I2A + perfect environment model | ~600 |
I2A shows performance improvement even with an imperfect environment model. The key is that the rollout encoder naturally learns to handle model uncertainty.
Practical Considerations
Difficulties in Environment Model Training
- Visual environments: Pixel-level prediction is difficult and expensive. Predicting in latent space is more efficient.
- Stochastic environments: Deterministic models average out multimodal transitions. VAE or GAN-based models may be needed.
Recent Research Directions
- World Models: Learn latent space with VAE and predict transitions with RNN
- Dreamer: Unified framework for imagining and planning in latent space
- MuZero: Learns only the latent dynamics needed for value prediction without explicitly learning an environment model
Key Takeaways
- Model-based RL improves sample efficiency by learning environment models for imagination (rollouts)
- Compounding error of the model is the main challenge, mitigated by short rollouts and ensembles
- I2A encodes imagination results for multiple actions to inform policy decisions
- A model-free path provides a safety net against model errors
In the next post, we will examine AlphaGo Zero, the most impressive example of model-based RL.