Browse Source

Simplify ProtoDecoder by removing dependency on the bytes crate.

Improve decoding error messages by annotating them with the position
at which the error was encountered. This could use some improvement
in a follow-up by introducing a specific Error type instead of
piggy-backing onto std::io::Error.

It is easier and more generically useful to accept any byte slice
as a buffer. In addition our zero-copy use of the underlying buffer
does not align particularly well with std::io::Read, so we move away
from std::io::Cursor as well.
wip
Titouan Rigoudy 5 years ago
parent
commit
468c404765
4 changed files with 136 additions and 66 deletions
  1. +133
    -63
      src/proto/base_codec.rs
  2. +1
    -1
      src/proto/peer/message.rs
  3. +1
    -1
      src/proto/server/request.rs
  4. +1
    -1
      src/proto/server/response.rs

+ 133
- 63
src/proto/base_codec.rs View File

@ -16,11 +16,11 @@
use std::fmt; use std::fmt;
use std::io; use std::io;
use std::net; use std::net;
use std::u16;
use bytes::{Buf, BufMut, BytesMut};
use bytes::{BufMut, BytesMut};
use encoding::all::WINDOWS_1252; use encoding::all::WINDOWS_1252;
use encoding::{DecoderTrap, EncoderTrap, Encoding}; use encoding::{DecoderTrap, EncoderTrap, Encoding};
use std::convert::{TryFrom, TryInto};
// Constants // Constants
// --------- // ---------
@ -38,25 +38,40 @@ pub trait Encode<T> {
fn encode(&mut self, value: T) -> io::Result<()>; fn encode(&mut self, value: T) -> io::Result<()>;
} }
/// Builds an EOF error encountered when reading a value of the given type.
fn unexpected_eof_error(type_name: &str) -> io::Error {
io::Error::new(
io::ErrorKind::UnexpectedEof,
format!("reading {}", type_name),
)
// TODO: Define a real DecodeError type using the thiserror crate for better
// error messages without all this io::Error hackery.
/// Builds an UnexpectedEof error with the given message.
fn unexpected_eof_error(message: String) -> io::Error {
io::Error::new(io::ErrorKind::UnexpectedEof, message)
}
/// Builds an InvalidData error with the given message.
fn invalid_data_error(message: String) -> io::Error {
io::Error::new(io::ErrorKind::InvalidData, message)
} }
/// Builds an InvalidData error for the given value of the given type.
fn invalid_data_error<T: fmt::Debug>(type_name: &str, value: T) -> io::Error {
io::Error::new(
io::ErrorKind::InvalidData,
format!("invalid {}: {:?}", type_name, value),
)
/// Annotates the given error with the given context.
fn annotate_error(error: &io::Error, context: &str) -> io::Error {
io::Error::new(error.kind(), format!("{}: {}", context, error))
} }
/// A type for decoding various types of values from protocol messages. /// A type for decoding various types of values from protocol messages.
pub struct ProtoDecoder<'a> { pub struct ProtoDecoder<'a> {
inner: io::Cursor<&'a BytesMut>,
// The buffer we are decoding from.
//
// Invariant: `position <= buffer.len()`.
buffer: &'a [u8],
// Our current position within `buffer`.
//
// We could instead maintain this implicitly in `buffer` by splitting off
// decoded bytes from the start of the buffer, but we would then be unable
// to remember how many bytes we had decoded. This information is useful to
// have in error messages when encountering decoding errors.
//
// Invariant: `position <= buffer.len()`.
position: usize,
} }
/// This trait is implemented by types that can be decoded from messages using /// This trait is implemented by types that can be decoded from messages using
@ -68,65 +83,94 @@ pub trait ProtoDecode: Sized {
impl<'a> ProtoDecoder<'a> { impl<'a> ProtoDecoder<'a> {
/// Wraps the given byte buffer. /// Wraps the given byte buffer.
pub fn new(bytes: &'a BytesMut) -> Self {
pub fn new(buffer: &'a [u8]) -> Self {
Self { Self {
inner: io::Cursor::new(bytes),
buffer: buffer,
position: 0,
} }
} }
/// Returns the number of bytes remaining to decode.
pub fn remaining(&self) -> usize {
self.buffer.len() - self.position
}
/// Returns whether the underlying buffer has remaining bytes to decode. /// Returns whether the underlying buffer has remaining bytes to decode.
///
/// Shorthand for `remaining() > 0`.
pub fn has_remaining(&self) -> bool { pub fn has_remaining(&self) -> bool {
self.inner.has_remaining()
self.remaining() > 0
} }
/// Returns a read-only view of the remaining bytes to decode. /// Returns a read-only view of the remaining bytes to decode.
///
/// The returned slice is of size `remaining()`.
pub fn bytes(&self) -> &[u8] { pub fn bytes(&self) -> &[u8] {
self.inner.bytes()
}
/// Asserts that the buffer contains at least `n` more bytes from which to
/// read a value of the named type.
/// Returns Ok(()) if there are that many bytes, otherwise returns a
/// descriptive error.
fn expect_remaining(&self, type_name: &str, n: usize) -> io::Result<()> {
if self.inner.remaining() < n {
Err(unexpected_eof_error(type_name))
} else {
Ok(())
&self.buffer[self.position..]
}
/// Attempts to consume the next `n` bytes from this buffer.
///
/// Returns a slice of size `n` if successful, in which case this decoder
/// advances its internal position by `n`.
fn consume(&mut self, n: usize) -> io::Result<&[u8]> {
if self.remaining() < n {
return Err(unexpected_eof_error(format!(
"expected {} bytes remaining, found {}",
n,
self.remaining()
)));
} }
// Cannot use bytes() here as it borrows self immutably, which
// prevents us from mutating self.position afterwards.
let end = self.position + n;
let bytes = &self.buffer[self.position..end];
self.position = end;
Ok(bytes)
} }
/// Attempts to decode a u32 value in the context of decoding a value of
/// the named type.
fn decode_u32_generic(&mut self, type_name: &str) -> io::Result<u32> {
self.expect_remaining(type_name, U32_BYTE_LEN)?;
Ok(self.inner.get_u32_le())
/// Attempts to decode a u32 value.
///
/// Note that this method returns a less descriptive error than
/// `self.decode::<u32>()`. It is intended to be a low-level building block
/// for decoding other types.
fn decode_u32(&mut self) -> io::Result<u32> {
let bytes = self.consume(U32_BYTE_LEN)?;
// The conversion from slice to fixed-size array cannot fail, because
// consume() guarantees that its return value is of size n.
let array: [u8; U32_BYTE_LEN] = bytes.try_into().unwrap();
Ok(u32::from_le_bytes(array))
} }
/// Attempts to decode a boolean value. /// Attempts to decode a boolean value.
///
/// Note that this method returns a less descriptive error than
/// `self.decode::<bool>()`. It is intended to be a low-level building block
/// for decoding other types.
fn decode_bool(&mut self) -> io::Result<bool> { fn decode_bool(&mut self) -> io::Result<bool> {
self.expect_remaining("bool", 1)?;
match self.inner.get_u8() {
let bytes = self.consume(1)?;
match bytes[0] {
0 => Ok(false), 0 => Ok(false),
1 => Ok(true), 1 => Ok(true),
n => Err(invalid_data_error("bool", n)),
n => Err(invalid_data_error(format!("invalid bool value {}", n))),
} }
} }
/// Attempts to decode a string value. /// Attempts to decode a string value.
///
/// Note that this method returns a less descriptive error than
/// `self.decode::<String>()`. It is intended to be a low-level building
/// block for decoding other types.
fn decode_string(&mut self) -> io::Result<String> { fn decode_string(&mut self) -> io::Result<String> {
let len = self.decode_u32_generic("string length")? as usize;
self.expect_remaining("string", len)?;
let result = {
let bytes = &self.inner.bytes()[..len];
WINDOWS_1252
.decode(bytes, DecoderTrap::Strict)
.map_err(|err| invalid_data_error("string", (err, bytes)))
};
let length = self.decode_u32()? as usize;
let bytes = self.consume(length)?;
self.inner.advance(len);
result
let result = WINDOWS_1252.decode(bytes, DecoderTrap::Strict);
match result {
Ok(string) => Ok(string),
Err(error) => Err(invalid_data_error(format!("invalid string: {:?}", error))),
}
} }
/// Attempts to decode a value of the given type. /// Attempts to decode a value of the given type.
@ -134,26 +178,33 @@ impl<'a> ProtoDecoder<'a> {
/// Allows easy decoding of complex values using type inference: /// Allows easy decoding of complex values using type inference:
/// ///
/// ``` /// ```
/// let val : Foo = decoder.decode()?;
/// let val: Foo = decoder.decode()?;
/// ``` /// ```
pub fn decode<T: ProtoDecode>(&mut self) -> io::Result<T> { pub fn decode<T: ProtoDecode>(&mut self) -> io::Result<T> {
T::decode_from(self)
let position = self.position;
match T::decode_from(self) {
Ok(value) => Ok(value),
Err(ref error) => Err(annotate_error(
error,
&format!("decoding value at position {}", position),
)),
}
} }
} }
impl ProtoDecode for u32 { impl ProtoDecode for u32 {
fn decode_from(decoder: &mut ProtoDecoder) -> io::Result<Self> { fn decode_from(decoder: &mut ProtoDecoder) -> io::Result<Self> {
decoder.decode_u32_generic("u32")
decoder.decode_u32()
} }
} }
impl ProtoDecode for u16 { impl ProtoDecode for u16 {
fn decode_from(decoder: &mut ProtoDecoder) -> io::Result<Self> { fn decode_from(decoder: &mut ProtoDecoder) -> io::Result<Self> {
let n = decoder.decode_u32_generic("u16")?;
if n > u16::MAX as u32 {
return Err(invalid_data_error("u16", n));
let n = decoder.decode_u32()?;
match u16::try_from(n) {
Ok(value) => Ok(value),
Err(_) => Err(invalid_data_error(format!("invalid u16 value {}", n))),
} }
Ok(n as u16)
} }
} }
@ -165,7 +216,7 @@ impl ProtoDecode for bool {
impl ProtoDecode for net::Ipv4Addr { impl ProtoDecode for net::Ipv4Addr {
fn decode_from(decoder: &mut ProtoDecoder) -> io::Result<Self> { fn decode_from(decoder: &mut ProtoDecoder) -> io::Result<Self> {
let ip = decoder.decode_u32_generic("ipv4 address")?;
let ip = decoder.decode_u32()?;
Ok(net::Ipv4Addr::from(ip)) Ok(net::Ipv4Addr::from(ip))
} }
} }
@ -186,7 +237,7 @@ impl<T: ProtoDecode, U: ProtoDecode> ProtoDecode for (T, U) {
impl<T: ProtoDecode> ProtoDecode for Vec<T> { impl<T: ProtoDecode> ProtoDecode for Vec<T> {
fn decode_from(decoder: &mut ProtoDecoder) -> io::Result<Self> { fn decode_from(decoder: &mut ProtoDecoder) -> io::Result<Self> {
let len = decoder.decode_u32_generic("vector length")? as usize;
let len = decoder.decode_u32()? as usize;
let mut vec = Vec::with_capacity(len); let mut vec = Vec::with_capacity(len);
for _ in 0..len { for _ in 0..len {
let val = decoder.decode()?; let val = decoder.decode()?;
@ -473,7 +524,11 @@ pub mod tests {
let result = decoder.decode::<u32>(); let result = decoder.decode::<u32>();
expect_io_error(result, io::ErrorKind::UnexpectedEof, "reading u32");
expect_io_error(
result,
io::ErrorKind::UnexpectedEof,
"decoding value at position 0: expected 4 bytes remaining, found 1",
);
assert_eq!(decoder.bytes(), &[13]); assert_eq!(decoder.bytes(), &[13]);
} }
@ -516,7 +571,11 @@ pub mod tests {
let result = ProtoDecoder::new(&buffer).decode::<bool>(); let result = ProtoDecoder::new(&buffer).decode::<bool>();
expect_io_error(result, io::ErrorKind::InvalidData, "invalid bool: 42");
expect_io_error(
result,
io::ErrorKind::InvalidData,
"decoding value at position 0: invalid bool value 42",
);
} }
#[test] #[test]
@ -525,7 +584,11 @@ pub mod tests {
let result = ProtoDecoder::new(&buffer).decode::<bool>(); let result = ProtoDecoder::new(&buffer).decode::<bool>();
expect_io_error(result, io::ErrorKind::UnexpectedEof, "reading bool");
expect_io_error(
result,
io::ErrorKind::UnexpectedEof,
"decoding value at position 0: expected 1 bytes remaining, found 0",
);
} }
#[test] #[test]
@ -564,7 +627,10 @@ pub mod tests {
expect_io_error( expect_io_error(
decoder.decode::<u16>(), decoder.decode::<u16>(),
io::ErrorKind::InvalidData, io::ErrorKind::InvalidData,
&format!("invalid u16: {}", expected_val),
&format!(
"decoding value at position 0: invalid u16 value {}",
expected_val
),
); );
} }
} }
@ -577,7 +643,11 @@ pub mod tests {
let result = decoder.decode::<u16>(); let result = decoder.decode::<u16>();
expect_io_error(result, io::ErrorKind::UnexpectedEof, "reading u16");
expect_io_error(
result,
io::ErrorKind::UnexpectedEof,
"decoding value at position 0: expected 4 bytes remaining, found 0",
);
} }
#[test] #[test]


+ 1
- 1
src/proto/peer/message.rs View File

@ -171,7 +171,7 @@ mod tests {
expect_io_error( expect_io_error(
result, result,
io::ErrorKind::InvalidData, io::ErrorKind::InvalidData,
"unknown peer message code: 1337",
"decoding value at position 0: unknown peer message code: 1337",
); );
} }


+ 1
- 1
src/proto/server/request.rs View File

@ -600,7 +600,7 @@ mod tests {
expect_io_error( expect_io_error(
result, result,
io::ErrorKind::InvalidData, io::ErrorKind::InvalidData,
"unknown server request code: 1337",
"decoding value at position 0: unknown server request code: 1337",
); );
} }


+ 1
- 1
src/proto/server/response.rs View File

@ -1348,7 +1348,7 @@ mod tests {
expect_io_error( expect_io_error(
result, result,
io::ErrorKind::InvalidData, io::ErrorKind::InvalidData,
"unknown server response code: 1337",
"decoding value at position 0: unknown server response code: 1337",
); );
} }


Loading…
Cancel
Save