- Published on
KAN(Kolmogorov-Arnold Networks) 논문 분석: MLP를 대체하는 학습 가능한 활성화 함수 아키텍처
- Authors
- Name
- 들어가며
- Kolmogorov-Arnold 표현 정리
- KAN 아키텍처 상세 분석
- B-스플라인 활성화 함수
- KAN vs MLP 비교 분석
- 구현 코드
- 학습과 시각화
- 실전 적용 사례
- 한계와 향후 방향
- 트러블슈팅과 최적화 팁
- 참고자료

들어가며
딥러닝의 기본 구성 단위인 **Multi-Layer Perceptron(MLP)**은 1958년 Frank Rosenblatt의 퍼셉트론 이래 60년 넘게 신경망 아키텍처의 핵심으로 자리 잡아 왔다. MLP의 구조는 명확하다. 각 뉴런(노드)에 고정된 활성화 함수(ReLU, SiLU 등)를 적용하고, 엣지(가중치)는 학습 가능한 선형 변환을 수행한다. 이 단순하면서도 강력한 구조는 Universal Approximation Theorem에 의해 이론적 정당성을 확보했으며, GPU 연산에 최적화된 행렬 곱셈 덕분에 실용적으로도 지배적인 위치를 유지해 왔다.
그런데 2024년 4월, MIT의 Ziming Liu, Max Tegmark 등이 발표한 "KAN: Kolmogorov-Arnold Networks" 논문은 이 60년 된 패러다임에 근본적인 질문을 던졌다. 활성화 함수가 반드시 노드에 고정되어 있어야 하는가? 엣지(가중치)가 단순한 스칼라 곱이어야 하는가? KAN은 이 두 가정을 모두 뒤집는다. 활성화 함수를 노드가 아닌 엣지에 배치하고, 각 활성화 함수를 B-스플라인으로 매개변수화하여 학습 가능하게 만든 것이다. 이 아이디어의 수학적 근거는 1957년 Andrey Kolmogorov와 그의 제자 Vladimir Arnold가 증명한 Kolmogorov-Arnold 표현 정리에서 비롯된다.
KAN은 ICLR 2025에 채택되었으며, 함수 피팅, 편미분방정식(PDE) 풀이, 기호 회귀(symbolic regression) 등 과학적 탐구 태스크에서 MLP보다 적은 파라미터로 높은 정확도를 달성했다. 더 나아가, 학습된 활성화 함수를 시각화하고 기호적으로 해석할 수 있어 해석가능성(interpretability) 측면에서도 MLP를 크게 앞선다.
이 글에서는 KAN 논문의 수학적 기반부터 아키텍처 설계, B-스플라인 활성화 함수의 동작 원리, MLP와의 체계적 비교, PyTorch 구현 코드, 시각화 방법, 실전 적용 사례, 그리고 현재의 한계까지 포괄적으로 분석한다.
Kolmogorov-Arnold 표현 정리
정리의 수학적 정의
KAN의 이론적 기반은 1957년에 증명된 **Kolmogorov-Arnold 표현 정리(Kolmogorov-Arnold Representation Theorem)**이다. 이 정리는 힐베르트의 13번째 문제(Hilbert's 13th Problem)에 대한 답으로, 다변수 연속 함수를 단변수 연속 함수의 합성으로 표현할 수 있음을 증명했다.
정리를 수학적으로 기술하면 다음과 같다. 유계 영역 위의 임의의 연속 함수 에 대해 다음이 성립한다:
여기서 은 **내부 함수(inner function)**이고, 은 **외부 함수(outer function)**이다. 핵심은 내부 함수와 외부 함수가 모두 단변수(univariate) 연속 함수라는 점이다. 즉, 아무리 복잡한 다변수 함수라 하더라도 단변수 함수의 합성과 덧셈만으로 정확히 표현할 수 있다.
MLP의 Universal Approximation Theorem과의 차이
MLP의 이론적 기반인 Universal Approximation Theorem은 **근사(approximation)**를 보장하는 반면, Kolmogorov-Arnold 표현 정리는 **정확한 표현(exact representation)**을 보장한다. 두 정리를 비교하면 다음과 같다:
| 항목 | Universal Approximation (MLP) | Kolmogorov-Arnold (KAN) |
|---|---|---|
| 보장 수준 | 임의 정밀도의 근사 | 정확한 표현 |
| 구조 | 고정 활성화 + 학습 가중치 | 학습 가능한 단변수 함수 합성 |
| 뉴런 수 | 무한히 많을 수 있음 | 개의 내부 합으로 충분 |
| 활성화 함수 위치 | 노드(뉴런) | 엣지(가중치) |
| 해석가능성 | 블랙박스 | 단변수 함수 시각화 가능 |
신경망 관점에서의 재해석
Kolmogorov-Arnold 표현 정리를 신경망의 관점에서 재해석하면, 이 정리는 **2개의 은닉층(hidden layer)**을 가진 네트워크로 볼 수 있다. 첫 번째 은닉층에서는 개의 단변수 함수 가 각 입력 차원에 적용되고, 두 번째 은닉층에서는 개의 외부 함수 가 내부 합에 적용된다.
하지만 원래 정리의 와 는 매끄럽지 않거나 심지어 프랙탈적 구조를 가질 수 있어, 실제 학습에 직접 사용하기 어렵다. KAN 논문의 핵심 기여는 이 정리를 임의 깊이의 네트워크로 일반화하고, 각 단변수 함수를 B-스플라인으로 매개변수화함으로써 실용적으로 만든 것이다.
KAN 아키텍처 상세 분석
MLP와 KAN의 구조적 차이
MLP와 KAN의 근본적인 차이는 활성화 함수와 가중치의 역할이 뒤바뀐다는 점이다.
MLP: 엣지에서 선형 변환() 수행, 노드에서 비선형 활성화 적용
KAN: 엣지에서 학습 가능한 비선형 함수 적용, 노드에서는 단순 합산만 수행
여기서 는 번째 입력에서 번째 출력으로 가는 엣지에 배치된 학습 가능한 단변수 함수이다.
KAN 레이어 아키텍처 다이어그램
다음은 KAN 레이어의 구조를 보여주는 다이어그램이다. 핵심은 각 엣지가 단순한 스칼라 가중치가 아니라 학습 가능한 B-스플라인 함수라는 점이다.
KAN Layer: [2, 3] (2 inputs -> 3 outputs)
Input Layer Learnable Activations (Edges) Output Layer
(B-spline functions)
+------ phi_1,1(x) --------+
| |
x_1 -------+------ phi_2,1(x) --------+-----> y_1 = phi_1,1(x_1) + phi_1,2(x_2)
| |
+------ phi_3,1(x) --------+
| |
| +-----> y_2 = phi_2,1(x_1) + phi_2,2(x_2)
| |
x_2 -------+------ phi_1,2(x) --------+
| |
+------ phi_2,2(x) --------+-----> y_3 = phi_3,1(x_1) + phi_3,2(x_2)
| |
+------ phi_3,2(x) --------+
[Nodes: SUM only] [Edges: Learnable] [Nodes: SUM only]
Total learnable functions: n_in x n_out = 2 x 3 = 6
Each phi_{j,i} is a B-spline with (G + k) parameters
where G = grid intervals, k = spline order
다층 KAN 네트워크
Kolmogorov-Arnold 표현 정리의 원래 구조는 2개의 레이어로 제한되지만, KAN 논문은 이를 임의의 깊이로 확장한다. 깊이가 인 KAN 네트워크에서 각 레이어 은 개의 입력과 개의 출력을 가지며, 전체 네트워크는 다음과 같이 표현된다:
여기서 각 은 KAN 레이어이고, 레이어 에는 총 개의 학습 가능한 단변수 함수가 존재한다. 전체 네트워크의 학습 가능한 함수 수는 다음과 같다:
KAN의 파라미터 수 계산
각 단변수 함수가 개의 그리드 구간과 차 B-스플라인으로 매개변수화되면, 함수 하나당 개의 계수가 필요하다. 따라서 전체 네트워크의 파라미터 수는 다음과 같다:
예를 들어, 폭(width)이 이고 , 인 KAN의 경우:
- 레이어 0: 파라미터
- 레이어 1: 파라미터
- 총 파라미터: 개
같은 구조의 MLP 은 개의 파라미터를 가진다. KAN이 더 많은 파라미터를 가지지만, 논문에서는 파라미터 대비 정확도 측면에서 KAN이 MLP를 압도한다고 주장한다.
B-스플라인 활성화 함수
B-스플라인의 정의
B-스플라인(Basis Spline)은 구간별 다항식(piecewise polynomial)의 선형 결합으로 정의되는 매끈한 함수이다. 차 B-스플라인 기저 함수는 De Boor의 재귀 공식으로 정의된다:
여기서 는 매듭(knot) 벡터의 원소이다. 개의 그리드 구간과 차 스플라인을 사용하면, 개의 B-스플라인 기저 함수가 생성되며, 스플라인 함수는 이들의 선형 결합으로 표현된다:
여기서 는 학습 가능한 계수이다.
KAN에서의 활성화 함수 구조
KAN의 각 엣지에 배치된 활성화 함수 는 두 부분으로 구성된다:
여기서 는 **잔차 함수(residual function)**로 기본값은 이며, 와 는 각각의 스케일링 계수이다. 잔차 함수는 학습 초기에 안정성을 제공하고, 스플라인 부분이 점진적으로 복잡한 패턴을 학습하도록 돕는다.
B-스플라인 기저 함수 구현
다음은 PyTorch로 B-스플라인 기저 함수를 구현하는 코드이다:
import torch
import torch.nn as nn
import numpy as np
def compute_bspline_basis(x: torch.Tensor, grid: torch.Tensor, k: int) -> torch.Tensor:
"""B-스플라인 기저 함수를 De Boor 재귀 공식으로 계산한다.
Args:
x: 입력 텐서 [batch_size, in_features]
grid: 매듭(knot) 벡터 [in_features, grid_size + 2k + 1]
k: 스플라인 차수 (보통 3, 즉 3차 스플라인)
Returns:
기저 함수 값 [batch_size, in_features, grid_size + k]
"""
# x를 grid와 비교할 수 있도록 차원 확장
# x: [batch, in_features] -> [batch, in_features, 1]
x = x.unsqueeze(-1)
# 0차 기저 함수: 구간 내에 있으면 1, 아니면 0
bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).float()
# De Boor 재귀: k번 반복하여 k차 기저 함수 계산
for degree in range(1, k + 1):
# 왼쪽 항: (x - t_i) / (t_{i+degree} - t_i)
left_num = x - grid[:, :-(degree + 1)]
left_den = grid[:, degree:-1] - grid[:, :-(degree + 1)]
left = left_num / left_den.clamp(min=1e-8)
# 오른쪽 항: (t_{i+degree+1} - x) / (t_{i+degree+1} - t_{i+1})
right_num = grid[:, (degree + 1):] - x
right_den = grid[:, (degree + 1):] - grid[:, 1:(-degree)]
right = right_num / right_den.clamp(min=1e-8)
# 재귀 결합
bases = left * bases[:, :, :-1] + right * bases[:, :, 1:]
return bases # [batch, in_features, grid_size + k]
# 사용 예시
batch_size, in_features = 32, 4
grid_size, k = 5, 3 # 5개 그리드 구간, 3차 스플라인
# 균일 그리드 생성: [-1, 1] 범위에 k개씩 양쪽 확장
grid_points = torch.linspace(-1, 1, grid_size + 1)
extended_grid = torch.cat([
grid_points[0:1] - torch.arange(k, 0, -1) * (grid_points[1] - grid_points[0]),
grid_points,
grid_points[-1:] + torch.arange(1, k + 1) * (grid_points[1] - grid_points[0])
])
# in_features 차원으로 복제
grid = extended_grid.unsqueeze(0).repeat(in_features, 1)
x = torch.randn(batch_size, in_features)
basis_values = compute_bspline_basis(x, grid, k)
print(f"기저 함수 shape: {basis_values.shape}")
# 출력: 기저 함수 shape: torch.Size([32, 4, 8]) # G + k = 5 + 3 = 8
그리드 업데이트 메커니즘
KAN의 중요한 특징 중 하나는 학습 도중 그리드를 동적으로 업데이트할 수 있다는 점이다. 초기에는 균일한 그리드를 사용하지만, 학습 데이터의 분포에 맞게 그리드를 재배치하면 근사 정확도를 크게 높일 수 있다. grid_eps 하이퍼파라미터가 이를 제어한다:
grid_eps = 1.0: 완전히 균일한 그리드grid_eps = 0.0: 데이터 샘플의 분위수(quantile)에 따른 그리드0 < grid_eps < 1: 두 극단의 보간
그리드 구간 수 를 점진적으로 늘리면(grid extension), 해상도를 높여 더 복잡한 함수를 표현할 수 있다. 이는 MLP에서 뉴런 수를 늘리는 것과 유사하지만, 기존 학습 결과를 보존하면서 해상도만 높일 수 있다는 장점이 있다.
KAN vs MLP 비교 분석
체계적 비교표
| 비교 항목 | MLP | KAN |
|---|---|---|
| 활성화 함수 위치 | 노드 (고정: ReLU, SiLU 등) | 엣지 (학습 가능: B-스플라인) |
| 가중치 역할 | 학습 가능한 선형 변환 | 없음 (활성화 함수가 대체) |
| 파라미터 효율성 | 낮음 (폭을 넓혀야 정확도 향상) | 높음 (그리드 해상도로 정확도 향상) |
| 스케일링 법칙 | , 가 작음 | , 가 더 큼 |
| 해석가능성 | 블랙박스 | 활성화 함수 시각화 가능 |
| 기호 회귀 | 불가 | 학습 후 기호 함수 추출 가능 |
| 학습 속도 | 빠름 (GPU 행렬 곱 최적화) | 느림 (스플라인 계산 오버헤드) |
| GPU 병렬화 | 우수 (배치 행렬 곱) | 제한적 (각 엣지마다 다른 함수) |
| 고차원 입력 | 우수 (이미지, NLP 등) | 제한적 (차원의 저주 영향) |
| 주요 적용 분야 | 범용 딥러닝 | 과학적 발견, 기호 회귀 |
정확도 스케일링 비교
KAN 논문에서 가장 인상적인 결과 중 하나는 **신경 스케일링 법칙(Neural Scaling Law)**에서의 차이다. 테스트 RMSE를 파라미터 수 의 함수로 표현할 때:
MLP의 경우 가 Universal Approximation Theorem에 의해 상한이 정해지지만, KAN은 Kolmogorov-Arnold 표현 정리에 기반하여 더 빠른 스케일링(더 큰 )을 보여준다. 논문의 실험에서 5개의 특수 함수에 대해, KAN [2,5,1]은 파라미터 80개로 MLP [2,100,1]의 파라미터 201개보다 2-3배 적으면서도 동등하거나 더 나은 정확도를 달성했다.
학습 속도와 현실적 고려사항
KAN의 학습 속도는 MLP에 비해 상당히 느리다. 주요 원인은 다음과 같다:
- 스플라인 기저 함수 계산: 각 엣지마다 개의 기저 함수를 재귀적으로 계산해야 한다.
- GPU 병렬화 한계: MLP는 가중치 행렬의 행렬 곱으로 한 번에 계산되지만, KAN은 엣지마다 서로 다른 함수를 적용하므로 효율적인 배치 행렬 곱을 활용하기 어렵다.
- 그리드 업데이트 오버헤드: 학습 중 그리드 재배치가 추가적인 연산 비용을 발생시킨다.
"KAN or MLP: A Fairer Comparison" (Yu et al., 2024) 논문에 따르면, MLP가 대부분의 표준 머신러닝, 컴퓨터 비전, NLP, 오디오 처리 태스크에서 더 높은 평균 정확도를 보였다. KAN이 우위를 보인 것은 기호 수식 표현(symbolic formula representation)과 같은 특정 과학적 태스크에 한정되었다.
구현 코드
KAN 레이어 PyTorch 구현
다음은 KAN 레이어를 처음부터(from scratch) PyTorch로 구현한 코드이다:
import torch
import torch.nn as nn
import torch.nn.functional as F
class KANLayer(nn.Module):
"""Kolmogorov-Arnold Network 레이어.
각 엣지에 학습 가능한 B-스플라인 활성화 함수를 배치한다.
노드에서는 단순 합산만 수행한다.
"""
def __init__(
self,
in_features: int,
out_features: int,
grid_size: int = 5,
spline_order: int = 3,
scale_base: float = 1.0,
scale_spline: float = 1.0,
grid_range: tuple = (-1.0, 1.0),
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.grid_size = grid_size
self.spline_order = spline_order
self.scale_base = scale_base
self.scale_spline = scale_spline
# 균일 그리드 생성 및 확장 (양쪽으로 spline_order만큼)
h = (grid_range[1] - grid_range[0]) / grid_size
grid = torch.arange(-spline_order, grid_size + spline_order + 1) * h + grid_range[0]
# grid shape: [grid_size + 2 * spline_order + 1]
# 각 (in, out) 쌍마다 별도 그리드를 가질 수 있도록 확장
self.register_buffer(
"grid",
grid.unsqueeze(0).unsqueeze(0).expand(out_features, in_features, -1)
)
# 스플라인 계수: 각 엣지마다 (grid_size + spline_order)개
self.spline_weight = nn.Parameter(
torch.randn(out_features, in_features, grid_size + spline_order)
* 0.1
)
# 잔차(base) 가중치
self.base_weight = nn.Parameter(
torch.randn(out_features, in_features) * 0.1
)
def b_splines(self, x: torch.Tensor) -> torch.Tensor:
"""B-스플라인 기저 함수 계산.
Args:
x: [batch_size, in_features]
Returns:
기저 함수 값 [batch_size, in_features, grid_size + spline_order]
"""
# x를 [batch, 1, in_features, 1]로 확장하여 grid와 비교
x = x.unsqueeze(1).unsqueeze(-1) # [B, 1, in, 1]
# grid: [out, in, G+2k+1] -> 방송을 위해
grid = self.grid # [out, in, G+2k+1]
# 0차 기저: 구간 내에 있으면 1
# grid[:, :, :-1]과 grid[:, :, 1:]를 사용
# 첫 번째 out 차원의 그리드만 사용 (모두 동일하므로)
g = grid[0] # [in, G+2k+1]
bases = ((x[:, 0] >= g[:, :-1]) & (x[:, 0] < g[:, 1:])).float()
# bases: [B, in, G+2k]
for degree in range(1, self.spline_order + 1):
left_num = x[:, 0] - g[:, :-(degree + 1)]
left_den = g[:, degree:-1] - g[:, :-(degree + 1)]
left = left_num / left_den.clamp(min=1e-8)
right_num = g[:, (degree + 1):] - x[:, 0]
right_den = g[:, (degree + 1):] - g[:, 1:(-degree)]
right = right_num / right_den.clamp(min=1e-8)
bases = left * bases[:, :, :-1] + right * bases[:, :, 1:]
return bases # [B, in, G+k]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""순전파: 잔차 함수 + 스플라인 함수.
Args:
x: [batch_size, in_features]
Returns:
출력 [batch_size, out_features]
"""
batch_size = x.size(0)
# 1) 잔차 부분: SiLU 활성화 후 선형 결합
base_output = F.silu(x) # [B, in]
base_output = torch.einsum("bi,oi->bo", base_output, self.base_weight)
# base_output: [B, out]
# 2) 스플라인 부분: B-스플라인 기저 * 계수
spline_basis = self.b_splines(x) # [B, in, G+k]
spline_output = torch.einsum(
"big,oig->bo", spline_basis, self.spline_weight
)
# spline_output: [B, out]
# 3) 결합
return self.scale_base * base_output + self.scale_spline * spline_output
완전한 KAN 네트워크 구현
class KAN(nn.Module):
"""다층 Kolmogorov-Arnold Network.
Args:
width: 각 레이어의 뉴런 수 리스트. 예: [2, 5, 1]
grid_size: B-스플라인 그리드 구간 수
spline_order: B-스플라인 차수 (기본값: 3, 3차 스플라인)
"""
def __init__(
self,
width: list,
grid_size: int = 5,
spline_order: int = 3,
):
super().__init__()
self.width = width
self.layers = nn.ModuleList()
for i in range(len(width) - 1):
self.layers.append(
KANLayer(
in_features=width[i],
out_features=width[i + 1],
grid_size=grid_size,
spline_order=spline_order,
)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for layer in self.layers:
x = layer(x)
return x
def parameter_count(self) -> dict:
"""레이어별 파라미터 수를 반환한다."""
info = {}
total = 0
for i, layer in enumerate(self.layers):
layer_params = sum(p.numel() for p in layer.parameters())
info[f"layer_{i}"] = layer_params
total += layer_params
info["total"] = total
return info
# 사용 예시
model = KAN(width=[2, 5, 1], grid_size=5, spline_order=3)
print(f"모델 구조: {model.width}")
print(f"파라미터 수: {model.parameter_count()}")
# 순전파 테스트
x = torch.randn(16, 2)
y = model(x)
print(f"입력: {x.shape}, 출력: {y.shape}")
# 출력 예시:
# 모델 구조: [2, 5, 1]
# 파라미터 수: {'layer_0': 90, 'layer_1': 45, 'total': 135}
# 입력: torch.Size([16, 2]), 출력: torch.Size([16, 1])
MLP 비교 구현
동일한 구조의 MLP를 구현하여 공정한 비교를 수행한다:
class MLP(nn.Module):
"""비교용 Multi-Layer Perceptron.
KAN과 동일한 width 구조를 사용하되,
고정 활성화 함수(SiLU)와 선형 가중치를 사용한다.
"""
def __init__(self, width: list):
super().__init__()
layers = []
for i in range(len(width) - 1):
layers.append(nn.Linear(width[i], width[i + 1]))
if i < len(width) - 2: # 마지막 레이어 제외
layers.append(nn.SiLU())
self.network = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.network(x)
# MLP vs KAN 파라미터 비교
mlp = MLP(width=[2, 5, 1])
kan = KAN(width=[2, 5, 1], grid_size=5, spline_order=3)
mlp_params = sum(p.numel() for p in mlp.parameters())
kan_params = sum(p.numel() for p in kan.parameters())
print(f"MLP 파라미터: {mlp_params}")
print(f"KAN 파라미터: {kan_params}")
print(f"KAN/MLP 파라미터 비율: {kan_params / mlp_params:.1f}x")
# 출력 예시:
# MLP 파라미터: 21
# KAN 파라미터: 135
# KAN/MLP 파라미터 비율: 6.4x
학습과 시각화
학습 루프 구현
KAN 논문에서는 LBFGS 옵티마이저를 권장하지만, 대규모 데이터셋에서는 Adam도 사용할 수 있다. 다음은 함수 근사 태스크를 위한 완전한 학습 루프이다:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
def create_dataset(f, n_train=1000, n_test=200, input_dim=2):
"""학습/테스트 데이터셋 생성.
Args:
f: 근사할 대상 함수
n_train: 학습 샘플 수
n_test: 테스트 샘플 수
input_dim: 입력 차원
"""
x_train = torch.rand(n_train, input_dim) * 2 - 1 # [-1, 1]
y_train = f(x_train)
x_test = torch.rand(n_test, input_dim) * 2 - 1
y_test = f(x_test)
return x_train, y_train, x_test, y_test
def train_model(model, x_train, y_train, x_test, y_test,
epochs=500, lr=1e-2, optimizer_type="adam"):
"""KAN 또는 MLP 모델 학습.
Args:
model: KAN 또는 MLP 인스턴스
optimizer_type: 'adam' 또는 'lbfgs'
Returns:
학습 이력 딕셔너리
"""
if optimizer_type == "lbfgs":
optimizer = torch.optim.LBFGS(
model.parameters(), lr=lr, max_iter=20,
history_size=10, line_search_fn="strong_wolfe"
)
else:
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.MSELoss()
history = {"train_loss": [], "test_loss": []}
for epoch in range(epochs):
model.train()
if optimizer_type == "lbfgs":
def closure():
optimizer.zero_grad()
pred = model(x_train)
loss = criterion(pred, y_train)
loss.backward()
return loss
loss = optimizer.step(closure)
else:
optimizer.zero_grad()
pred = model(x_train)
loss = criterion(pred, y_train)
loss.backward()
optimizer.step()
# 평가
model.eval()
with torch.no_grad():
train_loss = criterion(model(x_train), y_train).item()
test_loss = criterion(model(x_test), y_test).item()
history["train_loss"].append(train_loss)
history["test_loss"].append(test_loss)
if (epoch + 1) % 100 == 0:
print(f"Epoch {epoch+1}/{epochs} | "
f"Train MSE: {train_loss:.6f} | "
f"Test MSE: {test_loss:.6f}")
return history
# 대상 함수: f(x1, x2) = exp(sin(pi * x1) + x2^2)
target_fn = lambda x: torch.exp(
torch.sin(torch.pi * x[:, 0:1]) + x[:, 1:2] ** 2
)
x_train, y_train, x_test, y_test = create_dataset(target_fn)
# KAN 학습
kan_model = KAN(width=[2, 5, 1], grid_size=5, spline_order=3)
kan_history = train_model(kan_model, x_train, y_train, x_test, y_test,
epochs=500, lr=1e-2, optimizer_type="adam")
# MLP 학습 (파라미터 수를 맞추기 위해 폭을 넓힘)
mlp_model = MLP(width=[2, 50, 50, 1])
mlp_history = train_model(mlp_model, x_train, y_train, x_test, y_test,
epochs=500, lr=1e-3, optimizer_type="adam")
# 손실 곡선 비교 시각화
fig, ax = plt.subplots(1, 1, figsize=(10, 6))
ax.semilogy(kan_history["test_loss"], label="KAN [2,5,1]", linewidth=2)
ax.semilogy(mlp_history["test_loss"], label="MLP [2,50,50,1]", linewidth=2)
ax.set_xlabel("Epoch")
ax.set_ylabel("Test MSE (log scale)")
ax.set_title("KAN vs MLP: Function Approximation")
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig("kan_vs_mlp_loss.png", dpi=150)
plt.show()
pykan 라이브러리를 활용한 학습과 시각화
실제 연구나 실험에서는 공식 pykan 라이브러리를 사용하는 것이 편리하다:
# pip install pykan
from kan import KAN as PyKAN
import torch
# 대상 함수 정의
target_fn = lambda x: torch.exp(
torch.sin(torch.pi * x[:, [0]]) + x[:, [1]] ** 2
)
# 데이터셋 생성 (pykan 내장 유틸리티 활용)
dataset = {
"train_input": torch.rand(1000, 2) * 2 - 1,
"test_input": torch.rand(200, 2) * 2 - 1,
}
dataset["train_label"] = target_fn(dataset["train_input"])
dataset["test_label"] = target_fn(dataset["test_input"])
# KAN 모델 생성 및 학습
model = PyKAN(width=[2, 5, 1], grid=5, k=3, seed=42)
# 학습 실행 (LBFGS 옵티마이저 사용)
results = model.fit(dataset, opt="LBFGS", steps=50, lamb=0.01)
# 학습된 활성화 함수 시각화
# 각 엣지의 학습된 B-스플라인 함수를 플롯으로 표시
model.plot()
# 기호 회귀: 학습된 함수에서 수학적 수식 추출 시도
# model.auto_symbolic()을 통해 sin, exp 등 기호 함수로 매핑
model.auto_symbolic()
formula = model.symbolic_formula()
print(f"추출된 수식: {formula}")
# 예상 출력: exp(sin(pi*x1) + x2^2)
실전 적용 사례
1. 편미분방정식(PDE) 풀이
KAN은 물리학과 공학에서 편미분방정식을 풀기 위한 신경망(Physics-Informed Neural Networks, PINNs) 대안으로 주목받고 있다. 논문에서 포아송 방정식 에 대해 KAN이 MLP 기반 PINN보다 100배 적은 파라미터로 동일한 정확도를 달성했다.
적용 분야:
- 유체역학 시뮬레이션
- 열전달 방정식
- 파동 방정식
2. 기호 회귀(Symbolic Regression)
KAN의 가장 강력한 응용은 데이터로부터 수학적 수식을 발견하는 기호 회귀이다. 학습된 활성화 함수를 시각화하고 알려진 기호 함수(sin, cos, exp 등)와 매칭하면, 데이터에 숨겨진 수학적 관계를 추출할 수 있다.
KAN 2.0 논문(Liu et al., 2024)에서는 이를 확장하여 다음과 같은 물리 법칙을 재발견했다:
- 보존량(Conserved Quantities): 역학 시스템의 에너지 보존 법칙
- 라그랑지안(Lagrangians): 변분 원리로부터의 운동 방정식
- 대칭성(Symmetries): 뇌터 정리에 기반한 대칭 구조
- 구성 법칙(Constitutive Laws): 재료의 응력-변형 관계
3. 매듭 이론(Knot Theory)
수학 분야에서도 KAN은 의미 있는 결과를 보여주었다. 매듭 불변량(knot invariants) 간의 비자명한 관계를 재발견하는 데 성공했으며, 이는 순수 수학 연구에서 AI의 역할 확장 가능성을 시사한다.
4. 응집물질 물리학
상전이(phase transition) 경계를 식별하는 태스크에서 KAN은 MLP보다 더 정확한 경계를 학습했으며, 학습된 활성화 함수를 통해 상전이의 물리적 의미를 해석할 수 있었다.
한계와 향후 방향
현재의 한계
1. 학습 속도
KAN은 MLP에 비해 학습 속도가 현저히 느리다. 스플라인 기저 함수 계산과 각 엣지마다 서로 다른 함수를 적용해야 하는 구조적 특성 때문에, GPU의 배치 행렬 곱 가속을 충분히 활용하지 못한다. 동일한 태스크에서 KAN의 학습 시간이 MLP의 10배 이상 소요되는 경우도 보고되고 있다.
2. 고차원 입력의 확장성
입력 차원이 높아질수록(예: 이미지의 수만 픽셀) KAN의 성능이 저하된다. CIFAR-10과 같은 표준 컴퓨터 비전 벤치마크에서 KAN은 MLP에 비해 뚜렷한 우위를 보이지 못하며, 스플라인 파라미터 튜닝이 어려워진다.
3. 노이즈 민감성
KAN은 깨끗한 함수형 데이터에서는 우수한 성능을 보이지만, 노이즈가 심한 실제 데이터에서는 MLP보다 민감할 수 있다. 스플라인 함수가 노이즈를 과적합(overfit)하려는 경향이 있기 때문이다.
4. 메모리 사용량
각 엣지마다 개의 스플라인 계수와 그리드 정보를 저장해야 하므로, 동일한 구조의 MLP에 비해 메모리 사용량이 상당히 높다. 대규모 네트워크에서는 이것이 실질적인 병목이 될 수 있다.
5. 표준 태스크에서의 성능
기계 학습, 컴퓨터 비전, NLP, 오디오 처리 등 일반적인 태스크에서는 MLP가 KAN보다 평균적으로 더 높은 정확도를 보인다. KAN이 진정한 우위를 보이는 영역은 기호 수식 표현, 과학적 발견 등 특정 도메인에 한정된다.
향후 연구 방향
1. 효율적인 KAN 구현
EfficientKAN, FastKAN, FourierKAN 등 KAN의 계산 효율성을 높이려는 다양한 변형이 제안되고 있다. B-스플라인 대신 체비셰프 다항식이나 푸리에 기저를 사용하는 접근이 학습 속도를 개선하면서도 KAN의 장점을 유지할 수 있는지 탐구 중이다.
2. 하이브리드 아키텍처
MLP와 KAN의 장점을 결합한 하이브리드 아키텍처가 유망한 방향이다. 예를 들어, 트랜스포머의 FFN 레이어에서 MLP 대신 KAN을 사용하거나, 특정 레이어에서만 선택적으로 KAN을 적용하는 방식이 연구되고 있다.
3. KAN 2.0과 과학적 발견
KAN 2.0 논문에서는 MultKAN(곱셈 노드 포함), 그리드 자동 최적화, 다양한 과학 도메인 적용 등을 통해 KAN의 실용성을 크게 확장했다. Physical Review X에 게재된 이 후속 연구는 KAN이 단순한 함수 근사 도구를 넘어 과학적 발견 도구로서의 가능성을 보여준다.
트러블슈팅과 최적화 팁
1. 학습이 수렴하지 않을 때
KAN 학습에서 가장 흔한 문제는 수렴 실패이다. 다음을 확인해 보자:
# 그리드 범위가 입력 데이터 범위를 포함하는지 확인
x_min, x_max = x_train.min().item(), x_train.max().item()
print(f"데이터 범위: [{x_min:.2f}, {x_max:.2f}]")
# 그리드 범위를 데이터 범위보다 약간 넓게 설정
# grid_range=(-1.5, 1.5) 등으로 조정
# LBFGS 옵티마이저 사용 시 closure 패턴 필수
optimizer = torch.optim.LBFGS(
model.parameters(),
lr=0.1, # LBFGS는 보통 더 큰 lr 사용
max_iter=20, # 내부 반복 횟수
tolerance_grad=1e-7,
tolerance_change=1e-9,
line_search_fn="strong_wolfe" # 안정적 라인서치
)
# grid_eps 조정으로 그리드 적응
# 초기에는 0.5 정도로 시작하고, 학습 후반에 0.0으로 줄이기
2. 과적합 방지
KAN은 표현력이 매우 높아 과적합에 취약할 수 있다:
- 정규화: pykan에서는
lamb파라미터로 L1 정규화 적용. 엔트로피 정규화(lamb_entropy)도 활용 가능 - 그리드 크기 조절: 초기에 작은 (예: 3)로 시작하고, 학습이 안정화되면 점진적으로 늘리기
- 프루닝(Pruning): 학습 후 불필요한 엣지(거의 0에 가까운 활성화 함수)를 제거
3. 메모리 최적화
# 그래디언트 체크포인팅을 활용한 메모리 절약
from torch.utils.checkpoint import checkpoint
class MemoryEfficientKAN(nn.Module):
def __init__(self, width, grid_size=5, spline_order=3):
super().__init__()
self.layers = nn.ModuleList([
KANLayer(width[i], width[i+1], grid_size, spline_order)
for i in range(len(width) - 1)
])
def forward(self, x):
for layer in self.layers:
# 그래디언트 체크포인팅: 메모리 절약 (역전파 시 재계산)
x = checkpoint(layer, x, use_reentrant=False)
return x
4. 시각화를 통한 디버깅
KAN의 큰 장점은 학습된 활성화 함수를 직접 시각화할 수 있다는 점이다. 학습이 제대로 진행되고 있는지 확인하려면 각 엣지의 활성화 함수를 주기적으로 플롯하는 것이 유용하다:
import matplotlib.pyplot as plt
import torch
import numpy as np
def visualize_kan_activations(model, layer_idx=0, input_range=(-1, 1)):
"""KAN 레이어의 학습된 활성화 함수를 시각화한다.
Args:
model: KAN 모델 인스턴스
layer_idx: 시각화할 레이어 인덱스
input_range: 입력 범위
"""
layer = model.layers[layer_idx]
n_in = layer.in_features
n_out = layer.out_features
# 시각화용 입력 생성
x_vis = torch.linspace(input_range[0], input_range[1], 200)
fig, axes = plt.subplots(n_out, n_in, figsize=(4 * n_in, 3 * n_out))
if n_out == 1:
axes = axes.reshape(1, -1)
if n_in == 1:
axes = axes.reshape(-1, 1)
for j in range(n_out):
for i in range(n_in):
ax = axes[j][i]
# 단일 입력 차원에 대한 활성화 함수 계산
x_input = torch.zeros(200, n_in)
x_input[:, i] = x_vis
with torch.no_grad():
# 스플라인 부분만 추출
basis = layer.b_splines(x_input) # [200, in, G+k]
spline_out = torch.einsum(
"big,g->bi", basis[:, i:i+1, :],
layer.spline_weight[j, i, :]
)
ax.plot(x_vis.numpy(), spline_out[:, 0].numpy(),
linewidth=2, color="blue")
ax.set_title(f"phi_{j+1},{i+1}(x)", fontsize=10)
ax.grid(True, alpha=0.3)
ax.axhline(y=0, color="gray", linestyle="--", alpha=0.5)
plt.suptitle(f"Layer {layer_idx} Activation Functions", fontsize=14)
plt.tight_layout()
plt.savefig(f"kan_layer{layer_idx}_activations.png", dpi=150)
plt.show()
# 학습 후 시각화
# visualize_kan_activations(kan_model, layer_idx=0)
5. 성능 비교 시 주의사항
KAN과 MLP를 공정하게 비교하려면 다음 사항에 유의해야 한다:
- 파라미터 수 통제: 동일한 네트워크 구조가 아니라 동일한 파라미터 수로 비교해야 한다
- 학습 시간 통제: 에포크 수가 아닌 총 학습 시간(wall-clock time)으로 비교하는 것이 현실적이다
- 태스크 적합성: KAN은 저차원 함수 근사에서 강점을 보이고, MLP는 고차원 패턴 인식에서 강하다. 비교 시 다양한 태스크를 포함해야 한다
- 옵티마이저 선택: KAN은 LBFGS에서, MLP는 Adam에서 최적 성능을 보이는 경향이 있으므로, 각 모델에 최적의 옵티마이저를 사용해야 공정하다
참고자료
KAN: Kolmogorov-Arnold Networks - Ziming Liu, Yixuan Wang, Sachin Vaidya, Fabian Ruehle, James Halverson, Marin Soljacic, Thomas Y. Hou, Max Tegmark (ICLR 2025) https://arxiv.org/abs/2404.19756
KAN 2.0: Kolmogorov-Arnold Networks Meet Science - Ziming Liu et al. (Physical Review X, 2024) https://arxiv.org/abs/2408.10205
pykan: Official Python Library for KAN - KindXiaoming (GitHub) https://github.com/KindXiaoming/pykan
KAN or MLP: A Fairer Comparison - Runpeng Yu, Weihao Yu, Xinchao Wang (2024) https://arxiv.org/abs/2407.16674
Kolmogorov-Arnold Representation Theorem - Wikipedia https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Arnold_representation_theorem
pykan Documentation - Kolmogorov Arnold Network Documentation https://kindxiaoming.github.io/pykan/
Kolmogorov-Arnold Networks: The Latest Advance in Neural Networks, Simply Explained - Towards Data Science https://towardsdatascience.com/kolmogorov-arnold-networks-the-latest-advance-in-neural-networks-simply-explained-f083cf994a85/
A Comprehensive and FAIR Comparison Between MLP and KAN - ScienceDirect (2024) https://www.sciencedirect.com/science/article/abs/pii/S0045782524005462