Skip to main content

pliron_llvm/
function_call_utils.rs

1//! Helper functions to call common simple C functions
2
3use pliron::{
4    arg_err,
5    builtin::{
6        op_interfaces::{OneResultInterface, SymbolTableInterface},
7        types::{IntegerType, Signedness},
8    },
9    context::{Context, Ptr},
10    identifier::Identifier,
11    irbuild::inserter::Inserter,
12    result::Result,
13    symbol_table::SymbolTableCollection,
14    r#type::TypeObj,
15    value::Value,
16};
17
18use crate::{
19    op_interfaces::CastOpInterface,
20    ops::{FuncOp, GepIndex, GetElementPtrOp, PtrToIntOp, ZeroOp},
21    types::{FuncType, PointerType, VoidType},
22};
23
24#[derive(Debug, thiserror::Error)]
25pub enum LookupOrInsertFunctionError {
26    #[error("Symbol '{0}' found but is not a function")]
27    SymbolNotFunction(Identifier),
28    #[error("Existing function '{0}' has a different type than the one being inserted")]
29    FunctionTypeMismatch(Identifier),
30}
31
32/// Looks up a function by name in the given symbol table.
33/// If it exists, checks that its type matches the provided type.
34/// If it doesn't exist, inserts a new function with the given name and type.
35pub fn lookup_or_insert_function(
36    ctx: &mut Context,
37    symbol_table_collection: &mut SymbolTableCollection,
38    symbol_table_op: Box<dyn SymbolTableInterface>,
39    name: Identifier,
40    return_type: Ptr<TypeObj>,
41    param_types: Vec<Ptr<TypeObj>>,
42    is_var_arg: bool,
43) -> Result<FuncOp> {
44    let loc = symbol_table_op.loc(ctx);
45    let func_ty = FuncType::get(ctx, return_type, param_types, is_var_arg);
46    let symbol_table = symbol_table_collection.get_symbol_table(ctx, symbol_table_op.clone());
47    if let Some(func) = symbol_table.lookup(&name) {
48        if let Some(func_op) = func.as_any().downcast_ref::<FuncOp>() {
49            let existing_func_ty = func_op.get_type(ctx);
50            if existing_func_ty != func_ty {
51                return arg_err!(loc, LookupOrInsertFunctionError::FunctionTypeMismatch(name));
52            }
53            Ok(*func_op)
54        } else {
55            arg_err!(loc, LookupOrInsertFunctionError::SymbolNotFunction(name))
56        }
57    } else {
58        let func = FuncOp::new(ctx, name.clone(), func_ty);
59        symbol_table.insert(ctx, Box::new(func), None)?;
60        Ok(func)
61    }
62}
63
64/// Get the type used to represet size
65pub fn get_size_type(ctx: &mut Context) -> Ptr<TypeObj> {
66    IntegerType::get(ctx, 64, Signedness::Signless).into()
67}
68
69/// Get a declaration to the `malloc` function,
70/// inserting it if it doesn't already exist in the symbol table.
71pub fn lookup_or_create_malloc_fn(
72    ctx: &mut Context,
73    symbol_table_collection: &mut SymbolTableCollection,
74    symbol_table_op: Box<dyn SymbolTableInterface>,
75) -> Result<FuncOp> {
76    let size_ty = get_size_type(ctx);
77    lookup_or_insert_function(
78        ctx,
79        symbol_table_collection,
80        symbol_table_op,
81        "malloc".try_into().unwrap(),
82        PointerType::get(ctx).into(),
83        vec![size_ty],
84        false,
85    )
86}
87
88/// Get a declaration to the `free` function,
89/// inserting it if it doesn't already exist in the symbol table.
90pub fn lookup_or_create_free_fn(
91    ctx: &mut Context,
92    symbol_table_collection: &mut SymbolTableCollection,
93    symbol_table_op: Box<dyn SymbolTableInterface>,
94) -> Result<FuncOp> {
95    let ptr_ty = PointerType::get(ctx).into();
96    lookup_or_insert_function(
97        ctx,
98        symbol_table_collection,
99        symbol_table_op,
100        "free".try_into().unwrap(),
101        VoidType::get(ctx).into(),
102        vec![ptr_ty],
103        false,
104    )
105}
106
107/// Compute size of a type in bytes
108pub fn compute_type_size_in_bytes(
109    ctx: &mut Context,
110    inserter: &mut dyn Inserter,
111    ty: Ptr<TypeObj>,
112) -> Value {
113    // This is LLVM's expansion for sizeof
114    // (as per a comment in MLIR's `ConvertToLLVMPattern::getSizeInBytes`)
115    //   %0 = getelementptr %ty* null, %sizeType 1
116    //   %1 = ptrtoint %ty* %0 to %sizeType
117    let size_ty = get_size_type(ctx);
118    let pointer_ty = PointerType::get(ctx).into();
119    let zero_op = ZeroOp::new(ctx, pointer_ty);
120    inserter.append_op(ctx, &zero_op);
121    let gep_op = GetElementPtrOp::new(
122        ctx,
123        zero_op.get_result(ctx),
124        vec![GepIndex::Constant(1)],
125        ty,
126    );
127    inserter.append_op(ctx, &gep_op);
128    let ptr_to_int_op = PtrToIntOp::new(ctx, gep_op.get_result(ctx), size_ty);
129    inserter.append_op(ctx, &ptr_to_int_op);
130    ptr_to_int_op.get_result(ctx)
131}
132
133#[cfg(test)]
134mod tests {
135    use expect_test::expect;
136    use pliron::{
137        builtin::{
138            op_interfaces::{
139                CallOpCallable, OneResultInterface, SingleBlockRegionInterface, SymbolOpInterface,
140            },
141            ops::ModuleOp,
142            types::FP64Type,
143        },
144        context::Context,
145        init_env_logger_for_tests,
146        irbuild::{
147            inserter::{IRInserter, Inserter, OpInsertionPoint},
148            listener::DummyListener,
149        },
150        op::{Op, verify_op},
151        result::ExpectOk,
152    };
153
154    use crate::{
155        function_call_utils::{
156            compute_type_size_in_bytes, get_size_type, lookup_or_create_free_fn,
157            lookup_or_create_malloc_fn,
158        },
159        llvm_sys::{core::LLVMContext, lljit::LLVMLLJIT, target},
160        ops::{CallOp, FuncOp, ReturnOp},
161        to_llvm_ir::convert_module,
162        types::FuncType,
163    };
164
165    #[test]
166    fn test_malloc_and_free_integration() {
167        init_env_logger_for_tests!();
168        let mut ctx = Context::new();
169        let mut symbol_table_collection = pliron::symbol_table::SymbolTableCollection::new();
170
171        // Create a module
172        let module = ModuleOp::new(&mut ctx, "test_module".try_into().unwrap());
173        let module_box = Box::new(module);
174
175        // Get malloc function
176        let malloc_fn =
177            lookup_or_create_malloc_fn(&mut ctx, &mut symbol_table_collection, module_box.clone())
178                .expect("Failed to create malloc function");
179
180        // Get free function
181        let free_fn =
182            lookup_or_create_free_fn(&mut ctx, &mut symbol_table_collection, module_box.clone())
183                .expect("Failed to create free function");
184
185        // Verify both functions were created
186        assert_eq!(
187            malloc_fn.get_symbol_name(&ctx),
188            "malloc".try_into().unwrap()
189        );
190        assert_eq!(free_fn.get_symbol_name(&ctx), "free".try_into().unwrap());
191
192        // Verify calling them again returns the same functions
193        let malloc_fn_2 =
194            lookup_or_create_malloc_fn(&mut ctx, &mut symbol_table_collection, module_box.clone())
195                .expect("Failed to get malloc function again");
196
197        assert!(
198            malloc_fn == malloc_fn_2,
199            "Expected to get the same malloc function on second lookup"
200        );
201
202        // Create a main function
203        let return_type = get_size_type(&mut ctx);
204        let func_ty = FuncType::get(&mut ctx, return_type, vec![], false);
205        let main_fn = FuncOp::new(&mut ctx, "main".try_into().unwrap(), func_ty);
206        main_fn
207            .get_operation()
208            .insert_at_front(module.get_body(&ctx, 0), &ctx);
209
210        // Insert calls to malloc and free in the entry block of main
211        let entry = main_fn.get_or_create_entry_block(&mut ctx);
212        let mut inserter = IRInserter::<DummyListener>::new(OpInsertionPoint::AtBlockEnd(entry));
213
214        let fp_ty = FP64Type::get(&ctx);
215        let fp_ty_size = compute_type_size_in_bytes(&mut ctx, &mut inserter, fp_ty.into());
216
217        let callee = CallOpCallable::Direct(malloc_fn.get_symbol_name(&ctx));
218        let callee_ty = malloc_fn.get_type(&ctx);
219        let args = vec![fp_ty_size];
220        let malloc_call = CallOp::new(&mut ctx, callee, callee_ty, args);
221        inserter.append_op(&ctx, &malloc_call);
222
223        let ptr_result = malloc_call.get_result(&ctx);
224        let free_callee = CallOpCallable::Direct(free_fn.get_symbol_name(&ctx));
225        let free_callee_ty = free_fn.get_type(&ctx);
226        let free_args = vec![ptr_result];
227        let free_call = CallOp::new(&mut ctx, free_callee, free_callee_ty, free_args);
228        inserter.append_op(&ctx, &free_call);
229
230        let ret_op = ReturnOp::new(&mut ctx, Some(fp_ty_size));
231        inserter.append_op(&ctx, &ret_op);
232
233        verify_op(&module, &ctx).expect_ok(&ctx);
234
235        // Convert to LLVM
236        let llvm_ctx = LLVMContext::default();
237        let llvm_ir = convert_module(&ctx, &llvm_ctx, module).expect_ok(&ctx);
238
239        expect![[r#"
240            ; ModuleID = 'test_module'
241            source_filename = "test_module"
242
243            define i64 @main() {
244            entry_block2v1:
245              %v3 = call ptr @malloc(i64 ptrtoint (ptr getelementptr (double, ptr null, i32 1) to i64))
246              call void @free(ptr %v3)
247              ret i64 ptrtoint (ptr getelementptr (double, ptr null, i32 1) to i64)
248            }
249
250            declare ptr @malloc(i64)
251
252            declare void @free(ptr)
253        "#]]
254        .assert_eq(&llvm_ir.to_string());
255        llvm_ir.verify().expect("Generated LLVM IR is invalid");
256
257        // Execute the LLVM IR using the JIT and check it runs without errors
258        target::initialize_native().expect("Failed to initialize native target for JIT");
259        let jit = LLVMLLJIT::new_with_default_builder().expect("Failed to create LLJIT instance");
260        jit.add_module(llvm_ir)
261            .expect("Failed to add module to JIT");
262        let main_addr = jit
263            .lookup_symbol("main")
264            .expect("Failed to lookup 'main' symbol");
265        let main_fn = unsafe { std::mem::transmute::<u64, fn() -> u64>(main_addr) };
266        let fp_ty_size = main_fn();
267        assert_eq!(
268            fp_ty_size, 8,
269            "Expected size of double type to be 8 bytes, got {}",
270            fp_ty_size
271        );
272    }
273}