Server authentication code improvements

* Added validity checks for player names, they must be UTF-8 encoded, else game clients will crash attempting to set non-UTF-8 player names in RUI.
* CServer::RejectConnection now takes a 'v_netadr_t*', previously it was 'user_creds_s*', which worked as 'v_netadr_t' is the first member in 'user_creds_s'.
* Programmer must now manually pass a character buffer to 'v_net_adr::GetAddress(...)'.
This commit is contained in:
Kawe Mazidjatari 2022-11-03 02:30:57 +01:00
parent f27ff874c2
commit 40566235e0
5 changed files with 103 additions and 36 deletions

View File

@ -64,17 +64,30 @@ int CServer::GetNumFakeClients(void) const
//---------------------------------------------------------------------------------
bool CServer::AuthClient(user_creds_s* pChallenge)
{
string svIpAddress = pChallenge->m_nAddr.GetAddress();
if (sv_showconnecting->GetBool())
DevMsg(eDLL_T::SERVER, "Processing connectionless challenge for '%s' ('%llu')\n", svIpAddress.c_str(), pChallenge->m_nNucleusID);
char pszAddresBuffer[INET6_ADDRSTRLEN]; // Render the client's address.
pChallenge->m_nAddr.GetAddress(pszAddresBuffer, sizeof(pszAddresBuffer));
const bool bEnableLogging = sv_showconnecting->GetBool();
if (bEnableLogging)
DevMsg(eDLL_T::SERVER, "Processing connectionless challenge for '%s' ('%llu')\n", pszAddresBuffer, pChallenge->m_nNucleusID);
char* pUserID = pChallenge->m_pUserID;
if (!pUserID || !pUserID[0] || !IsValidUTF8(pUserID)) // Only proceed connection if the client's name is valid and UTF-8 encoded.
{
RejectConnection(m_Socket, &pChallenge->m_nAddr, "#Valve_Reject_Invalid_Name");
if (bEnableLogging)
Warning(eDLL_T::SERVER, "Connection rejected for '%s' ('%llu' has an invalid name!)\n", pszAddresBuffer, pChallenge->m_nNucleusID);
return false;
}
if (g_pBanSystem->IsBanListValid()) // Is the banned list vector valid?
{
if (g_pBanSystem->IsBanned(svIpAddress, pChallenge->m_nNucleusID)) // Is the client trying to connect banned?
if (g_pBanSystem->IsBanned(pszAddresBuffer, pChallenge->m_nNucleusID)) // Is the client trying to connect banned?
{
RejectConnection(m_Socket, pChallenge, "#Valve_Reject_Banned"); // RejectConnection for the client.
if (sv_showconnecting->GetBool())
Warning(eDLL_T::SERVER, "Connection rejected for '%s' ('%llu' is banned from this server!)\n", svIpAddress.c_str(), pChallenge->m_nNucleusID);
RejectConnection(m_Socket, &pChallenge->m_nAddr, "#Valve_Reject_Banned"); // RejectConnection for the client.
if (bEnableLogging)
Warning(eDLL_T::SERVER, "Connection rejected for '%s' ('%llu' is banned from this server!)\n", pszAddresBuffer, pChallenge->m_nNucleusID);
return false;
}
@ -82,7 +95,7 @@ bool CServer::AuthClient(user_creds_s* pChallenge)
if (g_bCheckCompBanDB)
{
std::thread th(SV_IsClientBanned, svIpAddress, pChallenge->m_nNucleusID);
std::thread th(SV_IsClientBanned, string(pszAddresBuffer), pChallenge->m_nNucleusID);
th.detach();
}
@ -98,13 +111,11 @@ bool CServer::AuthClient(user_creds_s* pChallenge)
//---------------------------------------------------------------------------------
CClient* CServer::ConnectClient(CServer* pServer, user_creds_s* pChallenge)
{
if (pServer->AuthClient(pChallenge))
{
CClient* pClient = v_CServer_ConnectClient(pServer, pChallenge);
return pClient;
}
if (pServer->m_State < server_state_t::ss_active || !pServer->AuthClient(pChallenge))
return nullptr;
return nullptr;
CClient* pClient = v_CServer_ConnectClient(pServer, pChallenge);
return pClient;
}
//---------------------------------------------------------------------------------
@ -113,9 +124,9 @@ CClient* CServer::ConnectClient(CServer* pServer, user_creds_s* pChallenge)
// *pChallenge -
// *szMessage -
//---------------------------------------------------------------------------------
void CServer::RejectConnection(int iSocket, user_creds_s* pChallenge, const char* szMessage)
void CServer::RejectConnection(int iSocket, v_netadr_t* pNetAdr, const char* szMessage)
{
v_CServer_RejectConnection(this, iSocket, pChallenge, szMessage);
v_CServer_RejectConnection(this, iSocket, pNetAdr, szMessage);
}
///////////////////////////////////////////////////////////////////////////////

View File

@ -23,7 +23,7 @@ struct user_creds_s
int32_t m_nChallenge;
uint32_t m_nReservation;
uint64_t m_nNucleusID;
uint8_t* m_pUserID;
char* m_pUserID;
};
class CServer : public IServer
@ -41,7 +41,7 @@ public:
bool IsLoading(void) const { return m_State == server_state_t::ss_loading; }
bool IsDedicated(void) const { return g_bDedicated; }
bool AuthClient(user_creds_s* pChallenge);
void RejectConnection(int iSocket, user_creds_s* pCreds, const char* szMessage);
void RejectConnection(int iSocket, v_netadr_t* pNetAdr, const char* szMessage);
static CClient* ConnectClient(CServer* pServer, user_creds_s* pChallenge);
#endif // !CLIENT_DLL
@ -91,7 +91,7 @@ inline CMemory p_CServer_Authenticate;
inline auto v_CServer_ConnectClient = p_CServer_Authenticate.RCast<CClient* (*)(CServer* pServer, user_creds_s* pCreds)>();
inline CMemory p_CServer_RejectConnection;
inline auto v_CServer_RejectConnection = p_CServer_RejectConnection.RCast<void* (*)(CServer* pServer, int iSocket, user_creds_s* pCreds, const char* szMessage)>();
inline auto v_CServer_RejectConnection = p_CServer_RejectConnection.RCast<void* (*)(CServer* pServer, int iSocket, v_netadr_t* pNetAdr, const char* szMessage)>();
void CServer_Attach();
void CServer_Detach();
@ -124,9 +124,9 @@ class VServer : public IDetour
#endif
p_CServer_RejectConnection = g_GameDll.FindPatternSIMD(reinterpret_cast<rsig_t>("\x4C\x89\x4C\x24\x00\x53\x55\x56\x57\x48\x81\xEC\x00\x00\x00\x00\x49\x8B\xD9"), "xxxx?xxxxxxx????xxx");
v_CServer_Think = p_CServer_Think.RCast<void (*)(bool, bool)>(); /*48 89 5C 24 ?? 48 89 74 24 ?? 57 48 81 EC ?? ?? ?? ?? 80 3D ?? ?? ?? ?? ??*/
v_CServer_ConnectClient = p_CServer_Authenticate.RCast<CClient* (*)(CServer*, user_creds_s*)>(); /*40 55 57 41 55 41 57 48 8D AC 24 ?? ?? ?? ??*/
v_CServer_RejectConnection = p_CServer_RejectConnection.RCast<void* (*)(CServer*, int, user_creds_s*, const char*)>(); /*4C 89 4C 24 ?? 53 55 56 57 48 81 EC ?? ?? ?? ?? 49 8B D9*/
v_CServer_Think = p_CServer_Think.RCast<void (*)(bool, bool)>(); /*48 89 5C 24 ?? 48 89 74 24 ?? 57 48 81 EC ?? ?? ?? ?? 80 3D ?? ?? ?? ?? ??*/
v_CServer_ConnectClient = p_CServer_Authenticate.RCast<CClient* (*)(CServer*, user_creds_s*)>(); /*40 55 57 41 55 41 57 48 8D AC 24 ?? ?? ?? ??*/
v_CServer_RejectConnection = p_CServer_RejectConnection.RCast<void* (*)(CServer*, int, v_netadr_t*, const char*)>(); /*4C 89 4C 24 ?? 53 55 56 57 48 81 EC ?? ?? ?? ?? 49 8B D9*/
#endif // !CLIENT_DLL
}
virtual void GetVar(void) const

View File

@ -482,7 +482,7 @@ string Base64Decode(const string& svInput)
}
///////////////////////////////////////////////////////////////////////////////
// For encoding data in UTF8.
// For encoding data in UTF-8.
string UTF8Encode(const wstring& wsvInput)
{
string result;
@ -496,7 +496,7 @@ string UTF8Encode(const wstring& wsvInput)
}
///////////////////////////////////////////////////////////////////////////////
// For decoding data in UTF8.
// For decoding data in UTF-8.
string UTF8Decode(const string& svInput)
{
//struct destructible_codecvt : public std::codecvt<char32_t, char, std::mbstate_t>
@ -510,7 +510,7 @@ string UTF8Decode(const string& svInput)
}
///////////////////////////////////////////////////////////////////////////////
// For obtaining UTF8 character length.
// For obtaining UTF-8 character length.
size_t UTF8CharLength(const uint8_t cInput)
{
if ((cInput & 0xFE) == 0xFC)
@ -526,6 +526,69 @@ size_t UTF8CharLength(const uint8_t cInput)
return 1;
}
///////////////////////////////////////////////////////////////////////////////
// For checking if input string is a valid UTF-8 encoded string.
bool IsValidUTF8(char* pszString)
{
char v1; // r9
char* v2; // rdx
char v4; // r10
int v5; // er8
while (true)
{
while (true)
{
v1 = *pszString;
v2 = pszString++;
if (v1 < 0)
{
break;
}
if (!v1)
{
return true;
}
}
v4 = *pszString;
if ((*pszString & 0xC0) != 0x80)
{
break;
}
pszString = v2 + 2;
if (v1 >= 0xE0u)
{
v5 = *pszString & 0x3F | ((v4 & 0x3F | ((v1 & 0xF) << 6)) << 6);
if ((*pszString & 0xC0) != 0x80)
{
return false;
}
pszString = v2 + 3;
if (v1 >= 0xF0u)
{
if ((*pszString & 0xC0) != 0x80 || ((v5 << 6) | *pszString & 0x3Fu) > 0x10FFFF)
{
return false;
}
pszString = v2 + 4;
}
else if ((v5 - 55296) <= 0x7FF)
{
return false;
}
}
else if (v1 < 0xC2u)
{
return false;
}
}
return false;
}
///////////////////////////////////////////////////////////////////////////////
// For checking if a string is a number.
bool StringIsDigit(const string& svInput)

View File

@ -39,6 +39,7 @@ string Base64Decode(const string& svInput);
string UTF8Encode(const wstring& wsvInput);
string UTF8Decode(const string& svInput);
size_t UTF8CharLength(const uint8_t cInput);
bool IsValidUTF8(char* pszString);
bool StringIsDigit(const string& svInput);
bool CompareStringAlphabetically(const string& svA, const string& svB);

View File

@ -72,18 +72,10 @@ public:
{
return this->type;
}
inline string GetAddress(void) const
inline void GetAddress(char* pchBuffer, uint32_t nBufferLen) const
{
// Select a static buffer
static char s[4][INET6_ADDRSTRLEN];
static int slot = 0;
int useSlot = (slot++) % 4;
// Render into it
inet_ntop(AF_INET6, &this->adr, s[useSlot], sizeof(s[0]));
// Pray the caller uses it before it gets clobbered
return s[useSlot];
assert(nBufferLen >= INET6_ADDRSTRLEN);
inet_ntop(AF_INET6, &this->adr, pchBuffer, nBufferLen);
}
inline uint16_t GetPort(void) const
{