Skip to main content

laminar_connectors/lookup/
postgres_reference.rs

1//! `PostgreSQL` poll-based reference table source.
2//!
3//! Implements [`ReferenceTableSource`](crate::reference::ReferenceTableSource) via a simple `SELECT * FROM table`
4//! query. No replication slot or CDC configuration required — suitable for
5//! slowly-changing dimension tables that are refreshed by periodic snapshot.
6
7use arrow_array::RecordBatch;
8
9use crate::checkpoint::SourceCheckpoint;
10use crate::config::ConnectorConfig;
11use crate::error::ConnectorError;
12use crate::reference::ReferenceTableSource;
13
14/// A [`ReferenceTableSource`] backed by a single `SELECT *` query against
15/// a `PostgreSQL` table. Returns the full table as a snapshot, then completes.
16pub struct PostgresReferenceTableSource {
17    config: ConnectorConfig,
18    snapshot_done: bool,
19}
20
21impl PostgresReferenceTableSource {
22    /// Creates a new source from a [`ConnectorConfig`].
23    #[must_use]
24    pub fn new(config: ConnectorConfig) -> Self {
25        Self {
26            config,
27            snapshot_done: false,
28        }
29    }
30
31    /// Builds a `tokio_postgres` connection string from connector properties.
32    fn connection_string(&self) -> String {
33        let props = self.config.properties();
34        // Accept pre-formed connection string under either key.
35        if let Some(cs) = props
36            .get("connection_string")
37            .or_else(|| props.get("connection"))
38        {
39            return cs.clone();
40        }
41        let mut parts = Vec::new();
42        if let Some(h) = props.get("host") {
43            parts.push(format!("host={h}"));
44        }
45        if let Some(p) = props.get("port") {
46            parts.push(format!("port={p}"));
47        }
48        if let Some(d) = props.get("database") {
49            parts.push(format!("dbname={d}"));
50        }
51        if let Some(u) = props.get("user") {
52            parts.push(format!("user={u}"));
53        }
54        if let Some(pw) = props.get("password") {
55            parts.push(format!("password={pw}"));
56        }
57        parts.join(" ")
58    }
59
60    fn table_name(&self) -> &str {
61        self.config
62            .properties()
63            .get("table")
64            .map_or("unknown", String::as_str)
65    }
66}
67
68#[async_trait::async_trait]
69impl ReferenceTableSource for PostgresReferenceTableSource {
70    #[allow(clippy::too_many_lines)]
71    async fn poll_snapshot(&mut self) -> Result<Option<RecordBatch>, ConnectorError> {
72        if self.snapshot_done {
73            return Ok(None);
74        }
75
76        let conn_str = self.connection_string();
77        let (client, connection) = tokio_postgres::connect(&conn_str, tokio_postgres::NoTls)
78            .await
79            .map_err(|e| {
80                let msg = e.to_string();
81                if msg.contains("SSL") || msg.contains("TLS") || msg.contains("sslmode") {
82                    ConnectorError::ConnectionFailed(format!(
83                        "postgres connect: {e} (TLS not supported by the standalone \
84                         'postgres' lookup connector — use sslmode=disable or \
85                         'postgres-cdc' for TLS)"
86                    ))
87                } else {
88                    ConnectorError::ConnectionFailed(format!("postgres connect: {e}"))
89                }
90            })?;
91
92        // Drive the connection on a background task.
93        tokio::spawn(async move {
94            if let Err(e) = connection.await {
95                tracing::warn!(error = %e, "postgres lookup connection error");
96            }
97        });
98
99        let table = self.table_name();
100        let sql = format!("SELECT * FROM {table}");
101        let rows = client
102            .query(&sql, &[])
103            .await
104            .map_err(|e| ConnectorError::ReadError(format!("postgres query: {e}")))?;
105
106        self.snapshot_done = true;
107
108        if rows.is_empty() {
109            return Ok(None);
110        }
111
112        // Convert rows to Arrow RecordBatch via pg row metadata.
113        let pg_columns = rows[0].columns();
114        let fields: Vec<arrow_schema::Field> = pg_columns
115            .iter()
116            .map(|col| {
117                let dt = pg_type_to_arrow(col.type_());
118                arrow_schema::Field::new(col.name(), dt, true)
119            })
120            .collect();
121        let schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
122
123        // Build columnar arrays.
124        let mut columns: Vec<std::sync::Arc<dyn arrow_array::Array>> =
125            Vec::with_capacity(schema.fields().len());
126
127        for (col_idx, field) in schema.fields().iter().enumerate() {
128            let col_name = pg_columns[col_idx].name();
129            let pg_type = pg_columns[col_idx].type_().clone();
130            let array: std::sync::Arc<dyn arrow_array::Array> = match field.data_type() {
131                arrow_schema::DataType::Boolean => {
132                    let vals: Vec<Option<bool>> = collect_column(&rows, col_name, &pg_type)?;
133                    std::sync::Arc::new(arrow_array::BooleanArray::from(vals))
134                }
135                arrow_schema::DataType::Int16 => {
136                    let vals: Vec<Option<i16>> = collect_column(&rows, col_name, &pg_type)?;
137                    std::sync::Arc::new(arrow_array::Int16Array::from(vals))
138                }
139                arrow_schema::DataType::Int32 => {
140                    let vals: Vec<Option<i32>> = collect_column(&rows, col_name, &pg_type)?;
141                    std::sync::Arc::new(arrow_array::Int32Array::from(vals))
142                }
143                arrow_schema::DataType::Int64 => {
144                    let vals: Vec<Option<i64>> = collect_column(&rows, col_name, &pg_type)?;
145                    std::sync::Arc::new(arrow_array::Int64Array::from(vals))
146                }
147                arrow_schema::DataType::Float32 => {
148                    let vals: Vec<Option<f32>> = collect_column(&rows, col_name, &pg_type)?;
149                    std::sync::Arc::new(arrow_array::Float32Array::from(vals))
150                }
151                arrow_schema::DataType::Float64 => {
152                    let vals: Vec<Option<f64>> = collect_column(&rows, col_name, &pg_type)?;
153                    std::sync::Arc::new(arrow_array::Float64Array::from(vals))
154                }
155                _ => {
156                    // Fallback: read as String via try_get.
157                    let mut vals: Vec<Option<String>> = Vec::with_capacity(rows.len());
158                    for row in &rows {
159                        match row.try_get::<_, Option<String>>(col_name) {
160                            Ok(v) => vals.push(v),
161                            Err(_) => {
162                                // Type doesn't implement FromSql<String> — format via Debug.
163                                vals.push(None);
164                            }
165                        }
166                    }
167                    let str_vals: Vec<Option<&str>> =
168                        vals.iter().map(|v: &Option<String>| v.as_deref()).collect();
169                    std::sync::Arc::new(arrow_array::StringArray::from(str_vals))
170                }
171            };
172            columns.push(array);
173        }
174
175        let batch = RecordBatch::try_new(schema, columns)
176            .map_err(|e| ConnectorError::ReadError(format!("arrow batch construction: {e}")))?;
177
178        Ok(Some(batch))
179    }
180
181    fn is_snapshot_complete(&self) -> bool {
182        self.snapshot_done
183    }
184
185    async fn poll_changes(&mut self) -> Result<Option<RecordBatch>, ConnectorError> {
186        Ok(None)
187    }
188
189    fn checkpoint(&self) -> SourceCheckpoint {
190        SourceCheckpoint::new(u64::from(self.snapshot_done))
191    }
192
193    async fn restore(&mut self, _checkpoint: &SourceCheckpoint) -> Result<(), ConnectorError> {
194        Ok(())
195    }
196
197    async fn close(&mut self) -> Result<(), ConnectorError> {
198        Ok(())
199    }
200}
201
202/// Collect a typed column from all rows via `try_get`, returning
203/// `ConnectorError::ReadError` on conversion failure.
204fn collect_column<'a, T>(
205    rows: &'a [tokio_postgres::Row],
206    col_name: &str,
207    pg_type: &tokio_postgres::types::Type,
208) -> Result<Vec<Option<T>>, ConnectorError>
209where
210    T: tokio_postgres::types::FromSql<'a>,
211{
212    rows.iter()
213        .map(|r| {
214            r.try_get::<_, Option<T>>(col_name).map_err(|e| {
215                ConnectorError::ReadError(format!("column '{col_name}' (pg type {pg_type}): {e}"))
216            })
217        })
218        .collect()
219}
220
221/// Map a `tokio_postgres` type to an Arrow `DataType`.
222fn pg_type_to_arrow(pg_type: &tokio_postgres::types::Type) -> arrow_schema::DataType {
223    use tokio_postgres::types::Type;
224    match *pg_type {
225        Type::BOOL => arrow_schema::DataType::Boolean,
226        Type::INT2 => arrow_schema::DataType::Int16,
227        Type::INT4 => arrow_schema::DataType::Int32,
228        Type::INT8 => arrow_schema::DataType::Int64,
229        Type::FLOAT4 => arrow_schema::DataType::Float32,
230        Type::FLOAT8 => arrow_schema::DataType::Float64,
231        // TIMESTAMP, UUID, JSONB, etc. → read as Utf8 (String).
232        _ => arrow_schema::DataType::Utf8,
233    }
234}