Skip to main content

pliron/
type.rs

1//! Every SSA value, such as operation results or block arguments
2//! has a type defined by the type system.
3//!
4//! The type system is open, with no fixed list of types,
5//! and there are no restrictions on the abstractions they represent.
6//!
7//! See [MLIR Types](https://mlir.llvm.org/docs/DefiningDialects/TypesAndTypes/#types)
8//!
9//! The [pliron_type](pliron::derive::pliron_type) proc macro from [pliron-derive]
10//! can be used to implement [Type] for a rust type.
11//!
12//! Common semantics, API and behaviour of [Type]s are
13//! abstracted into interfaces. Interfaces in pliron capture MLIR
14//! functionality of both [Traits](https://mlir.llvm.org/docs/Traits/)
15//! and [Interfaces](https://mlir.llvm.org/docs/Interfaces/).
16//! Interfaces must all implement an associated function named `verify` with
17//! the type [TypeInterfaceVerifier].
18//!
19//! Interfaces are rust Trait definitions annotated with the attribute macro
20//! [type_interface](pliron::derive::type_interface). The attribute ensures that any
21//! verifiers of super-interfaces are run prior to the verifier of this interface.
22//! Note: Super-interface verifiers *may* run multiple times for the same type.
23//!
24//! [Type]s that implement an interface must annotate the implementation with
25//! [type_interface_impl](pliron::derive::type_interface_impl) macro to ensure that
26//! the interface verifier is automatically called during verification
27//! and that a `&dyn Type` object can be [cast](type_cast) into an interface object,
28//! (or that it can be checked if the interface is [implemented](type_impls))
29//! with ease.
30//!
31//! Use [verify_type] to verify a [Type] object.
32//! This function verifies all interfaces implemented by the type, and then the type itself.
33//! The type's verifier must explicitly invoke verifiers on any sub-objects it contains.
34//!
35//! [TypeObj]s can be downcasted to their concrete types using
36//! [downcast_rs](https://docs.rs/downcast-rs/latest/downcast_rs/#example-without-generics).
37
38use crate::common_traits::Verify;
39use crate::context::{Arena, Context, Ptr, collect_deduped_interface_verifiers, private::ArenaObj};
40use crate::dialect::{Dialect, DialectName};
41use crate::identifier::Identifier;
42use crate::irfmt::parsers::spaced;
43use crate::location::Located;
44use crate::parsable::{Parsable, ParseResult, StateStream};
45use crate::printable::{self, Printable};
46use crate::result::Result;
47use crate::storage_uniquer::TypeValueHash;
48use crate::{arg_err_noloc, impl_printable_for_display, input_err};
49
50use combine::{Parser, parser};
51use downcast_rs::{Downcast, impl_downcast};
52use rustc_hash::FxHashMap;
53use std::cell::Ref;
54use std::fmt::Debug;
55use std::fmt::Display;
56use std::hash::{Hash, Hasher};
57use std::marker::PhantomData;
58use std::ops::Deref;
59use std::sync::LazyLock;
60use thiserror::Error;
61
62/// Basic functionality that every type in the IR must implement.
63/// Type objects (instances of a Type) are (mostly) immutable once created,
64/// and are uniqued globally. Uniquing is based on the type name (i.e.,
65/// the rust type being defined) and its contents.
66///
67/// So, for example, if we have
68/// ```rust
69///     # use pliron::derive::pliron_type;
70///     #[pliron_type(
71///         name = "test.intty",
72///         format,
73///         verifier = "succ"
74///     )]
75///     #[derive(Debug, PartialEq, Eq, Hash)]
76///     struct IntType {
77///         width: u64
78///     }
79/// ```
80/// the uniquing will include
81///   - [`std::any::TypeId::of::<IntType>()`](std::any::TypeId)
82///   - `width`
83///
84/// Types *can* have mutable contents that can be modified *after*
85/// the type is created. This enables creation of recursive types.
86/// In such a case, it is up to the type definition to ensure that
87///   1. It manually implements Hash, ignoring these mutable fields.
88///   2. A proper distinguisher content (such as a string), that is part
89///      of the hash, is used so that uniquing still works.
90pub trait Type: Printable + Verify + Downcast + Sync + Send + Debug {
91    /// Compute and get the hash for this instance of Self.
92    /// Hash collisions can be a possibility.
93    fn hash_type(&self) -> TypeValueHash;
94    /// Is self equal to an other Type?
95    fn eq_type(&self, other: &dyn Type) -> bool;
96
97    /// Get a copyable pointer to this type.
98    // Unlike in other [ArenaObj]s,
99    // we do not store a self pointer inside the object itself
100    // because that can upset taking automatic hashes of the object.
101    fn get_self_ptr(&self, ctx: &Context) -> Ptr<TypeObj> {
102        let is = |other: &TypeObj| self.eq_type(&**other);
103        let idx = ctx
104            .type_store
105            .get(self.hash_type(), &is)
106            .expect("Unregistered type object in existence");
107        Ptr {
108            idx,
109            _dummy: PhantomData::<TypeObj>,
110        }
111    }
112
113    /// Register an instance of a type in the provided [Context]
114    /// Returns a pointer to self. If the type was already registered,
115    /// a pointer to the existing object is returned.
116    fn register_instance(t: Self, ctx: &mut Context) -> TypePtr<Self>
117    where
118        Self: Sized,
119    {
120        let hash = t.hash_type();
121        let idx = ctx
122            .type_store
123            .get_or_create_unique(Box::new(t), hash, &TypeObj::eq);
124        let ptr = Ptr {
125            idx,
126            _dummy: PhantomData::<TypeObj>,
127        };
128        TypePtr(ptr, PhantomData::<Self>)
129    }
130
131    /// If an instance of `t` already exists, get a [Ptr] to it.
132    /// Consumes `t` either way.
133    fn get_instance(t: Self, ctx: &Context) -> Option<TypePtr<Self>>
134    where
135        Self: Sized,
136    {
137        let is = |other: &TypeObj| t.eq_type(&**other);
138        ctx.type_store.get(t.hash_type(), &is).map(|idx| {
139            let ptr = Ptr {
140                idx,
141                _dummy: PhantomData::<TypeObj>,
142            };
143            TypePtr(ptr, PhantomData::<Self>)
144        })
145    }
146
147    /// Get a Type's static name. This is *not* per instantiation of the type.
148    /// It is mostly useful for printing and parsing the type.
149    /// Uniquing does *not* use this, but instead uses [std::any::TypeId].
150    fn get_type_id(&self) -> TypeId;
151
152    /// Same as [get_type_id](Self::get_type_id), but without the self reference.
153    fn get_type_id_static() -> TypeId
154    where
155        Self: Sized;
156
157    #[doc(hidden)]
158    /// Verify all interfaces implemented by this Type.
159    fn verify_interfaces(&self, ctx: &Context) -> Result<()>;
160
161    /// Register this Type's [TypeId] in the dialect it belongs to.
162    fn register(ctx: &mut Context)
163    where
164        Self: Sized + Parsable<Arg = (), Parsed = TypePtr<Self>>,
165    {
166        let ptr_parser: TypeParserFn = Box::new(|&()| {
167            combine::parser(move |parsable_state: &mut StateStream<'_>| {
168                Self::parse(parsable_state, ())
169                    .map(|(typtr, r)| -> (Ptr<TypeObj>, _) { (typtr.to_ptr(), r) })
170            })
171            .boxed()
172        });
173        let typeid = Self::get_type_id_static();
174        Dialect::register(ctx, &typeid.dialect.clone()).add_type(typeid, ptr_parser);
175    }
176}
177impl_downcast!(Type);
178
179/// A storable function pointer to parse a specific [Type].
180/// The [Type]'s [Dialect] maps a [TypeId] to such a parser.
181pub(crate) type TypeParserFn = Box<
182    for<'a> fn(
183        &'a (),
184    )
185        -> Box<dyn Parser<StateStream<'a>, Output = Ptr<TypeObj>, PartialState = ()> + 'a>,
186>;
187
188/// Trait for IR entities that have a direct type.
189pub trait Typed {
190    /// Get the [Type] of the current entity.
191    fn get_type(&self, ctx: &Context) -> Ptr<TypeObj>;
192}
193
194impl Typed for Ptr<TypeObj> {
195    fn get_type(&self, _ctx: &Context) -> Ptr<TypeObj> {
196        *self
197    }
198}
199
200impl Typed for dyn Type {
201    fn get_type(&self, ctx: &Context) -> Ptr<TypeObj> {
202        self.get_self_ptr(ctx)
203    }
204}
205
206impl<T: Typed + ?Sized> Typed for &T {
207    fn get_type(&self, ctx: &Context) -> Ptr<TypeObj> {
208        (*self).get_type(ctx)
209    }
210}
211
212impl<T: Typed + ?Sized> Typed for &mut T {
213    fn get_type(&self, ctx: &Context) -> Ptr<TypeObj> {
214        (**self).get_type(ctx)
215    }
216}
217
218impl<T: Typed + ?Sized> Typed for Box<T> {
219    fn get_type(&self, ctx: &Context) -> Ptr<TypeObj> {
220        (**self).get_type(ctx)
221    }
222}
223
224#[derive(Clone, Hash, PartialEq, Eq)]
225/// A Type's name (not including it's dialect).
226pub struct TypeName(Identifier);
227
228impl TypeName {
229    /// Create a new TypeName.
230    pub fn new(name: &str) -> TypeName {
231        TypeName(name.try_into().expect("Invalid Identifier for TypeName"))
232    }
233}
234
235impl Deref for TypeName {
236    type Target = Identifier;
237
238    fn deref(&self) -> &Self::Target {
239        &self.0
240    }
241}
242
243impl_printable_for_display!(TypeName);
244
245impl Display for TypeName {
246    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247        write!(f, "{}", self.0)
248    }
249}
250
251impl Parsable for TypeName {
252    type Arg = ();
253    type Parsed = TypeName;
254
255    fn parse<'a>(
256        state_stream: &mut crate::parsable::StateStream<'a>,
257        _arg: Self::Arg,
258    ) -> ParseResult<'a, Self::Parsed>
259    where
260        Self: Sized,
261    {
262        Identifier::parser(())
263            .map(|name| TypeName::new(&name))
264            .parse_stream(state_stream)
265            .into()
266    }
267}
268
269/// A combination of a Type's name and its dialect.
270#[derive(Clone, Hash, PartialEq, Eq)]
271pub struct TypeId {
272    pub dialect: DialectName,
273    pub name: TypeName,
274}
275
276impl Parsable for TypeId {
277    type Arg = ();
278    type Parsed = TypeId;
279
280    // Parses (but does not validate) a TypeId.
281    fn parse<'a>(
282        state_stream: &mut StateStream<'a>,
283        _arg: Self::Arg,
284    ) -> ParseResult<'a, Self::Parsed>
285    where
286        Self: Sized,
287    {
288        let mut parser = DialectName::parser(())
289            .skip(parser::char::char('.'))
290            .and(TypeName::parser(()))
291            .map(|(dialect, name)| TypeId { dialect, name });
292        parser.parse_stream(state_stream).into()
293    }
294}
295
296impl_printable_for_display!(TypeId);
297
298impl Display for TypeId {
299    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300        write!(f, "{}.{}", self.dialect, self.name)
301    }
302}
303
304/// Since we can't store the [Type] trait in the arena,
305/// we store boxed dyn objects of it instead.
306pub type TypeObj = Box<dyn Type>;
307
308impl PartialEq for TypeObj {
309    fn eq(&self, other: &Self) -> bool {
310        (**self).eq_type(&**other)
311    }
312}
313
314impl Eq for TypeObj {}
315
316impl Hash for TypeObj {
317    fn hash<H: Hasher>(&self, state: &mut H) {
318        state.write(&u64::from(self.hash_type()).to_ne_bytes())
319    }
320}
321
322impl ArenaObj for TypeObj {
323    fn get_arena(ctx: &Context) -> &Arena<Self> {
324        &ctx.type_store.unique_store
325    }
326
327    fn get_arena_mut(ctx: &mut Context) -> &mut Arena<Self> {
328        &mut ctx.type_store.unique_store
329    }
330
331    fn get_self_ptr(&self, ctx: &Context) -> Ptr<Self> {
332        self.as_ref().get_self_ptr(ctx)
333    }
334
335    fn dealloc_sub_objects(_ptr: Ptr<Self>, _ctx: &mut Context) {
336        panic!("Cannot dealloc arena sub-objects of types")
337    }
338}
339
340impl Printable for TypeObj {
341    fn fmt(
342        &self,
343        ctx: &Context,
344        state: &printable::State,
345        f: &mut std::fmt::Formatter<'_>,
346    ) -> std::fmt::Result {
347        write!(f, "{} ", self.get_type_id())?;
348        Printable::fmt(self.deref(), ctx, state, f)
349    }
350}
351
352impl Parsable for Ptr<TypeObj> {
353    type Arg = ();
354    type Parsed = Self;
355
356    fn parse<'a>(
357        state_stream: &mut StateStream<'a>,
358        _arg: Self::Arg,
359    ) -> ParseResult<'a, Self::Parsed> {
360        let loc = state_stream.loc();
361        let type_id_parser = spaced(TypeId::parser(()));
362
363        let mut type_parser = type_id_parser.then(move |type_id: TypeId| {
364            // This clone is to satify the borrow checker.
365            let loc = loc.clone();
366            combine::parser(move |parsable_state: &mut StateStream<'a>| {
367                let state = &parsable_state.state;
368                let dialect = state
369                    .ctx
370                    .dialects
371                    .get(&type_id.dialect)
372                    .expect("Dialect name parsed but dialect isn't registered");
373                let Some(type_parser) = dialect.types.get(&type_id) else {
374                    input_err!(loc.clone(), "Unregistered type {}", type_id.disp(state.ctx))?
375                };
376                type_parser(&()).parse_stream(parsable_state).into()
377            })
378        });
379
380        type_parser.parse_stream(state_stream).into_result()
381    }
382}
383
384/// Verify a [Type] object:
385/// 1. All interfaces it implements are verified
386/// 2. The type itself is verified.
387pub fn verify_type(ty: &dyn Type, ctx: &Context) -> Result<()> {
388    // Verify all interfaces implemented by this Type.
389    ty.verify_interfaces(ctx)?;
390
391    // Verify the type itself.
392    Verify::verify(ty, ctx)
393}
394
395impl Verify for TypeObj {
396    fn verify(&self, ctx: &Context) -> Result<()> {
397        verify_type(self.as_ref(), ctx)
398    }
399}
400
401/// A wrapper around [`Ptr<TypeObj>`](TypeObj) with the underlying [Type] statically marked.
402#[derive(Debug)]
403pub struct TypePtr<T: Type>(Ptr<TypeObj>, PhantomData<T>);
404
405#[derive(Error, Debug)]
406#[error("TypePtr mismatch: Constructing {expected} but provided {provided}")]
407pub struct TypePtrErr {
408    pub expected: String,
409    pub provided: String,
410}
411
412impl<T: Type> TypePtr<T> {
413    /// Return a [Ref] to the [Type]
414    /// This borrows from a RefCell and the borrow is live
415    /// as long as the returned [Ref] lives.
416    pub fn deref<'a>(&self, ctx: &'a Context) -> Ref<'a, T> {
417        Ref::map(self.0.deref(ctx), |t| {
418            t.downcast_ref::<T>()
419                .expect("Type mistmatch, inconsistent TypePtr")
420        })
421    }
422
423    /// Create a new [TypePtr] from [`Ptr<TypeObj>`](TypeObj)
424    pub fn from_ptr(ptr: Ptr<TypeObj>, ctx: &Context) -> Result<TypePtr<T>> {
425        if ptr.deref(ctx).is::<T>() {
426            Ok(TypePtr(ptr, PhantomData::<T>))
427        } else {
428            arg_err_noloc!(TypePtrErr {
429                expected: T::get_type_id_static().disp(ctx).to_string(),
430                provided: ptr.disp(ctx).to_string()
431            })
432        }
433    }
434
435    /// Erase the static rust type.
436    pub fn to_ptr(&self) -> Ptr<TypeObj> {
437        self.0
438    }
439}
440
441impl<T: Type> From<TypePtr<T>> for Ptr<TypeObj> {
442    fn from(value: TypePtr<T>) -> Self {
443        value.to_ptr()
444    }
445}
446
447impl<T: Type> Clone for TypePtr<T> {
448    fn clone(&self) -> TypePtr<T> {
449        *self
450    }
451}
452
453impl<T: Type> Copy for TypePtr<T> {}
454
455impl<T: Type> PartialEq for TypePtr<T> {
456    fn eq(&self, other: &Self) -> bool {
457        self.0 == other.0
458    }
459}
460
461impl<T: Type> Eq for TypePtr<T> {}
462
463impl<T: Type> Hash for TypePtr<T> {
464    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
465        self.0.hash(state);
466    }
467}
468
469impl<T: Type> Printable for TypePtr<T> {
470    fn fmt(
471        &self,
472        ctx: &Context,
473        state: &printable::State,
474        f: &mut core::fmt::Formatter<'_>,
475    ) -> core::fmt::Result {
476        Printable::fmt(&self.0, ctx, state, f)
477    }
478}
479
480impl<T: Type + Parsable<Arg = (), Parsed = TypePtr<T>>> Parsable for TypePtr<T> {
481    type Arg = ();
482    type Parsed = Self;
483
484    fn parse<'a>(
485        state_stream: &mut StateStream<'a>,
486        arg: Self::Arg,
487    ) -> ParseResult<'a, Self::Parsed> {
488        let loc = state_stream.loc();
489        spaced(TypeId::parser(()))
490            .then(move |type_id| {
491                let loc = loc.clone();
492                combine::parser(move |parsable_state: &mut StateStream<'a>| {
493                    if type_id != T::get_type_id_static() {
494                        input_err!(
495                            loc.clone(),
496                            "Expected type {}, but found {}",
497                            T::get_type_id_static().disp(parsable_state.state.ctx),
498                            type_id.disp(parsable_state.state.ctx)
499                        )?
500                    }
501                    T::parser(arg).parse_stream(parsable_state).into()
502                })
503            })
504            .parse_stream(state_stream)
505            .into_result()
506    }
507}
508
509impl<T: Type> Verify for TypePtr<T> {
510    fn verify(&self, ctx: &Context) -> Result<()> {
511        self.0.verify(ctx)
512    }
513}
514
515/// Marker trait for type interface trait objects.
516///
517/// This is auto-implemented by the `#[type_interface]` macro for `dyn Interface`
518/// objects and is used to restrict [type_cast] and [type_impls] to interface casts.
519#[diagnostic::on_unimplemented(
520    message = "`{Self}` not a type interface.",
521    label = "If `{Self}` is a trait, annotate it with #[type_interface] to be able to cast to it from a `&dyn Type`",
522    note = "If you want to cast to a concrete `Type`, use `downcast_ref` instead."
523)]
524pub trait TypeInterfaceMarker {}
525
526/// Cast reference to a [Type] object to an interface reference.
527///
528/// Right usage: cast to an interface trait object.
529/// ```
530/// use pliron::builtin::type_interfaces::FunctionTypeInterface;
531/// use pliron::r#type::{Type, type_cast};
532///
533/// fn right_cast(ty: &dyn Type) {
534///     let _ = type_cast::<dyn FunctionTypeInterface>(ty);
535/// }
536/// ```
537///
538/// Casting to concrete [Type] types are intentionally rejected.
539/// ```compile_fail
540/// use pliron::builtin::types::IntegerType;
541/// use pliron::r#type::{Type, type_cast};
542///
543/// fn wrong_cast(ty: &dyn Type) {
544///     let _ = type_cast::<IntegerType>(ty);
545/// }
546/// ```
547/// Use [downcast_rs](https://docs.rs/downcast-rs/latest/downcast_rs/#example-without-generics)
548/// to cast to concrete [Type] types.
549pub fn type_cast<T: ?Sized + TypeInterfaceMarker + 'static>(ty: &dyn Type) -> Option<&T> {
550    crate::utils::trait_cast::any_to_trait::<T>(ty.as_any())
551}
552
553/// Does this [Type] object implement interface `T`?
554///
555/// Right usage: query using an interface trait object.
556/// ```
557/// use pliron::builtin::type_interfaces::FunctionTypeInterface;
558/// use pliron::r#type::{Type, type_impls};
559///
560/// fn right_query(ty: &dyn Type) {
561///     let _ = type_impls::<dyn FunctionTypeInterface>(ty);
562/// }
563/// ```
564///
565/// Querying with a concrete [Type] type is intentionally rejected.
566/// ```compile_fail
567/// use pliron::builtin::types::IntegerType;
568/// use pliron::r#type::{Type, type_impls};
569///
570/// fn wrong_query(ty: &dyn Type) {
571///     let _ = type_impls::<IntegerType>(ty);
572/// }
573/// ```
574pub fn type_impls<T: ?Sized + TypeInterfaceMarker + 'static>(ty: &dyn Type) -> bool {
575    type_cast::<T>(ty).is_some()
576}
577
578/// Every type interface must have a function named `verify` with this type.
579pub type TypeInterfaceVerifier = fn(&dyn Type, &Context) -> Result<()>;
580/// Function returns the list of super verifiers, followed by a self verifier, for an interface.
581pub type TypeInterfaceAllVerifiers = fn() -> Vec<TypeInterfaceVerifier>;
582
583#[doc(hidden)]
584/// A [Type] paired with an interface it implements
585/// (specifically the verifiers (including super verifiers) for that interface).
586type TypeInterfaceVerifierInfo = (std::any::TypeId, TypeInterfaceAllVerifiers);
587
588#[cfg(not(target_family = "wasm"))]
589pub mod statics {
590    use super::*;
591
592    #[::pliron::linkme::distributed_slice]
593    pub static TYPE_INTERFACE_VERIFIERS: [LazyLock<TypeInterfaceVerifierInfo>] = [..];
594
595    pub fn get_type_interface_verifiers()
596    -> impl Iterator<Item = &'static LazyLock<TypeInterfaceVerifierInfo>> {
597        TYPE_INTERFACE_VERIFIERS.iter()
598    }
599}
600
601#[cfg(target_family = "wasm")]
602pub mod statics {
603    use super::*;
604    use crate::utils::inventory::LazyLockWrapper;
605
606    ::pliron::inventory::collect!(LazyLockWrapper<TypeInterfaceVerifierInfo>);
607
608    pub fn get_type_interface_verifiers()
609    -> impl Iterator<Item = &'static LazyLock<TypeInterfaceVerifierInfo>> {
610        ::pliron::inventory::iter::<LazyLockWrapper<TypeInterfaceVerifierInfo>>().map(|llw| llw.0)
611    }
612}
613
614pub use statics::*;
615
616#[doc(hidden)]
617/// A map from every [Type] to its ordered (as per interface deps) list of interface verifiers.
618/// An interface's super-interfaces are to be verified before it itself is.
619pub static TYPE_INTERFACE_VERIFIERS_MAP: LazyLock<
620    FxHashMap<std::any::TypeId, Vec<TypeInterfaceVerifier>>,
621> = LazyLock::new(|| collect_deduped_interface_verifiers(get_type_interface_verifiers()));