laminar_sql/datafusion/
ai_udf.rs1use std::any::Any;
18use std::hash::{Hash, Hasher};
19use std::sync::Arc;
20
21use arrow::datatypes::{DataType, Field};
22use datafusion_common::{exec_err, Result};
23use datafusion_expr::{
24 ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
25};
26
27#[derive(Debug)]
30pub struct AiFunctionMarker {
31 name: &'static str,
32 signature: Signature,
33 return_type: DataType,
34}
35
36impl AiFunctionMarker {
37 #[must_use]
39 pub fn new(name: &'static str, return_type: DataType) -> Self {
40 Self {
41 name,
42 signature: Signature::variadic_any(Volatility::Volatile),
46 return_type,
47 }
48 }
49}
50
51impl PartialEq for AiFunctionMarker {
52 fn eq(&self, other: &Self) -> bool {
53 self.name == other.name
54 }
55}
56
57impl Eq for AiFunctionMarker {}
58
59impl Hash for AiFunctionMarker {
60 fn hash<H: Hasher>(&self, state: &mut H) {
61 self.name.hash(state);
62 }
63}
64
65impl ScalarUDFImpl for AiFunctionMarker {
66 fn as_any(&self) -> &dyn Any {
67 self
68 }
69
70 fn name(&self) -> &'static str {
71 self.name
72 }
73
74 fn signature(&self) -> &Signature {
75 &self.signature
76 }
77
78 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
79 Ok(self.return_type.clone())
80 }
81
82 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
83 exec_err!(
84 "{} is an AI function and must be the top-level expression of a SELECT \
85 projection over a stream; it cannot be evaluated in this position",
86 self.name
87 )
88 }
89}
90
91#[must_use]
97pub fn ai_function_markers() -> Vec<ScalarUDF> {
98 let embedding = DataType::List(Arc::new(Field::new("item", DataType::Float32, true)));
99 let specs: [(&'static str, DataType); 8] = [
100 ("ai_classify", DataType::Utf8),
101 ("ai_sentiment", DataType::Float64),
102 ("ai_embed", embedding),
103 ("ai_extract", DataType::Utf8),
104 ("ai_complete", DataType::Utf8),
105 ("ai_summarize", DataType::Utf8),
106 ("ai_translate", DataType::Utf8),
107 ("ai_gen", DataType::Utf8),
108 ];
109 specs
110 .into_iter()
111 .map(|(name, return_type)| {
112 ScalarUDF::new_from_impl(AiFunctionMarker::new(name, return_type))
113 })
114 .collect()
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120 use datafusion_common::config::ConfigOptions;
121
122 #[test]
123 fn registers_eight_markers_with_expected_return_types() {
124 let markers = ai_function_markers();
125 assert_eq!(markers.len(), 8);
126
127 let classify = markers.iter().find(|m| m.name() == "ai_classify").unwrap();
128 assert_eq!(classify.return_type(&[]).unwrap(), DataType::Utf8);
129
130 let sentiment = markers.iter().find(|m| m.name() == "ai_sentiment").unwrap();
131 assert_eq!(sentiment.return_type(&[]).unwrap(), DataType::Float64);
132
133 let embed = markers.iter().find(|m| m.name() == "ai_embed").unwrap();
134 assert!(matches!(embed.return_type(&[]).unwrap(), DataType::List(_)));
135 }
136
137 #[test]
138 fn invoking_a_marker_is_an_error() {
139 let marker = AiFunctionMarker::new("ai_classify", DataType::Utf8);
140 let args = ScalarFunctionArgs {
141 args: vec![],
142 arg_fields: vec![],
143 number_rows: 1,
144 return_field: Arc::new(Field::new("out", DataType::Utf8, true)),
145 config_options: Arc::new(ConfigOptions::default()),
146 };
147 assert!(marker.invoke_with_args(args).is_err());
148 }
149}