diff --git a/projects/apis/metal/src/pipelines/graphics_pipeline/metal_graphics_pipeline.h b/projects/apis/metal/src/pipelines/graphics_pipeline/metal_graphics_pipeline.h index 8cdc9aa..62bff05 100644 --- a/projects/apis/metal/src/pipelines/graphics_pipeline/metal_graphics_pipeline.h +++ b/projects/apis/metal/src/pipelines/graphics_pipeline/metal_graphics_pipeline.h @@ -6,7 +6,7 @@ typedef struct gnPlatformGraphicsPipeline_t { id graphicsPipeline; id depthState; - metalShaderMap vertexShaderMaps, fragmentShaderMaps; + mtlShaderMap vertexShaderMaps, fragmentShaderMaps; } gnPlatformGraphicsPipeline; gnReturnCode createMetalGraphicsPipeline(gnGraphicsPipeline graphicsPipeline, gnOutputDevice device, gnGraphicsPipelineInfo info); diff --git a/projects/apis/metal/src/pipelines/graphics_pipeline/metal_graphics_pipeline.m b/projects/apis/metal/src/pipelines/graphics_pipeline/metal_graphics_pipeline.m index 97d1270..196e7bb 100644 --- a/projects/apis/metal/src/pipelines/graphics_pipeline/metal_graphics_pipeline.m +++ b/projects/apis/metal/src/pipelines/graphics_pipeline/metal_graphics_pipeline.m @@ -76,10 +76,10 @@ gnReturnCode createMetalGraphicsPipeline(gnGraphicsPipeline graphicsPipeline, gn for (int i = 0; i < info.shaderModuleCount; i++) { if (info.shaderModules[i]->info.stage == GN_VERTEX_SHADER_MODULE) { [descriptor setVertexFunction:info.shaderModules[i]->shaderModule->function]; - graphicsPipeline->graphicsPipeline->vertexShaderMaps = info.shaderModules[i]->shaderModule->map; + graphicsPipeline->graphicsPipeline->vertexShaderMaps = info.shaderModules[i]->shaderModule->shaderMap; } else if (info.shaderModules[i]->info.stage == GN_FRAGMENT_SHADER_MODULE) { [descriptor setFragmentFunction:info.shaderModules[i]->shaderModule->function]; - graphicsPipeline->graphicsPipeline->fragmentShaderMaps = info.shaderModules[i]->shaderModule->map; + graphicsPipeline->graphicsPipeline->fragmentShaderMaps = info.shaderModules[i]->shaderModule->shaderMap; } else { return GN_UNSUPPORTED_SHADER_MODULE; } diff --git a/projects/apis/metal/src/shader_module/metal_shader_compiler.h b/projects/apis/metal/src/shader_module/metal_shader_compiler.h index ab77cdc..d713499 100644 --- a/projects/apis/metal/src/shader_module/metal_shader_compiler.h +++ b/projects/apis/metal/src/shader_module/metal_shader_compiler.h @@ -3,6 +3,9 @@ #include "stdlib.h" #include "utils/gryphn_bool.h" +#define MAX_METAL_SETS 32 +#define MAX_METAL_BINDINGS 16 + typedef enum mtlShaderModuleStage { vertex, fragment } mtlShaderModuleStage; @@ -13,7 +16,26 @@ typedef struct mtlShaderOptions { const char* entryPoint; } mtlShaderOptions; +typedef struct mtlBinding { + uint32_t spvBinding; + uint32_t metalID; +} mtlBinding; + +typedef struct mtlSetMap { + uint32_t setIndex, mtlSetIndex; + mtlBinding bindings[MAX_METAL_BINDINGS]; +} mtlSetMap; + +typedef struct mtlShaderMap { + mtlSetMap sets[MAX_METAL_SETS]; +} mtlShaderMap; + +typedef struct mtlShader { + const char* code; + mtlShaderMap map; +} mtlShader; + #ifdef __cplusplus extern "C" #endif -const char* mtlCompileShader(uint32_t* code, size_t wordCount, mtlShaderOptions* options); +mtlShader mtlCompileShader(uint32_t* code, size_t wordCount, mtlShaderOptions* options); diff --git a/projects/apis/metal/src/shader_module/metal_shader_compiler.mm b/projects/apis/metal/src/shader_module/metal_shader_compiler.mm index b2586a2..77c9209 100644 --- a/projects/apis/metal/src/shader_module/metal_shader_compiler.mm +++ b/projects/apis/metal/src/shader_module/metal_shader_compiler.mm @@ -2,15 +2,26 @@ #include "spirv_msl.hpp" #include "iostream" -void handle_resources(spirv_cross::CompilerMSL& compiler, spirv_cross::SmallVector& resources) { +void handle_resources(spirv_cross::CompilerMSL& compiler, spirv_cross::SmallVector& resources, mtlShaderMap* map) { for (int i = 0; i < resources.size(); i++) { uint32_t set = compiler.get_decoration(resources[i].id, spv::DecorationDescriptorSet); compiler.unset_decoration(resources[i].id, spv::DecorationDescriptorSet); compiler.set_decoration(resources[i].id, spv::DecorationDescriptorSet, set + 1); + map->sets[set].setIndex = set; + map->sets[set].mtlSetIndex = set + 1; } } -extern "C" const char* mtlCompileShader(uint32_t* code, size_t wordCount, mtlShaderOptions* inOptions) { +void improve_map(spirv_cross::CompilerMSL& compiler, spirv_cross::SmallVector& resources, mtlShaderMap* map) { + for (int i = 0; i < resources.size(); i++) { + uint32_t set = compiler.get_decoration(resources[i].id, spv::DecorationDescriptorSet); + uint32_t binding = compiler.get_decoration(resources[i].id, spv::DecorationBinding); + map->sets[(set - 1)].bindings[binding].spvBinding = binding; + map->sets[(set - 1)].bindings[binding].metalID = compiler.get_automatic_msl_resource_binding(resources[i].id); + } +} + +extern "C" mtlShader mtlCompileShader(uint32_t* code, size_t wordCount, mtlShaderOptions* inOptions) { spirv_cross::CompilerMSL compiler(code, wordCount); spirv_cross::CompilerMSL::Options options; @@ -20,7 +31,7 @@ extern "C" const char* mtlCompileShader(uint32_t* code, size_t wordCount, mtlSha options.argument_buffers = true; } else { options.set_msl_version(1); - return NULL; + return {}; } compiler.set_msl_options(options); if (inOptions->stage == vertex) @@ -28,16 +39,33 @@ extern "C" const char* mtlCompileShader(uint32_t* code, size_t wordCount, mtlSha else if (inOptions->stage == fragment) compiler.set_entry_point(inOptions->entryPoint, spv::ExecutionModelFragment); else { - return NULL; + return {}; } + mtlShaderMap map; + for (int i = 0; i < MAX_METAL_SETS; i++) { + map.sets[i].mtlSetIndex = -1; + map.sets[i].setIndex = -1; + for (int c = 0; c < MAX_METAL_BINDINGS; c++) { + map.sets[i].bindings[c].spvBinding = -1; + map.sets[i].bindings[c].metalID = -1; + } + } auto arg_buffers = compiler.get_shader_resources(); - handle_resources(compiler, arg_buffers.uniform_buffers); - handle_resources(compiler, arg_buffers.storage_buffers); - handle_resources(compiler, arg_buffers.sampled_images); + handle_resources(compiler, arg_buffers.uniform_buffers, &map); + handle_resources(compiler, arg_buffers.storage_buffers, &map); + handle_resources(compiler, arg_buffers.sampled_images, &map); std::string returnedCode = compiler.compile(); - char* returnString = (char*)malloc(sizeof(char) * returnedCode.size()); + + improve_map(compiler, arg_buffers.uniform_buffers, &map); + improve_map(compiler, arg_buffers.storage_buffers, &map); + improve_map(compiler, arg_buffers.sampled_images, &map); + + char* returnString = (char*)malloc(sizeof(char) * (returnedCode.size() + 1)); strcpy(returnString, returnedCode.c_str()); - return returnString; + return { + .code = returnString, + .map = map + }; } diff --git a/projects/apis/metal/src/shader_module/metal_shader_module.h b/projects/apis/metal/src/shader_module/metal_shader_module.h index a70e514..9c5131e 100644 --- a/projects/apis/metal/src/shader_module/metal_shader_module.h +++ b/projects/apis/metal/src/shader_module/metal_shader_module.h @@ -1,24 +1,12 @@ #pragma once #include "shader_module/gryphn_shader_module.h" #include "utils/lists/gryphn_array_list.h" +#include "metal_shader_compiler.h" #import -#define METAL_MAX_SET_COUNT 16 -#define METAL_MAX_BINDING_COUNT 16 - -typedef struct metalSetMap { - uint32_t bindings[METAL_MAX_BINDING_COUNT]; -} metalSetMap; - -typedef struct metalShaderMap { - metalSetMap sets[METAL_MAX_SET_COUNT]; - uint32_t pushConstantBufferIndex; -} metalShaderMap; - typedef struct gnPlatformShaderModule_t { id function; - metalShaderMap map; - gnBool useShaderMap; + mtlShaderMap shaderMap; } gnPlatformShaderModule; #ifdef __cplusplus diff --git a/projects/apis/metal/src/shader_module/metal_shader_module.m b/projects/apis/metal/src/shader_module/metal_shader_module.m index f72891f..8ba10bf 100644 --- a/projects/apis/metal/src/shader_module/metal_shader_module.m +++ b/projects/apis/metal/src/shader_module/metal_shader_module.m @@ -16,7 +16,8 @@ gnReturnCode createMetalShaderModule(gnShaderModule module, gnDevice device, gnS }; if (shaderModuleInfo.stage == GN_FRAGMENT_SHADER_MODULE) options.stage = fragment; - const char* res = mtlCompileShader(shaderModuleInfo.code, shaderModuleInfo.size / 4, &options); + mtlShader shader = mtlCompileShader(shaderModuleInfo.code, shaderModuleInfo.size / 4, &options); + const char* res = shader.code; if (res == NULL) return GN_FAILED_TO_CONVERT_SHADER_CODE; NSError* error = nil; @@ -45,10 +46,10 @@ gnReturnCode createMetalShaderModule(gnShaderModule module, gnDevice device, gnS NSString* functionName = [NSString stringWithCString:name encoding:NSUTF8StringEncoding]; module->shaderModule->function = [shaderLib newFunctionWithName:functionName]; - // printf("res %s\n", res); [shaderLib release]; free((void*)res); + module->shaderModule->shaderMap = shader.map; return GN_SUCCESS; }