WIP: Experimental: Metal backend #439

Closed
GreemDev wants to merge 374 commits from metal into master
10 changed files with 110 additions and 77 deletions
Showing only changes of commit 97a36298fa - Show all commits

View File

@ -9,7 +9,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
public const string Tab = " "; public const string Tab = " ";
// The number of additional arguments that every function (except for the main one) must have (for instance support_buffer) // The number of additional arguments that every function (except for the main one) must have (for instance support_buffer)
public const int AdditionalArgCount = 2; public const int AdditionalArgCount = 1;
public StructuredFunction CurrentFunction { get; set; } public StructuredFunction CurrentFunction { get; set; }

View File

@ -64,9 +64,14 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
return ioDefinition.StorageKind == storageKind && ioDefinition.IoVariable == IoVariable.UserDefined; return ioDefinition.StorageKind == storageKind && ioDefinition.IoVariable == IoVariable.UserDefined;
} }
public static void DeclareLocals(CodeGenContext context, StructuredFunction function, ShaderStage stage) public static void DeclareLocals(CodeGenContext context, StructuredFunction function, ShaderStage stage, bool isMainFunc = false)
{
if (isMainFunc)
{ {
DeclareMemories(context, context.Properties.LocalMemories.Values, isShared: false); DeclareMemories(context, context.Properties.LocalMemories.Values, isShared: false);
DeclareMemories(context, context.Properties.SharedMemories.Values, isShared: true);
}
switch (stage) switch (stage)
{ {
case ShaderStage.Vertex: case ShaderStage.Vertex:
@ -112,6 +117,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
private static void DeclareMemories(CodeGenContext context, IEnumerable<MemoryDefinition> memories, bool isShared) private static void DeclareMemories(CodeGenContext context, IEnumerable<MemoryDefinition> memories, bool isShared)
{ {
string prefix = isShared ? "threadgroup " : string.Empty;
foreach (var memory in memories) foreach (var memory in memories)
{ {
string arraySize = ""; string arraySize = "";
@ -120,7 +127,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
arraySize = $"[{memory.ArrayLength}]"; arraySize = $"[{memory.ArrayLength}]";
} }
var typeName = GetVarTypeName(context, memory.Type & ~AggregateType.Array); var typeName = GetVarTypeName(context, memory.Type & ~AggregateType.Array);
context.AppendLine($"{typeName} {memory.Name}{arraySize};"); context.AppendLine($"{prefix}{typeName} {memory.Name}{arraySize};");
} }
} }
@ -128,25 +135,30 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
{ {
foreach (BufferDefinition buffer in buffers) foreach (BufferDefinition buffer in buffers)
{ {
context.AppendLine($"struct Struct_{buffer.Name}"); context.AppendLine($"struct {DefaultNames.StructPrefix}_{buffer.Name}");
context.EnterScope(); context.EnterScope();
foreach (StructureField field in buffer.Type.Fields) foreach (StructureField field in buffer.Type.Fields)
{
if (field.Type.HasFlag(AggregateType.Array) && field.ArrayLength > 0)
{ {
string typeName = GetVarTypeName(context, field.Type & ~AggregateType.Array); string typeName = GetVarTypeName(context, field.Type & ~AggregateType.Array);
string arraySuffix = "";
context.AppendLine($"{typeName} {field.Name}[{field.ArrayLength}];"); if (field.Type.HasFlag(AggregateType.Array))
{
if (field.ArrayLength > 0)
{
arraySuffix = $"[{field.ArrayLength}]";
} }
else else
{ {
string typeName = GetVarTypeName(context, field.Type & ~AggregateType.Array); // Probably UB, but this is the approach that MVK takes
arraySuffix = "[1]";
context.AppendLine($"{typeName} {field.Name};");
} }
} }
context.AppendLine($"{typeName} {field.Name}{arraySuffix};");
}
context.LeaveScope(";"); context.LeaveScope(";");
context.AppendLine(); context.AppendLine();
} }
@ -191,6 +203,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
IoVariable.GlobalId => "uint3", IoVariable.GlobalId => "uint3",
IoVariable.VertexId => "uint", IoVariable.VertexId => "uint",
IoVariable.VertexIndex => "uint", IoVariable.VertexIndex => "uint",
IoVariable.PointCoord => "float2",
_ => GetVarTypeName(context, context.Definitions.GetUserDefinedType(ioDefinition.Location, isOutput: false)) _ => GetVarTypeName(context, context.Definitions.GetUserDefinedType(ioDefinition.Location, isOutput: false))
}; };
string name = ioDefinition.IoVariable switch string name = ioDefinition.IoVariable switch
@ -199,6 +212,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
IoVariable.GlobalId => "global_id", IoVariable.GlobalId => "global_id",
IoVariable.VertexId => "vertex_id", IoVariable.VertexId => "vertex_id",
IoVariable.VertexIndex => "vertex_index", IoVariable.VertexIndex => "vertex_index",
IoVariable.PointCoord => "point_coord",
_ => $"{DefaultNames.IAttributePrefix}{ioDefinition.Location}" _ => $"{DefaultNames.IAttributePrefix}{ioDefinition.Location}"
}; };
string suffix = ioDefinition.IoVariable switch string suffix = ioDefinition.IoVariable switch
@ -208,6 +222,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
IoVariable.VertexId => "[[vertex_id]]", IoVariable.VertexId => "[[vertex_id]]",
// TODO: Avoid potential redeclaration // TODO: Avoid potential redeclaration
IoVariable.VertexIndex => "[[vertex_id]]", IoVariable.VertexIndex => "[[vertex_id]]",
IoVariable.PointCoord => "[[point_coord]]",
IoVariable.UserDefined => context.Definitions.Stage == ShaderStage.Fragment ? $"[[user(loc{ioDefinition.Location})]]" : $"[[attribute({ioDefinition.Location})]]", IoVariable.UserDefined => context.Definitions.Stage == ShaderStage.Fragment ? $"[[user(loc{ioDefinition.Location})]]" : $"[[attribute({ioDefinition.Location})]]",
_ => "" _ => ""
}; };

View File

@ -8,6 +8,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
public const string IAttributePrefix = "inAttr"; public const string IAttributePrefix = "inAttr";
public const string OAttributePrefix = "outAttr"; public const string OAttributePrefix = "outAttr";
public const string StructPrefix = "struct";
public const string ArgumentNamePrefix = "a"; public const string ArgumentNamePrefix = "a";
public const string UndefinedName = "0"; public const string UndefinedName = "0";

View File

@ -2,7 +2,7 @@ using Ryujinx.Graphics.Shader.IntermediateRepresentation;
using Ryujinx.Graphics.Shader.StructuredIr; using Ryujinx.Graphics.Shader.StructuredIr;
using Ryujinx.Graphics.Shader.Translation; using Ryujinx.Graphics.Shader.Translation;
using System; using System;
using System.Text;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenCall; using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenCall;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenHelper; using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenHelper;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenMemory; using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenMemory;
@ -39,11 +39,20 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
int arity = (int)(info.Type & InstType.ArityMask); int arity = (int)(info.Type & InstType.ArityMask);
string args = string.Empty; StringBuilder builder = new();
if (atomic) if (atomic && (operation.StorageKind == StorageKind.StorageBuffer || operation.StorageKind == StorageKind.SharedMemory))
{ {
// Hell builder.Append(GenerateLoadOrStore(context, operation, isStore: false));
AggregateType dstType = operation.Inst == Instruction.AtomicMaxS32 || operation.Inst == Instruction.AtomicMinS32
? AggregateType.S32
: AggregateType.U32;
for (int argIndex = operation.SourcesCount - arity + 2; argIndex < operation.SourcesCount; argIndex++)
{
builder.Append($", {GetSourceExpr(context, operation.GetSource(argIndex), dstType)}");
}
} }
else else
{ {
@ -51,16 +60,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
{ {
if (argIndex != 0) if (argIndex != 0)
{ {
args += ", "; builder.Append(", ");
} }
AggregateType dstType = GetSrcVarType(inst, argIndex); AggregateType dstType = GetSrcVarType(inst, argIndex);
args += GetSourceExpr(context, operation.GetSource(argIndex), dstType); builder.Append(GetSourceExpr(context, operation.GetSource(argIndex), dstType));
} }
} }
return info.OpName + '(' + args + ')'; return $"{info.OpName}({builder})";
} }
else if ((info.Type & InstType.Op) != 0) else if ((info.Type & InstType.Op) != 0)
{ {
@ -110,7 +119,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
switch (inst & Instruction.Mask) switch (inst & Instruction.Mask)
{ {
case Instruction.Barrier: case Instruction.Barrier:
return "|| BARRIER ||"; return "threadgroup_barrier(mem_flags::mem_threadgroup)";
case Instruction.Call: case Instruction.Call:
return Call(context, operation); return Call(context, operation);
case Instruction.FSIBegin: case Instruction.FSIBegin:

View File

@ -13,13 +13,22 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
var functon = context.GetFunction(funcId.Value); var functon = context.GetFunction(funcId.Value);
int argCount = operation.SourcesCount - 1; int argCount = operation.SourcesCount - 1;
string[] args = new string[argCount + CodeGenContext.AdditionalArgCount]; int additionalArgCount = CodeGenContext.AdditionalArgCount + (context.Definitions.Stage != ShaderStage.Compute ? 1 : 0);
string[] args = new string[argCount + additionalArgCount];
// Additional arguments // Additional arguments
if (context.Definitions.Stage != ShaderStage.Compute)
{
args[0] = "in"; args[0] = "in";
args[1] = "support_buffer"; args[1] = "support_buffer";
}
else
{
args[0] = "support_buffer";
}
int argIndex = CodeGenContext.AdditionalArgCount; int argIndex = additionalArgCount;
for (int i = 0; i < argCount; i++) for (int i = 0; i < argCount; i++)
{ {
args[argIndex++] = GetSourceExpr(context, operation.GetSource(i + 1), functon.GetArgumentType(i)); args[argIndex++] = GetSourceExpr(context, operation.GetSource(i + 1), functon.GetArgumentType(i));

View File

@ -109,10 +109,10 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
Add(Instruction.ShiftLeft, InstType.OpBinary, "<<", 3); Add(Instruction.ShiftLeft, InstType.OpBinary, "<<", 3);
Add(Instruction.ShiftRightS32, InstType.OpBinary, ">>", 3); Add(Instruction.ShiftRightS32, InstType.OpBinary, ">>", 3);
Add(Instruction.ShiftRightU32, InstType.OpBinary, ">>", 3); Add(Instruction.ShiftRightU32, InstType.OpBinary, ">>", 3);
Add(Instruction.Shuffle, InstType.CallQuaternary, "simd_shuffle"); Add(Instruction.Shuffle, InstType.CallBinary, "simd_shuffle");
Add(Instruction.ShuffleDown, InstType.CallQuaternary, "simd_shuffle_down"); Add(Instruction.ShuffleDown, InstType.CallBinary, "simd_shuffle_down");
Add(Instruction.ShuffleUp, InstType.CallQuaternary, "simd_shuffle_up"); Add(Instruction.ShuffleUp, InstType.CallBinary, "simd_shuffle_up");
Add(Instruction.ShuffleXor, InstType.CallQuaternary, "simd_shuffle_xor"); Add(Instruction.ShuffleXor, InstType.CallBinary, "simd_shuffle_xor");
Add(Instruction.Sine, InstType.CallUnary, "sin"); Add(Instruction.Sine, InstType.CallUnary, "sin");
Add(Instruction.SquareRoot, InstType.CallUnary, "sqrt"); Add(Instruction.SquareRoot, InstType.CallUnary, "sqrt");
Add(Instruction.Store, InstType.Special); Add(Instruction.Store, InstType.Special);

View File

@ -47,15 +47,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
StructureField field = buffer.Type.Fields[fieldIndex.Value]; StructureField field = buffer.Type.Fields[fieldIndex.Value];
varName = buffer.Name; varName = buffer.Name;
if ((field.Type & AggregateType.Array) != 0 && field.ArrayLength == 0)
{
// Unsized array, the buffer is indexed instead of the field
fieldName = "." + field.Name;
}
else
{
varName += "->" + field.Name; varName += "->" + field.Name;
}
varType = field.Type; varType = field.Type;
break; break;

View File

@ -27,13 +27,19 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
IoVariable.GlobalId => ("thread_position_in_grid", AggregateType.Vector3 | AggregateType.U32), IoVariable.GlobalId => ("thread_position_in_grid", AggregateType.Vector3 | AggregateType.U32),
IoVariable.InstanceId => ("instance_id", AggregateType.S32), IoVariable.InstanceId => ("instance_id", AggregateType.S32),
IoVariable.InvocationId => ("INVOCATION_ID", AggregateType.S32), IoVariable.InvocationId => ("INVOCATION_ID", AggregateType.S32),
IoVariable.PointCoord => ("point_coord", AggregateType.Vector2), IoVariable.PointCoord => ("point_coord", AggregateType.Vector2 | AggregateType.FP32),
IoVariable.PointSize => ("out.point_size", AggregateType.FP32), IoVariable.PointSize => ("out.point_size", AggregateType.FP32),
IoVariable.Position => ("out.position", AggregateType.Vector4 | AggregateType.FP32), IoVariable.Position => ("out.position", AggregateType.Vector4 | AggregateType.FP32),
IoVariable.PrimitiveId => ("primitive_id", AggregateType.S32), IoVariable.PrimitiveId => ("primitive_id", AggregateType.S32),
IoVariable.SubgroupEqMask => ("thread_index_in_simdgroup >= 32 ? uint4(0, (1 << (thread_index_in_simdgroup - 32)), uint2(0)) : uint4(1 << thread_index_in_simdgroup, uint3(0))", AggregateType.Vector4 | AggregateType.U32),
IoVariable.SubgroupGeMask => ("uint4(insert_bits(0u, 0xFFFFFFFF, thread_index_in_simdgroup, 32 - thread_index_in_simdgroup), uint3(0)) & (uint4((uint)((simd_vote::vote_t)simd_ballot(true) & 0xFFFFFFFF), (uint)(((simd_vote::vote_t)simd_ballot(true) >> 32) & 0xFFFFFFFF), 0, 0))", AggregateType.Vector4 | AggregateType.U32),
IoVariable.SubgroupGtMask => ("uint4(insert_bits(0u, 0xFFFFFFFF, thread_index_in_simdgroup + 1, 32 - thread_index_in_simdgroup - 1), uint3(0)) & (uint4((uint)((simd_vote::vote_t)simd_ballot(true) & 0xFFFFFFFF), (uint)(((simd_vote::vote_t)simd_ballot(true) >> 32) & 0xFFFFFFFF), 0, 0))", AggregateType.Vector4 | AggregateType.U32),
IoVariable.SubgroupLaneId => ("thread_index_in_simdgroup", AggregateType.U32),
IoVariable.SubgroupLeMask => ("uint4(extract_bits(0xFFFFFFFF, 0, min(thread_index_in_simdgroup + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)thread_index_in_simdgroup + 1 - 32, 0)), uint2(0))", AggregateType.Vector4 | AggregateType.U32),
IoVariable.SubgroupLtMask => ("uint4(extract_bits(0xFFFFFFFF, 0, min(thread_index_in_simdgroup, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)thread_index_in_simdgroup - 32, 0)), uint2(0))", AggregateType.Vector4 | AggregateType.U32),
IoVariable.ThreadKill => ("simd_is_helper_thread()", AggregateType.Bool),
IoVariable.UserDefined => GetUserDefinedVariableName(definitions, location, component, isOutput, isPerPatch), IoVariable.UserDefined => GetUserDefinedVariableName(definitions, location, component, isOutput, isPerPatch),
IoVariable.ThreadId => ("thread_position_in_threadgroup", AggregateType.Vector3 | AggregateType.U32), IoVariable.ThreadId => ("thread_position_in_threadgroup", AggregateType.Vector3 | AggregateType.U32),
IoVariable.SubgroupLaneId => ("thread_index_in_simdgroup", AggregateType.U32),
IoVariable.VertexId => ("vertex_id", AggregateType.S32), IoVariable.VertexId => ("vertex_id", AggregateType.S32),
// gl_VertexIndex does not have a direct equivalent in MSL // gl_VertexIndex does not have a direct equivalent in MSL
IoVariable.VertexIndex => ("vertex_id", AggregateType.U32), IoVariable.VertexIndex => ("vertex_id", AggregateType.U32),

View File

@ -44,7 +44,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
context.AppendLine(GetFunctionSignature(context, function, stage, isMainFunc)); context.AppendLine(GetFunctionSignature(context, function, stage, isMainFunc));
context.EnterScope(); context.EnterScope();
Declarations.DeclareLocals(context, function, stage); Declarations.DeclareLocals(context, function, stage, isMainFunc);
PrintBlock(context, function.MainBlock, isMainFunc); PrintBlock(context, function.MainBlock, isMainFunc);
@ -63,15 +63,22 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
ShaderStage stage, ShaderStage stage,
bool isMainFunc = false) bool isMainFunc = false)
{ {
int additionalArgCount = isMainFunc ? 0 : CodeGenContext.AdditionalArgCount; int additionalArgCount = isMainFunc ? 0 : CodeGenContext.AdditionalArgCount + (context.Definitions.Stage != ShaderStage.Compute ? 1 : 0);
string[] args = new string[additionalArgCount + function.InArguments.Length + function.OutArguments.Length]; string[] args = new string[additionalArgCount + function.InArguments.Length + function.OutArguments.Length];
// All non-main functions need to be able to access the support_buffer as well // All non-main functions need to be able to access the support_buffer as well
if (!isMainFunc) if (!isMainFunc)
{ {
args[0] = "FragmentIn in"; if (stage != ShaderStage.Compute)
args[1] = "constant Struct_support_buffer* support_buffer"; {
args[0] = stage == ShaderStage.Vertex ? "VertexIn in" : "FragmentIn in";
args[1] = $"constant {DefaultNames.StructPrefix}_support_buffer* support_buffer";
}
else
{
args[0] = $"constant {DefaultNames.StructPrefix}_support_buffer* support_buffer";
}
} }
int argIndex = additionalArgCount; int argIndex = additionalArgCount;
@ -141,13 +148,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
foreach (var constantBuffer in context.Properties.ConstantBuffers.Values) foreach (var constantBuffer in context.Properties.ConstantBuffers.Values)
{ {
args = args.Append($"constant Struct_{constantBuffer.Name}* {constantBuffer.Name} [[buffer({constantBuffer.Binding})]]").ToArray(); args = args.Append($"constant {DefaultNames.StructPrefix}_{constantBuffer.Name}* {constantBuffer.Name} [[buffer({constantBuffer.Binding})]]").ToArray();
} }
foreach (var storageBuffers in context.Properties.StorageBuffers.Values) foreach (var storageBuffers in context.Properties.StorageBuffers.Values)
{ {
// Offset the binding by 15 to avoid clashing with the constant buffers // Offset the binding by 15 to avoid clashing with the constant buffers
args = args.Append($"device Struct_{storageBuffers.Name}* {storageBuffers.Name} [[buffer({storageBuffers.Binding + 15})]]").ToArray(); args = args.Append($"device {DefaultNames.StructPrefix}_{storageBuffers.Name}* {storageBuffers.Name} [[buffer({storageBuffers.Binding + 15})]]").ToArray();
} }
foreach (var texture in context.Properties.Textures.Values) foreach (var texture in context.Properties.Textures.Values)
@ -162,7 +169,10 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
} }
} }
return $"{funcKeyword} {returnType} {funcName ?? function.Name}({string.Join(", ", args)})"; var funcPrefix = $"{funcKeyword} {returnType} {funcName ?? function.Name}(";
var indent = new string(' ', funcPrefix.Length);
return $"{funcPrefix}{string.Join($", \n{indent}", args)})";
} }
private static void PrintBlock(CodeGenContext context, AstBlock block, bool isMainFunction) private static void PrintBlock(CodeGenContext context, AstBlock block, bool isMainFunction)

View File

@ -10,24 +10,20 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
public static bool TryFormat(int value, AggregateType dstType, out string formatted) public static bool TryFormat(int value, AggregateType dstType, out string formatted)
{ {
if (dstType == AggregateType.FP32) switch (dstType)
{ {
case AggregateType.FP32:
return TryFormatFloat(BitConverter.Int32BitsToSingle(value), out formatted); return TryFormatFloat(BitConverter.Int32BitsToSingle(value), out formatted);
} case AggregateType.S32:
else if (dstType == AggregateType.S32)
{
formatted = FormatInt(value); formatted = FormatInt(value);
} break;
else if (dstType == AggregateType.U32) case AggregateType.U32:
{
formatted = FormatUint((uint)value); formatted = FormatUint((uint)value);
} break;
else if (dstType == AggregateType.Bool) case AggregateType.Bool:
{
formatted = value != 0 ? "true" : "false"; formatted = value != 0 ? "true" : "false";
} break;
else default:
{
throw new ArgumentException($"Invalid variable type \"{dstType}\"."); throw new ArgumentException($"Invalid variable type \"{dstType}\".");
} }
@ -65,18 +61,12 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
public static string FormatInt(int value, AggregateType dstType) public static string FormatInt(int value, AggregateType dstType)
{ {
if (dstType == AggregateType.S32) return dstType switch
{ {
return FormatInt(value); AggregateType.S32 => FormatInt(value),
} AggregateType.U32 => FormatUint((uint)value),
else if (dstType == AggregateType.U32) _ => throw new ArgumentException($"Invalid variable type \"{dstType}\".")
{ };
return FormatUint((uint)value);
}
else
{
throw new ArgumentException($"Invalid variable type \"{dstType}\".");
}
} }
public static string FormatInt(int value) public static string FormatInt(int value)