|
|
|
@ -8,7 +8,7 @@ use std::marker::PhantomData; |
|
|
|
|
|
|
|
use bytes::BytesMut;
|
|
|
|
use thiserror::Error;
|
|
|
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
|
|
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
|
|
|
use tokio::net::TcpStream;
|
|
|
|
|
|
|
|
use super::prefix::Prefixer;
|
|
|
|
@ -136,6 +136,76 @@ impl<T: ValueDecode> FrameDecoder<T> { |
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// An asynchronous sink for frames wrapping around a byte writer.
|
|
|
|
#[derive(Debug)]
|
|
|
|
pub struct FrameWriter<Frame: ?Sized, Writer> {
|
|
|
|
encoder: FrameEncoder<Frame>,
|
|
|
|
writer: Writer,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl<Frame, Writer> FrameWriter<Frame, Writer>
|
|
|
|
where
|
|
|
|
Frame: ValueEncode + ?Sized,
|
|
|
|
Writer: AsyncWrite + Unpin,
|
|
|
|
{
|
|
|
|
pub fn new(writer: Writer) -> Self {
|
|
|
|
FrameWriter {
|
|
|
|
encoder: FrameEncoder::new(),
|
|
|
|
writer,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
pub async fn write(&mut self, frame: &Frame) -> io::Result<()> {
|
|
|
|
let mut bytes = BytesMut::new();
|
|
|
|
self.encoder.encode_to(frame, &mut bytes)?;
|
|
|
|
self.writer.write_all(bytes.as_ref()).await
|
|
|
|
}
|
|
|
|
|
|
|
|
pub async fn shutdown(&mut self) -> io::Result<()> {
|
|
|
|
self.writer.shutdown().await
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// An asynchronous stream of frames wrapping around a byte reader.
|
|
|
|
#[derive(Debug)]
|
|
|
|
pub struct FrameReader<Frame, Reader> {
|
|
|
|
decoder: FrameDecoder<Frame>,
|
|
|
|
reader: Reader,
|
|
|
|
read_buffer: BytesMut,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl<Frame, Reader> FrameReader<Frame, Reader>
|
|
|
|
where
|
|
|
|
Frame: ValueDecode,
|
|
|
|
Reader: AsyncRead + Unpin,
|
|
|
|
{
|
|
|
|
pub fn new(reader: Reader) -> Self {
|
|
|
|
FrameReader {
|
|
|
|
decoder: FrameDecoder::new(),
|
|
|
|
reader,
|
|
|
|
read_buffer: BytesMut::new(),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Attempts to read the next frame from the underlying byte stream.
|
|
|
|
///
|
|
|
|
/// Returns `Ok(Some(frame))` on success.
|
|
|
|
/// Returns `Ok(None)` if the stream has reached the end-of-file event.
|
|
|
|
///
|
|
|
|
/// Returns an error if reading from the stream returned an error or if an
|
|
|
|
/// invalid frame was received.
|
|
|
|
pub async fn read(&mut self) -> io::Result<Option<Frame>> {
|
|
|
|
loop {
|
|
|
|
if let Some(frame) = self.decoder.decode_from(&mut self.read_buffer)? {
|
|
|
|
return Ok(Some(frame));
|
|
|
|
}
|
|
|
|
if self.reader.read_buf(&mut self.read_buffer).await? == 0 {
|
|
|
|
return Ok(None);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Debug)]
|
|
|
|
pub struct FrameStream<ReadFrame, WriteFrame: ?Sized> {
|
|
|
|
stream: TcpStream,
|
|
|
|
@ -194,7 +264,7 @@ mod tests { |
|
|
|
use bytes::BytesMut;
|
|
|
|
use tokio::net::{TcpListener, TcpStream};
|
|
|
|
|
|
|
|
use super::{FrameDecoder, FrameEncoder, FrameStream};
|
|
|
|
use super::{FrameDecoder, FrameEncoder, FrameReader, FrameWriter};
|
|
|
|
|
|
|
|
// Test value: [1, 3, 3, 7] in little-endian.
|
|
|
|
const U32_1337: u32 = 1 + (3 << 8) + (3 << 16) + (7 << 24);
|
|
|
|
@ -345,22 +415,26 @@ mod tests { |
|
|
|
let address = listener.local_addr().unwrap();
|
|
|
|
|
|
|
|
let server_task = tokio::spawn(async move {
|
|
|
|
let (stream, _peer_address) = listener.accept().await.unwrap();
|
|
|
|
let mut frame_stream = FrameStream::<String, str>::new(stream);
|
|
|
|
|
|
|
|
assert_eq!(frame_stream.read().await.unwrap(), Some("ping".to_string()));
|
|
|
|
frame_stream.write("pong").await.unwrap();
|
|
|
|
assert_eq!(frame_stream.read().await.unwrap(), Some("ping".to_string()));
|
|
|
|
frame_stream.write("pong").await.unwrap();
|
|
|
|
let (mut stream, _peer_address) = listener.accept().await.unwrap();
|
|
|
|
let (read_half, write_half) = stream.split();
|
|
|
|
let mut reader = FrameReader::new(read_half);
|
|
|
|
let mut writer = FrameWriter::new(write_half);
|
|
|
|
|
|
|
|
assert_eq!(reader.read().await.unwrap(), Some("ping".to_string()));
|
|
|
|
writer.write("pong").await.unwrap();
|
|
|
|
assert_eq!(reader.read().await.unwrap(), Some("ping".to_string()));
|
|
|
|
writer.write("pong").await.unwrap();
|
|
|
|
});
|
|
|
|
|
|
|
|
let stream = TcpStream::connect(address).await.unwrap();
|
|
|
|
let mut frame_stream = FrameStream::<String, str>::new(stream);
|
|
|
|
let mut stream = TcpStream::connect(address).await.unwrap();
|
|
|
|
let (read_half, write_half) = stream.split();
|
|
|
|
let mut reader = FrameReader::new(read_half);
|
|
|
|
let mut writer = FrameWriter::new(write_half);
|
|
|
|
|
|
|
|
frame_stream.write("ping").await.unwrap();
|
|
|
|
assert_eq!(frame_stream.read().await.unwrap(), Some("pong".to_string()));
|
|
|
|
frame_stream.write("ping").await.unwrap();
|
|
|
|
assert_eq!(frame_stream.read().await.unwrap(), Some("pong".to_string()));
|
|
|
|
writer.write("ping").await.unwrap();
|
|
|
|
assert_eq!(reader.read().await.unwrap(), Some("pong".to_string()));
|
|
|
|
writer.write("ping").await.unwrap();
|
|
|
|
assert_eq!(reader.read().await.unwrap(), Some("pong".to_string()));
|
|
|
|
|
|
|
|
server_task.await.unwrap();
|
|
|
|
}
|
|
|
|
@ -371,18 +445,22 @@ mod tests { |
|
|
|
let address = listener.local_addr().unwrap();
|
|
|
|
|
|
|
|
let server_task = tokio::spawn(async move {
|
|
|
|
let (stream, _peer_address) = listener.accept().await.unwrap();
|
|
|
|
let mut frame_stream = FrameStream::<String, Vec<u32>>::new(stream);
|
|
|
|
let (mut stream, _peer_address) = listener.accept().await.unwrap();
|
|
|
|
let (read_half, write_half) = stream.split();
|
|
|
|
let mut reader = FrameReader::new(read_half);
|
|
|
|
let mut writer = FrameWriter::new(write_half);
|
|
|
|
|
|
|
|
assert_eq!(frame_stream.read().await.unwrap(), Some("ping".to_string()));
|
|
|
|
frame_stream.write(&vec![0; 10 * 4096]).await.unwrap();
|
|
|
|
assert_eq!(reader.read().await.unwrap(), Some("ping".to_string()));
|
|
|
|
writer.write(&vec![0u32; 10 * 4096]).await.unwrap();
|
|
|
|
});
|
|
|
|
|
|
|
|
let stream = TcpStream::connect(address).await.unwrap();
|
|
|
|
let mut frame_stream = FrameStream::<Vec<u32>, str>::new(stream);
|
|
|
|
let mut stream = TcpStream::connect(address).await.unwrap();
|
|
|
|
let (read_half, write_half) = stream.split();
|
|
|
|
let mut reader = FrameReader::new(read_half);
|
|
|
|
let mut writer = FrameWriter::new(write_half);
|
|
|
|
|
|
|
|
frame_stream.write("ping").await.unwrap();
|
|
|
|
assert_eq!(frame_stream.read().await.unwrap(), Some(vec![0; 10 * 4096]));
|
|
|
|
writer.write("ping").await.unwrap();
|
|
|
|
assert_eq!(reader.read().await.unwrap(), Some(vec![0u32; 10 * 4096]));
|
|
|
|
|
|
|
|
server_task.await.unwrap();
|
|
|
|
}
|
|
|
|
|