Skip to content

필사 모드: コンパイラ/インタープリタ設計完全攻略: パーサからLLVM、AIコンパイラ(TVM/XLA)まで

日本語
0%
정확도 0%
💡 왼쪽 원문을 읽으면서 오른쪽에 따라 써보세요. Tab 키로 힌트를 받을 수 있습니다.
원문 렌더가 준비되기 전까지 텍스트 가이드로 표시합니다.

はじめに

コンパイラを理解する理由は、プログラミング言語を作るためだけではありません。今日のAI/ML分野では`torch.compile()`、TVM、XLA、MLIRがコアインフラとして定着しており、それらの内部動作はすべてコンパイラ理論の上に成り立っています。

この記事ではコンパイラ設計のパイプライン全体を最初から最後まで解説し、ML専用コンパイラの最前線まで辿り着きます。

1. コンパイラの基礎: 全体パイプライン

コンパイラはソースコードを実行可能なコードへ変換するプログラムです。

ソースコード

↓ 字句解析 (Lexing)

トークン列

↓ 構文解析 (Parsing)

AST (抽象構文木)

↓ 意味解析 (Semantic Analysis)

注釈付きAST

↓ IR生成

中間表現 (IR)

↓ 最適化 (Optimization)

最適化されたIR

↓ コード生成 (Code Generation)

機械語 / アセンブリ

1.1 字句解析 (Lexical Analysis)

字句解析器(Lexer / Tokenizer)はソースコードの文字列を**トークン**のストリームに変換します。トークンは数値リテラル、識別子、キーワード、演算子などの意味単位です。

from enum import Enum, auto

from dataclasses import dataclass

from typing import List, Optional

class TokenType(Enum):

NUMBER = auto()

IDENT = auto()

PLUS = auto()

MINUS = auto()

STAR = auto()

SLASH = auto()

LPAREN = auto()

RPAREN = auto()

ASSIGN = auto()

SEMICOL = auto()

LET = auto()

EOF = auto()

@dataclass

class Token:

type: TokenType

value: str

line: int

KEYWORDS = {'let': TokenType.LET}

TOKEN_PATTERNS = [

(r'\d+(\.\d+)?', TokenType.NUMBER),

(r'[a-zA-Z_]\w*', TokenType.IDENT),

(r'\+', TokenType.PLUS),

(r'-', TokenType.MINUS),

(r'\*', TokenType.STAR),

(r'/', TokenType.SLASH),

(r'\(', TokenType.LPAREN),

(r'\)', TokenType.RPAREN),

(r'=', TokenType.ASSIGN),

(r';', TokenType.SEMICOL),

]

def tokenize(source: str) -> List[Token]:

tokens = []

line = 1

pos = 0

while pos < len(source):

if source[pos] in ' \t\r':

pos += 1

continue

if source[pos] == '\n':

line += 1

pos += 1

continue

matched = False

for pattern, ttype in TOKEN_PATTERNS:

m = re.match(pattern, source[pos:])

if m:

value = m.group(0)

if ttype == TokenType.IDENT and value in KEYWORDS:

ttype = KEYWORDS[value]

tokens.append(Token(ttype, value, line))

pos += len(value)

matched = True

break

if not matched:

raise SyntaxError(f"予期しない文字 '{source[pos]}' (行 {line})")

tokens.append(Token(TokenType.EOF, '', line))

return tokens

使用例

src = "let x = 10 + 20 * 3;"

for tok in tokenize(src):

print(tok)

1.2 構文解析とAST

パーサはトークン列を受け取り**抽象構文木(AST)**を構築します。

from dataclasses import dataclass

from typing import Union

ASTノード定義

@dataclass

class NumberLit:

value: float

@dataclass

class Identifier:

name: str

@dataclass

class BinOp:

op: str

left: 'Expr'

right: 'Expr'

@dataclass

class LetStmt:

name: str

value: 'Expr'

Expr = Union[NumberLit, Identifier, BinOp]

再帰下降パーサ

class Parser:

def __init__(self, tokens: List[Token]):

self.tokens = tokens

self.pos = 0

def peek(self) -> Token:

return self.tokens[self.pos]

def consume(self, expected: Optional[TokenType] = None) -> Token:

tok = self.tokens[self.pos]

if expected and tok.type != expected:

raise SyntaxError(f"期待: {expected}, 実際: {tok.type} (行 {tok.line})")

self.pos += 1

return tok

def parse_let(self) -> LetStmt:

self.consume(TokenType.LET)

name = self.consume(TokenType.IDENT).value

self.consume(TokenType.ASSIGN)

expr = self.parse_expr()

self.consume(TokenType.SEMICOL)

return LetStmt(name=name, value=expr)

def parse_expr(self) -> Expr:

return self.parse_additive()

def parse_additive(self) -> Expr:

left = self.parse_multiplicative()

while self.peek().type in (TokenType.PLUS, TokenType.MINUS):

op = self.consume().value

right = self.parse_multiplicative()

left = BinOp(op=op, left=left, right=right)

return left

def parse_multiplicative(self) -> Expr:

left = self.parse_primary()

while self.peek().type in (TokenType.STAR, TokenType.SLASH):

op = self.consume().value

right = self.parse_primary()

left = BinOp(op=op, left=left, right=right)

return left

def parse_primary(self) -> Expr:

tok = self.peek()

if tok.type == TokenType.NUMBER:

self.consume()

return NumberLit(value=float(tok.value))

if tok.type == TokenType.IDENT:

self.consume()

return Identifier(name=tok.value)

if tok.type == TokenType.LPAREN:

self.consume()

expr = self.parse_expr()

self.consume(TokenType.RPAREN)

return expr

raise SyntaxError(f"予期しないトークン: {tok}")

主要なパース戦略の比較:

| 手法 | 種別 | 用途 |

| ------------- | ---------------------- | ------------------ |

| 再帰下降 | LL(k) | 手書き、読みやすい |

| Pratt Parser | トップダウン演算子優先 | 式の解析に優秀 |

| LALR(1) | ボトムアップ | Bison、GCC、Python |

| PEG / Packrat | 解析表現文法 | 曖昧さなし、メモ化 |

2. 中間表現 (IR)

ASTから直接機械語へ変換するのは非効率です。IRを介することで言語に依存しない最適化が可能になります。

2.1 3アドレスコード

最も単純なフラットなIR。すべての演算は最大3つのオペランドを持ちます。

ソース: x = (a + b) * (c - d)

t1 = a + b

t2 = c - d

x = t1 * t2

2.2 SSA形式 (静的単一代入)

SSAでは各変数が**正確に1回だけ代入**されます。これによりデータフロー分析が劇的に単純化されます。

通常コード

x = 1

x = x + 2 # xが2回代入される

SSA変換後

x_1 = 1

x_2 = x_1 + 2

制御フローの合流点では**Phi関数**を使います。

if cond:

x_1 = 10

else:

x_2 = 20

合流点

x_3 = phi(x_1, x_2)

2.3 LLVM IR

LLVM IRはClang、Rust、Swiftなどが採用する産業標準のIRです。

; 関数定義: int add(int a, int b)

define i32 @add(i32 %a, i32 %b) {

entry:

%result = add i32 %a, %b

ret i32 %result

}

; 条件分岐 + SSA phi関数

define i32 @max(i32 %a, i32 %b) {

entry:

%cmp = icmp sgt i32 %a, %b ; a > b ?

br i1 %cmp, label %then, label %else

then:

br label %merge

else:

br label %merge

merge:

%result = phi i32 [ %a, %then ], [ %b, %else ]

ret i32 %result

}

LLVM IRの特徴:

- **強い型システム**: `i32`、`i64`、`float`、`ptr` などすべて明示的

- **SSA基盤**: すべての値は1回だけ定義

- **仮想レジスタ無制限**: `%0`、`%1`、`%2` ... 制限なし

- **言語独立性**: C、Rust、Swift、Haskellがすべて同じIRに変換

3. 最適化パス

3.1 定数畳み込み (Constant Folding)

コンパイル時に計算可能な定数式を事前計算します。

最適化前

x = 2 * 3 + 4

最適化後 (コンパイル時に計算済み)

x = 10

def constant_fold(node):

"""簡単な定数畳み込みの実装"""

if isinstance(node, BinOp):

left = constant_fold(node.left)

right = constant_fold(node.right)

if isinstance(left, NumberLit) and isinstance(right, NumberLit):

ops = {

'+': lambda a, b: a + b,

'-': lambda a, b: a - b,

'*': lambda a, b: a * b,

'/': lambda a, b: a / b,

}

return NumberLit(value=ops[node.op](left.value, right.value))

return BinOp(op=node.op, left=left, right=right)

return node

3.2 デッドコード除去 (Dead Code Elimination)

結果が使われない計算を除去します。

def dead_code_eliminate(stmts, used_vars):

"""使われない変数への代入を削除する後方パス"""

live = set()

for stmt in reversed(stmts):

if isinstance(stmt, LetStmt):

if stmt.name in used_vars or stmt.name in live:

live.update(collect_vars(stmt.value))

return [s for s in stmts

if not isinstance(s, LetStmt) or s.name in live]

3.3 ループアンローリング (Loop Unrolling)

ループオーバーヘッドを削減し、命令レベル並列性(ILP)を高めます。

// 元のコード (ループ4回)

for (int i = 0; i < 4; i++) {

a[i] = b[i] + c[i];

}

// アンローリング後 (ループ消去)

a[0] = b[0] + c[0];

a[1] = b[1] + c[1];

a[2] = b[2] + c[2];

a[3] = b[3] + c[3];

3.4 関数インライン化 (Function Inlining)

関数呼び出しのオーバーヘッドをなくし、追加の最適化機会を生み出します。

// インライン前

inline int square(int x) { return x * x; }

int y = square(5);

// インライン後 (定数畳み込みで25に)

int y = 25;

4. コード生成: x86-64アセンブリ

4.1 レジスタ割り付け

x86-64には汎用レジスタが16個(`rax`、`rbx`、`rcx`、`rdx`、`rsi`、`rdi`、`rsp`、`rbp`、`r8`〜`r15`)あります。無制限の仮想レジスタをこの有限セットにマッピングする必要があります。

**グラフ彩色法**が最も有名なアルゴリズムです。同時に生存している仮想レジスタ間に辺を張り、N色でグラフを彩色します(N=物理レジスタ数)。彩色できない場合はスピル(スタックへの退避)を行います。

4.2 x86-64アセンブリの例

; C関数: int add(int a, int b) { return a + b; }

; Linux System V ABI: 引数はrdi, rsi順; 返り値はrax

add:

push rbp

mov rbp, rsp

mov DWORD PTR [rbp-4], edi

mov DWORD PTR [rbp-8], esi

mov edx, DWORD PTR [rbp-4]

mov eax, DWORD PTR [rbp-8]

add eax, edx

pop rbp

ret

; -O2最適化版 (1命令のみ!)

add:

lea eax, [rdi + rsi]

ret

5. JITコンパイル: 実行時最適化

**JIT(Just-In-Time)**コンパイルはプログラム実行中にコードをコンパイルします。インタープリタ言語とネイティブコードのパフォーマンスギャップを埋める核心技術です。

5.1 Numba JIT

NumbaはPython/NumPyコードをLLVMを通じてネイティブ機械語にコンパイルします。

from numba import njit, prange

純粋Pythonベースライン

def matmul_python(A, B):

N = A.shape[0]

C = np.zeros((N, N))

for i in range(N):

for j in range(N):

for k in range(N):

C[i, j] += A[i, k] * B[k, j]

return C

Numba JIT版 (並列化あり)

@njit(parallel=True)

def matmul_numba(A, B):

N = A.shape[0]

C = np.zeros((N, N))

for i in prange(N): # 並列ループ

for j in range(N):

for k in range(N):

C[i, j] += A[i, k] * B[k, j]

return C

N = 256

A = np.random.rand(N, N).astype(np.float32)

B = np.random.rand(N, N).astype(np.float32)

初回呼び出しでコンパイル (ウォームアップ)

_ = matmul_numba(A, B)

t0 = time.perf_counter()

C_py = matmul_python(A, B)

print(f"Python: {time.perf_counter()-t0:.3f}s")

t0 = time.perf_counter()

C_nb = matmul_numba(A, B)

print(f"Numba JIT: {time.perf_counter()-t0:.4f}s")

典型的な結果: Python ~30s vs Numba ~0.01s (数百倍高速)

**Numbaの内部動作:**

1. Pythonバイトコードの読み取り

2. 型推論により具体的な型を確定

3. LLVM IRへのローワリング

4. LLVMバックエンドが機械語にコンパイル

5. コンパイル済みアーティファクトをキャッシュ

5.2 PyPyとトレーシングJIT

PyPyのトレーシングJITは「ホットパス」(頻繁に実行されるトレース)を記録し、それを直接コンパイルします。コール境界を越えてインライン化できる点が特徴です。

6. AI/MLコンパイラのエコシステム

6.1 torch.compile()の内部動作

PyTorch 2.0の核心機能`torch.compile()`は3段階のコンパイルパイプラインで構成されています。

class SimpleNet(nn.Module):

def __init__(self):

super().__init__()

self.fc1 = nn.Linear(128, 256)

self.fc2 = nn.Linear(256, 10)

self.relu = nn.ReLU()

def forward(self, x):

return self.fc2(self.relu(self.fc1(x)))

model = SimpleNet().cuda()

x = torch.randn(32, 128).cuda()

torch.compile適用

compiled = torch.compile(model, mode='max-autotune')

初回呼び出し: コンパイル発生 (数秒かかる)

out = compiled(x)

以降の呼び出し: コンパイル済みカーネルを再利用

for _ in range(100):

out = compiled(x)

**第1段階 — TorchDynamo (グラフキャプチャ)**

DynamoはPythonバイトコードを傍受し、PyTorch演算を`torch.fx.Graph`として捕捉します。

def my_func(x, y):

return torch.sin(x) + torch.cos(y)

FXグラフの確認

explanation = dynamo.explain(my_func)(

torch.randn(4), torch.randn(4)

)

print(explanation.graphs[0].print_tabular())

opcode | name | target | args

--------|--------|--------------|----------

call_fn | sin | torch.sin | (x,)

call_fn | cos | torch.cos | (y,)

call_fn | add | operator.add | (sin, cos)

output | output | output | (add,)

**第2段階 — AOTAutograd (自動微分の事前コンパイル)**

順伝播FXグラフを受け取り、実行前に順伝播・逆伝播グラフを両方生成します。これによりバックエンドが完全な勾配計算を最適化できます。

**第3段階 — TorchInductor (カーネル生成)**

FXグラフをTriton(CUDA)またはC++(CPU)のカーネルコードに変換します。演算子フュージョン、タイリング、ベクトル化などの最適化がここで適用されます。

6.2 Apache TVM

TVMは深層学習モデルを多様なハードウェアターゲット(GPU、CPU、FPGA、NPU)向けに最適化します。

from tvm import relay

PyTorchモデルから開始

model = torchvision.models.resnet18(pretrained=False).eval()

input_shape = [1, 3, 224, 224]

trace_input = torch.randn(input_shape)

TorchScriptを経由してTVM Relay IRへ変換

scripted = torch.jit.trace(model, trace_input)

mod, params = relay.frontend.from_pytorch(

scripted, [("input0", input_shape)]

)

target = tvm.target.Target("cuda")

AutoScheduler: 最適カーネルパラメータの探索

from tvm import auto_scheduler

tasks, task_weights = auto_scheduler.extract_tasks(

mod["main"], params, target

)

tuner = auto_scheduler.TaskScheduler(tasks, task_weights)

tune_option = auto_scheduler.TuningOptions(

num_measure_trials=200,

runner=auto_scheduler.LocalRunner(repeat=10),

measure_callbacks=[auto_scheduler.RecordToFile("tuning.json")],

)

tuner.tune(tune_option)

最良スケジュールでコンパイル

with auto_scheduler.ApplyHistoryBest("tuning.json"):

with tvm.transform.PassContext(opt_level=3):

lib = relay.build(mod, target=target, params=params)

print("TVMコンパイル完了!")

6.3 演算子フュージョン (Operator Fusion)

フュージョンはGPUにおける最も重要な最適化です。複数のカーネル呼び出しを1つに統合します。

フュージョン前: 3回のGPUカーネル起動

y1 = conv2d(x) # カーネル1 — HBMへ書き込み

y2 = batch_norm(y1) # カーネル2 — HBM読み書き

y3 = relu(y2) # カーネル3 — HBM読み書き

フュージョン後: 1回のGPUカーネル起動

y3 = fused_conv_bn_relu(x)

中間結果はL1/L2キャッシュに留まる

HBM帯域幅の節約 = 大幅な高速化

6.4 XLAとJAX

XLA(Accelerated Linear Algebra)はTensorFlowとJAXのコンパイラです。HLO(High Level Operations)をIRとして使用します。

@jax.jit

def compute(x, W, b):

XLAが自動的にmatmul + add + reluをフューズ

return jax.nn.relu(x @ W + b)

x = jnp.ones((32, 128))

W = jnp.ones((128, 256))

b = jnp.zeros(256)

初回呼び出し: XLAコンパイル発生

result = compute(x, W, b)

print(result.shape) # (32, 256)

XLAの主要最適化:

- **演算フュージョン**: matmul/conv後のelementwise演算を統合

- **レイアウト最適化**: NCHWとNHWCを自動選択

- **メモリ計画**: ピーク使用メモリを最小化

7. MLIR: 次世代コンパイラインフラ

MLIR(Multi-Level Intermediate Representation)はGoogleが設計したコンパイラフレームワークで、複数レベルのIRを単一インフラに統合します。

TensorFlow Graph

↓ TF Dialect

MLIR HLO Dialect

↓ Linalg Dialect

ループネスト IR

↓ Affine Dialect

アフィンループ IR

↓ LLVM Dialect

LLVM IR

機械語

各**Dialect**は特定の抽象化レベルを表現し、コンパイルは高レベルのDialectから低レベルへと段階的にローワリング(lowering)することで進みます。

MLIRは現在以下の基盤として使われています:

- IREE (Googleのデバイス向けMLランタイム)

- ONNX-MLIR

- Torch-MLIR (PyTorchからMLIRへ)

- Flang (LLVMのFortranコンパイラ)

8. クイズ: 理解度確認

**答え**: LR(1)パーサの方がより広い範囲の文法を処理できます。

**解説**: LL(1)はLeft-to-right、Leftmost derivation、1トークン先読みを意味するトップダウンパーサです。左再帰文法を処理できず、1つのトークンだけを見てどの規則を適用するか決定しなければなりません。LR(1)はLeft-to-right、Rightmost derivation、1先読みのボトムアップ(シフト還元)パーサです。左再帰を自然に扱え、決定性文脈自由言語をほぼすべてカバーします。LALR(1)はLR(1)のメモリ効率的な近似で、BisonやGCCで広く使われています。実際のほとんどのプログラミング言語文法はLALR(1)で表現可能です。

**答え**: 各変数の定義(def)と使用(use)の関係が自明になり、データフロー解析が大幅に簡単になるためです。

**解説**: SSAでは各変数が正確に1回だけ定義されるため: (1) 定数伝播 — 定義地点で値が定数ならすべての使用箇所で即座に置換可能、(2) デッドコード除去 — 使用のない定義は自明に不要、(3) レジスタ割り付け — 生存解析が定義から使用への到達可能性に還元、(4) ループ不変コードの移動 — 依存関係解析が容易になります。LLVM、GCCの最適化パスはすべてSSAを基盤としています。

**答え**: LLVM IRが特定言語やアーキテクチャに依存しない汎用低水準IRであるためです。

**解説**: LLVM IRの主要特性: (1) 強い静的型システムにより最適化器が依存できる不変条件を確保、(2) SSA形式ですべてのデータ依存関係が明示的、(3) 無制限の仮想レジスタ — 物理レジスタ割り付けはコード生成時のみ、(4) 明示的なメモリモデルとアラインメント情報。言語独立性の核心は、フロントエンド(C→IR、Rust→IR、Swift→IR)とバックエンド(IR→x86、IR→ARM)を分離した点です。中間の最適化パスはIRのみを扱うため、すべての言語とアーキテクチャの組み合わせに適用されます。

**答え**: GPU高帯域幅メモリ(HBM)へのアクセス回数を削減し、オンチップキャッシュの効率を最大化するためです。

**解説**: GPUの演算でボトルネックになるのは演算(FLOPS)自体ではなく、メモリ帯域幅(HBM↔SRAM)であるケースが多いです。フュージョンなしでは各演算が中間結果をHBM(グローバルメモリ)に書き込み、次のカーネルが再度読み込む必要があります。フュージョン後は中間結果がL1/L2キャッシュやレジスタに留まり、HBMアクセス回数が大幅に減少します。特にReLU、BatchNormなどのelementwise演算はメモリバウンドなため、フュージョンの効果が劇的です。TVMのRelay→TIR変換過程でフュージョン可能な演算グループが自動的に検出されます。

**答え**: TorchDynamo → AOTAutograd → TorchInductorの3段階スタックです。

**解説**: (1) TorchDynamo: Pythonバイトコードレベルで動作し、PyTorch演算部分をFX Graphとしてキャプチャします。Pythonの動的特性(制御フロー、データ依存分岐)に対応するため、ガードベースの特殊化を使用します。 (2) AOTAutograd: 順伝播FX Graphを受け取り、`torch.autograd`を通じて逆伝播グラフも実行前に(ahead-of-time)生成します。これによりバックエンドが完全な勾配計算全体を最適化できます。 (3) TorchInductor: FX GraphをTriton(CUDA)またはC++(CPU)のカーネルコードに変換し、演算子フュージョン、ループタイリング、ベクトル化などの最適化を適用します。`mode`引数で'default'、'reduce-overhead'、'max-autotune'から選択可能です。

まとめ

コンパイラ設計は単なる理論ではありません。今日のAIインフラの核心である`torch.compile()`、TVM、XLA、MLIRはすべて数十年のコンパイラ研究の上に構築されています。

Lexer → Parser → AST → IR → Optimization → Code Genのパイプラインを理解すれば、以下のことが自然と分かります:

- なぜGPUプログラミングでカーネルフュージョンがこれほど重要なのか

- なぜLLVMがこれほど多くの言語のバックエンドとして使われるのか

- なぜSSA形式が最適化の普遍的な基盤となっているのか

次のステップとして、LLVMのKaleidoscopeチュートリアル、TVM公式ドキュメント、PyTorch 2.0論文(Ansel et al., 2024)を読まれることをお勧めします。

현재 단락 (1/400)

コンパイラを理解する理由は、プログラミング言語を作るためだけではありません。今日のAI/ML分野では`torch.compile()`、TVM、XLA、MLIRがコアインフラとして定着しており、それら...

작성 글자: 0원문 글자: 12,502작성 단락: 0/400