1use std::{borrow::Cow, collections::HashMap, fmt::Debug};
2
3use super::{
4 Edge, EdgeExistence, Node, NodeId, NodeLabel, NodeRunError, NodeState, RenderGraphContext,
5 RenderGraphError, SlotInfo, SlotLabel,
6};
7use crate::{
8 render::{graph::RenderContext, RenderResources},
9 tcs::world::World,
10};
11
12#[derive(Default)]
53pub struct RenderGraph {
54 nodes: HashMap<NodeId, NodeState>,
55 node_names: HashMap<Cow<'static, str>, NodeId>,
56 sub_graphs: HashMap<Cow<'static, str>, RenderGraph>,
57 input_node: Option<NodeId>,
58
59 current_id: usize,
60}
61
62impl RenderGraph {
63 pub const INPUT_NODE_NAME: &'static str = "GraphInputNode";
65
66 pub fn update(&mut self, state: &mut RenderResources) {
68 for node in self.nodes.values_mut() {
69 node.node.update(state);
70 }
71
72 for sub_graph in self.sub_graphs.values_mut() {
73 sub_graph.update(state);
74 }
75 }
76
77 pub fn set_input(&mut self, inputs: Vec<SlotInfo>) -> NodeId {
79 assert!(self.input_node.is_none(), "Graph already has an input node");
80
81 let id = self.add_node("GraphInputNode", GraphInputNode { inputs });
82 self.input_node = Some(id);
83 id
84 }
85
86 #[inline]
88 pub fn input_node(&self) -> Option<&NodeState> {
89 self.input_node.and_then(|id| self.get_node_state(id).ok())
90 }
91
92 pub fn add_node<T>(&mut self, name: impl Into<Cow<'static, str>>, node: T) -> NodeId
95 where
96 T: Node,
97 {
98 let id = NodeId::new(self.current_id);
99 self.current_id += 1;
100 let name = name.into();
101 let mut node_state = NodeState::new(id, node);
102 node_state.name = Some(name.clone());
103 self.nodes.insert(id, node_state);
104 self.node_names.insert(name, id);
105 id
106 }
107
108 pub fn remove_node(
111 &mut self,
112 name: impl Into<Cow<'static, str>>,
113 ) -> Result<(), RenderGraphError> {
114 let name = name.into();
115 if let Some(id) = self.node_names.remove(&name) {
116 if let Some(node_state) = self.nodes.remove(&id) {
117 for input_edge in node_state.edges.input_edges().iter() {
120 match input_edge {
121 Edge::SlotEdge {
122 output_node,
123 output_index: _,
124 input_node: _,
125 input_index: _,
126 } => {
127 if let Ok(output_node) = self.get_node_state_mut(*output_node) {
128 output_node.edges.remove_output_edge(input_edge.clone())?;
129 }
130 }
131 Edge::NodeEdge {
132 input_node: _,
133 output_node,
134 } => {
135 if let Ok(output_node) = self.get_node_state_mut(*output_node) {
136 output_node.edges.remove_output_edge(input_edge.clone())?;
137 }
138 }
139 }
140 }
141 for output_edge in node_state.edges.output_edges().iter() {
144 match output_edge {
145 Edge::SlotEdge {
146 output_node: _,
147 output_index: _,
148 input_node,
149 input_index: _,
150 } => {
151 if let Ok(input_node) = self.get_node_state_mut(*input_node) {
152 input_node.edges.remove_input_edge(output_edge.clone())?;
153 }
154 }
155 Edge::NodeEdge {
156 output_node: _,
157 input_node,
158 } => {
159 if let Ok(input_node) = self.get_node_state_mut(*input_node) {
160 input_node.edges.remove_input_edge(output_edge.clone())?;
161 }
162 }
163 }
164 }
165 }
166 }
167
168 Ok(())
169 }
170
171 pub fn get_node_state(
173 &self,
174 label: impl Into<NodeLabel>,
175 ) -> Result<&NodeState, RenderGraphError> {
176 let label = label.into();
177 let node_id = self.get_node_id(&label)?;
178 self.nodes
179 .get(&node_id)
180 .ok_or(RenderGraphError::InvalidNode(label))
181 }
182
183 pub fn get_node_state_mut(
185 &mut self,
186 label: impl Into<NodeLabel>,
187 ) -> Result<&mut NodeState, RenderGraphError> {
188 let label = label.into();
189 let node_id = self.get_node_id(&label)?;
190 self.nodes
191 .get_mut(&node_id)
192 .ok_or(RenderGraphError::InvalidNode(label))
193 }
194
195 pub fn get_node_id(&self, label: impl Into<NodeLabel>) -> Result<NodeId, RenderGraphError> {
197 let label = label.into();
198 match label {
199 NodeLabel::Id(id) => Ok(id),
200 NodeLabel::Name(ref name) => self
201 .node_names
202 .get(name)
203 .cloned()
204 .ok_or(RenderGraphError::InvalidNode(label)),
205 }
206 }
207
208 pub fn get_node<T>(&self, label: impl Into<NodeLabel>) -> Result<&T, RenderGraphError>
210 where
211 T: Node,
212 {
213 self.get_node_state(label).and_then(|n| n.node())
214 }
215
216 pub fn get_node_mut<T>(
218 &mut self,
219 label: impl Into<NodeLabel>,
220 ) -> Result<&mut T, RenderGraphError>
221 where
222 T: Node,
223 {
224 self.get_node_state_mut(label).and_then(|n| n.node_mut())
225 }
226
227 pub fn add_slot_edge(
230 &mut self,
231 output_node: impl Into<NodeLabel>,
232 output_slot: impl Into<SlotLabel>,
233 input_node: impl Into<NodeLabel>,
234 input_slot: impl Into<SlotLabel>,
235 ) -> Result<(), RenderGraphError> {
236 let output_slot = output_slot.into();
237 let input_slot = input_slot.into();
238 let output_node_id = self.get_node_id(output_node)?;
239 let input_node_id = self.get_node_id(input_node)?;
240
241 let output_index = self
242 .get_node_state(output_node_id)?
243 .output_slots
244 .get_slot_index(output_slot.clone())
245 .ok_or(RenderGraphError::InvalidOutputNodeSlot(output_slot))?;
246 let input_index = self
247 .get_node_state(input_node_id)?
248 .input_slots
249 .get_slot_index(input_slot.clone())
250 .ok_or(RenderGraphError::InvalidInputNodeSlot(input_slot))?;
251
252 let edge = Edge::SlotEdge {
253 output_node: output_node_id,
254 output_index,
255 input_node: input_node_id,
256 input_index,
257 };
258
259 self.validate_edge(&edge, EdgeExistence::DoesNotExist)?;
260
261 {
262 let output_node = self.get_node_state_mut(output_node_id)?;
263 output_node.edges.add_output_edge(edge.clone())?;
264 }
265 let input_node = self.get_node_state_mut(input_node_id)?;
266 input_node.edges.add_input_edge(edge)?;
267
268 Ok(())
269 }
270
271 pub fn remove_slot_edge(
274 &mut self,
275 output_node: impl Into<NodeLabel>,
276 output_slot: impl Into<SlotLabel>,
277 input_node: impl Into<NodeLabel>,
278 input_slot: impl Into<SlotLabel>,
279 ) -> Result<(), RenderGraphError> {
280 let output_slot = output_slot.into();
281 let input_slot = input_slot.into();
282 let output_node_id = self.get_node_id(output_node)?;
283 let input_node_id = self.get_node_id(input_node)?;
284
285 let output_index = self
286 .get_node_state(output_node_id)?
287 .output_slots
288 .get_slot_index(output_slot.clone())
289 .ok_or(RenderGraphError::InvalidOutputNodeSlot(output_slot))?;
290 let input_index = self
291 .get_node_state(input_node_id)?
292 .input_slots
293 .get_slot_index(input_slot.clone())
294 .ok_or(RenderGraphError::InvalidInputNodeSlot(input_slot))?;
295
296 let edge = Edge::SlotEdge {
297 output_node: output_node_id,
298 output_index,
299 input_node: input_node_id,
300 input_index,
301 };
302
303 self.validate_edge(&edge, EdgeExistence::Exists)?;
304
305 {
306 let output_node = self.get_node_state_mut(output_node_id)?;
307 output_node.edges.remove_output_edge(edge.clone())?;
308 }
309 let input_node = self.get_node_state_mut(input_node_id)?;
310 input_node.edges.remove_input_edge(edge)?;
311
312 Ok(())
313 }
314
315 pub fn add_node_edge(
318 &mut self,
319 output_node: impl Into<NodeLabel>,
320 input_node: impl Into<NodeLabel>,
321 ) -> Result<(), RenderGraphError> {
322 let output_node_id = self.get_node_id(output_node)?;
323 let input_node_id = self.get_node_id(input_node)?;
324
325 let edge = Edge::NodeEdge {
326 output_node: output_node_id,
327 input_node: input_node_id,
328 };
329
330 self.validate_edge(&edge, EdgeExistence::DoesNotExist)?;
331
332 {
333 let output_node = self.get_node_state_mut(output_node_id)?;
334 output_node.edges.add_output_edge(edge.clone())?;
335 }
336 let input_node = self.get_node_state_mut(input_node_id)?;
337 input_node.edges.add_input_edge(edge)?;
338
339 Ok(())
340 }
341
342 pub fn remove_node_edge(
345 &mut self,
346 output_node: impl Into<NodeLabel>,
347 input_node: impl Into<NodeLabel>,
348 ) -> Result<(), RenderGraphError> {
349 let output_node_id = self.get_node_id(output_node)?;
350 let input_node_id = self.get_node_id(input_node)?;
351
352 let edge = Edge::NodeEdge {
353 output_node: output_node_id,
354 input_node: input_node_id,
355 };
356
357 self.validate_edge(&edge, EdgeExistence::Exists)?;
358
359 {
360 let output_node = self.get_node_state_mut(output_node_id)?;
361 output_node.edges.remove_output_edge(edge.clone())?;
362 }
363 let input_node = self.get_node_state_mut(input_node_id)?;
364 input_node.edges.remove_input_edge(edge)?;
365
366 Ok(())
367 }
368
369 pub fn validate_edge(
372 &mut self,
373 edge: &Edge,
374 should_exist: EdgeExistence,
375 ) -> Result<(), RenderGraphError> {
376 if should_exist == EdgeExistence::Exists && !self.has_edge(edge) {
377 return Err(RenderGraphError::EdgeDoesNotExist(edge.clone()));
378 } else if should_exist == EdgeExistence::DoesNotExist && self.has_edge(edge) {
379 return Err(RenderGraphError::EdgeAlreadyExists(edge.clone()));
380 }
381
382 match *edge {
383 Edge::SlotEdge {
384 output_node,
385 output_index,
386 input_node,
387 input_index,
388 } => {
389 let output_node_state = self.get_node_state(output_node)?;
390 let input_node_state = self.get_node_state(input_node)?;
391
392 let output_slot = output_node_state
393 .output_slots
394 .get_slot(output_index)
395 .ok_or(RenderGraphError::InvalidOutputNodeSlot(SlotLabel::Index(
396 output_index,
397 )))?;
398 let input_slot = input_node_state.input_slots.get_slot(input_index).ok_or(
399 RenderGraphError::InvalidInputNodeSlot(SlotLabel::Index(input_index)),
400 )?;
401
402 if let Some(Edge::SlotEdge {
403 output_node: current_output_node,
404 ..
405 }) = input_node_state.edges.input_edges().iter().find(|e| {
406 if let Edge::SlotEdge {
407 input_index: current_input_index,
408 ..
409 } = e
410 {
411 input_index == *current_input_index
412 } else {
413 false
414 }
415 }) {
416 if should_exist == EdgeExistence::DoesNotExist {
417 return Err(RenderGraphError::NodeInputSlotAlreadyOccupied {
418 node: input_node,
419 input_slot: input_index,
420 occupied_by_node: *current_output_node,
421 });
422 }
423 }
424
425 if output_slot.slot_type != input_slot.slot_type {
426 return Err(RenderGraphError::MismatchedNodeSlots {
427 output_node,
428 output_slot: output_index,
429 input_node,
430 input_slot: input_index,
431 });
432 }
433 }
434 Edge::NodeEdge { .. } => { }
435 }
436
437 Ok(())
438 }
439
440 pub fn has_edge(&self, edge: &Edge) -> bool {
442 let output_node_state = self.get_node_state(edge.get_output_node());
443 let input_node_state = self.get_node_state(edge.get_input_node());
444 if let Ok(output_node_state) = output_node_state {
445 if output_node_state.edges.output_edges().contains(edge) {
446 if let Ok(input_node_state) = input_node_state {
447 if input_node_state.edges.input_edges().contains(edge) {
448 return true;
449 }
450 }
451 }
452 }
453
454 false
455 }
456
457 pub fn iter_nodes(&self) -> impl Iterator<Item = &NodeState> {
459 self.nodes.values()
460 }
461
462 pub fn iter_nodes_mut(&mut self) -> impl Iterator<Item = &mut NodeState> {
464 self.nodes.values_mut()
465 }
466
467 pub fn iter_sub_graphs(&self) -> impl Iterator<Item = (&str, &RenderGraph)> {
469 self.sub_graphs
470 .iter()
471 .map(|(name, graph)| (name.as_ref(), graph))
472 }
473
474 pub fn iter_sub_graphs_mut(&mut self) -> impl Iterator<Item = (&str, &mut RenderGraph)> {
476 self.sub_graphs
477 .iter_mut()
478 .map(|(name, graph)| (name.as_ref(), graph))
479 }
480
481 pub fn iter_node_inputs(
484 &self,
485 label: impl Into<NodeLabel>,
486 ) -> Result<impl Iterator<Item = (&Edge, &NodeState)>, RenderGraphError> {
487 let node = self.get_node_state(label)?;
488 Ok(node
489 .edges
490 .input_edges()
491 .iter()
492 .map(|edge| (edge, edge.get_output_node()))
493 .map(move |(edge, output_node_id)| {
494 (edge, self.get_node_state(output_node_id).unwrap())
495 }))
496 }
497
498 pub fn iter_node_outputs(
501 &self,
502 label: impl Into<NodeLabel>,
503 ) -> Result<impl Iterator<Item = (&Edge, &NodeState)>, RenderGraphError> {
504 let node = self.get_node_state(label)?;
505 Ok(node
506 .edges
507 .output_edges()
508 .iter()
509 .map(|edge| (edge, edge.get_input_node()))
510 .map(move |(edge, input_node_id)| (edge, self.get_node_state(input_node_id).unwrap())))
511 }
512
513 pub fn add_sub_graph(&mut self, name: impl Into<Cow<'static, str>>, sub_graph: RenderGraph) {
516 self.sub_graphs.insert(name.into(), sub_graph);
517 }
518
519 pub fn remove_sub_graph(&mut self, name: impl Into<Cow<'static, str>>) {
522 self.sub_graphs.remove(&name.into());
523 }
524
525 pub fn get_sub_graph(&self, name: impl AsRef<str>) -> Option<&RenderGraph> {
527 self.sub_graphs.get(name.as_ref())
528 }
529
530 pub fn get_sub_graph_mut(&mut self, name: impl AsRef<str>) -> Option<&mut RenderGraph> {
532 self.sub_graphs.get_mut(name.as_ref())
533 }
534}
535
536impl Debug for RenderGraph {
537 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
538 for node in self.iter_nodes() {
539 writeln!(f, "{:?}", node.id)?;
540 writeln!(f, " in: {:?}", node.input_slots)?;
541 writeln!(f, " out: {:?}", node.output_slots)?;
542 }
543
544 Ok(())
545 }
546}
547
548pub struct GraphInputNode {
551 inputs: Vec<SlotInfo>,
552}
553
554impl Node for GraphInputNode {
555 fn input(&self) -> Vec<SlotInfo> {
556 self.inputs.clone()
557 }
558
559 fn output(&self) -> Vec<SlotInfo> {
560 self.inputs.clone()
561 }
562
563 fn run(
564 &self,
565 graph: &mut RenderGraphContext,
566 _render_context: &mut RenderContext,
567 _state: &RenderResources,
568 _world: &World,
569 ) -> Result<(), NodeRunError> {
570 for i in 0..graph.inputs().len() {
571 let input = graph.inputs()[i].clone();
572 graph.set_output(i, input)?;
573 }
574 Ok(())
575 }
576}
577
578#[cfg(test)]
579mod tests {
580 use std::collections::HashSet;
581
582 use super::{
583 Edge, Node, NodeId, NodeRunError, RenderGraph, RenderGraphContext, RenderGraphError,
584 SlotInfo,
585 };
586 use crate::{
587 render::{
588 graph::{RenderContext, SlotType},
589 RenderResources,
590 },
591 tcs::world::World,
592 };
593
594 #[derive(Debug)]
595 struct TestNode {
596 inputs: Vec<SlotInfo>,
597 outputs: Vec<SlotInfo>,
598 }
599
600 impl TestNode {
601 pub fn new(inputs: usize, outputs: usize) -> Self {
602 TestNode {
603 inputs: (0..inputs)
604 .map(|i| SlotInfo::new(format!("in_{i}"), SlotType::TextureView))
605 .collect(),
606 outputs: (0..outputs)
607 .map(|i| SlotInfo::new(format!("out_{i}"), SlotType::TextureView))
608 .collect(),
609 }
610 }
611 }
612
613 impl Node for TestNode {
614 fn input(&self) -> Vec<SlotInfo> {
615 self.inputs.clone()
616 }
617
618 fn output(&self) -> Vec<SlotInfo> {
619 self.outputs.clone()
620 }
621
622 fn run(
623 &self,
624 _graph: &mut RenderGraphContext,
625 _render_context: &mut RenderContext,
626 _state: &RenderResources,
627 _world: &World,
628 ) -> Result<(), NodeRunError> {
629 Ok(())
630 }
631 }
632
633 #[test]
634 fn test_graph_edges() {
635 let mut graph = RenderGraph::default();
636 let a_id = graph.add_node("A", TestNode::new(0, 1));
637 let b_id = graph.add_node("B", TestNode::new(0, 1));
638 let c_id = graph.add_node("C", TestNode::new(1, 1));
639 let d_id = graph.add_node("D", TestNode::new(1, 0));
640
641 graph.add_slot_edge("A", "out_0", "C", "in_0").unwrap();
642 graph.add_node_edge("B", "C").unwrap();
643 graph.add_slot_edge("C", 0, "D", 0).unwrap();
644
645 fn input_nodes(name: &'static str, graph: &RenderGraph) -> HashSet<NodeId> {
646 graph
647 .iter_node_inputs(name)
648 .unwrap()
649 .map(|(_edge, node)| node.id)
650 .collect::<HashSet<NodeId>>()
651 }
652
653 fn output_nodes(name: &'static str, graph: &RenderGraph) -> HashSet<NodeId> {
654 graph
655 .iter_node_outputs(name)
656 .unwrap()
657 .map(|(_edge, node)| node.id)
658 .collect::<HashSet<NodeId>>()
659 }
660
661 assert!(input_nodes("A", &graph).is_empty(), "A has no inputs");
662 assert_eq!(
663 output_nodes("A", &graph),
664 HashSet::from_iter(vec![c_id]),
665 "A outputs to C"
666 );
667
668 assert!(input_nodes("B", &graph).is_empty(), "B has no inputs");
669 assert_eq!(
670 output_nodes("B", &graph),
671 HashSet::from_iter(vec![c_id]),
672 "B outputs to C"
673 );
674
675 assert_eq!(
676 input_nodes("C", &graph),
677 HashSet::from_iter(vec![a_id, b_id]),
678 "A and B input to C"
679 );
680 assert_eq!(
681 output_nodes("C", &graph),
682 HashSet::from_iter(vec![d_id]),
683 "C outputs to D"
684 );
685
686 assert_eq!(
687 input_nodes("D", &graph),
688 HashSet::from_iter(vec![c_id]),
689 "C inputs to D"
690 );
691 assert!(output_nodes("D", &graph).is_empty(), "D has no outputs");
692 }
693
694 #[test]
695 fn test_get_node_typed() {
696 struct MyNode {
697 value: usize,
698 }
699
700 impl Node for MyNode {
701 fn run(
702 &self,
703 _graph: &mut RenderGraphContext,
704 _render_context: &mut RenderContext,
705 _state: &RenderResources,
706 _world: &World,
707 ) -> Result<(), NodeRunError> {
708 Ok(())
709 }
710 }
711
712 let mut graph = RenderGraph::default();
713
714 graph.add_node("A", MyNode { value: 42 });
715
716 let node: &MyNode = graph.get_node("A").unwrap();
717 assert_eq!(node.value, 42, "node value matches");
718
719 let result: Result<&TestNode, RenderGraphError> = graph.get_node("A");
720 assert_eq!(
721 result.unwrap_err(),
722 RenderGraphError::WrongNodeType,
723 "expect a wrong node type error"
724 );
725 }
726
727 #[test]
728 fn test_slot_already_occupied() {
729 let mut graph = RenderGraph::default();
730
731 graph.add_node("A", TestNode::new(0, 1));
732 graph.add_node("B", TestNode::new(0, 1));
733 graph.add_node("C", TestNode::new(1, 1));
734
735 graph.add_slot_edge("A", 0, "C", 0).unwrap();
736 assert_eq!(
737 graph.add_slot_edge("B", 0, "C", 0),
738 Err(RenderGraphError::NodeInputSlotAlreadyOccupied {
739 node: graph.get_node_id("C").unwrap(),
740 input_slot: 0,
741 occupied_by_node: graph.get_node_id("A").unwrap(),
742 }),
743 "Adding to a slot that is already occupied should return an error"
744 );
745 }
746
747 #[test]
748 fn test_edge_already_exists() {
749 let mut graph = RenderGraph::default();
750
751 graph.add_node("A", TestNode::new(0, 1));
752 graph.add_node("B", TestNode::new(1, 0));
753
754 graph.add_slot_edge("A", 0, "B", 0).unwrap();
755 assert_eq!(
756 graph.add_slot_edge("A", 0, "B", 0),
757 Err(RenderGraphError::EdgeAlreadyExists(Edge::SlotEdge {
758 output_node: graph.get_node_id("A").unwrap(),
759 output_index: 0,
760 input_node: graph.get_node_id("B").unwrap(),
761 input_index: 0,
762 })),
763 "Adding to a duplicate edge should return an error"
764 );
765 }
766}