Skip to content

필사 모드: [심층 강화학습] 06. Deep Q-Network: DQN의 원리와 구현

한국어
0%
정확도 0%
💡 왼쪽 원문을 읽으면서 오른쪽에 따라 써보세요. Tab 키로 힌트를 받을 수 있습니다.
원문 렌더가 준비되기 전까지 텍스트 가이드로 표시합니다.

현실 세계에서의 가치 반복의 한계

이전 글에서 다룬 테이블 기반 방법은 상태 공간이 작을 때만 사용할 수 있습니다. 현실 세계의 문제에서는 다음과 같은 한계에 직면합니다.

상태 공간의 폭발

- **Atari 게임 화면**: 210 x 160 픽셀, 3개 컬러 채널 = 약 10만 개의 연속 값

- **바둑**: 약 10의 170승에 달하는 가능한 상태

- **자율 주행**: 센서 데이터의 조합은 사실상 무한

이러한 환경에서는 Q 테이블을 만드는 것 자체가 불가능합니다. 신경망으로 Q함수를 **근사(approximation)**해야 합니다.

테이블 Q-러닝에서 딥 Q-러닝으로

핵심 아이디어

Q 테이블 대신 신경망 Q(s, a; theta)를 사용하여 Q값을 근사합니다. 여기서 theta는 신경망의 파라미터(가중치)입니다.

Q-테이블: Q[state_index][action_index] = value (정확한 값 저장)

DQN: Q(state; theta) -> [q_action_0, q_action_1, ...] (함수 근사)

왜 단순히 신경망만 넣으면 안 되는가

Q-러닝에 단순히 신경망을 대입하면 학습이 매우 불안정합니다. 세 가지 핵심 문제가 있습니다.

1. **데이터 간의 상관관계**: 연속된 경험들은 서로 강하게 상관되어 있어 i.i.d. 가정을 위반합니다

2. **비정상 타겟**: 학습 타겟이 계속 변하므로 수렴이 어렵습니다

3. **마르코프 성질 위반**: 단일 프레임만으로는 환경 상태를 완전히 표현하기 어렵습니다

DQN의 핵심 기법

1. 경험 리플레이 (Experience Replay)

에이전트의 경험을 버퍼에 저장하고, 학습 시 무작위로 샘플링하여 사용합니다. 이를 통해 데이터 간 상관관계를 깨뜨립니다.

from collections import deque

class ReplayBuffer:

"""경험 리플레이 버퍼"""

def __init__(self, capacity):

self.buffer = deque(maxlen=capacity)

def push(self, state, action, reward, next_state, done):

self.buffer.append((state, action, reward, next_state, done))

def sample(self, batch_size):

batch = random.sample(self.buffer, batch_size)

states, actions, rewards, next_states, dones = zip(*batch)

return (

np.array(states),

np.array(actions),

np.array(rewards, dtype=np.float32),

np.array(next_states),

np.array(dones, dtype=np.bool_),

)

def __len__(self):

return len(self.buffer)

사용 예시

buffer = ReplayBuffer(capacity=100000)

경험 저장

buffer.push(

state=np.array([0.1, 0.2, 0.3, 0.4]),

action=1,

reward=1.0,

next_state=np.array([0.2, 0.3, 0.4, 0.5]),

done=False,

)

배치 샘플링

if len(buffer) >= 32:

states, actions, rewards, next_states, dones = buffer.sample(32)

2. 타겟 네트워크 (Target Network)

Q값 업데이트의 타겟을 계산할 때 별도의 네트워크를 사용합니다. 타겟 네트워크는 주기적으로 온라인 네트워크의 가중치를 복사받습니다.

TD 타겟 = r + gamma * max_{a'} Q_target(s', a'; theta_target)

손실 = (Q(s, a; theta) - TD 타겟)^2

타겟 네트워크를 사용하면 학습 타겟이 일정 기간 동안 고정되어 학습이 안정화됩니다.

3. 프레임 스태킹 (Frame Stacking)

Atari 게임에서 단일 프레임은 공의 이동 방향 등을 알 수 없습니다. 연속 4개의 프레임을 쌓아서 하나의 관찰로 사용합니다.

DQN 모델 구현

Atari용 DQN 네트워크

class DQN(nn.Module):

"""Atari 게임용 DQN 네트워크"""

def __init__(self, input_channels, n_actions):

super().__init__()

self.conv = nn.Sequential(

nn.Conv2d(input_channels, 32, kernel_size=8, stride=4),

nn.ReLU(),

nn.Conv2d(32, 64, kernel_size=4, stride=2),

nn.ReLU(),

nn.Conv2d(64, 64, kernel_size=3, stride=1),

nn.ReLU(),

)

conv_out_size = self._get_conv_out(input_channels)

self.fc = nn.Sequential(

nn.Linear(conv_out_size, 512),

nn.ReLU(),

nn.Linear(512, n_actions),

)

def _get_conv_out(self, channels):

o = self.conv(torch.zeros(1, channels, 84, 84))

return int(np.prod(o.size()))

def forward(self, x):

conv_out = self.conv(x.float() / 255.0)

return self.fc(conv_out.view(conv_out.size(0), -1))

DQN 에이전트 구현

class DQNAgent:

"""DQN 에이전트"""

def __init__(self, env, device="cpu", buffer_size=100000,

batch_size=32, gamma=0.99, lr=1e-4,

epsilon_start=1.0, epsilon_end=0.01,

epsilon_decay=100000, target_update=1000):

self.env = env

self.device = device

self.batch_size = batch_size

self.gamma = gamma

self.target_update = target_update

엡실론 스케줄

self.epsilon_start = epsilon_start

self.epsilon_end = epsilon_end

self.epsilon_decay = epsilon_decay

n_actions = env.action_space.n

온라인 네트워크와 타겟 네트워크

self.online_net = DQN(4, n_actions).to(device)

self.target_net = DQN(4, n_actions).to(device)

self.target_net.load_state_dict(self.online_net.state_dict())

self.target_net.eval()

self.optimizer = optim.Adam(self.online_net.parameters(), lr=lr)

self.buffer = ReplayBuffer(buffer_size)

self.step_count = 0

def get_epsilon(self):

"""현재 엡실론 값 계산 (선형 감소)"""

epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \

max(0, 1 - self.step_count / self.epsilon_decay)

return epsilon

def select_action(self, state):

"""엡실론-탐욕 정책으로 행동 선택"""

epsilon = self.get_epsilon()

if random.random() < epsilon:

return self.env.action_space.sample()

with torch.no_grad():

state_tensor = torch.tensor(

np.array([state]), dtype=torch.uint8

).to(self.device)

q_values = self.online_net(state_tensor)

return q_values.argmax(dim=1).item()

def update(self):

"""경험 리플레이에서 배치를 샘플링하여 네트워크 업데이트"""

if len(self.buffer) < self.batch_size:

return None

states, actions, rewards, next_states, dones = self.buffer.sample(self.batch_size)

텐서 변환

states_t = torch.tensor(states, dtype=torch.uint8).to(self.device)

actions_t = torch.tensor(actions, dtype=torch.long).to(self.device)

rewards_t = torch.tensor(rewards, dtype=torch.float32).to(self.device)

next_states_t = torch.tensor(next_states, dtype=torch.uint8).to(self.device)

dones_t = torch.tensor(dones, dtype=torch.bool).to(self.device)

현재 Q값: Q(s, a)

current_q = self.online_net(states_t).gather(1, actions_t.unsqueeze(1)).squeeze(1)

타겟 Q값: r + gamma * max_a' Q_target(s', a')

with torch.no_grad():

next_q = self.target_net(next_states_t).max(dim=1)[0]

next_q[dones_t] = 0.0

target_q = rewards_t + self.gamma * next_q

손실 계산 및 역전파

loss = nn.SmoothL1Loss()(current_q, target_q)

self.optimizer.zero_grad()

loss.backward()

그래디언트 클리핑 (안정적 학습)

torch.nn.utils.clip_grad_norm_(self.online_net.parameters(), max_norm=10)

self.optimizer.step()

return loss.item()

def sync_target_network(self):

"""타겟 네트워크를 온라인 네트워크로 동기화"""

self.target_net.load_state_dict(self.online_net.state_dict())

Atari 환경 전처리

DQN 학습을 위한 Atari 환경 전처리 파이프라인입니다.

from gymnasium import ObservationWrapper, Wrapper

from gymnasium.spaces import Box

class FireResetWrapper(Wrapper):

"""에피소드 시작 시 FIRE 행동을 자동 실행"""

def __init__(self, env):

super().__init__(env)

def reset(self, **kwargs):

obs, info = self.env.reset(**kwargs)

obs, _, terminated, truncated, info = self.env.step(1) # FIRE

if terminated or truncated:

obs, info = self.env.reset(**kwargs)

return obs, info

class MaxAndSkipWrapper(Wrapper):

"""프레임 스킵: 4프레임마다 행동을 선택하고 중간 프레임은 반복"""

def __init__(self, env, skip=4):

super().__init__(env)

self.skip = skip

self.obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=np.uint8)

def step(self, action):

total_reward = 0.0

done = False

for i in range(self.skip):

obs, reward, terminated, truncated, info = self.env.step(action)

total_reward += reward

if i == self.skip - 2:

self.obs_buffer[0] = obs

if i == self.skip - 1:

self.obs_buffer[1] = obs

if terminated or truncated:

done = True

break

max_frame = self.obs_buffer.max(axis=0)

return max_frame, total_reward, terminated, truncated, info

class ProcessFrame84Wrapper(ObservationWrapper):

"""프레임을 84x84 그레이스케일로 변환"""

def __init__(self, env):

super().__init__(env)

self.observation_space = Box(

low=0, high=255, shape=(84, 84, 1), dtype=np.uint8

)

def observation(self, obs):

return self._process(obs)

def _process(self, frame):

img = np.mean(frame, axis=2).astype(np.uint8)

resized = cv2.resize(img, (84, 84), interpolation=cv2.INTER_AREA)

return resized.reshape(84, 84, 1)

class FrameStackWrapper(ObservationWrapper):

"""연속 n개의 프레임을 채널 방향으로 스태킹"""

def __init__(self, env, n_frames=4):

super().__init__(env)

self.n_frames = n_frames

old_shape = env.observation_space.shape

self.observation_space = Box(

low=0, high=255,

shape=(old_shape[0], old_shape[1], n_frames),

dtype=np.uint8,

)

self.frames = deque(maxlen=n_frames)

def reset(self, **kwargs):

obs, info = self.env.reset(**kwargs)

for _ in range(self.n_frames):

self.frames.append(obs)

return np.concatenate(list(self.frames), axis=-1), info

def observation(self, obs):

self.frames.append(obs)

return np.concatenate(list(self.frames), axis=-1)

class ImageToChannelsFirstWrapper(ObservationWrapper):

"""(H, W, C) -> (C, H, W) 변환"""

def __init__(self, env):

super().__init__(env)

old_shape = env.observation_space.shape

self.observation_space = Box(

low=0, high=255,

shape=(old_shape[2], old_shape[0], old_shape[1]),

dtype=np.uint8,

)

def observation(self, obs):

return np.transpose(obs, (2, 0, 1))

def make_atari_env(env_name="ALE/Pong-v5"):

"""Atari 환경 생성 및 래퍼 적용"""

env = gym.make(env_name)

env = MaxAndSkipWrapper(env)

env = FireResetWrapper(env)

env = ProcessFrame84Wrapper(env)

env = FrameStackWrapper(env)

env = ImageToChannelsFirstWrapper(env)

return env

DQN 학습 루프

from torch.utils.tensorboard import SummaryWriter

def train_dqn_pong():

"""DQN으로 Pong 학습"""

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"사용 장치: {device}")

env = make_atari_env("ALE/Pong-v5")

agent = DQNAgent(

env=env,

device=device,

buffer_size=100000,

batch_size=32,

gamma=0.99,

lr=1e-4,

epsilon_start=1.0,

epsilon_end=0.02,

epsilon_decay=100000,

target_update=1000,

)

writer = SummaryWriter("runs/dqn_pong")

obs, _ = env.reset()

episode_reward = 0

episode_count = 0

best_mean_reward = -float("inf")

for frame in range(1, 1000001):

행동 선택 및 실행

action = agent.select_action(obs)

next_obs, reward, terminated, truncated, _ = env.step(action)

경험 저장

agent.buffer.push(obs, action, reward, next_obs, terminated or truncated)

episode_reward += reward

agent.step_count += 1

if terminated or truncated:

episode_count += 1

writer.add_scalar("reward/episode", episode_reward, episode_count)

writer.add_scalar("epsilon", agent.get_epsilon(), frame)

if episode_count % 10 == 0:

print(

f"프레임 {frame}, 에피소드 {episode_count}: "

f"보상={episode_reward:.0f}, "

f"엡실론={agent.get_epsilon():.3f}, "

f"버퍼={len(agent.buffer)}"

)

episode_reward = 0

obs, _ = env.reset()

else:

obs = next_obs

네트워크 업데이트

loss = agent.update()

if loss is not None and frame % 1000 == 0:

writer.add_scalar("loss/td", loss, frame)

타겟 네트워크 동기화

if frame % agent.target_update == 0:

agent.sync_target_network()

모델 저장

if frame % 50000 == 0:

torch.save(agent.online_net.state_dict(), f"dqn_pong_{frame}.pth")

writer.close()

env.close()

train_dqn_pong()

DQN 학습 결과 분석

Pong 환경에서의 DQN 학습은 일반적으로 다음과 같은 패턴을 보입니다.

학습 곡선의 특징

- **초기 (0-10만 프레임)**: 무작위에 가까운 행동, 보상 약 -21 (완패)

- **탐색기 (10-30만 프레임)**: 엡실론이 감소하면서 간헐적으로 점수를 얻기 시작

- **학습기 (30-70만 프레임)**: 빠르게 성능이 향상, 보상이 0 근처로 상승

- **수렴기 (70만 이후)**: 보상 +19~+21에 도달, 거의 완벽한 플레이

핵심 하이퍼파라미터의 영향

| 파라미터 | 값 | 역할 |

| ------------------ | ----------- | ------------------------------------------- |

| 학습률 | 1e-4 | 너무 크면 불안정, 너무 작으면 느림 |

| 배치 크기 | 32 | 클수록 안정적이지만 메모리 사용 증가 |

| 버퍼 크기 | 10만 | 작으면 최근 경험만 사용, 클수록 다양한 경험 |

| 타겟 업데이트 주기 | 1000 | 짧으면 불안정, 길면 오래된 타겟 사용 |

| 엡실론 감소 구간 | 10만 프레임 | 탐색에서 활용으로의 전환 속도 |

| 할인 인자 | 0.99 | 미래 보상의 중요도 |

간단한 환경에서의 DQN 실험

Pong을 학습하려면 GPU와 시간이 많이 필요합니다. CartPole에서 DQN의 핵심 개념을 검증할 수 있습니다.

class SimpleDQN(nn.Module):

"""CartPole용 간단한 DQN"""

def __init__(self, obs_size, n_actions):

super().__init__()

self.net = nn.Sequential(

nn.Linear(obs_size, 128),

nn.ReLU(),

nn.Linear(128, 128),

nn.ReLU(),

nn.Linear(128, n_actions),

)

def forward(self, x):

return self.net(x)

def train_dqn_cartpole():

"""CartPole에서 DQN 학습"""

env = gym.make("CartPole-v1")

device = torch.device("cpu")

obs_size = env.observation_space.shape[0]

n_actions = env.action_space.n

online_net = SimpleDQN(obs_size, n_actions).to(device)

target_net = SimpleDQN(obs_size, n_actions).to(device)

target_net.load_state_dict(online_net.state_dict())

optimizer = optim.Adam(online_net.parameters(), lr=1e-3)

buffer = ReplayBuffer(10000)

epsilon = 1.0

epsilon_min = 0.01

epsilon_decay = 0.995

gamma = 0.99

batch_size = 64

target_update = 100

rewards_history = []

step = 0

for episode in range(500):

obs, _ = env.reset()

total_reward = 0

while True:

엡실론-탐욕 행동 선택

if random.random() < epsilon:

action = env.action_space.sample()

else:

with torch.no_grad():

q = online_net(torch.tensor([obs], dtype=torch.float32))

action = q.argmax(dim=1).item()

next_obs, reward, terminated, truncated, _ = env.step(action)

done = terminated or truncated

buffer.push(obs, action, reward, next_obs, done)

total_reward += reward

obs = next_obs

step += 1

학습

if len(buffer) >= batch_size:

s, a, r, ns, d = buffer.sample(batch_size)

s_t = torch.tensor(s, dtype=torch.float32)

a_t = torch.tensor(a, dtype=torch.long)

r_t = torch.tensor(r, dtype=torch.float32)

ns_t = torch.tensor(ns, dtype=torch.float32)

d_t = torch.tensor(d, dtype=torch.bool)

current_q = online_net(s_t).gather(1, a_t.unsqueeze(1)).squeeze(1)

with torch.no_grad():

next_q = target_net(ns_t).max(dim=1)[0]

next_q[d_t] = 0.0

target_q = r_t + gamma * next_q

loss = nn.MSELoss()(current_q, target_q)

optimizer.zero_grad()

loss.backward()

optimizer.step()

타겟 네트워크 업데이트

if step % target_update == 0:

target_net.load_state_dict(online_net.state_dict())

if done:

break

epsilon = max(epsilon_min, epsilon * epsilon_decay)

rewards_history.append(total_reward)

if episode % 50 == 0:

mean_reward = np.mean(rewards_history[-50:])

print(f"에피소드 {episode}: 평균 보상={mean_reward:.1f}, 엡실론={epsilon:.3f}")

if mean_reward >= 475:

print("CartPole 해결!")

break

env.close()

return online_net

trained_model = train_dqn_cartpole()

정리

1. **테이블 Q-러닝의 한계**: 대규모 상태 공간에서는 Q 테이블을 사용할 수 없음

2. **DQN의 핵심 기법**: 경험 리플레이(상관관계 제거), 타겟 네트워크(학습 안정화)

3. **Atari 전처리**: 프레임 스킵, 그레이스케일 변환, 84x84 리사이즈, 프레임 스태킹

4. **학습 과정**: 백만 프레임 이상의 학습을 통해 Pong에서 인간 수준 달성

5. **엡실론 스케줄**: 초기 탐색에서 점진적으로 활용으로 전환

다음 글에서는 DQN의 성능을 크게 향상시키는 다양한 확장 기법(Double DQN, Dueling DQN, Prioritized Replay 등)을 다루겠습니다.

현재 단락 (1/372)

이전 글에서 다룬 테이블 기반 방법은 상태 공간이 작을 때만 사용할 수 있습니다. 현실 세계의 문제에서는 다음과 같은 한계에 직면합니다.

작성 글자: 0원문 글자: 11,483작성 단락: 0/372