Split View: 컴파일러/인터프리터 설계 완전 정복: 파서부터 LLVM, AI 컴파일러(TVM/XLA)까지
컴파일러/인터프리터 설계 완전 정복: 파서부터 LLVM, AI 컴파일러(TVM/XLA)까지
- 들어가며
- 1. 컴파일러 기초: 전체 파이프라인
- 2. 중간 표현 (IR, Intermediate Representation)
- 3. 최적화 패스 (Optimization Passes)
- 4. 코드 생성: x86-64 어셈블리
- 5. JIT 컴파일: 런타임 최적화
- 6. AI/ML 컴파일러 생태계
- 7. MLIR: 차세대 컴파일러 인프라
- 8. 퀴즈: 이해도 점검
- 마치며
들어가며
소프트웨어 엔지니어가 "컴파일러"를 이해해야 하는 이유는 단순히 언어를 만들기 위해서가 아닙니다. 오늘날 AI/ML 분야에서 torch.compile(), TVM, XLA, MLIR 같은 도구들이 핵심 인프라로 자리 잡았고, 이들의 내부 동작은 모두 컴파일러 이론 위에 세워져 있습니다.
이 글에서는 컴파일러 설계의 전체 파이프라인을 처음부터 끝까지 살펴봅니다.
1. 컴파일러 기초: 전체 파이프라인
컴파일러는 소스 코드를 받아 실행 가능한 코드로 변환하는 프로그램입니다. 이 과정은 여러 단계로 나뉩니다.
소스 코드
↓ 어휘 분석 (Lexing)
토큰 스트림
↓ 구문 분석 (Parsing)
AST (추상 구문 트리)
↓ 의미 분석 (Semantic Analysis)
Annotated AST
↓ IR 생성
중간 표현 (IR)
↓ 최적화 (Optimization)
최적화된 IR
↓ 코드 생성 (Code Generation)
기계어 / 어셈블리
1.1 어휘 분석 (Lexical Analysis / Lexing)
어휘 분석기(Lexer 또는 Tokenizer)는 소스 코드 문자열을 토큰(Token) 의 스트림으로 변환합니다.
토큰은 숫자 리터럴, 식별자, 키워드, 연산자, 구두점 같은 의미 단위입니다.
import re
from enum import Enum, auto
from dataclasses import dataclass
from typing import List, Optional
class TokenType(Enum):
NUMBER = auto()
IDENT = auto()
PLUS = auto()
MINUS = auto()
STAR = auto()
SLASH = auto()
LPAREN = auto()
RPAREN = auto()
ASSIGN = auto()
SEMICOL = auto()
LET = auto()
EOF = auto()
@dataclass
class Token:
type: TokenType
value: str
line: int
KEYWORDS = {'let': TokenType.LET}
TOKEN_PATTERNS = [
(r'\d+(\.\d+)?', TokenType.NUMBER),
(r'[a-zA-Z_]\w*', TokenType.IDENT),
(r'\+', TokenType.PLUS),
(r'-', TokenType.MINUS),
(r'\*', TokenType.STAR),
(r'/', TokenType.SLASH),
(r'\(', TokenType.LPAREN),
(r'\)', TokenType.RPAREN),
(r'=', TokenType.ASSIGN),
(r';', TokenType.SEMICOL),
]
def tokenize(source: str) -> List[Token]:
tokens = []
line = 1
pos = 0
while pos < len(source):
if source[pos] in ' \t\r':
pos += 1
continue
if source[pos] == '\n':
line += 1
pos += 1
continue
matched = False
for pattern, ttype in TOKEN_PATTERNS:
m = re.match(pattern, source[pos:])
if m:
value = m.group(0)
# 키워드 처리
if ttype == TokenType.IDENT and value in KEYWORDS:
ttype = KEYWORDS[value]
tokens.append(Token(ttype, value, line))
pos += len(value)
matched = True
break
if not matched:
raise SyntaxError(f"Unexpected character '{source[pos]}' at line {line}")
tokens.append(Token(TokenType.EOF, '', line))
return tokens
# 사용 예
src = "let x = 10 + 20 * 3;"
for tok in tokenize(src):
print(tok)
1.2 구문 분석 (Parsing) 과 AST
파서는 토큰 스트림을 받아 추상 구문 트리(AST, Abstract Syntax Tree) 를 만듭니다. AST는 프로그램의 계층적 구조를 표현합니다.
from dataclasses import dataclass, field
from typing import Union
# AST 노드 정의
@dataclass
class NumberLit:
value: float
@dataclass
class Identifier:
name: str
@dataclass
class BinOp:
op: str
left: 'Expr'
right: 'Expr'
@dataclass
class LetStmt:
name: str
value: 'Expr'
Expr = Union[NumberLit, Identifier, BinOp]
Stmt = Union[LetStmt]
# Recursive Descent Parser
class Parser:
def __init__(self, tokens: List[Token]):
self.tokens = tokens
self.pos = 0
def peek(self) -> Token:
return self.tokens[self.pos]
def consume(self, expected: Optional[TokenType] = None) -> Token:
tok = self.tokens[self.pos]
if expected and tok.type != expected:
raise SyntaxError(f"Expected {expected}, got {tok.type} at line {tok.line}")
self.pos += 1
return tok
def parse_stmt(self) -> Stmt:
if self.peek().type == TokenType.LET:
return self.parse_let()
raise SyntaxError("Expected statement")
def parse_let(self) -> LetStmt:
self.consume(TokenType.LET)
name = self.consume(TokenType.IDENT).value
self.consume(TokenType.ASSIGN)
expr = self.parse_expr()
self.consume(TokenType.SEMICOL)
return LetStmt(name=name, value=expr)
def parse_expr(self) -> Expr:
return self.parse_additive()
def parse_additive(self) -> Expr:
left = self.parse_multiplicative()
while self.peek().type in (TokenType.PLUS, TokenType.MINUS):
op = self.consume().value
right = self.parse_multiplicative()
left = BinOp(op=op, left=left, right=right)
return left
def parse_multiplicative(self) -> Expr:
left = self.parse_primary()
while self.peek().type in (TokenType.STAR, TokenType.SLASH):
op = self.consume().value
right = self.parse_primary()
left = BinOp(op=op, left=left, right=right)
return left
def parse_primary(self) -> Expr:
tok = self.peek()
if tok.type == TokenType.NUMBER:
self.consume()
return NumberLit(value=float(tok.value))
if tok.type == TokenType.IDENT:
self.consume()
return Identifier(name=tok.value)
if tok.type == TokenType.LPAREN:
self.consume()
expr = self.parse_expr()
self.consume(TokenType.RPAREN)
return expr
raise SyntaxError(f"Unexpected token {tok}")
2. 중간 표현 (IR, Intermediate Representation)
AST를 직접 기계어로 변환하는 것은 비효율적입니다. 대신 중간 표현(IR) 을 거치면 언어와 독립적인 최적화가 가능해집니다.
2.1 3-주소 코드 (Three-Address Code)
가장 단순한 IR 형태입니다. 모든 연산은 최대 3개의 피연산자를 가집니다.
# 소스: x = (a + b) * (c - d)
t1 = a + b
t2 = c - d
x = t1 * t2
2.2 SSA (Static Single Assignment) Form
SSA는 각 변수가 정확히 한 번만 할당되는 IR 형태입니다. 최적화 분석이 크게 단순해집니다.
# 일반 코드
x = 1
x = x + 2 # x가 두 번 할당됨
# SSA 변환
x_1 = 1
x_2 = x_1 + 2
분기가 합류하는 지점에는 Phi 함수를 사용합니다.
if cond:
x_1 = 10
else:
x_2 = 20
# 합류 지점
x_3 = phi(x_1, x_2)
2.3 LLVM IR
LLVM IR은 실제 산업용 컴파일러에서 사용하는 강력한 IR입니다.
; 함수 정의: int add(int a, int b)
define i32 @add(i32 %a, i32 %b) {
entry:
%result = add i32 %a, %b
ret i32 %result
}
; 더 복잡한 예: 조건 분기 + SSA
define i32 @max(i32 %a, i32 %b) {
entry:
%cmp = icmp sgt i32 %a, %b ; a > b ?
br i1 %cmp, label %then, label %else
then:
br label %merge
else:
br label %merge
merge:
; phi 함수: 어느 블록에서 왔느냐에 따라 값 선택
%result = phi i32 [ %a, %then ], [ %b, %else ]
ret i32 %result
}
LLVM IR의 특징:
- 타입 시스템:
i32,i64,float,double, 포인터 등 명확한 타입 - SSA 기반: 모든 값은 한 번만 정의
- 무한 가상 레지스터:
%0,%1,%2... 레지스터 수 제한 없음 - 언어 독립성: C, C++, Rust, Swift 모두 같은 LLVM IR로 변환됨
3. 최적화 패스 (Optimization Passes)
컴파일러 최적화의 핵심 기법들을 살펴봅니다.
3.1 Constant Folding (상수 접기)
컴파일 시간에 계산 가능한 상수 표현식을 미리 계산합니다.
# 최적화 전
x = 2 * 3 + 4
# 최적화 후 (컴파일러가 미리 계산)
x = 10
def constant_fold(node):
"""간단한 상수 접기 구현"""
if isinstance(node, BinOp):
left = constant_fold(node.left)
right = constant_fold(node.right)
if isinstance(left, NumberLit) and isinstance(right, NumberLit):
ops = {'+': lambda a,b: a+b, '-': lambda a,b: a-b,
'*': lambda a,b: a*b, '/': lambda a,b: a/b}
result = ops[node.op](left.value, right.value)
return NumberLit(value=result)
return BinOp(op=node.op, left=left, right=right)
return node
3.2 Dead Code Elimination (죽은 코드 제거)
실행되지 않거나 결과가 사용되지 않는 코드를 제거합니다.
def is_reachable(block, cfg):
"""제어 흐름 그래프에서 도달 가능한 블록을 BFS로 탐색"""
visited = set()
queue = [cfg.entry]
while queue:
current = queue.pop(0)
if current in visited:
continue
visited.add(current)
queue.extend(cfg.successors[current])
return block in visited
3.3 Loop Unrolling (루프 풀기)
루프 오버헤드를 줄이고 명령어 수준 병렬성(ILP)을 높입니다.
// 원본 (루프 4회)
for (int i = 0; i < 4; i++) {
a[i] = b[i] + c[i];
}
// 언롤링 후 (루프 제거)
a[0] = b[0] + c[0];
a[1] = b[1] + c[1];
a[2] = b[2] + c[2];
a[3] = b[3] + c[3];
3.4 Inlining (함수 인라이닝)
함수 호출 오버헤드를 없애고 추가 최적화 기회를 만듭니다.
// 원본
inline int square(int x) { return x * x; }
int y = square(5);
// 인라이닝 후
int y = 5 * 5; // → 상수 접기로 25
4. 코드 생성: x86-64 어셈블리
IR을 실제 기계 코드로 변환하는 과정입니다.
4.1 레지스터 할당
x86-64에는 범용 레지스터가 16개(rax, rbx, rcx, rdx, rsi, rdi, rsp, rbp, r8~r15)입니다. 가상 레지스터를 물리 레지스터에 매핑해야 합니다.
그래프 채색(Graph Coloring) 알고리즘이 가장 유명한 레지스터 할당 방법입니다. 동시에 살아 있는(live) 변수들은 같은 레지스터를 공유할 수 없습니다.
4.2 x86-64 어셈블리 예시
; C 함수: int add(int a, int b) { return a + b; }
; Linux x86-64 System V ABI: 인자는 rdi, rsi 순서
add:
push rbp
mov rbp, rsp
mov DWORD PTR [rbp-4], edi ; a 저장
mov DWORD PTR [rbp-8], esi ; b 저장
mov edx, DWORD PTR [rbp-4]
mov eax, DWORD PTR [rbp-8]
add eax, edx ; a + b
pop rbp
ret ; 결과는 eax에
; 최적화된 버전 (-O2)
add:
lea eax, [rdi + rsi] ; 단 1 명령어!
ret
5. JIT 컴파일: 런타임 최적화
JIT(Just-In-Time) 컴파일은 프로그램 실행 중에 코드를 컴파일합니다. Python의 느린 실행 속도를 극복하는 핵심 기술입니다.
5.1 Numba JIT
Numba는 Python/NumPy 코드를 LLVM을 통해 기계어로 컴파일합니다.
import numpy as np
from numba import njit, prange
import time
# 순수 Python 버전
def matmul_python(A, B):
N = A.shape[0]
C = np.zeros((N, N))
for i in range(N):
for j in range(N):
for k in range(N):
C[i, j] += A[i, k] * B[k, j]
return C
# Numba JIT 버전 (병렬화 포함)
@njit(parallel=True)
def matmul_numba(A, B):
N = A.shape[0]
C = np.zeros((N, N))
for i in prange(N): # 병렬 루프
for j in range(N):
for k in range(N):
C[i, j] += A[i, k] * B[k, j]
return C
N = 256
A = np.random.rand(N, N).astype(np.float32)
B = np.random.rand(N, N).astype(np.float32)
# 첫 호출에서 컴파일 발생 (웜업)
_ = matmul_numba(A, B)
t0 = time.perf_counter()
C_python = matmul_python(A, B)
print(f"Python: {time.perf_counter()-t0:.3f}s")
t0 = time.perf_counter()
C_numba = matmul_numba(A, B)
print(f"Numba JIT: {time.perf_counter()-t0:.4f}s")
# 결과: Python ~30s vs Numba ~0.01s (수백 배 빠름)
Numba의 내부 동작:
- Python 바이트코드를 Numba IR로 변환
- 타입 추론 (type inference)으로 정적 타입 확정
- LLVM IR 생성
- LLVM 백엔드가 기계어로 컴파일
- 컴파일된 코드를 캐시에 저장
6. AI/ML 컴파일러 생태계
6.1 torch.compile() 내부 동작
PyTorch 2.0의 핵심 기능인 torch.compile()은 세 가지 컴포넌트로 구성됩니다.
import torch
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(128, 256)
self.linear2 = nn.Linear(256, 10)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.linear1(x))
return self.linear2(x)
model = SimpleNet().cuda()
x = torch.randn(32, 128).cuda()
# torch.compile 적용
compiled_model = torch.compile(model, mode='max-autotune')
# 첫 호출: 컴파일 발생 (수 초 소요)
out = compiled_model(x)
# 이후 호출: 컴파일된 커널 재사용 (빠름)
for _ in range(100):
out = compiled_model(x)
컴파일 스택 3단계:
1단계 - TorchDynamo (그래프 캡처)
Python 바이트코드를 가로채서 torch.fx.Graph 로 변환합니다.
import torch._dynamo as dynamo
def my_func(x, y):
return torch.sin(x) + torch.cos(y)
# FX 그래프 확인
explanation = dynamo.explain(my_func)(torch.randn(4), torch.randn(4))
print(explanation.graphs[0].print_tabular())
# opcode | name | target | args
# --------|--------|---------------|----------
# call_fn | sin | torch.sin | (x,)
# call_fn | cos | torch.cos | (y,)
# call_fn | add | operator.add | (sin, cos)
# output | output | output | (add,)
2단계 - AOTAutograd (자동 미분 사전 컴파일) 순전파(forward)와 역전파(backward) 그래프를 미리 생성합니다.
3단계 - TorchInductor (커널 생성) FX 그래프를 Triton(GPU) 또는 C++(CPU) 커널로 변환합니다.
6.2 TVM: 딥러닝 컴파일러
Apache TVM은 딥러닝 모델을 다양한 하드웨어(GPU, CPU, FPGA, NPU)에 최적화합니다.
import tvm
from tvm import relay
import tvm.relay.testing
import numpy as np
# PyTorch 모델을 Relay IR로 변환
import torch
import torchvision
model = torchvision.models.resnet18(pretrained=False).eval()
input_shape = [1, 3, 224, 224]
input_data = torch.randn(input_shape)
# TorchScript를 거쳐 TVM Relay로
scripted_model = torch.jit.trace(model, input_data)
shape_list = [("input0", input_shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
# 타겟 설정 및 최적화
target = tvm.target.Target("cuda")
# AutoTuning: 최적 커널 파라미터 탐색
from tvm import auto_scheduler
tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target)
tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
tune_option = auto_scheduler.TuningOptions(
num_measure_trials=200, # 측정 횟수
runner=auto_scheduler.LocalRunner(repeat=10),
measure_callbacks=[auto_scheduler.RecordToFile("tuning_log.json")],
)
tuner.tune(tune_option)
# 튜닝 결과로 컴파일
with auto_scheduler.ApplyHistoryBest("tuning_log.json"):
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params=params)
print("TVM 컴파일 완료!")
6.3 Operator Fusion (연산자 퓨전)
GPU에서 가장 중요한 최적화입니다. 여러 커널 호출을 하나로 합칩니다.
# 퓨전 전: 3번의 GPU 커널 호출
y1 = conv2d(x) # 커널 1 (메모리 R/W)
y2 = batch_norm(y1) # 커널 2 (메모리 R/W)
y3 = relu(y2) # 커널 3 (메모리 R/W)
# 퓨전 후: 1번의 GPU 커널 호출
y3 = conv_bn_relu(x) # 단일 퓨즈드 커널
# 중간 결과를 L1/L2 캐시에 유지
# 메모리 대역폭 절약 = 성능 향상
XLA(Accelerated Linear Algebra)는 TensorFlow/JAX의 컴파일러로, HLO(High Level Operations) IR을 사용합니다.
import jax
import jax.numpy as jnp
# JAX + XLA: jit 데코레이터로 XLA 컴파일 활성화
@jax.jit
def compute(x, W, b):
# XLA가 자동으로 matmul + add + relu를 퓨즈
return jax.nn.relu(x @ W + b)
x = jnp.ones((32, 128))
W = jnp.ones((128, 256))
b = jnp.zeros(256)
# 첫 호출: XLA 컴파일
result = compute(x, W, b)
print(result.shape) # (32, 256)
7. MLIR: 차세대 컴파일러 인프라
MLIR(Multi-Level Intermediate Representation)은 구글이 설계한 컴파일러 프레임워크로, 여러 레벨의 IR을 하나의 인프라로 통합합니다.
TensorFlow Graph
↓ (TF Dialect)
MLIR HLO Dialect
↓ (Linalg Dialect)
Loop Nest
↓ (Affine Dialect)
Affine Loops
↓ (LLVM Dialect)
LLVM IR
↓
기계어
각 Dialect는 특정 수준의 추상화를 표현하며, 점진적으로 낮은 수준으로 변환(lowering)합니다.
8. 퀴즈: 이해도 점검
Q1. LL(1) 파서와 LR(1) 파서의 파싱 능력 차이는?
정답: LR(1) 파서가 더 넓은 범위의 문법을 처리할 수 있습니다.
설명: LL(1)은 Left-to-right, Leftmost derivation, 1 lookahead를 의미하며 하향식(top-down) 파서입니다. 좌재귀(left-recursive) 문법을 처리하지 못하고, 첫 번째 토큰만 보고 어떤 규칙을 적용할지 결정해야 합니다. LR(1)은 Left-to-right, Rightmost derivation, 1 lookahead를 의미하며 상향식(bottom-up) 파서입니다. LALR(1)은 LR(1)의 메모리 효율적인 변형으로, GCC, Bison 등 실제 컴파일러에서 널리 사용됩니다. LR 파서는 LL 파서보다 훨씬 넓은 문법 클래스를 처리하며, 대부분의 프로그래밍 언어 문법이 LALR(1)로 표현 가능합니다.
Q2. SSA form이 최적화에 유리한 이유는?
정답: 변수의 데이터 흐름 분석이 극도로 단순해지기 때문입니다.
설명: SSA에서 각 변수는 정확히 한 번만 정의되므로, 변수의 정의(def)와 사용(use) 관계가 명확합니다. 이를 통해 (1) 상수 전파: 정의 지점에서 값이 상수면 사용 지점에서 즉시 대체 가능, (2) Dead code elimination: use가 없는 def를 즉시 제거 가능, (3) 레지스터 할당: liveness analysis가 단순해짐, (4) 루프 불변 코드 이동: 의존성 분석이 쉬워짐. LLVM, GCC의 middle-end 최적화 패스들이 모두 SSA 기반으로 동작합니다.
Q3. LLVM IR의 특징과 언어 독립적 최적화가 가능한 이유는?
정답: LLVM IR이 특정 언어나 아키텍처에 종속되지 않은 범용 저수준 IR이기 때문입니다.
설명: LLVM IR의 핵심 특징은 (1) 강타입 시스템으로 타입 안전성 보장, (2) SSA 기반으로 분석/최적화 용이, (3) 무한 가상 레지스터 (물리 레지스터 할당은 코드 생성 시), (4) 명시적인 메모리 모델과 정렬 정보. 언어 독립성은 프론트엔드(C→IR, Rust→IR, Swift→IR)와 백엔드(IR→x86, IR→ARM, IR→RISC-V)를 분리했기 때문입니다. 중간의 최적화 패스는 IR만 다루므로 모든 언어-아키텍처 조합에 적용됩니다.
Q4. TVM에서 operator fusion이 GPU 성능을 향상시키는 원리는?
정답: GPU 메모리 대역폭 병목을 줄이고 캐시 효율을 극대화하기 때문입니다.
설명: GPU 연산에서 실제 병목은 연산(FLOPS) 자체가 아니라 메모리 대역폭(HBM ↔ SRAM)인 경우가 많습니다. 퓨전 없이 연산할 때는 각 연산마다 중간 결과를 HBM(전역 메모리)에 쓰고 다음 커널이 다시 읽어야 합니다. 퓨전 후에는 중간 결과가 L1/L2 캐시나 레지스터에 머물러 HBM 접근 횟수가 크게 줄어듭니다. 특히 elementwise 연산들(ReLU, BatchNorm 등)은 메모리 바운드이므로 퓨전 효과가 극적입니다. TVM의 Relay→TIR 변환 과정에서 퓨전 가능한 연산 그룹을 자동으로 탐지합니다.
Q5. torch.compile()이 내부적으로 사용하는 컴파일 스택은?
정답: TorchDynamo → AOTAutograd → TorchInductor의 3단계 스택입니다.
설명: (1) TorchDynamo: Python 바이트코드 수준에서 동작하며, Python 코드를 실행하면서 PyTorch 연산 부분을 FX Graph로 캡처합니다. Python의 동적 특성(제어 흐름, 데이터 의존적 분기)을 처리하기 위해 guard 기반 특수화를 사용합니다. (2) AOTAutograd: 순전파 FX Graph를 받아 역전파 그래프까지 미리(ahead-of-time) 생성합니다. (3) TorchInductor: FX Graph를 실제 GPU 커널 코드(Triton) 또는 CPU 코드(C++)로 변환합니다. operator fusion, tiling, vectorization 등의 최적화가 여기서 적용됩니다. mode 옵션으로 'default', 'reduce-overhead', 'max-autotune' 중 선택 가능합니다.
마치며
컴파일러 설계는 단순한 이론이 아닙니다. 오늘날 AI/ML 인프라의 핵심인 torch.compile(), TVM, XLA, MLIR은 모두 수십 년의 컴파일러 연구 위에 세워져 있습니다.
Lexer → Parser → AST → IR → Optimization → Code Gen 파이프라인을 이해하면, 왜 GPU 프로그래밍에서 커널 퓨전이 중요한지, 왜 LLVM이 이렇게 많은 언어의 백엔드로 사용되는지, 왜 SSA form이 최적화의 기반이 되는지를 자연스럽게 이해할 수 있습니다.
다음 단계로는 LLVM 튜토리얼(Kaleidoscope), TVM 공식 문서, 그리고 PyTorch 2.0 논문을 읽어보시길 추천합니다.
Compiler & Interpreter Design: From Parsers to LLVM and AI Compilers (TVM/XLA)
- Introduction
- 1. Compiler Fundamentals: The Full Pipeline
- 2. Intermediate Representations (IR)
- 3. Optimization Passes
- 4. Code Generation: x86-64 Assembly
- 5. JIT Compilation
- 6. AI/ML Compilers
- 7. MLIR: Next-Generation Compiler Infrastructure
- 8. Quiz
- Conclusion
Introduction
Understanding compilers is not just about building programming languages. Today, tools like torch.compile(), TVM, XLA, and MLIR are core infrastructure in AI/ML, and their internals are built entirely on compiler theory.
This guide walks through the entire compiler pipeline from first principles, ending at the cutting edge of ML compilation.
1. Compiler Fundamentals: The Full Pipeline
A compiler transforms source code into executable code through a series of well-defined stages.
Source Code
↓ Lexical Analysis (Lexing)
Token Stream
↓ Syntactic Analysis (Parsing)
AST (Abstract Syntax Tree)
↓ Semantic Analysis
Annotated AST
↓ IR Generation
Intermediate Representation (IR)
↓ Optimization Passes
Optimized IR
↓ Code Generation
Machine Code / Assembly
1.1 Lexical Analysis (Tokenization)
The lexer reads source characters and emits a stream of tokens — the smallest meaningful units like numbers, identifiers, keywords, and operators.
import re
from enum import Enum, auto
from dataclasses import dataclass
from typing import List, Optional
class TokenType(Enum):
NUMBER = auto()
IDENT = auto()
PLUS = auto()
MINUS = auto()
STAR = auto()
SLASH = auto()
LPAREN = auto()
RPAREN = auto()
ASSIGN = auto()
SEMICOL = auto()
LET = auto()
EOF = auto()
@dataclass
class Token:
type: TokenType
value: str
line: int
KEYWORDS = {'let': TokenType.LET}
TOKEN_PATTERNS = [
(r'\d+(\.\d+)?', TokenType.NUMBER),
(r'[a-zA-Z_]\w*', TokenType.IDENT),
(r'\+', TokenType.PLUS),
(r'-', TokenType.MINUS),
(r'\*', TokenType.STAR),
(r'/', TokenType.SLASH),
(r'\(', TokenType.LPAREN),
(r'\)', TokenType.RPAREN),
(r'=', TokenType.ASSIGN),
(r';', TokenType.SEMICOL),
]
def tokenize(source: str) -> List[Token]:
tokens = []
line = 1
pos = 0
while pos < len(source):
if source[pos] in ' \t\r':
pos += 1
continue
if source[pos] == '\n':
line += 1
pos += 1
continue
matched = False
for pattern, ttype in TOKEN_PATTERNS:
m = re.match(pattern, source[pos:])
if m:
value = m.group(0)
if ttype == TokenType.IDENT and value in KEYWORDS:
ttype = KEYWORDS[value]
tokens.append(Token(ttype, value, line))
pos += len(value)
matched = True
break
if not matched:
raise SyntaxError(f"Unexpected character '{source[pos]}' at line {line}")
tokens.append(Token(TokenType.EOF, '', line))
return tokens
# Example usage
src = "let x = 10 + 20 * 3;"
for tok in tokenize(src):
print(tok)
1.2 Parsing and the AST
The parser consumes tokens and builds an Abstract Syntax Tree (AST) that captures the hierarchical structure of the program.
from dataclasses import dataclass
from typing import Union
# AST node definitions
@dataclass
class NumberLit:
value: float
@dataclass
class Identifier:
name: str
@dataclass
class BinOp:
op: str
left: 'Expr'
right: 'Expr'
@dataclass
class LetStmt:
name: str
value: 'Expr'
Expr = Union[NumberLit, Identifier, BinOp]
# Recursive Descent Parser
class Parser:
def __init__(self, tokens: List[Token]):
self.tokens = tokens
self.pos = 0
def peek(self) -> Token:
return self.tokens[self.pos]
def consume(self, expected: Optional[TokenType] = None) -> Token:
tok = self.tokens[self.pos]
if expected and tok.type != expected:
raise SyntaxError(f"Expected {expected}, got {tok.type} at line {tok.line}")
self.pos += 1
return tok
def parse_let(self) -> LetStmt:
self.consume(TokenType.LET)
name = self.consume(TokenType.IDENT).value
self.consume(TokenType.ASSIGN)
expr = self.parse_expr()
self.consume(TokenType.SEMICOL)
return LetStmt(name=name, value=expr)
def parse_expr(self) -> Expr:
return self.parse_additive()
def parse_additive(self) -> Expr:
left = self.parse_multiplicative()
while self.peek().type in (TokenType.PLUS, TokenType.MINUS):
op = self.consume().value
right = self.parse_multiplicative()
left = BinOp(op=op, left=left, right=right)
return left
def parse_multiplicative(self) -> Expr:
left = self.parse_primary()
while self.peek().type in (TokenType.STAR, TokenType.SLASH):
op = self.consume().value
right = self.parse_primary()
left = BinOp(op=op, left=left, right=right)
return left
def parse_primary(self) -> Expr:
tok = self.peek()
if tok.type == TokenType.NUMBER:
self.consume()
return NumberLit(value=float(tok.value))
if tok.type == TokenType.IDENT:
self.consume()
return Identifier(name=tok.value)
if tok.type == TokenType.LPAREN:
self.consume()
expr = self.parse_expr()
self.consume(TokenType.RPAREN)
return expr
raise SyntaxError(f"Unexpected token {tok}")
Parsing strategies differ in power and complexity:
| Strategy | Type | Notes |
|---|---|---|
| Recursive Descent | LL(k) | Hand-written, easy to read |
| Pratt Parser | Top-down operator precedence | Elegant for expressions |
| LALR(1) | Bottom-up | Used by Bison, GCC, Python |
| PEG / Packrat | Parsing Expression | No ambiguity, memoized |
2. Intermediate Representations (IR)
Going straight from AST to machine code is impractical. A well-designed IR enables language-agnostic optimizations.
2.1 Three-Address Code
The simplest flat IR. Every operation uses at most three operands.
# Source: x = (a + b) * (c - d)
t1 = a + b
t2 = c - d
x = t1 * t2
2.2 SSA Form (Static Single Assignment)
SSA requires that every variable is assigned exactly once. This simple constraint dramatically simplifies dataflow analysis.
# Original code
x = 1
x = x + 2 # x assigned twice
# SSA form
x_1 = 1
x_2 = x_1 + 2
At control-flow merge points, phi functions select the correct version:
if cond:
x_1 = 10
else:
x_2 = 20
# Merge point
x_3 = phi(x_1, x_2)
2.3 LLVM IR
LLVM IR is the industry-standard IR used by Clang, Rust, Swift, and many others.
; Function: int add(int a, int b)
define i32 @add(i32 %a, i32 %b) {
entry:
%result = add i32 %a, %b
ret i32 %result
}
; Branching with SSA phi nodes
define i32 @max(i32 %a, i32 %b) {
entry:
%cmp = icmp sgt i32 %a, %b ; signed greater-than
br i1 %cmp, label %then, label %else
then:
br label %merge
else:
br label %merge
merge:
%result = phi i32 [ %a, %then ], [ %b, %else ]
ret i32 %result
}
Key properties of LLVM IR:
- Strong type system:
i32,i64,float,ptr— all explicit - SSA-based: every
%nameis defined exactly once - Infinite virtual registers:
%0,%1,%2... no limit - Target-independent: C, Rust, Swift all lower to the same IR
3. Optimization Passes
3.1 Constant Folding
Pre-compute constant expressions at compile time.
# Before
x = 2 * 3 + 4
# After (compiler computes at build time)
x = 10
def constant_fold(node):
"""Simple constant folding pass."""
if isinstance(node, BinOp):
left = constant_fold(node.left)
right = constant_fold(node.right)
if isinstance(left, NumberLit) and isinstance(right, NumberLit):
ops = {
'+': lambda a, b: a + b,
'-': lambda a, b: a - b,
'*': lambda a, b: a * b,
'/': lambda a, b: a / b,
}
return NumberLit(value=ops[node.op](left.value, right.value))
return BinOp(op=node.op, left=left, right=right)
return node
3.2 Dead Code Elimination
Remove code whose results are never used.
def dead_code_eliminate(stmts, used_vars):
"""Remove assignments to variables that are never read."""
live = set()
# Backward pass: collect used variables
for stmt in reversed(stmts):
if isinstance(stmt, LetStmt):
if stmt.name in used_vars or stmt.name in live:
live.update(collect_vars(stmt.value))
return [s for s in stmts
if not isinstance(s, LetStmt) or s.name in live]
3.3 Loop Unrolling
Reduce loop overhead and expose instruction-level parallelism.
// Original — 4 iterations with loop overhead
for (int i = 0; i < 4; i++) {
a[i] = b[i] + c[i];
}
// Unrolled — zero loop overhead, can be vectorized
a[0] = b[0] + c[0];
a[1] = b[1] + c[1];
a[2] = b[2] + c[2];
a[3] = b[3] + c[3];
3.4 Function Inlining
Replace a call site with the callee body, removing call overhead and enabling further optimization.
// Before
inline int square(int x) { return x * x; }
int y = square(5);
// After inlining (then constant folding)
int y = 25;
4. Code Generation: x86-64 Assembly
4.1 Register Allocation
x86-64 has 16 general-purpose registers (rax, rbx, rcx, rdx, rsi, rdi, rsp, rbp, r8–r15). The compiler must map unlimited virtual registers to this finite set.
Graph Coloring is the classic algorithm: build an interference graph where two virtual registers share an edge if they are live at the same time, then color the graph with N colors (N = number of physical registers). If coloring fails, spill a variable to the stack.
4.2 x86-64 Example
; C: int add(int a, int b) { return a + b; }
; Linux System V ABI: args in rdi, rsi; return in rax
add:
push rbp
mov rbp, rsp
mov DWORD PTR [rbp-4], edi
mov DWORD PTR [rbp-8], esi
mov edx, DWORD PTR [rbp-4]
mov eax, DWORD PTR [rbp-8]
add eax, edx
pop rbp
ret
; With -O2 (optimized)
add:
lea eax, [rdi + rsi] ; single instruction!
ret
5. JIT Compilation
JIT (Just-In-Time) compilation happens at runtime. It is the key technology that closes the performance gap between interpreted languages and native code.
5.1 Numba JIT for Python
Numba uses LLVM to compile Python/NumPy code to native machine code at runtime.
import numpy as np
from numba import njit, prange
import time
# Pure Python baseline
def matmul_python(A, B):
N = A.shape[0]
C = np.zeros((N, N))
for i in range(N):
for j in range(N):
for k in range(N):
C[i, j] += A[i, k] * B[k, j]
return C
# Numba JIT with parallelism
@njit(parallel=True)
def matmul_numba(A, B):
N = A.shape[0]
C = np.zeros((N, N))
for i in prange(N): # parallel loop
for j in range(N):
for k in range(N):
C[i, j] += A[i, k] * B[k, j]
return C
N = 256
A = np.random.rand(N, N).astype(np.float32)
B = np.random.rand(N, N).astype(np.float32)
# First call triggers compilation (warmup)
_ = matmul_numba(A, B)
t0 = time.perf_counter()
C_py = matmul_python(A, B)
print(f"Python: {time.perf_counter()-t0:.3f}s")
t0 = time.perf_counter()
C_nb = matmul_numba(A, B)
print(f"Numba JIT: {time.perf_counter()-t0:.4f}s")
# Typical result: Python ~30s vs Numba ~0.01s
How Numba works internally:
- Intercept Python bytecode
- Type inference to determine concrete types
- Lower to LLVM IR
- LLVM backend compiles to native machine code
- Cache the compiled artifact for subsequent calls
5.2 PyPy and Tracing JIT
PyPy's tracing JIT records "hot paths" (frequently executed traces) and compiles them directly. Unlike method-based JIT (HotSpot JVM), it can inline across call boundaries along the trace.
6. AI/ML Compilers
6.1 torch.compile() Internals
torch.compile(), introduced in PyTorch 2.0, is a three-stage compilation pipeline.
import torch
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(128, 256)
self.fc2 = nn.Linear(256, 10)
self.relu = nn.ReLU()
def forward(self, x):
return self.fc2(self.relu(self.fc1(x)))
model = SimpleNet().cuda()
x = torch.randn(32, 128).cuda()
# Apply torch.compile
compiled = torch.compile(model, mode='max-autotune')
# First call: compilation happens (takes a few seconds)
out = compiled(x)
# Subsequent calls: reuse compiled kernels
for _ in range(100):
out = compiled(x)
Stage 1 — TorchDynamo (Graph Capture)
Dynamo intercepts Python bytecode and captures PyTorch operations as a torch.fx.Graph.
import torch._dynamo as dynamo
def my_func(x, y):
return torch.sin(x) + torch.cos(y)
# Inspect the captured FX graph
explanation = dynamo.explain(my_func)(
torch.randn(4), torch.randn(4)
)
print(explanation.graphs[0].print_tabular())
# opcode | name | target | args
# ----------|--------|--------------|----------
# call_fn | sin | torch.sin | (x,)
# call_fn | cos | torch.cos | (y,)
# call_fn | add | operator.add | (sin, cos)
# output | output | output | (add,)
Stage 2 — AOTAutograd (Pre-compile Autograd)
Takes the forward FX graph and traces through torch.autograd to produce both forward and backward graphs before execution. This enables the backend to optimize across the full gradient computation.
Stage 3 — TorchInductor (Kernel Generation)
Converts the FX graph into Triton (for CUDA) or C++ (for CPU) kernel code, applying operator fusion, tiling, and vectorization.
6.2 Apache TVM
TVM compiles deep learning models to a wide variety of hardware targets (GPU, CPU, FPGA, custom NPUs).
import tvm
from tvm import relay
import torch, torchvision
import numpy as np
# Start from a PyTorch model
model = torchvision.models.resnet18(pretrained=False).eval()
input_shape = [1, 3, 224, 224]
trace_input = torch.randn(input_shape)
# Trace to TorchScript, then import into TVM Relay IR
scripted = torch.jit.trace(model, trace_input)
mod, params = relay.frontend.from_pytorch(scripted, [("input0", input_shape)])
target = tvm.target.Target("cuda")
# AutoScheduler: search for optimal kernel parameters
from tvm import auto_scheduler
tasks, task_weights = auto_scheduler.extract_tasks(
mod["main"], params, target
)
tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
tune_option = auto_scheduler.TuningOptions(
num_measure_trials=200,
runner=auto_scheduler.LocalRunner(repeat=10),
measure_callbacks=[auto_scheduler.RecordToFile("tuning.json")],
)
tuner.tune(tune_option)
# Compile with best-found schedules
with auto_scheduler.ApplyHistoryBest("tuning.json"):
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params=params)
print("TVM compilation complete!")
6.3 Operator Fusion
Fusion is the single most impactful GPU optimization. Multiple kernel launches are merged into one.
# Without fusion: 3 separate GPU kernel launches
y1 = conv2d(x) # kernel 1 — write to HBM
y2 = batch_norm(y1) # kernel 2 — read+write HBM
y3 = relu(y2) # kernel 3 — read+write HBM
# With fusion: 1 kernel launch
y3 = fused_conv_bn_relu(x)
# Intermediate results stay in L1/L2 cache
# HBM bandwidth saved = significant speedup
6.4 XLA and JAX
XLA (Accelerated Linear Algebra) is the compiler behind TensorFlow and JAX. It uses HLO (High Level Operations) as its IR.
import jax
import jax.numpy as jnp
@jax.jit
def compute(x, W, b):
# XLA automatically fuses: matmul + add + relu
return jax.nn.relu(x @ W + b)
x = jnp.ones((32, 128))
W = jnp.ones((128, 256))
b = jnp.zeros(256)
# First call: XLA compilation
result = compute(x, W, b)
print(result.shape) # (32, 256)
XLA's key optimizations include:
- Operation fusion: merges elementwise ops after matmul/conv
- Layout optimization: chooses NCHW vs NHWC automatically
- Memory planning: minimizes peak live memory
7. MLIR: Next-Generation Compiler Infrastructure
MLIR (Multi-Level Intermediate Representation), developed at Google, unifies many compiler IRs under a single extensible framework.
TensorFlow Graph
↓ TF Dialect
MLIR HLO Dialect
↓ Linalg Dialect
Loop Nest IR
↓ Affine Dialect
Affine Loop IR
↓ LLVM Dialect
LLVM IR
↓
Machine Code
Each Dialect represents a different level of abstraction. Compilation proceeds by lowering from high-level dialects to lower-level ones, with each step preserving correctness while enabling new optimizations.
MLIR is now the backbone of:
- IREE (Google's on-device ML runtime)
- ONNX-MLIR
- Torch-MLIR (PyTorch to MLIR)
- Flang (LLVM Fortran compiler)
8. Quiz
Q1. What is the difference in parsing power between LL(1) and LR(1) parsers?
Answer: LR(1) parsers can handle a strictly larger class of grammars than LL(1).
Explanation: LL(1) is a top-down parser that reads left-to-right and constructs a leftmost derivation using 1 token of lookahead. It cannot handle left-recursive grammars and requires the grammar to be factored for unique predictions. LR(1) is a bottom-up, shift-reduce parser that can handle left recursion naturally and covers all deterministic context-free languages parsable with 1 lookahead. LALR(1) is a memory-efficient approximation of LR(1) used in Bison and GCC. In practice, nearly all real programming language grammars fit within LALR(1).
Q2. Why does SSA form make optimization easier?
Answer: Because each variable has exactly one definition, the def-use chain is trivial to compute, making dataflow analysis simple and efficient.
Explanation: In SSA, the single definition property means: (1) Constant propagation — if a def is a constant, every use can be replaced immediately; (2) Dead code elimination — a def with no uses is trivially dead; (3) Register allocation — liveness analysis reduces to reachability from def to uses; (4) Loop-invariant code motion — dependencies are explicit. LLVM, GCC, and virtually every modern compiler optimizer operates on SSA.
Q3. What properties of LLVM IR enable language-independent optimization?
Answer: LLVM IR is a typed, SSA-based, target-independent representation that cleanly separates frontend concerns from backend code generation.
Explanation: LLVM IR properties: (1) Strong static typing enforces invariants optimizers can rely on; (2) SSA form makes all data dependencies explicit; (3) Infinite virtual registers — physical register allocation happens only at code-gen time; (4) Explicit memory model with alignment. The separation of concerns is key: C, Rust, Swift, and Haskell frontends all emit LLVM IR, and a single set of optimization passes (InstCombine, GVN, LICM, etc.) improves all of them. The backend then handles target-specific concerns like instruction selection and register allocation.
Q4. How does operator fusion improve GPU performance in TVM?
Answer: Fusion reduces HBM (high-bandwidth memory) traffic by keeping intermediate results in faster on-chip memory (registers or L1/L2 cache).
Explanation: Modern GPUs are often memory-bandwidth-bound rather than compute-bound. Without fusion, each operator writes its output to HBM and the next operator reads it back — very expensive. With fusion, the intermediate result stays in registers or shared memory, and HBM is only accessed for the initial input and final output. Elementwise operators (ReLU, BatchNorm, element-wise add) are especially good fusion candidates because they are 100% memory-bandwidth-bound. TVM's Relay-to-TIR lowering pass identifies fusable operator groups automatically using a dataflow analysis.
Q5. What are the three stages of the torch.compile() stack?
Answer: TorchDynamo (graph capture) → AOTAutograd (autograd pre-compilation) → TorchInductor (kernel generation).
Explanation: (1) TorchDynamo operates at the Python bytecode level. It intercepts execution and traces PyTorch operations into an FX Graph. It handles Python's dynamic behavior (control flow, data-dependent branches) using guards — if a guard is violated at runtime, it recompiles. (2) AOTAutograd takes the forward FX graph and pre-traces through PyTorch's autograd engine to produce a joint forward+backward graph before any execution. (3) TorchInductor converts the FX graph into Triton GPU kernels or C++ CPU code, applying fusion, loop tiling, and vectorization. The mode argument controls the trade-off: 'default' is fast to compile, 'max-autotune' spends more time searching for the best tile sizes and kernel configurations.
Conclusion
Compiler design is not dusty theory — it is the engine powering today's AI infrastructure. torch.compile(), TVM, XLA, and MLIR are all built on decades of compiler research: lexers, parsers, SSA, dataflow analysis, register allocation, and code generation.
Understanding the Lexer → Parser → AST → IR → Optimization → Code Gen pipeline gives you the mental model to understand:
- Why GPU kernel fusion matters so much
- Why LLVM is the backend for so many languages
- Why SSA form is the universal foundation for optimization
Recommended next steps: the LLVM Kaleidoscope tutorial, the TVM documentation, and the PyTorch 2.0 paper (Ansel et al., 2024).