diff --git a/src/proto/codec.rs b/src/proto/codec.rs index d1b8b8e..877d527 100644 --- a/src/proto/codec.rs +++ b/src/proto/codec.rs @@ -4,10 +4,12 @@ use std::convert::TryInto; use std::io; -use std::marker; +use std::marker::PhantomData; use bytes::BytesMut; use thiserror::Error; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; use super::prefix::Prefixer; use super::u32::{decode_u32, U32_BYTE_LEN}; @@ -35,14 +37,15 @@ impl From for io::Error { } /// Encodes entire protocol frames containing values of type `T`. +#[derive(Debug)] pub struct FrameEncoder { - phantom: marker::PhantomData, + phantom: PhantomData, } impl FrameEncoder { pub fn new() -> Self { Self { - phantom: marker::PhantomData, + phantom: PhantomData, } } @@ -66,15 +69,16 @@ impl FrameEncoder { } /// Decodes entire protocol frames containing values of type `T`. +#[derive(Debug)] pub struct FrameDecoder { // Only here to enable parameterizing `Decoder` by `T`. - phantom: marker::PhantomData, + phantom: PhantomData, } impl FrameDecoder { pub fn new() -> Self { Self { - phantom: marker::PhantomData, + phantom: PhantomData, } } @@ -140,10 +144,51 @@ impl FrameDecoder { } } +#[derive(Debug)] +pub struct Connection { + stream: TcpStream, + + read_buffer: BytesMut, + + decoder: FrameDecoder, + encoder: FrameEncoder, +} + +impl Connection +where + ReadFrame: ValueDecode, + WriteFrame: ValueEncode + ?Sized, +{ + pub fn new(stream: TcpStream) -> Self { + Connection { + stream, + read_buffer: BytesMut::new(), + decoder: FrameDecoder::new(), + encoder: FrameEncoder::new(), + } + } + + pub async fn read(&mut self) -> io::Result { + loop { + if let Some(frame) = self.decoder.decode_from(&mut self.read_buffer)? { + return Ok(frame); + } + self.stream.read_buf(&mut self.read_buffer).await?; + } + } + + pub async fn write(&mut self, frame: &WriteFrame) -> io::Result<()> { + let mut bytes = BytesMut::new(); + self.encoder.encode_to(frame, &mut bytes)?; + self.stream.write_all(bytes.as_ref()).await + } +} + mod tests { use bytes::BytesMut; + use tokio::net::{TcpListener, TcpStream}; - use super::{FrameDecoder, FrameEncoder}; + use super::{Connection, FrameDecoder, FrameEncoder}; // Test value: [1, 3, 3, 7] in little-endian. const U32_1337: u32 = 1 + (3 << 8) + (3 << 16) + (7 << 24); @@ -289,4 +334,52 @@ mod tests { assert_eq!(decoded, Some(value)); assert_eq!(buffer, vec![]); } + + #[tokio::test] + async fn ping_pong() { + let listener = TcpListener::bind("localhost:0").await.unwrap(); + let address = listener.local_addr().unwrap(); + + let server_task = tokio::spawn(async move { + let (stream, _peer_address) = listener.accept().await.unwrap(); + let mut connection = Connection::::new(stream); + + assert_eq!(connection.read().await.unwrap(), "ping"); + connection.write("pong").await.unwrap(); + assert_eq!(connection.read().await.unwrap(), "ping"); + connection.write("pong").await.unwrap(); + }); + + let stream = TcpStream::connect(address).await.unwrap(); + let mut connection = Connection::::new(stream); + + connection.write("ping").await.unwrap(); + assert_eq!(connection.read().await.unwrap(), "pong"); + connection.write("ping").await.unwrap(); + assert_eq!(connection.read().await.unwrap(), "pong"); + + server_task.await.unwrap(); + } + + #[tokio::test] + async fn very_large_message() { + let listener = TcpListener::bind("localhost:0").await.unwrap(); + let address = listener.local_addr().unwrap(); + + let server_task = tokio::spawn(async move { + let (stream, _peer_address) = listener.accept().await.unwrap(); + let mut connection = Connection::>::new(stream); + + assert_eq!(connection.read().await.unwrap(), "ping"); + connection.write(&vec![0; 10 * 4096]).await.unwrap(); + }); + + let stream = TcpStream::connect(address).await.unwrap(); + let mut connection = Connection::, str>::new(stream); + + connection.write("ping").await.unwrap(); + assert_eq!(connection.read().await.unwrap(), vec![0; 10 * 4096]); + + server_task.await.unwrap(); + } } diff --git a/src/proto/connection.rs b/src/proto/connection.rs deleted file mode 100644 index 0a7ee58..0000000 --- a/src/proto/connection.rs +++ /dev/null @@ -1,105 +0,0 @@ -use std::io; -use std::marker::PhantomData; - -use bytes::BytesMut; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::TcpStream; - -use crate::proto::{FrameDecoder, FrameEncoder, ValueDecode, ValueEncode}; - -#[derive(Debug)] -pub struct Connection { - stream: TcpStream, - - read_buffer: BytesMut, - - phantom_read: PhantomData, - phantom_write: PhantomData, -} - -impl Connection -where - ReadFrame: ValueDecode, - WriteFrame: ValueEncode + ?Sized, -{ - pub fn new(stream: TcpStream) -> Self { - Connection { - stream, - read_buffer: BytesMut::new(), - phantom_read: PhantomData, - phantom_write: PhantomData, - } - } - - pub async fn read(&mut self) -> io::Result { - let mut decoder = FrameDecoder::new(); - - loop { - if let Some(frame) = decoder.decode_from(&mut self.read_buffer)? { - return Ok(frame); - } - self.stream.read_buf(&mut self.read_buffer).await?; - } - } - - pub async fn write(&mut self, frame: &WriteFrame) -> io::Result<()> { - let mut bytes = BytesMut::new(); - FrameEncoder::new().encode_to(frame, &mut bytes)?; - self.stream.write_all(bytes.as_ref()).await - } -} - -#[cfg(test)] -mod tests { - use tokio::net::{TcpListener, TcpStream}; - - use super::Connection; - - #[tokio::test] - async fn ping_pong() { - let listener = TcpListener::bind("localhost:0").await.unwrap(); - let address = listener.local_addr().unwrap(); - - let server_task = tokio::spawn(async move { - let (stream, _peer_address) = listener.accept().await.unwrap(); - let mut connection = Connection::::new(stream); - - assert_eq!(connection.read().await.unwrap(), "ping"); - connection.write("pong").await.unwrap(); - assert_eq!(connection.read().await.unwrap(), "ping"); - connection.write("pong").await.unwrap(); - }); - - let stream = TcpStream::connect(address).await.unwrap(); - let mut connection = Connection::::new(stream); - - connection.write("ping").await.unwrap(); - assert_eq!(connection.read().await.unwrap(), "pong"); - connection.write("ping").await.unwrap(); - assert_eq!(connection.read().await.unwrap(), "pong"); - - server_task.await.unwrap(); - } - - #[tokio::test] - async fn very_large_message() { - let listener = TcpListener::bind("localhost:0").await.unwrap(); - let address = listener.local_addr().unwrap(); - - let server_task = tokio::spawn(async move { - let (stream, _peer_address) = listener.accept().await.unwrap(); - let mut connection = Connection::>::new(stream); - - assert_eq!(connection.read().await.unwrap(), "ping"); - connection.write(&vec![0; 10 * 4096]).await.unwrap(); - }); - - let stream = TcpStream::connect(address).await.unwrap(); - let mut connection = Connection::, str>::new(stream); - - connection.write("ping").await.unwrap(); - assert_eq!(connection.read().await.unwrap(), vec![0; 10 * 4096]); - - server_task.await.unwrap(); - } -} diff --git a/src/proto/mod.rs b/src/proto/mod.rs index 4daa4ac..bcf37bc 100644 --- a/src/proto/mod.rs +++ b/src/proto/mod.rs @@ -1,5 +1,4 @@ mod codec; -mod connection; mod constants; mod handler; mod packet; @@ -13,8 +12,7 @@ pub mod u32; mod user; mod value_codec; -pub use self::codec::*; -pub use self::connection::Connection; +pub use self::codec::Connection; pub use self::handler::*; pub use self::packet::*; pub use self::server::{ServerRequest, ServerResponse};