Browse Source

Introduce FrameReader and FrameWriter.

wip
Titouan Rigoudy 4 years ago
parent
commit
aa13cfd253
1 changed files with 101 additions and 23 deletions
  1. +101
    -23
      proto/src/core/frame.rs

+ 101
- 23
proto/src/core/frame.rs View File

@ -8,7 +8,7 @@ use std::marker::PhantomData;
use bytes::BytesMut;
use thiserror::Error;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use super::prefix::Prefixer;
@ -136,6 +136,76 @@ impl<T: ValueDecode> FrameDecoder<T> {
}
}
/// An asynchronous sink for frames wrapping around a byte writer.
#[derive(Debug)]
pub struct FrameWriter<Frame: ?Sized, Writer> {
encoder: FrameEncoder<Frame>,
writer: Writer,
}
impl<Frame, Writer> FrameWriter<Frame, Writer>
where
Frame: ValueEncode + ?Sized,
Writer: AsyncWrite + Unpin,
{
pub fn new(writer: Writer) -> Self {
FrameWriter {
encoder: FrameEncoder::new(),
writer,
}
}
pub async fn write(&mut self, frame: &Frame) -> io::Result<()> {
let mut bytes = BytesMut::new();
self.encoder.encode_to(frame, &mut bytes)?;
self.writer.write_all(bytes.as_ref()).await
}
pub async fn shutdown(&mut self) -> io::Result<()> {
self.writer.shutdown().await
}
}
/// An asynchronous stream of frames wrapping around a byte reader.
#[derive(Debug)]
pub struct FrameReader<Frame, Reader> {
decoder: FrameDecoder<Frame>,
reader: Reader,
read_buffer: BytesMut,
}
impl<Frame, Reader> FrameReader<Frame, Reader>
where
Frame: ValueDecode,
Reader: AsyncRead + Unpin,
{
pub fn new(reader: Reader) -> Self {
FrameReader {
decoder: FrameDecoder::new(),
reader,
read_buffer: BytesMut::new(),
}
}
/// Attempts to read the next frame from the underlying byte stream.
///
/// Returns `Ok(Some(frame))` on success.
/// Returns `Ok(None)` if the stream has reached the end-of-file event.
///
/// Returns an error if reading from the stream returned an error or if an
/// invalid frame was received.
pub async fn read(&mut self) -> io::Result<Option<Frame>> {
loop {
if let Some(frame) = self.decoder.decode_from(&mut self.read_buffer)? {
return Ok(Some(frame));
}
if self.reader.read_buf(&mut self.read_buffer).await? == 0 {
return Ok(None);
}
}
}
}
#[derive(Debug)]
pub struct FrameStream<ReadFrame, WriteFrame: ?Sized> {
stream: TcpStream,
@ -194,7 +264,7 @@ mod tests {
use bytes::BytesMut;
use tokio::net::{TcpListener, TcpStream};
use super::{FrameDecoder, FrameEncoder, FrameStream};
use super::{FrameDecoder, FrameEncoder, FrameReader, FrameWriter};
// Test value: [1, 3, 3, 7] in little-endian.
const U32_1337: u32 = 1 + (3 << 8) + (3 << 16) + (7 << 24);
@ -345,22 +415,26 @@ mod tests {
let address = listener.local_addr().unwrap();
let server_task = tokio::spawn(async move {
let (stream, _peer_address) = listener.accept().await.unwrap();
let mut frame_stream = FrameStream::<String, str>::new(stream);
assert_eq!(frame_stream.read().await.unwrap(), Some("ping".to_string()));
frame_stream.write("pong").await.unwrap();
assert_eq!(frame_stream.read().await.unwrap(), Some("ping".to_string()));
frame_stream.write("pong").await.unwrap();
let (mut stream, _peer_address) = listener.accept().await.unwrap();
let (read_half, write_half) = stream.split();
let mut reader = FrameReader::new(read_half);
let mut writer = FrameWriter::new(write_half);
assert_eq!(reader.read().await.unwrap(), Some("ping".to_string()));
writer.write("pong").await.unwrap();
assert_eq!(reader.read().await.unwrap(), Some("ping".to_string()));
writer.write("pong").await.unwrap();
});
let stream = TcpStream::connect(address).await.unwrap();
let mut frame_stream = FrameStream::<String, str>::new(stream);
let mut stream = TcpStream::connect(address).await.unwrap();
let (read_half, write_half) = stream.split();
let mut reader = FrameReader::new(read_half);
let mut writer = FrameWriter::new(write_half);
frame_stream.write("ping").await.unwrap();
assert_eq!(frame_stream.read().await.unwrap(), Some("pong".to_string()));
frame_stream.write("ping").await.unwrap();
assert_eq!(frame_stream.read().await.unwrap(), Some("pong".to_string()));
writer.write("ping").await.unwrap();
assert_eq!(reader.read().await.unwrap(), Some("pong".to_string()));
writer.write("ping").await.unwrap();
assert_eq!(reader.read().await.unwrap(), Some("pong".to_string()));
server_task.await.unwrap();
}
@ -371,18 +445,22 @@ mod tests {
let address = listener.local_addr().unwrap();
let server_task = tokio::spawn(async move {
let (stream, _peer_address) = listener.accept().await.unwrap();
let mut frame_stream = FrameStream::<String, Vec<u32>>::new(stream);
let (mut stream, _peer_address) = listener.accept().await.unwrap();
let (read_half, write_half) = stream.split();
let mut reader = FrameReader::new(read_half);
let mut writer = FrameWriter::new(write_half);
assert_eq!(frame_stream.read().await.unwrap(), Some("ping".to_string()));
frame_stream.write(&vec![0; 10 * 4096]).await.unwrap();
assert_eq!(reader.read().await.unwrap(), Some("ping".to_string()));
writer.write(&vec![0u32; 10 * 4096]).await.unwrap();
});
let stream = TcpStream::connect(address).await.unwrap();
let mut frame_stream = FrameStream::<Vec<u32>, str>::new(stream);
let mut stream = TcpStream::connect(address).await.unwrap();
let (read_half, write_half) = stream.split();
let mut reader = FrameReader::new(read_half);
let mut writer = FrameWriter::new(write_half);
frame_stream.write("ping").await.unwrap();
assert_eq!(frame_stream.read().await.unwrap(), Some(vec![0; 10 * 4096]));
writer.write("ping").await.unwrap();
assert_eq!(reader.read().await.unwrap(), Some(vec![0u32; 10 * 4096]));
server_task.await.unwrap();
}


Loading…
Cancel
Save