pliron_llvm/llvm_sys/
lljit.rs1use bitflags::bitflags;
45use std::{mem::MaybeUninit, ptr};
46
47use llvm_sys::orc2::{
48 LLVMJITEvaluatedSymbol, LLVMJITSymbolFlags, LLVMJITSymbolGenericFlags, LLVMOrcAbsoluteSymbols,
49 LLVMOrcCSymbolMapPair, LLVMOrcCreateNewThreadSafeContext, LLVMOrcCreateNewThreadSafeModule,
50 LLVMOrcDisposeThreadSafeContext, LLVMOrcDisposeThreadSafeModule, lljit,
51};
52
53use crate::llvm_sys::{
54 core::{LLVMModule, handle_err},
55 cstr_to_string, to_c_str,
56};
57
58bitflags! {
59 #[derive(PartialEq, Eq, Clone, Debug, Hash, Copy)]
60 pub struct JITSymbolGenericFlags: u8 {
61 const JITSymbolGenericFlagsNone = 0;
62 const JITSymbolGenericFlagsExported = 1;
63 const JITSymbolGenericFlagsWeak = 2;
64 const JITSymbolGenericFlagsCallable = 4;
65 const JITSymbolGenericFlagsMaterializationSideEffectsOnly = 8;
66 }
67}
68
69impl From<LLVMJITSymbolGenericFlags> for JITSymbolGenericFlags {
70 fn from(value: LLVMJITSymbolGenericFlags) -> Self {
71 let mut flags = JITSymbolGenericFlags::empty();
72 if (value as u8) & (LLVMJITSymbolGenericFlags::LLVMJITSymbolGenericFlagsExported as u8) != 0
73 {
74 flags |= JITSymbolGenericFlags::JITSymbolGenericFlagsExported;
75 }
76 if (value as u8) & (LLVMJITSymbolGenericFlags::LLVMJITSymbolGenericFlagsWeak as u8) != 0 {
77 flags |= JITSymbolGenericFlags::JITSymbolGenericFlagsWeak;
78 }
79 if (value as u8) & (LLVMJITSymbolGenericFlags::LLVMJITSymbolGenericFlagsCallable as u8) != 0
80 {
81 flags |= JITSymbolGenericFlags::JITSymbolGenericFlagsCallable;
82 }
83 if (value as u8)
84 & (LLVMJITSymbolGenericFlags::LLVMJITSymbolGenericFlagsMaterializationSideEffectsOnly
85 as u8)
86 != 0
87 {
88 flags |= JITSymbolGenericFlags::JITSymbolGenericFlagsMaterializationSideEffectsOnly;
89 }
90 flags
91 }
92}
93
94impl From<JITSymbolGenericFlags> for u8 {
95 fn from(value: JITSymbolGenericFlags) -> Self {
96 let mut flags = LLVMJITSymbolGenericFlags::LLVMJITSymbolGenericFlagsNone as u8;
97 if value.contains(JITSymbolGenericFlags::JITSymbolGenericFlagsExported) {
98 flags |= LLVMJITSymbolGenericFlags::LLVMJITSymbolGenericFlagsExported as u8;
99 }
100 if value.contains(JITSymbolGenericFlags::JITSymbolGenericFlagsWeak) {
101 flags |= LLVMJITSymbolGenericFlags::LLVMJITSymbolGenericFlagsWeak as u8;
102 }
103 if value.contains(JITSymbolGenericFlags::JITSymbolGenericFlagsCallable) {
104 flags |= LLVMJITSymbolGenericFlags::LLVMJITSymbolGenericFlagsCallable as u8;
105 }
106 if value
107 .contains(JITSymbolGenericFlags::JITSymbolGenericFlagsMaterializationSideEffectsOnly)
108 {
109 flags |=
110 LLVMJITSymbolGenericFlags::LLVMJITSymbolGenericFlagsMaterializationSideEffectsOnly
111 as u8;
112 }
113 flags
114 }
115}
116
117pub struct LLVMLLJIT(lljit::LLVMOrcLLJITRef);
118
119impl LLVMLLJIT {
120 pub fn new_with_default_builder() -> Result<Self, String> {
122 unsafe {
123 let mut jit = MaybeUninit::uninit();
124 let err = lljit::LLVMOrcCreateLLJIT(jit.as_mut_ptr(), ptr::null_mut());
125 handle_err(err)?;
126 Ok(LLVMLLJIT(jit.assume_init()))
127 }
128 }
129
130 pub fn add_module(&self, module: LLVMModule) -> Result<(), String> {
132 unsafe {
133 let tsctx = LLVMOrcCreateNewThreadSafeContext();
134 let tsm = LLVMOrcCreateNewThreadSafeModule(module.inner_ref(), tsctx);
135 let main_jd = lljit::LLVMOrcLLJITGetMainJITDylib(self.0);
136 let err = lljit::LLVMOrcLLJITAddLLVMIRModule(self.0, main_jd, tsm);
137 LLVMOrcDisposeThreadSafeContext(tsctx);
140 std::mem::forget(module);
142 handle_err(err).inspect_err(|_| {
143 LLVMOrcDisposeThreadSafeModule(tsm);
145 })
146 }
147 }
148
149 pub fn lookup_symbol(&self, name: &str) -> Result<u64, String> {
151 unsafe {
152 let mut addr = MaybeUninit::uninit();
153 let err = lljit::LLVMOrcLLJITLookup(self.0, addr.as_mut_ptr(), to_c_str(name).as_ptr());
154 handle_err(err)?;
155 Ok(addr.assume_init())
156 }
157 }
158
159 pub fn get_triple_string(&self) -> String {
161 unsafe {
162 let triple_ptr = lljit::LLVMOrcLLJITGetTripleString(self.0);
163 cstr_to_string(triple_ptr).unwrap()
164 }
165 }
166
167 pub fn add_symbol_mapping(
169 &self,
170 name: &str,
171 addr: u64,
172 flags: JITSymbolGenericFlags,
173 ) -> Result<(), String> {
174 let symbol_pool_ref =
175 unsafe { lljit::LLVMOrcLLJITMangleAndIntern(self.0, to_c_str(name).as_ptr()) };
176
177 let jit_evaluated_symbol = LLVMJITEvaluatedSymbol {
178 Address: addr,
179 Flags: LLVMJITSymbolFlags {
180 GenericFlags: flags.into(),
181 TargetFlags: 0,
182 },
183 };
184
185 let mut symbol_pair = LLVMOrcCSymbolMapPair {
186 Name: symbol_pool_ref,
187 Sym: jit_evaluated_symbol,
188 };
189
190 let materialization_unit = unsafe { LLVMOrcAbsoluteSymbols(&mut symbol_pair as *mut _, 1) };
191 let main_dylib = unsafe { lljit::LLVMOrcLLJITGetMainJITDylib(self.0) };
192
193 let res =
194 unsafe { llvm_sys::orc2::LLVMOrcJITDylibDefine(main_dylib, materialization_unit) };
195 handle_err(res)
196 }
197}
198
199impl Drop for LLVMLLJIT {
200 fn drop(&mut self) {
201 unsafe {
202 let err = lljit::LLVMOrcDisposeLLJIT(self.0);
203 if let Err(err) = handle_err(err) {
204 panic!("Error disposing LLJIT: {}", err);
205 }
206 }
207 }
208}