slang_ir/
ast.rs

1use crate::Location;
2use crate::Visitor;
3use slang_types::types::TypeId;
4use std::fmt::Display;
5
6#[derive(Debug, PartialEq)]
7pub enum BinaryOperator {
8    /// Addition operator
9    Add,
10    /// Subtraction operator
11    Subtract,
12    /// Multiplication operator
13    Multiply,
14    /// Division operator
15    Divide,
16    /// Greater than operator
17    GreaterThan,
18    /// Less than operator
19    LessThan,
20    /// Greater than or equal to operator
21    GreaterThanOrEqual,
22    /// Less than or equal to operator
23    LessThanOrEqual,
24    /// Equality operator
25    Equal,
26    /// Not equal operator
27    NotEqual,
28    /// Logical AND operator
29    And,
30    /// Logical OR operator
31    Or,
32}
33
34impl Display for BinaryOperator {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        let op_str = match self {
37            BinaryOperator::Add => "+",
38            BinaryOperator::Subtract => "-",
39            BinaryOperator::Multiply => "*",
40            BinaryOperator::Divide => "/",
41            BinaryOperator::GreaterThan => ">",
42            BinaryOperator::LessThan => "<",
43            BinaryOperator::GreaterThanOrEqual => ">=",
44            BinaryOperator::LessThanOrEqual => "<=",
45            BinaryOperator::Equal => "==",
46            BinaryOperator::NotEqual => "!=",
47            BinaryOperator::And => "&&",
48            BinaryOperator::Or => "||",
49        };
50        write!(f, "{}", op_str)
51    }
52}
53
54#[derive(Debug, PartialEq)]
55pub enum UnaryOperator {
56    /// Negation operator
57    Negate,
58    /// Logical NOT operator
59    Not,
60}
61
62impl Display for UnaryOperator {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        let op_str = match self {
65            UnaryOperator::Negate => "-",
66            UnaryOperator::Not => "!",
67        };
68        write!(f, "{}", op_str)
69    }
70}
71
72/// Expression nodes in the AST
73#[derive(Debug)]
74pub enum Expression {
75    /// A literal value (constant)
76    Literal(LiteralExpr),
77    /// A binary operation (e.g., a + b)
78    Binary(BinaryExpr),
79    /// A variable reference
80    Variable(VariableExpr),
81    /// A unary operation (e.g., -x)
82    Unary(UnaryExpr),
83    /// A function call
84    Call(FunctionCallExpr),
85    /// A conditional expression (if/else)
86    Conditional(ConditionalExpr),
87    /// A block expression with statements and optional return value
88    Block(BlockExpr),
89    /// A function type expression (e.g., fn(i32, string) -> string)
90    FunctionType(FunctionTypeExpr),
91}
92
93impl Expression {
94    pub fn location(&self) -> Location {
95        match self {
96            Expression::Literal(e) => e.location,
97            Expression::Binary(e) => e.location,
98            Expression::Variable(e) => e.location,
99            Expression::Unary(e) => e.location,
100            Expression::Call(e) => e.location,
101            Expression::Conditional(e) => e.location,
102            Expression::Block(e) => e.location,
103            Expression::FunctionType(e) => e.location,
104        }
105    }
106}
107
108/// Statement nodes in the AST
109#[derive(Debug)]
110pub enum Statement {
111    /// Variable declaration
112    Let(LetStatement),
113    /// Variable assignment
114    Assignment(AssignmentStatement),
115    /// Expression statement
116    Expression(Expression),
117    /// Type definition (e.g., struct)
118    TypeDefinition(TypeDefinitionStmt),
119    /// Function declaration
120    FunctionDeclaration(FunctionDeclarationStmt),
121    /// Return statement
122    Return(ReturnStatement),
123    /// Conditional statement (if/else)
124    If(IfStatement),
125}
126
127/// A function call expression
128#[derive(Debug)]
129pub struct FunctionCallExpr {
130    /// Name of the function being called
131    pub name: String,
132    /// Arguments passed to the function
133    pub arguments: Vec<Expression>,
134    /// Type of the function call expression
135    pub expr_type: TypeId,
136    /// Source code location information
137    pub location: Location,
138}
139
140/// A conditional expression (if/else)
141#[derive(Debug)]
142pub struct ConditionalExpr {
143    /// Condition to evaluate
144    pub condition: Box<Expression>,
145    /// Expression to evaluate if condition is true
146    pub then_branch: Box<Expression>,
147    /// Expression to evaluate if condition is false (always present for expressions)
148    pub else_branch: Box<Expression>,
149    /// Type of the conditional expression
150    pub expr_type: TypeId,
151    /// Source code location information
152    pub location: Location,
153}
154
155/// A block expression containing statements and an optional return value
156#[derive(Debug)]
157pub struct BlockExpr {
158    /// Statements in the block
159    pub statements: Vec<Statement>,
160    /// Optional final expression that becomes the return value (without semicolon)
161    pub return_expr: Option<Box<Expression>>,
162    /// Type of the block expression
163    pub expr_type: TypeId,
164    /// Source code location information
165    pub location: Location,
166}
167
168/// A function parameter
169#[derive(Debug)]
170pub struct Parameter {
171    /// Parameter name
172    pub name: String,
173    /// Parameter type
174    pub param_type: TypeId,
175    /// Source code location information
176    pub location: Location,
177}
178
179/// A function declaration statement
180#[derive(Debug)]
181pub struct FunctionDeclarationStmt {
182    /// Function name
183    pub name: String,
184    /// Function parameters
185    pub parameters: Vec<Parameter>,
186    /// Function return type
187    pub return_type: TypeId,
188    /// Function body (block expression)
189    pub body: BlockExpr,
190    /// Source code location information
191    pub location: Location,
192}
193
194/// A function type expression (e.g., fn(i32, string) -> string)
195#[derive(Debug)]
196pub struct FunctionTypeExpr {
197    /// Parameter types of the function
198    pub param_types: Vec<TypeId>,
199    /// Return type of the function
200    pub return_type: TypeId,
201    /// Type of the function type expression (will be a function type)
202    pub expr_type: TypeId,
203    /// Source code location information
204    pub location: Location,
205}
206
207/// A type definition statement (like struct)
208#[derive(Debug)]
209pub struct TypeDefinitionStmt {
210    /// Name of the defined type
211    pub name: String,
212    /// Fields of the type with their names and types
213    pub fields: Vec<(String, TypeId)>,
214    /// Source code location information
215    pub location: Location,
216}
217
218/// A literal expression
219#[derive(Debug)]
220pub struct LiteralExpr {
221    /// Value of the literal
222    pub value: LiteralValue,
223    /// Type of the literal expression
224    pub expr_type: TypeId,
225    /// Source code location information
226    pub location: Location,
227}
228
229/// A variable reference expression
230#[derive(Debug)]
231pub struct VariableExpr {
232    /// Name of the variable being referenced
233    pub name: String,
234    /// Source code location information
235    pub location: Location,
236}
237
238/// A unary expression (e.g., -x)
239#[derive(Debug)]
240pub struct UnaryExpr {
241    /// The operator (e.g., -)
242    pub operator: UnaryOperator,
243    /// The operand
244    pub right: Box<Expression>,
245    /// Type of the unary expression
246    pub expr_type: TypeId,
247    /// Source code location information
248    pub location: Location,
249}
250
251/// Possible values for literal expressions
252#[derive(Debug)]
253pub enum LiteralValue {
254    /// 32-bit signed integer
255    I32(i32),
256    /// 64-bit signed integer
257    I64(i64),
258    /// 32-bit unsigned integer
259    U32(u32),
260    /// 64-bit unsigned integer
261    U64(u64),
262    /// Integer without specified type (needs inference)
263    UnspecifiedInteger(i64),
264    /// 32-bit floating point
265    F32(f32),
266    /// 64-bit floating point
267    F64(f64),
268    /// Float without specified type (needs inference)
269    UnspecifiedFloat(f64),
270    /// String value
271    String(String),
272    /// Boolean value (true or false)
273    Boolean(bool),
274    /// Unit value (similar to Rust's ())
275    Unit,
276}
277
278/// A binary expression (e.g., a + b)
279#[derive(Debug)]
280pub struct BinaryExpr {
281    /// Left operand
282    pub left: Box<Expression>,
283    /// Operator
284    pub operator: BinaryOperator,
285    /// Right operand
286    pub right: Box<Expression>,
287    /// Type of the binary expression
288    pub expr_type: TypeId,
289    /// Source code location information
290    pub location: Location,
291}
292
293/// A variable declaration statement
294#[derive(Debug)]
295pub struct LetStatement {
296    /// Name of the variable
297    pub name: String,
298    /// Whether the variable is mutable
299    pub is_mutable: bool,
300    /// Initial value for the variable
301    pub value: Expression,
302    /// Type of the variable
303    pub expr_type: TypeId,
304    /// Source code location information
305    pub location: Location,
306}
307
308/// A variable assignment statement
309#[derive(Debug)]
310pub struct AssignmentStatement {
311    /// Name of the variable being assigned
312    pub name: String,
313    /// New value for the variable
314    pub value: Expression,
315    /// Source code location information
316    pub location: Location,
317}
318
319/// A conditional statement (if/else)
320#[derive(Debug)]
321pub struct IfStatement {
322    /// Condition to evaluate
323    pub condition: Expression,
324    /// Block expression to execute if condition is true
325    pub then_branch: BlockExpr,
326    /// Optional block expression to execute if condition is false
327    pub else_branch: Option<BlockExpr>,
328    /// Source code location information
329    pub location: Location,
330}
331
332/// A return statement
333#[derive(Debug)]
334pub struct ReturnStatement {
335    /// Optional expression to return
336    pub value: Option<Expression>,
337    /// Source code location information
338    pub location: Location,
339}
340
341impl Statement {
342    /// Accepts a visitor for this statement
343    ///
344    /// ### Arguments
345    /// * `visitor` - The visitor to accept
346    ///
347    /// ### Returns
348    /// The result of the visitor's visit method for this statement
349    pub fn accept<T>(&self, visitor: &mut dyn Visitor<T>) -> T {
350        match self {
351            Statement::Let(let_stmt) => visitor.visit_let_statement(let_stmt),
352            Statement::Assignment(assign_stmt) => visitor.visit_assignment_statement(assign_stmt),
353            Statement::Expression(expr) => visitor.visit_expression_statement(expr),
354            Statement::TypeDefinition(type_def) => {
355                visitor.visit_type_definition_statement(type_def)
356            }
357            Statement::FunctionDeclaration(fn_decl) => {
358                visitor.visit_function_declaration_statement(fn_decl)
359            }
360            Statement::Return(return_stmt) => visitor.visit_return_statement(return_stmt),
361            Statement::If(if_stmt) => visitor.visit_if_statement(if_stmt),
362        }
363    }
364}
365
366impl Expression {
367    /// Accepts a visitor for this expression
368    ///
369    /// ### Arguments
370    /// * `visitor` - The visitor to accept
371    ///
372    /// ### Returns
373    /// The result of the visitor's visit method for this expression
374    pub fn accept<T>(&self, visitor: &mut dyn Visitor<T>) -> T {
375        match self {
376            Expression::Literal(lit) => visitor.visit_literal_expression(lit),
377            Expression::Binary(bin) => visitor.visit_binary_expression(bin),
378            Expression::Variable(var) => visitor.visit_variable_expression(var),
379            Expression::Unary(unary) => visitor.visit_unary_expression(unary),
380            Expression::Call(call) => visitor.visit_call_expression(call),
381            Expression::Conditional(cond) => visitor.visit_conditional_expression(cond),
382            Expression::Block(block) => visitor.visit_block_expression(block),
383            Expression::FunctionType(func_type) => {
384                visitor.visit_function_type_expression(func_type)
385            }
386        }
387    }
388}