Skip to content
Published on

[심층 강화학습] 06. Deep Q-Network: DQN의 원리와 구현

Authors

현실 세계에서의 가치 반복의 한계

이전 글에서 다룬 테이블 기반 방법은 상태 공간이 작을 때만 사용할 수 있습니다. 현실 세계의 문제에서는 다음과 같은 한계에 직면합니다.

상태 공간의 폭발

  • Atari 게임 화면: 210 x 160 픽셀, 3개 컬러 채널 = 약 10만 개의 연속 값
  • 바둑: 약 10의 170승에 달하는 가능한 상태
  • 자율 주행: 센서 데이터의 조합은 사실상 무한

이러한 환경에서는 Q 테이블을 만드는 것 자체가 불가능합니다. 신경망으로 Q함수를 **근사(approximation)**해야 합니다.


테이블 Q-러닝에서 딥 Q-러닝으로

핵심 아이디어

Q 테이블 대신 신경망 Q(s, a; theta)를 사용하여 Q값을 근사합니다. 여기서 theta는 신경망의 파라미터(가중치)입니다.

Q-테이블: Q[state_index][action_index] = value (정확한 값 저장)
DQN:      Q(state; theta) -> [q_action_0, q_action_1, ...] (함수 근사)

왜 단순히 신경망만 넣으면 안 되는가

Q-러닝에 단순히 신경망을 대입하면 학습이 매우 불안정합니다. 세 가지 핵심 문제가 있습니다.

  1. 데이터 간의 상관관계: 연속된 경험들은 서로 강하게 상관되어 있어 i.i.d. 가정을 위반합니다
  2. 비정상 타겟: 학습 타겟이 계속 변하므로 수렴이 어렵습니다
  3. 마르코프 성질 위반: 단일 프레임만으로는 환경 상태를 완전히 표현하기 어렵습니다

DQN의 핵심 기법

1. 경험 리플레이 (Experience Replay)

에이전트의 경험을 버퍼에 저장하고, 학습 시 무작위로 샘플링하여 사용합니다. 이를 통해 데이터 간 상관관계를 깨뜨립니다.

import numpy as np
from collections import deque
import random

class ReplayBuffer:
    """경험 리플레이 버퍼"""
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (
            np.array(states),
            np.array(actions),
            np.array(rewards, dtype=np.float32),
            np.array(next_states),
            np.array(dones, dtype=np.bool_),
        )

    def __len__(self):
        return len(self.buffer)

# 사용 예시
buffer = ReplayBuffer(capacity=100000)

# 경험 저장
buffer.push(
    state=np.array([0.1, 0.2, 0.3, 0.4]),
    action=1,
    reward=1.0,
    next_state=np.array([0.2, 0.3, 0.4, 0.5]),
    done=False,
)

# 배치 샘플링
if len(buffer) >= 32:
    states, actions, rewards, next_states, dones = buffer.sample(32)

2. 타겟 네트워크 (Target Network)

Q값 업데이트의 타겟을 계산할 때 별도의 네트워크를 사용합니다. 타겟 네트워크는 주기적으로 온라인 네트워크의 가중치를 복사받습니다.

TD 타겟 = r + gamma * max_{a'} Q_target(s', a'; theta_target)
손실 = (Q(s, a; theta) - TD 타겟)^2

타겟 네트워크를 사용하면 학습 타겟이 일정 기간 동안 고정되어 학습이 안정화됩니다.

3. 프레임 스태킹 (Frame Stacking)

Atari 게임에서 단일 프레임은 공의 이동 방향 등을 알 수 없습니다. 연속 4개의 프레임을 쌓아서 하나의 관찰로 사용합니다.


DQN 모델 구현

Atari용 DQN 네트워크

import torch
import torch.nn as nn

class DQN(nn.Module):
    """Atari 게임용 DQN 네트워크"""
    def __init__(self, input_channels, n_actions):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(input_channels, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
        )

        conv_out_size = self._get_conv_out(input_channels)

        self.fc = nn.Sequential(
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions),
        )

    def _get_conv_out(self, channels):
        o = self.conv(torch.zeros(1, channels, 84, 84))
        return int(np.prod(o.size()))

    def forward(self, x):
        conv_out = self.conv(x.float() / 255.0)
        return self.fc(conv_out.view(conv_out.size(0), -1))

DQN 에이전트 구현

import torch
import torch.optim as optim

class DQNAgent:
    """DQN 에이전트"""
    def __init__(self, env, device="cpu", buffer_size=100000,
                 batch_size=32, gamma=0.99, lr=1e-4,
                 epsilon_start=1.0, epsilon_end=0.01,
                 epsilon_decay=100000, target_update=1000):
        self.env = env
        self.device = device
        self.batch_size = batch_size
        self.gamma = gamma
        self.target_update = target_update

        # 엡실론 스케줄
        self.epsilon_start = epsilon_start
        self.epsilon_end = epsilon_end
        self.epsilon_decay = epsilon_decay

        n_actions = env.action_space.n

        # 온라인 네트워크와 타겟 네트워크
        self.online_net = DQN(4, n_actions).to(device)
        self.target_net = DQN(4, n_actions).to(device)
        self.target_net.load_state_dict(self.online_net.state_dict())
        self.target_net.eval()

        self.optimizer = optim.Adam(self.online_net.parameters(), lr=lr)
        self.buffer = ReplayBuffer(buffer_size)
        self.step_count = 0

    def get_epsilon(self):
        """현재 엡실론 값 계산 (선형 감소)"""
        epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
            max(0, 1 - self.step_count / self.epsilon_decay)
        return epsilon

    def select_action(self, state):
        """엡실론-탐욕 정책으로 행동 선택"""
        epsilon = self.get_epsilon()

        if random.random() < epsilon:
            return self.env.action_space.sample()

        with torch.no_grad():
            state_tensor = torch.tensor(
                np.array([state]), dtype=torch.uint8
            ).to(self.device)
            q_values = self.online_net(state_tensor)
            return q_values.argmax(dim=1).item()

    def update(self):
        """경험 리플레이에서 배치를 샘플링하여 네트워크 업데이트"""
        if len(self.buffer) < self.batch_size:
            return None

        states, actions, rewards, next_states, dones = self.buffer.sample(self.batch_size)

        # 텐서 변환
        states_t = torch.tensor(states, dtype=torch.uint8).to(self.device)
        actions_t = torch.tensor(actions, dtype=torch.long).to(self.device)
        rewards_t = torch.tensor(rewards, dtype=torch.float32).to(self.device)
        next_states_t = torch.tensor(next_states, dtype=torch.uint8).to(self.device)
        dones_t = torch.tensor(dones, dtype=torch.bool).to(self.device)

        # 현재 Q값: Q(s, a)
        current_q = self.online_net(states_t).gather(1, actions_t.unsqueeze(1)).squeeze(1)

        # 타겟 Q값: r + gamma * max_a' Q_target(s', a')
        with torch.no_grad():
            next_q = self.target_net(next_states_t).max(dim=1)[0]
            next_q[dones_t] = 0.0
            target_q = rewards_t + self.gamma * next_q

        # 손실 계산 및 역전파
        loss = nn.SmoothL1Loss()(current_q, target_q)

        self.optimizer.zero_grad()
        loss.backward()
        # 그래디언트 클리핑 (안정적 학습)
        torch.nn.utils.clip_grad_norm_(self.online_net.parameters(), max_norm=10)
        self.optimizer.step()

        return loss.item()

    def sync_target_network(self):
        """타겟 네트워크를 온라인 네트워크로 동기화"""
        self.target_net.load_state_dict(self.online_net.state_dict())

Atari 환경 전처리

DQN 학습을 위한 Atari 환경 전처리 파이프라인입니다.

import gymnasium as gym
from gymnasium import ObservationWrapper, Wrapper
from gymnasium.spaces import Box
import numpy as np
import cv2

class FireResetWrapper(Wrapper):
    """에피소드 시작 시 FIRE 행동을 자동 실행"""
    def __init__(self, env):
        super().__init__(env)

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        obs, _, terminated, truncated, info = self.env.step(1)  # FIRE
        if terminated or truncated:
            obs, info = self.env.reset(**kwargs)
        return obs, info

class MaxAndSkipWrapper(Wrapper):
    """프레임 스킵: 4프레임마다 행동을 선택하고 중간 프레임은 반복"""
    def __init__(self, env, skip=4):
        super().__init__(env)
        self.skip = skip
        self.obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=np.uint8)

    def step(self, action):
        total_reward = 0.0
        done = False
        for i in range(self.skip):
            obs, reward, terminated, truncated, info = self.env.step(action)
            total_reward += reward
            if i == self.skip - 2:
                self.obs_buffer[0] = obs
            if i == self.skip - 1:
                self.obs_buffer[1] = obs
            if terminated or truncated:
                done = True
                break
        max_frame = self.obs_buffer.max(axis=0)
        return max_frame, total_reward, terminated, truncated, info

class ProcessFrame84Wrapper(ObservationWrapper):
    """프레임을 84x84 그레이스케일로 변환"""
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = Box(
            low=0, high=255, shape=(84, 84, 1), dtype=np.uint8
        )

    def observation(self, obs):
        return self._process(obs)

    def _process(self, frame):
        img = np.mean(frame, axis=2).astype(np.uint8)
        resized = cv2.resize(img, (84, 84), interpolation=cv2.INTER_AREA)
        return resized.reshape(84, 84, 1)

class FrameStackWrapper(ObservationWrapper):
    """연속 n개의 프레임을 채널 방향으로 스태킹"""
    def __init__(self, env, n_frames=4):
        super().__init__(env)
        self.n_frames = n_frames
        old_shape = env.observation_space.shape
        self.observation_space = Box(
            low=0, high=255,
            shape=(old_shape[0], old_shape[1], n_frames),
            dtype=np.uint8,
        )
        self.frames = deque(maxlen=n_frames)

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        for _ in range(self.n_frames):
            self.frames.append(obs)
        return np.concatenate(list(self.frames), axis=-1), info

    def observation(self, obs):
        self.frames.append(obs)
        return np.concatenate(list(self.frames), axis=-1)

class ImageToChannelsFirstWrapper(ObservationWrapper):
    """(H, W, C) -> (C, H, W) 변환"""
    def __init__(self, env):
        super().__init__(env)
        old_shape = env.observation_space.shape
        self.observation_space = Box(
            low=0, high=255,
            shape=(old_shape[2], old_shape[0], old_shape[1]),
            dtype=np.uint8,
        )

    def observation(self, obs):
        return np.transpose(obs, (2, 0, 1))


def make_atari_env(env_name="ALE/Pong-v5"):
    """Atari 환경 생성 및 래퍼 적용"""
    env = gym.make(env_name)
    env = MaxAndSkipWrapper(env)
    env = FireResetWrapper(env)
    env = ProcessFrame84Wrapper(env)
    env = FrameStackWrapper(env)
    env = ImageToChannelsFirstWrapper(env)
    return env

DQN 학습 루프

from torch.utils.tensorboard import SummaryWriter

def train_dqn_pong():
    """DQN으로 Pong 학습"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"사용 장치: {device}")

    env = make_atari_env("ALE/Pong-v5")
    agent = DQNAgent(
        env=env,
        device=device,
        buffer_size=100000,
        batch_size=32,
        gamma=0.99,
        lr=1e-4,
        epsilon_start=1.0,
        epsilon_end=0.02,
        epsilon_decay=100000,
        target_update=1000,
    )

    writer = SummaryWriter("runs/dqn_pong")

    obs, _ = env.reset()
    episode_reward = 0
    episode_count = 0
    best_mean_reward = -float("inf")

    for frame in range(1, 1000001):
        # 행동 선택 및 실행
        action = agent.select_action(obs)
        next_obs, reward, terminated, truncated, _ = env.step(action)

        # 경험 저장
        agent.buffer.push(obs, action, reward, next_obs, terminated or truncated)
        episode_reward += reward
        agent.step_count += 1

        if terminated or truncated:
            episode_count += 1
            writer.add_scalar("reward/episode", episode_reward, episode_count)
            writer.add_scalar("epsilon", agent.get_epsilon(), frame)

            if episode_count % 10 == 0:
                print(
                    f"프레임 {frame}, 에피소드 {episode_count}: "
                    f"보상={episode_reward:.0f}, "
                    f"엡실론={agent.get_epsilon():.3f}, "
                    f"버퍼={len(agent.buffer)}"
                )

            episode_reward = 0
            obs, _ = env.reset()
        else:
            obs = next_obs

        # 네트워크 업데이트
        loss = agent.update()
        if loss is not None and frame % 1000 == 0:
            writer.add_scalar("loss/td", loss, frame)

        # 타겟 네트워크 동기화
        if frame % agent.target_update == 0:
            agent.sync_target_network()

        # 모델 저장
        if frame % 50000 == 0:
            torch.save(agent.online_net.state_dict(), f"dqn_pong_{frame}.pth")

    writer.close()
    env.close()

# train_dqn_pong()

DQN 학습 결과 분석

Pong 환경에서의 DQN 학습은 일반적으로 다음과 같은 패턴을 보입니다.

학습 곡선의 특징

  • 초기 (0-10만 프레임): 무작위에 가까운 행동, 보상 약 -21 (완패)
  • 탐색기 (10-30만 프레임): 엡실론이 감소하면서 간헐적으로 점수를 얻기 시작
  • 학습기 (30-70만 프레임): 빠르게 성능이 향상, 보상이 0 근처로 상승
  • 수렴기 (70만 이후): 보상 +19~+21에 도달, 거의 완벽한 플레이

핵심 하이퍼파라미터의 영향

파라미터역할
학습률1e-4너무 크면 불안정, 너무 작으면 느림
배치 크기32클수록 안정적이지만 메모리 사용 증가
버퍼 크기10만작으면 최근 경험만 사용, 클수록 다양한 경험
타겟 업데이트 주기1000짧으면 불안정, 길면 오래된 타겟 사용
엡실론 감소 구간10만 프레임탐색에서 활용으로의 전환 속도
할인 인자0.99미래 보상의 중요도

간단한 환경에서의 DQN 실험

Pong을 학습하려면 GPU와 시간이 많이 필요합니다. CartPole에서 DQN의 핵심 개념을 검증할 수 있습니다.

class SimpleDQN(nn.Module):
    """CartPole용 간단한 DQN"""
    def __init__(self, obs_size, n_actions):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_size, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, n_actions),
        )

    def forward(self, x):
        return self.net(x)

def train_dqn_cartpole():
    """CartPole에서 DQN 학습"""
    env = gym.make("CartPole-v1")
    device = torch.device("cpu")

    obs_size = env.observation_space.shape[0]
    n_actions = env.action_space.n

    online_net = SimpleDQN(obs_size, n_actions).to(device)
    target_net = SimpleDQN(obs_size, n_actions).to(device)
    target_net.load_state_dict(online_net.state_dict())

    optimizer = optim.Adam(online_net.parameters(), lr=1e-3)
    buffer = ReplayBuffer(10000)

    epsilon = 1.0
    epsilon_min = 0.01
    epsilon_decay = 0.995
    gamma = 0.99
    batch_size = 64
    target_update = 100

    rewards_history = []
    step = 0

    for episode in range(500):
        obs, _ = env.reset()
        total_reward = 0

        while True:
            # 엡실론-탐욕 행동 선택
            if random.random() < epsilon:
                action = env.action_space.sample()
            else:
                with torch.no_grad():
                    q = online_net(torch.tensor([obs], dtype=torch.float32))
                    action = q.argmax(dim=1).item()

            next_obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            buffer.push(obs, action, reward, next_obs, done)
            total_reward += reward
            obs = next_obs
            step += 1

            # 학습
            if len(buffer) >= batch_size:
                s, a, r, ns, d = buffer.sample(batch_size)
                s_t = torch.tensor(s, dtype=torch.float32)
                a_t = torch.tensor(a, dtype=torch.long)
                r_t = torch.tensor(r, dtype=torch.float32)
                ns_t = torch.tensor(ns, dtype=torch.float32)
                d_t = torch.tensor(d, dtype=torch.bool)

                current_q = online_net(s_t).gather(1, a_t.unsqueeze(1)).squeeze(1)

                with torch.no_grad():
                    next_q = target_net(ns_t).max(dim=1)[0]
                    next_q[d_t] = 0.0
                    target_q = r_t + gamma * next_q

                loss = nn.MSELoss()(current_q, target_q)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            # 타겟 네트워크 업데이트
            if step % target_update == 0:
                target_net.load_state_dict(online_net.state_dict())

            if done:
                break

        epsilon = max(epsilon_min, epsilon * epsilon_decay)
        rewards_history.append(total_reward)

        if episode % 50 == 0:
            mean_reward = np.mean(rewards_history[-50:])
            print(f"에피소드 {episode}: 평균 보상={mean_reward:.1f}, 엡실론={epsilon:.3f}")

            if mean_reward >= 475:
                print("CartPole 해결!")
                break

    env.close()
    return online_net

# trained_model = train_dqn_cartpole()

정리

  1. 테이블 Q-러닝의 한계: 대규모 상태 공간에서는 Q 테이블을 사용할 수 없음
  2. DQN의 핵심 기법: 경험 리플레이(상관관계 제거), 타겟 네트워크(학습 안정화)
  3. Atari 전처리: 프레임 스킵, 그레이스케일 변환, 84x84 리사이즈, 프레임 스태킹
  4. 학습 과정: 백만 프레임 이상의 학습을 통해 Pong에서 인간 수준 달성
  5. 엡실론 스케줄: 초기 탐색에서 점진적으로 활용으로 전환

다음 글에서는 DQN의 성능을 크게 향상시키는 다양한 확장 기법(Double DQN, Dueling DQN, Prioritized Replay 등)을 다루겠습니다.