Skip to content
Published on

[深層強化学習] 12. 強化学習でチャットボットを訓練する

Authors

概要

強化学習はゲームやロボット制御だけでなく、**自然言語処理(NLP)**にも適用されます。特に対話システム(チャットボット)では、「良い会話」という報酬信号を定義するのが難しいため、従来の教師あり学習の限界を強化学習で克服できます。

この記事では、Seq2Seqモデルの基礎から始めて、強化学習を適用してチャットボットの応答品質を向上させる方法を扱います。


深層NLPの基礎

リカレントニューラルネットワーク(RNN)

自然言語は順序のあるデータです。リカレントニューラルネットワーク(Recurrent Neural Network)は、以前の時間ステップの情報を現在の計算に反映します。

import torch
import torch.nn as nn

class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden=None):
        # x: (batch, seq_len, input_size)
        output, hidden = self.rnn(x, hidden)
        # output: (batch, seq_len, hidden_size)
        return self.fc(output), hidden

RNNの長期依存性問題を解決するために、**LSTM(Long Short-Term Memory)GRU(Gated Recurrent Unit)**が使用されます。

単語埋め込み(Word Embeddings)

単語を密ベクトルで表現する埋め込みはNLPの核心要素です:

class EmbeddingLayer(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)

    def forward(self, token_ids):
        # token_ids: (batch, seq_len) 整数テンソル
        # 出力: (batch, seq_len, embed_dim) 実数テンソル
        return self.embedding(token_ids)

Word2Vec、GloVeなどの事前学習済み埋め込みを使用するか、モデルと一緒にゼロから学習することができます。

エンコーダ・デコーダ(Encoder-Decoder)

Seq2Seqモデルはエンコーダが入力シーケンスを固定長ベクトルに圧縮し、デコーダがこのベクトルを基に出力シーケンスを生成します。

class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_size, batch_first=True)

    def forward(self, input_ids):
        embedded = self.embedding(input_ids)
        outputs, (hidden, cell) = self.lstm(embedded)
        return hidden, cell

class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, input_id, hidden, cell):
        # input_id: (batch, 1) - 1トークンずつ生成
        embedded = self.embedding(input_id)
        output, (hidden, cell) = self.lstm(embedded, (hidden, cell))
        prediction = self.fc(output.squeeze(1))
        return prediction, hidden, cell

Seq2Seq学習:教師あり学習方式

対数尤度学習(Log-Likelihood Training)

最も基本的なSeq2Seq学習方法は**Teacher Forcing(教師強制)**を使用した対数尤度最大化です:

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, vocab_size):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.vocab_size = vocab_size

    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        batch_size = src.shape[0]
        trg_len = trg.shape[1]
        outputs = torch.zeros(batch_size, trg_len, self.vocab_size)

        hidden, cell = self.encoder(src)

        # 最初の入力はSOSトークン
        input_token = trg[:, 0:1]

        for t in range(1, trg_len):
            prediction, hidden, cell = self.decoder(input_token, hidden, cell)
            outputs[:, t] = prediction

            # Teacher Forcing: 正解を次の入力として使用
            if torch.rand(1).item() < teacher_forcing_ratio:
                input_token = trg[:, t:t+1]
            else:
                input_token = prediction.argmax(dim=-1, keepdim=True)

        return outputs

def train_supervised(model, dataloader, optimizer, criterion):
    model.train()
    total_loss = 0

    for src, trg in dataloader:
        optimizer.zero_grad()
        output = model(src, trg)

        # クロスエントロピー損失
        output = output[:, 1:].reshape(-1, output.shape[-1])
        trg = trg[:, 1:].reshape(-1)
        loss = criterion(output, trg)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(dataloader)

BLEUスコア

教師あり学習の評価指標として**BLEU(Bilingual Evaluation Understudy)**スコアが使用されます。n-gramの重複に基づいて生成テキストの品質を測定します。

from collections import Counter
import math

def compute_bleu(reference, hypothesis, max_n=4):
    """簡単なBLEUスコア計算"""
    scores = []
    for n in range(1, max_n + 1):
        ref_ngrams = Counter(zip(*[reference[i:] for i in range(n)]))
        hyp_ngrams = Counter(zip(*[hypothesis[i:] for i in range(n)]))

        # クリッピングされたカウント
        clipped = sum(min(hyp_ngrams[ng], ref_ngrams[ng])
                      for ng in hyp_ngrams)
        total = max(sum(hyp_ngrams.values()), 1)
        scores.append(clipped / total)

    # 幾何平均
    if min(scores) > 0:
        log_avg = sum(math.log(s) for s in scores) / len(scores)
        bleu = math.exp(log_avg)
    else:
        bleu = 0.0

    # Brevity Penalty
    bp = min(1.0, math.exp(1 - len(reference) / max(len(hypothesis), 1)))
    return bp * bleu

教師あり学習の限界

  1. 露出バイアス(Exposure Bias):学習時はTeacher Forcingで正解を見ますが、推論時は自分の出力を入力として使用します。学習と推論の不一致がエラーを蓄積させます。
  2. メトリクスの不一致:クロスエントロピーを最小化しますが、実際の評価はBLEUや人間の満足度で行います。
  3. 多様性の不足:データセットの平均的な応答に収束し、多様で興味深い会話を生成するのが困難です。

強化学習をSeq2Seqに適用

問題をMDPとして定義

Seq2Seqテキスト生成を強化学習問題として再構成できます:

  • 状態:エンコーダ出力 + これまでに生成されたトークン
  • 行動:語彙辞書から次のトークンを選択
  • 報酬:シーケンス完成後のBLEUスコアまたはその他のメトリクス
  • 方策:デコーダの出力確率分布
def generate_with_policy(model, src, max_len=50, sos_token=1, eos_token=2):
    """方策(デコーダ)でシーケンスを生成しながらログ確率を記録"""
    model.eval()
    hidden, cell = model.encoder(src)

    input_token = torch.tensor([[sos_token]])
    generated_tokens = []
    log_probs = []

    for _ in range(max_len):
        prediction, hidden, cell = model.decoder(input_token, hidden, cell)
        probs = torch.softmax(prediction, dim=-1)
        dist = torch.distributions.Categorical(probs)

        token = dist.sample()
        log_prob = dist.log_prob(token)

        generated_tokens.append(token.item())
        log_probs.append(log_prob)

        if token.item() == eos_token:
            break

        input_token = token.unsqueeze(0)

    return generated_tokens, log_probs

Self-Critical Sequence Training(SCST)

SCSTはREINFORCEアルゴリズムの変形で、**ベースライン(baseline)**としてグリーディデコーディングの報酬を使用します。これにより分散を減らし学習を安定化させます。

SCSTの核心アイデア

  1. サンプリングパス:方策から確率的にトークンをサンプリングしてシーケンスを生成し報酬(BLEU)を計算
  2. グリーディパス:各時間ステップで最も確率の高いトークンを選択してシーケンスを生成し報酬を計算
  3. アドバンテージ:サンプリング報酬 - グリーディ報酬
def self_critical_loss(model, src, reference, reward_fn):
    """Self-Critical Sequence Training損失関数"""
    # 1. サンプリングでシーケンス生成
    sampled_tokens, log_probs = generate_with_policy(model, src)
    sampled_reward = reward_fn(reference, sampled_tokens)

    # 2. グリーディでシーケンス生成(ベースライン)
    with torch.no_grad():
        greedy_tokens = greedy_decode(model, src)
        baseline_reward = reward_fn(reference, greedy_tokens)

    # 3. REINFORCE with baseline
    advantage = sampled_reward - baseline_reward

    # 方策勾配損失
    policy_loss = 0.0
    for log_prob in log_probs:
        policy_loss -= log_prob * advantage

    return policy_loss / len(log_probs)

def greedy_decode(model, src, max_len=50, sos_token=1, eos_token=2):
    """グリーディデコーディング:常に最も高い確率のトークンを選択"""
    model.eval()
    hidden, cell = model.encoder(src)

    input_token = torch.tensor([[sos_token]])
    generated_tokens = []

    with torch.no_grad():
        for _ in range(max_len):
            prediction, hidden, cell = model.decoder(input_token, hidden, cell)
            token = prediction.argmax(dim=-1)
            generated_tokens.append(token.item())

            if token.item() == eos_token:
                break
            input_token = token.unsqueeze(0)

    return generated_tokens

チャットボットの実装

データ準備

Cornell Movie Dialog Corpusなどの対話データセットを使用します:

import json

class DialogDataset:
    def __init__(self, data_path, vocab, max_len=50):
        self.pairs = []
        self.vocab = vocab
        self.max_len = max_len
        self._load_data(data_path)

    def _load_data(self, path):
        with open(path, 'r') as f:
            for line in f:
                pair = json.loads(line)
                src = self.vocab.encode(pair['input'])[:self.max_len]
                trg = self.vocab.encode(pair['response'])[:self.max_len]
                self.pairs.append((src, trg))

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        src, trg = self.pairs[idx]
        return (torch.tensor(src, dtype=torch.long),
                torch.tensor(trg, dtype=torch.long))

2段階学習パイプライン

def train_chatbot(model, train_data, val_data, config):
    """2段階チャットボット学習"""
    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
    criterion = nn.CrossEntropyLoss(ignore_index=0)  # PADを無視

    # === 第1段階:教師あり学習(Teacher Forcing)===
    print("第1段階:教師あり学習開始")
    for epoch in range(config['supervised_epochs']):
        loss = train_supervised(model, train_data, optimizer, criterion)
        bleu = evaluate_bleu(model, val_data)
        print(f"Epoch {epoch}: Loss={loss:.4f}, BLEU={bleu:.4f}")

    # === 第2段階:SCST強化学習 ===
    print("第2段階:SCST強化学習開始")
    rl_optimizer = torch.optim.Adam(model.parameters(),
                                     lr=config['rl_lr'])  # より小さい学習率

    for epoch in range(config['rl_epochs']):
        model.train()
        total_rl_loss = 0.0

        for src, trg in train_data:
            rl_optimizer.zero_grad()

            # BLEUを報酬関数として使用
            loss = self_critical_loss(
                model, src, trg,
                reward_fn=compute_bleu
            )

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            rl_optimizer.step()
            total_rl_loss += loss.item()

        avg_loss = total_rl_loss / len(train_data)
        bleu = evaluate_bleu(model, val_data)
        print(f"RL Epoch {epoch}: Loss={avg_loss:.4f}, BLEU={bleu:.4f}")

    return model

対話テスト

def chat(model, vocab, max_len=50):
    """対話インターフェース"""
    model.eval()
    print("チャットボットとの会話を始めます。'quit'と入力すると終了します。")

    while True:
        user_input = input("ユーザー: ")
        if user_input.lower() == 'quit':
            break

        # 入力をエンコード
        tokens = vocab.encode(user_input)
        src = torch.tensor([tokens], dtype=torch.long)

        # グリーディデコーディングで応答生成
        response_tokens = greedy_decode(model, src, max_len)
        response = vocab.decode(response_tokens)
        print(f"チャットボット: {response}")

実践的な考慮事項

報酬設計の重要性

BLEUスコアだけでは良い対話品質を保証できません。実際のチャットボットシステムでは複数の報酬を組み合わせます:

  • 流暢さ:言語モデルのperplexity
  • 関連性:入力と応答の意味的類似度
  • 多様性:繰り返しパターンへのペナルティ
  • 安全性:有害コンテンツフィルタリングスコア
def combined_reward(reference, hypothesis, input_text=None):
    bleu = compute_bleu(reference, hypothesis)
    diversity = compute_distinct_ngrams(hypothesis)
    repetition_penalty = compute_repetition_penalty(hypothesis)

    return 0.5 * bleu + 0.3 * diversity - 0.2 * repetition_penalty

最新動向

現代の対話システムはTransformerベースの大規模言語モデル(LLM)とRLHF(Reinforcement Learning from Human Feedback)を使用しています。この記事で扱ったSCSTの概念は、RLHFのPPOベースのファインチューニングと直接的につながっています。


要点まとめ

  • Seq2Seqモデルはエンコーダ・デコーダ構造でシーケンス変換問題を解決する
  • 教師あり学習の露出バイアスとメトリクス不一致の問題を強化学習が緩和する
  • SCSTはグリーディデコーディングをベースラインとして使用し分散を減らしたREINFORCEの変形である
  • 報酬関数の設計がチャットボット品質の鍵である

次の記事では、強化学習をウェブナビゲーションに適用する方法を見ていきます。