Browse Source

Introduce proto Worker.

wip
Titouan Rigoudy 4 years ago
parent
commit
0aa0c48591
2 changed files with 189 additions and 0 deletions
  1. +4
    -0
      proto/src/core/mod.rs
  2. +185
    -0
      proto/src/core/worker.rs

+ 4
- 0
proto/src/core/mod.rs View File

@ -4,6 +4,10 @@ pub mod frame;
mod prefix;
mod u32;
mod user;
// TODO: Remove `pub` qualifier, depend on re-exports.
pub mod value;
mod worker;
pub use user::{User, UserStatus};
pub use worker::{Worker, WorkerError};
pub use value::{ValueDecode, ValueEncode};

+ 185
- 0
proto/src/core/worker.rs View File

@ -0,0 +1,185 @@
use std::fmt::Debug;
use std::io;
use log::debug;
use thiserror::Error;
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use crate::core::frame::{FrameReader, FrameWriter};
use crate::core::value::{ValueDecode, ValueEncode};
/// An error that arose while exchanging messages over a `Channel`.
#[derive(Debug, Error)]
pub enum WorkerError {
#[error("read error: {0}")]
ReadError(io::Error),
#[error("write error: {0}")]
WriteError(io::Error),
#[error("incoming frame channel is unexpectedly closed")]
IncomingChannelClosed,
}
async fn forward_incoming<ReadFrame: ValueDecode + Debug>(
mut reader: FrameReader<ReadFrame, OwnedReadHalf>,
incoming_tx: mpsc::Sender<ReadFrame>,
) -> Result<(), WorkerError> {
while let Some(frame) = reader.read().await.map_err(WorkerError::ReadError)? {
debug!("Channel: received frame: {:?}", frame);
if let Err(_) = incoming_tx.send(frame).await {
return Err(WorkerError::IncomingChannelClosed);
}
}
debug!("Stopping incoming handler: frame reader is closed");
Ok(())
}
async fn forward_outgoing<WriteFrame: ValueEncode + Debug>(
mut outgoing_rx: mpsc::Receiver<WriteFrame>,
mut writer: FrameWriter<WriteFrame, OwnedWriteHalf>,
) -> Result<(), WorkerError> {
while let Some(frame) = outgoing_rx.recv().await {
debug!("Channel: sending frame: {:?}", frame);
writer
.write(&frame)
.await
.map_err(WorkerError::WriteError)?;
}
debug!("Stopping outgoing handler: channel is closed");
Ok(())
}
/// A worker that operates a full-duplex connection exchanging frames over TCP.
pub struct Worker<ReadFrame, WriteFrame> {
reader: FrameReader<ReadFrame, OwnedReadHalf>,
writer: FrameWriter<WriteFrame, OwnedWriteHalf>,
incoming_tx: mpsc::Sender<ReadFrame>,
outgoing_rx: mpsc::Receiver<WriteFrame>,
}
impl<ReadFrame, WriteFrame> Worker<ReadFrame, WriteFrame>
where
ReadFrame: ValueDecode + Debug,
WriteFrame: ValueEncode + Debug,
{
fn new(
stream: TcpStream,
incoming_tx: mpsc::Sender<ReadFrame>,
outgoing_rx: mpsc::Receiver<WriteFrame>,
) -> Self {
let (read_half, write_half) = stream.into_split();
let reader = FrameReader::new(read_half);
let writer = FrameWriter::new(write_half);
Self {
reader,
writer,
incoming_tx,
outgoing_rx,
}
}
async fn run(self) -> Result<(), WorkerError> {
tokio::select! {
result = forward_incoming(self.reader, self.incoming_tx) => {
debug!("{:?}", result);
},
result = forward_outgoing(self.outgoing_rx, self.writer) => {
debug!("{:?}", result);
},
};
Ok(())
}
}
#[cfg(test)]
mod tests {
use tokio::io::AsyncReadExt;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc;
use crate::core::frame::{FrameReader, FrameWriter};
use super::Worker;
// Enable capturing logs in tests.
fn init() {
let _ = env_logger::builder().is_test(true).try_init();
}
#[tokio::test]
async fn forwards_incoming_frames() {
init();
let listener = TcpListener::bind("localhost:0").await.unwrap();
let address = listener.local_addr().unwrap();
let listener_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let (mut read_half, write_half) = stream.split();
let mut writer = FrameWriter::new(write_half);
let frame: u32 = 42;
writer.write(&frame).await.expect("writing frame");
let mut buf = Vec::new();
read_half.read_to_end(&mut buf).await.expect("waiting for eof");
assert_eq!(buf, Vec::<u8>::new());
});
let stream = TcpStream::connect(address).await.expect("connecting");
let (request_tx, request_rx) = mpsc::channel::<u32>(100);
let (response_tx, mut response_rx) = mpsc::channel::<u32>(100);
let worker = Worker::new(stream, response_tx, request_rx);
let worker_task = tokio::spawn(worker.run());
let frame = response_rx.recv().await.expect("receiving frame");
assert_eq!(frame, 42);
// Signal to the worker that it should stop running.
drop(request_tx);
worker_task.await.expect("joining worker").expect("running worker");
listener_task.await.expect("joining listener");
}
#[tokio::test]
async fn forwards_outgoing_frames() {
init();
let listener = TcpListener::bind("localhost:0").await.unwrap();
let address = listener.local_addr().unwrap();
let listener_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let mut reader = FrameReader::new(stream);
let frame = reader.read().await.expect("reading frame");
assert_eq!(frame, Some(42u32));
});
let stream = TcpStream::connect(address).await.expect("connecting");
let (request_tx, request_rx) = mpsc::channel::<u32>(100);
let (response_tx, _response_rx) = mpsc::channel::<u32>(100);
let worker = Worker::new(stream, response_tx, request_rx);
let worker_task = tokio::spawn(worker.run());
request_tx.send(42).await.expect("sending frame");
// Signal to the worker that it should stop running.
drop(request_tx);
worker_task.await.expect("joining worker").expect("running worker");
listener_task.await.expect("joining listener");
}
}

Loading…
Cancel
Save