//! A client interface for remote servers. use std::io; use log::{debug, info}; use thiserror::Error; use tokio::net::TcpStream; use crate::core::channel::{Channel, ChannelError}; use crate::server::{ Credentials, LoginResponse, ServerRequest, ServerResponse, Version, }; /// A client for the client-server protocol. #[derive(Debug)] pub struct Client { channel: Channel, version: Version, } /// An error that arose while logging in to a remote server. #[derive(Debug, Error)] pub enum ClientLoginError { #[error("login failed: {0}")] LoginFailed(String, Client), #[error("unexpected response: {0:?}")] UnexpectedResponse(ServerResponse), #[error("channel error: {0}")] ChannelError(#[from] ChannelError), } impl From for ClientLoginError { fn from(error: io::Error) -> Self { ClientLoginError::from(ChannelError::from(error)) } } /// A `Channel` that sends `ServerRequest`s and receives `ServerResponse`s. pub type ClientChannel = Channel; impl Client { /// Instantiates a new client pub fn new(tcp_stream: TcpStream) -> Self { Client { channel: Channel::new(tcp_stream), version: Version::default(), } } /// Sets a custom version to identify as to the server. pub fn with_version(mut self, version: Version) -> Self { self.version = version; self } /// Performs the login exchange, presenting `credentials` to the server. pub async fn login( mut self, credentials: Credentials, ) -> Result { let login_request = credentials.into_login_request(self.version); debug!("Client: sending login request: {:?}", login_request); let request = login_request.into(); self.channel.write(&request).await?; let response = self.channel.read().await?; debug!("Client: received first response: {:?}", response); match response { ServerResponse::LoginResponse(LoginResponse::LoginOk { motd, ip, password_md5_opt, }) => { info!("Client: Logged in successfully!"); info!("Client: Message Of The Day: {}", motd); info!("Client: Public IP address: {}", ip); info!("Client: Password MD5: {:?}", password_md5_opt); Ok(self.channel) } ServerResponse::LoginResponse(LoginResponse::LoginFail { reason }) => { Err(ClientLoginError::LoginFailed(reason, self)) } response => Err(ClientLoginError::UnexpectedResponse(response)), } } } #[cfg(test)] mod tests { use futures::stream::{empty, StreamExt}; use tokio::net::TcpStream; use tokio::sync::mpsc; use crate::server::testing::{ServerBuilder, ShutdownType, UserStatusMap}; use crate::server::{ Credentials, ServerRequest, ServerResponse, UserStatusRequest, UserStatusResponse, }; use crate::UserStatus; use super::Client; // Enable capturing logs in tests. fn init() { let _ = env_logger::builder().is_test(true).try_init(); } // Returns default `Credentials` suitable for testing. fn credentials() -> Credentials { let user_name = "alice".to_string(); let password = "sekrit".to_string(); Credentials::new(user_name, password).unwrap() } #[tokio::test] async fn login() { init(); let (server, handle) = ServerBuilder::default().bind().await.unwrap(); let server_task = tokio::spawn(server.serve()); let stream = TcpStream::connect(handle.address()).await.unwrap(); let channel = Client::new(stream).login(credentials()).await.unwrap(); // Send nothing, receive no responses. let inbound = channel.run(empty()); tokio::pin!(inbound); assert!(inbound.next().await.is_none()); handle.shutdown(ShutdownType::LameDuck); server_task.await.unwrap().unwrap(); } #[tokio::test] async fn simple_exchange() { init(); let response = UserStatusResponse { user_name: "alice".to_string(), status: UserStatus::Online, is_privileged: false, }; let mut user_status_map = UserStatusMap::default(); user_status_map.insert(response.clone()); let (server, handle) = ServerBuilder::default() .with_user_status_map(user_status_map) .bind() .await .unwrap(); let server_task = tokio::spawn(server.serve()); let stream = TcpStream::connect(handle.address()).await.unwrap(); let channel = Client::new(stream).login(credentials()).await.unwrap(); let outbound = Box::pin(async_stream::stream! { yield ServerRequest::UserStatusRequest(UserStatusRequest { user_name: "bob".to_string(), }); yield ServerRequest::UserStatusRequest(UserStatusRequest { user_name: "alice".to_string(), }); }); let inbound = channel.run(outbound); tokio::pin!(inbound); assert_eq!( inbound.next().await.unwrap().unwrap(), ServerResponse::UserStatusResponse(response) ); assert!(inbound.next().await.is_none()); handle.shutdown(ShutdownType::LameDuck); server_task.await.unwrap().unwrap(); } #[tokio::test] async fn stream_closed() { init(); let (server, handle) = ServerBuilder::default().bind().await.unwrap(); let server_task = tokio::spawn(server.serve()); let stream = TcpStream::connect(handle.address()).await.unwrap(); let channel = Client::new(stream).login(credentials()).await.unwrap(); let (_request_tx, mut request_rx) = mpsc::channel(1); let outbound = Box::pin(async_stream::stream! { while let Some(request) = request_rx.recv().await { yield request; } }); let inbound = channel.run(outbound); tokio::pin!(inbound); // Server shuts down, closing its connection before the client has had a // chance to send all of `outbound`. handle.shutdown(ShutdownType::Immediate); // Wait for the server to terminate, to avoid race conditions. server_task.await.unwrap().unwrap(); // Check that the client returns the correct error, then stops running. assert!(inbound .next() .await .unwrap() .unwrap_err() .is_unexpected_eof()); assert!(inbound.next().await.is_none()); } }