- Published on
Google TPU Deep Dive: How Systolic Arrays Solve Matrix Multiplication Perfectly
- Authors

- Name
- Youngju Kim
- @fjvbn20031
- Introduction: Why Google Built Its Own Chip
- 1. Systolic Array: Data Flows Like a Heartbeat
- 2. TPU v1 Specs — The Full Analysis
- 3. GPU vs TPU: A Philosophical Divide
- 4. TPU Memory Hierarchy
- 5. bfloat16: Google's Number Format Innovation
- 6. TPU Version History: Four Generations of Evolution
- 7. XLA: The Compiler That Makes TPU Sing
- 8. Running LLM Inference on a TPU Pod
- 9. Performance Numbers: TPU vs The Competition
- 10. Practical Profiling and Optimization
- Conclusion
- References
Introduction: Why Google Built Its Own Chip
In 2013, Google engineers ran a chilling calculation. If every Gmail user spent just 3 minutes per day using voice search — powered by Deep Neural Networks — it would require doubling Google's entire data center capacity. The computational cost of DNN inference was the bottleneck.
Jeff Dean and the team reached a conclusion: general-purpose CPUs and GPUs simply aren't efficient enough. We need a chip designed specifically for neural network inference.
The result was the Tensor Processing Unit (TPU). First deployed in Google data centers in 2015, and revealed to the world via an ISCA 2017 paper authored by Norman Jouppi's hardware team. The core insight was elegantly simple:
"For neural network inference, you don't need full IEEE floating point — INT8 is enough."
This single insight enabled 10x+ performance-per-watt efficiency over contemporary GPUs.
1. Systolic Array: Data Flows Like a Heartbeat
The Name's Origin
The term "Systolic Array" comes from the heart's systolic pumping rhythm. Just as the heart contracts rhythmically to pump blood, a Systolic Array pulses data through its grid of processing elements on every clock cycle. The concept was originally proposed by H.T. Kung and Charles Leiserson in 1978 — Google scaled it massively for neural networks.
The Core Structure
Let's walk through how a 4×4 Systolic Array performs matrix multiplication A × B = C, step by step.
Systolic Array: 4x4 grid of MAC (Multiply-Accumulate) units
Each cell: receives inputs from left (A row) and top (B column),
multiplies them, adds to its accumulator
Time step 0: Setup
B[0,0] B[1,0] B[2,0] B[3,0] <- B matrix columns flow DOWN
| | | |
A[0,0]->[MAC][MAC][MAC][MAC]
A[1,0]->[MAC][MAC][MAC][MAC]
A[2,0]->[MAC][MAC][MAC][MAC]
A[3,0]->[MAC][MAC][MAC][MAC]
| | | |
C col C col ... ... <- results flow OUT
Time step 1: A[0,0] reaches MAC[0,0]: computes A[0,0]*B[0,0]
Time step 2: A[0,0] passes to MAC[0,1], A[0,1] enters MAC[0,0]
Each cell accumulates partial products over time
KEY INSIGHT:
- Every MAC unit is busy EVERY SINGLE CLOCK CYCLE
- No memory reads during computation (data is already "in flight")
- TPU v1: 256x256 = 65,536 MACs running simultaneously!
- Perfect data reuse: each value read from memory contributes to 256 operations
Why This Beats Conventional Processors
On a general-purpose processor, matrix multiplication looks like this:
# Naive matrix multiplication — the memory access problem
def matmul_naive(A, B, N):
C = [[0] * N for _ in range(N)]
for i in range(N):
for j in range(N):
for k in range(N):
# READ A[i][k] FROM MEMORY
# READ B[k][j] FROM MEMORY
C[i][j] += A[i][k] * B[k][j]
return C
# Memory accesses: O(N^3) — catastrophic at large N
# For N=256: 16,777,216 memory reads!
# Cache miss penalty: ~100-300 cycles each
# Systolic Array solution:
# A values: read once, flow through entire row (256 ops per read)
# B values: read once, flow through entire column (256 ops per read)
# Total memory accesses: O(N^2) instead of O(N^3)
# For N=256: 65,536 reads vs 16M reads — 256x fewer!
2. TPU v1 Specs — The Full Analysis
Let's dissect the actual hardware numbers from the 2015 TPU v1.
TPU v1 Hardware Specifications:
+-----------------------------------------------------+
| Systolic Array: 256 x 256 = 65,536 MAC units |
| Data types: INT8 weights, INT32 accumulator |
| Clock speed: 700 MHz |
| On-chip memory: 28 MB (Weight FIFO + Unified Buf) |
| Memory bandwidth: 30 GB/s (DDR3) |
| Power draw: 40W |
| Process node: 28nm CMOS, PCIe card form factor |
+-----------------------------------------------------+
Performance calculation:
65,536 MACs x 700 MHz x 2 (one multiply + one add) = 92 TOPS (INT8)
Competitive comparison:
- TPU v1: 92 TOPS @ 40W = 2.30 TOPS/W
- K80 GPU: 8.7 TOPS @ 300W = 0.029 TOPS/W
Result:
- 79x better energy efficiency than K80
- 10.6x higher raw throughput
These numbers come from the ISCA 2017 paper's measured results across 6 different production neural network workloads. On average, the TPU v1 delivered 15-30x faster inference than contemporary GPUs, at 80x better energy efficiency.
3. GPU vs TPU: A Philosophical Divide
GPU Philosophy: "I can do everything"
+---------------------------------------------+
| Thousands of general-purpose CUDA cores |
| Flexible CUDA programming model |
| Complex branching and control flow support |
| Arbitrary memory access patterns |
| Graphics, simulation, AI — all supported |
| Overhead: scheduler, register file, caches |
+---------------------------------------------+
TPU Philosophy: "I only do matrix multiplication — perfectly"
+---------------------------------------------+
| Systolic Array (dedicated to GEMM) |
| Deterministic dataflow (compiler decides) |
| Limited operation set (MAC-centric) |
| No complex control logic needed |
| 100% hardware utilization, no waste |
| Maximized energy efficiency |
+---------------------------------------------+
Why specialization wins:
Transformer operations: ~95% are GEMM (matrix multiply)
-> Hardware specialized for GEMM dominates everything else
Let's measure exactly what fraction of transformer compute is GEMM:
# FLOPS breakdown for a GPT-2-scale transformer (124M params)
layer_config = {
'd_model': 768,
'n_heads': 12,
'n_layers': 12,
'seq_len': 1024
}
d = layer_config['d_model'] # 768
n = layer_config['seq_len'] # 1024
L = layer_config['n_layers'] # 12
# Q, K, V projections: 3 separate [n x d] x [d x d] GEMMs per layer
qkv_flops = 3 * 2 * n * d * d * L
# Attention: Q x K^T (batched GEMM) + weighted sum by V
attn_flops = (2 * n * n * d + 2 * n * n * d) * L
# FFN: two GEMMs [d -> 4d -> d]
ffn_flops = (2 * n * d * (4*d) + 2 * n * (4*d) * d) * L
# Output projection per layer
out_flops = 2 * n * d * d * L
total = qkv_flops + attn_flops + ffn_flops + out_flops
print(f"QKV projections: {qkv_flops/total*100:.1f}%") # ~38%
print(f"Attention GEMM: {attn_flops/total*100:.1f}%") # ~12%
print(f"FFN: {ffn_flops/total*100:.1f}%") # ~41%
print(f"Output proj: {out_flops/total*100:.1f}%") # ~9%
# Total GEMM: ~100% (attention is also batched GEMM)
4. TPU Memory Hierarchy
The memory architecture is designed to keep the Systolic Array fed at all times.
TPU v4 Memory Hierarchy:
+-----------------------------------------------------------+
| Weight FIFO (on-chip, feeds directly to Systolic) |
| Capacity: 32 MB / Bandwidth: 2.7 TB/s |
| Purpose: stream weights into array without stalls |
+-----------------------------------------------------------+
| Unified Buffer (on-chip, holds activations) |
| Capacity: 256 MB / Bandwidth: 900 GB/s |
| Purpose: inter-layer intermediate activation storage |
+-----------------------------------------------------------+
| High Bandwidth Memory (HBM, off-chip) |
| Capacity: 32 GB / Bandwidth: 1.2 TB/s |
| Purpose: store the full model weights |
+-----------------------------------------------------------+
Data flow:
HBM -> Weight FIFO -> Systolic Array (weight stream)
HBM -> Unified Buffer -> Systolic Array (activation stream)
Systolic Array -> Unified Buffer -> next layer
Weight Stationary vs Output Stationary Dataflow
There are two primary strategies for orchestrating data through a Systolic Array:
Weight Stationary:
- Pre-load weights (W) into each MAC unit
- Stream input activations (X) through the array
- Compute: W x X
- Advantage: maximizes weight reuse (great for large batches)
- Used by: TPU v1 for inference
Output Stationary:
- Each MAC accumulates one output element C[i][j]
- Stream both input matrices through
- Advantage: minimizes output writes (great for small batches)
- Better for single-sample inference
When to use which:
- Batch inference (batch_size > 32): Weight Stationary
- Single-sample inference: Output Stationary
- Training (large batches): Weight Stationary
5. bfloat16: Google's Number Format Innovation
Introduced with TPU v2, bfloat16 has become the de facto standard for deep learning.
Floating-point format comparison:
FP32: [1 sign][8 exponent][23 mantissa] = 32 bits
Range: +/- 3.4 x 10^38
Precision: ~7 decimal digits
FP16: [1 sign][5 exponent][10 mantissa] = 16 bits
Range: +/- 6.5 x 10^4 (dangerously narrow!)
Precision: ~3-4 decimal digits
Problem: gradients vanish/explode during training
bfloat16: [1 sign][8 exponent][7 mantissa] = 16 bits (Google's design)
Range: +/- 3.4 x 10^38 (SAME as FP32!)
Precision: ~2-3 decimal digits
Key trick: just drop the last 16 bits of FP32!
Why bfloat16 wins for deep learning:
1. Same exponent range as FP32 -> no overflow/underflow in gradients
2. Less mantissa -> DL needs range more than precision
3. FP32 -> bfloat16 conversion: just truncate last 16 bits (free!)
4. Mixed precision: FP32 master weights + bfloat16 compute
import jax
import jax.numpy as jnp
import numpy as np
# Demonstrating bfloat16 properties
x_fp32 = np.float32(1.0 / 3.0)
x_bf16 = jnp.bfloat16(1.0 / 3.0)
x_fp16 = np.float16(1.0 / 3.0)
print(f"FP32: {float(x_fp32):.10f}") # 0.3333333433
print(f"BF16: {float(x_bf16):.10f}") # 0.3320312500 (less precise, OK)
print(f"FP16: {float(x_fp16):.10f}") # 0.3334960938
# Range comparison (critical for training)
print(f"FP32 max: {np.finfo(np.float32).max:.2e}") # 3.40e+38
print(f"BF16 max: {float(jnp.finfo(jnp.bfloat16).max):.2e}") # 3.39e+38
print(f"FP16 max: {np.finfo(np.float16).max:.2e}") # 6.55e+04 << problem!
# In practice: training loss at step 10000
# FP32: 0.2341 (clean, stable)
# BF16: 0.2343 (virtually identical)
# FP16: NaN (gradient overflow in many models!)
6. TPU Version History: Four Generations of Evolution
| Version | Year | Key Innovation | Memory | Peak Performance |
|---|---|---|---|---|
| TPU v1 | 2015 | INT8 inference, 256x256 array | 28MB on-chip | 92 TOPS |
| TPU v2 | 2017 | bfloat16 training, HBM | 8GB HBM | 45 TFLOPS |
| TPU v3 | 2018 | Liquid cooling, 2x v2 perf | 16GB HBM | 90 TFLOPS |
| TPU v4 | 2021 | 3D torus interconnect, OCS | 32GB HBM | 275 TFLOPS |
| TPU v5p | 2023 | Largest pod, transformer-opt | 96GB HBM | 459 TFLOPS |
TPU v4's 3D Torus Interconnect
TPU v4 Pod Topology:
Each TPU v4 chip connects to 6 neighbors: +X, -X, +Y, -Y, +Z, -Z
This forms a 3-dimensional torus network
Full TPU v4 Pod: 16 x 16 x 16 = 4,096 chips in 3D torus
Per-chip interconnect bandwidth: 600 GB/s
Why 3D torus?
- Max hops between any two nodes: O(N^(1/3))
- At 4096 nodes: max 24 hops (vs 2048 hops for a simple ring)
- Collective communication (AllReduce) is highly efficient
- Bisection bandwidth scales well with pod size
OCS (Optical Circuit Switch):
- Software-configurable optical switching fabric
- Dynamically reconfigures the torus topology per workload
- Eliminates electrical switch bottlenecks
- Enables full bandwidth between any chip pairs
7. XLA: The Compiler That Makes TPU Sing
XLA (Accelerated Linear Algebra) is the backbone of TPU's software stack. It compiles JAX/TensorFlow compute graphs into highly optimized TPU machine code.
# JAX code that compiles to XLA -> runs on TPU
import jax
import jax.numpy as jnp
from functools import partial
@jax.jit # JIT compile via XLA
def transformer_forward(x, params):
"""Single transformer layer, JAX style"""
w_q, w_k, w_v, w_o = params['attn']
w_ff1, w_ff2 = params['ffn']
# Layer norm
x_norm = jax.nn.standardize(x, axis=-1)
# Multi-head attention projections (GEMM -> Systolic Array)
q = jnp.dot(x_norm, w_q) # [batch, seq, d_model] x [d_model, d_head]
k = jnp.dot(x_norm, w_k)
v = jnp.dot(x_norm, w_v)
# Attention (batched GEMM)
d_head = q.shape[-1]
scale = jnp.sqrt(float(d_head))
scores = jnp.einsum('bqh,bkh->bqk', q, k) / scale
attn_weights = jax.nn.softmax(scores, axis=-1)
attended = jnp.einsum('bqk,bkh->bqh', attn_weights, v)
# Output projection
out = jnp.dot(attended, w_o) + x # residual
# FFN
x2_norm = jax.nn.standardize(out, axis=-1)
hidden = jax.nn.gelu(jnp.dot(x2_norm, w_ff1))
ffn_out = jnp.dot(hidden, w_ff2) + out # residual
return ffn_out
# What XLA does to this code:
# 1. Operation fusion: LayerNorm (5 ops) -> single kernel
# 2. Layout optimization: pick memory layout for Systolic Array
# 3. Rematerialization: recompute activations vs store (memory/compute tradeoff)
# 4. Auto-sharding: split across TPU chips with minimal communication
# 5. Constant folding: pre-compute anything computable at compile time
XLA Operation Fusion: Concrete Example
Without fusion (naive):
LayerNorm breakdown:
Step 1: mean(x) -> write to HBM
Step 2: x - mean -> write to HBM
Step 3: variance(x - mean) -> write to HBM
Step 4: normalize -> write to HBM
Step 5: scale * x + bias -> write to HBM
Total: 5 HBM round trips per layer
With XLA fusion:
LayerNorm: single kernel
- Read input once from HBM
- Do all 5 computations in SRAM (on-chip)
- Write output once to HBM
Total: 1 HBM round trip per layer
Impact: 5x reduction in memory bandwidth usage
XLA applies hundreds of such fusion patterns automatically
8. Running LLM Inference on a TPU Pod
Distributed Inference with JAX: Full Implementation
import jax
import jax.numpy as jnp
from jax.experimental import mesh_utils
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from functools import partial
import numpy as np
# Check available TPU devices
print(f"Available TPUs: {jax.devices()}")
# [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), ...]
# Create 8-TPU mesh for tensor parallelism
num_devices = 8 # one TPU v4 chip = 4 cores; 2 chips = 8 cores
device_mesh = mesh_utils.create_device_mesh((num_devices,))
mesh = Mesh(device_mesh, axis_names=('model',))
# Define sharding strategies
# Weight matrix W: shard along columns (model-parallel)
W_sharding = NamedSharding(mesh, PartitionSpec(None, 'model'))
# Bias: replicated on all devices
bias_sharding = NamedSharding(mesh, PartitionSpec(None))
# Activations: replicated (all devices see same input)
act_sharding = NamedSharding(mesh, PartitionSpec(None, None))
@partial(jax.jit,
in_shardings=(act_sharding, W_sharding, bias_sharding),
out_shardings=act_sharding)
def sharded_linear(x, W, b):
"""Column-parallel linear layer across 8 TPUs"""
# Each TPU computes [batch, seq, d_model/8]
local_out = jnp.dot(x, W)
# AllReduce to sum partial results (automatic!)
return local_out + b
# Real inference pipeline for a 7B model
def create_model_params(d_model=4096, d_ffn=16384, n_layers=32):
"""Create model params sharded across 8 TPUs"""
params = {}
for i in range(n_layers):
# Each layer's FFN weights sharded across devices
params[f'layer_{i}_ff1'] = jax.device_put(
np.random.randn(d_model, d_ffn).astype(np.float16),
W_sharding
)
params[f'layer_{i}_ff2'] = jax.device_put(
np.random.randn(d_ffn, d_model).astype(np.float16),
W_sharding
)
return params
# Inference: generate one token
def generate_next_token(params, input_ids, kv_cache):
batch_size, seq_len = input_ids.shape
# Run through all layers
x = embedding_lookup(params['embed'], input_ids) # [batch, seq, d_model]
for layer_idx in range(32):
# Attention with cached KV
x = attention_with_cache(x, params, kv_cache, layer_idx)
# FFN (sharded linear layers)
with mesh:
ffn_hidden = sharded_linear(
x,
params[f'layer_{layer_idx}_ff1'],
params[f'bias_{layer_idx}_ff1']
)
x = x + sharded_linear(
jax.nn.gelu(ffn_hidden),
params[f'layer_{layer_idx}_ff2'],
params[f'bias_{layer_idx}_ff2']
)
# Final logits
logits = jnp.dot(x[:, -1, :], params['lm_head'])
return logits.argmax(axis=-1)
Real-World Scale: Training PaLM on TPU Pods
PaLM 540B Training Configuration (Google, 2022):
- Hardware: TPU v4 Pod x 6,144 chips
- Total compute: 6,144 x 275 TFLOPS = 1.69 ExaFLOPS
- Training data: 780 billion tokens
- Training time: ~57 days
- Batch size: 2,048 sequences x 2,048 tokens
- MFU: 46.2% (Model FLOPS Utilization)
Gemini Ultra Training (Google DeepMind, 2023):
- Hardware: Multiple TPU v5p pods via OCS
- First model to outperform human experts on MMLU
- Context length training: up to 32,768 tokens
Why TPU Pods beat GPU clusters for this scale:
1. 3D torus: O(N^1/3) hop distance vs O(N) for ring-allreduce
2. OCS: full bandwidth between any chip pair (no switch bottleneck)
3. XLA: whole-model compilation, global optimization
4. bfloat16: stable training without loss scaling tricks
9. Performance Numbers: TPU vs The Competition
Llama 2 70B Inference Benchmark (batch_size=1, float16):
Hardware | Mem Bandwidth | Throughput | Latency
------------------|---------------|-------------|----------
NVIDIA H100 SXM | 3.35 TB/s | ~120 tok/s | ~8.3 ms/tok
NVIDIA A100 80GB | 2.0 TB/s | ~72 tok/s | ~13.9 ms/tok
TPU v5p (4-chip) | 4x 1.2 TB/s | ~160 tok/s | ~6.3 ms/tok
AMD MI300X | 5.3 TB/s | ~190 tok/s | ~5.3 ms/tok
Apple M3 Ultra | 800 GB/s | ~55 tok/s | ~18.2 ms/tok
Key insight: throughput correlates almost perfectly with memory bandwidth!
This is the "memory-bound" nature of LLM inference — more in the NPU post.
TPU advantages at scale (beyond single-chip numbers):
- TPU Pod AllReduce: 600 GB/s per chip interconnect
- GPU NVLink: 900 GB/s (H100), but only within NVSwitch domain
- TPU OCS: reconfigurable topology for different workloads
- Cost: TPU v5p ~$2.20/hr vs H100 ~$3.50/hr on comparable clouds
10. Practical Profiling and Optimization
# Profile your JAX/TPU code to find bottlenecks
import jax
import jax.numpy as jnp
# Method 1: JAX Profiler (produces Chrome Trace format)
with jax.profiler.trace("/tmp/jax_trace", create_perfetto_link=True):
# Warmup (needed to exclude JIT compilation time)
_ = my_model_fn(sample_input).block_until_ready()
# Actual profiled run
result = my_model_fn(real_input).block_until_ready()
# Open in Perfetto UI: https://ui.perfetto.dev
# Look for:
# - Long gaps between ops (memory bandwidth bound)
# - Unbalanced ops across devices (sharding inefficiency)
# - Frequent recompilation (dynamic shapes)
# Method 2: Measure Model FLOPS Utilization (MFU)
def compute_mfu(model_flops, elapsed_seconds, peak_tflops):
"""MFU = actual FLOPS / theoretical peak FLOPS"""
actual_tflops = model_flops / elapsed_seconds / 1e12
return actual_tflops / peak_tflops * 100
# Example: 7B model, 100 token/s throughput
flops_per_token = 2 * 7e9 # 2 * num_params (rough estimate)
elapsed = 1.0 / 100 # 1 sec / 100 tokens = 0.01 sec/token
tpu_v5p_tflops = 459
mfu = compute_mfu(flops_per_token, elapsed, tpu_v5p_tflops)
print(f"MFU: {mfu:.1f}%")
# Good MFU: >40% means you're using the hardware well
# Low MFU: memory bandwidth bound (typical for inference)
# Method 3: Identify recompilation events
# Add logging to catch shape changes that trigger recompiles
jax.config.update("jax_log_compiles", True)
# Check stderr for "Compiling..." messages during inference
Conclusion
The Systolic Array is a perfect demonstration of the principle: specialization beats generalization.
An idea from 1978 became the engine powering Google's data centers in 2015, and today runs the world's most capable AI systems — Gemini, PaLM, and the infrastructure behind Google Search, Gmail, and Google Assistant.
The key lessons:
- Domain specialization pays: 95% of transformer compute is GEMM. Hardware built for GEMM dominates everything else.
- Data reuse is the key: The Systolic Array reads data once and extracts maximum computation from it — this is the fundamental insight.
- Number formats matter: bfloat16 hit the sweet spot of "accurate enough while fast" — same dynamic range as FP32, half the bandwidth.
- Compiler-hardware co-design: Without XLA, TPU would be 50-70% less effective. The compiler IS part of the hardware.
- Scale requires topology: The 3D torus interconnect in TPU Pods is what makes ExaFLOP-scale training possible.
In the next post, we dive into NPUs — the chips embedded in your phone and laptop — and explain how they implement these same principles in a 1-5 watt power envelope.
References
- Jouppi et al., "In-Datacenter Performance Analysis of a Tensor Processing Unit" (ISCA 2017)
- Google TPU Research Cloud: cloud.google.com/tpu
- JAX Documentation: jax.readthedocs.io
- "PaLM: Scaling Language Modeling with Pathways" (Chowdhery et al., 2022)
- "Gemini: A Family of Highly Capable Multimodal Models" (Google DeepMind, 2023)
- H.T. Kung & C.E. Leiserson, "Systolic Arrays for VLSI" (1978)