Solstice client.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

223 lines
6.1 KiB

//! A client interface for remote servers.
use std::io;
use log::{debug, info};
use thiserror::Error;
use tokio::net::TcpStream;
use crate::core::channel::{Channel, ChannelError};
use crate::server::{
Credentials, LoginResponse, ServerRequest, ServerResponse, Version,
};
/// A client for the client-server protocol.
#[derive(Debug)]
pub struct Client {
channel: Channel<ServerResponse, ServerRequest>,
version: Version,
}
/// An error that arose while logging in to a remote server.
#[derive(Debug, Error)]
pub enum ClientLoginError {
#[error("login failed: {0}")]
LoginFailed(String, Client),
#[error("unexpected response: {0:?}")]
UnexpectedResponse(ServerResponse),
#[error("channel error: {0}")]
ChannelError(#[from] ChannelError),
}
impl From<io::Error> for ClientLoginError {
fn from(error: io::Error) -> Self {
ClientLoginError::from(ChannelError::from(error))
}
}
/// A `Channel` that sends `ServerRequest`s and receives `ServerResponse`s.
pub type ClientChannel = Channel<ServerResponse, ServerRequest>;
impl Client {
/// Instantiates a new client
pub fn new(tcp_stream: TcpStream) -> Self {
Client {
channel: Channel::new(tcp_stream),
version: Version::default(),
}
}
/// Sets a custom version to identify as to the server.
pub fn with_version(mut self, version: Version) -> Self {
self.version = version;
self
}
/// Performs the login exchange, presenting `credentials` to the server.
pub async fn login(
mut self,
credentials: Credentials,
) -> Result<ClientChannel, ClientLoginError> {
let login_request = credentials.into_login_request(self.version);
debug!("Client: sending login request: {:?}", login_request);
let request = login_request.into();
self.channel.write(&request).await?;
let response = self.channel.read().await?;
debug!("Client: received first response: {:?}", response);
match response {
ServerResponse::LoginResponse(LoginResponse::LoginOk {
motd,
ip,
password_md5_opt,
}) => {
info!("Client: Logged in successfully!");
info!("Client: Message Of The Day: {}", motd);
info!("Client: Public IP address: {}", ip);
info!("Client: Password MD5: {:?}", password_md5_opt);
Ok(self.channel)
}
ServerResponse::LoginResponse(LoginResponse::LoginFail { reason }) => {
Err(ClientLoginError::LoginFailed(reason, self))
}
response => Err(ClientLoginError::UnexpectedResponse(response)),
}
}
}
#[cfg(test)]
mod tests {
use futures::stream::{empty, StreamExt};
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use crate::server::testing::{ServerBuilder, ShutdownType, UserStatusMap};
use crate::server::{
Credentials, ServerRequest, ServerResponse, UserStatusRequest,
UserStatusResponse,
};
use crate::UserStatus;
use super::Client;
// Enable capturing logs in tests.
fn init() {
let _ = env_logger::builder().is_test(true).try_init();
}
// Returns default `Credentials` suitable for testing.
fn credentials() -> Credentials {
let user_name = "alice".to_string();
let password = "sekrit".to_string();
Credentials::new(user_name, password).unwrap()
}
#[tokio::test]
async fn login() {
init();
let (server, handle) = ServerBuilder::default().bind().await.unwrap();
let server_task = tokio::spawn(server.serve());
let stream = TcpStream::connect(handle.address()).await.unwrap();
let channel = Client::new(stream).login(credentials()).await.unwrap();
// Send nothing, receive no responses.
let inbound = channel.run(empty());
tokio::pin!(inbound);
assert!(inbound.next().await.is_none());
handle.shutdown(ShutdownType::LameDuck);
server_task.await.unwrap().unwrap();
}
#[tokio::test]
async fn simple_exchange() {
init();
let response = UserStatusResponse {
user_name: "alice".to_string(),
status: UserStatus::Online,
is_privileged: false,
};
let mut user_status_map = UserStatusMap::default();
user_status_map.insert(response.clone());
let (server, handle) = ServerBuilder::default()
.with_user_status_map(user_status_map)
.bind()
.await
.unwrap();
let server_task = tokio::spawn(server.serve());
let stream = TcpStream::connect(handle.address()).await.unwrap();
let channel = Client::new(stream).login(credentials()).await.unwrap();
let outbound = Box::pin(async_stream::stream! {
yield ServerRequest::UserStatusRequest(UserStatusRequest {
user_name: "bob".to_string(),
});
yield ServerRequest::UserStatusRequest(UserStatusRequest {
user_name: "alice".to_string(),
});
});
let inbound = channel.run(outbound);
tokio::pin!(inbound);
assert_eq!(
inbound.next().await.unwrap().unwrap(),
ServerResponse::UserStatusResponse(response)
);
assert!(inbound.next().await.is_none());
handle.shutdown(ShutdownType::LameDuck);
server_task.await.unwrap().unwrap();
}
#[tokio::test]
async fn stream_closed() {
init();
let (server, handle) = ServerBuilder::default().bind().await.unwrap();
let server_task = tokio::spawn(server.serve());
let stream = TcpStream::connect(handle.address()).await.unwrap();
let channel = Client::new(stream).login(credentials()).await.unwrap();
let (_request_tx, mut request_rx) = mpsc::channel(1);
let outbound = Box::pin(async_stream::stream! {
while let Some(request) = request_rx.recv().await {
yield request;
}
});
let inbound = channel.run(outbound);
tokio::pin!(inbound);
// Server shuts down, closing its connection before the client has had a
// chance to send all of `outbound`.
handle.shutdown(ShutdownType::Immediate);
// Wait for the server to terminate, to avoid race conditions.
server_task.await.unwrap().unwrap();
// Check that the client returns the correct error, then stops running.
assert!(inbound
.next()
.await
.unwrap()
.unwrap_err()
.is_unexpected_eof());
assert!(inbound.next().await.is_none());
}
}