Solstice client.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

333 lines
9.2 KiB

use std::net::SocketAddr;
use anyhow::Context;
use futures::stream::{SplitSink, SplitStream};
use futures::{SinkExt, StreamExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc;
use tokio_tungstenite::tungstenite::{
Error as WebSocketError, Message as WebSocketMessage,
};
use tokio_tungstenite::WebSocketStream;
use crate::control::request::*;
use crate::control::response::*;
use crate::dispatcher::Message;
async fn forward_incoming(
mut incoming: SplitStream<WebSocketStream<TcpStream>>,
message_tx: &mpsc::UnboundedSender<Message>,
) -> anyhow::Result<()> {
while let Some(result) = incoming.next().await {
if let Err(WebSocketError::ConnectionClosed) = result {
break;
}
let message = result.context("reading control message")?;
let text = message.to_text().context("non-text control message")?;
let control_request: Request = serde_json::from_str(&text)
.with_context(|| format!("decoding JSON message {:?}", text))?;
message_tx.send(Message::ControlRequest(control_request));
}
Ok(())
}
async fn forward_outgoing(
response_rx: &mut mpsc::Receiver<Response>,
outgoing: &mut SplitSink<WebSocketStream<TcpStream>, WebSocketMessage>,
) -> anyhow::Result<()> {
while let Some(response) = response_rx.recv().await {
let text =
serde_json::to_string(&response).context("encoding control response")?;
outgoing
.send(WebSocketMessage::Text(text))
.await
.context("sending control response")?;
}
Ok(())
}
async fn handle(
stream: WebSocketStream<TcpStream>,
remote_address: &SocketAddr,
message_tx: &mpsc::UnboundedSender<Message>,
response_rx: &mut mpsc::Receiver<Response>,
) {
let (mut outgoing, incoming) = stream.split();
tokio::select! {
result = forward_incoming(incoming, message_tx) => match result {
Ok(()) => info!(
"Incoming websocket handler task for {} stopped",
remote_address,
),
Err(err) => error!(
"Error in websocket handler task for {}: {:#}",
remote_address, err,
),
},
result = forward_outgoing(response_rx, &mut outgoing) => match result {
Ok(()) => info!(
"Outgoing websocket handler for {} stopped",
remote_address,
),
Err(err) => warn!(
"Error in outgoing websocket handler for {}: {:#}",
remote_address, err,
),
},
};
match outgoing.close().await {
Ok(()) => info!("Closed websocket for {}", remote_address),
Err(err) => {
error!("Error closing websocket for {}: {}", remote_address, err,)
}
};
}
/// A listener for control connections.
pub struct Listener {
inner: TcpListener,
address: SocketAddr,
}
impl Listener {
pub async fn bind(address_str: &str) -> anyhow::Result<Self> {
let inner = TcpListener::bind(address_str)
.await
.context("binding listener")?;
let address = inner.local_addr().context("accessing local address")?;
info!("Listening for control connections on {}", address);
Ok(Self { inner, address })
}
pub fn address(&self) -> &SocketAddr {
&self.address
}
/// Starts accepting control connections, one at a time. For each connection,
/// forwards incoming messages from the socket to `message_tx` and outgoing
/// responses from `response_rx` to the socket.
pub async fn run(
&mut self,
message_tx: mpsc::UnboundedSender<Message>,
mut response_rx: mpsc::Receiver<Response>,
) -> anyhow::Result<()> {
loop {
let (raw_stream, remote_address) = tokio::select! {
result = self.inner.accept() => {
result.context("accepting connection")?
},
option = response_rx.recv() => match option {
Some(response) => {
debug!("Dropping control response in absence of connection: {:?}", response);
continue
},
None => {
info!("Stopping control listener: response channel closed");
break
},
},
};
info!("Accepted control connection from {}", remote_address);
// TODO: Continue iterating in case of error.
let ws_stream = tokio_tungstenite::accept_async(raw_stream).await?;
info!("WebSocket connection established from {}", remote_address);
handle(ws_stream, &remote_address, &message_tx, &mut response_rx).await
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::Listener;
use std::net::SocketAddr;
use anyhow::Context;
use futures::{SinkExt, StreamExt};
use tokio::sync::mpsc;
use tokio_tungstenite::connect_async;
use tokio_tungstenite::tungstenite::Message as WebSocketMessage;
use crate::control::{Request, Response, RoomLeaveResponse};
use crate::dispatcher::Message;
struct Channels {
message_tx: mpsc::UnboundedSender<Message>,
message_rx: mpsc::UnboundedReceiver<Message>,
response_tx: mpsc::Sender<Response>,
response_rx: mpsc::Receiver<Response>,
}
impl Default for Channels {
fn default() -> Self {
let (message_tx, message_rx) = mpsc::unbounded_channel();
let (response_tx, response_rx) = mpsc::channel(100);
Self {
message_tx,
message_rx,
response_tx,
response_rx,
}
}
}
// Enable capturing logs in tests.
fn init() {
let _ = env_logger::builder().is_test(true).try_init();
}
#[tokio::test]
async fn binds_to_localhost() -> anyhow::Result<()> {
init();
let listener = Listener::bind("localhost:0")
.await
.context("binding listener")?;
match listener.address() {
SocketAddr::V4(address) => assert!(address.ip().is_loopback()),
SocketAddr::V6(address) => assert!(address.ip().is_loopback()),
};
Ok(())
}
#[tokio::test]
async fn listens_for_websocket_connections() -> anyhow::Result<()> {
init();
let mut listener = Listener::bind("localhost:0")
.await
.context("binding listener")?;
let address = format!("ws://{}", listener.address());
let channels = Channels::default();
// Move individual fields out of `channels`, for capture by `tokio::spawn`.
let message_tx = channels.message_tx;
let response_rx = channels.response_rx;
let listener_task =
tokio::spawn(async move { listener.run(message_tx, response_rx).await });
let (_ws_stream, _response) = connect_async(address).await?;
// Dropping this sender signals to the listener that it should stop.
drop(channels.response_tx);
let () = listener_task
.await
.context("joining listener task")?
.context("running listener")?;
Ok(())
}
#[tokio::test]
async fn forwards_incoming_requests() -> anyhow::Result<()> {
init();
let mut listener = Listener::bind("localhost:0")
.await
.context("binding listener")?;
let address = format!("ws://{}", listener.address());
let mut channels = Channels::default();
// Move individual fields out of `channels`, for capture by `tokio::spawn`.
let message_tx = channels.message_tx;
let response_rx = channels.response_rx;
let listener_task =
tokio::spawn(async move { listener.run(message_tx, response_rx).await });
let (mut ws_stream, _response) = connect_async(address).await?;
let request = serde_json::to_string(&Request::RoomListRequest)
.context("serializing request")?;
ws_stream.send(WebSocketMessage::Text(request)).await?;
assert_eq!(
channels.message_rx.recv().await,
Some(Message::ControlRequest(Request::RoomListRequest))
);
// Dropping this sender signals to the listener that it should stop.
drop(channels.response_tx);
let () = listener_task
.await
.context("joining listener task")?
.context("running listener")?;
Ok(())
}
#[tokio::test]
async fn forwards_outgoing_responses() -> anyhow::Result<()> {
init();
let mut listener = Listener::bind("localhost:0")
.await
.context("binding listener")?;
let address = format!("ws://{}", listener.address());
let channels = Channels::default();
// Move individual fields out of `channels`, for capture by `tokio::spawn`.
let message_tx = channels.message_tx;
let response_rx = channels.response_rx;
let listener_task =
tokio::spawn(async move { listener.run(message_tx, response_rx).await });
let (mut ws_stream, _response) = connect_async(address).await?;
channels
.response_tx
.send(Response::RoomLeaveResponse(RoomLeaveResponse {
room_name: "bleep".to_string(),
}))
.await
.context("sending response")?;
let message = ws_stream
.next()
.await
.context("unwrapping next response from stream")?
.context("read result")?;
let text = message.to_text().context("non-text control message")?;
let response: Response = serde_json::from_str(&text)
.with_context(|| format!("decoding JSON message {:?}", text))?;
assert_eq!(
response,
Response::RoomLeaveResponse(RoomLeaveResponse {
room_name: "bleep".to_string(),
})
);
// Dropping this sender signals to the listener that it should stop.
drop(channels.response_tx);
let () = listener_task
.await
.context("joining listener task")?
.context("running listener")?;
Ok(())
}
}