From 9daf029f356898336de1ad0c63b6c36e261e4f9b Mon Sep 17 00:00:00 2001
From: gdkchan <gab.dark.100@gmail.com>
Date: Sat, 12 Nov 2022 20:20:40 -0300
Subject: [PATCH] Use vector transform feedback outputs if possible (#3832)

---
 .../Shader/DiskCache/DiskCacheHostStorage.cs  |  2 +-
 .../CodeGen/Glsl/CodeGenContext.cs            | 20 +----
 .../CodeGen/Glsl/Declarations.cs              | 36 ++++++--
 .../CodeGen/Glsl/GlslGenerator.cs             |  2 +-
 .../CodeGen/Glsl/Instructions/InstGen.cs      |  2 +-
 .../Glsl/Instructions/InstGenMemory.cs        |  4 +-
 .../CodeGen/Glsl/OperandManager.cs            | 22 +++--
 .../CodeGen/Spirv/CodeGenContext.cs           | 28 +++---
 .../CodeGen/Spirv/Declarations.cs             | 87 +++++++++++++------
 .../CodeGen/Spirv/SpirvGenerator.cs           | 14 ++-
 .../StructuredIr/StructuredProgram.cs         |  6 +-
 .../StructuredIr/StructuredProgramInfo.cs     | 35 ++++++++
 .../Translation/ShaderConfig.cs               |  4 +
 13 files changed, 180 insertions(+), 82 deletions(-)

diff --git a/Ryujinx.Graphics.Gpu/Shader/DiskCache/DiskCacheHostStorage.cs b/Ryujinx.Graphics.Gpu/Shader/DiskCache/DiskCacheHostStorage.cs
index 3f3a3c50e..e728c48c4 100644
--- a/Ryujinx.Graphics.Gpu/Shader/DiskCache/DiskCacheHostStorage.cs
+++ b/Ryujinx.Graphics.Gpu/Shader/DiskCache/DiskCacheHostStorage.cs
@@ -22,7 +22,7 @@ namespace Ryujinx.Graphics.Gpu.Shader.DiskCache
         private const ushort FileFormatVersionMajor = 1;
         private const ushort FileFormatVersionMinor = 2;
         private const uint FileFormatVersionPacked = ((uint)FileFormatVersionMajor << 16) | FileFormatVersionMinor;
-        private const uint CodeGenVersion = 3833;
+        private const uint CodeGenVersion = 3831;
 
         private const string SharedTocFileName = "shared.toc";
         private const string SharedDataFileName = "shared.data";
diff --git a/Ryujinx.Graphics.Shader/CodeGen/Glsl/CodeGenContext.cs b/Ryujinx.Graphics.Shader/CodeGen/Glsl/CodeGenContext.cs
index 418af6cb7..9eb20f6f8 100644
--- a/Ryujinx.Graphics.Shader/CodeGen/Glsl/CodeGenContext.cs
+++ b/Ryujinx.Graphics.Shader/CodeGen/Glsl/CodeGenContext.cs
@@ -10,12 +10,12 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
 
         public StructuredFunction CurrentFunction { get; set; }
 
+        public StructuredProgramInfo Info { get; }
+
         public ShaderConfig Config { get; }
 
         public OperandManager OperandManager { get; }
 
-        private readonly StructuredProgramInfo _info;
-
         private readonly StringBuilder _sb;
 
         private int _level;
@@ -24,7 +24,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
 
         public CodeGenContext(StructuredProgramInfo info, ShaderConfig config)
         {
-            _info = info;
+            Info = info;
             Config = config;
 
             OperandManager = new OperandManager();
@@ -72,19 +72,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
 
         public StructuredFunction GetFunction(int id)
         {
-            return _info.Functions[id];
-        }
-
-        public TransformFeedbackOutput GetTransformFeedbackOutput(int location, int component)
-        {
-            int index = (AttributeConsts.UserAttributeBase / 4) + location * 4 + component;
-            return _info.TransformFeedbackOutputs[index];
-        }
-
-        public TransformFeedbackOutput GetTransformFeedbackOutput(int location)
-        {
-            int index = location / 4;
-            return _info.TransformFeedbackOutputs[index];
+            return Info.Functions[id];
         }
 
         private void UpdateIndentation()
diff --git a/Ryujinx.Graphics.Shader/CodeGen/Glsl/Declarations.cs b/Ryujinx.Graphics.Shader/CodeGen/Glsl/Declarations.cs
index 91fd286d4..4f2751b12 100644
--- a/Ryujinx.Graphics.Shader/CodeGen/Glsl/Declarations.cs
+++ b/Ryujinx.Graphics.Shader/CodeGen/Glsl/Declarations.cs
@@ -210,7 +210,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
 
                 if (context.Config.TransformFeedbackEnabled && context.Config.LastInVertexPipeline)
                 {
-                    var tfOutput = context.GetTransformFeedbackOutput(AttributeConsts.PositionX);
+                    var tfOutput = context.Info.GetTransformFeedbackOutput(AttributeConsts.PositionX);
                     if (tfOutput.Valid)
                     {
                         context.AppendLine($"layout (xfb_buffer = {tfOutput.Buffer}, xfb_offset = {tfOutput.Offset}, xfb_stride = {tfOutput.Stride}) out gl_PerVertex");
@@ -604,19 +604,45 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
 
             if (context.Config.TransformFeedbackEnabled && context.Config.LastInVertexPipeline)
             {
-                for (int c = 0; c < 4; c++)
+                int attrOffset = AttributeConsts.UserAttributeBase + attr * 16;
+                int components = context.Config.LastInPipeline ? context.Info.GetTransformFeedbackOutputComponents(attrOffset) : 1;
+
+                if (components > 1)
                 {
-                    char swzMask = "xyzw"[c];
+                    string type = components switch
+                    {
+                        2 => "vec2",
+                        3 => "vec3",
+                        4 => "vec4",
+                        _ => "float"
+                    };
 
                     string xfb = string.Empty;
 
-                    var tfOutput = context.GetTransformFeedbackOutput(attr, c);
+                    var tfOutput = context.Info.GetTransformFeedbackOutput(attrOffset);
                     if (tfOutput.Valid)
                     {
                         xfb = $", xfb_buffer = {tfOutput.Buffer}, xfb_offset = {tfOutput.Offset}, xfb_stride = {tfOutput.Stride}";
                     }
 
-                    context.AppendLine($"layout (location = {attr}, component = {c}{xfb}) out float {name}_{swzMask};");
+                    context.AppendLine($"layout (location = {attr}{xfb}) out {type} {name};");
+                }
+                else
+                {
+                    for (int c = 0; c < 4; c++)
+                    {
+                        char swzMask = "xyzw"[c];
+
+                        string xfb = string.Empty;
+
+                        var tfOutput = context.Info.GetTransformFeedbackOutput(attrOffset + c * 4);
+                        if (tfOutput.Valid)
+                        {
+                            xfb = $", xfb_buffer = {tfOutput.Buffer}, xfb_offset = {tfOutput.Offset}, xfb_stride = {tfOutput.Stride}";
+                        }
+
+                        context.AppendLine($"layout (location = {attr}, component = {c}{xfb}) out float {name}_{swzMask};");
+                    }
                 }
             }
             else
diff --git a/Ryujinx.Graphics.Shader/CodeGen/Glsl/GlslGenerator.cs b/Ryujinx.Graphics.Shader/CodeGen/Glsl/GlslGenerator.cs
index e9dbdd2d3..e1b8eb6ec 100644
--- a/Ryujinx.Graphics.Shader/CodeGen/Glsl/GlslGenerator.cs
+++ b/Ryujinx.Graphics.Shader/CodeGen/Glsl/GlslGenerator.cs
@@ -134,7 +134,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
                     if (assignment.Destination is AstOperand operand && operand.Type.IsAttribute())
                     {
                         bool perPatch = operand.Type == OperandType.AttributePerPatch;
-                        dest = OperandManager.GetOutAttributeName(operand.Value, context.Config, perPatch);
+                        dest = OperandManager.GetOutAttributeName(context, operand.Value, perPatch);
                     }
                     else
                     {
diff --git a/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGen.cs b/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGen.cs
index 388285a8f..b890b0158 100644
--- a/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGen.cs
+++ b/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGen.cs
@@ -22,7 +22,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions
             }
             else if (node is AstOperand operand)
             {
-                return context.OperandManager.GetExpression(operand, context.Config);
+                return context.OperandManager.GetExpression(context, operand);
             }
 
             throw new ArgumentException($"Invalid node type \"{node?.GetType().Name ?? "null"}\".");
diff --git a/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenMemory.cs b/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenMemory.cs
index 094040013..022e3a444 100644
--- a/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenMemory.cs
+++ b/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenMemory.cs
@@ -205,7 +205,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions
             if (src2 is AstOperand operand && operand.Type == OperandType.Constant)
             {
                 int attrOffset = baseAttr.Value + (operand.Value << 2);
-                return OperandManager.GetAttributeName(attrOffset, context.Config, perPatch: false, isOutAttr: false, indexExpr);
+                return OperandManager.GetAttributeName(context, attrOffset, perPatch: false, isOutAttr: false, indexExpr);
             }
             else
             {
@@ -332,7 +332,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions
             if (src2 is AstOperand operand && operand.Type == OperandType.Constant)
             {
                 int attrOffset = baseAttr.Value + (operand.Value << 2);
-                attrName = OperandManager.GetAttributeName(attrOffset, context.Config, perPatch: false, isOutAttr: true);
+                attrName = OperandManager.GetAttributeName(context, attrOffset, perPatch: false, isOutAttr: true);
             }
             else
             {
diff --git a/Ryujinx.Graphics.Shader/CodeGen/Glsl/OperandManager.cs b/Ryujinx.Graphics.Shader/CodeGen/Glsl/OperandManager.cs
index 67442e5a1..031b1c44c 100644
--- a/Ryujinx.Graphics.Shader/CodeGen/Glsl/OperandManager.cs
+++ b/Ryujinx.Graphics.Shader/CodeGen/Glsl/OperandManager.cs
@@ -103,15 +103,15 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
             return name;
         }
 
-        public string GetExpression(AstOperand operand, ShaderConfig config)
+        public string GetExpression(CodeGenContext context, AstOperand operand)
         {
             return operand.Type switch
             {
                 OperandType.Argument => GetArgumentName(operand.Value),
-                OperandType.Attribute => GetAttributeName(operand.Value, config, perPatch: false),
-                OperandType.AttributePerPatch => GetAttributeName(operand.Value, config, perPatch: true),
+                OperandType.Attribute => GetAttributeName(context, operand.Value, perPatch: false),
+                OperandType.AttributePerPatch => GetAttributeName(context, operand.Value, perPatch: true),
                 OperandType.Constant => NumberFormatter.FormatInt(operand.Value),
-                OperandType.ConstantBuffer => GetConstantBufferName(operand, config),
+                OperandType.ConstantBuffer => GetConstantBufferName(operand, context.Config),
                 OperandType.LocalVariable => _locals[operand],
                 OperandType.Undefined => DefaultNames.UndefinedName,
                 _ => throw new ArgumentException($"Invalid operand type \"{operand.Type}\".")
@@ -153,13 +153,15 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
             return GetVec4Indexed(GetUbName(stage, slotExpr) + $"[{offsetExpr} >> 2]", offsetExpr + " & 3", indexElement);
         }
 
-        public static string GetOutAttributeName(int value, ShaderConfig config, bool perPatch)
+        public static string GetOutAttributeName(CodeGenContext context, int value, bool perPatch)
         {
-            return GetAttributeName(value, config, perPatch, isOutAttr: true);
+            return GetAttributeName(context, value, perPatch, isOutAttr: true);
         }
 
-        public static string GetAttributeName(int value, ShaderConfig config, bool perPatch, bool isOutAttr = false, string indexExpr = "0")
+        public static string GetAttributeName(CodeGenContext context, int value, bool perPatch, bool isOutAttr = false, string indexExpr = "0")
         {
+            ShaderConfig config = context.Config;
+
             if ((value & AttributeConsts.LoadOutputMask) != 0)
             {
                 isOutAttr = true;
@@ -192,6 +194,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
             }
             else if (value >= AttributeConsts.UserAttributeBase && value < AttributeConsts.UserAttributeEnd)
             {
+                int attrOffset = value;
                 value -= AttributeConsts.UserAttributeBase;
 
                 string prefix = isOutAttr
@@ -215,14 +218,15 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
                     ((config.LastInVertexPipeline && isOutAttr) ||
                     (config.Stage == ShaderStage.Fragment && !isOutAttr)))
                 {
-                    string name = $"{prefix}{(value >> 4)}_{swzMask}";
+                    int components = config.LastInPipeline ? context.Info.GetTransformFeedbackOutputComponents(attrOffset) : 1;
+                    string name = components > 1 ? $"{prefix}{(value >> 4)}" : $"{prefix}{(value >> 4)}_{swzMask}";
 
                     if (AttributeInfo.IsArrayAttributeGlsl(config.Stage, isOutAttr))
                     {
                         name += isOutAttr ? "[gl_InvocationID]" : $"[{indexExpr}]";
                     }
 
-                    return name;
+                    return components > 1 ? name + '.' + swzMask : name;
                 }
                 else
                 {
diff --git a/Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs b/Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs
index 04c053253..dff5474a1 100644
--- a/Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs
+++ b/Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs
@@ -17,7 +17,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
         private const uint SpirvVersionRevision = 0;
         private const uint SpirvVersionPacked = (SpirvVersionMajor << 16) | (SpirvVersionMinor << 8) | SpirvVersionRevision;
 
-        private readonly StructuredProgramInfo _info;
+        public StructuredProgramInfo Info { get; }
 
         public ShaderConfig Config { get; }
 
@@ -85,7 +85,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
             GeneratorPool<Instruction> instPool,
             GeneratorPool<LiteralInteger> integerPool) : base(SpirvVersionPacked, instPool, integerPool)
         {
-            _info = info;
+            Info = info;
             Config = config;
 
             if (config.Stage == ShaderStage.Geometry)
@@ -317,6 +317,18 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
             {
                 attrOffset = attr;
                 type = elemType;
+
+                if (Config.LastInPipeline && isOutAttr)
+                {
+                    int components = Info.GetTransformFeedbackOutputComponents(attr);
+
+                    if (components > 1)
+                    {
+                        attrOffset &= ~0xf;
+                        type = AggregateType.Vector | AggregateType.FP32;
+                        attrInfo = new AttributeInfo(attrOffset, (attr - attrOffset) / 4, components, type, false);
+                    }
+                }
             }
 
             ioVariable = isOutAttr ? Outputs[attrOffset] : Inputs[attrOffset];
@@ -536,18 +548,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
             return _functions[funcIndex];
         }
 
-        public TransformFeedbackOutput GetTransformFeedbackOutput(int location, int component)
-        {
-            int index = (AttributeConsts.UserAttributeBase / 4) + location * 4 + component;
-            return _info.TransformFeedbackOutputs[index];
-        }
-
-        public TransformFeedbackOutput GetTransformFeedbackOutput(int location)
-        {
-            int index = location / 4;
-            return _info.TransformFeedbackOutputs[index];
-        }
-
         public Instruction GetType(AggregateType type, int length = 1)
         {
             if (type.HasFlag(AggregateType.Array))
diff --git a/Ryujinx.Graphics.Shader/CodeGen/Spirv/Declarations.cs b/Ryujinx.Graphics.Shader/CodeGen/Spirv/Declarations.cs
index c007b9a20..fafb917db 100644
--- a/Ryujinx.Graphics.Shader/CodeGen/Spirv/Declarations.cs
+++ b/Ryujinx.Graphics.Shader/CodeGen/Spirv/Declarations.cs
@@ -440,11 +440,22 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
                 {
                     PixelImap iq = PixelImap.Unused;
 
-                    if (context.Config.Stage == ShaderStage.Fragment &&
-                        attr >= AttributeConsts.UserAttributeBase &&
-                        attr < AttributeConsts.UserAttributeEnd)
+                    if (context.Config.Stage == ShaderStage.Fragment)
                     {
-                        iq = context.Config.ImapTypes[(attr - AttributeConsts.UserAttributeBase) / 16].GetFirstUsedType();
+                        if (attr >= AttributeConsts.UserAttributeBase && attr < AttributeConsts.UserAttributeEnd)
+                        {
+                            iq = context.Config.ImapTypes[(attr - AttributeConsts.UserAttributeBase) / 16].GetFirstUsedType();
+                        }
+                        else
+                        {
+                            AttributeInfo attrInfo = AttributeInfo.From(context.Config, attr, isOutAttr: false);
+                            AggregateType elemType = attrInfo.Type & AggregateType.ElementTypeMask;
+
+                            if (attrInfo.IsBuiltin && (elemType == AggregateType.S32 || elemType == AggregateType.U32))
+                            {
+                                iq = PixelImap.Constant;
+                            }
+                        }
                     }
 
                     DeclareInputOrOutput(context, attr, perPatch, isOutAttr: false, iq);
@@ -516,7 +527,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
                 ((isOutAttr && context.Config.LastInVertexPipeline) ||
                 (!isOutAttr && context.Config.Stage == ShaderStage.Fragment)))
             {
-                DeclareInputOrOutput(context, attr, (attr >> 2) & 3, isOutAttr, iq);
+                DeclareTransformFeedbackInputOrOutput(context, attr, isOutAttr, iq);
                 return;
             }
 
@@ -572,7 +583,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
 
                 if (context.Config.TransformFeedbackEnabled && context.Config.LastInVertexPipeline && isOutAttr)
                 {
-                    var tfOutput = context.GetTransformFeedbackOutput(attrInfo.BaseValue);
+                    var tfOutput = context.Info.GetTransformFeedbackOutput(attrInfo.BaseValue);
                     if (tfOutput.Valid)
                     {
                         context.Decorate(spvVar, Decoration.XfbBuffer, (LiteralInteger)tfOutput.Buffer);
@@ -595,24 +606,12 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
 
                 context.Decorate(spvVar, Decoration.Location, (LiteralInteger)location);
 
-                if (!isOutAttr)
+                if (!isOutAttr &&
+                    !perPatch &&
+                    (context.Config.PassthroughAttributes & (1 << location)) != 0 &&
+                    context.Config.GpuAccessor.QueryHostSupportsGeometryShaderPassthrough())
                 {
-                    if (!perPatch &&
-                        (context.Config.PassthroughAttributes & (1 << location)) != 0 &&
-                        context.Config.GpuAccessor.QueryHostSupportsGeometryShaderPassthrough())
-                    {
-                        context.Decorate(spvVar, Decoration.PassthroughNV);
-                    }
-
-                    switch (iq)
-                    {
-                        case PixelImap.Constant:
-                            context.Decorate(spvVar, Decoration.Flat);
-                            break;
-                        case PixelImap.ScreenLinear:
-                            context.Decorate(spvVar, Decoration.NoPerspective);
-                            break;
-                    }
+                    context.Decorate(spvVar, Decoration.PassthroughNV);
                 }
             }
             else if (attr >= AttributeConsts.FragmentOutputColorBase && attr < AttributeConsts.FragmentOutputColorEnd)
@@ -621,22 +620,52 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
                 context.Decorate(spvVar, Decoration.Location, (LiteralInteger)location);
             }
 
+            if (!isOutAttr)
+            {
+                switch (iq)
+                {
+                    case PixelImap.Constant:
+                        context.Decorate(spvVar, Decoration.Flat);
+                        break;
+                    case PixelImap.ScreenLinear:
+                        context.Decorate(spvVar, Decoration.NoPerspective);
+                        break;
+                }
+            }
+
             context.AddGlobalVariable(spvVar);
             dict.Add(attrInfo.BaseValue, spvVar);
         }
 
-        private static void DeclareInputOrOutput(CodeGenContext context, int attr, int component, bool isOutAttr, PixelImap iq = PixelImap.Unused)
+        private static void DeclareTransformFeedbackInputOrOutput(CodeGenContext context, int attr, bool isOutAttr, PixelImap iq = PixelImap.Unused)
         {
             var dict = isOutAttr ? context.Outputs : context.Inputs;
             var attrInfo = AttributeInfo.From(context.Config, attr, isOutAttr);
 
+            bool hasComponent = true;
+            int component = (attr >> 2) & 3;
+            int components = 1;
+            var type = attrInfo.Type & AggregateType.ElementTypeMask;
+
+            if (context.Config.LastInPipeline && isOutAttr)
+            {
+                components = context.Info.GetTransformFeedbackOutputComponents(attr);
+
+                if (components > 1)
+                {
+                    attr &= ~0xf;
+                    type = AggregateType.Vector | AggregateType.FP32;
+                    hasComponent = false;
+                }
+            }
+
             if (dict.ContainsKey(attr))
             {
                 return;
             }
 
             var storageClass = isOutAttr ? StorageClass.Output : StorageClass.Input;
-            var attrType = context.GetType(attrInfo.Type & AggregateType.ElementTypeMask);
+            var attrType = context.GetType(type, components);
 
             if (AttributeInfo.IsArrayAttributeSpirv(context.Config.Stage, isOutAttr) && (!attrInfo.IsBuiltin || AttributeInfo.IsArrayBuiltIn(attr)))
             {
@@ -656,11 +685,15 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
             int location = (attr - AttributeConsts.UserAttributeBase) / 16;
 
             context.Decorate(spvVar, Decoration.Location, (LiteralInteger)location);
-            context.Decorate(spvVar, Decoration.Component, (LiteralInteger)component);
+
+            if (hasComponent)
+            {
+                context.Decorate(spvVar, Decoration.Component, (LiteralInteger)component);
+            }
 
             if (isOutAttr)
             {
-                var tfOutput = context.GetTransformFeedbackOutput(location, component);
+                var tfOutput = context.Info.GetTransformFeedbackOutput(attr);
                 if (tfOutput.Valid)
                 {
                     context.Decorate(spvVar, Decoration.XfbBuffer, (LiteralInteger)tfOutput.Buffer);
diff --git a/Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvGenerator.cs b/Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvGenerator.cs
index fad7f9b88..69283b0a3 100644
--- a/Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvGenerator.cs
+++ b/Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvGenerator.cs
@@ -62,10 +62,18 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
                 context.AddCapability(Capability.TransformFeedback);
             }
 
-            if (config.Stage == ShaderStage.Fragment && context.Config.GpuAccessor.QueryHostSupportsFragmentShaderInterlock())
+            if (config.Stage == ShaderStage.Fragment)
             {
-                context.AddCapability(Capability.FragmentShaderPixelInterlockEXT);
-                context.AddExtension("SPV_EXT_fragment_shader_interlock");
+                if (context.Info.Inputs.Contains(AttributeConsts.Layer))
+                {
+                    context.AddCapability(Capability.Geometry);
+                }
+
+                if (context.Config.GpuAccessor.QueryHostSupportsFragmentShaderInterlock())
+                {
+                    context.AddCapability(Capability.FragmentShaderPixelInterlockEXT);
+                    context.AddExtension("SPV_EXT_fragment_shader_interlock");
+                }
             }
             else if (config.Stage == ShaderStage.Geometry)
             {
diff --git a/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgram.cs b/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgram.cs
index 85049abb2..7678a4bf6 100644
--- a/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgram.cs
+++ b/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgram.cs
@@ -71,12 +71,12 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
                     var locations = config.GpuAccessor.QueryTransformFeedbackVaryingLocations(tfbIndex);
                     var stride = config.GpuAccessor.QueryTransformFeedbackStride(tfbIndex);
 
-                    for (int j = 0; j < locations.Length; j++)
+                    for (int i = 0; i < locations.Length; i++)
                     {
-                        byte location = locations[j];
+                        byte location = locations[i];
                         if (location < 0xc0)
                         {
-                            context.Info.TransformFeedbackOutputs[location] = new TransformFeedbackOutput(tfbIndex, j * 4, stride);
+                            context.Info.TransformFeedbackOutputs[location] = new TransformFeedbackOutput(tfbIndex, i * 4, stride);
                         }
                     }
                 }
diff --git a/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgramInfo.cs b/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgramInfo.cs
index 43bdfaba5..57253148f 100644
--- a/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgramInfo.cs
+++ b/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgramInfo.cs
@@ -42,5 +42,40 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
 
             TransformFeedbackOutputs = new TransformFeedbackOutput[0xc0];
         }
+
+        public TransformFeedbackOutput GetTransformFeedbackOutput(int attr)
+        {
+            int index = attr / 4;
+            return TransformFeedbackOutputs[index];
+        }
+
+        public int GetTransformFeedbackOutputComponents(int attr)
+        {
+            int index = attr / 4;
+            int baseIndex = index & ~3;
+
+            int count = 1;
+
+            for (; count < 4; count++)
+            {
+                ref var prev = ref TransformFeedbackOutputs[baseIndex + count - 1];
+                ref var curr = ref TransformFeedbackOutputs[baseIndex + count];
+
+                int prevOffset = prev.Offset;
+                int currOffset = curr.Offset;
+
+                if (!prev.Valid || !curr.Valid || prevOffset + 4 != currOffset)
+                {
+                    break;
+                }
+            }
+
+            if (baseIndex + count <= index)
+            {
+                return 1;
+            }
+
+            return count;
+        }
     }
 }
\ No newline at end of file
diff --git a/Ryujinx.Graphics.Shader/Translation/ShaderConfig.cs b/Ryujinx.Graphics.Shader/Translation/ShaderConfig.cs
index c70ec16c6..fcf35ce27 100644
--- a/Ryujinx.Graphics.Shader/Translation/ShaderConfig.cs
+++ b/Ryujinx.Graphics.Shader/Translation/ShaderConfig.cs
@@ -17,6 +17,7 @@ namespace Ryujinx.Graphics.Shader.Translation
         public ShaderStage Stage { get; }
 
         public bool GpPassthrough { get; }
+        public bool LastInPipeline { get; private set; }
         public bool LastInVertexPipeline { get; private set; }
 
         public int ThreadsPerInputPrimitive { get; }
@@ -143,6 +144,7 @@ namespace Ryujinx.Graphics.Shader.Translation
             OmapSampleMask           = header.OmapSampleMask;
             OmapDepth                = header.OmapDepth;
             TransformFeedbackEnabled = gpuAccessor.QueryTransformFeedbackEnabled();
+            LastInPipeline           = true;
             LastInVertexPipeline     = header.Stage < ShaderStage.Fragment;
         }
 
@@ -306,6 +308,8 @@ namespace Ryujinx.Graphics.Shader.Translation
                 config._perPatchAttributeLocations = locationsMap;
             }
 
+            LastInPipeline = false;
+
             // We don't consider geometry shaders using the geometry shader passthrough feature
             // as being the last because when this feature is used, it can't actually modify any of the outputs,
             // so the stage that comes before it is the last one that can do modifications.