Skip to content

Split View: Rust for AI 시스템: Candle, PyO3, axum으로 고성능 AI 추론 서버 구축까지

|

Rust for AI 시스템: Candle, PyO3, axum으로 고성능 AI 추론 서버 구축까지

들어가며

AI 시스템 개발에서 Python은 여전히 지배적이지만, 성능과 안전성이 요구되는 프로덕션 AI 인프라에서 Rust가 빠르게 주목받고 있습니다.

HuggingFace의 Candle, 고속 Python linter Ruff, Pydantic v2의 Rust 코어 — 이미 AI 생태계의 핵심 도구들이 Rust로 재작성되고 있습니다. 이 가이드는 AI 엔지니어를 위한 Rust 실전 입문입니다.

왜 AI 시스템에서 Rust인가?

항목PythonRust
실행 속도보통C++ 수준
메모리 안전성GC 의존컴파일 타임 보장
동시성GIL 제약데이터 레이스 없음
바이너리 크기런타임 필요단일 바이너리
엣지 배포어려움WASM, 임베디드 지원

1. Rust 기초: AI 엔지니어 관점

소유권(Ownership)과 메모리 안전성

Rust의 핵심은 소유권 시스템입니다. 가비지 컬렉터 없이도 메모리 안전성을 보장합니다.

fn main() {
    // 소유권 이동 (move)
    let tensor_data = vec![1.0f32, 2.0, 3.0, 4.0];
    let moved = tensor_data; // tensor_data는 더 이상 유효하지 않음

    // 빌림(borrowing) — 소유권 이전 없이 참조
    let weights = vec![0.1f32, 0.2, 0.3];
    let sum = compute_sum(&weights); // 불변 참조
    println!("weights still valid: {:?}, sum: {}", weights, sum);
}

fn compute_sum(data: &[f32]) -> f32 {
    data.iter().sum()
}

라이프타임(Lifetimes)

라이프타임은 참조가 유효한 범위를 컴파일러에 알려줍니다. AI 모델에서 가중치 데이터를 안전하게 참조할 때 중요합니다.

// 라이프타임 명시: 반환된 참조는 입력 슬라이스보다 오래 살 수 없음
fn longest_sequence<'a>(seq1: &'a [f32], seq2: &'a [f32]) -> &'a [f32] {
    if seq1.len() > seq2.len() {
        seq1
    } else {
        seq2
    }
}

struct ModelWeights<'a> {
    data: &'a [f32],  // 외부 버퍼를 참조 (복사 없이)
    shape: (usize, usize),
}

impl<'a> ModelWeights<'a> {
    fn new(data: &'a [f32], rows: usize, cols: usize) -> Self {
        ModelWeights { data, shape: (rows, cols) }
    }

    fn get(&self, row: usize, col: usize) -> f32 {
        self.data[row * self.shape.1 + col]
    }
}

트레이트(Traits)와 제네릭(Generics)

Python의 덕 타이핑과 달리, Rust는 컴파일 타임에 타입 검사를 합니다.

use std::ops::{Add, Mul};

// AI 연산을 위한 트레이트 정의
trait Activation {
    fn forward(&self, x: f32) -> f32;
    fn backward(&self, x: f32) -> f32; // 미분
}

struct ReLU;
struct Sigmoid;

impl Activation for ReLU {
    fn forward(&self, x: f32) -> f32 {
        x.max(0.0)
    }
    fn backward(&self, x: f32) -> f32 {
        if x > 0.0 { 1.0 } else { 0.0 }
    }
}

impl Activation for Sigmoid {
    fn forward(&self, x: f32) -> f32 {
        1.0 / (1.0 + (-x).exp())
    }
    fn backward(&self, x: f32) -> f32 {
        let s = self.forward(x);
        s * (1.0 - s)
    }
}

// 제네릭 레이어: 어떤 활성화 함수든 사용 가능
fn apply_activation<A: Activation>(activation: &A, inputs: &[f32]) -> Vec<f32> {
    inputs.iter().map(|&x| activation.forward(x)).collect()
}

fn main() {
    let relu = ReLU;
    let data = vec![-1.0, 0.5, 2.0, -0.3];
    let output = apply_activation(&relu, &data);
    println!("ReLU output: {:?}", output); // [0.0, 0.5, 2.0, 0.0]
}

Arc vs Rc: 멀티스레드 AI 추론

use std::sync::{Arc, Mutex};
use std::thread;

// Rc는 단일 스레드 전용 — AI 추론 서버에서 사용 불가
// Arc(Atomic Reference Count)는 스레드 안전

fn multi_thread_inference() {
    // 모델 가중치를 여러 스레드가 공유
    let model_weights: Arc<Vec<f32>> = Arc::new(vec![0.1, 0.2, 0.3, 0.4]);
    let results: Arc<Mutex<Vec<f32>>> = Arc::new(Mutex::new(Vec::new()));

    let mut handles = vec![];

    for i in 0..4 {
        let weights = Arc::clone(&model_weights);
        let results = Arc::clone(&results);

        let handle = thread::spawn(move || {
            // 각 스레드가 동일한 가중치로 추론
            let inference_result = weights.iter().sum::<f32>() * i as f32;
            results.lock().unwrap().push(inference_result);
        });
        handles.push(handle);
    }

    for handle in handles {
        handle.join().unwrap();
    }

    println!("All results: {:?}", results.lock().unwrap());
}

2. Rust AI 생태계

Candle: HuggingFace의 Rust ML 프레임워크

Candle은 PyTorch의 Rust 대안으로, 최소 의존성과 WASM 지원이 강점입니다.

# Cargo.toml
[dependencies]
candle-core = "0.8"
candle-nn = "0.8"
candle-transformers = "0.8"
use candle_core::{Device, Tensor, DType};
use candle_nn::{Linear, Module, VarBuilder, VarMap};

// 간단한 피드포워드 네트워크
struct FeedForward {
    fc1: Linear,
    fc2: Linear,
}

impl FeedForward {
    fn new(vb: VarBuilder) -> candle_core::Result<Self> {
        let fc1 = candle_nn::linear(784, 256, vb.pp("fc1"))?;
        let fc2 = candle_nn::linear(256, 10, vb.pp("fc2"))?;
        Ok(FeedForward { fc1, fc2 })
    }
}

impl Module for FeedForward {
    fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
        let x = self.fc1.forward(x)?;
        let x = x.relu()?;
        self.fc2.forward(&x)
    }
}

fn candle_tensor_ops() -> candle_core::Result<()> {
    let device = Device::Cpu; // 또는 Device::new_cuda(0)?

    // 텐서 생성
    let a = Tensor::randn(0f32, 1f32, (3, 4), &device)?;
    let b = Tensor::randn(0f32, 1f32, (4, 5), &device)?;

    // 행렬 곱
    let c = a.matmul(&b)?;
    println!("Shape: {:?}", c.shape());

    // 요소별 연산
    let x = Tensor::new(&[[1.0f32, 2.0, 3.0]], &device)?;
    let softmax = candle_nn::ops::softmax(&x, 1)?;
    println!("Softmax: {:?}", softmax.to_vec2::<f32>()?);

    Ok(())
}

Burn: 백엔드 독립적 ML 프레임워크

use burn::prelude::*;
use burn::nn::{Linear, LinearConfig, Relu};

#[derive(Module, Debug)]
struct SimpleNet<B: Backend> {
    fc1: Linear<B>,
    activation: Relu,
    fc2: Linear<B>,
}

impl<B: Backend> SimpleNet<B> {
    fn new(device: &B::Device) -> Self {
        SimpleNet {
            fc1: LinearConfig::new(128, 64).init(device),
            activation: Relu::new(),
            fc2: LinearConfig::new(64, 10).init(device),
        }
    }

    fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
        let x = self.fc1.forward(x);
        let x = self.activation.forward(x);
        self.fc2.forward(x)
    }
}

Linfa: Scikit-learn 대안

use linfa::prelude::*;
use linfa_clustering::KMeans;
use ndarray::Array2;

fn kmeans_clustering() -> Result<(), Box<dyn std::error::Error>> {
    // 데이터 생성 (임베딩 벡터라고 가정)
    let data: Array2<f64> = Array2::from_shape_fn((100, 10), |(i, j)| {
        (i as f64 * 0.1 + j as f64) % 5.0
    });

    let dataset = Dataset::from(data);

    // K-means 클러스터링 (RAG 시스템에서 유사 문서 그룹화)
    let model = KMeans::params(5)
        .tolerance(1e-5)
        .fit(&dataset)?;

    let predictions = model.predict(&dataset);
    println!("Cluster assignments: {:?}", &predictions.targets()[..10]);

    Ok(())
}

3. 고성능 추론 서버: axum + Candle

tokio 비동기 기반 LLM 추론 서버

[dependencies]
axum = "0.7"
tokio = { version = "1", features = ["full"] }
candle-core = "0.8"
candle-transformers = "0.8"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
tower = "0.4"
use axum::{
    extract::State,
    http::StatusCode,
    response::Json,
    routing::post,
    Router,
};
use candle_core::{Device, Tensor};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::Mutex;

#[derive(Deserialize)]
struct InferenceRequest {
    prompt: String,
    max_tokens: Option<usize>,
    temperature: Option<f32>,
}

#[derive(Serialize)]
struct InferenceResponse {
    text: String,
    tokens_generated: usize,
    latency_ms: u64,
}

// 모델 상태를 Arc로 공유 (스레드 안전)
struct AppState {
    device: Device,
    // model: Arc<Mutex<YourLLMModel>>,
}

async fn inference_handler(
    State(state): State<Arc<AppState>>,
    Json(req): Json<InferenceRequest>,
) -> Result<Json<InferenceResponse>, StatusCode> {
    let start = std::time::Instant::now();

    // 실제로는 tokenizer + model forward pass
    let max_tokens = req.max_tokens.unwrap_or(100);
    let _temperature = req.temperature.unwrap_or(0.7);

    // 예시: 더미 추론 (실제 구현에서는 Candle 모델 사용)
    let generated_text = format!("Response to: {}", req.prompt);
    let latency = start.elapsed().as_millis() as u64;

    Ok(Json(InferenceResponse {
        text: generated_text,
        tokens_generated: max_tokens,
        latency_ms: latency,
    }))
}

async fn health_handler() -> &'static str {
    "OK"
}

#[tokio::main]
async fn main() {
    let state = Arc::new(AppState {
        device: Device::Cpu,
    });

    let app = Router::new()
        .route("/inference", post(inference_handler))
        .route("/health", axum::routing::get(health_handler))
        .with_state(state);

    let listener = tokio::net::TcpListener::bind("0.0.0.0:8080").await.unwrap();
    println!("AI Inference Server running on port 8080");
    axum::serve(listener, app).await.unwrap();
}

배치 추론 최적화

use tokio::sync::mpsc;
use std::sync::Arc;

struct BatchRequest {
    id: u64,
    prompt: String,
    response_tx: tokio::sync::oneshot::Sender<String>,
}

// 배치 처리 워커 — 여러 요청을 묶어서 GPU 효율 극대화
async fn batch_inference_worker(
    mut rx: mpsc::Receiver<BatchRequest>,
    batch_size: usize,
    timeout_ms: u64,
) {
    loop {
        let mut batch = Vec::new();

        // 첫 요청 대기
        if let Some(req) = rx.recv().await {
            batch.push(req);
        } else {
            break;
        }

        // 타임아웃 내에 배치 채우기
        let deadline = tokio::time::Instant::now()
            + tokio::time::Duration::from_millis(timeout_ms);

        while batch.len() < batch_size {
            match tokio::time::timeout_at(deadline, rx.recv()).await {
                Ok(Some(req)) => batch.push(req),
                _ => break,
            }
        }

        // 배치 처리 (실제로는 GPU에서 병렬 추론)
        for req in batch {
            let response = format!("Batch response for: {}", req.prompt);
            let _ = req.response_tx.send(response);
        }
    }
}

4. Python Interop: PyO3

Rust 함수를 Python에서 호출

[lib]
name = "rust_ai_tools"
crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.22", features = ["extension-module"] }
numpy = "0.22"
use pyo3::prelude::*;
use pyo3::types::PyList;

// Python에서 호출 가능한 Rust 함수
#[pyfunction]
fn fast_cosine_similarity(a: Vec<f32>, b: Vec<f32>) -> PyResult<f32> {
    if a.len() != b.len() {
        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
            "Vectors must have the same length"
        ));
    }

    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();

    Ok(dot / (norm_a * norm_b))
}

// numpy 배열을 받아서 처리 — GIL 해제로 병렬 처리 가능
#[pyfunction]
fn batch_normalize<'py>(
    py: Python<'py>,
    embeddings: Vec<Vec<f32>>,
) -> PyResult<Vec<Vec<f32>>> {
    // GIL을 해제하고 Rust 병렬 처리
    py.allow_threads(|| {
        embeddings.iter().map(|emb| {
            let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
            emb.iter().map(|x| x / norm).collect()
        }).collect::<Vec<Vec<f32>>>()
    });

    let result: Vec<Vec<f32>> = embeddings.iter().map(|emb| {
        let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
        emb.iter().map(|x| x / norm).collect()
    }).collect();

    Ok(result)
}

// Python 모듈 등록
#[pymodule]
fn rust_ai_tools(_py: Python, m: &PyModule) -> PyResult<()> {
    m.add_function(wrap_pyfunction!(fast_cosine_similarity, m)?)?;
    m.add_function(wrap_pyfunction!(batch_normalize, m)?)?;
    Ok(())
}

maturin으로 빌드 및 배포

# maturin 설치
pip install maturin

# 개발 모드 빌드 (가상환경에 직접 설치)
maturin develop

# Python에서 사용
# import rust_ai_tools
# similarity = rust_ai_tools.fast_cosine_similarity([1.0, 0.0], [0.0, 1.0])
# print(similarity)  # 0.0

# PyPI 배포용 wheel 빌드
maturin build --release

# pyproject.toml 설정
# [build-system]
# requires = ["maturin>=1.0,<2.0"]
# build-backend = "maturin"

PyO3 GIL 처리와 성능 이점

use pyo3::prelude::*;
use rayon::prelude::*;

// rayon을 활용한 진짜 병렬 처리
#[pyfunction]
fn parallel_embedding_search(
    py: Python,
    query: Vec<f32>,
    corpus: Vec<Vec<f32>>,
    top_k: usize,
) -> PyResult<Vec<usize>> {
    // GIL 해제 — Python GIL 없이 Rust 스레드 풀 사용
    let indices = py.allow_threads(|| {
        let mut scores: Vec<(usize, f32)> = corpus
            .par_iter()  // 병렬 이터레이터
            .enumerate()
            .map(|(i, emb)| {
                let dot: f32 = query.iter().zip(emb.iter()).map(|(a, b)| a * b).sum();
                (i, dot)
            })
            .collect();

        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
        scores.iter().take(top_k).map(|(i, _)| *i).collect::<Vec<_>>()
    });

    Ok(indices)
}

5. WASM for AI: 브라우저 추론

wasm-bindgen으로 브라우저에서 AI 실행

[lib]
crate-type = ["cdylib"]

[dependencies]
wasm-bindgen = "0.2"
candle-core = { version = "0.8", features = ["wasm"] }
use wasm_bindgen::prelude::*;

// JavaScript에서 호출 가능한 함수
#[wasm_bindgen]
pub struct MiniModel {
    weights: Vec<f32>,
    input_size: usize,
    output_size: usize,
}

#[wasm_bindgen]
impl MiniModel {
    #[wasm_bindgen(constructor)]
    pub fn new(input_size: usize, output_size: usize) -> MiniModel {
        let weights = vec![0.01f32; input_size * output_size];
        MiniModel { weights, input_size, output_size }
    }

    // JavaScript Float32Array를 받아서 추론
    pub fn predict(&self, input: &[f32]) -> Vec<f32> {
        let mut output = vec![0.0f32; self.output_size];
        for i in 0..self.output_size {
            for j in 0..self.input_size {
                output[i] += input[j] * self.weights[i * self.input_size + j];
            }
            // ReLU 활성화
            output[i] = output[i].max(0.0);
        }
        output
    }

    pub fn load_weights(&mut self, weights: Vec<f32>) {
        self.weights = weights;
    }
}

// 텍스트 임베딩의 코사인 유사도 (브라우저에서 RAG 가능)
#[wasm_bindgen]
pub fn cosine_similarity_wasm(a: &[f32], b: &[f32]) -> f32 {
    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
    let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
    let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
    dot / (na * nb)
}
// JavaScript에서 WASM 모듈 사용
import init, { MiniModel, cosine_similarity_wasm } from './rust_ai_wasm.js'

async function runBrowserInference() {
  await init()

  const model = new MiniModel(128, 10)
  const input = new Float32Array(128).fill(0.5)
  const output = model.predict(input)
  console.log('Browser inference output:', output)

  // 벡터 유사도 검색
  const queryEmb = new Float32Array(384).map(() => Math.random())
  const docEmb = new Float32Array(384).map(() => Math.random())
  const sim = cosine_similarity_wasm(queryEmb, docEmb)
  console.log('Similarity:', sim)
}

6. 메모리 안전 ML: Polars 데이터 파이프라인

Polars로 대용량 학습 데이터 처리

[dependencies]
polars = { version = "0.44", features = ["lazy", "parquet", "json", "csv"] }
use polars::prelude::*;

fn preprocess_training_data() -> PolarsResult<DataFrame> {
    // Parquet 파일 지연 로딩 (메모리 효율)
    let df = LazyFrame::scan_parquet("training_data.parquet", Default::default())?
        .filter(col("label").is_not_null())
        .with_column(
            // 텍스트 길이 특성 추출
            col("text").str().len_chars().alias("text_length")
        )
        .with_column(
            // 정규화
            (col("score") - col("score").mean()) / col("score").std(1)
        )
        .filter(col("text_length").gt(10))
        .select([col("text"), col("label"), col("score"), col("text_length")])
        .collect()?;

    println!("Shape: {:?}", df.shape());
    println!("{}", df.head(Some(5)));

    Ok(df)
}

fn batch_generator(df: &DataFrame, batch_size: usize) -> Vec<DataFrame> {
    let n = df.height();
    (0..n).step_by(batch_size).map(|start| {
        let end = (start + batch_size).min(n);
        df.slice(start as i64, end - start)
    }).collect()
}

fn compute_label_stats(df: &DataFrame) -> PolarsResult<DataFrame> {
    df.clone().lazy()
        .group_by([col("label")])
        .agg([
            col("score").mean().alias("avg_score"),
            col("score").std(1).alias("std_score"),
            col("text_length").mean().alias("avg_length"),
            col("label").count().alias("count"),
        ])
        .sort(["count"], SortMultipleOptions::default().with_order_descending(true))
        .collect()
}

Arrow 기반 제로 카피 데이터 공유

use arrow::array::{Float32Array, StringArray};
use arrow::record_batch::RecordBatch;
use arrow::datatypes::{DataType, Field, Schema};
use std::sync::Arc;

fn create_embedding_batch(
    texts: Vec<String>,
    embeddings: Vec<Vec<f32>>,
) -> RecordBatch {
    let schema = Arc::new(Schema::new(vec![
        Field::new("text", DataType::Utf8, false),
        Field::new("embedding_dim", DataType::Float32, false),
    ]));

    let text_array = StringArray::from(texts);
    // 첫 번째 차원만 저장 (예시)
    let emb_dim: Vec<f32> = embeddings.iter().map(|e| e[0]).collect();
    let emb_array = Float32Array::from(emb_dim);

    RecordBatch::try_new(
        schema,
        vec![Arc::new(text_array), Arc::new(emb_array)],
    ).unwrap()
}

7. 사례 연구

Hugging Face Candle

HuggingFace는 Python PyTorch 모델을 Rust Candle로 포팅하고 있습니다:

  • 바이너리 크기: PyTorch 250MB+ vs Candle 5MB 수준
  • WASM 지원으로 브라우저 직접 추론 가능
  • CUDA 커널 재사용 가능

Ruff: Python Linter의 Rust 재작성

Ruff는 Flake8 대비 100배 빠른 Python linter입니다. AI 코드베이스에서 lint 속도가 체감될 정도입니다.

# 기존 Python linter
time flake8 large_ml_project/  # ~30초

# Ruff (Rust)
time ruff check large_ml_project/  # ~0.3초

Pydantic v2의 Rust 코어

Pydantic v2는 핵심 검증 로직을 Rust(pydantic-core)로 재작성해 5-50배 성능 향상을 달성했습니다. AI API 서버에서 요청 검증 레이턴시가 크게 줄었습니다.


퀴즈

Q1. Rust에서 Arc와 Rc의 차이 및 멀티스레드 AI 추론에서 Arc가 필요한 이유는?

정답: Arc (Atomic Reference Count)는 원자적 연산으로 참조 카운트를 관리해 스레드 안전하지만, Rc (Reference Count)는 단일 스레드 전용입니다.

설명: AI 추론 서버에서는 여러 요청 스레드가 동일한 모델 가중치를 공유합니다. Rc는 Send 트레이트를 구현하지 않으므로 스레드 간 이동이 불가능합니다. Arc는 참조 카운트 증감에 원자적 연산(atomic fetch-add)을 사용해 데이터 레이스 없이 여러 스레드가 동일한 데이터를 참조할 수 있습니다. 성능 차이는 미미(나노초 단위)하지만 안전성은 컴파일 타임에 보장됩니다.

Q2. PyO3의 GIL 처리 방식과 Python에서 Rust 병렬 코드를 호출할 때의 성능 이점은?

정답: PyO3는 py.allow_threads()로 GIL을 해제하여 Rust 코드가 Python GIL 제약 없이 진짜 병렬 실행됩니다.

설명: Python의 GIL(Global Interpreter Lock)은 한 번에 하나의 Python 스레드만 실행하도록 제한합니다. PyO3의 allow_threads를 호출하면 현재 스레드가 GIL을 해제하고, Rust 코드(rayon 등)가 운영체제 스레드를 자유롭게 사용할 수 있습니다. 예를 들어 100만 개 임베딩 유사도 계산은 Python에서 단일 스레드로 수십 초가 걸리지만, Rust + rayon 병렬화로 수백 밀리초로 줄일 수 있습니다.

Q3. Candle이 PyTorch보다 엣지 배포에 유리한 이유는?

정답: Candle은 단일 Rust 바이너리로 배포 가능하고 WASM을 지원하며, PyTorch 런타임(250MB+) 없이 5MB 수준의 바이너리를 만들 수 있습니다.

설명: PyTorch는 Python 런타임, libtorch, CUDA 드라이버 등 대규모 의존성을 요구합니다. 반면 Candle은 Rust의 정적 링킹 덕분에 필요한 기능만 포함한 작은 실행 파일을 생성합니다. 또한 WebAssembly로 컴파일하면 브라우저에서 직접 추론이 가능해 서버 비용 없이 클라이언트 사이드 AI를 구현할 수 있습니다. Raspberry Pi나 임베디드 디바이스에서도 Python 없이 동작합니다.

Q4. Polars가 Pandas보다 대용량 데이터 처리에서 빠른 내부 구현 원리는?

정답: Polars는 Rust로 작성되어 Apache Arrow 컬럼형 메모리 포맷을 사용하고, SIMD 최적화와 라지 쿼리 최적화(lazy evaluation)를 활용합니다.

설명: Pandas는 Python 기반으로 행(row) 단위 처리가 많고 GIL 제약이 있습니다. Polars는 (1) Arrow 컬럼형 포맷으로 캐시 효율이 높고, (2) Rust의 SIMD 자동 벡터화로 수치 연산이 빠르며, (3) lazy API로 쿼리 플랜을 최적화(predicate pushdown, projection pushdown)하고, (4) 멀티스레드 병렬 처리를 GIL 없이 수행합니다. 일반적으로 Pandas 대비 5-20배 빠르며, 메모리 사용량도 낮습니다.

Q5. Rust의 제로 비용 추상화(zero-cost abstraction)가 C++ 성능에 필적하는 이유는?

정답: Rust 컴파일러(LLVM 백엔드)가 고수준 추상화(트레이트, 제네릭, 이터레이터)를 런타임 오버헤드 없이 최적화된 기계어로 변환합니다.

설명: C++의 템플릿과 인라이닝처럼, Rust의 제네릭은 단형화(monomorphization)를 통해 각 타입별 구체적인 코드를 컴파일 타임에 생성합니다. 이터레이터 체인(map, filter, fold)은 C 스타일 for 루프와 동일한 어셈블리를 생성합니다. 트레이트 메서드는 정적 디스패치(static dispatch)로 가상 함수 테이블 조회 없이 인라이닝됩니다. 결과적으로 "편리하게 작성해도 C처럼 빠른" 코드가 가능합니다.


마치며

Rust는 AI 인프라의 성능 임계점을 넘기 위한 강력한 도구입니다. Python으로 프로토타입을 만들고, 병목이 되는 부분을 Rust로 재작성하거나(PyO3), 처음부터 고성능 서비스를 Rust로 구축하는 전략이 현실적입니다.

학습 로드맵:

  1. The Rust Book — 공식 튜토리얼
  2. Rustlings — 인터랙티브 연습
  3. Candle 예제 — HuggingFace GitHub
  4. Axum 공식 문서 — 추론 서버 구축
  5. PyO3 가이드 — Python 통합

Rust for AI Systems: From Ownership to High-Performance Inference Servers with Candle, PyO3, and axum

Introduction

Python dominates AI development, but for production AI infrastructure where performance and safety are non-negotiable, Rust is gaining serious traction.

HuggingFace's Candle, the blazing-fast Python linter Ruff, and Pydantic v2's Rust core — core tools in the AI ecosystem are already being rewritten in Rust. This guide is a practical Rust introduction for AI engineers.

Why Rust for AI Systems?

AspectPythonRust
Runtime speedModerateC++ level
Memory safetyGC dependentCompile-time guaranteed
ConcurrencyGIL constrainedNo data races
Binary sizeRuntime requiredSingle binary
Edge deploymentDifficultWASM, embedded support

1. Rust Fundamentals: An AI Engineer's Perspective

Ownership and Memory Safety

Rust's cornerstone is the ownership system — memory safety without a garbage collector.

fn main() {
    // Ownership move
    let tensor_data = vec![1.0f32, 2.0, 3.0, 4.0];
    let moved = tensor_data; // tensor_data is no longer valid

    // Borrowing — reference without transferring ownership
    let weights = vec![0.1f32, 0.2, 0.3];
    let sum = compute_sum(&weights); // immutable reference
    println!("weights still valid: {:?}, sum: {}", weights, sum);
}

fn compute_sum(data: &[f32]) -> f32 {
    data.iter().sum()
}

Lifetimes

Lifetimes tell the compiler how long a reference is valid. Critical when safely referencing weight data in AI models without copying.

// Lifetime annotation: returned reference can't outlive input slices
fn longest_sequence<'a>(seq1: &'a [f32], seq2: &'a [f32]) -> &'a [f32] {
    if seq1.len() > seq2.len() {
        seq1
    } else {
        seq2
    }
}

struct ModelWeights<'a> {
    data: &'a [f32],  // references an external buffer (zero copy)
    shape: (usize, usize),
}

impl<'a> ModelWeights<'a> {
    fn new(data: &'a [f32], rows: usize, cols: usize) -> Self {
        ModelWeights { data, shape: (rows, cols) }
    }

    fn get(&self, row: usize, col: usize) -> f32 {
        self.data[row * self.shape.1 + col]
    }
}

Traits and Generics

Unlike Python's duck typing, Rust performs type checking at compile time.

// Trait definition for AI operations
trait Activation {
    fn forward(&self, x: f32) -> f32;
    fn backward(&self, x: f32) -> f32; // derivative
}

struct ReLU;
struct Sigmoid;

impl Activation for ReLU {
    fn forward(&self, x: f32) -> f32 {
        x.max(0.0)
    }
    fn backward(&self, x: f32) -> f32 {
        if x > 0.0 { 1.0 } else { 0.0 }
    }
}

impl Activation for Sigmoid {
    fn forward(&self, x: f32) -> f32 {
        1.0 / (1.0 + (-x).exp())
    }
    fn backward(&self, x: f32) -> f32 {
        let s = self.forward(x);
        s * (1.0 - s)
    }
}

// Generic layer: works with any activation function
fn apply_activation<A: Activation>(activation: &A, inputs: &[f32]) -> Vec<f32> {
    inputs.iter().map(|&x| activation.forward(x)).collect()
}

fn main() {
    let relu = ReLU;
    let data = vec![-1.0, 0.5, 2.0, -0.3];
    let output = apply_activation(&relu, &data);
    println!("ReLU output: {:?}", output); // [0.0, 0.5, 2.0, 0.0]
}

Arc vs Rc: Multi-threaded AI Inference

use std::sync::{Arc, Mutex};
use std::thread;

// Rc is single-threaded only — unusable in AI inference servers
// Arc (Atomic Reference Count) is thread-safe

fn multi_thread_inference() {
    // Share model weights across multiple threads
    let model_weights: Arc<Vec<f32>> = Arc::new(vec![0.1, 0.2, 0.3, 0.4]);
    let results: Arc<Mutex<Vec<f32>>> = Arc::new(Mutex::new(Vec::new()));

    let mut handles = vec![];

    for i in 0..4 {
        let weights = Arc::clone(&model_weights);
        let results = Arc::clone(&results);

        let handle = thread::spawn(move || {
            // Each thread runs inference with the same weights
            let inference_result = weights.iter().sum::<f32>() * i as f32;
            results.lock().unwrap().push(inference_result);
        });
        handles.push(handle);
    }

    for handle in handles {
        handle.join().unwrap();
    }

    println!("All results: {:?}", results.lock().unwrap());
}

2. The Rust AI Ecosystem

Candle: HuggingFace's Rust ML Framework

Candle is PyTorch's Rust alternative, with minimal dependencies and first-class WASM support.

# Cargo.toml
[dependencies]
candle-core = "0.8"
candle-nn = "0.8"
candle-transformers = "0.8"
use candle_core::{Device, Tensor, DType};
use candle_nn::{Linear, Module, VarBuilder, VarMap};

// Simple feedforward network
struct FeedForward {
    fc1: Linear,
    fc2: Linear,
}

impl FeedForward {
    fn new(vb: VarBuilder) -> candle_core::Result<Self> {
        let fc1 = candle_nn::linear(784, 256, vb.pp("fc1"))?;
        let fc2 = candle_nn::linear(256, 10, vb.pp("fc2"))?;
        Ok(FeedForward { fc1, fc2 })
    }
}

impl Module for FeedForward {
    fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
        let x = self.fc1.forward(x)?;
        let x = x.relu()?;
        self.fc2.forward(&x)
    }
}

fn candle_tensor_ops() -> candle_core::Result<()> {
    let device = Device::Cpu; // or Device::new_cuda(0)?

    // Create tensors
    let a = Tensor::randn(0f32, 1f32, (3, 4), &device)?;
    let b = Tensor::randn(0f32, 1f32, (4, 5), &device)?;

    // Matrix multiplication
    let c = a.matmul(&b)?;
    println!("Shape: {:?}", c.shape());

    // Element-wise operations
    let x = Tensor::new(&[[1.0f32, 2.0, 3.0]], &device)?;
    let softmax = candle_nn::ops::softmax(&x, 1)?;
    println!("Softmax: {:?}", softmax.to_vec2::<f32>()?);

    Ok(())
}

Burn: Backend-Agnostic ML Framework

use burn::prelude::*;
use burn::nn::{Linear, LinearConfig, Relu};

#[derive(Module, Debug)]
struct SimpleNet<B: Backend> {
    fc1: Linear<B>,
    activation: Relu,
    fc2: Linear<B>,
}

impl<B: Backend> SimpleNet<B> {
    fn new(device: &B::Device) -> Self {
        SimpleNet {
            fc1: LinearConfig::new(128, 64).init(device),
            activation: Relu::new(),
            fc2: LinearConfig::new(64, 10).init(device),
        }
    }

    fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
        let x = self.fc1.forward(x);
        let x = self.activation.forward(x);
        self.fc2.forward(x)
    }
}

Linfa: Scikit-learn Alternative

use linfa::prelude::*;
use linfa_clustering::KMeans;
use ndarray::Array2;

fn kmeans_clustering() -> Result<(), Box<dyn std::error::Error>> {
    // Generate data (assume embedding vectors)
    let data: Array2<f64> = Array2::from_shape_fn((100, 10), |(i, j)| {
        (i as f64 * 0.1 + j as f64) % 5.0
    });

    let dataset = Dataset::from(data);

    // K-means clustering (grouping similar docs in a RAG system)
    let model = KMeans::params(5)
        .tolerance(1e-5)
        .fit(&dataset)?;

    let predictions = model.predict(&dataset);
    println!("Cluster assignments: {:?}", &predictions.targets()[..10]);

    Ok(())
}

3. High-Performance Inference Server: axum + Candle

tokio Async-based LLM Inference Server

[dependencies]
axum = "0.7"
tokio = { version = "1", features = ["full"] }
candle-core = "0.8"
candle-transformers = "0.8"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
tower = "0.4"
use axum::{
    extract::State,
    http::StatusCode,
    response::Json,
    routing::post,
    Router,
};
use candle_core::Device;
use serde::{Deserialize, Serialize};
use std::sync::Arc;

#[derive(Deserialize)]
struct InferenceRequest {
    prompt: String,
    max_tokens: Option<usize>,
    temperature: Option<f32>,
}

#[derive(Serialize)]
struct InferenceResponse {
    text: String,
    tokens_generated: usize,
    latency_ms: u64,
}

// Model state shared via Arc (thread-safe)
struct AppState {
    device: Device,
    // model: Arc<Mutex<YourLLMModel>>,
}

async fn inference_handler(
    State(state): State<Arc<AppState>>,
    Json(req): Json<InferenceRequest>,
) -> Result<Json<InferenceResponse>, StatusCode> {
    let start = std::time::Instant::now();

    let max_tokens = req.max_tokens.unwrap_or(100);
    let _temperature = req.temperature.unwrap_or(0.7);

    // Dummy inference (real implementation uses Candle model forward pass)
    let generated_text = format!("Response to: {}", req.prompt);
    let latency = start.elapsed().as_millis() as u64;

    Ok(Json(InferenceResponse {
        text: generated_text,
        tokens_generated: max_tokens,
        latency_ms: latency,
    }))
}

async fn health_handler() -> &'static str {
    "OK"
}

#[tokio::main]
async fn main() {
    let state = Arc::new(AppState {
        device: Device::Cpu,
    });

    let app = Router::new()
        .route("/inference", post(inference_handler))
        .route("/health", axum::routing::get(health_handler))
        .with_state(state);

    let listener = tokio::net::TcpListener::bind("0.0.0.0:8080").await.unwrap();
    println!("AI Inference Server running on port 8080");
    axum::serve(listener, app).await.unwrap();
}

Batched Inference Optimization

use tokio::sync::mpsc;

struct BatchRequest {
    id: u64,
    prompt: String,
    response_tx: tokio::sync::oneshot::Sender<String>,
}

// Batch processing worker — bundle requests to maximize GPU utilization
async fn batch_inference_worker(
    mut rx: mpsc::Receiver<BatchRequest>,
    batch_size: usize,
    timeout_ms: u64,
) {
    loop {
        let mut batch = Vec::new();

        // Wait for first request
        if let Some(req) = rx.recv().await {
            batch.push(req);
        } else {
            break;
        }

        // Fill batch within timeout window
        let deadline = tokio::time::Instant::now()
            + tokio::time::Duration::from_millis(timeout_ms);

        while batch.len() < batch_size {
            match tokio::time::timeout_at(deadline, rx.recv()).await {
                Ok(Some(req)) => batch.push(req),
                _ => break,
            }
        }

        // Process batch (real implementation runs parallel GPU inference)
        for req in batch {
            let response = format!("Batch response for: {}", req.prompt);
            let _ = req.response_tx.send(response);
        }
    }
}

4. Python Interop: PyO3

Calling Rust Functions from Python

[lib]
name = "rust_ai_tools"
crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.22", features = ["extension-module"] }
numpy = "0.22"
use pyo3::prelude::*;

// Python-callable Rust function
#[pyfunction]
fn fast_cosine_similarity(a: Vec<f32>, b: Vec<f32>) -> PyResult<f32> {
    if a.len() != b.len() {
        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
            "Vectors must have the same length"
        ));
    }

    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();

    Ok(dot / (norm_a * norm_b))
}

// Accept numpy arrays and process — GIL release enables true parallelism
#[pyfunction]
fn batch_normalize(
    py: Python,
    embeddings: Vec<Vec<f32>>,
) -> PyResult<Vec<Vec<f32>>> {
    // Release GIL — Rust thread pool runs freely
    let result = py.allow_threads(|| {
        embeddings.iter().map(|emb| {
            let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
            emb.iter().map(|x| x / norm).collect()
        }).collect::<Vec<Vec<f32>>>()
    });

    Ok(result)
}

// Register Python module
#[pymodule]
fn rust_ai_tools(_py: Python, m: &PyModule) -> PyResult<()> {
    m.add_function(wrap_pyfunction!(fast_cosine_similarity, m)?)?;
    m.add_function(wrap_pyfunction!(batch_normalize, m)?)?;
    Ok(())
}

Build and Publish with maturin

# Install maturin
pip install maturin

# Development build (installs directly into virtualenv)
maturin develop

# Usage from Python:
# import rust_ai_tools
# sim = rust_ai_tools.fast_cosine_similarity([1.0, 0.0], [0.0, 1.0])
# print(sim)  # 0.0

# Build wheel for PyPI
maturin build --release

# pyproject.toml configuration:
# [build-system]
# requires = ["maturin>=1.0,<2.0"]
# build-backend = "maturin"

GIL Handling and True Parallelism

use pyo3::prelude::*;
use rayon::prelude::*;

// True parallelism via rayon
#[pyfunction]
fn parallel_embedding_search(
    py: Python,
    query: Vec<f32>,
    corpus: Vec<Vec<f32>>,
    top_k: usize,
) -> PyResult<Vec<usize>> {
    // Release GIL — Rust thread pool takes over
    let indices = py.allow_threads(|| {
        let mut scores: Vec<(usize, f32)> = corpus
            .par_iter()  // parallel iterator
            .enumerate()
            .map(|(i, emb)| {
                let dot: f32 = query.iter().zip(emb.iter()).map(|(a, b)| a * b).sum();
                (i, dot)
            })
            .collect();

        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
        scores.iter().take(top_k).map(|(i, _)| *i).collect::<Vec<_>>()
    });

    Ok(indices)
}

5. WASM for AI: Browser-side Inference

In-Browser AI with wasm-bindgen

[lib]
crate-type = ["cdylib"]

[dependencies]
wasm-bindgen = "0.2"
candle-core = { version = "0.8", features = ["wasm"] }
use wasm_bindgen::prelude::*;

// JavaScript-callable struct
#[wasm_bindgen]
pub struct MiniModel {
    weights: Vec<f32>,
    input_size: usize,
    output_size: usize,
}

#[wasm_bindgen]
impl MiniModel {
    #[wasm_bindgen(constructor)]
    pub fn new(input_size: usize, output_size: usize) -> MiniModel {
        let weights = vec![0.01f32; input_size * output_size];
        MiniModel { weights, input_size, output_size }
    }

    // Accept JavaScript Float32Array and run inference
    pub fn predict(&self, input: &[f32]) -> Vec<f32> {
        let mut output = vec![0.0f32; self.output_size];
        for i in 0..self.output_size {
            for j in 0..self.input_size {
                output[i] += input[j] * self.weights[i * self.input_size + j];
            }
            // ReLU activation
            output[i] = output[i].max(0.0);
        }
        output
    }

    pub fn load_weights(&mut self, weights: Vec<f32>) {
        self.weights = weights;
    }
}

// Cosine similarity for embeddings (enables client-side RAG)
#[wasm_bindgen]
pub fn cosine_similarity_wasm(a: &[f32], b: &[f32]) -> f32 {
    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
    let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
    let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
    dot / (na * nb)
}
// Using the WASM module from JavaScript
import init, { MiniModel, cosine_similarity_wasm } from './rust_ai_wasm.js'

async function runBrowserInference() {
  await init()

  const model = new MiniModel(128, 10)
  const input = new Float32Array(128).fill(0.5)
  const output = model.predict(input)
  console.log('Browser inference output:', output)

  // Vector similarity search
  const queryEmb = new Float32Array(384).map(() => Math.random())
  const docEmb = new Float32Array(384).map(() => Math.random())
  const sim = cosine_similarity_wasm(queryEmb, docEmb)
  console.log('Similarity:', sim)
}

6. Memory-Safe ML: Polars Data Pipelines

Large-Scale Training Data Processing with Polars

[dependencies]
polars = { version = "0.44", features = ["lazy", "parquet", "json", "csv"] }
use polars::prelude::*;

fn preprocess_training_data() -> PolarsResult<DataFrame> {
    // Lazy Parquet scan (memory efficient)
    let df = LazyFrame::scan_parquet("training_data.parquet", Default::default())?
        .filter(col("label").is_not_null())
        .with_column(
            // Extract text length feature
            col("text").str().len_chars().alias("text_length")
        )
        .with_column(
            // Normalize scores
            (col("score") - col("score").mean()) / col("score").std(1)
        )
        .filter(col("text_length").gt(10))
        .select([col("text"), col("label"), col("score"), col("text_length")])
        .collect()?;

    println!("Shape: {:?}", df.shape());
    println!("{}", df.head(Some(5)));

    Ok(df)
}

fn batch_generator(df: &DataFrame, batch_size: usize) -> Vec<DataFrame> {
    let n = df.height();
    (0..n).step_by(batch_size).map(|start| {
        let end = (start + batch_size).min(n);
        df.slice(start as i64, end - start)
    }).collect()
}

fn compute_label_stats(df: &DataFrame) -> PolarsResult<DataFrame> {
    df.clone().lazy()
        .group_by([col("label")])
        .agg([
            col("score").mean().alias("avg_score"),
            col("score").std(1).alias("std_score"),
            col("text_length").mean().alias("avg_length"),
            col("label").count().alias("count"),
        ])
        .sort(["count"], SortMultipleOptions::default().with_order_descending(true))
        .collect()
}

Zero-Copy Data Sharing with Apache Arrow

use arrow::array::{Float32Array, StringArray};
use arrow::record_batch::RecordBatch;
use arrow::datatypes::{DataType, Field, Schema};
use std::sync::Arc;

fn create_embedding_batch(
    texts: Vec<String>,
    embeddings: Vec<Vec<f32>>,
) -> RecordBatch {
    let schema = Arc::new(Schema::new(vec![
        Field::new("text", DataType::Utf8, false),
        Field::new("embedding_dim0", DataType::Float32, false),
    ]));

    let text_array = StringArray::from(texts);
    // Store first dimension as example
    let emb_dim: Vec<f32> = embeddings.iter().map(|e| e[0]).collect();
    let emb_array = Float32Array::from(emb_dim);

    RecordBatch::try_new(
        schema,
        vec![Arc::new(text_array), Arc::new(emb_array)],
    ).unwrap()
}

7. Case Studies

Hugging Face Candle

HuggingFace is porting Python PyTorch models to Rust Candle:

  • Binary size: PyTorch 250MB+ vs Candle ~5MB
  • WASM support enables direct browser inference
  • Reusable CUDA kernels

Ruff: Python Linter Rewritten in Rust

Ruff is 100x faster than Flake8. The speed difference is palpable on large ML codebases.

# Legacy Python linter
time flake8 large_ml_project/  # ~30 seconds

# Ruff (Rust)
time ruff check large_ml_project/  # ~0.3 seconds

Pydantic v2's Rust Core

Pydantic v2 rewrote its core validation logic in Rust (pydantic-core), achieving 5-50x performance improvements. AI API servers see dramatically reduced request validation latency.


Quiz

Q1. What is the difference between Arc and Rc in Rust, and why is Arc required for multi-threaded AI inference?

Answer: Arc (Atomic Reference Count) manages reference counts with atomic operations, making it thread-safe. Rc (Reference Count) is for single-threaded use only.

Explanation: In an AI inference server, multiple request threads share the same model weights. Rc does not implement the Send trait, so it cannot be moved between threads. Arc uses atomic fetch-add operations for reference count changes, allowing multiple threads to reference the same data without data races. The performance difference is negligible (nanosecond-level), but safety is guaranteed at compile time.

Q2. How does PyO3 handle the GIL, and what performance benefit does it provide when calling Rust parallel code from Python?

Answer: PyO3 releases the GIL via py.allow_threads(), enabling Rust code to execute truly in parallel without Python's GIL constraint.

Explanation: Python's GIL (Global Interpreter Lock) restricts execution to one Python thread at a time. Calling allow_threads releases the GIL for the current thread, letting Rust code (with rayon, for example) freely use OS threads. Computing cosine similarity over 1 million embeddings takes tens of seconds in single-threaded Python, but can be reduced to hundreds of milliseconds with Rust + rayon parallelism.

Q3. Why is Candle better suited for edge deployment than PyTorch?

Answer: Candle ships as a single Rust binary with WASM support, producing ~5MB binaries without requiring the PyTorch runtime (250MB+).

Explanation: PyTorch requires Python runtime, libtorch, CUDA drivers, and other large dependencies. Candle's static linking produces small executables containing only what's needed. Compiling to WebAssembly enables direct browser inference — client-side AI with no server costs. It also runs on Raspberry Pi and embedded devices without Python.

Q4. Why does Polars process large datasets faster than Pandas internally?

Answer: Polars is written in Rust and uses the Apache Arrow columnar memory format, SIMD optimizations, and query optimization via lazy evaluation.

Explanation: Pandas relies on Python with frequent row-wise operations and GIL constraints. Polars uses (1) Arrow columnar format for cache-efficient access, (2) Rust's automatic SIMD vectorization for numeric operations, (3) a lazy API that optimizes query plans (predicate pushdown, projection pushdown), and (4) multi-threaded parallel processing without a GIL. Typical speedup is 5-20x over Pandas, with lower memory consumption.

Q5. Why do Rust's zero-cost abstractions match C++ performance?

Answer: The Rust compiler (LLVM backend) transforms high-level abstractions (traits, generics, iterators) into optimized machine code with no runtime overhead.

Explanation: Like C++ templates and inlining, Rust generics use monomorphization to generate concrete code for each type at compile time. Iterator chains (map, filter, fold) produce the same assembly as C-style for loops. Trait methods use static dispatch — no virtual table lookups, just inlining. The result: code that is both ergonomic to write and as fast as C.


Wrapping Up

Rust is a powerful tool for pushing past the performance ceiling in AI infrastructure. The practical approach: prototype in Python, rewrite bottlenecks in Rust (via PyO3), or build high-performance services in Rust from the ground up.

Learning Roadmap:

  1. The Rust Book — official tutorial
  2. Rustlings — interactive exercises
  3. Candle examples — HuggingFace GitHub
  4. Axum documentation — building inference servers
  5. PyO3 guide — Python integration