Skip to main content

laminar_connectors/websocket/
parser.rs

1//! Message parsing: WebSocket frames → Arrow `RecordBatch`.
2//!
3//! Converts incoming WebSocket text/binary messages into Arrow
4//! `RecordBatch` rows for ingestion into Ring 0.
5
6use std::sync::Arc;
7
8use arrow_array::builder::BinaryBuilder;
9use arrow_array::{Array, RecordBatch};
10use arrow_schema::{DataType, Field, Schema, SchemaRef};
11
12use crate::error::ConnectorError;
13use crate::schema::csv::{CsvDecoder, CsvDecoderConfig};
14use crate::schema::json::decoder::{JsonDecoder, JsonDecoderConfig};
15use crate::schema::traits::FormatDecoder;
16use crate::schema::types::RawRecord;
17
18use super::source_config::MessageFormat;
19
20/// Parses raw WebSocket messages into Arrow `RecordBatch` data.
21pub struct MessageParser {
22    /// The output schema.
23    schema: SchemaRef,
24    /// The message format.
25    format: MessageFormat,
26    /// Type-aware JSON decoder (set for JSON/JsonLines formats).
27    json_decoder: Option<JsonDecoder>,
28    /// Type-aware CSV decoder (set for CSV format).
29    csv_decoder: Option<CsvDecoder>,
30}
31
32impl MessageParser {
33    /// Creates a new parser for the given schema, format, and JSON decoder config.
34    #[must_use]
35    pub fn new(
36        schema: SchemaRef,
37        format: MessageFormat,
38        decoder_config: JsonDecoderConfig,
39    ) -> Self {
40        let json_decoder = match &format {
41            MessageFormat::Json | MessageFormat::JsonLines => {
42                Some(JsonDecoder::with_config(schema.clone(), decoder_config))
43            }
44            _ => None,
45        };
46        let csv_decoder = match &format {
47            MessageFormat::Csv {
48                delimiter,
49                has_header,
50            } => {
51                #[allow(clippy::cast_possible_truncation)]
52                let csv_config = CsvDecoderConfig {
53                    delimiter: *delimiter as u8,
54                    has_header: *has_header,
55                    ..CsvDecoderConfig::default()
56                };
57                Some(CsvDecoder::with_config(schema.clone(), csv_config))
58            }
59            _ => None,
60        };
61        Self {
62            schema,
63            format,
64            json_decoder,
65            csv_decoder,
66        }
67    }
68
69    /// Returns the output schema.
70    #[must_use]
71    pub fn schema(&self) -> SchemaRef {
72        self.schema.clone()
73    }
74
75    /// Parses a batch of raw message payloads into a `RecordBatch`.
76    ///
77    /// # Errors
78    ///
79    /// Returns `ConnectorError::Serde` if parsing fails.
80    pub fn parse_batch(&self, messages: &[&[u8]]) -> Result<RecordBatch, ConnectorError> {
81        if messages.is_empty() {
82            return Ok(RecordBatch::new_empty(self.schema.clone()));
83        }
84
85        match &self.format {
86            MessageFormat::Json | MessageFormat::JsonLines => self.parse_json_batch(messages),
87            MessageFormat::Binary => self.parse_binary_batch(messages),
88            MessageFormat::Csv { .. } => self.parse_csv_batch(messages),
89        }
90    }
91
92    /// Parses JSON messages into a `RecordBatch`.
93    ///
94    /// Uses the type-aware [`JsonDecoder`] to coerce JSON values to the
95    /// Arrow types declared in the schema.
96    fn parse_json_batch(&self, messages: &[&[u8]]) -> Result<RecordBatch, ConnectorError> {
97        let decoder = self.json_decoder.as_ref().ok_or_else(|| {
98            ConnectorError::Internal("json_decoder not initialized for JSON format".into())
99        })?;
100        let records: Vec<RawRecord> = messages
101            .iter()
102            .map(|m| RawRecord::new(m.to_vec()))
103            .collect();
104        decoder.decode_batch(&records).map_err(ConnectorError::from)
105    }
106
107    /// Parses binary messages — each message becomes a single row with a
108    /// `LargeBinary` column named "payload".
109    #[allow(clippy::unused_self)]
110    fn parse_binary_batch(&self, messages: &[&[u8]]) -> Result<RecordBatch, ConnectorError> {
111        let mut builder =
112            BinaryBuilder::with_capacity(messages.len(), messages.iter().map(|m| m.len()).sum());
113        for msg in messages {
114            builder.append_value(msg);
115        }
116
117        let schema = Arc::new(Schema::new(vec![Field::new(
118            "payload",
119            DataType::Binary,
120            false,
121        )]));
122        let arrays: Vec<Arc<dyn arrow_array::Array>> = vec![Arc::new(builder.finish())];
123
124        RecordBatch::try_new(schema, arrays).map_err(|e| {
125            ConnectorError::Serde(crate::error::SerdeError::MalformedInput(format!(
126                "failed to build binary RecordBatch: {e}"
127            )))
128        })
129    }
130
131    /// Parses CSV text messages into a `RecordBatch`.
132    ///
133    /// Delegates to [`CsvDecoder`] for schema-directed type coercion.
134    fn parse_csv_batch(&self, messages: &[&[u8]]) -> Result<RecordBatch, ConnectorError> {
135        let decoder = self.csv_decoder.as_ref().ok_or_else(|| {
136            ConnectorError::Internal("csv_decoder not initialized for CSV format".into())
137        })?;
138        let records: Vec<RawRecord> = messages
139            .iter()
140            .map(|m| RawRecord::new(m.to_vec()))
141            .collect();
142        decoder.decode_batch(&records).map_err(ConnectorError::from)
143    }
144}
145
146/// Max event time (epoch ms) from a named `Timestamp(_)` column.
147/// `Ok(None)` when every row is null.
148///
149/// # Errors
150///
151/// `SchemaMismatch` if `field` is missing or isn't a `Timestamp(_)`.
152pub fn extract_max_event_time(
153    batch: &RecordBatch,
154    field: &str,
155) -> Result<Option<i64>, ConnectorError> {
156    let col_idx = batch.schema().index_of(field).map_err(|_| {
157        ConnectorError::SchemaMismatch(format!(
158            "event-time column '{field}' not found in batch schema"
159        ))
160    })?;
161    let arr = laminar_core::time::cast_to_millis_array(batch.column(col_idx).as_ref())
162        .map_err(|e| ConnectorError::SchemaMismatch(format!("event-time column '{field}': {e}")))?;
163    Ok((0..arr.len())
164        .filter(|&i| !arr.is_null(i))
165        .map(|i| arr.value(i))
166        .max())
167}
168
169/// Creates a default schema for JSON messages when no explicit schema is provided.
170///
171/// Uses schema inference from the first message. If `json_path` is provided,
172/// navigates into the object before inferring fields.
173///
174/// # Errors
175///
176/// Returns `ConnectorError::Serde` if the sample is not valid UTF-8 or valid JSON,
177/// or if the top-level value is not a JSON object.
178pub fn infer_schema_from_json(sample: &[u8]) -> Result<SchemaRef, ConnectorError> {
179    infer_schema_from_json_with_path(sample, None)
180}
181
182/// Like [`infer_schema_from_json`] but navigates a `json.path` first.
183///
184/// # Errors
185///
186/// Returns `ConnectorError::Serde` if the sample is not valid UTF-8 or valid JSON,
187/// if a path segment is not found, or if the target is not a JSON object.
188pub fn infer_schema_from_json_with_path(
189    sample: &[u8],
190    json_path: Option<&[String]>,
191) -> Result<SchemaRef, ConnectorError> {
192    let text = std::str::from_utf8(sample).map_err(|e| {
193        ConnectorError::Serde(crate::error::SerdeError::MalformedInput(format!(
194            "invalid UTF-8: {e}"
195        )))
196    })?;
197
198    let value: serde_json::Value = serde_json::from_str(text)
199        .map_err(|e| ConnectorError::Serde(crate::error::SerdeError::Json(e.to_string())))?;
200
201    let target = if let Some(path) = json_path {
202        let mut current = &value;
203        for segment in path {
204            current = current.get(segment.as_str()).ok_or_else(|| {
205                ConnectorError::Serde(crate::error::SerdeError::MalformedInput(format!(
206                    "json.path segment '{segment}' not found during inference"
207                )))
208            })?;
209        }
210        current
211    } else {
212        &value
213    };
214
215    let obj = target.as_object().ok_or_else(|| {
216        ConnectorError::Serde(crate::error::SerdeError::MalformedInput(
217            "schema inference requires a JSON object".into(),
218        ))
219    })?;
220
221    let fields: Vec<Field> = obj
222        .iter()
223        .map(|(key, val)| {
224            let dt = match val {
225                serde_json::Value::Bool(_) => DataType::Boolean,
226                serde_json::Value::Number(n) => {
227                    if n.is_f64() {
228                        DataType::Float64
229                    } else {
230                        DataType::Int64
231                    }
232                }
233                _ => DataType::Utf8,
234            };
235            Field::new(key, dt, true)
236        })
237        .collect();
238
239    Ok(Arc::new(Schema::new(fields)))
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245
246    fn json_schema() -> SchemaRef {
247        Arc::new(Schema::new(vec![
248            Field::new("id", DataType::Utf8, true),
249            Field::new("value", DataType::Utf8, true),
250        ]))
251    }
252
253    #[test]
254    fn test_parse_json_batch() {
255        let parser = MessageParser::new(
256            json_schema(),
257            MessageFormat::Json,
258            JsonDecoderConfig::default(),
259        );
260        let messages: Vec<&[u8]> = vec![
261            br#"{"id": "1", "value": "hello"}"#,
262            br#"{"id": "2", "value": "world"}"#,
263        ];
264
265        let batch = parser.parse_batch(&messages).unwrap();
266        assert_eq!(batch.num_rows(), 2);
267        assert_eq!(batch.num_columns(), 2);
268    }
269
270    #[test]
271    fn test_parse_json_missing_field() {
272        let parser = MessageParser::new(
273            json_schema(),
274            MessageFormat::Json,
275            JsonDecoderConfig::default(),
276        );
277        let messages: Vec<&[u8]> = vec![br#"{"id": "1"}"#];
278
279        let batch = parser.parse_batch(&messages).unwrap();
280        assert_eq!(batch.num_rows(), 1);
281        assert!(batch.column(1).is_null(0));
282    }
283
284    #[test]
285    fn test_parse_json_numeric_values() {
286        let parser = MessageParser::new(
287            json_schema(),
288            MessageFormat::Json,
289            JsonDecoderConfig::default(),
290        );
291        let messages: Vec<&[u8]> = vec![br#"{"id": "1", "value": 42}"#];
292
293        let batch = parser.parse_batch(&messages).unwrap();
294        assert_eq!(batch.num_rows(), 1);
295    }
296
297    #[test]
298    fn test_parse_binary_batch() {
299        let schema = Arc::new(Schema::new(vec![Field::new(
300            "payload",
301            DataType::Binary,
302            false,
303        )]));
304        let parser =
305            MessageParser::new(schema, MessageFormat::Binary, JsonDecoderConfig::default());
306        let messages: Vec<&[u8]> = vec![b"hello", b"world"];
307
308        let batch = parser.parse_batch(&messages).unwrap();
309        assert_eq!(batch.num_rows(), 2);
310    }
311
312    #[test]
313    fn test_parse_csv_batch() {
314        let parser = MessageParser::new(
315            json_schema(),
316            MessageFormat::Csv {
317                delimiter: ',',
318                has_header: false,
319            },
320            JsonDecoderConfig::default(),
321        );
322        let messages: Vec<&[u8]> = vec![b"1,hello", b"2,world"];
323
324        let batch = parser.parse_batch(&messages).unwrap();
325        assert_eq!(batch.num_rows(), 2);
326    }
327
328    #[test]
329    fn test_parse_empty() {
330        let parser = MessageParser::new(
331            json_schema(),
332            MessageFormat::Json,
333            JsonDecoderConfig::default(),
334        );
335        let messages: Vec<&[u8]> = vec![];
336
337        let batch = parser.parse_batch(&messages).unwrap();
338        assert_eq!(batch.num_rows(), 0);
339    }
340
341    #[test]
342    fn test_parse_invalid_json() {
343        let parser = MessageParser::new(
344            json_schema(),
345            MessageFormat::Json,
346            JsonDecoderConfig::default(),
347        );
348        let messages: Vec<&[u8]> = vec![b"not json"];
349
350        assert!(parser.parse_batch(&messages).is_err());
351    }
352
353    #[test]
354    fn test_parse_json_typed_columns() {
355        let schema = Arc::new(Schema::new(vec![
356            Field::new("id", DataType::Int64, false),
357            Field::new("price", DataType::Float64, false),
358            Field::new("name", DataType::Utf8, true),
359        ]));
360        let parser = MessageParser::new(schema, MessageFormat::Json, JsonDecoderConfig::default());
361        let messages: Vec<&[u8]> = vec![
362            br#"{"id": 1, "price": 99.5, "name": "Widget"}"#,
363            br#"{"id": 2, "price": 10.0, "name": "Gadget"}"#,
364        ];
365
366        let batch = parser.parse_batch(&messages).unwrap();
367        assert_eq!(batch.num_rows(), 2);
368
369        // Columns should have the declared types, not Utf8.
370        assert_eq!(batch.column(0).data_type(), &DataType::Int64);
371        assert_eq!(batch.column(1).data_type(), &DataType::Float64);
372        assert_eq!(batch.column(2).data_type(), &DataType::Utf8);
373
374        let ids = batch
375            .column(0)
376            .as_any()
377            .downcast_ref::<arrow_array::Int64Array>()
378            .unwrap();
379        assert_eq!(ids.value(0), 1);
380        assert_eq!(ids.value(1), 2);
381    }
382
383    #[test]
384    fn test_parse_json_coerces_string_numbers() {
385        let schema = Arc::new(Schema::new(vec![Field::new(
386            "price",
387            DataType::Float64,
388            false,
389        )]));
390        let parser = MessageParser::new(schema, MessageFormat::Json, JsonDecoderConfig::default());
391        let messages: Vec<&[u8]> = vec![br#"{"price": "187.52"}"#];
392
393        let batch = parser.parse_batch(&messages).unwrap();
394        assert_eq!(batch.column(0).data_type(), &DataType::Float64);
395        let prices = batch
396            .column(0)
397            .as_any()
398            .downcast_ref::<arrow_array::Float64Array>()
399            .unwrap();
400        assert!((prices.value(0) - 187.52).abs() < f64::EPSILON);
401    }
402
403    #[test]
404    fn test_infer_schema() {
405        let sample = br#"{"name": "Alice", "age": 30, "active": true, "score": 99.5}"#;
406        let schema = infer_schema_from_json(sample).unwrap();
407
408        assert_eq!(schema.fields().len(), 4);
409        let age_field = schema.field_with_name("age").unwrap();
410        assert_eq!(age_field.data_type(), &DataType::Int64);
411        let active_field = schema.field_with_name("active").unwrap();
412        assert_eq!(active_field.data_type(), &DataType::Boolean);
413        let score_field = schema.field_with_name("score").unwrap();
414        assert_eq!(score_field.data_type(), &DataType::Float64);
415    }
416
417    #[test]
418    fn test_extract_max_event_time_millis() {
419        let schema = Arc::new(Schema::new(vec![Field::new(
420            "ts",
421            DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None),
422            false,
423        )]));
424        let ts = arrow_array::TimestampMillisecondArray::from(vec![1000, 3000, 2000]);
425        let batch = RecordBatch::try_new(schema, vec![Arc::new(ts)]).unwrap();
426
427        assert_eq!(extract_max_event_time(&batch, "ts").unwrap(), Some(3000));
428    }
429
430    #[test]
431    fn test_extract_max_event_time_nanos_rescaled() {
432        let schema = Arc::new(Schema::new(vec![Field::new(
433            "ts",
434            DataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, None),
435            false,
436        )]));
437        let ts = arrow_array::TimestampNanosecondArray::from(vec![
438            1_000_000_000,
439            3_000_000_000,
440            2_000_000_000,
441        ]);
442        let batch = RecordBatch::try_new(schema, vec![Arc::new(ts)]).unwrap();
443
444        assert_eq!(extract_max_event_time(&batch, "ts").unwrap(), Some(3_000));
445    }
446
447    #[test]
448    fn test_extract_max_event_time_missing_column_errors() {
449        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
450        let ids = arrow_array::Int64Array::from(vec![1, 2, 3]);
451        let batch = RecordBatch::try_new(schema, vec![Arc::new(ids)]).unwrap();
452
453        assert!(extract_max_event_time(&batch, "ts").is_err());
454    }
455
456    #[test]
457    fn test_extract_max_event_time_non_timestamp_column_errors() {
458        let schema = Arc::new(Schema::new(vec![Field::new("ts", DataType::Int64, false)]));
459        let ts = arrow_array::Int64Array::from(vec![1, 2, 3]);
460        let batch = RecordBatch::try_new(schema, vec![Arc::new(ts)]).unwrap();
461
462        assert!(extract_max_event_time(&batch, "ts").is_err());
463    }
464
465    #[test]
466    fn test_extract_max_event_time_with_nulls() {
467        let schema = Arc::new(Schema::new(vec![Field::new(
468            "ts",
469            DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None),
470            true,
471        )]));
472        let ts =
473            arrow_array::TimestampMillisecondArray::from(vec![Some(1000), None, Some(3000), None]);
474        let batch = RecordBatch::try_new(schema, vec![Arc::new(ts)]).unwrap();
475
476        assert_eq!(extract_max_event_time(&batch, "ts").unwrap(), Some(3000));
477    }
478}