Skip to main content

pliron_llvm/
types.rs

1//! [Type]s defined in the LLVM dialect.
2
3use pliron::combine::{Parser, between, optional, token};
4
5use pliron::builtin::type_interfaces::FunctionTypeInterface;
6use pliron::derive::{format, pliron_type, type_interface_impl};
7use pliron::{
8    common_traits::Verify,
9    context::{Context, Ptr},
10    identifier::Identifier,
11    input_err_noloc,
12    irfmt::{
13        parsers::{delimited_list_parser, location, spaced},
14        printers::{enclosed, list_with_sep},
15    },
16    location::Located,
17    parsable::{IntoParseResult, Parsable, ParseResult, StateStream},
18    printable::{self, ListSeparator, Printable},
19    result::Result,
20    r#type::{Type, TypeObj, TypePtr},
21    verify_err_noloc,
22};
23use thiserror::Error;
24
25use std::hash::Hash;
26
27/// Represents a c-like struct type.
28/// Limitations and warnings on its usage are similar to that in MLIR.
29/// `<https://mlir.llvm.org/docs/Dialects/LLVM/#structure-types>`
30///   1. Anonymous (aka unnamed) structs cannot be recursive.
31///   2. Named structs are uniqued *only* by name, and may be recursive.
32///   3. LLVM calls anonymous structs as literal structs and
33///      named structs as identified structs.
34///   4. Named structs may be opaque, i.e., no body specificed.
35///      Recursive types may be created by first creating an opaque struct
36///      and later setting its fields (body).
37#[pliron_type(name = "llvm.struct")]
38#[derive(Debug)]
39pub struct StructType {
40    name: Option<Identifier>,
41    fields: Option<Vec<Ptr<TypeObj>>>,
42}
43
44impl StructType {
45    /// Get or create a named StructType.
46    /// If `fields` is `None`, it indicates an opaque struct.
47    /// A body can be added to opaque structs by calling this again later.
48    /// Returns an error if all of the below conditions are true:
49    ///   a. The name is already registered
50    ///   b. The body is already set (i.e, the struct is not oqaue)
51    ///   c. The fields provided here don't match with the existing body.
52    /// Since named structs only rely on the name for uniqueness,
53    /// It is not an error to provide `fields` as `None` even when
54    /// the named struct already exists and has its body set.
55    pub fn get_named(
56        ctx: &mut Context,
57        name: Identifier,
58        fields: Option<Vec<Ptr<TypeObj>>>,
59    ) -> Result<TypePtr<Self>> {
60        let self_ptr = Type::register_instance(
61            StructType {
62                name: Some(name.clone()),
63                // Uniquing happens only on the name, so this doesn't matter.
64                fields: None,
65            },
66            ctx,
67        );
68        // Verify that we created a new or equivalent existing type.
69        let mut self_ref = self_ptr.to_ptr().deref_mut(ctx);
70        let self_ref = self_ref.downcast_mut::<StructType>().unwrap();
71        assert!(self_ref.name.as_ref().unwrap() == &name);
72        if let Some(fields) = fields {
73            // We've been provided fields to be set.
74            if let Some(existing_fields) = &self_ref.fields {
75                // Fields were already set before, ensure they're same as the given ones.
76                if existing_fields != &fields {
77                    input_err_noloc!(StructErr::ExistingMismatch(name.into()))?
78                }
79            } else {
80                // Set the fields now.
81                self_ref.fields = Some(fields);
82            }
83        }
84        Ok(self_ptr)
85    }
86
87    /// Get or create a new unnamed (anonymous) struct.
88    /// These are finalized upon creation, and uniqued based on the fields.
89    pub fn get_unnamed(ctx: &mut Context, fields: Vec<Ptr<TypeObj>>) -> TypePtr<Self> {
90        Type::register_instance(
91            StructType {
92                name: None,
93                fields: Some(fields),
94            },
95            ctx,
96        )
97    }
98
99    /// If a named struct already exists, get a pointer to it.
100    pub fn get_existing_named(ctx: &Context, name: &Identifier) -> Option<TypePtr<Self>> {
101        Type::get_instance(
102            StructType {
103                name: Some(name.clone()),
104                // Named structs are uniqued only on the name.
105                fields: None,
106            },
107            ctx,
108        )
109    }
110
111    /// If an unnamed struct already exists, get a pointer to it.
112    pub fn get_existing_unnamed(ctx: &Context, fields: Vec<Ptr<TypeObj>>) -> Option<TypePtr<Self>> {
113        Type::get_instance(
114            StructType {
115                name: None,
116                fields: Some(fields),
117            },
118            ctx,
119        )
120    }
121
122    /// Does this struct not have its body set?
123    pub fn is_opaque(&self) -> bool {
124        self.fields.is_none()
125    }
126
127    /// Is this a named struct?
128    pub fn is_named(&self) -> bool {
129        self.name.is_some()
130    }
131
132    /// Get this struct's name, if it has one.
133    pub fn name(&self) -> Option<Identifier> {
134        self.name.clone()
135    }
136
137    /// Get type of the idx'th field.
138    pub fn field_type(&self, field_idx: usize) -> Ptr<TypeObj> {
139        self.fields
140            .as_ref()
141            .expect("field_type shouldn't be called on opaque types")[field_idx]
142    }
143
144    /// Get the number of fields this struct has
145    pub fn num_fields(&self) -> usize {
146        self.fields
147            .as_ref()
148            .expect("num_fields shouldn't be called on opaque types")
149            .len()
150    }
151
152    /// Get an iterator over the fields of this struct
153    pub fn fields(&self) -> impl Iterator<Item = Ptr<TypeObj>> + '_ {
154        self.fields
155            .as_ref()
156            .expect("fields shouldn't be called on opaque types")
157            .iter()
158            .cloned()
159    }
160}
161
162#[derive(Debug, Error)]
163pub enum StructErr {
164    #[error("struct cannot be both opaque and anonymous")]
165    OpaqueAndAnonymousErr,
166    #[error("struct {0} already exists and is different")]
167    ExistingMismatch(String),
168}
169
170impl Verify for StructType {
171    fn verify(&self, _ctx: &Context) -> Result<()> {
172        if self.name.is_none() && self.fields.is_none() {
173            verify_err_noloc!(StructErr::OpaqueAndAnonymousErr)?
174        }
175        Ok(())
176    }
177}
178
179impl Printable for StructType {
180    fn fmt(
181        &self,
182        ctx: &Context,
183        state: &printable::State,
184        f: &mut core::fmt::Formatter<'_>,
185    ) -> core::fmt::Result {
186        write!(f, "<")?;
187
188        use std::cell::RefCell;
189        // Ugly, but also the simplest way to avoid infinite recursion.
190        // MLIR does the same: see LLVMTypeSyntax::printStructType.
191        thread_local! {
192            // We use a vec instead of a HashMap hoping that this isn't
193            // going to be large, in which case vec would be faster.
194            static IN_PRINTING: RefCell<Vec<Identifier>>  = const { RefCell::new(vec![]) };
195        }
196        if let Some(name) = &self.name {
197            let in_printing = IN_PRINTING.with(|f| f.borrow().contains(name));
198            if in_printing {
199                return write!(f, "{}>", name.clone());
200            }
201            IN_PRINTING.with(|f| f.borrow_mut().push(name.clone()));
202            write!(f, "{name}")?;
203            if !self.is_opaque() {
204                write!(f, " ")?;
205            }
206        }
207
208        if let Some(fields) = &self.fields {
209            enclosed(
210                "{ ",
211                " }",
212                list_with_sep(fields, ListSeparator::CharSpace(',')),
213            )
214            .fmt(ctx, state, f)?;
215        }
216
217        // Done processing this struct. Remove it from the stack.
218        if let Some(name) = &self.name {
219            assert!(IN_PRINTING.with(|f| f.borrow().last().unwrap() == name));
220            IN_PRINTING.with(|f| f.borrow_mut().pop());
221        }
222        write!(f, ">")
223    }
224}
225
226impl Hash for StructType {
227    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
228        match &self.name {
229            Some(name) => name.hash(state),
230            None => self
231                .fields
232                .as_ref()
233                .expect("Anonymous struct must have its fields set")
234                .iter()
235                .for_each(|field_type| {
236                    field_type.hash(state);
237                }),
238        }
239    }
240}
241
242impl PartialEq for StructType {
243    fn eq(&self, other: &Self) -> bool {
244        match (&self.name, &other.name) {
245            (Some(name), Some(other_name)) => name == other_name,
246            (None, None) => self.fields == other.fields,
247            _ => false,
248        }
249    }
250}
251
252impl Parsable for StructType {
253    type Arg = ();
254    type Parsed = TypePtr<Self>;
255
256    fn parse<'a>(
257        state_stream: &mut StateStream<'a>,
258        _arg: Self::Arg,
259    ) -> ParseResult<'a, Self::Parsed>
260    where
261        Self: Sized,
262    {
263        let body_parser = || {
264            // Parse multiple type annotated fields separated by ',', all of it delimited by braces.
265            delimited_list_parser('{', '}', ',', Ptr::<TypeObj>::parser(()))
266        };
267
268        let named = spaced((location(), Identifier::parser(())))
269            .and(spaced(optional(body_parser())))
270            .map(|((loc, name), body_opt)| (loc, Some(name), body_opt));
271        let anonymous = spaced((location(), body_parser()))
272            .map(|(loc, body)| (loc, None::<Identifier>, Some(body)));
273
274        // A struct type is named or anonymous.
275        let mut struct_parser = between(token('<'), token('>'), named.or(anonymous));
276
277        let (loc, name_opt, body_opt) = struct_parser.parse_stream(state_stream).into_result()?.0;
278        let ctx = &mut state_stream.state.ctx;
279        if let Some(name) = name_opt {
280            StructType::get_named(ctx, name, body_opt)
281                .map_err(|mut err| {
282                    err.set_loc(loc);
283                    err
284                })
285                .into_parse_result()
286        } else {
287            Ok(StructType::get_unnamed(
288                ctx,
289                body_opt.expect("Without a name, a struct type must have a body."),
290            ))
291            .into_parse_result()
292        }
293    }
294}
295
296impl Eq for StructType {}
297
298/// An opaque pointer, corresponding to LLVM's pointer type.
299#[pliron_type(name = "llvm.ptr", generate_get = true, format, verifier = "succ")]
300#[derive(Hash, PartialEq, Eq, Debug)]
301pub struct PointerType;
302
303/// Array type, corresponding to LLVM's array type.
304#[pliron_type(
305    name = "llvm.array",
306    generate_get = true,
307    format = "`[` $size ` x ` $elem `]`",
308    verifier = "succ"
309)]
310#[derive(Hash, PartialEq, Eq, Debug)]
311pub struct ArrayType {
312    elem: Ptr<TypeObj>,
313    size: u64,
314}
315
316impl ArrayType {
317    /// Get array element type.
318    pub fn elem_type(&self) -> Ptr<TypeObj> {
319        self.elem
320    }
321
322    /// Get array size.
323    pub fn size(&self) -> u64 {
324        self.size
325    }
326}
327
328#[pliron_type(name = "llvm.void", generate_get = true, format, verifier = "succ")]
329#[derive(Hash, PartialEq, Eq, Debug)]
330pub struct VoidType;
331
332#[pliron_type(
333    name = "llvm.func",
334    generate_get = true,
335    format = "`<` $res `(` vec($args, CharSpace(`,`)) `) variadic = ` $is_var_arg `>`",
336    verifier = "succ"
337)]
338#[derive(Hash, PartialEq, Eq, Debug)]
339pub struct FuncType {
340    res: Ptr<TypeObj>,
341    args: Vec<Ptr<TypeObj>>,
342    is_var_arg: bool,
343}
344
345#[derive(Debug, Error)]
346pub enum FuncTypeErr {
347    #[error("Expected at most one result")]
348    TooManyResults,
349}
350
351impl FuncType {
352    /// Result type
353    pub fn result_type(&self) -> Ptr<TypeObj> {
354        self.res
355    }
356
357    /// Is this a variadic function type?
358    pub fn is_var_arg(&self) -> bool {
359        self.is_var_arg
360    }
361}
362
363#[type_interface_impl]
364impl FunctionTypeInterface for FuncType {
365    fn arg_types(&self) -> Vec<Ptr<TypeObj>> {
366        self.args.clone()
367    }
368    fn res_types(&self) -> Vec<Ptr<TypeObj>> {
369        vec![self.res]
370    }
371}
372
373/// Kind of vector type: fixed or scalable.
374/// See LLVM language reference for semantic details.
375#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
376#[format]
377pub enum VectorTypeKind {
378    Fixed,
379    Scalable,
380}
381
382#[pliron_type(
383    name = "llvm.vector",
384    generate_get = true,
385    format = "`<` $kind ` x ` $num_elems ` x ` $elem_ty `>`",
386    verifier = "succ"
387)]
388#[derive(Hash, PartialEq, Eq, Debug)]
389pub struct VectorType {
390    elem_ty: Ptr<TypeObj>,
391    num_elems: u32,
392    kind: VectorTypeKind,
393}
394
395impl VectorType {
396    /// Get the element type.
397    pub fn elem_type(&self) -> Ptr<TypeObj> {
398        self.elem_ty
399    }
400
401    /// Get the number of elements.
402    pub fn num_elements(&self) -> u32 {
403        self.num_elems
404    }
405
406    /// Is this a scalable vector type?
407    pub fn is_scalable(&self) -> bool {
408        self.kind == VectorTypeKind::Scalable
409    }
410
411    /// Get the scalable/fixed kind of this vector type.
412    pub fn kind(&self) -> VectorTypeKind {
413        self.kind
414    }
415}
416
417#[cfg(test)]
418mod tests {
419
420    use expect_test::expect;
421    use pliron::combine::{self, Parser, eof, token};
422    use pliron::derive::{pliron_type, verify_succ};
423
424    use crate::types::{FuncType, StructType, VoidType};
425    use pliron::{
426        builtin::types::{IntegerType, Signedness},
427        context::{Context, Ptr},
428        identifier::Identifier,
429        irfmt::parsers::{spaced, type_parser},
430        location,
431        parsable::{self, Parsable, ParseResult, StateStream, state_stream_from_iterator},
432        printable::{self, Printable},
433        result::Result,
434        r#type::{Type, TypeObj, TypePtr},
435    };
436
437    #[test]
438    fn test_struct() -> Result<()> {
439        let mut ctx = Context::new();
440        let int64_ptr = IntegerType::get(&mut ctx, 64, Signedness::Signless).into();
441        let linked_list_id: Identifier = "LinkedList".try_into().unwrap();
442        let linked_list_2_id: Identifier = "LinkedList2".try_into().unwrap();
443
444        // Create an opaque struct since we want a recursive type.
445        let list_struct: Ptr<TypeObj> =
446            StructType::get_named(&mut ctx, linked_list_id.clone(), None)?.into();
447        assert!(
448            list_struct
449                .deref(&ctx)
450                .downcast_ref::<StructType>()
451                .unwrap()
452                .is_opaque()
453        );
454        let list_struct_ptr = TypedPointerType::get(&mut ctx, list_struct).into();
455        let fields = vec![int64_ptr, list_struct_ptr];
456        // Set the struct body now.
457        StructType::get_named(&mut ctx, linked_list_id.clone(), Some(fields))?;
458        assert!(
459            !list_struct
460                .deref(&ctx)
461                .downcast_ref::<StructType>()
462                .unwrap()
463                .is_opaque()
464        );
465
466        let list_struct_2 = StructType::get_existing_named(&ctx, &linked_list_id)
467            .unwrap()
468            .into();
469        assert!(list_struct == list_struct_2);
470        assert!(StructType::get_existing_named(&ctx, &linked_list_2_id).is_none());
471
472        assert_eq!(
473            list_struct.disp(&ctx).to_string(),
474            "llvm.struct <LinkedList { builtin.integer i64, llvm.typed_ptr <llvm.struct <LinkedList>> }>"
475        );
476
477        let head_fields = vec![int64_ptr, list_struct_ptr];
478        let head_struct = StructType::get_unnamed(&mut ctx, head_fields.clone());
479        let head_struct2 = StructType::get_existing_unnamed(&ctx, head_fields).unwrap();
480        assert!(head_struct == head_struct2);
481        assert!(StructType::get_existing_unnamed(&ctx, vec![int64_ptr, list_struct]).is_none());
482
483        Ok(())
484    }
485
486    /// A pointer type that knows the type it points to.
487    /// This used to be in LLVM earlier, but the latest version
488    /// is now type-erased (https://llvm.org/docs/OpaquePointers.html)
489    #[verify_succ]
490    #[pliron_type(name = "llvm.typed_ptr", generate_get = true)]
491    #[derive(Hash, PartialEq, Eq, Debug)]
492    pub struct TypedPointerType {
493        to: Ptr<TypeObj>,
494    }
495
496    impl TypedPointerType {
497        /// Get, if it already exists, a pointer type.
498        pub fn get_existing(ctx: &Context, to: Ptr<TypeObj>) -> Option<TypePtr<Self>> {
499            Type::get_instance(TypedPointerType { to }, ctx)
500        }
501
502        /// Get the pointee type.
503        pub fn get_pointee_type(&self) -> Ptr<TypeObj> {
504            self.to
505        }
506    }
507
508    impl Printable for TypedPointerType {
509        fn fmt(
510            &self,
511            ctx: &Context,
512            _state: &printable::State,
513            f: &mut core::fmt::Formatter<'_>,
514        ) -> core::fmt::Result {
515            write!(f, "<{}>", self.to.disp(ctx))
516        }
517    }
518
519    impl Parsable for TypedPointerType {
520        type Arg = ();
521        type Parsed = TypePtr<Self>;
522
523        fn parse<'a>(
524            state_stream: &mut StateStream<'a>,
525            _arg: Self::Arg,
526        ) -> ParseResult<'a, Self::Parsed>
527        where
528            Self: Sized,
529        {
530            combine::between(token('<'), token('>'), spaced(type_parser()))
531                .parse_stream(state_stream)
532                .map(|pointee_ty| TypedPointerType::get(state_stream.state.ctx, pointee_ty))
533                .into()
534        }
535    }
536
537    #[test]
538    fn test_pointer_types() {
539        let mut ctx = Context::new();
540        let int32_1_ptr = IntegerType::get(&mut ctx, 32, Signedness::Signed);
541        let int64_ptr = IntegerType::get(&mut ctx, 64, Signedness::Signed).into();
542
543        let int64pointer_ptr = TypedPointerType { to: int64_ptr };
544        let int64pointer_ptr = Type::register_instance(int64pointer_ptr, &mut ctx);
545        assert_eq!(
546            int64pointer_ptr.disp(&ctx).to_string(),
547            "llvm.typed_ptr <builtin.integer si64>"
548        );
549        assert!(int64pointer_ptr == TypedPointerType::get(&mut ctx, int64_ptr));
550
551        assert!(
552            int64_ptr
553                .deref(&ctx)
554                .downcast_ref::<IntegerType>()
555                .unwrap()
556                .width()
557                == 64
558        );
559
560        assert!(IntegerType::get_existing(&ctx, 32, Signedness::Signed).unwrap() == int32_1_ptr);
561        assert!(TypedPointerType::get_existing(&ctx, int64_ptr).unwrap() == int64pointer_ptr);
562        assert!(int64pointer_ptr.deref(&ctx).get_pointee_type() == int64_ptr);
563    }
564
565    #[test]
566    fn test_pointer_type_parsing() {
567        let mut ctx = Context::new();
568
569        let state_stream = state_stream_from_iterator(
570            "llvm.typed_ptr <builtin.integer si64>".chars(),
571            parsable::State::new(&mut ctx, location::Source::InMemory),
572        );
573
574        let res = type_parser().parse(state_stream).unwrap().0;
575        assert_eq!(
576            &res.disp(&ctx).to_string(),
577            "llvm.typed_ptr <builtin.integer si64>"
578        );
579    }
580
581    #[test]
582    fn test_struct_type_parsing() {
583        let mut ctx = Context::new();
584
585        let state_stream = state_stream_from_iterator(
586            "llvm.struct <LinkedList { builtin.integer i64, llvm.typed_ptr <llvm.struct <LinkedList>> }>"
587                .chars(),
588            parsable::State::new(&mut ctx, location::Source::InMemory),
589        );
590
591        let res = type_parser().parse(state_stream).unwrap().0;
592        assert_eq!(
593            &res.disp(&ctx).to_string(),
594            "llvm.struct <LinkedList { builtin.integer i64, llvm.typed_ptr <llvm.struct <LinkedList>> }>"
595        );
596
597        // Test parsing an opaque struct.
598        let test_string = "llvm.struct <ExternStruct>";
599        let state_stream = state_stream_from_iterator(
600            test_string.chars(),
601            parsable::State::new(&mut ctx, location::Source::InMemory),
602        );
603        let res = type_parser().parse(state_stream).unwrap().0;
604        assert_eq!(&res.disp(&ctx).to_string(), test_string);
605        {
606            let res = res.deref(&ctx);
607            let res = res.downcast_ref::<StructType>().unwrap();
608            assert!(res.is_opaque() && res.is_named());
609        }
610
611        // Test parsing an unnamed struct.
612        let test_string = "llvm.struct <{ builtin.integer i8 }>";
613        let state_stream = state_stream_from_iterator(
614            test_string.chars(),
615            parsable::State::new(&mut ctx, location::Source::InMemory),
616        );
617        let res = type_parser().parse(state_stream).unwrap().0;
618        assert_eq!(&res.disp(&ctx).to_string(), test_string);
619        {
620            let res = res.deref(&ctx);
621            let res = res.downcast_ref::<StructType>().unwrap();
622            assert!(!res.is_opaque() && !res.is_named());
623        }
624    }
625
626    #[test]
627    fn test_struct_type_errs() {
628        let mut ctx = Context::new();
629
630        let state_stream = state_stream_from_iterator(
631            "llvm.struct < My1 { builtin.integer i8 } >".chars(),
632            parsable::State::new(&mut ctx, location::Source::InMemory),
633        );
634        let _ = type_parser().parse(state_stream).unwrap().0;
635
636        let state_stream = state_stream_from_iterator(
637            "llvm.struct < My1 { builtin.integer i16 } >".chars(),
638            parsable::State::new(&mut ctx, location::Source::InMemory),
639        );
640
641        let res = type_parser().parse(state_stream);
642        let err_msg = format!("{}", res.err().unwrap());
643
644        let expected_err_msg = expect![[r#"
645            Parse error at line: 1, column: 15
646            struct My1 already exists and is different
647        "#]];
648        expected_err_msg.assert_eq(&err_msg);
649    }
650
651    #[test]
652    fn test_functype_parsing() {
653        let mut ctx = Context::new();
654
655        let si32 = IntegerType::get(&mut ctx, 32, Signedness::Signed);
656
657        let input = "llvm.func <llvm.void (builtin.integer si32) variadic = false>";
658        let state_stream = state_stream_from_iterator(
659            input.chars(),
660            parsable::State::new(&mut ctx, location::Source::InMemory),
661        );
662
663        let res = type_parser().and(eof()).parse(state_stream).unwrap().0.0;
664
665        let void_ty = VoidType::get(&ctx);
666        assert!(res == FuncType::get(&mut ctx, void_ty.to_ptr(), vec![si32.into()], false).into());
667        assert_eq!(input, &res.disp(&ctx).to_string());
668    }
669}