From 0aa0c48591a18e80808c59110eac1c6e373af95f Mon Sep 17 00:00:00 2001 From: Titouan Rigoudy Date: Tue, 24 Aug 2021 18:06:37 +0200 Subject: [PATCH] Introduce proto Worker. --- proto/src/core/mod.rs | 4 + proto/src/core/worker.rs | 185 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 189 insertions(+) create mode 100644 proto/src/core/worker.rs diff --git a/proto/src/core/mod.rs b/proto/src/core/mod.rs index 0e336b3..8e642c1 100644 --- a/proto/src/core/mod.rs +++ b/proto/src/core/mod.rs @@ -4,6 +4,10 @@ pub mod frame; mod prefix; mod u32; mod user; +// TODO: Remove `pub` qualifier, depend on re-exports. pub mod value; +mod worker; pub use user::{User, UserStatus}; +pub use worker::{Worker, WorkerError}; +pub use value::{ValueDecode, ValueEncode}; diff --git a/proto/src/core/worker.rs b/proto/src/core/worker.rs new file mode 100644 index 0000000..7eebf67 --- /dev/null +++ b/proto/src/core/worker.rs @@ -0,0 +1,185 @@ +use std::fmt::Debug; +use std::io; + +use log::debug; +use thiserror::Error; +use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; +use tokio::net::TcpStream; +use tokio::sync::mpsc; + +use crate::core::frame::{FrameReader, FrameWriter}; +use crate::core::value::{ValueDecode, ValueEncode}; + +/// An error that arose while exchanging messages over a `Channel`. +#[derive(Debug, Error)] +pub enum WorkerError { + #[error("read error: {0}")] + ReadError(io::Error), + + #[error("write error: {0}")] + WriteError(io::Error), + + #[error("incoming frame channel is unexpectedly closed")] + IncomingChannelClosed, +} + +async fn forward_incoming( + mut reader: FrameReader, + incoming_tx: mpsc::Sender, +) -> Result<(), WorkerError> { + while let Some(frame) = reader.read().await.map_err(WorkerError::ReadError)? { + debug!("Channel: received frame: {:?}", frame); + + if let Err(_) = incoming_tx.send(frame).await { + return Err(WorkerError::IncomingChannelClosed); + } + } + + debug!("Stopping incoming handler: frame reader is closed"); + Ok(()) +} + +async fn forward_outgoing( + mut outgoing_rx: mpsc::Receiver, + mut writer: FrameWriter, +) -> Result<(), WorkerError> { + while let Some(frame) = outgoing_rx.recv().await { + debug!("Channel: sending frame: {:?}", frame); + writer + .write(&frame) + .await + .map_err(WorkerError::WriteError)?; + } + + debug!("Stopping outgoing handler: channel is closed"); + Ok(()) +} + +/// A worker that operates a full-duplex connection exchanging frames over TCP. +pub struct Worker { + reader: FrameReader, + writer: FrameWriter, + incoming_tx: mpsc::Sender, + outgoing_rx: mpsc::Receiver, +} + +impl Worker +where + ReadFrame: ValueDecode + Debug, + WriteFrame: ValueEncode + Debug, +{ + fn new( + stream: TcpStream, + incoming_tx: mpsc::Sender, + outgoing_rx: mpsc::Receiver, + ) -> Self { + let (read_half, write_half) = stream.into_split(); + let reader = FrameReader::new(read_half); + let writer = FrameWriter::new(write_half); + Self { + reader, + writer, + incoming_tx, + outgoing_rx, + } + } + + async fn run(self) -> Result<(), WorkerError> { + tokio::select! { + result = forward_incoming(self.reader, self.incoming_tx) => { + debug!("{:?}", result); + }, + result = forward_outgoing(self.outgoing_rx, self.writer) => { + debug!("{:?}", result); + }, + }; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use tokio::io::AsyncReadExt; + use tokio::net::{TcpListener, TcpStream}; + use tokio::sync::mpsc; + + use crate::core::frame::{FrameReader, FrameWriter}; + + use super::Worker; + + // Enable capturing logs in tests. + fn init() { + let _ = env_logger::builder().is_test(true).try_init(); + } + + #[tokio::test] + async fn forwards_incoming_frames() { + init(); + + let listener = TcpListener::bind("localhost:0").await.unwrap(); + let address = listener.local_addr().unwrap(); + + let listener_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let (mut read_half, write_half) = stream.split(); + let mut writer = FrameWriter::new(write_half); + + let frame: u32 = 42; + writer.write(&frame).await.expect("writing frame"); + + let mut buf = Vec::new(); + read_half.read_to_end(&mut buf).await.expect("waiting for eof"); + assert_eq!(buf, Vec::::new()); + }); + + let stream = TcpStream::connect(address).await.expect("connecting"); + + let (request_tx, request_rx) = mpsc::channel::(100); + let (response_tx, mut response_rx) = mpsc::channel::(100); + let worker = Worker::new(stream, response_tx, request_rx); + + let worker_task = tokio::spawn(worker.run()); + + let frame = response_rx.recv().await.expect("receiving frame"); + assert_eq!(frame, 42); + + // Signal to the worker that it should stop running. + drop(request_tx); + + worker_task.await.expect("joining worker").expect("running worker"); + listener_task.await.expect("joining listener"); + } + + #[tokio::test] + async fn forwards_outgoing_frames() { + init(); + + let listener = TcpListener::bind("localhost:0").await.unwrap(); + let address = listener.local_addr().unwrap(); + + let listener_task = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let mut reader = FrameReader::new(stream); + + let frame = reader.read().await.expect("reading frame"); + assert_eq!(frame, Some(42u32)); + }); + + let stream = TcpStream::connect(address).await.expect("connecting"); + + let (request_tx, request_rx) = mpsc::channel::(100); + let (response_tx, _response_rx) = mpsc::channel::(100); + let worker = Worker::new(stream, response_tx, request_rx); + + let worker_task = tokio::spawn(worker.run()); + + request_tx.send(42).await.expect("sending frame"); + + // Signal to the worker that it should stop running. + drop(request_tx); + + worker_task.await.expect("joining worker").expect("running worker"); + listener_task.await.expect("joining listener"); + } +}