필사 모드: Rust for AI Systems: From Ownership to High-Performance Inference Servers with Candle, PyO3, and axum
EnglishIntroduction
Python dominates AI development, but for production AI infrastructure where **performance and safety** are non-negotiable, **Rust** is gaining serious traction.
HuggingFace's **Candle**, the blazing-fast Python linter **Ruff**, and Pydantic v2's Rust core — core tools in the AI ecosystem are already being rewritten in Rust. This guide is a practical Rust introduction for AI engineers.
Why Rust for AI Systems?
| Aspect | Python | Rust |
| --------------- | ---------------- | ----------------------- |
| Runtime speed | Moderate | C++ level |
| Memory safety | GC dependent | Compile-time guaranteed |
| Concurrency | GIL constrained | No data races |
| Binary size | Runtime required | Single binary |
| Edge deployment | Difficult | WASM, embedded support |
1. Rust Fundamentals: An AI Engineer's Perspective
Ownership and Memory Safety
Rust's cornerstone is the **ownership system** — memory safety without a garbage collector.
fn main() {
// Ownership move
let tensor_data = vec![1.0f32, 2.0, 3.0, 4.0];
let moved = tensor_data; // tensor_data is no longer valid
// Borrowing — reference without transferring ownership
let weights = vec![0.1f32, 0.2, 0.3];
let sum = compute_sum(&weights); // immutable reference
println!("weights still valid: {:?}, sum: {}", weights, sum);
}
fn compute_sum(data: &[f32]) -> f32 {
data.iter().sum()
}
Lifetimes
Lifetimes tell the compiler how long a reference is valid. Critical when safely referencing weight data in AI models without copying.
// Lifetime annotation: returned reference can't outlive input slices
fn longest_sequence<'a>(seq1: &'a [f32], seq2: &'a [f32]) -> &'a [f32] {
if seq1.len() > seq2.len() {
seq1
} else {
seq2
}
}
struct ModelWeights<'a> {
data: &'a [f32], // references an external buffer (zero copy)
shape: (usize, usize),
}
impl<'a> ModelWeights<'a> {
fn new(data: &'a [f32], rows: usize, cols: usize) -> Self {
ModelWeights { data, shape: (rows, cols) }
}
fn get(&self, row: usize, col: usize) -> f32 {
self.data[row * self.shape.1 + col]
}
}
Traits and Generics
Unlike Python's duck typing, Rust performs type checking at **compile time**.
// Trait definition for AI operations
trait Activation {
fn forward(&self, x: f32) -> f32;
fn backward(&self, x: f32) -> f32; // derivative
}
struct ReLU;
struct Sigmoid;
impl Activation for ReLU {
fn forward(&self, x: f32) -> f32 {
x.max(0.0)
}
fn backward(&self, x: f32) -> f32 {
if x > 0.0 { 1.0 } else { 0.0 }
}
}
impl Activation for Sigmoid {
fn forward(&self, x: f32) -> f32 {
1.0 / (1.0 + (-x).exp())
}
fn backward(&self, x: f32) -> f32 {
let s = self.forward(x);
s * (1.0 - s)
}
}
// Generic layer: works with any activation function
fn apply_activation<A: Activation>(activation: &A, inputs: &[f32]) -> Vec<f32> {
inputs.iter().map(|&x| activation.forward(x)).collect()
}
fn main() {
let relu = ReLU;
let data = vec![-1.0, 0.5, 2.0, -0.3];
let output = apply_activation(&relu, &data);
println!("ReLU output: {:?}", output); // [0.0, 0.5, 2.0, 0.0]
}
Arc vs Rc: Multi-threaded AI Inference
use std::sync::{Arc, Mutex};
use std::thread;
// Rc is single-threaded only — unusable in AI inference servers
// Arc (Atomic Reference Count) is thread-safe
fn multi_thread_inference() {
// Share model weights across multiple threads
let model_weights: Arc<Vec<f32>> = Arc::new(vec![0.1, 0.2, 0.3, 0.4]);
let results: Arc<Mutex<Vec<f32>>> = Arc::new(Mutex::new(Vec::new()));
let mut handles = vec![];
for i in 0..4 {
let weights = Arc::clone(&model_weights);
let results = Arc::clone(&results);
let handle = thread::spawn(move || {
// Each thread runs inference with the same weights
let inference_result = weights.iter().sum::<f32>() * i as f32;
results.lock().unwrap().push(inference_result);
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
println!("All results: {:?}", results.lock().unwrap());
}
2. The Rust AI Ecosystem
Candle: HuggingFace's Rust ML Framework
**Candle** is PyTorch's Rust alternative, with minimal dependencies and first-class WASM support.
Cargo.toml
[dependencies]
candle-core = "0.8"
candle-nn = "0.8"
candle-transformers = "0.8"
use candle_core::{Device, Tensor, DType};
use candle_nn::{Linear, Module, VarBuilder, VarMap};
// Simple feedforward network
struct FeedForward {
fc1: Linear,
fc2: Linear,
}
impl FeedForward {
fn new(vb: VarBuilder) -> candle_core::Result<Self> {
let fc1 = candle_nn::linear(784, 256, vb.pp("fc1"))?;
let fc2 = candle_nn::linear(256, 10, vb.pp("fc2"))?;
Ok(FeedForward { fc1, fc2 })
}
}
impl Module for FeedForward {
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
let x = self.fc1.forward(x)?;
let x = x.relu()?;
self.fc2.forward(&x)
}
}
fn candle_tensor_ops() -> candle_core::Result<()> {
let device = Device::Cpu; // or Device::new_cuda(0)?
// Create tensors
let a = Tensor::randn(0f32, 1f32, (3, 4), &device)?;
let b = Tensor::randn(0f32, 1f32, (4, 5), &device)?;
// Matrix multiplication
let c = a.matmul(&b)?;
println!("Shape: {:?}", c.shape());
// Element-wise operations
let x = Tensor::new(&[[1.0f32, 2.0, 3.0]], &device)?;
let softmax = candle_nn::ops::softmax(&x, 1)?;
println!("Softmax: {:?}", softmax.to_vec2::<f32>()?);
Ok(())
}
Burn: Backend-Agnostic ML Framework
use burn::prelude::*;
use burn::nn::{Linear, LinearConfig, Relu};
#[derive(Module, Debug)]
struct SimpleNet<B: Backend> {
fc1: Linear<B>,
activation: Relu,
fc2: Linear<B>,
}
impl<B: Backend> SimpleNet<B> {
fn new(device: &B::Device) -> Self {
SimpleNet {
fc1: LinearConfig::new(128, 64).init(device),
activation: Relu::new(),
fc2: LinearConfig::new(64, 10).init(device),
}
}
fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
let x = self.fc1.forward(x);
let x = self.activation.forward(x);
self.fc2.forward(x)
}
}
Linfa: Scikit-learn Alternative
use linfa::prelude::*;
use linfa_clustering::KMeans;
use ndarray::Array2;
fn kmeans_clustering() -> Result<(), Box<dyn std::error::Error>> {
// Generate data (assume embedding vectors)
let data: Array2<f64> = Array2::from_shape_fn((100, 10), |(i, j)| {
(i as f64 * 0.1 + j as f64) % 5.0
});
let dataset = Dataset::from(data);
// K-means clustering (grouping similar docs in a RAG system)
let model = KMeans::params(5)
.tolerance(1e-5)
.fit(&dataset)?;
let predictions = model.predict(&dataset);
println!("Cluster assignments: {:?}", &predictions.targets()[..10]);
Ok(())
}
3. High-Performance Inference Server: axum + Candle
tokio Async-based LLM Inference Server
[dependencies]
axum = "0.7"
tokio = { version = "1", features = ["full"] }
candle-core = "0.8"
candle-transformers = "0.8"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
tower = "0.4"
use axum::{
extract::State,
http::StatusCode,
response::Json,
routing::post,
Router,
};
use candle_core::Device;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Deserialize)]
struct InferenceRequest {
prompt: String,
max_tokens: Option<usize>,
temperature: Option<f32>,
}
#[derive(Serialize)]
struct InferenceResponse {
text: String,
tokens_generated: usize,
latency_ms: u64,
}
// Model state shared via Arc (thread-safe)
struct AppState {
device: Device,
// model: Arc<Mutex<YourLLMModel>>,
}
async fn inference_handler(
State(state): State<Arc<AppState>>,
Json(req): Json<InferenceRequest>,
) -> Result<Json<InferenceResponse>, StatusCode> {
let start = std::time::Instant::now();
let max_tokens = req.max_tokens.unwrap_or(100);
let _temperature = req.temperature.unwrap_or(0.7);
// Dummy inference (real implementation uses Candle model forward pass)
let generated_text = format!("Response to: {}", req.prompt);
let latency = start.elapsed().as_millis() as u64;
Ok(Json(InferenceResponse {
text: generated_text,
tokens_generated: max_tokens,
latency_ms: latency,
}))
}
async fn health_handler() -> &'static str {
"OK"
}
#[tokio::main]
async fn main() {
let state = Arc::new(AppState {
device: Device::Cpu,
});
let app = Router::new()
.route("/inference", post(inference_handler))
.route("/health", axum::routing::get(health_handler))
.with_state(state);
let listener = tokio::net::TcpListener::bind("0.0.0.0:8080").await.unwrap();
println!("AI Inference Server running on port 8080");
axum::serve(listener, app).await.unwrap();
}
Batched Inference Optimization
use tokio::sync::mpsc;
struct BatchRequest {
id: u64,
prompt: String,
response_tx: tokio::sync::oneshot::Sender<String>,
}
// Batch processing worker — bundle requests to maximize GPU utilization
async fn batch_inference_worker(
mut rx: mpsc::Receiver<BatchRequest>,
batch_size: usize,
timeout_ms: u64,
) {
loop {
let mut batch = Vec::new();
// Wait for first request
if let Some(req) = rx.recv().await {
batch.push(req);
} else {
break;
}
// Fill batch within timeout window
let deadline = tokio::time::Instant::now()
+ tokio::time::Duration::from_millis(timeout_ms);
while batch.len() < batch_size {
match tokio::time::timeout_at(deadline, rx.recv()).await {
Ok(Some(req)) => batch.push(req),
_ => break,
}
}
// Process batch (real implementation runs parallel GPU inference)
for req in batch {
let response = format!("Batch response for: {}", req.prompt);
let _ = req.response_tx.send(response);
}
}
}
4. Python Interop: PyO3
Calling Rust Functions from Python
[lib]
name = "rust_ai_tools"
crate-type = ["cdylib"]
[dependencies]
pyo3 = { version = "0.22", features = ["extension-module"] }
numpy = "0.22"
use pyo3::prelude::*;
// Python-callable Rust function
#[pyfunction]
fn fast_cosine_similarity(a: Vec<f32>, b: Vec<f32>) -> PyResult<f32> {
if a.len() != b.len() {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Vectors must have the same length"
));
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
Ok(dot / (norm_a * norm_b))
}
// Accept numpy arrays and process — GIL release enables true parallelism
#[pyfunction]
fn batch_normalize(
py: Python,
embeddings: Vec<Vec<f32>>,
) -> PyResult<Vec<Vec<f32>>> {
// Release GIL — Rust thread pool runs freely
let result = py.allow_threads(|| {
embeddings.iter().map(|emb| {
let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
emb.iter().map(|x| x / norm).collect()
}).collect::<Vec<Vec<f32>>>()
});
Ok(result)
}
// Register Python module
#[pymodule]
fn rust_ai_tools(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(fast_cosine_similarity, m)?)?;
m.add_function(wrap_pyfunction!(batch_normalize, m)?)?;
Ok(())
}
Build and Publish with maturin
Install maturin
pip install maturin
Development build (installs directly into virtualenv)
maturin develop
Usage from Python:
import rust_ai_tools
sim = rust_ai_tools.fast_cosine_similarity([1.0, 0.0], [0.0, 1.0])
print(sim) # 0.0
Build wheel for PyPI
maturin build --release
pyproject.toml configuration:
[build-system]
requires = ["maturin>=1.0,<2.0"]
build-backend = "maturin"
GIL Handling and True Parallelism
use pyo3::prelude::*;
use rayon::prelude::*;
// True parallelism via rayon
#[pyfunction]
fn parallel_embedding_search(
py: Python,
query: Vec<f32>,
corpus: Vec<Vec<f32>>,
top_k: usize,
) -> PyResult<Vec<usize>> {
// Release GIL — Rust thread pool takes over
let indices = py.allow_threads(|| {
let mut scores: Vec<(usize, f32)> = corpus
.par_iter() // parallel iterator
.enumerate()
.map(|(i, emb)| {
let dot: f32 = query.iter().zip(emb.iter()).map(|(a, b)| a * b).sum();
(i, dot)
})
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
scores.iter().take(top_k).map(|(i, _)| *i).collect::<Vec<_>>()
});
Ok(indices)
}
5. WASM for AI: Browser-side Inference
In-Browser AI with wasm-bindgen
[lib]
crate-type = ["cdylib"]
[dependencies]
wasm-bindgen = "0.2"
candle-core = { version = "0.8", features = ["wasm"] }
use wasm_bindgen::prelude::*;
// JavaScript-callable struct
#[wasm_bindgen]
pub struct MiniModel {
weights: Vec<f32>,
input_size: usize,
output_size: usize,
}
#[wasm_bindgen]
impl MiniModel {
#[wasm_bindgen(constructor)]
pub fn new(input_size: usize, output_size: usize) -> MiniModel {
let weights = vec![0.01f32; input_size * output_size];
MiniModel { weights, input_size, output_size }
}
// Accept JavaScript Float32Array and run inference
pub fn predict(&self, input: &[f32]) -> Vec<f32> {
let mut output = vec![0.0f32; self.output_size];
for i in 0..self.output_size {
for j in 0..self.input_size {
output[i] += input[j] * self.weights[i * self.input_size + j];
}
// ReLU activation
output[i] = output[i].max(0.0);
}
output
}
pub fn load_weights(&mut self, weights: Vec<f32>) {
self.weights = weights;
}
}
// Cosine similarity for embeddings (enables client-side RAG)
#[wasm_bindgen]
pub fn cosine_similarity_wasm(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
dot / (na * nb)
}
// Using the WASM module from JavaScript
async function runBrowserInference() {
await init()
const model = new MiniModel(128, 10)
const input = new Float32Array(128).fill(0.5)
const output = model.predict(input)
console.log('Browser inference output:', output)
// Vector similarity search
const queryEmb = new Float32Array(384).map(() => Math.random())
const docEmb = new Float32Array(384).map(() => Math.random())
const sim = cosine_similarity_wasm(queryEmb, docEmb)
console.log('Similarity:', sim)
}
6. Memory-Safe ML: Polars Data Pipelines
Large-Scale Training Data Processing with Polars
[dependencies]
polars = { version = "0.44", features = ["lazy", "parquet", "json", "csv"] }
use polars::prelude::*;
fn preprocess_training_data() -> PolarsResult<DataFrame> {
// Lazy Parquet scan (memory efficient)
let df = LazyFrame::scan_parquet("training_data.parquet", Default::default())?
.filter(col("label").is_not_null())
.with_column(
// Extract text length feature
col("text").str().len_chars().alias("text_length")
)
.with_column(
// Normalize scores
(col("score") - col("score").mean()) / col("score").std(1)
)
.filter(col("text_length").gt(10))
.select([col("text"), col("label"), col("score"), col("text_length")])
.collect()?;
println!("Shape: {:?}", df.shape());
println!("{}", df.head(Some(5)));
Ok(df)
}
fn batch_generator(df: &DataFrame, batch_size: usize) -> Vec<DataFrame> {
let n = df.height();
(0..n).step_by(batch_size).map(|start| {
let end = (start + batch_size).min(n);
df.slice(start as i64, end - start)
}).collect()
}
fn compute_label_stats(df: &DataFrame) -> PolarsResult<DataFrame> {
df.clone().lazy()
.group_by([col("label")])
.agg([
col("score").mean().alias("avg_score"),
col("score").std(1).alias("std_score"),
col("text_length").mean().alias("avg_length"),
col("label").count().alias("count"),
])
.sort(["count"], SortMultipleOptions::default().with_order_descending(true))
.collect()
}
Zero-Copy Data Sharing with Apache Arrow
use arrow::array::{Float32Array, StringArray};
use arrow::record_batch::RecordBatch;
use arrow::datatypes::{DataType, Field, Schema};
use std::sync::Arc;
fn create_embedding_batch(
texts: Vec<String>,
embeddings: Vec<Vec<f32>>,
) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![
Field::new("text", DataType::Utf8, false),
Field::new("embedding_dim0", DataType::Float32, false),
]));
let text_array = StringArray::from(texts);
// Store first dimension as example
let emb_dim: Vec<f32> = embeddings.iter().map(|e| e[0]).collect();
let emb_array = Float32Array::from(emb_dim);
RecordBatch::try_new(
schema,
vec![Arc::new(text_array), Arc::new(emb_array)],
).unwrap()
}
7. Case Studies
Hugging Face Candle
HuggingFace is porting Python PyTorch models to Rust Candle:
- Binary size: PyTorch 250MB+ vs Candle ~5MB
- WASM support enables direct browser inference
- Reusable CUDA kernels
Ruff: Python Linter Rewritten in Rust
Ruff is 100x faster than Flake8. The speed difference is palpable on large ML codebases.
Legacy Python linter
time flake8 large_ml_project/ # ~30 seconds
Ruff (Rust)
time ruff check large_ml_project/ # ~0.3 seconds
Pydantic v2's Rust Core
Pydantic v2 rewrote its core validation logic in Rust (pydantic-core), achieving 5-50x performance improvements. AI API servers see dramatically reduced request validation latency.
Quiz
**Answer**: Arc (Atomic Reference Count) manages reference counts with atomic operations, making it thread-safe. Rc (Reference Count) is for single-threaded use only.
**Explanation**: In an AI inference server, multiple request threads share the same model weights. Rc does not implement the Send trait, so it cannot be moved between threads. Arc uses atomic fetch-add operations for reference count changes, allowing multiple threads to reference the same data without data races. The performance difference is negligible (nanosecond-level), but safety is guaranteed at compile time.
**Answer**: PyO3 releases the GIL via `py.allow_threads()`, enabling Rust code to execute truly in parallel without Python's GIL constraint.
**Explanation**: Python's GIL (Global Interpreter Lock) restricts execution to one Python thread at a time. Calling `allow_threads` releases the GIL for the current thread, letting Rust code (with rayon, for example) freely use OS threads. Computing cosine similarity over 1 million embeddings takes tens of seconds in single-threaded Python, but can be reduced to hundreds of milliseconds with Rust + rayon parallelism.
**Answer**: Candle ships as a single Rust binary with WASM support, producing ~5MB binaries without requiring the PyTorch runtime (250MB+).
**Explanation**: PyTorch requires Python runtime, libtorch, CUDA drivers, and other large dependencies. Candle's static linking produces small executables containing only what's needed. Compiling to WebAssembly enables direct browser inference — client-side AI with no server costs. It also runs on Raspberry Pi and embedded devices without Python.
**Answer**: Polars is written in Rust and uses the Apache Arrow columnar memory format, SIMD optimizations, and query optimization via lazy evaluation.
**Explanation**: Pandas relies on Python with frequent row-wise operations and GIL constraints. Polars uses (1) Arrow columnar format for cache-efficient access, (2) Rust's automatic SIMD vectorization for numeric operations, (3) a lazy API that optimizes query plans (predicate pushdown, projection pushdown), and (4) multi-threaded parallel processing without a GIL. Typical speedup is 5-20x over Pandas, with lower memory consumption.
**Answer**: The Rust compiler (LLVM backend) transforms high-level abstractions (traits, generics, iterators) into optimized machine code with no runtime overhead.
**Explanation**: Like C++ templates and inlining, Rust generics use monomorphization to generate concrete code for each type at compile time. Iterator chains (`map`, `filter`, `fold`) produce the same assembly as C-style for loops. Trait methods use static dispatch — no virtual table lookups, just inlining. The result: code that is both ergonomic to write and as fast as C.
Wrapping Up
Rust is a powerful tool for pushing past the **performance ceiling** in AI infrastructure. The practical approach: prototype in Python, rewrite bottlenecks in Rust (via PyO3), or build high-performance services in Rust from the ground up.
**Learning Roadmap**:
1. [The Rust Book](https://doc.rust-lang.org/book/) — official tutorial
2. Rustlings — interactive exercises
3. Candle examples — HuggingFace GitHub
4. Axum documentation — building inference servers
5. PyO3 guide — Python integration
현재 단락 (1/486)
Python dominates AI development, but for production AI infrastructure where **performance and safety...