Skip to main content

laminar_connectors/postgres/
types.rs

1//! Arrow to `PostgreSQL` type mapping for sink operations.
2//!
3//! Maps Apache Arrow `DataType` to `PostgreSQL` SQL type names for:
4//! - UNNEST array casts in upsert queries
5//! - CREATE TABLE DDL generation
6//! - COPY BINARY column type declarations
7
8use arrow_schema::DataType;
9
10/// Maps an Arrow `DataType` to a `PostgreSQL` SQL type name for UNNEST casts.
11///
12/// Used in batched upsert queries:
13/// ```sql
14/// INSERT INTO t (col) SELECT * FROM UNNEST($1::int8[])
15/// ```
16///
17/// # Examples
18///
19/// ```rust,ignore
20/// use arrow_schema::DataType;
21/// assert_eq!(arrow_type_to_pg_sql(&DataType::Int64), "int8");
22/// assert_eq!(arrow_type_to_pg_sql(&DataType::Utf8), "text");
23/// ```
24#[must_use]
25#[allow(clippy::match_same_arms)]
26pub fn arrow_type_to_pg_sql(dt: &DataType) -> &'static str {
27    match dt {
28        DataType::Boolean => "bool",
29        DataType::Int8 | DataType::UInt8 => "int2",
30        DataType::Int16 | DataType::UInt16 => "int2",
31        DataType::Int32 => "int4",
32        DataType::UInt32 => "int8", // Widened: no unsigned in PG
33        DataType::Int64 | DataType::UInt64 => "int8",
34        DataType::Float16 | DataType::Float32 => "float4",
35        DataType::Float64 => "float8",
36        DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => "numeric",
37        DataType::Utf8 | DataType::LargeUtf8 => "text",
38        DataType::Binary | DataType::LargeBinary => "bytea",
39        DataType::FixedSizeBinary(16) => "uuid",
40        DataType::FixedSizeBinary(_) => "bytea",
41        DataType::Date32 | DataType::Date64 => "date",
42        DataType::Time32(_) | DataType::Time64(_) => "time",
43        DataType::Timestamp(_, None) => "timestamp",
44        DataType::Timestamp(_, Some(_)) => "timestamptz",
45        DataType::Duration(_) => "interval",
46        _ => "text", // Fallback for complex/nested types
47    }
48}
49
50/// Maps an Arrow `DataType` to a `PostgreSQL` DDL type for CREATE TABLE.
51///
52/// Returns a type suitable for `CREATE TABLE` column definitions.
53/// More verbose than [`arrow_type_to_pg_sql`] where needed (e.g., `DOUBLE PRECISION`).
54///
55/// # Examples
56///
57/// ```rust,ignore
58/// use arrow_schema::DataType;
59/// assert_eq!(arrow_to_pg_ddl_type(&DataType::Int64), "BIGINT");
60/// assert_eq!(arrow_to_pg_ddl_type(&DataType::Float64), "DOUBLE PRECISION");
61/// ```
62#[must_use]
63#[allow(clippy::match_same_arms)]
64pub fn arrow_to_pg_ddl_type(dt: &DataType) -> &'static str {
65    match dt {
66        DataType::Boolean => "BOOLEAN",
67        DataType::Int8 | DataType::UInt8 => "SMALLINT",
68        DataType::Int16 | DataType::UInt16 => "SMALLINT",
69        DataType::Int32 => "INTEGER",
70        DataType::UInt32 => "BIGINT",
71        DataType::Int64 | DataType::UInt64 => "BIGINT",
72        DataType::Float16 | DataType::Float32 => "REAL",
73        DataType::Float64 => "DOUBLE PRECISION",
74        DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => "NUMERIC",
75        DataType::Utf8 | DataType::LargeUtf8 => "TEXT",
76        DataType::Binary | DataType::LargeBinary => "BYTEA",
77        DataType::FixedSizeBinary(16) => "UUID",
78        DataType::FixedSizeBinary(_) => "BYTEA",
79        DataType::Date32 | DataType::Date64 => "DATE",
80        DataType::Time32(_) | DataType::Time64(_) => "TIME",
81        DataType::Timestamp(_, None) => "TIMESTAMP",
82        DataType::Timestamp(_, Some(_)) => "TIMESTAMPTZ",
83        DataType::Duration(_) => "INTERVAL",
84        _ => "TEXT",
85    }
86}
87
88/// Returns the `PostgreSQL` array type suffix for UNNEST cast expressions.
89///
90/// Combines [`arrow_type_to_pg_sql`] with `[]` for array parameter casting:
91/// `$1::int8[]`
92#[must_use]
93pub fn arrow_type_to_pg_array_cast(dt: &DataType, param_index: usize) -> String {
94    format!("${}::{}[]", param_index, arrow_type_to_pg_sql(dt))
95}
96
97/// Converts an Arrow array column to a boxed `PostgreSQL` array parameter for UNNEST queries.
98///
99/// Each Arrow type maps to the corresponding Rust type that implements
100/// `postgres_types::ToSql`. The returned `Box` is passed as a bind parameter
101/// to `tokio_postgres::Client::execute`.
102///
103/// # Supported Types
104///
105/// `Boolean`, `Int8`–`Int64`, `UInt8`–`UInt64` (widened), `Float32`/`Float64`,
106/// `Utf8`, `LargeUtf8`, `Binary`, `LargeBinary`, `Date32`,
107/// `Timestamp` (all units, with/without tz).
108/// Unsupported types fall back to string representation.
109///
110/// # Errors
111///
112/// Returns `ConnectorError::Internal` if the array cannot be downcast to
113/// the expected Arrow array type.
114#[cfg(feature = "postgres-sink")]
115#[allow(
116    clippy::too_many_lines,
117    clippy::cast_possible_truncation,
118    clippy::missing_panics_doc
119)]
120pub fn arrow_column_to_pg_array(
121    array: &dyn arrow_array::Array,
122) -> Result<Box<dyn postgres_types::ToSql + Sync + Send>, crate::error::ConnectorError> {
123    use crate::error::ConnectorError;
124    use arrow_array::{
125        Array as _, BinaryArray, BooleanArray, Date32Array, Float32Array, Float64Array, Int16Array,
126        Int32Array, Int64Array, Int8Array, LargeStringArray, StringArray,
127        TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
128        TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
129    };
130    use arrow_schema::TimeUnit;
131
132    macro_rules! extract_primitive {
133        ($array:expr, $arrow_ty:ty, $rust_ty:ty) => {{
134            let arr = $array.as_any().downcast_ref::<$arrow_ty>().ok_or_else(|| {
135                ConnectorError::Internal(format!("downcast to {} failed", stringify!($arrow_ty)))
136            })?;
137            #[allow(
138                clippy::cast_possible_truncation,
139                clippy::cast_possible_wrap,
140                clippy::cast_sign_loss
141            )]
142            let vals: Vec<Option<$rust_ty>> = (0..arr.len())
143                .map(|i| {
144                    if arr.is_null(i) {
145                        None
146                    } else {
147                        Some(arr.value(i) as $rust_ty)
148                    }
149                })
150                .collect();
151            Ok(Box::new(vals))
152        }};
153    }
154
155    match array.data_type() {
156        DataType::Boolean => {
157            let arr = array
158                .as_any()
159                .downcast_ref::<BooleanArray>()
160                .ok_or_else(|| {
161                    ConnectorError::Internal("downcast to BooleanArray failed".into())
162                })?;
163            let vals: Vec<Option<bool>> = (0..arr.len())
164                .map(|i| {
165                    if arr.is_null(i) {
166                        None
167                    } else {
168                        Some(arr.value(i))
169                    }
170                })
171                .collect();
172            Ok(Box::new(vals))
173        }
174        DataType::Int8 => extract_primitive!(array, Int8Array, i16),
175        DataType::UInt8 => extract_primitive!(array, UInt8Array, i16),
176        DataType::Int16 => extract_primitive!(array, Int16Array, i16),
177        DataType::UInt16 => extract_primitive!(array, UInt16Array, i32),
178        DataType::Int32 => extract_primitive!(array, Int32Array, i32),
179        DataType::UInt32 => extract_primitive!(array, UInt32Array, i64),
180        DataType::Int64 => extract_primitive!(array, Int64Array, i64),
181        DataType::UInt64 => {
182            // PG has no unsigned 64-bit; cast to i64 (wraps for > i64::MAX).
183            let arr = array
184                .as_any()
185                .downcast_ref::<UInt64Array>()
186                .ok_or_else(|| ConnectorError::Internal("downcast to UInt64Array failed".into()))?;
187            #[allow(clippy::cast_possible_wrap)]
188            let vals: Vec<Option<i64>> = (0..arr.len())
189                .map(|i| {
190                    if arr.is_null(i) {
191                        None
192                    } else {
193                        Some(arr.value(i) as i64)
194                    }
195                })
196                .collect();
197            Ok(Box::new(vals))
198        }
199        DataType::Float32 => extract_primitive!(array, Float32Array, f32),
200        DataType::Float64 => extract_primitive!(array, Float64Array, f64),
201        DataType::Utf8 => {
202            let arr = array
203                .as_any()
204                .downcast_ref::<StringArray>()
205                .ok_or_else(|| ConnectorError::Internal("downcast to StringArray failed".into()))?;
206            let vals: Vec<Option<String>> = (0..arr.len())
207                .map(|i| {
208                    if arr.is_null(i) {
209                        None
210                    } else {
211                        Some(arr.value(i).to_owned())
212                    }
213                })
214                .collect();
215            Ok(Box::new(vals))
216        }
217        DataType::LargeUtf8 => {
218            let arr = array
219                .as_any()
220                .downcast_ref::<LargeStringArray>()
221                .ok_or_else(|| {
222                    ConnectorError::Internal("downcast to LargeStringArray failed".into())
223                })?;
224            let vals: Vec<Option<String>> = (0..arr.len())
225                .map(|i| {
226                    if arr.is_null(i) {
227                        None
228                    } else {
229                        Some(arr.value(i).to_owned())
230                    }
231                })
232                .collect();
233            Ok(Box::new(vals))
234        }
235        DataType::Binary | DataType::LargeBinary | DataType::FixedSizeBinary(_) => {
236            let arr = array
237                .as_any()
238                .downcast_ref::<BinaryArray>()
239                .ok_or_else(|| ConnectorError::Internal("downcast to BinaryArray failed".into()))?;
240            let vals: Vec<Option<Vec<u8>>> = (0..arr.len())
241                .map(|i| {
242                    if arr.is_null(i) {
243                        None
244                    } else {
245                        Some(arr.value(i).to_vec())
246                    }
247                })
248                .collect();
249            Ok(Box::new(vals))
250        }
251        DataType::Date32 => {
252            let arr = array
253                .as_any()
254                .downcast_ref::<Date32Array>()
255                .ok_or_else(|| ConnectorError::Internal("downcast to Date32Array failed".into()))?;
256            let epoch =
257                chrono::NaiveDate::from_ymd_opt(1970, 1, 1).expect("1970-01-01 is a valid date");
258            let vals: Vec<Option<chrono::NaiveDate>> = (0..arr.len())
259                .map(|i| {
260                    if arr.is_null(i) {
261                        None
262                    } else {
263                        let days = i64::from(arr.value(i));
264                        if days >= 0 {
265                            epoch.checked_add_days(chrono::Days::new(days.unsigned_abs()))
266                        } else {
267                            epoch.checked_sub_days(chrono::Days::new(days.unsigned_abs()))
268                        }
269                    }
270                })
271                .collect();
272            Ok(Box::new(vals))
273        }
274        DataType::Timestamp(unit, tz) => {
275            let arr = array
276                .as_any()
277                .downcast_ref::<TimestampMicrosecondArray>()
278                .map(|a| {
279                    (0..a.len())
280                        .map(|i| {
281                            if a.is_null(i) {
282                                None
283                            } else {
284                                to_naive_datetime(a.value(i), &TimeUnit::Microsecond)
285                            }
286                        })
287                        .collect::<Vec<_>>()
288                })
289                .or_else(|| {
290                    array
291                        .as_any()
292                        .downcast_ref::<TimestampMillisecondArray>()
293                        .map(|a| {
294                            (0..a.len())
295                                .map(|i| {
296                                    if a.is_null(i) {
297                                        None
298                                    } else {
299                                        to_naive_datetime(a.value(i), &TimeUnit::Millisecond)
300                                    }
301                                })
302                                .collect()
303                        })
304                })
305                .or_else(|| {
306                    array
307                        .as_any()
308                        .downcast_ref::<TimestampSecondArray>()
309                        .map(|a| {
310                            (0..a.len())
311                                .map(|i| {
312                                    if a.is_null(i) {
313                                        None
314                                    } else {
315                                        to_naive_datetime(a.value(i), &TimeUnit::Second)
316                                    }
317                                })
318                                .collect()
319                        })
320                })
321                .or_else(|| {
322                    array
323                        .as_any()
324                        .downcast_ref::<TimestampNanosecondArray>()
325                        .map(|a| {
326                            (0..a.len())
327                                .map(|i| {
328                                    if a.is_null(i) {
329                                        None
330                                    } else {
331                                        to_naive_datetime(a.value(i), &TimeUnit::Nanosecond)
332                                    }
333                                })
334                                .collect()
335                        })
336                });
337
338            let vals = arr.ok_or_else(|| {
339                ConnectorError::Internal(format!(
340                    "unsupported timestamp unit {unit:?} for pg array conversion"
341                ))
342            })?;
343
344            if tz.is_some() {
345                // TIMESTAMPTZ: wrap as DateTime<Utc>
346                let tz_vals: Vec<Option<chrono::DateTime<chrono::Utc>>> = vals
347                    .into_iter()
348                    .map(|opt| opt.map(|ndt| ndt.and_utc()))
349                    .collect();
350                Ok(Box::new(tz_vals))
351            } else {
352                Ok(Box::new(vals))
353            }
354        }
355        // Fallback: convert to string representation
356        other => {
357            let formatter = arrow_cast::display::ArrayFormatter::try_new(
358                array,
359                &arrow_cast::display::FormatOptions::default(),
360            )
361            .map_err(|e| ConnectorError::Internal(format!("arrow format error: {e}")))?;
362            let vals: Vec<Option<String>> = (0..array.len())
363                .map(|i| {
364                    if array.is_null(i) {
365                        None
366                    } else {
367                        Some(formatter.value(i).to_string())
368                    }
369                })
370                .collect();
371            tracing::debug!(
372                data_type = ?other,
373                "falling back to text conversion for unsupported Arrow type"
374            );
375            Ok(Box::new(vals))
376        }
377    }
378}
379
380/// Converts a raw timestamp value to [`chrono::NaiveDateTime`] based on the Arrow `TimeUnit`.
381#[cfg(feature = "postgres-sink")]
382#[allow(clippy::trivially_copy_pass_by_ref, clippy::cast_possible_truncation)]
383fn to_naive_datetime(value: i64, unit: &arrow_schema::TimeUnit) -> Option<chrono::NaiveDateTime> {
384    use arrow_schema::TimeUnit;
385    let (secs, nanos) = match unit {
386        TimeUnit::Second => (value, 0_u32),
387        TimeUnit::Millisecond => (
388            value / 1_000,
389            ((value % 1_000).unsigned_abs() as u32) * 1_000_000,
390        ),
391        TimeUnit::Microsecond => (
392            value / 1_000_000,
393            ((value % 1_000_000).unsigned_abs() as u32) * 1_000,
394        ),
395        TimeUnit::Nanosecond => (
396            value / 1_000_000_000,
397            (value % 1_000_000_000).unsigned_abs() as u32,
398        ),
399    };
400    chrono::DateTime::from_timestamp(secs, nanos).map(|dt| dt.naive_utc())
401}
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406    use arrow_schema::TimeUnit;
407
408    #[test]
409    fn test_boolean_mapping() {
410        assert_eq!(arrow_type_to_pg_sql(&DataType::Boolean), "bool");
411        assert_eq!(arrow_to_pg_ddl_type(&DataType::Boolean), "BOOLEAN");
412    }
413
414    #[test]
415    fn test_integer_mappings() {
416        assert_eq!(arrow_type_to_pg_sql(&DataType::Int8), "int2");
417        assert_eq!(arrow_type_to_pg_sql(&DataType::Int16), "int2");
418        assert_eq!(arrow_type_to_pg_sql(&DataType::Int32), "int4");
419        assert_eq!(arrow_type_to_pg_sql(&DataType::Int64), "int8");
420
421        assert_eq!(arrow_to_pg_ddl_type(&DataType::Int32), "INTEGER");
422        assert_eq!(arrow_to_pg_ddl_type(&DataType::Int64), "BIGINT");
423    }
424
425    #[test]
426    fn test_unsigned_widening() {
427        // UInt32 must widen to int8 (no unsigned in PG)
428        assert_eq!(arrow_type_to_pg_sql(&DataType::UInt32), "int8");
429        assert_eq!(arrow_to_pg_ddl_type(&DataType::UInt32), "BIGINT");
430        assert_eq!(arrow_type_to_pg_sql(&DataType::UInt64), "int8");
431    }
432
433    #[test]
434    fn test_float_mappings() {
435        assert_eq!(arrow_type_to_pg_sql(&DataType::Float32), "float4");
436        assert_eq!(arrow_type_to_pg_sql(&DataType::Float64), "float8");
437        assert_eq!(arrow_to_pg_ddl_type(&DataType::Float64), "DOUBLE PRECISION");
438    }
439
440    #[test]
441    fn test_decimal_mapping() {
442        assert_eq!(
443            arrow_type_to_pg_sql(&DataType::Decimal128(10, 2)),
444            "numeric"
445        );
446        assert_eq!(
447            arrow_to_pg_ddl_type(&DataType::Decimal128(10, 2)),
448            "NUMERIC"
449        );
450    }
451
452    #[test]
453    fn test_string_mappings() {
454        assert_eq!(arrow_type_to_pg_sql(&DataType::Utf8), "text");
455        assert_eq!(arrow_type_to_pg_sql(&DataType::LargeUtf8), "text");
456        assert_eq!(arrow_to_pg_ddl_type(&DataType::Utf8), "TEXT");
457    }
458
459    #[test]
460    fn test_binary_mappings() {
461        assert_eq!(arrow_type_to_pg_sql(&DataType::Binary), "bytea");
462        assert_eq!(arrow_type_to_pg_sql(&DataType::LargeBinary), "bytea");
463        assert_eq!(arrow_to_pg_ddl_type(&DataType::Binary), "BYTEA");
464    }
465
466    #[test]
467    fn test_uuid_mapping() {
468        assert_eq!(arrow_type_to_pg_sql(&DataType::FixedSizeBinary(16)), "uuid");
469        assert_eq!(arrow_to_pg_ddl_type(&DataType::FixedSizeBinary(16)), "UUID");
470        // Non-16 byte fixed binary falls back to bytea
471        assert_eq!(
472            arrow_type_to_pg_sql(&DataType::FixedSizeBinary(32)),
473            "bytea"
474        );
475    }
476
477    #[test]
478    fn test_date_time_mappings() {
479        assert_eq!(arrow_type_to_pg_sql(&DataType::Date32), "date");
480        assert_eq!(arrow_to_pg_ddl_type(&DataType::Date32), "DATE");
481
482        assert_eq!(
483            arrow_type_to_pg_sql(&DataType::Time64(TimeUnit::Microsecond)),
484            "time"
485        );
486        assert_eq!(
487            arrow_to_pg_ddl_type(&DataType::Time64(TimeUnit::Microsecond)),
488            "TIME"
489        );
490    }
491
492    #[test]
493    fn test_timestamp_mappings() {
494        assert_eq!(
495            arrow_type_to_pg_sql(&DataType::Timestamp(TimeUnit::Microsecond, None)),
496            "timestamp"
497        );
498        assert_eq!(
499            arrow_to_pg_ddl_type(&DataType::Timestamp(TimeUnit::Microsecond, None)),
500            "TIMESTAMP"
501        );
502
503        assert_eq!(
504            arrow_type_to_pg_sql(&DataType::Timestamp(
505                TimeUnit::Microsecond,
506                Some("UTC".into())
507            )),
508            "timestamptz"
509        );
510        assert_eq!(
511            arrow_to_pg_ddl_type(&DataType::Timestamp(
512                TimeUnit::Microsecond,
513                Some("UTC".into())
514            )),
515            "TIMESTAMPTZ"
516        );
517    }
518
519    #[test]
520    fn test_fallback_to_text() {
521        // Complex types fall back to text
522        assert_eq!(
523            arrow_type_to_pg_sql(&DataType::List(Arc::new(arrow_schema::Field::new(
524                "item",
525                DataType::Int32,
526                true
527            )))),
528            "text"
529        );
530    }
531
532    #[test]
533    fn test_array_cast_expression() {
534        assert_eq!(
535            arrow_type_to_pg_array_cast(&DataType::Int64, 1),
536            "$1::int8[]"
537        );
538        assert_eq!(
539            arrow_type_to_pg_array_cast(&DataType::Utf8, 3),
540            "$3::text[]"
541        );
542        assert_eq!(
543            arrow_type_to_pg_array_cast(&DataType::Boolean, 2),
544            "$2::bool[]"
545        );
546    }
547
548    use std::sync::Arc;
549}