- Authors

- Name
- Youngju Kim
- @fjvbn20031
- はじめに
- 1. コンパイラの基礎: 全体パイプライン
- 2. 中間表現 (IR)
- 3. 最適化パス
- 4. コード生成: x86-64アセンブリ
- 5. JITコンパイル: 実行時最適化
- 6. AI/MLコンパイラのエコシステム
- 7. MLIR: 次世代コンパイラインフラ
- 8. クイズ: 理解度確認
- まとめ
はじめに
コンパイラを理解する理由は、プログラミング言語を作るためだけではありません。今日のAI/ML分野ではtorch.compile()、TVM、XLA、MLIRがコアインフラとして定着しており、それらの内部動作はすべてコンパイラ理論の上に成り立っています。
この記事ではコンパイラ設計のパイプライン全体を最初から最後まで解説し、ML専用コンパイラの最前線まで辿り着きます。
1. コンパイラの基礎: 全体パイプライン
コンパイラはソースコードを実行可能なコードへ変換するプログラムです。
ソースコード
↓ 字句解析 (Lexing)
トークン列
↓ 構文解析 (Parsing)
AST (抽象構文木)
↓ 意味解析 (Semantic Analysis)
注釈付きAST
↓ IR生成
中間表現 (IR)
↓ 最適化 (Optimization)
最適化されたIR
↓ コード生成 (Code Generation)
機械語 / アセンブリ
1.1 字句解析 (Lexical Analysis)
字句解析器(Lexer / Tokenizer)はソースコードの文字列をトークンのストリームに変換します。トークンは数値リテラル、識別子、キーワード、演算子などの意味単位です。
import re
from enum import Enum, auto
from dataclasses import dataclass
from typing import List, Optional
class TokenType(Enum):
NUMBER = auto()
IDENT = auto()
PLUS = auto()
MINUS = auto()
STAR = auto()
SLASH = auto()
LPAREN = auto()
RPAREN = auto()
ASSIGN = auto()
SEMICOL = auto()
LET = auto()
EOF = auto()
@dataclass
class Token:
type: TokenType
value: str
line: int
KEYWORDS = {'let': TokenType.LET}
TOKEN_PATTERNS = [
(r'\d+(\.\d+)?', TokenType.NUMBER),
(r'[a-zA-Z_]\w*', TokenType.IDENT),
(r'\+', TokenType.PLUS),
(r'-', TokenType.MINUS),
(r'\*', TokenType.STAR),
(r'/', TokenType.SLASH),
(r'\(', TokenType.LPAREN),
(r'\)', TokenType.RPAREN),
(r'=', TokenType.ASSIGN),
(r';', TokenType.SEMICOL),
]
def tokenize(source: str) -> List[Token]:
tokens = []
line = 1
pos = 0
while pos < len(source):
if source[pos] in ' \t\r':
pos += 1
continue
if source[pos] == '\n':
line += 1
pos += 1
continue
matched = False
for pattern, ttype in TOKEN_PATTERNS:
m = re.match(pattern, source[pos:])
if m:
value = m.group(0)
if ttype == TokenType.IDENT and value in KEYWORDS:
ttype = KEYWORDS[value]
tokens.append(Token(ttype, value, line))
pos += len(value)
matched = True
break
if not matched:
raise SyntaxError(f"予期しない文字 '{source[pos]}' (行 {line})")
tokens.append(Token(TokenType.EOF, '', line))
return tokens
# 使用例
src = "let x = 10 + 20 * 3;"
for tok in tokenize(src):
print(tok)
1.2 構文解析とAST
パーサはトークン列を受け取り**抽象構文木(AST)**を構築します。
from dataclasses import dataclass
from typing import Union
# ASTノード定義
@dataclass
class NumberLit:
value: float
@dataclass
class Identifier:
name: str
@dataclass
class BinOp:
op: str
left: 'Expr'
right: 'Expr'
@dataclass
class LetStmt:
name: str
value: 'Expr'
Expr = Union[NumberLit, Identifier, BinOp]
# 再帰下降パーサ
class Parser:
def __init__(self, tokens: List[Token]):
self.tokens = tokens
self.pos = 0
def peek(self) -> Token:
return self.tokens[self.pos]
def consume(self, expected: Optional[TokenType] = None) -> Token:
tok = self.tokens[self.pos]
if expected and tok.type != expected:
raise SyntaxError(f"期待: {expected}, 実際: {tok.type} (行 {tok.line})")
self.pos += 1
return tok
def parse_let(self) -> LetStmt:
self.consume(TokenType.LET)
name = self.consume(TokenType.IDENT).value
self.consume(TokenType.ASSIGN)
expr = self.parse_expr()
self.consume(TokenType.SEMICOL)
return LetStmt(name=name, value=expr)
def parse_expr(self) -> Expr:
return self.parse_additive()
def parse_additive(self) -> Expr:
left = self.parse_multiplicative()
while self.peek().type in (TokenType.PLUS, TokenType.MINUS):
op = self.consume().value
right = self.parse_multiplicative()
left = BinOp(op=op, left=left, right=right)
return left
def parse_multiplicative(self) -> Expr:
left = self.parse_primary()
while self.peek().type in (TokenType.STAR, TokenType.SLASH):
op = self.consume().value
right = self.parse_primary()
left = BinOp(op=op, left=left, right=right)
return left
def parse_primary(self) -> Expr:
tok = self.peek()
if tok.type == TokenType.NUMBER:
self.consume()
return NumberLit(value=float(tok.value))
if tok.type == TokenType.IDENT:
self.consume()
return Identifier(name=tok.value)
if tok.type == TokenType.LPAREN:
self.consume()
expr = self.parse_expr()
self.consume(TokenType.RPAREN)
return expr
raise SyntaxError(f"予期しないトークン: {tok}")
主要なパース戦略の比較:
| 手法 | 種別 | 用途 |
|---|---|---|
| 再帰下降 | LL(k) | 手書き、読みやすい |
| Pratt Parser | トップダウン演算子優先 | 式の解析に優秀 |
| LALR(1) | ボトムアップ | Bison、GCC、Python |
| PEG / Packrat | 解析表現文法 | 曖昧さなし、メモ化 |
2. 中間表現 (IR)
ASTから直接機械語へ変換するのは非効率です。IRを介することで言語に依存しない最適化が可能になります。
2.1 3アドレスコード
最も単純なフラットなIR。すべての演算は最大3つのオペランドを持ちます。
# ソース: x = (a + b) * (c - d)
t1 = a + b
t2 = c - d
x = t1 * t2
2.2 SSA形式 (静的単一代入)
SSAでは各変数が正確に1回だけ代入されます。これによりデータフロー分析が劇的に単純化されます。
# 通常コード
x = 1
x = x + 2 # xが2回代入される
# SSA変換後
x_1 = 1
x_2 = x_1 + 2
制御フローの合流点ではPhi関数を使います。
if cond:
x_1 = 10
else:
x_2 = 20
# 合流点
x_3 = phi(x_1, x_2)
2.3 LLVM IR
LLVM IRはClang、Rust、Swiftなどが採用する産業標準のIRです。
; 関数定義: int add(int a, int b)
define i32 @add(i32 %a, i32 %b) {
entry:
%result = add i32 %a, %b
ret i32 %result
}
; 条件分岐 + SSA phi関数
define i32 @max(i32 %a, i32 %b) {
entry:
%cmp = icmp sgt i32 %a, %b ; a > b ?
br i1 %cmp, label %then, label %else
then:
br label %merge
else:
br label %merge
merge:
%result = phi i32 [ %a, %then ], [ %b, %else ]
ret i32 %result
}
LLVM IRの特徴:
- 強い型システム:
i32、i64、float、ptrなどすべて明示的 - SSA基盤: すべての値は1回だけ定義
- 仮想レジスタ無制限:
%0、%1、%2... 制限なし - 言語独立性: C、Rust、Swift、Haskellがすべて同じIRに変換
3. 最適化パス
3.1 定数畳み込み (Constant Folding)
コンパイル時に計算可能な定数式を事前計算します。
# 最適化前
x = 2 * 3 + 4
# 最適化後 (コンパイル時に計算済み)
x = 10
def constant_fold(node):
"""簡単な定数畳み込みの実装"""
if isinstance(node, BinOp):
left = constant_fold(node.left)
right = constant_fold(node.right)
if isinstance(left, NumberLit) and isinstance(right, NumberLit):
ops = {
'+': lambda a, b: a + b,
'-': lambda a, b: a - b,
'*': lambda a, b: a * b,
'/': lambda a, b: a / b,
}
return NumberLit(value=ops[node.op](left.value, right.value))
return BinOp(op=node.op, left=left, right=right)
return node
3.2 デッドコード除去 (Dead Code Elimination)
結果が使われない計算を除去します。
def dead_code_eliminate(stmts, used_vars):
"""使われない変数への代入を削除する後方パス"""
live = set()
for stmt in reversed(stmts):
if isinstance(stmt, LetStmt):
if stmt.name in used_vars or stmt.name in live:
live.update(collect_vars(stmt.value))
return [s for s in stmts
if not isinstance(s, LetStmt) or s.name in live]
3.3 ループアンローリング (Loop Unrolling)
ループオーバーヘッドを削減し、命令レベル並列性(ILP)を高めます。
// 元のコード (ループ4回)
for (int i = 0; i < 4; i++) {
a[i] = b[i] + c[i];
}
// アンローリング後 (ループ消去)
a[0] = b[0] + c[0];
a[1] = b[1] + c[1];
a[2] = b[2] + c[2];
a[3] = b[3] + c[3];
3.4 関数インライン化 (Function Inlining)
関数呼び出しのオーバーヘッドをなくし、追加の最適化機会を生み出します。
// インライン前
inline int square(int x) { return x * x; }
int y = square(5);
// インライン後 (定数畳み込みで25に)
int y = 25;
4. コード生成: x86-64アセンブリ
4.1 レジスタ割り付け
x86-64には汎用レジスタが16個(rax、rbx、rcx、rdx、rsi、rdi、rsp、rbp、r8〜r15)あります。無制限の仮想レジスタをこの有限セットにマッピングする必要があります。
グラフ彩色法が最も有名なアルゴリズムです。同時に生存している仮想レジスタ間に辺を張り、N色でグラフを彩色します(N=物理レジスタ数)。彩色できない場合はスピル(スタックへの退避)を行います。
4.2 x86-64アセンブリの例
; C関数: int add(int a, int b) { return a + b; }
; Linux System V ABI: 引数はrdi, rsi順; 返り値はrax
add:
push rbp
mov rbp, rsp
mov DWORD PTR [rbp-4], edi
mov DWORD PTR [rbp-8], esi
mov edx, DWORD PTR [rbp-4]
mov eax, DWORD PTR [rbp-8]
add eax, edx
pop rbp
ret
; -O2最適化版 (1命令のみ!)
add:
lea eax, [rdi + rsi]
ret
5. JITコンパイル: 実行時最適化
**JIT(Just-In-Time)**コンパイルはプログラム実行中にコードをコンパイルします。インタープリタ言語とネイティブコードのパフォーマンスギャップを埋める核心技術です。
5.1 Numba JIT
NumbaはPython/NumPyコードをLLVMを通じてネイティブ機械語にコンパイルします。
import numpy as np
from numba import njit, prange
import time
# 純粋Pythonベースライン
def matmul_python(A, B):
N = A.shape[0]
C = np.zeros((N, N))
for i in range(N):
for j in range(N):
for k in range(N):
C[i, j] += A[i, k] * B[k, j]
return C
# Numba JIT版 (並列化あり)
@njit(parallel=True)
def matmul_numba(A, B):
N = A.shape[0]
C = np.zeros((N, N))
for i in prange(N): # 並列ループ
for j in range(N):
for k in range(N):
C[i, j] += A[i, k] * B[k, j]
return C
N = 256
A = np.random.rand(N, N).astype(np.float32)
B = np.random.rand(N, N).astype(np.float32)
# 初回呼び出しでコンパイル (ウォームアップ)
_ = matmul_numba(A, B)
t0 = time.perf_counter()
C_py = matmul_python(A, B)
print(f"Python: {time.perf_counter()-t0:.3f}s")
t0 = time.perf_counter()
C_nb = matmul_numba(A, B)
print(f"Numba JIT: {time.perf_counter()-t0:.4f}s")
# 典型的な結果: Python ~30s vs Numba ~0.01s (数百倍高速)
Numbaの内部動作:
- Pythonバイトコードの読み取り
- 型推論により具体的な型を確定
- LLVM IRへのローワリング
- LLVMバックエンドが機械語にコンパイル
- コンパイル済みアーティファクトをキャッシュ
5.2 PyPyとトレーシングJIT
PyPyのトレーシングJITは「ホットパス」(頻繁に実行されるトレース)を記録し、それを直接コンパイルします。コール境界を越えてインライン化できる点が特徴です。
6. AI/MLコンパイラのエコシステム
6.1 torch.compile()の内部動作
PyTorch 2.0の核心機能torch.compile()は3段階のコンパイルパイプラインで構成されています。
import torch
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(128, 256)
self.fc2 = nn.Linear(256, 10)
self.relu = nn.ReLU()
def forward(self, x):
return self.fc2(self.relu(self.fc1(x)))
model = SimpleNet().cuda()
x = torch.randn(32, 128).cuda()
# torch.compile適用
compiled = torch.compile(model, mode='max-autotune')
# 初回呼び出し: コンパイル発生 (数秒かかる)
out = compiled(x)
# 以降の呼び出し: コンパイル済みカーネルを再利用
for _ in range(100):
out = compiled(x)
第1段階 — TorchDynamo (グラフキャプチャ)
DynamoはPythonバイトコードを傍受し、PyTorch演算をtorch.fx.Graphとして捕捉します。
import torch._dynamo as dynamo
def my_func(x, y):
return torch.sin(x) + torch.cos(y)
# FXグラフの確認
explanation = dynamo.explain(my_func)(
torch.randn(4), torch.randn(4)
)
print(explanation.graphs[0].print_tabular())
# opcode | name | target | args
# --------|--------|--------------|----------
# call_fn | sin | torch.sin | (x,)
# call_fn | cos | torch.cos | (y,)
# call_fn | add | operator.add | (sin, cos)
# output | output | output | (add,)
第2段階 — AOTAutograd (自動微分の事前コンパイル)
順伝播FXグラフを受け取り、実行前に順伝播・逆伝播グラフを両方生成します。これによりバックエンドが完全な勾配計算を最適化できます。
第3段階 — TorchInductor (カーネル生成)
FXグラフをTriton(CUDA)またはC++(CPU)のカーネルコードに変換します。演算子フュージョン、タイリング、ベクトル化などの最適化がここで適用されます。
6.2 Apache TVM
TVMは深層学習モデルを多様なハードウェアターゲット(GPU、CPU、FPGA、NPU)向けに最適化します。
import tvm
from tvm import relay
import torch, torchvision
import numpy as np
# PyTorchモデルから開始
model = torchvision.models.resnet18(pretrained=False).eval()
input_shape = [1, 3, 224, 224]
trace_input = torch.randn(input_shape)
# TorchScriptを経由してTVM Relay IRへ変換
scripted = torch.jit.trace(model, trace_input)
mod, params = relay.frontend.from_pytorch(
scripted, [("input0", input_shape)]
)
target = tvm.target.Target("cuda")
# AutoScheduler: 最適カーネルパラメータの探索
from tvm import auto_scheduler
tasks, task_weights = auto_scheduler.extract_tasks(
mod["main"], params, target
)
tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
tune_option = auto_scheduler.TuningOptions(
num_measure_trials=200,
runner=auto_scheduler.LocalRunner(repeat=10),
measure_callbacks=[auto_scheduler.RecordToFile("tuning.json")],
)
tuner.tune(tune_option)
# 最良スケジュールでコンパイル
with auto_scheduler.ApplyHistoryBest("tuning.json"):
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params=params)
print("TVMコンパイル完了!")
6.3 演算子フュージョン (Operator Fusion)
フュージョンはGPUにおける最も重要な最適化です。複数のカーネル呼び出しを1つに統合します。
# フュージョン前: 3回のGPUカーネル起動
y1 = conv2d(x) # カーネル1 — HBMへ書き込み
y2 = batch_norm(y1) # カーネル2 — HBM読み書き
y3 = relu(y2) # カーネル3 — HBM読み書き
# フュージョン後: 1回のGPUカーネル起動
y3 = fused_conv_bn_relu(x)
# 中間結果はL1/L2キャッシュに留まる
# HBM帯域幅の節約 = 大幅な高速化
6.4 XLAとJAX
XLA(Accelerated Linear Algebra)はTensorFlowとJAXのコンパイラです。HLO(High Level Operations)をIRとして使用します。
import jax
import jax.numpy as jnp
@jax.jit
def compute(x, W, b):
# XLAが自動的にmatmul + add + reluをフューズ
return jax.nn.relu(x @ W + b)
x = jnp.ones((32, 128))
W = jnp.ones((128, 256))
b = jnp.zeros(256)
# 初回呼び出し: XLAコンパイル発生
result = compute(x, W, b)
print(result.shape) # (32, 256)
XLAの主要最適化:
- 演算フュージョン: matmul/conv後のelementwise演算を統合
- レイアウト最適化: NCHWとNHWCを自動選択
- メモリ計画: ピーク使用メモリを最小化
7. MLIR: 次世代コンパイラインフラ
MLIR(Multi-Level Intermediate Representation)はGoogleが設計したコンパイラフレームワークで、複数レベルのIRを単一インフラに統合します。
TensorFlow Graph
↓ TF Dialect
MLIR HLO Dialect
↓ Linalg Dialect
ループネスト IR
↓ Affine Dialect
アフィンループ IR
↓ LLVM Dialect
LLVM IR
↓
機械語
各Dialectは特定の抽象化レベルを表現し、コンパイルは高レベルのDialectから低レベルへと段階的にローワリング(lowering)することで進みます。
MLIRは現在以下の基盤として使われています:
- IREE (Googleのデバイス向けMLランタイム)
- ONNX-MLIR
- Torch-MLIR (PyTorchからMLIRへ)
- Flang (LLVMのFortranコンパイラ)
8. クイズ: 理解度確認
Q1. LL(1)パーサとLR(1)パーサの解析能力の違いは?
答え: LR(1)パーサの方がより広い範囲の文法を処理できます。
解説: LL(1)はLeft-to-right、Leftmost derivation、1トークン先読みを意味するトップダウンパーサです。左再帰文法を処理できず、1つのトークンだけを見てどの規則を適用するか決定しなければなりません。LR(1)はLeft-to-right、Rightmost derivation、1先読みのボトムアップ(シフト還元)パーサです。左再帰を自然に扱え、決定性文脈自由言語をほぼすべてカバーします。LALR(1)はLR(1)のメモリ効率的な近似で、BisonやGCCで広く使われています。実際のほとんどのプログラミング言語文法はLALR(1)で表現可能です。
Q2. SSA形式が最適化に有利な理由は?
答え: 各変数の定義(def)と使用(use)の関係が自明になり、データフロー解析が大幅に簡単になるためです。
解説: SSAでは各変数が正確に1回だけ定義されるため: (1) 定数伝播 — 定義地点で値が定数ならすべての使用箇所で即座に置換可能、(2) デッドコード除去 — 使用のない定義は自明に不要、(3) レジスタ割り付け — 生存解析が定義から使用への到達可能性に還元、(4) ループ不変コードの移動 — 依存関係解析が容易になります。LLVM、GCCの最適化パスはすべてSSAを基盤としています。
Q3. LLVM IRの特徴と言語独立的な最適化が可能な理由は?
答え: LLVM IRが特定言語やアーキテクチャに依存しない汎用低水準IRであるためです。
解説: LLVM IRの主要特性: (1) 強い静的型システムにより最適化器が依存できる不変条件を確保、(2) SSA形式ですべてのデータ依存関係が明示的、(3) 無制限の仮想レジスタ — 物理レジスタ割り付けはコード生成時のみ、(4) 明示的なメモリモデルとアラインメント情報。言語独立性の核心は、フロントエンド(C→IR、Rust→IR、Swift→IR)とバックエンド(IR→x86、IR→ARM)を分離した点です。中間の最適化パスはIRのみを扱うため、すべての言語とアーキテクチャの組み合わせに適用されます。
Q4. TVMでのオペレータフュージョンがGPU性能を向上させる原理は?
答え: GPU高帯域幅メモリ(HBM)へのアクセス回数を削減し、オンチップキャッシュの効率を最大化するためです。
解説: GPUの演算でボトルネックになるのは演算(FLOPS)自体ではなく、メモリ帯域幅(HBM↔SRAM)であるケースが多いです。フュージョンなしでは各演算が中間結果をHBM(グローバルメモリ)に書き込み、次のカーネルが再度読み込む必要があります。フュージョン後は中間結果がL1/L2キャッシュやレジスタに留まり、HBMアクセス回数が大幅に減少します。特にReLU、BatchNormなどのelementwise演算はメモリバウンドなため、フュージョンの効果が劇的です。TVMのRelay→TIR変換過程でフュージョン可能な演算グループが自動的に検出されます。
Q5. torch.compile()が内部的に使用するコンパイルスタックは?
答え: TorchDynamo → AOTAutograd → TorchInductorの3段階スタックです。
解説: (1) TorchDynamo: Pythonバイトコードレベルで動作し、PyTorch演算部分をFX Graphとしてキャプチャします。Pythonの動的特性(制御フロー、データ依存分岐)に対応するため、ガードベースの特殊化を使用します。 (2) AOTAutograd: 順伝播FX Graphを受け取り、torch.autogradを通じて逆伝播グラフも実行前に(ahead-of-time)生成します。これによりバックエンドが完全な勾配計算全体を最適化できます。 (3) TorchInductor: FX GraphをTriton(CUDA)またはC++(CPU)のカーネルコードに変換し、演算子フュージョン、ループタイリング、ベクトル化などの最適化を適用します。mode引数で'default'、'reduce-overhead'、'max-autotune'から選択可能です。
まとめ
コンパイラ設計は単なる理論ではありません。今日のAIインフラの核心であるtorch.compile()、TVM、XLA、MLIRはすべて数十年のコンパイラ研究の上に構築されています。
Lexer → Parser → AST → IR → Optimization → Code Genのパイプラインを理解すれば、以下のことが自然と分かります:
- なぜGPUプログラミングでカーネルフュージョンがこれほど重要なのか
- なぜLLVMがこれほど多くの言語のバックエンドとして使われるのか
- なぜSSA形式が最適化の普遍的な基盤となっているのか
次のステップとして、LLVMのKaleidoscopeチュートリアル、TVM公式ドキュメント、PyTorch 2.0論文(Ansel et al., 2024)を読まれることをお勧めします。