- Authors

- Name
- Youngju Kim
- @fjvbn20031
はじめに
GPT、Llama、Mistral — 大規模言語モデル(LLM)は現在のAI革命の中心にあります。しかし、これらのモデルがコードレベルでどのように動作するかを理解している人はほとんどいません。このガイドでは、PyTorchを使ってゼロから小さなGPTスタイルのモデル「miniGPT」を構築し、モダンLLMの背後にあるすべてのコアコンセプトをマスターします。
1. LLMアーキテクチャ概要
1.1 事前学習パイプライン
モダンなLLM開発は3つのステージを経ます:
-
事前学習(Pretraining): モデルは数十億のテキストトークンから次のトークンを予測することを学習します。このステージで、統計的な言語パターン、世界知識、推論能力を獲得します。
-
インストラクションファインチューニング(SFT): モデルは人間が書いた質問と回答のペアでトレーニングされ、指示に従う方法を学習します。
-
選好学習(RLHF / DPO): モデルは人間の選好データを使ってより有用で安全な応答を生成するよう洗練されます。
1.2 miniGPTの仕様
| パラメータ | 値 |
|---|---|
| 語彙サイズ | 50,257(GPT-2と同じ) |
| コンテキスト長 | 1024 |
| 埋め込み次元 | 768 |
| レイヤー数 | 12 |
| Attentionヘッド数 | 12 |
| FFN拡張比 | 4x |
| 総パラメータ数 | 約124M |
これはGPT-2 Smallと一致し、単一のGPUでトレーニング可能です。
2. トークナイザーの実装
2.1 BPEアルゴリズム
Byte-Pair Encoding(BPE)は、最新のほとんどのLLMで使用されているトークン化アルゴリズムです。
from collections import defaultdict
from typing import List, Tuple, Dict
class SimpleBPETokenizer:
"""スクラッチから実装したBPEトークナイザー"""
def __init__(self):
self.vocab = {}
self.merges = {}
self.special_tokens = {
"<|endoftext|>": 50256,
}
def get_stats(self, vocab: Dict[str, int]) -> Dict[Tuple, int]:
"""隣接するトークンペアの頻度を数える"""
pairs = defaultdict(int)
for word, freq in vocab.items():
symbols = word.split()
for i in range(len(symbols) - 1):
pairs[(symbols[i], symbols[i+1])] += freq
return pairs
def merge_vocab(self, pair: Tuple, v_in: Dict) -> Dict:
"""最も頻繁なペアをマージする"""
v_out = {}
bigram = " ".join(pair)
replacement = "".join(pair)
for word in v_in:
w_out = word.replace(bigram, replacement)
v_out[w_out] = v_in[word]
return v_out
def train(self, corpus: List[str], vocab_size: int = 1000):
"""BPE語彙をトレーニングする"""
# 1. 文字レベルの初期語彙を構築
word_freq = defaultdict(int)
for text in corpus:
for word in text.split():
word_freq[" ".join(list(word)) + " </w>"] += 1
vocab = dict(word_freq)
# 2. vocab_sizeに達するまで反復的にマージ
for i in range(vocab_size):
pairs = self.get_stats(vocab)
if not pairs:
break
best = max(pairs, key=pairs.get)
vocab = self.merge_vocab(best, vocab)
self.merges[best] = i
print(f"Step {i}: merged {best} (freq={pairs[best]})")
# 3. 最終語彙を構築
for word in vocab:
for token in word.split():
if token not in self.vocab:
self.vocab[token] = len(self.vocab)
return self.vocab, self.merges
2.2 tiktokenの使用
実際には、OpenAIのtiktokenライブラリの方がはるかに効率的です。
import tiktoken
import torch
from typing import List
class GPTTokenizer:
"""tiktokenベースのGPTトークナイザーラッパー"""
def __init__(self, encoding_name: str = "gpt2"):
self.enc = tiktoken.get_encoding(encoding_name)
self.vocab_size = self.enc.n_vocab
self.eot_token = self.enc.eot_token # 50256
def encode(self, text: str, add_special_tokens: bool = True) -> List[int]:
"""テキストをトークンIDのリストに変換"""
ids = self.enc.encode(text, allowed_special={"<|endoftext|>"})
if add_special_tokens:
ids = [self.eot_token] + ids
return ids
def decode(self, ids: List[int]) -> str:
"""トークンIDのリストをテキストに戻す"""
return self.enc.decode(ids)
def encode_batch(self, texts: List[str]) -> List[List[int]]:
"""バッチエンコード"""
return [self.encode(t) for t in texts]
def __len__(self):
return self.vocab_size
# 使用例
tokenizer = GPTTokenizer()
text = "Hello! Let's build GPT from scratch."
ids = tokenizer.encode(text)
print(f"Token count: {len(ids)}")
print(f"Token IDs: {ids}")
decoded = tokenizer.decode(ids)
print(f"Decoded: {decoded}")
3. データ準備
3.1 テキストデータセットクラス
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
class TextDataset(Dataset):
"""スライディングウィンドウ言語モデリングデータセット"""
def __init__(
self,
data_path: str,
tokenizer: GPTTokenizer,
seq_len: int = 1024,
stride: int = 512
):
self.seq_len = seq_len
self.stride = stride
print(f"Loading data: {data_path}")
with open(data_path, "r", encoding="utf-8") as f:
text = f.read()
print(f"Total characters: {len(text):,}")
self.tokens = tokenizer.encode(text, add_special_tokens=False)
print(f"Total tokens: {len(self.tokens):,}")
self.tokens = np.array(self.tokens, dtype=np.int32)
def __len__(self):
return max(0, (len(self.tokens) - self.seq_len) // self.stride)
def __getitem__(self, idx):
start = idx * self.stride
end = start + self.seq_len + 1
chunk = self.tokens[start:end]
# 入力: tokens[0:seq_len], ラベル: tokens[1:seq_len+1]
x = torch.tensor(chunk[:-1], dtype=torch.long)
y = torch.tensor(chunk[1:], dtype=torch.long)
return x, y
def create_dataloader(
data_path: str,
tokenizer: GPTTokenizer,
batch_size: int = 8,
seq_len: int = 1024,
stride: int = 512,
num_workers: int = 4,
shuffle: bool = True
) -> DataLoader:
dataset = TextDataset(data_path, tokenizer, seq_len, stride)
print(f"Dataset size: {len(dataset):,} samples")
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=True,
drop_last=True
)
4. モデルアーキテクチャ
4.1 設定クラス
from dataclasses import dataclass
@dataclass
class GPTConfig:
"""GPTモデル設定"""
vocab_size: int = 50257
max_seq_len: int = 1024
n_embd: int = 768
n_layer: int = 12
n_head: int = 12
dropout: float = 0.1
bias: bool = False
ffn_mult: int = 4
use_swiglu: bool = True
use_rope: bool = True
4.2 RMSNorm
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization(Llamaスタイル)"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self.weight * self._norm(x.float()).type_as(x)
4.3 回転位置埋め込み(RoPE)
class RotaryEmbedding(nn.Module):
"""Rotary Position Embedding(RoPE)"""
def __init__(self, dim: int, max_seq_len: int = 2048, base: int = 10000):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
# 周波数を計算: theta_i = base^(-2i/dim)
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self._build_cache(max_seq_len)
def _build_cache(self, seq_len: int):
t = torch.arange(seq_len, device=self.inv_freq.device).float()
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer("cos_cache", emb.cos()[None, None, :, :])
self.register_buffer("sin_cache", emb.sin()[None, None, :, :])
def rotate_half(self, x):
"""テンソルの半分を回転"""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat([-x2, x1], dim=-1)
def forward(self, q: torch.Tensor, k: torch.Tensor, seq_len: int):
cos = self.cos_cache[:, :, :seq_len, :].to(q.dtype)
sin = self.sin_cache[:, :, :seq_len, :].to(q.dtype)
q_rot = (q * cos) + (self.rotate_half(q) * sin)
k_rot = (k * cos) + (self.rotate_half(k) * sin)
return q_rot, k_rot
4.4 KVキャッシュ付きマルチヘッドCausal Attention
class CausalSelfAttention(nn.Module):
"""KVキャッシュをサポートするCausal Self-Attention"""
def __init__(self, config: GPTConfig):
super().__init__()
assert config.n_embd % config.n_head == 0
self.n_head = config.n_head
self.n_embd = config.n_embd
self.head_dim = config.n_embd // config.n_head
# 結合されたQ, K, V射影
self.qkv_proj = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
if config.use_rope:
self.rope = RotaryEmbedding(self.head_dim, config.max_seq_len)
else:
self.rope = None
# Causalマスク(下三角行列)
self.register_buffer(
"causal_mask",
torch.tril(torch.ones(config.max_seq_len, config.max_seq_len))
.view(1, 1, config.max_seq_len, config.max_seq_len)
)
def forward(
self,
x: torch.Tensor,
use_cache: bool = False,
past_kv=None
):
B, T, C = x.shape
qkv = self.qkv_proj(x)
q, k, v = qkv.split(self.n_embd, dim=2)
# ヘッドを分割: [B, T, C] -> [B, n_head, T, head_dim]
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
if self.rope is not None:
q, k = self.rope(q, k, T)
# KVキャッシュの処理
if past_kv is not None:
past_k, past_v = past_kv
k = torch.cat([past_k, k], dim=2)
v = torch.cat([past_v, v], dim=2)
present_kv = (k, v) if use_cache else None
kv_len = k.shape[2]
# スケールドドット積アテンション(利用可能な場合はFlash Attentionを使用)
if hasattr(F, "scaled_dot_product_attention"):
y = F.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
dropout_p=self.attn_dropout.p if self.training else 0.0,
is_causal=True
)
else:
# 手動実装
scale = 1.0 / math.sqrt(self.head_dim)
attn = (q @ k.transpose(-2, -1)) * scale
mask = self.causal_mask[:, :, :T, :kv_len]
attn = attn.masked_fill(mask == 0, float("-inf"))
attn = F.softmax(attn, dim=-1)
attn = self.attn_dropout(attn)
y = attn @ v
# ヘッドをマージ: [B, n_head, T, head_dim] -> [B, T, C]
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.resid_dropout(self.out_proj(y))
return y, present_kv
4.5 SwiGLUフィードフォワードネットワーク
class SwiGLUFFN(nn.Module):
"""SwiGLU活性化を持つFFN(Llamaスタイル)"""
def __init__(self, config: GPTConfig):
super().__init__()
hidden_dim = int(config.n_embd * config.ffn_mult * 2 / 3)
# ハードウェア効率のために64の倍数に切り上げ
hidden_dim = ((hidden_dim + 63) // 64) * 64
self.gate_proj = nn.Linear(config.n_embd, hidden_dim, bias=False)
self.up_proj = nn.Linear(config.n_embd, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, config.n_embd, bias=False)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# SwiGLU: gate * SiLU(x)
gate = F.silu(self.gate_proj(x))
up = self.up_proj(x)
return self.dropout(self.down_proj(gate * up))
class GPTFFN(nn.Module):
"""クラシックなGPTスタイルのGELU FFN"""
def __init__(self, config: GPTConfig):
super().__init__()
hidden_dim = config.n_embd * config.ffn_mult
self.fc1 = nn.Linear(config.n_embd, hidden_dim, bias=config.bias)
self.fc2 = nn.Linear(hidden_dim, config.n_embd, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.gelu(self.fc1(x))
x = self.dropout(self.fc2(x))
return x
4.6 Transformerブロック
class TransformerBlock(nn.Module):
"""Pre-RMSNorm Transformerブロック"""
def __init__(self, config: GPTConfig):
super().__init__()
self.norm1 = RMSNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.norm2 = RMSNorm(config.n_embd)
if config.use_swiglu:
self.ffn = SwiGLUFFN(config)
else:
self.ffn = GPTFFN(config)
def forward(self, x: torch.Tensor, use_cache: bool = False, past_kv=None):
# Pre-Norm + Attention + Residual
attn_out, present_kv = self.attn(
self.norm1(x),
use_cache=use_cache,
past_kv=past_kv
)
x = x + attn_out
# Pre-Norm + FFN + Residual
x = x + self.ffn(self.norm2(x))
return x, present_kv
4.7 完全なGPTモデル
class MiniGPT(nn.Module):
"""完全なminiGPTモデル"""
def __init__(self, config: GPTConfig):
super().__init__()
self.config = config
# 埋め込みレイヤー
self.token_emb = nn.Embedding(config.vocab_size, config.n_embd)
if not config.use_rope:
self.pos_emb = nn.Embedding(config.max_seq_len, config.n_embd)
self.emb_dropout = nn.Dropout(config.dropout)
# Transformerレイヤー
self.layers = nn.ModuleList([
TransformerBlock(config) for _ in range(config.n_layer)
])
# 最終正規化
self.norm_final = RMSNorm(config.n_embd)
# 言語モデルヘッド
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# 重み共有(入力埋め込み == 出力埋め込み)
self.lm_head.weight = self.token_emb.weight
# 重みの初期化
self.apply(self._init_weights)
# 残差ストリームの特別なスケーリング
for pn, p in self.named_parameters():
if pn.endswith("out_proj.weight") or pn.endswith("down_proj.weight"):
nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
print(f"Parameters: {self.count_params() / 1e6:.1f}M")
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
def count_params(self) -> int:
return sum(p.numel() for p in self.parameters())
def forward(
self,
input_ids: torch.Tensor,
targets: torch.Tensor = None,
use_cache: bool = False,
past_kvs: list = None
):
B, T = input_ids.shape
device = input_ids.device
# トークン埋め込み
x = self.token_emb(input_ids) # [B, T, n_embd]
# 位置埋め込み(RoPEを使用しない場合のみ)
if not self.config.use_rope:
positions = torch.arange(T, device=device)
x = x + self.pos_emb(positions)
x = self.emb_dropout(x)
# Transformerレイヤー
present_kvs = []
for i, layer in enumerate(self.layers):
past_kv = past_kvs[i] if past_kvs is not None else None
x, present_kv = layer(x, use_cache=use_cache, past_kv=past_kv)
present_kvs.append(present_kv)
x = self.norm_final(x)
if targets is not None:
# トレーニング: シーケンス全体でロスを計算
logits = self.lm_head(x)
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
ignore_index=-1
)
return logits, loss
else:
# 推論: 最後のトークンのlogitsのみ必要
logits = self.lm_head(x[:, [-1], :])
return logits, present_kvs
@classmethod
def from_pretrained(cls, model_path: str):
checkpoint = torch.load(model_path, map_location="cpu")
config = checkpoint["config"]
model = cls(config)
model.load_state_dict(checkpoint["model"])
return model
5. 事前学習の実装
5.1 学習率スケジューラー
import math
def get_cosine_schedule_with_warmup(
optimizer,
warmup_steps: int,
total_steps: int,
min_lr_ratio: float = 0.1
):
"""ウォームアップ付きコサインLRスケジュール"""
def lr_lambda(current_step: int):
if current_step < warmup_steps:
return current_step / max(1, warmup_steps)
else:
progress = (current_step - warmup_steps) / max(1, total_steps - warmup_steps)
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
return max(min_lr_ratio, cosine_decay)
from torch.optim.lr_scheduler import LambdaLR
return LambdaLR(optimizer, lr_lambda)
5.2 トレーニングループ
import torch
import torch.nn as nn
from torch.optim import AdamW
import os
import time
class Trainer:
"""miniGPT事前学習トレーナー"""
def __init__(self, model: MiniGPT, train_dataloader, val_dataloader, config: dict):
self.model = model
self.train_dl = train_dataloader
self.val_dl = val_dataloader
self.config = config
self.device = config.get("device", "cuda" if torch.cuda.is_available() else "cpu")
# 混合精度
self.dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
self.scaler = torch.cuda.amp.GradScaler(enabled=(self.dtype == torch.float16))
self.ctx = torch.amp.autocast(device_type="cuda", dtype=self.dtype)
self.optimizer = self._create_optimizer()
total_steps = len(train_dataloader) * config["num_epochs"]
warmup_steps = int(total_steps * config.get("warmup_ratio", 0.05))
self.scheduler = get_cosine_schedule_with_warmup(
self.optimizer,
warmup_steps=warmup_steps,
total_steps=total_steps
)
self.step = 0
self.best_val_loss = float("inf")
def _create_optimizer(self):
"""weight decayあり/なしのパラメータを分離"""
decay_params = []
no_decay_params = []
for name, param in self.model.named_parameters():
if not param.requires_grad:
continue
if param.dim() == 1 or "emb" in name:
no_decay_params.append(param)
else:
decay_params.append(param)
optim_groups = [
{"params": decay_params, "weight_decay": self.config.get("weight_decay", 0.1)},
{"params": no_decay_params, "weight_decay": 0.0}
]
return AdamW(
optim_groups,
lr=self.config["learning_rate"],
betas=(0.9, 0.95),
fused=True
)
def compute_val_loss(self) -> float:
self.model.eval()
total_loss = 0.0
num_batches = min(len(self.val_dl), 50)
with torch.no_grad():
for i, (x, y) in enumerate(self.val_dl):
if i >= num_batches:
break
x, y = x.to(self.device), y.to(self.device)
with self.ctx:
_, loss = self.model(x, y)
total_loss += loss.item()
self.model.train()
return total_loss / num_batches
def save_checkpoint(self, path: str):
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save({
"step": self.step,
"model": self.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
"scheduler": self.scheduler.state_dict(),
"config": self.model.config,
"best_val_loss": self.best_val_loss
}, path)
print(f"Checkpoint saved: {path}")
def train(self):
self.model.to(self.device)
self.model.train()
num_epochs = self.config["num_epochs"]
grad_accum = self.config.get("gradient_accumulation_steps", 1)
max_grad_norm = self.config.get("max_grad_norm", 1.0)
log_interval = self.config.get("log_interval", 100)
eval_interval = self.config.get("eval_interval", 500)
save_interval = self.config.get("save_interval", 1000)
print(f"Training on {self.device}, dtype={self.dtype}")
for epoch in range(num_epochs):
t0 = time.time()
for batch_idx, (x, y) in enumerate(self.train_dl):
x, y = x.to(self.device), y.to(self.device)
with self.ctx:
_, loss = self.model(x, y)
loss = loss / grad_accum
self.scaler.scale(loss).backward()
if (batch_idx + 1) % grad_accum == 0:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
self.scaler.step(self.optimizer)
self.scaler.update()
self.scheduler.step()
self.optimizer.zero_grad(set_to_none=True)
self.step += 1
if self.step % log_interval == 0:
lr = self.optimizer.param_groups[0]["lr"]
elapsed = time.time() - t0
tokens_per_sec = log_interval * x.shape[0] * x.shape[1] / elapsed
print(
f"Epoch {epoch} | Step {self.step} | "
f"Loss: {loss.item() * grad_accum:.4f} | "
f"LR: {lr:.6f} | "
f"Tokens/s: {tokens_per_sec:.0f}"
)
t0 = time.time()
if self.step % eval_interval == 0:
val_loss = self.compute_val_loss()
print(f"Val Loss: {val_loss:.4f} | Perplexity: {math.exp(val_loss):.2f}")
if val_loss < self.best_val_loss:
self.best_val_loss = val_loss
self.save_checkpoint("./checkpoints/best_model.pt")
if self.step % save_interval == 0:
self.save_checkpoint(f"./checkpoints/step_{self.step}.pt")
6. テキスト生成
6.1 デコード戦略
import torch
import torch.nn.functional as F
class TextGenerator:
"""miniGPTテキストジェネレーター"""
def __init__(self, model: MiniGPT, tokenizer: GPTTokenizer, device: str = "cuda"):
self.model = model.to(device)
self.tokenizer = tokenizer
self.device = device
self.model.eval()
@torch.no_grad()
def generate(
self,
prompt: str,
max_new_tokens: int = 200,
strategy: str = "top_p",
temperature: float = 0.8,
top_k: int = 50,
top_p: float = 0.9,
num_beams: int = 4
) -> str:
input_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
input_ids = torch.tensor(input_ids, dtype=torch.long, device=self.device).unsqueeze(0)
if strategy == "greedy":
generated = self._greedy_decode(input_ids, max_new_tokens)
elif strategy == "temperature":
generated = self._temperature_sampling(input_ids, max_new_tokens, temperature)
elif strategy == "top_k":
generated = self._top_k_sampling(input_ids, max_new_tokens, temperature, top_k)
elif strategy == "top_p":
generated = self._top_p_sampling(input_ids, max_new_tokens, temperature, top_p)
elif strategy == "beam":
generated = self._beam_search(input_ids, max_new_tokens, num_beams)
else:
raise ValueError(f"Unknown strategy: {strategy}")
new_tokens = generated[0][input_ids.shape[1]:].tolist()
return self.tokenizer.decode(new_tokens)
def _greedy_decode(self, input_ids, max_new_tokens):
"""グリーディデコード: 常に最高確率のトークンを選択"""
current_ids = input_ids.clone()
for _ in range(max_new_tokens):
logits, _ = self.model(current_ids)
next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
current_ids = torch.cat([current_ids, next_token], dim=1)
if next_token.item() == self.tokenizer.eot_token:
break
return current_ids
def _temperature_sampling(self, input_ids, max_new_tokens, temperature):
"""温度サンプリング: softmax前にlogitsをスケール"""
current_ids = input_ids.clone()
for _ in range(max_new_tokens):
logits, _ = self.model(current_ids)
logits = logits[:, -1, :] / temperature
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
current_ids = torch.cat([current_ids, next_token], dim=1)
if next_token.item() == self.tokenizer.eot_token:
break
return current_ids
def _top_k_sampling(self, input_ids, max_new_tokens, temperature, top_k):
"""Top-kサンプリング: 上位kトークンのみからサンプリング"""
current_ids = input_ids.clone()
for _ in range(max_new_tokens):
logits, _ = self.model(current_ids)
logits = logits[:, -1, :] / temperature
top_k_logits, top_k_indices = torch.topk(logits, k=top_k, dim=-1)
filtered_logits = torch.full_like(logits, float("-inf"))
filtered_logits.scatter_(1, top_k_indices, top_k_logits)
probs = F.softmax(filtered_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
current_ids = torch.cat([current_ids, next_token], dim=1)
if next_token.item() == self.tokenizer.eot_token:
break
return current_ids
def _top_p_sampling(self, input_ids, max_new_tokens, temperature, top_p):
"""Top-p(Nucleus)サンプリング: 確率質量pをカバーする上位トークンからサンプリング"""
current_ids = input_ids.clone()
for _ in range(max_new_tokens):
logits, _ = self.model(current_ids)
logits = logits[:, -1, :] / temperature
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# top_pを超える累積確率のトークンを削除
sorted_indices_to_remove = cumulative_probs - F.softmax(sorted_logits, dim=-1) > top_p
sorted_logits[sorted_indices_to_remove] = float("-inf")
logits = torch.zeros_like(logits).scatter_(1, sorted_indices, sorted_logits)
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
current_ids = torch.cat([current_ids, next_token], dim=1)
if next_token.item() == self.tokenizer.eot_token:
break
return current_ids
def _beam_search(self, input_ids, max_new_tokens, num_beams):
"""ビームサーチ: 複数の候補を同時に探索"""
B = input_ids.shape[0]
assert B == 1, "Beam search only supports batch size 1"
beams = [(input_ids, 0.0)]
for _ in range(max_new_tokens):
all_candidates = []
for beam_ids, beam_score in beams:
logits, _ = self.model(beam_ids)
log_probs = F.log_softmax(logits[:, -1, :], dim=-1)
top_log_probs, top_ids = log_probs.topk(num_beams)
for i in range(num_beams):
token = top_ids[0, i].unsqueeze(0).unsqueeze(0)
new_ids = torch.cat([beam_ids, token], dim=1)
new_score = beam_score + top_log_probs[0, i].item()
all_candidates.append((new_ids, new_score))
all_candidates.sort(key=lambda x: x[1] / x[0].shape[1], reverse=True)
beams = all_candidates[:num_beams]
if beams[0][0][0, -1].item() == self.tokenizer.eot_token:
break
return beams[0][0]
7. パープレキシティ評価
import math
import torch
@torch.no_grad()
def compute_perplexity(
model: MiniGPT,
tokenizer: GPTTokenizer,
text: str,
device: str = "cuda",
seq_len: int = 512
) -> float:
"""テキストのパープレキシティを計算"""
model.eval()
model.to(device)
tokens = tokenizer.encode(text, add_special_tokens=False)
total_loss = 0.0
total_tokens = 0
for i in range(0, len(tokens) - seq_len, seq_len):
chunk = tokens[i:i + seq_len + 1]
if len(chunk) < seq_len + 1:
break
x = torch.tensor(chunk[:-1], dtype=torch.long, device=device).unsqueeze(0)
y = torch.tensor(chunk[1:], dtype=torch.long, device=device).unsqueeze(0)
_, loss = model(x, y)
total_loss += loss.item() * seq_len
total_tokens += seq_len
avg_loss = total_loss / total_tokens
return math.exp(avg_loss)
# 使用例
model = MiniGPT.from_pretrained("./checkpoints/best_model.pt")
tokenizer = GPTTokenizer()
generator = TextGenerator(model, tokenizer)
test_text = "The quick brown fox jumps over the lazy dog."
ppl = compute_perplexity(model, tokenizer, test_text)
print(f"Perplexity: {ppl:.2f}")
for strategy in ["greedy", "top_k", "top_p"]:
generated = generator.generate(
prompt="Once upon a time",
max_new_tokens=100,
strategy=strategy,
temperature=0.8
)
print(f"\n[{strategy}]: {generated}")
8. スケールアップ
8.1 Flash Attention 2
# pip install flash-attn
from flash_attn import flash_attn_qkvpacked_func
class FlashCausalAttention(nn.Module):
"""Flash Attention 2を使用した高速アテンション"""
def __init__(self, config: GPTConfig):
super().__init__()
self.n_head = config.n_head
self.n_embd = config.n_embd
self.head_dim = config.n_embd // config.n_head
self.qkv_proj = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.dropout = config.dropout
def forward(self, x: torch.Tensor):
B, T, C = x.shape
qkv = self.qkv_proj(x)
# Flash Attentionは[B, T, 3, n_head, head_dim]の形状を期待
qkv = qkv.view(B, T, 3, self.n_head, self.head_dim)
attn_out = flash_attn_qkvpacked_func(
qkv,
dropout_p=self.dropout if self.training else 0.0,
causal=True
) # [B, T, n_head, head_dim]
attn_out = attn_out.reshape(B, T, C)
return self.out_proj(attn_out)
8.2 FSDP分散トレーニング
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
import functools
def setup_fsdp_training():
"""FSDP分散トレーニングのセットアップ"""
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
torch.cuda.set_device(rank)
config = GPTConfig(
n_embd=1024,
n_layer=24,
n_head=16,
vocab_size=50257
)
model = MiniGPT(config)
# ラップポリシー: TransformerBlock境界でシャード
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={TransformerBlock}
)
model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=torch.distributed.fsdp.MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16
),
sharding_strategy="FULL_SHARD", # ZeRO-3相当
device_id=rank
)
return model, rank, world_size
8.3 Chinchillaスケーリング則
ChinchillaペーパーHoffmann et al. (2022)は、最適モデルサイズとトレーニングトークンの関係について次の式を示しています:
最適トークン数 = 20 x パラメータ数
def chinchilla_optimal_tokens(model_params: int) -> int:
"""Chinchillaスケーリング則に基づいた最適トレーニングトークン数を計算"""
return 20 * model_params
models = {
"124M (GPT-2 Small)": 124e6,
"1.3B": 1.3e9,
"7B (Llama-2 7B)": 7e9,
"13B": 13e9,
"70B (Llama-2 70B)": 70e9,
}
print("モデルサイズ別の推奨トレーニングトークン数:")
for name, params in models.items():
optimal_tokens = chinchilla_optimal_tokens(int(params))
print(f" {name}: {optimal_tokens/1e9:.1f}B tokens")
9. インストラクションチューニング
9.1 SFTデータフォーマット
# ChatMLフォーマット
CHATML_SYSTEM = "<|im_start|>system\n{system}<|im_end|>\n"
CHATML_USER = "<|im_start|>user\n{user}<|im_end|>\n"
CHATML_ASSISTANT = "<|im_start|>assistant\n{assistant}<|im_end|>\n"
def format_chatml(system: str, user: str, assistant: str, add_generation_prompt: bool = False) -> str:
formatted = CHATML_SYSTEM.format(system=system)
formatted += CHATML_USER.format(user=user)
formatted += CHATML_ASSISTANT.format(assistant=assistant)
if add_generation_prompt:
formatted += "<|im_start|>assistant\n"
return formatted
# Llama-3フォーマット
def format_llama3(system: str, user: str, assistant: str, add_generation_prompt: bool = False) -> str:
formatted = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system}<|eot_id|>\n"
formatted += f"<|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|>\n"
formatted += f"<|start_header_id|>assistant<|end_header_id|>\n\n{assistant}<|eot_id|>"
if add_generation_prompt:
formatted += "\n<|start_header_id|>assistant<|end_header_id|>\n\n"
return formatted
9.2 インストラクションデータセットクラス
class InstructionDataset(Dataset):
"""SFTデータセット: アシスタント応答のみでロスを計算"""
def __init__(
self,
data: list,
tokenizer: GPTTokenizer,
max_seq_len: int = 2048,
format_func=format_chatml
):
self.samples = []
for item in data:
# 会話全体をトークン化
full_text = format_func(
system=item.get("system", "You are a helpful assistant."),
user=item["instruction"],
assistant=item["response"]
)
full_ids = tokenizer.encode(full_text, add_special_tokens=False)
# プロンプトのみをトークン化(応答の開始位置を見つけるため)
prompt_text = format_func(
system=item.get("system", "You are a helpful assistant."),
user=item["instruction"],
assistant="",
add_generation_prompt=True
)
prompt_ids = tokenizer.encode(prompt_text, add_special_tokens=False)
prompt_len = len(prompt_ids)
if len(full_ids) > max_seq_len:
full_ids = full_ids[:max_seq_len]
# プロンプトトークンをラベルでマスク(-100 = ロスで無視)
labels = full_ids.copy()
labels[:prompt_len] = [-100] * prompt_len
self.samples.append({
"input_ids": full_ids,
"labels": labels
})
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
sample = self.samples[idx]
return (
torch.tensor(sample["input_ids"], dtype=torch.long),
torch.tensor(sample["labels"], dtype=torch.long)
)
9.3 SFTトレーニングループ
def train_sft(
base_model_path: str,
data_path: str,
output_dir: str,
num_epochs: int = 3,
learning_rate: float = 2e-5,
batch_size: int = 4
):
"""教師ありファインチューニング(SFT)"""
import json
with open(data_path) as f:
data = json.load(f)
tokenizer = GPTTokenizer()
dataset = InstructionDataset(data, tokenizer, max_seq_len=2048)
def collate_fn(batch):
input_ids = [item[0] for item in batch]
labels = [item[1] for item in batch]
max_len = max(len(x) for x in input_ids)
padded_ids = torch.zeros(len(batch), max_len, dtype=torch.long)
padded_labels = torch.full((len(batch), max_len), -100, dtype=torch.long)
for i, (ids, lbls) in enumerate(zip(input_ids, labels)):
padded_ids[i, :len(ids)] = ids
padded_labels[i, :len(lbls)] = lbls
return padded_ids, padded_labels
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
model = MiniGPT.from_pretrained(base_model_path)
device = "cuda"
model.to(device)
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
total_steps = len(dataloader) * num_epochs
scheduler = get_cosine_schedule_with_warmup(optimizer, int(total_steps * 0.1), total_steps)
model.train()
for epoch in range(num_epochs):
for step, (x, y) in enumerate(dataloader):
x, y = x.to(device), y.to(device)
logits, loss = model(x, y)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
if step % 50 == 0:
print(f"Epoch {epoch}, Step {step}, Loss: {loss.item():.4f}")
os.makedirs(output_dir, exist_ok=True)
torch.save({
"model": model.state_dict(),
"config": model.config
}, f"{output_dir}/sft_model.pt")
print(f"SFT complete. Model saved: {output_dir}")
10. オープンソースLLMアーキテクチャ分析
10.1 Llama 3アーキテクチャ
Llama 3はMetaのオープンソースLLMで、以下の主要な特性を持ちます:
| 特性 | 詳細 |
|---|---|
| アーキテクチャ | Decoder-only Transformer |
| 正規化 | Pre-RMSNorm |
| 活性化関数 | SwiGLU |
| 位置エンコーディング | RoPE (theta=500,000) |
| 語彙サイズ | 128,256 |
| コンテキスト長 | 128K (Llama 3.1+) |
| アテンション | GQA(8B、70Bモデル) |
# Llama 3スタイルの設定(8Bモデル)
llama3_config = GPTConfig(
vocab_size=128256,
max_seq_len=8192,
n_embd=4096,
n_layer=32,
n_head=32,
use_rope=True,
use_swiglu=True,
bias=False
)
10.2 グループクエリアテンション(GQA)
class GroupedQueryAttention(nn.Module):
"""Grouped Query Attention(Llama 3、Mistralで使用)"""
def __init__(self, config: GPTConfig, n_kv_heads: int = 8):
super().__init__()
self.n_head = config.n_head
self.n_kv_heads = n_kv_heads
self.n_rep = config.n_head // n_kv_heads
self.head_dim = config.n_embd // config.n_head
self.q_proj = nn.Linear(config.n_embd, config.n_head * self.head_dim, bias=False)
self.k_proj = nn.Linear(config.n_embd, n_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.n_embd, n_kv_heads * self.head_dim, bias=False)
self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.rope = RotaryEmbedding(self.head_dim, config.max_seq_len, base=500000)
def forward(self, x: torch.Tensor):
B, T, C = x.shape
q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
q, k = self.rope(q, k, T)
# KVヘッドをQueryヘッド数に合わせて繰り返す(GQAのコア)
k = k.unsqueeze(2).expand(B, self.n_kv_heads, self.n_rep, T, self.head_dim)
k = k.reshape(B, self.n_head, T, self.head_dim)
v = v.unsqueeze(2).expand(B, self.n_kv_heads, self.n_rep, T, self.head_dim)
v = v.reshape(B, self.n_head, T, self.head_dim)
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
y = y.transpose(1, 2).contiguous().view(B, T, C)
return self.out_proj(y)
10.3 Mixture of Experts(MoE)
DeepSeekとMixtralは、計算量を比例して増やさずにモデル容量を増やすためにMixture of Expertsを使用しています。
class MoEFFN(nn.Module):
"""Mixture of Experts FFN(簡略版)"""
def __init__(self, config: GPTConfig, num_experts: int = 8, top_k: int = 2):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
# ルーター
self.router = nn.Linear(config.n_embd, num_experts, bias=False)
# エキスパートFFN
self.experts = nn.ModuleList([
SwiGLUFFN(config) for _ in range(num_experts)
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, C = x.shape
x_flat = x.view(B * T, C)
# ルーティングスコアを計算
router_logits = self.router(x_flat)
router_probs = F.softmax(router_logits, dim=-1)
# top-kエキスパートを選択
top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
# エキスパート出力を計算
output = torch.zeros_like(x_flat)
for k in range(self.top_k):
expert_idx = top_k_indices[:, k]
expert_prob = top_k_probs[:, k].unsqueeze(-1)
for e in range(self.num_experts):
mask = (expert_idx == e)
if mask.any():
expert_out = self.experts[e](x_flat[mask])
output[mask] += expert_prob[mask] * expert_out
return output.view(B, T, C)
すべてを統合する
全コンポーネントを統合した完全なスクリプトです:
import torch
from dataclasses import dataclass
# 1. 設定
config = GPTConfig(
vocab_size=50257,
max_seq_len=1024,
n_embd=768,
n_layer=12,
n_head=12,
dropout=0.1,
use_rope=True,
use_swiglu=True
)
# 2. モデル
model = MiniGPT(config)
print(f"Model parameters: {model.count_params() / 1e6:.1f}M")
# 3. トークナイザー
tokenizer = GPTTokenizer()
# 4. データ
train_loader = create_dataloader("train.txt", tokenizer, batch_size=8, seq_len=1024)
val_loader = create_dataloader("val.txt", tokenizer, batch_size=8, seq_len=1024, shuffle=False)
# 5. トレーナー
trainer_config = {
"num_epochs": 5,
"learning_rate": 6e-4,
"weight_decay": 0.1,
"gradient_accumulation_steps": 4,
"max_grad_norm": 1.0,
"warmup_ratio": 0.05,
"log_interval": 100,
"eval_interval": 500,
"save_interval": 1000,
"device": "cuda"
}
trainer = Trainer(model, train_loader, val_loader, trainer_config)
trainer.train()
# 6. 生成
generator = TextGenerator(model, tokenizer)
prompts = [
"The history of artificial intelligence",
"In the year 2050, scientists discovered",
"The best way to learn programming is",
]
for prompt in prompts:
print(f"\nPrompt: {prompt}")
for strategy in ["greedy", "top_p"]:
text = generator.generate(prompt, max_new_tokens=100, strategy=strategy)
print(f" [{strategy}]: {text}")
まとめ
このガイドでは、miniGPTをゼロから構築し、モダンLLMのすべてのコア要素をカバーしました:
- トークナイザー: BPEアルゴリズムがテキストをサブワードトークンに分割
- 埋め込み: トークン埋め込み + RoPE位置エンコーディング
- マルチヘッドCausal Attention: Q/K/V射影、Causalマスク、KVキャッシュ
- FFN: 非線形性のためのSwiGLU活性化
- Pre-RMSNorm: トレーニングの安定性
- テキスト生成: Greedy、Temperature、Top-k、Top-p、ビームサーチ
- 事前学習: コサインLRスケジュール、勾配クリッピング、混合精度
- SFT: アシスタント応答トークンのみでロスを計算
- スケーリング: Flash Attention、FSDP、Chinchilla則
この基盤により、あらゆるモダンなオープンソースLLMを理解し拡張するためのツールが揃います。ここで説明したコンセプトは、Llama 3、Mistral、Qwenの実装詳細に直接対応しています。
参考資料
- Karpathy's nanoGPT: https://github.com/karpathy/nanoGPT
- Chinchilla Scaling Laws (Hoffmann et al., 2022): https://arxiv.org/abs/2203.15556
- Llama 3 Technical Report: https://ai.meta.com/blog/meta-llama-3
- Flash Attention: https://github.com/Dao-AILab/flash-attention
- OpenAI tiktoken: https://github.com/openai/tiktoken
- Attention is All You Need (Vaswani et al., 2017)
- RoFormer: Enhanced Transformer with Rotary Position Embedding (Su et al., 2021)