diff --git a/src/error.rs b/src/error.rs index b8aa966..2eb1075 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,4 +1,10 @@ #[derive(Debug)] pub enum Error { Compile(shaderc::Error), + Layout(ConvertError), +} + +#[derive(Debug)] +pub enum ConvertError { + Unimplemented, } diff --git a/src/lib.rs b/src/lib.rs index baf3426..f6442bc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,7 +8,7 @@ mod watch; pub use layouts::*; pub use reflection::LayoutData; pub use watch::{Message, Watch}; -pub use error::Error; +pub use error::{Error, ConvertError}; use spirv_reflect as sr; use vulkano as vk; @@ -20,7 +20,6 @@ pub struct CompiledShaders { pub fragment: Vec, } - pub fn load(vertex: T, fragment: T) -> Result where T: AsRef, @@ -30,6 +29,6 @@ where Ok(CompiledShaders{ vertex, fragment }) } -pub fn parse(code: &CompiledShaders) -> Entry { +pub fn parse(code: &CompiledShaders) -> Result { reflection::create_entry(code) } diff --git a/src/reflection.rs b/src/reflection.rs index 1940e30..13ae3e9 100644 --- a/src/reflection.rs +++ b/src/reflection.rs @@ -5,8 +5,10 @@ use crate::vk::descriptor::descriptor::*; use crate::vk::descriptor::pipeline_layout::PipelineLayoutDescPcRange; use crate::vk::pipeline::shader::ShaderInterfaceDefEntry; use crate::CompiledShaders; +use crate::error::Error; use std::borrow::Cow; use std::collections::HashMap; +use std::convert::TryFrom; pub struct ShaderInterfaces { pub inputs: Vec, @@ -22,11 +24,11 @@ pub struct LayoutData { pub pc_ranges: Vec, } -pub fn create_entry(shaders: &CompiledShaders) -> Entry { +pub fn create_entry(shaders: &CompiledShaders) -> Result { let vertex_interfaces = create_interfaces(&shaders.vertex); - let vertex_layout = create_layouts(&shaders.vertex); + let vertex_layout = create_layouts(&shaders.vertex)?; let fragment_interfaces = create_interfaces(&shaders.fragment); - let fragment_layout = create_layouts(&shaders.fragment); + let fragment_layout = create_layouts(&shaders.fragment)?; let frag_input = FragInput { inputs: fragment_interfaces.inputs, }; @@ -53,14 +55,14 @@ pub fn create_entry(shaders: &CompiledShaders) -> Entry { }, layout_data: vertex_layout, }; - Entry { + Ok(Entry { frag_input, frag_output, vert_input, vert_output, frag_layout, vert_layout, - } + }) } fn create_interfaces(data: &[u32]) -> ShaderInterfaces { @@ -105,11 +107,12 @@ fn create_interfaces(data: &[u32]) -> ShaderInterfaces { .expect("failed to load module") } -fn create_layouts(data: &[u32]) -> LayoutData { +fn create_layouts(data: &[u32]) -> Result { sr::ShaderModule::load_u32_data(data) .map(|m| { - let (num_sets, num_bindings, descriptions) = m - .enumerate_descriptor_sets(None) + //let (num_sets, num_bindings, descriptions) = m + let descs = + m.enumerate_descriptor_sets(None) .map(|sets| { let num_sets = sets.len(); let num_bindings = sets @@ -130,7 +133,8 @@ fn create_layouts(data: &[u32]) -> LayoutData { descriptor_type: b.descriptor_type, image: b.image, }; - let ty = SpirvTy::::from(info).inner(); + let ty = SpirvTy::::try_from(info)?; + let ty = ty.inner(); let stages = ShaderStages::none(); let d = DescriptorDesc { ty, @@ -140,17 +144,19 @@ fn create_layouts(data: &[u32]) -> LayoutData { // it's correct readonly: true, }; - (b.binding as usize, d) + Ok((b.binding as usize, d)) }) + .flat_map(|d| d.ok()) .collect::>(); (i.set as usize, desc) }) .collect::>>(); (num_sets, num_bindings, descriptions) }) - .expect("Failed to pass descriptors"); - let (num_constants, pc_ranges) = m - .enumerate_push_constant_blocks(None) + .into_iter(); + //let (num_constants, pc_ranges) = m + let pcs = + m.enumerate_push_constant_blocks(None) .map(|constants| { let num_constants = constants.len(); let pc_ranges = constants @@ -163,14 +169,14 @@ fn create_layouts(data: &[u32]) -> LayoutData { .collect::>(); (num_constants, pc_ranges) }) - .expect("Failed to pass push constants"); + .into_iter(); + descs.flat_map(|(num_sets, num_bindings, descriptions)| pcs.map(|(num_constants, pc_ranges)| LayoutData { num_sets, num_bindings, descriptions, num_constants, pc_ranges, - } + })).next() }) - .expect("failed to load module") } diff --git a/src/srvk.rs b/src/srvk.rs index a1d1d52..e7f7643 100644 --- a/src/srvk.rs +++ b/src/srvk.rs @@ -1,7 +1,9 @@ use crate::sr; use crate::vk; +use crate::error::{ConvertError, Error}; use vk::descriptor::descriptor::*; use vk::format::Format; +use std::convert::TryFrom; pub struct SpirvTy { inner: T, @@ -12,34 +14,35 @@ pub struct DescriptorDescInfo { pub image: sr::types::ReflectImageTraits, } + impl SpirvTy { pub fn inner(self) -> T { self.inner } } -impl From for SpirvTy { - fn from(d: DescriptorDescInfo) -> Self { +impl TryFrom for SpirvTy { + type Error = Error; + fn try_from(d: DescriptorDescInfo) -> Result { use sr::types::ReflectDescriptorType as SR; use DescriptorDescTy as VK; - let t = match d.descriptor_type { - SR::Undefined => unreachable!(), - SR::Sampler => VK::Sampler, - SR::CombinedImageSampler => VK::CombinedImageSampler(SpirvTy::from(d.image).inner()), - SR::SampledImage => unreachable!(), - SR::StorageImage => unreachable!(), - SR::UniformTexelBuffer => unreachable!(), - SR::StorageTexelBuffer => unreachable!(), - SR::UniformBuffer => unreachable!(), - SR::StorageBuffer => unreachable!(), - SR::UniformBufferDynamic => unreachable!(), - SR::StorageBufferDynamic => unreachable!(), - SR::InputAttachment => unreachable!(), - SR::AccelerationStructureNV => unreachable!(), - }; - SpirvTy { - inner: t, + match d.descriptor_type { + SR::Undefined => Err(ConvertError::Unimplemented), + SR::Sampler => Ok(VK::Sampler), + SR::CombinedImageSampler => Ok(VK::CombinedImageSampler(SpirvTy::from(d.image).inner())), + SR::SampledImage => Err(ConvertError::Unimplemented), + SR::StorageImage => Err(ConvertError::Unimplemented), + SR::UniformTexelBuffer => Err(ConvertError::Unimplemented), + SR::StorageTexelBuffer => Err(ConvertError::Unimplemented), + SR::UniformBuffer => Err(ConvertError::Unimplemented), + SR::StorageBuffer => Err(ConvertError::Unimplemented), + SR::UniformBufferDynamic => Err(ConvertError::Unimplemented), + SR::StorageBufferDynamic => Err(ConvertError::Unimplemented), + SR::InputAttachment => Err(ConvertError::Unimplemented), + SR::AccelerationStructureNV => Err(ConvertError::Unimplemented), } + .map(|t| SpirvTy{ inner: t }) + .map_err(|e| Error::Layout(e)) } } @@ -136,7 +139,7 @@ impl From for SpirvTy { use sr::types::ReflectFormat::*; use Format::*; let t = match f { - Undefined => unreachable!(), + Undefined => unimplemented!(), R32_UINT => R32Uint, R32_SINT => R32Sint, R32_SFLOAT => R32Sfloat,