From e06b3c86b63aa6a87c1e49acfbde3c23fafb026e Mon Sep 17 00:00:00 2001 From: Titouan Rigoudy Date: Thu, 15 Jul 2021 16:31:18 -0400 Subject: [PATCH] Define message-type-agnostic Channel type. --- proto/src/core/channel.rs | 253 +++++++++++++++++++++++++++++++++++++ proto/src/core/mod.rs | 1 + proto/src/server/client.rs | 153 +++++----------------- proto/tests/connect.rs | 6 +- 4 files changed, 287 insertions(+), 126 deletions(-) create mode 100644 proto/src/core/channel.rs diff --git a/proto/src/core/channel.rs b/proto/src/core/channel.rs new file mode 100644 index 0000000..f0379b5 --- /dev/null +++ b/proto/src/core/channel.rs @@ -0,0 +1,253 @@ +use std::fmt::Debug; +use std::future::Future; +use std::io; + +use futures::stream::{Stream, StreamExt}; +use log::debug; +use thiserror::Error; +use tokio::net::{ + tcp::{OwnedReadHalf, OwnedWriteHalf}, + TcpStream, +}; + +use crate::core::frame::{FrameReader, FrameWriter}; +use crate::core::value::{ValueDecode, ValueEncode}; + +/// An error that arose while exchanging messages over a `Channel`. +#[derive(Debug, Error)] +pub enum ChannelError { + #[error("unexpected end of file")] + UnexpectedEof, + + #[error("i/o error: {0}")] + IOError(#[from] io::Error), +} + +impl ChannelError { + #[cfg(test)] + pub fn is_unexpected_eof(&self) -> bool { + match self { + ChannelError::UnexpectedEof => true, + _ => false, + } + } +} + +/// An asynchronous bidirectional message channel over TCP. +#[derive(Debug)] +pub struct Channel { + reader: FrameReader, + writer: FrameWriter, +} + +impl Channel +where + ReadFrame: ValueDecode + Debug, + WriteFrame: ValueEncode + Debug, +{ + /// Wraps the given `stream` to yield a message channel. + pub fn new(stream: TcpStream) -> Self { + let (read_half, write_half) = stream.into_split(); + Self { + reader: FrameReader::new(read_half), + writer: FrameWriter::new(write_half), + } + } + + async fn read( + reader: &mut FrameReader, + ) -> io::Result> { + reader.read().await.map(|frame| { + debug!("Channel: received frame: {:?}", frame); + frame + }) + } + + async fn read_strict( + reader: &mut FrameReader, + ) -> Result { + match reader.read().await? { + None => Err(ChannelError::UnexpectedEof), + Some(frame) => Ok(frame), + } + } + + // This future sends all the requests from `request_stream` through `writer` + // until the stream is finished, then resolves. + async fn send>( + writer: &mut FrameWriter, + send_stream: S, + ) -> io::Result<()> { + tokio::pin!(send_stream); + while let Some(frame) = send_stream.next().await { + debug!("Channel: sending frame: {:?}", frame); + writer.write(&frame).await?; + } + Ok(()) + } + + // It would be easier to inline this `select!` call inside `run()`, but that + // fails due to some weird, undiagnosed error due to the interaction of + // `async_stream::try_stream!`, `select!` and the `?` operator. + async fn run_once>>( + send_task: S, + reader: &mut FrameReader, + ) -> Result, ChannelError> { + tokio::select! { + send_result = send_task => { + send_result?; + Ok(None) + }, + read_result = Self::read_strict(reader) => read_result.map(Some), + } + } + + /// Attempts to read a single frame from the underlying stream. + pub async fn read_once(&mut self) -> Result { + Self::read_strict(&mut self.reader).await + } + + /// Attempts to write a single frame to the underlying stream. + pub async fn write_once(&mut self, frame: &WriteFrame) -> io::Result<()> { + self.writer.write(frame).await + } + + /// Sends the given stream of frames while receiving frames in return. + /// Once `send_stream` is exhausted, shuts down the underlying TCP stream, + /// drains incoming frames, then terminates. + pub fn run>( + mut self, + send_stream: S, + ) -> impl Stream> { + async_stream::try_stream! { + // Drive the main loop: send requests and receive responses. + // + // We make a big future out of the operation of waiting for requests + // to send and from `request_stream` and sending them out through + // `self.writer`, that we can then poll repeatedly and concurrently + // with polling for responses. This allows us to concurrently write + // and read from the underlying `TcpStream` in full duplex mode. + { + let send_task = Self::send(&mut self.writer, send_stream); + tokio::pin!(send_task); + + while let Some(frame) = + Self::run_once(&mut send_task, &mut self.reader).await? { + yield frame; + } + } + + debug!("Channel: shutting down writer"); + self.writer.shutdown().await?; + + // Drain the receiving end of the connection. + while let Some(frame) = Self::read(&mut self.reader).await? { + yield frame; + } + } + } +} + +#[cfg(test)] +mod tests { + use futures::stream::{self, StreamExt}; + use tokio::net::{TcpListener, TcpStream}; + use tokio::sync::oneshot; + + use super::Channel; + + // Enable capturing logs in tests. + fn init() { + let _ = env_logger::builder().is_test(true).try_init(); + } + + #[tokio::test] + async fn open_close() { + init(); + + let listener = TcpListener::bind("localhost:0").await.unwrap(); + let address = listener.local_addr().unwrap(); + + let listener_task = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let channel = Channel::::new(stream); + + // Wait forever, receive no messages. + let inbound = channel.run(stream::pending()); + tokio::pin!(inbound); + + // The server observes an unexpected EOF error when the client + // decides to close the channel. + let error = inbound.next().await.unwrap().unwrap_err(); + assert!(error.is_unexpected_eof()); + }); + + let stream = TcpStream::connect(address).await.unwrap(); + let channel = Channel::::new(stream); + + // Stop immediately, receive no messages. + let inbound = channel.run(stream::empty()); + tokio::pin!(inbound); + + // The channel is closed cleanly from the client's point of view. + assert!(inbound.next().await.is_none()); + + listener_task.await.unwrap(); + } + + #[tokio::test] + async fn simple_exchange() { + init(); + + let listener = TcpListener::bind("localhost:0").await.unwrap(); + let address = listener.local_addr().unwrap(); + + let listener_task = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let channel = Channel::::new(stream); + + let (tx, rx) = oneshot::channel::(); + + // Send one message, then wait forever. + let outbound = stream::once(async move { rx.await.unwrap() }).chain(stream::pending()); + let inbound = channel.run(outbound); + tokio::pin!(inbound); + + // The server receives the client's message. + assert_eq!(inbound.next().await.unwrap().unwrap(), 1); + + // Server responds. + tx.send(1001).unwrap(); + + // The server observes an unexpected EOF error when the client + // decides to close the channel. + let error = inbound.next().await.unwrap().unwrap_err(); + assert!(error.is_unexpected_eof()); + }); + + let stream = TcpStream::connect(address).await.unwrap(); + let channel = Channel::::new(stream); + + // Send a message, then close the channel. + let (tx, rx) = oneshot::channel::<()>(); + + // Send one message then wait for a reply. If we did not wait, we might + // shut down the channel before the server has a chance to respond. + let outbound = async_stream::stream! { + yield 1; + rx.await.unwrap(); + }; + let inbound = channel.run(outbound); + tokio::pin!(inbound); + + // The client receives the server's message. + assert_eq!(inbound.next().await.unwrap().unwrap(), 1001); + + // Signal to the client that we should shut down. + tx.send(()).unwrap(); + + assert!(inbound.next().await.is_none()); + + listener_task.await.unwrap(); + } +} diff --git a/proto/src/core/mod.rs b/proto/src/core/mod.rs index 9d2870d..0e336b3 100644 --- a/proto/src/core/mod.rs +++ b/proto/src/core/mod.rs @@ -1,3 +1,4 @@ +pub mod channel; pub mod constants; pub mod frame; mod prefix; diff --git a/proto/src/server/client.rs b/proto/src/server/client.rs index 3b1a772..39fbd10 100644 --- a/proto/src/server/client.rs +++ b/proto/src/server/client.rs @@ -1,17 +1,13 @@ //! A client interface for remote servers. -use std::future::Future; use std::io; -use futures::stream::{Stream, StreamExt}; +use futures::stream::Stream; use log::{debug, info}; use thiserror::Error; -use tokio::net::{ - tcp::{OwnedReadHalf, OwnedWriteHalf}, - TcpStream, -}; +use tokio::net::TcpStream; -use crate::core::frame::{FrameReader, FrameWriter}; +use crate::core::channel::{Channel, ChannelError}; use crate::server::{Credentials, LoginResponse, ServerRequest, ServerResponse, Version}; /// Specifies options for a new `Client`. @@ -22,8 +18,7 @@ pub struct ClientOptions { /// A client for the client-server protocol. pub struct Client { - reader: FrameReader, - writer: FrameWriter, + channel: Channel, } /// An error that arose while logging in to a remote server. @@ -35,30 +30,13 @@ pub enum ClientLoginError { #[error("unexpected response: {0:?}")] UnexpectedResponse(ServerResponse), - #[error("unexpected end of file")] - UnexpectedEof, - - #[error("i/o error: {0}")] - IOError(#[from] io::Error), -} - -/// An error that arose while running the client. -#[derive(Debug, Error)] -pub enum ClientRunError { - #[error("underlying stream was closed unexpectedly")] - StreamClosed, - - #[error("i/o error: {0}")] - IOError(#[from] io::Error), + #[error("channel error: {0}")] + ChannelError(#[from] ChannelError), } -impl ClientRunError { - #[cfg(test)] - fn is_stream_closed(&self) -> bool { - match self { - ClientRunError::StreamClosed => true, - _ => false, - } +impl From for ClientLoginError { + fn from(error: io::Error) -> Self { + ClientLoginError::from(ChannelError::from(error)) } } @@ -67,10 +45,8 @@ impl Client { tcp_stream: TcpStream, options: ClientOptions, ) -> Result { - let (read_half, write_half) = tcp_stream.into_split(); let mut client = Client { - reader: FrameReader::new(read_half), - writer: FrameWriter::new(write_half), + channel: Channel::new(tcp_stream), }; client.handshake(options).await?; @@ -85,111 +61,35 @@ impl Client { debug!("Client: sending login request: {:?}", login_request); let request = login_request.into(); - self.writer.write(&request).await?; + self.channel.write_once(&request).await?; - let response = self.reader.read().await?; + let response = self.channel.read_once().await?; debug!("Client: received first response: {:?}", response); match response { - Some(ServerResponse::LoginResponse(LoginResponse::LoginOk { + ServerResponse::LoginResponse(LoginResponse::LoginOk { motd, ip, password_md5_opt, - })) => { + }) => { info!("Client: Logged in successfully!"); info!("Client: Message Of The Day: {}", motd); info!("Client: Public IP address: {}", ip); info!("Client: Password MD5: {:?}", password_md5_opt); Ok(()) } - Some(ServerResponse::LoginResponse(LoginResponse::LoginFail { reason })) => { + ServerResponse::LoginResponse(LoginResponse::LoginFail { reason }) => { Err(ClientLoginError::LoginFailed(reason)) } - Some(response) => Err(ClientLoginError::UnexpectedResponse(response)), - None => Err(ClientLoginError::UnexpectedEof), - } - } - - async fn read( - reader: &mut FrameReader, - ) -> Result { - match reader.read().await? { - Some(response) => { - debug!("Client: received response: {:?}", response); - Ok(response) - } - None => Err(ClientRunError::StreamClosed), - } - } - - // This future sends all the requests from `request_stream` through `writer` - // until the stream is finished, then resolves. - async fn send( - writer: &mut FrameWriter, - mut request_stream: S, - ) -> io::Result<()> - where - S: Stream + Unpin, - { - while let Some(request) = request_stream.next().await { - debug!("Client: sending request: {:?}", request); - writer.write(&request).await?; - } - Ok(()) - } - - // It would be easier to inline this `select!` call inside `run()`, but that - // fails due to some weird, undiagnosed error due to the interaction of - // `async_stream::try_stream!`, `select!` and the `?` operator. - async fn run_once( - send: impl Future>, - reader: &mut FrameReader, - ) -> Result, ClientRunError> { - tokio::select! { - send_result = send => { - send_result?; - Ok(None) - }, - read_result = Self::read(reader) => { - let response = read_result?; - Ok(Some(response)) - }, + response => Err(ClientLoginError::UnexpectedResponse(response)), } } - pub fn run( - mut self, + pub fn run>( + self, request_stream: S, - ) -> impl Stream> + Unpin - where - S: Stream + Unpin, - { - Box::pin(async_stream::try_stream! { - // Drive the main loop: send requests and receive responses. - // - // We make a big future out of the operation of waiting for requests - // to send and from `request_stream` and sending them out through - // `self.writer`, that we can then poll repeatedly and concurrently - // with polling for responses. This allows us to concurrently write - // and read from the underlying `TcpStream` in full duplex mode. - { - let send = Self::send(&mut self.writer, request_stream); - tokio::pin!(send); - - while let Some(response) = Self::run_once(&mut send, &mut self.reader).await? { - yield response; - } - } - - debug!("Client: shutting down outbound stream"); - self.writer.shutdown().await?; - - // Drain the receiving end of the connection. - while let Some(response) = self.reader.read().await? { - debug!("Client: received response: {:?}", response); - yield response; - } - }) + ) -> impl Stream> { + self.channel.run(request_stream) } } @@ -236,7 +136,9 @@ mod tests { let client = Client::login(stream, client_options()).await.unwrap(); // Send nothing, receive no responses. - let mut inbound = client.run(empty()); + let inbound = client.run(empty()); + tokio::pin!(inbound); + assert!(inbound.next().await.is_none()); handle.shutdown(ShutdownType::LameDuck); @@ -276,7 +178,9 @@ mod tests { }); }); - let mut inbound = client.run(outbound); + let inbound = client.run(outbound); + tokio::pin!(inbound); + assert_eq!( inbound.next().await.unwrap().unwrap(), ServerResponse::UserStatusResponse(response) @@ -305,7 +209,8 @@ mod tests { } }); - let mut inbound = client.run(outbound); + let inbound = client.run(outbound); + tokio::pin!(inbound); // Server shuts down, closing its connection before the client has had a // chance to send all of `outbound`. @@ -320,7 +225,7 @@ mod tests { .await .unwrap() .unwrap_err() - .is_stream_closed()); + .is_unexpected_eof()); assert!(inbound.next().await.is_none()); } } diff --git a/proto/tests/connect.rs b/proto/tests/connect.rs index aef0d0a..a0b8204 100644 --- a/proto/tests/connect.rs +++ b/proto/tests/connect.rs @@ -39,7 +39,8 @@ async fn integration_connect() { let options = client_options(make_user_name("connect")); let client = Client::login(stream, options).await.unwrap(); - let mut inbound = client.run(stream::pending()); + let inbound = client.run(stream::pending()); + tokio::pin!(inbound); assert!(inbound.next().await.is_some()); } @@ -62,7 +63,8 @@ async fn integration_check_user_status() { } }); - let mut inbound = client.run(outbound); + let inbound = client.run(outbound); + tokio::pin!(inbound); request_tx .send(