Fix compute shader workgroup size specialization.

Support separate specialization for each workgroup dimension.
Support zero as a specialization ID value.
Cleanup MoltenVKShaderConverterTool.
Update to latest SPIRV-Cross version.
Update MoltenVK version to 1.0.14.
This commit is contained in:
Bill Hollings 2018-07-03 13:57:53 -04:00
parent 5280e9515a
commit e203b93285
9 changed files with 110 additions and 114 deletions

View File

@ -1 +1 @@
d67e586b2e16a46a5cc1515093e8a04bff31c594
a6814a405abe81545bd3b0a50d374735001173c1

View File

@ -48,7 +48,7 @@ extern "C" {
*/
#define MVK_VERSION_MAJOR 1
#define MVK_VERSION_MINOR 0
#define MVK_VERSION_PATCH 13
#define MVK_VERSION_PATCH 14
#define MVK_MAKE_VERSION(major, minor, patch) (((major) * 10000) + ((minor) * 100) + (patch))
#define MVK_VERSION MVK_MAKE_VERSION(MVK_VERSION_MAJOR, MVK_VERSION_MINOR, MVK_VERSION_PATCH)

View File

@ -485,16 +485,19 @@ typedef enum {
// Ceral archive definitions
namespace mvk {
template<class Archive>
void serialize(Archive & archive, SPIRVWorkgroupSizeDimension& wsd) {
archive(wsd.size,
wsd.specializationID,
wsd.isSpecialized);
}
template<class Archive>
void serialize(Archive & archive, SPIRVEntryPoint& ep) {
archive(ep.mtlFunctionName,
ep.workgroupSize.width,
ep.workgroupSize.height,
ep.workgroupSize.depth,
ep.workgroupSizeId.width,
ep.workgroupSizeId.height,
ep.workgroupSizeId.depth,
ep.workgroupSizeId.constant);
ep.workgroupSize.depth);
}
template<class Archive>

View File

@ -30,14 +30,19 @@ const MVKMTLFunction MVKMTLFunctionNull = { nil, MTLSizeMake(1, 1, 1) };
#pragma mark -
#pragma mark MVKShaderLibrary
static uint32_t getOffsetForConstantId(const VkSpecializationInfo* pSpecInfo, uint32_t constantId)
{
for (uint32_t specIdx = 0; specIdx < pSpecInfo->mapEntryCount; specIdx++) {
const VkSpecializationMapEntry* pMapEntry = &pSpecInfo->pMapEntries[specIdx];
if (pMapEntry->constantID == constantId) { return pMapEntry->offset; }
}
// If the size of the workgroup dimension is specialized, extract it from the
// specialization info, otherwise use the value specified in the SPIR-V shader code.
static uint32_t getWorkgroupDimensionSize(const SPIRVWorkgroupSizeDimension& wgDim, const VkSpecializationInfo* pSpecInfo) {
if (wgDim.isSpecialized && pSpecInfo) {
for (uint32_t specIdx = 0; specIdx < pSpecInfo->mapEntryCount; specIdx++) {
const VkSpecializationMapEntry* pMapEntry = &pSpecInfo->pMapEntries[specIdx];
if (pMapEntry->constantID == wgDim.specializationID) {
return *reinterpret_cast<uint32_t*>((uintptr_t)pSpecInfo->pData + pMapEntry->offset) ;
}
}
}
return -1;
return wgDim.size;
}
MVKMTLFunction MVKShaderLibrary::getMTLFunction(const VkSpecializationInfo* pSpecializationInfo) {
@ -88,27 +93,10 @@ MVKMTLFunction MVKShaderLibrary::getMTLFunction(const VkSpecializationInfo* pSpe
mvkNotifyErrorWithText(VK_ERROR_INITIALIZATION_FAILED, "Shader module does not contain an entry point named '%s'.", mtlFuncName.UTF8String);
}
if (pSpecializationInfo) {
// Get the specialization constant values for the work group size
if (_entryPoint.workgroupSizeId.constant != 0) {
uint32_t widthOffset = getOffsetForConstantId(pSpecializationInfo, _entryPoint.workgroupSizeId.width);
if (widthOffset != -1) {
_entryPoint.workgroupSize.width = *reinterpret_cast<uint32_t*>((uint8_t*)pSpecializationInfo->pData + widthOffset);
}
return { mtlFunc, MTLSizeMake(getWorkgroupDimensionSize(_entryPoint.workgroupSize.width, pSpecializationInfo),
getWorkgroupDimensionSize(_entryPoint.workgroupSize.height, pSpecializationInfo),
getWorkgroupDimensionSize(_entryPoint.workgroupSize.depth, pSpecializationInfo)) };
uint32_t heightOffset = getOffsetForConstantId(pSpecializationInfo, _entryPoint.workgroupSizeId.height);
if (heightOffset != -1) {
_entryPoint.workgroupSize.height = *reinterpret_cast<uint32_t*>((uint8_t*)pSpecializationInfo->pData + heightOffset);
}
uint32_t depthOffset = getOffsetForConstantId(pSpecializationInfo, _entryPoint.workgroupSizeId.depth);
if (depthOffset != -1) {
_entryPoint.workgroupSize.depth = *reinterpret_cast<uint32_t*>((uint8_t*)pSpecializationInfo->pData + depthOffset);
}
}
}
return { mtlFunc, MTLSizeMake(_entryPoint.workgroupSize.width, _entryPoint.workgroupSize.height, _entryPoint.workgroupSize.depth) };
}
// Returns the MTLFunctionConstant with the specified ID from the specified array of function constants.

View File

@ -124,8 +124,8 @@ MVK_PUBLIC_SYMBOL void SPIRVToMSLConverterContext::alignUsageWith(const SPIRVToM
#pragma mark -
#pragma mark SPIRVToMSLConverter
/** Populates content extracted from the SPRI-V compiler. */
void populateFromCompiler(spirv_cross::Compiler* pCompiler, SPIRVEntryPoint& entryPoint, SPIRVToMSLConverterOptions& options);
// Populates the entry point with info extracted from the SPRI-V compiler.
void populateEntryPoint(SPIRVEntryPoint& entryPoint, spirv_cross::Compiler* pCompiler, SPIRVToMSLConverterOptions& options);
MVK_PUBLIC_SYMBOL void SPIRVToMSLConverter::setSPIRV(const vector<uint32_t>& spirv) { _spirv = spirv; }
@ -224,7 +224,7 @@ MVK_PUBLIC_SYMBOL bool SPIRVToMSLConverter::convert(SPIRVToMSLConverterContext&
#endif
// Populate content extracted from the SPRI-V compiler.
populateFromCompiler(pMSLCompiler, _entryPoint, context.options);
populateEntryPoint(_entryPoint, pMSLCompiler, context.options);
// To check GLSL conversion
if (shouldLogGLSL) {
@ -334,7 +334,14 @@ void SPIRVToMSLConverter::logSource(string& src, const char* srcLang, const char
#pragma mark Support functions
void populateFromCompiler(spirv_cross::Compiler* pCompiler, SPIRVEntryPoint& entryPoint, SPIRVToMSLConverterOptions& options) {
// Populate a workgroup size dimension.
void populateWorkgroupDimension(SPIRVWorkgroupSizeDimension& wgDim, uint32_t size, spirv_cross::SpecializationConstant& spvSpecConst) {
wgDim.size = max(size, 1u);
wgDim.isSpecialized = (spvSpecConst.id != 0);
wgDim.specializationID = spvSpecConst.constant_id;
}
void populateEntryPoint(SPIRVEntryPoint& entryPoint, spirv_cross::Compiler* pCompiler, SPIRVToMSLConverterOptions& options) {
if ( !pCompiler ) { return; }
@ -349,19 +356,13 @@ void populateFromCompiler(spirv_cross::Compiler* pCompiler, SPIRVEntryPoint& ent
}
}
uint32_t minDim = 1;
auto& wgSize = spvEP.workgroup_size;
spirv_cross::SpecializationConstant widthSC, heightSC, depthSC;
pCompiler->get_work_group_size_specialization_constants(widthSC, heightSC, depthSC);
entryPoint.mtlFunctionName = spvEP.name;
entryPoint.workgroupSize.width = max(wgSize.x, minDim);
entryPoint.workgroupSize.height = max(wgSize.y, minDim);
entryPoint.workgroupSize.depth = max(wgSize.z, minDim);
spirv_cross::SpecializationConstant width, height, depth;
entryPoint.workgroupSizeId.constant = pCompiler->get_work_group_size_specialization_constants(width, height, depth);
entryPoint.workgroupSizeId.width = width.constant_id;
entryPoint.workgroupSizeId.height = height.constant_id;
entryPoint.workgroupSizeId.depth = depth.constant_id;
populateWorkgroupDimension(entryPoint.workgroupSize.width, spvEP.workgroup_size.x, widthSC);
populateWorkgroupDimension(entryPoint.workgroupSize.height, spvEP.workgroup_size.y, heightSC);
populateWorkgroupDimension(entryPoint.workgroupSize.depth, spvEP.workgroup_size.z, depthSC);
}
MVK_PUBLIC_SYMBOL void mvk::logSPIRV(vector<uint32_t>& spirv, string& spvLog) {

View File

@ -138,24 +138,29 @@ namespace mvk {
} SPIRVToMSLConverterContext;
/**
* Describes one dimension of the workgroup size of a SPIR-V entry point, including whether
* it is specialized, and if so, the value of the corresponding specialization ID, which
* is used to map to a value which will be provided when the MSL is compiled into a pipeline.
*/
typedef struct {
uint32_t size = 1;
uint32_t specializationID = 0;
bool isSpecialized = false;
} SPIRVWorkgroupSizeDimension;
/**
* Describes a SPIRV entry point, including the Metal function name (which may be
* different than the Vulkan entry point name if the original name was illegal in Metal),
* and the number of threads in each workgroup or their specialization constant id, if the shader is a compute shader.
* and the size of each workgroup, if the shader is a compute shader.
*/
typedef struct {
std::string mtlFunctionName = "main0";
struct {
uint32_t width = 1;
uint32_t height = 1;
uint32_t depth = 1;
} workgroupSize;
struct {
uint32_t width = 1;
uint32_t height = 1;
uint32_t depth = 1;
uint32_t constant = 0;
} workgroupSizeId;
} SPIRVEntryPoint;
typedef struct {
std::string mtlFunctionName = "main0";
struct {
SPIRVWorkgroupSizeDimension width;
SPIRVWorkgroupSizeDimension height;
SPIRVWorkgroupSizeDimension depth;
} workgroupSize;
} SPIRVEntryPoint;
/** Special constant used in a MSLResourceBinding descriptorSet element to indicate the bindings for the push constants. */
static const uint32_t kPushConstDescSet = std::numeric_limits<uint32_t>::max();

View File

@ -67,7 +67,11 @@
isEnabled = "NO">
</CommandLineArgument>
<CommandLineArgument
argument = "/Users/bill/Documents/Dev/iOSProjects/Molten/MoltenVK/External/SPIRV-Cross/shaders-msl"
argument = "path-to-shader-directory"
isEnabled = "NO">
</CommandLineArgument>
<CommandLineArgument
argument = "-r"
isEnabled = "NO">
</CommandLineArgument>
<CommandLineArgument
@ -75,7 +79,7 @@
isEnabled = "YES">
</CommandLineArgument>
<CommandLineArgument
argument = "/Users/bill/Desktop/texture_buffer.vert"
argument = "path-to-GLSL-shader-file"
isEnabled = "YES">
</CommandLineArgument>
<CommandLineArgument
@ -83,7 +87,7 @@
isEnabled = "NO">
</CommandLineArgument>
<CommandLineArgument
argument = "/Users/bill/Documents/Dev/iOSProjects/Molten/Support/2018/MVK_Issue_112/second/vert_bin.spv"
argument = "path-to-SPIR-V-shader-file"
isEnabled = "NO">
</CommandLineArgument>
<CommandLineArgument

View File

@ -26,16 +26,16 @@ using namespace std;
using namespace mvk;
/** The default list of vertex file extensions. */
// The default list of vertex file extensions.
static const char* _defaultVertexShaderExtns = "vs vsh vert vertex";
/** The default list of fragment file extensions. */
// The default list of fragment file extensions.
static const char* _defaultFragShaderExtns = "fs fsh frag fragment";
/** The default list of compute file extensions. */
// The default list of compute file extensions.
static const char* _defaultCompShaderExtns = "cp cmp comp compute kn kl krn kern kernel";
/** The default list of SPIR-V file extensions. */
// The default list of SPIR-V file extensions.
static const char* _defaultSPIRVShaderExtns = "spv spirv";
@ -75,10 +75,8 @@ bool MoltenVKShaderConverterTool::processFile(string filePath) {
return false;
}
/**
* Read GLSL code from a GLSL file, convert to SPIR-V, and optionally MSL,
* and write the SPIR-V and/or MSL code to files.
*/
// Read GLSL code from a GLSL file, convert to SPIR-V, and optionally MSL,
// and write the SPIR-V and/or MSL code to files.
bool MoltenVKShaderConverterTool::convertGLSL(string& glslInFile,
string& spvOutFile,
string& mslOutFile,
@ -149,9 +147,8 @@ bool MoltenVKShaderConverterTool::convertGLSL(string& glslInFile,
return convertSPIRV(spv, glslInFile, mslOutFile, false);
}
/** Read SPIR-V code from a SPIR-V file, convert to MSL, and write the MSL code to files. */
bool MoltenVKShaderConverterTool::convertSPIRV(string& spvInFile,
string& mslOutFile) {
// Read SPIR-V code from a SPIR-V file, convert to MSL, and write the MSL code to files.
bool MoltenVKShaderConverterTool::convertSPIRV(string& spvInFile, string& mslOutFile) {
string path;
vector<char> fileContents;
vector<uint32_t> spv;
@ -177,11 +174,11 @@ bool MoltenVKShaderConverterTool::convertSPIRV(string& spvInFile,
return convertSPIRV(spv, spvInFile, mslOutFile, _shouldLogConversions);
}
/** Read SPIR-V code from an array, convert to MSL, and write the MSL code to files. */
// Read SPIR-V code from an array, convert to MSL, and write the MSL code to files.
bool MoltenVKShaderConverterTool::convertSPIRV(const vector<uint32_t>& spv,
string& inFile,
string& mslOutFile,
bool shouldLogSPV) {
string& inFile,
string& mslOutFile,
bool shouldLogSPV) {
if ( !_shouldWriteMSL ) { return true; }
// Derive the context under which conversion will occur
@ -236,10 +233,10 @@ bool MoltenVKShaderConverterTool::isSPIRVFileExtension(string& pathExtension) {
return false;
}
/** Log the specified message to the console. */
// Log the specified message to the console.
void MoltenVKShaderConverterTool::log(const char* logMsg) { printf("%s\n", logMsg); }
/** Display usage information about this application on the console. */
// Display usage information about this application on the console.
void MoltenVKShaderConverterTool::showUsage() {
string line = "\n\e[1m" + _processName + "\e[0m converts OpenGL Shading Language (GLSL) source code to";
log((const char*)line.c_str());
@ -252,9 +249,9 @@ void MoltenVKShaderConverterTool::showUsage() {
log("\nUse the -so or -mo option to indicate the desired type of output");
log("(SPIR-V or MSL, respectively).");
log("\nUsage:");
log(" -d [\"dirPath\"] - Path to a directory containing GLSL shader source code");
log(" files. The dirPath may be omitted to use the current");
log(" working directory.");
log(" -d [\"dirPath\"] - Path to a directory containing GLSL or SPIR-V shader");
log(" source code files. The dirPath may be omitted to use");
log(" the current working directory.");
log(" -r - (when using -d) Process directories recursively.");
log(" -gi [\"glslInFile\"] - Indicates that GLSL shader code should be input.");
log(" The optional path parameter specifies the path to a");
@ -334,7 +331,7 @@ bool MoltenVKShaderConverterTool::parseArgs(int argc, const char* argv[]) {
if (equal(arg, "-d", false)) {
int optIdx = argIdx;
argIdx = optionParam(_directoryPath, argIdx, argc, argv);
argIdx = optionalParam(_directoryPath, argIdx, argc, argv);
if (argIdx == optIdx) { return false; }
_directoryPath = absolutePath(_directoryPath);
continue;
@ -347,32 +344,32 @@ bool MoltenVKShaderConverterTool::parseArgs(int argc, const char* argv[]) {
if (equal(arg, "-gi", true)) {
_shouldReadGLSL = true;
argIdx = optionParam(_glslInFilePath, argIdx, argc, argv);
argIdx = optionalParam(_glslInFilePath, argIdx, argc, argv);
continue;
}
if (equal(arg, "-si", true)) {
_shouldReadSPIRV = true;
argIdx = optionParam(_spvInFilePath, argIdx, argc, argv);
argIdx = optionalParam(_spvInFilePath, argIdx, argc, argv);
continue;
}
if (equal(arg, "-so", true)) {
_shouldWriteSPIRV = true;
argIdx = optionParam(_spvOutFilePath, argIdx, argc, argv);
argIdx = optionalParam(_spvOutFilePath, argIdx, argc, argv);
continue;
}
if (equal(arg, "-mo", true)) {
_shouldWriteMSL = true;
argIdx = optionParam(_mslOutFilePath, argIdx, argc, argv);
argIdx = optionalParam(_mslOutFilePath, argIdx, argc, argv);
continue;
}
if (equal(arg, "-t", true)) {
int optIdx = argIdx;
string shdrTypeStr;
argIdx = optionParam(shdrTypeStr, argIdx, argc, argv);
argIdx = optionalParam(shdrTypeStr, argIdx, argc, argv);
if (argIdx == optIdx || shdrTypeStr.length() == 0) { return false; }
switch (shdrTypeStr.front()) {
@ -416,7 +413,7 @@ bool MoltenVKShaderConverterTool::parseArgs(int argc, const char* argv[]) {
if (equal(arg, "-vx", true)) {
int optIdx = argIdx;
string shdrExtnStr;
argIdx = optionParam(shdrExtnStr, argIdx, argc, argv);
argIdx = optionalParam(shdrExtnStr, argIdx, argc, argv);
if (argIdx == optIdx || shdrExtnStr.length() == 0) { return false; }
extractTokens(shdrExtnStr, _glslVtxFileExtns);
continue;
@ -425,7 +422,7 @@ bool MoltenVKShaderConverterTool::parseArgs(int argc, const char* argv[]) {
if (equal(arg, "-fx", true)) {
int optIdx = argIdx;
string shdrExtnStr;
argIdx = optionParam(shdrExtnStr, argIdx, argc, argv);
argIdx = optionalParam(shdrExtnStr, argIdx, argc, argv);
if (argIdx == optIdx || shdrExtnStr.length() == 0) { return false; }
extractTokens(shdrExtnStr, _glslFragFileExtns);
continue;
@ -434,7 +431,7 @@ bool MoltenVKShaderConverterTool::parseArgs(int argc, const char* argv[]) {
if (equal(arg, "-cx", true)) {
int optIdx = argIdx;
string shdrExtnStr;
argIdx = optionParam(shdrExtnStr, argIdx, argc, argv);
argIdx = optionalParam(shdrExtnStr, argIdx, argc, argv);
if (argIdx == optIdx || shdrExtnStr.length() == 0) { return false; }
extractTokens(shdrExtnStr, _glslCompFileExtns);
continue;
@ -443,7 +440,7 @@ bool MoltenVKShaderConverterTool::parseArgs(int argc, const char* argv[]) {
if (equal(arg, "-sx", true)) {
int optIdx = argIdx;
string shdrExtnStr;
argIdx = optionParam(shdrExtnStr, argIdx, argc, argv);
argIdx = optionalParam(shdrExtnStr, argIdx, argc, argv);
if (argIdx == optIdx || shdrExtnStr.length() == 0) { return false; }
extractTokens(shdrExtnStr, _spvFileExtns);
continue;
@ -459,21 +456,19 @@ bool MoltenVKShaderConverterTool::parseArgs(int argc, const char* argv[]) {
return true;
}
/** Returns whether the specified command line arg is an option arg. */
// Returns whether the specified command line arg is an option arg.
bool MoltenVKShaderConverterTool::isOptionArg(string& arg) {
return (arg.length() > 1 && arg.front() == '-');
}
/**
* Sets the contents of the specified string to the parameter part of the option at the
* specified arg index, and increments and returns the option index. If no parameter was
* provided for the option, the string will be set to an empty string, and the returned
* index will be the same as the specified index.
*/
int MoltenVKShaderConverterTool::optionParam(string& optionParamResult,
int optionArgIndex,
int argc,
const char* argv[]) {
// Sets the contents of the specified string to the parameter part of the option at the
// specified arg index, and increments and returns the option index. If no parameter was
// provided for the option, the string will be set to an empty string, and the returned
// index will be the same as the specified index.
int MoltenVKShaderConverterTool::optionalParam(string& optionParamResult,
int optionArgIndex,
int argc,
const char* argv[]) {
int optParamIdx = optionArgIndex + 1;
if (optParamIdx < argc) {
string arg(argv[optParamIdx]);
@ -490,7 +485,7 @@ int MoltenVKShaderConverterTool::optionParam(string& optionParamResult,
#pragma mark -
#pragma mark Support functions
/** Template function for tokenizing the components of a string into a vector. */
// Template function for tokenizing the components of a string into a vector.
template <typename Container>
Container& split(Container& result,
const typename Container::value_type& s,
@ -516,7 +511,7 @@ void mvk::extractTokens(string str, vector<string>& tokens) {
split(tokens, str, " \t\n\f", false);
}
/** Compares the specified characters ignoring case. */
// Compares the specified characters ignoring case.
static bool compareIgnoringCase(unsigned char a, unsigned char b) {
return tolower(a) == tolower(b);
}

View File

@ -69,10 +69,10 @@ namespace mvk {
void log(const char* logMsg);
void showUsage();
bool isOptionArg(std::string& arg);
int optionParam(std::string& optionParamResult,
int optionArgIndex,
int argc,
const char* argv[]);
int optionalParam(std::string& optionParamResult,
int optionArgIndex,
int argc,
const char* argv[]);
std::string _processName;
std::string _directoryPath;