Split View: [심층 강화학습] 11. A3C: 비동기 Advantage Actor-Critic
[심층 강화학습] 11. A3C: 비동기 Advantage Actor-Critic
개요
이전 글에서 살펴본 A2C(Advantage Actor-Critic)는 단일 환경에서 수집한 경험 데이터의 상관관계(correlation) 문제가 있었다. 연속된 상태 전이는 서로 강하게 연관되어 있어 학습 효율이 떨어진다. DQN에서는 experience replay로 이 문제를 해결했지만, on-policy 방법인 Actor-Critic에서는 다른 접근이 필요하다.
A3C(Asynchronous Advantage Actor-Critic) 는 여러 환경을 동시에 실행하여 데이터의 상관관계를 깨뜨리는 방법이다. 2016년 DeepMind의 Mnih 등이 제안한 이 방법은, replay buffer 없이도 안정적인 학습을 달성한다.
상관관계와 샘플 효율성
왜 상관관계가 문제인가
강화학습에서 에이전트가 한 에피소드 동안 수집하는 전이(transition)들은 시간적으로 연속되어 있다. 예를 들어 Pong 게임에서 공이 오른쪽으로 날아가는 연속 프레임은 매우 유사한 상태를 가진다. 이런 상관된 데이터로 신경망을 업데이트하면:
- 그래디언트가 특정 방향으로 편향된다
- 학습이 불안정해지고 수렴 속도가 느려진다
- 최악의 경우 학습이 발산할 수 있다
해결 접근법 비교
| 방법 | 원리 | 장점 | 단점 |
|---|---|---|---|
| Experience Replay | 과거 전이를 버퍼에 저장하고 무작위 샘플링 | 샘플 효율적 | Off-policy만 가능 |
| 병렬 환경 | 여러 환경을 동시에 실행 | On-policy 호환 | 더 많은 연산 필요 |
| A3C | 비동기 병렬 환경 + 독립 학습 | 탐색 다양성 극대화 | 구현 복잡 |
A2C에서 A3C로: 추가된 A의 의미
A2C는 여러 환경을 병렬로 실행하되 동기적(synchronous) 으로 경험을 모아 한꺼번에 업데이트한다. A3C는 여기에 Asynchronous(비동기) 를 추가한다.
A2C의 동기 방식
# A2C: 모든 워커가 동기적으로 동작
class A2CAgent:
def __init__(self, num_envs, model):
self.envs = [make_env() for _ in range(num_envs)]
self.model = model # 공유 모델
def train_step(self):
# 1. 모든 환경에서 동시에 행동 수집
states = [env.get_state() for env in self.envs]
actions, values = self.model.predict(states)
# 2. 모든 환경에서 동시에 스텝 실행
rewards, next_states, dones = [], [], []
for env, action in zip(self.envs, actions):
r, ns, d = env.step(action)
rewards.append(r)
next_states.append(ns)
dones.append(d)
# 3. 모아서 한 번에 업데이트
self.model.update(states, actions, rewards, next_states, dones)
A3C의 비동기 방식
A3C에서는 각 워커가 독립적으로 환경과 상호작용하고, 자체적으로 그래디언트를 계산한 뒤 중앙 모델에 비동기적으로 반영한다.
핵심 차이점:
- 각 워커는 다른 워커를 기다리지 않는다
- 워커마다 다른 탐색 정책을 사용할 수 있다(엡실론 값 다르게 설정 등)
- 자연스럽게 탐색 다양성이 확보된다
Python 멀티프로세싱 기초
A3C를 구현하기 전에 Python의 멀티프로세싱을 이해해야 한다. GIL(Global Interpreter Lock) 때문에 스레드 기반 병렬화는 CPU 바운드 작업에 적합하지 않다.
import torch.multiprocessing as mp
def worker_process(worker_id, shared_model, optimizer, device):
"""각 워커 프로세스가 실행하는 함수"""
env = make_env()
local_model = ActorCritic(env.observation_space.shape[0],
env.action_space.n)
local_model.to(device)
while True:
# 공유 모델의 파라미터를 로컬 모델에 복사
local_model.load_state_dict(shared_model.state_dict())
# 로컬 환경에서 경험 수집
experiences = collect_experiences(env, local_model, n_steps=20)
# 로컬에서 그래디언트 계산
loss = compute_loss(local_model, experiences)
loss.backward()
# 공유 모델에 그래디언트 반영
for shared_param, local_param in zip(shared_model.parameters(),
local_model.parameters()):
shared_param.grad = local_param.grad
optimizer.step()
optimizer.zero_grad()
if __name__ == '__main__':
mp.set_start_method('spawn')
shared_model = ActorCritic(obs_size, act_size)
shared_model.share_memory() # 프로세스 간 메모리 공유
optimizer = SharedAdam(shared_model.parameters(), lr=1e-4)
optimizer.share_memory()
processes = []
for i in range(mp.cpu_count()):
p = mp.Process(target=worker_process,
args=(i, shared_model, optimizer, 'cpu'))
p.start()
processes.append(p)
for p in processes:
p.join()
SharedAdam 옵티마이저
Adam 옵티마이저의 모멘텀 상태도 프로세스 간 공유해야 한다:
import torch
class SharedAdam(torch.optim.Adam):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8):
super().__init__(params, lr=lr, betas=betas, eps=eps)
# Adam의 내부 상태를 공유 메모리로 이동
for group in self.param_groups:
for p in group['params']:
state = self.state[p]
state['step'] = torch.zeros(1)
state['exp_avg'] = torch.zeros_like(p.data)
state['exp_avg_sq'] = torch.zeros_like(p.data)
# 공유 메모리 설정
state['step'].share_memory_()
state['exp_avg'].share_memory_()
state['exp_avg_sq'].share_memory_()
A3C 데이터 병렬화
데이터 병렬화(Data Parallelism)는 각 워커가 경험 데이터를 수집하여 중앙에 전송하고, 중앙에서 모아서 학습하는 방식이다.
import torch
import torch.nn as nn
import torch.multiprocessing as mp
from collections import namedtuple
Experience = namedtuple('Experience',
['state', 'action', 'reward', 'done', 'next_state'])
class ActorCritic(nn.Module):
def __init__(self, obs_size, act_size):
super().__init__()
self.shared = nn.Sequential(
nn.Linear(obs_size, 256),
nn.ReLU(),
)
self.policy = nn.Sequential(
nn.Linear(256, act_size),
nn.Softmax(dim=-1),
)
self.value = nn.Linear(256, 1)
def forward(self, x):
shared_out = self.shared(x)
return self.policy(shared_out), self.value(shared_out)
def data_worker(worker_id, shared_model, data_queue, num_steps=20):
"""데이터 수집 워커: 경험을 큐에 넣는다"""
env = make_env()
state = env.reset()
while True:
# 공유 모델 파라미터 동기화
local_model = ActorCritic(env.observation_space.shape[0],
env.action_space.n)
local_model.load_state_dict(shared_model.state_dict())
experiences = []
for _ in range(num_steps):
state_t = torch.FloatTensor(state)
probs, _ = local_model(state_t.unsqueeze(0))
action = torch.multinomial(probs, 1).item()
next_state, reward, done, _ = env.step(action)
experiences.append(Experience(state, action, reward, done, next_state))
state = next_state
if done:
state = env.reset()
# 수집한 경험을 큐에 전송
data_queue.put(experiences)
A3C 그래디언트 병렬화
그래디언트 병렬화(Gradient Parallelism)는 각 워커가 경험 수집과 그래디언트 계산을 모두 수행하고, 계산된 그래디언트를 중앙 모델에 직접 적용하는 방식이다. 이것이 원래 A3C 논문의 방식이다.
def gradient_worker(worker_id, shared_model, optimizer, counter, lock,
max_episodes=10000, gamma=0.99, entropy_beta=0.01):
"""그래디언트 계산 워커: 로컬에서 그래디언트까지 계산"""
env = make_env()
local_model = ActorCritic(env.observation_space.shape[0],
env.action_space.n)
state = env.reset()
episode_reward = 0.0
while True:
# 공유 모델 동기화
local_model.load_state_dict(shared_model.state_dict())
log_probs = []
values = []
rewards = []
entropies = []
for _ in range(20): # n-step
state_t = torch.FloatTensor(state).unsqueeze(0)
probs, value = local_model(state_t)
dist = torch.distributions.Categorical(probs)
action = dist.sample()
log_prob = dist.log_prob(action)
entropy = dist.entropy()
next_state, reward, done, _ = env.step(action.item())
log_probs.append(log_prob)
values.append(value.squeeze())
rewards.append(reward)
entropies.append(entropy)
episode_reward += reward
state = next_state
if done:
state = env.reset()
with lock:
counter.value += 1
episode_reward = 0.0
break
# 부트스트랩 값 계산
if done:
R = torch.tensor(0.0)
else:
_, R = local_model(torch.FloatTensor(state).unsqueeze(0))
R = R.squeeze().detach()
# 역방향으로 리턴 계산 및 손실 계산
policy_loss = 0.0
value_loss = 0.0
entropy_loss = 0.0
for i in reversed(range(len(rewards))):
R = rewards[i] + gamma * R
advantage = R - values[i].detach()
policy_loss -= log_probs[i] * advantage
value_loss += 0.5 * (R - values[i]) ** 2
entropy_loss -= entropies[i]
total_loss = policy_loss + value_loss + entropy_beta * entropy_loss
# 로컬 그래디언트 계산
optimizer.zero_grad()
total_loss.backward()
# 그래디언트 클리핑
torch.nn.utils.clip_grad_norm_(local_model.parameters(), 40.0)
# 공유 모델에 그래디언트 전달 및 업데이트
for shared_param, local_param in zip(shared_model.parameters(),
local_model.parameters()):
if shared_param.grad is None:
shared_param.grad = local_param.grad.clone()
else:
shared_param.grad.copy_(local_param.grad)
optimizer.step()
전체 A3C 학습 루프
def train_a3c(env_name='CartPole-v1', num_workers=4, max_episodes=5000):
"""A3C 메인 학습 함수"""
env = make_env()
obs_size = env.observation_space.shape[0]
act_size = env.action_space.n
env.close()
shared_model = ActorCritic(obs_size, act_size)
shared_model.share_memory()
optimizer = SharedAdam(shared_model.parameters(), lr=1e-4)
optimizer.share_memory()
counter = mp.Value('i', 0)
lock = mp.Lock()
processes = []
for i in range(num_workers):
p = mp.Process(
target=gradient_worker,
args=(i, shared_model, optimizer, counter, lock, max_episodes)
)
p.start()
processes.append(p)
# 모니터링 프로세스
while counter.value < max_episodes:
import time
time.sleep(10)
print(f"완료 에피소드: {counter.value}/{max_episodes}")
for p in processes:
p.terminate()
p.join()
return shared_model
if __name__ == '__main__':
mp.set_start_method('spawn')
model = train_a3c()
torch.save(model.state_dict(), 'a3c_model.pth')
데이터 병렬화 vs 그래디언트 병렬화 비교
| 항목 | 데이터 병렬화 | 그래디언트 병렬화 |
|---|---|---|
| 워커 역할 | 경험 수집만 | 경험 수집 + 그래디언트 계산 |
| 통신 내용 | 전이 데이터 (상태, 행동, 보상) | 그래디언트 텐서 |
| 통신량 | 상태 크기에 비례 | 모델 파라미터 수에 비례 |
| 중앙 부담 | 학습 연산 집중 | 업데이트만 수행 |
| 구현 난이도 | 상대적으로 쉬움 | 공유 메모리 관리 필요 |
| 확장성 | 중앙 병목 가능 | 워커 수에 따라 선형 확장 |
실험 결과 비교
CartPole-v1 환경에서의 학습 성능:
- A2C (단일 환경): 약 500 에피소드에서 수렴
- A2C (8 병렬 환경): 약 200 에피소드에서 수렴
- A3C (8 워커, 데이터 병렬): 약 150 에피소드에서 수렴
- A3C (8 워커, 그래디언트 병렬): 약 120 에피소드에서 수렴
비동기 업데이트로 인한 약간의 노이즈가 있지만, 탐색 다양성 덕분에 전체적으로 더 빠르게 수렴하는 경향을 보인다.
실전 팁과 주의사항
-
워커 수 선택: CPU 코어 수와 동일하게 설정하는 것이 일반적이다. GPU 사용 시에는 워커 수를 줄이고 배치 크기를 늘리는 것이 효율적이다.
-
비동기 업데이트의 불안정성: 워커 간 모델 버전 차이(staleness)가 클수록 학습이 불안정해진다. 그래디언트 클리핑이 필수적이다.
-
A2C vs A3C 선택 기준: GPU를 사용한다면 A2C가 더 효율적인 경우가 많다. 벡터화된 환경을 사용하면 A2C도 충분한 탐색 다양성을 확보할 수 있다. A3C는 CPU 기반 학습에서 장점이 크다.
-
디버깅: 비동기 프로그램은 디버깅이 어렵다. 먼저 단일 워커로 정상 동작을 확인한 뒤 워커 수를 늘리는 것이 좋다.
핵심 요약
- A3C는 비동기 병렬 학습으로 데이터 상관관계 문제를 해결한다
- 데이터 병렬화는 경험 수집을 분산하고, 그래디언트 병렬화는 연산도 분산한다
- Python의 multiprocessing과 shared memory를 활용하여 구현한다
- 최근에는 GPU 효율성 때문에 동기 방식(A2C)이나 PPO가 더 선호되는 추세이다
다음 글에서는 강화학습을 자연어 처리에 적용하여 챗봇을 훈련하는 방법을 알아보겠다.
[Deep RL] 11. A3C: Asynchronous Advantage Actor-Critic
Overview
The A2C (Advantage Actor-Critic) method we examined in the previous post suffered from correlation issues in experience data collected from a single environment. Consecutive state transitions are strongly correlated with each other, reducing learning efficiency. DQN solved this problem with experience replay, but on-policy methods like Actor-Critic require a different approach.
A3C (Asynchronous Advantage Actor-Critic) breaks the data correlation by running multiple environments simultaneously. Proposed by Mnih et al. at DeepMind in 2016, this method achieves stable learning without a replay buffer.
Correlation and Sample Efficiency
Why Is Correlation a Problem?
In reinforcement learning, transitions collected by an agent during a single episode are temporally consecutive. For example, consecutive frames in a Pong game where the ball is flying to the right have very similar states. When updating a neural network with such correlated data:
- Gradients become biased in a particular direction
- Learning becomes unstable and convergence slows down
- In the worst case, learning can diverge
Comparison of Solution Approaches
| Method | Principle | Advantages | Disadvantages |
|---|---|---|---|
| Experience Replay | Store past transitions in buffer, sample randomly | Sample efficient | Off-policy only |
| Parallel Envs | Run multiple environments simultaneously | On-policy compatible | More computation needed |
| A3C | Async parallel envs + independent learning | Maximizes exploration diversity | Complex implementation |
From A2C to A3C: The Meaning of the Added A
A2C runs multiple environments in parallel but collects experiences synchronously and updates all at once. A3C adds Asynchronous to this.
A2C's Synchronous Approach
# A2C: All workers operate synchronously
class A2CAgent:
def __init__(self, num_envs, model):
self.envs = [make_env() for _ in range(num_envs)]
self.model = model # Shared model
def train_step(self):
# 1. Collect actions from all environments simultaneously
states = [env.get_state() for env in self.envs]
actions, values = self.model.predict(states)
# 2. Execute steps in all environments simultaneously
rewards, next_states, dones = [], [], []
for env, action in zip(self.envs, actions):
r, ns, d = env.step(action)
rewards.append(r)
next_states.append(ns)
dones.append(d)
# 3. Gather and update all at once
self.model.update(states, actions, rewards, next_states, dones)
A3C's Asynchronous Approach
In A3C, each worker independently interacts with its environment, computes gradients locally, and asynchronously applies them to the central model.
Key differences:
- Each worker does not wait for other workers
- Each worker can use a different exploration policy (e.g., different epsilon values)
- Exploration diversity is naturally achieved
Python Multiprocessing Basics
Before implementing A3C, we need to understand Python's multiprocessing. Due to the GIL (Global Interpreter Lock), thread-based parallelism is not suitable for CPU-bound tasks.
import torch.multiprocessing as mp
def worker_process(worker_id, shared_model, optimizer, device):
"""Function executed by each worker process"""
env = make_env()
local_model = ActorCritic(env.observation_space.shape[0],
env.action_space.n)
local_model.to(device)
while True:
# Copy shared model parameters to local model
local_model.load_state_dict(shared_model.state_dict())
# Collect experiences from local environment
experiences = collect_experiences(env, local_model, n_steps=20)
# Compute gradients locally
loss = compute_loss(local_model, experiences)
loss.backward()
# Apply gradients to shared model
for shared_param, local_param in zip(shared_model.parameters(),
local_model.parameters()):
shared_param.grad = local_param.grad
optimizer.step()
optimizer.zero_grad()
if __name__ == '__main__':
mp.set_start_method('spawn')
shared_model = ActorCritic(obs_size, act_size)
shared_model.share_memory() # Share memory across processes
optimizer = SharedAdam(shared_model.parameters(), lr=1e-4)
optimizer.share_memory()
processes = []
for i in range(mp.cpu_count()):
p = mp.Process(target=worker_process,
args=(i, shared_model, optimizer, 'cpu'))
p.start()
processes.append(p)
for p in processes:
p.join()
SharedAdam Optimizer
The Adam optimizer's momentum states must also be shared across processes:
import torch
class SharedAdam(torch.optim.Adam):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8):
super().__init__(params, lr=lr, betas=betas, eps=eps)
# Move Adam's internal states to shared memory
for group in self.param_groups:
for p in group['params']:
state = self.state[p]
state['step'] = torch.zeros(1)
state['exp_avg'] = torch.zeros_like(p.data)
state['exp_avg_sq'] = torch.zeros_like(p.data)
# Set up shared memory
state['step'].share_memory_()
state['exp_avg'].share_memory_()
state['exp_avg_sq'].share_memory_()
A3C Data Parallelism
Data Parallelism is a method where each worker collects experience data and sends it to a central location, which then gathers and trains on the data.
import torch
import torch.nn as nn
import torch.multiprocessing as mp
from collections import namedtuple
Experience = namedtuple('Experience',
['state', 'action', 'reward', 'done', 'next_state'])
class ActorCritic(nn.Module):
def __init__(self, obs_size, act_size):
super().__init__()
self.shared = nn.Sequential(
nn.Linear(obs_size, 256),
nn.ReLU(),
)
self.policy = nn.Sequential(
nn.Linear(256, act_size),
nn.Softmax(dim=-1),
)
self.value = nn.Linear(256, 1)
def forward(self, x):
shared_out = self.shared(x)
return self.policy(shared_out), self.value(shared_out)
def data_worker(worker_id, shared_model, data_queue, num_steps=20):
"""Data collection worker: puts experiences into the queue"""
env = make_env()
state = env.reset()
while True:
# Sync shared model parameters
local_model = ActorCritic(env.observation_space.shape[0],
env.action_space.n)
local_model.load_state_dict(shared_model.state_dict())
experiences = []
for _ in range(num_steps):
state_t = torch.FloatTensor(state)
probs, _ = local_model(state_t.unsqueeze(0))
action = torch.multinomial(probs, 1).item()
next_state, reward, done, _ = env.step(action)
experiences.append(Experience(state, action, reward, done, next_state))
state = next_state
if done:
state = env.reset()
# Send collected experiences to queue
data_queue.put(experiences)
A3C Gradient Parallelism
Gradient Parallelism is a method where each worker performs both experience collection and gradient computation, then directly applies the computed gradients to the central model. This is the approach from the original A3C paper.
def gradient_worker(worker_id, shared_model, optimizer, counter, lock,
max_episodes=10000, gamma=0.99, entropy_beta=0.01):
"""Gradient computation worker: computes gradients locally"""
env = make_env()
local_model = ActorCritic(env.observation_space.shape[0],
env.action_space.n)
state = env.reset()
episode_reward = 0.0
while True:
# Sync with shared model
local_model.load_state_dict(shared_model.state_dict())
log_probs = []
values = []
rewards = []
entropies = []
for _ in range(20): # n-step
state_t = torch.FloatTensor(state).unsqueeze(0)
probs, value = local_model(state_t)
dist = torch.distributions.Categorical(probs)
action = dist.sample()
log_prob = dist.log_prob(action)
entropy = dist.entropy()
next_state, reward, done, _ = env.step(action.item())
log_probs.append(log_prob)
values.append(value.squeeze())
rewards.append(reward)
entropies.append(entropy)
episode_reward += reward
state = next_state
if done:
state = env.reset()
with lock:
counter.value += 1
episode_reward = 0.0
break
# Compute bootstrap value
if done:
R = torch.tensor(0.0)
else:
_, R = local_model(torch.FloatTensor(state).unsqueeze(0))
R = R.squeeze().detach()
# Compute returns and loss in reverse
policy_loss = 0.0
value_loss = 0.0
entropy_loss = 0.0
for i in reversed(range(len(rewards))):
R = rewards[i] + gamma * R
advantage = R - values[i].detach()
policy_loss -= log_probs[i] * advantage
value_loss += 0.5 * (R - values[i]) ** 2
entropy_loss -= entropies[i]
total_loss = policy_loss + value_loss + entropy_beta * entropy_loss
# Compute local gradients
optimizer.zero_grad()
total_loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(local_model.parameters(), 40.0)
# Transfer gradients to shared model and update
for shared_param, local_param in zip(shared_model.parameters(),
local_model.parameters()):
if shared_param.grad is None:
shared_param.grad = local_param.grad.clone()
else:
shared_param.grad.copy_(local_param.grad)
optimizer.step()
Complete A3C Training Loop
def train_a3c(env_name='CartPole-v1', num_workers=4, max_episodes=5000):
"""A3C main training function"""
env = make_env()
obs_size = env.observation_space.shape[0]
act_size = env.action_space.n
env.close()
shared_model = ActorCritic(obs_size, act_size)
shared_model.share_memory()
optimizer = SharedAdam(shared_model.parameters(), lr=1e-4)
optimizer.share_memory()
counter = mp.Value('i', 0)
lock = mp.Lock()
processes = []
for i in range(num_workers):
p = mp.Process(
target=gradient_worker,
args=(i, shared_model, optimizer, counter, lock, max_episodes)
)
p.start()
processes.append(p)
# Monitoring process
while counter.value < max_episodes:
import time
time.sleep(10)
print(f"Completed episodes: {counter.value}/{max_episodes}")
for p in processes:
p.terminate()
p.join()
return shared_model
if __name__ == '__main__':
mp.set_start_method('spawn')
model = train_a3c()
torch.save(model.state_dict(), 'a3c_model.pth')
Data Parallelism vs Gradient Parallelism Comparison
| Item | Data Parallelism | Gradient Parallelism |
|---|---|---|
| Worker role | Experience collection only | Experience collection + gradient computation |
| Communication | Transition data (state, action, reward) | Gradient tensors |
| Communication volume | Proportional to state size | Proportional to model parameter count |
| Central load | Training computation concentrated | Only performs updates |
| Implementation difficulty | Relatively easy | Requires shared memory management |
| Scalability | Central bottleneck possible | Scales linearly with worker count |
Experimental Results Comparison
Training performance on CartPole-v1 environment:
- A2C (single env): Converges around 500 episodes
- A2C (8 parallel envs): Converges around 200 episodes
- A3C (8 workers, data parallel): Converges around 150 episodes
- A3C (8 workers, gradient parallel): Converges around 120 episodes
While asynchronous updates introduce some noise, the exploration diversity generally leads to faster convergence overall.
Practical Tips and Considerations
-
Number of workers: It is common to set this equal to the number of CPU cores. When using GPU, it is more efficient to reduce the number of workers and increase the batch size.
-
Instability of asynchronous updates: The larger the model version difference (staleness) between workers, the more unstable learning becomes. Gradient clipping is essential.
-
A2C vs A3C selection criteria: If using GPU, A2C is often more efficient. Vectorized environments can provide sufficient exploration diversity even with A2C. A3C has greater advantages in CPU-based learning.
-
Debugging: Asynchronous programs are difficult to debug. It is best to verify correct behavior with a single worker first, then increase the number of workers.
Key Takeaways
- A3C solves the data correlation problem through asynchronous parallel learning
- Data parallelism distributes experience collection, while gradient parallelism distributes computation as well
- Implementation uses Python's multiprocessing and shared memory
- Recently, synchronous methods (A2C) and PPO are preferred due to GPU efficiency
In the next post, we will explore how to apply reinforcement learning to natural language processing for training chatbots.