1use std::borrow::Cow;
20use std::collections::HashMap;
21use std::path::{Path, PathBuf};
22use std::sync::Arc;
23use std::time::Duration;
24
25use async_trait::async_trait;
26use ort::session::{Session, SessionInputValue};
27use ort::value::Tensor;
28use parking_lot::Mutex;
29
30use crate::provider::{
31 InferenceOutputs, InferenceProvider, InferenceRequest, InferenceResponse, ProviderError, Usage,
32};
33use crate::registry::Task;
34
35const INFERENCE_TIMEOUT: Duration = Duration::from_secs(60);
40
41struct LoadedModel {
46 session: Mutex<Session>,
47 tokenizer: tokenizers::Tokenizer,
48 input_names: Vec<String>,
49 labels: Vec<String>,
51}
52
53pub struct LocalProvider {
55 cache_dir: PathBuf,
56 loaded: Mutex<HashMap<String, Arc<LoadedModel>>>,
57 load_lock: tokio::sync::Mutex<()>,
60 http: reqwest::Client,
61}
62
63impl LocalProvider {
64 #[must_use]
66 pub fn new(cache_dir: impl Into<PathBuf>) -> Self {
67 let http = reqwest::Client::builder()
70 .connect_timeout(Duration::from_secs(15))
71 .timeout(Duration::from_secs(600))
72 .build()
73 .unwrap_or_default();
74 Self {
75 cache_dir: cache_dir.into(),
76 loaded: Mutex::new(HashMap::new()),
77 load_lock: tokio::sync::Mutex::new(()),
78 http,
79 }
80 }
81
82 async fn ensure_model(&self, source: &str) -> Result<Arc<LoadedModel>, ProviderError> {
85 if let Some(model) = self.loaded.lock().get(source) {
86 return Ok(Arc::clone(model));
87 }
88 let _load = self.load_lock.lock().await;
91 if let Some(model) = self.loaded.lock().get(source) {
92 return Ok(Arc::clone(model));
93 }
94 let dir = model_dir(&self.cache_dir, source);
95 if let Some(repo) = source.strip_prefix("hf:") {
96 download_if_missing(&self.http, repo, &dir).await?;
97 }
98 let loaded = tokio::task::spawn_blocking(move || load_model(&dir))
100 .await
101 .map_err(|e| ProviderError::Transport(format!("model load task: {e}")))??;
102 let loaded = Arc::new(loaded);
103 self.loaded
104 .lock()
105 .insert(source.to_string(), Arc::clone(&loaded));
106 Ok(loaded)
107 }
108}
109
110async fn download_if_missing(
114 http: &reqwest::Client,
115 repo: &str,
116 dir: &Path,
117) -> Result<(), ProviderError> {
118 if onnx_path(dir).exists() && dir.join("tokenizer.json").exists() {
119 return Ok(());
120 }
121 for (rel, required) in [
125 ("onnx/model.onnx", true),
126 ("tokenizer.json", true),
127 ("config.json", false),
128 ] {
129 let dest = dir.join(rel);
130 if dest.exists() {
131 continue;
132 }
133 let url = format!("https://huggingface.co/{repo}/resolve/main/{rel}");
134 match download_file(http, &url, &dest).await {
135 Err(e) if required => return Err(e),
136 Ok(()) | Err(_) => {}
138 }
139 }
140 Ok(())
141}
142
143async fn download_file(
146 http: &reqwest::Client,
147 url: &str,
148 dest: &Path,
149) -> Result<(), ProviderError> {
150 let resp = http
151 .get(url)
152 .send()
153 .await
154 .map_err(|e| ProviderError::Transport(format!("download {url}: {e}")))?;
155 if !resp.status().is_success() {
156 return Err(ProviderError::Transport(format!(
157 "download {url}: HTTP {}",
158 resp.status()
159 )));
160 }
161 let bytes = resp
162 .bytes()
163 .await
164 .map_err(|e| ProviderError::Transport(format!("download {url}: {e}")))?;
165 if let Some(parent) = dest.parent() {
166 tokio::fs::create_dir_all(parent)
167 .await
168 .map_err(|e| ProviderError::Transport(format!("create {}: {e}", parent.display())))?;
169 }
170 tokio::fs::write(dest, &bytes)
171 .await
172 .map_err(|e| ProviderError::Transport(format!("write {}: {e}", dest.display())))?;
173 Ok(())
174}
175
176#[must_use]
179pub fn model_dir(cache_dir: &Path, source: &str) -> PathBuf {
180 if let Some(repo) = source.strip_prefix("hf:") {
181 cache_dir.join(repo)
182 } else if let Some(path) = source.strip_prefix("file://") {
183 PathBuf::from(path)
184 } else {
185 PathBuf::from(source)
186 }
187}
188
189#[must_use]
193pub fn load_labels(cache_dir: &Path, source: &str) -> Vec<String> {
194 std::fs::read_to_string(model_dir(cache_dir, source).join("config.json"))
195 .ok()
196 .map(|text| parse_id2label(&text))
197 .unwrap_or_default()
198}
199
200fn parse_id2label(config_json: &str) -> Vec<String> {
201 let Ok(json) = serde_json::from_str::<serde_json::Value>(config_json) else {
202 return Vec::new();
203 };
204 let Some(map) = json.get("id2label").and_then(serde_json::Value::as_object) else {
205 return Vec::new();
206 };
207 let mut indexed: Vec<(usize, String)> = map
208 .iter()
209 .filter_map(|(k, v)| Some((k.parse().ok()?, v.as_str()?.to_string())))
210 .collect();
211 indexed.sort_by_key(|(index, _)| *index);
212 indexed.into_iter().map(|(_, label)| label).collect()
213}
214
215#[async_trait]
216impl InferenceProvider for LocalProvider {
217 async fn infer_batch(
218 &self,
219 request: InferenceRequest,
220 ) -> Result<InferenceResponse, ProviderError> {
221 if matches!(
222 request.task,
223 Task::Complete | Task::Summarize | Task::Translate | Task::Gen | Task::Extract
224 ) {
225 return Err(ProviderError::UnsupportedTask(request.task));
226 }
227 let loaded = self.ensure_model(&request.model).await?;
228 let task = request.task;
229 let inputs = request.inputs;
230 let run = tokio::task::spawn_blocking(move || run(&loaded, task, &inputs));
233 let outputs = match tokio::time::timeout(INFERENCE_TIMEOUT, run).await {
234 Ok(joined) => {
235 joined.map_err(|e| ProviderError::Transport(format!("inference task: {e}")))??
236 }
237 Err(_) => {
238 return Err(ProviderError::Timeout(
239 u64::try_from(INFERENCE_TIMEOUT.as_millis()).unwrap_or(u64::MAX),
240 ))
241 }
242 };
243 Ok(InferenceResponse {
244 outputs,
245 usage: Usage::ZERO,
246 })
247 }
248
249 fn name(&self) -> &'static str {
250 "local"
251 }
252
253 fn intrinsic_labels(&self, model: &str) -> Option<Vec<String>> {
254 let labels = self
257 .loaded
258 .lock()
259 .get(model)
260 .map_or_else(|| load_labels(&self.cache_dir, model), |m| m.labels.clone());
261 (!labels.is_empty()).then_some(labels)
262 }
263}
264
265fn onnx_path(dir: &Path) -> PathBuf {
268 let nested = dir.join("onnx").join("model.onnx");
269 if nested.exists() {
270 nested
271 } else {
272 dir.join("model.onnx")
273 }
274}
275
276fn load_model(dir: &Path) -> Result<LoadedModel, ProviderError> {
277 let onnx = onnx_path(dir);
278 let tokenizer_path = dir.join("tokenizer.json");
279 if !onnx.exists() || !tokenizer_path.exists() {
280 return Err(ProviderError::Transport(format!(
281 "local model files not found in {} (expected onnx/model.onnx + tokenizer.json)",
282 dir.display()
283 )));
284 }
285 let session = Session::builder()
286 .map_err(|e| ProviderError::Transport(format!("ort init: {e}")))?
287 .commit_from_file(&onnx)
288 .map_err(|e| ProviderError::Transport(format!("load onnx: {e}")))?;
289 let input_names = session
290 .inputs()
291 .iter()
292 .map(|i| i.name().to_string())
293 .collect();
294 let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
295 .map_err(|e| ProviderError::Transport(format!("load tokenizer: {e}")))?;
296 let labels = std::fs::read_to_string(dir.join("config.json"))
297 .ok()
298 .map(|text| parse_id2label(&text))
299 .unwrap_or_default();
300 Ok(LoadedModel {
301 session: Mutex::new(session),
302 tokenizer,
303 input_names,
304 labels,
305 })
306}
307
308fn run(
309 loaded: &LoadedModel,
310 task: Task,
311 inputs: &[String],
312) -> Result<InferenceOutputs, ProviderError> {
313 let mut session = loaded.session.lock();
314 let mut vectors = Vec::with_capacity(inputs.len());
315 for text in inputs {
316 let encoding = loaded
317 .tokenizer
318 .encode(text.as_str(), true)
319 .map_err(|e| ProviderError::BadResponse(format!("tokenize: {e}")))?;
320 let ids: Vec<i64> = encoding.get_ids().iter().map(|&u| i64::from(u)).collect();
321 let mask: Vec<i64> = encoding
322 .get_attention_mask()
323 .iter()
324 .map(|&u| i64::from(u))
325 .collect();
326 let seq = i64::try_from(ids.len()).unwrap_or(i64::MAX);
327
328 let mut feeds: Vec<(Cow<str>, SessionInputValue)> =
331 Vec::with_capacity(loaded.input_names.len());
332 for name in &loaded.input_names {
333 let row = match name.as_str() {
334 "input_ids" => ids.clone(),
335 "attention_mask" => mask.clone(),
336 "token_type_ids" => vec![0i64; ids.len()],
337 other => {
338 return Err(ProviderError::BadResponse(format!(
339 "model expects unsupported input '{other}'"
340 )))
341 }
342 };
343 let tensor = Tensor::from_array((vec![1i64, seq], row))
344 .map_err(|e| ProviderError::Transport(format!("build tensor: {e}")))?;
345 feeds.push((Cow::Owned(name.clone()), SessionInputValue::from(tensor)));
346 }
347
348 let outputs = session
349 .run(feeds)
350 .map_err(|e| ProviderError::Transport(format!("inference: {e}")))?;
351 let (shape, data) = outputs[0]
352 .try_extract_tensor::<f32>()
353 .map_err(|e| ProviderError::BadResponse(format!("read output: {e}")))?;
354
355 if task == Task::Embed && shape.len() == 3 {
356 let seq_out = usize::try_from(shape[1]).unwrap_or(0);
358 let hidden = usize::try_from(shape[2]).unwrap_or(0);
359 vectors.push(mean_pool(data, seq_out, hidden, &mask));
360 } else {
361 vectors.push(data.to_vec());
363 }
364 }
365 drop(session);
366 Ok(InferenceOutputs::Vectors(vectors))
367}
368
369fn mean_pool(data: &[f32], seq: usize, hidden: usize, mask: &[i64]) -> Vec<f32> {
371 let mut pooled = vec![0.0_f32; hidden];
372 let mut count = 0.0_f32;
373 for t in 0..seq {
374 if mask.get(t).copied().unwrap_or(0) == 0 {
375 continue;
376 }
377 count += 1.0;
378 for h in 0..hidden {
379 pooled[h] += data[t * hidden + h];
380 }
381 }
382 if count > 0.0 {
383 for value in &mut pooled {
384 *value /= count;
385 }
386 }
387 pooled
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393
394 #[test]
395 fn model_dir_resolution() {
396 let cache = Path::new("/models");
397 assert_eq!(
398 model_dir(cache, "hf:onnx-community/finbert"),
399 Path::new("/models/onnx-community/finbert")
400 );
401 assert_eq!(model_dir(cache, "file:///abs/dir"), Path::new("/abs/dir"));
402 assert_eq!(model_dir(cache, "/some/path"), Path::new("/some/path"));
403 }
404
405 #[test]
406 fn parse_id2label_orders_by_index() {
407 let config = r#"{"id2label": {"2": "positive", "0": "negative", "1": "neutral"}}"#;
408 assert_eq!(
409 parse_id2label(config),
410 vec!["negative", "neutral", "positive"]
411 );
412 assert!(parse_id2label("{}").is_empty());
413 }
414
415 #[test]
416 fn intrinsic_labels_read_from_a_cached_models_config() {
417 let cache = tempfile::tempdir().expect("tempdir");
418 let provider = LocalProvider::new(cache.path());
419 let source = "hf:org/repo";
420
421 assert!(provider.intrinsic_labels(source).is_none());
423
424 let dir = model_dir(cache.path(), source);
425 std::fs::create_dir_all(&dir).unwrap();
426 std::fs::write(
427 dir.join("config.json"),
428 r#"{"id2label": {"0": "NEGATIVE", "1": "POSITIVE"}}"#,
429 )
430 .unwrap();
431 assert_eq!(
432 provider.intrinsic_labels(source),
433 Some(vec!["NEGATIVE".into(), "POSITIVE".into()])
434 );
435 }
436
437 #[test]
438 fn mean_pool_ignores_masked_tokens() {
439 let data = [1.0, 2.0, 3.0, 4.0, 100.0, 100.0];
441 let pooled = mean_pool(&data, 3, 2, &[1, 1, 0]);
442 assert_eq!(pooled, vec![2.0, 3.0]); }
444
445 #[tokio::test]
453 #[ignore = "downloads a model + needs ORT_DYLIB_PATH; run with --ignored"]
454 async fn classifies_with_a_real_onnx_community_model() {
455 use crate::adapter::parse_response;
456 use crate::provider::InferenceParams;
457 use crate::registry::BackendKind;
458
459 let cache = tempfile::tempdir().expect("tempdir");
460 let provider = LocalProvider::new(cache.path());
461 let source = "hf:onnx-community/distilbert-base-uncased-finetuned-sst-2-english-ONNX";
462 let request = InferenceRequest {
463 task: Task::Classify,
464 model: source.to_string(),
465 inputs: vec![
466 "this film was absolutely wonderful, I loved every minute".into(),
467 "a complete waste of time, dull and disappointing".into(),
468 ],
469 params: InferenceParams::default(),
470 };
471
472 let response = provider.infer_batch(request).await.expect("inference");
473 let InferenceOutputs::Vectors(rows) = &response.outputs else {
474 panic!("local classify returns logits");
475 };
476 let labels = load_labels(cache.path(), source);
477 assert!(!labels.is_empty(), "id2label should load from config.json");
478 assert_eq!(rows.len(), 2);
479 assert!(
480 rows.iter().all(|r| r.len() == labels.len()),
481 "logit dimension must equal the label count",
482 );
483
484 let InferenceOutputs::Text(out) = parse_response(
485 Task::Classify,
486 BackendKind::Local,
487 response.outputs,
488 Some(&labels),
489 )
490 .expect("adapt") else {
491 panic!("argmax yields labels");
492 };
493 assert_eq!(out, vec!["POSITIVE".to_string(), "NEGATIVE".to_string()]);
494 }
495}