1use std::time::Duration;
10
11use async_trait::async_trait;
12use futures::stream::{StreamExt, TryStreamExt};
13use serde::{Deserialize, Serialize};
14
15use crate::backends::remote::{add_usage, chat_prompt, post_json};
16use crate::provider::{
17 InferenceOutputs, InferenceProvider, InferenceRequest, InferenceResponse, ProviderError, Usage,
18};
19use crate::registry::Task;
20
21const REQUEST_TIMEOUT_MS: u64 = 60_000;
22const MAX_RETRIES: u32 = 2;
23
24pub struct OpenAiProvider {
26 client: reqwest::Client,
27 base_url: String,
28 api_key: String,
29 max_concurrency: usize,
30}
31
32impl OpenAiProvider {
33 pub fn new(
41 base_url: impl Into<String>,
42 api_key: impl Into<String>,
43 max_concurrency: usize,
44 ) -> Result<Self, ProviderError> {
45 let client = reqwest::Client::builder()
46 .timeout(Duration::from_millis(REQUEST_TIMEOUT_MS))
47 .build()
48 .map_err(|e| ProviderError::Transport(e.to_string()))?;
49 Ok(Self {
50 client,
51 base_url: base_url.into().trim_end_matches('/').to_string(),
52 api_key: api_key.into(),
53 max_concurrency: max_concurrency.max(1),
54 })
55 }
56
57 async fn chat(&self, request: &InferenceRequest) -> Result<InferenceResponse, ProviderError> {
58 let url = format!("{}/chat/completions", self.base_url);
59
60 let bodies: Vec<ChatBody> = request
64 .inputs
65 .iter()
66 .map(|input| {
67 let (system, user) =
68 chat_prompt(request.task, input, request.params.labels.as_deref());
69 ChatBody {
70 model: request.model.clone(),
71 messages: vec![ChatMessage::system(system), ChatMessage::user(user)],
72 }
73 })
74 .collect();
75
76 let url = &url;
77 let calls = bodies.into_iter().map(|body| async move {
78 let builder = self.client.post(url).bearer_auth(&self.api_key).json(&body);
79 let response: ChatResponse =
80 post_json(builder, MAX_RETRIES, REQUEST_TIMEOUT_MS).await?;
81 parse_chat(response)
82 });
83
84 let results: Vec<(String, Usage)> = futures::stream::iter(calls)
85 .buffered(self.max_concurrency)
86 .try_collect()
87 .await?;
88
89 let mut texts = Vec::with_capacity(results.len());
90 let mut usage = Usage::ZERO;
91 for (text, call_usage) in results {
92 texts.push(text);
93 usage = add_usage(usage, call_usage);
94 }
95 Ok(InferenceResponse {
96 outputs: InferenceOutputs::Text(texts),
97 usage,
98 })
99 }
100
101 async fn embed(&self, request: &InferenceRequest) -> Result<InferenceResponse, ProviderError> {
102 let url = format!("{}/embeddings", self.base_url);
103 let body = EmbedBody {
104 model: request.model.clone(),
105 input: request.inputs.clone(),
106 };
107 let builder = self
108 .client
109 .post(&url)
110 .bearer_auth(&self.api_key)
111 .json(&body);
112 let response: EmbedResponse = post_json(builder, MAX_RETRIES, REQUEST_TIMEOUT_MS).await?;
113 let (vectors, usage) = parse_embed(response);
114 Ok(InferenceResponse {
115 outputs: InferenceOutputs::Vectors(vectors),
116 usage,
117 })
118 }
119}
120
121#[async_trait]
122impl InferenceProvider for OpenAiProvider {
123 async fn infer_batch(
124 &self,
125 request: InferenceRequest,
126 ) -> Result<InferenceResponse, ProviderError> {
127 match request.task {
128 Task::Embed => self.embed(&request).await,
129 _ => self.chat(&request).await,
130 }
131 }
132
133 fn name(&self) -> &'static str {
134 "openai"
135 }
136}
137
138#[derive(Serialize)]
141struct ChatBody {
142 model: String,
143 messages: Vec<ChatMessage>,
144}
145
146#[derive(Serialize)]
147struct ChatMessage {
148 role: &'static str,
149 content: String,
150}
151
152impl ChatMessage {
153 fn system(content: String) -> Self {
154 Self {
155 role: "system",
156 content,
157 }
158 }
159 fn user(content: String) -> Self {
160 Self {
161 role: "user",
162 content,
163 }
164 }
165}
166
167#[derive(Deserialize)]
168struct ChatResponse {
169 choices: Vec<ChatChoice>,
170 usage: Option<TokenUsage>,
171}
172
173#[derive(Deserialize)]
174struct ChatChoice {
175 message: ChatChoiceMessage,
176}
177
178#[derive(Deserialize)]
179struct ChatChoiceMessage {
180 content: String,
181}
182
183#[derive(Serialize)]
184struct EmbedBody {
185 model: String,
186 input: Vec<String>,
187}
188
189#[derive(Deserialize)]
190struct EmbedResponse {
191 data: Vec<EmbedData>,
192 usage: Option<TokenUsage>,
193}
194
195#[derive(Deserialize)]
196struct EmbedData {
197 embedding: Vec<f32>,
198 index: usize,
199}
200
201#[derive(Deserialize)]
202struct TokenUsage {
203 #[serde(default)]
204 prompt_tokens: u64,
205 #[serde(default)]
206 completion_tokens: u64,
207}
208
209fn parse_chat(response: ChatResponse) -> Result<(String, Usage), ProviderError> {
213 let content = response
214 .choices
215 .into_iter()
216 .next()
217 .map(|c| c.message.content)
218 .ok_or_else(|| ProviderError::BadResponse("chat response had no choices".to_string()))?;
219 Ok((content, token_usage(response.usage)))
220}
221
222fn parse_embed(mut response: EmbedResponse) -> (Vec<Vec<f32>>, Usage) {
224 response.data.sort_by_key(|d| d.index);
225 let usage = token_usage(response.usage);
226 let vectors = response.data.into_iter().map(|d| d.embedding).collect();
227 (vectors, usage)
228}
229
230fn token_usage(usage: Option<TokenUsage>) -> Usage {
231 usage.map_or(Usage::ZERO, |u| Usage {
232 input_tokens: u.prompt_tokens,
233 output_tokens: u.completion_tokens,
234 cost_micros: 0,
235 })
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241
242 #[test]
243 fn chat_body_serializes_to_messages() {
244 let body = ChatBody {
245 model: "gpt-x".to_string(),
246 messages: vec![
247 ChatMessage::system("be terse".to_string()),
248 ChatMessage::user("hello".to_string()),
249 ],
250 };
251 let value = serde_json::to_value(&body).unwrap();
252 assert_eq!(value["model"], "gpt-x");
253 assert_eq!(value["messages"][0]["role"], "system");
254 assert_eq!(value["messages"][1]["content"], "hello");
255 }
256
257 #[test]
258 fn parse_chat_extracts_content_and_tokens() {
259 let json = r#"{
260 "choices": [{"message": {"role": "assistant", "content": "positive"}}],
261 "usage": {"prompt_tokens": 12, "completion_tokens": 1}
262 }"#;
263 let response: ChatResponse = serde_json::from_str(json).unwrap();
264 let (text, usage) = parse_chat(response).unwrap();
265 assert_eq!(text, "positive");
266 assert_eq!(usage.input_tokens, 12);
267 assert_eq!(usage.output_tokens, 1);
268 }
269
270 #[test]
271 fn parse_chat_errors_without_choices() {
272 let response: ChatResponse = serde_json::from_str(r#"{"choices": []}"#).unwrap();
273 assert!(parse_chat(response).is_err());
274 }
275
276 #[test]
277 fn parse_embed_orders_by_index() {
278 let json = r#"{
279 "data": [
280 {"embedding": [0.3, 0.4], "index": 1},
281 {"embedding": [0.1, 0.2], "index": 0}
282 ],
283 "usage": {"prompt_tokens": 5}
284 }"#;
285 let response: EmbedResponse = serde_json::from_str(json).unwrap();
286 let (vectors, usage) = parse_embed(response);
287 assert_eq!(vectors, vec![vec![0.1, 0.2], vec![0.3, 0.4]]);
288 assert_eq!(usage.input_tokens, 5);
289 }
290}