|
|
|
@ -1,20 +1,22 @@ |
|
|
|
//! A client interface for remote servers.
|
|
|
|
|
|
|
|
use std::io;
|
|
|
|
|
|
|
|
use log::{debug, info};
|
|
|
|
use thiserror::Error;
|
|
|
|
use tokio::net::TcpStream;
|
|
|
|
use tokio::sync::mpsc;
|
|
|
|
|
|
|
|
use crate::core::channel::{Channel, ChannelError};
|
|
|
|
use crate::core::{Worker, WorkerError};
|
|
|
|
use crate::server::{
|
|
|
|
Credentials, LoginResponse, ServerRequest, ServerResponse, Version,
|
|
|
|
};
|
|
|
|
|
|
|
|
/// A `Worker` that sends `ServerRequest`s and receives `ServerResponse`s.
|
|
|
|
pub type ClientWorker = Worker<ServerResponse, ServerRequest>;
|
|
|
|
|
|
|
|
/// A client for the client-server protocol.
|
|
|
|
#[derive(Debug)]
|
|
|
|
pub struct Client {
|
|
|
|
channel: Channel<ServerResponse, ServerRequest>,
|
|
|
|
stream: TcpStream,
|
|
|
|
version: Version,
|
|
|
|
}
|
|
|
|
|
|
|
|
@ -22,29 +24,26 @@ pub struct Client { |
|
|
|
#[derive(Debug, Error)]
|
|
|
|
pub enum ClientLoginError {
|
|
|
|
#[error("login failed: {0}")]
|
|
|
|
LoginFailed(String, Client),
|
|
|
|
LoginFailed(String, ClientWorker),
|
|
|
|
|
|
|
|
#[error("unexpected response: {0:?}")]
|
|
|
|
UnexpectedResponse(ServerResponse),
|
|
|
|
|
|
|
|
#[error("channel error: {0}")]
|
|
|
|
ChannelError(#[from] ChannelError),
|
|
|
|
}
|
|
|
|
#[error("send error: {0}")]
|
|
|
|
SendError(#[from] mpsc::error::SendError<ServerRequest>),
|
|
|
|
|
|
|
|
impl From<io::Error> for ClientLoginError {
|
|
|
|
fn from(error: io::Error) -> Self {
|
|
|
|
ClientLoginError::from(ChannelError::from(error))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#[error("worker error: {0}")]
|
|
|
|
WorkerError(#[from] WorkerError),
|
|
|
|
|
|
|
|
/// A `Channel` that sends `ServerRequest`s and receives `ServerResponse`s.
|
|
|
|
pub type ClientChannel = Channel<ServerResponse, ServerRequest>;
|
|
|
|
#[error("stream closed unexpectedly")]
|
|
|
|
StreamClosed,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl Client {
|
|
|
|
/// Instantiates a new client
|
|
|
|
pub fn new(tcp_stream: TcpStream) -> Self {
|
|
|
|
pub fn new(stream: TcpStream) -> Self {
|
|
|
|
Client {
|
|
|
|
channel: Channel::new(tcp_stream),
|
|
|
|
stream,
|
|
|
|
version: Version::default(),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
@ -57,32 +56,52 @@ impl Client { |
|
|
|
|
|
|
|
/// Performs the login exchange, presenting `credentials` to the server.
|
|
|
|
pub async fn login(
|
|
|
|
mut self,
|
|
|
|
self,
|
|
|
|
credentials: Credentials,
|
|
|
|
) -> Result<ClientChannel, ClientLoginError> {
|
|
|
|
) -> Result<ClientWorker, ClientLoginError> {
|
|
|
|
let mut worker = ClientWorker::new(self.stream);
|
|
|
|
let (request_tx, request_rx) = mpsc::channel(1);
|
|
|
|
let (response_tx, mut response_rx) = mpsc::channel(1);
|
|
|
|
|
|
|
|
let worker_task = tokio::spawn(async move {
|
|
|
|
worker
|
|
|
|
.run(response_tx, request_rx)
|
|
|
|
.await
|
|
|
|
.map(move |()| worker)
|
|
|
|
});
|
|
|
|
|
|
|
|
let login_request = credentials.into_login_request(self.version);
|
|
|
|
debug!("Client: sending login request: {:?}", login_request);
|
|
|
|
debug!("Sending login request: {:?}", login_request);
|
|
|
|
|
|
|
|
let request = login_request.into();
|
|
|
|
self.channel.write(&request).await?;
|
|
|
|
request_tx.send(login_request.into()).await?;
|
|
|
|
let optional_response = response_rx.recv().await;
|
|
|
|
|
|
|
|
let response = self.channel.read().await?;
|
|
|
|
debug!("Client: received first response: {:?}", response);
|
|
|
|
// Join the worker even if we received `None`, in case it failed.
|
|
|
|
// Panic in case of join error, as if we had run the worker itself.
|
|
|
|
drop(request_tx);
|
|
|
|
let worker = worker_task.await.expect("joining worker")?;
|
|
|
|
|
|
|
|
let response = match optional_response {
|
|
|
|
None => return Err(ClientLoginError::StreamClosed),
|
|
|
|
Some(response) => response,
|
|
|
|
};
|
|
|
|
|
|
|
|
debug!("Received first response: {:?}", response);
|
|
|
|
match response {
|
|
|
|
ServerResponse::LoginResponse(LoginResponse::LoginOk {
|
|
|
|
motd,
|
|
|
|
ip,
|
|
|
|
password_md5_opt,
|
|
|
|
}) => {
|
|
|
|
info!("Client: Logged in successfully!");
|
|
|
|
info!("Client: Message Of The Day: {}", motd);
|
|
|
|
info!("Client: Public IP address: {}", ip);
|
|
|
|
info!("Client: Password MD5: {:?}", password_md5_opt);
|
|
|
|
Ok(self.channel)
|
|
|
|
info!("Login: success!");
|
|
|
|
info!("Login: Message Of The Day: {}", motd);
|
|
|
|
info!("Login: Public IP address: {}", ip);
|
|
|
|
info!("Login: Password MD5: {:?}", password_md5_opt);
|
|
|
|
|
|
|
|
Ok(worker)
|
|
|
|
}
|
|
|
|
ServerResponse::LoginResponse(LoginResponse::LoginFail { reason }) => {
|
|
|
|
Err(ClientLoginError::LoginFailed(reason, self))
|
|
|
|
Err(ClientLoginError::LoginFailed(reason, worker))
|
|
|
|
}
|
|
|
|
response => Err(ClientLoginError::UnexpectedResponse(response)),
|
|
|
|
}
|
|
|
|
@ -91,7 +110,6 @@ impl Client { |
|
|
|
|
|
|
|
#[cfg(test)]
|
|
|
|
mod tests {
|
|
|
|
use futures::stream::{empty, StreamExt};
|
|
|
|
use tokio::net::TcpStream;
|
|
|
|
use tokio::sync::mpsc;
|
|
|
|
|
|
|
|
@ -113,28 +131,43 @@ mod tests { |
|
|
|
fn credentials() -> Credentials {
|
|
|
|
let user_name = "alice".to_string();
|
|
|
|
let password = "sekrit".to_string();
|
|
|
|
Credentials::new(user_name, password).unwrap()
|
|
|
|
Credentials::new(user_name, password).expect("building credentials")
|
|
|
|
}
|
|
|
|
|
|
|
|
// TODO: Tests for all login error conditions:
|
|
|
|
//
|
|
|
|
// - login failed
|
|
|
|
// - unexpected response
|
|
|
|
// - read error
|
|
|
|
// - write error
|
|
|
|
// - stream closed
|
|
|
|
|
|
|
|
#[tokio::test]
|
|
|
|
async fn login() {
|
|
|
|
async fn login_success() {
|
|
|
|
init();
|
|
|
|
|
|
|
|
let (server, handle) = ServerBuilder::default().bind().await.unwrap();
|
|
|
|
let (server, handle) = ServerBuilder::default()
|
|
|
|
.bind()
|
|
|
|
.await
|
|
|
|
.expect("binding server");
|
|
|
|
let server_task = tokio::spawn(server.serve());
|
|
|
|
|
|
|
|
let stream = TcpStream::connect(handle.address()).await.unwrap();
|
|
|
|
|
|
|
|
let channel = Client::new(stream).login(credentials()).await.unwrap();
|
|
|
|
let stream = TcpStream::connect(handle.address())
|
|
|
|
.await
|
|
|
|
.expect("connecting");
|
|
|
|
|
|
|
|
// Send nothing, receive no responses.
|
|
|
|
let inbound = channel.run(empty());
|
|
|
|
tokio::pin!(inbound);
|
|
|
|
let worker = Client::new(stream)
|
|
|
|
.login(credentials())
|
|
|
|
.await
|
|
|
|
.expect("logging in");
|
|
|
|
|
|
|
|
assert!(inbound.next().await.is_none());
|
|
|
|
drop(worker);
|
|
|
|
|
|
|
|
handle.shutdown(ShutdownType::LameDuck);
|
|
|
|
server_task.await.unwrap().unwrap();
|
|
|
|
server_task
|
|
|
|
.await
|
|
|
|
.expect("joining server")
|
|
|
|
.expect("running server");
|
|
|
|
}
|
|
|
|
|
|
|
|
#[tokio::test]
|
|
|
|
@ -142,7 +175,7 @@ mod tests { |
|
|
|
init();
|
|
|
|
|
|
|
|
let response = UserStatusResponse {
|
|
|
|
user_name: "alice".to_string(),
|
|
|
|
user_name: "shruti".to_string(),
|
|
|
|
status: UserStatus::Online,
|
|
|
|
is_privileged: false,
|
|
|
|
};
|
|
|
|
@ -154,70 +187,52 @@ mod tests { |
|
|
|
.with_user_status_map(user_status_map)
|
|
|
|
.bind()
|
|
|
|
.await
|
|
|
|
.unwrap();
|
|
|
|
.expect("binding server");
|
|
|
|
let server_task = tokio::spawn(server.serve());
|
|
|
|
|
|
|
|
let stream = TcpStream::connect(handle.address()).await.unwrap();
|
|
|
|
let stream = TcpStream::connect(handle.address())
|
|
|
|
.await
|
|
|
|
.expect("connecting");
|
|
|
|
|
|
|
|
let mut worker = Client::new(stream)
|
|
|
|
.login(credentials())
|
|
|
|
.await
|
|
|
|
.expect("logging in");
|
|
|
|
|
|
|
|
let channel = Client::new(stream).login(credentials()).await.unwrap();
|
|
|
|
let (request_tx, request_rx) = mpsc::channel(100);
|
|
|
|
let (response_tx, mut response_rx) = mpsc::channel(100);
|
|
|
|
|
|
|
|
let outbound = Box::pin(async_stream::stream! {
|
|
|
|
yield ServerRequest::UserStatusRequest(UserStatusRequest {
|
|
|
|
user_name: "bob".to_string(),
|
|
|
|
});
|
|
|
|
yield ServerRequest::UserStatusRequest(UserStatusRequest {
|
|
|
|
user_name: "alice".to_string(),
|
|
|
|
});
|
|
|
|
});
|
|
|
|
request_tx
|
|
|
|
.send(ServerRequest::UserStatusRequest(UserStatusRequest {
|
|
|
|
user_name: "shruti".to_string(),
|
|
|
|
}))
|
|
|
|
.await
|
|
|
|
.expect("sending shruti");
|
|
|
|
request_tx
|
|
|
|
.send(ServerRequest::UserStatusRequest(UserStatusRequest {
|
|
|
|
user_name: "karandeep".to_string(),
|
|
|
|
}))
|
|
|
|
.await
|
|
|
|
.expect("sending karandeep");
|
|
|
|
|
|
|
|
let inbound = channel.run(outbound);
|
|
|
|
tokio::pin!(inbound);
|
|
|
|
let worker_task =
|
|
|
|
tokio::spawn(async move { worker.run(response_tx, request_rx).await });
|
|
|
|
|
|
|
|
assert_eq!(
|
|
|
|
inbound.next().await.unwrap().unwrap(),
|
|
|
|
ServerResponse::UserStatusResponse(response)
|
|
|
|
response_rx.recv().await,
|
|
|
|
Some(ServerResponse::UserStatusResponse(response))
|
|
|
|
);
|
|
|
|
assert!(inbound.next().await.is_none());
|
|
|
|
|
|
|
|
handle.shutdown(ShutdownType::LameDuck);
|
|
|
|
server_task.await.unwrap().unwrap();
|
|
|
|
}
|
|
|
|
|
|
|
|
#[tokio::test]
|
|
|
|
async fn stream_closed() {
|
|
|
|
init();
|
|
|
|
|
|
|
|
let (server, handle) = ServerBuilder::default().bind().await.unwrap();
|
|
|
|
let server_task = tokio::spawn(server.serve());
|
|
|
|
|
|
|
|
let stream = TcpStream::connect(handle.address()).await.unwrap();
|
|
|
|
|
|
|
|
let channel = Client::new(stream).login(credentials()).await.unwrap();
|
|
|
|
|
|
|
|
let (_request_tx, mut request_rx) = mpsc::channel(1);
|
|
|
|
let outbound = Box::pin(async_stream::stream! {
|
|
|
|
while let Some(request) = request_rx.recv().await {
|
|
|
|
yield request;
|
|
|
|
}
|
|
|
|
});
|
|
|
|
|
|
|
|
let inbound = channel.run(outbound);
|
|
|
|
tokio::pin!(inbound);
|
|
|
|
|
|
|
|
// Server shuts down, closing its connection before the client has had a
|
|
|
|
// chance to send all of `outbound`.
|
|
|
|
handle.shutdown(ShutdownType::Immediate);
|
|
|
|
|
|
|
|
// Wait for the server to terminate, to avoid race conditions.
|
|
|
|
server_task.await.unwrap().unwrap();
|
|
|
|
drop(request_tx);
|
|
|
|
worker_task
|
|
|
|
.await
|
|
|
|
.expect("joining worker")
|
|
|
|
.expect("running worker");
|
|
|
|
|
|
|
|
// Check that the client returns the correct error, then stops running.
|
|
|
|
assert!(inbound
|
|
|
|
.next()
|
|
|
|
handle.shutdown(ShutdownType::LameDuck);
|
|
|
|
server_task
|
|
|
|
.await
|
|
|
|
.unwrap()
|
|
|
|
.unwrap_err()
|
|
|
|
.is_unexpected_eof());
|
|
|
|
assert!(inbound.next().await.is_none());
|
|
|
|
.expect("joining server")
|
|
|
|
.expect("running server");
|
|
|
|
}
|
|
|
|
}
|