finish redo metal shader compilation

This commit is contained in:
Greg Wells
2025-07-19 14:31:49 -04:00
parent 7d4973e27d
commit aa729d3589
6 changed files with 68 additions and 29 deletions

View File

@@ -6,7 +6,7 @@
typedef struct gnPlatformGraphicsPipeline_t { typedef struct gnPlatformGraphicsPipeline_t {
id<MTLRenderPipelineState> graphicsPipeline; id<MTLRenderPipelineState> graphicsPipeline;
id<MTLDepthStencilState> depthState; id<MTLDepthStencilState> depthState;
metalShaderMap vertexShaderMaps, fragmentShaderMaps; mtlShaderMap vertexShaderMaps, fragmentShaderMaps;
} gnPlatformGraphicsPipeline; } gnPlatformGraphicsPipeline;
gnReturnCode createMetalGraphicsPipeline(gnGraphicsPipeline graphicsPipeline, gnOutputDevice device, gnGraphicsPipelineInfo info); gnReturnCode createMetalGraphicsPipeline(gnGraphicsPipeline graphicsPipeline, gnOutputDevice device, gnGraphicsPipelineInfo info);

View File

@@ -76,10 +76,10 @@ gnReturnCode createMetalGraphicsPipeline(gnGraphicsPipeline graphicsPipeline, gn
for (int i = 0; i < info.shaderModuleCount; i++) { for (int i = 0; i < info.shaderModuleCount; i++) {
if (info.shaderModules[i]->info.stage == GN_VERTEX_SHADER_MODULE) { if (info.shaderModules[i]->info.stage == GN_VERTEX_SHADER_MODULE) {
[descriptor setVertexFunction:info.shaderModules[i]->shaderModule->function]; [descriptor setVertexFunction:info.shaderModules[i]->shaderModule->function];
graphicsPipeline->graphicsPipeline->vertexShaderMaps = info.shaderModules[i]->shaderModule->map; graphicsPipeline->graphicsPipeline->vertexShaderMaps = info.shaderModules[i]->shaderModule->shaderMap;
} else if (info.shaderModules[i]->info.stage == GN_FRAGMENT_SHADER_MODULE) { } else if (info.shaderModules[i]->info.stage == GN_FRAGMENT_SHADER_MODULE) {
[descriptor setFragmentFunction:info.shaderModules[i]->shaderModule->function]; [descriptor setFragmentFunction:info.shaderModules[i]->shaderModule->function];
graphicsPipeline->graphicsPipeline->fragmentShaderMaps = info.shaderModules[i]->shaderModule->map; graphicsPipeline->graphicsPipeline->fragmentShaderMaps = info.shaderModules[i]->shaderModule->shaderMap;
} else { } else {
return GN_UNSUPPORTED_SHADER_MODULE; return GN_UNSUPPORTED_SHADER_MODULE;
} }

View File

@@ -3,6 +3,9 @@
#include "stdlib.h" #include "stdlib.h"
#include "utils/gryphn_bool.h" #include "utils/gryphn_bool.h"
#define MAX_METAL_SETS 32
#define MAX_METAL_BINDINGS 16
typedef enum mtlShaderModuleStage { typedef enum mtlShaderModuleStage {
vertex, fragment vertex, fragment
} mtlShaderModuleStage; } mtlShaderModuleStage;
@@ -13,7 +16,26 @@ typedef struct mtlShaderOptions {
const char* entryPoint; const char* entryPoint;
} mtlShaderOptions; } mtlShaderOptions;
typedef struct mtlBinding {
uint32_t spvBinding;
uint32_t metalID;
} mtlBinding;
typedef struct mtlSetMap {
uint32_t setIndex, mtlSetIndex;
mtlBinding bindings[MAX_METAL_BINDINGS];
} mtlSetMap;
typedef struct mtlShaderMap {
mtlSetMap sets[MAX_METAL_SETS];
} mtlShaderMap;
typedef struct mtlShader {
const char* code;
mtlShaderMap map;
} mtlShader;
#ifdef __cplusplus #ifdef __cplusplus
extern "C" extern "C"
#endif #endif
const char* mtlCompileShader(uint32_t* code, size_t wordCount, mtlShaderOptions* options); mtlShader mtlCompileShader(uint32_t* code, size_t wordCount, mtlShaderOptions* options);

View File

@@ -2,15 +2,26 @@
#include "spirv_msl.hpp" #include "spirv_msl.hpp"
#include "iostream" #include "iostream"
void handle_resources(spirv_cross::CompilerMSL& compiler, spirv_cross::SmallVector<spirv_cross::Resource>& resources) { void handle_resources(spirv_cross::CompilerMSL& compiler, spirv_cross::SmallVector<spirv_cross::Resource>& resources, mtlShaderMap* map) {
for (int i = 0; i < resources.size(); i++) { for (int i = 0; i < resources.size(); i++) {
uint32_t set = compiler.get_decoration(resources[i].id, spv::DecorationDescriptorSet); uint32_t set = compiler.get_decoration(resources[i].id, spv::DecorationDescriptorSet);
compiler.unset_decoration(resources[i].id, spv::DecorationDescriptorSet); compiler.unset_decoration(resources[i].id, spv::DecorationDescriptorSet);
compiler.set_decoration(resources[i].id, spv::DecorationDescriptorSet, set + 1); compiler.set_decoration(resources[i].id, spv::DecorationDescriptorSet, set + 1);
map->sets[set].setIndex = set;
map->sets[set].mtlSetIndex = set + 1;
} }
} }
extern "C" const char* mtlCompileShader(uint32_t* code, size_t wordCount, mtlShaderOptions* inOptions) { void improve_map(spirv_cross::CompilerMSL& compiler, spirv_cross::SmallVector<spirv_cross::Resource>& resources, mtlShaderMap* map) {
for (int i = 0; i < resources.size(); i++) {
uint32_t set = compiler.get_decoration(resources[i].id, spv::DecorationDescriptorSet);
uint32_t binding = compiler.get_decoration(resources[i].id, spv::DecorationBinding);
map->sets[(set - 1)].bindings[binding].spvBinding = binding;
map->sets[(set - 1)].bindings[binding].metalID = compiler.get_automatic_msl_resource_binding(resources[i].id);
}
}
extern "C" mtlShader mtlCompileShader(uint32_t* code, size_t wordCount, mtlShaderOptions* inOptions) {
spirv_cross::CompilerMSL compiler(code, wordCount); spirv_cross::CompilerMSL compiler(code, wordCount);
spirv_cross::CompilerMSL::Options options; spirv_cross::CompilerMSL::Options options;
@@ -20,7 +31,7 @@ extern "C" const char* mtlCompileShader(uint32_t* code, size_t wordCount, mtlSha
options.argument_buffers = true; options.argument_buffers = true;
} else { } else {
options.set_msl_version(1); options.set_msl_version(1);
return NULL; return {};
} }
compiler.set_msl_options(options); compiler.set_msl_options(options);
if (inOptions->stage == vertex) if (inOptions->stage == vertex)
@@ -28,16 +39,33 @@ extern "C" const char* mtlCompileShader(uint32_t* code, size_t wordCount, mtlSha
else if (inOptions->stage == fragment) else if (inOptions->stage == fragment)
compiler.set_entry_point(inOptions->entryPoint, spv::ExecutionModelFragment); compiler.set_entry_point(inOptions->entryPoint, spv::ExecutionModelFragment);
else { else {
return NULL; return {};
} }
mtlShaderMap map;
for (int i = 0; i < MAX_METAL_SETS; i++) {
map.sets[i].mtlSetIndex = -1;
map.sets[i].setIndex = -1;
for (int c = 0; c < MAX_METAL_BINDINGS; c++) {
map.sets[i].bindings[c].spvBinding = -1;
map.sets[i].bindings[c].metalID = -1;
}
}
auto arg_buffers = compiler.get_shader_resources(); auto arg_buffers = compiler.get_shader_resources();
handle_resources(compiler, arg_buffers.uniform_buffers); handle_resources(compiler, arg_buffers.uniform_buffers, &map);
handle_resources(compiler, arg_buffers.storage_buffers); handle_resources(compiler, arg_buffers.storage_buffers, &map);
handle_resources(compiler, arg_buffers.sampled_images); handle_resources(compiler, arg_buffers.sampled_images, &map);
std::string returnedCode = compiler.compile(); std::string returnedCode = compiler.compile();
char* returnString = (char*)malloc(sizeof(char) * returnedCode.size());
improve_map(compiler, arg_buffers.uniform_buffers, &map);
improve_map(compiler, arg_buffers.storage_buffers, &map);
improve_map(compiler, arg_buffers.sampled_images, &map);
char* returnString = (char*)malloc(sizeof(char) * (returnedCode.size() + 1));
strcpy(returnString, returnedCode.c_str()); strcpy(returnString, returnedCode.c_str());
return returnString; return {
.code = returnString,
.map = map
};
} }

View File

@@ -1,24 +1,12 @@
#pragma once #pragma once
#include "shader_module/gryphn_shader_module.h" #include "shader_module/gryphn_shader_module.h"
#include "utils/lists/gryphn_array_list.h" #include "utils/lists/gryphn_array_list.h"
#include "metal_shader_compiler.h"
#import <Metal/Metal.h> #import <Metal/Metal.h>
#define METAL_MAX_SET_COUNT 16
#define METAL_MAX_BINDING_COUNT 16
typedef struct metalSetMap {
uint32_t bindings[METAL_MAX_BINDING_COUNT];
} metalSetMap;
typedef struct metalShaderMap {
metalSetMap sets[METAL_MAX_SET_COUNT];
uint32_t pushConstantBufferIndex;
} metalShaderMap;
typedef struct gnPlatformShaderModule_t { typedef struct gnPlatformShaderModule_t {
id<MTLFunction> function; id<MTLFunction> function;
metalShaderMap map; mtlShaderMap shaderMap;
gnBool useShaderMap;
} gnPlatformShaderModule; } gnPlatformShaderModule;
#ifdef __cplusplus #ifdef __cplusplus

View File

@@ -16,7 +16,8 @@ gnReturnCode createMetalShaderModule(gnShaderModule module, gnDevice device, gnS
}; };
if (shaderModuleInfo.stage == GN_FRAGMENT_SHADER_MODULE) options.stage = fragment; if (shaderModuleInfo.stage == GN_FRAGMENT_SHADER_MODULE) options.stage = fragment;
const char* res = mtlCompileShader(shaderModuleInfo.code, shaderModuleInfo.size / 4, &options); mtlShader shader = mtlCompileShader(shaderModuleInfo.code, shaderModuleInfo.size / 4, &options);
const char* res = shader.code;
if (res == NULL) return GN_FAILED_TO_CONVERT_SHADER_CODE; if (res == NULL) return GN_FAILED_TO_CONVERT_SHADER_CODE;
NSError* error = nil; NSError* error = nil;
@@ -45,10 +46,10 @@ gnReturnCode createMetalShaderModule(gnShaderModule module, gnDevice device, gnS
NSString* functionName = [NSString stringWithCString:name encoding:NSUTF8StringEncoding]; NSString* functionName = [NSString stringWithCString:name encoding:NSUTF8StringEncoding];
module->shaderModule->function = [shaderLib newFunctionWithName:functionName]; module->shaderModule->function = [shaderLib newFunctionWithName:functionName];
// printf("res %s\n", res);
[shaderLib release]; [shaderLib release];
free((void*)res); free((void*)res);
module->shaderModule->shaderMap = shader.map;
return GN_SUCCESS; return GN_SUCCESS;
} }