Skip to main content

laminardb/
pgwire.rs

1//! Postgres wire endpoint. Trust by default; MD5 with `pgwire_users`;
2//! TLS with `pgwire_tls_cert` + `pgwire_tls_key`. Non-loopback binds
3//! require `pgwire_allow_remote = true`.
4
5use std::collections::{HashMap, VecDeque};
6use std::fmt::Debug;
7use std::net::SocketAddr;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::Arc;
10
11use async_trait::async_trait;
12use futures::{stream, Sink, StreamExt};
13use laminar_sql::parser::{
14    parse_streaming_sql, ShowCommand, StreamingStatement, SubscribeStatement,
15};
16use pgwire::api::auth::md5pass::{hash_md5_password, Md5PasswordAuthStartupHandler};
17use pgwire::api::auth::noop::NoopStartupHandler;
18use pgwire::api::auth::{
19    AuthSource, DefaultServerParameterProvider, LoginInfo, Password, StartupHandler,
20};
21use pgwire::api::portal::{Format, Portal};
22use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
23use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag};
24use pgwire::api::stmt::QueryParser;
25use pgwire::api::store::PortalStore;
26use pgwire::api::{ClientInfo, ClientPortalStore, PgWireServerHandlers, Type};
27use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
28use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
29use pgwire::tokio::process_socket;
30use sqlparser::ast::{
31    CloseCursor, Expr, FetchDirection, FunctionArguments, SelectItem, Set, SetExpr, Statement,
32    Value as AstValue,
33};
34use tokio::net::TcpListener;
35use tokio::sync::Mutex as TokioMutex;
36use tracing::{info, warn};
37
38use laminar_db::subscription::{PortalFrame, SubscribeStart, SubscriptionPortal};
39use laminar_db::LaminarDB;
40
41use crate::config::Secret;
42use crate::server::ServerError;
43
44pub struct LaminarPgwireHandler {
45    db: Arc<LaminarDB>,
46    /// Per-peer SimpleQuery cursor map. Entries are evicted at the start of
47    /// every `do_query` call once their cursors are dead and no transaction
48    /// is open — pgwire 0.39 doesn't give us a connection-close hook, so this
49    /// is the cheapest way to keep stale state from leaking on port reuse.
50    connections: parking_lot::Mutex<HashMap<SocketAddr, Arc<ConnState>>>,
51}
52
53impl LaminarPgwireHandler {
54    fn new(db: Arc<LaminarDB>) -> Self {
55        Self {
56            db,
57            connections: parking_lot::Mutex::new(HashMap::new()),
58        }
59    }
60
61    fn conn_state(&self, peer: SocketAddr) -> Arc<ConnState> {
62        let mut guard = self.connections.lock();
63        Arc::clone(
64            guard
65                .entry(peer)
66                .or_insert_with(|| Arc::new(ConnState::default())),
67        )
68    }
69
70    fn evict_idle_peer(&self, peer: SocketAddr) {
71        let mut guard = self.connections.lock();
72        if let Some(state) = guard.get(&peer) {
73            if state.prune_dead_and_check_idle() {
74                guard.remove(&peer);
75            }
76        }
77    }
78}
79
80#[async_trait]
81impl NoopStartupHandler for LaminarPgwireHandler {
82    async fn post_startup<C>(
83        &self,
84        client: &mut C,
85        _message: PgWireFrontendMessage,
86    ) -> PgWireResult<()>
87    where
88        C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
89        C::Error: Debug,
90        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
91    {
92        info!(peer = %client.socket_addr(), "pgwire client connected");
93        Ok(())
94    }
95}
96
97#[async_trait]
98impl SimpleQueryHandler for LaminarPgwireHandler {
99    async fn do_query<C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
100    where
101        C: ClientInfo + ClientPortalStore + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
102        C::PortalStore: PortalStore,
103        C::Error: Debug,
104        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
105    {
106        if query.trim().is_empty() {
107            return Ok(vec![Response::EmptyQuery]);
108        }
109        let peer = client.socket_addr();
110        self.evict_idle_peer(peer);
111        let stmts = parse_streaming_sql(query)
112            .map_err(|e| user_error("42601", format!("parse error: {e}")))?;
113
114        // SUBSCRIBE owns the socket for its lifetime; it can't share a
115        // simple-query batch with earlier or later statements. Reject
116        // up front so trailing statements aren't silently dropped.
117        if stmts.len() > 1
118            && stmts
119                .iter()
120                .any(|s| matches!(s, StreamingStatement::Subscribe(_)))
121        {
122            return Err(user_error(
123                "0A000",
124                "SUBSCRIBE must be the only statement in a simple query",
125            ));
126        }
127
128        let mut out = Vec::with_capacity(stmts.len());
129        for stmt in stmts {
130            out.push(match stmt {
131                StreamingStatement::Subscribe(s) => {
132                    let portal = open_portal_for_subscribe(&self.db, &s).await?;
133                    // Simple query is always text (no Bind result format).
134                    stream_subscribe_flushing(client, portal, true, None).await?;
135                    return Ok(Vec::new());
136                }
137                StreamingStatement::Show(cmd) => {
138                    engine_metadata_response(&self.db, &show_sql(&cmd)).await?
139                }
140                StreamingStatement::DeclareCursorForSubscribe {
141                    name, subscribe, ..
142                } => {
143                    let state = self.conn_state(peer);
144                    handle_declare_cursor(&self.db, &state, &name.value, *subscribe).await?
145                }
146                StreamingStatement::Standard(s) => {
147                    let state = self.conn_state(peer);
148                    standard_or_cursor_response(&self.db, &state, *s)?
149                }
150                other => {
151                    return Err(user_error(
152                        "0A000",
153                        format!("not supported on pgwire (use HTTP /api/v1/sql): {other:?}"),
154                    ));
155                }
156            });
157        }
158        Ok(out)
159    }
160}
161
162async fn open_portal_for_subscribe(
163    db: &LaminarDB,
164    s: &SubscribeStatement,
165) -> PgWireResult<SubscriptionPortal> {
166    let name = s.name.to_string();
167    let start = match s.as_of_epoch {
168        Some(n) => SubscribeStart::AsOfEpoch(n),
169        None => SubscribeStart::Tail,
170    };
171    db.open_subscription(&name, s.filter_sql.as_deref(), start)
172        .await
173        .map_err(|e| user_error("42P01", format!("SUBSCRIBE '{name}': {e}")))
174}
175
176/// Stream a SUBSCRIBE, flushing the `Sink` after every batch.
177///
178/// Workaround: pgwire `feed()`s `DataRow`s and only flushes at
179/// end-of-response (never, for an unbounded SUBSCRIBE) or at its ~8 KB
180/// buffer, so a sparse stream stalls. Per-batch flush is unconditional
181/// (fine — batches amortise; not per-row). Retire when pgwire flushes
182/// streaming responses upstream. Both paths need it (psql=simple,
183/// psycopg/JDBC=extended).
184///
185/// `send_row_desc`: simple query carries `RowDescription`; extended
186/// already sent it via `Describe` (caller returns `Response::Execution`
187/// for `CommandComplete`). Returns `Ok(())` only on pump exit.
188async fn stream_subscribe_flushing<C>(
189    client: &mut C,
190    mut portal: SubscriptionPortal,
191    send_row_desc: bool,
192    result_format: Option<&Format>,
193) -> PgWireResult<()>
194where
195    C: Sink<PgWireBackendMessage> + Unpin + Send,
196    C::Error: Debug,
197    PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
198{
199    use futures::SinkExt;
200
201    let schema = portal.schema();
202    // Honour the extended client's per-column binary/text choice; pgwire's
203    // `Describe` advertised it, so the `DataRow` encoding must match.
204    let fields = std::sync::Arc::new(field_infos(&schema, result_format));
205
206    if send_row_desc {
207        // Equivalent to pgwire's crate-private `into_row_description`.
208        let row_desc =
209            pgwire::messages::data::RowDescription::new(fields.iter().map(Into::into).collect());
210        client
211            .feed(PgWireBackendMessage::RowDescription(row_desc))
212            .await?;
213        client.flush().await?;
214    }
215
216    let mut rows: usize = 0;
217    loop {
218        match portal.next_frame().await {
219            Some(PortalFrame::Batch(b)) if b.num_rows() > 0 => {
220                for row in encode_batch(&b, &fields) {
221                    client.feed(PgWireBackendMessage::DataRow(row?)).await?;
222                    rows += 1;
223                }
224                client.flush().await?;
225            }
226            Some(PortalFrame::Batch(_)) => {}
227            // Checkpoint barriers have no Postgres wire representation.
228            Some(PortalFrame::Barrier { .. }) => {}
229            Some(PortalFrame::Lagged(n)) => {
230                return Err(user_error(
231                    "54000",
232                    format!("subscription lagged: skipped {n} messages, terminating"),
233                ));
234            }
235            None => {
236                // Simple query owns the whole response, so emit
237                // CommandComplete here. The extended path lets pgwire
238                // emit it from the returned `Response::Execution`.
239                if send_row_desc {
240                    let tag = Tag::new("SUBSCRIBE").with_rows(rows);
241                    client
242                        .feed(PgWireBackendMessage::CommandComplete(tag.into()))
243                        .await?;
244                    client.flush().await?;
245                }
246                return Ok(());
247            }
248        }
249    }
250}
251
252/// Wrap a `SubscriptionPortal` in a pgwire `Response::Query` so the
253/// framework can chunk via `Execute(max_rows)` and emit PortalSuspended
254/// automatically. Used by the chunked extended-query path.
255fn subscription_query_response(
256    portal: SubscriptionPortal,
257    result_format: Option<&Format>,
258) -> Response {
259    use futures::stream;
260    let schema = portal.schema();
261    let fields = Arc::new(field_infos(&schema, result_format));
262    struct State {
263        portal: SubscriptionPortal,
264        fields: Arc<Vec<FieldInfo>>,
265        pending: VecDeque<PgWireResult<pgwire::messages::data::DataRow>>,
266    }
267    let init = State {
268        portal,
269        fields: Arc::clone(&fields),
270        pending: VecDeque::new(),
271    };
272    let row_stream = stream::unfold(init, |mut s| async move {
273        loop {
274            if let Some(row) = s.pending.pop_front() {
275                return Some((row, s));
276            }
277            match s.portal.next_frame().await {
278                None => return None,
279                Some(PortalFrame::Batch(b)) if b.num_rows() > 0 => {
280                    s.pending.extend(encode_batch(&b, &s.fields));
281                }
282                Some(PortalFrame::Batch(_)) | Some(PortalFrame::Barrier { .. }) => {}
283                Some(PortalFrame::Lagged(n)) => {
284                    let err = user_error(
285                        "54000",
286                        format!("subscription lagged: skipped {n} messages, terminating"),
287                    );
288                    return Some((Err(err), s));
289                }
290            }
291        }
292    });
293    let mut resp = QueryResponse::new(fields, row_stream);
294    resp.set_command_tag("SUBSCRIBE");
295    Response::Query(resp)
296}
297
298/// State that shares the cursor's lifetime: the portal, the leftover-row
299/// buffer, and the exhausted flag. Held by `Arc` so a row stream can keep
300/// reading after `ConnState::get` returns.
301struct CursorInner {
302    /// Tokio mutex because a FETCH stream holds it across `await` while
303    /// pulling frames.
304    portal: TokioMutex<SubscriptionPortal>,
305    /// Rows encoded from a prior frame that the previous FETCH didn't
306    /// consume. Without this, a multi-row batch + `FETCH 1` would drop the
307    /// leftover rows when the response stream ends.
308    pending: parking_lot::Mutex<VecDeque<PgWireResult<pgwire::messages::data::DataRow>>>,
309    /// Flipped when the pump emits `None` or `Lagged`. The next `evict_idle_peer`
310    /// pass reaps the cursor.
311    exhausted: AtomicBool,
312}
313
314#[derive(Clone)]
315struct ActiveCursor {
316    inner: Arc<CursorInner>,
317    schema: arrow_schema::SchemaRef,
318}
319
320#[derive(Default)]
321struct ConnState {
322    cursors: parking_lot::Mutex<HashMap<String, ActiveCursor>>,
323    in_tx: AtomicBool,
324}
325
326impl ConnState {
327    /// Cursor names follow PG identifier folding: unquoted → lowercase. We
328    /// don't track `quote_style`, so quoted-mixed-case cursors collapse too —
329    /// good enough for the `\set FETCH_COUNT` case this targets.
330    fn key(name: &str) -> String {
331        name.to_ascii_lowercase()
332    }
333
334    fn insert(&self, name: &str, cursor: ActiveCursor) {
335        self.cursors.lock().insert(Self::key(name), cursor);
336    }
337
338    fn contains(&self, name: &str) -> bool {
339        self.cursors.lock().contains_key(&Self::key(name))
340    }
341
342    fn get(&self, name: &str) -> Option<ActiveCursor> {
343        self.cursors.lock().get(&Self::key(name)).cloned()
344    }
345
346    fn remove(&self, name: &str) -> bool {
347        self.cursors.lock().remove(&Self::key(name)).is_some()
348    }
349
350    fn drop_all(&self) {
351        self.cursors.lock().clear();
352    }
353
354    /// Drop dead cursors and report whether the connection is now idle
355    /// (no cursors, no transaction). Single inner-lock acquisition.
356    fn prune_dead_and_check_idle(&self) -> bool {
357        let mut cursors = self.cursors.lock();
358        cursors.retain(|_, c| !c.inner.exhausted.load(Ordering::Acquire));
359        cursors.is_empty() && !self.in_tx.load(Ordering::Acquire)
360    }
361}
362
363/// Open a SUBSCRIBE behind a cursor name. Rejects with 42P03 if the name is
364/// already in use on this connection (matches PG; user must `CLOSE` first).
365async fn handle_declare_cursor(
366    db: &LaminarDB,
367    state: &ConnState,
368    cursor_name: &str,
369    subscribe: SubscribeStatement,
370) -> PgWireResult<Response> {
371    if state.contains(cursor_name) {
372        return Err(user_error(
373            "42P03",
374            format!("cursor \"{cursor_name}\" already exists"),
375        ));
376    }
377    let portal = open_portal_for_subscribe(db, &subscribe).await?;
378    let schema = portal.schema();
379    state.insert(
380        cursor_name,
381        ActiveCursor {
382            inner: Arc::new(CursorInner {
383                portal: TokioMutex::new(portal),
384                pending: parking_lot::Mutex::new(VecDeque::new()),
385                exhausted: AtomicBool::new(false),
386            }),
387            schema,
388        },
389    );
390    Ok(Response::Execution(Tag::new("DECLARE CURSOR")))
391}
392
393/// `FETCH NEXT` and bare `FETCH FORWARD` map to a single row, matching PG.
394fn fetch_direction_count(dir: &FetchDirection) -> PgWireResult<FetchTarget> {
395    match dir {
396        FetchDirection::Next | FetchDirection::Forward { limit: None } => Ok(FetchTarget::Count(1)),
397        FetchDirection::Count { limit } | FetchDirection::Forward { limit: Some(limit) } => {
398            value_to_u64(limit).map(FetchTarget::Count)
399        }
400        FetchDirection::All | FetchDirection::ForwardAll => Ok(FetchTarget::All),
401        FetchDirection::Prior
402        | FetchDirection::First
403        | FetchDirection::Last
404        | FetchDirection::Absolute { .. }
405        | FetchDirection::Relative { .. }
406        | FetchDirection::Backward { .. }
407        | FetchDirection::BackwardAll => Err(user_error(
408            "0A000",
409            "FETCH direction not supported (SUBSCRIBE cursors are forward-only): use FORWARD or NEXT",
410        )),
411    }
412}
413
414#[derive(Copy, Clone)]
415enum FetchTarget {
416    Count(u64),
417    All,
418}
419
420fn value_to_u64(v: &AstValue) -> PgWireResult<u64> {
421    match v {
422        AstValue::Number(n, _) => n
423            .parse::<u64>()
424            .map_err(|_| user_error("22023", format!("invalid FETCH count: {n}"))),
425        other => Err(user_error(
426            "22023",
427            format!("FETCH count must be an integer, got {other}"),
428        )),
429    }
430}
431
432fn handle_fetch(
433    state: &ConnState,
434    cursor_name: &str,
435    target: FetchTarget,
436) -> PgWireResult<Response> {
437    let cursor = state
438        .get(cursor_name)
439        .ok_or_else(|| user_error("34000", format!("cursor \"{cursor_name}\" does not exist")))?;
440    Ok(fetch_response(cursor, target))
441}
442
443fn handle_close(state: &ConnState, cursor: &CloseCursor) -> PgWireResult<Response> {
444    match cursor {
445        CloseCursor::All => {
446            state.drop_all();
447            Ok(Response::Execution(Tag::new("CLOSE CURSOR ALL")))
448        }
449        CloseCursor::Specific { name } => {
450            if state.remove(&name.value) {
451                Ok(Response::Execution(Tag::new("CLOSE CURSOR")))
452            } else {
453                Err(user_error(
454                    "34000",
455                    format!("cursor \"{}\" does not exist", name.value),
456                ))
457            }
458        }
459    }
460}
461
462/// Wraps the original `standard_response` and intercepts cursor / transaction
463/// statements that need ConnState. Anything else falls through to the
464/// existing handler unchanged.
465fn standard_or_cursor_response(
466    db: &LaminarDB,
467    state: &ConnState,
468    stmt: Statement,
469) -> PgWireResult<Response> {
470    match stmt {
471        Statement::StartTransaction { .. } => {
472            state.in_tx.store(true, Ordering::Release);
473            Ok(Response::Execution(Tag::new("BEGIN")))
474        }
475        Statement::Commit { .. } => {
476            state.drop_all();
477            state.in_tx.store(false, Ordering::Release);
478            Ok(Response::Execution(Tag::new("COMMIT")))
479        }
480        Statement::Rollback { .. } => {
481            state.drop_all();
482            state.in_tx.store(false, Ordering::Release);
483            Ok(Response::Execution(Tag::new("ROLLBACK")))
484        }
485        Statement::Fetch {
486            ref name,
487            ref direction,
488            ..
489        } => {
490            let target = fetch_direction_count(direction)?;
491            handle_fetch(state, &name.value, target)
492        }
493        Statement::Close { ref cursor } => handle_close(state, cursor),
494        Statement::Declare { .. } => Err(user_error(
495            "0A000",
496            "DECLARE on pgwire only supports CURSOR FOR SUBSCRIBE …",
497        )),
498        other => standard_response(db, other),
499    }
500}
501
502/// Connection-setup statements: transaction control, `SET`, and a tiny set
503/// of catalog probes drivers send during handshake. Anything DDL/DML hits
504/// the "use HTTP" error.
505fn standard_response(db: &LaminarDB, stmt: Statement) -> PgWireResult<Response> {
506    match stmt {
507        Statement::StartTransaction { .. } => Ok(Response::Execution(Tag::new("BEGIN"))),
508        Statement::Commit { .. } => Ok(Response::Execution(Tag::new("COMMIT"))),
509        Statement::Rollback { .. } => Ok(Response::Execution(Tag::new("ROLLBACK"))),
510        Statement::Set(s) => apply_set(db, s),
511        Statement::Query(q) => driver_select_response(*q),
512        Statement::Insert { .. }
513        | Statement::Update { .. }
514        | Statement::Delete { .. }
515        | Statement::CreateTable { .. }
516        | Statement::CreateView { .. }
517        | Statement::Drop { .. } => Err(user_error(
518            "0A000",
519            "DDL/DML is not supported on pgwire; use HTTP /api/v1/sql",
520        )),
521        other => Err(user_error(
522            "0A000",
523            format!("not supported on pgwire: {other}"),
524        )),
525    }
526}
527
528/// Handle the `SELECT`s drivers issue at connect time. Single literal,
529/// `SELECT version()`, and `SELECT current_schema()` are answered inline.
530/// Anything else is rejected — real queries belong on `/api/v1/sql`.
531fn driver_select_response(query: sqlparser::ast::Query) -> PgWireResult<Response> {
532    let SetExpr::Select(select) = *query.body else {
533        return Err(unsupported_select());
534    };
535    if select.projection.len() != 1 || !select.from.is_empty() || select.selection.is_some() {
536        return Err(unsupported_select());
537    }
538    let SelectItem::UnnamedExpr(expr) = &select.projection[0] else {
539        return Err(unsupported_select());
540    };
541
542    match expr {
543        Expr::Value(v) => match &v.value {
544            sqlparser::ast::Value::Number(n, _) => {
545                let parsed: i32 = n.parse().map_err(|_| unsupported_select())?;
546                Ok(text_response("?column?", Type::INT4, parsed.to_string()))
547            }
548            sqlparser::ast::Value::SingleQuotedString(s) => {
549                Ok(text_response("?column?", Type::VARCHAR, s.clone()))
550            }
551            _ => Err(unsupported_select()),
552        },
553        Expr::Function(func) => {
554            // Only no-arg builtins. `func.args` is `FunctionArguments::None`
555            // for `func()`, the only shape we accept.
556            if !matches!(func.args, FunctionArguments::List(ref a) if a.args.is_empty())
557                && !matches!(func.args, FunctionArguments::None)
558            {
559                return Err(unsupported_select());
560            }
561            let name = func.name.to_string().to_ascii_lowercase();
562            let (col, ty, value) = match name.as_str() {
563                "version" | "pg_catalog.version" => (
564                    "version",
565                    Type::VARCHAR,
566                    format!("LaminarDB {} on pgwire", env!("CARGO_PKG_VERSION")),
567                ),
568                "current_schema" | "pg_catalog.current_schema" => {
569                    ("current_schema", Type::VARCHAR, "public".to_string())
570                }
571                "current_database" | "pg_catalog.current_database" => {
572                    ("current_database", Type::VARCHAR, "laminar".to_string())
573                }
574                "current_user" | "session_user" | "user" => {
575                    ("current_user", Type::VARCHAR, "laminar".to_string())
576                }
577                _ => return Err(unsupported_select()),
578            };
579            Ok(text_response(col, ty, value))
580        }
581        _ => Err(unsupported_select()),
582    }
583}
584
585fn unsupported_select() -> PgWireError {
586    user_error(
587        "0A000",
588        "pgwire SELECT is limited to literals and connect-time builtins; use HTTP /api/v1/sql",
589    )
590}
591
592/// `SET` handling. We thread plain `SET name = value` to the engine's
593/// session-property store, and refuse `SET TRANSACTION`-class statements
594/// since we don't honor isolation levels.
595fn apply_set(db: &LaminarDB, set: Set) -> PgWireResult<Response> {
596    match set {
597        Set::SingleAssignment {
598            variable, values, ..
599        } => {
600            let key = variable.to_string();
601            let value = values
602                .first()
603                .map(ToString::to_string)
604                .unwrap_or_default()
605                .trim_matches('\'')
606                .to_string();
607            db.set_session_property(&key, &value);
608            Ok(Response::Execution(Tag::new("SET")))
609        }
610        // Refuse anything that implies semantics we do not provide.
611        Set::SetTransaction { .. } => Err(user_error(
612            "0A000",
613            "SET TRANSACTION is not supported (no transactional semantics)",
614        )),
615        // Lenient pass-through for the harmless catalog-style SETs drivers
616        // issue (NAMES, TIME ZONE, ROLE...). We don't honor them, but failing
617        // the connection is worse than silently accepting.
618        _ => Ok(Response::Execution(Tag::new("SET"))),
619    }
620}
621
622fn user_error(code: &str, msg: impl Into<String>) -> PgWireError {
623    PgWireError::UserError(Box::new(ErrorInfo::new(
624        "ERROR".into(),
625        code.into(),
626        msg.into(),
627    )))
628}
629
630/// Reconstruct a single SHOW statement from the parsed variant. Used by the
631/// pgwire dispatcher so a multi-statement query (`SHOW SOURCES; SHOW SINKS`)
632/// re-executes only the matching statement, not the whole input string.
633fn show_sql(cmd: &ShowCommand) -> String {
634    match cmd {
635        ShowCommand::Sources => "SHOW SOURCES".into(),
636        ShowCommand::Sinks => "SHOW SINKS".into(),
637        ShowCommand::Queries => "SHOW QUERIES".into(),
638        ShowCommand::MaterializedViews => "SHOW MATERIALIZED VIEWS".into(),
639        ShowCommand::Streams => "SHOW STREAMS".into(),
640        ShowCommand::Tables => "SHOW TABLES".into(),
641        ShowCommand::CheckpointStatus => "SHOW CHECKPOINT STATUS".into(),
642        ShowCommand::CreateSource { name } => format!("SHOW CREATE SOURCE {name}"),
643        ShowCommand::CreateSink { name } => format!("SHOW CREATE SINK {name}"),
644    }
645}
646
647/// Run a SHOW through the engine and stream its `RecordBatch` to the wire.
648async fn engine_metadata_response(db: &LaminarDB, sql: &str) -> PgWireResult<Response> {
649    use laminar_db::ExecuteResult;
650    let result = db
651        .execute(sql)
652        .await
653        .map_err(|e| user_error("XX000", e.to_string()))?;
654    let ExecuteResult::Metadata(batch) = result else {
655        return Err(user_error("XX000", "SHOW did not return metadata"));
656    };
657    Ok(record_batch_response(batch))
658}
659
660/// Single-row `text` response with one column.
661fn text_response(col: &str, ty: Type, value: String) -> Response {
662    let schema = Arc::new(vec![FieldInfo::new(
663        col.into(),
664        None,
665        None,
666        ty,
667        FieldFormat::Text,
668    )]);
669    let schema_for_row = Arc::clone(&schema);
670    let row_stream = stream::iter(std::iter::once(Ok::<_, PgWireError>(()))).map(move |_| {
671        let mut enc = DataRowEncoder::new(Arc::clone(&schema_for_row));
672        enc.encode_field(&Some(value.as_str()))?;
673        Ok(enc.take_row())
674    });
675    Response::Query(QueryResponse::new(schema, row_stream))
676}
677
678fn record_batch_response(batch: arrow_array::RecordBatch) -> Response {
679    let fields = Arc::new(field_infos(&batch.schema(), None));
680    let nrows = batch.num_rows();
681
682    // Encode rows eagerly: SHOW outputs are tiny and this avoids the
683    // !Send formatter dance.
684    let mut rows = Vec::with_capacity(nrows);
685    {
686        let opts = arrow_cast::display::FormatOptions::default();
687        let formatters: Vec<_> = batch
688            .columns()
689            .iter()
690            .map(|c| arrow_cast::display::ArrayFormatter::try_new(c.as_ref(), &opts))
691            .collect::<Result<_, _>>()
692            .unwrap_or_default();
693        for row in 0..nrows {
694            rows.push(encode_row(&batch, row, &fields, &formatters));
695        }
696    }
697
698    let row_stream = stream::iter(rows);
699    Response::Query(QueryResponse::new(fields, row_stream))
700}
701
702/// Strict-PG FETCH: blocks until `target` rows are produced, the pump exits,
703/// or the broadcast lags. Lag/exit flips `cursor.inner.exhausted` so the next
704/// `evict_idle_peer` reaps the cursor. Text format only; SimpleQuery has no
705/// binary. Leftover rows from a multi-row frame stay in `cursor.inner.pending`
706/// so successive FETCHes consume the frame in order.
707fn fetch_response(cursor: ActiveCursor, target: FetchTarget) -> Response {
708    let fields = Arc::new(field_infos(&cursor.schema, None));
709    let remaining = match target {
710        FetchTarget::Count(n) => Some(n),
711        FetchTarget::All => None,
712    };
713
714    struct State {
715        cursor: ActiveCursor,
716        fields: Arc<Vec<FieldInfo>>,
717        remaining: Option<u64>,
718    }
719
720    let init = State {
721        cursor,
722        fields: Arc::clone(&fields),
723        remaining,
724    };
725
726    let row_stream = stream::unfold(init, |mut s| async move {
727        loop {
728            if matches!(s.remaining, Some(0)) {
729                return None;
730            }
731            // Pending rows from a prior FETCH come out first. Anything left
732            // here when remaining hits 0 stays for the next call.
733            let popped = s.cursor.inner.pending.lock().pop_front();
734            if let Some(row) = popped {
735                if let Some(n) = s.remaining.as_mut() {
736                    *n = n.saturating_sub(1);
737                }
738                return Some((row, s));
739            }
740            if s.cursor.inner.exhausted.load(Ordering::Acquire) {
741                return None;
742            }
743
744            let next = s.cursor.inner.portal.lock().await.next_frame().await;
745            match next {
746                None => {
747                    s.cursor.inner.exhausted.store(true, Ordering::Release);
748                    return None;
749                }
750                Some(PortalFrame::Batch(b)) if b.num_rows() > 0 => {
751                    let encoded = encode_batch(&b, &s.fields);
752                    s.cursor.inner.pending.lock().extend(encoded);
753                }
754                Some(PortalFrame::Batch(_)) => {}
755                Some(PortalFrame::Barrier { .. }) => {
756                    // Same as portal_to_response: PG has no out-of-band marker.
757                }
758                Some(PortalFrame::Lagged(n)) => {
759                    s.cursor.inner.exhausted.store(true, Ordering::Release);
760                    let err = user_error(
761                        "54000",
762                        format!("subscription lagged: skipped {n} messages, terminating cursor"),
763                    );
764                    return Some((Err(err), s));
765                }
766            }
767        }
768    });
769    Response::Query(QueryResponse::new(fields, row_stream))
770}
771
772fn encode_batch(
773    batch: &arrow_array::RecordBatch,
774    fields: &Arc<Vec<FieldInfo>>,
775) -> Vec<PgWireResult<pgwire::messages::data::DataRow>> {
776    let opts = arrow_cast::display::FormatOptions::default();
777    let formatters: Vec<_> = match batch
778        .columns()
779        .iter()
780        .map(|c| arrow_cast::display::ArrayFormatter::try_new(c.as_ref(), &opts))
781        .collect::<Result<_, _>>()
782    {
783        Ok(f) => f,
784        Err(e) => {
785            return vec![Err(user_error("XX000", format!("format column: {e}")))];
786        }
787    };
788    (0..batch.num_rows())
789        .map(|row| encode_row(batch, row, fields, &formatters))
790        .collect()
791}
792
793/// Build pgwire `FieldInfo`s from an Arrow schema. `result_format` (from a
794/// `Bind`) sets per-column text/binary; `None` defaults all-text.
795fn field_infos(schema: &arrow_schema::Schema, result_format: Option<&Format>) -> Vec<FieldInfo> {
796    schema
797        .fields()
798        .iter()
799        .enumerate()
800        .map(|(i, f)| {
801            let format = result_format.map_or(FieldFormat::Text, |rf| rf.format_for(i));
802            FieldInfo::new(
803                f.name().clone(),
804                None,
805                None,
806                arrow_to_pg_type(f.data_type()),
807                format,
808            )
809        })
810        .collect()
811}
812
813fn encode_row(
814    batch: &arrow_array::RecordBatch,
815    row: usize,
816    fields: &Arc<Vec<FieldInfo>>,
817    formatters: &[arrow_cast::display::ArrayFormatter<'_>],
818) -> PgWireResult<pgwire::messages::data::DataRow> {
819    let mut enc = DataRowEncoder::new(Arc::clone(fields));
820    for (i, col) in batch.columns().iter().enumerate() {
821        let info = &fields[i];
822        match info.format() {
823            FieldFormat::Text => encode_field_text(&mut enc, col.as_ref(), row, &formatters[i])?,
824            FieldFormat::Binary => encode_field_binary(&mut enc, col.as_ref(), row, info.name())?,
825        }
826    }
827    Ok(enc.take_row())
828}
829
830fn encode_field_text(
831    enc: &mut DataRowEncoder,
832    col: &dyn arrow_array::Array,
833    row: usize,
834    formatter: &arrow_cast::display::ArrayFormatter<'_>,
835) -> PgWireResult<()> {
836    use arrow_schema::DataType;
837    if col.is_null(row) {
838        return enc.encode_field(&None::<&str>);
839    }
840    // A TEXT[] column must serialize as a Postgres array literal `{..}`, not
841    // Arrow's `[..]` display, so text-mode clients parse it as an array.
842    if matches!(col.data_type(), DataType::List(f) if matches!(f.data_type(), DataType::Utf8 | DataType::LargeUtf8))
843    {
844        return enc.encode_field(&Some(pg_text_array_literal(&list_text_elements(col, row))));
845    }
846    enc.encode_field(&Some(formatter.value(row).to_string()))
847}
848
849/// Owned elements of a `List<Utf8|LargeUtf8>` row, NULLs preserved.
850fn list_text_elements(col: &dyn arrow_array::Array, row: usize) -> Vec<Option<String>> {
851    use arrow_array::cast::AsArray;
852    use arrow_array::Array;
853    use arrow_schema::DataType;
854    let values = col.as_list::<i32>().value(row);
855    if matches!(values.data_type(), DataType::LargeUtf8) {
856        let s = values.as_string::<i64>();
857        (0..s.len())
858            .map(|i| (!s.is_null(i)).then(|| s.value(i).to_owned()))
859            .collect()
860    } else {
861        let s = values.as_string::<i32>();
862        (0..s.len())
863            .map(|i| (!s.is_null(i)).then(|| s.value(i).to_owned()))
864            .collect()
865    }
866}
867
868/// Postgres `text[]` literal, e.g. `{"en","ja",NULL}`. Every element is quoted
869/// (NULL excepted) so commas/braces/quotes in values are unambiguous.
870fn pg_text_array_literal(elements: &[Option<String>]) -> String {
871    let mut out = String::from("{");
872    for (i, elem) in elements.iter().enumerate() {
873        if i > 0 {
874            out.push(',');
875        }
876        match elem {
877            None => out.push_str("NULL"),
878            Some(v) => {
879                out.push('"');
880                for ch in v.chars() {
881                    if ch == '"' || ch == '\\' {
882                        out.push('\\');
883                    }
884                    out.push(ch);
885                }
886                out.push('"');
887            }
888        }
889    }
890    out.push('}');
891    out
892}
893
894/// Binary-encode a single Arrow value via `postgres-types` `ToSql`.
895///
896/// Coverage: Int{8,16,32,64}, UInt{8,16,32,64}, Float{32,64}, Bool,
897/// Utf8/LargeUtf8, Timestamp (any unit, naive), Date32, Date64, and
898/// `List<Utf8>` (as `text[]`). UInt64 is widened to INT8 with saturation
899/// since Postgres has no unsigned 64. Any other column type yields `0A000`.
900fn encode_field_binary(
901    enc: &mut DataRowEncoder,
902    col: &dyn arrow_array::Array,
903    row: usize,
904    name: &str,
905) -> PgWireResult<()> {
906    use arrow_array::{cast::AsArray, types::*};
907    use arrow_schema::DataType;
908
909    if col.is_null(row) {
910        return enc.encode_field(&None::<&str>);
911    }
912
913    // Pull the typed Arrow value and pass it to `DataRowEncoder`, which
914    // calls `postgres-types::ToSql` for the wire format. The `as $cast`
915    // arm widens a narrower Arrow int to the matching Postgres OID (see
916    // `arrow_to_pg_type`); only lossless `From` casts go through here.
917    macro_rules! prim {
918        ($ty:ty as $cast:ty) => {
919            enc.encode_field(&Some(<$cast>::from(col.as_primitive::<$ty>().value(row))))
920        };
921        ($ty:ty) => {
922            enc.encode_field(&Some(col.as_primitive::<$ty>().value(row)))
923        };
924    }
925
926    match col.data_type() {
927        DataType::Int8 => prim!(Int8Type as i32),
928        DataType::Int16 => prim!(Int16Type as i32),
929        DataType::Int32 => prim!(Int32Type),
930        DataType::Int64 => prim!(Int64Type),
931        DataType::UInt8 => prim!(UInt8Type as i32),
932        DataType::UInt16 => prim!(UInt16Type as i32),
933        DataType::UInt32 => prim!(UInt32Type as i64),
934        DataType::UInt64 => {
935            // PG has no unsigned 64; saturate so we never wrap.
936            let v = col.as_primitive::<UInt64Type>().value(row);
937            enc.encode_field(&Some(i64::try_from(v).unwrap_or(i64::MAX)))
938        }
939        DataType::Float32 => prim!(Float32Type as f64),
940        DataType::Float64 => prim!(Float64Type),
941        DataType::Boolean => enc.encode_field(&Some(col.as_boolean().value(row))),
942        DataType::Utf8 => enc.encode_field(&Some(col.as_string::<i32>().value(row))),
943        DataType::LargeUtf8 => enc.encode_field(&Some(col.as_string::<i64>().value(row))),
944        DataType::Timestamp(unit, _tz) => {
945            // Each unit has its own Arrow type — `PrimitiveArray<TimestampMicrosecondType>`
946            // is *not* `PrimitiveArray<Int64Type>`, so the downcast must match the unit.
947            use arrow_array::temporal_conversions::{
948                timestamp_ms_to_datetime, timestamp_ns_to_datetime, timestamp_s_to_datetime,
949                timestamp_us_to_datetime,
950            };
951            use arrow_schema::TimeUnit;
952            let (raw, dt) = match unit {
953                TimeUnit::Second => {
954                    let v = col.as_primitive::<TimestampSecondType>().value(row);
955                    (v, timestamp_s_to_datetime(v))
956                }
957                TimeUnit::Millisecond => {
958                    let v = col.as_primitive::<TimestampMillisecondType>().value(row);
959                    (v, timestamp_ms_to_datetime(v))
960                }
961                TimeUnit::Microsecond => {
962                    let v = col.as_primitive::<TimestampMicrosecondType>().value(row);
963                    (v, timestamp_us_to_datetime(v))
964                }
965                TimeUnit::Nanosecond => {
966                    let v = col.as_primitive::<TimestampNanosecondType>().value(row);
967                    (v, timestamp_ns_to_datetime(v))
968                }
969            };
970            let dt =
971                dt.ok_or_else(|| user_error("22008", format!("timestamp out of range: {raw}")))?;
972            enc.encode_field(&Some(dt))
973        }
974        DataType::Date32 => {
975            let v = col.as_primitive::<Date32Type>().value(row);
976            let dt = arrow_array::temporal_conversions::date32_to_datetime(v)
977                .ok_or_else(|| user_error("22008", format!("DATE out of range: {v}")))?;
978            enc.encode_field(&Some(dt.date()))
979        }
980        DataType::Date64 => {
981            let v = col.as_primitive::<Date64Type>().value(row);
982            let dt = arrow_array::temporal_conversions::date64_to_datetime(v)
983                .ok_or_else(|| user_error("22008", format!("DATE out of range: {v}")))?;
984            enc.encode_field(&Some(dt.date()))
985        }
986        DataType::List(field)
987            if matches!(field.data_type(), DataType::Utf8 | DataType::LargeUtf8) =>
988        {
989            // `postgres-types` encodes Vec<Option<String>> as the binary
990            // text[] wire format (the column's OID is TEXT_ARRAY).
991            enc.encode_field(&Some(list_text_elements(col, row)))
992        }
993        other => Err(user_error(
994            "0A000",
995            format!("binary format not supported for column '{name}' (type {other:?})"),
996        )),
997    }
998}
999
1000fn arrow_to_pg_type(dt: &arrow_schema::DataType) -> Type {
1001    use arrow_schema::DataType;
1002    match dt {
1003        DataType::Int8 | DataType::Int16 | DataType::Int32 => Type::INT4,
1004        DataType::Int64 | DataType::UInt32 | DataType::UInt64 => Type::INT8,
1005        DataType::UInt8 | DataType::UInt16 => Type::INT4,
1006        DataType::Float32 | DataType::Float64 => Type::FLOAT8,
1007        DataType::Utf8 | DataType::LargeUtf8 => Type::VARCHAR,
1008        DataType::Boolean => Type::BOOL,
1009        DataType::Timestamp(_, _) => Type::TIMESTAMP,
1010        DataType::Date32 | DataType::Date64 => Type::DATE,
1011        DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => Type::NUMERIC,
1012        DataType::List(field)
1013            if matches!(field.data_type(), DataType::Utf8 | DataType::LargeUtf8) =>
1014        {
1015            Type::TEXT_ARRAY
1016        }
1017        _ => Type::TEXT,
1018    }
1019}
1020
1021/// Per-call salt + stored credential for the MD5 challenge flow. The
1022/// stored value is either plaintext (legacy) or `md5<32-hex>`, the same
1023/// format Postgres' `pg_authid` uses, where the hex is `md5(password ‖
1024/// user)`. The pre-hashed form lets operators avoid plaintext at rest.
1025#[derive(Debug)]
1026struct LaminarAuthSource {
1027    users: Arc<HashMap<String, Secret>>,
1028}
1029
1030/// If `stored` is a `pg_authid`-style pre-hash, return the inner hex
1031/// (the bit after the `md5` tag). Lowercase hex only; uppercase or
1032/// other lengths fall back to plaintext handling.
1033pub(crate) fn parse_pre_hashed_md5(stored: &str) -> Option<&str> {
1034    let inner = stored.strip_prefix("md5")?;
1035    if inner.len() == 32 && inner.chars().all(|c| matches!(c, '0'..='9' | 'a'..='f')) {
1036        Some(inner)
1037    } else {
1038        None
1039    }
1040}
1041
1042/// MD5 challenge response when only the inner hash is known: the client
1043/// sends `md5{hex(md5(inner_hex || salt))}` and the server precomputes
1044/// the same string for comparison.
1045fn outer_md5_challenge(inner_hex: &str, salt: &[u8]) -> String {
1046    use md5::{Digest, Md5};
1047    let mut hasher = Md5::new();
1048    hasher.update(inner_hex.as_bytes());
1049    hasher.update(salt);
1050    format!("md5{:x}", hasher.finalize())
1051}
1052
1053#[async_trait]
1054impl AuthSource for LaminarAuthSource {
1055    async fn get_password(&self, login: &LoginInfo) -> PgWireResult<Password> {
1056        let user = login.user().unwrap_or("");
1057        // Indistinguishable from a wrong-password failure: both branches must
1058        // surface the same wire error so a client can't probe which usernames
1059        // are configured. pgwire emits exactly this variant on bad password.
1060        let stored = self
1061            .users
1062            .get(user)
1063            .ok_or_else(|| PgWireError::InvalidPassword(user.to_string()))?;
1064        let salt: [u8; 4] = rand::random();
1065        let expected = match parse_pre_hashed_md5(stored.expose()) {
1066            Some(inner_hex) => outer_md5_challenge(inner_hex, &salt),
1067            None => hash_md5_password(user, stored.expose(), &salt),
1068        };
1069        Ok(Password::new(Some(salt.to_vec()), expected.into_bytes()))
1070    }
1071}
1072
1073type Md5Handler = Md5PasswordAuthStartupHandler<LaminarAuthSource, DefaultServerParameterProvider>;
1074
1075/// Startup-phase dispatch. `Md5` requires password auth; `Trust` accepts any
1076/// connection. Selected once at listener startup based on whether
1077/// `pgwire_users` is non-empty.
1078enum StartupAuth {
1079    Trust(Arc<LaminarPgwireHandler>),
1080    Md5(Arc<Md5Handler>),
1081}
1082
1083#[async_trait]
1084impl StartupHandler for StartupAuth {
1085    async fn on_startup<C>(
1086        &self,
1087        client: &mut C,
1088        message: PgWireFrontendMessage,
1089    ) -> PgWireResult<()>
1090    where
1091        C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
1092        C::Error: Debug,
1093        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
1094    {
1095        match self {
1096            Self::Trust(h) => h.on_startup(client, message).await,
1097            Self::Md5(h) => h.on_startup(client, message).await,
1098        }
1099    }
1100}
1101
1102pub struct LaminarHandlerFactory {
1103    handler: Arc<LaminarPgwireHandler>,
1104    startup: Arc<StartupAuth>,
1105}
1106
1107impl LaminarHandlerFactory {
1108    fn new(db: Arc<LaminarDB>, users: HashMap<String, Secret>) -> Self {
1109        let handler = Arc::new(LaminarPgwireHandler::new(db));
1110        let startup = if users.is_empty() {
1111            Arc::new(StartupAuth::Trust(Arc::clone(&handler)))
1112        } else {
1113            let auth = LaminarAuthSource {
1114                users: Arc::new(users),
1115            };
1116            let md5 = Md5PasswordAuthStartupHandler::new(
1117                Arc::new(auth),
1118                Arc::new(DefaultServerParameterProvider::default()),
1119            );
1120            Arc::new(StartupAuth::Md5(Arc::new(md5)))
1121        };
1122        Self { handler, startup }
1123    }
1124}
1125
1126impl PgWireServerHandlers for LaminarHandlerFactory {
1127    fn simple_query_handler(&self) -> Arc<impl SimpleQueryHandler> {
1128        Arc::clone(&self.handler)
1129    }
1130
1131    fn extended_query_handler(&self) -> Arc<impl ExtendedQueryHandler> {
1132        Arc::clone(&self.handler)
1133    }
1134
1135    fn startup_handler(&self) -> Arc<impl StartupHandler> {
1136        Arc::clone(&self.startup)
1137    }
1138}
1139
1140/// Parsed statement carried through `Parse` → `Bind` → `Execute`.
1141#[derive(Clone, Debug)]
1142pub enum LaminarStmt {
1143    /// `SUBSCRIBE` with its schema resolved at parse time so `Describe` can
1144    /// answer before the portal is bound.
1145    Subscribe {
1146        name: String,
1147        filter_sql: Option<String>,
1148        as_of_epoch: Option<u64>,
1149        schema: arrow_schema::SchemaRef,
1150    },
1151    Show(ShowCommand),
1152    Standard(Box<Statement>),
1153}
1154
1155/// Resolves SQL to `LaminarStmt`, looking up stream schemas against the
1156/// live `LaminarDB` so the extended-query `Describe` returns columns
1157/// without running the query.
1158#[derive(Clone)]
1159pub struct LaminarQueryParser {
1160    db: Arc<LaminarDB>,
1161}
1162
1163#[async_trait]
1164impl QueryParser for LaminarQueryParser {
1165    type Statement = LaminarStmt;
1166
1167    async fn parse_sql<C>(
1168        &self,
1169        _client: &C,
1170        sql: &str,
1171        _types: &[Option<Type>],
1172    ) -> PgWireResult<Self::Statement>
1173    where
1174        C: ClientInfo + Unpin + Send + Sync,
1175    {
1176        let mut stmts = parse_streaming_sql(sql)
1177            .map_err(|e| user_error("42601", format!("parse error: {e}")))?;
1178        let stmt = stmts
1179            .pop()
1180            .ok_or_else(|| user_error("42601", "empty statement"))?;
1181        if !stmts.is_empty() {
1182            return Err(user_error(
1183                "42601",
1184                "extended query: multiple statements per Parse are not supported",
1185            ));
1186        }
1187
1188        match stmt {
1189            StreamingStatement::Subscribe(s) => {
1190                let name = s.name.to_string();
1191                let (schema, _) = self.db.lookup_subscription_schema(&name).ok_or_else(|| {
1192                    user_error("42P01", format!("SUBSCRIBE '{name}': stream not found"))
1193                })?;
1194                Ok(LaminarStmt::Subscribe {
1195                    name,
1196                    filter_sql: s.filter_sql,
1197                    as_of_epoch: s.as_of_epoch,
1198                    schema,
1199                })
1200            }
1201            StreamingStatement::Show(cmd) => Ok(LaminarStmt::Show(cmd)),
1202            StreamingStatement::Standard(s) => Ok(LaminarStmt::Standard(s)),
1203            other => Err(user_error(
1204                "0A000",
1205                format!("not supported on pgwire (use HTTP /api/v1/sql): {other:?}"),
1206            )),
1207        }
1208    }
1209
1210    fn get_parameter_types(&self, _stmt: &Self::Statement) -> PgWireResult<Vec<Type>> {
1211        // SUBSCRIBE has no `$N` placeholders.
1212        Ok(Vec::new())
1213    }
1214
1215    fn get_result_schema(
1216        &self,
1217        stmt: &Self::Statement,
1218        column_format: Option<&Format>,
1219    ) -> PgWireResult<Vec<FieldInfo>> {
1220        // SHOW and Standard are tiny single-row outputs whose schema only
1221        // materialises after execution; clients see it on Execute's
1222        // RowDescription instead.
1223        match stmt {
1224            LaminarStmt::Subscribe { schema, .. } => Ok(field_infos(schema, column_format)),
1225            LaminarStmt::Show(_) | LaminarStmt::Standard(_) => Ok(Vec::new()),
1226        }
1227    }
1228}
1229
1230#[async_trait]
1231impl ExtendedQueryHandler for LaminarPgwireHandler {
1232    type Statement = LaminarStmt;
1233    type QueryParser = LaminarQueryParser;
1234
1235    fn query_parser(&self) -> Arc<Self::QueryParser> {
1236        Arc::new(LaminarQueryParser {
1237            db: Arc::clone(&self.db),
1238        })
1239    }
1240
1241    async fn do_query<C>(
1242        &self,
1243        client: &mut C,
1244        portal: &Portal<Self::Statement>,
1245        max_rows: usize,
1246    ) -> PgWireResult<Response>
1247    where
1248        C: ClientInfo + ClientPortalStore + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
1249        C::PortalStore: PortalStore<Statement = Self::Statement>,
1250        C::Error: Debug,
1251        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
1252    {
1253        match &portal.statement.statement {
1254            LaminarStmt::Subscribe {
1255                name,
1256                filter_sql,
1257                as_of_epoch,
1258                ..
1259            } => {
1260                let start = match as_of_epoch {
1261                    Some(n) => SubscribeStart::AsOfEpoch(*n),
1262                    None => SubscribeStart::Tail,
1263                };
1264                let sub = self
1265                    .db
1266                    .open_subscription(name, filter_sql.as_deref(), start)
1267                    .await
1268                    .map_err(|e| user_error("42P01", format!("SUBSCRIBE '{name}': {e}")))?;
1269                if max_rows == 0 {
1270                    // Unbounded fetch — pgwire would buffer infinitely.
1271                    // Drive the stream ourselves with per-batch flushing.
1272                    stream_subscribe_flushing(
1273                        client,
1274                        sub,
1275                        false,
1276                        Some(&portal.result_column_format),
1277                    )
1278                    .await?;
1279                    Ok(Response::Execution(Tag::new("SUBSCRIBE")))
1280                } else {
1281                    // Chunked (JDBC setFetchSize / tokio-postgres query_portal).
1282                    // Hand pgwire a row stream so it honours max_rows and
1283                    // emits PortalSuspended automatically.
1284                    Ok(subscription_query_response(
1285                        sub,
1286                        Some(&portal.result_column_format),
1287                    ))
1288                }
1289            }
1290            LaminarStmt::Show(cmd) => engine_metadata_response(&self.db, &show_sql(cmd)).await,
1291            LaminarStmt::Standard(s) => standard_response(&self.db, *s.clone()),
1292        }
1293    }
1294
1295    /// Per-Sync portal cleanup: only the unnamed portal is destroyed.
1296    ///
1297    /// The pgwire 0.39 default `on_sync` calls `clear_portals()`, which wipes
1298    /// every named portal on the connection. PostgreSQL keeps named portals
1299    /// alive until `Close` or end-of-transaction, so the default would break
1300    /// any client that does `Bind named_portal; Sync; Execute named_portal;`
1301    /// — the standard JDBC / asyncpg / tokio-postgres pattern for chunked
1302    /// fetches via `setFetchSize` / `query_portal`.
1303    async fn on_sync<C>(
1304        &self,
1305        client: &mut C,
1306        _message: pgwire::messages::extendedquery::Sync,
1307    ) -> PgWireResult<()>
1308    where
1309        C: ClientInfo + ClientPortalStore + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
1310        C::PortalStore: PortalStore<Statement = Self::Statement>,
1311        C::Error: Debug,
1312        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
1313    {
1314        use futures::SinkExt;
1315        use pgwire::messages::response::ReadyForQuery;
1316
1317        // Drop only the unnamed portal; named portals survive Sync.
1318        client.portal_store().rm_portal("");
1319
1320        client
1321            .send(PgWireBackendMessage::ReadyForQuery(ReadyForQuery::new(
1322                client.transaction_status(),
1323            )))
1324            .await?;
1325        client.flush().await?;
1326        Ok(())
1327    }
1328}
1329
1330pub struct TlsPaths<'a> {
1331    pub cert: &'a std::path::Path,
1332    pub key: &'a std::path::Path,
1333    pub min_version: TlsMinVersion,
1334    /// PEM bundle of CA roots; presence enables mTLS — every client must
1335    /// present a cert that chains to one of these roots.
1336    pub client_ca: Option<&'a std::path::Path>,
1337}
1338
1339/// Owned counterpart to `TlsPaths` that the listener keeps for the
1340/// lifetime of `serve()` so the file watcher can rebuild the acceptor
1341/// without the original config still being in scope.
1342#[derive(Debug, Clone)]
1343struct TlsConfigPaths {
1344    cert: std::path::PathBuf,
1345    key: std::path::PathBuf,
1346    min_version: TlsMinVersion,
1347    client_ca: Option<std::path::PathBuf>,
1348}
1349
1350impl TlsConfigPaths {
1351    fn from_paths(paths: &TlsPaths<'_>) -> Self {
1352        Self {
1353            cert: paths.cert.to_path_buf(),
1354            key: paths.key.to_path_buf(),
1355            min_version: paths.min_version,
1356            client_ca: paths.client_ca.map(|p| p.to_path_buf()),
1357        }
1358    }
1359
1360    fn borrow(&self) -> TlsPaths<'_> {
1361        TlsPaths {
1362            cert: &self.cert,
1363            key: &self.key,
1364            min_version: self.min_version,
1365            client_ca: self.client_ca.as_deref(),
1366        }
1367    }
1368}
1369
1370/// Live TLS acceptor + paths needed to rebuild it on cert rotation.
1371/// Reads on the accept path are a single mutex acquire and a cheap
1372/// `TlsAcceptor` clone; reloads are triggered by the file watcher.
1373pub struct TlsReloadState {
1374    paths: TlsConfigPaths,
1375    acceptor: parking_lot::Mutex<Arc<tokio_rustls::TlsAcceptor>>,
1376}
1377
1378impl TlsReloadState {
1379    fn snapshot(&self) -> Arc<tokio_rustls::TlsAcceptor> {
1380        Arc::clone(&self.acceptor.lock())
1381    }
1382}
1383
1384/// Rebuild the TLS acceptor from `state.paths` and atomically swap it in.
1385/// On any error the previous acceptor is left in place, so a bad rotation
1386/// (truncated file, expired cert) doesn't take TLS down.
1387#[allow(clippy::result_large_err)]
1388pub(crate) fn try_reload_tls(state: &TlsReloadState) -> Result<(), ServerError> {
1389    let new_acceptor = load_tls_acceptor(state.paths.borrow())?;
1390    *state.acceptor.lock() = Arc::new(new_acceptor);
1391    Ok(())
1392}
1393
1394/// Watch the cert / key / client-CA files and call `try_reload_tls` after
1395/// debounced changes. Mirrors the pattern in `watcher.rs` (parent-dir
1396/// watch, debounce, then act). Runs until the channel closes; the caller
1397/// drives shutdown by aborting the task that owns this future.
1398async fn watch_tls_files(state: Arc<TlsReloadState>, debounce: std::time::Duration) {
1399    use crossfire::{mpsc, MTx};
1400    use notify::{Event, RecommendedWatcher, RecursiveMode, Watcher};
1401
1402    // Track raw + canonical paths so symlink-swap rotations and edits to
1403    // the symlink target both produce visible events.
1404    let mut raw_targets: Vec<std::path::PathBuf> = Vec::new();
1405    let mut canon_targets: Vec<std::path::PathBuf> = Vec::new();
1406    for path in [
1407        Some(state.paths.cert.clone()),
1408        Some(state.paths.key.clone()),
1409        state.paths.client_ca.clone(),
1410    ]
1411    .into_iter()
1412    .flatten()
1413    {
1414        match path.canonicalize() {
1415            Ok(canonical) => {
1416                canon_targets.push(canonical);
1417                raw_targets.push(path);
1418            }
1419            Err(e) => {
1420                warn!(
1421                    path = %path.display(),
1422                    error = %e,
1423                    "pgwire TLS watcher: cannot canonicalize path; reload disabled",
1424                );
1425                return;
1426            }
1427        }
1428    }
1429    let mut dirs: Vec<std::path::PathBuf> = raw_targets
1430        .iter()
1431        .chain(canon_targets.iter())
1432        .filter_map(|p| p.parent().map(|d| d.to_path_buf()))
1433        .collect();
1434    dirs.sort();
1435    dirs.dedup();
1436
1437    let (tx, rx) = mpsc::bounded_async::<()>(16);
1438    let blocking_tx: MTx<_> = tx.clone().into_blocking();
1439    let watch_raw = raw_targets.clone();
1440    let watch_canon = canon_targets.clone();
1441
1442    let mut watcher: RecommendedWatcher = match notify::recommended_watcher(
1443        move |result: Result<Event, notify::Error>| match result {
1444            Ok(event) => {
1445                let touched = event.paths.iter().any(|p| {
1446                    watch_raw.iter().any(|t| t == p)
1447                        || p.canonicalize()
1448                            .ok()
1449                            .as_ref()
1450                            .is_some_and(|c| watch_canon.contains(c))
1451                });
1452                if touched {
1453                    let _ = blocking_tx.send(());
1454                }
1455            }
1456            Err(e) => warn!(error = %e, "pgwire TLS watcher: notify error"),
1457        },
1458    ) {
1459        Ok(w) => w,
1460        Err(e) => {
1461            warn!(error = %e, "pgwire TLS watcher: failed to create watcher; reload disabled");
1462            return;
1463        }
1464    };
1465
1466    for dir in &dirs {
1467        if let Err(e) = watcher.watch(dir, RecursiveMode::NonRecursive) {
1468            warn!(
1469                dir = %dir.display(),
1470                error = %e,
1471                "pgwire TLS watcher: failed to watch directory; reload disabled",
1472            );
1473            return;
1474        }
1475    }
1476    info!(
1477        files = ?raw_targets.iter().map(|p| p.display().to_string()).collect::<Vec<_>>(),
1478        "pgwire TLS watcher started",
1479    );
1480
1481    loop {
1482        if rx.recv().await.is_err() {
1483            return;
1484        }
1485        // Debounce: sleep then drain so a burst of inotify events
1486        // (cert + key written separately) coalesces into one reload.
1487        tokio::time::sleep(debounce).await;
1488        while rx.try_recv().is_ok() {}
1489
1490        match try_reload_tls(&state) {
1491            Ok(()) => tracing::info!(
1492                target: "audit",
1493                event = "pgwire.tls_reload",
1494                outcome = "ok",
1495            ),
1496            Err(e) => tracing::warn!(
1497                target: "audit",
1498                event = "pgwire.tls_reload",
1499                outcome = "failed",
1500                error = %e,
1501                "pgwire TLS reload failed; previous certificate kept",
1502            ),
1503        }
1504    }
1505}
1506
1507/// Minimum TLS protocol version accepted on the pgwire listener. rustls
1508/// already disables TLS 1.0/1.1; this narrows further when an operator
1509/// needs TLS 1.3 only.
1510#[derive(Clone, Copy, Debug)]
1511pub enum TlsMinVersion {
1512    V1_2,
1513    V1_3,
1514}
1515
1516impl TlsMinVersion {
1517    pub(crate) fn from_config_str(s: &str) -> Option<Self> {
1518        match s {
1519            "1.2" => Some(Self::V1_2),
1520            "1.3" => Some(Self::V1_3),
1521            _ => None,
1522        }
1523    }
1524
1525    fn versions(self) -> &'static [&'static tokio_rustls::rustls::SupportedProtocolVersion] {
1526        use tokio_rustls::rustls::version::{TLS12, TLS13};
1527        static BOTH: &[&tokio_rustls::rustls::SupportedProtocolVersion] = &[&TLS12, &TLS13];
1528        static ONLY_13: &[&tokio_rustls::rustls::SupportedProtocolVersion] = &[&TLS13];
1529        match self {
1530            Self::V1_2 => BOTH,
1531            Self::V1_3 => ONLY_13,
1532        }
1533    }
1534
1535    fn label(self) -> &'static str {
1536        match self {
1537            Self::V1_2 => "1.2",
1538            Self::V1_3 => "1.3",
1539        }
1540    }
1541}
1542
1543/// Warn if the key file is group/other-readable.
1544#[cfg(unix)]
1545fn warn_if_key_world_readable(file: &std::fs::File, path: &std::path::Path) {
1546    use std::os::unix::fs::MetadataExt;
1547    if let Ok(meta) = file.metadata() {
1548        let mode = meta.mode();
1549        if mode & 0o077 != 0 {
1550            warn!(
1551                path = %path.display(),
1552                mode = format!("{:o}", mode & 0o777),
1553                "pgwire_tls_key permissions are too broad; tighten to 0600",
1554            );
1555        }
1556    }
1557}
1558
1559#[cfg(not(unix))]
1560fn warn_if_key_world_readable(_file: &std::fs::File, _path: &std::path::Path) {}
1561
1562/// Rolling-window auth-failure count per peer IP.
1563#[derive(Debug, Default)]
1564struct FailureTracker {
1565    inner: parking_lot::Mutex<
1566        HashMap<std::net::IpAddr, std::collections::VecDeque<std::time::Instant>>,
1567    >,
1568}
1569
1570impl FailureTracker {
1571    fn is_blocked(&self, ip: std::net::IpAddr, limit: u32, window: std::time::Duration) -> bool {
1572        if limit == 0 {
1573            return false;
1574        }
1575        let cutoff = std::time::Instant::now() - window;
1576        let mut inner = self.inner.lock();
1577        let Some(failures) = inner.get_mut(&ip) else {
1578            return false;
1579        };
1580        while failures.front().is_some_and(|t| *t < cutoff) {
1581            failures.pop_front();
1582        }
1583        let blocked = failures.len() >= limit as usize;
1584        if failures.is_empty() {
1585            inner.remove(&ip);
1586        }
1587        blocked
1588    }
1589
1590    fn record_failure(&self, ip: std::net::IpAddr) {
1591        let mut inner = self.inner.lock();
1592        // When full, evict the entry whose newest failure is oldest.
1593        if !inner.contains_key(&ip) && inner.len() >= MAX_TRACKED_IPS {
1594            if let Some(oldest) = inner
1595                .iter()
1596                .min_by_key(|(_, q)| q.back().copied())
1597                .map(|(k, _)| *k)
1598            {
1599                inner.remove(&oldest);
1600            }
1601        }
1602        inner
1603            .entry(ip)
1604            .or_default()
1605            .push_back(std::time::Instant::now());
1606    }
1607}
1608
1609const MAX_TRACKED_IPS: usize = 4096;
1610
1611/// Stable audit code for a session's exit status.
1612fn classify_outcome(result: &Result<(), std::io::Error>) -> &'static str {
1613    match result {
1614        Ok(()) => "ok",
1615        Err(e) => {
1616            let msg = e.to_string();
1617            if msg.contains("28P01") {
1618                "auth_failed"
1619            } else if msg.contains("HandshakeFailure")
1620                || msg.contains("rustls")
1621                || msg.contains("tls")
1622            {
1623                "tls_failed"
1624            } else {
1625                "error"
1626            }
1627        }
1628    }
1629}
1630
1631/// Reject certs past `notAfter`; warn within 30 days.
1632#[allow(clippy::result_large_err)]
1633fn check_cert_expiry(
1634    der: &tokio_rustls::rustls::pki_types::CertificateDer<'_>,
1635    path: &std::path::Path,
1636) -> Result<(), ServerError> {
1637    use x509_parser::prelude::FromDer;
1638    let (_, cert) = x509_parser::certificate::X509Certificate::from_der(der.as_ref())
1639        .map_err(|e| ServerError::Http(format!("parse pgwire_tls_cert {}: {e}", path.display())))?;
1640    let now = x509_parser::time::ASN1Time::now();
1641    let not_after = cert.validity().not_after;
1642    if not_after < now {
1643        return Err(ServerError::Http(format!(
1644            "pgwire_tls_cert {} expired at {not_after}",
1645            path.display()
1646        )));
1647    }
1648    let remaining = not_after.to_datetime() - now.to_datetime();
1649    if remaining <= time::Duration::days(30) {
1650        warn!(
1651            path = %path.display(),
1652            expires_at = %not_after,
1653            "pgwire_tls_cert expires within 30 days; rotate before it lapses",
1654        );
1655    }
1656    Ok(())
1657}
1658
1659/// Idempotent install of aws-lc-rs as rustls' default provider.
1660fn ensure_tls_provider() {
1661    let _ = tokio_rustls::rustls::crypto::aws_lc_rs::default_provider().install_default();
1662}
1663
1664#[allow(clippy::result_large_err)]
1665fn load_tls_acceptor(paths: TlsPaths<'_>) -> Result<tokio_rustls::TlsAcceptor, ServerError> {
1666    use std::fs::File;
1667    use std::io::BufReader;
1668
1669    ensure_tls_provider();
1670
1671    let cert_file = File::open(paths.cert)
1672        .map_err(|e| ServerError::Http(format!("open pgwire_tls_cert: {e}")))?;
1673    let certs = rustls_pemfile::certs(&mut BufReader::new(cert_file))
1674        .collect::<Result<Vec<_>, _>>()
1675        .map_err(|e| ServerError::Http(format!("parse pgwire_tls_cert: {e}")))?;
1676    if certs.is_empty() {
1677        return Err(ServerError::Http(format!(
1678            "pgwire_tls_cert {} contains no certificates",
1679            paths.cert.display()
1680        )));
1681    }
1682    for cert in &certs {
1683        check_cert_expiry(cert, paths.cert)?;
1684    }
1685
1686    let key_file = File::open(paths.key)
1687        .map_err(|e| ServerError::Http(format!("open pgwire_tls_key: {e}")))?;
1688    warn_if_key_world_readable(&key_file, paths.key);
1689    let key = rustls_pemfile::private_key(&mut BufReader::new(key_file))
1690        .map_err(|e| ServerError::Http(format!("parse pgwire_tls_key: {e}")))?
1691        .ok_or_else(|| {
1692            ServerError::Http(format!(
1693                "pgwire_tls_key {} contains no private key",
1694                paths.key.display()
1695            ))
1696        })?;
1697
1698    let builder = tokio_rustls::rustls::ServerConfig::builder_with_protocol_versions(
1699        paths.min_version.versions(),
1700    );
1701    let builder = match paths.client_ca {
1702        Some(ca_path) => {
1703            let verifier = build_client_cert_verifier(ca_path)?;
1704            builder.with_client_cert_verifier(verifier)
1705        }
1706        None => builder.with_no_client_auth(),
1707    };
1708    let server_config = builder
1709        .with_single_cert(certs, key)
1710        .map_err(|e| ServerError::Http(format!("rustls server config: {e}")))?;
1711    Ok(tokio_rustls::TlsAcceptor::from(Arc::new(server_config)))
1712}
1713
1714#[allow(clippy::result_large_err)]
1715fn build_client_cert_verifier(
1716    ca_path: &std::path::Path,
1717) -> Result<Arc<dyn tokio_rustls::rustls::server::danger::ClientCertVerifier>, ServerError> {
1718    use std::fs::File;
1719    use std::io::BufReader;
1720    use tokio_rustls::rustls::server::WebPkiClientVerifier;
1721    use tokio_rustls::rustls::RootCertStore;
1722
1723    let file = File::open(ca_path)
1724        .map_err(|e| ServerError::Http(format!("open pgwire_tls_client_ca: {e}")))?;
1725    let mut roots = RootCertStore::empty();
1726    let mut added = 0usize;
1727    for cert in rustls_pemfile::certs(&mut BufReader::new(file)) {
1728        let cert =
1729            cert.map_err(|e| ServerError::Http(format!("parse pgwire_tls_client_ca: {e}")))?;
1730        roots
1731            .add(cert)
1732            .map_err(|e| ServerError::Http(format!("invalid CA in pgwire_tls_client_ca: {e}")))?;
1733        added += 1;
1734    }
1735    if added == 0 {
1736        return Err(ServerError::Http(format!(
1737            "pgwire_tls_client_ca {} contains no certificates",
1738            ca_path.display()
1739        )));
1740    }
1741    WebPkiClientVerifier::builder(Arc::new(roots))
1742        .build()
1743        .map_err(|e| ServerError::Http(format!("build client-cert verifier: {e}")))
1744}
1745
1746pub async fn serve(
1747    db: Arc<LaminarDB>,
1748    bind: &str,
1749    users: HashMap<String, Secret>,
1750    allow_remote: bool,
1751    tls: Option<TlsPaths<'_>>,
1752    max_connections: usize,
1753    max_auth_failures_per_min: u32,
1754) -> Result<(SocketAddr, tokio::task::JoinHandle<()>), ServerError> {
1755    let addr: SocketAddr = bind
1756        .parse()
1757        .map_err(|e| ServerError::Http(format!("invalid pgwire_bind '{bind}': {e}")))?;
1758
1759    let auth_mode = if users.is_empty() { "trust" } else { "md5" };
1760    let is_remote_bind = !addr.ip().is_loopback();
1761    match (auth_mode, is_remote_bind, allow_remote) {
1762        ("trust", true, _) => {
1763            return Err(ServerError::Http(format!(
1764                "pgwire_bind '{addr}' is not loopback and pgwire_users is empty (trust auth); \
1765             configure pgwire_users + pgwire_allow_remote=true, or bind to 127.0.0.1"
1766            )))
1767        }
1768        ("md5", true, false) => {
1769            return Err(ServerError::Http(format!(
1770                "pgwire_bind '{addr}' is not loopback; set pgwire_allow_remote=true to opt in"
1771            )))
1772        }
1773        _ => {}
1774    }
1775
1776    let tls_min_label = tls.as_ref().map(|p| p.min_version.label());
1777    let mtls_on = tls.as_ref().is_some_and(|p| p.client_ca.is_some());
1778    let tls_state: Option<Arc<TlsReloadState>> = match tls {
1779        Some(paths) => {
1780            let acceptor = load_tls_acceptor(TlsPaths {
1781                cert: paths.cert,
1782                key: paths.key,
1783                min_version: paths.min_version,
1784                client_ca: paths.client_ca,
1785            })?;
1786            Some(Arc::new(TlsReloadState {
1787                paths: TlsConfigPaths::from_paths(&paths),
1788                acceptor: parking_lot::Mutex::new(Arc::new(acceptor)),
1789            }))
1790        }
1791        None => None,
1792    };
1793
1794    let listener = TcpListener::bind(addr)
1795        .await
1796        .map_err(|e| ServerError::Http(format!("pgwire bind {addr}: {e}")))?;
1797    let local_addr = listener
1798        .local_addr()
1799        .map_err(|e| ServerError::Http(format!("pgwire local_addr: {e}")))?;
1800
1801    let factory = Arc::new(LaminarHandlerFactory::new(db, users));
1802    let tls_mode = if tls_state.is_some() { "on" } else { "off" };
1803    let tls_min = tls_min_label.unwrap_or("-");
1804    let mtls = if mtls_on { "on" } else { "off" };
1805    if auth_mode == "trust" {
1806        warn!(
1807            addr = %local_addr,
1808            tls = tls_mode,
1809            tls_min,
1810            mtls,
1811            "pgwire listening with TRUST auth — any client reaching this address is admin",
1812        );
1813    } else {
1814        info!(
1815            addr = %local_addr,
1816            auth = auth_mode,
1817            tls = tls_mode,
1818            tls_min,
1819            mtls,
1820            "pgwire listening",
1821        );
1822    }
1823
1824    // Track per-connection tasks so abort on the outer JoinHandle stops
1825    // active sessions in addition to the accept loop.
1826    let failures = Arc::new(FailureTracker::default());
1827    let watcher_state = tls_state.as_ref().map(Arc::clone);
1828    let watcher_disabled =
1829        std::env::var("LAMINAR_DISABLE_FILE_WATCH").is_ok_and(|v| v == "1" || v == "true");
1830    let handle = tokio::spawn(async move {
1831        let mut sessions: tokio::task::JoinSet<()> = tokio::task::JoinSet::new();
1832        // Watcher in its own JoinSet so it doesn't count toward max_connections.
1833        let mut watcher_set: tokio::task::JoinSet<()> = tokio::task::JoinSet::new();
1834        if let (Some(state), false) = (watcher_state, watcher_disabled) {
1835            watcher_set.spawn(async move {
1836                watch_tls_files(state, std::time::Duration::from_millis(500)).await;
1837            });
1838        }
1839        loop {
1840            tokio::select! {
1841                Some(_) = sessions.join_next(), if !sessions.is_empty() => {
1842                    // Reap completed sessions; nothing to do with the result.
1843                }
1844                Some(_) = watcher_set.join_next(), if !watcher_set.is_empty() => {}
1845                accepted = listener.accept() => {
1846                    match accepted {
1847                        Ok((sock, peer)) => {
1848                            if sessions.len() >= max_connections {
1849                                tracing::info!(
1850                                    target: "audit",
1851                                    event = "pgwire.connection_rejected",
1852                                    peer = %peer,
1853                                    reason = "max_connections",
1854                                    in_flight = sessions.len(),
1855                                );
1856                                drop(sock);
1857                                continue;
1858                            }
1859                            if failures.is_blocked(
1860                                peer.ip(),
1861                                max_auth_failures_per_min,
1862                                std::time::Duration::from_secs(60),
1863                            ) {
1864                                tracing::warn!(
1865                                    target: "audit",
1866                                    event = "pgwire.connection_rejected",
1867                                    peer = %peer,
1868                                    reason = "auth_failure_throttle",
1869                                );
1870                                drop(sock);
1871                                continue;
1872                            }
1873                            let factory_ref = Arc::clone(&factory);
1874                            // Snapshot the live acceptor so that an in-flight
1875                            // handshake completes against whatever cert was
1876                            // current when the socket was accepted, even if a
1877                            // hot-reload swaps it under us.
1878                            let tls_ref: Option<tokio_rustls::TlsAcceptor> =
1879                                tls_state.as_ref().map(|s| (*s.snapshot()).clone());
1880                            let failures_ref = Arc::clone(&failures);
1881                            let peer_str = peer.to_string();
1882                            tracing::info!(
1883                                target: "audit",
1884                                event = "pgwire.connection_accepted",
1885                                peer = %peer,
1886                                auth = auth_mode,
1887                                tls = tls_mode,
1888                            );
1889                            let peer_ip = peer.ip();
1890                            sessions.spawn(async move {
1891                                let result = process_socket(sock, tls_ref, factory_ref).await;
1892                                let outcome = classify_outcome(&result);
1893                                if outcome == "auth_failed" {
1894                                    failures_ref.record_failure(peer_ip);
1895                                }
1896                                tracing::info!(
1897                                    target: "audit",
1898                                    event = "pgwire.connection_closed",
1899                                    peer = %peer_str,
1900                                    outcome,
1901                                );
1902                                if let Err(e) = result {
1903                                    warn!(peer = %peer_str, error = %e, "pgwire connection error");
1904                                }
1905                            });
1906                        }
1907                        Err(e) => {
1908                            warn!(error = %e, "pgwire accept failed");
1909                            tokio::time::sleep(std::time::Duration::from_millis(100)).await;
1910                        }
1911                    }
1912                }
1913            }
1914        }
1915    });
1916    Ok((local_addr, handle))
1917}
1918
1919#[cfg(test)]
1920mod tests {
1921    use super::*;
1922
1923    fn parse_one(sql: &str) -> StreamingStatement {
1924        parse_streaming_sql(sql)
1925            .unwrap()
1926            .into_iter()
1927            .next()
1928            .unwrap()
1929    }
1930
1931    fn standard(sql: &str) -> Statement {
1932        match parse_one(sql) {
1933            StreamingStatement::Standard(s) => *s,
1934            other => panic!("expected Standard, got {other:?}"),
1935        }
1936    }
1937
1938    #[test]
1939    fn pg_text_array_literal_quotes_nulls_and_escapes() {
1940        assert_eq!(pg_text_array_literal(&[]), "{}");
1941        assert_eq!(
1942            pg_text_array_literal(&[Some("en".into()), Some("ja".into())]),
1943            r#"{"en","ja"}"#
1944        );
1945        assert_eq!(
1946            pg_text_array_literal(&[None, Some("x".into())]),
1947            r#"{NULL,"x"}"#
1948        );
1949        // Embedded quote and backslash are escaped, not left ambiguous.
1950        assert_eq!(
1951            pg_text_array_literal(&[Some("a\"b\\c".into())]),
1952            r#"{"a\"b\\c"}"#
1953        );
1954    }
1955
1956    #[tokio::test]
1957    async fn select_one_dispatches() {
1958        let db = LaminarDB::open().unwrap();
1959        for sql in ["SELECT 1", "select 1", "/* hint */ SELECT 1"] {
1960            standard_response(&db, standard(sql)).unwrap();
1961        }
1962    }
1963
1964    #[tokio::test]
1965    async fn driver_select_builtins_dispatch() {
1966        let db = LaminarDB::open().unwrap();
1967        for sql in [
1968            "SELECT version()",
1969            "SELECT current_schema()",
1970            "SELECT current_database()",
1971            "SELECT current_user",
1972        ] {
1973            // current_user parses as Expr::Function with no parens in some versions;
1974            // we accept whatever the parser gives us.
1975            let _ = standard_response(&db, standard(sql));
1976        }
1977    }
1978
1979    #[tokio::test]
1980    async fn select_with_from_is_rejected() {
1981        let db = LaminarDB::open().unwrap();
1982        let err = standard_response(&db, standard("SELECT 1 FROM foo")).unwrap_err();
1983        assert!(err.to_string().contains("limited to literals"));
1984    }
1985
1986    #[tokio::test]
1987    async fn ddl_routed_to_http() {
1988        let db = LaminarDB::open().unwrap();
1989        let err = standard_response(&db, standard("CREATE TABLE foo (id INT)")).unwrap_err();
1990        assert!(err.to_string().contains("HTTP /api/v1/sql"));
1991    }
1992
1993    #[tokio::test]
1994    async fn transaction_control_dispatches() {
1995        let db = LaminarDB::open().unwrap();
1996        for sql in [
1997            "BEGIN",
1998            "BEGIN TRANSACTION",
1999            "START TRANSACTION",
2000            "COMMIT",
2001            "ROLLBACK",
2002        ] {
2003            standard_response(&db, standard(sql)).unwrap();
2004        }
2005    }
2006
2007    #[tokio::test]
2008    async fn set_writes_to_session_properties() {
2009        let db = LaminarDB::open().unwrap();
2010        standard_response(&db, standard("SET extra_float_digits = 3")).unwrap();
2011        assert_eq!(
2012            db.get_session_property("extra_float_digits").as_deref(),
2013            Some("3"),
2014        );
2015    }
2016
2017    #[tokio::test]
2018    async fn set_transaction_isolation_is_rejected() {
2019        let db = LaminarDB::open().unwrap();
2020        let err = standard_response(
2021            &db,
2022            standard("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE"),
2023        )
2024        .unwrap_err();
2025        assert!(err.to_string().contains("SET TRANSACTION"));
2026    }
2027
2028    #[test]
2029    fn multi_statement_parses() {
2030        let stmts = parse_streaming_sql("BEGIN; SELECT 1; COMMIT").unwrap();
2031        assert_eq!(stmts.len(), 3);
2032    }
2033
2034    #[test]
2035    fn classify_outcome_buckets_errors() {
2036        use std::io::{Error, ErrorKind};
2037        assert_eq!(super::classify_outcome(&Ok(())), "ok");
2038        assert_eq!(
2039            super::classify_outcome(&Err(Error::other("FATAL: 28P01 bad pass"))),
2040            "auth_failed"
2041        );
2042        assert_eq!(
2043            super::classify_outcome(&Err(Error::other("rustls HandshakeFailure"))),
2044            "tls_failed"
2045        );
2046        assert_eq!(
2047            super::classify_outcome(&Err(Error::new(ErrorKind::BrokenPipe, "broken"))),
2048            "error"
2049        );
2050    }
2051
2052    #[test]
2053    fn failure_tracker_blocks_after_threshold() {
2054        use std::net::{IpAddr, Ipv4Addr};
2055        use std::time::Duration;
2056        let ip: IpAddr = Ipv4Addr::LOCALHOST.into();
2057        let tracker = super::FailureTracker::default();
2058        let limit = 3;
2059        let window = Duration::from_secs(60);
2060
2061        for _ in 0..limit {
2062            assert!(!tracker.is_blocked(ip, limit, window));
2063            tracker.record_failure(ip);
2064        }
2065        assert!(tracker.is_blocked(ip, limit, window));
2066    }
2067
2068    #[test]
2069    fn failure_tracker_disabled_when_limit_zero() {
2070        use std::net::{IpAddr, Ipv4Addr};
2071        use std::time::Duration;
2072        let ip: IpAddr = Ipv4Addr::LOCALHOST.into();
2073        let tracker = super::FailureTracker::default();
2074        for _ in 0..100 {
2075            tracker.record_failure(ip);
2076        }
2077        assert!(!tracker.is_blocked(ip, 0, Duration::from_secs(60)));
2078    }
2079
2080    #[test]
2081    fn failure_tracker_expires_old_entries() {
2082        use std::net::{IpAddr, Ipv4Addr};
2083        use std::time::Duration;
2084        let ip: IpAddr = Ipv4Addr::LOCALHOST.into();
2085        let tracker = super::FailureTracker::default();
2086        for _ in 0..5 {
2087            tracker.record_failure(ip);
2088        }
2089        // Window of 0 means every recorded failure is already expired.
2090        assert!(!tracker.is_blocked(ip, 5, Duration::from_secs(0)));
2091    }
2092
2093    #[test]
2094    fn failure_tracker_caps_distinct_ips() {
2095        use std::net::{IpAddr, Ipv4Addr};
2096        let tracker = super::FailureTracker::default();
2097        // Push past the cap; map size must stay bounded.
2098        for i in 0..(super::MAX_TRACKED_IPS + 100) {
2099            #[allow(clippy::cast_possible_truncation)]
2100            let ip: IpAddr = Ipv4Addr::new(10, 0, (i / 256) as u8, (i % 256) as u8).into();
2101            tracker.record_failure(ip);
2102        }
2103        let len = tracker.inner.lock().len();
2104        assert!(
2105            len <= super::MAX_TRACKED_IPS,
2106            "tracker exceeded cap: {len} > {}",
2107            super::MAX_TRACKED_IPS
2108        );
2109    }
2110
2111    #[tokio::test]
2112    async fn serve_rejects_remote_bind_in_trust_mode() {
2113        let db = Arc::new(LaminarDB::open().expect("db opens"));
2114        let err = serve(db, "0.0.0.0:0", HashMap::new(), false, None, 256, 10)
2115            .await
2116            .expect_err("trust + 0.0.0.0 must fail");
2117        assert!(err.to_string().contains("trust auth"), "got: {err}");
2118    }
2119
2120    #[tokio::test]
2121    async fn serve_rejects_remote_bind_without_explicit_optin() {
2122        let db = Arc::new(LaminarDB::open().expect("db opens"));
2123        let mut users = HashMap::new();
2124        users.insert("alice".into(), Secret::new("wonderland-key"));
2125        let err = serve(db, "0.0.0.0:0", users, false, None, 256, 10)
2126            .await
2127            .expect_err("md5 + 0.0.0.0 without allow_remote must fail");
2128        assert!(
2129            err.to_string().contains("pgwire_allow_remote"),
2130            "got: {err}"
2131        );
2132    }
2133}
2134
2135#[cfg(test)]
2136mod integration_tests {
2137    //! End-to-end pgwire driven by `tokio_postgres` against an in-process
2138    //! `LaminarDB`. Verifies the wire-protocol surface — handshake, SimpleQuery
2139    //! dispatch, error reporting — that unit tests can't reach. Engine-level
2140    //! row flow is covered in `laminar-db`'s `db::tests`.
2141
2142    use std::collections::HashMap;
2143    use std::sync::Arc;
2144
2145    use laminar_db::LaminarDB;
2146    use tokio_postgres::{NoTls, SimpleQueryMessage};
2147
2148    use super::Secret;
2149
2150    async fn spawn_server_with(
2151        users: HashMap<String, Secret>,
2152    ) -> (std::net::SocketAddr, tokio::task::JoinHandle<()>) {
2153        let db = Arc::new(LaminarDB::open().expect("db opens"));
2154        db.execute("CREATE SOURCE trades (symbol VARCHAR, price DOUBLE)")
2155            .await
2156            .expect("create source");
2157        db.execute(
2158            "CREATE MATERIALIZED VIEW prices AS \
2159             SELECT symbol, price FROM trades",
2160        )
2161        .await
2162        .expect("create mv");
2163        db.start().await.expect("db starts");
2164
2165        let (addr, handle) =
2166            super::serve(Arc::clone(&db), "127.0.0.1:0", users, false, None, 256, 10)
2167                .await
2168                .expect("pgwire serve");
2169        (addr, handle)
2170    }
2171
2172    async fn spawn_server() -> (std::net::SocketAddr, tokio::task::JoinHandle<()>) {
2173        spawn_server_with(HashMap::new()).await
2174    }
2175
2176    async fn connect(addr: std::net::SocketAddr) -> tokio_postgres::Client {
2177        let conn_str = format!(
2178            "host={} port={} user=any dbname=laminardb",
2179            addr.ip(),
2180            addr.port()
2181        );
2182        let (client, conn) = tokio_postgres::connect(&conn_str, NoTls)
2183            .await
2184            .expect("pgwire connect");
2185        tokio::spawn(async move {
2186            let _ = conn.await;
2187        });
2188        client
2189    }
2190
2191    fn first_row_value(messages: &[SimpleQueryMessage], col: usize) -> Option<&str> {
2192        messages.iter().find_map(|m| match m {
2193            SimpleQueryMessage::Row(r) => r.get(col),
2194            _ => None,
2195        })
2196    }
2197
2198    #[tokio::test]
2199    async fn handshake_and_builtins() {
2200        let (addr, handle) = spawn_server().await;
2201        let client = connect(addr).await;
2202
2203        let messages = client
2204            .simple_query("SELECT version()")
2205            .await
2206            .expect("version");
2207        let v = first_row_value(&messages, 0).expect("row");
2208        assert!(v.contains("LaminarDB"), "version: {v}");
2209
2210        let messages = client
2211            .simple_query("SELECT current_database()")
2212            .await
2213            .expect("current_database");
2214        assert_eq!(first_row_value(&messages, 0), Some("laminar"));
2215
2216        handle.abort();
2217    }
2218
2219    #[tokio::test]
2220    async fn show_streams_runs() {
2221        let (addr, handle) = spawn_server().await;
2222        let client = connect(addr).await;
2223
2224        // No assertion on contents — just that the dispatch path returns rows
2225        // without error. Engine-level SHOW behavior is covered in laminar-db.
2226        client
2227            .simple_query("SHOW STREAMS")
2228            .await
2229            .expect("SHOW STREAMS");
2230
2231        handle.abort();
2232    }
2233
2234    #[tokio::test]
2235    async fn subscribe_unknown_name_returns_pg_error() {
2236        let (addr, handle) = spawn_server().await;
2237        let client = connect(addr).await;
2238
2239        let err = client
2240            .simple_query("SUBSCRIBE no_such_view")
2241            .await
2242            .expect_err("must fail");
2243        let db_err = err.as_db_error().expect("typed PG error");
2244        assert!(
2245            db_err.message().contains("no_such_view"),
2246            "message: {}",
2247            db_err.message()
2248        );
2249
2250        handle.abort();
2251    }
2252
2253    #[tokio::test]
2254    async fn subscribe_with_valid_where_is_accepted() {
2255        // SUBSCRIBE never returns CommandComplete, so a successful compile
2256        // shows up as a timeout. Anything else — Ok(Ok) or Ok(Err) — is a
2257        // regression.
2258        let (addr, handle) = spawn_server().await;
2259        let client = connect(addr).await;
2260
2261        let r = tokio::time::timeout(
2262            std::time::Duration::from_millis(500),
2263            client.simple_query("SUBSCRIBE prices WHERE symbol = 'AAPL'"),
2264        )
2265        .await;
2266
2267        assert!(
2268            r.is_err(),
2269            "subscribe must stay open until timeout, got: {r:?}"
2270        );
2271
2272        handle.abort();
2273    }
2274
2275    #[tokio::test]
2276    async fn subscribe_with_unknown_column_in_where_returns_pg_error() {
2277        let (addr, handle) = spawn_server().await;
2278        let client = connect(addr).await;
2279
2280        let err = client
2281            .simple_query("SUBSCRIBE prices WHERE no_such_col > 1")
2282            .await
2283            .expect_err("must fail");
2284        let db_err = err.as_db_error().expect("typed PG error");
2285        assert!(
2286            db_err.message().contains("no_such_col"),
2287            "filter error must name the bad column, got: {}",
2288            db_err.message()
2289        );
2290
2291        handle.abort();
2292    }
2293
2294    #[tokio::test]
2295    async fn subscribe_as_of_unretained_returns_pg_error() {
2296        // No retention configured on the `prices` MV from the default setup,
2297        // so AS OF EPOCH 1 must come back as a typed PG error.
2298        let (addr, handle) = spawn_server().await;
2299        let client = connect(addr).await;
2300
2301        let err = client
2302            .simple_query("SUBSCRIBE prices AS OF EPOCH 1")
2303            .await
2304            .expect_err("must fail");
2305        let db_err = err.as_db_error().expect("typed PG error");
2306        assert!(
2307            db_err.message().contains("no longer retained"),
2308            "message: {}",
2309            db_err.message()
2310        );
2311
2312        handle.abort();
2313    }
2314
2315    /// SUBSCRIBE must actually stream emitted MV rows over the socket: bind a
2316    /// portal, push rows into the source, and read them back via the
2317    /// extended-query portal (the chunked path JDBC/asyncpg use).
2318    #[tokio::test]
2319    async fn subscribe_streams_emitted_rows_over_the_wire() {
2320        use std::time::Duration;
2321
2322        use arrow_array::{Float64Array, RecordBatch, StringArray};
2323
2324        let db = Arc::new(LaminarDB::open().expect("db opens"));
2325        db.execute("CREATE SOURCE trades (symbol VARCHAR, price DOUBLE)")
2326            .await
2327            .expect("create source");
2328        db.execute("CREATE MATERIALIZED VIEW prices AS SELECT symbol, price FROM trades")
2329            .await
2330            .expect("create mv");
2331        db.start().await.expect("db starts");
2332        let (addr, handle) = super::serve(
2333            Arc::clone(&db),
2334            "127.0.0.1:0",
2335            HashMap::new(),
2336            false,
2337            None,
2338            256,
2339            10,
2340        )
2341        .await
2342        .expect("serve");
2343        let mut client = connect(addr).await;
2344        let txn = client.transaction().await.expect("begin");
2345
2346        // The subscription opens when the first Execute runs, so push once the
2347        // read is in flight (Tail would otherwise miss earlier rows).
2348        let stmt = txn.prepare("SUBSCRIBE prices").await.expect("prepare");
2349        let portal = txn.bind(&stmt, &[]).await.expect("bind");
2350
2351        let pusher = tokio::spawn({
2352            let db = Arc::clone(&db);
2353            async move {
2354                tokio::time::sleep(Duration::from_millis(300)).await;
2355                let src = db.source_untyped("trades").expect("source handle");
2356                let batch = RecordBatch::try_new(
2357                    src.schema().clone(),
2358                    vec![
2359                        Arc::new(StringArray::from(vec!["AAPL", "MSFT"])),
2360                        Arc::new(Float64Array::from(vec![100.0, 200.0])),
2361                    ],
2362                )
2363                .expect("batch");
2364                src.push_arrow(batch).expect("push");
2365            }
2366        });
2367
2368        let rows = tokio::time::timeout(Duration::from_secs(10), txn.query_portal(&portal, 2))
2369            .await
2370            .expect("read did not time out")
2371            .expect("query_portal");
2372        pusher.await.expect("pusher");
2373
2374        let mut symbols: Vec<String> = rows
2375            .iter()
2376            .map(|r| r.get::<_, &str>(0).to_string())
2377            .collect();
2378        symbols.sort();
2379        assert_eq!(
2380            symbols,
2381            ["AAPL", "MSFT"],
2382            "both emitted rows arrive over pgwire"
2383        );
2384
2385        handle.abort();
2386    }
2387
2388    /// A TEXT[] column must round-trip over the binary wire (asyncpg/JDBC
2389    /// request binary): the column advertises the _text OID and encodes as a
2390    /// Postgres array, so tokio_postgres decodes it into a Vec<String>.
2391    #[tokio::test]
2392    async fn subscribe_decodes_text_array_in_binary_format() {
2393        use std::time::Duration;
2394
2395        use arrow_array::{Int64Array, RecordBatch};
2396
2397        let db = Arc::new(LaminarDB::open().expect("db opens"));
2398        db.execute("CREATE SOURCE feed (id BIGINT)")
2399            .await
2400            .expect("create source");
2401        db.execute(
2402            "CREATE MATERIALIZED VIEW tagged AS SELECT id, make_array('en','ja') AS tags FROM feed",
2403        )
2404        .await
2405        .expect("create mv");
2406        db.start().await.expect("db starts");
2407        let (addr, handle) = super::serve(
2408            Arc::clone(&db),
2409            "127.0.0.1:0",
2410            HashMap::new(),
2411            false,
2412            None,
2413            256,
2414            10,
2415        )
2416        .await
2417        .expect("serve");
2418        let mut client = connect(addr).await;
2419        let txn = client.transaction().await.expect("begin");
2420        let stmt = txn.prepare("SUBSCRIBE tagged").await.expect("prepare");
2421        let portal = txn.bind(&stmt, &[]).await.expect("bind");
2422
2423        let pusher = tokio::spawn({
2424            let db = Arc::clone(&db);
2425            async move {
2426                tokio::time::sleep(Duration::from_millis(300)).await;
2427                let src = db.source_untyped("feed").expect("source handle");
2428                let batch = RecordBatch::try_new(
2429                    src.schema().clone(),
2430                    vec![Arc::new(Int64Array::from(vec![1_i64]))],
2431                )
2432                .expect("batch");
2433                src.push_arrow(batch).expect("push");
2434            }
2435        });
2436
2437        let rows = tokio::time::timeout(Duration::from_secs(10), txn.query_portal(&portal, 1))
2438            .await
2439            .expect("read did not time out")
2440            .expect("query_portal");
2441        pusher.await.expect("pusher");
2442
2443        assert_eq!(rows.len(), 1);
2444        let tags: Vec<String> = rows[0].get(1);
2445        assert_eq!(
2446            tags,
2447            vec!["en".to_string(), "ja".to_string()],
2448            "TEXT[] decoded over the binary wire"
2449        );
2450
2451        handle.abort();
2452    }
2453
2454    #[tokio::test]
2455    async fn ddl_returns_pg_error_pointing_at_http() {
2456        let (addr, handle) = spawn_server().await;
2457        let client = connect(addr).await;
2458
2459        let err = client
2460            .simple_query("CREATE SOURCE more_trades (sym VARCHAR)")
2461            .await
2462            .expect_err("DDL must be rejected");
2463        let db_err = err.as_db_error().expect("typed PG error");
2464        assert!(
2465            db_err.message().contains("/api/v1/sql"),
2466            "message: {}",
2467            db_err.message()
2468        );
2469
2470        handle.abort();
2471    }
2472
2473    async fn md5_users() -> HashMap<String, Secret> {
2474        let mut u = HashMap::new();
2475        u.insert("alice".to_string(), Secret::new(TEST_PASSWORD));
2476        u
2477    }
2478
2479    const TEST_PASSWORD: &str = "wonderland-key";
2480
2481    async fn connect_with_password(
2482        addr: std::net::SocketAddr,
2483        user: &str,
2484        password: &str,
2485    ) -> Result<tokio_postgres::Client, tokio_postgres::Error> {
2486        let conn_str = format!(
2487            "host={} port={} user={user} password={password} dbname=laminardb",
2488            addr.ip(),
2489            addr.port()
2490        );
2491        let (client, conn) = tokio_postgres::connect(&conn_str, NoTls).await?;
2492        tokio::spawn(async move {
2493            let _ = conn.await;
2494        });
2495        Ok(client)
2496    }
2497
2498    #[tokio::test]
2499    async fn md5_auth_accepts_correct_password() {
2500        let (addr, handle) = spawn_server_with(md5_users().await).await;
2501
2502        let client = connect_with_password(addr, "alice", TEST_PASSWORD)
2503            .await
2504            .expect("auth must succeed");
2505
2506        let messages = client
2507            .simple_query("SELECT version()")
2508            .await
2509            .expect("query after auth");
2510        let v = first_row_value(&messages, 0).expect("row");
2511        assert!(v.contains("LaminarDB"), "version: {v}");
2512
2513        handle.abort();
2514    }
2515
2516    #[tokio::test]
2517    async fn md5_auth_rejects_wrong_password() {
2518        let (addr, handle) = spawn_server_with(md5_users().await).await;
2519
2520        let err = connect_with_password(addr, "alice", "not-the-password")
2521            .await
2522            .expect_err("auth must fail");
2523
2524        let db_err = err.as_db_error().expect("typed PG error");
2525        assert_eq!(db_err.code().code(), "28P01", "got: {db_err:?}");
2526
2527        handle.abort();
2528    }
2529
2530    /// Pre-hashed pgwire_users entry: stored value is `md5{hex(md5(pw||user))}`,
2531    /// matching pg_authid. Plaintext never touches disk yet auth still succeeds.
2532    fn md5_users_prehashed(user: &str, password: &str) -> HashMap<String, Secret> {
2533        use md5::{Digest, Md5};
2534        let mut h = Md5::new();
2535        h.update(password.as_bytes());
2536        h.update(user.as_bytes());
2537        let inner = format!("{:x}", h.finalize());
2538        let mut u = HashMap::new();
2539        u.insert(user.to_string(), Secret::new(format!("md5{inner}")));
2540        u
2541    }
2542
2543    #[tokio::test]
2544    async fn md5_auth_accepts_correct_password_against_prehash() {
2545        let (addr, handle) = spawn_server_with(md5_users_prehashed("alice", TEST_PASSWORD)).await;
2546        let client = connect_with_password(addr, "alice", TEST_PASSWORD)
2547            .await
2548            .expect("auth must succeed against pre-hashed entry");
2549        let messages = client
2550            .simple_query("SELECT version()")
2551            .await
2552            .expect("query after auth");
2553        let v = first_row_value(&messages, 0).expect("row");
2554        assert!(v.contains("LaminarDB"), "version: {v}");
2555        handle.abort();
2556    }
2557
2558    #[tokio::test]
2559    async fn md5_auth_rejects_wrong_password_against_prehash() {
2560        let (addr, handle) = spawn_server_with(md5_users_prehashed("alice", TEST_PASSWORD)).await;
2561        let err = connect_with_password(addr, "alice", "not-the-password")
2562            .await
2563            .expect_err("auth must fail");
2564        let db_err = err.as_db_error().expect("typed PG error");
2565        assert_eq!(db_err.code().code(), "28P01", "got: {db_err:?}");
2566        handle.abort();
2567    }
2568
2569    #[test]
2570    fn parse_pre_hashed_md5_strict_format() {
2571        // 32 lowercase hex after the tag → accepted.
2572        let inner = "5d41402abc4b2a76b9719d911017c592";
2573        assert_eq!(
2574            super::parse_pre_hashed_md5(&format!("md5{inner}")),
2575            Some(inner),
2576        );
2577        // Wrong length, uppercase hex, missing prefix, or non-hex → rejected.
2578        assert_eq!(super::parse_pre_hashed_md5("md5short"), None);
2579        assert_eq!(
2580            super::parse_pre_hashed_md5("md55D41402ABC4B2A76B9719D911017C592"),
2581            None,
2582        );
2583        assert_eq!(super::parse_pre_hashed_md5(inner), None);
2584        assert_eq!(
2585            super::parse_pre_hashed_md5("md5zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz"),
2586            None,
2587        );
2588    }
2589
2590    #[tokio::test]
2591    async fn md5_auth_rejects_unknown_user() {
2592        let (addr, handle) = spawn_server_with(md5_users().await).await;
2593
2594        let err = connect_with_password(addr, "mallory", "anything")
2595            .await
2596            .expect_err("auth must fail");
2597
2598        let db_err = err.as_db_error().expect("typed PG error");
2599        assert_eq!(db_err.code().code(), "28P01", "got: {db_err:?}");
2600
2601        handle.abort();
2602    }
2603
2604    #[tokio::test]
2605    async fn connection_cap_drops_excess_clients() {
2606        // Cap of 1; first client occupies the slot, second is dropped.
2607        let db = Arc::new(LaminarDB::open().expect("db opens"));
2608        db.execute("CREATE SOURCE trades (symbol VARCHAR, price DOUBLE)")
2609            .await
2610            .expect("create source");
2611        db.execute("CREATE MATERIALIZED VIEW prices AS SELECT symbol, price FROM trades")
2612            .await
2613            .expect("create mv");
2614        db.start().await.expect("db starts");
2615        let (addr, handle) = super::serve(
2616            Arc::clone(&db),
2617            "127.0.0.1:0",
2618            HashMap::new(),
2619            false,
2620            None,
2621            1,
2622            10,
2623        )
2624        .await
2625        .expect("pgwire serve");
2626
2627        // First client occupies the slot via SUBSCRIBE (stays open until drop).
2628        let first = connect(addr).await;
2629        let _bg = tokio::spawn(async move {
2630            let _ = first.simple_query("SUBSCRIBE prices").await;
2631        });
2632        // Give the server a moment to register the session in the JoinSet.
2633        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
2634
2635        // Second connect: server accepts the TCP, then closes it because the
2636        // cap is hit. tokio_postgres surfaces this as an IO error during
2637        // startup. Exact string varies; just assert it failed.
2638        let conn_str = format!(
2639            "host={} port={} user=any dbname=laminardb",
2640            addr.ip(),
2641            addr.port()
2642        );
2643        let result = tokio_postgres::connect(&conn_str, NoTls).await;
2644        assert!(result.is_err(), "second connect must be refused");
2645
2646        handle.abort();
2647    }
2648
2649    /// Self-signed cert+key written to a tempdir for the duration of the
2650    /// test. `rcgen` is the well-maintained option for ad-hoc certs.
2651    fn self_signed_pem() -> (tempfile::TempDir, std::path::PathBuf, std::path::PathBuf) {
2652        let cert =
2653            rcgen::generate_simple_self_signed(vec!["localhost".into()]).expect("rcgen issue cert");
2654        let dir = tempfile::tempdir().expect("tempdir");
2655        let cert_path = dir.path().join("cert.pem");
2656        let key_path = dir.path().join("key.pem");
2657        std::fs::write(&cert_path, cert.cert.pem()).unwrap();
2658        std::fs::write(&key_path, cert.key_pair.serialize_pem()).unwrap();
2659        (dir, cert_path, key_path)
2660    }
2661
2662    /// CA + client-leaf bundle for mTLS tests. The CA PEM is written to a
2663    /// tempfile so the server can be pointed at it via `pgwire_tls_client_ca`;
2664    /// the leaf cert+key are returned in DER form for direct use by a rustls
2665    /// `ClientConfig`.
2666    struct MintedClientPki {
2667        _dir: tempfile::TempDir,
2668        ca_pem_path: std::path::PathBuf,
2669        leaf_chain: Vec<tokio_rustls::rustls::pki_types::CertificateDer<'static>>,
2670        leaf_key: tokio_rustls::rustls::pki_types::PrivateKeyDer<'static>,
2671    }
2672
2673    fn mint_ca_and_client_leaf(common_name: &str) -> MintedClientPki {
2674        use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
2675
2676        let mut ca_params = rcgen::CertificateParams::new(vec!["mtls-test-ca".into()]).unwrap();
2677        ca_params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
2678        let ca_key = rcgen::KeyPair::generate().unwrap();
2679        let ca_cert = ca_params.self_signed(&ca_key).unwrap();
2680
2681        let mut leaf_params = rcgen::CertificateParams::new(vec![common_name.into()]).unwrap();
2682        leaf_params.extended_key_usages = vec![rcgen::ExtendedKeyUsagePurpose::ClientAuth];
2683        let leaf_key = rcgen::KeyPair::generate().unwrap();
2684        let leaf_cert = leaf_params.signed_by(&leaf_key, &ca_cert, &ca_key).unwrap();
2685
2686        let dir = tempfile::tempdir().unwrap();
2687        let ca_pem_path = dir.path().join("ca.pem");
2688        std::fs::write(&ca_pem_path, ca_cert.pem()).unwrap();
2689
2690        let leaf_chain = vec![CertificateDer::from(leaf_cert.der().to_vec())];
2691        let leaf_key = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(leaf_key.serialize_der()));
2692
2693        MintedClientPki {
2694            _dir: dir,
2695            ca_pem_path,
2696            leaf_chain,
2697            leaf_key,
2698        }
2699    }
2700
2701    /// Builds a tokio_postgres TLS connector that trusts `server_cert_path`
2702    /// for the server hello and (optionally) presents a client cert for mTLS.
2703    fn make_client_tls(
2704        server_cert_path: &std::path::Path,
2705        client_auth: Option<(
2706            Vec<tokio_rustls::rustls::pki_types::CertificateDer<'static>>,
2707            tokio_rustls::rustls::pki_types::PrivateKeyDer<'static>,
2708        )>,
2709    ) -> tokio_postgres_rustls::MakeRustlsConnect {
2710        super::ensure_tls_provider();
2711        let cert_bytes = std::fs::read(server_cert_path).unwrap();
2712        let mut roots = tokio_rustls::rustls::RootCertStore::empty();
2713        for c in rustls_pemfile::certs(&mut std::io::Cursor::new(cert_bytes))
2714            .collect::<Result<Vec<_>, _>>()
2715            .unwrap()
2716        {
2717            roots.add(c).unwrap();
2718        }
2719        let builder = tokio_rustls::rustls::ClientConfig::builder().with_root_certificates(roots);
2720        let client_cfg = match client_auth {
2721            Some((chain, key)) => builder.with_client_auth_cert(chain, key).unwrap(),
2722            None => builder.with_no_client_auth(),
2723        };
2724        tokio_postgres_rustls::MakeRustlsConnect::new(client_cfg)
2725    }
2726
2727    /// Self-signed cert with notAfter in the past, for the expiry test.
2728    fn expired_self_signed_pem() -> (tempfile::TempDir, std::path::PathBuf, std::path::PathBuf) {
2729        let mut params = rcgen::CertificateParams::new(vec!["localhost".into()]).unwrap();
2730        let one_year_ago = time::OffsetDateTime::now_utc() - time::Duration::days(365);
2731        params.not_before = one_year_ago - time::Duration::days(2);
2732        params.not_after = one_year_ago;
2733        let key = rcgen::KeyPair::generate().unwrap();
2734        let cert = params.self_signed(&key).unwrap();
2735        let dir = tempfile::tempdir().unwrap();
2736        let cert_path = dir.path().join("cert.pem");
2737        let key_path = dir.path().join("key.pem");
2738        std::fs::write(&cert_path, cert.pem()).unwrap();
2739        std::fs::write(&key_path, key.serialize_pem()).unwrap();
2740        (dir, cert_path, key_path)
2741    }
2742
2743    #[tokio::test]
2744    async fn tls_load_rejects_expired_cert() {
2745        let (_dir, cert_path, key_path) = expired_self_signed_pem();
2746        let db = Arc::new(LaminarDB::open().expect("db opens"));
2747        db.start().await.expect("db starts");
2748        let err = super::serve(
2749            Arc::clone(&db),
2750            "127.0.0.1:0",
2751            HashMap::new(),
2752            false,
2753            Some(super::TlsPaths {
2754                cert: &cert_path,
2755                key: &key_path,
2756                min_version: super::TlsMinVersion::V1_2,
2757                client_ca: None,
2758            }),
2759            256,
2760            10,
2761        )
2762        .await
2763        .expect_err("expired cert must be rejected");
2764        assert!(err.to_string().contains("expired"), "got: {err}");
2765    }
2766
2767    #[tokio::test]
2768    async fn tls_min_1_3_rejects_tls_1_2_client() {
2769        let (_dir, cert_path, key_path) = self_signed_pem();
2770        let db = Arc::new(LaminarDB::open().expect("db opens"));
2771        db.start().await.expect("db starts");
2772        let (addr, handle) = super::serve(
2773            Arc::clone(&db),
2774            "127.0.0.1:0",
2775            HashMap::new(),
2776            false,
2777            Some(super::TlsPaths {
2778                cert: &cert_path,
2779                key: &key_path,
2780                min_version: super::TlsMinVersion::V1_3,
2781                client_ca: None,
2782            }),
2783            256,
2784            10,
2785        )
2786        .await
2787        .expect("pgwire serve");
2788
2789        let cert_bytes = std::fs::read(&cert_path).unwrap();
2790        let mut roots = tokio_rustls::rustls::RootCertStore::empty();
2791        for c in rustls_pemfile::certs(&mut std::io::Cursor::new(cert_bytes))
2792            .collect::<Result<Vec<_>, _>>()
2793            .unwrap()
2794        {
2795            roots.add(c).unwrap();
2796        }
2797        super::ensure_tls_provider();
2798        // Client pinned to TLS 1.2 only — must be refused by a 1.3-min server.
2799        let client_cfg = tokio_rustls::rustls::ClientConfig::builder_with_protocol_versions(&[
2800            &tokio_rustls::rustls::version::TLS12,
2801        ])
2802        .with_root_certificates(roots)
2803        .with_no_client_auth();
2804
2805        let conn_str = format!(
2806            "host=localhost hostaddr={} port={} user=any dbname=laminardb sslmode=require",
2807            addr.ip(),
2808            addr.port(),
2809        );
2810        let tls = tokio_postgres_rustls::MakeRustlsConnect::new(client_cfg);
2811        let err = match tokio_postgres::connect(&conn_str, tls).await {
2812            Ok(_) => panic!("TLS 1.2 client must be refused by a 1.3-min server"),
2813            Err(e) => e,
2814        };
2815        // tokio_postgres wraps the rustls error; flatten the chain so we can
2816        // assert against the version-mismatch token rustls emits.
2817        let chain = std::iter::successors(Some(&err as &dyn std::error::Error), |e| e.source())
2818            .map(|e| e.to_string())
2819            .collect::<Vec<_>>()
2820            .join(" | ");
2821        assert!(
2822            chain.contains("ProtocolVersion") || chain.contains("incompatible"),
2823            "expected a TLS version-mismatch error, got: {chain}"
2824        );
2825
2826        handle.abort();
2827    }
2828
2829    #[tokio::test]
2830    async fn tls_handshake_succeeds() {
2831        let (_dir, cert_path, key_path) = self_signed_pem();
2832        let db = Arc::new(LaminarDB::open().expect("db opens"));
2833        db.start().await.expect("db starts");
2834        let (addr, handle) = super::serve(
2835            Arc::clone(&db),
2836            "127.0.0.1:0",
2837            HashMap::new(),
2838            false,
2839            Some(super::TlsPaths {
2840                cert: &cert_path,
2841                key: &key_path,
2842                min_version: super::TlsMinVersion::V1_2,
2843                client_ca: None,
2844            }),
2845            256,
2846            10,
2847        )
2848        .await
2849        .expect("pgwire serve");
2850
2851        // Build a client TLS config that trusts the same self-signed cert.
2852        let cert_bytes = std::fs::read(&cert_path).unwrap();
2853        let mut roots = tokio_rustls::rustls::RootCertStore::empty();
2854        for c in rustls_pemfile::certs(&mut std::io::Cursor::new(cert_bytes))
2855            .collect::<Result<Vec<_>, _>>()
2856            .unwrap()
2857        {
2858            roots.add(c).unwrap();
2859        }
2860        super::ensure_tls_provider();
2861        let client_cfg = tokio_rustls::rustls::ClientConfig::builder()
2862            .with_root_certificates(roots)
2863            .with_no_client_auth();
2864
2865        let conn_str = format!(
2866            "host=localhost hostaddr={} port={} user=any dbname=laminardb sslmode=require",
2867            addr.ip(),
2868            addr.port(),
2869        );
2870        let tls = tokio_postgres_rustls::MakeRustlsConnect::new(client_cfg);
2871        let (client, conn) = tokio_postgres::connect(&conn_str, tls)
2872            .await
2873            .expect("TLS handshake + connect");
2874        tokio::spawn(async move {
2875            let _ = conn.await;
2876        });
2877
2878        let messages = client
2879            .simple_query("SELECT version()")
2880            .await
2881            .expect("query over TLS");
2882        let v = first_row_value(&messages, 0).expect("row");
2883        assert!(v.contains("LaminarDB"), "version: {v}");
2884
2885        handle.abort();
2886    }
2887
2888    /// mTLS: with a client_ca configured, a client that presents no cert
2889    /// must be refused at handshake time.
2890    #[tokio::test]
2891    async fn mtls_rejects_client_without_cert() {
2892        let (_dir, cert_path, key_path) = self_signed_pem();
2893        let pki = mint_ca_and_client_leaf("alice");
2894        let db = Arc::new(LaminarDB::open().expect("db opens"));
2895        db.start().await.expect("db starts");
2896        let (addr, handle) = super::serve(
2897            Arc::clone(&db),
2898            "127.0.0.1:0",
2899            HashMap::new(),
2900            false,
2901            Some(super::TlsPaths {
2902                cert: &cert_path,
2903                key: &key_path,
2904                min_version: super::TlsMinVersion::V1_2,
2905                client_ca: Some(&pki.ca_pem_path),
2906            }),
2907            256,
2908            10,
2909        )
2910        .await
2911        .expect("pgwire serve");
2912
2913        let tls = make_client_tls(&cert_path, None);
2914        let conn_str = format!(
2915            "host=localhost hostaddr={} port={} user=any dbname=laminardb sslmode=require",
2916            addr.ip(),
2917            addr.port(),
2918        );
2919        let err = match tokio_postgres::connect(&conn_str, tls).await {
2920            Ok(_) => panic!("client without a cert must be refused under mTLS"),
2921            Err(e) => e,
2922        };
2923        assert!(
2924            err_chain(&err).contains("CertificateRequired")
2925                || err_chain(&err).contains("HandshakeFailure")
2926                || err_chain(&err).contains("certificate required"),
2927            "expected a missing-client-cert error, got: {}",
2928            err_chain(&err),
2929        );
2930        handle.abort();
2931    }
2932
2933    /// mTLS: a client cert signed by an unknown CA must be refused.
2934    #[tokio::test]
2935    async fn mtls_rejects_untrusted_client_cert() {
2936        let (_dir, cert_path, key_path) = self_signed_pem();
2937        let trusted = mint_ca_and_client_leaf("trusted");
2938        let stranger = mint_ca_and_client_leaf("stranger");
2939        let db = Arc::new(LaminarDB::open().expect("db opens"));
2940        db.start().await.expect("db starts");
2941        let (addr, handle) = super::serve(
2942            Arc::clone(&db),
2943            "127.0.0.1:0",
2944            HashMap::new(),
2945            false,
2946            Some(super::TlsPaths {
2947                cert: &cert_path,
2948                key: &key_path,
2949                min_version: super::TlsMinVersion::V1_2,
2950                client_ca: Some(&trusted.ca_pem_path),
2951            }),
2952            256,
2953            10,
2954        )
2955        .await
2956        .expect("pgwire serve");
2957
2958        // Client presents a leaf signed by a CA the server doesn't know.
2959        let tls = make_client_tls(
2960            &cert_path,
2961            Some((stranger.leaf_chain.clone(), stranger.leaf_key.clone_key())),
2962        );
2963        let conn_str = format!(
2964            "host=localhost hostaddr={} port={} user=any dbname=laminardb sslmode=require",
2965            addr.ip(),
2966            addr.port(),
2967        );
2968        let err = match tokio_postgres::connect(&conn_str, tls).await {
2969            Ok(_) => panic!("untrusted client cert must be refused"),
2970            Err(e) => e,
2971        };
2972        // rustls maps a verifier-rejected client cert to a fatal alert; the
2973        // exact variant depends on the protocol version and verifier path
2974        // (UnknownCA / BadCertificate on 1.2, DecryptError or
2975        // CertificateUnknown on 1.3). We assert it failed at the TLS layer.
2976        let chain = err_chain(&err);
2977        assert!(
2978            chain.contains("UnknownCA")
2979                || chain.contains("BadCertificate")
2980                || chain.contains("CertificateUnknown")
2981                || chain.contains("DecryptError")
2982                || chain.contains("HandshakeFailure"),
2983            "expected a cert-rejection alert, got: {chain}",
2984        );
2985        handle.abort();
2986    }
2987
2988    /// mTLS: a client cert signed by the configured CA is accepted, and a
2989    /// SimpleQuery completes over the encrypted+authenticated session.
2990    #[tokio::test]
2991    async fn mtls_accepts_trusted_client_cert() {
2992        let (_dir, cert_path, key_path) = self_signed_pem();
2993        let pki = mint_ca_and_client_leaf("alice");
2994        let db = Arc::new(LaminarDB::open().expect("db opens"));
2995        db.start().await.expect("db starts");
2996        let (addr, handle) = super::serve(
2997            Arc::clone(&db),
2998            "127.0.0.1:0",
2999            HashMap::new(),
3000            false,
3001            Some(super::TlsPaths {
3002                cert: &cert_path,
3003                key: &key_path,
3004                min_version: super::TlsMinVersion::V1_2,
3005                client_ca: Some(&pki.ca_pem_path),
3006            }),
3007            256,
3008            10,
3009        )
3010        .await
3011        .expect("pgwire serve");
3012
3013        let tls = make_client_tls(
3014            &cert_path,
3015            Some((pki.leaf_chain.clone(), pki.leaf_key.clone_key())),
3016        );
3017        let conn_str = format!(
3018            "host=localhost hostaddr={} port={} user=any dbname=laminardb sslmode=require",
3019            addr.ip(),
3020            addr.port(),
3021        );
3022        let (client, conn) = tokio_postgres::connect(&conn_str, tls)
3023            .await
3024            .expect("mTLS handshake + connect");
3025        tokio::spawn(async move {
3026            let _ = conn.await;
3027        });
3028
3029        let messages = client
3030            .simple_query("SELECT version()")
3031            .await
3032            .expect("query over mTLS");
3033        let v = first_row_value(&messages, 0).expect("row");
3034        assert!(v.contains("LaminarDB"), "version: {v}");
3035        handle.abort();
3036    }
3037
3038    /// Build a `TlsReloadState` directly for unit-testing the reload path
3039    /// without standing up a listener.
3040    fn build_reload_state(cert: &std::path::Path, key: &std::path::Path) -> super::TlsReloadState {
3041        let paths = super::TlsPaths {
3042            cert,
3043            key,
3044            min_version: super::TlsMinVersion::V1_2,
3045            client_ca: None,
3046        };
3047        let acceptor = super::load_tls_acceptor(super::TlsPaths {
3048            cert: paths.cert,
3049            key: paths.key,
3050            min_version: paths.min_version,
3051            client_ca: paths.client_ca,
3052        })
3053        .expect("initial acceptor loads");
3054        super::TlsReloadState {
3055            paths: super::TlsConfigPaths::from_paths(&paths),
3056            acceptor: parking_lot::Mutex::new(Arc::new(acceptor)),
3057        }
3058    }
3059
3060    /// Hot-reload: writing a fresh cert+key over the configured paths and
3061    /// calling `try_reload_tls` swaps the acceptor under the mutex.
3062    #[test]
3063    fn tls_reload_swaps_acceptor_on_valid_pair() {
3064        let dir = tempfile::tempdir().unwrap();
3065        let cert_path = dir.path().join("cert.pem");
3066        let key_path = dir.path().join("key.pem");
3067        // Initial cert.
3068        let first = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
3069        std::fs::write(&cert_path, first.cert.pem()).unwrap();
3070        std::fs::write(&key_path, first.key_pair.serialize_pem()).unwrap();
3071
3072        let state = build_reload_state(&cert_path, &key_path);
3073        let before = state.snapshot();
3074
3075        // Rotate to a brand-new pair.
3076        let second = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
3077        std::fs::write(&cert_path, second.cert.pem()).unwrap();
3078        std::fs::write(&key_path, second.key_pair.serialize_pem()).unwrap();
3079
3080        super::try_reload_tls(&state).expect("reload succeeds");
3081        let after = state.snapshot();
3082        assert!(
3083            !Arc::ptr_eq(&before, &after),
3084            "acceptor pointer must change after a successful reload",
3085        );
3086    }
3087
3088    /// Hot-reload: a corrupt cert file leaves the previous acceptor in
3089    /// place — TLS doesn't go down on a bad rotation.
3090    #[test]
3091    fn tls_reload_keeps_old_acceptor_on_garbage() {
3092        let dir = tempfile::tempdir().unwrap();
3093        let cert_path = dir.path().join("cert.pem");
3094        let key_path = dir.path().join("key.pem");
3095        let first = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
3096        std::fs::write(&cert_path, first.cert.pem()).unwrap();
3097        std::fs::write(&key_path, first.key_pair.serialize_pem()).unwrap();
3098
3099        let state = build_reload_state(&cert_path, &key_path);
3100        let before = state.snapshot();
3101
3102        // Truncate cert.pem to non-PEM garbage.
3103        std::fs::write(&cert_path, b"this is not a certificate").unwrap();
3104        let err = super::try_reload_tls(&state).expect_err("reload must fail");
3105        let after = state.snapshot();
3106        assert!(
3107            Arc::ptr_eq(&before, &after),
3108            "acceptor must be unchanged on reload failure",
3109        );
3110        assert!(
3111            err.to_string().contains("pgwire_tls_cert"),
3112            "error should mention pgwire_tls_cert, got: {err}",
3113        );
3114    }
3115
3116    /// Flatten an error and its `source()` chain to a single string for
3117    /// substring assertions.
3118    fn err_chain(err: &(dyn std::error::Error + 'static)) -> String {
3119        std::iter::successors(Some(err), |e| e.source())
3120            .map(|e| e.to_string())
3121            .collect::<Vec<_>>()
3122            .join(" | ")
3123    }
3124
3125    /// Push one row into the `trades` source so subsequent SUBSCRIBE
3126    /// reads have something to drain. Returns the schema for tests
3127    /// that want to build their own batches.
3128    async fn push_one_trade(
3129        db: &Arc<LaminarDB>,
3130        symbol: &str,
3131        price: f64,
3132    ) -> arrow_schema::SchemaRef {
3133        let handle = db.source_untyped("trades").expect("source handle");
3134        let schema = handle.schema().clone();
3135        let batch = arrow_array::RecordBatch::try_new(
3136            Arc::clone(&schema),
3137            vec![
3138                Arc::new(arrow_array::StringArray::from(vec![symbol])),
3139                Arc::new(arrow_array::Float64Array::from(vec![price])),
3140            ],
3141        )
3142        .expect("batch");
3143        handle.push_arrow(batch).expect("push");
3144        schema
3145    }
3146
3147    /// Ingest a row and return both the running server and the underlying db
3148    /// so tests can keep pushing rows after the listener is up.
3149    async fn spawn_with_data() -> (
3150        Arc<LaminarDB>,
3151        std::net::SocketAddr,
3152        tokio::task::JoinHandle<()>,
3153    ) {
3154        let db = Arc::new(LaminarDB::open().expect("db opens"));
3155        db.execute("CREATE SOURCE trades (symbol VARCHAR, price DOUBLE)")
3156            .await
3157            .expect("create source");
3158        db.execute(
3159            "CREATE MATERIALIZED VIEW prices AS \
3160             SELECT symbol, price FROM trades",
3161        )
3162        .await
3163        .expect("create mv");
3164        db.start().await.expect("db starts");
3165
3166        let (addr, handle) = super::serve(
3167            Arc::clone(&db),
3168            "127.0.0.1:0",
3169            HashMap::new(),
3170            false,
3171            None,
3172            256,
3173            10,
3174        )
3175        .await
3176        .expect("pgwire serve");
3177        (db, addr, handle)
3178    }
3179
3180    /// Same as `spawn_with_data`, but `prices` is a STREAM with retained
3181    /// history. Lets cursor tests push rows *before* SUBSCRIBE attaches
3182    /// without losing them — the receiver replays on attach.
3183    async fn spawn_with_retained_data() -> (
3184        Arc<LaminarDB>,
3185        std::net::SocketAddr,
3186        tokio::task::JoinHandle<()>,
3187    ) {
3188        let db = Arc::new(LaminarDB::open().expect("db opens"));
3189        db.execute("CREATE SOURCE trades (symbol VARCHAR, price DOUBLE)")
3190            .await
3191            .expect("create source");
3192        db.execute(
3193            "CREATE STREAM prices AS SELECT symbol, price FROM trades \
3194             WITH ('retain_history' = '4mb')",
3195        )
3196        .await
3197        .expect("create stream");
3198        db.start().await.expect("db starts");
3199
3200        let (addr, handle) = super::serve(
3201            Arc::clone(&db),
3202            "127.0.0.1:0",
3203            HashMap::new(),
3204            false,
3205            None,
3206            256,
3207            10,
3208        )
3209        .await
3210        .expect("pgwire serve");
3211        (db, addr, handle)
3212    }
3213
3214    /// `prepare()` triggers `Parse` + `Describe(Statement)`. Verifies the
3215    /// extended-query parser resolves stream schemas at parse time and
3216    /// returns column metadata to the client.
3217    #[tokio::test]
3218    async fn extended_query_describe_subscribe_returns_columns() {
3219        let (_db, addr, handle) = spawn_with_data().await;
3220        let client = connect(addr).await;
3221
3222        let stmt = client
3223            .prepare("SUBSCRIBE prices")
3224            .await
3225            .expect("prepare SUBSCRIBE prices");
3226
3227        let cols = stmt.columns();
3228        assert_eq!(cols.len(), 2, "expected 2 columns, got {}", cols.len());
3229        assert_eq!(cols[0].name(), "symbol");
3230        assert_eq!(cols[1].name(), "price");
3231        assert_eq!(cols[0].type_(), &tokio_postgres::types::Type::VARCHAR);
3232        assert_eq!(cols[1].type_(), &tokio_postgres::types::Type::FLOAT8);
3233
3234        handle.abort();
3235    }
3236
3237    /// Unknown stream → typed PG error at `Parse` time, before any rows
3238    /// are pulled.
3239    #[tokio::test]
3240    async fn extended_query_prepare_unknown_stream_errors() {
3241        let (_db, addr, handle) = spawn_with_data().await;
3242        let client = connect(addr).await;
3243
3244        let err = client
3245            .prepare("SUBSCRIBE no_such_view")
3246            .await
3247            .expect_err("must fail at Parse");
3248        let db_err = err.as_db_error().expect("typed PG error");
3249        assert!(db_err.message().contains("no_such_view"));
3250
3251        handle.abort();
3252    }
3253
3254    /// Bind + Execute with `max_rows=1` against a portal returns one row at a
3255    /// time and `PortalSuspended`. Drives the binary-format encoders for
3256    /// VARCHAR + FLOAT8.
3257    #[tokio::test]
3258    async fn extended_query_binary_chunked_subscribe() {
3259        let (db, addr, handle) = spawn_with_data().await;
3260        let mut client = connect(addr).await;
3261
3262        // tokio_postgres' `bind` + `query_portal` uses the extended-query
3263        // protocol with binary format for known column types — the path
3264        // JDBC and asyncpg take with prepared statements.
3265        let tx = client.transaction().await.expect("BEGIN");
3266        let stmt = tx.prepare("SUBSCRIBE prices").await.expect("prepare");
3267        let portal = tx.bind(&stmt, &[]).await.expect("bind portal");
3268
3269        // The MV broadcast has no receiver until `Execute` reaches the
3270        // server and runs `do_query` → `open_subscription`. We can't push
3271        // from this task before query_portal because query_portal blocks
3272        // waiting for a row, so spawn the pushes from a sibling task with
3273        // a short head start for the receiver to attach. With cap=0
3274        // retention, a push that lands before the receiver is dropped.
3275        let pusher = {
3276            let db = Arc::clone(&db);
3277            tokio::spawn(async move {
3278                tokio::time::sleep(std::time::Duration::from_millis(100)).await;
3279                push_one_trade(&db, "AAPL", 150.5).await;
3280                push_one_trade(&db, "GOOG", 2700.25).await;
3281            })
3282        };
3283
3284        let first = tokio::time::timeout(
3285            std::time::Duration::from_secs(3),
3286            tx.query_portal(&portal, 1),
3287        )
3288        .await
3289        .expect("first chunk arrives within 3s")
3290        .expect("query_portal #1");
3291        assert_eq!(first.len(), 1);
3292        let symbol: &str = first[0].get(0);
3293        let price: f64 = first[0].get(1);
3294        assert_eq!(symbol, "AAPL");
3295        assert!((price - 150.5).abs() < 1e-9);
3296
3297        let second = tokio::time::timeout(
3298            std::time::Duration::from_secs(3),
3299            tx.query_portal(&portal, 1),
3300        )
3301        .await
3302        .expect("second chunk arrives within 3s")
3303        .expect("query_portal #2");
3304        assert_eq!(second.len(), 1);
3305        let symbol: &str = second[0].get(0);
3306        let price: f64 = second[0].get(1);
3307        assert_eq!(symbol, "GOOG");
3308        assert!((price - 2700.25).abs() < 1e-9);
3309
3310        pusher.await.expect("push task");
3311        handle.abort();
3312    }
3313
3314    /// Regression: binary encoding of `TIMESTAMP` columns must downcast
3315    /// the Arrow array as its unit-specific primitive type
3316    /// (`PrimitiveArray<TimestampMicrosecondType>`, not
3317    /// `PrimitiveArray<Int64Type>`). A bug in this branch would panic on
3318    /// the first row.
3319    #[tokio::test]
3320    async fn extended_query_binary_timestamp() {
3321        let db = Arc::new(LaminarDB::open().expect("db opens"));
3322        // `WATERMARK FOR ts AS ts - INTERVAL '0' SECOND` declares event time
3323        // so the streaming pipeline drives progress on the timestamp
3324        // column — without it, the MV stays empty.
3325        db.execute(
3326            "CREATE SOURCE events (ts TIMESTAMP, sym VARCHAR, \
3327             WATERMARK FOR ts AS ts - INTERVAL '0' SECOND)",
3328        )
3329        .await
3330        .expect("create source");
3331        db.execute("CREATE MATERIALIZED VIEW ev AS SELECT ts, sym FROM events")
3332            .await
3333            .expect("create mv");
3334        db.start().await.expect("db starts");
3335
3336        let (addr, handle) = super::serve(
3337            Arc::clone(&db),
3338            "127.0.0.1:0",
3339            HashMap::new(),
3340            false,
3341            None,
3342            256,
3343            10,
3344        )
3345        .await
3346        .expect("pgwire serve");
3347
3348        let mut client = connect(addr).await;
3349        let tx = client.transaction().await.expect("BEGIN");
3350        let stmt = tx.prepare("SUBSCRIBE ev").await.expect("prepare");
3351        let portal = tx.bind(&stmt, &[]).await.expect("bind");
3352
3353        let expected = chrono::NaiveDate::from_ymd_opt(2026, 5, 9)
3354            .unwrap()
3355            .and_hms_opt(0, 0, 0)
3356            .unwrap();
3357        let ts_us = expected.and_utc().timestamp_micros();
3358
3359        // Push from a sibling task after a short delay so the MV
3360        // broadcast receiver (created inside `Execute`) is attached
3361        // before send_batch fires. See the matching note in
3362        // `extended_query_binary_chunked_subscribe`.
3363        let pusher = {
3364            let db = Arc::clone(&db);
3365            tokio::spawn(async move {
3366                tokio::time::sleep(std::time::Duration::from_millis(100)).await;
3367                let src = db.source_untyped("events").expect("source");
3368                let batch = arrow_array::RecordBatch::try_new(
3369                    src.schema().clone(),
3370                    vec![
3371                        Arc::new(arrow_array::TimestampMicrosecondArray::from(vec![ts_us])),
3372                        Arc::new(arrow_array::StringArray::from(vec!["AAPL"])),
3373                    ],
3374                )
3375                .expect("batch");
3376                src.push_arrow(batch).expect("push");
3377            })
3378        };
3379
3380        let rows = tokio::time::timeout(
3381            std::time::Duration::from_secs(3),
3382            tx.query_portal(&portal, 1),
3383        )
3384        .await
3385        .expect("row arrives within 3s")
3386        .expect("query_portal");
3387        assert_eq!(rows.len(), 1);
3388
3389        let ts: chrono::NaiveDateTime = rows[0].get(0);
3390        let sym: &str = rows[0].get(1);
3391        assert_eq!(ts, expected);
3392        assert_eq!(sym, "AAPL");
3393
3394        pusher.await.expect("push task");
3395        handle.abort();
3396    }
3397
3398    /// DDL on the extended-query path is refused at `Parse` with a typed
3399    /// 0A000 error pointing at the HTTP endpoint — same surface as the
3400    /// SimpleQuery path.
3401    #[tokio::test]
3402    async fn extended_query_ddl_rejected() {
3403        let (_db, addr, handle) = spawn_with_data().await;
3404        let client = connect(addr).await;
3405
3406        let err = client
3407            .prepare("CREATE SOURCE more_trades (sym VARCHAR)")
3408            .await
3409            .expect_err("DDL must be rejected at Parse");
3410        let db_err = err.as_db_error().expect("typed PG error");
3411        assert!(
3412            db_err.message().contains("/api/v1/sql"),
3413            "message: {}",
3414            db_err.message()
3415        );
3416
3417        handle.abort();
3418    }
3419
3420    /// `\set FETCH_COUNT N` flow: BEGIN; DECLARE …; FETCH N FROM …; CLOSE; COMMIT.
3421    /// All over SimpleQuery — the path psql uses when `FETCH_COUNT` is set.
3422    /// Uses the retained-history variant so we can push before SUBSCRIBE.
3423    #[tokio::test]
3424    async fn cursor_declare_fetch_close_happy_path() {
3425        let (db, addr, handle) = spawn_with_retained_data().await;
3426        let client = connect(addr).await;
3427
3428        for i in 0..4 {
3429            push_one_trade(&db, &format!("S{i}"), i as f64).await;
3430        }
3431
3432        client.simple_query("BEGIN").await.expect("BEGIN");
3433        client
3434            .simple_query("DECLARE c CURSOR FOR SUBSCRIBE prices")
3435            .await
3436            .expect("DECLARE");
3437
3438        let messages = client
3439            .simple_query("FETCH 2 FROM c")
3440            .await
3441            .expect("FETCH 2");
3442        let row_count = messages
3443            .iter()
3444            .filter(|m| matches!(m, SimpleQueryMessage::Row(_)))
3445            .count();
3446        assert_eq!(row_count, 2, "expected exactly 2 rows from FETCH 2");
3447
3448        client.simple_query("CLOSE c").await.expect("CLOSE");
3449        client.simple_query("COMMIT").await.expect("COMMIT");
3450
3451        handle.abort();
3452    }
3453
3454    /// COMMIT must close any open cursors. After COMMIT, FETCH against the
3455    /// same name returns "cursor does not exist".
3456    #[tokio::test]
3457    async fn cursor_commit_closes_cursors() {
3458        let (_db, addr, handle) = spawn_with_data().await;
3459        let client = connect(addr).await;
3460
3461        client.simple_query("BEGIN").await.expect("BEGIN");
3462        client
3463            .simple_query("DECLARE c CURSOR FOR SUBSCRIBE prices")
3464            .await
3465            .expect("DECLARE");
3466        client.simple_query("COMMIT").await.expect("COMMIT");
3467
3468        let err = client
3469            .simple_query("FETCH 1 FROM c")
3470            .await
3471            .expect_err("FETCH after COMMIT must fail");
3472        let db_err = err.as_db_error().expect("typed PG error");
3473        assert_eq!(db_err.code().code(), "34000", "got {db_err:?}");
3474
3475        handle.abort();
3476    }
3477
3478    /// ROLLBACK closes cursors too — same reaper as COMMIT.
3479    #[tokio::test]
3480    async fn cursor_rollback_closes_cursors() {
3481        let (_db, addr, handle) = spawn_with_data().await;
3482        let client = connect(addr).await;
3483
3484        client.simple_query("BEGIN").await.expect("BEGIN");
3485        client
3486            .simple_query("DECLARE c CURSOR FOR SUBSCRIBE prices")
3487            .await
3488            .expect("DECLARE");
3489        client.simple_query("ROLLBACK").await.expect("ROLLBACK");
3490
3491        let err = client
3492            .simple_query("FETCH 1 FROM c")
3493            .await
3494            .expect_err("FETCH after ROLLBACK must fail");
3495        let db_err = err.as_db_error().expect("typed PG error");
3496        assert_eq!(db_err.code().code(), "34000", "got {db_err:?}");
3497
3498        handle.abort();
3499    }
3500
3501    /// Explicit CLOSE works outside a transaction. PG allows DECLARE without
3502    /// BEGIN; we follow that for parity with `\set FETCH_COUNT 0` clients.
3503    #[tokio::test]
3504    async fn cursor_close_explicit() {
3505        let (_db, addr, handle) = spawn_with_data().await;
3506        let client = connect(addr).await;
3507
3508        client
3509            .simple_query("DECLARE c CURSOR FOR SUBSCRIBE prices")
3510            .await
3511            .expect("DECLARE");
3512        client.simple_query("CLOSE c").await.expect("CLOSE");
3513
3514        let err = client
3515            .simple_query("FETCH 1 FROM c")
3516            .await
3517            .expect_err("FETCH after CLOSE must fail");
3518        let db_err = err.as_db_error().expect("typed PG error");
3519        assert_eq!(db_err.code().code(), "34000", "got {db_err:?}");
3520
3521        handle.abort();
3522    }
3523
3524    /// `SCROLL`, `BINARY`, `WITH HOLD` all rejected at parse time.
3525    #[tokio::test]
3526    async fn cursor_unsupported_modifiers_rejected() {
3527        let (_db, addr, handle) = spawn_with_data().await;
3528        let client = connect(addr).await;
3529
3530        for sql in [
3531            "DECLARE c SCROLL CURSOR FOR SUBSCRIBE prices",
3532            "DECLARE c BINARY CURSOR FOR SUBSCRIBE prices",
3533            "DECLARE c CURSOR WITH HOLD FOR SUBSCRIBE prices",
3534            "DECLARE c INSENSITIVE CURSOR FOR SUBSCRIBE prices",
3535        ] {
3536            let err = client
3537                .simple_query(sql)
3538                .await
3539                .expect_err(&format!("{sql} must fail"));
3540            let db_err = err.as_db_error().expect("typed PG error");
3541            assert_eq!(
3542                db_err.code().code(),
3543                "42601",
3544                "{sql}: expected parse error, got {db_err:?}"
3545            );
3546        }
3547
3548        handle.abort();
3549    }
3550
3551    /// `FETCH BACKWARD` and other reverse / absolute directions are rejected
3552    /// because SUBSCRIBE is forward-only.
3553    #[tokio::test]
3554    async fn cursor_backward_directions_rejected() {
3555        let (_db, addr, handle) = spawn_with_data().await;
3556        let client = connect(addr).await;
3557
3558        client
3559            .simple_query("DECLARE c CURSOR FOR SUBSCRIBE prices")
3560            .await
3561            .expect("DECLARE");
3562
3563        for sql in [
3564            "FETCH PRIOR FROM c",
3565            "FETCH BACKWARD 1 FROM c",
3566            "FETCH FIRST FROM c",
3567            "FETCH LAST FROM c",
3568            "FETCH ABSOLUTE 1 FROM c",
3569            "FETCH RELATIVE 1 FROM c",
3570        ] {
3571            let err = client
3572                .simple_query(sql)
3573                .await
3574                .expect_err(&format!("{sql} must fail"));
3575            let db_err = err.as_db_error().expect("typed PG error");
3576            assert_eq!(db_err.code().code(), "0A000", "{sql}: got {db_err:?}");
3577        }
3578
3579        client.simple_query("CLOSE c").await.expect("CLOSE");
3580        handle.abort();
3581    }
3582
3583    /// `DECLARE … CURSOR FOR <SELECT …>` (regular query, not SUBSCRIBE) is
3584    /// not supported on pgwire.
3585    #[tokio::test]
3586    async fn cursor_for_non_subscribe_rejected() {
3587        let (_db, addr, handle) = spawn_with_data().await;
3588        let client = connect(addr).await;
3589
3590        let err = client
3591            .simple_query("DECLARE c CURSOR FOR SELECT 1")
3592            .await
3593            .expect_err("DECLARE FOR SELECT must fail");
3594        let db_err = err.as_db_error().expect("typed PG error");
3595        assert_eq!(db_err.code().code(), "0A000", "got {db_err:?}");
3596
3597        handle.abort();
3598    }
3599
3600    /// FETCH against a name we never declared returns 34000 (invalid_cursor_name).
3601    #[tokio::test]
3602    async fn cursor_fetch_unknown_name_errors() {
3603        let (_db, addr, handle) = spawn_with_data().await;
3604        let client = connect(addr).await;
3605
3606        let err = client
3607            .simple_query("FETCH 1 FROM nope")
3608            .await
3609            .expect_err("must fail");
3610        let db_err = err.as_db_error().expect("typed PG error");
3611        assert_eq!(db_err.code().code(), "34000", "got {db_err:?}");
3612
3613        handle.abort();
3614    }
3615
3616    /// A multi-row batch with `FETCH 1` repeated must return each row in
3617    /// order — leftover rows persist on the cursor instead of being dropped
3618    /// when the response stream ends. With the bug, `FETCH 1` would consume
3619    /// the batch internally, return row[0], and discard row[1].
3620    #[tokio::test]
3621    async fn cursor_fetch_preserves_leftover_rows_in_one_batch() {
3622        let (db, addr, handle) = spawn_with_retained_data().await;
3623        let client = connect(addr).await;
3624
3625        let src = db.source_untyped("trades").expect("source");
3626        let batch = arrow_array::RecordBatch::try_new(
3627            src.schema().clone(),
3628            vec![
3629                Arc::new(arrow_array::StringArray::from(vec!["AAPL", "GOOG"])),
3630                Arc::new(arrow_array::Float64Array::from(vec![1.0, 2.0])),
3631            ],
3632        )
3633        .expect("batch");
3634        src.push_arrow(batch).expect("push");
3635
3636        client
3637            .simple_query("DECLARE c CURSOR FOR SUBSCRIBE prices")
3638            .await
3639            .expect("DECLARE");
3640
3641        let first = client
3642            .simple_query("FETCH 1 FROM c")
3643            .await
3644            .expect("FETCH 1");
3645        let r1: Vec<&str> = first
3646            .iter()
3647            .filter_map(|m| match m {
3648                SimpleQueryMessage::Row(r) => r.get(0),
3649                _ => None,
3650            })
3651            .collect();
3652        assert_eq!(r1, vec!["AAPL"]);
3653
3654        let second = client
3655            .simple_query("FETCH 1 FROM c")
3656            .await
3657            .expect("FETCH 1");
3658        let r2: Vec<&str> = second
3659            .iter()
3660            .filter_map(|m| match m {
3661                SimpleQueryMessage::Row(r) => r.get(0),
3662                _ => None,
3663            })
3664            .collect();
3665        assert_eq!(r2, vec!["GOOG"]);
3666
3667        client.simple_query("CLOSE c").await.expect("CLOSE");
3668        handle.abort();
3669    }
3670
3671    /// Re-DECLAREing an open cursor name returns 42P03; user must CLOSE first.
3672    #[tokio::test]
3673    async fn cursor_duplicate_declare_rejected() {
3674        let (_db, addr, handle) = spawn_with_data().await;
3675        let client = connect(addr).await;
3676
3677        client
3678            .simple_query("DECLARE c CURSOR FOR SUBSCRIBE prices")
3679            .await
3680            .expect("first DECLARE");
3681
3682        let err = client
3683            .simple_query("DECLARE c CURSOR FOR SUBSCRIBE prices")
3684            .await
3685            .expect_err("duplicate DECLARE must fail");
3686        let db_err = err.as_db_error().expect("typed PG error");
3687        assert_eq!(db_err.code().code(), "42P03", "got {db_err:?}");
3688
3689        // After CLOSE the name is free again.
3690        client.simple_query("CLOSE c").await.expect("CLOSE");
3691        client
3692            .simple_query("DECLARE c CURSOR FOR SUBSCRIBE prices")
3693            .await
3694            .expect("re-DECLARE after CLOSE");
3695        client.simple_query("CLOSE c").await.expect("CLOSE again");
3696
3697        handle.abort();
3698    }
3699
3700    /// Cursor name lookup is case-insensitive (PG identifier folding rules).
3701    #[tokio::test]
3702    async fn cursor_name_case_insensitive() {
3703        let (db, addr, handle) = spawn_with_retained_data().await;
3704        let client = connect(addr).await;
3705        push_one_trade(&db, "AAPL", 1.0).await;
3706
3707        client
3708            .simple_query("DECLARE MyCursor CURSOR FOR SUBSCRIBE prices")
3709            .await
3710            .expect("DECLARE");
3711
3712        let messages = client
3713            .simple_query("FETCH 1 FROM mycursor")
3714            .await
3715            .expect("FETCH from lowercased name");
3716        let row_count = messages
3717            .iter()
3718            .filter(|m| matches!(m, SimpleQueryMessage::Row(_)))
3719            .count();
3720        assert_eq!(row_count, 1);
3721
3722        client.simple_query("CLOSE MYCURSOR").await.expect("CLOSE");
3723        handle.abort();
3724    }
3725}