Skip to main content

pliron_llvm/
ops.rs

1//! [Op]s defined in the LLVM dialect
2
3use std::{sync::LazyLock, vec};
4
5use pliron::{
6    arg_err, arg_err_noloc,
7    attribute::{AttrObj, AttributeDict, attr_cast, attr_impls},
8    basic_block::BasicBlock,
9    builtin::{
10        attr_interfaces::{FloatAttr, TypedAttrInterface},
11        attributes::{IdentifierAttr, IntegerAttr, StringAttr, TypeAttr},
12        op_interfaces::{
13            self, ATTR_KEY_SYM_NAME, AtLeastNOpdsInterface, AtLeastNResultsInterface,
14            AtMostNOpdsInterface, AtMostNRegionsInterface, AtMostOneRegionInterface,
15            BranchOpInterface, CallOpCallable, CallOpInterface, IsTerminatorInterface,
16            IsolatedFromAboveInterface, NOpdsInterface, NResultsInterface, OneOpdInterface,
17            OneResultInterface, OperandNOfType, OperandSegmentInterface, OptionalOpdInterface,
18            ResultNOfType, SameOperandsAndResultType, SameOperandsType, SameResultsType,
19            SingleBlockRegionInterface, SymbolOpInterface, SymbolUserOpInterface,
20        },
21        type_interfaces::{FloatTypeInterface, FunctionTypeInterface},
22        types::{IntegerType, Signedness},
23    },
24    common_traits::{Named, Verify},
25    context::{Context, Ptr},
26    identifier::Identifier,
27    indented_block, input_err,
28    irbuild::{inserter::Inserter, rewriter::Rewriter},
29    irfmt::{
30        self,
31        parsers::{
32            attr_parser, block_opd_parser, delimited_list_parser, process_parsed_ssa_defs, spaced,
33            ssa_opd_parser, type_parser,
34        },
35        printers::{iter_with_sep, list_with_sep, op::typed_symb_op_header},
36    },
37    linked_list::ContainsLinkedList,
38    location::{Located, Location},
39    op::{Op, OpObj},
40    operation::Operation,
41    opts::mem2reg::{
42        AllocInfo, PromotableAllocationInterface, PromotableOpInterface, PromotableOpKind,
43    },
44    parsable::{IntoParseResult, Parsable, ParseResult, StateStream},
45    printable::{self, Printable, indented_nl},
46    region::Region,
47    result::{Error, ErrorKind, Result},
48    symbol_table::SymbolTableCollection,
49    r#type::{TypeObj, TypePtr, type_cast, type_impls},
50    utils::vec_exns::VecExtns,
51    value::Value,
52    verify_err, verify_error,
53};
54
55use crate::{
56    attributes::{
57        AlignmentAttr, CaseValuesAttr, FCmpPredicateAttr, FastmathFlagsAttr,
58        InsertExtractValueIndicesAttr, LinkageAttr, ShuffleVectorMaskAttr,
59    },
60    llvm_sys::core::{llvm_get_undef_mask_elem, llvm_lookup_intrinsic_id},
61    op_interfaces::{
62        AlignableOpInterface, BinArithOp, CastOpInterface, CastOpWithNNegInterface, FastMathFlags,
63        FloatBinArithOp, FloatBinArithOpWithFastMathFlags, IntBinArithOp,
64        IntBinArithOpWithOverflowFlag, IsDeclaration, LlvmSymbolName, NNegFlag, PointerTypeResult,
65    },
66    ops::{
67        func_op_attr_names::ATTR_KEY_LLVM_FUNC_TYPE,
68        global_op_attr_names::{ATTR_KEY_GLOBAL_INITIALIZER, ATTR_KEY_LLVM_GLOBAL_TYPE},
69    },
70    types::{ArrayType, FuncType, StructType, VectorType},
71};
72
73use pliron::combine::{
74    self, between, optional,
75    parser::{Parser, char::spaces},
76    token,
77};
78
79use pliron::derive::{op_interface_impl, pliron_op};
80use thiserror::Error;
81
82use super::{
83    attributes::{GepIndexAttr, GepIndicesAttr, ICmpPredicateAttr},
84    types::PointerType,
85};
86
87/// Equivalent to LLVM's return opcode.
88///
89/// Operands:
90///
91/// | operand | description |
92/// |-----|-------|
93/// | `arg` | any type |
94#[pliron_op(
95    name = "llvm.return",
96    format = "operands(CharSpace(`,`))",
97    interfaces = [IsTerminatorInterface, NResultsInterface<0>, AtMostNOpdsInterface<1>, OptionalOpdInterface],
98    verifier = "succ"
99)]
100pub struct ReturnOp;
101impl ReturnOp {
102    /// Create a new [ReturnOp]
103    pub fn new(ctx: &mut Context, value: Option<Value>) -> Self {
104        let op = Operation::new(
105            ctx,
106            Self::get_concrete_op_info(),
107            vec![],
108            value.into_iter().collect(),
109            vec![],
110            0,
111        );
112        ReturnOp { op }
113    }
114
115    /// Get the returned value, if it exists.
116    pub fn retval(&self, ctx: &Context) -> Option<Value> {
117        self.get_operand(ctx)
118    }
119}
120
121/// Equivalent to LLVM's unreachable opcode.
122/// No operands or results.
123#[pliron_op(
124    name = "llvm.unreachable",
125    format,
126    interfaces = [IsTerminatorInterface, NOpdsInterface<0>, NResultsInterface<0>],
127    verifier = "succ"
128)]
129pub struct UnreachableOp;
130
131impl UnreachableOp {
132    /// Create a new [UnreachableOp]
133    pub fn new(ctx: &mut Context) -> Self {
134        let op = Operation::new(ctx, Self::get_concrete_op_info(), vec![], vec![], vec![], 0);
135        UnreachableOp { op }
136    }
137}
138
139macro_rules! new_int_bin_op_with_format {
140    (   $(#[$outer:meta])*
141        $op_name:ident, $op_id:literal, $format:literal
142    ) => {
143        $(#[$outer])*
144        /// ### Operands:
145        ///
146        /// | operand | description |
147        /// |-----|-------|
148        /// | `lhs` | Signless integer |
149        /// | `rhs` | Signless integer |
150        ///
151        /// ### Result(s):
152        ///
153        /// | result | description |
154        /// |-----|-------|
155        /// | `res` | Signless integer |
156        #[pliron_op(
157            name = $op_id,
158            format = $format,
159            interfaces = [
160                OneResultInterface, NResultsInterface<1>, SameOperandsType, SameResultsType,
161                AtLeastNOpdsInterface<1>, AtLeastNResultsInterface<1>,
162                SameOperandsAndResultType, BinArithOp, IntBinArithOp, NOpdsInterface<2>
163            ],
164            verifier = "succ"
165        )]
166        pub struct $op_name;
167    }
168}
169
170macro_rules! new_int_bin_op {
171    (   $(#[$outer:meta])*
172        $op_name:ident, $op_id:literal
173    ) => {
174        new_int_bin_op_with_format!(
175            $(#[$outer])*
176            $op_name,
177            $op_id,
178            "$0 `, ` $1 ` : ` type($0)"
179        );
180    }
181}
182
183macro_rules! new_int_bin_op_with_overflow {
184    (   $(#[$outer:meta])*
185        $op_name:ident, $op_id:literal
186    ) => {
187        new_int_bin_op_with_format!(
188            $(#[$outer])*
189            /// ### Attributes:
190            ///
191            /// | key | value | via Interface |
192            /// |-----|-------| --------------
193            /// | [ATTR_KEY_INTEGER_OVERFLOW_FLAGS](super::op_interfaces::ATTR_KEY_INTEGER_OVERFLOW_FLAGS) | [IntegerOverflowFlagsAttr](super::attributes::IntegerOverflowFlagsAttr) | [IntBinArithOpWithOverflowFlag] |
194            $op_name,
195            $op_id,
196            "$0 `, ` $1 ` <` attr($llvm_integer_overflow_flags, `super::attributes::IntegerOverflowFlagsAttr`) `>` `: ` type($0)"
197        );
198        #[pliron::derive::op_interface_impl]
199        impl IntBinArithOpWithOverflowFlag for $op_name {}
200    }
201}
202
203new_int_bin_op_with_overflow!(
204    /// Equivalent to LLVM's Add opcode.
205    AddOp,
206    "llvm.add"
207);
208
209new_int_bin_op_with_overflow!(
210    /// Equivalent to LLVM's Sub opcode.
211    SubOp,
212    "llvm.sub"
213);
214
215new_int_bin_op_with_overflow!(
216    /// Equivalent to LLVM's Mul opcode.
217    MulOp,
218    "llvm.mul"
219);
220
221new_int_bin_op_with_overflow!(
222    /// Equivalent to LLVM's Shl opcode.
223    ShlOp,
224    "llvm.shl"
225);
226
227new_int_bin_op!(
228    /// Equivalent to LLVM's UDiv opcode.
229    UDivOp,
230    "llvm.udiv"
231);
232
233new_int_bin_op!(
234    /// Equivalent to LLVM's SDiv opcode.
235    SDivOp,
236    "llvm.sdiv"
237);
238
239new_int_bin_op!(
240    /// Equivalent to LLVM's URem opcode.
241    URemOp,
242    "llvm.urem"
243);
244
245new_int_bin_op!(
246    /// Equivalent to LLVM's SRem opcode.
247    SRemOp,
248    "llvm.srem"
249);
250
251new_int_bin_op!(
252    /// Equivalent to LLVM's And opcode.
253    AndOp,
254    "llvm.and"
255);
256
257new_int_bin_op!(
258    /// Equivalent to LLVM's Or opcode.
259    OrOp,
260    "llvm.or"
261);
262
263new_int_bin_op!(
264    /// Equivalent to LLVM's Xor opcode.
265    XorOp,
266    "llvm.xor"
267);
268
269new_int_bin_op!(
270    /// Equivalent to LLVM's LShr opcode.
271    LShrOp,
272    "llvm.lshr"
273);
274
275new_int_bin_op!(
276    /// Equivalent to LLVM's AShr opcode.
277    AShrOp,
278    "llvm.ashr"
279);
280
281#[derive(Error, Debug)]
282pub enum ICmpOpVerifyErr {
283    #[error("Result must be (possibly vector of) 1-bit integer (bool)")]
284    ResultNotBool,
285    #[error("Operand must be (possibly vector of) integer or pointer types")]
286    IncorrectOperandsType,
287    #[error("Missing or incorrect predicate attribute")]
288    PredAttrErr,
289    #[error("Vector operand and result types must have the same number of elements")]
290    MismatchedVectorNumElements,
291}
292
293/// Equivalent to LLVM's ICmp opcode.
294/// ### Operand(s):
295/// | operand | description |
296/// |-----|-------|
297/// | `lhs` | Signless integer or pointer |
298/// | `rhs` | Signless integer or pointer |
299///
300/// ### Result(s):
301///
302/// | result | description |
303/// |-----|-------|
304/// | `res` | 1-bit signless integer |
305#[pliron_op(
306    name = "llvm.icmp",
307    format = "$0 ` <` attr($icmp_predicate, $ICmpPredicateAttr) `> ` $1 ` : ` type($0)",
308    interfaces = [SameOperandsType, AtLeastNOpdsInterface<1>, OneResultInterface, NResultsInterface<1>],
309    attributes = (icmp_predicate: ICmpPredicateAttr)
310)]
311pub struct ICmpOp;
312
313impl ICmpOp {
314    /// Create a new [ICmpOp]
315    pub fn new(ctx: &mut Context, pred: ICmpPredicateAttr, lhs: Value, rhs: Value) -> Self {
316        use pliron::r#type::Typed;
317
318        // Determine the result type.
319        let bool_ty = IntegerType::get(ctx, 1, Signedness::Signless);
320        let opd_type = lhs.get_type(ctx);
321        let vec_details = opd_type
322            .deref(ctx)
323            .downcast_ref::<VectorType>()
324            .map(|vec_ty| (vec_ty.num_elements(), vec_ty.kind()));
325        let res_ty = if let Some((num_elements, kind)) = vec_details {
326            VectorType::get(ctx, bool_ty.into(), num_elements, kind).into()
327        } else {
328            bool_ty.into()
329        };
330
331        let op = Operation::new(
332            ctx,
333            Self::get_concrete_op_info(),
334            vec![res_ty],
335            vec![lhs, rhs],
336            vec![],
337            0,
338        );
339        let op = ICmpOp { op };
340        op.set_attr_icmp_predicate(ctx, pred);
341        op
342    }
343
344    /// Get the predicate
345    pub fn predicate(&self, ctx: &Context) -> ICmpPredicateAttr {
346        self.get_attr_icmp_predicate(ctx)
347            .expect("ICmpOp missing or incorrect predicate attribute type")
348            .clone()
349    }
350}
351
352impl Verify for ICmpOp {
353    fn verify(&self, ctx: &Context) -> Result<()> {
354        let loc = self.loc(ctx);
355
356        if self.get_attr_icmp_predicate(ctx).is_none() {
357            verify_err!(loc.clone(), ICmpOpVerifyErr::PredAttrErr)?
358        }
359
360        let mut res_ty = self.result_type(ctx);
361        let mut vec_num_elements = None;
362        if let Some(vec_ty) = res_ty.deref(ctx).downcast_ref::<VectorType>() {
363            res_ty = vec_ty.elem_type();
364            vec_num_elements = Some(vec_ty.num_elements());
365        }
366        let res_ty = res_ty.deref(ctx);
367        let Some(res_ty) = res_ty.downcast_ref::<IntegerType>() else {
368            return verify_err!(loc, ICmpOpVerifyErr::ResultNotBool);
369        };
370        if res_ty.width() != 1 {
371            return verify_err!(loc, ICmpOpVerifyErr::ResultNotBool);
372        }
373
374        let mut opd_ty = self.operand_type(ctx);
375        if let Some(vec_ty) = opd_ty.deref(ctx).downcast_ref::<VectorType>() {
376            opd_ty = vec_ty.elem_type();
377            // Ensure that the number of elements matches the result type's number of elements.
378            if vec_num_elements.is_none_or(|num_elements| vec_ty.num_elements() != num_elements) {
379                return verify_err!(loc, ICmpOpVerifyErr::MismatchedVectorNumElements);
380            }
381        } else if vec_num_elements.is_some() {
382            return verify_err!(loc, ICmpOpVerifyErr::MismatchedVectorNumElements);
383        }
384        let opd_ty = opd_ty.deref(ctx);
385        if !(opd_ty.is::<IntegerType>() || opd_ty.is::<PointerType>()) {
386            return verify_err!(loc, ICmpOpVerifyErr::IncorrectOperandsType);
387        }
388
389        Ok(())
390    }
391}
392
393#[derive(Error, Debug)]
394pub enum AllocaOpVerifyErr {
395    #[error("Operand must be a signless integer")]
396    OperandType,
397    #[error("Missing or incorrect type of attribute for element type")]
398    ElemTypeAttr,
399}
400
401/// Equivalent to LLVM's Alloca opcode.
402/// ### Operands
403/// | operand | description |
404/// |-----|-------|
405/// | `array_size` | Signless integer |
406///
407/// ### Result(s):
408///
409/// | result | description |
410/// |-----|-------|
411/// | `res` | [PointerType] |
412#[pliron_op(
413    name = "llvm.alloca",
414    format = "`[` attr($alloca_element_type, $TypeAttr) ` x ` $0 `]` ` ` \
415    opt_attr($llvm_alignment, $AlignmentAttr, label($align), delimiters(`[`, `]`)) \
416    ` : ` type($0)",
417    interfaces = [
418        OneResultInterface,
419        OneOpdInterface,
420        NResultsInterface<1>,
421        NOpdsInterface<1>,
422        AlignableOpInterface,
423        OperandNOfType<0, IntegerType>,
424        ResultNOfType<0, PointerType>,
425    ],
426    attributes = (alloca_element_type: TypeAttr)
427)]
428pub struct AllocaOp;
429impl Verify for AllocaOp {
430    fn verify(&self, ctx: &Context) -> Result<()> {
431        let loc = self.loc(ctx);
432        // Ensure correctness of element type.
433        if self.get_attr_alloca_element_type(ctx).is_none() {
434            verify_err!(loc, AllocaOpVerifyErr::ElemTypeAttr)?
435        }
436        Ok(())
437    }
438}
439
440#[op_interface_impl]
441impl PointerTypeResult for AllocaOp {
442    fn result_pointee_type(&self, ctx: &Context) -> Ptr<TypeObj> {
443        self.get_attr_alloca_element_type(ctx)
444            .expect("AllocaOp missing or incorrect type for elem_type attribute")
445            .get_type(ctx)
446    }
447}
448
449impl AllocaOp {
450    /// Create a new [AllocaOp]
451    pub fn new(ctx: &mut Context, elem_type: Ptr<TypeObj>, size: Value) -> Self {
452        let ptr_ty = PointerType::get(ctx).into();
453        let op = Operation::new(
454            ctx,
455            Self::get_concrete_op_info(),
456            vec![ptr_ty],
457            vec![size],
458            vec![],
459            0,
460        );
461        let op = AllocaOp { op };
462        op.set_attr_alloca_element_type(ctx, TypeAttr::new(elem_type));
463        op
464    }
465}
466
467#[derive(Error, Debug)]
468#[error("Register Promotion: Allocation info provided is not related to this operation")]
469pub struct UnrelatedAllocInfo;
470
471#[op_interface_impl]
472impl PromotableAllocationInterface for AllocaOp {
473    fn alloc_info(&self, ctx: &Context) -> Vec<AllocInfo> {
474        vec![AllocInfo {
475            ptr: self.get_result(ctx),
476            ty: self.result_pointee_type(ctx),
477        }]
478    }
479
480    fn default_value(
481        &self,
482        ctx: &mut Context,
483        inserter: &mut dyn Inserter,
484        alloc_info: &AllocInfo,
485    ) -> Result<Value> {
486        if alloc_info.ptr != self.get_result(ctx) {
487            return arg_err!(self.loc(ctx), UnrelatedAllocInfo);
488        }
489        let poison = PoisonOp::new(ctx, alloc_info.ty);
490        let poison_val = poison.get_result(ctx);
491        inserter.insert_op(ctx, &poison);
492        Ok(poison_val)
493    }
494
495    fn promote(
496        &self,
497        ctx: &mut Context,
498        rewriter: &mut dyn Rewriter,
499        alloc_infos: &[AllocInfo],
500    ) -> Result<()> {
501        if alloc_infos.len() != 1 || alloc_infos[0].ptr != self.get_result(ctx) {
502            return arg_err!(self.loc(ctx), UnrelatedAllocInfo);
503        }
504        rewriter.erase_operation(ctx, self.get_operation());
505        Ok(())
506    }
507}
508
509/// Equivalent to LLVM's Bitcast opcode.
510/// ### Operands
511/// | operand | description |
512/// |-----|-------|
513/// | `arg` | non-aggregate LLVM type |
514///
515/// ### Result(s):
516///
517/// | result | description |
518/// |-----|-------|
519/// | `res` | non-aggregate LLVM type |
520#[pliron_op(
521    name = "llvm.bitcast",
522    format = "$0 ` to ` type($0)",
523    interfaces = [
524        OneResultInterface,
525        OneOpdInterface,
526        NResultsInterface<1>,
527        NOpdsInterface<1>,
528        CastOpInterface
529    ],
530    verifier = "succ"
531)]
532pub struct BitcastOp;
533
534#[derive(Error, Debug)]
535pub enum IntToPtrOpErr {
536    #[error("Operand must be a signless integer")]
537    OperandTypeErr,
538    #[error("Result must be a pointer type")]
539    ResultTypeErr,
540}
541
542/// Equivalent to LLVM's IntToPtr opcode.
543/// ### Operands
544/// | operand | description |
545/// |-----|-------|
546/// | `arg` | Signless integer |
547////
548/// ### Result(s):
549///
550/// | result | description |
551/// |-----|-------|
552/// | `res` | [PointerType] |
553///
554#[pliron_op(
555    name = "llvm.inttoptr",
556    format = "$0 ` to ` type($0)",
557    interfaces = [
558        OneResultInterface,
559        OneOpdInterface,
560        NResultsInterface<1>,
561        NOpdsInterface<1>,
562        CastOpInterface,
563        OperandNOfType<0, IntegerType>,
564        ResultNOfType<0, PointerType>
565     ],
566     verifier = "succ"
567)]
568pub struct IntToPtrOp;
569
570#[derive(Error, Debug)]
571pub enum PtrToIntOpErr {
572    #[error("Operand must be a pointer type")]
573    OperandTypeErr,
574    #[error("Result must be a signless integer type")]
575    ResultTypeErr,
576}
577
578/// Equivalent to LLVM's PtrToInt opcode.
579/// ### Operands
580/// | operand | description |
581/// |-----|-------|
582/// | `arg` | [PointerType] |
583///
584/// ### Result(s):
585/// | result | description |
586/// |-----|-------|
587/// | `res` | Signless integer |
588#[pliron_op(
589    name = "llvm.ptrtoint",
590    format = "$0 ` to ` type($0)",
591    interfaces = [
592        OneResultInterface,
593        OneOpdInterface,
594        NResultsInterface<1>,
595        NOpdsInterface<1>,
596        CastOpInterface,
597        OperandNOfType<0, PointerType>,
598        ResultNOfType<0, IntegerType>,
599    ],
600    verifier = "succ"
601)]
602pub struct PtrToIntOp;
603
604/// Equivalent to LLVM's Unconditional Branch.
605/// ### Operands
606/// | operand | description |
607/// |-----|-------|
608/// | `dest_opds` | Any number of operands with any LLVM type |
609///
610/// ### Successors:
611///
612/// | Successor | description |
613/// |-----|-------|
614/// | `dest` | Any successor |
615#[pliron_op(
616    name = "llvm.br",
617    format = "succ($0) `(` operands(CharSpace(`,`)) `)`",
618    interfaces = [IsTerminatorInterface, NResultsInterface<0>],
619    verifier = "succ"
620)]
621pub struct BrOp;
622
623#[op_interface_impl]
624impl BranchOpInterface for BrOp {
625    fn successor_operands(&self, ctx: &Context, succ_idx: usize) -> Vec<Value> {
626        assert!(succ_idx == 0, "BrOp has exactly one successor");
627        self.get_operation().deref(ctx).operands().collect()
628    }
629
630    fn add_successor_operand(&self, ctx: &mut Context, succ_idx: usize, operand: Value) -> usize {
631        assert!(succ_idx == 0, "BrOp has exactly one successor");
632        Operation::push_operand(self.get_operation(), ctx, operand)
633    }
634
635    fn remove_successor_operand(
636        &self,
637        ctx: &mut Context,
638        succ_idx: usize,
639        opd_idx: usize,
640    ) -> Value {
641        assert!(succ_idx == 0, "BrOp has exactly one successor");
642        Operation::remove_operand(self.get_operation(), ctx, opd_idx)
643    }
644}
645
646impl BrOp {
647    /// Create anew [BrOp].
648    pub fn new(ctx: &mut Context, dest: Ptr<BasicBlock>, dest_opds: Vec<Value>) -> Self {
649        BrOp {
650            op: Operation::new(
651                ctx,
652                Self::get_concrete_op_info(),
653                vec![],
654                dest_opds,
655                vec![dest],
656                0,
657            ),
658        }
659    }
660}
661
662/// Equivalent to LLVM's Conditional Branch.
663/// ### Operands
664/// | operand | description |
665/// |-----|-------|
666/// | `condition` | 1-bit signless integer |
667/// | `true_dest_opds` | Any number of operands with any LLVM type |
668/// | `false_dest_opds` | Any number of operands with any LLVM type |
669///
670/// ### Successors:
671///
672/// | Successor | description |
673/// |-----|-------|
674/// | `true_dest` | Any successor |
675/// | `false_dest` | Any successor |
676#[pliron_op(
677    name = "llvm.cond_br",
678    interfaces = [IsTerminatorInterface, NResultsInterface<0>],
679)]
680pub struct CondBrOp;
681impl CondBrOp {
682    /// Create a new [CondBrOp].
683    pub fn new(
684        ctx: &mut Context,
685        condition: Value,
686        true_dest: Ptr<BasicBlock>,
687        true_dest_opds: Vec<Value>,
688        false_dest: Ptr<BasicBlock>,
689        false_dest_opds: Vec<Value>,
690    ) -> Self {
691        let (operands, segment_sizes) =
692            Self::compute_segment_sizes(vec![vec![condition], true_dest_opds, false_dest_opds]);
693
694        let op = CondBrOp {
695            op: Operation::new(
696                ctx,
697                Self::get_concrete_op_info(),
698                vec![],
699                operands,
700                vec![true_dest, false_dest],
701                0,
702            ),
703        };
704
705        // Set the operand segment sizes attribute.
706        op.set_operand_segment_sizes(ctx, segment_sizes);
707        op
708    }
709
710    /// Get the condition value for the branch.
711    pub fn condition(&self, ctx: &Context) -> Value {
712        self.op.deref(ctx).get_operand(0)
713    }
714}
715
716#[derive(Error, Debug)]
717enum CondBrOpVerifyErr {
718    #[error("Condition operand must be a 1-bit signless integer (i1) or vector of i1")]
719    IncorrectConditionType,
720}
721
722impl Verify for CondBrOp {
723    fn verify(&self, ctx: &Context) -> Result<()> {
724        use pliron::r#type::Typed;
725        // Ensure that the condition is a 1-bit signless integer
726        let condition_ty = self.condition(ctx).get_type(ctx);
727        let condition_ty = condition_ty.deref(ctx);
728        let condition_int_ty = condition_ty.downcast_ref::<IntegerType>().ok_or_else(|| {
729            verify_error!(self.loc(ctx), CondBrOpVerifyErr::IncorrectConditionType)
730        })?;
731        if condition_int_ty.width() != 1 || condition_int_ty.signedness() != Signedness::Signless {
732            verify_err!(self.loc(ctx), CondBrOpVerifyErr::IncorrectConditionType)?
733        }
734        Ok(())
735    }
736}
737
738#[op_interface_impl]
739impl OperandSegmentInterface for CondBrOp {}
740
741impl Printable for CondBrOp {
742    fn fmt(
743        &self,
744        ctx: &Context,
745        _state: &pliron::printable::State,
746        f: &mut std::fmt::Formatter<'_>,
747    ) -> std::fmt::Result {
748        let op = self.get_operation().deref(ctx);
749        let condition = op.get_operand(0);
750        let true_dest_opds = self.successor_operands(ctx, 0);
751        let false_dest_opds = self.successor_operands(ctx, 1);
752        let res = write!(
753            f,
754            "{} if {} ^{}({}) else ^{}({})",
755            Self::get_opid_static(),
756            condition.disp(ctx),
757            op.get_successor(0).deref(ctx).unique_name(ctx),
758            iter_with_sep(
759                true_dest_opds.iter(),
760                pliron::printable::ListSeparator::CharSpace(',')
761            )
762            .disp(ctx),
763            op.get_successor(1).deref(ctx).unique_name(ctx),
764            iter_with_sep(
765                false_dest_opds.iter(),
766                pliron::printable::ListSeparator::CharSpace(',')
767            )
768            .disp(ctx),
769        );
770        res
771    }
772}
773
774impl Parsable for CondBrOp {
775    type Arg = Vec<(Identifier, Location)>;
776    type Parsed = OpObj;
777    fn parse<'a>(
778        state_stream: &mut StateStream<'a>,
779        results: Self::Arg,
780    ) -> ParseResult<'a, Self::Parsed> {
781        if !results.is_empty() {
782            input_err!(
783                state_stream.loc(),
784                op_interfaces::NResultsVerifyErr(0, results.len())
785            )?
786        }
787
788        // Parse the condition operand.
789        let r#if = irfmt::parsers::spaced::<StateStream, _>(combine::parser::char::string("if"));
790
791        let condition = ssa_opd_parser();
792
793        let true_operands = delimited_list_parser('(', ')', ',', ssa_opd_parser());
794
795        let r_else =
796            irfmt::parsers::spaced::<StateStream, _>(combine::parser::char::string("else"));
797
798        let false_operands = delimited_list_parser('(', ')', ',', ssa_opd_parser());
799
800        let final_parser = r#if
801            .with(spaced(condition))
802            .and(spaced(block_opd_parser()))
803            .and(true_operands)
804            .and(spaced(r_else).with(spaced(block_opd_parser()).and(false_operands)));
805
806        final_parser
807            .then(
808                move |(((condition, true_dest), true_dest_opds), (false_dest, false_dest_opds))| {
809                    let results = results.clone();
810                    combine::parser(move |parsable_state: &mut StateStream<'a>| {
811                        let ctx = &mut parsable_state.state.ctx;
812                        let op = CondBrOp::new(
813                            ctx,
814                            condition,
815                            true_dest,
816                            true_dest_opds.clone(),
817                            false_dest,
818                            false_dest_opds.clone(),
819                        );
820
821                        process_parsed_ssa_defs(parsable_state, &results, op.get_operation())?;
822                        Ok(OpObj::new(op)).into_parse_result()
823                    })
824                },
825            )
826            .parse_stream(state_stream)
827            .into()
828    }
829}
830
831#[op_interface_impl]
832impl BranchOpInterface for CondBrOp {
833    fn successor_operands(&self, ctx: &Context, succ_idx: usize) -> Vec<Value> {
834        assert!(
835            succ_idx == 0 || succ_idx == 1,
836            "CondBrOp has exactly two successors"
837        );
838
839        // Skip the first segment, which is the condition.
840        self.get_segment(ctx, succ_idx + 1)
841    }
842
843    fn add_successor_operand(&self, ctx: &mut Context, succ_idx: usize, operand: Value) -> usize {
844        // The successor operands start at segment 1, since segment 0 is the condition operand.
845        self.push_to_segment(ctx, succ_idx + 1, operand)
846    }
847
848    fn remove_successor_operand(
849        &self,
850        ctx: &mut Context,
851        succ_idx: usize,
852        opd_idx: usize,
853    ) -> Value {
854        // The successor operands start at segment 1, since segment 0 is the condition operand.
855        self.remove_from_segment(ctx, succ_idx + 1, opd_idx)
856    }
857}
858
859/// Equivalent to LLVM's Switch opcode.
860///
861/// ### Operands
862/// | operand | description |
863/// |-----|-------|
864/// | `condition` | 1-bit signless integer |
865/// | `default_dest_opds` | variadic of any type |
866/// | `case_dest_opds` | variadic of any type |
867///
868/// ### Successors:
869/// | Successor | description |
870/// |-----|-------|
871/// | `default_dest` | any successor |
872/// | `case_dests` | any successor(s) |
873#[pliron_op(
874    name = "llvm.switch",
875    interfaces = [IsTerminatorInterface, NResultsInterface<0>],
876    attributes = (switch_case_values: CaseValuesAttr)
877)]
878pub struct SwitchOp;
879
880/// One case of a switch statement.
881#[derive(Clone)]
882pub struct SwitchCase {
883    /// The value being matched against.
884    pub value: IntegerAttr,
885    /// The destination block to jump to if this case is taken.
886    pub dest: Ptr<BasicBlock>,
887    /// The operands to pass to the destination block.
888    pub dest_opds: Vec<Value>,
889}
890
891impl Printable for SwitchCase {
892    fn fmt(
893        &self,
894        ctx: &Context,
895        _state: &pliron::printable::State,
896        f: &mut std::fmt::Formatter<'_>,
897    ) -> std::fmt::Result {
898        write!(
899            f,
900            "{{ {}: ^{}({}) }}",
901            self.value.disp(ctx),
902            self.dest.deref(ctx).unique_name(ctx),
903            list_with_sep(
904                &self.dest_opds,
905                pliron::printable::ListSeparator::CharSpace(',')
906            )
907            .disp(ctx)
908        )
909    }
910}
911
912impl Parsable for SwitchCase {
913    type Arg = ();
914    type Parsed = Self;
915
916    fn parse<'a>(
917        state_stream: &mut StateStream<'a>,
918        _arg: Self::Arg,
919    ) -> ParseResult<'a, Self::Parsed> {
920        let mut parser = between(
921            token('{'),
922            token('}'),
923            (
924                spaced(IntegerAttr::parser(())),
925                spaced(token(':')),
926                spaced(block_opd_parser()),
927                delimited_list_parser('(', ')', ',', ssa_opd_parser()),
928                spaces(),
929            ),
930        );
931
932        let ((value, _colon, dest, dest_opds, _spaces), _) =
933            parser.parse_stream(state_stream).into_result()?;
934
935        Ok(SwitchCase {
936            value,
937            dest,
938            dest_opds,
939        })
940        .into_parse_result()
941    }
942}
943
944impl Printable for SwitchOp {
945    fn fmt(
946        &self,
947        ctx: &Context,
948        state: &pliron::printable::State,
949        f: &mut std::fmt::Formatter<'_>,
950    ) -> std::fmt::Result {
951        let op = self.get_operation().deref(ctx);
952        let condition = op.get_operand(0);
953
954        let default_successor = op
955            .successors()
956            .next()
957            .expect("SwitchOp must have at least one successor");
958        let num_total_successors = op.get_num_successors();
959
960        write!(
961            f,
962            "{} {}, ^{}({})",
963            Self::get_opid_static(),
964            condition.disp(ctx),
965            default_successor.unique_name(ctx).disp(ctx),
966            iter_with_sep(
967                self.successor_operands(ctx, 0).iter(),
968                pliron::printable::ListSeparator::CharSpace(',')
969            )
970            .disp(ctx),
971        )?;
972
973        if num_total_successors < 2 {
974            writeln!(f, "[]")?;
975            return Ok(());
976        }
977
978        let cases = self.cases(ctx);
979
980        write!(f, "{}[", indented_nl(state))?;
981        indented_block!(state, {
982            write!(f, "{}", indented_nl(state))?;
983            list_with_sep(&cases, pliron::printable::ListSeparator::CharNewline(','))
984                .fmt(ctx, state, f)?;
985        });
986        write!(f, "{}]", indented_nl(state))?;
987
988        Ok(())
989    }
990}
991
992impl Parsable for SwitchOp {
993    type Arg = Vec<(Identifier, Location)>;
994    type Parsed = OpObj;
995
996    fn parse<'a>(
997        state_stream: &mut StateStream<'a>,
998        arg: Self::Arg,
999    ) -> ParseResult<'a, Self::Parsed> {
1000        if !arg.is_empty() {
1001            input_err!(
1002                state_stream.loc(),
1003                op_interfaces::NResultsVerifyErr(0, arg.len())
1004            )?
1005        }
1006
1007        // Parse the condition operand.
1008        let condition = ssa_opd_parser().skip(spaced(token(',')));
1009        let default_successor = block_opd_parser();
1010        let default_operands = delimited_list_parser('(', ')', ',', ssa_opd_parser());
1011        let cases = delimited_list_parser('[', ']', ',', SwitchCase::parser(()));
1012
1013        let final_parser = spaced(condition)
1014            .and(default_successor)
1015            .skip(spaces())
1016            .and(default_operands)
1017            .skip(spaces())
1018            .and(cases);
1019
1020        final_parser
1021            .then(
1022                move |(((condition, default_dest), default_dest_opds), cases)| {
1023                    let results = arg.clone();
1024                    combine::parser(move |parsable_state: &mut StateStream<'a>| {
1025                        let ctx = &mut parsable_state.state.ctx;
1026                        let op = SwitchOp::new(
1027                            ctx,
1028                            condition,
1029                            default_dest,
1030                            default_dest_opds.clone(),
1031                            cases.clone(),
1032                        );
1033
1034                        process_parsed_ssa_defs(parsable_state, &results, op.get_operation())?;
1035                        Ok(OpObj::new(op)).into_parse_result()
1036                    })
1037                },
1038            )
1039            .parse_stream(state_stream)
1040            .into()
1041    }
1042}
1043
1044impl SwitchOp {
1045    /// Create a new [SwitchOp].
1046    pub fn new(
1047        ctx: &mut Context,
1048        condition: Value,
1049        default_dest: Ptr<BasicBlock>,
1050        default_dest_opds: Vec<Value>,
1051        cases: Vec<SwitchCase>,
1052    ) -> Self {
1053        let case_values: Vec<IntegerAttr> = cases.iter().map(|case| case.value.clone()).collect();
1054
1055        let case_operands = cases
1056            .iter()
1057            .map(|case| case.dest_opds.clone())
1058            .collect::<Vec<_>>();
1059
1060        let mut operand_segments = vec![vec![condition], default_dest_opds];
1061        operand_segments.extend(case_operands);
1062        let (operands, segment_sizes) = Self::compute_segment_sizes(operand_segments);
1063
1064        let case_dests = cases.iter().map(|case| case.dest);
1065        let successors = vec![default_dest].into_iter().chain(case_dests).collect();
1066        let op = SwitchOp {
1067            op: Operation::new(
1068                ctx,
1069                Self::get_concrete_op_info(),
1070                vec![],
1071                operands,
1072                successors,
1073                0,
1074            ),
1075        };
1076
1077        // Set the operand segment sizes attribute.
1078        op.set_operand_segment_sizes(ctx, segment_sizes);
1079        // Set the case values
1080        op.set_attr_switch_case_values(ctx, CaseValuesAttr(case_values));
1081        op
1082    }
1083
1084    /// Get the cases of this switch operation.
1085    /// (The default case cannot be / isn't included here).
1086    pub fn cases(&self, ctx: &Context) -> Vec<SwitchCase> {
1087        let case_values = &*self
1088            .get_attr_switch_case_values(ctx)
1089            .expect("SwitchOp missing or incorrect case values attribute");
1090
1091        let op = self.get_operation().deref(ctx);
1092        // Skip the first one, which is the default successor.
1093        let successors = op.successors().skip(1);
1094
1095        successors
1096            .zip(case_values.0.iter())
1097            .enumerate()
1098            .map(|(i, (dest, value))| {
1099                // i+1 here because the first successor is the default destination.
1100                let dest_opds = self.successor_operands(ctx, i + 1);
1101                SwitchCase {
1102                    value: value.clone(),
1103                    dest,
1104                    dest_opds,
1105                }
1106            })
1107            .collect()
1108    }
1109
1110    /// Get the condition value for the switch.
1111    pub fn condition(&self, ctx: &Context) -> Value {
1112        self.get_operation().deref(ctx).get_operand(0)
1113    }
1114
1115    /// Get the default destination of this switch operation.
1116    pub fn default_dest(&self, ctx: &Context) -> Ptr<BasicBlock> {
1117        self.get_operation().deref(ctx).get_successor(0)
1118    }
1119
1120    /// Get the operands to pass to the default destination.
1121    pub fn default_dest_operands(&self, ctx: &Context) -> Vec<Value> {
1122        self.successor_operands(ctx, 0)
1123    }
1124}
1125
1126#[op_interface_impl]
1127impl BranchOpInterface for SwitchOp {
1128    fn successor_operands(&self, ctx: &Context, succ_idx: usize) -> Vec<Value> {
1129        // Skip the first segment, which is the condition.
1130        self.get_segment(ctx, succ_idx + 1)
1131    }
1132
1133    fn add_successor_operand(&self, ctx: &mut Context, succ_idx: usize, operand: Value) -> usize {
1134        // The successor operands start at segment 1, since segment 0 is the condition operand.
1135        self.push_to_segment(ctx, succ_idx + 1, operand)
1136    }
1137
1138    fn remove_successor_operand(
1139        &self,
1140        ctx: &mut Context,
1141        succ_idx: usize,
1142        opd_idx: usize,
1143    ) -> Value {
1144        // The successor operands start at segment 1, since segment 0 is the condition operand.
1145        self.remove_from_segment(ctx, succ_idx + 1, opd_idx)
1146    }
1147}
1148
1149#[op_interface_impl]
1150impl OperandSegmentInterface for SwitchOp {}
1151
1152#[derive(Error, Debug)]
1153pub enum SwitchOpVerifyErr {
1154    #[error("SwitchOp has no or incorrect case values attribute")]
1155    CaseValuesAttrErr,
1156    #[error("SwitchOp has no or incorrect default destination")]
1157    DefaultDestErr,
1158    #[error("SwitchOp has no condition operand or is not an integer")]
1159    ConditionErr,
1160}
1161
1162impl Verify for SwitchOp {
1163    fn verify(&self, ctx: &Context) -> Result<()> {
1164        let loc = self.loc(ctx);
1165
1166        let Some(case_values) = self.get_attr_switch_case_values(ctx) else {
1167            verify_err!(loc.clone(), SwitchOpVerifyErr::CaseValuesAttrErr)?
1168        };
1169
1170        let op = &*self.get_operation().deref(ctx);
1171
1172        if op.get_num_successors() < 1 {
1173            verify_err!(loc.clone(), SwitchOpVerifyErr::DefaultDestErr)?;
1174        }
1175
1176        if op.get_num_operands() < 1 {
1177            verify_err!(loc.clone(), SwitchOpVerifyErr::ConditionErr)?;
1178        }
1179
1180        let condition_ty = pliron::r#type::Typed::get_type(&op.get_operand(0), ctx);
1181        let condition_ty = TypePtr::<IntegerType>::from_ptr(condition_ty, ctx)?;
1182
1183        if let Some(case_value) = case_values.0.first() {
1184            // Ensure that the case value type matches the condition type.
1185            if case_value.get_type() != condition_ty {
1186                verify_err!(loc, SwitchOpVerifyErr::ConditionErr)?;
1187            }
1188        }
1189
1190        Ok(())
1191    }
1192}
1193
1194/// A way to express whether a GEP index is a constant or an SSA value
1195#[derive(Clone)]
1196pub enum GepIndex {
1197    Constant(u32),
1198    Value(Value),
1199}
1200
1201impl Printable for GepIndex {
1202    fn fmt(
1203        &self,
1204        ctx: &Context,
1205        _state: &pliron::printable::State,
1206        f: &mut std::fmt::Formatter<'_>,
1207    ) -> std::fmt::Result {
1208        match self {
1209            GepIndex::Constant(c) => write!(f, "{c}"),
1210            GepIndex::Value(v) => write!(f, "{}", v.disp(ctx)),
1211        }
1212    }
1213}
1214
1215#[derive(Error, Debug)]
1216pub enum GetElementPtrOpErr {
1217    #[error("GetElementPtrOp has no or incorrect indices attribute")]
1218    IndicesAttrErr,
1219    #[error("The indices on this GEP are invalid for its source element type")]
1220    IndicesErr,
1221}
1222
1223/// Equivalent to LLVM's GetElementPtr.
1224/// ### Operands
1225/// | operand | description |
1226/// |-----|-------|
1227/// | `base` | LLVM pointer type |
1228/// | `dynamicIndices` | Any number of signless integers |
1229///
1230/// ### Result(s):
1231///
1232/// | result | description |
1233/// |-----|-------|
1234/// | `res` | LLVM pointer type |
1235#[pliron_op(
1236    name = "llvm.gep",
1237    format = "`<` attr($gep_src_elem_type, $TypeAttr) `>` ` (` operands(CharSpace(`,`)) `)` attr($gep_indices, $GepIndicesAttr) ` : ` type($0)",
1238    interfaces = [OneResultInterface, NResultsInterface<1>, ResultNOfType<0, PointerType>],
1239    attributes = (gep_src_elem_type: TypeAttr, gep_indices: GepIndicesAttr)
1240)]
1241pub struct GetElementPtrOp;
1242
1243#[op_interface_impl]
1244impl PointerTypeResult for GetElementPtrOp {
1245    fn result_pointee_type(&self, ctx: &Context) -> Ptr<TypeObj> {
1246        Self::indexed_type(ctx, self.src_elem_type(ctx), &self.indices(ctx))
1247            .expect("Invalid indices for GEP")
1248    }
1249}
1250
1251impl Verify for GetElementPtrOp {
1252    fn verify(&self, ctx: &Context) -> Result<()> {
1253        let loc = self.loc(ctx);
1254        // Ensure that we have the indices as an attribute.
1255        if self.get_attr_gep_indices(ctx).is_none() {
1256            verify_err!(loc, GetElementPtrOpErr::IndicesAttrErr)?
1257        }
1258
1259        if let Err(e @ Error { .. }) =
1260            Self::indexed_type(ctx, self.src_elem_type(ctx), &self.indices(ctx))
1261        {
1262            return Err(Error {
1263                kind: ErrorKind::VerificationFailed,
1264                // We reset the error origin to be from here
1265                backtrace: std::backtrace::Backtrace::capture(),
1266                ..e
1267            });
1268        }
1269
1270        Ok(())
1271    }
1272}
1273
1274impl GetElementPtrOp {
1275    /// Create a new [GetElementPtrOp]
1276    pub fn new(
1277        ctx: &mut Context,
1278        base: Value,
1279        indices: Vec<GepIndex>,
1280        src_elem_type: Ptr<TypeObj>,
1281    ) -> Self {
1282        let result_type = PointerType::get(ctx).into();
1283        let mut attr: Vec<GepIndexAttr> = Vec::new();
1284        let mut opds: Vec<Value> = vec![base];
1285        for idx in indices {
1286            match idx {
1287                GepIndex::Constant(c) => {
1288                    attr.push(GepIndexAttr::Constant(c));
1289                }
1290                GepIndex::Value(v) => {
1291                    attr.push(GepIndexAttr::OperandIdx(opds.push_back(v)));
1292                }
1293            }
1294        }
1295        let op = Operation::new(
1296            ctx,
1297            Self::get_concrete_op_info(),
1298            vec![result_type],
1299            opds,
1300            vec![],
1301            0,
1302        );
1303        let src_elem_type = TypeAttr::new(src_elem_type);
1304        let op = GetElementPtrOp { op };
1305
1306        op.set_attr_gep_indices(ctx, GepIndicesAttr(attr));
1307        op.set_attr_gep_src_elem_type(ctx, src_elem_type);
1308        op
1309    }
1310
1311    /// Get the source pointer's element type.
1312    pub fn src_elem_type(&self, ctx: &Context) -> Ptr<TypeObj> {
1313        self.get_attr_gep_src_elem_type(ctx)
1314            .expect("GetElementPtrOp missing or has incorrect src_elem_type attribute type")
1315            .get_type(ctx)
1316    }
1317
1318    /// Get the base (source) pointer of this GEP.
1319    pub fn src_ptr(&self, ctx: &Context) -> Value {
1320        self.get_operation().deref(ctx).get_operand(0)
1321    }
1322
1323    /// Get the indices of this GEP.
1324    pub fn indices(&self, ctx: &Context) -> Vec<GepIndex> {
1325        let op = &*self.op.deref(ctx);
1326        self.get_attr_gep_indices(ctx)
1327            .unwrap()
1328            .0
1329            .iter()
1330            .map(|index| match index {
1331                GepIndexAttr::Constant(c) => GepIndex::Constant(*c),
1332                GepIndexAttr::OperandIdx(i) => GepIndex::Value(op.get_operand(*i)),
1333            })
1334            .collect()
1335    }
1336
1337    /// Returns the result element type of a GEP with the given source element type and indexes.
1338    /// See [getIndexedType](https://llvm.org/doxygen/classllvm_1_1GetElementPtrInst.html#a99d4bfe49182f8d80abb1960f2c12d46)
1339    pub fn indexed_type(
1340        ctx: &Context,
1341        src_elem_type: Ptr<TypeObj>,
1342        indices: &[GepIndex],
1343    ) -> Result<Ptr<TypeObj>> {
1344        fn indexed_type_inner(
1345            ctx: &Context,
1346            src_elem_type: Ptr<TypeObj>,
1347            mut idx_itr: impl Iterator<Item = GepIndex>,
1348        ) -> Result<Ptr<TypeObj>> {
1349            let Some(idx) = idx_itr.next() else {
1350                return Ok(src_elem_type);
1351            };
1352            let src_elem_type = &*src_elem_type.deref(ctx);
1353            if let Some(st) = src_elem_type.downcast_ref::<StructType>() {
1354                let GepIndex::Constant(i) = idx else {
1355                    return arg_err_noloc!(GetElementPtrOpErr::IndicesErr);
1356                };
1357                if st.is_opaque() || i as usize >= st.num_fields() {
1358                    return arg_err_noloc!(GetElementPtrOpErr::IndicesErr);
1359                }
1360                indexed_type_inner(ctx, st.field_type(i as usize), idx_itr)
1361            } else if let Some(at) = src_elem_type.downcast_ref::<ArrayType>() {
1362                indexed_type_inner(ctx, at.elem_type(), idx_itr)
1363            } else {
1364                arg_err_noloc!(GetElementPtrOpErr::IndicesErr)
1365            }
1366        }
1367        // The first index is for the base (source) pointer. Skip that.
1368        indexed_type_inner(ctx, src_elem_type, indices.iter().skip(1).cloned())
1369    }
1370}
1371
1372#[derive(Error, Debug)]
1373pub enum LoadOpVerifyErr {
1374    #[error("Load operand must be a pointer")]
1375    OperandTypeErr,
1376}
1377
1378/// Equivalent to LLVM's Load opcode.
1379/// ### Operands
1380/// | operand | description |
1381/// |-----|-------|
1382/// | `addr` | [PointerType] |
1383///
1384/// ### Result(s):
1385///
1386/// | result | description |
1387/// |-----|-------|
1388/// | `res` | sized LLVM type |
1389#[pliron_op(
1390    name = "llvm.load",
1391    format = "$0 ` ` opt_attr($llvm_alignment, $AlignmentAttr, label($align), delimiters(`[`, `]`)) ` : ` type($0)",
1392    interfaces = [
1393        OneResultInterface,
1394        OneOpdInterface,
1395        NResultsInterface<1>,
1396        NOpdsInterface<1>,
1397        AlignableOpInterface,
1398        OperandNOfType<0, PointerType>
1399    ],
1400    verifier = "succ"
1401)]
1402pub struct LoadOp;
1403impl LoadOp {
1404    /// Create a new [LoadOp]
1405    pub fn new(ctx: &mut Context, ptr: Value, res_ty: Ptr<TypeObj>) -> Self {
1406        LoadOp {
1407            op: Operation::new(
1408                ctx,
1409                Self::get_concrete_op_info(),
1410                vec![res_ty],
1411                vec![ptr],
1412                vec![],
1413                0,
1414            ),
1415        }
1416    }
1417
1418    /// Get the address operand
1419    pub fn address_opd(&self, ctx: &Context) -> Value {
1420        self.op.deref(ctx).get_operand(0)
1421    }
1422}
1423
1424#[op_interface_impl]
1425impl PromotableOpInterface for LoadOp {
1426    fn promotion_kind(&self, ctx: &Context, alloc_info: &AllocInfo) -> PromotableOpKind {
1427        if self.address_opd(ctx) == alloc_info.ptr {
1428            PromotableOpKind::Load
1429        } else {
1430            PromotableOpKind::NonPromotableUse
1431        }
1432    }
1433
1434    fn promote(
1435        &self,
1436        ctx: &mut Context,
1437        alloc_info_reaching_defs: &[(AllocInfo, Value)],
1438        rewriter: &mut dyn Rewriter,
1439    ) -> Result<()> {
1440        if alloc_info_reaching_defs.len() != 1 {
1441            return arg_err!(self.loc(ctx), UnrelatedAllocInfo);
1442        }
1443        let (alloc_info, reaching_def) = &alloc_info_reaching_defs[0];
1444        if self.address_opd(ctx) != alloc_info.ptr {
1445            return arg_err!(self.loc(ctx), UnrelatedAllocInfo);
1446        }
1447        rewriter.replace_operation_with_values(ctx, self.get_operation(), vec![*reaching_def]);
1448        Ok(())
1449    }
1450}
1451
1452#[derive(Error, Debug)]
1453pub enum StoreOpVerifyErr {
1454    #[error("Store operand must have two operands")]
1455    NumOpdsErr,
1456    #[error("Store operand must have a pointer as its second argument")]
1457    AddrOpdTypeErr,
1458}
1459
1460/// Equivalent to LLVM's Store opcode.
1461/// ### Operands
1462/// | operand | description |
1463/// |-----|-------|
1464/// | `addr` | [PointerType] |
1465/// | `value` | Sized type |
1466#[pliron_op(
1467    name = "llvm.store",
1468    format = "`*` $1 ` <- ` $0 ` ` opt_attr($llvm_alignment, $AlignmentAttr, label($align), delimiters(`[`, `]`))",
1469    interfaces = [
1470        NResultsInterface<0>,
1471        AlignableOpInterface,
1472        OperandNOfType<1, PointerType>,
1473        NOpdsInterface<2>
1474    ],
1475    verifier = "succ"
1476)]
1477pub struct StoreOp;
1478impl StoreOp {
1479    /// Create a new [StoreOp]
1480    pub fn new(ctx: &mut Context, value: Value, ptr: Value) -> Self {
1481        StoreOp {
1482            op: Operation::new(
1483                ctx,
1484                Self::get_concrete_op_info(),
1485                vec![],
1486                vec![value, ptr],
1487                vec![],
1488                0,
1489            ),
1490        }
1491    }
1492
1493    /// Get the value operand
1494    pub fn value_opd(&self, ctx: &Context) -> Value {
1495        self.op.deref(ctx).get_operand(0)
1496    }
1497
1498    /// Get the address operand
1499    pub fn address_opd(&self, ctx: &Context) -> Value {
1500        self.op.deref(ctx).get_operand(1)
1501    }
1502}
1503
1504#[op_interface_impl]
1505impl PromotableOpInterface for StoreOp {
1506    fn promotion_kind(&self, ctx: &Context, alloc_info: &AllocInfo) -> PromotableOpKind {
1507        if self.address_opd(ctx) == alloc_info.ptr {
1508            PromotableOpKind::Store(self.value_opd(ctx))
1509        } else {
1510            PromotableOpKind::NonPromotableUse
1511        }
1512    }
1513
1514    fn promote(
1515        &self,
1516        ctx: &mut Context,
1517        alloc_info_reaching_defs: &[(AllocInfo, Value)],
1518        rewriter: &mut dyn Rewriter,
1519    ) -> Result<()> {
1520        if alloc_info_reaching_defs.len() != 1 {
1521            return arg_err!(self.loc(ctx), UnrelatedAllocInfo);
1522        }
1523        let (alloc_info, _reaching_def) = &alloc_info_reaching_defs[0];
1524        if self.address_opd(ctx) != alloc_info.ptr {
1525            return arg_err!(self.loc(ctx), UnrelatedAllocInfo);
1526        }
1527        rewriter.erase_operation(ctx, self.get_operation());
1528        Ok(())
1529    }
1530}
1531
1532/// Equivalent to LLVM's Store opcode.
1533///
1534/// ### Operands
1535/// | operand | description |
1536/// |-----|-------|
1537/// | `callee_operands` | Optional function pointer followed by any number of parameters |
1538///
1539/// ### Result(s):
1540///
1541/// | result | description |
1542/// |-----|-------|
1543/// | `res` | LLVM type |
1544#[pliron_op(
1545    name = "llvm.call",
1546    interfaces = [OneResultInterface, NResultsInterface<1>],
1547    attributes = (llvm_call_callee: IdentifierAttr, llvm_call_fastmath_flags: FastmathFlagsAttr)
1548)]
1549pub struct CallOp;
1550
1551impl CallOp {
1552    /// Get a new [CallOp].
1553    pub fn new(
1554        ctx: &mut Context,
1555        callee: CallOpCallable,
1556        callee_ty: TypePtr<FuncType>,
1557        mut args: Vec<Value>,
1558    ) -> Self {
1559        let res_ty = callee_ty.deref(ctx).result_type();
1560        let op = match callee {
1561            CallOpCallable::Direct(cval) => {
1562                let op = Operation::new(
1563                    ctx,
1564                    Self::get_concrete_op_info(),
1565                    vec![res_ty],
1566                    args,
1567                    vec![],
1568                    0,
1569                );
1570                let op = CallOp { op };
1571                op.set_attr_llvm_call_callee(ctx, IdentifierAttr::new(cval));
1572                op
1573            }
1574            CallOpCallable::Indirect(csym) => {
1575                args.insert(0, csym);
1576                let op = Operation::new(
1577                    ctx,
1578                    Self::get_concrete_op_info(),
1579                    vec![res_ty],
1580                    args,
1581                    vec![],
1582                    0,
1583                );
1584                CallOp { op }
1585            }
1586        };
1587        op.set_callee_type(ctx, callee_ty.into());
1588        op
1589    }
1590}
1591
1592#[derive(Error, Debug)]
1593pub enum SymbolUserOpVerifyErr {
1594    #[error("Symbol {0} not found")]
1595    SymbolNotFound(String),
1596    #[error("Function {0} should have been llvm.func type")]
1597    NotLlvmFunc(String),
1598    #[error("AddressOf Op can only refer to a function or a global variable")]
1599    AddressOfInvalidReference,
1600    #[error("Function call has incorrect type: {0}")]
1601    FuncTypeErr(String),
1602}
1603
1604#[op_interface_impl]
1605impl SymbolUserOpInterface for CallOp {
1606    fn verify_symbol_uses(
1607        &self,
1608        ctx: &Context,
1609        symbol_tables: &mut SymbolTableCollection,
1610    ) -> Result<()> {
1611        match self.callee(ctx) {
1612            CallOpCallable::Direct(callee_sym) => {
1613                let Some(callee) = symbol_tables.lookup_symbol_in_nearest_table(
1614                    ctx,
1615                    self.get_operation(),
1616                    &callee_sym,
1617                ) else {
1618                    return verify_err!(
1619                        self.loc(ctx),
1620                        SymbolUserOpVerifyErr::SymbolNotFound(callee_sym.to_string())
1621                    );
1622                };
1623                let Some(func_op) = (&*callee as &dyn Op).downcast_ref::<FuncOp>() else {
1624                    return verify_err!(
1625                        self.loc(ctx),
1626                        SymbolUserOpVerifyErr::NotLlvmFunc(callee_sym.to_string())
1627                    );
1628                };
1629                let func_op_ty = func_op.get_type(ctx);
1630
1631                if func_op_ty.to_ptr() != self.callee_type(ctx) {
1632                    return verify_err!(
1633                        self.loc(ctx),
1634                        SymbolUserOpVerifyErr::FuncTypeErr(format!(
1635                            "expected {}, got {}",
1636                            func_op_ty.disp(ctx),
1637                            self.callee_type(ctx).disp(ctx)
1638                        ))
1639                    );
1640                }
1641            }
1642            CallOpCallable::Indirect(pointer) => {
1643                use pliron::r#type::Typed;
1644                if !pointer.get_type(ctx).deref(ctx).is::<PointerType>() {
1645                    return verify_err!(
1646                        self.loc(ctx),
1647                        SymbolUserOpVerifyErr::FuncTypeErr("Callee must be a pointer".to_string())
1648                    );
1649                }
1650            }
1651        }
1652        Ok(())
1653    }
1654
1655    fn used_symbols(&self, ctx: &Context) -> Vec<Identifier> {
1656        match self.callee(ctx) {
1657            CallOpCallable::Direct(identifier) => vec![identifier],
1658            CallOpCallable::Indirect(_) => vec![],
1659        }
1660    }
1661}
1662
1663#[op_interface_impl]
1664impl CallOpInterface for CallOp {
1665    fn callee(&self, ctx: &Context) -> CallOpCallable {
1666        let op = self.op.deref(ctx);
1667        if let Some(callee_sym) = self.get_attr_llvm_call_callee(ctx) {
1668            CallOpCallable::Direct(callee_sym.clone().into())
1669        } else {
1670            assert!(
1671                op.get_num_operands() > 0,
1672                "Indirect call must have function pointer operand"
1673            );
1674            CallOpCallable::Indirect(op.get_operand(0))
1675        }
1676    }
1677
1678    fn args(&self, ctx: &Context) -> Vec<Value> {
1679        let op = self.op.deref(ctx);
1680        // If this is an indirect call, the first operand is the callee value.
1681        let skip = if matches!(self.callee(ctx), CallOpCallable::Direct(_)) {
1682            0
1683        } else {
1684            1
1685        };
1686        op.operands().skip(skip).collect()
1687    }
1688}
1689
1690impl Printable for CallOp {
1691    fn fmt(
1692        &self,
1693        ctx: &Context,
1694        _state: &pliron::printable::State,
1695        f: &mut std::fmt::Formatter<'_>,
1696    ) -> std::fmt::Result {
1697        let callee = self.callee(ctx);
1698        write!(
1699            f,
1700            "{} = {} ",
1701            self.get_result(ctx).disp(ctx),
1702            self.get_opid()
1703        )?;
1704        match callee {
1705            CallOpCallable::Direct(callee_sym) => {
1706                write!(f, "@{callee_sym}")?;
1707            }
1708            CallOpCallable::Indirect(callee_val) => {
1709                write!(f, "{}", callee_val.disp(ctx))?;
1710            }
1711        }
1712
1713        if let Some(fmf) = self.get_attr_llvm_call_fastmath_flags(ctx)
1714            && *fmf != FastmathFlagsAttr::default()
1715        {
1716            write!(f, " {}", fmf.disp(ctx))?;
1717        }
1718
1719        let args = self.args(ctx);
1720        let ty = self.callee_type(ctx);
1721        write!(
1722            f,
1723            " ({}) : {}",
1724            list_with_sep(&args, pliron::printable::ListSeparator::CharSpace(',')).disp(ctx),
1725            ty.disp(ctx)
1726        )?;
1727        Ok(())
1728    }
1729}
1730
1731impl Parsable for CallOp {
1732    type Arg = Vec<(Identifier, Location)>;
1733    type Parsed = OpObj;
1734
1735    fn parse<'a>(
1736        state_stream: &mut StateStream<'a>,
1737        results: Self::Arg,
1738    ) -> ParseResult<'a, Self::Parsed> {
1739        let direct_callee = combine::token('@')
1740            .with(Identifier::parser(()))
1741            .map(CallOpCallable::Direct);
1742        let indirect_callee = ssa_opd_parser().map(CallOpCallable::Indirect);
1743        let callee_parser = direct_callee.or(indirect_callee);
1744        let fastmath_flags_parser = optional(FastmathFlagsAttr::parser(()));
1745        let args_parser = delimited_list_parser('(', ')', ',', ssa_opd_parser());
1746        let ty_parser = spaced(combine::token(':')).with(TypePtr::<FuncType>::parser(()));
1747
1748        let mut final_parser = spaced(callee_parser)
1749            .and(spaced(fastmath_flags_parser))
1750            .and(spaced(args_parser))
1751            .and(ty_parser)
1752            .then(move |(((callee, fastmath_flags), args), ty)| {
1753                let results = results.clone();
1754                combine::parser(move |parsable_state: &mut StateStream<'a>| {
1755                    let ctx = &mut parsable_state.state.ctx;
1756                    let op = CallOp::new(ctx, callee.clone(), ty, args.clone());
1757                    if let Some(fmf) = &fastmath_flags {
1758                        op.set_attr_llvm_call_fastmath_flags(ctx, *fmf);
1759                    }
1760                    process_parsed_ssa_defs(parsable_state, &results, op.get_operation())?;
1761                    Ok(OpObj::new(op)).into_parse_result()
1762                })
1763            });
1764
1765        final_parser.parse_stream(state_stream).into_result()
1766    }
1767}
1768
1769impl Verify for CallOp {
1770    fn verify(&self, ctx: &Context) -> Result<()> {
1771        // Check that the argument and result types match the callee type.
1772        let callee_ty = &*self.callee_type(ctx).deref(ctx);
1773        let Some(callee_ty) = callee_ty.downcast_ref::<FuncType>() else {
1774            return verify_err!(
1775                self.loc(ctx),
1776                SymbolUserOpVerifyErr::FuncTypeErr("Callee is not a function".to_string())
1777            );
1778        };
1779        // Check the function type against the arguments.
1780        let args = self.args(ctx);
1781        let expected_args = callee_ty.arg_types();
1782        if !callee_ty.is_var_arg() && args.len() != expected_args.len() {
1783            return verify_err!(
1784                self.loc(ctx),
1785                SymbolUserOpVerifyErr::FuncTypeErr("argument count mismatch.".to_string())
1786            );
1787        }
1788        use pliron::r#type::Typed;
1789        for (arg_idx, (arg, expected_arg)) in args.iter().zip(expected_args.iter()).enumerate() {
1790            if arg.get_type(ctx) != *expected_arg {
1791                return verify_err!(
1792                    self.loc(ctx),
1793                    SymbolUserOpVerifyErr::FuncTypeErr(format!(
1794                        "argument {} type mismatch: expected {}, got {}",
1795                        arg_idx,
1796                        expected_arg.disp(ctx),
1797                        arg.get_type(ctx).disp(ctx)
1798                    ))
1799                );
1800            }
1801        }
1802
1803        if callee_ty.result_type() != self.result_type(ctx) {
1804            return verify_err!(
1805                self.loc(ctx),
1806                SymbolUserOpVerifyErr::FuncTypeErr(format!(
1807                    "result type mismatch: expected {}, got {}",
1808                    callee_ty.result_type().disp(ctx),
1809                    self.result_type(ctx).disp(ctx)
1810                ))
1811            );
1812        }
1813
1814        Ok(())
1815    }
1816}
1817
1818/// Undefined value of a type.
1819/// See MLIR's [llvm.mlir.undef](https://mlir.llvm.org/docs/Dialects/LLVM/#llvmmlirundef-llvmundefop).
1820///
1821/// ### Results:
1822/// | result | description |
1823/// |-----|-------|
1824/// | `result` | any type |
1825#[pliron_op(
1826    name = "llvm.undef",
1827    format = "`: ` type($0)",
1828    interfaces = [OneResultInterface, NResultsInterface<1>],
1829    verifier = "succ"
1830)]
1831pub struct UndefOp;
1832
1833impl UndefOp {
1834    /// Create a new [UndefOp].
1835    pub fn new(ctx: &mut Context, result_ty: Ptr<TypeObj>) -> Self {
1836        let op = Operation::new(
1837            ctx,
1838            Self::get_concrete_op_info(),
1839            vec![result_ty],
1840            vec![],
1841            vec![],
1842            0,
1843        );
1844        UndefOp { op }
1845    }
1846}
1847
1848/// Poison value of a type.
1849/// See MLIR's [llvm.mlir.poison](https://mlir.llvm.org/docs/Dialects/LLVM/#llvmmlirpoison-llvmpoisonop).
1850///
1851/// ### Results:
1852/// | result | description |
1853/// |-----|-------|
1854/// | `result` | any type |
1855#[pliron_op(
1856    name = "llvm.poison",
1857    format = "`: ` type($0)",
1858    interfaces = [OneResultInterface, NResultsInterface<1>],
1859    verifier = "succ"
1860)]
1861pub struct PoisonOp;
1862
1863impl PoisonOp {
1864    /// Create a new [PoisonOp].
1865    pub fn new(ctx: &mut Context, result_ty: Ptr<TypeObj>) -> Self {
1866        let op = Operation::new(
1867            ctx,
1868            Self::get_concrete_op_info(),
1869            vec![result_ty],
1870            vec![],
1871            vec![],
1872            0,
1873        );
1874        PoisonOp { op }
1875    }
1876}
1877
1878/// Freeze value of a type.
1879/// See MLIR's [llvm.mlir.freeze](https://mlir.llvm.org/docs/Dialects/LLVM/#llvmfreeze-llvmfreezeop).
1880///
1881/// ### Results:
1882/// | result | description |
1883/// |-----|-------|
1884/// | `result` | any type |
1885///
1886/// ### Operands:
1887/// | operand | description |
1888/// |-----|-------|
1889/// | `value` | any type |
1890#[pliron_op(
1891    name = "llvm.freeze",
1892    format = "$0 ` : ` type($0)",
1893    interfaces = [OneOpdInterface, OneResultInterface, NOpdsInterface<1>, NResultsInterface<1>],
1894    verifier = "succ"
1895)]
1896pub struct FreezeOp;
1897
1898impl FreezeOp {
1899    /// Create a new [FreezeOp].
1900    pub fn new(ctx: &mut Context, value: Value) -> Self {
1901        use pliron::r#type::Typed;
1902        let result_ty = value.get_type(ctx);
1903        let op = Operation::new(
1904            ctx,
1905            Self::get_concrete_op_info(),
1906            vec![result_ty],
1907            vec![value],
1908            vec![],
1909            0,
1910        );
1911        FreezeOp { op }
1912    }
1913}
1914
1915/// Numeric (integer or floating point) constant.
1916/// See MLIR's [llvm.mlir.constant](https://mlir.llvm.org/docs/Dialects/LLVM/#llvmmlirconstant-llvmconstantop).
1917///
1918/// ### Results:
1919///
1920/// | result | description |
1921/// |-----|-------|
1922/// | `result` | any type |
1923#[pliron_op(
1924    name = "llvm.constant",
1925    format = "`<` $constant_value `>` ` : ` type($0)",
1926    interfaces = [NOpdsInterface<0>, OneResultInterface, NResultsInterface<1>],
1927    attributes = (constant_value)
1928)]
1929pub struct ConstantOp;
1930
1931impl ConstantOp {
1932    /// Get the constant value that this Op defines.
1933    pub fn get_value(&self, ctx: &Context) -> AttrObj {
1934        self.get_attr_constant_value(ctx).unwrap().clone()
1935    }
1936
1937    /// Create a new [ConstantOp].
1938    pub fn new(ctx: &mut Context, value: AttrObj) -> Self {
1939        let result_type = attr_cast::<dyn TypedAttrInterface>(&*value)
1940            .expect("ConstantOp const value must provide TypedAttrInterface")
1941            .get_type(ctx);
1942        let op = Operation::new(
1943            ctx,
1944            Self::get_concrete_op_info(),
1945            vec![result_type],
1946            vec![],
1947            vec![],
1948            0,
1949        );
1950        let op = ConstantOp { op };
1951        op.set_attr_constant_value(ctx, value);
1952        op
1953    }
1954}
1955
1956#[derive(Error, Debug)]
1957#[error("{}: Unexpected type", ConstantOp::get_opid_static())]
1958pub enum ConstantOpVerifyErr {
1959    #[error("ConstantOp must have either an integer or a float value")]
1960    InvalidValue,
1961}
1962
1963impl Verify for ConstantOp {
1964    fn verify(&self, ctx: &Context) -> Result<()> {
1965        let loc = self.loc(ctx);
1966        let value = self.get_value(ctx);
1967        if !(value.is::<IntegerAttr>() || attr_impls::<dyn FloatAttr>(&*value)) {
1968            return verify_err!(loc, ConstantOpVerifyErr::InvalidValue)?;
1969        }
1970        Ok(())
1971    }
1972}
1973
1974/// Same as MLIR's LLVM dialect [ZeroOp](https://mlir.llvm.org/docs/Dialects/LLVM/#llvmmlirzero-llvmzeroop)
1975/// It creates a zero-initialized value of the specified LLVM IR dialect type.
1976/// Results:
1977///
1978/// | result | description |
1979/// |-----|-------|
1980/// | `result` | any type |
1981#[pliron_op(
1982    name = "llvm.zero",
1983    format = "`: ` type($0)",
1984    interfaces = [NOpdsInterface<0>, OneResultInterface, NResultsInterface<1>],
1985    verifier = "succ"
1986)]
1987pub struct ZeroOp;
1988
1989impl ZeroOp {
1990    /// Create a new [ZeroOp].
1991    pub fn new(ctx: &mut Context, result_ty: Ptr<TypeObj>) -> Self {
1992        let op = Operation::new(
1993            ctx,
1994            Self::get_concrete_op_info(),
1995            vec![result_ty],
1996            vec![],
1997            vec![],
1998            0,
1999        );
2000        ZeroOp { op }
2001    }
2002}
2003
2004#[derive(Error, Debug)]
2005pub enum GlobalOpVerifyErr {
2006    #[error("GlobalOp must have a type")]
2007    MissingType,
2008}
2009
2010/// Same as MLIR's LLVM dialect [GlobalOp](https://mlir.llvm.org/docs/Dialects/LLVM/#llvmmlirglobal-llvmglobalop)
2011/// It creates a global variable of the specified LLVM IR dialect type.
2012/// An initializer can be specified either as an attribute or in the
2013/// operation's initializer region, ending with a return.
2014#[pliron_op(
2015    name = "llvm.global",
2016    interfaces = [
2017        IsolatedFromAboveInterface,
2018        NOpdsInterface<0>,
2019        NResultsInterface<0>,
2020        SymbolOpInterface,
2021        SingleBlockRegionInterface,
2022        LlvmSymbolName,
2023        AlignableOpInterface
2024    ],
2025    attributes = (llvm_global_type: TypeAttr, global_initializer, llvm_global_linkage: LinkageAttr)
2026)]
2027pub struct GlobalOp;
2028
2029impl GlobalOp {
2030    /// Create a new [GlobalOp]. An initializer region can be added later if needed.
2031    pub fn new(ctx: &mut Context, name: Identifier, ty: Ptr<TypeObj>) -> Self {
2032        let op = Operation::new(ctx, Self::get_concrete_op_info(), vec![], vec![], vec![], 0);
2033        let op = GlobalOp { op };
2034        op.set_symbol_name(ctx, name);
2035        op.set_attr_llvm_global_type(ctx, TypeAttr::new(ty));
2036        op
2037    }
2038}
2039
2040impl pliron::r#type::Typed for GlobalOp {
2041    fn get_type(&self, ctx: &Context) -> Ptr<TypeObj> {
2042        pliron::r#type::Typed::get_type(
2043            &*self
2044                .get_attr_llvm_global_type(ctx)
2045                .expect("GlobalOp missing or has incorrect type attribute"),
2046            ctx,
2047        )
2048    }
2049}
2050
2051impl GlobalOp {
2052    /// Get the initializer value of this global variable.
2053    pub fn get_initializer_value(&self, ctx: &Context) -> Option<AttrObj> {
2054        self.get_attr_global_initializer(ctx).map(|v| v.clone())
2055    }
2056
2057    /// Get the initializer region's block of this global variable.
2058    /// This is a block that ends with a return operation.
2059    /// The return operation must have the same type as the global variable.
2060    pub fn get_initializer_block(&self, ctx: &Context) -> Option<Ptr<BasicBlock>> {
2061        (self.op.deref(ctx).num_regions() > 0).then(|| self.get_body(ctx, 0))
2062    }
2063
2064    /// Get the initializer region of this global variable.
2065    pub fn get_initializer_region(&self, ctx: &Context) -> Option<Ptr<Region>> {
2066        (self.op.deref(ctx).num_regions() > 0)
2067            .then(|| self.get_operation().deref(ctx).get_region(0))
2068    }
2069
2070    /// Set a simple initializer value for this global variable.
2071    pub fn set_initializer_value(&self, ctx: &Context, value: AttrObj) {
2072        self.set_attr_global_initializer(ctx, value);
2073    }
2074
2075    /// Add an initializer region (with an entry block) for this global variable.
2076    /// There shouldn't already be one.
2077    pub fn add_initializer_region(&self, ctx: &mut Context) -> Ptr<Region> {
2078        assert!(
2079            self.get_initializer_value(ctx).is_none(),
2080            "Attempt to create an initializer region when there already is an initializer value"
2081        );
2082        let region = Operation::add_region(self.get_operation(), ctx);
2083        let entry = BasicBlock::new(ctx, Some("entry".try_into().unwrap()), vec![]);
2084        entry.insert_at_front(region, ctx);
2085
2086        region
2087    }
2088}
2089
2090impl IsDeclaration for GlobalOp {
2091    fn is_declaration(&self, ctx: &Context) -> bool {
2092        self.get_initializer_value(ctx).is_none() && self.get_initializer_region(ctx).is_none()
2093    }
2094}
2095
2096impl Verify for GlobalOp {
2097    fn verify(&self, ctx: &Context) -> Result<()> {
2098        let loc = self.loc(ctx);
2099
2100        // The name must be set. That is checked by the SymbolOpInterface.
2101        // So we check that other attributes are set. Start with type.
2102        if self.get_attr_llvm_global_type(ctx).is_none() {
2103            return verify_err!(loc, GlobalOpVerifyErr::MissingType);
2104        }
2105
2106        // Check that there is at most one initializer
2107        if self.get_initializer_value(ctx).is_some() && self.get_initializer_region(ctx).is_some() {
2108            return verify_err!(loc, GlobalOpVerifyErr::MissingType);
2109        }
2110
2111        Ok(())
2112    }
2113}
2114
2115impl Printable for GlobalOp {
2116    fn fmt(
2117        &self,
2118        ctx: &Context,
2119        state: &pliron::printable::State,
2120        f: &mut std::fmt::Formatter<'_>,
2121    ) -> std::fmt::Result {
2122        write!(
2123            f,
2124            "{} @{} : {}",
2125            self.get_opid(),
2126            self.get_symbol_name(ctx),
2127            <Self as pliron::r#type::Typed>::get_type(self, ctx).disp(ctx)
2128        )?;
2129
2130        // Print attributes except for type, initializer and symbol name.
2131        let mut attributes_to_print_separately =
2132            self.op.deref(ctx).attributes.clone_skip_outlined();
2133        attributes_to_print_separately.0.retain(|key, _| {
2134            key != &*ATTR_KEY_LLVM_GLOBAL_TYPE
2135                && key != &*ATTR_KEY_SYM_NAME
2136                && key != &*ATTR_KEY_GLOBAL_INITIALIZER
2137        });
2138        indented_block!(state, {
2139            write!(
2140                f,
2141                "{}{}",
2142                indented_nl(state),
2143                attributes_to_print_separately.disp(ctx)
2144            )?;
2145        });
2146
2147        if let Some(init_value) = self.get_initializer_value(ctx) {
2148            write!(f, " = {}", init_value.disp(ctx))?;
2149        }
2150
2151        if let Some(init_region) = self.get_initializer_region(ctx) {
2152            write!(f, " = {}", init_region.print(ctx, state))?;
2153        }
2154
2155        Ok(())
2156    }
2157}
2158
2159impl Parsable for GlobalOp {
2160    type Arg = Vec<(Identifier, Location)>;
2161    type Parsed = OpObj;
2162    fn parse<'a>(
2163        state_stream: &mut StateStream<'a>,
2164        results: Self::Arg,
2165    ) -> ParseResult<'a, Self::Parsed> {
2166        let loc = state_stream.loc();
2167        if !results.is_empty() {
2168            return input_err!(loc, "GlobalOp must cannot have results")?;
2169        }
2170        let name_parser = combine::token('@').with(Identifier::parser(()));
2171        let type_parser = type_parser();
2172        let attr_dict_parser = AttributeDict::parser(());
2173
2174        let mut parser = name_parser
2175            .skip(spaced(combine::token(':')))
2176            .and(type_parser)
2177            .and(spaced(attr_dict_parser));
2178
2179        let (((name, ty), attr_dict), _) = parser.parse_stream(state_stream).into_result()?;
2180        let op = GlobalOp::new(state_stream.state.ctx, name, ty);
2181        op.get_operation()
2182            .deref_mut(state_stream.state.ctx)
2183            .attributes
2184            .0
2185            .extend(attr_dict.0);
2186
2187        enum Initializer {
2188            Value(AttrObj),
2189            Region(Ptr<Region>),
2190        }
2191        // Parse optional initializer value or region.
2192        let initializer_parser = combine::token('=').skip(spaces()).with(
2193            attr_parser()
2194                .map(Initializer::Value)
2195                .or(Region::parser(op.get_operation()).map(Initializer::Region)),
2196        );
2197
2198        let initializer = spaces()
2199            .with(combine::optional(initializer_parser))
2200            .parse_stream(state_stream)
2201            .into_result()?;
2202
2203        if let Some(initializer) = initializer.0 {
2204            match initializer {
2205                Initializer::Value(v) => op.set_initializer_value(state_stream.state.ctx, v),
2206                Initializer::Region(_r) => {
2207                    // Nothing to do since the region is already added to the operation during parsing.
2208                }
2209            }
2210        }
2211
2212        Ok(OpObj::new(op)).into_parse_result()
2213    }
2214}
2215
2216/// Same as MLIR's LLVM dialect [AddressOfOp](https://mlir.llvm.org/docs/Dialects/LLVM/#llvmmliraddressof-llvmaddressofop).
2217/// Creates an SSA value containing a pointer to a global value (function, variable etc).
2218///
2219/// ### Results:
2220///
2221/// | result | description |
2222/// |-----|-------|
2223/// | `result` | LLVM pointer type |
2224///
2225#[pliron_op(
2226    name = "llvm.addressof",
2227    format = "`@` attr($global_name, $IdentifierAttr) ` : ` type($0)",
2228    interfaces = [OneResultInterface, NResultsInterface<1>, ResultNOfType<0, PointerType>],
2229    attributes = (global_name: IdentifierAttr),
2230    verifier = "succ"
2231)]
2232pub struct AddressOfOp;
2233
2234impl AddressOfOp {
2235    /// Create a new [AddressOfOp].
2236    pub fn new(ctx: &mut Context, global_name: Identifier) -> Self {
2237        let result_type = PointerType::get(ctx).into();
2238        let op = Operation::new(
2239            ctx,
2240            Self::get_concrete_op_info(),
2241            vec![result_type],
2242            vec![],
2243            vec![],
2244            0,
2245        );
2246        let op = AddressOfOp { op };
2247        op.set_attr_global_name(ctx, IdentifierAttr::new(global_name));
2248        op
2249    }
2250
2251    /// Get the global name that this refers to.
2252    pub fn get_global_name(&self, ctx: &Context) -> Identifier {
2253        self.get_attr_global_name(ctx)
2254            .expect("AddressOfOp missing or has incorrect global_name attribute type")
2255            .clone()
2256            .into()
2257    }
2258
2259    /// If this operation referes to a global, get it.
2260    pub fn get_global(
2261        &self,
2262        ctx: &Context,
2263        symbol_tables: &mut SymbolTableCollection,
2264    ) -> Option<GlobalOp> {
2265        let global_name = self.get_global_name(ctx);
2266        symbol_tables
2267            .lookup_symbol_in_nearest_table(ctx, self.get_operation(), &global_name)
2268            .and_then(|sym_op| {
2269                (sym_op as Box<dyn Op>)
2270                    .downcast::<GlobalOp>()
2271                    .map(|op| *op)
2272                    .ok()
2273            })
2274    }
2275
2276    /// If this operation refers to a function, get it.
2277    pub fn get_function(
2278        &self,
2279        ctx: &Context,
2280        symbol_tables: &mut SymbolTableCollection,
2281    ) -> Option<FuncOp> {
2282        let global_name = self.get_global_name(ctx);
2283        symbol_tables
2284            .lookup_symbol_in_nearest_table(ctx, self.get_operation(), &global_name)
2285            .and_then(|sym_op| {
2286                (sym_op as Box<dyn Op>)
2287                    .downcast::<FuncOp>()
2288                    .map(|op| *op)
2289                    .ok()
2290            })
2291    }
2292}
2293
2294#[op_interface_impl]
2295impl SymbolUserOpInterface for AddressOfOp {
2296    fn used_symbols(&self, ctx: &Context) -> Vec<Identifier> {
2297        vec![self.get_global_name(ctx)]
2298    }
2299
2300    fn verify_symbol_uses(
2301        &self,
2302        ctx: &Context,
2303        symbol_tables: &mut SymbolTableCollection,
2304    ) -> Result<()> {
2305        let loc = self.loc(ctx);
2306        let global_name = self.get_global_name(ctx);
2307        let Some(symbol) =
2308            symbol_tables.lookup_symbol_in_nearest_table(ctx, self.get_operation(), &global_name)
2309        else {
2310            return verify_err!(
2311                loc,
2312                SymbolUserOpVerifyErr::SymbolNotFound(global_name.to_string())
2313            );
2314        };
2315
2316        // Symbol can only be a FuncOp or a GlobalOp
2317        let is_global = (&*symbol as &dyn Op).is::<GlobalOp>();
2318        let is_func = (&*symbol as &dyn Op).is::<FuncOp>();
2319        if !is_global && !is_func {
2320            return verify_err!(loc, SymbolUserOpVerifyErr::AddressOfInvalidReference);
2321        }
2322
2323        Ok(())
2324    }
2325}
2326
2327#[derive(Error, Debug)]
2328enum IntCastVerifyErr {
2329    #[error("Result must be an integer")]
2330    ResultTypeErr,
2331    #[error("Operand must be an integer")]
2332    OperandTypeErr,
2333    #[error("Result type must be larger than operand type")]
2334    ResultTypeSmallerThanOperand,
2335    #[error("Result type must be smaller than operand type")]
2336    ResultTypeLargerThanOperand,
2337    #[error("Result type must be equal to operand type")]
2338    ResultTypeEqualToOperand,
2339}
2340
2341/// Ensure that the integer cast operation is valid.
2342/// This checks that the result type is an integer and that the operand type is also an integer.
2343/// It also checks that the result type is larger or smaller than the operand type (`cmp` operand).
2344fn integer_cast_verify(op: &Operation, ctx: &Context, cmp: ICmpPredicateAttr) -> Result<()> {
2345    use pliron::r#type::Typed;
2346
2347    let loc = op.loc();
2348    let mut res_ty = op.get_type(0).deref(ctx);
2349    let mut opd_ty = op.get_operand(0).get_type(ctx).deref(ctx);
2350
2351    if let Some(vec_res_ty) = res_ty.downcast_ref::<VectorType>() {
2352        res_ty = vec_res_ty.elem_type().deref(ctx);
2353    }
2354    if let Some(vec_opd_ty) = opd_ty.downcast_ref::<VectorType>() {
2355        opd_ty = vec_opd_ty.elem_type().deref(ctx);
2356    }
2357
2358    let Some(res_ty) = res_ty.downcast_ref::<IntegerType>() else {
2359        return verify_err!(loc, IntCastVerifyErr::ResultTypeErr);
2360    };
2361    let Some(opd_ty) = opd_ty.downcast_ref::<IntegerType>() else {
2362        return verify_err!(loc, IntCastVerifyErr::OperandTypeErr);
2363    };
2364
2365    match cmp {
2366        ICmpPredicateAttr::SLT | ICmpPredicateAttr::ULT => {
2367            if res_ty.width() >= opd_ty.width() {
2368                return verify_err!(loc, IntCastVerifyErr::ResultTypeLargerThanOperand);
2369            }
2370        }
2371        ICmpPredicateAttr::SGT | ICmpPredicateAttr::UGT => {
2372            if res_ty.width() <= opd_ty.width() {
2373                return verify_err!(loc, IntCastVerifyErr::ResultTypeSmallerThanOperand);
2374            }
2375        }
2376        ICmpPredicateAttr::SLE | ICmpPredicateAttr::ULE => {
2377            if res_ty.width() > opd_ty.width() {
2378                return verify_err!(loc, IntCastVerifyErr::ResultTypeLargerThanOperand);
2379            }
2380        }
2381        ICmpPredicateAttr::SGE | ICmpPredicateAttr::UGE => {
2382            if res_ty.width() < opd_ty.width() {
2383                return verify_err!(loc, IntCastVerifyErr::ResultTypeSmallerThanOperand);
2384            }
2385        }
2386        ICmpPredicateAttr::EQ | ICmpPredicateAttr::NE => {
2387            if res_ty.width() != opd_ty.width() {
2388                return verify_err!(loc, IntCastVerifyErr::ResultTypeEqualToOperand);
2389            }
2390        }
2391    }
2392    Ok(())
2393}
2394
2395/// Equivalent to LLVM's sext opcode.
2396/// ### Operands
2397/// | operand | description |
2398/// |-----|-------|
2399/// | `arg` | Signless integer |
2400/// ### Result(s):
2401/// | result | description |
2402/// |-----|-------|
2403/// | `res` | Signless integer |
2404#[pliron_op(
2405    name = "llvm.sext",
2406    format = "$0 ` to ` type($0)",
2407    interfaces = [CastOpInterface, OneResultInterface, OneOpdInterface, NResultsInterface<1>, NOpdsInterface<1>]
2408)]
2409pub struct SExtOp;
2410impl Verify for SExtOp {
2411    fn verify(&self, ctx: &Context) -> Result<()> {
2412        integer_cast_verify(
2413            &self.get_operation().deref(ctx),
2414            ctx,
2415            ICmpPredicateAttr::SGT,
2416        )
2417    }
2418}
2419
2420/// Equivalent to LLVM's zext opcode.
2421/// ### Operands
2422/// | operand | description |
2423/// |-----|-------|
2424/// | `arg` | Signless integer |
2425/// ### Result(s):
2426/// | result | description |
2427/// |-----|-------|
2428/// | `res` | Signless integer |
2429#[pliron_op(
2430    name = "llvm.zext",
2431    format = "`<nneg=` attr($llvm_nneg_flag, `pliron::builtin::attributes::BoolAttr`) `> ` $0 ` to ` type($0)",
2432    interfaces = [
2433        CastOpInterface,
2434        OneResultInterface,
2435        OneOpdInterface,
2436        NNegFlag,
2437        CastOpWithNNegInterface,
2438        NResultsInterface<1>,
2439        NOpdsInterface<1>
2440    ]
2441)]
2442pub struct ZExtOp;
2443
2444impl Verify for ZExtOp {
2445    fn verify(&self, ctx: &Context) -> Result<()> {
2446        integer_cast_verify(
2447            &self.get_operation().deref(ctx),
2448            ctx,
2449            ICmpPredicateAttr::UGT,
2450        )
2451    }
2452}
2453
2454/// Equivalent to LLVM's FPExt opcode.
2455///
2456/// ### Operands
2457/// | operand | description |
2458/// |-----|-------|
2459/// | `arg` | Floating-point number |
2460///
2461/// ### Result(s):
2462/// | result | description |
2463/// |-----|-------|
2464/// | `res` | Floating-point number |
2465#[pliron_op(
2466    name = "llvm.fpext",
2467    format = "attr($llvm_fast_math_flags, $FastmathFlagsAttr) ` ` $0 ` to ` type($0)",
2468    interfaces = [CastOpInterface, OneResultInterface, OneOpdInterface, FastMathFlags, NResultsInterface<1>, NOpdsInterface<1>]
2469)]
2470pub struct FPExtOp;
2471
2472impl Verify for FPExtOp {
2473    fn verify(&self, ctx: &Context) -> Result<()> {
2474        // Check operand type to be a float
2475        let opd_ty = OneOpdInterface::operand_type(self, ctx).deref(ctx);
2476        let Some(opd_float_ty) = type_cast::<dyn FloatTypeInterface>(&**opd_ty) else {
2477            return verify_err!(self.loc(ctx), FloatCastVerifyErr::OperandTypeErr);
2478        };
2479        let res_ty = OneResultInterface::result_type(self, ctx).deref(ctx);
2480        let Some(res_float_ty) = type_cast::<dyn FloatTypeInterface>(&**res_ty) else {
2481            return verify_err!(self.loc(ctx), FloatCastVerifyErr::ResultTypeErr);
2482        };
2483
2484        let opd_size = opd_float_ty.get_semantics().bits;
2485        let res_size = res_float_ty.get_semantics().bits;
2486        if res_size <= opd_size {
2487            return verify_err!(
2488                self.loc(ctx),
2489                FloatCastVerifyErr::ResultTypeSmallerThanOperand
2490            );
2491        }
2492        Ok(())
2493    }
2494}
2495
2496#[derive(Error, Debug)]
2497pub enum FloatCastVerifyErr {
2498    #[error("Incorrect operand type")]
2499    OperandTypeErr,
2500    #[error("Incorrect result type")]
2501    ResultTypeErr,
2502    #[error("Operand and result must both be scalars or vectors with matching shape")]
2503    MismatchedVectorShape,
2504    #[error("Result type must be bigger than the operand type")]
2505    ResultTypeSmallerThanOperand,
2506    #[error("Operand type must be bigger than the result type")]
2507    OperandTypeSmallerThanResult,
2508}
2509
2510/// Equivalent to LLVM's trunc opcode.
2511/// ### Operands
2512/// | operand | description |
2513/// |-----|-------|
2514/// | `arg` | Signless integer |
2515/// ### Result(s):
2516/// | result | description |
2517/// |-----|-------|
2518/// | `res` | Signless integer |
2519#[pliron_op(
2520    name = "llvm.trunc",
2521    format = "$0 ` to ` type($0)",
2522    interfaces = [CastOpInterface, OneResultInterface, OneOpdInterface, NResultsInterface<1>, NOpdsInterface<1>]
2523)]
2524pub struct TruncOp;
2525
2526impl Verify for TruncOp {
2527    fn verify(&self, ctx: &Context) -> Result<()> {
2528        integer_cast_verify(
2529            &self.get_operation().deref(ctx),
2530            ctx,
2531            ICmpPredicateAttr::ULT,
2532        )
2533    }
2534}
2535
2536/// Equivalent to LLVM's FPTrunc opcode.
2537/// ### Operands
2538/// | operand | description |
2539/// |-----|-------|
2540/// | `arg` | Floating-point number |
2541/// ### Result(s):
2542/// | result | description |
2543/// |-----|-------|
2544/// | `res` | Floating-point number |
2545#[pliron_op(
2546    name = "llvm.fptrunc",
2547    format = "attr($llvm_fast_math_flags, $FastmathFlagsAttr) ` ` $0 ` to ` type($0)",
2548    interfaces = [CastOpInterface, OneResultInterface, OneOpdInterface, FastMathFlags, NResultsInterface<1>, NOpdsInterface<1>]
2549)]
2550pub struct FPTruncOp;
2551
2552impl Verify for FPTruncOp {
2553    fn verify(&self, ctx: &Context) -> Result<()> {
2554        // Check operand type to be a float
2555        let opd_ty = OneOpdInterface::operand_type(self, ctx).deref(ctx);
2556        let Some(opd_float_ty) = type_cast::<dyn FloatTypeInterface>(&**opd_ty) else {
2557            return verify_err!(self.loc(ctx), FloatCastVerifyErr::OperandTypeErr);
2558        };
2559        let res_ty = OneResultInterface::result_type(self, ctx).deref(ctx);
2560        let Some(res_float_ty) = type_cast::<dyn FloatTypeInterface>(&**res_ty) else {
2561            return verify_err!(self.loc(ctx), FloatCastVerifyErr::ResultTypeErr);
2562        };
2563
2564        let opd_size = opd_float_ty.get_semantics().bits;
2565        let res_size = res_float_ty.get_semantics().bits;
2566        if opd_size <= res_size {
2567            return verify_err!(
2568                self.loc(ctx),
2569                FloatCastVerifyErr::OperandTypeSmallerThanResult
2570            );
2571        }
2572        Ok(())
2573    }
2574}
2575
2576fn cast_element_types(
2577    opd_ty: Ptr<TypeObj>,
2578    res_ty: Ptr<TypeObj>,
2579    ctx: &Context,
2580    loc: Location,
2581) -> Result<(Ptr<TypeObj>, Ptr<TypeObj>)> {
2582    let mut opd_elem_ty = opd_ty;
2583    let mut res_elem_ty = res_ty;
2584    let mut opd_vec_shape = None;
2585    let mut res_vec_shape = None;
2586
2587    if let Some(vec_ty) = opd_ty.deref(ctx).downcast_ref::<VectorType>() {
2588        opd_elem_ty = vec_ty.elem_type();
2589        opd_vec_shape = Some((vec_ty.num_elements(), vec_ty.kind()));
2590    }
2591    if let Some(vec_ty) = res_ty.deref(ctx).downcast_ref::<VectorType>() {
2592        res_elem_ty = vec_ty.elem_type();
2593        res_vec_shape = Some((vec_ty.num_elements(), vec_ty.kind()));
2594    }
2595
2596    if opd_vec_shape != res_vec_shape {
2597        return verify_err!(loc, FloatCastVerifyErr::MismatchedVectorShape);
2598    }
2599
2600    Ok((opd_elem_ty, res_elem_ty))
2601}
2602
2603/// Equivalent to LLVM's FPToSI opcode.
2604///
2605/// ### Operands
2606/// | operand | description |
2607/// |-----|-------|
2608/// | `arg` | Floating-point number |
2609///
2610/// ### Result(s):
2611/// | result | description |
2612/// |-----|-------|
2613/// | `res` | Signed integer |
2614#[pliron_op(
2615    name = "llvm.fptosi",
2616    format = "$0 ` to ` type($0)",
2617    interfaces = [CastOpInterface, OneResultInterface, OneOpdInterface, NResultsInterface<1>, NOpdsInterface<1>]
2618)]
2619pub struct FPToSIOp;
2620
2621impl Verify for FPToSIOp {
2622    fn verify(&self, ctx: &Context) -> Result<()> {
2623        // Check that the operand is a float and the result is an integer
2624        let (opd_ty, res_ty) = cast_element_types(
2625            OneOpdInterface::operand_type(self, ctx),
2626            OneResultInterface::result_type(self, ctx),
2627            ctx,
2628            self.loc(ctx),
2629        )?;
2630        let opd_ty = opd_ty.deref(ctx);
2631        if !type_impls::<dyn FloatTypeInterface>(&**opd_ty) {
2632            return verify_err!(self.loc(ctx), FloatCastVerifyErr::OperandTypeErr);
2633        };
2634        let res_ty = res_ty.deref(ctx);
2635        let Some(res_int_ty) = res_ty.downcast_ref::<IntegerType>() else {
2636            return verify_err!(self.loc(ctx), FloatCastVerifyErr::ResultTypeErr);
2637        };
2638        if !res_int_ty.is_signless() {
2639            return verify_err!(self.loc(ctx), FloatCastVerifyErr::ResultTypeErr);
2640        }
2641        Ok(())
2642    }
2643}
2644
2645/// Equivalent to LLVM's FPToUI opcode.
2646///
2647/// ### Operands
2648/// | operand | description |
2649/// |-----|-------|
2650/// | `arg` | Floating-point number |
2651///
2652/// ### Result(s):
2653/// | result | description |
2654/// |-----|-------|
2655/// | `res` | Unsigned integer |
2656#[pliron_op(
2657    name = "llvm.fptoui",
2658    format = "$0 ` to ` type($0)",
2659    interfaces = [CastOpInterface, OneResultInterface, OneOpdInterface, NResultsInterface<1>, NOpdsInterface<1>]
2660)]
2661pub struct FPToUIOp;
2662
2663impl Verify for FPToUIOp {
2664    fn verify(&self, ctx: &Context) -> Result<()> {
2665        // Check that the operand is a float and the result is an integer
2666        let (opd_ty, res_ty) = cast_element_types(
2667            OneOpdInterface::operand_type(self, ctx),
2668            OneResultInterface::result_type(self, ctx),
2669            ctx,
2670            self.loc(ctx),
2671        )?;
2672        let opd_ty = opd_ty.deref(ctx);
2673        if !type_impls::<dyn FloatTypeInterface>(&**opd_ty) {
2674            return verify_err!(self.loc(ctx), FloatCastVerifyErr::OperandTypeErr);
2675        };
2676        let res_ty = res_ty.deref(ctx);
2677        let Some(res_int_ty) = res_ty.downcast_ref::<IntegerType>() else {
2678            return verify_err!(self.loc(ctx), FloatCastVerifyErr::ResultTypeErr);
2679        };
2680        if !res_int_ty.is_signless() {
2681            return verify_err!(self.loc(ctx), FloatCastVerifyErr::ResultTypeErr);
2682        }
2683        Ok(())
2684    }
2685}
2686
2687/// Equivalent to LLVM's SIToFP opcode.
2688///
2689/// ### Operands
2690/// | operand | description |
2691/// |-----|-------|
2692/// | `arg` | Signed integer |
2693///
2694/// ### Result(s):
2695/// | result | description |
2696/// |-----|-------|
2697/// | `res` | Floating-point number |
2698#[pliron_op(
2699    name = "llvm.sitofp",
2700    format = "$0 ` to ` type($0)",
2701    interfaces = [CastOpInterface, OneResultInterface, OneOpdInterface, NResultsInterface<1>, NOpdsInterface<1>]
2702)]
2703pub struct SIToFPOp;
2704
2705impl Verify for SIToFPOp {
2706    fn verify(&self, ctx: &Context) -> Result<()> {
2707        // Check that the operand is an integer and the result is a float
2708        let (opd_ty, res_ty) = cast_element_types(
2709            OneOpdInterface::operand_type(self, ctx),
2710            OneResultInterface::result_type(self, ctx),
2711            ctx,
2712            self.loc(ctx),
2713        )?;
2714        let opd_ty = opd_ty.deref(ctx);
2715        let Some(opd_ty_int) = opd_ty.downcast_ref::<IntegerType>() else {
2716            return verify_err!(self.loc(ctx), FloatCastVerifyErr::OperandTypeErr);
2717        };
2718        if !opd_ty_int.is_signless() {
2719            return verify_err!(self.loc(ctx), FloatCastVerifyErr::OperandTypeErr);
2720        }
2721        let res_ty = res_ty.deref(ctx);
2722        if !type_impls::<dyn FloatTypeInterface>(&**res_ty) {
2723            return verify_err!(self.loc(ctx), FloatCastVerifyErr::ResultTypeErr);
2724        }
2725        Ok(())
2726    }
2727}
2728
2729/// Equivalent to LLVM's UIToFP opcode.
2730///
2731/// ### Operands
2732/// | operand | description |
2733/// |-----|-------|
2734/// | `arg` | Unsigned integer |
2735///
2736/// ### Result(s):
2737/// | result | description |
2738/// |-----|-------|
2739/// | `res` | Floating-point number |
2740#[pliron_op(
2741    name = "llvm.uitofp",
2742    format = "`<nneg=` attr($llvm_nneg_flag, `pliron::builtin::attributes::BoolAttr`) `> `$0 ` to ` type($0)",
2743    interfaces = [
2744        CastOpInterface,
2745        OneResultInterface,
2746        OneOpdInterface,
2747        CastOpWithNNegInterface,
2748        NNegFlag,
2749        NResultsInterface<1>,
2750        NOpdsInterface<1>
2751    ]
2752)]
2753pub struct UIToFPOp;
2754
2755impl Verify for UIToFPOp {
2756    fn verify(&self, ctx: &Context) -> Result<()> {
2757        // Check that the operand is an integer and the result is a float
2758        let (opd_ty, res_ty) = cast_element_types(
2759            OneOpdInterface::operand_type(self, ctx),
2760            OneResultInterface::result_type(self, ctx),
2761            ctx,
2762            self.loc(ctx),
2763        )?;
2764        let opd_ty = opd_ty.deref(ctx);
2765        let Some(opd_ty_int) = opd_ty.downcast_ref::<IntegerType>() else {
2766            return verify_err!(self.loc(ctx), FloatCastVerifyErr::OperandTypeErr);
2767        };
2768        if !opd_ty_int.is_signless() {
2769            return verify_err!(self.loc(ctx), FloatCastVerifyErr::OperandTypeErr);
2770        }
2771        let res_ty = res_ty.deref(ctx);
2772        if !type_impls::<dyn FloatTypeInterface>(&**res_ty) {
2773            return verify_err!(self.loc(ctx), FloatCastVerifyErr::ResultTypeErr);
2774        }
2775        Ok(())
2776    }
2777}
2778
2779/// Equivalent to LLVM's InsertValue opcode.
2780///
2781/// ### Operands
2782/// | operand | description |
2783/// |-----|-------|
2784/// | `aggregate` | LLVM aggregate type |
2785/// | `value` | LLVM type |
2786///
2787/// ### Result(s):
2788/// | result | description |
2789/// |-----|-------|
2790/// | `res` | LLVM aggregate type |
2791#[pliron_op(
2792    name = "llvm.insert_value",
2793    format = "$0 attr($insert_value_indices, $InsertExtractValueIndicesAttr) `, ` $1 ` : ` type($0)",
2794    interfaces = [OneResultInterface, NResultsInterface<1>, NOpdsInterface<2>],
2795    attributes = (insert_value_indices: InsertExtractValueIndicesAttr)
2796)]
2797pub struct InsertValueOp;
2798
2799impl InsertValueOp {
2800    /// Create a new [InsertValueOp].
2801    /// `aggregate` is the aggregate type and `value` is the value to insert.
2802    /// `indices` is the list of indices to insert the value at.
2803    /// The `indices` must be valid for the given `aggregate` type.
2804    pub fn new(ctx: &mut Context, aggregate: Value, value: Value, indices: Vec<u32>) -> Self {
2805        use pliron::r#type::Typed;
2806
2807        let result_type = aggregate.get_type(ctx);
2808        let op = Operation::new(
2809            ctx,
2810            Self::get_concrete_op_info(),
2811            vec![result_type],
2812            vec![aggregate, value],
2813            vec![],
2814            0,
2815        );
2816        let op = InsertValueOp { op };
2817        op.set_attr_insert_value_indices(ctx, InsertExtractValueIndicesAttr(indices));
2818        op
2819    }
2820
2821    /// Get the indices for inserting value into aggregate.
2822    pub fn indices(&self, ctx: &Context) -> Vec<u32> {
2823        self.get_attr_insert_value_indices(ctx).unwrap().clone().0
2824    }
2825}
2826
2827impl Verify for InsertValueOp {
2828    fn verify(&self, ctx: &Context) -> Result<()> {
2829        let loc = self.loc(ctx);
2830        // Ensure that we have the indices as an attribute.
2831        if self.get_attr_insert_value_indices(ctx).is_none() {
2832            verify_err!(loc.clone(), InsertExtractValueErr::IndicesAttrErr)?
2833        }
2834
2835        use pliron::r#type::Typed;
2836
2837        // Check that the value we are inserting is of the correct type.
2838        let aggr_type = self.get_operation().deref(ctx).get_operand(0).get_type(ctx);
2839        let indices = self.indices(ctx);
2840        match ExtractValueOp::indexed_type(ctx, aggr_type, &indices) {
2841            Err(e @ Error { .. }) => {
2842                // We reset the error type and error origin to be from here
2843                return Err(Error {
2844                    kind: ErrorKind::VerificationFailed,
2845                    backtrace: std::backtrace::Backtrace::capture(),
2846                    ..e
2847                });
2848            }
2849            Ok(indexed_type) => {
2850                if indexed_type != self.get_operation().deref(ctx).get_operand(1).get_type(ctx) {
2851                    return verify_err!(loc, InsertExtractValueErr::ValueTypeErr);
2852                }
2853            }
2854        }
2855
2856        Ok(())
2857    }
2858}
2859
2860/// Equivalent to LLVM's ExtractValue opcode.
2861///
2862/// ### Operands
2863/// | operand | description |
2864/// |-----|-------|
2865/// | `aggregate` | LLVM aggregate type |
2866///
2867/// ### Result(s):
2868/// | result | description |
2869/// |-----|-------|
2870/// | `res` | LLVM type |
2871#[pliron_op(
2872    name = "llvm.extract_value",
2873    format = "$0 attr($extract_value_indices, $InsertExtractValueIndicesAttr) ` : ` type($0)",
2874    interfaces = [OneResultInterface, OneOpdInterface, NResultsInterface<1>, NOpdsInterface<1>],
2875    attributes = (extract_value_indices: InsertExtractValueIndicesAttr)
2876)]
2877pub struct ExtractValueOp;
2878
2879impl Verify for ExtractValueOp {
2880    fn verify(&self, ctx: &Context) -> Result<()> {
2881        let loc = self.loc(ctx);
2882        // Ensure that we have the indices as an attribute.
2883        if self.get_attr_extract_value_indices(ctx).is_none() {
2884            verify_err!(loc.clone(), InsertExtractValueErr::IndicesAttrErr)?
2885        }
2886
2887        use pliron::r#type::Typed;
2888        // Check that the result type matches the indexed type
2889        let aggr_type = self.get_operation().deref(ctx).get_operand(0).get_type(ctx);
2890        let indices = self.indices(ctx);
2891        match Self::indexed_type(ctx, aggr_type, &indices) {
2892            Err(e @ Error { .. }) => {
2893                // We reset the error type and error origin to be from here
2894                return Err(Error {
2895                    kind: ErrorKind::VerificationFailed,
2896                    backtrace: std::backtrace::Backtrace::capture(),
2897                    ..e
2898                });
2899            }
2900            Ok(indexed_type) => {
2901                if indexed_type != self.get_operation().deref(ctx).get_type(0) {
2902                    return verify_err!(loc, InsertExtractValueErr::ValueTypeErr);
2903                }
2904            }
2905        }
2906
2907        Ok(())
2908    }
2909}
2910
2911impl ExtractValueOp {
2912    /// Create a new [ExtractValueOp].
2913    /// `aggregate` is the aggregate type and `indices` is the list of indices to extract the value from.
2914    /// The `indices` must be valid for the given `aggregate` type.
2915    /// The result type of the operation is the type of the value at the given indices.
2916    pub fn new(ctx: &mut Context, aggregate: Value, indices: Vec<u32>) -> Result<Self> {
2917        use pliron::r#type::Typed;
2918        let result_type = Self::indexed_type(ctx, aggregate.get_type(ctx), &indices)?;
2919        let op = Operation::new(
2920            ctx,
2921            Self::get_concrete_op_info(),
2922            vec![result_type],
2923            vec![aggregate],
2924            vec![],
2925            0,
2926        );
2927        let op = ExtractValueOp { op };
2928        op.set_attr_extract_value_indices(ctx, InsertExtractValueIndicesAttr(indices));
2929        Ok(op)
2930    }
2931
2932    /// Get the indices for extracting value from aggregate.
2933    pub fn indices(&self, ctx: &Context) -> Vec<u32> {
2934        self.get_attr_extract_value_indices(ctx).unwrap().clone().0
2935    }
2936
2937    /// Returns the type of the value at the given indices in the given aggregate type.
2938    pub fn indexed_type(
2939        ctx: &Context,
2940        aggr_type: Ptr<TypeObj>,
2941        indices: &[u32],
2942    ) -> Result<Ptr<TypeObj>> {
2943        fn indexed_type_inner(
2944            ctx: &Context,
2945            aggr_type: Ptr<TypeObj>,
2946            mut idx_itr: impl Iterator<Item = u32>,
2947        ) -> Result<Ptr<TypeObj>> {
2948            let Some(idx) = idx_itr.next() else {
2949                return Ok(aggr_type);
2950            };
2951            let aggr_type = &*aggr_type.deref(ctx);
2952            if let Some(st) = aggr_type.downcast_ref::<StructType>() {
2953                if st.is_opaque() || idx as usize >= st.num_fields() {
2954                    return arg_err_noloc!(InsertExtractValueErr::InvalidIndicesErr);
2955                }
2956                indexed_type_inner(ctx, st.field_type(idx as usize), idx_itr)
2957            } else if let Some(at) = aggr_type.downcast_ref::<ArrayType>() {
2958                if idx as u64 >= at.size() {
2959                    return arg_err_noloc!(InsertExtractValueErr::InvalidIndicesErr);
2960                }
2961                indexed_type_inner(ctx, at.elem_type(), idx_itr)
2962            } else {
2963                arg_err_noloc!(InsertExtractValueErr::InvalidIndicesErr)
2964            }
2965        }
2966        indexed_type_inner(ctx, aggr_type, indices.iter().cloned())
2967    }
2968}
2969
2970#[derive(Error, Debug)]
2971pub enum InsertExtractValueErr {
2972    #[error("Insert/Extract value instruction has no or incorrect indices attribute")]
2973    IndicesAttrErr,
2974    #[error("Invalid indices on insert/extract value instruction")]
2975    InvalidIndicesErr,
2976    #[error("Value being inserted / extracted does not match the type of the indexed aggregate")]
2977    ValueTypeErr,
2978}
2979
2980/// Equivalent to LLVM's InsertElement opcode.
2981///
2982//// ### Operands
2983/// | operand | description |
2984/// |-----|-------|
2985/// | `vector` | LLVM vector type |
2986/// | `element` | LLVM type |
2987/// | `index` | u32 |
2988///
2989/// /// ### Result(s):
2990/// | result | description |
2991/// |-----|-------|
2992/// | `res` | LLVM vector type |
2993#[pliron_op(
2994    name = "llvm.insert_element",
2995    format = "$0 `, ` $1 `, ` $2 ` : ` type($0)",
2996    interfaces = [OneResultInterface, NResultsInterface<1>, NOpdsInterface<3>]
2997)]
2998pub struct InsertElementOp;
2999impl Verify for InsertElementOp {
3000    fn verify(&self, ctx: &Context) -> Result<()> {
3001        use pliron::r#type::Typed;
3002
3003        let loc = self.loc(ctx);
3004        let op = &*self.op.deref(ctx);
3005        let vector_ty = op.get_operand(0).get_type(ctx);
3006        let element_ty = op.get_operand(1).get_type(ctx);
3007        let index_ty = op.get_operand(2).get_type(ctx);
3008
3009        let vector_ty = vector_ty.deref(ctx);
3010        let vector_ty = vector_ty.downcast_ref::<VectorType>();
3011        if vector_ty.is_none_or(|ty| ty.elem_type() != element_ty) {
3012            return verify_err!(loc, InsertExtractElementOpVerifyErr::ElementTypeErr);
3013        }
3014
3015        if !index_ty.deref(ctx).is::<IntegerType>() {
3016            return verify_err!(loc, InsertExtractElementOpVerifyErr::IndexTypeErr);
3017        }
3018
3019        Ok(())
3020    }
3021}
3022
3023impl InsertElementOp {
3024    /// Create a new [InsertElementOp].
3025    pub fn new(ctx: &mut Context, vector: Value, element: Value, index: Value) -> Self {
3026        use pliron::r#type::Typed;
3027
3028        let result_type = vector.get_type(ctx);
3029        let op = Operation::new(
3030            ctx,
3031            Self::get_concrete_op_info(),
3032            vec![result_type],
3033            vec![vector, element, index],
3034            vec![],
3035            0,
3036        );
3037        InsertElementOp { op }
3038    }
3039
3040    /// Get the vector type of the InsertElementOp.
3041    pub fn vector_type(&self, ctx: &Context) -> TypePtr<VectorType> {
3042        let ty = self.get_operation().deref(ctx).get_type(0);
3043        TypePtr::<VectorType>::from_ptr(ty, ctx)
3044            .expect("InsertElementOp result type is not a VectorType")
3045    }
3046
3047    /// Get the vector operand of the InsertElementOp.
3048    pub fn vector_operand(&self, ctx: &Context) -> Value {
3049        self.get_operation().deref(ctx).get_operand(0)
3050    }
3051
3052    /// Get the element operand of the InsertElementOp.
3053    pub fn element_operand(&self, ctx: &Context) -> Value {
3054        self.get_operation().deref(ctx).get_operand(1)
3055    }
3056
3057    /// Get the index operand of the InsertElementOp.
3058    pub fn index_operand(&self, ctx: &Context) -> Value {
3059        self.get_operation().deref(ctx).get_operand(2)
3060    }
3061}
3062
3063#[derive(Error, Debug)]
3064pub enum InsertExtractElementOpVerifyErr {
3065    #[error("Element type must match vector element type")]
3066    ElementTypeErr,
3067    #[error("Index type must be signless integer")]
3068    IndexTypeErr,
3069}
3070
3071/// ExtractElementOp
3072/// Equivalent to LLVM's ExtractElement opcode.
3073/// /// ### Operands
3074/// | operand | description |
3075/// |-----|-------|
3076/// | `vector` | LLVM vector type |
3077/// | `index` | u32 |
3078/// /// ### Result(s):
3079/// | result | description |
3080/// |-----|-------|
3081/// | `res` | LLVM type |
3082#[pliron_op(
3083    name = "llvm.extract_element",
3084    format = "$0 `, ` $1 ` : ` type($0)",
3085    interfaces = [OneResultInterface, NResultsInterface<1>, NOpdsInterface<2>]
3086)]
3087pub struct ExtractElementOp;
3088
3089impl Verify for ExtractElementOp {
3090    fn verify(&self, ctx: &Context) -> Result<()> {
3091        use pliron::r#type::Typed;
3092        let loc = self.loc(ctx);
3093        let op = &*self.op.deref(ctx);
3094        let vector_ty = op.get_operand(0).get_type(ctx);
3095        let index_ty = op.get_operand(1).get_type(ctx);
3096        let vector_ty = vector_ty.deref(ctx);
3097        let vector_ty = vector_ty.downcast_ref::<VectorType>();
3098        if vector_ty.is_none_or(|ty| ty.elem_type() != op.get_type(0)) {
3099            return verify_err!(loc, InsertExtractElementOpVerifyErr::ElementTypeErr);
3100        }
3101        if !index_ty.deref(ctx).is::<IntegerType>() {
3102            return verify_err!(loc, InsertExtractElementOpVerifyErr::IndexTypeErr);
3103        }
3104        Ok(())
3105    }
3106}
3107
3108impl ExtractElementOp {
3109    /// Create a new [ExtractElementOp].
3110    pub fn new(ctx: &mut Context, vector: Value, index: Value) -> Self {
3111        use pliron::r#type::Typed;
3112
3113        let result_type = vector
3114            .get_type(ctx)
3115            .deref(ctx)
3116            .downcast_ref::<VectorType>()
3117            .expect("ExtractElementOp vector operand must be a vector type")
3118            .elem_type();
3119
3120        let op = Operation::new(
3121            ctx,
3122            Self::get_concrete_op_info(),
3123            vec![result_type],
3124            vec![vector, index],
3125            vec![],
3126            0,
3127        );
3128        ExtractElementOp { op }
3129    }
3130
3131    /// Get the vector type of the ExtractElementOp.
3132    pub fn vector_type(&self, ctx: &Context) -> TypePtr<VectorType> {
3133        use pliron::r#type::Typed;
3134        let ty = self.vector_operand(ctx).get_type(ctx);
3135        TypePtr::<VectorType>::from_ptr(ty, ctx)
3136            .expect("ExtractElementOp vector operand type is not a VectorType")
3137    }
3138
3139    /// Get the vector operand of the ExtractElementOp.
3140    pub fn vector_operand(&self, ctx: &Context) -> Value {
3141        self.get_operation().deref(ctx).get_operand(0)
3142    }
3143
3144    /// Get the index operand of the ExtractElementOp.
3145    pub fn index_operand(&self, ctx: &Context) -> Value {
3146        self.get_operation().deref(ctx).get_operand(1)
3147    }
3148}
3149
3150/// ShuffleVectorOp
3151/// Equivalent to LLVM's ShuffleVector opcode.
3152///
3153/// ### Operands
3154/// | operand | description |
3155/// |-----|-------|
3156/// | `vector1` | LLVM vector type |
3157/// | `vector2` | LLVM vector type |
3158/// | `mask` | LLVM vector type |
3159///
3160/// ### Result(s):
3161/// | result | description |
3162/// |-----|-------|
3163/// | `res` | LLVM vector type |
3164#[pliron_op(
3165    name = "llvm.shuffle_vector",
3166    format = "$0 `, ` $1 `, ` attr($llvm_shuffle_vector_mask, $ShuffleVectorMaskAttr) ` : ` type($0)",
3167    interfaces = [OneResultInterface, NResultsInterface<1>, NOpdsInterface<2>],
3168    attributes = (llvm_shuffle_vector_mask: ShuffleVectorMaskAttr)
3169)]
3170pub struct ShuffleVectorOp;
3171impl Verify for ShuffleVectorOp {
3172    fn verify(&self, ctx: &Context) -> Result<()> {
3173        use pliron::r#type::Typed;
3174
3175        let loc = self.loc(ctx);
3176        let op = &*self.op.deref(ctx);
3177        let vector1_ty = op.get_operand(0).get_type(ctx);
3178        let vector2_ty = op.get_operand(1).get_type(ctx);
3179
3180        let vector1_ty = vector1_ty.deref(ctx);
3181        let vector1_ty = vector1_ty.downcast_ref::<VectorType>();
3182        let vector2_ty = vector2_ty.deref(ctx);
3183        let vector2_ty = vector2_ty.downcast_ref::<VectorType>();
3184
3185        let (Some(v1_ty), Some(v2_ty)) = (vector1_ty, vector2_ty) else {
3186            return verify_err!(loc, ShuffleVectorOpVerifyErr::OperandsTypeErr);
3187        };
3188
3189        if v1_ty != v2_ty {
3190            return verify_err!(loc, ShuffleVectorOpVerifyErr::OperandsTypeErr);
3191        }
3192
3193        let res_ty = op.get_type(0).deref(ctx);
3194        let res_ty = res_ty.downcast_ref::<VectorType>();
3195        let Some(res_ty) = res_ty else {
3196            return verify_err!(loc, ShuffleVectorOpVerifyErr::ResultTypeErr);
3197        };
3198
3199        if res_ty.elem_type() != v1_ty.elem_type()
3200            || res_ty.num_elements() as usize
3201                != self.get_attr_llvm_shuffle_vector_mask(ctx).unwrap().0.len()
3202        {
3203            return verify_err!(loc, ShuffleVectorOpVerifyErr::ResultTypeErr);
3204        }
3205
3206        Ok(())
3207    }
3208}
3209
3210/// The undef mask element used in ShuffleVectorOp masks.
3211pub static SHUFFLE_VECTOR_UNDEF_MASK_ELEM: LazyLock<i32> = LazyLock::new(llvm_get_undef_mask_elem);
3212
3213impl ShuffleVectorOp {
3214    /// Create a new [ShuffleVectorOp].
3215    pub fn new(ctx: &mut Context, vector1: Value, vector2: Value, mask: Vec<i32>) -> Self {
3216        use pliron::r#type::Typed;
3217
3218        let (elem_ty, kind) = {
3219            let vector1_ty = vector1.get_type(ctx).deref(ctx);
3220            let opd_vec_ty = vector1_ty
3221                .downcast_ref::<VectorType>()
3222                .expect("ShuffleVectorOp vector1 operand must be a vector type");
3223            (opd_vec_ty.elem_type(), opd_vec_ty.kind())
3224        };
3225
3226        let result_type = VectorType::get(
3227            ctx,
3228            elem_ty,
3229            mask.len()
3230                .try_into()
3231                .expect("ShuffleVectorOp mask length too large"),
3232            kind,
3233        );
3234        let op = Operation::new(
3235            ctx,
3236            Self::get_concrete_op_info(),
3237            vec![result_type.into()],
3238            vec![vector1, vector2],
3239            vec![],
3240            0,
3241        );
3242
3243        let mask_attr = ShuffleVectorMaskAttr(mask);
3244        let op = ShuffleVectorOp { op };
3245        op.set_attr_llvm_shuffle_vector_mask(ctx, mask_attr);
3246        op
3247    }
3248}
3249
3250#[derive(Error, Debug)]
3251pub enum ShuffleVectorOpVerifyErr {
3252    #[error("Both operands must be equivalent vector types")]
3253    OperandsTypeErr,
3254    #[error("Result type must be a vector type with correct element type and size")]
3255    ResultTypeErr,
3256}
3257
3258/// Equivalent to LLVM's Select opcode.
3259///
3260/// ### Operands
3261/// | operand | description |
3262/// |-----|-------|
3263/// | `condition` | i1 |
3264/// | `true_dest` | any type |
3265/// | `false_dest` | any type |
3266///
3267/// ### Result(s):
3268/// | result | description |
3269/// |-----|-------|
3270/// | `res` | any type |
3271#[pliron_op(
3272    name = "llvm.select",
3273    format = "$0 ` ? ` $1 ` : ` $2 ` : ` type($0)",
3274    interfaces = [OneResultInterface, NResultsInterface<1>, NOpdsInterface<3>]
3275)]
3276pub struct SelectOp;
3277
3278impl SelectOp {
3279    /// Create a new [SelectOp].
3280    pub fn new(ctx: &mut Context, cond: Value, true_val: Value, false_val: Value) -> Self {
3281        use pliron::r#type::Typed;
3282
3283        let result_type = true_val.get_type(ctx);
3284        let op = Operation::new(
3285            ctx,
3286            Self::get_concrete_op_info(),
3287            vec![result_type],
3288            vec![cond, true_val, false_val],
3289            vec![],
3290            0,
3291        );
3292        SelectOp { op }
3293    }
3294}
3295
3296impl Verify for SelectOp {
3297    fn verify(&self, ctx: &Context) -> Result<()> {
3298        use pliron::r#type::Typed;
3299
3300        let loc = self.loc(ctx);
3301        let op = &*self.op.deref(ctx);
3302        let ty = op.get_type(0);
3303        let cond_ty = op.get_operand(0).get_type(ctx);
3304        let true_ty = op.get_operand(1).get_type(ctx);
3305        let false_ty = op.get_operand(2).get_type(ctx);
3306        if ty != true_ty || ty != false_ty {
3307            return verify_err!(loc, SelectOpVerifyErr::ResultTypeErr);
3308        }
3309
3310        let mut cond_ty = cond_ty.deref(ctx);
3311        if let Some(vec_ty) = cond_ty.downcast_ref::<VectorType>() {
3312            if let Some(opd_vec_ty) = ty.deref(ctx).downcast_ref::<VectorType>()
3313                && vec_ty.num_elements() == opd_vec_ty.num_elements()
3314            {
3315                // We're good, both the condition and operand are vectors of the same length
3316            } else {
3317                return verify_err!(loc, SelectOpVerifyErr::ConditionTypeErr);
3318            }
3319            cond_ty = vec_ty.elem_type().deref(ctx);
3320        }
3321
3322        let cond_ty = cond_ty.downcast_ref::<IntegerType>();
3323        if cond_ty.is_none_or(|ty| ty.width() != 1) {
3324            return verify_err!(loc, SelectOpVerifyErr::ConditionTypeErr);
3325        }
3326        Ok(())
3327    }
3328}
3329
3330#[derive(Error, Debug)]
3331pub enum SelectOpVerifyErr {
3332    #[error("Result must be the same as the true and false destination types")]
3333    ResultTypeErr,
3334    #[error("Condition must be an i1 or a vector of i1 equal in length to the operand vectors")]
3335    ConditionTypeErr,
3336}
3337
3338/// Floating-point negation
3339/// Equivalent to LLVM's `fneg` instruction.
3340///
3341/// Operands:
3342/// | operand | description |
3343/// |-----|-------|
3344/// | `arg` | float |
3345///
3346/// Result(s):
3347/// | result | description |
3348/// |-----|-------|
3349/// | `res` | float |
3350#[pliron_op(
3351    name = "llvm.fneg",
3352    format = "attr($llvm_fast_math_flags, $FastmathFlagsAttr) $0 ` : ` type($0)",
3353    interfaces = [
3354        OneResultInterface,
3355        OneOpdInterface,
3356        SameResultsType,
3357        SameOperandsType,
3358        SameOperandsAndResultType,
3359        FastMathFlags,
3360        NResultsInterface<1>,
3361        NOpdsInterface<1>,
3362        AtLeastNOpdsInterface<1>,
3363        AtLeastNResultsInterface<1>
3364    ]
3365)]
3366pub struct FNegOp;
3367
3368impl Verify for FNegOp {
3369    fn verify(&self, ctx: &Context) -> Result<()> {
3370        use pliron::r#type::Typed;
3371
3372        let loc = self.loc(ctx);
3373        let op = &*self.op.deref(ctx);
3374        let arg_ty = op.get_operand(0).get_type(ctx);
3375        if !type_impls::<dyn FloatTypeInterface>(&**arg_ty.deref(ctx)) {
3376            return verify_err!(loc, FNegOpVerifyErr::ArgumentMustBeFloat);
3377        }
3378        Ok(())
3379    }
3380}
3381
3382impl FNegOp {
3383    /// Create a new [FNegOp].
3384    pub fn new_with_fast_math_flags(
3385        ctx: &mut Context,
3386        arg: Value,
3387        fast_math_flags: FastmathFlagsAttr,
3388    ) -> Self {
3389        use pliron::r#type::Typed;
3390        let op = Operation::new(
3391            ctx,
3392            Self::get_concrete_op_info(),
3393            vec![arg.get_type(ctx)],
3394            vec![arg],
3395            vec![],
3396            0,
3397        );
3398        let op = FNegOp { op };
3399        op.set_fast_math_flags(ctx, fast_math_flags);
3400        op
3401    }
3402}
3403
3404#[derive(Error, Debug)]
3405pub enum FNegOpVerifyErr {
3406    #[error("Argument must be a float")]
3407    ArgumentMustBeFloat,
3408    #[error("Fast math flags must be set")]
3409    FastMathFlagsMustBeSet,
3410}
3411
3412macro_rules! new_float_bin_op {
3413    (   $(#[$outer:meta])*
3414        $op_name:ident, $op_id:literal
3415    ) => {
3416        $(#[$outer])*
3417        /// ### Operands:
3418        ///
3419        /// | operand | description |
3420        /// |-----|-------|
3421        /// | `lhs` | float |
3422        /// | `rhs` | float |
3423        ///
3424        /// ### Result(s):
3425        ///
3426        /// | result | description |
3427        /// |-----|-------|
3428        /// | `res` | float |
3429        #[pliron_op(
3430            name = $op_id,
3431            format = "attr($llvm_fast_math_flags, $FastmathFlagsAttr) ` ` $0 `, ` $1 ` : ` type($0)",
3432            interfaces = [
3433                OneResultInterface, SameOperandsType, SameResultsType,
3434                AtLeastNOpdsInterface<1>, AtLeastNResultsInterface<1>,
3435                SameOperandsAndResultType, BinArithOp, FloatBinArithOp,
3436                FloatBinArithOpWithFastMathFlags, FastMathFlags, NResultsInterface<1>, NOpdsInterface<2>
3437            ],
3438            verifier = "succ"
3439        )]
3440        pub struct $op_name;
3441    }
3442}
3443
3444new_float_bin_op! {
3445    /// Equivalent to LLVM's `fadd` instruction.
3446    FAddOp,
3447    "llvm.fadd"
3448}
3449
3450new_float_bin_op! {
3451    /// Equivalent to LLVM's `fsub` instruction.
3452    FSubOp,
3453    "llvm.fsub"
3454}
3455
3456new_float_bin_op! {
3457    /// Equivalent to LLVM's `fmul` instruction.
3458    FMulOp,
3459    "llvm.fmul"
3460}
3461
3462new_float_bin_op! {
3463    /// Equivalent to LLVM's `fdiv` instruction.
3464    FDivOp,
3465    "llvm.fdiv"
3466}
3467
3468new_float_bin_op! {
3469    /// Equivalent to LLVM's `frem` instruction.
3470    FRemOp,
3471    "llvm.frem"
3472}
3473
3474/// Equivalent to LLVM'same `fcmp` instruction.
3475///
3476/// ### Operand(s):
3477/// | operand | description |
3478/// |-----|-------|
3479/// | `lhs` | float |
3480/// | `rhs` | float |
3481///
3482/// ### Result(s):
3483///
3484/// | result | description |
3485/// |-----|-------|
3486/// | `res` | 1-bit signless integer |
3487#[pliron_op(
3488    name = "llvm.fcmp",
3489    format = "attr($llvm_fast_math_flags, $FastmathFlagsAttr) ` ` $0 ` <` attr($fcmp_predicate, $FCmpPredicateAttr) `> ` $1 ` : ` type($0)",
3490    interfaces = [
3491        OneResultInterface,
3492        SameOperandsType,
3493        AtLeastNOpdsInterface<1>,
3494        FastMathFlags,
3495        NResultsInterface<1>,
3496        NOpdsInterface<2>
3497    ],
3498    attributes = (fcmp_predicate: FCmpPredicateAttr)
3499)]
3500pub struct FCmpOp;
3501
3502impl FCmpOp {
3503    /// Create a new [FCmpOp]
3504    pub fn new(ctx: &mut Context, pred: FCmpPredicateAttr, lhs: Value, rhs: Value) -> Self {
3505        let bool_ty = IntegerType::get(ctx, 1, Signedness::Signless);
3506        let op = Operation::new(
3507            ctx,
3508            Self::get_concrete_op_info(),
3509            vec![bool_ty.into()],
3510            vec![lhs, rhs],
3511            vec![],
3512            0,
3513        );
3514        let op = FCmpOp { op };
3515        op.set_attr_fcmp_predicate(ctx, pred);
3516        op
3517    }
3518
3519    /// Get the predicate
3520    pub fn predicate(&self, ctx: &Context) -> FCmpPredicateAttr {
3521        self.get_attr_fcmp_predicate(ctx)
3522            .expect("FCmpOp missing or incorrect predicate attribute type")
3523            .clone()
3524    }
3525}
3526
3527impl Verify for FCmpOp {
3528    fn verify(&self, ctx: &Context) -> Result<()> {
3529        let loc = self.loc(ctx);
3530
3531        if self.get_attr_fcmp_predicate(ctx).is_none() {
3532            verify_err!(loc.clone(), FCmpOpVerifyErr::PredAttrErr)?
3533        }
3534
3535        let res_ty: TypePtr<IntegerType> =
3536            TypePtr::from_ptr(self.result_type(ctx), ctx).map_err(|mut err| {
3537                err.set_loc(loc.clone());
3538                err
3539            })?;
3540
3541        if res_ty.deref(ctx).width() != 1 {
3542            return verify_err!(loc, FCmpOpVerifyErr::ResultNotBool);
3543        }
3544
3545        let opd_ty = self.operand_type(ctx).deref(ctx);
3546        if !(type_impls::<dyn FloatTypeInterface>(&**opd_ty)) {
3547            return verify_err!(loc, FCmpOpVerifyErr::IncorrectOperandsType);
3548        }
3549
3550        Ok(())
3551    }
3552}
3553
3554#[derive(Error, Debug)]
3555pub enum FCmpOpVerifyErr {
3556    #[error("Result must be 1-bit integer (bool)")]
3557    ResultNotBool,
3558    #[error("Operand must be floating point type")]
3559    IncorrectOperandsType,
3560    #[error("Missing or incorrect predicate attribute")]
3561    PredAttrErr,
3562}
3563
3564/// All LLVM intrinsic calls are represented by this [Op].
3565/// Same as MLIR's [llvm.call_intrinsic](https://mlir.llvm.org/docs/Dialects/LLVM/#llvmcall_intrinsic-llvmcallintrinsicop).
3566#[pliron_op(
3567    name = "llvm.call_intrinsic",
3568    interfaces = [OneResultInterface, NResultsInterface<1>],
3569    attributes = (
3570        llvm_intrinsic_name: StringAttr,
3571        llvm_intrinsic_type: TypeAttr,
3572        llvm_intrinsic_fastmath_flags: FastmathFlagsAttr
3573    )
3574)]
3575pub struct CallIntrinsicOp;
3576
3577impl CallIntrinsicOp {
3578    /// Create a new [CallIntrinsicOp].
3579    pub fn new(
3580        ctx: &mut Context,
3581        intrinsic_name: StringAttr,
3582        intrinsic_type: TypePtr<FuncType>,
3583        operands: Vec<Value>,
3584    ) -> Self {
3585        let res_ty = intrinsic_type.deref(ctx).result_type();
3586        let op = Operation::new(
3587            ctx,
3588            Self::get_concrete_op_info(),
3589            vec![res_ty],
3590            operands,
3591            vec![],
3592            0,
3593        );
3594        let op = CallIntrinsicOp { op };
3595        op.set_attr_llvm_intrinsic_name(ctx, intrinsic_name);
3596        op.set_attr_llvm_intrinsic_type(ctx, TypeAttr::new(intrinsic_type.into()));
3597        op
3598    }
3599}
3600
3601impl Printable for CallIntrinsicOp {
3602    fn fmt(
3603        &self,
3604        ctx: &Context,
3605        _state: &printable::State,
3606        f: &mut core::fmt::Formatter<'_>,
3607    ) -> core::fmt::Result {
3608        // [result = ] llvm.call_intrinsic @name <FastMathFlags> (operands) : type
3609        if let Some(res) = self.op.deref(ctx).results().next() {
3610            write!(f, "{} = ", res.disp(ctx))?;
3611        }
3612
3613        write!(
3614            f,
3615            "{} @{} ",
3616            Self::get_opid_static(),
3617            self.get_attr_llvm_intrinsic_name(ctx)
3618                .expect("CallIntrinsicOp missing or incorrect intrinsic name attribute")
3619                .disp(ctx),
3620        )?;
3621
3622        if let Some(fmf) = self.get_attr_llvm_intrinsic_fastmath_flags(ctx)
3623            && *fmf != FastmathFlagsAttr::default()
3624        {
3625            write!(f, " {} ", fmf.disp(ctx))?;
3626        }
3627
3628        write!(
3629            f,
3630            "({}) : {}",
3631            iter_with_sep(
3632                self.op.deref(ctx).operands(),
3633                printable::ListSeparator::CharSpace(',')
3634            )
3635            .disp(ctx),
3636            self.get_attr_llvm_intrinsic_type(ctx)
3637                .expect("CallIntrinsicOp missing or incorrect intrinsic type attribute")
3638                .disp(ctx),
3639        )
3640    }
3641}
3642
3643impl Parsable for CallIntrinsicOp {
3644    type Arg = Vec<(Identifier, Location)>;
3645    type Parsed = OpObj;
3646    fn parse<'a>(
3647        state_stream: &mut StateStream<'a>,
3648        results: Self::Arg,
3649    ) -> ParseResult<'a, Self::Parsed> {
3650        let pos = state_stream.loc();
3651
3652        let mut parser = (
3653            spaced(token('@').with(StringAttr::parser(()))),
3654            optional(spaced(FastmathFlagsAttr::parser(()))),
3655            delimited_list_parser('(', ')', ',', ssa_opd_parser()).skip(spaced(token(':'))),
3656            spaced(type_parser()),
3657        );
3658
3659        // Parse and build the call intrinsic op.
3660        let (iname, fmf, operands, ftype) = parser.parse_stream(state_stream).into_result()?.0;
3661
3662        let ctx = &mut state_stream.state.ctx;
3663        let intr_ty = TypePtr::<FuncType>::from_ptr(ftype, ctx).map_err(|mut err| {
3664            err.set_loc(pos);
3665            err
3666        })?;
3667        let op = CallIntrinsicOp::new(ctx, iname, intr_ty, operands);
3668        if let Some(fmf) = fmf {
3669            op.set_attr_llvm_intrinsic_fastmath_flags(ctx, fmf);
3670        }
3671        process_parsed_ssa_defs(state_stream, &results, op.get_operation())?;
3672        Ok(OpObj::new(op)).into_parse_result()
3673    }
3674}
3675
3676#[derive(Error, Debug)]
3677pub enum CallIntrinsicVerifyErr {
3678    #[error("Missing or incorrect intrinsic name attribute")]
3679    MissingIntrinsicNameAttr,
3680    #[error("Missing or incorrect intrinsic type attribute")]
3681    MissingIntrinsicTypeAttr,
3682    #[error("Number or types of operands does not match intrinsic type")]
3683    OperandsMismatch,
3684    #[error("Number or types of results does not match intrinsic type")]
3685    ResultsMismatch,
3686    #[error("Intrinsic name does not correspond to a known LLVM intrinsic")]
3687    UnknownIntrinsicName,
3688}
3689
3690impl Verify for CallIntrinsicOp {
3691    fn verify(&self, ctx: &Context) -> Result<()> {
3692        // Check that the intrinsic name and type attributes are present.
3693        let Some(name) = self.get_attr_llvm_intrinsic_name(ctx) else {
3694            return verify_err!(
3695                self.loc(ctx),
3696                CallIntrinsicVerifyErr::MissingIntrinsicNameAttr
3697            );
3698        };
3699
3700        let Some(ty) = self
3701            .get_attr_llvm_intrinsic_type(ctx)
3702            .and_then(|ty| TypePtr::<FuncType>::from_ptr(ty.get_type(ctx), ctx).ok())
3703        else {
3704            return verify_err!(
3705                self.loc(ctx),
3706                CallIntrinsicVerifyErr::MissingIntrinsicTypeAttr
3707            );
3708        };
3709
3710        let arg_types = ty.deref(ctx).arg_types();
3711        let res_type = ty.deref(ctx).result_type();
3712
3713        // Check that the operand and result types match the intrinsic type.
3714        let op = &*self.op.deref(ctx);
3715        let intrinsic_arg_types = ty.deref(ctx).arg_types();
3716        if op.operands().count() != intrinsic_arg_types.len() {
3717            return verify_err!(self.loc(ctx), CallIntrinsicVerifyErr::OperandsMismatch);
3718        }
3719
3720        for (i, operand) in op.operands().enumerate() {
3721            let opd_ty = pliron::r#type::Typed::get_type(&operand, ctx);
3722            if opd_ty != arg_types[i] {
3723                return verify_err!(self.loc(ctx), CallIntrinsicVerifyErr::OperandsMismatch);
3724            }
3725        }
3726
3727        let mut result_types = op.result_types();
3728        if let Some(result_type) = result_types.next()
3729            && result_type == res_type
3730            && result_types.next().is_none()
3731        {
3732        } else {
3733            return verify_err!(self.loc(ctx), CallIntrinsicVerifyErr::ResultsMismatch);
3734        }
3735
3736        if llvm_lookup_intrinsic_id(&<StringAttr as Into<String>>::into(name.clone())).is_none() {
3737            return verify_err!(self.loc(ctx), CallIntrinsicVerifyErr::UnknownIntrinsicName);
3738        }
3739
3740        Ok(())
3741    }
3742}
3743
3744/// Equivalent to LLVM's `va_arg` operation.
3745#[pliron_op(
3746    name = "llvm.va_arg",
3747    format = "$0 ` : ` type($0)",
3748    interfaces = [OneResultInterface, OneOpdInterface, NResultsInterface<1>, NOpdsInterface<1>]
3749)]
3750pub struct VAArgOp;
3751
3752#[derive(Error, Debug)]
3753pub enum VAArgOpVerifyErr {
3754    #[error("Operand must be a pointer type")]
3755    OperandNotPointer,
3756}
3757
3758impl Verify for VAArgOp {
3759    fn verify(&self, ctx: &Context) -> Result<()> {
3760        let loc = self.loc(ctx);
3761
3762        // Check that the argument is a pointer.
3763        let opd_ty = self.operand_type(ctx).deref(ctx);
3764        if !opd_ty.is::<PointerType>() {
3765            return verify_err!(loc, VAArgOpVerifyErr::OperandNotPointer);
3766        }
3767
3768        Ok(())
3769    }
3770}
3771
3772impl VAArgOp {
3773    /// Create a new [VAArgOp].
3774    pub fn new(ctx: &mut Context, list: Value, ty: Ptr<TypeObj>) -> Self {
3775        let op = Operation::new(
3776            ctx,
3777            Self::get_concrete_op_info(),
3778            vec![ty],
3779            vec![list],
3780            vec![],
3781            0,
3782        );
3783        VAArgOp { op }
3784    }
3785}
3786
3787/// Equivalent to LLVM's `func` operation.
3788/// See [llvm.func](https://mlir.llvm.org/docs/Dialects/LLVM/#llvmfunc-llvmllvmfuncop).
3789#[pliron_op(
3790    name = "llvm.func",
3791    interfaces = [
3792        SymbolOpInterface,
3793        IsolatedFromAboveInterface,
3794        AtMostNRegionsInterface<1>,
3795        AtMostOneRegionInterface,
3796        NResultsInterface<0>,
3797        NOpdsInterface<0>,
3798        LlvmSymbolName
3799    ],
3800    attributes = (llvm_func_type: TypeAttr, llvm_function_linkage: LinkageAttr)
3801)]
3802pub struct FuncOp;
3803
3804impl FuncOp {
3805    /// Create a new empty [FuncOp].
3806    pub fn new(ctx: &mut Context, name: Identifier, ty: TypePtr<FuncType>) -> Self {
3807        let ty_attr = TypeAttr::new(ty.into());
3808        let op = Operation::new(ctx, Self::get_concrete_op_info(), vec![], vec![], vec![], 0);
3809        let opop = FuncOp { op };
3810        opop.set_symbol_name(ctx, name);
3811        opop.set_attr_llvm_func_type(ctx, ty_attr);
3812
3813        opop
3814    }
3815
3816    /// Get the function signature (type).
3817    pub fn get_type(&self, ctx: &Context) -> TypePtr<FuncType> {
3818        let ty = attr_cast::<dyn TypedAttrInterface>(&*self.get_attr_llvm_func_type(ctx).unwrap())
3819            .unwrap()
3820            .get_type(ctx);
3821        TypePtr::from_ptr(ty, ctx).unwrap()
3822    }
3823
3824    /// Get the entry block (if it exists) of this function.
3825    pub fn get_entry_block(&self, ctx: &Context) -> Option<Ptr<BasicBlock>> {
3826        self.op
3827            .deref(ctx)
3828            .regions()
3829            .next()
3830            .and_then(|region| region.deref(ctx).get_head())
3831    }
3832
3833    /// Get the entry block of this function, creating it if it does not exist.
3834    pub fn get_or_create_entry_block(&self, ctx: &mut Context) -> Ptr<BasicBlock> {
3835        if let Some(entry_block) = self.get_entry_block(ctx) {
3836            return entry_block;
3837        }
3838
3839        // Create an empty entry block.
3840        assert!(
3841            self.op.deref(ctx).regions().next().is_none(),
3842            "FuncOp already has a region, but no block inside it"
3843        );
3844        let region = Operation::add_region(self.op, ctx);
3845        let arg_types = self.get_type(ctx).deref(ctx).arg_types().clone();
3846        let body = BasicBlock::new(ctx, Some("entry".try_into().unwrap()), arg_types);
3847        body.insert_at_front(region, ctx);
3848        body
3849    }
3850}
3851
3852impl pliron::r#type::Typed for FuncOp {
3853    fn get_type(&self, ctx: &Context) -> Ptr<TypeObj> {
3854        self.get_type(ctx).into()
3855    }
3856}
3857
3858impl Printable for FuncOp {
3859    fn fmt(
3860        &self,
3861        ctx: &Context,
3862        state: &printable::State,
3863        f: &mut core::fmt::Formatter<'_>,
3864    ) -> core::fmt::Result {
3865        typed_symb_op_header(self).fmt(ctx, state, f)?;
3866
3867        // Print attributes except for function type and symbol name.
3868        let mut attributes_to_print_separately =
3869            self.op.deref(ctx).attributes.clone_skip_outlined();
3870        attributes_to_print_separately
3871            .0
3872            .retain(|key, _| key != &*ATTR_KEY_LLVM_FUNC_TYPE && key != &*ATTR_KEY_SYM_NAME);
3873        indented_block!(state, {
3874            write!(
3875                f,
3876                "{}{}",
3877                indented_nl(state),
3878                attributes_to_print_separately.disp(ctx)
3879            )?;
3880        });
3881
3882        if let Some(r) = self.get_region(ctx) {
3883            write!(f, " ")?;
3884            r.fmt(ctx, state, f)?;
3885        }
3886        Ok(())
3887    }
3888}
3889
3890impl Parsable for FuncOp {
3891    type Arg = Vec<(Identifier, Location)>;
3892    type Parsed = OpObj;
3893    fn parse<'a>(
3894        state_stream: &mut StateStream<'a>,
3895        results: Self::Arg,
3896    ) -> ParseResult<'a, Self::Parsed> {
3897        if !results.is_empty() {
3898            input_err!(
3899                state_stream.loc(),
3900                op_interfaces::NResultsVerifyErr(0, results.len())
3901            )?
3902        }
3903
3904        let op = Operation::new(
3905            state_stream.state.ctx,
3906            Self::get_concrete_op_info(),
3907            vec![],
3908            vec![],
3909            vec![],
3910            0,
3911        );
3912
3913        let mut parser = (
3914            spaced(token('@').with(Identifier::parser(()))).skip(spaced(token(':'))),
3915            spaced(type_parser()),
3916            spaced(AttributeDict::parser(())),
3917            spaced(optional(Region::parser(op))),
3918        );
3919
3920        // Parse and build the function, providing name and type details.
3921        parser
3922            .parse_stream(state_stream)
3923            .map(|(fname, fty, attrs, _region)| -> OpObj {
3924                let ctx = &mut state_stream.state.ctx;
3925                op.deref_mut(ctx).attributes = attrs;
3926                let ty_attr = TypeAttr::new(fty);
3927                let opop = FuncOp { op };
3928                opop.set_symbol_name(ctx, fname);
3929                opop.set_attr_llvm_func_type(ctx, ty_attr);
3930                OpObj::new(opop)
3931            })
3932            .into()
3933    }
3934}
3935
3936#[derive(Error, Debug)]
3937#[error("llvm.func op does not have llvm.func type")]
3938pub struct FuncOpTypeErr;
3939
3940impl Verify for FuncOp {
3941    fn verify(&self, _ctx: &Context) -> Result<()> {
3942        Ok(())
3943    }
3944}
3945
3946impl IsDeclaration for FuncOp {
3947    fn is_declaration(&self, ctx: &Context) -> bool {
3948        self.get_region(ctx).is_none()
3949    }
3950}