diff --git a/client/src/control/ws.rs b/client/src/control/ws.rs index 0c379ed..9dc583b 100644 --- a/client/src/control/ws.rs +++ b/client/src/control/ws.rs @@ -183,23 +183,52 @@ mod tests { use crate::control::{Request, Response, RoomLeaveResponse}; use crate::dispatcher::Message; - struct Channels { + // A bound `Listener` packaged with the channels it needs to run. + // Convenient for tests. + struct RunnableListener { + inner: Listener, message_tx: mpsc::UnboundedSender, - message_rx: mpsc::UnboundedReceiver, - response_tx: mpsc::Sender, response_rx: mpsc::Receiver, } - impl Default for Channels { - fn default() -> Self { + impl RunnableListener { + async fn run(mut self) -> anyhow::Result<()> { + self.inner.run(self.message_tx, self.response_rx).await + } + } + + // Packages together all the things needed to test `Listener`. + struct ListenerBundle { + pub listener: RunnableListener, + pub address: SocketAddr, + pub websocket_address: String, + pub message_rx: mpsc::UnboundedReceiver, + pub response_tx: mpsc::Sender, + } + + impl ListenerBundle { + // Binds to a random port on localhost. + async fn bind() -> anyhow::Result { + let inner = Listener::bind("localhost:0").await?; + let address = inner.address().clone(); + let websocket_address = format!("ws://{}", address); + let (message_tx, message_rx) = mpsc::unbounded_channel(); let (response_tx, response_rx) = mpsc::channel(100); - Self { + + let listener = RunnableListener { + inner, message_tx, + response_rx, + }; + + Ok(Self { + listener, + address, + websocket_address, message_rx, response_tx, - response_rx, - } + }) } } @@ -212,13 +241,13 @@ mod tests { async fn binds_to_localhost() -> anyhow::Result<()> { init(); - let listener = Listener::bind("localhost:0") - .await - .context("binding listener")?; - match listener.address() { + let bundle = ListenerBundle::bind().await.context("binding listener")?; + + match bundle.address { SocketAddr::V4(address) => assert!(address.ip().is_loopback()), SocketAddr::V6(address) => assert!(address.ip().is_loopback()), }; + Ok(()) } @@ -226,23 +255,16 @@ mod tests { async fn listens_for_websocket_connections() -> anyhow::Result<()> { init(); - let mut listener = Listener::bind("localhost:0") - .await - .context("binding listener")?; + let bundle = ListenerBundle::bind().await.context("binding listener")?; - let address = format!("ws://{}", listener.address()); - let channels = Channels::default(); + let listener_task = tokio::spawn(bundle.listener.run()); - // Move individual fields out of `channels`, for capture by `tokio::spawn`. - let message_tx = channels.message_tx; - let response_rx = channels.response_rx; - let listener_task = - tokio::spawn(async move { listener.run(message_tx, response_rx).await }); - - let (_ws_stream, _response) = connect_async(address).await?; + let (_ws_stream, _response) = connect_async(bundle.websocket_address) + .await + .context("connecting")?; - // dropping this sender signals to the listener that it should stop. - drop(channels.response_tx); + // Dropping this sender signals to the listener that it should stop. + drop(bundle.response_tx); let () = listener_task .await @@ -256,20 +278,13 @@ mod tests { async fn keeps_listening_after_failed_handshake() -> anyhow::Result<()> { init(); - let mut listener = Listener::bind("localhost:0") - .await - .context("binding listener")?; + let bundle = ListenerBundle::bind().await.context("binding listener")?; - let address = listener.address().clone(); - let channels = Channels::default(); + let listener_task = tokio::spawn(bundle.listener.run()); - // Move individual fields out of `channels`, for capture by `tokio::spawn`. - let message_tx = channels.message_tx; - let response_rx = channels.response_rx; - let listener_task = - tokio::spawn(async move { listener.run(message_tx, response_rx).await }); - - let mut stream = TcpStream::connect(address).await.context("connecting")?; + let mut stream = TcpStream::connect(bundle.address) + .await + .context("connecting tcp")?; // Write some invalid data, causing the listener to drop the connection. stream.write_all(&[0]).await.context("writing")?; @@ -280,11 +295,12 @@ mod tests { assert_eq!(buf, Vec::::new()); // Connect a second time. - let (_ws_stream, _response) = - connect_async(format!("ws://{}", address)).await?; + let (_ws_stream, _response) = connect_async(bundle.websocket_address) + .await + .context("connecting websocket")?; // Dropping this sender signals to the listener that it should stop. - drop(channels.response_tx); + drop(bundle.response_tx); let () = listener_task .await @@ -298,32 +314,29 @@ mod tests { async fn forwards_incoming_requests() -> anyhow::Result<()> { init(); - let mut listener = Listener::bind("localhost:0") - .await - .context("binding listener")?; - - let address = format!("ws://{}", listener.address()); - let mut channels = Channels::default(); + let mut bundle = + ListenerBundle::bind().await.context("binding listener")?; - // Move individual fields out of `channels`, for capture by `tokio::spawn`. - let message_tx = channels.message_tx; - let response_rx = channels.response_rx; - let listener_task = - tokio::spawn(async move { listener.run(message_tx, response_rx).await }); + let listener_task = tokio::spawn(bundle.listener.run()); - let (mut ws_stream, _response) = connect_async(address).await?; + let (mut ws_stream, _response) = connect_async(bundle.websocket_address) + .await + .context("connecting")?; let request = serde_json::to_string(&Request::RoomListRequest) .context("serializing request")?; - ws_stream.send(WebSocketMessage::Text(request)).await?; + ws_stream + .send(WebSocketMessage::Text(request)) + .await + .context("sending request")?; assert_eq!( - channels.message_rx.recv().await, + bundle.message_rx.recv().await, Some(Message::ControlRequest(Request::RoomListRequest)) ); // Dropping this sender signals to the listener that it should stop. - drop(channels.response_tx); + drop(bundle.response_tx); let () = listener_task .await @@ -337,22 +350,15 @@ mod tests { async fn forwards_outgoing_responses() -> anyhow::Result<()> { init(); - let mut listener = Listener::bind("localhost:0") - .await - .context("binding listener")?; + let bundle = ListenerBundle::bind().await.context("binding listener")?; - let address = format!("ws://{}", listener.address()); - let channels = Channels::default(); + let listener_task = tokio::spawn(bundle.listener.run()); - // Move individual fields out of `channels`, for capture by `tokio::spawn`. - let message_tx = channels.message_tx; - let response_rx = channels.response_rx; - let listener_task = - tokio::spawn(async move { listener.run(message_tx, response_rx).await }); - - let (mut ws_stream, _response) = connect_async(address).await?; + let (mut ws_stream, _response) = connect_async(bundle.websocket_address) + .await + .context("connecting")?; - channels + bundle .response_tx .send(Response::RoomLeaveResponse(RoomLeaveResponse { room_name: "bleep".to_string(), @@ -379,7 +385,7 @@ mod tests { ); // Dropping this sender signals to the listener that it should stop. - drop(channels.response_tx); + drop(bundle.response_tx); let () = listener_task .await