Skip to main content

laminar_ai/backends/
local.rs

1//! Local inference via ONNX Runtime (`ort`, loaded dynamically). Encoder models
2//! only — the BERT / DistilBERT / MiniLM family: classification/sentiment yield
3//! logits (the adapter argmaxes), embedding yields a mean-pooled vector.
4//! Generative tasks are rejected. A model is loaded once per source and cached;
5//! the forward pass runs on `spawn_blocking`, off the Ring 1 task, under a
6//! deadline so a pathological model can never stall the worker (and the
7//! watermark behind it). ONNX Runtime is loaded at runtime, so `onnxruntime.dll`
8//! / `.so` (ORT >= 1.24) must be on the search path or named by `ORT_DYLIB_PATH`.
9//!
10//! A `source` resolves to a directory laid out like a Hugging Face export —
11//! `onnx/model.onnx` + `tokenizer.json` (+ optional `config.json` for labels):
12//! `hf:org/repo` → `<cache_dir>/org/repo`, `file://<path>` or a bare path used
13//! as-is. A missing `hf:` repo is downloaded from the Hugging Face CDN on first
14//! use (public repos only). Classifier labels come from the model's own
15//! `config.json` `id2label`; [`LocalProvider::intrinsic_labels`] resolves them on
16//! demand, so a model that downloads lazily scores correctly once it is cached —
17//! no restart, no externally supplied label list.
18
19use std::borrow::Cow;
20use std::collections::HashMap;
21use std::path::{Path, PathBuf};
22use std::sync::Arc;
23use std::time::Duration;
24
25use async_trait::async_trait;
26use ort::session::{Session, SessionInputValue};
27use ort::value::Tensor;
28use parking_lot::Mutex;
29
30use crate::provider::{
31    InferenceOutputs, InferenceProvider, InferenceRequest, InferenceResponse, ProviderError, Usage,
32};
33use crate::registry::Task;
34
35/// Per-batch deadline for the synchronous ONNX Runtime forward pass. Bounds the
36/// inference worker (and the held watermark) against a wedged model. On timeout
37/// the blocking thread is abandoned — it cannot be cancelled — and the batch
38/// fails, releasing the hold.
39const INFERENCE_TIMEOUT: Duration = Duration::from_secs(60);
40
41/// A loaded model: an ONNX Runtime session, its tokenizer, and the model's input
42/// names (`input_ids`, `attention_mask`, and optionally `token_type_ids`) that
43/// drive how each row is fed. `Session::run` takes `&mut`, so it sits behind a
44/// mutex — a model serves one batch at a time.
45struct LoadedModel {
46    session: Mutex<Session>,
47    tokenizer: tokenizers::Tokenizer,
48    input_names: Vec<String>,
49    /// `config.json` `id2label`, read once at load; empty if absent.
50    labels: Vec<String>,
51}
52
53/// Local ONNX provider, backed by a model cache directory.
54pub struct LocalProvider {
55    cache_dir: PathBuf,
56    loaded: Mutex<HashMap<String, Arc<LoadedModel>>>,
57    /// Serializes the download-and-compile path so concurrent misses for the
58    /// same model don't fetch and build it more than once.
59    load_lock: tokio::sync::Mutex<()>,
60    http: reqwest::Client,
61}
62
63impl LocalProvider {
64    /// Create a provider that resolves models under `cache_dir`.
65    #[must_use]
66    pub fn new(cache_dir: impl Into<PathBuf>) -> Self {
67        // A connect deadline plus a generous total cap bounds a hung download
68        // without killing a legitimately large model fetch.
69        let http = reqwest::Client::builder()
70            .connect_timeout(Duration::from_secs(15))
71            .timeout(Duration::from_secs(600))
72            .build()
73            .unwrap_or_default();
74        Self {
75            cache_dir: cache_dir.into(),
76            loaded: Mutex::new(HashMap::new()),
77            load_lock: tokio::sync::Mutex::new(()),
78            http,
79        }
80    }
81
82    /// Resolve a model to a loaded, cached plan: serve from cache, else download
83    /// (for an absent `hf:` repo) and compile on the blocking pool.
84    async fn ensure_model(&self, source: &str) -> Result<Arc<LoadedModel>, ProviderError> {
85        if let Some(model) = self.loaded.lock().get(source) {
86            return Ok(Arc::clone(model));
87        }
88        // Hold the load lock across download+compile and re-check the cache: a
89        // concurrent miss may have finished loading this model while we waited.
90        let _load = self.load_lock.lock().await;
91        if let Some(model) = self.loaded.lock().get(source) {
92            return Ok(Arc::clone(model));
93        }
94        let dir = model_dir(&self.cache_dir, source);
95        if let Some(repo) = source.strip_prefix("hf:") {
96            download_if_missing(&self.http, repo, &dir).await?;
97        }
98        // Compiling the ONNX graph is heavy, blocking work — keep it off Ring 1.
99        let loaded = tokio::task::spawn_blocking(move || load_model(&dir))
100            .await
101            .map_err(|e| ProviderError::Transport(format!("model load task: {e}")))??;
102        let loaded = Arc::new(loaded);
103        self.loaded
104            .lock()
105            .insert(source.to_string(), Arc::clone(&loaded));
106        Ok(loaded)
107    }
108}
109
110/// Download `model.onnx`, `tokenizer.json` (required) and `config.json`
111/// (optional — labels only) from the Hugging Face CDN if they are not already
112/// present. Public repos only; no auth.
113async fn download_if_missing(
114    http: &reqwest::Client,
115    repo: &str,
116    dir: &Path,
117) -> Result<(), ProviderError> {
118    if onnx_path(dir).exists() && dir.join("tokenizer.json").exists() {
119        return Ok(());
120    }
121    // Paths mirror the repo layout (the graph lives under `onnx/`); config.json
122    // is optional and only feeds label derivation, so any failure to fetch it —
123    // transport error or a 404 — is non-fatal.
124    for (rel, required) in [
125        ("onnx/model.onnx", true),
126        ("tokenizer.json", true),
127        ("config.json", false),
128    ] {
129        let dest = dir.join(rel);
130        if dest.exists() {
131            continue;
132        }
133        let url = format!("https://huggingface.co/{repo}/resolve/main/{rel}");
134        match download_file(http, &url, &dest).await {
135            Err(e) if required => return Err(e),
136            // Success, or an optional file (config.json) we couldn't fetch — skip.
137            Ok(()) | Err(_) => {}
138        }
139    }
140    Ok(())
141}
142
143/// Fetch one file from `url` into `dest`, creating parent directories. Errors on
144/// transport failure, a non-success status, or a write failure.
145async fn download_file(
146    http: &reqwest::Client,
147    url: &str,
148    dest: &Path,
149) -> Result<(), ProviderError> {
150    let resp = http
151        .get(url)
152        .send()
153        .await
154        .map_err(|e| ProviderError::Transport(format!("download {url}: {e}")))?;
155    if !resp.status().is_success() {
156        return Err(ProviderError::Transport(format!(
157            "download {url}: HTTP {}",
158            resp.status()
159        )));
160    }
161    let bytes = resp
162        .bytes()
163        .await
164        .map_err(|e| ProviderError::Transport(format!("download {url}: {e}")))?;
165    if let Some(parent) = dest.parent() {
166        tokio::fs::create_dir_all(parent)
167            .await
168            .map_err(|e| ProviderError::Transport(format!("create {}: {e}", parent.display())))?;
169    }
170    tokio::fs::write(dest, &bytes)
171        .await
172        .map_err(|e| ProviderError::Transport(format!("write {}: {e}", dest.display())))?;
173    Ok(())
174}
175
176/// On-disk directory for a model `source`: `hf:org/repo` → `<cache_dir>/org/repo`,
177/// `file://<path>` or a bare path used as-is.
178#[must_use]
179pub fn model_dir(cache_dir: &Path, source: &str) -> PathBuf {
180    if let Some(repo) = source.strip_prefix("hf:") {
181        cache_dir.join(repo)
182    } else if let Some(path) = source.strip_prefix("file://") {
183        PathBuf::from(path)
184    } else {
185        PathBuf::from(source)
186    }
187}
188
189/// Classifier labels from a model's `config.json` `id2label`, ordered by index.
190/// Empty if the file is absent or has no `id2label` — the registry uses this to
191/// auto-derive a local classifier's labels.
192#[must_use]
193pub fn load_labels(cache_dir: &Path, source: &str) -> Vec<String> {
194    std::fs::read_to_string(model_dir(cache_dir, source).join("config.json"))
195        .ok()
196        .map(|text| parse_id2label(&text))
197        .unwrap_or_default()
198}
199
200fn parse_id2label(config_json: &str) -> Vec<String> {
201    let Ok(json) = serde_json::from_str::<serde_json::Value>(config_json) else {
202        return Vec::new();
203    };
204    let Some(map) = json.get("id2label").and_then(serde_json::Value::as_object) else {
205        return Vec::new();
206    };
207    let mut indexed: Vec<(usize, String)> = map
208        .iter()
209        .filter_map(|(k, v)| Some((k.parse().ok()?, v.as_str()?.to_string())))
210        .collect();
211    indexed.sort_by_key(|(index, _)| *index);
212    indexed.into_iter().map(|(_, label)| label).collect()
213}
214
215#[async_trait]
216impl InferenceProvider for LocalProvider {
217    async fn infer_batch(
218        &self,
219        request: InferenceRequest,
220    ) -> Result<InferenceResponse, ProviderError> {
221        if matches!(
222            request.task,
223            Task::Complete | Task::Summarize | Task::Translate | Task::Gen | Task::Extract
224        ) {
225            return Err(ProviderError::UnsupportedTask(request.task));
226        }
227        let loaded = self.ensure_model(&request.model).await?;
228        let task = request.task;
229        let inputs = request.inputs;
230        // The forward pass is synchronous CPU work — keep it off the Ring 1 task,
231        // under a deadline so a wedged model cannot stall the worker indefinitely.
232        let run = tokio::task::spawn_blocking(move || run(&loaded, task, &inputs));
233        let outputs = match tokio::time::timeout(INFERENCE_TIMEOUT, run).await {
234            Ok(joined) => {
235                joined.map_err(|e| ProviderError::Transport(format!("inference task: {e}")))??
236            }
237            Err(_) => {
238                return Err(ProviderError::Timeout(
239                    u64::try_from(INFERENCE_TIMEOUT.as_millis()).unwrap_or(u64::MAX),
240                ))
241            }
242        };
243        Ok(InferenceResponse {
244            outputs,
245            usage: Usage::ZERO,
246        })
247    }
248
249    fn name(&self) -> &'static str {
250        "local"
251    }
252
253    fn intrinsic_labels(&self, model: &str) -> Option<Vec<String>> {
254        // Served from the loaded model (read once at load); only an unloaded model
255        // — i.e. before its first inference — touches disk here.
256        let labels = self
257            .loaded
258            .lock()
259            .get(model)
260            .map_or_else(|| load_labels(&self.cache_dir, model), |m| m.labels.clone());
261        (!labels.is_empty()).then_some(labels)
262    }
263}
264
265/// The model graph: the `onnx/model.onnx` that Optimum / transformers.js exports
266/// produce, falling back to a flat `model.onnx` for a hand-placed directory.
267fn onnx_path(dir: &Path) -> PathBuf {
268    let nested = dir.join("onnx").join("model.onnx");
269    if nested.exists() {
270        nested
271    } else {
272        dir.join("model.onnx")
273    }
274}
275
276fn load_model(dir: &Path) -> Result<LoadedModel, ProviderError> {
277    let onnx = onnx_path(dir);
278    let tokenizer_path = dir.join("tokenizer.json");
279    if !onnx.exists() || !tokenizer_path.exists() {
280        return Err(ProviderError::Transport(format!(
281            "local model files not found in {} (expected onnx/model.onnx + tokenizer.json)",
282            dir.display()
283        )));
284    }
285    let session = Session::builder()
286        .map_err(|e| ProviderError::Transport(format!("ort init: {e}")))?
287        .commit_from_file(&onnx)
288        .map_err(|e| ProviderError::Transport(format!("load onnx: {e}")))?;
289    let input_names = session
290        .inputs()
291        .iter()
292        .map(|i| i.name().to_string())
293        .collect();
294    let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
295        .map_err(|e| ProviderError::Transport(format!("load tokenizer: {e}")))?;
296    let labels = std::fs::read_to_string(dir.join("config.json"))
297        .ok()
298        .map(|text| parse_id2label(&text))
299        .unwrap_or_default();
300    Ok(LoadedModel {
301        session: Mutex::new(session),
302        tokenizer,
303        input_names,
304        labels,
305    })
306}
307
308fn run(
309    loaded: &LoadedModel,
310    task: Task,
311    inputs: &[String],
312) -> Result<InferenceOutputs, ProviderError> {
313    let mut session = loaded.session.lock();
314    let mut vectors = Vec::with_capacity(inputs.len());
315    for text in inputs {
316        let encoding = loaded
317            .tokenizer
318            .encode(text.as_str(), true)
319            .map_err(|e| ProviderError::BadResponse(format!("tokenize: {e}")))?;
320        let ids: Vec<i64> = encoding.get_ids().iter().map(|&u| i64::from(u)).collect();
321        let mask: Vec<i64> = encoding
322            .get_attention_mask()
323            .iter()
324            .map(|&u| i64::from(u))
325            .collect();
326        let seq = i64::try_from(ids.len()).unwrap_or(i64::MAX);
327
328        // Feed each input the model declares (batch of 1, [1, seq]); BERT adds
329        // token_type_ids (all zero, single segment), DistilBERT omits it.
330        let mut feeds: Vec<(Cow<str>, SessionInputValue)> =
331            Vec::with_capacity(loaded.input_names.len());
332        for name in &loaded.input_names {
333            let row = match name.as_str() {
334                "input_ids" => ids.clone(),
335                "attention_mask" => mask.clone(),
336                "token_type_ids" => vec![0i64; ids.len()],
337                other => {
338                    return Err(ProviderError::BadResponse(format!(
339                        "model expects unsupported input '{other}'"
340                    )))
341                }
342            };
343            let tensor = Tensor::from_array((vec![1i64, seq], row))
344                .map_err(|e| ProviderError::Transport(format!("build tensor: {e}")))?;
345            feeds.push((Cow::Owned(name.clone()), SessionInputValue::from(tensor)));
346        }
347
348        let outputs = session
349            .run(feeds)
350            .map_err(|e| ProviderError::Transport(format!("inference: {e}")))?;
351        let (shape, data) = outputs[0]
352            .try_extract_tensor::<f32>()
353            .map_err(|e| ProviderError::BadResponse(format!("read output: {e}")))?;
354
355        if task == Task::Embed && shape.len() == 3 {
356            // last_hidden_state [1, seq, hidden] → mean-pool over real tokens.
357            let seq_out = usize::try_from(shape[1]).unwrap_or(0);
358            let hidden = usize::try_from(shape[2]).unwrap_or(0);
359            vectors.push(mean_pool(data, seq_out, hidden, &mask));
360        } else {
361            // classify/sentiment logits, or a pre-pooled embedding.
362            vectors.push(data.to_vec());
363        }
364    }
365    drop(session);
366    Ok(InferenceOutputs::Vectors(vectors))
367}
368
369/// Mean-pool a row-major `[seq, hidden]` block over non-masked tokens.
370fn mean_pool(data: &[f32], seq: usize, hidden: usize, mask: &[i64]) -> Vec<f32> {
371    let mut pooled = vec![0.0_f32; hidden];
372    let mut count = 0.0_f32;
373    for t in 0..seq {
374        if mask.get(t).copied().unwrap_or(0) == 0 {
375            continue;
376        }
377        count += 1.0;
378        for h in 0..hidden {
379            pooled[h] += data[t * hidden + h];
380        }
381    }
382    if count > 0.0 {
383        for value in &mut pooled {
384            *value /= count;
385        }
386    }
387    pooled
388}
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393
394    #[test]
395    fn model_dir_resolution() {
396        let cache = Path::new("/models");
397        assert_eq!(
398            model_dir(cache, "hf:onnx-community/finbert"),
399            Path::new("/models/onnx-community/finbert")
400        );
401        assert_eq!(model_dir(cache, "file:///abs/dir"), Path::new("/abs/dir"));
402        assert_eq!(model_dir(cache, "/some/path"), Path::new("/some/path"));
403    }
404
405    #[test]
406    fn parse_id2label_orders_by_index() {
407        let config = r#"{"id2label": {"2": "positive", "0": "negative", "1": "neutral"}}"#;
408        assert_eq!(
409            parse_id2label(config),
410            vec!["negative", "neutral", "positive"]
411        );
412        assert!(parse_id2label("{}").is_empty());
413    }
414
415    #[test]
416    fn intrinsic_labels_read_from_a_cached_models_config() {
417        let cache = tempfile::tempdir().expect("tempdir");
418        let provider = LocalProvider::new(cache.path());
419        let source = "hf:org/repo";
420
421        // Absent until the model (its config.json) is on disk.
422        assert!(provider.intrinsic_labels(source).is_none());
423
424        let dir = model_dir(cache.path(), source);
425        std::fs::create_dir_all(&dir).unwrap();
426        std::fs::write(
427            dir.join("config.json"),
428            r#"{"id2label": {"0": "NEGATIVE", "1": "POSITIVE"}}"#,
429        )
430        .unwrap();
431        assert_eq!(
432            provider.intrinsic_labels(source),
433            Some(vec!["NEGATIVE".into(), "POSITIVE".into()])
434        );
435    }
436
437    #[test]
438    fn mean_pool_ignores_masked_tokens() {
439        // seq=3, hidden=2; token 2 is padding (mask 0).
440        let data = [1.0, 2.0, 3.0, 4.0, 100.0, 100.0];
441        let pooled = mean_pool(&data, 3, 2, &[1, 1, 0]);
442        assert_eq!(pooled, vec![2.0, 3.0]); // mean of rows 0 and 1 only
443    }
444
445    /// End-to-end against a real export: resolve the `onnx/` layout, download a
446    /// DistilBERT SST-2 sentiment classifier from the Hugging Face CDN, tokenize,
447    /// run it through ONNX Runtime, and check the argmax labels match the
448    /// sentiment of clearly positive and negative inputs.
449    ///
450    /// Opt-in: network + a ~268 MB model download, and ONNX Runtime must be
451    /// loadable at runtime (`ORT_DYLIB_PATH=/path/to/onnxruntime.dll`, ORT >= 1.24).
452    #[tokio::test]
453    #[ignore = "downloads a model + needs ORT_DYLIB_PATH; run with --ignored"]
454    async fn classifies_with_a_real_onnx_community_model() {
455        use crate::adapter::parse_response;
456        use crate::provider::InferenceParams;
457        use crate::registry::BackendKind;
458
459        let cache = tempfile::tempdir().expect("tempdir");
460        let provider = LocalProvider::new(cache.path());
461        let source = "hf:onnx-community/distilbert-base-uncased-finetuned-sst-2-english-ONNX";
462        let request = InferenceRequest {
463            task: Task::Classify,
464            model: source.to_string(),
465            inputs: vec![
466                "this film was absolutely wonderful, I loved every minute".into(),
467                "a complete waste of time, dull and disappointing".into(),
468            ],
469            params: InferenceParams::default(),
470        };
471
472        let response = provider.infer_batch(request).await.expect("inference");
473        let InferenceOutputs::Vectors(rows) = &response.outputs else {
474            panic!("local classify returns logits");
475        };
476        let labels = load_labels(cache.path(), source);
477        assert!(!labels.is_empty(), "id2label should load from config.json");
478        assert_eq!(rows.len(), 2);
479        assert!(
480            rows.iter().all(|r| r.len() == labels.len()),
481            "logit dimension must equal the label count",
482        );
483
484        let InferenceOutputs::Text(out) = parse_response(
485            Task::Classify,
486            BackendKind::Local,
487            response.outputs,
488            Some(&labels),
489        )
490        .expect("adapt") else {
491            panic!("argmax yields labels");
492        };
493        assert_eq!(out, vec!["POSITIVE".to_string(), "NEGATIVE".to_string()]);
494    }
495}