Skip to content
Published on

[Deep RL] 06. Deep Q-Network: Principles and Implementation of DQN

Authors

Limitations of Value Iteration in the Real World

The table-based methods covered in the previous article can only be used when the state space is small. In real-world problems, we face the following limitations.

Explosion of the State Space

  • Atari game screens: 210 x 160 pixels, 3 color channels = about 100,000 continuous values
  • Go: Approximately 10 to the power of 170 possible states
  • Autonomous driving: Sensor data combinations are practically infinite

In such environments, creating a Q table itself is impossible. We must approximate the Q function with neural networks.


From Table Q-Learning to Deep Q-Learning

Core Idea

Instead of a Q table, we use a neural network Q(s, a; theta) to approximate Q values, where theta represents the network parameters (weights).

Q-테이블: Q[state_index][action_index] = value (정확한 값 저장)
DQN:      Q(state; theta) -> [q_action_0, q_action_1, ...] (함수 근사)

Why Simply Adding a Neural Network Is Not Enough

Simply substituting a neural network into Q-learning makes training very unstable. There are three core problems:

  1. Correlation between data: Consecutive experiences are strongly correlated, violating the i.i.d. assumption
  2. Non-stationary targets: Learning targets keep changing, making convergence difficult
  3. Markov property violation: A single frame alone cannot fully represent the environment state

Core Techniques of DQN

1. Experience Replay

Agent experiences are stored in a buffer and randomly sampled during training. This breaks the correlation between data.

import numpy as np
from collections import deque
import random

class ReplayBuffer:
    """경험 리플레이 버퍼"""
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (
            np.array(states),
            np.array(actions),
            np.array(rewards, dtype=np.float32),
            np.array(next_states),
            np.array(dones, dtype=np.bool_),
        )

    def __len__(self):
        return len(self.buffer)

# 사용 예시
buffer = ReplayBuffer(capacity=100000)

# 경험 저장
buffer.push(
    state=np.array([0.1, 0.2, 0.3, 0.4]),
    action=1,
    reward=1.0,
    next_state=np.array([0.2, 0.3, 0.4, 0.5]),
    done=False,
)

# 배치 샘플링
if len(buffer) >= 32:
    states, actions, rewards, next_states, dones = buffer.sample(32)

2. Target Network

A separate network is used to compute the target for Q value updates. The target network periodically copies weights from the online network.

TD 타겟 = r + gamma * max_{a'} Q_target(s', a'; theta_target)
손실 = (Q(s, a; theta) - TD 타겟)^2

Using a target network fixes the learning target for a period, stabilizing training.

3. Frame Stacking

In Atari games, a single frame cannot tell us the direction of ball movement, etc. We stack 4 consecutive frames as one observation.


DQN Model Implementation

DQN Network for Atari

import torch
import torch.nn as nn

class DQN(nn.Module):
    """Atari 게임용 DQN 네트워크"""
    def __init__(self, input_channels, n_actions):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(input_channels, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
        )

        conv_out_size = self._get_conv_out(input_channels)

        self.fc = nn.Sequential(
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions),
        )

    def _get_conv_out(self, channels):
        o = self.conv(torch.zeros(1, channels, 84, 84))
        return int(np.prod(o.size()))

    def forward(self, x):
        conv_out = self.conv(x.float() / 255.0)
        return self.fc(conv_out.view(conv_out.size(0), -1))

DQN Agent Implementation

import torch
import torch.optim as optim

class DQNAgent:
    """DQN 에이전트"""
    def __init__(self, env, device="cpu", buffer_size=100000,
                 batch_size=32, gamma=0.99, lr=1e-4,
                 epsilon_start=1.0, epsilon_end=0.01,
                 epsilon_decay=100000, target_update=1000):
        self.env = env
        self.device = device
        self.batch_size = batch_size
        self.gamma = gamma
        self.target_update = target_update

        # 엡실론 스케줄
        self.epsilon_start = epsilon_start
        self.epsilon_end = epsilon_end
        self.epsilon_decay = epsilon_decay

        n_actions = env.action_space.n

        # 온라인 네트워크와 타겟 네트워크
        self.online_net = DQN(4, n_actions).to(device)
        self.target_net = DQN(4, n_actions).to(device)
        self.target_net.load_state_dict(self.online_net.state_dict())
        self.target_net.eval()

        self.optimizer = optim.Adam(self.online_net.parameters(), lr=lr)
        self.buffer = ReplayBuffer(buffer_size)
        self.step_count = 0

    def get_epsilon(self):
        """현재 엡실론 값 계산 (선형 감소)"""
        epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
            max(0, 1 - self.step_count / self.epsilon_decay)
        return epsilon

    def select_action(self, state):
        """엡실론-탐욕 정책으로 행동 선택"""
        epsilon = self.get_epsilon()

        if random.random() < epsilon:
            return self.env.action_space.sample()

        with torch.no_grad():
            state_tensor = torch.tensor(
                np.array([state]), dtype=torch.uint8
            ).to(self.device)
            q_values = self.online_net(state_tensor)
            return q_values.argmax(dim=1).item()

    def update(self):
        """경험 리플레이에서 배치를 샘플링하여 네트워크 업데이트"""
        if len(self.buffer) < self.batch_size:
            return None

        states, actions, rewards, next_states, dones = self.buffer.sample(self.batch_size)

        # 텐서 변환
        states_t = torch.tensor(states, dtype=torch.uint8).to(self.device)
        actions_t = torch.tensor(actions, dtype=torch.long).to(self.device)
        rewards_t = torch.tensor(rewards, dtype=torch.float32).to(self.device)
        next_states_t = torch.tensor(next_states, dtype=torch.uint8).to(self.device)
        dones_t = torch.tensor(dones, dtype=torch.bool).to(self.device)

        # 현재 Q값: Q(s, a)
        current_q = self.online_net(states_t).gather(1, actions_t.unsqueeze(1)).squeeze(1)

        # 타겟 Q값: r + gamma * max_a' Q_target(s', a')
        with torch.no_grad():
            next_q = self.target_net(next_states_t).max(dim=1)[0]
            next_q[dones_t] = 0.0
            target_q = rewards_t + self.gamma * next_q

        # 손실 계산 및 역전파
        loss = nn.SmoothL1Loss()(current_q, target_q)

        self.optimizer.zero_grad()
        loss.backward()
        # 그래디언트 클리핑 (안정적 학습)
        torch.nn.utils.clip_grad_norm_(self.online_net.parameters(), max_norm=10)
        self.optimizer.step()

        return loss.item()

    def sync_target_network(self):
        """타겟 네트워크를 온라인 네트워크로 동기화"""
        self.target_net.load_state_dict(self.online_net.state_dict())

Atari Environment Preprocessing

The Atari environment preprocessing pipeline for DQN training:

import gymnasium as gym
from gymnasium import ObservationWrapper, Wrapper
from gymnasium.spaces import Box
import numpy as np
import cv2

class FireResetWrapper(Wrapper):
    """에피소드 시작 시 FIRE 행동을 자동 실행"""
    def __init__(self, env):
        super().__init__(env)

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        obs, _, terminated, truncated, info = self.env.step(1)  # FIRE
        if terminated or truncated:
            obs, info = self.env.reset(**kwargs)
        return obs, info

class MaxAndSkipWrapper(Wrapper):
    """프레임 스킵: 4프레임마다 행동을 선택하고 중간 프레임은 반복"""
    def __init__(self, env, skip=4):
        super().__init__(env)
        self.skip = skip
        self.obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=np.uint8)

    def step(self, action):
        total_reward = 0.0
        done = False
        for i in range(self.skip):
            obs, reward, terminated, truncated, info = self.env.step(action)
            total_reward += reward
            if i == self.skip - 2:
                self.obs_buffer[0] = obs
            if i == self.skip - 1:
                self.obs_buffer[1] = obs
            if terminated or truncated:
                done = True
                break
        max_frame = self.obs_buffer.max(axis=0)
        return max_frame, total_reward, terminated, truncated, info

class ProcessFrame84Wrapper(ObservationWrapper):
    """프레임을 84x84 그레이스케일로 변환"""
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = Box(
            low=0, high=255, shape=(84, 84, 1), dtype=np.uint8
        )

    def observation(self, obs):
        return self._process(obs)

    def _process(self, frame):
        img = np.mean(frame, axis=2).astype(np.uint8)
        resized = cv2.resize(img, (84, 84), interpolation=cv2.INTER_AREA)
        return resized.reshape(84, 84, 1)

class FrameStackWrapper(ObservationWrapper):
    """연속 n개의 프레임을 채널 방향으로 스태킹"""
    def __init__(self, env, n_frames=4):
        super().__init__(env)
        self.n_frames = n_frames
        old_shape = env.observation_space.shape
        self.observation_space = Box(
            low=0, high=255,
            shape=(old_shape[0], old_shape[1], n_frames),
            dtype=np.uint8,
        )
        self.frames = deque(maxlen=n_frames)

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        for _ in range(self.n_frames):
            self.frames.append(obs)
        return np.concatenate(list(self.frames), axis=-1), info

    def observation(self, obs):
        self.frames.append(obs)
        return np.concatenate(list(self.frames), axis=-1)

class ImageToChannelsFirstWrapper(ObservationWrapper):
    """(H, W, C) -> (C, H, W) 변환"""
    def __init__(self, env):
        super().__init__(env)
        old_shape = env.observation_space.shape
        self.observation_space = Box(
            low=0, high=255,
            shape=(old_shape[2], old_shape[0], old_shape[1]),
            dtype=np.uint8,
        )

    def observation(self, obs):
        return np.transpose(obs, (2, 0, 1))


def make_atari_env(env_name="ALE/Pong-v5"):
    """Atari 환경 생성 및 래퍼 적용"""
    env = gym.make(env_name)
    env = MaxAndSkipWrapper(env)
    env = FireResetWrapper(env)
    env = ProcessFrame84Wrapper(env)
    env = FrameStackWrapper(env)
    env = ImageToChannelsFirstWrapper(env)
    return env

DQN Training Loop

from torch.utils.tensorboard import SummaryWriter

def train_dqn_pong():
    """DQN으로 Pong 학습"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"사용 장치: {device}")

    env = make_atari_env("ALE/Pong-v5")
    agent = DQNAgent(
        env=env,
        device=device,
        buffer_size=100000,
        batch_size=32,
        gamma=0.99,
        lr=1e-4,
        epsilon_start=1.0,
        epsilon_end=0.02,
        epsilon_decay=100000,
        target_update=1000,
    )

    writer = SummaryWriter("runs/dqn_pong")

    obs, _ = env.reset()
    episode_reward = 0
    episode_count = 0
    best_mean_reward = -float("inf")

    for frame in range(1, 1000001):
        # 행동 선택 및 실행
        action = agent.select_action(obs)
        next_obs, reward, terminated, truncated, _ = env.step(action)

        # 경험 저장
        agent.buffer.push(obs, action, reward, next_obs, terminated or truncated)
        episode_reward += reward
        agent.step_count += 1

        if terminated or truncated:
            episode_count += 1
            writer.add_scalar("reward/episode", episode_reward, episode_count)
            writer.add_scalar("epsilon", agent.get_epsilon(), frame)

            if episode_count % 10 == 0:
                print(
                    f"프레임 {frame}, 에피소드 {episode_count}: "
                    f"보상={episode_reward:.0f}, "
                    f"엡실론={agent.get_epsilon():.3f}, "
                    f"버퍼={len(agent.buffer)}"
                )

            episode_reward = 0
            obs, _ = env.reset()
        else:
            obs = next_obs

        # 네트워크 업데이트
        loss = agent.update()
        if loss is not None and frame % 1000 == 0:
            writer.add_scalar("loss/td", loss, frame)

        # 타겟 네트워크 동기화
        if frame % agent.target_update == 0:
            agent.sync_target_network()

        # 모델 저장
        if frame % 50000 == 0:
            torch.save(agent.online_net.state_dict(), f"dqn_pong_{frame}.pth")

    writer.close()
    env.close()

# train_dqn_pong()

DQN Training Results Analysis

DQN training in the Pong environment generally shows the following pattern:

Learning Curve Characteristics

  • Initial (0-100K frames): Near-random actions, reward around -21 (complete loss)
  • Exploration phase (100K-300K frames): Epsilon decreases, occasionally scores points
  • Learning phase (300K-700K frames): Rapid performance improvement, reward rises to around 0
  • Convergence phase (700K+): Reward reaches +19 to +21, near-perfect play

Impact of Key Hyperparameters

ParameterValueRole
Learning rate1e-4Too large = unstable, too small = slow
Batch size32Larger = more stable but more memory
Buffer size100KSmall = only recent experience, large = diverse
Target update period1000Short = unstable, long = stale target
Epsilon decay range100K framesSpeed of transition from exploration to exploitation
Discount factor0.99Importance of future rewards

DQN Experiment on Simple Environments

Training Pong requires significant GPU time. The core concepts of DQN can be verified on CartPole.

class SimpleDQN(nn.Module):
    """CartPole용 간단한 DQN"""
    def __init__(self, obs_size, n_actions):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_size, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, n_actions),
        )

    def forward(self, x):
        return self.net(x)

def train_dqn_cartpole():
    """CartPole에서 DQN 학습"""
    env = gym.make("CartPole-v1")
    device = torch.device("cpu")

    obs_size = env.observation_space.shape[0]
    n_actions = env.action_space.n

    online_net = SimpleDQN(obs_size, n_actions).to(device)
    target_net = SimpleDQN(obs_size, n_actions).to(device)
    target_net.load_state_dict(online_net.state_dict())

    optimizer = optim.Adam(online_net.parameters(), lr=1e-3)
    buffer = ReplayBuffer(10000)

    epsilon = 1.0
    epsilon_min = 0.01
    epsilon_decay = 0.995
    gamma = 0.99
    batch_size = 64
    target_update = 100

    rewards_history = []
    step = 0

    for episode in range(500):
        obs, _ = env.reset()
        total_reward = 0

        while True:
            # 엡실론-탐욕 행동 선택
            if random.random() < epsilon:
                action = env.action_space.sample()
            else:
                with torch.no_grad():
                    q = online_net(torch.tensor([obs], dtype=torch.float32))
                    action = q.argmax(dim=1).item()

            next_obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            buffer.push(obs, action, reward, next_obs, done)
            total_reward += reward
            obs = next_obs
            step += 1

            # 학습
            if len(buffer) >= batch_size:
                s, a, r, ns, d = buffer.sample(batch_size)
                s_t = torch.tensor(s, dtype=torch.float32)
                a_t = torch.tensor(a, dtype=torch.long)
                r_t = torch.tensor(r, dtype=torch.float32)
                ns_t = torch.tensor(ns, dtype=torch.float32)
                d_t = torch.tensor(d, dtype=torch.bool)

                current_q = online_net(s_t).gather(1, a_t.unsqueeze(1)).squeeze(1)

                with torch.no_grad():
                    next_q = target_net(ns_t).max(dim=1)[0]
                    next_q[d_t] = 0.0
                    target_q = r_t + gamma * next_q

                loss = nn.MSELoss()(current_q, target_q)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            # 타겟 네트워크 업데이트
            if step % target_update == 0:
                target_net.load_state_dict(online_net.state_dict())

            if done:
                break

        epsilon = max(epsilon_min, epsilon * epsilon_decay)
        rewards_history.append(total_reward)

        if episode % 50 == 0:
            mean_reward = np.mean(rewards_history[-50:])
            print(f"에피소드 {episode}: 평균 보상={mean_reward:.1f}, 엡실론={epsilon:.3f}")

            if mean_reward >= 475:
                print("CartPole 해결!")
                break

    env.close()
    return online_net

# trained_model = train_dqn_cartpole()

Summary

  1. Limitations of table Q-learning: Q tables cannot be used in large state spaces
  2. Core DQN techniques: Experience replay (removes correlation), target network (stabilizes training)
  3. Atari preprocessing: Frame skip, grayscale conversion, 84x84 resize, frame stacking
  4. Training process: Human-level performance on Pong through over one million frames of training
  5. Epsilon schedule: Gradual transition from initial exploration to exploitation

In the next article, we will cover various extension techniques that significantly improve DQN performance (Double DQN, Dueling DQN, Prioritized Replay, etc.).