diff --git a/proto/src/core/worker.rs b/proto/src/core/worker.rs index b96275b..c3ddae7 100644 --- a/proto/src/core/worker.rs +++ b/proto/src/core/worker.rs @@ -96,30 +96,138 @@ where #[cfg(test)] mod tests { - use tokio::io::AsyncReadExt; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::mpsc; use crate::core::frame::{FrameReader, FrameWriter}; - use super::Worker; + use super::{Worker, WorkerError}; // Enable capturing logs in tests. fn init() { let _ = env_logger::builder().is_test(true).try_init(); } - // TODO: test for all 3 error conditions. + #[tokio::test] + async fn stops_on_read_error() { + init(); + + let listener = TcpListener::bind("localhost:0").await.expect("binding listener"); + let address = listener.local_addr().expect("getting local address"); + + let listener_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.expect("accepting"); + + let junk = [ + 1, 0, 0, 0, // Length: 1 byte (big-endian) + 0, // This is not enough for a u32, encoded as 4 bytes. + ]; + stream.write_all(&junk).await.expect("writing frame"); + stream.shutdown().await.expect("shutting down"); + }); + + let stream = TcpStream::connect(address).await.expect("connecting"); + + let (_request_tx, request_rx) = mpsc::channel::(100); + let (response_tx, _response_rx) = mpsc::channel::(100); + let worker = Worker::new(stream, response_tx, request_rx); + + let err = worker.run().await.expect_err("running worker"); + if let WorkerError::ReadError(_) = err { + // Ok! + } else { + panic!("Wrong error: {:?}", err); + } + + listener_task.await.expect("joining listener"); + } + + #[tokio::test] + async fn stops_on_write_error() { + init(); + + let listener = TcpListener::bind("localhost:0").await.expect("binding"); + let address = listener.local_addr().expect("getting local address"); + + let listener_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.expect("accepting"); + + let mut buf = Vec::new(); + stream.read_to_end(&mut buf).await.expect("waiting for eof"); + assert_eq!(buf, Vec::::new()); + }); + + let mut stream = TcpStream::connect(address).await.expect("connecting"); + + // Shut down the stream before running the worker. Writing again will fail. + stream.shutdown().await.expect("shutting down"); + + let (request_tx, request_rx) = mpsc::channel::(100); + let (response_tx, _response_rx) = mpsc::channel::(100); + let worker = Worker::new(stream, response_tx, request_rx); + + // Queue a frame before we run the worker. + request_tx.send(42).await.expect("sending frame"); + + let err = worker.run().await.expect_err("running worker"); + if let WorkerError::WriteError(_) = err { + // Ok! + } else { + panic!("Wrong error: {:?}", err); + } + + listener_task.await.expect("joining listener"); + } + + #[tokio::test] + async fn stops_on_incoming_channel_closed() { + init(); + + let listener = TcpListener::bind("localhost:0").await.expect("binding"); + let address = listener.local_addr().expect("getting local address"); + + let listener_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.expect("accepting"); + let (mut read_half, write_half) = stream.split(); + let mut writer = FrameWriter::new(write_half); + + // Write a response that the worker will try to send to `response_tx`. + writer.write(&42u32).await.expect("writing frame"); + + let mut buf = Vec::new(); + read_half.read_to_end(&mut buf).await.expect("waiting for eof"); + assert_eq!(buf, Vec::::new()); + }); + + let stream = TcpStream::connect(address).await.expect("connecting"); + + let (_request_tx, request_rx) = mpsc::channel::(100); + let (response_tx, response_rx) = mpsc::channel::(100); + let worker = Worker::new(stream, response_tx, request_rx); + + // Drop the receiver before the worker can send anything. + drop(response_rx); + + let err = worker.run().await.expect_err("running worker"); + if let WorkerError::IncomingChannelClosed = err { + // Ok! + } else { + panic!("Wrong error: {:?}", err); + } + + listener_task.await.expect("joining listener"); + } #[tokio::test] async fn forwards_incoming_frames() { init(); - let listener = TcpListener::bind("localhost:0").await.unwrap(); - let address = listener.local_addr().unwrap(); + let listener = TcpListener::bind("localhost:0").await.expect("binding"); + let address = listener.local_addr().expect("getting local address"); let listener_task = tokio::spawn(async move { - let (mut stream, _) = listener.accept().await.unwrap(); + let (mut stream, _) = listener.accept().await.expect("accepting"); let (mut read_half, write_half) = stream.split(); let mut writer = FrameWriter::new(write_half); @@ -153,11 +261,11 @@ mod tests { async fn forwards_outgoing_frames() { init(); - let listener = TcpListener::bind("localhost:0").await.unwrap(); - let address = listener.local_addr().unwrap(); + let listener = TcpListener::bind("localhost:0").await.expect("binding"); + let address = listener.local_addr().expect("getting local address"); let listener_task = tokio::spawn(async move { - let (stream, _) = listener.accept().await.unwrap(); + let (stream, _) = listener.accept().await.expect("accepting"); let mut reader = FrameReader::new(stream); let frame = reader.read().await.expect("reading frame");