Skip to content
Published on

[深層強化学習] 06. Deep Q-Network:DQNの原理と実装

Authors

現実世界における価値反復の限界

前の記事で扱ったテーブルベースの方法は状態空間が小さい時のみ使用可能です。現実世界の問題では以下のような限界に直面します。

状態空間の爆発

  • Atariゲーム画面: 210 x 160ピクセル、3色チャンネル = 約10万個の連続値
  • 囲碁: 約10の170乗に達する可能な状態
  • 自動運転: センサーデータの組み合わせは事実上無限

このような環境ではQテーブルを作ること自体が不可能です。ニューラルネットワークでQ関数を**近似(approximation)**する必要があります。


テーブルQ学習からディープQ学習へ

核心アイデア

Qテーブルの代わりにニューラルネットワーク Q(s, a; theta) を使ってQ値を近似します。ここでthetaはニューラルネットワークのパラメータ(重み)です。

Q-테이블: Q[state_index][action_index] = value (정확한 값 저장)
DQN:      Q(state; theta) -> [q_action_0, q_action_1, ...] (함수 근사)

なぜ単純にニューラルネットワークを入れるだけではダメなのか

Q学習に単純にニューラルネットワークを代入すると学習が非常に不安定になります。3つの核心問題があります:

  1. データ間の相関: 連続した経験は互いに強く相関しており、i.i.d.仮定に違反する
  2. 非定常ターゲット: 学習ターゲットが常に変化するため収束が困難
  3. マルコフ性質違反: 単一フレームだけでは環境状態を完全に表現できない

DQNの核心技法

1. 経験リプレイ(Experience Replay)

エージェントの経験をバッファに保存し、学習時にランダムにサンプリングして使用します。これによりデータ間の相関を破壊します。

import numpy as np
from collections import deque
import random

class ReplayBuffer:
    """경험 리플레이 버퍼"""
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (
            np.array(states),
            np.array(actions),
            np.array(rewards, dtype=np.float32),
            np.array(next_states),
            np.array(dones, dtype=np.bool_),
        )

    def __len__(self):
        return len(self.buffer)

# 사용 예시
buffer = ReplayBuffer(capacity=100000)

# 경험 저장
buffer.push(
    state=np.array([0.1, 0.2, 0.3, 0.4]),
    action=1,
    reward=1.0,
    next_state=np.array([0.2, 0.3, 0.4, 0.5]),
    done=False,
)

# 배치 샘플링
if len(buffer) >= 32:
    states, actions, rewards, next_states, dones = buffer.sample(32)

2. ターゲットネットワーク(Target Network)

Q値更新のターゲット計算時に別のネットワークを使用します。ターゲットネットワークは定期的にオンラインネットワークの重みをコピーされます。

TD 타겟 = r + gamma * max_{a'} Q_target(s', a'; theta_target)
손실 = (Q(s, a; theta) - TD 타겟)^2

ターゲットネットワークを使うと学習ターゲットが一定期間固定され、学習が安定化します。

3. フレームスタッキング(Frame Stacking)

Atariゲームでは単一フレームではボールの移動方向などがわかりません。連続4個のフレームを重ねて1つの観測として使用します。


DQNモデル実装

Atari用DQNネットワーク

import torch
import torch.nn as nn

class DQN(nn.Module):
    """Atari 게임용 DQN 네트워크"""
    def __init__(self, input_channels, n_actions):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(input_channels, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
        )

        conv_out_size = self._get_conv_out(input_channels)

        self.fc = nn.Sequential(
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions),
        )

    def _get_conv_out(self, channels):
        o = self.conv(torch.zeros(1, channels, 84, 84))
        return int(np.prod(o.size()))

    def forward(self, x):
        conv_out = self.conv(x.float() / 255.0)
        return self.fc(conv_out.view(conv_out.size(0), -1))

DQNエージェント実装

import torch
import torch.optim as optim

class DQNAgent:
    """DQN 에이전트"""
    def __init__(self, env, device="cpu", buffer_size=100000,
                 batch_size=32, gamma=0.99, lr=1e-4,
                 epsilon_start=1.0, epsilon_end=0.01,
                 epsilon_decay=100000, target_update=1000):
        self.env = env
        self.device = device
        self.batch_size = batch_size
        self.gamma = gamma
        self.target_update = target_update

        # 엡실론 스케줄
        self.epsilon_start = epsilon_start
        self.epsilon_end = epsilon_end
        self.epsilon_decay = epsilon_decay

        n_actions = env.action_space.n

        # 온라인 네트워크와 타겟 네트워크
        self.online_net = DQN(4, n_actions).to(device)
        self.target_net = DQN(4, n_actions).to(device)
        self.target_net.load_state_dict(self.online_net.state_dict())
        self.target_net.eval()

        self.optimizer = optim.Adam(self.online_net.parameters(), lr=lr)
        self.buffer = ReplayBuffer(buffer_size)
        self.step_count = 0

    def get_epsilon(self):
        """현재 엡실론 값 계산 (선형 감소)"""
        epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
            max(0, 1 - self.step_count / self.epsilon_decay)
        return epsilon

    def select_action(self, state):
        """엡실론-탐욕 정책으로 행동 선택"""
        epsilon = self.get_epsilon()

        if random.random() < epsilon:
            return self.env.action_space.sample()

        with torch.no_grad():
            state_tensor = torch.tensor(
                np.array([state]), dtype=torch.uint8
            ).to(self.device)
            q_values = self.online_net(state_tensor)
            return q_values.argmax(dim=1).item()

    def update(self):
        """경험 리플레이에서 배치를 샘플링하여 네트워크 업데이트"""
        if len(self.buffer) < self.batch_size:
            return None

        states, actions, rewards, next_states, dones = self.buffer.sample(self.batch_size)

        # 텐서 변환
        states_t = torch.tensor(states, dtype=torch.uint8).to(self.device)
        actions_t = torch.tensor(actions, dtype=torch.long).to(self.device)
        rewards_t = torch.tensor(rewards, dtype=torch.float32).to(self.device)
        next_states_t = torch.tensor(next_states, dtype=torch.uint8).to(self.device)
        dones_t = torch.tensor(dones, dtype=torch.bool).to(self.device)

        # 현재 Q값: Q(s, a)
        current_q = self.online_net(states_t).gather(1, actions_t.unsqueeze(1)).squeeze(1)

        # 타겟 Q값: r + gamma * max_a' Q_target(s', a')
        with torch.no_grad():
            next_q = self.target_net(next_states_t).max(dim=1)[0]
            next_q[dones_t] = 0.0
            target_q = rewards_t + self.gamma * next_q

        # 손실 계산 및 역전파
        loss = nn.SmoothL1Loss()(current_q, target_q)

        self.optimizer.zero_grad()
        loss.backward()
        # 그래디언트 클리핑 (안정적 학습)
        torch.nn.utils.clip_grad_norm_(self.online_net.parameters(), max_norm=10)
        self.optimizer.step()

        return loss.item()

    def sync_target_network(self):
        """타겟 네트워크를 온라인 네트워크로 동기화"""
        self.target_net.load_state_dict(self.online_net.state_dict())

Atari環境の前処理

DQN学習のためのAtari環境前処理パイプラインです。

import gymnasium as gym
from gymnasium import ObservationWrapper, Wrapper
from gymnasium.spaces import Box
import numpy as np
import cv2

class FireResetWrapper(Wrapper):
    """에피소드 시작 시 FIRE 행동을 자동 실행"""
    def __init__(self, env):
        super().__init__(env)

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        obs, _, terminated, truncated, info = self.env.step(1)  # FIRE
        if terminated or truncated:
            obs, info = self.env.reset(**kwargs)
        return obs, info

class MaxAndSkipWrapper(Wrapper):
    """프레임 스킵: 4프레임마다 행동을 선택하고 중간 프레임은 반복"""
    def __init__(self, env, skip=4):
        super().__init__(env)
        self.skip = skip
        self.obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=np.uint8)

    def step(self, action):
        total_reward = 0.0
        done = False
        for i in range(self.skip):
            obs, reward, terminated, truncated, info = self.env.step(action)
            total_reward += reward
            if i == self.skip - 2:
                self.obs_buffer[0] = obs
            if i == self.skip - 1:
                self.obs_buffer[1] = obs
            if terminated or truncated:
                done = True
                break
        max_frame = self.obs_buffer.max(axis=0)
        return max_frame, total_reward, terminated, truncated, info

class ProcessFrame84Wrapper(ObservationWrapper):
    """프레임을 84x84 그레이스케일로 변환"""
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = Box(
            low=0, high=255, shape=(84, 84, 1), dtype=np.uint8
        )

    def observation(self, obs):
        return self._process(obs)

    def _process(self, frame):
        img = np.mean(frame, axis=2).astype(np.uint8)
        resized = cv2.resize(img, (84, 84), interpolation=cv2.INTER_AREA)
        return resized.reshape(84, 84, 1)

class FrameStackWrapper(ObservationWrapper):
    """연속 n개의 프레임을 채널 방향으로 스태킹"""
    def __init__(self, env, n_frames=4):
        super().__init__(env)
        self.n_frames = n_frames
        old_shape = env.observation_space.shape
        self.observation_space = Box(
            low=0, high=255,
            shape=(old_shape[0], old_shape[1], n_frames),
            dtype=np.uint8,
        )
        self.frames = deque(maxlen=n_frames)

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        for _ in range(self.n_frames):
            self.frames.append(obs)
        return np.concatenate(list(self.frames), axis=-1), info

    def observation(self, obs):
        self.frames.append(obs)
        return np.concatenate(list(self.frames), axis=-1)

class ImageToChannelsFirstWrapper(ObservationWrapper):
    """(H, W, C) -> (C, H, W) 변환"""
    def __init__(self, env):
        super().__init__(env)
        old_shape = env.observation_space.shape
        self.observation_space = Box(
            low=0, high=255,
            shape=(old_shape[2], old_shape[0], old_shape[1]),
            dtype=np.uint8,
        )

    def observation(self, obs):
        return np.transpose(obs, (2, 0, 1))


def make_atari_env(env_name="ALE/Pong-v5"):
    """Atari 환경 생성 및 래퍼 적용"""
    env = gym.make(env_name)
    env = MaxAndSkipWrapper(env)
    env = FireResetWrapper(env)
    env = ProcessFrame84Wrapper(env)
    env = FrameStackWrapper(env)
    env = ImageToChannelsFirstWrapper(env)
    return env

DQN学習ループ

from torch.utils.tensorboard import SummaryWriter

def train_dqn_pong():
    """DQN으로 Pong 학습"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"사용 장치: {device}")

    env = make_atari_env("ALE/Pong-v5")
    agent = DQNAgent(
        env=env,
        device=device,
        buffer_size=100000,
        batch_size=32,
        gamma=0.99,
        lr=1e-4,
        epsilon_start=1.0,
        epsilon_end=0.02,
        epsilon_decay=100000,
        target_update=1000,
    )

    writer = SummaryWriter("runs/dqn_pong")

    obs, _ = env.reset()
    episode_reward = 0
    episode_count = 0
    best_mean_reward = -float("inf")

    for frame in range(1, 1000001):
        # 행동 선택 및 실행
        action = agent.select_action(obs)
        next_obs, reward, terminated, truncated, _ = env.step(action)

        # 경험 저장
        agent.buffer.push(obs, action, reward, next_obs, terminated or truncated)
        episode_reward += reward
        agent.step_count += 1

        if terminated or truncated:
            episode_count += 1
            writer.add_scalar("reward/episode", episode_reward, episode_count)
            writer.add_scalar("epsilon", agent.get_epsilon(), frame)

            if episode_count % 10 == 0:
                print(
                    f"프레임 {frame}, 에피소드 {episode_count}: "
                    f"보상={episode_reward:.0f}, "
                    f"엡실론={agent.get_epsilon():.3f}, "
                    f"버퍼={len(agent.buffer)}"
                )

            episode_reward = 0
            obs, _ = env.reset()
        else:
            obs = next_obs

        # 네트워크 업데이트
        loss = agent.update()
        if loss is not None and frame % 1000 == 0:
            writer.add_scalar("loss/td", loss, frame)

        # 타겟 네트워크 동기화
        if frame % agent.target_update == 0:
            agent.sync_target_network()

        # 모델 저장
        if frame % 50000 == 0:
            torch.save(agent.online_net.state_dict(), f"dqn_pong_{frame}.pth")

    writer.close()
    env.close()

# train_dqn_pong()

DQN学習結果分析

Pong環境でのDQN学習は一般的に以下のようなパターンを示します。

学習曲線の特徴

  • 初期(0-10万フレーム): ランダムに近い行動、報酬約-21(完敗)
  • 探索期(10-30万フレーム): イプシロンが減少しながら間欠的にスコアを得始める
  • 学習期(30-70万フレーム): 急速に性能が向上、報酬が0付近まで上昇
  • 収束期(70万以降): 報酬+19〜+21に到達、ほぼ完璧なプレイ

核心ハイパーパラメータの影響

パラメータ役割
学習率1e-4大きすぎると不安定、小さすぎると遅い
バッチサイズ32大きいほど安定だがメモリ使用増加
バッファサイズ10万小さいと最近の経験のみ、大きいと多様な経験
ターゲット更新周期1000短いと不安定、長いと古いターゲット使用
イプシロン減少区間10万フレーム探索から活用への転換速度
割引因子0.99未来報酬の重要度

簡単な環境でのDQN実験

Pongを学習するにはGPUと時間が多く必要です。CartPoleでDQNの核心概念を検証できます。

class SimpleDQN(nn.Module):
    """CartPole용 간단한 DQN"""
    def __init__(self, obs_size, n_actions):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_size, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, n_actions),
        )

    def forward(self, x):
        return self.net(x)

def train_dqn_cartpole():
    """CartPole에서 DQN 학습"""
    env = gym.make("CartPole-v1")
    device = torch.device("cpu")

    obs_size = env.observation_space.shape[0]
    n_actions = env.action_space.n

    online_net = SimpleDQN(obs_size, n_actions).to(device)
    target_net = SimpleDQN(obs_size, n_actions).to(device)
    target_net.load_state_dict(online_net.state_dict())

    optimizer = optim.Adam(online_net.parameters(), lr=1e-3)
    buffer = ReplayBuffer(10000)

    epsilon = 1.0
    epsilon_min = 0.01
    epsilon_decay = 0.995
    gamma = 0.99
    batch_size = 64
    target_update = 100

    rewards_history = []
    step = 0

    for episode in range(500):
        obs, _ = env.reset()
        total_reward = 0

        while True:
            # 엡실론-탐욕 행동 선택
            if random.random() < epsilon:
                action = env.action_space.sample()
            else:
                with torch.no_grad():
                    q = online_net(torch.tensor([obs], dtype=torch.float32))
                    action = q.argmax(dim=1).item()

            next_obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            buffer.push(obs, action, reward, next_obs, done)
            total_reward += reward
            obs = next_obs
            step += 1

            # 학습
            if len(buffer) >= batch_size:
                s, a, r, ns, d = buffer.sample(batch_size)
                s_t = torch.tensor(s, dtype=torch.float32)
                a_t = torch.tensor(a, dtype=torch.long)
                r_t = torch.tensor(r, dtype=torch.float32)
                ns_t = torch.tensor(ns, dtype=torch.float32)
                d_t = torch.tensor(d, dtype=torch.bool)

                current_q = online_net(s_t).gather(1, a_t.unsqueeze(1)).squeeze(1)

                with torch.no_grad():
                    next_q = target_net(ns_t).max(dim=1)[0]
                    next_q[d_t] = 0.0
                    target_q = r_t + gamma * next_q

                loss = nn.MSELoss()(current_q, target_q)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            # 타겟 네트워크 업데이트
            if step % target_update == 0:
                target_net.load_state_dict(online_net.state_dict())

            if done:
                break

        epsilon = max(epsilon_min, epsilon * epsilon_decay)
        rewards_history.append(total_reward)

        if episode % 50 == 0:
            mean_reward = np.mean(rewards_history[-50:])
            print(f"에피소드 {episode}: 평균 보상={mean_reward:.1f}, 엡실론={epsilon:.3f}")

            if mean_reward >= 475:
                print("CartPole 해결!")
                break

    env.close()
    return online_net

# trained_model = train_dqn_cartpole()

まとめ

  1. テーブルQ学習の限界: 大規模状態空間ではQテーブルを使用できない
  2. DQNの核心技法: 経験リプレイ(相関除去)、ターゲットネットワーク(学習安定化)
  3. Atari前処理: フレームスキップ、グレースケール変換、84x84リサイズ、フレームスタッキング
  4. 学習過程: 100万フレーム以上の学習でPongで人間レベルを達成
  5. イプシロンスケジュール: 初期探索から段階的に活用へ転換

次の記事ではDQNの性能を大きく向上させる様々な拡張技法(Double DQN、Dueling DQN、Prioritized Replayなど)を扱います。