Ban system refactor

Use CUtlVector, and remove every copy caused by passing vectors by value. CUtlVector does not support copying. Also removed all extraneous std::string copies caused by calling itoa instead of std::to_string, or std::stoll, etc. All features have been tested and work as designed.
This commit is contained in:
Kawe Mazidjatari 2023-08-31 00:16:25 +02:00
parent 2e03139c39
commit cb1a69e82a
10 changed files with 140 additions and 88 deletions

View File

@ -216,7 +216,7 @@ Host_ReloadBanList_f
*/
void Host_ReloadBanList_f(const CCommand& args)
{
g_pBanSystem->Load(); // Reload banned list.
g_pBanSystem->LoadList(); // Reload banned list.
}
/*

View File

@ -250,7 +250,7 @@ void CHostState::Setup(void)
{
g_pHostState->LoadConfig();
#ifndef CLIENT_DLL
g_pBanSystem->Load();
g_pBanSystem->LoadList();
#endif // !CLIENT_DLL
ConVar_PurgeHostNames();
@ -303,12 +303,12 @@ void CHostState::Think(void) const
}
if (sv_autoReloadRate->GetBool())
{
if (g_ServerGlobalVariables->m_flCurTime > sv_autoReloadRate->GetDouble())
if (g_ServerGlobalVariables->m_flCurTime > sv_autoReloadRate->GetFloat())
{
Cbuf_AddText(Cbuf_GetCurrentPlayer(), "reload\n", cmd_source_t::kCommandSrcCode);
}
}
if (statsTimer.GetDurationInProgress().GetSeconds() > sv_statusRefreshRate->GetDouble())
if (statsTimer.GetDurationInProgress().GetSeconds() > sv_statusRefreshRate->GetFloat())
{
SetConsoleTitleA(Format("%s - %d/%d Players (%s on %s) - %d%% Server CPU (%.3f msec on frame %d)",
hostname->GetString(), g_pServer->GetNumClients(),
@ -319,13 +319,13 @@ void CHostState::Think(void) const
statsTimer.Start();
}
if (sv_globalBanlist->GetBool() &&
banListTimer.GetDurationInProgress().GetSeconds() > sv_banlistRefreshRate->GetDouble())
banListTimer.GetDurationInProgress().GetSeconds() > sv_banlistRefreshRate->GetFloat())
{
SV_CheckForBan();
banListTimer.Start();
}
#ifdef DEDICATED
if (pylonTimer.GetDurationInProgress().GetSeconds() > sv_pylonRefreshRate->GetDouble())
if (pylonTimer.GetDurationInProgress().GetSeconds() > sv_pylonRefreshRate->GetFloat())
{
const NetGameServer_t netGameServer
{

View File

@ -19,7 +19,7 @@
// Purpose: checks if particular client is banned on the comp server
//-----------------------------------------------------------------------------
void SV_IsClientBanned(CClient* pClient, const string& svIPAddr,
const uint64_t nNucleusID, const string& svPersonaName, const int nPort)
const NucleusID_t nNucleusID, const string& svPersonaName, const int nPort)
{
Assert(pClient != nullptr);
@ -53,16 +53,22 @@ void SV_IsClientBanned(CClient* pClient, const string& svIPAddr,
//-----------------------------------------------------------------------------
// Purpose: checks if particular client is banned on the master server
//-----------------------------------------------------------------------------
void SV_ProcessBulkCheck(const BannedVec_t& bannedVec)
void SV_ProcessBulkCheck(const CBanSystem::BannedList_t* pBannedVec, const bool bDelete)
{
BannedVec_t outBannedVec;
g_pMasterServer->GetBannedList(bannedVec, outBannedVec);
CBanSystem::BannedList_t* outBannedVec = new CBanSystem::BannedList_t();
g_pMasterServer->GetBannedList(*pBannedVec, *outBannedVec);
// Caller wants to destroy the vector.
if (bDelete)
{
delete pBannedVec;
}
if (!ThreadInMainThread())
{
g_TaskScheduler->Dispatch([outBannedVec]
{
SV_CheckForBan(&outBannedVec);
SV_CheckForBan(outBannedVec, true);
}, 0);
}
}
@ -70,11 +76,12 @@ void SV_ProcessBulkCheck(const BannedVec_t& bannedVec)
//-----------------------------------------------------------------------------
// Purpose: creates a snapshot of the currently connected clients
// Input : *pBannedVec - if passed, will check for bans and kick the clients
// bDelete - if set, will delete the passed in vector
//-----------------------------------------------------------------------------
void SV_CheckForBan(const BannedVec_t* pBannedVec /*= nullptr*/)
void SV_CheckForBan(const CBanSystem::BannedList_t* pBannedVec /*= nullptr*/, const bool bDelete /*= false*/)
{
Assert(ThreadInMainThread());
BannedVec_t bannedVec;
CBanSystem::BannedList_t* bannedVec = new CBanSystem::BannedList_t;
for (int c = 0; c < g_ServerGlobalVariables->m_nMaxClients; c++) // Loop through all possible client instances.
{
@ -93,25 +100,27 @@ void SV_CheckForBan(const BannedVec_t* pBannedVec /*= nullptr*/)
continue;
const char* szIPAddr = pNetChan->GetAddress(true);
const uint64_t nNucleusID = pClient->GetNucleusID();
const NucleusID_t nNucleusID = pClient->GetNucleusID();
// If no banned list was provided, build one with all clients
// on the server. This will be used for bulk checking so live
// bans could be performed, as this function is called periodically.
if (!pBannedVec)
bannedVec.push_back(std::make_pair(string(szIPAddr), nNucleusID));
bannedVec->AddToTail(CBanSystem::Banned_t(szIPAddr, nNucleusID));
else
{
// Check if current client is within provided banned list, and
// prune if so...
for (auto& it : *pBannedVec)
FOR_EACH_VEC(*pBannedVec, i)
{
if (it.second == pClient->GetNucleusID())
const CBanSystem::Banned_t& banned = (*pBannedVec)[i];
if (banned.m_NucleusID == pClient->GetNucleusID())
{
const int nUserID = pClient->GetUserID();
const int nPort = pNetChan->GetPort();
pClient->Disconnect(Reputation_t::REP_MARK_BAD, "%s", it.first.c_str());
pClient->Disconnect(Reputation_t::REP_MARK_BAD, "%s", banned.m_Address.String());
Warning(eDLL_T::SERVER, "Removed client '[%s]:%i' from slot #%i ('%llu' is banned globally!)\n",
szIPAddr, nPort, nUserID, nNucleusID);
}
@ -119,9 +128,19 @@ void SV_CheckForBan(const BannedVec_t* pBannedVec /*= nullptr*/)
}
}
if (!pBannedVec && !bannedVec.empty())
// Caller wants to destroy the vector.
if (bDelete && pBannedVec)
{
std::thread(&SV_ProcessBulkCheck, bannedVec).detach();
delete pBannedVec;
}
if (!pBannedVec && !bannedVec->IsEmpty())
{
std::thread(&SV_ProcessBulkCheck, bannedVec, true).detach();
}
else
{
delete bannedVec;
}
}

View File

@ -40,8 +40,8 @@ void SV_InitGameDLL();
void SV_ShutdownGameDLL();
bool SV_ActivateServer();
void SV_BroadcastVoiceData(CClient* cl, int nBytes, char* data);
void SV_IsClientBanned(CClient* pClient, const string& svIPAddr, const uint64_t nNucleusID, const string& svPersonaName, const int nPort);
void SV_CheckForBan(const BannedVec_t* pBannedVec = nullptr);
void SV_IsClientBanned(CClient* pClient, const string& svIPAddr, const NucleusID_t nNucleusID, const string& svPersonaName, const int nPort);
void SV_CheckForBan(const CBanSystem::BannedList_t* pBannedVec = nullptr, const bool bDelete = false);
///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////

View File

@ -605,7 +605,7 @@ void CBrowser::HostPanel(void)
{
g_TaskScheduler->Dispatch([]()
{
g_pBanSystem->Load();
g_pBanSystem->LoadList();
}, 0);
}
}

View File

@ -15,10 +15,10 @@
//-----------------------------------------------------------------------------
// Purpose: loads and parses the banned list
//-----------------------------------------------------------------------------
void CBanSystem::Load(void)
void CBanSystem::LoadList(void)
{
if (IsBanListValid())
m_vBanList.clear();
m_BannedList.Purge();
FileHandle_t pFile = FileSystem()->Open("banlist.json", "rt");
if (!pFile)
@ -48,10 +48,11 @@ void CBanSystem::Load(void)
nlohmann::json jsEntry = jsIn[std::to_string(i)];
if (!jsEntry.is_null())
{
const string svIpAddress = jsEntry["ipAddress"].get<string>();
const uint64_t nNucleusID = jsEntry["nucleusId"].get<uint64_t>();
Banned_t banned;
banned.m_Address = jsEntry["ipAddress"].get<string>().c_str();
banned.m_NucleusID = jsEntry["nucleusId"].get<NucleusID_t>();
m_vBanList.push_back(std::make_pair(svIpAddress, nNucleusID));
m_BannedList.AddToTail(banned);
}
}
}
@ -64,7 +65,7 @@ void CBanSystem::Load(void)
//-----------------------------------------------------------------------------
// Purpose: saves the banned list
//-----------------------------------------------------------------------------
void CBanSystem::Save(void) const
void CBanSystem::SaveList(void) const
{
FileHandle_t pFile = FileSystem()->Open("banlist.json", "wt", "PLATFORM");
if (!pFile)
@ -76,13 +77,17 @@ void CBanSystem::Save(void) const
try
{
nlohmann::json jsOut;
for (size_t i = 0; i < m_vBanList.size(); i++)
FOR_EACH_VEC(m_BannedList, i)
{
jsOut[std::to_string(i)]["ipAddress"] = m_vBanList[i].first;
jsOut[std::to_string(i)]["nucleusId"] = m_vBanList[i].second;
const Banned_t& banned = m_BannedList[i];
char idx[64]; itoa(i, idx, 10);
jsOut[idx]["ipAddress"] = banned.m_Address.String();
jsOut[idx]["nucleusId"] = banned.m_NucleusID;
}
jsOut["totalBans"] = m_vBanList.size();
jsOut["totalBans"] = m_BannedList.Count();
string svJsOut = jsOut.dump(4);
FileSystem()->Write(svJsOut.data(), svJsOut.size(), pFile);
@ -100,24 +105,22 @@ void CBanSystem::Save(void) const
// Input : *ipAddress -
// nucleusId -
//-----------------------------------------------------------------------------
bool CBanSystem::AddEntry(const char* ipAddress, const uint64_t nucleusId)
bool CBanSystem::AddEntry(const char* ipAddress, const NucleusID_t nucleusId)
{
Assert(VALID_CHARSTAR(ipAddress));
const auto idPair = std::make_pair(string(ipAddress), nucleusId);
const Banned_t banned(ipAddress, nucleusId);
if (IsBanListValid())
{
auto it = std::find(m_vBanList.begin(), m_vBanList.end(), idPair);
if (it == m_vBanList.end())
if (m_BannedList.Find(banned) == m_BannedList.InvalidIndex())
{
m_vBanList.push_back(idPair);
m_BannedList.AddToTail(banned);
return true;
}
}
else
{
m_vBanList.push_back(idPair);
m_BannedList.AddToTail(banned);
return true;
}
@ -129,23 +132,22 @@ bool CBanSystem::AddEntry(const char* ipAddress, const uint64_t nucleusId)
// Input : *ipAddress -
// nucleusId -
//-----------------------------------------------------------------------------
bool CBanSystem::DeleteEntry(const char* ipAddress, const uint64_t nucleusId)
bool CBanSystem::DeleteEntry(const char* ipAddress, const NucleusID_t nucleusId)
{
Assert(VALID_CHARSTAR(ipAddress));
if (IsBanListValid())
{
auto it = std::find_if(m_vBanList.begin(), m_vBanList.end(),
[&](const pair<const string, const uint64_t>& element)
{
return (strcmp(ipAddress, element.first.c_str()) == NULL
|| element.second == nucleusId);
});
if (it != m_vBanList.end())
FOR_EACH_VEC(m_BannedList, i)
{
m_vBanList.erase(it);
return true;
const Banned_t& banned = m_BannedList[i];
if (banned.m_NucleusID == nucleusId ||
banned.m_Address.IsEqual_CaseInsensitive(ipAddress))
{
m_BannedList.Remove(i);
return true;
}
}
}
@ -158,21 +160,22 @@ bool CBanSystem::DeleteEntry(const char* ipAddress, const uint64_t nucleusId)
// nucleusId -
// Output : true if banned, false if not banned
//-----------------------------------------------------------------------------
bool CBanSystem::IsBanned(const char* ipAddress, const uint64_t nucleusId) const
bool CBanSystem::IsBanned(const char* ipAddress, const NucleusID_t nucleusId) const
{
for (size_t i = 0; i < m_vBanList.size(); i++)
{
const string& bannedIpAddress = m_vBanList[i].first;
const uint64_t bannedNucleusID = m_vBanList[i].second;
if (bannedIpAddress.empty()
|| !bannedNucleusID) // Cannot be null.
FOR_EACH_VEC(m_BannedList, i)
{
const Banned_t& banned = m_BannedList[i];
if (banned.m_NucleusID == NULL ||
banned.m_Address.IsEmpty())
{
// Cannot be NULL.
continue;
}
if (bannedIpAddress.compare(ipAddress) == NULL
|| nucleusId == bannedNucleusID)
if (banned.m_NucleusID == nucleusId ||
banned.m_Address.IsEqual_CaseInsensitive(ipAddress))
{
return true;
}
@ -186,7 +189,7 @@ bool CBanSystem::IsBanned(const char* ipAddress, const uint64_t nucleusId) const
//-----------------------------------------------------------------------------
bool CBanSystem::IsBanListValid(void) const
{
return !m_vBanList.empty();
return !m_BannedList.IsEmpty();
}
//-----------------------------------------------------------------------------
@ -252,7 +255,7 @@ void CBanSystem::UnbanPlayer(const char* criteria)
bool bSave = false;
if (StringIsDigit(criteria)) // Check if we have an ip address or nucleus id.
{
if (DeleteEntry("<<invalid>>", std::stoll(criteria))) // Delete ban entry.
if (DeleteEntry("-<[InVaLiD]>-", atoll(criteria))) // Delete ban entry.
{
bSave = true;
}
@ -267,7 +270,7 @@ void CBanSystem::UnbanPlayer(const char* criteria)
if (bSave)
{
Save(); // Save modified vector to file.
SaveList(); // Save modified vector to file.
Msg(eDLL_T::SERVER, "Removed '%s' from banned list\n", criteria);
}
}
@ -282,7 +285,7 @@ void CBanSystem::UnbanPlayer(const char* criteria)
// Purpose: authors player by given name
// Input : *playerName -
// shouldBan - (only kicks if false)
// *reason -
// *reason -
//-----------------------------------------------------------------------------
void CBanSystem::AuthorPlayerByName(const char* playerName, const bool shouldBan, const char* reason)
{
@ -318,7 +321,7 @@ void CBanSystem::AuthorPlayerByName(const char* playerName, const bool shouldBan
if (bSave)
{
Save();
SaveList();
Msg(eDLL_T::SERVER, "Added '%s' to banned list\n", playerName);
}
else if (bDisconnect)
@ -331,7 +334,7 @@ void CBanSystem::AuthorPlayerByName(const char* playerName, const bool shouldBan
// Purpose: authors player by given nucleus id or ip address
// Input : *playerHandle -
// shouldBan - (only kicks if false)
// *reason -
// *reason -
//-----------------------------------------------------------------------------
void CBanSystem::AuthorPlayerById(const char* playerHandle, const bool shouldBan, const char* reason)
{
@ -358,16 +361,20 @@ void CBanSystem::AuthorPlayerById(const char* playerHandle, const bool shouldBan
if (bOnlyDigits)
{
uint64_t nTargetID = static_cast<uint64_t>(std::stoll(playerHandle));
if (nTargetID > static_cast<uint64_t>(MAX_PLAYERS)) // Is it a possible nucleusID?
char* pEnd = nullptr;
uint64_t nTargetID = strtoull(playerHandle, &pEnd, 10);
if (nTargetID > MAX_PLAYERS) // Is it a possible nucleusID?
{
uint64_t nNucleusID = pClient->GetNucleusID();
NucleusID_t nNucleusID = pClient->GetNucleusID();
if (nNucleusID != nTargetID)
continue;
}
else // If its not try by handle.
{
uint64_t nClientID = static_cast<uint64_t>(pClient->GetHandle());
if (nClientID != nTargetID)
continue;
}
@ -393,7 +400,7 @@ void CBanSystem::AuthorPlayerById(const char* playerHandle, const bool shouldBan
if (bSave)
{
Save();
SaveList();
Msg(eDLL_T::SERVER, "Added '%s' to banned list\n", playerHandle);
}
else if (bDisconnect)

View File

@ -1,6 +1,5 @@
#pragma once
typedef vector<std::pair<string, uint64_t>> BannedVec_t;
#include "ebisusdk/EbisuTypes.h"
enum EKickType
{
@ -13,13 +12,33 @@ enum EKickType
class CBanSystem
{
public:
void Load(void);
void Save(void) const;
struct Banned_t
{
Banned_t(const char* ipAddress = "", NucleusID_t nucleusId = NULL)
: m_Address(ipAddress)
, m_NucleusID(nucleusId)
{}
bool AddEntry(const char* ipAddress, const uint64_t nucleusId);
bool DeleteEntry(const char* ipAddress, const uint64_t nucleusId);
inline bool operator==(const Banned_t& other) const
{
return m_NucleusID == other.m_NucleusID
&& m_Address.IsEqual_CaseInsensitive(other.m_Address);
}
bool IsBanned(const char* ipAddress, const uint64_t nucleusId) const;
NucleusID_t m_NucleusID;
CUtlString m_Address;
};
typedef CUtlVector<Banned_t> BannedList_t;
public:
void LoadList(void);
void SaveList(void) const;
bool AddEntry(const char* ipAddress, const NucleusID_t nucleusId);
bool DeleteEntry(const char* ipAddress, const NucleusID_t nucleusId);
bool IsBanned(const char* ipAddress, const NucleusID_t nucleusId) const;
bool IsBanListValid(void) const;
void KickPlayerByName(const char* playerName, const char* reason = nullptr);
@ -34,7 +53,7 @@ private:
void AuthorPlayerByName(const char* playerName, const bool bBan, const char* reason = nullptr);
void AuthorPlayerById(const char* playerHandle, const bool bBan, const char* reason = nullptr);
BannedVec_t m_vBanList;
BannedList_t m_BannedList;
};
extern CBanSystem* g_pBanSystem;

View File

@ -193,15 +193,19 @@ bool CPylon::PostServerHost(string& outMessage, string& outToken,
// &outBannedVec -
// Output : True on success, false otherwise.
//-----------------------------------------------------------------------------
bool CPylon::GetBannedList(const BannedVec_t& inBannedVec, BannedVec_t& outBannedVec) const
bool CPylon::GetBannedList(const CBanSystem::BannedList_t& inBannedVec,
CBanSystem::BannedList_t& outBannedVec) const
{
nlohmann::json arrayJson = nlohmann::json::array();
for (const auto& bannedPair : inBannedVec)
FOR_EACH_VEC(inBannedVec, i)
{
const CBanSystem::Banned_t& banned = inBannedVec[i];
nlohmann::json player;
player["id"] = bannedPair.second;
player["ip"] = bannedPair.first;
player["id"] = banned.m_NucleusID;
player["ip"] = banned.m_Address;
arrayJson.push_back(player);
}
@ -227,12 +231,11 @@ bool CPylon::GetBannedList(const BannedVec_t& inBannedVec, BannedVec_t& outBanne
{
for (auto& obj : arrayJson["bannedPlayers"])
{
outBannedVec.push_back(
std::make_pair(
obj.value("reason", "#DISCONNECT_BANNED"),
obj.value("id", uint64_t(0))
)
CBanSystem::Banned_t banned(
obj.value("reason", "#DISCONNECT_BANNED").c_str(),
obj.value("id", NucleusID_t(NULL))
);
outBannedVec.AddToTail(banned);
}
return true;

View File

@ -10,7 +10,7 @@ public:
bool GetServerByToken(NetGameServer_t& slOutServer, string& outMessage, const string& svToken) const;
bool PostServerHost(string& outMessage, string& svOutToken, const NetGameServer_t& netGameServer) const;
bool GetBannedList(const BannedVec_t& inBannedVec, BannedVec_t& outBannedVec) const;
bool GetBannedList(const CBanSystem::BannedList_t& inBannedVec, CBanSystem::BannedList_t& outBannedVec) const;
bool CheckForBan(const string& ipAddress, const uint64_t nucleusId, const string& personaName, string& outReason) const;
void ExtractError(const nlohmann::json& resultBody, string& outMessage, CURLINFO status, const char* errorText = nullptr) const;

View File

@ -1,4 +1,8 @@
#pragma once
//-----------------------------------------------------------------------------
// Types
//-----------------------------------------------------------------------------
typedef unsigned __int64 NucleusID_t;
//-----------------------------------------------------------------------------
// General errors