From f16abe6dac72350f286c37899a3e410ae9495eed Mon Sep 17 00:00:00 2001
From: Gabriel A <gab.dark.100@gmail.com>
Date: Wed, 17 Jan 2024 22:25:51 -0300
Subject: [PATCH] Experimental 4KB tracking mode

---
 .../Jit/AddressSpacePageProtections.cs        | 340 ++++++++++++++++++
 src/Ryujinx.Cpu/Jit/AddressSpacePartition.cs  |  55 ++-
 .../Jit/AddressSpacePartitionAllocator.cs     |   9 +-
 .../Jit/AddressSpacePartitioned.cs            |  68 +++-
 .../Jit/MemoryManagerHostTracked.cs           |   4 +-
 5 files changed, 453 insertions(+), 23 deletions(-)
 create mode 100644 src/Ryujinx.Cpu/Jit/AddressSpacePageProtections.cs

diff --git a/src/Ryujinx.Cpu/Jit/AddressSpacePageProtections.cs b/src/Ryujinx.Cpu/Jit/AddressSpacePageProtections.cs
new file mode 100644
index 000000000..3f4cd2ff4
--- /dev/null
+++ b/src/Ryujinx.Cpu/Jit/AddressSpacePageProtections.cs
@@ -0,0 +1,340 @@
+using Ryujinx.Common;
+using Ryujinx.Common.Collections;
+using Ryujinx.Memory;
+using System;
+using System.Diagnostics;
+
+namespace Ryujinx.Cpu.Jit
+{
+    class AddressSpacePageProtections : IDisposable
+    {
+        private const ulong GuestPageSize = 0x1000;
+
+        class PageProtection : IntrusiveRedBlackTreeNode<PageProtection>, IComparable<PageProtection>
+        {
+            public readonly AddressSpacePartitionAllocation Memory;
+            public readonly ulong Offset;
+            public readonly ulong Address;
+            public readonly ulong Size;
+
+            private MemoryBlock _viewBlock;
+
+            public bool IsMapped => _viewBlock != null;
+
+            public PageProtection(AddressSpacePartitionAllocation memory, ulong offset, ulong address, ulong size)
+            {
+                Memory = memory;
+                Offset = offset;
+                Address = address;
+                Size = size;
+            }
+
+            public void SetViewBlock(MemoryBlock block)
+            {
+                _viewBlock = block;
+            }
+
+            public void Unmap()
+            {
+                if (_viewBlock != null)
+                {
+                    Memory.UnmapView(_viewBlock, Offset, MemoryBlock.GetPageSize());
+                    _viewBlock = null;
+                }
+            }
+
+            public bool OverlapsWith(ulong va, ulong size)
+            {
+                return Address < va + size && va < Address + Size;
+            }
+
+            public int CompareTo(PageProtection other)
+            {
+                if (OverlapsWith(other.Address, other.Size))
+                {
+                    return 0;
+                }
+                else if (Address < other.Address)
+                {
+                    return -1;
+                }
+                else
+                {
+                    return 1;
+                }
+            }
+        }
+
+        private readonly IntrusiveRedBlackTree<PageProtection> _protectionTree;
+
+        public AddressSpacePageProtections()
+        {
+            _protectionTree = new();
+        }
+
+        public void Reprotect(
+            AddressSpacePartitionAllocator asAllocator,
+            AddressSpacePartitioned addressSpace,
+            AddressSpacePartition partition,
+            ulong va,
+            ulong endVa,
+            MemoryPermission protection,
+            Action<ulong, IntPtr, ulong> updatePtCallback)
+        {
+            while (va < endVa)
+            {
+                ReprotectPage(asAllocator, addressSpace, partition, va, protection, updatePtCallback);
+
+                va += GuestPageSize;
+            }
+        }
+
+        private void ReprotectPage(
+            AddressSpacePartitionAllocator asAllocator,
+            AddressSpacePartitioned addressSpace,
+            AddressSpacePartition partition,
+            ulong va,
+            MemoryPermission protection,
+            Action<ulong, IntPtr, ulong> updatePtCallback)
+        {
+            ulong pageSize = MemoryBlock.GetPageSize();
+
+            PageProtection pageProtection = _protectionTree.GetNode(new PageProtection(default, 0, va, 1));
+
+            if (pageProtection == null)
+            {
+                ulong firstPage = BitUtils.AlignDown(va, pageSize);
+                ulong lastPage = BitUtils.AlignUp(va + GuestPageSize, pageSize) - GuestPageSize;
+
+                AddressSpacePartitionAllocation block;
+                PageProtection adjPageProtection = null;
+                ulong blockOffset = 0;
+
+                if (va == firstPage && va > partition.Address)
+                {
+                    block = asAllocator.AllocatePage(firstPage - pageSize, pageSize * 2);
+
+                    MapView(addressSpace, partition, block, 0, pageSize, va - GuestPageSize, out MemoryBlock adjMemory);
+
+                    adjPageProtection = new PageProtection(block, 0, va - GuestPageSize, GuestPageSize);
+                    adjPageProtection.SetViewBlock(adjMemory);
+                    blockOffset = pageSize;
+                }
+                else if (va == lastPage)
+                {
+                    block = asAllocator.AllocatePage(firstPage, pageSize * 2);
+
+                    MapView(addressSpace, partition, block, pageSize, pageSize, va + GuestPageSize, out MemoryBlock adjMemory);
+
+                    adjPageProtection = new PageProtection(block, pageSize, va + GuestPageSize, GuestPageSize);
+                    adjPageProtection.SetViewBlock(adjMemory);
+                }
+                else
+                {
+                    block = asAllocator.AllocatePage(firstPage, pageSize);
+                }
+
+                if (!MapView(addressSpace, partition, block, blockOffset, pageSize, va, out MemoryBlock viewMemory))
+                {
+                    block.Dispose();
+
+                    return;
+                }
+
+                pageProtection = new PageProtection(block, blockOffset, va, GuestPageSize);
+                pageProtection.SetViewBlock(viewMemory);
+                _protectionTree.Add(pageProtection);
+
+                if (adjPageProtection != null)
+                {
+                    Debug.Assert(_protectionTree.GetNode(adjPageProtection) == null);
+                    _protectionTree.Add(adjPageProtection);
+                }
+            }
+
+            Debug.Assert(pageProtection.IsMapped || partition.GetPrivateAllocation(va).Memory == null);
+
+            pageProtection.Memory.Reprotect(pageProtection.Offset, pageSize, protection, false);
+
+            updatePtCallback(va, pageProtection.Memory.GetPointer(pageProtection.Offset + (va & (pageSize - 1)), GuestPageSize), GuestPageSize);
+        }
+
+        public void UpdateMappings(AddressSpacePartition partition, ulong va, ulong size)
+        {
+            ulong pageSize = MemoryBlock.GetPageSize();
+
+            PageProtection pageProtection = GetLowestOverlap(va, size);
+
+            while (pageProtection != null)
+            {
+                if (pageProtection.Address >= va + size)
+                {
+                    break;
+                }
+
+                bool mapped = MapView(
+                    partition,
+                    pageProtection.Memory,
+                    pageProtection.Offset,
+                    pageSize,
+                    pageProtection.Address,
+                    out MemoryBlock memory);
+
+                Debug.Assert(mapped);
+
+                pageProtection.SetViewBlock(memory);
+                pageProtection = pageProtection.Successor;
+            }
+        }
+
+        public void Remove(ulong va, ulong size)
+        {
+            ulong pageSize = MemoryBlock.GetPageSize();
+
+            PageProtection pageProtection = GetLowestOverlap(va, size);
+
+            while (pageProtection != null)
+            {
+                if (pageProtection.Address >= va + size)
+                {
+                    break;
+                }
+
+                ulong firstPage = BitUtils.AlignDown(pageProtection.Address, pageSize);
+                ulong lastPage = BitUtils.AlignUp(pageProtection.Address + GuestPageSize, pageSize) - GuestPageSize;
+
+                bool canDelete;
+
+                if (pageProtection.Address == firstPage)
+                {
+                    canDelete = pageProtection.Predecessor == null ||
+                        pageProtection.Predecessor.Address + pageProtection.Predecessor.Size != pageProtection.Address ||
+                        !pageProtection.Predecessor.IsMapped;
+                }
+                else if (pageProtection.Address == lastPage)
+                {
+                    canDelete = pageProtection.Successor == null ||
+                        pageProtection.Address + pageProtection.Size != pageProtection.Successor.Address ||
+                        !pageProtection.Successor.IsMapped;
+                }
+                else
+                {
+                    canDelete = true;
+                }
+
+                PageProtection successor = pageProtection.Successor;
+
+                if (canDelete)
+                {
+                    if (pageProtection.Address == firstPage &&
+                        pageProtection.Predecessor != null &&
+                        pageProtection.Predecessor.Address + pageProtection.Predecessor.Size == pageProtection.Address)
+                    {
+                        _protectionTree.Remove(pageProtection.Predecessor);
+                    }
+                    else if (pageProtection.Address == lastPage &&
+                        pageProtection.Successor != null &&
+                        pageProtection.Address + pageProtection.Size == pageProtection.Successor.Address)
+                    {
+                        successor = successor.Successor;
+                        _protectionTree.Remove(pageProtection.Successor);
+                    }
+
+                    _protectionTree.Remove(pageProtection);
+                    pageProtection.Memory.Dispose();
+                }
+                else
+                {
+                    pageProtection.Unmap();
+                }
+
+                pageProtection = successor;
+            }
+        }
+
+        private static bool MapView(
+            AddressSpacePartitioned addressSpace,
+            AddressSpacePartition partition,
+            AddressSpacePartitionAllocation dstBlock,
+            ulong dstOffset,
+            ulong size,
+            ulong va,
+            out MemoryBlock memory)
+        {
+            PrivateRange privateRange;
+
+            if (va >= partition.Address && va < partition.EndAddress)
+            {
+                privateRange = partition.GetPrivateAllocation(va);
+            }
+            else
+            {
+                privateRange = addressSpace.GetPrivateAllocation(va);
+            }
+
+            memory = privateRange.Memory;
+
+            if (privateRange.Memory == null)
+            {
+                return false;
+            }
+
+            dstBlock.MapView(privateRange.Memory, privateRange.Offset & ~(MemoryBlock.GetPageSize() - 1), dstOffset, size);
+
+            return true;
+        }
+
+        private static bool MapView(
+            AddressSpacePartition partition,
+            AddressSpacePartitionAllocation dstBlock,
+            ulong dstOffset,
+            ulong size,
+            ulong va,
+            out MemoryBlock memory)
+        {
+            Debug.Assert(va >= partition.Address && va < partition.EndAddress);
+
+            PrivateRange privateRange = partition.GetPrivateAllocation(va);
+
+            memory = privateRange.Memory;
+
+            if (privateRange.Memory == null)
+            {
+                return false;
+            }
+
+            dstBlock.MapView(privateRange.Memory, privateRange.Offset & ~(size - 1), dstOffset, size);
+
+            return true;
+        }
+
+        private PageProtection GetLowestOverlap(ulong va, ulong size)
+        {
+            PageProtection pageProtection = _protectionTree.GetNode(new PageProtection(default, 0, va, size));
+
+            if (pageProtection == null)
+            {
+                return null;
+            }
+
+            while (pageProtection.Predecessor != null && pageProtection.Predecessor.OverlapsWith(va, size))
+            {
+                pageProtection = pageProtection.Predecessor;
+            }
+
+            return pageProtection;
+        }
+
+        protected virtual void Dispose(bool disposing)
+        {
+            Remove(0, ulong.MaxValue);
+            Debug.Assert(_protectionTree.Count == 0);
+        }
+
+        public void Dispose()
+        {
+            Dispose(disposing: true);
+            GC.SuppressFinalize(this);
+        }
+    }
+}
\ No newline at end of file
diff --git a/src/Ryujinx.Cpu/Jit/AddressSpacePartition.cs b/src/Ryujinx.Cpu/Jit/AddressSpacePartition.cs
index 309439dfb..7401a2784 100644
--- a/src/Ryujinx.Cpu/Jit/AddressSpacePartition.cs
+++ b/src/Ryujinx.Cpu/Jit/AddressSpacePartition.cs
@@ -98,10 +98,6 @@ namespace Ryujinx.Cpu.Jit
 
             public PrivateMapping(ulong address, ulong size, PrivateMemoryAllocation privateAllocation)
             {
-                if (size == 0)
-                {
-                    throw new Exception("huh? size is 0");
-                }
                 Address = address;
                 Size = size;
                 PrivateAllocation = privateAllocation;
@@ -169,6 +165,7 @@ namespace Ryujinx.Cpu.Jit
         private readonly PrivateMemoryAllocator _privateMemoryAllocator;
         private readonly IntrusiveRedBlackTree<Mapping> _mappingTree;
         private readonly IntrusiveRedBlackTree<PrivateMapping> _privateTree;
+        private readonly AddressSpacePageProtections _pageProtections;
 
         private readonly ReaderWriterLockSlim _treeLock;
 
@@ -190,6 +187,7 @@ namespace Ryujinx.Cpu.Jit
             _privateMemoryAllocator = new PrivateMemoryAllocator(DefaultBlockAlignment, MemoryAllocationFlags.Mirrorable);
             _mappingTree = new IntrusiveRedBlackTree<Mapping>();
             _privateTree = new IntrusiveRedBlackTree<PrivateMapping>();
+            _pageProtections = new AddressSpacePageProtections();
             _treeLock = new ReaderWriterLockSlim();
 
             _mappingTree.Add(new Mapping(address, size, MappingType.None));
@@ -240,6 +238,8 @@ namespace Ryujinx.Cpu.Jit
             }
 
             Update(va, pa, size, MappingType.Private);
+
+            _pageProtections.UpdateMappings(this, va, size);
         }
 
         public void Unmap(ulong va, ulong size)
@@ -258,9 +258,11 @@ namespace Ryujinx.Cpu.Jit
             }
 
             Update(va, 0UL, size, MappingType.None);
+
+            _pageProtections.Remove(va, size);
         }
 
-        public void Reprotect(ulong va, ulong size, MemoryPermission protection)
+        public void ReprotectAligned(ulong va, ulong size, MemoryPermission protection)
         {
             Debug.Assert(va >= Address);
             Debug.Assert(va + size <= EndAddress);
@@ -282,6 +284,19 @@ namespace Ryujinx.Cpu.Jit
             }
         }
 
+        public void Reprotect(
+            ulong va,
+            ulong size,
+            MemoryPermission protection,
+            AddressSpacePartitionAllocator asAllocator,
+            AddressSpacePartitioned addressSpace,
+            Action<ulong, IntPtr, ulong> updatePtCallback)
+        {
+            ulong endVa = va + size;
+
+            _pageProtections.Reprotect(asAllocator, addressSpace, this, va, endVa, protection, updatePtCallback);
+        }
+
         public IntPtr GetPointer(ulong va, ulong size)
         {
             Debug.Assert(va >= Address);
@@ -315,6 +330,8 @@ namespace Ryujinx.Cpu.Jit
                     updatePtCallback(EndAddress - _hostPageSize, _baseMemory.GetPointer(Size, _hostPageSize), _hostPageSize);
 
                     _hasBridgeAtEnd = true;
+
+                    _pageProtections.UpdateMappings(partitionAfter, EndAddress, GuestPageSize);
                 }
                 else
                 {
@@ -326,6 +343,8 @@ namespace Ryujinx.Cpu.Jit
                     }
 
                     _hasBridgeAtEnd = false;
+
+                    _pageProtections.Remove(EndAddress, GuestPageSize);
                 }
 
                 _cachedFirstPagePa = firstPagePa;
@@ -346,6 +365,8 @@ namespace Ryujinx.Cpu.Jit
             _cachedLastPagePa = ulong.MaxValue;
 
             _hasBridgeAtEnd = false;
+
+            _pageProtections.Remove(EndAddress, GuestPageSize);
         }
 
         private (MemoryBlock, ulong) GetFirstPageMemoryAndOffset()
@@ -392,6 +413,27 @@ namespace Ryujinx.Cpu.Jit
             return (_backingMemory, _lastPagePa.Value & ~(_hostPageSize - 1));
         }
 
+        public PrivateRange GetPrivateAllocation(ulong va)
+        {
+            _treeLock.EnterReadLock();
+
+            try
+            {
+                PrivateMapping map = _privateTree.GetNode(new PrivateMapping(va, 1UL, default));
+
+                if (map != null && map.PrivateAllocation.IsValid)
+                {
+                    return new(map.PrivateAllocation.Memory, map.PrivateAllocation.Offset + (va - map.Address), map.Size - (va - map.Address));
+                }
+            }
+            finally
+            {
+                _treeLock.ExitReadLock();
+            }
+
+            return PrivateRange.Empty;
+        }
+
         private void Update(ulong va, ulong pa, ulong size, MappingType type)
         {
             _treeLock.EnterWriteLock();
@@ -654,7 +696,8 @@ namespace Ryujinx.Cpu.Jit
         {
             GC.SuppressFinalize(this);
 
-            _privateMemoryAllocator?.Dispose();
+            _privateMemoryAllocator.Dispose();
+            _pageProtections.Dispose();
             _baseMemory.Dispose();
         }
     }
diff --git a/src/Ryujinx.Cpu/Jit/AddressSpacePartitionAllocator.cs b/src/Ryujinx.Cpu/Jit/AddressSpacePartitionAllocator.cs
index 561b92af4..e897d05da 100644
--- a/src/Ryujinx.Cpu/Jit/AddressSpacePartitionAllocator.cs
+++ b/src/Ryujinx.Cpu/Jit/AddressSpacePartitionAllocator.cs
@@ -2,7 +2,6 @@ using Ryujinx.Common.Collections;
 using Ryujinx.Memory;
 using Ryujinx.Memory.Tracking;
 using System;
-using System.Threading;
 
 namespace Ryujinx.Cpu.Jit
 {
@@ -167,6 +166,14 @@ namespace Ryujinx.Cpu.Jit
             return allocation;
         }
 
+        public AddressSpacePartitionAllocation AllocatePage(ulong va, ulong size)
+        {
+            AddressSpacePartitionAllocation allocation = new(this, Allocate(size, MemoryBlock.GetPageSize(), CreateBlock));
+            allocation.RegisterMapping(va, va + size, 0);
+
+            return allocation;
+        }
+
         private Block CreateBlock(MemoryBlock memory, ulong size)
         {
             return new Block(_tracking, memory, size, _lock);
diff --git a/src/Ryujinx.Cpu/Jit/AddressSpacePartitioned.cs b/src/Ryujinx.Cpu/Jit/AddressSpacePartitioned.cs
index ad082c96f..c7ff755c6 100644
--- a/src/Ryujinx.Cpu/Jit/AddressSpacePartitioned.cs
+++ b/src/Ryujinx.Cpu/Jit/AddressSpacePartitioned.cs
@@ -9,6 +9,8 @@ namespace Ryujinx.Cpu.Jit
 {
     class AddressSpacePartitioned : IDisposable
     {
+        public static readonly bool Use4KBProtection = false;
+
         private const int PartitionBits = 25;
         private const ulong PartitionSize = 1UL << PartitionBits;
 
@@ -89,27 +91,65 @@ namespace Ryujinx.Cpu.Jit
             }
         }
 
-        public void Reprotect(ulong va, ulong size, MemoryPermission protection, MemoryTracking tracking)
+        public void Reprotect(ulong va, ulong size, MemoryPermission protection)
         {
             ulong endVa = va + size;
 
-            while (va < endVa)
+            if (Use4KBProtection)
             {
-                AddressSpacePartition partition = FindPartition(va);
-
-                if (partition == null)
+                lock (_partitions)
                 {
-                    va += PartitionSize - (va & (PartitionSize - 1));
+                    while (va < endVa)
+                    {
+                        AddressSpacePartition partition = FindPartition(va);
 
-                    continue;
+                        if (partition == null)
+                        {
+                            va += PartitionSize - (va & (PartitionSize - 1));
+
+                            continue;
+                        }
+
+                        (ulong clampedVa, ulong clampedEndVa) = ClampRange(partition, va, endVa);
+
+                        partition.Reprotect(clampedVa, clampedEndVa - clampedVa, protection, _asAllocator, this, _updatePtCallback);
+
+                        va += clampedEndVa - clampedVa;
+                    }
                 }
-
-                (ulong clampedVa, ulong clampedEndVa) = ClampRange(partition, va, endVa);
-
-                partition.Reprotect(clampedVa, clampedEndVa - clampedVa, protection);
-
-                va += clampedEndVa - clampedVa;
             }
+            else
+            {
+                while (va < endVa)
+                {
+                    AddressSpacePartition partition = FindPartition(va);
+
+                    if (partition == null)
+                    {
+                        va += PartitionSize - (va & (PartitionSize - 1));
+
+                        continue;
+                    }
+
+                    (ulong clampedVa, ulong clampedEndVa) = ClampRange(partition, va, endVa);
+
+                    partition.ReprotectAligned(clampedVa, clampedEndVa - clampedVa, protection);
+
+                    va += clampedEndVa - clampedVa;
+                }
+            }
+        }
+
+        public PrivateRange GetPrivateAllocation(ulong va)
+        {
+            AddressSpacePartition partition = FindPartition(va);
+
+            if (partition == null)
+            {
+                return PrivateRange.Empty;
+            }
+
+            return partition.GetPrivateAllocation(va);
         }
 
         public PrivateRange GetFirstPrivateAllocation(ulong va, ulong size, out ulong nextVa)
@@ -226,7 +266,7 @@ namespace Ryujinx.Cpu.Jit
         private int FindPartitionIndexLocked(ulong va)
         {
             int left = 0;
-            int middle = 0;
+            int middle;
             int right = _partitions.Count - 1;
 
             while (left <= right)
diff --git a/src/Ryujinx.Cpu/Jit/MemoryManagerHostTracked.cs b/src/Ryujinx.Cpu/Jit/MemoryManagerHostTracked.cs
index 83b3e4094..03befa088 100644
--- a/src/Ryujinx.Cpu/Jit/MemoryManagerHostTracked.cs
+++ b/src/Ryujinx.Cpu/Jit/MemoryManagerHostTracked.cs
@@ -71,7 +71,7 @@ namespace Ryujinx.Cpu.Jit
         /// <param name="invalidAccessHandler">Optional function to handle invalid memory accesses</param>
         public MemoryManagerHostTracked(MemoryBlock backingMemory, ulong addressSpaceSize, InvalidAccessHandler invalidAccessHandler)
         {
-            Tracking = new MemoryTracking(this, (int)MemoryBlock.GetPageSize(), invalidAccessHandler);
+            Tracking = new MemoryTracking(this, AddressSpacePartitioned.Use4KBProtection ? PageSize : (int)MemoryBlock.GetPageSize(), invalidAccessHandler);
 
             _backingMemory = backingMemory;
             _pageTable = new PageTable<ulong>();
@@ -990,7 +990,7 @@ namespace Ryujinx.Cpu.Jit
                 _ => MemoryPermission.None,
             };
 
-            _addressSpace.Reprotect(va, size, protection, Tracking);
+            _addressSpace.Reprotect(va, size, protection);
         }
 
         /// <summary>