|
|
|
@ -5,17 +5,16 @@ use std::io; |
|
|
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
|
|
|
use std::sync::Arc;
|
|
|
|
|
|
|
|
use anyhow::{bail, Context};
|
|
|
|
use log::{debug, info, warn};
|
|
|
|
use parking_lot::Mutex;
|
|
|
|
use tokio::net::{
|
|
|
|
tcp::{OwnedReadHalf, OwnedWriteHalf},
|
|
|
|
TcpListener, TcpStream,
|
|
|
|
};
|
|
|
|
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
|
|
|
|
use tokio::net::{TcpListener, TcpStream};
|
|
|
|
use tokio::sync::mpsc;
|
|
|
|
use tokio::sync::oneshot;
|
|
|
|
use tokio::sync::watch;
|
|
|
|
|
|
|
|
use crate::core::{FrameReader, FrameWriter};
|
|
|
|
use crate::core::{FrameReader, FrameWriter, Worker};
|
|
|
|
use crate::server::{
|
|
|
|
LoginResponse, ServerRequest, ServerResponse, UserStatusRequest,
|
|
|
|
UserStatusResponse,
|
|
|
|
@ -38,14 +37,64 @@ impl UserStatusMap { |
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
struct Handler {
|
|
|
|
/// TODO: document.
|
|
|
|
pub trait Handler: Clone + Send {
|
|
|
|
/// TODO: document.
|
|
|
|
fn run(
|
|
|
|
self,
|
|
|
|
request_rx: mpsc::Receiver<ServerRequest>,
|
|
|
|
response_tx: mpsc::Sender<ServerResponse>,
|
|
|
|
) -> anyhow::Result<()>;
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Clone, Default)]
|
|
|
|
pub struct UserStatusHandler {
|
|
|
|
pub user_status_map: Arc<Mutex<UserStatusMap>>,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl UserStatusHandler {
|
|
|
|
pub fn new(user_status_map: UserStatusMap) -> Self {
|
|
|
|
Self {
|
|
|
|
user_status_map: Arc::new(Mutex::new(user_status_map)),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
impl Handler for UserStatusHandler {
|
|
|
|
fn run(
|
|
|
|
self,
|
|
|
|
mut request_rx: mpsc::Receiver<ServerRequest>,
|
|
|
|
response_tx: mpsc::Sender<ServerResponse>,
|
|
|
|
) -> anyhow::Result<()> {
|
|
|
|
while let Some(request) = request_rx.blocking_recv() {
|
|
|
|
match request {
|
|
|
|
ServerRequest::UserStatusRequest(UserStatusRequest { user_name }) => {
|
|
|
|
let entry = self.user_status_map.lock().get(&user_name);
|
|
|
|
if let Some(response) = entry {
|
|
|
|
response_tx
|
|
|
|
.blocking_send(ServerResponse::UserStatusResponse(response))
|
|
|
|
.context("sending response")?;
|
|
|
|
} else {
|
|
|
|
warn!("Received UserStatusRequest for unknown user {}", user_name);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
_ => {
|
|
|
|
warn!("Unhandled request: {:?}", request);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
struct LoginHandler<H> {
|
|
|
|
reader: FrameReader<ServerRequest, OwnedReadHalf>,
|
|
|
|
writer: FrameWriter<ServerResponse, OwnedWriteHalf>,
|
|
|
|
peer_address: SocketAddr,
|
|
|
|
user_status_map: Arc<Mutex<UserStatusMap>>,
|
|
|
|
inner: H,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl Handler {
|
|
|
|
impl<H: Handler + 'static> LoginHandler<H> {
|
|
|
|
fn ipv4_address(&self) -> Ipv4Addr {
|
|
|
|
match self.peer_address.ip() {
|
|
|
|
IpAddr::V4(ipv4_addr) => ipv4_addr,
|
|
|
|
@ -61,22 +110,16 @@ impl Handler { |
|
|
|
self.writer.write(response).await
|
|
|
|
}
|
|
|
|
|
|
|
|
async fn handle_login(&mut self) -> io::Result<()> {
|
|
|
|
match self.reader.read().await? {
|
|
|
|
async fn handle_login(&mut self) -> anyhow::Result<()> {
|
|
|
|
match self.reader.read().await.context("reading first request")? {
|
|
|
|
Some(ServerRequest::LoginRequest(request)) => {
|
|
|
|
info!("Received login request: {:?}", request);
|
|
|
|
}
|
|
|
|
Some(request) => {
|
|
|
|
return Err(io::Error::new(
|
|
|
|
io::ErrorKind::InvalidData,
|
|
|
|
format!("expected login request, got: {:?}", request),
|
|
|
|
));
|
|
|
|
bail!("expected login request, got: {:?}", request);
|
|
|
|
}
|
|
|
|
None => {
|
|
|
|
return Err(io::Error::new(
|
|
|
|
io::ErrorKind::UnexpectedEof,
|
|
|
|
"expected login request".to_string(),
|
|
|
|
));
|
|
|
|
bail!("expected login request, got eof");
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
@ -85,72 +128,65 @@ impl Handler { |
|
|
|
ip: self.ipv4_address(),
|
|
|
|
password_md5_opt: None,
|
|
|
|
});
|
|
|
|
self.send_response(&response).await
|
|
|
|
self
|
|
|
|
.send_response(&response)
|
|
|
|
.await
|
|
|
|
.context("sending login response")
|
|
|
|
}
|
|
|
|
|
|
|
|
async fn handle_request(&mut self, request: ServerRequest) -> io::Result<()> {
|
|
|
|
debug!("Received request: {:?}", request);
|
|
|
|
|
|
|
|
match request {
|
|
|
|
ServerRequest::UserStatusRequest(UserStatusRequest { user_name }) => {
|
|
|
|
let entry = self.user_status_map.lock().get(&user_name);
|
|
|
|
if let Some(response) = entry {
|
|
|
|
let response = ServerResponse::UserStatusResponse(response);
|
|
|
|
self.send_response(&response).await?;
|
|
|
|
} else {
|
|
|
|
warn!("Received UserStatusRequest for unknown user {}", user_name);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
_ => {
|
|
|
|
warn!("Unhandled request: {:?}", request);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
async fn run(mut self) -> anyhow::Result<()> {
|
|
|
|
self.handle_login().await?;
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
let mut worker = Worker::from_parts(self.reader, self.writer);
|
|
|
|
|
|
|
|
async fn run(mut self) -> io::Result<()> {
|
|
|
|
self.handle_login().await?;
|
|
|
|
let (request_tx, request_rx) = mpsc::channel(100);
|
|
|
|
let (response_tx, response_rx) = mpsc::channel(100);
|
|
|
|
|
|
|
|
while let Some(request) = self.reader.read().await? {
|
|
|
|
self.handle_request(request).await?;
|
|
|
|
}
|
|
|
|
let worker_task = tokio::spawn(async move {
|
|
|
|
worker.run(request_tx, response_rx).await
|
|
|
|
});
|
|
|
|
let inner = self.inner;
|
|
|
|
tokio::task::spawn_blocking(move || inner.run(request_rx, response_tx))
|
|
|
|
.await
|
|
|
|
.context("joining handler")?
|
|
|
|
.context("running handler")?;
|
|
|
|
worker_task.await.context("joining worker")?.context("running worker")?;
|
|
|
|
|
|
|
|
info!("Client disconnecting, shutting down");
|
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
struct GracefulHandler {
|
|
|
|
handler: Handler,
|
|
|
|
struct GracefulHandler<H> {
|
|
|
|
handler: LoginHandler<H>,
|
|
|
|
shutdown_rx: watch::Receiver<()>,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl GracefulHandler {
|
|
|
|
async fn run(mut self) -> io::Result<()> {
|
|
|
|
impl<H: Handler + 'static> GracefulHandler<H> {
|
|
|
|
async fn run(mut self) -> anyhow::Result<()> {
|
|
|
|
tokio::select!(
|
|
|
|
result = self.handler.run() => {
|
|
|
|
if let Err(ref error) = result {
|
|
|
|
warn!("Handler returned error {:?}", error);
|
|
|
|
warn!("LoginHandler returned error {:?}", error);
|
|
|
|
}
|
|
|
|
result
|
|
|
|
},
|
|
|
|
// Ignore receive errors - if shutdown_rx's sender is dropped, we take
|
|
|
|
// that as a signal to shut down too.
|
|
|
|
_ = self.shutdown_rx.changed() => {
|
|
|
|
info!("Handler shutting down.");
|
|
|
|
info!("GracefulHandler shutting down.");
|
|
|
|
Ok(())
|
|
|
|
},
|
|
|
|
)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
struct SenderHandler {
|
|
|
|
handler: GracefulHandler,
|
|
|
|
result_tx: mpsc::Sender<io::Result<()>>,
|
|
|
|
struct SenderHandler<H> {
|
|
|
|
handler: GracefulHandler<H>,
|
|
|
|
result_tx: mpsc::Sender<anyhow::Result<()>>,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl SenderHandler {
|
|
|
|
impl<H: Handler + 'static> SenderHandler<H> {
|
|
|
|
async fn run(self) {
|
|
|
|
let result = self.handler.run().await;
|
|
|
|
let _ = self.result_tx.send(result).await;
|
|
|
|
@ -159,28 +195,21 @@ impl SenderHandler { |
|
|
|
|
|
|
|
/// A builder for Server instances.
|
|
|
|
#[derive(Default)]
|
|
|
|
pub struct ServerBuilder {
|
|
|
|
user_status_map: Option<Arc<Mutex<UserStatusMap>>>,
|
|
|
|
pub struct ServerBuilder<H> {
|
|
|
|
handler: H,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl ServerBuilder {
|
|
|
|
/// Sets the UserStatusMap which the server will use to respond to
|
|
|
|
/// UserStatusRequest messages.
|
|
|
|
pub fn with_user_status_map(mut self, map: UserStatusMap) -> Self {
|
|
|
|
self.user_status_map = Some(Arc::new(Mutex::new(map)));
|
|
|
|
self
|
|
|
|
impl<H: Handler> ServerBuilder<H> {
|
|
|
|
/// `handler` will be used to handle incoming connections.
|
|
|
|
pub fn new(handler: H) -> Self {
|
|
|
|
ServerBuilder { handler }
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Binds to a localhost port, then returns a server and its handle.
|
|
|
|
pub async fn bind(self) -> io::Result<(Server, ServerHandle)> {
|
|
|
|
pub async fn bind(self) -> io::Result<(Server<H>, ServerHandle)> {
|
|
|
|
let listener = TcpListener::bind("localhost:0").await?;
|
|
|
|
let address = listener.local_addr()?;
|
|
|
|
|
|
|
|
let user_status_map = match self.user_status_map {
|
|
|
|
Some(user_status_map) => user_status_map,
|
|
|
|
None => Arc::new(Mutex::new(UserStatusMap::default())),
|
|
|
|
};
|
|
|
|
|
|
|
|
let (shutdown_tx, shutdown_rx) = oneshot::channel();
|
|
|
|
let (handler_shutdown_tx, handler_shutdown_rx) = watch::channel(());
|
|
|
|
let (result_tx, result_rx) = mpsc::channel(1);
|
|
|
|
@ -193,7 +222,7 @@ impl ServerBuilder { |
|
|
|
handler_shutdown_rx,
|
|
|
|
result_tx,
|
|
|
|
result_rx,
|
|
|
|
user_status_map,
|
|
|
|
handler: self.handler,
|
|
|
|
},
|
|
|
|
ServerHandle {
|
|
|
|
shutdown_tx,
|
|
|
|
@ -213,7 +242,7 @@ pub enum ShutdownType { |
|
|
|
}
|
|
|
|
|
|
|
|
/// A simple server for connecting to in tests.
|
|
|
|
pub struct Server {
|
|
|
|
pub struct Server<H> {
|
|
|
|
// Listener for new connections.
|
|
|
|
listener: TcpListener,
|
|
|
|
|
|
|
|
@ -225,11 +254,11 @@ pub struct Server { |
|
|
|
handler_shutdown_rx: watch::Receiver<()>,
|
|
|
|
|
|
|
|
// Channel for receiving results back from handlers.
|
|
|
|
result_tx: mpsc::Sender<io::Result<()>>,
|
|
|
|
result_rx: mpsc::Receiver<io::Result<()>>,
|
|
|
|
result_tx: mpsc::Sender<anyhow::Result<()>>,
|
|
|
|
result_rx: mpsc::Receiver<anyhow::Result<()>>,
|
|
|
|
|
|
|
|
// Shared state for handlers to use when serving responses.
|
|
|
|
user_status_map: Arc<Mutex<UserStatusMap>>,
|
|
|
|
// Handler used for incoming connections.
|
|
|
|
handler: H,
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Allows interacting with a running `Server`.
|
|
|
|
@ -252,7 +281,7 @@ impl ServerHandle { |
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
impl Server {
|
|
|
|
impl<H: Handler + 'static> Server<H> {
|
|
|
|
/// Returns the address to which this server is bound.
|
|
|
|
/// This is always localhost and a random port chosen by the OS.
|
|
|
|
pub fn address(&self) -> io::Result<SocketAddr> {
|
|
|
|
@ -265,11 +294,11 @@ impl Server { |
|
|
|
|
|
|
|
let handler = SenderHandler {
|
|
|
|
handler: GracefulHandler {
|
|
|
|
handler: Handler {
|
|
|
|
handler: LoginHandler {
|
|
|
|
reader: FrameReader::new(read_half),
|
|
|
|
writer: FrameWriter::new(write_half),
|
|
|
|
peer_address,
|
|
|
|
user_status_map: self.user_status_map.clone(),
|
|
|
|
inner: self.handler.clone(),
|
|
|
|
},
|
|
|
|
shutdown_rx: self.handler_shutdown_rx.clone(),
|
|
|
|
},
|
|
|
|
@ -296,11 +325,11 @@ impl Server { |
|
|
|
/// - an error was encountered while listening
|
|
|
|
/// - an error was encountered while serving a request
|
|
|
|
///
|
|
|
|
pub async fn serve(mut self) -> io::Result<()> {
|
|
|
|
pub async fn serve(mut self) -> anyhow::Result<()> {
|
|
|
|
loop {
|
|
|
|
tokio::select!(
|
|
|
|
result = self.listener.accept() => {
|
|
|
|
let (stream, peer_address) = result?;
|
|
|
|
let (stream, peer_address) = result.context("accepting stream")?;
|
|
|
|
self.spawn_handler(stream, peer_address);
|
|
|
|
},
|
|
|
|
result = &mut self.shutdown_rx => {
|
|
|
|
@ -343,7 +372,7 @@ mod tests { |
|
|
|
|
|
|
|
use tokio::net::TcpStream;
|
|
|
|
|
|
|
|
use super::{ServerBuilder, ShutdownType};
|
|
|
|
use super::{ServerBuilder, ShutdownType, UserStatusHandler};
|
|
|
|
|
|
|
|
// Enable capturing logs in tests.
|
|
|
|
fn init() {
|
|
|
|
@ -354,7 +383,7 @@ mod tests { |
|
|
|
async fn new_binds_to_localhost() {
|
|
|
|
init();
|
|
|
|
|
|
|
|
let (server, handle) = ServerBuilder::default().bind().await.unwrap();
|
|
|
|
let (server, handle) = ServerBuilder::new(UserStatusHandler::default()).bind().await.unwrap();
|
|
|
|
assert!(server.address().unwrap().ip().is_loopback());
|
|
|
|
assert_eq!(server.address().unwrap(), handle.address());
|
|
|
|
}
|
|
|
|
@ -363,7 +392,7 @@ mod tests { |
|
|
|
async fn accepts_incoming_connections() {
|
|
|
|
init();
|
|
|
|
|
|
|
|
let (server, handle) = ServerBuilder::default().bind().await.unwrap();
|
|
|
|
let (server, handle) = ServerBuilder::new(UserStatusHandler::default()).bind().await.unwrap();
|
|
|
|
let server_task = tokio::spawn(server.serve());
|
|
|
|
|
|
|
|
// The connection succeeds.
|
|
|
|
@ -382,7 +411,7 @@ mod tests { |
|
|
|
async fn serve_yields_handler_error() {
|
|
|
|
init();
|
|
|
|
|
|
|
|
let (mut server, handle) = ServerBuilder::default().bind().await.unwrap();
|
|
|
|
let (mut server, handle) = ServerBuilder::new(UserStatusHandler::default()).bind().await.unwrap();
|
|
|
|
|
|
|
|
// The connection is accepted, then immediately closed.
|
|
|
|
let address = handle.address();
|
|
|
|
@ -397,7 +426,6 @@ mod tests { |
|
|
|
handle.shutdown(ShutdownType::LameDuck);
|
|
|
|
|
|
|
|
// Drain outstanding requests, encountering the error.
|
|
|
|
let error = server.serve().await.unwrap_err();
|
|
|
|
assert_eq!(error.kind(), io::ErrorKind::UnexpectedEof);
|
|
|
|
server.serve().await.unwrap_err();
|
|
|
|
}
|
|
|
|
}
|