From 146278cd2a299a451b1047ab07b33910c72b6d90 Mon Sep 17 00:00:00 2001 From: Titouan Rigoudy Date: Tue, 16 Nov 2021 16:39:09 +0100 Subject: [PATCH] Allow accepting incoming reverse peer connections. --- proto/src/peer/worker.rs | 180 +++++++++++++++++++++++++-------------- 1 file changed, 115 insertions(+), 65 deletions(-) diff --git a/proto/src/peer/worker.rs b/proto/src/peer/worker.rs index fc1d4e2..db54c91 100644 --- a/proto/src/peer/worker.rs +++ b/proto/src/peer/worker.rs @@ -7,7 +7,7 @@ use thiserror::Error; use tokio::net::{TcpListener, TcpStream}; use crate::core::{FrameReader, FrameWriter, Worker}; -use crate::peer::{InitialMessage, Message, PeerConnectionType}; +use crate::peer::{InitialMessage, Message}; // Peer states: // @@ -83,23 +83,33 @@ impl PeerConnection { } } +/// The result of accepting a peer connection. #[derive(Debug)] -pub struct IncomingPeer { - pub user_name: String, - pub address: SocketAddr, - pub connection_type: PeerConnectionType, - pub token: u32, - pub worker: PeerWorker, +pub struct PeerAcceptResult { + /// The remote address of the peer. + address: SocketAddr, + + /// The newly-accepted TCP stream. + stream: TcpStream, } +/// The result of accepting a peer connection and waiting for the initial +/// message. #[derive(Debug)] -pub struct IncomingConnection { - address: SocketAddr, - stream: TcpStream, +pub struct PeerWaitResult { + /// The remote address of the peer. + pub address: SocketAddr, + + /// The initial message received from the peer. + pub initial_message: InitialMessage, + + /// The worker that will handle the connection from here on out. + pub worker: PeerWorker, } +/// Errors that may arise while waiting for the initial message from a peer. #[derive(Debug, Error)] -pub enum IncomingHandshakeError { +pub enum PeerWaitError { #[error("error reading message: {0}")] ReadError(#[source] io::Error), @@ -110,43 +120,53 @@ pub enum IncomingHandshakeError { UnexpectedMessage(Message), } -impl IncomingConnection { - pub async fn handshake(self) -> Result { +impl PeerAcceptResult { + /// Waits for the initial message to be received from the remote peer. + /// + /// The initial message determines the remote peer's identity as well as the + /// type of the connection. + pub async fn wait_for_initial_message( + self, + ) -> Result { let (read_half, write_half) = self.stream.into_split(); let mut reader = FrameReader::new(read_half); let writer = FrameWriter::new(write_half); - let optional_message = reader - .read() - .await - .map_err(IncomingHandshakeError::ReadError)?; - - match optional_message { - Some(Message::PeerInit(peer_init)) => Ok(IncomingPeer { - user_name: peer_init.user_name, - address: self.address, - connection_type: peer_init.connection_type, - token: peer_init.token, - worker: Worker::from_parts(reader, writer), - }), - Some(message) => Err(IncomingHandshakeError::UnexpectedMessage(message)), - None => Err(IncomingHandshakeError::StreamClosed), - } + let optional_message = + reader.read().await.map_err(PeerWaitError::ReadError)?; + + let initial_message = match optional_message { + Some(Message::PeerInit(inner)) => InitialMessage::PeerInit(inner), + Some(Message::PierceFirewall(inner)) => { + InitialMessage::PierceFirewall(inner) + } + Some(message) => return Err(PeerWaitError::UnexpectedMessage(message)), + None => return Err(PeerWaitError::StreamClosed), + }; + + Ok(PeerWaitResult { + address: self.address, + initial_message, + worker: Worker::from_parts(reader, writer), + }) } } -struct PeerListener { +/// A listener for incoming peer connections. +pub struct PeerListener { inner: TcpListener, } impl PeerListener { + /// Wraps the given TCP listener, specializing it to peer connections. pub fn new(listener: TcpListener) -> Self { Self { inner: listener } } - pub async fn accept(&mut self) -> io::Result { + /// Accepts the next peer connection. + pub async fn accept(&mut self) -> io::Result { let (stream, address) = self.inner.accept().await?; - Ok(IncomingConnection { stream, address }) + Ok(PeerAcceptResult { stream, address }) } } @@ -160,7 +180,7 @@ mod tests { use crate::peer::{InitialMessage, Message, PeerConnectionType, PeerInit}; use super::{ - IncomingHandshakeError, PeerConnection, PeerConnectionError, PeerListener, + PeerConnection, PeerConnectionError, PeerListener, PeerWaitError, }; #[tokio::test] @@ -193,13 +213,15 @@ mod tests { .expect("binding listener"); let address = listener.local_addr().expect("getting local address"); + let initial_message = InitialMessage::PeerInit(PeerInit { + connection_type: PeerConnectionType::File, + user_name: "me".to_string(), + token: 1337, + }); + let connection = PeerConnection { address, - initial_message: InitialMessage::PeerInit(PeerInit { - connection_type: PeerConnectionType::File, - user_name: "me".to_string(), - token: 1337, - }), + initial_message: initial_message.clone(), }; let (connect_result, accept_result) = @@ -209,14 +231,7 @@ mod tests { let (stream, _remote_address) = accept_result.expect("accepting"); let message = FrameReader::new(stream).read().await.expect("reading"); - assert_eq!( - message, - Some(Message::PeerInit(PeerInit { - user_name: "me".to_string(), - connection_type: PeerConnectionType::File, - token: 1337, - })) - ); + assert_eq!(message, Some(Message::from(initial_message))); } #[tokio::test] @@ -226,9 +241,11 @@ mod tests { .expect("binding listener"); let address = listener.local_addr().expect("getting local address"); + let initial_message = InitialMessage::PierceFirewall(1337); + let connection = PeerConnection { address, - initial_message: InitialMessage::PierceFirewall(1337), + initial_message: initial_message.clone(), }; let (connect_result, accept_result) = @@ -238,7 +255,7 @@ mod tests { let (stream, _remote_address) = accept_result.expect("accepting"); let message = FrameReader::new(stream).read().await.expect("reading"); - assert_eq!(message, Some(Message::PierceFirewall(1337))); + assert_eq!(message, Some(Message::from(initial_message))); } #[tokio::test] @@ -261,12 +278,12 @@ mod tests { let incoming_connection = accept_result.expect("accepting"); let err = incoming_connection - .handshake() + .wait_for_initial_message() .await - .expect_err("performing handshake"); + .expect_err("waiting for initial message"); match err { - IncomingHandshakeError::StreamClosed => (), + PeerWaitError::StreamClosed => (), _ => panic!("Wrong error: {:?}", err), } } @@ -297,31 +314,66 @@ mod tests { let incoming_connection = accept_result.expect("accepting"); let err = incoming_connection - .handshake() + .wait_for_initial_message() .await - .expect_err("performing handshake"); + .expect_err("waiting for initial message"); match err { - IncomingHandshakeError::ReadError(_) => (), + PeerWaitError::ReadError(_) => (), _ => panic!("Wrong error: {:?}", err), } } #[tokio::test] - async fn peer_accept_success() { + async fn peer_accept_success_direct() { let tcp_listener = TcpListener::bind("localhost:0").await.expect("binding"); let listener_address = tcp_listener.local_addr().expect("getting local address"); let mut peer_listener = PeerListener::new(tcp_listener); + let initial_message = InitialMessage::PeerInit(PeerInit { + connection_type: PeerConnectionType::File, + user_name: "olabode".to_string(), + token: 1337, + }); + let outgoing_connection = PeerConnection { address: listener_address, - initial_message: InitialMessage::PeerInit(PeerInit { - connection_type: PeerConnectionType::File, - user_name: "olabode".to_string(), - token: 1337, - }), + initial_message: initial_message.clone(), + }; + + let accept_task = async { + let incoming_connection = + peer_listener.accept().await.expect("accepting"); + + incoming_connection + .wait_for_initial_message() + .await + .expect("waiting for initial message") + }; + + let (connect_result, incoming_peer) = + tokio::join!(outgoing_connection.connect(), accept_task); + + let _worker = connect_result.expect("connecting"); + + assert_eq!(incoming_peer.initial_message, initial_message); + } + + #[tokio::test] + async fn peer_accept_success_reverse() { + let tcp_listener = TcpListener::bind("localhost:0").await.expect("binding"); + let listener_address = + tcp_listener.local_addr().expect("getting local address"); + + let mut peer_listener = PeerListener::new(tcp_listener); + + let initial_message = InitialMessage::PierceFirewall(1337); + + let outgoing_connection = PeerConnection { + address: listener_address, + initial_message: initial_message.clone(), }; let accept_task = async { @@ -329,9 +381,9 @@ mod tests { peer_listener.accept().await.expect("accepting"); incoming_connection - .handshake() + .wait_for_initial_message() .await - .expect("performing handshake") + .expect("waiting for initial message") }; let (connect_result, incoming_peer) = @@ -339,8 +391,6 @@ mod tests { let _worker = connect_result.expect("connecting"); - assert_eq!(incoming_peer.user_name, "olabode"); - assert_eq!(incoming_peer.connection_type, PeerConnectionType::File); - assert_eq!(incoming_peer.token, 1337); + assert_eq!(incoming_peer.initial_message, initial_message); } }