Skip to content
Published on

[심층 강화학습] 18. AlphaGo Zero: 자기 대국으로 배우는 AI

Authors

개요

2016년 AlphaGo가 이세돌 9단을 이겼을 때 전 세계가 놀랐다. 하지만 진정한 혁신은 2017년의 AlphaGo Zero에서 일어났다. 인간의 기보 데이터 없이, 오직 자기 대국(self-play) 만으로 바둑의 규칙부터 초인적 수준까지 학습한 것이다.

이 글에서는 보드 게임과 강화학습의 관계, AlphaGo Zero의 핵심 알고리즘(MCTS + 신경망 + 자기 대국), 그리고 Connect4(사목 연결) 봇을 직접 구현하는 방법을 다룬다.


보드 게임과 강화학습

보드 게임이 RL에 적합한 이유

  • 완전 정보 게임: 모든 상태가 관찰 가능하다
  • 결정적 전이: 행동의 결과가 명확하다 (주사위 게임 제외)
  • 명확한 보상: 승/패/무가 확실하다
  • 자기 대국 가능: 상대가 필요 없이 자신과 대전하며 학습

AlphaGo의 진화

버전핵심 특징데이터
AlphaGo FanCNN + MCTS + 가치 네트워크인간 기보 + 자기 대국
AlphaGo Lee더 깊은 네트워크 + 향상된 MCTS인간 기보 + 자기 대국
AlphaGo ZeroResNet + 통합 네트워크자기 대국만
AlphaZero바둑, 체스, 장기 통합자기 대국만

Monte Carlo Tree Search (MCTS)

MCTS는 게임 트리를 효율적으로 탐색하는 알고리즘이다. 전체 트리를 탐색하는 미니맥스와 달리, 유망한 부분만 집중적으로 탐색한다.

MCTS의 4단계

  1. 선택(Selection): 루트에서 시작하여 UCB(Upper Confidence Bound) 공식으로 자식 노드 선택
  2. 확장(Expansion): 리프 노드에 새 자식 노드 추가
  3. 시뮬레이션(Simulation): 새 노드에서 무작위 플레이아웃(AlphaGo Zero에서는 신경망 평가로 대체)
  4. 역전파(Backpropagation): 결과를 루트까지 전파하여 통계 업데이트
import math
import numpy as np

class MCTSNode:
    """MCTS 트리 노드"""

    def __init__(self, state, parent=None, action=None, prior=0.0):
        self.state = state
        self.parent = parent
        self.action = action  # 이 노드에 도달한 행동
        self.prior = prior    # 신경망이 예측한 사전 확률

        self.children = {}    # action -> MCTSNode
        self.visit_count = 0
        self.total_value = 0.0

    @property
    def q_value(self):
        if self.visit_count == 0:
            return 0.0
        return self.total_value / self.visit_count

    def ucb_score(self, c_puct=1.0):
        """PUCT(Predictor + UCT) 점수 계산"""
        if self.parent is None:
            return 0.0

        exploration = c_puct * self.prior * math.sqrt(
            self.parent.visit_count
        ) / (1 + self.visit_count)

        return self.q_value + exploration

    def is_leaf(self):
        return len(self.children) == 0

    def best_child(self, c_puct=1.0):
        """UCB 점수가 가장 높은 자식 선택"""
        best_score = -float('inf')
        best_node = None
        for child in self.children.values():
            score = child.ucb_score(c_puct)
            if score > best_score:
                best_score = score
                best_node = child
        return best_node

    def best_action(self, temperature=1.0):
        """방문 횟수 기반으로 행동 선택"""
        visit_counts = np.array([
            child.visit_count for child in self.children.values()
        ])
        actions = list(self.children.keys())

        if temperature == 0:
            # 탐욕적 선택
            return actions[np.argmax(visit_counts)]
        else:
            # 온도 기반 확률적 선택
            probs = visit_counts ** (1.0 / temperature)
            probs = probs / probs.sum()
            return np.random.choice(actions, p=probs)

MCTS 탐색 함수

class MCTS:
    """신경망 기반 MCTS"""

    def __init__(self, model, game, num_simulations=800, c_puct=1.0):
        self.model = model      # 정책 + 가치 네트워크
        self.game = game        # 게임 규칙
        self.num_simulations = num_simulations
        self.c_puct = c_puct

    def search(self, state):
        """MCTS 탐색 수행"""
        root = MCTSNode(state)

        # 루트 노드 확장
        self._expand(root)

        for _ in range(self.num_simulations):
            node = root

            # 1. 선택: 리프까지 내려감
            while not node.is_leaf():
                node = node.best_child(self.c_puct)

            # 게임 종료 상태인지 확인
            game_result = self.game.get_result(node.state)

            if game_result is None:
                # 2. 확장: 리프 노드 확장
                value = self._expand(node)
            else:
                # 게임 종료 시 실제 결과 사용
                value = game_result

            # 3. 역전파: 결과를 루트까지 전파
            self._backpropagate(node, value)

        return root

    def _expand(self, node):
        """노드 확장: 신경망으로 정책과 가치 예측"""
        import torch

        state_tensor = self.game.state_to_tensor(node.state)
        state_tensor = state_tensor.unsqueeze(0)

        with torch.no_grad():
            policy, value = self.model(state_tensor)
            policy = torch.softmax(policy, dim=-1).squeeze().numpy()
            value = value.item()

        # 유효한 행동에 대해서만 자식 노드 생성
        valid_actions = self.game.get_valid_actions(node.state)
        for action in valid_actions:
            child_state = self.game.apply_action(node.state, action)
            child = MCTSNode(
                state=child_state,
                parent=node,
                action=action,
                prior=policy[action],
            )
            node.children[action] = child

        return value

    def _backpropagate(self, node, value):
        """결과를 루트까지 역전파"""
        while node is not None:
            node.visit_count += 1
            # 상대 플레이어의 관점에서 가치 반전
            node.total_value += value
            value = -value  # 상대 관점
            node = node.parent

AlphaGo Zero의 자기 대국과 학습

자기 대국 (Self-Play)

import torch

def self_play_game(model, game, mcts_simulations=800, temperature=1.0):
    """자기 대국으로 학습 데이터 생성"""
    mcts = MCTS(model, game, num_simulations=mcts_simulations)
    state = game.get_initial_state()

    training_data = []  # (state, policy, value) 저장
    move_count = 0

    while True:
        # MCTS 탐색
        root = mcts.search(state)

        # 정책 타겟: 방문 횟수를 정규화
        policy_target = np.zeros(game.action_size)
        for action, child in root.children.items():
            policy_target[action] = child.visit_count
        policy_target = policy_target / policy_target.sum()

        # 초반 탐색 촉진, 후반 결정적
        temp = temperature if move_count < 30 else 0.01
        action = root.best_action(temperature=temp)

        training_data.append((
            game.state_to_tensor(state),
            policy_target,
            None,  # 가치는 게임 종료 후 채움
        ))

        # 행동 수행
        state = game.apply_action(state, action)
        move_count += 1

        # 게임 종료 확인
        result = game.get_result(state)
        if result is not None:
            # 승패 결과를 각 시점의 가치로 채움
            final_data = []
            for i, (s, p, _) in enumerate(training_data):
                # 해당 시점의 현재 플레이어 관점에서의 결과
                if i % 2 == 0:
                    value = result
                else:
                    value = -result
                final_data.append((s, p, value))
            return final_data

학습 루프

class AlphaZeroTrainer:
    """AlphaZero 학습 매니저"""

    def __init__(self, model, game, config):
        self.model = model
        self.game = game
        self.config = config
        self.optimizer = torch.optim.Adam(
            model.parameters(),
            lr=config.get('lr', 1e-3),
            weight_decay=config.get('weight_decay', 1e-4),
        )
        self.replay_buffer = []
        self.max_buffer_size = config.get('buffer_size', 100000)

    def train(self, num_iterations=100):
        for iteration in range(num_iterations):
            # 1. 자기 대국으로 데이터 생성
            print(f"반복 {iteration}: 자기 대국 중...")
            for _ in range(self.config['games_per_iteration']):
                game_data = self_play_game(
                    self.model, self.game,
                    mcts_simulations=self.config['mcts_simulations'],
                )
                self.replay_buffer.extend(game_data)

            # 버퍼 크기 제한
            if len(self.replay_buffer) > self.max_buffer_size:
                self.replay_buffer = self.replay_buffer[-self.max_buffer_size:]

            # 2. 신경망 학습
            print(f"반복 {iteration}: 학습 중...")
            for _ in range(self.config['train_epochs']):
                self._train_step()

            # 3. 평가
            if iteration % self.config['eval_interval'] == 0:
                win_rate = self._evaluate()
                print(f"반복 {iteration}: 승률 = {win_rate:.1%}")

    def _train_step(self):
        batch_size = self.config.get('batch_size', 256)
        if len(self.replay_buffer) < batch_size:
            return

        # 랜덤 샘플링
        indices = np.random.choice(len(self.replay_buffer), batch_size)
        batch = [self.replay_buffer[i] for i in indices]

        states = torch.stack([b[0] for b in batch])
        target_policies = torch.FloatTensor(
            np.array([b[1] for b in batch])
        )
        target_values = torch.FloatTensor([b[2] for b in batch])

        # 모델 예측
        pred_policies, pred_values = self.model(states)
        pred_policies = torch.log_softmax(pred_policies, dim=-1)
        pred_values = pred_values.squeeze(-1)

        # 손실 계산
        policy_loss = -(target_policies * pred_policies).sum(dim=-1).mean()
        value_loss = (target_values - pred_values).pow(2).mean()
        loss = policy_loss + value_loss

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def _evaluate(self, num_games=20):
        """무작위 상대와 대전하여 평가"""
        wins = 0
        for _ in range(num_games):
            result = play_against_random(self.model, self.game)
            if result > 0:
                wins += 1
        return wins / num_games

Connect4 봇 구현

Connect4(사목 연결)는 바둑보다 간단하면서 AlphaZero의 원리를 실습하기 좋은 게임이다.

게임 모델

class Connect4:
    """Connect4 게임 규칙"""
    ROWS = 6
    COLS = 7
    action_size = 7  # 7개의 열에 돌 놓기

    @staticmethod
    def get_initial_state():
        """빈 보드 반환. 0=빈칸, 1=플레이어1, -1=플레이어2"""
        return np.zeros((Connect4.ROWS, Connect4.COLS), dtype=np.int8)

    @staticmethod
    def get_valid_actions(state):
        """놓을 수 있는 열 목록"""
        return [c for c in range(Connect4.COLS) if state[0, c] == 0]

    @staticmethod
    def apply_action(state, action):
        """열에 돌 놓기 (현재 플레이어 = 1)"""
        new_state = state.copy()
        # 해당 열에서 가장 아래 빈칸 찾기
        for row in range(Connect4.ROWS - 1, -1, -1):
            if new_state[row, action] == 0:
                new_state[row, action] = 1
                break
        # 다음 턴을 위해 보드 뒤집기 (상대 관점)
        return -new_state

    @staticmethod
    def get_result(state):
        """게임 종료 확인. 1=현재 플레이어 승, -1=패, 0=무승부, None=진행 중"""
        # 4연속 체크 (가로, 세로, 대각선)
        for player in [1, -1]:
            # 가로
            for r in range(Connect4.ROWS):
                for c in range(Connect4.COLS - 3):
                    if all(state[r, c+i] == player for i in range(4)):
                        return -1 if player == -1 else None

            # 세로
            for r in range(Connect4.ROWS - 3):
                for c in range(Connect4.COLS):
                    if all(state[r+i, c] == player for i in range(4)):
                        return -1 if player == -1 else None

            # 대각선 (우하향)
            for r in range(Connect4.ROWS - 3):
                for c in range(Connect4.COLS - 3):
                    if all(state[r+i, c+i] == player for i in range(4)):
                        return -1 if player == -1 else None

            # 대각선 (좌하향)
            for r in range(Connect4.ROWS - 3):
                for c in range(3, Connect4.COLS):
                    if all(state[r+i, c-i] == player for i in range(4)):
                        return -1 if player == -1 else None

        # 무승부 (보드 가득 참)
        if all(state[0, c] != 0 for c in range(Connect4.COLS)):
            return 0

        return None  # 진행 중

    @staticmethod
    def state_to_tensor(state):
        """보드를 신경망 입력 텐서로 변환"""
        # 채널 0: 현재 플레이어 돌, 채널 1: 상대 플레이어 돌
        player_plane = (state == 1).astype(np.float32)
        opponent_plane = (state == -1).astype(np.float32)
        return torch.FloatTensor(
            np.stack([player_plane, opponent_plane])
        )

Connect4 신경망

class Connect4Net(nn.Module):
    """Connect4용 정책-가치 네트워크"""

    def __init__(self, num_res_blocks=5):
        super().__init__()

        # 입력 변환
        self.conv_input = nn.Sequential(
            nn.Conv2d(2, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
        )

        # Residual 블록
        self.res_blocks = nn.ModuleList([
            ResBlock(128) for _ in range(num_res_blocks)
        ])

        # 정책 헤드
        self.policy_head = nn.Sequential(
            nn.Conv2d(128, 32, 1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * 6 * 7, Connect4.action_size),
        )

        # 가치 헤드
        self.value_head = nn.Sequential(
            nn.Conv2d(128, 1, 1),
            nn.BatchNorm2d(1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(6 * 7, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Tanh(),  # -1 ~ 1 범위
        )

    def forward(self, x):
        out = self.conv_input(x)
        for block in self.res_blocks:
            out = block(out)
        policy = self.policy_head(out)
        value = self.value_head(out)
        return policy, value

class ResBlock(nn.Module):
    """Residual 블록"""

    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = x
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = torch.relu(out + residual)
        return out

학습 실행

def train_connect4():
    """Connect4 AlphaZero 학습"""
    game = Connect4()
    model = Connect4Net(num_res_blocks=5)

    config = {
        'lr': 1e-3,
        'weight_decay': 1e-4,
        'buffer_size': 50000,
        'games_per_iteration': 100,
        'mcts_simulations': 400,
        'train_epochs': 10,
        'batch_size': 128,
        'eval_interval': 5,
    }

    trainer = AlphaZeroTrainer(model, game, config)
    trainer.train(num_iterations=50)

    return model

def play_against_human(model, game):
    """학습된 모델과 인간 대전"""
    mcts = MCTS(model, game, num_simulations=400)
    state = game.get_initial_state()
    human_turn = True

    while True:
        display_board(state)
        result = game.get_result(state)
        if result is not None:
            if result == 0:
                print("무승부!")
            else:
                print("게임 종료!")
            break

        if human_turn:
            valid = game.get_valid_actions(state)
            col = int(input(f"열 선택 {valid}: "))
            state = game.apply_action(state, col)
        else:
            root = mcts.search(state)
            action = root.best_action(temperature=0.01)
            print(f"AI 선택: 열 {action}")
            state = game.apply_action(state, action)

        human_turn = not human_turn

def display_board(state):
    """보드 표시"""
    symbols = {0: '.', 1: 'O', -1: 'X'}
    print("\n 0 1 2 3 4 5 6")
    for row in range(Connect4.ROWS):
        line = ' '.join(symbols[state[row, c]] for c in range(Connect4.COLS))
        print(f" {line}")
    print()

학습 결과

Connect4에서 50회 반복 학습 후:

  • 무작위 상대: 95% 이상 승률
  • 미니맥스(깊이 4): 80% 이상 승률
  • 학습 시간: GPU 기준 약 2-4시간

AlphaZero의 핵심 통찰은 MCTS가 정책을 개선하고, 개선된 정책이 더 나은 학습 데이터를 생성하는 선순환이 일어난다는 것이다.


핵심 요약

  • AlphaGo Zero는 인간 데이터 없이 자기 대국만으로 초인적 수준에 도달했다
  • MCTS는 신경망의 정책/가치 예측을 활용하여 효율적으로 게임 트리를 탐색한다
  • 자기 대국에서 MCTS의 방문 횟수가 정책 학습의 타겟이 되고, 게임 결과가 가치 학습의 타겟이 된다
  • Connect4 같은 간단한 게임에서도 AlphaZero 프레임워크의 효과를 확인할 수 있다

다음 글에서는 심층 강화학습의 실전 응용 사례를 폭넓게 살펴보겠다.