WIP V2: Experimental: Metal backend #441
@ -9,7 +9,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
||||
public const string Tab = " ";
|
||||
|
||||
// The number of additional arguments that every function (except for the main one) must have (for instance support_buffer)
|
||||
public const int AdditionalArgCount = 2;
|
||||
public const int AdditionalArgCount = 1;
|
||||
|
||||
public StructuredFunction CurrentFunction { get; set; }
|
||||
|
||||
|
@ -64,9 +64,14 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
||||
return ioDefinition.StorageKind == storageKind && ioDefinition.IoVariable == IoVariable.UserDefined;
|
||||
}
|
||||
|
||||
public static void DeclareLocals(CodeGenContext context, StructuredFunction function, ShaderStage stage)
|
||||
public static void DeclareLocals(CodeGenContext context, StructuredFunction function, ShaderStage stage, bool isMainFunc = false)
|
||||
{
|
||||
DeclareMemories(context, context.Properties.LocalMemories.Values, isShared: false);
|
||||
if (isMainFunc)
|
||||
{
|
||||
DeclareMemories(context, context.Properties.LocalMemories.Values, isShared: false);
|
||||
DeclareMemories(context, context.Properties.SharedMemories.Values, isShared: true);
|
||||
}
|
||||
|
||||
switch (stage)
|
||||
{
|
||||
case ShaderStage.Vertex:
|
||||
@ -112,6 +117,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
||||
|
||||
private static void DeclareMemories(CodeGenContext context, IEnumerable<MemoryDefinition> memories, bool isShared)
|
||||
{
|
||||
string prefix = isShared ? "threadgroup " : string.Empty;
|
||||
|
||||
foreach (var memory in memories)
|
||||
{
|
||||
string arraySize = "";
|
||||
@ -120,7 +127,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
||||
arraySize = $"[{memory.ArrayLength}]";
|
||||
}
|
||||
var typeName = GetVarTypeName(context, memory.Type & ~AggregateType.Array);
|
||||
context.AppendLine($"{typeName} {memory.Name}{arraySize};");
|
||||
context.AppendLine($"{prefix}{typeName} {memory.Name}{arraySize};");
|
||||
}
|
||||
}
|
||||
|
||||
@ -128,23 +135,28 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
||||
{
|
||||
foreach (BufferDefinition buffer in buffers)
|
||||
{
|
||||
context.AppendLine($"struct Struct_{buffer.Name}");
|
||||
context.AppendLine($"struct {DefaultNames.StructPrefix}_{buffer.Name}");
|
||||
context.EnterScope();
|
||||
|
||||
foreach (StructureField field in buffer.Type.Fields)
|
||||
{
|
||||
if (field.Type.HasFlag(AggregateType.Array) && field.ArrayLength > 0)
|
||||
{
|
||||
string typeName = GetVarTypeName(context, field.Type & ~AggregateType.Array);
|
||||
string typeName = GetVarTypeName(context, field.Type & ~AggregateType.Array);
|
||||
string arraySuffix = "";
|
||||
|
||||
context.AppendLine($"{typeName} {field.Name}[{field.ArrayLength}];");
|
||||
}
|
||||
else
|
||||
if (field.Type.HasFlag(AggregateType.Array))
|
||||
{
|
||||
string typeName = GetVarTypeName(context, field.Type & ~AggregateType.Array);
|
||||
|
||||
context.AppendLine($"{typeName} {field.Name};");
|
||||
if (field.ArrayLength > 0)
|
||||
{
|
||||
arraySuffix = $"[{field.ArrayLength}]";
|
||||
}
|
||||
else
|
||||
{
|
||||
// Probably UB, but this is the approach that MVK takes
|
||||
arraySuffix = "[1]";
|
||||
}
|
||||
}
|
||||
|
||||
context.AppendLine($"{typeName} {field.Name}{arraySuffix};");
|
||||
}
|
||||
|
||||
context.LeaveScope(";");
|
||||
@ -191,6 +203,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
||||
IoVariable.GlobalId => "uint3",
|
||||
IoVariable.VertexId => "uint",
|
||||
IoVariable.VertexIndex => "uint",
|
||||
IoVariable.PointCoord => "float2",
|
||||
_ => GetVarTypeName(context, context.Definitions.GetUserDefinedType(ioDefinition.Location, isOutput: false))
|
||||
};
|
||||
string name = ioDefinition.IoVariable switch
|
||||
@ -199,6 +212,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
||||
IoVariable.GlobalId => "global_id",
|
||||
IoVariable.VertexId => "vertex_id",
|
||||
IoVariable.VertexIndex => "vertex_index",
|
||||
IoVariable.PointCoord => "point_coord",
|
||||
_ => $"{DefaultNames.IAttributePrefix}{ioDefinition.Location}"
|
||||
};
|
||||
string suffix = ioDefinition.IoVariable switch
|
||||
@ -208,6 +222,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
||||
IoVariable.VertexId => "[[vertex_id]]",
|
||||
// TODO: Avoid potential redeclaration
|
||||
IoVariable.VertexIndex => "[[vertex_id]]",
|
||||
IoVariable.PointCoord => "[[point_coord]]",
|
||||
IoVariable.UserDefined => context.Definitions.Stage == ShaderStage.Fragment ? $"[[user(loc{ioDefinition.Location})]]" : $"[[attribute({ioDefinition.Location})]]",
|
||||
_ => ""
|
||||
};
|
||||
|
@ -8,6 +8,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
||||
public const string IAttributePrefix = "inAttr";
|
||||
public const string OAttributePrefix = "outAttr";
|
||||
|
||||
public const string StructPrefix = "struct";
|
||||
|
||||
public const string ArgumentNamePrefix = "a";
|
||||
|
||||
public const string UndefinedName = "0";
|
||||
|
@ -2,7 +2,7 @@ using Ryujinx.Graphics.Shader.IntermediateRepresentation;
|
||||
using Ryujinx.Graphics.Shader.StructuredIr;
|
||||
using Ryujinx.Graphics.Shader.Translation;
|
||||
using System;
|
||||
|
||||
using System.Text;
|
||||
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenCall;
|
||||
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenHelper;
|
||||
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenMemory;
|
||||
@ -39,11 +39,20 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
|
||||
|
||||
int arity = (int)(info.Type & InstType.ArityMask);
|
||||
|
||||
string args = string.Empty;
|
||||
StringBuilder builder = new();
|
||||
|
||||
if (atomic)
|
||||
if (atomic && (operation.StorageKind == StorageKind.StorageBuffer || operation.StorageKind == StorageKind.SharedMemory))
|
||||
{
|
||||
// Hell
|
||||
builder.Append(GenerateLoadOrStore(context, operation, isStore: false));
|
||||
|
||||
AggregateType dstType = operation.Inst == Instruction.AtomicMaxS32 || operation.Inst == Instruction.AtomicMinS32
|
||||
? AggregateType.S32
|
||||
: AggregateType.U32;
|
||||
|
||||
for (int argIndex = operation.SourcesCount - arity + 2; argIndex < operation.SourcesCount; argIndex++)
|
||||
{
|
||||
builder.Append($", {GetSourceExpr(context, operation.GetSource(argIndex), dstType)}");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -51,16 +60,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
|
||||
{
|
||||
if (argIndex != 0)
|
||||
{
|
||||
args += ", ";
|
||||
builder.Append(", ");
|
||||
}
|
||||
|
||||
AggregateType dstType = GetSrcVarType(inst, argIndex);
|
||||
|
||||
args += GetSourceExpr(context, operation.GetSource(argIndex), dstType);
|
||||
builder.Append(GetSourceExpr(context, operation.GetSource(argIndex), dstType));
|
||||
}
|
||||
}
|
||||
|
||||
return info.OpName + '(' + args + ')';
|
||||
return $"{info.OpName}({builder})";
|
||||
}
|
||||
else if ((info.Type & InstType.Op) != 0)
|
||||
{
|
||||
@ -110,7 +119,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
|
||||
switch (inst & Instruction.Mask)
|
||||
{
|
||||
case Instruction.Barrier:
|
||||
return "|| BARRIER ||";
|
||||
return "threadgroup_barrier(mem_flags::mem_threadgroup)";
|
||||
case Instruction.Call:
|
||||
return Call(context, operation);
|
||||
case Instruction.FSIBegin:
|
||||
|
@ -13,13 +13,22 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
|
||||
var functon = context.GetFunction(funcId.Value);
|
||||
|
||||
int argCount = operation.SourcesCount - 1;
|
||||
string[] args = new string[argCount + CodeGenContext.AdditionalArgCount];
|
||||
int additionalArgCount = CodeGenContext.AdditionalArgCount + (context.Definitions.Stage != ShaderStage.Compute ? 1 : 0);
|
||||
|
||||
string[] args = new string[argCount + additionalArgCount];
|
||||
|
||||
// Additional arguments
|
||||
args[0] = "in";
|
||||
args[1] = "support_buffer";
|
||||
if (context.Definitions.Stage != ShaderStage.Compute)
|
||||
{
|
||||
args[0] = "in";
|
||||
args[1] = "support_buffer";
|
||||
}
|
||||
else
|
||||
{
|
||||
args[0] = "support_buffer";
|
||||
}
|
||||
|
||||
int argIndex = CodeGenContext.AdditionalArgCount;
|
||||
int argIndex = additionalArgCount;
|
||||
for (int i = 0; i < argCount; i++)
|
||||
{
|
||||
args[argIndex++] = GetSourceExpr(context, operation.GetSource(i + 1), functon.GetArgumentType(i));
|
||||
|
@ -109,10 +109,10 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
|
||||
Add(Instruction.ShiftLeft, InstType.OpBinary, "<<", 3);
|
||||
Add(Instruction.ShiftRightS32, InstType.OpBinary, ">>", 3);
|
||||
Add(Instruction.ShiftRightU32, InstType.OpBinary, ">>", 3);
|
||||
Add(Instruction.Shuffle, InstType.CallQuaternary, "simd_shuffle");
|
||||
Add(Instruction.ShuffleDown, InstType.CallQuaternary, "simd_shuffle_down");
|
||||
Add(Instruction.ShuffleUp, InstType.CallQuaternary, "simd_shuffle_up");
|
||||
Add(Instruction.ShuffleXor, InstType.CallQuaternary, "simd_shuffle_xor");
|
||||
Add(Instruction.Shuffle, InstType.CallBinary, "simd_shuffle");
|
||||
Add(Instruction.ShuffleDown, InstType.CallBinary, "simd_shuffle_down");
|
||||
Add(Instruction.ShuffleUp, InstType.CallBinary, "simd_shuffle_up");
|
||||
Add(Instruction.ShuffleXor, InstType.CallBinary, "simd_shuffle_xor");
|
||||
Add(Instruction.Sine, InstType.CallUnary, "sin");
|
||||
Add(Instruction.SquareRoot, InstType.CallUnary, "sqrt");
|
||||
Add(Instruction.Store, InstType.Special);
|
||||
|
@ -47,15 +47,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
|
||||
|
||||
StructureField field = buffer.Type.Fields[fieldIndex.Value];
|
||||
varName = buffer.Name;
|
||||
if ((field.Type & AggregateType.Array) != 0 && field.ArrayLength == 0)
|
||||
{
|
||||
// Unsized array, the buffer is indexed instead of the field
|
||||
fieldName = "." + field.Name;
|
||||
}
|
||||
else
|
||||
{
|
||||
varName += "->" + field.Name;
|
||||
}
|
||||
varName += "->" + field.Name;
|
||||
varType = field.Type;
|
||||
break;
|
||||
|
||||
|
@ -27,13 +27,19 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
|
||||
IoVariable.GlobalId => ("thread_position_in_grid", AggregateType.Vector3 | AggregateType.U32),
|
||||
IoVariable.InstanceId => ("instance_id", AggregateType.S32),
|
||||
IoVariable.InvocationId => ("INVOCATION_ID", AggregateType.S32),
|
||||
IoVariable.PointCoord => ("point_coord", AggregateType.Vector2),
|
||||
IoVariable.PointCoord => ("point_coord", AggregateType.Vector2 | AggregateType.FP32),
|
||||
IoVariable.PointSize => ("out.point_size", AggregateType.FP32),
|
||||
IoVariable.Position => ("out.position", AggregateType.Vector4 | AggregateType.FP32),
|
||||
IoVariable.PrimitiveId => ("primitive_id", AggregateType.S32),
|
||||
IoVariable.SubgroupEqMask => ("thread_index_in_simdgroup >= 32 ? uint4(0, (1 << (thread_index_in_simdgroup - 32)), uint2(0)) : uint4(1 << thread_index_in_simdgroup, uint3(0))", AggregateType.Vector4 | AggregateType.U32),
|
||||
IoVariable.SubgroupGeMask => ("uint4(insert_bits(0u, 0xFFFFFFFF, thread_index_in_simdgroup, 32 - thread_index_in_simdgroup), uint3(0)) & (uint4((uint)((simd_vote::vote_t)simd_ballot(true) & 0xFFFFFFFF), (uint)(((simd_vote::vote_t)simd_ballot(true) >> 32) & 0xFFFFFFFF), 0, 0))", AggregateType.Vector4 | AggregateType.U32),
|
||||
IoVariable.SubgroupGtMask => ("uint4(insert_bits(0u, 0xFFFFFFFF, thread_index_in_simdgroup + 1, 32 - thread_index_in_simdgroup - 1), uint3(0)) & (uint4((uint)((simd_vote::vote_t)simd_ballot(true) & 0xFFFFFFFF), (uint)(((simd_vote::vote_t)simd_ballot(true) >> 32) & 0xFFFFFFFF), 0, 0))", AggregateType.Vector4 | AggregateType.U32),
|
||||
IoVariable.SubgroupLaneId => ("thread_index_in_simdgroup", AggregateType.U32),
|
||||
IoVariable.SubgroupLeMask => ("uint4(extract_bits(0xFFFFFFFF, 0, min(thread_index_in_simdgroup + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)thread_index_in_simdgroup + 1 - 32, 0)), uint2(0))", AggregateType.Vector4 | AggregateType.U32),
|
||||
IoVariable.SubgroupLtMask => ("uint4(extract_bits(0xFFFFFFFF, 0, min(thread_index_in_simdgroup, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)thread_index_in_simdgroup - 32, 0)), uint2(0))", AggregateType.Vector4 | AggregateType.U32),
|
||||
IoVariable.ThreadKill => ("simd_is_helper_thread()", AggregateType.Bool),
|
||||
IoVariable.UserDefined => GetUserDefinedVariableName(definitions, location, component, isOutput, isPerPatch),
|
||||
IoVariable.ThreadId => ("thread_position_in_threadgroup", AggregateType.Vector3 | AggregateType.U32),
|
||||
IoVariable.SubgroupLaneId => ("thread_index_in_simdgroup", AggregateType.U32),
|
||||
IoVariable.VertexId => ("vertex_id", AggregateType.S32),
|
||||
// gl_VertexIndex does not have a direct equivalent in MSL
|
||||
IoVariable.VertexIndex => ("vertex_id", AggregateType.U32),
|
||||
|
@ -44,7 +44,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
||||
context.AppendLine(GetFunctionSignature(context, function, stage, isMainFunc));
|
||||
context.EnterScope();
|
||||
|
||||
Declarations.DeclareLocals(context, function, stage);
|
||||
Declarations.DeclareLocals(context, function, stage, isMainFunc);
|
||||
|
||||
PrintBlock(context, function.MainBlock, isMainFunc);
|
||||
|
||||
@ -63,15 +63,22 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
||||
ShaderStage stage,
|
||||
bool isMainFunc = false)
|
||||
{
|
||||
int additionalArgCount = isMainFunc ? 0 : CodeGenContext.AdditionalArgCount;
|
||||
int additionalArgCount = isMainFunc ? 0 : CodeGenContext.AdditionalArgCount + (context.Definitions.Stage != ShaderStage.Compute ? 1 : 0);
|
||||
|
||||
string[] args = new string[additionalArgCount + function.InArguments.Length + function.OutArguments.Length];
|
||||
|
||||
// All non-main functions need to be able to access the support_buffer as well
|
||||
if (!isMainFunc)
|
||||
{
|
||||
args[0] = "FragmentIn in";
|
||||
args[1] = "constant Struct_support_buffer* support_buffer";
|
||||
if (stage != ShaderStage.Compute)
|
||||
{
|
||||
args[0] = stage == ShaderStage.Vertex ? "VertexIn in" : "FragmentIn in";
|
||||
args[1] = $"constant {DefaultNames.StructPrefix}_support_buffer* support_buffer";
|
||||
}
|
||||
else
|
||||
{
|
||||
args[0] = $"constant {DefaultNames.StructPrefix}_support_buffer* support_buffer";
|
||||
}
|
||||
}
|
||||
|
||||
int argIndex = additionalArgCount;
|
||||
@ -141,13 +148,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
||||
|
||||
foreach (var constantBuffer in context.Properties.ConstantBuffers.Values)
|
||||
{
|
||||
args = args.Append($"constant Struct_{constantBuffer.Name}* {constantBuffer.Name} [[buffer({constantBuffer.Binding})]]").ToArray();
|
||||
args = args.Append($"constant {DefaultNames.StructPrefix}_{constantBuffer.Name}* {constantBuffer.Name} [[buffer({constantBuffer.Binding})]]").ToArray();
|
||||
}
|
||||
|
||||
foreach (var storageBuffers in context.Properties.StorageBuffers.Values)
|
||||
{
|
||||
// Offset the binding by 15 to avoid clashing with the constant buffers
|
||||
args = args.Append($"device Struct_{storageBuffers.Name}* {storageBuffers.Name} [[buffer({storageBuffers.Binding + 15})]]").ToArray();
|
||||
args = args.Append($"device {DefaultNames.StructPrefix}_{storageBuffers.Name}* {storageBuffers.Name} [[buffer({storageBuffers.Binding + 15})]]").ToArray();
|
||||
}
|
||||
|
||||
foreach (var texture in context.Properties.Textures.Values)
|
||||
@ -162,7 +169,10 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
||||
}
|
||||
}
|
||||
|
||||
return $"{funcKeyword} {returnType} {funcName ?? function.Name}({string.Join(", ", args)})";
|
||||
var funcPrefix = $"{funcKeyword} {returnType} {funcName ?? function.Name}(";
|
||||
var indent = new string(' ', funcPrefix.Length);
|
||||
|
||||
return $"{funcPrefix}{string.Join($", \n{indent}", args)})";
|
||||
}
|
||||
|
||||
private static void PrintBlock(CodeGenContext context, AstBlock block, bool isMainFunction)
|
||||
|
@ -10,25 +10,21 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
||||
|
||||
public static bool TryFormat(int value, AggregateType dstType, out string formatted)
|
||||
{
|
||||
if (dstType == AggregateType.FP32)
|
||||
switch (dstType)
|
||||
{
|
||||
return TryFormatFloat(BitConverter.Int32BitsToSingle(value), out formatted);
|
||||
}
|
||||
else if (dstType == AggregateType.S32)
|
||||
{
|
||||
formatted = FormatInt(value);
|
||||
}
|
||||
else if (dstType == AggregateType.U32)
|
||||
{
|
||||
formatted = FormatUint((uint)value);
|
||||
}
|
||||
else if (dstType == AggregateType.Bool)
|
||||
{
|
||||
formatted = value != 0 ? "true" : "false";
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new ArgumentException($"Invalid variable type \"{dstType}\".");
|
||||
case AggregateType.FP32:
|
||||
return TryFormatFloat(BitConverter.Int32BitsToSingle(value), out formatted);
|
||||
case AggregateType.S32:
|
||||
formatted = FormatInt(value);
|
||||
break;
|
||||
case AggregateType.U32:
|
||||
formatted = FormatUint((uint)value);
|
||||
break;
|
||||
case AggregateType.Bool:
|
||||
formatted = value != 0 ? "true" : "false";
|
||||
break;
|
||||
default:
|
||||
throw new ArgumentException($"Invalid variable type \"{dstType}\".");
|
||||
}
|
||||
|
||||
return true;
|
||||
@ -65,18 +61,12 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
||||
|
||||
public static string FormatInt(int value, AggregateType dstType)
|
||||
{
|
||||
if (dstType == AggregateType.S32)
|
||||
return dstType switch
|
||||
{
|
||||
return FormatInt(value);
|
||||
}
|
||||
else if (dstType == AggregateType.U32)
|
||||
{
|
||||
return FormatUint((uint)value);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new ArgumentException($"Invalid variable type \"{dstType}\".");
|
||||
}
|
||||
AggregateType.S32 => FormatInt(value),
|
||||
AggregateType.U32 => FormatUint((uint)value),
|
||||
_ => throw new ArgumentException($"Invalid variable type \"{dstType}\".")
|
||||
};
|
||||
}
|
||||
|
||||
public static string FormatInt(int value)
|
||||
|
Loading…
x
Reference in New Issue
Block a user