From 40566235e04c90d01d2919f779d50ecf53b87b2c Mon Sep 17 00:00:00 2001 From: Kawe Mazidjatari <48657826+Mauler125@users.noreply.github.com> Date: Thu, 3 Nov 2022 02:30:57 +0100 Subject: [PATCH] Server authentication code improvements * Added validity checks for player names, they must be UTF-8 encoded, else game clients will crash attempting to set non-UTF-8 player names in RUI. * CServer::RejectConnection now takes a 'v_netadr_t*', previously it was 'user_creds_s*', which worked as 'v_netadr_t' is the first member in 'user_creds_s'. * Programmer must now manually pass a character buffer to 'v_net_adr::GetAddress(...)'. --- r5dev/engine/server/server.cpp | 43 ++++++++++++-------- r5dev/engine/server/server.h | 12 +++--- r5dev/public/utility/utility.cpp | 69 ++++++++++++++++++++++++++++++-- r5dev/public/utility/utility.h | 1 + r5dev/tier1/NetAdr2.h | 14 ++----- 5 files changed, 103 insertions(+), 36 deletions(-) diff --git a/r5dev/engine/server/server.cpp b/r5dev/engine/server/server.cpp index 4187fa1e..aae73f4f 100644 --- a/r5dev/engine/server/server.cpp +++ b/r5dev/engine/server/server.cpp @@ -64,17 +64,30 @@ int CServer::GetNumFakeClients(void) const //--------------------------------------------------------------------------------- bool CServer::AuthClient(user_creds_s* pChallenge) { - string svIpAddress = pChallenge->m_nAddr.GetAddress(); - if (sv_showconnecting->GetBool()) - DevMsg(eDLL_T::SERVER, "Processing connectionless challenge for '%s' ('%llu')\n", svIpAddress.c_str(), pChallenge->m_nNucleusID); + char pszAddresBuffer[INET6_ADDRSTRLEN]; // Render the client's address. + pChallenge->m_nAddr.GetAddress(pszAddresBuffer, sizeof(pszAddresBuffer)); + + const bool bEnableLogging = sv_showconnecting->GetBool(); + if (bEnableLogging) + DevMsg(eDLL_T::SERVER, "Processing connectionless challenge for '%s' ('%llu')\n", pszAddresBuffer, pChallenge->m_nNucleusID); + + char* pUserID = pChallenge->m_pUserID; + if (!pUserID || !pUserID[0] || !IsValidUTF8(pUserID)) // Only proceed connection if the client's name is valid and UTF-8 encoded. + { + RejectConnection(m_Socket, &pChallenge->m_nAddr, "#Valve_Reject_Invalid_Name"); + if (bEnableLogging) + Warning(eDLL_T::SERVER, "Connection rejected for '%s' ('%llu' has an invalid name!)\n", pszAddresBuffer, pChallenge->m_nNucleusID); + + return false; + } if (g_pBanSystem->IsBanListValid()) // Is the banned list vector valid? { - if (g_pBanSystem->IsBanned(svIpAddress, pChallenge->m_nNucleusID)) // Is the client trying to connect banned? + if (g_pBanSystem->IsBanned(pszAddresBuffer, pChallenge->m_nNucleusID)) // Is the client trying to connect banned? { - RejectConnection(m_Socket, pChallenge, "#Valve_Reject_Banned"); // RejectConnection for the client. - if (sv_showconnecting->GetBool()) - Warning(eDLL_T::SERVER, "Connection rejected for '%s' ('%llu' is banned from this server!)\n", svIpAddress.c_str(), pChallenge->m_nNucleusID); + RejectConnection(m_Socket, &pChallenge->m_nAddr, "#Valve_Reject_Banned"); // RejectConnection for the client. + if (bEnableLogging) + Warning(eDLL_T::SERVER, "Connection rejected for '%s' ('%llu' is banned from this server!)\n", pszAddresBuffer, pChallenge->m_nNucleusID); return false; } @@ -82,7 +95,7 @@ bool CServer::AuthClient(user_creds_s* pChallenge) if (g_bCheckCompBanDB) { - std::thread th(SV_IsClientBanned, svIpAddress, pChallenge->m_nNucleusID); + std::thread th(SV_IsClientBanned, string(pszAddresBuffer), pChallenge->m_nNucleusID); th.detach(); } @@ -98,13 +111,11 @@ bool CServer::AuthClient(user_creds_s* pChallenge) //--------------------------------------------------------------------------------- CClient* CServer::ConnectClient(CServer* pServer, user_creds_s* pChallenge) { - if (pServer->AuthClient(pChallenge)) - { - CClient* pClient = v_CServer_ConnectClient(pServer, pChallenge); - return pClient; - } + if (pServer->m_State < server_state_t::ss_active || !pServer->AuthClient(pChallenge)) + return nullptr; - return nullptr; + CClient* pClient = v_CServer_ConnectClient(pServer, pChallenge); + return pClient; } //--------------------------------------------------------------------------------- @@ -113,9 +124,9 @@ CClient* CServer::ConnectClient(CServer* pServer, user_creds_s* pChallenge) // *pChallenge - // *szMessage - //--------------------------------------------------------------------------------- -void CServer::RejectConnection(int iSocket, user_creds_s* pChallenge, const char* szMessage) +void CServer::RejectConnection(int iSocket, v_netadr_t* pNetAdr, const char* szMessage) { - v_CServer_RejectConnection(this, iSocket, pChallenge, szMessage); + v_CServer_RejectConnection(this, iSocket, pNetAdr, szMessage); } /////////////////////////////////////////////////////////////////////////////// diff --git a/r5dev/engine/server/server.h b/r5dev/engine/server/server.h index ae12f961..3ec032af 100644 --- a/r5dev/engine/server/server.h +++ b/r5dev/engine/server/server.h @@ -23,7 +23,7 @@ struct user_creds_s int32_t m_nChallenge; uint32_t m_nReservation; uint64_t m_nNucleusID; - uint8_t* m_pUserID; + char* m_pUserID; }; class CServer : public IServer @@ -41,7 +41,7 @@ public: bool IsLoading(void) const { return m_State == server_state_t::ss_loading; } bool IsDedicated(void) const { return g_bDedicated; } bool AuthClient(user_creds_s* pChallenge); - void RejectConnection(int iSocket, user_creds_s* pCreds, const char* szMessage); + void RejectConnection(int iSocket, v_netadr_t* pNetAdr, const char* szMessage); static CClient* ConnectClient(CServer* pServer, user_creds_s* pChallenge); #endif // !CLIENT_DLL @@ -91,7 +91,7 @@ inline CMemory p_CServer_Authenticate; inline auto v_CServer_ConnectClient = p_CServer_Authenticate.RCast(); inline CMemory p_CServer_RejectConnection; -inline auto v_CServer_RejectConnection = p_CServer_RejectConnection.RCast(); +inline auto v_CServer_RejectConnection = p_CServer_RejectConnection.RCast(); void CServer_Attach(); void CServer_Detach(); @@ -124,9 +124,9 @@ class VServer : public IDetour #endif p_CServer_RejectConnection = g_GameDll.FindPatternSIMD(reinterpret_cast("\x4C\x89\x4C\x24\x00\x53\x55\x56\x57\x48\x81\xEC\x00\x00\x00\x00\x49\x8B\xD9"), "xxxx?xxxxxxx????xxx"); - v_CServer_Think = p_CServer_Think.RCast(); /*48 89 5C 24 ?? 48 89 74 24 ?? 57 48 81 EC ?? ?? ?? ?? 80 3D ?? ?? ?? ?? ??*/ - v_CServer_ConnectClient = p_CServer_Authenticate.RCast(); /*40 55 57 41 55 41 57 48 8D AC 24 ?? ?? ?? ??*/ - v_CServer_RejectConnection = p_CServer_RejectConnection.RCast(); /*4C 89 4C 24 ?? 53 55 56 57 48 81 EC ?? ?? ?? ?? 49 8B D9*/ + v_CServer_Think = p_CServer_Think.RCast(); /*48 89 5C 24 ?? 48 89 74 24 ?? 57 48 81 EC ?? ?? ?? ?? 80 3D ?? ?? ?? ?? ??*/ + v_CServer_ConnectClient = p_CServer_Authenticate.RCast(); /*40 55 57 41 55 41 57 48 8D AC 24 ?? ?? ?? ??*/ + v_CServer_RejectConnection = p_CServer_RejectConnection.RCast(); /*4C 89 4C 24 ?? 53 55 56 57 48 81 EC ?? ?? ?? ?? 49 8B D9*/ #endif // !CLIENT_DLL } virtual void GetVar(void) const diff --git a/r5dev/public/utility/utility.cpp b/r5dev/public/utility/utility.cpp index aa264200..305e63a6 100644 --- a/r5dev/public/utility/utility.cpp +++ b/r5dev/public/utility/utility.cpp @@ -482,7 +482,7 @@ string Base64Decode(const string& svInput) } /////////////////////////////////////////////////////////////////////////////// -// For encoding data in UTF8. +// For encoding data in UTF-8. string UTF8Encode(const wstring& wsvInput) { string result; @@ -496,7 +496,7 @@ string UTF8Encode(const wstring& wsvInput) } /////////////////////////////////////////////////////////////////////////////// -// For decoding data in UTF8. +// For decoding data in UTF-8. string UTF8Decode(const string& svInput) { //struct destructible_codecvt : public std::codecvt @@ -510,7 +510,7 @@ string UTF8Decode(const string& svInput) } /////////////////////////////////////////////////////////////////////////////// -// For obtaining UTF8 character length. +// For obtaining UTF-8 character length. size_t UTF8CharLength(const uint8_t cInput) { if ((cInput & 0xFE) == 0xFC) @@ -526,6 +526,69 @@ size_t UTF8CharLength(const uint8_t cInput) return 1; } +/////////////////////////////////////////////////////////////////////////////// +// For checking if input string is a valid UTF-8 encoded string. +bool IsValidUTF8(char* pszString) +{ + char v1; // r9 + char* v2; // rdx + char v4; // r10 + int v5; // er8 + + while (true) + { + while (true) + { + v1 = *pszString; + v2 = pszString++; + if (v1 < 0) + { + break; + } + if (!v1) + { + return true; + } + } + + v4 = *pszString; + if ((*pszString & 0xC0) != 0x80) + { + break; + } + + pszString = v2 + 2; + if (v1 >= 0xE0u) + { + v5 = *pszString & 0x3F | ((v4 & 0x3F | ((v1 & 0xF) << 6)) << 6); + if ((*pszString & 0xC0) != 0x80) + { + return false; + } + + pszString = v2 + 3; + if (v1 >= 0xF0u) + { + if ((*pszString & 0xC0) != 0x80 || ((v5 << 6) | *pszString & 0x3Fu) > 0x10FFFF) + { + return false; + } + + pszString = v2 + 4; + } + else if ((v5 - 55296) <= 0x7FF) + { + return false; + } + } + else if (v1 < 0xC2u) + { + return false; + } + } + return false; +} + /////////////////////////////////////////////////////////////////////////////// // For checking if a string is a number. bool StringIsDigit(const string& svInput) diff --git a/r5dev/public/utility/utility.h b/r5dev/public/utility/utility.h index 3ad1644b..181f4f09 100644 --- a/r5dev/public/utility/utility.h +++ b/r5dev/public/utility/utility.h @@ -39,6 +39,7 @@ string Base64Decode(const string& svInput); string UTF8Encode(const wstring& wsvInput); string UTF8Decode(const string& svInput); size_t UTF8CharLength(const uint8_t cInput); +bool IsValidUTF8(char* pszString); bool StringIsDigit(const string& svInput); bool CompareStringAlphabetically(const string& svA, const string& svB); diff --git a/r5dev/tier1/NetAdr2.h b/r5dev/tier1/NetAdr2.h index 0996d66b..6874df3c 100644 --- a/r5dev/tier1/NetAdr2.h +++ b/r5dev/tier1/NetAdr2.h @@ -72,18 +72,10 @@ public: { return this->type; } - inline string GetAddress(void) const + inline void GetAddress(char* pchBuffer, uint32_t nBufferLen) const { - // Select a static buffer - static char s[4][INET6_ADDRSTRLEN]; - static int slot = 0; - int useSlot = (slot++) % 4; - - // Render into it - inet_ntop(AF_INET6, &this->adr, s[useSlot], sizeof(s[0])); - - // Pray the caller uses it before it gets clobbered - return s[useSlot]; + assert(nBufferLen >= INET6_ADDRSTRLEN); + inet_ntop(AF_INET6, &this->adr, pchBuffer, nBufferLen); } inline uint16_t GetPort(void) const {