added workgroupssize specialization constants
This commit is contained in:
parent
ddcb833da2
commit
6eecd58ad2
@ -29,6 +29,18 @@ const MVKMTLFunction MVKMTLFunctionNull = { nil, MTLSizeMake(1, 1, 1) };
|
|||||||
#pragma mark -
|
#pragma mark -
|
||||||
#pragma mark MVKShaderLibrary
|
#pragma mark MVKShaderLibrary
|
||||||
|
|
||||||
|
uint32_t getOffsetForConstantId(const VkSpecializationInfo* pSpecInfo, uint32_t constantId)
|
||||||
|
{
|
||||||
|
for (uint32_t specIdx = 0; specIdx < pSpecInfo->mapEntryCount; specIdx++) {
|
||||||
|
const VkSpecializationMapEntry* pMapEntry = &pSpecInfo->pMapEntries[specIdx];
|
||||||
|
if (pMapEntry->constantID == constantId) {
|
||||||
|
return pMapEntry->offset;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
MVKMTLFunction MVKShaderLibrary::getMTLFunction(const VkPipelineShaderStageCreateInfo* pShaderStage) {
|
MVKMTLFunction MVKShaderLibrary::getMTLFunction(const VkPipelineShaderStageCreateInfo* pShaderStage) {
|
||||||
|
|
||||||
if ( !_mtlLibrary ) { return MVKMTLFunctionNull; }
|
if ( !_mtlLibrary ) { return MVKMTLFunctionNull; }
|
||||||
@ -68,6 +80,24 @@ MVKMTLFunction MVKShaderLibrary::getMTLFunction(const VkPipelineShaderStageCreat
|
|||||||
atIndex: mtlFCIndex];
|
atIndex: mtlFCIndex];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get the specialization constant values for the work group size
|
||||||
|
if (ep.workgroupSizeId.constant != 0) {
|
||||||
|
uint32_t widthOffset = getOffsetForConstantId(pSpecInfo, ep.workgroupSizeId.width);
|
||||||
|
if (widthOffset != -1) {
|
||||||
|
ep.workgroupSize.width = &(((uint32_t*)pSpecInfo->pData)[widthOffset]);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t heightOffset = getOffsetForConstantId(pSpecInfo, ep.workgroupSizeId.height);
|
||||||
|
if (heightOffset != -1) {
|
||||||
|
ep.workgroupSize.height = &(((uint32_t*)pSpecInfo->pData)[heightOffset]);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t depthOffset = getOffsetForConstantId(pSpecInfo, ep.workgroupSizeId.depth);
|
||||||
|
if (depthOffset != -1) {
|
||||||
|
ep.workgroupSize.depth = &(((uint32_t*)pSpecInfo->pData)[depthOffset]);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compile the specialized Metal function, and use it instead of the unspecialized Metal function.
|
// Compile the specialized Metal function, and use it instead of the unspecialized Metal function.
|
||||||
|
@ -322,6 +322,13 @@ void populateFromCompiler(spirv_cross::Compiler& compiler, SPIRVEntryPointsByNam
|
|||||||
mvkEP.workgroupSize.width = max(wgSize.x, minDim);
|
mvkEP.workgroupSize.width = max(wgSize.x, minDim);
|
||||||
mvkEP.workgroupSize.height = max(wgSize.y, minDim);
|
mvkEP.workgroupSize.height = max(wgSize.y, minDim);
|
||||||
mvkEP.workgroupSize.depth = max(wgSize.z, minDim);
|
mvkEP.workgroupSize.depth = max(wgSize.z, minDim);
|
||||||
|
|
||||||
|
spirv_cross::SpecializationConstant width, height, depth;
|
||||||
|
mvkEP.workgroupSizeId.constant = compiler.get_work_group_size_specialization_constants(width, height, depth);
|
||||||
|
mvkEP.workgroupSizeId.width = width.constant_id;
|
||||||
|
mvkEP.workgroupSizeId.height = height.constant_id;
|
||||||
|
mvkEP.workgroupSizeId.depth = depth.constant_id;
|
||||||
|
|
||||||
entryPoints[epOrigName] = mvkEP;
|
entryPoints[epOrigName] = mvkEP;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -132,7 +132,7 @@ namespace mvk {
|
|||||||
/**
|
/**
|
||||||
* Describes a SPIRV entry point, including the Metal function name (which may be
|
* Describes a SPIRV entry point, including the Metal function name (which may be
|
||||||
* different than the Vulkan entry point name if the original name was illegal in Metal),
|
* different than the Vulkan entry point name if the original name was illegal in Metal),
|
||||||
* and the number of threads in each workgroup, if the shader is a compute shader.
|
* and the number of threads in each workgroup or their specialization constant id, if the shader is a compute shader.
|
||||||
*/
|
*/
|
||||||
typedef struct {
|
typedef struct {
|
||||||
std::string mtlFunctionName;
|
std::string mtlFunctionName;
|
||||||
@ -141,6 +141,10 @@ namespace mvk {
|
|||||||
uint32_t height = 1;
|
uint32_t height = 1;
|
||||||
uint32_t depth = 1;
|
uint32_t depth = 1;
|
||||||
} workgroupSize;
|
} workgroupSize;
|
||||||
|
struct {
|
||||||
|
uint32_t width, height, depth;
|
||||||
|
uint32_t constant = 0;
|
||||||
|
} workgroupSizeId;
|
||||||
} SPIRVEntryPoint;
|
} SPIRVEntryPoint;
|
||||||
|
|
||||||
/** Holds a map of entry point info, indexed by the SPIRV entry point name. */
|
/** Holds a map of entry point info, indexed by the SPIRV entry point name. */
|
||||||
|
Loading…
x
Reference in New Issue
Block a user