Skip to content
Published on

[Deep RL] 12. Training Chatbots with Reinforcement Learning

Authors

Overview

Reinforcement learning is applied not only to games and robot control but also to natural language processing (NLP). In dialogue systems (chatbots) in particular, defining a reward signal for "good conversation" is difficult, so reinforcement learning can overcome the limitations of conventional supervised learning.

This post covers the basics of Seq2Seq models and how to apply reinforcement learning to improve chatbot response quality.


Deep NLP Basics

Recurrent Neural Networks (RNN)

Natural language is sequential data. Recurrent Neural Networks (RNNs) incorporate information from previous time steps into the current computation.

import torch
import torch.nn as nn

class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden=None):
        # x: (batch, seq_len, input_size)
        output, hidden = self.rnn(x, hidden)
        # output: (batch, seq_len, hidden_size)
        return self.fc(output), hidden

To address the long-term dependency problem of RNNs, LSTM (Long Short-Term Memory) and GRU (Gated Recurrent Unit) are used.

Word Embeddings

Embeddings that represent words as dense vectors are a core element of NLP:

class EmbeddingLayer(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)

    def forward(self, token_ids):
        # token_ids: (batch, seq_len) integer tensor
        # output: (batch, seq_len, embed_dim) float tensor
        return self.embedding(token_ids)

You can use pretrained embeddings like Word2Vec or GloVe, or train them from scratch together with the model.

Encoder-Decoder

A Seq2Seq model has an encoder that compresses the input sequence into a fixed-length vector, and a decoder that generates an output sequence based on this vector.

class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_size, batch_first=True)

    def forward(self, input_ids):
        embedded = self.embedding(input_ids)
        outputs, (hidden, cell) = self.lstm(embedded)
        return hidden, cell

class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, input_id, hidden, cell):
        # input_id: (batch, 1) - generates one token at a time
        embedded = self.embedding(input_id)
        output, (hidden, cell) = self.lstm(embedded, (hidden, cell))
        prediction = self.fc(output.squeeze(1))
        return prediction, hidden, cell

Seq2Seq Training: Supervised Learning Approach

Log-Likelihood Training

The most basic Seq2Seq training method is log-likelihood maximization using Teacher Forcing:

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, vocab_size):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.vocab_size = vocab_size

    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        batch_size = src.shape[0]
        trg_len = trg.shape[1]
        outputs = torch.zeros(batch_size, trg_len, self.vocab_size)

        hidden, cell = self.encoder(src)

        # First input is the SOS token
        input_token = trg[:, 0:1]

        for t in range(1, trg_len):
            prediction, hidden, cell = self.decoder(input_token, hidden, cell)
            outputs[:, t] = prediction

            # Teacher Forcing: use ground truth as next input
            if torch.rand(1).item() < teacher_forcing_ratio:
                input_token = trg[:, t:t+1]
            else:
                input_token = prediction.argmax(dim=-1, keepdim=True)

        return outputs

def train_supervised(model, dataloader, optimizer, criterion):
    model.train()
    total_loss = 0

    for src, trg in dataloader:
        optimizer.zero_grad()
        output = model(src, trg)

        # Cross-entropy loss
        output = output[:, 1:].reshape(-1, output.shape[-1])
        trg = trg[:, 1:].reshape(-1)
        loss = criterion(output, trg)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(dataloader)

BLEU Score

BLEU (Bilingual Evaluation Understudy) score is used as an evaluation metric for supervised learning. It measures the quality of generated text based on n-gram overlap.

from collections import Counter
import math

def compute_bleu(reference, hypothesis, max_n=4):
    """Simple BLEU score computation"""
    scores = []
    for n in range(1, max_n + 1):
        ref_ngrams = Counter(zip(*[reference[i:] for i in range(n)]))
        hyp_ngrams = Counter(zip(*[hypothesis[i:] for i in range(n)]))

        # Clipped counts
        clipped = sum(min(hyp_ngrams[ng], ref_ngrams[ng])
                      for ng in hyp_ngrams)
        total = max(sum(hyp_ngrams.values()), 1)
        scores.append(clipped / total)

    # Geometric mean
    if min(scores) > 0:
        log_avg = sum(math.log(s) for s in scores) / len(scores)
        bleu = math.exp(log_avg)
    else:
        bleu = 0.0

    # Brevity Penalty
    bp = min(1.0, math.exp(1 - len(reference) / max(len(hypothesis), 1)))
    return bp * bleu

Limitations of Supervised Learning

  1. Exposure Bias: During training, Teacher Forcing provides ground truth, but during inference, the model uses its own outputs as input. This train-test mismatch causes error accumulation.
  2. Metric Mismatch: Cross-entropy is minimized during training, but actual evaluation uses BLEU or human satisfaction.
  3. Lack of Diversity: The model converges to average responses in the dataset, making it difficult to generate diverse and interesting conversations.

Applying Reinforcement Learning to Seq2Seq

Defining the Problem as an MDP

Seq2Seq text generation can be reformulated as a reinforcement learning problem:

  • State: Encoder output + tokens generated so far
  • Action: Selecting the next token from the vocabulary
  • Reward: BLEU score or other metrics after sequence completion
  • Policy: The decoder's output probability distribution
def generate_with_policy(model, src, max_len=50, sos_token=1, eos_token=2):
    """Generate a sequence using the policy (decoder) while recording log probabilities"""
    model.eval()
    hidden, cell = model.encoder(src)

    input_token = torch.tensor([[sos_token]])
    generated_tokens = []
    log_probs = []

    for _ in range(max_len):
        prediction, hidden, cell = model.decoder(input_token, hidden, cell)
        probs = torch.softmax(prediction, dim=-1)
        dist = torch.distributions.Categorical(probs)

        token = dist.sample()
        log_prob = dist.log_prob(token)

        generated_tokens.append(token.item())
        log_probs.append(log_prob)

        if token.item() == eos_token:
            break

        input_token = token.unsqueeze(0)

    return generated_tokens, log_probs

Self-Critical Sequence Training (SCST)

SCST is a variant of the REINFORCE algorithm that uses the reward from greedy decoding as the baseline. This reduces variance and stabilizes learning.

Core Idea of SCST

  1. Sampling path: Stochastically sample tokens from the policy to generate a sequence and compute the reward (BLEU)
  2. Greedy path: Select the highest probability token at each time step to generate a sequence and compute the reward
  3. Advantage: Sampling reward - Greedy reward
def self_critical_loss(model, src, reference, reward_fn):
    """Self-Critical Sequence Training loss function"""
    # 1. Generate sequence by sampling
    sampled_tokens, log_probs = generate_with_policy(model, src)
    sampled_reward = reward_fn(reference, sampled_tokens)

    # 2. Generate sequence greedily (baseline)
    with torch.no_grad():
        greedy_tokens = greedy_decode(model, src)
        baseline_reward = reward_fn(reference, greedy_tokens)

    # 3. REINFORCE with baseline
    advantage = sampled_reward - baseline_reward

    # Policy gradient loss
    policy_loss = 0.0
    for log_prob in log_probs:
        policy_loss -= log_prob * advantage

    return policy_loss / len(log_probs)

def greedy_decode(model, src, max_len=50, sos_token=1, eos_token=2):
    """Greedy decoding: always select the highest probability token"""
    model.eval()
    hidden, cell = model.encoder(src)

    input_token = torch.tensor([[sos_token]])
    generated_tokens = []

    with torch.no_grad():
        for _ in range(max_len):
            prediction, hidden, cell = model.decoder(input_token, hidden, cell)
            token = prediction.argmax(dim=-1)
            generated_tokens.append(token.item())

            if token.item() == eos_token:
                break
            input_token = token.unsqueeze(0)

    return generated_tokens

Chatbot Implementation

Data Preparation

We use a dialogue dataset such as the Cornell Movie Dialog Corpus:

import json

class DialogDataset:
    def __init__(self, data_path, vocab, max_len=50):
        self.pairs = []
        self.vocab = vocab
        self.max_len = max_len
        self._load_data(data_path)

    def _load_data(self, path):
        with open(path, 'r') as f:
            for line in f:
                pair = json.loads(line)
                src = self.vocab.encode(pair['input'])[:self.max_len]
                trg = self.vocab.encode(pair['response'])[:self.max_len]
                self.pairs.append((src, trg))

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        src, trg = self.pairs[idx]
        return (torch.tensor(src, dtype=torch.long),
                torch.tensor(trg, dtype=torch.long))

Two-Stage Training Pipeline

def train_chatbot(model, train_data, val_data, config):
    """Two-stage chatbot training"""
    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
    criterion = nn.CrossEntropyLoss(ignore_index=0)  # Ignore PAD

    # === Stage 1: Supervised Learning (Teacher Forcing) ===
    print("Stage 1: Starting supervised learning")
    for epoch in range(config['supervised_epochs']):
        loss = train_supervised(model, train_data, optimizer, criterion)
        bleu = evaluate_bleu(model, val_data)
        print(f"Epoch {epoch}: Loss={loss:.4f}, BLEU={bleu:.4f}")

    # === Stage 2: SCST Reinforcement Learning ===
    print("Stage 2: Starting SCST reinforcement learning")
    rl_optimizer = torch.optim.Adam(model.parameters(),
                                     lr=config['rl_lr'])  # Smaller learning rate

    for epoch in range(config['rl_epochs']):
        model.train()
        total_rl_loss = 0.0

        for src, trg in train_data:
            rl_optimizer.zero_grad()

            # Use BLEU as the reward function
            loss = self_critical_loss(
                model, src, trg,
                reward_fn=compute_bleu
            )

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            rl_optimizer.step()
            total_rl_loss += loss.item()

        avg_loss = total_rl_loss / len(train_data)
        bleu = evaluate_bleu(model, val_data)
        print(f"RL Epoch {epoch}: Loss={avg_loss:.4f}, BLEU={bleu:.4f}")

    return model

Dialogue Testing

def chat(model, vocab, max_len=50):
    """Dialogue interface"""
    model.eval()
    print("Starting conversation with the chatbot. Type 'quit' to exit.")

    while True:
        user_input = input("User: ")
        if user_input.lower() == 'quit':
            break

        # Encode input
        tokens = vocab.encode(user_input)
        src = torch.tensor([tokens], dtype=torch.long)

        # Generate response via greedy decoding
        response_tokens = greedy_decode(model, src, max_len)
        response = vocab.decode(response_tokens)
        print(f"Chatbot: {response}")

Practical Considerations

Importance of Reward Design

BLEU score alone does not guarantee good dialogue quality. In real chatbot systems, multiple rewards are combined:

  • Fluency: Language model perplexity
  • Relevance: Semantic similarity between input and response
  • Diversity: Repetition pattern penalty
  • Safety: Harmful content filtering score
def combined_reward(reference, hypothesis, input_text=None):
    bleu = compute_bleu(reference, hypothesis)
    diversity = compute_distinct_ngrams(hypothesis)
    repetition_penalty = compute_repetition_penalty(hypothesis)

    return 0.5 * bleu + 0.3 * diversity - 0.2 * repetition_penalty

Modern dialogue systems use Transformer-based large language models (LLMs) and RLHF (Reinforcement Learning from Human Feedback). The SCST concepts covered in this post are directly connected to PPO-based fine-tuning in RLHF.


Key Takeaways

  • Seq2Seq models solve sequence transformation problems with an encoder-decoder architecture
  • Reinforcement learning mitigates the exposure bias and metric mismatch issues of supervised learning
  • SCST is a REINFORCE variant that reduces variance by using greedy decoding as the baseline
  • Reward function design is key to chatbot quality

In the next post, we will explore how to apply reinforcement learning to web navigation.