Skip to main content

pliron_llvm/llvm_sys/
lljit.rs

1//! Safe(r) wrappers around llvm_sys::lljit
2//!
3//! ### Example
4//!```
5//! use pliron_llvm::llvm_sys::target::initialize_native;
6//! use pliron_llvm::llvm_sys::core::{LLVMContext, LLVMModule, LLVMMemoryBuffer};
7//! use pliron_llvm::llvm_sys::lljit::{LLVMLLJIT, JITSymbolGenericFlags};
8//! fn main() -> Result<(), String> {
9//!    initialize_native()?;
10//!    let context = LLVMContext::default();
11//!
12//!    fn my_rust_adder(a: i32, b: i32) -> i32 {
13//!        a + b
14//!    }
15//!
16//!    let ir = r#"
17//!      declare i32 @my_rust_adder(i32, i32)
18//!      define i32 @add(i32 %a, i32 %b) {
19//!          %sum = call i32 @my_rust_adder(i32 %a, i32 %b)
20//!          ret i32 %sum
21//!      }"#;
22//!    let ir_mb = LLVMMemoryBuffer::from_str(ir, "test_buffer");
23//!    let module = LLVMModule::from_ir_in_memory_buffer(&context, ir_mb)?;
24//!
25//!    let jit = LLVMLLJIT::new_with_default_builder()?;
26//!    jit.add_module(module)?;
27//!    // Add the Rust function as a symbol mapping
28//!    let rust_adder_addr = my_rust_adder as *const () as u64;
29//!    jit.add_symbol_mapping
30//!     ("my_rust_adder", rust_adder_addr,
31//!       JITSymbolGenericFlags::JITSymbolGenericFlagsCallable
32//!         | JITSymbolGenericFlags::JITSymbolGenericFlagsExported)?;
33//!
34//!    // Get symbol address for 'add' in the LLVM module
35//!    let symbol_addr = jit.lookup_symbol("add")?;
36//!    assert!(symbol_addr != 0);
37//!
38//!    let adder = unsafe { std::mem::transmute::<u64, fn(i32, i32) -> i32>(symbol_addr) };
39//!    assert_eq!(adder(2, 3), 5);
40//!    Ok(())
41//! }
42//! ```
43
44use bitflags::bitflags;
45use std::{mem::MaybeUninit, ptr};
46
47use llvm_sys::orc2::{
48    LLVMJITEvaluatedSymbol, LLVMJITSymbolFlags, LLVMJITSymbolGenericFlags, LLVMOrcAbsoluteSymbols,
49    LLVMOrcCSymbolMapPair, LLVMOrcCreateNewThreadSafeContext, LLVMOrcCreateNewThreadSafeModule,
50    LLVMOrcDisposeThreadSafeContext, LLVMOrcDisposeThreadSafeModule, lljit,
51};
52
53use crate::llvm_sys::{
54    core::{LLVMModule, handle_err},
55    cstr_to_string, to_c_str,
56};
57
58bitflags! {
59     #[derive(PartialEq, Eq, Clone, Debug, Hash, Copy)]
60    pub struct JITSymbolGenericFlags: u8 {
61        const JITSymbolGenericFlagsNone = 0;
62        const JITSymbolGenericFlagsExported = 1;
63        const JITSymbolGenericFlagsWeak = 2;
64        const JITSymbolGenericFlagsCallable = 4;
65        const JITSymbolGenericFlagsMaterializationSideEffectsOnly = 8;
66    }
67}
68
69impl From<LLVMJITSymbolGenericFlags> for JITSymbolGenericFlags {
70    fn from(value: LLVMJITSymbolGenericFlags) -> Self {
71        let mut flags = JITSymbolGenericFlags::empty();
72        if (value as u8) & (LLVMJITSymbolGenericFlags::LLVMJITSymbolGenericFlagsExported as u8) != 0
73        {
74            flags |= JITSymbolGenericFlags::JITSymbolGenericFlagsExported;
75        }
76        if (value as u8) & (LLVMJITSymbolGenericFlags::LLVMJITSymbolGenericFlagsWeak as u8) != 0 {
77            flags |= JITSymbolGenericFlags::JITSymbolGenericFlagsWeak;
78        }
79        if (value as u8) & (LLVMJITSymbolGenericFlags::LLVMJITSymbolGenericFlagsCallable as u8) != 0
80        {
81            flags |= JITSymbolGenericFlags::JITSymbolGenericFlagsCallable;
82        }
83        if (value as u8)
84            & (LLVMJITSymbolGenericFlags::LLVMJITSymbolGenericFlagsMaterializationSideEffectsOnly
85                as u8)
86            != 0
87        {
88            flags |= JITSymbolGenericFlags::JITSymbolGenericFlagsMaterializationSideEffectsOnly;
89        }
90        flags
91    }
92}
93
94impl From<JITSymbolGenericFlags> for u8 {
95    fn from(value: JITSymbolGenericFlags) -> Self {
96        let mut flags = LLVMJITSymbolGenericFlags::LLVMJITSymbolGenericFlagsNone as u8;
97        if value.contains(JITSymbolGenericFlags::JITSymbolGenericFlagsExported) {
98            flags |= LLVMJITSymbolGenericFlags::LLVMJITSymbolGenericFlagsExported as u8;
99        }
100        if value.contains(JITSymbolGenericFlags::JITSymbolGenericFlagsWeak) {
101            flags |= LLVMJITSymbolGenericFlags::LLVMJITSymbolGenericFlagsWeak as u8;
102        }
103        if value.contains(JITSymbolGenericFlags::JITSymbolGenericFlagsCallable) {
104            flags |= LLVMJITSymbolGenericFlags::LLVMJITSymbolGenericFlagsCallable as u8;
105        }
106        if value
107            .contains(JITSymbolGenericFlags::JITSymbolGenericFlagsMaterializationSideEffectsOnly)
108        {
109            flags |=
110                LLVMJITSymbolGenericFlags::LLVMJITSymbolGenericFlagsMaterializationSideEffectsOnly
111                    as u8;
112        }
113        flags
114    }
115}
116
117pub struct LLVMLLJIT(lljit::LLVMOrcLLJITRef);
118
119impl LLVMLLJIT {
120    /// Create a new LLJIT instance with default settings.
121    pub fn new_with_default_builder() -> Result<Self, String> {
122        unsafe {
123            let mut jit = MaybeUninit::uninit();
124            let err = lljit::LLVMOrcCreateLLJIT(jit.as_mut_ptr(), ptr::null_mut());
125            handle_err(err)?;
126            Ok(LLVMLLJIT(jit.assume_init()))
127        }
128    }
129
130    /// Add an [LLVMModule] to the JIT's main JITDylib, in its own thread-safe context.
131    pub fn add_module(&self, module: LLVMModule) -> Result<(), String> {
132        unsafe {
133            let tsctx = LLVMOrcCreateNewThreadSafeContext();
134            let tsm = LLVMOrcCreateNewThreadSafeModule(module.inner_ref(), tsctx);
135            let main_jd = lljit::LLVMOrcLLJITGetMainJITDylib(self.0);
136            let err = lljit::LLVMOrcLLJITAddLLVMIRModule(self.0, main_jd, tsm);
137            // The underlying LLVMContext will be kept alive by our ThreadSafeModule
138            // (See OrcV2CBindingsBasicUsage.c)
139            LLVMOrcDisposeThreadSafeContext(tsctx);
140            // Ownership of the module has been transferred to the JIT
141            std::mem::forget(module);
142            handle_err(err).inspect_err(|_| {
143                // Dispose of the ThreadSafeModule on error
144                LLVMOrcDisposeThreadSafeModule(tsm);
145            })
146        }
147    }
148
149    /// Lookup a symbol in the JIT.
150    pub fn lookup_symbol(&self, name: &str) -> Result<u64, String> {
151        unsafe {
152            let mut addr = MaybeUninit::uninit();
153            let err = lljit::LLVMOrcLLJITLookup(self.0, addr.as_mut_ptr(), to_c_str(name).as_ptr());
154            handle_err(err)?;
155            Ok(addr.assume_init())
156        }
157    }
158
159    /// Get the target triple string for this JIT instance.
160    pub fn get_triple_string(&self) -> String {
161        unsafe {
162            let triple_ptr = lljit::LLVMOrcLLJITGetTripleString(self.0);
163            cstr_to_string(triple_ptr).unwrap()
164        }
165    }
166
167    /// Add a symbol mapping to the JIT's main DyLib
168    pub fn add_symbol_mapping(
169        &self,
170        name: &str,
171        addr: u64,
172        flags: JITSymbolGenericFlags,
173    ) -> Result<(), String> {
174        let symbol_pool_ref =
175            unsafe { lljit::LLVMOrcLLJITMangleAndIntern(self.0, to_c_str(name).as_ptr()) };
176
177        let jit_evaluated_symbol = LLVMJITEvaluatedSymbol {
178            Address: addr,
179            Flags: LLVMJITSymbolFlags {
180                GenericFlags: flags.into(),
181                TargetFlags: 0,
182            },
183        };
184
185        let mut symbol_pair = LLVMOrcCSymbolMapPair {
186            Name: symbol_pool_ref,
187            Sym: jit_evaluated_symbol,
188        };
189
190        let materialization_unit = unsafe { LLVMOrcAbsoluteSymbols(&mut symbol_pair as *mut _, 1) };
191        let main_dylib = unsafe { lljit::LLVMOrcLLJITGetMainJITDylib(self.0) };
192
193        let res =
194            unsafe { llvm_sys::orc2::LLVMOrcJITDylibDefine(main_dylib, materialization_unit) };
195        handle_err(res)
196    }
197}
198
199impl Drop for LLVMLLJIT {
200    fn drop(&mut self) {
201        unsafe {
202            let err = lljit::LLVMOrcDisposeLLJIT(self.0);
203            if let Err(err) = handle_err(err) {
204                panic!("Error disposing LLJIT: {}", err);
205            }
206        }
207    }
208}