diff --git a/r5dev/engine/client/cl_rcon.cpp b/r5dev/engine/client/cl_rcon.cpp index 795343f1..b04a2c7f 100644 --- a/r5dev/engine/client/cl_rcon.cpp +++ b/r5dev/engine/client/cl_rcon.cpp @@ -44,7 +44,7 @@ void CRConClient::Init(void) //----------------------------------------------------------------------------- void CRConClient::Shutdown(void) { - Disconnect(); + Disconnect("shutdown"); } //----------------------------------------------------------------------------- @@ -54,8 +54,13 @@ void CRConClient::RunFrame(void) { if (IsInitialized() && IsConnected()) { - CConnectedNetConsoleData* pData = m_Socket.GetAcceptedSocketData(0); - Recv(pData); + CConnectedNetConsoleData* pData = GetData(); + Assert(pData != nullptr); + + if (pData) + { + Recv(pData); + } } } @@ -66,11 +71,12 @@ void CRConClient::Disconnect(const char* szReason) { if (IsConnected()) { - if (szReason) + if (!szReason) { - DevMsg(eDLL_T::CLIENT, "%s", szReason); + szReason = "unknown reason"; } + DevMsg(eDLL_T::CLIENT, "Disconnect: (%s)\n", szReason); m_Socket.CloseAcceptedSocket(0); } } @@ -80,7 +86,7 @@ void CRConClient::Disconnect(const char* szReason) // Input : *pMsgBug - // nMsgLen - //----------------------------------------------------------------------------- -bool CRConClient::ProcessMessage(const char* pMsgBuf, int nMsgLen) +bool CRConClient::ProcessMessage(const char* pMsgBuf, const int nMsgLen) { sv_rcon::response response; bool bSuccess = Decode(&response, pMsgBuf, nMsgLen); @@ -99,26 +105,26 @@ bool CRConClient::ProcessMessage(const char* pMsgBuf, int nMsgLen) { const long i = strtol(response.responseval().c_str(), NULL, NULL); const bool bLocalHost = (g_pNetAdr->ComparePort(m_Address) && g_pNetAdr->CompareAdr(m_Address)); + const char* szEnable = nullptr; const SocketHandle_t hSocket = GetSocket(); if (!i) // sv_rcon_sendlogs is not set. { if (!bLocalHost && cl_rcon_request_sendlogs->GetBool()) { - vector vecMsg; - bool ret = Serialize(vecMsg, "", "1", cl_rcon::request_t::SERVERDATA_REQUEST_SEND_CONSOLE_LOG); - - if (ret && !Send(hSocket, vecMsg.data(), int(vecMsg.size()))) - { - Error(eDLL_T::CLIENT, NO_ERROR, "Failed to send RCON message: (%s)\n", "SOCKET_ERROR"); - } + szEnable = "1"; } } else if (bLocalHost) { // Don't send logs to local host, it already gets logged to the same console. + szEnable = "0"; + } + + if (szEnable) + { vector vecMsg; - bool ret = Serialize(vecMsg, "", "0", cl_rcon::request_t::SERVERDATA_REQUEST_SEND_CONSOLE_LOG); + bool ret = Serialize(vecMsg, "", szEnable, cl_rcon::request_t::SERVERDATA_REQUEST_SEND_CONSOLE_LOG); if (ret && !Send(hSocket, vecMsg.data(), int(vecMsg.size()))) { @@ -159,6 +165,15 @@ bool CRConClient::Serialize(vector& vecBuf, const char* szReqBuf, return CL_NetConSerialize(this, vecBuf, szReqBuf, szReqVal, requestType); } +//----------------------------------------------------------------------------- +// Purpose: retrieves the remote socket +// Output : SOCKET_ERROR (-1) on failure +//----------------------------------------------------------------------------- +CConnectedNetConsoleData* CRConClient::GetData(void) +{ + return SH_GetNetConData(this, 0); +} + //----------------------------------------------------------------------------- // Purpose: retrieves the remote socket // Output : SOCKET_ERROR (-1) on failure @@ -168,7 +183,6 @@ SocketHandle_t CRConClient::GetSocket(void) return SH_GetNetConSocketHandle(this, 0); } - //----------------------------------------------------------------------------- // Purpose: checks if client rcon is initialized //----------------------------------------------------------------------------- @@ -190,4 +204,4 @@ CRConClient g_RCONClient; CRConClient* RCONClient() // Singleton RCON Client. { return &g_RCONClient; -} \ No newline at end of file +} diff --git a/r5dev/engine/client/cl_rcon.h b/r5dev/engine/client/cl_rcon.h index 513f1a35..404040fd 100644 --- a/r5dev/engine/client/cl_rcon.h +++ b/r5dev/engine/client/cl_rcon.h @@ -18,7 +18,7 @@ public: virtual void Disconnect(const char* szReason = nullptr) override; - virtual bool ProcessMessage(const char* pMsgBuf, int nMsgLen) override; + virtual bool ProcessMessage(const char* pMsgBuf, const int nMsgLen) override; bool Serialize(vector& vecBuf, const char* szReqBuf, const char* szReqVal, const cl_rcon::request_t requestType) const; @@ -26,6 +26,7 @@ public: bool IsInitialized(void) const; bool IsConnected(void); + CConnectedNetConsoleData* GetData(void); SocketHandle_t GetSocket(void); private: diff --git a/r5dev/engine/server/sv_rcon.cpp b/r5dev/engine/server/sv_rcon.cpp index 994198f0..42114793 100644 --- a/r5dev/engine/server/sv_rcon.cpp +++ b/r5dev/engine/server/sv_rcon.cpp @@ -93,7 +93,7 @@ void CRConServer::Think(void) const CConnectedNetConsoleData* pData = m_Socket.GetAcceptedSocketData(m_nConnIndex); if (!pData->m_bAuthorized) { - Disconnect(); + Disconnect("redundant"); } } } @@ -119,7 +119,7 @@ bool CRConServer::SetPassword(const char* pszPassword) m_bInitialized = false; m_Socket.CloseAllAcceptedSockets(); - const size_t nLen = std::strlen(pszPassword); + const size_t nLen = strlen(pszPassword); if (nLen < RCON_MIN_PASSWORD_LEN) { if (nLen > NULL) @@ -168,8 +168,8 @@ void CRConServer::RunFrame(void) { SendEncode(pData->m_hSocket, s_BannedMessage, "", sv_rcon::response_t::SERVERDATA_RESPONSE_AUTH, int(eDLL_T::NETCON)); - Disconnect(); + Disconnect("banned"); continue; } @@ -184,9 +184,9 @@ void CRConServer::RunFrame(void) // nMsgLen - // Output: true on success, false otherwise //----------------------------------------------------------------------------- -bool CRConServer::SendToAll(const char* pMsgBuf, int nMsgLen) const +bool CRConServer::SendToAll(const char* pMsgBuf, const int nMsgLen) const { - std::ostringstream sendbuf; + ostringstream sendbuf; const u_long nLen = htonl(u_long(nMsgLen)); bool bSuccess = true; @@ -274,7 +274,8 @@ bool CRConServer::SendEncode(const SocketHandle_t hSocket, const char* pResponse //----------------------------------------------------------------------------- // Purpose: serializes input -// Input : *responseMsg - +// Input : &vecBuf - +// *responseMsg - // *responseVal - // responseType - // nMessageId - @@ -306,10 +307,10 @@ bool CRConServer::Serialize(vector& vecBuf, const char* pResponseMsg, cons //----------------------------------------------------------------------------- // Purpose: authenticate new connections -// Input : *cl_request - +// Input : &request - // *pData - //----------------------------------------------------------------------------- -void CRConServer::Authenticate(const cl_rcon::request& cl_request, CConnectedNetConsoleData* pData) +void CRConServer::Authenticate(const cl_rcon::request& request, CConnectedNetConsoleData* pData) { if (pData->m_bAuthorized) { @@ -317,7 +318,7 @@ void CRConServer::Authenticate(const cl_rcon::request& cl_request, CConnectedNet } else // Authorize. { - if (Comparator(cl_request.requestmsg())) + if (Comparator(request.requestmsg())) { pData->m_bAuthorized = true; if (++m_nAuthConnections >= sv_rcon_maxconnections->GetInt()) @@ -349,20 +350,20 @@ void CRConServer::Authenticate(const cl_rcon::request& cl_request, CConnectedNet //----------------------------------------------------------------------------- // Purpose: sha256 hashed password comparison -// Input : svCompare - +// Input : &svPassword - // Output : true if matches, false otherwise //----------------------------------------------------------------------------- -bool CRConServer::Comparator(std::string svPassword) const +bool CRConServer::Comparator(const string& svPassword) const { - svPassword = sha256(svPassword); + string passwordHash = sha256(svPassword); if (sv_rcon_debug->GetBool()) { DevMsg(eDLL_T::SERVER, "+---------------------------------------------------------------------------+\n"); DevMsg(eDLL_T::SERVER, "[ Server: '%s']\n", m_svPasswordHash.c_str()); - DevMsg(eDLL_T::SERVER, "[ Client: '%s']\n", svPassword.c_str()); + DevMsg(eDLL_T::SERVER, "[ Client: '%s']\n", passwordHash.c_str()); DevMsg(eDLL_T::SERVER, "+---------------------------------------------------------------------------+\n"); } - if (std::memcmp(svPassword.data(), m_svPasswordHash.data(), SHA256::DIGEST_SIZE) == 0) + if (memcmp(passwordHash.data(), m_svPasswordHash.data(), SHA256::DIGEST_SIZE) == 0) { return true; } @@ -375,7 +376,7 @@ bool CRConServer::Comparator(std::string svPassword) const // nMsgLen - // Output : true on success, false otherwise //----------------------------------------------------------------------------- -bool CRConServer::ProcessMessage(const char* pMsgBuf, int nMsgLen) +bool CRConServer::ProcessMessage(const char* pMsgBuf, const int nMsgLen) { CConnectedNetConsoleData* pData = m_Socket.GetAcceptedSocketData(m_nConnIndex); @@ -439,22 +440,22 @@ bool CRConServer::ProcessMessage(const char* pMsgBuf, int nMsgLen) //----------------------------------------------------------------------------- // Purpose: execute commands issued from net console -// Input : *cl_request - +// Input : *request - // bConVar - //----------------------------------------------------------------------------- -void CRConServer::Execute(const cl_rcon::request& cl_request, const bool bConVar) const +void CRConServer::Execute(const cl_rcon::request& request, const bool bConVar) const { if (bConVar) { - ConVar* pConVar = g_pCVar->FindVar(cl_request.requestmsg().c_str()); + ConVar* pConVar = g_pCVar->FindVar(request.requestmsg().c_str()); if (pConVar) // Only run if this is a ConVar. { - pConVar->SetValue(cl_request.requestval().c_str()); + pConVar->SetValue(request.requestval().c_str()); } } else // Execute command with "". { - Cbuf_AddText(Cbuf_GetCurrentPlayer(), cl_request.requestmsg().c_str(), cmd_source_t::kCommandSrcCode); + Cbuf_AddText(Cbuf_GetCurrentPlayer(), request.requestmsg().c_str(), cmd_source_t::kCommandSrcCode); } } @@ -527,15 +528,19 @@ bool CRConServer::CheckForBan(CConnectedNetConsoleData* pData) //----------------------------------------------------------------------------- // Purpose: close specific connection //----------------------------------------------------------------------------- -void CRConServer::Disconnect(const char* /*szReason*/) // NETMGR +void CRConServer::Disconnect(const char* szReason) // NETMGR { CConnectedNetConsoleData* pData = m_Socket.GetAcceptedSocketData(m_nConnIndex); - if (pData->m_bAuthorized) + if (pData->m_bAuthorized || sv_rcon_debug->GetBool()) { // Inform server owner when authenticated connection has been closed. netadr_t netAdr = m_Socket.GetAcceptedSocketAddress(m_nConnIndex); - DevMsg(eDLL_T::SERVER, "Net console '%s' closed RCON connection\n", netAdr.ToString()); + if (!szReason) + { + szReason = "unknown reason"; + } + DevMsg(eDLL_T::SERVER, "Connection to '%s' closed (%s)\n", netAdr.ToString(), szReason); m_nAuthConnections--; } diff --git a/r5dev/engine/server/sv_rcon.h b/r5dev/engine/server/sv_rcon.h index 03fd9b7c..bc256891 100644 --- a/r5dev/engine/server/sv_rcon.h +++ b/r5dev/engine/server/sv_rcon.h @@ -23,22 +23,27 @@ public: void Think(void); void RunFrame(void); - bool SendEncode(const char* pResponseMsg, const char* pResponseVal, const sv_rcon::response_t responseType, - const int nMessageId = static_cast(eDLL_T::NETCON), const int nMessageType = static_cast(LogType_t::LOG_NET)) const; - bool SendEncode(const SocketHandle_t hSocket, const char* pResponseMsg, const char* pResponseVal, const sv_rcon::response_t responseType, - const int nMessageId = static_cast(eDLL_T::NETCON), const int nMessageType = static_cast(LogType_t::LOG_NET)) const; + bool SendEncode(const char* pResponseMsg, const char* pResponseVal, + const sv_rcon::response_t responseType, + const int nMessageId = static_cast(eDLL_T::NETCON), + const int nMessageType = static_cast(LogType_t::LOG_NET)) const; - bool SendToAll(const char* pMsgBuf, int nMsgLen) const; + bool SendEncode(const SocketHandle_t hSocket, const char* pResponseMsg, + const char* pResponseVal, const sv_rcon::response_t responseType, + const int nMessageId = static_cast(eDLL_T::NETCON), + const int nMessageType = static_cast(LogType_t::LOG_NET)) const; + + bool SendToAll(const char* pMsgBuf, const int nMsgLen) const; bool Serialize(vector& vecBuf, const char* pResponseMsg, const char* pResponseVal, const sv_rcon::response_t responseType, const int nMessageId = static_cast(eDLL_T::NETCON), const int nMessageType = static_cast(LogType_t::LOG_NET)) const; - void Authenticate(const cl_rcon::request& cl_request, CConnectedNetConsoleData* pData); - bool Comparator(std::string svPassword) const; + void Authenticate(const cl_rcon::request& request, CConnectedNetConsoleData* pData); + bool Comparator(const string& svPassword) const; - virtual bool ProcessMessage(const char* pMsgBug, int nMsgLen) override; + virtual bool ProcessMessage(const char* pMsgBuf, const int nMsgLen) override; - void Execute(const cl_rcon::request& cl_request, const bool bConVar) const; + void Execute(const cl_rcon::request& request, const bool bConVar) const; bool CheckForBan(CConnectedNetConsoleData* pData); virtual void Disconnect(const char* szReason = nullptr) override; diff --git a/r5dev/engine/shared/base_rcon.cpp b/r5dev/engine/shared/base_rcon.cpp index 426f0ac7..819ac43c 100644 --- a/r5dev/engine/shared/base_rcon.cpp +++ b/r5dev/engine/shared/base_rcon.cpp @@ -15,7 +15,8 @@ // nMsgLen - // Output : true on success, false otherwise //----------------------------------------------------------------------------- -bool CNetConBase::Encode(google::protobuf::MessageLite* pMsg, char* pMsgBuf, size_t nMsgLen) const +bool CNetConBase::Encode(google::protobuf::MessageLite* pMsg, + char* pMsgBuf, const size_t nMsgLen) const { return pMsg->SerializeToArray(pMsgBuf, int(nMsgLen)); } @@ -27,7 +28,8 @@ bool CNetConBase::Encode(google::protobuf::MessageLite* pMsg, char* pMsgBuf, siz // nMsgLen - // Output : true on success, false otherwise //----------------------------------------------------------------------------- -bool CNetConBase::Decode(google::protobuf::MessageLite* pMsg, const char* pMsgBuf, size_t nMsgLen) const +bool CNetConBase::Decode(google::protobuf::MessageLite* pMsg, + const char* pMsgBuf, const size_t nMsgLen) const { return pMsg->ParseFromArray(pMsgBuf, int(nMsgLen)); } @@ -40,12 +42,7 @@ bool CNetConBase::Decode(google::protobuf::MessageLite* pMsg, const char* pMsgBu //----------------------------------------------------------------------------- bool CNetConBase::Connect(const char* pHostName, const int nPort) { - if (CL_NetConConnect(this, pHostName, nPort)) - { - return true; - } - - return false; + return CL_NetConConnect(this, pHostName, nPort); } //----------------------------------------------------------------------------- @@ -55,7 +52,8 @@ bool CNetConBase::Connect(const char* pHostName, const int nPort) // nMsgLen - // Output: true on success, false otherwise //----------------------------------------------------------------------------- -bool CNetConBase::Send(const SocketHandle_t hSocket, const char* pMsgBuf, int nMsgLen) const +bool CNetConBase::Send(const SocketHandle_t hSocket, const char* pMsgBuf, + const int nMsgLen) const { std::ostringstream sendbuf; const u_long nLen = htonl(u_long(nMsgLen)); @@ -63,7 +61,9 @@ bool CNetConBase::Send(const SocketHandle_t hSocket, const char* pMsgBuf, int nM sendbuf.write(reinterpret_cast(&nLen), sizeof(u_long)); sendbuf.write(pMsgBuf, nMsgLen); - int ret = ::send(hSocket, sendbuf.str().data(), int(sendbuf.str().size()), MSG_NOSIGNAL); + int ret = ::send(hSocket, sendbuf.str().data(), int(sendbuf.str().size()), + MSG_NOSIGNAL); + return (ret != SOCKET_ERROR); } @@ -71,9 +71,17 @@ bool CNetConBase::Send(const SocketHandle_t hSocket, const char* pMsgBuf, int nM // Purpose: receive message // Input : *pData - // nMaxLen - +// Output: true on success, false otherwise //----------------------------------------------------------------------------- void CNetConBase::Recv(CConnectedNetConsoleData* pData, const int nMaxLen) { + if (!pData) + { + Error(eDLL_T::ENGINE, NO_ERROR, "RCON Cmd: invalid input data\n"); + Assert(0); + return; + } + static char szRecvBuf[1024]; {////////////////////////////////////////////// @@ -84,20 +92,28 @@ void CNetConBase::Recv(CConnectedNetConsoleData* pData, const int nMaxLen) } if (nPendingLen <= 0) // EOF or error. { - Disconnect("Server closed RCON connection\n"); + Disconnect("unexpected EOF or error"); return; } }////////////////////////////////////////////// int nReadLen = 0; // Find out how much we have to read. - ::ioctlsocket(pData->m_hSocket, FIONREAD, reinterpret_cast(&nReadLen)); + int iResult = ::ioctlsocket(pData->m_hSocket, FIONREAD, reinterpret_cast(&nReadLen)); + + if (iResult == SOCKET_ERROR) + { + Error(eDLL_T::ENGINE, NO_ERROR, "RCON Cmd: ioctl(%s) error (%s)\n", "FIONREAD", NET_ErrorString(WSAGetLastError())); + return; + } + + bool bSuccess = true; while (nReadLen > 0) { const int nRecvLen = ::recv(pData->m_hSocket, szRecvBuf, MIN(sizeof(szRecvBuf), nReadLen), MSG_NOSIGNAL); if (nRecvLen == 0) // Socket was closed. { - Disconnect("Server closed RCON connection\n"); + Disconnect("socket closed unexpectedly"); break; } if (nRecvLen < 0 && !m_Socket.IsSocketBlocking()) @@ -107,8 +123,13 @@ void CNetConBase::Recv(CConnectedNetConsoleData* pData, const int nMaxLen) } nReadLen -= nRecvLen; // Process what we've got. - ProcessBuffer(pData, szRecvBuf, nRecvLen, nMaxLen); + if (!ProcessBuffer(pData, szRecvBuf, nRecvLen, nMaxLen) && bSuccess) + { + bSuccess = false; + } } + + return; } //----------------------------------------------------------------------------- @@ -118,8 +139,11 @@ void CNetConBase::Recv(CConnectedNetConsoleData* pData, const int nMaxLen) // *pData - // Output: true on success, false otherwise //----------------------------------------------------------------------------- -bool CNetConBase::ProcessBuffer(CConnectedNetConsoleData* pData, const char* pRecvBuf, int nRecvLen, int nMaxLen) +bool CNetConBase::ProcessBuffer(CConnectedNetConsoleData* pData, + const char* pRecvBuf, int nRecvLen, const int nMaxLen) { + bool bSuccess = true; + while (nRecvLen > 0) { if (pData->m_nPayloadLen) @@ -133,8 +157,12 @@ bool CNetConBase::ProcessBuffer(CConnectedNetConsoleData* pData, const char* pRe } if (pData->m_nPayloadRead == pData->m_nPayloadLen) { - ProcessMessage( - reinterpret_cast(pData->m_RecvBuffer.data()), pData->m_nPayloadLen); + if (!ProcessMessage( + reinterpret_cast(pData->m_RecvBuffer.data()), pData->m_nPayloadLen) + && bSuccess) + { + bSuccess = false; + } pData->m_nPayloadLen = 0; pData->m_nPayloadRead = 0; @@ -156,7 +184,7 @@ bool CNetConBase::ProcessBuffer(CConnectedNetConsoleData* pData, const char* pRe { if (pData->m_nPayloadLen > nMaxLen) { - Disconnect(); // Sending large messages while not authenticated. + Disconnect("overflow"); // Sending large messages while not authenticated. return false; } } @@ -165,7 +193,7 @@ bool CNetConBase::ProcessBuffer(CConnectedNetConsoleData* pData, const char* pRe pData->m_nPayloadLen > pData->m_RecvBuffer.max_size()) { Error(eDLL_T::ENGINE, NO_ERROR, "RCON Cmd: sync error (%d)\n", pData->m_nPayloadLen); - Disconnect(); // Out of sync (irrecoverable). + Disconnect("desync"); // Out of sync (irrecoverable). return false; } @@ -176,5 +204,5 @@ bool CNetConBase::ProcessBuffer(CConnectedNetConsoleData* pData, const char* pRe } } - return true; + return bSuccess; } diff --git a/r5dev/engine/shared/base_rcon.h b/r5dev/engine/shared/base_rcon.h index df33905f..35283ac4 100644 --- a/r5dev/engine/shared/base_rcon.h +++ b/r5dev/engine/shared/base_rcon.h @@ -11,16 +11,16 @@ public: CNetConBase(void) {} - virtual bool Encode(google::protobuf::MessageLite* pMsg, char* pMsgBuf, size_t nMsgLen) const; - virtual bool Decode(google::protobuf::MessageLite* pMsg, const char* pMsgBuf, size_t nMsgLen) const; + virtual bool Encode(google::protobuf::MessageLite* pMsg, char* pMsgBuf, const size_t nMsgLen) const; + virtual bool Decode(google::protobuf::MessageLite* pMsg, const char* pMsgBuf, const size_t nMsgLen) const; virtual bool Connect(const char* pHostAdr, const int nHostPort = SOCKET_ERROR); virtual void Disconnect(const char* szReason = nullptr) { NOTE_UNUSED(szReason); }; - virtual bool Send(const SocketHandle_t hSocket, const char* pMsgBuf, int nMsgLen) const; - virtual void Recv(CConnectedNetConsoleData* pData, const int nMaxLen = -1); + virtual bool Send(const SocketHandle_t hSocket, const char* pMsgBuf, const int nMsgLen) const; + virtual void Recv(CConnectedNetConsoleData* pData, const int nMaxLen = SOCKET_ERROR); - virtual bool ProcessBuffer(CConnectedNetConsoleData* pData, const char* pRecvBuf, int nRecvLen, int nMaxLen = -1); + virtual bool ProcessBuffer(CConnectedNetConsoleData* pData, const char* pRecvBuf, int nRecvLen, const int nMaxLen = SOCKET_ERROR); virtual bool ProcessMessage(const char* /*pMsgBuf*/, int /*nMsgLen*/) { return true; }; CSocketCreator* GetSocketCreator(void) { return &m_Socket; } diff --git a/r5dev/engine/shared/shared_rcon.cpp b/r5dev/engine/shared/shared_rcon.cpp index f3c51450..65f89a9a 100644 --- a/r5dev/engine/shared/shared_rcon.cpp +++ b/r5dev/engine/shared/shared_rcon.cpp @@ -89,7 +89,7 @@ bool CL_NetConConnect(CNetConBase* pBase, const char* pHostAdr, const int nHostP // iSocket - // Output : nullptr on failure //----------------------------------------------------------------------------- -CConnectedNetConsoleData* SH_GetNetConData(CNetConBase* pBase, int iSocket) +CConnectedNetConsoleData* SH_GetNetConData(CNetConBase* pBase, const int iSocket) { const CSocketCreator* pCreator = pBase->GetSocketCreator(); Assert(iSocket >= 0 && iSocket < pCreator->GetAcceptedSocketCount()); @@ -108,7 +108,7 @@ CConnectedNetConsoleData* SH_GetNetConData(CNetConBase* pBase, int iSocket) // iSocket - // Output : SOCKET_ERROR (-1) on failure //----------------------------------------------------------------------------- -SocketHandle_t SH_GetNetConSocketHandle(CNetConBase* pBase, int iSocket) +SocketHandle_t SH_GetNetConSocketHandle(CNetConBase* pBase, const int iSocket) { const CConnectedNetConsoleData* pData = SH_GetNetConData(pBase, iSocket); if (!pData) diff --git a/r5dev/engine/shared/shared_rcon.h b/r5dev/engine/shared/shared_rcon.h index 18cf5d46..394a8f9d 100644 --- a/r5dev/engine/shared/shared_rcon.h +++ b/r5dev/engine/shared/shared_rcon.h @@ -8,7 +8,7 @@ bool CL_NetConSerialize(const CNetConBase* pBase, vector& vecBuf, const ch const char* szReqVal, const cl_rcon::request_t requestType); bool CL_NetConConnect(CNetConBase* pBase, const char* pHostAdr, const int nHostPort); -CConnectedNetConsoleData* SH_GetNetConData(CNetConBase* pBase, int iSocket); -SocketHandle_t SH_GetNetConSocketHandle(CNetConBase* pBase, int iSocket); +CConnectedNetConsoleData* SH_GetNetConData(CNetConBase* pBase, const int iSocket); +SocketHandle_t SH_GetNetConSocketHandle(CNetConBase* pBase, const int iSocket); #endif // SHARED_RCON_H diff --git a/r5dev/netconsole/netconsole.cpp b/r5dev/netconsole/netconsole.cpp index e5f561a9..d31cd081 100644 --- a/r5dev/netconsole/netconsole.cpp +++ b/r5dev/netconsole/netconsole.cpp @@ -84,7 +84,8 @@ bool CNetCon::Shutdown(void) } else // WSACleanup() failed. { - Error(eDLL_T::CLIENT, NO_ERROR, "%s - Failed to stop Winsock: (%s)\n", __FUNCTION__, NET_ErrorString(WSAGetLastError())); + Error(eDLL_T::CLIENT, NO_ERROR, "%s - Failed to stop Winsock: (%s)\n", + __FUNCTION__, NET_ErrorString(WSAGetLastError())); } SpdLog_Shutdown(); @@ -123,11 +124,11 @@ void CNetCon::UserInput(void) { if (m_Input.compare("disconnect") == 0) { - Disconnect(); + Disconnect("user closed connection"); return; } - const std::vector vSubStrings = StringSplit(m_Input, ' ', 2); + const vector vSubStrings = StringSplit(m_Input, ' ', 2); vector vecMsg; const SocketHandle_t hSocket = GetSocket(); @@ -169,7 +170,7 @@ void CNetCon::UserInput(void) } else // Setup connection from input. { - const std::vector vSubStrings = StringSplit(m_Input, ' ', 2); + const vector vSubStrings = StringSplit(m_Input, ' ', 2); if (vSubStrings.size() > 1) { const string::size_type nPos = m_Input.find(' '); @@ -177,8 +178,8 @@ void CNetCon::UserInput(void) && nPos < m_Input.size() && nPos != m_Input.size()) { - std::string svInPort = m_Input.substr(nPos + 1); - std::string svInAdr = m_Input.erase(m_Input.find(' ')); + string svInPort = m_Input.substr(nPos + 1); + string svInAdr = m_Input.erase(m_Input.find(' ')); if (svInPort.empty() || svInAdr.empty()) { @@ -246,16 +247,18 @@ bool CNetCon::ShouldQuit(void) const //----------------------------------------------------------------------------- // Purpose: disconnect from current session +// Input : *szReason - //----------------------------------------------------------------------------- void CNetCon::Disconnect(const char* szReason) { if (IsConnected()) { - if (szReason) + if (!szReason) { - DevMsg(eDLL_T::CLIENT, "%s", szReason); + szReason = "unknown reason"; } + DevMsg(eDLL_T::CLIENT, "Disconnect: (%s)\n", szReason); m_Socket.CloseAcceptedSocket(0); } @@ -264,9 +267,11 @@ void CNetCon::Disconnect(const char* szReason) //----------------------------------------------------------------------------- // Purpose: processes received message -// Input : *sv_response - +// Input : *pMsgBuf - +// nMsgLen - +// Output : true on success, false otherwise //----------------------------------------------------------------------------- -bool CNetCon::ProcessMessage(const char* pMsgBuf, int nMsgLen) +bool CNetCon::ProcessMessage(const char* pMsgBuf, const int nMsgLen) { sv_rcon::response response; bool bSuccess = Decode(&response, pMsgBuf, nMsgLen); diff --git a/r5dev/netconsole/netconsole.h b/r5dev/netconsole/netconsole.h index 6d49cae3..45a1bb01 100644 --- a/r5dev/netconsole/netconsole.h +++ b/r5dev/netconsole/netconsole.h @@ -27,7 +27,7 @@ public: bool ShouldQuit(void) const; virtual void Disconnect(const char* szReason = nullptr); - virtual bool ProcessMessage(const char* pMsgBuf, int nMsgLen) override; + virtual bool ProcessMessage(const char* pMsgBuf, const int nMsgLen) override; bool Serialize(vector& vecBuf, const char* szReqBuf, const char* szReqVal, const cl_rcon::request_t requestType) const; diff --git a/r5dev/public/tier2/socketcreator.h b/r5dev/public/tier2/socketcreator.h index 69c732d0..91480d10 100644 --- a/r5dev/public/tier2/socketcreator.h +++ b/r5dev/public/tier2/socketcreator.h @@ -18,7 +18,8 @@ public: void CloseListenSocket(void); int ConnectSocket(const netadr_t& netAdr, bool bSingleSocket); - void DisconnectSocket(void); + void DisconnectSocket(SocketHandle_t hSocket); + void DisconnectSockets(void); bool ConfigureSocket(SocketHandle_t hSocket, bool bDualStack = true); int OnSocketAccepted(SocketHandle_t hSocket, const netadr_t& netAdr); diff --git a/r5dev/tier2/socketcreator.cpp b/r5dev/tier2/socketcreator.cpp index 0ece69a0..44fc0527 100644 --- a/r5dev/tier2/socketcreator.cpp +++ b/r5dev/tier2/socketcreator.cpp @@ -25,7 +25,7 @@ CSocketCreator::CSocketCreator(void) //----------------------------------------------------------------------------- CSocketCreator::~CSocketCreator(void) { - DisconnectSocket(); + DisconnectSockets(); } //----------------------------------------------------------------------------- @@ -58,7 +58,7 @@ void CSocketCreator::ProcessAccept(void) if (!ConfigureSocket(newSocket, false)) { - ::closesocket(newSocket); + DisconnectSocket(newSocket); return; } @@ -118,7 +118,7 @@ void CSocketCreator::CloseListenSocket(void) { if (m_hListenSocket != SOCKET_ERROR) { - ::closesocket(m_hListenSocket); + DisconnectSocket(m_hListenSocket); m_hListenSocket = SOCKET_ERROR; } } @@ -145,7 +145,7 @@ int CSocketCreator::ConnectSocket(const netadr_t& netAdr, bool bSingleSocket) if (!ConfigureSocket(hSocket)) { - ::closesocket(hSocket); + DisconnectSocket(hSocket); return SOCKET_ERROR; } @@ -159,7 +159,7 @@ int CSocketCreator::ConnectSocket(const netadr_t& netAdr, bool bSingleSocket) { Warning(eDLL_T::ENGINE, "Socket connection failed (%s)\n", NET_ErrorString(WSAGetLastError())); - ::closesocket(hSocket); + DisconnectSocket(hSocket); return SOCKET_ERROR; } @@ -175,7 +175,8 @@ int CSocketCreator::ConnectSocket(const netadr_t& netAdr, bool bSingleSocket) if (::select(hSocket + 1, NULL, &writefds, NULL, &tv) < 1) // block for at most 1 second. { Warning(eDLL_T::ENGINE, "Socket connection timed out\n"); - ::closesocket(hSocket); // took too long to connect to, give up. + DisconnectSocket(hSocket); // took too long to connect to, give up. + return SOCKET_ERROR; } } @@ -186,10 +187,23 @@ int CSocketCreator::ConnectSocket(const netadr_t& netAdr, bool bSingleSocket) return nIndex; } +//----------------------------------------------------------------------------- +// Purpose: closes specific open sockets (listen + accepted) +//----------------------------------------------------------------------------- +void CSocketCreator::DisconnectSocket(SocketHandle_t hSocket) +{ + Assert(hSocket != SOCKET_ERROR); + if (::closesocket(hSocket) == SOCKET_ERROR) + { + Error(eDLL_T::ENGINE, NO_ERROR, "Unable to close socket (%s)\n", + NET_ErrorString(WSAGetLastError())); + } +} + //----------------------------------------------------------------------------- // Purpose: closes all open sockets (listen + accepted) //----------------------------------------------------------------------------- -void CSocketCreator::DisconnectSocket(void) +void CSocketCreator::DisconnectSockets(void) { CloseListenSocket(); CloseAllAcceptedSockets(); @@ -278,7 +292,7 @@ void CSocketCreator::CloseAcceptedSocket(int nIndex) } AcceptedSocket_t& connected = m_hAcceptedSockets[nIndex]; - ::closesocket(connected.m_hSocket); + DisconnectSocket(connected.m_hSocket); delete connected.m_pData; m_hAcceptedSockets.erase(m_hAcceptedSockets.begin() + nIndex); @@ -292,7 +306,7 @@ void CSocketCreator::CloseAllAcceptedSockets(void) for (size_t i = 0; i < m_hAcceptedSockets.size(); ++i) { AcceptedSocket_t& connected = m_hAcceptedSockets[i]; - ::closesocket(connected.m_hSocket); + DisconnectSocket(connected.m_hSocket); delete connected.m_pData; } diff --git a/r5dev/vstdlib/callback.cpp b/r5dev/vstdlib/callback.cpp index 89a492b7..10451c54 100644 --- a/r5dev/vstdlib/callback.cpp +++ b/r5dev/vstdlib/callback.cpp @@ -888,7 +888,7 @@ void RCON_CmdQuery_f(const CCommand& args) if (bSuccess) { - RCONClient()->Send(hSocket, vecMsg.data(), vecMsg.size()); + RCONClient()->Send(hSocket, vecMsg.data(), int(vecMsg.size())); } return; @@ -902,7 +902,7 @@ void RCON_CmdQuery_f(const CCommand& args) bSuccess = RCONClient()->Serialize(vecMsg, args.ArgS(), "", cl_rcon::request_t::SERVERDATA_REQUEST_EXECCOMMAND); if (bSuccess) { - RCONClient()->Send(hSocket, vecMsg.data(), vecMsg.size()); + RCONClient()->Send(hSocket, vecMsg.data(), int(vecMsg.size())); } return; }