From 749f76e6fe35debb26288328764421af717af955 Mon Sep 17 00:00:00 2001
From: bunnei <bunneidev@gmail.com>
Date: Wed, 2 Mar 2022 17:59:54 -0800
Subject: [PATCH] hle: kernel: KPageTable: Improve implementations of
 MapCodeMemory and UnmapCodeMemory.

- This makes these functions more accurate to the real HOS implementations.
- Fixes memory access issues in Super Smash Bros. Ultimate that occur when un/mapping NROs.
---
 src/core/hle/kernel/k_page_table.cpp | 161 +++++++++++++++++++--------
 src/core/hle/kernel/k_page_table.h   |   4 +-
 2 files changed, 117 insertions(+), 48 deletions(-)

diff --git a/src/core/hle/kernel/k_page_table.cpp b/src/core/hle/kernel/k_page_table.cpp
index dfea0b6e26..0602de1f76 100644
--- a/src/core/hle/kernel/k_page_table.cpp
+++ b/src/core/hle/kernel/k_page_table.cpp
@@ -285,72 +285,141 @@ ResultCode KPageTable::MapProcessCode(VAddr addr, std::size_t num_pages, KMemory
     return ResultSuccess;
 }
 
-ResultCode KPageTable::MapCodeMemory(VAddr dst_addr, VAddr src_addr, std::size_t size) {
+ResultCode KPageTable::MapCodeMemory(VAddr dst_address, VAddr src_address, std::size_t size) {
+    // Validate the mapping request.
+    R_UNLESS(this->CanContain(dst_address, size, KMemoryState::AliasCode),
+             ResultInvalidMemoryRegion);
+
+    // Lock the table.
     KScopedLightLock lk(general_lock);
 
-    const std::size_t num_pages{size / PageSize};
+    // Verify that the source memory is normal heap.
+    KMemoryState src_state{};
+    KMemoryPermission src_perm{};
+    std::size_t num_src_allocator_blocks{};
+    R_TRY(this->CheckMemoryState(&src_state, &src_perm, nullptr, &num_src_allocator_blocks,
+                                 src_address, size, KMemoryState::All, KMemoryState::Normal,
+                                 KMemoryPermission::All, KMemoryPermission::UserReadWrite,
+                                 KMemoryAttribute::All, KMemoryAttribute::None));
 
-    KMemoryState state{};
-    KMemoryPermission perm{};
-    CASCADE_CODE(CheckMemoryState(&state, &perm, nullptr, nullptr, src_addr, size,
-                                  KMemoryState::All, KMemoryState::Normal, KMemoryPermission::All,
-                                  KMemoryPermission::UserReadWrite, KMemoryAttribute::Mask,
-                                  KMemoryAttribute::None, KMemoryAttribute::IpcAndDeviceMapped));
-
-    if (IsRegionMapped(dst_addr, size)) {
-        return ResultInvalidCurrentMemory;
-    }
-
-    KPageLinkedList page_linked_list;
-    AddRegionToPages(src_addr, num_pages, page_linked_list);
+    // Verify that the destination memory is unmapped.
+    std::size_t num_dst_allocator_blocks{};
+    R_TRY(this->CheckMemoryState(&num_dst_allocator_blocks, dst_address, size, KMemoryState::All,
+                                 KMemoryState::Free, KMemoryPermission::None,
+                                 KMemoryPermission::None, KMemoryAttribute::None,
+                                 KMemoryAttribute::None));
 
+    // Map the code memory.
     {
-        auto block_guard = detail::ScopeExit(
-            [&] { Operate(src_addr, num_pages, perm, OperationType::ChangePermissions); });
+        // Determine the number of pages being operated on.
+        const std::size_t num_pages = size / PageSize;
 
-        CASCADE_CODE(Operate(src_addr, num_pages, KMemoryPermission::None,
-                             OperationType::ChangePermissions));
-        CASCADE_CODE(MapPages(dst_addr, page_linked_list, KMemoryPermission::None));
+        // Create page groups for the memory being mapped.
+        KPageLinkedList pg;
+        AddRegionToPages(src_address, num_pages, pg);
 
-        block_guard.Cancel();
+        // Reprotect the source as kernel-read/not mapped.
+        const auto new_perm = static_cast<KMemoryPermission>(KMemoryPermission::KernelRead |
+                                                             KMemoryPermission::NotMapped);
+        R_TRY(Operate(src_address, num_pages, new_perm, OperationType::ChangePermissions));
+
+        // Ensure that we unprotect the source pages on failure.
+        auto unprot_guard = SCOPE_GUARD({
+            ASSERT(this->Operate(src_address, num_pages, src_perm, OperationType::ChangePermissions)
+                       .IsSuccess());
+        });
+
+        // Map the alias pages.
+        R_TRY(MapPages(dst_address, pg, new_perm));
+
+        // We successfully mapped the alias pages, so we don't need to unprotect the src pages on
+        // failure.
+        unprot_guard.Cancel();
+
+        // Apply the memory block updates.
+        block_manager->Update(src_address, num_pages, src_state, new_perm,
+                              KMemoryAttribute::Locked);
+        block_manager->Update(dst_address, num_pages, KMemoryState::AliasCode, new_perm,
+                              KMemoryAttribute::None);
     }
 
-    block_manager->Update(src_addr, num_pages, state, KMemoryPermission::None,
-                          KMemoryAttribute::Locked);
-    block_manager->Update(dst_addr, num_pages, KMemoryState::AliasCode);
-
     return ResultSuccess;
 }
 
-ResultCode KPageTable::UnmapCodeMemory(VAddr dst_addr, VAddr src_addr, std::size_t size) {
+ResultCode KPageTable::UnmapCodeMemory(VAddr dst_address, VAddr src_address, std::size_t size) {
+    // Validate the mapping request.
+    R_UNLESS(this->CanContain(dst_address, size, KMemoryState::AliasCode),
+             ResultInvalidMemoryRegion);
+
+    // Lock the table.
     KScopedLightLock lk(general_lock);
 
-    if (!size) {
-        return ResultSuccess;
+    // Verify that the source memory is locked normal heap.
+    std::size_t num_src_allocator_blocks{};
+    R_TRY(this->CheckMemoryState(std::addressof(num_src_allocator_blocks), src_address, size,
+                                 KMemoryState::All, KMemoryState::Normal, KMemoryPermission::None,
+                                 KMemoryPermission::None, KMemoryAttribute::All,
+                                 KMemoryAttribute::Locked));
+
+    // Verify that the destination memory is aliasable code.
+    std::size_t num_dst_allocator_blocks{};
+    R_TRY(this->CheckMemoryStateContiguous(
+        std::addressof(num_dst_allocator_blocks), dst_address, size, KMemoryState::FlagCanCodeAlias,
+        KMemoryState::FlagCanCodeAlias, KMemoryPermission::None, KMemoryPermission::None,
+        KMemoryAttribute::All, KMemoryAttribute::None));
+
+    // Determine whether any pages being unmapped are code.
+    bool any_code_pages = false;
+    {
+        KMemoryBlockManager::const_iterator it = block_manager->FindIterator(dst_address);
+        while (true) {
+            // Get the memory info.
+            const KMemoryInfo info = it->GetMemoryInfo();
+
+            // Check if the memory has code flag.
+            if ((info.GetState() & KMemoryState::FlagCode) != KMemoryState::None) {
+                any_code_pages = true;
+                break;
+            }
+
+            // Check if we're done.
+            if (dst_address + size - 1 <= info.GetLastAddress()) {
+                break;
+            }
+
+            // Advance.
+            ++it;
+        }
     }
 
-    const std::size_t num_pages{size / PageSize};
+    // Ensure that we maintain the instruction cache.
+    bool reprotected_pages = false;
+    SCOPE_EXIT({
+        if (reprotected_pages && any_code_pages) {
+            system.InvalidateCpuInstructionCacheRange(dst_address, size);
+        }
+    });
 
-    CASCADE_CODE(CheckMemoryState(nullptr, nullptr, nullptr, nullptr, src_addr, size,
-                                  KMemoryState::All, KMemoryState::Normal, KMemoryPermission::None,
-                                  KMemoryPermission::None, KMemoryAttribute::Mask,
-                                  KMemoryAttribute::Locked, KMemoryAttribute::IpcAndDeviceMapped));
+    // Unmap.
+    {
+        // Determine the number of pages being operated on.
+        const std::size_t num_pages = size / PageSize;
 
-    KMemoryState state{};
-    CASCADE_CODE(CheckMemoryState(
-        &state, nullptr, nullptr, nullptr, dst_addr, PageSize, KMemoryState::FlagCanCodeAlias,
-        KMemoryState::FlagCanCodeAlias, KMemoryPermission::None, KMemoryPermission::None,
-        KMemoryAttribute::Mask, KMemoryAttribute::None, KMemoryAttribute::IpcAndDeviceMapped));
-    CASCADE_CODE(CheckMemoryState(dst_addr, size, KMemoryState::All, state, KMemoryPermission::None,
-                                  KMemoryPermission::None, KMemoryAttribute::Mask,
-                                  KMemoryAttribute::None));
-    CASCADE_CODE(Operate(dst_addr, num_pages, KMemoryPermission::None, OperationType::Unmap));
+        // Unmap the aliased copy of the pages.
+        R_TRY(Operate(dst_address, num_pages, KMemoryPermission::None, OperationType::Unmap));
 
-    block_manager->Update(dst_addr, num_pages, KMemoryState::Free);
-    block_manager->Update(src_addr, num_pages, KMemoryState::Normal,
-                          KMemoryPermission::UserReadWrite);
+        // Try to set the permissions for the source pages back to what they should be.
+        R_TRY(Operate(src_address, num_pages, KMemoryPermission::UserReadWrite,
+                      OperationType::ChangePermissions));
 
-    system.InvalidateCpuInstructionCacheRange(dst_addr, size);
+        // Apply the memory block updates.
+        block_manager->Update(dst_address, num_pages, KMemoryState::None);
+        block_manager->Update(src_address, num_pages, KMemoryState::Normal,
+                              KMemoryPermission::UserReadWrite);
+
+        // Note that we reprotected pages.
+        reprotected_pages = true;
+    }
 
     return ResultSuccess;
 }
diff --git a/src/core/hle/kernel/k_page_table.h b/src/core/hle/kernel/k_page_table.h
index 194177332a..aea1b8f631 100644
--- a/src/core/hle/kernel/k_page_table.h
+++ b/src/core/hle/kernel/k_page_table.h
@@ -36,8 +36,8 @@ public:
                                     KMemoryManager::Pool pool);
     ResultCode MapProcessCode(VAddr addr, std::size_t pages_count, KMemoryState state,
                               KMemoryPermission perm);
-    ResultCode MapCodeMemory(VAddr dst_addr, VAddr src_addr, std::size_t size);
-    ResultCode UnmapCodeMemory(VAddr dst_addr, VAddr src_addr, std::size_t size);
+    ResultCode MapCodeMemory(VAddr dst_address, VAddr src_address, std::size_t size);
+    ResultCode UnmapCodeMemory(VAddr dst_address, VAddr src_address, std::size_t size);
     ResultCode UnmapProcessMemory(VAddr dst_addr, std::size_t size, KPageTable& src_page_table,
                                   VAddr src_addr);
     ResultCode MapPhysicalMemory(VAddr addr, std::size_t size);