diff --git a/MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.mm b/MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.mm index bb16ee73..678b8f19 100644 --- a/MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.mm +++ b/MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.mm @@ -29,6 +29,18 @@ const MVKMTLFunction MVKMTLFunctionNull = { nil, MTLSizeMake(1, 1, 1) }; #pragma mark - #pragma mark MVKShaderLibrary +uint32_t getOffsetForConstantId(const VkSpecializationInfo* pSpecInfo, uint32_t constantId) +{ + for (uint32_t specIdx = 0; specIdx < pSpecInfo->mapEntryCount; specIdx++) { + const VkSpecializationMapEntry* pMapEntry = &pSpecInfo->pMapEntries[specIdx]; + if (pMapEntry->constantID == constantId) { + return pMapEntry->offset; + } + } + + return -1; +} + MVKMTLFunction MVKShaderLibrary::getMTLFunction(const VkPipelineShaderStageCreateInfo* pShaderStage) { if ( !_mtlLibrary ) { return MVKMTLFunctionNull; } @@ -68,6 +80,24 @@ MVKMTLFunction MVKShaderLibrary::getMTLFunction(const VkPipelineShaderStageCreat atIndex: mtlFCIndex]; } } + + // Get the specialization constant values for the work group size + if (ep.workgroupSizeId.constant != 0) { + uint32_t widthOffset = getOffsetForConstantId(pSpecInfo, ep.workgroupSizeId.width); + if (widthOffset != -1) { + ep.workgroupSize.width = &(((uint32_t*)pSpecInfo->pData)[widthOffset]); + } + + uint32_t heightOffset = getOffsetForConstantId(pSpecInfo, ep.workgroupSizeId.height); + if (heightOffset != -1) { + ep.workgroupSize.height = &(((uint32_t*)pSpecInfo->pData)[heightOffset]); + } + + uint32_t depthOffset = getOffsetForConstantId(pSpecInfo, ep.workgroupSizeId.depth); + if (depthOffset != -1) { + ep.workgroupSize.depth = &(((uint32_t*)pSpecInfo->pData)[depthOffset]); + } + } } // Compile the specialized Metal function, and use it instead of the unspecialized Metal function. diff --git a/MoltenVKShaderConverter/MoltenVKSPIRVToMSLConverter/SPIRVToMSLConverter.cpp b/MoltenVKShaderConverter/MoltenVKSPIRVToMSLConverter/SPIRVToMSLConverter.cpp index 635d4de4..2c5f9f42 100644 --- a/MoltenVKShaderConverter/MoltenVKSPIRVToMSLConverter/SPIRVToMSLConverter.cpp +++ b/MoltenVKShaderConverter/MoltenVKSPIRVToMSLConverter/SPIRVToMSLConverter.cpp @@ -322,6 +322,13 @@ void populateFromCompiler(spirv_cross::Compiler& compiler, SPIRVEntryPointsByNam mvkEP.workgroupSize.width = max(wgSize.x, minDim); mvkEP.workgroupSize.height = max(wgSize.y, minDim); mvkEP.workgroupSize.depth = max(wgSize.z, minDim); + + spirv_cross::SpecializationConstant width, height, depth; + mvkEP.workgroupSizeId.constant = compiler.get_work_group_size_specialization_constants(width, height, depth); + mvkEP.workgroupSizeId.width = width.constant_id; + mvkEP.workgroupSizeId.height = height.constant_id; + mvkEP.workgroupSizeId.depth = depth.constant_id; + entryPoints[epOrigName] = mvkEP; } } diff --git a/MoltenVKShaderConverter/MoltenVKSPIRVToMSLConverter/SPIRVToMSLConverter.h b/MoltenVKShaderConverter/MoltenVKSPIRVToMSLConverter/SPIRVToMSLConverter.h index 66fa0031..aee99df2 100644 --- a/MoltenVKShaderConverter/MoltenVKSPIRVToMSLConverter/SPIRVToMSLConverter.h +++ b/MoltenVKShaderConverter/MoltenVKSPIRVToMSLConverter/SPIRVToMSLConverter.h @@ -132,7 +132,7 @@ namespace mvk { /** * Describes a SPIRV entry point, including the Metal function name (which may be * different than the Vulkan entry point name if the original name was illegal in Metal), - * and the number of threads in each workgroup, if the shader is a compute shader. + * and the number of threads in each workgroup or their specialization constant id, if the shader is a compute shader. */ typedef struct { std::string mtlFunctionName; @@ -141,6 +141,10 @@ namespace mvk { uint32_t height = 1; uint32_t depth = 1; } workgroupSize; + struct { + uint32_t width, height, depth; + uint32_t constant = 0; + } workgroupSizeId; } SPIRVEntryPoint; /** Holds a map of entry point info, indexed by the SPIRV entry point name. */