Skip to content
Published on

[Deep RL] 11. A3C: Asynchronous Advantage Actor-Critic

Authors

Overview

The A2C (Advantage Actor-Critic) method we examined in the previous post suffered from correlation issues in experience data collected from a single environment. Consecutive state transitions are strongly correlated with each other, reducing learning efficiency. DQN solved this problem with experience replay, but on-policy methods like Actor-Critic require a different approach.

A3C (Asynchronous Advantage Actor-Critic) breaks the data correlation by running multiple environments simultaneously. Proposed by Mnih et al. at DeepMind in 2016, this method achieves stable learning without a replay buffer.


Correlation and Sample Efficiency

Why Is Correlation a Problem?

In reinforcement learning, transitions collected by an agent during a single episode are temporally consecutive. For example, consecutive frames in a Pong game where the ball is flying to the right have very similar states. When updating a neural network with such correlated data:

  • Gradients become biased in a particular direction
  • Learning becomes unstable and convergence slows down
  • In the worst case, learning can diverge

Comparison of Solution Approaches

MethodPrincipleAdvantagesDisadvantages
Experience ReplayStore past transitions in buffer, sample randomlySample efficientOff-policy only
Parallel EnvsRun multiple environments simultaneouslyOn-policy compatibleMore computation needed
A3CAsync parallel envs + independent learningMaximizes exploration diversityComplex implementation

From A2C to A3C: The Meaning of the Added A

A2C runs multiple environments in parallel but collects experiences synchronously and updates all at once. A3C adds Asynchronous to this.

A2C's Synchronous Approach

# A2C: All workers operate synchronously
class A2CAgent:
    def __init__(self, num_envs, model):
        self.envs = [make_env() for _ in range(num_envs)]
        self.model = model  # Shared model

    def train_step(self):
        # 1. Collect actions from all environments simultaneously
        states = [env.get_state() for env in self.envs]
        actions, values = self.model.predict(states)

        # 2. Execute steps in all environments simultaneously
        rewards, next_states, dones = [], [], []
        for env, action in zip(self.envs, actions):
            r, ns, d = env.step(action)
            rewards.append(r)
            next_states.append(ns)
            dones.append(d)

        # 3. Gather and update all at once
        self.model.update(states, actions, rewards, next_states, dones)

A3C's Asynchronous Approach

In A3C, each worker independently interacts with its environment, computes gradients locally, and asynchronously applies them to the central model.

Key differences:

  • Each worker does not wait for other workers
  • Each worker can use a different exploration policy (e.g., different epsilon values)
  • Exploration diversity is naturally achieved

Python Multiprocessing Basics

Before implementing A3C, we need to understand Python's multiprocessing. Due to the GIL (Global Interpreter Lock), thread-based parallelism is not suitable for CPU-bound tasks.

import torch.multiprocessing as mp

def worker_process(worker_id, shared_model, optimizer, device):
    """Function executed by each worker process"""
    env = make_env()
    local_model = ActorCritic(env.observation_space.shape[0],
                              env.action_space.n)
    local_model.to(device)

    while True:
        # Copy shared model parameters to local model
        local_model.load_state_dict(shared_model.state_dict())

        # Collect experiences from local environment
        experiences = collect_experiences(env, local_model, n_steps=20)

        # Compute gradients locally
        loss = compute_loss(local_model, experiences)
        loss.backward()

        # Apply gradients to shared model
        for shared_param, local_param in zip(shared_model.parameters(),
                                              local_model.parameters()):
            shared_param.grad = local_param.grad

        optimizer.step()
        optimizer.zero_grad()

if __name__ == '__main__':
    mp.set_start_method('spawn')
    shared_model = ActorCritic(obs_size, act_size)
    shared_model.share_memory()  # Share memory across processes

    optimizer = SharedAdam(shared_model.parameters(), lr=1e-4)
    optimizer.share_memory()

    processes = []
    for i in range(mp.cpu_count()):
        p = mp.Process(target=worker_process,
                       args=(i, shared_model, optimizer, 'cpu'))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

SharedAdam Optimizer

The Adam optimizer's momentum states must also be shared across processes:

import torch

class SharedAdam(torch.optim.Adam):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8):
        super().__init__(params, lr=lr, betas=betas, eps=eps)
        # Move Adam's internal states to shared memory
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['step'] = torch.zeros(1)
                state['exp_avg'] = torch.zeros_like(p.data)
                state['exp_avg_sq'] = torch.zeros_like(p.data)
                # Set up shared memory
                state['step'].share_memory_()
                state['exp_avg'].share_memory_()
                state['exp_avg_sq'].share_memory_()

A3C Data Parallelism

Data Parallelism is a method where each worker collects experience data and sends it to a central location, which then gathers and trains on the data.

import torch
import torch.nn as nn
import torch.multiprocessing as mp
from collections import namedtuple

Experience = namedtuple('Experience',
    ['state', 'action', 'reward', 'done', 'next_state'])

class ActorCritic(nn.Module):
    def __init__(self, obs_size, act_size):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(obs_size, 256),
            nn.ReLU(),
        )
        self.policy = nn.Sequential(
            nn.Linear(256, act_size),
            nn.Softmax(dim=-1),
        )
        self.value = nn.Linear(256, 1)

    def forward(self, x):
        shared_out = self.shared(x)
        return self.policy(shared_out), self.value(shared_out)

def data_worker(worker_id, shared_model, data_queue, num_steps=20):
    """Data collection worker: puts experiences into the queue"""
    env = make_env()
    state = env.reset()

    while True:
        # Sync shared model parameters
        local_model = ActorCritic(env.observation_space.shape[0],
                                  env.action_space.n)
        local_model.load_state_dict(shared_model.state_dict())

        experiences = []
        for _ in range(num_steps):
            state_t = torch.FloatTensor(state)
            probs, _ = local_model(state_t.unsqueeze(0))
            action = torch.multinomial(probs, 1).item()

            next_state, reward, done, _ = env.step(action)
            experiences.append(Experience(state, action, reward, done, next_state))

            state = next_state
            if done:
                state = env.reset()

        # Send collected experiences to queue
        data_queue.put(experiences)

A3C Gradient Parallelism

Gradient Parallelism is a method where each worker performs both experience collection and gradient computation, then directly applies the computed gradients to the central model. This is the approach from the original A3C paper.

def gradient_worker(worker_id, shared_model, optimizer, counter, lock,
                    max_episodes=10000, gamma=0.99, entropy_beta=0.01):
    """Gradient computation worker: computes gradients locally"""
    env = make_env()
    local_model = ActorCritic(env.observation_space.shape[0],
                              env.action_space.n)

    state = env.reset()
    episode_reward = 0.0

    while True:
        # Sync with shared model
        local_model.load_state_dict(shared_model.state_dict())

        log_probs = []
        values = []
        rewards = []
        entropies = []

        for _ in range(20):  # n-step
            state_t = torch.FloatTensor(state).unsqueeze(0)
            probs, value = local_model(state_t)

            dist = torch.distributions.Categorical(probs)
            action = dist.sample()

            log_prob = dist.log_prob(action)
            entropy = dist.entropy()

            next_state, reward, done, _ = env.step(action.item())

            log_probs.append(log_prob)
            values.append(value.squeeze())
            rewards.append(reward)
            entropies.append(entropy)

            episode_reward += reward
            state = next_state

            if done:
                state = env.reset()
                with lock:
                    counter.value += 1
                episode_reward = 0.0
                break

        # Compute bootstrap value
        if done:
            R = torch.tensor(0.0)
        else:
            _, R = local_model(torch.FloatTensor(state).unsqueeze(0))
            R = R.squeeze().detach()

        # Compute returns and loss in reverse
        policy_loss = 0.0
        value_loss = 0.0
        entropy_loss = 0.0

        for i in reversed(range(len(rewards))):
            R = rewards[i] + gamma * R
            advantage = R - values[i].detach()

            policy_loss -= log_probs[i] * advantage
            value_loss += 0.5 * (R - values[i]) ** 2
            entropy_loss -= entropies[i]

        total_loss = policy_loss + value_loss + entropy_beta * entropy_loss

        # Compute local gradients
        optimizer.zero_grad()
        total_loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(local_model.parameters(), 40.0)

        # Transfer gradients to shared model and update
        for shared_param, local_param in zip(shared_model.parameters(),
                                              local_model.parameters()):
            if shared_param.grad is None:
                shared_param.grad = local_param.grad.clone()
            else:
                shared_param.grad.copy_(local_param.grad)
        optimizer.step()

Complete A3C Training Loop

def train_a3c(env_name='CartPole-v1', num_workers=4, max_episodes=5000):
    """A3C main training function"""
    env = make_env()
    obs_size = env.observation_space.shape[0]
    act_size = env.action_space.n
    env.close()

    shared_model = ActorCritic(obs_size, act_size)
    shared_model.share_memory()

    optimizer = SharedAdam(shared_model.parameters(), lr=1e-4)
    optimizer.share_memory()

    counter = mp.Value('i', 0)
    lock = mp.Lock()

    processes = []
    for i in range(num_workers):
        p = mp.Process(
            target=gradient_worker,
            args=(i, shared_model, optimizer, counter, lock, max_episodes)
        )
        p.start()
        processes.append(p)

    # Monitoring process
    while counter.value < max_episodes:
        import time
        time.sleep(10)
        print(f"Completed episodes: {counter.value}/{max_episodes}")

    for p in processes:
        p.terminate()
        p.join()

    return shared_model

if __name__ == '__main__':
    mp.set_start_method('spawn')
    model = train_a3c()
    torch.save(model.state_dict(), 'a3c_model.pth')

Data Parallelism vs Gradient Parallelism Comparison

ItemData ParallelismGradient Parallelism
Worker roleExperience collection onlyExperience collection + gradient computation
CommunicationTransition data (state, action, reward)Gradient tensors
Communication volumeProportional to state sizeProportional to model parameter count
Central loadTraining computation concentratedOnly performs updates
Implementation difficultyRelatively easyRequires shared memory management
ScalabilityCentral bottleneck possibleScales linearly with worker count

Experimental Results Comparison

Training performance on CartPole-v1 environment:

  • A2C (single env): Converges around 500 episodes
  • A2C (8 parallel envs): Converges around 200 episodes
  • A3C (8 workers, data parallel): Converges around 150 episodes
  • A3C (8 workers, gradient parallel): Converges around 120 episodes

While asynchronous updates introduce some noise, the exploration diversity generally leads to faster convergence overall.


Practical Tips and Considerations

  1. Number of workers: It is common to set this equal to the number of CPU cores. When using GPU, it is more efficient to reduce the number of workers and increase the batch size.

  2. Instability of asynchronous updates: The larger the model version difference (staleness) between workers, the more unstable learning becomes. Gradient clipping is essential.

  3. A2C vs A3C selection criteria: If using GPU, A2C is often more efficient. Vectorized environments can provide sufficient exploration diversity even with A2C. A3C has greater advantages in CPU-based learning.

  4. Debugging: Asynchronous programs are difficult to debug. It is best to verify correct behavior with a single worker first, then increase the number of workers.


Key Takeaways

  • A3C solves the data correlation problem through asynchronous parallel learning
  • Data parallelism distributes experience collection, while gradient parallelism distributes computation as well
  • Implementation uses Python's multiprocessing and shared memory
  • Recently, synchronous methods (A2C) and PPO are preferred due to GPU efficiency

In the next post, we will explore how to apply reinforcement learning to natural language processing for training chatbots.