|
|
@ -24,7 +24,7 @@ pub enum WorkerError { |
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
async fn forward_incoming<ReadFrame: ValueDecode + Debug>(
|
|
|
async fn forward_incoming<ReadFrame: ValueDecode + Debug>(
|
|
|
mut reader: FrameReader<ReadFrame, OwnedReadHalf>,
|
|
|
|
|
|
|
|
|
reader: &mut FrameReader<ReadFrame, OwnedReadHalf>,
|
|
|
incoming_tx: mpsc::Sender<ReadFrame>,
|
|
|
incoming_tx: mpsc::Sender<ReadFrame>,
|
|
|
) -> Result<(), WorkerError> {
|
|
|
) -> Result<(), WorkerError> {
|
|
|
while let Some(frame) = reader.read().await.map_err(WorkerError::ReadError)? {
|
|
|
while let Some(frame) = reader.read().await.map_err(WorkerError::ReadError)? {
|
|
|
@ -41,7 +41,7 @@ async fn forward_incoming<ReadFrame: ValueDecode + Debug>( |
|
|
|
|
|
|
|
|
async fn forward_outgoing<WriteFrame: ValueEncode + Debug>(
|
|
|
async fn forward_outgoing<WriteFrame: ValueEncode + Debug>(
|
|
|
mut outgoing_rx: mpsc::Receiver<WriteFrame>,
|
|
|
mut outgoing_rx: mpsc::Receiver<WriteFrame>,
|
|
|
mut writer: FrameWriter<WriteFrame, OwnedWriteHalf>,
|
|
|
|
|
|
|
|
|
writer: &mut FrameWriter<WriteFrame, OwnedWriteHalf>,
|
|
|
) -> Result<(), WorkerError> {
|
|
|
) -> Result<(), WorkerError> {
|
|
|
while let Some(frame) = outgoing_rx.recv().await {
|
|
|
while let Some(frame) = outgoing_rx.recv().await {
|
|
|
debug!("Sending frame: {:?}", frame);
|
|
|
debug!("Sending frame: {:?}", frame);
|
|
|
@ -56,11 +56,10 @@ async fn forward_outgoing<WriteFrame: ValueEncode + Debug>( |
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
/// A worker that operates a full-duplex connection exchanging frames over TCP.
|
|
|
/// A worker that operates a full-duplex connection exchanging frames over TCP.
|
|
|
|
|
|
#[derive(Debug)]
|
|
|
pub struct Worker<ReadFrame, WriteFrame> {
|
|
|
pub struct Worker<ReadFrame, WriteFrame> {
|
|
|
reader: FrameReader<ReadFrame, OwnedReadHalf>,
|
|
|
reader: FrameReader<ReadFrame, OwnedReadHalf>,
|
|
|
writer: FrameWriter<WriteFrame, OwnedWriteHalf>,
|
|
|
writer: FrameWriter<WriteFrame, OwnedWriteHalf>,
|
|
|
incoming_tx: mpsc::Sender<ReadFrame>,
|
|
|
|
|
|
outgoing_rx: mpsc::Receiver<WriteFrame>,
|
|
|
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
impl<ReadFrame, WriteFrame> Worker<ReadFrame, WriteFrame>
|
|
|
impl<ReadFrame, WriteFrame> Worker<ReadFrame, WriteFrame>
|
|
|
@ -68,30 +67,33 @@ where |
|
|
ReadFrame: ValueDecode + Debug,
|
|
|
ReadFrame: ValueDecode + Debug,
|
|
|
WriteFrame: ValueEncode + Debug,
|
|
|
WriteFrame: ValueEncode + Debug,
|
|
|
{
|
|
|
{
|
|
|
fn new(
|
|
|
|
|
|
stream: TcpStream,
|
|
|
|
|
|
incoming_tx: mpsc::Sender<ReadFrame>,
|
|
|
|
|
|
outgoing_rx: mpsc::Receiver<WriteFrame>,
|
|
|
|
|
|
) -> Self {
|
|
|
|
|
|
|
|
|
pub fn new(stream: TcpStream) -> Self {
|
|
|
let (read_half, write_half) = stream.into_split();
|
|
|
let (read_half, write_half) = stream.into_split();
|
|
|
let reader = FrameReader::new(read_half);
|
|
|
let reader = FrameReader::new(read_half);
|
|
|
let writer = FrameWriter::new(write_half);
|
|
|
let writer = FrameWriter::new(write_half);
|
|
|
Self {
|
|
|
|
|
|
reader,
|
|
|
|
|
|
writer,
|
|
|
|
|
|
incoming_tx,
|
|
|
|
|
|
outgoing_rx,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
Self { reader, writer }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
async fn run(self) -> Result<(), WorkerError> {
|
|
|
|
|
|
|
|
|
pub async fn run(
|
|
|
|
|
|
&mut self,
|
|
|
|
|
|
incoming_tx: mpsc::Sender<ReadFrame>,
|
|
|
|
|
|
outgoing_rx: mpsc::Receiver<WriteFrame>,
|
|
|
|
|
|
) -> Result<(), WorkerError> {
|
|
|
tokio::select! {
|
|
|
tokio::select! {
|
|
|
result = forward_incoming(self.reader, self.incoming_tx) => result?,
|
|
|
|
|
|
result = forward_outgoing(self.outgoing_rx, self.writer) => result?,
|
|
|
|
|
|
|
|
|
result = forward_incoming(&mut self.reader, incoming_tx) => result?,
|
|
|
|
|
|
result = forward_outgoing(outgoing_rx, &mut self.writer) => result?,
|
|
|
};
|
|
|
};
|
|
|
|
|
|
|
|
|
Ok(())
|
|
|
Ok(())
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
pub fn into_inner(self) -> TcpStream {
|
|
|
|
|
|
let read_half = self.reader.into_inner();
|
|
|
|
|
|
let write_half = self.writer.into_inner();
|
|
|
|
|
|
read_half
|
|
|
|
|
|
.reunite(write_half)
|
|
|
|
|
|
.expect("reuniting tcp stream halves")
|
|
|
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
#[cfg(test)]
|
|
|
#[cfg(test)]
|
|
|
@ -113,15 +115,17 @@ mod tests { |
|
|
async fn stops_on_read_error() {
|
|
|
async fn stops_on_read_error() {
|
|
|
init();
|
|
|
init();
|
|
|
|
|
|
|
|
|
let listener = TcpListener::bind("localhost:0").await.expect("binding listener");
|
|
|
|
|
|
|
|
|
let listener = TcpListener::bind("localhost:0")
|
|
|
|
|
|
.await
|
|
|
|
|
|
.expect("binding listener");
|
|
|
let address = listener.local_addr().expect("getting local address");
|
|
|
let address = listener.local_addr().expect("getting local address");
|
|
|
|
|
|
|
|
|
let listener_task = tokio::spawn(async move {
|
|
|
let listener_task = tokio::spawn(async move {
|
|
|
let (mut stream, _) = listener.accept().await.expect("accepting");
|
|
|
let (mut stream, _) = listener.accept().await.expect("accepting");
|
|
|
|
|
|
|
|
|
let junk = [
|
|
|
let junk = [
|
|
|
1, 0, 0, 0, // Length: 1 byte (big-endian)
|
|
|
|
|
|
0, // This is not enough for a u32, encoded as 4 bytes.
|
|
|
|
|
|
|
|
|
1, 0, 0, 0, // Length: 1 byte (big-endian)
|
|
|
|
|
|
0, // This is not enough for a u32, encoded as 4 bytes.
|
|
|
];
|
|
|
];
|
|
|
stream.write_all(&junk).await.expect("writing frame");
|
|
|
stream.write_all(&junk).await.expect("writing frame");
|
|
|
stream.shutdown().await.expect("shutting down");
|
|
|
stream.shutdown().await.expect("shutting down");
|
|
|
@ -131,9 +135,12 @@ mod tests { |
|
|
|
|
|
|
|
|
let (_request_tx, request_rx) = mpsc::channel::<u32>(100);
|
|
|
let (_request_tx, request_rx) = mpsc::channel::<u32>(100);
|
|
|
let (response_tx, _response_rx) = mpsc::channel::<u32>(100);
|
|
|
let (response_tx, _response_rx) = mpsc::channel::<u32>(100);
|
|
|
let worker = Worker::new(stream, response_tx, request_rx);
|
|
|
|
|
|
|
|
|
let mut worker = Worker::new(stream);
|
|
|
|
|
|
|
|
|
let err = worker.run().await.expect_err("running worker");
|
|
|
|
|
|
|
|
|
let err = worker
|
|
|
|
|
|
.run(response_tx, request_rx)
|
|
|
|
|
|
.await
|
|
|
|
|
|
.expect_err("running worker");
|
|
|
if let WorkerError::ReadError(_) = err {
|
|
|
if let WorkerError::ReadError(_) = err {
|
|
|
// Ok!
|
|
|
// Ok!
|
|
|
} else {
|
|
|
} else {
|
|
|
@ -165,12 +172,15 @@ mod tests { |
|
|
|
|
|
|
|
|
let (request_tx, request_rx) = mpsc::channel::<u32>(100);
|
|
|
let (request_tx, request_rx) = mpsc::channel::<u32>(100);
|
|
|
let (response_tx, _response_rx) = mpsc::channel::<u32>(100);
|
|
|
let (response_tx, _response_rx) = mpsc::channel::<u32>(100);
|
|
|
let worker = Worker::new(stream, response_tx, request_rx);
|
|
|
|
|
|
|
|
|
let mut worker = Worker::new(stream);
|
|
|
|
|
|
|
|
|
// Queue a frame before we run the worker.
|
|
|
// Queue a frame before we run the worker.
|
|
|
request_tx.send(42).await.expect("sending frame");
|
|
|
request_tx.send(42).await.expect("sending frame");
|
|
|
|
|
|
|
|
|
let err = worker.run().await.expect_err("running worker");
|
|
|
|
|
|
|
|
|
let err = worker
|
|
|
|
|
|
.run(response_tx, request_rx)
|
|
|
|
|
|
.await
|
|
|
|
|
|
.expect_err("running worker");
|
|
|
if let WorkerError::WriteError(_) = err {
|
|
|
if let WorkerError::WriteError(_) = err {
|
|
|
// Ok!
|
|
|
// Ok!
|
|
|
} else {
|
|
|
} else {
|
|
|
@ -196,7 +206,10 @@ mod tests { |
|
|
writer.write(&42u32).await.expect("writing frame");
|
|
|
writer.write(&42u32).await.expect("writing frame");
|
|
|
|
|
|
|
|
|
let mut buf = Vec::new();
|
|
|
let mut buf = Vec::new();
|
|
|
read_half.read_to_end(&mut buf).await.expect("waiting for eof");
|
|
|
|
|
|
|
|
|
read_half
|
|
|
|
|
|
.read_to_end(&mut buf)
|
|
|
|
|
|
.await
|
|
|
|
|
|
.expect("waiting for eof");
|
|
|
assert_eq!(buf, Vec::<u8>::new());
|
|
|
assert_eq!(buf, Vec::<u8>::new());
|
|
|
});
|
|
|
});
|
|
|
|
|
|
|
|
|
@ -204,18 +217,24 @@ mod tests { |
|
|
|
|
|
|
|
|
let (_request_tx, request_rx) = mpsc::channel::<u32>(100);
|
|
|
let (_request_tx, request_rx) = mpsc::channel::<u32>(100);
|
|
|
let (response_tx, response_rx) = mpsc::channel::<u32>(100);
|
|
|
let (response_tx, response_rx) = mpsc::channel::<u32>(100);
|
|
|
let worker = Worker::new(stream, response_tx, request_rx);
|
|
|
|
|
|
|
|
|
let mut worker = Worker::new(stream);
|
|
|
|
|
|
|
|
|
// Drop the receiver before the worker can send anything.
|
|
|
// Drop the receiver before the worker can send anything.
|
|
|
drop(response_rx);
|
|
|
drop(response_rx);
|
|
|
|
|
|
|
|
|
let err = worker.run().await.expect_err("running worker");
|
|
|
|
|
|
|
|
|
let err = worker
|
|
|
|
|
|
.run(response_tx, request_rx)
|
|
|
|
|
|
.await
|
|
|
|
|
|
.expect_err("running worker");
|
|
|
if let WorkerError::IncomingChannelClosed = err {
|
|
|
if let WorkerError::IncomingChannelClosed = err {
|
|
|
// Ok!
|
|
|
// Ok!
|
|
|
} else {
|
|
|
} else {
|
|
|
panic!("Wrong error: {:?}", err);
|
|
|
panic!("Wrong error: {:?}", err);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Drop the worker, and the underlying connection, to stop the listener.
|
|
|
|
|
|
drop(worker);
|
|
|
|
|
|
|
|
|
listener_task.await.expect("joining listener");
|
|
|
listener_task.await.expect("joining listener");
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -235,7 +254,10 @@ mod tests { |
|
|
writer.write(&frame).await.expect("writing frame");
|
|
|
writer.write(&frame).await.expect("writing frame");
|
|
|
|
|
|
|
|
|
let mut buf = Vec::new();
|
|
|
let mut buf = Vec::new();
|
|
|
read_half.read_to_end(&mut buf).await.expect("waiting for eof");
|
|
|
|
|
|
|
|
|
read_half
|
|
|
|
|
|
.read_to_end(&mut buf)
|
|
|
|
|
|
.await
|
|
|
|
|
|
.expect("waiting for eof");
|
|
|
assert_eq!(buf, Vec::<u8>::new());
|
|
|
assert_eq!(buf, Vec::<u8>::new());
|
|
|
});
|
|
|
});
|
|
|
|
|
|
|
|
|
@ -243,9 +265,10 @@ mod tests { |
|
|
|
|
|
|
|
|
let (request_tx, request_rx) = mpsc::channel::<u32>(100);
|
|
|
let (request_tx, request_rx) = mpsc::channel::<u32>(100);
|
|
|
let (response_tx, mut response_rx) = mpsc::channel::<u32>(100);
|
|
|
let (response_tx, mut response_rx) = mpsc::channel::<u32>(100);
|
|
|
let worker = Worker::new(stream, response_tx, request_rx);
|
|
|
|
|
|
|
|
|
let mut worker = Worker::new(stream);
|
|
|
|
|
|
|
|
|
let worker_task = tokio::spawn(worker.run());
|
|
|
|
|
|
|
|
|
let worker_task =
|
|
|
|
|
|
tokio::spawn(async move { worker.run(response_tx, request_rx).await });
|
|
|
|
|
|
|
|
|
let frame = response_rx.recv().await.expect("receiving frame");
|
|
|
let frame = response_rx.recv().await.expect("receiving frame");
|
|
|
assert_eq!(frame, 42);
|
|
|
assert_eq!(frame, 42);
|
|
|
@ -253,7 +276,10 @@ mod tests { |
|
|
// Signal to the worker that it should stop running.
|
|
|
// Signal to the worker that it should stop running.
|
|
|
drop(request_tx);
|
|
|
drop(request_tx);
|
|
|
|
|
|
|
|
|
worker_task.await.expect("joining worker").expect("running worker");
|
|
|
|
|
|
|
|
|
worker_task
|
|
|
|
|
|
.await
|
|
|
|
|
|
.expect("joining worker")
|
|
|
|
|
|
.expect("running worker");
|
|
|
listener_task.await.expect("joining listener");
|
|
|
listener_task.await.expect("joining listener");
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -276,16 +302,20 @@ mod tests { |
|
|
|
|
|
|
|
|
let (request_tx, request_rx) = mpsc::channel::<u32>(100);
|
|
|
let (request_tx, request_rx) = mpsc::channel::<u32>(100);
|
|
|
let (response_tx, _response_rx) = mpsc::channel::<u32>(100);
|
|
|
let (response_tx, _response_rx) = mpsc::channel::<u32>(100);
|
|
|
let worker = Worker::new(stream, response_tx, request_rx);
|
|
|
|
|
|
|
|
|
let mut worker = Worker::new(stream);
|
|
|
|
|
|
|
|
|
let worker_task = tokio::spawn(worker.run());
|
|
|
|
|
|
|
|
|
let worker_task =
|
|
|
|
|
|
tokio::spawn(async move { worker.run(response_tx, request_rx).await });
|
|
|
|
|
|
|
|
|
request_tx.send(42).await.expect("sending frame");
|
|
|
request_tx.send(42).await.expect("sending frame");
|
|
|
|
|
|
|
|
|
// Signal to the worker that it should stop running.
|
|
|
// Signal to the worker that it should stop running.
|
|
|
drop(request_tx);
|
|
|
drop(request_tx);
|
|
|
|
|
|
|
|
|
worker_task.await.expect("joining worker").expect("running worker");
|
|
|
|
|
|
|
|
|
worker_task
|
|
|
|
|
|
.await
|
|
|
|
|
|
.expect("joining worker")
|
|
|
|
|
|
.expect("running worker");
|
|
|
listener_task.await.expect("joining listener");
|
|
|
listener_task.await.expect("joining listener");
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
}
|