1use std::collections::{HashMap, VecDeque};
6use std::fmt::Debug;
7use std::net::SocketAddr;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::Arc;
10
11use async_trait::async_trait;
12use futures::{stream, Sink, StreamExt};
13use laminar_sql::parser::{
14 parse_streaming_sql, ShowCommand, StreamingStatement, SubscribeStatement,
15};
16use pgwire::api::auth::md5pass::{hash_md5_password, Md5PasswordAuthStartupHandler};
17use pgwire::api::auth::noop::NoopStartupHandler;
18use pgwire::api::auth::{
19 AuthSource, DefaultServerParameterProvider, LoginInfo, Password, StartupHandler,
20};
21use pgwire::api::portal::{Format, Portal};
22use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
23use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag};
24use pgwire::api::stmt::QueryParser;
25use pgwire::api::store::PortalStore;
26use pgwire::api::{ClientInfo, ClientPortalStore, PgWireServerHandlers, Type};
27use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
28use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
29use pgwire::tokio::process_socket;
30use sqlparser::ast::{
31 CloseCursor, Expr, FetchDirection, FunctionArguments, SelectItem, Set, SetExpr, Statement,
32 Value as AstValue,
33};
34use tokio::net::TcpListener;
35use tokio::sync::Mutex as TokioMutex;
36use tracing::{info, warn};
37
38use laminar_db::subscription::{PortalFrame, SubscribeStart, SubscriptionPortal};
39use laminar_db::LaminarDB;
40
41use crate::config::Secret;
42use crate::server::ServerError;
43
44pub struct LaminarPgwireHandler {
45 db: Arc<LaminarDB>,
46 connections: parking_lot::Mutex<HashMap<SocketAddr, Arc<ConnState>>>,
51}
52
53impl LaminarPgwireHandler {
54 fn new(db: Arc<LaminarDB>) -> Self {
55 Self {
56 db,
57 connections: parking_lot::Mutex::new(HashMap::new()),
58 }
59 }
60
61 fn conn_state(&self, peer: SocketAddr) -> Arc<ConnState> {
62 let mut guard = self.connections.lock();
63 Arc::clone(
64 guard
65 .entry(peer)
66 .or_insert_with(|| Arc::new(ConnState::default())),
67 )
68 }
69
70 fn evict_idle_peer(&self, peer: SocketAddr) {
71 let mut guard = self.connections.lock();
72 if let Some(state) = guard.get(&peer) {
73 if state.prune_dead_and_check_idle() {
74 guard.remove(&peer);
75 }
76 }
77 }
78}
79
80#[async_trait]
81impl NoopStartupHandler for LaminarPgwireHandler {
82 async fn post_startup<C>(
83 &self,
84 client: &mut C,
85 _message: PgWireFrontendMessage,
86 ) -> PgWireResult<()>
87 where
88 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
89 C::Error: Debug,
90 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
91 {
92 info!(peer = %client.socket_addr(), "pgwire client connected");
93 Ok(())
94 }
95}
96
97#[async_trait]
98impl SimpleQueryHandler for LaminarPgwireHandler {
99 async fn do_query<C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
100 where
101 C: ClientInfo + ClientPortalStore + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
102 C::PortalStore: PortalStore,
103 C::Error: Debug,
104 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
105 {
106 if query.trim().is_empty() {
107 return Ok(vec![Response::EmptyQuery]);
108 }
109 let peer = client.socket_addr();
110 self.evict_idle_peer(peer);
111 let stmts = parse_streaming_sql(query)
112 .map_err(|e| user_error("42601", format!("parse error: {e}")))?;
113
114 if stmts.len() > 1
118 && stmts
119 .iter()
120 .any(|s| matches!(s, StreamingStatement::Subscribe(_)))
121 {
122 return Err(user_error(
123 "0A000",
124 "SUBSCRIBE must be the only statement in a simple query",
125 ));
126 }
127
128 let mut out = Vec::with_capacity(stmts.len());
129 for stmt in stmts {
130 out.push(match stmt {
131 StreamingStatement::Subscribe(s) => {
132 let portal = open_portal_for_subscribe(&self.db, &s).await?;
133 stream_subscribe_flushing(client, portal, true, None).await?;
135 return Ok(Vec::new());
136 }
137 StreamingStatement::Show(cmd) => {
138 engine_metadata_response(&self.db, &show_sql(&cmd)).await?
139 }
140 StreamingStatement::DeclareCursorForSubscribe {
141 name, subscribe, ..
142 } => {
143 let state = self.conn_state(peer);
144 handle_declare_cursor(&self.db, &state, &name.value, *subscribe).await?
145 }
146 StreamingStatement::Standard(s) => {
147 let state = self.conn_state(peer);
148 standard_or_cursor_response(&self.db, &state, *s)?
149 }
150 other => {
151 return Err(user_error(
152 "0A000",
153 format!("not supported on pgwire (use HTTP /api/v1/sql): {other:?}"),
154 ));
155 }
156 });
157 }
158 Ok(out)
159 }
160}
161
162async fn open_portal_for_subscribe(
163 db: &LaminarDB,
164 s: &SubscribeStatement,
165) -> PgWireResult<SubscriptionPortal> {
166 let name = s.name.to_string();
167 let start = match s.as_of_epoch {
168 Some(n) => SubscribeStart::AsOfEpoch(n),
169 None => SubscribeStart::Tail,
170 };
171 db.open_subscription(&name, s.filter_sql.as_deref(), start)
172 .await
173 .map_err(|e| user_error("42P01", format!("SUBSCRIBE '{name}': {e}")))
174}
175
176async fn stream_subscribe_flushing<C>(
189 client: &mut C,
190 mut portal: SubscriptionPortal,
191 send_row_desc: bool,
192 result_format: Option<&Format>,
193) -> PgWireResult<()>
194where
195 C: Sink<PgWireBackendMessage> + Unpin + Send,
196 C::Error: Debug,
197 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
198{
199 use futures::SinkExt;
200
201 let schema = portal.schema();
202 let fields = std::sync::Arc::new(field_infos(&schema, result_format));
205
206 if send_row_desc {
207 let row_desc =
209 pgwire::messages::data::RowDescription::new(fields.iter().map(Into::into).collect());
210 client
211 .feed(PgWireBackendMessage::RowDescription(row_desc))
212 .await?;
213 client.flush().await?;
214 }
215
216 let mut rows: usize = 0;
217 loop {
218 match portal.next_frame().await {
219 Some(PortalFrame::Batch(b)) if b.num_rows() > 0 => {
220 for row in encode_batch(&b, &fields) {
221 client.feed(PgWireBackendMessage::DataRow(row?)).await?;
222 rows += 1;
223 }
224 client.flush().await?;
225 }
226 Some(PortalFrame::Batch(_)) => {}
227 Some(PortalFrame::Barrier { .. }) => {}
229 Some(PortalFrame::Lagged(n)) => {
230 return Err(user_error(
231 "54000",
232 format!("subscription lagged: skipped {n} messages, terminating"),
233 ));
234 }
235 None => {
236 if send_row_desc {
240 let tag = Tag::new("SUBSCRIBE").with_rows(rows);
241 client
242 .feed(PgWireBackendMessage::CommandComplete(tag.into()))
243 .await?;
244 client.flush().await?;
245 }
246 return Ok(());
247 }
248 }
249 }
250}
251
252fn subscription_query_response(
256 portal: SubscriptionPortal,
257 result_format: Option<&Format>,
258) -> Response {
259 use futures::stream;
260 let schema = portal.schema();
261 let fields = Arc::new(field_infos(&schema, result_format));
262 struct State {
263 portal: SubscriptionPortal,
264 fields: Arc<Vec<FieldInfo>>,
265 pending: VecDeque<PgWireResult<pgwire::messages::data::DataRow>>,
266 }
267 let init = State {
268 portal,
269 fields: Arc::clone(&fields),
270 pending: VecDeque::new(),
271 };
272 let row_stream = stream::unfold(init, |mut s| async move {
273 loop {
274 if let Some(row) = s.pending.pop_front() {
275 return Some((row, s));
276 }
277 match s.portal.next_frame().await {
278 None => return None,
279 Some(PortalFrame::Batch(b)) if b.num_rows() > 0 => {
280 s.pending.extend(encode_batch(&b, &s.fields));
281 }
282 Some(PortalFrame::Batch(_)) | Some(PortalFrame::Barrier { .. }) => {}
283 Some(PortalFrame::Lagged(n)) => {
284 let err = user_error(
285 "54000",
286 format!("subscription lagged: skipped {n} messages, terminating"),
287 );
288 return Some((Err(err), s));
289 }
290 }
291 }
292 });
293 let mut resp = QueryResponse::new(fields, row_stream);
294 resp.set_command_tag("SUBSCRIBE");
295 Response::Query(resp)
296}
297
298struct CursorInner {
302 portal: TokioMutex<SubscriptionPortal>,
305 pending: parking_lot::Mutex<VecDeque<PgWireResult<pgwire::messages::data::DataRow>>>,
309 exhausted: AtomicBool,
312}
313
314#[derive(Clone)]
315struct ActiveCursor {
316 inner: Arc<CursorInner>,
317 schema: arrow_schema::SchemaRef,
318}
319
320#[derive(Default)]
321struct ConnState {
322 cursors: parking_lot::Mutex<HashMap<String, ActiveCursor>>,
323 in_tx: AtomicBool,
324}
325
326impl ConnState {
327 fn key(name: &str) -> String {
331 name.to_ascii_lowercase()
332 }
333
334 fn insert(&self, name: &str, cursor: ActiveCursor) {
335 self.cursors.lock().insert(Self::key(name), cursor);
336 }
337
338 fn contains(&self, name: &str) -> bool {
339 self.cursors.lock().contains_key(&Self::key(name))
340 }
341
342 fn get(&self, name: &str) -> Option<ActiveCursor> {
343 self.cursors.lock().get(&Self::key(name)).cloned()
344 }
345
346 fn remove(&self, name: &str) -> bool {
347 self.cursors.lock().remove(&Self::key(name)).is_some()
348 }
349
350 fn drop_all(&self) {
351 self.cursors.lock().clear();
352 }
353
354 fn prune_dead_and_check_idle(&self) -> bool {
357 let mut cursors = self.cursors.lock();
358 cursors.retain(|_, c| !c.inner.exhausted.load(Ordering::Acquire));
359 cursors.is_empty() && !self.in_tx.load(Ordering::Acquire)
360 }
361}
362
363async fn handle_declare_cursor(
366 db: &LaminarDB,
367 state: &ConnState,
368 cursor_name: &str,
369 subscribe: SubscribeStatement,
370) -> PgWireResult<Response> {
371 if state.contains(cursor_name) {
372 return Err(user_error(
373 "42P03",
374 format!("cursor \"{cursor_name}\" already exists"),
375 ));
376 }
377 let portal = open_portal_for_subscribe(db, &subscribe).await?;
378 let schema = portal.schema();
379 state.insert(
380 cursor_name,
381 ActiveCursor {
382 inner: Arc::new(CursorInner {
383 portal: TokioMutex::new(portal),
384 pending: parking_lot::Mutex::new(VecDeque::new()),
385 exhausted: AtomicBool::new(false),
386 }),
387 schema,
388 },
389 );
390 Ok(Response::Execution(Tag::new("DECLARE CURSOR")))
391}
392
393fn fetch_direction_count(dir: &FetchDirection) -> PgWireResult<FetchTarget> {
395 match dir {
396 FetchDirection::Next | FetchDirection::Forward { limit: None } => Ok(FetchTarget::Count(1)),
397 FetchDirection::Count { limit } | FetchDirection::Forward { limit: Some(limit) } => {
398 value_to_u64(limit).map(FetchTarget::Count)
399 }
400 FetchDirection::All | FetchDirection::ForwardAll => Ok(FetchTarget::All),
401 FetchDirection::Prior
402 | FetchDirection::First
403 | FetchDirection::Last
404 | FetchDirection::Absolute { .. }
405 | FetchDirection::Relative { .. }
406 | FetchDirection::Backward { .. }
407 | FetchDirection::BackwardAll => Err(user_error(
408 "0A000",
409 "FETCH direction not supported (SUBSCRIBE cursors are forward-only): use FORWARD or NEXT",
410 )),
411 }
412}
413
414#[derive(Copy, Clone)]
415enum FetchTarget {
416 Count(u64),
417 All,
418}
419
420fn value_to_u64(v: &AstValue) -> PgWireResult<u64> {
421 match v {
422 AstValue::Number(n, _) => n
423 .parse::<u64>()
424 .map_err(|_| user_error("22023", format!("invalid FETCH count: {n}"))),
425 other => Err(user_error(
426 "22023",
427 format!("FETCH count must be an integer, got {other}"),
428 )),
429 }
430}
431
432fn handle_fetch(
433 state: &ConnState,
434 cursor_name: &str,
435 target: FetchTarget,
436) -> PgWireResult<Response> {
437 let cursor = state
438 .get(cursor_name)
439 .ok_or_else(|| user_error("34000", format!("cursor \"{cursor_name}\" does not exist")))?;
440 Ok(fetch_response(cursor, target))
441}
442
443fn handle_close(state: &ConnState, cursor: &CloseCursor) -> PgWireResult<Response> {
444 match cursor {
445 CloseCursor::All => {
446 state.drop_all();
447 Ok(Response::Execution(Tag::new("CLOSE CURSOR ALL")))
448 }
449 CloseCursor::Specific { name } => {
450 if state.remove(&name.value) {
451 Ok(Response::Execution(Tag::new("CLOSE CURSOR")))
452 } else {
453 Err(user_error(
454 "34000",
455 format!("cursor \"{}\" does not exist", name.value),
456 ))
457 }
458 }
459 }
460}
461
462fn standard_or_cursor_response(
466 db: &LaminarDB,
467 state: &ConnState,
468 stmt: Statement,
469) -> PgWireResult<Response> {
470 match stmt {
471 Statement::StartTransaction { .. } => {
472 state.in_tx.store(true, Ordering::Release);
473 Ok(Response::Execution(Tag::new("BEGIN")))
474 }
475 Statement::Commit { .. } => {
476 state.drop_all();
477 state.in_tx.store(false, Ordering::Release);
478 Ok(Response::Execution(Tag::new("COMMIT")))
479 }
480 Statement::Rollback { .. } => {
481 state.drop_all();
482 state.in_tx.store(false, Ordering::Release);
483 Ok(Response::Execution(Tag::new("ROLLBACK")))
484 }
485 Statement::Fetch {
486 ref name,
487 ref direction,
488 ..
489 } => {
490 let target = fetch_direction_count(direction)?;
491 handle_fetch(state, &name.value, target)
492 }
493 Statement::Close { ref cursor } => handle_close(state, cursor),
494 Statement::Declare { .. } => Err(user_error(
495 "0A000",
496 "DECLARE on pgwire only supports CURSOR FOR SUBSCRIBE …",
497 )),
498 other => standard_response(db, other),
499 }
500}
501
502fn standard_response(db: &LaminarDB, stmt: Statement) -> PgWireResult<Response> {
506 match stmt {
507 Statement::StartTransaction { .. } => Ok(Response::Execution(Tag::new("BEGIN"))),
508 Statement::Commit { .. } => Ok(Response::Execution(Tag::new("COMMIT"))),
509 Statement::Rollback { .. } => Ok(Response::Execution(Tag::new("ROLLBACK"))),
510 Statement::Set(s) => apply_set(db, s),
511 Statement::Query(q) => driver_select_response(*q),
512 Statement::Insert { .. }
513 | Statement::Update { .. }
514 | Statement::Delete { .. }
515 | Statement::CreateTable { .. }
516 | Statement::CreateView { .. }
517 | Statement::Drop { .. } => Err(user_error(
518 "0A000",
519 "DDL/DML is not supported on pgwire; use HTTP /api/v1/sql",
520 )),
521 other => Err(user_error(
522 "0A000",
523 format!("not supported on pgwire: {other}"),
524 )),
525 }
526}
527
528fn driver_select_response(query: sqlparser::ast::Query) -> PgWireResult<Response> {
532 let SetExpr::Select(select) = *query.body else {
533 return Err(unsupported_select());
534 };
535 if select.projection.len() != 1 || !select.from.is_empty() || select.selection.is_some() {
536 return Err(unsupported_select());
537 }
538 let SelectItem::UnnamedExpr(expr) = &select.projection[0] else {
539 return Err(unsupported_select());
540 };
541
542 match expr {
543 Expr::Value(v) => match &v.value {
544 sqlparser::ast::Value::Number(n, _) => {
545 let parsed: i32 = n.parse().map_err(|_| unsupported_select())?;
546 Ok(text_response("?column?", Type::INT4, parsed.to_string()))
547 }
548 sqlparser::ast::Value::SingleQuotedString(s) => {
549 Ok(text_response("?column?", Type::VARCHAR, s.clone()))
550 }
551 _ => Err(unsupported_select()),
552 },
553 Expr::Function(func) => {
554 if !matches!(func.args, FunctionArguments::List(ref a) if a.args.is_empty())
557 && !matches!(func.args, FunctionArguments::None)
558 {
559 return Err(unsupported_select());
560 }
561 let name = func.name.to_string().to_ascii_lowercase();
562 let (col, ty, value) = match name.as_str() {
563 "version" | "pg_catalog.version" => (
564 "version",
565 Type::VARCHAR,
566 format!("LaminarDB {} on pgwire", env!("CARGO_PKG_VERSION")),
567 ),
568 "current_schema" | "pg_catalog.current_schema" => {
569 ("current_schema", Type::VARCHAR, "public".to_string())
570 }
571 "current_database" | "pg_catalog.current_database" => {
572 ("current_database", Type::VARCHAR, "laminar".to_string())
573 }
574 "current_user" | "session_user" | "user" => {
575 ("current_user", Type::VARCHAR, "laminar".to_string())
576 }
577 _ => return Err(unsupported_select()),
578 };
579 Ok(text_response(col, ty, value))
580 }
581 _ => Err(unsupported_select()),
582 }
583}
584
585fn unsupported_select() -> PgWireError {
586 user_error(
587 "0A000",
588 "pgwire SELECT is limited to literals and connect-time builtins; use HTTP /api/v1/sql",
589 )
590}
591
592fn apply_set(db: &LaminarDB, set: Set) -> PgWireResult<Response> {
596 match set {
597 Set::SingleAssignment {
598 variable, values, ..
599 } => {
600 let key = variable.to_string();
601 let value = values
602 .first()
603 .map(ToString::to_string)
604 .unwrap_or_default()
605 .trim_matches('\'')
606 .to_string();
607 db.set_session_property(&key, &value);
608 Ok(Response::Execution(Tag::new("SET")))
609 }
610 Set::SetTransaction { .. } => Err(user_error(
612 "0A000",
613 "SET TRANSACTION is not supported (no transactional semantics)",
614 )),
615 _ => Ok(Response::Execution(Tag::new("SET"))),
619 }
620}
621
622fn user_error(code: &str, msg: impl Into<String>) -> PgWireError {
623 PgWireError::UserError(Box::new(ErrorInfo::new(
624 "ERROR".into(),
625 code.into(),
626 msg.into(),
627 )))
628}
629
630fn show_sql(cmd: &ShowCommand) -> String {
634 match cmd {
635 ShowCommand::Sources => "SHOW SOURCES".into(),
636 ShowCommand::Sinks => "SHOW SINKS".into(),
637 ShowCommand::Queries => "SHOW QUERIES".into(),
638 ShowCommand::MaterializedViews => "SHOW MATERIALIZED VIEWS".into(),
639 ShowCommand::Streams => "SHOW STREAMS".into(),
640 ShowCommand::Tables => "SHOW TABLES".into(),
641 ShowCommand::CheckpointStatus => "SHOW CHECKPOINT STATUS".into(),
642 ShowCommand::CreateSource { name } => format!("SHOW CREATE SOURCE {name}"),
643 ShowCommand::CreateSink { name } => format!("SHOW CREATE SINK {name}"),
644 }
645}
646
647async fn engine_metadata_response(db: &LaminarDB, sql: &str) -> PgWireResult<Response> {
649 use laminar_db::ExecuteResult;
650 let result = db
651 .execute(sql)
652 .await
653 .map_err(|e| user_error("XX000", e.to_string()))?;
654 let ExecuteResult::Metadata(batch) = result else {
655 return Err(user_error("XX000", "SHOW did not return metadata"));
656 };
657 Ok(record_batch_response(batch))
658}
659
660fn text_response(col: &str, ty: Type, value: String) -> Response {
662 let schema = Arc::new(vec![FieldInfo::new(
663 col.into(),
664 None,
665 None,
666 ty,
667 FieldFormat::Text,
668 )]);
669 let schema_for_row = Arc::clone(&schema);
670 let row_stream = stream::iter(std::iter::once(Ok::<_, PgWireError>(()))).map(move |_| {
671 let mut enc = DataRowEncoder::new(Arc::clone(&schema_for_row));
672 enc.encode_field(&Some(value.as_str()))?;
673 Ok(enc.take_row())
674 });
675 Response::Query(QueryResponse::new(schema, row_stream))
676}
677
678fn record_batch_response(batch: arrow_array::RecordBatch) -> Response {
679 let fields = Arc::new(field_infos(&batch.schema(), None));
680 let nrows = batch.num_rows();
681
682 let mut rows = Vec::with_capacity(nrows);
685 {
686 let opts = arrow_cast::display::FormatOptions::default();
687 let formatters: Vec<_> = batch
688 .columns()
689 .iter()
690 .map(|c| arrow_cast::display::ArrayFormatter::try_new(c.as_ref(), &opts))
691 .collect::<Result<_, _>>()
692 .unwrap_or_default();
693 for row in 0..nrows {
694 rows.push(encode_row(&batch, row, &fields, &formatters));
695 }
696 }
697
698 let row_stream = stream::iter(rows);
699 Response::Query(QueryResponse::new(fields, row_stream))
700}
701
702fn fetch_response(cursor: ActiveCursor, target: FetchTarget) -> Response {
708 let fields = Arc::new(field_infos(&cursor.schema, None));
709 let remaining = match target {
710 FetchTarget::Count(n) => Some(n),
711 FetchTarget::All => None,
712 };
713
714 struct State {
715 cursor: ActiveCursor,
716 fields: Arc<Vec<FieldInfo>>,
717 remaining: Option<u64>,
718 }
719
720 let init = State {
721 cursor,
722 fields: Arc::clone(&fields),
723 remaining,
724 };
725
726 let row_stream = stream::unfold(init, |mut s| async move {
727 loop {
728 if matches!(s.remaining, Some(0)) {
729 return None;
730 }
731 let popped = s.cursor.inner.pending.lock().pop_front();
734 if let Some(row) = popped {
735 if let Some(n) = s.remaining.as_mut() {
736 *n = n.saturating_sub(1);
737 }
738 return Some((row, s));
739 }
740 if s.cursor.inner.exhausted.load(Ordering::Acquire) {
741 return None;
742 }
743
744 let next = s.cursor.inner.portal.lock().await.next_frame().await;
745 match next {
746 None => {
747 s.cursor.inner.exhausted.store(true, Ordering::Release);
748 return None;
749 }
750 Some(PortalFrame::Batch(b)) if b.num_rows() > 0 => {
751 let encoded = encode_batch(&b, &s.fields);
752 s.cursor.inner.pending.lock().extend(encoded);
753 }
754 Some(PortalFrame::Batch(_)) => {}
755 Some(PortalFrame::Barrier { .. }) => {
756 }
758 Some(PortalFrame::Lagged(n)) => {
759 s.cursor.inner.exhausted.store(true, Ordering::Release);
760 let err = user_error(
761 "54000",
762 format!("subscription lagged: skipped {n} messages, terminating cursor"),
763 );
764 return Some((Err(err), s));
765 }
766 }
767 }
768 });
769 Response::Query(QueryResponse::new(fields, row_stream))
770}
771
772fn encode_batch(
773 batch: &arrow_array::RecordBatch,
774 fields: &Arc<Vec<FieldInfo>>,
775) -> Vec<PgWireResult<pgwire::messages::data::DataRow>> {
776 let opts = arrow_cast::display::FormatOptions::default();
777 let formatters: Vec<_> = match batch
778 .columns()
779 .iter()
780 .map(|c| arrow_cast::display::ArrayFormatter::try_new(c.as_ref(), &opts))
781 .collect::<Result<_, _>>()
782 {
783 Ok(f) => f,
784 Err(e) => {
785 return vec![Err(user_error("XX000", format!("format column: {e}")))];
786 }
787 };
788 (0..batch.num_rows())
789 .map(|row| encode_row(batch, row, fields, &formatters))
790 .collect()
791}
792
793fn field_infos(schema: &arrow_schema::Schema, result_format: Option<&Format>) -> Vec<FieldInfo> {
796 schema
797 .fields()
798 .iter()
799 .enumerate()
800 .map(|(i, f)| {
801 let format = result_format.map_or(FieldFormat::Text, |rf| rf.format_for(i));
802 FieldInfo::new(
803 f.name().clone(),
804 None,
805 None,
806 arrow_to_pg_type(f.data_type()),
807 format,
808 )
809 })
810 .collect()
811}
812
813fn encode_row(
814 batch: &arrow_array::RecordBatch,
815 row: usize,
816 fields: &Arc<Vec<FieldInfo>>,
817 formatters: &[arrow_cast::display::ArrayFormatter<'_>],
818) -> PgWireResult<pgwire::messages::data::DataRow> {
819 let mut enc = DataRowEncoder::new(Arc::clone(fields));
820 for (i, col) in batch.columns().iter().enumerate() {
821 let info = &fields[i];
822 match info.format() {
823 FieldFormat::Text => encode_field_text(&mut enc, col.as_ref(), row, &formatters[i])?,
824 FieldFormat::Binary => encode_field_binary(&mut enc, col.as_ref(), row, info.name())?,
825 }
826 }
827 Ok(enc.take_row())
828}
829
830fn encode_field_text(
831 enc: &mut DataRowEncoder,
832 col: &dyn arrow_array::Array,
833 row: usize,
834 formatter: &arrow_cast::display::ArrayFormatter<'_>,
835) -> PgWireResult<()> {
836 use arrow_schema::DataType;
837 if col.is_null(row) {
838 return enc.encode_field(&None::<&str>);
839 }
840 if matches!(col.data_type(), DataType::List(f) if matches!(f.data_type(), DataType::Utf8 | DataType::LargeUtf8))
843 {
844 return enc.encode_field(&Some(pg_text_array_literal(&list_text_elements(col, row))));
845 }
846 enc.encode_field(&Some(formatter.value(row).to_string()))
847}
848
849fn list_text_elements(col: &dyn arrow_array::Array, row: usize) -> Vec<Option<String>> {
851 use arrow_array::cast::AsArray;
852 use arrow_array::Array;
853 use arrow_schema::DataType;
854 let values = col.as_list::<i32>().value(row);
855 if matches!(values.data_type(), DataType::LargeUtf8) {
856 let s = values.as_string::<i64>();
857 (0..s.len())
858 .map(|i| (!s.is_null(i)).then(|| s.value(i).to_owned()))
859 .collect()
860 } else {
861 let s = values.as_string::<i32>();
862 (0..s.len())
863 .map(|i| (!s.is_null(i)).then(|| s.value(i).to_owned()))
864 .collect()
865 }
866}
867
868fn pg_text_array_literal(elements: &[Option<String>]) -> String {
871 let mut out = String::from("{");
872 for (i, elem) in elements.iter().enumerate() {
873 if i > 0 {
874 out.push(',');
875 }
876 match elem {
877 None => out.push_str("NULL"),
878 Some(v) => {
879 out.push('"');
880 for ch in v.chars() {
881 if ch == '"' || ch == '\\' {
882 out.push('\\');
883 }
884 out.push(ch);
885 }
886 out.push('"');
887 }
888 }
889 }
890 out.push('}');
891 out
892}
893
894fn encode_field_binary(
901 enc: &mut DataRowEncoder,
902 col: &dyn arrow_array::Array,
903 row: usize,
904 name: &str,
905) -> PgWireResult<()> {
906 use arrow_array::{cast::AsArray, types::*};
907 use arrow_schema::DataType;
908
909 if col.is_null(row) {
910 return enc.encode_field(&None::<&str>);
911 }
912
913 macro_rules! prim {
918 ($ty:ty as $cast:ty) => {
919 enc.encode_field(&Some(<$cast>::from(col.as_primitive::<$ty>().value(row))))
920 };
921 ($ty:ty) => {
922 enc.encode_field(&Some(col.as_primitive::<$ty>().value(row)))
923 };
924 }
925
926 match col.data_type() {
927 DataType::Int8 => prim!(Int8Type as i32),
928 DataType::Int16 => prim!(Int16Type as i32),
929 DataType::Int32 => prim!(Int32Type),
930 DataType::Int64 => prim!(Int64Type),
931 DataType::UInt8 => prim!(UInt8Type as i32),
932 DataType::UInt16 => prim!(UInt16Type as i32),
933 DataType::UInt32 => prim!(UInt32Type as i64),
934 DataType::UInt64 => {
935 let v = col.as_primitive::<UInt64Type>().value(row);
937 enc.encode_field(&Some(i64::try_from(v).unwrap_or(i64::MAX)))
938 }
939 DataType::Float32 => prim!(Float32Type as f64),
940 DataType::Float64 => prim!(Float64Type),
941 DataType::Boolean => enc.encode_field(&Some(col.as_boolean().value(row))),
942 DataType::Utf8 => enc.encode_field(&Some(col.as_string::<i32>().value(row))),
943 DataType::LargeUtf8 => enc.encode_field(&Some(col.as_string::<i64>().value(row))),
944 DataType::Timestamp(unit, _tz) => {
945 use arrow_array::temporal_conversions::{
948 timestamp_ms_to_datetime, timestamp_ns_to_datetime, timestamp_s_to_datetime,
949 timestamp_us_to_datetime,
950 };
951 use arrow_schema::TimeUnit;
952 let (raw, dt) = match unit {
953 TimeUnit::Second => {
954 let v = col.as_primitive::<TimestampSecondType>().value(row);
955 (v, timestamp_s_to_datetime(v))
956 }
957 TimeUnit::Millisecond => {
958 let v = col.as_primitive::<TimestampMillisecondType>().value(row);
959 (v, timestamp_ms_to_datetime(v))
960 }
961 TimeUnit::Microsecond => {
962 let v = col.as_primitive::<TimestampMicrosecondType>().value(row);
963 (v, timestamp_us_to_datetime(v))
964 }
965 TimeUnit::Nanosecond => {
966 let v = col.as_primitive::<TimestampNanosecondType>().value(row);
967 (v, timestamp_ns_to_datetime(v))
968 }
969 };
970 let dt =
971 dt.ok_or_else(|| user_error("22008", format!("timestamp out of range: {raw}")))?;
972 enc.encode_field(&Some(dt))
973 }
974 DataType::Date32 => {
975 let v = col.as_primitive::<Date32Type>().value(row);
976 let dt = arrow_array::temporal_conversions::date32_to_datetime(v)
977 .ok_or_else(|| user_error("22008", format!("DATE out of range: {v}")))?;
978 enc.encode_field(&Some(dt.date()))
979 }
980 DataType::Date64 => {
981 let v = col.as_primitive::<Date64Type>().value(row);
982 let dt = arrow_array::temporal_conversions::date64_to_datetime(v)
983 .ok_or_else(|| user_error("22008", format!("DATE out of range: {v}")))?;
984 enc.encode_field(&Some(dt.date()))
985 }
986 DataType::List(field)
987 if matches!(field.data_type(), DataType::Utf8 | DataType::LargeUtf8) =>
988 {
989 enc.encode_field(&Some(list_text_elements(col, row)))
992 }
993 other => Err(user_error(
994 "0A000",
995 format!("binary format not supported for column '{name}' (type {other:?})"),
996 )),
997 }
998}
999
1000fn arrow_to_pg_type(dt: &arrow_schema::DataType) -> Type {
1001 use arrow_schema::DataType;
1002 match dt {
1003 DataType::Int8 | DataType::Int16 | DataType::Int32 => Type::INT4,
1004 DataType::Int64 | DataType::UInt32 | DataType::UInt64 => Type::INT8,
1005 DataType::UInt8 | DataType::UInt16 => Type::INT4,
1006 DataType::Float32 | DataType::Float64 => Type::FLOAT8,
1007 DataType::Utf8 | DataType::LargeUtf8 => Type::VARCHAR,
1008 DataType::Boolean => Type::BOOL,
1009 DataType::Timestamp(_, _) => Type::TIMESTAMP,
1010 DataType::Date32 | DataType::Date64 => Type::DATE,
1011 DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => Type::NUMERIC,
1012 DataType::List(field)
1013 if matches!(field.data_type(), DataType::Utf8 | DataType::LargeUtf8) =>
1014 {
1015 Type::TEXT_ARRAY
1016 }
1017 _ => Type::TEXT,
1018 }
1019}
1020
1021#[derive(Debug)]
1026struct LaminarAuthSource {
1027 users: Arc<HashMap<String, Secret>>,
1028}
1029
1030pub(crate) fn parse_pre_hashed_md5(stored: &str) -> Option<&str> {
1034 let inner = stored.strip_prefix("md5")?;
1035 if inner.len() == 32 && inner.chars().all(|c| matches!(c, '0'..='9' | 'a'..='f')) {
1036 Some(inner)
1037 } else {
1038 None
1039 }
1040}
1041
1042fn outer_md5_challenge(inner_hex: &str, salt: &[u8]) -> String {
1046 use md5::{Digest, Md5};
1047 let mut hasher = Md5::new();
1048 hasher.update(inner_hex.as_bytes());
1049 hasher.update(salt);
1050 format!("md5{:x}", hasher.finalize())
1051}
1052
1053#[async_trait]
1054impl AuthSource for LaminarAuthSource {
1055 async fn get_password(&self, login: &LoginInfo) -> PgWireResult<Password> {
1056 let user = login.user().unwrap_or("");
1057 let stored = self
1061 .users
1062 .get(user)
1063 .ok_or_else(|| PgWireError::InvalidPassword(user.to_string()))?;
1064 let salt: [u8; 4] = rand::random();
1065 let expected = match parse_pre_hashed_md5(stored.expose()) {
1066 Some(inner_hex) => outer_md5_challenge(inner_hex, &salt),
1067 None => hash_md5_password(user, stored.expose(), &salt),
1068 };
1069 Ok(Password::new(Some(salt.to_vec()), expected.into_bytes()))
1070 }
1071}
1072
1073type Md5Handler = Md5PasswordAuthStartupHandler<LaminarAuthSource, DefaultServerParameterProvider>;
1074
1075enum StartupAuth {
1079 Trust(Arc<LaminarPgwireHandler>),
1080 Md5(Arc<Md5Handler>),
1081}
1082
1083#[async_trait]
1084impl StartupHandler for StartupAuth {
1085 async fn on_startup<C>(
1086 &self,
1087 client: &mut C,
1088 message: PgWireFrontendMessage,
1089 ) -> PgWireResult<()>
1090 where
1091 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
1092 C::Error: Debug,
1093 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
1094 {
1095 match self {
1096 Self::Trust(h) => h.on_startup(client, message).await,
1097 Self::Md5(h) => h.on_startup(client, message).await,
1098 }
1099 }
1100}
1101
1102pub struct LaminarHandlerFactory {
1103 handler: Arc<LaminarPgwireHandler>,
1104 startup: Arc<StartupAuth>,
1105}
1106
1107impl LaminarHandlerFactory {
1108 fn new(db: Arc<LaminarDB>, users: HashMap<String, Secret>) -> Self {
1109 let handler = Arc::new(LaminarPgwireHandler::new(db));
1110 let startup = if users.is_empty() {
1111 Arc::new(StartupAuth::Trust(Arc::clone(&handler)))
1112 } else {
1113 let auth = LaminarAuthSource {
1114 users: Arc::new(users),
1115 };
1116 let md5 = Md5PasswordAuthStartupHandler::new(
1117 Arc::new(auth),
1118 Arc::new(DefaultServerParameterProvider::default()),
1119 );
1120 Arc::new(StartupAuth::Md5(Arc::new(md5)))
1121 };
1122 Self { handler, startup }
1123 }
1124}
1125
1126impl PgWireServerHandlers for LaminarHandlerFactory {
1127 fn simple_query_handler(&self) -> Arc<impl SimpleQueryHandler> {
1128 Arc::clone(&self.handler)
1129 }
1130
1131 fn extended_query_handler(&self) -> Arc<impl ExtendedQueryHandler> {
1132 Arc::clone(&self.handler)
1133 }
1134
1135 fn startup_handler(&self) -> Arc<impl StartupHandler> {
1136 Arc::clone(&self.startup)
1137 }
1138}
1139
1140#[derive(Clone, Debug)]
1142pub enum LaminarStmt {
1143 Subscribe {
1146 name: String,
1147 filter_sql: Option<String>,
1148 as_of_epoch: Option<u64>,
1149 schema: arrow_schema::SchemaRef,
1150 },
1151 Show(ShowCommand),
1152 Standard(Box<Statement>),
1153}
1154
1155#[derive(Clone)]
1159pub struct LaminarQueryParser {
1160 db: Arc<LaminarDB>,
1161}
1162
1163#[async_trait]
1164impl QueryParser for LaminarQueryParser {
1165 type Statement = LaminarStmt;
1166
1167 async fn parse_sql<C>(
1168 &self,
1169 _client: &C,
1170 sql: &str,
1171 _types: &[Option<Type>],
1172 ) -> PgWireResult<Self::Statement>
1173 where
1174 C: ClientInfo + Unpin + Send + Sync,
1175 {
1176 let mut stmts = parse_streaming_sql(sql)
1177 .map_err(|e| user_error("42601", format!("parse error: {e}")))?;
1178 let stmt = stmts
1179 .pop()
1180 .ok_or_else(|| user_error("42601", "empty statement"))?;
1181 if !stmts.is_empty() {
1182 return Err(user_error(
1183 "42601",
1184 "extended query: multiple statements per Parse are not supported",
1185 ));
1186 }
1187
1188 match stmt {
1189 StreamingStatement::Subscribe(s) => {
1190 let name = s.name.to_string();
1191 let (schema, _) = self.db.lookup_subscription_schema(&name).ok_or_else(|| {
1192 user_error("42P01", format!("SUBSCRIBE '{name}': stream not found"))
1193 })?;
1194 Ok(LaminarStmt::Subscribe {
1195 name,
1196 filter_sql: s.filter_sql,
1197 as_of_epoch: s.as_of_epoch,
1198 schema,
1199 })
1200 }
1201 StreamingStatement::Show(cmd) => Ok(LaminarStmt::Show(cmd)),
1202 StreamingStatement::Standard(s) => Ok(LaminarStmt::Standard(s)),
1203 other => Err(user_error(
1204 "0A000",
1205 format!("not supported on pgwire (use HTTP /api/v1/sql): {other:?}"),
1206 )),
1207 }
1208 }
1209
1210 fn get_parameter_types(&self, _stmt: &Self::Statement) -> PgWireResult<Vec<Type>> {
1211 Ok(Vec::new())
1213 }
1214
1215 fn get_result_schema(
1216 &self,
1217 stmt: &Self::Statement,
1218 column_format: Option<&Format>,
1219 ) -> PgWireResult<Vec<FieldInfo>> {
1220 match stmt {
1224 LaminarStmt::Subscribe { schema, .. } => Ok(field_infos(schema, column_format)),
1225 LaminarStmt::Show(_) | LaminarStmt::Standard(_) => Ok(Vec::new()),
1226 }
1227 }
1228}
1229
1230#[async_trait]
1231impl ExtendedQueryHandler for LaminarPgwireHandler {
1232 type Statement = LaminarStmt;
1233 type QueryParser = LaminarQueryParser;
1234
1235 fn query_parser(&self) -> Arc<Self::QueryParser> {
1236 Arc::new(LaminarQueryParser {
1237 db: Arc::clone(&self.db),
1238 })
1239 }
1240
1241 async fn do_query<C>(
1242 &self,
1243 client: &mut C,
1244 portal: &Portal<Self::Statement>,
1245 max_rows: usize,
1246 ) -> PgWireResult<Response>
1247 where
1248 C: ClientInfo + ClientPortalStore + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
1249 C::PortalStore: PortalStore<Statement = Self::Statement>,
1250 C::Error: Debug,
1251 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
1252 {
1253 match &portal.statement.statement {
1254 LaminarStmt::Subscribe {
1255 name,
1256 filter_sql,
1257 as_of_epoch,
1258 ..
1259 } => {
1260 let start = match as_of_epoch {
1261 Some(n) => SubscribeStart::AsOfEpoch(*n),
1262 None => SubscribeStart::Tail,
1263 };
1264 let sub = self
1265 .db
1266 .open_subscription(name, filter_sql.as_deref(), start)
1267 .await
1268 .map_err(|e| user_error("42P01", format!("SUBSCRIBE '{name}': {e}")))?;
1269 if max_rows == 0 {
1270 stream_subscribe_flushing(
1273 client,
1274 sub,
1275 false,
1276 Some(&portal.result_column_format),
1277 )
1278 .await?;
1279 Ok(Response::Execution(Tag::new("SUBSCRIBE")))
1280 } else {
1281 Ok(subscription_query_response(
1285 sub,
1286 Some(&portal.result_column_format),
1287 ))
1288 }
1289 }
1290 LaminarStmt::Show(cmd) => engine_metadata_response(&self.db, &show_sql(cmd)).await,
1291 LaminarStmt::Standard(s) => standard_response(&self.db, *s.clone()),
1292 }
1293 }
1294
1295 async fn on_sync<C>(
1304 &self,
1305 client: &mut C,
1306 _message: pgwire::messages::extendedquery::Sync,
1307 ) -> PgWireResult<()>
1308 where
1309 C: ClientInfo + ClientPortalStore + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
1310 C::PortalStore: PortalStore<Statement = Self::Statement>,
1311 C::Error: Debug,
1312 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
1313 {
1314 use futures::SinkExt;
1315 use pgwire::messages::response::ReadyForQuery;
1316
1317 client.portal_store().rm_portal("");
1319
1320 client
1321 .send(PgWireBackendMessage::ReadyForQuery(ReadyForQuery::new(
1322 client.transaction_status(),
1323 )))
1324 .await?;
1325 client.flush().await?;
1326 Ok(())
1327 }
1328}
1329
1330pub struct TlsPaths<'a> {
1331 pub cert: &'a std::path::Path,
1332 pub key: &'a std::path::Path,
1333 pub min_version: TlsMinVersion,
1334 pub client_ca: Option<&'a std::path::Path>,
1337}
1338
1339#[derive(Debug, Clone)]
1343struct TlsConfigPaths {
1344 cert: std::path::PathBuf,
1345 key: std::path::PathBuf,
1346 min_version: TlsMinVersion,
1347 client_ca: Option<std::path::PathBuf>,
1348}
1349
1350impl TlsConfigPaths {
1351 fn from_paths(paths: &TlsPaths<'_>) -> Self {
1352 Self {
1353 cert: paths.cert.to_path_buf(),
1354 key: paths.key.to_path_buf(),
1355 min_version: paths.min_version,
1356 client_ca: paths.client_ca.map(|p| p.to_path_buf()),
1357 }
1358 }
1359
1360 fn borrow(&self) -> TlsPaths<'_> {
1361 TlsPaths {
1362 cert: &self.cert,
1363 key: &self.key,
1364 min_version: self.min_version,
1365 client_ca: self.client_ca.as_deref(),
1366 }
1367 }
1368}
1369
1370pub struct TlsReloadState {
1374 paths: TlsConfigPaths,
1375 acceptor: parking_lot::Mutex<Arc<tokio_rustls::TlsAcceptor>>,
1376}
1377
1378impl TlsReloadState {
1379 fn snapshot(&self) -> Arc<tokio_rustls::TlsAcceptor> {
1380 Arc::clone(&self.acceptor.lock())
1381 }
1382}
1383
1384#[allow(clippy::result_large_err)]
1388pub(crate) fn try_reload_tls(state: &TlsReloadState) -> Result<(), ServerError> {
1389 let new_acceptor = load_tls_acceptor(state.paths.borrow())?;
1390 *state.acceptor.lock() = Arc::new(new_acceptor);
1391 Ok(())
1392}
1393
1394async fn watch_tls_files(state: Arc<TlsReloadState>, debounce: std::time::Duration) {
1399 use crossfire::{mpsc, MTx};
1400 use notify::{Event, RecommendedWatcher, RecursiveMode, Watcher};
1401
1402 let mut raw_targets: Vec<std::path::PathBuf> = Vec::new();
1405 let mut canon_targets: Vec<std::path::PathBuf> = Vec::new();
1406 for path in [
1407 Some(state.paths.cert.clone()),
1408 Some(state.paths.key.clone()),
1409 state.paths.client_ca.clone(),
1410 ]
1411 .into_iter()
1412 .flatten()
1413 {
1414 match path.canonicalize() {
1415 Ok(canonical) => {
1416 canon_targets.push(canonical);
1417 raw_targets.push(path);
1418 }
1419 Err(e) => {
1420 warn!(
1421 path = %path.display(),
1422 error = %e,
1423 "pgwire TLS watcher: cannot canonicalize path; reload disabled",
1424 );
1425 return;
1426 }
1427 }
1428 }
1429 let mut dirs: Vec<std::path::PathBuf> = raw_targets
1430 .iter()
1431 .chain(canon_targets.iter())
1432 .filter_map(|p| p.parent().map(|d| d.to_path_buf()))
1433 .collect();
1434 dirs.sort();
1435 dirs.dedup();
1436
1437 let (tx, rx) = mpsc::bounded_async::<()>(16);
1438 let blocking_tx: MTx<_> = tx.clone().into_blocking();
1439 let watch_raw = raw_targets.clone();
1440 let watch_canon = canon_targets.clone();
1441
1442 let mut watcher: RecommendedWatcher = match notify::recommended_watcher(
1443 move |result: Result<Event, notify::Error>| match result {
1444 Ok(event) => {
1445 let touched = event.paths.iter().any(|p| {
1446 watch_raw.iter().any(|t| t == p)
1447 || p.canonicalize()
1448 .ok()
1449 .as_ref()
1450 .is_some_and(|c| watch_canon.contains(c))
1451 });
1452 if touched {
1453 let _ = blocking_tx.send(());
1454 }
1455 }
1456 Err(e) => warn!(error = %e, "pgwire TLS watcher: notify error"),
1457 },
1458 ) {
1459 Ok(w) => w,
1460 Err(e) => {
1461 warn!(error = %e, "pgwire TLS watcher: failed to create watcher; reload disabled");
1462 return;
1463 }
1464 };
1465
1466 for dir in &dirs {
1467 if let Err(e) = watcher.watch(dir, RecursiveMode::NonRecursive) {
1468 warn!(
1469 dir = %dir.display(),
1470 error = %e,
1471 "pgwire TLS watcher: failed to watch directory; reload disabled",
1472 );
1473 return;
1474 }
1475 }
1476 info!(
1477 files = ?raw_targets.iter().map(|p| p.display().to_string()).collect::<Vec<_>>(),
1478 "pgwire TLS watcher started",
1479 );
1480
1481 loop {
1482 if rx.recv().await.is_err() {
1483 return;
1484 }
1485 tokio::time::sleep(debounce).await;
1488 while rx.try_recv().is_ok() {}
1489
1490 match try_reload_tls(&state) {
1491 Ok(()) => tracing::info!(
1492 target: "audit",
1493 event = "pgwire.tls_reload",
1494 outcome = "ok",
1495 ),
1496 Err(e) => tracing::warn!(
1497 target: "audit",
1498 event = "pgwire.tls_reload",
1499 outcome = "failed",
1500 error = %e,
1501 "pgwire TLS reload failed; previous certificate kept",
1502 ),
1503 }
1504 }
1505}
1506
1507#[derive(Clone, Copy, Debug)]
1511pub enum TlsMinVersion {
1512 V1_2,
1513 V1_3,
1514}
1515
1516impl TlsMinVersion {
1517 pub(crate) fn from_config_str(s: &str) -> Option<Self> {
1518 match s {
1519 "1.2" => Some(Self::V1_2),
1520 "1.3" => Some(Self::V1_3),
1521 _ => None,
1522 }
1523 }
1524
1525 fn versions(self) -> &'static [&'static tokio_rustls::rustls::SupportedProtocolVersion] {
1526 use tokio_rustls::rustls::version::{TLS12, TLS13};
1527 static BOTH: &[&tokio_rustls::rustls::SupportedProtocolVersion] = &[&TLS12, &TLS13];
1528 static ONLY_13: &[&tokio_rustls::rustls::SupportedProtocolVersion] = &[&TLS13];
1529 match self {
1530 Self::V1_2 => BOTH,
1531 Self::V1_3 => ONLY_13,
1532 }
1533 }
1534
1535 fn label(self) -> &'static str {
1536 match self {
1537 Self::V1_2 => "1.2",
1538 Self::V1_3 => "1.3",
1539 }
1540 }
1541}
1542
1543#[cfg(unix)]
1545fn warn_if_key_world_readable(file: &std::fs::File, path: &std::path::Path) {
1546 use std::os::unix::fs::MetadataExt;
1547 if let Ok(meta) = file.metadata() {
1548 let mode = meta.mode();
1549 if mode & 0o077 != 0 {
1550 warn!(
1551 path = %path.display(),
1552 mode = format!("{:o}", mode & 0o777),
1553 "pgwire_tls_key permissions are too broad; tighten to 0600",
1554 );
1555 }
1556 }
1557}
1558
1559#[cfg(not(unix))]
1560fn warn_if_key_world_readable(_file: &std::fs::File, _path: &std::path::Path) {}
1561
1562#[derive(Debug, Default)]
1564struct FailureTracker {
1565 inner: parking_lot::Mutex<
1566 HashMap<std::net::IpAddr, std::collections::VecDeque<std::time::Instant>>,
1567 >,
1568}
1569
1570impl FailureTracker {
1571 fn is_blocked(&self, ip: std::net::IpAddr, limit: u32, window: std::time::Duration) -> bool {
1572 if limit == 0 {
1573 return false;
1574 }
1575 let cutoff = std::time::Instant::now() - window;
1576 let mut inner = self.inner.lock();
1577 let Some(failures) = inner.get_mut(&ip) else {
1578 return false;
1579 };
1580 while failures.front().is_some_and(|t| *t < cutoff) {
1581 failures.pop_front();
1582 }
1583 let blocked = failures.len() >= limit as usize;
1584 if failures.is_empty() {
1585 inner.remove(&ip);
1586 }
1587 blocked
1588 }
1589
1590 fn record_failure(&self, ip: std::net::IpAddr) {
1591 let mut inner = self.inner.lock();
1592 if !inner.contains_key(&ip) && inner.len() >= MAX_TRACKED_IPS {
1594 if let Some(oldest) = inner
1595 .iter()
1596 .min_by_key(|(_, q)| q.back().copied())
1597 .map(|(k, _)| *k)
1598 {
1599 inner.remove(&oldest);
1600 }
1601 }
1602 inner
1603 .entry(ip)
1604 .or_default()
1605 .push_back(std::time::Instant::now());
1606 }
1607}
1608
1609const MAX_TRACKED_IPS: usize = 4096;
1610
1611fn classify_outcome(result: &Result<(), std::io::Error>) -> &'static str {
1613 match result {
1614 Ok(()) => "ok",
1615 Err(e) => {
1616 let msg = e.to_string();
1617 if msg.contains("28P01") {
1618 "auth_failed"
1619 } else if msg.contains("HandshakeFailure")
1620 || msg.contains("rustls")
1621 || msg.contains("tls")
1622 {
1623 "tls_failed"
1624 } else {
1625 "error"
1626 }
1627 }
1628 }
1629}
1630
1631#[allow(clippy::result_large_err)]
1633fn check_cert_expiry(
1634 der: &tokio_rustls::rustls::pki_types::CertificateDer<'_>,
1635 path: &std::path::Path,
1636) -> Result<(), ServerError> {
1637 use x509_parser::prelude::FromDer;
1638 let (_, cert) = x509_parser::certificate::X509Certificate::from_der(der.as_ref())
1639 .map_err(|e| ServerError::Http(format!("parse pgwire_tls_cert {}: {e}", path.display())))?;
1640 let now = x509_parser::time::ASN1Time::now();
1641 let not_after = cert.validity().not_after;
1642 if not_after < now {
1643 return Err(ServerError::Http(format!(
1644 "pgwire_tls_cert {} expired at {not_after}",
1645 path.display()
1646 )));
1647 }
1648 let remaining = not_after.to_datetime() - now.to_datetime();
1649 if remaining <= time::Duration::days(30) {
1650 warn!(
1651 path = %path.display(),
1652 expires_at = %not_after,
1653 "pgwire_tls_cert expires within 30 days; rotate before it lapses",
1654 );
1655 }
1656 Ok(())
1657}
1658
1659fn ensure_tls_provider() {
1661 let _ = tokio_rustls::rustls::crypto::aws_lc_rs::default_provider().install_default();
1662}
1663
1664#[allow(clippy::result_large_err)]
1665fn load_tls_acceptor(paths: TlsPaths<'_>) -> Result<tokio_rustls::TlsAcceptor, ServerError> {
1666 use std::fs::File;
1667 use std::io::BufReader;
1668
1669 ensure_tls_provider();
1670
1671 let cert_file = File::open(paths.cert)
1672 .map_err(|e| ServerError::Http(format!("open pgwire_tls_cert: {e}")))?;
1673 let certs = rustls_pemfile::certs(&mut BufReader::new(cert_file))
1674 .collect::<Result<Vec<_>, _>>()
1675 .map_err(|e| ServerError::Http(format!("parse pgwire_tls_cert: {e}")))?;
1676 if certs.is_empty() {
1677 return Err(ServerError::Http(format!(
1678 "pgwire_tls_cert {} contains no certificates",
1679 paths.cert.display()
1680 )));
1681 }
1682 for cert in &certs {
1683 check_cert_expiry(cert, paths.cert)?;
1684 }
1685
1686 let key_file = File::open(paths.key)
1687 .map_err(|e| ServerError::Http(format!("open pgwire_tls_key: {e}")))?;
1688 warn_if_key_world_readable(&key_file, paths.key);
1689 let key = rustls_pemfile::private_key(&mut BufReader::new(key_file))
1690 .map_err(|e| ServerError::Http(format!("parse pgwire_tls_key: {e}")))?
1691 .ok_or_else(|| {
1692 ServerError::Http(format!(
1693 "pgwire_tls_key {} contains no private key",
1694 paths.key.display()
1695 ))
1696 })?;
1697
1698 let builder = tokio_rustls::rustls::ServerConfig::builder_with_protocol_versions(
1699 paths.min_version.versions(),
1700 );
1701 let builder = match paths.client_ca {
1702 Some(ca_path) => {
1703 let verifier = build_client_cert_verifier(ca_path)?;
1704 builder.with_client_cert_verifier(verifier)
1705 }
1706 None => builder.with_no_client_auth(),
1707 };
1708 let server_config = builder
1709 .with_single_cert(certs, key)
1710 .map_err(|e| ServerError::Http(format!("rustls server config: {e}")))?;
1711 Ok(tokio_rustls::TlsAcceptor::from(Arc::new(server_config)))
1712}
1713
1714#[allow(clippy::result_large_err)]
1715fn build_client_cert_verifier(
1716 ca_path: &std::path::Path,
1717) -> Result<Arc<dyn tokio_rustls::rustls::server::danger::ClientCertVerifier>, ServerError> {
1718 use std::fs::File;
1719 use std::io::BufReader;
1720 use tokio_rustls::rustls::server::WebPkiClientVerifier;
1721 use tokio_rustls::rustls::RootCertStore;
1722
1723 let file = File::open(ca_path)
1724 .map_err(|e| ServerError::Http(format!("open pgwire_tls_client_ca: {e}")))?;
1725 let mut roots = RootCertStore::empty();
1726 let mut added = 0usize;
1727 for cert in rustls_pemfile::certs(&mut BufReader::new(file)) {
1728 let cert =
1729 cert.map_err(|e| ServerError::Http(format!("parse pgwire_tls_client_ca: {e}")))?;
1730 roots
1731 .add(cert)
1732 .map_err(|e| ServerError::Http(format!("invalid CA in pgwire_tls_client_ca: {e}")))?;
1733 added += 1;
1734 }
1735 if added == 0 {
1736 return Err(ServerError::Http(format!(
1737 "pgwire_tls_client_ca {} contains no certificates",
1738 ca_path.display()
1739 )));
1740 }
1741 WebPkiClientVerifier::builder(Arc::new(roots))
1742 .build()
1743 .map_err(|e| ServerError::Http(format!("build client-cert verifier: {e}")))
1744}
1745
1746pub async fn serve(
1747 db: Arc<LaminarDB>,
1748 bind: &str,
1749 users: HashMap<String, Secret>,
1750 allow_remote: bool,
1751 tls: Option<TlsPaths<'_>>,
1752 max_connections: usize,
1753 max_auth_failures_per_min: u32,
1754) -> Result<(SocketAddr, tokio::task::JoinHandle<()>), ServerError> {
1755 let addr: SocketAddr = bind
1756 .parse()
1757 .map_err(|e| ServerError::Http(format!("invalid pgwire_bind '{bind}': {e}")))?;
1758
1759 let auth_mode = if users.is_empty() { "trust" } else { "md5" };
1760 let is_remote_bind = !addr.ip().is_loopback();
1761 match (auth_mode, is_remote_bind, allow_remote) {
1762 ("trust", true, _) => {
1763 return Err(ServerError::Http(format!(
1764 "pgwire_bind '{addr}' is not loopback and pgwire_users is empty (trust auth); \
1765 configure pgwire_users + pgwire_allow_remote=true, or bind to 127.0.0.1"
1766 )))
1767 }
1768 ("md5", true, false) => {
1769 return Err(ServerError::Http(format!(
1770 "pgwire_bind '{addr}' is not loopback; set pgwire_allow_remote=true to opt in"
1771 )))
1772 }
1773 _ => {}
1774 }
1775
1776 let tls_min_label = tls.as_ref().map(|p| p.min_version.label());
1777 let mtls_on = tls.as_ref().is_some_and(|p| p.client_ca.is_some());
1778 let tls_state: Option<Arc<TlsReloadState>> = match tls {
1779 Some(paths) => {
1780 let acceptor = load_tls_acceptor(TlsPaths {
1781 cert: paths.cert,
1782 key: paths.key,
1783 min_version: paths.min_version,
1784 client_ca: paths.client_ca,
1785 })?;
1786 Some(Arc::new(TlsReloadState {
1787 paths: TlsConfigPaths::from_paths(&paths),
1788 acceptor: parking_lot::Mutex::new(Arc::new(acceptor)),
1789 }))
1790 }
1791 None => None,
1792 };
1793
1794 let listener = TcpListener::bind(addr)
1795 .await
1796 .map_err(|e| ServerError::Http(format!("pgwire bind {addr}: {e}")))?;
1797 let local_addr = listener
1798 .local_addr()
1799 .map_err(|e| ServerError::Http(format!("pgwire local_addr: {e}")))?;
1800
1801 let factory = Arc::new(LaminarHandlerFactory::new(db, users));
1802 let tls_mode = if tls_state.is_some() { "on" } else { "off" };
1803 let tls_min = tls_min_label.unwrap_or("-");
1804 let mtls = if mtls_on { "on" } else { "off" };
1805 if auth_mode == "trust" {
1806 warn!(
1807 addr = %local_addr,
1808 tls = tls_mode,
1809 tls_min,
1810 mtls,
1811 "pgwire listening with TRUST auth — any client reaching this address is admin",
1812 );
1813 } else {
1814 info!(
1815 addr = %local_addr,
1816 auth = auth_mode,
1817 tls = tls_mode,
1818 tls_min,
1819 mtls,
1820 "pgwire listening",
1821 );
1822 }
1823
1824 let failures = Arc::new(FailureTracker::default());
1827 let watcher_state = tls_state.as_ref().map(Arc::clone);
1828 let watcher_disabled =
1829 std::env::var("LAMINAR_DISABLE_FILE_WATCH").is_ok_and(|v| v == "1" || v == "true");
1830 let handle = tokio::spawn(async move {
1831 let mut sessions: tokio::task::JoinSet<()> = tokio::task::JoinSet::new();
1832 let mut watcher_set: tokio::task::JoinSet<()> = tokio::task::JoinSet::new();
1834 if let (Some(state), false) = (watcher_state, watcher_disabled) {
1835 watcher_set.spawn(async move {
1836 watch_tls_files(state, std::time::Duration::from_millis(500)).await;
1837 });
1838 }
1839 loop {
1840 tokio::select! {
1841 Some(_) = sessions.join_next(), if !sessions.is_empty() => {
1842 }
1844 Some(_) = watcher_set.join_next(), if !watcher_set.is_empty() => {}
1845 accepted = listener.accept() => {
1846 match accepted {
1847 Ok((sock, peer)) => {
1848 if sessions.len() >= max_connections {
1849 tracing::info!(
1850 target: "audit",
1851 event = "pgwire.connection_rejected",
1852 peer = %peer,
1853 reason = "max_connections",
1854 in_flight = sessions.len(),
1855 );
1856 drop(sock);
1857 continue;
1858 }
1859 if failures.is_blocked(
1860 peer.ip(),
1861 max_auth_failures_per_min,
1862 std::time::Duration::from_secs(60),
1863 ) {
1864 tracing::warn!(
1865 target: "audit",
1866 event = "pgwire.connection_rejected",
1867 peer = %peer,
1868 reason = "auth_failure_throttle",
1869 );
1870 drop(sock);
1871 continue;
1872 }
1873 let factory_ref = Arc::clone(&factory);
1874 let tls_ref: Option<tokio_rustls::TlsAcceptor> =
1879 tls_state.as_ref().map(|s| (*s.snapshot()).clone());
1880 let failures_ref = Arc::clone(&failures);
1881 let peer_str = peer.to_string();
1882 tracing::info!(
1883 target: "audit",
1884 event = "pgwire.connection_accepted",
1885 peer = %peer,
1886 auth = auth_mode,
1887 tls = tls_mode,
1888 );
1889 let peer_ip = peer.ip();
1890 sessions.spawn(async move {
1891 let result = process_socket(sock, tls_ref, factory_ref).await;
1892 let outcome = classify_outcome(&result);
1893 if outcome == "auth_failed" {
1894 failures_ref.record_failure(peer_ip);
1895 }
1896 tracing::info!(
1897 target: "audit",
1898 event = "pgwire.connection_closed",
1899 peer = %peer_str,
1900 outcome,
1901 );
1902 if let Err(e) = result {
1903 warn!(peer = %peer_str, error = %e, "pgwire connection error");
1904 }
1905 });
1906 }
1907 Err(e) => {
1908 warn!(error = %e, "pgwire accept failed");
1909 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
1910 }
1911 }
1912 }
1913 }
1914 }
1915 });
1916 Ok((local_addr, handle))
1917}
1918
1919#[cfg(test)]
1920mod tests {
1921 use super::*;
1922
1923 fn parse_one(sql: &str) -> StreamingStatement {
1924 parse_streaming_sql(sql)
1925 .unwrap()
1926 .into_iter()
1927 .next()
1928 .unwrap()
1929 }
1930
1931 fn standard(sql: &str) -> Statement {
1932 match parse_one(sql) {
1933 StreamingStatement::Standard(s) => *s,
1934 other => panic!("expected Standard, got {other:?}"),
1935 }
1936 }
1937
1938 #[test]
1939 fn pg_text_array_literal_quotes_nulls_and_escapes() {
1940 assert_eq!(pg_text_array_literal(&[]), "{}");
1941 assert_eq!(
1942 pg_text_array_literal(&[Some("en".into()), Some("ja".into())]),
1943 r#"{"en","ja"}"#
1944 );
1945 assert_eq!(
1946 pg_text_array_literal(&[None, Some("x".into())]),
1947 r#"{NULL,"x"}"#
1948 );
1949 assert_eq!(
1951 pg_text_array_literal(&[Some("a\"b\\c".into())]),
1952 r#"{"a\"b\\c"}"#
1953 );
1954 }
1955
1956 #[tokio::test]
1957 async fn select_one_dispatches() {
1958 let db = LaminarDB::open().unwrap();
1959 for sql in ["SELECT 1", "select 1", "/* hint */ SELECT 1"] {
1960 standard_response(&db, standard(sql)).unwrap();
1961 }
1962 }
1963
1964 #[tokio::test]
1965 async fn driver_select_builtins_dispatch() {
1966 let db = LaminarDB::open().unwrap();
1967 for sql in [
1968 "SELECT version()",
1969 "SELECT current_schema()",
1970 "SELECT current_database()",
1971 "SELECT current_user",
1972 ] {
1973 let _ = standard_response(&db, standard(sql));
1976 }
1977 }
1978
1979 #[tokio::test]
1980 async fn select_with_from_is_rejected() {
1981 let db = LaminarDB::open().unwrap();
1982 let err = standard_response(&db, standard("SELECT 1 FROM foo")).unwrap_err();
1983 assert!(err.to_string().contains("limited to literals"));
1984 }
1985
1986 #[tokio::test]
1987 async fn ddl_routed_to_http() {
1988 let db = LaminarDB::open().unwrap();
1989 let err = standard_response(&db, standard("CREATE TABLE foo (id INT)")).unwrap_err();
1990 assert!(err.to_string().contains("HTTP /api/v1/sql"));
1991 }
1992
1993 #[tokio::test]
1994 async fn transaction_control_dispatches() {
1995 let db = LaminarDB::open().unwrap();
1996 for sql in [
1997 "BEGIN",
1998 "BEGIN TRANSACTION",
1999 "START TRANSACTION",
2000 "COMMIT",
2001 "ROLLBACK",
2002 ] {
2003 standard_response(&db, standard(sql)).unwrap();
2004 }
2005 }
2006
2007 #[tokio::test]
2008 async fn set_writes_to_session_properties() {
2009 let db = LaminarDB::open().unwrap();
2010 standard_response(&db, standard("SET extra_float_digits = 3")).unwrap();
2011 assert_eq!(
2012 db.get_session_property("extra_float_digits").as_deref(),
2013 Some("3"),
2014 );
2015 }
2016
2017 #[tokio::test]
2018 async fn set_transaction_isolation_is_rejected() {
2019 let db = LaminarDB::open().unwrap();
2020 let err = standard_response(
2021 &db,
2022 standard("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE"),
2023 )
2024 .unwrap_err();
2025 assert!(err.to_string().contains("SET TRANSACTION"));
2026 }
2027
2028 #[test]
2029 fn multi_statement_parses() {
2030 let stmts = parse_streaming_sql("BEGIN; SELECT 1; COMMIT").unwrap();
2031 assert_eq!(stmts.len(), 3);
2032 }
2033
2034 #[test]
2035 fn classify_outcome_buckets_errors() {
2036 use std::io::{Error, ErrorKind};
2037 assert_eq!(super::classify_outcome(&Ok(())), "ok");
2038 assert_eq!(
2039 super::classify_outcome(&Err(Error::other("FATAL: 28P01 bad pass"))),
2040 "auth_failed"
2041 );
2042 assert_eq!(
2043 super::classify_outcome(&Err(Error::other("rustls HandshakeFailure"))),
2044 "tls_failed"
2045 );
2046 assert_eq!(
2047 super::classify_outcome(&Err(Error::new(ErrorKind::BrokenPipe, "broken"))),
2048 "error"
2049 );
2050 }
2051
2052 #[test]
2053 fn failure_tracker_blocks_after_threshold() {
2054 use std::net::{IpAddr, Ipv4Addr};
2055 use std::time::Duration;
2056 let ip: IpAddr = Ipv4Addr::LOCALHOST.into();
2057 let tracker = super::FailureTracker::default();
2058 let limit = 3;
2059 let window = Duration::from_secs(60);
2060
2061 for _ in 0..limit {
2062 assert!(!tracker.is_blocked(ip, limit, window));
2063 tracker.record_failure(ip);
2064 }
2065 assert!(tracker.is_blocked(ip, limit, window));
2066 }
2067
2068 #[test]
2069 fn failure_tracker_disabled_when_limit_zero() {
2070 use std::net::{IpAddr, Ipv4Addr};
2071 use std::time::Duration;
2072 let ip: IpAddr = Ipv4Addr::LOCALHOST.into();
2073 let tracker = super::FailureTracker::default();
2074 for _ in 0..100 {
2075 tracker.record_failure(ip);
2076 }
2077 assert!(!tracker.is_blocked(ip, 0, Duration::from_secs(60)));
2078 }
2079
2080 #[test]
2081 fn failure_tracker_expires_old_entries() {
2082 use std::net::{IpAddr, Ipv4Addr};
2083 use std::time::Duration;
2084 let ip: IpAddr = Ipv4Addr::LOCALHOST.into();
2085 let tracker = super::FailureTracker::default();
2086 for _ in 0..5 {
2087 tracker.record_failure(ip);
2088 }
2089 assert!(!tracker.is_blocked(ip, 5, Duration::from_secs(0)));
2091 }
2092
2093 #[test]
2094 fn failure_tracker_caps_distinct_ips() {
2095 use std::net::{IpAddr, Ipv4Addr};
2096 let tracker = super::FailureTracker::default();
2097 for i in 0..(super::MAX_TRACKED_IPS + 100) {
2099 #[allow(clippy::cast_possible_truncation)]
2100 let ip: IpAddr = Ipv4Addr::new(10, 0, (i / 256) as u8, (i % 256) as u8).into();
2101 tracker.record_failure(ip);
2102 }
2103 let len = tracker.inner.lock().len();
2104 assert!(
2105 len <= super::MAX_TRACKED_IPS,
2106 "tracker exceeded cap: {len} > {}",
2107 super::MAX_TRACKED_IPS
2108 );
2109 }
2110
2111 #[tokio::test]
2112 async fn serve_rejects_remote_bind_in_trust_mode() {
2113 let db = Arc::new(LaminarDB::open().expect("db opens"));
2114 let err = serve(db, "0.0.0.0:0", HashMap::new(), false, None, 256, 10)
2115 .await
2116 .expect_err("trust + 0.0.0.0 must fail");
2117 assert!(err.to_string().contains("trust auth"), "got: {err}");
2118 }
2119
2120 #[tokio::test]
2121 async fn serve_rejects_remote_bind_without_explicit_optin() {
2122 let db = Arc::new(LaminarDB::open().expect("db opens"));
2123 let mut users = HashMap::new();
2124 users.insert("alice".into(), Secret::new("wonderland-key"));
2125 let err = serve(db, "0.0.0.0:0", users, false, None, 256, 10)
2126 .await
2127 .expect_err("md5 + 0.0.0.0 without allow_remote must fail");
2128 assert!(
2129 err.to_string().contains("pgwire_allow_remote"),
2130 "got: {err}"
2131 );
2132 }
2133}
2134
2135#[cfg(test)]
2136mod integration_tests {
2137 use std::collections::HashMap;
2143 use std::sync::Arc;
2144
2145 use laminar_db::LaminarDB;
2146 use tokio_postgres::{NoTls, SimpleQueryMessage};
2147
2148 use super::Secret;
2149
2150 async fn spawn_server_with(
2151 users: HashMap<String, Secret>,
2152 ) -> (std::net::SocketAddr, tokio::task::JoinHandle<()>) {
2153 let db = Arc::new(LaminarDB::open().expect("db opens"));
2154 db.execute("CREATE SOURCE trades (symbol VARCHAR, price DOUBLE)")
2155 .await
2156 .expect("create source");
2157 db.execute(
2158 "CREATE MATERIALIZED VIEW prices AS \
2159 SELECT symbol, price FROM trades",
2160 )
2161 .await
2162 .expect("create mv");
2163 db.start().await.expect("db starts");
2164
2165 let (addr, handle) =
2166 super::serve(Arc::clone(&db), "127.0.0.1:0", users, false, None, 256, 10)
2167 .await
2168 .expect("pgwire serve");
2169 (addr, handle)
2170 }
2171
2172 async fn spawn_server() -> (std::net::SocketAddr, tokio::task::JoinHandle<()>) {
2173 spawn_server_with(HashMap::new()).await
2174 }
2175
2176 async fn connect(addr: std::net::SocketAddr) -> tokio_postgres::Client {
2177 let conn_str = format!(
2178 "host={} port={} user=any dbname=laminardb",
2179 addr.ip(),
2180 addr.port()
2181 );
2182 let (client, conn) = tokio_postgres::connect(&conn_str, NoTls)
2183 .await
2184 .expect("pgwire connect");
2185 tokio::spawn(async move {
2186 let _ = conn.await;
2187 });
2188 client
2189 }
2190
2191 fn first_row_value(messages: &[SimpleQueryMessage], col: usize) -> Option<&str> {
2192 messages.iter().find_map(|m| match m {
2193 SimpleQueryMessage::Row(r) => r.get(col),
2194 _ => None,
2195 })
2196 }
2197
2198 #[tokio::test]
2199 async fn handshake_and_builtins() {
2200 let (addr, handle) = spawn_server().await;
2201 let client = connect(addr).await;
2202
2203 let messages = client
2204 .simple_query("SELECT version()")
2205 .await
2206 .expect("version");
2207 let v = first_row_value(&messages, 0).expect("row");
2208 assert!(v.contains("LaminarDB"), "version: {v}");
2209
2210 let messages = client
2211 .simple_query("SELECT current_database()")
2212 .await
2213 .expect("current_database");
2214 assert_eq!(first_row_value(&messages, 0), Some("laminar"));
2215
2216 handle.abort();
2217 }
2218
2219 #[tokio::test]
2220 async fn show_streams_runs() {
2221 let (addr, handle) = spawn_server().await;
2222 let client = connect(addr).await;
2223
2224 client
2227 .simple_query("SHOW STREAMS")
2228 .await
2229 .expect("SHOW STREAMS");
2230
2231 handle.abort();
2232 }
2233
2234 #[tokio::test]
2235 async fn subscribe_unknown_name_returns_pg_error() {
2236 let (addr, handle) = spawn_server().await;
2237 let client = connect(addr).await;
2238
2239 let err = client
2240 .simple_query("SUBSCRIBE no_such_view")
2241 .await
2242 .expect_err("must fail");
2243 let db_err = err.as_db_error().expect("typed PG error");
2244 assert!(
2245 db_err.message().contains("no_such_view"),
2246 "message: {}",
2247 db_err.message()
2248 );
2249
2250 handle.abort();
2251 }
2252
2253 #[tokio::test]
2254 async fn subscribe_with_valid_where_is_accepted() {
2255 let (addr, handle) = spawn_server().await;
2259 let client = connect(addr).await;
2260
2261 let r = tokio::time::timeout(
2262 std::time::Duration::from_millis(500),
2263 client.simple_query("SUBSCRIBE prices WHERE symbol = 'AAPL'"),
2264 )
2265 .await;
2266
2267 assert!(
2268 r.is_err(),
2269 "subscribe must stay open until timeout, got: {r:?}"
2270 );
2271
2272 handle.abort();
2273 }
2274
2275 #[tokio::test]
2276 async fn subscribe_with_unknown_column_in_where_returns_pg_error() {
2277 let (addr, handle) = spawn_server().await;
2278 let client = connect(addr).await;
2279
2280 let err = client
2281 .simple_query("SUBSCRIBE prices WHERE no_such_col > 1")
2282 .await
2283 .expect_err("must fail");
2284 let db_err = err.as_db_error().expect("typed PG error");
2285 assert!(
2286 db_err.message().contains("no_such_col"),
2287 "filter error must name the bad column, got: {}",
2288 db_err.message()
2289 );
2290
2291 handle.abort();
2292 }
2293
2294 #[tokio::test]
2295 async fn subscribe_as_of_unretained_returns_pg_error() {
2296 let (addr, handle) = spawn_server().await;
2299 let client = connect(addr).await;
2300
2301 let err = client
2302 .simple_query("SUBSCRIBE prices AS OF EPOCH 1")
2303 .await
2304 .expect_err("must fail");
2305 let db_err = err.as_db_error().expect("typed PG error");
2306 assert!(
2307 db_err.message().contains("no longer retained"),
2308 "message: {}",
2309 db_err.message()
2310 );
2311
2312 handle.abort();
2313 }
2314
2315 #[tokio::test]
2319 async fn subscribe_streams_emitted_rows_over_the_wire() {
2320 use std::time::Duration;
2321
2322 use arrow_array::{Float64Array, RecordBatch, StringArray};
2323
2324 let db = Arc::new(LaminarDB::open().expect("db opens"));
2325 db.execute("CREATE SOURCE trades (symbol VARCHAR, price DOUBLE)")
2326 .await
2327 .expect("create source");
2328 db.execute("CREATE MATERIALIZED VIEW prices AS SELECT symbol, price FROM trades")
2329 .await
2330 .expect("create mv");
2331 db.start().await.expect("db starts");
2332 let (addr, handle) = super::serve(
2333 Arc::clone(&db),
2334 "127.0.0.1:0",
2335 HashMap::new(),
2336 false,
2337 None,
2338 256,
2339 10,
2340 )
2341 .await
2342 .expect("serve");
2343 let mut client = connect(addr).await;
2344 let txn = client.transaction().await.expect("begin");
2345
2346 let stmt = txn.prepare("SUBSCRIBE prices").await.expect("prepare");
2349 let portal = txn.bind(&stmt, &[]).await.expect("bind");
2350
2351 let pusher = tokio::spawn({
2352 let db = Arc::clone(&db);
2353 async move {
2354 tokio::time::sleep(Duration::from_millis(300)).await;
2355 let src = db.source_untyped("trades").expect("source handle");
2356 let batch = RecordBatch::try_new(
2357 src.schema().clone(),
2358 vec![
2359 Arc::new(StringArray::from(vec!["AAPL", "MSFT"])),
2360 Arc::new(Float64Array::from(vec![100.0, 200.0])),
2361 ],
2362 )
2363 .expect("batch");
2364 src.push_arrow(batch).expect("push");
2365 }
2366 });
2367
2368 let rows = tokio::time::timeout(Duration::from_secs(10), txn.query_portal(&portal, 2))
2369 .await
2370 .expect("read did not time out")
2371 .expect("query_portal");
2372 pusher.await.expect("pusher");
2373
2374 let mut symbols: Vec<String> = rows
2375 .iter()
2376 .map(|r| r.get::<_, &str>(0).to_string())
2377 .collect();
2378 symbols.sort();
2379 assert_eq!(
2380 symbols,
2381 ["AAPL", "MSFT"],
2382 "both emitted rows arrive over pgwire"
2383 );
2384
2385 handle.abort();
2386 }
2387
2388 #[tokio::test]
2392 async fn subscribe_decodes_text_array_in_binary_format() {
2393 use std::time::Duration;
2394
2395 use arrow_array::{Int64Array, RecordBatch};
2396
2397 let db = Arc::new(LaminarDB::open().expect("db opens"));
2398 db.execute("CREATE SOURCE feed (id BIGINT)")
2399 .await
2400 .expect("create source");
2401 db.execute(
2402 "CREATE MATERIALIZED VIEW tagged AS SELECT id, make_array('en','ja') AS tags FROM feed",
2403 )
2404 .await
2405 .expect("create mv");
2406 db.start().await.expect("db starts");
2407 let (addr, handle) = super::serve(
2408 Arc::clone(&db),
2409 "127.0.0.1:0",
2410 HashMap::new(),
2411 false,
2412 None,
2413 256,
2414 10,
2415 )
2416 .await
2417 .expect("serve");
2418 let mut client = connect(addr).await;
2419 let txn = client.transaction().await.expect("begin");
2420 let stmt = txn.prepare("SUBSCRIBE tagged").await.expect("prepare");
2421 let portal = txn.bind(&stmt, &[]).await.expect("bind");
2422
2423 let pusher = tokio::spawn({
2424 let db = Arc::clone(&db);
2425 async move {
2426 tokio::time::sleep(Duration::from_millis(300)).await;
2427 let src = db.source_untyped("feed").expect("source handle");
2428 let batch = RecordBatch::try_new(
2429 src.schema().clone(),
2430 vec![Arc::new(Int64Array::from(vec![1_i64]))],
2431 )
2432 .expect("batch");
2433 src.push_arrow(batch).expect("push");
2434 }
2435 });
2436
2437 let rows = tokio::time::timeout(Duration::from_secs(10), txn.query_portal(&portal, 1))
2438 .await
2439 .expect("read did not time out")
2440 .expect("query_portal");
2441 pusher.await.expect("pusher");
2442
2443 assert_eq!(rows.len(), 1);
2444 let tags: Vec<String> = rows[0].get(1);
2445 assert_eq!(
2446 tags,
2447 vec!["en".to_string(), "ja".to_string()],
2448 "TEXT[] decoded over the binary wire"
2449 );
2450
2451 handle.abort();
2452 }
2453
2454 #[tokio::test]
2455 async fn ddl_returns_pg_error_pointing_at_http() {
2456 let (addr, handle) = spawn_server().await;
2457 let client = connect(addr).await;
2458
2459 let err = client
2460 .simple_query("CREATE SOURCE more_trades (sym VARCHAR)")
2461 .await
2462 .expect_err("DDL must be rejected");
2463 let db_err = err.as_db_error().expect("typed PG error");
2464 assert!(
2465 db_err.message().contains("/api/v1/sql"),
2466 "message: {}",
2467 db_err.message()
2468 );
2469
2470 handle.abort();
2471 }
2472
2473 async fn md5_users() -> HashMap<String, Secret> {
2474 let mut u = HashMap::new();
2475 u.insert("alice".to_string(), Secret::new(TEST_PASSWORD));
2476 u
2477 }
2478
2479 const TEST_PASSWORD: &str = "wonderland-key";
2480
2481 async fn connect_with_password(
2482 addr: std::net::SocketAddr,
2483 user: &str,
2484 password: &str,
2485 ) -> Result<tokio_postgres::Client, tokio_postgres::Error> {
2486 let conn_str = format!(
2487 "host={} port={} user={user} password={password} dbname=laminardb",
2488 addr.ip(),
2489 addr.port()
2490 );
2491 let (client, conn) = tokio_postgres::connect(&conn_str, NoTls).await?;
2492 tokio::spawn(async move {
2493 let _ = conn.await;
2494 });
2495 Ok(client)
2496 }
2497
2498 #[tokio::test]
2499 async fn md5_auth_accepts_correct_password() {
2500 let (addr, handle) = spawn_server_with(md5_users().await).await;
2501
2502 let client = connect_with_password(addr, "alice", TEST_PASSWORD)
2503 .await
2504 .expect("auth must succeed");
2505
2506 let messages = client
2507 .simple_query("SELECT version()")
2508 .await
2509 .expect("query after auth");
2510 let v = first_row_value(&messages, 0).expect("row");
2511 assert!(v.contains("LaminarDB"), "version: {v}");
2512
2513 handle.abort();
2514 }
2515
2516 #[tokio::test]
2517 async fn md5_auth_rejects_wrong_password() {
2518 let (addr, handle) = spawn_server_with(md5_users().await).await;
2519
2520 let err = connect_with_password(addr, "alice", "not-the-password")
2521 .await
2522 .expect_err("auth must fail");
2523
2524 let db_err = err.as_db_error().expect("typed PG error");
2525 assert_eq!(db_err.code().code(), "28P01", "got: {db_err:?}");
2526
2527 handle.abort();
2528 }
2529
2530 fn md5_users_prehashed(user: &str, password: &str) -> HashMap<String, Secret> {
2533 use md5::{Digest, Md5};
2534 let mut h = Md5::new();
2535 h.update(password.as_bytes());
2536 h.update(user.as_bytes());
2537 let inner = format!("{:x}", h.finalize());
2538 let mut u = HashMap::new();
2539 u.insert(user.to_string(), Secret::new(format!("md5{inner}")));
2540 u
2541 }
2542
2543 #[tokio::test]
2544 async fn md5_auth_accepts_correct_password_against_prehash() {
2545 let (addr, handle) = spawn_server_with(md5_users_prehashed("alice", TEST_PASSWORD)).await;
2546 let client = connect_with_password(addr, "alice", TEST_PASSWORD)
2547 .await
2548 .expect("auth must succeed against pre-hashed entry");
2549 let messages = client
2550 .simple_query("SELECT version()")
2551 .await
2552 .expect("query after auth");
2553 let v = first_row_value(&messages, 0).expect("row");
2554 assert!(v.contains("LaminarDB"), "version: {v}");
2555 handle.abort();
2556 }
2557
2558 #[tokio::test]
2559 async fn md5_auth_rejects_wrong_password_against_prehash() {
2560 let (addr, handle) = spawn_server_with(md5_users_prehashed("alice", TEST_PASSWORD)).await;
2561 let err = connect_with_password(addr, "alice", "not-the-password")
2562 .await
2563 .expect_err("auth must fail");
2564 let db_err = err.as_db_error().expect("typed PG error");
2565 assert_eq!(db_err.code().code(), "28P01", "got: {db_err:?}");
2566 handle.abort();
2567 }
2568
2569 #[test]
2570 fn parse_pre_hashed_md5_strict_format() {
2571 let inner = "5d41402abc4b2a76b9719d911017c592";
2573 assert_eq!(
2574 super::parse_pre_hashed_md5(&format!("md5{inner}")),
2575 Some(inner),
2576 );
2577 assert_eq!(super::parse_pre_hashed_md5("md5short"), None);
2579 assert_eq!(
2580 super::parse_pre_hashed_md5("md55D41402ABC4B2A76B9719D911017C592"),
2581 None,
2582 );
2583 assert_eq!(super::parse_pre_hashed_md5(inner), None);
2584 assert_eq!(
2585 super::parse_pre_hashed_md5("md5zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz"),
2586 None,
2587 );
2588 }
2589
2590 #[tokio::test]
2591 async fn md5_auth_rejects_unknown_user() {
2592 let (addr, handle) = spawn_server_with(md5_users().await).await;
2593
2594 let err = connect_with_password(addr, "mallory", "anything")
2595 .await
2596 .expect_err("auth must fail");
2597
2598 let db_err = err.as_db_error().expect("typed PG error");
2599 assert_eq!(db_err.code().code(), "28P01", "got: {db_err:?}");
2600
2601 handle.abort();
2602 }
2603
2604 #[tokio::test]
2605 async fn connection_cap_drops_excess_clients() {
2606 let db = Arc::new(LaminarDB::open().expect("db opens"));
2608 db.execute("CREATE SOURCE trades (symbol VARCHAR, price DOUBLE)")
2609 .await
2610 .expect("create source");
2611 db.execute("CREATE MATERIALIZED VIEW prices AS SELECT symbol, price FROM trades")
2612 .await
2613 .expect("create mv");
2614 db.start().await.expect("db starts");
2615 let (addr, handle) = super::serve(
2616 Arc::clone(&db),
2617 "127.0.0.1:0",
2618 HashMap::new(),
2619 false,
2620 None,
2621 1,
2622 10,
2623 )
2624 .await
2625 .expect("pgwire serve");
2626
2627 let first = connect(addr).await;
2629 let _bg = tokio::spawn(async move {
2630 let _ = first.simple_query("SUBSCRIBE prices").await;
2631 });
2632 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
2634
2635 let conn_str = format!(
2639 "host={} port={} user=any dbname=laminardb",
2640 addr.ip(),
2641 addr.port()
2642 );
2643 let result = tokio_postgres::connect(&conn_str, NoTls).await;
2644 assert!(result.is_err(), "second connect must be refused");
2645
2646 handle.abort();
2647 }
2648
2649 fn self_signed_pem() -> (tempfile::TempDir, std::path::PathBuf, std::path::PathBuf) {
2652 let cert =
2653 rcgen::generate_simple_self_signed(vec!["localhost".into()]).expect("rcgen issue cert");
2654 let dir = tempfile::tempdir().expect("tempdir");
2655 let cert_path = dir.path().join("cert.pem");
2656 let key_path = dir.path().join("key.pem");
2657 std::fs::write(&cert_path, cert.cert.pem()).unwrap();
2658 std::fs::write(&key_path, cert.key_pair.serialize_pem()).unwrap();
2659 (dir, cert_path, key_path)
2660 }
2661
2662 struct MintedClientPki {
2667 _dir: tempfile::TempDir,
2668 ca_pem_path: std::path::PathBuf,
2669 leaf_chain: Vec<tokio_rustls::rustls::pki_types::CertificateDer<'static>>,
2670 leaf_key: tokio_rustls::rustls::pki_types::PrivateKeyDer<'static>,
2671 }
2672
2673 fn mint_ca_and_client_leaf(common_name: &str) -> MintedClientPki {
2674 use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
2675
2676 let mut ca_params = rcgen::CertificateParams::new(vec!["mtls-test-ca".into()]).unwrap();
2677 ca_params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
2678 let ca_key = rcgen::KeyPair::generate().unwrap();
2679 let ca_cert = ca_params.self_signed(&ca_key).unwrap();
2680
2681 let mut leaf_params = rcgen::CertificateParams::new(vec![common_name.into()]).unwrap();
2682 leaf_params.extended_key_usages = vec![rcgen::ExtendedKeyUsagePurpose::ClientAuth];
2683 let leaf_key = rcgen::KeyPair::generate().unwrap();
2684 let leaf_cert = leaf_params.signed_by(&leaf_key, &ca_cert, &ca_key).unwrap();
2685
2686 let dir = tempfile::tempdir().unwrap();
2687 let ca_pem_path = dir.path().join("ca.pem");
2688 std::fs::write(&ca_pem_path, ca_cert.pem()).unwrap();
2689
2690 let leaf_chain = vec![CertificateDer::from(leaf_cert.der().to_vec())];
2691 let leaf_key = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(leaf_key.serialize_der()));
2692
2693 MintedClientPki {
2694 _dir: dir,
2695 ca_pem_path,
2696 leaf_chain,
2697 leaf_key,
2698 }
2699 }
2700
2701 fn make_client_tls(
2704 server_cert_path: &std::path::Path,
2705 client_auth: Option<(
2706 Vec<tokio_rustls::rustls::pki_types::CertificateDer<'static>>,
2707 tokio_rustls::rustls::pki_types::PrivateKeyDer<'static>,
2708 )>,
2709 ) -> tokio_postgres_rustls::MakeRustlsConnect {
2710 super::ensure_tls_provider();
2711 let cert_bytes = std::fs::read(server_cert_path).unwrap();
2712 let mut roots = tokio_rustls::rustls::RootCertStore::empty();
2713 for c in rustls_pemfile::certs(&mut std::io::Cursor::new(cert_bytes))
2714 .collect::<Result<Vec<_>, _>>()
2715 .unwrap()
2716 {
2717 roots.add(c).unwrap();
2718 }
2719 let builder = tokio_rustls::rustls::ClientConfig::builder().with_root_certificates(roots);
2720 let client_cfg = match client_auth {
2721 Some((chain, key)) => builder.with_client_auth_cert(chain, key).unwrap(),
2722 None => builder.with_no_client_auth(),
2723 };
2724 tokio_postgres_rustls::MakeRustlsConnect::new(client_cfg)
2725 }
2726
2727 fn expired_self_signed_pem() -> (tempfile::TempDir, std::path::PathBuf, std::path::PathBuf) {
2729 let mut params = rcgen::CertificateParams::new(vec!["localhost".into()]).unwrap();
2730 let one_year_ago = time::OffsetDateTime::now_utc() - time::Duration::days(365);
2731 params.not_before = one_year_ago - time::Duration::days(2);
2732 params.not_after = one_year_ago;
2733 let key = rcgen::KeyPair::generate().unwrap();
2734 let cert = params.self_signed(&key).unwrap();
2735 let dir = tempfile::tempdir().unwrap();
2736 let cert_path = dir.path().join("cert.pem");
2737 let key_path = dir.path().join("key.pem");
2738 std::fs::write(&cert_path, cert.pem()).unwrap();
2739 std::fs::write(&key_path, key.serialize_pem()).unwrap();
2740 (dir, cert_path, key_path)
2741 }
2742
2743 #[tokio::test]
2744 async fn tls_load_rejects_expired_cert() {
2745 let (_dir, cert_path, key_path) = expired_self_signed_pem();
2746 let db = Arc::new(LaminarDB::open().expect("db opens"));
2747 db.start().await.expect("db starts");
2748 let err = super::serve(
2749 Arc::clone(&db),
2750 "127.0.0.1:0",
2751 HashMap::new(),
2752 false,
2753 Some(super::TlsPaths {
2754 cert: &cert_path,
2755 key: &key_path,
2756 min_version: super::TlsMinVersion::V1_2,
2757 client_ca: None,
2758 }),
2759 256,
2760 10,
2761 )
2762 .await
2763 .expect_err("expired cert must be rejected");
2764 assert!(err.to_string().contains("expired"), "got: {err}");
2765 }
2766
2767 #[tokio::test]
2768 async fn tls_min_1_3_rejects_tls_1_2_client() {
2769 let (_dir, cert_path, key_path) = self_signed_pem();
2770 let db = Arc::new(LaminarDB::open().expect("db opens"));
2771 db.start().await.expect("db starts");
2772 let (addr, handle) = super::serve(
2773 Arc::clone(&db),
2774 "127.0.0.1:0",
2775 HashMap::new(),
2776 false,
2777 Some(super::TlsPaths {
2778 cert: &cert_path,
2779 key: &key_path,
2780 min_version: super::TlsMinVersion::V1_3,
2781 client_ca: None,
2782 }),
2783 256,
2784 10,
2785 )
2786 .await
2787 .expect("pgwire serve");
2788
2789 let cert_bytes = std::fs::read(&cert_path).unwrap();
2790 let mut roots = tokio_rustls::rustls::RootCertStore::empty();
2791 for c in rustls_pemfile::certs(&mut std::io::Cursor::new(cert_bytes))
2792 .collect::<Result<Vec<_>, _>>()
2793 .unwrap()
2794 {
2795 roots.add(c).unwrap();
2796 }
2797 super::ensure_tls_provider();
2798 let client_cfg = tokio_rustls::rustls::ClientConfig::builder_with_protocol_versions(&[
2800 &tokio_rustls::rustls::version::TLS12,
2801 ])
2802 .with_root_certificates(roots)
2803 .with_no_client_auth();
2804
2805 let conn_str = format!(
2806 "host=localhost hostaddr={} port={} user=any dbname=laminardb sslmode=require",
2807 addr.ip(),
2808 addr.port(),
2809 );
2810 let tls = tokio_postgres_rustls::MakeRustlsConnect::new(client_cfg);
2811 let err = match tokio_postgres::connect(&conn_str, tls).await {
2812 Ok(_) => panic!("TLS 1.2 client must be refused by a 1.3-min server"),
2813 Err(e) => e,
2814 };
2815 let chain = std::iter::successors(Some(&err as &dyn std::error::Error), |e| e.source())
2818 .map(|e| e.to_string())
2819 .collect::<Vec<_>>()
2820 .join(" | ");
2821 assert!(
2822 chain.contains("ProtocolVersion") || chain.contains("incompatible"),
2823 "expected a TLS version-mismatch error, got: {chain}"
2824 );
2825
2826 handle.abort();
2827 }
2828
2829 #[tokio::test]
2830 async fn tls_handshake_succeeds() {
2831 let (_dir, cert_path, key_path) = self_signed_pem();
2832 let db = Arc::new(LaminarDB::open().expect("db opens"));
2833 db.start().await.expect("db starts");
2834 let (addr, handle) = super::serve(
2835 Arc::clone(&db),
2836 "127.0.0.1:0",
2837 HashMap::new(),
2838 false,
2839 Some(super::TlsPaths {
2840 cert: &cert_path,
2841 key: &key_path,
2842 min_version: super::TlsMinVersion::V1_2,
2843 client_ca: None,
2844 }),
2845 256,
2846 10,
2847 )
2848 .await
2849 .expect("pgwire serve");
2850
2851 let cert_bytes = std::fs::read(&cert_path).unwrap();
2853 let mut roots = tokio_rustls::rustls::RootCertStore::empty();
2854 for c in rustls_pemfile::certs(&mut std::io::Cursor::new(cert_bytes))
2855 .collect::<Result<Vec<_>, _>>()
2856 .unwrap()
2857 {
2858 roots.add(c).unwrap();
2859 }
2860 super::ensure_tls_provider();
2861 let client_cfg = tokio_rustls::rustls::ClientConfig::builder()
2862 .with_root_certificates(roots)
2863 .with_no_client_auth();
2864
2865 let conn_str = format!(
2866 "host=localhost hostaddr={} port={} user=any dbname=laminardb sslmode=require",
2867 addr.ip(),
2868 addr.port(),
2869 );
2870 let tls = tokio_postgres_rustls::MakeRustlsConnect::new(client_cfg);
2871 let (client, conn) = tokio_postgres::connect(&conn_str, tls)
2872 .await
2873 .expect("TLS handshake + connect");
2874 tokio::spawn(async move {
2875 let _ = conn.await;
2876 });
2877
2878 let messages = client
2879 .simple_query("SELECT version()")
2880 .await
2881 .expect("query over TLS");
2882 let v = first_row_value(&messages, 0).expect("row");
2883 assert!(v.contains("LaminarDB"), "version: {v}");
2884
2885 handle.abort();
2886 }
2887
2888 #[tokio::test]
2891 async fn mtls_rejects_client_without_cert() {
2892 let (_dir, cert_path, key_path) = self_signed_pem();
2893 let pki = mint_ca_and_client_leaf("alice");
2894 let db = Arc::new(LaminarDB::open().expect("db opens"));
2895 db.start().await.expect("db starts");
2896 let (addr, handle) = super::serve(
2897 Arc::clone(&db),
2898 "127.0.0.1:0",
2899 HashMap::new(),
2900 false,
2901 Some(super::TlsPaths {
2902 cert: &cert_path,
2903 key: &key_path,
2904 min_version: super::TlsMinVersion::V1_2,
2905 client_ca: Some(&pki.ca_pem_path),
2906 }),
2907 256,
2908 10,
2909 )
2910 .await
2911 .expect("pgwire serve");
2912
2913 let tls = make_client_tls(&cert_path, None);
2914 let conn_str = format!(
2915 "host=localhost hostaddr={} port={} user=any dbname=laminardb sslmode=require",
2916 addr.ip(),
2917 addr.port(),
2918 );
2919 let err = match tokio_postgres::connect(&conn_str, tls).await {
2920 Ok(_) => panic!("client without a cert must be refused under mTLS"),
2921 Err(e) => e,
2922 };
2923 assert!(
2924 err_chain(&err).contains("CertificateRequired")
2925 || err_chain(&err).contains("HandshakeFailure")
2926 || err_chain(&err).contains("certificate required"),
2927 "expected a missing-client-cert error, got: {}",
2928 err_chain(&err),
2929 );
2930 handle.abort();
2931 }
2932
2933 #[tokio::test]
2935 async fn mtls_rejects_untrusted_client_cert() {
2936 let (_dir, cert_path, key_path) = self_signed_pem();
2937 let trusted = mint_ca_and_client_leaf("trusted");
2938 let stranger = mint_ca_and_client_leaf("stranger");
2939 let db = Arc::new(LaminarDB::open().expect("db opens"));
2940 db.start().await.expect("db starts");
2941 let (addr, handle) = super::serve(
2942 Arc::clone(&db),
2943 "127.0.0.1:0",
2944 HashMap::new(),
2945 false,
2946 Some(super::TlsPaths {
2947 cert: &cert_path,
2948 key: &key_path,
2949 min_version: super::TlsMinVersion::V1_2,
2950 client_ca: Some(&trusted.ca_pem_path),
2951 }),
2952 256,
2953 10,
2954 )
2955 .await
2956 .expect("pgwire serve");
2957
2958 let tls = make_client_tls(
2960 &cert_path,
2961 Some((stranger.leaf_chain.clone(), stranger.leaf_key.clone_key())),
2962 );
2963 let conn_str = format!(
2964 "host=localhost hostaddr={} port={} user=any dbname=laminardb sslmode=require",
2965 addr.ip(),
2966 addr.port(),
2967 );
2968 let err = match tokio_postgres::connect(&conn_str, tls).await {
2969 Ok(_) => panic!("untrusted client cert must be refused"),
2970 Err(e) => e,
2971 };
2972 let chain = err_chain(&err);
2977 assert!(
2978 chain.contains("UnknownCA")
2979 || chain.contains("BadCertificate")
2980 || chain.contains("CertificateUnknown")
2981 || chain.contains("DecryptError")
2982 || chain.contains("HandshakeFailure"),
2983 "expected a cert-rejection alert, got: {chain}",
2984 );
2985 handle.abort();
2986 }
2987
2988 #[tokio::test]
2991 async fn mtls_accepts_trusted_client_cert() {
2992 let (_dir, cert_path, key_path) = self_signed_pem();
2993 let pki = mint_ca_and_client_leaf("alice");
2994 let db = Arc::new(LaminarDB::open().expect("db opens"));
2995 db.start().await.expect("db starts");
2996 let (addr, handle) = super::serve(
2997 Arc::clone(&db),
2998 "127.0.0.1:0",
2999 HashMap::new(),
3000 false,
3001 Some(super::TlsPaths {
3002 cert: &cert_path,
3003 key: &key_path,
3004 min_version: super::TlsMinVersion::V1_2,
3005 client_ca: Some(&pki.ca_pem_path),
3006 }),
3007 256,
3008 10,
3009 )
3010 .await
3011 .expect("pgwire serve");
3012
3013 let tls = make_client_tls(
3014 &cert_path,
3015 Some((pki.leaf_chain.clone(), pki.leaf_key.clone_key())),
3016 );
3017 let conn_str = format!(
3018 "host=localhost hostaddr={} port={} user=any dbname=laminardb sslmode=require",
3019 addr.ip(),
3020 addr.port(),
3021 );
3022 let (client, conn) = tokio_postgres::connect(&conn_str, tls)
3023 .await
3024 .expect("mTLS handshake + connect");
3025 tokio::spawn(async move {
3026 let _ = conn.await;
3027 });
3028
3029 let messages = client
3030 .simple_query("SELECT version()")
3031 .await
3032 .expect("query over mTLS");
3033 let v = first_row_value(&messages, 0).expect("row");
3034 assert!(v.contains("LaminarDB"), "version: {v}");
3035 handle.abort();
3036 }
3037
3038 fn build_reload_state(cert: &std::path::Path, key: &std::path::Path) -> super::TlsReloadState {
3041 let paths = super::TlsPaths {
3042 cert,
3043 key,
3044 min_version: super::TlsMinVersion::V1_2,
3045 client_ca: None,
3046 };
3047 let acceptor = super::load_tls_acceptor(super::TlsPaths {
3048 cert: paths.cert,
3049 key: paths.key,
3050 min_version: paths.min_version,
3051 client_ca: paths.client_ca,
3052 })
3053 .expect("initial acceptor loads");
3054 super::TlsReloadState {
3055 paths: super::TlsConfigPaths::from_paths(&paths),
3056 acceptor: parking_lot::Mutex::new(Arc::new(acceptor)),
3057 }
3058 }
3059
3060 #[test]
3063 fn tls_reload_swaps_acceptor_on_valid_pair() {
3064 let dir = tempfile::tempdir().unwrap();
3065 let cert_path = dir.path().join("cert.pem");
3066 let key_path = dir.path().join("key.pem");
3067 let first = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
3069 std::fs::write(&cert_path, first.cert.pem()).unwrap();
3070 std::fs::write(&key_path, first.key_pair.serialize_pem()).unwrap();
3071
3072 let state = build_reload_state(&cert_path, &key_path);
3073 let before = state.snapshot();
3074
3075 let second = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
3077 std::fs::write(&cert_path, second.cert.pem()).unwrap();
3078 std::fs::write(&key_path, second.key_pair.serialize_pem()).unwrap();
3079
3080 super::try_reload_tls(&state).expect("reload succeeds");
3081 let after = state.snapshot();
3082 assert!(
3083 !Arc::ptr_eq(&before, &after),
3084 "acceptor pointer must change after a successful reload",
3085 );
3086 }
3087
3088 #[test]
3091 fn tls_reload_keeps_old_acceptor_on_garbage() {
3092 let dir = tempfile::tempdir().unwrap();
3093 let cert_path = dir.path().join("cert.pem");
3094 let key_path = dir.path().join("key.pem");
3095 let first = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
3096 std::fs::write(&cert_path, first.cert.pem()).unwrap();
3097 std::fs::write(&key_path, first.key_pair.serialize_pem()).unwrap();
3098
3099 let state = build_reload_state(&cert_path, &key_path);
3100 let before = state.snapshot();
3101
3102 std::fs::write(&cert_path, b"this is not a certificate").unwrap();
3104 let err = super::try_reload_tls(&state).expect_err("reload must fail");
3105 let after = state.snapshot();
3106 assert!(
3107 Arc::ptr_eq(&before, &after),
3108 "acceptor must be unchanged on reload failure",
3109 );
3110 assert!(
3111 err.to_string().contains("pgwire_tls_cert"),
3112 "error should mention pgwire_tls_cert, got: {err}",
3113 );
3114 }
3115
3116 fn err_chain(err: &(dyn std::error::Error + 'static)) -> String {
3119 std::iter::successors(Some(err), |e| e.source())
3120 .map(|e| e.to_string())
3121 .collect::<Vec<_>>()
3122 .join(" | ")
3123 }
3124
3125 async fn push_one_trade(
3129 db: &Arc<LaminarDB>,
3130 symbol: &str,
3131 price: f64,
3132 ) -> arrow_schema::SchemaRef {
3133 let handle = db.source_untyped("trades").expect("source handle");
3134 let schema = handle.schema().clone();
3135 let batch = arrow_array::RecordBatch::try_new(
3136 Arc::clone(&schema),
3137 vec![
3138 Arc::new(arrow_array::StringArray::from(vec![symbol])),
3139 Arc::new(arrow_array::Float64Array::from(vec![price])),
3140 ],
3141 )
3142 .expect("batch");
3143 handle.push_arrow(batch).expect("push");
3144 schema
3145 }
3146
3147 async fn spawn_with_data() -> (
3150 Arc<LaminarDB>,
3151 std::net::SocketAddr,
3152 tokio::task::JoinHandle<()>,
3153 ) {
3154 let db = Arc::new(LaminarDB::open().expect("db opens"));
3155 db.execute("CREATE SOURCE trades (symbol VARCHAR, price DOUBLE)")
3156 .await
3157 .expect("create source");
3158 db.execute(
3159 "CREATE MATERIALIZED VIEW prices AS \
3160 SELECT symbol, price FROM trades",
3161 )
3162 .await
3163 .expect("create mv");
3164 db.start().await.expect("db starts");
3165
3166 let (addr, handle) = super::serve(
3167 Arc::clone(&db),
3168 "127.0.0.1:0",
3169 HashMap::new(),
3170 false,
3171 None,
3172 256,
3173 10,
3174 )
3175 .await
3176 .expect("pgwire serve");
3177 (db, addr, handle)
3178 }
3179
3180 async fn spawn_with_retained_data() -> (
3184 Arc<LaminarDB>,
3185 std::net::SocketAddr,
3186 tokio::task::JoinHandle<()>,
3187 ) {
3188 let db = Arc::new(LaminarDB::open().expect("db opens"));
3189 db.execute("CREATE SOURCE trades (symbol VARCHAR, price DOUBLE)")
3190 .await
3191 .expect("create source");
3192 db.execute(
3193 "CREATE STREAM prices AS SELECT symbol, price FROM trades \
3194 WITH ('retain_history' = '4mb')",
3195 )
3196 .await
3197 .expect("create stream");
3198 db.start().await.expect("db starts");
3199
3200 let (addr, handle) = super::serve(
3201 Arc::clone(&db),
3202 "127.0.0.1:0",
3203 HashMap::new(),
3204 false,
3205 None,
3206 256,
3207 10,
3208 )
3209 .await
3210 .expect("pgwire serve");
3211 (db, addr, handle)
3212 }
3213
3214 #[tokio::test]
3218 async fn extended_query_describe_subscribe_returns_columns() {
3219 let (_db, addr, handle) = spawn_with_data().await;
3220 let client = connect(addr).await;
3221
3222 let stmt = client
3223 .prepare("SUBSCRIBE prices")
3224 .await
3225 .expect("prepare SUBSCRIBE prices");
3226
3227 let cols = stmt.columns();
3228 assert_eq!(cols.len(), 2, "expected 2 columns, got {}", cols.len());
3229 assert_eq!(cols[0].name(), "symbol");
3230 assert_eq!(cols[1].name(), "price");
3231 assert_eq!(cols[0].type_(), &tokio_postgres::types::Type::VARCHAR);
3232 assert_eq!(cols[1].type_(), &tokio_postgres::types::Type::FLOAT8);
3233
3234 handle.abort();
3235 }
3236
3237 #[tokio::test]
3240 async fn extended_query_prepare_unknown_stream_errors() {
3241 let (_db, addr, handle) = spawn_with_data().await;
3242 let client = connect(addr).await;
3243
3244 let err = client
3245 .prepare("SUBSCRIBE no_such_view")
3246 .await
3247 .expect_err("must fail at Parse");
3248 let db_err = err.as_db_error().expect("typed PG error");
3249 assert!(db_err.message().contains("no_such_view"));
3250
3251 handle.abort();
3252 }
3253
3254 #[tokio::test]
3258 async fn extended_query_binary_chunked_subscribe() {
3259 let (db, addr, handle) = spawn_with_data().await;
3260 let mut client = connect(addr).await;
3261
3262 let tx = client.transaction().await.expect("BEGIN");
3266 let stmt = tx.prepare("SUBSCRIBE prices").await.expect("prepare");
3267 let portal = tx.bind(&stmt, &[]).await.expect("bind portal");
3268
3269 let pusher = {
3276 let db = Arc::clone(&db);
3277 tokio::spawn(async move {
3278 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
3279 push_one_trade(&db, "AAPL", 150.5).await;
3280 push_one_trade(&db, "GOOG", 2700.25).await;
3281 })
3282 };
3283
3284 let first = tokio::time::timeout(
3285 std::time::Duration::from_secs(3),
3286 tx.query_portal(&portal, 1),
3287 )
3288 .await
3289 .expect("first chunk arrives within 3s")
3290 .expect("query_portal #1");
3291 assert_eq!(first.len(), 1);
3292 let symbol: &str = first[0].get(0);
3293 let price: f64 = first[0].get(1);
3294 assert_eq!(symbol, "AAPL");
3295 assert!((price - 150.5).abs() < 1e-9);
3296
3297 let second = tokio::time::timeout(
3298 std::time::Duration::from_secs(3),
3299 tx.query_portal(&portal, 1),
3300 )
3301 .await
3302 .expect("second chunk arrives within 3s")
3303 .expect("query_portal #2");
3304 assert_eq!(second.len(), 1);
3305 let symbol: &str = second[0].get(0);
3306 let price: f64 = second[0].get(1);
3307 assert_eq!(symbol, "GOOG");
3308 assert!((price - 2700.25).abs() < 1e-9);
3309
3310 pusher.await.expect("push task");
3311 handle.abort();
3312 }
3313
3314 #[tokio::test]
3320 async fn extended_query_binary_timestamp() {
3321 let db = Arc::new(LaminarDB::open().expect("db opens"));
3322 db.execute(
3326 "CREATE SOURCE events (ts TIMESTAMP, sym VARCHAR, \
3327 WATERMARK FOR ts AS ts - INTERVAL '0' SECOND)",
3328 )
3329 .await
3330 .expect("create source");
3331 db.execute("CREATE MATERIALIZED VIEW ev AS SELECT ts, sym FROM events")
3332 .await
3333 .expect("create mv");
3334 db.start().await.expect("db starts");
3335
3336 let (addr, handle) = super::serve(
3337 Arc::clone(&db),
3338 "127.0.0.1:0",
3339 HashMap::new(),
3340 false,
3341 None,
3342 256,
3343 10,
3344 )
3345 .await
3346 .expect("pgwire serve");
3347
3348 let mut client = connect(addr).await;
3349 let tx = client.transaction().await.expect("BEGIN");
3350 let stmt = tx.prepare("SUBSCRIBE ev").await.expect("prepare");
3351 let portal = tx.bind(&stmt, &[]).await.expect("bind");
3352
3353 let expected = chrono::NaiveDate::from_ymd_opt(2026, 5, 9)
3354 .unwrap()
3355 .and_hms_opt(0, 0, 0)
3356 .unwrap();
3357 let ts_us = expected.and_utc().timestamp_micros();
3358
3359 let pusher = {
3364 let db = Arc::clone(&db);
3365 tokio::spawn(async move {
3366 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
3367 let src = db.source_untyped("events").expect("source");
3368 let batch = arrow_array::RecordBatch::try_new(
3369 src.schema().clone(),
3370 vec![
3371 Arc::new(arrow_array::TimestampMicrosecondArray::from(vec![ts_us])),
3372 Arc::new(arrow_array::StringArray::from(vec!["AAPL"])),
3373 ],
3374 )
3375 .expect("batch");
3376 src.push_arrow(batch).expect("push");
3377 })
3378 };
3379
3380 let rows = tokio::time::timeout(
3381 std::time::Duration::from_secs(3),
3382 tx.query_portal(&portal, 1),
3383 )
3384 .await
3385 .expect("row arrives within 3s")
3386 .expect("query_portal");
3387 assert_eq!(rows.len(), 1);
3388
3389 let ts: chrono::NaiveDateTime = rows[0].get(0);
3390 let sym: &str = rows[0].get(1);
3391 assert_eq!(ts, expected);
3392 assert_eq!(sym, "AAPL");
3393
3394 pusher.await.expect("push task");
3395 handle.abort();
3396 }
3397
3398 #[tokio::test]
3402 async fn extended_query_ddl_rejected() {
3403 let (_db, addr, handle) = spawn_with_data().await;
3404 let client = connect(addr).await;
3405
3406 let err = client
3407 .prepare("CREATE SOURCE more_trades (sym VARCHAR)")
3408 .await
3409 .expect_err("DDL must be rejected at Parse");
3410 let db_err = err.as_db_error().expect("typed PG error");
3411 assert!(
3412 db_err.message().contains("/api/v1/sql"),
3413 "message: {}",
3414 db_err.message()
3415 );
3416
3417 handle.abort();
3418 }
3419
3420 #[tokio::test]
3424 async fn cursor_declare_fetch_close_happy_path() {
3425 let (db, addr, handle) = spawn_with_retained_data().await;
3426 let client = connect(addr).await;
3427
3428 for i in 0..4 {
3429 push_one_trade(&db, &format!("S{i}"), i as f64).await;
3430 }
3431
3432 client.simple_query("BEGIN").await.expect("BEGIN");
3433 client
3434 .simple_query("DECLARE c CURSOR FOR SUBSCRIBE prices")
3435 .await
3436 .expect("DECLARE");
3437
3438 let messages = client
3439 .simple_query("FETCH 2 FROM c")
3440 .await
3441 .expect("FETCH 2");
3442 let row_count = messages
3443 .iter()
3444 .filter(|m| matches!(m, SimpleQueryMessage::Row(_)))
3445 .count();
3446 assert_eq!(row_count, 2, "expected exactly 2 rows from FETCH 2");
3447
3448 client.simple_query("CLOSE c").await.expect("CLOSE");
3449 client.simple_query("COMMIT").await.expect("COMMIT");
3450
3451 handle.abort();
3452 }
3453
3454 #[tokio::test]
3457 async fn cursor_commit_closes_cursors() {
3458 let (_db, addr, handle) = spawn_with_data().await;
3459 let client = connect(addr).await;
3460
3461 client.simple_query("BEGIN").await.expect("BEGIN");
3462 client
3463 .simple_query("DECLARE c CURSOR FOR SUBSCRIBE prices")
3464 .await
3465 .expect("DECLARE");
3466 client.simple_query("COMMIT").await.expect("COMMIT");
3467
3468 let err = client
3469 .simple_query("FETCH 1 FROM c")
3470 .await
3471 .expect_err("FETCH after COMMIT must fail");
3472 let db_err = err.as_db_error().expect("typed PG error");
3473 assert_eq!(db_err.code().code(), "34000", "got {db_err:?}");
3474
3475 handle.abort();
3476 }
3477
3478 #[tokio::test]
3480 async fn cursor_rollback_closes_cursors() {
3481 let (_db, addr, handle) = spawn_with_data().await;
3482 let client = connect(addr).await;
3483
3484 client.simple_query("BEGIN").await.expect("BEGIN");
3485 client
3486 .simple_query("DECLARE c CURSOR FOR SUBSCRIBE prices")
3487 .await
3488 .expect("DECLARE");
3489 client.simple_query("ROLLBACK").await.expect("ROLLBACK");
3490
3491 let err = client
3492 .simple_query("FETCH 1 FROM c")
3493 .await
3494 .expect_err("FETCH after ROLLBACK must fail");
3495 let db_err = err.as_db_error().expect("typed PG error");
3496 assert_eq!(db_err.code().code(), "34000", "got {db_err:?}");
3497
3498 handle.abort();
3499 }
3500
3501 #[tokio::test]
3504 async fn cursor_close_explicit() {
3505 let (_db, addr, handle) = spawn_with_data().await;
3506 let client = connect(addr).await;
3507
3508 client
3509 .simple_query("DECLARE c CURSOR FOR SUBSCRIBE prices")
3510 .await
3511 .expect("DECLARE");
3512 client.simple_query("CLOSE c").await.expect("CLOSE");
3513
3514 let err = client
3515 .simple_query("FETCH 1 FROM c")
3516 .await
3517 .expect_err("FETCH after CLOSE must fail");
3518 let db_err = err.as_db_error().expect("typed PG error");
3519 assert_eq!(db_err.code().code(), "34000", "got {db_err:?}");
3520
3521 handle.abort();
3522 }
3523
3524 #[tokio::test]
3526 async fn cursor_unsupported_modifiers_rejected() {
3527 let (_db, addr, handle) = spawn_with_data().await;
3528 let client = connect(addr).await;
3529
3530 for sql in [
3531 "DECLARE c SCROLL CURSOR FOR SUBSCRIBE prices",
3532 "DECLARE c BINARY CURSOR FOR SUBSCRIBE prices",
3533 "DECLARE c CURSOR WITH HOLD FOR SUBSCRIBE prices",
3534 "DECLARE c INSENSITIVE CURSOR FOR SUBSCRIBE prices",
3535 ] {
3536 let err = client
3537 .simple_query(sql)
3538 .await
3539 .expect_err(&format!("{sql} must fail"));
3540 let db_err = err.as_db_error().expect("typed PG error");
3541 assert_eq!(
3542 db_err.code().code(),
3543 "42601",
3544 "{sql}: expected parse error, got {db_err:?}"
3545 );
3546 }
3547
3548 handle.abort();
3549 }
3550
3551 #[tokio::test]
3554 async fn cursor_backward_directions_rejected() {
3555 let (_db, addr, handle) = spawn_with_data().await;
3556 let client = connect(addr).await;
3557
3558 client
3559 .simple_query("DECLARE c CURSOR FOR SUBSCRIBE prices")
3560 .await
3561 .expect("DECLARE");
3562
3563 for sql in [
3564 "FETCH PRIOR FROM c",
3565 "FETCH BACKWARD 1 FROM c",
3566 "FETCH FIRST FROM c",
3567 "FETCH LAST FROM c",
3568 "FETCH ABSOLUTE 1 FROM c",
3569 "FETCH RELATIVE 1 FROM c",
3570 ] {
3571 let err = client
3572 .simple_query(sql)
3573 .await
3574 .expect_err(&format!("{sql} must fail"));
3575 let db_err = err.as_db_error().expect("typed PG error");
3576 assert_eq!(db_err.code().code(), "0A000", "{sql}: got {db_err:?}");
3577 }
3578
3579 client.simple_query("CLOSE c").await.expect("CLOSE");
3580 handle.abort();
3581 }
3582
3583 #[tokio::test]
3586 async fn cursor_for_non_subscribe_rejected() {
3587 let (_db, addr, handle) = spawn_with_data().await;
3588 let client = connect(addr).await;
3589
3590 let err = client
3591 .simple_query("DECLARE c CURSOR FOR SELECT 1")
3592 .await
3593 .expect_err("DECLARE FOR SELECT must fail");
3594 let db_err = err.as_db_error().expect("typed PG error");
3595 assert_eq!(db_err.code().code(), "0A000", "got {db_err:?}");
3596
3597 handle.abort();
3598 }
3599
3600 #[tokio::test]
3602 async fn cursor_fetch_unknown_name_errors() {
3603 let (_db, addr, handle) = spawn_with_data().await;
3604 let client = connect(addr).await;
3605
3606 let err = client
3607 .simple_query("FETCH 1 FROM nope")
3608 .await
3609 .expect_err("must fail");
3610 let db_err = err.as_db_error().expect("typed PG error");
3611 assert_eq!(db_err.code().code(), "34000", "got {db_err:?}");
3612
3613 handle.abort();
3614 }
3615
3616 #[tokio::test]
3621 async fn cursor_fetch_preserves_leftover_rows_in_one_batch() {
3622 let (db, addr, handle) = spawn_with_retained_data().await;
3623 let client = connect(addr).await;
3624
3625 let src = db.source_untyped("trades").expect("source");
3626 let batch = arrow_array::RecordBatch::try_new(
3627 src.schema().clone(),
3628 vec![
3629 Arc::new(arrow_array::StringArray::from(vec!["AAPL", "GOOG"])),
3630 Arc::new(arrow_array::Float64Array::from(vec![1.0, 2.0])),
3631 ],
3632 )
3633 .expect("batch");
3634 src.push_arrow(batch).expect("push");
3635
3636 client
3637 .simple_query("DECLARE c CURSOR FOR SUBSCRIBE prices")
3638 .await
3639 .expect("DECLARE");
3640
3641 let first = client
3642 .simple_query("FETCH 1 FROM c")
3643 .await
3644 .expect("FETCH 1");
3645 let r1: Vec<&str> = first
3646 .iter()
3647 .filter_map(|m| match m {
3648 SimpleQueryMessage::Row(r) => r.get(0),
3649 _ => None,
3650 })
3651 .collect();
3652 assert_eq!(r1, vec!["AAPL"]);
3653
3654 let second = client
3655 .simple_query("FETCH 1 FROM c")
3656 .await
3657 .expect("FETCH 1");
3658 let r2: Vec<&str> = second
3659 .iter()
3660 .filter_map(|m| match m {
3661 SimpleQueryMessage::Row(r) => r.get(0),
3662 _ => None,
3663 })
3664 .collect();
3665 assert_eq!(r2, vec!["GOOG"]);
3666
3667 client.simple_query("CLOSE c").await.expect("CLOSE");
3668 handle.abort();
3669 }
3670
3671 #[tokio::test]
3673 async fn cursor_duplicate_declare_rejected() {
3674 let (_db, addr, handle) = spawn_with_data().await;
3675 let client = connect(addr).await;
3676
3677 client
3678 .simple_query("DECLARE c CURSOR FOR SUBSCRIBE prices")
3679 .await
3680 .expect("first DECLARE");
3681
3682 let err = client
3683 .simple_query("DECLARE c CURSOR FOR SUBSCRIBE prices")
3684 .await
3685 .expect_err("duplicate DECLARE must fail");
3686 let db_err = err.as_db_error().expect("typed PG error");
3687 assert_eq!(db_err.code().code(), "42P03", "got {db_err:?}");
3688
3689 client.simple_query("CLOSE c").await.expect("CLOSE");
3691 client
3692 .simple_query("DECLARE c CURSOR FOR SUBSCRIBE prices")
3693 .await
3694 .expect("re-DECLARE after CLOSE");
3695 client.simple_query("CLOSE c").await.expect("CLOSE again");
3696
3697 handle.abort();
3698 }
3699
3700 #[tokio::test]
3702 async fn cursor_name_case_insensitive() {
3703 let (db, addr, handle) = spawn_with_retained_data().await;
3704 let client = connect(addr).await;
3705 push_one_trade(&db, "AAPL", 1.0).await;
3706
3707 client
3708 .simple_query("DECLARE MyCursor CURSOR FOR SUBSCRIBE prices")
3709 .await
3710 .expect("DECLARE");
3711
3712 let messages = client
3713 .simple_query("FETCH 1 FROM mycursor")
3714 .await
3715 .expect("FETCH from lowercased name");
3716 let row_count = messages
3717 .iter()
3718 .filter(|m| matches!(m, SimpleQueryMessage::Row(_)))
3719 .count();
3720 assert_eq!(row_count, 1);
3721
3722 client.simple_query("CLOSE MYCURSOR").await.expect("CLOSE");
3723 handle.abort();
3724 }
3725}