Browse Source

Allow accepting incoming reverse peer connections.

wip
Titouan Rigoudy 4 years ago
parent
commit
146278cd2a
1 changed files with 115 additions and 65 deletions
  1. +115
    -65
      proto/src/peer/worker.rs

+ 115
- 65
proto/src/peer/worker.rs View File

@ -7,7 +7,7 @@ use thiserror::Error;
use tokio::net::{TcpListener, TcpStream};
use crate::core::{FrameReader, FrameWriter, Worker};
use crate::peer::{InitialMessage, Message, PeerConnectionType};
use crate::peer::{InitialMessage, Message};
// Peer states:
//
@ -83,23 +83,33 @@ impl PeerConnection {
}
}
/// The result of accepting a peer connection.
#[derive(Debug)]
pub struct IncomingPeer {
pub user_name: String,
pub address: SocketAddr,
pub connection_type: PeerConnectionType,
pub token: u32,
pub worker: PeerWorker,
pub struct PeerAcceptResult {
/// The remote address of the peer.
address: SocketAddr,
/// The newly-accepted TCP stream.
stream: TcpStream,
}
/// The result of accepting a peer connection and waiting for the initial
/// message.
#[derive(Debug)]
pub struct IncomingConnection {
address: SocketAddr,
stream: TcpStream,
pub struct PeerWaitResult {
/// The remote address of the peer.
pub address: SocketAddr,
/// The initial message received from the peer.
pub initial_message: InitialMessage,
/// The worker that will handle the connection from here on out.
pub worker: PeerWorker,
}
/// Errors that may arise while waiting for the initial message from a peer.
#[derive(Debug, Error)]
pub enum IncomingHandshakeError {
pub enum PeerWaitError {
#[error("error reading message: {0}")]
ReadError(#[source] io::Error),
@ -110,43 +120,53 @@ pub enum IncomingHandshakeError {
UnexpectedMessage(Message),
}
impl IncomingConnection {
pub async fn handshake(self) -> Result<IncomingPeer, IncomingHandshakeError> {
impl PeerAcceptResult {
/// Waits for the initial message to be received from the remote peer.
///
/// The initial message determines the remote peer's identity as well as the
/// type of the connection.
pub async fn wait_for_initial_message(
self,
) -> Result<PeerWaitResult, PeerWaitError> {
let (read_half, write_half) = self.stream.into_split();
let mut reader = FrameReader::new(read_half);
let writer = FrameWriter::new(write_half);
let optional_message = reader
.read()
.await
.map_err(IncomingHandshakeError::ReadError)?;
match optional_message {
Some(Message::PeerInit(peer_init)) => Ok(IncomingPeer {
user_name: peer_init.user_name,
address: self.address,
connection_type: peer_init.connection_type,
token: peer_init.token,
worker: Worker::from_parts(reader, writer),
}),
Some(message) => Err(IncomingHandshakeError::UnexpectedMessage(message)),
None => Err(IncomingHandshakeError::StreamClosed),
}
let optional_message =
reader.read().await.map_err(PeerWaitError::ReadError)?;
let initial_message = match optional_message {
Some(Message::PeerInit(inner)) => InitialMessage::PeerInit(inner),
Some(Message::PierceFirewall(inner)) => {
InitialMessage::PierceFirewall(inner)
}
Some(message) => return Err(PeerWaitError::UnexpectedMessage(message)),
None => return Err(PeerWaitError::StreamClosed),
};
Ok(PeerWaitResult {
address: self.address,
initial_message,
worker: Worker::from_parts(reader, writer),
})
}
}
struct PeerListener {
/// A listener for incoming peer connections.
pub struct PeerListener {
inner: TcpListener,
}
impl PeerListener {
/// Wraps the given TCP listener, specializing it to peer connections.
pub fn new(listener: TcpListener) -> Self {
Self { inner: listener }
}
pub async fn accept(&mut self) -> io::Result<IncomingConnection> {
/// Accepts the next peer connection.
pub async fn accept(&mut self) -> io::Result<PeerAcceptResult> {
let (stream, address) = self.inner.accept().await?;
Ok(IncomingConnection { stream, address })
Ok(PeerAcceptResult { stream, address })
}
}
@ -160,7 +180,7 @@ mod tests {
use crate::peer::{InitialMessage, Message, PeerConnectionType, PeerInit};
use super::{
IncomingHandshakeError, PeerConnection, PeerConnectionError, PeerListener,
PeerConnection, PeerConnectionError, PeerListener, PeerWaitError,
};
#[tokio::test]
@ -193,13 +213,15 @@ mod tests {
.expect("binding listener");
let address = listener.local_addr().expect("getting local address");
let initial_message = InitialMessage::PeerInit(PeerInit {
connection_type: PeerConnectionType::File,
user_name: "me".to_string(),
token: 1337,
});
let connection = PeerConnection {
address,
initial_message: InitialMessage::PeerInit(PeerInit {
connection_type: PeerConnectionType::File,
user_name: "me".to_string(),
token: 1337,
}),
initial_message: initial_message.clone(),
};
let (connect_result, accept_result) =
@ -209,14 +231,7 @@ mod tests {
let (stream, _remote_address) = accept_result.expect("accepting");
let message = FrameReader::new(stream).read().await.expect("reading");
assert_eq!(
message,
Some(Message::PeerInit(PeerInit {
user_name: "me".to_string(),
connection_type: PeerConnectionType::File,
token: 1337,
}))
);
assert_eq!(message, Some(Message::from(initial_message)));
}
#[tokio::test]
@ -226,9 +241,11 @@ mod tests {
.expect("binding listener");
let address = listener.local_addr().expect("getting local address");
let initial_message = InitialMessage::PierceFirewall(1337);
let connection = PeerConnection {
address,
initial_message: InitialMessage::PierceFirewall(1337),
initial_message: initial_message.clone(),
};
let (connect_result, accept_result) =
@ -238,7 +255,7 @@ mod tests {
let (stream, _remote_address) = accept_result.expect("accepting");
let message = FrameReader::new(stream).read().await.expect("reading");
assert_eq!(message, Some(Message::PierceFirewall(1337)));
assert_eq!(message, Some(Message::from(initial_message)));
}
#[tokio::test]
@ -261,12 +278,12 @@ mod tests {
let incoming_connection = accept_result.expect("accepting");
let err = incoming_connection
.handshake()
.wait_for_initial_message()
.await
.expect_err("performing handshake");
.expect_err("waiting for initial message");
match err {
IncomingHandshakeError::StreamClosed => (),
PeerWaitError::StreamClosed => (),
_ => panic!("Wrong error: {:?}", err),
}
}
@ -297,31 +314,66 @@ mod tests {
let incoming_connection = accept_result.expect("accepting");
let err = incoming_connection
.handshake()
.wait_for_initial_message()
.await
.expect_err("performing handshake");
.expect_err("waiting for initial message");
match err {
IncomingHandshakeError::ReadError(_) => (),
PeerWaitError::ReadError(_) => (),
_ => panic!("Wrong error: {:?}", err),
}
}
#[tokio::test]
async fn peer_accept_success() {
async fn peer_accept_success_direct() {
let tcp_listener = TcpListener::bind("localhost:0").await.expect("binding");
let listener_address =
tcp_listener.local_addr().expect("getting local address");
let mut peer_listener = PeerListener::new(tcp_listener);
let initial_message = InitialMessage::PeerInit(PeerInit {
connection_type: PeerConnectionType::File,
user_name: "olabode".to_string(),
token: 1337,
});
let outgoing_connection = PeerConnection {
address: listener_address,
initial_message: InitialMessage::PeerInit(PeerInit {
connection_type: PeerConnectionType::File,
user_name: "olabode".to_string(),
token: 1337,
}),
initial_message: initial_message.clone(),
};
let accept_task = async {
let incoming_connection =
peer_listener.accept().await.expect("accepting");
incoming_connection
.wait_for_initial_message()
.await
.expect("waiting for initial message")
};
let (connect_result, incoming_peer) =
tokio::join!(outgoing_connection.connect(), accept_task);
let _worker = connect_result.expect("connecting");
assert_eq!(incoming_peer.initial_message, initial_message);
}
#[tokio::test]
async fn peer_accept_success_reverse() {
let tcp_listener = TcpListener::bind("localhost:0").await.expect("binding");
let listener_address =
tcp_listener.local_addr().expect("getting local address");
let mut peer_listener = PeerListener::new(tcp_listener);
let initial_message = InitialMessage::PierceFirewall(1337);
let outgoing_connection = PeerConnection {
address: listener_address,
initial_message: initial_message.clone(),
};
let accept_task = async {
@ -329,9 +381,9 @@ mod tests {
peer_listener.accept().await.expect("accepting");
incoming_connection
.handshake()
.wait_for_initial_message()
.await
.expect("performing handshake")
.expect("waiting for initial message")
};
let (connect_result, incoming_peer) =
@ -339,8 +391,6 @@ mod tests {
let _worker = connect_result.expect("connecting");
assert_eq!(incoming_peer.user_name, "olabode");
assert_eq!(incoming_peer.connection_type, PeerConnectionType::File);
assert_eq!(incoming_peer.token, 1337);
assert_eq!(incoming_peer.initial_message, initial_message);
}
}

Loading…
Cancel
Save