Skip to main content

laminar_sql/datafusion/
ai_udf.rs

1//! Marker UDFs for the `ai_*` SQL functions.
2//!
3//! These exist so that `ai_classify`, `ai_embed`, … resolve as known functions
4//! with a fixed return type, the same way `tumble()` and `watermark()` are
5//! markers the engine rewrites rather than evaluates. In the normal path the AI
6//! detector lifts an `ai_*` call out of the projection and replaces it with a
7//! computed column before DataFusion plans the residual query, so these markers
8//! are never invoked. If one survives to execution — used in a position the
9//! detector does not rewrite — `invoke` returns a clear error rather than
10//! producing a wrong value.
11//!
12//! The function name → return type map here must stay in step with the
13//! name → task map in `laminar-db`'s `sql_analysis` (the eight locked AI
14//! functions). The crate does not depend on `laminar-ai`, so the two lists are
15//! intentionally independent.
16
17use std::any::Any;
18use std::hash::{Hash, Hasher};
19use std::sync::Arc;
20
21use arrow::datatypes::{DataType, Field};
22use datafusion_common::{exec_err, Result};
23use datafusion_expr::{
24    ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
25};
26
27/// A placeholder scalar UDF for an `ai_*` function. Holds the function's name
28/// and fixed return type; errors if evaluated directly.
29#[derive(Debug)]
30pub struct AiFunctionMarker {
31    name: &'static str,
32    signature: Signature,
33    return_type: DataType,
34}
35
36impl AiFunctionMarker {
37    /// Create a marker for `name` returning `return_type`.
38    #[must_use]
39    pub fn new(name: &'static str, return_type: DataType) -> Self {
40        Self {
41            name,
42            // Accept any arguments (input plus `model =>` / `labels =>` named
43            // args). Volatile so it is never const-folded away before the
44            // detector can see it.
45            signature: Signature::variadic_any(Volatility::Volatile),
46            return_type,
47        }
48    }
49}
50
51impl PartialEq for AiFunctionMarker {
52    fn eq(&self, other: &Self) -> bool {
53        self.name == other.name
54    }
55}
56
57impl Eq for AiFunctionMarker {}
58
59impl Hash for AiFunctionMarker {
60    fn hash<H: Hasher>(&self, state: &mut H) {
61        self.name.hash(state);
62    }
63}
64
65impl ScalarUDFImpl for AiFunctionMarker {
66    fn as_any(&self) -> &dyn Any {
67        self
68    }
69
70    fn name(&self) -> &'static str {
71        self.name
72    }
73
74    fn signature(&self) -> &Signature {
75        &self.signature
76    }
77
78    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
79        Ok(self.return_type.clone())
80    }
81
82    fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
83        exec_err!(
84            "{} is an AI function and must be the top-level expression of a SELECT \
85             projection over a stream; it cannot be evaluated in this position",
86            self.name
87        )
88    }
89}
90
91/// Build the eight `ai_*` marker UDFs, ready to register on a session context.
92///
93/// Text-generating and discriminative-label tasks return `Utf8`; `ai_sentiment`
94/// returns a `Float64` score in `[-1, 1]`; `ai_embed` returns a `List<Float32>`
95/// embedding.
96#[must_use]
97pub fn ai_function_markers() -> Vec<ScalarUDF> {
98    let embedding = DataType::List(Arc::new(Field::new("item", DataType::Float32, true)));
99    let specs: [(&'static str, DataType); 8] = [
100        ("ai_classify", DataType::Utf8),
101        ("ai_sentiment", DataType::Float64),
102        ("ai_embed", embedding),
103        ("ai_extract", DataType::Utf8),
104        ("ai_complete", DataType::Utf8),
105        ("ai_summarize", DataType::Utf8),
106        ("ai_translate", DataType::Utf8),
107        ("ai_gen", DataType::Utf8),
108    ];
109    specs
110        .into_iter()
111        .map(|(name, return_type)| {
112            ScalarUDF::new_from_impl(AiFunctionMarker::new(name, return_type))
113        })
114        .collect()
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120    use datafusion_common::config::ConfigOptions;
121
122    #[test]
123    fn registers_eight_markers_with_expected_return_types() {
124        let markers = ai_function_markers();
125        assert_eq!(markers.len(), 8);
126
127        let classify = markers.iter().find(|m| m.name() == "ai_classify").unwrap();
128        assert_eq!(classify.return_type(&[]).unwrap(), DataType::Utf8);
129
130        let sentiment = markers.iter().find(|m| m.name() == "ai_sentiment").unwrap();
131        assert_eq!(sentiment.return_type(&[]).unwrap(), DataType::Float64);
132
133        let embed = markers.iter().find(|m| m.name() == "ai_embed").unwrap();
134        assert!(matches!(embed.return_type(&[]).unwrap(), DataType::List(_)));
135    }
136
137    #[test]
138    fn invoking_a_marker_is_an_error() {
139        let marker = AiFunctionMarker::new("ai_classify", DataType::Utf8);
140        let args = ScalarFunctionArgs {
141            args: vec![],
142            arg_fields: vec![],
143            number_rows: 1,
144            return_field: Arc::new(Field::new("out", DataType::Utf8, true)),
145            config_options: Arc::new(ConfigOptions::default()),
146        };
147        assert!(marker.invoke_with_args(args).is_err());
148    }
149}