Browse Source

Move Connection into codec.rs.

wip
Titouan Rigoudy 4 years ago
parent
commit
12f85fbaa9
3 changed files with 100 additions and 114 deletions
  1. +99
    -6
      src/proto/codec.rs
  2. +0
    -105
      src/proto/connection.rs
  3. +1
    -3
      src/proto/mod.rs

+ 99
- 6
src/proto/codec.rs View File

@ -4,10 +4,12 @@
use std::convert::TryInto; use std::convert::TryInto;
use std::io; use std::io;
use std::marker;
use std::marker::PhantomData;
use bytes::BytesMut; use bytes::BytesMut;
use thiserror::Error; use thiserror::Error;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use super::prefix::Prefixer; use super::prefix::Prefixer;
use super::u32::{decode_u32, U32_BYTE_LEN}; use super::u32::{decode_u32, U32_BYTE_LEN};
@ -35,14 +37,15 @@ impl From<FrameEncodeError> for io::Error {
} }
/// Encodes entire protocol frames containing values of type `T`. /// Encodes entire protocol frames containing values of type `T`.
#[derive(Debug)]
pub struct FrameEncoder<T: ?Sized> { pub struct FrameEncoder<T: ?Sized> {
phantom: marker::PhantomData<T>,
phantom: PhantomData<T>,
} }
impl<T: ValueEncode + ?Sized> FrameEncoder<T> { impl<T: ValueEncode + ?Sized> FrameEncoder<T> {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
phantom: marker::PhantomData,
phantom: PhantomData,
} }
} }
@ -66,15 +69,16 @@ impl<T: ValueEncode + ?Sized> FrameEncoder<T> {
} }
/// Decodes entire protocol frames containing values of type `T`. /// Decodes entire protocol frames containing values of type `T`.
#[derive(Debug)]
pub struct FrameDecoder<T> { pub struct FrameDecoder<T> {
// Only here to enable parameterizing `Decoder` by `T`. // Only here to enable parameterizing `Decoder` by `T`.
phantom: marker::PhantomData<T>,
phantom: PhantomData<T>,
} }
impl<T: ValueDecode> FrameDecoder<T> { impl<T: ValueDecode> FrameDecoder<T> {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
phantom: marker::PhantomData,
phantom: PhantomData,
} }
} }
@ -140,10 +144,51 @@ impl<T: ValueDecode> FrameDecoder<T> {
} }
} }
#[derive(Debug)]
pub struct Connection<ReadFrame, WriteFrame: ?Sized> {
stream: TcpStream,
read_buffer: BytesMut,
decoder: FrameDecoder<ReadFrame>,
encoder: FrameEncoder<WriteFrame>,
}
impl<ReadFrame, WriteFrame> Connection<ReadFrame, WriteFrame>
where
ReadFrame: ValueDecode,
WriteFrame: ValueEncode + ?Sized,
{
pub fn new(stream: TcpStream) -> Self {
Connection {
stream,
read_buffer: BytesMut::new(),
decoder: FrameDecoder::new(),
encoder: FrameEncoder::new(),
}
}
pub async fn read(&mut self) -> io::Result<ReadFrame> {
loop {
if let Some(frame) = self.decoder.decode_from(&mut self.read_buffer)? {
return Ok(frame);
}
self.stream.read_buf(&mut self.read_buffer).await?;
}
}
pub async fn write(&mut self, frame: &WriteFrame) -> io::Result<()> {
let mut bytes = BytesMut::new();
self.encoder.encode_to(frame, &mut bytes)?;
self.stream.write_all(bytes.as_ref()).await
}
}
mod tests { mod tests {
use bytes::BytesMut; use bytes::BytesMut;
use tokio::net::{TcpListener, TcpStream};
use super::{FrameDecoder, FrameEncoder};
use super::{Connection, FrameDecoder, FrameEncoder};
// Test value: [1, 3, 3, 7] in little-endian. // Test value: [1, 3, 3, 7] in little-endian.
const U32_1337: u32 = 1 + (3 << 8) + (3 << 16) + (7 << 24); const U32_1337: u32 = 1 + (3 << 8) + (3 << 16) + (7 << 24);
@ -289,4 +334,52 @@ mod tests {
assert_eq!(decoded, Some(value)); assert_eq!(decoded, Some(value));
assert_eq!(buffer, vec![]); assert_eq!(buffer, vec![]);
} }
#[tokio::test]
async fn ping_pong() {
let listener = TcpListener::bind("localhost:0").await.unwrap();
let address = listener.local_addr().unwrap();
let server_task = tokio::spawn(async move {
let (stream, _peer_address) = listener.accept().await.unwrap();
let mut connection = Connection::<String, str>::new(stream);
assert_eq!(connection.read().await.unwrap(), "ping");
connection.write("pong").await.unwrap();
assert_eq!(connection.read().await.unwrap(), "ping");
connection.write("pong").await.unwrap();
});
let stream = TcpStream::connect(address).await.unwrap();
let mut connection = Connection::<String, str>::new(stream);
connection.write("ping").await.unwrap();
assert_eq!(connection.read().await.unwrap(), "pong");
connection.write("ping").await.unwrap();
assert_eq!(connection.read().await.unwrap(), "pong");
server_task.await.unwrap();
}
#[tokio::test]
async fn very_large_message() {
let listener = TcpListener::bind("localhost:0").await.unwrap();
let address = listener.local_addr().unwrap();
let server_task = tokio::spawn(async move {
let (stream, _peer_address) = listener.accept().await.unwrap();
let mut connection = Connection::<String, Vec<u32>>::new(stream);
assert_eq!(connection.read().await.unwrap(), "ping");
connection.write(&vec![0; 10 * 4096]).await.unwrap();
});
let stream = TcpStream::connect(address).await.unwrap();
let mut connection = Connection::<Vec<u32>, str>::new(stream);
connection.write("ping").await.unwrap();
assert_eq!(connection.read().await.unwrap(), vec![0; 10 * 4096]);
server_task.await.unwrap();
}
} }

+ 0
- 105
src/proto/connection.rs View File

@ -1,105 +0,0 @@
use std::io;
use std::marker::PhantomData;
use bytes::BytesMut;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use crate::proto::{FrameDecoder, FrameEncoder, ValueDecode, ValueEncode};
#[derive(Debug)]
pub struct Connection<ReadFrame, WriteFrame: ?Sized> {
stream: TcpStream,
read_buffer: BytesMut,
phantom_read: PhantomData<ReadFrame>,
phantom_write: PhantomData<WriteFrame>,
}
impl<ReadFrame, WriteFrame> Connection<ReadFrame, WriteFrame>
where
ReadFrame: ValueDecode,
WriteFrame: ValueEncode + ?Sized,
{
pub fn new(stream: TcpStream) -> Self {
Connection {
stream,
read_buffer: BytesMut::new(),
phantom_read: PhantomData,
phantom_write: PhantomData,
}
}
pub async fn read(&mut self) -> io::Result<ReadFrame> {
let mut decoder = FrameDecoder::new();
loop {
if let Some(frame) = decoder.decode_from(&mut self.read_buffer)? {
return Ok(frame);
}
self.stream.read_buf(&mut self.read_buffer).await?;
}
}
pub async fn write(&mut self, frame: &WriteFrame) -> io::Result<()> {
let mut bytes = BytesMut::new();
FrameEncoder::new().encode_to(frame, &mut bytes)?;
self.stream.write_all(bytes.as_ref()).await
}
}
#[cfg(test)]
mod tests {
use tokio::net::{TcpListener, TcpStream};
use super::Connection;
#[tokio::test]
async fn ping_pong() {
let listener = TcpListener::bind("localhost:0").await.unwrap();
let address = listener.local_addr().unwrap();
let server_task = tokio::spawn(async move {
let (stream, _peer_address) = listener.accept().await.unwrap();
let mut connection = Connection::<String, str>::new(stream);
assert_eq!(connection.read().await.unwrap(), "ping");
connection.write("pong").await.unwrap();
assert_eq!(connection.read().await.unwrap(), "ping");
connection.write("pong").await.unwrap();
});
let stream = TcpStream::connect(address).await.unwrap();
let mut connection = Connection::<String, str>::new(stream);
connection.write("ping").await.unwrap();
assert_eq!(connection.read().await.unwrap(), "pong");
connection.write("ping").await.unwrap();
assert_eq!(connection.read().await.unwrap(), "pong");
server_task.await.unwrap();
}
#[tokio::test]
async fn very_large_message() {
let listener = TcpListener::bind("localhost:0").await.unwrap();
let address = listener.local_addr().unwrap();
let server_task = tokio::spawn(async move {
let (stream, _peer_address) = listener.accept().await.unwrap();
let mut connection = Connection::<String, Vec<u32>>::new(stream);
assert_eq!(connection.read().await.unwrap(), "ping");
connection.write(&vec![0; 10 * 4096]).await.unwrap();
});
let stream = TcpStream::connect(address).await.unwrap();
let mut connection = Connection::<Vec<u32>, str>::new(stream);
connection.write("ping").await.unwrap();
assert_eq!(connection.read().await.unwrap(), vec![0; 10 * 4096]);
server_task.await.unwrap();
}
}

+ 1
- 3
src/proto/mod.rs View File

@ -1,5 +1,4 @@
mod codec; mod codec;
mod connection;
mod constants; mod constants;
mod handler; mod handler;
mod packet; mod packet;
@ -13,8 +12,7 @@ pub mod u32;
mod user; mod user;
mod value_codec; mod value_codec;
pub use self::codec::*;
pub use self::connection::Connection;
pub use self::codec::Connection;
pub use self::handler::*; pub use self::handler::*;
pub use self::packet::*; pub use self::packet::*;
pub use self::server::{ServerRequest, ServerResponse}; pub use self::server::{ServerRequest, ServerResponse};


Loading…
Cancel
Save