moved spe constant work group size reading at the end

This commit is contained in:
Maximilian Maldacker 2018-02-15 22:49:08 +01:00
parent 6eecd58ad2
commit e3741004b4

View File

@ -80,24 +80,6 @@ MVKMTLFunction MVKShaderLibrary::getMTLFunction(const VkPipelineShaderStageCreat
atIndex: mtlFCIndex];
}
}
// Get the specialization constant values for the work group size
if (ep.workgroupSizeId.constant != 0) {
uint32_t widthOffset = getOffsetForConstantId(pSpecInfo, ep.workgroupSizeId.width);
if (widthOffset != -1) {
ep.workgroupSize.width = &(((uint32_t*)pSpecInfo->pData)[widthOffset]);
}
uint32_t heightOffset = getOffsetForConstantId(pSpecInfo, ep.workgroupSizeId.height);
if (heightOffset != -1) {
ep.workgroupSize.height = &(((uint32_t*)pSpecInfo->pData)[heightOffset]);
}
uint32_t depthOffset = getOffsetForConstantId(pSpecInfo, ep.workgroupSizeId.depth);
if (depthOffset != -1) {
ep.workgroupSize.depth = &(((uint32_t*)pSpecInfo->pData)[depthOffset]);
}
}
}
// Compile the specialized Metal function, and use it instead of the unspecialized Metal function.
@ -110,6 +92,27 @@ MVKMTLFunction MVKShaderLibrary::getMTLFunction(const VkPipelineShaderStageCreat
} else {
mvkNotifyErrorWithText(VK_ERROR_INITIALIZATION_FAILED, "Shader module does not contain an entry point named '%s'.", mtlFuncName.UTF8String);
}
const VkSpecializationInfo* pSpecInfo = pShaderStage->pSpecializationInfo;
if (pSpecInfo) {
// Get the specialization constant values for the work group size
if (ep.workgroupSizeId.constant != 0) {
uint32_t widthOffset = getOffsetForConstantId(pSpecInfo, ep.workgroupSizeId.width);
if (widthOffset != -1) {
ep.workgroupSize.width = *reinterpret_cast<uint32_t*>((uint8_t*)pSpecInfo->pData + widthOffset);
}
uint32_t heightOffset = getOffsetForConstantId(pSpecInfo, ep.workgroupSizeId.height);
if (heightOffset != -1) {
ep.workgroupSize.height = *reinterpret_cast<uint32_t*>((uint8_t*)pSpecInfo->pData + heightOffset);
}
uint32_t depthOffset = getOffsetForConstantId(pSpecInfo, ep.workgroupSizeId.depth);
if (depthOffset != -1) {
ep.workgroupSize.depth = *reinterpret_cast<uint32_t*>((uint8_t*)pSpecInfo->pData + depthOffset);
}
}
}
return { mtlFunc, MTLSizeMake(ep.workgroupSize.width, ep.workgroupSize.height, ep.workgroupSize.depth) };
}