From b76f4aa3bd06a40a8c6962db13147b2d245b55b4 Mon Sep 17 00:00:00 2001 From: Kawe Mazidjatari <48657826+Mauler125@users.noreply.github.com> Date: Sun, 25 Jun 2023 10:29:42 +0200 Subject: [PATCH] CModule class improvements *Use unordered_map to get mpdule sections instead, as this is more performant than comparing strings. * Removed 'm_SectionName' field from ModuleSections_t, as the unordered map now keeps track of them. * Removed all extraneous module section copies. * Renamed 'GetImportedFunction' to 'GetImportedSymbol'. * Renamed 'GetExportedFunction' to 'GetExportedSymbol'. *Made a static version of 'GetImportedSymbol' and 'GetExportedSymbol', so it could be used on raw module base addresses. *Created inlines for getting the DOS and NT headers. *Improved formatting so the code could be read more easily on a vertical monitor. --- r5dev/codecs/Miles/miles_impl.h | 4 +- r5dev/codecs/bink/bink_impl.h | 6 +- r5dev/launcher/launcher.h | 2 +- r5dev/pluginsdk/pluginsdk.cpp | 2 +- r5dev/pluginsystem/pluginsystem.cpp | 14 +- r5dev/public/tier0/module.h | 65 ++--- r5dev/tier0/crashhandler.cpp | 2 +- r5dev/tier0/module.cpp | 369 +++++++++++++++++----------- 8 files changed, 276 insertions(+), 188 deletions(-) diff --git a/r5dev/codecs/Miles/miles_impl.h b/r5dev/codecs/Miles/miles_impl.h index b5889c73..6a645614 100644 --- a/r5dev/codecs/Miles/miles_impl.h +++ b/r5dev/codecs/Miles/miles_impl.h @@ -34,10 +34,10 @@ class MilesCore : public IDetour #endif // !(GAMEDLL_S0) || !(GAMEDLL_S1) || !(GAMEDLL_S2) v_Miles_Initialize = p_Miles_Initialize.RCast(); - p_MilesQueueEventRun = g_RadAudioSystemDll.GetExportedFunction("MilesQueueEventRun"); + p_MilesQueueEventRun = g_RadAudioSystemDll.GetExportedSymbol("MilesQueueEventRun"); v_MilesQueueEventRun = p_MilesQueueEventRun.RCast(); - p_MilesBankPatch = g_RadAudioSystemDll.GetExportedFunction("MilesBankPatch"); + p_MilesBankPatch = g_RadAudioSystemDll.GetExportedSymbol("MilesBankPatch"); v_MilesBankPatch = p_MilesBankPatch.RCast(); } diff --git a/r5dev/codecs/bink/bink_impl.h b/r5dev/codecs/bink/bink_impl.h index ceef7a27..dddad5a0 100644 --- a/r5dev/codecs/bink/bink_impl.h +++ b/r5dev/codecs/bink/bink_impl.h @@ -20,11 +20,11 @@ class BinkCore : public IDetour } virtual void GetFun(void) const { - p_BinkOpen = g_RadVideoToolsDll.GetExportedFunction("BinkOpen"); + p_BinkOpen = g_RadVideoToolsDll.GetExportedSymbol("BinkOpen"); v_BinkOpen = p_BinkOpen.RCast(); - p_BinkClose = g_RadVideoToolsDll.GetExportedFunction("BinkClose"); + p_BinkClose = g_RadVideoToolsDll.GetExportedSymbol("BinkClose"); v_BinkClose = p_BinkClose.RCast(); - p_BinkGetError = g_RadVideoToolsDll.GetExportedFunction("BinkGetError"); + p_BinkGetError = g_RadVideoToolsDll.GetExportedSymbol("BinkGetError"); v_BinkGetError = p_BinkGetError.RCast(); } virtual void GetVar(void) const { } diff --git a/r5dev/launcher/launcher.h b/r5dev/launcher/launcher.h index 8426fafb..eab08ae2 100644 --- a/r5dev/launcher/launcher.h +++ b/r5dev/launcher/launcher.h @@ -37,7 +37,7 @@ class VLauncher : public IDetour p_WinMain = g_GameDll.FindPatternSIMD("48 89 5C 24 ?? 48 89 6C 24 ?? 48 89 74 24 ?? 57 48 83 EC 20 41 8B D9 49 8B F8"); v_WinMain = p_WinMain.RCast(); - p_LauncherMain = g_GameDll.GetExportedFunction("LauncherMain"); + p_LauncherMain = g_GameDll.GetExportedSymbol("LauncherMain"); v_LauncherMain = p_LauncherMain.RCast(); p_TopLevelExceptionFilter = g_GameDll.FindPatternSIMD("40 53 48 83 EC 20 48 8B 05 ?? ?? ?? ?? 48 8B D9 48 85 C0 74 06"); diff --git a/r5dev/pluginsdk/pluginsdk.cpp b/r5dev/pluginsdk/pluginsdk.cpp index d5bed641..9889e888 100644 --- a/r5dev/pluginsdk/pluginsdk.cpp +++ b/r5dev/pluginsdk/pluginsdk.cpp @@ -36,7 +36,7 @@ CPluginSDK::~CPluginSDK() //--------------------------------------------------------------------------------- bool CPluginSDK::InitSDK() { - auto getFactorySystemFn = m_SDKModule.GetExportedFunction("GetFactorySystem").RCast(); + auto getFactorySystemFn = m_SDKModule.GetExportedSymbol("GetFactorySystem").RCast(); Assert(getFactorySystemFn, "Could not find GetFactorySystem export from gamesdk.dll"); if (!getFactorySystemFn) diff --git a/r5dev/pluginsystem/pluginsystem.cpp b/r5dev/pluginsystem/pluginsystem.cpp index ecf72224..fb32f674 100644 --- a/r5dev/pluginsystem/pluginsystem.cpp +++ b/r5dev/pluginsystem/pluginsystem.cpp @@ -56,8 +56,12 @@ bool CPluginSystem::LoadPluginInstance(PluginInstance_t& pluginInst) CModule pluginModule = CModule(pluginInst.m_svPluginName.c_str()); - // Pass selfModule here on load function, we have to do this because local listen/dedi/client dll's are called different, refer to a comment on the pluginsdk. - auto onLoadFn = pluginModule.GetExportedFunction("PluginInstance_OnLoad").RCast(); + // Pass selfModule here on load function, we have to do + // this because local listen/dedi/client dll's are called + // different, refer to a comment on the pluginsdk. + PluginInstance_t::OnLoad onLoadFn = pluginModule.GetExportedSymbol( + "PluginInstance_OnLoad").RCast(); + Assert(onLoadFn); if (!onLoadFn(pluginInst.m_svPluginName.c_str(), g_SDKDll.GetModuleName().c_str())) @@ -67,7 +71,6 @@ bool CPluginSystem::LoadPluginInstance(PluginInstance_t& pluginInst) } pluginInst.m_hModule = pluginModule; - return pluginInst.m_bIsLoaded = true; } @@ -81,7 +84,10 @@ bool CPluginSystem::UnloadPluginInstance(PluginInstance_t& pluginInst) if (!pluginInst.m_bIsLoaded) return false; - auto onUnloadFn = pluginInst.m_hModule.GetExportedFunction("PluginInstance_OnUnload").RCast(); + PluginInstance_t::OnUnload onUnloadFn = + pluginInst.m_hModule.GetExportedSymbol( + "PluginInstance_OnUnload").RCast(); + Assert(onUnloadFn); if (onUnloadFn) diff --git a/r5dev/public/tier0/module.h b/r5dev/public/tier0/module.h index 2e4cdbf0..a4b65f06 100644 --- a/r5dev/public/tier0/module.h +++ b/r5dev/public/tier0/module.h @@ -1,5 +1,6 @@ #ifndef MODULE_H #define MODULE_H +#include "windows/tebpeb64.h" class CModule { @@ -7,58 +8,66 @@ public: struct ModuleSections_t { ModuleSections_t(void) = default; - ModuleSections_t(const char* sectionName, uintptr_t pSectionBase, size_t nSectionSize) : - m_SectionName(sectionName), m_pSectionBase(pSectionBase), m_nSectionSize(nSectionSize) {} + ModuleSections_t(QWORD pSectionBase, size_t nSectionSize) : + m_pSectionBase(pSectionBase), m_nSectionSize(nSectionSize) {} inline bool IsSectionValid(void) const { return m_nSectionSize != 0; } - string m_SectionName; // Name of section. - uintptr_t m_pSectionBase; // Start address of section. - size_t m_nSectionSize; // Size of section. + QWORD m_pSectionBase; // Start address of section. + size_t m_nSectionSize; // Size of section. }; + typedef unordered_map ModuleSectionsMap_t; CModule(void) = default; - CModule(const char* szModuleName); - CModule(const char* szModuleName, const uintptr_t nModuleBase); + CModule(const char* szModuleName, const bool bDynamicInit = true); + CModule(const char* szModuleName, const QWORD nModuleBase, const bool bDynamicInit = true); - void Init(); + void Init(const bool bInitSections); void LoadSections(); -#ifndef PLUGINSDK + CMemory FindPatternSIMD(const char* szPattern, const ModuleSections_t* moduleSection = nullptr) const; CMemory FindString(const char* szString, const ptrdiff_t occurrence = 1, bool nullTerminator = false) const; CMemory FindStringReadOnly(const char* szString, bool nullTerminator) const; CMemory FindFreeDataPage(const size_t nSize) const; CMemory GetVirtualMethodTable(const char* szTableName, const size_t nRefIndex = 0); -#endif // !PLUGINSDK - CMemory GetImportedFunction(const char* szModuleName, const char* szFunctionName, const bool bGetFunctionReference) const; - CMemory GetExportedFunction(const char* szFunctionName) const; - ModuleSections_t GetSectionByName(const char* szSectionName) const; - inline const vector& GetSections() const { return m_ModuleSections; } - inline uintptr_t GetModuleBase(void) const { return m_pModuleBase; } + static CMemory GetImportedSymbol(QWORD pModuleBase, const char* szModuleName, const char* szSymbolName, const bool bGetSymbolReference); + inline CMemory GetImportedSymbol(const char* szModuleName, const char* szSymbolName, const bool bGetSymbolReference) const + { return GetImportedSymbol(m_pModuleBase, szModuleName, szSymbolName, bGetSymbolReference); } + + static CMemory GetExportedSymbol(QWORD pModuleBase, const char* szSymbolName); + inline CMemory GetExportedSymbol(const char* szSymbolName) const + { return GetExportedSymbol(m_pModuleBase, szSymbolName); } + + inline const CModule::ModuleSections_t& GetSectionByName(const char* szSectionName) const + { return m_ModuleSections.at(szSectionName); } + + inline const ModuleSectionsMap_t& GetSections() const { return m_ModuleSections; } + inline QWORD GetModuleBase(void) const { return m_pModuleBase; } inline DWORD GetModuleSize(void) const { return m_nModuleSize; } inline const string& GetModuleName(void) const { return m_ModuleName; } - inline uintptr_t GetRVA(const uintptr_t nAddress) const { return (nAddress - GetModuleBase()); } + inline QWORD GetRVA(const QWORD nAddress) const { return (nAddress - GetModuleBase()); } + + + inline IMAGE_DOS_HEADER* GetDOSHeader() const { return GetDOSHeader(m_pModuleBase); } + inline static IMAGE_DOS_HEADER* GetDOSHeader(QWORD pModuleBase) + { return reinterpret_cast(pModuleBase); } + + inline IMAGE_NT_HEADERS64* GetNTHeaders() const { return GetNTHeaders(m_pModuleBase); } + inline static IMAGE_NT_HEADERS64* GetNTHeaders(QWORD pModuleBase) + { return reinterpret_cast(pModuleBase + GetDOSHeader(pModuleBase)->e_lfanew); } void UnlinkFromPEB(void) const; - IMAGE_NT_HEADERS64* m_pNTHeaders; - IMAGE_DOS_HEADER* m_pDOSHeader; - - ModuleSections_t m_ExecutableCode; - ModuleSections_t m_ExceptionTable; - ModuleSections_t m_RunTimeData; - ModuleSections_t m_ReadOnlyData; - private: CMemory FindPatternSIMD(const uint8_t* pPattern, const char* szMask, const ModuleSections_t* moduleSection = nullptr, const size_t nOccurrence = 0) const; - string m_ModuleName; - uintptr_t m_pModuleBase; - DWORD m_nModuleSize; - vector m_ModuleSections; + QWORD m_pModuleBase; + DWORD m_nModuleSize; + string m_ModuleName; + ModuleSectionsMap_t m_ModuleSections; }; #endif // MODULE_H \ No newline at end of file diff --git a/r5dev/tier0/crashhandler.cpp b/r5dev/tier0/crashhandler.cpp index 00170cae..225be733 100644 --- a/r5dev/tier0/crashhandler.cpp +++ b/r5dev/tier0/crashhandler.cpp @@ -205,7 +205,7 @@ void CCrashHandler::FormatSystemInfo() //----------------------------------------------------------------------------- void CCrashHandler::FormatBuildInfo() { - m_svBuffer.append(Format("build_id: %u\n", g_SDKDll.m_pNTHeaders->FileHeader.TimeDateStamp)); + m_svBuffer.append(Format("build_id: %u\n", g_SDKDll.GetNTHeaders()->FileHeader.TimeDateStamp)); } //----------------------------------------------------------------------------- diff --git a/r5dev/tier0/module.cpp b/r5dev/tier0/module.cpp index cbcba990..e2a3244c 100644 --- a/r5dev/tier0/module.cpp +++ b/r5dev/tier0/module.cpp @@ -5,19 +5,17 @@ //===========================================================================// #include "tier0/memaddr.h" #include "tier0/sigcache.h" -#include "windows/tebpeb64.h" //----------------------------------------------------------------------------- // Purpose: constructor // Input : *szModuleName - +// bDynamicInit - set this to false if there is no memory allocator //----------------------------------------------------------------------------- -CModule::CModule(const char* szModuleName) +CModule::CModule(const char* szModuleName, const bool bDynamicInit) : m_ModuleName(szModuleName) { - m_pModuleBase = reinterpret_cast(GetModuleHandleA(szModuleName)); - - Init(); - LoadSections(); + m_pModuleBase = reinterpret_cast(GetModuleHandleA(szModuleName)); + Init(bDynamicInit); } //----------------------------------------------------------------------------- @@ -25,30 +23,24 @@ CModule::CModule(const char* szModuleName) // Input : *szModuleName - // nModuleBase - //----------------------------------------------------------------------------- -CModule::CModule(const char* szModuleName, const uintptr_t nModuleBase) +CModule::CModule(const char* szModuleName, const QWORD nModuleBase, + const bool bDynamicInit) : m_ModuleName(szModuleName) , m_pModuleBase(nModuleBase) { - Init(); - LoadSections(); + Init(bDynamicInit); } //----------------------------------------------------------------------------- // Purpose: initializes module descriptors //----------------------------------------------------------------------------- -void CModule::Init() +void CModule::Init(const bool bInitSections) { - m_pDOSHeader = reinterpret_cast(m_pModuleBase); - m_pNTHeaders = reinterpret_cast(m_pModuleBase + m_pDOSHeader->e_lfanew); - m_nModuleSize = static_cast(m_pNTHeaders->OptionalHeader.SizeOfImage); + m_nModuleSize = static_cast(GetNTHeaders()->OptionalHeader.SizeOfImage); - const IMAGE_SECTION_HEADER* hSection = IMAGE_FIRST_SECTION(m_pNTHeaders); // Get first image section. - - for (WORD i = 0; i < m_pNTHeaders->FileHeader.NumberOfSections; i++) // Loop through the sections. + if (bInitSections) { - const IMAGE_SECTION_HEADER& hCurrentSection = hSection[i]; // Get current section. - m_ModuleSections.push_back(ModuleSections_t(reinterpret_cast(hCurrentSection.Name), - static_cast(m_pModuleBase + hCurrentSection.VirtualAddress), hCurrentSection.SizeOfRawData)); // Push back a struct with the section data. + LoadSections(); } } @@ -57,13 +49,19 @@ void CModule::Init() //----------------------------------------------------------------------------- void CModule::LoadSections() { - m_ExecutableCode = GetSectionByName(".text"); - m_ExceptionTable = GetSectionByName(".pdata"); - m_RunTimeData = GetSectionByName(".data"); - m_ReadOnlyData = GetSectionByName(".rdata"); + const IMAGE_NT_HEADERS64* pNTHeaders = GetNTHeaders(); + const IMAGE_SECTION_HEADER* hSection = IMAGE_FIRST_SECTION(pNTHeaders); + + for (WORD i = 0; i < pNTHeaders->FileHeader.NumberOfSections; i++) + { + // Capture each section. + const IMAGE_SECTION_HEADER& hCurrentSection = hSection[i]; + m_ModuleSections.emplace(reinterpret_cast(hCurrentSection.Name), + ModuleSections_t(static_cast + (m_pModuleBase + hCurrentSection.VirtualAddress), hCurrentSection.SizeOfRawData)); + } } -#ifndef PLUGINSDK //----------------------------------------------------------------------------- // Purpose: find array of bytes in process memory using SIMD instructions // Input : *pPattern - @@ -75,13 +73,17 @@ void CModule::LoadSections() CMemory CModule::FindPatternSIMD(const uint8_t* pPattern, const char* szMask, const ModuleSections_t* moduleSection, const size_t nOccurrence) const { - if (!m_ExecutableCode.IsSectionValid()) - return CMemory(); + const ModuleSections_t& executableCode = GetSectionByName(".text"); + + if (!executableCode.IsSectionValid()) + return nullptr; const bool bSectionValid = moduleSection ? moduleSection->IsSectionValid() : false; - const uintptr_t nBase = bSectionValid ? moduleSection->m_pSectionBase : m_ExecutableCode.m_pSectionBase; - const uintptr_t nSize = bSectionValid ? moduleSection->m_nSectionSize : m_ExecutableCode.m_nSectionSize; + const QWORD nBase = bSectionValid ? + moduleSection->m_pSectionBase : executableCode.m_pSectionBase; + const QWORD nSize = bSectionValid ? + moduleSection->m_nSectionSize : executableCode.m_nSectionSize; const size_t nMaskLen = strlen(szMask); const uint8_t* pData = reinterpret_cast(nBase); @@ -141,7 +143,7 @@ CMemory CModule::FindPatternSIMD(const uint8_t* pPattern, const char* szMask, } }cont:; } - return CMemory(); + return nullptr; } //----------------------------------------------------------------------------- @@ -150,7 +152,8 @@ CMemory CModule::FindPatternSIMD(const uint8_t* pPattern, const char* szMask, // *moduleSection - // Output : CMemory //----------------------------------------------------------------------------- -CMemory CModule::FindPatternSIMD(const char* szPattern, const ModuleSections_t* moduleSection) const +CMemory CModule::FindPatternSIMD(const char* szPattern, + const ModuleSections_t* moduleSection) const { uint64_t nRVA; if (g_SigCache.FindEntry(szPattern, nRVA)) @@ -158,8 +161,11 @@ CMemory CModule::FindPatternSIMD(const char* szPattern, const ModuleSections_t* return CMemory(nRVA + GetModuleBase()); } - const pair, string> patternInfo = PatternToMaskedBytes(szPattern); - const CMemory memory = FindPatternSIMD(patternInfo.first.data(), patternInfo.second.c_str(), moduleSection); + const pair, string> + patternInfo = PatternToMaskedBytes(szPattern); + + const CMemory memory = FindPatternSIMD(patternInfo.first.data(), + patternInfo.second.c_str(), moduleSection); g_SigCache.AddEntry(szPattern, GetRVA(memory.GetPtr())); return memory; @@ -171,10 +177,13 @@ CMemory CModule::FindPatternSIMD(const char* szPattern, const ModuleSections_t* // bNullTerminator - // Output : CMemory //----------------------------------------------------------------------------- -CMemory CModule::FindString(const char* szString, const ptrdiff_t nOccurrence, bool bNullTerminator) const +CMemory CModule::FindString(const char* szString, const ptrdiff_t nOccurrence, + bool bNullTerminator) const { - if (!m_ExecutableCode.IsSectionValid()) - return CMemory(); + const ModuleSections_t& executableCode = GetSectionByName(".text"); + + if (!executableCode.IsSectionValid()) + return nullptr; uint64_t nRVA; string svPackedString = szString + std::to_string(nOccurrence); @@ -184,25 +193,33 @@ CMemory CModule::FindString(const char* szString, const ptrdiff_t nOccurrence, b return CMemory(nRVA + GetModuleBase()); } - const CMemory stringAddress = FindStringReadOnly(szString, bNullTerminator); // Get Address for the string in the .rdata section. + // Get Address for the string in the .rdata section. + const CMemory stringAddress = FindStringReadOnly(szString, bNullTerminator); if (!stringAddress) - return CMemory(); + return nullptr; + // Get the start of the .text section. + uint8_t* pTextStart = reinterpret_cast(executableCode.m_pSectionBase); uint8_t* pLatestOccurrence = nullptr; - uint8_t* pTextStart = reinterpret_cast(m_ExecutableCode.m_pSectionBase); // Get the start of the .text section. ptrdiff_t dOccurrencesFound = 0; CMemory resultAddress; - for (size_t i = 0ull; i < m_ExecutableCode.m_nSectionSize - 0x5; i++) + for (size_t i = 0ull; i < executableCode.m_nSectionSize - 0x5; i++) { byte byte = pTextStart[i]; if (byte == LEA) { - const CMemory skipOpCode = CMemory(reinterpret_cast(&pTextStart[i])).OffsetSelf(0x2); // Skip next 2 opcodes, those being the instruction and the register. - const int32_t relativeAddress = skipOpCode.GetValue(); // Get 4-byte long string relative Address - const uintptr_t nextInstruction = skipOpCode.Offset(0x4).GetPtr(); // Get location of next instruction. - const CMemory potentialLocation = CMemory(nextInstruction + relativeAddress); // Get potential string location. + // Skip next 2 opcodes, those being the instruction and the register. + const CMemory skipOpCode = CMemory(reinterpret_cast< + QWORD>(&pTextStart[i])).OffsetSelf(0x2); + + // Get 4-byte long string relative Address + const int32_t relativeAddress = skipOpCode.GetValue(); + // Get location of next instruction. + const QWORD nextInstruction = skipOpCode.Offset(0x4).GetPtr(); + // Get potential string location. + const CMemory potentialLocation = CMemory(nextInstruction + relativeAddress); if (potentialLocation == stringAddress) { @@ -234,8 +251,10 @@ CMemory CModule::FindString(const char* szString, const ptrdiff_t nOccurrence, b //----------------------------------------------------------------------------- CMemory CModule::FindStringReadOnly(const char* szString, bool bNullTerminator) const { - if (!m_ReadOnlyData.IsSectionValid()) - return CMemory(); + const ModuleSections_t& readOnlyData = GetSectionByName(".rdata"); + + if (!readOnlyData.IsSectionValid()) + return nullptr; uint64_t nRVA; if (g_SigCache.FindEntry(szString, nRVA)) @@ -243,17 +262,21 @@ CMemory CModule::FindStringReadOnly(const char* szString, bool bNullTerminator) return CMemory(nRVA + GetModuleBase()); } - const vector vBytes = StringToBytes(szString, bNullTerminator); // Convert our string to a byte array. - const pair bytesInfo = std::make_pair(vBytes.size(), vBytes.data()); // Get the size and data of our bytes. + // Convert our string to a byte array. + const vector vBytes = StringToBytes(szString, bNullTerminator); + const pair bytesInfo = std::make_pair< + size_t, const int*>(vBytes.size(), vBytes.data()); // Get the size and data of our bytes. - const uint8_t* pBase = reinterpret_cast(m_ReadOnlyData.m_pSectionBase); // Get start of .rdata section. + // Get start of .rdata section. + const uint8_t* pBase = reinterpret_cast(readOnlyData.m_pSectionBase); - for (size_t i = 0ull; i < m_ReadOnlyData.m_nSectionSize - bytesInfo.first; i++) + for (size_t i = 0ull; i < readOnlyData.m_nSectionSize - bytesInfo.first; i++) { bool bFound = true; - // If either the current byte equals to the byte in our pattern or our current byte in the pattern is a wildcard - // our if clause will be false. + // If either the current byte equals to the byte in + // our pattern or our current byte in the pattern is + // a wildcard our if clause will be false. for (size_t j = 0ull; j < bytesInfo.first; j++) { if (pBase[i + j] != bytesInfo.second[j] && bytesInfo.second[j] != -1) @@ -272,7 +295,7 @@ CMemory CModule::FindStringReadOnly(const char* szString, bool bNullTerminator) } } - return CMemory(); + return nullptr; } //----------------------------------------------------------------------------- @@ -288,29 +311,44 @@ CMemory CModule::FindFreeDataPage(const size_t nSize) const VirtualQuery(address, &membInfo, sizeof(membInfo)); - if (membInfo.AllocationBase && membInfo.BaseAddress && membInfo.State == MEM_COMMIT && !(membInfo.Protect & PAGE_GUARD) && membInfo.Protect != PAGE_NOACCESS) + if (membInfo.AllocationBase && membInfo.BaseAddress && + membInfo.State == MEM_COMMIT && + !(membInfo.Protect & PAGE_GUARD) && + membInfo.Protect != PAGE_NOACCESS) { - if ((membInfo.Protect & (PAGE_EXECUTE_READWRITE | PAGE_READWRITE)) && membInfo.RegionSize >= size) + if ((membInfo.Protect & + (PAGE_EXECUTE_READWRITE | PAGE_READWRITE)) && + membInfo.RegionSize >= size) { - return ((membInfo.Protect & (PAGE_EXECUTE_READWRITE | PAGE_READWRITE)) && membInfo.RegionSize >= size) ? true : false; + return ((membInfo.Protect & + (PAGE_EXECUTE_READWRITE | PAGE_READWRITE)) + && membInfo.RegionSize >= size) ? true : false; } } return false; }; - // This is very unstable, this doesn't check for the actual 'page' sizes. - // Also can be optimized to search per 'section'. - const uintptr_t endOfModule = m_pModuleBase + m_pNTHeaders->OptionalHeader.SizeOfImage - sizeof(uintptr_t); - for (uintptr_t currAddr = endOfModule; m_pModuleBase < currAddr; currAddr -= sizeof(uintptr_t)) + // This is very unstable, this doesn't check + // for the actual 'page' sizes. Also can be + // optimized to search per 'section'. + const QWORD endOfModule = m_pModuleBase + + GetNTHeaders()->OptionalHeader.SizeOfImage - sizeof(QWORD); + + for (QWORD currAddr = endOfModule; + m_pModuleBase < currAddr; currAddr -= sizeof(QWORD)) { - if (*reinterpret_cast(currAddr) == 0 && checkDataSection(reinterpret_cast(currAddr), nSize)) + if (*reinterpret_cast(currAddr) == 0 && + checkDataSection(reinterpret_cast(currAddr), nSize)) { bool bIsGoodPage = true; uint32_t nPageCount = 0; - for (; nPageCount < nSize && bIsGoodPage; nPageCount += sizeof(uintptr_t)) + for (; nPageCount < nSize && bIsGoodPage; + nPageCount += sizeof(QWORD)) { - const uintptr_t pageData = *reinterpret_cast(currAddr + nPageCount); + const QWORD pageData =*reinterpret_cast< + QWORD*>(currAddr + nPageCount); + if (pageData != 0) bIsGoodPage = false; } @@ -320,7 +358,7 @@ CMemory CModule::FindFreeDataPage(const size_t nSize) const } } - return CMemory(); + return nullptr; } //----------------------------------------------------------------------------- @@ -331,7 +369,9 @@ CMemory CModule::FindFreeDataPage(const size_t nSize) const //----------------------------------------------------------------------------- CMemory CModule::GetVirtualMethodTable(const char* szTableName, const size_t nRefIndex) { - uint64_t nRVA; // Packed together as we can have multiple VFTable searches, but with different ref indexes. + // Packed together as we can have multiple VFTable + // searches, but with different ref indices. + uint64_t nRVA; string svPackedTableName = szTableName + std::to_string(nRefIndex); if (g_SigCache.FindEntry(svPackedTableName.c_str(), nRVA)) @@ -339,177 +379,210 @@ CMemory CModule::GetVirtualMethodTable(const char* szTableName, const size_t nRe return CMemory(nRVA + GetModuleBase()); } - ModuleSections_t moduleSection(".data", m_RunTimeData.m_pSectionBase, m_RunTimeData.m_nSectionSize); + const ModuleSections_t& readOnlyData = GetSectionByName(".rdata"); + const ModuleSections_t& dataSection = GetSectionByName(".data"); const auto tableNameInfo = StringToMaskedBytes(szTableName, false); - CMemory rttiTypeDescriptor = FindPatternSIMD(tableNameInfo.first.data(), tableNameInfo.second.c_str(), &moduleSection).OffsetSelf(-0x10); + CMemory rttiTypeDescriptor = FindPatternSIMD(tableNameInfo.first.data(), + tableNameInfo.second.c_str(), &dataSection).OffsetSelf(-0x10); + if (!rttiTypeDescriptor) - return CMemory(); + return nullptr; - uintptr_t scanStart = m_ReadOnlyData.m_pSectionBase; // Get the start address of our scan. + QWORD scanStart = readOnlyData.m_pSectionBase; + const QWORD scanEnd = (readOnlyData.m_pSectionBase + readOnlyData.m_nSectionSize) - 0x4; + + // The RTTI gets referenced by a 4-Byte RVA + // address, we need to scan for that address. + const QWORD rttiTDRva = rttiTypeDescriptor.GetPtr() - m_pModuleBase; + ModuleSections_t moduleSection; - const uintptr_t scanEnd = (m_ReadOnlyData.m_pSectionBase + m_ReadOnlyData.m_nSectionSize) - 0x4; // Calculate the end of our scan. - const uintptr_t rttiTDRva = rttiTypeDescriptor.GetPtr() - m_pModuleBase; // The RTTI gets referenced by a 4-Byte RVA address. We need to scan for that address. while (scanStart < scanEnd) { - moduleSection = { ".rdata", scanStart, m_ReadOnlyData.m_nSectionSize }; - CMemory reference = FindPatternSIMD(reinterpret_cast(&rttiTDRva), "xxxx", &moduleSection, nRefIndex); + moduleSection = { scanStart, readOnlyData.m_nSectionSize }; + + CMemory reference = FindPatternSIMD(reinterpret_cast( + &rttiTDRva), "xxxx", &moduleSection, nRefIndex); + if (!reference) break; CMemory referenceOffset = reference.Offset(-0xC); - if (referenceOffset.GetValue() != 1) // Check if we got a RTTI Object Locator for this reference by checking if -0xC is 1, which is the 'signature' field which is always 1 on x64. + + // Check if we got a RTTI Object Locator for this + // reference by checking if -0xC is 1, which is the + // 'signature' field which is always 1 on x64. + if (referenceOffset.GetValue() != 1) { - scanStart = reference.Offset(0x4).GetPtr(); // Set location to current reference + 0x4 so we avoid pushing it back again into the vector. + // Set location to current reference + 0x4 so we + // avoid pushing it back again into the vector. + scanStart = reference.Offset(0x4).GetPtr(); continue; } - moduleSection = { ".rdata", m_ReadOnlyData.m_pSectionBase, m_ReadOnlyData.m_nSectionSize }; - CMemory vfTable = FindPatternSIMD(reinterpret_cast(&referenceOffset), "xxxxxxxx", &moduleSection).OffsetSelf(0x8); - g_SigCache.AddEntry(svPackedTableName.c_str(), GetRVA(vfTable.GetPtr())); + moduleSection = {readOnlyData.m_pSectionBase, readOnlyData.m_nSectionSize }; + CMemory vfTable = FindPatternSIMD(reinterpret_cast( + &referenceOffset), "xxxxxxxx", &moduleSection).OffsetSelf(0x8); + + g_SigCache.AddEntry(svPackedTableName.c_str(), GetRVA(vfTable.GetPtr())); return vfTable; } return CMemory(); } -#endif // !PLUGINSDK //----------------------------------------------------------------------------- // Purpose: get address of imported function in target module -// Input : *szModuleName - -// *szFunctionName - -// bGetFunctionReference - +// Input : *szModuleName - +// *szSymbolName - +// bGetSymbolReference - // Output : CMemory //----------------------------------------------------------------------------- -CMemory CModule::GetImportedFunction(const char* szModuleName, const char* szFunctionName, const bool bGetFunctionReference) const +CMemory CModule::GetImportedSymbol(QWORD pModuleBase, const char* szModuleName, + const char* szSymbolName, const bool bGetSymbolReference) { - if (!m_pDOSHeader || m_pDOSHeader->e_magic != IMAGE_DOS_SIGNATURE) // Is dosHeader valid? - return CMemory(); + IMAGE_DOS_HEADER* pDOSHeader = GetDOSHeader(pModuleBase); + IMAGE_NT_HEADERS64* pNTHeaders = GetNTHeaders(pModuleBase); - if (!m_pNTHeaders || m_pNTHeaders->Signature != IMAGE_NT_SIGNATURE) // Is ntHeader valid? - return CMemory(); + if (!pDOSHeader || pDOSHeader->e_magic != IMAGE_DOS_SIGNATURE) + return nullptr; + + if (!pNTHeaders || pNTHeaders->Signature != IMAGE_NT_SIGNATURE) + return nullptr; + + // Get the location of IMAGE_IMPORT_DESCRIPTOR for this + // module by adding the IMAGE_DIRECTORY_ENTRY_IMPORT + // relative virtual address onto our module base address. + IMAGE_IMPORT_DESCRIPTOR* pImageImportDescriptors = reinterpret_cast + (pModuleBase + pNTHeaders->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT].VirtualAddress); - // Get the location of IMAGE_IMPORT_DESCRIPTOR for this module by adding the IMAGE_DIRECTORY_ENTRY_IMPORT relative virtual address onto our module base address. - IMAGE_IMPORT_DESCRIPTOR* pImageImportDescriptors = reinterpret_cast(m_pModuleBase + m_pNTHeaders->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT].VirtualAddress); if (!pImageImportDescriptors) - return CMemory(); + return nullptr; for (IMAGE_IMPORT_DESCRIPTOR* pIID = pImageImportDescriptors; pIID->Name != 0; pIID++) { - // Get virtual relative Address of the imported module name. Then add module base Address to get the actual location. - const char* szImportedModuleName = reinterpret_cast(reinterpret_cast(m_pModuleBase + pIID->Name)); + // Get virtual relative Address of the imported module name. + // Then add module base Address to get the actual location. + const char* szImportedModuleName = reinterpret_cast(reinterpret_cast(pModuleBase + pIID->Name)); - if (stricmp(szImportedModuleName, szModuleName) == 0) // Is this our wanted imported module?. + if (stricmp(szImportedModuleName, szModuleName) == NULL) { - // Original First Thunk to get function name. - IMAGE_THUNK_DATA* pOgFirstThunk = reinterpret_cast(m_pModuleBase + pIID->OriginalFirstThunk); + IMAGE_THUNK_DATA* pOgFirstThunk = reinterpret_cast(pModuleBase + pIID->OriginalFirstThunk); // To get actual function address. - IMAGE_THUNK_DATA* pFirstThunk = reinterpret_cast(m_pModuleBase + pIID->FirstThunk); + IMAGE_THUNK_DATA* pFirstThunk = reinterpret_cast(pModuleBase + pIID->FirstThunk); for (; pOgFirstThunk->u1.AddressOfData; ++pOgFirstThunk, ++pFirstThunk) { // Get image import by name. - const IMAGE_IMPORT_BY_NAME* pImageImportByName = reinterpret_cast(m_pModuleBase + pOgFirstThunk->u1.AddressOfData); + const IMAGE_IMPORT_BY_NAME* pImageImportByName = reinterpret_cast( + pModuleBase + pOgFirstThunk->u1.AddressOfData); - if (strcmp(pImageImportByName->Name, szFunctionName) == 0) // Is this our wanted imported function? + if (strcmp(pImageImportByName->Name, szSymbolName) == NULL) { // Grab function address from firstThunk. - uintptr_t* pFunctionAddress = &pFirstThunk->u1.Function; + QWORD* pFunctionAddress = &pFirstThunk->u1.Function; // Reference or address? - return bGetFunctionReference ? CMemory(pFunctionAddress) : CMemory(*pFunctionAddress); // Return as CMemory class. + return bGetSymbolReference ? CMemory(pFunctionAddress) : CMemory(*pFunctionAddress); } } } } - return CMemory(); + return nullptr; } //----------------------------------------------------------------------------- -// Purpose: get address of exported function in this module -// Input : *szFunctionName - -// bNullTerminator - +// Purpose: get address of exported symbol in this module +// Input : *pModuleBase - +// szSymbolName - // Output : CMemory //----------------------------------------------------------------------------- -CMemory CModule::GetExportedFunction(const char* szFunctionName) const +CMemory CModule::GetExportedSymbol(QWORD pModuleBase, const char* szSymbolName) { - if (!m_pDOSHeader || m_pDOSHeader->e_magic != IMAGE_DOS_SIGNATURE) // Is dosHeader valid? - return CMemory(); + IMAGE_DOS_HEADER* pDOSHeader = GetDOSHeader(pModuleBase); + IMAGE_NT_HEADERS64* pNTHeaders = GetNTHeaders(pModuleBase); - if (!m_pNTHeaders || m_pNTHeaders->Signature != IMAGE_NT_SIGNATURE) // Is ntHeader valid? - return CMemory(); + if (!pDOSHeader || pDOSHeader->e_magic != IMAGE_DOS_SIGNATURE) + return nullptr; + + if (!pNTHeaders || pNTHeaders->Signature != IMAGE_NT_SIGNATURE) + return nullptr; + + // Get the location of IMAGE_EXPORT_DIRECTORY for this + // module by adding the IMAGE_DIRECTORY_ENTRY_EXPORT + // relative virtual address onto our module base address. + const IMAGE_EXPORT_DIRECTORY* pImageExportDirectory = + reinterpret_cast(pModuleBase + + pNTHeaders->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_EXPORT].VirtualAddress); - // Get the location of IMAGE_EXPORT_DIRECTORY for this module by adding the IMAGE_DIRECTORY_ENTRY_EXPORT relative virtual address onto our module base address. - const IMAGE_EXPORT_DIRECTORY* pImageExportDirectory = reinterpret_cast(m_pModuleBase + m_pNTHeaders->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_EXPORT].VirtualAddress); if (!pImageExportDirectory) - return CMemory(); + return nullptr; - // Are there any exported functions? if (!pImageExportDirectory->NumberOfFunctions) - return CMemory(); + return nullptr; + + // Get the location of the functions. + const DWORD* pAddressOfFunctions = reinterpret_cast(pModuleBase + + pImageExportDirectory->AddressOfFunctions); - // Get the location of the functions via adding the relative virtual address from the struct into our module base address. - const DWORD* pAddressOfFunctions = reinterpret_cast(m_pModuleBase + pImageExportDirectory->AddressOfFunctions); if (!pAddressOfFunctions) - return CMemory(); + return nullptr; + + // Get the names of the functions. + const DWORD* pAddressOfName = reinterpret_cast(pModuleBase + + pImageExportDirectory->AddressOfNames); - // Get the names of the functions via adding the relative virtual address from the struct into our module base Address. - const DWORD* pAddressOfName = reinterpret_cast(m_pModuleBase + pImageExportDirectory->AddressOfNames); if (!pAddressOfName) - return CMemory(); + return nullptr; + + // Get the ordinals of the functions. + DWORD* pAddressOfOrdinals = reinterpret_cast(pModuleBase + + pImageExportDirectory->AddressOfNameOrdinals); - // Get the ordinals of the functions via adding the relative virtual Address from the struct into our module base address. - DWORD* pAddressOfOrdinals = reinterpret_cast(m_pModuleBase + pImageExportDirectory->AddressOfNameOrdinals); if (!pAddressOfOrdinals) - return CMemory(); + return nullptr; - for (DWORD i = 0; i < pImageExportDirectory->NumberOfFunctions; i++) // Iterate through all the functions. + for (DWORD i = 0; i < pImageExportDirectory->NumberOfFunctions; i++) { - // Get virtual relative Address of the function name. Then add module base Address to get the actual location. - const char* ExportFunctionName = reinterpret_cast(reinterpret_cast(m_pModuleBase + pAddressOfName[i])); + // Get virtual relative Address of the function name, + // then add module base Address to get the actual location. + const char* ExportFunctionName = + reinterpret_cast(reinterpret_cast( + pModuleBase + pAddressOfName[i])); - if (strcmp(ExportFunctionName, szFunctionName) == 0) // Is this our wanted exported function? + if (strcmp(ExportFunctionName, szSymbolName) == NULL) { - // Get the function ordinal. Then grab the relative virtual address of our wanted function. Then add module base address so we get the actual location. - return CMemory(m_pModuleBase + pAddressOfFunctions[reinterpret_cast(pAddressOfOrdinals)[i]]); // Return as CMemory class. + // Get the function ordinal, then grab the relative + // virtual address of our wanted function. Then add + // module base address so we get the actual location. + return pModuleBase + + pAddressOfFunctions[reinterpret_cast(pAddressOfOrdinals)[i]]; } } - return CMemory(); -} - -//----------------------------------------------------------------------------- -// Purpose: get the module section by name (example: '.rdata', '.text') -// Input : *szSectionName - -// Output : ModuleSections_t -//----------------------------------------------------------------------------- -CModule::ModuleSections_t CModule::GetSectionByName(const char* szSectionName) const -{ - for (const ModuleSections_t& section : m_ModuleSections) - { - if (section.m_SectionName.compare(szSectionName) == 0) - return section; - } - - return ModuleSections_t(); + return nullptr; } //----------------------------------------------------------------------------- // Purpose: unlink module from peb +// Disclaimer: This does not bypass GetMappedFileName. That function calls +// NtQueryVirtualMemory which does a syscall to ntoskrnl for getting info +// on a section. //----------------------------------------------------------------------------- -void CModule::UnlinkFromPEB() const // Disclaimer: This does not bypass GetMappedFileName. That function calls NtQueryVirtualMemory which does a syscall to ntoskrnl for getting info on a section. +void CModule::UnlinkFromPEB() const { #define UNLINK_FROM_PEB(entry) \ (entry).Flink->Blink = (entry).Blink; \ (entry).Blink->Flink = (entry).Flink; - const PEB64* processEnvBlock = reinterpret_cast(__readgsqword(0x60)); // https://en.wikipedia.org/wiki/Win32_Thread_Information_Block + // https://en.wikipedia.org/wiki/Win32_Thread_Information_Block + const PEB64* processEnvBlock = reinterpret_cast(__readgsqword(0x60)); const LIST_ENTRY* inLoadOrderList = &processEnvBlock->Ldr->InLoadOrderModuleList; for (LIST_ENTRY* entry = inLoadOrderList->Flink; entry != inLoadOrderList; entry = entry->Flink) { const PLDR_DATA_TABLE_ENTRY pldrEntry = reinterpret_cast(entry->Flink); - const std::uintptr_t baseAddr = reinterpret_cast(pldrEntry->DllBase); + const QWORD baseAddr = reinterpret_cast(pldrEntry->DllBase); if (baseAddr != m_pModuleBase) continue;