From 698ac0413bea8b8cbd9177ff3e4d2ed73ba801ad Mon Sep 17 00:00:00 2001
From: Gabriel A <gab.dark.100@gmail.com>
Date: Mon, 21 Aug 2023 14:01:45 -0300
Subject: [PATCH] Support ballot operations with divergent control flow on
 Adreno

---
 src/Ryujinx.Graphics.GAL/Capabilities.cs      |  3 ++
 .../Shader/GpuAccessorBase.cs                 |  2 +
 src/Ryujinx.Graphics.OpenGL/OpenGLRenderer.cs |  1 +
 .../CodeGen/Spirv/Instructions.cs             | 40 ++++++++++++++++---
 .../CodeGen/Spirv/SpirvGenerator.cs           |  8 +++-
 src/Ryujinx.Graphics.Shader/IGpuAccessor.cs   |  9 +++++
 .../StructuredIr/HelperFunctionsMask.cs       |  1 +
 .../StructuredIr/StructuredProgram.cs         |  3 ++
 .../Translation/HostCapabilities.cs           |  3 ++
 .../Translation/TranslatorContext.cs          |  1 +
 src/Ryujinx.Graphics.Vulkan/VulkanRenderer.cs |  1 +
 11 files changed, 66 insertions(+), 6 deletions(-)

diff --git a/src/Ryujinx.Graphics.GAL/Capabilities.cs b/src/Ryujinx.Graphics.GAL/Capabilities.cs
index dc927eaba..cdd7da8cc 100644
--- a/src/Ryujinx.Graphics.GAL/Capabilities.cs
+++ b/src/Ryujinx.Graphics.GAL/Capabilities.cs
@@ -37,6 +37,7 @@ namespace Ryujinx.Graphics.GAL
         public readonly bool SupportsCubemapView;
         public readonly bool SupportsNonConstantTextureOffset;
         public readonly bool SupportsShaderBallot;
+        public readonly bool SupportsShaderBallotDivergence;
         public readonly bool SupportsShaderBarrierDivergence;
         public readonly bool SupportsShaderFloat64;
         public readonly bool SupportsTextureGatherOffsets;
@@ -93,6 +94,7 @@ namespace Ryujinx.Graphics.GAL
             bool supportsCubemapView,
             bool supportsNonConstantTextureOffset,
             bool supportsShaderBallot,
+            bool supportsShaderBallotDivergence,
             bool supportsShaderBarrierDivergence,
             bool supportsShaderFloat64,
             bool supportsTextureGatherOffsets,
@@ -145,6 +147,7 @@ namespace Ryujinx.Graphics.GAL
             SupportsCubemapView = supportsCubemapView;
             SupportsNonConstantTextureOffset = supportsNonConstantTextureOffset;
             SupportsShaderBallot = supportsShaderBallot;
+            SupportsShaderBallotDivergence = supportsShaderBallotDivergence;
             SupportsShaderBarrierDivergence = supportsShaderBarrierDivergence;
             SupportsShaderFloat64 = supportsShaderFloat64;
             SupportsTextureGatherOffsets = supportsTextureGatherOffsets;
diff --git a/src/Ryujinx.Graphics.Gpu/Shader/GpuAccessorBase.cs b/src/Ryujinx.Graphics.Gpu/Shader/GpuAccessorBase.cs
index a5b31363b..fd0c48c40 100644
--- a/src/Ryujinx.Graphics.Gpu/Shader/GpuAccessorBase.cs
+++ b/src/Ryujinx.Graphics.Gpu/Shader/GpuAccessorBase.cs
@@ -180,6 +180,8 @@ namespace Ryujinx.Graphics.Gpu.Shader
 
         public bool QueryHostSupportsShaderBallot() => _context.Capabilities.SupportsShaderBallot;
 
+        public bool QueryHostSupportsShaderBallotDivergence() => _context.Capabilities.SupportsShaderBallotDivergence;
+
         public bool QueryHostSupportsShaderBarrierDivergence() => _context.Capabilities.SupportsShaderBarrierDivergence;
 
         public bool QueryHostSupportsShaderFloat64() => _context.Capabilities.SupportsShaderFloat64;
diff --git a/src/Ryujinx.Graphics.OpenGL/OpenGLRenderer.cs b/src/Ryujinx.Graphics.OpenGL/OpenGLRenderer.cs
index 64ba4e3ee..9390271a4 100644
--- a/src/Ryujinx.Graphics.OpenGL/OpenGLRenderer.cs
+++ b/src/Ryujinx.Graphics.OpenGL/OpenGLRenderer.cs
@@ -167,6 +167,7 @@ namespace Ryujinx.Graphics.OpenGL
                 supportsNonConstantTextureOffset: HwCapabilities.SupportsNonConstantTextureOffset,
                 supportsScaledVertexFormats: true,
                 supportsShaderBallot: HwCapabilities.SupportsShaderBallot,
+                supportsShaderBallotDivergence: true,
                 supportsShaderBarrierDivergence: !(intelWindows || intelUnix),
                 supportsShaderFloat64: true,
                 supportsTextureGatherOffsets: true,
diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs
index 601753cb0..34bc91f82 100644
--- a/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs
+++ b/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs
@@ -227,14 +227,44 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
         private static OperationResult GenerateBallot(CodeGenContext context, AstOperation operation)
         {
             var source = operation.GetSource(0);
+            var predicate = context.Get(AggregateType.Bool, source);
 
-            var uvec4Type = context.TypeVector(context.TypeU32(), 4);
-            var execution = context.Constant(context.TypeU32(), Scope.Subgroup);
+            if (!context.HostCapabilities.SupportsShaderBallotDivergence &&
+                (context.CurrentBlock.Type != AstBlockType.Main || context.MayHaveReturned || !context.IsMainFunction))
+            {
+                // If divergent ballot is not supported, we can emulate it with a subgroupAdd operation,
+                // where we add a bit mask with a unique bit set for each subgroup invocation.
 
-            var maskVector = context.GroupNonUniformBallot(uvec4Type, execution, context.Get(AggregateType.Bool, source));
-            var mask = context.CompositeExtract(context.TypeU32(), maskVector, (SpvLiteralInteger)operation.Index);
+                var bit = context.Select(
+                    context.TypeU32(),
+                    predicate,
+                    context.Constant(context.TypeU32(), 1),
+                    context.Constant(context.TypeU32(), 0));
 
-            return new OperationResult(AggregateType.U32, mask);
+                var threadId = GetScalarInput(context, IoVariable.SubgroupLaneId);
+                var threadIdLow = context.BitwiseAnd(context.TypeU32(), threadId, context.Constant(context.TypeU32(), 0x1f));
+                var threadIdHigh = context.ShiftRightLogical(context.TypeU32(), threadId, context.Constant(context.TypeU32(), 5));
+                var bitMask = context.ShiftLeftLogical(context.TypeU32(), bit, threadIdLow);
+                var isGroup = context.IEqual(context.TypeBool(), threadIdHigh, context.Constant(context.TypeU32(), operation.Index));
+                bitMask = context.Select(context.TypeU32(), isGroup, bitMask, context.Constant(context.TypeU32(), 0));
+                var mask = context.GroupNonUniformIAdd(
+                    context.TypeU32(),
+                    context.Constant(context.TypeU32(), Scope.Subgroup),
+                    GroupOperation.Reduce,
+                    bitMask);
+
+                return new OperationResult(AggregateType.U32, mask);
+            }
+            else
+            {
+                var uvec4Type = context.TypeVector(context.TypeU32(), 4);
+                var execution = context.Constant(context.TypeU32(), Scope.Subgroup);
+
+                var maskVector = context.GroupNonUniformBallot(uvec4Type, execution, predicate);
+                var mask = context.CompositeExtract(context.TypeU32(), maskVector, (SpvLiteralInteger)operation.Index);
+
+                return new OperationResult(AggregateType.U32, mask);
+            }
         }
 
         private static OperationResult GenerateBarrier(CodeGenContext context, AstOperation operation)
diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvGenerator.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvGenerator.cs
index ccfdc46d0..5438119ed 100644
--- a/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvGenerator.cs
+++ b/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvGenerator.cs
@@ -28,7 +28,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
             _poolLock = new object();
         }
 
-        private const HelperFunctionsMask NeedsInvocationIdMask = HelperFunctionsMask.SwizzleAdd;
+        private const HelperFunctionsMask NeedsInvocationIdMask = HelperFunctionsMask.SwizzleAdd | HelperFunctionsMask.Ballot;
 
         public static byte[] Generate(StructuredProgramInfo info, CodeGenParameters parameters)
         {
@@ -51,6 +51,12 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
             context.AddCapability(Capability.ImageQuery);
             context.AddCapability(Capability.SampledBuffer);
 
+            if (info.HelperFunctionsMask.HasFlag(HelperFunctionsMask.Ballot) && !context.HostCapabilities.SupportsShaderBallotDivergence)
+            {
+                // Ballots might be emulated with subgroupAdd in those cases.
+                context.AddCapability(Capability.GroupNonUniformArithmetic);
+            }
+
             if (parameters.Definitions.TransformFeedbackEnabled && parameters.Definitions.LastInVertexPipeline)
             {
                 context.AddCapability(Capability.TransformFeedback);
diff --git a/src/Ryujinx.Graphics.Shader/IGpuAccessor.cs b/src/Ryujinx.Graphics.Shader/IGpuAccessor.cs
index df6d29dc5..22f0ba611 100644
--- a/src/Ryujinx.Graphics.Shader/IGpuAccessor.cs
+++ b/src/Ryujinx.Graphics.Shader/IGpuAccessor.cs
@@ -312,6 +312,15 @@ namespace Ryujinx.Graphics.Shader
             return true;
         }
 
+        /// <summary>
+        /// Queries host GPU shader support for ballot instructions on divergent control flow paths.
+        /// </summary>
+        /// <returns>True if the GPU supports ballot instructions on divergent control flow paths, false otherwise</returns>
+        bool QueryHostSupportsShaderBallotDivergence()
+        {
+            return true;
+        }
+
         /// <summary>
         /// Queries host GPU shader support for barrier instructions on divergent control flow paths.
         /// </summary>
diff --git a/src/Ryujinx.Graphics.Shader/StructuredIr/HelperFunctionsMask.cs b/src/Ryujinx.Graphics.Shader/StructuredIr/HelperFunctionsMask.cs
index 2a3d65e75..f7ecbe4be 100644
--- a/src/Ryujinx.Graphics.Shader/StructuredIr/HelperFunctionsMask.cs
+++ b/src/Ryujinx.Graphics.Shader/StructuredIr/HelperFunctionsMask.cs
@@ -9,5 +9,6 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
         MultiplyHighU32 = 1 << 3,
         SwizzleAdd = 1 << 10,
         FSI = 1 << 11,
+        Ballot = 1 << 12,
     }
 }
diff --git a/src/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgram.cs b/src/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgram.cs
index 2e2df7546..70c343592 100644
--- a/src/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgram.cs
+++ b/src/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgram.cs
@@ -328,6 +328,9 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
                 case Instruction.FSIEnd:
                     context.Info.HelperFunctionsMask |= HelperFunctionsMask.FSI;
                     break;
+                case Instruction.Ballot:
+                    context.Info.HelperFunctionsMask |= HelperFunctionsMask.Ballot;
+                    break;
             }
         }
 
diff --git a/src/Ryujinx.Graphics.Shader/Translation/HostCapabilities.cs b/src/Ryujinx.Graphics.Shader/Translation/HostCapabilities.cs
index 2523272b0..431200033 100644
--- a/src/Ryujinx.Graphics.Shader/Translation/HostCapabilities.cs
+++ b/src/Ryujinx.Graphics.Shader/Translation/HostCapabilities.cs
@@ -7,6 +7,7 @@ namespace Ryujinx.Graphics.Shader.Translation
         public readonly bool SupportsFragmentShaderOrderingIntel;
         public readonly bool SupportsGeometryShaderPassthrough;
         public readonly bool SupportsShaderBallot;
+        public readonly bool SupportsShaderBallotDivergence;
         public readonly bool SupportsShaderBarrierDivergence;
         public readonly bool SupportsTextureShadowLod;
         public readonly bool SupportsViewportMask;
@@ -17,6 +18,7 @@ namespace Ryujinx.Graphics.Shader.Translation
             bool supportsFragmentShaderOrderingIntel,
             bool supportsGeometryShaderPassthrough,
             bool supportsShaderBallot,
+            bool supportsShaderBallotDivergence,
             bool supportsShaderBarrierDivergence,
             bool supportsTextureShadowLod,
             bool supportsViewportMask)
@@ -26,6 +28,7 @@ namespace Ryujinx.Graphics.Shader.Translation
             SupportsFragmentShaderOrderingIntel = supportsFragmentShaderOrderingIntel;
             SupportsGeometryShaderPassthrough = supportsGeometryShaderPassthrough;
             SupportsShaderBallot = supportsShaderBallot;
+            SupportsShaderBallotDivergence = supportsShaderBallotDivergence;
             SupportsShaderBarrierDivergence = supportsShaderBarrierDivergence;
             SupportsTextureShadowLod = supportsTextureShadowLod;
             SupportsViewportMask = supportsViewportMask;
diff --git a/src/Ryujinx.Graphics.Shader/Translation/TranslatorContext.cs b/src/Ryujinx.Graphics.Shader/Translation/TranslatorContext.cs
index a193ab3c4..c55cd11f9 100644
--- a/src/Ryujinx.Graphics.Shader/Translation/TranslatorContext.cs
+++ b/src/Ryujinx.Graphics.Shader/Translation/TranslatorContext.cs
@@ -361,6 +361,7 @@ namespace Ryujinx.Graphics.Shader.Translation
                 GpuAccessor.QueryHostSupportsFragmentShaderOrderingIntel(),
                 GpuAccessor.QueryHostSupportsGeometryShaderPassthrough(),
                 GpuAccessor.QueryHostSupportsShaderBallot(),
+                GpuAccessor.QueryHostSupportsShaderBallotDivergence(),
                 GpuAccessor.QueryHostSupportsShaderBarrierDivergence(),
                 GpuAccessor.QueryHostSupportsTextureShadowLod(),
                 GpuAccessor.QueryHostSupportsViewportMask());
diff --git a/src/Ryujinx.Graphics.Vulkan/VulkanRenderer.cs b/src/Ryujinx.Graphics.Vulkan/VulkanRenderer.cs
index 1db104f83..674335c2e 100644
--- a/src/Ryujinx.Graphics.Vulkan/VulkanRenderer.cs
+++ b/src/Ryujinx.Graphics.Vulkan/VulkanRenderer.cs
@@ -616,6 +616,7 @@ namespace Ryujinx.Graphics.Vulkan
                 supportsNonConstantTextureOffset: false,
                 supportsScaledVertexFormats: FormatCapabilities.SupportsScaledVertexFormats(),
                 supportsShaderBallot: false,
+                supportsShaderBallotDivergence: Vendor != Vendor.Qualcomm,
                 supportsShaderBarrierDivergence: Vendor != Vendor.Intel,
                 supportsShaderFloat64: Capabilities.SupportsShaderFloat64,
                 supportsTextureGatherOffsets: features2.Features.ShaderImageGatherExtended && !IsMoltenVk,