diff --git a/src/proto/base_codec.rs b/src/proto/base_codec.rs index d44c16d..b50cb1a 100644 --- a/src/proto/base_codec.rs +++ b/src/proto/base_codec.rs @@ -16,11 +16,11 @@ use std::fmt; use std::io; use std::net; -use std::u16; -use bytes::{Buf, BufMut, BytesMut}; +use bytes::{BufMut, BytesMut}; use encoding::all::WINDOWS_1252; use encoding::{DecoderTrap, EncoderTrap, Encoding}; +use std::convert::{TryFrom, TryInto}; // Constants // --------- @@ -38,25 +38,40 @@ pub trait Encode { fn encode(&mut self, value: T) -> io::Result<()>; } -/// Builds an EOF error encountered when reading a value of the given type. -fn unexpected_eof_error(type_name: &str) -> io::Error { - io::Error::new( - io::ErrorKind::UnexpectedEof, - format!("reading {}", type_name), - ) +// TODO: Define a real DecodeError type using the thiserror crate for better +// error messages without all this io::Error hackery. + +/// Builds an UnexpectedEof error with the given message. +fn unexpected_eof_error(message: String) -> io::Error { + io::Error::new(io::ErrorKind::UnexpectedEof, message) +} + +/// Builds an InvalidData error with the given message. +fn invalid_data_error(message: String) -> io::Error { + io::Error::new(io::ErrorKind::InvalidData, message) } -/// Builds an InvalidData error for the given value of the given type. -fn invalid_data_error(type_name: &str, value: T) -> io::Error { - io::Error::new( - io::ErrorKind::InvalidData, - format!("invalid {}: {:?}", type_name, value), - ) +/// Annotates the given error with the given context. +fn annotate_error(error: &io::Error, context: &str) -> io::Error { + io::Error::new(error.kind(), format!("{}: {}", context, error)) } /// A type for decoding various types of values from protocol messages. pub struct ProtoDecoder<'a> { - inner: io::Cursor<&'a BytesMut>, + // The buffer we are decoding from. + // + // Invariant: `position <= buffer.len()`. + buffer: &'a [u8], + + // Our current position within `buffer`. + // + // We could instead maintain this implicitly in `buffer` by splitting off + // decoded bytes from the start of the buffer, but we would then be unable + // to remember how many bytes we had decoded. This information is useful to + // have in error messages when encountering decoding errors. + // + // Invariant: `position <= buffer.len()`. + position: usize, } /// This trait is implemented by types that can be decoded from messages using @@ -68,65 +83,94 @@ pub trait ProtoDecode: Sized { impl<'a> ProtoDecoder<'a> { /// Wraps the given byte buffer. - pub fn new(bytes: &'a BytesMut) -> Self { + pub fn new(buffer: &'a [u8]) -> Self { Self { - inner: io::Cursor::new(bytes), + buffer: buffer, + position: 0, } } + /// Returns the number of bytes remaining to decode. + pub fn remaining(&self) -> usize { + self.buffer.len() - self.position + } + /// Returns whether the underlying buffer has remaining bytes to decode. + /// + /// Shorthand for `remaining() > 0`. pub fn has_remaining(&self) -> bool { - self.inner.has_remaining() + self.remaining() > 0 } /// Returns a read-only view of the remaining bytes to decode. + /// + /// The returned slice is of size `remaining()`. pub fn bytes(&self) -> &[u8] { - self.inner.bytes() - } - - /// Asserts that the buffer contains at least `n` more bytes from which to - /// read a value of the named type. - /// Returns Ok(()) if there are that many bytes, otherwise returns a - /// descriptive error. - fn expect_remaining(&self, type_name: &str, n: usize) -> io::Result<()> { - if self.inner.remaining() < n { - Err(unexpected_eof_error(type_name)) - } else { - Ok(()) + &self.buffer[self.position..] + } + + /// Attempts to consume the next `n` bytes from this buffer. + /// + /// Returns a slice of size `n` if successful, in which case this decoder + /// advances its internal position by `n`. + fn consume(&mut self, n: usize) -> io::Result<&[u8]> { + if self.remaining() < n { + return Err(unexpected_eof_error(format!( + "expected {} bytes remaining, found {}", + n, + self.remaining() + ))); } + + // Cannot use bytes() here as it borrows self immutably, which + // prevents us from mutating self.position afterwards. + let end = self.position + n; + let bytes = &self.buffer[self.position..end]; + self.position = end; + Ok(bytes) } - /// Attempts to decode a u32 value in the context of decoding a value of - /// the named type. - fn decode_u32_generic(&mut self, type_name: &str) -> io::Result { - self.expect_remaining(type_name, U32_BYTE_LEN)?; - Ok(self.inner.get_u32_le()) + /// Attempts to decode a u32 value. + /// + /// Note that this method returns a less descriptive error than + /// `self.decode::()`. It is intended to be a low-level building block + /// for decoding other types. + fn decode_u32(&mut self) -> io::Result { + let bytes = self.consume(U32_BYTE_LEN)?; + // The conversion from slice to fixed-size array cannot fail, because + // consume() guarantees that its return value is of size n. + let array: [u8; U32_BYTE_LEN] = bytes.try_into().unwrap(); + Ok(u32::from_le_bytes(array)) } /// Attempts to decode a boolean value. + /// + /// Note that this method returns a less descriptive error than + /// `self.decode::()`. It is intended to be a low-level building block + /// for decoding other types. fn decode_bool(&mut self) -> io::Result { - self.expect_remaining("bool", 1)?; - match self.inner.get_u8() { + let bytes = self.consume(1)?; + match bytes[0] { 0 => Ok(false), 1 => Ok(true), - n => Err(invalid_data_error("bool", n)), + n => Err(invalid_data_error(format!("invalid bool value {}", n))), } } /// Attempts to decode a string value. + /// + /// Note that this method returns a less descriptive error than + /// `self.decode::()`. It is intended to be a low-level building + /// block for decoding other types. fn decode_string(&mut self) -> io::Result { - let len = self.decode_u32_generic("string length")? as usize; - self.expect_remaining("string", len)?; - - let result = { - let bytes = &self.inner.bytes()[..len]; - WINDOWS_1252 - .decode(bytes, DecoderTrap::Strict) - .map_err(|err| invalid_data_error("string", (err, bytes))) - }; + let length = self.decode_u32()? as usize; + let bytes = self.consume(length)?; - self.inner.advance(len); - result + let result = WINDOWS_1252.decode(bytes, DecoderTrap::Strict); + match result { + Ok(string) => Ok(string), + Err(error) => Err(invalid_data_error(format!("invalid string: {:?}", error))), + } } /// Attempts to decode a value of the given type. @@ -134,26 +178,33 @@ impl<'a> ProtoDecoder<'a> { /// Allows easy decoding of complex values using type inference: /// /// ``` - /// let val : Foo = decoder.decode()?; + /// let val: Foo = decoder.decode()?; /// ``` pub fn decode(&mut self) -> io::Result { - T::decode_from(self) + let position = self.position; + match T::decode_from(self) { + Ok(value) => Ok(value), + Err(ref error) => Err(annotate_error( + error, + &format!("decoding value at position {}", position), + )), + } } } impl ProtoDecode for u32 { fn decode_from(decoder: &mut ProtoDecoder) -> io::Result { - decoder.decode_u32_generic("u32") + decoder.decode_u32() } } impl ProtoDecode for u16 { fn decode_from(decoder: &mut ProtoDecoder) -> io::Result { - let n = decoder.decode_u32_generic("u16")?; - if n > u16::MAX as u32 { - return Err(invalid_data_error("u16", n)); + let n = decoder.decode_u32()?; + match u16::try_from(n) { + Ok(value) => Ok(value), + Err(_) => Err(invalid_data_error(format!("invalid u16 value {}", n))), } - Ok(n as u16) } } @@ -165,7 +216,7 @@ impl ProtoDecode for bool { impl ProtoDecode for net::Ipv4Addr { fn decode_from(decoder: &mut ProtoDecoder) -> io::Result { - let ip = decoder.decode_u32_generic("ipv4 address")?; + let ip = decoder.decode_u32()?; Ok(net::Ipv4Addr::from(ip)) } } @@ -186,7 +237,7 @@ impl ProtoDecode for (T, U) { impl ProtoDecode for Vec { fn decode_from(decoder: &mut ProtoDecoder) -> io::Result { - let len = decoder.decode_u32_generic("vector length")? as usize; + let len = decoder.decode_u32()? as usize; let mut vec = Vec::with_capacity(len); for _ in 0..len { let val = decoder.decode()?; @@ -473,7 +524,11 @@ pub mod tests { let result = decoder.decode::(); - expect_io_error(result, io::ErrorKind::UnexpectedEof, "reading u32"); + expect_io_error( + result, + io::ErrorKind::UnexpectedEof, + "decoding value at position 0: expected 4 bytes remaining, found 1", + ); assert_eq!(decoder.bytes(), &[13]); } @@ -516,7 +571,11 @@ pub mod tests { let result = ProtoDecoder::new(&buffer).decode::(); - expect_io_error(result, io::ErrorKind::InvalidData, "invalid bool: 42"); + expect_io_error( + result, + io::ErrorKind::InvalidData, + "decoding value at position 0: invalid bool value 42", + ); } #[test] @@ -525,7 +584,11 @@ pub mod tests { let result = ProtoDecoder::new(&buffer).decode::(); - expect_io_error(result, io::ErrorKind::UnexpectedEof, "reading bool"); + expect_io_error( + result, + io::ErrorKind::UnexpectedEof, + "decoding value at position 0: expected 1 bytes remaining, found 0", + ); } #[test] @@ -564,7 +627,10 @@ pub mod tests { expect_io_error( decoder.decode::(), io::ErrorKind::InvalidData, - &format!("invalid u16: {}", expected_val), + &format!( + "decoding value at position 0: invalid u16 value {}", + expected_val + ), ); } } @@ -577,7 +643,11 @@ pub mod tests { let result = decoder.decode::(); - expect_io_error(result, io::ErrorKind::UnexpectedEof, "reading u16"); + expect_io_error( + result, + io::ErrorKind::UnexpectedEof, + "decoding value at position 0: expected 4 bytes remaining, found 0", + ); } #[test] diff --git a/src/proto/peer/message.rs b/src/proto/peer/message.rs index 81c9581..7d595b2 100644 --- a/src/proto/peer/message.rs +++ b/src/proto/peer/message.rs @@ -171,7 +171,7 @@ mod tests { expect_io_error( result, io::ErrorKind::InvalidData, - "unknown peer message code: 1337", + "decoding value at position 0: unknown peer message code: 1337", ); } diff --git a/src/proto/server/request.rs b/src/proto/server/request.rs index f25a21c..5545966 100644 --- a/src/proto/server/request.rs +++ b/src/proto/server/request.rs @@ -600,7 +600,7 @@ mod tests { expect_io_error( result, io::ErrorKind::InvalidData, - "unknown server request code: 1337", + "decoding value at position 0: unknown server request code: 1337", ); } diff --git a/src/proto/server/response.rs b/src/proto/server/response.rs index 872e4cc..dec6db0 100644 --- a/src/proto/server/response.rs +++ b/src/proto/server/response.rs @@ -1348,7 +1348,7 @@ mod tests { expect_io_error( result, io::ErrorKind::InvalidData, - "unknown server response code: 1337", + "decoding value at position 0: unknown server response code: 1337", ); }