From f18e48774dd78114e1ad103a3db46a8cdae92a9d Mon Sep 17 00:00:00 2001 From: Titouan Rigoudy Date: Tue, 16 Nov 2021 16:07:20 +0100 Subject: [PATCH] Allow establishing reverse peer connections. --- proto/src/peer/message.rs | 19 +++++++++++ proto/src/peer/worker.rs | 69 ++++++++++++++++++++++++++------------- 2 files changed, 66 insertions(+), 22 deletions(-) diff --git a/proto/src/peer/message.rs b/proto/src/peer/message.rs index 3309422..fda9162 100644 --- a/proto/src/peer/message.rs +++ b/proto/src/peer/message.rs @@ -10,6 +10,16 @@ use crate::core::{ const CODE_PIERCE_FIREWALL: u32 = 0; const CODE_PEER_INIT: u32 = 1; +/// A subset of `Message` sent by peers upon establishing a connection. +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum InitialMessage { + /// Sent when a peer establishes a connection on its own behalf. + PeerInit(PeerInit), + + /// Sent when a peer establishes connection on the remote peer's behalf. + PierceFirewall(u32), +} + /// This enum contains all the possible messages peers can exchange. #[derive(Clone, Debug, Eq, PartialEq)] pub enum Message { @@ -18,6 +28,15 @@ pub enum Message { Unknown(u32), } +impl From for Message { + fn from(message: InitialMessage) -> Self { + match message { + InitialMessage::PierceFirewall(inner) => Message::PierceFirewall(inner), + InitialMessage::PeerInit(inner) => Message::PeerInit(inner), + } + } +} + impl ValueDecode for Message { fn decode_from(decoder: &mut ValueDecoder) -> Result { let position = decoder.position(); diff --git a/proto/src/peer/worker.rs b/proto/src/peer/worker.rs index ee07fb0..fc1d4e2 100644 --- a/proto/src/peer/worker.rs +++ b/proto/src/peer/worker.rs @@ -7,7 +7,7 @@ use thiserror::Error; use tokio::net::{TcpListener, TcpStream}; use crate::core::{FrameReader, FrameWriter, Worker}; -use crate::peer::{Message, PeerConnectionType, PeerInit}; +use crate::peer::{InitialMessage, Message, PeerConnectionType}; // Peer states: // @@ -46,11 +46,8 @@ pub struct PeerConnection { /// The address of the peer. pub address: SocketAddr, - /// The type of connection to establish. - pub connection_type: PeerConnectionType, - - /// The user name as which to identify ourselves to the peer. - pub our_user_name: String, + /// The initial message to send to the peer. + pub initial_message: InitialMessage, } /// An error that arose while establishing a connection to a peer. @@ -68,7 +65,7 @@ type PeerWorker = Worker; impl PeerConnection { pub async fn connect(self) -> Result { - let stream = TcpStream::connect(self.address) + let stream = TcpStream::connect(&self.address) .await .map_err(PeerConnectionError::ConnectError)?; @@ -76,12 +73,9 @@ impl PeerConnection { let reader = FrameReader::new(read_half); let mut writer = FrameWriter::new(write_half); + let message = Message::from(self.initial_message); writer - .write(&Message::PeerInit(PeerInit { - user_name: self.our_user_name, - connection_type: self.connection_type, - token: 0, - })) + .write(&message) .await .map_err(PeerConnectionError::WriteError)?; @@ -163,7 +157,7 @@ mod tests { use tokio::net::{TcpListener, TcpStream}; use crate::core::{FrameReader, FrameWriter}; - use crate::peer::{Message, PeerConnectionType, PeerInit}; + use crate::peer::{InitialMessage, Message, PeerConnectionType, PeerInit}; use super::{ IncomingHandshakeError, PeerConnection, PeerConnectionError, PeerListener, @@ -177,8 +171,11 @@ mod tests { Ipv4Addr::new(0, 0, 0, 1), 42, )), - connection_type: PeerConnectionType::File, - our_user_name: "me".to_string(), + initial_message: InitialMessage::PeerInit(PeerInit { + connection_type: PeerConnectionType::File, + user_name: "me".to_string(), + token: 1337, + }), }; let err = connection.connect().await.unwrap_err(); @@ -190,7 +187,7 @@ mod tests { } #[tokio::test] - async fn peer_connection_success() { + async fn peer_connection_success_direct() { let listener = TcpListener::bind("localhost:0") .await .expect("binding listener"); @@ -198,8 +195,11 @@ mod tests { let connection = PeerConnection { address, - connection_type: PeerConnectionType::File, - our_user_name: "me".to_string(), + initial_message: InitialMessage::PeerInit(PeerInit { + connection_type: PeerConnectionType::File, + user_name: "me".to_string(), + token: 1337, + }), }; let (connect_result, accept_result) = @@ -214,11 +214,33 @@ mod tests { Some(Message::PeerInit(PeerInit { user_name: "me".to_string(), connection_type: PeerConnectionType::File, - token: 0, + token: 1337, })) ); } + #[tokio::test] + async fn peer_connection_success_reverse() { + let listener = TcpListener::bind("localhost:0") + .await + .expect("binding listener"); + let address = listener.local_addr().expect("getting local address"); + + let connection = PeerConnection { + address, + initial_message: InitialMessage::PierceFirewall(1337), + }; + + let (connect_result, accept_result) = + tokio::join!(connection.connect(), listener.accept()); + + let _worker = connect_result.expect("connecting"); + let (stream, _remote_address) = accept_result.expect("accepting"); + + let message = FrameReader::new(stream).read().await.expect("reading"); + assert_eq!(message, Some(Message::PierceFirewall(1337))); + } + #[tokio::test] async fn peer_accept_stream_closed() { let tcp_listener = TcpListener::bind("localhost:0").await.expect("binding"); @@ -295,8 +317,11 @@ mod tests { let outgoing_connection = PeerConnection { address: listener_address, - connection_type: PeerConnectionType::File, - our_user_name: "olabode".to_string(), + initial_message: InitialMessage::PeerInit(PeerInit { + connection_type: PeerConnectionType::File, + user_name: "olabode".to_string(), + token: 1337, + }), }; let accept_task = async { @@ -316,6 +341,6 @@ mod tests { assert_eq!(incoming_peer.user_name, "olabode"); assert_eq!(incoming_peer.connection_type, PeerConnectionType::File); - assert_eq!(incoming_peer.token, 0); + assert_eq!(incoming_peer.token, 1337); } }