slang_shared/compilation_context.rs
1use crate::symbol_table::SymbolData;
2use crate::{Symbol, SymbolKind, SymbolTable};
3use slang_types::{
4 FunctionType, PrimitiveType, StructType, TypeId, TypeInfo, TypeKind, TypeRegistry,
5};
6
7/// Compilation context that owns the type registry and symbol table
8pub struct CompilationContext {
9 /// The type registry that stores all types
10 type_registry: TypeRegistry,
11 /// The symbol table that stores all symbols (variables, types, functions)
12 symbol_table: SymbolTable,
13}
14
15impl Default for CompilationContext {
16 fn default() -> Self {
17 CompilationContext::new()
18 }
19}
20
21impl CompilationContext {
22 /// Creates a new compilation context with a type registry and symbol table
23 ///
24 /// Initializes the context with all primitive types registered in both the type registry
25 /// and symbol table. This includes boolean, integer, float, string, and unspecified types.
26 ///
27 /// ### Returns
28 /// A new CompilationContext instance ready for compilation
29 pub fn new() -> Self {
30 let type_registry = TypeRegistry::new_instance();
31 let mut symbol_table = SymbolTable::new();
32
33 let mut define_primitive = |ptype: PrimitiveType| {
34 let type_id = TypeId::from_primitive(ptype);
35 symbol_table
36 .define(ptype.name().to_string(), SymbolData::Type, type_id)
37 .unwrap_or_else(|_| {
38 panic!(
39 "Failed to define primitive type symbol for '{}'",
40 ptype.name()
41 )
42 });
43 type_id
44 };
45
46 define_primitive(PrimitiveType::Bool);
47 define_primitive(PrimitiveType::I32);
48 define_primitive(PrimitiveType::I64);
49 define_primitive(PrimitiveType::U32);
50 define_primitive(PrimitiveType::U64);
51 define_primitive(PrimitiveType::F32);
52 define_primitive(PrimitiveType::F64);
53 define_primitive(PrimitiveType::String);
54 define_primitive(PrimitiveType::UnspecifiedInt);
55 define_primitive(PrimitiveType::UnspecifiedFloat);
56 define_primitive(PrimitiveType::Unknown);
57
58 CompilationContext {
59 type_registry,
60 symbol_table,
61 }
62 }
63
64 /// Gets type information for a given type ID
65 ///
66 /// ### Arguments
67 /// * `id` - The type ID to look up
68 ///
69 /// ### Returns
70 /// An optional reference to the TypeInfo if the type exists
71 pub fn get_type_info(&self, id: &TypeId) -> Option<&TypeInfo> {
72 self.type_registry.get_type_info(id)
73 }
74
75 /// Gets the name of a type from its TypeId
76 ///
77 /// ### Arguments
78 /// * `type_id` - The type ID to get the name for
79 ///
80 /// ### Returns
81 /// The name of the type as a String, or a debug representation if the type is unknown
82 pub fn get_type_name(&self, type_id: &TypeId) -> String {
83 self.type_registry
84 .get_type_info(type_id)
85 .map(|t| t.name.clone())
86 .unwrap_or_else(|| format!("UnknownTypeId({:?})", type_id.0))
87 }
88
89 /// Gets the primitive type corresponding to a given type ID
90 ///
91 /// ### Arguments
92 /// * `id` - The type ID to look up
93 ///
94 /// ### Returns
95 /// An optional PrimitiveType if the type ID corresponds to a primitive type
96 pub fn get_primitive_type_from_id(&self, id: &TypeId) -> Option<PrimitiveType> {
97 self.type_registry.get_primitive_type(id)
98 }
99
100 /// Checks if a type ID corresponds to a primitive type
101 ///
102 /// ### Arguments
103 /// * `id` - The type ID to check
104 ///
105 /// ### Returns
106 /// True if the type is a primitive type, false otherwise
107 pub fn is_primitive_type(&self, id: &TypeId) -> bool {
108 self.type_registry.is_primitive_type(id)
109 }
110
111 /// Checks if a type fulfills a given predicate function
112 ///
113 /// ### Arguments
114 /// * `type_id` - The type ID to check
115 /// * `predicate` - A function that takes TypeInfo and returns a boolean
116 ///
117 /// ### Returns
118 /// True if the type exists and satisfies the predicate, false otherwise
119 pub fn type_fulfills<F>(&self, type_id: &TypeId, predicate: F) -> bool
120 where
121 F: Fn(&TypeInfo) -> bool,
122 {
123 self.get_type_info(type_id).is_some_and(predicate)
124 }
125
126 /// Checks if a type ID corresponds to a numeric type (integer or float)
127 ///
128 /// ### Arguments
129 /// * `type_id` - The type ID to check
130 ///
131 /// ### Returns
132 /// True if the type is numeric (integer or float), false otherwise
133 pub fn is_numeric_type(&self, type_id: &TypeId) -> bool {
134 self.get_primitive_type_from_id(type_id)
135 .is_some_and(|pt| pt.is_numeric())
136 }
137
138 /// Checks if a type ID corresponds to an integer type
139 ///
140 /// ### Arguments
141 /// * `type_id` - The type ID to check
142 ///
143 /// ### Returns
144 /// True if the type is an integer type (signed or unsigned), false otherwise
145 pub fn is_integer_type(&self, type_id: &TypeId) -> bool {
146 self.get_primitive_type_from_id(type_id)
147 .is_some_and(|pt| pt.is_integer())
148 }
149
150 /// Checks if a type ID corresponds to a floating-point type
151 ///
152 /// ### Arguments
153 /// * `type_id` - The type ID to check
154 ///
155 /// ### Returns
156 /// True if the type is a floating-point type (f32 or f64), false otherwise
157 pub fn is_float_type(&self, type_id: &TypeId) -> bool {
158 self.get_primitive_type_from_id(type_id)
159 .is_some_and(|pt| pt.is_float())
160 }
161
162 /// Checks if a type ID corresponds to a signed integer type
163 ///
164 /// ### Arguments
165 /// * `type_id` - The type ID to check
166 ///
167 /// ### Returns
168 /// True if the type is a signed integer type (i32 or i64), false otherwise
169 pub fn is_signed_integer_type(&self, type_id: &TypeId) -> bool {
170 self.get_primitive_type_from_id(type_id)
171 .is_some_and(|pt| pt.is_signed_integer())
172 }
173
174 /// Checks if a type ID corresponds to an unsigned integer type
175 ///
176 /// ### Arguments
177 /// * `type_id` - The type ID to check
178 ///
179 /// ### Returns
180 /// True if the type is an unsigned integer type (u32 or u64), false otherwise
181 pub fn is_unsigned_integer_type(&self, type_id: &TypeId) -> bool {
182 self.get_primitive_type_from_id(type_id)
183 .is_some_and(|pt| pt.is_unsigned_integer())
184 }
185
186 /// Gets the bit width of a type
187 ///
188 /// ### Arguments
189 /// * `type_id` - The type ID to get the bit width for
190 ///
191 /// ### Returns
192 /// The bit width of the type, or 0 if the type is not a primitive type
193 pub fn get_bit_width(&self, type_id: &TypeId) -> u8 {
194 self.get_primitive_type_from_id(type_id)
195 .map_or(0, |pt| pt.bit_width())
196 }
197
198 /// Checks if an integer value is within the valid range for a given type
199 ///
200 /// ### Arguments
201 /// * `value` - The integer value to check
202 /// * `type_id` - The type ID to check the value against
203 ///
204 /// ### Returns
205 /// True if the value is within the valid range for the type, false otherwise
206 pub fn check_value_in_range(&self, value: &i64, type_id: &TypeId) -> bool {
207 self.type_registry.check_value_in_range(value, type_id)
208 }
209
210 /// Checks if a floating-point value is within the valid range for a given type
211 ///
212 /// ### Arguments
213 /// * `value` - The floating-point value to check
214 /// * `type_id` - The type ID to check the value against
215 ///
216 /// ### Returns
217 /// True if the value is within the valid range for the type, false otherwise
218 pub fn check_float_value_in_range(&self, value: &f64, type_id: &TypeId) -> bool {
219 self.type_registry
220 .check_float_value_in_range(value, type_id)
221 }
222
223 /// Defines a symbol in the symbol table
224 ///
225 /// ### Arguments
226 /// * `name` - The name of the symbol
227 /// * `kind` - The kind of symbol (variable, type, function)
228 /// * `type_id` - The type ID associated with the symbol
229 /// * `is_mutable` - Whether the symbol is mutable (only relevant for variables)
230 ///
231 /// ### Returns
232 /// A Result indicating success or an error message if the symbol cannot be defined
233 pub fn define_symbol(
234 &mut self,
235 name: String,
236 kind: SymbolKind,
237 type_id: TypeId,
238 is_mutable: bool,
239 ) -> Result<(), String> {
240 let data = match kind {
241 SymbolKind::Type => SymbolData::Type,
242 SymbolKind::Variable => SymbolData::Variable { is_mutable },
243 SymbolKind::Function => SymbolData::Function,
244 };
245 self.symbol_table.define(name, data, type_id)
246 }
247
248 /// Looks up a symbol in the symbol table by name
249 ///
250 /// ### Arguments
251 /// * `name` - The name of the symbol to look up
252 ///
253 /// ### Returns
254 /// An optional reference to the Symbol if found, None otherwise
255 pub fn lookup_symbol(&self, name: &str) -> Option<&Symbol> {
256 self.symbol_table.lookup(name)
257 }
258
259 /// Registers a custom type with the given name and type kind
260 ///
261 /// ### Arguments
262 /// * `name` - The name of the custom type
263 /// * `type_kind` - The kind of type to register (struct, enum, etc.)
264 ///
265 /// ### Returns
266 /// A Result containing the TypeId of the registered type or an error message if the name is already defined
267 pub fn register_custom_type(
268 &mut self,
269 name: &str,
270 type_kind: TypeKind,
271 ) -> Result<TypeId, String> {
272 if self.symbol_table.lookup(name).is_some() {
273 return Err(format!("Symbol '{}' is already defined.", name));
274 }
275
276 let type_id = self.type_registry.register_type(name, type_kind);
277 self.symbol_table
278 .define(name.to_string(), SymbolData::Type, type_id)?;
279 Ok(type_id)
280 }
281
282 /// Registers a new struct type with the given name and fields
283 ///
284 /// ### Arguments
285 /// * `name` - The name of the struct type
286 /// * `fields` - A vector of tuples containing field names and their type IDs
287 ///
288 /// ### Returns
289 /// A Result containing the TypeId of the registered struct type or an error message
290 pub fn register_struct_type(
291 &mut self,
292 name: String,
293 fields: Vec<(String, TypeId)>,
294 ) -> Result<TypeId, String> {
295 let struct_type = StructType::new(name.clone(), fields);
296 let type_kind = TypeKind::Struct(struct_type);
297 self.register_custom_type(&name, type_kind)
298 }
299
300 /// Registers a function type and returns its TypeId
301 pub fn register_function_type(
302 &mut self,
303 param_types: Vec<TypeId>,
304 return_type: TypeId,
305 ) -> TypeId {
306 self.type_registry
307 .register_function_type(param_types, return_type)
308 }
309
310 /// Checks if a type is a function type
311 pub fn is_function_type(&self, type_id: &TypeId) -> bool {
312 if let Some(type_info) = self.type_registry.get_type_info(type_id) {
313 matches!(type_info.kind, TypeKind::Function(_))
314 } else {
315 false
316 }
317 }
318
319 /// Gets function type information
320 pub fn get_function_type(&self, type_id: &TypeId) -> Option<&FunctionType> {
321 if let Some(type_info) = self.type_registry.get_type_info(type_id) {
322 if let TypeKind::Function(ref function_type) = type_info.kind {
323 Some(function_type)
324 } else {
325 None
326 }
327 } else {
328 None
329 }
330 }
331
332 /// Begins a new scope by calling the symbol table
333 /// Used when entering a block, function, or other lexical scope.
334 pub fn begin_scope(&mut self) {
335 self.symbol_table.begin_scope();
336 }
337
338 /// Ends the current scope by calling the symbol table
339 /// Used when exiting a block, function, or other lexical scope.
340 pub fn end_scope(&mut self) {
341 self.symbol_table.end_scope();
342 }
343}