Skip to content
Published on

スクラッチからLLMを構築する: コードでGPTを理解する完全ガイド

Authors

はじめに

GPT、Llama、Mistral — 大規模言語モデル(LLM)は現在のAI革命の中心にあります。しかし、これらのモデルがコードレベルでどのように動作するかを理解している人はほとんどいません。このガイドでは、PyTorchを使ってゼロから小さなGPTスタイルのモデル「miniGPT」を構築し、モダンLLMの背後にあるすべてのコアコンセプトをマスターします。


1. LLMアーキテクチャ概要

1.1 事前学習パイプライン

モダンなLLM開発は3つのステージを経ます:

  1. 事前学習(Pretraining): モデルは数十億のテキストトークンから次のトークンを予測することを学習します。このステージで、統計的な言語パターン、世界知識、推論能力を獲得します。

  2. インストラクションファインチューニング(SFT): モデルは人間が書いた質問と回答のペアでトレーニングされ、指示に従う方法を学習します。

  3. 選好学習(RLHF / DPO): モデルは人間の選好データを使ってより有用で安全な応答を生成するよう洗練されます。

1.2 miniGPTの仕様

パラメータ
語彙サイズ50,257(GPT-2と同じ)
コンテキスト長1024
埋め込み次元768
レイヤー数12
Attentionヘッド数12
FFN拡張比4x
総パラメータ数約124M

これはGPT-2 Smallと一致し、単一のGPUでトレーニング可能です。


2. トークナイザーの実装

2.1 BPEアルゴリズム

Byte-Pair Encoding(BPE)は、最新のほとんどのLLMで使用されているトークン化アルゴリズムです。

from collections import defaultdict
from typing import List, Tuple, Dict

class SimpleBPETokenizer:
    """スクラッチから実装したBPEトークナイザー"""

    def __init__(self):
        self.vocab = {}
        self.merges = {}
        self.special_tokens = {
            "<|endoftext|>": 50256,
        }

    def get_stats(self, vocab: Dict[str, int]) -> Dict[Tuple, int]:
        """隣接するトークンペアの頻度を数える"""
        pairs = defaultdict(int)
        for word, freq in vocab.items():
            symbols = word.split()
            for i in range(len(symbols) - 1):
                pairs[(symbols[i], symbols[i+1])] += freq
        return pairs

    def merge_vocab(self, pair: Tuple, v_in: Dict) -> Dict:
        """最も頻繁なペアをマージする"""
        v_out = {}
        bigram = " ".join(pair)
        replacement = "".join(pair)
        for word in v_in:
            w_out = word.replace(bigram, replacement)
            v_out[w_out] = v_in[word]
        return v_out

    def train(self, corpus: List[str], vocab_size: int = 1000):
        """BPE語彙をトレーニングする"""
        # 1. 文字レベルの初期語彙を構築
        word_freq = defaultdict(int)
        for text in corpus:
            for word in text.split():
                word_freq[" ".join(list(word)) + " </w>"] += 1

        vocab = dict(word_freq)

        # 2. vocab_sizeに達するまで反復的にマージ
        for i in range(vocab_size):
            pairs = self.get_stats(vocab)
            if not pairs:
                break
            best = max(pairs, key=pairs.get)
            vocab = self.merge_vocab(best, vocab)
            self.merges[best] = i
            print(f"Step {i}: merged {best} (freq={pairs[best]})")

        # 3. 最終語彙を構築
        for word in vocab:
            for token in word.split():
                if token not in self.vocab:
                    self.vocab[token] = len(self.vocab)

        return self.vocab, self.merges

2.2 tiktokenの使用

実際には、OpenAIのtiktokenライブラリの方がはるかに効率的です。

import tiktoken
import torch
from typing import List

class GPTTokenizer:
    """tiktokenベースのGPTトークナイザーラッパー"""

    def __init__(self, encoding_name: str = "gpt2"):
        self.enc = tiktoken.get_encoding(encoding_name)
        self.vocab_size = self.enc.n_vocab
        self.eot_token = self.enc.eot_token  # 50256

    def encode(self, text: str, add_special_tokens: bool = True) -> List[int]:
        """テキストをトークンIDのリストに変換"""
        ids = self.enc.encode(text, allowed_special={"<|endoftext|>"})
        if add_special_tokens:
            ids = [self.eot_token] + ids
        return ids

    def decode(self, ids: List[int]) -> str:
        """トークンIDのリストをテキストに戻す"""
        return self.enc.decode(ids)

    def encode_batch(self, texts: List[str]) -> List[List[int]]:
        """バッチエンコード"""
        return [self.encode(t) for t in texts]

    def __len__(self):
        return self.vocab_size

# 使用例
tokenizer = GPTTokenizer()
text = "Hello! Let's build GPT from scratch."
ids = tokenizer.encode(text)
print(f"Token count: {len(ids)}")
print(f"Token IDs: {ids}")
decoded = tokenizer.decode(ids)
print(f"Decoded: {decoded}")

3. データ準備

3.1 テキストデータセットクラス

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class TextDataset(Dataset):
    """スライディングウィンドウ言語モデリングデータセット"""

    def __init__(
        self,
        data_path: str,
        tokenizer: GPTTokenizer,
        seq_len: int = 1024,
        stride: int = 512
    ):
        self.seq_len = seq_len
        self.stride = stride

        print(f"Loading data: {data_path}")
        with open(data_path, "r", encoding="utf-8") as f:
            text = f.read()

        print(f"Total characters: {len(text):,}")
        self.tokens = tokenizer.encode(text, add_special_tokens=False)
        print(f"Total tokens: {len(self.tokens):,}")

        self.tokens = np.array(self.tokens, dtype=np.int32)

    def __len__(self):
        return max(0, (len(self.tokens) - self.seq_len) // self.stride)

    def __getitem__(self, idx):
        start = idx * self.stride
        end = start + self.seq_len + 1

        chunk = self.tokens[start:end]

        # 入力: tokens[0:seq_len], ラベル: tokens[1:seq_len+1]
        x = torch.tensor(chunk[:-1], dtype=torch.long)
        y = torch.tensor(chunk[1:], dtype=torch.long)

        return x, y


def create_dataloader(
    data_path: str,
    tokenizer: GPTTokenizer,
    batch_size: int = 8,
    seq_len: int = 1024,
    stride: int = 512,
    num_workers: int = 4,
    shuffle: bool = True
) -> DataLoader:
    dataset = TextDataset(data_path, tokenizer, seq_len, stride)
    print(f"Dataset size: {len(dataset):,} samples")

    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True
    )

4. モデルアーキテクチャ

4.1 設定クラス

from dataclasses import dataclass

@dataclass
class GPTConfig:
    """GPTモデル設定"""
    vocab_size: int = 50257
    max_seq_len: int = 1024
    n_embd: int = 768
    n_layer: int = 12
    n_head: int = 12
    dropout: float = 0.1
    bias: bool = False

    ffn_mult: int = 4
    use_swiglu: bool = True
    use_rope: bool = True

4.2 RMSNorm

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization(Llamaスタイル)"""

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        return self.weight * self._norm(x.float()).type_as(x)

4.3 回転位置埋め込み(RoPE)

class RotaryEmbedding(nn.Module):
    """Rotary Position Embedding(RoPE)"""

    def __init__(self, dim: int, max_seq_len: int = 2048, base: int = 10000):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len

        # 周波数を計算: theta_i = base^(-2i/dim)
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

        self._build_cache(max_seq_len)

    def _build_cache(self, seq_len: int):
        t = torch.arange(seq_len, device=self.inv_freq.device).float()
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        emb = torch.cat([freqs, freqs], dim=-1)
        self.register_buffer("cos_cache", emb.cos()[None, None, :, :])
        self.register_buffer("sin_cache", emb.sin()[None, None, :, :])

    def rotate_half(self, x):
        """テンソルの半分を回転"""
        x1 = x[..., :x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2:]
        return torch.cat([-x2, x1], dim=-1)

    def forward(self, q: torch.Tensor, k: torch.Tensor, seq_len: int):
        cos = self.cos_cache[:, :, :seq_len, :].to(q.dtype)
        sin = self.sin_cache[:, :, :seq_len, :].to(q.dtype)

        q_rot = (q * cos) + (self.rotate_half(q) * sin)
        k_rot = (k * cos) + (self.rotate_half(k) * sin)
        return q_rot, k_rot

4.4 KVキャッシュ付きマルチヘッドCausal Attention

class CausalSelfAttention(nn.Module):
    """KVキャッシュをサポートするCausal Self-Attention"""

    def __init__(self, config: GPTConfig):
        super().__init__()
        assert config.n_embd % config.n_head == 0

        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_dim = config.n_embd // config.n_head

        # 結合されたQ, K, V射影
        self.qkv_proj = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)

        if config.use_rope:
            self.rope = RotaryEmbedding(self.head_dim, config.max_seq_len)
        else:
            self.rope = None

        # Causalマスク(下三角行列)
        self.register_buffer(
            "causal_mask",
            torch.tril(torch.ones(config.max_seq_len, config.max_seq_len))
            .view(1, 1, config.max_seq_len, config.max_seq_len)
        )

    def forward(
        self,
        x: torch.Tensor,
        use_cache: bool = False,
        past_kv=None
    ):
        B, T, C = x.shape

        qkv = self.qkv_proj(x)
        q, k, v = qkv.split(self.n_embd, dim=2)

        # ヘッドを分割: [B, T, C] -> [B, n_head, T, head_dim]
        q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)

        if self.rope is not None:
            q, k = self.rope(q, k, T)

        # KVキャッシュの処理
        if past_kv is not None:
            past_k, past_v = past_kv
            k = torch.cat([past_k, k], dim=2)
            v = torch.cat([past_v, v], dim=2)

        present_kv = (k, v) if use_cache else None
        kv_len = k.shape[2]

        # スケールドドット積アテンション(利用可能な場合はFlash Attentionを使用)
        if hasattr(F, "scaled_dot_product_attention"):
            y = F.scaled_dot_product_attention(
                q, k, v,
                attn_mask=None,
                dropout_p=self.attn_dropout.p if self.training else 0.0,
                is_causal=True
            )
        else:
            # 手動実装
            scale = 1.0 / math.sqrt(self.head_dim)
            attn = (q @ k.transpose(-2, -1)) * scale
            mask = self.causal_mask[:, :, :T, :kv_len]
            attn = attn.masked_fill(mask == 0, float("-inf"))
            attn = F.softmax(attn, dim=-1)
            attn = self.attn_dropout(attn)
            y = attn @ v

        # ヘッドをマージ: [B, n_head, T, head_dim] -> [B, T, C]
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.out_proj(y))

        return y, present_kv

4.5 SwiGLUフィードフォワードネットワーク

class SwiGLUFFN(nn.Module):
    """SwiGLU活性化を持つFFN(Llamaスタイル)"""

    def __init__(self, config: GPTConfig):
        super().__init__()
        hidden_dim = int(config.n_embd * config.ffn_mult * 2 / 3)
        # ハードウェア効率のために64の倍数に切り上げ
        hidden_dim = ((hidden_dim + 63) // 64) * 64

        self.gate_proj = nn.Linear(config.n_embd, hidden_dim, bias=False)
        self.up_proj = nn.Linear(config.n_embd, hidden_dim, bias=False)
        self.down_proj = nn.Linear(hidden_dim, config.n_embd, bias=False)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # SwiGLU: gate * SiLU(x)
        gate = F.silu(self.gate_proj(x))
        up = self.up_proj(x)
        return self.dropout(self.down_proj(gate * up))


class GPTFFN(nn.Module):
    """クラシックなGPTスタイルのGELU FFN"""

    def __init__(self, config: GPTConfig):
        super().__init__()
        hidden_dim = config.n_embd * config.ffn_mult
        self.fc1 = nn.Linear(config.n_embd, hidden_dim, bias=config.bias)
        self.fc2 = nn.Linear(hidden_dim, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.gelu(self.fc1(x))
        x = self.dropout(self.fc2(x))
        return x

4.6 Transformerブロック

class TransformerBlock(nn.Module):
    """Pre-RMSNorm Transformerブロック"""

    def __init__(self, config: GPTConfig):
        super().__init__()
        self.norm1 = RMSNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.norm2 = RMSNorm(config.n_embd)

        if config.use_swiglu:
            self.ffn = SwiGLUFFN(config)
        else:
            self.ffn = GPTFFN(config)

    def forward(self, x: torch.Tensor, use_cache: bool = False, past_kv=None):
        # Pre-Norm + Attention + Residual
        attn_out, present_kv = self.attn(
            self.norm1(x),
            use_cache=use_cache,
            past_kv=past_kv
        )
        x = x + attn_out

        # Pre-Norm + FFN + Residual
        x = x + self.ffn(self.norm2(x))

        return x, present_kv

4.7 完全なGPTモデル

class MiniGPT(nn.Module):
    """完全なminiGPTモデル"""

    def __init__(self, config: GPTConfig):
        super().__init__()
        self.config = config

        # 埋め込みレイヤー
        self.token_emb = nn.Embedding(config.vocab_size, config.n_embd)

        if not config.use_rope:
            self.pos_emb = nn.Embedding(config.max_seq_len, config.n_embd)

        self.emb_dropout = nn.Dropout(config.dropout)

        # Transformerレイヤー
        self.layers = nn.ModuleList([
            TransformerBlock(config) for _ in range(config.n_layer)
        ])

        # 最終正規化
        self.norm_final = RMSNorm(config.n_embd)

        # 言語モデルヘッド
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # 重み共有(入力埋め込み == 出力埋め込み)
        self.lm_head.weight = self.token_emb.weight

        # 重みの初期化
        self.apply(self._init_weights)

        # 残差ストリームの特別なスケーリング
        for pn, p in self.named_parameters():
            if pn.endswith("out_proj.weight") or pn.endswith("down_proj.weight"):
                nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))

        print(f"Parameters: {self.count_params() / 1e6:.1f}M")

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def count_params(self) -> int:
        return sum(p.numel() for p in self.parameters())

    def forward(
        self,
        input_ids: torch.Tensor,
        targets: torch.Tensor = None,
        use_cache: bool = False,
        past_kvs: list = None
    ):
        B, T = input_ids.shape
        device = input_ids.device

        # トークン埋め込み
        x = self.token_emb(input_ids)  # [B, T, n_embd]

        # 位置埋め込み(RoPEを使用しない場合のみ)
        if not self.config.use_rope:
            positions = torch.arange(T, device=device)
            x = x + self.pos_emb(positions)

        x = self.emb_dropout(x)

        # Transformerレイヤー
        present_kvs = []
        for i, layer in enumerate(self.layers):
            past_kv = past_kvs[i] if past_kvs is not None else None
            x, present_kv = layer(x, use_cache=use_cache, past_kv=past_kv)
            present_kvs.append(present_kv)

        x = self.norm_final(x)

        if targets is not None:
            # トレーニング: シーケンス全体でロスを計算
            logits = self.lm_head(x)
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1),
                ignore_index=-1
            )
            return logits, loss
        else:
            # 推論: 最後のトークンのlogitsのみ必要
            logits = self.lm_head(x[:, [-1], :])
            return logits, present_kvs

    @classmethod
    def from_pretrained(cls, model_path: str):
        checkpoint = torch.load(model_path, map_location="cpu")
        config = checkpoint["config"]
        model = cls(config)
        model.load_state_dict(checkpoint["model"])
        return model

5. 事前学習の実装

5.1 学習率スケジューラー

import math

def get_cosine_schedule_with_warmup(
    optimizer,
    warmup_steps: int,
    total_steps: int,
    min_lr_ratio: float = 0.1
):
    """ウォームアップ付きコサインLRスケジュール"""

    def lr_lambda(current_step: int):
        if current_step < warmup_steps:
            return current_step / max(1, warmup_steps)
        else:
            progress = (current_step - warmup_steps) / max(1, total_steps - warmup_steps)
            cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
            return max(min_lr_ratio, cosine_decay)

    from torch.optim.lr_scheduler import LambdaLR
    return LambdaLR(optimizer, lr_lambda)

5.2 トレーニングループ

import torch
import torch.nn as nn
from torch.optim import AdamW
import os
import time

class Trainer:
    """miniGPT事前学習トレーナー"""

    def __init__(self, model: MiniGPT, train_dataloader, val_dataloader, config: dict):
        self.model = model
        self.train_dl = train_dataloader
        self.val_dl = val_dataloader
        self.config = config
        self.device = config.get("device", "cuda" if torch.cuda.is_available() else "cpu")

        # 混合精度
        self.dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
        self.scaler = torch.cuda.amp.GradScaler(enabled=(self.dtype == torch.float16))
        self.ctx = torch.amp.autocast(device_type="cuda", dtype=self.dtype)

        self.optimizer = self._create_optimizer()

        total_steps = len(train_dataloader) * config["num_epochs"]
        warmup_steps = int(total_steps * config.get("warmup_ratio", 0.05))
        self.scheduler = get_cosine_schedule_with_warmup(
            self.optimizer,
            warmup_steps=warmup_steps,
            total_steps=total_steps
        )

        self.step = 0
        self.best_val_loss = float("inf")

    def _create_optimizer(self):
        """weight decayあり/なしのパラメータを分離"""
        decay_params = []
        no_decay_params = []

        for name, param in self.model.named_parameters():
            if not param.requires_grad:
                continue
            if param.dim() == 1 or "emb" in name:
                no_decay_params.append(param)
            else:
                decay_params.append(param)

        optim_groups = [
            {"params": decay_params, "weight_decay": self.config.get("weight_decay", 0.1)},
            {"params": no_decay_params, "weight_decay": 0.0}
        ]

        return AdamW(
            optim_groups,
            lr=self.config["learning_rate"],
            betas=(0.9, 0.95),
            fused=True
        )

    def compute_val_loss(self) -> float:
        self.model.eval()
        total_loss = 0.0
        num_batches = min(len(self.val_dl), 50)

        with torch.no_grad():
            for i, (x, y) in enumerate(self.val_dl):
                if i >= num_batches:
                    break
                x, y = x.to(self.device), y.to(self.device)
                with self.ctx:
                    _, loss = self.model(x, y)
                total_loss += loss.item()

        self.model.train()
        return total_loss / num_batches

    def save_checkpoint(self, path: str):
        os.makedirs(os.path.dirname(path), exist_ok=True)
        torch.save({
            "step": self.step,
            "model": self.model.state_dict(),
            "optimizer": self.optimizer.state_dict(),
            "scheduler": self.scheduler.state_dict(),
            "config": self.model.config,
            "best_val_loss": self.best_val_loss
        }, path)
        print(f"Checkpoint saved: {path}")

    def train(self):
        self.model.to(self.device)
        self.model.train()

        num_epochs = self.config["num_epochs"]
        grad_accum = self.config.get("gradient_accumulation_steps", 1)
        max_grad_norm = self.config.get("max_grad_norm", 1.0)
        log_interval = self.config.get("log_interval", 100)
        eval_interval = self.config.get("eval_interval", 500)
        save_interval = self.config.get("save_interval", 1000)

        print(f"Training on {self.device}, dtype={self.dtype}")

        for epoch in range(num_epochs):
            t0 = time.time()

            for batch_idx, (x, y) in enumerate(self.train_dl):
                x, y = x.to(self.device), y.to(self.device)

                with self.ctx:
                    _, loss = self.model(x, y)
                    loss = loss / grad_accum

                self.scaler.scale(loss).backward()

                if (batch_idx + 1) % grad_accum == 0:
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)

                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                    self.scheduler.step()
                    self.optimizer.zero_grad(set_to_none=True)
                    self.step += 1

                if self.step % log_interval == 0:
                    lr = self.optimizer.param_groups[0]["lr"]
                    elapsed = time.time() - t0
                    tokens_per_sec = log_interval * x.shape[0] * x.shape[1] / elapsed
                    print(
                        f"Epoch {epoch} | Step {self.step} | "
                        f"Loss: {loss.item() * grad_accum:.4f} | "
                        f"LR: {lr:.6f} | "
                        f"Tokens/s: {tokens_per_sec:.0f}"
                    )
                    t0 = time.time()

                if self.step % eval_interval == 0:
                    val_loss = self.compute_val_loss()
                    print(f"Val Loss: {val_loss:.4f} | Perplexity: {math.exp(val_loss):.2f}")
                    if val_loss < self.best_val_loss:
                        self.best_val_loss = val_loss
                        self.save_checkpoint("./checkpoints/best_model.pt")

                if self.step % save_interval == 0:
                    self.save_checkpoint(f"./checkpoints/step_{self.step}.pt")

6. テキスト生成

6.1 デコード戦略

import torch
import torch.nn.functional as F

class TextGenerator:
    """miniGPTテキストジェネレーター"""

    def __init__(self, model: MiniGPT, tokenizer: GPTTokenizer, device: str = "cuda"):
        self.model = model.to(device)
        self.tokenizer = tokenizer
        self.device = device
        self.model.eval()

    @torch.no_grad()
    def generate(
        self,
        prompt: str,
        max_new_tokens: int = 200,
        strategy: str = "top_p",
        temperature: float = 0.8,
        top_k: int = 50,
        top_p: float = 0.9,
        num_beams: int = 4
    ) -> str:
        input_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
        input_ids = torch.tensor(input_ids, dtype=torch.long, device=self.device).unsqueeze(0)

        if strategy == "greedy":
            generated = self._greedy_decode(input_ids, max_new_tokens)
        elif strategy == "temperature":
            generated = self._temperature_sampling(input_ids, max_new_tokens, temperature)
        elif strategy == "top_k":
            generated = self._top_k_sampling(input_ids, max_new_tokens, temperature, top_k)
        elif strategy == "top_p":
            generated = self._top_p_sampling(input_ids, max_new_tokens, temperature, top_p)
        elif strategy == "beam":
            generated = self._beam_search(input_ids, max_new_tokens, num_beams)
        else:
            raise ValueError(f"Unknown strategy: {strategy}")

        new_tokens = generated[0][input_ids.shape[1]:].tolist()
        return self.tokenizer.decode(new_tokens)

    def _greedy_decode(self, input_ids, max_new_tokens):
        """グリーディデコード: 常に最高確率のトークンを選択"""
        current_ids = input_ids.clone()
        for _ in range(max_new_tokens):
            logits, _ = self.model(current_ids)
            next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
            current_ids = torch.cat([current_ids, next_token], dim=1)
            if next_token.item() == self.tokenizer.eot_token:
                break
        return current_ids

    def _temperature_sampling(self, input_ids, max_new_tokens, temperature):
        """温度サンプリング: softmax前にlogitsをスケール"""
        current_ids = input_ids.clone()
        for _ in range(max_new_tokens):
            logits, _ = self.model(current_ids)
            logits = logits[:, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            current_ids = torch.cat([current_ids, next_token], dim=1)
            if next_token.item() == self.tokenizer.eot_token:
                break
        return current_ids

    def _top_k_sampling(self, input_ids, max_new_tokens, temperature, top_k):
        """Top-kサンプリング: 上位kトークンのみからサンプリング"""
        current_ids = input_ids.clone()
        for _ in range(max_new_tokens):
            logits, _ = self.model(current_ids)
            logits = logits[:, -1, :] / temperature

            top_k_logits, top_k_indices = torch.topk(logits, k=top_k, dim=-1)
            filtered_logits = torch.full_like(logits, float("-inf"))
            filtered_logits.scatter_(1, top_k_indices, top_k_logits)

            probs = F.softmax(filtered_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            current_ids = torch.cat([current_ids, next_token], dim=1)
            if next_token.item() == self.tokenizer.eot_token:
                break
        return current_ids

    def _top_p_sampling(self, input_ids, max_new_tokens, temperature, top_p):
        """Top-p(Nucleus)サンプリング: 確率質量pをカバーする上位トークンからサンプリング"""
        current_ids = input_ids.clone()
        for _ in range(max_new_tokens):
            logits, _ = self.model(current_ids)
            logits = logits[:, -1, :] / temperature

            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

            # top_pを超える累積確率のトークンを削除
            sorted_indices_to_remove = cumulative_probs - F.softmax(sorted_logits, dim=-1) > top_p
            sorted_logits[sorted_indices_to_remove] = float("-inf")

            logits = torch.zeros_like(logits).scatter_(1, sorted_indices, sorted_logits)
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            current_ids = torch.cat([current_ids, next_token], dim=1)
            if next_token.item() == self.tokenizer.eot_token:
                break
        return current_ids

    def _beam_search(self, input_ids, max_new_tokens, num_beams):
        """ビームサーチ: 複数の候補を同時に探索"""
        B = input_ids.shape[0]
        assert B == 1, "Beam search only supports batch size 1"

        beams = [(input_ids, 0.0)]

        for _ in range(max_new_tokens):
            all_candidates = []
            for beam_ids, beam_score in beams:
                logits, _ = self.model(beam_ids)
                log_probs = F.log_softmax(logits[:, -1, :], dim=-1)
                top_log_probs, top_ids = log_probs.topk(num_beams)

                for i in range(num_beams):
                    token = top_ids[0, i].unsqueeze(0).unsqueeze(0)
                    new_ids = torch.cat([beam_ids, token], dim=1)
                    new_score = beam_score + top_log_probs[0, i].item()
                    all_candidates.append((new_ids, new_score))

            all_candidates.sort(key=lambda x: x[1] / x[0].shape[1], reverse=True)
            beams = all_candidates[:num_beams]

            if beams[0][0][0, -1].item() == self.tokenizer.eot_token:
                break

        return beams[0][0]

7. パープレキシティ評価

import math
import torch

@torch.no_grad()
def compute_perplexity(
    model: MiniGPT,
    tokenizer: GPTTokenizer,
    text: str,
    device: str = "cuda",
    seq_len: int = 512
) -> float:
    """テキストのパープレキシティを計算"""
    model.eval()
    model.to(device)

    tokens = tokenizer.encode(text, add_special_tokens=False)
    total_loss = 0.0
    total_tokens = 0

    for i in range(0, len(tokens) - seq_len, seq_len):
        chunk = tokens[i:i + seq_len + 1]
        if len(chunk) < seq_len + 1:
            break

        x = torch.tensor(chunk[:-1], dtype=torch.long, device=device).unsqueeze(0)
        y = torch.tensor(chunk[1:], dtype=torch.long, device=device).unsqueeze(0)

        _, loss = model(x, y)
        total_loss += loss.item() * seq_len
        total_tokens += seq_len

    avg_loss = total_loss / total_tokens
    return math.exp(avg_loss)


# 使用例
model = MiniGPT.from_pretrained("./checkpoints/best_model.pt")
tokenizer = GPTTokenizer()
generator = TextGenerator(model, tokenizer)

test_text = "The quick brown fox jumps over the lazy dog."
ppl = compute_perplexity(model, tokenizer, test_text)
print(f"Perplexity: {ppl:.2f}")

for strategy in ["greedy", "top_k", "top_p"]:
    generated = generator.generate(
        prompt="Once upon a time",
        max_new_tokens=100,
        strategy=strategy,
        temperature=0.8
    )
    print(f"\n[{strategy}]: {generated}")

8. スケールアップ

8.1 Flash Attention 2

# pip install flash-attn

from flash_attn import flash_attn_qkvpacked_func

class FlashCausalAttention(nn.Module):
    """Flash Attention 2を使用した高速アテンション"""

    def __init__(self, config: GPTConfig):
        super().__init__()
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_dim = config.n_embd // config.n_head

        self.qkv_proj = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
        self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.dropout = config.dropout

    def forward(self, x: torch.Tensor):
        B, T, C = x.shape

        qkv = self.qkv_proj(x)
        # Flash Attentionは[B, T, 3, n_head, head_dim]の形状を期待
        qkv = qkv.view(B, T, 3, self.n_head, self.head_dim)

        attn_out = flash_attn_qkvpacked_func(
            qkv,
            dropout_p=self.dropout if self.training else 0.0,
            causal=True
        )  # [B, T, n_head, head_dim]

        attn_out = attn_out.reshape(B, T, C)
        return self.out_proj(attn_out)

8.2 FSDP分散トレーニング

import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
import functools

def setup_fsdp_training():
    """FSDP分散トレーニングのセットアップ"""
    dist.init_process_group(backend="nccl")
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    torch.cuda.set_device(rank)

    config = GPTConfig(
        n_embd=1024,
        n_layer=24,
        n_head=16,
        vocab_size=50257
    )
    model = MiniGPT(config)

    # ラップポリシー: TransformerBlock境界でシャード
    auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={TransformerBlock}
    )

    model = FSDP(
        model,
        auto_wrap_policy=auto_wrap_policy,
        mixed_precision=torch.distributed.fsdp.MixedPrecision(
            param_dtype=torch.bfloat16,
            reduce_dtype=torch.bfloat16,
            buffer_dtype=torch.bfloat16
        ),
        sharding_strategy="FULL_SHARD",  # ZeRO-3相当
        device_id=rank
    )

    return model, rank, world_size

8.3 Chinchillaスケーリング則

ChinchillaペーパーHoffmann et al. (2022)は、最適モデルサイズとトレーニングトークンの関係について次の式を示しています:

最適トークン数 = 20 x パラメータ数

def chinchilla_optimal_tokens(model_params: int) -> int:
    """Chinchillaスケーリング則に基づいた最適トレーニングトークン数を計算"""
    return 20 * model_params

models = {
    "124M (GPT-2 Small)": 124e6,
    "1.3B": 1.3e9,
    "7B (Llama-2 7B)": 7e9,
    "13B": 13e9,
    "70B (Llama-2 70B)": 70e9,
}

print("モデルサイズ別の推奨トレーニングトークン数:")
for name, params in models.items():
    optimal_tokens = chinchilla_optimal_tokens(int(params))
    print(f"  {name}: {optimal_tokens/1e9:.1f}B tokens")

9. インストラクションチューニング

9.1 SFTデータフォーマット

# ChatMLフォーマット
CHATML_SYSTEM = "<|im_start|>system\n{system}<|im_end|>\n"
CHATML_USER = "<|im_start|>user\n{user}<|im_end|>\n"
CHATML_ASSISTANT = "<|im_start|>assistant\n{assistant}<|im_end|>\n"

def format_chatml(system: str, user: str, assistant: str, add_generation_prompt: bool = False) -> str:
    formatted = CHATML_SYSTEM.format(system=system)
    formatted += CHATML_USER.format(user=user)
    formatted += CHATML_ASSISTANT.format(assistant=assistant)
    if add_generation_prompt:
        formatted += "<|im_start|>assistant\n"
    return formatted


# Llama-3フォーマット
def format_llama3(system: str, user: str, assistant: str, add_generation_prompt: bool = False) -> str:
    formatted = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system}<|eot_id|>\n"
    formatted += f"<|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|>\n"
    formatted += f"<|start_header_id|>assistant<|end_header_id|>\n\n{assistant}<|eot_id|>"
    if add_generation_prompt:
        formatted += "\n<|start_header_id|>assistant<|end_header_id|>\n\n"
    return formatted

9.2 インストラクションデータセットクラス

class InstructionDataset(Dataset):
    """SFTデータセット: アシスタント応答のみでロスを計算"""

    def __init__(
        self,
        data: list,
        tokenizer: GPTTokenizer,
        max_seq_len: int = 2048,
        format_func=format_chatml
    ):
        self.samples = []

        for item in data:
            # 会話全体をトークン化
            full_text = format_func(
                system=item.get("system", "You are a helpful assistant."),
                user=item["instruction"],
                assistant=item["response"]
            )
            full_ids = tokenizer.encode(full_text, add_special_tokens=False)

            # プロンプトのみをトークン化(応答の開始位置を見つけるため)
            prompt_text = format_func(
                system=item.get("system", "You are a helpful assistant."),
                user=item["instruction"],
                assistant="",
                add_generation_prompt=True
            )
            prompt_ids = tokenizer.encode(prompt_text, add_special_tokens=False)
            prompt_len = len(prompt_ids)

            if len(full_ids) > max_seq_len:
                full_ids = full_ids[:max_seq_len]

            # プロンプトトークンをラベルでマスク(-100 = ロスで無視)
            labels = full_ids.copy()
            labels[:prompt_len] = [-100] * prompt_len

            self.samples.append({
                "input_ids": full_ids,
                "labels": labels
            })

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        return (
            torch.tensor(sample["input_ids"], dtype=torch.long),
            torch.tensor(sample["labels"], dtype=torch.long)
        )

9.3 SFTトレーニングループ

def train_sft(
    base_model_path: str,
    data_path: str,
    output_dir: str,
    num_epochs: int = 3,
    learning_rate: float = 2e-5,
    batch_size: int = 4
):
    """教師ありファインチューニング(SFT)"""
    import json

    with open(data_path) as f:
        data = json.load(f)

    tokenizer = GPTTokenizer()
    dataset = InstructionDataset(data, tokenizer, max_seq_len=2048)

    def collate_fn(batch):
        input_ids = [item[0] for item in batch]
        labels = [item[1] for item in batch]
        max_len = max(len(x) for x in input_ids)
        padded_ids = torch.zeros(len(batch), max_len, dtype=torch.long)
        padded_labels = torch.full((len(batch), max_len), -100, dtype=torch.long)
        for i, (ids, lbls) in enumerate(zip(input_ids, labels)):
            padded_ids[i, :len(ids)] = ids
            padded_labels[i, :len(lbls)] = lbls
        return padded_ids, padded_labels

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

    model = MiniGPT.from_pretrained(base_model_path)
    device = "cuda"
    model.to(device)

    optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    total_steps = len(dataloader) * num_epochs
    scheduler = get_cosine_schedule_with_warmup(optimizer, int(total_steps * 0.1), total_steps)

    model.train()
    for epoch in range(num_epochs):
        for step, (x, y) in enumerate(dataloader):
            x, y = x.to(device), y.to(device)
            logits, loss = model(x, y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            if step % 50 == 0:
                print(f"Epoch {epoch}, Step {step}, Loss: {loss.item():.4f}")

    os.makedirs(output_dir, exist_ok=True)
    torch.save({
        "model": model.state_dict(),
        "config": model.config
    }, f"{output_dir}/sft_model.pt")
    print(f"SFT complete. Model saved: {output_dir}")

10. オープンソースLLMアーキテクチャ分析

10.1 Llama 3アーキテクチャ

Llama 3はMetaのオープンソースLLMで、以下の主要な特性を持ちます:

特性詳細
アーキテクチャDecoder-only Transformer
正規化Pre-RMSNorm
活性化関数SwiGLU
位置エンコーディングRoPE (theta=500,000)
語彙サイズ128,256
コンテキスト長128K (Llama 3.1+)
アテンションGQA(8B、70Bモデル)
# Llama 3スタイルの設定(8Bモデル)
llama3_config = GPTConfig(
    vocab_size=128256,
    max_seq_len=8192,
    n_embd=4096,
    n_layer=32,
    n_head=32,
    use_rope=True,
    use_swiglu=True,
    bias=False
)

10.2 グループクエリアテンション(GQA)

class GroupedQueryAttention(nn.Module):
    """Grouped Query Attention(Llama 3、Mistralで使用)"""

    def __init__(self, config: GPTConfig, n_kv_heads: int = 8):
        super().__init__()
        self.n_head = config.n_head
        self.n_kv_heads = n_kv_heads
        self.n_rep = config.n_head // n_kv_heads
        self.head_dim = config.n_embd // config.n_head

        self.q_proj = nn.Linear(config.n_embd, config.n_head * self.head_dim, bias=False)
        self.k_proj = nn.Linear(config.n_embd, n_kv_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(config.n_embd, n_kv_heads * self.head_dim, bias=False)
        self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)

        self.rope = RotaryEmbedding(self.head_dim, config.max_seq_len, base=500000)

    def forward(self, x: torch.Tensor):
        B, T, C = x.shape

        q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)

        q, k = self.rope(q, k, T)

        # KVヘッドをQueryヘッド数に合わせて繰り返す(GQAのコア)
        k = k.unsqueeze(2).expand(B, self.n_kv_heads, self.n_rep, T, self.head_dim)
        k = k.reshape(B, self.n_head, T, self.head_dim)
        v = v.unsqueeze(2).expand(B, self.n_kv_heads, self.n_rep, T, self.head_dim)
        v = v.reshape(B, self.n_head, T, self.head_dim)

        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        return self.out_proj(y)

10.3 Mixture of Experts(MoE)

DeepSeekとMixtralは、計算量を比例して増やさずにモデル容量を増やすためにMixture of Expertsを使用しています。

class MoEFFN(nn.Module):
    """Mixture of Experts FFN(簡略版)"""

    def __init__(self, config: GPTConfig, num_experts: int = 8, top_k: int = 2):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k

        # ルーター
        self.router = nn.Linear(config.n_embd, num_experts, bias=False)

        # エキスパートFFN
        self.experts = nn.ModuleList([
            SwiGLUFFN(config) for _ in range(num_experts)
        ])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.shape
        x_flat = x.view(B * T, C)

        # ルーティングスコアを計算
        router_logits = self.router(x_flat)
        router_probs = F.softmax(router_logits, dim=-1)

        # top-kエキスパートを選択
        top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
        top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)

        # エキスパート出力を計算
        output = torch.zeros_like(x_flat)
        for k in range(self.top_k):
            expert_idx = top_k_indices[:, k]
            expert_prob = top_k_probs[:, k].unsqueeze(-1)

            for e in range(self.num_experts):
                mask = (expert_idx == e)
                if mask.any():
                    expert_out = self.experts[e](x_flat[mask])
                    output[mask] += expert_prob[mask] * expert_out

        return output.view(B, T, C)

すべてを統合する

全コンポーネントを統合した完全なスクリプトです:

import torch
from dataclasses import dataclass

# 1. 設定
config = GPTConfig(
    vocab_size=50257,
    max_seq_len=1024,
    n_embd=768,
    n_layer=12,
    n_head=12,
    dropout=0.1,
    use_rope=True,
    use_swiglu=True
)

# 2. モデル
model = MiniGPT(config)
print(f"Model parameters: {model.count_params() / 1e6:.1f}M")

# 3. トークナイザー
tokenizer = GPTTokenizer()

# 4. データ
train_loader = create_dataloader("train.txt", tokenizer, batch_size=8, seq_len=1024)
val_loader = create_dataloader("val.txt", tokenizer, batch_size=8, seq_len=1024, shuffle=False)

# 5. トレーナー
trainer_config = {
    "num_epochs": 5,
    "learning_rate": 6e-4,
    "weight_decay": 0.1,
    "gradient_accumulation_steps": 4,
    "max_grad_norm": 1.0,
    "warmup_ratio": 0.05,
    "log_interval": 100,
    "eval_interval": 500,
    "save_interval": 1000,
    "device": "cuda"
}

trainer = Trainer(model, train_loader, val_loader, trainer_config)
trainer.train()

# 6. 生成
generator = TextGenerator(model, tokenizer)

prompts = [
    "The history of artificial intelligence",
    "In the year 2050, scientists discovered",
    "The best way to learn programming is",
]

for prompt in prompts:
    print(f"\nPrompt: {prompt}")
    for strategy in ["greedy", "top_p"]:
        text = generator.generate(prompt, max_new_tokens=100, strategy=strategy)
        print(f"  [{strategy}]: {text}")

まとめ

このガイドでは、miniGPTをゼロから構築し、モダンLLMのすべてのコア要素をカバーしました:

  • トークナイザー: BPEアルゴリズムがテキストをサブワードトークンに分割
  • 埋め込み: トークン埋め込み + RoPE位置エンコーディング
  • マルチヘッドCausal Attention: Q/K/V射影、Causalマスク、KVキャッシュ
  • FFN: 非線形性のためのSwiGLU活性化
  • Pre-RMSNorm: トレーニングの安定性
  • テキスト生成: Greedy、Temperature、Top-k、Top-p、ビームサーチ
  • 事前学習: コサインLRスケジュール、勾配クリッピング、混合精度
  • SFT: アシスタント応答トークンのみでロスを計算
  • スケーリング: Flash Attention、FSDP、Chinchilla則

この基盤により、あらゆるモダンなオープンソースLLMを理解し拡張するためのツールが揃います。ここで説明したコンセプトは、Llama 3、Mistral、Qwenの実装詳細に直接対応しています。

参考資料