feat: first working version of new allocator

This commit is contained in:
Max Richter
2026-01-22 18:48:16 +01:00
parent 841b447ac3
commit 47882a832d
22 changed files with 1668 additions and 273 deletions

View File

@@ -6,6 +6,7 @@ use std::env;
use std::fs;
use std::path::Path;
use syn::parse_macro_input;
use syn::spanned::Spanned;
fn add_line_numbers(input: String) -> String {
return input
@@ -16,86 +17,145 @@ fn add_line_numbers(input: String) -> String {
.join("\n");
}
fn read_node_definition(file_path: &Path) -> NodeDefinition {
let project_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
let full_path = Path::new(&project_dir).join(file_path);
let json_content = fs::read_to_string(&full_path).unwrap_or_else(|err| {
panic!(
"Failed to read JSON file at '{}/{}': {}",
project_dir,
file_path.to_string_lossy(),
err
)
});
serde_json::from_str(&json_content).unwrap_or_else(|err| {
panic!(
"JSON file contains invalid JSON: \n{} \n{}",
err,
add_line_numbers(json_content.clone())
)
})
}
#[proc_macro_attribute]
pub fn nodarium_execute(_attr: TokenStream, item: TokenStream) -> TokenStream {
let input_fn = parse_macro_input!(item as syn::ItemFn);
let _fn_name = &input_fn.sig.ident;
let _fn_vis = &input_fn.vis;
let fn_name = &input_fn.sig.ident;
let fn_vis = &input_fn.vis;
let fn_body = &input_fn.block;
let inner_fn_name = syn::Ident::new(&format!("__nodarium_inner_{}", fn_name), fn_name.span());
let first_arg_ident = if let Some(syn::FnArg::Typed(pat_type)) = input_fn.sig.inputs.first() {
if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
&pat_ident.ident
} else {
panic!("Expected a simple identifier for the first argument");
}
} else {
panic!("The execute function must have at least one argument (the input slice)");
};
let def: NodeDefinition = read_node_definition(Path::new("src/input.json"));
let input_count = def.inputs.as_ref().map(|i| i.len()).unwrap_or(0);
validate_signature(&input_fn.sig, input_count, &def);
let input_param_names: Vec<_> = input_fn
.sig
.inputs
.iter()
.filter_map(|arg| {
if let syn::FnArg::Typed(pat_type) = arg {
if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
Some(pat_ident.ident.clone())
} else {
None
}
} else {
None
}
})
.collect();
let arg_names: Vec<_> = (0..input_count)
.map(|i| syn::Ident::new(&format!("arg{}", i), input_fn.sig.span()))
.collect();
// We create a wrapper that handles the C ABI and pointer math
let expanded = quote! {
extern "C" {
fn host_log_panic(ptr: *const u8, len: usize);
fn host_log(ptr: *const u8, len: usize);
fn __nodarium_log(ptr: *const u8, len: usize);
fn __nodarium_log_panic(ptr: *const u8, len: usize);
}
fn setup_panic_hook() {
static SET_HOOK: std::sync::Once = std::sync::Once::new();
SET_HOOK.call_once(|| {
#fn_vis fn #inner_fn_name(#( #input_param_names: *const i32 ),*) -> Vec<i32> {
#fn_body
}
#[no_mangle]
#fn_vis extern "C" fn execute(output_pos: i32, #( #arg_names: i32 ),*) -> i32 {
static PANIC_HOOK_SET: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
if !PANIC_HOOK_SET.load(std::sync::atomic::Ordering::SeqCst) {
std::panic::set_hook(Box::new(|info| {
let msg = info.to_string();
unsafe { host_log_panic(msg.as_ptr(), msg.len()); }
unsafe { __nodarium_log_panic(msg.as_ptr(), msg.len()); }
}));
});
}
#[no_mangle]
pub extern "C" fn __alloc(len: usize) -> *mut i32 {
let mut buf = Vec::with_capacity(len);
let ptr = buf.as_mut_ptr();
std::mem::forget(buf);
ptr
}
#[no_mangle]
pub extern "C" fn __free(ptr: *mut i32, len: usize) {
unsafe {
let _ = Vec::from_raw_parts(ptr, 0, len);
PANIC_HOOK_SET.store(true, std::sync::atomic::Ordering::SeqCst);
}
}
static mut OUTPUT_BUFFER: Vec<i32> = Vec::new();
let result = #inner_fn_name(
#( #arg_names as *const i32 ),*
);
#[no_mangle]
pub extern "C" fn execute(ptr: *const i32, len: usize) -> *mut i32 {
setup_panic_hook();
// 1. Convert raw pointer to slice
let input = unsafe { core::slice::from_raw_parts(ptr, len) };
// 2. Call the logic (which we define below)
let result_data: Vec<i32> = internal_logic(input);
// 3. Use the static buffer for the result
let result_len = result_data.len();
let len_bytes = result.len() * 4;
unsafe {
OUTPUT_BUFFER.clear();
OUTPUT_BUFFER.reserve(result_len + 1);
OUTPUT_BUFFER.push(result_len as i32);
OUTPUT_BUFFER.extend(result_data);
OUTPUT_BUFFER.as_mut_ptr()
let src = result.as_ptr() as *const u8;
let dst = output_pos as *mut u8;
dst.copy_from_nonoverlapping(src, len_bytes);
}
}
fn internal_logic(#first_arg_ident: &[i32]) -> Vec<i32> {
#fn_body
core::mem::forget(result);
len_bytes as i32
}
};
TokenStream::from(expanded)
}
fn validate_signature(fn_sig: &syn::Signature, expected_inputs: usize, def: &NodeDefinition) {
let param_count = fn_sig.inputs.len();
if param_count != expected_inputs {
panic!(
"Execute function has {} parameters but definition has {} inputs\n\
Definition inputs: {:?}\n\
Expected signature:\n\
pub fn execute({}) -> Vec<i32>",
param_count,
expected_inputs,
def.inputs
.as_ref()
.map(|i| i.keys().collect::<Vec<_>>())
.unwrap_or_default(),
(0..expected_inputs)
.map(|i| format!("arg{}: *const i32", i))
.collect::<Vec<_>>()
.join(", ")
);
}
match &fn_sig.output {
syn::ReturnType::Type(_, ty) => {
let is_vec = match &**ty {
syn::Type::Path(tp) => tp
.path
.segments
.first()
.map(|seg| seg.ident == "Vec")
.unwrap_or(false),
_ => false,
};
if !is_vec {
panic!("Execute function must return Vec<i32>");
}
}
syn::ReturnType::Default => {
panic!("Execute function must return Vec<i32>");
}
}
}
#[proc_macro]
pub fn nodarium_definition_file(input: TokenStream) -> TokenStream {
let path_lit = syn::parse_macro_input!(input as syn::LitStr);
@@ -105,30 +165,26 @@ pub fn nodarium_definition_file(input: TokenStream) -> TokenStream {
let full_path = Path::new(&project_dir).join(&file_path);
let json_content = fs::read_to_string(&full_path).unwrap_or_else(|err| {
panic!("Failed to read JSON file at '{}/{}': {}", project_dir, file_path, err)
panic!(
"Failed to read JSON file at '{}/{}': {}",
project_dir, file_path, err
)
});
let _: NodeDefinition = serde_json::from_str(&json_content).unwrap_or_else(|err| {
panic!("JSON file contains invalid JSON: \n{} \n{}", err, add_line_numbers(json_content.clone()))
panic!(
"JSON file contains invalid JSON: \n{} \n{}",
err,
add_line_numbers(json_content.clone())
)
});
// We use the span from the input path literal
let bytes = syn::LitByteStr::new(json_content.as_bytes(), path_lit.span());
let len = json_content.len();
let expanded = quote! {
#[link_section = "nodarium_definition"]
static DEFINITION_DATA: [u8; #len] = *#bytes;
#[no_mangle]
pub extern "C" fn get_definition_ptr() -> *const u8 {
DEFINITION_DATA.as_ptr()
}
#[no_mangle]
pub extern "C" fn get_definition_len() -> usize {
DEFINITION_DATA.len()
}
};
TokenStream::from(expanded)

View File

@@ -13,6 +13,7 @@ log.mute();
export class RemoteNodeRegistry implements NodeRegistry {
status: 'loading' | 'ready' | 'error' = 'loading';
private nodes: Map<string, NodeDefinition> = new Map();
private memory = new WebAssembly.Memory({ initial: 1024, maximum: 8192 });
constructor(
private url: string,
@@ -127,7 +128,10 @@ export class RemoteNodeRegistry implements NodeRegistry {
}
async register(wasmBuffer: ArrayBuffer) {
const wrapper = createWasmWrapper(wasmBuffer);
const wrapper = createWasmWrapper(
wasmBuffer,
this.memory
);
const definition = NodeDefinitionSchema.safeParse(wrapper.get_definition());
@@ -139,10 +143,7 @@ export class RemoteNodeRegistry implements NodeRegistry {
this.cache.set(definition.data.id, wasmBuffer);
}
let node = {
...definition.data,
execute: wrapper.execute
};
let node = { ...definition.data, execute: wrapper.execute };
this.nodes.set(definition.data.id, node);

View File

@@ -65,7 +65,7 @@ export const NodeSchema = z.object({
export type SerializedNode = z.infer<typeof NodeSchema>;
export type NodeDefinition = z.infer<typeof NodeDefinitionSchema> & {
execute(input: Int32Array): Int32Array;
execute(outputPos: number, args: number[]): number;
};
export type Socket = {

View File

@@ -9,6 +9,10 @@ description = "A collection of utilities for Nodarium"
[lib]
crate-type = ["rlib"]
[features]
default = ["std"]
std = []
[dependencies]
glam = "0.30.10"
noise = "0.9.0"

View File

@@ -10,6 +10,55 @@ pub fn decode_float(bits: i32) -> f32 {
f32::from_bits(bits)
}
#[inline]
pub unsafe fn read_i32(ptr: *const i32) -> i32 {
*ptr
}
#[inline]
pub unsafe fn read_f32(ptr: *const i32) -> f32 {
f32::from_bits(*ptr as u32)
}
#[inline]
pub unsafe fn read_bool(ptr: *const i32) -> bool {
*ptr != 0
}
#[inline]
pub unsafe fn read_vec3(ptr: *const i32) -> [f32; 3] {
let p = ptr as *const f32;
[p.read(), p.add(1).read(), p.add(2).read()]
}
#[inline]
pub unsafe fn read_i32_slice(ptr: *const i32, len: usize) -> Vec<i32> {
std::slice::from_raw_parts(ptr, len).to_vec()
}
#[inline]
pub unsafe fn read_f32_slice(ptr: *const i32, len: usize) -> Vec<f32> {
std::slice::from_raw_parts(ptr as *const f32, len).to_vec()
}
#[inline]
pub unsafe fn read_f32_default(ptr: *const i32, default: f32) -> f32 {
if ptr.is_null() {
default
} else {
read_f32(ptr)
}
}
#[inline]
pub unsafe fn read_i32_default(ptr: *const i32, default: i32) -> i32 {
if ptr.is_null() {
default
} else {
read_i32(ptr)
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -8,30 +8,30 @@ pub mod geometry;
extern "C" {
#[cfg(target_arch = "wasm32")]
pub fn host_log(ptr: *const u8, len: usize);
pub fn __nodarium_log(ptr: *const u8, len: usize);
}
#[cfg(debug_assertions)]
// #[cfg(debug_assertions)]
#[macro_export]
macro_rules! log {
($($t:tt)*) => {{
let msg = std::format!($($t)*);
#[cfg(target_arch = "wasm32")]
unsafe {
$crate::host_log(msg.as_ptr(), msg.len());
$crate::__nodarium_log(msg.as_ptr(), msg.len());
}
#[cfg(not(target_arch = "wasm32"))]
println!("{}", msg);
}}
}
#[cfg(not(debug_assertions))]
#[macro_export]
macro_rules! log {
($($arg:tt)*) => {{
// This will expand to nothing in release builds
}};
}
// #[cfg(not(debug_assertions))]
// #[macro_export]
// macro_rules! log {
// ($($arg:tt)*) => {{
// // This will expand to nothing in release builds
// }};
// }
#[allow(dead_code)]
#[rustfmt::skip]

View File

@@ -1,24 +1,23 @@
interface NodariumExports extends WebAssembly.Exports {
memory: WebAssembly.Memory;
execute: (ptr: number, len: number) => number;
__free: (ptr: number, len: number) => void;
__alloc: (len: number) => number;
execute: (outputPos: number, ...args: number[]) => number;
}
export function createWasmWrapper(buffer: ArrayBuffer) {
export function createWasmWrapper(buffer: ArrayBuffer, memory: WebAssembly.Memory) {
let exports: NodariumExports;
const importObject = {
env: {
host_log_panic: (ptr: number, len: number) => {
memory: memory,
__nodarium_log_panic: (ptr: number, len: number) => {
if (!exports) return;
const view = new Uint8Array(exports.memory.buffer, ptr, len);
console.error("RUST PANIC:", new TextDecoder().decode(view));
const view = new Uint8Array(memory.buffer, ptr, len);
console.error('WASM PANIC:', new TextDecoder().decode(view));
},
host_log: (ptr: number, len: number) => {
__nodarium_log: (ptr: number, len: number) => {
if (!exports) return;
const view = new Uint8Array(exports.memory.buffer, ptr, len);
console.log("RUST:", new TextDecoder().decode(view));
const view = new Uint8Array(memory.buffer, ptr, len);
console.log('WASM:', new TextDecoder().decode(view));
}
}
};
@@ -27,23 +26,13 @@ export function createWasmWrapper(buffer: ArrayBuffer) {
const instance = new WebAssembly.Instance(module, importObject);
exports = instance.exports as NodariumExports;
function execute(args: Int32Array) {
const inPtr = exports.__alloc(args.length);
new Int32Array(exports.memory.buffer).set(args, inPtr / 4);
const outPtr = exports.execute(inPtr, args.length);
const i32Result = new Int32Array(exports.memory.buffer);
const outLen = i32Result[outPtr / 4];
const out = i32Result.slice(outPtr / 4 + 1, outPtr / 4 + 1 + outLen);
exports.__free(inPtr, args.length);
return out;
function execute(outputPos: number, args: number[]): number {
console.log('WASM_WRAPPER', { outputPos, args });
return exports.execute(outputPos, ...args);
}
function get_definition() {
const sections = WebAssembly.Module.customSections(module, "nodarium_definition");
const sections = WebAssembly.Module.customSections(module, 'nodarium_definition');
if (sections.length > 0) {
const decoder = new TextDecoder();
const jsonString = decoder.decode(sections[0]);