diff --git a/proto/src/core/frame.rs b/proto/src/core/frame.rs index 214ec92..850d946 100644 --- a/proto/src/core/frame.rs +++ b/proto/src/core/frame.rs @@ -9,7 +9,6 @@ use std::marker::PhantomData; use bytes::BytesMut; use thiserror::Error; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use tokio::net::TcpStream; use super::prefix::Prefixer; use super::u32::{decode_u32, U32_BYTE_LEN}; @@ -206,59 +205,6 @@ where } } -#[derive(Debug)] -pub struct FrameStream { - stream: TcpStream, - - read_buffer: BytesMut, - - decoder: FrameDecoder, - encoder: FrameEncoder, -} - -impl FrameStream -where - ReadFrame: ValueDecode, - WriteFrame: ValueEncode + ?Sized, -{ - pub fn new(stream: TcpStream) -> Self { - FrameStream { - stream, - read_buffer: BytesMut::new(), - decoder: FrameDecoder::new(), - encoder: FrameEncoder::new(), - } - } - - /// Attempts to read the next frame from the underlying byte stream. - /// - /// Returns `Ok(Some(frame))` on success. - /// Returns `Ok(None)` if the stream has reached the end-of-file event. - /// - /// Returns an error if reading from the stream returned an error or if an - /// invalid frame was received. - pub async fn read(&mut self) -> io::Result> { - loop { - if let Some(frame) = self.decoder.decode_from(&mut self.read_buffer)? { - return Ok(Some(frame)); - } - if self.stream.read_buf(&mut self.read_buffer).await? == 0 { - return Ok(None); - } - } - } - - pub async fn write(&mut self, frame: &WriteFrame) -> io::Result<()> { - let mut bytes = BytesMut::new(); - self.encoder.encode_to(frame, &mut bytes)?; - self.stream.write_all(bytes.as_ref()).await - } - - pub async fn shutdown(&mut self) -> io::Result<()> { - self.stream.shutdown().await - } -} - #[cfg(test)] mod tests { use bytes::BytesMut; diff --git a/proto/src/server/client.rs b/proto/src/server/client.rs index c3586d6..617d177 100644 --- a/proto/src/server/client.rs +++ b/proto/src/server/client.rs @@ -5,9 +5,12 @@ use std::io; use futures::stream::{Stream, StreamExt}; use log::{debug, info}; use thiserror::Error; -use tokio::net; +use tokio::net::{ + tcp::{OwnedReadHalf, OwnedWriteHalf}, + TcpStream, +}; -use crate::core::frame::FrameStream; +use crate::core::frame::{FrameReader, FrameWriter}; use crate::server::{Credentials, LoginResponse, ServerRequest, ServerResponse, Version}; /// Specifies options for a new `Client`. @@ -18,7 +21,8 @@ pub struct ClientOptions { /// A client for the client-server protocol. pub struct Client { - frame_stream: FrameStream, + reader: FrameReader, + writer: FrameWriter, } /// An error that arose while logging in to a remote server. @@ -65,11 +69,13 @@ enum RunOnceResult { impl Client { pub async fn login( - tcp_stream: net::TcpStream, + tcp_stream: TcpStream, options: ClientOptions, ) -> Result { + let (read_half, write_half) = tcp_stream.into_split(); let mut client = Client { - frame_stream: FrameStream::new(tcp_stream), + reader: FrameReader::new(read_half), + writer: FrameWriter::new(write_half), }; client.handshake(options).await?; @@ -84,9 +90,9 @@ impl Client { debug!("Client: sending login request: {:?}", login_request); let request = login_request.into(); - self.frame_stream.write(&request).await?; + self.writer.write(&request).await?; - let response = self.frame_stream.read().await?; + let response = self.reader.read().await?; debug!("Client: received first response: {:?}", response); match response { @@ -117,14 +123,14 @@ impl Client { maybe_request = request_stream.next() => { if let Some(request) = maybe_request { debug!("Client: sending request: {:?}", request); - self.frame_stream.write(&request).await?; + self.writer.write(&request).await?; Ok(RunOnceResult::Continue) } else { // Sender has been dropped. Ok(RunOnceResult::Break) } }, - read_result = self.frame_stream.read() => { + read_result = self.reader.read() => { match read_result? { Some(response) => { debug!("Client: received response: {:?}", response); @@ -154,10 +160,10 @@ impl Client { } debug!("Client: shutting down outbound stream"); - self.frame_stream.shutdown().await?; + self.writer.shutdown().await?; // Drain the receiving end of the connection. - while let Some(response) = self.frame_stream.read().await? { + while let Some(response) = self.reader.read().await? { debug!("Client: received response: {:?}", response); yield response; } @@ -168,7 +174,7 @@ impl Client { #[cfg(test)] mod tests { use futures::stream::{empty, StreamExt}; - use tokio::net; + use tokio::net::TcpStream; use tokio::sync::mpsc; use crate::server::testing::{ServerBuilder, ShutdownType, UserStatusMap}; @@ -203,7 +209,7 @@ mod tests { let (server, handle) = ServerBuilder::default().bind().await.unwrap(); let server_task = tokio::spawn(server.serve()); - let stream = net::TcpStream::connect(handle.address()).await.unwrap(); + let stream = TcpStream::connect(handle.address()).await.unwrap(); let client = Client::login(stream, client_options()).await.unwrap(); @@ -235,7 +241,7 @@ mod tests { .unwrap(); let server_task = tokio::spawn(server.serve()); - let stream = net::TcpStream::connect(handle.address()).await.unwrap(); + let stream = TcpStream::connect(handle.address()).await.unwrap(); let client = Client::login(stream, client_options()).await.unwrap(); @@ -266,7 +272,7 @@ mod tests { let (server, handle) = ServerBuilder::default().bind().await.unwrap(); let server_task = tokio::spawn(server.serve()); - let stream = net::TcpStream::connect(handle.address()).await.unwrap(); + let stream = TcpStream::connect(handle.address()).await.unwrap(); let client = Client::login(stream, client_options()).await.unwrap(); diff --git a/proto/src/server/testing.rs b/proto/src/server/testing.rs index 7cd7513..2dd967e 100644 --- a/proto/src/server/testing.rs +++ b/proto/src/server/testing.rs @@ -7,12 +7,15 @@ use std::sync::Arc; use log::{debug, info, warn}; use parking_lot::Mutex; -use tokio::net::{TcpListener, TcpStream}; +use tokio::net::{ + tcp::{OwnedReadHalf, OwnedWriteHalf}, + TcpListener, TcpStream, +}; use tokio::sync::mpsc; use tokio::sync::oneshot; use tokio::sync::watch; -use crate::core::frame::FrameStream; +use crate::core::frame::{FrameReader, FrameWriter}; use crate::server::{ LoginResponse, ServerRequest, ServerResponse, UserStatusRequest, UserStatusResponse, }; @@ -35,7 +38,8 @@ impl UserStatusMap { } struct Handler { - frame_stream: FrameStream, + reader: FrameReader, + writer: FrameWriter, peer_address: SocketAddr, user_status_map: Arc>, } @@ -50,11 +54,11 @@ impl Handler { async fn send_response(&mut self, response: &ServerResponse) -> io::Result<()> { debug!("Handler: sending response: {:?}", response); - self.frame_stream.write(response).await + self.writer.write(response).await } async fn handle_login(&mut self) -> io::Result<()> { - match self.frame_stream.read().await? { + match self.reader.read().await? { Some(ServerRequest::LoginRequest(request)) => { info!("Handler: Received login request: {:?}", request); } @@ -102,7 +106,7 @@ impl Handler { async fn run(mut self) -> io::Result<()> { self.handle_login().await?; - while let Some(request) = self.frame_stream.read().await? { + while let Some(request) = self.reader.read().await? { self.handle_request(request).await?; } @@ -251,10 +255,13 @@ impl Server { /// Spawns a handler for the given new stream, initiated by a remote peer. fn spawn_handler(&mut self, stream: TcpStream, peer_address: SocketAddr) { + let (read_half, write_half) = stream.into_split(); + let handler = SenderHandler { handler: GracefulHandler { handler: Handler { - frame_stream: FrameStream::new(stream), + reader: FrameReader::new(read_half), + writer: FrameWriter::new(write_half), peer_address, user_status_map: self.user_status_map.clone(), },