diff --git a/app/src/lib/graph-interface/components/AddMenu.svelte b/app/src/lib/graph-interface/components/AddMenu.svelte index 36d46c6..391be4f 100644 --- a/app/src/lib/graph-interface/components/AddMenu.svelte +++ b/app/src/lib/graph-interface/components/AddMenu.svelte @@ -183,7 +183,7 @@ activeNodeId = node.id; }} > - {node.id.split('/').at(-1)} + {node.meta?.title ?? node.id.split('/').at(-1)} {/each} diff --git a/app/src/lib/graph-interface/graph-manager.svelte.ts b/app/src/lib/graph-interface/graph-manager.svelte.ts index 1a8907f..31c51a0 100644 --- a/app/src/lib/graph-interface/graph-manager.svelte.ts +++ b/app/src/lib/graph-interface/graph-manager.svelte.ts @@ -10,6 +10,7 @@ import type { NodeRegistry, Socket } from '@nodarium/types'; +import { type GroupDefinition } from '@nodarium/types'; import { fastHashString } from '@nodarium/utils'; import { createLogger } from '@nodarium/utils'; import { SvelteMap, SvelteSet } from 'svelte/reactivity'; @@ -67,7 +68,7 @@ export class GraphManager extends EventEmitter<{ status = $state<'loading' | 'idle' | 'error'>(); loaded = false; - graph: Graph = { id: 0, nodes: [], edges: [] }; + graph: Graph = { id: 0, nodes: [], edges: [], groups: [] }; id = $state(0); nodes = new SvelteMap(); @@ -110,10 +111,36 @@ export class GraphManager extends EventEmitter<{ edge[2].id, edge[3] ]) as Graph['edges']; + + const groups = this.graph.groups?.map((group) => { + const groupNodes = group.nodes.map((node) => ({ + id: node.id, + position: [...node.position], + type: node.type, + props: node.props + })) as NodeInstance[]; + + const groupEdges = this.edges.map((edge) => [ + edge[0].id, + edge[1], + edge[2].id, + edge[3] + ]) as Graph['edges']; + + return { + id: group.id, + inputs: group.inputs, + outputs: group.outputs, + nodes: groupNodes, + edges: groupEdges + }; + }); + const serialized = { id: this.graph.id, settings: $state.snapshot(this.settings), meta: $state.snapshot(this.graph.meta), + groups, nodes, edges }; @@ -311,13 +338,16 @@ export class GraphManager extends EventEmitter<{ logger.info('loading graph', { nodes: graph.nodes, edges: graph.edges, id: graph.id }); - const nodeIds = Array.from(new SvelteSet([...graph.nodes.map((n) => n.type)])); + const nodeIds = Array + .from(new SvelteSet([...graph.nodes.map((n) => n.type)])) + .filter(n => !n.startsWith('__internal/')); await this.registry.load(nodeIds); // Fetch all nodes from all collections of the loaded nodes const allCollections = new SvelteSet<`${string}/${string}`>(); for (const id of nodeIds) { const [user, collection] = id.split('/'); + if (user === '__internal') continue; allCollections.add(`${user}/${collection}`); } for (const collection of allCollections) { @@ -333,7 +363,7 @@ export class GraphManager extends EventEmitter<{ for (const node of this.graph.nodes) { const nodeType = this.registry.getNode(node.type); - if (!nodeType) { + if (!nodeType && !node.type.startsWith('__internal/')) { logger.error(`Node type not found: ${node.type}`); this.status = 'error'; return; @@ -389,15 +419,47 @@ export class GraphManager extends EventEmitter<{ } getAllNodes() { - return Array.from(this.nodes.values()); + this.graph.groups ??= []; + if (!this.graph.groups.length) { + this.graph.groups.push({ + id: 0, + nodes: [], + edges: [] + }); + } + + return Array + .from(this.nodes.values()); } getNode(id: number) { return this.nodes.get(id); } - getNodeType(id: string) { - return this.registry.getNode(id); + getNodeType(node: NodeInstance) { + // Construct the inputs on the fly + if (node.type === '__internal/group/instance') { + const groupDefinition = this.getGroup(node.props?.groupId as number); + + const inputs = { + 'groupId': { + type: 'select', + label: '', + value: node.props?.groupId.toString(), + internal: true, + options: this.graph.groups.map(g => g.id.toString()) + }, + ...(node.state.type?.inputs || {}), + ...groupDefinition?.inputs + }; + + return { + ...node.state.type, + inputs + } as NodeDefinition; + } + + return node.state.type; } async loadNodeType(id: NodeId) { @@ -502,6 +564,14 @@ export class GraphManager extends EventEmitter<{ } } + createGroupId() { + return Math.max(0, ...this.graph.groups.keys()) + 1; + } + + getGroup(id: number) { + return this.graph.groups.find(g => g.id === id); + } + createNodeId() { return Math.max(0, ...this.nodes.keys()) + 1; } @@ -579,6 +649,26 @@ export class GraphManager extends EventEmitter<{ return node; } + createGroupNode(position: [number, number], groupDefinition: GroupDefinition): NodeInstance { + this.graph.groups ??= []; + this.graph.groups.push(groupDefinition); + const node = { + id: this.createNodeId(), + type: '__internal/group/instance', + meta: { + title: 'Group' + }, + props: { + groupId: groupDefinition.id + }, + position, + state: {} + } as const; + + this.nodes.set(node.id, node); + return node; + } + createEdge( from: NodeInstance, fromSocket: number, @@ -597,11 +687,14 @@ export class GraphManager extends EventEmitter<{ return; } + const fromType = this.getNodeType(from); + const toType = this.getNodeType(to); + // check if socket types match - const fromSocketType = from.state?.type?.outputs?.[fromSocket]; - const toSocketType = [to.state?.type?.inputs?.[toSocket]?.type]; - if (to.state?.type?.inputs?.[toSocket]?.accepts) { - toSocketType.push(...(to?.state?.type?.inputs?.[toSocket]?.accepts || [])); + const fromSocketType = fromType?.outputs?.[fromSocket]; + const toSocketType = [toType?.inputs?.[toSocket]?.type]; + if (toType?.inputs?.[toSocket]?.accepts) { + toSocketType.push(...(toType?.inputs?.[toSocket]?.accepts || [])); } if (!areSocketsCompatible(fromSocketType, toSocketType)) { @@ -724,7 +817,7 @@ export class GraphManager extends EventEmitter<{ } getPossibleSockets({ node, index }: Socket): [NodeInstance, string | number][] { - const nodeType = node?.state?.type; + const nodeType = this.getNodeType(node); if (!nodeType) return []; const sockets: [NodeInstance, string | number][] = []; @@ -740,7 +833,7 @@ export class GraphManager extends EventEmitter<{ const ownType = nodeType?.inputs?.[index].type; for (const node of nodes) { - const nodeType = node?.state?.type; + const nodeType = this.getNodeType(node); const inputs = nodeType?.outputs; if (!inputs) continue; for (let index = 0; index < inputs.length; index++) { @@ -772,7 +865,7 @@ export class GraphManager extends EventEmitter<{ const ownType = nodeType.outputs?.[index]; for (const node of nodes) { - const inputs = node?.state?.type?.inputs; + const inputs = this.getNodeType(node)?.inputs; if (!inputs) continue; for (const key in inputs) { const otherType = [inputs[key].type]; diff --git a/app/src/lib/graph-interface/graph-state.svelte.ts b/app/src/lib/graph-interface/graph-state.svelte.ts index 7127bb1..53cb68b 100644 --- a/app/src/lib/graph-interface/graph-state.svelte.ts +++ b/app/src/lib/graph-interface/graph-state.svelte.ts @@ -1,11 +1,11 @@ import { animate, lerp } from '$lib/helpers'; -import type { NodeInstance, Socket } from '@nodarium/types'; +import type { Box, Edge, GroupDefinition, NodeInput, NodeInstance, Socket } from '@nodarium/types'; import { getContext, setContext } from 'svelte'; import { SvelteMap, SvelteSet } from 'svelte/reactivity'; import type { OrthographicCamera, Vector3 } from 'three'; import type { GraphManager } from './graph-manager.svelte'; import { ColorGenerator } from './graph/colors'; -import { getNodeHeight, getSocketPosition } from './helpers/nodeHelpers'; +import { getNodeHeight, getParameterHeight } from './helpers/nodeHelpers'; const graphStateKey = Symbol('graph-state'); export function getGraphState() { @@ -203,7 +203,7 @@ export class GraphState { } const debugNode = this.graph.createNode({ - type: 'max/plantarium/debug', + type: '__internal/node/debug', position: [node.position[0] + 30, node.position[1]], props: {} }); @@ -240,6 +240,119 @@ export class GraphState { }; } + groupSelectedNodes(nodeIds = [...this.selectedNodes.keys(), this.activeNodeId]) { + const ids = new Set(nodeIds); + const nodes = [ + ...ids.values().map(id => this.graph.getNode(id)).filter(Boolean) + ] as NodeInstance[]; + + const incomingEdges = this.graph.edges.filter((edge) => + ids.has(edge[2].id) && !ids.has(edge[0].id) + ); + const groupInputs = new Map(); + for (const edge of incomingEdges) { + groupInputs.set(`${edge[0].id}-${edge[1]}`, edge); + } + + const outgoingEdges = this.graph.edges.filter((edge) => + ids.has(edge[0].id) && !ids.has(edge[2].id) + ); + const groupOutputs = new Map(); + for (const edge of outgoingEdges) { + groupOutputs.set(`${edge[2].id}-${edge[3]}`, edge); + } + + const inputs: Record = {}; + [...groupInputs.values()].forEach((edge, i) => { + const input = { + label: `Input ${i}`, + type: edge[0].state.type?.outputs?.[edge[1]] || '*' + }; + inputs[`input_${i}`] = input as NodeInput; + }); + + const outputs = [...groupOutputs.values()].map((edge, i) => ({ + label: `Output ${i}`, + type: edge[2].state.type?.inputs?.[edge[3]].type + })); + + const groupPosition = [0, 0] as [number, number]; + const bounds: Box = { minX: Infinity, maxX: -Infinity, minY: Infinity, maxY: -Infinity }; + for (const node of nodes) { + groupPosition[0] += node.position[0]; + groupPosition[1] += node.position[1]; + bounds.minX = Math.min(bounds.minX, node.position[0]); + bounds.maxX = Math.max(bounds.maxX, node.position[0]); + bounds.minY = Math.min(bounds.minY, node.position[1]); + bounds.maxY = Math.max(bounds.maxY, node.position[1]); + } + groupPosition[0] /= nodes.length; + groupPosition[1] /= nodes.length; + + const groupInputNode: NodeInstance = { + id: this.graph.createNodeId(), + type: '__internal/group/input', + position: [bounds.minX - 50, (bounds.minY + bounds.maxY) / 2], + state: {} + }; + + const groupOutputNode: NodeInstance = { + id: this.graph.createNodeId(), + type: '__internal/group/output', + position: [bounds.maxX + 25, (bounds.minY + bounds.maxY) / 2], + state: {} + }; + + // Edges that are inside the group + const internalEdges = this.graph.edges.filter((edge) => { + return ids.has(edge[0].id) || ids.has(edge[2].id); + }).map((edge) => { + // Going in from the group + if (!ids.has(edge[0].id)) { + return [groupInputNode, 0, edge[2], edge[3]]; + // Going out to the group + } else if (!ids.has(edge[2].id)) { + return [edge[0], edge[1], groupOutputNode, 'Out']; + } + return edge; + }); + + const groupId = this.graph.createGroupId(); + const groupDefinition: GroupDefinition = { + id: groupId, + inputs: inputs, + outputs: outputs, + edges: internalEdges, + nodes: [groupInputNode, ...nodes, groupOutputNode] + }; + const groupNode = this.graph.createGroupNode(groupPosition, groupDefinition); + + // Update the edges that are now inside + // the group to be connected to that group node + const externalEdges = this.graph.edges.map((edge) => { + if (ids.has(edge[2].id)) { + // Edge going into the group + return [edge[0], edge[1], groupNode, 'input_0'] as Edge; + } else if (ids.has(edge[0].id)) { + // Edge going out of the group + return [groupNode, 0, edge[2], edge[3]] as Edge; + } + return edge; + }); + + for (const node of nodes) { + this.graph.nodes.delete(node.id); + } + this.graph.edges = externalEdges; + this.graph.saveUndoGroup(); + console.log( + $state.snapshot({ + groupNode, + groupDefinition + }) + ); + } + centerNode(node?: NodeInstance) { const average = [0, 0, 4]; if (node) { @@ -301,7 +414,7 @@ export class GraphState { if (edge[3] === index) { node = edge[0]; index = edge[1]; - position = getSocketPosition(node, index); + position = this.getSocketPosition(node, index); this.graph.removeEdge(edge); break; } @@ -321,7 +434,7 @@ export class GraphState { return { node, index, - position: getSocketPosition(node, index) + position: this.getSocketPosition(node, index) }; }); } @@ -358,7 +471,7 @@ export class GraphState { for (const node of this.graph.nodes.values()) { const x = node.position[0]; const y = node.position[1]; - const height = getNodeHeight(node.state.type!); + const height = getNodeHeight(this.graph.getNodeType(node)!); if (downX > x && downX < x + 20 && downY > y && downY < y + height) { clickedNodeId = node.id; break; @@ -370,7 +483,7 @@ export class GraphState { } isNodeInView(node: NodeInstance) { - const height = getNodeHeight(node.state.type!); + const height = getNodeHeight(this.graph.getNodeType(node)!); const width = 20; return node.position[0] > this.cameraBounds[0] - width && node.position[0] < this.cameraBounds[1] @@ -381,4 +494,38 @@ export class GraphState { openNodePalette() { this.addMenuPosition = [this.mousePosition[0], this.mousePosition[1]]; } + + enterGroupNode() { + if (this.activeNodeId === -1) return; + const selectedNode = this.graph.getNode(this.activeNodeId); + if (!selectedNode || selectedNode.type.startsWith('__internal/group/instance')) return; + } + + getSocketPosition( + node: NodeInstance, + index: string | number + ): [number, number] { + if (typeof index === 'number') { + return [ + (node?.state?.x ?? node.position[0]) + 20, + (node?.state?.y ?? node.position[1]) + 2.5 + 10 * index + ]; + } else { + let height = 5; + const nodeType = this.graph.getNodeType(node)!; + const inputs = nodeType.inputs || {}; + for (const inputKey in inputs) { + const h = getParameterHeight(nodeType, inputKey) / 10; + if (inputKey === index) { + height += h / 2; + break; + } + height += h; + } + return [ + node?.state?.x ?? node.position[0], + (node?.state?.y ?? node.position[1]) + height + ]; + } + } } diff --git a/app/src/lib/graph-interface/graph/Graph.svelte b/app/src/lib/graph-interface/graph/Graph.svelte index 2460265..108b9a3 100644 --- a/app/src/lib/graph-interface/graph/Graph.svelte +++ b/app/src/lib/graph-interface/graph/Graph.svelte @@ -11,7 +11,6 @@ import Debug from '../debug/Debug.svelte'; import EdgeEl from '../edges/Edge.svelte'; import { getGraphManager, getGraphState } from '../graph-state.svelte'; - import { getSocketPosition } from '../helpers/nodeHelpers'; import NodeEl from '../node/Node.svelte'; import { maxZoom, minZoom } from './constants'; import { FileDropEventManager } from './drop.events'; @@ -39,8 +38,8 @@ return [0, 0, 0, 0]; } - const pos1 = getSocketPosition(fromNode, edge[1]); - const pos2 = getSocketPosition(toNode, edge[3]); + const pos1 = graphState.getSocketPosition(fromNode, edge[1]); + const pos2 = graphState.getSocketPosition(toNode, edge[3]); return [pos1[0], pos1[1], pos2[0], pos2[1]]; } @@ -96,11 +95,13 @@ graphState.addMenuPosition = null; } - function getSocketType(node: NodeInstance, index: number | string): string { + function getSocketType(node: NodeInstance, index: number | string, e: unknown): string { + const nodeType = graph.getNodeType(node); + console.log($state.snapshot({ nodeType, index, e })); if (typeof index === 'string') { - return node.state.type?.inputs?.[index].type || 'unknown'; + return nodeType?.inputs?.[index].type || 'unknown'; } - return node.state.type?.outputs?.[index] || 'unknown'; + return nodeType?.outputs?.[index] || 'unknown'; } @@ -182,8 +183,8 @@ {#if graphState.activeSocket} - {#each graph.nodes.values() as node (node.id)} + {#each graph.getAllNodes() as node (node.id)} = {}; export function getNodeHeight(node: NodeDefinition) { + if (!node) { + console.trace('Node is undefined', node); + } if (node.id in nodeHeightCache) { return nodeHeightCache[node.id]; } diff --git a/app/src/lib/graph-interface/keymaps.ts b/app/src/lib/graph-interface/keymaps.ts index 3614ed1..661344e 100644 --- a/app/src/lib/graph-interface/keymaps.ts +++ b/app/src/lib/graph-interface/keymaps.ts @@ -54,6 +54,19 @@ export function setupKeymaps(keymap: Keymap, graph: GraphManager, graphState: Gr } }); + keymap.addShortcut({ + key: 'g', + ctrl: true, + description: 'Group selected nodes', + callback: () => graphState.groupSelectedNodes() + }); + + keymap.addShortcut({ + key: 'Tab', + description: 'Enter selected node group', + callback: () => graphState.enterGroupNode() + }); + keymap.addShortcut({ key: 'A', shift: true, diff --git a/app/src/lib/graph-interface/node/Node.svelte b/app/src/lib/graph-interface/node/Node.svelte index cba7fd7..438183f 100644 --- a/app/src/lib/graph-interface/node/Node.svelte +++ b/app/src/lib/graph-interface/node/Node.svelte @@ -3,13 +3,14 @@ import type { NodeInstance } from '@nodarium/types'; import { T } from '@threlte/core'; import { type Mesh } from 'three'; - import { getGraphState } from '../graph-state.svelte'; + import { getGraphManager, getGraphState } from '../graph-state.svelte'; import { colors } from '../graph/colors.svelte'; import { getNodeHeight, getParameterHeight } from '../helpers/nodeHelpers'; import NodeFrag from './Node.frag'; import NodeVert from './Node.vert'; import NodeHtml from './NodeHTML.svelte'; + const graph = getGraphManager(); const graphState = getGraphState(); type Props = { @@ -18,7 +19,7 @@ }; let { node = $bindable(), inView }: Props = $props(); - const nodeType = $derived(node.state.type!); + const nodeType = $derived(graph.getNodeType(node)!); const isActive = $derived(graphState.activeNodeId === node.id); const isSelected = $derived(graphState.selectedNodes.has(node.id)); @@ -40,7 +41,11 @@ let meshRef: Mesh | undefined = $state(); - const height = getNodeHeight(node.state.type!); + const height = $derived(getNodeHeight(nodeType)); + + if (node.type.startsWith('__internal/')) { + $inspect({ node, nodeType, height, sectionHeights }); + } const zoom = $derived(graphState.cameraPosition[2]); diff --git a/app/src/lib/graph-interface/node/NodeHTML.svelte b/app/src/lib/graph-interface/node/NodeHTML.svelte index 8cb2b8d..75f45f2 100644 --- a/app/src/lib/graph-interface/node/NodeHTML.svelte +++ b/app/src/lib/graph-interface/node/NodeHTML.svelte @@ -1,11 +1,12 @@