Solstice client.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

391 lines
11 KiB

use std::net::SocketAddr;
use anyhow::Context;
use futures::stream::{SplitSink, SplitStream};
use futures::{SinkExt, StreamExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc;
use tokio_tungstenite::tungstenite::{
Error as WebSocketError, Message as WebSocketMessage,
};
use tokio_tungstenite::WebSocketStream;
use crate::control::request::*;
use crate::control::response::*;
use crate::dispatcher::Message;
async fn forward_incoming(
mut incoming: SplitStream<WebSocketStream<TcpStream>>,
message_tx: &mpsc::UnboundedSender<Message>,
) -> anyhow::Result<()> {
while let Some(result) = incoming.next().await {
if let Err(WebSocketError::ConnectionClosed) = result {
break;
}
let message = result.context("reading control message")?;
let text = message.to_text().context("non-text control message")?;
let control_request: Request = serde_json::from_str(&text)
.with_context(|| format!("decoding JSON message {:?}", text))?;
message_tx.send(Message::ControlRequest(control_request));
}
Ok(())
}
async fn forward_outgoing(
response_rx: &mut mpsc::Receiver<Response>,
outgoing: &mut SplitSink<WebSocketStream<TcpStream>, WebSocketMessage>,
) -> anyhow::Result<()> {
while let Some(response) = response_rx.recv().await {
let text =
serde_json::to_string(&response).context("encoding control response")?;
outgoing
.send(WebSocketMessage::Text(text))
.await
.context("sending control response")?;
}
Ok(())
}
async fn handle(
stream: TcpStream,
remote_address: &SocketAddr,
message_tx: &mpsc::UnboundedSender<Message>,
response_rx: &mut mpsc::Receiver<Response>,
) {
let ws_stream = match tokio_tungstenite::accept_async(stream).await {
Ok(ws_stream) => ws_stream,
Err(err) => {
warn!(
"Failed to accept WebSocket connection from {}: {}",
remote_address, err
);
return;
}
};
info!("WebSocket connection established from {}", remote_address);
let (mut outgoing, incoming) = ws_stream.split();
tokio::select! {
result = forward_incoming(incoming, message_tx) => match result {
Ok(()) => info!(
"Incoming WebSocket handler task for {} stopped",
remote_address,
),
Err(err) => error!(
"Error in WebSocket handler task for {}: {:#}",
remote_address, err,
),
},
result = forward_outgoing(response_rx, &mut outgoing) => match result {
Ok(()) => info!(
"Outgoing WebSocket handler for {} stopped",
remote_address,
),
Err(err) => warn!(
"Error in outgoing WebSocket handler for {}: {:#}",
remote_address, err,
),
},
};
match outgoing.close().await {
Ok(()) => info!("Closed WebSocket for {}", remote_address),
Err(err) => {
error!("Error closing WebSocket for {}: {}", remote_address, err,)
}
};
}
/// A listener for control connections.
pub struct Listener {
inner: TcpListener,
address: SocketAddr,
}
impl Listener {
pub async fn bind(address_str: &str) -> anyhow::Result<Self> {
let inner = TcpListener::bind(address_str)
.await
.context("binding listener")?;
let address = inner.local_addr().context("accessing local address")?;
info!("Bound listener for control connections to {}", address);
Ok(Self { inner, address })
}
pub fn address(&self) -> &SocketAddr {
&self.address
}
/// Starts accepting control connections, one at a time. For each connection,
/// forwards incoming messages from the socket to `message_tx` and outgoing
/// responses from `response_rx` to the socket.
pub async fn run(
&mut self,
message_tx: mpsc::UnboundedSender<Message>,
mut response_rx: mpsc::Receiver<Response>,
) -> anyhow::Result<()> {
info!("Accepting control connections on {}", self.address());
loop {
let (stream, remote_address) = tokio::select! {
result = self.inner.accept() => {
result.context("accepting connection")?
},
option = response_rx.recv() => match option {
Some(response) => {
debug!(
"Dropping control response in absence of connection: {:?}",
response,
);
continue
},
None => {
info!("Stopping control listener: response channel closed");
break
},
},
};
info!("Accepted control connection from {}", remote_address);
handle(stream, &remote_address, &message_tx, &mut response_rx).await
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::Listener;
use std::net::SocketAddr;
use anyhow::Context;
use futures::{SinkExt, StreamExt};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use tokio_tungstenite::connect_async;
use tokio_tungstenite::tungstenite::Message as WebSocketMessage;
use crate::control::{Request, Response, RoomLeaveResponse};
use crate::dispatcher::Message;
struct Channels {
message_tx: mpsc::UnboundedSender<Message>,
message_rx: mpsc::UnboundedReceiver<Message>,
response_tx: mpsc::Sender<Response>,
response_rx: mpsc::Receiver<Response>,
}
impl Default for Channels {
fn default() -> Self {
let (message_tx, message_rx) = mpsc::unbounded_channel();
let (response_tx, response_rx) = mpsc::channel(100);
Self {
message_tx,
message_rx,
response_tx,
response_rx,
}
}
}
// Enable capturing logs in tests.
fn init() {
let _ = env_logger::builder().is_test(true).try_init();
}
#[tokio::test]
async fn binds_to_localhost() -> anyhow::Result<()> {
init();
let listener = Listener::bind("localhost:0")
.await
.context("binding listener")?;
match listener.address() {
SocketAddr::V4(address) => assert!(address.ip().is_loopback()),
SocketAddr::V6(address) => assert!(address.ip().is_loopback()),
};
Ok(())
}
#[tokio::test]
async fn listens_for_websocket_connections() -> anyhow::Result<()> {
init();
let mut listener = Listener::bind("localhost:0")
.await
.context("binding listener")?;
let address = format!("ws://{}", listener.address());
let channels = Channels::default();
// Move individual fields out of `channels`, for capture by `tokio::spawn`.
let message_tx = channels.message_tx;
let response_rx = channels.response_rx;
let listener_task =
tokio::spawn(async move { listener.run(message_tx, response_rx).await });
let (_ws_stream, _response) = connect_async(address).await?;
// dropping this sender signals to the listener that it should stop.
drop(channels.response_tx);
let () = listener_task
.await
.context("joining listener task")?
.context("running listener")?;
Ok(())
}
#[tokio::test]
async fn keeps_listening_after_failed_handshake() -> anyhow::Result<()> {
init();
let mut listener = Listener::bind("localhost:0")
.await
.context("binding listener")?;
let address = listener.address().clone();
let channels = Channels::default();
// Move individual fields out of `channels`, for capture by `tokio::spawn`.
let message_tx = channels.message_tx;
let response_rx = channels.response_rx;
let listener_task =
tokio::spawn(async move { listener.run(message_tx, response_rx).await });
let mut stream = TcpStream::connect(address).await.context("connecting")?;
// Write some invalid data, causing the listener to drop the connection.
stream.write_all(&[0]).await.context("writing")?;
// Expect that the stream immediately closes.
let mut buf = Vec::new();
stream.read_to_end(&mut buf).await.context("reading")?;
assert_eq!(buf, Vec::<u8>::new());
// Connect a second time.
let (_ws_stream, _response) =
connect_async(format!("ws://{}", address)).await?;
// Dropping this sender signals to the listener that it should stop.
drop(channels.response_tx);
let () = listener_task
.await
.context("joining listener task")?
.context("running listener")?;
Ok(())
}
#[tokio::test]
async fn forwards_incoming_requests() -> anyhow::Result<()> {
init();
let mut listener = Listener::bind("localhost:0")
.await
.context("binding listener")?;
let address = format!("ws://{}", listener.address());
let mut channels = Channels::default();
// Move individual fields out of `channels`, for capture by `tokio::spawn`.
let message_tx = channels.message_tx;
let response_rx = channels.response_rx;
let listener_task =
tokio::spawn(async move { listener.run(message_tx, response_rx).await });
let (mut ws_stream, _response) = connect_async(address).await?;
let request = serde_json::to_string(&Request::RoomListRequest)
.context("serializing request")?;
ws_stream.send(WebSocketMessage::Text(request)).await?;
assert_eq!(
channels.message_rx.recv().await,
Some(Message::ControlRequest(Request::RoomListRequest))
);
// Dropping this sender signals to the listener that it should stop.
drop(channels.response_tx);
let () = listener_task
.await
.context("joining listener task")?
.context("running listener")?;
Ok(())
}
#[tokio::test]
async fn forwards_outgoing_responses() -> anyhow::Result<()> {
init();
let mut listener = Listener::bind("localhost:0")
.await
.context("binding listener")?;
let address = format!("ws://{}", listener.address());
let channels = Channels::default();
// Move individual fields out of `channels`, for capture by `tokio::spawn`.
let message_tx = channels.message_tx;
let response_rx = channels.response_rx;
let listener_task =
tokio::spawn(async move { listener.run(message_tx, response_rx).await });
let (mut ws_stream, _response) = connect_async(address).await?;
channels
.response_tx
.send(Response::RoomLeaveResponse(RoomLeaveResponse {
room_name: "bleep".to_string(),
}))
.await
.context("sending response")?;
let message = ws_stream
.next()
.await
.context("unwrapping next response from stream")?
.context("read result")?;
let text = message.to_text().context("non-text control message")?;
let response: Response = serde_json::from_str(&text)
.with_context(|| format!("decoding JSON message {:?}", text))?;
assert_eq!(
response,
Response::RoomLeaveResponse(RoomLeaveResponse {
room_name: "bleep".to_string(),
})
);
// Dropping this sender signals to the listener that it should stop.
drop(channels.response_tx);
let () = listener_task
.await
.context("joining listener task")?
.context("running listener")?;
Ok(())
}
}