Skip to main content

laminar_ai/backends/
openai.rs

1//! OpenAI-compatible remote provider.
2//!
3//! Covers OpenAI, Azure OpenAI, vLLM, and local OpenAI-style servers via
4//! `base_url`. Chat completions drive the discriminative and generative tasks
5//! (one bounded-concurrent call per input row); the embeddings endpoint serves
6//! `ai_embed` in a single batched call. This is the embeddings path for the
7//! whole feature — Anthropic exposes no embeddings endpoint.
8
9use std::time::Duration;
10
11use async_trait::async_trait;
12use futures::stream::{StreamExt, TryStreamExt};
13use serde::{Deserialize, Serialize};
14
15use crate::backends::remote::{add_usage, chat_prompt, post_json};
16use crate::provider::{
17    InferenceOutputs, InferenceProvider, InferenceRequest, InferenceResponse, ProviderError, Usage,
18};
19use crate::registry::Task;
20
21const REQUEST_TIMEOUT_MS: u64 = 60_000;
22const MAX_RETRIES: u32 = 2;
23
24/// OpenAI-compatible HTTP provider.
25pub struct OpenAiProvider {
26    client: reqwest::Client,
27    base_url: String,
28    api_key: String,
29    max_concurrency: usize,
30}
31
32impl OpenAiProvider {
33    /// Build a provider for `base_url` (e.g. `https://api.openai.com/v1`),
34    /// authenticating with `api_key` and issuing at most `max_concurrency`
35    /// concurrent chat requests per batch.
36    ///
37    /// # Errors
38    ///
39    /// Returns [`ProviderError::Transport`] if the HTTP client cannot be built.
40    pub fn new(
41        base_url: impl Into<String>,
42        api_key: impl Into<String>,
43        max_concurrency: usize,
44    ) -> Result<Self, ProviderError> {
45        let client = reqwest::Client::builder()
46            .timeout(Duration::from_millis(REQUEST_TIMEOUT_MS))
47            .build()
48            .map_err(|e| ProviderError::Transport(e.to_string()))?;
49        Ok(Self {
50            client,
51            base_url: base_url.into().trim_end_matches('/').to_string(),
52            api_key: api_key.into(),
53            max_concurrency: max_concurrency.max(1),
54        })
55    }
56
57    async fn chat(&self, request: &InferenceRequest) -> Result<InferenceResponse, ProviderError> {
58        let url = format!("{}/chat/completions", self.base_url);
59
60        // Build owned request bodies first, then map each to a future. Mapping
61        // over owned `ChatBody` (not `&input`) keeps the future-producing
62        // closure free of an input lifetime to generalize over.
63        let bodies: Vec<ChatBody> = request
64            .inputs
65            .iter()
66            .map(|input| {
67                let (system, user) =
68                    chat_prompt(request.task, input, request.params.labels.as_deref());
69                ChatBody {
70                    model: request.model.clone(),
71                    messages: vec![ChatMessage::system(system), ChatMessage::user(user)],
72                }
73            })
74            .collect();
75
76        let url = &url;
77        let calls = bodies.into_iter().map(|body| async move {
78            let builder = self.client.post(url).bearer_auth(&self.api_key).json(&body);
79            let response: ChatResponse =
80                post_json(builder, MAX_RETRIES, REQUEST_TIMEOUT_MS).await?;
81            parse_chat(response)
82        });
83
84        let results: Vec<(String, Usage)> = futures::stream::iter(calls)
85            .buffered(self.max_concurrency)
86            .try_collect()
87            .await?;
88
89        let mut texts = Vec::with_capacity(results.len());
90        let mut usage = Usage::ZERO;
91        for (text, call_usage) in results {
92            texts.push(text);
93            usage = add_usage(usage, call_usage);
94        }
95        Ok(InferenceResponse {
96            outputs: InferenceOutputs::Text(texts),
97            usage,
98        })
99    }
100
101    async fn embed(&self, request: &InferenceRequest) -> Result<InferenceResponse, ProviderError> {
102        let url = format!("{}/embeddings", self.base_url);
103        let body = EmbedBody {
104            model: request.model.clone(),
105            input: request.inputs.clone(),
106        };
107        let builder = self
108            .client
109            .post(&url)
110            .bearer_auth(&self.api_key)
111            .json(&body);
112        let response: EmbedResponse = post_json(builder, MAX_RETRIES, REQUEST_TIMEOUT_MS).await?;
113        let (vectors, usage) = parse_embed(response);
114        Ok(InferenceResponse {
115            outputs: InferenceOutputs::Vectors(vectors),
116            usage,
117        })
118    }
119}
120
121#[async_trait]
122impl InferenceProvider for OpenAiProvider {
123    async fn infer_batch(
124        &self,
125        request: InferenceRequest,
126    ) -> Result<InferenceResponse, ProviderError> {
127        match request.task {
128            Task::Embed => self.embed(&request).await,
129            _ => self.chat(&request).await,
130        }
131    }
132
133    fn name(&self) -> &'static str {
134        "openai"
135    }
136}
137
138// --- wire shapes ---
139
140#[derive(Serialize)]
141struct ChatBody {
142    model: String,
143    messages: Vec<ChatMessage>,
144}
145
146#[derive(Serialize)]
147struct ChatMessage {
148    role: &'static str,
149    content: String,
150}
151
152impl ChatMessage {
153    fn system(content: String) -> Self {
154        Self {
155            role: "system",
156            content,
157        }
158    }
159    fn user(content: String) -> Self {
160        Self {
161            role: "user",
162            content,
163        }
164    }
165}
166
167#[derive(Deserialize)]
168struct ChatResponse {
169    choices: Vec<ChatChoice>,
170    usage: Option<TokenUsage>,
171}
172
173#[derive(Deserialize)]
174struct ChatChoice {
175    message: ChatChoiceMessage,
176}
177
178#[derive(Deserialize)]
179struct ChatChoiceMessage {
180    content: String,
181}
182
183#[derive(Serialize)]
184struct EmbedBody {
185    model: String,
186    input: Vec<String>,
187}
188
189#[derive(Deserialize)]
190struct EmbedResponse {
191    data: Vec<EmbedData>,
192    usage: Option<TokenUsage>,
193}
194
195#[derive(Deserialize)]
196struct EmbedData {
197    embedding: Vec<f32>,
198    index: usize,
199}
200
201#[derive(Deserialize)]
202struct TokenUsage {
203    #[serde(default)]
204    prompt_tokens: u64,
205    #[serde(default)]
206    completion_tokens: u64,
207}
208
209/// Pull the first choice's text and token usage from a chat response. Cost is
210/// left at zero — converting tokens to dollars needs a per-model price table,
211/// which is configured separately.
212fn parse_chat(response: ChatResponse) -> Result<(String, Usage), ProviderError> {
213    let content = response
214        .choices
215        .into_iter()
216        .next()
217        .map(|c| c.message.content)
218        .ok_or_else(|| ProviderError::BadResponse("chat response had no choices".to_string()))?;
219    Ok((content, token_usage(response.usage)))
220}
221
222/// Collect embeddings in input order (sorted by `index`) plus usage.
223fn parse_embed(mut response: EmbedResponse) -> (Vec<Vec<f32>>, Usage) {
224    response.data.sort_by_key(|d| d.index);
225    let usage = token_usage(response.usage);
226    let vectors = response.data.into_iter().map(|d| d.embedding).collect();
227    (vectors, usage)
228}
229
230fn token_usage(usage: Option<TokenUsage>) -> Usage {
231    usage.map_or(Usage::ZERO, |u| Usage {
232        input_tokens: u.prompt_tokens,
233        output_tokens: u.completion_tokens,
234        cost_micros: 0,
235    })
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241
242    #[test]
243    fn chat_body_serializes_to_messages() {
244        let body = ChatBody {
245            model: "gpt-x".to_string(),
246            messages: vec![
247                ChatMessage::system("be terse".to_string()),
248                ChatMessage::user("hello".to_string()),
249            ],
250        };
251        let value = serde_json::to_value(&body).unwrap();
252        assert_eq!(value["model"], "gpt-x");
253        assert_eq!(value["messages"][0]["role"], "system");
254        assert_eq!(value["messages"][1]["content"], "hello");
255    }
256
257    #[test]
258    fn parse_chat_extracts_content_and_tokens() {
259        let json = r#"{
260            "choices": [{"message": {"role": "assistant", "content": "positive"}}],
261            "usage": {"prompt_tokens": 12, "completion_tokens": 1}
262        }"#;
263        let response: ChatResponse = serde_json::from_str(json).unwrap();
264        let (text, usage) = parse_chat(response).unwrap();
265        assert_eq!(text, "positive");
266        assert_eq!(usage.input_tokens, 12);
267        assert_eq!(usage.output_tokens, 1);
268    }
269
270    #[test]
271    fn parse_chat_errors_without_choices() {
272        let response: ChatResponse = serde_json::from_str(r#"{"choices": []}"#).unwrap();
273        assert!(parse_chat(response).is_err());
274    }
275
276    #[test]
277    fn parse_embed_orders_by_index() {
278        let json = r#"{
279            "data": [
280                {"embedding": [0.3, 0.4], "index": 1},
281                {"embedding": [0.1, 0.2], "index": 0}
282            ],
283            "usage": {"prompt_tokens": 5}
284        }"#;
285        let response: EmbedResponse = serde_json::from_str(json).unwrap();
286        let (vectors, usage) = parse_embed(response);
287        assert_eq!(vectors, vec![vec![0.1, 0.2], vec![0.3, 0.4]]);
288        assert_eq!(usage.input_tokens, 5);
289    }
290}