Skip to main content

pliron_llvm/
op_interfaces.rs

1//! [Op] Interfaces defined in the LLVM dialect.
2
3use pliron::{
4    builtin::{
5        attributes::BoolAttr,
6        op_interfaces::{
7            NOpdsInterface, NResultsInterface, OneOpdInterface, ResultNOfType, SymbolOpInterface,
8        },
9        type_interfaces::FloatTypeInterface,
10    },
11    derive::op_interface,
12    dict_key,
13    r#type::type_cast,
14};
15use thiserror::Error;
16
17use pliron::{
18    builtin::{
19        op_interfaces::{OneResultInterface, SameOperandsAndResultType},
20        types::{IntegerType, Signedness},
21    },
22    context::{Context, Ptr},
23    location::Located,
24    op::{Op, op_cast},
25    operation::Operation,
26    result::Result,
27    r#type::{TypeObj, Typed},
28    value::Value,
29    verify_err,
30};
31
32use crate::{
33    attributes::{AlignmentAttr, FastmathFlagsAttr},
34    types::VectorType,
35};
36
37use super::{attributes::IntegerOverflowFlagsAttr, types::PointerType};
38
39/// Binary arithmetic [Op].
40#[op_interface]
41pub trait BinArithOp:
42    SameOperandsAndResultType + OneResultInterface + NOpdsInterface<2> + NResultsInterface<1>
43{
44    /// Create a new binary arithmetic operation given the operands.
45    fn new(ctx: &mut Context, lhs: Value, rhs: Value) -> Self
46    where
47        Self: Sized,
48    {
49        let op = Operation::new(
50            ctx,
51            Self::get_concrete_op_info(),
52            vec![lhs.get_type(ctx)],
53            vec![lhs, rhs],
54            vec![],
55            0,
56        );
57        Self::from_operation(op)
58    }
59
60    fn verify(_op: &dyn Op, _ctx: &Context) -> Result<()>
61    where
62        Self: Sized,
63    {
64        Ok(())
65    }
66}
67
68#[derive(Error, Debug)]
69#[error("Integer binary arithmetic Op can only have signless integer result/operand type")]
70pub struct IntBinArithOpErr;
71
72/// Integer binary arithmetic [Op]
73#[op_interface]
74pub trait IntBinArithOp: BinArithOp {
75    fn verify(op: &dyn Op, ctx: &Context) -> Result<()>
76    where
77        Self: Sized,
78    {
79        let mut ty = op_cast::<dyn SameOperandsAndResultType>(op)
80            .expect("Op must impl SameOperandsAndResultType")
81            .get_type(ctx);
82
83        if let Some(vec_ty) = ty.deref(ctx).downcast_ref::<VectorType>() {
84            ty = vec_ty.elem_type();
85        }
86
87        let ty = ty.deref(ctx);
88        let Some(int_ty) = ty.downcast_ref::<IntegerType>() else {
89            return verify_err!(op.loc(ctx), IntBinArithOpErr);
90        };
91
92        if int_ty.signedness() != Signedness::Signless {
93            return verify_err!(op.loc(ctx), IntBinArithOpErr);
94        }
95
96        Ok(())
97    }
98}
99
100dict_key!(
101    /// Attribute key for integer overflow flags.
102    ATTR_KEY_INTEGER_OVERFLOW_FLAGS,
103    "llvm_integer_overflow_flags"
104);
105
106#[derive(Error, Debug)]
107#[error("IntegerOverflowFlag missing on Op")]
108pub struct IntBinArithOpWithOverflowFlagErr;
109
110/// Integer binary arithmetic [Op] with [IntegerOverflowFlagsAttr]
111#[op_interface]
112pub trait IntBinArithOpWithOverflowFlag: IntBinArithOp {
113    /// Create a new integer binary op with overflow flags set.
114    fn new_with_overflow_flag(
115        ctx: &mut Context,
116        lhs: Value,
117        rhs: Value,
118        flag: IntegerOverflowFlagsAttr,
119    ) -> Self
120    where
121        Self: Sized,
122    {
123        let op = Self::new(ctx, lhs, rhs);
124        op.set_integer_overflow_flag(ctx, flag);
125        op
126    }
127
128    /// Get the integer overflow flag on this [Op].
129    fn integer_overflow_flag(&self, ctx: &Context) -> IntegerOverflowFlagsAttr
130    where
131        Self: Sized,
132    {
133        self.get_operation()
134            .deref(ctx)
135            .attributes
136            .get::<IntegerOverflowFlagsAttr>(&ATTR_KEY_INTEGER_OVERFLOW_FLAGS)
137            .expect("Integer overflow flag missing or is of incorrect type")
138            .clone()
139    }
140
141    /// Set the integer overflow flag for this [Op].
142    fn set_integer_overflow_flag(&self, ctx: &Context, flag: IntegerOverflowFlagsAttr)
143    where
144        Self: Sized,
145    {
146        self.get_operation()
147            .deref_mut(ctx)
148            .attributes
149            .set(ATTR_KEY_INTEGER_OVERFLOW_FLAGS.clone(), flag);
150    }
151
152    fn verify(op: &dyn Op, ctx: &Context) -> Result<()>
153    where
154        Self: Sized,
155    {
156        let op = op.get_operation().deref(ctx);
157        if op
158            .attributes
159            .get::<IntegerOverflowFlagsAttr>(&ATTR_KEY_INTEGER_OVERFLOW_FLAGS)
160            .is_none()
161        {
162            return verify_err!(op.loc(), IntBinArithOpWithOverflowFlagErr);
163        }
164
165        Ok(())
166    }
167}
168
169#[derive(Error, Debug)]
170#[error("Floating point arithmetic Op can only have signless floating point result/operand type")]
171pub struct FloatBinArithOpErr;
172
173/// Floating point binary arithmetic [Op]
174#[op_interface]
175pub trait FloatBinArithOp: BinArithOp {
176    fn verify(op: &dyn Op, ctx: &Context) -> Result<()>
177    where
178        Self: Sized,
179    {
180        let mut ty = op_cast::<dyn SameOperandsAndResultType>(op)
181            .expect("Op must impl SameOperandsAndResultType")
182            .get_type(ctx);
183
184        if let Some(vec_ty) = ty.deref(ctx).downcast_ref::<VectorType>() {
185            ty = vec_ty.elem_type();
186        }
187
188        let ty = ty.deref(ctx);
189        if type_cast::<dyn FloatTypeInterface>(&**ty).is_none() {
190            return verify_err!(op.loc(ctx), FloatBinArithOpErr);
191        }
192        Ok(())
193    }
194}
195
196dict_key!(
197    /// Attribute key for fastmath flags.
198    ATTR_KEY_FAST_MATH_FLAGS,
199    "llvm_fast_math_flags"
200);
201
202#[derive(Error, Debug)]
203#[error("Fastmath flag missing on Op")]
204pub struct FastMathFlagMissingErr;
205
206#[op_interface]
207pub trait FastMathFlags {
208    /// Get the fast math flags on this [Op].
209    fn fast_math_flags(&self, ctx: &Context) -> FastmathFlagsAttr
210    where
211        Self: Sized,
212    {
213        *self
214            .get_operation()
215            .deref(ctx)
216            .attributes
217            .get::<FastmathFlagsAttr>(&ATTR_KEY_FAST_MATH_FLAGS)
218            .expect("Fast math flags missing or is of incorrect type")
219    }
220
221    /// Set the fast math flags for this [Op].
222    fn set_fast_math_flags(&self, ctx: &Context, flag: FastmathFlagsAttr)
223    where
224        Self: Sized,
225    {
226        self.get_operation()
227            .deref_mut(ctx)
228            .attributes
229            .set(ATTR_KEY_FAST_MATH_FLAGS.clone(), flag);
230    }
231
232    fn verify(op: &dyn Op, ctx: &Context) -> Result<()>
233    where
234        Self: Sized,
235    {
236        let op = op.get_operation().deref(ctx);
237        if op
238            .attributes
239            .get::<FastmathFlagsAttr>(&ATTR_KEY_FAST_MATH_FLAGS)
240            .is_none()
241        {
242            return verify_err!(op.loc(), FastmathFlagMissingErr);
243        }
244
245        Ok(())
246    }
247}
248
249/// Floating point binary arithmetic [Op] with [FastmathFlagsAttr]
250#[op_interface]
251pub trait FloatBinArithOpWithFastMathFlags: FloatBinArithOp + FastMathFlags {
252    /// Create a new floating point binary op with fast math flags set.
253    fn new_with_fast_math_flags(
254        ctx: &mut Context,
255        lhs: Value,
256        rhs: Value,
257        flag: FastmathFlagsAttr,
258    ) -> Self
259    where
260        Self: Sized,
261    {
262        let op = Self::new(ctx, lhs, rhs);
263        op.set_fast_math_flags(ctx, flag);
264        op
265    }
266
267    fn verify(_op: &dyn Op, _ctx: &Context) -> Result<()>
268    where
269        Self: Sized,
270    {
271        Ok(())
272    }
273}
274
275#[derive(Error, Debug)]
276#[error("Fastmath flag missing on Op")]
277pub struct FastmathFlagMissingErr;
278
279dict_key!(
280    /// Attribute key for nneg flag.
281    ATTR_KEY_NNEG_FLAG,
282    "llvm_nneg_flag"
283);
284
285#[op_interface]
286pub trait NNegFlag {
287    // Get the current NNEG flag value.
288    fn nneg(&self, ctx: &Context) -> bool {
289        self.get_operation()
290            .deref(ctx)
291            .attributes
292            .get::<BoolAttr>(&ATTR_KEY_NNEG_FLAG)
293            .expect("NNEG flag missing or is of incorrect type")
294            .clone()
295            .into()
296    }
297    // Set the current NNEG flag value.
298    fn set_nneg(&self, ctx: &Context, flag: bool) {
299        self.get_operation()
300            .deref_mut(ctx)
301            .attributes
302            .set(ATTR_KEY_NNEG_FLAG.clone(), BoolAttr::new(flag));
303    }
304    fn verify(op: &dyn Op, ctx: &Context) -> Result<()>
305    where
306        Self: Sized,
307    {
308        let op = op.get_operation().deref(ctx);
309        if op.attributes.get::<BoolAttr>(&ATTR_KEY_NNEG_FLAG).is_none() {
310            return verify_err!(op.loc(), NNegFlagMissingErr);
311        }
312
313        Ok(())
314    }
315}
316
317#[derive(Error, Debug)]
318#[error("NNEG flag missing on Op")]
319pub struct NNegFlagMissingErr;
320
321#[derive(Error, Debug)]
322#[error("Result must be a pointer type, but is not")]
323pub struct PointerTypeResultVerifyErr;
324
325/// An [Op] with a single result whose type is [PointerType]
326#[op_interface]
327pub trait PointerTypeResult: OneResultInterface + ResultNOfType<0, PointerType> {
328    /// Get the pointee type of the result pointer.
329    fn result_pointee_type(&self, ctx: &Context) -> Ptr<TypeObj>;
330
331    fn verify(op: &dyn Op, ctx: &Context) -> Result<()>
332    where
333        Self: Sized,
334    {
335        if !op_cast::<dyn OneResultInterface>(op)
336            .expect("An Op here must impl OneResultInterface")
337            .result_type(ctx)
338            .deref(ctx)
339            .is::<PointerType>()
340        {
341            return verify_err!(op.loc(ctx), PointerTypeResultVerifyErr);
342        }
343
344        Ok(())
345    }
346}
347
348/// A Cast [Op] has one argument and one result.
349#[op_interface]
350pub trait CastOpInterface:
351    OneResultInterface + OneOpdInterface + NResultsInterface<1> + NOpdsInterface<1>
352{
353    /// Create a new cast operation given the operand.
354    fn new(ctx: &mut Context, operand: Value, res_type: Ptr<TypeObj>) -> Self
355    where
356        Self: Sized,
357    {
358        let op = Operation::new(
359            ctx,
360            Self::get_concrete_op_info(),
361            vec![res_type],
362            vec![operand],
363            vec![],
364            0,
365        );
366        Self::from_operation(op)
367    }
368
369    fn verify(_op: &dyn Op, _ctx: &Context) -> Result<()>
370    where
371        Self: Sized,
372    {
373        Ok(())
374    }
375}
376
377/// A Cast [Op] with NNEG flag.
378#[op_interface]
379pub trait CastOpWithNNegInterface:
380    CastOpInterface + NNegFlag + NResultsInterface<1> + NOpdsInterface<1>
381{
382    /// Create a new cast operation with nneg flag
383    fn new_with_nneg(ctx: &mut Context, operand: Value, res_type: Ptr<TypeObj>, nneg: bool) -> Self
384    where
385        Self: Sized,
386    {
387        let op = Self::new(ctx, operand, res_type);
388        op.set_nneg(ctx, nneg);
389        op
390    }
391
392    fn verify(_op: &dyn Op, _ctx: &Context) -> Result<()>
393    where
394        Self: Sized,
395    {
396        Ok(())
397    }
398}
399
400/// Is a global value (variable or function) declaration.
401#[op_interface]
402pub trait IsDeclaration {
403    /// Check if this global value (variable or function) is a declaration.
404    fn is_declaration(&self, ctx: &Context) -> bool
405    where
406        Self: Sized;
407
408    fn verify(_op: &dyn Op, _ctx: &Context) -> Result<()>
409    where
410        Self: Sized,
411    {
412        Ok(())
413    }
414}
415
416dict_key!(
417    /// Attribute key for LLVM symbol name.
418    ATTR_KEY_LLVM_SYMBOL_NAME,
419    "llvm_symbol_name"
420);
421
422/// Since LLVM symbols can have characters that are illegal in pliron,
423/// this interface provides a way to get the original LLVM symbol name.
424#[op_interface]
425pub trait LlvmSymbolName: SymbolOpInterface {
426    /// Get the original LLVM symbol name, if it's different from the pliron symbol name.
427    fn llvm_symbol_name(&self, ctx: &Context) -> Option<String> {
428        self.get_operation()
429            .deref(ctx)
430            .attributes
431            .get::<pliron::builtin::attributes::StringAttr>(&ATTR_KEY_LLVM_SYMBOL_NAME)
432            .map(|attr| attr.clone().into())
433    }
434
435    /// Set the original LLVM symbol name.
436    fn set_llvm_symbol_name(&self, ctx: &Context, name: String) {
437        self.get_operation().deref_mut(ctx).attributes.set(
438            ATTR_KEY_LLVM_SYMBOL_NAME.clone(),
439            pliron::builtin::attributes::StringAttr::new(name),
440        );
441    }
442
443    fn verify(_op: &dyn Op, _ctx: &Context) -> Result<()>
444    where
445        Self: Sized,
446    {
447        Ok(())
448    }
449}
450
451dict_key!(
452    /// Attribute key for alignment.
453    ATTR_KEY_LLVM_ALIGNMENT,
454    "llvm_alignment"
455);
456
457/// Ops that can have an alignment set.
458#[op_interface]
459pub trait AlignableOpInterface {
460    /// Get the alignment of this [Op], if set.
461    fn alignment(&self, ctx: &Context) -> Option<u32>
462    where
463        Self: Sized,
464    {
465        self.get_operation()
466            .deref(ctx)
467            .attributes
468            .get::<AlignmentAttr>(&ATTR_KEY_LLVM_ALIGNMENT)
469            .map(|attr| attr.0)
470    }
471
472    /// Set the alignment of this [Op].
473    fn set_alignment(&self, ctx: &Context, alignment: u32)
474    where
475        Self: Sized,
476    {
477        self.get_operation()
478            .deref_mut(ctx)
479            .attributes
480            .set(ATTR_KEY_LLVM_ALIGNMENT.clone(), AlignmentAttr(alignment));
481    }
482
483    fn verify(_op: &dyn Op, _ctx: &Context) -> Result<()>
484    where
485        Self: Sized,
486    {
487        Ok(())
488    }
489}