| @ -0,0 +1,185 @@ | |||
| use std::fmt::Debug; | |||
| use std::io; | |||
| use log::debug; | |||
| use thiserror::Error; | |||
| use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; | |||
| use tokio::net::TcpStream; | |||
| use tokio::sync::mpsc; | |||
| use crate::core::frame::{FrameReader, FrameWriter}; | |||
| use crate::core::value::{ValueDecode, ValueEncode}; | |||
| /// An error that arose while exchanging messages over a `Channel`. | |||
| #[derive(Debug, Error)] | |||
| pub enum WorkerError { | |||
| #[error("read error: {0}")] | |||
| ReadError(io::Error), | |||
| #[error("write error: {0}")] | |||
| WriteError(io::Error), | |||
| #[error("incoming frame channel is unexpectedly closed")] | |||
| IncomingChannelClosed, | |||
| } | |||
| async fn forward_incoming<ReadFrame: ValueDecode + Debug>( | |||
| mut reader: FrameReader<ReadFrame, OwnedReadHalf>, | |||
| incoming_tx: mpsc::Sender<ReadFrame>, | |||
| ) -> Result<(), WorkerError> { | |||
| while let Some(frame) = reader.read().await.map_err(WorkerError::ReadError)? { | |||
| debug!("Channel: received frame: {:?}", frame); | |||
| if let Err(_) = incoming_tx.send(frame).await { | |||
| return Err(WorkerError::IncomingChannelClosed); | |||
| } | |||
| } | |||
| debug!("Stopping incoming handler: frame reader is closed"); | |||
| Ok(()) | |||
| } | |||
| async fn forward_outgoing<WriteFrame: ValueEncode + Debug>( | |||
| mut outgoing_rx: mpsc::Receiver<WriteFrame>, | |||
| mut writer: FrameWriter<WriteFrame, OwnedWriteHalf>, | |||
| ) -> Result<(), WorkerError> { | |||
| while let Some(frame) = outgoing_rx.recv().await { | |||
| debug!("Channel: sending frame: {:?}", frame); | |||
| writer | |||
| .write(&frame) | |||
| .await | |||
| .map_err(WorkerError::WriteError)?; | |||
| } | |||
| debug!("Stopping outgoing handler: channel is closed"); | |||
| Ok(()) | |||
| } | |||
| /// A worker that operates a full-duplex connection exchanging frames over TCP. | |||
| pub struct Worker<ReadFrame, WriteFrame> { | |||
| reader: FrameReader<ReadFrame, OwnedReadHalf>, | |||
| writer: FrameWriter<WriteFrame, OwnedWriteHalf>, | |||
| incoming_tx: mpsc::Sender<ReadFrame>, | |||
| outgoing_rx: mpsc::Receiver<WriteFrame>, | |||
| } | |||
| impl<ReadFrame, WriteFrame> Worker<ReadFrame, WriteFrame> | |||
| where | |||
| ReadFrame: ValueDecode + Debug, | |||
| WriteFrame: ValueEncode + Debug, | |||
| { | |||
| fn new( | |||
| stream: TcpStream, | |||
| incoming_tx: mpsc::Sender<ReadFrame>, | |||
| outgoing_rx: mpsc::Receiver<WriteFrame>, | |||
| ) -> Self { | |||
| let (read_half, write_half) = stream.into_split(); | |||
| let reader = FrameReader::new(read_half); | |||
| let writer = FrameWriter::new(write_half); | |||
| Self { | |||
| reader, | |||
| writer, | |||
| incoming_tx, | |||
| outgoing_rx, | |||
| } | |||
| } | |||
| async fn run(self) -> Result<(), WorkerError> { | |||
| tokio::select! { | |||
| result = forward_incoming(self.reader, self.incoming_tx) => { | |||
| debug!("{:?}", result); | |||
| }, | |||
| result = forward_outgoing(self.outgoing_rx, self.writer) => { | |||
| debug!("{:?}", result); | |||
| }, | |||
| }; | |||
| Ok(()) | |||
| } | |||
| } | |||
| #[cfg(test)] | |||
| mod tests { | |||
| use tokio::io::AsyncReadExt; | |||
| use tokio::net::{TcpListener, TcpStream}; | |||
| use tokio::sync::mpsc; | |||
| use crate::core::frame::{FrameReader, FrameWriter}; | |||
| use super::Worker; | |||
| // Enable capturing logs in tests. | |||
| fn init() { | |||
| let _ = env_logger::builder().is_test(true).try_init(); | |||
| } | |||
| #[tokio::test] | |||
| async fn forwards_incoming_frames() { | |||
| init(); | |||
| let listener = TcpListener::bind("localhost:0").await.unwrap(); | |||
| let address = listener.local_addr().unwrap(); | |||
| let listener_task = tokio::spawn(async move { | |||
| let (mut stream, _) = listener.accept().await.unwrap(); | |||
| let (mut read_half, write_half) = stream.split(); | |||
| let mut writer = FrameWriter::new(write_half); | |||
| let frame: u32 = 42; | |||
| writer.write(&frame).await.expect("writing frame"); | |||
| let mut buf = Vec::new(); | |||
| read_half.read_to_end(&mut buf).await.expect("waiting for eof"); | |||
| assert_eq!(buf, Vec::<u8>::new()); | |||
| }); | |||
| let stream = TcpStream::connect(address).await.expect("connecting"); | |||
| let (request_tx, request_rx) = mpsc::channel::<u32>(100); | |||
| let (response_tx, mut response_rx) = mpsc::channel::<u32>(100); | |||
| let worker = Worker::new(stream, response_tx, request_rx); | |||
| let worker_task = tokio::spawn(worker.run()); | |||
| let frame = response_rx.recv().await.expect("receiving frame"); | |||
| assert_eq!(frame, 42); | |||
| // Signal to the worker that it should stop running. | |||
| drop(request_tx); | |||
| worker_task.await.expect("joining worker").expect("running worker"); | |||
| listener_task.await.expect("joining listener"); | |||
| } | |||
| #[tokio::test] | |||
| async fn forwards_outgoing_frames() { | |||
| init(); | |||
| let listener = TcpListener::bind("localhost:0").await.unwrap(); | |||
| let address = listener.local_addr().unwrap(); | |||
| let listener_task = tokio::spawn(async move { | |||
| let (stream, _) = listener.accept().await.unwrap(); | |||
| let mut reader = FrameReader::new(stream); | |||
| let frame = reader.read().await.expect("reading frame"); | |||
| assert_eq!(frame, Some(42u32)); | |||
| }); | |||
| let stream = TcpStream::connect(address).await.expect("connecting"); | |||
| let (request_tx, request_rx) = mpsc::channel::<u32>(100); | |||
| let (response_tx, _response_rx) = mpsc::channel::<u32>(100); | |||
| let worker = Worker::new(stream, response_tx, request_rx); | |||
| let worker_task = tokio::spawn(worker.run()); | |||
| request_tx.send(42).await.expect("sending frame"); | |||
| // Signal to the worker that it should stop running. | |||
| drop(request_tx); | |||
| worker_task.await.expect("joining worker").expect("running worker"); | |||
| listener_task.await.expect("joining listener"); | |||
| } | |||
| } | |||