slang_frontend/semantic_analysis/visitors/
expression_visitor.rs

1use slang_ir::ast::*;
2use slang_shared::{CompilationContext, SymbolKind};
3use slang_types::TypeId;
4
5use super::super::{
6    error::SemanticAnalysisError, operations, traits::SemanticResult,
7    validation::TypeCheckingCoordinator,
8};
9
10/// Handles semantic analysis for all expression types
11///
12/// This visitor is responsible for analyzing expression-level constructs
13/// including binary operations, unary operations, function calls, variable
14/// references, literals, conditionals, and block expressions.
15pub struct ExpressionVisitor<'a> {
16    context: &'a mut CompilationContext,
17    current_return_type: Option<TypeId>,
18}
19
20impl<'a> ExpressionVisitor<'a> {
21    /// Create a new expression visitor
22    ///
23    /// # Arguments
24    /// * `context` - The compilation context for type information and symbol lookup
25    pub fn new(context: &'a mut CompilationContext) -> Self {
26        Self {
27            context,
28            current_return_type: None,
29        }
30    }
31
32    /// Create a new expression visitor with a specific return type context
33    ///
34    /// # Arguments
35    /// * `context` - The compilation context for type information and symbol lookup
36    /// * `current_return_type` - The current function's return type for validation
37    pub fn with_return_type(
38        context: &'a mut CompilationContext,
39        current_return_type: Option<TypeId>,
40    ) -> Self {
41        Self {
42            context,
43            current_return_type,
44        }
45    }
46
47    /// Set the current return type for function analysis
48    pub fn set_return_type(&mut self, return_type: Option<TypeId>) {
49        self.current_return_type = return_type;
50    }
51
52    /// Create a type checking coordinator for this visitor's context
53    fn create_type_coordinator(&self) -> TypeCheckingCoordinator {
54        TypeCheckingCoordinator::new(self.context)
55    }
56
57    /// Visit an expression and determine its type
58    pub fn visit_expression(&mut self, expr: &Expression) -> SemanticResult {
59        match expr {
60            Expression::Binary(bin_expr) => self.visit_binary_expression(bin_expr),
61            Expression::Unary(unary_expr) => self.visit_unary_expression(unary_expr),
62            Expression::Call(call_expr) => self.visit_call_expression(call_expr),
63            Expression::Variable(var_expr) => self.visit_variable_expression(var_expr),
64            Expression::Literal(lit_expr) => self.visit_literal_expression(lit_expr),
65            Expression::Conditional(cond_expr) => self.visit_conditional_expression(cond_expr),
66            Expression::Block(block_expr) => self.visit_block_expression(block_expr),
67            Expression::FunctionType(func_type_expr) => {
68                self.visit_function_type_expression(func_type_expr)
69            }
70        }
71    }
72
73    /// Visit a binary expression
74    pub fn visit_binary_expression(&mut self, bin_expr: &BinaryExpr) -> SemanticResult {
75        let left_type = self.visit_expression(&bin_expr.left)?;
76        let right_type = self.visit_expression(&bin_expr.right)?;
77
78        // Handle logical operations
79        if bin_expr.operator == BinaryOperator::And || bin_expr.operator == BinaryOperator::Or {
80            return operations::check_logical_operation(
81                &left_type,
82                &right_type,
83                &bin_expr.operator,
84                &bin_expr.location,
85            );
86        }
87
88        // Handle relational operations
89        if matches!(
90            bin_expr.operator,
91            BinaryOperator::GreaterThan
92                | BinaryOperator::LessThan
93                | BinaryOperator::GreaterThanOrEqual
94                | BinaryOperator::LessThanOrEqual
95                | BinaryOperator::Equal
96                | BinaryOperator::NotEqual
97        ) {
98            return operations::check_relational_operation(
99                self.context,
100                &left_type,
101                &right_type,
102                &bin_expr.operator,
103                &bin_expr.location,
104            );
105        }
106
107        // Handle arithmetic operations
108        if matches!(
109            bin_expr.operator,
110            BinaryOperator::Add
111                | BinaryOperator::Subtract
112                | BinaryOperator::Multiply
113                | BinaryOperator::Divide
114        ) {
115            if left_type == right_type {
116                return operations::check_same_type_arithmetic(
117                    self.context,
118                    &left_type,
119                    &bin_expr.operator,
120                    &bin_expr.location,
121                );
122            }
123
124            // Use coordinator for mixed arithmetic with coercion
125            let coordinator = self.create_type_coordinator();
126            return coordinator.check_mixed_arithmetic_with_coercion(
127                &left_type,
128                &right_type,
129                bin_expr,
130            );
131        }
132
133        Err(SemanticAnalysisError::OperationTypeMismatch {
134            operator: bin_expr.operator.to_string(),
135            left_type: left_type.clone(),
136            right_type: right_type.clone(),
137            location: bin_expr.location,
138        })
139    }
140
141    /// Visit a unary expression
142    pub fn visit_unary_expression(&mut self, unary_expr: &UnaryExpr) -> SemanticResult {
143        let operand_type = self.visit_expression(&unary_expr.right)?;
144
145        operations::unary::check_unary_operation(self.context, unary_expr, &operand_type)
146    }
147
148    /// Visit a function call expression
149    pub fn visit_call_expression(&mut self, call_expr: &FunctionCallExpr) -> SemanticResult {
150        let function_type = if let Some(symbol) = self.context.lookup_symbol(&call_expr.name) {
151            match symbol.kind() {
152                SymbolKind::Function => {
153                    if self.context.is_function_type(&symbol.type_id) {
154                        self.context.get_function_type(&symbol.type_id).cloned()
155                    } else {
156                        return Err(SemanticAnalysisError::UndefinedFunction {
157                            name: call_expr.name.clone(),
158                            location: call_expr.location,
159                        });
160                    }
161                }
162                SymbolKind::Variable => {
163                    if self.context.is_function_type(&symbol.type_id) {
164                        self.context.get_function_type(&symbol.type_id).cloned()
165                    } else {
166                        return Err(SemanticAnalysisError::VariableNotCallable {
167                            variable_name: call_expr.name.clone(),
168                            variable_type: symbol.type_id.clone(),
169                            location: call_expr.location,
170                        });
171                    }
172                }
173                _ => {
174                    return Err(SemanticAnalysisError::UndefinedFunction {
175                        name: call_expr.name.clone(),
176                        location: call_expr.location,
177                    });
178                }
179            }
180        } else {
181            return Err(SemanticAnalysisError::UndefinedFunction {
182                name: call_expr.name.clone(),
183                location: call_expr.location,
184            });
185        };
186
187        if let Some(func_type) = function_type {
188            // Check argument count
189            if func_type.param_types.len() != call_expr.arguments.len() {
190                return Err(SemanticAnalysisError::ArgumentCountMismatch {
191                    function_name: call_expr.name.clone(),
192                    expected: func_type.param_types.len(),
193                    actual: call_expr.arguments.len(),
194                    location: call_expr.location,
195                });
196            }
197
198            // Check argument types
199            for (i, arg) in call_expr.arguments.iter().enumerate() {
200                let param_type = func_type.param_types[i].clone();
201                let arg_type = self.visit_expression(arg)?;
202
203                if param_type == TypeId::unknown() {
204                    continue;
205                }
206
207                // Use coordinator for assignment compatibility checking
208                let coordinator = self.create_type_coordinator();
209                if !coordinator.check_assignment_compatibility(&param_type, &arg_type) {
210                    // For unspecified literals, try range validation
211                    if arg_type == TypeId::unspecified_int()
212                        || arg_type == TypeId::unspecified_float()
213                    {
214                        if coordinator
215                            .validate_literal_range(arg, &param_type)
216                            .is_err()
217                        {
218                            return Err(SemanticAnalysisError::ArgumentTypeMismatch {
219                                function_name: call_expr.name.clone(),
220                                argument_position: i + 1,
221                                expected: param_type.clone(),
222                                actual: arg_type,
223                                location: arg.location(),
224                            });
225                        }
226                    } else {
227                        // For non-literal types, it's a direct type mismatch
228                        return Err(SemanticAnalysisError::ArgumentTypeMismatch {
229                            function_name: call_expr.name.clone(),
230                            argument_position: i + 1,
231                            expected: param_type.clone(),
232                            actual: arg_type,
233                            location: arg.location(),
234                        });
235                    }
236                }
237            }
238
239            Ok(func_type.return_type.clone())
240        } else {
241            Err(SemanticAnalysisError::UndefinedFunction {
242                name: call_expr.name.clone(),
243                location: call_expr.location,
244            })
245        }
246    }
247
248    /// Visit a variable expression
249    pub fn visit_variable_expression(&mut self, var_expr: &VariableExpr) -> SemanticResult {
250        if let Some(var_info) = self.resolve_value(&var_expr.name) {
251            Ok(var_info.type_id.clone())
252        } else {
253            Err(SemanticAnalysisError::UndefinedVariable {
254                name: var_expr.name.clone(),
255                location: var_expr.location,
256            })
257        }
258    }
259
260    /// Visit a literal expression
261    pub fn visit_literal_expression(&mut self, literal_expr: &LiteralExpr) -> SemanticResult {
262        Ok(literal_expr.expr_type.clone())
263    }
264
265    /// Visit a conditional expression
266    pub fn visit_conditional_expression(&mut self, cond_expr: &ConditionalExpr) -> SemanticResult {
267        let condition_type = self.visit_expression(&cond_expr.condition)?;
268        if condition_type != TypeId::bool() {
269            return Err(SemanticAnalysisError::TypeMismatch {
270                expected: TypeId::bool(),
271                actual: condition_type,
272                context: Some("if condition".to_string()),
273                location: cond_expr.condition.location(),
274            });
275        }
276
277        let then_type = self.visit_expression(&cond_expr.then_branch)?;
278        let else_type = self.visit_expression(&cond_expr.else_branch)?;
279
280        if then_type == TypeId::unknown() {
281            Ok(else_type)
282        } else if else_type == TypeId::unknown() || then_type == else_type {
283            Ok(then_type)
284        } else {
285            Err(SemanticAnalysisError::TypeMismatch {
286                expected: then_type,
287                actual: else_type,
288                context: Some(
289                    "conditional expression branches must have the same type".to_string(),
290                ),
291                location: cond_expr.location,
292            })
293        }
294    }
295
296    /// Visit a block expression
297    pub fn visit_block_expression(&mut self, block_expr: &BlockExpr) -> SemanticResult {
298        self.context.begin_scope();
299
300        // Process all statements in the block
301        for stmt in &block_expr.statements {
302            // Create a statement visitor with the current return type context
303            let mut stmt_visitor = super::statement_visitor::StatementVisitor::with_return_type(
304                self.context,
305                self.current_return_type.clone(),
306            );
307            match stmt {
308                Statement::Let(let_stmt) => {
309                    stmt_visitor.visit_let_statement(let_stmt)?;
310                }
311                Statement::Assignment(assign_stmt) => {
312                    stmt_visitor.visit_assignment_statement(assign_stmt)?;
313                }
314                Statement::Expression(expr) => {
315                    self.visit_expression(expr)?;
316                }
317                Statement::If(if_stmt) => {
318                    stmt_visitor.visit_if_statement(if_stmt)?;
319                }
320                Statement::Return(return_stmt) => {
321                    stmt_visitor.visit_return_statement(return_stmt)?;
322                }
323                Statement::FunctionDeclaration(fn_decl) => {
324                    stmt_visitor.visit_function_declaration(fn_decl)?;
325                }
326                Statement::TypeDefinition(type_def) => {
327                    stmt_visitor.visit_type_definition_statement(type_def)?;
328                }
329            }
330        }
331
332        let block_type = if let Some(return_expr) = &block_expr.return_expr {
333            self.visit_expression(return_expr)?
334        } else {
335            TypeId::unit()
336        };
337
338        self.context.end_scope();
339
340        Ok(block_type)
341    }
342
343    /// Visit a function type expression
344    pub fn visit_function_type_expression(
345        &mut self,
346        func_type_expr: &FunctionTypeExpr,
347    ) -> SemanticResult {
348        // Validate all parameter types exist
349        for param_type in &func_type_expr.param_types {
350            if self.context.get_type_info(param_type).is_none() {
351                return Err(SemanticAnalysisError::InvalidFieldType {
352                    struct_name: "function type".to_string(),
353                    field_name: "parameter".to_string(),
354                    type_id: param_type.clone(),
355                    location: func_type_expr.location,
356                });
357            }
358        }
359
360        // Validate return type exists
361        if self
362            .context
363            .get_type_info(&func_type_expr.return_type)
364            .is_none()
365        {
366            return Err(SemanticAnalysisError::InvalidFieldType {
367                struct_name: "function type".to_string(),
368                field_name: "return type".to_string(),
369                type_id: func_type_expr.return_type.clone(),
370                location: func_type_expr.location,
371            });
372        }
373
374        Ok(func_type_expr.expr_type.clone())
375    }
376
377    // Helper methods
378
379    /// Resolve a symbol that can be used as a value (variables and functions)
380    fn resolve_value(&self, name: &str) -> Option<&slang_shared::Symbol> {
381        self.context
382            .lookup_symbol(name)
383            .filter(|symbol| matches!(symbol.kind(), SymbolKind::Variable | SymbolKind::Function))
384    }
385}