Browse Source

Have Client return a Worker.

wip
Titouan Rigoudy 4 years ago
parent
commit
c1d3f30918
10 changed files with 273 additions and 522 deletions
  1. +10
    -7
      client/src/control/ws.rs
  2. +34
    -42
      client/src/main.rs
  3. +0
    -310
      proto/src/core/channel.rs
  4. +14
    -2
      proto/src/core/frame.rs
  5. +3
    -2
      proto/src/core/mod.rs
  6. +65
    -35
      proto/src/core/worker.rs
  7. +112
    -97
      proto/src/server/client.rs
  8. +1
    -1
      proto/src/server/mod.rs
  9. +9
    -7
      proto/src/server/testing.rs
  10. +25
    -19
      proto/tests/connect.rs

+ 10
- 7
client/src/control/ws.rs View File

@ -16,7 +16,7 @@ use crate::dispatcher::Message;
async fn forward_incoming(
mut incoming: SplitStream<WebSocketStream<TcpStream>>,
message_tx: &mpsc::UnboundedSender<Message>,
message_tx: &mpsc::Sender<Message>,
) -> anyhow::Result<()> {
while let Some(result) = incoming.next().await {
if let Err(WebSocketError::ConnectionClosed) = result {
@ -30,7 +30,10 @@ async fn forward_incoming(
let control_request: Request = serde_json::from_str(&text)
.with_context(|| format!("decoding JSON message {:?}", text))?;
message_tx.send(Message::ControlRequest(control_request));
message_tx
.send(Message::ControlRequest(control_request))
.await
.context("dispatcher channel closed")?;
}
Ok(())
@ -56,7 +59,7 @@ async fn forward_outgoing(
async fn handle(
stream: TcpStream,
remote_address: &SocketAddr,
message_tx: &mpsc::UnboundedSender<Message>,
message_tx: &mpsc::Sender<Message>,
response_rx: &mut mpsc::Receiver<Response>,
) {
let ws_stream = match tokio_tungstenite::accept_async(stream).await {
@ -132,7 +135,7 @@ impl Listener {
/// responses from `response_rx` to the socket.
pub async fn run(
&mut self,
message_tx: mpsc::UnboundedSender<Message>,
message_tx: mpsc::Sender<Message>,
mut response_rx: mpsc::Receiver<Response>,
) -> anyhow::Result<()> {
info!("Accepting control connections on {}", self.address());
@ -187,7 +190,7 @@ mod tests {
// Convenient for tests.
struct RunnableListener {
inner: Listener,
message_tx: mpsc::UnboundedSender<Message>,
message_tx: mpsc::Sender<Message>,
response_rx: mpsc::Receiver<Response>,
}
@ -202,7 +205,7 @@ mod tests {
pub listener: RunnableListener,
pub address: SocketAddr,
pub websocket_address: String,
pub message_rx: mpsc::UnboundedReceiver<Message>,
pub message_rx: mpsc::Receiver<Message>,
pub response_tx: mpsc::Sender<Response>,
}
@ -213,7 +216,7 @@ mod tests {
let address = inner.address().clone();
let websocket_address = format!("ws://{}", address);
let (message_tx, message_rx) = mpsc::unbounded_channel();
let (message_tx, message_rx) = mpsc::channel(100);
let (response_tx, response_rx) = mpsc::channel(100);
let listener = RunnableListener {


+ 34
- 42
client/src/main.rs View File

@ -8,10 +8,10 @@ use anyhow::Context;
use clap::{App, Arg};
use crossbeam_channel;
use env_logger;
use futures::stream::StreamExt;
use log::info;
use solstice_proto::config;
use tokio::net::TcpStream;
use tokio::sync::mpsc;
mod client;
mod context;
@ -48,28 +48,9 @@ fn old_main() {
client.run();
}
// There is a risk of deadlock if we use two bounded channels here:
//
// - client task is blocked trying to send response to dispatcher
// - all dispatcher threads are blocked trying to send requests to client
//
// This stems from the fact that requests are only read from the channel and
// sent to the server when `inbound` is being polled, which is mutually
// exclusive with `dispatcher_tx.send()` being polled.
//
// This could be fixed in one of two ways, at least:
//
// - write `Client` interface in terms of channels, not streams
// - this would allow both receiving and sending tasks to run concurrently
// inside `Client`
// - in other words, sending the response on the dispatcher channel would
// run concurrently with sending requests to the server
// - use `FrameReader` / `FrameWriter` directly instead, and synchronize their
// behavior manually
//
async fn run_client(
mut request_rx: tokio::sync::mpsc::Receiver<solstice_proto::ServerRequest>,
dispatcher_tx: tokio::sync::mpsc::UnboundedSender<dispatcher::Message>,
request_rx: mpsc::Receiver<solstice_proto::ServerRequest>,
dispatcher_tx: mpsc::Sender<dispatcher::Message>,
) -> anyhow::Result<()> {
let address = format!(
"{}:{}",
@ -77,43 +58,54 @@ async fn run_client(
solstice_proto::config::SERVER_PORT
);
info!("Connecting to server at {}.", address);
let stream = TcpStream::connect(address).await?;
let stream = TcpStream::connect(address)
.await
.context("connecting to server")?;
info!("Connection successful.");
let credentials = solstice_proto::server::Credentials::new(
solstice_proto::config::USERNAME.to_string(),
solstice_proto::config::PASSWORD.to_string(),
)
.expect("Invalid credentials");
.context("validating credentials")?;
info!("Logging in to server.");
let client = solstice_proto::server::Client::new(stream)
let mut worker = solstice_proto::server::Client::new(stream)
.login(credentials)
.await?;
let outbound = async_stream::stream! {
while let Some(request) = request_rx.recv().await {
yield request;
}
};
.await
.context("logging in")?;
info!("Login successful.");
// TODO: Define a constant for this, or something.
let (response_tx, mut response_rx) = mpsc::channel(100);
let forwarder_task = tokio::spawn(async move {
while let Some(response) = response_rx.recv().await {
dispatcher_tx
.send(dispatcher::Message::ServerResponse(response))
.await
.expect("dispatcher channel closed");
}
});
info!("Running client.");
let inbound = client.run(outbound);
tokio::pin!(inbound);
info!("Running client worker.");
worker
.run(response_tx, request_rx)
.await
.context("running worker")?;
info!("Client worker finished running.");
while let Some(result) = inbound.next().await {
let response = result?;
dispatcher_tx.send(response.into())?;
}
info!("Joining forwarder.");
forwarder_task.await.context("joining forwarder")?;
info!("Joined forwarder.");
info!("Client finished running.");
Ok(())
}
async fn async_main() -> anyhow::Result<()> {
let bundle = ContextBundle::default();
let (dispatcher_tx, mut dispatcher_rx) =
tokio::sync::mpsc::unbounded_channel();
// TODO: Define a constant for this, or something.
let (dispatcher_tx, mut dispatcher_rx) = mpsc::channel(100);
let client_task =
tokio::spawn(run_client(bundle.server_request_rx, dispatcher_tx.clone()));


+ 0
- 310
proto/src/core/channel.rs View File

@ -1,310 +0,0 @@
use std::fmt::Debug;
use std::future::Future;
use std::io;
use futures::stream::{Stream, StreamExt};
use log::debug;
use thiserror::Error;
use tokio::net::{
tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpStream,
};
use crate::core::frame::{FrameReader, FrameWriter};
use crate::core::value::{ValueDecode, ValueEncode};
/// An error that arose while exchanging messages over a `Channel`.
#[derive(Debug, Error)]
pub enum ChannelError {
#[error("unexpected end of file")]
UnexpectedEof,
#[error("i/o error: {0}")]
IOError(#[from] io::Error),
}
impl ChannelError {
#[cfg(test)]
pub fn is_unexpected_eof(&self) -> bool {
match self {
ChannelError::UnexpectedEof => true,
_ => false,
}
}
}
/// A wrapper around a frame reader. Logically a part of `Channel`.
///
/// Exists to provide `Channel` functionality that requires borrowing the
/// channel's reader only. This allows borrowing both the reader and the writer
/// at the same time in `Channel` without resorting to static methods.
#[derive(Debug)]
struct ChannelReader<ReadFrame> {
inner: FrameReader<ReadFrame, OwnedReadHalf>,
}
impl<ReadFrame> ChannelReader<ReadFrame>
where
ReadFrame: ValueDecode + Debug,
{
async fn read(&mut self) -> io::Result<Option<ReadFrame>> {
self.inner.read().await.map(|frame| {
debug!("Channel: received frame: {:?}", frame);
frame
})
}
async fn read_strict(&mut self) -> Result<ReadFrame, ChannelError> {
match self.read().await? {
None => Err(ChannelError::UnexpectedEof),
Some(frame) => Ok(frame),
}
}
}
/// An asynchronous bidirectional message channel over TCP.
#[derive(Debug)]
pub struct Channel<ReadFrame, WriteFrame> {
reader: ChannelReader<ReadFrame>,
writer: FrameWriter<WriteFrame, OwnedWriteHalf>,
}
impl<ReadFrame, WriteFrame> Channel<ReadFrame, WriteFrame>
where
ReadFrame: ValueDecode + Debug,
WriteFrame: ValueEncode + Debug,
{
/// Wraps the given `stream` to yield a message channel.
pub fn new(stream: TcpStream) -> Self {
let (read_half, write_half) = stream.into_split();
Self {
reader: ChannelReader {
inner: FrameReader::new(read_half),
},
writer: FrameWriter::new(write_half),
}
}
// This future sends all the requests from `request_stream` through `writer`
// until the stream is finished, then resolves.
async fn send<S: Stream<Item = WriteFrame>>(
writer: &mut FrameWriter<WriteFrame, OwnedWriteHalf>,
send_stream: S,
) -> io::Result<()> {
tokio::pin!(send_stream);
while let Some(frame) = send_stream.next().await {
debug!("Channel: sending frame: {:?}", frame);
writer.write(&frame).await?;
}
Ok(())
}
// It would be easier to inline this `select!` call inside `run()`, but that
// fails due to some weird, undiagnosed error due to the interaction of
// `async_stream::try_stream!`, `select!` and the `?` operator.
async fn run_once<S: Future<Output = io::Result<()>>>(
send_task: S,
reader: &mut ChannelReader<ReadFrame>,
) -> Result<Option<ReadFrame>, ChannelError> {
tokio::select! {
send_result = send_task => {
send_result?;
Ok(None)
},
read_result = reader.read_strict() => read_result.map(Some),
}
}
/// Attempts to read a single frame from the underlying stream.
pub async fn read(&mut self) -> Result<ReadFrame, ChannelError> {
self.reader.read_strict().await
}
/// Attempts to write a single frame to the underlying stream.
pub async fn write(&mut self, frame: &WriteFrame) -> io::Result<()> {
self.writer.write(frame).await
}
/// Sends the given stream of frames while receiving frames in return.
/// Once `send_stream` is exhausted, shuts down the underlying TCP stream,
/// drains incoming frames, then terminates.
pub fn run<S: Stream<Item = WriteFrame>>(
mut self,
send_stream: S,
) -> impl Stream<Item = Result<ReadFrame, ChannelError>> {
async_stream::try_stream! {
// Drive the main loop: send requests and receive responses.
//
// We make a big future out of the operation of waiting for requests
// to send and from `request_stream` and sending them out through
// `self.writer`, that we can then poll repeatedly and concurrently
// with polling for responses. This allows us to concurrently write
// and read from the underlying `TcpStream` in full duplex mode.
{
let send_task = Self::send(&mut self.writer, send_stream);
tokio::pin!(send_task);
while let Some(frame) =
Self::run_once(&mut send_task, &mut self.reader).await? {
yield frame;
}
}
debug!("Channel: shutting down writer");
self.writer.shutdown().await?;
// Drain the receiving end of the connection.
while let Some(frame) = self.reader.read().await? {
yield frame;
}
}
}
}
#[cfg(test)]
mod tests {
use futures::stream::{self, StreamExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::oneshot;
use super::Channel;
// Enable capturing logs in tests.
fn init() {
let _ = env_logger::builder().is_test(true).try_init();
}
#[tokio::test]
async fn read_write() {
init();
let listener = TcpListener::bind("localhost:0").await.unwrap();
let address = listener.local_addr().unwrap();
let listener_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let mut channel = Channel::<u32, u32>::new(stream);
assert_eq!(channel.read().await.unwrap(), 1);
channel.write(&2).await.unwrap();
});
let stream = TcpStream::connect(address).await.unwrap();
let mut channel = Channel::<u32, u32>::new(stream);
channel.write(&1).await.unwrap();
assert_eq!(channel.read().await.unwrap(), 2);
listener_task.await.unwrap();
}
#[tokio::test]
async fn read_eof() {
init();
let listener = TcpListener::bind("localhost:0").await.unwrap();
let address = listener.local_addr().unwrap();
let listener_task = tokio::spawn(async move {
// Accept the stream and immediately drop/close it.
listener.accept().await.unwrap();
});
let stream = TcpStream::connect(address).await.unwrap();
let mut channel = Channel::<u32, u32>::new(stream);
assert!(channel.read().await.unwrap_err().is_unexpected_eof());
listener_task.await.unwrap();
}
#[tokio::test]
async fn open_close() {
init();
let listener = TcpListener::bind("localhost:0").await.unwrap();
let address = listener.local_addr().unwrap();
let listener_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let channel = Channel::<u32, u32>::new(stream);
// Wait forever, receive no messages.
let inbound = channel.run(stream::pending());
tokio::pin!(inbound);
// The server observes an unexpected EOF error when the client
// decides to close the channel.
let error = inbound.next().await.unwrap().unwrap_err();
assert!(error.is_unexpected_eof());
});
let stream = TcpStream::connect(address).await.unwrap();
let channel = Channel::<u32, u32>::new(stream);
// Stop immediately, receive no messages.
let inbound = channel.run(stream::empty());
tokio::pin!(inbound);
// The channel is closed cleanly from the client's point of view.
assert!(inbound.next().await.is_none());
listener_task.await.unwrap();
}
#[tokio::test]
async fn simple_exchange() {
init();
let listener = TcpListener::bind("localhost:0").await.unwrap();
let address = listener.local_addr().unwrap();
let listener_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let channel = Channel::<u32, u32>::new(stream);
let (tx, rx) = oneshot::channel::<u32>();
// Send one message, then wait forever.
let outbound =
stream::once(async move { rx.await.unwrap() }).chain(stream::pending());
let inbound = channel.run(outbound);
tokio::pin!(inbound);
// The server receives the client's message.
assert_eq!(inbound.next().await.unwrap().unwrap(), 1);
// Server responds.
tx.send(1001).unwrap();
// The server observes an unexpected EOF error when the client
// decides to close the channel.
let error = inbound.next().await.unwrap().unwrap_err();
assert!(error.is_unexpected_eof());
});
let stream = TcpStream::connect(address).await.unwrap();
let channel = Channel::<u32, u32>::new(stream);
let (tx, rx) = oneshot::channel::<()>();
// Send one message then wait for a reply. If we did not wait, we might
// shut down the channel before the server has a chance to respond.
let outbound = async_stream::stream! {
yield 1;
rx.await.unwrap();
};
let inbound = channel.run(outbound);
tokio::pin!(inbound);
// The client receives the server's message.
assert_eq!(inbound.next().await.unwrap().unwrap(), 1001);
// Signal to the client that we should shut down.
tx.send(()).unwrap();
assert!(inbound.next().await.is_none());
listener_task.await.unwrap();
}
}

+ 14
- 2
proto/src/core/frame.rs View File

@ -155,6 +155,7 @@ where
Frame: ValueEncode + ?Sized,
Writer: AsyncWrite + Unpin,
{
/// Wraps the given `writer`.
pub fn new(writer: Writer) -> Self {
FrameWriter {
encoder: FrameEncoder::new(),
@ -162,14 +163,19 @@ where
}
}
/// Attempts to write the given `frame` to the underlying byte sink.
///
/// Returns an error if encoding `frame` failed or if writing the encoded
/// bytes failed.
pub async fn write(&mut self, frame: &Frame) -> io::Result<()> {
let mut bytes = BytesMut::new();
self.encoder.encode_to(frame, &mut bytes)?;
self.writer.write_all(bytes.as_ref()).await
}
pub async fn shutdown(&mut self) -> io::Result<()> {
self.writer.shutdown().await
/// Returns the underlying byte sink.
pub fn into_inner(self) -> Writer {
self.writer
}
}
@ -186,6 +192,7 @@ where
Frame: ValueDecode,
Reader: AsyncRead + Unpin,
{
/// Wraps the given `reader`.
pub fn new(reader: Reader) -> Self {
FrameReader {
decoder: FrameDecoder::new(),
@ -211,6 +218,11 @@ where
}
}
}
/// Returns the underlying reader. Discards any buffered data already read.
pub fn into_inner(self) -> Reader {
return self.reader;
}
}
#[cfg(test)]


+ 3
- 2
proto/src/core/mod.rs View File

@ -1,5 +1,5 @@
pub mod channel;
pub mod constants;
// TODO: Remove `pub` qualifier, depend on re-exports.
pub mod frame;
mod prefix;
mod u32;
@ -8,6 +8,7 @@ mod user;
pub mod value;
mod worker;
pub use frame::{FrameReader, FrameWriter};
pub use user::{User, UserStatus};
pub use worker::{Worker, WorkerError};
pub use value::{ValueDecode, ValueEncode};
pub use worker::{Worker, WorkerError};

+ 65
- 35
proto/src/core/worker.rs View File

@ -24,7 +24,7 @@ pub enum WorkerError {
}
async fn forward_incoming<ReadFrame: ValueDecode + Debug>(
mut reader: FrameReader<ReadFrame, OwnedReadHalf>,
reader: &mut FrameReader<ReadFrame, OwnedReadHalf>,
incoming_tx: mpsc::Sender<ReadFrame>,
) -> Result<(), WorkerError> {
while let Some(frame) = reader.read().await.map_err(WorkerError::ReadError)? {
@ -41,7 +41,7 @@ async fn forward_incoming<ReadFrame: ValueDecode + Debug>(
async fn forward_outgoing<WriteFrame: ValueEncode + Debug>(
mut outgoing_rx: mpsc::Receiver<WriteFrame>,
mut writer: FrameWriter<WriteFrame, OwnedWriteHalf>,
writer: &mut FrameWriter<WriteFrame, OwnedWriteHalf>,
) -> Result<(), WorkerError> {
while let Some(frame) = outgoing_rx.recv().await {
debug!("Sending frame: {:?}", frame);
@ -56,11 +56,10 @@ async fn forward_outgoing<WriteFrame: ValueEncode + Debug>(
}
/// A worker that operates a full-duplex connection exchanging frames over TCP.
#[derive(Debug)]
pub struct Worker<ReadFrame, WriteFrame> {
reader: FrameReader<ReadFrame, OwnedReadHalf>,
writer: FrameWriter<WriteFrame, OwnedWriteHalf>,
incoming_tx: mpsc::Sender<ReadFrame>,
outgoing_rx: mpsc::Receiver<WriteFrame>,
}
impl<ReadFrame, WriteFrame> Worker<ReadFrame, WriteFrame>
@ -68,30 +67,33 @@ where
ReadFrame: ValueDecode + Debug,
WriteFrame: ValueEncode + Debug,
{
fn new(
stream: TcpStream,
incoming_tx: mpsc::Sender<ReadFrame>,
outgoing_rx: mpsc::Receiver<WriteFrame>,
) -> Self {
pub fn new(stream: TcpStream) -> Self {
let (read_half, write_half) = stream.into_split();
let reader = FrameReader::new(read_half);
let writer = FrameWriter::new(write_half);
Self {
reader,
writer,
incoming_tx,
outgoing_rx,
}
Self { reader, writer }
}
async fn run(self) -> Result<(), WorkerError> {
pub async fn run(
&mut self,
incoming_tx: mpsc::Sender<ReadFrame>,
outgoing_rx: mpsc::Receiver<WriteFrame>,
) -> Result<(), WorkerError> {
tokio::select! {
result = forward_incoming(self.reader, self.incoming_tx) => result?,
result = forward_outgoing(self.outgoing_rx, self.writer) => result?,
result = forward_incoming(&mut self.reader, incoming_tx) => result?,
result = forward_outgoing(outgoing_rx, &mut self.writer) => result?,
};
Ok(())
}
pub fn into_inner(self) -> TcpStream {
let read_half = self.reader.into_inner();
let write_half = self.writer.into_inner();
read_half
.reunite(write_half)
.expect("reuniting tcp stream halves")
}
}
#[cfg(test)]
@ -113,15 +115,17 @@ mod tests {
async fn stops_on_read_error() {
init();
let listener = TcpListener::bind("localhost:0").await.expect("binding listener");
let listener = TcpListener::bind("localhost:0")
.await
.expect("binding listener");
let address = listener.local_addr().expect("getting local address");
let listener_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.expect("accepting");
let junk = [
1, 0, 0, 0, // Length: 1 byte (big-endian)
0, // This is not enough for a u32, encoded as 4 bytes.
1, 0, 0, 0, // Length: 1 byte (big-endian)
0, // This is not enough for a u32, encoded as 4 bytes.
];
stream.write_all(&junk).await.expect("writing frame");
stream.shutdown().await.expect("shutting down");
@ -131,9 +135,12 @@ mod tests {
let (_request_tx, request_rx) = mpsc::channel::<u32>(100);
let (response_tx, _response_rx) = mpsc::channel::<u32>(100);
let worker = Worker::new(stream, response_tx, request_rx);
let mut worker = Worker::new(stream);
let err = worker.run().await.expect_err("running worker");
let err = worker
.run(response_tx, request_rx)
.await
.expect_err("running worker");
if let WorkerError::ReadError(_) = err {
// Ok!
} else {
@ -165,12 +172,15 @@ mod tests {
let (request_tx, request_rx) = mpsc::channel::<u32>(100);
let (response_tx, _response_rx) = mpsc::channel::<u32>(100);
let worker = Worker::new(stream, response_tx, request_rx);
let mut worker = Worker::new(stream);
// Queue a frame before we run the worker.
request_tx.send(42).await.expect("sending frame");
let err = worker.run().await.expect_err("running worker");
let err = worker
.run(response_tx, request_rx)
.await
.expect_err("running worker");
if let WorkerError::WriteError(_) = err {
// Ok!
} else {
@ -196,7 +206,10 @@ mod tests {
writer.write(&42u32).await.expect("writing frame");
let mut buf = Vec::new();
read_half.read_to_end(&mut buf).await.expect("waiting for eof");
read_half
.read_to_end(&mut buf)
.await
.expect("waiting for eof");
assert_eq!(buf, Vec::<u8>::new());
});
@ -204,18 +217,24 @@ mod tests {
let (_request_tx, request_rx) = mpsc::channel::<u32>(100);
let (response_tx, response_rx) = mpsc::channel::<u32>(100);
let worker = Worker::new(stream, response_tx, request_rx);
let mut worker = Worker::new(stream);
// Drop the receiver before the worker can send anything.
drop(response_rx);
let err = worker.run().await.expect_err("running worker");
let err = worker
.run(response_tx, request_rx)
.await
.expect_err("running worker");
if let WorkerError::IncomingChannelClosed = err {
// Ok!
} else {
panic!("Wrong error: {:?}", err);
}
// Drop the worker, and the underlying connection, to stop the listener.
drop(worker);
listener_task.await.expect("joining listener");
}
@ -235,7 +254,10 @@ mod tests {
writer.write(&frame).await.expect("writing frame");
let mut buf = Vec::new();
read_half.read_to_end(&mut buf).await.expect("waiting for eof");
read_half
.read_to_end(&mut buf)
.await
.expect("waiting for eof");
assert_eq!(buf, Vec::<u8>::new());
});
@ -243,9 +265,10 @@ mod tests {
let (request_tx, request_rx) = mpsc::channel::<u32>(100);
let (response_tx, mut response_rx) = mpsc::channel::<u32>(100);
let worker = Worker::new(stream, response_tx, request_rx);
let mut worker = Worker::new(stream);
let worker_task = tokio::spawn(worker.run());
let worker_task =
tokio::spawn(async move { worker.run(response_tx, request_rx).await });
let frame = response_rx.recv().await.expect("receiving frame");
assert_eq!(frame, 42);
@ -253,7 +276,10 @@ mod tests {
// Signal to the worker that it should stop running.
drop(request_tx);
worker_task.await.expect("joining worker").expect("running worker");
worker_task
.await
.expect("joining worker")
.expect("running worker");
listener_task.await.expect("joining listener");
}
@ -276,16 +302,20 @@ mod tests {
let (request_tx, request_rx) = mpsc::channel::<u32>(100);
let (response_tx, _response_rx) = mpsc::channel::<u32>(100);
let worker = Worker::new(stream, response_tx, request_rx);
let mut worker = Worker::new(stream);
let worker_task = tokio::spawn(worker.run());
let worker_task =
tokio::spawn(async move { worker.run(response_tx, request_rx).await });
request_tx.send(42).await.expect("sending frame");
// Signal to the worker that it should stop running.
drop(request_tx);
worker_task.await.expect("joining worker").expect("running worker");
worker_task
.await
.expect("joining worker")
.expect("running worker");
listener_task.await.expect("joining listener");
}
}

+ 112
- 97
proto/src/server/client.rs View File

@ -1,20 +1,22 @@
//! A client interface for remote servers.
use std::io;
use log::{debug, info};
use thiserror::Error;
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use crate::core::channel::{Channel, ChannelError};
use crate::core::{Worker, WorkerError};
use crate::server::{
Credentials, LoginResponse, ServerRequest, ServerResponse, Version,
};
/// A `Worker` that sends `ServerRequest`s and receives `ServerResponse`s.
pub type ClientWorker = Worker<ServerResponse, ServerRequest>;
/// A client for the client-server protocol.
#[derive(Debug)]
pub struct Client {
channel: Channel<ServerResponse, ServerRequest>,
stream: TcpStream,
version: Version,
}
@ -22,29 +24,26 @@ pub struct Client {
#[derive(Debug, Error)]
pub enum ClientLoginError {
#[error("login failed: {0}")]
LoginFailed(String, Client),
LoginFailed(String, ClientWorker),
#[error("unexpected response: {0:?}")]
UnexpectedResponse(ServerResponse),
#[error("channel error: {0}")]
ChannelError(#[from] ChannelError),
}
#[error("send error: {0}")]
SendError(#[from] mpsc::error::SendError<ServerRequest>),
impl From<io::Error> for ClientLoginError {
fn from(error: io::Error) -> Self {
ClientLoginError::from(ChannelError::from(error))
}
}
#[error("worker error: {0}")]
WorkerError(#[from] WorkerError),
/// A `Channel` that sends `ServerRequest`s and receives `ServerResponse`s.
pub type ClientChannel = Channel<ServerResponse, ServerRequest>;
#[error("stream closed unexpectedly")]
StreamClosed,
}
impl Client {
/// Instantiates a new client
pub fn new(tcp_stream: TcpStream) -> Self {
pub fn new(stream: TcpStream) -> Self {
Client {
channel: Channel::new(tcp_stream),
stream,
version: Version::default(),
}
}
@ -57,32 +56,52 @@ impl Client {
/// Performs the login exchange, presenting `credentials` to the server.
pub async fn login(
mut self,
self,
credentials: Credentials,
) -> Result<ClientChannel, ClientLoginError> {
) -> Result<ClientWorker, ClientLoginError> {
let mut worker = ClientWorker::new(self.stream);
let (request_tx, request_rx) = mpsc::channel(1);
let (response_tx, mut response_rx) = mpsc::channel(1);
let worker_task = tokio::spawn(async move {
worker
.run(response_tx, request_rx)
.await
.map(move |()| worker)
});
let login_request = credentials.into_login_request(self.version);
debug!("Client: sending login request: {:?}", login_request);
debug!("Sending login request: {:?}", login_request);
let request = login_request.into();
self.channel.write(&request).await?;
request_tx.send(login_request.into()).await?;
let optional_response = response_rx.recv().await;
let response = self.channel.read().await?;
debug!("Client: received first response: {:?}", response);
// Join the worker even if we received `None`, in case it failed.
// Panic in case of join error, as if we had run the worker itself.
drop(request_tx);
let worker = worker_task.await.expect("joining worker")?;
let response = match optional_response {
None => return Err(ClientLoginError::StreamClosed),
Some(response) => response,
};
debug!("Received first response: {:?}", response);
match response {
ServerResponse::LoginResponse(LoginResponse::LoginOk {
motd,
ip,
password_md5_opt,
}) => {
info!("Client: Logged in successfully!");
info!("Client: Message Of The Day: {}", motd);
info!("Client: Public IP address: {}", ip);
info!("Client: Password MD5: {:?}", password_md5_opt);
Ok(self.channel)
info!("Login: success!");
info!("Login: Message Of The Day: {}", motd);
info!("Login: Public IP address: {}", ip);
info!("Login: Password MD5: {:?}", password_md5_opt);
Ok(worker)
}
ServerResponse::LoginResponse(LoginResponse::LoginFail { reason }) => {
Err(ClientLoginError::LoginFailed(reason, self))
Err(ClientLoginError::LoginFailed(reason, worker))
}
response => Err(ClientLoginError::UnexpectedResponse(response)),
}
@ -91,7 +110,6 @@ impl Client {
#[cfg(test)]
mod tests {
use futures::stream::{empty, StreamExt};
use tokio::net::TcpStream;
use tokio::sync::mpsc;
@ -113,28 +131,43 @@ mod tests {
fn credentials() -> Credentials {
let user_name = "alice".to_string();
let password = "sekrit".to_string();
Credentials::new(user_name, password).unwrap()
Credentials::new(user_name, password).expect("building credentials")
}
// TODO: Tests for all login error conditions:
//
// - login failed
// - unexpected response
// - read error
// - write error
// - stream closed
#[tokio::test]
async fn login() {
async fn login_success() {
init();
let (server, handle) = ServerBuilder::default().bind().await.unwrap();
let (server, handle) = ServerBuilder::default()
.bind()
.await
.expect("binding server");
let server_task = tokio::spawn(server.serve());
let stream = TcpStream::connect(handle.address()).await.unwrap();
let channel = Client::new(stream).login(credentials()).await.unwrap();
let stream = TcpStream::connect(handle.address())
.await
.expect("connecting");
// Send nothing, receive no responses.
let inbound = channel.run(empty());
tokio::pin!(inbound);
let worker = Client::new(stream)
.login(credentials())
.await
.expect("logging in");
assert!(inbound.next().await.is_none());
drop(worker);
handle.shutdown(ShutdownType::LameDuck);
server_task.await.unwrap().unwrap();
server_task
.await
.expect("joining server")
.expect("running server");
}
#[tokio::test]
@ -142,7 +175,7 @@ mod tests {
init();
let response = UserStatusResponse {
user_name: "alice".to_string(),
user_name: "shruti".to_string(),
status: UserStatus::Online,
is_privileged: false,
};
@ -154,70 +187,52 @@ mod tests {
.with_user_status_map(user_status_map)
.bind()
.await
.unwrap();
.expect("binding server");
let server_task = tokio::spawn(server.serve());
let stream = TcpStream::connect(handle.address()).await.unwrap();
let stream = TcpStream::connect(handle.address())
.await
.expect("connecting");
let mut worker = Client::new(stream)
.login(credentials())
.await
.expect("logging in");
let channel = Client::new(stream).login(credentials()).await.unwrap();
let (request_tx, request_rx) = mpsc::channel(100);
let (response_tx, mut response_rx) = mpsc::channel(100);
let outbound = Box::pin(async_stream::stream! {
yield ServerRequest::UserStatusRequest(UserStatusRequest {
user_name: "bob".to_string(),
});
yield ServerRequest::UserStatusRequest(UserStatusRequest {
user_name: "alice".to_string(),
});
});
request_tx
.send(ServerRequest::UserStatusRequest(UserStatusRequest {
user_name: "shruti".to_string(),
}))
.await
.expect("sending shruti");
request_tx
.send(ServerRequest::UserStatusRequest(UserStatusRequest {
user_name: "karandeep".to_string(),
}))
.await
.expect("sending karandeep");
let inbound = channel.run(outbound);
tokio::pin!(inbound);
let worker_task =
tokio::spawn(async move { worker.run(response_tx, request_rx).await });
assert_eq!(
inbound.next().await.unwrap().unwrap(),
ServerResponse::UserStatusResponse(response)
response_rx.recv().await,
Some(ServerResponse::UserStatusResponse(response))
);
assert!(inbound.next().await.is_none());
handle.shutdown(ShutdownType::LameDuck);
server_task.await.unwrap().unwrap();
}
#[tokio::test]
async fn stream_closed() {
init();
let (server, handle) = ServerBuilder::default().bind().await.unwrap();
let server_task = tokio::spawn(server.serve());
let stream = TcpStream::connect(handle.address()).await.unwrap();
let channel = Client::new(stream).login(credentials()).await.unwrap();
let (_request_tx, mut request_rx) = mpsc::channel(1);
let outbound = Box::pin(async_stream::stream! {
while let Some(request) = request_rx.recv().await {
yield request;
}
});
let inbound = channel.run(outbound);
tokio::pin!(inbound);
// Server shuts down, closing its connection before the client has had a
// chance to send all of `outbound`.
handle.shutdown(ShutdownType::Immediate);
// Wait for the server to terminate, to avoid race conditions.
server_task.await.unwrap().unwrap();
drop(request_tx);
worker_task
.await
.expect("joining worker")
.expect("running worker");
// Check that the client returns the correct error, then stops running.
assert!(inbound
.next()
handle.shutdown(ShutdownType::LameDuck);
server_task
.await
.unwrap()
.unwrap_err()
.is_unexpected_eof());
assert!(inbound.next().await.is_none());
.expect("joining server")
.expect("running server");
}
}

+ 1
- 1
proto/src/server/mod.rs View File

@ -7,7 +7,7 @@ mod response;
mod testing;
mod version;
pub use self::client::{Client, ClientChannel, ClientLoginError};
pub use self::client::{Client, ClientLoginError, ClientWorker};
pub use self::credentials::Credentials;
pub use self::request::*;
pub use self::response::*;


+ 9
- 7
proto/src/server/testing.rs View File

@ -57,14 +57,14 @@ impl Handler {
&mut self,
response: &ServerResponse,
) -> io::Result<()> {
debug!("Handler: sending response: {:?}", response);
debug!("Sending response: {:?}", response);
self.writer.write(response).await
}
async fn handle_login(&mut self) -> io::Result<()> {
match self.reader.read().await? {
Some(ServerRequest::LoginRequest(request)) => {
info!("Handler: Received login request: {:?}", request);
info!("Received login request: {:?}", request);
}
Some(request) => {
return Err(io::Error::new(
@ -89,7 +89,7 @@ impl Handler {
}
async fn handle_request(&mut self, request: ServerRequest) -> io::Result<()> {
debug!("Handler: received request: {:?}", request);
debug!("Received request: {:?}", request);
match request {
ServerRequest::UserStatusRequest(UserStatusRequest { user_name }) => {
@ -97,10 +97,12 @@ impl Handler {
if let Some(response) = entry {
let response = ServerResponse::UserStatusResponse(response);
self.send_response(&response).await?;
} else {
warn!("Received UserStatusRequest for unknown user {}", user_name);
}
}
_ => {
warn!("Handler: unhandled request: {:?}", request);
warn!("Unhandled request: {:?}", request);
}
}
@ -114,7 +116,7 @@ impl Handler {
self.handle_request(request).await?;
}
info!("Handler: client disconnecting, shutting down");
info!("Client disconnecting, shutting down");
Ok(())
}
}
@ -129,14 +131,14 @@ impl GracefulHandler {
tokio::select!(
result = self.handler.run() => {
if let Err(ref error) = result {
warn!("GracefulHandler: handler returned error {:?}", error);
warn!("Handler returned error {:?}", error);
}
result
},
// Ignore receive errors - if shutdown_rx's sender is dropped, we take
// that as a signal to shut down too.
_ = self.shutdown_rx.changed() => {
info!("GracefulHandler: shutting down.");
info!("Handler shutting down.");
Ok(())
},
)


+ 25
- 19
proto/tests/connect.rs View File

@ -1,4 +1,3 @@
use futures::stream::{self, StreamExt};
use tokio::io;
use tokio::net;
use tokio::sync::mpsc;
@ -32,12 +31,19 @@ async fn integration_connect() {
let stream = connect().await.unwrap();
let credentials = make_credentials(make_user_name("connect"));
let channel = Client::new(stream).login(credentials).await.unwrap();
let mut worker = Client::new(stream).login(credentials).await.unwrap();
let inbound = channel.run(stream::pending());
tokio::pin!(inbound);
let (request_tx, request_rx) = mpsc::channel(100);
let (response_tx, _response_rx) = mpsc::channel(100);
assert!(inbound.next().await.is_some());
let worker_task =
tokio::spawn(async move { worker.run(response_tx, request_rx).await });
drop(request_tx);
worker_task
.await
.expect("joining worker")
.expect("running worker");
}
#[tokio::test]
@ -48,18 +54,10 @@ async fn integration_check_user_status() {
let user_name = make_user_name("check_user_status");
let credentials = make_credentials(user_name.clone());
let channel = Client::new(stream).login(credentials).await.unwrap();
let (request_tx, mut request_rx) = mpsc::channel(1);
let outbound = async_stream::stream! {
while let Some(request) = request_rx.recv().await {
yield request;
}
};
let mut worker = Client::new(stream).login(credentials).await.unwrap();
let inbound = channel.run(outbound);
tokio::pin!(inbound);
let (request_tx, request_rx) = mpsc::channel(100);
let (response_tx, mut response_rx) = mpsc::channel(100);
request_tx
.send(
@ -69,15 +67,23 @@ async fn integration_check_user_status() {
.into(),
)
.await
.unwrap();
.expect("sending request");
let worker_task =
tokio::spawn(async move { worker.run(response_tx, request_rx).await });
while let Some(result) = inbound.next().await {
let response = result.unwrap();
log::info!("Waiting for responses.");
while let Some(response) = response_rx.recv().await {
if let ServerResponse::UserStatusResponse(response) = response {
assert_eq!(response.user_name, user_name);
return;
}
}
worker_task
.await
.expect("joining worker")
.expect("running worker");
unreachable!();
}

Loading…
Cancel
Save