Skip to main content

laminar_core/time/
cast.rs

1//! Cast any `Timestamp(_)` array to `TimestampMillisecondArray`.
2
3use std::fmt;
4
5use arrow::array::{Array, TimestampMillisecondArray};
6use arrow::datatypes::{DataType, TimeUnit};
7
8/// Error returned when a column isn't a `Timestamp(_)` type or Arrow's
9/// cast kernel fails.
10#[derive(Debug)]
11pub struct CastError(pub String);
12
13impl fmt::Display for CastError {
14    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
15        write!(f, "{}", self.0)
16    }
17}
18
19impl std::error::Error for CastError {}
20
21/// Cast any `Timestamp(_)` array to `TimestampMillisecondArray`.
22///
23/// # Errors
24///
25/// [`CastError`] if `array` isn't a `Timestamp(_)` or the cast fails.
26pub fn cast_to_millis_array(array: &dyn Array) -> Result<TimestampMillisecondArray, CastError> {
27    if !matches!(array.data_type(), DataType::Timestamp(_, _)) {
28        return Err(CastError(format!(
29            "event-time column must be Timestamp(_), found {:?}",
30            array.data_type()
31        )));
32    }
33    let cast = arrow::compute::cast(array, &DataType::Timestamp(TimeUnit::Millisecond, None))
34        .map_err(|e| CastError(e.to_string()))?;
35    cast.as_any()
36        .downcast_ref::<TimestampMillisecondArray>()
37        .cloned()
38        .ok_or_else(|| CastError("arrow cast did not yield TimestampMillisecond".into()))
39}
40
41#[cfg(test)]
42mod tests {
43    use super::*;
44    use arrow::array::{
45        Int64Array, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
46        TimestampSecondArray,
47    };
48
49    #[test]
50    fn passthrough_when_already_millis() {
51        let arr = TimestampMillisecondArray::from(vec![1, 2, 3]);
52        let out = cast_to_millis_array(&arr).unwrap();
53        assert_eq!(out.values(), &[1, 2, 3]);
54    }
55
56    #[test]
57    fn rescales_nanos() {
58        let arr = TimestampNanosecondArray::from(vec![1_500_000, 2_500_000]);
59        let out = cast_to_millis_array(&arr).unwrap();
60        assert_eq!(out.values(), &[1, 2]);
61    }
62
63    #[test]
64    fn rescales_micros() {
65        let arr = TimestampMicrosecondArray::from(vec![1_500, 2_500]);
66        let out = cast_to_millis_array(&arr).unwrap();
67        assert_eq!(out.values(), &[1, 2]);
68    }
69
70    #[test]
71    fn rescales_seconds() {
72        let arr = TimestampSecondArray::from(vec![1, 2]);
73        let out = cast_to_millis_array(&arr).unwrap();
74        assert_eq!(out.values(), &[1_000, 2_000]);
75    }
76
77    #[test]
78    fn non_timestamp_errors() {
79        let arr = Int64Array::from(vec![1, 2]);
80        assert!(cast_to_millis_array(&arr).is_err());
81    }
82}