diff --git a/client/src/control/ws.rs b/client/src/control/ws.rs index fd40774..271d185 100644 --- a/client/src/control/ws.rs +++ b/client/src/control/ws.rs @@ -1,4 +1,3 @@ -use std::io; use std::net::SocketAddr; use anyhow::Context; @@ -100,9 +99,11 @@ pub struct Listener { } impl Listener { - pub async fn bind(address_str: &str) -> io::Result { - let inner = TcpListener::bind(address_str).await?; - let address = inner.local_addr()?; + pub async fn bind(address_str: &str) -> anyhow::Result { + let inner = TcpListener::bind(address_str) + .await + .context("binding listener")?; + let address = inner.local_addr().context("accessing local address")?; info!("Listening for control connections on {}", address); @@ -122,18 +123,32 @@ impl Listener { mut response_rx: mpsc::Receiver, ) -> anyhow::Result<()> { loop { - // TODO: Select from response_rx too, and stop looping when it is closed. - let (raw_stream, remote_address) = - self.inner.accept().await.context("accepting connection")?; + let (raw_stream, remote_address) = tokio::select! { + result = self.inner.accept() => { + result.context("accepting connection")? + }, + option = response_rx.recv() => match option { + Some(response) => { + debug!("Dropping control response in absence of connection: {:?}", response); + continue + }, + None => { + info!("Stopping control listener: response channel closed"); + break + }, + }, + }; + info!("Accepted control connection from {}", remote_address); // TODO: Continue iterating in case of error. let ws_stream = tokio_tungstenite::accept_async(raw_stream).await?; info!("WebSocket connection established from {}", remote_address); - // TODO: Stop gracefully once `response_rx` is closed. handle(ws_stream, &remote_address, &message_tx, &mut response_rx).await } + + Ok(()) } } @@ -170,8 +185,15 @@ mod tests { } } + // Enable capturing logs in tests. + fn init() { + let _ = env_logger::builder().is_test(true).try_init(); + } + #[tokio::test] async fn binds_to_localhost() -> anyhow::Result<()> { + init(); + let listener = Listener::bind("localhost:0") .await .context("binding listener")?; @@ -181,4 +203,35 @@ mod tests { }; Ok(()) } + + #[tokio::test] + async fn listens_for_websocket_connections() -> anyhow::Result<()> { + init(); + + let mut listener = Listener::bind("localhost:0") + .await + .context("binding listener")?; + + let address = format!("ws://{}", listener.address()); + + let channels = Channels::default(); + + // Move individual fields out of `channels`, for capture by `tokio::spawn`. + let message_tx = channels.message_tx; + let response_rx = channels.response_rx; + let listener_task = + tokio::spawn(async move { listener.run(message_tx, response_rx).await }); + + let (_ws_stream, _response) = connect_async(address).await?; + + // Dropping this sender signals to the listener that it should stop. + drop(channels.response_tx); + + let () = listener_task + .await + .context("joining listener task")? + .context("running listener")?; + + Ok(()) + } }