feat: add "*"/any type input for dev page

This commit is contained in:
Max Richter
2026-01-22 15:54:08 +01:00
committed by Max Richter
parent 8d403ba803
commit 2a14ed7202
9 changed files with 124 additions and 44 deletions

View File

@@ -25,14 +25,14 @@ const clone = 'structuredClone' in self
? self.structuredClone
: (args: unknown) => JSON.parse(JSON.stringify(args));
function areSocketsCompatible(
export function areSocketsCompatible(
output: string | undefined,
inputs: string | (string | undefined)[] | undefined
) {
if (Array.isArray(inputs) && output) {
return inputs.includes(output);
return inputs.includes('*') || inputs.includes(output);
}
return inputs === output;
return inputs === output || inputs === '*';
}
function areEdgesEqual(firstEdge: Edge, secondEdge: Edge) {
@@ -268,14 +268,7 @@ export class GraphManager extends EventEmitter<{
private _init(graph: Graph) {
const nodes = new SvelteMap(
graph.nodes.map((node) => {
const nodeType = this.registry.getNode(node.type);
const n = node as NodeInstance;
if (nodeType) {
n.state = {
type: nodeType
};
}
return [node.id, n];
return [node.id, node as NodeInstance];
})
);
@@ -300,6 +293,30 @@ export class GraphManager extends EventEmitter<{
this.execute();
}
private async loadAllCollections() {
// Fetch all nodes from all collections of the loaded nodes
const nodeIds = Array.from(new Set([...this.graph.nodes.map((n) => n.type)]));
const allCollections = new Set<`${string}/${string}`>();
for (const id of nodeIds) {
const [user, collection] = id.split('/');
allCollections.add(`${user}/${collection}`);
}
const allCollectionIds = await Promise
.all([...allCollections]
.map(async (collection) =>
remoteRegistry
.fetchCollection(collection)
.then((collection: { nodes: { id: NodeId }[] }) => {
return collection.nodes.map(n => n.id.replace(/\.wasm$/, '') as NodeId);
})
));
const missingNodeIds = [...new Set(allCollectionIds.flat())];
this.registry.load(missingNodeIds);
}
async load(graph: Graph) {
const a = performance.now();
@@ -384,7 +401,9 @@ export class GraphManager extends EventEmitter<{
this.loaded = true;
logger.log(`Graph loaded in ${performance.now() - a}ms`);
setTimeout(() => this.execute(), 100);
this.loadAllCollections(); // lazily load all nodes from all collections
}
getAllNodes() {
@@ -491,10 +510,10 @@ export class GraphManager extends EventEmitter<{
const inputs = Object.entries(to.state?.type?.inputs ?? {});
const outputs = from.state?.type?.outputs ?? [];
for (let i = 0; i < inputs.length; i++) {
const [inputName, input] = inputs[0];
const [inputName, input] = inputs[i];
for (let o = 0; o < outputs.length; o++) {
const output = outputs[0];
if (input.type === output) {
const output = outputs[o];
if (input.type === output || input.type === '*') {
return this.createEdge(from, o, to, inputName);
}
}
@@ -724,6 +743,7 @@ export class GraphManager extends EventEmitter<{
getPossibleSockets({ node, index }: Socket): [NodeInstance, string | number][] {
const nodeType = node?.state?.type;
console.log({ node: $state.snapshot(node), index, nodeType });
if (!nodeType) return [];
const sockets: [NodeInstance, string | number][] = [];
@@ -787,6 +807,7 @@ export class GraphManager extends EventEmitter<{
}
}
console.log(`Found ${sockets.length} possible sockets`, sockets);
return sockets;
}

View File

@@ -169,11 +169,14 @@ export class GraphState {
(node?.state?.y ?? node.position[1]) + 2.5 + 10 * index
];
} else {
const _index = Object.keys(node.state?.type?.inputs || {}).indexOf(index);
return [
const inputs = node.state.type?.inputs || this.graph.registry.getNode(node.type)?.inputs
|| {};
const _index = Object.keys(inputs).indexOf(index);
const pos = [
node?.state?.x ?? node.position[0],
(node?.state?.y ?? node.position[1]) + 10 + 10 * _index
];
] as [number, number];
return pos;
}
}
@@ -259,7 +262,7 @@ export class GraphState {
let { node, index, position } = socket;
// remove existing edge
// if the socket is an input socket -> remove existing edges
if (typeof index === 'string') {
const edges = this.graph.getEdgesToNode(node);
for (const edge of edges) {

View File

@@ -25,11 +25,11 @@
let {
graph,
registry,
settings = $bindable(),
activeNode = $bindable(),
backgroundType = $bindable('grid'),
snapToGrid = $bindable(true),
showHelp = $bindable(false),
settings = $bindable(),
settingTypes = $bindable(),
onsave,
onresult

View File

@@ -2,6 +2,7 @@ import {
type AsyncCache,
type NodeDefinition,
NodeDefinitionSchema,
type NodeId,
type NodeRegistry
} from '@nodarium/types';
import { createLogger, createWasmWrapper } from '@nodarium/utils';
@@ -163,6 +164,13 @@ export class RemoteNodeRegistry implements NodeRegistry {
}
getAllNodes() {
return [...this.nodes.values()];
const allNodes = [...this.nodes.values()];
log.info('getting all nodes', allNodes);
return allNodes;
}
async overwriteNode(nodeId: NodeId, node: NodeDefinition) {
log.info('Overwritten node', { nodeId, node });
this.nodes.set(nodeId, node);
}
}

View File

@@ -52,6 +52,7 @@ function getValue(input: NodeInput, value?: unknown) {
return value;
}
console.error({ input, value });
throw new Error(`Unknown input type ${input.type}`);
}
@@ -62,6 +63,8 @@ export class MemoryRuntimeExecutor implements RuntimeExecutor {
perf?: PerformanceStore;
private results: Record<string, Int32Array> = {};
constructor(
private registry: NodeRegistry,
public cache?: SyncCache<Int32Array>
@@ -170,7 +173,7 @@ export class MemoryRuntimeExecutor implements RuntimeExecutor {
);
// here we store the intermediate results of the nodes
const results: Record<string, Int32Array> = {};
this.results = {};
if (settings['randomSeed']) {
this.seed = Math.floor(Math.random() * 100000000);
@@ -201,12 +204,12 @@ export class MemoryRuntimeExecutor implements RuntimeExecutor {
// check if the input is connected to another node
const inputNode = node.state.inputNodes[key];
if (inputNode) {
if (results[inputNode.id] === undefined) {
if (this.results[inputNode.id] === undefined) {
throw new Error(
`Node ${node.type} is missing input from node ${inputNode.type}`
);
}
return results[inputNode.id];
return this.results[inputNode.id];
}
// If the value is stored in the node itself, we use that value
@@ -249,7 +252,7 @@ export class MemoryRuntimeExecutor implements RuntimeExecutor {
b = performance.now();
if (this.cache && node.id !== outputNode.id) {
this.cache.set(inputHash, results[node.id]);
this.cache.set(inputHash, this.results[node.id]);
}
this.perf?.addPoint('node/' + node_type.id, b - a);
@@ -262,7 +265,7 @@ export class MemoryRuntimeExecutor implements RuntimeExecutor {
}
// return the result of the parent of the output node
const res = results[outputNode.id];
const res = this.results[outputNode.id];
if (this.cache) {
this.cache.size = sortedNodes.length * 2;
@@ -273,6 +276,10 @@ export class MemoryRuntimeExecutor implements RuntimeExecutor {
return res as unknown as Int32Array;
}
getIntermediateResults() {
return this.results;
}
getPerformanceData() {
return this.perf?.get();
}