mirror of
https://github.com/Mauler125/r5sdk.git
synced 2025-02-09 19:15:03 +01:00
Engine: improve CClient::Authenticate performance
- Removed 2 std::string copy constructions - Removed 32 sprintf calls per token auth request. - Fixed a bug where we format the NucleusID as s64 instead of u64. - Added additional hardening for when token/sessionId stitching fails, this will now always reject the connection. - Improved the macro to make sure we always free the JWT claims if it has been allocated.
This commit is contained in:
parent
bc5e14643c
commit
6afb5fe593
@ -246,6 +246,11 @@ target_compile_definitions( ${PROJECT_NAME} PRIVATE
|
||||
)
|
||||
endif()
|
||||
|
||||
target_include_directories( ${PROJECT_NAME} PRIVATE
|
||||
"${THIRDPARTY_SOURCE_DIR}/recast/"
|
||||
"${THIRDPARTY_SOURCE_DIR}/mbedtls/include"
|
||||
)
|
||||
|
||||
endmacro()
|
||||
|
||||
add_engine_project( "engine" )
|
||||
|
@ -11,11 +11,11 @@
|
||||
#include "core/stdafx.h"
|
||||
#include "tier1/cvar.h"
|
||||
#include "tier1/strtools.h"
|
||||
#include "mathlib/sha256.h"
|
||||
#include "engine/server/server.h"
|
||||
#include "engine/client/client.h"
|
||||
#ifndef CLIENT_DLL
|
||||
#include "jwt/include/decode.h"
|
||||
#include "mbedtls/include/mbedtls/sha256.h"
|
||||
#endif
|
||||
|
||||
// Absolute max string cmd length, any character past this will be NULLED.
|
||||
@ -78,34 +78,47 @@ bool CClient::Authenticate(const char* const playerName, char* const reasonBuf,
|
||||
if (IsFakeClient() || GetNetChan()->GetRemoteAddress().IsLoopback())
|
||||
return true;
|
||||
|
||||
#define FORMAT_ERROR_REASON(fmt, ...) V_snprintf(reasonBuf, reasonBufLen, fmt, ##__VA_ARGS__);
|
||||
l8w8jwt_claim* claims = nullptr;
|
||||
size_t numClaims = 0;
|
||||
|
||||
// formats the error reason, and frees the claims and returns
|
||||
#define ERROR_AND_RETURN(fmt, ...) \
|
||||
do {\
|
||||
V_snprintf(reasonBuf, reasonBufLen, fmt, ##__VA_ARGS__); \
|
||||
if (claims) {\
|
||||
l8w8jwt_free_claims(claims, numClaims); \
|
||||
}\
|
||||
return false; \
|
||||
} while(0)\
|
||||
|
||||
KeyValues* const cl_onlineAuthTokenKv = this->m_ConVars->FindKey("cl_onlineAuthToken");
|
||||
KeyValues* const cl_onlineAuthTokenSignature1Kv = this->m_ConVars->FindKey("cl_onlineAuthTokenSignature1");
|
||||
KeyValues* const cl_onlineAuthTokenSignature2Kv = this->m_ConVars->FindKey("cl_onlineAuthTokenSignature2");
|
||||
|
||||
if (!cl_onlineAuthTokenKv || !cl_onlineAuthTokenSignature1Kv)
|
||||
{
|
||||
FORMAT_ERROR_REASON("Missing token");
|
||||
return false;
|
||||
}
|
||||
ERROR_AND_RETURN("Missing token");
|
||||
|
||||
const char* const onlineAuthToken = cl_onlineAuthTokenKv->GetString();
|
||||
const char* const onlineAuthTokenSignature1 = cl_onlineAuthTokenSignature1Kv->GetString();
|
||||
const char* const onlineAuthTokenSignature2 = cl_onlineAuthTokenSignature2Kv->GetString();
|
||||
|
||||
const std::string fullToken = Format("%s.%s%s", onlineAuthToken, onlineAuthTokenSignature1, onlineAuthTokenSignature2);
|
||||
char fullToken[1024]; // enough buffer for 3x255, which is cvar count * userinfo str limit.
|
||||
const int tokenLen = snprintf(fullToken, sizeof(fullToken), "%s.%s%s",
|
||||
onlineAuthToken, onlineAuthTokenSignature1, onlineAuthTokenSignature2);
|
||||
|
||||
if (tokenLen < 0)
|
||||
ERROR_AND_RETURN("Token stitching failed");
|
||||
|
||||
struct l8w8jwt_decoding_params params;
|
||||
l8w8jwt_decoding_params_init(¶ms);
|
||||
|
||||
params.alg = L8W8JWT_ALG_RS256;
|
||||
|
||||
params.jwt = (char*)fullToken.c_str();
|
||||
params.jwt_length = fullToken.length();
|
||||
params.jwt = (char*)fullToken;
|
||||
params.jwt_length = tokenLen;
|
||||
|
||||
params.verification_key = (unsigned char*)JWT_PUBLIC_KEY;
|
||||
params.verification_key_length = strlen(JWT_PUBLIC_KEY);
|
||||
params.verification_key_length = sizeof(JWT_PUBLIC_KEY);
|
||||
|
||||
params.validate_exp = sv_onlineAuthValidateExpiry->GetBool();
|
||||
params.exp_tolerance_seconds = (uint8_t)sv_onlineAuthExpiryTolerance->GetInt();
|
||||
@ -113,29 +126,18 @@ bool CClient::Authenticate(const char* const playerName, char* const reasonBuf,
|
||||
params.validate_iat = sv_onlineAuthValidateIssuedAt->GetBool();
|
||||
params.iat_tolerance_seconds = (uint8_t)sv_onlineAuthIssuedAtTolerance->GetInt();
|
||||
|
||||
l8w8jwt_claim* claims = nullptr;
|
||||
size_t numClaims = 0;
|
||||
|
||||
enum l8w8jwt_validation_result validation_result;
|
||||
const int r = l8w8jwt_decode(¶ms, &validation_result, &claims, &numClaims);
|
||||
|
||||
if (r != L8W8JWT_SUCCESS)
|
||||
{
|
||||
FORMAT_ERROR_REASON("Code %i", r);
|
||||
l8w8jwt_free_claims(claims, numClaims);
|
||||
|
||||
return false;
|
||||
}
|
||||
ERROR_AND_RETURN("Code %i", r);
|
||||
|
||||
if (validation_result != L8W8JWT_VALID)
|
||||
{
|
||||
char reasonBuffer[256];
|
||||
l8w8jwt_get_validation_result_desc(validation_result, reasonBuffer, sizeof(reasonBuffer));
|
||||
|
||||
FORMAT_ERROR_REASON("%s", reasonBuffer);
|
||||
l8w8jwt_free_claims(claims, numClaims);
|
||||
|
||||
return false;
|
||||
ERROR_AND_RETURN("%s", reasonBuffer);
|
||||
}
|
||||
|
||||
bool foundSessionId = false;
|
||||
@ -146,39 +148,34 @@ bool CClient::Authenticate(const char* const playerName, char* const reasonBuf,
|
||||
{
|
||||
const char* const sessionId = claims[i].value;
|
||||
|
||||
const std::string newId = Format(
|
||||
"%lld-%s-%s",
|
||||
this->m_DataBlock.userData,
|
||||
char newId[256];
|
||||
const int idLen = snprintf(newId, sizeof(newId), "%llu-%s-%s",
|
||||
(NucleusID_t)this->m_DataBlock.userData,
|
||||
playerName,
|
||||
g_MasterServer.GetHostIP().c_str()
|
||||
);
|
||||
g_MasterServer.GetHostIP().c_str());
|
||||
|
||||
DevMsg(eDLL_T::SERVER, "%s: newId=%s\n", __FUNCTION__, newId.c_str());
|
||||
const std::string hashedNewId = sha256(newId);
|
||||
if (idLen < 0)
|
||||
ERROR_AND_RETURN("Session ID stitching failed");
|
||||
|
||||
if (hashedNewId.compare(sessionId) != 0)
|
||||
{
|
||||
FORMAT_ERROR_REASON("Token is not authorized for the connecting client");
|
||||
l8w8jwt_free_claims(claims, numClaims);
|
||||
uint8_t sessionHash[32]; // hash decoded from JWT token
|
||||
V_hextobinary(sessionId, strlen(sessionId), sessionHash, sizeof(sessionHash));
|
||||
|
||||
return false;
|
||||
}
|
||||
uint8_t oobHash[32]; // hash of data collected from out of band packet
|
||||
const int shRet = mbedtls_sha256((const uint8_t*)newId, idLen, oobHash, NULL);
|
||||
|
||||
if (memcmp(oobHash, sessionHash, sizeof(sessionHash)) != 0)
|
||||
ERROR_AND_RETURN("Token is not authorized for the connecting client");
|
||||
|
||||
foundSessionId = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!foundSessionId)
|
||||
{
|
||||
FORMAT_ERROR_REASON("No session ID");
|
||||
l8w8jwt_free_claims(claims, numClaims);
|
||||
|
||||
return false;
|
||||
}
|
||||
ERROR_AND_RETURN("No session ID");
|
||||
|
||||
l8w8jwt_free_claims(claims, numClaims);
|
||||
|
||||
#undef REJECT_CONNECTION
|
||||
#undef ERROR_AND_RETURN
|
||||
#endif // !CLIENT_DLL
|
||||
|
||||
return true;
|
||||
@ -220,8 +217,8 @@ bool CClient::Connect(const char* szName, CNetChan* pNetChan, bool bFakePlayer,
|
||||
{
|
||||
const char* const netAdr = pNetChan ? pNetChan->GetAddress() : "<unknown>";
|
||||
|
||||
Warning(eDLL_T::SERVER, "Connection rejected for '%s' ('%llu' failed online authentication!)\n",
|
||||
netAdr, m_nNucleusID);
|
||||
Warning(eDLL_T::SERVER, "Client '%s' ('%llu') failed online authentication! [%s]\n",
|
||||
netAdr, (NucleusID_t)m_DataBlock.userData, authFailReason);
|
||||
}
|
||||
|
||||
return false;
|
||||
@ -477,7 +474,7 @@ bool CClient::VProcessSetConVar(CClient* pClient, NET_SetConVar* pMsg)
|
||||
bool bFunky = false;
|
||||
for (const char* s = name; *s != '\0'; ++s)
|
||||
{
|
||||
if (!isalnum(*s) && *s != '_')
|
||||
if (!V_isalnum(*s) && *s != '_')
|
||||
{
|
||||
bFunky = true;
|
||||
break;
|
||||
|
@ -83,6 +83,7 @@ inline bool V_iswdigit(int c)
|
||||
return (((uint)(c - '0')) < 10);
|
||||
}
|
||||
|
||||
void V_hextobinary(char const* in, size_t numchars, byte* out, size_t maxoutputbytes);
|
||||
void V_binarytohex(const byte* in, size_t inputbytes, char* out, size_t outsize);
|
||||
ssize_t V_vsnprintfRet(char* pDest, size_t maxLen, const char* pFormat, va_list params, bool* pbTruncated);
|
||||
|
||||
|
@ -257,6 +257,66 @@ bool V_IsAllDigit(const char* pString)
|
||||
return true;
|
||||
}
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Purpose: Returns the 4 bit nibble for a hex character
|
||||
// Input : c -
|
||||
// Output : unsigned char
|
||||
//-----------------------------------------------------------------------------
|
||||
static unsigned char V_nibble(char c)
|
||||
{
|
||||
if ((c >= '0') &&
|
||||
(c <= '9'))
|
||||
{
|
||||
return (unsigned char)(c - '0');
|
||||
}
|
||||
|
||||
if ((c >= 'A') &&
|
||||
(c <= 'F'))
|
||||
{
|
||||
return (unsigned char)(c - 'A' + 0x0a);
|
||||
}
|
||||
|
||||
if ((c >= 'a') &&
|
||||
(c <= 'f'))
|
||||
{
|
||||
return (unsigned char)(c - 'a' + 0x0a);
|
||||
}
|
||||
|
||||
return '0';
|
||||
}
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Purpose:
|
||||
// Input : *in -
|
||||
// numchars -
|
||||
// *out -
|
||||
// maxoutputbytes -
|
||||
//-----------------------------------------------------------------------------
|
||||
void V_hextobinary(char const* in, size_t numchars, byte* out, size_t maxoutputbytes)
|
||||
{
|
||||
size_t len = V_strlen(in);
|
||||
numchars = Min(len, numchars);
|
||||
// Make sure it's even
|
||||
numchars = (numchars) & ~0x1;
|
||||
|
||||
// Must be an even # of input characters (two chars per output byte)
|
||||
Assert(numchars >= 2);
|
||||
|
||||
memset(out, 0x00, maxoutputbytes);
|
||||
|
||||
byte* p;
|
||||
size_t i;
|
||||
|
||||
p = out;
|
||||
for (i = 0;
|
||||
(i < numchars) &&
|
||||
((size_t)(p - out) < maxoutputbytes);
|
||||
i += 2, p++)
|
||||
{
|
||||
*p = (V_nibble(in[i]) << 4) | V_nibble(in[i + 1]);
|
||||
}
|
||||
}
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Purpose:
|
||||
// Input : *in -
|
||||
@ -1163,7 +1223,7 @@ void V_StripExtension(const char* in, char* out, size_t outSize)
|
||||
|
||||
if (end > 0 && !PATHSEPARATOR(in[end]) && end < outSize)
|
||||
{
|
||||
size_t nChars = MIN(end, outSize - 1);
|
||||
size_t nChars = Min(end, outSize - 1);
|
||||
if (out != in)
|
||||
{
|
||||
memcpy(out, in, nChars);
|
||||
@ -1286,7 +1346,7 @@ void V_FileBase(const char* in, char* out, size_t maxlen)
|
||||
// Length of new sting
|
||||
len = end - start + 1;
|
||||
|
||||
size_t maxcopy = MIN(len + 1, maxlen);
|
||||
size_t maxcopy = Min(len + 1, maxlen);
|
||||
|
||||
// Copy partial string
|
||||
V_strncpy(out, &in[start], maxcopy);
|
||||
|
Loading…
x
Reference in New Issue
Block a user