- Published on
Compiler & Interpreter Design: From Parsers to LLVM and AI Compilers (TVM/XLA)
- Authors

- Name
- Youngju Kim
- @fjvbn20031
- Introduction
- 1. Compiler Fundamentals: The Full Pipeline
- 2. Intermediate Representations (IR)
- 3. Optimization Passes
- 4. Code Generation: x86-64 Assembly
- 5. JIT Compilation
- 6. AI/ML Compilers
- 7. MLIR: Next-Generation Compiler Infrastructure
- 8. Quiz
- Conclusion
Introduction
Understanding compilers is not just about building programming languages. Today, tools like torch.compile(), TVM, XLA, and MLIR are core infrastructure in AI/ML, and their internals are built entirely on compiler theory.
This guide walks through the entire compiler pipeline from first principles, ending at the cutting edge of ML compilation.
1. Compiler Fundamentals: The Full Pipeline
A compiler transforms source code into executable code through a series of well-defined stages.
Source Code
↓ Lexical Analysis (Lexing)
Token Stream
↓ Syntactic Analysis (Parsing)
AST (Abstract Syntax Tree)
↓ Semantic Analysis
Annotated AST
↓ IR Generation
Intermediate Representation (IR)
↓ Optimization Passes
Optimized IR
↓ Code Generation
Machine Code / Assembly
1.1 Lexical Analysis (Tokenization)
The lexer reads source characters and emits a stream of tokens — the smallest meaningful units like numbers, identifiers, keywords, and operators.
import re
from enum import Enum, auto
from dataclasses import dataclass
from typing import List, Optional
class TokenType(Enum):
NUMBER = auto()
IDENT = auto()
PLUS = auto()
MINUS = auto()
STAR = auto()
SLASH = auto()
LPAREN = auto()
RPAREN = auto()
ASSIGN = auto()
SEMICOL = auto()
LET = auto()
EOF = auto()
@dataclass
class Token:
type: TokenType
value: str
line: int
KEYWORDS = {'let': TokenType.LET}
TOKEN_PATTERNS = [
(r'\d+(\.\d+)?', TokenType.NUMBER),
(r'[a-zA-Z_]\w*', TokenType.IDENT),
(r'\+', TokenType.PLUS),
(r'-', TokenType.MINUS),
(r'\*', TokenType.STAR),
(r'/', TokenType.SLASH),
(r'\(', TokenType.LPAREN),
(r'\)', TokenType.RPAREN),
(r'=', TokenType.ASSIGN),
(r';', TokenType.SEMICOL),
]
def tokenize(source: str) -> List[Token]:
tokens = []
line = 1
pos = 0
while pos < len(source):
if source[pos] in ' \t\r':
pos += 1
continue
if source[pos] == '\n':
line += 1
pos += 1
continue
matched = False
for pattern, ttype in TOKEN_PATTERNS:
m = re.match(pattern, source[pos:])
if m:
value = m.group(0)
if ttype == TokenType.IDENT and value in KEYWORDS:
ttype = KEYWORDS[value]
tokens.append(Token(ttype, value, line))
pos += len(value)
matched = True
break
if not matched:
raise SyntaxError(f"Unexpected character '{source[pos]}' at line {line}")
tokens.append(Token(TokenType.EOF, '', line))
return tokens
# Example usage
src = "let x = 10 + 20 * 3;"
for tok in tokenize(src):
print(tok)
1.2 Parsing and the AST
The parser consumes tokens and builds an Abstract Syntax Tree (AST) that captures the hierarchical structure of the program.
from dataclasses import dataclass
from typing import Union
# AST node definitions
@dataclass
class NumberLit:
value: float
@dataclass
class Identifier:
name: str
@dataclass
class BinOp:
op: str
left: 'Expr'
right: 'Expr'
@dataclass
class LetStmt:
name: str
value: 'Expr'
Expr = Union[NumberLit, Identifier, BinOp]
# Recursive Descent Parser
class Parser:
def __init__(self, tokens: List[Token]):
self.tokens = tokens
self.pos = 0
def peek(self) -> Token:
return self.tokens[self.pos]
def consume(self, expected: Optional[TokenType] = None) -> Token:
tok = self.tokens[self.pos]
if expected and tok.type != expected:
raise SyntaxError(f"Expected {expected}, got {tok.type} at line {tok.line}")
self.pos += 1
return tok
def parse_let(self) -> LetStmt:
self.consume(TokenType.LET)
name = self.consume(TokenType.IDENT).value
self.consume(TokenType.ASSIGN)
expr = self.parse_expr()
self.consume(TokenType.SEMICOL)
return LetStmt(name=name, value=expr)
def parse_expr(self) -> Expr:
return self.parse_additive()
def parse_additive(self) -> Expr:
left = self.parse_multiplicative()
while self.peek().type in (TokenType.PLUS, TokenType.MINUS):
op = self.consume().value
right = self.parse_multiplicative()
left = BinOp(op=op, left=left, right=right)
return left
def parse_multiplicative(self) -> Expr:
left = self.parse_primary()
while self.peek().type in (TokenType.STAR, TokenType.SLASH):
op = self.consume().value
right = self.parse_primary()
left = BinOp(op=op, left=left, right=right)
return left
def parse_primary(self) -> Expr:
tok = self.peek()
if tok.type == TokenType.NUMBER:
self.consume()
return NumberLit(value=float(tok.value))
if tok.type == TokenType.IDENT:
self.consume()
return Identifier(name=tok.value)
if tok.type == TokenType.LPAREN:
self.consume()
expr = self.parse_expr()
self.consume(TokenType.RPAREN)
return expr
raise SyntaxError(f"Unexpected token {tok}")
Parsing strategies differ in power and complexity:
| Strategy | Type | Notes |
|---|---|---|
| Recursive Descent | LL(k) | Hand-written, easy to read |
| Pratt Parser | Top-down operator precedence | Elegant for expressions |
| LALR(1) | Bottom-up | Used by Bison, GCC, Python |
| PEG / Packrat | Parsing Expression | No ambiguity, memoized |
2. Intermediate Representations (IR)
Going straight from AST to machine code is impractical. A well-designed IR enables language-agnostic optimizations.
2.1 Three-Address Code
The simplest flat IR. Every operation uses at most three operands.
# Source: x = (a + b) * (c - d)
t1 = a + b
t2 = c - d
x = t1 * t2
2.2 SSA Form (Static Single Assignment)
SSA requires that every variable is assigned exactly once. This simple constraint dramatically simplifies dataflow analysis.
# Original code
x = 1
x = x + 2 # x assigned twice
# SSA form
x_1 = 1
x_2 = x_1 + 2
At control-flow merge points, phi functions select the correct version:
if cond:
x_1 = 10
else:
x_2 = 20
# Merge point
x_3 = phi(x_1, x_2)
2.3 LLVM IR
LLVM IR is the industry-standard IR used by Clang, Rust, Swift, and many others.
; Function: int add(int a, int b)
define i32 @add(i32 %a, i32 %b) {
entry:
%result = add i32 %a, %b
ret i32 %result
}
; Branching with SSA phi nodes
define i32 @max(i32 %a, i32 %b) {
entry:
%cmp = icmp sgt i32 %a, %b ; signed greater-than
br i1 %cmp, label %then, label %else
then:
br label %merge
else:
br label %merge
merge:
%result = phi i32 [ %a, %then ], [ %b, %else ]
ret i32 %result
}
Key properties of LLVM IR:
- Strong type system:
i32,i64,float,ptr— all explicit - SSA-based: every
%nameis defined exactly once - Infinite virtual registers:
%0,%1,%2... no limit - Target-independent: C, Rust, Swift all lower to the same IR
3. Optimization Passes
3.1 Constant Folding
Pre-compute constant expressions at compile time.
# Before
x = 2 * 3 + 4
# After (compiler computes at build time)
x = 10
def constant_fold(node):
"""Simple constant folding pass."""
if isinstance(node, BinOp):
left = constant_fold(node.left)
right = constant_fold(node.right)
if isinstance(left, NumberLit) and isinstance(right, NumberLit):
ops = {
'+': lambda a, b: a + b,
'-': lambda a, b: a - b,
'*': lambda a, b: a * b,
'/': lambda a, b: a / b,
}
return NumberLit(value=ops[node.op](left.value, right.value))
return BinOp(op=node.op, left=left, right=right)
return node
3.2 Dead Code Elimination
Remove code whose results are never used.
def dead_code_eliminate(stmts, used_vars):
"""Remove assignments to variables that are never read."""
live = set()
# Backward pass: collect used variables
for stmt in reversed(stmts):
if isinstance(stmt, LetStmt):
if stmt.name in used_vars or stmt.name in live:
live.update(collect_vars(stmt.value))
return [s for s in stmts
if not isinstance(s, LetStmt) or s.name in live]
3.3 Loop Unrolling
Reduce loop overhead and expose instruction-level parallelism.
// Original — 4 iterations with loop overhead
for (int i = 0; i < 4; i++) {
a[i] = b[i] + c[i];
}
// Unrolled — zero loop overhead, can be vectorized
a[0] = b[0] + c[0];
a[1] = b[1] + c[1];
a[2] = b[2] + c[2];
a[3] = b[3] + c[3];
3.4 Function Inlining
Replace a call site with the callee body, removing call overhead and enabling further optimization.
// Before
inline int square(int x) { return x * x; }
int y = square(5);
// After inlining (then constant folding)
int y = 25;
4. Code Generation: x86-64 Assembly
4.1 Register Allocation
x86-64 has 16 general-purpose registers (rax, rbx, rcx, rdx, rsi, rdi, rsp, rbp, r8–r15). The compiler must map unlimited virtual registers to this finite set.
Graph Coloring is the classic algorithm: build an interference graph where two virtual registers share an edge if they are live at the same time, then color the graph with N colors (N = number of physical registers). If coloring fails, spill a variable to the stack.
4.2 x86-64 Example
; C: int add(int a, int b) { return a + b; }
; Linux System V ABI: args in rdi, rsi; return in rax
add:
push rbp
mov rbp, rsp
mov DWORD PTR [rbp-4], edi
mov DWORD PTR [rbp-8], esi
mov edx, DWORD PTR [rbp-4]
mov eax, DWORD PTR [rbp-8]
add eax, edx
pop rbp
ret
; With -O2 (optimized)
add:
lea eax, [rdi + rsi] ; single instruction!
ret
5. JIT Compilation
JIT (Just-In-Time) compilation happens at runtime. It is the key technology that closes the performance gap between interpreted languages and native code.
5.1 Numba JIT for Python
Numba uses LLVM to compile Python/NumPy code to native machine code at runtime.
import numpy as np
from numba import njit, prange
import time
# Pure Python baseline
def matmul_python(A, B):
N = A.shape[0]
C = np.zeros((N, N))
for i in range(N):
for j in range(N):
for k in range(N):
C[i, j] += A[i, k] * B[k, j]
return C
# Numba JIT with parallelism
@njit(parallel=True)
def matmul_numba(A, B):
N = A.shape[0]
C = np.zeros((N, N))
for i in prange(N): # parallel loop
for j in range(N):
for k in range(N):
C[i, j] += A[i, k] * B[k, j]
return C
N = 256
A = np.random.rand(N, N).astype(np.float32)
B = np.random.rand(N, N).astype(np.float32)
# First call triggers compilation (warmup)
_ = matmul_numba(A, B)
t0 = time.perf_counter()
C_py = matmul_python(A, B)
print(f"Python: {time.perf_counter()-t0:.3f}s")
t0 = time.perf_counter()
C_nb = matmul_numba(A, B)
print(f"Numba JIT: {time.perf_counter()-t0:.4f}s")
# Typical result: Python ~30s vs Numba ~0.01s
How Numba works internally:
- Intercept Python bytecode
- Type inference to determine concrete types
- Lower to LLVM IR
- LLVM backend compiles to native machine code
- Cache the compiled artifact for subsequent calls
5.2 PyPy and Tracing JIT
PyPy's tracing JIT records "hot paths" (frequently executed traces) and compiles them directly. Unlike method-based JIT (HotSpot JVM), it can inline across call boundaries along the trace.
6. AI/ML Compilers
6.1 torch.compile() Internals
torch.compile(), introduced in PyTorch 2.0, is a three-stage compilation pipeline.
import torch
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(128, 256)
self.fc2 = nn.Linear(256, 10)
self.relu = nn.ReLU()
def forward(self, x):
return self.fc2(self.relu(self.fc1(x)))
model = SimpleNet().cuda()
x = torch.randn(32, 128).cuda()
# Apply torch.compile
compiled = torch.compile(model, mode='max-autotune')
# First call: compilation happens (takes a few seconds)
out = compiled(x)
# Subsequent calls: reuse compiled kernels
for _ in range(100):
out = compiled(x)
Stage 1 — TorchDynamo (Graph Capture)
Dynamo intercepts Python bytecode and captures PyTorch operations as a torch.fx.Graph.
import torch._dynamo as dynamo
def my_func(x, y):
return torch.sin(x) + torch.cos(y)
# Inspect the captured FX graph
explanation = dynamo.explain(my_func)(
torch.randn(4), torch.randn(4)
)
print(explanation.graphs[0].print_tabular())
# opcode | name | target | args
# ----------|--------|--------------|----------
# call_fn | sin | torch.sin | (x,)
# call_fn | cos | torch.cos | (y,)
# call_fn | add | operator.add | (sin, cos)
# output | output | output | (add,)
Stage 2 — AOTAutograd (Pre-compile Autograd)
Takes the forward FX graph and traces through torch.autograd to produce both forward and backward graphs before execution. This enables the backend to optimize across the full gradient computation.
Stage 3 — TorchInductor (Kernel Generation)
Converts the FX graph into Triton (for CUDA) or C++ (for CPU) kernel code, applying operator fusion, tiling, and vectorization.
6.2 Apache TVM
TVM compiles deep learning models to a wide variety of hardware targets (GPU, CPU, FPGA, custom NPUs).
import tvm
from tvm import relay
import torch, torchvision
import numpy as np
# Start from a PyTorch model
model = torchvision.models.resnet18(pretrained=False).eval()
input_shape = [1, 3, 224, 224]
trace_input = torch.randn(input_shape)
# Trace to TorchScript, then import into TVM Relay IR
scripted = torch.jit.trace(model, trace_input)
mod, params = relay.frontend.from_pytorch(scripted, [("input0", input_shape)])
target = tvm.target.Target("cuda")
# AutoScheduler: search for optimal kernel parameters
from tvm import auto_scheduler
tasks, task_weights = auto_scheduler.extract_tasks(
mod["main"], params, target
)
tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
tune_option = auto_scheduler.TuningOptions(
num_measure_trials=200,
runner=auto_scheduler.LocalRunner(repeat=10),
measure_callbacks=[auto_scheduler.RecordToFile("tuning.json")],
)
tuner.tune(tune_option)
# Compile with best-found schedules
with auto_scheduler.ApplyHistoryBest("tuning.json"):
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params=params)
print("TVM compilation complete!")
6.3 Operator Fusion
Fusion is the single most impactful GPU optimization. Multiple kernel launches are merged into one.
# Without fusion: 3 separate GPU kernel launches
y1 = conv2d(x) # kernel 1 — write to HBM
y2 = batch_norm(y1) # kernel 2 — read+write HBM
y3 = relu(y2) # kernel 3 — read+write HBM
# With fusion: 1 kernel launch
y3 = fused_conv_bn_relu(x)
# Intermediate results stay in L1/L2 cache
# HBM bandwidth saved = significant speedup
6.4 XLA and JAX
XLA (Accelerated Linear Algebra) is the compiler behind TensorFlow and JAX. It uses HLO (High Level Operations) as its IR.
import jax
import jax.numpy as jnp
@jax.jit
def compute(x, W, b):
# XLA automatically fuses: matmul + add + relu
return jax.nn.relu(x @ W + b)
x = jnp.ones((32, 128))
W = jnp.ones((128, 256))
b = jnp.zeros(256)
# First call: XLA compilation
result = compute(x, W, b)
print(result.shape) # (32, 256)
XLA's key optimizations include:
- Operation fusion: merges elementwise ops after matmul/conv
- Layout optimization: chooses NCHW vs NHWC automatically
- Memory planning: minimizes peak live memory
7. MLIR: Next-Generation Compiler Infrastructure
MLIR (Multi-Level Intermediate Representation), developed at Google, unifies many compiler IRs under a single extensible framework.
TensorFlow Graph
↓ TF Dialect
MLIR HLO Dialect
↓ Linalg Dialect
Loop Nest IR
↓ Affine Dialect
Affine Loop IR
↓ LLVM Dialect
LLVM IR
↓
Machine Code
Each Dialect represents a different level of abstraction. Compilation proceeds by lowering from high-level dialects to lower-level ones, with each step preserving correctness while enabling new optimizations.
MLIR is now the backbone of:
- IREE (Google's on-device ML runtime)
- ONNX-MLIR
- Torch-MLIR (PyTorch to MLIR)
- Flang (LLVM Fortran compiler)
8. Quiz
Q1. What is the difference in parsing power between LL(1) and LR(1) parsers?
Answer: LR(1) parsers can handle a strictly larger class of grammars than LL(1).
Explanation: LL(1) is a top-down parser that reads left-to-right and constructs a leftmost derivation using 1 token of lookahead. It cannot handle left-recursive grammars and requires the grammar to be factored for unique predictions. LR(1) is a bottom-up, shift-reduce parser that can handle left recursion naturally and covers all deterministic context-free languages parsable with 1 lookahead. LALR(1) is a memory-efficient approximation of LR(1) used in Bison and GCC. In practice, nearly all real programming language grammars fit within LALR(1).
Q2. Why does SSA form make optimization easier?
Answer: Because each variable has exactly one definition, the def-use chain is trivial to compute, making dataflow analysis simple and efficient.
Explanation: In SSA, the single definition property means: (1) Constant propagation — if a def is a constant, every use can be replaced immediately; (2) Dead code elimination — a def with no uses is trivially dead; (3) Register allocation — liveness analysis reduces to reachability from def to uses; (4) Loop-invariant code motion — dependencies are explicit. LLVM, GCC, and virtually every modern compiler optimizer operates on SSA.
Q3. What properties of LLVM IR enable language-independent optimization?
Answer: LLVM IR is a typed, SSA-based, target-independent representation that cleanly separates frontend concerns from backend code generation.
Explanation: LLVM IR properties: (1) Strong static typing enforces invariants optimizers can rely on; (2) SSA form makes all data dependencies explicit; (3) Infinite virtual registers — physical register allocation happens only at code-gen time; (4) Explicit memory model with alignment. The separation of concerns is key: C, Rust, Swift, and Haskell frontends all emit LLVM IR, and a single set of optimization passes (InstCombine, GVN, LICM, etc.) improves all of them. The backend then handles target-specific concerns like instruction selection and register allocation.
Q4. How does operator fusion improve GPU performance in TVM?
Answer: Fusion reduces HBM (high-bandwidth memory) traffic by keeping intermediate results in faster on-chip memory (registers or L1/L2 cache).
Explanation: Modern GPUs are often memory-bandwidth-bound rather than compute-bound. Without fusion, each operator writes its output to HBM and the next operator reads it back — very expensive. With fusion, the intermediate result stays in registers or shared memory, and HBM is only accessed for the initial input and final output. Elementwise operators (ReLU, BatchNorm, element-wise add) are especially good fusion candidates because they are 100% memory-bandwidth-bound. TVM's Relay-to-TIR lowering pass identifies fusable operator groups automatically using a dataflow analysis.
Q5. What are the three stages of the torch.compile() stack?
Answer: TorchDynamo (graph capture) → AOTAutograd (autograd pre-compilation) → TorchInductor (kernel generation).
Explanation: (1) TorchDynamo operates at the Python bytecode level. It intercepts execution and traces PyTorch operations into an FX Graph. It handles Python's dynamic behavior (control flow, data-dependent branches) using guards — if a guard is violated at runtime, it recompiles. (2) AOTAutograd takes the forward FX graph and pre-traces through PyTorch's autograd engine to produce a joint forward+backward graph before any execution. (3) TorchInductor converts the FX graph into Triton GPU kernels or C++ CPU code, applying fusion, loop tiling, and vectorization. The mode argument controls the trade-off: 'default' is fast to compile, 'max-autotune' spends more time searching for the best tile sizes and kernel configurations.
Conclusion
Compiler design is not dusty theory — it is the engine powering today's AI infrastructure. torch.compile(), TVM, XLA, and MLIR are all built on decades of compiler research: lexers, parsers, SSA, dataflow analysis, register allocation, and code generation.
Understanding the Lexer → Parser → AST → IR → Optimization → Code Gen pipeline gives you the mental model to understand:
- Why GPU kernel fusion matters so much
- Why LLVM is the backend for so many languages
- Why SSA form is the universal foundation for optimization
Recommended next steps: the LLVM Kaleidoscope tutorial, the TVM documentation, and the PyTorch 2.0 paper (Ansel et al., 2024).