1use pliron::{
4 arg_err,
5 builtin::{
6 op_interfaces::{OneResultInterface, SymbolTableInterface},
7 types::{IntegerType, Signedness},
8 },
9 context::{Context, Ptr},
10 identifier::Identifier,
11 irbuild::inserter::Inserter,
12 result::Result,
13 symbol_table::SymbolTableCollection,
14 r#type::TypeObj,
15 value::Value,
16};
17
18use crate::{
19 op_interfaces::CastOpInterface,
20 ops::{FuncOp, GepIndex, GetElementPtrOp, PtrToIntOp, ZeroOp},
21 types::{FuncType, PointerType, VoidType},
22};
23
24#[derive(Debug, thiserror::Error)]
25pub enum LookupOrInsertFunctionError {
26 #[error("Symbol '{0}' found but is not a function")]
27 SymbolNotFunction(Identifier),
28 #[error("Existing function '{0}' has a different type than the one being inserted")]
29 FunctionTypeMismatch(Identifier),
30}
31
32pub fn lookup_or_insert_function(
36 ctx: &mut Context,
37 symbol_table_collection: &mut SymbolTableCollection,
38 symbol_table_op: Box<dyn SymbolTableInterface>,
39 name: Identifier,
40 return_type: Ptr<TypeObj>,
41 param_types: Vec<Ptr<TypeObj>>,
42 is_var_arg: bool,
43) -> Result<FuncOp> {
44 let loc = symbol_table_op.loc(ctx);
45 let func_ty = FuncType::get(ctx, return_type, param_types, is_var_arg);
46 let symbol_table = symbol_table_collection.get_symbol_table(ctx, symbol_table_op.clone());
47 if let Some(func) = symbol_table.lookup(&name) {
48 if let Some(func_op) = func.as_any().downcast_ref::<FuncOp>() {
49 let existing_func_ty = func_op.get_type(ctx);
50 if existing_func_ty != func_ty {
51 return arg_err!(loc, LookupOrInsertFunctionError::FunctionTypeMismatch(name));
52 }
53 Ok(*func_op)
54 } else {
55 arg_err!(loc, LookupOrInsertFunctionError::SymbolNotFunction(name))
56 }
57 } else {
58 let func = FuncOp::new(ctx, name.clone(), func_ty);
59 symbol_table.insert(ctx, Box::new(func), None)?;
60 Ok(func)
61 }
62}
63
64pub fn get_size_type(ctx: &mut Context) -> Ptr<TypeObj> {
66 IntegerType::get(ctx, 64, Signedness::Signless).into()
67}
68
69pub fn lookup_or_create_malloc_fn(
72 ctx: &mut Context,
73 symbol_table_collection: &mut SymbolTableCollection,
74 symbol_table_op: Box<dyn SymbolTableInterface>,
75) -> Result<FuncOp> {
76 let size_ty = get_size_type(ctx);
77 lookup_or_insert_function(
78 ctx,
79 symbol_table_collection,
80 symbol_table_op,
81 "malloc".try_into().unwrap(),
82 PointerType::get(ctx).into(),
83 vec![size_ty],
84 false,
85 )
86}
87
88pub fn lookup_or_create_free_fn(
91 ctx: &mut Context,
92 symbol_table_collection: &mut SymbolTableCollection,
93 symbol_table_op: Box<dyn SymbolTableInterface>,
94) -> Result<FuncOp> {
95 let ptr_ty = PointerType::get(ctx).into();
96 lookup_or_insert_function(
97 ctx,
98 symbol_table_collection,
99 symbol_table_op,
100 "free".try_into().unwrap(),
101 VoidType::get(ctx).into(),
102 vec![ptr_ty],
103 false,
104 )
105}
106
107pub fn compute_type_size_in_bytes(
109 ctx: &mut Context,
110 inserter: &mut dyn Inserter,
111 ty: Ptr<TypeObj>,
112) -> Value {
113 let size_ty = get_size_type(ctx);
118 let pointer_ty = PointerType::get(ctx).into();
119 let zero_op = ZeroOp::new(ctx, pointer_ty);
120 inserter.append_op(ctx, &zero_op);
121 let gep_op = GetElementPtrOp::new(
122 ctx,
123 zero_op.get_result(ctx),
124 vec![GepIndex::Constant(1)],
125 ty,
126 );
127 inserter.append_op(ctx, &gep_op);
128 let ptr_to_int_op = PtrToIntOp::new(ctx, gep_op.get_result(ctx), size_ty);
129 inserter.append_op(ctx, &ptr_to_int_op);
130 ptr_to_int_op.get_result(ctx)
131}
132
133#[cfg(test)]
134mod tests {
135 use expect_test::expect;
136 use pliron::{
137 builtin::{
138 op_interfaces::{
139 CallOpCallable, OneResultInterface, SingleBlockRegionInterface, SymbolOpInterface,
140 },
141 ops::ModuleOp,
142 types::FP64Type,
143 },
144 context::Context,
145 init_env_logger_for_tests,
146 irbuild::{
147 inserter::{IRInserter, Inserter, OpInsertionPoint},
148 listener::DummyListener,
149 },
150 op::{Op, verify_op},
151 result::ExpectOk,
152 };
153
154 use crate::{
155 function_call_utils::{
156 compute_type_size_in_bytes, get_size_type, lookup_or_create_free_fn,
157 lookup_or_create_malloc_fn,
158 },
159 llvm_sys::{core::LLVMContext, lljit::LLVMLLJIT, target},
160 ops::{CallOp, FuncOp, ReturnOp},
161 to_llvm_ir::convert_module,
162 types::FuncType,
163 };
164
165 #[test]
166 fn test_malloc_and_free_integration() {
167 init_env_logger_for_tests!();
168 let mut ctx = Context::new();
169 let mut symbol_table_collection = pliron::symbol_table::SymbolTableCollection::new();
170
171 let module = ModuleOp::new(&mut ctx, "test_module".try_into().unwrap());
173 let module_box = Box::new(module);
174
175 let malloc_fn =
177 lookup_or_create_malloc_fn(&mut ctx, &mut symbol_table_collection, module_box.clone())
178 .expect("Failed to create malloc function");
179
180 let free_fn =
182 lookup_or_create_free_fn(&mut ctx, &mut symbol_table_collection, module_box.clone())
183 .expect("Failed to create free function");
184
185 assert_eq!(
187 malloc_fn.get_symbol_name(&ctx),
188 "malloc".try_into().unwrap()
189 );
190 assert_eq!(free_fn.get_symbol_name(&ctx), "free".try_into().unwrap());
191
192 let malloc_fn_2 =
194 lookup_or_create_malloc_fn(&mut ctx, &mut symbol_table_collection, module_box.clone())
195 .expect("Failed to get malloc function again");
196
197 assert!(
198 malloc_fn == malloc_fn_2,
199 "Expected to get the same malloc function on second lookup"
200 );
201
202 let return_type = get_size_type(&mut ctx);
204 let func_ty = FuncType::get(&mut ctx, return_type, vec![], false);
205 let main_fn = FuncOp::new(&mut ctx, "main".try_into().unwrap(), func_ty);
206 main_fn
207 .get_operation()
208 .insert_at_front(module.get_body(&ctx, 0), &ctx);
209
210 let entry = main_fn.get_or_create_entry_block(&mut ctx);
212 let mut inserter = IRInserter::<DummyListener>::new(OpInsertionPoint::AtBlockEnd(entry));
213
214 let fp_ty = FP64Type::get(&ctx);
215 let fp_ty_size = compute_type_size_in_bytes(&mut ctx, &mut inserter, fp_ty.into());
216
217 let callee = CallOpCallable::Direct(malloc_fn.get_symbol_name(&ctx));
218 let callee_ty = malloc_fn.get_type(&ctx);
219 let args = vec![fp_ty_size];
220 let malloc_call = CallOp::new(&mut ctx, callee, callee_ty, args);
221 inserter.append_op(&ctx, &malloc_call);
222
223 let ptr_result = malloc_call.get_result(&ctx);
224 let free_callee = CallOpCallable::Direct(free_fn.get_symbol_name(&ctx));
225 let free_callee_ty = free_fn.get_type(&ctx);
226 let free_args = vec![ptr_result];
227 let free_call = CallOp::new(&mut ctx, free_callee, free_callee_ty, free_args);
228 inserter.append_op(&ctx, &free_call);
229
230 let ret_op = ReturnOp::new(&mut ctx, Some(fp_ty_size));
231 inserter.append_op(&ctx, &ret_op);
232
233 verify_op(&module, &ctx).expect_ok(&ctx);
234
235 let llvm_ctx = LLVMContext::default();
237 let llvm_ir = convert_module(&ctx, &llvm_ctx, module).expect_ok(&ctx);
238
239 expect![[r#"
240 ; ModuleID = 'test_module'
241 source_filename = "test_module"
242
243 define i64 @main() {
244 entry_block2v1:
245 %v3 = call ptr @malloc(i64 ptrtoint (ptr getelementptr (double, ptr null, i32 1) to i64))
246 call void @free(ptr %v3)
247 ret i64 ptrtoint (ptr getelementptr (double, ptr null, i32 1) to i64)
248 }
249
250 declare ptr @malloc(i64)
251
252 declare void @free(ptr)
253 "#]]
254 .assert_eq(&llvm_ir.to_string());
255 llvm_ir.verify().expect("Generated LLVM IR is invalid");
256
257 target::initialize_native().expect("Failed to initialize native target for JIT");
259 let jit = LLVMLLJIT::new_with_default_builder().expect("Failed to create LLJIT instance");
260 jit.add_module(llvm_ir)
261 .expect("Failed to add module to JIT");
262 let main_addr = jit
263 .lookup_symbol("main")
264 .expect("Failed to lookup 'main' symbol");
265 let main_fn = unsafe { std::mem::transmute::<u64, fn() -> u64>(main_addr) };
266 let fp_ty_size = main_fn();
267 assert_eq!(
268 fp_ty_size, 8,
269 "Expected size of double type to be 8 bytes, got {}",
270 fp_ty_size
271 );
272 }
273}