From 902a5d1b0ed66d5e65ae7adc89ab607468ac84ae Mon Sep 17 00:00:00 2001 From: Titouan Rigoudy Date: Tue, 16 Nov 2021 14:03:59 +0100 Subject: [PATCH] Implement receiving peer connections. --- proto/src/peer/worker.rs | 177 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 173 insertions(+), 4 deletions(-) diff --git a/proto/src/peer/worker.rs b/proto/src/peer/worker.rs index 6cca3c0..ee07fb0 100644 --- a/proto/src/peer/worker.rs +++ b/proto/src/peer/worker.rs @@ -4,7 +4,7 @@ use std::io; use std::net::SocketAddr; use thiserror::Error; -use tokio::net::TcpStream; +use tokio::net::{TcpListener, TcpStream}; use crate::core::{FrameReader, FrameWriter, Worker}; use crate::peer::{Message, PeerConnectionType, PeerInit}; @@ -89,16 +89,85 @@ impl PeerConnection { } } +#[derive(Debug)] +pub struct IncomingPeer { + pub user_name: String, + pub address: SocketAddr, + pub connection_type: PeerConnectionType, + pub token: u32, + pub worker: PeerWorker, +} + +#[derive(Debug)] +pub struct IncomingConnection { + address: SocketAddr, + stream: TcpStream, +} + +#[derive(Debug, Error)] +pub enum IncomingHandshakeError { + #[error("error reading message: {0}")] + ReadError(#[source] io::Error), + + #[error("stream closed unexpectedly")] + StreamClosed, + + #[error("unexpected message: {0:?}")] + UnexpectedMessage(Message), +} + +impl IncomingConnection { + pub async fn handshake(self) -> Result { + let (read_half, write_half) = self.stream.into_split(); + let mut reader = FrameReader::new(read_half); + let writer = FrameWriter::new(write_half); + + let optional_message = reader + .read() + .await + .map_err(IncomingHandshakeError::ReadError)?; + + match optional_message { + Some(Message::PeerInit(peer_init)) => Ok(IncomingPeer { + user_name: peer_init.user_name, + address: self.address, + connection_type: peer_init.connection_type, + token: peer_init.token, + worker: Worker::from_parts(reader, writer), + }), + Some(message) => Err(IncomingHandshakeError::UnexpectedMessage(message)), + None => Err(IncomingHandshakeError::StreamClosed), + } + } +} + +struct PeerListener { + inner: TcpListener, +} + +impl PeerListener { + pub fn new(listener: TcpListener) -> Self { + Self { inner: listener } + } + + pub async fn accept(&mut self) -> io::Result { + let (stream, address) = self.inner.accept().await?; + Ok(IncomingConnection { stream, address }) + } +} + #[cfg(test)] mod tests { use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; - use tokio::net::TcpListener; + use tokio::net::{TcpListener, TcpStream}; - use crate::core::FrameReader; + use crate::core::{FrameReader, FrameWriter}; use crate::peer::{Message, PeerConnectionType, PeerInit}; - use super::{PeerConnection, PeerConnectionError}; + use super::{ + IncomingHandshakeError, PeerConnection, PeerConnectionError, PeerListener, + }; #[tokio::test] async fn peer_connection_connect_error() { @@ -149,4 +218,104 @@ mod tests { })) ); } + + #[tokio::test] + async fn peer_accept_stream_closed() { + let tcp_listener = TcpListener::bind("localhost:0").await.expect("binding"); + let listener_address = + tcp_listener.local_addr().expect("getting local address"); + + let mut peer_listener = PeerListener::new(tcp_listener); + + let connect_task = async { + // Open connection and immediately drop/close it. + TcpStream::connect(listener_address) + .await + .expect("connecting"); + }; + + let ((), accept_result) = + tokio::join!(connect_task, peer_listener.accept()); + + let incoming_connection = accept_result.expect("accepting"); + let err = incoming_connection + .handshake() + .await + .expect_err("performing handshake"); + + match err { + IncomingHandshakeError::StreamClosed => (), + _ => panic!("Wrong error: {:?}", err), + } + } + + #[tokio::test] + async fn peer_accept_unexpected_message() { + let tcp_listener = TcpListener::bind("localhost:0").await.expect("binding"); + let listener_address = + tcp_listener.local_addr().expect("getting local address"); + + let mut peer_listener = PeerListener::new(tcp_listener); + + let connect_task = async { + let stream = TcpStream::connect(listener_address) + .await + .expect("connecting"); + + let mut writer = FrameWriter::new(stream); + writer.write("garbage").await.expect("writing"); + + // Return writer so that the connection is kept alive until we notice the + // error on the receiving side. + writer + }; + + let (_writer, accept_result) = + tokio::join!(connect_task, peer_listener.accept()); + + let incoming_connection = accept_result.expect("accepting"); + let err = incoming_connection + .handshake() + .await + .expect_err("performing handshake"); + + match err { + IncomingHandshakeError::ReadError(_) => (), + _ => panic!("Wrong error: {:?}", err), + } + } + + #[tokio::test] + async fn peer_accept_success() { + let tcp_listener = TcpListener::bind("localhost:0").await.expect("binding"); + let listener_address = + tcp_listener.local_addr().expect("getting local address"); + + let mut peer_listener = PeerListener::new(tcp_listener); + + let outgoing_connection = PeerConnection { + address: listener_address, + connection_type: PeerConnectionType::File, + our_user_name: "olabode".to_string(), + }; + + let accept_task = async { + let incoming_connection = + peer_listener.accept().await.expect("accepting"); + + incoming_connection + .handshake() + .await + .expect("performing handshake") + }; + + let (connect_result, incoming_peer) = + tokio::join!(outgoing_connection.connect(), accept_task); + + let _worker = connect_result.expect("connecting"); + + assert_eq!(incoming_peer.user_name, "olabode"); + assert_eq!(incoming_peer.connection_type, PeerConnectionType::File); + assert_eq!(incoming_peer.token, 0); + } }