//! A client interface for remote servers. use log::{debug, info}; use thiserror::Error; use tokio::net::TcpStream; use tokio::sync::mpsc; use crate::core::{Worker, WorkerError}; use crate::server::{ Credentials, LoginResponse, ServerRequest, ServerResponse, Version, }; /// A `Worker` that sends `ServerRequest`s and receives `ServerResponse`s. pub type ClientWorker = Worker; /// A client for the client-server protocol. #[derive(Debug)] pub struct Client { stream: TcpStream, version: Version, } /// An error that arose while logging in to a remote server. #[derive(Debug, Error)] pub enum ClientLoginError { #[error("login failed: {0}")] LoginFailed(String, ClientWorker), #[error("unexpected response: {0:?}")] UnexpectedResponse(ServerResponse), #[error("send error: {0}")] SendError(#[from] mpsc::error::SendError), #[error("worker error: {0}")] WorkerError(#[from] WorkerError), #[error("stream closed unexpectedly")] StreamClosed, } impl Client { /// Instantiates a new client pub fn new(stream: TcpStream) -> Self { Client { stream, version: Version::default(), } } /// Sets a custom version to identify as to the server. pub fn with_version(mut self, version: Version) -> Self { self.version = version; self } /// Performs the login exchange, presenting `credentials` to the server. pub async fn login( self, credentials: Credentials, ) -> Result { let mut worker = ClientWorker::new(self.stream); let (request_tx, request_rx) = mpsc::channel(1); let (response_tx, mut response_rx) = mpsc::channel(1); let worker_task = tokio::spawn(async move { worker .run(response_tx, request_rx) .await .map(move |()| worker) }); let login_request = credentials.into_login_request(self.version); debug!("Sending login request: {:?}", login_request); request_tx.send(login_request.into()).await?; let optional_response = response_rx.recv().await; // Join the worker even if we received `None`, in case it failed. // Panic in case of join error, as if we had run the worker itself. drop(request_tx); let worker = worker_task.await.expect("joining worker")?; let response = match optional_response { None => return Err(ClientLoginError::StreamClosed), Some(response) => response, }; debug!("Received first response: {:?}", response); match response { ServerResponse::LoginResponse(LoginResponse::LoginOk { motd, ip, password_md5_opt, }) => { info!("Login: success!"); info!("Login: Message Of The Day: {}", motd); info!("Login: Public IP address: {}", ip); info!("Login: Password MD5: {:?}", password_md5_opt); Ok(worker) } ServerResponse::LoginResponse(LoginResponse::LoginFail { reason }) => { Err(ClientLoginError::LoginFailed(reason, worker)) } response => Err(ClientLoginError::UnexpectedResponse(response)), } } } #[cfg(test)] mod tests { use tokio::net::TcpStream; use tokio::sync::mpsc; use crate::server::testing::{ServerBuilder, ShutdownType, UserStatusMap}; use crate::server::{ Credentials, ServerRequest, ServerResponse, UserStatusRequest, UserStatusResponse, }; use crate::UserStatus; use super::Client; // Enable capturing logs in tests. fn init() { let _ = env_logger::builder().is_test(true).try_init(); } // Returns default `Credentials` suitable for testing. fn credentials() -> Credentials { let user_name = "alice".to_string(); let password = "sekrit".to_string(); Credentials::new(user_name, password).expect("building credentials") } // TODO: Tests for all login error conditions: // // - login failed // - unexpected response // - read error // - write error // - stream closed #[tokio::test] async fn login_success() { init(); let (server, handle) = ServerBuilder::default() .bind() .await .expect("binding server"); let server_task = tokio::spawn(server.serve()); let stream = TcpStream::connect(handle.address()) .await .expect("connecting"); let worker = Client::new(stream) .login(credentials()) .await .expect("logging in"); drop(worker); handle.shutdown(ShutdownType::LameDuck); server_task .await .expect("joining server") .expect("running server"); } #[tokio::test] async fn simple_exchange() { init(); let response = UserStatusResponse { user_name: "shruti".to_string(), status: UserStatus::Online, is_privileged: false, }; let mut user_status_map = UserStatusMap::default(); user_status_map.insert(response.clone()); let (server, handle) = ServerBuilder::default() .with_user_status_map(user_status_map) .bind() .await .expect("binding server"); let server_task = tokio::spawn(server.serve()); let stream = TcpStream::connect(handle.address()) .await .expect("connecting"); let mut worker = Client::new(stream) .login(credentials()) .await .expect("logging in"); let (request_tx, request_rx) = mpsc::channel(100); let (response_tx, mut response_rx) = mpsc::channel(100); request_tx .send(ServerRequest::UserStatusRequest(UserStatusRequest { user_name: "shruti".to_string(), })) .await .expect("sending shruti"); request_tx .send(ServerRequest::UserStatusRequest(UserStatusRequest { user_name: "karandeep".to_string(), })) .await .expect("sending karandeep"); let worker_task = tokio::spawn(async move { worker.run(response_tx, request_rx).await }); assert_eq!( response_rx.recv().await, Some(ServerResponse::UserStatusResponse(response)) ); drop(request_tx); worker_task .await .expect("joining worker") .expect("running worker"); handle.shutdown(ShutdownType::LameDuck); server_task .await .expect("joining server") .expect("running server"); } }