From 468c404765c21f601565feff645bf650190776f2 Mon Sep 17 00:00:00 2001 From: Titouan Rigoudy Date: Sun, 14 Jun 2020 00:14:53 +0000 Subject: [PATCH] Simplify ProtoDecoder by removing dependency on the bytes crate. Improve decoding error messages by annotating them with the position at which the error was encountered. This could use some improvement in a follow-up by introducing a specific Error type instead of piggy-backing onto std::io::Error. It is easier and more generically useful to accept any byte slice as a buffer. In addition our zero-copy use of the underlying buffer does not align particularly well with std::io::Read, so we move away from std::io::Cursor as well. --- src/proto/base_codec.rs | 196 ++++++++++++++++++++++++----------- src/proto/peer/message.rs | 2 +- src/proto/server/request.rs | 2 +- src/proto/server/response.rs | 2 +- 4 files changed, 136 insertions(+), 66 deletions(-) diff --git a/src/proto/base_codec.rs b/src/proto/base_codec.rs index d44c16d..b50cb1a 100644 --- a/src/proto/base_codec.rs +++ b/src/proto/base_codec.rs @@ -16,11 +16,11 @@ use std::fmt; use std::io; use std::net; -use std::u16; -use bytes::{Buf, BufMut, BytesMut}; +use bytes::{BufMut, BytesMut}; use encoding::all::WINDOWS_1252; use encoding::{DecoderTrap, EncoderTrap, Encoding}; +use std::convert::{TryFrom, TryInto}; // Constants // --------- @@ -38,25 +38,40 @@ pub trait Encode { fn encode(&mut self, value: T) -> io::Result<()>; } -/// Builds an EOF error encountered when reading a value of the given type. -fn unexpected_eof_error(type_name: &str) -> io::Error { - io::Error::new( - io::ErrorKind::UnexpectedEof, - format!("reading {}", type_name), - ) +// TODO: Define a real DecodeError type using the thiserror crate for better +// error messages without all this io::Error hackery. + +/// Builds an UnexpectedEof error with the given message. +fn unexpected_eof_error(message: String) -> io::Error { + io::Error::new(io::ErrorKind::UnexpectedEof, message) +} + +/// Builds an InvalidData error with the given message. +fn invalid_data_error(message: String) -> io::Error { + io::Error::new(io::ErrorKind::InvalidData, message) } -/// Builds an InvalidData error for the given value of the given type. -fn invalid_data_error(type_name: &str, value: T) -> io::Error { - io::Error::new( - io::ErrorKind::InvalidData, - format!("invalid {}: {:?}", type_name, value), - ) +/// Annotates the given error with the given context. +fn annotate_error(error: &io::Error, context: &str) -> io::Error { + io::Error::new(error.kind(), format!("{}: {}", context, error)) } /// A type for decoding various types of values from protocol messages. pub struct ProtoDecoder<'a> { - inner: io::Cursor<&'a BytesMut>, + // The buffer we are decoding from. + // + // Invariant: `position <= buffer.len()`. + buffer: &'a [u8], + + // Our current position within `buffer`. + // + // We could instead maintain this implicitly in `buffer` by splitting off + // decoded bytes from the start of the buffer, but we would then be unable + // to remember how many bytes we had decoded. This information is useful to + // have in error messages when encountering decoding errors. + // + // Invariant: `position <= buffer.len()`. + position: usize, } /// This trait is implemented by types that can be decoded from messages using @@ -68,65 +83,94 @@ pub trait ProtoDecode: Sized { impl<'a> ProtoDecoder<'a> { /// Wraps the given byte buffer. - pub fn new(bytes: &'a BytesMut) -> Self { + pub fn new(buffer: &'a [u8]) -> Self { Self { - inner: io::Cursor::new(bytes), + buffer: buffer, + position: 0, } } + /// Returns the number of bytes remaining to decode. + pub fn remaining(&self) -> usize { + self.buffer.len() - self.position + } + /// Returns whether the underlying buffer has remaining bytes to decode. + /// + /// Shorthand for `remaining() > 0`. pub fn has_remaining(&self) -> bool { - self.inner.has_remaining() + self.remaining() > 0 } /// Returns a read-only view of the remaining bytes to decode. + /// + /// The returned slice is of size `remaining()`. pub fn bytes(&self) -> &[u8] { - self.inner.bytes() - } - - /// Asserts that the buffer contains at least `n` more bytes from which to - /// read a value of the named type. - /// Returns Ok(()) if there are that many bytes, otherwise returns a - /// descriptive error. - fn expect_remaining(&self, type_name: &str, n: usize) -> io::Result<()> { - if self.inner.remaining() < n { - Err(unexpected_eof_error(type_name)) - } else { - Ok(()) + &self.buffer[self.position..] + } + + /// Attempts to consume the next `n` bytes from this buffer. + /// + /// Returns a slice of size `n` if successful, in which case this decoder + /// advances its internal position by `n`. + fn consume(&mut self, n: usize) -> io::Result<&[u8]> { + if self.remaining() < n { + return Err(unexpected_eof_error(format!( + "expected {} bytes remaining, found {}", + n, + self.remaining() + ))); } + + // Cannot use bytes() here as it borrows self immutably, which + // prevents us from mutating self.position afterwards. + let end = self.position + n; + let bytes = &self.buffer[self.position..end]; + self.position = end; + Ok(bytes) } - /// Attempts to decode a u32 value in the context of decoding a value of - /// the named type. - fn decode_u32_generic(&mut self, type_name: &str) -> io::Result { - self.expect_remaining(type_name, U32_BYTE_LEN)?; - Ok(self.inner.get_u32_le()) + /// Attempts to decode a u32 value. + /// + /// Note that this method returns a less descriptive error than + /// `self.decode::()`. It is intended to be a low-level building block + /// for decoding other types. + fn decode_u32(&mut self) -> io::Result { + let bytes = self.consume(U32_BYTE_LEN)?; + // The conversion from slice to fixed-size array cannot fail, because + // consume() guarantees that its return value is of size n. + let array: [u8; U32_BYTE_LEN] = bytes.try_into().unwrap(); + Ok(u32::from_le_bytes(array)) } /// Attempts to decode a boolean value. + /// + /// Note that this method returns a less descriptive error than + /// `self.decode::()`. It is intended to be a low-level building block + /// for decoding other types. fn decode_bool(&mut self) -> io::Result { - self.expect_remaining("bool", 1)?; - match self.inner.get_u8() { + let bytes = self.consume(1)?; + match bytes[0] { 0 => Ok(false), 1 => Ok(true), - n => Err(invalid_data_error("bool", n)), + n => Err(invalid_data_error(format!("invalid bool value {}", n))), } } /// Attempts to decode a string value. + /// + /// Note that this method returns a less descriptive error than + /// `self.decode::()`. It is intended to be a low-level building + /// block for decoding other types. fn decode_string(&mut self) -> io::Result { - let len = self.decode_u32_generic("string length")? as usize; - self.expect_remaining("string", len)?; - - let result = { - let bytes = &self.inner.bytes()[..len]; - WINDOWS_1252 - .decode(bytes, DecoderTrap::Strict) - .map_err(|err| invalid_data_error("string", (err, bytes))) - }; + let length = self.decode_u32()? as usize; + let bytes = self.consume(length)?; - self.inner.advance(len); - result + let result = WINDOWS_1252.decode(bytes, DecoderTrap::Strict); + match result { + Ok(string) => Ok(string), + Err(error) => Err(invalid_data_error(format!("invalid string: {:?}", error))), + } } /// Attempts to decode a value of the given type. @@ -134,26 +178,33 @@ impl<'a> ProtoDecoder<'a> { /// Allows easy decoding of complex values using type inference: /// /// ``` - /// let val : Foo = decoder.decode()?; + /// let val: Foo = decoder.decode()?; /// ``` pub fn decode(&mut self) -> io::Result { - T::decode_from(self) + let position = self.position; + match T::decode_from(self) { + Ok(value) => Ok(value), + Err(ref error) => Err(annotate_error( + error, + &format!("decoding value at position {}", position), + )), + } } } impl ProtoDecode for u32 { fn decode_from(decoder: &mut ProtoDecoder) -> io::Result { - decoder.decode_u32_generic("u32") + decoder.decode_u32() } } impl ProtoDecode for u16 { fn decode_from(decoder: &mut ProtoDecoder) -> io::Result { - let n = decoder.decode_u32_generic("u16")?; - if n > u16::MAX as u32 { - return Err(invalid_data_error("u16", n)); + let n = decoder.decode_u32()?; + match u16::try_from(n) { + Ok(value) => Ok(value), + Err(_) => Err(invalid_data_error(format!("invalid u16 value {}", n))), } - Ok(n as u16) } } @@ -165,7 +216,7 @@ impl ProtoDecode for bool { impl ProtoDecode for net::Ipv4Addr { fn decode_from(decoder: &mut ProtoDecoder) -> io::Result { - let ip = decoder.decode_u32_generic("ipv4 address")?; + let ip = decoder.decode_u32()?; Ok(net::Ipv4Addr::from(ip)) } } @@ -186,7 +237,7 @@ impl ProtoDecode for (T, U) { impl ProtoDecode for Vec { fn decode_from(decoder: &mut ProtoDecoder) -> io::Result { - let len = decoder.decode_u32_generic("vector length")? as usize; + let len = decoder.decode_u32()? as usize; let mut vec = Vec::with_capacity(len); for _ in 0..len { let val = decoder.decode()?; @@ -473,7 +524,11 @@ pub mod tests { let result = decoder.decode::(); - expect_io_error(result, io::ErrorKind::UnexpectedEof, "reading u32"); + expect_io_error( + result, + io::ErrorKind::UnexpectedEof, + "decoding value at position 0: expected 4 bytes remaining, found 1", + ); assert_eq!(decoder.bytes(), &[13]); } @@ -516,7 +571,11 @@ pub mod tests { let result = ProtoDecoder::new(&buffer).decode::(); - expect_io_error(result, io::ErrorKind::InvalidData, "invalid bool: 42"); + expect_io_error( + result, + io::ErrorKind::InvalidData, + "decoding value at position 0: invalid bool value 42", + ); } #[test] @@ -525,7 +584,11 @@ pub mod tests { let result = ProtoDecoder::new(&buffer).decode::(); - expect_io_error(result, io::ErrorKind::UnexpectedEof, "reading bool"); + expect_io_error( + result, + io::ErrorKind::UnexpectedEof, + "decoding value at position 0: expected 1 bytes remaining, found 0", + ); } #[test] @@ -564,7 +627,10 @@ pub mod tests { expect_io_error( decoder.decode::(), io::ErrorKind::InvalidData, - &format!("invalid u16: {}", expected_val), + &format!( + "decoding value at position 0: invalid u16 value {}", + expected_val + ), ); } } @@ -577,7 +643,11 @@ pub mod tests { let result = decoder.decode::(); - expect_io_error(result, io::ErrorKind::UnexpectedEof, "reading u16"); + expect_io_error( + result, + io::ErrorKind::UnexpectedEof, + "decoding value at position 0: expected 4 bytes remaining, found 0", + ); } #[test] diff --git a/src/proto/peer/message.rs b/src/proto/peer/message.rs index 81c9581..7d595b2 100644 --- a/src/proto/peer/message.rs +++ b/src/proto/peer/message.rs @@ -171,7 +171,7 @@ mod tests { expect_io_error( result, io::ErrorKind::InvalidData, - "unknown peer message code: 1337", + "decoding value at position 0: unknown peer message code: 1337", ); } diff --git a/src/proto/server/request.rs b/src/proto/server/request.rs index f25a21c..5545966 100644 --- a/src/proto/server/request.rs +++ b/src/proto/server/request.rs @@ -600,7 +600,7 @@ mod tests { expect_io_error( result, io::ErrorKind::InvalidData, - "unknown server request code: 1337", + "decoding value at position 0: unknown server request code: 1337", ); } diff --git a/src/proto/server/response.rs b/src/proto/server/response.rs index 872e4cc..dec6db0 100644 --- a/src/proto/server/response.rs +++ b/src/proto/server/response.rs @@ -1348,7 +1348,7 @@ mod tests { expect_io_error( result, io::ErrorKind::InvalidData, - "unknown server response code: 1337", + "decoding value at position 0: unknown server response code: 1337", ); }