laminar_connectors/
reference.rs1#[cfg(any(test, feature = "testing"))]
15use std::collections::VecDeque;
16use std::time::Duration;
17
18use arrow_array::RecordBatch;
19
20use crate::checkpoint::SourceCheckpoint;
21use crate::error::ConnectorError;
22
23#[derive(Debug, Clone, PartialEq, Eq)]
25pub enum RefreshMode {
26 SnapshotOnly,
28 SnapshotPlusCdc,
30 Periodic {
32 interval: Duration,
34 },
35 Manual,
37}
38
39#[async_trait::async_trait]
51pub trait ReferenceTableSource: Send {
52 async fn poll_snapshot(&mut self) -> Result<Option<RecordBatch>, ConnectorError>;
61
62 fn is_snapshot_complete(&self) -> bool;
64
65 async fn poll_changes(&mut self) -> Result<Option<RecordBatch>, ConnectorError>;
74
75 fn checkpoint(&self) -> SourceCheckpoint;
77
78 async fn restore(&mut self, checkpoint: &SourceCheckpoint) -> Result<(), ConnectorError>;
84
85 async fn close(&mut self) -> Result<(), ConnectorError>;
91}
92
93#[cfg(any(test, feature = "testing"))]
100pub struct MockReferenceTableSource {
101 pub snapshot_batches: VecDeque<RecordBatch>,
103 pub change_batches: VecDeque<RecordBatch>,
105 pub snapshot_complete: bool,
107 pub restored: bool,
109 pub closed: bool,
111 pub mock_checkpoint: SourceCheckpoint,
113}
114
115#[cfg(any(test, feature = "testing"))]
116impl MockReferenceTableSource {
117 #[must_use]
119 pub fn new(snapshot_batches: Vec<RecordBatch>, change_batches: Vec<RecordBatch>) -> Self {
120 Self {
121 snapshot_batches: VecDeque::from(snapshot_batches),
122 change_batches: VecDeque::from(change_batches),
123 snapshot_complete: false,
124 restored: false,
125 closed: false,
126 mock_checkpoint: SourceCheckpoint::new(0),
127 }
128 }
129
130 #[must_use]
132 pub fn empty() -> Self {
133 Self::new(vec![], vec![])
134 }
135}
136
137#[cfg(any(test, feature = "testing"))]
138#[async_trait::async_trait]
139impl ReferenceTableSource for MockReferenceTableSource {
140 async fn poll_snapshot(&mut self) -> Result<Option<RecordBatch>, ConnectorError> {
141 if let Some(batch) = self.snapshot_batches.pop_front() {
142 Ok(Some(batch))
143 } else {
144 self.snapshot_complete = true;
145 Ok(None)
146 }
147 }
148
149 fn is_snapshot_complete(&self) -> bool {
150 self.snapshot_complete
151 }
152
153 async fn poll_changes(&mut self) -> Result<Option<RecordBatch>, ConnectorError> {
154 Ok(self.change_batches.pop_front())
155 }
156
157 fn checkpoint(&self) -> SourceCheckpoint {
158 self.mock_checkpoint.clone()
159 }
160
161 async fn restore(&mut self, _checkpoint: &SourceCheckpoint) -> Result<(), ConnectorError> {
162 self.restored = true;
163 Ok(())
164 }
165
166 async fn close(&mut self) -> Result<(), ConnectorError> {
167 self.closed = true;
168 Ok(())
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175 use arrow_array::Int32Array;
176 use arrow_schema::{DataType, Field, Schema};
177 use std::sync::Arc;
178
179 fn test_batch(values: &[i32]) -> RecordBatch {
180 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
181 RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(values.to_vec()))]).unwrap()
182 }
183
184 #[tokio::test]
185 async fn test_mock_snapshot_exhaustion() {
186 let mut src =
187 MockReferenceTableSource::new(vec![test_batch(&[1, 2]), test_batch(&[3])], vec![]);
188
189 assert!(!src.is_snapshot_complete());
190
191 let b1 = src.poll_snapshot().await.unwrap().unwrap();
192 assert_eq!(b1.num_rows(), 2);
193 assert!(!src.is_snapshot_complete());
194
195 let b2 = src.poll_snapshot().await.unwrap().unwrap();
196 assert_eq!(b2.num_rows(), 1);
197 assert!(!src.is_snapshot_complete());
198
199 let none = src.poll_snapshot().await.unwrap();
200 assert!(none.is_none());
201 assert!(src.is_snapshot_complete());
202
203 assert!(src.poll_snapshot().await.unwrap().is_none());
205 }
206
207 #[tokio::test]
208 async fn test_mock_change_polling() {
209 let mut src =
210 MockReferenceTableSource::new(vec![], vec![test_batch(&[10]), test_batch(&[20, 30])]);
211
212 assert!(src.poll_snapshot().await.unwrap().is_none());
214
215 let c1 = src.poll_changes().await.unwrap().unwrap();
216 assert_eq!(c1.num_rows(), 1);
217
218 let c2 = src.poll_changes().await.unwrap().unwrap();
219 assert_eq!(c2.num_rows(), 2);
220
221 assert!(src.poll_changes().await.unwrap().is_none());
222 }
223
224 #[tokio::test]
225 async fn test_mock_checkpoint_round_trip() {
226 let mut cp = SourceCheckpoint::new(5);
227 cp.set_offset("lsn", "0/ABCD");
228
229 let mut src = MockReferenceTableSource::empty();
230 src.mock_checkpoint = cp.clone();
231
232 let returned = src.checkpoint();
233 assert_eq!(returned.epoch(), 5);
234 assert_eq!(returned.get_offset("lsn"), Some("0/ABCD"));
235 }
236
237 #[tokio::test]
238 async fn test_mock_restore_sets_flag() {
239 let mut src = MockReferenceTableSource::empty();
240 assert!(!src.restored);
241
242 let cp = SourceCheckpoint::new(1);
243 src.restore(&cp).await.unwrap();
244 assert!(src.restored);
245 }
246
247 #[tokio::test]
248 async fn test_mock_close_idempotent() {
249 let mut src = MockReferenceTableSource::empty();
250 assert!(!src.closed);
251
252 src.close().await.unwrap();
253 assert!(src.closed);
254
255 src.close().await.unwrap();
257 assert!(src.closed);
258 }
259
260 #[tokio::test]
261 async fn test_trait_compliance_with_mock() {
262 let mut src: Box<dyn ReferenceTableSource> = Box::new(MockReferenceTableSource::new(
264 vec![test_batch(&[1])],
265 vec![test_batch(&[2])],
266 ));
267
268 let batch = src.poll_snapshot().await.unwrap().unwrap();
270 assert_eq!(batch.num_rows(), 1);
271 assert!(src.poll_snapshot().await.unwrap().is_none());
272 assert!(src.is_snapshot_complete());
273
274 let change = src.poll_changes().await.unwrap().unwrap();
276 assert_eq!(change.num_rows(), 1);
277 assert!(src.poll_changes().await.unwrap().is_none());
278
279 let _cp = src.checkpoint();
281
282 let cp = SourceCheckpoint::new(0);
284 src.restore(&cp).await.unwrap();
285
286 src.close().await.unwrap();
288 }
289}