Merge pull request #52 from mmaldacker/feature/workgroupsize_spe_const

Work group size specialization constant
This commit is contained in:
Bill Hollings 2018-02-19 14:50:17 -05:00 committed by GitHub
commit 1a8b3004bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 45 additions and 1 deletions

View File

@ -29,6 +29,18 @@ const MVKMTLFunction MVKMTLFunctionNull = { nil, MTLSizeMake(1, 1, 1) };
#pragma mark -
#pragma mark MVKShaderLibrary
uint32_t getOffsetForConstantId(const VkSpecializationInfo* pSpecInfo, uint32_t constantId)
{
for (uint32_t specIdx = 0; specIdx < pSpecInfo->mapEntryCount; specIdx++) {
const VkSpecializationMapEntry* pMapEntry = &pSpecInfo->pMapEntries[specIdx];
if (pMapEntry->constantID == constantId) {
return pMapEntry->offset;
}
}
return -1;
}
MVKMTLFunction MVKShaderLibrary::getMTLFunction(const VkPipelineShaderStageCreateInfo* pShaderStage) {
if ( !_mtlLibrary ) { return MVKMTLFunctionNull; }
@ -80,6 +92,27 @@ MVKMTLFunction MVKShaderLibrary::getMTLFunction(const VkPipelineShaderStageCreat
} else {
mvkNotifyErrorWithText(VK_ERROR_INITIALIZATION_FAILED, "Shader module does not contain an entry point named '%s'.", mtlFuncName.UTF8String);
}
const VkSpecializationInfo* pSpecInfo = pShaderStage->pSpecializationInfo;
if (pSpecInfo) {
// Get the specialization constant values for the work group size
if (ep.workgroupSizeId.constant != 0) {
uint32_t widthOffset = getOffsetForConstantId(pSpecInfo, ep.workgroupSizeId.width);
if (widthOffset != -1) {
ep.workgroupSize.width = *reinterpret_cast<uint32_t*>((uint8_t*)pSpecInfo->pData + widthOffset);
}
uint32_t heightOffset = getOffsetForConstantId(pSpecInfo, ep.workgroupSizeId.height);
if (heightOffset != -1) {
ep.workgroupSize.height = *reinterpret_cast<uint32_t*>((uint8_t*)pSpecInfo->pData + heightOffset);
}
uint32_t depthOffset = getOffsetForConstantId(pSpecInfo, ep.workgroupSizeId.depth);
if (depthOffset != -1) {
ep.workgroupSize.depth = *reinterpret_cast<uint32_t*>((uint8_t*)pSpecInfo->pData + depthOffset);
}
}
}
return { mtlFunc, MTLSizeMake(ep.workgroupSize.width, ep.workgroupSize.height, ep.workgroupSize.depth) };
}

View File

@ -322,6 +322,13 @@ void populateFromCompiler(spirv_cross::Compiler& compiler, SPIRVEntryPointsByNam
mvkEP.workgroupSize.width = max(wgSize.x, minDim);
mvkEP.workgroupSize.height = max(wgSize.y, minDim);
mvkEP.workgroupSize.depth = max(wgSize.z, minDim);
spirv_cross::SpecializationConstant width, height, depth;
mvkEP.workgroupSizeId.constant = compiler.get_work_group_size_specialization_constants(width, height, depth);
mvkEP.workgroupSizeId.width = width.constant_id;
mvkEP.workgroupSizeId.height = height.constant_id;
mvkEP.workgroupSizeId.depth = depth.constant_id;
entryPoints[epOrigName] = mvkEP;
}
}

View File

@ -132,7 +132,7 @@ namespace mvk {
/**
* Describes a SPIRV entry point, including the Metal function name (which may be
* different than the Vulkan entry point name if the original name was illegal in Metal),
* and the number of threads in each workgroup, if the shader is a compute shader.
* and the number of threads in each workgroup or their specialization constant id, if the shader is a compute shader.
*/
typedef struct {
std::string mtlFunctionName;
@ -141,6 +141,10 @@ namespace mvk {
uint32_t height = 1;
uint32_t depth = 1;
} workgroupSize;
struct {
uint32_t width, height, depth;
uint32_t constant = 0;
} workgroupSizeId;
} SPIRVEntryPoint;
/** Holds a map of entry point info, indexed by the SPIRV entry point name. */