|
|
|
@ -4,7 +4,7 @@ use std::io; |
|
|
|
use std::net::SocketAddr;
|
|
|
|
|
|
|
|
use thiserror::Error;
|
|
|
|
use tokio::net::TcpStream;
|
|
|
|
use tokio::net::{TcpListener, TcpStream};
|
|
|
|
|
|
|
|
use crate::core::{FrameReader, FrameWriter, Worker};
|
|
|
|
use crate::peer::{Message, PeerConnectionType, PeerInit};
|
|
|
|
@ -89,16 +89,85 @@ impl PeerConnection { |
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Debug)]
|
|
|
|
pub struct IncomingPeer {
|
|
|
|
pub user_name: String,
|
|
|
|
pub address: SocketAddr,
|
|
|
|
pub connection_type: PeerConnectionType,
|
|
|
|
pub token: u32,
|
|
|
|
pub worker: PeerWorker,
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Debug)]
|
|
|
|
pub struct IncomingConnection {
|
|
|
|
address: SocketAddr,
|
|
|
|
stream: TcpStream,
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Debug, Error)]
|
|
|
|
pub enum IncomingHandshakeError {
|
|
|
|
#[error("error reading message: {0}")]
|
|
|
|
ReadError(#[source] io::Error),
|
|
|
|
|
|
|
|
#[error("stream closed unexpectedly")]
|
|
|
|
StreamClosed,
|
|
|
|
|
|
|
|
#[error("unexpected message: {0:?}")]
|
|
|
|
UnexpectedMessage(Message),
|
|
|
|
}
|
|
|
|
|
|
|
|
impl IncomingConnection {
|
|
|
|
pub async fn handshake(self) -> Result<IncomingPeer, IncomingHandshakeError> {
|
|
|
|
let (read_half, write_half) = self.stream.into_split();
|
|
|
|
let mut reader = FrameReader::new(read_half);
|
|
|
|
let writer = FrameWriter::new(write_half);
|
|
|
|
|
|
|
|
let optional_message = reader
|
|
|
|
.read()
|
|
|
|
.await
|
|
|
|
.map_err(IncomingHandshakeError::ReadError)?;
|
|
|
|
|
|
|
|
match optional_message {
|
|
|
|
Some(Message::PeerInit(peer_init)) => Ok(IncomingPeer {
|
|
|
|
user_name: peer_init.user_name,
|
|
|
|
address: self.address,
|
|
|
|
connection_type: peer_init.connection_type,
|
|
|
|
token: peer_init.token,
|
|
|
|
worker: Worker::from_parts(reader, writer),
|
|
|
|
}),
|
|
|
|
Some(message) => Err(IncomingHandshakeError::UnexpectedMessage(message)),
|
|
|
|
None => Err(IncomingHandshakeError::StreamClosed),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
struct PeerListener {
|
|
|
|
inner: TcpListener,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl PeerListener {
|
|
|
|
pub fn new(listener: TcpListener) -> Self {
|
|
|
|
Self { inner: listener }
|
|
|
|
}
|
|
|
|
|
|
|
|
pub async fn accept(&mut self) -> io::Result<IncomingConnection> {
|
|
|
|
let (stream, address) = self.inner.accept().await?;
|
|
|
|
Ok(IncomingConnection { stream, address })
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#[cfg(test)]
|
|
|
|
mod tests {
|
|
|
|
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
|
|
|
|
|
|
|
|
use tokio::net::TcpListener;
|
|
|
|
use tokio::net::{TcpListener, TcpStream};
|
|
|
|
|
|
|
|
use crate::core::FrameReader;
|
|
|
|
use crate::core::{FrameReader, FrameWriter};
|
|
|
|
use crate::peer::{Message, PeerConnectionType, PeerInit};
|
|
|
|
|
|
|
|
use super::{PeerConnection, PeerConnectionError};
|
|
|
|
use super::{
|
|
|
|
IncomingHandshakeError, PeerConnection, PeerConnectionError, PeerListener,
|
|
|
|
};
|
|
|
|
|
|
|
|
#[tokio::test]
|
|
|
|
async fn peer_connection_connect_error() {
|
|
|
|
@ -149,4 +218,104 @@ mod tests { |
|
|
|
}))
|
|
|
|
);
|
|
|
|
}
|
|
|
|
|
|
|
|
#[tokio::test]
|
|
|
|
async fn peer_accept_stream_closed() {
|
|
|
|
let tcp_listener = TcpListener::bind("localhost:0").await.expect("binding");
|
|
|
|
let listener_address =
|
|
|
|
tcp_listener.local_addr().expect("getting local address");
|
|
|
|
|
|
|
|
let mut peer_listener = PeerListener::new(tcp_listener);
|
|
|
|
|
|
|
|
let connect_task = async {
|
|
|
|
// Open connection and immediately drop/close it.
|
|
|
|
TcpStream::connect(listener_address)
|
|
|
|
.await
|
|
|
|
.expect("connecting");
|
|
|
|
};
|
|
|
|
|
|
|
|
let ((), accept_result) =
|
|
|
|
tokio::join!(connect_task, peer_listener.accept());
|
|
|
|
|
|
|
|
let incoming_connection = accept_result.expect("accepting");
|
|
|
|
let err = incoming_connection
|
|
|
|
.handshake()
|
|
|
|
.await
|
|
|
|
.expect_err("performing handshake");
|
|
|
|
|
|
|
|
match err {
|
|
|
|
IncomingHandshakeError::StreamClosed => (),
|
|
|
|
_ => panic!("Wrong error: {:?}", err),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#[tokio::test]
|
|
|
|
async fn peer_accept_unexpected_message() {
|
|
|
|
let tcp_listener = TcpListener::bind("localhost:0").await.expect("binding");
|
|
|
|
let listener_address =
|
|
|
|
tcp_listener.local_addr().expect("getting local address");
|
|
|
|
|
|
|
|
let mut peer_listener = PeerListener::new(tcp_listener);
|
|
|
|
|
|
|
|
let connect_task = async {
|
|
|
|
let stream = TcpStream::connect(listener_address)
|
|
|
|
.await
|
|
|
|
.expect("connecting");
|
|
|
|
|
|
|
|
let mut writer = FrameWriter::new(stream);
|
|
|
|
writer.write("garbage").await.expect("writing");
|
|
|
|
|
|
|
|
// Return writer so that the connection is kept alive until we notice the
|
|
|
|
// error on the receiving side.
|
|
|
|
writer
|
|
|
|
};
|
|
|
|
|
|
|
|
let (_writer, accept_result) =
|
|
|
|
tokio::join!(connect_task, peer_listener.accept());
|
|
|
|
|
|
|
|
let incoming_connection = accept_result.expect("accepting");
|
|
|
|
let err = incoming_connection
|
|
|
|
.handshake()
|
|
|
|
.await
|
|
|
|
.expect_err("performing handshake");
|
|
|
|
|
|
|
|
match err {
|
|
|
|
IncomingHandshakeError::ReadError(_) => (),
|
|
|
|
_ => panic!("Wrong error: {:?}", err),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#[tokio::test]
|
|
|
|
async fn peer_accept_success() {
|
|
|
|
let tcp_listener = TcpListener::bind("localhost:0").await.expect("binding");
|
|
|
|
let listener_address =
|
|
|
|
tcp_listener.local_addr().expect("getting local address");
|
|
|
|
|
|
|
|
let mut peer_listener = PeerListener::new(tcp_listener);
|
|
|
|
|
|
|
|
let outgoing_connection = PeerConnection {
|
|
|
|
address: listener_address,
|
|
|
|
connection_type: PeerConnectionType::File,
|
|
|
|
our_user_name: "olabode".to_string(),
|
|
|
|
};
|
|
|
|
|
|
|
|
let accept_task = async {
|
|
|
|
let incoming_connection =
|
|
|
|
peer_listener.accept().await.expect("accepting");
|
|
|
|
|
|
|
|
incoming_connection
|
|
|
|
.handshake()
|
|
|
|
.await
|
|
|
|
.expect("performing handshake")
|
|
|
|
};
|
|
|
|
|
|
|
|
let (connect_result, incoming_peer) =
|
|
|
|
tokio::join!(outgoing_connection.connect(), accept_task);
|
|
|
|
|
|
|
|
let _worker = connect_result.expect("connecting");
|
|
|
|
|
|
|
|
assert_eq!(incoming_peer.user_name, "olabode");
|
|
|
|
assert_eq!(incoming_peer.connection_type, PeerConnectionType::File);
|
|
|
|
assert_eq!(incoming_peer.token, 0);
|
|
|
|
}
|
|
|
|
}
|