Merge pull request #52 from mmaldacker/feature/workgroupsize_spe_const
Work group size specialization constant
This commit is contained in:
commit
1a8b3004bd
@ -29,6 +29,18 @@ const MVKMTLFunction MVKMTLFunctionNull = { nil, MTLSizeMake(1, 1, 1) };
|
||||
#pragma mark -
|
||||
#pragma mark MVKShaderLibrary
|
||||
|
||||
uint32_t getOffsetForConstantId(const VkSpecializationInfo* pSpecInfo, uint32_t constantId)
|
||||
{
|
||||
for (uint32_t specIdx = 0; specIdx < pSpecInfo->mapEntryCount; specIdx++) {
|
||||
const VkSpecializationMapEntry* pMapEntry = &pSpecInfo->pMapEntries[specIdx];
|
||||
if (pMapEntry->constantID == constantId) {
|
||||
return pMapEntry->offset;
|
||||
}
|
||||
}
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
MVKMTLFunction MVKShaderLibrary::getMTLFunction(const VkPipelineShaderStageCreateInfo* pShaderStage) {
|
||||
|
||||
if ( !_mtlLibrary ) { return MVKMTLFunctionNull; }
|
||||
@ -80,6 +92,27 @@ MVKMTLFunction MVKShaderLibrary::getMTLFunction(const VkPipelineShaderStageCreat
|
||||
} else {
|
||||
mvkNotifyErrorWithText(VK_ERROR_INITIALIZATION_FAILED, "Shader module does not contain an entry point named '%s'.", mtlFuncName.UTF8String);
|
||||
}
|
||||
|
||||
const VkSpecializationInfo* pSpecInfo = pShaderStage->pSpecializationInfo;
|
||||
if (pSpecInfo) {
|
||||
// Get the specialization constant values for the work group size
|
||||
if (ep.workgroupSizeId.constant != 0) {
|
||||
uint32_t widthOffset = getOffsetForConstantId(pSpecInfo, ep.workgroupSizeId.width);
|
||||
if (widthOffset != -1) {
|
||||
ep.workgroupSize.width = *reinterpret_cast<uint32_t*>((uint8_t*)pSpecInfo->pData + widthOffset);
|
||||
}
|
||||
|
||||
uint32_t heightOffset = getOffsetForConstantId(pSpecInfo, ep.workgroupSizeId.height);
|
||||
if (heightOffset != -1) {
|
||||
ep.workgroupSize.height = *reinterpret_cast<uint32_t*>((uint8_t*)pSpecInfo->pData + heightOffset);
|
||||
}
|
||||
|
||||
uint32_t depthOffset = getOffsetForConstantId(pSpecInfo, ep.workgroupSizeId.depth);
|
||||
if (depthOffset != -1) {
|
||||
ep.workgroupSize.depth = *reinterpret_cast<uint32_t*>((uint8_t*)pSpecInfo->pData + depthOffset);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { mtlFunc, MTLSizeMake(ep.workgroupSize.width, ep.workgroupSize.height, ep.workgroupSize.depth) };
|
||||
}
|
||||
|
@ -322,6 +322,13 @@ void populateFromCompiler(spirv_cross::Compiler& compiler, SPIRVEntryPointsByNam
|
||||
mvkEP.workgroupSize.width = max(wgSize.x, minDim);
|
||||
mvkEP.workgroupSize.height = max(wgSize.y, minDim);
|
||||
mvkEP.workgroupSize.depth = max(wgSize.z, minDim);
|
||||
|
||||
spirv_cross::SpecializationConstant width, height, depth;
|
||||
mvkEP.workgroupSizeId.constant = compiler.get_work_group_size_specialization_constants(width, height, depth);
|
||||
mvkEP.workgroupSizeId.width = width.constant_id;
|
||||
mvkEP.workgroupSizeId.height = height.constant_id;
|
||||
mvkEP.workgroupSizeId.depth = depth.constant_id;
|
||||
|
||||
entryPoints[epOrigName] = mvkEP;
|
||||
}
|
||||
}
|
||||
|
@ -132,7 +132,7 @@ namespace mvk {
|
||||
/**
|
||||
* Describes a SPIRV entry point, including the Metal function name (which may be
|
||||
* different than the Vulkan entry point name if the original name was illegal in Metal),
|
||||
* and the number of threads in each workgroup, if the shader is a compute shader.
|
||||
* and the number of threads in each workgroup or their specialization constant id, if the shader is a compute shader.
|
||||
*/
|
||||
typedef struct {
|
||||
std::string mtlFunctionName;
|
||||
@ -141,6 +141,10 @@ namespace mvk {
|
||||
uint32_t height = 1;
|
||||
uint32_t depth = 1;
|
||||
} workgroupSize;
|
||||
struct {
|
||||
uint32_t width, height, depth;
|
||||
uint32_t constant = 0;
|
||||
} workgroupSizeId;
|
||||
} SPIRVEntryPoint;
|
||||
|
||||
/** Holds a map of entry point info, indexed by the SPIRV entry point name. */
|
||||
|
Loading…
x
Reference in New Issue
Block a user