/*****************************************************************************
 * Copyright (c) 2014-2025 OpenRCT2 developers
 *
 * For a complete list of all authors, please refer to contributors.md
 * Interested in contributing? Visit https://github.com/OpenRCT2/OpenRCT2
 *
 * OpenRCT2 is licensed under the GNU General Public License version 3.
 *****************************************************************************/

#ifndef DISABLE_NETWORK

    #include "NetworkConnection.h"

    #include "../core/String.hpp"
    #include "../localisation/Formatting.h"
    #include "../platform/Platform.h"
    #include "Network.h"
    #include "Socket.h"

    #include <sfl/small_vector.hpp>

namespace OpenRCT2::Network
{
    static constexpr size_t kDisconnectReasonBufSize = 256;
    static constexpr size_t kBufferSize = (1024 * 64) - 1; // 64 KiB, maximum packet size.
    #ifndef DEBUG
    static constexpr size_t kNoDataTimeout = 20; // Seconds.
    #endif

    static_assert(kBufferSize <= std::numeric_limits<uint16_t>::max(), "kBufferSize too big, uint16_t is max.");

    Connection::Connection() noexcept
    {
        ResetLastPacketTime();
    }

    ReadPacket Connection::readPacket()
    {
        size_t bytesRead = 0;

        // Read packet header.
        auto& header = InboundPacket.Header;
        if (InboundPacket.BytesTransferred < sizeof(InboundPacket.Header))
        {
            const size_t missingLength = sizeof(header) - InboundPacket.BytesTransferred;

            uint8_t* buffer = reinterpret_cast<uint8_t*>(&InboundPacket.Header);

            ReadPacket status = Socket->ReceiveData(buffer, missingLength, &bytesRead);
            if (status != ReadPacket::success)
            {
                return status;
            }

            InboundPacket.BytesTransferred += bytesRead;
            if (InboundPacket.BytesTransferred < sizeof(InboundPacket.Header))
            {
                // If still not enough data for header, keep waiting.
                return ReadPacket::moreData;
            }

            // Normalise values.
            header.Size = Convert::NetworkToHost(header.Size);
            header.Id = ByteSwapBE(header.Id);

            // NOTE: For compatibility reasons for the master server we need to remove sizeof(Header.Id) from the size.
            // Previously the Id field was not part of the header rather part of the body.
            header.Size -= std::min<uint16_t>(header.Size, sizeof(header.Id));

            // Fall-through: Read rest of packet.
        }

        // Read packet body.
        {
            // NOTE: BytesTransfered includes the header length, this will not underflow.
            const size_t missingLength = header.Size - (InboundPacket.BytesTransferred - sizeof(header));

            uint8_t buffer[kBufferSize];

            if (missingLength > 0)
            {
                ReadPacket status = Socket->ReceiveData(buffer, std::min(missingLength, kBufferSize), &bytesRead);
                if (status != ReadPacket::success)
                {
                    return status;
                }

                InboundPacket.BytesTransferred += bytesRead;
                InboundPacket.Write(buffer, bytesRead);
            }

            if (InboundPacket.Data.size() == header.Size)
            {
                // Received complete packet.
                _lastPacketTime = Platform::GetTicks();

                RecordPacketStats(InboundPacket, false);

                return ReadPacket::success;
            }
        }

        return ReadPacket::moreData;
    }

    static sfl::small_vector<uint8_t, 512> serializePacket(const Packet& packet)
    {
        // NOTE: For compatibility reasons for the master server we need to add sizeof(Header.Id) to the size.
        // Previously the Id field was not part of the header rather part of the body.
        const auto bodyLength = packet.Data.size() + sizeof(packet.Header.Id);

        Guard::Assert(bodyLength <= std::numeric_limits<uint16_t>::max(), "Packet size too large");

        auto header = packet.Header;
        header.Size = static_cast<uint16_t>(bodyLength);
        header.Size = Convert::HostToNetwork(header.Size);
        header.Id = ByteSwapBE(header.Id);

        sfl::small_vector<uint8_t, 512> buffer;
        buffer.reserve(sizeof(header) + packet.Data.size());

        buffer.insert(buffer.end(), reinterpret_cast<uint8_t*>(&header), reinterpret_cast<uint8_t*>(&header) + sizeof(header));
        buffer.insert(buffer.end(), packet.Data.begin(), packet.Data.end());

        return buffer;
    }

    void Connection::QueuePacket(const Packet& packet, bool front)
    {
        if (AuthStatus == Auth::ok || !packet.CommandRequiresAuth())
        {
            const auto payload = serializePacket(packet);
            if (front)
            {
                _outboundBuffer.insert(_outboundBuffer.begin(), payload.begin(), payload.end());
            }
            else
            {
                _outboundBuffer.insert(_outboundBuffer.end(), payload.begin(), payload.end());
            }

            RecordPacketStats(packet, true);
        }
    }

    void Connection::Disconnect() noexcept
    {
        ShouldDisconnect = true;
    }

    bool Connection::IsValid() const
    {
        return !ShouldDisconnect && Socket->GetStatus() == SocketStatus::connected;
    }

    void Connection::SendQueuedData()
    {
        if (_outboundBuffer.empty())
        {
            return;
        }

        const auto bytesSent = Socket->SendData(_outboundBuffer.data(), _outboundBuffer.size());

        if (bytesSent > 0)
        {
            _outboundBuffer.erase(_outboundBuffer.begin(), _outboundBuffer.begin() + bytesSent);
        }
    }

    void Connection::ResetLastPacketTime() noexcept
    {
        _lastPacketTime = Platform::GetTicks();
    }

    bool Connection::ReceivedPacketRecently() const noexcept
    {
    #ifndef DEBUG
        constexpr auto kTimeoutMs = kNoDataTimeout * 1000;
        if (Platform::GetTicks() > _lastPacketTime + kTimeoutMs)
        {
            return false;
        }
    #endif
        return true;
    }

    const utf8* Connection::GetLastDisconnectReason() const noexcept
    {
        return this->_lastDisconnectReason.c_str();
    }

    void Connection::SetLastDisconnectReason(std::string_view src)
    {
        _lastDisconnectReason = src;
    }

    void Connection::SetLastDisconnectReason(const StringId string_id, void* args)
    {
        char buffer[kDisconnectReasonBufSize];
        FormatStringLegacy(buffer, kDisconnectReasonBufSize, string_id, args);
        SetLastDisconnectReason(buffer);
    }

    void Connection::RecordPacketStats(const Packet& packet, bool sending)
    {
        uint32_t packetSize = static_cast<uint32_t>(packet.BytesTransferred);
        StatisticsGroup trafficGroup;

        switch (packet.GetCommand())
        {
            case Command::gameAction:
                trafficGroup = StatisticsGroup::Commands;
                break;
            case Command::map:
                trafficGroup = StatisticsGroup::MapData;
                break;
            default:
                trafficGroup = StatisticsGroup::Base;
                break;
        }

        if (sending)
        {
            stats.bytesSent[EnumValue(trafficGroup)] += packetSize;
            stats.bytesSent[EnumValue(StatisticsGroup::Total)] += packetSize;
        }
        else
        {
            stats.bytesReceived[EnumValue(trafficGroup)] += packetSize;
            stats.bytesReceived[EnumValue(StatisticsGroup::Total)] += packetSize;
        }
    }
} // namespace OpenRCT2::Network

#endif
