chore: make some things a bit more typesafe

This commit is contained in:
Max Richter
2025-11-23 15:15:38 +01:00
parent 2dcd797762
commit 64ea7ac349
4 changed files with 252 additions and 162 deletions

View File

@@ -1,4 +1,11 @@
import type { Edge, Graph, Node, NodeInput, NodeRegistry, Socket, } from "@nodes/types"; import type {
Edge,
Graph,
Node,
NodeInput,
NodeRegistry,
Socket,
} from "@nodes/types";
import { fastHashString } from "@nodes/utils"; import { fastHashString } from "@nodes/utils";
import { writable, type Writable } from "svelte/store"; import { writable, type Writable } from "svelte/store";
import EventEmitter from "./helpers/EventEmitter.js"; import EventEmitter from "./helpers/EventEmitter.js";
@@ -10,17 +17,29 @@ const logger = createLogger("graph-manager");
logger.mute(); logger.mute();
const clone = "structuredClone" in self ? self.structuredClone : (args: any) => JSON.parse(JSON.stringify(args)); const clone =
"structuredClone" in self
? self.structuredClone
: (args: any) => JSON.parse(JSON.stringify(args));
function areSocketsCompatible(output: string | undefined, inputs: string | string[] | undefined) { function areSocketsCompatible(
output: string | undefined,
inputs: string | string[] | undefined,
) {
if (Array.isArray(inputs) && output) { if (Array.isArray(inputs) && output) {
return inputs.includes(output); return inputs.includes(output);
} }
return inputs === output; return inputs === output;
} }
export class GraphManager extends EventEmitter<{ "save": Graph, "result": any, "settings": { types: Record<string, NodeInput>, values: Record<string, unknown> } }> { export class GraphManager extends EventEmitter<{
save: Graph;
result: any;
settings: {
types: Record<string, NodeInput>;
values: Record<string, unknown>;
};
}> {
status: Writable<"loading" | "idle" | "error"> = writable("loading"); status: Writable<"loading" | "idle" | "error"> = writable("loading");
loaded = false; loaded = false;
@@ -62,24 +81,32 @@ export class GraphManager extends EventEmitter<{ "save": Graph, "result": any, "
} }
serialize(): Graph { serialize(): Graph {
logger.group("serializing graph") logger.group("serializing graph");
const nodes = Array.from(this._nodes.values()).map(node => ({ const nodes = Array.from(this._nodes.values()).map((node) => ({
id: node.id, id: node.id,
position: [...node.position], position: [...node.position],
type: node.type, type: node.type,
props: node.props, props: node.props,
})) as Node[]; })) as Node[];
const edges = this._edges.map(edge => [edge[0].id, edge[1], edge[2].id, edge[3]]) as Graph["edges"]; const edges = this._edges.map((edge) => [
const serialized = { id: this.graph.id, settings: this.settings, nodes, edges }; edge[0].id,
edge[1],
edge[2].id,
edge[3],
]) as Graph["edges"];
const serialized = {
id: this.graph.id,
settings: this.settings,
nodes,
edges,
};
logger.groupEnd(); logger.groupEnd();
return clone(serialized); return clone(serialized);
} }
private lastSettingsHash = 0; private lastSettingsHash = 0;
setSettings(settings: Record<string, unknown>) { setSettings(settings: Record<string, unknown>) {
let hash = fastHashString(JSON.stringify(settings)); let hash = fastHashString(JSON.stringify(settings));
if (hash === this.lastSettingsHash) return; if (hash === this.lastSettingsHash) return;
this.lastSettingsHash = hash; this.lastSettingsHash = hash;
@@ -89,8 +116,6 @@ export class GraphManager extends EventEmitter<{ "save": Graph, "result": any, "
this.execute(); this.execute();
} }
getNodeDefinitions() { getNodeDefinitions() {
return this.registry.getAllNodes(); return this.registry.getAllNodes();
} }
@@ -104,7 +129,7 @@ export class GraphManager extends EventEmitter<{ "save": Graph, "result": any, "
nodes.add(n); nodes.add(n);
const children = this.getChildrenOfNode(n); const children = this.getChildrenOfNode(n);
const parents = this.getParentsOfNode(n); const parents = this.getParentsOfNode(n);
const newNodes = [...children, ...parents].filter(n => !nodes.has(n)); const newNodes = [...children, ...parents].filter((n) => !nodes.has(n));
stack.push(...newNodes); stack.push(...newNodes);
} }
return [...nodes.values()]; return [...nodes.values()];
@@ -116,9 +141,16 @@ export class GraphManager extends EventEmitter<{ "save": Graph, "result": any, "
const children = node.tmp?.children || []; const children = node.tmp?.children || [];
for (const child of children) { for (const child of children) {
if (nodes.includes(child)) { if (nodes.includes(child)) {
const edge = this._edges.find(e => e[0].id === node.id && e[2].id === child.id); const edge = this._edges.find(
(e) => e[0].id === node.id && e[2].id === child.id,
);
if (edge) { if (edge) {
edges.push([edge[0].id, edge[1], edge[2].id, edge[3]] as [number, number, number, string]); edges.push([edge[0].id, edge[1], edge[2].id, edge[3]] as [
number,
number,
number,
string,
]);
} }
} }
} }
@@ -127,25 +159,26 @@ export class GraphManager extends EventEmitter<{ "save": Graph, "result": any, "
return edges; return edges;
} }
private _init(graph: Graph) { private _init(graph: Graph) {
const nodes = new Map(graph.nodes.map(node => { const nodes = new Map(
const nodeType = this.registry.getNode(node.type); graph.nodes.map((node) => {
if (nodeType) { const nodeType = this.registry.getNode(node.type);
node.tmp = { if (nodeType) {
random: (Math.random() - 0.5) * 2, node.tmp = {
type: nodeType random: (Math.random() - 0.5) * 2,
}; type: nodeType,
} };
return [node.id, node] }
})); return [node.id, node];
}),
);
const edges = graph.edges.map((edge) => { const edges = graph.edges.map((edge) => {
const from = nodes.get(edge[0]); const from = nodes.get(edge[0]);
const to = nodes.get(edge[2]); const to = nodes.get(edge[2]);
if (!from || !to) { if (!from || !to) {
throw new Error("Edge references non-existing node"); throw new Error("Edge references non-existing node");
}; }
from.tmp = from.tmp || {}; from.tmp = from.tmp || {};
from.tmp.children = from.tmp.children || []; from.tmp.children = from.tmp.children || [];
from.tmp.children.push(to); from.tmp.children.push(to);
@@ -153,17 +186,15 @@ export class GraphManager extends EventEmitter<{ "save": Graph, "result": any, "
to.tmp.parents = to.tmp.parents || []; to.tmp.parents = to.tmp.parents || [];
to.tmp.parents.push(from); to.tmp.parents.push(from);
return [from, edge[1], to, edge[3]] as Edge; return [from, edge[1], to, edge[3]] as Edge;
}) });
this.edges.set(edges); this.edges.set(edges);
this.nodes.set(nodes); this.nodes.set(nodes);
this.execute(); this.execute();
} }
async load(graph: Graph) { async load(graph: Graph) {
const a = performance.now(); const a = performance.now();
this.loaded = false; this.loaded = false;
@@ -171,7 +202,7 @@ export class GraphManager extends EventEmitter<{ "save": Graph, "result": any, "
this.status.set("loading"); this.status.set("loading");
this.id.set(graph.id); this.id.set(graph.id);
const nodeIds = Array.from(new Set([...graph.nodes.map(n => n.type)])); const nodeIds = Array.from(new Set([...graph.nodes.map((n) => n.type)]));
await this.registry.load(nodeIds); await this.registry.load(nodeIds);
for (const node of this.graph.nodes) { for (const node of this.graph.nodes) {
@@ -186,9 +217,12 @@ export class GraphManager extends EventEmitter<{ "save": Graph, "result": any, "
node.tmp.type = nodeType; node.tmp.type = nodeType;
} }
// load settings // load settings
const settingTypes: Record<string, NodeInput> = {}; const settingTypes: Record<
string,
// Optional metadata to map settings to specific nodes
NodeInput & { __node_type: string; __node_input: string }
> = {};
const settingValues = graph.settings || {}; const settingValues = graph.settings || {};
const types = this.getNodeDefinitions(); const types = this.getNodeDefinitions();
for (const type of types) { for (const type of types) {
@@ -196,8 +230,15 @@ export class GraphManager extends EventEmitter<{ "save": Graph, "result": any, "
for (const key in type.inputs) { for (const key in type.inputs) {
let settingId = type.inputs[key].setting; let settingId = type.inputs[key].setting;
if (settingId) { if (settingId) {
settingTypes[settingId] = { __node_type: type.id, __node_input: key, ...type.inputs[key] }; settingTypes[settingId] = {
if (settingValues[settingId] === undefined && "value" in type.inputs[key]) { __node_type: type.id,
__node_input: key,
...type.inputs[key],
};
if (
settingValues[settingId] === undefined &&
"value" in type.inputs[key]
) {
settingValues[settingId] = type.inputs[key].value; settingValues[settingId] = type.inputs[key].value;
} }
} }
@@ -220,7 +261,6 @@ export class GraphManager extends EventEmitter<{ "save": Graph, "result": any, "
setTimeout(() => this.execute(), 100); setTimeout(() => this.execute(), 100);
} }
getAllNodes() { getAllNodes() {
return Array.from(this._nodes.values()); return Array.from(this._nodes.values());
} }
@@ -234,7 +274,6 @@ export class GraphManager extends EventEmitter<{ "save": Graph, "result": any, "
} }
async loadNode(id: string) { async loadNode(id: string) {
await this.registry.load([id]); await this.registry.load([id]);
const nodeType = this.registry.getNode(id); const nodeType = this.registry.getNode(id);
@@ -247,7 +286,10 @@ export class GraphManager extends EventEmitter<{ "save": Graph, "result": any, "
let settingId = nodeType.inputs[key].setting; let settingId = nodeType.inputs[key].setting;
if (settingId) { if (settingId) {
settingTypes[settingId] = nodeType.inputs[key]; settingTypes[settingId] = nodeType.inputs[key];
if (settingValues[settingId] === undefined && "value" in nodeType.inputs[key]) { if (
settingValues[settingId] === undefined &&
"value" in nodeType.inputs[key]
) {
settingValues[settingId] = nodeType.inputs[key].value; settingValues[settingId] = nodeType.inputs[key].value;
} }
} }
@@ -266,7 +308,7 @@ export class GraphManager extends EventEmitter<{ "save": Graph, "result": any, "
const child = stack.pop(); const child = stack.pop();
if (!child) continue; if (!child) continue;
children.push(child); children.push(child);
stack.push(...child.tmp?.children || []); stack.push(...(child.tmp?.children || []));
} }
return children; return children;
} }
@@ -278,10 +320,10 @@ export class GraphManager extends EventEmitter<{ "save": Graph, "result": any, "
const fromParents = this.getParentsOfNode(from); const fromParents = this.getParentsOfNode(from);
if (toParents.includes(from)) { if (toParents.includes(from)) {
const fromChildren = this.getChildrenOfNode(from); const fromChildren = this.getChildrenOfNode(from);
return toParents.filter(n => fromChildren.includes(n)); return toParents.filter((n) => fromChildren.includes(n));
} else if (fromParents.includes(to)) { } else if (fromParents.includes(to)) {
const toChildren = this.getChildrenOfNode(to); const toChildren = this.getChildrenOfNode(to);
return fromParents.filter(n => toChildren.includes(n)); return fromParents.filter((n) => toChildren.includes(n));
} else { } else {
// these two nodes are not connected // these two nodes are not connected
return; return;
@@ -289,7 +331,6 @@ export class GraphManager extends EventEmitter<{ "save": Graph, "result": any, "
} }
removeNode(node: Node, { restoreEdges = false } = {}) { removeNode(node: Node, { restoreEdges = false } = {}) {
const edgesToNode = this._edges.filter((edge) => edge[2].id === node.id); const edgesToNode = this._edges.filter((edge) => edge[2].id === node.id);
const edgesFromNode = this._edges.filter((edge) => edge[0].id === node.id); const edgesFromNode = this._edges.filter((edge) => edge[0].id === node.id);
for (const edge of [...edgesToNode, ...edgesFromNode]) { for (const edge of [...edgesToNode, ...edgesFromNode]) {
@@ -297,15 +338,17 @@ export class GraphManager extends EventEmitter<{ "save": Graph, "result": any, "
} }
if (restoreEdges) { if (restoreEdges) {
const outputSockets = edgesToNode.map(e => [e[0], e[1]] as const); const outputSockets = edgesToNode.map((e) => [e[0], e[1]] as const);
const inputSockets = edgesFromNode.map(e => [e[2], e[3]] as const); const inputSockets = edgesFromNode.map((e) => [e[2], e[3]] as const);
for (const [to, toSocket] of inputSockets) { for (const [to, toSocket] of inputSockets) {
for (const [from, fromSocket] of outputSockets) { for (const [from, fromSocket] of outputSockets) {
const outputType = from.tmp?.type?.outputs?.[fromSocket]; const outputType = from.tmp?.type?.outputs?.[fromSocket];
const inputType = to?.tmp?.type?.inputs?.[toSocket]?.type; const inputType = to?.tmp?.type?.inputs?.[toSocket]?.type;
if (outputType === inputType) { if (outputType === inputType) {
this.createEdge(from, fromSocket, to, toSocket, { applyUpdate: false }); this.createEdge(from, fromSocket, to, toSocket, {
applyUpdate: false,
});
continue; continue;
} }
} }
@@ -318,7 +361,7 @@ export class GraphManager extends EventEmitter<{ "save": Graph, "result": any, "
nodes.delete(node.id); nodes.delete(node.id);
return nodes; return nodes;
}); });
this.execute() this.execute();
this.save(); this.save();
} }
@@ -328,7 +371,6 @@ export class GraphManager extends EventEmitter<{ "save": Graph, "result": any, "
} }
createGraph(nodes: Node[], edges: [number, number, number, string][]) { createGraph(nodes: Node[], edges: [number, number, number, string][]) {
// map old ids to new ids // map old ids to new ids
const idMap = new Map<number, number>(); const idMap = new Map<number, number>();
@@ -344,9 +386,9 @@ export class GraphManager extends EventEmitter<{ "save": Graph, "result": any, "
return { ...node, id, tmp: { type } }; return { ...node, id, tmp: { type } };
}); });
const _edges = edges.map(edge => { const _edges = edges.map((edge) => {
const from = nodes.find(n => n.id === idMap.get(edge[0])); const from = nodes.find((n) => n.id === idMap.get(edge[0]));
const to = nodes.find(n => n.id === idMap.get(edge[2])); const to = nodes.find((n) => n.id === idMap.get(edge[2]));
if (!from || !to) { if (!from || !to) {
throw new Error("Edge references non-existing node"); throw new Error("Edge references non-existing node");
@@ -375,15 +417,28 @@ export class GraphManager extends EventEmitter<{ "save": Graph, "result": any, "
return nodes; return nodes;
} }
createNode({ type, position, props = {} }: { type: Node["type"], position: Node["position"], props: Node["props"] }) { createNode({
type,
position,
props = {},
}: {
type: Node["type"];
position: Node["position"];
props: Node["props"];
}) {
const nodeType = this.registry.getNode(type); const nodeType = this.registry.getNode(type);
if (!nodeType) { if (!nodeType) {
logger.error(`Node type not found: ${type}`); logger.error(`Node type not found: ${type}`);
return; return;
} }
const node: Node = { id: this.createNodeId(), type, position, tmp: { type: nodeType }, props }; const node: Node = {
id: this.createNodeId(),
type,
position,
tmp: { type: nodeType },
props,
};
this.nodes.update((nodes) => { this.nodes.update((nodes) => {
nodes.set(node.id, node); nodes.set(node.id, node);
@@ -393,16 +448,23 @@ export class GraphManager extends EventEmitter<{ "save": Graph, "result": any, "
this.save(); this.save();
} }
createEdge(from: Node, fromSocket: number, to: Node, toSocket: string, { applyUpdate = true } = {}) { createEdge(
from: Node,
fromSocket: number,
to: Node,
toSocket: string,
{ applyUpdate = true } = {},
) {
const existingEdges = this.getEdgesToNode(to); const existingEdges = this.getEdgesToNode(to);
// check if this exact edge already exists // check if this exact edge already exists
const existingEdge = existingEdges.find(e => e[0].id === from.id && e[1] === fromSocket && e[3] === toSocket); const existingEdge = existingEdges.find(
(e) => e[0].id === from.id && e[1] === fromSocket && e[3] === toSocket,
);
if (existingEdge) { if (existingEdge) {
logger.error("Edge already exists", existingEdge); logger.error("Edge already exists", existingEdge);
return; return;
}; }
// check if socket types match // check if socket types match
const fromSocketType = from.tmp?.type?.outputs?.[fromSocket]; const fromSocketType = from.tmp?.type?.outputs?.[fromSocket];
@@ -412,11 +474,15 @@ export class GraphManager extends EventEmitter<{ "save": Graph, "result": any, "
} }
if (!areSocketsCompatible(fromSocketType, toSocketType)) { if (!areSocketsCompatible(fromSocketType, toSocketType)) {
logger.error(`Socket types do not match: ${fromSocketType} !== ${toSocketType}`); logger.error(
`Socket types do not match: ${fromSocketType} !== ${toSocketType}`,
);
return; return;
} }
const edgeToBeReplaced = this._edges.find(e => e[2].id === to.id && e[3] === toSocket); const edgeToBeReplaced = this._edges.find(
(e) => e[2].id === to.id && e[3] === toSocket,
);
if (edgeToBeReplaced) { if (edgeToBeReplaced) {
this.removeEdge(edgeToBeReplaced, { applyDeletion: false }); this.removeEdge(edgeToBeReplaced, { applyDeletion: false });
} }
@@ -450,14 +516,12 @@ export class GraphManager extends EventEmitter<{ "save": Graph, "result": any, "
} }
} }
redo() { redo() {
const nextState = this.history.redo(); const nextState = this.history.redo();
if (nextState) { if (nextState) {
this._init(nextState); this._init(nextState);
this.emit("save", this.serialize()); this.emit("save", this.serialize());
} }
} }
startUndoGroup() { startUndoGroup() {
@@ -482,30 +546,30 @@ export class GraphManager extends EventEmitter<{ "save": Graph, "result": any, "
const stack = node.tmp?.parents?.slice(0); const stack = node.tmp?.parents?.slice(0);
while (stack?.length) { while (stack?.length) {
if (parents.length > 1000000) { if (parents.length > 1000000) {
logger.warn("Infinite loop detected") logger.warn("Infinite loop detected");
break; break;
} }
const parent = stack.pop(); const parent = stack.pop();
if (!parent) continue; if (!parent) continue;
parents.push(parent); parents.push(parent);
stack.push(...parent.tmp?.parents || []); stack.push(...(parent.tmp?.parents || []));
} }
return parents.reverse(); return parents.reverse();
} }
getPossibleSockets({ node, index }: Socket): [Node, string | number][] { getPossibleSockets({ node, index }: Socket): [Node, string | number][] {
const nodeType = node?.tmp?.type; const nodeType = node?.tmp?.type;
if (!nodeType) return []; if (!nodeType) return [];
const sockets: [Node, string | number][] = [] const sockets: [Node, string | number][] = [];
// if index is a string, we are an input looking for outputs // if index is a string, we are an input looking for outputs
if (typeof index === "string") { if (typeof index === "string") {
// filter out self and child nodes // filter out self and child nodes
const children = new Set(this.getChildrenOfNode(node).map(n => n.id)); const children = new Set(this.getChildrenOfNode(node).map((n) => n.id));
const nodes = this.getAllNodes().filter(n => n.id !== node.id && !children.has(n.id)); const nodes = this.getAllNodes().filter(
(n) => n.id !== node.id && !children.has(n.id),
);
const ownType = nodeType?.inputs?.[index].type; const ownType = nodeType?.inputs?.[index].type;
@@ -519,16 +583,21 @@ export class GraphManager extends EventEmitter<{ "save": Graph, "result": any, "
} }
} }
} }
} else if (typeof index === "number") { } else if (typeof index === "number") {
// if index is a number, we are an output looking for inputs // if index is a number, we are an output looking for inputs
// filter out self and parent nodes // filter out self and parent nodes
const parents = new Set(this.getParentsOfNode(node).map(n => n.id)); const parents = new Set(this.getParentsOfNode(node).map((n) => n.id));
const nodes = this.getAllNodes().filter(n => n.id !== node.id && !parents.has(n.id)); const nodes = this.getAllNodes().filter(
(n) => n.id !== node.id && !parents.has(n.id),
);
// get edges from this socket // get edges from this socket
const edges = new Map(this.getEdgesFromNode(node).filter(e => e[1] === index).map(e => [e[2].id, e[3]])); const edges = new Map(
this.getEdgesFromNode(node)
.filter((e) => e[1] === index)
.map((e) => [e[2].id, e[3]]),
);
const ownType = nodeType.outputs?.[index]; const ownType = nodeType.outputs?.[index];
@@ -536,11 +605,13 @@ export class GraphManager extends EventEmitter<{ "save": Graph, "result": any, "
const inputs = node?.tmp?.type?.inputs; const inputs = node?.tmp?.type?.inputs;
if (!inputs) continue; if (!inputs) continue;
for (const key in inputs) { for (const key in inputs) {
const otherType = [inputs[key].type]; const otherType = [inputs[key].type];
otherType.push(...(inputs[key].accepts || [])); otherType.push(...(inputs[key].accepts || []));
if (areSocketsCompatible(ownType, otherType) && edges.get(node.id) !== key) { if (
areSocketsCompatible(ownType, otherType) &&
edges.get(node.id) !== key
) {
sockets.push([node, key]); sockets.push([node, key]);
} }
} }
@@ -548,38 +619,46 @@ export class GraphManager extends EventEmitter<{ "save": Graph, "result": any, "
} }
return sockets; return sockets;
} }
removeEdge(edge: Edge, { applyDeletion = true }: { applyDeletion?: boolean } = {}) { removeEdge(
edge: Edge,
{ applyDeletion = true }: { applyDeletion?: boolean } = {},
) {
const id0 = edge[0].id; const id0 = edge[0].id;
const sid0 = edge[1]; const sid0 = edge[1];
const id2 = edge[2].id; const id2 = edge[2].id;
const sid2 = edge[3]; const sid2 = edge[3];
const _edge = this._edges.find((e) => e[0].id === id0 && e[1] === sid0 && e[2].id === id2 && e[3] === sid2); const _edge = this._edges.find(
(e) =>
e[0].id === id0 && e[1] === sid0 && e[2].id === id2 && e[3] === sid2,
);
if (!_edge) return; if (!_edge) return;
edge[0].tmp = edge[0].tmp || {}; edge[0].tmp = edge[0].tmp || {};
if (edge[0].tmp.children) { if (edge[0].tmp.children) {
edge[0].tmp.children = edge[0].tmp.children.filter(n => n.id !== id2); edge[0].tmp.children = edge[0].tmp.children.filter(
(n: Node) => n.id !== id2,
);
} }
edge[2].tmp = edge[2].tmp || {}; edge[2].tmp = edge[2].tmp || {};
if (edge[2].tmp.parents) { if (edge[2].tmp.parents) {
edge[2].tmp.parents = edge[2].tmp.parents.filter(n => n.id !== id0); edge[2].tmp.parents = edge[2].tmp.parents.filter(
(n: Node) => n.id !== id0,
);
} }
if (applyDeletion) { if (applyDeletion) {
this.edges.update((edges) => { this.edges.update((edges) => {
return edges.filter(e => e !== _edge); return edges.filter((e) => e !== _edge);
}); });
this.execute(); this.execute();
this.save(); this.save();
} else { } else {
this._edges = this._edges.filter(e => e !== _edge); this._edges = this._edges.filter((e) => e !== _edge);
} }
} }
getEdgesToNode(node: Node) { getEdgesToNode(node: Node) {
@@ -605,7 +684,4 @@ export class GraphManager extends EventEmitter<{ "save": Graph, "result": any, "
}) })
.filter(Boolean) as unknown as [Node, number, Node, string][]; .filter(Boolean) as unknown as [Node, number, Node, string][];
} }
} }

View File

@@ -1,19 +1,24 @@
import { NodeDefinitionSchema, type AsyncCache, type NodeDefinition, type NodeRegistry } from "@nodes/types"; import {
NodeDefinitionSchema,
type AsyncCache,
type NodeDefinition,
type NodeRegistry,
} from "@nodes/types";
import { createLogger, createWasmWrapper } from "@nodes/utils"; import { createLogger, createWasmWrapper } from "@nodes/utils";
const log = createLogger("node-registry"); const log = createLogger("node-registry");
log.mute(); log.mute();
export class RemoteNodeRegistry implements NodeRegistry { export class RemoteNodeRegistry implements NodeRegistry {
status: "loading" | "ready" | "error" = "loading"; status: "loading" | "ready" | "error" = "loading";
private nodes: Map<string, NodeDefinition> = new Map(); private nodes: Map<string, NodeDefinition> = new Map();
cache?: AsyncCache<ArrayBuffer>;
fetch: typeof fetch = globalThis.fetch.bind(globalThis); fetch: typeof fetch = globalThis.fetch.bind(globalThis);
constructor(private url: string, private cache?: AsyncCache<ArrayBuffer>) { } constructor(
private url: string,
private cache?: AsyncCache<ArrayBuffer>,
) {}
async fetchUsers() { async fetchUsers() {
const response = await this.fetch(`${this.url}/nodes/users.json`); const response = await this.fetch(`${this.url}/nodes/users.json`);
@@ -32,7 +37,9 @@ export class RemoteNodeRegistry implements NodeRegistry {
} }
async fetchCollection(userCollectionId: `${string}/${string}`) { async fetchCollection(userCollectionId: `${string}/${string}`) {
const response = await this.fetch(`${this.url}/nodes/${userCollectionId}.json`); const response = await this.fetch(
`${this.url}/nodes/${userCollectionId}.json`,
);
if (!response.ok) { if (!response.ok) {
throw new Error(`Failed to load collection ${userCollectionId}`); throw new Error(`Failed to load collection ${userCollectionId}`);
} }
@@ -44,20 +51,16 @@ export class RemoteNodeRegistry implements NodeRegistry {
if (!response.ok) { if (!response.ok) {
throw new Error(`Failed to load node definition ${nodeId}`); throw new Error(`Failed to load node definition ${nodeId}`);
} }
return response.json() return response.json();
} }
private async fetchNodeWasm(nodeId: `${string}/${string}/${string}`) { private async fetchNodeWasm(nodeId: `${string}/${string}/${string}`) {
const fetchNode = async () => { const fetchNode = async () => {
const response = await this.fetch(`${this.url}/nodes/${nodeId}.wasm`); const response = await this.fetch(`${this.url}/nodes/${nodeId}.wasm`);
return response.arrayBuffer(); return response.arrayBuffer();
} };
const res = await Promise.race([ const res = await Promise.race([fetchNode(), this.cache?.get(nodeId)]);
fetchNode(),
this.cache?.get(nodeId)
]);
if (!res) { if (!res) {
throw new Error(`Failed to load node wasm ${nodeId}`); throw new Error(`Failed to load node wasm ${nodeId}`);
@@ -69,18 +72,17 @@ export class RemoteNodeRegistry implements NodeRegistry {
async load(nodeIds: `${string}/${string}/${string}`[]) { async load(nodeIds: `${string}/${string}/${string}`[]) {
const a = performance.now(); const a = performance.now();
const nodes = await Promise.all([...new Set(nodeIds).values()].map(async id => { const nodes = await Promise.all(
[...new Set(nodeIds).values()].map(async (id) => {
if (this.nodes.has(id)) {
return this.nodes.get(id)!;
}
if (this.nodes.has(id)) { const wasmBuffer = await this.fetchNodeWasm(id);
return this.nodes.get(id)!;
}
const wasmBuffer = await this.fetchNodeWasm(id);
return this.register(wasmBuffer);
}));
return this.register(wasmBuffer);
}),
);
const duration = performance.now() - a; const duration = performance.now() - a;
@@ -90,11 +92,10 @@ export class RemoteNodeRegistry implements NodeRegistry {
log.groupEnd(); log.groupEnd();
this.status = "ready"; this.status = "ready";
return nodes return nodes;
} }
async register(wasmBuffer: ArrayBuffer) { async register(wasmBuffer: ArrayBuffer) {
const wrapper = createWasmWrapper(wasmBuffer); const wrapper = createWasmWrapper(wasmBuffer);
const definition = NodeDefinitionSchema.safeParse(wrapper.get_definition()); const definition = NodeDefinitionSchema.safeParse(wrapper.get_definition());
@@ -110,8 +111,8 @@ export class RemoteNodeRegistry implements NodeRegistry {
let node = { let node = {
...definition.data, ...definition.data,
execute: wrapper.execute execute: wrapper.execute,
} };
this.nodes.set(definition.data.id, node); this.nodes.set(definition.data.id, node);

View File

@@ -2,17 +2,17 @@ import { Graph, NodeDefinition, NodeId } from "./types";
export interface NodeRegistry { export interface NodeRegistry {
/** /**
* The status of the node registry * The status of the node registry
* @remarks The status should be "loading" when the registry is loading, "ready" when the registry is ready, and "error" if an error occurred while loading the registry * @remarks The status should be "loading" when the registry is loading, "ready" when the registry is ready, and "error" if an error occurred while loading the registry
*/ */
status: "loading" | "ready" | "error"; status: "loading" | "ready" | "error";
/** /**
* Load the nodes with the given ids * Load the nodes with the given ids
* @param nodeIds - The ids of the nodes to load * @param nodeIds - The ids of the nodes to load
* @returns A promise that resolves when the nodes are loaded * @returns A promise that resolves when the nodes are loaded
* @throws An error if the nodes could not be loaded * @throws An error if the nodes could not be loaded
* @remarks This method should be called before calling getNode or getAllNodes * @remarks This method should be called before calling getNode or getAllNodes
*/ */
load: (nodeIds: NodeId[]) => Promise<NodeDefinition[]>; load: (nodeIds: NodeId[]) => Promise<NodeDefinition[]>;
/** /**
* Get a node by id * Get a node by id
@@ -27,30 +27,30 @@ export interface NodeRegistry {
getAllNodes: () => NodeDefinition[]; getAllNodes: () => NodeDefinition[];
/** /**
* Register a new node * Register a new node
* @param wasmBuffer - The WebAssembly buffer for the node * @param wasmBuffer - The WebAssembly buffer for the node
* @returns The node definition * @returns The node definition
*/ */
register: (wasmBuffer: ArrayBuffer) => Promise<NodeDefinition>; register: (wasmBuffer: ArrayBuffer) => Promise<NodeDefinition>;
cache?: AsyncCache<ArrayBuffer>;
} }
export interface RuntimeExecutor { export interface RuntimeExecutor {
/** /**
* Execute the given graph * Execute the given graph
* @param graph - The graph to execute * @param graph - The graph to execute
* @returns The result of the execution * @returns The result of the execution
*/ */
execute: (graph: Graph, settings: Record<string, unknown>) => Promise<Int32Array>; execute: (
graph: Graph,
settings: Record<string, unknown>,
) => Promise<Int32Array>;
} }
export interface SyncCache<T = unknown> { export interface SyncCache<T = unknown> {
/** /**
* The maximum number of items that can be stored in the cache * The maximum number of items that can be stored in the cache
* @remarks When the cache size exceeds this value, the oldest items should be removed * @remarks When the cache size exceeds this value, the oldest items should be removed
*/ */
size: number; size: number;
/** /**
@@ -69,14 +69,13 @@ export interface SyncCache<T = unknown> {
* Clear the cache * Clear the cache
*/ */
clear: () => void; clear: () => void;
} }
export interface AsyncCache<T = unknown> { export interface AsyncCache<T = unknown> {
/** /**
* The maximum number of items that can be stored in the cache * The maximum number of items that can be stored in the cache
* @remarks When the cache size exceeds this value, the oldest items should be removed * @remarks When the cache size exceeds this value, the oldest items should be removed
*/ */
size: number; size: number;
/** /**

View File

@@ -1,17 +1,27 @@
import { z } from "zod"; import { z } from "zod";
import { NodeInputSchema } from "./inputs"; import { NodeInputSchema } from "./inputs";
export type NodeId = `${string}/${string}/${string}`; export const NodeTypeSchema = z
.string()
.regex(/^[^/]+\/[^/]+\/[^/]+$/, "Invalid NodeId format")
.transform((value) => value as `${string}/${string}/${string}`);
export type NodeType = z.infer<typeof NodeTypeSchema>;
export const NodeSchema = z.object({ export const NodeSchema = z.object({
id: z.number(), id: z.number(),
type: z.string(), type: NodeTypeSchema,
props: z.record(z.union([z.number(), z.array(z.number())])).optional(), tmp: z.any().optional(),
meta: z.object({ props: z
title: z.string().optional(), .record(z.string(), z.union([z.number(), z.array(z.number())]))
lastModified: z.string().optional(), .optional(),
}).optional(), meta: z
position: z.tuple([z.number(), z.number()]) .object({
title: z.string().optional(),
lastModified: z.string().optional(),
})
.optional(),
position: z.tuple([z.number(), z.number()]),
}); });
export type Node = { export type Node = {
@@ -19,9 +29,9 @@ export type Node = {
depth?: number; depth?: number;
mesh?: any; mesh?: any;
random?: number; random?: number;
parents?: Node[], parents?: Node[];
children?: Node[], children?: Node[];
inputNodes?: Record<string, Node> inputNodes?: Record<string, Node>;
type?: NodeDefinition; type?: NodeDefinition;
downX?: number; downX?: number;
downY?: number; downY?: number;
@@ -30,17 +40,19 @@ export type Node = {
ref?: HTMLElement; ref?: HTMLElement;
visible?: boolean; visible?: boolean;
isMoving?: boolean; isMoving?: boolean;
} };
} & z.infer<typeof NodeSchema>; } & z.infer<typeof NodeSchema>;
export const NodeDefinitionSchema = z.object({ export const NodeDefinitionSchema = z.object({
id: z.string(), id: z.string(),
inputs: z.record(NodeInputSchema).optional(), inputs: z.record(z.string(), NodeInputSchema).optional(),
outputs: z.array(z.string()).optional(), outputs: z.array(z.string()).optional(),
meta: z.object({ meta: z
description: z.string().optional(), .object({
title: z.string().optional(), description: z.string().optional(),
}).optional(), title: z.string().optional(),
})
.optional(),
}); });
export type NodeDefinition = z.infer<typeof NodeDefinitionSchema> & { export type NodeDefinition = z.infer<typeof NodeDefinitionSchema> & {
@@ -56,12 +68,14 @@ export type Socket = {
export type Edge = [Node, number, Node, string]; export type Edge = [Node, number, Node, string];
export const GraphSchema = z.object({ export const GraphSchema = z.object({
id: z.number().optional(), id: z.number(),
meta: z.object({ meta: z
title: z.string().optional(), .object({
lastModified: z.string().optional(), title: z.string().optional(),
}).optional(), lastModified: z.string().optional(),
settings: z.record(z.any()).optional(), })
.optional(),
settings: z.record(z.string(), z.any()).optional(),
nodes: z.array(NodeSchema), nodes: z.array(NodeSchema),
edges: z.array(z.tuple([z.number(), z.number(), z.number(), z.string()])), edges: z.array(z.tuple([z.number(), z.number(), z.number(), z.string()])),
}); });