diff --git a/proto/src/server/client.rs b/proto/src/server/client.rs index dbcb410..b3f6d3e 100644 --- a/proto/src/server/client.rs +++ b/proto/src/server/client.rs @@ -2,7 +2,6 @@ use std::io; -use futures::stream::Stream; use log::{debug, info}; use thiserror::Error; use tokio::net::TcpStream; @@ -10,22 +9,18 @@ use tokio::net::TcpStream; use crate::core::channel::{Channel, ChannelError}; use crate::server::{Credentials, LoginResponse, ServerRequest, ServerResponse, Version}; -/// Specifies options for a new `Client`. -pub struct ClientOptions { - pub credentials: Credentials, - pub version: Version, -} - /// A client for the client-server protocol. +#[derive(Debug)] pub struct Client { channel: Channel, + version: Version, } /// An error that arose while logging in to a remote server. #[derive(Debug, Error)] pub enum ClientLoginError { #[error("login failed: {0}")] - LoginFailed(String), + LoginFailed(String, Client), #[error("unexpected response: {0:?}")] UnexpectedResponse(ServerResponse), @@ -40,24 +35,30 @@ impl From for ClientLoginError { } } +/// A `Channel` that sends `ServerRequest`s and receives `ServerResponse`s. +pub type ClientChannel = Channel; + impl Client { - pub async fn login( - tcp_stream: TcpStream, - options: ClientOptions, - ) -> Result { - let mut client = Client { + /// Instantiates a new client + pub fn new(tcp_stream: TcpStream) -> Self { + Client { channel: Channel::new(tcp_stream), - }; - - client.handshake(options).await?; + version: Version::default(), + } + } - Ok(client) + /// Sets a custom version to identify as to the server. + pub fn with_version(mut self, version: Version) -> Self { + self.version = version; + self } - // Performs the login exchange. - // Called this way because `login` is already taken. - async fn handshake(&mut self, options: ClientOptions) -> Result<(), ClientLoginError> { - let login_request = options.credentials.into_login_request(options.version); + /// Performs the login exchange, presenting `credentials` to the server. + pub async fn login( + mut self, + credentials: Credentials, + ) -> Result { + let login_request = credentials.into_login_request(self.version); debug!("Client: sending login request: {:?}", login_request); let request = login_request.into(); @@ -76,21 +77,14 @@ impl Client { info!("Client: Message Of The Day: {}", motd); info!("Client: Public IP address: {}", ip); info!("Client: Password MD5: {:?}", password_md5_opt); - Ok(()) + Ok(self.channel) } ServerResponse::LoginResponse(LoginResponse::LoginFail { reason }) => { - Err(ClientLoginError::LoginFailed(reason)) + Err(ClientLoginError::LoginFailed(reason, self)) } response => Err(ClientLoginError::UnexpectedResponse(response)), } } - - pub fn run>( - self, - request_stream: S, - ) -> impl Stream> { - self.channel.run(request_stream) - } } #[cfg(test)] @@ -105,23 +99,18 @@ mod tests { }; use crate::UserStatus; - use super::{Client, ClientOptions, Version}; + use super::Client; // Enable capturing logs in tests. fn init() { let _ = env_logger::builder().is_test(true).try_init(); } - // Returns default ClientOptions suitable for testing. - fn client_options() -> ClientOptions { + // Returns default `Credentials` suitable for testing. + fn credentials() -> Credentials { let user_name = "alice".to_string(); let password = "sekrit".to_string(); - let credentials = Credentials::new(user_name, password).unwrap(); - - ClientOptions { - credentials, - version: Version::default(), - } + Credentials::new(user_name, password).unwrap() } #[tokio::test] @@ -133,10 +122,10 @@ mod tests { let stream = TcpStream::connect(handle.address()).await.unwrap(); - let client = Client::login(stream, client_options()).await.unwrap(); + let channel = Client::new(stream).login(credentials()).await.unwrap(); // Send nothing, receive no responses. - let inbound = client.run(empty()); + let inbound = channel.run(empty()); tokio::pin!(inbound); assert!(inbound.next().await.is_none()); @@ -167,7 +156,7 @@ mod tests { let stream = TcpStream::connect(handle.address()).await.unwrap(); - let client = Client::login(stream, client_options()).await.unwrap(); + let channel = Client::new(stream).login(credentials()).await.unwrap(); let outbound = Box::pin(async_stream::stream! { yield ServerRequest::UserStatusRequest(UserStatusRequest { @@ -178,7 +167,7 @@ mod tests { }); }); - let inbound = client.run(outbound); + let inbound = channel.run(outbound); tokio::pin!(inbound); assert_eq!( @@ -200,7 +189,7 @@ mod tests { let stream = TcpStream::connect(handle.address()).await.unwrap(); - let client = Client::login(stream, client_options()).await.unwrap(); + let channel = Client::new(stream).login(credentials()).await.unwrap(); let (_request_tx, mut request_rx) = mpsc::channel(1); let outbound = Box::pin(async_stream::stream! { @@ -209,7 +198,7 @@ mod tests { } }); - let inbound = client.run(outbound); + let inbound = channel.run(outbound); tokio::pin!(inbound); // Server shuts down, closing its connection before the client has had a diff --git a/proto/src/server/mod.rs b/proto/src/server/mod.rs index b762b04..b038604 100644 --- a/proto/src/server/mod.rs +++ b/proto/src/server/mod.rs @@ -7,7 +7,7 @@ mod response; mod testing; mod version; -pub use self::client::{Client, ClientOptions}; +pub use self::client::{Client, ClientChannel, ClientLoginError}; pub use self::credentials::Credentials; pub use self::request::*; pub use self::response::*; diff --git a/proto/src/server/version.rs b/proto/src/server/version.rs index c4c54c7..e0da3c2 100644 --- a/proto/src/server/version.rs +++ b/proto/src/server/version.rs @@ -1,6 +1,7 @@ //! Protocol versioning. /// Specifies a protocol version. +#[derive(Clone, Copy, Debug)] pub struct Version { /// The major version number. pub major: u32, diff --git a/proto/tests/connect.rs b/proto/tests/connect.rs index 59fd983..e750045 100644 --- a/proto/tests/connect.rs +++ b/proto/tests/connect.rs @@ -3,9 +3,7 @@ use tokio::io; use tokio::net; use tokio::sync::mpsc; -use solstice_proto::server::{ - Client, ClientOptions, Credentials, ServerResponse, UserStatusRequest, Version, -}; +use solstice_proto::server::{Client, Credentials, ServerResponse, UserStatusRequest}; // Enable capturing logs in tests. fn init() { @@ -20,14 +18,9 @@ fn make_user_name(test_name: &str) -> String { format!("st_{}", test_name) } -fn client_options(user_name: String) -> ClientOptions { +fn make_credentials(user_name: String) -> Credentials { let password = "abcdefgh".to_string(); - let credentials = Credentials::new(user_name, password).unwrap(); - - ClientOptions { - credentials, - version: Version::default(), - } + Credentials::new(user_name, password).unwrap() } #[tokio::test] @@ -36,10 +29,10 @@ async fn integration_connect() { let stream = connect().await.unwrap(); - let options = client_options(make_user_name("connect")); - let client = Client::login(stream, options).await.unwrap(); + let credentials = make_credentials(make_user_name("connect")); + let channel = Client::new(stream).login(credentials).await.unwrap(); - let inbound = client.run(stream::pending()); + let inbound = channel.run(stream::pending()); tokio::pin!(inbound); assert!(inbound.next().await.is_some()); @@ -52,8 +45,8 @@ async fn integration_check_user_status() { let stream = connect().await.unwrap(); let user_name = make_user_name("check_user_status"); - let options = client_options(user_name.clone()); - let client = Client::login(stream, options).await.unwrap(); + let credentials = make_credentials(user_name.clone()); + let channel = Client::new(stream).login(credentials).await.unwrap(); let (request_tx, mut request_rx) = mpsc::channel(1); @@ -63,7 +56,7 @@ async fn integration_check_user_status() { } }; - let inbound = client.run(outbound); + let inbound = channel.run(outbound); tokio::pin!(inbound); request_tx