Sat, 13 Dec 2014 04:50:33 +0200
- to hell with that 'ok' field. now throws an exception if attempts to read past the end
--- a/sources/network/bytestream.cpp Sat Dec 13 04:32:15 2014 +0200 +++ b/sources/network/bytestream.cpp Sat Dec 13 04:50:33 2014 +0200 @@ -31,8 +31,6 @@ #include "bytestream.h" #include <string.h> -bool Bytestream::sink; - // ------------------------------------------------------------------------------------------------- // Bytestream::Bytestream (unsigned long length) : @@ -122,22 +120,29 @@ // ------------------------------------------------------------------------------------------------- // -char Bytestream::read_byte (bool* ok) +METHOD +Bytestream::ensure_read_space (unsigned int bytes) -> void { - *ok = bytes_left() > 0; - return *ok ? *m_cursor++ : -1; + if (bytes_left() < bytes) + { + throw IOError (format ("attempted to read %1 byte(s) past the end of bytestream", + bytes - bytes_left())); + } } // ------------------------------------------------------------------------------------------------- // -short int Bytestream::read_short (bool* ok) +char Bytestream::read_byte() { - if (bytes_left() < 2) - { - *ok = false; - return false; - } + ensure_read_space (1); + return *m_cursor++; +} +// ------------------------------------------------------------------------------------------------- +// +short int Bytestream::read_short() +{ + ensure_read_space (2); short int result = 0; for (int i = 0; i < 2; ++i) @@ -149,14 +154,9 @@ // ------------------------------------------------------------------------------------------------- // -long int Bytestream::read_long (bool* ok) +long int Bytestream::read_long() { - if (bytes_left() < 4) - { - *ok = false; - return -1; - } - + ensure_read_space (4); long int result = 0; for (int i = 0; i < 4; ++i) @@ -168,19 +168,15 @@ // ------------------------------------------------------------------------------------------------- // -float Bytestream::read_float (bool* ok) +float Bytestream::read_float() { - int value = read_long (ok); - - if (*ok == false) - return -1.0f; - + int value = read_long(); return reinterpret_cast<float&> (value); } // ------------------------------------------------------------------------------------------------- // -String Bytestream::read_string (bool* ok) +String Bytestream::read_string() { // Zandronum sends strings of maximum 2048 characters, though it only // reads 2047-character long ones so I guess we can follow up and do @@ -194,11 +190,8 @@ for (stringEnd = m_cursor; *stringEnd != '\0'; ++stringEnd) { if (stringEnd == end) - { // past the end of the buffer! Argh! - *ok = false; - return ""; - } + throw IOError ("unterminated string in packet"); } m_cursor = stringEnd + 1; @@ -217,14 +210,9 @@ // ------------------------------------------------------------------------------------------------- // METHOD -Bytestream::read (unsigned char* buffer, unsigned long length, bool* ok) -> void +Bytestream::read (unsigned char* buffer, unsigned long length) -> void { - if (bytes_left() < length) - { - *ok = false; - return; - } - + ensure_read_space (length); memcpy (buffer, m_cursor, length); m_cursor += length; }
--- a/sources/network/bytestream.h Sat Dec 13 04:32:15 2014 +0200 +++ b/sources/network/bytestream.h Sat Dec 13 04:50:33 2014 +0200 @@ -29,6 +29,7 @@ */ #pragma once +#include <stdexcept> #include "../main.h" class String; @@ -39,7 +40,20 @@ class Bytestream { public: - static bool sink; + class IOError : public std::exception + { + String m_message; + + public: + IOError (String message) : + m_message (message) {} + + inline METHOD + what() const throw() -> const char* + { + return m_message.chars(); + } + }; Bytestream (unsigned long length = 0x800); Bytestream (const unsigned char* data, unsigned long length); @@ -54,12 +68,12 @@ inline METHOD data() const -> const unsigned char*; METHOD grow_to_fit (unsigned long bytes) -> void; inline METHOD position() const -> unsigned long; - METHOD read (unsigned char* buffer, unsigned long length, bool* ok = &sink) -> void; - METHOD read_byte (bool* ok = &sink) -> char; - METHOD read_short (bool* ok = &sink) -> short int; - METHOD read_long (bool* ok = &sink) -> long int; - METHOD read_string (bool* ok = &sink) -> String; - METHOD read_float (bool* ok = &sink) -> float; + METHOD read (unsigned char* buffer, unsigned long length) -> void; + METHOD read_byte() -> char; + METHOD read_short() -> short int; + METHOD read_long() -> long int; + METHOD read_string() -> String; + METHOD read_float() -> float; METHOD resize (unsigned long length) -> void; inline METHOD rewind() -> void; inline METHOD seek (unsigned long pos) -> void; @@ -87,6 +101,7 @@ METHOD init (const unsigned char* data, unsigned long length) -> void; METHOD write (unsigned char val) -> void; + METHOD ensure_read_space (unsigned int bytes) -> void; inline METHOD space_left() const -> unsigned long; };
--- a/sources/network/rconsession.cpp Sat Dec 13 04:32:15 2014 +0200 +++ b/sources/network/rconsession.cpp Sat Dec 13 04:50:33 2014 +0200 @@ -88,73 +88,75 @@ RCONSession::handle_packet (Bytestream& packet, const IPAddress& from) -> void { print ("Processing packet of %1 bytes\n", packet.written_length()); - bool ok = true; - - while (packet.bytes_left() > 0) - { - int header = packet.read_byte (&ok); - print ("Recieved packet with header %1\n", header); - switch (ServerResponse (header)) + try + { + while (packet.bytes_left() > 0) { - case SVRC_OLDPROTOCOL: - print ("wrong version\n"); - m_state = RCON_DISCONNECTED; - break; + int header = packet.read_byte(); + print ("Recieved packet with header %1\n", header); + + switch (ServerResponse (header)) + { + case SVRC_OLDPROTOCOL: + print ("wrong version\n"); + m_state = RCON_DISCONNECTED; + break; - case SVRC_BANNED: - print ("you're banned\n"); - m_state = RCON_DISCONNECTED; - break; + case SVRC_BANNED: + print ("you're banned\n"); + m_state = RCON_DISCONNECTED; + break; - case SVRC_SALT: - { - String salt = packet.read_string(); - m_salt = salt; - m_state = RCON_AUTHENTICATING; - send_password(); - } - break; + case SVRC_SALT: + { + String salt = packet.read_string(); + m_salt = salt; + m_state = RCON_AUTHENTICATING; + send_password(); + } + break; - case SVRC_INVALIDPASSWORD: - print ("bad password\n"); - m_state = RCON_DISCONNECTED; - break; + case SVRC_INVALIDPASSWORD: + print ("bad password\n"); + m_state = RCON_DISCONNECTED; + break; - case SVRC_MESSAGE: - { - String message = packet.read_string(); - if (message.ends_with ("\n")) - message.remove_from_end (1); - print ("message: %1\n", message); - } - break; + case SVRC_MESSAGE: + { + String message = packet.read_string(); + if (message.ends_with ("\n")) + message.remove_from_end (1); + print ("message: %1\n", message); + } + break; - case SVRC_LOGGEDIN: - print ("login successful\n"); - m_serverProtocol = packet.read_byte(); - m_hostname = packet.read_string(); - m_state = RCON_CONNECTED; + case SVRC_LOGGEDIN: + print ("login successful\n"); + m_serverProtocol = packet.read_byte(); + m_hostname = packet.read_string(); + m_state = RCON_CONNECTED; + + for (int i = packet.read_byte(); i > 0; --i) + process_server_updates (packet); - for (int i = packet.read_byte(); i > 0; --i) + for (int i = packet.read_byte(); i > 0; --i) + { + String message = packet.read_string(); + message.normalize(); + print ("--- %1\n", message); + } + break; + + case SVRC_UPDATE: process_server_updates (packet); - - for (int i = packet.read_byte(); i > 0; --i) - { - String message = packet.read_string(); - message.normalize(); - print ("--- %1\n", message); + break; } - - break; - - case SVRC_UPDATE: - process_server_updates (packet); - break; } - - if (not ok) - print ("error while reading packet\n"); + } + catch (std::exception& e) + { + print ("error while reading packet: %1\n", e.what()); } }