Skip to main content

laminar_db/ai/backends/
local.rs

1//! Local inference via ONNX Runtime (`ort`, loaded dynamically): encoder models
2//! only (BERT/DistilBERT/MiniLM) — classify/sentiment return logits, embed
3//! returns a mean-pooled vector; generative tasks are rejected. A model is loaded
4//! once per source and cached; the forward pass runs on `spawn_blocking` under
5//! a deadline so a wedged model can't stall the worker. ORT loads at
6//! runtime — `onnxruntime.{dll,so}` (>= 1.24) must be on the search path or named
7//! by `ORT_DYLIB_PATH`. A `source` is `hf:org/repo` (downloaded on first use),
8//! `file://<path>`, or a bare path; labels come from the model's `config.json`
9//! `id2label`.
10
11use std::borrow::Cow;
12use std::collections::HashMap;
13use std::path::{Path, PathBuf};
14use std::sync::Arc;
15use std::time::Duration;
16
17use async_trait::async_trait;
18use ort::session::{Session, SessionInputValue};
19use ort::value::Tensor;
20use parking_lot::Mutex;
21
22use crate::ai::provider::{
23    InferenceOutputs, InferenceProvider, InferenceRequest, InferenceResponse, ProviderError, Usage,
24};
25use crate::ai::registry::Task;
26
27/// Per-batch deadline for the synchronous ONNX Runtime forward pass. Bounds the
28/// inference worker (and the held watermark) against a wedged model. On timeout
29/// the blocking thread is abandoned — it cannot be cancelled — and the batch
30/// fails, releasing the hold.
31const INFERENCE_TIMEOUT: Duration = Duration::from_secs(60);
32
33/// A loaded model: an ONNX Runtime session, its tokenizer, and the model's input
34/// names (`input_ids`, `attention_mask`, and optionally `token_type_ids`) that
35/// drive how each row is fed. `Session::run` takes `&mut`, so it sits behind a
36/// mutex — a model serves one batch at a time.
37struct LoadedModel {
38    session: Mutex<Session>,
39    tokenizer: tokenizers::Tokenizer,
40    input_names: Vec<String>,
41    /// `config.json` `id2label`, read once at load; empty if absent.
42    labels: Vec<String>,
43}
44
45/// Local ONNX provider, backed by a model cache directory.
46pub struct LocalProvider {
47    cache_dir: PathBuf,
48    loaded: Mutex<HashMap<String, Arc<LoadedModel>>>,
49    /// Serializes the download-and-compile path so concurrent misses for the
50    /// same model don't fetch and build it more than once.
51    load_lock: tokio::sync::Mutex<()>,
52}
53
54impl LocalProvider {
55    /// Create a provider that resolves models under `cache_dir`.
56    #[must_use]
57    pub fn new(cache_dir: impl Into<PathBuf>) -> Self {
58        Self {
59            cache_dir: cache_dir.into(),
60            loaded: Mutex::new(HashMap::new()),
61            load_lock: tokio::sync::Mutex::new(()),
62        }
63    }
64
65    /// Resolve a model to a loaded, cached plan: serve from cache, else download
66    /// (for an absent `hf:` repo) and compile on the blocking pool.
67    async fn ensure_model(&self, source: &str) -> Result<Arc<LoadedModel>, ProviderError> {
68        if let Some(model) = self.loaded.lock().get(source) {
69            return Ok(Arc::clone(model));
70        }
71        // Hold the load lock across download+compile and re-check the cache: a
72        // concurrent miss may have finished loading this model while we waited.
73        let _load = self.load_lock.lock().await;
74        if let Some(model) = self.loaded.lock().get(source) {
75            return Ok(Arc::clone(model));
76        }
77
78        let loaded = if let Some(repo_id) = source.strip_prefix("hf:") {
79            let api = hf_hub::api::tokio::ApiBuilder::from_env()
80                .with_cache_dir(self.cache_dir.clone())
81                .build()
82                .map_err(|e| {
83                    ProviderError::Transport(format!("failed to initialize hf-hub client: {e}"))
84                })?;
85            let repo = api.model(repo_id.to_string());
86
87            let tokenizer_path = repo.get("tokenizer.json").await.map_err(|e| {
88                ProviderError::Transport(format!("failed to download tokenizer.json: {e}"))
89            })?;
90
91            let onnx_path = match repo.get("onnx/model.onnx").await {
92                Ok(path) => path,
93                Err(_) => repo.get("model.onnx").await.map_err(|e| {
94                    ProviderError::Transport(format!("failed to download model.onnx: {e}"))
95                })?,
96            };
97
98            let config_path = repo.get("config.json").await.ok();
99
100            // Compiling the ONNX graph is heavy, blocking work — keep it off Ring 1.
101            tokio::task::spawn_blocking(move || {
102                load_model_from_paths(&onnx_path, &tokenizer_path, config_path.as_deref())
103            })
104            .await
105            .map_err(|e| ProviderError::Transport(format!("model load task: {e}")))??
106        } else {
107            let dir = model_dir(&self.cache_dir, source);
108            // Compiling the ONNX graph is heavy, blocking work — keep it off Ring 1.
109            tokio::task::spawn_blocking(move || load_model(&dir))
110                .await
111                .map_err(|e| ProviderError::Transport(format!("model load task: {e}")))??
112        };
113
114        let loaded = Arc::new(loaded);
115        self.loaded
116            .lock()
117            .insert(source.to_string(), Arc::clone(&loaded));
118        Ok(loaded)
119    }
120}
121
122/// On-disk directory for a model `source`: `hf:org/repo` → `<cache_dir>/org/repo`,
123/// `file://<path>` or a bare path used as-is.
124#[must_use]
125pub fn model_dir(cache_dir: &Path, source: &str) -> PathBuf {
126    if let Some(repo) = source.strip_prefix("hf:") {
127        cache_dir.join(repo)
128    } else if let Some(path) = source.strip_prefix("file://") {
129        PathBuf::from(path)
130    } else {
131        PathBuf::from(source)
132    }
133}
134
135/// Classifier labels from a model's `config.json` `id2label`, ordered by index.
136/// Empty if the file is absent or has no `id2label` — the registry uses this to
137/// auto-derive a local classifier's labels.
138#[must_use]
139pub fn load_labels(cache_dir: &Path, source: &str) -> Vec<String> {
140    // A config.json placed directly under the model dir (hand-placed exports,
141    // `file://`/path sources, and the legacy hf cache layout) wins.
142    if let Ok(text) = std::fs::read_to_string(model_dir(cache_dir, source).join("config.json")) {
143        return parse_id2label(&text);
144    }
145    // Otherwise resolve an hf-hub-downloaded model from its cache snapshot.
146    if let Some(repo_id) = source.strip_prefix("hf:") {
147        let cache = hf_hub::Cache::new(cache_dir.to_path_buf());
148        let repo = cache.repo(hf_hub::Repo::model(repo_id.to_string()));
149        if let Some(path) = repo.get("config.json") {
150            if let Ok(text) = std::fs::read_to_string(path) {
151                return parse_id2label(&text);
152            }
153        }
154    }
155    Vec::new()
156}
157
158fn parse_id2label(config_json: &str) -> Vec<String> {
159    let Ok(json) = serde_json::from_str::<serde_json::Value>(config_json) else {
160        return Vec::new();
161    };
162    let Some(map) = json.get("id2label").and_then(serde_json::Value::as_object) else {
163        return Vec::new();
164    };
165    let mut indexed: Vec<(usize, String)> = map
166        .iter()
167        .filter_map(|(k, v)| Some((k.parse().ok()?, v.as_str()?.to_string())))
168        .collect();
169    indexed.sort_by_key(|(index, _)| *index);
170    indexed.into_iter().map(|(_, label)| label).collect()
171}
172
173#[async_trait]
174impl InferenceProvider for LocalProvider {
175    async fn infer_batch(
176        &self,
177        request: InferenceRequest,
178    ) -> Result<InferenceResponse, ProviderError> {
179        if matches!(
180            request.task,
181            Task::Complete | Task::Summarize | Task::Translate | Task::Gen | Task::Extract
182        ) {
183            return Err(ProviderError::UnsupportedTask(request.task));
184        }
185        let loaded = self.ensure_model(&request.model).await?;
186        let task = request.task;
187        let inputs = request.inputs;
188        // The forward pass is synchronous CPU work — keep it off the Ring 1 task,
189        // under a deadline so a wedged model cannot stall the worker indefinitely.
190        let run = tokio::task::spawn_blocking(move || run(&loaded, task, &inputs));
191        let outputs = match tokio::time::timeout(INFERENCE_TIMEOUT, run).await {
192            Ok(joined) => {
193                joined.map_err(|e| ProviderError::Transport(format!("inference task: {e}")))??
194            }
195            Err(_) => {
196                return Err(ProviderError::Timeout(
197                    u64::try_from(INFERENCE_TIMEOUT.as_millis()).unwrap_or(u64::MAX),
198                ))
199            }
200        };
201        Ok(InferenceResponse {
202            outputs,
203            usage: Usage::ZERO,
204        })
205    }
206
207    fn name(&self) -> &'static str {
208        "local"
209    }
210
211    fn intrinsic_labels(&self, model: &str) -> Option<Vec<String>> {
212        // Served from the loaded model (read once at load); only an unloaded model
213        // — i.e. before its first inference — touches disk here.
214        let labels = self
215            .loaded
216            .lock()
217            .get(model)
218            .map_or_else(|| load_labels(&self.cache_dir, model), |m| m.labels.clone());
219        (!labels.is_empty()).then_some(labels)
220    }
221}
222
223/// The model graph: the `onnx/model.onnx` that Optimum / transformers.js exports
224/// produce, falling back to a flat `model.onnx` for a hand-placed directory.
225fn onnx_path(dir: &Path) -> PathBuf {
226    let nested = dir.join("onnx").join("model.onnx");
227    if nested.exists() {
228        nested
229    } else {
230        dir.join("model.onnx")
231    }
232}
233
234fn load_model_from_paths(
235    onnx_path: &Path,
236    tokenizer_path: &Path,
237    config_path: Option<&Path>,
238) -> Result<LoadedModel, ProviderError> {
239    let session = Session::builder()
240        .map_err(|e| ProviderError::Transport(format!("ort init: {e}")))?
241        .commit_from_file(onnx_path)
242        .map_err(|e| ProviderError::Transport(format!("load onnx: {e}")))?;
243    let input_names = session
244        .inputs()
245        .iter()
246        .map(|i| i.name().to_string())
247        .collect();
248    let mut tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
249        .map_err(|e| ProviderError::Transport(format!("load tokenizer: {e}")))?;
250    if tokenizer.get_padding().is_none() {
251        tokenizer.with_padding(Some(tokenizers::PaddingParams {
252            strategy: tokenizers::PaddingStrategy::BatchLongest,
253            direction: tokenizers::PaddingDirection::Right,
254            pad_to_multiple_of: None,
255            pad_id: 0,
256            pad_type_id: 0,
257            pad_token: "[PAD]".to_string(),
258        }));
259    }
260    let labels = if let Some(path) = config_path {
261        std::fs::read_to_string(path)
262            .ok()
263            .map(|text| parse_id2label(&text))
264            .unwrap_or_default()
265    } else {
266        Vec::new()
267    };
268    Ok(LoadedModel {
269        session: Mutex::new(session),
270        tokenizer,
271        input_names,
272        labels,
273    })
274}
275
276fn load_model(dir: &Path) -> Result<LoadedModel, ProviderError> {
277    let onnx = onnx_path(dir);
278    let tokenizer_path = dir.join("tokenizer.json");
279    if !onnx.exists() || !tokenizer_path.exists() {
280        return Err(ProviderError::Transport(format!(
281            "local model files not found in {} (expected onnx/model.onnx + tokenizer.json)",
282            dir.display()
283        )));
284    }
285    let config_path = dir.join("config.json");
286    let config_path_opt = config_path.exists().then_some(config_path);
287    load_model_from_paths(&onnx, &tokenizer_path, config_path_opt.as_deref())
288}
289
290fn run(
291    loaded: &LoadedModel,
292    task: Task,
293    inputs: &[String],
294) -> Result<InferenceOutputs, ProviderError> {
295    if inputs.is_empty() {
296        return Ok(InferenceOutputs::Vectors(vec![]));
297    }
298
299    let encodings = loaded
300        .tokenizer
301        .encode_batch(inputs.to_vec(), true)
302        .map_err(|e| ProviderError::BadResponse(format!("tokenize batch: {e}")))?;
303
304    let batch_size = encodings.len();
305    if batch_size == 0 {
306        return Ok(InferenceOutputs::Vectors(vec![]));
307    }
308    let seq = encodings[0].len();
309
310    let mut stacked_ids = Vec::with_capacity(batch_size * seq);
311    let mut stacked_mask = Vec::with_capacity(batch_size * seq);
312    for encoding in &encodings {
313        stacked_ids.extend(encoding.get_ids().iter().map(|&u| i64::from(u)));
314        stacked_mask.extend(encoding.get_attention_mask().iter().map(|&u| i64::from(u)));
315    }
316
317    let batch_size_i64 = i64::try_from(batch_size).unwrap_or(i64::MAX);
318    let seq_i64 = i64::try_from(seq).unwrap_or(i64::MAX);
319    let shape = vec![batch_size_i64, seq_i64];
320
321    let mut feeds: Vec<(Cow<str>, SessionInputValue)> =
322        Vec::with_capacity(loaded.input_names.len());
323    for name in &loaded.input_names {
324        let row = match name.as_str() {
325            "input_ids" => stacked_ids.clone(),
326            "attention_mask" => stacked_mask.clone(),
327            "token_type_ids" => vec![0i64; batch_size * seq],
328            other => {
329                return Err(ProviderError::BadResponse(format!(
330                    "model expects unsupported input '{other}'"
331                )))
332            }
333        };
334        let tensor = Tensor::from_array((shape.clone(), row))
335            .map_err(|e| ProviderError::Transport(format!("build tensor: {e}")))?;
336        feeds.push((Cow::Owned(name.clone()), SessionInputValue::from(tensor)));
337    }
338
339    let mut session = loaded.session.lock();
340    let outputs = session
341        .run(feeds)
342        .map_err(|e| ProviderError::Transport(format!("inference: {e}")))?;
343    let (shape, data) = outputs[0]
344        .try_extract_tensor::<f32>()
345        .map_err(|e| ProviderError::BadResponse(format!("read output: {e}")))?;
346
347    if shape.first().copied().and_then(|d| usize::try_from(d).ok()) != Some(batch_size) {
348        return Err(ProviderError::BadResponse(format!(
349            "expected batch size {batch_size}, got output shape {shape:?}"
350        )));
351    }
352
353    let mut vectors = Vec::with_capacity(batch_size);
354    if task == Task::Embed && shape.len() == 3 {
355        // last_hidden_state [batch_size, seq_out, hidden] → mean-pool over real tokens.
356        let seq_out = usize::try_from(shape[1]).unwrap_or(0);
357        let hidden = usize::try_from(shape[2]).unwrap_or(0);
358        for (i, encoding) in encodings.iter().enumerate() {
359            let start = i * seq_out * hidden;
360            let end = (i + 1) * seq_out * hidden;
361            let slice = &data[start..end];
362            let mask: Vec<i64> = encoding
363                .get_attention_mask()
364                .iter()
365                .map(|&u| i64::from(u))
366                .collect();
367            vectors.push(mean_pool(slice, seq_out, hidden, &mask));
368        }
369    } else {
370        // classify/sentiment logits, or a pre-pooled embedding.
371        let dim: usize = shape[1..]
372            .iter()
373            .map(|&d| usize::try_from(d).unwrap_or(0))
374            .product();
375        for i in 0..batch_size {
376            let start = i * dim;
377            let end = (i + 1) * dim;
378            let slice = &data[start..end];
379            vectors.push(slice.to_vec());
380        }
381    }
382
383    Ok(InferenceOutputs::Vectors(vectors))
384}
385
386/// Mean-pool a row-major `[seq, hidden]` block over non-masked tokens.
387fn mean_pool(data: &[f32], seq: usize, hidden: usize, mask: &[i64]) -> Vec<f32> {
388    let mut pooled = vec![0.0_f32; hidden];
389    let mut count = 0.0_f32;
390    for t in 0..seq {
391        if mask.get(t).copied().unwrap_or(0) == 0 {
392            continue;
393        }
394        count += 1.0;
395        for h in 0..hidden {
396            pooled[h] += data[t * hidden + h];
397        }
398    }
399    if count > 0.0 {
400        for value in &mut pooled {
401            *value /= count;
402        }
403    }
404    pooled
405}
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410
411    #[test]
412    fn model_dir_resolution() {
413        let cache = Path::new("/models");
414        assert_eq!(
415            model_dir(cache, "hf:onnx-community/finbert"),
416            Path::new("/models/onnx-community/finbert")
417        );
418        assert_eq!(model_dir(cache, "file:///abs/dir"), Path::new("/abs/dir"));
419        assert_eq!(model_dir(cache, "/some/path"), Path::new("/some/path"));
420    }
421
422    #[test]
423    fn parse_id2label_orders_by_index() {
424        let config = r#"{"id2label": {"2": "positive", "0": "negative", "1": "neutral"}}"#;
425        assert_eq!(
426            parse_id2label(config),
427            vec!["negative", "neutral", "positive"]
428        );
429        assert!(parse_id2label("{}").is_empty());
430    }
431
432    #[test]
433    fn intrinsic_labels_read_from_a_cached_models_config() {
434        let cache = tempfile::tempdir().expect("tempdir");
435        let provider = LocalProvider::new(cache.path());
436        let source = "hf:org/repo";
437
438        // Absent until the model (its config.json) is on disk.
439        assert!(provider.intrinsic_labels(source).is_none());
440
441        let dir = model_dir(cache.path(), source);
442        std::fs::create_dir_all(&dir).unwrap();
443        std::fs::write(
444            dir.join("config.json"),
445            r#"{"id2label": {"0": "NEGATIVE", "1": "POSITIVE"}}"#,
446        )
447        .unwrap();
448        assert_eq!(
449            provider.intrinsic_labels(source),
450            Some(vec!["NEGATIVE".into(), "POSITIVE".into()])
451        );
452    }
453
454    #[test]
455    fn mean_pool_ignores_masked_tokens() {
456        // seq=3, hidden=2; token 2 is padding (mask 0).
457        let data = [1.0, 2.0, 3.0, 4.0, 100.0, 100.0];
458        let pooled = mean_pool(&data, 3, 2, &[1, 1, 0]);
459        assert_eq!(pooled, vec![2.0, 3.0]); // mean of rows 0 and 1 only
460    }
461
462    /// End-to-end against a real export: resolve the `onnx/` layout, download a
463    /// DistilBERT SST-2 sentiment classifier from the Hugging Face CDN, tokenize,
464    /// run it through ONNX Runtime, and check the argmax labels match the
465    /// sentiment of clearly positive and negative inputs.
466    ///
467    /// Opt-in: network + a ~268 MB model download, and ONNX Runtime must be
468    /// loadable at runtime (`ORT_DYLIB_PATH=/path/to/onnxruntime.dll`, ORT >= 1.24).
469    #[tokio::test]
470    #[ignore = "downloads a model + needs ORT_DYLIB_PATH; run with --ignored"]
471    async fn classifies_with_a_real_onnx_community_model() {
472        use crate::ai::adapter::parse_response;
473        use crate::ai::provider::InferenceParams;
474        use crate::ai::registry::BackendKind;
475
476        let cache = tempfile::tempdir().expect("tempdir");
477        let provider = LocalProvider::new(cache.path());
478        let source = "hf:onnx-community/distilbert-base-uncased-finetuned-sst-2-english-ONNX";
479        let request = InferenceRequest {
480            task: Task::Classify,
481            model: source.to_string(),
482            inputs: vec![
483                "this film was absolutely wonderful, I loved every minute".into(),
484                "a complete waste of time, dull and disappointing".into(),
485            ],
486            params: InferenceParams::default(),
487        };
488
489        let response = provider.infer_batch(request).await.expect("inference");
490        let InferenceOutputs::Vectors(rows) = &response.outputs else {
491            panic!("local classify returns logits");
492        };
493        let labels = load_labels(cache.path(), source);
494        assert!(!labels.is_empty(), "id2label should load from config.json");
495        assert_eq!(rows.len(), 2);
496        assert!(
497            rows.iter().all(|r| r.len() == labels.len()),
498            "logit dimension must equal the label count",
499        );
500
501        let InferenceOutputs::Text(out) = parse_response(
502            Task::Classify,
503            BackendKind::Local,
504            response.outputs,
505            Some(&labels),
506        )
507        .expect("adapt") else {
508            panic!("argmax yields labels");
509        };
510        assert_eq!(out, vec!["POSITIVE".to_string(), "NEGATIVE".to_string()]);
511    }
512}