- to hell with that 'ok' field. now throws an exception if attempts to read past the end

Sat, 13 Dec 2014 04:50:33 +0200

author
Teemu Piippo <crimsondusk64@gmail.com>
date
Sat, 13 Dec 2014 04:50:33 +0200
changeset 13
09dcaeaa216b
parent 12
8d0d1b368de0
child 14
33b8f428bacb

- to hell with that 'ok' field. now throws an exception if attempts to read past the end

sources/network/bytestream.cpp file | annotate | diff | comparison | revisions
sources/network/bytestream.h file | annotate | diff | comparison | revisions
sources/network/rconsession.cpp file | annotate | diff | comparison | revisions
--- a/sources/network/bytestream.cpp	Sat Dec 13 04:32:15 2014 +0200
+++ b/sources/network/bytestream.cpp	Sat Dec 13 04:50:33 2014 +0200
@@ -31,8 +31,6 @@
 #include "bytestream.h"
 #include <string.h>
 
-bool Bytestream::sink;
-
 // -------------------------------------------------------------------------------------------------
 //
 Bytestream::Bytestream (unsigned long length) :
@@ -122,22 +120,29 @@
 
 // -------------------------------------------------------------------------------------------------
 //
-char Bytestream::read_byte (bool* ok)
+METHOD
+Bytestream::ensure_read_space (unsigned int bytes) -> void
 {
-	*ok = bytes_left() > 0;
-	return *ok ? *m_cursor++ : -1;
+	if (bytes_left() < bytes)
+	{
+		throw IOError (format ("attempted to read %1 byte(s) past the end of bytestream",
+			bytes - bytes_left()));
+	}
 }
 
 // -------------------------------------------------------------------------------------------------
 //
-short int Bytestream::read_short (bool* ok)
+char Bytestream::read_byte()
 {
-	if (bytes_left() < 2)
-	{
-		*ok = false;
-		return false;
-	}
+	ensure_read_space (1);
+	return *m_cursor++;
+}
 
+// -------------------------------------------------------------------------------------------------
+//
+short int Bytestream::read_short()
+{
+	ensure_read_space (2);
 	short int result = 0;
 
 	for (int i = 0; i < 2; ++i)
@@ -149,14 +154,9 @@
 
 // -------------------------------------------------------------------------------------------------
 //
-long int Bytestream::read_long (bool* ok)
+long int Bytestream::read_long()
 {
-	if (bytes_left() < 4)
-	{
-		*ok = false;
-		return -1;
-	}
-
+	ensure_read_space (4);
 	long int result = 0;
 
 	for (int i = 0; i < 4; ++i)
@@ -168,19 +168,15 @@
 
 // -------------------------------------------------------------------------------------------------
 //
-float Bytestream::read_float (bool* ok)
+float Bytestream::read_float()
 {
-	int value = read_long (ok);
-
-	if (*ok == false)
-		return -1.0f;
-
+	int value = read_long();
 	return reinterpret_cast<float&> (value);
 }
 
 // -------------------------------------------------------------------------------------------------
 //
-String Bytestream::read_string (bool* ok)
+String Bytestream::read_string()
 {
 	// Zandronum sends strings of maximum 2048 characters, though it only
 	// reads 2047-character long ones so I guess we can follow up and do
@@ -194,11 +190,8 @@
 	for (stringEnd = m_cursor; *stringEnd != '\0'; ++stringEnd)
 	{
 		if (stringEnd == end)
-		{
 			// past the end of the buffer! Argh!
-			*ok = false;
-			return "";
-		}
+			throw IOError ("unterminated string in packet");
 	}
 
 	m_cursor = stringEnd + 1;
@@ -217,14 +210,9 @@
 // -------------------------------------------------------------------------------------------------
 //
 METHOD
-Bytestream::read (unsigned char* buffer, unsigned long length, bool* ok) -> void
+Bytestream::read (unsigned char* buffer, unsigned long length) -> void
 {
-	if (bytes_left() < length)
-	{
-		*ok = false;
-		return;
-	}
-
+	ensure_read_space (length);
 	memcpy (buffer, m_cursor, length);
 	m_cursor += length;
 }
--- a/sources/network/bytestream.h	Sat Dec 13 04:32:15 2014 +0200
+++ b/sources/network/bytestream.h	Sat Dec 13 04:50:33 2014 +0200
@@ -29,6 +29,7 @@
 */
 
 #pragma once
+#include <stdexcept>
 #include "../main.h"
 
 class String;
@@ -39,7 +40,20 @@
 class Bytestream
 {
 public:
-	static bool sink;
+	class IOError : public std::exception
+	{
+		String m_message;
+
+	public:
+		IOError (String message) :
+			m_message (message) {}
+
+		inline METHOD
+		what() const throw() -> const char*
+		{
+			return m_message.chars();
+		}
+	};
 
 	Bytestream (unsigned long length = 0x800);
 	Bytestream (const unsigned char* data, unsigned long length);
@@ -54,12 +68,12 @@
 	inline METHOD data() const -> const unsigned char*;
 	       METHOD grow_to_fit (unsigned long bytes) -> void;
 	inline METHOD position() const -> unsigned long;
-	       METHOD read (unsigned char* buffer, unsigned long length, bool* ok = &sink) -> void;
-	       METHOD read_byte (bool* ok = &sink) -> char;
-	       METHOD read_short (bool* ok = &sink) -> short int;
-	       METHOD read_long (bool* ok = &sink) -> long int;
-	       METHOD read_string (bool* ok = &sink) -> String;
-	       METHOD read_float (bool* ok = &sink) -> float;
+	       METHOD read (unsigned char* buffer, unsigned long length) -> void;
+	       METHOD read_byte() -> char;
+	       METHOD read_short() -> short int;
+	       METHOD read_long() -> long int;
+	       METHOD read_string() -> String;
+	       METHOD read_float() -> float;
 	       METHOD resize (unsigned long length) -> void;
 	inline METHOD rewind() -> void;
 	inline METHOD seek (unsigned long pos) -> void;
@@ -87,6 +101,7 @@
 
 	METHOD init (const unsigned char* data, unsigned long length) -> void;
 	METHOD write (unsigned char val) -> void;
+	METHOD ensure_read_space (unsigned int bytes) -> void;
 	inline METHOD space_left() const -> unsigned long;
 };
 
--- a/sources/network/rconsession.cpp	Sat Dec 13 04:32:15 2014 +0200
+++ b/sources/network/rconsession.cpp	Sat Dec 13 04:50:33 2014 +0200
@@ -88,73 +88,75 @@
 RCONSession::handle_packet (Bytestream& packet, const IPAddress& from) -> void
 {
 	print ("Processing packet of %1 bytes\n", packet.written_length());
-	bool ok = true;
-
-	while (packet.bytes_left() > 0)
-	{
-		int header = packet.read_byte (&ok);
-		print ("Recieved packet with header %1\n", header);
 
-		switch (ServerResponse (header))
+	try
+	{
+		while (packet.bytes_left() > 0)
 		{
-		case SVRC_OLDPROTOCOL:
-			print ("wrong version\n");
-			m_state = RCON_DISCONNECTED;
-			break;
+			int header = packet.read_byte();
+			print ("Recieved packet with header %1\n", header);
+
+			switch (ServerResponse (header))
+			{
+			case SVRC_OLDPROTOCOL:
+				print ("wrong version\n");
+				m_state = RCON_DISCONNECTED;
+				break;
 
-		case SVRC_BANNED:
-			print ("you're banned\n");
-			m_state = RCON_DISCONNECTED;
-			break;
+			case SVRC_BANNED:
+				print ("you're banned\n");
+				m_state = RCON_DISCONNECTED;
+				break;
 
-		case SVRC_SALT:
-			{
-				String salt = packet.read_string();
-				m_salt = salt;
-				m_state = RCON_AUTHENTICATING;
-				send_password();
-			}
-			break;
+			case SVRC_SALT:
+				{
+					String salt = packet.read_string();
+					m_salt = salt;
+					m_state = RCON_AUTHENTICATING;
+					send_password();
+				}
+				break;
 
-		case SVRC_INVALIDPASSWORD:
-			print ("bad password\n");
-			m_state = RCON_DISCONNECTED;
-			break;
+			case SVRC_INVALIDPASSWORD:
+				print ("bad password\n");
+				m_state = RCON_DISCONNECTED;
+				break;
 
-		case SVRC_MESSAGE:
-			{
-				String message = packet.read_string();
-				if (message.ends_with ("\n"))
-					message.remove_from_end (1);
-				print ("message: %1\n", message);
-			}
-			break;
+			case SVRC_MESSAGE:
+				{
+					String message = packet.read_string();
+					if (message.ends_with ("\n"))
+						message.remove_from_end (1);
+					print ("message: %1\n", message);
+				}
+				break;
 
-		case SVRC_LOGGEDIN:
-			print ("login successful\n");
-			m_serverProtocol = packet.read_byte();
-			m_hostname = packet.read_string();
-			m_state = RCON_CONNECTED;
+			case SVRC_LOGGEDIN:
+				print ("login successful\n");
+				m_serverProtocol = packet.read_byte();
+				m_hostname = packet.read_string();
+				m_state = RCON_CONNECTED;
+
+				for (int i = packet.read_byte(); i > 0; --i)
+					process_server_updates (packet);
 
-			for (int i = packet.read_byte(); i > 0; --i)
+				for (int i = packet.read_byte(); i > 0; --i)
+				{
+					String message = packet.read_string();
+					message.normalize();
+					print ("--- %1\n", message);
+				}
+				break;
+
+			case SVRC_UPDATE:
 				process_server_updates (packet);
-
-			for (int i = packet.read_byte(); i > 0; --i)
-			{
-				String message = packet.read_string();
-				message.normalize();
-				print ("--- %1\n", message);
+				break;
 			}
-
-			break;
-
-		case SVRC_UPDATE:
-			process_server_updates (packet);
-			break;
 		}
-
-		if (not ok)
-			print ("error while reading packet\n");
+	}
+	catch (std::exception& e)
+	{
+		print ("error while reading packet: %1\n", e.what());
 	}
 }
 

mercurial