Engine: improve CClient::Authenticate performance

- Removed 2 std::string copy constructions
- Removed 32 sprintf calls per token auth request.
- Fixed a bug where we format the NucleusID as s64 instead of u64.
- Added additional hardening for when token/sessionId stitching fails, this will now always reject the connection.
- Improved the macro to make sure we always free the JWT claims if it has been allocated.
This commit is contained in:
Kawe Mazidjatari 2024-02-22 00:59:00 +01:00
parent bc5e14643c
commit 6afb5fe593
4 changed files with 111 additions and 48 deletions

View File

@ -246,6 +246,11 @@ target_compile_definitions( ${PROJECT_NAME} PRIVATE
)
endif()
target_include_directories( ${PROJECT_NAME} PRIVATE
"${THIRDPARTY_SOURCE_DIR}/recast/"
"${THIRDPARTY_SOURCE_DIR}/mbedtls/include"
)
endmacro()
add_engine_project( "engine" )

View File

@ -11,11 +11,11 @@
#include "core/stdafx.h"
#include "tier1/cvar.h"
#include "tier1/strtools.h"
#include "mathlib/sha256.h"
#include "engine/server/server.h"
#include "engine/client/client.h"
#ifndef CLIENT_DLL
#include "jwt/include/decode.h"
#include "mbedtls/include/mbedtls/sha256.h"
#endif
// Absolute max string cmd length, any character past this will be NULLED.
@ -78,34 +78,47 @@ bool CClient::Authenticate(const char* const playerName, char* const reasonBuf,
if (IsFakeClient() || GetNetChan()->GetRemoteAddress().IsLoopback())
return true;
#define FORMAT_ERROR_REASON(fmt, ...) V_snprintf(reasonBuf, reasonBufLen, fmt, ##__VA_ARGS__);
l8w8jwt_claim* claims = nullptr;
size_t numClaims = 0;
// formats the error reason, and frees the claims and returns
#define ERROR_AND_RETURN(fmt, ...) \
do {\
V_snprintf(reasonBuf, reasonBufLen, fmt, ##__VA_ARGS__); \
if (claims) {\
l8w8jwt_free_claims(claims, numClaims); \
}\
return false; \
} while(0)\
KeyValues* const cl_onlineAuthTokenKv = this->m_ConVars->FindKey("cl_onlineAuthToken");
KeyValues* const cl_onlineAuthTokenSignature1Kv = this->m_ConVars->FindKey("cl_onlineAuthTokenSignature1");
KeyValues* const cl_onlineAuthTokenSignature2Kv = this->m_ConVars->FindKey("cl_onlineAuthTokenSignature2");
if (!cl_onlineAuthTokenKv || !cl_onlineAuthTokenSignature1Kv)
{
FORMAT_ERROR_REASON("Missing token");
return false;
}
ERROR_AND_RETURN("Missing token");
const char* const onlineAuthToken = cl_onlineAuthTokenKv->GetString();
const char* const onlineAuthTokenSignature1 = cl_onlineAuthTokenSignature1Kv->GetString();
const char* const onlineAuthTokenSignature2 = cl_onlineAuthTokenSignature2Kv->GetString();
const std::string fullToken = Format("%s.%s%s", onlineAuthToken, onlineAuthTokenSignature1, onlineAuthTokenSignature2);
char fullToken[1024]; // enough buffer for 3x255, which is cvar count * userinfo str limit.
const int tokenLen = snprintf(fullToken, sizeof(fullToken), "%s.%s%s",
onlineAuthToken, onlineAuthTokenSignature1, onlineAuthTokenSignature2);
if (tokenLen < 0)
ERROR_AND_RETURN("Token stitching failed");
struct l8w8jwt_decoding_params params;
l8w8jwt_decoding_params_init(&params);
params.alg = L8W8JWT_ALG_RS256;
params.jwt = (char*)fullToken.c_str();
params.jwt_length = fullToken.length();
params.jwt = (char*)fullToken;
params.jwt_length = tokenLen;
params.verification_key = (unsigned char*)JWT_PUBLIC_KEY;
params.verification_key_length = strlen(JWT_PUBLIC_KEY);
params.verification_key_length = sizeof(JWT_PUBLIC_KEY);
params.validate_exp = sv_onlineAuthValidateExpiry->GetBool();
params.exp_tolerance_seconds = (uint8_t)sv_onlineAuthExpiryTolerance->GetInt();
@ -113,29 +126,18 @@ bool CClient::Authenticate(const char* const playerName, char* const reasonBuf,
params.validate_iat = sv_onlineAuthValidateIssuedAt->GetBool();
params.iat_tolerance_seconds = (uint8_t)sv_onlineAuthIssuedAtTolerance->GetInt();
l8w8jwt_claim* claims = nullptr;
size_t numClaims = 0;
enum l8w8jwt_validation_result validation_result;
const int r = l8w8jwt_decode(&params, &validation_result, &claims, &numClaims);
if (r != L8W8JWT_SUCCESS)
{
FORMAT_ERROR_REASON("Code %i", r);
l8w8jwt_free_claims(claims, numClaims);
return false;
}
ERROR_AND_RETURN("Code %i", r);
if (validation_result != L8W8JWT_VALID)
{
char reasonBuffer[256];
l8w8jwt_get_validation_result_desc(validation_result, reasonBuffer, sizeof(reasonBuffer));
FORMAT_ERROR_REASON("%s", reasonBuffer);
l8w8jwt_free_claims(claims, numClaims);
return false;
ERROR_AND_RETURN("%s", reasonBuffer);
}
bool foundSessionId = false;
@ -146,39 +148,34 @@ bool CClient::Authenticate(const char* const playerName, char* const reasonBuf,
{
const char* const sessionId = claims[i].value;
const std::string newId = Format(
"%lld-%s-%s",
this->m_DataBlock.userData,
char newId[256];
const int idLen = snprintf(newId, sizeof(newId), "%llu-%s-%s",
(NucleusID_t)this->m_DataBlock.userData,
playerName,
g_MasterServer.GetHostIP().c_str()
);
g_MasterServer.GetHostIP().c_str());
DevMsg(eDLL_T::SERVER, "%s: newId=%s\n", __FUNCTION__, newId.c_str());
const std::string hashedNewId = sha256(newId);
if (idLen < 0)
ERROR_AND_RETURN("Session ID stitching failed");
if (hashedNewId.compare(sessionId) != 0)
{
FORMAT_ERROR_REASON("Token is not authorized for the connecting client");
l8w8jwt_free_claims(claims, numClaims);
uint8_t sessionHash[32]; // hash decoded from JWT token
V_hextobinary(sessionId, strlen(sessionId), sessionHash, sizeof(sessionHash));
return false;
}
uint8_t oobHash[32]; // hash of data collected from out of band packet
const int shRet = mbedtls_sha256((const uint8_t*)newId, idLen, oobHash, NULL);
if (memcmp(oobHash, sessionHash, sizeof(sessionHash)) != 0)
ERROR_AND_RETURN("Token is not authorized for the connecting client");
foundSessionId = true;
}
}
if (!foundSessionId)
{
FORMAT_ERROR_REASON("No session ID");
l8w8jwt_free_claims(claims, numClaims);
return false;
}
ERROR_AND_RETURN("No session ID");
l8w8jwt_free_claims(claims, numClaims);
#undef REJECT_CONNECTION
#undef ERROR_AND_RETURN
#endif // !CLIENT_DLL
return true;
@ -220,8 +217,8 @@ bool CClient::Connect(const char* szName, CNetChan* pNetChan, bool bFakePlayer,
{
const char* const netAdr = pNetChan ? pNetChan->GetAddress() : "<unknown>";
Warning(eDLL_T::SERVER, "Connection rejected for '%s' ('%llu' failed online authentication!)\n",
netAdr, m_nNucleusID);
Warning(eDLL_T::SERVER, "Client '%s' ('%llu') failed online authentication! [%s]\n",
netAdr, (NucleusID_t)m_DataBlock.userData, authFailReason);
}
return false;
@ -477,7 +474,7 @@ bool CClient::VProcessSetConVar(CClient* pClient, NET_SetConVar* pMsg)
bool bFunky = false;
for (const char* s = name; *s != '\0'; ++s)
{
if (!isalnum(*s) && *s != '_')
if (!V_isalnum(*s) && *s != '_')
{
bFunky = true;
break;

View File

@ -83,6 +83,7 @@ inline bool V_iswdigit(int c)
return (((uint)(c - '0')) < 10);
}
void V_hextobinary(char const* in, size_t numchars, byte* out, size_t maxoutputbytes);
void V_binarytohex(const byte* in, size_t inputbytes, char* out, size_t outsize);
ssize_t V_vsnprintfRet(char* pDest, size_t maxLen, const char* pFormat, va_list params, bool* pbTruncated);

View File

@ -257,6 +257,66 @@ bool V_IsAllDigit(const char* pString)
return true;
}
//-----------------------------------------------------------------------------
// Purpose: Returns the 4 bit nibble for a hex character
// Input : c -
// Output : unsigned char
//-----------------------------------------------------------------------------
static unsigned char V_nibble(char c)
{
if ((c >= '0') &&
(c <= '9'))
{
return (unsigned char)(c - '0');
}
if ((c >= 'A') &&
(c <= 'F'))
{
return (unsigned char)(c - 'A' + 0x0a);
}
if ((c >= 'a') &&
(c <= 'f'))
{
return (unsigned char)(c - 'a' + 0x0a);
}
return '0';
}
//-----------------------------------------------------------------------------
// Purpose:
// Input : *in -
// numchars -
// *out -
// maxoutputbytes -
//-----------------------------------------------------------------------------
void V_hextobinary(char const* in, size_t numchars, byte* out, size_t maxoutputbytes)
{
size_t len = V_strlen(in);
numchars = Min(len, numchars);
// Make sure it's even
numchars = (numchars) & ~0x1;
// Must be an even # of input characters (two chars per output byte)
Assert(numchars >= 2);
memset(out, 0x00, maxoutputbytes);
byte* p;
size_t i;
p = out;
for (i = 0;
(i < numchars) &&
((size_t)(p - out) < maxoutputbytes);
i += 2, p++)
{
*p = (V_nibble(in[i]) << 4) | V_nibble(in[i + 1]);
}
}
//-----------------------------------------------------------------------------
// Purpose:
// Input : *in -
@ -1163,7 +1223,7 @@ void V_StripExtension(const char* in, char* out, size_t outSize)
if (end > 0 && !PATHSEPARATOR(in[end]) && end < outSize)
{
size_t nChars = MIN(end, outSize - 1);
size_t nChars = Min(end, outSize - 1);
if (out != in)
{
memcpy(out, in, nChars);
@ -1286,7 +1346,7 @@ void V_FileBase(const char* in, char* out, size_t maxlen)
// Length of new sting
len = end - start + 1;
size_t maxcopy = MIN(len + 1, maxlen);
size_t maxcopy = Min(len + 1, maxlen);
// Copy partial string
V_strncpy(out, &in[start], maxcopy);