- Authors

- Name
- Youngju Kim
- @fjvbn20031
現実世界における価値反復の限界
前の記事で扱ったテーブルベースの方法は状態空間が小さい時のみ使用可能です。現実世界の問題では以下のような限界に直面します。
状態空間の爆発
- Atariゲーム画面: 210 x 160ピクセル、3色チャンネル = 約10万個の連続値
- 囲碁: 約10の170乗に達する可能な状態
- 自動運転: センサーデータの組み合わせは事実上無限
このような環境ではQテーブルを作ること自体が不可能です。ニューラルネットワークでQ関数を**近似(approximation)**する必要があります。
テーブルQ学習からディープQ学習へ
核心アイデア
Qテーブルの代わりにニューラルネットワーク Q(s, a; theta) を使ってQ値を近似します。ここでthetaはニューラルネットワークのパラメータ(重み)です。
Q-테이블: Q[state_index][action_index] = value (정확한 값 저장)
DQN: Q(state; theta) -> [q_action_0, q_action_1, ...] (함수 근사)
なぜ単純にニューラルネットワークを入れるだけではダメなのか
Q学習に単純にニューラルネットワークを代入すると学習が非常に不安定になります。3つの核心問題があります:
- データ間の相関: 連続した経験は互いに強く相関しており、i.i.d.仮定に違反する
- 非定常ターゲット: 学習ターゲットが常に変化するため収束が困難
- マルコフ性質違反: 単一フレームだけでは環境状態を完全に表現できない
DQNの核心技法
1. 経験リプレイ(Experience Replay)
エージェントの経験をバッファに保存し、学習時にランダムにサンプリングして使用します。これによりデータ間の相関を破壊します。
import numpy as np
from collections import deque
import random
class ReplayBuffer:
"""경험 리플레이 버퍼"""
def __init__(self, capacity):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
return (
np.array(states),
np.array(actions),
np.array(rewards, dtype=np.float32),
np.array(next_states),
np.array(dones, dtype=np.bool_),
)
def __len__(self):
return len(self.buffer)
# 사용 예시
buffer = ReplayBuffer(capacity=100000)
# 경험 저장
buffer.push(
state=np.array([0.1, 0.2, 0.3, 0.4]),
action=1,
reward=1.0,
next_state=np.array([0.2, 0.3, 0.4, 0.5]),
done=False,
)
# 배치 샘플링
if len(buffer) >= 32:
states, actions, rewards, next_states, dones = buffer.sample(32)
2. ターゲットネットワーク(Target Network)
Q値更新のターゲット計算時に別のネットワークを使用します。ターゲットネットワークは定期的にオンラインネットワークの重みをコピーされます。
TD 타겟 = r + gamma * max_{a'} Q_target(s', a'; theta_target)
손실 = (Q(s, a; theta) - TD 타겟)^2
ターゲットネットワークを使うと学習ターゲットが一定期間固定され、学習が安定化します。
3. フレームスタッキング(Frame Stacking)
Atariゲームでは単一フレームではボールの移動方向などがわかりません。連続4個のフレームを重ねて1つの観測として使用します。
DQNモデル実装
Atari用DQNネットワーク
import torch
import torch.nn as nn
class DQN(nn.Module):
"""Atari 게임용 DQN 네트워크"""
def __init__(self, input_channels, n_actions):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(input_channels, 32, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.ReLU(),
)
conv_out_size = self._get_conv_out(input_channels)
self.fc = nn.Sequential(
nn.Linear(conv_out_size, 512),
nn.ReLU(),
nn.Linear(512, n_actions),
)
def _get_conv_out(self, channels):
o = self.conv(torch.zeros(1, channels, 84, 84))
return int(np.prod(o.size()))
def forward(self, x):
conv_out = self.conv(x.float() / 255.0)
return self.fc(conv_out.view(conv_out.size(0), -1))
DQNエージェント実装
import torch
import torch.optim as optim
class DQNAgent:
"""DQN 에이전트"""
def __init__(self, env, device="cpu", buffer_size=100000,
batch_size=32, gamma=0.99, lr=1e-4,
epsilon_start=1.0, epsilon_end=0.01,
epsilon_decay=100000, target_update=1000):
self.env = env
self.device = device
self.batch_size = batch_size
self.gamma = gamma
self.target_update = target_update
# 엡실론 스케줄
self.epsilon_start = epsilon_start
self.epsilon_end = epsilon_end
self.epsilon_decay = epsilon_decay
n_actions = env.action_space.n
# 온라인 네트워크와 타겟 네트워크
self.online_net = DQN(4, n_actions).to(device)
self.target_net = DQN(4, n_actions).to(device)
self.target_net.load_state_dict(self.online_net.state_dict())
self.target_net.eval()
self.optimizer = optim.Adam(self.online_net.parameters(), lr=lr)
self.buffer = ReplayBuffer(buffer_size)
self.step_count = 0
def get_epsilon(self):
"""현재 엡실론 값 계산 (선형 감소)"""
epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
max(0, 1 - self.step_count / self.epsilon_decay)
return epsilon
def select_action(self, state):
"""엡실론-탐욕 정책으로 행동 선택"""
epsilon = self.get_epsilon()
if random.random() < epsilon:
return self.env.action_space.sample()
with torch.no_grad():
state_tensor = torch.tensor(
np.array([state]), dtype=torch.uint8
).to(self.device)
q_values = self.online_net(state_tensor)
return q_values.argmax(dim=1).item()
def update(self):
"""경험 리플레이에서 배치를 샘플링하여 네트워크 업데이트"""
if len(self.buffer) < self.batch_size:
return None
states, actions, rewards, next_states, dones = self.buffer.sample(self.batch_size)
# 텐서 변환
states_t = torch.tensor(states, dtype=torch.uint8).to(self.device)
actions_t = torch.tensor(actions, dtype=torch.long).to(self.device)
rewards_t = torch.tensor(rewards, dtype=torch.float32).to(self.device)
next_states_t = torch.tensor(next_states, dtype=torch.uint8).to(self.device)
dones_t = torch.tensor(dones, dtype=torch.bool).to(self.device)
# 현재 Q값: Q(s, a)
current_q = self.online_net(states_t).gather(1, actions_t.unsqueeze(1)).squeeze(1)
# 타겟 Q값: r + gamma * max_a' Q_target(s', a')
with torch.no_grad():
next_q = self.target_net(next_states_t).max(dim=1)[0]
next_q[dones_t] = 0.0
target_q = rewards_t + self.gamma * next_q
# 손실 계산 및 역전파
loss = nn.SmoothL1Loss()(current_q, target_q)
self.optimizer.zero_grad()
loss.backward()
# 그래디언트 클리핑 (안정적 학습)
torch.nn.utils.clip_grad_norm_(self.online_net.parameters(), max_norm=10)
self.optimizer.step()
return loss.item()
def sync_target_network(self):
"""타겟 네트워크를 온라인 네트워크로 동기화"""
self.target_net.load_state_dict(self.online_net.state_dict())
Atari環境の前処理
DQN学習のためのAtari環境前処理パイプラインです。
import gymnasium as gym
from gymnasium import ObservationWrapper, Wrapper
from gymnasium.spaces import Box
import numpy as np
import cv2
class FireResetWrapper(Wrapper):
"""에피소드 시작 시 FIRE 행동을 자동 실행"""
def __init__(self, env):
super().__init__(env)
def reset(self, **kwargs):
obs, info = self.env.reset(**kwargs)
obs, _, terminated, truncated, info = self.env.step(1) # FIRE
if terminated or truncated:
obs, info = self.env.reset(**kwargs)
return obs, info
class MaxAndSkipWrapper(Wrapper):
"""프레임 스킵: 4프레임마다 행동을 선택하고 중간 프레임은 반복"""
def __init__(self, env, skip=4):
super().__init__(env)
self.skip = skip
self.obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=np.uint8)
def step(self, action):
total_reward = 0.0
done = False
for i in range(self.skip):
obs, reward, terminated, truncated, info = self.env.step(action)
total_reward += reward
if i == self.skip - 2:
self.obs_buffer[0] = obs
if i == self.skip - 1:
self.obs_buffer[1] = obs
if terminated or truncated:
done = True
break
max_frame = self.obs_buffer.max(axis=0)
return max_frame, total_reward, terminated, truncated, info
class ProcessFrame84Wrapper(ObservationWrapper):
"""프레임을 84x84 그레이스케일로 변환"""
def __init__(self, env):
super().__init__(env)
self.observation_space = Box(
low=0, high=255, shape=(84, 84, 1), dtype=np.uint8
)
def observation(self, obs):
return self._process(obs)
def _process(self, frame):
img = np.mean(frame, axis=2).astype(np.uint8)
resized = cv2.resize(img, (84, 84), interpolation=cv2.INTER_AREA)
return resized.reshape(84, 84, 1)
class FrameStackWrapper(ObservationWrapper):
"""연속 n개의 프레임을 채널 방향으로 스태킹"""
def __init__(self, env, n_frames=4):
super().__init__(env)
self.n_frames = n_frames
old_shape = env.observation_space.shape
self.observation_space = Box(
low=0, high=255,
shape=(old_shape[0], old_shape[1], n_frames),
dtype=np.uint8,
)
self.frames = deque(maxlen=n_frames)
def reset(self, **kwargs):
obs, info = self.env.reset(**kwargs)
for _ in range(self.n_frames):
self.frames.append(obs)
return np.concatenate(list(self.frames), axis=-1), info
def observation(self, obs):
self.frames.append(obs)
return np.concatenate(list(self.frames), axis=-1)
class ImageToChannelsFirstWrapper(ObservationWrapper):
"""(H, W, C) -> (C, H, W) 변환"""
def __init__(self, env):
super().__init__(env)
old_shape = env.observation_space.shape
self.observation_space = Box(
low=0, high=255,
shape=(old_shape[2], old_shape[0], old_shape[1]),
dtype=np.uint8,
)
def observation(self, obs):
return np.transpose(obs, (2, 0, 1))
def make_atari_env(env_name="ALE/Pong-v5"):
"""Atari 환경 생성 및 래퍼 적용"""
env = gym.make(env_name)
env = MaxAndSkipWrapper(env)
env = FireResetWrapper(env)
env = ProcessFrame84Wrapper(env)
env = FrameStackWrapper(env)
env = ImageToChannelsFirstWrapper(env)
return env
DQN学習ループ
from torch.utils.tensorboard import SummaryWriter
def train_dqn_pong():
"""DQN으로 Pong 학습"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"사용 장치: {device}")
env = make_atari_env("ALE/Pong-v5")
agent = DQNAgent(
env=env,
device=device,
buffer_size=100000,
batch_size=32,
gamma=0.99,
lr=1e-4,
epsilon_start=1.0,
epsilon_end=0.02,
epsilon_decay=100000,
target_update=1000,
)
writer = SummaryWriter("runs/dqn_pong")
obs, _ = env.reset()
episode_reward = 0
episode_count = 0
best_mean_reward = -float("inf")
for frame in range(1, 1000001):
# 행동 선택 및 실행
action = agent.select_action(obs)
next_obs, reward, terminated, truncated, _ = env.step(action)
# 경험 저장
agent.buffer.push(obs, action, reward, next_obs, terminated or truncated)
episode_reward += reward
agent.step_count += 1
if terminated or truncated:
episode_count += 1
writer.add_scalar("reward/episode", episode_reward, episode_count)
writer.add_scalar("epsilon", agent.get_epsilon(), frame)
if episode_count % 10 == 0:
print(
f"프레임 {frame}, 에피소드 {episode_count}: "
f"보상={episode_reward:.0f}, "
f"엡실론={agent.get_epsilon():.3f}, "
f"버퍼={len(agent.buffer)}"
)
episode_reward = 0
obs, _ = env.reset()
else:
obs = next_obs
# 네트워크 업데이트
loss = agent.update()
if loss is not None and frame % 1000 == 0:
writer.add_scalar("loss/td", loss, frame)
# 타겟 네트워크 동기화
if frame % agent.target_update == 0:
agent.sync_target_network()
# 모델 저장
if frame % 50000 == 0:
torch.save(agent.online_net.state_dict(), f"dqn_pong_{frame}.pth")
writer.close()
env.close()
# train_dqn_pong()
DQN学習結果分析
Pong環境でのDQN学習は一般的に以下のようなパターンを示します。
学習曲線の特徴
- 初期(0-10万フレーム): ランダムに近い行動、報酬約-21(完敗)
- 探索期(10-30万フレーム): イプシロンが減少しながら間欠的にスコアを得始める
- 学習期(30-70万フレーム): 急速に性能が向上、報酬が0付近まで上昇
- 収束期(70万以降): 報酬+19〜+21に到達、ほぼ完璧なプレイ
核心ハイパーパラメータの影響
| パラメータ | 値 | 役割 |
|---|---|---|
| 学習率 | 1e-4 | 大きすぎると不安定、小さすぎると遅い |
| バッチサイズ | 32 | 大きいほど安定だがメモリ使用増加 |
| バッファサイズ | 10万 | 小さいと最近の経験のみ、大きいと多様な経験 |
| ターゲット更新周期 | 1000 | 短いと不安定、長いと古いターゲット使用 |
| イプシロン減少区間 | 10万フレーム | 探索から活用への転換速度 |
| 割引因子 | 0.99 | 未来報酬の重要度 |
簡単な環境でのDQN実験
Pongを学習するにはGPUと時間が多く必要です。CartPoleでDQNの核心概念を検証できます。
class SimpleDQN(nn.Module):
"""CartPole용 간단한 DQN"""
def __init__(self, obs_size, n_actions):
super().__init__()
self.net = nn.Sequential(
nn.Linear(obs_size, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, n_actions),
)
def forward(self, x):
return self.net(x)
def train_dqn_cartpole():
"""CartPole에서 DQN 학습"""
env = gym.make("CartPole-v1")
device = torch.device("cpu")
obs_size = env.observation_space.shape[0]
n_actions = env.action_space.n
online_net = SimpleDQN(obs_size, n_actions).to(device)
target_net = SimpleDQN(obs_size, n_actions).to(device)
target_net.load_state_dict(online_net.state_dict())
optimizer = optim.Adam(online_net.parameters(), lr=1e-3)
buffer = ReplayBuffer(10000)
epsilon = 1.0
epsilon_min = 0.01
epsilon_decay = 0.995
gamma = 0.99
batch_size = 64
target_update = 100
rewards_history = []
step = 0
for episode in range(500):
obs, _ = env.reset()
total_reward = 0
while True:
# 엡실론-탐욕 행동 선택
if random.random() < epsilon:
action = env.action_space.sample()
else:
with torch.no_grad():
q = online_net(torch.tensor([obs], dtype=torch.float32))
action = q.argmax(dim=1).item()
next_obs, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
buffer.push(obs, action, reward, next_obs, done)
total_reward += reward
obs = next_obs
step += 1
# 학습
if len(buffer) >= batch_size:
s, a, r, ns, d = buffer.sample(batch_size)
s_t = torch.tensor(s, dtype=torch.float32)
a_t = torch.tensor(a, dtype=torch.long)
r_t = torch.tensor(r, dtype=torch.float32)
ns_t = torch.tensor(ns, dtype=torch.float32)
d_t = torch.tensor(d, dtype=torch.bool)
current_q = online_net(s_t).gather(1, a_t.unsqueeze(1)).squeeze(1)
with torch.no_grad():
next_q = target_net(ns_t).max(dim=1)[0]
next_q[d_t] = 0.0
target_q = r_t + gamma * next_q
loss = nn.MSELoss()(current_q, target_q)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 타겟 네트워크 업데이트
if step % target_update == 0:
target_net.load_state_dict(online_net.state_dict())
if done:
break
epsilon = max(epsilon_min, epsilon * epsilon_decay)
rewards_history.append(total_reward)
if episode % 50 == 0:
mean_reward = np.mean(rewards_history[-50:])
print(f"에피소드 {episode}: 평균 보상={mean_reward:.1f}, 엡실론={epsilon:.3f}")
if mean_reward >= 475:
print("CartPole 해결!")
break
env.close()
return online_net
# trained_model = train_dqn_cartpole()
まとめ
- テーブルQ学習の限界: 大規模状態空間ではQテーブルを使用できない
- DQNの核心技法: 経験リプレイ(相関除去)、ターゲットネットワーク(学習安定化)
- Atari前処理: フレームスキップ、グレースケール変換、84x84リサイズ、フレームスタッキング
- 学習過程: 100万フレーム以上の学習でPongで人間レベルを達成
- イプシロンスケジュール: 初期探索から段階的に活用へ転換
次の記事ではDQNの性能を大きく向上させる様々な拡張技法(Double DQN、Dueling DQN、Prioritized Replayなど)を扱います。