maplibre/render/
graph_runner.rs

1//! Executes a [`RenderGraph`]
2
3use std::{
4    borrow::Cow,
5    collections::{HashMap, VecDeque},
6};
7
8use smallvec::{smallvec, SmallVec};
9use thiserror::Error;
10
11use crate::{
12    render::{
13        graph::{
14            Edge, NodeId, NodeRunError, NodeState, RenderContext, RenderGraph, RenderGraphContext,
15            SlotLabel, SlotType, SlotValue,
16        },
17        RenderResources,
18    },
19    tcs::world::World,
20};
21
22pub(crate) struct RenderGraphRunner;
23
24#[derive(Error, Debug)]
25pub enum RenderGraphRunnerError {
26    #[error(transparent)]
27    NodeRunError(#[from] NodeRunError),
28    #[error("node output slot not set (index {slot_index}, name {slot_name})")]
29    EmptyNodeOutputSlot {
30        type_name: &'static str,
31        slot_index: usize,
32        slot_name: Cow<'static, str>,
33    },
34    #[error("graph (name: '{graph_name:?}') could not be run because slot '{slot_name}' at index {slot_index} has no value")]
35    MissingInput {
36        slot_index: usize,
37        slot_name: Cow<'static, str>,
38        graph_name: Option<Cow<'static, str>>,
39    },
40    #[error("attempted to use the wrong type for input slot")]
41    MismatchedInputSlotType {
42        slot_index: usize,
43        label: SlotLabel,
44        expected: SlotType,
45        actual: SlotType,
46    },
47}
48
49impl RenderGraphRunner {
50    pub fn run(
51        graph: &RenderGraph,
52        device: &wgpu::Device,
53        queue: &wgpu::Queue,
54        state: &RenderResources,
55        world: &World,
56    ) -> Result<(), RenderGraphRunnerError> {
57        let command_encoder =
58            device.create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
59        let mut render_context = RenderContext {
60            device,
61            command_encoder,
62        };
63
64        Self::run_graph(graph, None, &mut render_context, state, world, &[])?;
65        {
66            #[cfg(feature = "trace")]
67            let _span = tracing::info_span!("submit_graph_commands").entered();
68            queue.submit(vec![render_context.command_encoder.finish()]);
69        }
70        Ok(())
71    }
72
73    fn run_graph(
74        graph: &RenderGraph,
75        graph_name: Option<Cow<'static, str>>,
76        render_context: &mut RenderContext,
77        state: &RenderResources,
78        world: &World,
79        inputs: &[SlotValue],
80    ) -> Result<(), RenderGraphRunnerError> {
81        let mut node_outputs: HashMap<NodeId, SmallVec<[SlotValue; 4]>> = HashMap::default();
82        #[cfg(feature = "trace")]
83        let span = if let Some(name) = &graph_name {
84            tracing::info_span!("run_graph", name = name.as_ref())
85        } else {
86            tracing::info_span!("run_graph", name = "main_graph")
87        };
88        #[cfg(feature = "trace")]
89        let _guard = span.enter();
90
91        // Queue up nodes without inputs, which can be run immediately
92        let mut node_queue: VecDeque<&NodeState> = graph
93            .iter_nodes()
94            .filter(|node| node.input_slots.is_empty())
95            .collect();
96
97        // pass inputs into the graph
98        if let Some(input_node) = graph.input_node() {
99            let mut input_values: SmallVec<[SlotValue; 4]> = SmallVec::new();
100            for (i, input_slot) in input_node.input_slots.iter().enumerate() {
101                if let Some(input_value) = inputs.get(i) {
102                    if input_slot.slot_type != input_value.slot_type() {
103                        return Err(RenderGraphRunnerError::MismatchedInputSlotType {
104                            slot_index: i,
105                            actual: input_value.slot_type(),
106                            expected: input_slot.slot_type,
107                            label: input_slot.name.clone().into(),
108                        });
109                    } else {
110                        input_values.push(input_value.clone());
111                    }
112                } else {
113                    return Err(RenderGraphRunnerError::MissingInput {
114                        slot_index: i,
115                        slot_name: input_slot.name.clone(),
116                        graph_name: graph_name.clone(),
117                    });
118                }
119            }
120
121            node_outputs.insert(input_node.id, input_values);
122
123            for (_, node_state) in graph.iter_node_outputs(input_node.id).expect("node exists") {
124                node_queue.push_front(node_state);
125            }
126        }
127
128        'handle_node: while let Some(node_state) = node_queue.pop_back() {
129            // skip nodes that are already processed
130            if node_outputs.contains_key(&node_state.id) {
131                continue;
132            }
133
134            let mut slot_indices_and_inputs: SmallVec<[(usize, SlotValue); 4]> = SmallVec::new();
135            // check if all dependencies have finished running
136            for (edge, input_node) in graph
137                .iter_node_inputs(node_state.id)
138                .expect("node is in graph")
139            {
140                match edge {
141                    Edge::SlotEdge {
142                        output_index,
143                        input_index,
144                        ..
145                    } => {
146                        if let Some(outputs) = node_outputs.get(&input_node.id) {
147                            slot_indices_and_inputs
148                                .push((*input_index, outputs[*output_index].clone()));
149                        } else {
150                            node_queue.push_front(node_state);
151                            continue 'handle_node;
152                        }
153                    }
154                    Edge::NodeEdge { .. } => {
155                        if !node_outputs.contains_key(&input_node.id) {
156                            node_queue.push_front(node_state);
157                            continue 'handle_node;
158                        }
159                    }
160                }
161            }
162
163            // construct final sorted input list
164            slot_indices_and_inputs.sort_by_key(|(index, _)| *index);
165            let inputs: SmallVec<[SlotValue; 4]> = slot_indices_and_inputs
166                .into_iter()
167                .map(|(_, value)| value)
168                .collect();
169
170            assert_eq!(inputs.len(), node_state.input_slots.len());
171
172            let mut outputs: SmallVec<[Option<SlotValue>; 4]> =
173                smallvec![None; node_state.output_slots.len()];
174            {
175                let mut context = RenderGraphContext::new(graph, node_state, &inputs, &mut outputs);
176                {
177                    #[cfg(feature = "trace")]
178                    let _span = tracing::info_span!("node", name = node_state.type_name).entered();
179
180                    node_state
181                        .node
182                        .run(&mut context, render_context, state, world)?;
183                }
184
185                for run_sub_graph in context.finish() {
186                    let sub_graph = graph
187                        .get_sub_graph(&run_sub_graph.name)
188                        .expect("sub graph exists because it was validated when queued.");
189                    Self::run_graph(
190                        sub_graph,
191                        Some(run_sub_graph.name),
192                        render_context,
193                        state,
194                        world,
195                        &run_sub_graph.inputs,
196                    )?;
197                }
198            }
199
200            let mut values: SmallVec<[SlotValue; 4]> = SmallVec::new();
201            for (i, output) in outputs.into_iter().enumerate() {
202                if let Some(value) = output {
203                    values.push(value);
204                } else {
205                    let empty_slot = node_state.output_slots.get_slot(i).unwrap();
206                    return Err(RenderGraphRunnerError::EmptyNodeOutputSlot {
207                        type_name: node_state.type_name,
208                        slot_index: i,
209                        slot_name: empty_slot.name.clone(),
210                    });
211                }
212            }
213            node_outputs.insert(node_state.id, values);
214
215            for (_, node_state) in graph.iter_node_outputs(node_state.id).expect("node exists") {
216                node_queue.push_front(node_state);
217            }
218        }
219
220        Ok(())
221    }
222}