- Authors

- Name
- Youngju Kim
- @fjvbn20031
Transformer 아키텍처 완전 분석: Attention부터 최신 LLM까지
2017년 Google이 발표한 "Attention is All You Need" 논문은 자연어처리 분야의 패러다임을 완전히 바꿨습니다. RNN, LSTM 기반의 순차적 처리 방식에서 벗어나, Attention 메커니즘만으로 구성된 Transformer 아키텍처는 이후 GPT, BERT, T5, LLaMA 등 현대 LLM의 근간이 되었습니다. 이 글에서는 Transformer를 완전히 처음부터 이해하고, PyTorch로 직접 구현할 수 있도록 수식과 코드를 함께 설명합니다.
1. Attention 메커니즘의 탄생 배경
1.1 RNN/LSTM의 한계
Transformer 이전의 시퀀스 처리는 주로 RNN(Recurrent Neural Network)과 LSTM(Long Short-Term Memory)에 의존했습니다. 이 모델들은 시간 단계별로 입력을 처리하면서 히든 스테이트를 통해 정보를 전파했는데, 다음과 같은 근본적인 문제가 있었습니다.
장기 의존성 문제 (Long-Range Dependency Problem)
RNN은 멀리 떨어진 단어들 사이의 관계를 학습하는 데 어려움을 겪습니다. 예를 들어 "The cat that sat on the mat was hungry"에서 "cat"과 "was hungry"의 관계를 학습하려면, 중간에 있는 모든 토큰을 거쳐야 합니다. 시퀀스가 길어질수록 정보가 소실되거나 희석됩니다. LSTM이 Gating 메커니즘으로 이를 완화했지만, 근본적인 해결책은 아니었습니다.
순차 처리의 비효율 (Sequential Processing)
RNN은 타임스텝 t를 계산하려면 반드시 t-1이 완료되어야 합니다. 이는 현대 GPU의 병렬 처리 능력을 전혀 활용하지 못한다는 뜻입니다. 배치 내에서는 병렬화가 가능하지만, 시퀀스 내에서는 순차적으로만 처리할 수 있습니다.
기울기 소실/폭발 (Vanishing/Exploding Gradients)
긴 시퀀스를 학습할 때 역전파 과정에서 기울기가 소실되거나 폭발하는 문제가 발생합니다. LSTM의 게이트가 이를 완화하지만, 수백~수천 토큰 길이의 시퀀스에서는 여전히 문제가 됩니다.
1.2 Attention의 직관적 이해
Attention 메커니즘은 인간의 주의 집중 방식에서 영감을 받았습니다. 문장을 읽을 때 우리는 모든 단어에 동일한 주의를 기울이지 않습니다. 현재 처리 중인 단어와 관련성이 높은 다른 단어들에 더 많은 주의를 기울입니다.
예를 들어 "나는 파리에서 에펠탑을 봤는데, 그곳은 정말 아름다웠다"를 해석할 때:
- "그곳"을 처리할 때 "파리"와 "에펠탑"에 높은 Attention 가중치가 부여됩니다
- 각 단어가 다른 모든 단어와 직접적으로 상호작용할 수 있습니다
- 거리와 무관하게 관련성 높은 단어들이 연결됩니다
2014년 Bahdanau et al.은 처음으로 seq2seq 모델에 Attention을 도입했습니다. 당시에는 RNN의 보조 메커니즘으로 사용되었지만, 2017년 Vaswani et al.은 "Attention is All You Need"에서 RNN을 완전히 제거하고 Attention만으로 모든 것을 처리하는 Transformer를 제안했습니다.
2. Scaled Dot-Product Attention
2.1 Q, K, V의 개념
Transformer의 핵심 아이디어는 Query, Key, Value 세 가지 벡터를 이용한 Attention입니다. 이 개념은 데이터베이스 검색에서 비유적으로 이해할 수 있습니다.
- Query (Q): "내가 찾고 있는 것은 무엇인가?" — 현재 처리 중인 위치의 표현
- Key (K): "나는 어떤 정보를 가지고 있는가?" — 각 위치의 레이블/식별자
- Value (V): "실제 저장된 정보는 무엇인가?" — 각 위치의 실제 내용
Query와 Key의 유사도(내적)를 계산하여 Attention 가중치를 구하고, 이 가중치로 Value를 가중 합산합니다.
2.2 수식 상세 분석
Scaled Dot-Product Attention의 수식은 다음과 같습니다:
Attention(Q, K, V) = softmax(Q * K^T / sqrt(d_k)) * V
여기서:
- Q: Query 행렬 (시퀀스 길이 x d_k)
- K: Key 행렬 (시퀀스 길이 x d_k)
- V: Value 행렬 (시퀀스 길이 x d_v)
- d_k: Key/Query 차원수
- sqrt(d_k): 스케일링 팩터
Scaling의 이유:
d_k가 커질수록 Q와 K의 내적값이 커지게 됩니다. 예를 들어 d_k=512이면 내적값이 매우 커져 softmax가 매우 극단적인 값(0 또는 1에 가까운)을 출력하게 됩니다. 이는 기울기 소실로 이어집니다. sqrt(d_k)로 나누면 내적의 분산을 약 1로 유지할 수 있습니다.
수학적으로, 각 원소가 평균 0, 분산 1인 분포에서 독립적으로 샘플링된 q와 k의 내적 q·k의 분산은 d_k입니다. sqrt(d_k)로 나누면 분산이 1이 됩니다.
2.3 Masking
두 가지 종류의 마스크가 사용됩니다:
패딩 마스크 (Padding Mask): 배치 내 시퀀스 길이가 다를 때, 짧은 시퀀스에 패딩 토큰을 추가합니다. 패딩 위치에는 매우 작은 값(-inf)을 더해 softmax 후 가중치가 0이 되도록 합니다.
인과적 마스크 (Causal Mask / Look-ahead Mask): 디코더에서 사용됩니다. 현재 위치 i에서는 i보다 앞의 위치들만 참조할 수 있어야 합니다 (미래 토큰을 보면 안 됨). 행렬의 상삼각 부분에 -inf를 더합니다.
2.4 PyTorch 구현
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def scaled_dot_product_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = None,
dropout_p: float = 0.0,
) -> tuple:
"""
Scaled Dot-Product Attention 구현
Args:
query: (batch, heads, seq_len, d_k)
key: (batch, heads, seq_len, d_k)
value: (batch, heads, seq_len, d_v)
mask: (batch, 1, 1, seq_len) 또는 (batch, 1, seq_len, seq_len)
dropout_p: dropout 확률
Returns:
output: (batch, heads, seq_len, d_v)
attn_weights: (batch, heads, seq_len, seq_len)
"""
d_k = query.size(-1)
# Q * K^T / sqrt(d_k): (batch, heads, seq_len, seq_len)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
# 마스크 적용
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax로 Attention 가중치 계산
attn_weights = F.softmax(scores, dim=-1)
# NaN 처리 (모든 위치가 마스킹된 경우)
attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
# Dropout
if dropout_p > 0.0:
attn_weights = F.dropout(attn_weights, p=dropout_p)
# 가중 합산: (batch, heads, seq_len, d_v)
output = torch.matmul(attn_weights, value)
return output, attn_weights
# 테스트
batch_size = 2
num_heads = 8
seq_len = 10
d_k = 64
d_v = 64
q = torch.randn(batch_size, num_heads, seq_len, d_k)
k = torch.randn(batch_size, num_heads, seq_len, d_k)
v = torch.randn(batch_size, num_heads, seq_len, d_v)
# 인과적 마스크 생성 (하삼각 행렬)
causal_mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(0)
output, weights = scaled_dot_product_attention(q, k, v, mask=causal_mask)
print(f"Output shape: {output.shape}") # (2, 8, 10, 64)
print(f"Attention weights shape: {weights.shape}") # (2, 8, 10, 10)
3. Multi-Head Attention
3.1 여러 표현 공간의 장점
Single-Head Attention은 입력을 하나의 표현 공간에서만 처리합니다. Multi-Head Attention은 같은 입력을 여러 개의 다른 표현 공간에서 동시에 처리함으로써, 각 헤드가 서로 다른 측면의 관계를 학습할 수 있게 합니다.
예를 들어:
- Head 1: 구문론적 관계 (주어-동사 일치)
- Head 2: 의미론적 관계 (동의어, 유사 개념)
- Head 3: 위치적 관계 (인접한 단어들)
- Head 4: 공참조 관계 (대명사가 가리키는 명사)
각 헤드의 d_k = d_model / num_heads입니다. h개의 헤드가 병렬로 실행되지만, 각 헤드의 차원이 줄어들기 때문에 전체 연산량은 Single-Head와 비슷합니다.
3.2 수식
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) * W_O
head_i = Attention(Q * W_Q_i, K * W_K_i, V * W_V_i)
여기서 W_Q_i, W_K_i, W_V_i는 각 헤드의 투영 행렬이고, W_O는 출력 투영 행렬입니다.
3.3 완전한 PyTorch 구현
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
super().__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# 투영 레이어 (하나의 큰 행렬로 Q, K, V를 한번에 계산)
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
# 가중치 초기화
self._init_weights()
def _init_weights(self):
for module in [self.W_q, self.W_k, self.W_v, self.W_o]:
nn.init.xavier_uniform_(module.weight)
def split_heads(self, x: torch.Tensor) -> torch.Tensor:
"""
(batch, seq_len, d_model) -> (batch, num_heads, seq_len, d_k)
"""
batch_size, seq_len, _ = x.size()
x = x.view(batch_size, seq_len, self.num_heads, self.d_k)
return x.transpose(1, 2)
def combine_heads(self, x: torch.Tensor) -> torch.Tensor:
"""
(batch, num_heads, seq_len, d_k) -> (batch, seq_len, d_model)
"""
batch_size, _, seq_len, _ = x.size()
x = x.transpose(1, 2).contiguous()
return x.view(batch_size, seq_len, self.d_model)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = None,
) -> tuple:
"""
Args:
query, key, value: (batch, seq_len, d_model)
mask: (batch, 1, 1, seq_len) or (batch, 1, seq_len, seq_len)
Returns:
output: (batch, seq_len, d_model)
attn_weights: (batch, num_heads, seq_len, seq_len)
"""
# Q, K, V 투영
Q = self.split_heads(self.W_q(query)) # (batch, heads, seq_len, d_k)
K = self.split_heads(self.W_k(key)) # (batch, heads, seq_len, d_k)
V = self.split_heads(self.W_v(value)) # (batch, heads, seq_len, d_k)
# Attention 계산
attn_output, attn_weights = scaled_dot_product_attention(
Q, K, V, mask=mask, dropout_p=self.dropout.p if self.training else 0.0
)
# 헤드 결합 및 출력 투영
output = self.combine_heads(attn_output) # (batch, seq_len, d_model)
output = self.W_o(output)
return output, attn_weights
# 테스트
mha = MultiHeadAttention(d_model=512, num_heads=8)
x = torch.randn(2, 10, 512) # (batch=2, seq_len=10, d_model=512)
out, weights = mha(x, x, x)
print(f"MHA output: {out.shape}") # (2, 10, 512)
print(f"Attn weights: {weights.shape}") # (2, 8, 10, 10)
4. Positional Encoding
4.1 순서 정보가 필요한 이유
Attention 메커니즘은 순서를 고려하지 않습니다. "나는 밥을 먹었다"와 "밥은 나를 먹었다"가 동일한 Attention 점수를 받을 수 있습니다. 따라서 각 토큰의 위치 정보를 별도로 입력에 더해주어야 합니다.
4.2 사인/코사인 Positional Encoding
원래 Transformer에서 사용된 방식으로, 학습 없이 수식으로 위치 임베딩을 계산합니다:
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
여기서 pos는 위치, i는 차원 인덱스입니다.
이 방식의 장점:
- 학습이 필요 없음 (파라미터 없음)
- 훈련 시 본 것보다 긴 시퀀스도 처리 가능
- PE(pos+k)를 PE(pos)의 선형 변환으로 표현 가능 → 상대적 위치 관계 인코딩
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_seq_len: int = 5000, dropout: float = 0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
# PE 행렬 계산: (max_seq_len, d_model)
pe = torch.zeros(max_seq_len, d_model)
position = torch.arange(0, max_seq_len).unsqueeze(1).float()
# 분모: 10000^(2i/d_model) = exp(2i * log(10000) / d_model)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term) # 짝수 차원: sin
pe[:, 1::2] = torch.cos(position * div_term) # 홀수 차원: cos
# (1, max_seq_len, d_model)로 reshape하여 배치 브로드캐스팅
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe) # 학습되지 않는 파라미터로 등록
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (batch, seq_len, d_model)
"""
seq_len = x.size(1)
x = x + self.pe[:, :seq_len, :]
return self.dropout(x)
4.3 RoPE (Rotary Position Embedding)
RoPE는 현대 LLM(LLaMA, GPT-NeoX, PaLM 등)의 표준 위치 인코딩이 되었습니다. 핵심 아이디어는 위치 정보를 Q, K 벡터에 회전 변환으로 인코딩하는 것입니다.
핵심 특징:
- 절대 위치 대신 상대 위치를 인코딩
- Q와 K의 내적이 자동으로 상대 위치에 의존하게 됨
- 긴 시퀀스로의 외삽(extrapolation)이 우수함
- Value 벡터에는 적용하지 않음
def precompute_freqs_cis(dim: int, max_seq_len: int, theta: float = 10000.0):
"""RoPE 주파수 행렬 사전계산"""
# 각 차원 쌍에 대한 주파수: theta^(-2i/dim)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len)
# 외적으로 (max_seq_len, dim/2) 생성
freqs = torch.outer(t, freqs)
# 복소수로 변환 (크기=1, 각도=freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
"""RoPE 적용"""
# 마지막 차원을 복소수 쌍으로 변환
xq_r = xq.float().reshape(*xq.shape[:-1], -1, 2)
xk_r = xk.float().reshape(*xk.shape[:-1], -1, 2)
xq_complex = torch.view_as_complex(xq_r)
xk_complex = torch.view_as_complex(xk_r)
# 회전 적용 (복소수 곱셈)
freqs_cis = freqs_cis[:xq.shape[1]]
xq_out = torch.view_as_real(xq_complex * freqs_cis).flatten(-2)
xk_out = torch.view_as_real(xk_complex * freqs_cis).flatten(-2)
return xq_out.type_as(xq), xk_out.type_as(xk)
4.4 ALiBi (Attention with Linear Biases)
ALiBi는 Attention 점수에 위치 거리에 비례한 선형 편향을 더하는 방식입니다:
Attention Score = Q * K^T / sqrt(d_k) - m * |i - j|
여기서 m은 헤드별로 다른 기울기(slope) 값입니다. 별도의 위치 임베딩이 필요 없으며, 훈련 길이보다 훨씬 긴 시퀀스에서도 잘 동작합니다.
5. Transformer Encoder
5.1 Encoder 구조
Transformer Encoder는 N개의 동일한 레이어를 스택합니다. 각 레이어는 두 가지 서브레이어로 구성됩니다:
- Multi-Head Self-Attention
- Position-wise Feed-Forward Network
각 서브레이어 주변에는 잔차 연결(Residual Connection)과 레이어 정규화(Layer Normalization)가 적용됩니다.
5.2 Feed-Forward Network (FFN)
FFN은 각 위치에 독립적으로 적용되는 2층 완전연결 신경망입니다:
FFN(x) = max(0, x * W_1 + b_1) * W_2 + b_2
원래 논문에서는 d_model=512, d_ff=2048 (4배)를 사용했습니다. 현대 LLM에서는 SwiGLU 활성화 함수를 사용하고, d_ff ≈ 2.67 * d_model을 사용하기도 합니다.
5.3 Pre-LN vs Post-LN
원래 Transformer는 Post-LN을 사용했습니다: Sublayer(x) + x → LayerNorm
# Post-LN (원래 Transformer)
x = LayerNorm(x + Sublayer(x))
# Pre-LN (현대 Transformer의 표준)
x = x + Sublayer(LayerNorm(x))
Pre-LN이 훈련 안정성이 더 좋아 현대 모델들이 채택했습니다. 워밍업 스케줄 없이도 안정적으로 학습됩니다.
5.4 완전한 Encoder 구현
class FeedForward(nn.Module):
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class EncoderLayer(nn.Module):
def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.ffn = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
# Pre-LN: Sublayer(LayerNorm(x)) + x
# Self-Attention
attn_out, _ = self.self_attn(
self.norm1(x), self.norm1(x), self.norm1(x), mask=mask
)
x = x + self.dropout(attn_out)
# Feed-Forward
x = x + self.dropout(self.ffn(self.norm2(x)))
return x
class TransformerEncoder(nn.Module):
def __init__(
self,
vocab_size: int,
d_model: int = 512,
num_heads: int = 8,
num_layers: int = 6,
d_ff: int = 2048,
max_seq_len: int = 512,
dropout: float = 0.1,
):
super().__init__()
self.d_model = d_model
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = SinusoidalPositionalEncoding(d_model, max_seq_len, dropout)
self.layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(d_model)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
"""
x: (batch, seq_len) 토큰 인덱스
"""
x = self.embedding(x) * math.sqrt(self.d_model) # 임베딩 스케일링
x = self.pos_encoding(x)
for layer in self.layers:
x = layer(x, mask)
return self.norm(x)
6. Transformer Decoder
6.1 Decoder의 구조
Decoder는 Encoder보다 복잡하며 세 가지 서브레이어를 가집니다:
- Masked Multi-Head Self-Attention (인과적 마스크 적용)
- Multi-Head Cross-Attention (Encoder 출력에 대한 Attention)
- Feed-Forward Network
6.2 Cross-Attention 메커니즘
Cross-Attention에서:
- Query: Decoder의 현재 상태
- Key, Value: Encoder의 출력
이를 통해 Decoder가 소스 시퀀스의 어느 부분에 주목해야 하는지 학습합니다.
6.3 Autoregressive 생성
추론 시 Decoder는 자기회귀적으로 동작합니다:
[BOS]토큰으로 시작- 이전에 생성된 모든 토큰을 입력으로 받음
- 다음 토큰 예측
[EOS]토큰이 생성될 때까지 반복
class DecoderLayer(nn.Module):
def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.ffn = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(
self,
x: torch.Tensor,
encoder_output: torch.Tensor,
src_mask: torch.Tensor = None,
tgt_mask: torch.Tensor = None,
) -> torch.Tensor:
# 1. Masked Self-Attention (인과적 마스크)
self_attn_out, _ = self.self_attn(
self.norm1(x), self.norm1(x), self.norm1(x), mask=tgt_mask
)
x = x + self.dropout(self_attn_out)
# 2. Cross-Attention (Encoder에 대한 Attention)
cross_attn_out, _ = self.cross_attn(
self.norm2(x), encoder_output, encoder_output, mask=src_mask
)
x = x + self.dropout(cross_attn_out)
# 3. Feed-Forward
x = x + self.dropout(self.ffn(self.norm3(x)))
return x
7. 전체 Transformer 구현
class Transformer(nn.Module):
"""
완전한 Encoder-Decoder Transformer 구현
"""
def __init__(
self,
src_vocab_size: int,
tgt_vocab_size: int,
d_model: int = 512,
num_heads: int = 8,
num_encoder_layers: int = 6,
num_decoder_layers: int = 6,
d_ff: int = 2048,
max_seq_len: int = 512,
dropout: float = 0.1,
):
super().__init__()
self.d_model = d_model
# 임베딩 레이어
self.src_embedding = nn.Embedding(src_vocab_size, d_model)
self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
self.pos_encoding = SinusoidalPositionalEncoding(d_model, max_seq_len, dropout)
# Encoder
self.encoder_layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_encoder_layers)
])
# Decoder
self.decoder_layers = nn.ModuleList([
DecoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_decoder_layers)
])
# 출력 레이어
self.encoder_norm = nn.LayerNorm(d_model)
self.decoder_norm = nn.LayerNorm(d_model)
self.output_projection = nn.Linear(d_model, tgt_vocab_size)
# 가중치 초기화
self._init_weights()
def _init_weights(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def encode(
self,
src: torch.Tensor,
src_mask: torch.Tensor = None,
) -> torch.Tensor:
x = self.src_embedding(src) * math.sqrt(self.d_model)
x = self.pos_encoding(x)
for layer in self.encoder_layers:
x = layer(x, src_mask)
return self.encoder_norm(x)
def decode(
self,
tgt: torch.Tensor,
encoder_output: torch.Tensor,
src_mask: torch.Tensor = None,
tgt_mask: torch.Tensor = None,
) -> torch.Tensor:
x = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
x = self.pos_encoding(x)
for layer in self.decoder_layers:
x = layer(x, encoder_output, src_mask, tgt_mask)
return self.decoder_norm(x)
def forward(
self,
src: torch.Tensor,
tgt: torch.Tensor,
src_mask: torch.Tensor = None,
tgt_mask: torch.Tensor = None,
) -> torch.Tensor:
encoder_output = self.encode(src, src_mask)
decoder_output = self.decode(tgt, encoder_output, src_mask, tgt_mask)
logits = self.output_projection(decoder_output)
return logits
@torch.no_grad()
def generate(
self,
src: torch.Tensor,
bos_token_id: int,
eos_token_id: int,
max_new_tokens: int = 100,
temperature: float = 1.0,
) -> torch.Tensor:
"""Greedy 디코딩"""
self.eval()
device = src.device
# 소스 인코딩
src_mask = self._make_padding_mask(src)
encoder_output = self.encode(src, src_mask)
# BOS 토큰으로 시작
tgt = torch.tensor([[bos_token_id]], device=device)
for _ in range(max_new_tokens):
seq_len = tgt.size(1)
tgt_mask = torch.tril(
torch.ones(seq_len, seq_len, device=device)
).unsqueeze(0).unsqueeze(0)
logits = self.decode(tgt, encoder_output, src_mask, tgt_mask)
# 마지막 토큰의 logits만 사용
next_token_logits = logits[:, -1, :] / temperature
next_token = next_token_logits.argmax(dim=-1, keepdim=True)
tgt = torch.cat([tgt, next_token], dim=1)
if next_token.item() == eos_token_id:
break
return tgt
def _make_padding_mask(self, x: torch.Tensor, pad_token_id: int = 0) -> torch.Tensor:
# (batch, 1, 1, seq_len)
return (x != pad_token_id).unsqueeze(1).unsqueeze(2)
# 모델 생성 및 테스트
model = Transformer(
src_vocab_size=32000,
tgt_vocab_size=32000,
d_model=512,
num_heads=8,
num_encoder_layers=6,
num_decoder_layers=6,
d_ff=2048,
)
# 파라미터 수 계산
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}") # 약 44M
# Forward pass 테스트
src = torch.randint(1, 32000, (2, 20)) # (batch=2, src_len=20)
tgt = torch.randint(1, 32000, (2, 15)) # (batch=2, tgt_len=15)
logits = model(src, tgt)
print(f"Logits shape: {logits.shape}") # (2, 15, 32000)
8. Transformer 변형 아키텍처
8.1 BERT (Encoder-only)
BERT(Bidirectional Encoder Representations from Transformers)는 Encoder만 사용하는 아키텍처입니다. 두 가지 사전학습 태스크를 사용합니다:
Masked Language Modeling (MLM): 입력 토큰의 15%를 무작위로 마스킹하고, 마스킹된 토큰을 예측합니다. 양방향 컨텍스트를 활용하므로 이해(Understanding) 태스크에 강합니다.
Next Sentence Prediction (NSP): 두 문장이 연속적인지 예측합니다 (이후 연구에서 실제로 NSP의 효과는 미미하다는 것이 밝혀짐).
BERT는 분류, 질의응답, NER 등 NLU 태스크에 강하지만, 생성 태스크에는 직접 사용할 수 없습니다.
8.2 GPT (Decoder-only)
GPT(Generative Pre-trained Transformer)는 Decoder만 사용하는 자기회귀 모델입니다. 인과적 언어 모델링(CLM)으로 사전학습됩니다: 이전 토큰들로 다음 토큰을 예측합니다.
Cross-Attention이 없는 단순화된 Decoder 구조를 사용합니다 (Self-Attention만 사용). GPT-2, GPT-3, GPT-4, LLaMA, Mistral 등이 모두 이 Decoder-only 아키텍처를 따릅니다.
8.3 T5, BART (Encoder-Decoder)
T5(Text-To-Text Transfer Transformer)는 모든 NLP 태스크를 텍스트-텍스트 문제로 통일했습니다. 번역, 요약, 분류, 질의응답을 모두 동일한 형식으로 처리합니다.
BART는 노이즈 제거 오토인코더로 사전학습됩니다. 다양한 노이즈 전략(토큰 마스킹, 문장 셔플링, 문서 회전)을 사용합니다.
8.4 Vision Transformer (ViT)
이미지를 16x16 패치로 분할하고, 각 패치를 선형 투영하여 토큰으로 처리합니다. 위치 임베딩을 더한 후 Transformer Encoder에 입력합니다.
ViT는 대규모 데이터로 사전학습하면 CNN을 능가하는 성능을 보여줍니다.
9. 현대 LLM의 Transformer 최적화
9.1 RMSNorm
LayerNorm 대신 RMSNorm(Root Mean Square Layer Normalization)을 사용합니다:
RMSNorm(x) = x / RMS(x) * g
RMS(x) = sqrt(mean(x^2) + epsilon)
평균을 계산하지 않아 더 빠르고, 성능은 LayerNorm과 유사합니다. LLaMA, Mistral 등에서 사용됩니다.
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# RMS 계산
rms = x.pow(2).mean(-1, keepdim=True).add(self.eps).sqrt()
return x / rms * self.weight
9.2 SwiGLU Activation Function
SwiGLU는 Swish 함수와 GLU(Gated Linear Unit)를 결합한 활성화 함수입니다:
SwiGLU(x, W, V, b, c) = Swish(xW + b) ⊙ (xV + c)
Swish(x) = x * sigmoid(x)
일반적으로 FFN 차원을 d_ff = (2/3) _ 4 _ d_model ≈ 2.67 * d_model로 조정합니다.
class SwiGLUFeedForward(nn.Module):
def __init__(self, d_model: int, d_ff: int = None):
super().__init__()
if d_ff is None:
# LLaMA 스타일: 2/3 * 4 * d_model을 64의 배수로 반올림
d_ff = int(2 * 4 * d_model / 3)
d_ff = ((d_ff + 63) // 64) * 64
self.w1 = nn.Linear(d_model, d_ff, bias=False)
self.w2 = nn.Linear(d_ff, d_model, bias=False)
self.w3 = nn.Linear(d_model, d_ff, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# SwiGLU: Swish(W1*x) * (W3*x) -> W2
return self.w2(F.silu(self.w1(x)) * self.w3(x))
9.3 Grouped Query Attention (GQA)
Multi-Head Attention(MHA)에서는 각 헤드가 독립적인 Q, K, V를 가집니다. Multi-Query Attention(MQA)은 모든 헤드가 K, V를 공유합니다. GQA는 MHA와 MQA의 중간: G개의 그룹이 각각 독립적인 K, V를 가지고, 그룹 내 헤드들이 공유합니다.
LLaMA 2, Mistral, Gemma 등이 GQA를 사용합니다.
class GroupedQueryAttention(nn.Module):
def __init__(
self,
d_model: int,
num_heads: int, # Query 헤드 수
num_kv_heads: int, # Key/Value 헤드 수 (그룹 수)
dropout: float = 0.0,
):
super().__init__()
assert num_heads % num_kv_heads == 0
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.num_rep = num_heads // num_kv_heads # 각 KV가 몇 개 Q와 공유되는지
self.d_head = d_model // num_heads
self.wq = nn.Linear(d_model, num_heads * self.d_head, bias=False)
self.wk = nn.Linear(d_model, num_kv_heads * self.d_head, bias=False)
self.wv = nn.Linear(d_model, num_kv_heads * self.d_head, bias=False)
self.wo = nn.Linear(num_heads * self.d_head, d_model, bias=False)
self.dropout = dropout
def repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
"""KV 헤드를 Query 헤드 수만큼 반복"""
# (batch, num_kv_heads, seq_len, d_head) -> (batch, num_heads, seq_len, d_head)
if self.num_rep == 1:
return x
batch, n_kv, seq_len, d_head = x.shape
return x.unsqueeze(2).expand(
batch, n_kv, self.num_rep, seq_len, d_head
).reshape(batch, n_kv * self.num_rep, seq_len, d_head)
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor = None,
freqs_cis: torch.Tensor = None,
) -> torch.Tensor:
batch, seq_len, _ = x.shape
# Q, K, V 투영
xq = self.wq(x).view(batch, seq_len, self.num_heads, self.d_head).transpose(1, 2)
xk = self.wk(x).view(batch, seq_len, self.num_kv_heads, self.d_head).transpose(1, 2)
xv = self.wv(x).view(batch, seq_len, self.num_kv_heads, self.d_head).transpose(1, 2)
# RoPE 적용 (선택적)
if freqs_cis is not None:
xq, xk = apply_rotary_emb(
xq.transpose(1, 2), xk.transpose(1, 2), freqs_cis
)
xq = xq.transpose(1, 2)
xk = xk.transpose(1, 2)
# KV를 Q 헤드 수만큼 반복
xk = self.repeat_kv(xk)
xv = self.repeat_kv(xv)
# Attention 계산
scores = torch.matmul(xq, xk.transpose(-2, -1)) / math.sqrt(self.d_head)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
attn = F.dropout(attn, p=self.dropout, training=self.training)
out = torch.matmul(attn, xv)
out = out.transpose(1, 2).contiguous().view(batch, seq_len, -1)
return self.wo(out)
9.4 KV Cache 메커니즘
추론 시 각 생성 단계마다 이전 토큰들의 K, V를 재계산하는 것은 매우 비효율적입니다. KV Cache는 이전 단계에서 계산된 K, V를 저장해두고 재사용합니다.
class KVCacheAttention(nn.Module):
"""KV Cache를 사용하는 Attention"""
def __init__(self, d_model: int, num_heads: int):
super().__init__()
self.num_heads = num_heads
self.d_head = d_model // num_heads
self.wq = nn.Linear(d_model, d_model, bias=False)
self.wk = nn.Linear(d_model, d_model, bias=False)
self.wv = nn.Linear(d_model, d_model, bias=False)
self.wo = nn.Linear(d_model, d_model, bias=False)
# KV Cache
self.cache_k = None
self.cache_v = None
def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
batch, seq_len, _ = x.shape
xq = self.wq(x).view(batch, seq_len, self.num_heads, self.d_head).transpose(1, 2)
xk = self.wk(x).view(batch, seq_len, self.num_heads, self.d_head).transpose(1, 2)
xv = self.wv(x).view(batch, seq_len, self.num_heads, self.d_head).transpose(1, 2)
# 생성 단계에서 KV Cache 사용
if start_pos > 0 and self.cache_k is not None:
self.cache_k = torch.cat([self.cache_k, xk], dim=2)
self.cache_v = torch.cat([self.cache_v, xv], dim=2)
xk = self.cache_k
xv = self.cache_v
else:
self.cache_k = xk
self.cache_v = xv
scores = torch.matmul(xq, xk.transpose(-2, -1)) / math.sqrt(self.d_head)
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, xv)
out = out.transpose(1, 2).contiguous().view(batch, seq_len, -1)
return self.wo(out)
10. Flash Attention
10.1 기존 Attention의 메모리 문제
표준 Attention의 중간 계산 결과인 Attention 행렬은 시퀀스 길이 N에 대해 O(N^2)의 메모리를 필요로 합니다. N=8192이면 Attention 행렬만 8192×8192×4bytes ≈ 256MB입니다 (FP32 기준). GPU HBM에 이 행렬을 쓰고 읽는 과정이 병목이 됩니다.
현대 GPU의 FLOP 대비 메모리 대역폭이 부족하기 때문에 (compute-bound가 아니라 memory-bound), Attention은 실제 FLOP에 비해 훨씬 느립니다.
10.2 IO-aware Attention 알고리즘
Flash Attention(Dao et al., 2022)은 Attention 행렬을 명시적으로 HBM에 쓰지 않고도 정확한 Attention을 계산하는 알고리즘입니다.
핵심 아이디어: Tiling
입력 Q, K, V를 블록(타일)으로 나누어 SRAM(온칩 고속 메모리)에 올린 후, 블록 단위로 Attention을 계산합니다. Online Softmax 알고리즘을 사용하여 전체 행렬 없이도 정확한 softmax를 계산합니다.
수식:
각 블록 i에서:
- Q_i * K_j^T 를 계산하여 스코어 S_ij를 얻음
- 이전 블록들의 최대값과 합을 이용한 온라인 방식으로 softmax 계산
- V_j와 가중합산
복잡도:
- 메모리: O(N) — Attention 행렬 저장 불필요
- 연산: O(N^2)로 동일하지만 메모리 접근 횟수 감소
- 속도: A100에서 2-4배 빠름
10.3 Flash Attention 발전
Flash Attention 1 (2022):
- 처음으로 IO-aware Attention을 공식화
- Forward/Backward pass 모두 구현
- CUDA 커스텀 커널 필요
Flash Attention 2 (2023):
- Q를 외부 루프, K/V를 내부 루프로 변경 (병렬성 향상)
- 워프(warp) 간 작업 분배 최적화
- 비인과적 Attention에서 시퀀스 길이가 2의 배수가 아닐 때 처리 개선
- 약 2배 추가 속도 향상
Flash Attention 3 (2024):
- H100 GPU의 WGMMA (Warp Group Matrix Multiply-Accumulate) 활용
- TMA (Tensor Memory Accelerator) 비동기 복사 활용
- 약 1.5-2배 추가 속도 향상
10.4 사용법
import torch
import torch.nn.functional as F
# PyTorch 2.0+ 에서 내장 Flash Attention
# CUDA 환경에서 자동으로 Flash Attention을 사용
q = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
k = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
# 인과적 마스크를 is_causal=True로 간단하게 지정
with torch.backends.cuda.sdp_kernel(
enable_flash=True, # Flash Attention 활성화
enable_math=False, # 표준 Attention 비활성화
enable_mem_efficient=False
):
output = F.scaled_dot_product_attention(q, k, v, is_causal=True)
print(f"Output shape: {output.shape}") # (2, 8, 1024, 64)
# flash-attn 패키지 직접 사용
# pip install flash-attn --no-build-isolation
try:
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
# QKV packed 형식
qkv = torch.randn(2, 1024, 3, 8, 64, device='cuda', dtype=torch.float16)
out = flash_attn_qkvpacked_func(qkv, dropout_p=0.0, causal=True)
print(f"Flash Attention output: {out.shape}") # (2, 1024, 8, 64)
except ImportError:
print("flash-attn 패키지가 설치되지 않았습니다.")
class FlashAttentionMHA(nn.Module):
"""Flash Attention을 사용하는 MHA"""
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.0):
super().__init__()
self.num_heads = num_heads
self.d_head = d_model // num_heads
self.wqkv = nn.Linear(d_model, 3 * d_model, bias=False)
self.wo = nn.Linear(d_model, d_model, bias=False)
self.dropout = dropout
def forward(self, x: torch.Tensor, is_causal: bool = False) -> torch.Tensor:
batch, seq_len, d_model = x.shape
# Q, K, V 한번에 계산
qkv = self.wqkv(x)
qkv = qkv.view(batch, seq_len, 3, self.num_heads, self.d_head)
q, k, v = qkv.unbind(dim=2)
# (batch, heads, seq_len, d_head)로 변환
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# PyTorch 내장 scaled_dot_product_attention (Flash Attention 자동 선택)
out = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal,
)
# (batch, seq_len, d_model)
out = out.transpose(1, 2).contiguous().view(batch, seq_len, d_model)
return self.wo(out)
11. Mixture of Experts (MoE)
11.1 MoE의 아이디어
Mixture of Experts는 모델 용량(파라미터 수)을 늘리면서도 연산량은 서브리니어하게 유지하는 방법입니다. 핵심 아이디어는 각 토큰을 "전문가(Expert)" 네트워크들 중 일부에만 라우팅하는 것입니다.
예를 들어, 8개의 FFN(전문가)을 가진 MoE에서 각 토큰이 Top-2 전문가에게만 라우팅된다면:
- 총 파라미터: 8개 FFN × 크기 = Dense 모델의 8배
- 실제 연산: 2개 FFN만 사용 = Dense 모델과 동일
11.2 Top-k Routing
class MoELayer(nn.Module):
"""Mixture of Experts FFN 레이어"""
def __init__(
self,
d_model: int,
d_ff: int,
num_experts: int = 8,
top_k: int = 2,
dropout: float = 0.0,
):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
# 라우터 (게이팅 네트워크)
self.router = nn.Linear(d_model, num_experts, bias=False)
# 전문가들
self.experts = nn.ModuleList([
SwiGLUFeedForward(d_model, d_ff)
for _ in range(num_experts)
])
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> tuple:
batch, seq_len, d_model = x.shape
x_flat = x.view(-1, d_model) # (batch*seq_len, d_model)
# 라우팅 스코어 계산
router_logits = self.router(x_flat) # (batch*seq_len, num_experts)
router_probs = F.softmax(router_logits, dim=-1)
# Top-k 전문가 선택
topk_probs, topk_indices = router_probs.topk(self.top_k, dim=-1)
# 선택된 전문가들의 확률을 정규화
topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True)
# 각 토큰을 전문가들에게 디스패치
output = torch.zeros_like(x_flat)
for i in range(self.top_k):
expert_idx = topk_indices[:, i] # 각 토큰이 선택한 i번째 전문가
prob = topk_probs[:, i:i+1] # 가중치
# 각 전문가로 토큰 그룹화
for e in range(self.num_experts):
token_mask = (expert_idx == e)
if token_mask.any():
expert_input = x_flat[token_mask]
expert_output = self.experts[e](expert_input)
output[token_mask] += prob[token_mask] * expert_output
# 보조 로드 밸런싱 손실 (모든 전문가를 균등하게 사용하도록)
# 이를 훈련 손실에 더해 전문가 활용 균등화
aux_loss = self._load_balancing_loss(router_probs)
output = output.view(batch, seq_len, d_model)
return self.dropout(output), aux_loss
def _load_balancing_loss(self, router_probs: torch.Tensor) -> torch.Tensor:
"""Switch Transformer 스타일 보조 손실"""
# 각 전문가가 선택되는 비율이 균등해야 함
num_tokens = router_probs.size(0)
# 각 전문가의 평균 라우팅 확률
avg_probs = router_probs.mean(dim=0) # (num_experts,)
# 각 전문가가 선택된 토큰 비율
top1_indices = router_probs.argmax(dim=-1)
expert_counts = torch.bincount(top1_indices, minlength=self.num_experts).float()
expert_fractions = expert_counts / num_tokens
# 보조 손실: 전문가 비율과 평균 확률의 내적
aux_loss = self.num_experts * (avg_probs * expert_fractions).sum()
return aux_loss
11.3 Mixtral, DeepSeek 아키텍처
Mixtral 8x7B:
- 8개의 전문가 FFN, Top-2 라우팅
- 실제 파라미터: 46.7B, 활성 파라미터: 12.9B (추론 시)
- 각 토큰이 2개 전문가만 사용
- 표준 MHA + RoPE + SwiGLU + 그룹화된 전문가
DeepSeek MoE:
- Fine-grained Expert Segmentation: 전문가를 더 작게 나눔
- Shared Expert: 모든 토큰이 공유하는 기본 전문가 + 라우팅 전문가
- 전문가 붕괴(Expert Collapse) 방지 전략
12. 마치며
Transformer 아키텍처는 단순한 Attention 메커니즘에서 시작하여, 오늘날 수천억 파라미터의 LLM을 구동하는 핵심 구조로 발전했습니다.
이 글에서 다룬 핵심 개념들을 정리하면:
- Scaled Dot-Product Attention: Q, K, V의 내적 기반 정보 검색
- Multi-Head Attention: 여러 표현 공간에서의 병렬 Attention
- Positional Encoding: 사인파 PE → RoPE → ALiBi
- Encoder/Decoder: BERT/GPT 패밀리의 기초
- 현대 최적화: RMSNorm, SwiGLU, GQA, KV Cache
- Flash Attention: IO-aware 메모리 효율적 Attention
- MoE: 희소 활성화로 효율적인 스케일링
다음으로 공부할 것들:
- Speculative Decoding (추론 가속)
- LoRA/QLoRA 파인튜닝
- LLM 정렬 (RLHF, DPO)
- vLLM, TensorRT-LLM 서빙 최적화
참고 자료
- Vaswani et al. (2017). "Attention Is All You Need." — https://arxiv.org/abs/1706.03762
- Su et al. (2022). "RoFormer: Enhanced Transformer with Rotary Position Embedding." — https://arxiv.org/abs/2104.09864
- Dao et al. (2022). "FlashAttention: Fast and Memory-Efficient Exact Attention." — https://arxiv.org/abs/2205.14135
- Dao (2023). "FlashAttention-2: Faster Attention with Better Parallelism." — https://arxiv.org/abs/2307.08691
- Ainslie et al. (2023). "GQA: Training Generalized Multi-Query Transformer Models." — https://arxiv.org/abs/2305.13245
- Shazeer (2020). "GLU Variants Improve Transformer." — https://arxiv.org/abs/2002.05202
- Jiang et al. (2024). "Mixtral of Experts." — https://arxiv.org/abs/2401.04088