WIP V2: Experimental: Metal backend #441

Merged
GreemDev merged 369 commits from new-metal into master 2024-12-24 06:55:16 +00:00
10 changed files with 110 additions and 77 deletions
Showing only changes of commit becf828d0a - Show all commits

View File

@ -9,7 +9,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
public const string Tab = " ";
// The number of additional arguments that every function (except for the main one) must have (for instance support_buffer)
public const int AdditionalArgCount = 2;
public const int AdditionalArgCount = 1;
public StructuredFunction CurrentFunction { get; set; }

View File

@ -64,9 +64,14 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
return ioDefinition.StorageKind == storageKind && ioDefinition.IoVariable == IoVariable.UserDefined;
}
public static void DeclareLocals(CodeGenContext context, StructuredFunction function, ShaderStage stage)
public static void DeclareLocals(CodeGenContext context, StructuredFunction function, ShaderStage stage, bool isMainFunc = false)
{
DeclareMemories(context, context.Properties.LocalMemories.Values, isShared: false);
if (isMainFunc)
{
DeclareMemories(context, context.Properties.LocalMemories.Values, isShared: false);
DeclareMemories(context, context.Properties.SharedMemories.Values, isShared: true);
}
switch (stage)
{
case ShaderStage.Vertex:
@ -112,6 +117,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
private static void DeclareMemories(CodeGenContext context, IEnumerable<MemoryDefinition> memories, bool isShared)
{
string prefix = isShared ? "threadgroup " : string.Empty;
foreach (var memory in memories)
{
string arraySize = "";
@ -120,7 +127,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
arraySize = $"[{memory.ArrayLength}]";
}
var typeName = GetVarTypeName(context, memory.Type & ~AggregateType.Array);
context.AppendLine($"{typeName} {memory.Name}{arraySize};");
context.AppendLine($"{prefix}{typeName} {memory.Name}{arraySize};");
}
}
@ -128,23 +135,28 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
{
foreach (BufferDefinition buffer in buffers)
{
context.AppendLine($"struct Struct_{buffer.Name}");
context.AppendLine($"struct {DefaultNames.StructPrefix}_{buffer.Name}");
context.EnterScope();
foreach (StructureField field in buffer.Type.Fields)
{
if (field.Type.HasFlag(AggregateType.Array) && field.ArrayLength > 0)
{
string typeName = GetVarTypeName(context, field.Type & ~AggregateType.Array);
string typeName = GetVarTypeName(context, field.Type & ~AggregateType.Array);
string arraySuffix = "";
context.AppendLine($"{typeName} {field.Name}[{field.ArrayLength}];");
}
else
if (field.Type.HasFlag(AggregateType.Array))
{
string typeName = GetVarTypeName(context, field.Type & ~AggregateType.Array);
context.AppendLine($"{typeName} {field.Name};");
if (field.ArrayLength > 0)
{
arraySuffix = $"[{field.ArrayLength}]";
}
else
{
// Probably UB, but this is the approach that MVK takes
arraySuffix = "[1]";
}
}
context.AppendLine($"{typeName} {field.Name}{arraySuffix};");
}
context.LeaveScope(";");
@ -191,6 +203,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
IoVariable.GlobalId => "uint3",
IoVariable.VertexId => "uint",
IoVariable.VertexIndex => "uint",
IoVariable.PointCoord => "float2",
_ => GetVarTypeName(context, context.Definitions.GetUserDefinedType(ioDefinition.Location, isOutput: false))
};
string name = ioDefinition.IoVariable switch
@ -199,6 +212,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
IoVariable.GlobalId => "global_id",
IoVariable.VertexId => "vertex_id",
IoVariable.VertexIndex => "vertex_index",
IoVariable.PointCoord => "point_coord",
_ => $"{DefaultNames.IAttributePrefix}{ioDefinition.Location}"
};
string suffix = ioDefinition.IoVariable switch
@ -208,6 +222,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
IoVariable.VertexId => "[[vertex_id]]",
// TODO: Avoid potential redeclaration
IoVariable.VertexIndex => "[[vertex_id]]",
IoVariable.PointCoord => "[[point_coord]]",
IoVariable.UserDefined => context.Definitions.Stage == ShaderStage.Fragment ? $"[[user(loc{ioDefinition.Location})]]" : $"[[attribute({ioDefinition.Location})]]",
_ => ""
};

View File

@ -8,6 +8,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
public const string IAttributePrefix = "inAttr";
public const string OAttributePrefix = "outAttr";
public const string StructPrefix = "struct";
public const string ArgumentNamePrefix = "a";
public const string UndefinedName = "0";

View File

@ -2,7 +2,7 @@ using Ryujinx.Graphics.Shader.IntermediateRepresentation;
using Ryujinx.Graphics.Shader.StructuredIr;
using Ryujinx.Graphics.Shader.Translation;
using System;
using System.Text;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenCall;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenHelper;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenMemory;
@ -39,11 +39,20 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
int arity = (int)(info.Type & InstType.ArityMask);
string args = string.Empty;
StringBuilder builder = new();
if (atomic)
if (atomic && (operation.StorageKind == StorageKind.StorageBuffer || operation.StorageKind == StorageKind.SharedMemory))
{
// Hell
builder.Append(GenerateLoadOrStore(context, operation, isStore: false));
AggregateType dstType = operation.Inst == Instruction.AtomicMaxS32 || operation.Inst == Instruction.AtomicMinS32
? AggregateType.S32
: AggregateType.U32;
for (int argIndex = operation.SourcesCount - arity + 2; argIndex < operation.SourcesCount; argIndex++)
{
builder.Append($", {GetSourceExpr(context, operation.GetSource(argIndex), dstType)}");
}
}
else
{
@ -51,16 +60,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
{
if (argIndex != 0)
{
args += ", ";
builder.Append(", ");
}
AggregateType dstType = GetSrcVarType(inst, argIndex);
args += GetSourceExpr(context, operation.GetSource(argIndex), dstType);
builder.Append(GetSourceExpr(context, operation.GetSource(argIndex), dstType));
}
}
return info.OpName + '(' + args + ')';
return $"{info.OpName}({builder})";
}
else if ((info.Type & InstType.Op) != 0)
{
@ -110,7 +119,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
switch (inst & Instruction.Mask)
{
case Instruction.Barrier:
return "|| BARRIER ||";
return "threadgroup_barrier(mem_flags::mem_threadgroup)";
case Instruction.Call:
return Call(context, operation);
case Instruction.FSIBegin:

View File

@ -13,13 +13,22 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
var functon = context.GetFunction(funcId.Value);
int argCount = operation.SourcesCount - 1;
string[] args = new string[argCount + CodeGenContext.AdditionalArgCount];
int additionalArgCount = CodeGenContext.AdditionalArgCount + (context.Definitions.Stage != ShaderStage.Compute ? 1 : 0);
string[] args = new string[argCount + additionalArgCount];
// Additional arguments
args[0] = "in";
args[1] = "support_buffer";
if (context.Definitions.Stage != ShaderStage.Compute)
{
args[0] = "in";
args[1] = "support_buffer";
}
else
{
args[0] = "support_buffer";
}
int argIndex = CodeGenContext.AdditionalArgCount;
int argIndex = additionalArgCount;
for (int i = 0; i < argCount; i++)
{
args[argIndex++] = GetSourceExpr(context, operation.GetSource(i + 1), functon.GetArgumentType(i));

View File

@ -109,10 +109,10 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
Add(Instruction.ShiftLeft, InstType.OpBinary, "<<", 3);
Add(Instruction.ShiftRightS32, InstType.OpBinary, ">>", 3);
Add(Instruction.ShiftRightU32, InstType.OpBinary, ">>", 3);
Add(Instruction.Shuffle, InstType.CallQuaternary, "simd_shuffle");
Add(Instruction.ShuffleDown, InstType.CallQuaternary, "simd_shuffle_down");
Add(Instruction.ShuffleUp, InstType.CallQuaternary, "simd_shuffle_up");
Add(Instruction.ShuffleXor, InstType.CallQuaternary, "simd_shuffle_xor");
Add(Instruction.Shuffle, InstType.CallBinary, "simd_shuffle");
Add(Instruction.ShuffleDown, InstType.CallBinary, "simd_shuffle_down");
Add(Instruction.ShuffleUp, InstType.CallBinary, "simd_shuffle_up");
Add(Instruction.ShuffleXor, InstType.CallBinary, "simd_shuffle_xor");
Add(Instruction.Sine, InstType.CallUnary, "sin");
Add(Instruction.SquareRoot, InstType.CallUnary, "sqrt");
Add(Instruction.Store, InstType.Special);

View File

@ -47,15 +47,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
StructureField field = buffer.Type.Fields[fieldIndex.Value];
varName = buffer.Name;
if ((field.Type & AggregateType.Array) != 0 && field.ArrayLength == 0)
{
// Unsized array, the buffer is indexed instead of the field
fieldName = "." + field.Name;
}
else
{
varName += "->" + field.Name;
}
varName += "->" + field.Name;
varType = field.Type;
break;

View File

@ -27,13 +27,19 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
IoVariable.GlobalId => ("thread_position_in_grid", AggregateType.Vector3 | AggregateType.U32),
IoVariable.InstanceId => ("instance_id", AggregateType.S32),
IoVariable.InvocationId => ("INVOCATION_ID", AggregateType.S32),
IoVariable.PointCoord => ("point_coord", AggregateType.Vector2),
IoVariable.PointCoord => ("point_coord", AggregateType.Vector2 | AggregateType.FP32),
IoVariable.PointSize => ("out.point_size", AggregateType.FP32),
IoVariable.Position => ("out.position", AggregateType.Vector4 | AggregateType.FP32),
IoVariable.PrimitiveId => ("primitive_id", AggregateType.S32),
IoVariable.SubgroupEqMask => ("thread_index_in_simdgroup >= 32 ? uint4(0, (1 << (thread_index_in_simdgroup - 32)), uint2(0)) : uint4(1 << thread_index_in_simdgroup, uint3(0))", AggregateType.Vector4 | AggregateType.U32),
IoVariable.SubgroupGeMask => ("uint4(insert_bits(0u, 0xFFFFFFFF, thread_index_in_simdgroup, 32 - thread_index_in_simdgroup), uint3(0)) & (uint4((uint)((simd_vote::vote_t)simd_ballot(true) & 0xFFFFFFFF), (uint)(((simd_vote::vote_t)simd_ballot(true) >> 32) & 0xFFFFFFFF), 0, 0))", AggregateType.Vector4 | AggregateType.U32),
IoVariable.SubgroupGtMask => ("uint4(insert_bits(0u, 0xFFFFFFFF, thread_index_in_simdgroup + 1, 32 - thread_index_in_simdgroup - 1), uint3(0)) & (uint4((uint)((simd_vote::vote_t)simd_ballot(true) & 0xFFFFFFFF), (uint)(((simd_vote::vote_t)simd_ballot(true) >> 32) & 0xFFFFFFFF), 0, 0))", AggregateType.Vector4 | AggregateType.U32),
IoVariable.SubgroupLaneId => ("thread_index_in_simdgroup", AggregateType.U32),
IoVariable.SubgroupLeMask => ("uint4(extract_bits(0xFFFFFFFF, 0, min(thread_index_in_simdgroup + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)thread_index_in_simdgroup + 1 - 32, 0)), uint2(0))", AggregateType.Vector4 | AggregateType.U32),
IoVariable.SubgroupLtMask => ("uint4(extract_bits(0xFFFFFFFF, 0, min(thread_index_in_simdgroup, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)thread_index_in_simdgroup - 32, 0)), uint2(0))", AggregateType.Vector4 | AggregateType.U32),
IoVariable.ThreadKill => ("simd_is_helper_thread()", AggregateType.Bool),
IoVariable.UserDefined => GetUserDefinedVariableName(definitions, location, component, isOutput, isPerPatch),
IoVariable.ThreadId => ("thread_position_in_threadgroup", AggregateType.Vector3 | AggregateType.U32),
IoVariable.SubgroupLaneId => ("thread_index_in_simdgroup", AggregateType.U32),
IoVariable.VertexId => ("vertex_id", AggregateType.S32),
// gl_VertexIndex does not have a direct equivalent in MSL
IoVariable.VertexIndex => ("vertex_id", AggregateType.U32),

View File

@ -44,7 +44,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
context.AppendLine(GetFunctionSignature(context, function, stage, isMainFunc));
context.EnterScope();
Declarations.DeclareLocals(context, function, stage);
Declarations.DeclareLocals(context, function, stage, isMainFunc);
PrintBlock(context, function.MainBlock, isMainFunc);
@ -63,15 +63,22 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
ShaderStage stage,
bool isMainFunc = false)
{
int additionalArgCount = isMainFunc ? 0 : CodeGenContext.AdditionalArgCount;
int additionalArgCount = isMainFunc ? 0 : CodeGenContext.AdditionalArgCount + (context.Definitions.Stage != ShaderStage.Compute ? 1 : 0);
string[] args = new string[additionalArgCount + function.InArguments.Length + function.OutArguments.Length];
// All non-main functions need to be able to access the support_buffer as well
if (!isMainFunc)
{
args[0] = "FragmentIn in";
args[1] = "constant Struct_support_buffer* support_buffer";
if (stage != ShaderStage.Compute)
{
args[0] = stage == ShaderStage.Vertex ? "VertexIn in" : "FragmentIn in";
args[1] = $"constant {DefaultNames.StructPrefix}_support_buffer* support_buffer";
}
else
{
args[0] = $"constant {DefaultNames.StructPrefix}_support_buffer* support_buffer";
}
}
int argIndex = additionalArgCount;
@ -141,13 +148,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
foreach (var constantBuffer in context.Properties.ConstantBuffers.Values)
{
args = args.Append($"constant Struct_{constantBuffer.Name}* {constantBuffer.Name} [[buffer({constantBuffer.Binding})]]").ToArray();
args = args.Append($"constant {DefaultNames.StructPrefix}_{constantBuffer.Name}* {constantBuffer.Name} [[buffer({constantBuffer.Binding})]]").ToArray();
}
foreach (var storageBuffers in context.Properties.StorageBuffers.Values)
{
// Offset the binding by 15 to avoid clashing with the constant buffers
args = args.Append($"device Struct_{storageBuffers.Name}* {storageBuffers.Name} [[buffer({storageBuffers.Binding + 15})]]").ToArray();
args = args.Append($"device {DefaultNames.StructPrefix}_{storageBuffers.Name}* {storageBuffers.Name} [[buffer({storageBuffers.Binding + 15})]]").ToArray();
}
foreach (var texture in context.Properties.Textures.Values)
@ -162,7 +169,10 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
}
}
return $"{funcKeyword} {returnType} {funcName ?? function.Name}({string.Join(", ", args)})";
var funcPrefix = $"{funcKeyword} {returnType} {funcName ?? function.Name}(";
var indent = new string(' ', funcPrefix.Length);
return $"{funcPrefix}{string.Join($", \n{indent}", args)})";
}
private static void PrintBlock(CodeGenContext context, AstBlock block, bool isMainFunction)

View File

@ -10,25 +10,21 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
public static bool TryFormat(int value, AggregateType dstType, out string formatted)
{
if (dstType == AggregateType.FP32)
switch (dstType)
{
return TryFormatFloat(BitConverter.Int32BitsToSingle(value), out formatted);
}
else if (dstType == AggregateType.S32)
{
formatted = FormatInt(value);
}
else if (dstType == AggregateType.U32)
{
formatted = FormatUint((uint)value);
}
else if (dstType == AggregateType.Bool)
{
formatted = value != 0 ? "true" : "false";
}
else
{
throw new ArgumentException($"Invalid variable type \"{dstType}\".");
case AggregateType.FP32:
return TryFormatFloat(BitConverter.Int32BitsToSingle(value), out formatted);
case AggregateType.S32:
formatted = FormatInt(value);
break;
case AggregateType.U32:
formatted = FormatUint((uint)value);
break;
case AggregateType.Bool:
formatted = value != 0 ? "true" : "false";
break;
default:
throw new ArgumentException($"Invalid variable type \"{dstType}\".");
}
return true;
@ -65,18 +61,12 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
public static string FormatInt(int value, AggregateType dstType)
{
if (dstType == AggregateType.S32)
return dstType switch
{
return FormatInt(value);
}
else if (dstType == AggregateType.U32)
{
return FormatUint((uint)value);
}
else
{
throw new ArgumentException($"Invalid variable type \"{dstType}\".");
}
AggregateType.S32 => FormatInt(value),
AggregateType.U32 => FormatUint((uint)value),
_ => throw new ArgumentException($"Invalid variable type \"{dstType}\".")
};
}
public static string FormatInt(int value)