slang_frontend/semantic_analysis/visitors/
statement_visitor.rs

1use slang_ir::Location;
2use slang_ir::ast::*;
3use slang_shared::{CompilationContext, SymbolKind};
4use slang_types::{TYPE_NAME_U32, TYPE_NAME_U64, TypeId};
5
6use super::super::{error::SemanticAnalysisError, traits::SemanticResult, type_system};
7use super::expression_visitor::ExpressionVisitor;
8
9/// Handles semantic analysis for all statement types
10///
11/// This visitor is responsible for analyzing statement-level constructs
12/// including function declarations, variable declarations, assignments,
13/// return statements, type definitions, and control flow statements.
14pub struct StatementVisitor<'a> {
15    context: &'a mut CompilationContext,
16    current_return_type: Option<TypeId>,
17}
18
19impl<'a> StatementVisitor<'a> {
20    /// Create a new statement visitor
21    ///
22    /// # Arguments
23    /// * `context` - The compilation context for symbol management
24    pub fn new(context: &'a mut CompilationContext) -> Self {
25        Self {
26            context,
27            current_return_type: None,
28        }
29    }
30
31    /// Create a new statement visitor with an inherited return type context
32    ///
33    /// # Arguments
34    /// * `context` - The compilation context for symbol management
35    /// * `return_type` - The current function's return type for context inheritance
36    pub fn with_return_type(
37        context: &'a mut CompilationContext,
38        return_type: Option<TypeId>,
39    ) -> Self {
40        Self {
41            context,
42            current_return_type: return_type,
43        }
44    }
45
46    /// Set the current return type for function analysis
47    ///
48    /// This is used when analyzing function bodies to validate return statements
49    pub fn set_return_type(&mut self, return_type: Option<TypeId>) {
50        self.current_return_type = return_type;
51    }
52
53    /// Get the current return type
54    pub fn current_return_type(&self) -> Option<&TypeId> {
55        self.current_return_type.as_ref()
56    }
57
58    /// Visit a function declaration statement
59    pub fn visit_function_declaration(
60        &mut self,
61        fn_decl: &FunctionDeclarationStmt,
62    ) -> SemanticResult {
63        let mut param_types = Vec::new();
64        for param in &fn_decl.parameters {
65            param_types.push(param.param_type.clone());
66        }
67
68        let function_type_id = self
69            .context
70            .register_function_type(param_types.clone(), fn_decl.return_type.clone());
71
72        if self
73            .context
74            .define_symbol(
75                fn_decl.name.clone(),
76                SymbolKind::Function,
77                function_type_id,
78                false,
79            )
80            .is_err()
81        {
82            return Err(SemanticAnalysisError::SymbolRedefinition {
83                name: fn_decl.name.clone(),
84                kind: "function".to_string(),
85                location: fn_decl.location,
86            });
87        }
88
89        let previous_return_type = self.current_return_type.clone();
90        self.current_return_type = Some(fn_decl.return_type.clone());
91
92        self.context.begin_scope();
93        for param in &fn_decl.parameters {
94            if self
95                .context
96                .define_symbol(
97                    param.name.clone(),
98                    SymbolKind::Variable,
99                    param.param_type.clone(),
100                    true,
101                )
102                .is_err()
103            {
104                return Err(SemanticAnalysisError::SymbolRedefinition {
105                    name: param.name.clone(),
106                    kind: "parameter".to_string(),
107                    location: fn_decl.location,
108                });
109            }
110        }
111
112        // For now, we'll need to handle block expression analysis differently
113        // This will be resolved when we integrate with expression visitor
114        let result = self.analyze_function_body(&fn_decl.body);
115
116        self.current_return_type = previous_return_type;
117        self.context.end_scope();
118
119        result.and(Ok(fn_decl.return_type.clone()))
120    }
121
122    /// Visit a return statement
123    pub fn visit_return_statement(&mut self, return_stmt: &ReturnStatement) -> SemanticResult {
124        let error_location = match &return_stmt.value {
125            Some(expr) => expr.location(),
126            None => return_stmt.location,
127        };
128
129        if let Some(expected_type) = &self.current_return_type {
130            let expected_type = expected_type.clone();
131            if let Some(expr) = &return_stmt.value {
132                return self.check_return_expr_type_internal(
133                    expr,
134                    &expected_type,
135                    &expr.location(),
136                );
137            } else if expected_type != TypeId::unknown() && expected_type != TypeId::unit() {
138                return Err(SemanticAnalysisError::MissingReturnValue {
139                    expected: expected_type.clone(),
140                    location: error_location,
141                });
142            }
143
144            // Empty return is treated as returning unit
145            Ok(TypeId::unit())
146        } else {
147            Err(SemanticAnalysisError::ReturnOutsideFunction {
148                location: error_location,
149            })
150        }
151    }
152
153    /// Visit a let statement
154    pub fn visit_let_statement(&mut self, let_stmt: &LetStatement) -> SemanticResult {
155        // Check for negative values assigned to unsigned types
156        if let Expression::Unary(unary_expr) = &let_stmt.value {
157            if unary_expr.operator == UnaryOperator::Negate {
158                if let Expression::Literal(lit) = &*unary_expr.right {
159                    if let LiteralValue::UnspecifiedInteger(n) = &lit.value {
160                        if self.context.get_type_name(&let_stmt.expr_type) == TYPE_NAME_U32
161                            || self.context.get_type_name(&let_stmt.expr_type) == TYPE_NAME_U64
162                        {
163                            let negative_value = -n;
164                            return Err(SemanticAnalysisError::ValueOutOfRange {
165                                value: negative_value.to_string(),
166                                target_type: let_stmt.expr_type.clone(),
167                                is_float: false,
168                                location: let_stmt.location,
169                            });
170                        }
171                    }
172                }
173            }
174        }
175
176        // Check for symbol conflicts
177        if let Some(symbol) = self.context.lookup_symbol(&let_stmt.name) {
178            if symbol.kind() == SymbolKind::Type {
179                return Err(SemanticAnalysisError::SymbolRedefinition {
180                    name: let_stmt.name.clone(),
181                    kind: "variable (conflicts with type)".to_string(),
182                    location: let_stmt.location,
183                });
184            } else if symbol.kind() == SymbolKind::Function {
185                return Err(SemanticAnalysisError::SymbolRedefinition {
186                    name: let_stmt.name.clone(),
187                    kind: "variable (conflicts with function)".to_string(),
188                    location: let_stmt.location,
189                });
190            }
191        }
192
193        // TODO: This will need to be updated to use expression visitor
194        let expr_type = self.visit_expression(&let_stmt.value)?;
195        let final_type = self.determine_let_statement_type(let_stmt, expr_type)?;
196        let final_type = type_system::finalize_inferred_type(final_type);
197
198        if self
199            .context
200            .define_symbol(
201                let_stmt.name.clone(),
202                SymbolKind::Variable,
203                final_type.clone(),
204                let_stmt.is_mutable,
205            )
206            .is_err()
207        {
208            return Err(SemanticAnalysisError::VariableRedefinition {
209                name: let_stmt.name.clone(),
210                location: let_stmt.location,
211            });
212        }
213
214        Ok(final_type)
215    }
216
217    /// Visit an assignment statement
218    pub fn visit_assignment_statement(
219        &mut self,
220        assign_stmt: &AssignmentStatement,
221    ) -> SemanticResult {
222        // First check if variable exists and get its type and mutability
223        let (var_type_id, is_mutable) =
224            if let Some(var_info) = self.resolve_variable(&assign_stmt.name) {
225                (var_info.type_id.clone(), var_info.is_mutable())
226            } else {
227                return Err(SemanticAnalysisError::UndefinedVariable {
228                    name: assign_stmt.name.clone(),
229                    location: assign_stmt.location,
230                });
231            };
232
233        // Check mutability
234        if !is_mutable {
235            return Err(SemanticAnalysisError::AssignmentToImmutableVariable {
236                name: assign_stmt.name.clone(),
237                location: assign_stmt.location,
238            });
239        }
240
241        // TODO: This will need to be updated to use expression visitor
242        let expr_type = self.visit_expression(&assign_stmt.value)?;
243
244        if var_type_id == expr_type
245            || expr_type == TypeId::unspecified_int()
246            || expr_type == TypeId::unspecified_float()
247        {
248            Ok(var_type_id)
249        } else {
250            Err(SemanticAnalysisError::TypeMismatch {
251                expected: var_type_id,
252                actual: expr_type,
253                context: Some(format!("assignment to variable '{}'", assign_stmt.name)),
254                location: assign_stmt.location,
255            })
256        }
257    }
258
259    /// Visit a type definition statement
260    pub fn visit_type_definition_statement(
261        &mut self,
262        type_def: &TypeDefinitionStmt,
263    ) -> SemanticResult {
264        if self.context.lookup_symbol(&type_def.name).is_some() {
265            return Err(SemanticAnalysisError::SymbolRedefinition {
266                name: type_def.name.clone(),
267                kind: "type".to_string(),
268                location: type_def.location,
269            });
270        }
271
272        let mut field_types_for_registration = Vec::new();
273        for (name, type_id) in &type_def.fields {
274            if *type_id == TypeId::unknown()
275                || *type_id == TypeId::unspecified_int()
276                || *type_id == TypeId::unspecified_float()
277            {
278                return Err(SemanticAnalysisError::InvalidFieldType {
279                    struct_name: type_def.name.clone(),
280                    field_name: name.clone(),
281                    type_id: type_id.clone(),
282                    location: type_def.location,
283                });
284            }
285            field_types_for_registration.push((name.clone(), type_id.clone()));
286        }
287
288        match self
289            .context
290            .register_struct_type(type_def.name.clone(), field_types_for_registration)
291        {
292            Ok(type_id) => Ok(type_id),
293            Err(_) => Err(SemanticAnalysisError::SymbolRedefinition {
294                name: type_def.name.clone(),
295                kind: "type".to_string(),
296                location: type_def.location,
297            }),
298        }
299    }
300
301    /// Visit an expression statement
302    pub fn visit_expression_statement(&mut self, expr: &Expression) -> SemanticResult {
303        self.visit_expression(expr)
304    }
305
306    /// Visit an if statement
307    pub fn visit_if_statement(&mut self, if_stmt: &IfStatement) -> SemanticResult {
308        let condition_type = self.visit_expression(&if_stmt.condition)?;
309        if condition_type != TypeId::bool() {
310            return Err(SemanticAnalysisError::TypeMismatch {
311                expected: TypeId::bool(),
312                actual: condition_type,
313                context: Some("if condition".to_string()),
314                location: if_stmt.condition.location(),
315            });
316        }
317
318        self.visit_block_expression(&if_stmt.then_branch)?;
319
320        if let Some(else_branch) = &if_stmt.else_branch {
321            self.visit_block_expression(else_branch)?;
322        }
323
324        Ok(TypeId::unit())
325    }
326
327    // Helper methods that will be replaced when integrating with expression visitor
328
329    fn resolve_variable(&self, name: &str) -> Option<&slang_shared::Symbol> {
330        self.context
331            .lookup_symbol(name)
332            .filter(|symbol| symbol.kind() == SymbolKind::Variable)
333    }
334
335    fn check_return_expr_type_internal(
336        &mut self,
337        expr: &Expression,
338        expected_type: &TypeId,
339        location: &Location,
340    ) -> SemanticResult {
341        let actual_type = self.visit_expression(expr)?;
342
343        if actual_type == *expected_type {
344            return Ok(actual_type);
345        }
346
347        // Handle coercion of unspecified int to specific integer types
348        if actual_type == TypeId::unspecified_int() {
349            if type_system::is_integer_type(self.context, expected_type) {
350                return type_system::check_unspecified_int_for_type(
351                    self.context,
352                    expr,
353                    expected_type,
354                );
355            }
356        }
357
358        // Handle coercion of unspecified float to specific float types
359        if actual_type == TypeId::unspecified_float() {
360            if type_system::is_float_type(self.context, expected_type) {
361                return type_system::check_unspecified_float_for_type(
362                    self.context,
363                    expr,
364                    expected_type,
365                );
366            }
367        }
368
369        Err(SemanticAnalysisError::ReturnTypeMismatch {
370            expected: expected_type.clone(),
371            actual: actual_type,
372            location: *location,
373        })
374    }
375
376    fn determine_let_statement_type(
377        &mut self,
378        let_stmt: &LetStatement,
379        expr_type: TypeId,
380    ) -> SemanticResult {
381        type_system::determine_let_statement_type(self.context, let_stmt, expr_type)
382    }
383
384    // Helper methods for expression analysis
385
386    fn visit_expression(&mut self, expr: &Expression) -> SemanticResult {
387        // Create a new expression visitor with the current return type context
388        let mut expr_visitor =
389            ExpressionVisitor::with_return_type(self.context, self.current_return_type.clone());
390        expr_visitor.visit_expression(expr)
391    }
392
393    fn visit_block_expression(&mut self, block: &BlockExpr) -> SemanticResult {
394        // Create a new expression visitor with the current return type context
395        let mut expr_visitor =
396            ExpressionVisitor::with_return_type(self.context, self.current_return_type.clone());
397        expr_visitor.visit_block_expression(block)
398    }
399
400    fn analyze_function_body(&mut self, body: &BlockExpr) -> SemanticResult {
401        self.visit_block_expression(body)
402    }
403}