Browse Source

Write first test for control listening code.

wip
Titouan Rigoudy 4 years ago
parent
commit
2dcd7ceca3
3 changed files with 101 additions and 25 deletions
  1. +1
    -1
      client/src/control/mod.rs
  2. +83
    -16
      client/src/control/ws.rs
  3. +17
    -8
      client/src/main.rs

+ 1
- 1
client/src/control/mod.rs View File

@ -4,4 +4,4 @@ mod ws;
pub use self::request::*; pub use self::request::*;
pub use self::response::*; pub use self::response::*;
pub use self::ws::listen;
pub use self::ws::Listener;

+ 83
- 16
client/src/control/ws.rs View File

@ -1,9 +1,9 @@
use std::io;
use std::net::SocketAddr; use std::net::SocketAddr;
use anyhow::Context; use anyhow::Context;
use futures::stream::{SplitSink, SplitStream}; use futures::stream::{SplitSink, SplitStream};
use futures::{SinkExt, StreamExt}; use futures::{SinkExt, StreamExt};
use solstice_proto::config;
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio_tungstenite::tungstenite::{ use tokio_tungstenite::tungstenite::{
@ -93,25 +93,92 @@ async fn handle(
}; };
} }
/// Start listening on the socket address stored in configuration, and send
/// control notifications to the client through the given channel.
pub async fn listen(
message_tx: mpsc::UnboundedSender<Message>,
mut response_rx: mpsc::Receiver<Response>,
) -> anyhow::Result<()> {
let address = format!("{}:{}", config::CONTROL_HOST, config::CONTROL_PORT);
let listener = TcpListener::bind(&address).await?;
/// A listener for control connections.
pub struct Listener {
inner: TcpListener,
address: SocketAddr,
}
info!("Listening for control connections on {}", address);
impl Listener {
pub async fn bind(address_str: &str) -> io::Result<Self> {
let inner = TcpListener::bind(address_str).await?;
let address = inner.local_addr()?;
while let Ok((raw_stream, remote_address)) = listener.accept().await {
info!("Accepted control connection from {}", remote_address);
info!("Listening for control connections on {}", address);
let ws_stream = tokio_tungstenite::accept_async(raw_stream).await?;
info!("WebSocket connection established from {}", remote_address);
Ok(Self { inner, address })
}
handle(ws_stream, &remote_address, &message_tx, &mut response_rx).await
pub fn address(&self) -> &SocketAddr {
&self.address
} }
Ok(())
/// Starts accepting control connections, one at a time. For each connection,
/// forwards incoming messages from the socket to `message_tx` and outgoing
/// responses from `response_rx` to the socket.
pub async fn run(
&mut self,
message_tx: mpsc::UnboundedSender<Message>,
mut response_rx: mpsc::Receiver<Response>,
) -> anyhow::Result<()> {
loop {
// TODO: Select from response_rx too, and stop looping when it is closed.
let (raw_stream, remote_address) =
self.inner.accept().await.context("accepting connection")?;
info!("Accepted control connection from {}", remote_address);
// TODO: Continue iterating in case of error.
let ws_stream = tokio_tungstenite::accept_async(raw_stream).await?;
info!("WebSocket connection established from {}", remote_address);
// TODO: Stop gracefully once `response_rx` is closed.
handle(ws_stream, &remote_address, &message_tx, &mut response_rx).await
}
}
}
#[cfg(test)]
mod tests {
use super::Listener;
use std::net::SocketAddr;
use anyhow::Context;
use tokio::sync::mpsc;
use tokio_tungstenite::connect_async;
use crate::control::Response;
use crate::dispatcher::Message;
struct Channels {
message_tx: mpsc::UnboundedSender<Message>,
message_rx: mpsc::UnboundedReceiver<Message>,
response_tx: mpsc::Sender<Response>,
response_rx: mpsc::Receiver<Response>,
}
impl Default for Channels {
fn default() -> Self {
let (message_tx, message_rx) = mpsc::unbounded_channel();
let (response_tx, response_rx) = mpsc::channel(100);
Self {
message_tx,
message_rx,
response_tx,
response_rx,
}
}
}
#[tokio::test]
async fn binds_to_localhost() -> anyhow::Result<()> {
let listener = Listener::bind("localhost:0")
.await
.context("binding listener")?;
match listener.address() {
SocketAddr::V4(address) => assert!(address.ip().is_loopback()),
SocketAddr::V6(address) => assert!(address.ip().is_loopback()),
};
Ok(())
}
} }

+ 17
- 8
client/src/main.rs View File

@ -4,12 +4,13 @@ extern crate log;
use std::thread; use std::thread;
use anyhow::Context;
use clap::{App, Arg}; use clap::{App, Arg};
use crossbeam_channel; use crossbeam_channel;
use env_logger; use env_logger;
use futures::stream::StreamExt; use futures::stream::StreamExt;
use log::info; use log::info;
use solstice_proto;
use solstice_proto::config;
use tokio::net::TcpStream; use tokio::net::TcpStream;
mod client; mod client;
@ -108,7 +109,7 @@ async fn run_client(
Ok(()) Ok(())
} }
async fn async_main() {
async fn async_main() -> anyhow::Result<()> {
let bundle = ContextBundle::default(); let bundle = ContextBundle::default();
let (dispatcher_tx, mut dispatcher_rx) = let (dispatcher_tx, mut dispatcher_rx) =
@ -120,7 +121,13 @@ async fn async_main() {
let dispatcher = Dispatcher::new(); let dispatcher = Dispatcher::new();
let executor = Executor::new(bundle.context); let executor = Executor::new(bundle.context);
let control_task = control::listen(dispatcher_tx, bundle.control_response_rx);
let control_address =
format!("{}:{}", config::CONTROL_HOST, config::CONTROL_PORT);
let mut control_listener = control::Listener::bind(&control_address)
.await
.context("binding control listener")?;
let control_task =
control_listener.run(dispatcher_tx, bundle.control_response_rx);
let dispatch_task = async move { let dispatch_task = async move {
while let Some(message) = dispatcher_rx.recv().await { while let Some(message) = dispatcher_rx.recv().await {
@ -145,12 +152,12 @@ async fn async_main() {
client_task client_task
.await .await
.expect("Client task join error")
.expect("Client error");
.context("joining client task")?
.context("running client task")
} }
#[tokio::main] #[tokio::main]
async fn main() {
async fn main() -> anyhow::Result<()> {
env_logger::init(); env_logger::init();
let matches = App::new("solstice-client") let matches = App::new("solstice-client")
@ -165,9 +172,11 @@ async fn main() {
if matches.is_present("async") { if matches.is_present("async") {
info!("Running in asynchronous mode."); info!("Running in asynchronous mode.");
async_main().await;
async_main().await.context("running asynchronous main")
} else { } else {
info!("Running in synchronous mode."); info!("Running in synchronous mode.");
tokio::task::spawn_blocking(old_main).await.unwrap();
tokio::task::spawn_blocking(old_main)
.await
.context("running synchronous main")
} }
} }

Loading…
Cancel
Save