Split View: Google TPU 완전 해부: Systolic Array가 행렬 곱셈을 어떻게 완벽히 해결하는가
Google TPU 완전 해부: Systolic Array가 행렬 곱셈을 어떻게 완벽히 해결하는가
- 들어가며: 구글은 왜 직접 칩을 만들었는가
- 1. Systolic Array: 심장이 뛰듯 데이터가 흐른다
- 2. TPU v1 스펙 완전 분석
- 3. GPU vs TPU: 철학의 차이
- 4. TPU 메모리 계층 구조
- 5. bfloat16: Google이 만든 수 형식
- 6. TPU 버전 역사: 4세대의 진화
- 7. XLA: TPU를 위한 컴파일러
- 8. TPU Pod에서 LLM 서빙하기
- 9. 실전: JAX로 TPU에서 LLM 추론 최적화
- 10. TPU vs 경쟁자: 실제 LLM 서빙 벤치마크
- 마치며
- 참고 자료
들어가며: 구글은 왜 직접 칩을 만들었는가
2013년 구글 엔지니어들이 충격적인 계산을 했습니다. 만약 모든 Gmail 사용자가 하루 3분씩 음성 검색을 사용하면, 기존 서버 인프라의 두 배가 필요하다는 것이었습니다. 음성 인식에 사용되는 딥 뉴럴 네트워크(DNN)의 연산 비용이 문제였습니다.
Jeff Dean을 포함한 구글 팀은 결론을 내렸습니다: 범용 CPU와 GPU로는 한계가 있다. 뉴럴 네트워크 추론에 특화된 칩을 만들자.
그 결과가 Tensor Processing Unit (TPU) 입니다. 2015년 구글 데이터센터에 처음 배포되었고, 2017년 ISCA 논문으로 세상에 공개되었습니다. Norman Jouppi가 이끈 하드웨어 팀의 핵심 인사이트는 단순했습니다:
"뉴럴 네트워크 추론에는 완전한 IEEE 부동소수점이 필요 없다. INT8이면 충분하다."
이 단순한 인사이트가 경쟁 GPU 대비 10배 이상의 성능/전력 효율을 가능하게 했습니다.
1. Systolic Array: 심장이 뛰듯 데이터가 흐른다
이름의 유래
Systolic Array라는 이름은 심장의 수축기(systolic) 에서 왔습니다. 심장이 규칙적으로 수축하며 혈액을 펌핑하듯, Systolic Array는 규칙적인 클럭 사이클마다 데이터를 흘려보냅니다. H.T. Kung과 Charles Leiserson이 1978년에 제안한 개념을 구글이 뉴럴 네트워크에 맞게 대규모로 구현했습니다.
Systolic Array의 기본 구조
4x4 Systolic Array로 행렬 곱셈 A × B = C를 수행하는 과정을 단계별로 살펴봅시다.
Systolic Array: 4x4 MAC (Multiply-Accumulate) 유닛 배열
각 셀: 왼쪽(A)과 위(B)에서 입력을 받아 곱하고, 누산기에 더함
시간 0단계: 데이터 로드 준비
B[0,0] B[1,0] B[2,0] B[3,0] ← B 행렬 열이 아래로 흐름
↓ ↓ ↓ ↓
A[0,0]→[MAC][MAC][MAC][MAC]
A[1,0]→[MAC][MAC][MAC][MAC]
A[2,0]→[MAC][MAC][MAC][MAC]
A[3,0]→[MAC][MAC][MAC][MAC]
↓ ↓ ↓ ↓
C열0 C열1 C열2 C열3 ← 결과가 아래로 출력
시간 1단계: A[0,0]이 MAC[0,0]에 도달 → A[0,0]*B[0,0] 계산
시간 2단계: A[0,0]이 MAC[0,1]으로 이동, A[0,1]이 MAC[0,0]에 진입
각 셀이 부분 곱을 누산
핵심 인사이트:
- 각 MAC 유닛은 매 클럭 사이클마다 동작 (유휴 시간 없음)
- 연산 중 메모리 읽기 없음 (데이터가 이미 "비행 중")
- TPU v1: 256×256 = 65,536개의 MAC이 동시에 동작!
기존 방식과의 차이
일반적인 프로세서에서 행렬 곱셈을 할 때:
일반 프로세서 방식:
for i in range(N):
for j in range(N):
for k in range(N):
C[i][j] += A[i][k] * B[k][j] # 매번 메모리에서 A, B 읽기
문제점: 메모리 접근이 O(N³) 번 발생
캐시 미스가 병목이 됨
실제 계산보다 데이터 이동 시간이 더 걸림
Systolic Array는 이 문제를 근본적으로 해결합니다:
Systolic Array 방식:
- A의 각 원소: 한 번 입력 → 전체 행을 통과하며 모든 곱을 기여
- B의 각 원소: 한 번 입력 → 전체 열을 통과하며 모든 곱을 기여
- 각 MAC 유닛: 데이터를 기다리지 않고, 흘러오는 데이터를 즉시 처리
결과: 메모리 접근 횟수를 O(N²)으로 줄임 (N이 크면 엄청난 차이!)
2. TPU v1 스펙 완전 분석
2015년 배포된 TPU v1의 실제 스펙을 분석해봅시다.
TPU v1 하드웨어 스펙:
┌─────────────────────────────────────────────────────┐
│ Systolic Array: 256 × 256 = 65,536 MAC 유닛 │
│ 데이터 형식: INT8 (가중치), INT16 (어큐뮬레이터) │
│ 클럭 속도: 700 MHz │
│ 온칩 메모리: 28 MB (Weight FIFO + Unified Buffer) │
│ 메모리 대역폭: 30 GB/s (DDR3) │
│ 전력 소비: 40W │
│ 패키지: 28nm CMOS, PCIe 카드 형태 │
└─────────────────────────────────────────────────────┘
성능 계산:
65,536 MAC × 700 MHz × 2 (곱셈+덧셈) = 92 TOPS (INT8)
비교:
- TPU v1: 92 TOPS @ 40W = 2.3 TOPS/W
- NVIDIA K80 GPU: 8.7 TOPS @ 300W = 0.029 TOPS/W
→ TPU v1이 K80 대비 에너지 효율 79배!
→ 성능 자체도 10.6배 높음
이 숫자들은 Google의 ISCA 2017 논문에서 검증된 실측값입니다. 6개의 서로 다른 뉴럴 네트워크 워크로드에서 TPU v1은 GPU 대비 평균 15~30배 빠른 추론 성능을 보였습니다.
3. GPU vs TPU: 철학의 차이
GPU 설계 철학: "나는 모든 것을 할 수 있다"
┌─────────────────────────────────────────────┐
│ 수천 개의 범용 CUDA 코어 │
│ 유연한 프로그래밍 모델 (CUDA) │
│ 복잡한 분기/조건문 지원 │
│ 임의 메모리 접근 패턴 │
│ 그래픽, 물리 시뮬레이션, AI 모두 가능 │
│ 오버헤드: 스케줄러, 레지스터 파일, 캐시 계층 │
└─────────────────────────────────────────────┘
TPU 설계 철학: "나는 행렬 곱셈만 한다, 하지만 완벽하게"
┌─────────────────────────────────────────────┐
│ Systolic Array (행렬 곱셈 전용 하드웨어) │
│ 예측 가능한 데이터 흐름 (컴파일 타임 결정) │
│ 제한된 연산 집합 (MAC 위주) │
│ 복잡한 제어 로직 없음 │
│ 낭비 없는 100% 하드웨어 활용 │
│ 에너지 효율 극대화 │
└─────────────────────────────────────────────┘
왜 이 전략이 먹히는가?
Transformer 연산의 95% 이상이 행렬 곱셈 (GEMM)
→ 특화 설계가 범용 설계를 압도
트랜스포머 모델에서 실제로 어떤 연산이 얼마나 발생하는지 측정해봅시다:
# GPT-2 (124M params) 연산 분석
import torch
from torch.profiler import profile, ProfilerActivity
# Attention + FFN의 FLOPS 분포
layer_config = {
'd_model': 768,
'n_heads': 12,
'n_layers': 12,
'seq_len': 1024
}
d = layer_config['d_model']
n = layer_config['seq_len']
L = layer_config['n_layers']
# Q, K, V projection: 각 3개의 [d×d] 행렬 곱셈
qkv_flops = 3 * 2 * n * d * d * L # factor of 2 for multiply+add
# Attention 계산: Q×K^T + softmax + ×V
attn_flops = (2 * n * n * d + 2 * n * n * d) * L
# FFN: 2개의 [d×4d] 행렬 곱셈
ffn_flops = 2 * 2 * n * d * (4 * d) * L
total = qkv_flops + attn_flops + ffn_flops
print(f"QKV projection: {qkv_flops/total*100:.1f}%") # ~47%
print(f"Attention: {attn_flops/total*100:.1f}%") # ~12%
print(f"FFN: {ffn_flops/total*100:.1f}%") # ~41%
# GEMM 계열: 약 88-95%
4. TPU 메모리 계층 구조
TPU의 메모리 설계는 Systolic Array의 데이터 흐름에 최적화되어 있습니다.
TPU v4 메모리 계층:
┌─────────────────────────────────────────────────────────┐
│ Weight FIFO (온칩, Systolic Array 직접 공급) │
│ 크기: 32 MB / 대역폭: 2.7 TB/s │
│ 목적: 가중치를 Systolic Array에 끊임없이 공급 │
├─────────────────────────────────────────────────────────┤
│ Unified Buffer (온칩, 활성화 값 저장) │
│ 크기: 256 MB / 대역폭: 900 GB/s │
│ 목적: 레이어 간 중간 활성화 값 저장 │
├─────────────────────────────────────────────────────────┤
│ High Bandwidth Memory (HBM, 오프칩) │
│ 크기: 32 GB / 대역폭: 1.2 TB/s │
│ 목적: 모델 가중치 전체 저장 │
└─────────────────────────────────────────────────────────┘
데이터 흐름:
HBM → Weight FIFO → Systolic Array (가중치 흐름)
HBM → Unified Buffer → Systolic Array (활성화 흐름)
Systolic Array → Unified Buffer → 다음 레이어
Weight Stationary vs Output Stationary Dataflow
Systolic Array를 사용하는 방식에는 두 가지 전략이 있습니다:
Weight Stationary (가중치 고정):
- 가중치(W)를 Systolic Array의 각 MAC에 미리 로드
- 입력 활성화(X)를 흘려보냄
- 결과: W × X를 계산
- 장점: 가중치 재사용 최대화 (배치 크기가 클 때 유리)
- TPU v1이 사용하는 방식
Output Stationary (출력 고정):
- 출력 C[i][j]를 해당 MAC에 누산
- 입력 A와 B를 모두 흘려보냄
- 장점: 출력 쓰기 최소화 (작은 배치에 유리)
선택 기준:
- 추론 (배치 크기 작음): Output Stationary 유리
- 학습 (배치 크기 큼): Weight Stationary 유리
5. bfloat16: Google이 만든 수 형식
TPU v2부터 도입된 bfloat16은 현재 딥러닝의 표준 데이터 형식이 되었습니다.
수 형식 비교:
FP32: [1 부호][8 지수][23 가수] = 32비트
범위: ±3.4 × 10^38
정밀도: 약 7자리
FP16: [1 부호][5 지수][10 가수] = 16비트
범위: ±6.5 × 10^4 (너무 좁음!)
정밀도: 약 3-4자리
문제: 기울기 소실/폭발 위험
bfloat16:[1 부호][8 지수][7 가수] = 16비트 ← Google 혁신
범위: ±3.4 × 10^38 (FP32와 동일!)
정밀도: 약 2-3자리
장점: FP32 상위 16비트만 떼어내면 됨
왜 딥러닝에 bfloat16이 좋은가:
1. FP32와 같은 지수 범위 → 오버플로/언더플로 없음
2. 가수 정밀도 손실 → 딥러닝은 정밀도보다 범위가 중요
3. FP32 → bfloat16 변환: 마지막 16비트 버리면 됨 (하드웨어 비용 극소)
4. 혼합 정밀도 학습: FP32 마스터 가중치 + bfloat16 연산
실제 bfloat16 사용 예:
import jax
import jax.numpy as jnp
# bfloat16으로 행렬 연산
A = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=jnp.bfloat16)
B = jnp.array([[5.0, 6.0], [7.0, 8.0]], dtype=jnp.bfloat16)
# TPU에서 bfloat16 행렬 곱셈
C = jnp.dot(A, B) # 자동으로 TPU Systolic Array 활용
print(f"Result dtype: {C.dtype}") # bfloat16
print(f"Result:\n{C}")
# FP32 vs bfloat16 정밀도 비교
x_fp32 = jnp.array(1.0 / 3.0, dtype=jnp.float32)
x_bf16 = jnp.array(1.0 / 3.0, dtype=jnp.bfloat16)
print(f"FP32: {float(x_fp32):.10f}") # 0.3333333433
print(f"BF16: {float(x_bf16):.10f}") # 0.3320312500 (약간 덜 정확)
# 딥러닝에서 이 정도 차이는 무시 가능
6. TPU 버전 역사: 4세대의 진화
| 버전 | 연도 | 핵심 혁신 | 온칩 메모리 | 성능 |
|---|---|---|---|---|
| TPU v1 | 2015 | INT8 추론, 256×256 Systolic Array | 28MB | 92 TOPS |
| TPU v2 | 2017 | bfloat16 학습, HBM 도입 | 8GB HBM | 45 TFLOPS |
| TPU v3 | 2018 | 수냉 쿨링, v2 대비 2배 성능 | 16GB HBM | 90 TFLOPS |
| TPU v4 | 2021 | 3D 토러스 인터커넥트, OCS | 32GB HBM | 275 TFLOPS |
| TPU v5p | 2023 | 최대 포드 규모, 트랜스포머 최적화 | 96GB HBM | 459 TFLOPS |
TPU v4의 3D 토러스 인터커넥트
TPU v4 Pod 토폴로지:
각 TPU v4 칩은 6개의 방향으로 직접 연결됨:
+X, -X, +Y, -Y, +Z, -Z (3차원 토러스)
4096개 TPU = 16 × 16 × 16 = 4096 노드 3D 토러스
칩당 인터커넥트 대역폭: 600 GB/s
왜 3D 토러스인가?
- 임의 두 노드 간 홉 수: O(N^(1/3))
- 4096개 노드에서 최대 홉 수: 24 (vs 링 구조의 2048)
- 집합 통신 (AllReduce) 효율 극대화
OCS (Optical Circuit Switch):
- 광 스위치로 토폴로지를 동적으로 재구성
- 다른 워크로드에 맞게 네트워크 구조 변경 가능
- 대역폭 낭비 없는 최적 경로
7. XLA: TPU를 위한 컴파일러
XLA (Accelerated Linear Algebra)는 TPU의 소프트웨어 스택의 핵심입니다. TensorFlow/JAX 코드를 TPU에 최적화된 기계어로 변환합니다.
# JAX 코드 (XLA → TPU로 JIT 컴파일):
import jax
import jax.numpy as jnp
from functools import partial
@jax.jit # XLA로 JIT 컴파일
def transformer_layer(x, w_q, w_k, w_v, w_o, w_ff1, w_ff2):
"""단순화된 트랜스포머 레이어"""
batch, seq, d_model = x.shape
# Multi-head attention (Q, K, V 투영)
q = jnp.dot(x, w_q) # GEMM → Systolic Array
k = jnp.dot(x, w_k) # GEMM → Systolic Array
v = jnp.dot(x, w_v) # GEMM → Systolic Array
# Attention 점수 계산
scale = jnp.sqrt(d_model // 8).astype(jnp.float32)
scores = jnp.einsum('bqd,bkd->bqk', q, k) / scale
attn = jax.nn.softmax(scores, axis=-1)
# 가중합
out = jnp.einsum('bqk,bkd->bqd', attn, v)
out = jnp.dot(out, w_o) # GEMM
# FFN
hidden = jax.nn.gelu(jnp.dot(out, w_ff1)) # GEMM
result = jnp.dot(hidden, w_ff2) # GEMM
return result
# XLA가 수행하는 최적화:
# 1. 연산 퓨전: multiple ops → 단일 커널
# 2. 레이아웃 최적화: Systolic Array에 최적 메모리 배치
# 3. 재물질화: 중간 활성화 재계산 vs 저장 트레이드오프
# 4. 자동 샤딩: 여러 TPU 코어에 자동 분할
XLA의 연산 퓨전 (Operation Fusion)
퓨전 전:
LayerNorm: Mean 계산 → Variance 계산 → Normalize → Scale → Bias
각 단계마다 HBM에 쓰고 읽음 = 5번의 메모리 왕복
퓨전 후:
LayerNorm: 단일 커널로 모든 계산 수행
메모리 왕복: 1번 (입력 읽기 + 출력 쓰기)
결과: 메모리 대역폭 절약 → 5배 빠른 LayerNorm
XLA는 수백 가지 이런 퓨전 패턴을 자동 적용
8. TPU Pod에서 LLM 서빙하기
실제 JAX 코드로 구현하는 분산 추론
import jax
import jax.numpy as jnp
from jax import devices
from jax.experimental import mesh_utils
from jax.sharding import Mesh, PartitionSpec, NamedSharding
import numpy as np
# TPU 디바이스 확인
print(f"사용 가능한 TPU: {jax.devices()}")
# 출력: [TpuDevice(id=0, ...), TpuDevice(id=1, ...), ...]
# 8개 TPU 코어에 모델 샤딩
num_devices = len(jax.devices())
devices_array = mesh_utils.create_device_mesh((num_devices,))
mesh = Mesh(devices_array, axis_names=('model',))
# 샤딩 전략 정의
# 가중치 행렬: 'model' 축 방향으로 열 분할
weight_sharding = NamedSharding(mesh, PartitionSpec('model', None))
# 활성화: 복제 (모든 디바이스에 동일)
activation_sharding = NamedSharding(mesh, PartitionSpec(None, None))
def shard_weights(weights):
"""모델 가중치를 TPU들에 분산"""
return jax.device_put(weights, weight_sharding)
# 실제 추론 함수
@partial(jax.jit, in_shardings=(activation_sharding, weight_sharding),
out_shardings=activation_sharding)
def linear_layer(x, w):
"""샤딩된 선형 레이어"""
return jnp.dot(x, w)
# 예시: 7B 모델을 8개 TPU에 분산
d_model = 4096
d_ffn = 16384 # 4x expansion
batch_size = 8
seq_len = 2048
# 가중치 생성 및 샤딩
w_ffn = np.random.randn(d_model, d_ffn).astype(np.float16)
w_ffn_sharded = shard_weights(w_ffn)
# 입력 생성
x = np.random.randn(batch_size, seq_len, d_model).astype(np.float16)
x_batched = jax.device_put(x, activation_sharding)
# 분산 추론 실행
with mesh:
output = linear_layer(x_batched, w_ffn_sharded)
# 8개 TPU가 자동으로 협력하여 계산!
TPU Pod에서 LLM 학습: 실제 규모
PaLM 540B 학습 설정 (Google, 2022):
- 하드웨어: TPU v4 Pod 6144개
- 총 연산 능력: 6144 × 275 TFLOPS = 1.69 ExaFLOPS
- 학습 토큰: 780 billion tokens
- 학습 시간: 약 57일
- 배치 크기: 2048 sequences × 2048 tokens
Gemini Ultra 학습 (Google DeepMind, 2023):
- TPU v5p 사용
- 다중 포드를 OCS로 연결
- 최대 규모의 트랜스포머 학습
Gemini 1.5 Pro:
- 1M 토큰 컨텍스트 처리
- TPU Pod의 3D 토러스로 분산 KV 캐시
9. 실전: JAX로 TPU에서 LLM 추론 최적화
# TPU에서 효율적인 추론을 위한 실전 팁
import jax
import jax.numpy as jnp
from jax import lax
import functools
# 팁 1: vmap으로 배치 처리 벡터화
@jax.vmap # 배치 차원을 자동으로 벡터화
def process_single_example(x):
return some_model_fn(x)
# 여러 입력을 한 번에 처리
batched_results = process_single_example(batch_of_inputs)
# 팁 2: scan으로 레이어 반복 효율화
def transformer_layer_fn(carry, x):
"""단일 트랜스포머 레이어"""
return carry, apply_layer(carry, x)
# N개 레이어를 메모리 효율적으로 처리
final_state, outputs = lax.scan(
transformer_layer_fn,
init_state,
layer_inputs
)
# 팁 3: 정적 형태(static shapes)로 XLA 재컴파일 방지
@functools.partial(jax.jit, static_argnames=['seq_len'])
def generate_tokens(model_params, prompt_ids, seq_len: int):
"""seq_len이 같으면 재컴파일 없음"""
# ... 토큰 생성 로직
pass
# 팁 4: 프로파일링으로 병목 찾기
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
result = my_model_fn(inputs)
# Google Chrome Tracing으로 병목 시각화
# 팁 5: 그래디언트 체크포인팅으로 메모리 절약
from jax.checkpoint import checkpoint
# 순방향 패스에서 일부 활성화만 저장
checkpointed_layer = checkpoint(transformer_layer_fn)
10. TPU vs 경쟁자: 실제 LLM 서빙 벤치마크
Llama 2 70B 추론 벤치마크 (배치 크기 1, FP16):
하드웨어 | 메모리 대역폭 | 처리량 (tok/s) | 지연시간 (ms/tok)
------------------|-------------|--------------|----------------
NVIDIA H100 SXM | 3.35 TB/s | ~120 | ~8.3
NVIDIA A100 80GB | 2.0 TB/s | ~72 | ~13.9
TPU v5p (4칩) | 4× 1.2TB/s | ~160 | ~6.3
AMD MI300X | 5.3 TB/s | ~190 | ~5.3
Apple M3 Ultra | 800 GB/s | ~55 | ~18.2
결론:
- TPU는 메모리 대역폭 대비 효율 최고
- 클라우드 규모에서는 TPU Pod의 집합 통신 효율이 추가 이점
- 온디바이스에서는 Apple Silicon의 통합 메모리 구조 유리
마치며
Systolic Array는 "특화가 범용을 이긴다"는 원리의 완벽한 증명입니다. 1978년에 제안된 아이디어가 2015년 구글의 데이터센터에서 부활하여, 현재 전 세계 AI 인프라의 핵심을 이루고 있습니다.
핵심 교훈:
- 도메인 특화: 트랜스포머의 95%가 GEMM이라는 사실을 활용하면 전용 하드웨어로 압도적 효율 달성
- 데이터 재사용: Systolic Array의 핵심은 데이터를 한 번 읽어 최대한 재사용하는 것
- 수 형식의 중요성: bfloat16은 "충분히 정확하면서 빠른" 트레이드오프의 승자
- 컴파일러와 하드웨어 협력: XLA 없이는 TPU의 성능을 끌어내기 어렵다
다음 편에서는 스마트폰과 PC에 내장된 NPU가 이 아이디어들을 어떻게 에너지 효율적으로 구현하는지 살펴보겠습니다.
참고 자료
- Jouppi et al., "In-Datacenter Performance Analysis of a Tensor Processing Unit" (ISCA 2017)
- Google TPU Research Cloud: cloud.google.com/tpu
- JAX Documentation: jax.readthedocs.io
- "Scaling Language Models: Methods, Analysis & Insights from Training Gopher" (DeepMind, 2021)
- PaLM: "Scaling Language Modeling with Pathways" (Google, 2022)
Google TPU Deep Dive: How Systolic Arrays Solve Matrix Multiplication Perfectly
- Introduction: Why Google Built Its Own Chip
- 1. Systolic Array: Data Flows Like a Heartbeat
- 2. TPU v1 Specs — The Full Analysis
- 3. GPU vs TPU: A Philosophical Divide
- 4. TPU Memory Hierarchy
- 5. bfloat16: Google's Number Format Innovation
- 6. TPU Version History: Four Generations of Evolution
- 7. XLA: The Compiler That Makes TPU Sing
- 8. Running LLM Inference on a TPU Pod
- 9. Performance Numbers: TPU vs The Competition
- 10. Practical Profiling and Optimization
- Conclusion
- References
Introduction: Why Google Built Its Own Chip
In 2013, Google engineers ran a chilling calculation. If every Gmail user spent just 3 minutes per day using voice search — powered by Deep Neural Networks — it would require doubling Google's entire data center capacity. The computational cost of DNN inference was the bottleneck.
Jeff Dean and the team reached a conclusion: general-purpose CPUs and GPUs simply aren't efficient enough. We need a chip designed specifically for neural network inference.
The result was the Tensor Processing Unit (TPU). First deployed in Google data centers in 2015, and revealed to the world via an ISCA 2017 paper authored by Norman Jouppi's hardware team. The core insight was elegantly simple:
"For neural network inference, you don't need full IEEE floating point — INT8 is enough."
This single insight enabled 10x+ performance-per-watt efficiency over contemporary GPUs.
1. Systolic Array: Data Flows Like a Heartbeat
The Name's Origin
The term "Systolic Array" comes from the heart's systolic pumping rhythm. Just as the heart contracts rhythmically to pump blood, a Systolic Array pulses data through its grid of processing elements on every clock cycle. The concept was originally proposed by H.T. Kung and Charles Leiserson in 1978 — Google scaled it massively for neural networks.
The Core Structure
Let's walk through how a 4×4 Systolic Array performs matrix multiplication A × B = C, step by step.
Systolic Array: 4x4 grid of MAC (Multiply-Accumulate) units
Each cell: receives inputs from left (A row) and top (B column),
multiplies them, adds to its accumulator
Time step 0: Setup
B[0,0] B[1,0] B[2,0] B[3,0] <- B matrix columns flow DOWN
| | | |
A[0,0]->[MAC][MAC][MAC][MAC]
A[1,0]->[MAC][MAC][MAC][MAC]
A[2,0]->[MAC][MAC][MAC][MAC]
A[3,0]->[MAC][MAC][MAC][MAC]
| | | |
C col C col ... ... <- results flow OUT
Time step 1: A[0,0] reaches MAC[0,0]: computes A[0,0]*B[0,0]
Time step 2: A[0,0] passes to MAC[0,1], A[0,1] enters MAC[0,0]
Each cell accumulates partial products over time
KEY INSIGHT:
- Every MAC unit is busy EVERY SINGLE CLOCK CYCLE
- No memory reads during computation (data is already "in flight")
- TPU v1: 256x256 = 65,536 MACs running simultaneously!
- Perfect data reuse: each value read from memory contributes to 256 operations
Why This Beats Conventional Processors
On a general-purpose processor, matrix multiplication looks like this:
# Naive matrix multiplication — the memory access problem
def matmul_naive(A, B, N):
C = [[0] * N for _ in range(N)]
for i in range(N):
for j in range(N):
for k in range(N):
# READ A[i][k] FROM MEMORY
# READ B[k][j] FROM MEMORY
C[i][j] += A[i][k] * B[k][j]
return C
# Memory accesses: O(N^3) — catastrophic at large N
# For N=256: 16,777,216 memory reads!
# Cache miss penalty: ~100-300 cycles each
# Systolic Array solution:
# A values: read once, flow through entire row (256 ops per read)
# B values: read once, flow through entire column (256 ops per read)
# Total memory accesses: O(N^2) instead of O(N^3)
# For N=256: 65,536 reads vs 16M reads — 256x fewer!
2. TPU v1 Specs — The Full Analysis
Let's dissect the actual hardware numbers from the 2015 TPU v1.
TPU v1 Hardware Specifications:
+-----------------------------------------------------+
| Systolic Array: 256 x 256 = 65,536 MAC units |
| Data types: INT8 weights, INT32 accumulator |
| Clock speed: 700 MHz |
| On-chip memory: 28 MB (Weight FIFO + Unified Buf) |
| Memory bandwidth: 30 GB/s (DDR3) |
| Power draw: 40W |
| Process node: 28nm CMOS, PCIe card form factor |
+-----------------------------------------------------+
Performance calculation:
65,536 MACs x 700 MHz x 2 (one multiply + one add) = 92 TOPS (INT8)
Competitive comparison:
- TPU v1: 92 TOPS @ 40W = 2.30 TOPS/W
- K80 GPU: 8.7 TOPS @ 300W = 0.029 TOPS/W
Result:
- 79x better energy efficiency than K80
- 10.6x higher raw throughput
These numbers come from the ISCA 2017 paper's measured results across 6 different production neural network workloads. On average, the TPU v1 delivered 15-30x faster inference than contemporary GPUs, at 80x better energy efficiency.
3. GPU vs TPU: A Philosophical Divide
GPU Philosophy: "I can do everything"
+---------------------------------------------+
| Thousands of general-purpose CUDA cores |
| Flexible CUDA programming model |
| Complex branching and control flow support |
| Arbitrary memory access patterns |
| Graphics, simulation, AI — all supported |
| Overhead: scheduler, register file, caches |
+---------------------------------------------+
TPU Philosophy: "I only do matrix multiplication — perfectly"
+---------------------------------------------+
| Systolic Array (dedicated to GEMM) |
| Deterministic dataflow (compiler decides) |
| Limited operation set (MAC-centric) |
| No complex control logic needed |
| 100% hardware utilization, no waste |
| Maximized energy efficiency |
+---------------------------------------------+
Why specialization wins:
Transformer operations: ~95% are GEMM (matrix multiply)
-> Hardware specialized for GEMM dominates everything else
Let's measure exactly what fraction of transformer compute is GEMM:
# FLOPS breakdown for a GPT-2-scale transformer (124M params)
layer_config = {
'd_model': 768,
'n_heads': 12,
'n_layers': 12,
'seq_len': 1024
}
d = layer_config['d_model'] # 768
n = layer_config['seq_len'] # 1024
L = layer_config['n_layers'] # 12
# Q, K, V projections: 3 separate [n x d] x [d x d] GEMMs per layer
qkv_flops = 3 * 2 * n * d * d * L
# Attention: Q x K^T (batched GEMM) + weighted sum by V
attn_flops = (2 * n * n * d + 2 * n * n * d) * L
# FFN: two GEMMs [d -> 4d -> d]
ffn_flops = (2 * n * d * (4*d) + 2 * n * (4*d) * d) * L
# Output projection per layer
out_flops = 2 * n * d * d * L
total = qkv_flops + attn_flops + ffn_flops + out_flops
print(f"QKV projections: {qkv_flops/total*100:.1f}%") # ~38%
print(f"Attention GEMM: {attn_flops/total*100:.1f}%") # ~12%
print(f"FFN: {ffn_flops/total*100:.1f}%") # ~41%
print(f"Output proj: {out_flops/total*100:.1f}%") # ~9%
# Total GEMM: ~100% (attention is also batched GEMM)
4. TPU Memory Hierarchy
The memory architecture is designed to keep the Systolic Array fed at all times.
TPU v4 Memory Hierarchy:
+-----------------------------------------------------------+
| Weight FIFO (on-chip, feeds directly to Systolic) |
| Capacity: 32 MB / Bandwidth: 2.7 TB/s |
| Purpose: stream weights into array without stalls |
+-----------------------------------------------------------+
| Unified Buffer (on-chip, holds activations) |
| Capacity: 256 MB / Bandwidth: 900 GB/s |
| Purpose: inter-layer intermediate activation storage |
+-----------------------------------------------------------+
| High Bandwidth Memory (HBM, off-chip) |
| Capacity: 32 GB / Bandwidth: 1.2 TB/s |
| Purpose: store the full model weights |
+-----------------------------------------------------------+
Data flow:
HBM -> Weight FIFO -> Systolic Array (weight stream)
HBM -> Unified Buffer -> Systolic Array (activation stream)
Systolic Array -> Unified Buffer -> next layer
Weight Stationary vs Output Stationary Dataflow
There are two primary strategies for orchestrating data through a Systolic Array:
Weight Stationary:
- Pre-load weights (W) into each MAC unit
- Stream input activations (X) through the array
- Compute: W x X
- Advantage: maximizes weight reuse (great for large batches)
- Used by: TPU v1 for inference
Output Stationary:
- Each MAC accumulates one output element C[i][j]
- Stream both input matrices through
- Advantage: minimizes output writes (great for small batches)
- Better for single-sample inference
When to use which:
- Batch inference (batch_size > 32): Weight Stationary
- Single-sample inference: Output Stationary
- Training (large batches): Weight Stationary
5. bfloat16: Google's Number Format Innovation
Introduced with TPU v2, bfloat16 has become the de facto standard for deep learning.
Floating-point format comparison:
FP32: [1 sign][8 exponent][23 mantissa] = 32 bits
Range: +/- 3.4 x 10^38
Precision: ~7 decimal digits
FP16: [1 sign][5 exponent][10 mantissa] = 16 bits
Range: +/- 6.5 x 10^4 (dangerously narrow!)
Precision: ~3-4 decimal digits
Problem: gradients vanish/explode during training
bfloat16: [1 sign][8 exponent][7 mantissa] = 16 bits (Google's design)
Range: +/- 3.4 x 10^38 (SAME as FP32!)
Precision: ~2-3 decimal digits
Key trick: just drop the last 16 bits of FP32!
Why bfloat16 wins for deep learning:
1. Same exponent range as FP32 -> no overflow/underflow in gradients
2. Less mantissa -> DL needs range more than precision
3. FP32 -> bfloat16 conversion: just truncate last 16 bits (free!)
4. Mixed precision: FP32 master weights + bfloat16 compute
import jax
import jax.numpy as jnp
import numpy as np
# Demonstrating bfloat16 properties
x_fp32 = np.float32(1.0 / 3.0)
x_bf16 = jnp.bfloat16(1.0 / 3.0)
x_fp16 = np.float16(1.0 / 3.0)
print(f"FP32: {float(x_fp32):.10f}") # 0.3333333433
print(f"BF16: {float(x_bf16):.10f}") # 0.3320312500 (less precise, OK)
print(f"FP16: {float(x_fp16):.10f}") # 0.3334960938
# Range comparison (critical for training)
print(f"FP32 max: {np.finfo(np.float32).max:.2e}") # 3.40e+38
print(f"BF16 max: {float(jnp.finfo(jnp.bfloat16).max):.2e}") # 3.39e+38
print(f"FP16 max: {np.finfo(np.float16).max:.2e}") # 6.55e+04 << problem!
# In practice: training loss at step 10000
# FP32: 0.2341 (clean, stable)
# BF16: 0.2343 (virtually identical)
# FP16: NaN (gradient overflow in many models!)
6. TPU Version History: Four Generations of Evolution
| Version | Year | Key Innovation | Memory | Peak Performance |
|---|---|---|---|---|
| TPU v1 | 2015 | INT8 inference, 256x256 array | 28MB on-chip | 92 TOPS |
| TPU v2 | 2017 | bfloat16 training, HBM | 8GB HBM | 45 TFLOPS |
| TPU v3 | 2018 | Liquid cooling, 2x v2 perf | 16GB HBM | 90 TFLOPS |
| TPU v4 | 2021 | 3D torus interconnect, OCS | 32GB HBM | 275 TFLOPS |
| TPU v5p | 2023 | Largest pod, transformer-opt | 96GB HBM | 459 TFLOPS |
TPU v4's 3D Torus Interconnect
TPU v4 Pod Topology:
Each TPU v4 chip connects to 6 neighbors: +X, -X, +Y, -Y, +Z, -Z
This forms a 3-dimensional torus network
Full TPU v4 Pod: 16 x 16 x 16 = 4,096 chips in 3D torus
Per-chip interconnect bandwidth: 600 GB/s
Why 3D torus?
- Max hops between any two nodes: O(N^(1/3))
- At 4096 nodes: max 24 hops (vs 2048 hops for a simple ring)
- Collective communication (AllReduce) is highly efficient
- Bisection bandwidth scales well with pod size
OCS (Optical Circuit Switch):
- Software-configurable optical switching fabric
- Dynamically reconfigures the torus topology per workload
- Eliminates electrical switch bottlenecks
- Enables full bandwidth between any chip pairs
7. XLA: The Compiler That Makes TPU Sing
XLA (Accelerated Linear Algebra) is the backbone of TPU's software stack. It compiles JAX/TensorFlow compute graphs into highly optimized TPU machine code.
# JAX code that compiles to XLA -> runs on TPU
import jax
import jax.numpy as jnp
from functools import partial
@jax.jit # JIT compile via XLA
def transformer_forward(x, params):
"""Single transformer layer, JAX style"""
w_q, w_k, w_v, w_o = params['attn']
w_ff1, w_ff2 = params['ffn']
# Layer norm
x_norm = jax.nn.standardize(x, axis=-1)
# Multi-head attention projections (GEMM -> Systolic Array)
q = jnp.dot(x_norm, w_q) # [batch, seq, d_model] x [d_model, d_head]
k = jnp.dot(x_norm, w_k)
v = jnp.dot(x_norm, w_v)
# Attention (batched GEMM)
d_head = q.shape[-1]
scale = jnp.sqrt(float(d_head))
scores = jnp.einsum('bqh,bkh->bqk', q, k) / scale
attn_weights = jax.nn.softmax(scores, axis=-1)
attended = jnp.einsum('bqk,bkh->bqh', attn_weights, v)
# Output projection
out = jnp.dot(attended, w_o) + x # residual
# FFN
x2_norm = jax.nn.standardize(out, axis=-1)
hidden = jax.nn.gelu(jnp.dot(x2_norm, w_ff1))
ffn_out = jnp.dot(hidden, w_ff2) + out # residual
return ffn_out
# What XLA does to this code:
# 1. Operation fusion: LayerNorm (5 ops) -> single kernel
# 2. Layout optimization: pick memory layout for Systolic Array
# 3. Rematerialization: recompute activations vs store (memory/compute tradeoff)
# 4. Auto-sharding: split across TPU chips with minimal communication
# 5. Constant folding: pre-compute anything computable at compile time
XLA Operation Fusion: Concrete Example
Without fusion (naive):
LayerNorm breakdown:
Step 1: mean(x) -> write to HBM
Step 2: x - mean -> write to HBM
Step 3: variance(x - mean) -> write to HBM
Step 4: normalize -> write to HBM
Step 5: scale * x + bias -> write to HBM
Total: 5 HBM round trips per layer
With XLA fusion:
LayerNorm: single kernel
- Read input once from HBM
- Do all 5 computations in SRAM (on-chip)
- Write output once to HBM
Total: 1 HBM round trip per layer
Impact: 5x reduction in memory bandwidth usage
XLA applies hundreds of such fusion patterns automatically
8. Running LLM Inference on a TPU Pod
Distributed Inference with JAX: Full Implementation
import jax
import jax.numpy as jnp
from jax.experimental import mesh_utils
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from functools import partial
import numpy as np
# Check available TPU devices
print(f"Available TPUs: {jax.devices()}")
# [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), ...]
# Create 8-TPU mesh for tensor parallelism
num_devices = 8 # one TPU v4 chip = 4 cores; 2 chips = 8 cores
device_mesh = mesh_utils.create_device_mesh((num_devices,))
mesh = Mesh(device_mesh, axis_names=('model',))
# Define sharding strategies
# Weight matrix W: shard along columns (model-parallel)
W_sharding = NamedSharding(mesh, PartitionSpec(None, 'model'))
# Bias: replicated on all devices
bias_sharding = NamedSharding(mesh, PartitionSpec(None))
# Activations: replicated (all devices see same input)
act_sharding = NamedSharding(mesh, PartitionSpec(None, None))
@partial(jax.jit,
in_shardings=(act_sharding, W_sharding, bias_sharding),
out_shardings=act_sharding)
def sharded_linear(x, W, b):
"""Column-parallel linear layer across 8 TPUs"""
# Each TPU computes [batch, seq, d_model/8]
local_out = jnp.dot(x, W)
# AllReduce to sum partial results (automatic!)
return local_out + b
# Real inference pipeline for a 7B model
def create_model_params(d_model=4096, d_ffn=16384, n_layers=32):
"""Create model params sharded across 8 TPUs"""
params = {}
for i in range(n_layers):
# Each layer's FFN weights sharded across devices
params[f'layer_{i}_ff1'] = jax.device_put(
np.random.randn(d_model, d_ffn).astype(np.float16),
W_sharding
)
params[f'layer_{i}_ff2'] = jax.device_put(
np.random.randn(d_ffn, d_model).astype(np.float16),
W_sharding
)
return params
# Inference: generate one token
def generate_next_token(params, input_ids, kv_cache):
batch_size, seq_len = input_ids.shape
# Run through all layers
x = embedding_lookup(params['embed'], input_ids) # [batch, seq, d_model]
for layer_idx in range(32):
# Attention with cached KV
x = attention_with_cache(x, params, kv_cache, layer_idx)
# FFN (sharded linear layers)
with mesh:
ffn_hidden = sharded_linear(
x,
params[f'layer_{layer_idx}_ff1'],
params[f'bias_{layer_idx}_ff1']
)
x = x + sharded_linear(
jax.nn.gelu(ffn_hidden),
params[f'layer_{layer_idx}_ff2'],
params[f'bias_{layer_idx}_ff2']
)
# Final logits
logits = jnp.dot(x[:, -1, :], params['lm_head'])
return logits.argmax(axis=-1)
Real-World Scale: Training PaLM on TPU Pods
PaLM 540B Training Configuration (Google, 2022):
- Hardware: TPU v4 Pod x 6,144 chips
- Total compute: 6,144 x 275 TFLOPS = 1.69 ExaFLOPS
- Training data: 780 billion tokens
- Training time: ~57 days
- Batch size: 2,048 sequences x 2,048 tokens
- MFU: 46.2% (Model FLOPS Utilization)
Gemini Ultra Training (Google DeepMind, 2023):
- Hardware: Multiple TPU v5p pods via OCS
- First model to outperform human experts on MMLU
- Context length training: up to 32,768 tokens
Why TPU Pods beat GPU clusters for this scale:
1. 3D torus: O(N^1/3) hop distance vs O(N) for ring-allreduce
2. OCS: full bandwidth between any chip pair (no switch bottleneck)
3. XLA: whole-model compilation, global optimization
4. bfloat16: stable training without loss scaling tricks
9. Performance Numbers: TPU vs The Competition
Llama 2 70B Inference Benchmark (batch_size=1, float16):
Hardware | Mem Bandwidth | Throughput | Latency
------------------|---------------|-------------|----------
NVIDIA H100 SXM | 3.35 TB/s | ~120 tok/s | ~8.3 ms/tok
NVIDIA A100 80GB | 2.0 TB/s | ~72 tok/s | ~13.9 ms/tok
TPU v5p (4-chip) | 4x 1.2 TB/s | ~160 tok/s | ~6.3 ms/tok
AMD MI300X | 5.3 TB/s | ~190 tok/s | ~5.3 ms/tok
Apple M3 Ultra | 800 GB/s | ~55 tok/s | ~18.2 ms/tok
Key insight: throughput correlates almost perfectly with memory bandwidth!
This is the "memory-bound" nature of LLM inference — more in the NPU post.
TPU advantages at scale (beyond single-chip numbers):
- TPU Pod AllReduce: 600 GB/s per chip interconnect
- GPU NVLink: 900 GB/s (H100), but only within NVSwitch domain
- TPU OCS: reconfigurable topology for different workloads
- Cost: TPU v5p ~$2.20/hr vs H100 ~$3.50/hr on comparable clouds
10. Practical Profiling and Optimization
# Profile your JAX/TPU code to find bottlenecks
import jax
import jax.numpy as jnp
# Method 1: JAX Profiler (produces Chrome Trace format)
with jax.profiler.trace("/tmp/jax_trace", create_perfetto_link=True):
# Warmup (needed to exclude JIT compilation time)
_ = my_model_fn(sample_input).block_until_ready()
# Actual profiled run
result = my_model_fn(real_input).block_until_ready()
# Open in Perfetto UI: https://ui.perfetto.dev
# Look for:
# - Long gaps between ops (memory bandwidth bound)
# - Unbalanced ops across devices (sharding inefficiency)
# - Frequent recompilation (dynamic shapes)
# Method 2: Measure Model FLOPS Utilization (MFU)
def compute_mfu(model_flops, elapsed_seconds, peak_tflops):
"""MFU = actual FLOPS / theoretical peak FLOPS"""
actual_tflops = model_flops / elapsed_seconds / 1e12
return actual_tflops / peak_tflops * 100
# Example: 7B model, 100 token/s throughput
flops_per_token = 2 * 7e9 # 2 * num_params (rough estimate)
elapsed = 1.0 / 100 # 1 sec / 100 tokens = 0.01 sec/token
tpu_v5p_tflops = 459
mfu = compute_mfu(flops_per_token, elapsed, tpu_v5p_tflops)
print(f"MFU: {mfu:.1f}%")
# Good MFU: >40% means you're using the hardware well
# Low MFU: memory bandwidth bound (typical for inference)
# Method 3: Identify recompilation events
# Add logging to catch shape changes that trigger recompiles
jax.config.update("jax_log_compiles", True)
# Check stderr for "Compiling..." messages during inference
Conclusion
The Systolic Array is a perfect demonstration of the principle: specialization beats generalization.
An idea from 1978 became the engine powering Google's data centers in 2015, and today runs the world's most capable AI systems — Gemini, PaLM, and the infrastructure behind Google Search, Gmail, and Google Assistant.
The key lessons:
- Domain specialization pays: 95% of transformer compute is GEMM. Hardware built for GEMM dominates everything else.
- Data reuse is the key: The Systolic Array reads data once and extracts maximum computation from it — this is the fundamental insight.
- Number formats matter: bfloat16 hit the sweet spot of "accurate enough while fast" — same dynamic range as FP32, half the bandwidth.
- Compiler-hardware co-design: Without XLA, TPU would be 50-70% less effective. The compiler IS part of the hardware.
- Scale requires topology: The 3D torus interconnect in TPU Pods is what makes ExaFLOP-scale training possible.
In the next post, we dive into NPUs — the chips embedded in your phone and laptop — and explain how they implement these same principles in a 1-5 watt power envelope.
References
- Jouppi et al., "In-Datacenter Performance Analysis of a Tensor Processing Unit" (ISCA 2017)
- Google TPU Research Cloud: cloud.google.com/tpu
- JAX Documentation: jax.readthedocs.io
- "PaLM: Scaling Language Modeling with Pathways" (Chowdhery et al., 2022)
- "Gemini: A Family of Highly Capable Multimodal Models" (Google DeepMind, 2023)
- H.T. Kung & C.E. Leiserson, "Systolic Arrays for VLSI" (1978)