Solstice client.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

310 lines
9.1 KiB

use std::fmt::Debug;
use std::future::Future;
use std::io;
use futures::stream::{Stream, StreamExt};
use log::debug;
use thiserror::Error;
use tokio::net::{
tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpStream,
};
use crate::core::frame::{FrameReader, FrameWriter};
use crate::core::value::{ValueDecode, ValueEncode};
/// An error that arose while exchanging messages over a `Channel`.
#[derive(Debug, Error)]
pub enum ChannelError {
#[error("unexpected end of file")]
UnexpectedEof,
#[error("i/o error: {0}")]
IOError(#[from] io::Error),
}
impl ChannelError {
#[cfg(test)]
pub fn is_unexpected_eof(&self) -> bool {
match self {
ChannelError::UnexpectedEof => true,
_ => false,
}
}
}
/// A wrapper around a frame reader. Logically a part of `Channel`.
///
/// Exists to provide `Channel` functionality that requires borrowing the
/// channel's reader only. This allows borrowing both the reader and the writer
/// at the same time in `Channel` without resorting to static methods.
#[derive(Debug)]
struct ChannelReader<ReadFrame> {
inner: FrameReader<ReadFrame, OwnedReadHalf>,
}
impl<ReadFrame> ChannelReader<ReadFrame>
where
ReadFrame: ValueDecode + Debug,
{
async fn read(&mut self) -> io::Result<Option<ReadFrame>> {
self.inner.read().await.map(|frame| {
debug!("Channel: received frame: {:?}", frame);
frame
})
}
async fn read_strict(&mut self) -> Result<ReadFrame, ChannelError> {
match self.read().await? {
None => Err(ChannelError::UnexpectedEof),
Some(frame) => Ok(frame),
}
}
}
/// An asynchronous bidirectional message channel over TCP.
#[derive(Debug)]
pub struct Channel<ReadFrame, WriteFrame> {
reader: ChannelReader<ReadFrame>,
writer: FrameWriter<WriteFrame, OwnedWriteHalf>,
}
impl<ReadFrame, WriteFrame> Channel<ReadFrame, WriteFrame>
where
ReadFrame: ValueDecode + Debug,
WriteFrame: ValueEncode + Debug,
{
/// Wraps the given `stream` to yield a message channel.
pub fn new(stream: TcpStream) -> Self {
let (read_half, write_half) = stream.into_split();
Self {
reader: ChannelReader {
inner: FrameReader::new(read_half),
},
writer: FrameWriter::new(write_half),
}
}
// This future sends all the requests from `request_stream` through `writer`
// until the stream is finished, then resolves.
async fn send<S: Stream<Item = WriteFrame>>(
writer: &mut FrameWriter<WriteFrame, OwnedWriteHalf>,
send_stream: S,
) -> io::Result<()> {
tokio::pin!(send_stream);
while let Some(frame) = send_stream.next().await {
debug!("Channel: sending frame: {:?}", frame);
writer.write(&frame).await?;
}
Ok(())
}
// It would be easier to inline this `select!` call inside `run()`, but that
// fails due to some weird, undiagnosed error due to the interaction of
// `async_stream::try_stream!`, `select!` and the `?` operator.
async fn run_once<S: Future<Output = io::Result<()>>>(
send_task: S,
reader: &mut ChannelReader<ReadFrame>,
) -> Result<Option<ReadFrame>, ChannelError> {
tokio::select! {
send_result = send_task => {
send_result?;
Ok(None)
},
read_result = reader.read_strict() => read_result.map(Some),
}
}
/// Attempts to read a single frame from the underlying stream.
pub async fn read(&mut self) -> Result<ReadFrame, ChannelError> {
self.reader.read_strict().await
}
/// Attempts to write a single frame to the underlying stream.
pub async fn write(&mut self, frame: &WriteFrame) -> io::Result<()> {
self.writer.write(frame).await
}
/// Sends the given stream of frames while receiving frames in return.
/// Once `send_stream` is exhausted, shuts down the underlying TCP stream,
/// drains incoming frames, then terminates.
pub fn run<S: Stream<Item = WriteFrame>>(
mut self,
send_stream: S,
) -> impl Stream<Item = Result<ReadFrame, ChannelError>> {
async_stream::try_stream! {
// Drive the main loop: send requests and receive responses.
//
// We make a big future out of the operation of waiting for requests
// to send and from `request_stream` and sending them out through
// `self.writer`, that we can then poll repeatedly and concurrently
// with polling for responses. This allows us to concurrently write
// and read from the underlying `TcpStream` in full duplex mode.
{
let send_task = Self::send(&mut self.writer, send_stream);
tokio::pin!(send_task);
while let Some(frame) =
Self::run_once(&mut send_task, &mut self.reader).await? {
yield frame;
}
}
debug!("Channel: shutting down writer");
self.writer.shutdown().await?;
// Drain the receiving end of the connection.
while let Some(frame) = self.reader.read().await? {
yield frame;
}
}
}
}
#[cfg(test)]
mod tests {
use futures::stream::{self, StreamExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::oneshot;
use super::Channel;
// Enable capturing logs in tests.
fn init() {
let _ = env_logger::builder().is_test(true).try_init();
}
#[tokio::test]
async fn read_write() {
init();
let listener = TcpListener::bind("localhost:0").await.unwrap();
let address = listener.local_addr().unwrap();
let listener_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let mut channel = Channel::<u32, u32>::new(stream);
assert_eq!(channel.read().await.unwrap(), 1);
channel.write(&2).await.unwrap();
});
let stream = TcpStream::connect(address).await.unwrap();
let mut channel = Channel::<u32, u32>::new(stream);
channel.write(&1).await.unwrap();
assert_eq!(channel.read().await.unwrap(), 2);
listener_task.await.unwrap();
}
#[tokio::test]
async fn read_eof() {
init();
let listener = TcpListener::bind("localhost:0").await.unwrap();
let address = listener.local_addr().unwrap();
let listener_task = tokio::spawn(async move {
// Accept the stream and immediately drop/close it.
listener.accept().await.unwrap();
});
let stream = TcpStream::connect(address).await.unwrap();
let mut channel = Channel::<u32, u32>::new(stream);
assert!(channel.read().await.unwrap_err().is_unexpected_eof());
listener_task.await.unwrap();
}
#[tokio::test]
async fn open_close() {
init();
let listener = TcpListener::bind("localhost:0").await.unwrap();
let address = listener.local_addr().unwrap();
let listener_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let channel = Channel::<u32, u32>::new(stream);
// Wait forever, receive no messages.
let inbound = channel.run(stream::pending());
tokio::pin!(inbound);
// The server observes an unexpected EOF error when the client
// decides to close the channel.
let error = inbound.next().await.unwrap().unwrap_err();
assert!(error.is_unexpected_eof());
});
let stream = TcpStream::connect(address).await.unwrap();
let channel = Channel::<u32, u32>::new(stream);
// Stop immediately, receive no messages.
let inbound = channel.run(stream::empty());
tokio::pin!(inbound);
// The channel is closed cleanly from the client's point of view.
assert!(inbound.next().await.is_none());
listener_task.await.unwrap();
}
#[tokio::test]
async fn simple_exchange() {
init();
let listener = TcpListener::bind("localhost:0").await.unwrap();
let address = listener.local_addr().unwrap();
let listener_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let channel = Channel::<u32, u32>::new(stream);
let (tx, rx) = oneshot::channel::<u32>();
// Send one message, then wait forever.
let outbound =
stream::once(async move { rx.await.unwrap() }).chain(stream::pending());
let inbound = channel.run(outbound);
tokio::pin!(inbound);
// The server receives the client's message.
assert_eq!(inbound.next().await.unwrap().unwrap(), 1);
// Server responds.
tx.send(1001).unwrap();
// The server observes an unexpected EOF error when the client
// decides to close the channel.
let error = inbound.next().await.unwrap().unwrap_err();
assert!(error.is_unexpected_eof());
});
let stream = TcpStream::connect(address).await.unwrap();
let channel = Channel::<u32, u32>::new(stream);
let (tx, rx) = oneshot::channel::<()>();
// Send one message then wait for a reply. If we did not wait, we might
// shut down the channel before the server has a chance to respond.
let outbound = async_stream::stream! {
yield 1;
rx.await.unwrap();
};
let inbound = channel.run(outbound);
tokio::pin!(inbound);
// The client receives the server's message.
assert_eq!(inbound.next().await.unwrap().unwrap(), 1001);
// Signal to the client that we should shut down.
tx.send(()).unwrap();
assert!(inbound.next().await.is_none());
listener_task.await.unwrap();
}
}