From de7025edc89fd8389245e2fcdedc6882059d0cbe Mon Sep 17 00:00:00 2001 From: Titouan Rigoudy Date: Tue, 26 Dec 2017 18:19:49 -0500 Subject: [PATCH] Implement ProtoDecode and ProtoEncode for peer messages. --- src/proto/codec.rs | 39 ++++++++++++------------ src/proto/mod.rs | 1 + src/proto/peer/message.rs | 62 +++++++++++++++++++++++++++++++++++++-- src/proto/transport.rs | 41 +++++++++++++++++--------- 4 files changed, 108 insertions(+), 35 deletions(-) diff --git a/src/proto/codec.rs b/src/proto/codec.rs index 42038a5..ee5def8 100644 --- a/src/proto/codec.rs +++ b/src/proto/codec.rs @@ -116,23 +116,23 @@ pub trait ProtoEncode { // messages. pub struct ProtoDecoder<'a> { // If bytes::Buf was object-safe we would just store &'a Buf. We work - // around this limitation by storing the cursor itself. + // around this limitation by explicitly naming the implementing type. inner: &'a mut io::Cursor, } impl<'a> ProtoDecoder<'a> { - fn new(cursor: &'a mut io::Cursor) -> ProtoDecoder<'a> { - ProtoDecoder { inner: cursor } + pub fn new(inner: &'a mut io::Cursor) -> Self { + ProtoDecoder { inner: inner } } - fn decode_u32(&mut self) -> Result { + pub fn decode_u32(&mut self) -> Result { if self.inner.remaining() < U32_BYTE_LEN { return Err(unexpected_eof_error("u32")); } Ok(self.inner.get_u32::()) } - fn decode_u16(&mut self) -> Result { + pub fn decode_u16(&mut self) -> Result { let n = self.decode_u32()?; if n > u16::MAX as u32 { return Err(DecodeError::InvalidU16Error(n)); @@ -140,7 +140,7 @@ impl<'a> ProtoDecoder<'a> { Ok(n as u16) } - fn decode_bool(&mut self) -> Result { + pub fn decode_bool(&mut self) -> Result { if self.inner.remaining() < 1 { return Err(unexpected_eof_error("bool")); } @@ -151,12 +151,12 @@ impl<'a> ProtoDecoder<'a> { } } - fn decode_ipv4_addr(&mut self) -> Result { + pub fn decode_ipv4_addr(&mut self) -> Result { let ip = self.decode_u32()?; Ok(net::Ipv4Addr::from(ip)) } - fn decode_string(&mut self) -> Result { + pub fn decode_string(&mut self) -> Result { let len = self.decode_u32()? as usize; if self.inner.remaining() < len { return Err(unexpected_eof_error("string")); @@ -173,7 +173,7 @@ impl<'a> ProtoDecoder<'a> { result } - fn decode_vec(&mut self) -> Result, DecodeError> { + pub fn decode_vec(&mut self) -> Result, DecodeError> { let len = self.decode_u32()? as usize; let mut vec = Vec::with_capacity(len); for _ in 0..len { @@ -187,17 +187,18 @@ impl<'a> ProtoDecoder<'a> { // A `ProtoEncoder` knows how to encode various types of values into protocol // messages. pub struct ProtoEncoder<'a> { - // If bytes::BufMut was object-safe we would store an &'a BufMut. We work - // around this limiation by using BytesMut directly. + // We would like to store an &'a BufMut instead, but not only is it not + // object-safe yet, it does not grow the buffer on writes either... So we + // don't want to template this struct like ProtoDecoder either. inner: &'a mut BytesMut, } impl<'a> ProtoEncoder<'a> { - fn new(buf: &'a mut BytesMut) -> ProtoEncoder { - ProtoEncoder { inner: buf } + pub fn new(inner: &'a mut BytesMut) -> Self { + ProtoEncoder { inner: inner } } - fn encode_u32(&mut self, val: u32) -> io::Result<()> { + pub fn encode_u32(&mut self, val: u32) -> io::Result<()> { if self.inner.remaining_mut() < U32_BYTE_LEN { self.inner.reserve(U32_BYTE_LEN); } @@ -205,11 +206,11 @@ impl<'a> ProtoEncoder<'a> { Ok(()) } - fn encode_u16(&mut self, val: u16) -> io::Result<()> { + pub fn encode_u16(&mut self, val: u16) -> io::Result<()> { self.encode_u32(val as u32) } - fn encode_bool(&mut self, val: bool) -> io::Result<()> { + pub fn encode_bool(&mut self, val: bool) -> io::Result<()> { if !self.inner.has_remaining_mut() { self.inner.reserve(1); } @@ -217,14 +218,14 @@ impl<'a> ProtoEncoder<'a> { Ok(()) } - fn encode_ipv4_addr(&mut self, addr: net::Ipv4Addr) -> io::Result<()> { + pub fn encode_ipv4_addr(&mut self, addr: net::Ipv4Addr) -> io::Result<()> { let mut octets = addr.octets(); octets.reverse(); // Little endian. self.inner.extend(&octets); Ok(()) } - fn encode_string(&mut self, val: &str) -> io::Result<()> { + pub fn encode_string(&mut self, val: &str) -> io::Result<()> { // Encode the string. let bytes = match WINDOWS_1252.encode(val, EncoderTrap::Strict) { Ok(bytes) => bytes, @@ -238,7 +239,7 @@ impl<'a> ProtoEncoder<'a> { Ok(()) } - fn encode_vec(&mut self, vec: &[T]) -> io::Result<()> { + pub fn encode_vec(&mut self, vec: &[T]) -> io::Result<()> { self.encode_u32(vec.len() as u32)?; for ref item in vec { item.encode(self)?; diff --git a/src/proto/mod.rs b/src/proto/mod.rs index d2b4054..97198d1 100644 --- a/src/proto/mod.rs +++ b/src/proto/mod.rs @@ -12,3 +12,4 @@ pub use self::packet::*; pub use self::stream::*; pub use self::server::{ServerResponse, ServerRequest}; pub use self::transport::{PeerTransport, ServerTransport}; +pub use self::codec::{DecodeError, ProtoDecode, ProtoDecoder, ProtoEncode, ProtoEncoder}; diff --git a/src/proto/peer/message.rs b/src/proto/peer/message.rs index ceeea2d..de47ff2 100644 --- a/src/proto/peer/message.rs +++ b/src/proto/peer/message.rs @@ -1,7 +1,8 @@ use std::io; -use super::super::{MutPacket, Packet, PacketReadError, ReadFromPacket, WriteToPacket}; -use super::constants::*; +use proto::{DecodeError, MutPacket, Packet, PacketReadError, ProtoDecode, ProtoDecoder, + ProtoEncode, ProtoEncoder, ReadFromPacket, WriteToPacket}; +use proto::peer::constants::*; /*=========* * MESSAGE * @@ -39,6 +40,41 @@ impl ReadFromPacket for Message { } } +impl ProtoDecode for Message { + fn decode(decoder: &mut ProtoDecoder) -> Result { + let code = decoder.decode_u32()?; + let message = match code { + CODE_PIERCE_FIREWALL => { + let val = decoder.decode_u32()?; + Message::PierceFirewall(val) + } + CODE_PEER_INIT => { + let peer_init = PeerInit::decode(decoder)?; + Message::PeerInit(peer_init) + } + code => Message::Unknown(code), + }; + Ok(message) + } +} + +impl ProtoEncode for Message { + fn encode(&self, encoder: &mut ProtoEncoder) -> io::Result<()> { + match *self { + Message::PierceFirewall(token) => { + encoder.encode_u32(CODE_PIERCE_FIREWALL)?; + encoder.encode_u32(token)?; + } + Message::PeerInit(ref request) => { + encoder.encode_u32(CODE_PEER_INIT)?; + request.encode(encoder)?; + } + Message::Unknown(_) => unreachable!(), + } + Ok(()) + } +} + impl WriteToPacket for Message { fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { match *self { @@ -86,3 +122,25 @@ impl WriteToPacket for PeerInit { Ok(()) } } + +impl ProtoEncode for PeerInit { + fn encode(&self, encoder: &mut ProtoEncoder) -> io::Result<()> { + encoder.encode_string(&self.user_name)?; + encoder.encode_string(&self.connection_type)?; + encoder.encode_u32(self.token)?; + Ok(()) + } +} + +impl ProtoDecode for PeerInit { + fn decode(decoder: &mut ProtoDecoder) -> Result { + let user_name = decoder.decode_string()?; + let connection_type = decoder.decode_string()?; + let token = decoder.decode_u32()?; + Ok(PeerInit { + user_name: user_name, + connection_type: connection_type, + token: token, + }) + } +} diff --git a/src/proto/transport.rs b/src/proto/transport.rs index 4dda433..b4e5935 100644 --- a/src/proto/transport.rs +++ b/src/proto/transport.rs @@ -1,13 +1,13 @@ use std::io; -use bytes::BytesMut; +use bytes::{Buf, BytesMut}; use futures::{Async, AsyncSink, Poll, Sink, StartSend, Stream}; use tokio_io::{AsyncRead, AsyncWrite}; -use tokio_io::codec::{Decoder, Encoder, length_delimited}; +use tokio_io::codec::length_delimited; -use super::peer; -use super::codec::DecodeError; -use super::{ServerResponse, ServerRequest}; +use proto::peer; +use proto::{DecodeError, ProtoDecode, ProtoDecoder, ProtoEncode, ProtoEncoder, ServerResponse, + ServerRequest}; /* ------- * * Helpers * @@ -24,16 +24,27 @@ fn decode_server_response(bytes: &mut BytesMut) -> Result BytesMut { +fn encode_server_request(request: &ServerRequest) -> Result { unimplemented!(); } -fn decode_peer_message(bytes: &mut BytesMut) -> Result { - unimplemented!(); +fn decode_peer_message(bytes: BytesMut) -> Result { + let mut cursor = io::Cursor::new(bytes); + let message = peer::Message::decode(&mut ProtoDecoder::new(&mut cursor))?; + if cursor.has_remaining() { + warn!( + "Received peer message with trailing bytes. Message:\n{:?}Bytes:{:?}", + message, + cursor.bytes() + ); + } + Ok(message) } -fn encode_peer_message(message: &peer::Message) -> BytesMut { - unimplemented!(); +fn encode_peer_message(message: &peer::Message) -> Result { + let mut bytes = BytesMut::new(); + message.encode(&mut ProtoEncoder::new(&mut bytes))?; + Ok(bytes) } /* --------------- * @@ -72,7 +83,8 @@ impl Sink for ServerTransport { type SinkError = io::Error; fn start_send(&mut self, item: Self::SinkItem) -> StartSend { - match self.framed.start_send(encode_server_request(&item)) { + let bytes = encode_server_request(&item)?; + match self.framed.start_send(bytes) { Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready), Ok(AsyncSink::NotReady(_)) => Ok(AsyncSink::NotReady(item)), Err(err) => Err(err), @@ -104,8 +116,8 @@ impl Stream for PeerTransport { fn poll(&mut self) -> Poll, Self::Error> { match self.framed.poll() { - Ok(Async::Ready(Some(mut bytes))) => { - let message = decode_peer_message(&mut bytes)?; + Ok(Async::Ready(Some(bytes))) => { + let message = decode_peer_message(bytes)?; Ok(Async::Ready(Some(message))) } Ok(Async::Ready(None)) => Ok(Async::Ready(None)), @@ -120,7 +132,8 @@ impl Sink for PeerTransport { type SinkError = io::Error; fn start_send(&mut self, item: Self::SinkItem) -> StartSend { - match self.framed.start_send(encode_peer_message(&item)) { + let bytes = encode_peer_message(&item)?; + match self.framed.start_send(bytes) { Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready), Ok(AsyncSink::NotReady(_)) => Ok(AsyncSink::NotReady(item)), Err(err) => Err(err),