diff --git a/proto/src/peer/mod.rs b/proto/src/peer/mod.rs index ca31c63..e81126b 100644 --- a/proto/src/peer/mod.rs +++ b/proto/src/peer/mod.rs @@ -1,3 +1,5 @@ mod message; +mod worker; pub use self::message::*; +pub use self::worker::*; diff --git a/proto/src/peer/worker.rs b/proto/src/peer/worker.rs new file mode 100644 index 0000000..6cca3c0 --- /dev/null +++ b/proto/src/peer/worker.rs @@ -0,0 +1,152 @@ +//! Defines a worker model for peer connections. + +use std::io; +use std::net::SocketAddr; + +use thiserror::Error; +use tokio::net::TcpStream; + +use crate::core::{FrameReader, FrameWriter, Worker}; +use crate::peer::{Message, PeerConnectionType, PeerInit}; + +// Peer states: +// +// - closed +// - open +// - waiting for pierce firewall +// - cannot connect +// +// Transitions: +// +// - closed: +// -> open: +// connect +// send peerinit +// -> waiting for pierce firewall: +// connect failed, or send peerinit failed +// send connecttopeer to server +// -> open: +// accept connection +// receive peerinit +// - open: +// -> closed: +// connection closed +// - waiting for pierce firewall: +// -> open: +// accept connection +// receive pierce firewall +// ???send peer init??? +// -> cannot connect: +// receive cannot connect +// + +/// A peer to connect to. +#[derive(Debug)] +pub struct PeerConnection { + /// The address of the peer. + pub address: SocketAddr, + + /// The type of connection to establish. + pub connection_type: PeerConnectionType, + + /// The user name as which to identify ourselves to the peer. + pub our_user_name: String, +} + +/// An error that arose while establishing a connection to a peer. +#[derive(Debug, Error)] +pub enum PeerConnectionError { + #[error("error establishing network connection: {0}")] + ConnectError(#[source] io::Error), + + #[error("error sending initial message: {0}")] + WriteError(#[source] io::Error), +} + +/// A `Worker` that handles an open connection to a peer. +type PeerWorker = Worker; + +impl PeerConnection { + pub async fn connect(self) -> Result { + let stream = TcpStream::connect(self.address) + .await + .map_err(PeerConnectionError::ConnectError)?; + + let (read_half, write_half) = stream.into_split(); + let reader = FrameReader::new(read_half); + let mut writer = FrameWriter::new(write_half); + + writer + .write(&Message::PeerInit(PeerInit { + user_name: self.our_user_name, + connection_type: self.connection_type, + token: 0, + })) + .await + .map_err(PeerConnectionError::WriteError)?; + + Ok(Worker::from_parts(reader, writer)) + } +} + +#[cfg(test)] +mod tests { + use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; + + use tokio::net::TcpListener; + + use crate::core::FrameReader; + use crate::peer::{Message, PeerConnectionType, PeerInit}; + + use super::{PeerConnection, PeerConnectionError}; + + #[tokio::test] + async fn peer_connection_connect_error() { + let connection = PeerConnection { + address: SocketAddr::V4(SocketAddrV4::new( + // TODO: use example IP instead, ensuring this fails. + Ipv4Addr::new(0, 0, 0, 1), + 42, + )), + connection_type: PeerConnectionType::File, + our_user_name: "me".to_string(), + }; + + let err = connection.connect().await.unwrap_err(); + + match err { + PeerConnectionError::ConnectError(_) => (), + _ => panic!("Wrong error: {:?}", err), + } + } + + #[tokio::test] + async fn peer_connection_success() { + let listener = TcpListener::bind("localhost:0") + .await + .expect("binding listener"); + let address = listener.local_addr().expect("getting local address"); + + let connection = PeerConnection { + address, + connection_type: PeerConnectionType::File, + our_user_name: "me".to_string(), + }; + + let (connect_result, accept_result) = + tokio::join!(connection.connect(), listener.accept()); + + let _worker = connect_result.expect("connecting"); + let (stream, _remote_address) = accept_result.expect("accepting"); + + let message = FrameReader::new(stream).read().await.expect("reading"); + assert_eq!( + message, + Some(Message::PeerInit(PeerInit { + user_name: "me".to_string(), + connection_type: PeerConnectionType::File, + token: 0, + })) + ); + } +}