From 370bf8350af660ae77f9831959b766a68aa8eb74 Mon Sep 17 00:00:00 2001 From: freesig Date: Wed, 5 Jun 2019 13:28:24 +1000 Subject: [PATCH] support for compute shaders --- Cargo.toml | 3 +- src/layouts.rs | 96 ++++++++++++++++++++++++++++++------ src/lib.rs | 17 ++++++- src/reflection.rs | 17 ++++--- src/watch.rs | 121 ++++++++++++++++++++++++++++++++++++++-------- tests/tests.rs | 36 ++------------ 6 files changed, 212 insertions(+), 78 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d342979..3f3e864 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "shade_runner" -version = "0.1.0" +version = "0.1.1" authors = ["Tom Gowan "] edition = "2018" description = "Allows runtime hot loading of shaders for vulkano" @@ -9,7 +9,6 @@ repository = "https://github.com/freesig/shade_runner" readme = "README.md" license = "MIT" keywords = ["vulkan", "vulkano", "shaders", "hotloading"] -license_file = "LICENSE" [dependencies] notify = "4" diff --git a/src/layouts.rs b/src/layouts.rs index 5c5a5b3..55f3ce1 100644 --- a/src/layouts.rs +++ b/src/layouts.rs @@ -5,7 +5,7 @@ use vk::descriptor::descriptor::*; use vk::descriptor::pipeline_layout::*; use crate::reflection::LayoutData; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct Entry { pub frag_input: FragInput, pub frag_output: FragOutput, @@ -13,9 +13,10 @@ pub struct Entry { pub vert_input: VertInput, pub vert_output: VertOutput, pub vert_layout: VertLayout, + pub compute_layout: ComputeLayout, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct FragInput { pub inputs: Vec, } @@ -30,7 +31,7 @@ unsafe impl ShaderInterfaceDef for FragInput { pub type FragInputIter = std::vec::IntoIter; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct FragOutput { pub outputs: Vec, } @@ -46,11 +47,20 @@ unsafe impl ShaderInterfaceDef for FragOutput { pub type FragOutputIter = std::vec::IntoIter; // Layout same as with vertex shader. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct FragLayout { - pub stages: ShaderStages, pub layout_data: LayoutData, } +impl FragLayout { + const STAGES: ShaderStages = ShaderStages { + vertex: false, + tessellation_control: false, + tessellation_evaluation: false, + geometry: false, + fragment: true, + compute: false, + }; +} unsafe impl PipelineLayoutDesc for FragLayout { fn num_sets(&self) -> usize { self.layout_data.num_sets @@ -63,10 +73,10 @@ unsafe impl PipelineLayoutDesc for FragLayout { .and_then(|s|s.get(&binding)) .map(|desc| { let mut desc = desc.clone(); - desc.stages = self.stages; + desc.stages = FragLayout::STAGES; desc }) - + } fn num_push_constants_ranges(&self) -> usize { self.layout_data.num_constants @@ -75,14 +85,14 @@ unsafe impl PipelineLayoutDesc for FragLayout { self.layout_data.pc_ranges.get(num) .map(|desc| { let mut desc = *desc; - desc.stages = self.stages; + desc.stages = FragLayout::STAGES; desc }) } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct VertInput { pub inputs: Vec, } @@ -97,7 +107,7 @@ unsafe impl ShaderInterfaceDef for VertInput { pub type VertInputIter = std::vec::IntoIter; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct VertOutput { pub outputs: Vec, } @@ -113,11 +123,20 @@ unsafe impl ShaderInterfaceDef for VertOutput { pub type VertOutputIter = std::vec::IntoIter; // This structure describes layout of this stage. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct VertLayout { - pub stages: ShaderStages, pub layout_data: LayoutData, } +impl VertLayout { + const STAGES: ShaderStages = ShaderStages { + vertex: true, + tessellation_control: false, + tessellation_evaluation: false, + geometry: false, + fragment: false, + compute: false, + }; +} unsafe impl PipelineLayoutDesc for VertLayout { fn num_sets(&self) -> usize { self.layout_data.num_sets @@ -130,10 +149,57 @@ unsafe impl PipelineLayoutDesc for VertLayout { .and_then(|s|s.get(&binding)) .map(|desc| { let mut desc = desc.clone(); - desc.stages = self.stages; + desc.stages = VertLayout::STAGES; desc }) - + + } + fn num_push_constants_ranges(&self) -> usize { + self.layout_data.num_constants + } + fn push_constants_range(&self, num: usize) -> Option { + self.layout_data.pc_ranges.get(num) + .map(|desc| { + let mut desc = *desc; + desc.stages = VertLayout::STAGES; + desc + }) + + } +} + +#[derive(Debug, Clone, Default)] +pub struct ComputeLayout { + pub layout_data: LayoutData, +} + +impl ComputeLayout { + const STAGES: ShaderStages = ShaderStages { + vertex: false, + tessellation_control: false, + tessellation_evaluation: false, + geometry: false, + fragment: false, + compute: true, + }; +} + +unsafe impl PipelineLayoutDesc for ComputeLayout { + fn num_sets(&self) -> usize { + self.layout_data.num_sets + } + fn num_bindings_in_set(&self, set: usize) -> Option { + self.layout_data.num_bindings.get(&set).map(|&b|b) + } + fn descriptor(&self, set: usize, binding: usize) -> Option { + self.layout_data.descriptions.get(&set) + .and_then(|s|s.get(&binding)) + .map(|desc| { + let mut desc = desc.clone(); + desc.stages = Self::STAGES; + desc + }) + } fn num_push_constants_ranges(&self) -> usize { self.layout_data.num_constants @@ -142,7 +208,7 @@ unsafe impl PipelineLayoutDesc for VertLayout { self.layout_data.pc_ranges.get(num) .map(|desc| { let mut desc = *desc; - desc.stages = self.stages; + desc.stages = Self::STAGES; desc }) diff --git a/src/lib.rs b/src/lib.rs index 1de73db..8de337e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,6 +18,7 @@ use shaderc::ShaderKind; pub struct CompiledShaders { pub vertex: Vec, pub fragment: Vec, + pub compute: Vec, } /// Loads and compiles the vertex and fragment GLSL shaders from files @@ -27,7 +28,21 @@ where { let vertex = compiler::compile(vertex, ShaderKind::Vertex).map_err(Error::Compile)?; let fragment = compiler::compile(fragment, ShaderKind::Fragment).map_err(Error::Compile)?; - Ok(CompiledShaders{ vertex, fragment }) + Ok(CompiledShaders{ vertex, fragment, compute: Vec::new() }) +} + +// TODO this should be incorpoarted into load but that would be +// a breaking change. Do this in next major version +pub fn load_compute(compute: T) -> Result +where + T: AsRef, +{ + let compute = compiler::compile(compute, ShaderKind::Compute).map_err(Error::Compile)?; + Ok(CompiledShaders{ vertex: Vec::new(), fragment: Vec::new(), compute }) +} + +pub fn parse_compute(code: &CompiledShaders) -> Result { + reflection::create_compute_entry(code) } /// Parses the shaders and gives an entry point diff --git a/src/reflection.rs b/src/reflection.rs index 159303b..efe4e28 100644 --- a/src/reflection.rs +++ b/src/reflection.rs @@ -36,10 +36,6 @@ pub fn create_entry(shaders: &CompiledShaders) -> Result { outputs: fragment_interfaces.outputs, }; let frag_layout = FragLayout { - stages: ShaderStages { - fragment: true, - ..ShaderStages::none() - }, layout_data: fragment_layout, }; let vert_input = VertInput { @@ -49,10 +45,6 @@ pub fn create_entry(shaders: &CompiledShaders) -> Result { outputs: vertex_interfaces.outputs, }; let vert_layout = VertLayout { - stages: ShaderStages { - vertex: true, - ..ShaderStages::none() - }, layout_data: vertex_layout, }; Ok(Entry { @@ -62,6 +54,15 @@ pub fn create_entry(shaders: &CompiledShaders) -> Result { vert_output, frag_layout, vert_layout, + compute_layout: Default::default(), + }) +} + +pub fn create_compute_entry(shaders: &CompiledShaders) -> Result { + create_layouts(&shaders.compute).map(|layout_data| { + let mut entry = Entry::default(); + entry.compute_layout = ComputeLayout{ layout_data }; + entry }) } diff --git a/src/watch.rs b/src/watch.rs index d8a8a89..7cc9b8d 100644 --- a/src/watch.rs +++ b/src/watch.rs @@ -13,12 +13,27 @@ pub struct Watch { pub rx: Receiver>, } -struct Loader { +enum Loader { + Graphics(GraphicsLoader), + Compute(ComputeLoader), +} + +enum SrcPath { + Graphics(PathBuf, PathBuf), + Compute(PathBuf), +} + +struct GraphicsLoader { vertex: PathBuf, fragment: PathBuf, tx: Sender>, } +struct ComputeLoader { + compute: PathBuf, + tx: Sender>, +} + pub struct Message { pub shaders: CompiledShaders, pub entry: Entry, @@ -31,9 +46,28 @@ impl Watch { where T: AsRef, { - let (handler, rx) = create_watch( + let src_path = SrcPath::Graphics( vertex.as_ref().to_path_buf(), - fragment.as_ref().to_path_buf(), + fragment.as_ref().to_path_buf() + ); + let (handler, rx) = create_watch( + src_path, + frequency, + )?; + Ok(Watch { + _handler: handler, + rx, + }) + } + + pub fn create_compute(compute: T, frequency: Duration) -> Result + where + T: AsRef, + { + let src_path = SrcPath::Compute( + compute.as_ref(). to_path_buf()); + let (handler, rx) = create_watch( + src_path, frequency, )?; Ok(Watch { @@ -43,10 +77,10 @@ impl Watch { } } -impl Loader { +impl GraphicsLoader { fn create(vertex: PathBuf, fragment: PathBuf) -> (Self, Receiver>) { let (tx, rx) = mpsc::channel(); - let loader = Loader { + let loader = GraphicsLoader { vertex, fragment, tx, @@ -67,6 +101,38 @@ impl Loader { } } +impl ComputeLoader { + fn create(compute: PathBuf) -> (Self, Receiver>) { + let (tx, rx) = mpsc::channel(); + let loader = ComputeLoader { + compute, + tx, + }; + loader.reload(); + (loader, rx) + } + + fn reload(&self) { + match crate::load_compute(&self.compute) { + Ok(shaders) => { + let entry = crate::parse_compute(&shaders); + let msg = entry.map(|entry| Message { shaders, entry }); + self.tx.send(msg).ok() + } + Err(e) => self.tx.send(Err(e)).ok(), + }; + } +} + +impl Loader { + fn reload(&self) { + match self { + Loader::Graphics(g) => g.reload(), + Loader::Compute(g) => g.reload(), + } + } +} + struct Handler { thread_tx: mpsc::Sender<()>, handle: Option>, @@ -83,8 +149,7 @@ impl Drop for Handler { } fn create_watch( - vert_path: PathBuf, - frag_path: PathBuf, + src_path: SrcPath, frequency: Duration ) -> Result<(Handler, mpsc::Receiver>), Error> { let (notify_tx, notify_rx) = mpsc::channel(); @@ -92,20 +157,36 @@ fn create_watch( let mut watcher: RecommendedWatcher = Watcher::new(notify_tx, frequency).map_err(Error::FileWatch)?; - let mut vp = vert_path.clone(); - let mut fp = frag_path.clone(); - vp.pop(); - fp.pop(); - watcher - .watch(&vp, RecursiveMode::NonRecursive) - .map_err(Error::FileWatch)?; - if vp != fp { - watcher - .watch(&fp, RecursiveMode::NonRecursive) - .map_err(Error::FileWatch)?; - } + let (loader, rx) = match src_path { + SrcPath::Graphics(vert_path, frag_path) => { + let mut vp = vert_path.clone(); + let mut fp = frag_path.clone(); + vp.pop(); + fp.pop(); + watcher + .watch(&vp, RecursiveMode::NonRecursive) + .map_err(Error::FileWatch)?; + if vp != fp { + watcher + .watch(&fp, RecursiveMode::NonRecursive) + .map_err(Error::FileWatch)?; + } + + let (loader, rx) = GraphicsLoader::create(vert_path, frag_path); + (Loader::Graphics(loader), rx) + } + SrcPath::Compute(compute_path) => { + let mut cp = compute_path.clone(); + cp.pop(); + watcher + .watch(&cp, RecursiveMode::NonRecursive) + .map_err(Error::FileWatch)?; + + let (loader, rx) = ComputeLoader::create(compute_path); + (Loader::Compute(loader), rx) + } + }; - let (loader, rx) = Loader::create(vert_path, frag_path); let handle = thread::spawn(move || 'watch_loop: loop { if thread_rx.try_recv().is_ok() { diff --git a/tests/tests.rs b/tests/tests.rs index 8c96e90..9c207ba 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -81,6 +81,7 @@ where fn test_shade1() { setup(); let target = Entry { + compute_layout: Default::default(), frag_input: FragInput { inputs: Vec::new() }, frag_output: FragOutput { outputs: vec![ShaderInterfaceDefEntry { @@ -90,10 +91,6 @@ fn test_shade1() { }], }, frag_layout: FragLayout { - stages: ShaderStages { - fragment: true, - ..ShaderStages::none() - }, layout_data: LayoutData { num_sets: 0, num_bindings: HashMap::new(), @@ -113,10 +110,6 @@ fn test_shade1() { outputs: Vec::new(), }, vert_layout: VertLayout { - stages: ShaderStages { - vertex: true, - ..ShaderStages::none() - }, layout_data: LayoutData { num_sets: 0, num_bindings: HashMap::new(), @@ -134,6 +127,7 @@ fn test_shade1() { fn test_shade2() { setup(); let target = Entry { + compute_layout: Default::default(), frag_input: FragInput { inputs: vec![ ShaderInterfaceDefEntry { @@ -161,10 +155,6 @@ fn test_shade2() { }], }, frag_layout: FragLayout { - stages: ShaderStages { - fragment: true, - ..ShaderStages::none() - }, layout_data: LayoutData { num_sets: 0, num_bindings: HashMap::new(), @@ -200,10 +190,6 @@ fn test_shade2() { ], }, vert_layout: VertLayout { - stages: ShaderStages { - vertex: true, - ..ShaderStages::none() - }, layout_data: LayoutData { num_sets: 0, num_bindings: HashMap::new(), @@ -221,6 +207,7 @@ fn test_shade2() { fn test_shade3() { setup(); let target = Entry { + compute_layout: Default::default(), frag_input: FragInput { inputs: Vec::new() }, frag_output: FragOutput { outputs: vec![ShaderInterfaceDefEntry { @@ -230,10 +217,6 @@ fn test_shade3() { }], }, frag_layout: FragLayout { - stages: ShaderStages { - fragment: true, - ..ShaderStages::none() - }, layout_data: LayoutData { num_sets: 1, num_bindings: vec![(0, 1)].into_iter().collect(), @@ -277,10 +260,6 @@ fn test_shade3() { outputs: Vec::new(), }, vert_layout: VertLayout { - stages: ShaderStages { - vertex: true, - ..ShaderStages::none() - }, layout_data: LayoutData { num_sets: 0, num_bindings: HashMap::new(), @@ -309,6 +288,7 @@ fn test_shade3() { fn test_shade4() { setup(); let target = Entry { + compute_layout: Default::default(), frag_input: FragInput { inputs: Vec::new() }, frag_output: FragOutput { outputs: vec![ShaderInterfaceDefEntry { @@ -318,10 +298,6 @@ fn test_shade4() { }], }, frag_layout: FragLayout { - stages: ShaderStages { - fragment: true, - ..ShaderStages::none() - }, layout_data: LayoutData { num_sets: 0, num_bindings: HashMap::new(), @@ -348,10 +324,6 @@ fn test_shade4() { outputs: Vec::new(), }, vert_layout: VertLayout { - stages: ShaderStages { - vertex: true, - ..ShaderStages::none() - }, layout_data: LayoutData { num_sets: 0, num_bindings: HashMap::new(),