Browse Source

Support graceful stream shutdown in client and server.

This requires amending the FrameStream::read() interface to report
when EOF has been reached.
wip
Titouan Rigoudy 4 years ago
parent
commit
27b65eec60
3 changed files with 90 additions and 33 deletions
  1. +18
    -9
      src/proto/core/frame.rs
  2. +36
    -15
      src/proto/server/client.rs
  3. +36
    -9
      src/proto/server/testing.rs

+ 18
- 9
src/proto/core/frame.rs View File

@ -168,12 +168,21 @@ where
}
}
pub async fn read(&mut self) -> io::Result<ReadFrame> {
/// Attempts to read the next frame from the underlying byte stream.
///
/// Returns `Ok(Some(frame))` on success.
/// Returns `Ok(None)` if the stream has reached the end-of-file event.
///
/// Returns an error if reading from the stream returned an error or if an
/// invalid frame was received.
pub async fn read(&mut self) -> io::Result<Option<ReadFrame>> {
loop {
if let Some(frame) = self.decoder.decode_from(&mut self.read_buffer)? {
return Ok(frame);
return Ok(Some(frame));
}
if self.stream.read_buf(&mut self.read_buffer).await? == 0 {
return Ok(None);
}
self.stream.read_buf(&mut self.read_buffer).await?;
}
}
@ -348,9 +357,9 @@ mod tests {
let (stream, _peer_address) = listener.accept().await.unwrap();
let mut frame_stream = FrameStream::<String, str>::new(stream);
assert_eq!(frame_stream.read().await.unwrap(), "ping");
assert_eq!(frame_stream.read().await.unwrap(), Some("ping".to_string()));
frame_stream.write("pong").await.unwrap();
assert_eq!(frame_stream.read().await.unwrap(), "ping");
assert_eq!(frame_stream.read().await.unwrap(), Some("ping".to_string()));
frame_stream.write("pong").await.unwrap();
});
@ -358,9 +367,9 @@ mod tests {
let mut frame_stream = FrameStream::<String, str>::new(stream);
frame_stream.write("ping").await.unwrap();
assert_eq!(frame_stream.read().await.unwrap(), "pong");
assert_eq!(frame_stream.read().await.unwrap(), Some("pong".to_string()));
frame_stream.write("ping").await.unwrap();
assert_eq!(frame_stream.read().await.unwrap(), "pong");
assert_eq!(frame_stream.read().await.unwrap(), Some("pong".to_string()));
server_task.await.unwrap();
}
@ -374,7 +383,7 @@ mod tests {
let (stream, _peer_address) = listener.accept().await.unwrap();
let mut frame_stream = FrameStream::<String, Vec<u32>>::new(stream);
assert_eq!(frame_stream.read().await.unwrap(), "ping");
assert_eq!(frame_stream.read().await.unwrap(), Some("ping".to_string()));
frame_stream.write(&vec![0; 10 * 4096]).await.unwrap();
});
@ -382,7 +391,7 @@ mod tests {
let mut frame_stream = FrameStream::<Vec<u32>, str>::new(stream);
frame_stream.write("ping").await.unwrap();
assert_eq!(frame_stream.read().await.unwrap(), vec![0; 10 * 4096]);
assert_eq!(frame_stream.read().await.unwrap(), Some(vec![0; 10 * 4096]));
server_task.await.unwrap();
}


+ 36
- 15
src/proto/server/client.rs View File

@ -3,6 +3,7 @@
use std::io;
use futures::stream::{Stream, StreamExt};
use log::{debug, info};
use thiserror::Error;
use tokio::net;
@ -54,6 +55,9 @@ pub enum ClientLoginError {
#[error("unexpected response: {0:?}")]
UnexpectedResponse(ServerResponse),
#[error("unexpected end of file")]
UnexpectedEof,
#[error("i/o error: {0}")]
IOError(#[from] io::Error),
}
@ -104,21 +108,22 @@ impl Client {
let response = self.frame_stream.read().await?;
match response {
ServerResponse::LoginResponse(LoginResponse::LoginOk {
Some(ServerResponse::LoginResponse(LoginResponse::LoginOk {
motd,
ip,
password_md5_opt,
}) => {
println!("Logged in successfully!");
println!("Message Of The Day: {}", motd);
println!("Public IP address: {}", ip);
println!("Password MD5: {:?}", password_md5_opt);
})) => {
info!("Logged in successfully!");
info!("Message Of The Day: {}", motd);
info!("Public IP address: {}", ip);
info!("Password MD5: {:?}", password_md5_opt);
Ok(())
}
ServerResponse::LoginResponse(LoginResponse::LoginFail { reason }) => {
Err(ClientLoginError::LoginFailed(reason))
}
response @ _ => Err(ClientLoginError::UnexpectedResponse(response)),
Some(ServerResponse::LoginResponse(LoginResponse::LoginFail {
reason,
})) => Err(ClientLoginError::LoginFailed(reason)),
Some(response) => Err(ClientLoginError::UnexpectedResponse(response)),
None => Err(ClientLoginError::UnexpectedEof),
}
}
@ -135,14 +140,16 @@ impl Client {
self.frame_stream.write(&request).await?;
Ok(RunOnceResult::Continue)
} else {
// Sender has been dropped. Shut down the write half of the stream.
self.frame_stream.shutdown().await?;
// Sender has been dropped.
Ok(RunOnceResult::Break)
}
},
read_result = self.frame_stream.read() => {
let response = read_result?;
Ok(RunOnceResult::Response(response))
match read_result? {
Some(response) => Ok(RunOnceResult::Response(response)),
// TODO: Consider returning error here.
None => Ok(RunOnceResult::Break),
}
},
)
}
@ -164,6 +171,14 @@ impl Client {
RunOnceResult::Response(response) => yield response,
}
}
debug!("Client: shutting down outbound stream");
self.frame_stream.shutdown().await?;
// Drain the receiving end of the connection.
while let Some(response) = self.frame_stream.read().await? {
yield response;
}
})
}
}
@ -178,8 +193,15 @@ mod tests {
use super::{Client, ClientOptions, Version};
// Enable capturing logs in tests.
fn init() {
env_logger::builder().is_test(true).try_init().unwrap();
}
#[tokio::test]
async fn client_like_grpc() {
init();
let (server, handle) = fake_server().await.unwrap();
let server_task = tokio::spawn(server.serve());
@ -197,7 +219,6 @@ mod tests {
yield ServerRequest::UserStatusRequest(UserStatusRequest {
user_name: "bob".to_string(),
});
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
});


+ 36
- 9
src/proto/server/testing.rs View File

@ -3,6 +3,7 @@
use std::io;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use log::{info, warn};
use tokio::net::TcpListener;
use tokio::sync::watch;
@ -24,16 +25,21 @@ impl Handler {
async fn run(mut self) -> io::Result<()> {
match self.frame_stream.read().await? {
ServerRequest::LoginRequest(request) => {
// TODO: Logging.
println!("Handler: Received login request: {:?}", request);
Some(ServerRequest::LoginRequest(request)) => {
info!("Handler: Received login request: {:?}", request);
}
request => {
Some(request) => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("expected login request, got: {:?}", request),
));
}
None => {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"expected login request".to_string(),
));
}
};
let response = ServerResponse::LoginResponse(LoginResponse::LoginOk {
@ -43,10 +49,12 @@ impl Handler {
});
self.frame_stream.write(&response).await?;
loop {
let request = self.frame_stream.read().await?;
println!("Handler: received request: {:?}", request);
while let Some(request) = self.frame_stream.read().await? {
info!("Handler: received request: {:?}", request);
}
info!("Handler: client disconnecting, shutting down");
Ok(())
}
}
@ -58,10 +66,18 @@ struct GracefulHandler {
impl GracefulHandler {
async fn run(mut self) -> io::Result<()> {
tokio::select!(
result = self.handler.run() => result,
result = self.handler.run() => {
if let Err(ref error) = result {
warn!("GracefulHandler: handler returned error {:?}", error);
}
result
},
// Ignore receive errors - if shutdown_rx's sender is dropped, we take
// that as a signal to shut down too.
_ = self.shutdown_rx.changed() => Ok(()),
_ = self.shutdown_rx.changed() => {
info!("GracefulHandler: shutting down.");
Ok(())
},
)
}
}
@ -142,6 +158,8 @@ impl FakeServer {
);
}
info!("FakeServer: shutting down");
// TODO: pass results back instead through an mpsc channel.
for task in handler_tasks {
task.await??;
@ -172,14 +190,23 @@ mod tests {
use super::fake_server;
// Enable capturing logs in tests.
fn init() {
let _ = env_logger::builder().is_test(true).try_init();
}
#[tokio::test]
async fn new_binds_to_localhost() {
init();
let (server, _handle) = fake_server().await.unwrap();
assert!(server.address().unwrap().ip().is_loopback());
}
#[tokio::test]
async fn accepts_incoming_connections() {
init();
let (server, handle) = fake_server().await.unwrap();
let server_task = tokio::spawn(server.serve());


Loading…
Cancel
Save