diff --git a/src/proto/core/frame.rs b/src/proto/core/frame.rs index 8471377..0327cde 100644 --- a/src/proto/core/frame.rs +++ b/src/proto/core/frame.rs @@ -168,12 +168,21 @@ where } } - pub async fn read(&mut self) -> io::Result { + /// Attempts to read the next frame from the underlying byte stream. + /// + /// Returns `Ok(Some(frame))` on success. + /// Returns `Ok(None)` if the stream has reached the end-of-file event. + /// + /// Returns an error if reading from the stream returned an error or if an + /// invalid frame was received. + pub async fn read(&mut self) -> io::Result> { loop { if let Some(frame) = self.decoder.decode_from(&mut self.read_buffer)? { - return Ok(frame); + return Ok(Some(frame)); + } + if self.stream.read_buf(&mut self.read_buffer).await? == 0 { + return Ok(None); } - self.stream.read_buf(&mut self.read_buffer).await?; } } @@ -348,9 +357,9 @@ mod tests { let (stream, _peer_address) = listener.accept().await.unwrap(); let mut frame_stream = FrameStream::::new(stream); - assert_eq!(frame_stream.read().await.unwrap(), "ping"); + assert_eq!(frame_stream.read().await.unwrap(), Some("ping".to_string())); frame_stream.write("pong").await.unwrap(); - assert_eq!(frame_stream.read().await.unwrap(), "ping"); + assert_eq!(frame_stream.read().await.unwrap(), Some("ping".to_string())); frame_stream.write("pong").await.unwrap(); }); @@ -358,9 +367,9 @@ mod tests { let mut frame_stream = FrameStream::::new(stream); frame_stream.write("ping").await.unwrap(); - assert_eq!(frame_stream.read().await.unwrap(), "pong"); + assert_eq!(frame_stream.read().await.unwrap(), Some("pong".to_string())); frame_stream.write("ping").await.unwrap(); - assert_eq!(frame_stream.read().await.unwrap(), "pong"); + assert_eq!(frame_stream.read().await.unwrap(), Some("pong".to_string())); server_task.await.unwrap(); } @@ -374,7 +383,7 @@ mod tests { let (stream, _peer_address) = listener.accept().await.unwrap(); let mut frame_stream = FrameStream::>::new(stream); - assert_eq!(frame_stream.read().await.unwrap(), "ping"); + assert_eq!(frame_stream.read().await.unwrap(), Some("ping".to_string())); frame_stream.write(&vec![0; 10 * 4096]).await.unwrap(); }); @@ -382,7 +391,7 @@ mod tests { let mut frame_stream = FrameStream::, str>::new(stream); frame_stream.write("ping").await.unwrap(); - assert_eq!(frame_stream.read().await.unwrap(), vec![0; 10 * 4096]); + assert_eq!(frame_stream.read().await.unwrap(), Some(vec![0; 10 * 4096])); server_task.await.unwrap(); } diff --git a/src/proto/server/client.rs b/src/proto/server/client.rs index dec892c..9604b01 100644 --- a/src/proto/server/client.rs +++ b/src/proto/server/client.rs @@ -3,6 +3,7 @@ use std::io; use futures::stream::{Stream, StreamExt}; +use log::{debug, info}; use thiserror::Error; use tokio::net; @@ -54,6 +55,9 @@ pub enum ClientLoginError { #[error("unexpected response: {0:?}")] UnexpectedResponse(ServerResponse), + #[error("unexpected end of file")] + UnexpectedEof, + #[error("i/o error: {0}")] IOError(#[from] io::Error), } @@ -104,21 +108,22 @@ impl Client { let response = self.frame_stream.read().await?; match response { - ServerResponse::LoginResponse(LoginResponse::LoginOk { + Some(ServerResponse::LoginResponse(LoginResponse::LoginOk { motd, ip, password_md5_opt, - }) => { - println!("Logged in successfully!"); - println!("Message Of The Day: {}", motd); - println!("Public IP address: {}", ip); - println!("Password MD5: {:?}", password_md5_opt); + })) => { + info!("Logged in successfully!"); + info!("Message Of The Day: {}", motd); + info!("Public IP address: {}", ip); + info!("Password MD5: {:?}", password_md5_opt); Ok(()) } - ServerResponse::LoginResponse(LoginResponse::LoginFail { reason }) => { - Err(ClientLoginError::LoginFailed(reason)) - } - response @ _ => Err(ClientLoginError::UnexpectedResponse(response)), + Some(ServerResponse::LoginResponse(LoginResponse::LoginFail { + reason, + })) => Err(ClientLoginError::LoginFailed(reason)), + Some(response) => Err(ClientLoginError::UnexpectedResponse(response)), + None => Err(ClientLoginError::UnexpectedEof), } } @@ -135,14 +140,16 @@ impl Client { self.frame_stream.write(&request).await?; Ok(RunOnceResult::Continue) } else { - // Sender has been dropped. Shut down the write half of the stream. - self.frame_stream.shutdown().await?; + // Sender has been dropped. Ok(RunOnceResult::Break) } }, read_result = self.frame_stream.read() => { - let response = read_result?; - Ok(RunOnceResult::Response(response)) + match read_result? { + Some(response) => Ok(RunOnceResult::Response(response)), + // TODO: Consider returning error here. + None => Ok(RunOnceResult::Break), + } }, ) } @@ -164,6 +171,14 @@ impl Client { RunOnceResult::Response(response) => yield response, } } + + debug!("Client: shutting down outbound stream"); + self.frame_stream.shutdown().await?; + + // Drain the receiving end of the connection. + while let Some(response) = self.frame_stream.read().await? { + yield response; + } }) } } @@ -178,8 +193,15 @@ mod tests { use super::{Client, ClientOptions, Version}; + // Enable capturing logs in tests. + fn init() { + env_logger::builder().is_test(true).try_init().unwrap(); + } + #[tokio::test] async fn client_like_grpc() { + init(); + let (server, handle) = fake_server().await.unwrap(); let server_task = tokio::spawn(server.serve()); @@ -197,7 +219,6 @@ mod tests { yield ServerRequest::UserStatusRequest(UserStatusRequest { user_name: "bob".to_string(), }); - tokio::time::sleep(std::time::Duration::from_secs(1)).await; } }); diff --git a/src/proto/server/testing.rs b/src/proto/server/testing.rs index 809ad78..6cf7c51 100644 --- a/src/proto/server/testing.rs +++ b/src/proto/server/testing.rs @@ -3,6 +3,7 @@ use std::io; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use log::{info, warn}; use tokio::net::TcpListener; use tokio::sync::watch; @@ -24,16 +25,21 @@ impl Handler { async fn run(mut self) -> io::Result<()> { match self.frame_stream.read().await? { - ServerRequest::LoginRequest(request) => { - // TODO: Logging. - println!("Handler: Received login request: {:?}", request); + Some(ServerRequest::LoginRequest(request)) => { + info!("Handler: Received login request: {:?}", request); } - request => { + Some(request) => { return Err(io::Error::new( io::ErrorKind::InvalidData, format!("expected login request, got: {:?}", request), )); } + None => { + return Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "expected login request".to_string(), + )); + } }; let response = ServerResponse::LoginResponse(LoginResponse::LoginOk { @@ -43,10 +49,12 @@ impl Handler { }); self.frame_stream.write(&response).await?; - loop { - let request = self.frame_stream.read().await?; - println!("Handler: received request: {:?}", request); + while let Some(request) = self.frame_stream.read().await? { + info!("Handler: received request: {:?}", request); } + + info!("Handler: client disconnecting, shutting down"); + Ok(()) } } @@ -58,10 +66,18 @@ struct GracefulHandler { impl GracefulHandler { async fn run(mut self) -> io::Result<()> { tokio::select!( - result = self.handler.run() => result, + result = self.handler.run() => { + if let Err(ref error) = result { + warn!("GracefulHandler: handler returned error {:?}", error); + } + result + }, // Ignore receive errors - if shutdown_rx's sender is dropped, we take // that as a signal to shut down too. - _ = self.shutdown_rx.changed() => Ok(()), + _ = self.shutdown_rx.changed() => { + info!("GracefulHandler: shutting down."); + Ok(()) + }, ) } } @@ -142,6 +158,8 @@ impl FakeServer { ); } + info!("FakeServer: shutting down"); + // TODO: pass results back instead through an mpsc channel. for task in handler_tasks { task.await??; @@ -172,14 +190,23 @@ mod tests { use super::fake_server; + // Enable capturing logs in tests. + fn init() { + let _ = env_logger::builder().is_test(true).try_init(); + } + #[tokio::test] async fn new_binds_to_localhost() { + init(); + let (server, _handle) = fake_server().await.unwrap(); assert!(server.address().unwrap().ip().is_loopback()); } #[tokio::test] async fn accepts_incoming_connections() { + init(); + let (server, handle) = fake_server().await.unwrap(); let server_task = tokio::spawn(server.serve());