Skip to main content

laminar_ai/
provider.rs

1//! The single transport abstraction over a model backend.
2//!
3//! An [`InferenceProvider`] does I/O only: a homogeneous batch of inputs in, a
4//! homogeneous batch of outputs plus usage out. It knows nothing about SQL tasks
5//! or output columns — framing a request and turning the response into a task's
6//! output column is the adapter's job. Implementors: Anthropic, an
7//! OpenAI-compatible provider (OpenAI / Azure / vLLM via `base_url`), and a
8//! local ONNX Runtime provider.
9
10use async_trait::async_trait;
11use thiserror::Error;
12
13use crate::registry::Task;
14
15/// One batch of inputs to run through a model. A request is homogeneous: a
16/// single task, a single model, and one input string per row in order.
17#[derive(Debug, Clone, PartialEq)]
18pub struct InferenceRequest {
19    /// The task being performed.
20    pub task: Task,
21    /// Runtime model identifier — the vendor model id for remote backends, or
22    /// the weight source for local backends.
23    pub model: String,
24    /// One input string per row, in row order.
25    pub inputs: Vec<String>,
26    /// Task-shaping parameters that also version the result cache.
27    pub params: InferenceParams,
28}
29
30/// Knobs that shape a request and contribute to the cache's `params_version`,
31/// so the same text under different parameters never collides. Generation knobs
32/// (`max_tokens`, `temperature`, …) are added here as backends consume them.
33#[derive(Debug, Clone, PartialEq, Default)]
34pub struct InferenceParams {
35    /// Candidate label set for classification (required for remote classify;
36    /// for local classifiers it must match or be a subset of the model's
37    /// intrinsic labels, validated at plan time).
38    pub labels: Option<Vec<String>>,
39}
40
41/// Per-row outputs of a batch. Homogeneous for a given request: a classify or
42/// generate batch yields text; an embed batch — or a local classifier's raw
43/// logits awaiting softmax in the adapter — yields numeric vectors; a sentiment
44/// batch yields one scalar score per row (the adapter's output, never a raw
45/// provider shape).
46#[derive(Debug, Clone, PartialEq)]
47pub enum InferenceOutputs {
48    /// One text output per input row.
49    Text(Vec<String>),
50    /// One numeric vector per input row (embeddings, or classifier logits).
51    Vectors(Vec<Vec<f32>>),
52    /// One scalar score per input row. Produced by the adapter for
53    /// `ai_sentiment` (continuous, in `[-1, 1]`); providers never return this
54    /// shape directly.
55    Scores(Vec<f64>),
56}
57
58impl InferenceOutputs {
59    /// Number of rows produced. Must equal the request's input count.
60    #[must_use]
61    pub fn len(&self) -> usize {
62        match self {
63            InferenceOutputs::Text(v) => v.len(),
64            InferenceOutputs::Vectors(v) => v.len(),
65            InferenceOutputs::Scores(v) => v.len(),
66        }
67    }
68
69    /// Whether no rows were produced.
70    #[must_use]
71    pub fn is_empty(&self) -> bool {
72        self.len() == 0
73    }
74}
75
76/// Token and cost accounting for a single batch call. Local backends report
77/// [`Usage::ZERO`]; remote backends report what the provider charged.
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
79pub struct Usage {
80    /// Prompt/input tokens billed.
81    pub input_tokens: u64,
82    /// Completion/output tokens billed.
83    pub output_tokens: u64,
84    /// Metered cost in micro-USD (millionths of a dollar); 0 for local.
85    pub cost_micros: u64,
86}
87
88impl Usage {
89    /// Zero usage — what local, deterministic backends report.
90    pub const ZERO: Usage = Usage {
91        input_tokens: 0,
92        output_tokens: 0,
93        cost_micros: 0,
94    };
95}
96
97/// The result of a batch inference call: outputs aligned 1:1 with the request's
98/// inputs, plus usage.
99#[derive(Debug, Clone, PartialEq)]
100pub struct InferenceResponse {
101    /// Per-row outputs, in input order.
102    pub outputs: InferenceOutputs,
103    /// Token/cost accounting for the call.
104    pub usage: Usage,
105}
106
107/// Transport over a model backend. Implementors perform I/O only — no task
108/// framing, no result parsing. Shared as `Arc<dyn InferenceProvider>` and
109/// driven from the Ring 1 inference worker, never from Ring 0.
110#[async_trait]
111pub trait InferenceProvider: Send + Sync {
112    /// Run one batch of inputs through the model.
113    ///
114    /// # Errors
115    ///
116    /// Returns [`ProviderError`] on transport failure, timeout, rate limiting,
117    /// a malformed response, or an unsupported task.
118    async fn infer_batch(
119        &self,
120        request: InferenceRequest,
121    ) -> Result<InferenceResponse, ProviderError>;
122
123    /// Stable backend-kind identity for logging and the `laminar.ai_calls`
124    /// log (e.g. `anthropic`, `openai`, `local`). Constant per implementor.
125    fn name(&self) -> &'static str;
126
127    /// Classifier labels intrinsic to a model, discovered from its own metadata.
128    /// Returns `None` for backends that have none (remote providers, embedding
129    /// models). A local classifier returns its `config.json` `id2label` once the
130    /// model is on disk — the seam that lets a lazily downloaded classifier score
131    /// without the labels having been known at startup. The default is `None`.
132    fn intrinsic_labels(&self, _model: &str) -> Option<Vec<String>> {
133        None
134    }
135}
136
137/// Errors a provider can return for a batch call.
138#[derive(Debug, Error)]
139pub enum ProviderError {
140    /// Network or connection failure.
141    #[error("transport error: {0}")]
142    Transport(String),
143
144    /// The call exceeded its deadline.
145    #[error("request timed out after {0} ms")]
146    Timeout(u64),
147
148    /// The provider signalled rate limiting.
149    #[error("rate limited by provider")]
150    RateLimited,
151
152    /// The response could not be parsed into the expected shape.
153    #[error("malformed response: {0}")]
154    BadResponse(String),
155
156    /// The provider cannot perform the requested task.
157    #[error("provider does not support task '{0}'")]
158    UnsupportedTask(Task),
159}