1use pliron::combine::{Parser, between, optional, token};
4
5use pliron::builtin::type_interfaces::FunctionTypeInterface;
6use pliron::derive::{format, pliron_type, type_interface_impl};
7use pliron::{
8 common_traits::Verify,
9 context::{Context, Ptr},
10 identifier::Identifier,
11 input_err_noloc,
12 irfmt::{
13 parsers::{delimited_list_parser, location, spaced},
14 printers::{enclosed, list_with_sep},
15 },
16 location::Located,
17 parsable::{IntoParseResult, Parsable, ParseResult, StateStream},
18 printable::{self, ListSeparator, Printable},
19 result::Result,
20 r#type::{Type, TypeObj, TypePtr},
21 verify_err_noloc,
22};
23use thiserror::Error;
24
25use std::hash::Hash;
26
27#[pliron_type(name = "llvm.struct")]
38#[derive(Debug)]
39pub struct StructType {
40 name: Option<Identifier>,
41 fields: Option<Vec<Ptr<TypeObj>>>,
42}
43
44impl StructType {
45 pub fn get_named(
56 ctx: &mut Context,
57 name: Identifier,
58 fields: Option<Vec<Ptr<TypeObj>>>,
59 ) -> Result<TypePtr<Self>> {
60 let self_ptr = Type::register_instance(
61 StructType {
62 name: Some(name.clone()),
63 fields: None,
65 },
66 ctx,
67 );
68 let mut self_ref = self_ptr.to_ptr().deref_mut(ctx);
70 let self_ref = self_ref.downcast_mut::<StructType>().unwrap();
71 assert!(self_ref.name.as_ref().unwrap() == &name);
72 if let Some(fields) = fields {
73 if let Some(existing_fields) = &self_ref.fields {
75 if existing_fields != &fields {
77 input_err_noloc!(StructErr::ExistingMismatch(name.into()))?
78 }
79 } else {
80 self_ref.fields = Some(fields);
82 }
83 }
84 Ok(self_ptr)
85 }
86
87 pub fn get_unnamed(ctx: &mut Context, fields: Vec<Ptr<TypeObj>>) -> TypePtr<Self> {
90 Type::register_instance(
91 StructType {
92 name: None,
93 fields: Some(fields),
94 },
95 ctx,
96 )
97 }
98
99 pub fn get_existing_named(ctx: &Context, name: &Identifier) -> Option<TypePtr<Self>> {
101 Type::get_instance(
102 StructType {
103 name: Some(name.clone()),
104 fields: None,
106 },
107 ctx,
108 )
109 }
110
111 pub fn get_existing_unnamed(ctx: &Context, fields: Vec<Ptr<TypeObj>>) -> Option<TypePtr<Self>> {
113 Type::get_instance(
114 StructType {
115 name: None,
116 fields: Some(fields),
117 },
118 ctx,
119 )
120 }
121
122 pub fn is_opaque(&self) -> bool {
124 self.fields.is_none()
125 }
126
127 pub fn is_named(&self) -> bool {
129 self.name.is_some()
130 }
131
132 pub fn name(&self) -> Option<Identifier> {
134 self.name.clone()
135 }
136
137 pub fn field_type(&self, field_idx: usize) -> Ptr<TypeObj> {
139 self.fields
140 .as_ref()
141 .expect("field_type shouldn't be called on opaque types")[field_idx]
142 }
143
144 pub fn num_fields(&self) -> usize {
146 self.fields
147 .as_ref()
148 .expect("num_fields shouldn't be called on opaque types")
149 .len()
150 }
151
152 pub fn fields(&self) -> impl Iterator<Item = Ptr<TypeObj>> + '_ {
154 self.fields
155 .as_ref()
156 .expect("fields shouldn't be called on opaque types")
157 .iter()
158 .cloned()
159 }
160}
161
162#[derive(Debug, Error)]
163pub enum StructErr {
164 #[error("struct cannot be both opaque and anonymous")]
165 OpaqueAndAnonymousErr,
166 #[error("struct {0} already exists and is different")]
167 ExistingMismatch(String),
168}
169
170impl Verify for StructType {
171 fn verify(&self, _ctx: &Context) -> Result<()> {
172 if self.name.is_none() && self.fields.is_none() {
173 verify_err_noloc!(StructErr::OpaqueAndAnonymousErr)?
174 }
175 Ok(())
176 }
177}
178
179impl Printable for StructType {
180 fn fmt(
181 &self,
182 ctx: &Context,
183 state: &printable::State,
184 f: &mut core::fmt::Formatter<'_>,
185 ) -> core::fmt::Result {
186 write!(f, "<")?;
187
188 use std::cell::RefCell;
189 thread_local! {
192 static IN_PRINTING: RefCell<Vec<Identifier>> = const { RefCell::new(vec![]) };
195 }
196 if let Some(name) = &self.name {
197 let in_printing = IN_PRINTING.with(|f| f.borrow().contains(name));
198 if in_printing {
199 return write!(f, "{}>", name.clone());
200 }
201 IN_PRINTING.with(|f| f.borrow_mut().push(name.clone()));
202 write!(f, "{name}")?;
203 if !self.is_opaque() {
204 write!(f, " ")?;
205 }
206 }
207
208 if let Some(fields) = &self.fields {
209 enclosed(
210 "{ ",
211 " }",
212 list_with_sep(fields, ListSeparator::CharSpace(',')),
213 )
214 .fmt(ctx, state, f)?;
215 }
216
217 if let Some(name) = &self.name {
219 assert!(IN_PRINTING.with(|f| f.borrow().last().unwrap() == name));
220 IN_PRINTING.with(|f| f.borrow_mut().pop());
221 }
222 write!(f, ">")
223 }
224}
225
226impl Hash for StructType {
227 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
228 match &self.name {
229 Some(name) => name.hash(state),
230 None => self
231 .fields
232 .as_ref()
233 .expect("Anonymous struct must have its fields set")
234 .iter()
235 .for_each(|field_type| {
236 field_type.hash(state);
237 }),
238 }
239 }
240}
241
242impl PartialEq for StructType {
243 fn eq(&self, other: &Self) -> bool {
244 match (&self.name, &other.name) {
245 (Some(name), Some(other_name)) => name == other_name,
246 (None, None) => self.fields == other.fields,
247 _ => false,
248 }
249 }
250}
251
252impl Parsable for StructType {
253 type Arg = ();
254 type Parsed = TypePtr<Self>;
255
256 fn parse<'a>(
257 state_stream: &mut StateStream<'a>,
258 _arg: Self::Arg,
259 ) -> ParseResult<'a, Self::Parsed>
260 where
261 Self: Sized,
262 {
263 let body_parser = || {
264 delimited_list_parser('{', '}', ',', Ptr::<TypeObj>::parser(()))
266 };
267
268 let named = spaced((location(), Identifier::parser(())))
269 .and(spaced(optional(body_parser())))
270 .map(|((loc, name), body_opt)| (loc, Some(name), body_opt));
271 let anonymous = spaced((location(), body_parser()))
272 .map(|(loc, body)| (loc, None::<Identifier>, Some(body)));
273
274 let mut struct_parser = between(token('<'), token('>'), named.or(anonymous));
276
277 let (loc, name_opt, body_opt) = struct_parser.parse_stream(state_stream).into_result()?.0;
278 let ctx = &mut state_stream.state.ctx;
279 if let Some(name) = name_opt {
280 StructType::get_named(ctx, name, body_opt)
281 .map_err(|mut err| {
282 err.set_loc(loc);
283 err
284 })
285 .into_parse_result()
286 } else {
287 Ok(StructType::get_unnamed(
288 ctx,
289 body_opt.expect("Without a name, a struct type must have a body."),
290 ))
291 .into_parse_result()
292 }
293 }
294}
295
296impl Eq for StructType {}
297
298#[pliron_type(name = "llvm.ptr", generate_get = true, format, verifier = "succ")]
300#[derive(Hash, PartialEq, Eq, Debug)]
301pub struct PointerType;
302
303#[pliron_type(
305 name = "llvm.array",
306 generate_get = true,
307 format = "`[` $size ` x ` $elem `]`",
308 verifier = "succ"
309)]
310#[derive(Hash, PartialEq, Eq, Debug)]
311pub struct ArrayType {
312 elem: Ptr<TypeObj>,
313 size: u64,
314}
315
316impl ArrayType {
317 pub fn elem_type(&self) -> Ptr<TypeObj> {
319 self.elem
320 }
321
322 pub fn size(&self) -> u64 {
324 self.size
325 }
326}
327
328#[pliron_type(name = "llvm.void", generate_get = true, format, verifier = "succ")]
329#[derive(Hash, PartialEq, Eq, Debug)]
330pub struct VoidType;
331
332#[pliron_type(
333 name = "llvm.func",
334 generate_get = true,
335 format = "`<` $res `(` vec($args, CharSpace(`,`)) `) variadic = ` $is_var_arg `>`",
336 verifier = "succ"
337)]
338#[derive(Hash, PartialEq, Eq, Debug)]
339pub struct FuncType {
340 res: Ptr<TypeObj>,
341 args: Vec<Ptr<TypeObj>>,
342 is_var_arg: bool,
343}
344
345#[derive(Debug, Error)]
346pub enum FuncTypeErr {
347 #[error("Expected at most one result")]
348 TooManyResults,
349}
350
351impl FuncType {
352 pub fn result_type(&self) -> Ptr<TypeObj> {
354 self.res
355 }
356
357 pub fn is_var_arg(&self) -> bool {
359 self.is_var_arg
360 }
361}
362
363#[type_interface_impl]
364impl FunctionTypeInterface for FuncType {
365 fn arg_types(&self) -> Vec<Ptr<TypeObj>> {
366 self.args.clone()
367 }
368 fn res_types(&self) -> Vec<Ptr<TypeObj>> {
369 vec![self.res]
370 }
371}
372
373#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
376#[format]
377pub enum VectorTypeKind {
378 Fixed,
379 Scalable,
380}
381
382#[pliron_type(
383 name = "llvm.vector",
384 generate_get = true,
385 format = "`<` $kind ` x ` $num_elems ` x ` $elem_ty `>`",
386 verifier = "succ"
387)]
388#[derive(Hash, PartialEq, Eq, Debug)]
389pub struct VectorType {
390 elem_ty: Ptr<TypeObj>,
391 num_elems: u32,
392 kind: VectorTypeKind,
393}
394
395impl VectorType {
396 pub fn elem_type(&self) -> Ptr<TypeObj> {
398 self.elem_ty
399 }
400
401 pub fn num_elements(&self) -> u32 {
403 self.num_elems
404 }
405
406 pub fn is_scalable(&self) -> bool {
408 self.kind == VectorTypeKind::Scalable
409 }
410
411 pub fn kind(&self) -> VectorTypeKind {
413 self.kind
414 }
415}
416
417#[cfg(test)]
418mod tests {
419
420 use expect_test::expect;
421 use pliron::combine::{self, Parser, eof, token};
422 use pliron::derive::{pliron_type, verify_succ};
423
424 use crate::types::{FuncType, StructType, VoidType};
425 use pliron::{
426 builtin::types::{IntegerType, Signedness},
427 context::{Context, Ptr},
428 identifier::Identifier,
429 irfmt::parsers::{spaced, type_parser},
430 location,
431 parsable::{self, Parsable, ParseResult, StateStream, state_stream_from_iterator},
432 printable::{self, Printable},
433 result::Result,
434 r#type::{Type, TypeObj, TypePtr},
435 };
436
437 #[test]
438 fn test_struct() -> Result<()> {
439 let mut ctx = Context::new();
440 let int64_ptr = IntegerType::get(&mut ctx, 64, Signedness::Signless).into();
441 let linked_list_id: Identifier = "LinkedList".try_into().unwrap();
442 let linked_list_2_id: Identifier = "LinkedList2".try_into().unwrap();
443
444 let list_struct: Ptr<TypeObj> =
446 StructType::get_named(&mut ctx, linked_list_id.clone(), None)?.into();
447 assert!(
448 list_struct
449 .deref(&ctx)
450 .downcast_ref::<StructType>()
451 .unwrap()
452 .is_opaque()
453 );
454 let list_struct_ptr = TypedPointerType::get(&mut ctx, list_struct).into();
455 let fields = vec![int64_ptr, list_struct_ptr];
456 StructType::get_named(&mut ctx, linked_list_id.clone(), Some(fields))?;
458 assert!(
459 !list_struct
460 .deref(&ctx)
461 .downcast_ref::<StructType>()
462 .unwrap()
463 .is_opaque()
464 );
465
466 let list_struct_2 = StructType::get_existing_named(&ctx, &linked_list_id)
467 .unwrap()
468 .into();
469 assert!(list_struct == list_struct_2);
470 assert!(StructType::get_existing_named(&ctx, &linked_list_2_id).is_none());
471
472 assert_eq!(
473 list_struct.disp(&ctx).to_string(),
474 "llvm.struct <LinkedList { builtin.integer i64, llvm.typed_ptr <llvm.struct <LinkedList>> }>"
475 );
476
477 let head_fields = vec![int64_ptr, list_struct_ptr];
478 let head_struct = StructType::get_unnamed(&mut ctx, head_fields.clone());
479 let head_struct2 = StructType::get_existing_unnamed(&ctx, head_fields).unwrap();
480 assert!(head_struct == head_struct2);
481 assert!(StructType::get_existing_unnamed(&ctx, vec![int64_ptr, list_struct]).is_none());
482
483 Ok(())
484 }
485
486 #[verify_succ]
490 #[pliron_type(name = "llvm.typed_ptr", generate_get = true)]
491 #[derive(Hash, PartialEq, Eq, Debug)]
492 pub struct TypedPointerType {
493 to: Ptr<TypeObj>,
494 }
495
496 impl TypedPointerType {
497 pub fn get_existing(ctx: &Context, to: Ptr<TypeObj>) -> Option<TypePtr<Self>> {
499 Type::get_instance(TypedPointerType { to }, ctx)
500 }
501
502 pub fn get_pointee_type(&self) -> Ptr<TypeObj> {
504 self.to
505 }
506 }
507
508 impl Printable for TypedPointerType {
509 fn fmt(
510 &self,
511 ctx: &Context,
512 _state: &printable::State,
513 f: &mut core::fmt::Formatter<'_>,
514 ) -> core::fmt::Result {
515 write!(f, "<{}>", self.to.disp(ctx))
516 }
517 }
518
519 impl Parsable for TypedPointerType {
520 type Arg = ();
521 type Parsed = TypePtr<Self>;
522
523 fn parse<'a>(
524 state_stream: &mut StateStream<'a>,
525 _arg: Self::Arg,
526 ) -> ParseResult<'a, Self::Parsed>
527 where
528 Self: Sized,
529 {
530 combine::between(token('<'), token('>'), spaced(type_parser()))
531 .parse_stream(state_stream)
532 .map(|pointee_ty| TypedPointerType::get(state_stream.state.ctx, pointee_ty))
533 .into()
534 }
535 }
536
537 #[test]
538 fn test_pointer_types() {
539 let mut ctx = Context::new();
540 let int32_1_ptr = IntegerType::get(&mut ctx, 32, Signedness::Signed);
541 let int64_ptr = IntegerType::get(&mut ctx, 64, Signedness::Signed).into();
542
543 let int64pointer_ptr = TypedPointerType { to: int64_ptr };
544 let int64pointer_ptr = Type::register_instance(int64pointer_ptr, &mut ctx);
545 assert_eq!(
546 int64pointer_ptr.disp(&ctx).to_string(),
547 "llvm.typed_ptr <builtin.integer si64>"
548 );
549 assert!(int64pointer_ptr == TypedPointerType::get(&mut ctx, int64_ptr));
550
551 assert!(
552 int64_ptr
553 .deref(&ctx)
554 .downcast_ref::<IntegerType>()
555 .unwrap()
556 .width()
557 == 64
558 );
559
560 assert!(IntegerType::get_existing(&ctx, 32, Signedness::Signed).unwrap() == int32_1_ptr);
561 assert!(TypedPointerType::get_existing(&ctx, int64_ptr).unwrap() == int64pointer_ptr);
562 assert!(int64pointer_ptr.deref(&ctx).get_pointee_type() == int64_ptr);
563 }
564
565 #[test]
566 fn test_pointer_type_parsing() {
567 let mut ctx = Context::new();
568
569 let state_stream = state_stream_from_iterator(
570 "llvm.typed_ptr <builtin.integer si64>".chars(),
571 parsable::State::new(&mut ctx, location::Source::InMemory),
572 );
573
574 let res = type_parser().parse(state_stream).unwrap().0;
575 assert_eq!(
576 &res.disp(&ctx).to_string(),
577 "llvm.typed_ptr <builtin.integer si64>"
578 );
579 }
580
581 #[test]
582 fn test_struct_type_parsing() {
583 let mut ctx = Context::new();
584
585 let state_stream = state_stream_from_iterator(
586 "llvm.struct <LinkedList { builtin.integer i64, llvm.typed_ptr <llvm.struct <LinkedList>> }>"
587 .chars(),
588 parsable::State::new(&mut ctx, location::Source::InMemory),
589 );
590
591 let res = type_parser().parse(state_stream).unwrap().0;
592 assert_eq!(
593 &res.disp(&ctx).to_string(),
594 "llvm.struct <LinkedList { builtin.integer i64, llvm.typed_ptr <llvm.struct <LinkedList>> }>"
595 );
596
597 let test_string = "llvm.struct <ExternStruct>";
599 let state_stream = state_stream_from_iterator(
600 test_string.chars(),
601 parsable::State::new(&mut ctx, location::Source::InMemory),
602 );
603 let res = type_parser().parse(state_stream).unwrap().0;
604 assert_eq!(&res.disp(&ctx).to_string(), test_string);
605 {
606 let res = res.deref(&ctx);
607 let res = res.downcast_ref::<StructType>().unwrap();
608 assert!(res.is_opaque() && res.is_named());
609 }
610
611 let test_string = "llvm.struct <{ builtin.integer i8 }>";
613 let state_stream = state_stream_from_iterator(
614 test_string.chars(),
615 parsable::State::new(&mut ctx, location::Source::InMemory),
616 );
617 let res = type_parser().parse(state_stream).unwrap().0;
618 assert_eq!(&res.disp(&ctx).to_string(), test_string);
619 {
620 let res = res.deref(&ctx);
621 let res = res.downcast_ref::<StructType>().unwrap();
622 assert!(!res.is_opaque() && !res.is_named());
623 }
624 }
625
626 #[test]
627 fn test_struct_type_errs() {
628 let mut ctx = Context::new();
629
630 let state_stream = state_stream_from_iterator(
631 "llvm.struct < My1 { builtin.integer i8 } >".chars(),
632 parsable::State::new(&mut ctx, location::Source::InMemory),
633 );
634 let _ = type_parser().parse(state_stream).unwrap().0;
635
636 let state_stream = state_stream_from_iterator(
637 "llvm.struct < My1 { builtin.integer i16 } >".chars(),
638 parsable::State::new(&mut ctx, location::Source::InMemory),
639 );
640
641 let res = type_parser().parse(state_stream);
642 let err_msg = format!("{}", res.err().unwrap());
643
644 let expected_err_msg = expect![[r#"
645 Parse error at line: 1, column: 15
646 struct My1 already exists and is different
647 "#]];
648 expected_err_msg.assert_eq(&err_msg);
649 }
650
651 #[test]
652 fn test_functype_parsing() {
653 let mut ctx = Context::new();
654
655 let si32 = IntegerType::get(&mut ctx, 32, Signedness::Signed);
656
657 let input = "llvm.func <llvm.void (builtin.integer si32) variadic = false>";
658 let state_stream = state_stream_from_iterator(
659 input.chars(),
660 parsable::State::new(&mut ctx, location::Source::InMemory),
661 );
662
663 let res = type_parser().and(eof()).parse(state_stream).unwrap().0.0;
664
665 let void_ty = VoidType::get(&ctx);
666 assert!(res == FuncType::get(&mut ctx, void_ty.to_ptr(), vec![si32.into()], false).into());
667 assert_eq!(input, &res.disp(&ctx).to_string());
668 }
669}