Browse Source

Apply same rustfmt.toml to proto/.

wip
Titouan Rigoudy 4 years ago
parent
commit
6f8d18f283
18 changed files with 4898 additions and 4666 deletions
  1. +213
    -212
      proto/src/core/channel.rs
  2. +347
    -337
      proto/src/core/frame.rs
  3. +88
    -88
      proto/src/core/prefix.rs
  4. +2
    -2
      proto/src/core/u32.rs
  5. +84
    -66
      proto/src/core/user.rs
  6. +714
    -675
      proto/src/core/value.rs
  7. +281
    -272
      proto/src/handler.rs
  8. +251
    -251
      proto/src/packet.rs
  9. +155
    -146
      proto/src/peer/message.rs
  10. +173
    -170
      proto/src/server/client.rs
  11. +106
    -105
      proto/src/server/credentials.rs
  12. +633
    -599
      proto/src/server/request.rs
  13. +1338
    -1237
      proto/src/server/response.rs
  14. +293
    -289
      proto/src/server/testing.rs
  15. +9
    -9
      proto/src/server/version.rs
  16. +165
    -164
      proto/src/stream.rs
  17. +46
    -44
      proto/tests/connect.rs
  18. +0
    -0
      rustfmt.toml

+ 213
- 212
proto/src/core/channel.rs View File

@ -6,8 +6,8 @@ use futures::stream::{Stream, StreamExt};
use log::debug; use log::debug;
use thiserror::Error; use thiserror::Error;
use tokio::net::{ use tokio::net::{
tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpStream,
tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpStream,
}; };
use crate::core::frame::{FrameReader, FrameWriter}; use crate::core::frame::{FrameReader, FrameWriter};
@ -16,21 +16,21 @@ use crate::core::value::{ValueDecode, ValueEncode};
/// An error that arose while exchanging messages over a `Channel`. /// An error that arose while exchanging messages over a `Channel`.
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum ChannelError { pub enum ChannelError {
#[error("unexpected end of file")]
UnexpectedEof,
#[error("unexpected end of file")]
UnexpectedEof,
#[error("i/o error: {0}")]
IOError(#[from] io::Error),
#[error("i/o error: {0}")]
IOError(#[from] io::Error),
} }
impl ChannelError { impl ChannelError {
#[cfg(test)]
pub fn is_unexpected_eof(&self) -> bool {
match self {
ChannelError::UnexpectedEof => true,
_ => false,
}
#[cfg(test)]
pub fn is_unexpected_eof(&self) -> bool {
match self {
ChannelError::UnexpectedEof => true,
_ => false,
} }
}
} }
/// A wrapper around a frame reader. Logically a part of `Channel`. /// A wrapper around a frame reader. Logically a part of `Channel`.
@ -40,270 +40,271 @@ impl ChannelError {
/// at the same time in `Channel` without resorting to static methods. /// at the same time in `Channel` without resorting to static methods.
#[derive(Debug)] #[derive(Debug)]
struct ChannelReader<ReadFrame> { struct ChannelReader<ReadFrame> {
inner: FrameReader<ReadFrame, OwnedReadHalf>,
inner: FrameReader<ReadFrame, OwnedReadHalf>,
} }
impl<ReadFrame> ChannelReader<ReadFrame> impl<ReadFrame> ChannelReader<ReadFrame>
where where
ReadFrame: ValueDecode + Debug,
ReadFrame: ValueDecode + Debug,
{ {
async fn read(&mut self) -> io::Result<Option<ReadFrame>> {
self.inner.read().await.map(|frame| {
debug!("Channel: received frame: {:?}", frame);
frame
})
}
async fn read_strict(&mut self) -> Result<ReadFrame, ChannelError> {
match self.read().await? {
None => Err(ChannelError::UnexpectedEof),
Some(frame) => Ok(frame),
}
async fn read(&mut self) -> io::Result<Option<ReadFrame>> {
self.inner.read().await.map(|frame| {
debug!("Channel: received frame: {:?}", frame);
frame
})
}
async fn read_strict(&mut self) -> Result<ReadFrame, ChannelError> {
match self.read().await? {
None => Err(ChannelError::UnexpectedEof),
Some(frame) => Ok(frame),
} }
}
} }
/// An asynchronous bidirectional message channel over TCP. /// An asynchronous bidirectional message channel over TCP.
#[derive(Debug)] #[derive(Debug)]
pub struct Channel<ReadFrame, WriteFrame> { pub struct Channel<ReadFrame, WriteFrame> {
reader: ChannelReader<ReadFrame>,
writer: FrameWriter<WriteFrame, OwnedWriteHalf>,
reader: ChannelReader<ReadFrame>,
writer: FrameWriter<WriteFrame, OwnedWriteHalf>,
} }
impl<ReadFrame, WriteFrame> Channel<ReadFrame, WriteFrame> impl<ReadFrame, WriteFrame> Channel<ReadFrame, WriteFrame>
where where
ReadFrame: ValueDecode + Debug,
WriteFrame: ValueEncode + Debug,
ReadFrame: ValueDecode + Debug,
WriteFrame: ValueEncode + Debug,
{ {
/// Wraps the given `stream` to yield a message channel.
pub fn new(stream: TcpStream) -> Self {
let (read_half, write_half) = stream.into_split();
Self {
reader: ChannelReader {
inner: FrameReader::new(read_half),
},
writer: FrameWriter::new(write_half),
}
}
// This future sends all the requests from `request_stream` through `writer`
// until the stream is finished, then resolves.
async fn send<S: Stream<Item = WriteFrame>>(
writer: &mut FrameWriter<WriteFrame, OwnedWriteHalf>,
send_stream: S,
) -> io::Result<()> {
tokio::pin!(send_stream);
while let Some(frame) = send_stream.next().await {
debug!("Channel: sending frame: {:?}", frame);
writer.write(&frame).await?;
}
Ok(())
/// Wraps the given `stream` to yield a message channel.
pub fn new(stream: TcpStream) -> Self {
let (read_half, write_half) = stream.into_split();
Self {
reader: ChannelReader {
inner: FrameReader::new(read_half),
},
writer: FrameWriter::new(write_half),
} }
// It would be easier to inline this `select!` call inside `run()`, but that
// fails due to some weird, undiagnosed error due to the interaction of
// `async_stream::try_stream!`, `select!` and the `?` operator.
async fn run_once<S: Future<Output = io::Result<()>>>(
send_task: S,
reader: &mut ChannelReader<ReadFrame>,
) -> Result<Option<ReadFrame>, ChannelError> {
tokio::select! {
send_result = send_task => {
send_result?;
Ok(None)
},
read_result = reader.read_strict() => read_result.map(Some),
}
}
/// Attempts to read a single frame from the underlying stream.
pub async fn read(&mut self) -> Result<ReadFrame, ChannelError> {
self.reader.read_strict().await
}
// This future sends all the requests from `request_stream` through `writer`
// until the stream is finished, then resolves.
async fn send<S: Stream<Item = WriteFrame>>(
writer: &mut FrameWriter<WriteFrame, OwnedWriteHalf>,
send_stream: S,
) -> io::Result<()> {
tokio::pin!(send_stream);
while let Some(frame) = send_stream.next().await {
debug!("Channel: sending frame: {:?}", frame);
writer.write(&frame).await?;
} }
/// Attempts to write a single frame to the underlying stream.
pub async fn write(&mut self, frame: &WriteFrame) -> io::Result<()> {
self.writer.write(frame).await
Ok(())
}
// It would be easier to inline this `select!` call inside `run()`, but that
// fails due to some weird, undiagnosed error due to the interaction of
// `async_stream::try_stream!`, `select!` and the `?` operator.
async fn run_once<S: Future<Output = io::Result<()>>>(
send_task: S,
reader: &mut ChannelReader<ReadFrame>,
) -> Result<Option<ReadFrame>, ChannelError> {
tokio::select! {
send_result = send_task => {
send_result?;
Ok(None)
},
read_result = reader.read_strict() => read_result.map(Some),
} }
/// Sends the given stream of frames while receiving frames in return.
/// Once `send_stream` is exhausted, shuts down the underlying TCP stream,
/// drains incoming frames, then terminates.
pub fn run<S: Stream<Item = WriteFrame>>(
mut self,
send_stream: S,
) -> impl Stream<Item = Result<ReadFrame, ChannelError>> {
async_stream::try_stream! {
// Drive the main loop: send requests and receive responses.
//
// We make a big future out of the operation of waiting for requests
// to send and from `request_stream` and sending them out through
// `self.writer`, that we can then poll repeatedly and concurrently
// with polling for responses. This allows us to concurrently write
// and read from the underlying `TcpStream` in full duplex mode.
{
let send_task = Self::send(&mut self.writer, send_stream);
tokio::pin!(send_task);
while let Some(frame) =
Self::run_once(&mut send_task, &mut self.reader).await? {
yield frame;
}
}
/// Attempts to read a single frame from the underlying stream.
pub async fn read(&mut self) -> Result<ReadFrame, ChannelError> {
self.reader.read_strict().await
}
/// Attempts to write a single frame to the underlying stream.
pub async fn write(&mut self, frame: &WriteFrame) -> io::Result<()> {
self.writer.write(frame).await
}
/// Sends the given stream of frames while receiving frames in return.
/// Once `send_stream` is exhausted, shuts down the underlying TCP stream,
/// drains incoming frames, then terminates.
pub fn run<S: Stream<Item = WriteFrame>>(
mut self,
send_stream: S,
) -> impl Stream<Item = Result<ReadFrame, ChannelError>> {
async_stream::try_stream! {
// Drive the main loop: send requests and receive responses.
//
// We make a big future out of the operation of waiting for requests
// to send and from `request_stream` and sending them out through
// `self.writer`, that we can then poll repeatedly and concurrently
// with polling for responses. This allows us to concurrently write
// and read from the underlying `TcpStream` in full duplex mode.
{
let send_task = Self::send(&mut self.writer, send_stream);
tokio::pin!(send_task);
while let Some(frame) =
Self::run_once(&mut send_task, &mut self.reader).await? {
yield frame;
} }
}
debug!("Channel: shutting down writer");
self.writer.shutdown().await?;
debug!("Channel: shutting down writer");
self.writer.shutdown().await?;
// Drain the receiving end of the connection.
while let Some(frame) = self.reader.read().await? {
yield frame;
}
}
// Drain the receiving end of the connection.
while let Some(frame) = self.reader.read().await? {
yield frame;
}
} }
}
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use futures::stream::{self, StreamExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::oneshot;
use futures::stream::{self, StreamExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::oneshot;
use super::Channel;
use super::Channel;
// Enable capturing logs in tests.
fn init() {
let _ = env_logger::builder().is_test(true).try_init();
}
// Enable capturing logs in tests.
fn init() {
let _ = env_logger::builder().is_test(true).try_init();
}
#[tokio::test]
async fn read_write() {
init();
#[tokio::test]
async fn read_write() {
init();
let listener = TcpListener::bind("localhost:0").await.unwrap();
let address = listener.local_addr().unwrap();
let listener = TcpListener::bind("localhost:0").await.unwrap();
let address = listener.local_addr().unwrap();
let listener_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let mut channel = Channel::<u32, u32>::new(stream);
let listener_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let mut channel = Channel::<u32, u32>::new(stream);
assert_eq!(channel.read().await.unwrap(), 1);
channel.write(&2).await.unwrap();
});
assert_eq!(channel.read().await.unwrap(), 1);
channel.write(&2).await.unwrap();
});
let stream = TcpStream::connect(address).await.unwrap();
let mut channel = Channel::<u32, u32>::new(stream);
let stream = TcpStream::connect(address).await.unwrap();
let mut channel = Channel::<u32, u32>::new(stream);
channel.write(&1).await.unwrap();
assert_eq!(channel.read().await.unwrap(), 2);
channel.write(&1).await.unwrap();
assert_eq!(channel.read().await.unwrap(), 2);
listener_task.await.unwrap();
}
listener_task.await.unwrap();
}
#[tokio::test]
async fn read_eof() {
init();
#[tokio::test]
async fn read_eof() {
init();
let listener = TcpListener::bind("localhost:0").await.unwrap();
let address = listener.local_addr().unwrap();
let listener = TcpListener::bind("localhost:0").await.unwrap();
let address = listener.local_addr().unwrap();
let listener_task = tokio::spawn(async move {
// Accept the stream and immediately drop/close it.
listener.accept().await.unwrap();
});
let listener_task = tokio::spawn(async move {
// Accept the stream and immediately drop/close it.
listener.accept().await.unwrap();
});
let stream = TcpStream::connect(address).await.unwrap();
let mut channel = Channel::<u32, u32>::new(stream);
let stream = TcpStream::connect(address).await.unwrap();
let mut channel = Channel::<u32, u32>::new(stream);
assert!(channel.read().await.unwrap_err().is_unexpected_eof());
assert!(channel.read().await.unwrap_err().is_unexpected_eof());
listener_task.await.unwrap();
}
listener_task.await.unwrap();
}
#[tokio::test]
async fn open_close() {
init();
#[tokio::test]
async fn open_close() {
init();
let listener = TcpListener::bind("localhost:0").await.unwrap();
let address = listener.local_addr().unwrap();
let listener = TcpListener::bind("localhost:0").await.unwrap();
let address = listener.local_addr().unwrap();
let listener_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let channel = Channel::<u32, u32>::new(stream);
let listener_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let channel = Channel::<u32, u32>::new(stream);
// Wait forever, receive no messages.
let inbound = channel.run(stream::pending());
tokio::pin!(inbound);
// Wait forever, receive no messages.
let inbound = channel.run(stream::pending());
tokio::pin!(inbound);
// The server observes an unexpected EOF error when the client
// decides to close the channel.
let error = inbound.next().await.unwrap().unwrap_err();
assert!(error.is_unexpected_eof());
});
// The server observes an unexpected EOF error when the client
// decides to close the channel.
let error = inbound.next().await.unwrap().unwrap_err();
assert!(error.is_unexpected_eof());
});
let stream = TcpStream::connect(address).await.unwrap();
let channel = Channel::<u32, u32>::new(stream);
let stream = TcpStream::connect(address).await.unwrap();
let channel = Channel::<u32, u32>::new(stream);
// Stop immediately, receive no messages.
let inbound = channel.run(stream::empty());
tokio::pin!(inbound);
// Stop immediately, receive no messages.
let inbound = channel.run(stream::empty());
tokio::pin!(inbound);
// The channel is closed cleanly from the client's point of view.
assert!(inbound.next().await.is_none());
// The channel is closed cleanly from the client's point of view.
assert!(inbound.next().await.is_none());
listener_task.await.unwrap();
}
listener_task.await.unwrap();
}
#[tokio::test]
async fn simple_exchange() {
init();
#[tokio::test]
async fn simple_exchange() {
init();
let listener = TcpListener::bind("localhost:0").await.unwrap();
let address = listener.local_addr().unwrap();
let listener = TcpListener::bind("localhost:0").await.unwrap();
let address = listener.local_addr().unwrap();
let listener_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let channel = Channel::<u32, u32>::new(stream);
let listener_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let channel = Channel::<u32, u32>::new(stream);
let (tx, rx) = oneshot::channel::<u32>();
let (tx, rx) = oneshot::channel::<u32>();
// Send one message, then wait forever.
let outbound = stream::once(async move { rx.await.unwrap() }).chain(stream::pending());
let inbound = channel.run(outbound);
tokio::pin!(inbound);
// Send one message, then wait forever.
let outbound =
stream::once(async move { rx.await.unwrap() }).chain(stream::pending());
let inbound = channel.run(outbound);
tokio::pin!(inbound);
// The server receives the client's message.
assert_eq!(inbound.next().await.unwrap().unwrap(), 1);
// The server receives the client's message.
assert_eq!(inbound.next().await.unwrap().unwrap(), 1);
// Server responds.
tx.send(1001).unwrap();
// Server responds.
tx.send(1001).unwrap();
// The server observes an unexpected EOF error when the client
// decides to close the channel.
let error = inbound.next().await.unwrap().unwrap_err();
assert!(error.is_unexpected_eof());
});
// The server observes an unexpected EOF error when the client
// decides to close the channel.
let error = inbound.next().await.unwrap().unwrap_err();
assert!(error.is_unexpected_eof());
});
let stream = TcpStream::connect(address).await.unwrap();
let channel = Channel::<u32, u32>::new(stream);
let stream = TcpStream::connect(address).await.unwrap();
let channel = Channel::<u32, u32>::new(stream);
let (tx, rx) = oneshot::channel::<()>();
let (tx, rx) = oneshot::channel::<()>();
// Send one message then wait for a reply. If we did not wait, we might
// shut down the channel before the server has a chance to respond.
let outbound = async_stream::stream! {
yield 1;
rx.await.unwrap();
};
let inbound = channel.run(outbound);
tokio::pin!(inbound);
// Send one message then wait for a reply. If we did not wait, we might
// shut down the channel before the server has a chance to respond.
let outbound = async_stream::stream! {
yield 1;
rx.await.unwrap();
};
let inbound = channel.run(outbound);
tokio::pin!(inbound);
// The client receives the server's message.
assert_eq!(inbound.next().await.unwrap().unwrap(), 1001);
// The client receives the server's message.
assert_eq!(inbound.next().await.unwrap().unwrap(), 1001);
// Signal to the client that we should shut down.
tx.send(()).unwrap();
// Signal to the client that we should shut down.
tx.send(()).unwrap();
assert!(inbound.next().await.is_none());
assert!(inbound.next().await.is_none());
listener_task.await.unwrap();
}
listener_task.await.unwrap();
}
} }

+ 347
- 337
proto/src/core/frame.rs View File

@ -13,401 +13,411 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use super::prefix::Prefixer; use super::prefix::Prefixer;
use super::u32::{decode_u32, U32_BYTE_LEN}; use super::u32::{decode_u32, U32_BYTE_LEN};
use super::value::{ use super::value::{
ValueDecode, ValueDecodeError, ValueDecoder, ValueEncode, ValueEncodeError, ValueEncoder,
ValueDecode, ValueDecodeError, ValueDecoder, ValueEncode, ValueEncodeError,
ValueEncoder,
}; };
#[derive(Debug, Error, PartialEq)] #[derive(Debug, Error, PartialEq)]
pub enum FrameEncodeError { pub enum FrameEncodeError {
#[error("encoded value length {length} is too large")]
ValueTooLarge {
/// The length of the encoded value.
length: usize,
},
#[error("failed to encode value: {0}")]
ValueEncodeError(#[from] ValueEncodeError),
#[error("encoded value length {length} is too large")]
ValueTooLarge {
/// The length of the encoded value.
length: usize,
},
#[error("failed to encode value: {0}")]
ValueEncodeError(#[from] ValueEncodeError),
} }
impl From<FrameEncodeError> for io::Error { impl From<FrameEncodeError> for io::Error {
fn from(error: FrameEncodeError) -> Self {
io::Error::new(io::ErrorKind::InvalidData, format!("{}", error))
}
fn from(error: FrameEncodeError) -> Self {
io::Error::new(io::ErrorKind::InvalidData, format!("{}", error))
}
} }
/// Encodes entire protocol frames containing values of type `T`. /// Encodes entire protocol frames containing values of type `T`.
#[derive(Debug)] #[derive(Debug)]
pub struct FrameEncoder<T: ?Sized> { pub struct FrameEncoder<T: ?Sized> {
phantom: PhantomData<T>,
phantom: PhantomData<T>,
} }
impl<T: ValueEncode + ?Sized> FrameEncoder<T> { impl<T: ValueEncode + ?Sized> FrameEncoder<T> {
pub fn new() -> Self {
Self {
phantom: PhantomData,
}
pub fn new() -> Self {
Self {
phantom: PhantomData,
} }
}
pub fn encode_to(&mut self, value: &T, buffer: &mut BytesMut) -> Result<(), FrameEncodeError> {
let mut prefixer = Prefixer::new(buffer);
ValueEncoder::new(prefixer.suffix_mut()).encode(value)?;
pub fn encode_to(
&mut self,
value: &T,
buffer: &mut BytesMut,
) -> Result<(), FrameEncodeError> {
let mut prefixer = Prefixer::new(buffer);
if let Err(prefixer) = prefixer.finalize() {
return Err(FrameEncodeError::ValueTooLarge {
length: prefixer.suffix().len(),
});
}
ValueEncoder::new(prefixer.suffix_mut()).encode(value)?;
Ok(())
if let Err(prefixer) = prefixer.finalize() {
return Err(FrameEncodeError::ValueTooLarge {
length: prefixer.suffix().len(),
});
} }
Ok(())
}
} }
/// Decodes entire protocol frames containing values of type `T`. /// Decodes entire protocol frames containing values of type `T`.
#[derive(Debug)] #[derive(Debug)]
pub struct FrameDecoder<T> { pub struct FrameDecoder<T> {
// Only here to enable parameterizing `Decoder` by `T`.
phantom: PhantomData<T>,
// Only here to enable parameterizing `Decoder` by `T`.
phantom: PhantomData<T>,
} }
impl<T: ValueDecode> FrameDecoder<T> { impl<T: ValueDecode> FrameDecoder<T> {
pub fn new() -> Self {
Self {
phantom: PhantomData,
}
pub fn new() -> Self {
Self {
phantom: PhantomData,
}
}
/// Attempts to decode an entire frame from the given buffer.
///
/// Returns `Ok(Some(frame))` if successful, in which case the frame's bytes
/// have been split off from the left of `bytes`.
///
/// Returns `Ok(None)` if not enough bytes are available to decode an entire
/// frame yet, in which case `bytes` is untouched.
///
/// Returns an error if the length prefix or the framed value are malformed,
/// in which case `bytes` is untouched.
pub fn decode_from(
&mut self,
bytes: &mut BytesMut,
) -> Result<Option<T>, ValueDecodeError> {
if bytes.len() < U32_BYTE_LEN {
return Ok(None); // Not enough bytes yet.
} }
/// Attempts to decode an entire frame from the given buffer.
///
/// Returns `Ok(Some(frame))` if successful, in which case the frame's bytes
/// have been split off from the left of `bytes`.
///
/// Returns `Ok(None)` if not enough bytes are available to decode an entire
/// frame yet, in which case `bytes` is untouched.
///
/// Returns an error if the length prefix or the framed value are malformed,
/// in which case `bytes` is untouched.
pub fn decode_from(&mut self, bytes: &mut BytesMut) -> Result<Option<T>, ValueDecodeError> {
if bytes.len() < U32_BYTE_LEN {
return Ok(None); // Not enough bytes yet.
}
// Split the prefix off. After this:
//
// | bytes (len 4) | suffix |
//
// NOTE: This method would be simpler if we could use split_to() instead
// here such that `bytes` contained the suffix. At the end, we would not
// have to replace `bytes` with `suffix`. However, that would require
// calling `prefix.unsplit(*bytes)`, and that does not work since
// `bytes` is only borrowed, and unsplit() takes its argument by value.
let mut suffix = bytes.split_off(U32_BYTE_LEN);
// unwrap() cannot panic because `bytes` is of the exact right length.
let array: [u8; U32_BYTE_LEN] = bytes.as_ref().try_into().unwrap();
let length = decode_u32(array) as usize;
if suffix.len() < length {
// Re-assemble `bytes` as it first was.
bytes.unsplit(suffix);
return Ok(None); // Not enough bytes yet.
}
// Split off the right amount of bytes from the buffer. After this:
//
// | bytes (len 4) | contents | suffix |
//
let mut contents = suffix.split_to(length);
// Attempt to decode the value.
let item = match ValueDecoder::new(&contents).decode() {
Ok(item) => item,
Err(error) => {
// Re-assemble `bytes` as it first was.
contents.unsplit(suffix);
bytes.unsplit(contents);
return Err(error);
}
};
// Remove the decoded bytes from the left of `bytes`.
*bytes = suffix;
Ok(Some(item))
// Split the prefix off. After this:
//
// | bytes (len 4) | suffix |
//
// NOTE: This method would be simpler if we could use split_to() instead
// here such that `bytes` contained the suffix. At the end, we would not
// have to replace `bytes` with `suffix`. However, that would require
// calling `prefix.unsplit(*bytes)`, and that does not work since
// `bytes` is only borrowed, and unsplit() takes its argument by value.
let mut suffix = bytes.split_off(U32_BYTE_LEN);
// unwrap() cannot panic because `bytes` is of the exact right length.
let array: [u8; U32_BYTE_LEN] = bytes.as_ref().try_into().unwrap();
let length = decode_u32(array) as usize;
if suffix.len() < length {
// Re-assemble `bytes` as it first was.
bytes.unsplit(suffix);
return Ok(None); // Not enough bytes yet.
} }
// Split off the right amount of bytes from the buffer. After this:
//
// | bytes (len 4) | contents | suffix |
//
let mut contents = suffix.split_to(length);
// Attempt to decode the value.
let item = match ValueDecoder::new(&contents).decode() {
Ok(item) => item,
Err(error) => {
// Re-assemble `bytes` as it first was.
contents.unsplit(suffix);
bytes.unsplit(contents);
return Err(error);
}
};
// Remove the decoded bytes from the left of `bytes`.
*bytes = suffix;
Ok(Some(item))
}
} }
/// An asynchronous sink for frames wrapping around a byte writer. /// An asynchronous sink for frames wrapping around a byte writer.
#[derive(Debug)] #[derive(Debug)]
pub struct FrameWriter<Frame: ?Sized, Writer> { pub struct FrameWriter<Frame: ?Sized, Writer> {
encoder: FrameEncoder<Frame>,
writer: Writer,
encoder: FrameEncoder<Frame>,
writer: Writer,
} }
impl<Frame, Writer> FrameWriter<Frame, Writer> impl<Frame, Writer> FrameWriter<Frame, Writer>
where where
Frame: ValueEncode + ?Sized,
Writer: AsyncWrite + Unpin,
Frame: ValueEncode + ?Sized,
Writer: AsyncWrite + Unpin,
{ {
pub fn new(writer: Writer) -> Self {
FrameWriter {
encoder: FrameEncoder::new(),
writer,
}
pub fn new(writer: Writer) -> Self {
FrameWriter {
encoder: FrameEncoder::new(),
writer,
} }
}
pub async fn write(&mut self, frame: &Frame) -> io::Result<()> {
let mut bytes = BytesMut::new();
self.encoder.encode_to(frame, &mut bytes)?;
self.writer.write_all(bytes.as_ref()).await
}
pub async fn write(&mut self, frame: &Frame) -> io::Result<()> {
let mut bytes = BytesMut::new();
self.encoder.encode_to(frame, &mut bytes)?;
self.writer.write_all(bytes.as_ref()).await
}
pub async fn shutdown(&mut self) -> io::Result<()> {
self.writer.shutdown().await
}
pub async fn shutdown(&mut self) -> io::Result<()> {
self.writer.shutdown().await
}
} }
/// An asynchronous stream of frames wrapping around a byte reader. /// An asynchronous stream of frames wrapping around a byte reader.
#[derive(Debug)] #[derive(Debug)]
pub struct FrameReader<Frame, Reader> { pub struct FrameReader<Frame, Reader> {
decoder: FrameDecoder<Frame>,
reader: Reader,
read_buffer: BytesMut,
decoder: FrameDecoder<Frame>,
reader: Reader,
read_buffer: BytesMut,
} }
impl<Frame, Reader> FrameReader<Frame, Reader> impl<Frame, Reader> FrameReader<Frame, Reader>
where where
Frame: ValueDecode,
Reader: AsyncRead + Unpin,
Frame: ValueDecode,
Reader: AsyncRead + Unpin,
{ {
pub fn new(reader: Reader) -> Self {
FrameReader {
decoder: FrameDecoder::new(),
reader,
read_buffer: BytesMut::new(),
}
pub fn new(reader: Reader) -> Self {
FrameReader {
decoder: FrameDecoder::new(),
reader,
read_buffer: BytesMut::new(),
} }
/// Attempts to read the next frame from the underlying byte stream.
///
/// Returns `Ok(Some(frame))` on success.
/// Returns `Ok(None)` if the stream has reached the end-of-file event.
///
/// Returns an error if reading from the stream returned an error or if an
/// invalid frame was received.
pub async fn read(&mut self) -> io::Result<Option<Frame>> {
loop {
if let Some(frame) = self.decoder.decode_from(&mut self.read_buffer)? {
return Ok(Some(frame));
}
if self.reader.read_buf(&mut self.read_buffer).await? == 0 {
return Ok(None);
}
}
}
/// Attempts to read the next frame from the underlying byte stream.
///
/// Returns `Ok(Some(frame))` on success.
/// Returns `Ok(None)` if the stream has reached the end-of-file event.
///
/// Returns an error if reading from the stream returned an error or if an
/// invalid frame was received.
pub async fn read(&mut self) -> io::Result<Option<Frame>> {
loop {
if let Some(frame) = self.decoder.decode_from(&mut self.read_buffer)? {
return Ok(Some(frame));
}
if self.reader.read_buf(&mut self.read_buffer).await? == 0 {
return Ok(None);
}
} }
}
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use bytes::BytesMut;
use tokio::net::{TcpListener, TcpStream};
use super::{FrameDecoder, FrameEncoder, FrameReader, FrameWriter};
// Test value: [1, 3, 3, 7] in little-endian.
const U32_1337: u32 = 1 + (3 << 8) + (3 << 16) + (7 << 24);
#[test]
fn encode_u32() {
let mut bytes = BytesMut::new();
FrameEncoder::new()
.encode_to(&U32_1337, &mut bytes)
.unwrap();
assert_eq!(
bytes,
vec![
4, 0, 0, 0, // 1 32-bit integer = 4 bytes.
1, 3, 3, 7, // Little-endian integer.
]
);
}
#[test]
fn encode_appends() {
let mut bytes = BytesMut::new();
let mut encoder = FrameEncoder::new();
encoder.encode_to(&U32_1337, &mut bytes).unwrap();
encoder.encode_to(&U32_1337, &mut bytes).unwrap();
assert_eq!(
bytes,
vec![
4, 0, 0, 0, // 1 32-bit integer = 4 bytes.
1, 3, 3, 7, // Little-endian integer.
4, 0, 0, 0, // Repeated.
1, 3, 3, 7,
]
);
}
#[test]
fn encode_vec() {
let v: Vec<u32> = vec![1, 3, 3, 7];
let mut bytes = BytesMut::new();
FrameEncoder::new().encode_to(&v, &mut bytes).unwrap();
assert_eq!(
bytes,
vec![
20, 0, 0, 0, // 5 32-bit integers = 20 bytes.
4, 0, 0, 0, // 4 elements in the vector.
1, 0, 0, 0, // Little-endian vector elements.
3, 0, 0, 0, //
3, 0, 0, 0, //
7, 0, 0, 0, //
]
);
}
#[test]
fn decode_not_enough_data_for_prefix() {
let initial_bytes = vec![
4, 0, 0, // Incomplete 32-bit length prefix.
];
let mut bytes = BytesMut::new();
bytes.extend_from_slice(&initial_bytes);
let value: Option<u32> = FrameDecoder::new().decode_from(&mut bytes).unwrap();
assert_eq!(value, None);
assert_eq!(bytes, initial_bytes); // Untouched.
}
#[test]
fn decode_not_enough_data_for_contents() {
let initial_bytes = vec![
4, 0, 0, 0, // Length 4.
1, 2, 3, // But there are only 3 bytes!
];
let mut bytes = BytesMut::new();
bytes.extend_from_slice(&initial_bytes);
let value: Option<u32> = FrameDecoder::new().decode_from(&mut bytes).unwrap();
assert_eq!(value, None);
assert_eq!(bytes, initial_bytes); // Untouched.
}
#[test]
fn decode_u32() {
let mut bytes = BytesMut::new();
bytes.extend_from_slice(&[
4, 0, 0, 0, // 1 32-bit integer = 4 bytes.
1, 3, 3, 7, // Little-endian integer.
4, 2, // Trailing bytes.
]);
let value = FrameDecoder::new().decode_from(&mut bytes).unwrap();
assert_eq!(value, Some(U32_1337));
assert_eq!(bytes, vec![4, 2]); // Decoded bytes were split off.
}
#[test]
fn decode_vec() {
let mut bytes = BytesMut::new();
bytes.extend_from_slice(&[
20, 0, 0, 0, // 5 32-bit integers = 20 bytes.
4, 0, 0, 0, // 4 elements in the vector.
1, 0, 0, 0, // Little-endian vector elements.
3, 0, 0, 0, //
3, 0, 0, 0, //
7, 0, 0, 0, //
4, 2, // Trailing bytes.
]);
let value = FrameDecoder::new().decode_from(&mut bytes).unwrap();
let expected_value: Vec<u32> = vec![1, 3, 3, 7];
assert_eq!(value, Some(expected_value));
assert_eq!(bytes, vec![4, 2]); // Decoded bytes were split off.
}
#[test]
fn roundtrip() {
let value: Vec<String> = vec![
"apples".to_string(), //
"bananas".to_string(), //
"oranges".to_string(), //
"and cheese!".to_string(), //
];
let mut buffer = BytesMut::new();
FrameEncoder::new().encode_to(&value, &mut buffer).unwrap();
let decoded = FrameDecoder::new().decode_from(&mut buffer).unwrap();
assert_eq!(decoded, Some(value));
assert_eq!(buffer, vec![]);
}
#[tokio::test]
async fn ping_pong() {
let listener = TcpListener::bind("localhost:0").await.unwrap();
let address = listener.local_addr().unwrap();
let server_task = tokio::spawn(async move {
let (mut stream, _peer_address) = listener.accept().await.unwrap();
let (read_half, write_half) = stream.split();
let mut reader = FrameReader::new(read_half);
let mut writer = FrameWriter::new(write_half);
assert_eq!(reader.read().await.unwrap(), Some("ping".to_string()));
writer.write("pong").await.unwrap();
assert_eq!(reader.read().await.unwrap(), Some("ping".to_string()));
writer.write("pong").await.unwrap();
});
let mut stream = TcpStream::connect(address).await.unwrap();
let (read_half, write_half) = stream.split();
let mut reader = FrameReader::new(read_half);
let mut writer = FrameWriter::new(write_half);
writer.write("ping").await.unwrap();
assert_eq!(reader.read().await.unwrap(), Some("pong".to_string()));
writer.write("ping").await.unwrap();
assert_eq!(reader.read().await.unwrap(), Some("pong".to_string()));
server_task.await.unwrap();
}
#[tokio::test]
async fn very_large_message() {
let listener = TcpListener::bind("localhost:0").await.unwrap();
let address = listener.local_addr().unwrap();
let server_task = tokio::spawn(async move {
let (mut stream, _peer_address) = listener.accept().await.unwrap();
let (read_half, write_half) = stream.split();
let mut reader = FrameReader::new(read_half);
let mut writer = FrameWriter::new(write_half);
assert_eq!(reader.read().await.unwrap(), Some("ping".to_string()));
writer.write(&vec![0u32; 10 * 4096]).await.unwrap();
});
let mut stream = TcpStream::connect(address).await.unwrap();
let (read_half, write_half) = stream.split();
let mut reader = FrameReader::new(read_half);
let mut writer = FrameWriter::new(write_half);
writer.write("ping").await.unwrap();
assert_eq!(reader.read().await.unwrap(), Some(vec![0u32; 10 * 4096]));
server_task.await.unwrap();
}
use bytes::BytesMut;
use tokio::net::{TcpListener, TcpStream};
use super::{FrameDecoder, FrameEncoder, FrameReader, FrameWriter};
// Test value: [1, 3, 3, 7] in little-endian.
const U32_1337: u32 = 1 + (3 << 8) + (3 << 16) + (7 << 24);
#[test]
fn encode_u32() {
let mut bytes = BytesMut::new();
FrameEncoder::new()
.encode_to(&U32_1337, &mut bytes)
.unwrap();
assert_eq!(
bytes,
vec![
4, 0, 0, 0, // 1 32-bit integer = 4 bytes.
1, 3, 3, 7, // Little-endian integer.
]
);
}
#[test]
fn encode_appends() {
let mut bytes = BytesMut::new();
let mut encoder = FrameEncoder::new();
encoder.encode_to(&U32_1337, &mut bytes).unwrap();
encoder.encode_to(&U32_1337, &mut bytes).unwrap();
assert_eq!(
bytes,
vec![
4, 0, 0, 0, // 1 32-bit integer = 4 bytes.
1, 3, 3, 7, // Little-endian integer.
4, 0, 0, 0, // Repeated.
1, 3, 3, 7,
]
);
}
#[test]
fn encode_vec() {
let v: Vec<u32> = vec![1, 3, 3, 7];
let mut bytes = BytesMut::new();
FrameEncoder::new().encode_to(&v, &mut bytes).unwrap();
assert_eq!(
bytes,
vec![
20, 0, 0, 0, // 5 32-bit integers = 20 bytes.
4, 0, 0, 0, // 4 elements in the vector.
1, 0, 0, 0, // Little-endian vector elements.
3, 0, 0, 0, //
3, 0, 0, 0, //
7, 0, 0, 0, //
]
);
}
#[test]
fn decode_not_enough_data_for_prefix() {
let initial_bytes = vec![
4, 0, 0, // Incomplete 32-bit length prefix.
];
let mut bytes = BytesMut::new();
bytes.extend_from_slice(&initial_bytes);
let value: Option<u32> =
FrameDecoder::new().decode_from(&mut bytes).unwrap();
assert_eq!(value, None);
assert_eq!(bytes, initial_bytes); // Untouched.
}
#[test]
fn decode_not_enough_data_for_contents() {
let initial_bytes = vec![
4, 0, 0, 0, // Length 4.
1, 2, 3, // But there are only 3 bytes!
];
let mut bytes = BytesMut::new();
bytes.extend_from_slice(&initial_bytes);
let value: Option<u32> =
FrameDecoder::new().decode_from(&mut bytes).unwrap();
assert_eq!(value, None);
assert_eq!(bytes, initial_bytes); // Untouched.
}
#[test]
fn decode_u32() {
let mut bytes = BytesMut::new();
bytes.extend_from_slice(&[
4, 0, 0, 0, // 1 32-bit integer = 4 bytes.
1, 3, 3, 7, // Little-endian integer.
4, 2, // Trailing bytes.
]);
let value = FrameDecoder::new().decode_from(&mut bytes).unwrap();
assert_eq!(value, Some(U32_1337));
assert_eq!(bytes, vec![4, 2]); // Decoded bytes were split off.
}
#[test]
fn decode_vec() {
let mut bytes = BytesMut::new();
bytes.extend_from_slice(&[
20, 0, 0, 0, // 5 32-bit integers = 20 bytes.
4, 0, 0, 0, // 4 elements in the vector.
1, 0, 0, 0, // Little-endian vector elements.
3, 0, 0, 0, //
3, 0, 0, 0, //
7, 0, 0, 0, //
4, 2, // Trailing bytes.
]);
let value = FrameDecoder::new().decode_from(&mut bytes).unwrap();
let expected_value: Vec<u32> = vec![1, 3, 3, 7];
assert_eq!(value, Some(expected_value));
assert_eq!(bytes, vec![4, 2]); // Decoded bytes were split off.
}
#[test]
fn roundtrip() {
let value: Vec<String> = vec![
"apples".to_string(), //
"bananas".to_string(), //
"oranges".to_string(), //
"and cheese!".to_string(), //
];
let mut buffer = BytesMut::new();
FrameEncoder::new().encode_to(&value, &mut buffer).unwrap();
let decoded = FrameDecoder::new().decode_from(&mut buffer).unwrap();
assert_eq!(decoded, Some(value));
assert_eq!(buffer, vec![]);
}
#[tokio::test]
async fn ping_pong() {
let listener = TcpListener::bind("localhost:0").await.unwrap();
let address = listener.local_addr().unwrap();
let server_task = tokio::spawn(async move {
let (mut stream, _peer_address) = listener.accept().await.unwrap();
let (read_half, write_half) = stream.split();
let mut reader = FrameReader::new(read_half);
let mut writer = FrameWriter::new(write_half);
assert_eq!(reader.read().await.unwrap(), Some("ping".to_string()));
writer.write("pong").await.unwrap();
assert_eq!(reader.read().await.unwrap(), Some("ping".to_string()));
writer.write("pong").await.unwrap();
});
let mut stream = TcpStream::connect(address).await.unwrap();
let (read_half, write_half) = stream.split();
let mut reader = FrameReader::new(read_half);
let mut writer = FrameWriter::new(write_half);
writer.write("ping").await.unwrap();
assert_eq!(reader.read().await.unwrap(), Some("pong".to_string()));
writer.write("ping").await.unwrap();
assert_eq!(reader.read().await.unwrap(), Some("pong".to_string()));
server_task.await.unwrap();
}
#[tokio::test]
async fn very_large_message() {
let listener = TcpListener::bind("localhost:0").await.unwrap();
let address = listener.local_addr().unwrap();
let server_task = tokio::spawn(async move {
let (mut stream, _peer_address) = listener.accept().await.unwrap();
let (read_half, write_half) = stream.split();
let mut reader = FrameReader::new(read_half);
let mut writer = FrameWriter::new(write_half);
assert_eq!(reader.read().await.unwrap(), Some("ping".to_string()));
writer.write(&vec![0u32; 10 * 4096]).await.unwrap();
});
let mut stream = TcpStream::connect(address).await.unwrap();
let (read_half, write_half) = stream.split();
let mut reader = FrameReader::new(read_half);
let mut writer = FrameWriter::new(write_half);
writer.write("ping").await.unwrap();
assert_eq!(reader.read().await.unwrap(), Some(vec![0u32; 10 * 4096]));
server_task.await.unwrap();
}
} }

+ 88
- 88
proto/src/core/prefix.rs View File

@ -11,111 +11,111 @@ use crate::core::u32::{encode_u32, U32_BYTE_LEN};
/// know the length ahead of encoding time. /// know the length ahead of encoding time.
#[derive(Debug)] #[derive(Debug)]
pub struct Prefixer<'a> { pub struct Prefixer<'a> {
/// The prefix buffer.
///
/// The length of the suffix buffer is written to the end of this buffer
/// when the prefixer is finalized.
///
/// Contains any bytes with which this prefixer was constructed.
prefix: &'a mut BytesMut,
/// The suffix buffer.
///
/// This is the buffer into which data is written before finalization.
suffix: BytesMut,
/// The prefix buffer.
///
/// The length of the suffix buffer is written to the end of this buffer
/// when the prefixer is finalized.
///
/// Contains any bytes with which this prefixer was constructed.
prefix: &'a mut BytesMut,
/// The suffix buffer.
///
/// This is the buffer into which data is written before finalization.
suffix: BytesMut,
} }
impl Prefixer<'_> { impl Prefixer<'_> {
/// Constructs a prefixer for easily appending a length prefixed value to
/// the given buffer.
pub fn new<'a>(buffer: &'a mut BytesMut) -> Prefixer<'a> {
// Reserve some space fot the prefix, but don't write it yet.
buffer.reserve(U32_BYTE_LEN);
// Split off the suffix, into which bytes will be written.
let suffix = buffer.split_off(buffer.len() + U32_BYTE_LEN);
Prefixer {
prefix: buffer,
suffix: suffix,
}
}
/// Returns a reference to the buffer into which data is written.
pub fn suffix(&self) -> &BytesMut {
&self.suffix
}
/// Returns a mutable reference to a buffer into which data can be written.
pub fn suffix_mut(&mut self) -> &mut BytesMut {
&mut self.suffix
}
/// Returns a buffer containing the original data passed at construction
/// time, to which a length-prefixed value is appended. The value itself is
/// the data written into the buffer returned by `get_mut()`.
///
/// Returns `Ok(length)` if successful, in which case the length of the
/// suffix is `length`.
///
/// Returns `Err(self)` if the length of the suffix is too large to store as
/// a prefix.
pub fn finalize(self) -> Result<u32, Self> {
// Check that the suffix's length is not too large.
let length = self.suffix.len();
let length_u32 = match u32::try_from(length) {
Ok(value) => value,
Err(_) => return Err(self),
};
// Write the prefix.
self.prefix.extend_from_slice(&encode_u32(length_u32));
// Join the prefix and suffix back again. Because `self.prefix` is
// private, we are sure that this is O(1).
self.prefix.unsplit(self.suffix);
Ok(length_u32)
/// Constructs a prefixer for easily appending a length prefixed value to
/// the given buffer.
pub fn new<'a>(buffer: &'a mut BytesMut) -> Prefixer<'a> {
// Reserve some space fot the prefix, but don't write it yet.
buffer.reserve(U32_BYTE_LEN);
// Split off the suffix, into which bytes will be written.
let suffix = buffer.split_off(buffer.len() + U32_BYTE_LEN);
Prefixer {
prefix: buffer,
suffix: suffix,
} }
}
/// Returns a reference to the buffer into which data is written.
pub fn suffix(&self) -> &BytesMut {
&self.suffix
}
/// Returns a mutable reference to a buffer into which data can be written.
pub fn suffix_mut(&mut self) -> &mut BytesMut {
&mut self.suffix
}
/// Returns a buffer containing the original data passed at construction
/// time, to which a length-prefixed value is appended. The value itself is
/// the data written into the buffer returned by `get_mut()`.
///
/// Returns `Ok(length)` if successful, in which case the length of the
/// suffix is `length`.
///
/// Returns `Err(self)` if the length of the suffix is too large to store as
/// a prefix.
pub fn finalize(self) -> Result<u32, Self> {
// Check that the suffix's length is not too large.
let length = self.suffix.len();
let length_u32 = match u32::try_from(length) {
Ok(value) => value,
Err(_) => return Err(self),
};
// Write the prefix.
self.prefix.extend_from_slice(&encode_u32(length_u32));
// Join the prefix and suffix back again. Because `self.prefix` is
// private, we are sure that this is O(1).
self.prefix.unsplit(self.suffix);
Ok(length_u32)
}
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::Prefixer;
use super::Prefixer;
use std::convert::TryInto;
use std::convert::TryInto;
use bytes::{BufMut, BytesMut};
use bytes::{BufMut, BytesMut};
use crate::core::u32::{decode_u32, U32_BYTE_LEN};
use crate::core::u32::{decode_u32, U32_BYTE_LEN};
#[test]
fn finalize_empty() {
let mut buffer = BytesMut::new();
buffer.put_u8(13);
#[test]
fn finalize_empty() {
let mut buffer = BytesMut::new();
buffer.put_u8(13);
Prefixer::new(&mut buffer).finalize().unwrap();
Prefixer::new(&mut buffer).finalize().unwrap();
assert_eq!(buffer.len(), U32_BYTE_LEN + 1);
let array: [u8; U32_BYTE_LEN] = buffer[1..].try_into().unwrap();
assert_eq!(decode_u32(array), 0);
}
assert_eq!(buffer.len(), U32_BYTE_LEN + 1);
let array: [u8; U32_BYTE_LEN] = buffer[1..].try_into().unwrap();
assert_eq!(decode_u32(array), 0);
}
#[test]
fn finalize_ok() {
let mut buffer = BytesMut::new();
buffer.put_u8(13);
#[test]
fn finalize_ok() {
let mut buffer = BytesMut::new();
buffer.put_u8(13);
let mut prefixer = Prefixer::new(&mut buffer);
let mut prefixer = Prefixer::new(&mut buffer);
prefixer.suffix_mut().extend_from_slice(&[0; 42]);
prefixer.suffix_mut().extend_from_slice(&[0; 42]);
prefixer.finalize().unwrap();
prefixer.finalize().unwrap();
// 1 junk prefix byte, length prefix, 42 bytes of value.
assert_eq!(buffer.len(), U32_BYTE_LEN + 43);
let prefix = &buffer[1..U32_BYTE_LEN + 1];
let array: [u8; U32_BYTE_LEN] = prefix.try_into().unwrap();
assert_eq!(decode_u32(array), 42);
}
// 1 junk prefix byte, length prefix, 42 bytes of value.
assert_eq!(buffer.len(), U32_BYTE_LEN + 43);
let prefix = &buffer[1..U32_BYTE_LEN + 1];
let array: [u8; U32_BYTE_LEN] = prefix.try_into().unwrap();
assert_eq!(decode_u32(array), 42);
}
} }

+ 2
- 2
proto/src/core/u32.rs View File

@ -8,10 +8,10 @@ pub const U32_BYTE_LEN: usize = 4;
/// Returns the byte representatio of the given integer value. /// Returns the byte representatio of the given integer value.
pub fn encode_u32(value: u32) -> [u8; U32_BYTE_LEN] { pub fn encode_u32(value: u32) -> [u8; U32_BYTE_LEN] {
value.to_le_bytes()
value.to_le_bytes()
} }
/// Returns the integer value corresponding to the given bytes. /// Returns the integer value corresponding to the given bytes.
pub fn decode_u32(bytes: [u8; U32_BYTE_LEN]) -> u32 { pub fn decode_u32(bytes: [u8; U32_BYTE_LEN]) -> u32 {
u32::from_le_bytes(bytes)
u32::from_le_bytes(bytes)
} }

+ 84
- 66
proto/src/core/user.rs View File

@ -1,96 +1,114 @@
use std::io; use std::io;
use crate::core::value::{ use crate::core::value::{
ValueDecode, ValueDecodeError, ValueDecoder, ValueEncode, ValueEncodeError, ValueEncoder,
ValueDecode, ValueDecodeError, ValueDecoder, ValueEncode, ValueEncodeError,
ValueEncoder,
};
use crate::{
MutPacket, Packet, PacketReadError, ReadFromPacket, WriteToPacket,
}; };
use crate::{MutPacket, Packet, PacketReadError, ReadFromPacket, WriteToPacket};
const STATUS_OFFLINE: u32 = 1; const STATUS_OFFLINE: u32 = 1;
const STATUS_AWAY: u32 = 2; const STATUS_AWAY: u32 = 2;
const STATUS_ONLINE: u32 = 3; const STATUS_ONLINE: u32 = 3;
/// This enumeration is the list of possible user statuses. /// This enumeration is the list of possible user statuses.
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd, RustcDecodable, RustcEncodable)]
#[derive(
Clone,
Copy,
Debug,
Eq,
Ord,
PartialEq,
PartialOrd,
RustcDecodable,
RustcEncodable,
)]
pub enum UserStatus { pub enum UserStatus {
/// The user if offline.
Offline,
/// The user is connected, but AFK.
Away,
/// The user is present.
Online,
/// The user if offline.
Offline,
/// The user is connected, but AFK.
Away,
/// The user is present.
Online,
} }
impl ReadFromPacket for UserStatus { impl ReadFromPacket for UserStatus {
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let n: u32 = packet.read_value()?;
match n {
STATUS_OFFLINE => Ok(UserStatus::Offline),
STATUS_AWAY => Ok(UserStatus::Away),
STATUS_ONLINE => Ok(UserStatus::Online),
_ => Err(PacketReadError::InvalidUserStatusError(n)),
}
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let n: u32 = packet.read_value()?;
match n {
STATUS_OFFLINE => Ok(UserStatus::Offline),
STATUS_AWAY => Ok(UserStatus::Away),
STATUS_ONLINE => Ok(UserStatus::Online),
_ => Err(PacketReadError::InvalidUserStatusError(n)),
} }
}
} }
impl WriteToPacket for UserStatus { impl WriteToPacket for UserStatus {
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
let n = match *self {
UserStatus::Offline => STATUS_OFFLINE,
UserStatus::Away => STATUS_AWAY,
UserStatus::Online => STATUS_ONLINE,
};
packet.write_value(&n)?;
Ok(())
}
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
let n = match *self {
UserStatus::Offline => STATUS_OFFLINE,
UserStatus::Away => STATUS_AWAY,
UserStatus::Online => STATUS_ONLINE,
};
packet.write_value(&n)?;
Ok(())
}
} }
impl ValueEncode for UserStatus { impl ValueEncode for UserStatus {
fn encode_to(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> {
let value = match *self {
UserStatus::Offline => STATUS_OFFLINE,
UserStatus::Away => STATUS_AWAY,
UserStatus::Online => STATUS_ONLINE,
};
encoder.encode_u32(value)
}
fn encode_to(
&self,
encoder: &mut ValueEncoder,
) -> Result<(), ValueEncodeError> {
let value = match *self {
UserStatus::Offline => STATUS_OFFLINE,
UserStatus::Away => STATUS_AWAY,
UserStatus::Online => STATUS_ONLINE,
};
encoder.encode_u32(value)
}
} }
impl ValueDecode for UserStatus { impl ValueDecode for UserStatus {
fn decode_from(decoder: &mut ValueDecoder) -> Result<Self, ValueDecodeError> {
let position = decoder.position();
let value: u32 = decoder.decode()?;
match value {
STATUS_OFFLINE => Ok(UserStatus::Offline),
STATUS_AWAY => Ok(UserStatus::Away),
STATUS_ONLINE => Ok(UserStatus::Online),
_ => Err(ValueDecodeError::InvalidData {
value_name: "user status".to_string(),
cause: format!("unknown value {}", value),
position: position,
}),
}
fn decode_from(decoder: &mut ValueDecoder) -> Result<Self, ValueDecodeError> {
let position = decoder.position();
let value: u32 = decoder.decode()?;
match value {
STATUS_OFFLINE => Ok(UserStatus::Offline),
STATUS_AWAY => Ok(UserStatus::Away),
STATUS_ONLINE => Ok(UserStatus::Online),
_ => Err(ValueDecodeError::InvalidData {
value_name: "user status".to_string(),
cause: format!("unknown value {}", value),
position: position,
}),
} }
}
} }
/// This structure contains the last known information about a fellow user. /// This structure contains the last known information about a fellow user.
#[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd, RustcDecodable, RustcEncodable)]
#[derive(
Clone, Debug, Eq, Ord, PartialEq, PartialOrd, RustcDecodable, RustcEncodable,
)]
pub struct User { pub struct User {
/// The name of the user.
pub name: String,
/// The last known status of the user.
pub status: UserStatus,
/// The average upload speed of the user.
pub average_speed: usize,
/// ??? Nicotine calls it downloadnum.
pub num_downloads: usize,
/// ??? Unknown field.
pub unknown: usize,
/// The number of files this user shares.
pub num_files: usize,
/// The number of folders this user shares.
pub num_folders: usize,
/// The number of free download slots of this user.
pub num_free_slots: usize,
/// The user's country code.
pub country: String,
/// The name of the user.
pub name: String,
/// The last known status of the user.
pub status: UserStatus,
/// The average upload speed of the user.
pub average_speed: usize,
/// ??? Nicotine calls it downloadnum.
pub num_downloads: usize,
/// ??? Unknown field.
pub unknown: usize,
/// The number of files this user shares.
pub num_files: usize,
/// The number of folders this user shares.
pub num_folders: usize,
/// The number of free download slots of this user.
pub num_free_slots: usize,
/// The user's country code.
pub country: String,
} }

+ 714
- 675
proto/src/core/value.rs
File diff suppressed because it is too large
View File


+ 281
- 272
proto/src/handler.rs View File

@ -31,17 +31,17 @@ const LISTEN_TOKEN: usize = config::MAX_PEERS + 1;
#[derive(Debug)] #[derive(Debug)]
pub enum Request { pub enum Request {
PeerConnect(usize, net::Ipv4Addr, u16),
PeerMessage(usize, peer::Message),
ServerRequest(ServerRequest),
PeerConnect(usize, net::Ipv4Addr, u16),
PeerMessage(usize, peer::Message),
ServerRequest(ServerRequest),
} }
#[derive(Debug)] #[derive(Debug)]
pub enum Response { pub enum Response {
PeerConnectionClosed(usize),
PeerConnectionOpen(usize),
PeerMessage(usize, peer::Message),
ServerResponse(ServerResponse),
PeerConnectionClosed(usize),
PeerConnectionOpen(usize),
PeerMessage(usize, peer::Message),
ServerResponse(ServerResponse),
} }
/*========================* /*========================*
@ -51,16 +51,16 @@ pub enum Response {
pub struct ServerResponseSender(crossbeam_channel::Sender<Response>); pub struct ServerResponseSender(crossbeam_channel::Sender<Response>);
impl SendPacket for ServerResponseSender { impl SendPacket for ServerResponseSender {
type Value = ServerResponse;
type Error = crossbeam_channel::SendError<Response>;
type Value = ServerResponse;
type Error = crossbeam_channel::SendError<Response>;
fn send_packet(&mut self, value: Self::Value) -> Result<(), Self::Error> {
self.0.send(Response::ServerResponse(value))
}
fn send_packet(&mut self, value: Self::Value) -> Result<(), Self::Error> {
self.0.send(Response::ServerResponse(value))
}
fn notify_open(&mut self) -> Result<(), Self::Error> {
Ok(())
}
fn notify_open(&mut self) -> Result<(), Self::Error> {
Ok(())
}
} }
/*======================* /*======================*
@ -68,21 +68,21 @@ impl SendPacket for ServerResponseSender {
*======================*/ *======================*/
pub struct PeerResponseSender { pub struct PeerResponseSender {
sender: crossbeam_channel::Sender<Response>,
peer_id: usize,
sender: crossbeam_channel::Sender<Response>,
peer_id: usize,
} }
impl SendPacket for PeerResponseSender { impl SendPacket for PeerResponseSender {
type Value = peer::Message;
type Error = crossbeam_channel::SendError<Response>;
type Value = peer::Message;
type Error = crossbeam_channel::SendError<Response>;
fn send_packet(&mut self, value: Self::Value) -> Result<(), Self::Error> {
self.sender.send(Response::PeerMessage(self.peer_id, value))
}
fn send_packet(&mut self, value: Self::Value) -> Result<(), Self::Error> {
self.sender.send(Response::PeerMessage(self.peer_id, value))
}
fn notify_open(&mut self) -> Result<(), Self::Error> {
self.sender.send(Response::PeerConnectionOpen(self.peer_id))
}
fn notify_open(&mut self) -> Result<(), Self::Error> {
self.sender.send(Response::PeerConnectionOpen(self.peer_id))
}
} }
/*=========* /*=========*
@ -92,293 +92,302 @@ impl SendPacket for PeerResponseSender {
/// This struct handles all the soulseek connections, to the server and to /// This struct handles all the soulseek connections, to the server and to
/// peers. /// peers.
struct Handler { struct Handler {
server_stream: Stream<ServerResponseSender>,
server_stream: Stream<ServerResponseSender>,
peer_streams: slab::Slab<Stream<PeerResponseSender>, usize>,
peer_streams: slab::Slab<Stream<PeerResponseSender>, usize>,
listener: mio::tcp::TcpListener,
listener: mio::tcp::TcpListener,
client_tx: crossbeam_channel::Sender<Response>,
client_tx: crossbeam_channel::Sender<Response>,
} }
fn listener_bind<U>(addr_spec: U) -> io::Result<mio::tcp::TcpListener> fn listener_bind<U>(addr_spec: U) -> io::Result<mio::tcp::TcpListener>
where where
U: ToSocketAddrs + fmt::Debug,
U: ToSocketAddrs + fmt::Debug,
{ {
for socket_addr in addr_spec.to_socket_addrs()? {
if let Ok(listener) = mio::tcp::TcpListener::bind(&socket_addr) {
return Ok(listener);
}
for socket_addr in addr_spec.to_socket_addrs()? {
if let Ok(listener) = mio::tcp::TcpListener::bind(&socket_addr) {
return Ok(listener);
} }
Err(io::Error::new(
io::ErrorKind::Other,
format!("Cannot bind to {:?}", addr_spec),
))
}
Err(io::Error::new(
io::ErrorKind::Other,
format!("Cannot bind to {:?}", addr_spec),
))
} }
impl Handler { impl Handler {
#[allow(deprecated)]
fn new(
client_tx: crossbeam_channel::Sender<Response>,
event_loop: &mut mio::deprecated::EventLoop<Self>,
) -> io::Result<Self> {
let host = config::SERVER_HOST;
let port = config::SERVER_PORT;
let server_stream = Stream::new((host, port), ServerResponseSender(client_tx.clone()))?;
info!("Connected to server at {}:{}", host, port);
let listener = listener_bind((config::LISTEN_HOST, config::LISTEN_PORT))?;
info!(
"Listening for connections on {}:{}",
config::LISTEN_HOST,
config::LISTEN_PORT
);
event_loop.register(
server_stream.evented(),
#[allow(deprecated)]
fn new(
client_tx: crossbeam_channel::Sender<Response>,
event_loop: &mut mio::deprecated::EventLoop<Self>,
) -> io::Result<Self> {
let host = config::SERVER_HOST;
let port = config::SERVER_PORT;
let server_stream =
Stream::new((host, port), ServerResponseSender(client_tx.clone()))?;
info!("Connected to server at {}:{}", host, port);
let listener = listener_bind((config::LISTEN_HOST, config::LISTEN_PORT))?;
info!(
"Listening for connections on {}:{}",
config::LISTEN_HOST,
config::LISTEN_PORT
);
event_loop.register(
server_stream.evented(),
mio::Token(SERVER_TOKEN),
mio::Ready::all(),
mio::PollOpt::edge() | mio::PollOpt::oneshot(),
)?;
event_loop.register(
&listener,
mio::Token(LISTEN_TOKEN),
mio::Ready::all(),
mio::PollOpt::edge() | mio::PollOpt::oneshot(),
)?;
Ok(Handler {
server_stream: server_stream,
peer_streams: slab::Slab::new(config::MAX_PEERS),
listener: listener,
client_tx: client_tx,
})
}
#[allow(deprecated)]
fn connect_to_peer(
&mut self,
peer_id: usize,
ip: net::Ipv4Addr,
port: u16,
event_loop: &mut mio::deprecated::EventLoop<Self>,
) -> Result<(), String> {
let vacant_entry = match self.peer_streams.entry(peer_id) {
None => return Err("id out of range".to_string()),
Some(slab::Entry::Occupied(_occupied_entry)) => {
return Err("id already taken".to_string());
}
Some(slab::Entry::Vacant(vacant_entry)) => vacant_entry,
};
info!("Opening peer connection {} to {}:{}", peer_id, ip, port);
let sender = PeerResponseSender {
sender: self.client_tx.clone(),
peer_id: peer_id,
};
let peer_stream = match Stream::new((ip, port), sender) {
Ok(peer_stream) => peer_stream,
Err(err) => return Err(format!("i/o error: {}", err)),
};
event_loop
.register(
peer_stream.evented(),
mio::Token(peer_id),
mio::Ready::all(),
mio::PollOpt::edge() | mio::PollOpt::oneshot(),
)
.unwrap();
vacant_entry.insert(peer_stream);
Ok(())
}
#[allow(deprecated)]
fn process_server_intent(
&mut self,
intent: Intent,
event_loop: &mut mio::deprecated::EventLoop<Self>,
) {
match intent {
Intent::Done => {
error!("Server connection closed");
// TODO notify client and shut down
}
Intent::Continue(event_set) => {
event_loop
.reregister(
self.server_stream.evented(),
mio::Token(SERVER_TOKEN), mio::Token(SERVER_TOKEN),
mio::Ready::all(),
event_set,
mio::PollOpt::edge() | mio::PollOpt::oneshot(), mio::PollOpt::edge() | mio::PollOpt::oneshot(),
)?;
event_loop.register(
&listener,
mio::Token(LISTEN_TOKEN),
mio::Ready::all(),
mio::PollOpt::edge() | mio::PollOpt::oneshot(),
)?;
Ok(Handler {
server_stream: server_stream,
peer_streams: slab::Slab::new(config::MAX_PEERS),
listener: listener,
client_tx: client_tx,
})
)
.unwrap();
}
} }
#[allow(deprecated)]
fn connect_to_peer(
&mut self,
peer_id: usize,
ip: net::Ipv4Addr,
port: u16,
event_loop: &mut mio::deprecated::EventLoop<Self>,
) -> Result<(), String> {
let vacant_entry = match self.peer_streams.entry(peer_id) {
None => return Err("id out of range".to_string()),
Some(slab::Entry::Occupied(_occupied_entry)) => {
return Err("id already taken".to_string());
}
Some(slab::Entry::Vacant(vacant_entry)) => vacant_entry,
};
info!("Opening peer connection {} to {}:{}", peer_id, ip, port);
let sender = PeerResponseSender {
sender: self.client_tx.clone(),
peer_id: peer_id,
};
let peer_stream = match Stream::new((ip, port), sender) {
Ok(peer_stream) => peer_stream,
Err(err) => return Err(format!("i/o error: {}", err)),
};
event_loop
.register(
peer_stream.evented(),
mio::Token(peer_id),
mio::Ready::all(),
mio::PollOpt::edge() | mio::PollOpt::oneshot(),
}
#[allow(deprecated)]
fn process_peer_intent(
&mut self,
intent: Intent,
token: mio::Token,
event_loop: &mut mio::deprecated::EventLoop<Self>,
) {
match intent {
Intent::Done => {
self.peer_streams.remove(token.0);
self
.client_tx
.send(Response::PeerConnectionClosed(token.0))
.unwrap();
}
Intent::Continue(event_set) => {
if let Some(peer_stream) = self.peer_streams.get_mut(token.0) {
event_loop
.reregister(
peer_stream.evented(),
token,
event_set,
mio::PollOpt::edge() | mio::PollOpt::oneshot(),
) )
.unwrap(); .unwrap();
vacant_entry.insert(peer_stream);
Ok(())
}
#[allow(deprecated)]
fn process_server_intent(
&mut self,
intent: Intent,
event_loop: &mut mio::deprecated::EventLoop<Self>,
) {
match intent {
Intent::Done => {
error!("Server connection closed");
// TODO notify client and shut down
}
Intent::Continue(event_set) => {
event_loop
.reregister(
self.server_stream.evented(),
mio::Token(SERVER_TOKEN),
event_set,
mio::PollOpt::edge() | mio::PollOpt::oneshot(),
)
.unwrap();
}
}
}
#[allow(deprecated)]
fn process_peer_intent(
&mut self,
intent: Intent,
token: mio::Token,
event_loop: &mut mio::deprecated::EventLoop<Self>,
) {
match intent {
Intent::Done => {
self.peer_streams.remove(token.0);
self.client_tx
.send(Response::PeerConnectionClosed(token.0))
.unwrap();
}
Intent::Continue(event_set) => {
if let Some(peer_stream) = self.peer_streams.get_mut(token.0) {
event_loop
.reregister(
peer_stream.evented(),
token,
event_set,
mio::PollOpt::edge() | mio::PollOpt::oneshot(),
)
.unwrap();
}
}
} }
}
} }
}
} }
#[allow(deprecated)] #[allow(deprecated)]
impl mio::deprecated::Handler for Handler { impl mio::deprecated::Handler for Handler {
type Timeout = ();
type Message = Request;
fn ready(
&mut self,
event_loop: &mut mio::deprecated::EventLoop<Self>,
token: mio::Token,
event_set: mio::Ready,
) {
match token {
mio::Token(LISTEN_TOKEN) => {
if event_set.is_readable() {
// A peer wants to connect to us.
match self.listener.accept() {
Ok((_sock, addr)) => {
// TODO add it to peer streams
info!("Peer connection accepted from {}", addr);
}
Err(err) => {
error!("Cannot accept peer connection: {}", err);
}
}
}
event_loop
.reregister(
&self.listener,
token,
mio::Ready::all(),
mio::PollOpt::edge() | mio::PollOpt::oneshot(),
)
.unwrap();
type Timeout = ();
type Message = Request;
fn ready(
&mut self,
event_loop: &mut mio::deprecated::EventLoop<Self>,
token: mio::Token,
event_set: mio::Ready,
) {
match token {
mio::Token(LISTEN_TOKEN) => {
if event_set.is_readable() {
// A peer wants to connect to us.
match self.listener.accept() {
Ok((_sock, addr)) => {
// TODO add it to peer streams
info!("Peer connection accepted from {}", addr);
} }
mio::Token(SERVER_TOKEN) => {
let intent = self.server_stream.on_ready(event_set);
self.process_server_intent(intent, event_loop);
}
mio::Token(peer_id) => {
let intent = match self.peer_streams.get_mut(peer_id) {
Some(peer_stream) => peer_stream.on_ready(event_set),
None => unreachable!("Unknown peer {} is ready", peer_id),
};
self.process_peer_intent(intent, token, event_loop);
Err(err) => {
error!("Cannot accept peer connection: {}", err);
} }
}
} }
}
event_loop
.reregister(
&self.listener,
token,
mio::Ready::all(),
mio::PollOpt::edge() | mio::PollOpt::oneshot(),
)
.unwrap();
}
fn notify(&mut self, event_loop: &mut mio::deprecated::EventLoop<Self>, request: Request) {
match request {
Request::PeerConnect(peer_id, ip, port) => {
if let Err(err) = self.connect_to_peer(peer_id, ip, port, event_loop) {
error!(
"Cannot open peer connection {} to {}:{}: {}",
peer_id, ip, port, err
);
self.client_tx
.send(Response::PeerConnectionClosed(peer_id))
.unwrap();
}
}
mio::Token(SERVER_TOKEN) => {
let intent = self.server_stream.on_ready(event_set);
self.process_server_intent(intent, event_loop);
}
Request::PeerMessage(peer_id, message) => {
let intent = match self.peer_streams.get_mut(peer_id) {
Some(peer_stream) => peer_stream.on_notify(&message),
None => {
error!(
"Cannot send peer message {:?}: unknown id {}",
message, peer_id
);
return;
}
};
self.process_peer_intent(intent, mio::Token(peer_id), event_loop);
}
mio::Token(peer_id) => {
let intent = match self.peer_streams.get_mut(peer_id) {
Some(peer_stream) => peer_stream.on_ready(event_set),
Request::ServerRequest(server_request) => {
let intent = self.server_stream.on_notify(&server_request);
self.process_server_intent(intent, event_loop);
}
None => unreachable!("Unknown peer {} is ready", peer_id),
};
self.process_peer_intent(intent, token, event_loop);
}
}
}
fn notify(
&mut self,
event_loop: &mut mio::deprecated::EventLoop<Self>,
request: Request,
) {
match request {
Request::PeerConnect(peer_id, ip, port) => {
if let Err(err) = self.connect_to_peer(peer_id, ip, port, event_loop) {
error!(
"Cannot open peer connection {} to {}:{}: {}",
peer_id, ip, port, err
);
self
.client_tx
.send(Response::PeerConnectionClosed(peer_id))
.unwrap();
} }
}
Request::PeerMessage(peer_id, message) => {
let intent = match self.peer_streams.get_mut(peer_id) {
Some(peer_stream) => peer_stream.on_notify(&message),
None => {
error!(
"Cannot send peer message {:?}: unknown id {}",
message, peer_id
);
return;
}
};
self.process_peer_intent(intent, mio::Token(peer_id), event_loop);
}
Request::ServerRequest(server_request) => {
let intent = self.server_stream.on_notify(&server_request);
self.process_server_intent(intent, event_loop);
}
} }
}
} }
#[allow(deprecated)] #[allow(deprecated)]
pub type Sender = mio::deprecated::Sender<Request>; pub type Sender = mio::deprecated::Sender<Request>;
pub struct Agent { pub struct Agent {
#[allow(deprecated)]
event_loop: mio::deprecated::EventLoop<Handler>,
handler: Handler,
#[allow(deprecated)]
event_loop: mio::deprecated::EventLoop<Handler>,
handler: Handler,
} }
impl Agent { impl Agent {
pub fn new(client_tx: crossbeam_channel::Sender<Response>) -> io::Result<Self> {
// Create the event loop.
#[allow(deprecated)]
let mut event_loop = mio::deprecated::EventLoop::new()?;
// Create the handler for the event loop and register the handler's
// sockets with the event loop.
let handler = Handler::new(client_tx, &mut event_loop)?;
Ok(Agent {
event_loop: event_loop,
handler: handler,
})
}
pub fn channel(&self) -> Sender {
#[allow(deprecated)]
self.event_loop.channel()
}
pub fn new(
client_tx: crossbeam_channel::Sender<Response>,
) -> io::Result<Self> {
// Create the event loop.
#[allow(deprecated)]
let mut event_loop = mio::deprecated::EventLoop::new()?;
// Create the handler for the event loop and register the handler's
// sockets with the event loop.
let handler = Handler::new(client_tx, &mut event_loop)?;
Ok(Agent {
event_loop: event_loop,
handler: handler,
})
}
pub fn channel(&self) -> Sender {
#[allow(deprecated)]
self.event_loop.channel()
}
pub fn run(&mut self) -> io::Result<()> {
#[allow(deprecated)]
self.event_loop.run(&mut self.handler)
}
pub fn run(&mut self) -> io::Result<()> {
#[allow(deprecated)]
self.event_loop.run(&mut self.handler)
}
} }

+ 251
- 251
proto/src/packet.rs View File

@ -19,46 +19,46 @@ use crate::core::constants::*;
#[derive(Debug)] #[derive(Debug)]
pub struct Packet { pub struct Packet {
/// The current read position in the byte buffer.
cursor: usize,
/// The underlying bytes.
bytes: Vec<u8>,
/// The current read position in the byte buffer.
cursor: usize,
/// The underlying bytes.
bytes: Vec<u8>,
} }
impl io::Read for Packet { impl io::Read for Packet {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let bytes_read = {
let mut slice = &self.bytes[self.cursor..];
slice.read(buf)?
};
self.cursor += bytes_read;
Ok(bytes_read)
}
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let bytes_read = {
let mut slice = &self.bytes[self.cursor..];
slice.read(buf)?
};
self.cursor += bytes_read;
Ok(bytes_read)
}
} }
impl Packet { impl Packet {
/// Returns a readable packet struct from the wire representation of a
/// packet.
/// Assumes that the given vector is a valid length-prefixed packet.
fn from_wire(bytes: Vec<u8>) -> Self {
Packet {
cursor: U32_SIZE,
bytes: bytes,
}
}
/// Provides the main way to read data out of a binary packet.
pub fn read_value<T>(&mut self) -> Result<T, PacketReadError>
where
T: ReadFromPacket,
{
T::read_from_packet(self)
}
/// Returns the number of unread bytes remaining in the packet.
pub fn bytes_remaining(&self) -> usize {
self.bytes.len() - self.cursor
/// Returns a readable packet struct from the wire representation of a
/// packet.
/// Assumes that the given vector is a valid length-prefixed packet.
fn from_wire(bytes: Vec<u8>) -> Self {
Packet {
cursor: U32_SIZE,
bytes: bytes,
} }
}
/// Provides the main way to read data out of a binary packet.
pub fn read_value<T>(&mut self) -> Result<T, PacketReadError>
where
T: ReadFromPacket,
{
T::read_from_packet(self)
}
/// Returns the number of unread bytes remaining in the packet.
pub fn bytes_remaining(&self) -> usize {
self.bytes.len() - self.cursor
}
} }
/*===================* /*===================*
@ -67,45 +67,45 @@ impl Packet {
#[derive(Debug)] #[derive(Debug)]
pub struct MutPacket { pub struct MutPacket {
bytes: Vec<u8>,
bytes: Vec<u8>,
} }
impl MutPacket { impl MutPacket {
/// Returns an empty packet with the given packet code.
pub fn new() -> Self {
// Leave space for the eventual size of the packet.
MutPacket {
bytes: vec![0; U32_SIZE],
}
/// Returns an empty packet with the given packet code.
pub fn new() -> Self {
// Leave space for the eventual size of the packet.
MutPacket {
bytes: vec![0; U32_SIZE],
} }
/// Provides the main way to write data into a binary packet.
pub fn write_value<T>(&mut self, val: &T) -> io::Result<()>
where
T: WriteToPacket + ?Sized,
}
/// Provides the main way to write data into a binary packet.
pub fn write_value<T>(&mut self, val: &T) -> io::Result<()>
where
T: WriteToPacket + ?Sized,
{
val.write_to_packet(self)
}
/// Consumes the mutable packet and returns its wire representation.
pub fn into_bytes(mut self) -> Vec<u8> {
let length = (self.bytes.len() - U32_SIZE) as u32;
{ {
val.write_to_packet(self)
}
/// Consumes the mutable packet and returns its wire representation.
pub fn into_bytes(mut self) -> Vec<u8> {
let length = (self.bytes.len() - U32_SIZE) as u32;
{
let mut first_word = &mut self.bytes[..U32_SIZE];
first_word.write_u32::<LittleEndian>(length).unwrap();
}
self.bytes
let mut first_word = &mut self.bytes[..U32_SIZE];
first_word.write_u32::<LittleEndian>(length).unwrap();
} }
self.bytes
}
} }
impl io::Write for MutPacket { impl io::Write for MutPacket {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.bytes.write(buf)
}
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.bytes.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.bytes.flush()
}
fn flush(&mut self) -> io::Result<()> {
self.bytes.flush()
}
} }
/*===================* /*===================*
@ -115,68 +115,68 @@ impl io::Write for MutPacket {
/// This enum contains an error that arose when reading data out of a Packet. /// This enum contains an error that arose when reading data out of a Packet.
#[derive(Debug)] #[derive(Debug)]
pub enum PacketReadError { pub enum PacketReadError {
/// Attempted to read a boolean, but the value was not 0 nor 1.
InvalidBoolError(u8),
/// Attempted to read an unsigned 16-bit integer, but the value was too
/// large.
InvalidU16Error(u32),
/// Attempted to read a string, but a character was invalid.
InvalidStringError(Vec<u8>),
/// Attempted to read a user::Status, but the value was not a valid
/// representation of an enum variant.
InvalidUserStatusError(u32),
/// Encountered an I/O error while reading.
IOError(io::Error),
/// Attempted to read a boolean, but the value was not 0 nor 1.
InvalidBoolError(u8),
/// Attempted to read an unsigned 16-bit integer, but the value was too
/// large.
InvalidU16Error(u32),
/// Attempted to read a string, but a character was invalid.
InvalidStringError(Vec<u8>),
/// Attempted to read a user::Status, but the value was not a valid
/// representation of an enum variant.
InvalidUserStatusError(u32),
/// Encountered an I/O error while reading.
IOError(io::Error),
} }
impl fmt::Display for PacketReadError { impl fmt::Display for PacketReadError {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
match *self {
PacketReadError::InvalidBoolError(n) => {
write!(fmt, "InvalidBoolError: {}", n)
}
PacketReadError::InvalidU16Error(n) => {
write!(fmt, "InvalidU16Error: {}", n)
}
PacketReadError::InvalidStringError(ref bytes) => {
write!(fmt, "InvalidStringError: {:?}", bytes)
}
PacketReadError::InvalidUserStatusError(n) => {
write!(fmt, "InvalidUserStatusError: {}", n)
}
PacketReadError::IOError(ref err) => {
write!(fmt, "IOError: {}", err)
}
}
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
match *self {
PacketReadError::InvalidBoolError(n) => {
write!(fmt, "InvalidBoolError: {}", n)
}
PacketReadError::InvalidU16Error(n) => {
write!(fmt, "InvalidU16Error: {}", n)
}
PacketReadError::InvalidStringError(ref bytes) => {
write!(fmt, "InvalidStringError: {:?}", bytes)
}
PacketReadError::InvalidUserStatusError(n) => {
write!(fmt, "InvalidUserStatusError: {}", n)
}
PacketReadError::IOError(ref err) => {
write!(fmt, "IOError: {}", err)
}
} }
}
} }
impl error::Error for PacketReadError { impl error::Error for PacketReadError {
fn description(&self) -> &str {
match *self {
PacketReadError::InvalidBoolError(_) => "InvalidBoolError",
PacketReadError::InvalidU16Error(_) => "InvalidU16Error",
PacketReadError::InvalidStringError(_) => "InvalidStringError",
PacketReadError::InvalidUserStatusError(_) => "InvalidUserStatusError",
PacketReadError::IOError(_) => "IOError",
}
fn description(&self) -> &str {
match *self {
PacketReadError::InvalidBoolError(_) => "InvalidBoolError",
PacketReadError::InvalidU16Error(_) => "InvalidU16Error",
PacketReadError::InvalidStringError(_) => "InvalidStringError",
PacketReadError::InvalidUserStatusError(_) => "InvalidUserStatusError",
PacketReadError::IOError(_) => "IOError",
} }
fn cause(&self) -> Option<&dyn error::Error> {
match *self {
PacketReadError::InvalidBoolError(_) => None,
PacketReadError::InvalidU16Error(_) => None,
PacketReadError::InvalidStringError(_) => None,
PacketReadError::InvalidUserStatusError(_) => None,
PacketReadError::IOError(ref err) => Some(err),
}
}
fn cause(&self) -> Option<&dyn error::Error> {
match *self {
PacketReadError::InvalidBoolError(_) => None,
PacketReadError::InvalidU16Error(_) => None,
PacketReadError::InvalidStringError(_) => None,
PacketReadError::InvalidUserStatusError(_) => None,
PacketReadError::IOError(ref err) => Some(err),
} }
}
} }
impl From<io::Error> for PacketReadError { impl From<io::Error> for PacketReadError {
fn from(err: io::Error) -> Self {
PacketReadError::IOError(err)
}
fn from(err: io::Error) -> Self {
PacketReadError::IOError(err)
}
} }
/*==================* /*==================*
@ -186,81 +186,81 @@ impl From<io::Error> for PacketReadError {
/// This trait is implemented by types that can be deserialized from binary /// This trait is implemented by types that can be deserialized from binary
/// Packets. /// Packets.
pub trait ReadFromPacket: Sized { pub trait ReadFromPacket: Sized {
fn read_from_packet(_: &mut Packet) -> Result<Self, PacketReadError>;
fn read_from_packet(_: &mut Packet) -> Result<Self, PacketReadError>;
} }
/// 32-bit integers are serialized in 4 bytes, little-endian. /// 32-bit integers are serialized in 4 bytes, little-endian.
impl ReadFromPacket for u32 { impl ReadFromPacket for u32 {
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
Ok(packet.read_u32::<LittleEndian>()?)
}
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
Ok(packet.read_u32::<LittleEndian>()?)
}
} }
/// For convenience, usize's are deserialized as u32's then casted. /// For convenience, usize's are deserialized as u32's then casted.
impl ReadFromPacket for usize { impl ReadFromPacket for usize {
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
Ok(u32::read_from_packet(packet)? as usize)
}
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
Ok(u32::read_from_packet(packet)? as usize)
}
} }
/// Booleans are serialized as single bytes, containing either 0 or 1. /// Booleans are serialized as single bytes, containing either 0 or 1.
impl ReadFromPacket for bool { impl ReadFromPacket for bool {
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
match packet.read_u8()? {
0 => Ok(false),
1 => Ok(true),
n => Err(PacketReadError::InvalidBoolError(n)),
}
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
match packet.read_u8()? {
0 => Ok(false),
1 => Ok(true),
n => Err(PacketReadError::InvalidBoolError(n)),
} }
}
} }
/// 16-bit integers are serialized as 32-bit integers. /// 16-bit integers are serialized as 32-bit integers.
impl ReadFromPacket for u16 { impl ReadFromPacket for u16 {
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let n = u32::read_from_packet(packet)?;
if n > MAX_PORT {
return Err(PacketReadError::InvalidU16Error(n));
}
Ok(n as u16)
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let n = u32::read_from_packet(packet)?;
if n > MAX_PORT {
return Err(PacketReadError::InvalidU16Error(n));
} }
Ok(n as u16)
}
} }
/// IPv4 addresses are serialized directly as 32-bit integers. /// IPv4 addresses are serialized directly as 32-bit integers.
impl ReadFromPacket for net::Ipv4Addr { impl ReadFromPacket for net::Ipv4Addr {
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let ip = u32::read_from_packet(packet)?;
Ok(net::Ipv4Addr::from(ip))
}
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let ip = u32::read_from_packet(packet)?;
Ok(net::Ipv4Addr::from(ip))
}
} }
/// Strings are serialized as length-prefixed arrays of ISO-8859-1 encoded /// Strings are serialized as length-prefixed arrays of ISO-8859-1 encoded
/// characters. /// characters.
impl ReadFromPacket for String { impl ReadFromPacket for String {
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let len = usize::read_from_packet(packet)?;
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let len = usize::read_from_packet(packet)?;
let mut buffer = vec![0; len];
packet.read_exact(&mut buffer)?;
let mut buffer = vec![0; len];
packet.read_exact(&mut buffer)?;
match ISO_8859_1.decode(&buffer, DecoderTrap::Strict) {
Ok(string) => Ok(string),
Err(_) => Err(PacketReadError::InvalidStringError(buffer)),
}
match ISO_8859_1.decode(&buffer, DecoderTrap::Strict) {
Ok(string) => Ok(string),
Err(_) => Err(PacketReadError::InvalidStringError(buffer)),
} }
}
} }
/// Vectors are serialized as length-prefixed arrays of values. /// Vectors are serialized as length-prefixed arrays of values.
impl<T: ReadFromPacket> ReadFromPacket for Vec<T> { impl<T: ReadFromPacket> ReadFromPacket for Vec<T> {
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let len = usize::read_from_packet(packet)?;
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let len = usize::read_from_packet(packet)?;
let mut vec = Vec::new();
for _ in 0..len {
vec.push(T::read_from_packet(packet)?);
}
Ok(vec)
let mut vec = Vec::new();
for _ in 0..len {
vec.push(T::read_from_packet(packet)?);
} }
Ok(vec)
}
} }
/*=================* /*=================*
@ -270,55 +270,55 @@ impl<T: ReadFromPacket> ReadFromPacket for Vec<T> {
/// This trait is implemented by types that can be serialized to a binary /// This trait is implemented by types that can be serialized to a binary
/// MutPacket. /// MutPacket.
pub trait WriteToPacket { pub trait WriteToPacket {
fn write_to_packet(&self, _: &mut MutPacket) -> io::Result<()>;
fn write_to_packet(&self, _: &mut MutPacket) -> io::Result<()>;
} }
/// 32-bit integers are serialized in 4 bytes, little-endian. /// 32-bit integers are serialized in 4 bytes, little-endian.
impl WriteToPacket for u32 { impl WriteToPacket for u32 {
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
packet.write_u32::<LittleEndian>(*self)
}
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
packet.write_u32::<LittleEndian>(*self)
}
} }
/// Booleans are serialized as single bytes, containing either 0 or 1. /// Booleans are serialized as single bytes, containing either 0 or 1.
impl WriteToPacket for bool { impl WriteToPacket for bool {
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
packet.write_u8(*self as u8)?;
Ok(())
}
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
packet.write_u8(*self as u8)?;
Ok(())
}
} }
/// 16-bit integers are serialized as 32-bit integers. /// 16-bit integers are serialized as 32-bit integers.
impl WriteToPacket for u16 { impl WriteToPacket for u16 {
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
(*self as u32).write_to_packet(packet)
}
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
(*self as u32).write_to_packet(packet)
}
} }
/// Strings are serialized as a length-prefixed array of ISO-8859-1 encoded /// Strings are serialized as a length-prefixed array of ISO-8859-1 encoded
/// characters. /// characters.
impl WriteToPacket for str { impl WriteToPacket for str {
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
// Encode the string.
let bytes = match ISO_8859_1.encode(self, EncoderTrap::Strict) {
Ok(bytes) => bytes,
Err(_) => {
let copy = self.to_string();
return Err(io::Error::new(io::ErrorKind::Other, copy));
}
};
// Then write the bytes to the packet.
(bytes.len() as u32).write_to_packet(packet)?;
packet.write(&bytes)?;
Ok(())
}
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
// Encode the string.
let bytes = match ISO_8859_1.encode(self, EncoderTrap::Strict) {
Ok(bytes) => bytes,
Err(_) => {
let copy = self.to_string();
return Err(io::Error::new(io::ErrorKind::Other, copy));
}
};
// Then write the bytes to the packet.
(bytes.len() as u32).write_to_packet(packet)?;
packet.write(&bytes)?;
Ok(())
}
} }
/// Deref coercion does not happen for trait methods apparently. /// Deref coercion does not happen for trait methods apparently.
impl WriteToPacket for String { impl WriteToPacket for String {
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
(self as &str).write_to_packet(packet)
}
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
(self as &str).write_to_packet(packet)
}
} }
/*========* /*========*
@ -328,86 +328,86 @@ impl WriteToPacket for String {
/// This enum defines the possible states of a packet parser state machine. /// This enum defines the possible states of a packet parser state machine.
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
enum State { enum State {
/// The parser is waiting to read enough bytes to determine the
/// length of the following packet.
ReadingLength,
/// The parser is waiting to read enough bytes to form the entire
/// packet.
ReadingPacket,
/// The parser is waiting to read enough bytes to determine the
/// length of the following packet.
ReadingLength,
/// The parser is waiting to read enough bytes to form the entire
/// packet.
ReadingPacket,
} }
#[derive(Debug)] #[derive(Debug)]
pub struct Parser { pub struct Parser {
state: State,
num_bytes_left: usize,
buffer: Vec<u8>,
state: State,
num_bytes_left: usize,
buffer: Vec<u8>,
} }
impl Parser { impl Parser {
pub fn new() -> Self {
Parser {
state: State::ReadingLength,
num_bytes_left: U32_SIZE,
buffer: vec![0; U32_SIZE],
}
pub fn new() -> Self {
Parser {
state: State::ReadingLength,
num_bytes_left: U32_SIZE,
buffer: vec![0; U32_SIZE],
}
}
/// Attemps to read a packet in a non-blocking fashion.
/// If enough bytes can be read from the given byte stream to form a
/// complete packet `p`, returns `Ok(Some(p))`.
/// If not enough bytes are available, returns `Ok(None)`.
/// If an I/O error `e` arises when trying to read the underlying stream,
/// returns `Err(e)`.
/// Note: as long as this function returns `Ok(Some(p))`, the caller is
/// responsible for calling it once more to ensure that all packets are
/// read as soon as possible.
pub fn try_read<U>(&mut self, stream: &mut U) -> io::Result<Option<Packet>>
where
U: io::Read,
{
// Try to read as many bytes as we currently need from the underlying
// byte stream.
let offset = self.buffer.len() - self.num_bytes_left;
#[allow(deprecated)]
match stream.try_read(&mut self.buffer[offset..])? {
None => (),
Some(num_bytes_read) => {
self.num_bytes_left -= num_bytes_read;
}
} }
/// Attemps to read a packet in a non-blocking fashion.
/// If enough bytes can be read from the given byte stream to form a
/// complete packet `p`, returns `Ok(Some(p))`.
/// If not enough bytes are available, returns `Ok(None)`.
/// If an I/O error `e` arises when trying to read the underlying stream,
/// returns `Err(e)`.
/// Note: as long as this function returns `Ok(Some(p))`, the caller is
/// responsible for calling it once more to ensure that all packets are
/// read as soon as possible.
pub fn try_read<U>(&mut self, stream: &mut U) -> io::Result<Option<Packet>>
where
U: io::Read,
{
// Try to read as many bytes as we currently need from the underlying
// byte stream.
let offset = self.buffer.len() - self.num_bytes_left;
#[allow(deprecated)]
match stream.try_read(&mut self.buffer[offset..])? {
None => (),
Some(num_bytes_read) => {
self.num_bytes_left -= num_bytes_read;
}
}
// If we haven't read enough bytes, return.
if self.num_bytes_left > 0 {
return Ok(None);
}
// Otherwise, the behavior depends on what state we were in.
match self.state {
State::ReadingLength => {
// If we have finished reading the length prefix, then
// deserialize it, switch states and try to read the packet
// bytes.
let message_len = LittleEndian::read_u32(&mut self.buffer) as usize;
if message_len > MAX_MESSAGE_SIZE {
unimplemented!();
};
self.state = State::ReadingPacket;
self.num_bytes_left = message_len;
self.buffer.resize(message_len + U32_SIZE, 0);
self.try_read(stream)
}
State::ReadingPacket => {
// If we have finished reading the packet, swap the full buffer
// out and return the packet made from the full buffer.
self.state = State::ReadingLength;
self.num_bytes_left = U32_SIZE;
let new_buffer = vec![0; U32_SIZE];
let old_buffer = mem::replace(&mut self.buffer, new_buffer);
Ok(Some(Packet::from_wire(old_buffer)))
}
}
// If we haven't read enough bytes, return.
if self.num_bytes_left > 0 {
return Ok(None);
}
// Otherwise, the behavior depends on what state we were in.
match self.state {
State::ReadingLength => {
// If we have finished reading the length prefix, then
// deserialize it, switch states and try to read the packet
// bytes.
let message_len = LittleEndian::read_u32(&mut self.buffer) as usize;
if message_len > MAX_MESSAGE_SIZE {
unimplemented!();
};
self.state = State::ReadingPacket;
self.num_bytes_left = message_len;
self.buffer.resize(message_len + U32_SIZE, 0);
self.try_read(stream)
}
State::ReadingPacket => {
// If we have finished reading the packet, swap the full buffer
// out and return the packet made from the full buffer.
self.state = State::ReadingLength;
self.num_bytes_left = U32_SIZE;
let new_buffer = vec![0; U32_SIZE];
let old_buffer = mem::replace(&mut self.buffer, new_buffer);
Ok(Some(Packet::from_wire(old_buffer)))
}
} }
}
} }

+ 155
- 146
proto/src/peer/message.rs View File

@ -3,10 +3,13 @@ use std::io;
use log::warn; use log::warn;
use crate::core::value::{ use crate::core::value::{
ValueDecode, ValueDecodeError, ValueDecoder, ValueEncode, ValueEncodeError, ValueEncoder,
ValueDecode, ValueDecodeError, ValueDecoder, ValueEncode, ValueEncodeError,
ValueEncoder,
}; };
use crate::peer::constants::*; use crate::peer::constants::*;
use crate::{MutPacket, Packet, PacketReadError, ReadFromPacket, WriteToPacket};
use crate::{
MutPacket, Packet, PacketReadError, ReadFromPacket, WriteToPacket,
};
/*=========* /*=========*
* MESSAGE * * MESSAGE *
@ -15,183 +18,189 @@ use crate::{MutPacket, Packet, PacketReadError, ReadFromPacket, WriteToPacket};
/// This enum contains all the possible messages peers can exchange. /// This enum contains all the possible messages peers can exchange.
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug, Eq, PartialEq)]
pub enum Message { pub enum Message {
PierceFirewall(u32),
PeerInit(PeerInit),
Unknown(u32),
PierceFirewall(u32),
PeerInit(PeerInit),
Unknown(u32),
} }
impl ReadFromPacket for Message { impl ReadFromPacket for Message {
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let code: u32 = packet.read_value()?;
let message = match code {
CODE_PIERCE_FIREWALL => Message::PierceFirewall(packet.read_value()?),
CODE_PEER_INIT => Message::PeerInit(packet.read_value()?),
code => Message::Unknown(code),
};
let bytes_remaining = packet.bytes_remaining();
if bytes_remaining > 0 {
warn!(
"Peer message with code {} contains {} extra bytes",
code, bytes_remaining
)
}
Ok(message)
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let code: u32 = packet.read_value()?;
let message = match code {
CODE_PIERCE_FIREWALL => Message::PierceFirewall(packet.read_value()?),
CODE_PEER_INIT => Message::PeerInit(packet.read_value()?),
code => Message::Unknown(code),
};
let bytes_remaining = packet.bytes_remaining();
if bytes_remaining > 0 {
warn!(
"Peer message with code {} contains {} extra bytes",
code, bytes_remaining
)
} }
Ok(message)
}
} }
impl ValueDecode for Message { impl ValueDecode for Message {
fn decode_from(decoder: &mut ValueDecoder) -> Result<Self, ValueDecodeError> {
let position = decoder.position();
let code: u32 = decoder.decode()?;
let message = match code {
CODE_PIERCE_FIREWALL => {
let val = decoder.decode()?;
Message::PierceFirewall(val)
}
CODE_PEER_INIT => {
let peer_init = decoder.decode()?;
Message::PeerInit(peer_init)
}
_ => {
return Err(ValueDecodeError::InvalidData {
value_name: "peer message code".to_string(),
cause: format!("unknown value {}", code),
position: position,
})
}
};
Ok(message)
}
fn decode_from(decoder: &mut ValueDecoder) -> Result<Self, ValueDecodeError> {
let position = decoder.position();
let code: u32 = decoder.decode()?;
let message = match code {
CODE_PIERCE_FIREWALL => {
let val = decoder.decode()?;
Message::PierceFirewall(val)
}
CODE_PEER_INIT => {
let peer_init = decoder.decode()?;
Message::PeerInit(peer_init)
}
_ => {
return Err(ValueDecodeError::InvalidData {
value_name: "peer message code".to_string(),
cause: format!("unknown value {}", code),
position: position,
})
}
};
Ok(message)
}
} }
impl ValueEncode for Message { impl ValueEncode for Message {
fn encode_to(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> {
match *self {
Message::PierceFirewall(token) => {
encoder.encode_u32(CODE_PIERCE_FIREWALL)?;
encoder.encode_u32(token)?;
}
Message::PeerInit(ref request) => {
encoder.encode_u32(CODE_PEER_INIT)?;
request.encode_to(encoder)?;
}
Message::Unknown(_) => unreachable!(),
}
Ok(())
fn encode_to(
&self,
encoder: &mut ValueEncoder,
) -> Result<(), ValueEncodeError> {
match *self {
Message::PierceFirewall(token) => {
encoder.encode_u32(CODE_PIERCE_FIREWALL)?;
encoder.encode_u32(token)?;
}
Message::PeerInit(ref request) => {
encoder.encode_u32(CODE_PEER_INIT)?;
request.encode_to(encoder)?;
}
Message::Unknown(_) => unreachable!(),
} }
Ok(())
}
} }
impl WriteToPacket for Message { impl WriteToPacket for Message {
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
match *self {
Message::PierceFirewall(ref token) => {
packet.write_value(&CODE_PIERCE_FIREWALL)?;
packet.write_value(token)?;
}
Message::PeerInit(ref request) => {
packet.write_value(&CODE_PEER_INIT)?;
packet.write_value(request)?;
}
Message::Unknown(_) => unreachable!(),
}
Ok(())
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
match *self {
Message::PierceFirewall(ref token) => {
packet.write_value(&CODE_PIERCE_FIREWALL)?;
packet.write_value(token)?;
}
Message::PeerInit(ref request) => {
packet.write_value(&CODE_PEER_INIT)?;
packet.write_value(request)?;
}
Message::Unknown(_) => unreachable!(),
} }
Ok(())
}
} }
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug, Eq, PartialEq)]
pub struct PeerInit { pub struct PeerInit {
pub user_name: String,
pub connection_type: String,
pub token: u32,
pub user_name: String,
pub connection_type: String,
pub token: u32,
} }
impl ReadFromPacket for PeerInit { impl ReadFromPacket for PeerInit {
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let user_name = packet.read_value()?;
let connection_type = packet.read_value()?;
let token = packet.read_value()?;
Ok(PeerInit {
user_name,
connection_type,
token,
})
}
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let user_name = packet.read_value()?;
let connection_type = packet.read_value()?;
let token = packet.read_value()?;
Ok(PeerInit {
user_name,
connection_type,
token,
})
}
} }
impl WriteToPacket for PeerInit { impl WriteToPacket for PeerInit {
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
packet.write_value(&self.user_name)?;
packet.write_value(&self.connection_type)?;
packet.write_value(&self.token)?;
Ok(())
}
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
packet.write_value(&self.user_name)?;
packet.write_value(&self.connection_type)?;
packet.write_value(&self.token)?;
Ok(())
}
} }
impl ValueEncode for PeerInit { impl ValueEncode for PeerInit {
fn encode_to(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> {
encoder.encode_string(&self.user_name)?;
encoder.encode_string(&self.connection_type)?;
encoder.encode_u32(self.token)?;
Ok(())
}
fn encode_to(
&self,
encoder: &mut ValueEncoder,
) -> Result<(), ValueEncodeError> {
encoder.encode_string(&self.user_name)?;
encoder.encode_string(&self.connection_type)?;
encoder.encode_u32(self.token)?;
Ok(())
}
} }
impl ValueDecode for PeerInit { impl ValueDecode for PeerInit {
fn decode_from(decoder: &mut ValueDecoder) -> Result<Self, ValueDecodeError> {
let user_name = decoder.decode()?;
let connection_type = decoder.decode()?;
let token = decoder.decode()?;
Ok(PeerInit {
user_name,
connection_type,
token,
})
}
fn decode_from(decoder: &mut ValueDecoder) -> Result<Self, ValueDecodeError> {
let user_name = decoder.decode()?;
let connection_type = decoder.decode()?;
let token = decoder.decode()?;
Ok(PeerInit {
user_name,
connection_type,
token,
})
}
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use bytes::BytesMut;
use crate::core::value::tests::roundtrip;
use crate::core::value::{ValueDecodeError, ValueDecoder};
use super::*;
#[test]
fn invalid_code() {
let mut bytes = BytesMut::new();
bytes.extend_from_slice(&[57, 5, 0, 0]);
let result = ValueDecoder::new(&bytes).decode::<Message>();
assert_eq!(
result,
Err(ValueDecodeError::InvalidData {
value_name: "peer message code".to_string(),
cause: "unknown value 1337".to_string(),
position: 0,
})
);
}
#[test]
fn roundtrip_pierce_firewall() {
roundtrip(Message::PierceFirewall(1337))
}
#[test]
fn roundtrip_peer_init() {
roundtrip(Message::PeerInit(PeerInit {
user_name: "alice".to_string(),
connection_type: "P".to_string(),
token: 1337,
}));
}
use bytes::BytesMut;
use crate::core::value::tests::roundtrip;
use crate::core::value::{ValueDecodeError, ValueDecoder};
use super::*;
#[test]
fn invalid_code() {
let mut bytes = BytesMut::new();
bytes.extend_from_slice(&[57, 5, 0, 0]);
let result = ValueDecoder::new(&bytes).decode::<Message>();
assert_eq!(
result,
Err(ValueDecodeError::InvalidData {
value_name: "peer message code".to_string(),
cause: "unknown value 1337".to_string(),
position: 0,
})
);
}
#[test]
fn roundtrip_pierce_firewall() {
roundtrip(Message::PierceFirewall(1337))
}
#[test]
fn roundtrip_peer_init() {
roundtrip(Message::PeerInit(PeerInit {
user_name: "alice".to_string(),
connection_type: "P".to_string(),
token: 1337,
}));
}
} }

+ 173
- 170
proto/src/server/client.rs View File

@ -7,214 +7,217 @@ use thiserror::Error;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use crate::core::channel::{Channel, ChannelError}; use crate::core::channel::{Channel, ChannelError};
use crate::server::{Credentials, LoginResponse, ServerRequest, ServerResponse, Version};
use crate::server::{
Credentials, LoginResponse, ServerRequest, ServerResponse, Version,
};
/// A client for the client-server protocol. /// A client for the client-server protocol.
#[derive(Debug)] #[derive(Debug)]
pub struct Client { pub struct Client {
channel: Channel<ServerResponse, ServerRequest>,
version: Version,
channel: Channel<ServerResponse, ServerRequest>,
version: Version,
} }
/// An error that arose while logging in to a remote server. /// An error that arose while logging in to a remote server.
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum ClientLoginError { pub enum ClientLoginError {
#[error("login failed: {0}")]
LoginFailed(String, Client),
#[error("login failed: {0}")]
LoginFailed(String, Client),
#[error("unexpected response: {0:?}")]
UnexpectedResponse(ServerResponse),
#[error("unexpected response: {0:?}")]
UnexpectedResponse(ServerResponse),
#[error("channel error: {0}")]
ChannelError(#[from] ChannelError),
#[error("channel error: {0}")]
ChannelError(#[from] ChannelError),
} }
impl From<io::Error> for ClientLoginError { impl From<io::Error> for ClientLoginError {
fn from(error: io::Error) -> Self {
ClientLoginError::from(ChannelError::from(error))
}
fn from(error: io::Error) -> Self {
ClientLoginError::from(ChannelError::from(error))
}
} }
/// A `Channel` that sends `ServerRequest`s and receives `ServerResponse`s. /// A `Channel` that sends `ServerRequest`s and receives `ServerResponse`s.
pub type ClientChannel = Channel<ServerResponse, ServerRequest>; pub type ClientChannel = Channel<ServerResponse, ServerRequest>;
impl Client { impl Client {
/// Instantiates a new client
pub fn new(tcp_stream: TcpStream) -> Self {
Client {
channel: Channel::new(tcp_stream),
version: Version::default(),
}
/// Instantiates a new client
pub fn new(tcp_stream: TcpStream) -> Self {
Client {
channel: Channel::new(tcp_stream),
version: Version::default(),
} }
/// Sets a custom version to identify as to the server.
pub fn with_version(mut self, version: Version) -> Self {
self.version = version;
self
}
/// Performs the login exchange, presenting `credentials` to the server.
pub async fn login(
mut self,
credentials: Credentials,
) -> Result<ClientChannel, ClientLoginError> {
let login_request = credentials.into_login_request(self.version);
debug!("Client: sending login request: {:?}", login_request);
let request = login_request.into();
self.channel.write(&request).await?;
let response = self.channel.read().await?;
debug!("Client: received first response: {:?}", response);
match response {
ServerResponse::LoginResponse(LoginResponse::LoginOk {
motd,
ip,
password_md5_opt,
}) => {
info!("Client: Logged in successfully!");
info!("Client: Message Of The Day: {}", motd);
info!("Client: Public IP address: {}", ip);
info!("Client: Password MD5: {:?}", password_md5_opt);
Ok(self.channel)
}
ServerResponse::LoginResponse(LoginResponse::LoginFail { reason }) => {
Err(ClientLoginError::LoginFailed(reason, self))
}
response => Err(ClientLoginError::UnexpectedResponse(response)),
}
}
/// Sets a custom version to identify as to the server.
pub fn with_version(mut self, version: Version) -> Self {
self.version = version;
self
}
/// Performs the login exchange, presenting `credentials` to the server.
pub async fn login(
mut self,
credentials: Credentials,
) -> Result<ClientChannel, ClientLoginError> {
let login_request = credentials.into_login_request(self.version);
debug!("Client: sending login request: {:?}", login_request);
let request = login_request.into();
self.channel.write(&request).await?;
let response = self.channel.read().await?;
debug!("Client: received first response: {:?}", response);
match response {
ServerResponse::LoginResponse(LoginResponse::LoginOk {
motd,
ip,
password_md5_opt,
}) => {
info!("Client: Logged in successfully!");
info!("Client: Message Of The Day: {}", motd);
info!("Client: Public IP address: {}", ip);
info!("Client: Password MD5: {:?}", password_md5_opt);
Ok(self.channel)
}
ServerResponse::LoginResponse(LoginResponse::LoginFail { reason }) => {
Err(ClientLoginError::LoginFailed(reason, self))
}
response => Err(ClientLoginError::UnexpectedResponse(response)),
} }
}
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use futures::stream::{empty, StreamExt};
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use crate::server::testing::{ServerBuilder, ShutdownType, UserStatusMap};
use crate::server::{
Credentials, ServerRequest, ServerResponse, UserStatusRequest, UserStatusResponse,
};
use crate::UserStatus;
use super::Client;
// Enable capturing logs in tests.
fn init() {
let _ = env_logger::builder().is_test(true).try_init();
}
use futures::stream::{empty, StreamExt};
use tokio::net::TcpStream;
use tokio::sync::mpsc;
// Returns default `Credentials` suitable for testing.
fn credentials() -> Credentials {
let user_name = "alice".to_string();
let password = "sekrit".to_string();
Credentials::new(user_name, password).unwrap()
}
#[tokio::test]
async fn login() {
init();
use crate::server::testing::{ServerBuilder, ShutdownType, UserStatusMap};
use crate::server::{
Credentials, ServerRequest, ServerResponse, UserStatusRequest,
UserStatusResponse,
};
use crate::UserStatus;
let (server, handle) = ServerBuilder::default().bind().await.unwrap();
let server_task = tokio::spawn(server.serve());
use super::Client;
let stream = TcpStream::connect(handle.address()).await.unwrap();
// Enable capturing logs in tests.
fn init() {
let _ = env_logger::builder().is_test(true).try_init();
}
let channel = Client::new(stream).login(credentials()).await.unwrap();
// Returns default `Credentials` suitable for testing.
fn credentials() -> Credentials {
let user_name = "alice".to_string();
let password = "sekrit".to_string();
Credentials::new(user_name, password).unwrap()
}
// Send nothing, receive no responses.
let inbound = channel.run(empty());
tokio::pin!(inbound);
#[tokio::test]
async fn login() {
init();
assert!(inbound.next().await.is_none());
let (server, handle) = ServerBuilder::default().bind().await.unwrap();
let server_task = tokio::spawn(server.serve());
handle.shutdown(ShutdownType::LameDuck);
server_task.await.unwrap().unwrap();
}
#[tokio::test]
async fn simple_exchange() {
init();
let response = UserStatusResponse {
user_name: "alice".to_string(),
status: UserStatus::Online,
is_privileged: false,
};
let mut user_status_map = UserStatusMap::default();
user_status_map.insert(response.clone());
let (server, handle) = ServerBuilder::default()
.with_user_status_map(user_status_map)
.bind()
.await
.unwrap();
let server_task = tokio::spawn(server.serve());
let stream = TcpStream::connect(handle.address()).await.unwrap();
let channel = Client::new(stream).login(credentials()).await.unwrap();
let outbound = Box::pin(async_stream::stream! {
yield ServerRequest::UserStatusRequest(UserStatusRequest {
user_name: "bob".to_string(),
});
yield ServerRequest::UserStatusRequest(UserStatusRequest {
user_name: "alice".to_string(),
});
});
let inbound = channel.run(outbound);
tokio::pin!(inbound);
assert_eq!(
inbound.next().await.unwrap().unwrap(),
ServerResponse::UserStatusResponse(response)
);
assert!(inbound.next().await.is_none());
handle.shutdown(ShutdownType::LameDuck);
server_task.await.unwrap().unwrap();
}
let stream = TcpStream::connect(handle.address()).await.unwrap();
#[tokio::test]
async fn stream_closed() {
init();
let channel = Client::new(stream).login(credentials()).await.unwrap();
let (server, handle) = ServerBuilder::default().bind().await.unwrap();
let server_task = tokio::spawn(server.serve());
// Send nothing, receive no responses.
let inbound = channel.run(empty());
tokio::pin!(inbound);
let stream = TcpStream::connect(handle.address()).await.unwrap();
assert!(inbound.next().await.is_none());
let channel = Client::new(stream).login(credentials()).await.unwrap();
handle.shutdown(ShutdownType::LameDuck);
server_task.await.unwrap().unwrap();
}
let (_request_tx, mut request_rx) = mpsc::channel(1);
let outbound = Box::pin(async_stream::stream! {
while let Some(request) = request_rx.recv().await {
yield request;
}
});
#[tokio::test]
async fn simple_exchange() {
init();
let inbound = channel.run(outbound);
tokio::pin!(inbound);
// Server shuts down, closing its connection before the client has had a
// chance to send all of `outbound`.
handle.shutdown(ShutdownType::Immediate);
// Wait for the server to terminate, to avoid race conditions.
server_task.await.unwrap().unwrap();
let response = UserStatusResponse {
user_name: "alice".to_string(),
status: UserStatus::Online,
is_privileged: false,
};
// Check that the client returns the correct error, then stops running.
assert!(inbound
.next()
.await
.unwrap()
.unwrap_err()
.is_unexpected_eof());
assert!(inbound.next().await.is_none());
}
let mut user_status_map = UserStatusMap::default();
user_status_map.insert(response.clone());
let (server, handle) = ServerBuilder::default()
.with_user_status_map(user_status_map)
.bind()
.await
.unwrap();
let server_task = tokio::spawn(server.serve());
let stream = TcpStream::connect(handle.address()).await.unwrap();
let channel = Client::new(stream).login(credentials()).await.unwrap();
let outbound = Box::pin(async_stream::stream! {
yield ServerRequest::UserStatusRequest(UserStatusRequest {
user_name: "bob".to_string(),
});
yield ServerRequest::UserStatusRequest(UserStatusRequest {
user_name: "alice".to_string(),
});
});
let inbound = channel.run(outbound);
tokio::pin!(inbound);
assert_eq!(
inbound.next().await.unwrap().unwrap(),
ServerResponse::UserStatusResponse(response)
);
assert!(inbound.next().await.is_none());
handle.shutdown(ShutdownType::LameDuck);
server_task.await.unwrap().unwrap();
}
#[tokio::test]
async fn stream_closed() {
init();
let (server, handle) = ServerBuilder::default().bind().await.unwrap();
let server_task = tokio::spawn(server.serve());
let stream = TcpStream::connect(handle.address()).await.unwrap();
let channel = Client::new(stream).login(credentials()).await.unwrap();
let (_request_tx, mut request_rx) = mpsc::channel(1);
let outbound = Box::pin(async_stream::stream! {
while let Some(request) = request_rx.recv().await {
yield request;
}
});
let inbound = channel.run(outbound);
tokio::pin!(inbound);
// Server shuts down, closing its connection before the client has had a
// chance to send all of `outbound`.
handle.shutdown(ShutdownType::Immediate);
// Wait for the server to terminate, to avoid race conditions.
server_task.await.unwrap().unwrap();
// Check that the client returns the correct error, then stops running.
assert!(inbound
.next()
.await
.unwrap()
.unwrap_err()
.is_unexpected_eof());
assert!(inbound.next().await.is_none());
}
} }

+ 106
- 105
proto/src/server/credentials.rs View File

@ -8,119 +8,120 @@ use crate::server::{LoginRequest, Version};
/// Credentials for logging in a client to a server. /// Credentials for logging in a client to a server.
#[derive(Debug, Eq, PartialEq)] #[derive(Debug, Eq, PartialEq)]
pub struct Credentials { pub struct Credentials {
user_name: String,
password: String,
digest: String,
user_name: String,
password: String,
digest: String,
} }
impl Credentials { impl Credentials {
/// Attempts to build credentials.
///
/// Returns `None` if `password` is empty, `Some(credentials)` otherwise.
pub fn new(user_name: String, password: String) -> Option<Credentials> {
if password.is_empty() {
return None;
}
let mut hasher = Md5::new();
hasher.input_str(&user_name);
hasher.input_str(&password);
let digest = hasher.result_str();
Some(Credentials {
user_name,
password,
digest,
})
/// Attempts to build credentials.
///
/// Returns `None` if `password` is empty, `Some(credentials)` otherwise.
pub fn new(user_name: String, password: String) -> Option<Credentials> {
if password.is_empty() {
return None;
} }
/// The user name to log in as.
pub fn user_name(&self) -> &str {
&self.user_name
}
/// The password to log in with.
pub fn password(&self) -> &str {
&self.password
}
/// Returns md5(user_name + password).
pub fn digest(&self) -> &str {
&self.digest
}
pub fn into_login_request(self, version: Version) -> LoginRequest {
LoginRequest {
user_name: self.user_name,
password: self.password,
digest: self.digest,
major: version.major,
minor: version.minor,
}
let mut hasher = Md5::new();
hasher.input_str(&user_name);
hasher.input_str(&password);
let digest = hasher.result_str();
Some(Credentials {
user_name,
password,
digest,
})
}
/// The user name to log in as.
pub fn user_name(&self) -> &str {
&self.user_name
}
/// The password to log in with.
pub fn password(&self) -> &str {
&self.password
}
/// Returns md5(user_name + password).
pub fn digest(&self) -> &str {
&self.digest
}
pub fn into_login_request(self, version: Version) -> LoginRequest {
LoginRequest {
user_name: self.user_name,
password: self.password,
digest: self.digest,
major: version.major,
minor: version.minor,
} }
}
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::server::{LoginRequest, Version};
use super::Credentials;
#[test]
fn empty_password() {
assert_eq!(Credentials::new("alice".to_string(), String::new()), None);
}
#[test]
fn new_success() {
let user_name = "alice".to_string();
let password = "sekrit".to_string();
let credentials = Credentials::new(user_name, password).unwrap();
assert_eq!(credentials.user_name(), "alice");
assert_eq!(credentials.password(), "sekrit");
}
#[test]
fn digest() {
let user_name = "alice".to_string();
let password = "sekrit".to_string();
let credentials = Credentials::new(user_name, password).unwrap();
// To generate the expected value on the command line;
//
// ```sh
// $ user_name="alice"
// $ password="sekrit"
// $ echo -e "${user_name}${password}" | md5sum
// ```
assert_eq!(credentials.digest(), "286da88eb442032bdd3913979af76e8a");
}
#[test]
fn into_login_request() {
let user_name = "alice".to_string();
let password = "sekrit".to_string();
let credentials = Credentials::new(user_name.clone(), password.clone()).unwrap();
let digest = credentials.digest().to_string();
let version = Version {
major: 13,
minor: 37,
};
assert_eq!(
credentials.into_login_request(version),
LoginRequest {
user_name,
password,
digest,
major: 13,
minor: 37,
}
);
}
use crate::server::{LoginRequest, Version};
use super::Credentials;
#[test]
fn empty_password() {
assert_eq!(Credentials::new("alice".to_string(), String::new()), None);
}
#[test]
fn new_success() {
let user_name = "alice".to_string();
let password = "sekrit".to_string();
let credentials = Credentials::new(user_name, password).unwrap();
assert_eq!(credentials.user_name(), "alice");
assert_eq!(credentials.password(), "sekrit");
}
#[test]
fn digest() {
let user_name = "alice".to_string();
let password = "sekrit".to_string();
let credentials = Credentials::new(user_name, password).unwrap();
// To generate the expected value on the command line;
//
// ```sh
// $ user_name="alice"
// $ password="sekrit"
// $ echo -e "${user_name}${password}" | md5sum
// ```
assert_eq!(credentials.digest(), "286da88eb442032bdd3913979af76e8a");
}
#[test]
fn into_login_request() {
let user_name = "alice".to_string();
let password = "sekrit".to_string();
let credentials =
Credentials::new(user_name.clone(), password.clone()).unwrap();
let digest = credentials.digest().to_string();
let version = Version {
major: 13,
minor: 37,
};
assert_eq!(
credentials.into_login_request(version),
LoginRequest {
user_name,
password,
digest,
major: 13,
minor: 37,
}
);
}
} }

+ 633
- 599
proto/src/server/request.rs
File diff suppressed because it is too large
View File


+ 1338
- 1237
proto/src/server/response.rs
File diff suppressed because it is too large
View File


+ 293
- 289
proto/src/server/testing.rs View File

@ -8,8 +8,8 @@ use std::sync::Arc;
use log::{debug, info, warn}; use log::{debug, info, warn};
use parking_lot::Mutex; use parking_lot::Mutex;
use tokio::net::{ use tokio::net::{
tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpListener, TcpStream,
tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpListener, TcpStream,
}; };
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio::sync::oneshot; use tokio::sync::oneshot;
@ -17,381 +17,385 @@ use tokio::sync::watch;
use crate::core::frame::{FrameReader, FrameWriter}; use crate::core::frame::{FrameReader, FrameWriter};
use crate::server::{ use crate::server::{
LoginResponse, ServerRequest, ServerResponse, UserStatusRequest, UserStatusResponse,
LoginResponse, ServerRequest, ServerResponse, UserStatusRequest,
UserStatusResponse,
}; };
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct UserStatusMap { pub struct UserStatusMap {
map: HashMap<String, UserStatusResponse>,
map: HashMap<String, UserStatusResponse>,
} }
// IDEA: impl FromIterator<UserStatusResponse> for UserStatusMap. // IDEA: impl FromIterator<UserStatusResponse> for UserStatusMap.
impl UserStatusMap { impl UserStatusMap {
pub fn insert(&mut self, response: UserStatusResponse) {
self.map.insert(response.user_name.clone(), response);
}
pub fn insert(&mut self, response: UserStatusResponse) {
self.map.insert(response.user_name.clone(), response);
}
pub fn get(&self, user_name: &str) -> Option<UserStatusResponse> {
self.map.get(user_name).map(|response| response.clone())
}
pub fn get(&self, user_name: &str) -> Option<UserStatusResponse> {
self.map.get(user_name).map(|response| response.clone())
}
} }
struct Handler { struct Handler {
reader: FrameReader<ServerRequest, OwnedReadHalf>,
writer: FrameWriter<ServerResponse, OwnedWriteHalf>,
peer_address: SocketAddr,
user_status_map: Arc<Mutex<UserStatusMap>>,
reader: FrameReader<ServerRequest, OwnedReadHalf>,
writer: FrameWriter<ServerResponse, OwnedWriteHalf>,
peer_address: SocketAddr,
user_status_map: Arc<Mutex<UserStatusMap>>,
} }
impl Handler { impl Handler {
fn ipv4_address(&self) -> Ipv4Addr {
match self.peer_address.ip() {
IpAddr::V4(ipv4_addr) => ipv4_addr,
IpAddr::V6(_) => Ipv4Addr::UNSPECIFIED,
}
}
async fn send_response(&mut self, response: &ServerResponse) -> io::Result<()> {
debug!("Handler: sending response: {:?}", response);
self.writer.write(response).await
}
async fn handle_login(&mut self) -> io::Result<()> {
match self.reader.read().await? {
Some(ServerRequest::LoginRequest(request)) => {
info!("Handler: Received login request: {:?}", request);
}
Some(request) => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("expected login request, got: {:?}", request),
));
}
None => {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"expected login request".to_string(),
));
}
};
let response = ServerResponse::LoginResponse(LoginResponse::LoginOk {
motd: "hi there".to_string(),
ip: self.ipv4_address(),
password_md5_opt: None,
});
self.send_response(&response).await
fn ipv4_address(&self) -> Ipv4Addr {
match self.peer_address.ip() {
IpAddr::V4(ipv4_addr) => ipv4_addr,
IpAddr::V6(_) => Ipv4Addr::UNSPECIFIED,
} }
async fn handle_request(&mut self, request: ServerRequest) -> io::Result<()> {
debug!("Handler: received request: {:?}", request);
match request {
ServerRequest::UserStatusRequest(UserStatusRequest { user_name }) => {
let entry = self.user_status_map.lock().get(&user_name);
if let Some(response) = entry {
let response = ServerResponse::UserStatusResponse(response);
self.send_response(&response).await?;
}
}
_ => {
warn!("Handler: unhandled request: {:?}", request);
}
}
async fn send_response(
&mut self,
response: &ServerResponse,
) -> io::Result<()> {
debug!("Handler: sending response: {:?}", response);
self.writer.write(response).await
}
async fn handle_login(&mut self) -> io::Result<()> {
match self.reader.read().await? {
Some(ServerRequest::LoginRequest(request)) => {
info!("Handler: Received login request: {:?}", request);
}
Some(request) => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("expected login request, got: {:?}", request),
));
}
None => {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"expected login request".to_string(),
));
}
};
let response = ServerResponse::LoginResponse(LoginResponse::LoginOk {
motd: "hi there".to_string(),
ip: self.ipv4_address(),
password_md5_opt: None,
});
self.send_response(&response).await
}
async fn handle_request(&mut self, request: ServerRequest) -> io::Result<()> {
debug!("Handler: received request: {:?}", request);
match request {
ServerRequest::UserStatusRequest(UserStatusRequest { user_name }) => {
let entry = self.user_status_map.lock().get(&user_name);
if let Some(response) = entry {
let response = ServerResponse::UserStatusResponse(response);
self.send_response(&response).await?;
} }
Ok(())
}
_ => {
warn!("Handler: unhandled request: {:?}", request);
}
} }
async fn run(mut self) -> io::Result<()> {
self.handle_login().await?;
Ok(())
}
while let Some(request) = self.reader.read().await? {
self.handle_request(request).await?;
}
async fn run(mut self) -> io::Result<()> {
self.handle_login().await?;
info!("Handler: client disconnecting, shutting down");
Ok(())
while let Some(request) = self.reader.read().await? {
self.handle_request(request).await?;
} }
info!("Handler: client disconnecting, shutting down");
Ok(())
}
} }
struct GracefulHandler { struct GracefulHandler {
handler: Handler,
shutdown_rx: watch::Receiver<()>,
handler: Handler,
shutdown_rx: watch::Receiver<()>,
} }
impl GracefulHandler { impl GracefulHandler {
async fn run(mut self) -> io::Result<()> {
tokio::select!(
result = self.handler.run() => {
if let Err(ref error) = result {
warn!("GracefulHandler: handler returned error {:?}", error);
}
result
},
// Ignore receive errors - if shutdown_rx's sender is dropped, we take
// that as a signal to shut down too.
_ = self.shutdown_rx.changed() => {
info!("GracefulHandler: shutting down.");
Ok(())
},
)
}
async fn run(mut self) -> io::Result<()> {
tokio::select!(
result = self.handler.run() => {
if let Err(ref error) = result {
warn!("GracefulHandler: handler returned error {:?}", error);
}
result
},
// Ignore receive errors - if shutdown_rx's sender is dropped, we take
// that as a signal to shut down too.
_ = self.shutdown_rx.changed() => {
info!("GracefulHandler: shutting down.");
Ok(())
},
)
}
} }
struct SenderHandler { struct SenderHandler {
handler: GracefulHandler,
result_tx: mpsc::Sender<io::Result<()>>,
handler: GracefulHandler,
result_tx: mpsc::Sender<io::Result<()>>,
} }
impl SenderHandler { impl SenderHandler {
async fn run(self) {
let result = self.handler.run().await;
let _ = self.result_tx.send(result).await;
}
async fn run(self) {
let result = self.handler.run().await;
let _ = self.result_tx.send(result).await;
}
} }
/// A builder for Server instances. /// A builder for Server instances.
#[derive(Default)] #[derive(Default)]
pub struct ServerBuilder { pub struct ServerBuilder {
user_status_map: Option<Arc<Mutex<UserStatusMap>>>,
user_status_map: Option<Arc<Mutex<UserStatusMap>>>,
} }
impl ServerBuilder { impl ServerBuilder {
/// Sets the UserStatusMap which the server will use to respond to
/// UserStatusRequest messages.
pub fn with_user_status_map(mut self, map: UserStatusMap) -> Self {
self.user_status_map = Some(Arc::new(Mutex::new(map)));
self
}
/// Binds to a localhost port, then returns a server and its handle.
pub async fn bind(self) -> io::Result<(Server, ServerHandle)> {
let listener = TcpListener::bind("localhost:0").await?;
let address = listener.local_addr()?;
let user_status_map = match self.user_status_map {
Some(user_status_map) => user_status_map,
None => Arc::new(Mutex::new(UserStatusMap::default())),
};
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let (handler_shutdown_tx, handler_shutdown_rx) = watch::channel(());
let (result_tx, result_rx) = mpsc::channel(1);
Ok((
Server {
listener,
shutdown_rx,
handler_shutdown_tx,
handler_shutdown_rx,
result_tx,
result_rx,
user_status_map,
},
ServerHandle {
shutdown_tx,
address,
},
))
}
/// Sets the UserStatusMap which the server will use to respond to
/// UserStatusRequest messages.
pub fn with_user_status_map(mut self, map: UserStatusMap) -> Self {
self.user_status_map = Some(Arc::new(Mutex::new(map)));
self
}
/// Binds to a localhost port, then returns a server and its handle.
pub async fn bind(self) -> io::Result<(Server, ServerHandle)> {
let listener = TcpListener::bind("localhost:0").await?;
let address = listener.local_addr()?;
let user_status_map = match self.user_status_map {
Some(user_status_map) => user_status_map,
None => Arc::new(Mutex::new(UserStatusMap::default())),
};
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let (handler_shutdown_tx, handler_shutdown_rx) = watch::channel(());
let (result_tx, result_rx) = mpsc::channel(1);
Ok((
Server {
listener,
shutdown_rx,
handler_shutdown_tx,
handler_shutdown_rx,
result_tx,
result_rx,
user_status_map,
},
ServerHandle {
shutdown_tx,
address,
},
))
}
} }
/// Specifies how to shut down a server. /// Specifies how to shut down a server.
pub enum ShutdownType { pub enum ShutdownType {
/// Shut down immediately, sever open connections.
Immediate,
/// Shut down immediately, sever open connections.
Immediate,
/// Stop accepting new connections, wait for open connections to close.
LameDuck,
/// Stop accepting new connections, wait for open connections to close.
LameDuck,
} }
/// A simple server for connecting to in tests. /// A simple server for connecting to in tests.
pub struct Server { pub struct Server {
// Listener for new connections.
listener: TcpListener,
// Listener for new connections.
listener: TcpListener,
// Receiver for ServerHandle shutdown() notification.
shutdown_rx: oneshot::Receiver<ShutdownType>,
// Receiver for ServerHandle shutdown() notification.
shutdown_rx: oneshot::Receiver<ShutdownType>,
// Watch channel for signalling immediate termination to handlers.
handler_shutdown_tx: watch::Sender<()>,
handler_shutdown_rx: watch::Receiver<()>,
// Watch channel for signalling immediate termination to handlers.
handler_shutdown_tx: watch::Sender<()>,
handler_shutdown_rx: watch::Receiver<()>,
// Channel for receiving results back from handlers.
result_tx: mpsc::Sender<io::Result<()>>,
result_rx: mpsc::Receiver<io::Result<()>>,
// Channel for receiving results back from handlers.
result_tx: mpsc::Sender<io::Result<()>>,
result_rx: mpsc::Receiver<io::Result<()>>,
// Shared state for handlers to use when serving responses.
user_status_map: Arc<Mutex<UserStatusMap>>,
// Shared state for handlers to use when serving responses.
user_status_map: Arc<Mutex<UserStatusMap>>,
} }
/// Allows interacting with a running `Server`. /// Allows interacting with a running `Server`.
pub struct ServerHandle { pub struct ServerHandle {
shutdown_tx: oneshot::Sender<ShutdownType>,
address: SocketAddr,
shutdown_tx: oneshot::Sender<ShutdownType>,
address: SocketAddr,
} }
impl ServerHandle { impl ServerHandle {
/// Returns the address on which the server is accepting connections.
pub fn address(&self) -> SocketAddr {
self.address
}
/// Starts shutting down the server.
/// Does nothing if the server is already shutting down or even dropped.
pub fn shutdown(self, how: ShutdownType) {
// Ignore send errors, which mean that the server has been dropped.
let _ = self.shutdown_tx.send(how);
}
/// Returns the address on which the server is accepting connections.
pub fn address(&self) -> SocketAddr {
self.address
}
/// Starts shutting down the server.
/// Does nothing if the server is already shutting down or even dropped.
pub fn shutdown(self, how: ShutdownType) {
// Ignore send errors, which mean that the server has been dropped.
let _ = self.shutdown_tx.send(how);
}
} }
impl Server { impl Server {
/// Returns the address to which this server is bound.
/// This is always localhost and a random port chosen by the OS.
pub fn address(&self) -> io::Result<SocketAddr> {
self.listener.local_addr()
}
/// Spawns a handler for the given new stream, initiated by a remote peer.
fn spawn_handler(&mut self, stream: TcpStream, peer_address: SocketAddr) {
let (read_half, write_half) = stream.into_split();
let handler = SenderHandler {
handler: GracefulHandler {
handler: Handler {
reader: FrameReader::new(read_half),
writer: FrameWriter::new(write_half),
peer_address,
user_status_map: self.user_status_map.clone(),
},
shutdown_rx: self.handler_shutdown_rx.clone(),
},
result_tx: self.result_tx.clone(),
};
tokio::spawn(handler.run());
/// Returns the address to which this server is bound.
/// This is always localhost and a random port chosen by the OS.
pub fn address(&self) -> io::Result<SocketAddr> {
self.listener.local_addr()
}
/// Spawns a handler for the given new stream, initiated by a remote peer.
fn spawn_handler(&mut self, stream: TcpStream, peer_address: SocketAddr) {
let (read_half, write_half) = stream.into_split();
let handler = SenderHandler {
handler: GracefulHandler {
handler: Handler {
reader: FrameReader::new(read_half),
writer: FrameWriter::new(write_half),
peer_address,
user_status_map: self.user_status_map.clone(),
},
shutdown_rx: self.handler_shutdown_rx.clone(),
},
result_tx: self.result_tx.clone(),
};
tokio::spawn(handler.run());
}
/// Accepts a single connection, spawns a handler for it and returns.
///
/// Useful for tests who need to guarantee a handler is spawned before the
/// server shuts down.
async fn accept(&mut self) -> io::Result<()> {
let (stream, peer_address) = self.listener.accept().await?;
self.spawn_handler(stream, peer_address);
Ok(())
}
/// Runs the server: accepts incoming connections and responds to requests.
///
/// Returns an error if:
///
/// - an error was encountered while listening
/// - an error was encountered while serving a request
///
pub async fn serve(mut self) -> io::Result<()> {
loop {
tokio::select!(
result = self.listener.accept() => {
let (stream, peer_address) = result?;
self.spawn_handler(stream, peer_address);
},
result = &mut self.shutdown_rx => {
// If shutdown_rx's sender is dropped and we receive an error, we take
// that as a signal to shut down immediately too.
match result.unwrap_or(ShutdownType::Immediate) {
ShutdownType::LameDuck => break,
ShutdownType::Immediate => {
// Send errors cannot happen, since we hold onto a receiver in
// `self.handler_shutdown_rx`.
self.handler_shutdown_tx.send(()).unwrap();
break
}
}
},
optional_result = self.result_rx.recv() => {
// We can never exhaust the result channel because we hold onto a
// sender in `self.result_tx`.
let result = optional_result.unwrap();
// Return an error if a handler returns an error.
result?;
}
);
} }
/// Accepts a single connection, spawns a handler for it and returns.
///
/// Useful for tests who need to guarantee a handler is spawned before the
/// server shuts down.
async fn accept(&mut self) -> io::Result<()> {
let (stream, peer_address) = self.listener.accept().await?;
self.spawn_handler(stream, peer_address);
Ok(())
info!("Server: shutting down");
drop(self.result_tx);
while let Some(result) = self.result_rx.recv().await {
result?;
} }
/// Runs the server: accepts incoming connections and responds to requests.
///
/// Returns an error if:
///
/// - an error was encountered while listening
/// - an error was encountered while serving a request
///
pub async fn serve(mut self) -> io::Result<()> {
loop {
tokio::select!(
result = self.listener.accept() => {
let (stream, peer_address) = result?;
self.spawn_handler(stream, peer_address);
},
result = &mut self.shutdown_rx => {
// If shutdown_rx's sender is dropped and we receive an error, we take
// that as a signal to shut down immediately too.
match result.unwrap_or(ShutdownType::Immediate) {
ShutdownType::LameDuck => break,
ShutdownType::Immediate => {
// Send errors cannot happen, since we hold onto a receiver in
// `self.handler_shutdown_rx`.
self.handler_shutdown_tx.send(()).unwrap();
break
}
}
},
optional_result = self.result_rx.recv() => {
// We can never exhaust the result channel because we hold onto a
// sender in `self.result_tx`.
let result = optional_result.unwrap();
// Return an error if a handler returns an error.
result?;
}
);
}
info!("Server: shutting down");
drop(self.result_tx);
while let Some(result) = self.result_rx.recv().await {
result?;
}
Ok(())
}
Ok(())
}
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::io;
use std::io;
use tokio::net::TcpStream;
use tokio::net::TcpStream;
use super::{ServerBuilder, ShutdownType};
use super::{ServerBuilder, ShutdownType};
// Enable capturing logs in tests.
fn init() {
let _ = env_logger::builder().is_test(true).try_init();
}
// Enable capturing logs in tests.
fn init() {
let _ = env_logger::builder().is_test(true).try_init();
}
#[tokio::test]
async fn new_binds_to_localhost() {
init();
#[tokio::test]
async fn new_binds_to_localhost() {
init();
let (server, handle) = ServerBuilder::default().bind().await.unwrap();
assert!(server.address().unwrap().ip().is_loopback());
assert_eq!(server.address().unwrap(), handle.address());
}
let (server, handle) = ServerBuilder::default().bind().await.unwrap();
assert!(server.address().unwrap().ip().is_loopback());
assert_eq!(server.address().unwrap(), handle.address());
}
#[tokio::test]
async fn accepts_incoming_connections() {
init();
#[tokio::test]
async fn accepts_incoming_connections() {
init();
let (server, handle) = ServerBuilder::default().bind().await.unwrap();
let server_task = tokio::spawn(server.serve());
let (server, handle) = ServerBuilder::default().bind().await.unwrap();
let server_task = tokio::spawn(server.serve());
// The connection succeeds.
let _ = TcpStream::connect(handle.address()).await.unwrap();
// The connection succeeds.
let _ = TcpStream::connect(handle.address()).await.unwrap();
handle.shutdown(ShutdownType::Immediate);
handle.shutdown(ShutdownType::Immediate);
// Ignore errors, which can happen when the handler task is spawned right
// before we call `handle.shutdown()`. See `serve_yields_handler_error`.
let _ = server_task.await.unwrap();
}
// Ignore errors, which can happen when the handler task is spawned right
// before we call `handle.shutdown()`. See `serve_yields_handler_error`.
let _ = server_task.await.unwrap();
}
// This test verifies that when a handler encounters an error, it is
// reflected in `Server::serve()`'s return value.
#[tokio::test]
async fn serve_yields_handler_error() {
init();
// This test verifies that when a handler encounters an error, it is
// reflected in `Server::serve()`'s return value.
#[tokio::test]
async fn serve_yields_handler_error() {
init();
let (mut server, handle) = ServerBuilder::default().bind().await.unwrap();
let (mut server, handle) = ServerBuilder::default().bind().await.unwrap();
// The connection is accepted, then immediately closed.
let address = handle.address();
tokio::spawn(async move {
let _ = TcpStream::connect(address).await.unwrap();
});
// The connection is accepted, then immediately closed.
let address = handle.address();
tokio::spawn(async move {
let _ = TcpStream::connect(address).await.unwrap();
});
// Accept the connection on the server and spawn a handler for it.
server.accept().await.unwrap();
// Accept the connection on the server and spawn a handler for it.
server.accept().await.unwrap();
// Signal that the server should stop accepting incoming connections.
handle.shutdown(ShutdownType::LameDuck);
// Signal that the server should stop accepting incoming connections.
handle.shutdown(ShutdownType::LameDuck);
// Drain outstanding requests, encountering the error.
let error = server.serve().await.unwrap_err();
assert_eq!(error.kind(), io::ErrorKind::UnexpectedEof);
}
// Drain outstanding requests, encountering the error.
let error = server.serve().await.unwrap_err();
assert_eq!(error.kind(), io::ErrorKind::UnexpectedEof);
}
} }

+ 9
- 9
proto/src/server/version.rs View File

@ -3,18 +3,18 @@
/// Specifies a protocol version. /// Specifies a protocol version.
#[derive(Clone, Copy, Debug)] #[derive(Clone, Copy, Debug)]
pub struct Version { pub struct Version {
/// The major version number.
pub major: u32,
/// The major version number.
pub major: u32,
/// The minor version number.
pub minor: u32,
/// The minor version number.
pub minor: u32,
} }
impl Default for Version { impl Default for Version {
fn default() -> Self {
Self {
major: 181,
minor: 100,
}
fn default() -> Self {
Self {
major: 181,
minor: 100,
} }
}
} }

+ 165
- 164
proto/src/stream.rs View File

@ -16,41 +16,41 @@ use super::packet::{MutPacket, Parser, ReadFromPacket, WriteToPacket};
/// A struct used for writing bytes to a TryWrite sink. /// A struct used for writing bytes to a TryWrite sink.
#[derive(Debug)] #[derive(Debug)]
struct OutBuf { struct OutBuf {
cursor: usize,
bytes: Vec<u8>,
cursor: usize,
bytes: Vec<u8>,
} }
impl From<Vec<u8>> for OutBuf { impl From<Vec<u8>> for OutBuf {
fn from(bytes: Vec<u8>) -> Self {
OutBuf {
cursor: 0,
bytes: bytes,
}
fn from(bytes: Vec<u8>) -> Self {
OutBuf {
cursor: 0,
bytes: bytes,
} }
}
} }
impl OutBuf { impl OutBuf {
#[inline]
fn remaining(&self) -> usize {
self.bytes.len() - self.cursor
}
#[inline]
fn has_remaining(&self) -> bool {
self.remaining() > 0
}
#[allow(deprecated)]
fn try_write_to<T>(&mut self, mut writer: T) -> io::Result<Option<usize>>
where
T: mio::deprecated::TryWrite,
{
let result = writer.try_write(&self.bytes[self.cursor..]);
if let Ok(Some(bytes_written)) = result {
self.cursor += bytes_written;
}
result
#[inline]
fn remaining(&self) -> usize {
self.bytes.len() - self.cursor
}
#[inline]
fn has_remaining(&self) -> bool {
self.remaining() > 0
}
#[allow(deprecated)]
fn try_write_to<T>(&mut self, mut writer: T) -> io::Result<Option<usize>>
where
T: mio::deprecated::TryWrite,
{
let result = writer.try_write(&self.bytes[self.cursor..]);
if let Ok(Some(bytes_written)) = result {
self.cursor += bytes_written;
} }
result
}
} }
/*========* /*========*
@ -60,170 +60,171 @@ impl OutBuf {
/// This trait is implemented by packet sinks to which a stream can forward /// This trait is implemented by packet sinks to which a stream can forward
/// the packets it reads. /// the packets it reads.
pub trait SendPacket { pub trait SendPacket {
type Value: ReadFromPacket;
type Error: error::Error;
type Value: ReadFromPacket;
type Error: error::Error;
fn send_packet(&mut self, _: Self::Value) -> Result<(), Self::Error>;
fn send_packet(&mut self, _: Self::Value) -> Result<(), Self::Error>;
fn notify_open(&mut self) -> Result<(), Self::Error>;
fn notify_open(&mut self) -> Result<(), Self::Error>;
} }
/// This enum defines the possible actions the stream wants to take after /// This enum defines the possible actions the stream wants to take after
/// processing an event. /// processing an event.
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub enum Intent { pub enum Intent {
/// The stream is done, the event loop handler can drop it.
Done,
/// The stream wants to wait for the next event matching the given
/// `EventSet`.
Continue(mio::Ready),
/// The stream is done, the event loop handler can drop it.
Done,
/// The stream wants to wait for the next event matching the given
/// `EventSet`.
Continue(mio::Ready),
} }
/// This struct wraps around an mio tcp stream and handles packet reads and /// This struct wraps around an mio tcp stream and handles packet reads and
/// writes. /// writes.
#[derive(Debug)] #[derive(Debug)]
pub struct Stream<T: SendPacket> { pub struct Stream<T: SendPacket> {
parser: Parser,
queue: VecDeque<OutBuf>,
sender: T,
stream: mio::tcp::TcpStream,
parser: Parser,
queue: VecDeque<OutBuf>,
sender: T,
stream: mio::tcp::TcpStream,
is_connected: bool,
is_connected: bool,
} }
impl<T: SendPacket> Stream<T> { impl<T: SendPacket> Stream<T> {
/// Returns a new stream, asynchronously connected to the given address,
/// which forwards incoming packets to the given sender.
/// If an error occurs when connecting, returns an error.
pub fn new<U>(addr_spec: U, sender: T) -> io::Result<Self>
where
U: ToSocketAddrs + fmt::Debug,
{
for sock_addr in addr_spec.to_socket_addrs()? {
if let Ok(stream) = mio::tcp::TcpStream::connect(&sock_addr) {
return Ok(Stream {
parser: Parser::new(),
queue: VecDeque::new(),
sender: sender,
stream: stream,
is_connected: false,
});
}
}
Err(io::Error::new(
io::ErrorKind::Other,
format!("Cannot connect to {:?}", addr_spec),
))
/// Returns a new stream, asynchronously connected to the given address,
/// which forwards incoming packets to the given sender.
/// If an error occurs when connecting, returns an error.
pub fn new<U>(addr_spec: U, sender: T) -> io::Result<Self>
where
U: ToSocketAddrs + fmt::Debug,
{
for sock_addr in addr_spec.to_socket_addrs()? {
if let Ok(stream) = mio::tcp::TcpStream::connect(&sock_addr) {
return Ok(Stream {
parser: Parser::new(),
queue: VecDeque::new(),
sender: sender,
stream: stream,
is_connected: false,
});
}
} }
/// Returns a reference to the underlying byte stream, to allow it to be
/// registered with an event loop.
pub fn evented(&self) -> &mio::tcp::TcpStream {
&self.stream
Err(io::Error::new(
io::ErrorKind::Other,
format!("Cannot connect to {:?}", addr_spec),
))
}
/// Returns a reference to the underlying byte stream, to allow it to be
/// registered with an event loop.
pub fn evented(&self) -> &mio::tcp::TcpStream {
&self.stream
}
/// The stream is ready to be read from.
fn on_readable(&mut self) -> Result<(), String> {
loop {
let mut packet = match self.parser.try_read(&mut self.stream) {
Ok(Some(packet)) => packet,
Ok(None) => break,
Err(e) => return Err(format!("Error reading stream: {}", e)),
};
let value = match packet.read_value() {
Ok(value) => value,
Err(e) => return Err(format!("Error parsing packet: {}", e)),
};
if let Err(e) = self.sender.send_packet(value) {
return Err(format!("Error sending parsed packet: {}", e));
}
} }
/// The stream is ready to be read from.
fn on_readable(&mut self) -> Result<(), String> {
loop {
let mut packet = match self.parser.try_read(&mut self.stream) {
Ok(Some(packet)) => packet,
Ok(None) => break,
Err(e) => return Err(format!("Error reading stream: {}", e)),
};
let value = match packet.read_value() {
Ok(value) => value,
Err(e) => return Err(format!("Error parsing packet: {}", e)),
};
if let Err(e) = self.sender.send_packet(value) {
return Err(format!("Error sending parsed packet: {}", e));
}
Ok(())
}
/// The stream is ready to be written to.
fn on_writable(&mut self) -> io::Result<()> {
loop {
let mut outbuf = match self.queue.pop_front() {
Some(outbuf) => outbuf,
None => break,
};
let option = outbuf.try_write_to(&mut self.stream)?;
match option {
Some(_) => {
if outbuf.has_remaining() {
self.queue.push_front(outbuf)
}
// Continue looping
} }
Ok(())
}
/// The stream is ready to be written to.
fn on_writable(&mut self) -> io::Result<()> {
loop {
let mut outbuf = match self.queue.pop_front() {
Some(outbuf) => outbuf,
None => break,
};
let option = outbuf.try_write_to(&mut self.stream)?;
match option {
Some(_) => {
if outbuf.has_remaining() {
self.queue.push_front(outbuf)
}
// Continue looping
}
None => {
self.queue.push_front(outbuf);
break;
}
}
None => {
self.queue.push_front(outbuf);
break;
} }
Ok(())
}
} }
Ok(())
}
/// The stream is ready to read, write, or both.
pub fn on_ready(&mut self, event_set: mio::Ready) -> Intent {
#[allow(deprecated)]
if event_set.is_hup() || event_set.is_error() {
return Intent::Done;
}
if event_set.is_readable() {
let result = self.on_readable();
if let Err(e) = result {
error!("Stream input error: {}", e);
return Intent::Done;
}
}
if event_set.is_writable() {
let result = self.on_writable();
if let Err(e) = result {
error!("Stream output error: {}", e);
return Intent::Done;
}
}
// We must have read or written something succesfully if we're here,
// so the stream must be connected.
if !self.is_connected {
// If we weren't already connected, notify the sink.
if let Err(err) = self.sender.notify_open() {
error!("Cannot notify client that stream is open: {}", err);
return Intent::Done;
}
// And record the fact that we are now connected.
self.is_connected = true;
}
/// The stream is ready to read, write, or both.
pub fn on_ready(&mut self, event_set: mio::Ready) -> Intent {
#[allow(deprecated)]
if event_set.is_hup() || event_set.is_error() {
return Intent::Done;
}
if event_set.is_readable() {
let result = self.on_readable();
if let Err(e) = result {
error!("Stream input error: {}", e);
return Intent::Done;
}
}
if event_set.is_writable() {
let result = self.on_writable();
if let Err(e) = result {
error!("Stream output error: {}", e);
return Intent::Done;
}
}
// We're always interested in reading more.
#[allow(deprecated)]
let mut event_set = mio::Ready::readable() | mio::Ready::hup() | mio::Ready::error();
// If there is still stuff to write in the queue, we're interested in
// the socket becoming writable too.
if self.queue.len() > 0 {
event_set = event_set | mio::Ready::writable();
}
// We must have read or written something succesfully if we're here,
// so the stream must be connected.
if !self.is_connected {
// If we weren't already connected, notify the sink.
if let Err(err) = self.sender.notify_open() {
error!("Cannot notify client that stream is open: {}", err);
return Intent::Done;
}
// And record the fact that we are now connected.
self.is_connected = true;
}
Intent::Continue(event_set)
// We're always interested in reading more.
#[allow(deprecated)]
let mut event_set =
mio::Ready::readable() | mio::Ready::hup() | mio::Ready::error();
// If there is still stuff to write in the queue, we're interested in
// the socket becoming writable too.
if self.queue.len() > 0 {
event_set = event_set | mio::Ready::writable();
} }
/// The stream has been notified.
pub fn on_notify<V>(&mut self, payload: &V) -> Intent
where
V: WriteToPacket,
{
let mut packet = MutPacket::new();
let result = packet.write_value(payload);
if let Err(e) = result {
error!("Error writing payload to packet: {}", e);
return Intent::Done;
}
self.queue.push_back(OutBuf::from(packet.into_bytes()));
Intent::Continue(mio::Ready::readable() | mio::Ready::writable())
Intent::Continue(event_set)
}
/// The stream has been notified.
pub fn on_notify<V>(&mut self, payload: &V) -> Intent
where
V: WriteToPacket,
{
let mut packet = MutPacket::new();
let result = packet.write_value(payload);
if let Err(e) = result {
error!("Error writing payload to packet: {}", e);
return Intent::Done;
} }
self.queue.push_back(OutBuf::from(packet.into_bytes()));
Intent::Continue(mio::Ready::readable() | mio::Ready::writable())
}
} }

+ 46
- 44
proto/tests/connect.rs View File

@ -3,79 +3,81 @@ use tokio::io;
use tokio::net; use tokio::net;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use solstice_proto::server::{Client, Credentials, ServerResponse, UserStatusRequest};
use solstice_proto::server::{
Client, Credentials, ServerResponse, UserStatusRequest,
};
// Enable capturing logs in tests. // Enable capturing logs in tests.
fn init() { fn init() {
let _ = env_logger::builder().is_test(true).try_init();
let _ = env_logger::builder().is_test(true).try_init();
} }
async fn connect() -> io::Result<net::TcpStream> { async fn connect() -> io::Result<net::TcpStream> {
net::TcpStream::connect("server.slsknet.org:2242").await
net::TcpStream::connect("server.slsknet.org:2242").await
} }
fn make_user_name(test_name: &str) -> String { fn make_user_name(test_name: &str) -> String {
format!("st_{}", test_name)
format!("st_{}", test_name)
} }
fn make_credentials(user_name: String) -> Credentials { fn make_credentials(user_name: String) -> Credentials {
let password = "abcdefgh".to_string();
Credentials::new(user_name, password).unwrap()
let password = "abcdefgh".to_string();
Credentials::new(user_name, password).unwrap()
} }
#[tokio::test] #[tokio::test]
async fn integration_connect() { async fn integration_connect() {
init();
init();
let stream = connect().await.unwrap();
let stream = connect().await.unwrap();
let credentials = make_credentials(make_user_name("connect"));
let channel = Client::new(stream).login(credentials).await.unwrap();
let credentials = make_credentials(make_user_name("connect"));
let channel = Client::new(stream).login(credentials).await.unwrap();
let inbound = channel.run(stream::pending());
tokio::pin!(inbound);
let inbound = channel.run(stream::pending());
tokio::pin!(inbound);
assert!(inbound.next().await.is_some());
assert!(inbound.next().await.is_some());
} }
#[tokio::test] #[tokio::test]
async fn integration_check_user_status() { async fn integration_check_user_status() {
init();
init();
let stream = connect().await.unwrap();
let stream = connect().await.unwrap();
let user_name = make_user_name("check_user_status");
let credentials = make_credentials(user_name.clone());
let channel = Client::new(stream).login(credentials).await.unwrap();
let user_name = make_user_name("check_user_status");
let credentials = make_credentials(user_name.clone());
let channel = Client::new(stream).login(credentials).await.unwrap();
let (request_tx, mut request_rx) = mpsc::channel(1);
let (request_tx, mut request_rx) = mpsc::channel(1);
let outbound = async_stream::stream! {
while let Some(request) = request_rx.recv().await {
yield request;
let outbound = async_stream::stream! {
while let Some(request) = request_rx.recv().await {
yield request;
}
};
let inbound = channel.run(outbound);
tokio::pin!(inbound);
request_tx
.send(
UserStatusRequest {
user_name: user_name.clone(),
} }
};
let inbound = channel.run(outbound);
tokio::pin!(inbound);
request_tx
.send(
UserStatusRequest {
user_name: user_name.clone(),
}
.into(),
)
.await
.unwrap();
while let Some(result) = inbound.next().await {
let response = result.unwrap();
if let ServerResponse::UserStatusResponse(response) = response {
assert_eq!(response.user_name, user_name);
return;
}
.into(),
)
.await
.unwrap();
while let Some(result) = inbound.next().await {
let response = result.unwrap();
if let ServerResponse::UserStatusResponse(response) = response {
assert_eq!(response.user_name, user_name);
return;
} }
}
unreachable!();
unreachable!();
} }

client/rustfmt.toml → rustfmt.toml View File


Loading…
Cancel
Save