From e203b93285680cc38986ba068b51db09e69350be Mon Sep 17 00:00:00 2001 From: Bill Hollings Date: Tue, 3 Jul 2018 13:57:53 -0400 Subject: [PATCH] Fix compute shader workgroup size specialization. Support separate specialization for each workgroup dimension. Support zero as a specialization ID value. Cleanup MoltenVKShaderConverterTool. Update to latest SPIRV-Cross version. Update MoltenVK version to 1.0.14. --- ExternalRevisions/SPIRV-Cross_repo_revision | 2 +- MoltenVK/MoltenVK/API/vk_mvk_moltenvk.h | 2 +- MoltenVK/MoltenVK/GPUObjects/MVKPipeline.mm | 13 +-- .../MoltenVK/GPUObjects/MVKShaderModule.mm | 42 ++++------ .../SPIRVToMSLConverter.cpp | 31 +++---- .../SPIRVToMSLConverter.h | 35 ++++---- .../MoltenVKShaderConverter.xcscheme | 10 ++- .../MoltenVKShaderConverterTool.cpp | 81 +++++++++---------- .../MoltenVKShaderConverterTool.h | 8 +- 9 files changed, 110 insertions(+), 114 deletions(-) diff --git a/ExternalRevisions/SPIRV-Cross_repo_revision b/ExternalRevisions/SPIRV-Cross_repo_revision index 0f1794a3..3c97b234 100644 --- a/ExternalRevisions/SPIRV-Cross_repo_revision +++ b/ExternalRevisions/SPIRV-Cross_repo_revision @@ -1 +1 @@ -d67e586b2e16a46a5cc1515093e8a04bff31c594 +a6814a405abe81545bd3b0a50d374735001173c1 diff --git a/MoltenVK/MoltenVK/API/vk_mvk_moltenvk.h b/MoltenVK/MoltenVK/API/vk_mvk_moltenvk.h index d0df3e36..71496d40 100644 --- a/MoltenVK/MoltenVK/API/vk_mvk_moltenvk.h +++ b/MoltenVK/MoltenVK/API/vk_mvk_moltenvk.h @@ -48,7 +48,7 @@ extern "C" { */ #define MVK_VERSION_MAJOR 1 #define MVK_VERSION_MINOR 0 -#define MVK_VERSION_PATCH 13 +#define MVK_VERSION_PATCH 14 #define MVK_MAKE_VERSION(major, minor, patch) (((major) * 10000) + ((minor) * 100) + (patch)) #define MVK_VERSION MVK_MAKE_VERSION(MVK_VERSION_MAJOR, MVK_VERSION_MINOR, MVK_VERSION_PATCH) diff --git a/MoltenVK/MoltenVK/GPUObjects/MVKPipeline.mm b/MoltenVK/MoltenVK/GPUObjects/MVKPipeline.mm index 1a7dee94..c700aca5 100644 --- a/MoltenVK/MoltenVK/GPUObjects/MVKPipeline.mm +++ b/MoltenVK/MoltenVK/GPUObjects/MVKPipeline.mm @@ -485,16 +485,19 @@ typedef enum { // Ceral archive definitions namespace mvk { + template + void serialize(Archive & archive, SPIRVWorkgroupSizeDimension& wsd) { + archive(wsd.size, + wsd.specializationID, + wsd.isSpecialized); + } + template void serialize(Archive & archive, SPIRVEntryPoint& ep) { archive(ep.mtlFunctionName, ep.workgroupSize.width, ep.workgroupSize.height, - ep.workgroupSize.depth, - ep.workgroupSizeId.width, - ep.workgroupSizeId.height, - ep.workgroupSizeId.depth, - ep.workgroupSizeId.constant); + ep.workgroupSize.depth); } template diff --git a/MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.mm b/MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.mm index 7bec627f..30e90a95 100644 --- a/MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.mm +++ b/MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.mm @@ -30,14 +30,19 @@ const MVKMTLFunction MVKMTLFunctionNull = { nil, MTLSizeMake(1, 1, 1) }; #pragma mark - #pragma mark MVKShaderLibrary -static uint32_t getOffsetForConstantId(const VkSpecializationInfo* pSpecInfo, uint32_t constantId) -{ - for (uint32_t specIdx = 0; specIdx < pSpecInfo->mapEntryCount; specIdx++) { - const VkSpecializationMapEntry* pMapEntry = &pSpecInfo->pMapEntries[specIdx]; - if (pMapEntry->constantID == constantId) { return pMapEntry->offset; } - } +// If the size of the workgroup dimension is specialized, extract it from the +// specialization info, otherwise use the value specified in the SPIR-V shader code. +static uint32_t getWorkgroupDimensionSize(const SPIRVWorkgroupSizeDimension& wgDim, const VkSpecializationInfo* pSpecInfo) { + if (wgDim.isSpecialized && pSpecInfo) { + for (uint32_t specIdx = 0; specIdx < pSpecInfo->mapEntryCount; specIdx++) { + const VkSpecializationMapEntry* pMapEntry = &pSpecInfo->pMapEntries[specIdx]; + if (pMapEntry->constantID == wgDim.specializationID) { + return *reinterpret_cast((uintptr_t)pSpecInfo->pData + pMapEntry->offset) ; + } + } + } - return -1; + return wgDim.size; } MVKMTLFunction MVKShaderLibrary::getMTLFunction(const VkSpecializationInfo* pSpecializationInfo) { @@ -88,27 +93,10 @@ MVKMTLFunction MVKShaderLibrary::getMTLFunction(const VkSpecializationInfo* pSpe mvkNotifyErrorWithText(VK_ERROR_INITIALIZATION_FAILED, "Shader module does not contain an entry point named '%s'.", mtlFuncName.UTF8String); } - if (pSpecializationInfo) { - // Get the specialization constant values for the work group size - if (_entryPoint.workgroupSizeId.constant != 0) { - uint32_t widthOffset = getOffsetForConstantId(pSpecializationInfo, _entryPoint.workgroupSizeId.width); - if (widthOffset != -1) { - _entryPoint.workgroupSize.width = *reinterpret_cast((uint8_t*)pSpecializationInfo->pData + widthOffset); - } + return { mtlFunc, MTLSizeMake(getWorkgroupDimensionSize(_entryPoint.workgroupSize.width, pSpecializationInfo), + getWorkgroupDimensionSize(_entryPoint.workgroupSize.height, pSpecializationInfo), + getWorkgroupDimensionSize(_entryPoint.workgroupSize.depth, pSpecializationInfo)) }; - uint32_t heightOffset = getOffsetForConstantId(pSpecializationInfo, _entryPoint.workgroupSizeId.height); - if (heightOffset != -1) { - _entryPoint.workgroupSize.height = *reinterpret_cast((uint8_t*)pSpecializationInfo->pData + heightOffset); - } - - uint32_t depthOffset = getOffsetForConstantId(pSpecializationInfo, _entryPoint.workgroupSizeId.depth); - if (depthOffset != -1) { - _entryPoint.workgroupSize.depth = *reinterpret_cast((uint8_t*)pSpecializationInfo->pData + depthOffset); - } - } - } - - return { mtlFunc, MTLSizeMake(_entryPoint.workgroupSize.width, _entryPoint.workgroupSize.height, _entryPoint.workgroupSize.depth) }; } // Returns the MTLFunctionConstant with the specified ID from the specified array of function constants. diff --git a/MoltenVKShaderConverter/MoltenVKSPIRVToMSLConverter/SPIRVToMSLConverter.cpp b/MoltenVKShaderConverter/MoltenVKSPIRVToMSLConverter/SPIRVToMSLConverter.cpp index f1e3308d..3b868c2e 100644 --- a/MoltenVKShaderConverter/MoltenVKSPIRVToMSLConverter/SPIRVToMSLConverter.cpp +++ b/MoltenVKShaderConverter/MoltenVKSPIRVToMSLConverter/SPIRVToMSLConverter.cpp @@ -124,8 +124,8 @@ MVK_PUBLIC_SYMBOL void SPIRVToMSLConverterContext::alignUsageWith(const SPIRVToM #pragma mark - #pragma mark SPIRVToMSLConverter -/** Populates content extracted from the SPRI-V compiler. */ -void populateFromCompiler(spirv_cross::Compiler* pCompiler, SPIRVEntryPoint& entryPoint, SPIRVToMSLConverterOptions& options); +// Populates the entry point with info extracted from the SPRI-V compiler. +void populateEntryPoint(SPIRVEntryPoint& entryPoint, spirv_cross::Compiler* pCompiler, SPIRVToMSLConverterOptions& options); MVK_PUBLIC_SYMBOL void SPIRVToMSLConverter::setSPIRV(const vector& spirv) { _spirv = spirv; } @@ -224,7 +224,7 @@ MVK_PUBLIC_SYMBOL bool SPIRVToMSLConverter::convert(SPIRVToMSLConverterContext& #endif // Populate content extracted from the SPRI-V compiler. - populateFromCompiler(pMSLCompiler, _entryPoint, context.options); + populateEntryPoint(_entryPoint, pMSLCompiler, context.options); // To check GLSL conversion if (shouldLogGLSL) { @@ -334,7 +334,14 @@ void SPIRVToMSLConverter::logSource(string& src, const char* srcLang, const char #pragma mark Support functions -void populateFromCompiler(spirv_cross::Compiler* pCompiler, SPIRVEntryPoint& entryPoint, SPIRVToMSLConverterOptions& options) { +// Populate a workgroup size dimension. +void populateWorkgroupDimension(SPIRVWorkgroupSizeDimension& wgDim, uint32_t size, spirv_cross::SpecializationConstant& spvSpecConst) { + wgDim.size = max(size, 1u); + wgDim.isSpecialized = (spvSpecConst.id != 0); + wgDim.specializationID = spvSpecConst.constant_id; +} + +void populateEntryPoint(SPIRVEntryPoint& entryPoint, spirv_cross::Compiler* pCompiler, SPIRVToMSLConverterOptions& options) { if ( !pCompiler ) { return; } @@ -349,19 +356,13 @@ void populateFromCompiler(spirv_cross::Compiler* pCompiler, SPIRVEntryPoint& ent } } - uint32_t minDim = 1; - auto& wgSize = spvEP.workgroup_size; + spirv_cross::SpecializationConstant widthSC, heightSC, depthSC; + pCompiler->get_work_group_size_specialization_constants(widthSC, heightSC, depthSC); entryPoint.mtlFunctionName = spvEP.name; - entryPoint.workgroupSize.width = max(wgSize.x, minDim); - entryPoint.workgroupSize.height = max(wgSize.y, minDim); - entryPoint.workgroupSize.depth = max(wgSize.z, minDim); - - spirv_cross::SpecializationConstant width, height, depth; - entryPoint.workgroupSizeId.constant = pCompiler->get_work_group_size_specialization_constants(width, height, depth); - entryPoint.workgroupSizeId.width = width.constant_id; - entryPoint.workgroupSizeId.height = height.constant_id; - entryPoint.workgroupSizeId.depth = depth.constant_id; + populateWorkgroupDimension(entryPoint.workgroupSize.width, spvEP.workgroup_size.x, widthSC); + populateWorkgroupDimension(entryPoint.workgroupSize.height, spvEP.workgroup_size.y, heightSC); + populateWorkgroupDimension(entryPoint.workgroupSize.depth, spvEP.workgroup_size.z, depthSC); } MVK_PUBLIC_SYMBOL void mvk::logSPIRV(vector& spirv, string& spvLog) { diff --git a/MoltenVKShaderConverter/MoltenVKSPIRVToMSLConverter/SPIRVToMSLConverter.h b/MoltenVKShaderConverter/MoltenVKSPIRVToMSLConverter/SPIRVToMSLConverter.h index 0106cde2..ba2ced3a 100644 --- a/MoltenVKShaderConverter/MoltenVKSPIRVToMSLConverter/SPIRVToMSLConverter.h +++ b/MoltenVKShaderConverter/MoltenVKSPIRVToMSLConverter/SPIRVToMSLConverter.h @@ -138,24 +138,29 @@ namespace mvk { } SPIRVToMSLConverterContext; /** + * Describes one dimension of the workgroup size of a SPIR-V entry point, including whether + * it is specialized, and if so, the value of the corresponding specialization ID, which + * is used to map to a value which will be provided when the MSL is compiled into a pipeline. + */ + typedef struct { + uint32_t size = 1; + uint32_t specializationID = 0; + bool isSpecialized = false; + } SPIRVWorkgroupSizeDimension; + + /** * Describes a SPIRV entry point, including the Metal function name (which may be * different than the Vulkan entry point name if the original name was illegal in Metal), - * and the number of threads in each workgroup or their specialization constant id, if the shader is a compute shader. + * and the size of each workgroup, if the shader is a compute shader. */ - typedef struct { - std::string mtlFunctionName = "main0"; - struct { - uint32_t width = 1; - uint32_t height = 1; - uint32_t depth = 1; - } workgroupSize; - struct { - uint32_t width = 1; - uint32_t height = 1; - uint32_t depth = 1; - uint32_t constant = 0; - } workgroupSizeId; - } SPIRVEntryPoint; + typedef struct { + std::string mtlFunctionName = "main0"; + struct { + SPIRVWorkgroupSizeDimension width; + SPIRVWorkgroupSizeDimension height; + SPIRVWorkgroupSizeDimension depth; + } workgroupSize; + } SPIRVEntryPoint; /** Special constant used in a MSLResourceBinding descriptorSet element to indicate the bindings for the push constants. */ static const uint32_t kPushConstDescSet = std::numeric_limits::max(); diff --git a/MoltenVKShaderConverter/MoltenVKShaderConverter.xcodeproj/xcshareddata/xcschemes/MoltenVKShaderConverter.xcscheme b/MoltenVKShaderConverter/MoltenVKShaderConverter.xcodeproj/xcshareddata/xcschemes/MoltenVKShaderConverter.xcscheme index c3d1e188..04cdef2b 100644 --- a/MoltenVKShaderConverter/MoltenVKShaderConverter.xcodeproj/xcshareddata/xcschemes/MoltenVKShaderConverter.xcscheme +++ b/MoltenVKShaderConverter/MoltenVKShaderConverter.xcodeproj/xcshareddata/xcschemes/MoltenVKShaderConverter.xcscheme @@ -67,7 +67,11 @@ isEnabled = "NO"> + + fileContents; vector spv; @@ -177,11 +174,11 @@ bool MoltenVKShaderConverterTool::convertSPIRV(string& spvInFile, return convertSPIRV(spv, spvInFile, mslOutFile, _shouldLogConversions); } -/** Read SPIR-V code from an array, convert to MSL, and write the MSL code to files. */ +// Read SPIR-V code from an array, convert to MSL, and write the MSL code to files. bool MoltenVKShaderConverterTool::convertSPIRV(const vector& spv, - string& inFile, - string& mslOutFile, - bool shouldLogSPV) { + string& inFile, + string& mslOutFile, + bool shouldLogSPV) { if ( !_shouldWriteMSL ) { return true; } // Derive the context under which conversion will occur @@ -236,10 +233,10 @@ bool MoltenVKShaderConverterTool::isSPIRVFileExtension(string& pathExtension) { return false; } -/** Log the specified message to the console. */ +// Log the specified message to the console. void MoltenVKShaderConverterTool::log(const char* logMsg) { printf("%s\n", logMsg); } -/** Display usage information about this application on the console. */ +// Display usage information about this application on the console. void MoltenVKShaderConverterTool::showUsage() { string line = "\n\e[1m" + _processName + "\e[0m converts OpenGL Shading Language (GLSL) source code to"; log((const char*)line.c_str()); @@ -252,9 +249,9 @@ void MoltenVKShaderConverterTool::showUsage() { log("\nUse the -so or -mo option to indicate the desired type of output"); log("(SPIR-V or MSL, respectively)."); log("\nUsage:"); - log(" -d [\"dirPath\"] - Path to a directory containing GLSL shader source code"); - log(" files. The dirPath may be omitted to use the current"); - log(" working directory."); + log(" -d [\"dirPath\"] - Path to a directory containing GLSL or SPIR-V shader"); + log(" source code files. The dirPath may be omitted to use"); + log(" the current working directory."); log(" -r - (when using -d) Process directories recursively."); log(" -gi [\"glslInFile\"] - Indicates that GLSL shader code should be input."); log(" The optional path parameter specifies the path to a"); @@ -334,7 +331,7 @@ bool MoltenVKShaderConverterTool::parseArgs(int argc, const char* argv[]) { if (equal(arg, "-d", false)) { int optIdx = argIdx; - argIdx = optionParam(_directoryPath, argIdx, argc, argv); + argIdx = optionalParam(_directoryPath, argIdx, argc, argv); if (argIdx == optIdx) { return false; } _directoryPath = absolutePath(_directoryPath); continue; @@ -347,32 +344,32 @@ bool MoltenVKShaderConverterTool::parseArgs(int argc, const char* argv[]) { if (equal(arg, "-gi", true)) { _shouldReadGLSL = true; - argIdx = optionParam(_glslInFilePath, argIdx, argc, argv); + argIdx = optionalParam(_glslInFilePath, argIdx, argc, argv); continue; } if (equal(arg, "-si", true)) { _shouldReadSPIRV = true; - argIdx = optionParam(_spvInFilePath, argIdx, argc, argv); + argIdx = optionalParam(_spvInFilePath, argIdx, argc, argv); continue; } if (equal(arg, "-so", true)) { _shouldWriteSPIRV = true; - argIdx = optionParam(_spvOutFilePath, argIdx, argc, argv); + argIdx = optionalParam(_spvOutFilePath, argIdx, argc, argv); continue; } if (equal(arg, "-mo", true)) { _shouldWriteMSL = true; - argIdx = optionParam(_mslOutFilePath, argIdx, argc, argv); + argIdx = optionalParam(_mslOutFilePath, argIdx, argc, argv); continue; } if (equal(arg, "-t", true)) { int optIdx = argIdx; string shdrTypeStr; - argIdx = optionParam(shdrTypeStr, argIdx, argc, argv); + argIdx = optionalParam(shdrTypeStr, argIdx, argc, argv); if (argIdx == optIdx || shdrTypeStr.length() == 0) { return false; } switch (shdrTypeStr.front()) { @@ -416,7 +413,7 @@ bool MoltenVKShaderConverterTool::parseArgs(int argc, const char* argv[]) { if (equal(arg, "-vx", true)) { int optIdx = argIdx; string shdrExtnStr; - argIdx = optionParam(shdrExtnStr, argIdx, argc, argv); + argIdx = optionalParam(shdrExtnStr, argIdx, argc, argv); if (argIdx == optIdx || shdrExtnStr.length() == 0) { return false; } extractTokens(shdrExtnStr, _glslVtxFileExtns); continue; @@ -425,7 +422,7 @@ bool MoltenVKShaderConverterTool::parseArgs(int argc, const char* argv[]) { if (equal(arg, "-fx", true)) { int optIdx = argIdx; string shdrExtnStr; - argIdx = optionParam(shdrExtnStr, argIdx, argc, argv); + argIdx = optionalParam(shdrExtnStr, argIdx, argc, argv); if (argIdx == optIdx || shdrExtnStr.length() == 0) { return false; } extractTokens(shdrExtnStr, _glslFragFileExtns); continue; @@ -434,7 +431,7 @@ bool MoltenVKShaderConverterTool::parseArgs(int argc, const char* argv[]) { if (equal(arg, "-cx", true)) { int optIdx = argIdx; string shdrExtnStr; - argIdx = optionParam(shdrExtnStr, argIdx, argc, argv); + argIdx = optionalParam(shdrExtnStr, argIdx, argc, argv); if (argIdx == optIdx || shdrExtnStr.length() == 0) { return false; } extractTokens(shdrExtnStr, _glslCompFileExtns); continue; @@ -443,7 +440,7 @@ bool MoltenVKShaderConverterTool::parseArgs(int argc, const char* argv[]) { if (equal(arg, "-sx", true)) { int optIdx = argIdx; string shdrExtnStr; - argIdx = optionParam(shdrExtnStr, argIdx, argc, argv); + argIdx = optionalParam(shdrExtnStr, argIdx, argc, argv); if (argIdx == optIdx || shdrExtnStr.length() == 0) { return false; } extractTokens(shdrExtnStr, _spvFileExtns); continue; @@ -459,21 +456,19 @@ bool MoltenVKShaderConverterTool::parseArgs(int argc, const char* argv[]) { return true; } -/** Returns whether the specified command line arg is an option arg. */ +// Returns whether the specified command line arg is an option arg. bool MoltenVKShaderConverterTool::isOptionArg(string& arg) { return (arg.length() > 1 && arg.front() == '-'); } -/** - * Sets the contents of the specified string to the parameter part of the option at the - * specified arg index, and increments and returns the option index. If no parameter was - * provided for the option, the string will be set to an empty string, and the returned - * index will be the same as the specified index. - */ -int MoltenVKShaderConverterTool::optionParam(string& optionParamResult, - int optionArgIndex, - int argc, - const char* argv[]) { +// Sets the contents of the specified string to the parameter part of the option at the +// specified arg index, and increments and returns the option index. If no parameter was +// provided for the option, the string will be set to an empty string, and the returned +// index will be the same as the specified index. +int MoltenVKShaderConverterTool::optionalParam(string& optionParamResult, + int optionArgIndex, + int argc, + const char* argv[]) { int optParamIdx = optionArgIndex + 1; if (optParamIdx < argc) { string arg(argv[optParamIdx]); @@ -490,7 +485,7 @@ int MoltenVKShaderConverterTool::optionParam(string& optionParamResult, #pragma mark - #pragma mark Support functions -/** Template function for tokenizing the components of a string into a vector. */ +// Template function for tokenizing the components of a string into a vector. template Container& split(Container& result, const typename Container::value_type& s, @@ -516,7 +511,7 @@ void mvk::extractTokens(string str, vector& tokens) { split(tokens, str, " \t\n\f", false); } -/** Compares the specified characters ignoring case. */ +// Compares the specified characters ignoring case. static bool compareIgnoringCase(unsigned char a, unsigned char b) { return tolower(a) == tolower(b); } diff --git a/MoltenVKShaderConverter/MoltenVKShaderConverterTool/MoltenVKShaderConverterTool.h b/MoltenVKShaderConverter/MoltenVKShaderConverterTool/MoltenVKShaderConverterTool.h index 2a6f909f..13ba6265 100644 --- a/MoltenVKShaderConverter/MoltenVKShaderConverterTool/MoltenVKShaderConverterTool.h +++ b/MoltenVKShaderConverter/MoltenVKShaderConverterTool/MoltenVKShaderConverterTool.h @@ -69,10 +69,10 @@ namespace mvk { void log(const char* logMsg); void showUsage(); bool isOptionArg(std::string& arg); - int optionParam(std::string& optionParamResult, - int optionArgIndex, - int argc, - const char* argv[]); + int optionalParam(std::string& optionParamResult, + int optionArgIndex, + int argc, + const char* argv[]); std::string _processName; std::string _directoryPath;