Skip to content
Published on

AIシステムのためのRust: Candle、PyO3、axumで高性能AI推論サーバーを構築する

Authors

はじめに

AI開発でPythonは依然として主流ですが、パフォーマンスと安全性が求められるプロダクションAIインフラにおいて、Rustが急速に注目されています。

HuggingFaceのCandle、高速PythonリンターRuff、Pydantic v2のRustコア — AIエコシステムの主要ツールがRustで書き直されています。このガイドはAIエンジニアのための実践的なRust入門です。

なぜAIシステムにRustなのか?

項目PythonRust
実行速度普通C++レベル
メモリ安全性GC依存コンパイル時保証
並行性GIL制約ありデータ競合なし
バイナリサイズランタイム必要単一バイナリ
エッジ配備困難WASM、組み込み対応

1. Rust基礎: AIエンジニアの視点

Ownership(所有権)とメモリ安全性

Rustの核心はOwnershipシステムです。ガベージコレクターなしにメモリ安全性を保証します。

fn main() {
    // 所有権の移動 (move)
    let tensor_data = vec![1.0f32, 2.0, 3.0, 4.0];
    let moved = tensor_data; // tensor_dataはもう使えない

    // 借用 (borrowing) — 所有権移転なしに参照
    let weights = vec![0.1f32, 0.2, 0.3];
    let sum = compute_sum(&weights); // 不変参照
    println!("weights still valid: {:?}, sum: {}", weights, sum);
}

fn compute_sum(data: &[f32]) -> f32 {
    data.iter().sum()
}

ライフタイム (Lifetimes)

ライフタイムは参照が有効な範囲をコンパイラに伝えます。AIモデルで重みデータをコピーなしに安全に参照するときに重要です。

// ライフタイム注釈: 返された参照は入力スライスより長生きできない
fn longest_sequence<'a>(seq1: &'a [f32], seq2: &'a [f32]) -> &'a [f32] {
    if seq1.len() > seq2.len() {
        seq1
    } else {
        seq2
    }
}

struct ModelWeights<'a> {
    data: &'a [f32],  // 外部バッファを参照 (ゼロコピー)
    shape: (usize, usize),
}

impl<'a> ModelWeights<'a> {
    fn new(data: &'a [f32], rows: usize, cols: usize) -> Self {
        ModelWeights { data, shape: (rows, cols) }
    }

    fn get(&self, row: usize, col: usize) -> f32 {
        self.data[row * self.shape.1 + col]
    }
}

トレイト (Traits) とジェネリクス (Generics)

Pythonのダックタイピングと異なり、Rustはコンパイル時に型チェックを行います。

// AI演算のためのトレイト定義
trait Activation {
    fn forward(&self, x: f32) -> f32;
    fn backward(&self, x: f32) -> f32; // 微分
}

struct ReLU;
struct Sigmoid;

impl Activation for ReLU {
    fn forward(&self, x: f32) -> f32 {
        x.max(0.0)
    }
    fn backward(&self, x: f32) -> f32 {
        if x > 0.0 { 1.0 } else { 0.0 }
    }
}

impl Activation for Sigmoid {
    fn forward(&self, x: f32) -> f32 {
        1.0 / (1.0 + (-x).exp())
    }
    fn backward(&self, x: f32) -> f32 {
        let s = self.forward(x);
        s * (1.0 - s)
    }
}

// ジェネリックレイヤー: どんな活性化関数でも使用可能
fn apply_activation<A: Activation>(activation: &A, inputs: &[f32]) -> Vec<f32> {
    inputs.iter().map(|&x| activation.forward(x)).collect()
}

fn main() {
    let relu = ReLU;
    let data = vec![-1.0, 0.5, 2.0, -0.3];
    let output = apply_activation(&relu, &data);
    println!("ReLU output: {:?}", output); // [0.0, 0.5, 2.0, 0.0]
}

Arc vs Rc: マルチスレッドAI推論

use std::sync::{Arc, Mutex};
use std::thread;

// Rcはシングルスレッド専用 — AI推論サーバーでは使用不可
// Arc (Atomic Reference Count) はスレッドセーフ

fn multi_thread_inference() {
    // 複数スレッドがモデル重みを共有
    let model_weights: Arc<Vec<f32>> = Arc::new(vec![0.1, 0.2, 0.3, 0.4]);
    let results: Arc<Mutex<Vec<f32>>> = Arc::new(Mutex::new(Vec::new()));

    let mut handles = vec![];

    for i in 0..4 {
        let weights = Arc::clone(&model_weights);
        let results = Arc::clone(&results);

        let handle = thread::spawn(move || {
            // 各スレッドが同じ重みで推論を実行
            let inference_result = weights.iter().sum::<f32>() * i as f32;
            results.lock().unwrap().push(inference_result);
        });
        handles.push(handle);
    }

    for handle in handles {
        handle.join().unwrap();
    }

    println!("All results: {:?}", results.lock().unwrap());
}

2. Rust AIエコシステム

Candle: HuggingFaceのRust MLフレームワーク

CandleはPyTorchのRust代替で、最小依存性とWASMサポートが強みです。

# Cargo.toml
[dependencies]
candle-core = "0.8"
candle-nn = "0.8"
candle-transformers = "0.8"
use candle_core::{Device, Tensor, DType};
use candle_nn::{Linear, Module, VarBuilder, VarMap};

// シンプルなフィードフォワードネットワーク
struct FeedForward {
    fc1: Linear,
    fc2: Linear,
}

impl FeedForward {
    fn new(vb: VarBuilder) -> candle_core::Result<Self> {
        let fc1 = candle_nn::linear(784, 256, vb.pp("fc1"))?;
        let fc2 = candle_nn::linear(256, 10, vb.pp("fc2"))?;
        Ok(FeedForward { fc1, fc2 })
    }
}

impl Module for FeedForward {
    fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
        let x = self.fc1.forward(x)?;
        let x = x.relu()?;
        self.fc2.forward(&x)
    }
}

fn candle_tensor_ops() -> candle_core::Result<()> {
    let device = Device::Cpu; // または Device::new_cuda(0)?

    // テンソル作成
    let a = Tensor::randn(0f32, 1f32, (3, 4), &device)?;
    let b = Tensor::randn(0f32, 1f32, (4, 5), &device)?;

    // 行列積
    let c = a.matmul(&b)?;
    println!("Shape: {:?}", c.shape());

    // 要素ごとの演算
    let x = Tensor::new(&[[1.0f32, 2.0, 3.0]], &device)?;
    let softmax = candle_nn::ops::softmax(&x, 1)?;
    println!("Softmax: {:?}", softmax.to_vec2::<f32>()?);

    Ok(())
}

Burn: バックエンド非依存MLフレームワーク

use burn::prelude::*;
use burn::nn::{Linear, LinearConfig, Relu};

#[derive(Module, Debug)]
struct SimpleNet<B: Backend> {
    fc1: Linear<B>,
    activation: Relu,
    fc2: Linear<B>,
}

impl<B: Backend> SimpleNet<B> {
    fn new(device: &B::Device) -> Self {
        SimpleNet {
            fc1: LinearConfig::new(128, 64).init(device),
            activation: Relu::new(),
            fc2: LinearConfig::new(64, 10).init(device),
        }
    }

    fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
        let x = self.fc1.forward(x);
        let x = self.activation.forward(x);
        self.fc2.forward(x)
    }
}

Linfa: Scikit-learn代替

use linfa::prelude::*;
use linfa_clustering::KMeans;
use ndarray::Array2;

fn kmeans_clustering() -> Result<(), Box<dyn std::error::Error>> {
    // データ生成 (埋め込みベクトルと仮定)
    let data: Array2<f64> = Array2::from_shape_fn((100, 10), |(i, j)| {
        (i as f64 * 0.1 + j as f64) % 5.0
    });

    let dataset = Dataset::from(data);

    // K-meansクラスタリング (RAGシステムで類似文書のグループ化)
    let model = KMeans::params(5)
        .tolerance(1e-5)
        .fit(&dataset)?;

    let predictions = model.predict(&dataset);
    println!("クラスタ割り当て: {:?}", &predictions.targets()[..10]);

    Ok(())
}

3. 高性能推論サーバー: axum + Candle

tokio非同期ベースのLLM推論サーバー

[dependencies]
axum = "0.7"
tokio = { version = "1", features = ["full"] }
candle-core = "0.8"
candle-transformers = "0.8"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
tower = "0.4"
use axum::{
    extract::State,
    http::StatusCode,
    response::Json,
    routing::post,
    Router,
};
use candle_core::Device;
use serde::{Deserialize, Serialize};
use std::sync::Arc;

#[derive(Deserialize)]
struct InferenceRequest {
    prompt: String,
    max_tokens: Option<usize>,
    temperature: Option<f32>,
}

#[derive(Serialize)]
struct InferenceResponse {
    text: String,
    tokens_generated: usize,
    latency_ms: u64,
}

// モデル状態をArcで共有 (スレッドセーフ)
struct AppState {
    device: Device,
    // model: Arc<Mutex<YourLLMModel>>,
}

async fn inference_handler(
    State(state): State<Arc<AppState>>,
    Json(req): Json<InferenceRequest>,
) -> Result<Json<InferenceResponse>, StatusCode> {
    let start = std::time::Instant::now();

    let max_tokens = req.max_tokens.unwrap_or(100);
    let _temperature = req.temperature.unwrap_or(0.7);

    // ダミー推論 (実際はCandleモデルのforward passを使用)
    let generated_text = format!("Response to: {}", req.prompt);
    let latency = start.elapsed().as_millis() as u64;

    Ok(Json(InferenceResponse {
        text: generated_text,
        tokens_generated: max_tokens,
        latency_ms: latency,
    }))
}

async fn health_handler() -> &'static str {
    "OK"
}

#[tokio::main]
async fn main() {
    let state = Arc::new(AppState {
        device: Device::Cpu,
    });

    let app = Router::new()
        .route("/inference", post(inference_handler))
        .route("/health", axum::routing::get(health_handler))
        .with_state(state);

    let listener = tokio::net::TcpListener::bind("0.0.0.0:8080").await.unwrap();
    println!("AI推論サーバー起動: ポート8080");
    axum::serve(listener, app).await.unwrap();
}

バッチ推論最適化

use tokio::sync::mpsc;

struct BatchRequest {
    id: u64,
    prompt: String,
    response_tx: tokio::sync::oneshot::Sender<String>,
}

// バッチ処理ワーカー — 複数リクエストをまとめてGPU効率を最大化
async fn batch_inference_worker(
    mut rx: mpsc::Receiver<BatchRequest>,
    batch_size: usize,
    timeout_ms: u64,
) {
    loop {
        let mut batch = Vec::new();

        // 最初のリクエストを待機
        if let Some(req) = rx.recv().await {
            batch.push(req);
        } else {
            break;
        }

        // タイムアウト内にバッチを満たす
        let deadline = tokio::time::Instant::now()
            + tokio::time::Duration::from_millis(timeout_ms);

        while batch.len() < batch_size {
            match tokio::time::timeout_at(deadline, rx.recv()).await {
                Ok(Some(req)) => batch.push(req),
                _ => break,
            }
        }

        // バッチ処理 (実際はGPUで並列推論)
        for req in batch {
            let response = format!("バッチ応答: {}", req.prompt);
            let _ = req.response_tx.send(response);
        }
    }
}

4. Python連携: PyO3

RustコードをPythonから呼び出す

[lib]
name = "rust_ai_tools"
crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.22", features = ["extension-module"] }
numpy = "0.22"
use pyo3::prelude::*;

// Pythonから呼び出し可能なRust関数
#[pyfunction]
fn fast_cosine_similarity(a: Vec<f32>, b: Vec<f32>) -> PyResult<f32> {
    if a.len() != b.len() {
        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
            "ベクトルの長さが一致しません"
        ));
    }

    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();

    Ok(dot / (norm_a * norm_b))
}

// GIL解放で真の並列処理
#[pyfunction]
fn batch_normalize(
    py: Python,
    embeddings: Vec<Vec<f32>>,
) -> PyResult<Vec<Vec<f32>>> {
    // GILを解放してRustスレッドプールを使用
    let result = py.allow_threads(|| {
        embeddings.iter().map(|emb| {
            let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
            emb.iter().map(|x| x / norm).collect()
        }).collect::<Vec<Vec<f32>>>()
    });

    Ok(result)
}

// Pythonモジュール登録
#[pymodule]
fn rust_ai_tools(_py: Python, m: &PyModule) -> PyResult<()> {
    m.add_function(wrap_pyfunction!(fast_cosine_similarity, m)?)?;
    m.add_function(wrap_pyfunction!(batch_normalize, m)?)?;
    Ok(())
}

maturinでビルドとデプロイ

# maturinインストール
pip install maturin

# 開発モードビルド (仮想環境に直接インストール)
maturin develop

# Pythonでの使用例:
# import rust_ai_tools
# sim = rust_ai_tools.fast_cosine_similarity([1.0, 0.0], [0.0, 1.0])
# print(sim)  # 0.0

# PyPI配布用wheelビルド
maturin build --release

# pyproject.toml設定例:
# [build-system]
# requires = ["maturin>=1.0,<2.0"]
# build-backend = "maturin"

真の並列処理: GIL解放 + rayon

use pyo3::prelude::*;
use rayon::prelude::*;

// rayonによる真の並列埋め込み検索
#[pyfunction]
fn parallel_embedding_search(
    py: Python,
    query: Vec<f32>,
    corpus: Vec<Vec<f32>>,
    top_k: usize,
) -> PyResult<Vec<usize>> {
    // GIL解放 — RustスレッドプールがPythonのGILなしで動作
    let indices = py.allow_threads(|| {
        let mut scores: Vec<(usize, f32)> = corpus
            .par_iter()  // 並列イテレーター
            .enumerate()
            .map(|(i, emb)| {
                let dot: f32 = query.iter().zip(emb.iter()).map(|(a, b)| a * b).sum();
                (i, dot)
            })
            .collect();

        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
        scores.iter().take(top_k).map(|(i, _)| *i).collect::<Vec<_>>()
    });

    Ok(indices)
}

5. WASM for AI: ブラウザ推論

wasm-bindgenでブラウザ内AI実行

[lib]
crate-type = ["cdylib"]

[dependencies]
wasm-bindgen = "0.2"
candle-core = { version = "0.8", features = ["wasm"] }
use wasm_bindgen::prelude::*;

// JavaScriptから呼び出し可能な構造体
#[wasm_bindgen]
pub struct MiniModel {
    weights: Vec<f32>,
    input_size: usize,
    output_size: usize,
}

#[wasm_bindgen]
impl MiniModel {
    #[wasm_bindgen(constructor)]
    pub fn new(input_size: usize, output_size: usize) -> MiniModel {
        let weights = vec![0.01f32; input_size * output_size];
        MiniModel { weights, input_size, output_size }
    }

    // JavaScript Float32Arrayを受け取って推論
    pub fn predict(&self, input: &[f32]) -> Vec<f32> {
        let mut output = vec![0.0f32; self.output_size];
        for i in 0..self.output_size {
            for j in 0..self.input_size {
                output[i] += input[j] * self.weights[i * self.input_size + j];
            }
            // ReLU活性化
            output[i] = output[i].max(0.0);
        }
        output
    }

    pub fn load_weights(&mut self, weights: Vec<f32>) {
        self.weights = weights;
    }
}

// 埋め込みのコサイン類似度 (ブラウザでRAG可能)
#[wasm_bindgen]
pub fn cosine_similarity_wasm(a: &[f32], b: &[f32]) -> f32 {
    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
    let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
    let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
    dot / (na * nb)
}
// JavaScriptからWASMモジュールを使用
import init, { MiniModel, cosine_similarity_wasm } from './rust_ai_wasm.js'

async function runBrowserInference() {
  await init()

  const model = new MiniModel(128, 10)
  const input = new Float32Array(128).fill(0.5)
  const output = model.predict(input)
  console.log('ブラウザ推論結果:', output)

  // ベクトル類似度検索
  const queryEmb = new Float32Array(384).map(() => Math.random())
  const docEmb = new Float32Array(384).map(() => Math.random())
  const sim = cosine_similarity_wasm(queryEmb, docEmb)
  console.log('類似度:', sim)
}

6. メモリ安全ML: Polarsデータパイプライン

Polarsで大量の学習データを処理

[dependencies]
polars = { version = "0.44", features = ["lazy", "parquet", "json", "csv"] }
use polars::prelude::*;

fn preprocess_training_data() -> PolarsResult<DataFrame> {
    // Parquetファイルの遅延ロード (メモリ効率)
    let df = LazyFrame::scan_parquet("training_data.parquet", Default::default())?
        .filter(col("label").is_not_null())
        .with_column(
            // テキスト長特徴の抽出
            col("text").str().len_chars().alias("text_length")
        )
        .with_column(
            // 正規化
            (col("score") - col("score").mean()) / col("score").std(1)
        )
        .filter(col("text_length").gt(10))
        .select([col("text"), col("label"), col("score"), col("text_length")])
        .collect()?;

    println!("Shape: {:?}", df.shape());
    println!("{}", df.head(Some(5)));

    Ok(df)
}

fn batch_generator(df: &DataFrame, batch_size: usize) -> Vec<DataFrame> {
    let n = df.height();
    (0..n).step_by(batch_size).map(|start| {
        let end = (start + batch_size).min(n);
        df.slice(start as i64, end - start)
    }).collect()
}

fn compute_label_stats(df: &DataFrame) -> PolarsResult<DataFrame> {
    df.clone().lazy()
        .group_by([col("label")])
        .agg([
            col("score").mean().alias("avg_score"),
            col("score").std(1).alias("std_score"),
            col("text_length").mean().alias("avg_length"),
            col("label").count().alias("count"),
        ])
        .sort(["count"], SortMultipleOptions::default().with_order_descending(true))
        .collect()
}

Apache Arrowによるゼロコピーデータ共有

use arrow::array::{Float32Array, StringArray};
use arrow::record_batch::RecordBatch;
use arrow::datatypes::{DataType, Field, Schema};
use std::sync::Arc;

fn create_embedding_batch(
    texts: Vec<String>,
    embeddings: Vec<Vec<f32>>,
) -> RecordBatch {
    let schema = Arc::new(Schema::new(vec![
        Field::new("text", DataType::Utf8, false),
        Field::new("embedding_dim0", DataType::Float32, false),
    ]));

    let text_array = StringArray::from(texts);
    // 最初の次元のみ保存 (例示)
    let emb_dim: Vec<f32> = embeddings.iter().map(|e| e[0]).collect();
    let emb_array = Float32Array::from(emb_dim);

    RecordBatch::try_new(
        schema,
        vec![Arc::new(text_array), Arc::new(emb_array)],
    ).unwrap()
}

7. ケーススタディ

Hugging Face Candle

HuggingFaceはPython PyTorchモデルをRust Candleに移植しています:

  • バイナリサイズ: PyTorch 250MB以上 vs Candle 約5MB
  • WASMサポートでブラウザ直接推論が可能
  • CUDAカーネルの再利用も可能

Ruff: Python LinterのRust書き直し

RuffはFlake8より100倍高速なPythonリンターです。大規模MLコードベースでその差が実感できます。

# 従来のPythonリンター
time flake8 large_ml_project/  # 約30秒

# Ruff (Rust)
time ruff check large_ml_project/  # 約0.3秒

Pydantic v2のRustコア

Pydantic v2はコアバリデーションロジックをRust (pydantic-core) で書き直し、5〜50倍のパフォーマンス向上を達成しました。AI APIサーバーでのリクエスト検証レイテンシが大幅に削減されました。


クイズ

Q1. RustにおけるArcとRcの違いと、マルチスレッドAI推論でArcが必要な理由は?

答え: Arc (Atomic Reference Count) はアトミック操作で参照カウントを管理するためスレッドセーフですが、Rc (Reference Count) はシングルスレッド専用です。

解説: AI推論サーバーでは複数のリクエストスレッドが同じモデル重みを共有します。RcはSendトレイトを実装していないため、スレッド間で移動できません。Arcは参照カウントの増減にアトミック操作 (atomic fetch-add) を使用し、データ競合なしに複数スレッドが同じデータを参照できます。性能差はナノ秒レベルですが、安全性はコンパイル時に保証されます。

Q2. PyO3のGIL処理方式と、PythonからRust並列コードを呼び出す際のパフォーマンス利点は?

答え: PyO3は py.allow_threads() でGILを解放し、RustコードがPythonのGIL制約なしに真の並列実行を行います。

解説: PythonのGIL (Global Interpreter Lock) は一度に一つのPythonスレッドしか実行できないよう制限します。allow_threads を呼び出すと現在のスレッドがGILを解放し、Rustコード (rayon等) がOSスレッドを自由に使用できます。100万件の埋め込み類似度計算はPythonシングルスレッドで数十秒かかりますが、Rust + rayon並列化で数百ミリ秒に短縮できます。

Q3. CandleがPyTorchよりエッジ配備に有利な理由は?

答え: CandleはRustの単一バイナリとして配布可能でWASMをサポートし、PyTorchランタイム (250MB以上) 不要で約5MBのバイナリを生成できます。

解説: PyTorchはPythonランタイム、libtorch、CUDAドライバーなど大規模な依存性が必要です。Candleの静的リンクにより、必要な機能だけを含む小さな実行ファイルが生成されます。WebAssemblyにコンパイルするとブラウザで直接推論が可能になり、サーバーコストなしでクライアントサイドAIを実現できます。Raspberry Piや組み込みデバイスでもPythonなしで動作します。

Q4. Polarsが大容量データ処理でPandasより高速な内部実装の原理は?

答え: PolarsはRustで書かれており、Apache Arrowカラム型メモリフォーマット、SIMD最適化、遅延評価によるクエリ最適化を活用しています。

解説: PandasはPythonベースで行単位処理が多くGIL制約があります。Polarsは (1) Arrowカラム型フォーマットでキャッシュ効率が高く、(2) RustのSIMD自動ベクトル化で数値演算が高速で、(3) lazy APIでクエリプランを最適化 (predicate pushdown、projection pushdown) し、(4) GILなしにマルチスレッド並列処理を行います。一般的にPandasと比べて5〜20倍高速で、メモリ使用量も少なくなります。

Q5. Rustのゼロコスト抽象化 (zero-cost abstraction) がC++の性能に匹敵する理由は?

答え: Rustコンパイラー (LLVMバックエンド) が高水準の抽象化 (トレイト、ジェネリクス、イテレーター) をランタイムオーバーヘッドなしに最適化された機械語に変換します。

解説: C++のテンプレートとインライニングと同様に、Rustのジェネリクスは単相化 (monomorphization) によってコンパイル時に各型の具体的なコードを生成します。イテレーターチェーン (mapfilterfold) はCスタイルのforループと同じアセンブリを生成します。トレイトメソッドは静的ディスパッチ (static dispatch) で仮想関数テーブル参照なしにインライン化されます。結果として「書きやすくてCと同じ速さ」なコードが可能です。


まとめ

RustはAIインフラのパフォーマンスの壁を突破するための強力なツールです。Pythonでプロトタイプを作り、ボトルネックになった部分をRustで書き直す (PyO3)、または最初から高性能サービスをRustで構築する戦略が現実的です。

学習ロードマップ:

  1. The Rust Book — 公式チュートリアル
  2. Rustlings — インタラクティブ演習
  3. Candleサンプル — HuggingFace GitHub
  4. Axum公式ドキュメント — 推論サーバー構築
  5. PyO3ガイド — Python連携