diff --git a/src/proto/handler.rs b/src/proto/handler.rs index 17f5d57..0a586a9 100644 --- a/src/proto/handler.rs +++ b/src/proto/handler.rs @@ -7,9 +7,46 @@ use mio; use config; -use super::{Packet, PacketStream, Request, Response}; +use super::{PacketStream, Request, Response}; use super::server::*; +/// A struct used for writing bytes to a TryWrite sink. +struct OutBuf { + cursor: usize, + bytes: Vec +} + +impl From> for OutBuf { + fn from(bytes: Vec) -> Self { + OutBuf { + cursor: 0, + bytes: bytes + } + } +} + +impl OutBuf { + #[inline] + fn remaining(&self) -> usize { + self.bytes.len() - self.cursor + } + + #[inline] + fn has_remaining(&self) -> bool { + self.remaining() > 0 + } + + fn try_write_to(&mut self, mut writer: T) -> io::Result> + where T: mio::TryWrite + { + let result = writer.try_write(&self.bytes[self.cursor..]); + if let Ok(Some(bytes_written)) = result { + self.cursor += bytes_written; + } + result + } +} + /// This struct provides a simple way to generate different tokens. struct TokenCounter { counter: usize, @@ -35,7 +72,7 @@ struct Handler { server_token: mio::Token, server_stream: PacketStream, - server_queue: VecDeque, + server_queue: VecDeque, client_tx: mpsc::Sender, } @@ -118,15 +155,20 @@ impl Handler { fn write_server(&mut self) { loop { - let mut packet = match self.server_queue.pop_front() { - Some(packet) => packet, + let mut outbuf = match self.server_queue.pop_front() { + Some(outbuf) => outbuf, None => break }; - match self.server_stream.try_write(&mut packet) { - Ok(Some(())) => (), // continue looping + match outbuf.try_write_to(&mut self.server_stream) { + Ok(Some(_)) => { + if outbuf.has_remaining() { + self.server_queue.push_front(outbuf) + } + // Continue looping + }, Ok(None) => { - self.server_queue.push_front(packet); + self.server_queue.push_front(outbuf); break }, Err(e) => { @@ -140,7 +182,7 @@ impl Handler { fn notify_server(&mut self, request: ServerRequest) -> io::Result<()> { debug!("Sending server request: {:?}", request); let packet = try!(request.to_packet()); - self.server_queue.push_back(packet); + self.server_queue.push_back(OutBuf::from(packet.into_bytes())); Ok(()) } diff --git a/src/proto/mod.rs b/src/proto/mod.rs index 41e2c04..ca8f8b2 100644 --- a/src/proto/mod.rs +++ b/src/proto/mod.rs @@ -4,13 +4,7 @@ pub mod server; pub use self::handler::*; -pub use self::packet::{ - Packet, - PacketReadError, - PacketStream, - ReadFromPacket, - WriteToPacket -}; +pub use self::packet::*; use self::server::{ServerRequest, ServerResponse}; diff --git a/src/proto/packet.rs b/src/proto/packet.rs index f921282..986a18a 100644 --- a/src/proto/packet.rs +++ b/src/proto/packet.rs @@ -10,7 +10,7 @@ use byteorder::{ByteOrder, LittleEndian, ReadBytesExt, WriteBytesExt}; use encoding::{Encoding, DecoderTrap, EncoderTrap}; use encoding::all::ISO_8859_1; use mio::{ - Evented, EventLoop, EventSet, Handler, PollOpt, Token, TryRead, TryWrite + Evented, EventLoop, EventSet, Handler, PollOpt, Token, TryRead }; const MAX_PACKET_SIZE: usize = 1 << 20; // 1 MiB @@ -19,88 +19,114 @@ const MAX_MESSAGE_SIZE: usize = MAX_PACKET_SIZE - U32_SIZE; const MAX_PORT: u32 = (1 << 16) - 1; -/*========* - * PACKET * - *========*/ +/*==================* + * READ-ONLY PACKET * + *==================*/ #[derive(Debug)] pub struct Packet { + /// The packet code. + code: u32, + /// The current read position in the byte buffer. cursor: usize, + /// The underlying bytes. bytes: Vec, } impl io::Read for Packet { fn read(&mut self, buf: &mut [u8]) -> io::Result { - let mut slice = &self.bytes[self.cursor..]; - let result = slice.read(buf); - if let Ok(num_bytes_read) = result { - self.cursor += num_bytes_read - } - result - } -} - -impl io::Write for Packet { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.bytes.write(buf) - } - - fn flush(&mut self) -> io::Result<()> { - self.bytes.flush() + let bytes_read = { + let mut slice = &self.bytes[self.cursor..]; + try!(slice.read(buf)) + }; + self.cursor += bytes_read; + Ok(bytes_read) } } impl Packet { - /// Returns an empty packet with the given packet code. - pub fn new(code: u32) -> Self { - let mut bytes = Vec::new(); - bytes.write_u32::(0).unwrap(); - bytes.write_u32::(code).unwrap(); + /// Returns a readable packet struct from the wire representation of a + /// packet. + fn from_bytes(bytes: Vec) -> Self { + // Check that the packet is long enough to contain at least a code. + assert!(bytes.len() >= 2*U32_SIZE); + // Read the purported length of the packet. + let size = LittleEndian::read_u32(&bytes[0..U32_SIZE]) as usize; + // Check that the packet has the right length. + assert!(size + U32_SIZE == bytes.len()); + // Read the packet code. + let code = LittleEndian::read_u32(&bytes[U32_SIZE..2*U32_SIZE]); Packet { + code: code, cursor: 2*U32_SIZE, - bytes: bytes, + bytes: bytes, } } - /// Returns a new packet struct, constructed from the wire representation - /// of a packet. - fn from_raw_parts(bytes: Vec) -> Self { - let size = LittleEndian::read_u32(&bytes[..U32_SIZE]) as usize; - assert!(size + U32_SIZE == bytes.len()); - Packet { - cursor: U32_SIZE, - bytes: bytes, - } + /// Returns the packet code. + pub fn code(&self) -> u32 { + self.code } /// Provides the main way to read data out of a binary packet. - pub fn read_value(&mut self) - -> Result + pub fn read_value(&mut self) -> Result + where T: ReadFromPacket { T::read_from_packet(self) } - /// Provides the main way to write data into a binary packet. - pub fn write_value(&mut self, val: T) - -> io::Result<()> - { - val.write_to_packet(self) - } - /// Returns the number of unread bytes remaining in the packet. pub fn bytes_remaining(&self) -> usize { self.bytes.len() - self.cursor } +} - /// Returns a slice pointing to the entire underlying byte array, including - /// the length prefix. - pub fn as_slice(&mut self) -> &[u8] { - let bytes_len = (self.bytes.len() - U32_SIZE) as u32; +/*===================* + * WRITE-ONLY PACKET * + *===================*/ + +#[derive(Debug)] +pub struct MutPacket { + bytes: Vec, +} + +impl MutPacket { + /// Returns an empty packet with the given packet code. + pub fn new(code: u32) -> Self { + // Leave space for the eventual size of the packet. + let mut bytes = vec![0; U32_SIZE]; + // Write the code. + bytes.write_u32::(code).unwrap(); + MutPacket { + bytes: bytes, + } + } + + /// Provides the main way to write data into a binary packet. + pub fn write_value(&mut self, val: T) -> io::Result<()> + where T: WriteToPacket + { + val.write_to_packet(self) + } + + /// Consumes the mutable packet and returns its wire representation. + pub fn into_bytes(mut self) -> Vec { + let length = (self.bytes.len() - U32_SIZE) as u32; { let mut first_word = &mut self.bytes[..U32_SIZE]; - first_word.write_u32::(bytes_len).unwrap(); + first_word.write_u32::(length).unwrap(); } - &self.bytes + self.bytes + } +} + +impl io::Write for MutPacket { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.bytes.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.bytes.flush() } } @@ -263,21 +289,21 @@ impl ReadFromPacket for Vec { *=================*/ /// This trait is implemented by types that can be serialized to a binary -/// Packet. +/// MutPacket. pub trait WriteToPacket { - fn write_to_packet(self, &mut Packet) -> io::Result<()>; + fn write_to_packet(self, &mut MutPacket) -> io::Result<()>; } /// 32-bit integers are serialized in 4 bytes, little-endian. impl WriteToPacket for u32 { - fn write_to_packet(self, packet: &mut Packet) -> io::Result<()> { + fn write_to_packet(self, packet: &mut MutPacket) -> io::Result<()> { packet.write_u32::(self) } } /// Booleans are serialized as single bytes, containing either 0 or 1. impl WriteToPacket for bool { - fn write_to_packet(self, packet: &mut Packet) -> io::Result<()> { + fn write_to_packet(self, packet: &mut MutPacket) -> io::Result<()> { try!(packet.write(&[self as u8])); Ok(()) } @@ -285,7 +311,7 @@ impl WriteToPacket for bool { /// 16-bit integers are serialized as 32-bit integers. impl WriteToPacket for u16 { - fn write_to_packet(self, packet: &mut Packet) -> io::Result<()> { + fn write_to_packet(self, packet: &mut MutPacket) -> io::Result<()> { (self as u32).write_to_packet(packet) } } @@ -293,7 +319,7 @@ impl WriteToPacket for u16 { /// Strings are serialized as a length-prefixed array of ISO-8859-1 encoded /// characters. impl<'a> WriteToPacket for &'a str { - fn write_to_packet(self, packet: &mut Packet) -> io::Result<()> { + fn write_to_packet(self, packet: &mut MutPacket) -> io::Result<()> { let bytes = match ISO_8859_1.encode(self, EncoderTrap::Strict) { Ok(bytes) => bytes, Err(_) => { @@ -309,7 +335,7 @@ impl<'a> WriteToPacket for &'a str { /// Deref coercion does not happen for trait methods apparently. impl<'a> WriteToPacket for &'a String { - fn write_to_packet(self, packet: &mut Packet) -> io::Result<()> { + fn write_to_packet(self, packet: &mut MutPacket) -> io::Result<()> { packet.write_value::<&'a str>(self) } } @@ -403,22 +429,11 @@ impl PacketStream { self.num_bytes_left = U32_SIZE; let new_buffer = vec![0;U32_SIZE]; let old_buffer = mem::replace(&mut self.buffer, new_buffer); - Ok(Some(Packet::from_raw_parts(old_buffer))) + Ok(Some(Packet::from_bytes(old_buffer))) } } } - /// Tries to write a given packet to the underlying byte stream. - /// TODO: If the packet is not entirely written in the first call, this - /// will send garbage along the wire. Instead we should track how far we - /// are in sending the given packet? - pub fn try_write(&mut self, packet: &mut Packet) -> io::Result> { - match try!(self.stream.try_write(packet.as_slice())) { - None => Ok(None), - Some(_) => Ok(Some(())) - } - } - /// Register the packet stream with the given mio event loop. pub fn register( &self, event_loop: &mut EventLoop, token: Token, @@ -437,3 +452,13 @@ impl PacketStream { event_loop.reregister(&self.stream, token, event_set, poll_opt) } } + +impl io::Write for PacketStream { + fn write(&mut self, bytes: &[u8]) -> io::Result { + self.stream.write(bytes) + } + + fn flush(&mut self) -> io::Result<()> { + self.stream.flush() + } +} diff --git a/src/proto/server/request.rs b/src/proto/server/request.rs index dc06d20..a8f87d1 100644 --- a/src/proto/server/request.rs +++ b/src/proto/server/request.rs @@ -4,7 +4,7 @@ use crypto::md5::Md5; use crypto::digest::Digest; use super::constants::*; -use super::super::packet::{Packet, WriteToPacket}; +use super::super::packet::{MutPacket, WriteToPacket}; /*================* * SERVER REQUEST * @@ -25,7 +25,7 @@ pub enum ServerRequest { macro_rules! try_to_packet { ($code: ident, $request:ident) => { { - let mut packet = Packet::new($code); + let mut packet = MutPacket::new($code); try!($request.write_to_packet(&mut packet)); Ok(packet) } @@ -33,7 +33,7 @@ macro_rules! try_to_packet { } impl ServerRequest { - pub fn to_packet(&self) -> io::Result { + pub fn to_packet(&self) -> io::Result { match *self { ServerRequest::LoginRequest(ref request) => try_to_packet!(CODE_LOGIN, request), @@ -48,7 +48,7 @@ impl ServerRequest { try_to_packet!(CODE_ROOM_LEAVE, request), ServerRequest::RoomListRequest => - Ok(Packet::new(CODE_ROOM_LIST)), + Ok(MutPacket::new(CODE_ROOM_LIST)), ServerRequest::RoomMessageRequest(ref request) => try_to_packet!(CODE_ROOM_MESSAGE, request), @@ -97,7 +97,7 @@ impl LoginRequest { } impl<'a> WriteToPacket for &'a LoginRequest { - fn write_to_packet(self, packet: &mut Packet) -> io::Result<()> { + fn write_to_packet(self, packet: &mut MutPacket) -> io::Result<()> { let userpass = String::new() + &self.username + &self.password; let userpass_md5 = md5_str(&userpass); @@ -121,7 +121,7 @@ pub struct PeerAddressRequest { } impl<'a> WriteToPacket for &'a PeerAddressRequest { - fn write_to_packet(self, packet: &mut Packet) -> io::Result<()> { + fn write_to_packet(self, packet: &mut MutPacket) -> io::Result<()> { try!(packet.write_value(&self.username)); Ok(()) } @@ -137,7 +137,7 @@ pub struct RoomJoinRequest { } impl<'a> WriteToPacket for &'a RoomJoinRequest { - fn write_to_packet(self, packet: &mut Packet) -> io::Result<()> { + fn write_to_packet(self, packet: &mut MutPacket) -> io::Result<()> { try!(packet.write_value(&self.room_name)); Ok(()) } @@ -153,7 +153,7 @@ pub struct RoomLeaveRequest { } impl<'a> WriteToPacket for &'a RoomLeaveRequest { - fn write_to_packet(self, packet: &mut Packet) -> io::Result<()> { + fn write_to_packet(self, packet: &mut MutPacket) -> io::Result<()> { try!(packet.write_value(&self.room_name)); Ok(()) } @@ -170,7 +170,7 @@ pub struct RoomMessageRequest { } impl<'a> WriteToPacket for &'a RoomMessageRequest { - fn write_to_packet(self, packet: &mut Packet) -> io::Result<()> { + fn write_to_packet(self, packet: &mut MutPacket) -> io::Result<()> { try!(packet.write_value(&self.room_name)); try!(packet.write_value(&self.message)); Ok(()) @@ -187,7 +187,7 @@ pub struct SetListenPortRequest { } impl<'a> WriteToPacket for &'a SetListenPortRequest { - fn write_to_packet(self, packet: &mut Packet) -> io::Result<()> { + fn write_to_packet(self, packet: &mut MutPacket) -> io::Result<()> { try!(packet.write_value(self.port)); Ok(()) } @@ -203,7 +203,7 @@ pub struct UserStatusRequest { } impl<'a> WriteToPacket for &'a UserStatusRequest { - fn write_to_packet(self, packet: &mut Packet) -> io::Result<()> { + fn write_to_packet(self, packet: &mut MutPacket) -> io::Result<()> { try!(packet.write_value(&self.user_name)); Ok(()) } diff --git a/src/proto/server/response.rs b/src/proto/server/response.rs index 786fe6a..f471635 100644 --- a/src/proto/server/response.rs +++ b/src/proto/server/response.rs @@ -42,8 +42,7 @@ macro_rules! try_read_from_packet { impl ReadFromPacket for ServerResponse { fn read_from_packet(packet: &mut Packet) -> Result { - let code = try!(packet.read_value()); - let resp = match code { + let resp = match packet.code() { CODE_CONNECT_TO_PEER => try_read_from_packet!(ConnectToPeerResponse, packet), @@ -94,7 +93,7 @@ impl ReadFromPacket for ServerResponse { let bytes_remaining = packet.bytes_remaining(); if bytes_remaining > 0 { warn!("Packet with code {} contains {} extra bytes", - code, bytes_remaining) + packet.code(), bytes_remaining) } Ok(resp) } diff --git a/src/user.rs b/src/user.rs index e3b736d..58d3913 100644 --- a/src/user.rs +++ b/src/user.rs @@ -37,7 +37,7 @@ impl proto::ReadFromPacket for Status { } impl<'a> proto::WriteToPacket for &'a Status { - fn write_to_packet(self, packet: &mut proto::Packet) -> io::Result<()> { + fn write_to_packet(self, packet: &mut proto::MutPacket) -> io::Result<()> { let n = match *self { Status::Offline => STATUS_OFFLINE, Status::Away => STATUS_AWAY,