Split View: Transformer 아키텍처 완전 분석: Attention부터 최신 LLM까지
Transformer 아키텍처 완전 분석: Attention부터 최신 LLM까지
Transformer 아키텍처 완전 분석: Attention부터 최신 LLM까지
2017년 Google이 발표한 "Attention is All You Need" 논문은 자연어처리 분야의 패러다임을 완전히 바꿨습니다. RNN, LSTM 기반의 순차적 처리 방식에서 벗어나, Attention 메커니즘만으로 구성된 Transformer 아키텍처는 이후 GPT, BERT, T5, LLaMA 등 현대 LLM의 근간이 되었습니다. 이 글에서는 Transformer를 완전히 처음부터 이해하고, PyTorch로 직접 구현할 수 있도록 수식과 코드를 함께 설명합니다.
1. Attention 메커니즘의 탄생 배경
1.1 RNN/LSTM의 한계
Transformer 이전의 시퀀스 처리는 주로 RNN(Recurrent Neural Network)과 LSTM(Long Short-Term Memory)에 의존했습니다. 이 모델들은 시간 단계별로 입력을 처리하면서 히든 스테이트를 통해 정보를 전파했는데, 다음과 같은 근본적인 문제가 있었습니다.
장기 의존성 문제 (Long-Range Dependency Problem)
RNN은 멀리 떨어진 단어들 사이의 관계를 학습하는 데 어려움을 겪습니다. 예를 들어 "The cat that sat on the mat was hungry"에서 "cat"과 "was hungry"의 관계를 학습하려면, 중간에 있는 모든 토큰을 거쳐야 합니다. 시퀀스가 길어질수록 정보가 소실되거나 희석됩니다. LSTM이 Gating 메커니즘으로 이를 완화했지만, 근본적인 해결책은 아니었습니다.
순차 처리의 비효율 (Sequential Processing)
RNN은 타임스텝 t를 계산하려면 반드시 t-1이 완료되어야 합니다. 이는 현대 GPU의 병렬 처리 능력을 전혀 활용하지 못한다는 뜻입니다. 배치 내에서는 병렬화가 가능하지만, 시퀀스 내에서는 순차적으로만 처리할 수 있습니다.
기울기 소실/폭발 (Vanishing/Exploding Gradients)
긴 시퀀스를 학습할 때 역전파 과정에서 기울기가 소실되거나 폭발하는 문제가 발생합니다. LSTM의 게이트가 이를 완화하지만, 수백~수천 토큰 길이의 시퀀스에서는 여전히 문제가 됩니다.
1.2 Attention의 직관적 이해
Attention 메커니즘은 인간의 주의 집중 방식에서 영감을 받았습니다. 문장을 읽을 때 우리는 모든 단어에 동일한 주의를 기울이지 않습니다. 현재 처리 중인 단어와 관련성이 높은 다른 단어들에 더 많은 주의를 기울입니다.
예를 들어 "나는 파리에서 에펠탑을 봤는데, 그곳은 정말 아름다웠다"를 해석할 때:
- "그곳"을 처리할 때 "파리"와 "에펠탑"에 높은 Attention 가중치가 부여됩니다
- 각 단어가 다른 모든 단어와 직접적으로 상호작용할 수 있습니다
- 거리와 무관하게 관련성 높은 단어들이 연결됩니다
2014년 Bahdanau et al.은 처음으로 seq2seq 모델에 Attention을 도입했습니다. 당시에는 RNN의 보조 메커니즘으로 사용되었지만, 2017년 Vaswani et al.은 "Attention is All You Need"에서 RNN을 완전히 제거하고 Attention만으로 모든 것을 처리하는 Transformer를 제안했습니다.
2. Scaled Dot-Product Attention
2.1 Q, K, V의 개념
Transformer의 핵심 아이디어는 Query, Key, Value 세 가지 벡터를 이용한 Attention입니다. 이 개념은 데이터베이스 검색에서 비유적으로 이해할 수 있습니다.
- Query (Q): "내가 찾고 있는 것은 무엇인가?" — 현재 처리 중인 위치의 표현
- Key (K): "나는 어떤 정보를 가지고 있는가?" — 각 위치의 레이블/식별자
- Value (V): "실제 저장된 정보는 무엇인가?" — 각 위치의 실제 내용
Query와 Key의 유사도(내적)를 계산하여 Attention 가중치를 구하고, 이 가중치로 Value를 가중 합산합니다.
2.2 수식 상세 분석
Scaled Dot-Product Attention의 수식은 다음과 같습니다:
Attention(Q, K, V) = softmax(Q * K^T / sqrt(d_k)) * V
여기서:
- Q: Query 행렬 (시퀀스 길이 x d_k)
- K: Key 행렬 (시퀀스 길이 x d_k)
- V: Value 행렬 (시퀀스 길이 x d_v)
- d_k: Key/Query 차원수
- sqrt(d_k): 스케일링 팩터
Scaling의 이유:
d_k가 커질수록 Q와 K의 내적값이 커지게 됩니다. 예를 들어 d_k=512이면 내적값이 매우 커져 softmax가 매우 극단적인 값(0 또는 1에 가까운)을 출력하게 됩니다. 이는 기울기 소실로 이어집니다. sqrt(d_k)로 나누면 내적의 분산을 약 1로 유지할 수 있습니다.
수학적으로, 각 원소가 평균 0, 분산 1인 분포에서 독립적으로 샘플링된 q와 k의 내적 q·k의 분산은 d_k입니다. sqrt(d_k)로 나누면 분산이 1이 됩니다.
2.3 Masking
두 가지 종류의 마스크가 사용됩니다:
패딩 마스크 (Padding Mask): 배치 내 시퀀스 길이가 다를 때, 짧은 시퀀스에 패딩 토큰을 추가합니다. 패딩 위치에는 매우 작은 값(-inf)을 더해 softmax 후 가중치가 0이 되도록 합니다.
인과적 마스크 (Causal Mask / Look-ahead Mask): 디코더에서 사용됩니다. 현재 위치 i에서는 i보다 앞의 위치들만 참조할 수 있어야 합니다 (미래 토큰을 보면 안 됨). 행렬의 상삼각 부분에 -inf를 더합니다.
2.4 PyTorch 구현
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def scaled_dot_product_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = None,
dropout_p: float = 0.0,
) -> tuple:
"""
Scaled Dot-Product Attention 구현
Args:
query: (batch, heads, seq_len, d_k)
key: (batch, heads, seq_len, d_k)
value: (batch, heads, seq_len, d_v)
mask: (batch, 1, 1, seq_len) 또는 (batch, 1, seq_len, seq_len)
dropout_p: dropout 확률
Returns:
output: (batch, heads, seq_len, d_v)
attn_weights: (batch, heads, seq_len, seq_len)
"""
d_k = query.size(-1)
# Q * K^T / sqrt(d_k): (batch, heads, seq_len, seq_len)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
# 마스크 적용
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax로 Attention 가중치 계산
attn_weights = F.softmax(scores, dim=-1)
# NaN 처리 (모든 위치가 마스킹된 경우)
attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
# Dropout
if dropout_p > 0.0:
attn_weights = F.dropout(attn_weights, p=dropout_p)
# 가중 합산: (batch, heads, seq_len, d_v)
output = torch.matmul(attn_weights, value)
return output, attn_weights
# 테스트
batch_size = 2
num_heads = 8
seq_len = 10
d_k = 64
d_v = 64
q = torch.randn(batch_size, num_heads, seq_len, d_k)
k = torch.randn(batch_size, num_heads, seq_len, d_k)
v = torch.randn(batch_size, num_heads, seq_len, d_v)
# 인과적 마스크 생성 (하삼각 행렬)
causal_mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(0)
output, weights = scaled_dot_product_attention(q, k, v, mask=causal_mask)
print(f"Output shape: {output.shape}") # (2, 8, 10, 64)
print(f"Attention weights shape: {weights.shape}") # (2, 8, 10, 10)
3. Multi-Head Attention
3.1 여러 표현 공간의 장점
Single-Head Attention은 입력을 하나의 표현 공간에서만 처리합니다. Multi-Head Attention은 같은 입력을 여러 개의 다른 표현 공간에서 동시에 처리함으로써, 각 헤드가 서로 다른 측면의 관계를 학습할 수 있게 합니다.
예를 들어:
- Head 1: 구문론적 관계 (주어-동사 일치)
- Head 2: 의미론적 관계 (동의어, 유사 개념)
- Head 3: 위치적 관계 (인접한 단어들)
- Head 4: 공참조 관계 (대명사가 가리키는 명사)
각 헤드의 d_k = d_model / num_heads입니다. h개의 헤드가 병렬로 실행되지만, 각 헤드의 차원이 줄어들기 때문에 전체 연산량은 Single-Head와 비슷합니다.
3.2 수식
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) * W_O
head_i = Attention(Q * W_Q_i, K * W_K_i, V * W_V_i)
여기서 W_Q_i, W_K_i, W_V_i는 각 헤드의 투영 행렬이고, W_O는 출력 투영 행렬입니다.
3.3 완전한 PyTorch 구현
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
super().__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# 투영 레이어 (하나의 큰 행렬로 Q, K, V를 한번에 계산)
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
# 가중치 초기화
self._init_weights()
def _init_weights(self):
for module in [self.W_q, self.W_k, self.W_v, self.W_o]:
nn.init.xavier_uniform_(module.weight)
def split_heads(self, x: torch.Tensor) -> torch.Tensor:
"""
(batch, seq_len, d_model) -> (batch, num_heads, seq_len, d_k)
"""
batch_size, seq_len, _ = x.size()
x = x.view(batch_size, seq_len, self.num_heads, self.d_k)
return x.transpose(1, 2)
def combine_heads(self, x: torch.Tensor) -> torch.Tensor:
"""
(batch, num_heads, seq_len, d_k) -> (batch, seq_len, d_model)
"""
batch_size, _, seq_len, _ = x.size()
x = x.transpose(1, 2).contiguous()
return x.view(batch_size, seq_len, self.d_model)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = None,
) -> tuple:
"""
Args:
query, key, value: (batch, seq_len, d_model)
mask: (batch, 1, 1, seq_len) or (batch, 1, seq_len, seq_len)
Returns:
output: (batch, seq_len, d_model)
attn_weights: (batch, num_heads, seq_len, seq_len)
"""
# Q, K, V 투영
Q = self.split_heads(self.W_q(query)) # (batch, heads, seq_len, d_k)
K = self.split_heads(self.W_k(key)) # (batch, heads, seq_len, d_k)
V = self.split_heads(self.W_v(value)) # (batch, heads, seq_len, d_k)
# Attention 계산
attn_output, attn_weights = scaled_dot_product_attention(
Q, K, V, mask=mask, dropout_p=self.dropout.p if self.training else 0.0
)
# 헤드 결합 및 출력 투영
output = self.combine_heads(attn_output) # (batch, seq_len, d_model)
output = self.W_o(output)
return output, attn_weights
# 테스트
mha = MultiHeadAttention(d_model=512, num_heads=8)
x = torch.randn(2, 10, 512) # (batch=2, seq_len=10, d_model=512)
out, weights = mha(x, x, x)
print(f"MHA output: {out.shape}") # (2, 10, 512)
print(f"Attn weights: {weights.shape}") # (2, 8, 10, 10)
4. Positional Encoding
4.1 순서 정보가 필요한 이유
Attention 메커니즘은 순서를 고려하지 않습니다. "나는 밥을 먹었다"와 "밥은 나를 먹었다"가 동일한 Attention 점수를 받을 수 있습니다. 따라서 각 토큰의 위치 정보를 별도로 입력에 더해주어야 합니다.
4.2 사인/코사인 Positional Encoding
원래 Transformer에서 사용된 방식으로, 학습 없이 수식으로 위치 임베딩을 계산합니다:
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
여기서 pos는 위치, i는 차원 인덱스입니다.
이 방식의 장점:
- 학습이 필요 없음 (파라미터 없음)
- 훈련 시 본 것보다 긴 시퀀스도 처리 가능
- PE(pos+k)를 PE(pos)의 선형 변환으로 표현 가능 → 상대적 위치 관계 인코딩
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_seq_len: int = 5000, dropout: float = 0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
# PE 행렬 계산: (max_seq_len, d_model)
pe = torch.zeros(max_seq_len, d_model)
position = torch.arange(0, max_seq_len).unsqueeze(1).float()
# 분모: 10000^(2i/d_model) = exp(2i * log(10000) / d_model)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term) # 짝수 차원: sin
pe[:, 1::2] = torch.cos(position * div_term) # 홀수 차원: cos
# (1, max_seq_len, d_model)로 reshape하여 배치 브로드캐스팅
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe) # 학습되지 않는 파라미터로 등록
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (batch, seq_len, d_model)
"""
seq_len = x.size(1)
x = x + self.pe[:, :seq_len, :]
return self.dropout(x)
4.3 RoPE (Rotary Position Embedding)
RoPE는 현대 LLM(LLaMA, GPT-NeoX, PaLM 등)의 표준 위치 인코딩이 되었습니다. 핵심 아이디어는 위치 정보를 Q, K 벡터에 회전 변환으로 인코딩하는 것입니다.
핵심 특징:
- 절대 위치 대신 상대 위치를 인코딩
- Q와 K의 내적이 자동으로 상대 위치에 의존하게 됨
- 긴 시퀀스로의 외삽(extrapolation)이 우수함
- Value 벡터에는 적용하지 않음
def precompute_freqs_cis(dim: int, max_seq_len: int, theta: float = 10000.0):
"""RoPE 주파수 행렬 사전계산"""
# 각 차원 쌍에 대한 주파수: theta^(-2i/dim)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len)
# 외적으로 (max_seq_len, dim/2) 생성
freqs = torch.outer(t, freqs)
# 복소수로 변환 (크기=1, 각도=freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
"""RoPE 적용"""
# 마지막 차원을 복소수 쌍으로 변환
xq_r = xq.float().reshape(*xq.shape[:-1], -1, 2)
xk_r = xk.float().reshape(*xk.shape[:-1], -1, 2)
xq_complex = torch.view_as_complex(xq_r)
xk_complex = torch.view_as_complex(xk_r)
# 회전 적용 (복소수 곱셈)
freqs_cis = freqs_cis[:xq.shape[1]]
xq_out = torch.view_as_real(xq_complex * freqs_cis).flatten(-2)
xk_out = torch.view_as_real(xk_complex * freqs_cis).flatten(-2)
return xq_out.type_as(xq), xk_out.type_as(xk)
4.4 ALiBi (Attention with Linear Biases)
ALiBi는 Attention 점수에 위치 거리에 비례한 선형 편향을 더하는 방식입니다:
Attention Score = Q * K^T / sqrt(d_k) - m * |i - j|
여기서 m은 헤드별로 다른 기울기(slope) 값입니다. 별도의 위치 임베딩이 필요 없으며, 훈련 길이보다 훨씬 긴 시퀀스에서도 잘 동작합니다.
5. Transformer Encoder
5.1 Encoder 구조
Transformer Encoder는 N개의 동일한 레이어를 스택합니다. 각 레이어는 두 가지 서브레이어로 구성됩니다:
- Multi-Head Self-Attention
- Position-wise Feed-Forward Network
각 서브레이어 주변에는 잔차 연결(Residual Connection)과 레이어 정규화(Layer Normalization)가 적용됩니다.
5.2 Feed-Forward Network (FFN)
FFN은 각 위치에 독립적으로 적용되는 2층 완전연결 신경망입니다:
FFN(x) = max(0, x * W_1 + b_1) * W_2 + b_2
원래 논문에서는 d_model=512, d_ff=2048 (4배)를 사용했습니다. 현대 LLM에서는 SwiGLU 활성화 함수를 사용하고, d_ff ≈ 2.67 * d_model을 사용하기도 합니다.
5.3 Pre-LN vs Post-LN
원래 Transformer는 Post-LN을 사용했습니다: Sublayer(x) + x → LayerNorm
# Post-LN (원래 Transformer)
x = LayerNorm(x + Sublayer(x))
# Pre-LN (현대 Transformer의 표준)
x = x + Sublayer(LayerNorm(x))
Pre-LN이 훈련 안정성이 더 좋아 현대 모델들이 채택했습니다. 워밍업 스케줄 없이도 안정적으로 학습됩니다.
5.4 완전한 Encoder 구현
class FeedForward(nn.Module):
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class EncoderLayer(nn.Module):
def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.ffn = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
# Pre-LN: Sublayer(LayerNorm(x)) + x
# Self-Attention
attn_out, _ = self.self_attn(
self.norm1(x), self.norm1(x), self.norm1(x), mask=mask
)
x = x + self.dropout(attn_out)
# Feed-Forward
x = x + self.dropout(self.ffn(self.norm2(x)))
return x
class TransformerEncoder(nn.Module):
def __init__(
self,
vocab_size: int,
d_model: int = 512,
num_heads: int = 8,
num_layers: int = 6,
d_ff: int = 2048,
max_seq_len: int = 512,
dropout: float = 0.1,
):
super().__init__()
self.d_model = d_model
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = SinusoidalPositionalEncoding(d_model, max_seq_len, dropout)
self.layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(d_model)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
"""
x: (batch, seq_len) 토큰 인덱스
"""
x = self.embedding(x) * math.sqrt(self.d_model) # 임베딩 스케일링
x = self.pos_encoding(x)
for layer in self.layers:
x = layer(x, mask)
return self.norm(x)
6. Transformer Decoder
6.1 Decoder의 구조
Decoder는 Encoder보다 복잡하며 세 가지 서브레이어를 가집니다:
- Masked Multi-Head Self-Attention (인과적 마스크 적용)
- Multi-Head Cross-Attention (Encoder 출력에 대한 Attention)
- Feed-Forward Network
6.2 Cross-Attention 메커니즘
Cross-Attention에서:
- Query: Decoder의 현재 상태
- Key, Value: Encoder의 출력
이를 통해 Decoder가 소스 시퀀스의 어느 부분에 주목해야 하는지 학습합니다.
6.3 Autoregressive 생성
추론 시 Decoder는 자기회귀적으로 동작합니다:
[BOS]토큰으로 시작- 이전에 생성된 모든 토큰을 입력으로 받음
- 다음 토큰 예측
[EOS]토큰이 생성될 때까지 반복
class DecoderLayer(nn.Module):
def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.ffn = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(
self,
x: torch.Tensor,
encoder_output: torch.Tensor,
src_mask: torch.Tensor = None,
tgt_mask: torch.Tensor = None,
) -> torch.Tensor:
# 1. Masked Self-Attention (인과적 마스크)
self_attn_out, _ = self.self_attn(
self.norm1(x), self.norm1(x), self.norm1(x), mask=tgt_mask
)
x = x + self.dropout(self_attn_out)
# 2. Cross-Attention (Encoder에 대한 Attention)
cross_attn_out, _ = self.cross_attn(
self.norm2(x), encoder_output, encoder_output, mask=src_mask
)
x = x + self.dropout(cross_attn_out)
# 3. Feed-Forward
x = x + self.dropout(self.ffn(self.norm3(x)))
return x
7. 전체 Transformer 구현
class Transformer(nn.Module):
"""
완전한 Encoder-Decoder Transformer 구현
"""
def __init__(
self,
src_vocab_size: int,
tgt_vocab_size: int,
d_model: int = 512,
num_heads: int = 8,
num_encoder_layers: int = 6,
num_decoder_layers: int = 6,
d_ff: int = 2048,
max_seq_len: int = 512,
dropout: float = 0.1,
):
super().__init__()
self.d_model = d_model
# 임베딩 레이어
self.src_embedding = nn.Embedding(src_vocab_size, d_model)
self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
self.pos_encoding = SinusoidalPositionalEncoding(d_model, max_seq_len, dropout)
# Encoder
self.encoder_layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_encoder_layers)
])
# Decoder
self.decoder_layers = nn.ModuleList([
DecoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_decoder_layers)
])
# 출력 레이어
self.encoder_norm = nn.LayerNorm(d_model)
self.decoder_norm = nn.LayerNorm(d_model)
self.output_projection = nn.Linear(d_model, tgt_vocab_size)
# 가중치 초기화
self._init_weights()
def _init_weights(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def encode(
self,
src: torch.Tensor,
src_mask: torch.Tensor = None,
) -> torch.Tensor:
x = self.src_embedding(src) * math.sqrt(self.d_model)
x = self.pos_encoding(x)
for layer in self.encoder_layers:
x = layer(x, src_mask)
return self.encoder_norm(x)
def decode(
self,
tgt: torch.Tensor,
encoder_output: torch.Tensor,
src_mask: torch.Tensor = None,
tgt_mask: torch.Tensor = None,
) -> torch.Tensor:
x = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
x = self.pos_encoding(x)
for layer in self.decoder_layers:
x = layer(x, encoder_output, src_mask, tgt_mask)
return self.decoder_norm(x)
def forward(
self,
src: torch.Tensor,
tgt: torch.Tensor,
src_mask: torch.Tensor = None,
tgt_mask: torch.Tensor = None,
) -> torch.Tensor:
encoder_output = self.encode(src, src_mask)
decoder_output = self.decode(tgt, encoder_output, src_mask, tgt_mask)
logits = self.output_projection(decoder_output)
return logits
@torch.no_grad()
def generate(
self,
src: torch.Tensor,
bos_token_id: int,
eos_token_id: int,
max_new_tokens: int = 100,
temperature: float = 1.0,
) -> torch.Tensor:
"""Greedy 디코딩"""
self.eval()
device = src.device
# 소스 인코딩
src_mask = self._make_padding_mask(src)
encoder_output = self.encode(src, src_mask)
# BOS 토큰으로 시작
tgt = torch.tensor([[bos_token_id]], device=device)
for _ in range(max_new_tokens):
seq_len = tgt.size(1)
tgt_mask = torch.tril(
torch.ones(seq_len, seq_len, device=device)
).unsqueeze(0).unsqueeze(0)
logits = self.decode(tgt, encoder_output, src_mask, tgt_mask)
# 마지막 토큰의 logits만 사용
next_token_logits = logits[:, -1, :] / temperature
next_token = next_token_logits.argmax(dim=-1, keepdim=True)
tgt = torch.cat([tgt, next_token], dim=1)
if next_token.item() == eos_token_id:
break
return tgt
def _make_padding_mask(self, x: torch.Tensor, pad_token_id: int = 0) -> torch.Tensor:
# (batch, 1, 1, seq_len)
return (x != pad_token_id).unsqueeze(1).unsqueeze(2)
# 모델 생성 및 테스트
model = Transformer(
src_vocab_size=32000,
tgt_vocab_size=32000,
d_model=512,
num_heads=8,
num_encoder_layers=6,
num_decoder_layers=6,
d_ff=2048,
)
# 파라미터 수 계산
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}") # 약 44M
# Forward pass 테스트
src = torch.randint(1, 32000, (2, 20)) # (batch=2, src_len=20)
tgt = torch.randint(1, 32000, (2, 15)) # (batch=2, tgt_len=15)
logits = model(src, tgt)
print(f"Logits shape: {logits.shape}") # (2, 15, 32000)
8. Transformer 변형 아키텍처
8.1 BERT (Encoder-only)
BERT(Bidirectional Encoder Representations from Transformers)는 Encoder만 사용하는 아키텍처입니다. 두 가지 사전학습 태스크를 사용합니다:
Masked Language Modeling (MLM): 입력 토큰의 15%를 무작위로 마스킹하고, 마스킹된 토큰을 예측합니다. 양방향 컨텍스트를 활용하므로 이해(Understanding) 태스크에 강합니다.
Next Sentence Prediction (NSP): 두 문장이 연속적인지 예측합니다 (이후 연구에서 실제로 NSP의 효과는 미미하다는 것이 밝혀짐).
BERT는 분류, 질의응답, NER 등 NLU 태스크에 강하지만, 생성 태스크에는 직접 사용할 수 없습니다.
8.2 GPT (Decoder-only)
GPT(Generative Pre-trained Transformer)는 Decoder만 사용하는 자기회귀 모델입니다. 인과적 언어 모델링(CLM)으로 사전학습됩니다: 이전 토큰들로 다음 토큰을 예측합니다.
Cross-Attention이 없는 단순화된 Decoder 구조를 사용합니다 (Self-Attention만 사용). GPT-2, GPT-3, GPT-4, LLaMA, Mistral 등이 모두 이 Decoder-only 아키텍처를 따릅니다.
8.3 T5, BART (Encoder-Decoder)
T5(Text-To-Text Transfer Transformer)는 모든 NLP 태스크를 텍스트-텍스트 문제로 통일했습니다. 번역, 요약, 분류, 질의응답을 모두 동일한 형식으로 처리합니다.
BART는 노이즈 제거 오토인코더로 사전학습됩니다. 다양한 노이즈 전략(토큰 마스킹, 문장 셔플링, 문서 회전)을 사용합니다.
8.4 Vision Transformer (ViT)
이미지를 16x16 패치로 분할하고, 각 패치를 선형 투영하여 토큰으로 처리합니다. 위치 임베딩을 더한 후 Transformer Encoder에 입력합니다.
ViT는 대규모 데이터로 사전학습하면 CNN을 능가하는 성능을 보여줍니다.
9. 현대 LLM의 Transformer 최적화
9.1 RMSNorm
LayerNorm 대신 RMSNorm(Root Mean Square Layer Normalization)을 사용합니다:
RMSNorm(x) = x / RMS(x) * g
RMS(x) = sqrt(mean(x^2) + epsilon)
평균을 계산하지 않아 더 빠르고, 성능은 LayerNorm과 유사합니다. LLaMA, Mistral 등에서 사용됩니다.
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# RMS 계산
rms = x.pow(2).mean(-1, keepdim=True).add(self.eps).sqrt()
return x / rms * self.weight
9.2 SwiGLU Activation Function
SwiGLU는 Swish 함수와 GLU(Gated Linear Unit)를 결합한 활성화 함수입니다:
SwiGLU(x, W, V, b, c) = Swish(xW + b) ⊙ (xV + c)
Swish(x) = x * sigmoid(x)
일반적으로 FFN 차원을 d_ff = (2/3) _ 4 _ d_model ≈ 2.67 * d_model로 조정합니다.
class SwiGLUFeedForward(nn.Module):
def __init__(self, d_model: int, d_ff: int = None):
super().__init__()
if d_ff is None:
# LLaMA 스타일: 2/3 * 4 * d_model을 64의 배수로 반올림
d_ff = int(2 * 4 * d_model / 3)
d_ff = ((d_ff + 63) // 64) * 64
self.w1 = nn.Linear(d_model, d_ff, bias=False)
self.w2 = nn.Linear(d_ff, d_model, bias=False)
self.w3 = nn.Linear(d_model, d_ff, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# SwiGLU: Swish(W1*x) * (W3*x) -> W2
return self.w2(F.silu(self.w1(x)) * self.w3(x))
9.3 Grouped Query Attention (GQA)
Multi-Head Attention(MHA)에서는 각 헤드가 독립적인 Q, K, V를 가집니다. Multi-Query Attention(MQA)은 모든 헤드가 K, V를 공유합니다. GQA는 MHA와 MQA의 중간: G개의 그룹이 각각 독립적인 K, V를 가지고, 그룹 내 헤드들이 공유합니다.
LLaMA 2, Mistral, Gemma 등이 GQA를 사용합니다.
class GroupedQueryAttention(nn.Module):
def __init__(
self,
d_model: int,
num_heads: int, # Query 헤드 수
num_kv_heads: int, # Key/Value 헤드 수 (그룹 수)
dropout: float = 0.0,
):
super().__init__()
assert num_heads % num_kv_heads == 0
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.num_rep = num_heads // num_kv_heads # 각 KV가 몇 개 Q와 공유되는지
self.d_head = d_model // num_heads
self.wq = nn.Linear(d_model, num_heads * self.d_head, bias=False)
self.wk = nn.Linear(d_model, num_kv_heads * self.d_head, bias=False)
self.wv = nn.Linear(d_model, num_kv_heads * self.d_head, bias=False)
self.wo = nn.Linear(num_heads * self.d_head, d_model, bias=False)
self.dropout = dropout
def repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
"""KV 헤드를 Query 헤드 수만큼 반복"""
# (batch, num_kv_heads, seq_len, d_head) -> (batch, num_heads, seq_len, d_head)
if self.num_rep == 1:
return x
batch, n_kv, seq_len, d_head = x.shape
return x.unsqueeze(2).expand(
batch, n_kv, self.num_rep, seq_len, d_head
).reshape(batch, n_kv * self.num_rep, seq_len, d_head)
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor = None,
freqs_cis: torch.Tensor = None,
) -> torch.Tensor:
batch, seq_len, _ = x.shape
# Q, K, V 투영
xq = self.wq(x).view(batch, seq_len, self.num_heads, self.d_head).transpose(1, 2)
xk = self.wk(x).view(batch, seq_len, self.num_kv_heads, self.d_head).transpose(1, 2)
xv = self.wv(x).view(batch, seq_len, self.num_kv_heads, self.d_head).transpose(1, 2)
# RoPE 적용 (선택적)
if freqs_cis is not None:
xq, xk = apply_rotary_emb(
xq.transpose(1, 2), xk.transpose(1, 2), freqs_cis
)
xq = xq.transpose(1, 2)
xk = xk.transpose(1, 2)
# KV를 Q 헤드 수만큼 반복
xk = self.repeat_kv(xk)
xv = self.repeat_kv(xv)
# Attention 계산
scores = torch.matmul(xq, xk.transpose(-2, -1)) / math.sqrt(self.d_head)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
attn = F.dropout(attn, p=self.dropout, training=self.training)
out = torch.matmul(attn, xv)
out = out.transpose(1, 2).contiguous().view(batch, seq_len, -1)
return self.wo(out)
9.4 KV Cache 메커니즘
추론 시 각 생성 단계마다 이전 토큰들의 K, V를 재계산하는 것은 매우 비효율적입니다. KV Cache는 이전 단계에서 계산된 K, V를 저장해두고 재사용합니다.
class KVCacheAttention(nn.Module):
"""KV Cache를 사용하는 Attention"""
def __init__(self, d_model: int, num_heads: int):
super().__init__()
self.num_heads = num_heads
self.d_head = d_model // num_heads
self.wq = nn.Linear(d_model, d_model, bias=False)
self.wk = nn.Linear(d_model, d_model, bias=False)
self.wv = nn.Linear(d_model, d_model, bias=False)
self.wo = nn.Linear(d_model, d_model, bias=False)
# KV Cache
self.cache_k = None
self.cache_v = None
def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
batch, seq_len, _ = x.shape
xq = self.wq(x).view(batch, seq_len, self.num_heads, self.d_head).transpose(1, 2)
xk = self.wk(x).view(batch, seq_len, self.num_heads, self.d_head).transpose(1, 2)
xv = self.wv(x).view(batch, seq_len, self.num_heads, self.d_head).transpose(1, 2)
# 생성 단계에서 KV Cache 사용
if start_pos > 0 and self.cache_k is not None:
self.cache_k = torch.cat([self.cache_k, xk], dim=2)
self.cache_v = torch.cat([self.cache_v, xv], dim=2)
xk = self.cache_k
xv = self.cache_v
else:
self.cache_k = xk
self.cache_v = xv
scores = torch.matmul(xq, xk.transpose(-2, -1)) / math.sqrt(self.d_head)
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, xv)
out = out.transpose(1, 2).contiguous().view(batch, seq_len, -1)
return self.wo(out)
10. Flash Attention
10.1 기존 Attention의 메모리 문제
표준 Attention의 중간 계산 결과인 Attention 행렬은 시퀀스 길이 N에 대해 O(N^2)의 메모리를 필요로 합니다. N=8192이면 Attention 행렬만 8192×8192×4bytes ≈ 256MB입니다 (FP32 기준). GPU HBM에 이 행렬을 쓰고 읽는 과정이 병목이 됩니다.
현대 GPU의 FLOP 대비 메모리 대역폭이 부족하기 때문에 (compute-bound가 아니라 memory-bound), Attention은 실제 FLOP에 비해 훨씬 느립니다.
10.2 IO-aware Attention 알고리즘
Flash Attention(Dao et al., 2022)은 Attention 행렬을 명시적으로 HBM에 쓰지 않고도 정확한 Attention을 계산하는 알고리즘입니다.
핵심 아이디어: Tiling
입력 Q, K, V를 블록(타일)으로 나누어 SRAM(온칩 고속 메모리)에 올린 후, 블록 단위로 Attention을 계산합니다. Online Softmax 알고리즘을 사용하여 전체 행렬 없이도 정확한 softmax를 계산합니다.
수식:
각 블록 i에서:
- Q_i * K_j^T 를 계산하여 스코어 S_ij를 얻음
- 이전 블록들의 최대값과 합을 이용한 온라인 방식으로 softmax 계산
- V_j와 가중합산
복잡도:
- 메모리: O(N) — Attention 행렬 저장 불필요
- 연산: O(N^2)로 동일하지만 메모리 접근 횟수 감소
- 속도: A100에서 2-4배 빠름
10.3 Flash Attention 발전
Flash Attention 1 (2022):
- 처음으로 IO-aware Attention을 공식화
- Forward/Backward pass 모두 구현
- CUDA 커스텀 커널 필요
Flash Attention 2 (2023):
- Q를 외부 루프, K/V를 내부 루프로 변경 (병렬성 향상)
- 워프(warp) 간 작업 분배 최적화
- 비인과적 Attention에서 시퀀스 길이가 2의 배수가 아닐 때 처리 개선
- 약 2배 추가 속도 향상
Flash Attention 3 (2024):
- H100 GPU의 WGMMA (Warp Group Matrix Multiply-Accumulate) 활용
- TMA (Tensor Memory Accelerator) 비동기 복사 활용
- 약 1.5-2배 추가 속도 향상
10.4 사용법
import torch
import torch.nn.functional as F
# PyTorch 2.0+ 에서 내장 Flash Attention
# CUDA 환경에서 자동으로 Flash Attention을 사용
q = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
k = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
# 인과적 마스크를 is_causal=True로 간단하게 지정
with torch.backends.cuda.sdp_kernel(
enable_flash=True, # Flash Attention 활성화
enable_math=False, # 표준 Attention 비활성화
enable_mem_efficient=False
):
output = F.scaled_dot_product_attention(q, k, v, is_causal=True)
print(f"Output shape: {output.shape}") # (2, 8, 1024, 64)
# flash-attn 패키지 직접 사용
# pip install flash-attn --no-build-isolation
try:
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
# QKV packed 형식
qkv = torch.randn(2, 1024, 3, 8, 64, device='cuda', dtype=torch.float16)
out = flash_attn_qkvpacked_func(qkv, dropout_p=0.0, causal=True)
print(f"Flash Attention output: {out.shape}") # (2, 1024, 8, 64)
except ImportError:
print("flash-attn 패키지가 설치되지 않았습니다.")
class FlashAttentionMHA(nn.Module):
"""Flash Attention을 사용하는 MHA"""
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.0):
super().__init__()
self.num_heads = num_heads
self.d_head = d_model // num_heads
self.wqkv = nn.Linear(d_model, 3 * d_model, bias=False)
self.wo = nn.Linear(d_model, d_model, bias=False)
self.dropout = dropout
def forward(self, x: torch.Tensor, is_causal: bool = False) -> torch.Tensor:
batch, seq_len, d_model = x.shape
# Q, K, V 한번에 계산
qkv = self.wqkv(x)
qkv = qkv.view(batch, seq_len, 3, self.num_heads, self.d_head)
q, k, v = qkv.unbind(dim=2)
# (batch, heads, seq_len, d_head)로 변환
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# PyTorch 내장 scaled_dot_product_attention (Flash Attention 자동 선택)
out = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal,
)
# (batch, seq_len, d_model)
out = out.transpose(1, 2).contiguous().view(batch, seq_len, d_model)
return self.wo(out)
11. Mixture of Experts (MoE)
11.1 MoE의 아이디어
Mixture of Experts는 모델 용량(파라미터 수)을 늘리면서도 연산량은 서브리니어하게 유지하는 방법입니다. 핵심 아이디어는 각 토큰을 "전문가(Expert)" 네트워크들 중 일부에만 라우팅하는 것입니다.
예를 들어, 8개의 FFN(전문가)을 가진 MoE에서 각 토큰이 Top-2 전문가에게만 라우팅된다면:
- 총 파라미터: 8개 FFN × 크기 = Dense 모델의 8배
- 실제 연산: 2개 FFN만 사용 = Dense 모델과 동일
11.2 Top-k Routing
class MoELayer(nn.Module):
"""Mixture of Experts FFN 레이어"""
def __init__(
self,
d_model: int,
d_ff: int,
num_experts: int = 8,
top_k: int = 2,
dropout: float = 0.0,
):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
# 라우터 (게이팅 네트워크)
self.router = nn.Linear(d_model, num_experts, bias=False)
# 전문가들
self.experts = nn.ModuleList([
SwiGLUFeedForward(d_model, d_ff)
for _ in range(num_experts)
])
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> tuple:
batch, seq_len, d_model = x.shape
x_flat = x.view(-1, d_model) # (batch*seq_len, d_model)
# 라우팅 스코어 계산
router_logits = self.router(x_flat) # (batch*seq_len, num_experts)
router_probs = F.softmax(router_logits, dim=-1)
# Top-k 전문가 선택
topk_probs, topk_indices = router_probs.topk(self.top_k, dim=-1)
# 선택된 전문가들의 확률을 정규화
topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True)
# 각 토큰을 전문가들에게 디스패치
output = torch.zeros_like(x_flat)
for i in range(self.top_k):
expert_idx = topk_indices[:, i] # 각 토큰이 선택한 i번째 전문가
prob = topk_probs[:, i:i+1] # 가중치
# 각 전문가로 토큰 그룹화
for e in range(self.num_experts):
token_mask = (expert_idx == e)
if token_mask.any():
expert_input = x_flat[token_mask]
expert_output = self.experts[e](expert_input)
output[token_mask] += prob[token_mask] * expert_output
# 보조 로드 밸런싱 손실 (모든 전문가를 균등하게 사용하도록)
# 이를 훈련 손실에 더해 전문가 활용 균등화
aux_loss = self._load_balancing_loss(router_probs)
output = output.view(batch, seq_len, d_model)
return self.dropout(output), aux_loss
def _load_balancing_loss(self, router_probs: torch.Tensor) -> torch.Tensor:
"""Switch Transformer 스타일 보조 손실"""
# 각 전문가가 선택되는 비율이 균등해야 함
num_tokens = router_probs.size(0)
# 각 전문가의 평균 라우팅 확률
avg_probs = router_probs.mean(dim=0) # (num_experts,)
# 각 전문가가 선택된 토큰 비율
top1_indices = router_probs.argmax(dim=-1)
expert_counts = torch.bincount(top1_indices, minlength=self.num_experts).float()
expert_fractions = expert_counts / num_tokens
# 보조 손실: 전문가 비율과 평균 확률의 내적
aux_loss = self.num_experts * (avg_probs * expert_fractions).sum()
return aux_loss
11.3 Mixtral, DeepSeek 아키텍처
Mixtral 8x7B:
- 8개의 전문가 FFN, Top-2 라우팅
- 실제 파라미터: 46.7B, 활성 파라미터: 12.9B (추론 시)
- 각 토큰이 2개 전문가만 사용
- 표준 MHA + RoPE + SwiGLU + 그룹화된 전문가
DeepSeek MoE:
- Fine-grained Expert Segmentation: 전문가를 더 작게 나눔
- Shared Expert: 모든 토큰이 공유하는 기본 전문가 + 라우팅 전문가
- 전문가 붕괴(Expert Collapse) 방지 전략
12. 마치며
Transformer 아키텍처는 단순한 Attention 메커니즘에서 시작하여, 오늘날 수천억 파라미터의 LLM을 구동하는 핵심 구조로 발전했습니다.
이 글에서 다룬 핵심 개념들을 정리하면:
- Scaled Dot-Product Attention: Q, K, V의 내적 기반 정보 검색
- Multi-Head Attention: 여러 표현 공간에서의 병렬 Attention
- Positional Encoding: 사인파 PE → RoPE → ALiBi
- Encoder/Decoder: BERT/GPT 패밀리의 기초
- 현대 최적화: RMSNorm, SwiGLU, GQA, KV Cache
- Flash Attention: IO-aware 메모리 효율적 Attention
- MoE: 희소 활성화로 효율적인 스케일링
다음으로 공부할 것들:
- Speculative Decoding (추론 가속)
- LoRA/QLoRA 파인튜닝
- LLM 정렬 (RLHF, DPO)
- vLLM, TensorRT-LLM 서빙 최적화
참고 자료
- Vaswani et al. (2017). "Attention Is All You Need." — https://arxiv.org/abs/1706.03762
- Su et al. (2022). "RoFormer: Enhanced Transformer with Rotary Position Embedding." — https://arxiv.org/abs/2104.09864
- Dao et al. (2022). "FlashAttention: Fast and Memory-Efficient Exact Attention." — https://arxiv.org/abs/2205.14135
- Dao (2023). "FlashAttention-2: Faster Attention with Better Parallelism." — https://arxiv.org/abs/2307.08691
- Ainslie et al. (2023). "GQA: Training Generalized Multi-Query Transformer Models." — https://arxiv.org/abs/2305.13245
- Shazeer (2020). "GLU Variants Improve Transformer." — https://arxiv.org/abs/2002.05202
- Jiang et al. (2024). "Mixtral of Experts." — https://arxiv.org/abs/2401.04088
Transformer Architecture Complete Analysis: From Attention to Modern LLMs
Transformer Architecture Complete Analysis: From Attention to Modern LLMs
Google's 2017 paper "Attention is All You Need" completely transformed the natural language processing landscape. Moving away from sequential RNN and LSTM architectures, the Transformer — built entirely on the Attention mechanism — became the foundation for GPT, BERT, T5, LLaMA, and every major modern LLM. This guide takes you from first principles through the complete architecture, with PyTorch implementations at every step.
1. The Origins of Attention
1.1 Limitations of RNN/LSTM
Before Transformers, sequence modeling relied heavily on RNNs and LSTMs. These models propagated information through hidden states processed one time step at a time, with fundamental limitations:
Long-Range Dependency Problem
RNNs struggle to learn relationships between words that are far apart. In "The cat that sat on the mat was hungry," the connection between "cat" and "was hungry" must pass through every intermediate token. As sequences grow longer, information fades or gets overwritten. LSTMs alleviate this with gating, but provide no complete solution.
Sequential Processing
RNNs require step t to complete before step t+1 begins. This means the parallel processing power of modern GPUs goes completely unused for sequence-internal computation. You can parallelize across batches, but not within a sequence.
Vanishing/Exploding Gradients
Backpropagating through hundreds or thousands of time steps causes gradients to either shrink to zero or grow uncontrollably. Even with LSTM gates, sequences beyond a few hundred tokens are problematic.
1.2 The Intuition Behind Attention
The Attention mechanism draws inspiration from how humans read. We do not give equal attention to every word — we focus more on words relevant to what we are currently processing.
Consider: "I saw the Eiffel Tower in Paris; it was breathtaking." When processing "it," the model should attend strongly to both "Eiffel Tower" and "Paris." Every word can directly interact with every other word, regardless of distance.
Bahdanau et al. (2014) introduced the first Attention mechanism as an auxiliary component in a seq2seq model. Vaswani et al. (2017) then made the decisive move: eliminate the RNN entirely and build everything from Attention alone.
2. Scaled Dot-Product Attention
2.1 The Q, K, V Framework
The Transformer uses three vectors — Query, Key, and Value — to implement Attention. Think of it as a soft database lookup:
- Query (Q): "What am I looking for?" — the current position's representation
- Key (K): "What information do I hold?" — each position's label/identifier
- Value (V): "What is actually stored?" — each position's content
We compute similarity scores between Queries and Keys, normalize them with softmax into attention weights, then use those weights to produce a weighted sum of the Values.
2.2 The Formula
Attention(Q, K, V) = softmax(Q * K^T / sqrt(d_k)) * V
Where:
- Q: Query matrix (seq_len x d_k)
- K: Key matrix (seq_len x d_k)
- V: Value matrix (seq_len x d_v)
- d_k: Key/Query dimension
- sqrt(d_k): scaling factor
Why scale by sqrt(d_k)?
As d_k grows, dot products grow proportionally in magnitude. With d_k=512, raw dot products can be very large, pushing softmax into regions with near-zero gradients. Dividing by sqrt(d_k) keeps the variance of dot products close to 1 (assuming q and k have unit variance components), preventing gradient issues.
2.3 Masking
Padding Mask: When sequences in a batch have different lengths, shorter sequences are padded. We add -inf to padded positions so softmax assigns them zero weight.
Causal Mask (Look-ahead Mask): Used in Decoders. Position i should only attend to positions 0 through i — never future positions. We fill the upper triangle of the score matrix with -inf.
2.4 PyTorch Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def scaled_dot_product_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = None,
dropout_p: float = 0.0,
) -> tuple:
"""
Scaled Dot-Product Attention
Args:
query: (batch, heads, seq_len, d_k)
key: (batch, heads, seq_len, d_k)
value: (batch, heads, seq_len, d_v)
mask: (batch, 1, 1, seq_len) or (batch, 1, seq_len, seq_len)
dropout_p: dropout probability
Returns:
output: (batch, heads, seq_len, d_v)
attn_weights: (batch, heads, seq_len, seq_len)
"""
d_k = query.size(-1)
# Q * K^T / sqrt(d_k): (batch, heads, seq_len, seq_len)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
# Apply mask
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax attention weights
attn_weights = F.softmax(scores, dim=-1)
# Handle NaN (all positions masked)
attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
# Dropout
if dropout_p > 0.0:
attn_weights = F.dropout(attn_weights, p=dropout_p)
# Weighted sum: (batch, heads, seq_len, d_v)
output = torch.matmul(attn_weights, value)
return output, attn_weights
# Quick test
batch_size = 2
num_heads = 8
seq_len = 10
d_k = 64
q = torch.randn(batch_size, num_heads, seq_len, d_k)
k = torch.randn(batch_size, num_heads, seq_len, d_k)
v = torch.randn(batch_size, num_heads, seq_len, d_k)
# Causal mask (lower triangular)
causal_mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(0)
output, weights = scaled_dot_product_attention(q, k, v, mask=causal_mask)
print(f"Output shape: {output.shape}") # (2, 8, 10, 64)
print(f"Weights shape: {weights.shape}") # (2, 8, 10, 10)
3. Multi-Head Attention
3.1 Why Multiple Heads?
Single-head attention processes the input in only one representation subspace. Multi-Head Attention runs multiple independent attention operations in parallel, each learning different aspects of token relationships:
- Head 1: syntactic relations (subject-verb agreement)
- Head 2: semantic similarity (synonyms, related concepts)
- Head 3: positional relations (neighboring tokens)
- Head 4: coreference (pronouns and their referents)
Each head uses d_k = d_model / num_heads, so the total computation is similar to a single full-sized attention.
3.2 Formula
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) * W_O
head_i = Attention(Q * W_Q_i, K * W_K_i, V * W_V_i)
3.3 Complete PyTorch Implementation
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
super().__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# Single large projection for efficiency
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
self._init_weights()
def _init_weights(self):
for module in [self.W_q, self.W_k, self.W_v, self.W_o]:
nn.init.xavier_uniform_(module.weight)
def split_heads(self, x: torch.Tensor) -> torch.Tensor:
"""(batch, seq_len, d_model) -> (batch, num_heads, seq_len, d_k)"""
batch_size, seq_len, _ = x.size()
x = x.view(batch_size, seq_len, self.num_heads, self.d_k)
return x.transpose(1, 2)
def combine_heads(self, x: torch.Tensor) -> torch.Tensor:
"""(batch, num_heads, seq_len, d_k) -> (batch, seq_len, d_model)"""
batch_size, _, seq_len, _ = x.size()
x = x.transpose(1, 2).contiguous()
return x.view(batch_size, seq_len, self.d_model)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = None,
) -> tuple:
"""
Args:
query, key, value: (batch, seq_len, d_model)
mask: (batch, 1, 1, seq_len) or (batch, 1, seq_len, seq_len)
Returns:
output: (batch, seq_len, d_model)
attn_weights: (batch, num_heads, seq_len, seq_len)
"""
Q = self.split_heads(self.W_q(query))
K = self.split_heads(self.W_k(key))
V = self.split_heads(self.W_v(value))
attn_output, attn_weights = scaled_dot_product_attention(
Q, K, V, mask=mask, dropout_p=self.dropout.p if self.training else 0.0
)
output = self.combine_heads(attn_output)
output = self.W_o(output)
return output, attn_weights
# Test
mha = MultiHeadAttention(d_model=512, num_heads=8)
x = torch.randn(2, 10, 512)
out, weights = mha(x, x, x)
print(f"MHA output: {out.shape}") # (2, 10, 512)
print(f"Attn weights: {weights.shape}") # (2, 8, 10, 10)
4. Positional Encoding
4.1 Why Position Information Matters
Attention is permutation-invariant: if you shuffle all input tokens, the attention scores are identical. "I ate rice" and "Rice ate I" would produce the same attention patterns without positional encoding. We need to inject sequence order information.
4.2 Sinusoidal Positional Encoding
The original Transformer uses fixed sinusoidal functions — no learned parameters:
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
Where pos is the token position and i is the dimension index.
Advantages:
- No parameters to learn
- Can extrapolate to sequences longer than seen during training
- PE(pos+k) can be expressed as a linear function of PE(pos), encoding relative position
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_seq_len: int = 5000, dropout: float = 0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
pe = torch.zeros(max_seq_len, d_model)
position = torch.arange(0, max_seq_len).unsqueeze(1).float()
# 10000^(2i/d_model) = exp(2i * ln(10000) / d_model)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term) # even dims: sin
pe[:, 1::2] = torch.cos(position * div_term) # odd dims: cos
pe = pe.unsqueeze(0) # (1, max_seq_len, d_model) for batch broadcasting
self.register_buffer('pe', pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""x: (batch, seq_len, d_model)"""
seq_len = x.size(1)
x = x + self.pe[:, :seq_len, :]
return self.dropout(x)
4.3 RoPE (Rotary Position Embedding)
RoPE has become the standard positional encoding for modern LLMs (LLaMA, GPT-NeoX, PaLM, Mistral). The key idea is to encode position by applying rotation matrices to Q and K vectors.
Core properties:
- Encodes relative rather than absolute positions
- Q and K dot products automatically depend on relative position
- Excellent extrapolation to longer sequences
- Applied only to Q and K — not V
def precompute_freqs_cis(dim: int, max_seq_len: int, theta: float = 10000.0):
"""Precompute RoPE frequency matrix"""
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len)
freqs = torch.outer(t, freqs) # (max_seq_len, dim/2)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex unit vectors
return freqs_cis
def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
"""Apply RoPE to query and key tensors"""
xq_r = xq.float().reshape(*xq.shape[:-1], -1, 2)
xk_r = xk.float().reshape(*xk.shape[:-1], -1, 2)
xq_complex = torch.view_as_complex(xq_r)
xk_complex = torch.view_as_complex(xk_r)
freqs_cis = freqs_cis[:xq.shape[1]]
xq_out = torch.view_as_real(xq_complex * freqs_cis).flatten(-2)
xk_out = torch.view_as_real(xk_complex * freqs_cis).flatten(-2)
return xq_out.type_as(xq), xk_out.type_as(xk)
4.4 ALiBi (Attention with Linear Biases)
ALiBi adds a linear position bias to attention scores:
Score = Q * K^T / sqrt(d_k) - m * |i - j|
Where m is a head-specific slope. No position embeddings are needed, and it generalizes very well to longer sequences than seen during training.
5. Transformer Encoder
5.1 Encoder Architecture
The Encoder stacks N identical layers, each with two sub-layers:
- Multi-Head Self-Attention
- Position-wise Feed-Forward Network
Each sub-layer is wrapped with a residual connection and layer normalization.
5.2 Feed-Forward Network
The FFN is a two-layer MLP applied independently to each position:
FFN(x) = max(0, x * W_1 + b_1) * W_2 + b_2
The original paper uses d_model=512, d_ff=2048. Modern LLMs use SwiGLU and d_ff ≈ 2.67 * d_model.
5.3 Pre-LN vs Post-LN
# Post-LN (original Transformer)
x = LayerNorm(x + Sublayer(x))
# Pre-LN (modern standard)
x = x + Sublayer(LayerNorm(x))
Pre-LN is more training-stable and does not require learning rate warmup, which is why modern models have adopted it.
5.4 Full Encoder Implementation
class FeedForward(nn.Module):
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class EncoderLayer(nn.Module):
def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.ffn = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
# Pre-LN Self-Attention
attn_out, _ = self.self_attn(
self.norm1(x), self.norm1(x), self.norm1(x), mask=mask
)
x = x + self.dropout(attn_out)
# Pre-LN Feed-Forward
x = x + self.dropout(self.ffn(self.norm2(x)))
return x
class TransformerEncoder(nn.Module):
def __init__(
self,
vocab_size: int,
d_model: int = 512,
num_heads: int = 8,
num_layers: int = 6,
d_ff: int = 2048,
max_seq_len: int = 512,
dropout: float = 0.1,
):
super().__init__()
self.d_model = d_model
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = SinusoidalPositionalEncoding(d_model, max_seq_len, dropout)
self.layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(d_model)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
"""x: (batch, seq_len) token indices"""
x = self.embedding(x) * math.sqrt(self.d_model)
x = self.pos_encoding(x)
for layer in self.layers:
x = layer(x, mask)
return self.norm(x)
6. Transformer Decoder
6.1 Decoder Architecture
The Decoder has three sub-layers per layer:
- Masked Multi-Head Self-Attention (with causal mask)
- Multi-Head Cross-Attention (attends to Encoder output)
- Feed-Forward Network
6.2 Cross-Attention
In Cross-Attention:
- Query comes from the Decoder's current state
- Keys and Values come from the Encoder's output
This is how the Decoder learns which parts of the source sequence to focus on when generating each output token.
6.3 Autoregressive Generation
During inference, the Decoder generates tokens one at a time:
- Start with a [BOS] token
- Use all previously generated tokens as input
- Predict the next token
- Repeat until [EOS] is generated
class DecoderLayer(nn.Module):
def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.ffn = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(
self,
x: torch.Tensor,
encoder_output: torch.Tensor,
src_mask: torch.Tensor = None,
tgt_mask: torch.Tensor = None,
) -> torch.Tensor:
# 1. Masked Self-Attention (causal)
self_attn_out, _ = self.self_attn(
self.norm1(x), self.norm1(x), self.norm1(x), mask=tgt_mask
)
x = x + self.dropout(self_attn_out)
# 2. Cross-Attention over Encoder output
cross_attn_out, _ = self.cross_attn(
self.norm2(x), encoder_output, encoder_output, mask=src_mask
)
x = x + self.dropout(cross_attn_out)
# 3. Feed-Forward
x = x + self.dropout(self.ffn(self.norm3(x)))
return x
7. Full Transformer Implementation
class Transformer(nn.Module):
"""Complete Encoder-Decoder Transformer"""
def __init__(
self,
src_vocab_size: int,
tgt_vocab_size: int,
d_model: int = 512,
num_heads: int = 8,
num_encoder_layers: int = 6,
num_decoder_layers: int = 6,
d_ff: int = 2048,
max_seq_len: int = 512,
dropout: float = 0.1,
):
super().__init__()
self.d_model = d_model
self.src_embedding = nn.Embedding(src_vocab_size, d_model)
self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
self.pos_encoding = SinusoidalPositionalEncoding(d_model, max_seq_len, dropout)
self.encoder_layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_encoder_layers)
])
self.decoder_layers = nn.ModuleList([
DecoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_decoder_layers)
])
self.encoder_norm = nn.LayerNorm(d_model)
self.decoder_norm = nn.LayerNorm(d_model)
self.output_projection = nn.Linear(d_model, tgt_vocab_size)
self._init_weights()
def _init_weights(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def encode(self, src: torch.Tensor, src_mask: torch.Tensor = None) -> torch.Tensor:
x = self.src_embedding(src) * math.sqrt(self.d_model)
x = self.pos_encoding(x)
for layer in self.encoder_layers:
x = layer(x, src_mask)
return self.encoder_norm(x)
def decode(
self,
tgt: torch.Tensor,
encoder_output: torch.Tensor,
src_mask: torch.Tensor = None,
tgt_mask: torch.Tensor = None,
) -> torch.Tensor:
x = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
x = self.pos_encoding(x)
for layer in self.decoder_layers:
x = layer(x, encoder_output, src_mask, tgt_mask)
return self.decoder_norm(x)
def forward(
self,
src: torch.Tensor,
tgt: torch.Tensor,
src_mask: torch.Tensor = None,
tgt_mask: torch.Tensor = None,
) -> torch.Tensor:
encoder_output = self.encode(src, src_mask)
decoder_output = self.decode(tgt, encoder_output, src_mask, tgt_mask)
return self.output_projection(decoder_output)
@torch.no_grad()
def generate(
self,
src: torch.Tensor,
bos_token_id: int,
eos_token_id: int,
max_new_tokens: int = 100,
temperature: float = 1.0,
) -> torch.Tensor:
"""Greedy decoding"""
self.eval()
device = src.device
src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
encoder_output = self.encode(src, src_mask)
tgt = torch.tensor([[bos_token_id]], device=device)
for _ in range(max_new_tokens):
seq_len = tgt.size(1)
tgt_mask = torch.tril(
torch.ones(seq_len, seq_len, device=device)
).unsqueeze(0).unsqueeze(0)
logits = self.decode(tgt, encoder_output, src_mask, tgt_mask)
next_token_logits = logits[:, -1, :] / temperature
next_token = next_token_logits.argmax(dim=-1, keepdim=True)
tgt = torch.cat([tgt, next_token], dim=1)
if next_token.item() == eos_token_id:
break
return tgt
# Create and test model
model = Transformer(src_vocab_size=32000, tgt_vocab_size=32000)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}") # ~44M
src = torch.randint(1, 32000, (2, 20))
tgt = torch.randint(1, 32000, (2, 15))
logits = model(src, tgt)
print(f"Logits shape: {logits.shape}") # (2, 15, 32000)
8. Transformer Variants
8.1 BERT (Encoder-only)
BERT (Bidirectional Encoder Representations from Transformers) uses only the Encoder with two pretraining objectives:
Masked Language Modeling (MLM): 15% of tokens are masked at random. The model predicts the masked tokens using bidirectional context. This makes BERT excellent for understanding tasks.
Next Sentence Prediction (NSP): Predicts whether two sentences are consecutive. Later research showed NSP contributes little, and RoBERTa dropped it without performance loss.
BERT excels at classification, QA, and NER but cannot directly generate text.
8.2 GPT (Decoder-only)
GPT (Generative Pre-trained Transformer) uses only the Decoder stack with causal language modeling: predict the next token from all previous tokens.
The architecture is a simplified Decoder with no cross-attention — just masked self-attention and FFN. GPT-2, GPT-3, GPT-4, LLaMA, Mistral, and most modern LLMs follow this Decoder-only design.
8.3 T5 and BART (Encoder-Decoder)
T5 (Text-To-Text Transfer Transformer) unifies all NLP tasks as text-to-text. Translation, summarization, classification, and QA all use identical input-output format.
BART is pretrained as a denoising autoencoder with various corruption strategies including token masking, sentence shuffling, and document rotation.
8.4 Vision Transformer (ViT)
ViT splits an image into 16×16 patches, linearly projects each patch into a token embedding, adds positional embeddings, and feeds the sequence into a standard Transformer Encoder.
With large-scale pretraining, ViT matches and surpasses CNNs on image classification benchmarks.
9. Modern LLM Optimizations
9.1 RMSNorm
Modern LLMs replace LayerNorm with RMSNorm:
RMSNorm(x) = x / RMS(x) * g
RMS(x) = sqrt(mean(x^2) + epsilon)
No mean subtraction needed — it is faster with comparable performance. Used in LLaMA, Mistral, and Gemma.
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
rms = x.pow(2).mean(-1, keepdim=True).add(self.eps).sqrt()
return x / rms * self.weight
9.2 SwiGLU Activation
SwiGLU combines Swish and Gated Linear Units:
SwiGLU(x, W, V) = Swish(x * W) * (x * V)
Swish(x) = x * sigmoid(x) = x * sigma(x)
The FFN dimension is typically adjusted to d_ff = (2/3) _ 4 _ d_model ≈ 2.67 * d_model.
class SwiGLUFeedForward(nn.Module):
def __init__(self, d_model: int, d_ff: int = None):
super().__init__()
if d_ff is None:
d_ff = int(2 * 4 * d_model / 3)
d_ff = ((d_ff + 63) // 64) * 64 # round to multiple of 64
self.w1 = nn.Linear(d_model, d_ff, bias=False)
self.w2 = nn.Linear(d_ff, d_model, bias=False)
self.w3 = nn.Linear(d_model, d_ff, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))
9.3 Grouped Query Attention (GQA)
MHA: every head has independent Q, K, V. MQA: all heads share K, V. GQA: G groups share K, V within each group.
LLaMA 2, Mistral, and Gemma use GQA to reduce KV cache size while maintaining quality.
class GroupedQueryAttention(nn.Module):
def __init__(
self,
d_model: int,
num_heads: int,
num_kv_heads: int,
dropout: float = 0.0,
):
super().__init__()
assert num_heads % num_kv_heads == 0
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.num_rep = num_heads // num_kv_heads
self.d_head = d_model // num_heads
self.wq = nn.Linear(d_model, num_heads * self.d_head, bias=False)
self.wk = nn.Linear(d_model, num_kv_heads * self.d_head, bias=False)
self.wv = nn.Linear(d_model, num_kv_heads * self.d_head, bias=False)
self.wo = nn.Linear(num_heads * self.d_head, d_model, bias=False)
self.dropout = dropout
def repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
"""Repeat KV heads to match Q head count"""
if self.num_rep == 1:
return x
batch, n_kv, seq_len, d_head = x.shape
return x.unsqueeze(2).expand(
batch, n_kv, self.num_rep, seq_len, d_head
).reshape(batch, n_kv * self.num_rep, seq_len, d_head)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
batch, seq_len, _ = x.shape
xq = self.wq(x).view(batch, seq_len, self.num_heads, self.d_head).transpose(1, 2)
xk = self.wk(x).view(batch, seq_len, self.num_kv_heads, self.d_head).transpose(1, 2)
xv = self.wv(x).view(batch, seq_len, self.num_kv_heads, self.d_head).transpose(1, 2)
xk = self.repeat_kv(xk)
xv = self.repeat_kv(xv)
scores = torch.matmul(xq, xk.transpose(-2, -1)) / math.sqrt(self.d_head)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
attn = F.dropout(attn, p=self.dropout, training=self.training)
out = torch.matmul(attn, xv)
out = out.transpose(1, 2).contiguous().view(batch, seq_len, -1)
return self.wo(out)
9.4 KV Cache
During autoregressive generation, recomputing K and V for all previous tokens at every step is wasteful. KV Cache stores previously computed K and V tensors and reuses them:
class KVCacheAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int):
super().__init__()
self.num_heads = num_heads
self.d_head = d_model // num_heads
self.wq = nn.Linear(d_model, d_model, bias=False)
self.wk = nn.Linear(d_model, d_model, bias=False)
self.wv = nn.Linear(d_model, d_model, bias=False)
self.wo = nn.Linear(d_model, d_model, bias=False)
self.cache_k = None
self.cache_v = None
def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
batch, seq_len, _ = x.shape
xq = self.wq(x).view(batch, seq_len, self.num_heads, self.d_head).transpose(1, 2)
xk = self.wk(x).view(batch, seq_len, self.num_heads, self.d_head).transpose(1, 2)
xv = self.wv(x).view(batch, seq_len, self.num_heads, self.d_head).transpose(1, 2)
if start_pos > 0 and self.cache_k is not None:
self.cache_k = torch.cat([self.cache_k, xk], dim=2)
self.cache_v = torch.cat([self.cache_v, xv], dim=2)
else:
self.cache_k = xk
self.cache_v = xv
scores = torch.matmul(xq, self.cache_k.transpose(-2, -1)) / math.sqrt(self.d_head)
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, self.cache_v)
out = out.transpose(1, 2).contiguous().view(batch, seq_len, -1)
return self.wo(out)
10. Flash Attention
10.1 The Memory Problem with Standard Attention
The attention matrix is O(N^2) in memory. For N=8192, the attention matrix alone requires 8192 x 8192 x 4 bytes ≈ 256MB in FP32. Writing and reading this matrix to/from GPU HBM creates a bandwidth bottleneck.
Modern GPUs are heavily FLOP-bound: they can perform far more arithmetic than they can move data. Standard Attention is memory-bandwidth-bound, meaning it runs much slower than its FLOP count suggests.
10.2 IO-Aware Attention Algorithm
Flash Attention (Dao et al., 2022) computes exact Attention without ever materializing the full attention matrix in HBM.
Core Idea: Tiling
Split Q, K, V into blocks that fit in SRAM (fast on-chip memory). Process one block at a time using the Online Softmax algorithm, which accumulates the correct softmax result without seeing all scores at once.
Algorithm sketch:
- Load Q_i block to SRAM
- For each K_j, V_j block: load to SRAM, compute S_ij = Q_i * K_j^T
- Use running max/sum to update softmax incrementally
- Accumulate into output O_i
Complexity:
- Memory: O(N) instead of O(N^2) — no full attention matrix stored
- FLOPs: O(N^2) — same as standard
- Wall-clock speed: 2-4x faster on A100 due to far fewer HBM reads/writes
10.3 Flash Attention Versions
Flash Attention 1 (2022): First IO-aware Attention formalization. Custom CUDA kernels for forward and backward pass.
Flash Attention 2 (2023): Outer loop over Q, inner loop over K/V (better parallelism). Optimized warp-level work partitioning. Roughly 2x additional speedup.
Flash Attention 3 (2024): Exploits H100's WGMMA (Warp Group Matrix Multiply-Accumulate) and TMA (Tensor Memory Accelerator) asynchronous copy. Another 1.5-2x speedup.
10.4 Usage
import torch
import torch.nn.functional as F
# PyTorch 2.0+ built-in Flash Attention
# On CUDA, automatically uses Flash Attention when possible
q = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
k = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
# is_causal=True uses causal mask efficiently
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False,
enable_mem_efficient=False
):
output = F.scaled_dot_product_attention(q, k, v, is_causal=True)
print(f"Output shape: {output.shape}") # (2, 8, 1024, 64)
# Direct flash-attn package usage
# pip install flash-attn --no-build-isolation
try:
from flash_attn import flash_attn_qkvpacked_func
qkv = torch.randn(2, 1024, 3, 8, 64, device='cuda', dtype=torch.float16)
out = flash_attn_qkvpacked_func(qkv, dropout_p=0.0, causal=True)
print(f"Flash Attention output: {out.shape}")
except ImportError:
print("flash-attn not installed")
class FlashAttentionMHA(nn.Module):
"""MHA using PyTorch's built-in Flash Attention"""
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.0):
super().__init__()
self.num_heads = num_heads
self.d_head = d_model // num_heads
self.wqkv = nn.Linear(d_model, 3 * d_model, bias=False)
self.wo = nn.Linear(d_model, d_model, bias=False)
self.dropout = dropout
def forward(self, x: torch.Tensor, is_causal: bool = False) -> torch.Tensor:
batch, seq_len, d_model = x.shape
qkv = self.wqkv(x).view(batch, seq_len, 3, self.num_heads, self.d_head)
q, k, v = qkv.unbind(dim=2)
q = q.transpose(1, 2) # (batch, heads, seq_len, d_head)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal,
)
out = out.transpose(1, 2).contiguous().view(batch, seq_len, d_model)
return self.wo(out)
11. Mixture of Experts (MoE)
11.1 The Core Idea
Mixture of Experts scales model capacity (parameter count) while keeping active computation constant per forward pass. The key: each token is routed to only a small subset of "expert" networks.
For an 8-expert MoE with Top-2 routing:
- Total parameters: 8 expert FFNs = 8x a dense model's FFN
- Active parameters per token: 2 experts = same compute as a dense model
This achieves a favorable tradeoff: larger models (more capacity/knowledge) without proportionally larger inference cost.
11.2 Top-k Routing
class MoELayer(nn.Module):
"""Mixture of Experts FFN layer"""
def __init__(
self,
d_model: int,
d_ff: int,
num_experts: int = 8,
top_k: int = 2,
dropout: float = 0.0,
):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
# Router / gating network
self.router = nn.Linear(d_model, num_experts, bias=False)
# Experts
self.experts = nn.ModuleList([
SwiGLUFeedForward(d_model, d_ff)
for _ in range(num_experts)
])
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> tuple:
batch, seq_len, d_model = x.shape
x_flat = x.view(-1, d_model)
router_logits = self.router(x_flat)
router_probs = F.softmax(router_logits, dim=-1)
topk_probs, topk_indices = router_probs.topk(self.top_k, dim=-1)
topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True)
output = torch.zeros_like(x_flat)
for i in range(self.top_k):
expert_idx = topk_indices[:, i]
prob = topk_probs[:, i:i+1]
for e in range(self.num_experts):
token_mask = (expert_idx == e)
if token_mask.any():
expert_input = x_flat[token_mask]
expert_output = self.experts[e](expert_input)
output[token_mask] += prob[token_mask] * expert_output
aux_loss = self._load_balancing_loss(router_probs)
return self.dropout(output.view(batch, seq_len, d_model)), aux_loss
def _load_balancing_loss(self, router_probs: torch.Tensor) -> torch.Tensor:
"""Switch Transformer auxiliary loss to prevent expert collapse"""
num_tokens = router_probs.size(0)
avg_probs = router_probs.mean(dim=0)
top1_indices = router_probs.argmax(dim=-1)
expert_counts = torch.bincount(top1_indices, minlength=self.num_experts).float()
expert_fractions = expert_counts / num_tokens
return self.num_experts * (avg_probs * expert_fractions).sum()
11.3 Mixtral and DeepSeek
Mixtral 8x7B:
- 8 expert FFNs, Top-2 routing
- Total parameters: 46.7B; Active parameters per token: 12.9B
- Uses standard MHA + RoPE + SwiGLU + Grouped experts
DeepSeek MoE:
- Fine-grained expert segmentation: more, smaller experts
- Shared expert: a base expert that processes all tokens plus routed experts
- Advanced expert collapse prevention strategies
11.4 MoE Tradeoffs
Advantages:
- Larger model capacity per FLOPs
- Specialization of experts for different token types
- Efficient scaling
Disadvantages:
- All expert parameters must fit in memory even if only 2 are active
- Inter-device communication overhead when experts are sharded across GPUs (expert parallelism)
- Training instability from imbalanced routing without auxiliary losses
Summary
The Transformer has evolved from a seq2seq translation model into the universal architecture powering virtually all state-of-the-art AI systems.
Key concepts covered:
- Scaled Dot-Product Attention: soft database retrieval with Q/K/V
- Multi-Head Attention: parallel attention in multiple subspaces
- Positional Encoding: sinusoidal PE → RoPE → ALiBi
- Encoder/Decoder: foundation for BERT and GPT families
- Modern Optimizations: RMSNorm, SwiGLU, GQA, KV Cache
- Flash Attention: IO-aware memory-efficient exact attention
- MoE: sparse activation for efficient scaling
Next steps to explore:
- Speculative decoding for inference acceleration
- LoRA/QLoRA fine-tuning
- Alignment techniques (RLHF, DPO)
- Production serving with vLLM and TensorRT-LLM
References
- Vaswani et al. (2017). "Attention Is All You Need." — https://arxiv.org/abs/1706.03762
- Su et al. (2022). "RoFormer: Enhanced Transformer with Rotary Position Embedding." — https://arxiv.org/abs/2104.09864
- Dao et al. (2022). "FlashAttention: Fast and Memory-Efficient Exact Attention." — https://arxiv.org/abs/2205.14135
- Dao (2023). "FlashAttention-2: Faster Attention with Better Parallelism." — https://arxiv.org/abs/2307.08691
- Ainslie et al. (2023). "GQA: Training Generalized Multi-Query Transformer Models." — https://arxiv.org/abs/2305.13245
- Shazeer (2020). "GLU Variants Improve Transformer." — https://arxiv.org/abs/2002.05202
- Jiang et al. (2024). "Mixtral of Experts." — https://arxiv.org/abs/2401.04088
- PyTorch scaled_dot_product_attention — https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
- Flash Attention GitHub — https://github.com/Dao-AILab/flash-attention