diff --git a/client/src/control/mod.rs b/client/src/control/mod.rs index 7850598..9d1af8d 100644 --- a/client/src/control/mod.rs +++ b/client/src/control/mod.rs @@ -4,4 +4,4 @@ mod ws; pub use self::request::*; pub use self::response::*; -pub use self::ws::listen; +pub use self::ws::Listener; diff --git a/client/src/control/ws.rs b/client/src/control/ws.rs index be3e65e..fd40774 100644 --- a/client/src/control/ws.rs +++ b/client/src/control/ws.rs @@ -1,9 +1,9 @@ +use std::io; use std::net::SocketAddr; use anyhow::Context; use futures::stream::{SplitSink, SplitStream}; use futures::{SinkExt, StreamExt}; -use solstice_proto::config; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::mpsc; use tokio_tungstenite::tungstenite::{ @@ -93,25 +93,92 @@ async fn handle( }; } -/// Start listening on the socket address stored in configuration, and send -/// control notifications to the client through the given channel. -pub async fn listen( - message_tx: mpsc::UnboundedSender, - mut response_rx: mpsc::Receiver, -) -> anyhow::Result<()> { - let address = format!("{}:{}", config::CONTROL_HOST, config::CONTROL_PORT); - let listener = TcpListener::bind(&address).await?; +/// A listener for control connections. +pub struct Listener { + inner: TcpListener, + address: SocketAddr, +} - info!("Listening for control connections on {}", address); +impl Listener { + pub async fn bind(address_str: &str) -> io::Result { + let inner = TcpListener::bind(address_str).await?; + let address = inner.local_addr()?; - while let Ok((raw_stream, remote_address)) = listener.accept().await { - info!("Accepted control connection from {}", remote_address); + info!("Listening for control connections on {}", address); - let ws_stream = tokio_tungstenite::accept_async(raw_stream).await?; - info!("WebSocket connection established from {}", remote_address); + Ok(Self { inner, address }) + } - handle(ws_stream, &remote_address, &message_tx, &mut response_rx).await + pub fn address(&self) -> &SocketAddr { + &self.address } - Ok(()) + /// Starts accepting control connections, one at a time. For each connection, + /// forwards incoming messages from the socket to `message_tx` and outgoing + /// responses from `response_rx` to the socket. + pub async fn run( + &mut self, + message_tx: mpsc::UnboundedSender, + mut response_rx: mpsc::Receiver, + ) -> anyhow::Result<()> { + loop { + // TODO: Select from response_rx too, and stop looping when it is closed. + let (raw_stream, remote_address) = + self.inner.accept().await.context("accepting connection")?; + info!("Accepted control connection from {}", remote_address); + + // TODO: Continue iterating in case of error. + let ws_stream = tokio_tungstenite::accept_async(raw_stream).await?; + info!("WebSocket connection established from {}", remote_address); + + // TODO: Stop gracefully once `response_rx` is closed. + handle(ws_stream, &remote_address, &message_tx, &mut response_rx).await + } + } +} + +#[cfg(test)] +mod tests { + use super::Listener; + + use std::net::SocketAddr; + + use anyhow::Context; + use tokio::sync::mpsc; + use tokio_tungstenite::connect_async; + + use crate::control::Response; + use crate::dispatcher::Message; + + struct Channels { + message_tx: mpsc::UnboundedSender, + message_rx: mpsc::UnboundedReceiver, + response_tx: mpsc::Sender, + response_rx: mpsc::Receiver, + } + + impl Default for Channels { + fn default() -> Self { + let (message_tx, message_rx) = mpsc::unbounded_channel(); + let (response_tx, response_rx) = mpsc::channel(100); + Self { + message_tx, + message_rx, + response_tx, + response_rx, + } + } + } + + #[tokio::test] + async fn binds_to_localhost() -> anyhow::Result<()> { + let listener = Listener::bind("localhost:0") + .await + .context("binding listener")?; + match listener.address() { + SocketAddr::V4(address) => assert!(address.ip().is_loopback()), + SocketAddr::V6(address) => assert!(address.ip().is_loopback()), + }; + Ok(()) + } } diff --git a/client/src/main.rs b/client/src/main.rs index 2be787e..850dab9 100644 --- a/client/src/main.rs +++ b/client/src/main.rs @@ -4,12 +4,13 @@ extern crate log; use std::thread; +use anyhow::Context; use clap::{App, Arg}; use crossbeam_channel; use env_logger; use futures::stream::StreamExt; use log::info; -use solstice_proto; +use solstice_proto::config; use tokio::net::TcpStream; mod client; @@ -108,7 +109,7 @@ async fn run_client( Ok(()) } -async fn async_main() { +async fn async_main() -> anyhow::Result<()> { let bundle = ContextBundle::default(); let (dispatcher_tx, mut dispatcher_rx) = @@ -120,7 +121,13 @@ async fn async_main() { let dispatcher = Dispatcher::new(); let executor = Executor::new(bundle.context); - let control_task = control::listen(dispatcher_tx, bundle.control_response_rx); + let control_address = + format!("{}:{}", config::CONTROL_HOST, config::CONTROL_PORT); + let mut control_listener = control::Listener::bind(&control_address) + .await + .context("binding control listener")?; + let control_task = + control_listener.run(dispatcher_tx, bundle.control_response_rx); let dispatch_task = async move { while let Some(message) = dispatcher_rx.recv().await { @@ -145,12 +152,12 @@ async fn async_main() { client_task .await - .expect("Client task join error") - .expect("Client error"); + .context("joining client task")? + .context("running client task") } #[tokio::main] -async fn main() { +async fn main() -> anyhow::Result<()> { env_logger::init(); let matches = App::new("solstice-client") @@ -165,9 +172,11 @@ async fn main() { if matches.is_present("async") { info!("Running in asynchronous mode."); - async_main().await; + async_main().await.context("running asynchronous main") } else { info!("Running in synchronous mode."); - tokio::task::spawn_blocking(old_main).await.unwrap(); + tokio::task::spawn_blocking(old_main) + .await + .context("running synchronous main") } }