1use std::sync::Arc;
7
8use arrow::array::{Array, TimestampMillisecondArray};
9use arrow::datatypes::{DataType, Schema};
10use arrow::record_batch::RecordBatch;
11
12use super::cast::cast_to_millis_array;
13
14#[derive(Debug, Clone)]
16pub enum TimestampField {
17 Name(String),
19 Index(usize),
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
25pub enum ExtractionMode {
26 #[default]
28 First,
29 Last,
31 Max,
33 Min,
35}
36
37#[derive(Debug, thiserror::Error)]
39pub enum EventTimeError {
40 #[error("Column not found: {0}")]
42 ColumnNotFound(String),
43
44 #[error("Column index {index} out of bounds (batch has {num_columns} columns)")]
46 IndexOutOfBounds {
47 index: usize,
49 num_columns: usize,
51 },
52
53 #[error("event-time column must be Timestamp(_), found {found}")]
55 IncompatibleType {
56 found: String,
58 },
59
60 #[error("Null timestamp at row {row}")]
62 NullTimestamp {
63 row: usize,
65 },
66
67 #[error("Cannot extract timestamp from empty batch")]
69 EmptyBatch,
70
71 #[error("Arrow cast to Timestamp(Millisecond) failed: {0}")]
73 CastFailed(String),
74}
75
76#[derive(Debug)]
78pub struct EventTimeExtractor {
79 field: TimestampField,
80 mode: ExtractionMode,
81 cached_index: Option<usize>,
82}
83
84impl EventTimeExtractor {
85 #[must_use]
88 pub fn from_column(name: &str) -> Self {
89 Self {
90 field: TimestampField::Name(name.to_string()),
91 mode: ExtractionMode::default(),
92 cached_index: None,
93 }
94 }
95
96 #[must_use]
99 pub fn from_index(index: usize) -> Self {
100 Self {
101 field: TimestampField::Index(index),
102 mode: ExtractionMode::default(),
103 cached_index: Some(index),
104 }
105 }
106
107 #[must_use]
109 pub fn with_mode(mut self, mode: ExtractionMode) -> Self {
110 self.mode = mode;
111 self
112 }
113
114 #[must_use]
116 pub fn mode(&self) -> ExtractionMode {
117 self.mode
118 }
119
120 pub fn validate_schema(&self, schema: &Schema) -> Result<(), EventTimeError> {
127 let (_, data_type) = self.resolve_column(schema)?;
128 if !matches!(data_type, DataType::Timestamp(_, _)) {
129 return Err(EventTimeError::IncompatibleType {
130 found: format!("{data_type:?}"),
131 });
132 }
133 Ok(())
134 }
135
136 pub fn extract(&mut self, batch: &RecordBatch) -> Result<i64, EventTimeError> {
144 if batch.num_rows() == 0 {
145 return Err(EventTimeError::EmptyBatch);
146 }
147 let index = self.get_column_index(batch.schema().as_ref())?;
148 let column = batch.column(index);
149 self.extract_from_column(column)
150 }
151
152 fn get_column_index(&mut self, schema: &Schema) -> Result<usize, EventTimeError> {
153 if let Some(idx) = self.cached_index {
154 if idx < schema.fields().len() {
155 return Ok(idx);
156 }
157 }
158 let (index, _) = self.resolve_column(schema)?;
159 self.cached_index = Some(index);
160 Ok(index)
161 }
162
163 fn resolve_column<'a>(
164 &self,
165 schema: &'a Schema,
166 ) -> Result<(usize, &'a DataType), EventTimeError> {
167 match &self.field {
168 TimestampField::Name(name) => {
169 let index = schema
170 .index_of(name)
171 .map_err(|_| EventTimeError::ColumnNotFound(name.clone()))?;
172 Ok((index, schema.field(index).data_type()))
173 }
174 TimestampField::Index(index) => {
175 if *index >= schema.fields().len() {
176 return Err(EventTimeError::IndexOutOfBounds {
177 index: *index,
178 num_columns: schema.fields().len(),
179 });
180 }
181 Ok((*index, schema.field(*index).data_type()))
182 }
183 }
184 }
185
186 fn extract_from_column(&self, column: &Arc<dyn Array>) -> Result<i64, EventTimeError> {
187 let ms = cast_to_millis_array(column.as_ref()).map_err(|e| {
188 if matches!(column.data_type(), DataType::Timestamp(_, _)) {
189 EventTimeError::CastFailed(e.0)
190 } else {
191 EventTimeError::IncompatibleType { found: e.0 }
192 }
193 })?;
194 match self.mode {
195 ExtractionMode::First => read_indexed(&ms, 0),
196 ExtractionMode::Last => read_indexed(&ms, ms.len() - 1),
197 ExtractionMode::Max => fold_non_null(&ms, i64::MIN, i64::max),
198 ExtractionMode::Min => fold_non_null(&ms, i64::MAX, i64::min),
199 }
200 }
201}
202
203fn read_indexed(arr: &TimestampMillisecondArray, idx: usize) -> Result<i64, EventTimeError> {
204 if arr.is_null(idx) {
205 Err(EventTimeError::NullTimestamp { row: idx })
206 } else {
207 Ok(arr.value(idx))
208 }
209}
210
211fn fold_non_null<F>(arr: &TimestampMillisecondArray, init: i64, f: F) -> Result<i64, EventTimeError>
212where
213 F: Fn(i64, i64) -> i64,
214{
215 let mut out = init;
216 let mut found = false;
217 for i in 0..arr.len() {
218 if !arr.is_null(i) {
219 found = true;
220 out = f(out, arr.value(i));
221 }
222 }
223 if found {
224 Ok(out)
225 } else {
226 Err(EventTimeError::NullTimestamp { row: 0 })
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 use arrow::array::{
234 ArrayRef, Int64Builder, TimestampMicrosecondBuilder, TimestampMillisecondBuilder,
235 TimestampNanosecondBuilder, TimestampSecondBuilder,
236 };
237 use arrow::datatypes::{Field, TimeUnit};
238 use std::sync::Arc;
239
240 fn make_ms_batch(values: &[Option<i64>]) -> RecordBatch {
241 let mut b = TimestampMillisecondBuilder::new();
242 for v in values {
243 match v {
244 Some(val) => b.append_value(*val),
245 None => b.append_null(),
246 }
247 }
248 let array: ArrayRef = Arc::new(b.finish());
249 let schema = Arc::new(Schema::new(vec![Field::new(
250 "ts",
251 DataType::Timestamp(TimeUnit::Millisecond, None),
252 true,
253 )]));
254 RecordBatch::try_new(schema, vec![array]).unwrap()
255 }
256
257 fn make_ns_batch(values: &[Option<i64>]) -> RecordBatch {
258 let mut b = TimestampNanosecondBuilder::new();
259 for v in values {
260 match v {
261 Some(val) => b.append_value(*val),
262 None => b.append_null(),
263 }
264 }
265 let array: ArrayRef = Arc::new(b.finish());
266 let schema = Arc::new(Schema::new(vec![Field::new(
267 "ts",
268 DataType::Timestamp(TimeUnit::Nanosecond, None),
269 true,
270 )]));
271 RecordBatch::try_new(schema, vec![array]).unwrap()
272 }
273
274 fn make_us_batch(values: &[Option<i64>]) -> RecordBatch {
275 let mut b = TimestampMicrosecondBuilder::new();
276 for v in values {
277 match v {
278 Some(val) => b.append_value(*val),
279 None => b.append_null(),
280 }
281 }
282 let array: ArrayRef = Arc::new(b.finish());
283 let schema = Arc::new(Schema::new(vec![Field::new(
284 "ts",
285 DataType::Timestamp(TimeUnit::Microsecond, None),
286 true,
287 )]));
288 RecordBatch::try_new(schema, vec![array]).unwrap()
289 }
290
291 fn make_s_batch(values: &[Option<i64>]) -> RecordBatch {
292 let mut b = TimestampSecondBuilder::new();
293 for v in values {
294 match v {
295 Some(val) => b.append_value(*val),
296 None => b.append_null(),
297 }
298 }
299 let array: ArrayRef = Arc::new(b.finish());
300 let schema = Arc::new(Schema::new(vec![Field::new(
301 "ts",
302 DataType::Timestamp(TimeUnit::Second, None),
303 true,
304 )]));
305 RecordBatch::try_new(schema, vec![array]).unwrap()
306 }
307
308 #[test]
309 fn test_extract_millis() {
310 let batch = make_ms_batch(&[Some(1_705_312_200_000)]);
311 let mut extractor = EventTimeExtractor::from_column("ts");
312 assert_eq!(extractor.extract(&batch).unwrap(), 1_705_312_200_000);
313 }
314
315 #[test]
316 fn test_extract_nanos_is_rescaled_to_millis() {
317 let batch = make_ns_batch(&[Some(1_705_312_200_000_000_000)]);
318 let mut extractor = EventTimeExtractor::from_column("ts");
319 assert_eq!(extractor.extract(&batch).unwrap(), 1_705_312_200_000);
320 }
321
322 #[test]
323 fn test_extract_micros_is_rescaled_to_millis() {
324 let batch = make_us_batch(&[Some(1_705_312_200_000_000)]);
325 let mut extractor = EventTimeExtractor::from_column("ts");
326 assert_eq!(extractor.extract(&batch).unwrap(), 1_705_312_200_000);
327 }
328
329 #[test]
330 fn test_extract_seconds_is_rescaled_to_millis() {
331 let batch = make_s_batch(&[Some(1_705_312_200)]);
332 let mut extractor = EventTimeExtractor::from_column("ts");
333 assert_eq!(extractor.extract(&batch).unwrap(), 1_705_312_200_000);
334 }
335
336 #[test]
337 fn test_mode_first() {
338 let batch = make_ms_batch(&[Some(100), Some(200), Some(150)]);
339 let mut extractor = EventTimeExtractor::from_column("ts").with_mode(ExtractionMode::First);
340 assert_eq!(extractor.extract(&batch).unwrap(), 100);
341 }
342
343 #[test]
344 fn test_mode_last() {
345 let batch = make_ms_batch(&[Some(100), Some(200), Some(150)]);
346 let mut extractor = EventTimeExtractor::from_column("ts").with_mode(ExtractionMode::Last);
347 assert_eq!(extractor.extract(&batch).unwrap(), 150);
348 }
349
350 #[test]
351 fn test_mode_max() {
352 let batch = make_ms_batch(&[Some(100), Some(200), Some(150)]);
353 let mut extractor = EventTimeExtractor::from_column("ts").with_mode(ExtractionMode::Max);
354 assert_eq!(extractor.extract(&batch).unwrap(), 200);
355 }
356
357 #[test]
358 fn test_mode_min() {
359 let batch = make_ms_batch(&[Some(100), Some(200), Some(150)]);
360 let mut extractor = EventTimeExtractor::from_column("ts").with_mode(ExtractionMode::Min);
361 assert_eq!(extractor.extract(&batch).unwrap(), 100);
362 }
363
364 #[test]
365 fn test_max_skips_nulls() {
366 let batch = make_ms_batch(&[Some(100), None, Some(200), Some(150)]);
367 let mut extractor = EventTimeExtractor::from_column("ts").with_mode(ExtractionMode::Max);
368 assert_eq!(extractor.extract(&batch).unwrap(), 200);
369 }
370
371 #[test]
372 fn test_column_not_found() {
373 let batch = make_ms_batch(&[Some(100)]);
374 let mut extractor = EventTimeExtractor::from_column("missing");
375 assert!(matches!(
376 extractor.extract(&batch),
377 Err(EventTimeError::ColumnNotFound(_))
378 ));
379 }
380
381 #[test]
382 fn test_non_timestamp_column_is_rejected() {
383 let mut b = Int64Builder::new();
384 b.append_value(100);
385 let array: ArrayRef = Arc::new(b.finish());
386 let schema = Arc::new(Schema::new(vec![Field::new("ts", DataType::Int64, true)]));
387 let batch = RecordBatch::try_new(schema, vec![array]).unwrap();
388
389 let mut extractor = EventTimeExtractor::from_column("ts");
390 assert!(matches!(
391 extractor.extract(&batch),
392 Err(EventTimeError::IncompatibleType { .. })
393 ));
394 }
395
396 #[test]
397 fn test_empty_batch() {
398 let batch = make_ms_batch(&[]);
399 let mut extractor = EventTimeExtractor::from_column("ts");
400 assert!(matches!(
401 extractor.extract(&batch),
402 Err(EventTimeError::EmptyBatch)
403 ));
404 }
405
406 #[test]
407 fn test_null_first_row() {
408 let batch = make_ms_batch(&[None, Some(100)]);
409 let mut extractor = EventTimeExtractor::from_column("ts").with_mode(ExtractionMode::First);
410 assert!(matches!(
411 extractor.extract(&batch),
412 Err(EventTimeError::NullTimestamp { row: 0 })
413 ));
414 }
415
416 #[test]
417 fn test_column_index_caching() {
418 let batch = make_ms_batch(&[Some(100)]);
419 let mut extractor = EventTimeExtractor::from_column("ts");
420
421 assert!(extractor.cached_index.is_none());
422 let _ = extractor.extract(&batch).unwrap();
423 assert_eq!(extractor.cached_index, Some(0));
424 assert_eq!(extractor.extract(&batch).unwrap(), 100);
425 }
426
427 #[test]
428 fn test_from_index_skips_name_lookup() {
429 let batch = make_ms_batch(&[Some(100)]);
430 let mut extractor = EventTimeExtractor::from_index(0);
431 assert_eq!(extractor.cached_index, Some(0));
432 assert_eq!(extractor.extract(&batch).unwrap(), 100);
433 }
434
435 #[test]
436 fn test_validate_schema_ok() {
437 let schema = Schema::new(vec![Field::new(
438 "ts",
439 DataType::Timestamp(TimeUnit::Millisecond, None),
440 true,
441 )]);
442 let extractor = EventTimeExtractor::from_column("ts");
443 assert!(extractor.validate_schema(&schema).is_ok());
444 }
445
446 #[test]
447 fn test_validate_schema_rejects_non_timestamp() {
448 let schema = Schema::new(vec![Field::new("ts", DataType::Int64, true)]);
449 let extractor = EventTimeExtractor::from_column("ts");
450 assert!(matches!(
451 extractor.validate_schema(&schema),
452 Err(EventTimeError::IncompatibleType { .. })
453 ));
454 }
455}