- Authors

- Name
- Youngju Kim
- @fjvbn20031
- はじめに
- 1. Rust基礎: AIエンジニアの視点
- 2. Rust AIエコシステム
- 3. 高性能推論サーバー: axum + Candle
- 4. Python連携: PyO3
- 5. WASM for AI: ブラウザ推論
- 6. メモリ安全ML: Polarsデータパイプライン
- 7. ケーススタディ
- クイズ
- まとめ
はじめに
AI開発でPythonは依然として主流ですが、パフォーマンスと安全性が求められるプロダクションAIインフラにおいて、Rustが急速に注目されています。
HuggingFaceのCandle、高速PythonリンターRuff、Pydantic v2のRustコア — AIエコシステムの主要ツールがRustで書き直されています。このガイドはAIエンジニアのための実践的なRust入門です。
なぜAIシステムにRustなのか?
| 項目 | Python | Rust |
|---|---|---|
| 実行速度 | 普通 | C++レベル |
| メモリ安全性 | GC依存 | コンパイル時保証 |
| 並行性 | GIL制約あり | データ競合なし |
| バイナリサイズ | ランタイム必要 | 単一バイナリ |
| エッジ配備 | 困難 | WASM、組み込み対応 |
1. Rust基礎: AIエンジニアの視点
Ownership(所有権)とメモリ安全性
Rustの核心はOwnershipシステムです。ガベージコレクターなしにメモリ安全性を保証します。
fn main() {
// 所有権の移動 (move)
let tensor_data = vec![1.0f32, 2.0, 3.0, 4.0];
let moved = tensor_data; // tensor_dataはもう使えない
// 借用 (borrowing) — 所有権移転なしに参照
let weights = vec![0.1f32, 0.2, 0.3];
let sum = compute_sum(&weights); // 不変参照
println!("weights still valid: {:?}, sum: {}", weights, sum);
}
fn compute_sum(data: &[f32]) -> f32 {
data.iter().sum()
}
ライフタイム (Lifetimes)
ライフタイムは参照が有効な範囲をコンパイラに伝えます。AIモデルで重みデータをコピーなしに安全に参照するときに重要です。
// ライフタイム注釈: 返された参照は入力スライスより長生きできない
fn longest_sequence<'a>(seq1: &'a [f32], seq2: &'a [f32]) -> &'a [f32] {
if seq1.len() > seq2.len() {
seq1
} else {
seq2
}
}
struct ModelWeights<'a> {
data: &'a [f32], // 外部バッファを参照 (ゼロコピー)
shape: (usize, usize),
}
impl<'a> ModelWeights<'a> {
fn new(data: &'a [f32], rows: usize, cols: usize) -> Self {
ModelWeights { data, shape: (rows, cols) }
}
fn get(&self, row: usize, col: usize) -> f32 {
self.data[row * self.shape.1 + col]
}
}
トレイト (Traits) とジェネリクス (Generics)
Pythonのダックタイピングと異なり、Rustはコンパイル時に型チェックを行います。
// AI演算のためのトレイト定義
trait Activation {
fn forward(&self, x: f32) -> f32;
fn backward(&self, x: f32) -> f32; // 微分
}
struct ReLU;
struct Sigmoid;
impl Activation for ReLU {
fn forward(&self, x: f32) -> f32 {
x.max(0.0)
}
fn backward(&self, x: f32) -> f32 {
if x > 0.0 { 1.0 } else { 0.0 }
}
}
impl Activation for Sigmoid {
fn forward(&self, x: f32) -> f32 {
1.0 / (1.0 + (-x).exp())
}
fn backward(&self, x: f32) -> f32 {
let s = self.forward(x);
s * (1.0 - s)
}
}
// ジェネリックレイヤー: どんな活性化関数でも使用可能
fn apply_activation<A: Activation>(activation: &A, inputs: &[f32]) -> Vec<f32> {
inputs.iter().map(|&x| activation.forward(x)).collect()
}
fn main() {
let relu = ReLU;
let data = vec![-1.0, 0.5, 2.0, -0.3];
let output = apply_activation(&relu, &data);
println!("ReLU output: {:?}", output); // [0.0, 0.5, 2.0, 0.0]
}
Arc vs Rc: マルチスレッドAI推論
use std::sync::{Arc, Mutex};
use std::thread;
// Rcはシングルスレッド専用 — AI推論サーバーでは使用不可
// Arc (Atomic Reference Count) はスレッドセーフ
fn multi_thread_inference() {
// 複数スレッドがモデル重みを共有
let model_weights: Arc<Vec<f32>> = Arc::new(vec![0.1, 0.2, 0.3, 0.4]);
let results: Arc<Mutex<Vec<f32>>> = Arc::new(Mutex::new(Vec::new()));
let mut handles = vec![];
for i in 0..4 {
let weights = Arc::clone(&model_weights);
let results = Arc::clone(&results);
let handle = thread::spawn(move || {
// 各スレッドが同じ重みで推論を実行
let inference_result = weights.iter().sum::<f32>() * i as f32;
results.lock().unwrap().push(inference_result);
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
println!("All results: {:?}", results.lock().unwrap());
}
2. Rust AIエコシステム
Candle: HuggingFaceのRust MLフレームワーク
CandleはPyTorchのRust代替で、最小依存性とWASMサポートが強みです。
# Cargo.toml
[dependencies]
candle-core = "0.8"
candle-nn = "0.8"
candle-transformers = "0.8"
use candle_core::{Device, Tensor, DType};
use candle_nn::{Linear, Module, VarBuilder, VarMap};
// シンプルなフィードフォワードネットワーク
struct FeedForward {
fc1: Linear,
fc2: Linear,
}
impl FeedForward {
fn new(vb: VarBuilder) -> candle_core::Result<Self> {
let fc1 = candle_nn::linear(784, 256, vb.pp("fc1"))?;
let fc2 = candle_nn::linear(256, 10, vb.pp("fc2"))?;
Ok(FeedForward { fc1, fc2 })
}
}
impl Module for FeedForward {
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
let x = self.fc1.forward(x)?;
let x = x.relu()?;
self.fc2.forward(&x)
}
}
fn candle_tensor_ops() -> candle_core::Result<()> {
let device = Device::Cpu; // または Device::new_cuda(0)?
// テンソル作成
let a = Tensor::randn(0f32, 1f32, (3, 4), &device)?;
let b = Tensor::randn(0f32, 1f32, (4, 5), &device)?;
// 行列積
let c = a.matmul(&b)?;
println!("Shape: {:?}", c.shape());
// 要素ごとの演算
let x = Tensor::new(&[[1.0f32, 2.0, 3.0]], &device)?;
let softmax = candle_nn::ops::softmax(&x, 1)?;
println!("Softmax: {:?}", softmax.to_vec2::<f32>()?);
Ok(())
}
Burn: バックエンド非依存MLフレームワーク
use burn::prelude::*;
use burn::nn::{Linear, LinearConfig, Relu};
#[derive(Module, Debug)]
struct SimpleNet<B: Backend> {
fc1: Linear<B>,
activation: Relu,
fc2: Linear<B>,
}
impl<B: Backend> SimpleNet<B> {
fn new(device: &B::Device) -> Self {
SimpleNet {
fc1: LinearConfig::new(128, 64).init(device),
activation: Relu::new(),
fc2: LinearConfig::new(64, 10).init(device),
}
}
fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
let x = self.fc1.forward(x);
let x = self.activation.forward(x);
self.fc2.forward(x)
}
}
Linfa: Scikit-learn代替
use linfa::prelude::*;
use linfa_clustering::KMeans;
use ndarray::Array2;
fn kmeans_clustering() -> Result<(), Box<dyn std::error::Error>> {
// データ生成 (埋め込みベクトルと仮定)
let data: Array2<f64> = Array2::from_shape_fn((100, 10), |(i, j)| {
(i as f64 * 0.1 + j as f64) % 5.0
});
let dataset = Dataset::from(data);
// K-meansクラスタリング (RAGシステムで類似文書のグループ化)
let model = KMeans::params(5)
.tolerance(1e-5)
.fit(&dataset)?;
let predictions = model.predict(&dataset);
println!("クラスタ割り当て: {:?}", &predictions.targets()[..10]);
Ok(())
}
3. 高性能推論サーバー: axum + Candle
tokio非同期ベースのLLM推論サーバー
[dependencies]
axum = "0.7"
tokio = { version = "1", features = ["full"] }
candle-core = "0.8"
candle-transformers = "0.8"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
tower = "0.4"
use axum::{
extract::State,
http::StatusCode,
response::Json,
routing::post,
Router,
};
use candle_core::Device;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Deserialize)]
struct InferenceRequest {
prompt: String,
max_tokens: Option<usize>,
temperature: Option<f32>,
}
#[derive(Serialize)]
struct InferenceResponse {
text: String,
tokens_generated: usize,
latency_ms: u64,
}
// モデル状態をArcで共有 (スレッドセーフ)
struct AppState {
device: Device,
// model: Arc<Mutex<YourLLMModel>>,
}
async fn inference_handler(
State(state): State<Arc<AppState>>,
Json(req): Json<InferenceRequest>,
) -> Result<Json<InferenceResponse>, StatusCode> {
let start = std::time::Instant::now();
let max_tokens = req.max_tokens.unwrap_or(100);
let _temperature = req.temperature.unwrap_or(0.7);
// ダミー推論 (実際はCandleモデルのforward passを使用)
let generated_text = format!("Response to: {}", req.prompt);
let latency = start.elapsed().as_millis() as u64;
Ok(Json(InferenceResponse {
text: generated_text,
tokens_generated: max_tokens,
latency_ms: latency,
}))
}
async fn health_handler() -> &'static str {
"OK"
}
#[tokio::main]
async fn main() {
let state = Arc::new(AppState {
device: Device::Cpu,
});
let app = Router::new()
.route("/inference", post(inference_handler))
.route("/health", axum::routing::get(health_handler))
.with_state(state);
let listener = tokio::net::TcpListener::bind("0.0.0.0:8080").await.unwrap();
println!("AI推論サーバー起動: ポート8080");
axum::serve(listener, app).await.unwrap();
}
バッチ推論最適化
use tokio::sync::mpsc;
struct BatchRequest {
id: u64,
prompt: String,
response_tx: tokio::sync::oneshot::Sender<String>,
}
// バッチ処理ワーカー — 複数リクエストをまとめてGPU効率を最大化
async fn batch_inference_worker(
mut rx: mpsc::Receiver<BatchRequest>,
batch_size: usize,
timeout_ms: u64,
) {
loop {
let mut batch = Vec::new();
// 最初のリクエストを待機
if let Some(req) = rx.recv().await {
batch.push(req);
} else {
break;
}
// タイムアウト内にバッチを満たす
let deadline = tokio::time::Instant::now()
+ tokio::time::Duration::from_millis(timeout_ms);
while batch.len() < batch_size {
match tokio::time::timeout_at(deadline, rx.recv()).await {
Ok(Some(req)) => batch.push(req),
_ => break,
}
}
// バッチ処理 (実際はGPUで並列推論)
for req in batch {
let response = format!("バッチ応答: {}", req.prompt);
let _ = req.response_tx.send(response);
}
}
}
4. Python連携: PyO3
RustコードをPythonから呼び出す
[lib]
name = "rust_ai_tools"
crate-type = ["cdylib"]
[dependencies]
pyo3 = { version = "0.22", features = ["extension-module"] }
numpy = "0.22"
use pyo3::prelude::*;
// Pythonから呼び出し可能なRust関数
#[pyfunction]
fn fast_cosine_similarity(a: Vec<f32>, b: Vec<f32>) -> PyResult<f32> {
if a.len() != b.len() {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"ベクトルの長さが一致しません"
));
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
Ok(dot / (norm_a * norm_b))
}
// GIL解放で真の並列処理
#[pyfunction]
fn batch_normalize(
py: Python,
embeddings: Vec<Vec<f32>>,
) -> PyResult<Vec<Vec<f32>>> {
// GILを解放してRustスレッドプールを使用
let result = py.allow_threads(|| {
embeddings.iter().map(|emb| {
let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
emb.iter().map(|x| x / norm).collect()
}).collect::<Vec<Vec<f32>>>()
});
Ok(result)
}
// Pythonモジュール登録
#[pymodule]
fn rust_ai_tools(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(fast_cosine_similarity, m)?)?;
m.add_function(wrap_pyfunction!(batch_normalize, m)?)?;
Ok(())
}
maturinでビルドとデプロイ
# maturinインストール
pip install maturin
# 開発モードビルド (仮想環境に直接インストール)
maturin develop
# Pythonでの使用例:
# import rust_ai_tools
# sim = rust_ai_tools.fast_cosine_similarity([1.0, 0.0], [0.0, 1.0])
# print(sim) # 0.0
# PyPI配布用wheelビルド
maturin build --release
# pyproject.toml設定例:
# [build-system]
# requires = ["maturin>=1.0,<2.0"]
# build-backend = "maturin"
真の並列処理: GIL解放 + rayon
use pyo3::prelude::*;
use rayon::prelude::*;
// rayonによる真の並列埋め込み検索
#[pyfunction]
fn parallel_embedding_search(
py: Python,
query: Vec<f32>,
corpus: Vec<Vec<f32>>,
top_k: usize,
) -> PyResult<Vec<usize>> {
// GIL解放 — RustスレッドプールがPythonのGILなしで動作
let indices = py.allow_threads(|| {
let mut scores: Vec<(usize, f32)> = corpus
.par_iter() // 並列イテレーター
.enumerate()
.map(|(i, emb)| {
let dot: f32 = query.iter().zip(emb.iter()).map(|(a, b)| a * b).sum();
(i, dot)
})
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
scores.iter().take(top_k).map(|(i, _)| *i).collect::<Vec<_>>()
});
Ok(indices)
}
5. WASM for AI: ブラウザ推論
wasm-bindgenでブラウザ内AI実行
[lib]
crate-type = ["cdylib"]
[dependencies]
wasm-bindgen = "0.2"
candle-core = { version = "0.8", features = ["wasm"] }
use wasm_bindgen::prelude::*;
// JavaScriptから呼び出し可能な構造体
#[wasm_bindgen]
pub struct MiniModel {
weights: Vec<f32>,
input_size: usize,
output_size: usize,
}
#[wasm_bindgen]
impl MiniModel {
#[wasm_bindgen(constructor)]
pub fn new(input_size: usize, output_size: usize) -> MiniModel {
let weights = vec![0.01f32; input_size * output_size];
MiniModel { weights, input_size, output_size }
}
// JavaScript Float32Arrayを受け取って推論
pub fn predict(&self, input: &[f32]) -> Vec<f32> {
let mut output = vec![0.0f32; self.output_size];
for i in 0..self.output_size {
for j in 0..self.input_size {
output[i] += input[j] * self.weights[i * self.input_size + j];
}
// ReLU活性化
output[i] = output[i].max(0.0);
}
output
}
pub fn load_weights(&mut self, weights: Vec<f32>) {
self.weights = weights;
}
}
// 埋め込みのコサイン類似度 (ブラウザでRAG可能)
#[wasm_bindgen]
pub fn cosine_similarity_wasm(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
dot / (na * nb)
}
// JavaScriptからWASMモジュールを使用
import init, { MiniModel, cosine_similarity_wasm } from './rust_ai_wasm.js'
async function runBrowserInference() {
await init()
const model = new MiniModel(128, 10)
const input = new Float32Array(128).fill(0.5)
const output = model.predict(input)
console.log('ブラウザ推論結果:', output)
// ベクトル類似度検索
const queryEmb = new Float32Array(384).map(() => Math.random())
const docEmb = new Float32Array(384).map(() => Math.random())
const sim = cosine_similarity_wasm(queryEmb, docEmb)
console.log('類似度:', sim)
}
6. メモリ安全ML: Polarsデータパイプライン
Polarsで大量の学習データを処理
[dependencies]
polars = { version = "0.44", features = ["lazy", "parquet", "json", "csv"] }
use polars::prelude::*;
fn preprocess_training_data() -> PolarsResult<DataFrame> {
// Parquetファイルの遅延ロード (メモリ効率)
let df = LazyFrame::scan_parquet("training_data.parquet", Default::default())?
.filter(col("label").is_not_null())
.with_column(
// テキスト長特徴の抽出
col("text").str().len_chars().alias("text_length")
)
.with_column(
// 正規化
(col("score") - col("score").mean()) / col("score").std(1)
)
.filter(col("text_length").gt(10))
.select([col("text"), col("label"), col("score"), col("text_length")])
.collect()?;
println!("Shape: {:?}", df.shape());
println!("{}", df.head(Some(5)));
Ok(df)
}
fn batch_generator(df: &DataFrame, batch_size: usize) -> Vec<DataFrame> {
let n = df.height();
(0..n).step_by(batch_size).map(|start| {
let end = (start + batch_size).min(n);
df.slice(start as i64, end - start)
}).collect()
}
fn compute_label_stats(df: &DataFrame) -> PolarsResult<DataFrame> {
df.clone().lazy()
.group_by([col("label")])
.agg([
col("score").mean().alias("avg_score"),
col("score").std(1).alias("std_score"),
col("text_length").mean().alias("avg_length"),
col("label").count().alias("count"),
])
.sort(["count"], SortMultipleOptions::default().with_order_descending(true))
.collect()
}
Apache Arrowによるゼロコピーデータ共有
use arrow::array::{Float32Array, StringArray};
use arrow::record_batch::RecordBatch;
use arrow::datatypes::{DataType, Field, Schema};
use std::sync::Arc;
fn create_embedding_batch(
texts: Vec<String>,
embeddings: Vec<Vec<f32>>,
) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![
Field::new("text", DataType::Utf8, false),
Field::new("embedding_dim0", DataType::Float32, false),
]));
let text_array = StringArray::from(texts);
// 最初の次元のみ保存 (例示)
let emb_dim: Vec<f32> = embeddings.iter().map(|e| e[0]).collect();
let emb_array = Float32Array::from(emb_dim);
RecordBatch::try_new(
schema,
vec![Arc::new(text_array), Arc::new(emb_array)],
).unwrap()
}
7. ケーススタディ
Hugging Face Candle
HuggingFaceはPython PyTorchモデルをRust Candleに移植しています:
- バイナリサイズ: PyTorch 250MB以上 vs Candle 約5MB
- WASMサポートでブラウザ直接推論が可能
- CUDAカーネルの再利用も可能
Ruff: Python LinterのRust書き直し
RuffはFlake8より100倍高速なPythonリンターです。大規模MLコードベースでその差が実感できます。
# 従来のPythonリンター
time flake8 large_ml_project/ # 約30秒
# Ruff (Rust)
time ruff check large_ml_project/ # 約0.3秒
Pydantic v2のRustコア
Pydantic v2はコアバリデーションロジックをRust (pydantic-core) で書き直し、5〜50倍のパフォーマンス向上を達成しました。AI APIサーバーでのリクエスト検証レイテンシが大幅に削減されました。
クイズ
Q1. RustにおけるArcとRcの違いと、マルチスレッドAI推論でArcが必要な理由は?
答え: Arc (Atomic Reference Count) はアトミック操作で参照カウントを管理するためスレッドセーフですが、Rc (Reference Count) はシングルスレッド専用です。
解説: AI推論サーバーでは複数のリクエストスレッドが同じモデル重みを共有します。RcはSendトレイトを実装していないため、スレッド間で移動できません。Arcは参照カウントの増減にアトミック操作 (atomic fetch-add) を使用し、データ競合なしに複数スレッドが同じデータを参照できます。性能差はナノ秒レベルですが、安全性はコンパイル時に保証されます。
Q2. PyO3のGIL処理方式と、PythonからRust並列コードを呼び出す際のパフォーマンス利点は?
答え: PyO3は py.allow_threads() でGILを解放し、RustコードがPythonのGIL制約なしに真の並列実行を行います。
解説: PythonのGIL (Global Interpreter Lock) は一度に一つのPythonスレッドしか実行できないよう制限します。allow_threads を呼び出すと現在のスレッドがGILを解放し、Rustコード (rayon等) がOSスレッドを自由に使用できます。100万件の埋め込み類似度計算はPythonシングルスレッドで数十秒かかりますが、Rust + rayon並列化で数百ミリ秒に短縮できます。
Q3. CandleがPyTorchよりエッジ配備に有利な理由は?
答え: CandleはRustの単一バイナリとして配布可能でWASMをサポートし、PyTorchランタイム (250MB以上) 不要で約5MBのバイナリを生成できます。
解説: PyTorchはPythonランタイム、libtorch、CUDAドライバーなど大規模な依存性が必要です。Candleの静的リンクにより、必要な機能だけを含む小さな実行ファイルが生成されます。WebAssemblyにコンパイルするとブラウザで直接推論が可能になり、サーバーコストなしでクライアントサイドAIを実現できます。Raspberry Piや組み込みデバイスでもPythonなしで動作します。
Q4. Polarsが大容量データ処理でPandasより高速な内部実装の原理は?
答え: PolarsはRustで書かれており、Apache Arrowカラム型メモリフォーマット、SIMD最適化、遅延評価によるクエリ最適化を活用しています。
解説: PandasはPythonベースで行単位処理が多くGIL制約があります。Polarsは (1) Arrowカラム型フォーマットでキャッシュ効率が高く、(2) RustのSIMD自動ベクトル化で数値演算が高速で、(3) lazy APIでクエリプランを最適化 (predicate pushdown、projection pushdown) し、(4) GILなしにマルチスレッド並列処理を行います。一般的にPandasと比べて5〜20倍高速で、メモリ使用量も少なくなります。
Q5. Rustのゼロコスト抽象化 (zero-cost abstraction) がC++の性能に匹敵する理由は?
答え: Rustコンパイラー (LLVMバックエンド) が高水準の抽象化 (トレイト、ジェネリクス、イテレーター) をランタイムオーバーヘッドなしに最適化された機械語に変換します。
解説: C++のテンプレートとインライニングと同様に、Rustのジェネリクスは単相化 (monomorphization) によってコンパイル時に各型の具体的なコードを生成します。イテレーターチェーン (map、filter、fold) はCスタイルのforループと同じアセンブリを生成します。トレイトメソッドは静的ディスパッチ (static dispatch) で仮想関数テーブル参照なしにインライン化されます。結果として「書きやすくてCと同じ速さ」なコードが可能です。
まとめ
RustはAIインフラのパフォーマンスの壁を突破するための強力なツールです。Pythonでプロトタイプを作り、ボトルネックになった部分をRustで書き直す (PyO3)、または最初から高性能サービスをRustで構築する戦略が現実的です。
学習ロードマップ:
- The Rust Book — 公式チュートリアル
- Rustlings — インタラクティブ演習
- Candleサンプル — HuggingFace GitHub
- Axum公式ドキュメント — 推論サーバー構築
- PyO3ガイド — Python連携