use std::net::SocketAddr;
|
|
|
|
use anyhow::Context;
|
|
use futures::stream::{SplitSink, SplitStream};
|
|
use futures::{SinkExt, StreamExt};
|
|
use tokio::net::{TcpListener, TcpStream};
|
|
use tokio::sync::mpsc;
|
|
use tokio_tungstenite::tungstenite::{
|
|
Error as WebSocketError, Message as WebSocketMessage,
|
|
};
|
|
use tokio_tungstenite::WebSocketStream;
|
|
|
|
use crate::control::request::*;
|
|
use crate::control::response::*;
|
|
use crate::dispatcher::Message;
|
|
|
|
async fn forward_incoming(
|
|
mut incoming: SplitStream<WebSocketStream<TcpStream>>,
|
|
message_tx: &mpsc::UnboundedSender<Message>,
|
|
) -> anyhow::Result<()> {
|
|
while let Some(result) = incoming.next().await {
|
|
if let Err(WebSocketError::ConnectionClosed) = result {
|
|
break;
|
|
}
|
|
|
|
let message = result.context("reading control message")?;
|
|
|
|
let text = message.to_text().context("non-text control message")?;
|
|
|
|
let control_request: Request = serde_json::from_str(&text)
|
|
.with_context(|| format!("decoding JSON message {:?}", text))?;
|
|
|
|
message_tx.send(Message::ControlRequest(control_request));
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn forward_outgoing(
|
|
response_rx: &mut mpsc::Receiver<Response>,
|
|
outgoing: &mut SplitSink<WebSocketStream<TcpStream>, WebSocketMessage>,
|
|
) -> anyhow::Result<()> {
|
|
while let Some(response) = response_rx.recv().await {
|
|
let text =
|
|
serde_json::to_string(&response).context("encoding control response")?;
|
|
|
|
outgoing
|
|
.send(WebSocketMessage::Text(text))
|
|
.await
|
|
.context("sending control response")?;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn handle(
|
|
stream: TcpStream,
|
|
remote_address: &SocketAddr,
|
|
message_tx: &mpsc::UnboundedSender<Message>,
|
|
response_rx: &mut mpsc::Receiver<Response>,
|
|
) {
|
|
let ws_stream = match tokio_tungstenite::accept_async(stream).await {
|
|
Ok(ws_stream) => ws_stream,
|
|
Err(err) => {
|
|
warn!(
|
|
"Failed to accept WebSocket connection from {}: {}",
|
|
remote_address, err
|
|
);
|
|
return;
|
|
}
|
|
};
|
|
|
|
info!("WebSocket connection established from {}", remote_address);
|
|
|
|
let (mut outgoing, incoming) = ws_stream.split();
|
|
|
|
tokio::select! {
|
|
result = forward_incoming(incoming, message_tx) => match result {
|
|
Ok(()) => info!(
|
|
"Incoming WebSocket handler task for {} stopped",
|
|
remote_address,
|
|
),
|
|
Err(err) => error!(
|
|
"Error in WebSocket handler task for {}: {:#}",
|
|
remote_address, err,
|
|
),
|
|
},
|
|
result = forward_outgoing(response_rx, &mut outgoing) => match result {
|
|
Ok(()) => info!(
|
|
"Outgoing WebSocket handler for {} stopped",
|
|
remote_address,
|
|
),
|
|
Err(err) => warn!(
|
|
"Error in outgoing WebSocket handler for {}: {:#}",
|
|
remote_address, err,
|
|
),
|
|
},
|
|
};
|
|
|
|
match outgoing.close().await {
|
|
Ok(()) => info!("Closed WebSocket for {}", remote_address),
|
|
Err(err) => {
|
|
error!("Error closing WebSocket for {}: {}", remote_address, err,)
|
|
}
|
|
};
|
|
}
|
|
|
|
/// A listener for control connections.
|
|
pub struct Listener {
|
|
inner: TcpListener,
|
|
address: SocketAddr,
|
|
}
|
|
|
|
impl Listener {
|
|
pub async fn bind(address_str: &str) -> anyhow::Result<Self> {
|
|
let inner = TcpListener::bind(address_str)
|
|
.await
|
|
.context("binding listener")?;
|
|
let address = inner.local_addr().context("accessing local address")?;
|
|
|
|
info!("Bound listener for control connections to {}", address);
|
|
|
|
Ok(Self { inner, address })
|
|
}
|
|
|
|
pub fn address(&self) -> &SocketAddr {
|
|
&self.address
|
|
}
|
|
|
|
/// Starts accepting control connections, one at a time. For each connection,
|
|
/// forwards incoming messages from the socket to `message_tx` and outgoing
|
|
/// responses from `response_rx` to the socket.
|
|
pub async fn run(
|
|
&mut self,
|
|
message_tx: mpsc::UnboundedSender<Message>,
|
|
mut response_rx: mpsc::Receiver<Response>,
|
|
) -> anyhow::Result<()> {
|
|
info!("Accepting control connections on {}", self.address());
|
|
|
|
loop {
|
|
let (stream, remote_address) = tokio::select! {
|
|
result = self.inner.accept() => {
|
|
result.context("accepting connection")?
|
|
},
|
|
option = response_rx.recv() => match option {
|
|
Some(response) => {
|
|
debug!(
|
|
"Dropping control response in absence of connection: {:?}",
|
|
response,
|
|
);
|
|
continue
|
|
},
|
|
None => {
|
|
info!("Stopping control listener: response channel closed");
|
|
break
|
|
},
|
|
},
|
|
};
|
|
|
|
info!("Accepted control connection from {}", remote_address);
|
|
|
|
handle(stream, &remote_address, &message_tx, &mut response_rx).await
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::Listener;
|
|
|
|
use std::net::SocketAddr;
|
|
|
|
use anyhow::Context;
|
|
use futures::{SinkExt, StreamExt};
|
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
|
use tokio::net::TcpStream;
|
|
use tokio::sync::mpsc;
|
|
use tokio_tungstenite::connect_async;
|
|
use tokio_tungstenite::tungstenite::Message as WebSocketMessage;
|
|
|
|
use crate::control::{Request, Response, RoomLeaveResponse};
|
|
use crate::dispatcher::Message;
|
|
|
|
struct Channels {
|
|
message_tx: mpsc::UnboundedSender<Message>,
|
|
message_rx: mpsc::UnboundedReceiver<Message>,
|
|
response_tx: mpsc::Sender<Response>,
|
|
response_rx: mpsc::Receiver<Response>,
|
|
}
|
|
|
|
impl Default for Channels {
|
|
fn default() -> Self {
|
|
let (message_tx, message_rx) = mpsc::unbounded_channel();
|
|
let (response_tx, response_rx) = mpsc::channel(100);
|
|
Self {
|
|
message_tx,
|
|
message_rx,
|
|
response_tx,
|
|
response_rx,
|
|
}
|
|
}
|
|
}
|
|
|
|
// Enable capturing logs in tests.
|
|
fn init() {
|
|
let _ = env_logger::builder().is_test(true).try_init();
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn binds_to_localhost() -> anyhow::Result<()> {
|
|
init();
|
|
|
|
let listener = Listener::bind("localhost:0")
|
|
.await
|
|
.context("binding listener")?;
|
|
match listener.address() {
|
|
SocketAddr::V4(address) => assert!(address.ip().is_loopback()),
|
|
SocketAddr::V6(address) => assert!(address.ip().is_loopback()),
|
|
};
|
|
Ok(())
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn listens_for_websocket_connections() -> anyhow::Result<()> {
|
|
init();
|
|
|
|
let mut listener = Listener::bind("localhost:0")
|
|
.await
|
|
.context("binding listener")?;
|
|
|
|
let address = format!("ws://{}", listener.address());
|
|
let channels = Channels::default();
|
|
|
|
// Move individual fields out of `channels`, for capture by `tokio::spawn`.
|
|
let message_tx = channels.message_tx;
|
|
let response_rx = channels.response_rx;
|
|
let listener_task =
|
|
tokio::spawn(async move { listener.run(message_tx, response_rx).await });
|
|
|
|
let (_ws_stream, _response) = connect_async(address).await?;
|
|
|
|
// dropping this sender signals to the listener that it should stop.
|
|
drop(channels.response_tx);
|
|
|
|
let () = listener_task
|
|
.await
|
|
.context("joining listener task")?
|
|
.context("running listener")?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn keeps_listening_after_failed_handshake() -> anyhow::Result<()> {
|
|
init();
|
|
|
|
let mut listener = Listener::bind("localhost:0")
|
|
.await
|
|
.context("binding listener")?;
|
|
|
|
let address = listener.address().clone();
|
|
let channels = Channels::default();
|
|
|
|
// Move individual fields out of `channels`, for capture by `tokio::spawn`.
|
|
let message_tx = channels.message_tx;
|
|
let response_rx = channels.response_rx;
|
|
let listener_task =
|
|
tokio::spawn(async move { listener.run(message_tx, response_rx).await });
|
|
|
|
let mut stream = TcpStream::connect(address).await.context("connecting")?;
|
|
|
|
// Write some invalid data, causing the listener to drop the connection.
|
|
stream.write_all(&[0]).await.context("writing")?;
|
|
|
|
// Expect that the stream immediately closes.
|
|
let mut buf = Vec::new();
|
|
stream.read_to_end(&mut buf).await.context("reading")?;
|
|
assert_eq!(buf, Vec::<u8>::new());
|
|
|
|
// Connect a second time.
|
|
let (_ws_stream, _response) =
|
|
connect_async(format!("ws://{}", address)).await?;
|
|
|
|
// Dropping this sender signals to the listener that it should stop.
|
|
drop(channels.response_tx);
|
|
|
|
let () = listener_task
|
|
.await
|
|
.context("joining listener task")?
|
|
.context("running listener")?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn forwards_incoming_requests() -> anyhow::Result<()> {
|
|
init();
|
|
|
|
let mut listener = Listener::bind("localhost:0")
|
|
.await
|
|
.context("binding listener")?;
|
|
|
|
let address = format!("ws://{}", listener.address());
|
|
let mut channels = Channels::default();
|
|
|
|
// Move individual fields out of `channels`, for capture by `tokio::spawn`.
|
|
let message_tx = channels.message_tx;
|
|
let response_rx = channels.response_rx;
|
|
let listener_task =
|
|
tokio::spawn(async move { listener.run(message_tx, response_rx).await });
|
|
|
|
let (mut ws_stream, _response) = connect_async(address).await?;
|
|
|
|
let request = serde_json::to_string(&Request::RoomListRequest)
|
|
.context("serializing request")?;
|
|
ws_stream.send(WebSocketMessage::Text(request)).await?;
|
|
|
|
assert_eq!(
|
|
channels.message_rx.recv().await,
|
|
Some(Message::ControlRequest(Request::RoomListRequest))
|
|
);
|
|
|
|
// Dropping this sender signals to the listener that it should stop.
|
|
drop(channels.response_tx);
|
|
|
|
let () = listener_task
|
|
.await
|
|
.context("joining listener task")?
|
|
.context("running listener")?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn forwards_outgoing_responses() -> anyhow::Result<()> {
|
|
init();
|
|
|
|
let mut listener = Listener::bind("localhost:0")
|
|
.await
|
|
.context("binding listener")?;
|
|
|
|
let address = format!("ws://{}", listener.address());
|
|
let channels = Channels::default();
|
|
|
|
// Move individual fields out of `channels`, for capture by `tokio::spawn`.
|
|
let message_tx = channels.message_tx;
|
|
let response_rx = channels.response_rx;
|
|
let listener_task =
|
|
tokio::spawn(async move { listener.run(message_tx, response_rx).await });
|
|
|
|
let (mut ws_stream, _response) = connect_async(address).await?;
|
|
|
|
channels
|
|
.response_tx
|
|
.send(Response::RoomLeaveResponse(RoomLeaveResponse {
|
|
room_name: "bleep".to_string(),
|
|
}))
|
|
.await
|
|
.context("sending response")?;
|
|
|
|
let message = ws_stream
|
|
.next()
|
|
.await
|
|
.context("unwrapping next response from stream")?
|
|
.context("read result")?;
|
|
|
|
let text = message.to_text().context("non-text control message")?;
|
|
|
|
let response: Response = serde_json::from_str(&text)
|
|
.with_context(|| format!("decoding JSON message {:?}", text))?;
|
|
|
|
assert_eq!(
|
|
response,
|
|
Response::RoomLeaveResponse(RoomLeaveResponse {
|
|
room_name: "bleep".to_string(),
|
|
})
|
|
);
|
|
|
|
// Dropping this sender signals to the listener that it should stop.
|
|
drop(channels.response_tx);
|
|
|
|
let () = listener_task
|
|
.await
|
|
.context("joining listener task")?
|
|
.context("running listener")?;
|
|
|
|
Ok(())
|
|
}
|
|
}
|