diff --git a/src/Ryujinx.Cpu/Jit/AddressSpacePageProtections.cs b/src/Ryujinx.Cpu/Jit/AddressSpacePageProtections.cs new file mode 100644 index 000000000..3f4cd2ff4 --- /dev/null +++ b/src/Ryujinx.Cpu/Jit/AddressSpacePageProtections.cs @@ -0,0 +1,340 @@ +using Ryujinx.Common; +using Ryujinx.Common.Collections; +using Ryujinx.Memory; +using System; +using System.Diagnostics; + +namespace Ryujinx.Cpu.Jit +{ + class AddressSpacePageProtections : IDisposable + { + private const ulong GuestPageSize = 0x1000; + + class PageProtection : IntrusiveRedBlackTreeNode, IComparable + { + public readonly AddressSpacePartitionAllocation Memory; + public readonly ulong Offset; + public readonly ulong Address; + public readonly ulong Size; + + private MemoryBlock _viewBlock; + + public bool IsMapped => _viewBlock != null; + + public PageProtection(AddressSpacePartitionAllocation memory, ulong offset, ulong address, ulong size) + { + Memory = memory; + Offset = offset; + Address = address; + Size = size; + } + + public void SetViewBlock(MemoryBlock block) + { + _viewBlock = block; + } + + public void Unmap() + { + if (_viewBlock != null) + { + Memory.UnmapView(_viewBlock, Offset, MemoryBlock.GetPageSize()); + _viewBlock = null; + } + } + + public bool OverlapsWith(ulong va, ulong size) + { + return Address < va + size && va < Address + Size; + } + + public int CompareTo(PageProtection other) + { + if (OverlapsWith(other.Address, other.Size)) + { + return 0; + } + else if (Address < other.Address) + { + return -1; + } + else + { + return 1; + } + } + } + + private readonly IntrusiveRedBlackTree _protectionTree; + + public AddressSpacePageProtections() + { + _protectionTree = new(); + } + + public void Reprotect( + AddressSpacePartitionAllocator asAllocator, + AddressSpacePartitioned addressSpace, + AddressSpacePartition partition, + ulong va, + ulong endVa, + MemoryPermission protection, + Action updatePtCallback) + { + while (va < endVa) + { + ReprotectPage(asAllocator, addressSpace, partition, va, protection, updatePtCallback); + + va += GuestPageSize; + } + } + + private void ReprotectPage( + AddressSpacePartitionAllocator asAllocator, + AddressSpacePartitioned addressSpace, + AddressSpacePartition partition, + ulong va, + MemoryPermission protection, + Action updatePtCallback) + { + ulong pageSize = MemoryBlock.GetPageSize(); + + PageProtection pageProtection = _protectionTree.GetNode(new PageProtection(default, 0, va, 1)); + + if (pageProtection == null) + { + ulong firstPage = BitUtils.AlignDown(va, pageSize); + ulong lastPage = BitUtils.AlignUp(va + GuestPageSize, pageSize) - GuestPageSize; + + AddressSpacePartitionAllocation block; + PageProtection adjPageProtection = null; + ulong blockOffset = 0; + + if (va == firstPage && va > partition.Address) + { + block = asAllocator.AllocatePage(firstPage - pageSize, pageSize * 2); + + MapView(addressSpace, partition, block, 0, pageSize, va - GuestPageSize, out MemoryBlock adjMemory); + + adjPageProtection = new PageProtection(block, 0, va - GuestPageSize, GuestPageSize); + adjPageProtection.SetViewBlock(adjMemory); + blockOffset = pageSize; + } + else if (va == lastPage) + { + block = asAllocator.AllocatePage(firstPage, pageSize * 2); + + MapView(addressSpace, partition, block, pageSize, pageSize, va + GuestPageSize, out MemoryBlock adjMemory); + + adjPageProtection = new PageProtection(block, pageSize, va + GuestPageSize, GuestPageSize); + adjPageProtection.SetViewBlock(adjMemory); + } + else + { + block = asAllocator.AllocatePage(firstPage, pageSize); + } + + if (!MapView(addressSpace, partition, block, blockOffset, pageSize, va, out MemoryBlock viewMemory)) + { + block.Dispose(); + + return; + } + + pageProtection = new PageProtection(block, blockOffset, va, GuestPageSize); + pageProtection.SetViewBlock(viewMemory); + _protectionTree.Add(pageProtection); + + if (adjPageProtection != null) + { + Debug.Assert(_protectionTree.GetNode(adjPageProtection) == null); + _protectionTree.Add(adjPageProtection); + } + } + + Debug.Assert(pageProtection.IsMapped || partition.GetPrivateAllocation(va).Memory == null); + + pageProtection.Memory.Reprotect(pageProtection.Offset, pageSize, protection, false); + + updatePtCallback(va, pageProtection.Memory.GetPointer(pageProtection.Offset + (va & (pageSize - 1)), GuestPageSize), GuestPageSize); + } + + public void UpdateMappings(AddressSpacePartition partition, ulong va, ulong size) + { + ulong pageSize = MemoryBlock.GetPageSize(); + + PageProtection pageProtection = GetLowestOverlap(va, size); + + while (pageProtection != null) + { + if (pageProtection.Address >= va + size) + { + break; + } + + bool mapped = MapView( + partition, + pageProtection.Memory, + pageProtection.Offset, + pageSize, + pageProtection.Address, + out MemoryBlock memory); + + Debug.Assert(mapped); + + pageProtection.SetViewBlock(memory); + pageProtection = pageProtection.Successor; + } + } + + public void Remove(ulong va, ulong size) + { + ulong pageSize = MemoryBlock.GetPageSize(); + + PageProtection pageProtection = GetLowestOverlap(va, size); + + while (pageProtection != null) + { + if (pageProtection.Address >= va + size) + { + break; + } + + ulong firstPage = BitUtils.AlignDown(pageProtection.Address, pageSize); + ulong lastPage = BitUtils.AlignUp(pageProtection.Address + GuestPageSize, pageSize) - GuestPageSize; + + bool canDelete; + + if (pageProtection.Address == firstPage) + { + canDelete = pageProtection.Predecessor == null || + pageProtection.Predecessor.Address + pageProtection.Predecessor.Size != pageProtection.Address || + !pageProtection.Predecessor.IsMapped; + } + else if (pageProtection.Address == lastPage) + { + canDelete = pageProtection.Successor == null || + pageProtection.Address + pageProtection.Size != pageProtection.Successor.Address || + !pageProtection.Successor.IsMapped; + } + else + { + canDelete = true; + } + + PageProtection successor = pageProtection.Successor; + + if (canDelete) + { + if (pageProtection.Address == firstPage && + pageProtection.Predecessor != null && + pageProtection.Predecessor.Address + pageProtection.Predecessor.Size == pageProtection.Address) + { + _protectionTree.Remove(pageProtection.Predecessor); + } + else if (pageProtection.Address == lastPage && + pageProtection.Successor != null && + pageProtection.Address + pageProtection.Size == pageProtection.Successor.Address) + { + successor = successor.Successor; + _protectionTree.Remove(pageProtection.Successor); + } + + _protectionTree.Remove(pageProtection); + pageProtection.Memory.Dispose(); + } + else + { + pageProtection.Unmap(); + } + + pageProtection = successor; + } + } + + private static bool MapView( + AddressSpacePartitioned addressSpace, + AddressSpacePartition partition, + AddressSpacePartitionAllocation dstBlock, + ulong dstOffset, + ulong size, + ulong va, + out MemoryBlock memory) + { + PrivateRange privateRange; + + if (va >= partition.Address && va < partition.EndAddress) + { + privateRange = partition.GetPrivateAllocation(va); + } + else + { + privateRange = addressSpace.GetPrivateAllocation(va); + } + + memory = privateRange.Memory; + + if (privateRange.Memory == null) + { + return false; + } + + dstBlock.MapView(privateRange.Memory, privateRange.Offset & ~(MemoryBlock.GetPageSize() - 1), dstOffset, size); + + return true; + } + + private static bool MapView( + AddressSpacePartition partition, + AddressSpacePartitionAllocation dstBlock, + ulong dstOffset, + ulong size, + ulong va, + out MemoryBlock memory) + { + Debug.Assert(va >= partition.Address && va < partition.EndAddress); + + PrivateRange privateRange = partition.GetPrivateAllocation(va); + + memory = privateRange.Memory; + + if (privateRange.Memory == null) + { + return false; + } + + dstBlock.MapView(privateRange.Memory, privateRange.Offset & ~(size - 1), dstOffset, size); + + return true; + } + + private PageProtection GetLowestOverlap(ulong va, ulong size) + { + PageProtection pageProtection = _protectionTree.GetNode(new PageProtection(default, 0, va, size)); + + if (pageProtection == null) + { + return null; + } + + while (pageProtection.Predecessor != null && pageProtection.Predecessor.OverlapsWith(va, size)) + { + pageProtection = pageProtection.Predecessor; + } + + return pageProtection; + } + + protected virtual void Dispose(bool disposing) + { + Remove(0, ulong.MaxValue); + Debug.Assert(_protectionTree.Count == 0); + } + + public void Dispose() + { + Dispose(disposing: true); + GC.SuppressFinalize(this); + } + } +} \ No newline at end of file diff --git a/src/Ryujinx.Cpu/Jit/AddressSpacePartition.cs b/src/Ryujinx.Cpu/Jit/AddressSpacePartition.cs index 309439dfb..7401a2784 100644 --- a/src/Ryujinx.Cpu/Jit/AddressSpacePartition.cs +++ b/src/Ryujinx.Cpu/Jit/AddressSpacePartition.cs @@ -98,10 +98,6 @@ namespace Ryujinx.Cpu.Jit public PrivateMapping(ulong address, ulong size, PrivateMemoryAllocation privateAllocation) { - if (size == 0) - { - throw new Exception("huh? size is 0"); - } Address = address; Size = size; PrivateAllocation = privateAllocation; @@ -169,6 +165,7 @@ namespace Ryujinx.Cpu.Jit private readonly PrivateMemoryAllocator _privateMemoryAllocator; private readonly IntrusiveRedBlackTree _mappingTree; private readonly IntrusiveRedBlackTree _privateTree; + private readonly AddressSpacePageProtections _pageProtections; private readonly ReaderWriterLockSlim _treeLock; @@ -190,6 +187,7 @@ namespace Ryujinx.Cpu.Jit _privateMemoryAllocator = new PrivateMemoryAllocator(DefaultBlockAlignment, MemoryAllocationFlags.Mirrorable); _mappingTree = new IntrusiveRedBlackTree(); _privateTree = new IntrusiveRedBlackTree(); + _pageProtections = new AddressSpacePageProtections(); _treeLock = new ReaderWriterLockSlim(); _mappingTree.Add(new Mapping(address, size, MappingType.None)); @@ -240,6 +238,8 @@ namespace Ryujinx.Cpu.Jit } Update(va, pa, size, MappingType.Private); + + _pageProtections.UpdateMappings(this, va, size); } public void Unmap(ulong va, ulong size) @@ -258,9 +258,11 @@ namespace Ryujinx.Cpu.Jit } Update(va, 0UL, size, MappingType.None); + + _pageProtections.Remove(va, size); } - public void Reprotect(ulong va, ulong size, MemoryPermission protection) + public void ReprotectAligned(ulong va, ulong size, MemoryPermission protection) { Debug.Assert(va >= Address); Debug.Assert(va + size <= EndAddress); @@ -282,6 +284,19 @@ namespace Ryujinx.Cpu.Jit } } + public void Reprotect( + ulong va, + ulong size, + MemoryPermission protection, + AddressSpacePartitionAllocator asAllocator, + AddressSpacePartitioned addressSpace, + Action updatePtCallback) + { + ulong endVa = va + size; + + _pageProtections.Reprotect(asAllocator, addressSpace, this, va, endVa, protection, updatePtCallback); + } + public IntPtr GetPointer(ulong va, ulong size) { Debug.Assert(va >= Address); @@ -315,6 +330,8 @@ namespace Ryujinx.Cpu.Jit updatePtCallback(EndAddress - _hostPageSize, _baseMemory.GetPointer(Size, _hostPageSize), _hostPageSize); _hasBridgeAtEnd = true; + + _pageProtections.UpdateMappings(partitionAfter, EndAddress, GuestPageSize); } else { @@ -326,6 +343,8 @@ namespace Ryujinx.Cpu.Jit } _hasBridgeAtEnd = false; + + _pageProtections.Remove(EndAddress, GuestPageSize); } _cachedFirstPagePa = firstPagePa; @@ -346,6 +365,8 @@ namespace Ryujinx.Cpu.Jit _cachedLastPagePa = ulong.MaxValue; _hasBridgeAtEnd = false; + + _pageProtections.Remove(EndAddress, GuestPageSize); } private (MemoryBlock, ulong) GetFirstPageMemoryAndOffset() @@ -392,6 +413,27 @@ namespace Ryujinx.Cpu.Jit return (_backingMemory, _lastPagePa.Value & ~(_hostPageSize - 1)); } + public PrivateRange GetPrivateAllocation(ulong va) + { + _treeLock.EnterReadLock(); + + try + { + PrivateMapping map = _privateTree.GetNode(new PrivateMapping(va, 1UL, default)); + + if (map != null && map.PrivateAllocation.IsValid) + { + return new(map.PrivateAllocation.Memory, map.PrivateAllocation.Offset + (va - map.Address), map.Size - (va - map.Address)); + } + } + finally + { + _treeLock.ExitReadLock(); + } + + return PrivateRange.Empty; + } + private void Update(ulong va, ulong pa, ulong size, MappingType type) { _treeLock.EnterWriteLock(); @@ -654,7 +696,8 @@ namespace Ryujinx.Cpu.Jit { GC.SuppressFinalize(this); - _privateMemoryAllocator?.Dispose(); + _privateMemoryAllocator.Dispose(); + _pageProtections.Dispose(); _baseMemory.Dispose(); } } diff --git a/src/Ryujinx.Cpu/Jit/AddressSpacePartitionAllocator.cs b/src/Ryujinx.Cpu/Jit/AddressSpacePartitionAllocator.cs index 561b92af4..e897d05da 100644 --- a/src/Ryujinx.Cpu/Jit/AddressSpacePartitionAllocator.cs +++ b/src/Ryujinx.Cpu/Jit/AddressSpacePartitionAllocator.cs @@ -2,7 +2,6 @@ using Ryujinx.Common.Collections; using Ryujinx.Memory; using Ryujinx.Memory.Tracking; using System; -using System.Threading; namespace Ryujinx.Cpu.Jit { @@ -167,6 +166,14 @@ namespace Ryujinx.Cpu.Jit return allocation; } + public AddressSpacePartitionAllocation AllocatePage(ulong va, ulong size) + { + AddressSpacePartitionAllocation allocation = new(this, Allocate(size, MemoryBlock.GetPageSize(), CreateBlock)); + allocation.RegisterMapping(va, va + size, 0); + + return allocation; + } + private Block CreateBlock(MemoryBlock memory, ulong size) { return new Block(_tracking, memory, size, _lock); diff --git a/src/Ryujinx.Cpu/Jit/AddressSpacePartitioned.cs b/src/Ryujinx.Cpu/Jit/AddressSpacePartitioned.cs index ad082c96f..c7ff755c6 100644 --- a/src/Ryujinx.Cpu/Jit/AddressSpacePartitioned.cs +++ b/src/Ryujinx.Cpu/Jit/AddressSpacePartitioned.cs @@ -9,6 +9,8 @@ namespace Ryujinx.Cpu.Jit { class AddressSpacePartitioned : IDisposable { + public static readonly bool Use4KBProtection = false; + private const int PartitionBits = 25; private const ulong PartitionSize = 1UL << PartitionBits; @@ -89,27 +91,65 @@ namespace Ryujinx.Cpu.Jit } } - public void Reprotect(ulong va, ulong size, MemoryPermission protection, MemoryTracking tracking) + public void Reprotect(ulong va, ulong size, MemoryPermission protection) { ulong endVa = va + size; - while (va < endVa) + if (Use4KBProtection) { - AddressSpacePartition partition = FindPartition(va); - - if (partition == null) + lock (_partitions) { - va += PartitionSize - (va & (PartitionSize - 1)); + while (va < endVa) + { + AddressSpacePartition partition = FindPartition(va); - continue; + if (partition == null) + { + va += PartitionSize - (va & (PartitionSize - 1)); + + continue; + } + + (ulong clampedVa, ulong clampedEndVa) = ClampRange(partition, va, endVa); + + partition.Reprotect(clampedVa, clampedEndVa - clampedVa, protection, _asAllocator, this, _updatePtCallback); + + va += clampedEndVa - clampedVa; + } } - - (ulong clampedVa, ulong clampedEndVa) = ClampRange(partition, va, endVa); - - partition.Reprotect(clampedVa, clampedEndVa - clampedVa, protection); - - va += clampedEndVa - clampedVa; } + else + { + while (va < endVa) + { + AddressSpacePartition partition = FindPartition(va); + + if (partition == null) + { + va += PartitionSize - (va & (PartitionSize - 1)); + + continue; + } + + (ulong clampedVa, ulong clampedEndVa) = ClampRange(partition, va, endVa); + + partition.ReprotectAligned(clampedVa, clampedEndVa - clampedVa, protection); + + va += clampedEndVa - clampedVa; + } + } + } + + public PrivateRange GetPrivateAllocation(ulong va) + { + AddressSpacePartition partition = FindPartition(va); + + if (partition == null) + { + return PrivateRange.Empty; + } + + return partition.GetPrivateAllocation(va); } public PrivateRange GetFirstPrivateAllocation(ulong va, ulong size, out ulong nextVa) @@ -226,7 +266,7 @@ namespace Ryujinx.Cpu.Jit private int FindPartitionIndexLocked(ulong va) { int left = 0; - int middle = 0; + int middle; int right = _partitions.Count - 1; while (left <= right) diff --git a/src/Ryujinx.Cpu/Jit/MemoryManagerHostTracked.cs b/src/Ryujinx.Cpu/Jit/MemoryManagerHostTracked.cs index 83b3e4094..03befa088 100644 --- a/src/Ryujinx.Cpu/Jit/MemoryManagerHostTracked.cs +++ b/src/Ryujinx.Cpu/Jit/MemoryManagerHostTracked.cs @@ -71,7 +71,7 @@ namespace Ryujinx.Cpu.Jit /// Optional function to handle invalid memory accesses public MemoryManagerHostTracked(MemoryBlock backingMemory, ulong addressSpaceSize, InvalidAccessHandler invalidAccessHandler) { - Tracking = new MemoryTracking(this, (int)MemoryBlock.GetPageSize(), invalidAccessHandler); + Tracking = new MemoryTracking(this, AddressSpacePartitioned.Use4KBProtection ? PageSize : (int)MemoryBlock.GetPageSize(), invalidAccessHandler); _backingMemory = backingMemory; _pageTable = new PageTable(); @@ -990,7 +990,7 @@ namespace Ryujinx.Cpu.Jit _ => MemoryPermission.None, }; - _addressSpace.Reprotect(va, size, protection, Tracking); + _addressSpace.Reprotect(va, size, protection); } ///