Skip to main content

pliron_llvm/
attributes.rs

1//! Attributes belonging to the LLVM dialect.
2
3use std::fmt::Display;
4
5use pliron::combine::{self, Parser, choice, parser::char::spaces};
6use thiserror::Error;
7
8use pliron::builtin::attributes::IntegerAttr;
9use pliron::common_traits::Verify;
10use pliron::context::Context;
11use pliron::derive::{format, pliron_attr};
12use pliron::location::Located;
13use pliron::parsable::{IntoParseResult, Parsable};
14use pliron::printable::Printable;
15use pliron::result::Result;
16use pliron::{impl_printable_for_display, input_error, verify_err_noloc};
17
18use crate::llvm_sys::core::FastmathFlags;
19
20/// Integer overflow flags for arithmetic operations.
21/// The description below is from LLVM's
22/// [release notes](https://releases.llvm.org/2.6/docs/ReleaseNotes.html)
23/// that added the flags.
24/// "nsw" and "nuw" bits indicate that the operation is guaranteed to not overflow
25/// (in the signed or unsigned case, respectively). This gives the optimizer more information
26///  and can be used for things like C signed integer values, which are undefined on overflow.
27#[pliron_attr(name = "llvm.integer_overlflow_flags", format, verifier = "succ")]
28#[derive(PartialEq, Eq, Clone, Debug, Default, Hash)]
29pub struct IntegerOverflowFlagsAttr {
30    pub nsw: bool,
31    pub nuw: bool,
32}
33
34#[pliron_attr(name = "llvm.fast_math_flags", verifier = "succ")]
35#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
36pub struct FastmathFlagsAttr(pub FastmathFlags);
37
38impl Default for FastmathFlagsAttr {
39    fn default() -> Self {
40        FastmathFlagsAttr(FastmathFlags::empty())
41    }
42}
43
44impl From<FastmathFlags> for FastmathFlagsAttr {
45    fn from(value: FastmathFlags) -> Self {
46        FastmathFlagsAttr(value)
47    }
48}
49
50impl From<FastmathFlagsAttr> for FastmathFlags {
51    fn from(attr: FastmathFlagsAttr) -> Self {
52        attr.0
53    }
54}
55
56impl Display for FastmathFlagsAttr {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        write!(f, "<")?;
59        bitflags::parser::to_writer(&self.0, &mut *f)?;
60        write!(f, ">")
61    }
62}
63
64impl_printable_for_display!(FastmathFlagsAttr);
65
66#[derive(Debug, Error)]
67#[error("Error parsing fastmath flags: {0}")]
68pub struct FastmathFlagParseErr(pub bitflags::parser::ParseError);
69
70impl Parsable for FastmathFlagsAttr {
71    type Arg = ();
72
73    type Parsed = Self;
74
75    fn parse<'a>(
76        state_stream: &mut pliron::parsable::StateStream<'a>,
77        _arg: Self::Arg,
78    ) -> pliron::parsable::ParseResult<'a, Self::Parsed> {
79        let pos = state_stream.loc();
80        let allowed_chars = combine::choice!(
81            combine::parser::char::space().map(|c| c.to_string()),
82            combine::parser::char::alpha_num().map(|c| c.to_string()),
83            combine::parser::char::char('|').map(|c: char| c.to_string())
84        );
85
86        let (parsed, _): (Vec<String>, _) = combine::between(
87            combine::parser::char::char('<').with(spaces()),
88            spaces().with(combine::parser::char::char('>')),
89            combine::many(allowed_chars),
90        )
91        .parse_stream(state_stream)
92        .into_result()?;
93        let parsed_string = parsed.concat();
94
95        let (fast_math_flags, _) = bitflags::parser::from_str::<FastmathFlags>(&parsed_string)
96            .map_err(|e| input_error!(pos.clone(), FastmathFlagParseErr(e)))
97            .into_parse_result()?;
98
99        Ok(FastmathFlagsAttr(fast_math_flags)).into_parse_result()
100    }
101}
102
103#[pliron_attr(name = "llvm.icmp_predicate", verifier = "succ", format)]
104#[derive(PartialEq, Eq, Clone, Debug, Hash)]
105pub enum ICmpPredicateAttr {
106    EQ,
107    NE,
108    SLT,
109    SLE,
110    SGT,
111    SGE,
112    ULT,
113    ULE,
114    UGT,
115    UGE,
116}
117
118#[pliron_attr(name = "llvm.fcmp_predicate", format, verifier = "succ")]
119#[derive(PartialEq, Eq, Clone, Debug, Hash)]
120pub enum FCmpPredicateAttr {
121    False,
122    OEQ,
123    OGT,
124    OGE,
125    OLT,
126    OLE,
127    ONE,
128    ORD,
129    UEQ,
130    UGT,
131    UGE,
132    ULT,
133    ULE,
134    UNE,
135    UNO,
136    True,
137}
138
139/// An index for a GEP can be either a constant or an SSA operand.
140/// Contrary to its name, this isn't an [Attribute][pliron::attribute::Attribute].
141#[derive(PartialEq, Eq, Clone, Debug, Hash)]
142#[format]
143pub enum GepIndexAttr {
144    /// This GEP index is a raw u32 compile time constant
145    Constant(u32),
146    /// This GEP Index is the SSA value in the containing
147    /// [Operation](pliron::operation::Operation)s `operands[idx]`
148    OperandIdx(usize),
149}
150
151#[pliron_attr(
152    name = "llvm.gep_indices",
153    format = "`[` vec($0, CharSpace(`,`)) `]`",
154    verifier = "succ"
155)]
156#[derive(PartialEq, Eq, Clone, Debug, Hash)]
157pub struct GepIndicesAttr(pub Vec<GepIndexAttr>);
158
159/// An attribute that contains a list of case values for a switch operation.
160#[pliron_attr(name = "llvm.case_values", format = "`[` vec($0, CharSpace(`,`)) `]`")]
161#[derive(PartialEq, Eq, Clone, Debug, Hash)]
162pub struct CaseValuesAttr(pub Vec<IntegerAttr>);
163
164#[derive(Debug, Error)]
165#[error("Case values must be of the same type, but found different types: {0} and {1}")]
166pub struct CaseValuesAttrVerifyErr(pub String, pub String);
167
168impl Verify for CaseValuesAttr {
169    fn verify(&self, ctx: &Context) -> Result<()> {
170        self.0.windows(2).try_for_each(|pair| {
171            pair[0].verify(ctx)?;
172            if pair[0].get_type() != pair[1].get_type() {
173                verify_err_noloc!(CaseValuesAttrVerifyErr(
174                    pair[0].get_type().disp(ctx).to_string(),
175                    pair[1].get_type().disp(ctx).to_string()
176                ))
177            } else {
178                Ok(())
179            }
180        })
181    }
182}
183
184#[pliron_attr(name = "llvm.linkage", format, verifier = "succ")]
185#[derive(PartialEq, Eq, Clone, Debug, Hash)]
186pub enum LinkageAttr {
187    ExternalLinkage,
188    AvailableExternallyLinkage,
189    LinkOnceAnyLinkage,
190    LinkOnceODRLinkage,
191    LinkOnceODRAutoHideLinkage,
192    WeakAnyLinkage,
193    WeakODRLinkage,
194    AppendingLinkage,
195    InternalLinkage,
196    PrivateLinkage,
197    DLLImportLinkage,
198    DLLExportLinkage,
199    ExternalWeakLinkage,
200    GhostLinkage,
201    CommonLinkage,
202    LinkerPrivateLinkage,
203    LinkerPrivateWeakLinkage,
204}
205
206#[pliron_attr(
207    name = "llvm.insert_extract_value_indices",
208    format = "`[` vec($0, CharSpace(`,`)) `]`",
209    verifier = "succ"
210)]
211#[derive(PartialEq, Eq, Clone, Debug, Hash)]
212pub struct InsertExtractValueIndicesAttr(pub Vec<u32>);
213
214#[pliron_attr(name = "llvm.align", format = "$0", verifier = "succ")]
215#[derive(PartialEq, Eq, Clone, Debug, Hash)]
216pub struct AlignmentAttr(pub u32);
217
218#[pliron_attr(
219    name = "llvm.shuffle_vector_mask",
220    format = "`[` vec($0, CharSpace(`,`)) `]`",
221    verifier = "succ"
222)]
223#[derive(PartialEq, Eq, Clone, Debug, Hash)]
224pub struct ShuffleVectorMaskAttr(pub Vec<i32>);
225
226#[cfg(test)]
227mod tests {
228    use expect_test::expect;
229    use pliron::{
230        location,
231        parsable::{self, state_stream_from_iterator},
232    };
233
234    use super::*;
235
236    #[test]
237    fn test_fastmath_flags_attr_empty() {
238        let flags = FastmathFlags::empty();
239        assert_eq!(flags.bits(), 0);
240
241        let ctx = &mut Context::default();
242        let flags_attr: FastmathFlagsAttr = flags.into();
243        expect!["<>"].assert_eq(&flags_attr.disp(ctx).to_string());
244
245        let input = "<>";
246        let mut state_stream = state_stream_from_iterator(
247            input.chars(),
248            parsable::State::new(ctx, location::Source::InMemory),
249        );
250        let (parsed, _) = FastmathFlagsAttr::parse(&mut state_stream, ()).unwrap();
251        assert_eq!(parsed, flags_attr);
252    }
253
254    #[test]
255    fn test_fastmath_flags_attr_set_flags() {
256        let mut flags = FastmathFlags::empty();
257        flags |= FastmathFlags::NNAN | FastmathFlags::NINF;
258        assert!(flags.contains(FastmathFlags::NNAN));
259        assert!(flags.contains(FastmathFlags::NINF));
260        assert!(!flags.contains(FastmathFlags::NSZ));
261    }
262
263    #[test]
264    fn test_fastmath_flags_attr_fmt() {
265        let ctx = &Context::default();
266        let flags: FastmathFlagsAttr = (FastmathFlags::NNAN | FastmathFlags::ARCP).into();
267        expect!["<NNAN | ARCP>"].assert_eq(&flags.disp(ctx).to_string());
268    }
269
270    #[test]
271    fn test_fastmath_flags_attr_fmt_fast() {
272        let ctx = &Context::default();
273        let flags: FastmathFlagsAttr = FastmathFlags::FAST.into();
274        expect!["<NNAN | NINF | NSZ | ARCP | CONTRACT | AFN | REASSOC>"]
275            .assert_eq(&flags.disp(ctx).to_string());
276    }
277
278    #[test]
279    fn test_fastmath_flags_attr_parse_valid() {
280        let ctx = &mut Context::default();
281
282        let input = "<NNAN | ARCP>";
283        let mut state_stream = state_stream_from_iterator(
284            input.chars(),
285            parsable::State::new(ctx, location::Source::InMemory),
286        );
287        let (parsed, _) = FastmathFlagsAttr::parse(&mut state_stream, ()).unwrap();
288        assert!(parsed.0.contains(FastmathFlags::NNAN));
289        assert!(parsed.0.contains(FastmathFlags::ARCP));
290    }
291
292    // Test input with FAST flag set
293    #[test]
294    fn test_fastmath_flags_attr_parse_fast() {
295        let ctx = &mut Context::default();
296
297        let input = "<FAST>";
298        let mut state_stream = state_stream_from_iterator(
299            input.chars(),
300            parsable::State::new(ctx, location::Source::InMemory),
301        );
302        let (parsed, _) = FastmathFlagsAttr::parse(&mut state_stream, ()).unwrap();
303        assert!(parsed.0.contains(FastmathFlags::FAST));
304
305        // FAST also means all the other flags.
306        assert!(parsed.0.contains(FastmathFlags::NNAN));
307        assert!(parsed.0.contains(FastmathFlags::NINF));
308        assert!(parsed.0.contains(FastmathFlags::NSZ));
309        assert!(parsed.0.contains(FastmathFlags::ARCP));
310        assert!(parsed.0.contains(FastmathFlags::CONTRACT));
311        assert!(parsed.0.contains(FastmathFlags::REASSOC));
312    }
313
314    #[test]
315    fn test_fastmath_flags_attr_parse_invalid() {
316        let ctx = &mut Context::default();
317        let input = "<INVALIDFLAG>";
318        let state_stream = state_stream_from_iterator(
319            input.chars(),
320            parsable::State::new(ctx, location::Source::InMemory),
321        );
322        match FastmathFlagsAttr::parser(()).parse(state_stream) {
323            Ok((parsed, _)) => {
324                panic!("Expected error, but got: {}", parsed);
325            }
326            Err(e) => {
327                expect![[r#"
328                    Parse error at line: 1, column: 1
329                    Error parsing fastmath flags: unrecognized named flag `INVALIDFLAG`
330                "#]]
331                .assert_eq(&e.to_string());
332            }
333        }
334    }
335}