use std::fmt::Debug;
|
|
use std::future::Future;
|
|
use std::io;
|
|
|
|
use futures::stream::{Stream, StreamExt};
|
|
use log::debug;
|
|
use thiserror::Error;
|
|
use tokio::net::{
|
|
tcp::{OwnedReadHalf, OwnedWriteHalf},
|
|
TcpStream,
|
|
};
|
|
|
|
use crate::core::frame::{FrameReader, FrameWriter};
|
|
use crate::core::value::{ValueDecode, ValueEncode};
|
|
|
|
/// An error that arose while exchanging messages over a `Channel`.
|
|
#[derive(Debug, Error)]
|
|
pub enum ChannelError {
|
|
#[error("unexpected end of file")]
|
|
UnexpectedEof,
|
|
|
|
#[error("i/o error: {0}")]
|
|
IOError(#[from] io::Error),
|
|
}
|
|
|
|
impl ChannelError {
|
|
#[cfg(test)]
|
|
pub fn is_unexpected_eof(&self) -> bool {
|
|
match self {
|
|
ChannelError::UnexpectedEof => true,
|
|
_ => false,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// A wrapper around a frame reader. Logically a part of `Channel`.
|
|
///
|
|
/// Exists to provide `Channel` functionality that requires borrowing the
|
|
/// channel's reader only. This allows borrowing both the reader and the writer
|
|
/// at the same time in `Channel` without resorting to static methods.
|
|
#[derive(Debug)]
|
|
struct ChannelReader<ReadFrame> {
|
|
inner: FrameReader<ReadFrame, OwnedReadHalf>,
|
|
}
|
|
|
|
impl<ReadFrame> ChannelReader<ReadFrame>
|
|
where
|
|
ReadFrame: ValueDecode + Debug,
|
|
{
|
|
async fn read(&mut self) -> io::Result<Option<ReadFrame>> {
|
|
self.inner.read().await.map(|frame| {
|
|
debug!("Channel: received frame: {:?}", frame);
|
|
frame
|
|
})
|
|
}
|
|
|
|
async fn read_strict(&mut self) -> Result<ReadFrame, ChannelError> {
|
|
match self.read().await? {
|
|
None => Err(ChannelError::UnexpectedEof),
|
|
Some(frame) => Ok(frame),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// An asynchronous bidirectional message channel over TCP.
|
|
#[derive(Debug)]
|
|
pub struct Channel<ReadFrame, WriteFrame> {
|
|
reader: ChannelReader<ReadFrame>,
|
|
writer: FrameWriter<WriteFrame, OwnedWriteHalf>,
|
|
}
|
|
|
|
impl<ReadFrame, WriteFrame> Channel<ReadFrame, WriteFrame>
|
|
where
|
|
ReadFrame: ValueDecode + Debug,
|
|
WriteFrame: ValueEncode + Debug,
|
|
{
|
|
/// Wraps the given `stream` to yield a message channel.
|
|
pub fn new(stream: TcpStream) -> Self {
|
|
let (read_half, write_half) = stream.into_split();
|
|
Self {
|
|
reader: ChannelReader {
|
|
inner: FrameReader::new(read_half),
|
|
},
|
|
writer: FrameWriter::new(write_half),
|
|
}
|
|
}
|
|
|
|
// This future sends all the requests from `request_stream` through `writer`
|
|
// until the stream is finished, then resolves.
|
|
async fn send<S: Stream<Item = WriteFrame>>(
|
|
writer: &mut FrameWriter<WriteFrame, OwnedWriteHalf>,
|
|
send_stream: S,
|
|
) -> io::Result<()> {
|
|
tokio::pin!(send_stream);
|
|
while let Some(frame) = send_stream.next().await {
|
|
debug!("Channel: sending frame: {:?}", frame);
|
|
writer.write(&frame).await?;
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
// It would be easier to inline this `select!` call inside `run()`, but that
|
|
// fails due to some weird, undiagnosed error due to the interaction of
|
|
// `async_stream::try_stream!`, `select!` and the `?` operator.
|
|
async fn run_once<S: Future<Output = io::Result<()>>>(
|
|
send_task: S,
|
|
reader: &mut ChannelReader<ReadFrame>,
|
|
) -> Result<Option<ReadFrame>, ChannelError> {
|
|
tokio::select! {
|
|
send_result = send_task => {
|
|
send_result?;
|
|
Ok(None)
|
|
},
|
|
read_result = reader.read_strict() => read_result.map(Some),
|
|
}
|
|
}
|
|
|
|
/// Attempts to read a single frame from the underlying stream.
|
|
pub async fn read(&mut self) -> Result<ReadFrame, ChannelError> {
|
|
self.reader.read_strict().await
|
|
}
|
|
|
|
/// Attempts to write a single frame to the underlying stream.
|
|
pub async fn write(&mut self, frame: &WriteFrame) -> io::Result<()> {
|
|
self.writer.write(frame).await
|
|
}
|
|
|
|
/// Sends the given stream of frames while receiving frames in return.
|
|
/// Once `send_stream` is exhausted, shuts down the underlying TCP stream,
|
|
/// drains incoming frames, then terminates.
|
|
pub fn run<S: Stream<Item = WriteFrame>>(
|
|
mut self,
|
|
send_stream: S,
|
|
) -> impl Stream<Item = Result<ReadFrame, ChannelError>> {
|
|
async_stream::try_stream! {
|
|
// Drive the main loop: send requests and receive responses.
|
|
//
|
|
// We make a big future out of the operation of waiting for requests
|
|
// to send and from `request_stream` and sending them out through
|
|
// `self.writer`, that we can then poll repeatedly and concurrently
|
|
// with polling for responses. This allows us to concurrently write
|
|
// and read from the underlying `TcpStream` in full duplex mode.
|
|
{
|
|
let send_task = Self::send(&mut self.writer, send_stream);
|
|
tokio::pin!(send_task);
|
|
|
|
while let Some(frame) =
|
|
Self::run_once(&mut send_task, &mut self.reader).await? {
|
|
yield frame;
|
|
}
|
|
}
|
|
|
|
debug!("Channel: shutting down writer");
|
|
self.writer.shutdown().await?;
|
|
|
|
// Drain the receiving end of the connection.
|
|
while let Some(frame) = self.reader.read().await? {
|
|
yield frame;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use futures::stream::{self, StreamExt};
|
|
use tokio::net::{TcpListener, TcpStream};
|
|
use tokio::sync::oneshot;
|
|
|
|
use super::Channel;
|
|
|
|
// Enable capturing logs in tests.
|
|
fn init() {
|
|
let _ = env_logger::builder().is_test(true).try_init();
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn read_write() {
|
|
init();
|
|
|
|
let listener = TcpListener::bind("localhost:0").await.unwrap();
|
|
let address = listener.local_addr().unwrap();
|
|
|
|
let listener_task = tokio::spawn(async move {
|
|
let (stream, _) = listener.accept().await.unwrap();
|
|
let mut channel = Channel::<u32, u32>::new(stream);
|
|
|
|
assert_eq!(channel.read().await.unwrap(), 1);
|
|
channel.write(&2).await.unwrap();
|
|
});
|
|
|
|
let stream = TcpStream::connect(address).await.unwrap();
|
|
let mut channel = Channel::<u32, u32>::new(stream);
|
|
|
|
channel.write(&1).await.unwrap();
|
|
assert_eq!(channel.read().await.unwrap(), 2);
|
|
|
|
listener_task.await.unwrap();
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn read_eof() {
|
|
init();
|
|
|
|
let listener = TcpListener::bind("localhost:0").await.unwrap();
|
|
let address = listener.local_addr().unwrap();
|
|
|
|
let listener_task = tokio::spawn(async move {
|
|
// Accept the stream and immediately drop/close it.
|
|
listener.accept().await.unwrap();
|
|
});
|
|
|
|
let stream = TcpStream::connect(address).await.unwrap();
|
|
let mut channel = Channel::<u32, u32>::new(stream);
|
|
|
|
assert!(channel.read().await.unwrap_err().is_unexpected_eof());
|
|
|
|
listener_task.await.unwrap();
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn open_close() {
|
|
init();
|
|
|
|
let listener = TcpListener::bind("localhost:0").await.unwrap();
|
|
let address = listener.local_addr().unwrap();
|
|
|
|
let listener_task = tokio::spawn(async move {
|
|
let (stream, _) = listener.accept().await.unwrap();
|
|
let channel = Channel::<u32, u32>::new(stream);
|
|
|
|
// Wait forever, receive no messages.
|
|
let inbound = channel.run(stream::pending());
|
|
tokio::pin!(inbound);
|
|
|
|
// The server observes an unexpected EOF error when the client
|
|
// decides to close the channel.
|
|
let error = inbound.next().await.unwrap().unwrap_err();
|
|
assert!(error.is_unexpected_eof());
|
|
});
|
|
|
|
let stream = TcpStream::connect(address).await.unwrap();
|
|
let channel = Channel::<u32, u32>::new(stream);
|
|
|
|
// Stop immediately, receive no messages.
|
|
let inbound = channel.run(stream::empty());
|
|
tokio::pin!(inbound);
|
|
|
|
// The channel is closed cleanly from the client's point of view.
|
|
assert!(inbound.next().await.is_none());
|
|
|
|
listener_task.await.unwrap();
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn simple_exchange() {
|
|
init();
|
|
|
|
let listener = TcpListener::bind("localhost:0").await.unwrap();
|
|
let address = listener.local_addr().unwrap();
|
|
|
|
let listener_task = tokio::spawn(async move {
|
|
let (stream, _) = listener.accept().await.unwrap();
|
|
let channel = Channel::<u32, u32>::new(stream);
|
|
|
|
let (tx, rx) = oneshot::channel::<u32>();
|
|
|
|
// Send one message, then wait forever.
|
|
let outbound =
|
|
stream::once(async move { rx.await.unwrap() }).chain(stream::pending());
|
|
let inbound = channel.run(outbound);
|
|
tokio::pin!(inbound);
|
|
|
|
// The server receives the client's message.
|
|
assert_eq!(inbound.next().await.unwrap().unwrap(), 1);
|
|
|
|
// Server responds.
|
|
tx.send(1001).unwrap();
|
|
|
|
// The server observes an unexpected EOF error when the client
|
|
// decides to close the channel.
|
|
let error = inbound.next().await.unwrap().unwrap_err();
|
|
assert!(error.is_unexpected_eof());
|
|
});
|
|
|
|
let stream = TcpStream::connect(address).await.unwrap();
|
|
let channel = Channel::<u32, u32>::new(stream);
|
|
|
|
let (tx, rx) = oneshot::channel::<()>();
|
|
|
|
// Send one message then wait for a reply. If we did not wait, we might
|
|
// shut down the channel before the server has a chance to respond.
|
|
let outbound = async_stream::stream! {
|
|
yield 1;
|
|
rx.await.unwrap();
|
|
};
|
|
let inbound = channel.run(outbound);
|
|
tokio::pin!(inbound);
|
|
|
|
// The client receives the server's message.
|
|
assert_eq!(inbound.next().await.unwrap().unwrap(), 1001);
|
|
|
|
// Signal to the client that we should shut down.
|
|
tx.send(()).unwrap();
|
|
|
|
assert!(inbound.next().await.is_none());
|
|
|
|
listener_task.await.unwrap();
|
|
}
|
|
}
|