diff --git a/client/src/control/ws.rs b/client/src/control/ws.rs index a7011b9..7f4a428 100644 --- a/client/src/control/ws.rs +++ b/client/src/control/ws.rs @@ -1,3 +1,5 @@ +use std::net::SocketAddr; + use anyhow::Context; use futures::stream::{SplitSink, SplitStream}; use futures::{SinkExt, StreamExt}; @@ -17,7 +19,6 @@ async fn forward_incoming( mut incoming: SplitStream>, client_tx: &mpsc::UnboundedSender, ) -> anyhow::Result<()> { - // TODO: close incoming on error? notify forward_outgoing somehow. while let Some(result) = incoming.next().await { if let Err(WebSocketError::ConnectionClosed) = result { break; @@ -36,29 +37,60 @@ async fn forward_incoming( Ok(()) } -async fn forward_outgoing_inner( +async fn forward_outgoing( client_rx: &mut mpsc::Receiver, outgoing: &mut SplitSink, WebSocketMessage>, ) -> anyhow::Result<()> { while let Some(response) = client_rx.recv().await { let text = - serde_json::to_string(&response).context("Encoding control response")?; + serde_json::to_string(&response).context("encoding control response")?; + outgoing .send(WebSocketMessage::Text(text)) .await - .context("Sending control response to {}")?; + .context("sending control response")?; } + Ok(()) } -async fn forward_outgoing( +async fn handle( + stream: WebSocketStream, + remote_address: &SocketAddr, + client_tx: &mpsc::UnboundedSender, client_rx: &mut mpsc::Receiver, - mut outgoing: SplitSink, WebSocketMessage>, -) -> anyhow::Result<()> { - let result = forward_outgoing_inner(client_rx, &mut outgoing).await; - // TODO: handle error. - outgoing.close().await; - result +) { + let (mut outgoing, incoming) = stream.split(); + + tokio::select! { + result = forward_incoming(incoming, client_tx) => match result { + Ok(()) => info!( + "Incoming websocket handler task for {} stopped", + remote_address, + ), + Err(err) => error!( + "Error in websocket handler task for {}: {:#}", + remote_address, err, + ), + }, + result = forward_outgoing(client_rx, &mut outgoing) => match result { + Ok(()) => info!( + "Outgoing websocket handler for {} stopped", + remote_address, + ), + Err(err) => warn!( + "Error in outgoing websocket handler for {}: {:#}", + remote_address, err, + ), + }, + }; + + match outgoing.close().await { + Ok(()) => info!("Closed websocket for {}", remote_address), + Err(err) => { + error!("Error closing websocket for {}: {}", remote_address, err,) + } + }; } /// Start listening on the socket address stored in configuration, and send @@ -78,32 +110,7 @@ pub async fn listen( let ws_stream = tokio_tungstenite::accept_async(raw_stream).await?; info!("WebSocket connection established from {}", remote_address); - let (outgoing, incoming) = ws_stream.split(); - - // Instead of selecting, spawn a task and wait for both to resolve. This - // works because tunstenite communicates the "closed" state to both ends. - tokio::select! { - result = forward_incoming(incoming, &client_tx) => match result { - Ok(()) => (), - Err(err) => { - warn!( - "Error in incoming websocket handler for {}: {}", - remote_address, err - ) - }, - }, - result = forward_outgoing(&mut client_rx, outgoing) => match result { - Ok(()) => (), - Err(err) => { - warn!( - "Error in outgoing websocket handler for {}: {}", - remote_address, err - ) - }, - }, - }; - - info!("WebSocket connection from {} closed", remote_address); + handle(ws_stream, &remote_address, &client_tx, &mut client_rx).await } Ok(())