Skip to content

필사 모드: Rust for AI Systems: From Ownership to High-Performance Inference Servers with Candle, PyO3, and axum

English
0%
정확도 0%
💡 왼쪽 원문을 읽으면서 오른쪽에 따라 써보세요. Tab 키로 힌트를 받을 수 있습니다.
원문 렌더가 준비되기 전까지 텍스트 가이드로 표시합니다.

Introduction

Python dominates AI development, but for production AI infrastructure where **performance and safety** are non-negotiable, **Rust** is gaining serious traction.

HuggingFace's **Candle**, the blazing-fast Python linter **Ruff**, and Pydantic v2's Rust core — core tools in the AI ecosystem are already being rewritten in Rust. This guide is a practical Rust introduction for AI engineers.

Why Rust for AI Systems?

| Aspect | Python | Rust |

| --------------- | ---------------- | ----------------------- |

| Runtime speed | Moderate | C++ level |

| Memory safety | GC dependent | Compile-time guaranteed |

| Concurrency | GIL constrained | No data races |

| Binary size | Runtime required | Single binary |

| Edge deployment | Difficult | WASM, embedded support |

1. Rust Fundamentals: An AI Engineer's Perspective

Ownership and Memory Safety

Rust's cornerstone is the **ownership system** — memory safety without a garbage collector.

fn main() {

// Ownership move

let tensor_data = vec![1.0f32, 2.0, 3.0, 4.0];

let moved = tensor_data; // tensor_data is no longer valid

// Borrowing — reference without transferring ownership

let weights = vec![0.1f32, 0.2, 0.3];

let sum = compute_sum(&weights); // immutable reference

println!("weights still valid: {:?}, sum: {}", weights, sum);

}

fn compute_sum(data: &[f32]) -> f32 {

data.iter().sum()

}

Lifetimes

Lifetimes tell the compiler how long a reference is valid. Critical when safely referencing weight data in AI models without copying.

// Lifetime annotation: returned reference can't outlive input slices

fn longest_sequence<'a>(seq1: &'a [f32], seq2: &'a [f32]) -> &'a [f32] {

if seq1.len() > seq2.len() {

seq1

} else {

seq2

}

}

struct ModelWeights<'a> {

data: &'a [f32], // references an external buffer (zero copy)

shape: (usize, usize),

}

impl<'a> ModelWeights<'a> {

fn new(data: &'a [f32], rows: usize, cols: usize) -> Self {

ModelWeights { data, shape: (rows, cols) }

}

fn get(&self, row: usize, col: usize) -> f32 {

self.data[row * self.shape.1 + col]

}

}

Traits and Generics

Unlike Python's duck typing, Rust performs type checking at **compile time**.

// Trait definition for AI operations

trait Activation {

fn forward(&self, x: f32) -> f32;

fn backward(&self, x: f32) -> f32; // derivative

}

struct ReLU;

struct Sigmoid;

impl Activation for ReLU {

fn forward(&self, x: f32) -> f32 {

x.max(0.0)

}

fn backward(&self, x: f32) -> f32 {

if x > 0.0 { 1.0 } else { 0.0 }

}

}

impl Activation for Sigmoid {

fn forward(&self, x: f32) -> f32 {

1.0 / (1.0 + (-x).exp())

}

fn backward(&self, x: f32) -> f32 {

let s = self.forward(x);

s * (1.0 - s)

}

}

// Generic layer: works with any activation function

fn apply_activation<A: Activation>(activation: &A, inputs: &[f32]) -> Vec<f32> {

inputs.iter().map(|&x| activation.forward(x)).collect()

}

fn main() {

let relu = ReLU;

let data = vec![-1.0, 0.5, 2.0, -0.3];

let output = apply_activation(&relu, &data);

println!("ReLU output: {:?}", output); // [0.0, 0.5, 2.0, 0.0]

}

Arc vs Rc: Multi-threaded AI Inference

use std::sync::{Arc, Mutex};

use std::thread;

// Rc is single-threaded only — unusable in AI inference servers

// Arc (Atomic Reference Count) is thread-safe

fn multi_thread_inference() {

// Share model weights across multiple threads

let model_weights: Arc<Vec<f32>> = Arc::new(vec![0.1, 0.2, 0.3, 0.4]);

let results: Arc<Mutex<Vec<f32>>> = Arc::new(Mutex::new(Vec::new()));

let mut handles = vec![];

for i in 0..4 {

let weights = Arc::clone(&model_weights);

let results = Arc::clone(&results);

let handle = thread::spawn(move || {

// Each thread runs inference with the same weights

let inference_result = weights.iter().sum::<f32>() * i as f32;

results.lock().unwrap().push(inference_result);

});

handles.push(handle);

}

for handle in handles {

handle.join().unwrap();

}

println!("All results: {:?}", results.lock().unwrap());

}

2. The Rust AI Ecosystem

Candle: HuggingFace's Rust ML Framework

**Candle** is PyTorch's Rust alternative, with minimal dependencies and first-class WASM support.

Cargo.toml

[dependencies]

candle-core = "0.8"

candle-nn = "0.8"

candle-transformers = "0.8"

use candle_core::{Device, Tensor, DType};

use candle_nn::{Linear, Module, VarBuilder, VarMap};

// Simple feedforward network

struct FeedForward {

fc1: Linear,

fc2: Linear,

}

impl FeedForward {

fn new(vb: VarBuilder) -> candle_core::Result<Self> {

let fc1 = candle_nn::linear(784, 256, vb.pp("fc1"))?;

let fc2 = candle_nn::linear(256, 10, vb.pp("fc2"))?;

Ok(FeedForward { fc1, fc2 })

}

}

impl Module for FeedForward {

fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {

let x = self.fc1.forward(x)?;

let x = x.relu()?;

self.fc2.forward(&x)

}

}

fn candle_tensor_ops() -> candle_core::Result<()> {

let device = Device::Cpu; // or Device::new_cuda(0)?

// Create tensors

let a = Tensor::randn(0f32, 1f32, (3, 4), &device)?;

let b = Tensor::randn(0f32, 1f32, (4, 5), &device)?;

// Matrix multiplication

let c = a.matmul(&b)?;

println!("Shape: {:?}", c.shape());

// Element-wise operations

let x = Tensor::new(&[[1.0f32, 2.0, 3.0]], &device)?;

let softmax = candle_nn::ops::softmax(&x, 1)?;

println!("Softmax: {:?}", softmax.to_vec2::<f32>()?);

Ok(())

}

Burn: Backend-Agnostic ML Framework

use burn::prelude::*;

use burn::nn::{Linear, LinearConfig, Relu};

#[derive(Module, Debug)]

struct SimpleNet<B: Backend> {

fc1: Linear<B>,

activation: Relu,

fc2: Linear<B>,

}

impl<B: Backend> SimpleNet<B> {

fn new(device: &B::Device) -> Self {

SimpleNet {

fc1: LinearConfig::new(128, 64).init(device),

activation: Relu::new(),

fc2: LinearConfig::new(64, 10).init(device),

}

}

fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {

let x = self.fc1.forward(x);

let x = self.activation.forward(x);

self.fc2.forward(x)

}

}

Linfa: Scikit-learn Alternative

use linfa::prelude::*;

use linfa_clustering::KMeans;

use ndarray::Array2;

fn kmeans_clustering() -> Result<(), Box<dyn std::error::Error>> {

// Generate data (assume embedding vectors)

let data: Array2<f64> = Array2::from_shape_fn((100, 10), |(i, j)| {

(i as f64 * 0.1 + j as f64) % 5.0

});

let dataset = Dataset::from(data);

// K-means clustering (grouping similar docs in a RAG system)

let model = KMeans::params(5)

.tolerance(1e-5)

.fit(&dataset)?;

let predictions = model.predict(&dataset);

println!("Cluster assignments: {:?}", &predictions.targets()[..10]);

Ok(())

}

3. High-Performance Inference Server: axum + Candle

tokio Async-based LLM Inference Server

[dependencies]

axum = "0.7"

tokio = { version = "1", features = ["full"] }

candle-core = "0.8"

candle-transformers = "0.8"

serde = { version = "1", features = ["derive"] }

serde_json = "1"

tower = "0.4"

use axum::{

extract::State,

http::StatusCode,

response::Json,

routing::post,

Router,

};

use candle_core::Device;

use serde::{Deserialize, Serialize};

use std::sync::Arc;

#[derive(Deserialize)]

struct InferenceRequest {

prompt: String,

max_tokens: Option<usize>,

temperature: Option<f32>,

}

#[derive(Serialize)]

struct InferenceResponse {

text: String,

tokens_generated: usize,

latency_ms: u64,

}

// Model state shared via Arc (thread-safe)

struct AppState {

device: Device,

// model: Arc<Mutex<YourLLMModel>>,

}

async fn inference_handler(

State(state): State<Arc<AppState>>,

Json(req): Json<InferenceRequest>,

) -> Result<Json<InferenceResponse>, StatusCode> {

let start = std::time::Instant::now();

let max_tokens = req.max_tokens.unwrap_or(100);

let _temperature = req.temperature.unwrap_or(0.7);

// Dummy inference (real implementation uses Candle model forward pass)

let generated_text = format!("Response to: {}", req.prompt);

let latency = start.elapsed().as_millis() as u64;

Ok(Json(InferenceResponse {

text: generated_text,

tokens_generated: max_tokens,

latency_ms: latency,

}))

}

async fn health_handler() -> &'static str {

"OK"

}

#[tokio::main]

async fn main() {

let state = Arc::new(AppState {

device: Device::Cpu,

});

let app = Router::new()

.route("/inference", post(inference_handler))

.route("/health", axum::routing::get(health_handler))

.with_state(state);

let listener = tokio::net::TcpListener::bind("0.0.0.0:8080").await.unwrap();

println!("AI Inference Server running on port 8080");

axum::serve(listener, app).await.unwrap();

}

Batched Inference Optimization

use tokio::sync::mpsc;

struct BatchRequest {

id: u64,

prompt: String,

response_tx: tokio::sync::oneshot::Sender<String>,

}

// Batch processing worker — bundle requests to maximize GPU utilization

async fn batch_inference_worker(

mut rx: mpsc::Receiver<BatchRequest>,

batch_size: usize,

timeout_ms: u64,

) {

loop {

let mut batch = Vec::new();

// Wait for first request

if let Some(req) = rx.recv().await {

batch.push(req);

} else {

break;

}

// Fill batch within timeout window

let deadline = tokio::time::Instant::now()

+ tokio::time::Duration::from_millis(timeout_ms);

while batch.len() < batch_size {

match tokio::time::timeout_at(deadline, rx.recv()).await {

Ok(Some(req)) => batch.push(req),

_ => break,

}

}

// Process batch (real implementation runs parallel GPU inference)

for req in batch {

let response = format!("Batch response for: {}", req.prompt);

let _ = req.response_tx.send(response);

}

}

}

4. Python Interop: PyO3

Calling Rust Functions from Python

[lib]

name = "rust_ai_tools"

crate-type = ["cdylib"]

[dependencies]

pyo3 = { version = "0.22", features = ["extension-module"] }

numpy = "0.22"

use pyo3::prelude::*;

// Python-callable Rust function

#[pyfunction]

fn fast_cosine_similarity(a: Vec<f32>, b: Vec<f32>) -> PyResult<f32> {

if a.len() != b.len() {

return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(

"Vectors must have the same length"

));

}

let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();

let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();

let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();

Ok(dot / (norm_a * norm_b))

}

// Accept numpy arrays and process — GIL release enables true parallelism

#[pyfunction]

fn batch_normalize(

py: Python,

embeddings: Vec<Vec<f32>>,

) -> PyResult<Vec<Vec<f32>>> {

// Release GIL — Rust thread pool runs freely

let result = py.allow_threads(|| {

embeddings.iter().map(|emb| {

let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();

emb.iter().map(|x| x / norm).collect()

}).collect::<Vec<Vec<f32>>>()

});

Ok(result)

}

// Register Python module

#[pymodule]

fn rust_ai_tools(_py: Python, m: &PyModule) -> PyResult<()> {

m.add_function(wrap_pyfunction!(fast_cosine_similarity, m)?)?;

m.add_function(wrap_pyfunction!(batch_normalize, m)?)?;

Ok(())

}

Build and Publish with maturin

Install maturin

pip install maturin

Development build (installs directly into virtualenv)

maturin develop

Usage from Python:

import rust_ai_tools

sim = rust_ai_tools.fast_cosine_similarity([1.0, 0.0], [0.0, 1.0])

print(sim) # 0.0

Build wheel for PyPI

maturin build --release

pyproject.toml configuration:

[build-system]

requires = ["maturin>=1.0,<2.0"]

build-backend = "maturin"

GIL Handling and True Parallelism

use pyo3::prelude::*;

use rayon::prelude::*;

// True parallelism via rayon

#[pyfunction]

fn parallel_embedding_search(

py: Python,

query: Vec<f32>,

corpus: Vec<Vec<f32>>,

top_k: usize,

) -> PyResult<Vec<usize>> {

// Release GIL — Rust thread pool takes over

let indices = py.allow_threads(|| {

let mut scores: Vec<(usize, f32)> = corpus

.par_iter() // parallel iterator

.enumerate()

.map(|(i, emb)| {

let dot: f32 = query.iter().zip(emb.iter()).map(|(a, b)| a * b).sum();

(i, dot)

})

.collect();

scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());

scores.iter().take(top_k).map(|(i, _)| *i).collect::<Vec<_>>()

});

Ok(indices)

}

5. WASM for AI: Browser-side Inference

In-Browser AI with wasm-bindgen

[lib]

crate-type = ["cdylib"]

[dependencies]

wasm-bindgen = "0.2"

candle-core = { version = "0.8", features = ["wasm"] }

use wasm_bindgen::prelude::*;

// JavaScript-callable struct

#[wasm_bindgen]

pub struct MiniModel {

weights: Vec<f32>,

input_size: usize,

output_size: usize,

}

#[wasm_bindgen]

impl MiniModel {

#[wasm_bindgen(constructor)]

pub fn new(input_size: usize, output_size: usize) -> MiniModel {

let weights = vec![0.01f32; input_size * output_size];

MiniModel { weights, input_size, output_size }

}

// Accept JavaScript Float32Array and run inference

pub fn predict(&self, input: &[f32]) -> Vec<f32> {

let mut output = vec![0.0f32; self.output_size];

for i in 0..self.output_size {

for j in 0..self.input_size {

output[i] += input[j] * self.weights[i * self.input_size + j];

}

// ReLU activation

output[i] = output[i].max(0.0);

}

output

}

pub fn load_weights(&mut self, weights: Vec<f32>) {

self.weights = weights;

}

}

// Cosine similarity for embeddings (enables client-side RAG)

#[wasm_bindgen]

pub fn cosine_similarity_wasm(a: &[f32], b: &[f32]) -> f32 {

let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();

let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();

let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();

dot / (na * nb)

}

// Using the WASM module from JavaScript

async function runBrowserInference() {

await init()

const model = new MiniModel(128, 10)

const input = new Float32Array(128).fill(0.5)

const output = model.predict(input)

console.log('Browser inference output:', output)

// Vector similarity search

const queryEmb = new Float32Array(384).map(() => Math.random())

const docEmb = new Float32Array(384).map(() => Math.random())

const sim = cosine_similarity_wasm(queryEmb, docEmb)

console.log('Similarity:', sim)

}

6. Memory-Safe ML: Polars Data Pipelines

Large-Scale Training Data Processing with Polars

[dependencies]

polars = { version = "0.44", features = ["lazy", "parquet", "json", "csv"] }

use polars::prelude::*;

fn preprocess_training_data() -> PolarsResult<DataFrame> {

// Lazy Parquet scan (memory efficient)

let df = LazyFrame::scan_parquet("training_data.parquet", Default::default())?

.filter(col("label").is_not_null())

.with_column(

// Extract text length feature

col("text").str().len_chars().alias("text_length")

)

.with_column(

// Normalize scores

(col("score") - col("score").mean()) / col("score").std(1)

)

.filter(col("text_length").gt(10))

.select([col("text"), col("label"), col("score"), col("text_length")])

.collect()?;

println!("Shape: {:?}", df.shape());

println!("{}", df.head(Some(5)));

Ok(df)

}

fn batch_generator(df: &DataFrame, batch_size: usize) -> Vec<DataFrame> {

let n = df.height();

(0..n).step_by(batch_size).map(|start| {

let end = (start + batch_size).min(n);

df.slice(start as i64, end - start)

}).collect()

}

fn compute_label_stats(df: &DataFrame) -> PolarsResult<DataFrame> {

df.clone().lazy()

.group_by([col("label")])

.agg([

col("score").mean().alias("avg_score"),

col("score").std(1).alias("std_score"),

col("text_length").mean().alias("avg_length"),

col("label").count().alias("count"),

])

.sort(["count"], SortMultipleOptions::default().with_order_descending(true))

.collect()

}

Zero-Copy Data Sharing with Apache Arrow

use arrow::array::{Float32Array, StringArray};

use arrow::record_batch::RecordBatch;

use arrow::datatypes::{DataType, Field, Schema};

use std::sync::Arc;

fn create_embedding_batch(

texts: Vec<String>,

embeddings: Vec<Vec<f32>>,

) -> RecordBatch {

let schema = Arc::new(Schema::new(vec![

Field::new("text", DataType::Utf8, false),

Field::new("embedding_dim0", DataType::Float32, false),

]));

let text_array = StringArray::from(texts);

// Store first dimension as example

let emb_dim: Vec<f32> = embeddings.iter().map(|e| e[0]).collect();

let emb_array = Float32Array::from(emb_dim);

RecordBatch::try_new(

schema,

vec![Arc::new(text_array), Arc::new(emb_array)],

).unwrap()

}

7. Case Studies

Hugging Face Candle

HuggingFace is porting Python PyTorch models to Rust Candle:

- Binary size: PyTorch 250MB+ vs Candle ~5MB

- WASM support enables direct browser inference

- Reusable CUDA kernels

Ruff: Python Linter Rewritten in Rust

Ruff is 100x faster than Flake8. The speed difference is palpable on large ML codebases.

Legacy Python linter

time flake8 large_ml_project/ # ~30 seconds

Ruff (Rust)

time ruff check large_ml_project/ # ~0.3 seconds

Pydantic v2's Rust Core

Pydantic v2 rewrote its core validation logic in Rust (pydantic-core), achieving 5-50x performance improvements. AI API servers see dramatically reduced request validation latency.

Quiz

**Answer**: Arc (Atomic Reference Count) manages reference counts with atomic operations, making it thread-safe. Rc (Reference Count) is for single-threaded use only.

**Explanation**: In an AI inference server, multiple request threads share the same model weights. Rc does not implement the Send trait, so it cannot be moved between threads. Arc uses atomic fetch-add operations for reference count changes, allowing multiple threads to reference the same data without data races. The performance difference is negligible (nanosecond-level), but safety is guaranteed at compile time.

**Answer**: PyO3 releases the GIL via `py.allow_threads()`, enabling Rust code to execute truly in parallel without Python's GIL constraint.

**Explanation**: Python's GIL (Global Interpreter Lock) restricts execution to one Python thread at a time. Calling `allow_threads` releases the GIL for the current thread, letting Rust code (with rayon, for example) freely use OS threads. Computing cosine similarity over 1 million embeddings takes tens of seconds in single-threaded Python, but can be reduced to hundreds of milliseconds with Rust + rayon parallelism.

**Answer**: Candle ships as a single Rust binary with WASM support, producing ~5MB binaries without requiring the PyTorch runtime (250MB+).

**Explanation**: PyTorch requires Python runtime, libtorch, CUDA drivers, and other large dependencies. Candle's static linking produces small executables containing only what's needed. Compiling to WebAssembly enables direct browser inference — client-side AI with no server costs. It also runs on Raspberry Pi and embedded devices without Python.

**Answer**: Polars is written in Rust and uses the Apache Arrow columnar memory format, SIMD optimizations, and query optimization via lazy evaluation.

**Explanation**: Pandas relies on Python with frequent row-wise operations and GIL constraints. Polars uses (1) Arrow columnar format for cache-efficient access, (2) Rust's automatic SIMD vectorization for numeric operations, (3) a lazy API that optimizes query plans (predicate pushdown, projection pushdown), and (4) multi-threaded parallel processing without a GIL. Typical speedup is 5-20x over Pandas, with lower memory consumption.

**Answer**: The Rust compiler (LLVM backend) transforms high-level abstractions (traits, generics, iterators) into optimized machine code with no runtime overhead.

**Explanation**: Like C++ templates and inlining, Rust generics use monomorphization to generate concrete code for each type at compile time. Iterator chains (`map`, `filter`, `fold`) produce the same assembly as C-style for loops. Trait methods use static dispatch — no virtual table lookups, just inlining. The result: code that is both ergonomic to write and as fast as C.

Wrapping Up

Rust is a powerful tool for pushing past the **performance ceiling** in AI infrastructure. The practical approach: prototype in Python, rewrite bottlenecks in Rust (via PyO3), or build high-performance services in Rust from the ground up.

**Learning Roadmap**:

1. [The Rust Book](https://doc.rust-lang.org/book/) — official tutorial

2. Rustlings — interactive exercises

3. Candle examples — HuggingFace GitHub

4. Axum documentation — building inference servers

5. PyO3 guide — Python integration

현재 단락 (1/486)

Python dominates AI development, but for production AI infrastructure where **performance and safety...

작성 글자: 0원문 글자: 17,281작성 단락: 0/486