slang_shared/
compilation_context.rs

1use crate::symbol_table::SymbolData;
2use crate::{Symbol, SymbolKind, SymbolTable};
3use slang_types::{
4    FunctionType, PrimitiveType, StructType, TypeId, TypeInfo, TypeKind, TypeRegistry,
5};
6
7/// Compilation context that owns the type registry and symbol table
8pub struct CompilationContext {
9    /// The type registry that stores all types
10    type_registry: TypeRegistry,
11    /// The symbol table that stores all symbols (variables, types, functions)
12    symbol_table: SymbolTable,
13}
14
15impl Default for CompilationContext {
16    fn default() -> Self {
17        CompilationContext::new()
18    }
19}
20
21impl CompilationContext {
22    /// Creates a new compilation context with a type registry and symbol table
23    ///
24    /// Initializes the context with all primitive types registered in both the type registry
25    /// and symbol table. This includes boolean, integer, float, string, and unspecified types.
26    ///
27    /// ### Returns
28    /// A new CompilationContext instance ready for compilation
29    pub fn new() -> Self {
30        let type_registry = TypeRegistry::new_instance();
31        let mut symbol_table = SymbolTable::new();
32
33        let mut define_primitive = |ptype: PrimitiveType| {
34            let type_id = TypeId::from_primitive(ptype);
35            symbol_table
36                .define(ptype.name().to_string(), SymbolData::Type, type_id)
37                .unwrap_or_else(|_| {
38                    panic!(
39                        "Failed to define primitive type symbol for '{}'",
40                        ptype.name()
41                    )
42                });
43            type_id
44        };
45
46        define_primitive(PrimitiveType::Bool);
47        define_primitive(PrimitiveType::I32);
48        define_primitive(PrimitiveType::I64);
49        define_primitive(PrimitiveType::U32);
50        define_primitive(PrimitiveType::U64);
51        define_primitive(PrimitiveType::F32);
52        define_primitive(PrimitiveType::F64);
53        define_primitive(PrimitiveType::String);
54        define_primitive(PrimitiveType::UnspecifiedInt);
55        define_primitive(PrimitiveType::UnspecifiedFloat);
56        define_primitive(PrimitiveType::Unknown);
57
58        CompilationContext {
59            type_registry,
60            symbol_table,
61        }
62    }
63
64    /// Gets type information for a given type ID
65    ///
66    /// ### Arguments
67    /// * `id` - The type ID to look up
68    ///
69    /// ### Returns
70    /// An optional reference to the TypeInfo if the type exists
71    pub fn get_type_info(&self, id: &TypeId) -> Option<&TypeInfo> {
72        self.type_registry.get_type_info(id)
73    }
74
75    /// Gets the name of a type from its TypeId
76    ///
77    /// ### Arguments
78    /// * `type_id` - The type ID to get the name for
79    ///
80    /// ### Returns
81    /// The name of the type as a String, or a debug representation if the type is unknown
82    pub fn get_type_name(&self, type_id: &TypeId) -> String {
83        self.type_registry
84            .get_type_info(type_id)
85            .map(|t| t.name.clone())
86            .unwrap_or_else(|| format!("UnknownTypeId({:?})", type_id.0))
87    }
88
89    /// Gets the primitive type corresponding to a given type ID
90    ///
91    /// ### Arguments
92    /// * `id` - The type ID to look up
93    ///
94    /// ### Returns
95    /// An optional PrimitiveType if the type ID corresponds to a primitive type
96    pub fn get_primitive_type_from_id(&self, id: &TypeId) -> Option<PrimitiveType> {
97        self.type_registry.get_primitive_type(id)
98    }
99
100    /// Checks if a type ID corresponds to a primitive type
101    ///
102    /// ### Arguments
103    /// * `id` - The type ID to check
104    ///
105    /// ### Returns
106    /// True if the type is a primitive type, false otherwise
107    pub fn is_primitive_type(&self, id: &TypeId) -> bool {
108        self.type_registry.is_primitive_type(id)
109    }
110
111    /// Checks if a type fulfills a given predicate function
112    ///
113    /// ### Arguments
114    /// * `type_id` - The type ID to check
115    /// * `predicate` - A function that takes TypeInfo and returns a boolean
116    ///
117    /// ### Returns
118    /// True if the type exists and satisfies the predicate, false otherwise
119    pub fn type_fulfills<F>(&self, type_id: &TypeId, predicate: F) -> bool
120    where
121        F: Fn(&TypeInfo) -> bool,
122    {
123        self.get_type_info(type_id).is_some_and(predicate)
124    }
125
126    /// Checks if a type ID corresponds to a numeric type (integer or float)
127    ///
128    /// ### Arguments
129    /// * `type_id` - The type ID to check
130    ///
131    /// ### Returns
132    /// True if the type is numeric (integer or float), false otherwise
133    pub fn is_numeric_type(&self, type_id: &TypeId) -> bool {
134        self.get_primitive_type_from_id(type_id)
135            .is_some_and(|pt| pt.is_numeric())
136    }
137
138    /// Checks if a type ID corresponds to an integer type
139    ///
140    /// ### Arguments
141    /// * `type_id` - The type ID to check
142    ///
143    /// ### Returns
144    /// True if the type is an integer type (signed or unsigned), false otherwise
145    pub fn is_integer_type(&self, type_id: &TypeId) -> bool {
146        self.get_primitive_type_from_id(type_id)
147            .is_some_and(|pt| pt.is_integer())
148    }
149
150    /// Checks if a type ID corresponds to a floating-point type
151    ///
152    /// ### Arguments
153    /// * `type_id` - The type ID to check
154    ///
155    /// ### Returns
156    /// True if the type is a floating-point type (f32 or f64), false otherwise
157    pub fn is_float_type(&self, type_id: &TypeId) -> bool {
158        self.get_primitive_type_from_id(type_id)
159            .is_some_and(|pt| pt.is_float())
160    }
161
162    /// Checks if a type ID corresponds to a signed integer type
163    ///
164    /// ### Arguments
165    /// * `type_id` - The type ID to check
166    ///
167    /// ### Returns
168    /// True if the type is a signed integer type (i32 or i64), false otherwise
169    pub fn is_signed_integer_type(&self, type_id: &TypeId) -> bool {
170        self.get_primitive_type_from_id(type_id)
171            .is_some_and(|pt| pt.is_signed_integer())
172    }
173
174    /// Checks if a type ID corresponds to an unsigned integer type
175    ///
176    /// ### Arguments
177    /// * `type_id` - The type ID to check
178    ///
179    /// ### Returns
180    /// True if the type is an unsigned integer type (u32 or u64), false otherwise
181    pub fn is_unsigned_integer_type(&self, type_id: &TypeId) -> bool {
182        self.get_primitive_type_from_id(type_id)
183            .is_some_and(|pt| pt.is_unsigned_integer())
184    }
185
186    /// Gets the bit width of a type
187    ///
188    /// ### Arguments
189    /// * `type_id` - The type ID to get the bit width for
190    ///
191    /// ### Returns
192    /// The bit width of the type, or 0 if the type is not a primitive type
193    pub fn get_bit_width(&self, type_id: &TypeId) -> u8 {
194        self.get_primitive_type_from_id(type_id)
195            .map_or(0, |pt| pt.bit_width())
196    }
197
198    /// Checks if an integer value is within the valid range for a given type
199    ///
200    /// ### Arguments
201    /// * `value` - The integer value to check
202    /// * `type_id` - The type ID to check the value against
203    ///
204    /// ### Returns
205    /// True if the value is within the valid range for the type, false otherwise
206    pub fn check_value_in_range(&self, value: &i64, type_id: &TypeId) -> bool {
207        self.type_registry.check_value_in_range(value, type_id)
208    }
209
210    /// Checks if a floating-point value is within the valid range for a given type
211    ///
212    /// ### Arguments
213    /// * `value` - The floating-point value to check
214    /// * `type_id` - The type ID to check the value against
215    ///
216    /// ### Returns
217    /// True if the value is within the valid range for the type, false otherwise
218    pub fn check_float_value_in_range(&self, value: &f64, type_id: &TypeId) -> bool {
219        self.type_registry
220            .check_float_value_in_range(value, type_id)
221    }
222
223    /// Defines a symbol in the symbol table
224    ///
225    /// ### Arguments
226    /// * `name` - The name of the symbol
227    /// * `kind` - The kind of symbol (variable, type, function)
228    /// * `type_id` - The type ID associated with the symbol
229    /// * `is_mutable` - Whether the symbol is mutable (only relevant for variables)
230    ///
231    /// ### Returns
232    /// A Result indicating success or an error message if the symbol cannot be defined
233    pub fn define_symbol(
234        &mut self,
235        name: String,
236        kind: SymbolKind,
237        type_id: TypeId,
238        is_mutable: bool,
239    ) -> Result<(), String> {
240        let data = match kind {
241            SymbolKind::Type => SymbolData::Type,
242            SymbolKind::Variable => SymbolData::Variable { is_mutable },
243            SymbolKind::Function => SymbolData::Function,
244        };
245        self.symbol_table.define(name, data, type_id)
246    }
247
248    /// Looks up a symbol in the symbol table by name
249    ///
250    /// ### Arguments
251    /// * `name` - The name of the symbol to look up
252    ///
253    /// ### Returns
254    /// An optional reference to the Symbol if found, None otherwise
255    pub fn lookup_symbol(&self, name: &str) -> Option<&Symbol> {
256        self.symbol_table.lookup(name)
257    }
258
259    /// Registers a custom type with the given name and type kind
260    ///
261    /// ### Arguments
262    /// * `name` - The name of the custom type
263    /// * `type_kind` - The kind of type to register (struct, enum, etc.)
264    ///
265    /// ### Returns
266    /// A Result containing the TypeId of the registered type or an error message if the name is already defined
267    pub fn register_custom_type(
268        &mut self,
269        name: &str,
270        type_kind: TypeKind,
271    ) -> Result<TypeId, String> {
272        if self.symbol_table.lookup(name).is_some() {
273            return Err(format!("Symbol '{}' is already defined.", name));
274        }
275
276        let type_id = self.type_registry.register_type(name, type_kind);
277        self.symbol_table
278            .define(name.to_string(), SymbolData::Type, type_id)?;
279        Ok(type_id)
280    }
281
282    /// Registers a new struct type with the given name and fields
283    ///
284    /// ### Arguments
285    /// * `name` - The name of the struct type
286    /// * `fields` - A vector of tuples containing field names and their type IDs
287    ///
288    /// ### Returns
289    /// A Result containing the TypeId of the registered struct type or an error message
290    pub fn register_struct_type(
291        &mut self,
292        name: String,
293        fields: Vec<(String, TypeId)>,
294    ) -> Result<TypeId, String> {
295        let struct_type = StructType::new(name.clone(), fields);
296        let type_kind = TypeKind::Struct(struct_type);
297        self.register_custom_type(&name, type_kind)
298    }
299
300    /// Registers a function type and returns its TypeId
301    pub fn register_function_type(
302        &mut self,
303        param_types: Vec<TypeId>,
304        return_type: TypeId,
305    ) -> TypeId {
306        self.type_registry
307            .register_function_type(param_types, return_type)
308    }
309
310    /// Checks if a type is a function type
311    pub fn is_function_type(&self, type_id: &TypeId) -> bool {
312        if let Some(type_info) = self.type_registry.get_type_info(type_id) {
313            matches!(type_info.kind, TypeKind::Function(_))
314        } else {
315            false
316        }
317    }
318
319    /// Gets function type information
320    pub fn get_function_type(&self, type_id: &TypeId) -> Option<&FunctionType> {
321        if let Some(type_info) = self.type_registry.get_type_info(type_id) {
322            if let TypeKind::Function(ref function_type) = type_info.kind {
323                Some(function_type)
324            } else {
325                None
326            }
327        } else {
328            None
329        }
330    }
331
332    /// Begins a new scope by calling the symbol table
333    /// Used when entering a block, function, or other lexical scope.
334    pub fn begin_scope(&mut self) {
335        self.symbol_table.begin_scope();
336    }
337
338    /// Ends the current scope by calling the symbol table
339    /// Used when exiting a block, function, or other lexical scope.
340    pub fn end_scope(&mut self) {
341        self.symbol_table.end_scope();
342    }
343}