From 45e1b5fa247c0d5d318f7f98775a8ad99058a1cb Mon Sep 17 00:00:00 2001 From: Titouan Rigoudy Date: Fri, 13 Aug 2021 19:23:34 +0200 Subject: [PATCH] Refactor IncomingHandler. --- client/src/control/ws.rs | 102 +++++++++++++++++---------------------- 1 file changed, 43 insertions(+), 59 deletions(-) diff --git a/client/src/control/ws.rs b/client/src/control/ws.rs index 702a56c..a7011b9 100644 --- a/client/src/control/ws.rs +++ b/client/src/control/ws.rs @@ -4,71 +4,41 @@ use futures::{SinkExt, StreamExt}; use solstice_proto::config; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::mpsc; -use tokio_tungstenite::tungstenite::Message as WebSocketMessage; +use tokio_tungstenite::tungstenite::{ + Error as WebSocketError, Message as WebSocketMessage, +}; use tokio_tungstenite::WebSocketStream; use crate::control::request::*; use crate::control::response::*; use crate::dispatcher::Message; -struct IncomingHandler<'a> { - address: &'a str, - client_tx: &'a mpsc::UnboundedSender, -} +async fn forward_incoming( + mut incoming: SplitStream>, + client_tx: &mpsc::UnboundedSender, +) -> anyhow::Result<()> { + // TODO: close incoming on error? notify forward_outgoing somehow. + while let Some(result) = incoming.next().await { + if let Err(WebSocketError::ConnectionClosed) = result { + break; + } -impl<'a> IncomingHandler<'a> { - async fn run(&self, mut incoming: SplitStream>) { - while let Some(result) = incoming.next().await { - let message = match result { - Ok(message) => message, - Err(err) => { - warn!( - "Error reading control message from {}: {}", - self.address, err - ); - break; - } - }; - - let text = match message.to_text() { - Ok(text) => text, - Err(err) => { - warn!( - "Received non-text control message from {}: {}", - self.address, err - ); - break; - } - }; + let message = result.context("reading control message")?; - debug!("Received a text message from {}: {}", self.address, text); + let text = message.to_text().context("non-text control message")?; - let control_request: Request = match serde_json::from_str(&text) { - Ok(control_request) => control_request, - Err(err) => { - warn!( - "Received invalid JSON message from {}: {}", - self.address, err - ); - break; - } - }; - - debug!( - "Received control request from {}: {:?}", - self.address, control_request - ); - - self - .client_tx - .send(Message::ControlRequest(control_request)); - } + let control_request: Request = + serde_json::from_str(&text).context("decoding JSON message")?; + + client_tx.send(Message::ControlRequest(control_request)); } + + Ok(()) } -async fn forward_outgoing( +async fn forward_outgoing_inner( client_rx: &mut mpsc::Receiver, - mut outgoing: SplitSink, WebSocketMessage>, + outgoing: &mut SplitSink, WebSocketMessage>, ) -> anyhow::Result<()> { while let Some(response) = client_rx.recv().await { let text = @@ -81,6 +51,16 @@ async fn forward_outgoing( Ok(()) } +async fn forward_outgoing( + client_rx: &mut mpsc::Receiver, + mut outgoing: SplitSink, WebSocketMessage>, +) -> anyhow::Result<()> { + let result = forward_outgoing_inner(client_rx, &mut outgoing).await; + // TODO: handle error. + outgoing.close().await; + result +} + /// Start listening on the socket address stored in configuration, and send /// control notifications to the client through the given channel. pub async fn listen( @@ -92,12 +72,6 @@ pub async fn listen( info!("Listening for control connections on {}", address); - let incoming_handler = IncomingHandler { - // TODO: This should be the address of the remote endpoint, obtained below. - address: &address, - client_tx: &client_tx, - }; - while let Ok((raw_stream, remote_address)) = listener.accept().await { info!("Accepted control connection from {}", remote_address); @@ -106,8 +80,18 @@ pub async fn listen( let (outgoing, incoming) = ws_stream.split(); + // Instead of selecting, spawn a task and wait for both to resolve. This + // works because tunstenite communicates the "closed" state to both ends. tokio::select! { - () = incoming_handler.run(incoming) => (), + result = forward_incoming(incoming, &client_tx) => match result { + Ok(()) => (), + Err(err) => { + warn!( + "Error in incoming websocket handler for {}: {}", + remote_address, err + ) + }, + }, result = forward_outgoing(&mut client_rx, outgoing) => match result { Ok(()) => (), Err(err) => {