Skip to content
Published on

Google TPU完全解剖:Systolic Arrayが行列乗算をいかに完璧に解くか

Authors

はじめに:なぜGoogleは自社チップを作ったのか

2013年、Googleのエンジニアたちは衝撃的な計算をしました。もしすべてのGmailユーザーが毎日3分間、深層ニューラルネットワーク(DNN)を使った音声検索を使用したら、Googleのデータセンター全体を2倍に拡張する必要があるという結果でした。DNN推論の計算コストがボトルネックでした。

Jeff Deanを含むチームは結論を出しました:汎用CPUとGPUでは限界がある。ニューラルネットワーク推論に特化したチップを作ろう。

その結果が**Tensor Processing Unit(TPU)**です。2015年にGoogleのデータセンターに初めて配備され、Norman Jouppiが率いるハードウェアチームによって2017年のISCA論文で世界に公開されました。核心的な洞察はシンプルでした:

「ニューラルネットワーク推論には完全なIEE浮動小数点は不要だ。INT8で十分だ。」

このシンプルな洞察が、当時のGPUと比べて10倍以上のパフォーマンス/電力効率を可能にしました。


1. Systolic Array:心臓のように脈打つデータの流れ

名前の由来

「Systolic Array(シストリックアレイ)」という名前は、心臓の**収縮期(systolic)**から来ています。心臓が規則的に収縮して血液を送り出すように、Systolic Arrayは毎クロックサイクルごとに規則的にデータを流します。1978年にH.T. KungとCharles Leisersonが提案した概念をGoogleがニューラルネットワーク向けに大規模に実装しました。

基本構造

4×4のSystolic Arrayが行列乗算 A × B = C を実行する手順を見ていきましょう。

Systolic Array: 4x4 MAC(Multiply-Accumulate)ユニット配列
各セル:左(A行)と上(B列)から入力を受け取り、掛け算し、アキュムレータに加算

時間ステップ0:準備
     B[0,0]  B[1,0]  B[2,0]  B[3,0]   <- B行列の列が下向きに流れる
       |       |       |       |
A[0,0]->[MAC][MAC][MAC][MAC]
A[1,0]->[MAC][MAC][MAC][MAC]
A[2,0]->[MAC][MAC][MAC][MAC]
A[3,0]->[MAC][MAC][MAC][MAC]
         |     |     |     |
       C0   C1   C2   C3   <- 結果が下に出力

時間ステップ1A[0,0]MAC[0,0]に到達 -> A[0,0]*B[0,0] を計算
時間ステップ2A[0,0]MAC[0,1]へ移動、A[0,1]MAC[0,0]に進入
               各セルが部分積を累積

重要なポイント:
-MACユニットは毎クロックサイクル動作(アイドル時間なし)
- 計算中にメモリ読み込み不要(データはすでに"飛行中"- TPU v1:256x256 = 65,536個のMACが同時動作!
- 完璧なデータ再利用:各値は1回メモリから読むだけで256回の演算に貢献

従来方式との違い

汎用プロセッサでの行列乗算と比較してみましょう:

# 通常の行列乗算 - メモリアクセスの問題
def matmul_naive(A, B, N):
    C = [[0] * N for _ in range(N)]
    for i in range(N):
        for j in range(N):
            for k in range(N):
                # メモリからA[i][k]を読み込む
                # メモリからB[k][j]を読み込む
                C[i][j] += A[i][k] * B[k][j]
    return C

# メモリアクセス回数:O(N^3) - 大きなNでは壊滅的
# N=256の場合:16,777,216回のメモリ読み込み!
# キャッシュミスのペナルティ:1回あたり100-300サイクル

# Systolic Arrayの解決策:
# Aの各要素:1回入力 -> 行全体を通過して256回演算に貢献
# Bの各要素:1回入力 -> 列全体を通過して256回演算に貢献
# 合計メモリアクセス:O(N^2)(O(N^3)から改善)
# N=256の場合:65,536回 vs 16M回 - 256倍削減!

2. TPU v1スペック完全分析

2015年に配備されたTPU v1の実際のハードウェアスペックを詳しく見てみましょう。

TPU v1 ハードウェアスペック:
+-----------------------------------------------------+
| Systolic Array:  256 x 256 = 65,536 MACユニット     |
| データ形式:     INT8(重み)、INT32(アキュムレータ)|
| クロック速度:   700 MHz                             |
| オンチップメモリ:28 MB(Weight FIFO + Unified Buf) |
| メモリ帯域幅:   30 GB/s(DDR3|
| 消費電力:       40W                                 |
| プロセスノード:  28nm CMOS、PCIeカード形式          |
+-----------------------------------------------------+

性能計算:
65,536 MAC × 700 MHz × 2(乗算+加算)= 92 TOPSINT8
競合比較:
- TPU v1:  92 TOPS @ 40W  = 2.30 TOPS/W
- K80 GPU8.7 TOPS @ 300W = 0.029 TOPS/W

結果:
- K80比でエネルギー効率79- 生の性能も10.6

これらの数値はISCA 2017論文で6種類の本番ニューラルネットワークワークロードを使って検証された実測値です。TPU v1は平均して当時のGPUより15〜30倍高速な推論性能を示しました。


3. GPU vs TPU:設計思想の違い

GPU設計哲学:「私はすべてができる」
+---------------------------------------------+
| 数千の汎用CUDAコア                          |
| 柔軟なCUDAプログラミングモデル              |
| 複雑な分岐・条件分岐のサポート              |
| 任意のメモリアクセスパターン                |
| グラフィックス、シミュレーション、AIすべて  |
| オーバーヘッド:スケジューラ、レジスタファイル |
+---------------------------------------------+

TPU設計哲学:「行列乗算だけやる、でも完璧に」
+---------------------------------------------+
| Systolic ArrayGEMM専用)                  |
| 決定論的データフロー(コンパイル時決定)    |
| 限定的な演算セット(MAC中心)               |
| 複雑な制御ロジック不要                      |
| 100%ハードウェア活用、無駄なし              |
| エネルギー効率の最大化                      |
+---------------------------------------------+

なぜ特化が勝つのか:
Transformer演算の約95%GEMM(行列乗算)
-> GEMM専用ハードウェアが汎用ハードウェアを圧倒

Transformerでのユーザー演算の実際の分布を測定してみましょう:

# GPT-2規模のTransformerでのFLOPS分析(1.24億パラメータ)
layer_config = {
    'd_model': 768,
    'n_heads': 12,
    'n_layers': 12,
    'seq_len': 1024
}

d = layer_config['d_model']     # 768
n = layer_config['seq_len']     # 1024
L = layer_config['n_layers']    # 12

# Q、K、V投影:各レイヤーで3つの[n x d] x [d x d] GEMM
qkv_flops = 3 * 2 * n * d * d * L

# アテンション:Q x K^T + Vによる加重和
attn_flops = (2 * n * n * d + 2 * n * n * d) * L

# FFN:2つのGEMM [d -> 4d -> d]
ffn_flops = (2 * n * d * (4*d) + 2 * n * (4*d) * d) * L

# 出力投影
out_flops = 2 * n * d * d * L

total = qkv_flops + attn_flops + ffn_flops + out_flops
print(f"QKV投影:   {qkv_flops/total*100:.1f}%")   # ~38%
print(f"アテンション:{attn_flops/total*100:.1f}%") # ~12%
print(f"FFN:        {ffn_flops/total*100:.1f}%")  # ~41%
print(f"出力投影:   {out_flops/total*100:.1f}%")  # ~9%
# 合計GEMM系:約100%(アテンションもバッチGEMM)

4. TPUメモリ階層

TPUのメモリ設計はSystolic Arrayにデータを継続的に供給するよう最適化されています。

TPU v4 メモリ階層:
+-----------------------------------------------------------+
|     Weight FIFO(オンチップ、Systolicに直接供給)         |
|     容量:32 MB / 帯域幅:2.7 TB/s                       |
|     目的:ストールなしで重みをアレイにストリーミング      |
+-----------------------------------------------------------+
|     Unified Buffer(オンチップ、活性化値保存)            |
|     容量:256 MB / 帯域幅:900 GB/s                      |
|     目的:レイヤー間の中間活性化値の保存                 |
+-----------------------------------------------------------+
|     High Bandwidth Memory(HBM、オフチップ)             |
|     容量:32 GB / 帯域幅:1.2 TB/s                       |
|     目的:モデルの全重みを保存                            |
+-----------------------------------------------------------+

データフロー:
HBM -> Weight FIFO -> Systolic Array  (重みストリーム)
HBM -> Unified Buffer -> Systolic Array  (活性化ストリーム)
Systolic Array -> Unified Buffer -> 次のレイヤー

Weight Stationary vs Output Stationary データフロー

Systolic Arrayのデータ管理には2つの主要な戦略があります:

Weight Stationary(重み固定):
- 重み(W)を各MACユニットに事前ロード
- 入力活性化(X)をアレイにストリーミング
- 計算:W × X
- 利点:重みの再利用を最大化(大バッチで有利)
- TPU v1が使用

Output Stationary(出力固定):
-MACが出力C[i][j]を累積
- 入力ABの両方をストリーミング
- 利点:出力書き込みを最小化(小バッチで有利)
- シングルサンプル推論に最適

選択の基準:
- バッチ推論(batch_size > 32):Weight Stationary
- シングルサンプル推論:Output Stationary
- 学習(大バッチ):Weight Stationary

5. bfloat16:Googleの数値フォーマット革新

TPU v2から導入されたbfloat16は、現在の深層学習のデファクトスタンダードになっています。

浮動小数点フォーマット比較:

FP32[1 符号][8 指数][23 仮数] = 32ビット
          範囲:±3.4 × 10^38
          精度:約7
FP16[1 符号][5 指数][10 仮数] = 16ビット
          範囲:±6.5 × 10^4 (危険なほど狭い!)
          精度:約3-4          問題:学習中に勾配消失/爆発のリスク

bfloat16:[1 符号][8 指数][7 仮数]  = 16ビット(Googleの設計)
          範囲:±3.4 × 10^38FP32と同じ!)
          精度:約2-3          秘訣:FP32の上位16ビットを切り取るだけ!

深層学習においてbfloat16が優れる理由:
1. FP32と同じ指数範囲 -> 勾配でオーバーフロー/アンダーフローなし
2. 仮数精度の低下 -> DLは精度よりも範囲の方が重要
3. FP32 -> bfloat16変換:下位16ビットを切り捨てるだけ(コストほぼゼロ)
4. 混合精度:FP32マスター重み + bfloat16計算
import jax
import jax.numpy as jnp
import numpy as np

# bfloat16の特性を実証
x_fp32 = np.float32(1.0 / 3.0)
x_bf16 = jnp.bfloat16(1.0 / 3.0)
x_fp16 = np.float16(1.0 / 3.0)

print(f"FP32:   {float(x_fp32):.10f}")  # 0.3333333433
print(f"BF16:   {float(x_bf16):.10f}")  # 0.3320312500(精度低下は許容範囲)
print(f"FP16:   {float(x_fp16):.10f}")  # 0.3334960938

# 範囲比較(学習において重要)
print(f"FP32 最大値:{np.finfo(np.float32).max:.2e}")   # 3.40e+38
print(f"BF16 最大値:{float(jnp.finfo(jnp.bfloat16).max):.2e}")  # 3.39e+38
print(f"FP16 最大値:{np.finfo(np.float16).max:.2e}")   # 6.55e+04 << 問題!

# 実際の学習での影響(ステップ10000時の損失):
# FP32:  0.2341(安定)
# BF16:  0.2343(ほぼ同じ)
# FP16:  NaN   (多くのモデルで勾配オーバーフロー!)

6. TPUバージョン歴史:4世代の進化

バージョン主要革新メモリピーク性能
TPU v12015INT8推論、256×256アレイ28MB オンチップ92 TOPS
TPU v22017bfloat16学習、HBM導入8GB HBM45 TFLOPS
TPU v32018液冷、v2比2倍性能16GB HBM90 TFLOPS
TPU v420213Dトーラス相互接続、OCS32GB HBM275 TFLOPS
TPU v5p2023最大ポッド規模、Transformer最適化96GB HBM459 TFLOPS

TPU v4の3Dトーラス相互接続

TPU v4 Podトポロジー:

TPU v4チップは6方向(+X-X+Y-Y+Z-Z)の隣と直接接続
これにより3次元トーラスネットワークを形成

TPU v4 Pod全体:16 × 16 × 16 = 4,096チップを3Dトーラスで接続
チップ間相互接続帯域幅:600 GB/s

なぜ3Dトーラスか?
- 任意2ノード間の最大ホップ数:O(N^(1/3))
- 4,096ノードで最大24ホップ(単純リングなら2,048ホップ)
- 集合通信(AllReduce)の高効率化
- ポッドサイズに対してバイセクション帯域幅がよくスケール

OCS(光回路スイッチ):
- ソフトウェアで設定可能な光スイッチングファブリック
- ワークロードに応じてトーラストポロジーを動的に再構成
- 電気スイッチのボトルネックを排除
- 任意チップペア間で全帯域幅を実現

7. XLA:TPUを輝かせるコンパイラ

XLA(Accelerated Linear Algebra)はTPUのソフトウェアスタックの根幹です。JAX/TensorFlowの計算グラフを高度に最適化されたTPUマシンコードにコンパイルします。

# XLA経由でTPUにJITコンパイルされるJAXコード
import jax
import jax.numpy as jnp
from functools import partial

@jax.jit  # XLAでJITコンパイル
def transformer_forward(x, params):
    """JAXスタイルの単一Transformerレイヤー"""
    w_q, w_k, w_v, w_o = params['attn']
    w_ff1, w_ff2 = params['ffn']

    # Layer Norm
    x_norm = jax.nn.standardize(x, axis=-1)

    # マルチヘッドアテンション投影(GEMM -> Systolic Array)
    q = jnp.dot(x_norm, w_q)
    k = jnp.dot(x_norm, w_k)
    v = jnp.dot(x_norm, w_v)

    # アテンション(バッチGEMM)
    d_head = q.shape[-1]
    scale = jnp.sqrt(float(d_head))
    scores = jnp.einsum('bqh,bkh->bqk', q, k) / scale
    attn_weights = jax.nn.softmax(scores, axis=-1)
    attended = jnp.einsum('bqk,bkh->bqh', attn_weights, v)

    # 出力投影
    out = jnp.dot(attended, w_o) + x  # 残差接続

    # FFN
    x2_norm = jax.nn.standardize(out, axis=-1)
    hidden = jax.nn.gelu(jnp.dot(x2_norm, w_ff1))
    ffn_out = jnp.dot(hidden, w_ff2) + out  # 残差接続

    return ffn_out

# XLAがこのコードに施す最適化:
# 1. 演算融合:LayerNorm(5演算)-> 単一カーネル
# 2. レイアウト最適化:Systolic Arrayに最適なメモリ配置
# 3. 再マテリアライゼーション:活性化の保存vs再計算のトレードオフ
# 4. 自動シャーディング:最小通信でTPUチップに自動分割
# 5. 定数畳み込み:コンパイル時に計算可能な値を事前計算

XLA演算融合の具体例

融合前(素朴な実装):
  LayerNorm分解:
  ステップ1:mean(x) -> HBMに書き込み
  ステップ2:x - mean -> HBMに書き込み
  ステップ3:variance(x - mean) -> HBMに書き込み
  ステップ4:normalize -> HBMに書き込み
  ステップ5:scale * x + bias -> HBMに書き込み
  合計:1レイヤーあたり5回のHBM往復

XLA融合後:
  LayerNorm:単一カーネル
  - HBMから入力を1回読み込み
  - 5つの計算すべてをSRAM(オンチップ)で実行
  - HBMに出力を1回書き込み
  合計:1レイヤーあたり1回のHBM往復

効果:メモリ帯域幅使用量が5分の1に削減
XLAはこのような融合パターンを数百種類自動適用

8. TPU PodでのLLM推論

JAXによる完全な分散推論実装

import jax
import jax.numpy as jnp
from jax.experimental import mesh_utils
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from functools import partial
import numpy as np

# 利用可能なTPUデバイスを確認
print(f"利用可能なTPU:{jax.devices()}")
# [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), ...]

# テンソル並列性のための8-TPUメッシュを作成
num_devices = 8
device_mesh = mesh_utils.create_device_mesh((num_devices,))
mesh = Mesh(device_mesh, axis_names=('model',))

# シャーディング戦略を定義
# 重み行列W:列方向にシャード(モデル並列)
W_sharding = NamedSharding(mesh, PartitionSpec(None, 'model'))
# 活性化:複製(全デバイスで同じ入力)
act_sharding = NamedSharding(mesh, PartitionSpec(None, None))

@partial(jax.jit,
         in_shardings=(act_sharding, W_sharding),
         out_shardings=act_sharding)
def sharded_linear(x, W):
    """8つのTPUにわたる列並列線形レイヤー"""
    # 各TPUが [batch, seq, d_model/8] を計算
    local_out = jnp.dot(x, W)
    # 部分結果を合計するためのAllReduce(自動!)
    return local_out

# 7Bモデルの推論パイプライン
def run_inference(params, input_ids):
    """シャードされたパラメータを使った推論"""
    x = embedding_lookup(params['embed'], input_ids)

    for layer_idx in range(32):
        with mesh:
            # 注意:各線形演算はTPU全体に自動分散
            ffn_hidden = sharded_linear(x, params[f'w_ff1_{layer_idx}'])
            x = x + sharded_linear(
                jax.nn.gelu(ffn_hidden),
                params[f'w_ff2_{layer_idx}']
            )

    logits = jnp.dot(x[:, -1, :], params['lm_head'])
    return logits.argmax(axis=-1)

実際の規模:TPU PodでのPaLM学習

PaLM 540B 学習設定(Google, 2022):
- ハードウェア:   TPU v4 Pod × 6,144チップ
- 総計算能力:     6,144 × 275 TFLOPS = 1.69 ExaFLOPS
- 学習トークン数:  7,800億トークン
- 学習期間:       約57- バッチサイズ:   2,048シーケンス × 2,048トークン
- MFU46.2%(モデルFLOPS活用率)

Gemini Ultra学習(Google DeepMind, 2023):
- ハードウェア:   OCS経由で接続された複数のTPU v5p Pod
- MMLUで人間の専門家を初めて上回ったモデル
- コンテキスト長:最大32,768トークンで学習

GPU クラスターに対するTPU Podの優位性:
1. 3Dトーラス:ノード間の最大ホップ数 O(N^1/3) vs リングの O(N)
2. OCS:任意チップペア間でフル帯域幅(スイッチボトルネックなし)
3. XLA:モデル全体コンパイル、グローバル最適化
4. bfloat16:損失スケーリング不要で安定学習

9. 性能比較:TPU vs 競合他社

Llama 2 70B 推論ベンチマーク(batch_size=1、float16):

ハードウェア         | メモリ帯域幅  | スループット | レイテンシ
--------------------|-------------|------------|------------
NVIDIA H100 SXM     | 3.35 TB/s   | ~120 tok/s | ~8.3 ms/tok
NVIDIA A100 80GB    | 2.0  TB/s   | ~72  tok/s | ~13.9 ms/tok
TPU v5p(4チップ)  | 4x 1.2 TB/s | ~160 tok/s | ~6.3 ms/tok
AMD MI300X          | 5.3  TB/s   | ~190 tok/s | ~5.3 ms/tok
Apple M3 Ultra      | 800  GB/s   | ~55  tok/s | ~18.2 ms/tok

重要な発見:スループットはメモリ帯域幅とほぼ完全に相関!
これがLLM推論の"メモリ律速"特性 - NPU記事で詳しく解説

スケールでのTPUの優位性(単一チップ性能を超える点):
- TPU Pod AllReduce:チップ間600 GB/s相互接続
- NVLink(H100):900 GB/s、ただしNVSwitchドメイン内のみ
- TPU OCS:異なるワークロード向けに再構成可能なトポロジー
- コスト:TPU v5p ~$2.20/時 vs H100 ~$3.50/時(比較可能なクラウド)

10. プロファイリングと最適化の実践

# JAX/TPUコードのプロファイリングでボトルネックを発見
import jax
import jax.numpy as jnp

# 方法1:JAXプロファイラー(Chromeトレース形式を出力)
with jax.profiler.trace("/tmp/jax_trace", create_perfetto_link=True):
    # ウォームアップ(JITコンパイル時間を除外)
    _ = my_model_fn(sample_input).block_until_ready()
    # 実際のプロファイリング実行
    result = my_model_fn(real_input).block_until_ready()

# Perfetto UIで開く:https://ui.perfetto.dev
# 確認すべきポイント:
# - 演算間の長いギャップ(メモリ帯域幅律速)
# - デバイス間の不均衡(シャーディングの非効率)
# - 頻繁な再コンパイル(動的シェイプ)

# 方法2:モデルFLOPS活用率(MFU)の計算
def compute_mfu(model_flops, elapsed_seconds, peak_tflops):
    """MFU = 実際のFLOPS / 理論ピークFLOPS"""
    actual_tflops = model_flops / elapsed_seconds / 1e12
    return actual_tflops / peak_tflops * 100

# 例:7Bモデル、100 tok/sのスループット
flops_per_token = 2 * 7e9   # 2 * パラメータ数(概算)
elapsed = 1.0 / 100         # 100 tok/sなら1トークン=0.01秒
tpu_v5p_tflops = 459

mfu = compute_mfu(flops_per_token, elapsed, tpu_v5p_tflops)
print(f"MFU: {mfu:.1f}%")
# MFU > 40%:ハードウェアをうまく活用できている
# MFU が低い:メモリ帯域幅律速(推論では一般的)

# 方法3:再コンパイルイベントの検出
jax.config.update("jax_log_compiles", True)
# 推論中にstderrの "Compiling..." メッセージを確認

まとめ

Systolic Arrayは「特化が汎用を上回る」原理の完璧な実証です。

1978年に提案されたアイデアが2015年にGoogleのデータセンターで復活し、現在では世界で最も高度なAIシステム — Gemini、PaLM、そしてGoogle検索、Gmail、Googleアシスタントを支えるインフラ — の中核を担っています。

重要な教訓:

  1. ドメイン特化の力:Transformer計算の95%がGEMMという事実を活かせば、専用ハードウェアで他を圧倒できる
  2. データ再利用が鍵:Systolic Arrayの本質はデータを1回読み込んで最大限の計算を引き出すこと
  3. 数値フォーマットの重要性:bfloat16は「十分精度で高速」というトレードオフの勝者
  4. コンパイラとハードウェアの協調設計:XLAなしではTPUの性能を引き出せない
  5. スケールにはトポロジーが必要:TPU Podの3Dトーラス相互接続こそがExaFLOPS規模の学習を可能にする

次の記事では、スマートフォンやノートPCに内蔵されたNPUが、これらの原理を1〜5ワットの消費電力でどのように実装しているかを詳しく解説します。


参考文献

  • Jouppi et al., "In-Datacenter Performance Analysis of a Tensor Processing Unit" (ISCA 2017)
  • Google TPU Research Cloud: cloud.google.com/tpu
  • JAX ドキュメント: jax.readthedocs.io
  • "PaLM: Scaling Language Modeling with Pathways" (Chowdhery et al., 2022)
  • "Gemini: A Family of Highly Capable Multimodal Models" (Google DeepMind, 2023)
  • H.T. Kung & C.E. Leiserson, "Systolic Arrays for VLSI" (1978)