From 8b2e6aeae2ba11808f4b06fc8a5238052d81087a Mon Sep 17 00:00:00 2001 From: Titouan Rigoudy Date: Sat, 25 Sep 2021 10:20:49 +0200 Subject: [PATCH] Introduce testing server for client tests. Make proto testing server abstract over connection handlers. --- Cargo.lock | 1 + client/src/main.rs | 2 + client/src/testing/mod.rs | 1 + client/src/testing/server.rs | 26 +++++ proto/Cargo.toml | 1 + proto/src/server/client.rs | 7 +- proto/src/server/testing.rs | 196 ++++++++++++++++++++--------------- 7 files changed, 146 insertions(+), 88 deletions(-) create mode 100644 client/src/testing/mod.rs create mode 100644 client/src/testing/server.rs diff --git a/Cargo.lock b/Cargo.lock index 27fd4ad..a0283f2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -760,6 +760,7 @@ dependencies = [ name = "solstice-proto" version = "0.1.0" dependencies = [ + "anyhow", "async-stream", "bytes", "encoding_rs", diff --git a/client/src/main.rs b/client/src/main.rs index 02f8185..d4080ee 100644 --- a/client/src/main.rs +++ b/client/src/main.rs @@ -20,6 +20,8 @@ mod handlers; mod login; mod message_handler; mod room; +#[cfg(test)] +mod testing; mod user; use config::Config; diff --git a/client/src/testing/mod.rs b/client/src/testing/mod.rs new file mode 100644 index 0000000..df4adf9 --- /dev/null +++ b/client/src/testing/mod.rs @@ -0,0 +1 @@ +mod server; diff --git a/client/src/testing/server.rs b/client/src/testing/server.rs new file mode 100644 index 0000000..46cbf3e --- /dev/null +++ b/client/src/testing/server.rs @@ -0,0 +1,26 @@ +use std::net::SocketAddr; + +use anyhow::Context; +use tokio::net::TcpListener; +use tokio::sync::oneshot; + +struct ServerHandle { + address: SocketAddr, + shutdown_tx: oneshot::Sender<()>, +} + +struct Server { + listener: TcpListener, + shutdown_rx: oneshot::Receiver<()>, +} + +struct ServerBuilder { + // TODO +} + +impl ServerBuilder { + pub async fn bind(self) -> anyhow::Result<(Server, ServerHandle)> { + let listener = TcpListener::bind("localhost:0").await.context("binding")?; + unimplemented!(); + } +} diff --git a/proto/Cargo.toml b/proto/Cargo.toml index 0edf7a2..a7e3224 100644 --- a/proto/Cargo.toml +++ b/proto/Cargo.toml @@ -5,6 +5,7 @@ authors = ["letitz"] edition = "2018" [dependencies] +anyhow = "^1.0" async-stream = "^0.3" bytes = "^1.0" encoding_rs = "^0.8" diff --git a/proto/src/server/client.rs b/proto/src/server/client.rs index e8aa174..bd28b4c 100644 --- a/proto/src/server/client.rs +++ b/proto/src/server/client.rs @@ -124,7 +124,7 @@ mod tests { use tokio::net::TcpStream; use tokio::sync::mpsc; - use crate::server::testing::{ServerBuilder, ShutdownType, UserStatusMap}; + use crate::server::testing::{ServerBuilder, ShutdownType, UserStatusMap, UserStatusHandler}; use crate::server::{ Credentials, ServerRequest, ServerResponse, UserStatusRequest, UserStatusResponse, @@ -157,7 +157,7 @@ mod tests { async fn login_success() { init(); - let (server, handle) = ServerBuilder::default() + let (server, handle) = ServerBuilder::new(UserStatusHandler::default()) .bind() .await .expect("binding server"); @@ -194,8 +194,7 @@ mod tests { let mut user_status_map = UserStatusMap::default(); user_status_map.insert(response.clone()); - let (server, handle) = ServerBuilder::default() - .with_user_status_map(user_status_map) + let (server, handle) = ServerBuilder::new(UserStatusHandler::new(user_status_map)) .bind() .await .expect("binding server"); diff --git a/proto/src/server/testing.rs b/proto/src/server/testing.rs index cf752af..2b35fa4 100644 --- a/proto/src/server/testing.rs +++ b/proto/src/server/testing.rs @@ -5,17 +5,16 @@ use std::io; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::sync::Arc; +use anyhow::{bail, Context}; use log::{debug, info, warn}; use parking_lot::Mutex; -use tokio::net::{ - tcp::{OwnedReadHalf, OwnedWriteHalf}, - TcpListener, TcpStream, -}; +use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; +use tokio::net::{TcpListener, TcpStream}; use tokio::sync::mpsc; use tokio::sync::oneshot; use tokio::sync::watch; -use crate::core::{FrameReader, FrameWriter}; +use crate::core::{FrameReader, FrameWriter, Worker}; use crate::server::{ LoginResponse, ServerRequest, ServerResponse, UserStatusRequest, UserStatusResponse, @@ -38,14 +37,64 @@ impl UserStatusMap { } } -struct Handler { +/// TODO: document. +pub trait Handler: Clone + Send { + /// TODO: document. + fn run( + self, + request_rx: mpsc::Receiver, + response_tx: mpsc::Sender, + ) -> anyhow::Result<()>; +} + +#[derive(Clone, Default)] +pub struct UserStatusHandler { + pub user_status_map: Arc>, +} + +impl UserStatusHandler { + pub fn new(user_status_map: UserStatusMap) -> Self { + Self { + user_status_map: Arc::new(Mutex::new(user_status_map)), + } + } +} + +impl Handler for UserStatusHandler { + fn run( + self, + mut request_rx: mpsc::Receiver, + response_tx: mpsc::Sender, + ) -> anyhow::Result<()> { + while let Some(request) = request_rx.blocking_recv() { + match request { + ServerRequest::UserStatusRequest(UserStatusRequest { user_name }) => { + let entry = self.user_status_map.lock().get(&user_name); + if let Some(response) = entry { + response_tx + .blocking_send(ServerResponse::UserStatusResponse(response)) + .context("sending response")?; + } else { + warn!("Received UserStatusRequest for unknown user {}", user_name); + } + } + _ => { + warn!("Unhandled request: {:?}", request); + } + } + } + Ok(()) + } +} + +struct LoginHandler { reader: FrameReader, writer: FrameWriter, peer_address: SocketAddr, - user_status_map: Arc>, + inner: H, } -impl Handler { +impl LoginHandler { fn ipv4_address(&self) -> Ipv4Addr { match self.peer_address.ip() { IpAddr::V4(ipv4_addr) => ipv4_addr, @@ -61,22 +110,16 @@ impl Handler { self.writer.write(response).await } - async fn handle_login(&mut self) -> io::Result<()> { - match self.reader.read().await? { + async fn handle_login(&mut self) -> anyhow::Result<()> { + match self.reader.read().await.context("reading first request")? { Some(ServerRequest::LoginRequest(request)) => { info!("Received login request: {:?}", request); } Some(request) => { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("expected login request, got: {:?}", request), - )); + bail!("expected login request, got: {:?}", request); } None => { - return Err(io::Error::new( - io::ErrorKind::UnexpectedEof, - "expected login request".to_string(), - )); + bail!("expected login request, got eof"); } }; @@ -85,72 +128,65 @@ impl Handler { ip: self.ipv4_address(), password_md5_opt: None, }); - self.send_response(&response).await + self + .send_response(&response) + .await + .context("sending login response") } - async fn handle_request(&mut self, request: ServerRequest) -> io::Result<()> { - debug!("Received request: {:?}", request); - - match request { - ServerRequest::UserStatusRequest(UserStatusRequest { user_name }) => { - let entry = self.user_status_map.lock().get(&user_name); - if let Some(response) = entry { - let response = ServerResponse::UserStatusResponse(response); - self.send_response(&response).await?; - } else { - warn!("Received UserStatusRequest for unknown user {}", user_name); - } - } - _ => { - warn!("Unhandled request: {:?}", request); - } - } + async fn run(mut self) -> anyhow::Result<()> { + self.handle_login().await?; - Ok(()) - } + let mut worker = Worker::from_parts(self.reader, self.writer); - async fn run(mut self) -> io::Result<()> { - self.handle_login().await?; + let (request_tx, request_rx) = mpsc::channel(100); + let (response_tx, response_rx) = mpsc::channel(100); - while let Some(request) = self.reader.read().await? { - self.handle_request(request).await?; - } + let worker_task = tokio::spawn(async move { + worker.run(request_tx, response_rx).await + }); + let inner = self.inner; + tokio::task::spawn_blocking(move || inner.run(request_rx, response_tx)) + .await + .context("joining handler")? + .context("running handler")?; + worker_task.await.context("joining worker")?.context("running worker")?; info!("Client disconnecting, shutting down"); Ok(()) } } -struct GracefulHandler { - handler: Handler, +struct GracefulHandler { + handler: LoginHandler, shutdown_rx: watch::Receiver<()>, } -impl GracefulHandler { - async fn run(mut self) -> io::Result<()> { +impl GracefulHandler { + async fn run(mut self) -> anyhow::Result<()> { tokio::select!( result = self.handler.run() => { if let Err(ref error) = result { - warn!("Handler returned error {:?}", error); + warn!("LoginHandler returned error {:?}", error); } result }, // Ignore receive errors - if shutdown_rx's sender is dropped, we take // that as a signal to shut down too. _ = self.shutdown_rx.changed() => { - info!("Handler shutting down."); + info!("GracefulHandler shutting down."); Ok(()) }, ) } } -struct SenderHandler { - handler: GracefulHandler, - result_tx: mpsc::Sender>, +struct SenderHandler { + handler: GracefulHandler, + result_tx: mpsc::Sender>, } -impl SenderHandler { +impl SenderHandler { async fn run(self) { let result = self.handler.run().await; let _ = self.result_tx.send(result).await; @@ -159,28 +195,21 @@ impl SenderHandler { /// A builder for Server instances. #[derive(Default)] -pub struct ServerBuilder { - user_status_map: Option>>, +pub struct ServerBuilder { + handler: H, } -impl ServerBuilder { - /// Sets the UserStatusMap which the server will use to respond to - /// UserStatusRequest messages. - pub fn with_user_status_map(mut self, map: UserStatusMap) -> Self { - self.user_status_map = Some(Arc::new(Mutex::new(map))); - self +impl ServerBuilder { + /// `handler` will be used to handle incoming connections. + pub fn new(handler: H) -> Self { + ServerBuilder { handler } } /// Binds to a localhost port, then returns a server and its handle. - pub async fn bind(self) -> io::Result<(Server, ServerHandle)> { + pub async fn bind(self) -> io::Result<(Server, ServerHandle)> { let listener = TcpListener::bind("localhost:0").await?; let address = listener.local_addr()?; - let user_status_map = match self.user_status_map { - Some(user_status_map) => user_status_map, - None => Arc::new(Mutex::new(UserStatusMap::default())), - }; - let (shutdown_tx, shutdown_rx) = oneshot::channel(); let (handler_shutdown_tx, handler_shutdown_rx) = watch::channel(()); let (result_tx, result_rx) = mpsc::channel(1); @@ -193,7 +222,7 @@ impl ServerBuilder { handler_shutdown_rx, result_tx, result_rx, - user_status_map, + handler: self.handler, }, ServerHandle { shutdown_tx, @@ -213,7 +242,7 @@ pub enum ShutdownType { } /// A simple server for connecting to in tests. -pub struct Server { +pub struct Server { // Listener for new connections. listener: TcpListener, @@ -225,11 +254,11 @@ pub struct Server { handler_shutdown_rx: watch::Receiver<()>, // Channel for receiving results back from handlers. - result_tx: mpsc::Sender>, - result_rx: mpsc::Receiver>, + result_tx: mpsc::Sender>, + result_rx: mpsc::Receiver>, - // Shared state for handlers to use when serving responses. - user_status_map: Arc>, + // Handler used for incoming connections. + handler: H, } /// Allows interacting with a running `Server`. @@ -252,7 +281,7 @@ impl ServerHandle { } } -impl Server { +impl Server { /// Returns the address to which this server is bound. /// This is always localhost and a random port chosen by the OS. pub fn address(&self) -> io::Result { @@ -265,11 +294,11 @@ impl Server { let handler = SenderHandler { handler: GracefulHandler { - handler: Handler { + handler: LoginHandler { reader: FrameReader::new(read_half), writer: FrameWriter::new(write_half), peer_address, - user_status_map: self.user_status_map.clone(), + inner: self.handler.clone(), }, shutdown_rx: self.handler_shutdown_rx.clone(), }, @@ -296,11 +325,11 @@ impl Server { /// - an error was encountered while listening /// - an error was encountered while serving a request /// - pub async fn serve(mut self) -> io::Result<()> { + pub async fn serve(mut self) -> anyhow::Result<()> { loop { tokio::select!( result = self.listener.accept() => { - let (stream, peer_address) = result?; + let (stream, peer_address) = result.context("accepting stream")?; self.spawn_handler(stream, peer_address); }, result = &mut self.shutdown_rx => { @@ -343,7 +372,7 @@ mod tests { use tokio::net::TcpStream; - use super::{ServerBuilder, ShutdownType}; + use super::{ServerBuilder, ShutdownType, UserStatusHandler}; // Enable capturing logs in tests. fn init() { @@ -354,7 +383,7 @@ mod tests { async fn new_binds_to_localhost() { init(); - let (server, handle) = ServerBuilder::default().bind().await.unwrap(); + let (server, handle) = ServerBuilder::new(UserStatusHandler::default()).bind().await.unwrap(); assert!(server.address().unwrap().ip().is_loopback()); assert_eq!(server.address().unwrap(), handle.address()); } @@ -363,7 +392,7 @@ mod tests { async fn accepts_incoming_connections() { init(); - let (server, handle) = ServerBuilder::default().bind().await.unwrap(); + let (server, handle) = ServerBuilder::new(UserStatusHandler::default()).bind().await.unwrap(); let server_task = tokio::spawn(server.serve()); // The connection succeeds. @@ -382,7 +411,7 @@ mod tests { async fn serve_yields_handler_error() { init(); - let (mut server, handle) = ServerBuilder::default().bind().await.unwrap(); + let (mut server, handle) = ServerBuilder::new(UserStatusHandler::default()).bind().await.unwrap(); // The connection is accepted, then immediately closed. let address = handle.address(); @@ -397,7 +426,6 @@ mod tests { handle.shutdown(ShutdownType::LameDuck); // Drain outstanding requests, encountering the error. - let error = server.serve().await.unwrap_err(); - assert_eq!(error.kind(), io::ErrorKind::UnexpectedEof); + server.serve().await.unwrap_err(); } }