| @ -0,0 +1,185 @@ | |||||
| use std::fmt::Debug; | |||||
| use std::io; | |||||
| use log::debug; | |||||
| use thiserror::Error; | |||||
| use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; | |||||
| use tokio::net::TcpStream; | |||||
| use tokio::sync::mpsc; | |||||
| use crate::core::frame::{FrameReader, FrameWriter}; | |||||
| use crate::core::value::{ValueDecode, ValueEncode}; | |||||
| /// An error that arose while exchanging messages over a `Channel`. | |||||
| #[derive(Debug, Error)] | |||||
| pub enum WorkerError { | |||||
| #[error("read error: {0}")] | |||||
| ReadError(io::Error), | |||||
| #[error("write error: {0}")] | |||||
| WriteError(io::Error), | |||||
| #[error("incoming frame channel is unexpectedly closed")] | |||||
| IncomingChannelClosed, | |||||
| } | |||||
| async fn forward_incoming<ReadFrame: ValueDecode + Debug>( | |||||
| mut reader: FrameReader<ReadFrame, OwnedReadHalf>, | |||||
| incoming_tx: mpsc::Sender<ReadFrame>, | |||||
| ) -> Result<(), WorkerError> { | |||||
| while let Some(frame) = reader.read().await.map_err(WorkerError::ReadError)? { | |||||
| debug!("Channel: received frame: {:?}", frame); | |||||
| if let Err(_) = incoming_tx.send(frame).await { | |||||
| return Err(WorkerError::IncomingChannelClosed); | |||||
| } | |||||
| } | |||||
| debug!("Stopping incoming handler: frame reader is closed"); | |||||
| Ok(()) | |||||
| } | |||||
| async fn forward_outgoing<WriteFrame: ValueEncode + Debug>( | |||||
| mut outgoing_rx: mpsc::Receiver<WriteFrame>, | |||||
| mut writer: FrameWriter<WriteFrame, OwnedWriteHalf>, | |||||
| ) -> Result<(), WorkerError> { | |||||
| while let Some(frame) = outgoing_rx.recv().await { | |||||
| debug!("Channel: sending frame: {:?}", frame); | |||||
| writer | |||||
| .write(&frame) | |||||
| .await | |||||
| .map_err(WorkerError::WriteError)?; | |||||
| } | |||||
| debug!("Stopping outgoing handler: channel is closed"); | |||||
| Ok(()) | |||||
| } | |||||
| /// A worker that operates a full-duplex connection exchanging frames over TCP. | |||||
| pub struct Worker<ReadFrame, WriteFrame> { | |||||
| reader: FrameReader<ReadFrame, OwnedReadHalf>, | |||||
| writer: FrameWriter<WriteFrame, OwnedWriteHalf>, | |||||
| incoming_tx: mpsc::Sender<ReadFrame>, | |||||
| outgoing_rx: mpsc::Receiver<WriteFrame>, | |||||
| } | |||||
| impl<ReadFrame, WriteFrame> Worker<ReadFrame, WriteFrame> | |||||
| where | |||||
| ReadFrame: ValueDecode + Debug, | |||||
| WriteFrame: ValueEncode + Debug, | |||||
| { | |||||
| fn new( | |||||
| stream: TcpStream, | |||||
| incoming_tx: mpsc::Sender<ReadFrame>, | |||||
| outgoing_rx: mpsc::Receiver<WriteFrame>, | |||||
| ) -> Self { | |||||
| let (read_half, write_half) = stream.into_split(); | |||||
| let reader = FrameReader::new(read_half); | |||||
| let writer = FrameWriter::new(write_half); | |||||
| Self { | |||||
| reader, | |||||
| writer, | |||||
| incoming_tx, | |||||
| outgoing_rx, | |||||
| } | |||||
| } | |||||
| async fn run(self) -> Result<(), WorkerError> { | |||||
| tokio::select! { | |||||
| result = forward_incoming(self.reader, self.incoming_tx) => { | |||||
| debug!("{:?}", result); | |||||
| }, | |||||
| result = forward_outgoing(self.outgoing_rx, self.writer) => { | |||||
| debug!("{:?}", result); | |||||
| }, | |||||
| }; | |||||
| Ok(()) | |||||
| } | |||||
| } | |||||
| #[cfg(test)] | |||||
| mod tests { | |||||
| use tokio::io::AsyncReadExt; | |||||
| use tokio::net::{TcpListener, TcpStream}; | |||||
| use tokio::sync::mpsc; | |||||
| use crate::core::frame::{FrameReader, FrameWriter}; | |||||
| use super::Worker; | |||||
| // Enable capturing logs in tests. | |||||
| fn init() { | |||||
| let _ = env_logger::builder().is_test(true).try_init(); | |||||
| } | |||||
| #[tokio::test] | |||||
| async fn forwards_incoming_frames() { | |||||
| init(); | |||||
| let listener = TcpListener::bind("localhost:0").await.unwrap(); | |||||
| let address = listener.local_addr().unwrap(); | |||||
| let listener_task = tokio::spawn(async move { | |||||
| let (mut stream, _) = listener.accept().await.unwrap(); | |||||
| let (mut read_half, write_half) = stream.split(); | |||||
| let mut writer = FrameWriter::new(write_half); | |||||
| let frame: u32 = 42; | |||||
| writer.write(&frame).await.expect("writing frame"); | |||||
| let mut buf = Vec::new(); | |||||
| read_half.read_to_end(&mut buf).await.expect("waiting for eof"); | |||||
| assert_eq!(buf, Vec::<u8>::new()); | |||||
| }); | |||||
| let stream = TcpStream::connect(address).await.expect("connecting"); | |||||
| let (request_tx, request_rx) = mpsc::channel::<u32>(100); | |||||
| let (response_tx, mut response_rx) = mpsc::channel::<u32>(100); | |||||
| let worker = Worker::new(stream, response_tx, request_rx); | |||||
| let worker_task = tokio::spawn(worker.run()); | |||||
| let frame = response_rx.recv().await.expect("receiving frame"); | |||||
| assert_eq!(frame, 42); | |||||
| // Signal to the worker that it should stop running. | |||||
| drop(request_tx); | |||||
| worker_task.await.expect("joining worker").expect("running worker"); | |||||
| listener_task.await.expect("joining listener"); | |||||
| } | |||||
| #[tokio::test] | |||||
| async fn forwards_outgoing_frames() { | |||||
| init(); | |||||
| let listener = TcpListener::bind("localhost:0").await.unwrap(); | |||||
| let address = listener.local_addr().unwrap(); | |||||
| let listener_task = tokio::spawn(async move { | |||||
| let (stream, _) = listener.accept().await.unwrap(); | |||||
| let mut reader = FrameReader::new(stream); | |||||
| let frame = reader.read().await.expect("reading frame"); | |||||
| assert_eq!(frame, Some(42u32)); | |||||
| }); | |||||
| let stream = TcpStream::connect(address).await.expect("connecting"); | |||||
| let (request_tx, request_rx) = mpsc::channel::<u32>(100); | |||||
| let (response_tx, _response_rx) = mpsc::channel::<u32>(100); | |||||
| let worker = Worker::new(stream, response_tx, request_rx); | |||||
| let worker_task = tokio::spawn(worker.run()); | |||||
| request_tx.send(42).await.expect("sending frame"); | |||||
| // Signal to the worker that it should stop running. | |||||
| drop(request_tx); | |||||
| worker_task.await.expect("joining worker").expect("running worker"); | |||||
| listener_task.await.expect("joining listener"); | |||||
| } | |||||
| } | |||||