Skip to content

Split View: [심층 강화학습] 12. 강화학습으로 챗봇 훈련하기

|

[심층 강화학습] 12. 강화학습으로 챗봇 훈련하기

개요

강화학습은 게임이나 로봇 제어뿐 아니라 자연어 처리(NLP) 에도 적용된다. 특히 대화 시스템(챗봇)에서는 "좋은 대화"라는 보상 신호를 정의하기 어렵기 때문에, 기존 지도학습의 한계를 강화학습으로 극복할 수 있다.

이 글에서는 Seq2Seq 모델의 기초부터 시작하여, 강화학습을 적용해 챗봇의 응답 품질을 향상시키는 방법을 다룬다.


딥 NLP 기초

순환 신경망 (RNN)

자연어는 순서가 있는 데이터이다. 순환 신경망(Recurrent Neural Network)은 이전 시간 단계의 정보를 현재 계산에 반영한다.

import torch
import torch.nn as nn

class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden=None):
        # x: (batch, seq_len, input_size)
        output, hidden = self.rnn(x, hidden)
        # output: (batch, seq_len, hidden_size)
        return self.fc(output), hidden

RNN의 한계인 장기 의존성 문제를 해결하기 위해 LSTM(Long Short-Term Memory)GRU(Gated Recurrent Unit) 가 사용된다.

단어 임베딩 (Word Embeddings)

단어를 밀집 벡터로 표현하는 임베딩은 NLP의 핵심 요소이다:

class EmbeddingLayer(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)

    def forward(self, token_ids):
        # token_ids: (batch, seq_len) 정수 텐서
        # 출력: (batch, seq_len, embed_dim) 실수 텐서
        return self.embedding(token_ids)

Word2Vec, GloVe 등의 사전 학습 임베딩을 사용하거나, 모델과 함께 처음부터 학습할 수 있다.

인코더-디코더 (Encoder-Decoder)

Seq2Seq 모델은 인코더가 입력 시퀀스를 고정 길이 벡터로 압축하고, 디코더가 이 벡터를 바탕으로 출력 시퀀스를 생성한다.

class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_size, batch_first=True)

    def forward(self, input_ids):
        embedded = self.embedding(input_ids)
        outputs, (hidden, cell) = self.lstm(embedded)
        return hidden, cell

class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, input_id, hidden, cell):
        # input_id: (batch, 1) - 한 토큰씩 생성
        embedded = self.embedding(input_id)
        output, (hidden, cell) = self.lstm(embedded, (hidden, cell))
        prediction = self.fc(output.squeeze(1))
        return prediction, hidden, cell

Seq2Seq 학습: 지도 학습 방식

로그 우도 학습 (Log-Likelihood Training)

가장 기본적인 Seq2Seq 학습 방법은 교사 강요(Teacher Forcing) 를 사용한 로그 우도 최대화이다:

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, vocab_size):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.vocab_size = vocab_size

    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        batch_size = src.shape[0]
        trg_len = trg.shape[1]
        outputs = torch.zeros(batch_size, trg_len, self.vocab_size)

        hidden, cell = self.encoder(src)

        # 첫 입력은 SOS 토큰
        input_token = trg[:, 0:1]

        for t in range(1, trg_len):
            prediction, hidden, cell = self.decoder(input_token, hidden, cell)
            outputs[:, t] = prediction

            # Teacher Forcing: 정답을 다음 입력으로 사용
            if torch.rand(1).item() < teacher_forcing_ratio:
                input_token = trg[:, t:t+1]
            else:
                input_token = prediction.argmax(dim=-1, keepdim=True)

        return outputs

def train_supervised(model, dataloader, optimizer, criterion):
    model.train()
    total_loss = 0

    for src, trg in dataloader:
        optimizer.zero_grad()
        output = model(src, trg)

        # 크로스 엔트로피 손실
        output = output[:, 1:].reshape(-1, output.shape[-1])
        trg = trg[:, 1:].reshape(-1)
        loss = criterion(output, trg)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(dataloader)

BLEU 점수

지도 학습의 평가 지표로 BLEU(Bilingual Evaluation Understudy) 점수가 사용된다. n-gram 겹침을 기반으로 생성된 텍스트의 품질을 측정한다.

from collections import Counter
import math

def compute_bleu(reference, hypothesis, max_n=4):
    """간단한 BLEU 점수 계산"""
    scores = []
    for n in range(1, max_n + 1):
        ref_ngrams = Counter(zip(*[reference[i:] for i in range(n)]))
        hyp_ngrams = Counter(zip(*[hypothesis[i:] for i in range(n)]))

        # 클리핑된 카운트
        clipped = sum(min(hyp_ngrams[ng], ref_ngrams[ng])
                      for ng in hyp_ngrams)
        total = max(sum(hyp_ngrams.values()), 1)
        scores.append(clipped / total)

    # 기하 평균
    if min(scores) > 0:
        log_avg = sum(math.log(s) for s in scores) / len(scores)
        bleu = math.exp(log_avg)
    else:
        bleu = 0.0

    # Brevity Penalty
    bp = min(1.0, math.exp(1 - len(reference) / max(len(hypothesis), 1)))
    return bp * bleu

지도 학습의 한계

  1. 노출 편향(Exposure Bias): 학습 시 Teacher Forcing으로 정답을 보지만, 추론 시에는 자신의 출력을 입력으로 사용한다. 학습과 추론의 불일치가 오류를 누적시킨다.
  2. 메트릭 불일치: 크로스 엔트로피를 최소화하지만, 실제 평가는 BLEU나 인간 만족도로 한다.
  3. 다양성 부족: 데이터셋의 평균적인 응답으로 수렴하여 다양하고 흥미로운 대화를 생성하기 어렵다.

강화학습을 Seq2Seq에 적용

문제를 MDP로 정의

Seq2Seq 텍스트 생성을 강화학습 문제로 재구성할 수 있다:

  • 상태: 인코더 출력 + 지금까지 생성된 토큰
  • 행동: 어휘 사전에서 다음 토큰 선택
  • 보상: 시퀀스 완성 후 BLEU 점수 또는 기타 메트릭
  • 정책: 디코더의 출력 확률 분포
def generate_with_policy(model, src, max_len=50, sos_token=1, eos_token=2):
    """정책(디코더)으로 시퀀스 생성하며 로그 확률 기록"""
    model.eval()
    hidden, cell = model.encoder(src)

    input_token = torch.tensor([[sos_token]])
    generated_tokens = []
    log_probs = []

    for _ in range(max_len):
        prediction, hidden, cell = model.decoder(input_token, hidden, cell)
        probs = torch.softmax(prediction, dim=-1)
        dist = torch.distributions.Categorical(probs)

        token = dist.sample()
        log_prob = dist.log_prob(token)

        generated_tokens.append(token.item())
        log_probs.append(log_prob)

        if token.item() == eos_token:
            break

        input_token = token.unsqueeze(0)

    return generated_tokens, log_probs

Self-Critical Sequence Training (SCST)

SCST는 REINFORCE 알고리즘의 변형으로, 기준선(baseline) 으로 그리디 디코딩의 보상을 사용한다. 이를 통해 분산을 줄이고 학습을 안정화한다.

SCST의 핵심 아이디어

  1. 샘플링 경로: 정책에서 확률적으로 토큰을 샘플링하여 시퀀스 생성 후 보상(BLEU) 계산
  2. 그리디 경로: 매 시간 단계에서 가장 확률 높은 토큰을 선택하여 시퀀스 생성 후 보상 계산
  3. 어드밴티지: 샘플링 보상 - 그리디 보상
def self_critical_loss(model, src, reference, reward_fn):
    """Self-Critical Sequence Training 손실 함수"""
    # 1. 샘플링으로 시퀀스 생성
    sampled_tokens, log_probs = generate_with_policy(model, src)
    sampled_reward = reward_fn(reference, sampled_tokens)

    # 2. 그리디로 시퀀스 생성 (기준선)
    with torch.no_grad():
        greedy_tokens = greedy_decode(model, src)
        baseline_reward = reward_fn(reference, greedy_tokens)

    # 3. REINFORCE with baseline
    advantage = sampled_reward - baseline_reward

    # 정책 그래디언트 손실
    policy_loss = 0.0
    for log_prob in log_probs:
        policy_loss -= log_prob * advantage

    return policy_loss / len(log_probs)

def greedy_decode(model, src, max_len=50, sos_token=1, eos_token=2):
    """그리디 디코딩: 항상 가장 높은 확률의 토큰 선택"""
    model.eval()
    hidden, cell = model.encoder(src)

    input_token = torch.tensor([[sos_token]])
    generated_tokens = []

    with torch.no_grad():
        for _ in range(max_len):
            prediction, hidden, cell = model.decoder(input_token, hidden, cell)
            token = prediction.argmax(dim=-1)
            generated_tokens.append(token.item())

            if token.item() == eos_token:
                break
            input_token = token.unsqueeze(0)

    return generated_tokens

챗봇 구현

데이터 준비

Cornell Movie Dialog Corpus와 같은 대화 데이터셋을 사용한다:

import json

class DialogDataset:
    def __init__(self, data_path, vocab, max_len=50):
        self.pairs = []
        self.vocab = vocab
        self.max_len = max_len
        self._load_data(data_path)

    def _load_data(self, path):
        with open(path, 'r') as f:
            for line in f:
                pair = json.loads(line)
                src = self.vocab.encode(pair['input'])[:self.max_len]
                trg = self.vocab.encode(pair['response'])[:self.max_len]
                self.pairs.append((src, trg))

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        src, trg = self.pairs[idx]
        return (torch.tensor(src, dtype=torch.long),
                torch.tensor(trg, dtype=torch.long))

2단계 학습 파이프라인

def train_chatbot(model, train_data, val_data, config):
    """2단계 챗봇 학습"""
    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
    criterion = nn.CrossEntropyLoss(ignore_index=0)  # PAD 무시

    # === 1단계: 지도 학습 (Teacher Forcing) ===
    print("1단계: 지도 학습 시작")
    for epoch in range(config['supervised_epochs']):
        loss = train_supervised(model, train_data, optimizer, criterion)
        bleu = evaluate_bleu(model, val_data)
        print(f"Epoch {epoch}: Loss={loss:.4f}, BLEU={bleu:.4f}")

    # === 2단계: SCST 강화학습 ===
    print("2단계: SCST 강화학습 시작")
    rl_optimizer = torch.optim.Adam(model.parameters(),
                                     lr=config['rl_lr'])  # 더 작은 학습률

    for epoch in range(config['rl_epochs']):
        model.train()
        total_rl_loss = 0.0

        for src, trg in train_data:
            rl_optimizer.zero_grad()

            # BLEU를 보상 함수로 사용
            loss = self_critical_loss(
                model, src, trg,
                reward_fn=compute_bleu
            )

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            rl_optimizer.step()
            total_rl_loss += loss.item()

        avg_loss = total_rl_loss / len(train_data)
        bleu = evaluate_bleu(model, val_data)
        print(f"RL Epoch {epoch}: Loss={avg_loss:.4f}, BLEU={bleu:.4f}")

    return model

대화 테스트

def chat(model, vocab, max_len=50):
    """대화 인터페이스"""
    model.eval()
    print("챗봇과 대화를 시작합니다. 'quit'를 입력하면 종료됩니다.")

    while True:
        user_input = input("사용자: ")
        if user_input.lower() == 'quit':
            break

        # 입력 인코딩
        tokens = vocab.encode(user_input)
        src = torch.tensor([tokens], dtype=torch.long)

        # 그리디 디코딩으로 응답 생성
        response_tokens = greedy_decode(model, src, max_len)
        response = vocab.decode(response_tokens)
        print(f"챗봇: {response}")

실전 고려사항

보상 설계의 중요성

BLEU 점수만으로는 좋은 대화 품질을 보장하지 못한다. 실제 챗봇 시스템에서는 여러 보상을 조합한다:

  • 유창성: 언어 모델 perplexity
  • 관련성: 입력과 응답의 의미 유사도
  • 다양성: 반복 패턴 페널티
  • 안전성: 유해 콘텐츠 필터링 점수
def combined_reward(reference, hypothesis, input_text=None):
    bleu = compute_bleu(reference, hypothesis)
    diversity = compute_distinct_ngrams(hypothesis)
    repetition_penalty = compute_repetition_penalty(hypothesis)

    return 0.5 * bleu + 0.3 * diversity - 0.2 * repetition_penalty

최신 동향

현대의 대화 시스템은 Transformer 기반 대규모 언어모델(LLM)과 RLHF(Reinforcement Learning from Human Feedback)를 사용한다. 이 글에서 다룬 SCST의 개념은 RLHF의 PPO 기반 미세조정과 직접적으로 연결된다.


핵심 요약

  • Seq2Seq 모델은 인코더-디코더 구조로 시퀀스 변환 문제를 해결한다
  • 지도 학습의 노출 편향과 메트릭 불일치 문제를 강화학습이 완화한다
  • SCST는 그리디 디코딩을 기준선으로 사용하여 분산을 줄인 REINFORCE 변형이다
  • 보상 함수 설계가 챗봇 품질의 핵심이다

다음 글에서는 강화학습을 웹 내비게이션에 적용하는 방법을 살펴보겠다.

[Deep RL] 12. Training Chatbots with Reinforcement Learning

Overview

Reinforcement learning is applied not only to games and robot control but also to natural language processing (NLP). In dialogue systems (chatbots) in particular, defining a reward signal for "good conversation" is difficult, so reinforcement learning can overcome the limitations of conventional supervised learning.

This post covers the basics of Seq2Seq models and how to apply reinforcement learning to improve chatbot response quality.


Deep NLP Basics

Recurrent Neural Networks (RNN)

Natural language is sequential data. Recurrent Neural Networks (RNNs) incorporate information from previous time steps into the current computation.

import torch
import torch.nn as nn

class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden=None):
        # x: (batch, seq_len, input_size)
        output, hidden = self.rnn(x, hidden)
        # output: (batch, seq_len, hidden_size)
        return self.fc(output), hidden

To address the long-term dependency problem of RNNs, LSTM (Long Short-Term Memory) and GRU (Gated Recurrent Unit) are used.

Word Embeddings

Embeddings that represent words as dense vectors are a core element of NLP:

class EmbeddingLayer(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)

    def forward(self, token_ids):
        # token_ids: (batch, seq_len) integer tensor
        # output: (batch, seq_len, embed_dim) float tensor
        return self.embedding(token_ids)

You can use pretrained embeddings like Word2Vec or GloVe, or train them from scratch together with the model.

Encoder-Decoder

A Seq2Seq model has an encoder that compresses the input sequence into a fixed-length vector, and a decoder that generates an output sequence based on this vector.

class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_size, batch_first=True)

    def forward(self, input_ids):
        embedded = self.embedding(input_ids)
        outputs, (hidden, cell) = self.lstm(embedded)
        return hidden, cell

class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, input_id, hidden, cell):
        # input_id: (batch, 1) - generates one token at a time
        embedded = self.embedding(input_id)
        output, (hidden, cell) = self.lstm(embedded, (hidden, cell))
        prediction = self.fc(output.squeeze(1))
        return prediction, hidden, cell

Seq2Seq Training: Supervised Learning Approach

Log-Likelihood Training

The most basic Seq2Seq training method is log-likelihood maximization using Teacher Forcing:

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, vocab_size):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.vocab_size = vocab_size

    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        batch_size = src.shape[0]
        trg_len = trg.shape[1]
        outputs = torch.zeros(batch_size, trg_len, self.vocab_size)

        hidden, cell = self.encoder(src)

        # First input is the SOS token
        input_token = trg[:, 0:1]

        for t in range(1, trg_len):
            prediction, hidden, cell = self.decoder(input_token, hidden, cell)
            outputs[:, t] = prediction

            # Teacher Forcing: use ground truth as next input
            if torch.rand(1).item() < teacher_forcing_ratio:
                input_token = trg[:, t:t+1]
            else:
                input_token = prediction.argmax(dim=-1, keepdim=True)

        return outputs

def train_supervised(model, dataloader, optimizer, criterion):
    model.train()
    total_loss = 0

    for src, trg in dataloader:
        optimizer.zero_grad()
        output = model(src, trg)

        # Cross-entropy loss
        output = output[:, 1:].reshape(-1, output.shape[-1])
        trg = trg[:, 1:].reshape(-1)
        loss = criterion(output, trg)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(dataloader)

BLEU Score

BLEU (Bilingual Evaluation Understudy) score is used as an evaluation metric for supervised learning. It measures the quality of generated text based on n-gram overlap.

from collections import Counter
import math

def compute_bleu(reference, hypothesis, max_n=4):
    """Simple BLEU score computation"""
    scores = []
    for n in range(1, max_n + 1):
        ref_ngrams = Counter(zip(*[reference[i:] for i in range(n)]))
        hyp_ngrams = Counter(zip(*[hypothesis[i:] for i in range(n)]))

        # Clipped counts
        clipped = sum(min(hyp_ngrams[ng], ref_ngrams[ng])
                      for ng in hyp_ngrams)
        total = max(sum(hyp_ngrams.values()), 1)
        scores.append(clipped / total)

    # Geometric mean
    if min(scores) > 0:
        log_avg = sum(math.log(s) for s in scores) / len(scores)
        bleu = math.exp(log_avg)
    else:
        bleu = 0.0

    # Brevity Penalty
    bp = min(1.0, math.exp(1 - len(reference) / max(len(hypothesis), 1)))
    return bp * bleu

Limitations of Supervised Learning

  1. Exposure Bias: During training, Teacher Forcing provides ground truth, but during inference, the model uses its own outputs as input. This train-test mismatch causes error accumulation.
  2. Metric Mismatch: Cross-entropy is minimized during training, but actual evaluation uses BLEU or human satisfaction.
  3. Lack of Diversity: The model converges to average responses in the dataset, making it difficult to generate diverse and interesting conversations.

Applying Reinforcement Learning to Seq2Seq

Defining the Problem as an MDP

Seq2Seq text generation can be reformulated as a reinforcement learning problem:

  • State: Encoder output + tokens generated so far
  • Action: Selecting the next token from the vocabulary
  • Reward: BLEU score or other metrics after sequence completion
  • Policy: The decoder's output probability distribution
def generate_with_policy(model, src, max_len=50, sos_token=1, eos_token=2):
    """Generate a sequence using the policy (decoder) while recording log probabilities"""
    model.eval()
    hidden, cell = model.encoder(src)

    input_token = torch.tensor([[sos_token]])
    generated_tokens = []
    log_probs = []

    for _ in range(max_len):
        prediction, hidden, cell = model.decoder(input_token, hidden, cell)
        probs = torch.softmax(prediction, dim=-1)
        dist = torch.distributions.Categorical(probs)

        token = dist.sample()
        log_prob = dist.log_prob(token)

        generated_tokens.append(token.item())
        log_probs.append(log_prob)

        if token.item() == eos_token:
            break

        input_token = token.unsqueeze(0)

    return generated_tokens, log_probs

Self-Critical Sequence Training (SCST)

SCST is a variant of the REINFORCE algorithm that uses the reward from greedy decoding as the baseline. This reduces variance and stabilizes learning.

Core Idea of SCST

  1. Sampling path: Stochastically sample tokens from the policy to generate a sequence and compute the reward (BLEU)
  2. Greedy path: Select the highest probability token at each time step to generate a sequence and compute the reward
  3. Advantage: Sampling reward - Greedy reward
def self_critical_loss(model, src, reference, reward_fn):
    """Self-Critical Sequence Training loss function"""
    # 1. Generate sequence by sampling
    sampled_tokens, log_probs = generate_with_policy(model, src)
    sampled_reward = reward_fn(reference, sampled_tokens)

    # 2. Generate sequence greedily (baseline)
    with torch.no_grad():
        greedy_tokens = greedy_decode(model, src)
        baseline_reward = reward_fn(reference, greedy_tokens)

    # 3. REINFORCE with baseline
    advantage = sampled_reward - baseline_reward

    # Policy gradient loss
    policy_loss = 0.0
    for log_prob in log_probs:
        policy_loss -= log_prob * advantage

    return policy_loss / len(log_probs)

def greedy_decode(model, src, max_len=50, sos_token=1, eos_token=2):
    """Greedy decoding: always select the highest probability token"""
    model.eval()
    hidden, cell = model.encoder(src)

    input_token = torch.tensor([[sos_token]])
    generated_tokens = []

    with torch.no_grad():
        for _ in range(max_len):
            prediction, hidden, cell = model.decoder(input_token, hidden, cell)
            token = prediction.argmax(dim=-1)
            generated_tokens.append(token.item())

            if token.item() == eos_token:
                break
            input_token = token.unsqueeze(0)

    return generated_tokens

Chatbot Implementation

Data Preparation

We use a dialogue dataset such as the Cornell Movie Dialog Corpus:

import json

class DialogDataset:
    def __init__(self, data_path, vocab, max_len=50):
        self.pairs = []
        self.vocab = vocab
        self.max_len = max_len
        self._load_data(data_path)

    def _load_data(self, path):
        with open(path, 'r') as f:
            for line in f:
                pair = json.loads(line)
                src = self.vocab.encode(pair['input'])[:self.max_len]
                trg = self.vocab.encode(pair['response'])[:self.max_len]
                self.pairs.append((src, trg))

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        src, trg = self.pairs[idx]
        return (torch.tensor(src, dtype=torch.long),
                torch.tensor(trg, dtype=torch.long))

Two-Stage Training Pipeline

def train_chatbot(model, train_data, val_data, config):
    """Two-stage chatbot training"""
    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
    criterion = nn.CrossEntropyLoss(ignore_index=0)  # Ignore PAD

    # === Stage 1: Supervised Learning (Teacher Forcing) ===
    print("Stage 1: Starting supervised learning")
    for epoch in range(config['supervised_epochs']):
        loss = train_supervised(model, train_data, optimizer, criterion)
        bleu = evaluate_bleu(model, val_data)
        print(f"Epoch {epoch}: Loss={loss:.4f}, BLEU={bleu:.4f}")

    # === Stage 2: SCST Reinforcement Learning ===
    print("Stage 2: Starting SCST reinforcement learning")
    rl_optimizer = torch.optim.Adam(model.parameters(),
                                     lr=config['rl_lr'])  # Smaller learning rate

    for epoch in range(config['rl_epochs']):
        model.train()
        total_rl_loss = 0.0

        for src, trg in train_data:
            rl_optimizer.zero_grad()

            # Use BLEU as the reward function
            loss = self_critical_loss(
                model, src, trg,
                reward_fn=compute_bleu
            )

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            rl_optimizer.step()
            total_rl_loss += loss.item()

        avg_loss = total_rl_loss / len(train_data)
        bleu = evaluate_bleu(model, val_data)
        print(f"RL Epoch {epoch}: Loss={avg_loss:.4f}, BLEU={bleu:.4f}")

    return model

Dialogue Testing

def chat(model, vocab, max_len=50):
    """Dialogue interface"""
    model.eval()
    print("Starting conversation with the chatbot. Type 'quit' to exit.")

    while True:
        user_input = input("User: ")
        if user_input.lower() == 'quit':
            break

        # Encode input
        tokens = vocab.encode(user_input)
        src = torch.tensor([tokens], dtype=torch.long)

        # Generate response via greedy decoding
        response_tokens = greedy_decode(model, src, max_len)
        response = vocab.decode(response_tokens)
        print(f"Chatbot: {response}")

Practical Considerations

Importance of Reward Design

BLEU score alone does not guarantee good dialogue quality. In real chatbot systems, multiple rewards are combined:

  • Fluency: Language model perplexity
  • Relevance: Semantic similarity between input and response
  • Diversity: Repetition pattern penalty
  • Safety: Harmful content filtering score
def combined_reward(reference, hypothesis, input_text=None):
    bleu = compute_bleu(reference, hypothesis)
    diversity = compute_distinct_ngrams(hypothesis)
    repetition_penalty = compute_repetition_penalty(hypothesis)

    return 0.5 * bleu + 0.3 * diversity - 0.2 * repetition_penalty

Modern dialogue systems use Transformer-based large language models (LLMs) and RLHF (Reinforcement Learning from Human Feedback). The SCST concepts covered in this post are directly connected to PPO-based fine-tuning in RLHF.


Key Takeaways

  • Seq2Seq models solve sequence transformation problems with an encoder-decoder architecture
  • Reinforcement learning mitigates the exposure bias and metric mismatch issues of supervised learning
  • SCST is a REINFORCE variant that reduces variance by using greedy decoding as the baseline
  • Reward function design is key to chatbot quality

In the next post, we will explore how to apply reinforcement learning to web navigation.