1use std::{
4 borrow::Cow,
5 collections::{HashMap, VecDeque},
6};
7
8use smallvec::{smallvec, SmallVec};
9use thiserror::Error;
10
11use crate::{
12 render::{
13 graph::{
14 Edge, NodeId, NodeRunError, NodeState, RenderContext, RenderGraph, RenderGraphContext,
15 SlotLabel, SlotType, SlotValue,
16 },
17 RenderResources,
18 },
19 tcs::world::World,
20};
21
22pub(crate) struct RenderGraphRunner;
23
24#[derive(Error, Debug)]
25pub enum RenderGraphRunnerError {
26 #[error(transparent)]
27 NodeRunError(#[from] NodeRunError),
28 #[error("node output slot not set (index {slot_index}, name {slot_name})")]
29 EmptyNodeOutputSlot {
30 type_name: &'static str,
31 slot_index: usize,
32 slot_name: Cow<'static, str>,
33 },
34 #[error("graph (name: '{graph_name:?}') could not be run because slot '{slot_name}' at index {slot_index} has no value")]
35 MissingInput {
36 slot_index: usize,
37 slot_name: Cow<'static, str>,
38 graph_name: Option<Cow<'static, str>>,
39 },
40 #[error("attempted to use the wrong type for input slot")]
41 MismatchedInputSlotType {
42 slot_index: usize,
43 label: SlotLabel,
44 expected: SlotType,
45 actual: SlotType,
46 },
47}
48
49impl RenderGraphRunner {
50 pub fn run(
51 graph: &RenderGraph,
52 device: &wgpu::Device,
53 queue: &wgpu::Queue,
54 state: &RenderResources,
55 world: &World,
56 ) -> Result<(), RenderGraphRunnerError> {
57 let command_encoder =
58 device.create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
59 let mut render_context = RenderContext {
60 device,
61 command_encoder,
62 };
63
64 Self::run_graph(graph, None, &mut render_context, state, world, &[])?;
65 {
66 #[cfg(feature = "trace")]
67 let _span = tracing::info_span!("submit_graph_commands").entered();
68 queue.submit(vec![render_context.command_encoder.finish()]);
69 }
70 Ok(())
71 }
72
73 fn run_graph(
74 graph: &RenderGraph,
75 graph_name: Option<Cow<'static, str>>,
76 render_context: &mut RenderContext,
77 state: &RenderResources,
78 world: &World,
79 inputs: &[SlotValue],
80 ) -> Result<(), RenderGraphRunnerError> {
81 let mut node_outputs: HashMap<NodeId, SmallVec<[SlotValue; 4]>> = HashMap::default();
82 #[cfg(feature = "trace")]
83 let span = if let Some(name) = &graph_name {
84 tracing::info_span!("run_graph", name = name.as_ref())
85 } else {
86 tracing::info_span!("run_graph", name = "main_graph")
87 };
88 #[cfg(feature = "trace")]
89 let _guard = span.enter();
90
91 let mut node_queue: VecDeque<&NodeState> = graph
93 .iter_nodes()
94 .filter(|node| node.input_slots.is_empty())
95 .collect();
96
97 if let Some(input_node) = graph.input_node() {
99 let mut input_values: SmallVec<[SlotValue; 4]> = SmallVec::new();
100 for (i, input_slot) in input_node.input_slots.iter().enumerate() {
101 if let Some(input_value) = inputs.get(i) {
102 if input_slot.slot_type != input_value.slot_type() {
103 return Err(RenderGraphRunnerError::MismatchedInputSlotType {
104 slot_index: i,
105 actual: input_value.slot_type(),
106 expected: input_slot.slot_type,
107 label: input_slot.name.clone().into(),
108 });
109 } else {
110 input_values.push(input_value.clone());
111 }
112 } else {
113 return Err(RenderGraphRunnerError::MissingInput {
114 slot_index: i,
115 slot_name: input_slot.name.clone(),
116 graph_name: graph_name.clone(),
117 });
118 }
119 }
120
121 node_outputs.insert(input_node.id, input_values);
122
123 for (_, node_state) in graph.iter_node_outputs(input_node.id).expect("node exists") {
124 node_queue.push_front(node_state);
125 }
126 }
127
128 'handle_node: while let Some(node_state) = node_queue.pop_back() {
129 if node_outputs.contains_key(&node_state.id) {
131 continue;
132 }
133
134 let mut slot_indices_and_inputs: SmallVec<[(usize, SlotValue); 4]> = SmallVec::new();
135 for (edge, input_node) in graph
137 .iter_node_inputs(node_state.id)
138 .expect("node is in graph")
139 {
140 match edge {
141 Edge::SlotEdge {
142 output_index,
143 input_index,
144 ..
145 } => {
146 if let Some(outputs) = node_outputs.get(&input_node.id) {
147 slot_indices_and_inputs
148 .push((*input_index, outputs[*output_index].clone()));
149 } else {
150 node_queue.push_front(node_state);
151 continue 'handle_node;
152 }
153 }
154 Edge::NodeEdge { .. } => {
155 if !node_outputs.contains_key(&input_node.id) {
156 node_queue.push_front(node_state);
157 continue 'handle_node;
158 }
159 }
160 }
161 }
162
163 slot_indices_and_inputs.sort_by_key(|(index, _)| *index);
165 let inputs: SmallVec<[SlotValue; 4]> = slot_indices_and_inputs
166 .into_iter()
167 .map(|(_, value)| value)
168 .collect();
169
170 assert_eq!(inputs.len(), node_state.input_slots.len());
171
172 let mut outputs: SmallVec<[Option<SlotValue>; 4]> =
173 smallvec![None; node_state.output_slots.len()];
174 {
175 let mut context = RenderGraphContext::new(graph, node_state, &inputs, &mut outputs);
176 {
177 #[cfg(feature = "trace")]
178 let _span = tracing::info_span!("node", name = node_state.type_name).entered();
179
180 node_state
181 .node
182 .run(&mut context, render_context, state, world)?;
183 }
184
185 for run_sub_graph in context.finish() {
186 let sub_graph = graph
187 .get_sub_graph(&run_sub_graph.name)
188 .expect("sub graph exists because it was validated when queued.");
189 Self::run_graph(
190 sub_graph,
191 Some(run_sub_graph.name),
192 render_context,
193 state,
194 world,
195 &run_sub_graph.inputs,
196 )?;
197 }
198 }
199
200 let mut values: SmallVec<[SlotValue; 4]> = SmallVec::new();
201 for (i, output) in outputs.into_iter().enumerate() {
202 if let Some(value) = output {
203 values.push(value);
204 } else {
205 let empty_slot = node_state.output_slots.get_slot(i).unwrap();
206 return Err(RenderGraphRunnerError::EmptyNodeOutputSlot {
207 type_name: node_state.type_name,
208 slot_index: i,
209 slot_name: empty_slot.name.clone(),
210 });
211 }
212 }
213 node_outputs.insert(node_state.id, values);
214
215 for (_, node_state) in graph.iter_node_outputs(node_state.id).expect("node exists") {
216 node_queue.push_front(node_state);
217 }
218 }
219
220 Ok(())
221 }
222}