|
|
|
@ -6,7 +6,7 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr}; |
|
|
|
use std::sync::Arc;
|
|
|
|
|
|
|
|
use anyhow::{bail, Context};
|
|
|
|
use log::{debug, info, warn};
|
|
|
|
use log::{info, warn};
|
|
|
|
use parking_lot::Mutex;
|
|
|
|
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
|
|
|
|
use tokio::net::{TcpListener, TcpStream};
|
|
|
|
@ -37,9 +37,9 @@ impl UserStatusMap { |
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// TODO: document.
|
|
|
|
pub trait Handler: Clone + Send {
|
|
|
|
/// TODO: document.
|
|
|
|
/// A handler for a single client connected to the server.
|
|
|
|
pub trait ClientHandler: Send {
|
|
|
|
/// Runs this handler against the given incoming stream of requests.
|
|
|
|
fn run(
|
|
|
|
self,
|
|
|
|
request_rx: mpsc::Receiver<ServerRequest>,
|
|
|
|
@ -47,20 +47,24 @@ pub trait Handler: Clone + Send { |
|
|
|
) -> anyhow::Result<()>;
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Clone, Default)]
|
|
|
|
pub struct UserStatusHandler {
|
|
|
|
pub user_status_map: Arc<Mutex<UserStatusMap>>,
|
|
|
|
/// A factory for creating client handlers as needed.
|
|
|
|
pub trait ClientHandlerFactory {
|
|
|
|
/// The type of handler created by this factory.
|
|
|
|
type Handler: ClientHandler + 'static;
|
|
|
|
|
|
|
|
/// Creates a handler for a new client.
|
|
|
|
fn make(&self) -> Self::Handler;
|
|
|
|
}
|
|
|
|
|
|
|
|
impl UserStatusHandler {
|
|
|
|
pub fn new(user_status_map: UserStatusMap) -> Self {
|
|
|
|
Self {
|
|
|
|
user_status_map: Arc::new(Mutex::new(user_status_map)),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
/// A handler that can serve responses to user status requests.
|
|
|
|
/// Responses are sent only for users who appear in `user_status_map`.
|
|
|
|
/// All other requests are ignored.
|
|
|
|
#[derive(Clone, Default)]
|
|
|
|
pub struct UserStatusHandler {
|
|
|
|
user_status_map: Arc<Mutex<UserStatusMap>>,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl Handler for UserStatusHandler {
|
|
|
|
impl ClientHandler for UserStatusHandler {
|
|
|
|
fn run(
|
|
|
|
self,
|
|
|
|
mut request_rx: mpsc::Receiver<ServerRequest>,
|
|
|
|
@ -87,6 +91,36 @@ impl Handler for UserStatusHandler { |
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// A factory for `UserStatusHandler`. All handlers share `user_status_map`.
|
|
|
|
#[derive(Default)]
|
|
|
|
pub struct UserStatusHandlerFactory {
|
|
|
|
/// The status map from which responses are served.
|
|
|
|
///
|
|
|
|
/// Testing code may wish to retain a copy of this in order to mutate the map
|
|
|
|
/// concurrently with requests being handled.
|
|
|
|
pub user_status_map: Arc<Mutex<UserStatusMap>>,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl ClientHandlerFactory for UserStatusHandlerFactory {
|
|
|
|
type Handler = UserStatusHandler;
|
|
|
|
|
|
|
|
fn make(&self) -> Self::Handler {
|
|
|
|
Self::Handler {
|
|
|
|
user_status_map: self.user_status_map.clone(),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
impl UserStatusHandlerFactory {
|
|
|
|
/// Convenience function to create a new factory wrapping the given `map`.
|
|
|
|
pub fn new(map: UserStatusMap) -> Self {
|
|
|
|
Self {
|
|
|
|
user_status_map: Arc::new(Mutex::new(map)),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Handles the login exchange, then hands off the connection to `inner`.
|
|
|
|
struct LoginHandler<H> {
|
|
|
|
reader: FrameReader<ServerRequest, OwnedReadHalf>,
|
|
|
|
writer: FrameWriter<ServerResponse, OwnedWriteHalf>,
|
|
|
|
@ -94,7 +128,7 @@ struct LoginHandler<H> { |
|
|
|
inner: H,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl<H: Handler + 'static> LoginHandler<H> {
|
|
|
|
impl<H: ClientHandler + 'static> LoginHandler<H> {
|
|
|
|
fn ipv4_address(&self) -> Ipv4Addr {
|
|
|
|
match self.peer_address.ip() {
|
|
|
|
IpAddr::V4(ipv4_addr) => ipv4_addr,
|
|
|
|
@ -102,14 +136,6 @@ impl<H: Handler + 'static> LoginHandler<H> { |
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
async fn send_response(
|
|
|
|
&mut self,
|
|
|
|
response: &ServerResponse,
|
|
|
|
) -> io::Result<()> {
|
|
|
|
debug!("Sending response: {:?}", response);
|
|
|
|
self.writer.write(response).await
|
|
|
|
}
|
|
|
|
|
|
|
|
async fn handle_login(&mut self) -> anyhow::Result<()> {
|
|
|
|
match self.reader.read().await.context("reading first request")? {
|
|
|
|
Some(ServerRequest::LoginRequest(request)) => {
|
|
|
|
@ -129,7 +155,8 @@ impl<H: Handler + 'static> LoginHandler<H> { |
|
|
|
password_md5_opt: None,
|
|
|
|
});
|
|
|
|
self
|
|
|
|
.send_response(&response)
|
|
|
|
.writer
|
|
|
|
.write(&response)
|
|
|
|
.await
|
|
|
|
.context("sending login response")
|
|
|
|
}
|
|
|
|
@ -142,27 +169,30 @@ impl<H: Handler + 'static> LoginHandler<H> { |
|
|
|
let (request_tx, request_rx) = mpsc::channel(100);
|
|
|
|
let (response_tx, response_rx) = mpsc::channel(100);
|
|
|
|
|
|
|
|
let worker_task = tokio::spawn(async move {
|
|
|
|
worker.run(request_tx, response_rx).await
|
|
|
|
});
|
|
|
|
let worker_task =
|
|
|
|
tokio::spawn(async move { worker.run(request_tx, response_rx).await });
|
|
|
|
let inner = self.inner;
|
|
|
|
tokio::task::spawn_blocking(move || inner.run(request_rx, response_tx))
|
|
|
|
.await
|
|
|
|
.context("joining handler")?
|
|
|
|
.context("running handler")?;
|
|
|
|
worker_task.await.context("joining worker")?.context("running worker")?;
|
|
|
|
worker_task
|
|
|
|
.await
|
|
|
|
.context("joining worker")?
|
|
|
|
.context("running worker")?;
|
|
|
|
|
|
|
|
info!("Client disconnecting, shutting down");
|
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Wraps a `LoginHandler` and shuts down gracefully upon receipt of a signal.
|
|
|
|
struct GracefulHandler<H> {
|
|
|
|
handler: LoginHandler<H>,
|
|
|
|
shutdown_rx: watch::Receiver<()>,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl<H: Handler + 'static> GracefulHandler<H> {
|
|
|
|
impl<H: ClientHandler + 'static> GracefulHandler<H> {
|
|
|
|
async fn run(mut self) -> anyhow::Result<()> {
|
|
|
|
tokio::select!(
|
|
|
|
result = self.handler.run() => {
|
|
|
|
@ -181,12 +211,16 @@ impl<H: Handler + 'static> GracefulHandler<H> { |
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Wraps `GracefulHandler` and sends the result of `run()` on a channel.
|
|
|
|
///
|
|
|
|
/// Defined separately from `GracefulHandler` to enable partial moves: we want
|
|
|
|
/// to move `self.handler` separely from `self.result_tx`.
|
|
|
|
struct SenderHandler<H> {
|
|
|
|
handler: GracefulHandler<H>,
|
|
|
|
result_tx: mpsc::Sender<anyhow::Result<()>>,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl<H: Handler + 'static> SenderHandler<H> {
|
|
|
|
impl<H: ClientHandler + 'static> SenderHandler<H> {
|
|
|
|
async fn run(self) {
|
|
|
|
let result = self.handler.run().await;
|
|
|
|
let _ = self.result_tx.send(result).await;
|
|
|
|
@ -195,20 +229,20 @@ impl<H: Handler + 'static> SenderHandler<H> { |
|
|
|
|
|
|
|
/// A builder for Server instances.
|
|
|
|
#[derive(Default)]
|
|
|
|
pub struct ServerBuilder<H> {
|
|
|
|
handler: H,
|
|
|
|
pub struct ServerBuilder<F> {
|
|
|
|
factory: F,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl<H: Handler> ServerBuilder<H> {
|
|
|
|
impl<F: ClientHandlerFactory> ServerBuilder<F> {
|
|
|
|
/// `handler` will be used to handle incoming connections.
|
|
|
|
pub fn new(handler: H) -> Self {
|
|
|
|
ServerBuilder { handler }
|
|
|
|
pub fn new(factory: F) -> Self {
|
|
|
|
ServerBuilder { factory }
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Binds to a localhost port, then returns a server and its handle.
|
|
|
|
pub async fn bind(self) -> io::Result<(Server<H>, ServerHandle)> {
|
|
|
|
let listener = TcpListener::bind("localhost:0").await?;
|
|
|
|
let address = listener.local_addr()?;
|
|
|
|
pub async fn bind(self) -> anyhow::Result<(Server<F>, ServerHandle)> {
|
|
|
|
let listener = TcpListener::bind("localhost:0").await.context("binding")?;
|
|
|
|
let address = listener.local_addr().context("getting local address")?;
|
|
|
|
|
|
|
|
let (shutdown_tx, shutdown_rx) = oneshot::channel();
|
|
|
|
let (handler_shutdown_tx, handler_shutdown_rx) = watch::channel(());
|
|
|
|
@ -222,7 +256,7 @@ impl<H: Handler> ServerBuilder<H> { |
|
|
|
handler_shutdown_rx,
|
|
|
|
result_tx,
|
|
|
|
result_rx,
|
|
|
|
handler: self.handler,
|
|
|
|
factory: self.factory,
|
|
|
|
},
|
|
|
|
ServerHandle {
|
|
|
|
shutdown_tx,
|
|
|
|
@ -242,7 +276,7 @@ pub enum ShutdownType { |
|
|
|
}
|
|
|
|
|
|
|
|
/// A simple server for connecting to in tests.
|
|
|
|
pub struct Server<H> {
|
|
|
|
pub struct Server<F> {
|
|
|
|
// Listener for new connections.
|
|
|
|
listener: TcpListener,
|
|
|
|
|
|
|
|
@ -257,8 +291,8 @@ pub struct Server<H> { |
|
|
|
result_tx: mpsc::Sender<anyhow::Result<()>>,
|
|
|
|
result_rx: mpsc::Receiver<anyhow::Result<()>>,
|
|
|
|
|
|
|
|
// Handler used for incoming connections.
|
|
|
|
handler: H,
|
|
|
|
// Factory used to create handlers for incoming connections.
|
|
|
|
factory: F,
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Allows interacting with a running `Server`.
|
|
|
|
@ -281,7 +315,7 @@ impl ServerHandle { |
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
impl<H: Handler + 'static> Server<H> {
|
|
|
|
impl<F: ClientHandlerFactory> Server<F> {
|
|
|
|
/// Returns the address to which this server is bound.
|
|
|
|
/// This is always localhost and a random port chosen by the OS.
|
|
|
|
pub fn address(&self) -> io::Result<SocketAddr> {
|
|
|
|
@ -298,7 +332,7 @@ impl<H: Handler + 'static> Server<H> { |
|
|
|
reader: FrameReader::new(read_half),
|
|
|
|
writer: FrameWriter::new(write_half),
|
|
|
|
peer_address,
|
|
|
|
inner: self.handler.clone(),
|
|
|
|
inner: self.factory.make(),
|
|
|
|
},
|
|
|
|
shutdown_rx: self.handler_shutdown_rx.clone(),
|
|
|
|
},
|
|
|
|
@ -368,11 +402,9 @@ impl<H: Handler + 'static> Server<H> { |
|
|
|
|
|
|
|
#[cfg(test)]
|
|
|
|
mod tests {
|
|
|
|
use std::io;
|
|
|
|
|
|
|
|
use tokio::net::TcpStream;
|
|
|
|
|
|
|
|
use super::{ServerBuilder, ShutdownType, UserStatusHandler};
|
|
|
|
use super::{ServerBuilder, ShutdownType, UserStatusHandlerFactory};
|
|
|
|
|
|
|
|
// Enable capturing logs in tests.
|
|
|
|
fn init() {
|
|
|
|
@ -383,7 +415,11 @@ mod tests { |
|
|
|
async fn new_binds_to_localhost() {
|
|
|
|
init();
|
|
|
|
|
|
|
|
let (server, handle) = ServerBuilder::new(UserStatusHandler::default()).bind().await.unwrap();
|
|
|
|
let (server, handle) =
|
|
|
|
ServerBuilder::new(UserStatusHandlerFactory::default())
|
|
|
|
.bind()
|
|
|
|
.await
|
|
|
|
.unwrap();
|
|
|
|
assert!(server.address().unwrap().ip().is_loopback());
|
|
|
|
assert_eq!(server.address().unwrap(), handle.address());
|
|
|
|
}
|
|
|
|
@ -392,7 +428,11 @@ mod tests { |
|
|
|
async fn accepts_incoming_connections() {
|
|
|
|
init();
|
|
|
|
|
|
|
|
let (server, handle) = ServerBuilder::new(UserStatusHandler::default()).bind().await.unwrap();
|
|
|
|
let (server, handle) =
|
|
|
|
ServerBuilder::new(UserStatusHandlerFactory::default())
|
|
|
|
.bind()
|
|
|
|
.await
|
|
|
|
.unwrap();
|
|
|
|
let server_task = tokio::spawn(server.serve());
|
|
|
|
|
|
|
|
// The connection succeeds.
|
|
|
|
@ -411,7 +451,11 @@ mod tests { |
|
|
|
async fn serve_yields_handler_error() {
|
|
|
|
init();
|
|
|
|
|
|
|
|
let (mut server, handle) = ServerBuilder::new(UserStatusHandler::default()).bind().await.unwrap();
|
|
|
|
let (mut server, handle) =
|
|
|
|
ServerBuilder::new(UserStatusHandlerFactory::default())
|
|
|
|
.bind()
|
|
|
|
.await
|
|
|
|
.unwrap();
|
|
|
|
|
|
|
|
// The connection is accepted, then immediately closed.
|
|
|
|
let address = handle.address();
|
|
|
|
|