Skip to main content

laminar_connectors/websocket/
parser.rs

1//! Message parsing: WebSocket frames → Arrow `RecordBatch`.
2//!
3//! Converts incoming WebSocket text/binary messages into Arrow
4//! `RecordBatch` rows for ingestion into Ring 0.
5
6use std::sync::Arc;
7
8use arrow_array::builder::{BinaryBuilder, StringBuilder};
9use arrow_array::{Array, RecordBatch};
10use arrow_schema::{DataType, Field, Schema, SchemaRef};
11
12use crate::error::ConnectorError;
13use crate::schema::json::decoder::{JsonDecoder, JsonDecoderConfig};
14use crate::schema::traits::FormatDecoder;
15use crate::schema::types::RawRecord;
16
17use super::source_config::MessageFormat;
18
19/// Parses raw WebSocket messages into Arrow `RecordBatch` data.
20pub struct MessageParser {
21    /// The output schema.
22    schema: SchemaRef,
23    /// The message format.
24    format: MessageFormat,
25    /// Type-aware JSON decoder (set for JSON/JsonLines formats).
26    json_decoder: Option<JsonDecoder>,
27}
28
29impl MessageParser {
30    /// Creates a new parser for the given schema, format, and JSON decoder config.
31    #[must_use]
32    pub fn new(
33        schema: SchemaRef,
34        format: MessageFormat,
35        decoder_config: JsonDecoderConfig,
36    ) -> Self {
37        let json_decoder = match &format {
38            MessageFormat::Json | MessageFormat::JsonLines => {
39                Some(JsonDecoder::with_config(schema.clone(), decoder_config))
40            }
41            _ => None,
42        };
43        Self {
44            schema,
45            format,
46            json_decoder,
47        }
48    }
49
50    /// Returns the output schema.
51    #[must_use]
52    pub fn schema(&self) -> SchemaRef {
53        self.schema.clone()
54    }
55
56    /// Parses a batch of raw message payloads into a `RecordBatch`.
57    ///
58    /// # Errors
59    ///
60    /// Returns `ConnectorError::Serde` if parsing fails.
61    pub fn parse_batch(&self, messages: &[&[u8]]) -> Result<RecordBatch, ConnectorError> {
62        if messages.is_empty() {
63            return Ok(RecordBatch::new_empty(self.schema.clone()));
64        }
65
66        match &self.format {
67            MessageFormat::Json | MessageFormat::JsonLines => self.parse_json_batch(messages),
68            MessageFormat::Binary => self.parse_binary_batch(messages),
69            MessageFormat::Csv { delimiter, .. } => self.parse_csv_batch(messages, *delimiter),
70        }
71    }
72
73    /// Parses JSON messages into a `RecordBatch`.
74    ///
75    /// Uses the type-aware [`JsonDecoder`] to coerce JSON values to the
76    /// Arrow types declared in the schema.
77    fn parse_json_batch(&self, messages: &[&[u8]]) -> Result<RecordBatch, ConnectorError> {
78        let decoder = self.json_decoder.as_ref().ok_or_else(|| {
79            ConnectorError::Internal("json_decoder not initialized for JSON format".into())
80        })?;
81        let records: Vec<RawRecord> = messages
82            .iter()
83            .map(|m| RawRecord::new(m.to_vec()))
84            .collect();
85        decoder.decode_batch(&records).map_err(ConnectorError::from)
86    }
87
88    /// Parses binary messages — each message becomes a single row with a
89    /// `LargeBinary` column named "payload".
90    #[allow(clippy::unused_self)]
91    fn parse_binary_batch(&self, messages: &[&[u8]]) -> Result<RecordBatch, ConnectorError> {
92        let mut builder =
93            BinaryBuilder::with_capacity(messages.len(), messages.iter().map(|m| m.len()).sum());
94        for msg in messages {
95            builder.append_value(msg);
96        }
97
98        let schema = Arc::new(Schema::new(vec![Field::new(
99            "payload",
100            DataType::Binary,
101            false,
102        )]));
103        let arrays: Vec<Arc<dyn arrow_array::Array>> = vec![Arc::new(builder.finish())];
104
105        RecordBatch::try_new(schema, arrays).map_err(|e| {
106            ConnectorError::Serde(crate::error::SerdeError::MalformedInput(format!(
107                "failed to build binary RecordBatch: {e}"
108            )))
109        })
110    }
111
112    /// Parses CSV text messages into a `RecordBatch`.
113    ///
114    /// Each message is treated as one or more CSV rows.
115    fn parse_csv_batch(
116        &self,
117        messages: &[&[u8]],
118        delimiter: char,
119    ) -> Result<RecordBatch, ConnectorError> {
120        let mut builders: Vec<StringBuilder> = self
121            .schema
122            .fields()
123            .iter()
124            .map(|_| StringBuilder::with_capacity(messages.len(), messages.len() * 32))
125            .collect();
126
127        let num_fields = self.schema.fields().len();
128
129        for msg in messages {
130            let text = std::str::from_utf8(msg).map_err(|e| {
131                ConnectorError::Serde(crate::error::SerdeError::MalformedInput(format!(
132                    "invalid UTF-8 in CSV message: {e}"
133                )))
134            })?;
135
136            for line in text.lines() {
137                if line.trim().is_empty() {
138                    continue;
139                }
140
141                let fields: Vec<&str> = line.split(delimiter).collect();
142                for (i, builder) in builders.iter_mut().enumerate().take(num_fields) {
143                    if i < fields.len() {
144                        builder.append_value(fields[i].trim());
145                    } else {
146                        builder.append_null();
147                    }
148                }
149            }
150        }
151
152        let arrays: Vec<Arc<dyn arrow_array::Array>> = builders
153            .into_iter()
154            .map(|mut b| Arc::new(b.finish()) as _)
155            .collect();
156
157        RecordBatch::try_new(self.schema.clone(), arrays).map_err(|e| {
158            ConnectorError::Serde(crate::error::SerdeError::Csv(format!(
159                "failed to build CSV RecordBatch: {e}"
160            )))
161        })
162    }
163}
164
165/// Extracts the maximum event time (as epoch milliseconds) from a named column.
166///
167/// Supports `Int64` and `TimestampMillisecond` column types. Returns `None`
168/// if the column is missing, has an unsupported type, or is entirely null.
169#[must_use]
170#[allow(clippy::cast_possible_truncation)]
171pub fn extract_max_event_time(batch: &RecordBatch, field: &str) -> Option<i64> {
172    let col_idx = batch.schema().index_of(field).ok()?;
173    let col = batch.column(col_idx);
174
175    if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Int64Array>() {
176        (0..arr.len())
177            .filter(|&i| !arr.is_null(i))
178            .map(|i| arr.value(i))
179            .max()
180    } else if let Some(arr) = col
181        .as_any()
182        .downcast_ref::<arrow_array::TimestampMillisecondArray>()
183    {
184        (0..arr.len())
185            .filter(|&i| !arr.is_null(i))
186            .map(|i| arr.value(i))
187            .max()
188    } else {
189        None
190    }
191}
192
193/// Creates a default schema for JSON messages when no explicit schema is provided.
194///
195/// Uses schema inference from the first message. If `json_path` is provided,
196/// navigates into the object before inferring fields.
197///
198/// # Errors
199///
200/// Returns `ConnectorError::Serde` if the sample is not valid UTF-8 or valid JSON,
201/// or if the top-level value is not a JSON object.
202pub fn infer_schema_from_json(sample: &[u8]) -> Result<SchemaRef, ConnectorError> {
203    infer_schema_from_json_with_path(sample, None)
204}
205
206/// Like [`infer_schema_from_json`] but navigates a `json.path` first.
207///
208/// # Errors
209///
210/// Returns `ConnectorError::Serde` if the sample is not valid UTF-8 or valid JSON,
211/// if a path segment is not found, or if the target is not a JSON object.
212pub fn infer_schema_from_json_with_path(
213    sample: &[u8],
214    json_path: Option<&[String]>,
215) -> Result<SchemaRef, ConnectorError> {
216    let text = std::str::from_utf8(sample).map_err(|e| {
217        ConnectorError::Serde(crate::error::SerdeError::MalformedInput(format!(
218            "invalid UTF-8: {e}"
219        )))
220    })?;
221
222    let value: serde_json::Value = serde_json::from_str(text)
223        .map_err(|e| ConnectorError::Serde(crate::error::SerdeError::Json(e.to_string())))?;
224
225    let target = if let Some(path) = json_path {
226        let mut current = &value;
227        for segment in path {
228            current = current.get(segment.as_str()).ok_or_else(|| {
229                ConnectorError::Serde(crate::error::SerdeError::MalformedInput(format!(
230                    "json.path segment '{segment}' not found during inference"
231                )))
232            })?;
233        }
234        current
235    } else {
236        &value
237    };
238
239    let obj = target.as_object().ok_or_else(|| {
240        ConnectorError::Serde(crate::error::SerdeError::MalformedInput(
241            "schema inference requires a JSON object".into(),
242        ))
243    })?;
244
245    let fields: Vec<Field> = obj
246        .iter()
247        .map(|(key, val)| {
248            let dt = match val {
249                serde_json::Value::Bool(_) => DataType::Boolean,
250                serde_json::Value::Number(n) => {
251                    if n.is_f64() {
252                        DataType::Float64
253                    } else {
254                        DataType::Int64
255                    }
256                }
257                _ => DataType::Utf8,
258            };
259            Field::new(key, dt, true)
260        })
261        .collect();
262
263    Ok(Arc::new(Schema::new(fields)))
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    fn json_schema() -> SchemaRef {
271        Arc::new(Schema::new(vec![
272            Field::new("id", DataType::Utf8, true),
273            Field::new("value", DataType::Utf8, true),
274        ]))
275    }
276
277    #[test]
278    fn test_parse_json_batch() {
279        let parser = MessageParser::new(
280            json_schema(),
281            MessageFormat::Json,
282            JsonDecoderConfig::default(),
283        );
284        let messages: Vec<&[u8]> = vec![
285            br#"{"id": "1", "value": "hello"}"#,
286            br#"{"id": "2", "value": "world"}"#,
287        ];
288
289        let batch = parser.parse_batch(&messages).unwrap();
290        assert_eq!(batch.num_rows(), 2);
291        assert_eq!(batch.num_columns(), 2);
292    }
293
294    #[test]
295    fn test_parse_json_missing_field() {
296        let parser = MessageParser::new(
297            json_schema(),
298            MessageFormat::Json,
299            JsonDecoderConfig::default(),
300        );
301        let messages: Vec<&[u8]> = vec![br#"{"id": "1"}"#];
302
303        let batch = parser.parse_batch(&messages).unwrap();
304        assert_eq!(batch.num_rows(), 1);
305        assert!(batch.column(1).is_null(0));
306    }
307
308    #[test]
309    fn test_parse_json_numeric_values() {
310        let parser = MessageParser::new(
311            json_schema(),
312            MessageFormat::Json,
313            JsonDecoderConfig::default(),
314        );
315        let messages: Vec<&[u8]> = vec![br#"{"id": "1", "value": 42}"#];
316
317        let batch = parser.parse_batch(&messages).unwrap();
318        assert_eq!(batch.num_rows(), 1);
319    }
320
321    #[test]
322    fn test_parse_binary_batch() {
323        let schema = Arc::new(Schema::new(vec![Field::new(
324            "payload",
325            DataType::Binary,
326            false,
327        )]));
328        let parser =
329            MessageParser::new(schema, MessageFormat::Binary, JsonDecoderConfig::default());
330        let messages: Vec<&[u8]> = vec![b"hello", b"world"];
331
332        let batch = parser.parse_batch(&messages).unwrap();
333        assert_eq!(batch.num_rows(), 2);
334    }
335
336    #[test]
337    fn test_parse_csv_batch() {
338        let parser = MessageParser::new(
339            json_schema(),
340            MessageFormat::Csv {
341                delimiter: ',',
342                has_header: false,
343            },
344            JsonDecoderConfig::default(),
345        );
346        let messages: Vec<&[u8]> = vec![b"1,hello", b"2,world"];
347
348        let batch = parser.parse_batch(&messages).unwrap();
349        assert_eq!(batch.num_rows(), 2);
350    }
351
352    #[test]
353    fn test_parse_empty() {
354        let parser = MessageParser::new(
355            json_schema(),
356            MessageFormat::Json,
357            JsonDecoderConfig::default(),
358        );
359        let messages: Vec<&[u8]> = vec![];
360
361        let batch = parser.parse_batch(&messages).unwrap();
362        assert_eq!(batch.num_rows(), 0);
363    }
364
365    #[test]
366    fn test_parse_invalid_json() {
367        let parser = MessageParser::new(
368            json_schema(),
369            MessageFormat::Json,
370            JsonDecoderConfig::default(),
371        );
372        let messages: Vec<&[u8]> = vec![b"not json"];
373
374        assert!(parser.parse_batch(&messages).is_err());
375    }
376
377    #[test]
378    fn test_parse_json_typed_columns() {
379        let schema = Arc::new(Schema::new(vec![
380            Field::new("id", DataType::Int64, false),
381            Field::new("price", DataType::Float64, false),
382            Field::new("name", DataType::Utf8, true),
383        ]));
384        let parser = MessageParser::new(schema, MessageFormat::Json, JsonDecoderConfig::default());
385        let messages: Vec<&[u8]> = vec![
386            br#"{"id": 1, "price": 99.5, "name": "Widget"}"#,
387            br#"{"id": 2, "price": 10.0, "name": "Gadget"}"#,
388        ];
389
390        let batch = parser.parse_batch(&messages).unwrap();
391        assert_eq!(batch.num_rows(), 2);
392
393        // Columns should have the declared types, not Utf8.
394        assert_eq!(batch.column(0).data_type(), &DataType::Int64);
395        assert_eq!(batch.column(1).data_type(), &DataType::Float64);
396        assert_eq!(batch.column(2).data_type(), &DataType::Utf8);
397
398        let ids = batch
399            .column(0)
400            .as_any()
401            .downcast_ref::<arrow_array::Int64Array>()
402            .unwrap();
403        assert_eq!(ids.value(0), 1);
404        assert_eq!(ids.value(1), 2);
405    }
406
407    #[test]
408    fn test_parse_json_coerces_string_numbers() {
409        let schema = Arc::new(Schema::new(vec![Field::new(
410            "price",
411            DataType::Float64,
412            false,
413        )]));
414        let parser = MessageParser::new(schema, MessageFormat::Json, JsonDecoderConfig::default());
415        let messages: Vec<&[u8]> = vec![br#"{"price": "187.52"}"#];
416
417        let batch = parser.parse_batch(&messages).unwrap();
418        assert_eq!(batch.column(0).data_type(), &DataType::Float64);
419        let prices = batch
420            .column(0)
421            .as_any()
422            .downcast_ref::<arrow_array::Float64Array>()
423            .unwrap();
424        assert!((prices.value(0) - 187.52).abs() < f64::EPSILON);
425    }
426
427    #[test]
428    fn test_infer_schema() {
429        let sample = br#"{"name": "Alice", "age": 30, "active": true, "score": 99.5}"#;
430        let schema = infer_schema_from_json(sample).unwrap();
431
432        assert_eq!(schema.fields().len(), 4);
433        let age_field = schema.field_with_name("age").unwrap();
434        assert_eq!(age_field.data_type(), &DataType::Int64);
435        let active_field = schema.field_with_name("active").unwrap();
436        assert_eq!(active_field.data_type(), &DataType::Boolean);
437        let score_field = schema.field_with_name("score").unwrap();
438        assert_eq!(score_field.data_type(), &DataType::Float64);
439    }
440
441    #[test]
442    fn test_extract_max_event_time_int64() {
443        let schema = Arc::new(Schema::new(vec![Field::new("ts", DataType::Int64, false)]));
444        let ts = arrow_array::Int64Array::from(vec![1000, 3000, 2000]);
445        let batch = RecordBatch::try_new(schema, vec![Arc::new(ts)]).unwrap();
446
447        assert_eq!(extract_max_event_time(&batch, "ts"), Some(3000));
448    }
449
450    #[test]
451    fn test_extract_max_event_time_missing_column() {
452        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
453        let ids = arrow_array::Int64Array::from(vec![1, 2, 3]);
454        let batch = RecordBatch::try_new(schema, vec![Arc::new(ids)]).unwrap();
455
456        assert_eq!(extract_max_event_time(&batch, "ts"), None);
457    }
458
459    #[test]
460    fn test_extract_max_event_time_with_nulls() {
461        let schema = Arc::new(Schema::new(vec![Field::new("ts", DataType::Int64, true)]));
462        let ts = arrow_array::Int64Array::from(vec![Some(1000), None, Some(3000), None]);
463        let batch = RecordBatch::try_new(schema, vec![Arc::new(ts)]).unwrap();
464
465        assert_eq!(extract_max_event_time(&batch, "ts"), Some(3000));
466    }
467
468    #[test]
469    fn test_extract_max_event_time_unsupported_type() {
470        let schema = Arc::new(Schema::new(vec![Field::new("ts", DataType::Utf8, false)]));
471        let ts = arrow_array::StringArray::from(vec!["2026-01-01"]);
472        let batch = RecordBatch::try_new(schema, vec![Arc::new(ts)]).unwrap();
473
474        assert_eq!(extract_max_event_time(&batch, "ts"), None);
475    }
476}