CModule: class improvements

* Added 'GetImportedFunction'.
* Remove extraneous std::string copy constructors during construction of 'm_vModuleSections'.
* Added extra constructor using base address.
This commit is contained in:
Kawe Mazidjatari 2023-01-31 23:52:11 +01:00
parent af42dfafe3
commit 3ea7cc1cd4
2 changed files with 82 additions and 4 deletions

View File

@ -13,22 +13,46 @@
//-----------------------------------------------------------------------------
CModule::CModule(const string& svModuleName) : m_svModuleName(svModuleName)
{
const MODULEINFO mInfo = GetModuleInfo(svModuleName.c_str());
m_nModuleSize = static_cast<size_t>(mInfo.SizeOfImage);
m_pModuleBase = reinterpret_cast<uintptr_t>(mInfo.lpBaseOfDll);
m_pModuleBase = reinterpret_cast<uintptr_t>(GetModuleHandleA(svModuleName.c_str()));
Init();
LoadSections();
}
//-----------------------------------------------------------------------------
// Purpose: constructor
// Input : nModuleBase
//-----------------------------------------------------------------------------
CModule::CModule(const uintptr_t nModuleBase, const string& svModuleName) : m_svModuleName(svModuleName), m_pModuleBase(nModuleBase)
{
Init();
LoadSections();
}
//-----------------------------------------------------------------------------
// Purpose: initializes module descriptors
//-----------------------------------------------------------------------------
void CModule::Init()
{
m_pDOSHeader = reinterpret_cast<IMAGE_DOS_HEADER*>(m_pModuleBase);
m_pNTHeaders = reinterpret_cast<IMAGE_NT_HEADERS64*>(m_pModuleBase + m_pDOSHeader->e_lfanew);
m_nModuleSize = static_cast<size_t>(m_pNTHeaders->OptionalHeader.SizeOfImage);
const IMAGE_SECTION_HEADER* hSection = IMAGE_FIRST_SECTION(m_pNTHeaders); // Get first image section.
for (WORD i = 0; i < m_pNTHeaders->FileHeader.NumberOfSections; i++) // Loop through the sections.
{
const IMAGE_SECTION_HEADER& hCurrentSection = hSection[i]; // Get current section.
m_vModuleSections.push_back(ModuleSections_t(string(reinterpret_cast<const char*>(hCurrentSection.Name)),
m_vModuleSections.push_back(ModuleSections_t(reinterpret_cast<const char*>(hCurrentSection.Name),
static_cast<uintptr_t>(m_pModuleBase + hCurrentSection.VirtualAddress), hCurrentSection.SizeOfRawData)); // Push back a struct with the section data.
}
}
//-----------------------------------------------------------------------------
// Purpose: initializes the default executable segments
//-----------------------------------------------------------------------------
void CModule::LoadSections()
{
m_ExecutableCode = GetSectionByName(".text");
m_ExceptionTable = GetSectionByName(".pdata");
m_RunTimeData = GetSectionByName(".data");
@ -297,6 +321,55 @@ CMemory CModule::GetVirtualMethodTable(const string& svTableName, const uint32_t
}
#endif // !PLUGINSDK
CMemory CModule::GetImportedFunction(const string& svModuleName, const string& svFunctionName, const bool bGetFunctionReference) const
{
if (!m_pDOSHeader || m_pDOSHeader->e_magic != IMAGE_DOS_SIGNATURE) // Is dosHeader valid?
return CMemory();
if (!m_pNTHeaders || m_pNTHeaders->Signature != IMAGE_NT_SIGNATURE) // Is ntHeader valid?
return CMemory();
// Get the location of IMAGE_IMPORT_DESCRIPTOR for this module by adding the IMAGE_DIRECTORY_ENTRY_IMPORT relative virtual address onto our module base address.
IMAGE_IMPORT_DESCRIPTOR* pImageImportDescriptors = reinterpret_cast<IMAGE_IMPORT_DESCRIPTOR*>(m_pModuleBase + m_pNTHeaders->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT].VirtualAddress);
if (!pImageImportDescriptors)
return CMemory();
for (IMAGE_IMPORT_DESCRIPTOR* pIID = pImageImportDescriptors; pIID->Name != 0; pIID++)
{
// Get virtual relative Address of the imported module name. Then add module base Address to get the actual location.
string svImportedModuleName = reinterpret_cast<char*>(reinterpret_cast<DWORD*>(m_pModuleBase + pIID->Name));
// Convert all characters to lower case because KERNEL32.DLL sometimes is kernel32.DLL, sometimes KERNEL32.dll.
std::transform(svImportedModuleName.begin(), svImportedModuleName.end(), svImportedModuleName.begin(), static_cast<int (*)(int)>(std::tolower));
if (svImportedModuleName.compare(svModuleName) == 0) // Is this our wanted imported module?.
{
// Original First Thunk to get function name.
IMAGE_THUNK_DATA* pOgFirstThunk = reinterpret_cast<IMAGE_THUNK_DATA*>(m_pModuleBase + pIID->OriginalFirstThunk);
// To get actual function address.
IMAGE_THUNK_DATA* pFirstThunk = reinterpret_cast<IMAGE_THUNK_DATA*>(m_pModuleBase + pIID->FirstThunk);
for (; pOgFirstThunk->u1.AddressOfData; ++pOgFirstThunk, ++pFirstThunk)
{
// Get image import by name.
const IMAGE_IMPORT_BY_NAME* pImageImportByName = reinterpret_cast<IMAGE_IMPORT_BY_NAME*>(m_pModuleBase + pOgFirstThunk->u1.AddressOfData);
// Get import function name.
const string svImportedFunctionName = pImageImportByName->Name;
if (svImportedFunctionName.compare(svFunctionName) == 0) // Is this our wanted imported function?
{
// Grab function address from firstThunk.
uintptr_t* pFunctionAddress = &pFirstThunk->u1.Function;
// Reference or address?
return bGetFunctionReference ? CMemory(pFunctionAddress) : CMemory(*pFunctionAddress); // Return as CMemory class.
}
}
}
}
return CMemory();
}
//-----------------------------------------------------------------------------
// Purpose: get address of exported function in this module
// Input : *svFunctionName -

View File

@ -22,6 +22,10 @@ public:
CModule(void) = default;
CModule(const string& moduleName);
CModule(const uintptr_t nModuleBase, const string& svModuleName);
void Init();
void LoadSections();
#ifndef PLUGINSDK
CMemory FindPatternSIMD(const string& svPattern, const ModuleSections_t* moduleSection = nullptr) const;
CMemory FindString(const string& string, const ptrdiff_t occurrence = 1, bool nullTerminator = false) const;
@ -29,6 +33,7 @@ public:
CMemory GetVirtualMethodTable(const string& svTableName, const uint32_t nRefIndex = 0);
#endif // !PLUGINSDK
CMemory GetImportedFunction(const string& svModuleName, const string& svFunctionName, const bool bGetFunctionReference) const;
CMemory GetExportedFunction(const string& svFunctionName) const;
ModuleSections_t GetSectionByName(const string& svSectionName) const;
uintptr_t GetModuleBase(void) const;