Skip to main content

laminardb/
config.rs

1//! TOML configuration parsing for LaminarDB server.
2//!
3//! Supports `${VAR}` and `${VAR:-default}` environment variable substitution.
4
5use std::collections::HashSet;
6use std::path::Path;
7use std::sync::LazyLock;
8use std::time::Duration;
9
10use laminar_core::state::StateBackendConfig;
11use regex::Regex;
12use serde::Deserialize;
13
14/// Regex for `${VAR}` and `${VAR:-default}` patterns.
15static ENV_VAR_RE: LazyLock<Regex> = LazyLock::new(|| {
16    Regex::new(r"\$\{([A-Za-z_][A-Za-z0-9_]*)(?::-([^}]*))?\}").expect("valid regex")
17});
18
19/// NIST baseline; MD5 has no work factor, so length is the only knob.
20const MIN_PGWIRE_PASSWORD_LEN: usize = 12;
21
22/// Load, parse, and validate a LaminarDB configuration file.
23pub fn load_config(path: &Path) -> Result<ServerConfig, ConfigError> {
24    let raw = std::fs::read_to_string(path).map_err(|e| ConfigError::FileRead {
25        path: path.to_path_buf(),
26        source: e,
27    })?;
28
29    let substituted = substitute_env_vars(&raw)?;
30    let config: ServerConfig =
31        toml::from_str(&substituted).map_err(|e| ConfigError::ParseError {
32            path: path.to_path_buf(),
33            source: e,
34        })?;
35
36    validate_config(&config)?;
37    Ok(config)
38}
39
40/// Substitute `${VAR}` and `${VAR:-default}` patterns with environment values.
41fn substitute_env_vars(input: &str) -> Result<String, ConfigError> {
42    let mut errors = Vec::new();
43    let result = ENV_VAR_RE.replace_all(input, |caps: &regex::Captures| {
44        let var_name = &caps[1];
45        match std::env::var(var_name) {
46            Ok(val) => val,
47            Err(_) => {
48                if let Some(default) = caps.get(2) {
49                    default.as_str().to_string()
50                } else {
51                    errors.push(var_name.to_string());
52                    String::new()
53                }
54            }
55        }
56    });
57
58    if !errors.is_empty() {
59        return Err(ConfigError::MissingEnvVars { vars: errors });
60    }
61
62    Ok(result.into_owned())
63}
64
65fn validate_config(config: &ServerConfig) -> Result<(), ConfigError> {
66    let mut errors = Vec::new();
67
68    // Collect all pipeline names
69    let pipeline_names: HashSet<&str> = config.pipelines.iter().map(|p| p.name.as_str()).collect();
70
71    // Validate: sink must reference an existing pipeline
72    for sink in &config.sinks {
73        if !pipeline_names.contains(sink.pipeline.as_str()) {
74            errors.push(format!(
75                "sink '{}' references unknown pipeline '{}'",
76                sink.name, sink.pipeline
77            ));
78        }
79    }
80
81    // Validate: no duplicate names within a section
82    let mut seen_sources = HashSet::new();
83    for source in &config.sources {
84        if !seen_sources.insert(&source.name) {
85            errors.push(format!("duplicate source name: '{}'", source.name));
86        }
87    }
88
89    let mut seen_pipelines = HashSet::new();
90    for pipeline in &config.pipelines {
91        if !seen_pipelines.insert(&pipeline.name) {
92            errors.push(format!("duplicate pipeline name: '{}'", pipeline.name));
93        }
94    }
95
96    let mut seen_sinks = HashSet::new();
97    for sink in &config.sinks {
98        if !seen_sinks.insert(&sink.name) {
99            errors.push(format!("duplicate sink name: '{}'", sink.name));
100        }
101    }
102
103    let mut seen_lookups = HashSet::new();
104    for lookup in &config.lookups {
105        if !seen_lookups.insert(&lookup.name) {
106            errors.push(format!("duplicate lookup name: '{}'", lookup.name));
107        }
108    }
109
110    // Validate: bind address is parseable
111    if config.server.bind.parse::<std::net::SocketAddr>().is_err() {
112        errors.push(format!(
113            "invalid server bind address: '{}'",
114            config.server.bind
115        ));
116    }
117    if let Some(addr) = &config.server.pgwire_bind {
118        if addr.parse::<std::net::SocketAddr>().is_err() {
119            errors.push(format!("invalid server pgwire_bind address: '{}'", addr));
120        }
121    }
122    for (user, password) in &config.server.pgwire_users {
123        if user.is_empty() {
124            errors.push("pgwire_users contains an empty username".to_string());
125        }
126        let pw = password.expose();
127        if let Some(rest) = pw.strip_prefix("md5") {
128            // pg_authid-style pre-hash: 'md5' + lowercase-hex(md5(password ‖ user)).
129            // Strict shape so a typo isn't silently treated as plaintext.
130            let valid =
131                rest.len() == 32 && rest.chars().all(|c| matches!(c, '0'..='9' | 'a'..='f'));
132            if !valid {
133                errors.push(format!(
134                    "pgwire_users['{user}']: pre-hashed value must be 'md5' \
135                     followed by 32 lowercase hex characters"
136                ));
137            }
138        } else if password.len() < MIN_PGWIRE_PASSWORD_LEN {
139            errors.push(format!(
140                "pgwire_users['{user}']: password must be at least {MIN_PGWIRE_PASSWORD_LEN} characters"
141            ));
142        }
143    }
144    if config.server.pgwire_max_connections == 0 {
145        errors.push(
146            "pgwire_max_connections must be > 0; remove pgwire_bind to disable the listener"
147                .to_string(),
148        );
149    }
150    match (
151        &config.server.pgwire_tls_cert,
152        &config.server.pgwire_tls_key,
153    ) {
154        (Some(_), None) | (None, Some(_)) => {
155            errors.push("pgwire_tls_cert and pgwire_tls_key must be set together".to_string());
156        }
157        (Some(cert), Some(key)) => {
158            if !cert.exists() {
159                errors.push(format!("pgwire_tls_cert not found: {}", cert.display()));
160            }
161            if !key.exists() {
162                errors.push(format!("pgwire_tls_key not found: {}", key.display()));
163            }
164        }
165        (None, None) => {}
166    }
167    match config.server.pgwire_tls_min_version.as_str() {
168        "1.2" | "1.3" => {}
169        other => errors.push(format!(
170            "pgwire_tls_min_version must be \"1.2\" or \"1.3\" (got \"{other}\")"
171        )),
172    }
173    if let Some(ca) = &config.server.pgwire_tls_client_ca {
174        if config.server.pgwire_tls_cert.is_none() {
175            errors.push(
176                "pgwire_tls_client_ca requires pgwire_tls_cert + pgwire_tls_key (mTLS \
177                 layers on top of server TLS)"
178                    .to_string(),
179            );
180        }
181        if !ca.exists() {
182            errors.push(format!("pgwire_tls_client_ca not found: {}", ca.display()));
183        }
184    }
185
186    // Validate: cluster mode requires discovery and coordination
187    if config.server.mode == "cluster" {
188        if config.discovery.is_none() {
189            errors.push("mode = \"cluster\" requires a [discovery] section".to_string());
190        }
191        if config.coordination.is_none() {
192            errors.push("mode = \"cluster\" requires a [coordination] section".to_string());
193        }
194        if config.node_id.is_none() {
195            errors.push("mode = \"cluster\" requires node_id to be set".to_string());
196        }
197        // Distributed 2PC has a per-barrier cost of ~1-3s (manifest
198        // persist + durability gate + sink commit). Cadences tighter
199        // than 2s spend more than half their time on coordination.
200        if config.checkpoint.interval < Duration::from_secs(2) {
201            errors.push(format!(
202                "mode = \"cluster\": checkpoint.interval = {:?} is too tight; minimum is 2s",
203                config.checkpoint.interval,
204            ));
205        }
206    }
207
208    validate_ai(config, &mut errors);
209
210    if !errors.is_empty() {
211        return Err(ConfigError::ValidationErrors { errors });
212    }
213
214    Ok(())
215}
216
217/// Top-level server configuration deserialized from `laminardb.toml`.
218#[derive(Debug, Clone, PartialEq, Deserialize)]
219pub struct ServerConfig {
220    #[serde(default)]
221    pub server: ServerSection,
222    #[serde(default)]
223    pub state: StateBackendConfig,
224    #[serde(default)]
225    pub checkpoint: CheckpointSection,
226    #[serde(default, rename = "source")]
227    pub sources: Vec<SourceConfig>,
228    #[serde(default, rename = "lookup")]
229    pub lookups: Vec<LookupConfig>,
230    #[serde(default, rename = "pipeline")]
231    pub pipelines: Vec<PipelineConfig>,
232    #[serde(default, rename = "sink")]
233    pub sinks: Vec<SinkConfig>,
234    /// Raw SQL DDL executed before `start()`, as an alternative to structured sections.
235    #[serde(default)]
236    pub sql: Option<String>,
237    pub discovery: Option<DiscoverySection>,
238    pub coordination: Option<CoordinationSection>,
239    pub node_id: Option<String>,
240    /// `[ai]` — AI provider wiring and per-task default models.
241    #[serde(default)]
242    pub ai: AiSection,
243    /// `[models.<name>]` — the AI model registry (top-level, per the contract).
244    #[serde(default)]
245    pub models: std::collections::HashMap<String, ModelConfig>,
246}
247
248/// `[server]` section.
249#[derive(Debug, Clone, PartialEq, Deserialize)]
250pub struct ServerSection {
251    #[serde(default = "default_mode")]
252    pub mode: String,
253    #[serde(default = "default_bind")]
254    pub bind: String,
255    /// Postgres wire bind address; `None` disables it.
256    #[serde(default)]
257    pub pgwire_bind: Option<String>,
258    /// MD5 auth users. Empty → trust auth (loopback only).
259    #[serde(default)]
260    pub pgwire_users: std::collections::HashMap<String, Secret>,
261    /// Required true to bind pgwire on a non-loopback address.
262    #[serde(default)]
263    pub pgwire_allow_remote: bool,
264    /// PEM cert; pair with `pgwire_tls_key` to enable TLS.
265    #[serde(default)]
266    pub pgwire_tls_cert: Option<std::path::PathBuf>,
267    /// PEM private key (PKCS#8 or RSA).
268    #[serde(default)]
269    pub pgwire_tls_key: Option<std::path::PathBuf>,
270    /// PEM CA bundle. Setting this requires every connecting client to
271    /// present a certificate chained to one of these roots (mTLS).
272    #[serde(default)]
273    pub pgwire_tls_client_ca: Option<std::path::PathBuf>,
274    /// Concurrent session cap; excess accepts close immediately.
275    #[serde(default = "default_pgwire_max_connections")]
276    pub pgwire_max_connections: usize,
277    /// Per-IP auth-failure cap in a 60s rolling window. 0 disables.
278    #[serde(default = "default_pgwire_max_auth_failures_per_min")]
279    pub pgwire_max_auth_failures_per_min: u32,
280    /// Minimum TLS protocol version: `"1.2"` (default) or `"1.3"`. Pinning
281    /// to `"1.3"` is the PCI-DSS / FedRAMP-High posture; rustls already
282    /// disables TLS 1.0/1.1 unconditionally.
283    #[serde(default = "default_pgwire_tls_min_version")]
284    pub pgwire_tls_min_version: String,
285}
286
287fn default_pgwire_max_connections() -> usize {
288    256
289}
290
291fn default_pgwire_max_auth_failures_per_min() -> u32 {
292    10
293}
294
295fn default_pgwire_tls_min_version() -> String {
296    "1.2".to_string()
297}
298
299impl Default for ServerSection {
300    fn default() -> Self {
301        Self {
302            mode: default_mode(),
303            bind: default_bind(),
304            pgwire_bind: None,
305            pgwire_users: std::collections::HashMap::new(),
306            pgwire_allow_remote: false,
307            pgwire_tls_cert: None,
308            pgwire_tls_key: None,
309            pgwire_tls_client_ca: None,
310            pgwire_max_connections: default_pgwire_max_connections(),
311            pgwire_max_auth_failures_per_min: default_pgwire_max_auth_failures_per_min(),
312            pgwire_tls_min_version: default_pgwire_tls_min_version(),
313        }
314    }
315}
316
317/// String that redacts itself in `Debug` output.
318#[derive(Clone, PartialEq, Eq, Deserialize)]
319#[serde(transparent)]
320pub struct Secret(String);
321
322impl Secret {
323    #[cfg(test)]
324    pub fn new(value: impl Into<String>) -> Self {
325        Self(value.into())
326    }
327
328    pub fn expose(&self) -> &str {
329        &self.0
330    }
331
332    pub fn len(&self) -> usize {
333        self.0.chars().count()
334    }
335}
336
337impl std::fmt::Debug for Secret {
338    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
339        f.write_str("[REDACTED]")
340    }
341}
342
343/// `[checkpoint]` section.
344#[derive(Debug, Clone, PartialEq, Deserialize)]
345pub struct CheckpointSection {
346    /// Storage URL: file:///path, s3://bucket/prefix, gs://bucket/prefix.
347    #[serde(default = "default_checkpoint_url")]
348    pub url: String,
349    #[serde(default = "default_checkpoint_interval", with = "humantime_serde")]
350    pub interval: Duration,
351    /// Number of recent checkpoints to retain before pruning.
352    #[serde(default = "default_max_retained")]
353    pub max_retained: usize,
354    /// Cloud storage credentials/config (e.g., `aws_access_key_id`).
355    #[serde(default)]
356    pub storage: std::collections::HashMap<String, String>,
357}
358
359impl Default for CheckpointSection {
360    fn default() -> Self {
361        Self {
362            url: default_checkpoint_url(),
363            interval: default_checkpoint_interval(),
364            max_retained: default_max_retained(),
365            storage: std::collections::HashMap::new(),
366        }
367    }
368}
369
370fn default_ai_max_concurrency() -> usize {
371    8
372}
373
374/// `[ai]` — provider wiring and per-task defaults. Models live in the top-level
375/// `[models.*]` tables, per the configuration contract.
376#[derive(Debug, Clone, PartialEq, Deserialize, Default)]
377pub struct AiSection {
378    /// `[ai.providers.<name>]` — transport endpoints.
379    #[serde(default)]
380    pub providers: std::collections::HashMap<String, ProviderConfig>,
381    /// `[ai.defaults]` — task name → default model name (e.g. `classify = "finbert"`).
382    #[serde(default)]
383    pub defaults: std::collections::HashMap<String, String>,
384}
385
386/// `[ai.providers.<name>]`.
387#[derive(Debug, Clone, PartialEq, Deserialize, Default)]
388pub struct ProviderConfig {
389    /// Transport kind: `anthropic`, `openai`, or `local`. Inferred from the
390    /// provider name when omitted (for the canonical names).
391    #[serde(default)]
392    pub kind: Option<String>,
393    /// Name of the environment variable holding the API key (remote providers).
394    /// The key itself is never stored in config.
395    #[serde(default)]
396    pub api_key_env: Option<String>,
397    /// Base URL (remote). Defaults per kind when omitted.
398    #[serde(default)]
399    pub base_url: Option<String>,
400    /// Maximum concurrent requests issued per batch (remote).
401    #[serde(default = "default_ai_max_concurrency")]
402    pub max_concurrency: usize,
403    /// Steady requests-per-second cap (remote). When set, calls are paced by a
404    /// token bucket — bursts are shaped, not sent unbounded. Unset = no limit.
405    #[serde(default)]
406    pub requests_per_second: Option<u32>,
407    /// Model cache directory or `object_store` URI (local provider).
408    #[serde(default)]
409    pub cache_dir: Option<String>,
410}
411
412/// `[models.<name>]`.
413#[derive(Debug, Clone, PartialEq, Deserialize)]
414pub struct ModelConfig {
415    /// `local` or `remote`.
416    pub kind: String,
417    /// One task (`task = "classify"`) or several (`task = ["classify", "extract"]`).
418    pub task: TaskSpec,
419    /// Remote: the provider name (a key in `[ai.providers]`).
420    #[serde(default)]
421    pub provider: Option<String>,
422    /// Remote: the provider-specific model id.
423    #[serde(default)]
424    pub model: Option<String>,
425    /// Local: the weight source (`hf:org/repo` or a file/`object_store` URI).
426    #[serde(default)]
427    pub source: Option<String>,
428}
429
430/// A model's task list, written as a single string or an array.
431#[derive(Debug, Clone, PartialEq, Deserialize)]
432#[serde(untagged)]
433pub enum TaskSpec {
434    /// A single task.
435    One(String),
436    /// Several tasks.
437    Many(Vec<String>),
438}
439
440impl TaskSpec {
441    /// The task names as a list.
442    #[must_use]
443    pub fn tasks(&self) -> Vec<String> {
444        match self {
445            TaskSpec::One(t) => vec![t.clone()],
446            TaskSpec::Many(ts) => ts.clone(),
447        }
448    }
449}
450
451/// Structural validation of the `[ai]` / `[models]` config — references resolve
452/// and required fields are present. Semantic checks (task names, label seam)
453/// happen when the registry is built.
454fn validate_ai(config: &ServerConfig, errors: &mut Vec<String>) {
455    for (name, model) in &config.models {
456        match model.kind.as_str() {
457            "remote" => {
458                match &model.provider {
459                    Some(p) if config.ai.providers.contains_key(p) => {}
460                    Some(p) => errors.push(format!("model '{name}': unknown provider '{p}'")),
461                    None => {
462                        errors.push(format!(
463                            "model '{name}': remote model requires a 'provider'"
464                        ));
465                    }
466                }
467                if model.model.is_none() {
468                    errors.push(format!(
469                        "model '{name}': remote model requires a 'model' id"
470                    ));
471                }
472            }
473            "local" => {
474                if model.source.is_none() {
475                    errors.push(format!("model '{name}': local model requires a 'source'"));
476                }
477            }
478            other => errors.push(format!(
479                "model '{name}': kind must be 'local' or 'remote', got '{other}'"
480            )),
481        }
482        if model.task.tasks().is_empty() {
483            errors.push(format!("model '{name}': at least one task is required"));
484        }
485    }
486
487    for (name, provider) in &config.ai.providers {
488        // Mirror runtime kind resolution exactly (explicit `kind`, else the
489        // provider name) so validation can't disagree with how the provider is
490        // actually built — a `cache_dir` on a remote provider must not excuse a
491        // missing key.
492        let kind = provider.kind.as_deref().unwrap_or(name.as_str());
493        if kind == "local" {
494            // A local provider must carry a cache_dir, or no LocalProvider can be
495            // built and local models would fail at runtime — reject it now.
496            if provider.cache_dir.is_none() {
497                errors.push(format!(
498                    "provider '{name}': local provider requires a 'cache_dir'"
499                ));
500            }
501        } else if provider.api_key_env.is_none() {
502            errors.push(format!(
503                "provider '{name}': remote provider requires 'api_key_env'"
504            ));
505        }
506    }
507
508    for (task, model_name) in &config.ai.defaults {
509        if !config.models.contains_key(model_name) {
510            errors.push(format!(
511                "ai.defaults.{task} references unknown model '{model_name}'"
512            ));
513        }
514    }
515}
516
517/// `[[source]]` section.
518#[derive(Debug, Clone, PartialEq, Deserialize)]
519pub struct SourceConfig {
520    pub name: String,
521    /// Connector type: "kafka", "postgres_cdc", "mysql_cdc", "generator".
522    pub connector: String,
523    #[serde(default = "default_format")]
524    pub format: String,
525    #[serde(default)]
526    pub properties: toml::Table,
527    #[serde(default)]
528    pub schema: Vec<ColumnDef>,
529    pub watermark: Option<WatermarkConfig>,
530}
531
532/// Column definition within a source or lookup schema.
533#[derive(Debug, Clone, PartialEq, Deserialize)]
534pub struct ColumnDef {
535    pub name: String,
536    #[serde(rename = "type")]
537    pub data_type: String,
538    #[serde(default = "default_true")]
539    pub nullable: bool,
540}
541
542/// Watermark configuration for a source.
543#[derive(Debug, Clone, PartialEq, Deserialize)]
544pub struct WatermarkConfig {
545    pub column: String,
546    #[serde(default = "default_max_ooo", with = "humantime_serde")]
547    pub max_out_of_orderness: Duration,
548}
549
550/// `[[lookup]]` section: lookup table for enrichment joins.
551#[derive(Debug, Clone, PartialEq, Deserialize)]
552pub struct LookupConfig {
553    pub name: String,
554    /// Connector type: "postgres", "mysql", "redis", "csv".
555    pub connector: String,
556    #[serde(default = "default_lookup_strategy")]
557    pub strategy: String,
558    #[serde(default)]
559    pub cache: LookupCacheConfig,
560    #[serde(default)]
561    pub properties: toml::Table,
562    #[serde(default)]
563    pub primary_key: Vec<String>,
564    #[serde(default)]
565    pub schema: Vec<ColumnDef>,
566}
567
568/// Cache configuration for lookup tables.
569#[derive(Debug, Clone, PartialEq, Deserialize)]
570pub struct LookupCacheConfig {
571    #[serde(default = "default_cache_size")]
572    pub size_bytes: u64,
573    #[serde(default = "default_cache_ttl", with = "humantime_serde")]
574    pub ttl: Duration,
575}
576
577impl Default for LookupCacheConfig {
578    fn default() -> Self {
579        Self {
580            size_bytes: default_cache_size(),
581            ttl: default_cache_ttl(),
582        }
583    }
584}
585
586/// `[[pipeline]]` section.
587#[derive(Debug, Clone, PartialEq, Deserialize)]
588pub struct PipelineConfig {
589    pub name: String,
590    pub sql: String,
591}
592
593/// `[[sink]]` section.
594#[derive(Debug, Clone, PartialEq, Deserialize)]
595pub struct SinkConfig {
596    pub name: String,
597    pub pipeline: String,
598    /// Connector type: "kafka", "postgres", "delta-lake", "iceberg", "stdout".
599    pub connector: String,
600    #[serde(default = "default_delivery")]
601    pub delivery: String,
602    #[serde(default)]
603    pub properties: toml::Table,
604}
605
606/// `[discovery]` section: delta node discovery.
607#[derive(Debug, Clone, PartialEq, Deserialize)]
608pub struct DiscoverySection {
609    pub strategy: String,
610    #[serde(default)]
611    pub seeds: Vec<String>,
612    #[serde(default = "default_gossip_port")]
613    pub gossip_port: u16,
614}
615
616/// `[coordination]` section: delta coordination.
617#[derive(Debug, Clone, PartialEq, Deserialize)]
618pub struct CoordinationSection {
619    #[serde(default = "default_coordination_strategy")]
620    pub strategy: String,
621    #[serde(default = "default_raft_port")]
622    pub raft_port: u16,
623    #[serde(default = "default_election_timeout", with = "humantime_serde")]
624    pub election_timeout: Duration,
625    #[serde(default = "default_heartbeat_interval", with = "humantime_serde")]
626    pub heartbeat_interval: Duration,
627}
628
629#[derive(Debug, thiserror::Error)]
630pub enum ConfigError {
631    #[error("failed to read config file '{}': {source}", path.display())]
632    FileRead {
633        path: std::path::PathBuf,
634        source: std::io::Error,
635    },
636    #[error("failed to parse config file '{}': {source}", path.display())]
637    ParseError {
638        path: std::path::PathBuf,
639        source: toml::de::Error,
640    },
641    #[error("missing environment variables: {}", vars.join(", "))]
642    MissingEnvVars { vars: Vec<String> },
643    #[error("configuration validation errors:\n  - {}", errors.join("\n  - "))]
644    ValidationErrors { errors: Vec<String> },
645}
646
647fn default_mode() -> String {
648    "embedded".to_string()
649}
650fn default_bind() -> String {
651    "127.0.0.1:8080".to_string()
652}
653fn default_checkpoint_url() -> String {
654    let base = std::env::temp_dir();
655    let path = base.join("laminardb");
656    let path_str = path.to_string_lossy().replace('\\', "/");
657    if path_str.starts_with('/') {
658        format!("file://{path_str}")
659    } else {
660        format!("file:///{path_str}")
661    }
662}
663fn default_max_retained() -> usize {
664    10
665}
666fn default_checkpoint_interval() -> Duration {
667    Duration::from_secs(10)
668}
669fn default_format() -> String {
670    "json".to_string()
671}
672fn default_max_ooo() -> Duration {
673    Duration::from_secs(5)
674}
675fn default_lookup_strategy() -> String {
676    "poll".to_string()
677}
678fn default_true() -> bool {
679    true
680}
681fn default_cache_size() -> u64 {
682    100 * 1024 * 1024
683}
684fn default_cache_ttl() -> Duration {
685    Duration::from_secs(300)
686}
687fn default_delivery() -> String {
688    "at_least_once".to_string()
689}
690fn default_gossip_port() -> u16 {
691    7946
692}
693fn default_coordination_strategy() -> String {
694    "raft".to_string()
695}
696fn default_raft_port() -> u16 {
697    7947
698}
699fn default_election_timeout() -> Duration {
700    Duration::from_millis(1500)
701}
702fn default_heartbeat_interval() -> Duration {
703    Duration::from_millis(300)
704}
705
706// ---------------------------------------------------------------------------
707// Tests
708// ---------------------------------------------------------------------------
709
710#[cfg(test)]
711mod tests {
712    use super::*;
713
714    const AI_TOML: &str = r#"
715[server]
716
717[ai.providers.anthropic]
718api_key_env = "LAMINAR_ANTHROPIC_API_KEY"
719base_url = "https://api.anthropic.com"
720max_concurrency = 8
721
722[ai.providers.openai]
723api_key_env = "LAMINAR_OPENAI_API_KEY"
724base_url = "https://api.openai.com/v1"
725
726[ai.providers.local]
727cache_dir = "/var/lib/laminar/models"
728
729[models.finbert]
730kind = "local"
731source = "hf:onnx-community/finbert"
732task = "classify"
733
734[models.haiku]
735kind = "remote"
736provider = "anthropic"
737model = "claude-haiku-4-5-20251001"
738task = ["classify", "extract", "complete"]
739
740[ai.defaults]
741classify = "finbert"
742complete = "haiku"
743"#;
744
745    #[test]
746    fn parses_ai_section_and_models() {
747        let config: ServerConfig = toml::from_str(AI_TOML).unwrap();
748        assert_eq!(config.ai.providers.len(), 3);
749        assert_eq!(
750            config.ai.providers["anthropic"].api_key_env.as_deref(),
751            Some("LAMINAR_ANTHROPIC_API_KEY")
752        );
753        assert_eq!(config.ai.providers["openai"].max_concurrency, 8);
754        assert_eq!(
755            config.ai.providers["local"].cache_dir.as_deref(),
756            Some("/var/lib/laminar/models")
757        );
758        assert_eq!(config.models["finbert"].task.tasks(), vec!["classify"]);
759        assert_eq!(
760            config.models["haiku"].task.tasks(),
761            vec!["classify", "extract", "complete"]
762        );
763        assert_eq!(config.ai.defaults["classify"], "finbert");
764        validate_config(&config).unwrap();
765    }
766
767    #[test]
768    fn rejects_local_provider_without_cache_dir() {
769        let toml = r#"
770[server]
771[ai.providers.local]
772[models.m]
773kind = "local"
774source = "hf:x/y"
775task = "classify"
776"#;
777        let config: ServerConfig = toml::from_str(toml).unwrap();
778        assert!(validate_config(&config).is_err());
779    }
780
781    #[test]
782    fn rejects_unknown_provider_and_default() {
783        let toml = r#"
784[server]
785[ai.providers.anthropic]
786api_key_env = "K"
787[models.bad]
788kind = "remote"
789provider = "ghost"
790model = "x"
791task = "classify"
792[ai.defaults]
793classify = "missing"
794"#;
795        let config: ServerConfig = toml::from_str(toml).unwrap();
796        let err = validate_config(&config).unwrap_err();
797        let msg = format!("{err:?}");
798        assert!(msg.contains("unknown provider 'ghost'"), "{msg}");
799        assert!(msg.contains("unknown model 'missing'"), "{msg}");
800    }
801
802    #[test]
803    fn rejects_remote_provider_without_api_key_env() {
804        let toml = r#"
805[server]
806[ai.providers.openai]
807base_url = "http://localhost:8000/v1"
808[models.m]
809kind = "remote"
810provider = "openai"
811model = "x"
812task = "embed"
813"#;
814        let config: ServerConfig = toml::from_str(toml).unwrap();
815        let err = validate_config(&config).unwrap_err();
816        assert!(format!("{err:?}").contains("requires 'api_key_env'"));
817    }
818
819    #[test]
820    fn local_model_requires_source() {
821        let toml = r#"
822[server]
823[models.m]
824kind = "local"
825task = "classify"
826"#;
827        let config: ServerConfig = toml::from_str(toml).unwrap();
828        let err = validate_config(&config).unwrap_err();
829        assert!(format!("{err:?}").contains("requires a 'source'"));
830    }
831
832    #[test]
833    fn test_parse_minimal_config() {
834        let toml = "[server]\n";
835        let config: ServerConfig = toml::from_str(toml).unwrap();
836        assert_eq!(config.server.mode, "embedded");
837        assert_eq!(config.server.bind, "127.0.0.1:8080");
838        assert!(config.sources.is_empty());
839        assert!(config.pipelines.is_empty());
840        assert!(config.sinks.is_empty());
841    }
842
843    #[test]
844    fn test_parse_full_embedded_config() {
845        let toml = r#"
846[server]
847mode = "embedded"
848bind = "127.0.0.1:8080"
849
850[state]
851backend = "in_process"
852
853[checkpoint]
854url = "file:///tmp/checkpoints"
855interval = "10s"
856mode = "aligned"
857
858[[source]]
859name = "trades"
860connector = "kafka"
861format = "json"
862[source.properties]
863brokers = "localhost:9092"
864topic = "trades"
865[[source.schema]]
866name = "symbol"
867type = "VARCHAR"
868nullable = false
869[[source.schema]]
870name = "price"
871type = "DOUBLE"
872[source.watermark]
873column = "trade_time"
874max_out_of_orderness = "5s"
875
876[[pipeline]]
877name = "vwap"
878sql = "SELECT symbol, SUM(price) FROM trades GROUP BY symbol"
879
880[[sink]]
881name = "output"
882pipeline = "vwap"
883connector = "kafka"
884[sink.properties]
885topic = "vwap_output"
886"#;
887
888        let config: ServerConfig = toml::from_str(toml).unwrap();
889        assert_eq!(config.sources.len(), 1);
890        assert_eq!(config.sources[0].name, "trades");
891        assert_eq!(config.sources[0].schema.len(), 2);
892        assert!(!config.sources[0].schema[0].nullable);
893        assert!(config.sources[0].schema[1].nullable); // default true
894        assert!(config.sources[0].watermark.is_some());
895        assert_eq!(config.pipelines.len(), 1);
896        assert_eq!(config.sinks.len(), 1);
897        assert_eq!(config.sinks[0].pipeline, "vwap");
898
899        validate_config(&config).unwrap();
900    }
901
902    #[test]
903    fn test_parse_full_delta_config() {
904        let toml = r#"
905node_id = "star-1"
906
907[server]
908mode = "cluster"
909bind = "0.0.0.0:8080"
910
911[state]
912backend = "local"
913path = "/data/state"
914vnode_capacity = 256
915
916[checkpoint]
917url = "s3://bucket/checkpoints"
918interval = "30s"
919snapshot_strategy = "fork_cow"
920
921[discovery]
922strategy = "static"
923seeds = ["node-1:7946", "node-2:7946"]
924gossip_port = 7946
925
926[coordination]
927strategy = "raft"
928raft_port = 7947
929election_timeout = "1500ms"
930heartbeat_interval = "300ms"
931
932[[source]]
933name = "orders"
934connector = "kafka"
935format = "avro"
936
937[[pipeline]]
938name = "enrichment"
939sql = "SELECT * FROM orders"
940parallelism = 8
941
942[[sink]]
943name = "output"
944pipeline = "enrichment"
945connector = "kafka"
946delivery = "exactly_once"
947"#;
948
949        let config: ServerConfig = toml::from_str(toml).unwrap();
950        assert_eq!(config.node_id.as_deref(), Some("star-1"));
951        assert_eq!(config.server.mode, "cluster");
952        assert!(matches!(config.state, StateBackendConfig::Local { .. }));
953        assert!(config.discovery.is_some());
954        assert!(config.coordination.is_some());
955
956        let coord = config.coordination.as_ref().unwrap();
957        assert_eq!(coord.election_timeout, Duration::from_millis(1500));
958        assert_eq!(coord.heartbeat_interval, Duration::from_millis(300));
959
960        validate_config(&config).unwrap();
961    }
962
963    #[test]
964    fn test_env_var_substitution_resolves() {
965        std::env::set_var("LAMINAR_TEST_VAR_1", "resolved_value");
966        let input = "brokers = \"${LAMINAR_TEST_VAR_1}\"";
967        let result = substitute_env_vars(input).unwrap();
968        assert_eq!(result, "brokers = \"resolved_value\"");
969        std::env::remove_var("LAMINAR_TEST_VAR_1");
970    }
971
972    #[test]
973    fn test_env_var_substitution_with_default() {
974        // Ensure the variable is NOT set
975        std::env::remove_var("LAMINAR_TEST_UNSET_VAR");
976        let input = "brokers = \"${LAMINAR_TEST_UNSET_VAR:-localhost:9092}\"";
977        let result = substitute_env_vars(input).unwrap();
978        assert_eq!(result, "brokers = \"localhost:9092\"");
979    }
980
981    #[test]
982    fn test_env_var_substitution_missing_errors() {
983        std::env::remove_var("LAMINAR_TEST_MISSING_1");
984        std::env::remove_var("LAMINAR_TEST_MISSING_2");
985        let input = "a = \"${LAMINAR_TEST_MISSING_1}\"\nb = \"${LAMINAR_TEST_MISSING_2}\"";
986        let err = substitute_env_vars(input).unwrap_err();
987        match err {
988            ConfigError::MissingEnvVars { vars } => {
989                assert!(vars.contains(&"LAMINAR_TEST_MISSING_1".to_string()));
990                assert!(vars.contains(&"LAMINAR_TEST_MISSING_2".to_string()));
991            }
992            _ => panic!("expected MissingEnvVars"),
993        }
994    }
995
996    #[test]
997    fn test_validate_sink_references_missing_pipeline() {
998        let toml = r#"
999[[pipeline]]
1000name = "exists"
1001sql = "SELECT 1"
1002
1003[[sink]]
1004name = "broken"
1005pipeline = "nonexistent"
1006connector = "kafka"
1007"#;
1008
1009        let config: ServerConfig = toml::from_str(toml).unwrap();
1010        let err = validate_config(&config).unwrap_err();
1011        match err {
1012            ConfigError::ValidationErrors { errors } => {
1013                assert!(errors[0].contains("nonexistent"));
1014            }
1015            _ => panic!("expected ValidationErrors"),
1016        }
1017    }
1018
1019    #[test]
1020    fn test_validate_duplicate_source_names() {
1021        let toml = r#"
1022[[source]]
1023name = "dup"
1024connector = "kafka"
1025
1026[[source]]
1027name = "dup"
1028connector = "kafka"
1029
1030[[pipeline]]
1031name = "p"
1032sql = "SELECT 1"
1033"#;
1034
1035        let config: ServerConfig = toml::from_str(toml).unwrap();
1036        let err = validate_config(&config).unwrap_err();
1037        match err {
1038            ConfigError::ValidationErrors { errors } => {
1039                assert!(errors.iter().any(|e| e.contains("duplicate source")));
1040            }
1041            _ => panic!("expected ValidationErrors"),
1042        }
1043    }
1044
1045    #[test]
1046    fn test_validate_duplicate_pipeline_names() {
1047        let toml = r#"
1048[[pipeline]]
1049name = "dup"
1050sql = "SELECT 1"
1051
1052[[pipeline]]
1053name = "dup"
1054sql = "SELECT 2"
1055"#;
1056
1057        let config: ServerConfig = toml::from_str(toml).unwrap();
1058        let err = validate_config(&config).unwrap_err();
1059        match err {
1060            ConfigError::ValidationErrors { errors } => {
1061                assert!(errors.iter().any(|e| e.contains("duplicate pipeline")));
1062            }
1063            _ => panic!("expected ValidationErrors"),
1064        }
1065    }
1066
1067    #[test]
1068    fn test_cluster_mode_rejects_tight_checkpoint_interval() {
1069        // Two-phase commit in cluster mode can't keep up with sub-2s
1070        // cadence.
1071        let toml = r#"
1072node_id = "n1"
1073
1074[server]
1075mode = "cluster"
1076
1077[checkpoint]
1078interval = "500ms"
1079
1080[discovery]
1081strategy = "static"
1082seeds = ["x:1"]
1083
1084[coordination]
1085strategy = "raft"
1086"#;
1087        let config: ServerConfig = toml::from_str(toml).unwrap();
1088        let err = validate_config(&config).unwrap_err();
1089        match err {
1090            ConfigError::ValidationErrors { errors } => {
1091                assert!(
1092                    errors.iter().any(|e| e.contains("too tight")),
1093                    "expected tight-interval error, got: {errors:?}",
1094                );
1095            }
1096            _ => panic!("expected ValidationErrors"),
1097        }
1098    }
1099
1100    #[test]
1101    fn test_validate_invalid_bind_address() {
1102        let toml = r#"
1103[server]
1104bind = "not-a-socket-addr"
1105"#;
1106
1107        let config: ServerConfig = toml::from_str(toml).unwrap();
1108        let err = validate_config(&config).unwrap_err();
1109        match err {
1110            ConfigError::ValidationErrors { errors } => {
1111                assert!(errors.iter().any(|e| e.contains("invalid server bind")));
1112            }
1113            _ => panic!("expected ValidationErrors"),
1114        }
1115    }
1116
1117    #[test]
1118    fn test_validate_zero_max_connections() {
1119        let toml = r#"
1120[server]
1121pgwire_max_connections = 0
1122"#;
1123        let config: ServerConfig = toml::from_str(toml).unwrap();
1124        let err = validate_config(&config).unwrap_err();
1125        match err {
1126            ConfigError::ValidationErrors { errors } => {
1127                assert!(
1128                    errors.iter().any(|e| e.contains("must be > 0")),
1129                    "errors: {errors:?}"
1130                );
1131            }
1132            _ => panic!("expected ValidationErrors"),
1133        }
1134    }
1135
1136    #[test]
1137    fn test_validate_client_ca_requires_server_cert() {
1138        let toml = r#"
1139[server]
1140pgwire_tls_client_ca = "/does/not/matter.pem"
1141"#;
1142        let config: ServerConfig = toml::from_str(toml).unwrap();
1143        let err = validate_config(&config).unwrap_err();
1144        match err {
1145            ConfigError::ValidationErrors { errors } => {
1146                assert!(
1147                    errors
1148                        .iter()
1149                        .any(|e| e.contains("requires pgwire_tls_cert")),
1150                    "errors: {errors:?}"
1151                );
1152            }
1153            _ => panic!("expected ValidationErrors"),
1154        }
1155    }
1156
1157    #[test]
1158    fn test_validate_rejects_unknown_tls_min_version() {
1159        let toml = r#"
1160[server]
1161pgwire_tls_min_version = "1.4"
1162"#;
1163        let config: ServerConfig = toml::from_str(toml).unwrap();
1164        let err = validate_config(&config).unwrap_err();
1165        match err {
1166            ConfigError::ValidationErrors { errors } => {
1167                assert!(
1168                    errors.iter().any(|e| e.contains("pgwire_tls_min_version")),
1169                    "errors: {errors:?}"
1170                );
1171            }
1172            _ => panic!("expected ValidationErrors"),
1173        }
1174    }
1175
1176    #[test]
1177    fn test_validate_accepts_well_formed_pre_hashed_pgwire_password() {
1178        let toml = r#"
1179[server]
1180[server.pgwire_users]
1181alice = "md55d41402abc4b2a76b9719d911017c592"
1182"#;
1183        let config: ServerConfig = toml::from_str(toml).unwrap();
1184        // 35-char pre-hashed value bypasses the MIN_PGWIRE_PASSWORD_LEN gate.
1185        validate_config(&config).expect("well-formed pre-hash must validate");
1186    }
1187
1188    #[test]
1189    fn test_validate_rejects_malformed_pre_hashed_pgwire_password() {
1190        // 'md5' prefix followed by non-hex — clearly meant to be pre-hashed
1191        // but malformed; rejected so a typo doesn't slip through as plaintext.
1192        let toml = r#"
1193[server]
1194[server.pgwire_users]
1195alice = "md5zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz"
1196"#;
1197        let config: ServerConfig = toml::from_str(toml).unwrap();
1198        let err = validate_config(&config).unwrap_err();
1199        match err {
1200            ConfigError::ValidationErrors { errors } => {
1201                assert!(
1202                    errors.iter().any(|e| e.contains("pre-hashed")),
1203                    "errors: {errors:?}",
1204                );
1205            }
1206            _ => panic!("expected ValidationErrors"),
1207        }
1208    }
1209
1210    #[test]
1211    fn test_validate_short_pgwire_password() {
1212        let toml = r#"
1213[server]
1214[server.pgwire_users]
1215alice = "short"
1216"#;
1217        let config: ServerConfig = toml::from_str(toml).unwrap();
1218        let err = validate_config(&config).unwrap_err();
1219        match err {
1220            ConfigError::ValidationErrors { errors } => {
1221                assert!(
1222                    errors.iter().any(|e| e.contains("at least 12 characters")),
1223                    "errors: {errors:?}"
1224                );
1225            }
1226            _ => panic!("expected ValidationErrors"),
1227        }
1228    }
1229
1230    #[test]
1231    fn test_validate_pgwire_password_redacted_in_debug() {
1232        let toml = r#"
1233[server]
1234[server.pgwire_users]
1235alice = "wonderland-key"
1236"#;
1237        let config: ServerConfig = toml::from_str(toml).unwrap();
1238        validate_config(&config).unwrap();
1239        let dump = format!("{:?}", config.server);
1240        assert!(!dump.contains("wonderland"), "secret leaked: {dump}");
1241        assert!(
1242            dump.contains("REDACTED"),
1243            "expected REDACTED marker: {dump}"
1244        );
1245    }
1246
1247    #[test]
1248    fn test_default_values_applied() {
1249        let config = ServerConfig {
1250            server: ServerSection::default(),
1251            state: StateBackendConfig::default(),
1252            checkpoint: CheckpointSection::default(),
1253            sources: vec![],
1254            lookups: vec![],
1255            pipelines: vec![],
1256            sinks: vec![],
1257            discovery: None,
1258            coordination: None,
1259            node_id: None,
1260            sql: None,
1261            ai: Default::default(),
1262            models: Default::default(),
1263        };
1264
1265        assert_eq!(config.server.mode, "embedded");
1266        assert_eq!(config.server.bind, "127.0.0.1:8080");
1267        assert!(matches!(config.state, StateBackendConfig::InProcess { .. }));
1268        assert_eq!(config.checkpoint.interval, Duration::from_secs(10));
1269    }
1270
1271    #[test]
1272    fn test_checkpoint_duration_parsing() {
1273        let toml = r#"
1274[checkpoint]
1275interval = "30s"
1276"#;
1277        let config: ServerConfig = toml::from_str(toml).unwrap();
1278        assert_eq!(config.checkpoint.interval, Duration::from_secs(30));
1279
1280        let toml2 = r#"
1281[checkpoint]
1282interval = "1m"
1283"#;
1284        let config2: ServerConfig = toml::from_str(toml2).unwrap();
1285        assert_eq!(config2.checkpoint.interval, Duration::from_secs(60));
1286
1287        let toml3 = r#"
1288[checkpoint]
1289interval = "500ms"
1290"#;
1291        let config3: ServerConfig = toml::from_str(toml3).unwrap();
1292        assert_eq!(config3.checkpoint.interval, Duration::from_millis(500));
1293    }
1294
1295    #[test]
1296    fn test_watermark_config_parsing() {
1297        let toml = r#"
1298[[source]]
1299name = "s"
1300connector = "kafka"
1301[source.watermark]
1302column = "event_time"
1303max_out_of_orderness = "10s"
1304"#;
1305        let config: ServerConfig = toml::from_str(toml).unwrap();
1306        let wm = config.sources[0].watermark.as_ref().unwrap();
1307        assert_eq!(wm.column, "event_time");
1308        assert_eq!(wm.max_out_of_orderness, Duration::from_secs(10));
1309    }
1310
1311    #[test]
1312    fn test_lookup_cache_defaults() {
1313        let cache = LookupCacheConfig::default();
1314        assert_eq!(cache.size_bytes, 100 * 1024 * 1024);
1315        assert_eq!(cache.ttl, Duration::from_secs(300));
1316    }
1317
1318    #[test]
1319    fn test_cluster_mode_requires_discovery() {
1320        let toml = r#"
1321[server]
1322mode = "cluster"
1323
1324[checkpoint]
1325interval = "10s"
1326"#;
1327        let config: ServerConfig = toml::from_str(toml).unwrap();
1328        let err = validate_config(&config).unwrap_err();
1329        match err {
1330            ConfigError::ValidationErrors { errors } => {
1331                assert!(errors.iter().any(|e| e.contains("[discovery]")));
1332                assert!(errors.iter().any(|e| e.contains("[coordination]")));
1333                assert!(errors.iter().any(|e| e.contains("node_id")));
1334            }
1335            _ => panic!("expected ValidationErrors"),
1336        }
1337    }
1338
1339    #[test]
1340    fn test_source_schema_parsing() {
1341        let toml = r#"
1342[[source]]
1343name = "test"
1344connector = "kafka"
1345[[source.schema]]
1346name = "id"
1347type = "BIGINT"
1348nullable = false
1349[[source.schema]]
1350name = "name"
1351type = "VARCHAR"
1352"#;
1353        let config: ServerConfig = toml::from_str(toml).unwrap();
1354        assert_eq!(config.sources[0].schema.len(), 2);
1355        assert_eq!(config.sources[0].schema[0].data_type, "BIGINT");
1356        assert!(!config.sources[0].schema[0].nullable);
1357        assert_eq!(config.sources[0].schema[1].data_type, "VARCHAR");
1358        assert!(config.sources[0].schema[1].nullable); // default
1359    }
1360
1361    #[test]
1362    fn test_config_error_display_messages() {
1363        let err = ConfigError::MissingEnvVars {
1364            vars: vec!["A".to_string(), "B".to_string()],
1365        };
1366        assert_eq!(err.to_string(), "missing environment variables: A, B");
1367
1368        let err = ConfigError::ValidationErrors {
1369            errors: vec!["error one".to_string(), "error two".to_string()],
1370        };
1371        let msg = err.to_string();
1372        assert!(msg.contains("error one"));
1373        assert!(msg.contains("error two"));
1374    }
1375}