1use std::sync::atomic::{AtomicBool, Ordering};
23use std::sync::Arc;
24use std::time::{Duration, Instant};
25
26use arrow_array::RecordBatch;
27use crossfire::{mpsc, AsyncRx};
28use laminar_connectors::checkpoint::SourceCheckpoint;
29use laminar_connectors::connector::{DeliveryGuarantee, SourceBatch};
30use laminar_connectors::error::ConnectorError;
31use laminar_core::alloc::{PriorityClass, PriorityGuard};
32use laminar_core::checkpoint::{CheckpointBarrier, CheckpointBarrierInjector};
33use rustc_hash::{FxHashMap, FxHashSet};
34
35use super::callback::{PipelineCallback, SourceRegistration};
36use super::config::PipelineConfig;
37use crate::error::DbError;
38
39type SourceMsgRx = AsyncRx<mpsc::Array<SourceMsg>>;
41type ControlMsgRx = AsyncRx<mpsc::Array<super::ControlMsg>>;
43
44enum SourceMsg {
51 Batch {
53 source_idx: usize,
54 batch: RecordBatch,
55 checkpoint: SourceCheckpoint,
59 },
60 Barrier {
62 source_idx: usize,
63 barrier: CheckpointBarrier,
64 checkpoint: SourceCheckpoint,
66 },
67}
68
69struct SourceHandle {
71 name: Arc<str>,
72 shutdown: Arc<tokio::sync::Notify>,
73 join: tokio::task::JoinHandle<()>,
74 barrier_injector: CheckpointBarrierInjector,
76 epoch_committed_tx: tokio::sync::watch::Sender<Option<u64>>,
79}
80
81pub struct StreamingCoordinator {
86 config: PipelineConfig,
87 rx: SourceMsgRx,
89 source_handles: Vec<SourceHandle>,
91 source_names: Vec<Arc<str>>,
93 shutdown: Arc<tokio::sync::Notify>,
95 pending_barrier: PendingBarrier,
97 next_checkpoint_id: u64,
99 last_checkpoint: Instant,
101 checkpoint_request_flags: Vec<Arc<AtomicBool>>,
103 source_batches_buf: FxHashMap<Arc<str>, Vec<RecordBatch>>,
105 post_barrier_buf: Vec<SourceMsg>,
110 pending_watermark_batches: Vec<(Arc<str>, RecordBatch)>,
111 barrier_seen: FxHashSet<usize>,
115 committed_offsets: Vec<Option<SourceCheckpoint>>,
118 pending_offsets: Vec<Option<SourceCheckpoint>>,
121 control_rx: ControlMsgRx,
123}
124
125struct PendingBarrier {
127 checkpoint_id: u64,
128 sources_total: usize,
129 sources_aligned: FxHashSet<usize>,
130 source_checkpoints: FxHashMap<String, SourceCheckpoint>,
131 started_at: Instant,
132 active: bool,
133}
134
135impl PendingBarrier {
136 fn new() -> Self {
137 Self {
138 checkpoint_id: 0,
139 sources_total: 0,
140 sources_aligned: FxHashSet::default(),
141 source_checkpoints: FxHashMap::default(),
142 started_at: Instant::now(),
143 active: false,
144 }
145 }
146
147 fn reset(&mut self, checkpoint_id: u64, sources_total: usize) {
148 self.checkpoint_id = checkpoint_id;
149 self.sources_total = sources_total;
150 self.sources_aligned.clear();
151 self.source_checkpoints.clear();
152 self.started_at = Instant::now();
153 self.active = true;
154 }
155}
156
157const IDLE_TIMEOUT: Duration = Duration::from_millis(100);
159
160enum SourceWake {
162 Shutdown,
163 EpochCommitted(u64),
164 Polled(Result<Option<SourceBatch>, ConnectorError>),
165}
166
167impl StreamingCoordinator {
168 fn broadcast_epoch_committed(&self, epoch: u64) {
170 for handle in &self.source_handles {
171 let _ = handle.epoch_committed_tx.send(Some(epoch));
172 }
173 }
174
175 #[allow(clippy::too_many_lines)]
185 pub async fn new(
186 sources: Vec<SourceRegistration>,
187 config: PipelineConfig,
188 shutdown: Arc<tokio::sync::Notify>,
189 control_rx: ControlMsgRx,
190 ) -> Result<Self, DbError> {
191 if config.delivery_guarantee == DeliveryGuarantee::ExactlyOnce {
193 for src in &sources {
194 if !src.supports_replay {
195 return Err(DbError::Config(format!(
196 "[LDB-5031] exactly-once requires source '{}' to support replay",
197 src.name
198 )));
199 }
200 }
201 if config.checkpoint_interval.is_none() {
202 return Err(DbError::Config(
203 "[LDB-5032] exactly-once requires checkpointing to be enabled".into(),
204 ));
205 }
206 }
207
208 if config.channel_capacity == 0 {
209 return Err(DbError::Config(
210 "[LDB-0010] channel_capacity must be > 0".into(),
211 ));
212 }
213
214 let (tx, rx) = mpsc::bounded_async::<SourceMsg>(config.channel_capacity);
216
217 let mut source_handles = Vec::with_capacity(sources.len());
218 let mut source_names = Vec::with_capacity(sources.len());
219 let mut checkpoint_request_flags = Vec::new();
220 let mut committed_offsets = Vec::with_capacity(sources.len());
221
222 for (idx, src) in sources.into_iter().enumerate() {
223 if let Some(flag) = src.connector.checkpoint_requested() {
224 checkpoint_request_flags.push(flag);
225 }
226
227 let task_shutdown = Arc::new(tokio::sync::Notify::new());
228 let task_shutdown_clone = Arc::clone(&task_shutdown);
229 let task_tx = tx.clone();
230 let max_poll = config.max_poll_records;
231 let poll_interval = config.fallback_poll_interval;
232 let src_name = src.name.clone();
233 let restore = src.restore_checkpoint.clone();
234 let mut connector = src.connector;
235 let connector_config = src.config;
236
237 connector
239 .open(&connector_config)
240 .await
241 .map_err(|e| DbError::Config(format!("source '{src_name}' open failed: {e}")))?;
242
243 if let Some(ref cp) = restore {
245 if let Err(e) = connector.restore(cp).await {
246 tracing::warn!(
247 source = %src_name, error = %e,
248 "source restore failed, starting from beginning"
249 );
250 }
251 }
252
253 committed_offsets.push(src.restore_checkpoint);
256
257 let barrier_injector = CheckpointBarrierInjector::new();
259 let barrier_handle = barrier_injector.handle();
260
261 let (epoch_committed_tx, mut epoch_committed_rx) =
262 tokio::sync::watch::channel::<Option<u64>>(None);
263
264 let join = tokio::spawn(async move {
265 let mut epoch: u64 = 0;
266
267 loop {
270 let wake = tokio::select! {
271 biased;
272 () = task_shutdown_clone.notified() => SourceWake::Shutdown,
273 r = epoch_committed_rx.changed() => match r {
274 Ok(()) => match *epoch_committed_rx.borrow_and_update() {
275 Some(e) => SourceWake::EpochCommitted(e),
276 None => continue,
277 },
278 Err(_) => SourceWake::Shutdown,
279 },
280 r = connector.poll_batch(max_poll) => SourceWake::Polled(r),
281 };
282
283 let poll_result = match wake {
284 SourceWake::Shutdown => break,
285 SourceWake::EpochCommitted(e) => {
286 if let Err(err) = connector.notify_epoch_committed(e).await {
287 tracing::warn!(
288 source = %src_name,
289 error = %err,
290 epoch = e,
291 "notify_epoch_committed failed",
292 );
293 }
294 continue;
295 }
296 SourceWake::Polled(r) => r,
297 };
298
299 match poll_result {
300 Ok(Some(batch)) => {
301 let cp = connector.checkpoint();
305 let msg = SourceMsg::Batch {
306 source_idx: idx,
307 batch: batch.records,
308 checkpoint: cp,
309 };
310 if task_tx.send(msg).await.is_err() {
311 break; }
313 }
314 Ok(None) => {
315 tokio::select! {
317 biased;
318 () = task_shutdown_clone.notified() => break,
319 () = tokio::time::sleep(poll_interval) => {}
320 }
321 }
322 Err(e) if !e.is_transient() => {
323 tracing::error!(source = %src_name, error = %e, "terminal poll error");
324 break;
325 }
326 Err(e) => {
327 tracing::warn!(source = %src_name, error = %e, "poll error (retrying)");
328 tokio::select! {
329 biased;
330 () = task_shutdown_clone.notified() => break,
331 () = tokio::time::sleep(poll_interval) => {}
332 }
333 }
334 }
335
336 if let Some(barrier) = barrier_handle.poll(epoch) {
338 epoch += 1;
339 let cp = connector.checkpoint();
340 let msg = SourceMsg::Barrier {
341 source_idx: idx,
342 barrier,
343 checkpoint: cp,
344 };
345 if task_tx.send(msg).await.is_err() {
346 break;
347 }
348 }
349 }
350
351 while let Ok(Some(batch)) = connector.poll_batch(max_poll).await {
355 let cp = connector.checkpoint();
356 let msg = SourceMsg::Batch {
357 source_idx: idx,
358 batch: batch.records,
359 checkpoint: cp,
360 };
361 if task_tx.send(msg).await.is_err() {
362 break;
363 }
364 }
365
366 if let Err(e) = connector.close().await {
367 tracing::warn!(source = %src_name, error = %e, "source close error");
368 }
369 });
370
371 let arc_name: Arc<str> = Arc::from(src.name.as_str());
372 source_handles.push(SourceHandle {
373 name: Arc::clone(&arc_name),
374 shutdown: task_shutdown,
375 join,
376 barrier_injector,
377 epoch_committed_tx,
378 });
379 source_names.push(arc_name);
380 }
381
382 Ok(Self {
383 config,
384 rx,
385 source_handles,
386 source_names,
387 shutdown,
388 pending_barrier: PendingBarrier::new(),
389 next_checkpoint_id: 1,
390 last_checkpoint: Instant::now(),
391 checkpoint_request_flags,
392 source_batches_buf: FxHashMap::default(),
393 post_barrier_buf: Vec::new(),
394 pending_watermark_batches: Vec::new(),
395 barrier_seen: FxHashSet::default(),
396 pending_offsets: vec![None; committed_offsets.len()],
397 committed_offsets,
398 control_rx,
399 })
400 }
401
402 #[allow(clippy::too_many_lines)]
416 pub async fn run<C: PipelineCallback>(mut self, mut callback: C) {
417 const MAX_DRAIN_PER_CYCLE: usize = 10_000;
419
420 let batch_window = self.config.batch_window;
421 let mut barriers_buf: Vec<(usize, CheckpointBarrier, SourceCheckpoint)> = Vec::new();
422
423 loop {
424 let msg = tokio::select! {
426 biased;
427 () = self.shutdown.notified() => break,
428 msg = self.rx.recv() => {
429 match msg {
430 Ok(m) => {
431 if !batch_window.is_zero() {
434 tokio::time::sleep(batch_window).await;
435 }
436 Some(m)
437 }
438 Err(_) => break, }
440 }
441 () = tokio::time::sleep(IDLE_TIMEOUT) => None,
442 };
443
444 let event_priority = PriorityGuard::enter(PriorityClass::EventProcessing);
446 self.source_batches_buf.clear();
447 self.barrier_seen.clear();
448 self.discard_pending_offsets();
449 barriers_buf.clear();
450 let mut cycle_events: u64 = 0;
451 let cycle_start = Instant::now();
452
453 let deferred = std::mem::take(&mut self.post_barrier_buf);
457 for deferred_msg in deferred {
458 self.process_msg(
459 deferred_msg,
460 &mut callback,
461 &mut barriers_buf,
462 &mut cycle_events,
463 );
464 }
465
466 let had_data = msg.is_some();
467 if let Some(first_msg) = msg {
468 self.process_msg(
469 first_msg,
470 &mut callback,
471 &mut barriers_buf,
472 &mut cycle_events,
473 );
474 }
475
476 let mut drain_count = 0;
481 let drain_budget_ns = self.config.drain_budget_ns;
482 let backpressured = had_data && callback.is_backpressured();
483 if backpressured {
484 tracing::debug!("operator graph backpressured — skipping drain");
485 }
486 #[allow(clippy::cast_possible_truncation)]
487 while !backpressured
488 && drain_count < MAX_DRAIN_PER_CYCLE
489 && (cycle_start.elapsed().as_nanos() as u64) < drain_budget_ns
490 {
491 match self.rx.try_recv() {
492 Ok(msg) => {
493 self.process_msg(msg, &mut callback, &mut barriers_buf, &mut cycle_events);
494 drain_count += 1;
495 }
496 Err(_) => break,
497 }
498 }
499
500 for (name, batch) in self.pending_watermark_batches.drain(..) {
501 callback.extract_watermark(&name, &batch);
502 }
503
504 if !self.source_batches_buf.is_empty() || callback.has_deferred_input() {
509 let wm = callback.current_watermark();
510 match callback.execute_cycle(&self.source_batches_buf, wm).await {
511 Ok(results) => {
512 self.commit_pending_offsets();
513 callback.update_mv_stores(&results);
514 callback.push_to_streams(&results);
515 callback.write_to_sinks(&results).await;
516 }
517 Err(e) => {
518 self.discard_pending_offsets();
519 tracing::warn!(error = %e, "[LDB-3020] SQL cycle error");
520 }
521 }
522 #[allow(clippy::cast_possible_truncation)]
523 let elapsed_ns = cycle_start.elapsed().as_nanos() as u64;
524 callback.record_cycle(cycle_events, 0, elapsed_ns);
525
526 if elapsed_ns >= self.config.cycle_budget_ns {
527 tracing::debug!(
528 elapsed_ms = elapsed_ns / 1_000_000,
529 budget_ms = self.config.cycle_budget_ns / 1_000_000,
530 "cycle budget exceeded — proceeding to maintenance"
531 );
532 }
533 }
534
535 #[allow(clippy::cast_possible_truncation)]
537 let cycle_elapsed_ns = cycle_start.elapsed().as_nanos() as u64;
538
539 drop(event_priority);
540 let _bg_priority = PriorityGuard::enter(PriorityClass::BackgroundIo);
541 let bg_start = Instant::now();
542 let bg_budget = self.config.background_budget_ns;
543
544 for (source_idx, barrier, cp) in &barriers_buf {
547 self.handle_barrier(*source_idx, barrier, cp, &mut callback)
548 .await;
549 }
550
551 #[allow(clippy::cast_possible_truncation)]
554 if (bg_start.elapsed().as_nanos() as u64) < bg_budget {
555 self.maybe_checkpoint(&mut callback).await;
556 }
557
558 #[allow(clippy::cast_possible_truncation)]
561 let bg_elapsed = bg_start.elapsed().as_nanos() as u64;
562 if cycle_elapsed_ns < self.config.cycle_budget_ns && bg_elapsed < bg_budget {
563 callback.poll_tables().await;
564 } else {
565 tracing::debug!("skipping poll_tables (budget exhausted)");
566 }
567
568 while let Ok(msg) = self.control_rx.try_recv() {
572 callback.apply_control(msg);
573 }
574
575 if self.pending_barrier.active
577 && self.pending_barrier.started_at.elapsed() > self.config.barrier_alignment_timeout
578 {
579 tracing::warn!(
580 checkpoint_id = self.pending_barrier.checkpoint_id,
581 "Barrier alignment timeout — cancelling checkpoint"
582 );
583 self.pending_barrier.active = false;
584 }
585 }
586
587 for handle in &self.source_handles {
589 handle.shutdown.notify_one();
590 }
591
592 self.source_batches_buf.clear();
598 self.barrier_seen.clear();
599 self.discard_pending_offsets();
600 let mut drain_barriers: Vec<(usize, CheckpointBarrier, SourceCheckpoint)> = Vec::new();
601 let mut drain_events: u64 = 0;
602
603 loop {
604 let deferred = std::mem::take(&mut self.post_barrier_buf);
605 let mut got_any = !deferred.is_empty();
606 for msg in deferred {
607 self.process_msg(msg, &mut callback, &mut drain_barriers, &mut drain_events);
608 }
609 while let Ok(msg) = self.rx.try_recv() {
610 got_any = true;
611 self.process_msg(msg, &mut callback, &mut drain_barriers, &mut drain_events);
612 }
613 if !got_any {
614 break;
615 }
616 }
617
618 for (name, batch) in self.pending_watermark_batches.drain(..) {
619 callback.extract_watermark(&name, &batch);
620 }
621 if !self.source_batches_buf.is_empty() || callback.has_deferred_input() {
622 let wm = callback.current_watermark();
623 match callback.execute_cycle(&self.source_batches_buf, wm).await {
624 Ok(results) => {
625 self.commit_pending_offsets();
626 callback.update_mv_stores(&results);
627 callback.push_to_streams(&results);
628 callback.write_to_sinks(&results).await;
629 }
630 Err(e) => {
631 self.discard_pending_offsets();
632 tracing::warn!(error = %e, "[LDB-3020] SQL cycle error during shutdown drain");
633 }
634 }
635 }
636
637 for handle in std::mem::take(&mut self.source_handles) {
639 if let Err(e) = handle.join.await {
640 tracing::warn!(source = %handle.name, error = ?e, "source task panicked");
641 }
642 }
643
644 self.source_batches_buf.clear();
647 self.barrier_seen.clear();
648 self.discard_pending_offsets();
649 drain_barriers.clear();
650 while let Ok(msg) = self.rx.try_recv() {
651 self.process_msg(msg, &mut callback, &mut drain_barriers, &mut drain_events);
652 }
653 for (name, batch) in self.pending_watermark_batches.drain(..) {
654 callback.extract_watermark(&name, &batch);
655 }
656 if !self.source_batches_buf.is_empty() || callback.has_deferred_input() {
657 let wm = callback.current_watermark();
658 match callback.execute_cycle(&self.source_batches_buf, wm).await {
659 Ok(results) => {
660 self.commit_pending_offsets();
661 callback.update_mv_stores(&results);
662 callback.push_to_streams(&results);
663 callback.write_to_sinks(&results).await;
664 }
665 Err(e) => {
666 self.discard_pending_offsets();
667 tracing::warn!(error = %e, "[LDB-3020] SQL cycle error during final drain");
668 }
669 }
670 }
671
672 let checkpoint_enabled = self.config.checkpoint_interval.is_some();
675 if checkpoint_enabled {
676 let source_offsets: FxHashMap<String, SourceCheckpoint> = self
677 .committed_offsets
678 .iter()
679 .enumerate()
680 .filter_map(|(idx, cp)| {
681 cp.as_ref().and_then(|c| {
682 self.source_names
683 .get(idx)
684 .map(|name| (name.to_string(), c.clone()))
685 })
686 })
687 .collect();
688 if let Some(epoch) = callback.maybe_checkpoint(true, source_offsets).await {
689 tracing::info!(epoch, "final checkpoint completed before shutdown");
690 self.broadcast_epoch_committed(epoch);
691 }
692 }
693 }
694
695 fn process_msg(
701 &mut self,
702 msg: SourceMsg,
703 callback: &mut impl PipelineCallback,
704 barriers: &mut Vec<(usize, CheckpointBarrier, SourceCheckpoint)>,
705 cycle_events: &mut u64,
706 ) {
707 match msg {
708 SourceMsg::Batch {
709 source_idx,
710 batch,
711 checkpoint,
712 } => {
713 if self.barrier_seen.contains(&source_idx) {
716 self.post_barrier_buf.push(SourceMsg::Batch {
717 source_idx,
718 batch,
719 checkpoint,
720 });
721 return;
722 }
723
724 if source_idx < self.pending_offsets.len() {
726 self.pending_offsets[source_idx] = Some(checkpoint);
727 }
728
729 if let Some(name) = self.source_names.get(source_idx) {
730 #[allow(clippy::cast_possible_truncation)]
731 {
732 *cycle_events += batch.num_rows() as u64;
733 }
734 if let Some(filtered) = callback.filter_late_rows(name, &batch) {
739 self.source_batches_buf
740 .entry(Arc::clone(name))
741 .or_default()
742 .push(filtered);
743 }
744 self.pending_watermark_batches
745 .push((Arc::clone(name), batch));
746 }
747 }
748 SourceMsg::Barrier {
749 source_idx,
750 barrier,
751 checkpoint,
752 } => {
753 self.barrier_seen.insert(source_idx);
754 barriers.push((source_idx, barrier, checkpoint));
755 }
756 }
757 }
758
759 fn commit_pending_offsets(&mut self) {
762 for (i, pending) in self.pending_offsets.iter_mut().enumerate() {
763 if let Some(cp) = pending.take() {
764 self.committed_offsets[i] = Some(cp);
765 }
766 }
767 }
768
769 fn discard_pending_offsets(&mut self) {
771 for slot in &mut self.pending_offsets {
772 *slot = None;
773 }
774 }
775
776 async fn handle_barrier(
778 &mut self,
779 source_idx: usize,
780 barrier: &CheckpointBarrier,
781 barrier_checkpoint: &SourceCheckpoint,
782 callback: &mut impl PipelineCallback,
783 ) {
784 if !self.pending_barrier.active
785 || barrier.checkpoint_id != self.pending_barrier.checkpoint_id
786 {
787 return;
788 }
789
790 if let Some(name) = self.source_names.get(source_idx) {
793 self.pending_barrier
794 .source_checkpoints
795 .insert(name.to_string(), barrier_checkpoint.clone());
796 }
797
798 self.pending_barrier.sources_aligned.insert(source_idx);
799
800 if self.pending_barrier.sources_aligned.len() >= self.pending_barrier.sources_total {
802 let checkpoints = std::mem::take(&mut self.pending_barrier.source_checkpoints);
803 if let Some(epoch) = callback.checkpoint_with_barrier(checkpoints).await {
804 self.broadcast_epoch_committed(epoch);
805 } else {
806 tracing::warn!(
807 checkpoint_id = self.pending_barrier.checkpoint_id,
808 "barrier checkpoint failed, will retry on next interval"
809 );
810 }
811 self.pending_barrier.active = false;
812 self.last_checkpoint = Instant::now();
813 }
814 }
815
816 async fn maybe_checkpoint(&mut self, callback: &mut impl PipelineCallback) {
822 if self.pending_barrier.active {
823 return; }
825
826 let offsets = FxHashMap::default();
834 if let Some(epoch) = callback.maybe_checkpoint(false, offsets).await {
835 self.broadcast_epoch_committed(epoch);
836 }
837
838 let should_checkpoint = self
839 .config
840 .checkpoint_interval
841 .is_some_and(|interval| self.last_checkpoint.elapsed() >= interval)
842 || self
843 .checkpoint_request_flags
844 .iter()
845 .any(|f| f.swap(false, Ordering::AcqRel));
846
847 if !should_checkpoint {
848 return;
849 }
850
851 if self.source_handles.is_empty() {
852 let offsets = FxHashMap::default();
854 if let Some(epoch) = callback.maybe_checkpoint(false, offsets).await {
855 self.last_checkpoint = Instant::now();
856 self.broadcast_epoch_committed(epoch);
857 }
858 return;
859 }
860
861 let checkpoint_id = self.next_checkpoint_id;
863 self.next_checkpoint_id += 1;
864 self.pending_barrier
865 .reset(checkpoint_id, self.source_handles.len());
866
867 for handle in &self.source_handles {
868 handle.barrier_injector.trigger(checkpoint_id, 0);
869 }
870 }
871}
872
873#[cfg(test)]
874mod tests {
875 use super::*;
876 use arrow::array::Int64Array;
877 use arrow::datatypes::{DataType, Field, Schema};
878 use std::sync::Arc;
879
880 struct MockCallback {
882 cycle_count: u32,
883 results: Vec<FxHashMap<Arc<str>, Vec<RecordBatch>>>,
884 watermark: i64,
885 force_checkpoint_flag: Option<Arc<std::sync::atomic::AtomicBool>>,
887 }
888
889 impl MockCallback {
890 fn new() -> Self {
891 Self {
892 cycle_count: 0,
893 results: Vec::new(),
894 watermark: 0,
895 force_checkpoint_flag: None,
896 }
897 }
898 }
899
900 impl PipelineCallback for MockCallback {
901 async fn execute_cycle(
902 &mut self,
903 source_batches: &FxHashMap<Arc<str>, Vec<RecordBatch>>,
904 _watermark: i64,
905 ) -> Result<FxHashMap<Arc<str>, Vec<RecordBatch>>, String> {
906 self.cycle_count += 1;
907 let results: FxHashMap<Arc<str>, Vec<RecordBatch>> = source_batches
909 .iter()
910 .map(|(k, v)| (k.clone(), v.clone()))
911 .collect();
912 self.results.push(results.clone());
913 Ok(results)
914 }
915
916 fn push_to_streams(&self, _results: &FxHashMap<Arc<str>, Vec<RecordBatch>>) {}
917 async fn write_to_sinks(&mut self, _results: &FxHashMap<Arc<str>, Vec<RecordBatch>>) {}
918
919 fn extract_watermark(&mut self, _source_name: &str, batch: &RecordBatch) {
920 #[allow(clippy::cast_possible_wrap)]
922 {
923 self.watermark += batch.num_rows() as i64;
924 }
925 }
926
927 fn filter_late_rows(&self, _source_name: &str, batch: &RecordBatch) -> Option<RecordBatch> {
928 Some(batch.clone())
929 }
930
931 fn current_watermark(&self) -> i64 {
932 self.watermark
933 }
934
935 async fn maybe_checkpoint(
936 &mut self,
937 force: bool,
938 _source_offsets: FxHashMap<String, SourceCheckpoint>,
939 ) -> Option<u64> {
940 if force {
941 if let Some(ref flag) = self.force_checkpoint_flag {
942 flag.store(true, std::sync::atomic::Ordering::SeqCst);
943 }
944 Some(1)
945 } else {
946 None
947 }
948 }
949
950 async fn checkpoint_with_barrier(
951 &mut self,
952 _source_checkpoints: FxHashMap<String, SourceCheckpoint>,
953 ) -> Option<u64> {
954 Some(1)
955 }
956
957 fn record_cycle(&self, _events: u64, _batches: u64, _elapsed_ns: u64) {}
958 async fn poll_tables(&mut self) {}
959 fn apply_control(&mut self, _msg: crate::pipeline::ControlMsg) {}
960 }
961
962 #[tokio::test]
964 async fn test_coordinator_direct_channel() {
965 let shutdown = Arc::new(tokio::sync::Notify::new());
966 let (tx, rx) = mpsc::bounded_async::<SourceMsg>(64);
967
968 let (_control_tx, control_rx) = mpsc::bounded_async::<crate::pipeline::ControlMsg>(64);
970 let coordinator = StreamingCoordinator {
971 config: PipelineConfig {
972 batch_window: Duration::ZERO,
973 max_poll_records: 1000,
974 channel_capacity: 64,
975 fallback_poll_interval: Duration::from_millis(10),
976 checkpoint_interval: None,
977 delivery_guarantee: DeliveryGuarantee::AtLeastOnce,
978 barrier_alignment_timeout: Duration::from_secs(30),
979 cycle_budget_ns: 10_000_000,
980 drain_budget_ns: 1_000_000,
981 query_budget_ns: 8_000_000,
982 background_budget_ns: 5_000_000,
983 max_input_buf_batches: 256,
984 max_input_buf_bytes: None,
985 backpressure_policy: crate::config::BackpressurePolicy::Backpressure,
986 },
987 rx,
988 source_handles: Vec::new(),
989 source_names: vec![Arc::from("test_source")],
990 shutdown: Arc::clone(&shutdown),
991 pending_barrier: PendingBarrier::new(),
992 next_checkpoint_id: 1,
993 last_checkpoint: Instant::now(),
994 checkpoint_request_flags: Vec::new(),
995 source_batches_buf: FxHashMap::default(),
996 post_barrier_buf: Vec::new(),
997 pending_watermark_batches: Vec::new(),
998 barrier_seen: FxHashSet::default(),
999 committed_offsets: vec![None],
1000 pending_offsets: vec![None],
1001 control_rx,
1002 };
1003
1004 let callback = MockCallback::new();
1005
1006 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1008 let batch =
1009 RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(vec![1, 2, 3]))]).unwrap();
1010 tx.send(SourceMsg::Batch {
1011 source_idx: 0,
1012 batch,
1013 checkpoint: SourceCheckpoint::new(1),
1014 })
1015 .await
1016 .unwrap();
1017
1018 let shutdown_clone = Arc::clone(&shutdown);
1020 tokio::spawn(async move {
1021 tokio::time::sleep(Duration::from_millis(50)).await;
1022 shutdown_clone.notify_one();
1023 });
1024
1025 coordinator.run(callback).await;
1027
1028 }
1031
1032 #[tokio::test]
1035 async fn test_final_checkpoint_on_shutdown() {
1036 let shutdown = Arc::new(tokio::sync::Notify::new());
1037 let (tx, rx) = mpsc::bounded_async::<SourceMsg>(64);
1038 let (_control_tx, control_rx) = mpsc::bounded_async::<crate::pipeline::ControlMsg>(64);
1039
1040 let coordinator = StreamingCoordinator {
1041 config: PipelineConfig {
1042 batch_window: Duration::ZERO,
1043 max_poll_records: 1000,
1044 channel_capacity: 64,
1045 fallback_poll_interval: Duration::from_millis(10),
1046 checkpoint_interval: Some(Duration::from_secs(60)),
1047 delivery_guarantee: DeliveryGuarantee::AtLeastOnce,
1048 barrier_alignment_timeout: Duration::from_secs(30),
1049 cycle_budget_ns: 10_000_000,
1050 drain_budget_ns: 1_000_000,
1051 query_budget_ns: 8_000_000,
1052 background_budget_ns: 5_000_000,
1053 max_input_buf_batches: 256,
1054 max_input_buf_bytes: None,
1055 backpressure_policy: crate::config::BackpressurePolicy::Backpressure,
1056 },
1057 rx,
1058 source_handles: Vec::new(),
1059 source_names: vec![Arc::from("test_source")],
1060 shutdown: Arc::clone(&shutdown),
1061 pending_barrier: PendingBarrier::new(),
1062 next_checkpoint_id: 1,
1063 last_checkpoint: Instant::now(),
1064 checkpoint_request_flags: Vec::new(),
1065 source_batches_buf: FxHashMap::default(),
1066 post_barrier_buf: Vec::new(),
1067 pending_watermark_batches: Vec::new(),
1068 barrier_seen: FxHashSet::default(),
1069 committed_offsets: vec![None],
1070 pending_offsets: vec![None],
1071 control_rx,
1072 };
1073
1074 let force_flag = Arc::new(std::sync::atomic::AtomicBool::new(false));
1075 let mut callback = MockCallback::new();
1076 callback.force_checkpoint_flag = Some(Arc::clone(&force_flag));
1077
1078 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1079 let batch =
1080 RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(vec![1]))]).unwrap();
1081 tx.send(SourceMsg::Batch {
1082 source_idx: 0,
1083 batch,
1084 checkpoint: SourceCheckpoint::new(1),
1085 })
1086 .await
1087 .unwrap();
1088
1089 let shutdown_clone = Arc::clone(&shutdown);
1090 tokio::spawn(async move {
1091 tokio::time::sleep(Duration::from_millis(50)).await;
1092 shutdown_clone.notify_one();
1093 });
1094
1095 coordinator.run(callback).await;
1096
1097 assert!(
1098 force_flag.load(std::sync::atomic::Ordering::SeqCst),
1099 "final checkpoint with force=true should have been called"
1100 );
1101 }
1102
1103 #[tokio::test]
1106 #[allow(clippy::too_many_lines, clippy::similar_names)]
1107 async fn test_barrier_excludes_post_barrier_data() {
1108 let shutdown = Arc::new(tokio::sync::Notify::new());
1109 let schema = Arc::new(Schema::new(vec![Field::new("ts", DataType::Int64, false)]));
1110
1111 let (_control_tx2, control_rx2) = mpsc::bounded_async::<crate::pipeline::ControlMsg>(64);
1112 let mut coordinator = StreamingCoordinator {
1113 config: PipelineConfig {
1114 batch_window: Duration::ZERO,
1115 max_poll_records: 1000,
1116 channel_capacity: 64,
1117 fallback_poll_interval: Duration::from_millis(10),
1118 checkpoint_interval: None,
1119 delivery_guarantee: DeliveryGuarantee::AtLeastOnce,
1120 barrier_alignment_timeout: Duration::from_secs(30),
1121 cycle_budget_ns: 10_000_000,
1122 drain_budget_ns: 1_000_000,
1123 query_budget_ns: 8_000_000,
1124 background_budget_ns: 5_000_000,
1125 max_input_buf_batches: 256,
1126 max_input_buf_bytes: None,
1127 backpressure_policy: crate::config::BackpressurePolicy::Backpressure,
1128 },
1129 rx: mpsc::bounded_async::<SourceMsg>(64).1, source_handles: Vec::new(),
1131 source_names: vec![Arc::from("s0"), Arc::from("s1")],
1132 shutdown: Arc::clone(&shutdown),
1133 pending_barrier: PendingBarrier::new(),
1134 next_checkpoint_id: 1,
1135 last_checkpoint: Instant::now(),
1136 checkpoint_request_flags: Vec::new(),
1137 source_batches_buf: FxHashMap::default(),
1138 post_barrier_buf: Vec::new(),
1139 pending_watermark_batches: Vec::new(),
1140 barrier_seen: FxHashSet::default(),
1141 committed_offsets: vec![None, None],
1142 pending_offsets: vec![None, None],
1143 control_rx: control_rx2,
1144 };
1145
1146 let mut callback = MockCallback::new();
1147 let mut barriers = Vec::new();
1148 let mut cycle_events: u64 = 0;
1149
1150 let batch_1 = RecordBatch::try_new(
1152 Arc::clone(&schema),
1153 vec![Arc::new(Int64Array::from(vec![1]))],
1154 )
1155 .unwrap();
1156 let batch_2 = RecordBatch::try_new(
1157 Arc::clone(&schema),
1158 vec![Arc::new(Int64Array::from(vec![2]))],
1159 )
1160 .unwrap();
1161 let barrier = CheckpointBarrier::new(1, 0);
1162
1163 coordinator.process_msg(
1164 SourceMsg::Batch {
1165 source_idx: 0,
1166 batch: batch_1,
1167 checkpoint: SourceCheckpoint::new(10),
1168 },
1169 &mut callback,
1170 &mut barriers,
1171 &mut cycle_events,
1172 );
1173 coordinator.process_msg(
1174 SourceMsg::Barrier {
1175 source_idx: 0,
1176 barrier,
1177 checkpoint: SourceCheckpoint::new(10),
1178 },
1179 &mut callback,
1180 &mut barriers,
1181 &mut cycle_events,
1182 );
1183 coordinator.process_msg(
1184 SourceMsg::Batch {
1185 source_idx: 0,
1186 batch: batch_2,
1187 checkpoint: SourceCheckpoint::new(20),
1188 },
1189 &mut callback,
1190 &mut barriers,
1191 &mut cycle_events,
1192 );
1193
1194 let batch_s1 = RecordBatch::try_new(
1196 Arc::clone(&schema),
1197 vec![Arc::new(Int64Array::from(vec![1]))],
1198 )
1199 .unwrap();
1200 coordinator.process_msg(
1201 SourceMsg::Batch {
1202 source_idx: 1,
1203 batch: batch_s1,
1204 checkpoint: SourceCheckpoint::new(5),
1205 },
1206 &mut callback,
1207 &mut barriers,
1208 &mut cycle_events,
1209 );
1210 coordinator.process_msg(
1211 SourceMsg::Barrier {
1212 source_idx: 1,
1213 barrier,
1214 checkpoint: SourceCheckpoint::new(5),
1215 },
1216 &mut callback,
1217 &mut barriers,
1218 &mut cycle_events,
1219 );
1220
1221 let s0_batches = coordinator.source_batches_buf.get("s0").unwrap();
1224 assert_eq!(
1225 s0_batches.len(),
1226 1,
1227 "s0 should have exactly 1 pre-barrier batch"
1228 );
1229 let s0_col = s0_batches[0]
1230 .column(0)
1231 .as_any()
1232 .downcast_ref::<Int64Array>()
1233 .unwrap();
1234 assert_eq!(s0_col.value(0), 1, "s0 batch should contain ts=1");
1235
1236 let s1_batches = coordinator.source_batches_buf.get("s1").unwrap();
1237 assert_eq!(s1_batches.len(), 1, "s1 should have exactly 1 batch");
1238
1239 assert_eq!(
1241 coordinator.post_barrier_buf.len(),
1242 1,
1243 "post_barrier_buf should have 1 deferred batch"
1244 );
1245
1246 assert_eq!(
1248 coordinator.pending_offsets[0].as_ref().unwrap().epoch(),
1249 10,
1250 "s0 pending offset should be the pre-barrier batch"
1251 );
1252 assert_eq!(
1253 coordinator.pending_offsets[1].as_ref().unwrap().epoch(),
1254 5,
1255 "s1 pending offset should be epoch 5"
1256 );
1257 assert!(
1259 coordinator.committed_offsets[0].is_none(),
1260 "s0 committed offset should be None before execute_cycle"
1261 );
1262 assert!(
1263 coordinator.committed_offsets[1].is_none(),
1264 "s1 committed offset should be None before execute_cycle"
1265 );
1266
1267 coordinator.commit_pending_offsets();
1269 assert_eq!(
1270 coordinator.committed_offsets[0].as_ref().unwrap().epoch(),
1271 10,
1272 "s0 committed after cycle"
1273 );
1274 assert_eq!(
1275 coordinator.committed_offsets[1].as_ref().unwrap().epoch(),
1276 5,
1277 "s1 committed after cycle"
1278 );
1279
1280 assert_eq!(barriers.len(), 2, "should have barriers from both sources");
1282 }
1283
1284 #[allow(clippy::disallowed_types)] struct BackpressuredCallback {
1286 inner: MockCallback,
1287 cycle_count: Arc<std::sync::atomic::AtomicU32>,
1288 events_per_cycle: Arc<std::sync::Mutex<Vec<u64>>>,
1289 }
1290
1291 impl BackpressuredCallback {
1292 #[allow(clippy::disallowed_types)]
1293 fn new(
1294 cycle_count: Arc<std::sync::atomic::AtomicU32>,
1295 events_per_cycle: Arc<std::sync::Mutex<Vec<u64>>>,
1296 ) -> Self {
1297 Self {
1298 inner: MockCallback::new(),
1299 cycle_count,
1300 events_per_cycle,
1301 }
1302 }
1303 }
1304
1305 impl PipelineCallback for BackpressuredCallback {
1306 async fn execute_cycle(
1307 &mut self,
1308 source_batches: &FxHashMap<Arc<str>, Vec<RecordBatch>>,
1309 watermark: i64,
1310 ) -> Result<FxHashMap<Arc<str>, Vec<RecordBatch>>, String> {
1311 self.cycle_count
1312 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1313 let total: u64 = source_batches
1314 .values()
1315 .flat_map(|bs| bs.iter())
1316 .map(|b| b.num_rows() as u64)
1317 .sum();
1318 self.events_per_cycle.lock().unwrap().push(total);
1319 self.inner.execute_cycle(source_batches, watermark).await
1320 }
1321
1322 fn push_to_streams(&self, r: &FxHashMap<Arc<str>, Vec<RecordBatch>>) {
1323 self.inner.push_to_streams(r);
1324 }
1325 async fn write_to_sinks(&mut self, r: &FxHashMap<Arc<str>, Vec<RecordBatch>>) {
1326 self.inner.write_to_sinks(r).await;
1327 }
1328 fn extract_watermark(&mut self, s: &str, b: &RecordBatch) {
1329 self.inner.extract_watermark(s, b);
1330 }
1331 fn filter_late_rows(&self, s: &str, b: &RecordBatch) -> Option<RecordBatch> {
1332 self.inner.filter_late_rows(s, b)
1333 }
1334 fn current_watermark(&self) -> i64 {
1335 self.inner.current_watermark()
1336 }
1337 async fn maybe_checkpoint(
1338 &mut self,
1339 force: bool,
1340 offsets: FxHashMap<String, SourceCheckpoint>,
1341 ) -> Option<u64> {
1342 self.inner.maybe_checkpoint(force, offsets).await
1343 }
1344 async fn checkpoint_with_barrier(
1345 &mut self,
1346 cp: FxHashMap<String, SourceCheckpoint>,
1347 ) -> Option<u64> {
1348 self.inner.checkpoint_with_barrier(cp).await
1349 }
1350 fn record_cycle(&self, e: u64, b: u64, ns: u64) {
1351 self.inner.record_cycle(e, b, ns);
1352 }
1353 async fn poll_tables(&mut self) {
1354 self.inner.poll_tables().await;
1355 }
1356 fn apply_control(&mut self, msg: crate::pipeline::ControlMsg) {
1357 self.inner.apply_control(msg);
1358 }
1359
1360 fn is_backpressured(&self) -> bool {
1361 true }
1363 }
1364
1365 #[tokio::test]
1370 async fn test_drain_skip_under_backpressure() {
1371 let shutdown = Arc::new(tokio::sync::Notify::new());
1372 let (tx, rx) = mpsc::bounded_async::<SourceMsg>(64);
1373 let (_control_tx, control_rx) = mpsc::bounded_async::<crate::pipeline::ControlMsg>(64);
1374
1375 let coordinator = StreamingCoordinator {
1376 config: PipelineConfig {
1377 batch_window: Duration::ZERO,
1378 max_poll_records: 1000,
1379 channel_capacity: 64,
1380 fallback_poll_interval: Duration::from_millis(10),
1381 checkpoint_interval: None,
1382 delivery_guarantee: DeliveryGuarantee::AtLeastOnce,
1383 barrier_alignment_timeout: Duration::from_secs(30),
1384 cycle_budget_ns: 10_000_000,
1385 drain_budget_ns: 1_000_000,
1386 query_budget_ns: 8_000_000,
1387 background_budget_ns: 5_000_000,
1388 max_input_buf_batches: 256,
1389 max_input_buf_bytes: None,
1390 backpressure_policy: crate::config::BackpressurePolicy::Backpressure,
1391 },
1392 rx,
1393 source_handles: Vec::new(),
1394 source_names: vec![Arc::from("src")],
1395 shutdown: Arc::clone(&shutdown),
1396 pending_barrier: PendingBarrier::new(),
1397 next_checkpoint_id: 1,
1398 last_checkpoint: Instant::now(),
1399 checkpoint_request_flags: Vec::new(),
1400 source_batches_buf: FxHashMap::default(),
1401 post_barrier_buf: Vec::new(),
1402 pending_watermark_batches: Vec::new(),
1403 barrier_seen: FxHashSet::default(),
1404 committed_offsets: vec![None],
1405 pending_offsets: vec![None],
1406 control_rx,
1407 };
1408
1409 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1410
1411 for i in 0..5 {
1413 let batch = RecordBatch::try_new(
1414 Arc::clone(&schema),
1415 vec![Arc::new(Int64Array::from(vec![i]))],
1416 )
1417 .unwrap();
1418 tx.send(SourceMsg::Batch {
1419 source_idx: 0,
1420 batch,
1421 checkpoint: SourceCheckpoint::new(u64::try_from(i).unwrap()),
1422 })
1423 .await
1424 .unwrap();
1425 }
1426
1427 let shutdown_clone = Arc::clone(&shutdown);
1428 tokio::spawn(async move {
1429 tokio::time::sleep(Duration::from_millis(300)).await;
1430 shutdown_clone.notify_one();
1431 });
1432
1433 let cycle_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
1434 #[allow(clippy::disallowed_types)]
1435 let events_per_cycle = Arc::new(std::sync::Mutex::new(Vec::new()));
1436 let callback =
1437 BackpressuredCallback::new(Arc::clone(&cycle_count), Arc::clone(&events_per_cycle));
1438 coordinator.run(callback).await;
1439
1440 let cycles = cycle_count.load(std::sync::atomic::Ordering::SeqCst);
1441 let epc = events_per_cycle.lock().unwrap();
1442 let total: u64 = epc.iter().sum();
1443
1444 assert_eq!(total, 5, "all events must be processed, got {total}");
1446 assert!(
1450 cycles >= 5,
1451 "expected >=5 cycles (1 event each), got {cycles} cycles with events/cycle: {epc:?}"
1452 );
1453 for (i, &events) in epc.iter().enumerate() {
1455 assert!(
1456 events <= 1,
1457 "cycle {i} saw {events} events, expected <=1 under backpressure"
1458 );
1459 }
1460 }
1461}