Skip to content

필사 모드: Compiler & Interpreter Design: From Parsers to LLVM and AI Compilers (TVM/XLA)

English
0%
정확도 0%
💡 왼쪽 원문을 읽으면서 오른쪽에 따라 써보세요. Tab 키로 힌트를 받을 수 있습니다.
원문 렌더가 준비되기 전까지 텍스트 가이드로 표시합니다.

Introduction

Understanding compilers is not just about building programming languages. Today, tools like `torch.compile()`, TVM, XLA, and MLIR are core infrastructure in AI/ML, and their internals are built entirely on compiler theory.

This guide walks through the entire compiler pipeline from first principles, ending at the cutting edge of ML compilation.

1. Compiler Fundamentals: The Full Pipeline

A compiler transforms source code into executable code through a series of well-defined stages.

Source Code

↓ Lexical Analysis (Lexing)

Token Stream

↓ Syntactic Analysis (Parsing)

AST (Abstract Syntax Tree)

↓ Semantic Analysis

Annotated AST

↓ IR Generation

Intermediate Representation (IR)

↓ Optimization Passes

Optimized IR

↓ Code Generation

Machine Code / Assembly

1.1 Lexical Analysis (Tokenization)

The lexer reads source characters and emits a stream of **tokens** — the smallest meaningful units like numbers, identifiers, keywords, and operators.

from enum import Enum, auto

from dataclasses import dataclass

from typing import List, Optional

class TokenType(Enum):

NUMBER = auto()

IDENT = auto()

PLUS = auto()

MINUS = auto()

STAR = auto()

SLASH = auto()

LPAREN = auto()

RPAREN = auto()

ASSIGN = auto()

SEMICOL = auto()

LET = auto()

EOF = auto()

@dataclass

class Token:

type: TokenType

value: str

line: int

KEYWORDS = {'let': TokenType.LET}

TOKEN_PATTERNS = [

(r'\d+(\.\d+)?', TokenType.NUMBER),

(r'[a-zA-Z_]\w*', TokenType.IDENT),

(r'\+', TokenType.PLUS),

(r'-', TokenType.MINUS),

(r'\*', TokenType.STAR),

(r'/', TokenType.SLASH),

(r'\(', TokenType.LPAREN),

(r'\)', TokenType.RPAREN),

(r'=', TokenType.ASSIGN),

(r';', TokenType.SEMICOL),

]

def tokenize(source: str) -> List[Token]:

tokens = []

line = 1

pos = 0

while pos < len(source):

if source[pos] in ' \t\r':

pos += 1

continue

if source[pos] == '\n':

line += 1

pos += 1

continue

matched = False

for pattern, ttype in TOKEN_PATTERNS:

m = re.match(pattern, source[pos:])

if m:

value = m.group(0)

if ttype == TokenType.IDENT and value in KEYWORDS:

ttype = KEYWORDS[value]

tokens.append(Token(ttype, value, line))

pos += len(value)

matched = True

break

if not matched:

raise SyntaxError(f"Unexpected character '{source[pos]}' at line {line}")

tokens.append(Token(TokenType.EOF, '', line))

return tokens

Example usage

src = "let x = 10 + 20 * 3;"

for tok in tokenize(src):

print(tok)

1.2 Parsing and the AST

The parser consumes tokens and builds an **Abstract Syntax Tree (AST)** that captures the hierarchical structure of the program.

from dataclasses import dataclass

from typing import Union

AST node definitions

@dataclass

class NumberLit:

value: float

@dataclass

class Identifier:

name: str

@dataclass

class BinOp:

op: str

left: 'Expr'

right: 'Expr'

@dataclass

class LetStmt:

name: str

value: 'Expr'

Expr = Union[NumberLit, Identifier, BinOp]

Recursive Descent Parser

class Parser:

def __init__(self, tokens: List[Token]):

self.tokens = tokens

self.pos = 0

def peek(self) -> Token:

return self.tokens[self.pos]

def consume(self, expected: Optional[TokenType] = None) -> Token:

tok = self.tokens[self.pos]

if expected and tok.type != expected:

raise SyntaxError(f"Expected {expected}, got {tok.type} at line {tok.line}")

self.pos += 1

return tok

def parse_let(self) -> LetStmt:

self.consume(TokenType.LET)

name = self.consume(TokenType.IDENT).value

self.consume(TokenType.ASSIGN)

expr = self.parse_expr()

self.consume(TokenType.SEMICOL)

return LetStmt(name=name, value=expr)

def parse_expr(self) -> Expr:

return self.parse_additive()

def parse_additive(self) -> Expr:

left = self.parse_multiplicative()

while self.peek().type in (TokenType.PLUS, TokenType.MINUS):

op = self.consume().value

right = self.parse_multiplicative()

left = BinOp(op=op, left=left, right=right)

return left

def parse_multiplicative(self) -> Expr:

left = self.parse_primary()

while self.peek().type in (TokenType.STAR, TokenType.SLASH):

op = self.consume().value

right = self.parse_primary()

left = BinOp(op=op, left=left, right=right)

return left

def parse_primary(self) -> Expr:

tok = self.peek()

if tok.type == TokenType.NUMBER:

self.consume()

return NumberLit(value=float(tok.value))

if tok.type == TokenType.IDENT:

self.consume()

return Identifier(name=tok.value)

if tok.type == TokenType.LPAREN:

self.consume()

expr = self.parse_expr()

self.consume(TokenType.RPAREN)

return expr

raise SyntaxError(f"Unexpected token {tok}")

Parsing strategies differ in power and complexity:

| Strategy | Type | Notes |

| ----------------- | ---------------------------- | -------------------------- |

| Recursive Descent | LL(k) | Hand-written, easy to read |

| Pratt Parser | Top-down operator precedence | Elegant for expressions |

| LALR(1) | Bottom-up | Used by Bison, GCC, Python |

| PEG / Packrat | Parsing Expression | No ambiguity, memoized |

2. Intermediate Representations (IR)

Going straight from AST to machine code is impractical. A well-designed IR enables language-agnostic optimizations.

2.1 Three-Address Code

The simplest flat IR. Every operation uses at most three operands.

Source: x = (a + b) * (c - d)

t1 = a + b

t2 = c - d

x = t1 * t2

2.2 SSA Form (Static Single Assignment)

SSA requires that every variable is assigned **exactly once**. This simple constraint dramatically simplifies dataflow analysis.

Original code

x = 1

x = x + 2 # x assigned twice

SSA form

x_1 = 1

x_2 = x_1 + 2

At control-flow merge points, **phi functions** select the correct version:

if cond:

x_1 = 10

else:

x_2 = 20

Merge point

x_3 = phi(x_1, x_2)

2.3 LLVM IR

LLVM IR is the industry-standard IR used by Clang, Rust, Swift, and many others.

; Function: int add(int a, int b)

define i32 @add(i32 %a, i32 %b) {

entry:

%result = add i32 %a, %b

ret i32 %result

}

; Branching with SSA phi nodes

define i32 @max(i32 %a, i32 %b) {

entry:

%cmp = icmp sgt i32 %a, %b ; signed greater-than

br i1 %cmp, label %then, label %else

then:

br label %merge

else:

br label %merge

merge:

%result = phi i32 [ %a, %then ], [ %b, %else ]

ret i32 %result

}

Key properties of LLVM IR:

- **Strong type system**: `i32`, `i64`, `float`, `ptr` — all explicit

- **SSA-based**: every `%name` is defined exactly once

- **Infinite virtual registers**: `%0`, `%1`, `%2` ... no limit

- **Target-independent**: C, Rust, Swift all lower to the same IR

3. Optimization Passes

3.1 Constant Folding

Pre-compute constant expressions at compile time.

Before

x = 2 * 3 + 4

After (compiler computes at build time)

x = 10

def constant_fold(node):

"""Simple constant folding pass."""

if isinstance(node, BinOp):

left = constant_fold(node.left)

right = constant_fold(node.right)

if isinstance(left, NumberLit) and isinstance(right, NumberLit):

ops = {

'+': lambda a, b: a + b,

'-': lambda a, b: a - b,

'*': lambda a, b: a * b,

'/': lambda a, b: a / b,

}

return NumberLit(value=ops[node.op](left.value, right.value))

return BinOp(op=node.op, left=left, right=right)

return node

3.2 Dead Code Elimination

Remove code whose results are never used.

def dead_code_eliminate(stmts, used_vars):

"""Remove assignments to variables that are never read."""

live = set()

Backward pass: collect used variables

for stmt in reversed(stmts):

if isinstance(stmt, LetStmt):

if stmt.name in used_vars or stmt.name in live:

live.update(collect_vars(stmt.value))

return [s for s in stmts

if not isinstance(s, LetStmt) or s.name in live]

3.3 Loop Unrolling

Reduce loop overhead and expose instruction-level parallelism.

// Original — 4 iterations with loop overhead

for (int i = 0; i < 4; i++) {

a[i] = b[i] + c[i];

}

// Unrolled — zero loop overhead, can be vectorized

a[0] = b[0] + c[0];

a[1] = b[1] + c[1];

a[2] = b[2] + c[2];

a[3] = b[3] + c[3];

3.4 Function Inlining

Replace a call site with the callee body, removing call overhead and enabling further optimization.

// Before

inline int square(int x) { return x * x; }

int y = square(5);

// After inlining (then constant folding)

int y = 25;

4. Code Generation: x86-64 Assembly

4.1 Register Allocation

x86-64 has 16 general-purpose registers (`rax`, `rbx`, `rcx`, `rdx`, `rsi`, `rdi`, `rsp`, `rbp`, `r8`–`r15`). The compiler must map unlimited virtual registers to this finite set.

**Graph Coloring** is the classic algorithm: build an interference graph where two virtual registers share an edge if they are live at the same time, then color the graph with N colors (N = number of physical registers). If coloring fails, spill a variable to the stack.

4.2 x86-64 Example

; C: int add(int a, int b) { return a + b; }

; Linux System V ABI: args in rdi, rsi; return in rax

add:

push rbp

mov rbp, rsp

mov DWORD PTR [rbp-4], edi

mov DWORD PTR [rbp-8], esi

mov edx, DWORD PTR [rbp-4]

mov eax, DWORD PTR [rbp-8]

add eax, edx

pop rbp

ret

; With -O2 (optimized)

add:

lea eax, [rdi + rsi] ; single instruction!

ret

5. JIT Compilation

**JIT (Just-In-Time)** compilation happens at runtime. It is the key technology that closes the performance gap between interpreted languages and native code.

5.1 Numba JIT for Python

Numba uses LLVM to compile Python/NumPy code to native machine code at runtime.

from numba import njit, prange

Pure Python baseline

def matmul_python(A, B):

N = A.shape[0]

C = np.zeros((N, N))

for i in range(N):

for j in range(N):

for k in range(N):

C[i, j] += A[i, k] * B[k, j]

return C

Numba JIT with parallelism

@njit(parallel=True)

def matmul_numba(A, B):

N = A.shape[0]

C = np.zeros((N, N))

for i in prange(N): # parallel loop

for j in range(N):

for k in range(N):

C[i, j] += A[i, k] * B[k, j]

return C

N = 256

A = np.random.rand(N, N).astype(np.float32)

B = np.random.rand(N, N).astype(np.float32)

First call triggers compilation (warmup)

_ = matmul_numba(A, B)

t0 = time.perf_counter()

C_py = matmul_python(A, B)

print(f"Python: {time.perf_counter()-t0:.3f}s")

t0 = time.perf_counter()

C_nb = matmul_numba(A, B)

print(f"Numba JIT: {time.perf_counter()-t0:.4f}s")

Typical result: Python ~30s vs Numba ~0.01s

**How Numba works internally:**

1. Intercept Python bytecode

2. Type inference to determine concrete types

3. Lower to LLVM IR

4. LLVM backend compiles to native machine code

5. Cache the compiled artifact for subsequent calls

5.2 PyPy and Tracing JIT

PyPy's tracing JIT records "hot paths" (frequently executed traces) and compiles them directly. Unlike method-based JIT (HotSpot JVM), it can inline across call boundaries along the trace.

6. AI/ML Compilers

6.1 torch.compile() Internals

`torch.compile()`, introduced in PyTorch 2.0, is a three-stage compilation pipeline.

class SimpleNet(nn.Module):

def __init__(self):

super().__init__()

self.fc1 = nn.Linear(128, 256)

self.fc2 = nn.Linear(256, 10)

self.relu = nn.ReLU()

def forward(self, x):

return self.fc2(self.relu(self.fc1(x)))

model = SimpleNet().cuda()

x = torch.randn(32, 128).cuda()

Apply torch.compile

compiled = torch.compile(model, mode='max-autotune')

First call: compilation happens (takes a few seconds)

out = compiled(x)

Subsequent calls: reuse compiled kernels

for _ in range(100):

out = compiled(x)

**Stage 1 — TorchDynamo (Graph Capture)**

Dynamo intercepts Python bytecode and captures PyTorch operations as a `torch.fx.Graph`.

def my_func(x, y):

return torch.sin(x) + torch.cos(y)

Inspect the captured FX graph

explanation = dynamo.explain(my_func)(

torch.randn(4), torch.randn(4)

)

print(explanation.graphs[0].print_tabular())

opcode | name | target | args

----------|--------|--------------|----------

call_fn | sin | torch.sin | (x,)

call_fn | cos | torch.cos | (y,)

call_fn | add | operator.add | (sin, cos)

output | output | output | (add,)

**Stage 2 — AOTAutograd (Pre-compile Autograd)**

Takes the forward FX graph and traces through `torch.autograd` to produce both forward and backward graphs _before_ execution. This enables the backend to optimize across the full gradient computation.

**Stage 3 — TorchInductor (Kernel Generation)**

Converts the FX graph into Triton (for CUDA) or C++ (for CPU) kernel code, applying operator fusion, tiling, and vectorization.

6.2 Apache TVM

TVM compiles deep learning models to a wide variety of hardware targets (GPU, CPU, FPGA, custom NPUs).

from tvm import relay

Start from a PyTorch model

model = torchvision.models.resnet18(pretrained=False).eval()

input_shape = [1, 3, 224, 224]

trace_input = torch.randn(input_shape)

Trace to TorchScript, then import into TVM Relay IR

scripted = torch.jit.trace(model, trace_input)

mod, params = relay.frontend.from_pytorch(scripted, [("input0", input_shape)])

target = tvm.target.Target("cuda")

AutoScheduler: search for optimal kernel parameters

from tvm import auto_scheduler

tasks, task_weights = auto_scheduler.extract_tasks(

mod["main"], params, target

)

tuner = auto_scheduler.TaskScheduler(tasks, task_weights)

tune_option = auto_scheduler.TuningOptions(

num_measure_trials=200,

runner=auto_scheduler.LocalRunner(repeat=10),

measure_callbacks=[auto_scheduler.RecordToFile("tuning.json")],

)

tuner.tune(tune_option)

Compile with best-found schedules

with auto_scheduler.ApplyHistoryBest("tuning.json"):

with tvm.transform.PassContext(opt_level=3):

lib = relay.build(mod, target=target, params=params)

print("TVM compilation complete!")

6.3 Operator Fusion

Fusion is the single most impactful GPU optimization. Multiple kernel launches are merged into one.

Without fusion: 3 separate GPU kernel launches

y1 = conv2d(x) # kernel 1 — write to HBM

y2 = batch_norm(y1) # kernel 2 — read+write HBM

y3 = relu(y2) # kernel 3 — read+write HBM

With fusion: 1 kernel launch

y3 = fused_conv_bn_relu(x)

Intermediate results stay in L1/L2 cache

HBM bandwidth saved = significant speedup

6.4 XLA and JAX

XLA (Accelerated Linear Algebra) is the compiler behind TensorFlow and JAX. It uses HLO (High Level Operations) as its IR.

@jax.jit

def compute(x, W, b):

XLA automatically fuses: matmul + add + relu

return jax.nn.relu(x @ W + b)

x = jnp.ones((32, 128))

W = jnp.ones((128, 256))

b = jnp.zeros(256)

First call: XLA compilation

result = compute(x, W, b)

print(result.shape) # (32, 256)

XLA's key optimizations include:

- **Operation fusion**: merges elementwise ops after matmul/conv

- **Layout optimization**: chooses NCHW vs NHWC automatically

- **Memory planning**: minimizes peak live memory

7. MLIR: Next-Generation Compiler Infrastructure

MLIR (Multi-Level Intermediate Representation), developed at Google, unifies many compiler IRs under a single extensible framework.

TensorFlow Graph

↓ TF Dialect

MLIR HLO Dialect

↓ Linalg Dialect

Loop Nest IR

↓ Affine Dialect

Affine Loop IR

↓ LLVM Dialect

LLVM IR

Machine Code

Each **Dialect** represents a different level of abstraction. Compilation proceeds by _lowering_ from high-level dialects to lower-level ones, with each step preserving correctness while enabling new optimizations.

MLIR is now the backbone of:

- IREE (Google's on-device ML runtime)

- ONNX-MLIR

- Torch-MLIR (PyTorch to MLIR)

- Flang (LLVM Fortran compiler)

8. Quiz

**Answer**: LR(1) parsers can handle a strictly larger class of grammars than LL(1).

**Explanation**: LL(1) is a top-down parser that reads left-to-right and constructs a leftmost derivation using 1 token of lookahead. It cannot handle left-recursive grammars and requires the grammar to be factored for unique predictions. LR(1) is a bottom-up, shift-reduce parser that can handle left recursion naturally and covers all deterministic context-free languages parsable with 1 lookahead. LALR(1) is a memory-efficient approximation of LR(1) used in Bison and GCC. In practice, nearly all real programming language grammars fit within LALR(1).

**Answer**: Because each variable has exactly one definition, the def-use chain is trivial to compute, making dataflow analysis simple and efficient.

**Explanation**: In SSA, the single definition property means: (1) Constant propagation — if a def is a constant, every use can be replaced immediately; (2) Dead code elimination — a def with no uses is trivially dead; (3) Register allocation — liveness analysis reduces to reachability from def to uses; (4) Loop-invariant code motion — dependencies are explicit. LLVM, GCC, and virtually every modern compiler optimizer operates on SSA.

**Answer**: LLVM IR is a typed, SSA-based, target-independent representation that cleanly separates frontend concerns from backend code generation.

**Explanation**: LLVM IR properties: (1) Strong static typing enforces invariants optimizers can rely on; (2) SSA form makes all data dependencies explicit; (3) Infinite virtual registers — physical register allocation happens only at code-gen time; (4) Explicit memory model with alignment. The separation of concerns is key: C, Rust, Swift, and Haskell frontends all emit LLVM IR, and a single set of optimization passes (InstCombine, GVN, LICM, etc.) improves all of them. The backend then handles target-specific concerns like instruction selection and register allocation.

**Answer**: Fusion reduces HBM (high-bandwidth memory) traffic by keeping intermediate results in faster on-chip memory (registers or L1/L2 cache).

**Explanation**: Modern GPUs are often memory-bandwidth-bound rather than compute-bound. Without fusion, each operator writes its output to HBM and the next operator reads it back — very expensive. With fusion, the intermediate result stays in registers or shared memory, and HBM is only accessed for the initial input and final output. Elementwise operators (ReLU, BatchNorm, element-wise add) are especially good fusion candidates because they are 100% memory-bandwidth-bound. TVM's Relay-to-TIR lowering pass identifies fusable operator groups automatically using a dataflow analysis.

**Answer**: TorchDynamo (graph capture) → AOTAutograd (autograd pre-compilation) → TorchInductor (kernel generation).

**Explanation**: (1) TorchDynamo operates at the Python bytecode level. It intercepts execution and traces PyTorch operations into an FX Graph. It handles Python's dynamic behavior (control flow, data-dependent branches) using guards — if a guard is violated at runtime, it recompiles. (2) AOTAutograd takes the forward FX graph and pre-traces through PyTorch's autograd engine to produce a joint forward+backward graph before any execution. (3) TorchInductor converts the FX graph into Triton GPU kernels or C++ CPU code, applying fusion, loop tiling, and vectorization. The `mode` argument controls the trade-off: 'default' is fast to compile, 'max-autotune' spends more time searching for the best tile sizes and kernel configurations.

Conclusion

Compiler design is not dusty theory — it is the engine powering today's AI infrastructure. `torch.compile()`, TVM, XLA, and MLIR are all built on decades of compiler research: lexers, parsers, SSA, dataflow analysis, register allocation, and code generation.

Understanding the Lexer → Parser → AST → IR → Optimization → Code Gen pipeline gives you the mental model to understand:

- Why GPU kernel fusion matters so much

- Why LLVM is the backend for so many languages

- Why SSA form is the universal foundation for optimization

Recommended next steps: the LLVM Kaleidoscope tutorial, the TVM documentation, and the PyTorch 2.0 paper (Ansel et al., 2024).

현재 단락 (1/398)

Understanding compilers is not just about building programming languages. Today, tools like `torch.c...

작성 글자: 0원문 글자: 16,548작성 단락: 0/398