slang_frontend/semantic_analysis/visitors/
statement_visitor.rs1use slang_ir::Location;
2use slang_ir::ast::*;
3use slang_shared::{CompilationContext, SymbolKind};
4use slang_types::{TYPE_NAME_U32, TYPE_NAME_U64, TypeId};
5
6use super::super::{error::SemanticAnalysisError, traits::SemanticResult, type_system};
7use super::expression_visitor::ExpressionVisitor;
8
9pub struct StatementVisitor<'a> {
15 context: &'a mut CompilationContext,
16 current_return_type: Option<TypeId>,
17}
18
19impl<'a> StatementVisitor<'a> {
20 pub fn new(context: &'a mut CompilationContext) -> Self {
25 Self {
26 context,
27 current_return_type: None,
28 }
29 }
30
31 pub fn with_return_type(
37 context: &'a mut CompilationContext,
38 return_type: Option<TypeId>,
39 ) -> Self {
40 Self {
41 context,
42 current_return_type: return_type,
43 }
44 }
45
46 pub fn set_return_type(&mut self, return_type: Option<TypeId>) {
50 self.current_return_type = return_type;
51 }
52
53 pub fn current_return_type(&self) -> Option<&TypeId> {
55 self.current_return_type.as_ref()
56 }
57
58 pub fn visit_function_declaration(
60 &mut self,
61 fn_decl: &FunctionDeclarationStmt,
62 ) -> SemanticResult {
63 let mut param_types = Vec::new();
64 for param in &fn_decl.parameters {
65 param_types.push(param.param_type.clone());
66 }
67
68 let function_type_id = self
69 .context
70 .register_function_type(param_types.clone(), fn_decl.return_type.clone());
71
72 if self
73 .context
74 .define_symbol(
75 fn_decl.name.clone(),
76 SymbolKind::Function,
77 function_type_id,
78 false,
79 )
80 .is_err()
81 {
82 return Err(SemanticAnalysisError::SymbolRedefinition {
83 name: fn_decl.name.clone(),
84 kind: "function".to_string(),
85 location: fn_decl.location,
86 });
87 }
88
89 let previous_return_type = self.current_return_type.clone();
90 self.current_return_type = Some(fn_decl.return_type.clone());
91
92 self.context.begin_scope();
93 for param in &fn_decl.parameters {
94 if self
95 .context
96 .define_symbol(
97 param.name.clone(),
98 SymbolKind::Variable,
99 param.param_type.clone(),
100 true,
101 )
102 .is_err()
103 {
104 return Err(SemanticAnalysisError::SymbolRedefinition {
105 name: param.name.clone(),
106 kind: "parameter".to_string(),
107 location: fn_decl.location,
108 });
109 }
110 }
111
112 let result = self.analyze_function_body(&fn_decl.body);
115
116 self.current_return_type = previous_return_type;
117 self.context.end_scope();
118
119 result.and(Ok(fn_decl.return_type.clone()))
120 }
121
122 pub fn visit_return_statement(&mut self, return_stmt: &ReturnStatement) -> SemanticResult {
124 let error_location = match &return_stmt.value {
125 Some(expr) => expr.location(),
126 None => return_stmt.location,
127 };
128
129 if let Some(expected_type) = &self.current_return_type {
130 let expected_type = expected_type.clone();
131 if let Some(expr) = &return_stmt.value {
132 return self.check_return_expr_type_internal(
133 expr,
134 &expected_type,
135 &expr.location(),
136 );
137 } else if expected_type != TypeId::unknown() && expected_type != TypeId::unit() {
138 return Err(SemanticAnalysisError::MissingReturnValue {
139 expected: expected_type.clone(),
140 location: error_location,
141 });
142 }
143
144 Ok(TypeId::unit())
146 } else {
147 Err(SemanticAnalysisError::ReturnOutsideFunction {
148 location: error_location,
149 })
150 }
151 }
152
153 pub fn visit_let_statement(&mut self, let_stmt: &LetStatement) -> SemanticResult {
155 if let Expression::Unary(unary_expr) = &let_stmt.value {
157 if unary_expr.operator == UnaryOperator::Negate {
158 if let Expression::Literal(lit) = &*unary_expr.right {
159 if let LiteralValue::UnspecifiedInteger(n) = &lit.value {
160 if self.context.get_type_name(&let_stmt.expr_type) == TYPE_NAME_U32
161 || self.context.get_type_name(&let_stmt.expr_type) == TYPE_NAME_U64
162 {
163 let negative_value = -n;
164 return Err(SemanticAnalysisError::ValueOutOfRange {
165 value: negative_value.to_string(),
166 target_type: let_stmt.expr_type.clone(),
167 is_float: false,
168 location: let_stmt.location,
169 });
170 }
171 }
172 }
173 }
174 }
175
176 if let Some(symbol) = self.context.lookup_symbol(&let_stmt.name) {
178 if symbol.kind() == SymbolKind::Type {
179 return Err(SemanticAnalysisError::SymbolRedefinition {
180 name: let_stmt.name.clone(),
181 kind: "variable (conflicts with type)".to_string(),
182 location: let_stmt.location,
183 });
184 } else if symbol.kind() == SymbolKind::Function {
185 return Err(SemanticAnalysisError::SymbolRedefinition {
186 name: let_stmt.name.clone(),
187 kind: "variable (conflicts with function)".to_string(),
188 location: let_stmt.location,
189 });
190 }
191 }
192
193 let expr_type = self.visit_expression(&let_stmt.value)?;
195 let final_type = self.determine_let_statement_type(let_stmt, expr_type)?;
196 let final_type = type_system::finalize_inferred_type(final_type);
197
198 if self
199 .context
200 .define_symbol(
201 let_stmt.name.clone(),
202 SymbolKind::Variable,
203 final_type.clone(),
204 let_stmt.is_mutable,
205 )
206 .is_err()
207 {
208 return Err(SemanticAnalysisError::VariableRedefinition {
209 name: let_stmt.name.clone(),
210 location: let_stmt.location,
211 });
212 }
213
214 Ok(final_type)
215 }
216
217 pub fn visit_assignment_statement(
219 &mut self,
220 assign_stmt: &AssignmentStatement,
221 ) -> SemanticResult {
222 let (var_type_id, is_mutable) =
224 if let Some(var_info) = self.resolve_variable(&assign_stmt.name) {
225 (var_info.type_id.clone(), var_info.is_mutable())
226 } else {
227 return Err(SemanticAnalysisError::UndefinedVariable {
228 name: assign_stmt.name.clone(),
229 location: assign_stmt.location,
230 });
231 };
232
233 if !is_mutable {
235 return Err(SemanticAnalysisError::AssignmentToImmutableVariable {
236 name: assign_stmt.name.clone(),
237 location: assign_stmt.location,
238 });
239 }
240
241 let expr_type = self.visit_expression(&assign_stmt.value)?;
243
244 if var_type_id == expr_type
245 || expr_type == TypeId::unspecified_int()
246 || expr_type == TypeId::unspecified_float()
247 {
248 Ok(var_type_id)
249 } else {
250 Err(SemanticAnalysisError::TypeMismatch {
251 expected: var_type_id,
252 actual: expr_type,
253 context: Some(format!("assignment to variable '{}'", assign_stmt.name)),
254 location: assign_stmt.location,
255 })
256 }
257 }
258
259 pub fn visit_type_definition_statement(
261 &mut self,
262 type_def: &TypeDefinitionStmt,
263 ) -> SemanticResult {
264 if self.context.lookup_symbol(&type_def.name).is_some() {
265 return Err(SemanticAnalysisError::SymbolRedefinition {
266 name: type_def.name.clone(),
267 kind: "type".to_string(),
268 location: type_def.location,
269 });
270 }
271
272 let mut field_types_for_registration = Vec::new();
273 for (name, type_id) in &type_def.fields {
274 if *type_id == TypeId::unknown()
275 || *type_id == TypeId::unspecified_int()
276 || *type_id == TypeId::unspecified_float()
277 {
278 return Err(SemanticAnalysisError::InvalidFieldType {
279 struct_name: type_def.name.clone(),
280 field_name: name.clone(),
281 type_id: type_id.clone(),
282 location: type_def.location,
283 });
284 }
285 field_types_for_registration.push((name.clone(), type_id.clone()));
286 }
287
288 match self
289 .context
290 .register_struct_type(type_def.name.clone(), field_types_for_registration)
291 {
292 Ok(type_id) => Ok(type_id),
293 Err(_) => Err(SemanticAnalysisError::SymbolRedefinition {
294 name: type_def.name.clone(),
295 kind: "type".to_string(),
296 location: type_def.location,
297 }),
298 }
299 }
300
301 pub fn visit_expression_statement(&mut self, expr: &Expression) -> SemanticResult {
303 self.visit_expression(expr)
304 }
305
306 pub fn visit_if_statement(&mut self, if_stmt: &IfStatement) -> SemanticResult {
308 let condition_type = self.visit_expression(&if_stmt.condition)?;
309 if condition_type != TypeId::bool() {
310 return Err(SemanticAnalysisError::TypeMismatch {
311 expected: TypeId::bool(),
312 actual: condition_type,
313 context: Some("if condition".to_string()),
314 location: if_stmt.condition.location(),
315 });
316 }
317
318 self.visit_block_expression(&if_stmt.then_branch)?;
319
320 if let Some(else_branch) = &if_stmt.else_branch {
321 self.visit_block_expression(else_branch)?;
322 }
323
324 Ok(TypeId::unit())
325 }
326
327 fn resolve_variable(&self, name: &str) -> Option<&slang_shared::Symbol> {
330 self.context
331 .lookup_symbol(name)
332 .filter(|symbol| symbol.kind() == SymbolKind::Variable)
333 }
334
335 fn check_return_expr_type_internal(
336 &mut self,
337 expr: &Expression,
338 expected_type: &TypeId,
339 location: &Location,
340 ) -> SemanticResult {
341 let actual_type = self.visit_expression(expr)?;
342
343 if actual_type == *expected_type {
344 return Ok(actual_type);
345 }
346
347 if actual_type == TypeId::unspecified_int() {
349 if type_system::is_integer_type(self.context, expected_type) {
350 return type_system::check_unspecified_int_for_type(
351 self.context,
352 expr,
353 expected_type,
354 );
355 }
356 }
357
358 if actual_type == TypeId::unspecified_float() {
360 if type_system::is_float_type(self.context, expected_type) {
361 return type_system::check_unspecified_float_for_type(
362 self.context,
363 expr,
364 expected_type,
365 );
366 }
367 }
368
369 Err(SemanticAnalysisError::ReturnTypeMismatch {
370 expected: expected_type.clone(),
371 actual: actual_type,
372 location: *location,
373 })
374 }
375
376 fn determine_let_statement_type(
377 &mut self,
378 let_stmt: &LetStatement,
379 expr_type: TypeId,
380 ) -> SemanticResult {
381 type_system::determine_let_statement_type(self.context, let_stmt, expr_type)
382 }
383
384 fn visit_expression(&mut self, expr: &Expression) -> SemanticResult {
387 let mut expr_visitor =
389 ExpressionVisitor::with_return_type(self.context, self.current_return_type.clone());
390 expr_visitor.visit_expression(expr)
391 }
392
393 fn visit_block_expression(&mut self, block: &BlockExpr) -> SemanticResult {
394 let mut expr_visitor =
396 ExpressionVisitor::with_return_type(self.context, self.current_return_type.clone());
397 expr_visitor.visit_block_expression(block)
398 }
399
400 fn analyze_function_body(&mut self, body: &BlockExpr) -> SemanticResult {
401 self.visit_block_expression(body)
402 }
403}