Skip to content
Published on

Rust for AI Systems: From Ownership to High-Performance Inference Servers with Candle, PyO3, and axum

Authors

Introduction

Python dominates AI development, but for production AI infrastructure where performance and safety are non-negotiable, Rust is gaining serious traction.

HuggingFace's Candle, the blazing-fast Python linter Ruff, and Pydantic v2's Rust core — core tools in the AI ecosystem are already being rewritten in Rust. This guide is a practical Rust introduction for AI engineers.

Why Rust for AI Systems?

AspectPythonRust
Runtime speedModerateC++ level
Memory safetyGC dependentCompile-time guaranteed
ConcurrencyGIL constrainedNo data races
Binary sizeRuntime requiredSingle binary
Edge deploymentDifficultWASM, embedded support

1. Rust Fundamentals: An AI Engineer's Perspective

Ownership and Memory Safety

Rust's cornerstone is the ownership system — memory safety without a garbage collector.

fn main() {
    // Ownership move
    let tensor_data = vec![1.0f32, 2.0, 3.0, 4.0];
    let moved = tensor_data; // tensor_data is no longer valid

    // Borrowing — reference without transferring ownership
    let weights = vec![0.1f32, 0.2, 0.3];
    let sum = compute_sum(&weights); // immutable reference
    println!("weights still valid: {:?}, sum: {}", weights, sum);
}

fn compute_sum(data: &[f32]) -> f32 {
    data.iter().sum()
}

Lifetimes

Lifetimes tell the compiler how long a reference is valid. Critical when safely referencing weight data in AI models without copying.

// Lifetime annotation: returned reference can't outlive input slices
fn longest_sequence<'a>(seq1: &'a [f32], seq2: &'a [f32]) -> &'a [f32] {
    if seq1.len() > seq2.len() {
        seq1
    } else {
        seq2
    }
}

struct ModelWeights<'a> {
    data: &'a [f32],  // references an external buffer (zero copy)
    shape: (usize, usize),
}

impl<'a> ModelWeights<'a> {
    fn new(data: &'a [f32], rows: usize, cols: usize) -> Self {
        ModelWeights { data, shape: (rows, cols) }
    }

    fn get(&self, row: usize, col: usize) -> f32 {
        self.data[row * self.shape.1 + col]
    }
}

Traits and Generics

Unlike Python's duck typing, Rust performs type checking at compile time.

// Trait definition for AI operations
trait Activation {
    fn forward(&self, x: f32) -> f32;
    fn backward(&self, x: f32) -> f32; // derivative
}

struct ReLU;
struct Sigmoid;

impl Activation for ReLU {
    fn forward(&self, x: f32) -> f32 {
        x.max(0.0)
    }
    fn backward(&self, x: f32) -> f32 {
        if x > 0.0 { 1.0 } else { 0.0 }
    }
}

impl Activation for Sigmoid {
    fn forward(&self, x: f32) -> f32 {
        1.0 / (1.0 + (-x).exp())
    }
    fn backward(&self, x: f32) -> f32 {
        let s = self.forward(x);
        s * (1.0 - s)
    }
}

// Generic layer: works with any activation function
fn apply_activation<A: Activation>(activation: &A, inputs: &[f32]) -> Vec<f32> {
    inputs.iter().map(|&x| activation.forward(x)).collect()
}

fn main() {
    let relu = ReLU;
    let data = vec![-1.0, 0.5, 2.0, -0.3];
    let output = apply_activation(&relu, &data);
    println!("ReLU output: {:?}", output); // [0.0, 0.5, 2.0, 0.0]
}

Arc vs Rc: Multi-threaded AI Inference

use std::sync::{Arc, Mutex};
use std::thread;

// Rc is single-threaded only — unusable in AI inference servers
// Arc (Atomic Reference Count) is thread-safe

fn multi_thread_inference() {
    // Share model weights across multiple threads
    let model_weights: Arc<Vec<f32>> = Arc::new(vec![0.1, 0.2, 0.3, 0.4]);
    let results: Arc<Mutex<Vec<f32>>> = Arc::new(Mutex::new(Vec::new()));

    let mut handles = vec![];

    for i in 0..4 {
        let weights = Arc::clone(&model_weights);
        let results = Arc::clone(&results);

        let handle = thread::spawn(move || {
            // Each thread runs inference with the same weights
            let inference_result = weights.iter().sum::<f32>() * i as f32;
            results.lock().unwrap().push(inference_result);
        });
        handles.push(handle);
    }

    for handle in handles {
        handle.join().unwrap();
    }

    println!("All results: {:?}", results.lock().unwrap());
}

2. The Rust AI Ecosystem

Candle: HuggingFace's Rust ML Framework

Candle is PyTorch's Rust alternative, with minimal dependencies and first-class WASM support.

# Cargo.toml
[dependencies]
candle-core = "0.8"
candle-nn = "0.8"
candle-transformers = "0.8"
use candle_core::{Device, Tensor, DType};
use candle_nn::{Linear, Module, VarBuilder, VarMap};

// Simple feedforward network
struct FeedForward {
    fc1: Linear,
    fc2: Linear,
}

impl FeedForward {
    fn new(vb: VarBuilder) -> candle_core::Result<Self> {
        let fc1 = candle_nn::linear(784, 256, vb.pp("fc1"))?;
        let fc2 = candle_nn::linear(256, 10, vb.pp("fc2"))?;
        Ok(FeedForward { fc1, fc2 })
    }
}

impl Module for FeedForward {
    fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
        let x = self.fc1.forward(x)?;
        let x = x.relu()?;
        self.fc2.forward(&x)
    }
}

fn candle_tensor_ops() -> candle_core::Result<()> {
    let device = Device::Cpu; // or Device::new_cuda(0)?

    // Create tensors
    let a = Tensor::randn(0f32, 1f32, (3, 4), &device)?;
    let b = Tensor::randn(0f32, 1f32, (4, 5), &device)?;

    // Matrix multiplication
    let c = a.matmul(&b)?;
    println!("Shape: {:?}", c.shape());

    // Element-wise operations
    let x = Tensor::new(&[[1.0f32, 2.0, 3.0]], &device)?;
    let softmax = candle_nn::ops::softmax(&x, 1)?;
    println!("Softmax: {:?}", softmax.to_vec2::<f32>()?);

    Ok(())
}

Burn: Backend-Agnostic ML Framework

use burn::prelude::*;
use burn::nn::{Linear, LinearConfig, Relu};

#[derive(Module, Debug)]
struct SimpleNet<B: Backend> {
    fc1: Linear<B>,
    activation: Relu,
    fc2: Linear<B>,
}

impl<B: Backend> SimpleNet<B> {
    fn new(device: &B::Device) -> Self {
        SimpleNet {
            fc1: LinearConfig::new(128, 64).init(device),
            activation: Relu::new(),
            fc2: LinearConfig::new(64, 10).init(device),
        }
    }

    fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
        let x = self.fc1.forward(x);
        let x = self.activation.forward(x);
        self.fc2.forward(x)
    }
}

Linfa: Scikit-learn Alternative

use linfa::prelude::*;
use linfa_clustering::KMeans;
use ndarray::Array2;

fn kmeans_clustering() -> Result<(), Box<dyn std::error::Error>> {
    // Generate data (assume embedding vectors)
    let data: Array2<f64> = Array2::from_shape_fn((100, 10), |(i, j)| {
        (i as f64 * 0.1 + j as f64) % 5.0
    });

    let dataset = Dataset::from(data);

    // K-means clustering (grouping similar docs in a RAG system)
    let model = KMeans::params(5)
        .tolerance(1e-5)
        .fit(&dataset)?;

    let predictions = model.predict(&dataset);
    println!("Cluster assignments: {:?}", &predictions.targets()[..10]);

    Ok(())
}

3. High-Performance Inference Server: axum + Candle

tokio Async-based LLM Inference Server

[dependencies]
axum = "0.7"
tokio = { version = "1", features = ["full"] }
candle-core = "0.8"
candle-transformers = "0.8"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
tower = "0.4"
use axum::{
    extract::State,
    http::StatusCode,
    response::Json,
    routing::post,
    Router,
};
use candle_core::Device;
use serde::{Deserialize, Serialize};
use std::sync::Arc;

#[derive(Deserialize)]
struct InferenceRequest {
    prompt: String,
    max_tokens: Option<usize>,
    temperature: Option<f32>,
}

#[derive(Serialize)]
struct InferenceResponse {
    text: String,
    tokens_generated: usize,
    latency_ms: u64,
}

// Model state shared via Arc (thread-safe)
struct AppState {
    device: Device,
    // model: Arc<Mutex<YourLLMModel>>,
}

async fn inference_handler(
    State(state): State<Arc<AppState>>,
    Json(req): Json<InferenceRequest>,
) -> Result<Json<InferenceResponse>, StatusCode> {
    let start = std::time::Instant::now();

    let max_tokens = req.max_tokens.unwrap_or(100);
    let _temperature = req.temperature.unwrap_or(0.7);

    // Dummy inference (real implementation uses Candle model forward pass)
    let generated_text = format!("Response to: {}", req.prompt);
    let latency = start.elapsed().as_millis() as u64;

    Ok(Json(InferenceResponse {
        text: generated_text,
        tokens_generated: max_tokens,
        latency_ms: latency,
    }))
}

async fn health_handler() -> &'static str {
    "OK"
}

#[tokio::main]
async fn main() {
    let state = Arc::new(AppState {
        device: Device::Cpu,
    });

    let app = Router::new()
        .route("/inference", post(inference_handler))
        .route("/health", axum::routing::get(health_handler))
        .with_state(state);

    let listener = tokio::net::TcpListener::bind("0.0.0.0:8080").await.unwrap();
    println!("AI Inference Server running on port 8080");
    axum::serve(listener, app).await.unwrap();
}

Batched Inference Optimization

use tokio::sync::mpsc;

struct BatchRequest {
    id: u64,
    prompt: String,
    response_tx: tokio::sync::oneshot::Sender<String>,
}

// Batch processing worker — bundle requests to maximize GPU utilization
async fn batch_inference_worker(
    mut rx: mpsc::Receiver<BatchRequest>,
    batch_size: usize,
    timeout_ms: u64,
) {
    loop {
        let mut batch = Vec::new();

        // Wait for first request
        if let Some(req) = rx.recv().await {
            batch.push(req);
        } else {
            break;
        }

        // Fill batch within timeout window
        let deadline = tokio::time::Instant::now()
            + tokio::time::Duration::from_millis(timeout_ms);

        while batch.len() < batch_size {
            match tokio::time::timeout_at(deadline, rx.recv()).await {
                Ok(Some(req)) => batch.push(req),
                _ => break,
            }
        }

        // Process batch (real implementation runs parallel GPU inference)
        for req in batch {
            let response = format!("Batch response for: {}", req.prompt);
            let _ = req.response_tx.send(response);
        }
    }
}

4. Python Interop: PyO3

Calling Rust Functions from Python

[lib]
name = "rust_ai_tools"
crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.22", features = ["extension-module"] }
numpy = "0.22"
use pyo3::prelude::*;

// Python-callable Rust function
#[pyfunction]
fn fast_cosine_similarity(a: Vec<f32>, b: Vec<f32>) -> PyResult<f32> {
    if a.len() != b.len() {
        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
            "Vectors must have the same length"
        ));
    }

    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();

    Ok(dot / (norm_a * norm_b))
}

// Accept numpy arrays and process — GIL release enables true parallelism
#[pyfunction]
fn batch_normalize(
    py: Python,
    embeddings: Vec<Vec<f32>>,
) -> PyResult<Vec<Vec<f32>>> {
    // Release GIL — Rust thread pool runs freely
    let result = py.allow_threads(|| {
        embeddings.iter().map(|emb| {
            let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
            emb.iter().map(|x| x / norm).collect()
        }).collect::<Vec<Vec<f32>>>()
    });

    Ok(result)
}

// Register Python module
#[pymodule]
fn rust_ai_tools(_py: Python, m: &PyModule) -> PyResult<()> {
    m.add_function(wrap_pyfunction!(fast_cosine_similarity, m)?)?;
    m.add_function(wrap_pyfunction!(batch_normalize, m)?)?;
    Ok(())
}

Build and Publish with maturin

# Install maturin
pip install maturin

# Development build (installs directly into virtualenv)
maturin develop

# Usage from Python:
# import rust_ai_tools
# sim = rust_ai_tools.fast_cosine_similarity([1.0, 0.0], [0.0, 1.0])
# print(sim)  # 0.0

# Build wheel for PyPI
maturin build --release

# pyproject.toml configuration:
# [build-system]
# requires = ["maturin>=1.0,<2.0"]
# build-backend = "maturin"

GIL Handling and True Parallelism

use pyo3::prelude::*;
use rayon::prelude::*;

// True parallelism via rayon
#[pyfunction]
fn parallel_embedding_search(
    py: Python,
    query: Vec<f32>,
    corpus: Vec<Vec<f32>>,
    top_k: usize,
) -> PyResult<Vec<usize>> {
    // Release GIL — Rust thread pool takes over
    let indices = py.allow_threads(|| {
        let mut scores: Vec<(usize, f32)> = corpus
            .par_iter()  // parallel iterator
            .enumerate()
            .map(|(i, emb)| {
                let dot: f32 = query.iter().zip(emb.iter()).map(|(a, b)| a * b).sum();
                (i, dot)
            })
            .collect();

        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
        scores.iter().take(top_k).map(|(i, _)| *i).collect::<Vec<_>>()
    });

    Ok(indices)
}

5. WASM for AI: Browser-side Inference

In-Browser AI with wasm-bindgen

[lib]
crate-type = ["cdylib"]

[dependencies]
wasm-bindgen = "0.2"
candle-core = { version = "0.8", features = ["wasm"] }
use wasm_bindgen::prelude::*;

// JavaScript-callable struct
#[wasm_bindgen]
pub struct MiniModel {
    weights: Vec<f32>,
    input_size: usize,
    output_size: usize,
}

#[wasm_bindgen]
impl MiniModel {
    #[wasm_bindgen(constructor)]
    pub fn new(input_size: usize, output_size: usize) -> MiniModel {
        let weights = vec![0.01f32; input_size * output_size];
        MiniModel { weights, input_size, output_size }
    }

    // Accept JavaScript Float32Array and run inference
    pub fn predict(&self, input: &[f32]) -> Vec<f32> {
        let mut output = vec![0.0f32; self.output_size];
        for i in 0..self.output_size {
            for j in 0..self.input_size {
                output[i] += input[j] * self.weights[i * self.input_size + j];
            }
            // ReLU activation
            output[i] = output[i].max(0.0);
        }
        output
    }

    pub fn load_weights(&mut self, weights: Vec<f32>) {
        self.weights = weights;
    }
}

// Cosine similarity for embeddings (enables client-side RAG)
#[wasm_bindgen]
pub fn cosine_similarity_wasm(a: &[f32], b: &[f32]) -> f32 {
    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
    let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
    let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
    dot / (na * nb)
}
// Using the WASM module from JavaScript
import init, { MiniModel, cosine_similarity_wasm } from './rust_ai_wasm.js'

async function runBrowserInference() {
  await init()

  const model = new MiniModel(128, 10)
  const input = new Float32Array(128).fill(0.5)
  const output = model.predict(input)
  console.log('Browser inference output:', output)

  // Vector similarity search
  const queryEmb = new Float32Array(384).map(() => Math.random())
  const docEmb = new Float32Array(384).map(() => Math.random())
  const sim = cosine_similarity_wasm(queryEmb, docEmb)
  console.log('Similarity:', sim)
}

6. Memory-Safe ML: Polars Data Pipelines

Large-Scale Training Data Processing with Polars

[dependencies]
polars = { version = "0.44", features = ["lazy", "parquet", "json", "csv"] }
use polars::prelude::*;

fn preprocess_training_data() -> PolarsResult<DataFrame> {
    // Lazy Parquet scan (memory efficient)
    let df = LazyFrame::scan_parquet("training_data.parquet", Default::default())?
        .filter(col("label").is_not_null())
        .with_column(
            // Extract text length feature
            col("text").str().len_chars().alias("text_length")
        )
        .with_column(
            // Normalize scores
            (col("score") - col("score").mean()) / col("score").std(1)
        )
        .filter(col("text_length").gt(10))
        .select([col("text"), col("label"), col("score"), col("text_length")])
        .collect()?;

    println!("Shape: {:?}", df.shape());
    println!("{}", df.head(Some(5)));

    Ok(df)
}

fn batch_generator(df: &DataFrame, batch_size: usize) -> Vec<DataFrame> {
    let n = df.height();
    (0..n).step_by(batch_size).map(|start| {
        let end = (start + batch_size).min(n);
        df.slice(start as i64, end - start)
    }).collect()
}

fn compute_label_stats(df: &DataFrame) -> PolarsResult<DataFrame> {
    df.clone().lazy()
        .group_by([col("label")])
        .agg([
            col("score").mean().alias("avg_score"),
            col("score").std(1).alias("std_score"),
            col("text_length").mean().alias("avg_length"),
            col("label").count().alias("count"),
        ])
        .sort(["count"], SortMultipleOptions::default().with_order_descending(true))
        .collect()
}

Zero-Copy Data Sharing with Apache Arrow

use arrow::array::{Float32Array, StringArray};
use arrow::record_batch::RecordBatch;
use arrow::datatypes::{DataType, Field, Schema};
use std::sync::Arc;

fn create_embedding_batch(
    texts: Vec<String>,
    embeddings: Vec<Vec<f32>>,
) -> RecordBatch {
    let schema = Arc::new(Schema::new(vec![
        Field::new("text", DataType::Utf8, false),
        Field::new("embedding_dim0", DataType::Float32, false),
    ]));

    let text_array = StringArray::from(texts);
    // Store first dimension as example
    let emb_dim: Vec<f32> = embeddings.iter().map(|e| e[0]).collect();
    let emb_array = Float32Array::from(emb_dim);

    RecordBatch::try_new(
        schema,
        vec![Arc::new(text_array), Arc::new(emb_array)],
    ).unwrap()
}

7. Case Studies

Hugging Face Candle

HuggingFace is porting Python PyTorch models to Rust Candle:

  • Binary size: PyTorch 250MB+ vs Candle ~5MB
  • WASM support enables direct browser inference
  • Reusable CUDA kernels

Ruff: Python Linter Rewritten in Rust

Ruff is 100x faster than Flake8. The speed difference is palpable on large ML codebases.

# Legacy Python linter
time flake8 large_ml_project/  # ~30 seconds

# Ruff (Rust)
time ruff check large_ml_project/  # ~0.3 seconds

Pydantic v2's Rust Core

Pydantic v2 rewrote its core validation logic in Rust (pydantic-core), achieving 5-50x performance improvements. AI API servers see dramatically reduced request validation latency.


Quiz

Q1. What is the difference between Arc and Rc in Rust, and why is Arc required for multi-threaded AI inference?

Answer: Arc (Atomic Reference Count) manages reference counts with atomic operations, making it thread-safe. Rc (Reference Count) is for single-threaded use only.

Explanation: In an AI inference server, multiple request threads share the same model weights. Rc does not implement the Send trait, so it cannot be moved between threads. Arc uses atomic fetch-add operations for reference count changes, allowing multiple threads to reference the same data without data races. The performance difference is negligible (nanosecond-level), but safety is guaranteed at compile time.

Q2. How does PyO3 handle the GIL, and what performance benefit does it provide when calling Rust parallel code from Python?

Answer: PyO3 releases the GIL via py.allow_threads(), enabling Rust code to execute truly in parallel without Python's GIL constraint.

Explanation: Python's GIL (Global Interpreter Lock) restricts execution to one Python thread at a time. Calling allow_threads releases the GIL for the current thread, letting Rust code (with rayon, for example) freely use OS threads. Computing cosine similarity over 1 million embeddings takes tens of seconds in single-threaded Python, but can be reduced to hundreds of milliseconds with Rust + rayon parallelism.

Q3. Why is Candle better suited for edge deployment than PyTorch?

Answer: Candle ships as a single Rust binary with WASM support, producing ~5MB binaries without requiring the PyTorch runtime (250MB+).

Explanation: PyTorch requires Python runtime, libtorch, CUDA drivers, and other large dependencies. Candle's static linking produces small executables containing only what's needed. Compiling to WebAssembly enables direct browser inference — client-side AI with no server costs. It also runs on Raspberry Pi and embedded devices without Python.

Q4. Why does Polars process large datasets faster than Pandas internally?

Answer: Polars is written in Rust and uses the Apache Arrow columnar memory format, SIMD optimizations, and query optimization via lazy evaluation.

Explanation: Pandas relies on Python with frequent row-wise operations and GIL constraints. Polars uses (1) Arrow columnar format for cache-efficient access, (2) Rust's automatic SIMD vectorization for numeric operations, (3) a lazy API that optimizes query plans (predicate pushdown, projection pushdown), and (4) multi-threaded parallel processing without a GIL. Typical speedup is 5-20x over Pandas, with lower memory consumption.

Q5. Why do Rust's zero-cost abstractions match C++ performance?

Answer: The Rust compiler (LLVM backend) transforms high-level abstractions (traits, generics, iterators) into optimized machine code with no runtime overhead.

Explanation: Like C++ templates and inlining, Rust generics use monomorphization to generate concrete code for each type at compile time. Iterator chains (map, filter, fold) produce the same assembly as C-style for loops. Trait methods use static dispatch — no virtual table lookups, just inlining. The result: code that is both ergonomic to write and as fast as C.


Wrapping Up

Rust is a powerful tool for pushing past the performance ceiling in AI infrastructure. The practical approach: prototype in Python, rewrite bottlenecks in Rust (via PyO3), or build high-performance services in Rust from the ground up.

Learning Roadmap:

  1. The Rust Book — official tutorial
  2. Rustlings — interactive exercises
  3. Candle examples — HuggingFace GitHub
  4. Axum documentation — building inference servers
  5. PyO3 guide — Python integration