1use pliron::{
4 builtin::{
5 attributes::BoolAttr,
6 op_interfaces::{
7 NOpdsInterface, NResultsInterface, OneOpdInterface, ResultNOfType, SymbolOpInterface,
8 },
9 type_interfaces::FloatTypeInterface,
10 },
11 derive::op_interface,
12 dict_key,
13 r#type::type_cast,
14};
15use thiserror::Error;
16
17use pliron::{
18 builtin::{
19 op_interfaces::{OneResultInterface, SameOperandsAndResultType},
20 types::{IntegerType, Signedness},
21 },
22 context::{Context, Ptr},
23 location::Located,
24 op::{Op, op_cast},
25 operation::Operation,
26 result::Result,
27 r#type::{TypeObj, Typed},
28 value::Value,
29 verify_err,
30};
31
32use crate::{
33 attributes::{AlignmentAttr, FastmathFlagsAttr},
34 types::VectorType,
35};
36
37use super::{attributes::IntegerOverflowFlagsAttr, types::PointerType};
38
39#[op_interface]
41pub trait BinArithOp:
42 SameOperandsAndResultType + OneResultInterface + NOpdsInterface<2> + NResultsInterface<1>
43{
44 fn new(ctx: &mut Context, lhs: Value, rhs: Value) -> Self
46 where
47 Self: Sized,
48 {
49 let op = Operation::new(
50 ctx,
51 Self::get_concrete_op_info(),
52 vec![lhs.get_type(ctx)],
53 vec![lhs, rhs],
54 vec![],
55 0,
56 );
57 Self::from_operation(op)
58 }
59
60 fn verify(_op: &dyn Op, _ctx: &Context) -> Result<()>
61 where
62 Self: Sized,
63 {
64 Ok(())
65 }
66}
67
68#[derive(Error, Debug)]
69#[error("Integer binary arithmetic Op can only have signless integer result/operand type")]
70pub struct IntBinArithOpErr;
71
72#[op_interface]
74pub trait IntBinArithOp: BinArithOp {
75 fn verify(op: &dyn Op, ctx: &Context) -> Result<()>
76 where
77 Self: Sized,
78 {
79 let mut ty = op_cast::<dyn SameOperandsAndResultType>(op)
80 .expect("Op must impl SameOperandsAndResultType")
81 .get_type(ctx);
82
83 if let Some(vec_ty) = ty.deref(ctx).downcast_ref::<VectorType>() {
84 ty = vec_ty.elem_type();
85 }
86
87 let ty = ty.deref(ctx);
88 let Some(int_ty) = ty.downcast_ref::<IntegerType>() else {
89 return verify_err!(op.loc(ctx), IntBinArithOpErr);
90 };
91
92 if int_ty.signedness() != Signedness::Signless {
93 return verify_err!(op.loc(ctx), IntBinArithOpErr);
94 }
95
96 Ok(())
97 }
98}
99
100dict_key!(
101 ATTR_KEY_INTEGER_OVERFLOW_FLAGS,
103 "llvm_integer_overflow_flags"
104);
105
106#[derive(Error, Debug)]
107#[error("IntegerOverflowFlag missing on Op")]
108pub struct IntBinArithOpWithOverflowFlagErr;
109
110#[op_interface]
112pub trait IntBinArithOpWithOverflowFlag: IntBinArithOp {
113 fn new_with_overflow_flag(
115 ctx: &mut Context,
116 lhs: Value,
117 rhs: Value,
118 flag: IntegerOverflowFlagsAttr,
119 ) -> Self
120 where
121 Self: Sized,
122 {
123 let op = Self::new(ctx, lhs, rhs);
124 op.set_integer_overflow_flag(ctx, flag);
125 op
126 }
127
128 fn integer_overflow_flag(&self, ctx: &Context) -> IntegerOverflowFlagsAttr
130 where
131 Self: Sized,
132 {
133 self.get_operation()
134 .deref(ctx)
135 .attributes
136 .get::<IntegerOverflowFlagsAttr>(&ATTR_KEY_INTEGER_OVERFLOW_FLAGS)
137 .expect("Integer overflow flag missing or is of incorrect type")
138 .clone()
139 }
140
141 fn set_integer_overflow_flag(&self, ctx: &Context, flag: IntegerOverflowFlagsAttr)
143 where
144 Self: Sized,
145 {
146 self.get_operation()
147 .deref_mut(ctx)
148 .attributes
149 .set(ATTR_KEY_INTEGER_OVERFLOW_FLAGS.clone(), flag);
150 }
151
152 fn verify(op: &dyn Op, ctx: &Context) -> Result<()>
153 where
154 Self: Sized,
155 {
156 let op = op.get_operation().deref(ctx);
157 if op
158 .attributes
159 .get::<IntegerOverflowFlagsAttr>(&ATTR_KEY_INTEGER_OVERFLOW_FLAGS)
160 .is_none()
161 {
162 return verify_err!(op.loc(), IntBinArithOpWithOverflowFlagErr);
163 }
164
165 Ok(())
166 }
167}
168
169#[derive(Error, Debug)]
170#[error("Floating point arithmetic Op can only have signless floating point result/operand type")]
171pub struct FloatBinArithOpErr;
172
173#[op_interface]
175pub trait FloatBinArithOp: BinArithOp {
176 fn verify(op: &dyn Op, ctx: &Context) -> Result<()>
177 where
178 Self: Sized,
179 {
180 let mut ty = op_cast::<dyn SameOperandsAndResultType>(op)
181 .expect("Op must impl SameOperandsAndResultType")
182 .get_type(ctx);
183
184 if let Some(vec_ty) = ty.deref(ctx).downcast_ref::<VectorType>() {
185 ty = vec_ty.elem_type();
186 }
187
188 let ty = ty.deref(ctx);
189 if type_cast::<dyn FloatTypeInterface>(&**ty).is_none() {
190 return verify_err!(op.loc(ctx), FloatBinArithOpErr);
191 }
192 Ok(())
193 }
194}
195
196dict_key!(
197 ATTR_KEY_FAST_MATH_FLAGS,
199 "llvm_fast_math_flags"
200);
201
202#[derive(Error, Debug)]
203#[error("Fastmath flag missing on Op")]
204pub struct FastMathFlagMissingErr;
205
206#[op_interface]
207pub trait FastMathFlags {
208 fn fast_math_flags(&self, ctx: &Context) -> FastmathFlagsAttr
210 where
211 Self: Sized,
212 {
213 *self
214 .get_operation()
215 .deref(ctx)
216 .attributes
217 .get::<FastmathFlagsAttr>(&ATTR_KEY_FAST_MATH_FLAGS)
218 .expect("Fast math flags missing or is of incorrect type")
219 }
220
221 fn set_fast_math_flags(&self, ctx: &Context, flag: FastmathFlagsAttr)
223 where
224 Self: Sized,
225 {
226 self.get_operation()
227 .deref_mut(ctx)
228 .attributes
229 .set(ATTR_KEY_FAST_MATH_FLAGS.clone(), flag);
230 }
231
232 fn verify(op: &dyn Op, ctx: &Context) -> Result<()>
233 where
234 Self: Sized,
235 {
236 let op = op.get_operation().deref(ctx);
237 if op
238 .attributes
239 .get::<FastmathFlagsAttr>(&ATTR_KEY_FAST_MATH_FLAGS)
240 .is_none()
241 {
242 return verify_err!(op.loc(), FastmathFlagMissingErr);
243 }
244
245 Ok(())
246 }
247}
248
249#[op_interface]
251pub trait FloatBinArithOpWithFastMathFlags: FloatBinArithOp + FastMathFlags {
252 fn new_with_fast_math_flags(
254 ctx: &mut Context,
255 lhs: Value,
256 rhs: Value,
257 flag: FastmathFlagsAttr,
258 ) -> Self
259 where
260 Self: Sized,
261 {
262 let op = Self::new(ctx, lhs, rhs);
263 op.set_fast_math_flags(ctx, flag);
264 op
265 }
266
267 fn verify(_op: &dyn Op, _ctx: &Context) -> Result<()>
268 where
269 Self: Sized,
270 {
271 Ok(())
272 }
273}
274
275#[derive(Error, Debug)]
276#[error("Fastmath flag missing on Op")]
277pub struct FastmathFlagMissingErr;
278
279dict_key!(
280 ATTR_KEY_NNEG_FLAG,
282 "llvm_nneg_flag"
283);
284
285#[op_interface]
286pub trait NNegFlag {
287 fn nneg(&self, ctx: &Context) -> bool {
289 self.get_operation()
290 .deref(ctx)
291 .attributes
292 .get::<BoolAttr>(&ATTR_KEY_NNEG_FLAG)
293 .expect("NNEG flag missing or is of incorrect type")
294 .clone()
295 .into()
296 }
297 fn set_nneg(&self, ctx: &Context, flag: bool) {
299 self.get_operation()
300 .deref_mut(ctx)
301 .attributes
302 .set(ATTR_KEY_NNEG_FLAG.clone(), BoolAttr::new(flag));
303 }
304 fn verify(op: &dyn Op, ctx: &Context) -> Result<()>
305 where
306 Self: Sized,
307 {
308 let op = op.get_operation().deref(ctx);
309 if op.attributes.get::<BoolAttr>(&ATTR_KEY_NNEG_FLAG).is_none() {
310 return verify_err!(op.loc(), NNegFlagMissingErr);
311 }
312
313 Ok(())
314 }
315}
316
317#[derive(Error, Debug)]
318#[error("NNEG flag missing on Op")]
319pub struct NNegFlagMissingErr;
320
321#[derive(Error, Debug)]
322#[error("Result must be a pointer type, but is not")]
323pub struct PointerTypeResultVerifyErr;
324
325#[op_interface]
327pub trait PointerTypeResult: OneResultInterface + ResultNOfType<0, PointerType> {
328 fn result_pointee_type(&self, ctx: &Context) -> Ptr<TypeObj>;
330
331 fn verify(op: &dyn Op, ctx: &Context) -> Result<()>
332 where
333 Self: Sized,
334 {
335 if !op_cast::<dyn OneResultInterface>(op)
336 .expect("An Op here must impl OneResultInterface")
337 .result_type(ctx)
338 .deref(ctx)
339 .is::<PointerType>()
340 {
341 return verify_err!(op.loc(ctx), PointerTypeResultVerifyErr);
342 }
343
344 Ok(())
345 }
346}
347
348#[op_interface]
350pub trait CastOpInterface:
351 OneResultInterface + OneOpdInterface + NResultsInterface<1> + NOpdsInterface<1>
352{
353 fn new(ctx: &mut Context, operand: Value, res_type: Ptr<TypeObj>) -> Self
355 where
356 Self: Sized,
357 {
358 let op = Operation::new(
359 ctx,
360 Self::get_concrete_op_info(),
361 vec![res_type],
362 vec![operand],
363 vec![],
364 0,
365 );
366 Self::from_operation(op)
367 }
368
369 fn verify(_op: &dyn Op, _ctx: &Context) -> Result<()>
370 where
371 Self: Sized,
372 {
373 Ok(())
374 }
375}
376
377#[op_interface]
379pub trait CastOpWithNNegInterface:
380 CastOpInterface + NNegFlag + NResultsInterface<1> + NOpdsInterface<1>
381{
382 fn new_with_nneg(ctx: &mut Context, operand: Value, res_type: Ptr<TypeObj>, nneg: bool) -> Self
384 where
385 Self: Sized,
386 {
387 let op = Self::new(ctx, operand, res_type);
388 op.set_nneg(ctx, nneg);
389 op
390 }
391
392 fn verify(_op: &dyn Op, _ctx: &Context) -> Result<()>
393 where
394 Self: Sized,
395 {
396 Ok(())
397 }
398}
399
400#[op_interface]
402pub trait IsDeclaration {
403 fn is_declaration(&self, ctx: &Context) -> bool
405 where
406 Self: Sized;
407
408 fn verify(_op: &dyn Op, _ctx: &Context) -> Result<()>
409 where
410 Self: Sized,
411 {
412 Ok(())
413 }
414}
415
416dict_key!(
417 ATTR_KEY_LLVM_SYMBOL_NAME,
419 "llvm_symbol_name"
420);
421
422#[op_interface]
425pub trait LlvmSymbolName: SymbolOpInterface {
426 fn llvm_symbol_name(&self, ctx: &Context) -> Option<String> {
428 self.get_operation()
429 .deref(ctx)
430 .attributes
431 .get::<pliron::builtin::attributes::StringAttr>(&ATTR_KEY_LLVM_SYMBOL_NAME)
432 .map(|attr| attr.clone().into())
433 }
434
435 fn set_llvm_symbol_name(&self, ctx: &Context, name: String) {
437 self.get_operation().deref_mut(ctx).attributes.set(
438 ATTR_KEY_LLVM_SYMBOL_NAME.clone(),
439 pliron::builtin::attributes::StringAttr::new(name),
440 );
441 }
442
443 fn verify(_op: &dyn Op, _ctx: &Context) -> Result<()>
444 where
445 Self: Sized,
446 {
447 Ok(())
448 }
449}
450
451dict_key!(
452 ATTR_KEY_LLVM_ALIGNMENT,
454 "llvm_alignment"
455);
456
457#[op_interface]
459pub trait AlignableOpInterface {
460 fn alignment(&self, ctx: &Context) -> Option<u32>
462 where
463 Self: Sized,
464 {
465 self.get_operation()
466 .deref(ctx)
467 .attributes
468 .get::<AlignmentAttr>(&ATTR_KEY_LLVM_ALIGNMENT)
469 .map(|attr| attr.0)
470 }
471
472 fn set_alignment(&self, ctx: &Context, alignment: u32)
474 where
475 Self: Sized,
476 {
477 self.get_operation()
478 .deref_mut(ctx)
479 .attributes
480 .set(ATTR_KEY_LLVM_ALIGNMENT.clone(), AlignmentAttr(alignment));
481 }
482
483 fn verify(_op: &dyn Op, _ctx: &Context) -> Result<()>
484 where
485 Self: Sized,
486 {
487 Ok(())
488 }
489}