1use std::collections::VecDeque;
18use std::sync::Arc;
19
20use arrow_array::{
21 Array, Float64Array, Int64Array, RecordBatch, StringArray, TimestampMicrosecondArray,
22};
23use arrow_schema::{DataType, Field, Schema};
24use rustc_hash::FxHashMap;
25
26use super::{
27 Event, Operator, OperatorContext, OperatorError, OperatorState, Output, OutputVec, Timer,
28};
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum AnalyticFunctionKind {
33 Lag,
35 Lead,
37 FirstValue,
39 LastValue,
41 NthValue,
43}
44
45impl AnalyticFunctionKind {
46 #[must_use]
48 pub fn is_lag(self) -> bool {
49 self == Self::Lag
50 }
51
52 #[must_use]
54 pub fn is_lead(self) -> bool {
55 self == Self::Lead
56 }
57}
58
59#[derive(Debug, Clone)]
61pub struct LagLeadConfig {
62 pub operator_id: String,
64 pub functions: Vec<LagLeadFunctionSpec>,
66 pub partition_columns: Vec<String>,
68 pub max_partitions: usize,
70}
71
72#[derive(Debug, Clone)]
74pub struct LagLeadFunctionSpec {
75 pub function_type: AnalyticFunctionKind,
77 pub source_column: String,
79 pub offset: usize,
81 pub default_value: Option<f64>,
83 pub output_column: String,
85}
86
87#[derive(Debug, Clone)]
89struct PartitionState {
90 lag_history: VecDeque<f64>,
92 lead_pending: VecDeque<PendingLeadEvent>,
94 first_values: Vec<Option<f64>>,
96 nth_values: Vec<Option<f64>>,
98 event_count: usize,
100}
101
102#[derive(Debug, Clone)]
104struct PendingLeadEvent {
105 event: Event,
107 remaining: usize,
109 value: f64,
111}
112
113#[derive(Debug, Default)]
115pub struct LagLeadMetrics {
116 pub events_processed: u64,
118 pub lag_lookups: u64,
120 pub lead_buffered: u64,
122 pub lead_flushed: u64,
124 pub partitions_active: u64,
126}
127
128pub struct LagLeadOperator {
133 operator_id: String,
135 functions: Vec<LagLeadFunctionSpec>,
137 partition_columns: Vec<String>,
139 partitions: FxHashMap<Vec<u8>, PartitionState>,
141 max_partitions: usize,
143 metrics: LagLeadMetrics,
145 key_buf: Vec<u8>,
148 lead_specs_cache: Vec<(usize, Option<f64>)>,
150 lead_defaults_cache: Vec<f64>,
152 lead_columns_cache: Vec<String>,
154 cached_partition_indices: Vec<Option<usize>>,
156 cached_lag_col_index: Option<usize>,
158 cached_lead_col_index: Option<usize>,
160}
161
162impl LagLeadOperator {
163 #[must_use]
165 pub fn new(config: LagLeadConfig) -> Self {
166 let lead_specs_cache: Vec<(usize, Option<f64>)> = config
168 .functions
169 .iter()
170 .filter(|f| f.function_type.is_lead())
171 .map(|f| (f.offset, f.default_value))
172 .collect();
173 let lead_defaults_cache: Vec<f64> = config
174 .functions
175 .iter()
176 .filter(|f| f.function_type.is_lead())
177 .map(|f| f.default_value.unwrap_or(f64::NAN))
178 .collect();
179 let lead_columns_cache: Vec<String> = config
180 .functions
181 .iter()
182 .filter(|f| f.function_type.is_lead())
183 .map(|f| f.output_column.clone())
184 .collect();
185
186 let num_partition_cols = config.partition_columns.len();
187 Self {
188 operator_id: config.operator_id,
189 functions: config.functions,
190 partition_columns: config.partition_columns,
191 partitions: FxHashMap::default(),
192 max_partitions: config.max_partitions,
193 metrics: LagLeadMetrics::default(),
194 key_buf: Vec::with_capacity(64),
195 lead_specs_cache,
196 lead_defaults_cache,
197 lead_columns_cache,
198 cached_partition_indices: vec![None; num_partition_cols],
199 cached_lag_col_index: None,
200 cached_lead_col_index: None,
201 }
202 }
203
204 #[must_use]
206 pub fn partition_count(&self) -> usize {
207 self.partitions.len()
208 }
209
210 #[must_use]
212 pub fn metrics(&self) -> &LagLeadMetrics {
213 &self.metrics
214 }
215
216 fn fill_partition_key(&mut self, event: &Event) {
220 self.key_buf.clear();
221 let batch = &event.data;
222
223 for (i, col_name) in self.partition_columns.iter().enumerate() {
224 let col_idx = if let Some(idx) = self.cached_partition_indices[i] {
225 idx
226 } else {
227 let Ok(idx) = batch.schema().index_of(col_name) else {
228 self.key_buf.push(0x00);
229 continue;
230 };
231 self.cached_partition_indices[i] = Some(idx);
232 idx
233 };
234
235 let array = batch.column(col_idx);
236
237 if array.is_null(0) {
238 self.key_buf.push(0x00);
239 continue;
240 }
241
242 self.key_buf.push(0x01); match array.data_type() {
245 DataType::Int64 => {
246 if let Some(arr) = array.as_any().downcast_ref::<Int64Array>() {
247 self.key_buf.extend_from_slice(&arr.value(0).to_le_bytes());
248 } else {
249 self.key_buf.push(0x00);
250 }
251 }
252 DataType::Utf8 => {
253 if let Some(arr) = array.as_any().downcast_ref::<StringArray>() {
254 self.key_buf.extend_from_slice(arr.value(0).as_bytes());
255 self.key_buf.push(0x00); } else {
257 self.key_buf.push(0x00);
258 }
259 }
260 DataType::Float64 => {
261 if let Some(arr) = array.as_any().downcast_ref::<Float64Array>() {
262 self.key_buf
263 .extend_from_slice(&arr.value(0).to_bits().to_le_bytes());
264 } else {
265 self.key_buf.push(0x00);
266 }
267 }
268 _ => {
269 self.key_buf.push(0x00);
270 }
271 }
272 }
273 }
274
275 fn extract_column_value(event: &Event, column: &str, cached_index: &mut Option<usize>) -> f64 {
279 let batch = &event.data;
280 let col_idx = if let Some(idx) = *cached_index {
281 idx
282 } else {
283 let Ok(idx) = batch.schema().index_of(column) else {
284 return f64::NAN;
285 };
286 *cached_index = Some(idx);
287 idx
288 };
289
290 let array = batch.column(col_idx);
291 if array.is_null(0) {
292 return f64::NAN;
293 }
294
295 match array.data_type() {
296 DataType::Float64 => {
297 if let Some(arr) = array.as_any().downcast_ref::<Float64Array>() {
298 arr.value(0)
299 } else {
300 f64::NAN
301 }
302 }
303 DataType::Int64 => {
304 if let Some(arr) = array.as_any().downcast_ref::<Int64Array>() {
305 #[allow(clippy::cast_precision_loss)]
306 {
307 arr.value(0) as f64
308 }
309 } else {
310 f64::NAN
311 }
312 }
313 DataType::Timestamp(_, _) => {
314 if let Some(arr) = array.as_any().downcast_ref::<TimestampMicrosecondArray>() {
315 #[allow(clippy::cast_precision_loss)]
316 {
317 arr.value(0) as f64
318 }
319 } else {
320 f64::NAN
321 }
322 }
323 _ => f64::NAN,
324 }
325 }
326
327 fn compute_lag_values(functions: &[LagLeadFunctionSpec], state: &PartitionState) -> Vec<f64> {
329 functions
330 .iter()
331 .filter(|f| f.function_type.is_lag())
332 .map(|func| {
333 let history = &state.lag_history;
334 if history.len() >= func.offset {
335 let idx = history.len() - func.offset;
336 history[idx]
337 } else {
338 func.default_value.unwrap_or(f64::NAN)
339 }
340 })
341 .collect()
342 }
343
344 fn compute_positional_values(
348 functions: &[LagLeadFunctionSpec],
349 state: &mut PartitionState,
350 current_value: f64,
351 ) -> Vec<f64> {
352 let mut first_idx = 0usize;
353 let mut nth_idx = 0usize;
354 let mut results = Vec::new();
355
356 for func in functions {
357 match func.function_type {
358 AnalyticFunctionKind::FirstValue => {
359 if state.first_values[first_idx].is_none() {
360 state.first_values[first_idx] = Some(current_value);
361 }
362 results.push(state.first_values[first_idx].unwrap());
363 first_idx += 1;
364 }
365 AnalyticFunctionKind::LastValue => {
366 results.push(current_value);
368 }
369 AnalyticFunctionKind::NthValue => {
370 if state.nth_values[nth_idx].is_none() && state.event_count == func.offset {
372 state.nth_values[nth_idx] = Some(current_value);
373 }
374 results.push(
375 state.nth_values[nth_idx].unwrap_or(func.default_value.unwrap_or(f64::NAN)),
376 );
377 nth_idx += 1;
378 }
379 _ => {} }
381 }
382
383 results
384 }
385
386 fn build_output(
388 functions: &[LagLeadFunctionSpec],
389 event: &Event,
390 lag_values: &[f64],
391 lead_values: &[f64],
392 positional_values: &[f64],
393 ) -> Event {
394 let original_batch = &event.data;
395 let num_original = original_batch.num_columns();
396 let num_functions = functions.len();
397 let mut fields: Vec<Field> = Vec::with_capacity(num_original + num_functions);
398 fields.extend(
399 original_batch
400 .schema()
401 .fields()
402 .iter()
403 .map(|f| f.as_ref().clone()),
404 );
405 let mut columns: Vec<Arc<dyn Array>> = Vec::with_capacity(num_original + num_functions);
406 columns.extend((0..num_original).map(|i| original_batch.column(i).clone()));
407
408 let mut lag_idx = 0;
409 let mut lead_idx = 0;
410 let mut pos_idx = 0;
411
412 for func in functions {
413 let value = match func.function_type {
414 AnalyticFunctionKind::Lag => {
415 let v = lag_values.get(lag_idx).copied().unwrap_or(f64::NAN);
416 lag_idx += 1;
417 v
418 }
419 AnalyticFunctionKind::Lead => {
420 let v = lead_values.get(lead_idx).copied().unwrap_or(f64::NAN);
421 lead_idx += 1;
422 v
423 }
424 AnalyticFunctionKind::FirstValue
425 | AnalyticFunctionKind::LastValue
426 | AnalyticFunctionKind::NthValue => {
427 let v = positional_values.get(pos_idx).copied().unwrap_or(f64::NAN);
428 pos_idx += 1;
429 v
430 }
431 };
432
433 fields.push(Field::new(&func.output_column, DataType::Float64, true));
434 columns.push(Arc::new(Float64Array::from(vec![value])));
435 }
436
437 let schema = Arc::new(Schema::new(fields));
438 let batch = RecordBatch::try_new(schema, columns)
439 .unwrap_or_else(|_| RecordBatch::new_empty(Arc::new(Schema::empty())));
440 Event::new(event.timestamp, batch)
441 }
442
443 #[allow(clippy::too_many_lines)]
446 fn process_event(&mut self, event: &Event) -> OutputVec {
447 self.fill_partition_key(event);
448 let partition_key = self.key_buf.clone();
449
450 if !self.partitions.contains_key(&partition_key)
452 && self.partitions.len() >= self.max_partitions
453 {
454 return OutputVec::new();
455 }
456
457 let has_lag = self.functions.iter().any(|f| f.function_type.is_lag());
458 let has_lead = !self.lead_specs_cache.is_empty();
459
460 let first_value_count = self
462 .functions
463 .iter()
464 .filter(|f| f.function_type == AnalyticFunctionKind::FirstValue)
465 .count();
466 let nth_value_count = self
467 .functions
468 .iter()
469 .filter(|f| f.function_type == AnalyticFunctionKind::NthValue)
470 .count();
471
472 let max_lag_offset = self
474 .functions
475 .iter()
476 .filter(|f| f.function_type.is_lag())
477 .map(|f| f.offset)
478 .max()
479 .unwrap_or(1);
480 let max_lead_offset = self
481 .functions
482 .iter()
483 .filter(|f| f.function_type.is_lead())
484 .map(|f| f.offset)
485 .max()
486 .unwrap_or(1);
487 let lag_source_col = self
488 .functions
489 .iter()
490 .find(|f| f.function_type.is_lag())
491 .map(|f| f.source_column.clone());
492 let lead_source_col = self
493 .functions
494 .iter()
495 .find(|f| f.function_type.is_lead())
496 .map(|f| f.source_column.clone());
497
498 let state = self
500 .partitions
501 .entry(partition_key)
502 .or_insert_with(|| PartitionState {
503 lag_history: VecDeque::new(),
504 lead_pending: VecDeque::new(),
505 first_values: vec![None; first_value_count],
506 nth_values: vec![None; nth_value_count],
507 event_count: 0,
508 });
509 state.event_count += 1;
510
511 let mut outputs = OutputVec::new();
512
513 let lag_values = if has_lag {
515 Self::compute_lag_values(&self.functions, state)
516 } else {
517 vec![]
518 };
519
520 if has_lag {
522 if let Some(col) = &lag_source_col {
523 let value = Self::extract_column_value(event, col, &mut self.cached_lag_col_index);
524 state.lag_history.push_back(value);
525 while state.lag_history.len() > max_lag_offset {
526 state.lag_history.pop_front();
527 }
528 }
529 }
530
531 let has_positional = self.functions.iter().any(|f| {
534 matches!(
535 f.function_type,
536 AnalyticFunctionKind::FirstValue
537 | AnalyticFunctionKind::LastValue
538 | AnalyticFunctionKind::NthValue
539 )
540 });
541 let positional_values = if has_positional {
542 let pos_source_col = self
543 .functions
544 .iter()
545 .find(|f| {
546 matches!(
547 f.function_type,
548 AnalyticFunctionKind::FirstValue
549 | AnalyticFunctionKind::LastValue
550 | AnalyticFunctionKind::NthValue
551 )
552 })
553 .map(|f| f.source_column.clone());
554 let current_value = pos_source_col.as_ref().map_or(f64::NAN, |col| {
555 Self::extract_column_value(event, col, &mut self.cached_lag_col_index)
556 });
557 Self::compute_positional_values(&self.functions, state, current_value)
558 } else {
559 vec![]
560 };
561
562 if has_lead {
563 let value = if let Some(col) = &lead_source_col {
565 Self::extract_column_value(event, col, &mut self.cached_lead_col_index)
566 } else {
567 f64::NAN
568 };
569
570 for pending in &mut state.lead_pending {
572 pending.remaining = pending.remaining.saturating_sub(1);
573 }
574
575 state.lead_pending.push_back(PendingLeadEvent {
576 event: event.clone(),
577 remaining: max_lead_offset,
578 value,
579 });
580 self.metrics.lead_buffered += 1;
581
582 let mut resolved_events = Vec::new();
585 while state.lead_pending.front().is_some_and(|p| p.remaining == 0) {
586 let resolved = state.lead_pending.pop_front().unwrap();
587 let lead_values: Vec<f64> = self
588 .lead_specs_cache
589 .iter()
590 .map(|(offset, default)| {
591 if *offset <= state.lead_pending.len() {
592 state.lead_pending[*offset - 1].value
593 } else {
594 default.unwrap_or(f64::NAN)
595 }
596 })
597 .collect();
598 resolved_events.push((resolved, lead_values));
599 }
600
601 for (resolved, lead_values) in resolved_events {
602 let output = Self::build_output(
603 &self.functions,
604 &resolved.event,
605 &lag_values,
606 &lead_values,
607 &positional_values,
608 );
609 outputs.push(Output::Event(output));
610 self.metrics.lead_flushed += 1;
611 }
612 } else {
613 let output =
615 Self::build_output(&self.functions, event, &lag_values, &[], &positional_values);
616 outputs.push(Output::Event(output));
617 }
618
619 self.metrics.events_processed += 1;
620 if has_lag {
621 self.metrics.lag_lookups += 1;
622 }
623 self.metrics.partitions_active = self.partitions.len() as u64;
624
625 outputs
626 }
627
628 fn flush_pending_leads(&mut self) -> OutputVec {
631 let mut outputs = OutputVec::new();
632
633 let lead_defaults = &self.lead_defaults_cache;
635 let lead_output_columns = &self.lead_columns_cache;
636
637 let mut flushed_count = 0u64;
638
639 for state in self.partitions.values_mut() {
640 while let Some(pending) = state.lead_pending.pop_front() {
641 let original_batch = &pending.event.data;
642 let num_original = original_batch.num_columns();
643 let num_lead = lead_output_columns.len();
644 let mut fields: Vec<Field> = Vec::with_capacity(num_original + num_lead);
645 fields.extend(
646 original_batch
647 .schema()
648 .fields()
649 .iter()
650 .map(|f| f.as_ref().clone()),
651 );
652 let mut columns: Vec<Arc<dyn Array>> = Vec::with_capacity(num_original + num_lead);
653 columns.extend((0..num_original).map(|i| original_batch.column(i).clone()));
654
655 for (col_name, &default) in lead_output_columns.iter().zip(lead_defaults.iter()) {
656 fields.push(Field::new(col_name, DataType::Float64, true));
657 columns.push(Arc::new(Float64Array::from(vec![default])));
658 }
659
660 let schema = Arc::new(Schema::new(fields));
661 if let Ok(batch) = RecordBatch::try_new(schema, columns) {
662 let output_event = Event::new(pending.event.timestamp, batch);
663 outputs.push(Output::Event(output_event));
664 flushed_count += 1;
665 }
666 }
667 }
668
669 self.metrics.lead_flushed += flushed_count;
670 outputs
671 }
672}
673
674impl Operator for LagLeadOperator {
675 fn process(&mut self, event: &Event, _ctx: &mut OperatorContext) -> OutputVec {
676 self.process_event(event)
677 }
678
679 fn on_timer(&mut self, _timer: Timer, _ctx: &mut OperatorContext) -> OutputVec {
680 self.flush_pending_leads()
682 }
683
684 fn checkpoint(&self) -> OperatorState {
685 let mut data = Vec::new();
686
687 let num_partitions = self.partitions.len() as u64;
689 data.extend_from_slice(&num_partitions.to_le_bytes());
690
691 for (key, state) in &self.partitions {
693 let key_len = key.len() as u64;
695 data.extend_from_slice(&key_len.to_le_bytes());
696 data.extend_from_slice(key);
697
698 let history_len = state.lag_history.len() as u64;
700 data.extend_from_slice(&history_len.to_le_bytes());
701 for &val in &state.lag_history {
702 data.extend_from_slice(&val.to_le_bytes());
703 }
704
705 let pending_len = state.lead_pending.len() as u64;
707 data.extend_from_slice(&pending_len.to_le_bytes());
708 for pending in &state.lead_pending {
709 data.extend_from_slice(&pending.event.timestamp.to_le_bytes());
710 data.extend_from_slice(&(pending.remaining as u64).to_le_bytes());
711 data.extend_from_slice(&pending.value.to_le_bytes());
712 }
713 }
714
715 OperatorState {
716 operator_id: self.operator_id.clone(),
717 data,
718 }
719 }
720
721 #[allow(clippy::cast_possible_truncation)]
722 fn restore(&mut self, state: OperatorState) -> Result<(), OperatorError> {
723 if state.data.len() < 8 {
724 return Err(OperatorError::SerializationFailed(
725 "LagLead checkpoint data too short".to_string(),
726 ));
727 }
728
729 let mut offset = 0;
730
731 let num_partitions = u64::from_le_bytes(
732 state.data[offset..offset + 8]
733 .try_into()
734 .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
735 ) as usize;
736 offset += 8;
737
738 self.partitions.clear();
739
740 for _ in 0..num_partitions {
741 if offset + 8 > state.data.len() {
742 return Err(OperatorError::SerializationFailed(
743 "LagLead checkpoint truncated".to_string(),
744 ));
745 }
746
747 let key_len = u64::from_le_bytes(
749 state.data[offset..offset + 8]
750 .try_into()
751 .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
752 ) as usize;
753 offset += 8;
754
755 let partition_key = state.data[offset..offset + key_len].to_vec();
756 offset += key_len;
757
758 let history_len = u64::from_le_bytes(
760 state.data[offset..offset + 8]
761 .try_into()
762 .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
763 ) as usize;
764 offset += 8;
765
766 let mut lag_history = VecDeque::with_capacity(history_len);
767 for _ in 0..history_len {
768 let val = f64::from_le_bytes(
769 state.data[offset..offset + 8]
770 .try_into()
771 .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
772 );
773 offset += 8;
774 lag_history.push_back(val);
775 }
776
777 let pending_len = u64::from_le_bytes(
779 state.data[offset..offset + 8]
780 .try_into()
781 .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
782 ) as usize;
783 offset += 8;
784
785 let mut lead_pending = VecDeque::with_capacity(pending_len);
786 for _ in 0..pending_len {
787 let timestamp = i64::from_le_bytes(
788 state.data[offset..offset + 8]
789 .try_into()
790 .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
791 );
792 offset += 8;
793
794 let remaining = u64::from_le_bytes(
795 state.data[offset..offset + 8]
796 .try_into()
797 .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
798 ) as usize;
799 offset += 8;
800
801 let value = f64::from_le_bytes(
802 state.data[offset..offset + 8]
803 .try_into()
804 .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
805 );
806 offset += 8;
807
808 let batch = RecordBatch::new_empty(Arc::new(Schema::empty()));
809 lead_pending.push_back(PendingLeadEvent {
810 event: Event::new(timestamp, batch),
811 remaining,
812 value,
813 });
814 }
815
816 let first_value_count = self
817 .functions
818 .iter()
819 .filter(|f| f.function_type == AnalyticFunctionKind::FirstValue)
820 .count();
821 let nth_value_count = self
822 .functions
823 .iter()
824 .filter(|f| f.function_type == AnalyticFunctionKind::NthValue)
825 .count();
826 self.partitions.insert(
827 partition_key,
828 PartitionState {
829 lag_history,
830 lead_pending,
831 first_values: vec![None; first_value_count],
832 nth_values: vec![None; nth_value_count],
833 event_count: 0,
834 },
835 );
836 }
837
838 Ok(())
839 }
840}
841
842#[cfg(test)]
843#[allow(clippy::float_cmp)]
844mod tests {
845 use super::*;
846 use crate::operator::TimerKey;
847 use crate::state::InMemoryStore;
848 use crate::time::{BoundedOutOfOrdernessGenerator, TimerService};
849
850 fn make_trade(timestamp: i64, symbol: &str, price: f64) -> Event {
851 let schema = Arc::new(Schema::new(vec![
852 Field::new("symbol", DataType::Utf8, false),
853 Field::new("price", DataType::Float64, false),
854 ]));
855 let batch = RecordBatch::try_new(
856 schema,
857 vec![
858 Arc::new(StringArray::from(vec![symbol])),
859 Arc::new(Float64Array::from(vec![price])),
860 ],
861 )
862 .unwrap();
863 Event::new(timestamp, batch)
864 }
865
866 fn create_test_context<'a>(
867 timers: &'a mut TimerService,
868 state: &'a mut dyn crate::state::StateStore,
869 watermark_gen: &'a mut dyn crate::time::WatermarkGenerator,
870 ) -> OperatorContext<'a> {
871 OperatorContext {
872 event_time: 0,
873 processing_time: 0,
874 timers,
875 state,
876 watermark_generator: watermark_gen,
877 operator_index: 0,
878 }
879 }
880
881 fn lag_config(offset: usize) -> LagLeadConfig {
882 LagLeadConfig {
883 operator_id: "test_lag".to_string(),
884 functions: vec![LagLeadFunctionSpec {
885 function_type: AnalyticFunctionKind::Lag,
886 source_column: "price".to_string(),
887 offset,
888 default_value: None,
889 output_column: "prev_price".to_string(),
890 }],
891 partition_columns: vec!["symbol".to_string()],
892 max_partitions: 100,
893 }
894 }
895
896 fn lead_config(offset: usize) -> LagLeadConfig {
897 LagLeadConfig {
898 operator_id: "test_lead".to_string(),
899 functions: vec![LagLeadFunctionSpec {
900 function_type: AnalyticFunctionKind::Lead,
901 source_column: "price".to_string(),
902 offset,
903 default_value: Some(0.0),
904 output_column: "next_price".to_string(),
905 }],
906 partition_columns: vec!["symbol".to_string()],
907 max_partitions: 100,
908 }
909 }
910
911 #[test]
912 fn test_lag_first_event_returns_nan() {
913 let mut op = LagLeadOperator::new(lag_config(1));
914 let mut timers = TimerService::new();
915 let mut state = InMemoryStore::new();
916 let mut wm = BoundedOutOfOrdernessGenerator::new(0);
917
918 let event = make_trade(1, "AAPL", 150.0);
919 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
920 let outputs = op.process(&event, &mut ctx);
921
922 assert_eq!(outputs.len(), 1);
923 if let Output::Event(e) = &outputs[0] {
924 let arr = e
925 .data
926 .column_by_name("prev_price")
927 .unwrap()
928 .as_any()
929 .downcast_ref::<Float64Array>()
930 .unwrap();
931 assert!(arr.value(0).is_nan());
932 } else {
933 panic!("Expected Event output");
934 }
935 }
936
937 #[test]
938 fn test_lag_second_event_returns_previous() {
939 let mut op = LagLeadOperator::new(lag_config(1));
940 let mut timers = TimerService::new();
941 let mut state = InMemoryStore::new();
942 let mut wm = BoundedOutOfOrdernessGenerator::new(0);
943
944 let e1 = make_trade(1, "AAPL", 150.0);
945 let e2 = make_trade(2, "AAPL", 155.0);
946 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
947 op.process(&e1, &mut ctx);
948 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
949 let outputs = op.process(&e2, &mut ctx);
950
951 if let Output::Event(e) = &outputs[0] {
952 let arr = e
953 .data
954 .column_by_name("prev_price")
955 .unwrap()
956 .as_any()
957 .downcast_ref::<Float64Array>()
958 .unwrap();
959 assert_eq!(arr.value(0), 150.0);
960 }
961 }
962
963 #[test]
964 fn test_lag_with_default() {
965 let mut op = LagLeadOperator::new(LagLeadConfig {
966 operator_id: "test".to_string(),
967 functions: vec![LagLeadFunctionSpec {
968 function_type: AnalyticFunctionKind::Lag,
969 source_column: "price".to_string(),
970 offset: 1,
971 default_value: Some(-1.0),
972 output_column: "prev_price".to_string(),
973 }],
974 partition_columns: vec!["symbol".to_string()],
975 max_partitions: 100,
976 });
977 let mut timers = TimerService::new();
978 let mut state = InMemoryStore::new();
979 let mut wm = BoundedOutOfOrdernessGenerator::new(0);
980
981 let event = make_trade(1, "AAPL", 150.0);
982 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
983 let outputs = op.process(&event, &mut ctx);
984
985 if let Output::Event(e) = &outputs[0] {
986 let arr = e
987 .data
988 .column_by_name("prev_price")
989 .unwrap()
990 .as_any()
991 .downcast_ref::<Float64Array>()
992 .unwrap();
993 assert_eq!(arr.value(0), -1.0);
994 }
995 }
996
997 #[test]
998 fn test_lag_offset_2() {
999 let mut op = LagLeadOperator::new(lag_config(2));
1000 let mut timers = TimerService::new();
1001 let mut state = InMemoryStore::new();
1002 let mut wm = BoundedOutOfOrdernessGenerator::new(0);
1003
1004 let events = [
1005 make_trade(1, "AAPL", 100.0),
1006 make_trade(2, "AAPL", 110.0),
1007 make_trade(3, "AAPL", 120.0),
1008 ];
1009
1010 for e in &events[..2] {
1011 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
1012 op.process(e, &mut ctx);
1013 }
1014
1015 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
1016 let outputs = op.process(&events[2], &mut ctx);
1017
1018 if let Output::Event(e) = &outputs[0] {
1019 let arr = e
1020 .data
1021 .column_by_name("prev_price")
1022 .unwrap()
1023 .as_any()
1024 .downcast_ref::<Float64Array>()
1025 .unwrap();
1026 assert_eq!(arr.value(0), 100.0); }
1028 }
1029
1030 #[test]
1031 fn test_lag_separate_partitions() {
1032 let mut op = LagLeadOperator::new(lag_config(1));
1033 let mut timers = TimerService::new();
1034 let mut state = InMemoryStore::new();
1035 let mut wm = BoundedOutOfOrdernessGenerator::new(0);
1036
1037 let a1 = make_trade(1, "AAPL", 150.0);
1039 let a2 = make_trade(3, "AAPL", 155.0);
1040 let g1 = make_trade(2, "GOOG", 2800.0);
1042
1043 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
1044 op.process(&a1, &mut ctx);
1045 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
1046 op.process(&g1, &mut ctx);
1047 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
1048 let outputs = op.process(&a2, &mut ctx);
1049
1050 if let Output::Event(e) = &outputs[0] {
1052 let arr = e
1053 .data
1054 .column_by_name("prev_price")
1055 .unwrap()
1056 .as_any()
1057 .downcast_ref::<Float64Array>()
1058 .unwrap();
1059 assert_eq!(arr.value(0), 150.0);
1060 }
1061 assert_eq!(op.partition_count(), 2);
1062 }
1063
1064 #[test]
1065 fn test_lead_buffers_events() {
1066 let mut op = LagLeadOperator::new(lead_config(1));
1067 let mut timers = TimerService::new();
1068 let mut state = InMemoryStore::new();
1069 let mut wm = BoundedOutOfOrdernessGenerator::new(0);
1070
1071 let e1 = make_trade(1, "AAPL", 150.0);
1072 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
1073 let outputs = op.process(&e1, &mut ctx);
1074
1075 assert!(outputs.is_empty());
1077 }
1078
1079 #[test]
1080 fn test_lead_resolves_on_next_event() {
1081 let mut op = LagLeadOperator::new(lead_config(1));
1082 let mut timers = TimerService::new();
1083 let mut state = InMemoryStore::new();
1084 let mut wm = BoundedOutOfOrdernessGenerator::new(0);
1085
1086 let e1 = make_trade(1, "AAPL", 150.0);
1087 let e2 = make_trade(2, "AAPL", 155.0);
1088
1089 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
1090 op.process(&e1, &mut ctx);
1091 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
1092 let outputs = op.process(&e2, &mut ctx);
1093
1094 assert_eq!(outputs.len(), 1);
1096 if let Output::Event(e) = &outputs[0] {
1097 let arr = e
1098 .data
1099 .column_by_name("next_price")
1100 .unwrap()
1101 .as_any()
1102 .downcast_ref::<Float64Array>()
1103 .unwrap();
1104 assert_eq!(arr.value(0), 155.0);
1105 }
1106 }
1107
1108 #[test]
1109 fn test_lead_flush_on_watermark() {
1110 let mut op = LagLeadOperator::new(lead_config(1));
1111 let mut timers = TimerService::new();
1112 let mut state = InMemoryStore::new();
1113 let mut wm = BoundedOutOfOrdernessGenerator::new(0);
1114
1115 let e1 = make_trade(1, "AAPL", 150.0);
1116 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
1117 op.process(&e1, &mut ctx);
1118
1119 let timer = Timer {
1121 key: TimerKey::default(),
1122 timestamp: 100,
1123 };
1124 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
1125 let outputs = op.on_timer(timer, &mut ctx);
1126
1127 assert_eq!(outputs.len(), 1);
1129 if let Output::Event(e) = &outputs[0] {
1130 let arr = e
1131 .data
1132 .column_by_name("next_price")
1133 .unwrap()
1134 .as_any()
1135 .downcast_ref::<Float64Array>()
1136 .unwrap();
1137 assert_eq!(arr.value(0), 0.0);
1138 }
1139 }
1140
1141 #[test]
1142 fn test_max_partitions() {
1143 let mut op = LagLeadOperator::new(LagLeadConfig {
1144 operator_id: "test".to_string(),
1145 functions: vec![LagLeadFunctionSpec {
1146 function_type: AnalyticFunctionKind::Lag,
1147 source_column: "price".to_string(),
1148 offset: 1,
1149 default_value: None,
1150 output_column: "prev_price".to_string(),
1151 }],
1152 partition_columns: vec!["symbol".to_string()],
1153 max_partitions: 2,
1154 });
1155 let mut timers = TimerService::new();
1156 let mut state = InMemoryStore::new();
1157 let mut wm = BoundedOutOfOrdernessGenerator::new(0);
1158
1159 let e1 = make_trade(1, "AAPL", 150.0);
1160 let e2 = make_trade(2, "GOOG", 2800.0);
1161 let e3 = make_trade(3, "MSFT", 300.0);
1162
1163 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
1164 op.process(&e1, &mut ctx);
1165 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
1166 op.process(&e2, &mut ctx);
1167 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
1168 let outputs = op.process(&e3, &mut ctx);
1169
1170 assert!(outputs.is_empty()); assert_eq!(op.partition_count(), 2);
1172 }
1173
1174 #[test]
1175 fn test_checkpoint_restore() {
1176 let mut op = LagLeadOperator::new(lag_config(1));
1177 let mut timers = TimerService::new();
1178 let mut state = InMemoryStore::new();
1179 let mut wm = BoundedOutOfOrdernessGenerator::new(0);
1180
1181 let events = vec![
1182 make_trade(1, "AAPL", 100.0),
1183 make_trade(2, "AAPL", 110.0),
1184 make_trade(3, "GOOG", 2800.0),
1185 ];
1186 for e in &events {
1187 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
1188 op.process(e, &mut ctx);
1189 }
1190
1191 let checkpoint = op.checkpoint();
1192 assert_eq!(checkpoint.operator_id, "test_lag");
1193
1194 let mut op2 = LagLeadOperator::new(lag_config(1));
1195 op2.restore(checkpoint).unwrap();
1196 assert_eq!(op2.partition_count(), 2);
1197 }
1198
1199 #[test]
1200 fn test_metrics() {
1201 let mut op = LagLeadOperator::new(lag_config(1));
1202 let mut timers = TimerService::new();
1203 let mut state = InMemoryStore::new();
1204 let mut wm = BoundedOutOfOrdernessGenerator::new(0);
1205
1206 let e1 = make_trade(1, "AAPL", 150.0);
1207 let e2 = make_trade(2, "AAPL", 155.0);
1208 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
1209 op.process(&e1, &mut ctx);
1210 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
1211 op.process(&e2, &mut ctx);
1212
1213 assert_eq!(op.metrics().events_processed, 2);
1214 assert_eq!(op.metrics().lag_lookups, 2);
1215 }
1216
1217 #[test]
1218 fn test_lead_separate_partitions() {
1219 let mut op = LagLeadOperator::new(lead_config(1));
1220 let mut timers = TimerService::new();
1221 let mut state = InMemoryStore::new();
1222 let mut wm = BoundedOutOfOrdernessGenerator::new(0);
1223
1224 let a1 = make_trade(1, "AAPL", 150.0);
1225 let g1 = make_trade(2, "GOOG", 2800.0);
1226 let a2 = make_trade(3, "AAPL", 155.0);
1227
1228 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
1229 op.process(&a1, &mut ctx);
1230 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
1231 op.process(&g1, &mut ctx);
1232 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
1233 let outputs = op.process(&a2, &mut ctx);
1234
1235 assert_eq!(outputs.len(), 1);
1237 if let Output::Event(e) = &outputs[0] {
1238 let arr = e
1239 .data
1240 .column_by_name("next_price")
1241 .unwrap()
1242 .as_any()
1243 .downcast_ref::<Float64Array>()
1244 .unwrap();
1245 assert_eq!(arr.value(0), 155.0);
1246 }
1247 }
1248
1249 #[test]
1250 fn test_empty_operator() {
1251 let op = LagLeadOperator::new(lag_config(1));
1252 assert_eq!(op.partition_count(), 0);
1253 assert_eq!(op.metrics().events_processed, 0);
1254 }
1255
1256 #[test]
1257 fn test_first_value() {
1258 let config = LagLeadConfig {
1259 operator_id: "first_val".to_string(),
1260 functions: vec![LagLeadFunctionSpec {
1261 function_type: AnalyticFunctionKind::FirstValue,
1262 source_column: "price".to_string(),
1263 offset: 0,
1264 default_value: None,
1265 output_column: "first_price".to_string(),
1266 }],
1267 partition_columns: vec!["symbol".to_string()],
1268 max_partitions: 100,
1269 };
1270
1271 let mut op = LagLeadOperator::new(config);
1272 let mut timers = TimerService::new();
1273 let mut state = InMemoryStore::new();
1274 let mut wm = BoundedOutOfOrdernessGenerator::new(100);
1275
1276 let events = [
1277 make_trade(1, "AAPL", 100.0),
1278 make_trade(2, "AAPL", 110.0),
1279 make_trade(3, "AAPL", 120.0),
1280 ];
1281
1282 for e in &events {
1283 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
1284 let outputs = op.process(e, &mut ctx);
1285 assert_eq!(outputs.len(), 1);
1286 if let Output::Event(out) = &outputs[0] {
1287 let arr = out
1288 .data
1289 .column_by_name("first_price")
1290 .unwrap()
1291 .as_any()
1292 .downcast_ref::<Float64Array>()
1293 .unwrap();
1294 assert_eq!(arr.value(0), 100.0);
1296 }
1297 }
1298 }
1299
1300 #[test]
1301 fn test_last_value() {
1302 let config = LagLeadConfig {
1303 operator_id: "last_val".to_string(),
1304 functions: vec![LagLeadFunctionSpec {
1305 function_type: AnalyticFunctionKind::LastValue,
1306 source_column: "price".to_string(),
1307 offset: 0,
1308 default_value: None,
1309 output_column: "last_price".to_string(),
1310 }],
1311 partition_columns: vec!["symbol".to_string()],
1312 max_partitions: 100,
1313 };
1314
1315 let mut op = LagLeadOperator::new(config);
1316 let mut timers = TimerService::new();
1317 let mut state = InMemoryStore::new();
1318 let mut wm = BoundedOutOfOrdernessGenerator::new(100);
1319
1320 let events = [
1321 make_trade(1, "AAPL", 100.0),
1322 make_trade(2, "AAPL", 110.0),
1323 make_trade(3, "AAPL", 120.0),
1324 ];
1325
1326 let expected = [100.0, 110.0, 120.0];
1327 for (e, &exp) in events.iter().zip(expected.iter()) {
1328 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
1329 let outputs = op.process(e, &mut ctx);
1330 assert_eq!(outputs.len(), 1);
1331 if let Output::Event(out) = &outputs[0] {
1332 let arr = out
1333 .data
1334 .column_by_name("last_price")
1335 .unwrap()
1336 .as_any()
1337 .downcast_ref::<Float64Array>()
1338 .unwrap();
1339 assert_eq!(arr.value(0), exp);
1341 }
1342 }
1343 }
1344
1345 #[test]
1346 fn test_nth_value() {
1347 let config = LagLeadConfig {
1348 operator_id: "nth_val".to_string(),
1349 functions: vec![LagLeadFunctionSpec {
1350 function_type: AnalyticFunctionKind::NthValue,
1351 source_column: "price".to_string(),
1352 offset: 2, default_value: Some(-1.0),
1354 output_column: "second_price".to_string(),
1355 }],
1356 partition_columns: vec!["symbol".to_string()],
1357 max_partitions: 100,
1358 };
1359
1360 let mut op = LagLeadOperator::new(config);
1361 let mut timers = TimerService::new();
1362 let mut state = InMemoryStore::new();
1363 let mut wm = BoundedOutOfOrdernessGenerator::new(100);
1364
1365 let events = [
1366 make_trade(1, "AAPL", 100.0),
1367 make_trade(2, "AAPL", 110.0),
1368 make_trade(3, "AAPL", 120.0),
1369 ];
1370
1371 {
1373 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
1374 let outputs = op.process(&events[0], &mut ctx);
1375 if let Output::Event(out) = &outputs[0] {
1376 let arr = out
1377 .data
1378 .column_by_name("second_price")
1379 .unwrap()
1380 .as_any()
1381 .downcast_ref::<Float64Array>()
1382 .unwrap();
1383 assert_eq!(arr.value(0), -1.0);
1384 }
1385 }
1386
1387 {
1389 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
1390 let outputs = op.process(&events[1], &mut ctx);
1391 if let Output::Event(out) = &outputs[0] {
1392 let arr = out
1393 .data
1394 .column_by_name("second_price")
1395 .unwrap()
1396 .as_any()
1397 .downcast_ref::<Float64Array>()
1398 .unwrap();
1399 assert_eq!(arr.value(0), 110.0);
1400 }
1401 }
1402
1403 {
1405 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
1406 let outputs = op.process(&events[2], &mut ctx);
1407 if let Output::Event(out) = &outputs[0] {
1408 let arr = out
1409 .data
1410 .column_by_name("second_price")
1411 .unwrap()
1412 .as_any()
1413 .downcast_ref::<Float64Array>()
1414 .unwrap();
1415 assert_eq!(arr.value(0), 110.0);
1416 }
1417 }
1418 }
1419}