@ -16,11 +16,11 @@
use std ::fmt ;
use std ::io ;
use std ::net ;
use std ::u16 ;
use bytes ::{ Buf , Buf Mut, BytesMut } ;
use bytes ::{ BufMut , BytesMut } ;
use encoding ::all ::WINDOWS_1252 ;
use encoding ::{ DecoderTrap , EncoderTrap , Encoding } ;
use std ::convert ::{ TryFrom , TryInto } ;
// Constants
// ---------
@ -38,25 +38,40 @@ pub trait Encode<T> {
fn encode ( & mut self , value : T ) -> io ::Result < ( ) > ;
}
/// Builds an EOF error encountered when reading a value of the given type.
fn unexpected_eof_error ( type_name : & str ) -> io ::Error {
io ::Error ::new (
io ::ErrorKind ::UnexpectedEof ,
format ! ( "reading {}" , type_name ) ,
)
// TODO: Define a real DecodeError type using the thiserror crate for better
// error messages without all this io::Error hackery.
/// Builds an UnexpectedEof error with the given message.
fn unexpected_eof_error ( message : String ) -> io ::Error {
io ::Error ::new ( io ::ErrorKind ::UnexpectedEof , message )
}
/// Builds an InvalidData error with the given message.
fn invalid_data_error ( message : String ) -> io ::Error {
io ::Error ::new ( io ::ErrorKind ::InvalidData , message )
}
/// Builds an InvalidData error for the given value of the given type.
fn invalid_data_error < T : fmt ::Debug > ( type_name : & str , value : T ) -> io ::Error {
io ::Error ::new (
io ::ErrorKind ::InvalidData ,
format ! ( "invalid {}: {:?}" , type_name , value ) ,
)
/// Annotates the given error with the given context.
fn annotate_error ( error : & io ::Error , context : & str ) -> io ::Error {
io ::Error ::new ( error . kind ( ) , format ! ( "{}: {}" , context , error ) )
}
/// A type for decoding various types of values from protocol messages.
pub struct ProtoDecoder < 'a > {
inner : io ::Cursor < & 'a BytesMut > ,
// The buffer we are decoding from.
//
// Invariant: `position <= buffer.len()`.
buffer : & 'a [ u8 ] ,
// Our current position within `buffer`.
//
// We could instead maintain this implicitly in `buffer` by splitting off
// decoded bytes from the start of the buffer, but we would then be unable
// to remember how many bytes we had decoded. This information is useful to
// have in error messages when encountering decoding errors.
//
// Invariant: `position <= buffer.len()`.
position : usize ,
}
/// This trait is implemented by types that can be decoded from messages using
@ -68,65 +83,94 @@ pub trait ProtoDecode: Sized {
impl < 'a > ProtoDecoder < 'a > {
/// Wraps the given byte buffer.
pub fn new ( bytes : & 'a BytesMut ) -> Self {
pub fn new ( buffer : & 'a [ u8 ] ) -> Self {
Self {
inner : io ::Cursor ::new ( bytes ) ,
buffer : buffer ,
position : 0 ,
}
}
/// Returns the number of bytes remaining to decode.
pub fn remaining ( & self ) -> usize {
self . buffer . len ( ) - self . position
}
/// Returns whether the underlying buffer has remaining bytes to decode.
///
/// Shorthand for `remaining() > 0`.
pub fn has_remaining ( & self ) -> bool {
self . inner . has_remaining ( )
self . remaining ( ) > 0
}
/// Returns a read-only view of the remaining bytes to decode.
///
/// The returned slice is of size `remaining()`.
pub fn bytes ( & self ) -> & [ u8 ] {
self . inner . bytes ( )
}
/// Asserts that the buffer contains at least `n` more bytes from which to
/// read a value of the named type.
/// Returns Ok(()) if there are that many bytes, otherwise returns a
/// descriptive error.
fn expect_remaining ( & self , type_name : & str , n : usize ) -> io ::Result < ( ) > {
if self . inner . remaining ( ) < n {
Err ( unexpected_eof_error ( type_name ) )
} else {
Ok ( ( ) )
& self . buffer [ self . position . . ]
}
/// Attempts to consume the next `n` bytes from this buffer.
///
/// Returns a slice of size `n` if successful, in which case this decoder
/// advances its internal position by `n`.
fn consume ( & mut self , n : usize ) -> io ::Result < & [ u8 ] > {
if self . remaining ( ) < n {
return Err ( unexpected_eof_error ( format ! (
"expected {} bytes remaining, found {}" ,
n ,
self . remaining ( )
) ) ) ;
}
// Cannot use bytes() here as it borrows self immutably, which
// prevents us from mutating self.position afterwards.
let end = self . position + n ;
let bytes = & self . buffer [ self . position . . end ] ;
self . position = end ;
Ok ( bytes )
}
/// Attempts to decode a u32 value in the context of decoding a value of
/// the named type.
fn decode_u32_generic ( & mut self , type_name : & str ) -> io ::Result < u32 > {
self . expect_remaining ( type_name , U32_BYTE_LEN ) ? ;
Ok ( self . inner . get_u32_le ( ) )
/// Attempts to decode a u32 value.
///
/// Note that this method returns a less descriptive error than
/// `self.decode::<u32>()`. It is intended to be a low-level building block
/// for decoding other types.
fn decode_u32 ( & mut self ) -> io ::Result < u32 > {
let bytes = self . consume ( U32_BYTE_LEN ) ? ;
// The conversion from slice to fixed-size array cannot fail, because
// consume() guarantees that its return value is of size n.
let array : [ u8 ; U32_BYTE_LEN ] = bytes . try_into ( ) . unwrap ( ) ;
Ok ( u32 ::from_le_bytes ( array ) )
}
/// Attempts to decode a boolean value.
///
/// Note that this method returns a less descriptive error than
/// `self.decode::<bool>()`. It is intended to be a low-level building block
/// for decoding other types.
fn decode_bool ( & mut self ) -> io ::Result < bool > {
self . expect_remaining ( "bool" , 1 ) ? ;
match self . inner . get_u8 ( ) {
let bytes = self . consume ( 1 ) ? ;
match bytes [ 0 ] {
0 = > Ok ( false ) ,
1 = > Ok ( true ) ,
n = > Err ( invalid_data_error ( "bool" , n ) ) ,
n = > Err ( invalid_data_error ( format ! ( "invalid bool value {} " , n ) ) ) ,
}
}
/// Attempts to decode a string value.
///
/// Note that this method returns a less descriptive error than
/// `self.decode::<String>()`. It is intended to be a low-level building
/// block for decoding other types.
fn decode_string ( & mut self ) -> io ::Result < String > {
let len = self . decode_u32_generic ( "string length" ) ? as usize ;
self . expect_remaining ( "string" , len ) ? ;
let result = {
let bytes = & self . inner . bytes ( ) [ . . len ] ;
WINDOWS_1252
. decode ( bytes , DecoderTrap ::Strict )
. map_err ( | err | invalid_data_error ( "string" , ( err , bytes ) ) )
} ;
let length = self . decode_u32 ( ) ? as usize ;
let bytes = self . consume ( length ) ? ;
self . inner . advance ( len ) ;
result
let result = WINDOWS_1252 . decode ( bytes , DecoderTrap ::Strict ) ;
match result {
Ok ( string ) = > Ok ( string ) ,
Err ( error ) = > Err ( invalid_data_error ( format ! ( "invalid string: {:?}" , error ) ) ) ,
}
}
/// Attempts to decode a value of the given type.
@ -134,26 +178,33 @@ impl<'a> ProtoDecoder<'a> {
/// Allows easy decoding of complex values using type inference:
///
/// ```
/// let val : Foo = decoder.decode()?;
/// let val: Foo = decoder.decode()?;
/// ```
pub fn decode < T : ProtoDecode > ( & mut self ) -> io ::Result < T > {
T ::decode_from ( self )
let position = self . position ;
match T ::decode_from ( self ) {
Ok ( value ) = > Ok ( value ) ,
Err ( ref error ) = > Err ( annotate_error (
error ,
& format ! ( "decoding value at position {}" , position ) ,
) ) ,
}
}
}
impl ProtoDecode for u32 {
fn decode_from ( decoder : & mut ProtoDecoder ) -> io ::Result < Self > {
decoder . decode_u32_generic ( "u32" )
decoder . decode_u32 ( )
}
}
impl ProtoDecode for u16 {
fn decode_from ( decoder : & mut ProtoDecoder ) -> io ::Result < Self > {
let n = decoder . decode_u32_generic ( "u16" ) ? ;
if n > u16 ::MAX as u32 {
return Err ( invalid_data_error ( "u16" , n ) ) ;
let n = decoder . decode_u32 ( ) ? ;
match u16 ::try_from ( n ) {
Ok ( value ) = > Ok ( value ) ,
Err ( _ ) = > Err ( invalid_data_error ( format ! ( "invalid u16 value {}" , n ) ) ) ,
}
Ok ( n as u16 )
}
}
@ -165,7 +216,7 @@ impl ProtoDecode for bool {
impl ProtoDecode for net ::Ipv4Addr {
fn decode_from ( decoder : & mut ProtoDecoder ) -> io ::Result < Self > {
let ip = decoder . decode_u32_generic ( "ipv4 address" ) ? ;
let ip = decoder . decode_u32 ( ) ? ;
Ok ( net ::Ipv4Addr ::from ( ip ) )
}
}
@ -186,7 +237,7 @@ impl<T: ProtoDecode, U: ProtoDecode> ProtoDecode for (T, U) {
impl < T : ProtoDecode > ProtoDecode for Vec < T > {
fn decode_from ( decoder : & mut ProtoDecoder ) -> io ::Result < Self > {
let len = decoder . decode_u32_generic ( "vector length" ) ? as usize ;
let len = decoder . decode_u32 ( ) ? as usize ;
let mut vec = Vec ::with_capacity ( len ) ;
for _ in 0 . . len {
let val = decoder . decode ( ) ? ;
@ -473,7 +524,11 @@ pub mod tests {
let result = decoder . decode ::< u32 > ( ) ;
expect_io_error ( result , io ::ErrorKind ::UnexpectedEof , "reading u32" ) ;
expect_io_error (
result ,
io ::ErrorKind ::UnexpectedEof ,
"decoding value at position 0: expected 4 bytes remaining, found 1" ,
) ;
assert_eq ! ( decoder . bytes ( ) , & [ 13 ] ) ;
}
@ -516,7 +571,11 @@ pub mod tests {
let result = ProtoDecoder ::new ( & buffer ) . decode ::< bool > ( ) ;
expect_io_error ( result , io ::ErrorKind ::InvalidData , "invalid bool: 42" ) ;
expect_io_error (
result ,
io ::ErrorKind ::InvalidData ,
"decoding value at position 0: invalid bool value 42" ,
) ;
}
#[ test ]
@ -525,7 +584,11 @@ pub mod tests {
let result = ProtoDecoder ::new ( & buffer ) . decode ::< bool > ( ) ;
expect_io_error ( result , io ::ErrorKind ::UnexpectedEof , "reading bool" ) ;
expect_io_error (
result ,
io ::ErrorKind ::UnexpectedEof ,
"decoding value at position 0: expected 1 bytes remaining, found 0" ,
) ;
}
#[ test ]
@ -564,7 +627,10 @@ pub mod tests {
expect_io_error (
decoder . decode ::< u16 > ( ) ,
io ::ErrorKind ::InvalidData ,
& format ! ( "invalid u16: {}" , expected_val ) ,
& format ! (
"decoding value at position 0: invalid u16 value {}" ,
expected_val
) ,
) ;
}
}
@ -577,7 +643,11 @@ pub mod tests {
let result = decoder . decode ::< u16 > ( ) ;
expect_io_error ( result , io ::ErrorKind ::UnexpectedEof , "reading u16" ) ;
expect_io_error (
result ,
io ::ErrorKind ::UnexpectedEof ,
"decoding value at position 0: expected 4 bytes remaining, found 0" ,
) ;
}
#[ test ]