feat: first working version of new allocator
This commit is contained in:
@@ -6,6 +6,7 @@ use std::env;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use syn::parse_macro_input;
|
||||
use syn::spanned::Spanned;
|
||||
|
||||
fn add_line_numbers(input: String) -> String {
|
||||
return input
|
||||
@@ -16,86 +17,145 @@ fn add_line_numbers(input: String) -> String {
|
||||
.join("\n");
|
||||
}
|
||||
|
||||
fn read_node_definition(file_path: &Path) -> NodeDefinition {
|
||||
let project_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
|
||||
let full_path = Path::new(&project_dir).join(file_path);
|
||||
let json_content = fs::read_to_string(&full_path).unwrap_or_else(|err| {
|
||||
panic!(
|
||||
"Failed to read JSON file at '{}/{}': {}",
|
||||
project_dir,
|
||||
file_path.to_string_lossy(),
|
||||
err
|
||||
)
|
||||
});
|
||||
serde_json::from_str(&json_content).unwrap_or_else(|err| {
|
||||
panic!(
|
||||
"JSON file contains invalid JSON: \n{} \n{}",
|
||||
err,
|
||||
add_line_numbers(json_content.clone())
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
#[proc_macro_attribute]
|
||||
pub fn nodarium_execute(_attr: TokenStream, item: TokenStream) -> TokenStream {
|
||||
let input_fn = parse_macro_input!(item as syn::ItemFn);
|
||||
let _fn_name = &input_fn.sig.ident;
|
||||
let _fn_vis = &input_fn.vis;
|
||||
let fn_name = &input_fn.sig.ident;
|
||||
let fn_vis = &input_fn.vis;
|
||||
let fn_body = &input_fn.block;
|
||||
let inner_fn_name = syn::Ident::new(&format!("__nodarium_inner_{}", fn_name), fn_name.span());
|
||||
|
||||
let first_arg_ident = if let Some(syn::FnArg::Typed(pat_type)) = input_fn.sig.inputs.first() {
|
||||
if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
|
||||
&pat_ident.ident
|
||||
} else {
|
||||
panic!("Expected a simple identifier for the first argument");
|
||||
}
|
||||
} else {
|
||||
panic!("The execute function must have at least one argument (the input slice)");
|
||||
};
|
||||
let def: NodeDefinition = read_node_definition(Path::new("src/input.json"));
|
||||
|
||||
let input_count = def.inputs.as_ref().map(|i| i.len()).unwrap_or(0);
|
||||
|
||||
validate_signature(&input_fn.sig, input_count, &def);
|
||||
|
||||
let input_param_names: Vec<_> = input_fn
|
||||
.sig
|
||||
.inputs
|
||||
.iter()
|
||||
.filter_map(|arg| {
|
||||
if let syn::FnArg::Typed(pat_type) = arg {
|
||||
if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
|
||||
Some(pat_ident.ident.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let arg_names: Vec<_> = (0..input_count)
|
||||
.map(|i| syn::Ident::new(&format!("arg{}", i), input_fn.sig.span()))
|
||||
.collect();
|
||||
|
||||
// We create a wrapper that handles the C ABI and pointer math
|
||||
let expanded = quote! {
|
||||
extern "C" {
|
||||
fn host_log_panic(ptr: *const u8, len: usize);
|
||||
fn host_log(ptr: *const u8, len: usize);
|
||||
fn __nodarium_log(ptr: *const u8, len: usize);
|
||||
fn __nodarium_log_panic(ptr: *const u8, len: usize);
|
||||
}
|
||||
|
||||
fn setup_panic_hook() {
|
||||
static SET_HOOK: std::sync::Once = std::sync::Once::new();
|
||||
SET_HOOK.call_once(|| {
|
||||
#fn_vis fn #inner_fn_name(#( #input_param_names: *const i32 ),*) -> Vec<i32> {
|
||||
#fn_body
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
#fn_vis extern "C" fn execute(output_pos: i32, #( #arg_names: i32 ),*) -> i32 {
|
||||
static PANIC_HOOK_SET: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
|
||||
|
||||
if !PANIC_HOOK_SET.load(std::sync::atomic::Ordering::SeqCst) {
|
||||
std::panic::set_hook(Box::new(|info| {
|
||||
let msg = info.to_string();
|
||||
unsafe { host_log_panic(msg.as_ptr(), msg.len()); }
|
||||
unsafe { __nodarium_log_panic(msg.as_ptr(), msg.len()); }
|
||||
}));
|
||||
});
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn __alloc(len: usize) -> *mut i32 {
|
||||
let mut buf = Vec::with_capacity(len);
|
||||
let ptr = buf.as_mut_ptr();
|
||||
std::mem::forget(buf);
|
||||
ptr
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn __free(ptr: *mut i32, len: usize) {
|
||||
unsafe {
|
||||
let _ = Vec::from_raw_parts(ptr, 0, len);
|
||||
PANIC_HOOK_SET.store(true, std::sync::atomic::Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
static mut OUTPUT_BUFFER: Vec<i32> = Vec::new();
|
||||
let result = #inner_fn_name(
|
||||
#( #arg_names as *const i32 ),*
|
||||
);
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn execute(ptr: *const i32, len: usize) -> *mut i32 {
|
||||
setup_panic_hook();
|
||||
// 1. Convert raw pointer to slice
|
||||
let input = unsafe { core::slice::from_raw_parts(ptr, len) };
|
||||
|
||||
// 2. Call the logic (which we define below)
|
||||
let result_data: Vec<i32> = internal_logic(input);
|
||||
|
||||
// 3. Use the static buffer for the result
|
||||
let result_len = result_data.len();
|
||||
let len_bytes = result.len() * 4;
|
||||
unsafe {
|
||||
OUTPUT_BUFFER.clear();
|
||||
OUTPUT_BUFFER.reserve(result_len + 1);
|
||||
OUTPUT_BUFFER.push(result_len as i32);
|
||||
OUTPUT_BUFFER.extend(result_data);
|
||||
|
||||
OUTPUT_BUFFER.as_mut_ptr()
|
||||
let src = result.as_ptr() as *const u8;
|
||||
let dst = output_pos as *mut u8;
|
||||
dst.copy_from_nonoverlapping(src, len_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
fn internal_logic(#first_arg_ident: &[i32]) -> Vec<i32> {
|
||||
#fn_body
|
||||
core::mem::forget(result);
|
||||
|
||||
len_bytes as i32
|
||||
}
|
||||
};
|
||||
|
||||
TokenStream::from(expanded)
|
||||
}
|
||||
|
||||
fn validate_signature(fn_sig: &syn::Signature, expected_inputs: usize, def: &NodeDefinition) {
|
||||
let param_count = fn_sig.inputs.len();
|
||||
if param_count != expected_inputs {
|
||||
panic!(
|
||||
"Execute function has {} parameters but definition has {} inputs\n\
|
||||
Definition inputs: {:?}\n\
|
||||
Expected signature:\n\
|
||||
pub fn execute({}) -> Vec<i32>",
|
||||
param_count,
|
||||
expected_inputs,
|
||||
def.inputs
|
||||
.as_ref()
|
||||
.map(|i| i.keys().collect::<Vec<_>>())
|
||||
.unwrap_or_default(),
|
||||
(0..expected_inputs)
|
||||
.map(|i| format!("arg{}: *const i32", i))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
);
|
||||
}
|
||||
|
||||
match &fn_sig.output {
|
||||
syn::ReturnType::Type(_, ty) => {
|
||||
let is_vec = match &**ty {
|
||||
syn::Type::Path(tp) => tp
|
||||
.path
|
||||
.segments
|
||||
.first()
|
||||
.map(|seg| seg.ident == "Vec")
|
||||
.unwrap_or(false),
|
||||
_ => false,
|
||||
};
|
||||
if !is_vec {
|
||||
panic!("Execute function must return Vec<i32>");
|
||||
}
|
||||
}
|
||||
syn::ReturnType::Default => {
|
||||
panic!("Execute function must return Vec<i32>");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[proc_macro]
|
||||
pub fn nodarium_definition_file(input: TokenStream) -> TokenStream {
|
||||
let path_lit = syn::parse_macro_input!(input as syn::LitStr);
|
||||
@@ -105,30 +165,26 @@ pub fn nodarium_definition_file(input: TokenStream) -> TokenStream {
|
||||
let full_path = Path::new(&project_dir).join(&file_path);
|
||||
|
||||
let json_content = fs::read_to_string(&full_path).unwrap_or_else(|err| {
|
||||
panic!("Failed to read JSON file at '{}/{}': {}", project_dir, file_path, err)
|
||||
panic!(
|
||||
"Failed to read JSON file at '{}/{}': {}",
|
||||
project_dir, file_path, err
|
||||
)
|
||||
});
|
||||
|
||||
let _: NodeDefinition = serde_json::from_str(&json_content).unwrap_or_else(|err| {
|
||||
panic!("JSON file contains invalid JSON: \n{} \n{}", err, add_line_numbers(json_content.clone()))
|
||||
panic!(
|
||||
"JSON file contains invalid JSON: \n{} \n{}",
|
||||
err,
|
||||
add_line_numbers(json_content.clone())
|
||||
)
|
||||
});
|
||||
|
||||
// We use the span from the input path literal
|
||||
let bytes = syn::LitByteStr::new(json_content.as_bytes(), path_lit.span());
|
||||
let len = json_content.len();
|
||||
|
||||
let expanded = quote! {
|
||||
#[link_section = "nodarium_definition"]
|
||||
static DEFINITION_DATA: [u8; #len] = *#bytes;
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn get_definition_ptr() -> *const u8 {
|
||||
DEFINITION_DATA.as_ptr()
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn get_definition_len() -> usize {
|
||||
DEFINITION_DATA.len()
|
||||
}
|
||||
};
|
||||
|
||||
TokenStream::from(expanded)
|
||||
|
||||
Reference in New Issue
Block a user