|
|
|
@ -4,10 +4,12 @@ |
|
|
|
|
|
|
|
use std::convert::TryInto;
|
|
|
|
use std::io;
|
|
|
|
use std::marker;
|
|
|
|
use std::marker::PhantomData;
|
|
|
|
|
|
|
|
use bytes::BytesMut;
|
|
|
|
use thiserror::Error;
|
|
|
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
|
|
|
use tokio::net::TcpStream;
|
|
|
|
|
|
|
|
use super::prefix::Prefixer;
|
|
|
|
use super::u32::{decode_u32, U32_BYTE_LEN};
|
|
|
|
@ -35,14 +37,15 @@ impl From<FrameEncodeError> for io::Error { |
|
|
|
}
|
|
|
|
|
|
|
|
/// Encodes entire protocol frames containing values of type `T`.
|
|
|
|
#[derive(Debug)]
|
|
|
|
pub struct FrameEncoder<T: ?Sized> {
|
|
|
|
phantom: marker::PhantomData<T>,
|
|
|
|
phantom: PhantomData<T>,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl<T: ValueEncode + ?Sized> FrameEncoder<T> {
|
|
|
|
pub fn new() -> Self {
|
|
|
|
Self {
|
|
|
|
phantom: marker::PhantomData,
|
|
|
|
phantom: PhantomData,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
@ -66,15 +69,16 @@ impl<T: ValueEncode + ?Sized> FrameEncoder<T> { |
|
|
|
}
|
|
|
|
|
|
|
|
/// Decodes entire protocol frames containing values of type `T`.
|
|
|
|
#[derive(Debug)]
|
|
|
|
pub struct FrameDecoder<T> {
|
|
|
|
// Only here to enable parameterizing `Decoder` by `T`.
|
|
|
|
phantom: marker::PhantomData<T>,
|
|
|
|
phantom: PhantomData<T>,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl<T: ValueDecode> FrameDecoder<T> {
|
|
|
|
pub fn new() -> Self {
|
|
|
|
Self {
|
|
|
|
phantom: marker::PhantomData,
|
|
|
|
phantom: PhantomData,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
@ -140,10 +144,51 @@ impl<T: ValueDecode> FrameDecoder<T> { |
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Debug)]
|
|
|
|
pub struct Connection<ReadFrame, WriteFrame: ?Sized> {
|
|
|
|
stream: TcpStream,
|
|
|
|
|
|
|
|
read_buffer: BytesMut,
|
|
|
|
|
|
|
|
decoder: FrameDecoder<ReadFrame>,
|
|
|
|
encoder: FrameEncoder<WriteFrame>,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl<ReadFrame, WriteFrame> Connection<ReadFrame, WriteFrame>
|
|
|
|
where
|
|
|
|
ReadFrame: ValueDecode,
|
|
|
|
WriteFrame: ValueEncode + ?Sized,
|
|
|
|
{
|
|
|
|
pub fn new(stream: TcpStream) -> Self {
|
|
|
|
Connection {
|
|
|
|
stream,
|
|
|
|
read_buffer: BytesMut::new(),
|
|
|
|
decoder: FrameDecoder::new(),
|
|
|
|
encoder: FrameEncoder::new(),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
pub async fn read(&mut self) -> io::Result<ReadFrame> {
|
|
|
|
loop {
|
|
|
|
if let Some(frame) = self.decoder.decode_from(&mut self.read_buffer)? {
|
|
|
|
return Ok(frame);
|
|
|
|
}
|
|
|
|
self.stream.read_buf(&mut self.read_buffer).await?;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
pub async fn write(&mut self, frame: &WriteFrame) -> io::Result<()> {
|
|
|
|
let mut bytes = BytesMut::new();
|
|
|
|
self.encoder.encode_to(frame, &mut bytes)?;
|
|
|
|
self.stream.write_all(bytes.as_ref()).await
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
mod tests {
|
|
|
|
use bytes::BytesMut;
|
|
|
|
use tokio::net::{TcpListener, TcpStream};
|
|
|
|
|
|
|
|
use super::{FrameDecoder, FrameEncoder};
|
|
|
|
use super::{Connection, FrameDecoder, FrameEncoder};
|
|
|
|
|
|
|
|
// Test value: [1, 3, 3, 7] in little-endian.
|
|
|
|
const U32_1337: u32 = 1 + (3 << 8) + (3 << 16) + (7 << 24);
|
|
|
|
@ -289,4 +334,52 @@ mod tests { |
|
|
|
assert_eq!(decoded, Some(value));
|
|
|
|
assert_eq!(buffer, vec![]);
|
|
|
|
}
|
|
|
|
|
|
|
|
#[tokio::test]
|
|
|
|
async fn ping_pong() {
|
|
|
|
let listener = TcpListener::bind("localhost:0").await.unwrap();
|
|
|
|
let address = listener.local_addr().unwrap();
|
|
|
|
|
|
|
|
let server_task = tokio::spawn(async move {
|
|
|
|
let (stream, _peer_address) = listener.accept().await.unwrap();
|
|
|
|
let mut connection = Connection::<String, str>::new(stream);
|
|
|
|
|
|
|
|
assert_eq!(connection.read().await.unwrap(), "ping");
|
|
|
|
connection.write("pong").await.unwrap();
|
|
|
|
assert_eq!(connection.read().await.unwrap(), "ping");
|
|
|
|
connection.write("pong").await.unwrap();
|
|
|
|
});
|
|
|
|
|
|
|
|
let stream = TcpStream::connect(address).await.unwrap();
|
|
|
|
let mut connection = Connection::<String, str>::new(stream);
|
|
|
|
|
|
|
|
connection.write("ping").await.unwrap();
|
|
|
|
assert_eq!(connection.read().await.unwrap(), "pong");
|
|
|
|
connection.write("ping").await.unwrap();
|
|
|
|
assert_eq!(connection.read().await.unwrap(), "pong");
|
|
|
|
|
|
|
|
server_task.await.unwrap();
|
|
|
|
}
|
|
|
|
|
|
|
|
#[tokio::test]
|
|
|
|
async fn very_large_message() {
|
|
|
|
let listener = TcpListener::bind("localhost:0").await.unwrap();
|
|
|
|
let address = listener.local_addr().unwrap();
|
|
|
|
|
|
|
|
let server_task = tokio::spawn(async move {
|
|
|
|
let (stream, _peer_address) = listener.accept().await.unwrap();
|
|
|
|
let mut connection = Connection::<String, Vec<u32>>::new(stream);
|
|
|
|
|
|
|
|
assert_eq!(connection.read().await.unwrap(), "ping");
|
|
|
|
connection.write(&vec![0; 10 * 4096]).await.unwrap();
|
|
|
|
});
|
|
|
|
|
|
|
|
let stream = TcpStream::connect(address).await.unwrap();
|
|
|
|
let mut connection = Connection::<Vec<u32>, str>::new(stream);
|
|
|
|
|
|
|
|
connection.write("ping").await.unwrap();
|
|
|
|
assert_eq!(connection.read().await.unwrap(), vec![0; 10 * 4096]);
|
|
|
|
|
|
|
|
server_task.await.unwrap();
|
|
|
|
}
|
|
|
|
}
|