1use std::sync::Arc;
7
8use arrow_array::builder::BinaryBuilder;
9use arrow_array::{Array, RecordBatch};
10use arrow_schema::{DataType, Field, Schema, SchemaRef};
11
12use crate::error::ConnectorError;
13use crate::schema::csv::{CsvDecoder, CsvDecoderConfig};
14use crate::schema::json::decoder::{JsonDecoder, JsonDecoderConfig};
15use crate::schema::traits::FormatDecoder;
16use crate::schema::types::RawRecord;
17
18use super::source_config::MessageFormat;
19
20pub struct MessageParser {
22 schema: SchemaRef,
24 format: MessageFormat,
26 json_decoder: Option<JsonDecoder>,
28 csv_decoder: Option<CsvDecoder>,
30}
31
32impl MessageParser {
33 #[must_use]
35 pub fn new(
36 schema: SchemaRef,
37 format: MessageFormat,
38 decoder_config: JsonDecoderConfig,
39 ) -> Self {
40 let json_decoder = match &format {
41 MessageFormat::Json | MessageFormat::JsonLines => {
42 Some(JsonDecoder::with_config(schema.clone(), decoder_config))
43 }
44 _ => None,
45 };
46 let csv_decoder = match &format {
47 MessageFormat::Csv {
48 delimiter,
49 has_header,
50 } => {
51 #[allow(clippy::cast_possible_truncation)]
52 let csv_config = CsvDecoderConfig {
53 delimiter: *delimiter as u8,
54 has_header: *has_header,
55 ..CsvDecoderConfig::default()
56 };
57 Some(CsvDecoder::with_config(schema.clone(), csv_config))
58 }
59 _ => None,
60 };
61 Self {
62 schema,
63 format,
64 json_decoder,
65 csv_decoder,
66 }
67 }
68
69 #[must_use]
71 pub fn schema(&self) -> SchemaRef {
72 self.schema.clone()
73 }
74
75 pub fn parse_batch(&self, messages: &[&[u8]]) -> Result<RecordBatch, ConnectorError> {
81 if messages.is_empty() {
82 return Ok(RecordBatch::new_empty(self.schema.clone()));
83 }
84
85 match &self.format {
86 MessageFormat::Json | MessageFormat::JsonLines => self.parse_json_batch(messages),
87 MessageFormat::Binary => self.parse_binary_batch(messages),
88 MessageFormat::Csv { .. } => self.parse_csv_batch(messages),
89 }
90 }
91
92 fn parse_json_batch(&self, messages: &[&[u8]]) -> Result<RecordBatch, ConnectorError> {
97 let decoder = self.json_decoder.as_ref().ok_or_else(|| {
98 ConnectorError::Internal("json_decoder not initialized for JSON format".into())
99 })?;
100 let records: Vec<RawRecord> = messages
101 .iter()
102 .map(|m| RawRecord::new(m.to_vec()))
103 .collect();
104 decoder.decode_batch(&records).map_err(ConnectorError::from)
105 }
106
107 #[allow(clippy::unused_self)]
110 fn parse_binary_batch(&self, messages: &[&[u8]]) -> Result<RecordBatch, ConnectorError> {
111 let mut builder =
112 BinaryBuilder::with_capacity(messages.len(), messages.iter().map(|m| m.len()).sum());
113 for msg in messages {
114 builder.append_value(msg);
115 }
116
117 let schema = Arc::new(Schema::new(vec![Field::new(
118 "payload",
119 DataType::Binary,
120 false,
121 )]));
122 let arrays: Vec<Arc<dyn arrow_array::Array>> = vec![Arc::new(builder.finish())];
123
124 RecordBatch::try_new(schema, arrays).map_err(|e| {
125 ConnectorError::Serde(crate::error::SerdeError::MalformedInput(format!(
126 "failed to build binary RecordBatch: {e}"
127 )))
128 })
129 }
130
131 fn parse_csv_batch(&self, messages: &[&[u8]]) -> Result<RecordBatch, ConnectorError> {
135 let decoder = self.csv_decoder.as_ref().ok_or_else(|| {
136 ConnectorError::Internal("csv_decoder not initialized for CSV format".into())
137 })?;
138 let records: Vec<RawRecord> = messages
139 .iter()
140 .map(|m| RawRecord::new(m.to_vec()))
141 .collect();
142 decoder.decode_batch(&records).map_err(ConnectorError::from)
143 }
144}
145
146pub fn extract_max_event_time(
153 batch: &RecordBatch,
154 field: &str,
155) -> Result<Option<i64>, ConnectorError> {
156 let col_idx = batch.schema().index_of(field).map_err(|_| {
157 ConnectorError::SchemaMismatch(format!(
158 "event-time column '{field}' not found in batch schema"
159 ))
160 })?;
161 let arr = laminar_core::time::cast_to_millis_array(batch.column(col_idx).as_ref())
162 .map_err(|e| ConnectorError::SchemaMismatch(format!("event-time column '{field}': {e}")))?;
163 Ok((0..arr.len())
164 .filter(|&i| !arr.is_null(i))
165 .map(|i| arr.value(i))
166 .max())
167}
168
169pub fn infer_schema_from_json(sample: &[u8]) -> Result<SchemaRef, ConnectorError> {
179 infer_schema_from_json_with_path(sample, None)
180}
181
182pub fn infer_schema_from_json_with_path(
189 sample: &[u8],
190 json_path: Option<&[String]>,
191) -> Result<SchemaRef, ConnectorError> {
192 let text = std::str::from_utf8(sample).map_err(|e| {
193 ConnectorError::Serde(crate::error::SerdeError::MalformedInput(format!(
194 "invalid UTF-8: {e}"
195 )))
196 })?;
197
198 let value: serde_json::Value = serde_json::from_str(text)
199 .map_err(|e| ConnectorError::Serde(crate::error::SerdeError::Json(e.to_string())))?;
200
201 let target = if let Some(path) = json_path {
202 let mut current = &value;
203 for segment in path {
204 current = current.get(segment.as_str()).ok_or_else(|| {
205 ConnectorError::Serde(crate::error::SerdeError::MalformedInput(format!(
206 "json.path segment '{segment}' not found during inference"
207 )))
208 })?;
209 }
210 current
211 } else {
212 &value
213 };
214
215 let obj = target.as_object().ok_or_else(|| {
216 ConnectorError::Serde(crate::error::SerdeError::MalformedInput(
217 "schema inference requires a JSON object".into(),
218 ))
219 })?;
220
221 let fields: Vec<Field> = obj
222 .iter()
223 .map(|(key, val)| {
224 let dt = match val {
225 serde_json::Value::Bool(_) => DataType::Boolean,
226 serde_json::Value::Number(n) => {
227 if n.is_f64() {
228 DataType::Float64
229 } else {
230 DataType::Int64
231 }
232 }
233 _ => DataType::Utf8,
234 };
235 Field::new(key, dt, true)
236 })
237 .collect();
238
239 Ok(Arc::new(Schema::new(fields)))
240}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245
246 fn json_schema() -> SchemaRef {
247 Arc::new(Schema::new(vec![
248 Field::new("id", DataType::Utf8, true),
249 Field::new("value", DataType::Utf8, true),
250 ]))
251 }
252
253 #[test]
254 fn test_parse_json_batch() {
255 let parser = MessageParser::new(
256 json_schema(),
257 MessageFormat::Json,
258 JsonDecoderConfig::default(),
259 );
260 let messages: Vec<&[u8]> = vec![
261 br#"{"id": "1", "value": "hello"}"#,
262 br#"{"id": "2", "value": "world"}"#,
263 ];
264
265 let batch = parser.parse_batch(&messages).unwrap();
266 assert_eq!(batch.num_rows(), 2);
267 assert_eq!(batch.num_columns(), 2);
268 }
269
270 #[test]
271 fn test_parse_json_missing_field() {
272 let parser = MessageParser::new(
273 json_schema(),
274 MessageFormat::Json,
275 JsonDecoderConfig::default(),
276 );
277 let messages: Vec<&[u8]> = vec![br#"{"id": "1"}"#];
278
279 let batch = parser.parse_batch(&messages).unwrap();
280 assert_eq!(batch.num_rows(), 1);
281 assert!(batch.column(1).is_null(0));
282 }
283
284 #[test]
285 fn test_parse_json_numeric_values() {
286 let parser = MessageParser::new(
287 json_schema(),
288 MessageFormat::Json,
289 JsonDecoderConfig::default(),
290 );
291 let messages: Vec<&[u8]> = vec![br#"{"id": "1", "value": 42}"#];
292
293 let batch = parser.parse_batch(&messages).unwrap();
294 assert_eq!(batch.num_rows(), 1);
295 }
296
297 #[test]
298 fn test_parse_binary_batch() {
299 let schema = Arc::new(Schema::new(vec![Field::new(
300 "payload",
301 DataType::Binary,
302 false,
303 )]));
304 let parser =
305 MessageParser::new(schema, MessageFormat::Binary, JsonDecoderConfig::default());
306 let messages: Vec<&[u8]> = vec![b"hello", b"world"];
307
308 let batch = parser.parse_batch(&messages).unwrap();
309 assert_eq!(batch.num_rows(), 2);
310 }
311
312 #[test]
313 fn test_parse_csv_batch() {
314 let parser = MessageParser::new(
315 json_schema(),
316 MessageFormat::Csv {
317 delimiter: ',',
318 has_header: false,
319 },
320 JsonDecoderConfig::default(),
321 );
322 let messages: Vec<&[u8]> = vec![b"1,hello", b"2,world"];
323
324 let batch = parser.parse_batch(&messages).unwrap();
325 assert_eq!(batch.num_rows(), 2);
326 }
327
328 #[test]
329 fn test_parse_empty() {
330 let parser = MessageParser::new(
331 json_schema(),
332 MessageFormat::Json,
333 JsonDecoderConfig::default(),
334 );
335 let messages: Vec<&[u8]> = vec![];
336
337 let batch = parser.parse_batch(&messages).unwrap();
338 assert_eq!(batch.num_rows(), 0);
339 }
340
341 #[test]
342 fn test_parse_invalid_json() {
343 let parser = MessageParser::new(
344 json_schema(),
345 MessageFormat::Json,
346 JsonDecoderConfig::default(),
347 );
348 let messages: Vec<&[u8]> = vec![b"not json"];
349
350 assert!(parser.parse_batch(&messages).is_err());
351 }
352
353 #[test]
354 fn test_parse_json_typed_columns() {
355 let schema = Arc::new(Schema::new(vec![
356 Field::new("id", DataType::Int64, false),
357 Field::new("price", DataType::Float64, false),
358 Field::new("name", DataType::Utf8, true),
359 ]));
360 let parser = MessageParser::new(schema, MessageFormat::Json, JsonDecoderConfig::default());
361 let messages: Vec<&[u8]> = vec![
362 br#"{"id": 1, "price": 99.5, "name": "Widget"}"#,
363 br#"{"id": 2, "price": 10.0, "name": "Gadget"}"#,
364 ];
365
366 let batch = parser.parse_batch(&messages).unwrap();
367 assert_eq!(batch.num_rows(), 2);
368
369 assert_eq!(batch.column(0).data_type(), &DataType::Int64);
371 assert_eq!(batch.column(1).data_type(), &DataType::Float64);
372 assert_eq!(batch.column(2).data_type(), &DataType::Utf8);
373
374 let ids = batch
375 .column(0)
376 .as_any()
377 .downcast_ref::<arrow_array::Int64Array>()
378 .unwrap();
379 assert_eq!(ids.value(0), 1);
380 assert_eq!(ids.value(1), 2);
381 }
382
383 #[test]
384 fn test_parse_json_coerces_string_numbers() {
385 let schema = Arc::new(Schema::new(vec![Field::new(
386 "price",
387 DataType::Float64,
388 false,
389 )]));
390 let parser = MessageParser::new(schema, MessageFormat::Json, JsonDecoderConfig::default());
391 let messages: Vec<&[u8]> = vec![br#"{"price": "187.52"}"#];
392
393 let batch = parser.parse_batch(&messages).unwrap();
394 assert_eq!(batch.column(0).data_type(), &DataType::Float64);
395 let prices = batch
396 .column(0)
397 .as_any()
398 .downcast_ref::<arrow_array::Float64Array>()
399 .unwrap();
400 assert!((prices.value(0) - 187.52).abs() < f64::EPSILON);
401 }
402
403 #[test]
404 fn test_infer_schema() {
405 let sample = br#"{"name": "Alice", "age": 30, "active": true, "score": 99.5}"#;
406 let schema = infer_schema_from_json(sample).unwrap();
407
408 assert_eq!(schema.fields().len(), 4);
409 let age_field = schema.field_with_name("age").unwrap();
410 assert_eq!(age_field.data_type(), &DataType::Int64);
411 let active_field = schema.field_with_name("active").unwrap();
412 assert_eq!(active_field.data_type(), &DataType::Boolean);
413 let score_field = schema.field_with_name("score").unwrap();
414 assert_eq!(score_field.data_type(), &DataType::Float64);
415 }
416
417 #[test]
418 fn test_extract_max_event_time_millis() {
419 let schema = Arc::new(Schema::new(vec![Field::new(
420 "ts",
421 DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None),
422 false,
423 )]));
424 let ts = arrow_array::TimestampMillisecondArray::from(vec![1000, 3000, 2000]);
425 let batch = RecordBatch::try_new(schema, vec![Arc::new(ts)]).unwrap();
426
427 assert_eq!(extract_max_event_time(&batch, "ts").unwrap(), Some(3000));
428 }
429
430 #[test]
431 fn test_extract_max_event_time_nanos_rescaled() {
432 let schema = Arc::new(Schema::new(vec![Field::new(
433 "ts",
434 DataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, None),
435 false,
436 )]));
437 let ts = arrow_array::TimestampNanosecondArray::from(vec![
438 1_000_000_000,
439 3_000_000_000,
440 2_000_000_000,
441 ]);
442 let batch = RecordBatch::try_new(schema, vec![Arc::new(ts)]).unwrap();
443
444 assert_eq!(extract_max_event_time(&batch, "ts").unwrap(), Some(3_000));
445 }
446
447 #[test]
448 fn test_extract_max_event_time_missing_column_errors() {
449 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
450 let ids = arrow_array::Int64Array::from(vec![1, 2, 3]);
451 let batch = RecordBatch::try_new(schema, vec![Arc::new(ids)]).unwrap();
452
453 assert!(extract_max_event_time(&batch, "ts").is_err());
454 }
455
456 #[test]
457 fn test_extract_max_event_time_non_timestamp_column_errors() {
458 let schema = Arc::new(Schema::new(vec![Field::new("ts", DataType::Int64, false)]));
459 let ts = arrow_array::Int64Array::from(vec![1, 2, 3]);
460 let batch = RecordBatch::try_new(schema, vec![Arc::new(ts)]).unwrap();
461
462 assert!(extract_max_event_time(&batch, "ts").is_err());
463 }
464
465 #[test]
466 fn test_extract_max_event_time_with_nulls() {
467 let schema = Arc::new(Schema::new(vec![Field::new(
468 "ts",
469 DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None),
470 true,
471 )]));
472 let ts =
473 arrow_array::TimestampMillisecondArray::from(vec![Some(1000), None, Some(3000), None]);
474 let batch = RecordBatch::try_new(schema, vec![Arc::new(ts)]).unwrap();
475
476 assert_eq!(extract_max_event_time(&batch, "ts").unwrap(), Some(3000));
477 }
478}