use std::fmt::Debug; use std::future::Future; use std::io; use futures::stream::{Stream, StreamExt}; use log::debug; use thiserror::Error; use tokio::net::{ tcp::{OwnedReadHalf, OwnedWriteHalf}, TcpStream, }; use crate::core::frame::{FrameReader, FrameWriter}; use crate::core::value::{ValueDecode, ValueEncode}; /// An error that arose while exchanging messages over a `Channel`. #[derive(Debug, Error)] pub enum ChannelError { #[error("unexpected end of file")] UnexpectedEof, #[error("i/o error: {0}")] IOError(#[from] io::Error), } impl ChannelError { #[cfg(test)] pub fn is_unexpected_eof(&self) -> bool { match self { ChannelError::UnexpectedEof => true, _ => false, } } } /// A wrapper around a frame reader. Logically a part of `Channel`. /// /// Exists to provide `Channel` functionality that requires borrowing the /// channel's reader only. This allows borrowing both the reader and the writer /// at the same time in `Channel` without resorting to static methods. #[derive(Debug)] struct ChannelReader { inner: FrameReader, } impl ChannelReader where ReadFrame: ValueDecode + Debug, { async fn read(&mut self) -> io::Result> { self.inner.read().await.map(|frame| { debug!("Channel: received frame: {:?}", frame); frame }) } async fn read_strict(&mut self) -> Result { match self.read().await? { None => Err(ChannelError::UnexpectedEof), Some(frame) => Ok(frame), } } } /// An asynchronous bidirectional message channel over TCP. #[derive(Debug)] pub struct Channel { reader: ChannelReader, writer: FrameWriter, } impl Channel where ReadFrame: ValueDecode + Debug, WriteFrame: ValueEncode + Debug, { /// Wraps the given `stream` to yield a message channel. pub fn new(stream: TcpStream) -> Self { let (read_half, write_half) = stream.into_split(); Self { reader: ChannelReader { inner: FrameReader::new(read_half), }, writer: FrameWriter::new(write_half), } } // This future sends all the requests from `request_stream` through `writer` // until the stream is finished, then resolves. async fn send>( writer: &mut FrameWriter, send_stream: S, ) -> io::Result<()> { tokio::pin!(send_stream); while let Some(frame) = send_stream.next().await { debug!("Channel: sending frame: {:?}", frame); writer.write(&frame).await?; } Ok(()) } // It would be easier to inline this `select!` call inside `run()`, but that // fails due to some weird, undiagnosed error due to the interaction of // `async_stream::try_stream!`, `select!` and the `?` operator. async fn run_once>>( send_task: S, reader: &mut ChannelReader, ) -> Result, ChannelError> { tokio::select! { send_result = send_task => { send_result?; Ok(None) }, read_result = reader.read_strict() => read_result.map(Some), } } /// Attempts to read a single frame from the underlying stream. pub async fn read(&mut self) -> Result { self.reader.read_strict().await } /// Attempts to write a single frame to the underlying stream. pub async fn write(&mut self, frame: &WriteFrame) -> io::Result<()> { self.writer.write(frame).await } /// Sends the given stream of frames while receiving frames in return. /// Once `send_stream` is exhausted, shuts down the underlying TCP stream, /// drains incoming frames, then terminates. pub fn run>( mut self, send_stream: S, ) -> impl Stream> { async_stream::try_stream! { // Drive the main loop: send requests and receive responses. // // We make a big future out of the operation of waiting for requests // to send and from `request_stream` and sending them out through // `self.writer`, that we can then poll repeatedly and concurrently // with polling for responses. This allows us to concurrently write // and read from the underlying `TcpStream` in full duplex mode. { let send_task = Self::send(&mut self.writer, send_stream); tokio::pin!(send_task); while let Some(frame) = Self::run_once(&mut send_task, &mut self.reader).await? { yield frame; } } debug!("Channel: shutting down writer"); self.writer.shutdown().await?; // Drain the receiving end of the connection. while let Some(frame) = self.reader.read().await? { yield frame; } } } } #[cfg(test)] mod tests { use futures::stream::{self, StreamExt}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::oneshot; use super::Channel; // Enable capturing logs in tests. fn init() { let _ = env_logger::builder().is_test(true).try_init(); } #[tokio::test] async fn read_write() { init(); let listener = TcpListener::bind("localhost:0").await.unwrap(); let address = listener.local_addr().unwrap(); let listener_task = tokio::spawn(async move { let (stream, _) = listener.accept().await.unwrap(); let mut channel = Channel::::new(stream); assert_eq!(channel.read().await.unwrap(), 1); channel.write(&2).await.unwrap(); }); let stream = TcpStream::connect(address).await.unwrap(); let mut channel = Channel::::new(stream); channel.write(&1).await.unwrap(); assert_eq!(channel.read().await.unwrap(), 2); listener_task.await.unwrap(); } #[tokio::test] async fn read_eof() { init(); let listener = TcpListener::bind("localhost:0").await.unwrap(); let address = listener.local_addr().unwrap(); let listener_task = tokio::spawn(async move { // Accept the stream and immediately drop/close it. listener.accept().await.unwrap(); }); let stream = TcpStream::connect(address).await.unwrap(); let mut channel = Channel::::new(stream); assert!(channel.read().await.unwrap_err().is_unexpected_eof()); listener_task.await.unwrap(); } #[tokio::test] async fn open_close() { init(); let listener = TcpListener::bind("localhost:0").await.unwrap(); let address = listener.local_addr().unwrap(); let listener_task = tokio::spawn(async move { let (stream, _) = listener.accept().await.unwrap(); let channel = Channel::::new(stream); // Wait forever, receive no messages. let inbound = channel.run(stream::pending()); tokio::pin!(inbound); // The server observes an unexpected EOF error when the client // decides to close the channel. let error = inbound.next().await.unwrap().unwrap_err(); assert!(error.is_unexpected_eof()); }); let stream = TcpStream::connect(address).await.unwrap(); let channel = Channel::::new(stream); // Stop immediately, receive no messages. let inbound = channel.run(stream::empty()); tokio::pin!(inbound); // The channel is closed cleanly from the client's point of view. assert!(inbound.next().await.is_none()); listener_task.await.unwrap(); } #[tokio::test] async fn simple_exchange() { init(); let listener = TcpListener::bind("localhost:0").await.unwrap(); let address = listener.local_addr().unwrap(); let listener_task = tokio::spawn(async move { let (stream, _) = listener.accept().await.unwrap(); let channel = Channel::::new(stream); let (tx, rx) = oneshot::channel::(); // Send one message, then wait forever. let outbound = stream::once(async move { rx.await.unwrap() }).chain(stream::pending()); let inbound = channel.run(outbound); tokio::pin!(inbound); // The server receives the client's message. assert_eq!(inbound.next().await.unwrap().unwrap(), 1); // Server responds. tx.send(1001).unwrap(); // The server observes an unexpected EOF error when the client // decides to close the channel. let error = inbound.next().await.unwrap().unwrap_err(); assert!(error.is_unexpected_eof()); }); let stream = TcpStream::connect(address).await.unwrap(); let channel = Channel::::new(stream); let (tx, rx) = oneshot::channel::<()>(); // Send one message then wait for a reply. If we did not wait, we might // shut down the channel before the server has a chance to respond. let outbound = async_stream::stream! { yield 1; rx.await.unwrap(); }; let inbound = channel.run(outbound); tokio::pin!(inbound); // The client receives the server's message. assert_eq!(inbound.next().await.unwrap().unwrap(), 1001); // Signal to the client that we should shut down. tx.send(()).unwrap(); assert!(inbound.next().await.is_none()); listener_task.await.unwrap(); } }