Browse Source

Introduce factory pattern for client handlers.

wip
Titouan Rigoudy 4 years ago
parent
commit
3d2eae62d1
2 changed files with 107 additions and 59 deletions
  1. +13
    -9
      proto/src/server/client.rs
  2. +94
    -50
      proto/src/server/testing.rs

+ 13
- 9
proto/src/server/client.rs View File

@ -124,7 +124,9 @@ mod tests {
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use crate::server::testing::{ServerBuilder, ShutdownType, UserStatusMap, UserStatusHandler};
use crate::server::testing::{
ServerBuilder, ShutdownType, UserStatusHandlerFactory, UserStatusMap,
};
use crate::server::{
Credentials, ServerRequest, ServerResponse, UserStatusRequest,
UserStatusResponse,
@ -157,10 +159,11 @@ mod tests {
async fn login_success() {
init();
let (server, handle) = ServerBuilder::new(UserStatusHandler::default())
.bind()
.await
.expect("binding server");
let (server, handle) =
ServerBuilder::new(UserStatusHandlerFactory::default())
.bind()
.await
.expect("binding server");
let server_task = tokio::spawn(server.serve());
let stream = TcpStream::connect(handle.address())
@ -194,10 +197,11 @@ mod tests {
let mut user_status_map = UserStatusMap::default();
user_status_map.insert(response.clone());
let (server, handle) = ServerBuilder::new(UserStatusHandler::new(user_status_map))
.bind()
.await
.expect("binding server");
let (server, handle) =
ServerBuilder::new(UserStatusHandlerFactory::new(user_status_map))
.bind()
.await
.expect("binding server");
let server_task = tokio::spawn(server.serve());
let stream = TcpStream::connect(handle.address())


+ 94
- 50
proto/src/server/testing.rs View File

@ -6,7 +6,7 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use anyhow::{bail, Context};
use log::{debug, info, warn};
use log::{info, warn};
use parking_lot::Mutex;
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::{TcpListener, TcpStream};
@ -37,9 +37,9 @@ impl UserStatusMap {
}
}
/// TODO: document.
pub trait Handler: Clone + Send {
/// TODO: document.
/// A handler for a single client connected to the server.
pub trait ClientHandler: Send {
/// Runs this handler against the given incoming stream of requests.
fn run(
self,
request_rx: mpsc::Receiver<ServerRequest>,
@ -47,20 +47,24 @@ pub trait Handler: Clone + Send {
) -> anyhow::Result<()>;
}
#[derive(Clone, Default)]
pub struct UserStatusHandler {
pub user_status_map: Arc<Mutex<UserStatusMap>>,
/// A factory for creating client handlers as needed.
pub trait ClientHandlerFactory {
/// The type of handler created by this factory.
type Handler: ClientHandler + 'static;
/// Creates a handler for a new client.
fn make(&self) -> Self::Handler;
}
impl UserStatusHandler {
pub fn new(user_status_map: UserStatusMap) -> Self {
Self {
user_status_map: Arc::new(Mutex::new(user_status_map)),
}
}
/// A handler that can serve responses to user status requests.
/// Responses are sent only for users who appear in `user_status_map`.
/// All other requests are ignored.
#[derive(Clone, Default)]
pub struct UserStatusHandler {
user_status_map: Arc<Mutex<UserStatusMap>>,
}
impl Handler for UserStatusHandler {
impl ClientHandler for UserStatusHandler {
fn run(
self,
mut request_rx: mpsc::Receiver<ServerRequest>,
@ -87,6 +91,36 @@ impl Handler for UserStatusHandler {
}
}
/// A factory for `UserStatusHandler`. All handlers share `user_status_map`.
#[derive(Default)]
pub struct UserStatusHandlerFactory {
/// The status map from which responses are served.
///
/// Testing code may wish to retain a copy of this in order to mutate the map
/// concurrently with requests being handled.
pub user_status_map: Arc<Mutex<UserStatusMap>>,
}
impl ClientHandlerFactory for UserStatusHandlerFactory {
type Handler = UserStatusHandler;
fn make(&self) -> Self::Handler {
Self::Handler {
user_status_map: self.user_status_map.clone(),
}
}
}
impl UserStatusHandlerFactory {
/// Convenience function to create a new factory wrapping the given `map`.
pub fn new(map: UserStatusMap) -> Self {
Self {
user_status_map: Arc::new(Mutex::new(map)),
}
}
}
/// Handles the login exchange, then hands off the connection to `inner`.
struct LoginHandler<H> {
reader: FrameReader<ServerRequest, OwnedReadHalf>,
writer: FrameWriter<ServerResponse, OwnedWriteHalf>,
@ -94,7 +128,7 @@ struct LoginHandler<H> {
inner: H,
}
impl<H: Handler + 'static> LoginHandler<H> {
impl<H: ClientHandler + 'static> LoginHandler<H> {
fn ipv4_address(&self) -> Ipv4Addr {
match self.peer_address.ip() {
IpAddr::V4(ipv4_addr) => ipv4_addr,
@ -102,14 +136,6 @@ impl<H: Handler + 'static> LoginHandler<H> {
}
}
async fn send_response(
&mut self,
response: &ServerResponse,
) -> io::Result<()> {
debug!("Sending response: {:?}", response);
self.writer.write(response).await
}
async fn handle_login(&mut self) -> anyhow::Result<()> {
match self.reader.read().await.context("reading first request")? {
Some(ServerRequest::LoginRequest(request)) => {
@ -129,7 +155,8 @@ impl<H: Handler + 'static> LoginHandler<H> {
password_md5_opt: None,
});
self
.send_response(&response)
.writer
.write(&response)
.await
.context("sending login response")
}
@ -142,27 +169,30 @@ impl<H: Handler + 'static> LoginHandler<H> {
let (request_tx, request_rx) = mpsc::channel(100);
let (response_tx, response_rx) = mpsc::channel(100);
let worker_task = tokio::spawn(async move {
worker.run(request_tx, response_rx).await
});
let worker_task =
tokio::spawn(async move { worker.run(request_tx, response_rx).await });
let inner = self.inner;
tokio::task::spawn_blocking(move || inner.run(request_rx, response_tx))
.await
.context("joining handler")?
.context("running handler")?;
worker_task.await.context("joining worker")?.context("running worker")?;
worker_task
.await
.context("joining worker")?
.context("running worker")?;
info!("Client disconnecting, shutting down");
Ok(())
}
}
/// Wraps a `LoginHandler` and shuts down gracefully upon receipt of a signal.
struct GracefulHandler<H> {
handler: LoginHandler<H>,
shutdown_rx: watch::Receiver<()>,
}
impl<H: Handler + 'static> GracefulHandler<H> {
impl<H: ClientHandler + 'static> GracefulHandler<H> {
async fn run(mut self) -> anyhow::Result<()> {
tokio::select!(
result = self.handler.run() => {
@ -181,12 +211,16 @@ impl<H: Handler + 'static> GracefulHandler<H> {
}
}
/// Wraps `GracefulHandler` and sends the result of `run()` on a channel.
///
/// Defined separately from `GracefulHandler` to enable partial moves: we want
/// to move `self.handler` separely from `self.result_tx`.
struct SenderHandler<H> {
handler: GracefulHandler<H>,
result_tx: mpsc::Sender<anyhow::Result<()>>,
}
impl<H: Handler + 'static> SenderHandler<H> {
impl<H: ClientHandler + 'static> SenderHandler<H> {
async fn run(self) {
let result = self.handler.run().await;
let _ = self.result_tx.send(result).await;
@ -195,20 +229,20 @@ impl<H: Handler + 'static> SenderHandler<H> {
/// A builder for Server instances.
#[derive(Default)]
pub struct ServerBuilder<H> {
handler: H,
pub struct ServerBuilder<F> {
factory: F,
}
impl<H: Handler> ServerBuilder<H> {
impl<F: ClientHandlerFactory> ServerBuilder<F> {
/// `handler` will be used to handle incoming connections.
pub fn new(handler: H) -> Self {
ServerBuilder { handler }
pub fn new(factory: F) -> Self {
ServerBuilder { factory }
}
/// Binds to a localhost port, then returns a server and its handle.
pub async fn bind(self) -> io::Result<(Server<H>, ServerHandle)> {
let listener = TcpListener::bind("localhost:0").await?;
let address = listener.local_addr()?;
pub async fn bind(self) -> anyhow::Result<(Server<F>, ServerHandle)> {
let listener = TcpListener::bind("localhost:0").await.context("binding")?;
let address = listener.local_addr().context("getting local address")?;
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let (handler_shutdown_tx, handler_shutdown_rx) = watch::channel(());
@ -222,7 +256,7 @@ impl<H: Handler> ServerBuilder<H> {
handler_shutdown_rx,
result_tx,
result_rx,
handler: self.handler,
factory: self.factory,
},
ServerHandle {
shutdown_tx,
@ -242,7 +276,7 @@ pub enum ShutdownType {
}
/// A simple server for connecting to in tests.
pub struct Server<H> {
pub struct Server<F> {
// Listener for new connections.
listener: TcpListener,
@ -257,8 +291,8 @@ pub struct Server<H> {
result_tx: mpsc::Sender<anyhow::Result<()>>,
result_rx: mpsc::Receiver<anyhow::Result<()>>,
// Handler used for incoming connections.
handler: H,
// Factory used to create handlers for incoming connections.
factory: F,
}
/// Allows interacting with a running `Server`.
@ -281,7 +315,7 @@ impl ServerHandle {
}
}
impl<H: Handler + 'static> Server<H> {
impl<F: ClientHandlerFactory> Server<F> {
/// Returns the address to which this server is bound.
/// This is always localhost and a random port chosen by the OS.
pub fn address(&self) -> io::Result<SocketAddr> {
@ -298,7 +332,7 @@ impl<H: Handler + 'static> Server<H> {
reader: FrameReader::new(read_half),
writer: FrameWriter::new(write_half),
peer_address,
inner: self.handler.clone(),
inner: self.factory.make(),
},
shutdown_rx: self.handler_shutdown_rx.clone(),
},
@ -368,11 +402,9 @@ impl<H: Handler + 'static> Server<H> {
#[cfg(test)]
mod tests {
use std::io;
use tokio::net::TcpStream;
use super::{ServerBuilder, ShutdownType, UserStatusHandler};
use super::{ServerBuilder, ShutdownType, UserStatusHandlerFactory};
// Enable capturing logs in tests.
fn init() {
@ -383,7 +415,11 @@ mod tests {
async fn new_binds_to_localhost() {
init();
let (server, handle) = ServerBuilder::new(UserStatusHandler::default()).bind().await.unwrap();
let (server, handle) =
ServerBuilder::new(UserStatusHandlerFactory::default())
.bind()
.await
.unwrap();
assert!(server.address().unwrap().ip().is_loopback());
assert_eq!(server.address().unwrap(), handle.address());
}
@ -392,7 +428,11 @@ mod tests {
async fn accepts_incoming_connections() {
init();
let (server, handle) = ServerBuilder::new(UserStatusHandler::default()).bind().await.unwrap();
let (server, handle) =
ServerBuilder::new(UserStatusHandlerFactory::default())
.bind()
.await
.unwrap();
let server_task = tokio::spawn(server.serve());
// The connection succeeds.
@ -411,7 +451,11 @@ mod tests {
async fn serve_yields_handler_error() {
init();
let (mut server, handle) = ServerBuilder::new(UserStatusHandler::default()).bind().await.unwrap();
let (mut server, handle) =
ServerBuilder::new(UserStatusHandlerFactory::default())
.bind()
.await
.unwrap();
// The connection is accepted, then immediately closed.
let address = handle.address();


Loading…
Cancel
Save