1use std::borrow::Cow;
12use std::collections::HashMap;
13use std::path::{Path, PathBuf};
14use std::sync::Arc;
15use std::time::Duration;
16
17use async_trait::async_trait;
18use ort::session::{Session, SessionInputValue};
19use ort::value::Tensor;
20use parking_lot::Mutex;
21
22use crate::ai::provider::{
23 InferenceOutputs, InferenceProvider, InferenceRequest, InferenceResponse, ProviderError, Usage,
24};
25use crate::ai::registry::Task;
26
27const INFERENCE_TIMEOUT: Duration = Duration::from_secs(60);
32
33struct LoadedModel {
38 session: Mutex<Session>,
39 tokenizer: tokenizers::Tokenizer,
40 input_names: Vec<String>,
41 labels: Vec<String>,
43}
44
45pub struct LocalProvider {
47 cache_dir: PathBuf,
48 loaded: Mutex<HashMap<String, Arc<LoadedModel>>>,
49 load_lock: tokio::sync::Mutex<()>,
52}
53
54impl LocalProvider {
55 #[must_use]
57 pub fn new(cache_dir: impl Into<PathBuf>) -> Self {
58 Self {
59 cache_dir: cache_dir.into(),
60 loaded: Mutex::new(HashMap::new()),
61 load_lock: tokio::sync::Mutex::new(()),
62 }
63 }
64
65 async fn ensure_model(&self, source: &str) -> Result<Arc<LoadedModel>, ProviderError> {
68 if let Some(model) = self.loaded.lock().get(source) {
69 return Ok(Arc::clone(model));
70 }
71 let _load = self.load_lock.lock().await;
74 if let Some(model) = self.loaded.lock().get(source) {
75 return Ok(Arc::clone(model));
76 }
77
78 let loaded = if let Some(repo_id) = source.strip_prefix("hf:") {
79 let api = hf_hub::api::tokio::ApiBuilder::from_env()
80 .with_cache_dir(self.cache_dir.clone())
81 .build()
82 .map_err(|e| {
83 ProviderError::Transport(format!("failed to initialize hf-hub client: {e}"))
84 })?;
85 let repo = api.model(repo_id.to_string());
86
87 let tokenizer_path = repo.get("tokenizer.json").await.map_err(|e| {
88 ProviderError::Transport(format!("failed to download tokenizer.json: {e}"))
89 })?;
90
91 let onnx_path = match repo.get("onnx/model.onnx").await {
92 Ok(path) => path,
93 Err(_) => repo.get("model.onnx").await.map_err(|e| {
94 ProviderError::Transport(format!("failed to download model.onnx: {e}"))
95 })?,
96 };
97
98 let config_path = repo.get("config.json").await.ok();
99
100 tokio::task::spawn_blocking(move || {
102 load_model_from_paths(&onnx_path, &tokenizer_path, config_path.as_deref())
103 })
104 .await
105 .map_err(|e| ProviderError::Transport(format!("model load task: {e}")))??
106 } else {
107 let dir = model_dir(&self.cache_dir, source);
108 tokio::task::spawn_blocking(move || load_model(&dir))
110 .await
111 .map_err(|e| ProviderError::Transport(format!("model load task: {e}")))??
112 };
113
114 let loaded = Arc::new(loaded);
115 self.loaded
116 .lock()
117 .insert(source.to_string(), Arc::clone(&loaded));
118 Ok(loaded)
119 }
120}
121
122#[must_use]
125pub fn model_dir(cache_dir: &Path, source: &str) -> PathBuf {
126 if let Some(repo) = source.strip_prefix("hf:") {
127 cache_dir.join(repo)
128 } else if let Some(path) = source.strip_prefix("file://") {
129 PathBuf::from(path)
130 } else {
131 PathBuf::from(source)
132 }
133}
134
135#[must_use]
139pub fn load_labels(cache_dir: &Path, source: &str) -> Vec<String> {
140 if let Ok(text) = std::fs::read_to_string(model_dir(cache_dir, source).join("config.json")) {
143 return parse_id2label(&text);
144 }
145 if let Some(repo_id) = source.strip_prefix("hf:") {
147 let cache = hf_hub::Cache::new(cache_dir.to_path_buf());
148 let repo = cache.repo(hf_hub::Repo::model(repo_id.to_string()));
149 if let Some(path) = repo.get("config.json") {
150 if let Ok(text) = std::fs::read_to_string(path) {
151 return parse_id2label(&text);
152 }
153 }
154 }
155 Vec::new()
156}
157
158fn parse_id2label(config_json: &str) -> Vec<String> {
159 let Ok(json) = serde_json::from_str::<serde_json::Value>(config_json) else {
160 return Vec::new();
161 };
162 let Some(map) = json.get("id2label").and_then(serde_json::Value::as_object) else {
163 return Vec::new();
164 };
165 let mut indexed: Vec<(usize, String)> = map
166 .iter()
167 .filter_map(|(k, v)| Some((k.parse().ok()?, v.as_str()?.to_string())))
168 .collect();
169 indexed.sort_by_key(|(index, _)| *index);
170 indexed.into_iter().map(|(_, label)| label).collect()
171}
172
173#[async_trait]
174impl InferenceProvider for LocalProvider {
175 async fn infer_batch(
176 &self,
177 request: InferenceRequest,
178 ) -> Result<InferenceResponse, ProviderError> {
179 if matches!(
180 request.task,
181 Task::Complete | Task::Summarize | Task::Translate | Task::Gen | Task::Extract
182 ) {
183 return Err(ProviderError::UnsupportedTask(request.task));
184 }
185 let loaded = self.ensure_model(&request.model).await?;
186 let task = request.task;
187 let inputs = request.inputs;
188 let run = tokio::task::spawn_blocking(move || run(&loaded, task, &inputs));
191 let outputs = match tokio::time::timeout(INFERENCE_TIMEOUT, run).await {
192 Ok(joined) => {
193 joined.map_err(|e| ProviderError::Transport(format!("inference task: {e}")))??
194 }
195 Err(_) => {
196 return Err(ProviderError::Timeout(
197 u64::try_from(INFERENCE_TIMEOUT.as_millis()).unwrap_or(u64::MAX),
198 ))
199 }
200 };
201 Ok(InferenceResponse {
202 outputs,
203 usage: Usage::ZERO,
204 })
205 }
206
207 fn name(&self) -> &'static str {
208 "local"
209 }
210
211 fn intrinsic_labels(&self, model: &str) -> Option<Vec<String>> {
212 let labels = self
215 .loaded
216 .lock()
217 .get(model)
218 .map_or_else(|| load_labels(&self.cache_dir, model), |m| m.labels.clone());
219 (!labels.is_empty()).then_some(labels)
220 }
221}
222
223fn onnx_path(dir: &Path) -> PathBuf {
226 let nested = dir.join("onnx").join("model.onnx");
227 if nested.exists() {
228 nested
229 } else {
230 dir.join("model.onnx")
231 }
232}
233
234fn load_model_from_paths(
235 onnx_path: &Path,
236 tokenizer_path: &Path,
237 config_path: Option<&Path>,
238) -> Result<LoadedModel, ProviderError> {
239 let session = Session::builder()
240 .map_err(|e| ProviderError::Transport(format!("ort init: {e}")))?
241 .commit_from_file(onnx_path)
242 .map_err(|e| ProviderError::Transport(format!("load onnx: {e}")))?;
243 let input_names = session
244 .inputs()
245 .iter()
246 .map(|i| i.name().to_string())
247 .collect();
248 let mut tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
249 .map_err(|e| ProviderError::Transport(format!("load tokenizer: {e}")))?;
250 if tokenizer.get_padding().is_none() {
251 tokenizer.with_padding(Some(tokenizers::PaddingParams {
252 strategy: tokenizers::PaddingStrategy::BatchLongest,
253 direction: tokenizers::PaddingDirection::Right,
254 pad_to_multiple_of: None,
255 pad_id: 0,
256 pad_type_id: 0,
257 pad_token: "[PAD]".to_string(),
258 }));
259 }
260 let labels = if let Some(path) = config_path {
261 std::fs::read_to_string(path)
262 .ok()
263 .map(|text| parse_id2label(&text))
264 .unwrap_or_default()
265 } else {
266 Vec::new()
267 };
268 Ok(LoadedModel {
269 session: Mutex::new(session),
270 tokenizer,
271 input_names,
272 labels,
273 })
274}
275
276fn load_model(dir: &Path) -> Result<LoadedModel, ProviderError> {
277 let onnx = onnx_path(dir);
278 let tokenizer_path = dir.join("tokenizer.json");
279 if !onnx.exists() || !tokenizer_path.exists() {
280 return Err(ProviderError::Transport(format!(
281 "local model files not found in {} (expected onnx/model.onnx + tokenizer.json)",
282 dir.display()
283 )));
284 }
285 let config_path = dir.join("config.json");
286 let config_path_opt = config_path.exists().then_some(config_path);
287 load_model_from_paths(&onnx, &tokenizer_path, config_path_opt.as_deref())
288}
289
290fn run(
291 loaded: &LoadedModel,
292 task: Task,
293 inputs: &[String],
294) -> Result<InferenceOutputs, ProviderError> {
295 if inputs.is_empty() {
296 return Ok(InferenceOutputs::Vectors(vec![]));
297 }
298
299 let encodings = loaded
300 .tokenizer
301 .encode_batch(inputs.to_vec(), true)
302 .map_err(|e| ProviderError::BadResponse(format!("tokenize batch: {e}")))?;
303
304 let batch_size = encodings.len();
305 if batch_size == 0 {
306 return Ok(InferenceOutputs::Vectors(vec![]));
307 }
308 let seq = encodings[0].len();
309
310 let mut stacked_ids = Vec::with_capacity(batch_size * seq);
311 let mut stacked_mask = Vec::with_capacity(batch_size * seq);
312 for encoding in &encodings {
313 stacked_ids.extend(encoding.get_ids().iter().map(|&u| i64::from(u)));
314 stacked_mask.extend(encoding.get_attention_mask().iter().map(|&u| i64::from(u)));
315 }
316
317 let batch_size_i64 = i64::try_from(batch_size).unwrap_or(i64::MAX);
318 let seq_i64 = i64::try_from(seq).unwrap_or(i64::MAX);
319 let shape = vec![batch_size_i64, seq_i64];
320
321 let mut feeds: Vec<(Cow<str>, SessionInputValue)> =
322 Vec::with_capacity(loaded.input_names.len());
323 for name in &loaded.input_names {
324 let row = match name.as_str() {
325 "input_ids" => stacked_ids.clone(),
326 "attention_mask" => stacked_mask.clone(),
327 "token_type_ids" => vec![0i64; batch_size * seq],
328 other => {
329 return Err(ProviderError::BadResponse(format!(
330 "model expects unsupported input '{other}'"
331 )))
332 }
333 };
334 let tensor = Tensor::from_array((shape.clone(), row))
335 .map_err(|e| ProviderError::Transport(format!("build tensor: {e}")))?;
336 feeds.push((Cow::Owned(name.clone()), SessionInputValue::from(tensor)));
337 }
338
339 let mut session = loaded.session.lock();
340 let outputs = session
341 .run(feeds)
342 .map_err(|e| ProviderError::Transport(format!("inference: {e}")))?;
343 let (shape, data) = outputs[0]
344 .try_extract_tensor::<f32>()
345 .map_err(|e| ProviderError::BadResponse(format!("read output: {e}")))?;
346
347 if shape.first().copied().and_then(|d| usize::try_from(d).ok()) != Some(batch_size) {
348 return Err(ProviderError::BadResponse(format!(
349 "expected batch size {batch_size}, got output shape {shape:?}"
350 )));
351 }
352
353 let mut vectors = Vec::with_capacity(batch_size);
354 if task == Task::Embed && shape.len() == 3 {
355 let seq_out = usize::try_from(shape[1]).unwrap_or(0);
357 let hidden = usize::try_from(shape[2]).unwrap_or(0);
358 for (i, encoding) in encodings.iter().enumerate() {
359 let start = i * seq_out * hidden;
360 let end = (i + 1) * seq_out * hidden;
361 let slice = &data[start..end];
362 let mask: Vec<i64> = encoding
363 .get_attention_mask()
364 .iter()
365 .map(|&u| i64::from(u))
366 .collect();
367 vectors.push(mean_pool(slice, seq_out, hidden, &mask));
368 }
369 } else {
370 let dim: usize = shape[1..]
372 .iter()
373 .map(|&d| usize::try_from(d).unwrap_or(0))
374 .product();
375 for i in 0..batch_size {
376 let start = i * dim;
377 let end = (i + 1) * dim;
378 let slice = &data[start..end];
379 vectors.push(slice.to_vec());
380 }
381 }
382
383 Ok(InferenceOutputs::Vectors(vectors))
384}
385
386fn mean_pool(data: &[f32], seq: usize, hidden: usize, mask: &[i64]) -> Vec<f32> {
388 let mut pooled = vec![0.0_f32; hidden];
389 let mut count = 0.0_f32;
390 for t in 0..seq {
391 if mask.get(t).copied().unwrap_or(0) == 0 {
392 continue;
393 }
394 count += 1.0;
395 for h in 0..hidden {
396 pooled[h] += data[t * hidden + h];
397 }
398 }
399 if count > 0.0 {
400 for value in &mut pooled {
401 *value /= count;
402 }
403 }
404 pooled
405}
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410
411 #[test]
412 fn model_dir_resolution() {
413 let cache = Path::new("/models");
414 assert_eq!(
415 model_dir(cache, "hf:onnx-community/finbert"),
416 Path::new("/models/onnx-community/finbert")
417 );
418 assert_eq!(model_dir(cache, "file:///abs/dir"), Path::new("/abs/dir"));
419 assert_eq!(model_dir(cache, "/some/path"), Path::new("/some/path"));
420 }
421
422 #[test]
423 fn parse_id2label_orders_by_index() {
424 let config = r#"{"id2label": {"2": "positive", "0": "negative", "1": "neutral"}}"#;
425 assert_eq!(
426 parse_id2label(config),
427 vec!["negative", "neutral", "positive"]
428 );
429 assert!(parse_id2label("{}").is_empty());
430 }
431
432 #[test]
433 fn intrinsic_labels_read_from_a_cached_models_config() {
434 let cache = tempfile::tempdir().expect("tempdir");
435 let provider = LocalProvider::new(cache.path());
436 let source = "hf:org/repo";
437
438 assert!(provider.intrinsic_labels(source).is_none());
440
441 let dir = model_dir(cache.path(), source);
442 std::fs::create_dir_all(&dir).unwrap();
443 std::fs::write(
444 dir.join("config.json"),
445 r#"{"id2label": {"0": "NEGATIVE", "1": "POSITIVE"}}"#,
446 )
447 .unwrap();
448 assert_eq!(
449 provider.intrinsic_labels(source),
450 Some(vec!["NEGATIVE".into(), "POSITIVE".into()])
451 );
452 }
453
454 #[test]
455 fn mean_pool_ignores_masked_tokens() {
456 let data = [1.0, 2.0, 3.0, 4.0, 100.0, 100.0];
458 let pooled = mean_pool(&data, 3, 2, &[1, 1, 0]);
459 assert_eq!(pooled, vec![2.0, 3.0]); }
461
462 #[tokio::test]
470 #[ignore = "downloads a model + needs ORT_DYLIB_PATH; run with --ignored"]
471 async fn classifies_with_a_real_onnx_community_model() {
472 use crate::ai::adapter::parse_response;
473 use crate::ai::provider::InferenceParams;
474 use crate::ai::registry::BackendKind;
475
476 let cache = tempfile::tempdir().expect("tempdir");
477 let provider = LocalProvider::new(cache.path());
478 let source = "hf:onnx-community/distilbert-base-uncased-finetuned-sst-2-english-ONNX";
479 let request = InferenceRequest {
480 task: Task::Classify,
481 model: source.to_string(),
482 inputs: vec![
483 "this film was absolutely wonderful, I loved every minute".into(),
484 "a complete waste of time, dull and disappointing".into(),
485 ],
486 params: InferenceParams::default(),
487 };
488
489 let response = provider.infer_batch(request).await.expect("inference");
490 let InferenceOutputs::Vectors(rows) = &response.outputs else {
491 panic!("local classify returns logits");
492 };
493 let labels = load_labels(cache.path(), source);
494 assert!(!labels.is_empty(), "id2label should load from config.json");
495 assert_eq!(rows.len(), 2);
496 assert!(
497 rows.iter().all(|r| r.len() == labels.len()),
498 "logit dimension must equal the label count",
499 );
500
501 let InferenceOutputs::Text(out) = parse_response(
502 Task::Classify,
503 BackendKind::Local,
504 response.outputs,
505 Some(&labels),
506 )
507 .expect("adapt") else {
508 panic!("argmax yields labels");
509 };
510 assert_eq!(out, vec!["POSITIVE".to_string(), "NEGATIVE".to_string()]);
511 }
512}