1use std::borrow::Cow;
2
3use thiserror::Error;
4
5use super::{NodeState, RenderGraph, SlotInfos, SlotLabel, SlotType, SlotValue};
6use crate::render::resource::TextureView;
7
8pub struct RunSubGraph {
11 pub name: Cow<'static, str>,
12 pub inputs: Vec<SlotValue>,
13}
14
15pub struct RenderGraphContext<'a> {
24 graph: &'a RenderGraph,
25 node: &'a NodeState,
26 inputs: &'a [SlotValue],
27 outputs: &'a mut [Option<SlotValue>],
28 run_sub_graphs: Vec<RunSubGraph>,
29}
30
31impl<'a> RenderGraphContext<'a> {
32 pub fn new(
34 graph: &'a RenderGraph,
35 node: &'a NodeState,
36 inputs: &'a [SlotValue],
37 outputs: &'a mut [Option<SlotValue>],
38 ) -> Self {
39 Self {
40 graph,
41 node,
42 inputs,
43 outputs,
44 run_sub_graphs: Vec::new(),
45 }
46 }
47
48 #[inline]
50 pub fn inputs(&self) -> &[SlotValue] {
51 self.inputs
52 }
53
54 pub fn input_info(&self) -> &SlotInfos {
56 &self.node.input_slots
57 }
58
59 pub fn output_info(&self) -> &SlotInfos {
61 &self.node.output_slots
62 }
63
64 pub fn get_input(&self, label: impl Into<SlotLabel>) -> Result<&SlotValue, InputSlotError> {
66 let label = label.into();
67 let index = self
68 .input_info()
69 .get_slot_index(label.clone())
70 .ok_or(InputSlotError::InvalidSlot(label))?;
71 Ok(&self.inputs[index])
72 }
73
74 pub fn get_input_texture(
76 &self,
77 label: impl Into<SlotLabel>,
78 ) -> Result<&TextureView, InputSlotError> {
79 let label = label.into();
80 match self.get_input(label.clone())? {
81 SlotValue::TextureView(value) => Ok(value),
82 value => Err(InputSlotError::MismatchedSlotType {
83 label,
84 actual: value.slot_type(),
85 expected: SlotType::TextureView,
86 }),
87 }
88 }
89
90 pub fn get_input_sampler(
92 &self,
93 label: impl Into<SlotLabel>,
94 ) -> Result<&wgpu::Sampler, InputSlotError> {
95 let label = label.into();
96 match self.get_input(label.clone())? {
97 SlotValue::Sampler(value) => Ok(value),
98 value => Err(InputSlotError::MismatchedSlotType {
99 label,
100 actual: value.slot_type(),
101 expected: SlotType::Sampler,
102 }),
103 }
104 }
105
106 pub fn get_input_buffer(
108 &self,
109 label: impl Into<SlotLabel>,
110 ) -> Result<&wgpu::Buffer, InputSlotError> {
111 let label = label.into();
112 match self.get_input(label.clone())? {
113 SlotValue::Buffer(value) => Ok(value),
114 value => Err(InputSlotError::MismatchedSlotType {
115 label,
116 actual: value.slot_type(),
117 expected: SlotType::Buffer,
118 }),
119 }
120 }
121
122 pub fn set_output(
124 &mut self,
125 label: impl Into<SlotLabel>,
126 value: impl Into<SlotValue>,
127 ) -> Result<(), OutputSlotError> {
128 let label = label.into();
129 let value = value.into();
130 let slot_index = self
131 .output_info()
132 .get_slot_index(label.clone())
133 .ok_or_else(|| OutputSlotError::InvalidSlot(label.clone()))?;
134 let slot = self
135 .output_info()
136 .get_slot(slot_index)
137 .expect("slot is valid");
138 if value.slot_type() != slot.slot_type {
139 return Err(OutputSlotError::MismatchedSlotType {
140 label,
141 actual: slot.slot_type,
142 expected: value.slot_type(),
143 });
144 }
145 self.outputs[slot_index] = Some(value);
146 Ok(())
147 }
148
149 pub fn run_sub_graph(
151 &mut self,
152 name: impl Into<Cow<'static, str>>,
153 inputs: Vec<SlotValue>,
154 ) -> Result<(), RunSubGraphError> {
155 let name = name.into();
156 let sub_graph = self
157 .graph
158 .get_sub_graph(&name)
159 .ok_or_else(|| RunSubGraphError::MissingSubGraph(name.clone()))?;
160 if let Some(input_node) = sub_graph.input_node() {
161 for (i, input_slot) in input_node.input_slots.iter().enumerate() {
162 if let Some(input_value) = inputs.get(i) {
163 if input_slot.slot_type != input_value.slot_type() {
164 return Err(RunSubGraphError::MismatchedInputSlotType {
165 graph_name: name,
166 slot_index: i,
167 actual: input_value.slot_type(),
168 expected: input_slot.slot_type,
169 label: input_slot.name.clone().into(),
170 });
171 }
172 } else {
173 return Err(RunSubGraphError::MissingInput {
174 slot_index: i,
175 slot_name: input_slot.name.clone(),
176 graph_name: name,
177 });
178 }
179 }
180 } else if !inputs.is_empty() {
181 return Err(RunSubGraphError::SubGraphHasNoInputs(name));
182 }
183
184 self.run_sub_graphs.push(RunSubGraph { name, inputs });
185
186 Ok(())
187 }
188
189 pub fn finish(self) -> Vec<RunSubGraph> {
192 self.run_sub_graphs
193 }
194}
195
196#[derive(Error, Debug, Eq, PartialEq)]
197pub enum RunSubGraphError {
198 #[error("attempted to run sub-graph `{0}`, but it does not exist")]
199 MissingSubGraph(Cow<'static, str>),
200 #[error("attempted to pass inputs to sub-graph `{0}`, which has no input slots")]
201 SubGraphHasNoInputs(Cow<'static, str>),
202 #[error("sub graph (name: `{graph_name:?}`) could not be run because slot `{slot_name}` at index {slot_index} has no value")]
203 MissingInput {
204 slot_index: usize,
205 slot_name: Cow<'static, str>,
206 graph_name: Cow<'static, str>,
207 },
208 #[error("attempted to use the wrong type for input slot")]
209 MismatchedInputSlotType {
210 graph_name: Cow<'static, str>,
211 slot_index: usize,
212 label: SlotLabel,
213 expected: SlotType,
214 actual: SlotType,
215 },
216}
217
218#[derive(Error, Debug, Eq, PartialEq)]
219pub enum OutputSlotError {
220 #[error("output slot `{0:?}` does not exist")]
221 InvalidSlot(SlotLabel),
222 #[error("attempted to output a value of type `{actual}` to output slot `{label:?}`, which has type `{expected}`")]
223 MismatchedSlotType {
224 label: SlotLabel,
225 expected: SlotType,
226 actual: SlotType,
227 },
228}
229
230#[derive(Error, Debug, Eq, PartialEq)]
231pub enum InputSlotError {
232 #[error("input slot `{0:?}` does not exist")]
233 InvalidSlot(SlotLabel),
234 #[error("attempted to retrieve a value of type `{actual}` from input slot `{label:?}`, which has type `{expected}`")]
235 MismatchedSlotType {
236 label: SlotLabel,
237 expected: SlotType,
238 actual: SlotType,
239 },
240}