Browse Source

Implement FakeServer graceful shutdown.

wip
Titouan Rigoudy 4 years ago
parent
commit
51464e2251
3 changed files with 119 additions and 53 deletions
  1. +3
    -1
      src/proto/core/frame.rs
  2. +15
    -8
      src/proto/server/client.rs
  3. +101
    -44
      src/proto/server/testing.rs

+ 3
- 1
src/proto/core/frame.rs View File

@ -183,7 +183,9 @@ where
self.stream.write_all(bytes.as_ref()).await
}
pub fn get_mut_stream(&mut self) -> &mut TcpStream { &mut self.stream }
pub async fn shutdown(&mut self) -> io::Result<()> {
self.stream.shutdown().await
}
}
mod tests {


+ 15
- 8
src/proto/server/client.rs View File

@ -4,7 +4,6 @@ use std::io;
use futures::stream::{Stream, StreamExt};
use thiserror::Error;
use tokio::io::AsyncWriteExt;
use tokio::net;
use crate::proto::core::frame::FrameStream;
@ -137,7 +136,7 @@ impl Client {
Ok(RunOnceResult::Continue)
} else {
// Sender has been dropped. Shut down the write half of the stream.
self.frame_stream.get_mut_stream().shutdown().await?;
self.frame_stream.shutdown().await?;
Ok(RunOnceResult::Break)
}
},
@ -173,6 +172,7 @@ impl Client {
mod tests {
use futures::stream::StreamExt;
use tokio::net;
use tokio::sync::watch;
use crate::proto::server::testing::FakeServer;
use crate::proto::server::*;
@ -181,10 +181,11 @@ mod tests {
#[tokio::test]
async fn client_like_grpc() {
// TODO: Check that server does not crash.
let mut server = FakeServer::new().await.unwrap();
let server = FakeServer::new().await.unwrap();
let address = server.address().unwrap();
let _server_task = tokio::spawn(async move { server.run().await.unwrap() });
let (shutdown_tx, shutdown_rx) = watch::channel(());
let server_task = tokio::spawn(server.serve(shutdown_rx));
let stream = net::TcpStream::connect(address).await.unwrap();
@ -196,14 +197,20 @@ mod tests {
let client = Client::login(stream, options).await.unwrap();
let outbound = Box::pin(async_stream::stream! {
yield ServerRequest::UserStatusRequest(UserStatusRequest {
user_name: "bob".to_string(),
});
for _ in 0..2 {
yield ServerRequest::UserStatusRequest(UserStatusRequest {
user_name: "bob".to_string(),
});
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
});
let mut inbound = client.run(outbound);
while let Some(result) = inbound.next().await {
let _ = dbg!(result);
}
drop(shutdown_tx);
server_task.await.unwrap().unwrap();
}
}

+ 101
- 44
src/proto/server/testing.rs View File

@ -3,52 +3,66 @@
use std::io;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use tokio::net::{TcpListener, TcpStream};
use tokio::net::TcpListener;
use tokio::sync::watch;
use crate::proto::core::frame::FrameStream;
use crate::proto::server::{LoginResponse, ServerRequest, ServerResponse};
async fn process(
stream: TcpStream,
struct Handler {
frame_stream: FrameStream<ServerRequest, ServerResponse>,
peer_address: SocketAddr,
) -> io::Result<()> {
let mut frame_stream =
FrameStream::<ServerRequest, ServerResponse>::new(stream);
match frame_stream.read().await? {
ServerRequest::LoginRequest(request) => {
// TODO: Logging.
println!("FakeServer: Received login request: {:?}", request);
}
request => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("expected login request, got: {:?}", request),
));
}
impl Handler {
fn ipv4_address(&self) -> Ipv4Addr {
match self.peer_address.ip() {
IpAddr::V4(ipv4_addr) => ipv4_addr,
IpAddr::V6(_) => Ipv4Addr::UNSPECIFIED,
}
};
let ipv4_addr = match peer_address.ip() {
IpAddr::V4(ipv4_addr) => ipv4_addr,
IpAddr::V6(ipv6_addr) => {
println!(
"FakeServer: peer connected from IPv6 address {}, echoing 0.0.0.0",
ipv6_addr
);
Ipv4Addr::UNSPECIFIED
}
async fn run(mut self) -> io::Result<()> {
match self.frame_stream.read().await? {
ServerRequest::LoginRequest(request) => {
// TODO: Logging.
println!("Handler: Received login request: {:?}", request);
}
request => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("expected login request, got: {:?}", request),
));
}
};
let response = ServerResponse::LoginResponse(LoginResponse::LoginOk {
motd: "hi there".to_string(),
ip: self.ipv4_address(),
password_md5_opt: None,
});
self.frame_stream.write(&response).await?;
loop {
let request = self.frame_stream.read().await?;
println!("Handler: received request: {:?}", request);
}
};
let response = ServerResponse::LoginResponse(LoginResponse::LoginOk {
motd: "hi there".to_string(),
ip: ipv4_addr,
password_md5_opt: None,
});
frame_stream.write(&response).await?;
loop {
let request = frame_stream.read().await?;
println!("FakeServer: received request: {:?}", request);
}
}
struct GracefulHandler {
handler: Handler,
shutdown_rx: watch::Receiver<()>,
}
impl GracefulHandler {
async fn run(mut self) -> io::Result<()> {
tokio::select!(
result = self.handler.run() => result,
// Ignore receive errors - if shutdown_rx's sender is dropped, we take
// that as a signal to shut down too.
_ = self.shutdown_rx.changed() => Ok(()),
)
}
}
@ -71,17 +85,55 @@ impl FakeServer {
}
/// Runs the server: accepts incoming connections and responds to requests.
pub async fn run(&mut self) -> io::Result<()> {
///
/// Attempts to shut down when `shutdown_rx` receives a change, or when its
/// sender is dropped.
///
/// Returns an error if:
///
/// - an error was encountered while listening
/// - an error was encountered while serving a request
///
pub async fn serve(
self,
mut shutdown_rx: watch::Receiver<()>,
) -> io::Result<()> {
let mut handler_tasks = vec![];
loop {
let (socket, peer_address) = self.listener.accept().await?;
tokio::spawn(async move { process(socket, peer_address).await });
tokio::select!(
result = self.listener.accept() => {
let (stream, peer_address) = result?;
let handler = GracefulHandler {
handler: Handler {
frame_stream: FrameStream::new(stream),
peer_address,
},
shutdown_rx: shutdown_rx.clone(),
};
handler_tasks.push(tokio::spawn(handler.run()));
},
// Ignore receive errors - if shutdown_rx's sender is dropped, we take
// that as a signal to shut down too.
_ = shutdown_rx.changed() => break,
);
}
// TODO: pass results back instead through an mpsc channel.
for task in handler_tasks {
task.await??;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use tokio::net::TcpStream;
use tokio::sync::watch;
use super::FakeServer;
@ -93,11 +145,16 @@ mod tests {
#[tokio::test]
async fn accepts_incoming_connections() {
let mut server = FakeServer::new().await.unwrap();
let server = FakeServer::new().await.unwrap();
let address = server.address().unwrap();
tokio::spawn(async move { server.run().await.unwrap() });
let (shutdown_tx, shutdown_rx) = watch::channel(());
let server_task = tokio::spawn(server.serve(shutdown_rx));
// The connection succeeds.
let _ = TcpStream::connect(address).await.unwrap();
drop(shutdown_tx);
server_task.await.unwrap().unwrap();
}
}

Loading…
Cancel
Save