diff --git a/src/proto/core/frame.rs b/src/proto/core/frame.rs index 3ff1aa3..8471377 100644 --- a/src/proto/core/frame.rs +++ b/src/proto/core/frame.rs @@ -183,7 +183,9 @@ where self.stream.write_all(bytes.as_ref()).await } - pub fn get_mut_stream(&mut self) -> &mut TcpStream { &mut self.stream } + pub async fn shutdown(&mut self) -> io::Result<()> { + self.stream.shutdown().await + } } mod tests { diff --git a/src/proto/server/client.rs b/src/proto/server/client.rs index 2116039..9e0a5ff 100644 --- a/src/proto/server/client.rs +++ b/src/proto/server/client.rs @@ -4,7 +4,6 @@ use std::io; use futures::stream::{Stream, StreamExt}; use thiserror::Error; -use tokio::io::AsyncWriteExt; use tokio::net; use crate::proto::core::frame::FrameStream; @@ -137,7 +136,7 @@ impl Client { Ok(RunOnceResult::Continue) } else { // Sender has been dropped. Shut down the write half of the stream. - self.frame_stream.get_mut_stream().shutdown().await?; + self.frame_stream.shutdown().await?; Ok(RunOnceResult::Break) } }, @@ -173,6 +172,7 @@ impl Client { mod tests { use futures::stream::StreamExt; use tokio::net; + use tokio::sync::watch; use crate::proto::server::testing::FakeServer; use crate::proto::server::*; @@ -181,10 +181,11 @@ mod tests { #[tokio::test] async fn client_like_grpc() { - // TODO: Check that server does not crash. - let mut server = FakeServer::new().await.unwrap(); + let server = FakeServer::new().await.unwrap(); let address = server.address().unwrap(); - let _server_task = tokio::spawn(async move { server.run().await.unwrap() }); + + let (shutdown_tx, shutdown_rx) = watch::channel(()); + let server_task = tokio::spawn(server.serve(shutdown_rx)); let stream = net::TcpStream::connect(address).await.unwrap(); @@ -196,14 +197,20 @@ mod tests { let client = Client::login(stream, options).await.unwrap(); let outbound = Box::pin(async_stream::stream! { - yield ServerRequest::UserStatusRequest(UserStatusRequest { - user_name: "bob".to_string(), - }); + for _ in 0..2 { + yield ServerRequest::UserStatusRequest(UserStatusRequest { + user_name: "bob".to_string(), + }); + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + } }); let mut inbound = client.run(outbound); while let Some(result) = inbound.next().await { let _ = dbg!(result); } + + drop(shutdown_tx); + server_task.await.unwrap().unwrap(); } } diff --git a/src/proto/server/testing.rs b/src/proto/server/testing.rs index 6340a8e..dfb838e 100644 --- a/src/proto/server/testing.rs +++ b/src/proto/server/testing.rs @@ -3,52 +3,66 @@ use std::io; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; -use tokio::net::{TcpListener, TcpStream}; +use tokio::net::TcpListener; +use tokio::sync::watch; use crate::proto::core::frame::FrameStream; use crate::proto::server::{LoginResponse, ServerRequest, ServerResponse}; -async fn process( - stream: TcpStream, +struct Handler { + frame_stream: FrameStream, peer_address: SocketAddr, -) -> io::Result<()> { - let mut frame_stream = - FrameStream::::new(stream); - - match frame_stream.read().await? { - ServerRequest::LoginRequest(request) => { - // TODO: Logging. - println!("FakeServer: Received login request: {:?}", request); - } - request => { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("expected login request, got: {:?}", request), - )); +} + +impl Handler { + fn ipv4_address(&self) -> Ipv4Addr { + match self.peer_address.ip() { + IpAddr::V4(ipv4_addr) => ipv4_addr, + IpAddr::V6(_) => Ipv4Addr::UNSPECIFIED, } - }; - - let ipv4_addr = match peer_address.ip() { - IpAddr::V4(ipv4_addr) => ipv4_addr, - IpAddr::V6(ipv6_addr) => { - println!( - "FakeServer: peer connected from IPv6 address {}, echoing 0.0.0.0", - ipv6_addr - ); - Ipv4Addr::UNSPECIFIED + } + + async fn run(mut self) -> io::Result<()> { + match self.frame_stream.read().await? { + ServerRequest::LoginRequest(request) => { + // TODO: Logging. + println!("Handler: Received login request: {:?}", request); + } + request => { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("expected login request, got: {:?}", request), + )); + } + }; + + let response = ServerResponse::LoginResponse(LoginResponse::LoginOk { + motd: "hi there".to_string(), + ip: self.ipv4_address(), + password_md5_opt: None, + }); + self.frame_stream.write(&response).await?; + + loop { + let request = self.frame_stream.read().await?; + println!("Handler: received request: {:?}", request); } - }; - - let response = ServerResponse::LoginResponse(LoginResponse::LoginOk { - motd: "hi there".to_string(), - ip: ipv4_addr, - password_md5_opt: None, - }); - frame_stream.write(&response).await?; - - loop { - let request = frame_stream.read().await?; - println!("FakeServer: received request: {:?}", request); + } +} + +struct GracefulHandler { + handler: Handler, + shutdown_rx: watch::Receiver<()>, +} + +impl GracefulHandler { + async fn run(mut self) -> io::Result<()> { + tokio::select!( + result = self.handler.run() => result, + // Ignore receive errors - if shutdown_rx's sender is dropped, we take + // that as a signal to shut down too. + _ = self.shutdown_rx.changed() => Ok(()), + ) } } @@ -71,17 +85,55 @@ impl FakeServer { } /// Runs the server: accepts incoming connections and responds to requests. - pub async fn run(&mut self) -> io::Result<()> { + /// + /// Attempts to shut down when `shutdown_rx` receives a change, or when its + /// sender is dropped. + /// + /// Returns an error if: + /// + /// - an error was encountered while listening + /// - an error was encountered while serving a request + /// + pub async fn serve( + self, + mut shutdown_rx: watch::Receiver<()>, + ) -> io::Result<()> { + let mut handler_tasks = vec![]; + loop { - let (socket, peer_address) = self.listener.accept().await?; - tokio::spawn(async move { process(socket, peer_address).await }); + tokio::select!( + result = self.listener.accept() => { + let (stream, peer_address) = result?; + + let handler = GracefulHandler { + handler: Handler { + frame_stream: FrameStream::new(stream), + peer_address, + }, + shutdown_rx: shutdown_rx.clone(), + }; + + handler_tasks.push(tokio::spawn(handler.run())); + }, + // Ignore receive errors - if shutdown_rx's sender is dropped, we take + // that as a signal to shut down too. + _ = shutdown_rx.changed() => break, + ); + } + + // TODO: pass results back instead through an mpsc channel. + for task in handler_tasks { + task.await??; } + + Ok(()) } } #[cfg(test)] mod tests { use tokio::net::TcpStream; + use tokio::sync::watch; use super::FakeServer; @@ -93,11 +145,16 @@ mod tests { #[tokio::test] async fn accepts_incoming_connections() { - let mut server = FakeServer::new().await.unwrap(); + let server = FakeServer::new().await.unwrap(); let address = server.address().unwrap(); - tokio::spawn(async move { server.run().await.unwrap() }); + + let (shutdown_tx, shutdown_rx) = watch::channel(()); + let server_task = tokio::spawn(server.serve(shutdown_rx)); // The connection succeeds. let _ = TcpStream::connect(address).await.unwrap(); + + drop(shutdown_tx); + server_task.await.unwrap().unwrap(); } }