Skip to content
Published on

[深層強化学習] 11. A3C: 非同期Advantage Actor-Critic

Authors

概要

前回の記事で紹介したA2C(Advantage Actor-Critic)は、単一環境で収集した経験データの**相関(correlation)**問題がありました。連続する状態遷移は互いに強く関連しており、学習効率が低下します。DQNではexperience replayでこの問題を解決しましたが、on-policy手法であるActor-Criticでは別のアプローチが必要です。

**A3C(Asynchronous Advantage Actor-Critic)**は、複数の環境を同時に実行してデータの相関を破壊する方法です。2016年にDeepMindのMnihらが提案したこの方法は、replay bufferなしでも安定した学習を実現します。


相関とサンプル効率

なぜ相関が問題なのか

強化学習において、エージェントが1つのエピソード中に収集する遷移(transition)は時間的に連続しています。例えば、Pongゲームでボールが右に飛んでいる連続フレームは非常に類似した状態を持ちます。このような相関データでニューラルネットワークを更新すると:

  • 勾配が特定の方向に偏る
  • 学習が不安定になり収束速度が遅くなる
  • 最悪の場合、学習が発散する可能性がある

解決アプローチの比較

方法原理長所短所
Experience Replay過去の遷移をバッファに保存しランダムサンプリングサンプル効率的Off-policyのみ
並列環境複数の環境を同時に実行On-policy互換より多くの計算必要
A3C非同期並列環境 + 独立学習探索多様性の最大化実装が複雑

A2CからA3Cへ:追加されたAの意味

A2Cは複数の環境を並列に実行しつつ**同期的(synchronous)に経験を集めて一括更新します。A3CはここにAsynchronous(非同期)**を追加します。

A2Cの同期方式

# A2C: すべてのワーカーが同期的に動作
class A2CAgent:
    def __init__(self, num_envs, model):
        self.envs = [make_env() for _ in range(num_envs)]
        self.model = model  # 共有モデル

    def train_step(self):
        # 1. すべての環境から同時に行動を収集
        states = [env.get_state() for env in self.envs]
        actions, values = self.model.predict(states)

        # 2. すべての環境で同時にステップ実行
        rewards, next_states, dones = [], [], []
        for env, action in zip(self.envs, actions):
            r, ns, d = env.step(action)
            rewards.append(r)
            next_states.append(ns)
            dones.append(d)

        # 3. まとめて一度に更新
        self.model.update(states, actions, rewards, next_states, dones)

A3Cの非同期方式

A3Cでは、各ワーカーが独立的に環境と相互作用し、自らの勾配を計算した後、中央モデルに非同期的に反映します。

主な違い:

  • 各ワーカーは他のワーカーを待たない
  • ワーカーごとに異なる探索方策を使用可能(イプシロン値を異なる設定など)
  • 自然に探索の多様性が確保される

Pythonマルチプロセッシングの基礎

A3Cを実装する前に、Pythonのマルチプロセッシングを理解する必要があります。GIL(Global Interpreter Lock)のため、スレッドベースの並列化はCPUバウンドタスクには適していません。

import torch.multiprocessing as mp

def worker_process(worker_id, shared_model, optimizer, device):
    """各ワーカープロセスが実行する関数"""
    env = make_env()
    local_model = ActorCritic(env.observation_space.shape[0],
                              env.action_space.n)
    local_model.to(device)

    while True:
        # 共有モデルのパラメータをローカルモデルにコピー
        local_model.load_state_dict(shared_model.state_dict())

        # ローカル環境で経験を収集
        experiences = collect_experiences(env, local_model, n_steps=20)

        # ローカルで勾配を計算
        loss = compute_loss(local_model, experiences)
        loss.backward()

        # 共有モデルに勾配を反映
        for shared_param, local_param in zip(shared_model.parameters(),
                                              local_model.parameters()):
            shared_param.grad = local_param.grad

        optimizer.step()
        optimizer.zero_grad()

if __name__ == '__main__':
    mp.set_start_method('spawn')
    shared_model = ActorCritic(obs_size, act_size)
    shared_model.share_memory()  # プロセス間メモリ共有

    optimizer = SharedAdam(shared_model.parameters(), lr=1e-4)
    optimizer.share_memory()

    processes = []
    for i in range(mp.cpu_count()):
        p = mp.Process(target=worker_process,
                       args=(i, shared_model, optimizer, 'cpu'))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

SharedAdamオプティマイザ

Adamオプティマイザのモメンタム状態もプロセス間で共有する必要があります:

import torch

class SharedAdam(torch.optim.Adam):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8):
        super().__init__(params, lr=lr, betas=betas, eps=eps)
        # Adamの内部状態を共有メモリに移動
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['step'] = torch.zeros(1)
                state['exp_avg'] = torch.zeros_like(p.data)
                state['exp_avg_sq'] = torch.zeros_like(p.data)
                # 共有メモリの設定
                state['step'].share_memory_()
                state['exp_avg'].share_memory_()
                state['exp_avg_sq'].share_memory_()

A3Cデータ並列化

データ並列化(Data Parallelism)は、各ワーカーが経験データを収集して中央に送信し、中央でまとめて学習する方式です。

import torch
import torch.nn as nn
import torch.multiprocessing as mp
from collections import namedtuple

Experience = namedtuple('Experience',
    ['state', 'action', 'reward', 'done', 'next_state'])

class ActorCritic(nn.Module):
    def __init__(self, obs_size, act_size):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(obs_size, 256),
            nn.ReLU(),
        )
        self.policy = nn.Sequential(
            nn.Linear(256, act_size),
            nn.Softmax(dim=-1),
        )
        self.value = nn.Linear(256, 1)

    def forward(self, x):
        shared_out = self.shared(x)
        return self.policy(shared_out), self.value(shared_out)

def data_worker(worker_id, shared_model, data_queue, num_steps=20):
    """データ収集ワーカー: 経験をキューに入れる"""
    env = make_env()
    state = env.reset()

    while True:
        # 共有モデルパラメータの同期
        local_model = ActorCritic(env.observation_space.shape[0],
                                  env.action_space.n)
        local_model.load_state_dict(shared_model.state_dict())

        experiences = []
        for _ in range(num_steps):
            state_t = torch.FloatTensor(state)
            probs, _ = local_model(state_t.unsqueeze(0))
            action = torch.multinomial(probs, 1).item()

            next_state, reward, done, _ = env.step(action)
            experiences.append(Experience(state, action, reward, done, next_state))

            state = next_state
            if done:
                state = env.reset()

        # 収集した経験をキューに送信
        data_queue.put(experiences)

A3C勾配並列化

勾配並列化(Gradient Parallelism)は、各ワーカーが経験収集と勾配計算の両方を行い、計算された勾配を中央モデルに直接適用する方式です。これが元のA3C論文の方式です。

def gradient_worker(worker_id, shared_model, optimizer, counter, lock,
                    max_episodes=10000, gamma=0.99, entropy_beta=0.01):
    """勾配計算ワーカー: ローカルで勾配まで計算"""
    env = make_env()
    local_model = ActorCritic(env.observation_space.shape[0],
                              env.action_space.n)

    state = env.reset()
    episode_reward = 0.0

    while True:
        # 共有モデルの同期
        local_model.load_state_dict(shared_model.state_dict())

        log_probs = []
        values = []
        rewards = []
        entropies = []

        for _ in range(20):  # n-step
            state_t = torch.FloatTensor(state).unsqueeze(0)
            probs, value = local_model(state_t)

            dist = torch.distributions.Categorical(probs)
            action = dist.sample()

            log_prob = dist.log_prob(action)
            entropy = dist.entropy()

            next_state, reward, done, _ = env.step(action.item())

            log_probs.append(log_prob)
            values.append(value.squeeze())
            rewards.append(reward)
            entropies.append(entropy)

            episode_reward += reward
            state = next_state

            if done:
                state = env.reset()
                with lock:
                    counter.value += 1
                episode_reward = 0.0
                break

        # ブートストラップ値の計算
        if done:
            R = torch.tensor(0.0)
        else:
            _, R = local_model(torch.FloatTensor(state).unsqueeze(0))
            R = R.squeeze().detach()

        # 逆方向にリターン計算と損失計算
        policy_loss = 0.0
        value_loss = 0.0
        entropy_loss = 0.0

        for i in reversed(range(len(rewards))):
            R = rewards[i] + gamma * R
            advantage = R - values[i].detach()

            policy_loss -= log_probs[i] * advantage
            value_loss += 0.5 * (R - values[i]) ** 2
            entropy_loss -= entropies[i]

        total_loss = policy_loss + value_loss + entropy_beta * entropy_loss

        # ローカル勾配の計算
        optimizer.zero_grad()
        total_loss.backward()

        # 勾配クリッピング
        torch.nn.utils.clip_grad_norm_(local_model.parameters(), 40.0)

        # 共有モデルに勾配を転送して更新
        for shared_param, local_param in zip(shared_model.parameters(),
                                              local_model.parameters()):
            if shared_param.grad is None:
                shared_param.grad = local_param.grad.clone()
            else:
                shared_param.grad.copy_(local_param.grad)
        optimizer.step()

A3C全体の学習ループ

def train_a3c(env_name='CartPole-v1', num_workers=4, max_episodes=5000):
    """A3Cメイン学習関数"""
    env = make_env()
    obs_size = env.observation_space.shape[0]
    act_size = env.action_space.n
    env.close()

    shared_model = ActorCritic(obs_size, act_size)
    shared_model.share_memory()

    optimizer = SharedAdam(shared_model.parameters(), lr=1e-4)
    optimizer.share_memory()

    counter = mp.Value('i', 0)
    lock = mp.Lock()

    processes = []
    for i in range(num_workers):
        p = mp.Process(
            target=gradient_worker,
            args=(i, shared_model, optimizer, counter, lock, max_episodes)
        )
        p.start()
        processes.append(p)

    # モニタリングプロセス
    while counter.value < max_episodes:
        import time
        time.sleep(10)
        print(f"完了エピソード: {counter.value}/{max_episodes}")

    for p in processes:
        p.terminate()
        p.join()

    return shared_model

if __name__ == '__main__':
    mp.set_start_method('spawn')
    model = train_a3c()
    torch.save(model.state_dict(), 'a3c_model.pth')

データ並列化 vs 勾配並列化の比較

項目データ並列化勾配並列化
ワーカーの役割経験収集のみ経験収集 + 勾配計算
通信内容遷移データ(状態、行動、報酬)勾配テンソル
通信量状態サイズに比例モデルパラメータ数に比例
中央の負担学習演算集中更新のみ実行
実装難易度比較的容易共有メモリ管理が必要
スケーラビリティ中央ボトルネックの可能性ワーカー数に応じて線形スケール

実験結果の比較

CartPole-v1環境での学習性能:

  • A2C(単一環境):約500エピソードで収束
  • A2C(8並列環境):約200エピソードで収束
  • A3C(8ワーカー、データ並列):約150エピソードで収束
  • A3C(8ワーカー、勾配並列):約120エピソードで収束

非同期更新による若干のノイズがありますが、探索の多様性のおかげで全体的により速く収束する傾向があります。


実践的なヒントと注意事項

  1. ワーカー数の選択:CPUコア数と同じに設定するのが一般的です。GPU使用時はワーカー数を減らしてバッチサイズを増やすのが効率的です。

  2. 非同期更新の不安定性:ワーカー間のモデルバージョン差(staleness)が大きいほど学習が不安定になります。勾配クリッピングは必須です。

  3. A2C vs A3Cの選択基準:GPUを使用するならA2Cの方が効率的な場合が多いです。ベクトル化環境を使用すればA2Cでも十分な探索多様性を確保できます。A3CはCPUベースの学習で利点が大きいです。

  4. デバッグ:非同期プログラムはデバッグが困難です。まず単一ワーカーで正常動作を確認してからワーカー数を増やすのが良いでしょう。


要点まとめ

  • A3Cは非同期並列学習でデータの相関問題を解決する
  • データ並列化は経験収集を分散し、勾配並列化は計算も分散する
  • Pythonのmultiprocessingとshared memoryを活用して実装する
  • 最近ではGPU効率性のため同期方式(A2C)やPPOがより好まれる傾向にある

次の記事では、強化学習を自然言語処理に適用してチャットボットを訓練する方法を見ていきます。