Skip to main content

laminar_core/operator/
partitioned_topk.rs

1//! # Partitioned Top-K Operator
2//!
3//! Per-group top-K supporting the `ROW_NUMBER() OVER (PARTITION BY ... ORDER BY ...) WHERE rn <= N`
4//! pattern. Each partition key gets an independent top-K heap.
5//!
6//! ## Memory Bounds
7//!
8//! Memory is O(P * K) where P = distinct partitions, K = per-partition limit.
9//! A `max_partitions` safety limit prevents unbounded partition growth.
10//!
11//! ## Emit Strategies
12//!
13//! Same as the global top-K: `OnUpdate`, `OnWatermark`, or `Periodic`.
14//! Changelog records are emitted per-partition.
15
16use super::topk::{
17    encode_f64, encode_i64, encode_not_null, encode_null, encode_utf8, TopKEmitStrategy,
18    TopKSortColumn,
19};
20use super::window::ChangelogRecord;
21use super::{
22    Event, Operator, OperatorContext, OperatorError, OperatorState, Output, OutputVec, Timer,
23};
24use arrow_array::{Array, Float64Array, Int64Array, StringArray, TimestampMicrosecondArray};
25use arrow_schema::DataType;
26use rustc_hash::FxHashMap;
27
28/// Configuration for a partition key column.
29#[derive(Debug, Clone, PartialEq, Eq)]
30pub struct PartitionColumn {
31    /// Column name in the event schema.
32    pub column_name: String,
33}
34
35impl PartitionColumn {
36    /// Creates a new partition column.
37    #[must_use]
38    pub fn new(name: impl Into<String>) -> Self {
39        Self {
40            column_name: name.into(),
41        }
42    }
43}
44
45/// An entry in a per-partition top-K heap.
46#[derive(Debug, Clone)]
47struct PartitionEntry {
48    /// Memcomparable sort key.
49    sort_key: Vec<u8>,
50    /// The original event.
51    event: Event,
52}
53
54/// Partitioned top-K operator.
55///
56/// Maintains independent top-K heaps per partition key.
57/// Supports the `ROW_NUMBER() OVER (PARTITION BY ... ORDER BY ...) WHERE rn <= N` pattern.
58pub struct PartitionedTopKOperator {
59    /// Operator identifier for checkpointing.
60    operator_id: String,
61    /// Number of top entries per partition.
62    k: usize,
63    /// Partition key columns.
64    partition_columns: Vec<PartitionColumn>,
65    /// Sort column specifications.
66    sort_columns: Vec<TopKSortColumn>,
67    /// Per-partition top-K heaps, keyed by partition key bytes.
68    partitions: FxHashMap<Vec<u8>, Vec<PartitionEntry>>,
69    /// Emission strategy.
70    emit_strategy: TopKEmitStrategy,
71    /// Pending changelog records (for OnWatermark/Periodic strategies).
72    pending_changes: Vec<ChangelogRecord>,
73    /// Monotonic sequence counter.
74    sequence_counter: u64,
75    /// Maximum number of partitions (memory safety).
76    max_partitions: usize,
77    /// Cached column indices for partition columns — resolved on first event.
78    cached_partition_indices: Vec<Option<usize>>,
79    /// Cached column indices for sort columns — resolved on first event.
80    cached_sort_indices: Vec<Option<usize>>,
81}
82
83impl PartitionedTopKOperator {
84    /// Creates a new partitioned top-K operator.
85    #[must_use]
86    pub fn new(
87        operator_id: String,
88        k: usize,
89        partition_columns: Vec<PartitionColumn>,
90        sort_columns: Vec<TopKSortColumn>,
91        emit_strategy: TopKEmitStrategy,
92        max_partitions: usize,
93    ) -> Self {
94        let num_partition_cols = partition_columns.len();
95        let num_sort_cols = sort_columns.len();
96        Self {
97            operator_id,
98            k,
99            partition_columns,
100            sort_columns,
101            partitions: FxHashMap::default(),
102            emit_strategy,
103            pending_changes: Vec::new(),
104            sequence_counter: 0,
105            max_partitions,
106            cached_partition_indices: vec![None; num_partition_cols],
107            cached_sort_indices: vec![None; num_sort_cols],
108        }
109    }
110
111    /// Returns the number of active partitions.
112    #[must_use]
113    pub fn partition_count(&self) -> usize {
114        self.partitions.len()
115    }
116
117    /// Returns the total number of entries across all partitions.
118    #[must_use]
119    pub fn total_entries(&self) -> usize {
120        self.partitions.values().map(Vec::len).sum()
121    }
122
123    /// Returns the number of entries in a specific partition.
124    #[must_use]
125    pub fn partition_size(&self, partition_key: &[u8]) -> usize {
126        self.partitions.get(partition_key).map_or(0, Vec::len)
127    }
128
129    /// Returns the number of pending changelog records.
130    #[must_use]
131    pub fn pending_changes_count(&self) -> usize {
132        self.pending_changes.len()
133    }
134
135    /// Extracts the partition key from an event.
136    ///
137    /// Caches column indices on first call to avoid per-event schema lookups.
138    fn extract_partition_key(&mut self, event: &Event) -> Vec<u8> {
139        let batch = &event.data;
140        let mut key = Vec::new();
141
142        for (i, col) in self.partition_columns.iter().enumerate() {
143            let col_idx = if let Some(idx) = self.cached_partition_indices[i] {
144                idx
145            } else {
146                let Ok(idx) = batch.schema().index_of(&col.column_name) else {
147                    key.push(0x00);
148                    continue;
149                };
150                self.cached_partition_indices[i] = Some(idx);
151                idx
152            };
153
154            let array = batch.column(col_idx);
155
156            if array.is_null(0) {
157                key.push(0x00); // null marker
158                continue;
159            }
160
161            key.push(0x01); // non-null marker
162
163            match array.data_type() {
164                DataType::Int64 => {
165                    if let Some(arr) = array.as_any().downcast_ref::<Int64Array>() {
166                        key.extend_from_slice(&arr.value(0).to_le_bytes());
167                    } else {
168                        key.push(0x00);
169                    }
170                }
171                DataType::Utf8 => {
172                    if let Some(arr) = array.as_any().downcast_ref::<StringArray>() {
173                        let val = arr.value(0);
174                        key.extend_from_slice(val.as_bytes());
175                        key.push(0x00); // null terminator
176                    } else {
177                        key.push(0x00);
178                    }
179                }
180                DataType::Float64 => {
181                    if let Some(arr) = array.as_any().downcast_ref::<Float64Array>() {
182                        key.extend_from_slice(&arr.value(0).to_bits().to_le_bytes());
183                    } else {
184                        key.push(0x00);
185                    }
186                }
187                _ => {
188                    key.push(0x00); // unsupported type marker
189                }
190            }
191        }
192
193        key
194    }
195
196    /// Extracts a memcomparable sort key from an event.
197    ///
198    /// Caches column indices on first call to avoid per-event schema lookups.
199    fn extract_sort_key(&mut self, event: &Event) -> Vec<u8> {
200        let batch = &event.data;
201        let mut key = Vec::new();
202
203        for (i, col_spec) in self.sort_columns.iter().enumerate() {
204            let col_idx = if let Some(idx) = self.cached_sort_indices[i] {
205                idx
206            } else {
207                let Ok(idx) = batch.schema().index_of(&col_spec.column_name) else {
208                    encode_null(col_spec.nulls_first, col_spec.descending, &mut key);
209                    continue;
210                };
211                self.cached_sort_indices[i] = Some(idx);
212                idx
213            };
214
215            let array = batch.column(col_idx);
216
217            if array.is_null(0) {
218                encode_null(col_spec.nulls_first, col_spec.descending, &mut key);
219                continue;
220            }
221
222            match array.data_type() {
223                DataType::Int64 => {
224                    if let Some(arr) = array.as_any().downcast_ref::<Int64Array>() {
225                        encode_not_null(col_spec.nulls_first, col_spec.descending, &mut key);
226                        encode_i64(arr.value(0), col_spec.descending, &mut key);
227                    } else {
228                        encode_null(col_spec.nulls_first, col_spec.descending, &mut key);
229                    }
230                }
231                DataType::Float64 => {
232                    if let Some(arr) = array.as_any().downcast_ref::<Float64Array>() {
233                        encode_not_null(col_spec.nulls_first, col_spec.descending, &mut key);
234                        encode_f64(arr.value(0), col_spec.descending, &mut key);
235                    } else {
236                        encode_null(col_spec.nulls_first, col_spec.descending, &mut key);
237                    }
238                }
239                DataType::Utf8 => {
240                    if let Some(arr) = array.as_any().downcast_ref::<StringArray>() {
241                        encode_not_null(col_spec.nulls_first, col_spec.descending, &mut key);
242                        encode_utf8(arr.value(0), col_spec.descending, &mut key);
243                    } else {
244                        encode_null(col_spec.nulls_first, col_spec.descending, &mut key);
245                    }
246                }
247                DataType::Timestamp(_, _) => {
248                    if let Some(arr) = array.as_any().downcast_ref::<TimestampMicrosecondArray>() {
249                        encode_not_null(col_spec.nulls_first, col_spec.descending, &mut key);
250                        encode_i64(arr.value(0), col_spec.descending, &mut key);
251                    } else {
252                        encode_null(col_spec.nulls_first, col_spec.descending, &mut key);
253                    }
254                }
255                _ => {
256                    encode_null(col_spec.nulls_first, col_spec.descending, &mut key);
257                }
258            }
259        }
260
261        key
262    }
263
264    /// Processes an event for a specific partition, returning changelog records.
265    fn process_partition(
266        &mut self,
267        partition_key: Vec<u8>,
268        event: &Event,
269        emit_timestamp: i64,
270    ) -> Vec<ChangelogRecord> {
271        let sort_key = self.extract_sort_key(event);
272
273        let entries = self.partitions.entry(partition_key).or_default();
274
275        // Check if event enters this partition's top-K
276        if entries.len() >= self.k {
277            if let Some(worst) = entries.last() {
278                if sort_key >= worst.sort_key {
279                    return Vec::new(); // Doesn't enter top-K
280                }
281            }
282        }
283
284        // Find insertion position (binary search)
285        let insert_pos = entries
286            .binary_search_by(|entry| entry.sort_key.as_slice().cmp(&sort_key))
287            .unwrap_or_else(|pos| pos);
288
289        let new_entry = PartitionEntry {
290            sort_key,
291            event: event.clone(),
292        };
293        entries.insert(insert_pos, new_entry);
294
295        let mut changes = Vec::new();
296
297        // Generate insert changelog
298        changes.push(ChangelogRecord::insert(event.clone(), emit_timestamp));
299
300        // Generate rank change retractions for shifted entries
301        for entry in entries
302            .iter()
303            .take(entries.len().min(self.k))
304            .skip(insert_pos + 1)
305        {
306            let shifted_event = &entry.event;
307            let (before, after) = ChangelogRecord::update(
308                shifted_event.clone(),
309                shifted_event.clone(),
310                emit_timestamp,
311            );
312            changes.push(before);
313            changes.push(after);
314        }
315
316        // Evict worst entry if over capacity
317        if entries.len() > self.k {
318            let evicted = entries.pop().unwrap();
319            changes.push(ChangelogRecord::delete(evicted.event, emit_timestamp));
320        }
321
322        self.sequence_counter += 1;
323        changes
324    }
325
326    /// Flushes pending changelog records as Output.
327    fn flush_pending(&mut self) -> OutputVec {
328        let mut outputs = OutputVec::new();
329        for record in self.pending_changes.drain(..) {
330            outputs.push(Output::Changelog(record));
331        }
332        outputs
333    }
334}
335
336impl Operator for PartitionedTopKOperator {
337    fn process(&mut self, event: &Event, _ctx: &mut OperatorContext) -> OutputVec {
338        let partition_key = self.extract_partition_key(event);
339
340        // Check max partitions limit
341        if !self.partitions.contains_key(&partition_key)
342            && self.partitions.len() >= self.max_partitions
343        {
344            // Reject: too many partitions
345            return OutputVec::new();
346        }
347
348        let emit_timestamp = event.timestamp;
349        let changes = self.process_partition(partition_key, event, emit_timestamp);
350
351        match &self.emit_strategy {
352            TopKEmitStrategy::OnUpdate => {
353                let mut outputs = OutputVec::new();
354                for record in changes {
355                    outputs.push(Output::Changelog(record));
356                }
357                outputs
358            }
359            TopKEmitStrategy::OnWatermark | TopKEmitStrategy::Periodic(_) => {
360                self.pending_changes.extend(changes);
361                OutputVec::new()
362            }
363        }
364    }
365
366    fn on_timer(&mut self, _timer: Timer, _ctx: &mut OperatorContext) -> OutputVec {
367        match &self.emit_strategy {
368            TopKEmitStrategy::Periodic(_) => self.flush_pending(),
369            _ => OutputVec::new(),
370        }
371    }
372
373    fn checkpoint(&self) -> OperatorState {
374        let mut data = Vec::new();
375
376        // Write partition count
377        let num_partitions = self.partitions.len() as u64;
378        data.extend_from_slice(&num_partitions.to_le_bytes());
379
380        // Write sequence counter
381        data.extend_from_slice(&self.sequence_counter.to_le_bytes());
382
383        // Write each partition
384        for (key, entries) in &self.partitions {
385            // Partition key length + bytes
386            let key_len = key.len() as u64;
387            data.extend_from_slice(&key_len.to_le_bytes());
388            data.extend_from_slice(key);
389
390            // Entry count
391            let entry_count = entries.len() as u64;
392            data.extend_from_slice(&entry_count.to_le_bytes());
393
394            // Each entry: sort_key_len + sort_key + timestamp
395            for entry in entries {
396                let sk_len = entry.sort_key.len() as u64;
397                data.extend_from_slice(&sk_len.to_le_bytes());
398                data.extend_from_slice(&entry.sort_key);
399                data.extend_from_slice(&entry.event.timestamp.to_le_bytes());
400            }
401        }
402
403        OperatorState {
404            operator_id: self.operator_id.clone(),
405            data,
406        }
407    }
408
409    #[allow(clippy::cast_possible_truncation)] // Checkpoint wire format uses u64 for counts
410    fn restore(&mut self, state: OperatorState) -> Result<(), OperatorError> {
411        if state.data.len() < 16 {
412            return Err(OperatorError::SerializationFailed(
413                "PartitionedTopK checkpoint data too short".to_string(),
414            ));
415        }
416
417        let mut offset = 0;
418
419        let num_partitions = u64::from_le_bytes(
420            state.data[offset..offset + 8]
421                .try_into()
422                .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
423        ) as usize;
424        offset += 8;
425
426        self.sequence_counter = u64::from_le_bytes(
427            state.data[offset..offset + 8]
428                .try_into()
429                .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
430        );
431        offset += 8;
432
433        self.partitions.clear();
434
435        for _ in 0..num_partitions {
436            if offset + 8 > state.data.len() {
437                return Err(OperatorError::SerializationFailed(
438                    "PartitionedTopK checkpoint truncated".to_string(),
439                ));
440            }
441            let key_len = u64::from_le_bytes(
442                state.data[offset..offset + 8]
443                    .try_into()
444                    .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
445            ) as usize;
446            offset += 8;
447
448            if offset + key_len + 8 > state.data.len() {
449                return Err(OperatorError::SerializationFailed(
450                    "PartitionedTopK checkpoint truncated at key".to_string(),
451                ));
452            }
453            let partition_key = state.data[offset..offset + key_len].to_vec();
454            offset += key_len;
455
456            let entry_count = u64::from_le_bytes(
457                state.data[offset..offset + 8]
458                    .try_into()
459                    .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
460            ) as usize;
461            offset += 8;
462
463            let mut entries = Vec::with_capacity(entry_count);
464            for _ in 0..entry_count {
465                if offset + 8 > state.data.len() {
466                    return Err(OperatorError::SerializationFailed(
467                        "PartitionedTopK checkpoint truncated at entry".to_string(),
468                    ));
469                }
470                let sk_len = u64::from_le_bytes(
471                    state.data[offset..offset + 8]
472                        .try_into()
473                        .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
474                ) as usize;
475                offset += 8;
476
477                if offset + sk_len + 8 > state.data.len() {
478                    return Err(OperatorError::SerializationFailed(
479                        "PartitionedTopK checkpoint truncated at sort key".to_string(),
480                    ));
481                }
482                let sort_key = state.data[offset..offset + sk_len].to_vec();
483                offset += sk_len;
484
485                let timestamp = i64::from_le_bytes(
486                    state.data[offset..offset + 8]
487                        .try_into()
488                        .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
489                );
490                offset += 8;
491
492                let batch = arrow_array::RecordBatch::new_empty(std::sync::Arc::new(
493                    arrow_schema::Schema::empty(),
494                ));
495                entries.push(PartitionEntry {
496                    sort_key,
497                    event: Event::new(timestamp, batch),
498                });
499            }
500
501            self.partitions.insert(partition_key, entries);
502        }
503
504        Ok(())
505    }
506}
507
508#[cfg(test)]
509#[allow(clippy::uninlined_format_args)]
510#[allow(clippy::cast_precision_loss)]
511mod tests {
512    use super::super::window::CdcOperation;
513    use super::*;
514    use crate::state::InMemoryStore;
515    use crate::time::{BoundedOutOfOrdernessGenerator, TimerService};
516    use arrow_array::{Float64Array, Int64Array, RecordBatch, StringArray};
517    use arrow_schema::{DataType, Field, Schema};
518    use std::sync::Arc;
519
520    fn make_trade(timestamp: i64, category: &str, price: f64) -> Event {
521        let schema = Arc::new(Schema::new(vec![
522            Field::new("category", DataType::Utf8, false),
523            Field::new("price", DataType::Float64, false),
524        ]));
525        let batch = RecordBatch::try_new(
526            schema,
527            vec![
528                Arc::new(StringArray::from(vec![category])),
529                Arc::new(Float64Array::from(vec![price])),
530            ],
531        )
532        .unwrap();
533        Event::new(timestamp, batch)
534    }
535
536    fn make_trade_int(timestamp: i64, category: &str, value: i64) -> Event {
537        let schema = Arc::new(Schema::new(vec![
538            Field::new("category", DataType::Utf8, false),
539            Field::new("value", DataType::Int64, false),
540        ]));
541        let batch = RecordBatch::try_new(
542            schema,
543            vec![
544                Arc::new(StringArray::from(vec![category])),
545                Arc::new(Int64Array::from(vec![value])),
546            ],
547        )
548        .unwrap();
549        Event::new(timestamp, batch)
550    }
551
552    fn create_test_context<'a>(
553        timers: &'a mut TimerService,
554        state: &'a mut dyn crate::state::StateStore,
555        watermark_gen: &'a mut dyn crate::time::WatermarkGenerator,
556    ) -> OperatorContext<'a> {
557        OperatorContext {
558            event_time: 0,
559            processing_time: 0,
560            timers,
561            state,
562            watermark_generator: watermark_gen,
563            operator_index: 0,
564        }
565    }
566
567    fn create_partitioned_topk(k: usize, max_partitions: usize) -> PartitionedTopKOperator {
568        PartitionedTopKOperator::new(
569            "test_ptopk".to_string(),
570            k,
571            vec![PartitionColumn::new("category")],
572            vec![TopKSortColumn::descending("price")],
573            TopKEmitStrategy::OnUpdate,
574            max_partitions,
575        )
576    }
577
578    #[test]
579    fn test_partitioned_topk_single_partition() {
580        let mut op = create_partitioned_topk(3, 100);
581        let mut timers = TimerService::new();
582        let mut state = InMemoryStore::new();
583        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
584
585        let trades = vec![
586            make_trade(1, "A", 100.0),
587            make_trade(2, "A", 200.0),
588            make_trade(3, "A", 150.0),
589        ];
590
591        for trade in &trades {
592            let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
593            op.process(trade, &mut ctx);
594        }
595
596        assert_eq!(op.partition_count(), 1);
597        assert_eq!(op.total_entries(), 3);
598    }
599
600    #[test]
601    fn test_partitioned_topk_multiple_partitions() {
602        let mut op = create_partitioned_topk(2, 100);
603        let mut timers = TimerService::new();
604        let mut state = InMemoryStore::new();
605        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
606
607        let trades = vec![
608            make_trade(1, "A", 100.0),
609            make_trade(2, "B", 200.0),
610            make_trade(3, "A", 150.0),
611            make_trade(4, "B", 250.0),
612        ];
613
614        for trade in &trades {
615            let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
616            op.process(trade, &mut ctx);
617        }
618
619        assert_eq!(op.partition_count(), 2);
620        assert_eq!(op.total_entries(), 4);
621    }
622
623    #[test]
624    fn test_partitioned_topk_eviction_in_partition() {
625        let mut op = create_partitioned_topk(2, 100);
626        let mut timers = TimerService::new();
627        let mut state = InMemoryStore::new();
628        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
629
630        // Fill partition "A" to capacity
631        let e1 = make_trade(1, "A", 200.0);
632        let e2 = make_trade(2, "A", 150.0);
633        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
634        op.process(&e1, &mut ctx);
635        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
636        op.process(&e2, &mut ctx);
637
638        // Better entry evicts worst in partition
639        let e3 = make_trade(3, "A", 300.0);
640        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
641        let outputs = op.process(&e3, &mut ctx);
642
643        // Should evict price=150 and keep 300, 200
644        assert_eq!(op.total_entries(), 2);
645        assert!(!outputs.is_empty());
646
647        // Verify we have Insert + Delete among outputs
648        let mut has_insert = false;
649        let mut has_delete = false;
650        for output in &outputs {
651            if let Output::Changelog(rec) = output {
652                match rec.operation {
653                    CdcOperation::Insert => has_insert = true,
654                    CdcOperation::Delete => has_delete = true,
655                    _ => {}
656                }
657            }
658        }
659        assert!(has_insert);
660        assert!(has_delete);
661    }
662
663    #[test]
664    fn test_partitioned_topk_no_cross_partition_eviction() {
665        let mut op = create_partitioned_topk(2, 100);
666        let mut timers = TimerService::new();
667        let mut state = InMemoryStore::new();
668        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
669
670        // Fill partition "A" to capacity
671        let e1 = make_trade(1, "A", 200.0);
672        let e2 = make_trade(2, "A", 150.0);
673        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
674        op.process(&e1, &mut ctx);
675        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
676        op.process(&e2, &mut ctx);
677
678        // New entry in partition "B" does NOT evict from "A"
679        let e3 = make_trade(3, "B", 50.0);
680        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
681        op.process(&e3, &mut ctx);
682
683        assert_eq!(op.partition_count(), 2);
684        assert_eq!(op.total_entries(), 3); // 2 in A + 1 in B
685    }
686
687    #[test]
688    fn test_partitioned_topk_emit_on_update() {
689        let mut op = create_partitioned_topk(3, 100);
690        let mut timers = TimerService::new();
691        let mut state = InMemoryStore::new();
692        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
693
694        let trade = make_trade(1, "A", 100.0);
695        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
696        let outputs = op.process(&trade, &mut ctx);
697
698        // OnUpdate: should emit immediately
699        assert!(!outputs.is_empty());
700        match &outputs[0] {
701            Output::Changelog(rec) => {
702                assert_eq!(rec.operation, CdcOperation::Insert);
703            }
704            _ => panic!("Expected Changelog output"),
705        }
706    }
707
708    #[test]
709    fn test_partitioned_topk_emit_on_watermark() {
710        let mut op = PartitionedTopKOperator::new(
711            "test_ptopk".to_string(),
712            2,
713            vec![PartitionColumn::new("category")],
714            vec![TopKSortColumn::descending("price")],
715            TopKEmitStrategy::OnWatermark,
716            100,
717        );
718
719        let mut timers = TimerService::new();
720        let mut state = InMemoryStore::new();
721        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
722
723        let trade = make_trade(1, "A", 100.0);
724        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
725        let outputs = op.process(&trade, &mut ctx);
726
727        // OnWatermark: should buffer, not emit
728        assert!(outputs.is_empty());
729        assert!(op.pending_changes_count() > 0);
730    }
731
732    #[test]
733    fn test_partitioned_topk_empty_partition() {
734        let op = create_partitioned_topk(3, 100);
735        assert_eq!(op.partition_count(), 0);
736        assert_eq!(op.total_entries(), 0);
737    }
738
739    #[test]
740    fn test_partitioned_topk_max_partitions() {
741        let mut op = create_partitioned_topk(2, 2); // max 2 partitions
742        let mut timers = TimerService::new();
743        let mut state = InMemoryStore::new();
744        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
745
746        // Create 2 partitions
747        let e1 = make_trade(1, "A", 100.0);
748        let e2 = make_trade(2, "B", 200.0);
749        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
750        op.process(&e1, &mut ctx);
751        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
752        op.process(&e2, &mut ctx);
753
754        // Third partition rejected
755        let e3 = make_trade(3, "C", 300.0);
756        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
757        let outputs = op.process(&e3, &mut ctx);
758
759        assert!(outputs.is_empty());
760        assert_eq!(op.partition_count(), 2);
761    }
762
763    #[test]
764    fn test_partitioned_topk_k_equals_one() {
765        let mut op = create_partitioned_topk(1, 100);
766        let mut timers = TimerService::new();
767        let mut state = InMemoryStore::new();
768        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
769
770        let e1 = make_trade(1, "A", 100.0);
771        let e2 = make_trade(2, "A", 200.0);
772        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
773        op.process(&e1, &mut ctx);
774        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
775        op.process(&e2, &mut ctx);
776
777        // Only best entry kept per partition
778        assert_eq!(op.total_entries(), 1);
779    }
780
781    #[test]
782    fn test_partitioned_topk_multi_column_partition_key() {
783        let mut op = PartitionedTopKOperator::new(
784            "test_ptopk".to_string(),
785            3,
786            vec![
787                PartitionColumn::new("category"),
788                PartitionColumn::new("value"),
789            ],
790            vec![TopKSortColumn::descending("value")],
791            TopKEmitStrategy::OnUpdate,
792            100,
793        );
794
795        let mut timers = TimerService::new();
796        let mut state = InMemoryStore::new();
797        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
798
799        let e1 = make_trade_int(1, "A", 100);
800        let e2 = make_trade_int(2, "A", 200);
801        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
802        op.process(&e1, &mut ctx);
803        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
804        op.process(&e2, &mut ctx);
805
806        // Each (category, value) combo is a unique partition
807        assert_eq!(op.partition_count(), 2);
808    }
809
810    #[test]
811    fn test_partitioned_topk_multi_column_sort() {
812        let mut op = PartitionedTopKOperator::new(
813            "test_ptopk".to_string(),
814            3,
815            vec![PartitionColumn::new("category")],
816            vec![TopKSortColumn::descending("price")],
817            TopKEmitStrategy::OnUpdate,
818            100,
819        );
820
821        let mut timers = TimerService::new();
822        let mut state = InMemoryStore::new();
823        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
824
825        let trades = vec![
826            make_trade(1, "A", 100.0),
827            make_trade(2, "A", 300.0),
828            make_trade(3, "A", 200.0),
829        ];
830
831        for trade in &trades {
832            let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
833            op.process(trade, &mut ctx);
834        }
835
836        assert_eq!(op.total_entries(), 3);
837    }
838
839    #[test]
840    fn test_partitioned_topk_checkpoint_restore() {
841        let mut op = create_partitioned_topk(3, 100);
842        let mut timers = TimerService::new();
843        let mut state = InMemoryStore::new();
844        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
845
846        let trades = vec![
847            make_trade(1, "A", 100.0),
848            make_trade(2, "B", 200.0),
849            make_trade(3, "A", 150.0),
850        ];
851
852        for trade in &trades {
853            let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
854            op.process(trade, &mut ctx);
855        }
856
857        let checkpoint = op.checkpoint();
858        assert_eq!(checkpoint.operator_id, "test_ptopk");
859
860        let mut op2 = create_partitioned_topk(3, 100);
861        op2.restore(checkpoint).unwrap();
862
863        assert_eq!(op2.partition_count(), 2);
864        assert_eq!(op2.total_entries(), 3);
865    }
866
867    #[test]
868    fn test_partitioned_topk_rank_changes() {
869        let mut op = create_partitioned_topk(3, 100);
870        let mut timers = TimerService::new();
871        let mut state = InMemoryStore::new();
872        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
873
874        // Insert two entries
875        let e1 = make_trade(1, "A", 100.0);
876        let e2 = make_trade(2, "A", 200.0);
877        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
878        op.process(&e1, &mut ctx);
879        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
880        op.process(&e2, &mut ctx);
881
882        // Insert between them causes rank change for price=100
883        let e3 = make_trade(3, "A", 150.0);
884        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
885        let outputs = op.process(&e3, &mut ctx);
886
887        // Should have Insert + UpdateBefore + UpdateAfter
888        let mut has_update_before = false;
889        let mut has_update_after = false;
890        for output in &outputs {
891            if let Output::Changelog(rec) = output {
892                match rec.operation {
893                    CdcOperation::UpdateBefore => has_update_before = true,
894                    CdcOperation::UpdateAfter => has_update_after = true,
895                    _ => {}
896                }
897            }
898        }
899        assert!(has_update_before);
900        assert!(has_update_after);
901    }
902
903    #[test]
904    fn test_partitioned_topk_row_number_pattern() {
905        // Simulates ROW_NUMBER() OVER (PARTITION BY category ORDER BY price DESC) WHERE rn <= 2
906        let mut op = create_partitioned_topk(2, 100);
907        let mut timers = TimerService::new();
908        let mut state = InMemoryStore::new();
909        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
910
911        let trades = vec![
912            make_trade(1, "tech", 100.0),
913            make_trade(2, "tech", 200.0),
914            make_trade(3, "tech", 150.0), // evicts 100
915            make_trade(4, "finance", 300.0),
916            make_trade(5, "finance", 250.0),
917        ];
918
919        for trade in &trades {
920            let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
921            op.process(trade, &mut ctx);
922        }
923
924        assert_eq!(op.partition_count(), 2);
925        assert_eq!(op.total_entries(), 4); // 2 in tech + 2 in finance
926    }
927
928    #[test]
929    fn test_partitioned_topk_string_partition_key() {
930        let mut op = create_partitioned_topk(3, 100);
931        let mut timers = TimerService::new();
932        let mut state = InMemoryStore::new();
933        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
934
935        let trades = vec![
936            make_trade(1, "electronics", 100.0),
937            make_trade(2, "clothing", 200.0),
938            make_trade(3, "electronics", 150.0),
939        ];
940
941        for trade in &trades {
942            let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
943            op.process(trade, &mut ctx);
944        }
945
946        assert_eq!(op.partition_count(), 2);
947    }
948
949    #[test]
950    fn test_partitioned_topk_null_partition_key() {
951        let mut op = create_partitioned_topk(3, 100);
952        let mut timers = TimerService::new();
953        let mut state = InMemoryStore::new();
954        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
955
956        // Create event with null category
957        let schema = Arc::new(Schema::new(vec![
958            Field::new("category", DataType::Utf8, true),
959            Field::new("price", DataType::Float64, false),
960        ]));
961        let batch = RecordBatch::try_new(
962            schema,
963            vec![
964                Arc::new(StringArray::new_null(1)),
965                Arc::new(Float64Array::from(vec![100.0])),
966            ],
967        )
968        .unwrap();
969        let null_event = Event::new(1, batch);
970
971        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
972        op.process(&null_event, &mut ctx);
973
974        // Null partition key should still create a partition
975        assert_eq!(op.partition_count(), 1);
976    }
977
978    #[test]
979    fn test_partitioned_topk_changelog_per_partition() {
980        let mut op = create_partitioned_topk(2, 100);
981        let mut timers = TimerService::new();
982        let mut state = InMemoryStore::new();
983        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
984
985        // Insert in partition A
986        let e1 = make_trade(1, "A", 100.0);
987        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
988        let out_a = op.process(&e1, &mut ctx);
989
990        // Insert in partition B
991        let e2 = make_trade(2, "B", 200.0);
992        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
993        let out_b = op.process(&e2, &mut ctx);
994
995        // Both should independently emit Insert changelog
996        assert_eq!(out_a.len(), 1);
997        assert_eq!(out_b.len(), 1);
998    }
999
1000    #[test]
1001    fn test_partitioned_topk_large_partitions() {
1002        let mut op = create_partitioned_topk(5, 1000);
1003        let mut timers = TimerService::new();
1004        let mut state = InMemoryStore::new();
1005        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
1006
1007        // Create many partitions with a few entries each
1008        for i in 0..50 {
1009            let category = format!("cat_{}", i);
1010            for j in 0..3 {
1011                let trade = make_trade(i * 100 + j, &category, j as f64 * 10.0);
1012                let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
1013                op.process(&trade, &mut ctx);
1014            }
1015        }
1016
1017        assert_eq!(op.partition_count(), 50);
1018        assert_eq!(op.total_entries(), 150); // 50 partitions * 3 entries each
1019    }
1020}