Skip to content
Published on

Google TPU 완전 해부: Systolic Array가 행렬 곱셈을 어떻게 완벽히 해결하는가

Authors

들어가며: 구글은 왜 직접 칩을 만들었는가

2013년 구글 엔지니어들이 충격적인 계산을 했습니다. 만약 모든 Gmail 사용자가 하루 3분씩 음성 검색을 사용하면, 기존 서버 인프라의 두 배가 필요하다는 것이었습니다. 음성 인식에 사용되는 딥 뉴럴 네트워크(DNN)의 연산 비용이 문제였습니다.

Jeff Dean을 포함한 구글 팀은 결론을 내렸습니다: 범용 CPU와 GPU로는 한계가 있다. 뉴럴 네트워크 추론에 특화된 칩을 만들자.

그 결과가 Tensor Processing Unit (TPU) 입니다. 2015년 구글 데이터센터에 처음 배포되었고, 2017년 ISCA 논문으로 세상에 공개되었습니다. Norman Jouppi가 이끈 하드웨어 팀의 핵심 인사이트는 단순했습니다:

"뉴럴 네트워크 추론에는 완전한 IEEE 부동소수점이 필요 없다. INT8이면 충분하다."

이 단순한 인사이트가 경쟁 GPU 대비 10배 이상의 성능/전력 효율을 가능하게 했습니다.


1. Systolic Array: 심장이 뛰듯 데이터가 흐른다

이름의 유래

Systolic Array라는 이름은 심장의 수축기(systolic) 에서 왔습니다. 심장이 규칙적으로 수축하며 혈액을 펌핑하듯, Systolic Array는 규칙적인 클럭 사이클마다 데이터를 흘려보냅니다. H.T. Kung과 Charles Leiserson이 1978년에 제안한 개념을 구글이 뉴럴 네트워크에 맞게 대규모로 구현했습니다.

Systolic Array의 기본 구조

4x4 Systolic Array로 행렬 곱셈 A × B = C를 수행하는 과정을 단계별로 살펴봅시다.

Systolic Array: 4x4 MAC (Multiply-Accumulate) 유닛 배열
각 셀: 왼쪽(A)(B)에서 입력을 받아 곱하고, 누산기에 더함

시간 0단계: 데이터 로드 준비
     B[0,0]  B[1,0]  B[2,0]  B[3,0]B 행렬 열이 아래로 흐름
       ↓       ↓       ↓       ↓
A[0,0][MAC][MAC][MAC][MAC]
A[1,0][MAC][MAC][MAC][MAC]
A[2,0][MAC][MAC][MAC][MAC]
A[3,0][MAC][MAC][MAC][MAC]
        ↓     ↓     ↓     ↓
      C0   C1   C2   C3   ← 결과가 아래로 출력

시간 1단계: A[0,0]MAC[0,0]에 도달 → A[0,0]*B[0,0] 계산
시간 2단계: A[0,0]MAC[0,1]으로 이동, A[0,1]MAC[0,0]에 진입
            각 셀이 부분 곱을 누산

핵심 인사이트:
-MAC 유닛은 매 클럭 사이클마다 동작 (유휴 시간 없음)
- 연산 중 메모리 읽기 없음 (데이터가 이미 "비행 중")
- TPU v1: 256×256 = 65,536개의 MAC이 동시에 동작!

기존 방식과의 차이

일반적인 프로세서에서 행렬 곱셈을 할 때:

일반 프로세서 방식:
for i in range(N):
    for j in range(N):
        for k in range(N):
            C[i][j] += A[i][k] * B[k][j]  # 매번 메모리에서 A, B 읽기

문제점: 메모리 접근이 O(N³) 번 발생
       캐시 미스가 병목이 됨
       실제 계산보다 데이터 이동 시간이 더 걸림

Systolic Array는 이 문제를 근본적으로 해결합니다:

Systolic Array 방식:
- A의 각 원소: 한 번 입력 → 전체 행을 통과하며 모든 곱을 기여
- B의 각 원소: 한 번 입력 → 전체 열을 통과하며 모든 곱을 기여
-MAC 유닛: 데이터를 기다리지 않고, 흘러오는 데이터를 즉시 처리

결과: 메모리 접근 횟수를 O(N²)으로 줄임 (N이 크면 엄청난 차이!)

2. TPU v1 스펙 완전 분석

2015년 배포된 TPU v1의 실제 스펙을 분석해봅시다.

TPU v1 하드웨어 스펙:
┌─────────────────────────────────────────────────────┐
Systolic Array: 256 × 256 = 65,536 MAC 유닛         │
│ 데이터 형식: INT8 (가중치), INT16 (어큐뮬레이터)│ 클럭 속도: 700 MHz│ 온칩 메모리: 28 MB (Weight FIFO + Unified Buffer)│ 메모리 대역폭: 30 GB/s (DDR3)│ 전력 소비: 40W                                       │
│ 패키지: 28nm CMOS, PCIe 카드 형태                   │
└─────────────────────────────────────────────────────┘

성능 계산:
65,536 MAC × 700 MHz × 2 (곱셈+덧셈) = 92 TOPS (INT8)

비교:
- TPU v1: 92 TOPS @ 40W = 2.3 TOPS/W
- NVIDIA K80 GPU: 8.7 TOPS @ 300W = 0.029 TOPS/W

TPU v1이 K80 대비 에너지 효율 79!
→ 성능 자체도 10.6배 높음

이 숫자들은 Google의 ISCA 2017 논문에서 검증된 실측값입니다. 6개의 서로 다른 뉴럴 네트워크 워크로드에서 TPU v1은 GPU 대비 평균 15~30배 빠른 추론 성능을 보였습니다.


3. GPU vs TPU: 철학의 차이

GPU 설계 철학: "나는 모든 것을 할 수 있다"
┌─────────────────────────────────────────────┐
│ 수천 개의 범용 CUDA 코어                    │
│ 유연한 프로그래밍 모델 (CUDA)│ 복잡한 분기/조건문 지원                     │
│ 임의 메모리 접근 패턴                       │
│ 그래픽, 물리 시뮬레이션, AI 모두 가능       │
│ 오버헤드: 스케줄러, 레지스터 파일, 캐시 계층 │
└─────────────────────────────────────────────┘

TPU 설계 철학: "나는 행렬 곱셈만 한다, 하지만 완벽하게"
┌─────────────────────────────────────────────┐
Systolic Array (행렬 곱셈 전용 하드웨어)│ 예측 가능한 데이터 흐름 (컴파일 타임 결정)│ 제한된 연산 집합 (MAC 위주)│ 복잡한 제어 로직 없음                       │
│ 낭비 없는 100% 하드웨어 활용                │
│ 에너지 효율 극대화                          │
└─────────────────────────────────────────────┘

왜 이 전략이 먹히는가?
Transformer 연산의 95% 이상이 행렬 곱셈 (GEMM)
→ 특화 설계가 범용 설계를 압도

트랜스포머 모델에서 실제로 어떤 연산이 얼마나 발생하는지 측정해봅시다:

# GPT-2 (124M params) 연산 분석
import torch
from torch.profiler import profile, ProfilerActivity

# Attention + FFN의 FLOPS 분포
layer_config = {
    'd_model': 768,
    'n_heads': 12,
    'n_layers': 12,
    'seq_len': 1024
}

d = layer_config['d_model']
n = layer_config['seq_len']
L = layer_config['n_layers']

# Q, K, V projection: 각 3개의 [d×d] 행렬 곱셈
qkv_flops = 3 * 2 * n * d * d * L  # factor of 2 for multiply+add

# Attention 계산: Q×K^T + softmax + ×V
attn_flops = (2 * n * n * d + 2 * n * n * d) * L

# FFN: 2개의 [d×4d] 행렬 곱셈
ffn_flops = 2 * 2 * n * d * (4 * d) * L

total = qkv_flops + attn_flops + ffn_flops
print(f"QKV projection: {qkv_flops/total*100:.1f}%")  # ~47%
print(f"Attention:       {attn_flops/total*100:.1f}%")  # ~12%
print(f"FFN:             {ffn_flops/total*100:.1f}%")  # ~41%
# GEMM 계열: 약 88-95%

4. TPU 메모리 계층 구조

TPU의 메모리 설계는 Systolic Array의 데이터 흐름에 최적화되어 있습니다.

TPU v4 메모리 계층:
┌─────────────────────────────────────────────────────────┐
Weight FIFO (온칩, Systolic Array 직접 공급)│          크기: 32 MB / 대역폭: 2.7 TB/s                  │
│          목적: 가중치를 Systolic Array에 끊임없이 공급   │
├─────────────────────────────────────────────────────────┤
Unified Buffer (온칩, 활성화 값 저장)│          크기: 256 MB / 대역폭: 900 GB/s                 │
│          목적: 레이어 간 중간 활성화 값 저장             │
├─────────────────────────────────────────────────────────┤
High Bandwidth Memory (HBM, 오프칩)│          크기: 32 GB / 대역폭: 1.2 TB/s                  │
│          목적: 모델 가중치 전체 저장                     │
└─────────────────────────────────────────────────────────┘

데이터 흐름:
HBMWeight FIFOSystolic Array (가중치 흐름)
HBMUnified BufferSystolic Array (활성화 흐름)
Systolic ArrayUnified Buffer → 다음 레이어

Weight Stationary vs Output Stationary Dataflow

Systolic Array를 사용하는 방식에는 두 가지 전략이 있습니다:

Weight Stationary (가중치 고정):
- 가중치(W)Systolic Array의 각 MAC에 미리 로드
- 입력 활성화(X)를 흘려보냄
- 결과: W × X를 계산
- 장점: 가중치 재사용 최대화 (배치 크기가 클 때 유리)
- TPU v1이 사용하는 방식

Output Stationary (출력 고정):
- 출력 C[i][j]를 해당 MAC에 누산
- 입력 AB를 모두 흘려보냄
- 장점: 출력 쓰기 최소화 (작은 배치에 유리)

선택 기준:
- 추론 (배치 크기 작음): Output Stationary 유리
- 학습 (배치 크기 큼): Weight Stationary 유리

5. bfloat16: Google이 만든 수 형식

TPU v2부터 도입된 bfloat16은 현재 딥러닝의 표준 데이터 형식이 되었습니다.

수 형식 비교:

FP32:    [1 부호][8 지수][23 가수] = 32비트
         범위: ±3.4 × 10^38
         정밀도:7자리

FP16:    [1 부호][5 지수][10 가수] = 16비트
         범위: ±6.5 × 10^4 (너무 좁음!)
         정밀도:3-4자리
         문제: 기울기 소실/폭발 위험

bfloat16:[1 부호][8 지수][7 가수]  = 16비트  ← Google 혁신
         범위: ±3.4 × 10^38 (FP32와 동일!)
         정밀도:2-3자리
         장점: FP32 상위 16비트만 떼어내면 됨

왜 딥러닝에 bfloat16이 좋은가:
1. FP32와 같은 지수 범위 → 오버플로/언더플로 없음
2. 가수 정밀도 손실 → 딥러닝은 정밀도보다 범위가 중요
3. FP32 → bfloat16 변환: 마지막 16비트 버리면  (하드웨어 비용 극소)
4. 혼합 정밀도 학습: FP32 마스터 가중치 + bfloat16 연산

실제 bfloat16 사용 예:

import jax
import jax.numpy as jnp

# bfloat16으로 행렬 연산
A = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=jnp.bfloat16)
B = jnp.array([[5.0, 6.0], [7.0, 8.0]], dtype=jnp.bfloat16)

# TPU에서 bfloat16 행렬 곱셈
C = jnp.dot(A, B)  # 자동으로 TPU Systolic Array 활용
print(f"Result dtype: {C.dtype}")   # bfloat16
print(f"Result:\n{C}")

# FP32 vs bfloat16 정밀도 비교
x_fp32 = jnp.array(1.0 / 3.0, dtype=jnp.float32)
x_bf16 = jnp.array(1.0 / 3.0, dtype=jnp.bfloat16)
print(f"FP32:    {float(x_fp32):.10f}")   # 0.3333333433
print(f"BF16:    {float(x_bf16):.10f}")   # 0.3320312500 (약간 덜 정확)
# 딥러닝에서 이 정도 차이는 무시 가능

6. TPU 버전 역사: 4세대의 진화

버전연도핵심 혁신온칩 메모리성능
TPU v12015INT8 추론, 256×256 Systolic Array28MB92 TOPS
TPU v22017bfloat16 학습, HBM 도입8GB HBM45 TFLOPS
TPU v32018수냉 쿨링, v2 대비 2배 성능16GB HBM90 TFLOPS
TPU v420213D 토러스 인터커넥트, OCS32GB HBM275 TFLOPS
TPU v5p2023최대 포드 규모, 트랜스포머 최적화96GB HBM459 TFLOPS

TPU v4의 3D 토러스 인터커넥트

TPU v4 Pod 토폴로지:

TPU v4 칩은 6개의 방향으로 직접 연결됨:
+X, -X, +Y, -Y, +Z, -Z (3차원 토러스)

4096TPU = 16 × 16 × 16 = 4096 노드 3D 토러스
칩당 인터커넥트 대역폭: 600 GB/s

왜 3D 토러스인가?
- 임의 두 노드 간 홉 수: O(N^(1/3))
- 4096개 노드에서 최대 홉 수: 24 (vs 링 구조의 2048)
- 집합 통신 (AllReduce) 효율 극대화

OCS (Optical Circuit Switch):
- 광 스위치로 토폴로지를 동적으로 재구성
- 다른 워크로드에 맞게 네트워크 구조 변경 가능
- 대역폭 낭비 없는 최적 경로

7. XLA: TPU를 위한 컴파일러

XLA (Accelerated Linear Algebra)는 TPU의 소프트웨어 스택의 핵심입니다. TensorFlow/JAX 코드를 TPU에 최적화된 기계어로 변환합니다.

# JAX 코드 (XLA → TPU로 JIT 컴파일):
import jax
import jax.numpy as jnp
from functools import partial

@jax.jit  # XLA로 JIT 컴파일
def transformer_layer(x, w_q, w_k, w_v, w_o, w_ff1, w_ff2):
    """단순화된 트랜스포머 레이어"""
    batch, seq, d_model = x.shape

    # Multi-head attention (Q, K, V 투영)
    q = jnp.dot(x, w_q)  # GEMM → Systolic Array
    k = jnp.dot(x, w_k)  # GEMM → Systolic Array
    v = jnp.dot(x, w_v)  # GEMM → Systolic Array

    # Attention 점수 계산
    scale = jnp.sqrt(d_model // 8).astype(jnp.float32)
    scores = jnp.einsum('bqd,bkd->bqk', q, k) / scale
    attn = jax.nn.softmax(scores, axis=-1)

    # 가중합
    out = jnp.einsum('bqk,bkd->bqd', attn, v)
    out = jnp.dot(out, w_o)  # GEMM

    # FFN
    hidden = jax.nn.gelu(jnp.dot(out, w_ff1))  # GEMM
    result = jnp.dot(hidden, w_ff2)  # GEMM

    return result

# XLA가 수행하는 최적화:
# 1. 연산 퓨전: multiple ops → 단일 커널
# 2. 레이아웃 최적화: Systolic Array에 최적 메모리 배치
# 3. 재물질화: 중간 활성화 재계산 vs 저장 트레이드오프
# 4. 자동 샤딩: 여러 TPU 코어에 자동 분할

XLA의 연산 퓨전 (Operation Fusion)

퓨전 전:
LayerNorm: Mean 계산 → Variance 계산 → NormalizeScaleBias
각 단계마다 HBM에 쓰고 읽음 = 5번의 메모리 왕복

퓨전 후:
LayerNorm: 단일 커널로 모든 계산 수행
메모리 왕복: 1 (입력 읽기 + 출력 쓰기)

결과: 메모리 대역폭 절약 → 5배 빠른 LayerNorm
XLA는 수백 가지 이런 퓨전 패턴을 자동 적용

8. TPU Pod에서 LLM 서빙하기

실제 JAX 코드로 구현하는 분산 추론

import jax
import jax.numpy as jnp
from jax import devices
from jax.experimental import mesh_utils
from jax.sharding import Mesh, PartitionSpec, NamedSharding
import numpy as np

# TPU 디바이스 확인
print(f"사용 가능한 TPU: {jax.devices()}")
# 출력: [TpuDevice(id=0, ...), TpuDevice(id=1, ...), ...]

# 8개 TPU 코어에 모델 샤딩
num_devices = len(jax.devices())
devices_array = mesh_utils.create_device_mesh((num_devices,))
mesh = Mesh(devices_array, axis_names=('model',))

# 샤딩 전략 정의
# 가중치 행렬: 'model' 축 방향으로 열 분할
weight_sharding = NamedSharding(mesh, PartitionSpec('model', None))
# 활성화: 복제 (모든 디바이스에 동일)
activation_sharding = NamedSharding(mesh, PartitionSpec(None, None))

def shard_weights(weights):
    """모델 가중치를 TPU들에 분산"""
    return jax.device_put(weights, weight_sharding)

# 실제 추론 함수
@partial(jax.jit, in_shardings=(activation_sharding, weight_sharding),
         out_shardings=activation_sharding)
def linear_layer(x, w):
    """샤딩된 선형 레이어"""
    return jnp.dot(x, w)

# 예시: 7B 모델을 8개 TPU에 분산
d_model = 4096
d_ffn = 16384  # 4x expansion
batch_size = 8
seq_len = 2048

# 가중치 생성 및 샤딩
w_ffn = np.random.randn(d_model, d_ffn).astype(np.float16)
w_ffn_sharded = shard_weights(w_ffn)

# 입력 생성
x = np.random.randn(batch_size, seq_len, d_model).astype(np.float16)
x_batched = jax.device_put(x, activation_sharding)

# 분산 추론 실행
with mesh:
    output = linear_layer(x_batched, w_ffn_sharded)
    # 8개 TPU가 자동으로 협력하여 계산!

TPU Pod에서 LLM 학습: 실제 규모

PaLM 540B 학습 설정 (Google, 2022):
- 하드웨어: TPU v4 Pod 6144- 총 연산 능력: 6144 × 275 TFLOPS = 1.69 ExaFLOPS
- 학습 토큰: 780 billion tokens
- 학습 시간:57- 배치 크기: 2048 sequences × 2048 tokens

Gemini Ultra 학습 (Google DeepMind, 2023):
- TPU v5p 사용
- 다중 포드를 OCS로 연결
- 최대 규모의 트랜스포머 학습

Gemini 1.5 Pro:
- 1M 토큰 컨텍스트 처리
- TPU Pod의 3D 토러스로 분산 KV 캐시

9. 실전: JAX로 TPU에서 LLM 추론 최적화

# TPU에서 효율적인 추론을 위한 실전 팁

import jax
import jax.numpy as jnp
from jax import lax
import functools

# 팁 1: vmap으로 배치 처리 벡터화
@jax.vmap  # 배치 차원을 자동으로 벡터화
def process_single_example(x):
    return some_model_fn(x)

# 여러 입력을 한 번에 처리
batched_results = process_single_example(batch_of_inputs)

# 팁 2: scan으로 레이어 반복 효율화
def transformer_layer_fn(carry, x):
    """단일 트랜스포머 레이어"""
    return carry, apply_layer(carry, x)

# N개 레이어를 메모리 효율적으로 처리
final_state, outputs = lax.scan(
    transformer_layer_fn,
    init_state,
    layer_inputs
)

# 팁 3: 정적 형태(static shapes)로 XLA 재컴파일 방지
@functools.partial(jax.jit, static_argnames=['seq_len'])
def generate_tokens(model_params, prompt_ids, seq_len: int):
    """seq_len이 같으면 재컴파일 없음"""
    # ... 토큰 생성 로직
    pass

# 팁 4: 프로파일링으로 병목 찾기
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
    result = my_model_fn(inputs)
# Google Chrome Tracing으로 병목 시각화

# 팁 5: 그래디언트 체크포인팅으로 메모리 절약
from jax.checkpoint import checkpoint

# 순방향 패스에서 일부 활성화만 저장
checkpointed_layer = checkpoint(transformer_layer_fn)

10. TPU vs 경쟁자: 실제 LLM 서빙 벤치마크

Llama 2 70B 추론 벤치마크 (배치 크기 1, FP16):

하드웨어           | 메모리 대역폭 | 처리량 (tok/s) | 지연시간 (ms/tok)
------------------|-------------|--------------|----------------
NVIDIA H100 SXM   | 3.35 TB/s   | ~120         | ~8.3
NVIDIA A100 80GB  | 2.0 TB/s    | ~72          | ~13.9
TPU v5p (4)     | 4× 1.2TB/s  | ~160         | ~6.3
AMD MI300X        | 5.3 TB/s    | ~190         | ~5.3
Apple M3 Ultra    | 800 GB/s    | ~55          | ~18.2

결론:
- TPU는 메모리 대역폭 대비 효율 최고
- 클라우드 규모에서는 TPU Pod의 집합 통신 효율이 추가 이점
- 온디바이스에서는 Apple Silicon의 통합 메모리 구조 유리

마치며

Systolic Array는 "특화가 범용을 이긴다"는 원리의 완벽한 증명입니다. 1978년에 제안된 아이디어가 2015년 구글의 데이터센터에서 부활하여, 현재 전 세계 AI 인프라의 핵심을 이루고 있습니다.

핵심 교훈:

  1. 도메인 특화: 트랜스포머의 95%가 GEMM이라는 사실을 활용하면 전용 하드웨어로 압도적 효율 달성
  2. 데이터 재사용: Systolic Array의 핵심은 데이터를 한 번 읽어 최대한 재사용하는 것
  3. 수 형식의 중요성: bfloat16은 "충분히 정확하면서 빠른" 트레이드오프의 승자
  4. 컴파일러와 하드웨어 협력: XLA 없이는 TPU의 성능을 끌어내기 어렵다

다음 편에서는 스마트폰과 PC에 내장된 NPU가 이 아이디어들을 어떻게 에너지 효율적으로 구현하는지 살펴보겠습니다.


참고 자료

  • Jouppi et al., "In-Datacenter Performance Analysis of a Tensor Processing Unit" (ISCA 2017)
  • Google TPU Research Cloud: cloud.google.com/tpu
  • JAX Documentation: jax.readthedocs.io
  • "Scaling Language Models: Methods, Analysis & Insights from Training Gopher" (DeepMind, 2021)
  • PaLM: "Scaling Language Modeling with Pathways" (Google, 2022)