1use std::{
4 borrow::Cow,
5 collections::{HashMap, VecDeque},
6};
7
8use log::error;
9use smallvec::{smallvec, SmallVec};
10use thiserror::Error;
11
12use crate::{
13 render::{
14 graph::{
15 Edge, NodeId, NodeRunError, NodeState, RenderContext, RenderGraph, RenderGraphContext,
16 SlotLabel, SlotType, SlotValue,
17 },
18 RenderResources,
19 },
20 tcs::world::World,
21};
22
23pub(crate) struct RenderGraphRunner;
24
25#[derive(Error, Debug)]
26pub enum RenderGraphRunnerError {
27 #[error(transparent)]
28 NodeRunError(#[from] NodeRunError),
29 #[error("node output slot not set (index {slot_index}, name {slot_name})")]
30 EmptyNodeOutputSlot {
31 type_name: &'static str,
32 slot_index: usize,
33 slot_name: Cow<'static, str>,
34 },
35 #[error("graph (name: '{graph_name:?}') could not be run because slot '{slot_name}' at index {slot_index} has no value")]
36 MissingInput {
37 slot_index: usize,
38 slot_name: Cow<'static, str>,
39 graph_name: Option<Cow<'static, str>>,
40 },
41 #[error("attempted to use the wrong type for input slot")]
42 MismatchedInputSlotType {
43 slot_index: usize,
44 label: SlotLabel,
45 expected: SlotType,
46 actual: SlotType,
47 },
48}
49
50impl RenderGraphRunner {
51 pub fn run(
52 graph: &RenderGraph,
53 device: &wgpu::Device,
54 queue: &wgpu::Queue,
55 state: &RenderResources,
56 world: &World,
57 ) -> Result<(), RenderGraphRunnerError> {
58 let command_encoder =
59 device.create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
60 let mut render_context = RenderContext {
61 device,
62 command_encoder,
63 };
64
65 Self::run_graph(graph, None, &mut render_context, state, world, &[])?;
66 {
67 #[cfg(feature = "trace")]
68 let _span = tracing::info_span!("submit_graph_commands").entered();
69 queue.submit(vec![render_context.command_encoder.finish()]);
70 }
71 Ok(())
72 }
73
74 fn run_graph(
75 graph: &RenderGraph,
76 graph_name: Option<Cow<'static, str>>,
77 render_context: &mut RenderContext,
78 state: &RenderResources,
79 world: &World,
80 inputs: &[SlotValue],
81 ) -> Result<(), RenderGraphRunnerError> {
82 let mut node_outputs: HashMap<NodeId, SmallVec<[SlotValue; 4]>> = HashMap::default();
83 #[cfg(feature = "trace")]
84 let span = if let Some(name) = &graph_name {
85 tracing::info_span!("run_graph", name = name.as_ref())
86 } else {
87 tracing::info_span!("run_graph", name = "main_graph")
88 };
89 #[cfg(feature = "trace")]
90 let _guard = span.enter();
91
92 let mut node_queue: VecDeque<&NodeState> = graph
94 .iter_nodes()
95 .filter(|node| node.input_slots.is_empty())
96 .collect();
97
98 if let Some(input_node) = graph.input_node() {
100 let mut input_values: SmallVec<[SlotValue; 4]> = SmallVec::new();
101 for (i, input_slot) in input_node.input_slots.iter().enumerate() {
102 if let Some(input_value) = inputs.get(i) {
103 if input_slot.slot_type != input_value.slot_type() {
104 return Err(RenderGraphRunnerError::MismatchedInputSlotType {
105 slot_index: i,
106 actual: input_value.slot_type(),
107 expected: input_slot.slot_type,
108 label: input_slot.name.clone().into(),
109 });
110 } else {
111 input_values.push(input_value.clone());
112 }
113 } else {
114 return Err(RenderGraphRunnerError::MissingInput {
115 slot_index: i,
116 slot_name: input_slot.name.clone(),
117 graph_name: graph_name.clone(),
118 });
119 }
120 }
121
122 node_outputs.insert(input_node.id, input_values);
123
124 for (_, node_state) in graph.iter_node_outputs(input_node.id).expect("node exists") {
125 node_queue.push_front(node_state);
126 }
127 }
128
129 'handle_node: while let Some(node_state) = node_queue.pop_back() {
130 if node_outputs.contains_key(&node_state.id) {
132 continue;
133 }
134
135 let mut slot_indices_and_inputs: SmallVec<[(usize, SlotValue); 4]> = SmallVec::new();
136 for (edge, input_node) in graph
138 .iter_node_inputs(node_state.id)
139 .expect("node is in graph")
140 {
141 match edge {
142 Edge::SlotEdge {
143 output_index,
144 input_index,
145 ..
146 } => {
147 if let Some(outputs) = node_outputs.get(&input_node.id) {
148 slot_indices_and_inputs
149 .push((*input_index, outputs[*output_index].clone()));
150 } else {
151 node_queue.push_front(node_state);
152 continue 'handle_node;
153 }
154 }
155 Edge::NodeEdge { .. } => {
156 if !node_outputs.contains_key(&input_node.id) {
157 node_queue.push_front(node_state);
158 continue 'handle_node;
159 }
160 }
161 }
162 }
163
164 slot_indices_and_inputs.sort_by_key(|(index, _)| *index);
166 let inputs: SmallVec<[SlotValue; 4]> = slot_indices_and_inputs
167 .into_iter()
168 .map(|(_, value)| value)
169 .collect();
170
171 assert_eq!(inputs.len(), node_state.input_slots.len());
172
173 let mut outputs: SmallVec<[Option<SlotValue>; 4]> =
174 smallvec![None; node_state.output_slots.len()];
175 {
176 let mut context = RenderGraphContext::new(graph, node_state, &inputs, &mut outputs);
177 {
178 #[cfg(feature = "trace")]
179 let _span = tracing::info_span!("node", name = node_state.type_name).entered();
180
181 node_state
182 .node
183 .run(&mut context, render_context, state, world)?;
184 }
185
186 for run_sub_graph in context.finish() {
187 let sub_graph = graph
188 .get_sub_graph(&run_sub_graph.name)
189 .expect("sub graph exists because it was validated when queued.");
190 Self::run_graph(
191 sub_graph,
192 Some(run_sub_graph.name),
193 render_context,
194 state,
195 world,
196 &run_sub_graph.inputs,
197 )?;
198 }
199 }
200
201 let mut values: SmallVec<[SlotValue; 4]> = SmallVec::new();
202 for (i, output) in outputs.into_iter().enumerate() {
203 if let Some(value) = output {
204 values.push(value);
205 } else {
206 let empty_slot = node_state.output_slots.get_slot(i).unwrap();
207 return Err(RenderGraphRunnerError::EmptyNodeOutputSlot {
208 type_name: node_state.type_name,
209 slot_index: i,
210 slot_name: empty_slot.name.clone(),
211 });
212 }
213 }
214 node_outputs.insert(node_state.id, values);
215
216 for (_, node_state) in graph.iter_node_outputs(node_state.id).expect("node exists") {
217 node_queue.push_front(node_state);
218 }
219 }
220
221 Ok(())
222 }
223}