Skip to main content

laminar_connectors/websocket/
parser.rs

1//! Message parsing: WebSocket frames → Arrow `RecordBatch`.
2//!
3//! Converts incoming WebSocket text/binary messages into Arrow
4//! `RecordBatch` rows for ingestion into Ring 0.
5
6use std::sync::Arc;
7
8use arrow_array::builder::BinaryBuilder;
9use arrow_array::{Array, RecordBatch};
10use arrow_schema::{DataType, Field, Schema, SchemaRef};
11
12use crate::error::ConnectorError;
13use crate::schema::csv::{CsvDecoder, CsvDecoderConfig};
14use crate::schema::json::decoder::{JsonDecoder, JsonDecoderConfig};
15use crate::schema::traits::FormatDecoder;
16use crate::schema::types::RawRecord;
17
18use super::source_config::MessageFormat;
19
20/// Parses raw WebSocket messages into Arrow `RecordBatch` data.
21pub struct MessageParser {
22    /// The output schema.
23    schema: SchemaRef,
24    /// The message format.
25    format: MessageFormat,
26    /// Type-aware JSON decoder (set for JSON/JsonLines formats).
27    json_decoder: Option<JsonDecoder>,
28    /// Type-aware CSV decoder (set for CSV format).
29    csv_decoder: Option<CsvDecoder>,
30}
31
32impl MessageParser {
33    /// Creates a new parser for the given schema, format, and JSON decoder config.
34    #[must_use]
35    pub fn new(
36        schema: SchemaRef,
37        format: MessageFormat,
38        decoder_config: JsonDecoderConfig,
39    ) -> Self {
40        let json_decoder = match &format {
41            MessageFormat::Json | MessageFormat::JsonLines => {
42                Some(JsonDecoder::with_config(schema.clone(), decoder_config))
43            }
44            _ => None,
45        };
46        let csv_decoder = match &format {
47            MessageFormat::Csv {
48                delimiter,
49                has_header,
50            } => {
51                // `parse_format` hard-codes the CSV delimiter to ',', so this
52                // is ASCII by construction. If the delimiter ever becomes
53                // user-configurable, validate `is_ascii()` at config parse.
54                #[allow(clippy::cast_possible_truncation)]
55                let csv_config = CsvDecoderConfig {
56                    delimiter: *delimiter as u8,
57                    has_header: *has_header,
58                    ..CsvDecoderConfig::default()
59                };
60                Some(CsvDecoder::with_config(schema.clone(), csv_config))
61            }
62            _ => None,
63        };
64        Self {
65            schema,
66            format,
67            json_decoder,
68            csv_decoder,
69        }
70    }
71
72    /// Returns the output schema.
73    #[must_use]
74    pub fn schema(&self) -> SchemaRef {
75        self.schema.clone()
76    }
77
78    /// Parses a batch of raw message payloads into a `RecordBatch`.
79    ///
80    /// # Errors
81    ///
82    /// Returns `ConnectorError::Serde` if parsing fails.
83    pub fn parse_batch(&self, messages: &[&[u8]]) -> Result<RecordBatch, ConnectorError> {
84        if messages.is_empty() {
85            return Ok(RecordBatch::new_empty(self.schema.clone()));
86        }
87
88        match &self.format {
89            MessageFormat::Json | MessageFormat::JsonLines => self.parse_json_batch(messages),
90            MessageFormat::Binary => self.parse_binary_batch(messages),
91            MessageFormat::Csv { .. } => self.parse_csv_batch(messages),
92        }
93    }
94
95    /// Parses JSON messages into a `RecordBatch`.
96    ///
97    /// Uses the type-aware [`JsonDecoder`] to coerce JSON values to the
98    /// Arrow types declared in the schema.
99    fn parse_json_batch(&self, messages: &[&[u8]]) -> Result<RecordBatch, ConnectorError> {
100        let decoder = self.json_decoder.as_ref().ok_or_else(|| {
101            ConnectorError::Internal("json_decoder not initialized for JSON format".into())
102        })?;
103        let records: Vec<RawRecord> = messages
104            .iter()
105            .map(|m| RawRecord::new(m.to_vec()))
106            .collect();
107        decoder.decode_batch(&records).map_err(ConnectorError::from)
108    }
109
110    /// Parses binary messages — each message becomes a single row with a
111    /// `LargeBinary` column named "payload".
112    #[allow(clippy::unused_self)]
113    fn parse_binary_batch(&self, messages: &[&[u8]]) -> Result<RecordBatch, ConnectorError> {
114        let mut builder =
115            BinaryBuilder::with_capacity(messages.len(), messages.iter().map(|m| m.len()).sum());
116        for msg in messages {
117            builder.append_value(msg);
118        }
119
120        let schema = Arc::new(Schema::new(vec![Field::new(
121            "payload",
122            DataType::Binary,
123            false,
124        )]));
125        let arrays: Vec<Arc<dyn arrow_array::Array>> = vec![Arc::new(builder.finish())];
126
127        RecordBatch::try_new(schema, arrays).map_err(|e| {
128            ConnectorError::Serde(crate::error::SerdeError::MalformedInput(format!(
129                "failed to build binary RecordBatch: {e}"
130            )))
131        })
132    }
133
134    /// Parses CSV text messages into a `RecordBatch`.
135    ///
136    /// Delegates to [`CsvDecoder`] for schema-directed type coercion.
137    fn parse_csv_batch(&self, messages: &[&[u8]]) -> Result<RecordBatch, ConnectorError> {
138        let decoder = self.csv_decoder.as_ref().ok_or_else(|| {
139            ConnectorError::Internal("csv_decoder not initialized for CSV format".into())
140        })?;
141        let records: Vec<RawRecord> = messages
142            .iter()
143            .map(|m| RawRecord::new(m.to_vec()))
144            .collect();
145        decoder.decode_batch(&records).map_err(ConnectorError::from)
146    }
147}
148
149/// Max event time (epoch ms) from a named `Timestamp(_)` column.
150/// `Ok(None)` when every row is null.
151///
152/// # Errors
153///
154/// `SchemaMismatch` if `field` is missing or isn't a `Timestamp(_)`.
155pub fn extract_max_event_time(
156    batch: &RecordBatch,
157    field: &str,
158) -> Result<Option<i64>, ConnectorError> {
159    let col_idx = batch.schema().index_of(field).map_err(|_| {
160        ConnectorError::SchemaMismatch(format!(
161            "event-time column '{field}' not found in batch schema"
162        ))
163    })?;
164    let arr = laminar_core::time::cast_to_millis_array(batch.column(col_idx).as_ref())
165        .map_err(|e| ConnectorError::SchemaMismatch(format!("event-time column '{field}': {e}")))?;
166    Ok((0..arr.len())
167        .filter(|&i| !arr.is_null(i))
168        .map(|i| arr.value(i))
169        .max())
170}
171
172/// Creates a default schema for JSON messages when no explicit schema is provided.
173///
174/// Uses schema inference from the first message. If `json_path` is provided,
175/// navigates into the object before inferring fields.
176///
177/// # Errors
178///
179/// Returns `ConnectorError::Serde` if the sample is not valid UTF-8 or valid JSON,
180/// or if the top-level value is not a JSON object.
181pub fn infer_schema_from_json(sample: &[u8]) -> Result<SchemaRef, ConnectorError> {
182    infer_schema_from_json_with_path(sample, None)
183}
184
185/// Like [`infer_schema_from_json`] but navigates a `json.path` first.
186///
187/// # Errors
188///
189/// Returns `ConnectorError::Serde` if the sample is not valid UTF-8 or valid JSON,
190/// if a path segment is not found, or if the target is not a JSON object.
191pub fn infer_schema_from_json_with_path(
192    sample: &[u8],
193    json_path: Option<&[String]>,
194) -> Result<SchemaRef, ConnectorError> {
195    let text = std::str::from_utf8(sample).map_err(|e| {
196        ConnectorError::Serde(crate::error::SerdeError::MalformedInput(format!(
197            "invalid UTF-8: {e}"
198        )))
199    })?;
200
201    let value: serde_json::Value = serde_json::from_str(text)
202        .map_err(|e| ConnectorError::Serde(crate::error::SerdeError::Json(e.to_string())))?;
203
204    let target = if let Some(path) = json_path {
205        let mut current = &value;
206        for segment in path {
207            current = current.get(segment.as_str()).ok_or_else(|| {
208                ConnectorError::Serde(crate::error::SerdeError::MalformedInput(format!(
209                    "json.path segment '{segment}' not found during inference"
210                )))
211            })?;
212        }
213        current
214    } else {
215        &value
216    };
217
218    let obj = target.as_object().ok_or_else(|| {
219        ConnectorError::Serde(crate::error::SerdeError::MalformedInput(
220            "schema inference requires a JSON object".into(),
221        ))
222    })?;
223
224    let fields: Vec<Field> = obj
225        .iter()
226        .map(|(key, val)| {
227            let dt = match val {
228                serde_json::Value::Bool(_) => DataType::Boolean,
229                serde_json::Value::Number(n) => {
230                    if n.is_f64() {
231                        DataType::Float64
232                    } else {
233                        DataType::Int64
234                    }
235                }
236                _ => DataType::Utf8,
237            };
238            Field::new(key, dt, true)
239        })
240        .collect();
241
242    Ok(Arc::new(Schema::new(fields)))
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    fn json_schema() -> SchemaRef {
250        Arc::new(Schema::new(vec![
251            Field::new("id", DataType::Utf8, true),
252            Field::new("value", DataType::Utf8, true),
253        ]))
254    }
255
256    #[test]
257    fn test_parse_json_batch() {
258        let parser = MessageParser::new(
259            json_schema(),
260            MessageFormat::Json,
261            JsonDecoderConfig::default(),
262        );
263        let messages: Vec<&[u8]> = vec![
264            br#"{"id": "1", "value": "hello"}"#,
265            br#"{"id": "2", "value": "world"}"#,
266        ];
267
268        let batch = parser.parse_batch(&messages).unwrap();
269        assert_eq!(batch.num_rows(), 2);
270        assert_eq!(batch.num_columns(), 2);
271    }
272
273    #[test]
274    fn test_parse_json_missing_field() {
275        let parser = MessageParser::new(
276            json_schema(),
277            MessageFormat::Json,
278            JsonDecoderConfig::default(),
279        );
280        let messages: Vec<&[u8]> = vec![br#"{"id": "1"}"#];
281
282        let batch = parser.parse_batch(&messages).unwrap();
283        assert_eq!(batch.num_rows(), 1);
284        assert!(batch.column(1).is_null(0));
285    }
286
287    #[test]
288    fn test_parse_json_numeric_values() {
289        let parser = MessageParser::new(
290            json_schema(),
291            MessageFormat::Json,
292            JsonDecoderConfig::default(),
293        );
294        let messages: Vec<&[u8]> = vec![br#"{"id": "1", "value": 42}"#];
295
296        let batch = parser.parse_batch(&messages).unwrap();
297        assert_eq!(batch.num_rows(), 1);
298    }
299
300    #[test]
301    fn test_parse_binary_batch() {
302        let schema = Arc::new(Schema::new(vec![Field::new(
303            "payload",
304            DataType::Binary,
305            false,
306        )]));
307        let parser =
308            MessageParser::new(schema, MessageFormat::Binary, JsonDecoderConfig::default());
309        let messages: Vec<&[u8]> = vec![b"hello", b"world"];
310
311        let batch = parser.parse_batch(&messages).unwrap();
312        assert_eq!(batch.num_rows(), 2);
313    }
314
315    #[test]
316    fn test_parse_csv_batch() {
317        let parser = MessageParser::new(
318            json_schema(),
319            MessageFormat::Csv {
320                delimiter: ',',
321                has_header: false,
322            },
323            JsonDecoderConfig::default(),
324        );
325        let messages: Vec<&[u8]> = vec![b"1,hello", b"2,world"];
326
327        let batch = parser.parse_batch(&messages).unwrap();
328        assert_eq!(batch.num_rows(), 2);
329    }
330
331    #[test]
332    fn test_parse_empty() {
333        let parser = MessageParser::new(
334            json_schema(),
335            MessageFormat::Json,
336            JsonDecoderConfig::default(),
337        );
338        let messages: Vec<&[u8]> = vec![];
339
340        let batch = parser.parse_batch(&messages).unwrap();
341        assert_eq!(batch.num_rows(), 0);
342    }
343
344    #[test]
345    fn test_parse_invalid_json() {
346        let parser = MessageParser::new(
347            json_schema(),
348            MessageFormat::Json,
349            JsonDecoderConfig::default(),
350        );
351        let messages: Vec<&[u8]> = vec![b"not json"];
352
353        assert!(parser.parse_batch(&messages).is_err());
354    }
355
356    #[test]
357    fn test_parse_json_typed_columns() {
358        let schema = Arc::new(Schema::new(vec![
359            Field::new("id", DataType::Int64, false),
360            Field::new("price", DataType::Float64, false),
361            Field::new("name", DataType::Utf8, true),
362        ]));
363        let parser = MessageParser::new(schema, MessageFormat::Json, JsonDecoderConfig::default());
364        let messages: Vec<&[u8]> = vec![
365            br#"{"id": 1, "price": 99.5, "name": "Widget"}"#,
366            br#"{"id": 2, "price": 10.0, "name": "Gadget"}"#,
367        ];
368
369        let batch = parser.parse_batch(&messages).unwrap();
370        assert_eq!(batch.num_rows(), 2);
371
372        // Columns should have the declared types, not Utf8.
373        assert_eq!(batch.column(0).data_type(), &DataType::Int64);
374        assert_eq!(batch.column(1).data_type(), &DataType::Float64);
375        assert_eq!(batch.column(2).data_type(), &DataType::Utf8);
376
377        let ids = batch
378            .column(0)
379            .as_any()
380            .downcast_ref::<arrow_array::Int64Array>()
381            .unwrap();
382        assert_eq!(ids.value(0), 1);
383        assert_eq!(ids.value(1), 2);
384    }
385
386    #[test]
387    fn test_parse_json_coerces_string_numbers() {
388        let schema = Arc::new(Schema::new(vec![Field::new(
389            "price",
390            DataType::Float64,
391            false,
392        )]));
393        let parser = MessageParser::new(schema, MessageFormat::Json, JsonDecoderConfig::default());
394        let messages: Vec<&[u8]> = vec![br#"{"price": "187.52"}"#];
395
396        let batch = parser.parse_batch(&messages).unwrap();
397        assert_eq!(batch.column(0).data_type(), &DataType::Float64);
398        let prices = batch
399            .column(0)
400            .as_any()
401            .downcast_ref::<arrow_array::Float64Array>()
402            .unwrap();
403        assert!((prices.value(0) - 187.52).abs() < f64::EPSILON);
404    }
405
406    #[test]
407    fn test_infer_schema() {
408        let sample = br#"{"name": "Alice", "age": 30, "active": true, "score": 99.5}"#;
409        let schema = infer_schema_from_json(sample).unwrap();
410
411        assert_eq!(schema.fields().len(), 4);
412        let age_field = schema.field_with_name("age").unwrap();
413        assert_eq!(age_field.data_type(), &DataType::Int64);
414        let active_field = schema.field_with_name("active").unwrap();
415        assert_eq!(active_field.data_type(), &DataType::Boolean);
416        let score_field = schema.field_with_name("score").unwrap();
417        assert_eq!(score_field.data_type(), &DataType::Float64);
418    }
419
420    #[test]
421    fn test_extract_max_event_time_millis() {
422        let schema = Arc::new(Schema::new(vec![Field::new(
423            "ts",
424            DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None),
425            false,
426        )]));
427        let ts = arrow_array::TimestampMillisecondArray::from(vec![1000, 3000, 2000]);
428        let batch = RecordBatch::try_new(schema, vec![Arc::new(ts)]).unwrap();
429
430        assert_eq!(extract_max_event_time(&batch, "ts").unwrap(), Some(3000));
431    }
432
433    #[test]
434    fn test_extract_max_event_time_nanos_rescaled() {
435        let schema = Arc::new(Schema::new(vec![Field::new(
436            "ts",
437            DataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, None),
438            false,
439        )]));
440        let ts = arrow_array::TimestampNanosecondArray::from(vec![
441            1_000_000_000,
442            3_000_000_000,
443            2_000_000_000,
444        ]);
445        let batch = RecordBatch::try_new(schema, vec![Arc::new(ts)]).unwrap();
446
447        assert_eq!(extract_max_event_time(&batch, "ts").unwrap(), Some(3_000));
448    }
449
450    #[test]
451    fn test_extract_max_event_time_missing_column_errors() {
452        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
453        let ids = arrow_array::Int64Array::from(vec![1, 2, 3]);
454        let batch = RecordBatch::try_new(schema, vec![Arc::new(ids)]).unwrap();
455
456        assert!(extract_max_event_time(&batch, "ts").is_err());
457    }
458
459    #[test]
460    fn test_extract_max_event_time_non_timestamp_column_errors() {
461        let schema = Arc::new(Schema::new(vec![Field::new("ts", DataType::Int64, false)]));
462        let ts = arrow_array::Int64Array::from(vec![1, 2, 3]);
463        let batch = RecordBatch::try_new(schema, vec![Arc::new(ts)]).unwrap();
464
465        assert!(extract_max_event_time(&batch, "ts").is_err());
466    }
467
468    #[test]
469    fn test_extract_max_event_time_with_nulls() {
470        let schema = Arc::new(Schema::new(vec![Field::new(
471            "ts",
472            DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None),
473            true,
474        )]));
475        let ts =
476            arrow_array::TimestampMillisecondArray::from(vec![Some(1000), None, Some(3000), None]);
477        let batch = RecordBatch::try_new(schema, vec![Arc::new(ts)]).unwrap();
478
479        assert_eq!(extract_max_event_time(&batch, "ts").unwrap(), Some(3000));
480    }
481}