Skip to content
Published on

[深層強化学習] 18. AlphaGo Zero:自己対局で学ぶAI

Authors

概要

2016年にAlphaGoがイ・セドル9段を破った時、全世界が驚きました。しかし真の革新は2017年のAlphaGo Zeroで起こりました。人間の棋譜データなしに、ただ**自己対局(self-play)**だけで囲碁のルールから超人的レベルまで学習したのです。

この記事では、ボードゲームと強化学習の関係、AlphaGo Zeroの核心アルゴリズム(MCTS + ニューラルネットワーク + 自己対局)、そしてConnect4(四目並べ)ボットを実装する方法を扱います。


ボードゲームと強化学習

ボードゲームがRLに適している理由

  • 完全情報ゲーム:すべての状態が観察可能
  • 決定的遷移:行動の結果が明確(サイコロゲームを除く)
  • 明確な報酬:勝/敗/引き分けが確実
  • 自己対局可能:相手なしで自分自身と対戦しながら学習

AlphaGoの進化

バージョン核心特徴データ
AlphaGo FanCNN + MCTS + 価値ネットワーク人間の棋譜 + 自己対局
AlphaGo Leeより深いネットワーク + 改良MCTS人間の棋譜 + 自己対局
AlphaGo ZeroResNet + 統合ネットワーク自己対局のみ
AlphaZero囲碁、チェス、将棋統合自己対局のみ

モンテカルロ木探索(MCTS)

MCTSはゲーム木を効率的に探索するアルゴリズムです。全体の木を探索するミニマックスとは異なり、有望な部分のみを集中的に探索します。

MCTSの4段階

  1. 選択(Selection):ルートから開始してUCB(Upper Confidence Bound)公式で子ノードを選択
  2. 展開(Expansion):リーフノードに新しい子ノードを追加
  3. シミュレーション(Simulation):新ノードからランダムプレイアウト(AlphaGo Zeroではニューラルネットワーク評価に置換)
  4. 逆伝播(Backpropagation):結果をルートまで伝播して統計を更新
import math
import numpy as np

class MCTSNode:
    """MCTS木ノード"""

    def __init__(self, state, parent=None, action=None, prior=0.0):
        self.state = state
        self.parent = parent
        self.action = action  # このノードに到達した行動
        self.prior = prior    # ニューラルネットワークが予測した事前確率

        self.children = {}    # action -> MCTSNode
        self.visit_count = 0
        self.total_value = 0.0

    @property
    def q_value(self):
        if self.visit_count == 0:
            return 0.0
        return self.total_value / self.visit_count

    def ucb_score(self, c_puct=1.0):
        """PUCT(Predictor + UCT)スコアの計算"""
        if self.parent is None:
            return 0.0
        exploration = c_puct * self.prior * math.sqrt(
            self.parent.visit_count
        ) / (1 + self.visit_count)
        return self.q_value + exploration

    def is_leaf(self):
        return len(self.children) == 0

    def best_child(self, c_puct=1.0):
        """UCBスコアが最も高い子を選択"""
        best_score = -float('inf')
        best_node = None
        for child in self.children.values():
            score = child.ucb_score(c_puct)
            if score > best_score:
                best_score = score
                best_node = child
        return best_node

    def best_action(self, temperature=1.0):
        """訪問回数に基づいて行動を選択"""
        visit_counts = np.array([
            child.visit_count for child in self.children.values()
        ])
        actions = list(self.children.keys())
        if temperature == 0:
            return actions[np.argmax(visit_counts)]
        else:
            probs = visit_counts ** (1.0 / temperature)
            probs = probs / probs.sum()
            return np.random.choice(actions, p=probs)

MCTS探索関数

class MCTS:
    """ニューラルネットワークベースのMCTS"""

    def __init__(self, model, game, num_simulations=800, c_puct=1.0):
        self.model = model
        self.game = game
        self.num_simulations = num_simulations
        self.c_puct = c_puct

    def search(self, state):
        """MCTS探索の実行"""
        root = MCTSNode(state)
        self._expand(root)
        for _ in range(self.num_simulations):
            node = root
            while not node.is_leaf():
                node = node.best_child(self.c_puct)
            game_result = self.game.get_result(node.state)
            if game_result is None:
                value = self._expand(node)
            else:
                value = game_result
            self._backpropagate(node, value)
        return root

    def _expand(self, node):
        """ノード展開:ニューラルネットワークで方策と価値を予測"""
        import torch
        state_tensor = self.game.state_to_tensor(node.state)
        state_tensor = state_tensor.unsqueeze(0)
        with torch.no_grad():
            policy, value = self.model(state_tensor)
            policy = torch.softmax(policy, dim=-1).squeeze().numpy()
            value = value.item()
        valid_actions = self.game.get_valid_actions(node.state)
        for action in valid_actions:
            child_state = self.game.apply_action(node.state, action)
            child = MCTSNode(
                state=child_state, parent=node,
                action=action, prior=policy[action],
            )
            node.children[action] = child
        return value

    def _backpropagate(self, node, value):
        """結果をルートまで逆伝播"""
        while node is not None:
            node.visit_count += 1
            node.total_value += value
            value = -value
            node = node.parent

AlphaGo Zeroの自己対局と学習

自己対局(Self-Play)

import torch

def self_play_game(model, game, mcts_simulations=800, temperature=1.0):
    """自己対局で学習データを生成"""
    mcts = MCTS(model, game, num_simulations=mcts_simulations)
    state = game.get_initial_state()
    training_data = []
    move_count = 0

    while True:
        root = mcts.search(state)
        policy_target = np.zeros(game.action_size)
        for action, child in root.children.items():
            policy_target[action] = child.visit_count
        policy_target = policy_target / policy_target.sum()

        temp = temperature if move_count < 30 else 0.01
        action = root.best_action(temperature=temp)

        training_data.append((
            game.state_to_tensor(state), policy_target, None,
        ))

        state = game.apply_action(state, action)
        move_count += 1

        result = game.get_result(state)
        if result is not None:
            final_data = []
            for i, (s, p, _) in enumerate(training_data):
                if i % 2 == 0:
                    value = result
                else:
                    value = -result
                final_data.append((s, p, value))
            return final_data

学習ループ

class AlphaZeroTrainer:
    """AlphaZero学習マネージャ"""

    def __init__(self, model, game, config):
        self.model = model
        self.game = game
        self.config = config
        self.optimizer = torch.optim.Adam(
            model.parameters(),
            lr=config.get('lr', 1e-3),
            weight_decay=config.get('weight_decay', 1e-4),
        )
        self.replay_buffer = []
        self.max_buffer_size = config.get('buffer_size', 100000)

    def train(self, num_iterations=100):
        for iteration in range(num_iterations):
            print(f"反復 {iteration}: 自己対局中...")
            for _ in range(self.config['games_per_iteration']):
                game_data = self_play_game(
                    self.model, self.game,
                    mcts_simulations=self.config['mcts_simulations'],
                )
                self.replay_buffer.extend(game_data)
            if len(self.replay_buffer) > self.max_buffer_size:
                self.replay_buffer = self.replay_buffer[-self.max_buffer_size:]

            print(f"反復 {iteration}: 学習中...")
            for _ in range(self.config['train_epochs']):
                self._train_step()

            if iteration % self.config['eval_interval'] == 0:
                win_rate = self._evaluate()
                print(f"反復 {iteration}: 勝率 = {win_rate:.1%}")

    def _train_step(self):
        batch_size = self.config.get('batch_size', 256)
        if len(self.replay_buffer) < batch_size:
            return
        indices = np.random.choice(len(self.replay_buffer), batch_size)
        batch = [self.replay_buffer[i] for i in indices]
        states = torch.stack([b[0] for b in batch])
        target_policies = torch.FloatTensor(np.array([b[1] for b in batch]))
        target_values = torch.FloatTensor([b[2] for b in batch])

        pred_policies, pred_values = self.model(states)
        pred_policies = torch.log_softmax(pred_policies, dim=-1)
        pred_values = pred_values.squeeze(-1)

        policy_loss = -(target_policies * pred_policies).sum(dim=-1).mean()
        value_loss = (target_values - pred_values).pow(2).mean()
        loss = policy_loss + value_loss

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def _evaluate(self, num_games=20):
        wins = 0
        for _ in range(num_games):
            result = play_against_random(self.model, self.game)
            if result > 0:
                wins += 1
        return wins / num_games

Connect4ボットの実装

Connect4(四目並べ)は囲碁よりシンプルでありながらAlphaZeroの原理を実習するのに良いゲームです。

ゲームモデル

class Connect4:
    """Connect4ゲームルール"""
    ROWS = 6
    COLS = 7
    action_size = 7

    @staticmethod
    def get_initial_state():
        return np.zeros((Connect4.ROWS, Connect4.COLS), dtype=np.int8)

    @staticmethod
    def get_valid_actions(state):
        return [c for c in range(Connect4.COLS) if state[0, c] == 0]

    @staticmethod
    def apply_action(state, action):
        new_state = state.copy()
        for row in range(Connect4.ROWS - 1, -1, -1):
            if new_state[row, action] == 0:
                new_state[row, action] = 1
                break
        return -new_state

    @staticmethod
    def get_result(state):
        for player in [1, -1]:
            for r in range(Connect4.ROWS):
                for c in range(Connect4.COLS - 3):
                    if all(state[r, c+i] == player for i in range(4)):
                        return -1 if player == -1 else None
            for r in range(Connect4.ROWS - 3):
                for c in range(Connect4.COLS):
                    if all(state[r+i, c] == player for i in range(4)):
                        return -1 if player == -1 else None
            for r in range(Connect4.ROWS - 3):
                for c in range(Connect4.COLS - 3):
                    if all(state[r+i, c+i] == player for i in range(4)):
                        return -1 if player == -1 else None
            for r in range(Connect4.ROWS - 3):
                for c in range(3, Connect4.COLS):
                    if all(state[r+i, c-i] == player for i in range(4)):
                        return -1 if player == -1 else None
        if all(state[0, c] != 0 for c in range(Connect4.COLS)):
            return 0
        return None

    @staticmethod
    def state_to_tensor(state):
        player_plane = (state == 1).astype(np.float32)
        opponent_plane = (state == -1).astype(np.float32)
        return torch.FloatTensor(np.stack([player_plane, opponent_plane]))

Connect4ニューラルネットワーク

class Connect4Net(nn.Module):
    """Connect4用方策-価値ネットワーク"""
    def __init__(self, num_res_blocks=5):
        super().__init__()
        self.conv_input = nn.Sequential(
            nn.Conv2d(2, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
        )
        self.res_blocks = nn.ModuleList([ResBlock(128) for _ in range(num_res_blocks)])
        self.policy_head = nn.Sequential(
            nn.Conv2d(128, 32, 1), nn.BatchNorm2d(32), nn.ReLU(),
            nn.Flatten(), nn.Linear(32 * 6 * 7, Connect4.action_size),
        )
        self.value_head = nn.Sequential(
            nn.Conv2d(128, 1, 1), nn.BatchNorm2d(1), nn.ReLU(),
            nn.Flatten(), nn.Linear(6 * 7, 64), nn.ReLU(),
            nn.Linear(64, 1), nn.Tanh(),
        )

    def forward(self, x):
        out = self.conv_input(x)
        for block in self.res_blocks:
            out = block(out)
        return self.policy_head(out), self.value_head(out)

class ResBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = x
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        return torch.relu(out + residual)

学習の実行

def train_connect4():
    game = Connect4()
    model = Connect4Net(num_res_blocks=5)
    config = {
        'lr': 1e-3, 'weight_decay': 1e-4, 'buffer_size': 50000,
        'games_per_iteration': 100, 'mcts_simulations': 400,
        'train_epochs': 10, 'batch_size': 128, 'eval_interval': 5,
    }
    trainer = AlphaZeroTrainer(model, game, config)
    trainer.train(num_iterations=50)
    return model

def play_against_human(model, game):
    mcts = MCTS(model, game, num_simulations=400)
    state = game.get_initial_state()
    human_turn = True
    while True:
        display_board(state)
        result = game.get_result(state)
        if result is not None:
            if result == 0:
                print("引き分け!")
            else:
                print("ゲーム終了!")
            break
        if human_turn:
            valid = game.get_valid_actions(state)
            col = int(input(f"列を選択 {valid}: "))
            state = game.apply_action(state, col)
        else:
            root = mcts.search(state)
            action = root.best_action(temperature=0.01)
            print(f"AIの選択: 列 {action}")
            state = game.apply_action(state, action)
        human_turn = not human_turn

def display_board(state):
    symbols = {0: '.', 1: 'O', -1: 'X'}
    print("\n 0 1 2 3 4 5 6")
    for row in range(Connect4.ROWS):
        line = ' '.join(symbols[state[row, c]] for c in range(Connect4.COLS))
        print(f" {line}")
    print()

学習結果

Connect4で50回反復学習後:

  • ランダム相手:95%以上の勝率
  • ミニマックス(深さ4):80%以上の勝率
  • 学習時間:GPU基準で約2〜4時間

AlphaZeroの核心的な洞察は、MCTSが方策を改善し、改善された方策がより良い学習データを生成する好循環が起こるということです。


要点まとめ

  • AlphaGo Zeroは人間データなしで自己対局のみで超人的レベルに到達した
  • MCTSはニューラルネットワークの方策/価値予測を活用してゲーム木を効率的に探索する
  • 自己対局でMCTSの訪問回数が方策学習のターゲットとなり、ゲーム結果が価値学習のターゲットとなる
  • Connect4のような簡単なゲームでもAlphaZeroフレームワークの効果を確認できる

次の記事では、深層強化学習の実践的な応用事例を幅広く見ていきます。