Skip to content
Published on

[Deep RL] 18. AlphaGo Zero: AI That Learns by Playing Itself

Authors

Overview

When AlphaGo defeated Lee Sedol 9-dan in 2016, the whole world was amazed. But the true innovation came with AlphaGo Zero in 2017. It learned from the rules of Go all the way to superhuman level using only self-play, without any human game records.

This post covers the relationship between board games and reinforcement learning, the core algorithms of AlphaGo Zero (MCTS + neural network + self-play), and how to implement a Connect4 bot.


Board Games and Reinforcement Learning

Why Board Games Are Suitable for RL

  • Perfect information games: All states are observable
  • Deterministic transitions: Results of actions are clear (except dice games)
  • Clear rewards: Win/loss/draw is definitive
  • Self-play possible: Learning by playing against oneself without needing an opponent

Evolution of AlphaGo

VersionKey FeaturesData
AlphaGo FanCNN + MCTS + value networkHuman records + self-play
AlphaGo LeeDeeper network + improved MCTSHuman records + self-play
AlphaGo ZeroResNet + unified networkSelf-play only
AlphaZeroUnified for Go, chess, and shogiSelf-play only

Monte Carlo Tree Search (MCTS)

MCTS is an algorithm for efficiently searching game trees. Unlike minimax which searches the entire tree, it focuses on promising portions.

Four Stages of MCTS

  1. Selection: Starting from root, select child nodes using UCB (Upper Confidence Bound) formula
  2. Expansion: Add new child nodes to leaf nodes
  3. Simulation: Random playout from new node (replaced by neural network evaluation in AlphaGo Zero)
  4. Backpropagation: Propagate results back to root to update statistics
import math
import numpy as np

class MCTSNode:
    """MCTS tree node"""

    def __init__(self, state, parent=None, action=None, prior=0.0):
        self.state = state
        self.parent = parent
        self.action = action  # Action that led to this node
        self.prior = prior    # Prior probability predicted by neural network

        self.children = {}    # action -> MCTSNode
        self.visit_count = 0
        self.total_value = 0.0

    @property
    def q_value(self):
        if self.visit_count == 0:
            return 0.0
        return self.total_value / self.visit_count

    def ucb_score(self, c_puct=1.0):
        """PUCT (Predictor + UCT) score computation"""
        if self.parent is None:
            return 0.0

        exploration = c_puct * self.prior * math.sqrt(
            self.parent.visit_count
        ) / (1 + self.visit_count)

        return self.q_value + exploration

    def is_leaf(self):
        return len(self.children) == 0

    def best_child(self, c_puct=1.0):
        """Select child with highest UCB score"""
        best_score = -float('inf')
        best_node = None
        for child in self.children.values():
            score = child.ucb_score(c_puct)
            if score > best_score:
                best_score = score
                best_node = child
        return best_node

    def best_action(self, temperature=1.0):
        """Select action based on visit counts"""
        visit_counts = np.array([
            child.visit_count for child in self.children.values()
        ])
        actions = list(self.children.keys())

        if temperature == 0:
            # Greedy selection
            return actions[np.argmax(visit_counts)]
        else:
            # Temperature-based probabilistic selection
            probs = visit_counts ** (1.0 / temperature)
            probs = probs / probs.sum()
            return np.random.choice(actions, p=probs)

MCTS Search Function

class MCTS:
    """Neural network-based MCTS"""

    def __init__(self, model, game, num_simulations=800, c_puct=1.0):
        self.model = model      # Policy + value network
        self.game = game        # Game rules
        self.num_simulations = num_simulations
        self.c_puct = c_puct

    def search(self, state):
        """Perform MCTS search"""
        root = MCTSNode(state)
        self._expand(root)

        for _ in range(self.num_simulations):
            node = root

            # 1. Selection: descend to leaf
            while not node.is_leaf():
                node = node.best_child(self.c_puct)

            # Check if game over
            game_result = self.game.get_result(node.state)

            if game_result is None:
                # 2. Expansion: expand leaf node
                value = self._expand(node)
            else:
                # Use actual result when game is over
                value = game_result

            # 3. Backpropagation: propagate result to root
            self._backpropagate(node, value)

        return root

    def _expand(self, node):
        """Expand node: predict policy and value with neural network"""
        import torch

        state_tensor = self.game.state_to_tensor(node.state)
        state_tensor = state_tensor.unsqueeze(0)

        with torch.no_grad():
            policy, value = self.model(state_tensor)
            policy = torch.softmax(policy, dim=-1).squeeze().numpy()
            value = value.item()

        # Create child nodes only for valid actions
        valid_actions = self.game.get_valid_actions(node.state)
        for action in valid_actions:
            child_state = self.game.apply_action(node.state, action)
            child = MCTSNode(
                state=child_state,
                parent=node,
                action=action,
                prior=policy[action],
            )
            node.children[action] = child

        return value

    def _backpropagate(self, node, value):
        """Backpropagate result to root"""
        while node is not None:
            node.visit_count += 1
            # Invert value from opponent's perspective
            node.total_value += value
            value = -value  # Opponent's perspective
            node = node.parent

AlphaGo Zero's Self-Play and Training

Self-Play

import torch

def self_play_game(model, game, mcts_simulations=800, temperature=1.0):
    """Generate training data through self-play"""
    mcts = MCTS(model, game, num_simulations=mcts_simulations)
    state = game.get_initial_state()

    training_data = []  # Store (state, policy, value)
    move_count = 0

    while True:
        # MCTS search
        root = mcts.search(state)

        # Policy target: normalized visit counts
        policy_target = np.zeros(game.action_size)
        for action, child in root.children.items():
            policy_target[action] = child.visit_count
        policy_target = policy_target / policy_target.sum()

        # Promote exploration early, deterministic later
        temp = temperature if move_count < 30 else 0.01
        action = root.best_action(temperature=temp)

        training_data.append((
            game.state_to_tensor(state),
            policy_target,
            None,  # Value filled after game ends
        ))

        # Execute action
        state = game.apply_action(state, action)
        move_count += 1

        # Check game over
        result = game.get_result(state)
        if result is not None:
            # Fill in win/loss result as value at each timestep
            final_data = []
            for i, (s, p, _) in enumerate(training_data):
                # Result from current player's perspective at that timestep
                if i % 2 == 0:
                    value = result
                else:
                    value = -result
                final_data.append((s, p, value))
            return final_data

Training Loop

class AlphaZeroTrainer:
    """AlphaZero training manager"""

    def __init__(self, model, game, config):
        self.model = model
        self.game = game
        self.config = config
        self.optimizer = torch.optim.Adam(
            model.parameters(),
            lr=config.get('lr', 1e-3),
            weight_decay=config.get('weight_decay', 1e-4),
        )
        self.replay_buffer = []
        self.max_buffer_size = config.get('buffer_size', 100000)

    def train(self, num_iterations=100):
        for iteration in range(num_iterations):
            # 1. Generate data through self-play
            print(f"Iteration {iteration}: Self-playing...")
            for _ in range(self.config['games_per_iteration']):
                game_data = self_play_game(
                    self.model, self.game,
                    mcts_simulations=self.config['mcts_simulations'],
                )
                self.replay_buffer.extend(game_data)

            # Limit buffer size
            if len(self.replay_buffer) > self.max_buffer_size:
                self.replay_buffer = self.replay_buffer[-self.max_buffer_size:]

            # 2. Neural network training
            print(f"Iteration {iteration}: Training...")
            for _ in range(self.config['train_epochs']):
                self._train_step()

            # 3. Evaluation
            if iteration % self.config['eval_interval'] == 0:
                win_rate = self._evaluate()
                print(f"Iteration {iteration}: Win rate = {win_rate:.1%}")

    def _train_step(self):
        batch_size = self.config.get('batch_size', 256)
        if len(self.replay_buffer) < batch_size:
            return

        indices = np.random.choice(len(self.replay_buffer), batch_size)
        batch = [self.replay_buffer[i] for i in indices]

        states = torch.stack([b[0] for b in batch])
        target_policies = torch.FloatTensor(np.array([b[1] for b in batch]))
        target_values = torch.FloatTensor([b[2] for b in batch])

        pred_policies, pred_values = self.model(states)
        pred_policies = torch.log_softmax(pred_policies, dim=-1)
        pred_values = pred_values.squeeze(-1)

        policy_loss = -(target_policies * pred_policies).sum(dim=-1).mean()
        value_loss = (target_values - pred_values).pow(2).mean()
        loss = policy_loss + value_loss

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def _evaluate(self, num_games=20):
        """Evaluate by playing against random opponent"""
        wins = 0
        for _ in range(num_games):
            result = play_against_random(self.model, self.game)
            if result > 0:
                wins += 1
        return wins / num_games

Connect4 Bot Implementation

Connect4 is a simpler game than Go but excellent for practicing AlphaZero principles.

Game Model

class Connect4:
    """Connect4 game rules"""
    ROWS = 6
    COLS = 7
    action_size = 7  # Place stone in one of 7 columns

    @staticmethod
    def get_initial_state():
        """Return empty board. 0=empty, 1=player1, -1=player2"""
        return np.zeros((Connect4.ROWS, Connect4.COLS), dtype=np.int8)

    @staticmethod
    def get_valid_actions(state):
        """List of columns where stones can be placed"""
        return [c for c in range(Connect4.COLS) if state[0, c] == 0]

    @staticmethod
    def apply_action(state, action):
        """Place stone in column (current player = 1)"""
        new_state = state.copy()
        # Find lowest empty row in the column
        for row in range(Connect4.ROWS - 1, -1, -1):
            if new_state[row, action] == 0:
                new_state[row, action] = 1
                break
        # Flip board for next turn (opponent's perspective)
        return -new_state

    @staticmethod
    def get_result(state):
        """Check game over. 1=current player wins, -1=loss, 0=draw, None=ongoing"""
        # Check 4-in-a-row (horizontal, vertical, diagonal)
        for player in [1, -1]:
            # Horizontal
            for r in range(Connect4.ROWS):
                for c in range(Connect4.COLS - 3):
                    if all(state[r, c+i] == player for i in range(4)):
                        return -1 if player == -1 else None

            # Vertical
            for r in range(Connect4.ROWS - 3):
                for c in range(Connect4.COLS):
                    if all(state[r+i, c] == player for i in range(4)):
                        return -1 if player == -1 else None

            # Diagonal (down-right)
            for r in range(Connect4.ROWS - 3):
                for c in range(Connect4.COLS - 3):
                    if all(state[r+i, c+i] == player for i in range(4)):
                        return -1 if player == -1 else None

            # Diagonal (down-left)
            for r in range(Connect4.ROWS - 3):
                for c in range(3, Connect4.COLS):
                    if all(state[r+i, c-i] == player for i in range(4)):
                        return -1 if player == -1 else None

        # Draw (board full)
        if all(state[0, c] != 0 for c in range(Connect4.COLS)):
            return 0

        return None  # Ongoing

    @staticmethod
    def state_to_tensor(state):
        """Convert board to neural network input tensor"""
        # Channel 0: current player stones, Channel 1: opponent stones
        player_plane = (state == 1).astype(np.float32)
        opponent_plane = (state == -1).astype(np.float32)
        return torch.FloatTensor(np.stack([player_plane, opponent_plane]))

Connect4 Neural Network

class Connect4Net(nn.Module):
    """Policy-value network for Connect4"""

    def __init__(self, num_res_blocks=5):
        super().__init__()
        self.conv_input = nn.Sequential(
            nn.Conv2d(2, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
        )
        self.res_blocks = nn.ModuleList([
            ResBlock(128) for _ in range(num_res_blocks)
        ])
        self.policy_head = nn.Sequential(
            nn.Conv2d(128, 32, 1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * 6 * 7, Connect4.action_size),
        )
        self.value_head = nn.Sequential(
            nn.Conv2d(128, 1, 1),
            nn.BatchNorm2d(1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(6 * 7, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Tanh(),  # Range -1 to 1
        )

    def forward(self, x):
        out = self.conv_input(x)
        for block in self.res_blocks:
            out = block(out)
        policy = self.policy_head(out)
        value = self.value_head(out)
        return policy, value

class ResBlock(nn.Module):
    """Residual block"""
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = x
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = torch.relu(out + residual)
        return out

Running Training

def train_connect4():
    """Connect4 AlphaZero training"""
    game = Connect4()
    model = Connect4Net(num_res_blocks=5)
    config = {
        'lr': 1e-3, 'weight_decay': 1e-4, 'buffer_size': 50000,
        'games_per_iteration': 100, 'mcts_simulations': 400,
        'train_epochs': 10, 'batch_size': 128, 'eval_interval': 5,
    }
    trainer = AlphaZeroTrainer(model, game, config)
    trainer.train(num_iterations=50)
    return model

def play_against_human(model, game):
    """Play against trained model"""
    mcts = MCTS(model, game, num_simulations=400)
    state = game.get_initial_state()
    human_turn = True

    while True:
        display_board(state)
        result = game.get_result(state)
        if result is not None:
            if result == 0:
                print("Draw!")
            else:
                print("Game over!")
            break

        if human_turn:
            valid = game.get_valid_actions(state)
            col = int(input(f"Choose column {valid}: "))
            state = game.apply_action(state, col)
        else:
            root = mcts.search(state)
            action = root.best_action(temperature=0.01)
            print(f"AI chooses: column {action}")
            state = game.apply_action(state, action)

        human_turn = not human_turn

def display_board(state):
    """Display board"""
    symbols = {0: '.', 1: 'O', -1: 'X'}
    print("\n 0 1 2 3 4 5 6")
    for row in range(Connect4.ROWS):
        line = ' '.join(symbols[state[row, c]] for c in range(Connect4.COLS))
        print(f" {line}")
    print()

Training Results

After 50 iterations of training on Connect4:

  • Random opponent: Over 95% win rate
  • Minimax (depth 4): Over 80% win rate
  • Training time: Approximately 2-4 hours on GPU

The key insight of AlphaZero is the virtuous cycle where MCTS improves the policy, and the improved policy generates better training data.


Key Takeaways

  • AlphaGo Zero reached superhuman level using only self-play without human data
  • MCTS efficiently searches game trees by leveraging neural network policy/value predictions
  • In self-play, MCTS visit counts become the policy training target, and game results become the value training target
  • The effectiveness of the AlphaZero framework can be verified even in simple games like Connect4

In the next post, we will broadly explore practical applications of deep reinforcement learning.