slang_frontend/semantic_analysis/visitors/
expression_visitor.rs1use slang_ir::ast::*;
2use slang_shared::{CompilationContext, SymbolKind};
3use slang_types::TypeId;
4
5use super::super::{
6 error::SemanticAnalysisError, operations, traits::SemanticResult,
7 validation::TypeCheckingCoordinator,
8};
9
10pub struct ExpressionVisitor<'a> {
16 context: &'a mut CompilationContext,
17 current_return_type: Option<TypeId>,
18}
19
20impl<'a> ExpressionVisitor<'a> {
21 pub fn new(context: &'a mut CompilationContext) -> Self {
26 Self {
27 context,
28 current_return_type: None,
29 }
30 }
31
32 pub fn with_return_type(
38 context: &'a mut CompilationContext,
39 current_return_type: Option<TypeId>,
40 ) -> Self {
41 Self {
42 context,
43 current_return_type,
44 }
45 }
46
47 pub fn set_return_type(&mut self, return_type: Option<TypeId>) {
49 self.current_return_type = return_type;
50 }
51
52 fn create_type_coordinator(&self) -> TypeCheckingCoordinator {
54 TypeCheckingCoordinator::new(self.context)
55 }
56
57 pub fn visit_expression(&mut self, expr: &Expression) -> SemanticResult {
59 match expr {
60 Expression::Binary(bin_expr) => self.visit_binary_expression(bin_expr),
61 Expression::Unary(unary_expr) => self.visit_unary_expression(unary_expr),
62 Expression::Call(call_expr) => self.visit_call_expression(call_expr),
63 Expression::Variable(var_expr) => self.visit_variable_expression(var_expr),
64 Expression::Literal(lit_expr) => self.visit_literal_expression(lit_expr),
65 Expression::Conditional(cond_expr) => self.visit_conditional_expression(cond_expr),
66 Expression::Block(block_expr) => self.visit_block_expression(block_expr),
67 Expression::FunctionType(func_type_expr) => {
68 self.visit_function_type_expression(func_type_expr)
69 }
70 }
71 }
72
73 pub fn visit_binary_expression(&mut self, bin_expr: &BinaryExpr) -> SemanticResult {
75 let left_type = self.visit_expression(&bin_expr.left)?;
76 let right_type = self.visit_expression(&bin_expr.right)?;
77
78 if bin_expr.operator == BinaryOperator::And || bin_expr.operator == BinaryOperator::Or {
80 return operations::check_logical_operation(
81 &left_type,
82 &right_type,
83 &bin_expr.operator,
84 &bin_expr.location,
85 );
86 }
87
88 if matches!(
90 bin_expr.operator,
91 BinaryOperator::GreaterThan
92 | BinaryOperator::LessThan
93 | BinaryOperator::GreaterThanOrEqual
94 | BinaryOperator::LessThanOrEqual
95 | BinaryOperator::Equal
96 | BinaryOperator::NotEqual
97 ) {
98 return operations::check_relational_operation(
99 self.context,
100 &left_type,
101 &right_type,
102 &bin_expr.operator,
103 &bin_expr.location,
104 );
105 }
106
107 if matches!(
109 bin_expr.operator,
110 BinaryOperator::Add
111 | BinaryOperator::Subtract
112 | BinaryOperator::Multiply
113 | BinaryOperator::Divide
114 ) {
115 if left_type == right_type {
116 return operations::check_same_type_arithmetic(
117 self.context,
118 &left_type,
119 &bin_expr.operator,
120 &bin_expr.location,
121 );
122 }
123
124 let coordinator = self.create_type_coordinator();
126 return coordinator.check_mixed_arithmetic_with_coercion(
127 &left_type,
128 &right_type,
129 bin_expr,
130 );
131 }
132
133 Err(SemanticAnalysisError::OperationTypeMismatch {
134 operator: bin_expr.operator.to_string(),
135 left_type: left_type.clone(),
136 right_type: right_type.clone(),
137 location: bin_expr.location,
138 })
139 }
140
141 pub fn visit_unary_expression(&mut self, unary_expr: &UnaryExpr) -> SemanticResult {
143 let operand_type = self.visit_expression(&unary_expr.right)?;
144
145 operations::unary::check_unary_operation(self.context, unary_expr, &operand_type)
146 }
147
148 pub fn visit_call_expression(&mut self, call_expr: &FunctionCallExpr) -> SemanticResult {
150 let function_type = if let Some(symbol) = self.context.lookup_symbol(&call_expr.name) {
151 match symbol.kind() {
152 SymbolKind::Function => {
153 if self.context.is_function_type(&symbol.type_id) {
154 self.context.get_function_type(&symbol.type_id).cloned()
155 } else {
156 return Err(SemanticAnalysisError::UndefinedFunction {
157 name: call_expr.name.clone(),
158 location: call_expr.location,
159 });
160 }
161 }
162 SymbolKind::Variable => {
163 if self.context.is_function_type(&symbol.type_id) {
164 self.context.get_function_type(&symbol.type_id).cloned()
165 } else {
166 return Err(SemanticAnalysisError::VariableNotCallable {
167 variable_name: call_expr.name.clone(),
168 variable_type: symbol.type_id.clone(),
169 location: call_expr.location,
170 });
171 }
172 }
173 _ => {
174 return Err(SemanticAnalysisError::UndefinedFunction {
175 name: call_expr.name.clone(),
176 location: call_expr.location,
177 });
178 }
179 }
180 } else {
181 return Err(SemanticAnalysisError::UndefinedFunction {
182 name: call_expr.name.clone(),
183 location: call_expr.location,
184 });
185 };
186
187 if let Some(func_type) = function_type {
188 if func_type.param_types.len() != call_expr.arguments.len() {
190 return Err(SemanticAnalysisError::ArgumentCountMismatch {
191 function_name: call_expr.name.clone(),
192 expected: func_type.param_types.len(),
193 actual: call_expr.arguments.len(),
194 location: call_expr.location,
195 });
196 }
197
198 for (i, arg) in call_expr.arguments.iter().enumerate() {
200 let param_type = func_type.param_types[i].clone();
201 let arg_type = self.visit_expression(arg)?;
202
203 if param_type == TypeId::unknown() {
204 continue;
205 }
206
207 let coordinator = self.create_type_coordinator();
209 if !coordinator.check_assignment_compatibility(¶m_type, &arg_type) {
210 if arg_type == TypeId::unspecified_int()
212 || arg_type == TypeId::unspecified_float()
213 {
214 if coordinator
215 .validate_literal_range(arg, ¶m_type)
216 .is_err()
217 {
218 return Err(SemanticAnalysisError::ArgumentTypeMismatch {
219 function_name: call_expr.name.clone(),
220 argument_position: i + 1,
221 expected: param_type.clone(),
222 actual: arg_type,
223 location: arg.location(),
224 });
225 }
226 } else {
227 return Err(SemanticAnalysisError::ArgumentTypeMismatch {
229 function_name: call_expr.name.clone(),
230 argument_position: i + 1,
231 expected: param_type.clone(),
232 actual: arg_type,
233 location: arg.location(),
234 });
235 }
236 }
237 }
238
239 Ok(func_type.return_type.clone())
240 } else {
241 Err(SemanticAnalysisError::UndefinedFunction {
242 name: call_expr.name.clone(),
243 location: call_expr.location,
244 })
245 }
246 }
247
248 pub fn visit_variable_expression(&mut self, var_expr: &VariableExpr) -> SemanticResult {
250 if let Some(var_info) = self.resolve_value(&var_expr.name) {
251 Ok(var_info.type_id.clone())
252 } else {
253 Err(SemanticAnalysisError::UndefinedVariable {
254 name: var_expr.name.clone(),
255 location: var_expr.location,
256 })
257 }
258 }
259
260 pub fn visit_literal_expression(&mut self, literal_expr: &LiteralExpr) -> SemanticResult {
262 Ok(literal_expr.expr_type.clone())
263 }
264
265 pub fn visit_conditional_expression(&mut self, cond_expr: &ConditionalExpr) -> SemanticResult {
267 let condition_type = self.visit_expression(&cond_expr.condition)?;
268 if condition_type != TypeId::bool() {
269 return Err(SemanticAnalysisError::TypeMismatch {
270 expected: TypeId::bool(),
271 actual: condition_type,
272 context: Some("if condition".to_string()),
273 location: cond_expr.condition.location(),
274 });
275 }
276
277 let then_type = self.visit_expression(&cond_expr.then_branch)?;
278 let else_type = self.visit_expression(&cond_expr.else_branch)?;
279
280 if then_type == TypeId::unknown() {
281 Ok(else_type)
282 } else if else_type == TypeId::unknown() || then_type == else_type {
283 Ok(then_type)
284 } else {
285 Err(SemanticAnalysisError::TypeMismatch {
286 expected: then_type,
287 actual: else_type,
288 context: Some(
289 "conditional expression branches must have the same type".to_string(),
290 ),
291 location: cond_expr.location,
292 })
293 }
294 }
295
296 pub fn visit_block_expression(&mut self, block_expr: &BlockExpr) -> SemanticResult {
298 self.context.begin_scope();
299
300 for stmt in &block_expr.statements {
302 let mut stmt_visitor = super::statement_visitor::StatementVisitor::with_return_type(
304 self.context,
305 self.current_return_type.clone(),
306 );
307 match stmt {
308 Statement::Let(let_stmt) => {
309 stmt_visitor.visit_let_statement(let_stmt)?;
310 }
311 Statement::Assignment(assign_stmt) => {
312 stmt_visitor.visit_assignment_statement(assign_stmt)?;
313 }
314 Statement::Expression(expr) => {
315 self.visit_expression(expr)?;
316 }
317 Statement::If(if_stmt) => {
318 stmt_visitor.visit_if_statement(if_stmt)?;
319 }
320 Statement::Return(return_stmt) => {
321 stmt_visitor.visit_return_statement(return_stmt)?;
322 }
323 Statement::FunctionDeclaration(fn_decl) => {
324 stmt_visitor.visit_function_declaration(fn_decl)?;
325 }
326 Statement::TypeDefinition(type_def) => {
327 stmt_visitor.visit_type_definition_statement(type_def)?;
328 }
329 }
330 }
331
332 let block_type = if let Some(return_expr) = &block_expr.return_expr {
333 self.visit_expression(return_expr)?
334 } else {
335 TypeId::unit()
336 };
337
338 self.context.end_scope();
339
340 Ok(block_type)
341 }
342
343 pub fn visit_function_type_expression(
345 &mut self,
346 func_type_expr: &FunctionTypeExpr,
347 ) -> SemanticResult {
348 for param_type in &func_type_expr.param_types {
350 if self.context.get_type_info(param_type).is_none() {
351 return Err(SemanticAnalysisError::InvalidFieldType {
352 struct_name: "function type".to_string(),
353 field_name: "parameter".to_string(),
354 type_id: param_type.clone(),
355 location: func_type_expr.location,
356 });
357 }
358 }
359
360 if self
362 .context
363 .get_type_info(&func_type_expr.return_type)
364 .is_none()
365 {
366 return Err(SemanticAnalysisError::InvalidFieldType {
367 struct_name: "function type".to_string(),
368 field_name: "return type".to_string(),
369 type_id: func_type_expr.return_type.clone(),
370 location: func_type_expr.location,
371 });
372 }
373
374 Ok(func_type_expr.expr_type.clone())
375 }
376
377 fn resolve_value(&self, name: &str) -> Option<&slang_shared::Symbol> {
381 self.context
382 .lookup_symbol(name)
383 .filter(|symbol| matches!(symbol.kind(), SymbolKind::Variable | SymbolKind::Function))
384 }
385}