1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{Data, DeriveInput, Expr, ExprLit, Lit, Meta, MetaNameValue, Variant, parse_macro_input};
4
5#[proc_macro_derive(NamedEnum, attributes(name))]
22pub fn derive_named_enum(input: TokenStream) -> TokenStream {
23 let input = parse_macro_input!(input as DeriveInput);
24 let enum_name = &input.ident;
25
26 let variants = if let Data::Enum(data_enum) = &input.data {
27 &data_enum.variants
28 } else {
29 panic!("NamedEnum can only be derived for enums");
30 };
31 let variant_mappings = variants
32 .iter()
33 .map(|variant| {
34 let variant_name = &variant.ident;
35 let string_name = extract_name_attribute(variant)
36 .unwrap_or_else(|| variant_name.to_string().to_lowercase());
37 (variant_name, string_name)
38 })
39 .collect::<Vec<_>>();
40
41 let type_name_arms = variant_mappings.iter().map(|(variant_name, string_name)| {
42 quote! {
43 #enum_name::#variant_name => #string_name
44 }
45 });
46
47 let from_str_arms = variant_mappings.iter().map(|(variant_name, string_name)| {
48 quote! {
49 #string_name => Some(#enum_name::#variant_name)
50 }
51 });
52
53 let expanded = quote! {
54 impl #enum_name {
55 pub const fn name(&self) -> &'static str {
56 match self {
57 #(#type_name_arms),*
58 }
59 }
60
61 pub fn from_str(s: &str) -> Option<Self> {
62 match s {
63 #(#from_str_arms),*,
64 _ => None,
65 }
66 }
67 }
68 };
69
70 proc_macro::TokenStream::from(expanded)
71}
72
73fn extract_name_attribute(variant: &Variant) -> Option<String> {
75 variant
76 .attrs
77 .iter()
78 .find(|attr| attr.path().is_ident("name"))
79 .map(|attr| match &attr.meta {
80 Meta::NameValue(MetaNameValue { value, .. }) => {
81 if let Expr::Lit(ExprLit {
82 lit: Lit::Str(lit_str),
83 ..
84 }) = value
85 {
86 lit_str.value()
87 } else {
88 panic!("name attribute must have a string literal value");
89 }
90 }
91 _ => panic!("name attribute must be in the form #[name = \"value\"]"),
92 })
93}
94
95#[proc_macro_derive(NumericEnum)]
138pub fn derive_numeric_enum(input: TokenStream) -> TokenStream {
139 let input = parse_macro_input!(input as DeriveInput);
140 let enum_name = &input.ident;
141
142 let variants = if let Data::Enum(data_enum) = &input.data {
143 &data_enum.variants
144 } else {
145 panic!("NumericEnum can only be derived for enums");
146 };
147
148 let mut next_discriminant = 0usize;
149
150 let mut variant_values = Vec::new();
151
152 for variant in variants.iter() {
153 let variant_name = &variant.ident;
154
155 let value = if let Some((_, expr)) = &variant.discriminant {
156 if let Expr::Lit(ExprLit {
157 lit: Lit::Int(lit_int),
158 ..
159 }) = expr
160 {
161 let parsed_value = lit_int
162 .base10_parse::<usize>()
163 .expect("Enum discriminant must be a valid integer");
164 next_discriminant = parsed_value + 1;
165 parsed_value
166 } else {
167 panic!("NumericEnum requires integer literals as enum discriminants");
168 }
169 } else {
170 let value = next_discriminant;
171 next_discriminant += 1;
172 value
173 };
174
175 variant_values.push((variant_name, value));
176 }
177
178 let from_int_arms = variant_values.iter().map(|(variant_name, value)| {
179 quote! {
180 #value => Some(#enum_name::#variant_name)
181 }
182 });
183
184 let expanded = quote! {
185 impl #enum_name {
186 pub fn from_int<T: Into<usize>>(value: T) -> Option<Self> {
187 let value = value.into();
188 match value {
189 #(#from_int_arms),*,
190 _ => None,
191 }
192 }
193 }
194 };
195
196 proc_macro::TokenStream::from(expanded)
197}
198
199#[proc_macro_derive(IterableEnum)]
221pub fn derive_iterable_enum(input: TokenStream) -> TokenStream {
222 let input = parse_macro_input!(input as DeriveInput);
223 let enum_name = &input.ident;
224
225 let variants = if let Data::Enum(data_enum) = &input.data {
226 &data_enum.variants
227 } else {
228 panic!("IterableEnum can only be derived for enums");
229 };
230
231 for variant in variants.iter() {
232 if !variant.fields.is_empty() {
233 panic!(
234 "IterableEnum can only be derived for enums with unit variants (no associated data)"
235 );
236 }
237 }
238
239 let variant_names = variants.iter().map(|variant| &variant.ident);
240 let variant_count = variants.len();
241
242 let expanded = quote! {
243 impl #enum_name {
244 pub fn iter() -> impl Iterator<Item = #enum_name> + Clone {
245 const VARIANTS: [#enum_name; #variant_count] = [
246 #(#enum_name::#variant_names),*
247 ];
248 VARIANTS.iter().copied()
249 }
250 }
251 };
252
253 proc_macro::TokenStream::from(expanded)
254}