Browse Source

Format and remove warnings from proto.

wip
Titouan Rigoudy 4 years ago
parent
commit
d432d58bb3
16 changed files with 4454 additions and 4689 deletions
  1. +323
    -332
      proto/src/core/frame.rs
  2. +88
    -88
      proto/src/core/prefix.rs
  3. +2
    -2
      proto/src/core/u32.rs
  4. +66
    -84
      proto/src/core/user.rs
  5. +675
    -713
      proto/src/core/value.rs
  6. +272
    -281
      proto/src/handler.rs
  7. +251
    -251
      proto/src/packet.rs
  8. +146
    -155
      proto/src/peer/message.rs
  9. +237
    -247
      proto/src/server/client.rs
  10. +105
    -106
      proto/src/server/credentials.rs
  11. +599
    -633
      proto/src/server/request.rs
  12. +1187
    -1288
      proto/src/server/response.rs
  13. +283
    -287
      proto/src/server/testing.rs
  14. +9
    -9
      proto/src/server/version.rs
  15. +165
    -166
      proto/src/stream.rs
  16. +46
    -47
      proto/tests/connect.rs

+ 323
- 332
proto/src/core/frame.rs View File

@ -14,385 +14,376 @@ use tokio::net::TcpStream;
use super::prefix::Prefixer;
use super::u32::{decode_u32, U32_BYTE_LEN};
use super::value::{
ValueDecode, ValueDecodeError, ValueDecoder, ValueEncode, ValueEncodeError,
ValueEncoder,
ValueDecode, ValueDecodeError, ValueDecoder, ValueEncode, ValueEncodeError, ValueEncoder,
};
#[derive(Debug, Error, PartialEq)]
pub enum FrameEncodeError {
#[error("encoded value length {length} is too large")]
ValueTooLarge {
/// The length of the encoded value.
length: usize,
},
#[error("failed to encode value: {0}")]
ValueEncodeError(#[from] ValueEncodeError),
#[error("encoded value length {length} is too large")]
ValueTooLarge {
/// The length of the encoded value.
length: usize,
},
#[error("failed to encode value: {0}")]
ValueEncodeError(#[from] ValueEncodeError),
}
impl From<FrameEncodeError> for io::Error {
fn from(error: FrameEncodeError) -> Self {
io::Error::new(io::ErrorKind::InvalidData, format!("{}", error))
}
fn from(error: FrameEncodeError) -> Self {
io::Error::new(io::ErrorKind::InvalidData, format!("{}", error))
}
}
/// Encodes entire protocol frames containing values of type `T`.
#[derive(Debug)]
pub struct FrameEncoder<T: ?Sized> {
phantom: PhantomData<T>,
phantom: PhantomData<T>,
}
impl<T: ValueEncode + ?Sized> FrameEncoder<T> {
pub fn new() -> Self {
Self {
phantom: PhantomData,
pub fn new() -> Self {
Self {
phantom: PhantomData,
}
}
}
pub fn encode_to(
&mut self,
value: &T,
buffer: &mut BytesMut,
) -> Result<(), FrameEncodeError> {
let mut prefixer = Prefixer::new(buffer);
pub fn encode_to(&mut self, value: &T, buffer: &mut BytesMut) -> Result<(), FrameEncodeError> {
let mut prefixer = Prefixer::new(buffer);
ValueEncoder::new(prefixer.suffix_mut()).encode(value)?;
ValueEncoder::new(prefixer.suffix_mut()).encode(value)?;
if let Err(prefixer) = prefixer.finalize() {
return Err(FrameEncodeError::ValueTooLarge {
length: prefixer.suffix().len(),
});
}
if let Err(prefixer) = prefixer.finalize() {
return Err(FrameEncodeError::ValueTooLarge {
length: prefixer.suffix().len(),
});
}
Ok(())
}
Ok(())
}
}
/// Decodes entire protocol frames containing values of type `T`.
#[derive(Debug)]
pub struct FrameDecoder<T> {
// Only here to enable parameterizing `Decoder` by `T`.
phantom: PhantomData<T>,
// Only here to enable parameterizing `Decoder` by `T`.
phantom: PhantomData<T>,
}
impl<T: ValueDecode> FrameDecoder<T> {
pub fn new() -> Self {
Self {
phantom: PhantomData,
}
}
/// Attempts to decode an entire frame from the given buffer.
///
/// Returns `Ok(Some(frame))` if successful, in which case the frame's bytes
/// have been split off from the left of `bytes`.
///
/// Returns `Ok(None)` if not enough bytes are available to decode an entire
/// frame yet, in which case `bytes` is untouched.
///
/// Returns an error if the length prefix or the framed value are malformed,
/// in which case `bytes` is untouched.
pub fn decode_from(
&mut self,
bytes: &mut BytesMut,
) -> Result<Option<T>, ValueDecodeError> {
if bytes.len() < U32_BYTE_LEN {
return Ok(None); // Not enough bytes yet.
pub fn new() -> Self {
Self {
phantom: PhantomData,
}
}
// Split the prefix off. After this:
//
// | bytes (len 4) | suffix |
//
// NOTE: This method would be simpler if we could use split_to() instead
// here such that `bytes` contained the suffix. At the end, we would not
// have to replace `bytes` with `suffix`. However, that would require
// calling `prefix.unsplit(*bytes)`, and that does not work since
// `bytes` is only borrowed, and unsplit() takes its argument by value.
let mut suffix = bytes.split_off(U32_BYTE_LEN);
// unwrap() cannot panic because `bytes` is of the exact right length.
let array: [u8; U32_BYTE_LEN] = bytes.as_ref().try_into().unwrap();
let length = decode_u32(array) as usize;
if suffix.len() < length {
// Re-assemble `bytes` as it first was.
bytes.unsplit(suffix);
return Ok(None); // Not enough bytes yet.
/// Attempts to decode an entire frame from the given buffer.
///
/// Returns `Ok(Some(frame))` if successful, in which case the frame's bytes
/// have been split off from the left of `bytes`.
///
/// Returns `Ok(None)` if not enough bytes are available to decode an entire
/// frame yet, in which case `bytes` is untouched.
///
/// Returns an error if the length prefix or the framed value are malformed,
/// in which case `bytes` is untouched.
pub fn decode_from(&mut self, bytes: &mut BytesMut) -> Result<Option<T>, ValueDecodeError> {
if bytes.len() < U32_BYTE_LEN {
return Ok(None); // Not enough bytes yet.
}
// Split the prefix off. After this:
//
// | bytes (len 4) | suffix |
//
// NOTE: This method would be simpler if we could use split_to() instead
// here such that `bytes` contained the suffix. At the end, we would not
// have to replace `bytes` with `suffix`. However, that would require
// calling `prefix.unsplit(*bytes)`, and that does not work since
// `bytes` is only borrowed, and unsplit() takes its argument by value.
let mut suffix = bytes.split_off(U32_BYTE_LEN);
// unwrap() cannot panic because `bytes` is of the exact right length.
let array: [u8; U32_BYTE_LEN] = bytes.as_ref().try_into().unwrap();
let length = decode_u32(array) as usize;
if suffix.len() < length {
// Re-assemble `bytes` as it first was.
bytes.unsplit(suffix);
return Ok(None); // Not enough bytes yet.
}
// Split off the right amount of bytes from the buffer. After this:
//
// | bytes (len 4) | contents | suffix |
//
let mut contents = suffix.split_to(length);
// Attempt to decode the value.
let item = match ValueDecoder::new(&contents).decode() {
Ok(item) => item,
Err(error) => {
// Re-assemble `bytes` as it first was.
contents.unsplit(suffix);
bytes.unsplit(contents);
return Err(error);
}
};
// Remove the decoded bytes from the left of `bytes`.
*bytes = suffix;
Ok(Some(item))
}
// Split off the right amount of bytes from the buffer. After this:
//
// | bytes (len 4) | contents | suffix |
//
let mut contents = suffix.split_to(length);
// Attempt to decode the value.
let item = match ValueDecoder::new(&contents).decode() {
Ok(item) => item,
Err(error) => {
// Re-assemble `bytes` as it first was.
contents.unsplit(suffix);
bytes.unsplit(contents);
return Err(error);
}
};
// Remove the decoded bytes from the left of `bytes`.
*bytes = suffix;
Ok(Some(item))
}
}
#[derive(Debug)]
pub struct FrameStream<ReadFrame, WriteFrame: ?Sized> {
stream: TcpStream,
stream: TcpStream,
read_buffer: BytesMut,
read_buffer: BytesMut,
decoder: FrameDecoder<ReadFrame>,
encoder: FrameEncoder<WriteFrame>,
decoder: FrameDecoder<ReadFrame>,
encoder: FrameEncoder<WriteFrame>,
}
impl<ReadFrame, WriteFrame> FrameStream<ReadFrame, WriteFrame>
where
ReadFrame: ValueDecode,
WriteFrame: ValueEncode + ?Sized,
ReadFrame: ValueDecode,
WriteFrame: ValueEncode + ?Sized,
{
pub fn new(stream: TcpStream) -> Self {
FrameStream {
stream,
read_buffer: BytesMut::new(),
decoder: FrameDecoder::new(),
encoder: FrameEncoder::new(),
pub fn new(stream: TcpStream) -> Self {
FrameStream {
stream,
read_buffer: BytesMut::new(),
decoder: FrameDecoder::new(),
encoder: FrameEncoder::new(),
}
}
}
/// Attempts to read the next frame from the underlying byte stream.
///
/// Returns `Ok(Some(frame))` on success.
/// Returns `Ok(None)` if the stream has reached the end-of-file event.
///
/// Returns an error if reading from the stream returned an error or if an
/// invalid frame was received.
pub async fn read(&mut self) -> io::Result<Option<ReadFrame>> {
loop {
if let Some(frame) = self.decoder.decode_from(&mut self.read_buffer)? {
return Ok(Some(frame));
}
if self.stream.read_buf(&mut self.read_buffer).await? == 0 {
return Ok(None);
}
/// Attempts to read the next frame from the underlying byte stream.
///
/// Returns `Ok(Some(frame))` on success.
/// Returns `Ok(None)` if the stream has reached the end-of-file event.
///
/// Returns an error if reading from the stream returned an error or if an
/// invalid frame was received.
pub async fn read(&mut self) -> io::Result<Option<ReadFrame>> {
loop {
if let Some(frame) = self.decoder.decode_from(&mut self.read_buffer)? {
return Ok(Some(frame));
}
if self.stream.read_buf(&mut self.read_buffer).await? == 0 {
return Ok(None);
}
}
}
}
pub async fn write(&mut self, frame: &WriteFrame) -> io::Result<()> {
let mut bytes = BytesMut::new();
self.encoder.encode_to(frame, &mut bytes)?;
self.stream.write_all(bytes.as_ref()).await
}
pub async fn write(&mut self, frame: &WriteFrame) -> io::Result<()> {
let mut bytes = BytesMut::new();
self.encoder.encode_to(frame, &mut bytes)?;
self.stream.write_all(bytes.as_ref()).await
}
pub async fn shutdown(&mut self) -> io::Result<()> {
self.stream.shutdown().await
}
pub async fn shutdown(&mut self) -> io::Result<()> {
self.stream.shutdown().await
}
}
#[cfg(test)]
mod tests {
use bytes::BytesMut;
use tokio::net::{TcpListener, TcpStream};
use super::{FrameDecoder, FrameEncoder, FrameStream};
// Test value: [1, 3, 3, 7] in little-endian.
const U32_1337: u32 = 1 + (3 << 8) + (3 << 16) + (7 << 24);
#[test]
fn encode_u32() {
let mut bytes = BytesMut::new();
FrameEncoder::new()
.encode_to(&U32_1337, &mut bytes)
.unwrap();
assert_eq!(
bytes,
vec![
4, 0, 0, 0, // 1 32-bit integer = 4 bytes.
1, 3, 3, 7, // Little-endian integer.
]
);
}
#[test]
fn encode_appends() {
let mut bytes = BytesMut::new();
let mut encoder = FrameEncoder::new();
encoder.encode_to(&U32_1337, &mut bytes).unwrap();
encoder.encode_to(&U32_1337, &mut bytes).unwrap();
assert_eq!(
bytes,
vec![
4, 0, 0, 0, // 1 32-bit integer = 4 bytes.
1, 3, 3, 7, // Little-endian integer.
4, 0, 0, 0, // Repeated.
1, 3, 3, 7,
]
);
}
#[test]
fn encode_vec() {
let v: Vec<u32> = vec![1, 3, 3, 7];
let mut bytes = BytesMut::new();
FrameEncoder::new().encode_to(&v, &mut bytes).unwrap();
assert_eq!(
bytes,
vec![
20, 0, 0, 0, // 5 32-bit integers = 20 bytes.
4, 0, 0, 0, // 4 elements in the vector.
1, 0, 0, 0, // Little-endian vector elements.
3, 0, 0, 0, //
3, 0, 0, 0, //
7, 0, 0, 0, //
]
);
}
#[test]
fn decode_not_enough_data_for_prefix() {
let initial_bytes = vec![
4, 0, 0, // Incomplete 32-bit length prefix.
];
let mut bytes = BytesMut::new();
bytes.extend_from_slice(&initial_bytes);
let value: Option<u32> =
FrameDecoder::new().decode_from(&mut bytes).unwrap();
assert_eq!(value, None);
assert_eq!(bytes, initial_bytes); // Untouched.
}
#[test]
fn decode_not_enough_data_for_contents() {
let initial_bytes = vec![
4, 0, 0, 0, // Length 4.
1, 2, 3, // But there are only 3 bytes!
];
let mut bytes = BytesMut::new();
bytes.extend_from_slice(&initial_bytes);
let value: Option<u32> =
FrameDecoder::new().decode_from(&mut bytes).unwrap();
assert_eq!(value, None);
assert_eq!(bytes, initial_bytes); // Untouched.
}
#[test]
fn decode_u32() {
let mut bytes = BytesMut::new();
bytes.extend_from_slice(&[
4, 0, 0, 0, // 1 32-bit integer = 4 bytes.
1, 3, 3, 7, // Little-endian integer.
4, 2, // Trailing bytes.
]);
let value = FrameDecoder::new().decode_from(&mut bytes).unwrap();
assert_eq!(value, Some(U32_1337));
assert_eq!(bytes, vec![4, 2]); // Decoded bytes were split off.
}
#[test]
fn decode_vec() {
let mut bytes = BytesMut::new();
bytes.extend_from_slice(&[
20, 0, 0, 0, // 5 32-bit integers = 20 bytes.
4, 0, 0, 0, // 4 elements in the vector.
1, 0, 0, 0, // Little-endian vector elements.
3, 0, 0, 0, //
3, 0, 0, 0, //
7, 0, 0, 0, //
4, 2, // Trailing bytes.
]);
let value = FrameDecoder::new().decode_from(&mut bytes).unwrap();
let expected_value: Vec<u32> = vec![1, 3, 3, 7];
assert_eq!(value, Some(expected_value));
assert_eq!(bytes, vec![4, 2]); // Decoded bytes were split off.
}
#[test]
fn roundtrip() {
let value: Vec<String> = vec![
"apples".to_string(), //
"bananas".to_string(), //
"oranges".to_string(), //
"and cheese!".to_string(), //
];
let mut buffer = BytesMut::new();
FrameEncoder::new().encode_to(&value, &mut buffer).unwrap();
let decoded = FrameDecoder::new().decode_from(&mut buffer).unwrap();
assert_eq!(decoded, Some(value));
assert_eq!(buffer, vec![]);
}
#[tokio::test]
async fn ping_pong() {
let listener = TcpListener::bind("localhost:0").await.unwrap();
let address = listener.local_addr().unwrap();
let server_task = tokio::spawn(async move {
let (stream, _peer_address) = listener.accept().await.unwrap();
let mut frame_stream = FrameStream::<String, str>::new(stream);
assert_eq!(frame_stream.read().await.unwrap(), Some("ping".to_string()));
frame_stream.write("pong").await.unwrap();
assert_eq!(frame_stream.read().await.unwrap(), Some("ping".to_string()));
frame_stream.write("pong").await.unwrap();
});
let stream = TcpStream::connect(address).await.unwrap();
let mut frame_stream = FrameStream::<String, str>::new(stream);
frame_stream.write("ping").await.unwrap();
assert_eq!(frame_stream.read().await.unwrap(), Some("pong".to_string()));
frame_stream.write("ping").await.unwrap();
assert_eq!(frame_stream.read().await.unwrap(), Some("pong".to_string()));
server_task.await.unwrap();
}
#[tokio::test]
async fn very_large_message() {
let listener = TcpListener::bind("localhost:0").await.unwrap();
let address = listener.local_addr().unwrap();
let server_task = tokio::spawn(async move {
let (stream, _peer_address) = listener.accept().await.unwrap();
let mut frame_stream = FrameStream::<String, Vec<u32>>::new(stream);
assert_eq!(frame_stream.read().await.unwrap(), Some("ping".to_string()));
frame_stream.write(&vec![0; 10 * 4096]).await.unwrap();
});
let stream = TcpStream::connect(address).await.unwrap();
let mut frame_stream = FrameStream::<Vec<u32>, str>::new(stream);
frame_stream.write("ping").await.unwrap();
assert_eq!(frame_stream.read().await.unwrap(), Some(vec![0; 10 * 4096]));
server_task.await.unwrap();
}
use bytes::BytesMut;
use tokio::net::{TcpListener, TcpStream};
use super::{FrameDecoder, FrameEncoder, FrameStream};
// Test value: [1, 3, 3, 7] in little-endian.
const U32_1337: u32 = 1 + (3 << 8) + (3 << 16) + (7 << 24);
#[test]
fn encode_u32() {
let mut bytes = BytesMut::new();
FrameEncoder::new()
.encode_to(&U32_1337, &mut bytes)
.unwrap();
assert_eq!(
bytes,
vec![
4, 0, 0, 0, // 1 32-bit integer = 4 bytes.
1, 3, 3, 7, // Little-endian integer.
]
);
}
#[test]
fn encode_appends() {
let mut bytes = BytesMut::new();
let mut encoder = FrameEncoder::new();
encoder.encode_to(&U32_1337, &mut bytes).unwrap();
encoder.encode_to(&U32_1337, &mut bytes).unwrap();
assert_eq!(
bytes,
vec![
4, 0, 0, 0, // 1 32-bit integer = 4 bytes.
1, 3, 3, 7, // Little-endian integer.
4, 0, 0, 0, // Repeated.
1, 3, 3, 7,
]
);
}
#[test]
fn encode_vec() {
let v: Vec<u32> = vec![1, 3, 3, 7];
let mut bytes = BytesMut::new();
FrameEncoder::new().encode_to(&v, &mut bytes).unwrap();
assert_eq!(
bytes,
vec![
20, 0, 0, 0, // 5 32-bit integers = 20 bytes.
4, 0, 0, 0, // 4 elements in the vector.
1, 0, 0, 0, // Little-endian vector elements.
3, 0, 0, 0, //
3, 0, 0, 0, //
7, 0, 0, 0, //
]
);
}
#[test]
fn decode_not_enough_data_for_prefix() {
let initial_bytes = vec![
4, 0, 0, // Incomplete 32-bit length prefix.
];
let mut bytes = BytesMut::new();
bytes.extend_from_slice(&initial_bytes);
let value: Option<u32> = FrameDecoder::new().decode_from(&mut bytes).unwrap();
assert_eq!(value, None);
assert_eq!(bytes, initial_bytes); // Untouched.
}
#[test]
fn decode_not_enough_data_for_contents() {
let initial_bytes = vec![
4, 0, 0, 0, // Length 4.
1, 2, 3, // But there are only 3 bytes!
];
let mut bytes = BytesMut::new();
bytes.extend_from_slice(&initial_bytes);
let value: Option<u32> = FrameDecoder::new().decode_from(&mut bytes).unwrap();
assert_eq!(value, None);
assert_eq!(bytes, initial_bytes); // Untouched.
}
#[test]
fn decode_u32() {
let mut bytes = BytesMut::new();
bytes.extend_from_slice(&[
4, 0, 0, 0, // 1 32-bit integer = 4 bytes.
1, 3, 3, 7, // Little-endian integer.
4, 2, // Trailing bytes.
]);
let value = FrameDecoder::new().decode_from(&mut bytes).unwrap();
assert_eq!(value, Some(U32_1337));
assert_eq!(bytes, vec![4, 2]); // Decoded bytes were split off.
}
#[test]
fn decode_vec() {
let mut bytes = BytesMut::new();
bytes.extend_from_slice(&[
20, 0, 0, 0, // 5 32-bit integers = 20 bytes.
4, 0, 0, 0, // 4 elements in the vector.
1, 0, 0, 0, // Little-endian vector elements.
3, 0, 0, 0, //
3, 0, 0, 0, //
7, 0, 0, 0, //
4, 2, // Trailing bytes.
]);
let value = FrameDecoder::new().decode_from(&mut bytes).unwrap();
let expected_value: Vec<u32> = vec![1, 3, 3, 7];
assert_eq!(value, Some(expected_value));
assert_eq!(bytes, vec![4, 2]); // Decoded bytes were split off.
}
#[test]
fn roundtrip() {
let value: Vec<String> = vec![
"apples".to_string(), //
"bananas".to_string(), //
"oranges".to_string(), //
"and cheese!".to_string(), //
];
let mut buffer = BytesMut::new();
FrameEncoder::new().encode_to(&value, &mut buffer).unwrap();
let decoded = FrameDecoder::new().decode_from(&mut buffer).unwrap();
assert_eq!(decoded, Some(value));
assert_eq!(buffer, vec![]);
}
#[tokio::test]
async fn ping_pong() {
let listener = TcpListener::bind("localhost:0").await.unwrap();
let address = listener.local_addr().unwrap();
let server_task = tokio::spawn(async move {
let (stream, _peer_address) = listener.accept().await.unwrap();
let mut frame_stream = FrameStream::<String, str>::new(stream);
assert_eq!(frame_stream.read().await.unwrap(), Some("ping".to_string()));
frame_stream.write("pong").await.unwrap();
assert_eq!(frame_stream.read().await.unwrap(), Some("ping".to_string()));
frame_stream.write("pong").await.unwrap();
});
let stream = TcpStream::connect(address).await.unwrap();
let mut frame_stream = FrameStream::<String, str>::new(stream);
frame_stream.write("ping").await.unwrap();
assert_eq!(frame_stream.read().await.unwrap(), Some("pong".to_string()));
frame_stream.write("ping").await.unwrap();
assert_eq!(frame_stream.read().await.unwrap(), Some("pong".to_string()));
server_task.await.unwrap();
}
#[tokio::test]
async fn very_large_message() {
let listener = TcpListener::bind("localhost:0").await.unwrap();
let address = listener.local_addr().unwrap();
let server_task = tokio::spawn(async move {
let (stream, _peer_address) = listener.accept().await.unwrap();
let mut frame_stream = FrameStream::<String, Vec<u32>>::new(stream);
assert_eq!(frame_stream.read().await.unwrap(), Some("ping".to_string()));
frame_stream.write(&vec![0; 10 * 4096]).await.unwrap();
});
let stream = TcpStream::connect(address).await.unwrap();
let mut frame_stream = FrameStream::<Vec<u32>, str>::new(stream);
frame_stream.write("ping").await.unwrap();
assert_eq!(frame_stream.read().await.unwrap(), Some(vec![0; 10 * 4096]));
server_task.await.unwrap();
}
}

+ 88
- 88
proto/src/core/prefix.rs View File

@ -11,111 +11,111 @@ use crate::core::u32::{encode_u32, U32_BYTE_LEN};
/// know the length ahead of encoding time.
#[derive(Debug)]
pub struct Prefixer<'a> {
/// The prefix buffer.
///
/// The length of the suffix buffer is written to the end of this buffer
/// when the prefixer is finalized.
///
/// Contains any bytes with which this prefixer was constructed.
prefix: &'a mut BytesMut,
/// The suffix buffer.
///
/// This is the buffer into which data is written before finalization.
suffix: BytesMut,
/// The prefix buffer.
///
/// The length of the suffix buffer is written to the end of this buffer
/// when the prefixer is finalized.
///
/// Contains any bytes with which this prefixer was constructed.
prefix: &'a mut BytesMut,
/// The suffix buffer.
///
/// This is the buffer into which data is written before finalization.
suffix: BytesMut,
}
impl Prefixer<'_> {
/// Constructs a prefixer for easily appending a length prefixed value to
/// the given buffer.
pub fn new<'a>(buffer: &'a mut BytesMut) -> Prefixer<'a> {
// Reserve some space fot the prefix, but don't write it yet.
buffer.reserve(U32_BYTE_LEN);
// Split off the suffix, into which bytes will be written.
let suffix = buffer.split_off(buffer.len() + U32_BYTE_LEN);
Prefixer {
prefix: buffer,
suffix: suffix,
/// Constructs a prefixer for easily appending a length prefixed value to
/// the given buffer.
pub fn new<'a>(buffer: &'a mut BytesMut) -> Prefixer<'a> {
// Reserve some space fot the prefix, but don't write it yet.
buffer.reserve(U32_BYTE_LEN);
// Split off the suffix, into which bytes will be written.
let suffix = buffer.split_off(buffer.len() + U32_BYTE_LEN);
Prefixer {
prefix: buffer,
suffix: suffix,
}
}
/// Returns a reference to the buffer into which data is written.
pub fn suffix(&self) -> &BytesMut {
&self.suffix
}
/// Returns a mutable reference to a buffer into which data can be written.
pub fn suffix_mut(&mut self) -> &mut BytesMut {
&mut self.suffix
}
/// Returns a buffer containing the original data passed at construction
/// time, to which a length-prefixed value is appended. The value itself is
/// the data written into the buffer returned by `get_mut()`.
///
/// Returns `Ok(length)` if successful, in which case the length of the
/// suffix is `length`.
///
/// Returns `Err(self)` if the length of the suffix is too large to store as
/// a prefix.
pub fn finalize(self) -> Result<u32, Self> {
// Check that the suffix's length is not too large.
let length = self.suffix.len();
let length_u32 = match u32::try_from(length) {
Ok(value) => value,
Err(_) => return Err(self),
};
// Write the prefix.
self.prefix.extend_from_slice(&encode_u32(length_u32));
// Join the prefix and suffix back again. Because `self.prefix` is
// private, we are sure that this is O(1).
self.prefix.unsplit(self.suffix);
Ok(length_u32)
}
}
/// Returns a reference to the buffer into which data is written.
pub fn suffix(&self) -> &BytesMut {
&self.suffix
}
/// Returns a mutable reference to a buffer into which data can be written.
pub fn suffix_mut(&mut self) -> &mut BytesMut {
&mut self.suffix
}
/// Returns a buffer containing the original data passed at construction
/// time, to which a length-prefixed value is appended. The value itself is
/// the data written into the buffer returned by `get_mut()`.
///
/// Returns `Ok(length)` if successful, in which case the length of the
/// suffix is `length`.
///
/// Returns `Err(self)` if the length of the suffix is too large to store as
/// a prefix.
pub fn finalize(self) -> Result<u32, Self> {
// Check that the suffix's length is not too large.
let length = self.suffix.len();
let length_u32 = match u32::try_from(length) {
Ok(value) => value,
Err(_) => return Err(self),
};
// Write the prefix.
self.prefix.extend_from_slice(&encode_u32(length_u32));
// Join the prefix and suffix back again. Because `self.prefix` is
// private, we are sure that this is O(1).
self.prefix.unsplit(self.suffix);
Ok(length_u32)
}
}
#[cfg(test)]
mod tests {
use super::Prefixer;
use super::Prefixer;
use std::convert::TryInto;
use std::convert::TryInto;
use bytes::{BufMut, BytesMut};
use bytes::{BufMut, BytesMut};
use crate::core::u32::{decode_u32, U32_BYTE_LEN};
use crate::core::u32::{decode_u32, U32_BYTE_LEN};
#[test]
fn finalize_empty() {
let mut buffer = BytesMut::new();
buffer.put_u8(13);
#[test]
fn finalize_empty() {
let mut buffer = BytesMut::new();
buffer.put_u8(13);
Prefixer::new(&mut buffer).finalize().unwrap();
Prefixer::new(&mut buffer).finalize().unwrap();
assert_eq!(buffer.len(), U32_BYTE_LEN + 1);
let array: [u8; U32_BYTE_LEN] = buffer[1..].try_into().unwrap();
assert_eq!(decode_u32(array), 0);
}
assert_eq!(buffer.len(), U32_BYTE_LEN + 1);
let array: [u8; U32_BYTE_LEN] = buffer[1..].try_into().unwrap();
assert_eq!(decode_u32(array), 0);
}
#[test]
fn finalize_ok() {
let mut buffer = BytesMut::new();
buffer.put_u8(13);
#[test]
fn finalize_ok() {
let mut buffer = BytesMut::new();
buffer.put_u8(13);
let mut prefixer = Prefixer::new(&mut buffer);
let mut prefixer = Prefixer::new(&mut buffer);
prefixer.suffix_mut().extend_from_slice(&[0; 42]);
prefixer.suffix_mut().extend_from_slice(&[0; 42]);
prefixer.finalize().unwrap();
prefixer.finalize().unwrap();
// 1 junk prefix byte, length prefix, 42 bytes of value.
assert_eq!(buffer.len(), U32_BYTE_LEN + 43);
let prefix = &buffer[1..U32_BYTE_LEN + 1];
let array: [u8; U32_BYTE_LEN] = prefix.try_into().unwrap();
assert_eq!(decode_u32(array), 42);
}
// 1 junk prefix byte, length prefix, 42 bytes of value.
assert_eq!(buffer.len(), U32_BYTE_LEN + 43);
let prefix = &buffer[1..U32_BYTE_LEN + 1];
let array: [u8; U32_BYTE_LEN] = prefix.try_into().unwrap();
assert_eq!(decode_u32(array), 42);
}
}

+ 2
- 2
proto/src/core/u32.rs View File

@ -8,10 +8,10 @@ pub const U32_BYTE_LEN: usize = 4;
/// Returns the byte representatio of the given integer value.
pub fn encode_u32(value: u32) -> [u8; U32_BYTE_LEN] {
value.to_le_bytes()
value.to_le_bytes()
}
/// Returns the integer value corresponding to the given bytes.
pub fn decode_u32(bytes: [u8; U32_BYTE_LEN]) -> u32 {
u32::from_le_bytes(bytes)
u32::from_le_bytes(bytes)
}

+ 66
- 84
proto/src/core/user.rs View File

@ -1,114 +1,96 @@
use std::io;
use crate::core::value::{
ValueDecode, ValueDecodeError, ValueDecoder, ValueEncode, ValueEncodeError,
ValueEncoder,
};
use crate::{
MutPacket, Packet, PacketReadError, ReadFromPacket, WriteToPacket,
ValueDecode, ValueDecodeError, ValueDecoder, ValueEncode, ValueEncodeError, ValueEncoder,
};
use crate::{MutPacket, Packet, PacketReadError, ReadFromPacket, WriteToPacket};
const STATUS_OFFLINE: u32 = 1;
const STATUS_AWAY: u32 = 2;
const STATUS_ONLINE: u32 = 3;
/// This enumeration is the list of possible user statuses.
#[derive(
Clone,
Copy,
Debug,
Eq,
Ord,
PartialEq,
PartialOrd,
RustcDecodable,
RustcEncodable,
)]
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd, RustcDecodable, RustcEncodable)]
pub enum UserStatus {
/// The user if offline.
Offline,
/// The user is connected, but AFK.
Away,
/// The user is present.
Online,
/// The user if offline.
Offline,
/// The user is connected, but AFK.
Away,
/// The user is present.
Online,
}
impl ReadFromPacket for UserStatus {
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let n: u32 = packet.read_value()?;
match n {
STATUS_OFFLINE => Ok(UserStatus::Offline),
STATUS_AWAY => Ok(UserStatus::Away),
STATUS_ONLINE => Ok(UserStatus::Online),
_ => Err(PacketReadError::InvalidUserStatusError(n)),
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let n: u32 = packet.read_value()?;
match n {
STATUS_OFFLINE => Ok(UserStatus::Offline),
STATUS_AWAY => Ok(UserStatus::Away),
STATUS_ONLINE => Ok(UserStatus::Online),
_ => Err(PacketReadError::InvalidUserStatusError(n)),
}
}
}
}
impl WriteToPacket for UserStatus {
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
let n = match *self {
UserStatus::Offline => STATUS_OFFLINE,
UserStatus::Away => STATUS_AWAY,
UserStatus::Online => STATUS_ONLINE,
};
packet.write_value(&n)?;
Ok(())
}
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
let n = match *self {
UserStatus::Offline => STATUS_OFFLINE,
UserStatus::Away => STATUS_AWAY,
UserStatus::Online => STATUS_ONLINE,
};
packet.write_value(&n)?;
Ok(())
}
}
impl ValueEncode for UserStatus {
fn encode_to(
&self,
encoder: &mut ValueEncoder,
) -> Result<(), ValueEncodeError> {
let value = match *self {
UserStatus::Offline => STATUS_OFFLINE,
UserStatus::Away => STATUS_AWAY,
UserStatus::Online => STATUS_ONLINE,
};
encoder.encode_u32(value)
}
fn encode_to(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> {
let value = match *self {
UserStatus::Offline => STATUS_OFFLINE,
UserStatus::Away => STATUS_AWAY,
UserStatus::Online => STATUS_ONLINE,
};
encoder.encode_u32(value)
}
}
impl ValueDecode for UserStatus {
fn decode_from(decoder: &mut ValueDecoder) -> Result<Self, ValueDecodeError> {
let position = decoder.position();
let value: u32 = decoder.decode()?;
match value {
STATUS_OFFLINE => Ok(UserStatus::Offline),
STATUS_AWAY => Ok(UserStatus::Away),
STATUS_ONLINE => Ok(UserStatus::Online),
_ => Err(ValueDecodeError::InvalidData {
value_name: "user status".to_string(),
cause: format!("unknown value {}", value),
position: position,
}),
fn decode_from(decoder: &mut ValueDecoder) -> Result<Self, ValueDecodeError> {
let position = decoder.position();
let value: u32 = decoder.decode()?;
match value {
STATUS_OFFLINE => Ok(UserStatus::Offline),
STATUS_AWAY => Ok(UserStatus::Away),
STATUS_ONLINE => Ok(UserStatus::Online),
_ => Err(ValueDecodeError::InvalidData {
value_name: "user status".to_string(),
cause: format!("unknown value {}", value),
position: position,
}),
}
}
}
}
/// This structure contains the last known information about a fellow user.
#[derive(
Clone, Debug, Eq, Ord, PartialEq, PartialOrd, RustcDecodable, RustcEncodable,
)]
#[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd, RustcDecodable, RustcEncodable)]
pub struct User {
/// The name of the user.
pub name: String,
/// The last known status of the user.
pub status: UserStatus,
/// The average upload speed of the user.
pub average_speed: usize,
/// ??? Nicotine calls it downloadnum.
pub num_downloads: usize,
/// ??? Unknown field.
pub unknown: usize,
/// The number of files this user shares.
pub num_files: usize,
/// The number of folders this user shares.
pub num_folders: usize,
/// The number of free download slots of this user.
pub num_free_slots: usize,
/// The user's country code.
pub country: String,
/// The name of the user.
pub name: String,
/// The last known status of the user.
pub status: UserStatus,
/// The average upload speed of the user.
pub average_speed: usize,
/// ??? Nicotine calls it downloadnum.
pub num_downloads: usize,
/// ??? Unknown field.
pub unknown: usize,
/// The number of files this user shares.
pub num_files: usize,
/// The number of folders this user shares.
pub num_folders: usize,
/// The number of free download slots of this user.
pub num_free_slots: usize,
/// The user's country code.
pub country: String,
}

+ 675
- 713
proto/src/core/value.rs
File diff suppressed because it is too large
View File


+ 272
- 281
proto/src/handler.rs View File

@ -31,17 +31,17 @@ const LISTEN_TOKEN: usize = config::MAX_PEERS + 1;
#[derive(Debug)]
pub enum Request {
PeerConnect(usize, net::Ipv4Addr, u16),
PeerMessage(usize, peer::Message),
ServerRequest(ServerRequest),
PeerConnect(usize, net::Ipv4Addr, u16),
PeerMessage(usize, peer::Message),
ServerRequest(ServerRequest),
}
#[derive(Debug)]
pub enum Response {
PeerConnectionClosed(usize),
PeerConnectionOpen(usize),
PeerMessage(usize, peer::Message),
ServerResponse(ServerResponse),
PeerConnectionClosed(usize),
PeerConnectionOpen(usize),
PeerMessage(usize, peer::Message),
ServerResponse(ServerResponse),
}
/*========================*
@ -51,16 +51,16 @@ pub enum Response {
pub struct ServerResponseSender(crossbeam_channel::Sender<Response>);
impl SendPacket for ServerResponseSender {
type Value = ServerResponse;
type Error = crossbeam_channel::SendError<Response>;
type Value = ServerResponse;
type Error = crossbeam_channel::SendError<Response>;
fn send_packet(&mut self, value: Self::Value) -> Result<(), Self::Error> {
self.0.send(Response::ServerResponse(value))
}
fn send_packet(&mut self, value: Self::Value) -> Result<(), Self::Error> {
self.0.send(Response::ServerResponse(value))
}
fn notify_open(&mut self) -> Result<(), Self::Error> {
Ok(())
}
fn notify_open(&mut self) -> Result<(), Self::Error> {
Ok(())
}
}
/*======================*
@ -68,21 +68,21 @@ impl SendPacket for ServerResponseSender {
*======================*/
pub struct PeerResponseSender {
sender: crossbeam_channel::Sender<Response>,
peer_id: usize,
sender: crossbeam_channel::Sender<Response>,
peer_id: usize,
}
impl SendPacket for PeerResponseSender {
type Value = peer::Message;
type Error = crossbeam_channel::SendError<Response>;
type Value = peer::Message;
type Error = crossbeam_channel::SendError<Response>;
fn send_packet(&mut self, value: Self::Value) -> Result<(), Self::Error> {
self.sender.send(Response::PeerMessage(self.peer_id, value))
}
fn send_packet(&mut self, value: Self::Value) -> Result<(), Self::Error> {
self.sender.send(Response::PeerMessage(self.peer_id, value))
}
fn notify_open(&mut self) -> Result<(), Self::Error> {
self.sender.send(Response::PeerConnectionOpen(self.peer_id))
}
fn notify_open(&mut self) -> Result<(), Self::Error> {
self.sender.send(Response::PeerConnectionOpen(self.peer_id))
}
}
/*=========*
@ -92,302 +92,293 @@ impl SendPacket for PeerResponseSender {
/// This struct handles all the soulseek connections, to the server and to
/// peers.
struct Handler {
server_stream: Stream<ServerResponseSender>,
server_stream: Stream<ServerResponseSender>,
peer_streams: slab::Slab<Stream<PeerResponseSender>, usize>,
peer_streams: slab::Slab<Stream<PeerResponseSender>, usize>,
listener: mio::tcp::TcpListener,
listener: mio::tcp::TcpListener,
client_tx: crossbeam_channel::Sender<Response>,
client_tx: crossbeam_channel::Sender<Response>,
}
fn listener_bind<U>(addr_spec: U) -> io::Result<mio::tcp::TcpListener>
where
U: ToSocketAddrs + fmt::Debug,
U: ToSocketAddrs + fmt::Debug,
{
for socket_addr in addr_spec.to_socket_addrs()? {
if let Ok(listener) = mio::tcp::TcpListener::bind(&socket_addr) {
return Ok(listener);
for socket_addr in addr_spec.to_socket_addrs()? {
if let Ok(listener) = mio::tcp::TcpListener::bind(&socket_addr) {
return Ok(listener);
}
}
}
Err(io::Error::new(
io::ErrorKind::Other,
format!("Cannot bind to {:?}", addr_spec),
))
Err(io::Error::new(
io::ErrorKind::Other,
format!("Cannot bind to {:?}", addr_spec),
))
}
impl Handler {
#[allow(deprecated)]
fn new(
client_tx: crossbeam_channel::Sender<Response>,
event_loop: &mut mio::deprecated::EventLoop<Self>,
) -> io::Result<Self> {
let host = config::SERVER_HOST;
let port = config::SERVER_PORT;
let server_stream =
Stream::new((host, port), ServerResponseSender(client_tx.clone()))?;
info!("Connected to server at {}:{}", host, port);
let listener = listener_bind((config::LISTEN_HOST, config::LISTEN_PORT))?;
info!(
"Listening for connections on {}:{}",
config::LISTEN_HOST,
config::LISTEN_PORT
);
event_loop.register(
server_stream.evented(),
mio::Token(SERVER_TOKEN),
mio::Ready::all(),
mio::PollOpt::edge() | mio::PollOpt::oneshot(),
)?;
event_loop.register(
&listener,
mio::Token(LISTEN_TOKEN),
mio::Ready::all(),
mio::PollOpt::edge() | mio::PollOpt::oneshot(),
)?;
Ok(Handler {
server_stream: server_stream,
peer_streams: slab::Slab::new(config::MAX_PEERS),
listener: listener,
client_tx: client_tx,
})
}
#[allow(deprecated)]
fn connect_to_peer(
&mut self,
peer_id: usize,
ip: net::Ipv4Addr,
port: u16,
event_loop: &mut mio::deprecated::EventLoop<Self>,
) -> Result<(), String> {
let vacant_entry = match self.peer_streams.entry(peer_id) {
None => return Err("id out of range".to_string()),
Some(slab::Entry::Occupied(_occupied_entry)) => {
return Err("id already taken".to_string());
}
Some(slab::Entry::Vacant(vacant_entry)) => vacant_entry,
};
info!("Opening peer connection {} to {}:{}", peer_id, ip, port);
let sender = PeerResponseSender {
sender: self.client_tx.clone(),
peer_id: peer_id,
};
let peer_stream = match Stream::new((ip, port), sender) {
Ok(peer_stream) => peer_stream,
Err(err) => return Err(format!("i/o error: {}", err)),
};
event_loop
.register(
peer_stream.evented(),
mio::Token(peer_id),
mio::Ready::all(),
mio::PollOpt::edge() | mio::PollOpt::oneshot(),
)
.unwrap();
vacant_entry.insert(peer_stream);
Ok(())
}
#[allow(deprecated)]
fn process_server_intent(
&mut self,
intent: Intent,
event_loop: &mut mio::deprecated::EventLoop<Self>,
) {
match intent {
Intent::Done => {
error!("Server connection closed");
// TODO notify client and shut down
}
Intent::Continue(event_set) => {
event_loop
.reregister(
self.server_stream.evented(),
#[allow(deprecated)]
fn new(
client_tx: crossbeam_channel::Sender<Response>,
event_loop: &mut mio::deprecated::EventLoop<Self>,
) -> io::Result<Self> {
let host = config::SERVER_HOST;
let port = config::SERVER_PORT;
let server_stream = Stream::new((host, port), ServerResponseSender(client_tx.clone()))?;
info!("Connected to server at {}:{}", host, port);
let listener = listener_bind((config::LISTEN_HOST, config::LISTEN_PORT))?;
info!(
"Listening for connections on {}:{}",
config::LISTEN_HOST,
config::LISTEN_PORT
);
event_loop.register(
server_stream.evented(),
mio::Token(SERVER_TOKEN),
event_set,
mio::Ready::all(),
mio::PollOpt::edge() | mio::PollOpt::oneshot(),
)
.unwrap();
}
)?;
event_loop.register(
&listener,
mio::Token(LISTEN_TOKEN),
mio::Ready::all(),
mio::PollOpt::edge() | mio::PollOpt::oneshot(),
)?;
Ok(Handler {
server_stream: server_stream,
peer_streams: slab::Slab::new(config::MAX_PEERS),
listener: listener,
client_tx: client_tx,
})
}
}
#[allow(deprecated)]
fn process_peer_intent(
&mut self,
intent: Intent,
token: mio::Token,
event_loop: &mut mio::deprecated::EventLoop<Self>,
) {
match intent {
Intent::Done => {
self.peer_streams.remove(token.0);
self
.client_tx
.send(Response::PeerConnectionClosed(token.0))
.unwrap();
}
Intent::Continue(event_set) => {
if let Some(peer_stream) = self.peer_streams.get_mut(token.0) {
event_loop
.reregister(
peer_stream.evented(),
token,
event_set,
mio::PollOpt::edge() | mio::PollOpt::oneshot(),
#[allow(deprecated)]
fn connect_to_peer(
&mut self,
peer_id: usize,
ip: net::Ipv4Addr,
port: u16,
event_loop: &mut mio::deprecated::EventLoop<Self>,
) -> Result<(), String> {
let vacant_entry = match self.peer_streams.entry(peer_id) {
None => return Err("id out of range".to_string()),
Some(slab::Entry::Occupied(_occupied_entry)) => {
return Err("id already taken".to_string());
}
Some(slab::Entry::Vacant(vacant_entry)) => vacant_entry,
};
info!("Opening peer connection {} to {}:{}", peer_id, ip, port);
let sender = PeerResponseSender {
sender: self.client_tx.clone(),
peer_id: peer_id,
};
let peer_stream = match Stream::new((ip, port), sender) {
Ok(peer_stream) => peer_stream,
Err(err) => return Err(format!("i/o error: {}", err)),
};
event_loop
.register(
peer_stream.evented(),
mio::Token(peer_id),
mio::Ready::all(),
mio::PollOpt::edge() | mio::PollOpt::oneshot(),
)
.unwrap();
vacant_entry.insert(peer_stream);
Ok(())
}
#[allow(deprecated)]
fn process_server_intent(
&mut self,
intent: Intent,
event_loop: &mut mio::deprecated::EventLoop<Self>,
) {
match intent {
Intent::Done => {
error!("Server connection closed");
// TODO notify client and shut down
}
Intent::Continue(event_set) => {
event_loop
.reregister(
self.server_stream.evented(),
mio::Token(SERVER_TOKEN),
event_set,
mio::PollOpt::edge() | mio::PollOpt::oneshot(),
)
.unwrap();
}
}
}
#[allow(deprecated)]
fn process_peer_intent(
&mut self,
intent: Intent,
token: mio::Token,
event_loop: &mut mio::deprecated::EventLoop<Self>,
) {
match intent {
Intent::Done => {
self.peer_streams.remove(token.0);
self.client_tx
.send(Response::PeerConnectionClosed(token.0))
.unwrap();
}
Intent::Continue(event_set) => {
if let Some(peer_stream) = self.peer_streams.get_mut(token.0) {
event_loop
.reregister(
peer_stream.evented(),
token,
event_set,
mio::PollOpt::edge() | mio::PollOpt::oneshot(),
)
.unwrap();
}
}
}
}
}
}
}
#[allow(deprecated)]
impl mio::deprecated::Handler for Handler {
type Timeout = ();
type Message = Request;
fn ready(
&mut self,
event_loop: &mut mio::deprecated::EventLoop<Self>,
token: mio::Token,
event_set: mio::Ready,
) {
match token {
mio::Token(LISTEN_TOKEN) => {
if event_set.is_readable() {
// A peer wants to connect to us.
match self.listener.accept() {
Ok((_sock, addr)) => {
// TODO add it to peer streams
info!("Peer connection accepted from {}", addr);
type Timeout = ();
type Message = Request;
fn ready(
&mut self,
event_loop: &mut mio::deprecated::EventLoop<Self>,
token: mio::Token,
event_set: mio::Ready,
) {
match token {
mio::Token(LISTEN_TOKEN) => {
if event_set.is_readable() {
// A peer wants to connect to us.
match self.listener.accept() {
Ok((_sock, addr)) => {
// TODO add it to peer streams
info!("Peer connection accepted from {}", addr);
}
Err(err) => {
error!("Cannot accept peer connection: {}", err);
}
}
}
event_loop
.reregister(
&self.listener,
token,
mio::Ready::all(),
mio::PollOpt::edge() | mio::PollOpt::oneshot(),
)
.unwrap();
}
Err(err) => {
error!("Cannot accept peer connection: {}", err);
mio::Token(SERVER_TOKEN) => {
let intent = self.server_stream.on_ready(event_set);
self.process_server_intent(intent, event_loop);
}
mio::Token(peer_id) => {
let intent = match self.peer_streams.get_mut(peer_id) {
Some(peer_stream) => peer_stream.on_ready(event_set),
None => unreachable!("Unknown peer {} is ready", peer_id),
};
self.process_peer_intent(intent, token, event_loop);
}
}
}
event_loop
.reregister(
&self.listener,
token,
mio::Ready::all(),
mio::PollOpt::edge() | mio::PollOpt::oneshot(),
)
.unwrap();
}
}
mio::Token(SERVER_TOKEN) => {
let intent = self.server_stream.on_ready(event_set);
self.process_server_intent(intent, event_loop);
}
fn notify(&mut self, event_loop: &mut mio::deprecated::EventLoop<Self>, request: Request) {
match request {
Request::PeerConnect(peer_id, ip, port) => {
if let Err(err) = self.connect_to_peer(peer_id, ip, port, event_loop) {
error!(
"Cannot open peer connection {} to {}:{}: {}",
peer_id, ip, port, err
);
self.client_tx
.send(Response::PeerConnectionClosed(peer_id))
.unwrap();
}
}
mio::Token(peer_id) => {
let intent = match self.peer_streams.get_mut(peer_id) {
Some(peer_stream) => peer_stream.on_ready(event_set),
Request::PeerMessage(peer_id, message) => {
let intent = match self.peer_streams.get_mut(peer_id) {
Some(peer_stream) => peer_stream.on_notify(&message),
None => {
error!(
"Cannot send peer message {:?}: unknown id {}",
message, peer_id
);
return;
}
};
self.process_peer_intent(intent, mio::Token(peer_id), event_loop);
}
None => unreachable!("Unknown peer {} is ready", peer_id),
};
self.process_peer_intent(intent, token, event_loop);
}
}
}
fn notify(
&mut self,
event_loop: &mut mio::deprecated::EventLoop<Self>,
request: Request,
) {
match request {
Request::PeerConnect(peer_id, ip, port) => {
if let Err(err) = self.connect_to_peer(peer_id, ip, port, event_loop) {
error!(
"Cannot open peer connection {} to {}:{}: {}",
peer_id, ip, port, err
);
self
.client_tx
.send(Response::PeerConnectionClosed(peer_id))
.unwrap();
Request::ServerRequest(server_request) => {
let intent = self.server_stream.on_notify(&server_request);
self.process_server_intent(intent, event_loop);
}
}
}
Request::PeerMessage(peer_id, message) => {
let intent = match self.peer_streams.get_mut(peer_id) {
Some(peer_stream) => peer_stream.on_notify(&message),
None => {
error!(
"Cannot send peer message {:?}: unknown id {}",
message, peer_id
);
return;
}
};
self.process_peer_intent(intent, mio::Token(peer_id), event_loop);
}
Request::ServerRequest(server_request) => {
let intent = self.server_stream.on_notify(&server_request);
self.process_server_intent(intent, event_loop);
}
}
}
}
#[allow(deprecated)]
pub type Sender = mio::deprecated::Sender<Request>;
pub struct Agent {
#[allow(deprecated)]
event_loop: mio::deprecated::EventLoop<Handler>,
handler: Handler,
#[allow(deprecated)]
event_loop: mio::deprecated::EventLoop<Handler>,
handler: Handler,
}
impl Agent {
pub fn new(
client_tx: crossbeam_channel::Sender<Response>,
) -> io::Result<Self> {
// Create the event loop.
#[allow(deprecated)]
let mut event_loop = mio::deprecated::EventLoop::new()?;
// Create the handler for the event loop and register the handler's
// sockets with the event loop.
let handler = Handler::new(client_tx, &mut event_loop)?;
Ok(Agent {
event_loop: event_loop,
handler: handler,
})
}
pub fn channel(&self) -> Sender {
#[allow(deprecated)]
self.event_loop.channel()
}
pub fn new(client_tx: crossbeam_channel::Sender<Response>) -> io::Result<Self> {
// Create the event loop.
#[allow(deprecated)]
let mut event_loop = mio::deprecated::EventLoop::new()?;
// Create the handler for the event loop and register the handler's
// sockets with the event loop.
let handler = Handler::new(client_tx, &mut event_loop)?;
Ok(Agent {
event_loop: event_loop,
handler: handler,
})
}
pub fn run(&mut self) -> io::Result<()> {
#[allow(deprecated)]
self.event_loop.run(&mut self.handler)
}
pub fn channel(&self) -> Sender {
#[allow(deprecated)]
self.event_loop.channel()
}
pub fn run(&mut self) -> io::Result<()> {
#[allow(deprecated)]
self.event_loop.run(&mut self.handler)
}
}

+ 251
- 251
proto/src/packet.rs View File

@ -19,46 +19,46 @@ use crate::core::constants::*;
#[derive(Debug)]
pub struct Packet {
/// The current read position in the byte buffer.
cursor: usize,
/// The underlying bytes.
bytes: Vec<u8>,
/// The current read position in the byte buffer.
cursor: usize,
/// The underlying bytes.
bytes: Vec<u8>,
}
impl io::Read for Packet {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let bytes_read = {
let mut slice = &self.bytes[self.cursor..];
slice.read(buf)?
};
self.cursor += bytes_read;
Ok(bytes_read)
}
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let bytes_read = {
let mut slice = &self.bytes[self.cursor..];
slice.read(buf)?
};
self.cursor += bytes_read;
Ok(bytes_read)
}
}
impl Packet {
/// Returns a readable packet struct from the wire representation of a
/// packet.
/// Assumes that the given vector is a valid length-prefixed packet.
fn from_wire(bytes: Vec<u8>) -> Self {
Packet {
cursor: U32_SIZE,
bytes: bytes,
/// Returns a readable packet struct from the wire representation of a
/// packet.
/// Assumes that the given vector is a valid length-prefixed packet.
fn from_wire(bytes: Vec<u8>) -> Self {
Packet {
cursor: U32_SIZE,
bytes: bytes,
}
}
/// Provides the main way to read data out of a binary packet.
pub fn read_value<T>(&mut self) -> Result<T, PacketReadError>
where
T: ReadFromPacket,
{
T::read_from_packet(self)
}
/// Returns the number of unread bytes remaining in the packet.
pub fn bytes_remaining(&self) -> usize {
self.bytes.len() - self.cursor
}
}
/// Provides the main way to read data out of a binary packet.
pub fn read_value<T>(&mut self) -> Result<T, PacketReadError>
where
T: ReadFromPacket,
{
T::read_from_packet(self)
}
/// Returns the number of unread bytes remaining in the packet.
pub fn bytes_remaining(&self) -> usize {
self.bytes.len() - self.cursor
}
}
/*===================*
@ -67,45 +67,45 @@ impl Packet {
#[derive(Debug)]
pub struct MutPacket {
bytes: Vec<u8>,
bytes: Vec<u8>,
}
impl MutPacket {
/// Returns an empty packet with the given packet code.
pub fn new() -> Self {
// Leave space for the eventual size of the packet.
MutPacket {
bytes: vec![0; U32_SIZE],
/// Returns an empty packet with the given packet code.
pub fn new() -> Self {
// Leave space for the eventual size of the packet.
MutPacket {
bytes: vec![0; U32_SIZE],
}
}
}
/// Provides the main way to write data into a binary packet.
pub fn write_value<T>(&mut self, val: &T) -> io::Result<()>
where
T: WriteToPacket + ?Sized,
{
val.write_to_packet(self)
}
/// Consumes the mutable packet and returns its wire representation.
pub fn into_bytes(mut self) -> Vec<u8> {
let length = (self.bytes.len() - U32_SIZE) as u32;
/// Provides the main way to write data into a binary packet.
pub fn write_value<T>(&mut self, val: &T) -> io::Result<()>
where
T: WriteToPacket + ?Sized,
{
let mut first_word = &mut self.bytes[..U32_SIZE];
first_word.write_u32::<LittleEndian>(length).unwrap();
val.write_to_packet(self)
}
/// Consumes the mutable packet and returns its wire representation.
pub fn into_bytes(mut self) -> Vec<u8> {
let length = (self.bytes.len() - U32_SIZE) as u32;
{
let mut first_word = &mut self.bytes[..U32_SIZE];
first_word.write_u32::<LittleEndian>(length).unwrap();
}
self.bytes
}
self.bytes
}
}
impl io::Write for MutPacket {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.bytes.write(buf)
}
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.bytes.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.bytes.flush()
}
fn flush(&mut self) -> io::Result<()> {
self.bytes.flush()
}
}
/*===================*
@ -115,68 +115,68 @@ impl io::Write for MutPacket {
/// This enum contains an error that arose when reading data out of a Packet.
#[derive(Debug)]
pub enum PacketReadError {
/// Attempted to read a boolean, but the value was not 0 nor 1.
InvalidBoolError(u8),
/// Attempted to read an unsigned 16-bit integer, but the value was too
/// large.
InvalidU16Error(u32),
/// Attempted to read a string, but a character was invalid.
InvalidStringError(Vec<u8>),
/// Attempted to read a user::Status, but the value was not a valid
/// representation of an enum variant.
InvalidUserStatusError(u32),
/// Encountered an I/O error while reading.
IOError(io::Error),
/// Attempted to read a boolean, but the value was not 0 nor 1.
InvalidBoolError(u8),
/// Attempted to read an unsigned 16-bit integer, but the value was too
/// large.
InvalidU16Error(u32),
/// Attempted to read a string, but a character was invalid.
InvalidStringError(Vec<u8>),
/// Attempted to read a user::Status, but the value was not a valid
/// representation of an enum variant.
InvalidUserStatusError(u32),
/// Encountered an I/O error while reading.
IOError(io::Error),
}
impl fmt::Display for PacketReadError {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
match *self {
PacketReadError::InvalidBoolError(n) => {
write!(fmt, "InvalidBoolError: {}", n)
}
PacketReadError::InvalidU16Error(n) => {
write!(fmt, "InvalidU16Error: {}", n)
}
PacketReadError::InvalidStringError(ref bytes) => {
write!(fmt, "InvalidStringError: {:?}", bytes)
}
PacketReadError::InvalidUserStatusError(n) => {
write!(fmt, "InvalidUserStatusError: {}", n)
}
PacketReadError::IOError(ref err) => {
write!(fmt, "IOError: {}", err)
}
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
match *self {
PacketReadError::InvalidBoolError(n) => {
write!(fmt, "InvalidBoolError: {}", n)
}
PacketReadError::InvalidU16Error(n) => {
write!(fmt, "InvalidU16Error: {}", n)
}
PacketReadError::InvalidStringError(ref bytes) => {
write!(fmt, "InvalidStringError: {:?}", bytes)
}
PacketReadError::InvalidUserStatusError(n) => {
write!(fmt, "InvalidUserStatusError: {}", n)
}
PacketReadError::IOError(ref err) => {
write!(fmt, "IOError: {}", err)
}
}
}
}
}
impl error::Error for PacketReadError {
fn description(&self) -> &str {
match *self {
PacketReadError::InvalidBoolError(_) => "InvalidBoolError",
PacketReadError::InvalidU16Error(_) => "InvalidU16Error",
PacketReadError::InvalidStringError(_) => "InvalidStringError",
PacketReadError::InvalidUserStatusError(_) => "InvalidUserStatusError",
PacketReadError::IOError(_) => "IOError",
fn description(&self) -> &str {
match *self {
PacketReadError::InvalidBoolError(_) => "InvalidBoolError",
PacketReadError::InvalidU16Error(_) => "InvalidU16Error",
PacketReadError::InvalidStringError(_) => "InvalidStringError",
PacketReadError::InvalidUserStatusError(_) => "InvalidUserStatusError",
PacketReadError::IOError(_) => "IOError",
}
}
}
fn cause(&self) -> Option<&dyn error::Error> {
match *self {
PacketReadError::InvalidBoolError(_) => None,
PacketReadError::InvalidU16Error(_) => None,
PacketReadError::InvalidStringError(_) => None,
PacketReadError::InvalidUserStatusError(_) => None,
PacketReadError::IOError(ref err) => Some(err),
fn cause(&self) -> Option<&dyn error::Error> {
match *self {
PacketReadError::InvalidBoolError(_) => None,
PacketReadError::InvalidU16Error(_) => None,
PacketReadError::InvalidStringError(_) => None,
PacketReadError::InvalidUserStatusError(_) => None,
PacketReadError::IOError(ref err) => Some(err),
}
}
}
}
impl From<io::Error> for PacketReadError {
fn from(err: io::Error) -> Self {
PacketReadError::IOError(err)
}
fn from(err: io::Error) -> Self {
PacketReadError::IOError(err)
}
}
/*==================*
@ -186,81 +186,81 @@ impl From<io::Error> for PacketReadError {
/// This trait is implemented by types that can be deserialized from binary
/// Packets.
pub trait ReadFromPacket: Sized {
fn read_from_packet(_: &mut Packet) -> Result<Self, PacketReadError>;
fn read_from_packet(_: &mut Packet) -> Result<Self, PacketReadError>;
}
/// 32-bit integers are serialized in 4 bytes, little-endian.
impl ReadFromPacket for u32 {
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
Ok(packet.read_u32::<LittleEndian>()?)
}
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
Ok(packet.read_u32::<LittleEndian>()?)
}
}
/// For convenience, usize's are deserialized as u32's then casted.
impl ReadFromPacket for usize {
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
Ok(u32::read_from_packet(packet)? as usize)
}
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
Ok(u32::read_from_packet(packet)? as usize)
}
}
/// Booleans are serialized as single bytes, containing either 0 or 1.
impl ReadFromPacket for bool {
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
match packet.read_u8()? {
0 => Ok(false),
1 => Ok(true),
n => Err(PacketReadError::InvalidBoolError(n)),
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
match packet.read_u8()? {
0 => Ok(false),
1 => Ok(true),
n => Err(PacketReadError::InvalidBoolError(n)),
}
}
}
}
/// 16-bit integers are serialized as 32-bit integers.
impl ReadFromPacket for u16 {
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let n = u32::read_from_packet(packet)?;
if n > MAX_PORT {
return Err(PacketReadError::InvalidU16Error(n));
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let n = u32::read_from_packet(packet)?;
if n > MAX_PORT {
return Err(PacketReadError::InvalidU16Error(n));
}
Ok(n as u16)
}
Ok(n as u16)
}
}
/// IPv4 addresses are serialized directly as 32-bit integers.
impl ReadFromPacket for net::Ipv4Addr {
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let ip = u32::read_from_packet(packet)?;
Ok(net::Ipv4Addr::from(ip))
}
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let ip = u32::read_from_packet(packet)?;
Ok(net::Ipv4Addr::from(ip))
}
}
/// Strings are serialized as length-prefixed arrays of ISO-8859-1 encoded
/// characters.
impl ReadFromPacket for String {
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let len = usize::read_from_packet(packet)?;
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let len = usize::read_from_packet(packet)?;
let mut buffer = vec![0; len];
packet.read_exact(&mut buffer)?;
let mut buffer = vec![0; len];
packet.read_exact(&mut buffer)?;
match ISO_8859_1.decode(&buffer, DecoderTrap::Strict) {
Ok(string) => Ok(string),
Err(_) => Err(PacketReadError::InvalidStringError(buffer)),
match ISO_8859_1.decode(&buffer, DecoderTrap::Strict) {
Ok(string) => Ok(string),
Err(_) => Err(PacketReadError::InvalidStringError(buffer)),
}
}
}
}
/// Vectors are serialized as length-prefixed arrays of values.
impl<T: ReadFromPacket> ReadFromPacket for Vec<T> {
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let len = usize::read_from_packet(packet)?;
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let len = usize::read_from_packet(packet)?;
let mut vec = Vec::new();
for _ in 0..len {
vec.push(T::read_from_packet(packet)?);
}
let mut vec = Vec::new();
for _ in 0..len {
vec.push(T::read_from_packet(packet)?);
}
Ok(vec)
}
Ok(vec)
}
}
/*=================*
@ -270,55 +270,55 @@ impl<T: ReadFromPacket> ReadFromPacket for Vec<T> {
/// This trait is implemented by types that can be serialized to a binary
/// MutPacket.
pub trait WriteToPacket {
fn write_to_packet(&self, _: &mut MutPacket) -> io::Result<()>;
fn write_to_packet(&self, _: &mut MutPacket) -> io::Result<()>;
}
/// 32-bit integers are serialized in 4 bytes, little-endian.
impl WriteToPacket for u32 {
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
packet.write_u32::<LittleEndian>(*self)
}
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
packet.write_u32::<LittleEndian>(*self)
}
}
/// Booleans are serialized as single bytes, containing either 0 or 1.
impl WriteToPacket for bool {
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
packet.write_u8(*self as u8)?;
Ok(())
}
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
packet.write_u8(*self as u8)?;
Ok(())
}
}
/// 16-bit integers are serialized as 32-bit integers.
impl WriteToPacket for u16 {
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
(*self as u32).write_to_packet(packet)
}
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
(*self as u32).write_to_packet(packet)
}
}
/// Strings are serialized as a length-prefixed array of ISO-8859-1 encoded
/// characters.
impl WriteToPacket for str {
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
// Encode the string.
let bytes = match ISO_8859_1.encode(self, EncoderTrap::Strict) {
Ok(bytes) => bytes,
Err(_) => {
let copy = self.to_string();
return Err(io::Error::new(io::ErrorKind::Other, copy));
}
};
// Then write the bytes to the packet.
(bytes.len() as u32).write_to_packet(packet)?;
packet.write(&bytes)?;
Ok(())
}
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
// Encode the string.
let bytes = match ISO_8859_1.encode(self, EncoderTrap::Strict) {
Ok(bytes) => bytes,
Err(_) => {
let copy = self.to_string();
return Err(io::Error::new(io::ErrorKind::Other, copy));
}
};
// Then write the bytes to the packet.
(bytes.len() as u32).write_to_packet(packet)?;
packet.write(&bytes)?;
Ok(())
}
}
/// Deref coercion does not happen for trait methods apparently.
impl WriteToPacket for String {
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
(self as &str).write_to_packet(packet)
}
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
(self as &str).write_to_packet(packet)
}
}
/*========*
@ -328,86 +328,86 @@ impl WriteToPacket for String {
/// This enum defines the possible states of a packet parser state machine.
#[derive(Debug, Clone, Copy)]
enum State {
/// The parser is waiting to read enough bytes to determine the
/// length of the following packet.
ReadingLength,
/// The parser is waiting to read enough bytes to form the entire
/// packet.
ReadingPacket,
/// The parser is waiting to read enough bytes to determine the
/// length of the following packet.
ReadingLength,
/// The parser is waiting to read enough bytes to form the entire
/// packet.
ReadingPacket,
}
#[derive(Debug)]
pub struct Parser {
state: State,
num_bytes_left: usize,
buffer: Vec<u8>,
state: State,
num_bytes_left: usize,
buffer: Vec<u8>,
}
impl Parser {
pub fn new() -> Self {
Parser {
state: State::ReadingLength,
num_bytes_left: U32_SIZE,
buffer: vec![0; U32_SIZE],
}
}
/// Attemps to read a packet in a non-blocking fashion.
/// If enough bytes can be read from the given byte stream to form a
/// complete packet `p`, returns `Ok(Some(p))`.
/// If not enough bytes are available, returns `Ok(None)`.
/// If an I/O error `e` arises when trying to read the underlying stream,
/// returns `Err(e)`.
/// Note: as long as this function returns `Ok(Some(p))`, the caller is
/// responsible for calling it once more to ensure that all packets are
/// read as soon as possible.
pub fn try_read<U>(&mut self, stream: &mut U) -> io::Result<Option<Packet>>
where
U: io::Read,
{
// Try to read as many bytes as we currently need from the underlying
// byte stream.
let offset = self.buffer.len() - self.num_bytes_left;
#[allow(deprecated)]
match stream.try_read(&mut self.buffer[offset..])? {
None => (),
Some(num_bytes_read) => {
self.num_bytes_left -= num_bytes_read;
}
pub fn new() -> Self {
Parser {
state: State::ReadingLength,
num_bytes_left: U32_SIZE,
buffer: vec![0; U32_SIZE],
}
}
// If we haven't read enough bytes, return.
if self.num_bytes_left > 0 {
return Ok(None);
}
// Otherwise, the behavior depends on what state we were in.
match self.state {
State::ReadingLength => {
// If we have finished reading the length prefix, then
// deserialize it, switch states and try to read the packet
// bytes.
let message_len = LittleEndian::read_u32(&mut self.buffer) as usize;
if message_len > MAX_MESSAGE_SIZE {
unimplemented!();
};
self.state = State::ReadingPacket;
self.num_bytes_left = message_len;
self.buffer.resize(message_len + U32_SIZE, 0);
self.try_read(stream)
}
State::ReadingPacket => {
// If we have finished reading the packet, swap the full buffer
// out and return the packet made from the full buffer.
self.state = State::ReadingLength;
self.num_bytes_left = U32_SIZE;
let new_buffer = vec![0; U32_SIZE];
let old_buffer = mem::replace(&mut self.buffer, new_buffer);
Ok(Some(Packet::from_wire(old_buffer)))
}
/// Attemps to read a packet in a non-blocking fashion.
/// If enough bytes can be read from the given byte stream to form a
/// complete packet `p`, returns `Ok(Some(p))`.
/// If not enough bytes are available, returns `Ok(None)`.
/// If an I/O error `e` arises when trying to read the underlying stream,
/// returns `Err(e)`.
/// Note: as long as this function returns `Ok(Some(p))`, the caller is
/// responsible for calling it once more to ensure that all packets are
/// read as soon as possible.
pub fn try_read<U>(&mut self, stream: &mut U) -> io::Result<Option<Packet>>
where
U: io::Read,
{
// Try to read as many bytes as we currently need from the underlying
// byte stream.
let offset = self.buffer.len() - self.num_bytes_left;
#[allow(deprecated)]
match stream.try_read(&mut self.buffer[offset..])? {
None => (),
Some(num_bytes_read) => {
self.num_bytes_left -= num_bytes_read;
}
}
// If we haven't read enough bytes, return.
if self.num_bytes_left > 0 {
return Ok(None);
}
// Otherwise, the behavior depends on what state we were in.
match self.state {
State::ReadingLength => {
// If we have finished reading the length prefix, then
// deserialize it, switch states and try to read the packet
// bytes.
let message_len = LittleEndian::read_u32(&mut self.buffer) as usize;
if message_len > MAX_MESSAGE_SIZE {
unimplemented!();
};
self.state = State::ReadingPacket;
self.num_bytes_left = message_len;
self.buffer.resize(message_len + U32_SIZE, 0);
self.try_read(stream)
}
State::ReadingPacket => {
// If we have finished reading the packet, swap the full buffer
// out and return the packet made from the full buffer.
self.state = State::ReadingLength;
self.num_bytes_left = U32_SIZE;
let new_buffer = vec![0; U32_SIZE];
let old_buffer = mem::replace(&mut self.buffer, new_buffer);
Ok(Some(Packet::from_wire(old_buffer)))
}
}
}
}
}

+ 146
- 155
proto/src/peer/message.rs View File

@ -3,13 +3,10 @@ use std::io;
use log::warn;
use crate::core::value::{
ValueDecode, ValueDecodeError, ValueDecoder, ValueEncode, ValueEncodeError,
ValueEncoder,
ValueDecode, ValueDecodeError, ValueDecoder, ValueEncode, ValueEncodeError, ValueEncoder,
};
use crate::peer::constants::*;
use crate::{
MutPacket, Packet, PacketReadError, ReadFromPacket, WriteToPacket,
};
use crate::{MutPacket, Packet, PacketReadError, ReadFromPacket, WriteToPacket};
/*=========*
* MESSAGE *
@ -18,189 +15,183 @@ use crate::{
/// This enum contains all the possible messages peers can exchange.
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum Message {
PierceFirewall(u32),
PeerInit(PeerInit),
Unknown(u32),
PierceFirewall(u32),
PeerInit(PeerInit),
Unknown(u32),
}
impl ReadFromPacket for Message {
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let code: u32 = packet.read_value()?;
let message = match code {
CODE_PIERCE_FIREWALL => Message::PierceFirewall(packet.read_value()?),
CODE_PEER_INIT => Message::PeerInit(packet.read_value()?),
code => Message::Unknown(code),
};
let bytes_remaining = packet.bytes_remaining();
if bytes_remaining > 0 {
warn!(
"Peer message with code {} contains {} extra bytes",
code, bytes_remaining
)
}
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let code: u32 = packet.read_value()?;
let message = match code {
CODE_PIERCE_FIREWALL => Message::PierceFirewall(packet.read_value()?),
CODE_PEER_INIT => Message::PeerInit(packet.read_value()?),
Ok(message)
}
code => Message::Unknown(code),
};
let bytes_remaining = packet.bytes_remaining();
if bytes_remaining > 0 {
warn!(
"Peer message with code {} contains {} extra bytes",
code, bytes_remaining
)
}
Ok(message)
}
}
impl ValueDecode for Message {
fn decode_from(decoder: &mut ValueDecoder) -> Result<Self, ValueDecodeError> {
let position = decoder.position();
let code: u32 = decoder.decode()?;
let message = match code {
CODE_PIERCE_FIREWALL => {
let val = decoder.decode()?;
Message::PierceFirewall(val)
}
CODE_PEER_INIT => {
let peer_init = decoder.decode()?;
Message::PeerInit(peer_init)
}
_ => {
return Err(ValueDecodeError::InvalidData {
value_name: "peer message code".to_string(),
cause: format!("unknown value {}", code),
position: position,
})
}
};
Ok(message)
}
fn decode_from(decoder: &mut ValueDecoder) -> Result<Self, ValueDecodeError> {
let position = decoder.position();
let code: u32 = decoder.decode()?;
let message = match code {
CODE_PIERCE_FIREWALL => {
let val = decoder.decode()?;
Message::PierceFirewall(val)
}
CODE_PEER_INIT => {
let peer_init = decoder.decode()?;
Message::PeerInit(peer_init)
}
_ => {
return Err(ValueDecodeError::InvalidData {
value_name: "peer message code".to_string(),
cause: format!("unknown value {}", code),
position: position,
})
}
};
Ok(message)
}
}
impl ValueEncode for Message {
fn encode_to(
&self,
encoder: &mut ValueEncoder,
) -> Result<(), ValueEncodeError> {
match *self {
Message::PierceFirewall(token) => {
encoder.encode_u32(CODE_PIERCE_FIREWALL)?;
encoder.encode_u32(token)?;
}
Message::PeerInit(ref request) => {
encoder.encode_u32(CODE_PEER_INIT)?;
request.encode_to(encoder)?;
}
Message::Unknown(_) => unreachable!(),
fn encode_to(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> {
match *self {
Message::PierceFirewall(token) => {
encoder.encode_u32(CODE_PIERCE_FIREWALL)?;
encoder.encode_u32(token)?;
}
Message::PeerInit(ref request) => {
encoder.encode_u32(CODE_PEER_INIT)?;
request.encode_to(encoder)?;
}
Message::Unknown(_) => unreachable!(),
}
Ok(())
}
Ok(())
}
}
impl WriteToPacket for Message {
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
match *self {
Message::PierceFirewall(ref token) => {
packet.write_value(&CODE_PIERCE_FIREWALL)?;
packet.write_value(token)?;
}
Message::PeerInit(ref request) => {
packet.write_value(&CODE_PEER_INIT)?;
packet.write_value(request)?;
}
Message::Unknown(_) => unreachable!(),
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
match *self {
Message::PierceFirewall(ref token) => {
packet.write_value(&CODE_PIERCE_FIREWALL)?;
packet.write_value(token)?;
}
Message::PeerInit(ref request) => {
packet.write_value(&CODE_PEER_INIT)?;
packet.write_value(request)?;
}
Message::Unknown(_) => unreachable!(),
}
Ok(())
}
Ok(())
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct PeerInit {
pub user_name: String,
pub connection_type: String,
pub token: u32,
pub user_name: String,
pub connection_type: String,
pub token: u32,
}
impl ReadFromPacket for PeerInit {
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let user_name = packet.read_value()?;
let connection_type = packet.read_value()?;
let token = packet.read_value()?;
Ok(PeerInit {
user_name,
connection_type,
token,
})
}
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let user_name = packet.read_value()?;
let connection_type = packet.read_value()?;
let token = packet.read_value()?;
Ok(PeerInit {
user_name,
connection_type,
token,
})
}
}
impl WriteToPacket for PeerInit {
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
packet.write_value(&self.user_name)?;
packet.write_value(&self.connection_type)?;
packet.write_value(&self.token)?;
Ok(())
}
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
packet.write_value(&self.user_name)?;
packet.write_value(&self.connection_type)?;
packet.write_value(&self.token)?;
Ok(())
}
}
impl ValueEncode for PeerInit {
fn encode_to(
&self,
encoder: &mut ValueEncoder,
) -> Result<(), ValueEncodeError> {
encoder.encode_string(&self.user_name)?;
encoder.encode_string(&self.connection_type)?;
encoder.encode_u32(self.token)?;
Ok(())
}
fn encode_to(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> {
encoder.encode_string(&self.user_name)?;
encoder.encode_string(&self.connection_type)?;
encoder.encode_u32(self.token)?;
Ok(())
}
}
impl ValueDecode for PeerInit {
fn decode_from(decoder: &mut ValueDecoder) -> Result<Self, ValueDecodeError> {
let user_name = decoder.decode()?;
let connection_type = decoder.decode()?;
let token = decoder.decode()?;
Ok(PeerInit {
user_name,
connection_type,
token,
})
}
fn decode_from(decoder: &mut ValueDecoder) -> Result<Self, ValueDecodeError> {
let user_name = decoder.decode()?;
let connection_type = decoder.decode()?;
let token = decoder.decode()?;
Ok(PeerInit {
user_name,
connection_type,
token,
})
}
}
#[cfg(test)]
mod tests {
use bytes::BytesMut;
use crate::core::value::tests::roundtrip;
use crate::core::value::{ValueDecodeError, ValueDecoder};
use super::*;
#[test]
fn invalid_code() {
let mut bytes = BytesMut::new();
bytes.extend_from_slice(&[57, 5, 0, 0]);
let result = ValueDecoder::new(&bytes).decode::<Message>();
assert_eq!(
result,
Err(ValueDecodeError::InvalidData {
value_name: "peer message code".to_string(),
cause: "unknown value 1337".to_string(),
position: 0,
})
);
}
#[test]
fn roundtrip_pierce_firewall() {
roundtrip(Message::PierceFirewall(1337))
}
#[test]
fn roundtrip_peer_init() {
roundtrip(Message::PeerInit(PeerInit {
user_name: "alice".to_string(),
connection_type: "P".to_string(),
token: 1337,
}));
}
use bytes::BytesMut;
use crate::core::value::tests::roundtrip;
use crate::core::value::{ValueDecodeError, ValueDecoder};
use super::*;
#[test]
fn invalid_code() {
let mut bytes = BytesMut::new();
bytes.extend_from_slice(&[57, 5, 0, 0]);
let result = ValueDecoder::new(&bytes).decode::<Message>();
assert_eq!(
result,
Err(ValueDecodeError::InvalidData {
value_name: "peer message code".to_string(),
cause: "unknown value 1337".to_string(),
position: 0,
})
);
}
#[test]
fn roundtrip_pierce_firewall() {
roundtrip(Message::PierceFirewall(1337))
}
#[test]
fn roundtrip_peer_init() {
roundtrip(Message::PeerInit(PeerInit {
user_name: "alice".to_string(),
connection_type: "P".to_string(),
token: 1337,
}));
}
}

+ 237
- 247
proto/src/server/client.rs View File

@ -8,301 +8,291 @@ use thiserror::Error;
use tokio::net;
use crate::core::frame::FrameStream;
use crate::server::{
Credentials, LoginResponse, ServerRequest, ServerResponse, Version,
};
use crate::server::{Credentials, LoginResponse, ServerRequest, ServerResponse, Version};
/// Specifies options for a new `Client`.
pub struct ClientOptions {
pub credentials: Credentials,
pub version: Version,
pub credentials: Credentials,
pub version: Version,
}
/// A client for the client-server protocol.
pub struct Client {
frame_stream: FrameStream<ServerResponse, ServerRequest>,
frame_stream: FrameStream<ServerResponse, ServerRequest>,
}
/// An error that arose while logging in to a remote server.
#[derive(Debug, Error)]
pub enum ClientLoginError {
#[error("login failed: {0}")]
LoginFailed(String),
#[error("login failed: {0}")]
LoginFailed(String),
#[error("unexpected response: {0:?}")]
UnexpectedResponse(ServerResponse),
#[error("unexpected response: {0:?}")]
UnexpectedResponse(ServerResponse),
#[error("unexpected end of file")]
UnexpectedEof,
#[error("unexpected end of file")]
UnexpectedEof,
#[error("i/o error: {0}")]
IOError(#[from] io::Error),
#[error("i/o error: {0}")]
IOError(#[from] io::Error),
}
/// An error that arose while running the client.
#[derive(Debug, Error)]
pub enum ClientRunError {
#[error("underlying stream was closed unexpectedly")]
StreamClosed,
#[error("underlying stream was closed unexpectedly")]
StreamClosed,
#[error("i/o error: {0}")]
IOError(#[from] io::Error),
#[error("i/o error: {0}")]
IOError(#[from] io::Error),
}
impl ClientRunError {
fn is_stream_closed(&self) -> bool {
match self {
ClientRunError::StreamClosed => true,
_ => false,
#[cfg(test)]
fn is_stream_closed(&self) -> bool {
match self {
ClientRunError::StreamClosed => true,
_ => false,
}
}
}
}
enum RunOnceResult {
Break,
Continue,
Response(ServerResponse),
Break,
Continue,
Response(ServerResponse),
}
impl Client {
pub async fn login(
tcp_stream: net::TcpStream,
options: ClientOptions,
) -> Result<Client, ClientLoginError> {
let mut client = Client {
frame_stream: FrameStream::new(tcp_stream),
};
pub async fn login(
tcp_stream: net::TcpStream,
options: ClientOptions,
) -> Result<Client, ClientLoginError> {
let mut client = Client {
frame_stream: FrameStream::new(tcp_stream),
};
client.handshake(options).await?;
client.handshake(options).await?;
Ok(client)
}
// Performs the login exchange.
// Called this way because `login` is already taken.
async fn handshake(
&mut self,
options: ClientOptions,
) -> Result<(), ClientLoginError> {
let login_request = options.credentials.into_login_request(options.version);
debug!("Client: sending login request: {:?}", login_request);
let request = login_request.into();
self.frame_stream.write(&request).await?;
let response = self.frame_stream.read().await?;
debug!("Client: received first response: {:?}", response);
match response {
Some(ServerResponse::LoginResponse(LoginResponse::LoginOk {
motd,
ip,
password_md5_opt,
})) => {
info!("Client: Logged in successfully!");
info!("Client: Message Of The Day: {}", motd);
info!("Client: Public IP address: {}", ip);
info!("Client: Password MD5: {:?}", password_md5_opt);
Ok(())
}
Some(ServerResponse::LoginResponse(LoginResponse::LoginFail {
reason,
})) => Err(ClientLoginError::LoginFailed(reason)),
Some(response) => Err(ClientLoginError::UnexpectedResponse(response)),
None => Err(ClientLoginError::UnexpectedEof),
Ok(client)
}
}
async fn run_once<S>(
&mut self,
request_stream: &mut S,
) -> Result<RunOnceResult, ClientRunError>
where
S: Stream<Item = ServerRequest> + Unpin,
{
tokio::select!(
maybe_request = request_stream.next() => {
if let Some(request) = maybe_request {
debug!("Client: sending request: {:?}", request);
self.frame_stream.write(&request).await?;
Ok(RunOnceResult::Continue)
} else {
// Sender has been dropped.
Ok(RunOnceResult::Break)
// Performs the login exchange.
// Called this way because `login` is already taken.
async fn handshake(&mut self, options: ClientOptions) -> Result<(), ClientLoginError> {
let login_request = options.credentials.into_login_request(options.version);
debug!("Client: sending login request: {:?}", login_request);
let request = login_request.into();
self.frame_stream.write(&request).await?;
let response = self.frame_stream.read().await?;
debug!("Client: received first response: {:?}", response);
match response {
Some(ServerResponse::LoginResponse(LoginResponse::LoginOk {
motd,
ip,
password_md5_opt,
})) => {
info!("Client: Logged in successfully!");
info!("Client: Message Of The Day: {}", motd);
info!("Client: Public IP address: {}", ip);
info!("Client: Password MD5: {:?}", password_md5_opt);
Ok(())
}
Some(ServerResponse::LoginResponse(LoginResponse::LoginFail { reason })) => {
Err(ClientLoginError::LoginFailed(reason))
}
Some(response) => Err(ClientLoginError::UnexpectedResponse(response)),
None => Err(ClientLoginError::UnexpectedEof),
}
},
read_result = self.frame_stream.read() => {
match read_result? {
Some(response) => {
}
async fn run_once<S>(&mut self, request_stream: &mut S) -> Result<RunOnceResult, ClientRunError>
where
S: Stream<Item = ServerRequest> + Unpin,
{
tokio::select!(
maybe_request = request_stream.next() => {
if let Some(request) = maybe_request {
debug!("Client: sending request: {:?}", request);
self.frame_stream.write(&request).await?;
Ok(RunOnceResult::Continue)
} else {
// Sender has been dropped.
Ok(RunOnceResult::Break)
}
},
read_result = self.frame_stream.read() => {
match read_result? {
Some(response) => {
debug!("Client: received response: {:?}", response);
Ok(RunOnceResult::Response(response))
}
None => Err(ClientRunError::StreamClosed),
}
},
)
}
pub fn run<S>(
mut self,
mut request_stream: S,
) -> impl Stream<Item = Result<ServerResponse, ClientRunError>> + Unpin
where
S: Stream<Item = ServerRequest> + Unpin,
{
Box::pin(async_stream::try_stream! {
// Drive the main loop: send requests and receive responses.
loop {
match self.run_once(&mut request_stream).await? {
RunOnceResult::Break => break,
RunOnceResult::Continue => continue,
RunOnceResult::Response(response) => yield response,
}
}
debug!("Client: shutting down outbound stream");
self.frame_stream.shutdown().await?;
// Drain the receiving end of the connection.
while let Some(response) = self.frame_stream.read().await? {
debug!("Client: received response: {:?}", response);
Ok(RunOnceResult::Response(response))
yield response;
}
None => Err(ClientRunError::StreamClosed),
}
},
)
}
pub fn run<S>(
mut self,
mut request_stream: S,
) -> impl Stream<Item = Result<ServerResponse, ClientRunError>> + Unpin
where
S: Stream<Item = ServerRequest> + Unpin,
{
Box::pin(async_stream::try_stream! {
// Drive the main loop: send requests and receive responses.
loop {
match self.run_once(&mut request_stream).await? {
RunOnceResult::Break => break,
RunOnceResult::Continue => continue,
RunOnceResult::Response(response) => yield response,
}
}
debug!("Client: shutting down outbound stream");
self.frame_stream.shutdown().await?;
// Drain the receiving end of the connection.
while let Some(response) = self.frame_stream.read().await? {
debug!("Client: received response: {:?}", response);
yield response;
}
})
}
})
}
}
#[cfg(test)]
mod tests {
use futures::stream::{empty, StreamExt};
use tokio::net;
use tokio::sync::mpsc;
use crate::server::testing::{
ServerBuilder, ShutdownType, UserStatusMap,
};
use crate::server::{
Credentials, ServerRequest, ServerResponse, UserStatusRequest,
UserStatusResponse,
};
use crate::UserStatus;
use super::{Client, ClientOptions, Version};
// Enable capturing logs in tests.
fn init() {
let _ = env_logger::builder().is_test(true).try_init();
}
// Returns default ClientOptions suitable for testing.
fn client_options() -> ClientOptions {
let user_name = "alice".to_string();
let password = "sekrit".to_string();
let credentials = Credentials::new(user_name, password).unwrap();
ClientOptions {
credentials,
version: Version::default(),
use futures::stream::{empty, StreamExt};
use tokio::net;
use tokio::sync::mpsc;
use crate::server::testing::{ServerBuilder, ShutdownType, UserStatusMap};
use crate::server::{
Credentials, ServerRequest, ServerResponse, UserStatusRequest, UserStatusResponse,
};
use crate::UserStatus;
use super::{Client, ClientOptions, Version};
// Enable capturing logs in tests.
fn init() {
let _ = env_logger::builder().is_test(true).try_init();
}
}
#[tokio::test]
async fn login() {
init();
// Returns default ClientOptions suitable for testing.
fn client_options() -> ClientOptions {
let user_name = "alice".to_string();
let password = "sekrit".to_string();
let credentials = Credentials::new(user_name, password).unwrap();
let (server, handle) = ServerBuilder::default().bind().await.unwrap();
let server_task = tokio::spawn(server.serve());
ClientOptions {
credentials,
version: Version::default(),
}
}
let stream = net::TcpStream::connect(handle.address()).await.unwrap();
#[tokio::test]
async fn login() {
init();
let client = Client::login(stream, client_options()).await.unwrap();
let (server, handle) = ServerBuilder::default().bind().await.unwrap();
let server_task = tokio::spawn(server.serve());
// Send nothing, receive no responses.
let mut inbound = client.run(empty());
assert!(inbound.next().await.is_none());
let stream = net::TcpStream::connect(handle.address()).await.unwrap();
handle.shutdown(ShutdownType::LameDuck);
server_task.await.unwrap().unwrap();
}
let client = Client::login(stream, client_options()).await.unwrap();
#[tokio::test]
async fn simple_exchange() {
init();
// Send nothing, receive no responses.
let mut inbound = client.run(empty());
assert!(inbound.next().await.is_none());
let response = UserStatusResponse {
user_name: "alice".to_string(),
status: UserStatus::Online,
is_privileged: false,
};
handle.shutdown(ShutdownType::LameDuck);
server_task.await.unwrap().unwrap();
}
#[tokio::test]
async fn simple_exchange() {
init();
let response = UserStatusResponse {
user_name: "alice".to_string(),
status: UserStatus::Online,
is_privileged: false,
};
let mut user_status_map = UserStatusMap::default();
user_status_map.insert(response.clone());
let (server, handle) = ServerBuilder::default()
.with_user_status_map(user_status_map)
.bind()
.await
.unwrap();
let server_task = tokio::spawn(server.serve());
let stream = net::TcpStream::connect(handle.address()).await.unwrap();
let client = Client::login(stream, client_options()).await.unwrap();
let outbound = Box::pin(async_stream::stream! {
yield ServerRequest::UserStatusRequest(UserStatusRequest {
user_name: "bob".to_string(),
});
yield ServerRequest::UserStatusRequest(UserStatusRequest {
user_name: "alice".to_string(),
});
});
let mut inbound = client.run(outbound);
assert_eq!(
inbound.next().await.unwrap().unwrap(),
ServerResponse::UserStatusResponse(response)
);
assert!(inbound.next().await.is_none());
handle.shutdown(ShutdownType::LameDuck);
server_task.await.unwrap().unwrap();
}
#[tokio::test]
async fn stream_closed() {
init();
let (server, handle) = ServerBuilder::default().bind().await.unwrap();
let server_task = tokio::spawn(server.serve());
let mut user_status_map = UserStatusMap::default();
user_status_map.insert(response.clone());
let (server, handle) = ServerBuilder::default()
.with_user_status_map(user_status_map)
.bind()
.await
.unwrap();
let server_task = tokio::spawn(server.serve());
let stream = net::TcpStream::connect(handle.address()).await.unwrap();
let client = Client::login(stream, client_options()).await.unwrap();
let outbound = Box::pin(async_stream::stream! {
yield ServerRequest::UserStatusRequest(UserStatusRequest {
user_name: "bob".to_string(),
});
yield ServerRequest::UserStatusRequest(UserStatusRequest {
user_name: "alice".to_string(),
});
});
let mut inbound = client.run(outbound);
assert_eq!(
inbound.next().await.unwrap().unwrap(),
ServerResponse::UserStatusResponse(response)
);
assert!(inbound.next().await.is_none());
handle.shutdown(ShutdownType::LameDuck);
server_task.await.unwrap().unwrap();
}
#[tokio::test]
async fn stream_closed() {
init();
let (server, handle) = ServerBuilder::default().bind().await.unwrap();
let server_task = tokio::spawn(server.serve());
let stream = net::TcpStream::connect(handle.address()).await.unwrap();
let client = Client::login(stream, client_options()).await.unwrap();
let (_request_tx, mut request_rx) = mpsc::channel(1);
let outbound = Box::pin(async_stream::stream! {
while let Some(request) = request_rx.recv().await {
yield request;
}
});
let mut inbound = client.run(outbound);
// Server shuts down, closing its connection before the client has had a
// chance to send all of `outbound`.
handle.shutdown(ShutdownType::Immediate);
// Wait for the server to terminate, to avoid race conditions.
server_task.await.unwrap().unwrap();
// Check that the client returns the correct error, then stops running.
assert!(inbound
.next()
.await
.unwrap()
.unwrap_err()
.is_stream_closed());
assert!(inbound.next().await.is_none());
}
let stream = net::TcpStream::connect(handle.address()).await.unwrap();
let client = Client::login(stream, client_options()).await.unwrap();
let (_request_tx, mut request_rx) = mpsc::channel(1);
let outbound = Box::pin(async_stream::stream! {
while let Some(request) = request_rx.recv().await {
yield request;
}
});
let mut inbound = client.run(outbound);
// Server shuts down, closing its connection before the client has had a
// chance to send all of `outbound`.
handle.shutdown(ShutdownType::Immediate);
// Wait for the server to terminate, to avoid race conditions.
server_task.await.unwrap().unwrap();
// Check that the client returns the correct error, then stops running.
assert!(inbound
.next()
.await
.unwrap()
.unwrap_err()
.is_stream_closed());
assert!(inbound.next().await.is_none());
}
}

+ 105
- 106
proto/src/server/credentials.rs View File

@ -8,120 +8,119 @@ use crate::server::{LoginRequest, Version};
/// Credentials for logging in a client to a server.
#[derive(Debug, Eq, PartialEq)]
pub struct Credentials {
user_name: String,
password: String,
digest: String,
user_name: String,
password: String,
digest: String,
}
impl Credentials {
/// Attempts to build credentials.
///
/// Returns `None` if `password` is empty, `Some(credentials)` otherwise.
pub fn new(user_name: String, password: String) -> Option<Credentials> {
if password.is_empty() {
return None;
/// Attempts to build credentials.
///
/// Returns `None` if `password` is empty, `Some(credentials)` otherwise.
pub fn new(user_name: String, password: String) -> Option<Credentials> {
if password.is_empty() {
return None;
}
let mut hasher = Md5::new();
hasher.input_str(&user_name);
hasher.input_str(&password);
let digest = hasher.result_str();
Some(Credentials {
user_name,
password,
digest,
})
}
let mut hasher = Md5::new();
hasher.input_str(&user_name);
hasher.input_str(&password);
let digest = hasher.result_str();
Some(Credentials {
user_name,
password,
digest,
})
}
/// The user name to log in as.
pub fn user_name(&self) -> &str {
&self.user_name
}
/// The password to log in with.
pub fn password(&self) -> &str {
&self.password
}
/// Returns md5(user_name + password).
pub fn digest(&self) -> &str {
&self.digest
}
pub fn into_login_request(self, version: Version) -> LoginRequest {
LoginRequest {
user_name: self.user_name,
password: self.password,
digest: self.digest,
major: version.major,
minor: version.minor,
/// The user name to log in as.
pub fn user_name(&self) -> &str {
&self.user_name
}
/// The password to log in with.
pub fn password(&self) -> &str {
&self.password
}
/// Returns md5(user_name + password).
pub fn digest(&self) -> &str {
&self.digest
}
pub fn into_login_request(self, version: Version) -> LoginRequest {
LoginRequest {
user_name: self.user_name,
password: self.password,
digest: self.digest,
major: version.major,
minor: version.minor,
}
}
}
}
#[cfg(test)]
mod tests {
use crate::server::{LoginRequest, Version};
use super::Credentials;
#[test]
fn empty_password() {
assert_eq!(Credentials::new("alice".to_string(), String::new()), None);
}
#[test]
fn new_success() {
let user_name = "alice".to_string();
let password = "sekrit".to_string();
let credentials = Credentials::new(user_name, password).unwrap();
assert_eq!(credentials.user_name(), "alice");
assert_eq!(credentials.password(), "sekrit");
}
#[test]
fn digest() {
let user_name = "alice".to_string();
let password = "sekrit".to_string();
let credentials = Credentials::new(user_name, password).unwrap();
// To generate the expected value on the command line;
//
// ```sh
// $ user_name="alice"
// $ password="sekrit"
// $ echo -e "${user_name}${password}" | md5sum
// ```
assert_eq!(credentials.digest(), "286da88eb442032bdd3913979af76e8a");
}
#[test]
fn into_login_request() {
let user_name = "alice".to_string();
let password = "sekrit".to_string();
let credentials =
Credentials::new(user_name.clone(), password.clone()).unwrap();
let digest = credentials.digest().to_string();
let version = Version {
major: 13,
minor: 37,
};
assert_eq!(
credentials.into_login_request(version),
LoginRequest {
user_name,
password,
digest,
major: 13,
minor: 37,
}
);
}
use crate::server::{LoginRequest, Version};
use super::Credentials;
#[test]
fn empty_password() {
assert_eq!(Credentials::new("alice".to_string(), String::new()), None);
}
#[test]
fn new_success() {
let user_name = "alice".to_string();
let password = "sekrit".to_string();
let credentials = Credentials::new(user_name, password).unwrap();
assert_eq!(credentials.user_name(), "alice");
assert_eq!(credentials.password(), "sekrit");
}
#[test]
fn digest() {
let user_name = "alice".to_string();
let password = "sekrit".to_string();
let credentials = Credentials::new(user_name, password).unwrap();
// To generate the expected value on the command line;
//
// ```sh
// $ user_name="alice"
// $ password="sekrit"
// $ echo -e "${user_name}${password}" | md5sum
// ```
assert_eq!(credentials.digest(), "286da88eb442032bdd3913979af76e8a");
}
#[test]
fn into_login_request() {
let user_name = "alice".to_string();
let password = "sekrit".to_string();
let credentials = Credentials::new(user_name.clone(), password.clone()).unwrap();
let digest = credentials.digest().to_string();
let version = Version {
major: 13,
minor: 37,
};
assert_eq!(
credentials.into_login_request(version),
LoginRequest {
user_name,
password,
digest,
major: 13,
minor: 37,
}
);
}
}

+ 599
- 633
proto/src/server/request.rs
File diff suppressed because it is too large
View File


+ 1187
- 1288
proto/src/server/response.rs
File diff suppressed because it is too large
View File


+ 283
- 287
proto/src/server/testing.rs View File

@ -14,381 +14,377 @@ use tokio::sync::watch;
use crate::core::frame::FrameStream;
use crate::server::{
LoginResponse, ServerRequest, ServerResponse, UserStatusRequest,
UserStatusResponse,
LoginResponse, ServerRequest, ServerResponse, UserStatusRequest, UserStatusResponse,
};
#[derive(Debug, Default)]
pub struct UserStatusMap {
map: HashMap<String, UserStatusResponse>,
map: HashMap<String, UserStatusResponse>,
}
// IDEA: impl FromIterator<UserStatusResponse> for UserStatusMap.
impl UserStatusMap {
pub fn insert(&mut self, response: UserStatusResponse) {
self.map.insert(response.user_name.clone(), response);
}
pub fn insert(&mut self, response: UserStatusResponse) {
self.map.insert(response.user_name.clone(), response);
}
pub fn get(&self, user_name: &str) -> Option<UserStatusResponse> {
self.map.get(user_name).map(|response| response.clone())
}
pub fn get(&self, user_name: &str) -> Option<UserStatusResponse> {
self.map.get(user_name).map(|response| response.clone())
}
}
struct Handler {
frame_stream: FrameStream<ServerRequest, ServerResponse>,
peer_address: SocketAddr,
user_status_map: Arc<Mutex<UserStatusMap>>,
frame_stream: FrameStream<ServerRequest, ServerResponse>,
peer_address: SocketAddr,
user_status_map: Arc<Mutex<UserStatusMap>>,
}
impl Handler {
fn ipv4_address(&self) -> Ipv4Addr {
match self.peer_address.ip() {
IpAddr::V4(ipv4_addr) => ipv4_addr,
IpAddr::V6(_) => Ipv4Addr::UNSPECIFIED,
}
}
async fn send_response(
&mut self,
response: &ServerResponse,
) -> io::Result<()> {
debug!("Handler: sending response: {:?}", response);
self.frame_stream.write(response).await
}
async fn handle_login(&mut self) -> io::Result<()> {
match self.frame_stream.read().await? {
Some(ServerRequest::LoginRequest(request)) => {
info!("Handler: Received login request: {:?}", request);
}
Some(request) => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("expected login request, got: {:?}", request),
));
}
None => {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"expected login request".to_string(),
));
}
};
let response = ServerResponse::LoginResponse(LoginResponse::LoginOk {
motd: "hi there".to_string(),
ip: self.ipv4_address(),
password_md5_opt: None,
});
self.send_response(&response).await
}
async fn handle_request(&mut self, request: ServerRequest) -> io::Result<()> {
debug!("Handler: received request: {:?}", request);
match request {
ServerRequest::UserStatusRequest(UserStatusRequest { user_name }) => {
let entry = self.user_status_map.lock().get(&user_name);
if let Some(response) = entry {
let response = ServerResponse::UserStatusResponse(response);
self.send_response(&response).await?;
fn ipv4_address(&self) -> Ipv4Addr {
match self.peer_address.ip() {
IpAddr::V4(ipv4_addr) => ipv4_addr,
IpAddr::V6(_) => Ipv4Addr::UNSPECIFIED,
}
}
_ => {
warn!("Handler: unhandled request: {:?}", request);
}
}
Ok(())
}
async fn send_response(&mut self, response: &ServerResponse) -> io::Result<()> {
debug!("Handler: sending response: {:?}", response);
self.frame_stream.write(response).await
}
async fn handle_login(&mut self) -> io::Result<()> {
match self.frame_stream.read().await? {
Some(ServerRequest::LoginRequest(request)) => {
info!("Handler: Received login request: {:?}", request);
}
Some(request) => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("expected login request, got: {:?}", request),
));
}
None => {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"expected login request".to_string(),
));
}
};
let response = ServerResponse::LoginResponse(LoginResponse::LoginOk {
motd: "hi there".to_string(),
ip: self.ipv4_address(),
password_md5_opt: None,
});
self.send_response(&response).await
}
async fn run(mut self) -> io::Result<()> {
self.handle_login().await?;
async fn handle_request(&mut self, request: ServerRequest) -> io::Result<()> {
debug!("Handler: received request: {:?}", request);
while let Some(request) = self.frame_stream.read().await? {
self.handle_request(request).await?;
match request {
ServerRequest::UserStatusRequest(UserStatusRequest { user_name }) => {
let entry = self.user_status_map.lock().get(&user_name);
if let Some(response) = entry {
let response = ServerResponse::UserStatusResponse(response);
self.send_response(&response).await?;
}
}
_ => {
warn!("Handler: unhandled request: {:?}", request);
}
}
Ok(())
}
info!("Handler: client disconnecting, shutting down");
Ok(())
}
async fn run(mut self) -> io::Result<()> {
self.handle_login().await?;
while let Some(request) = self.frame_stream.read().await? {
self.handle_request(request).await?;
}
info!("Handler: client disconnecting, shutting down");
Ok(())
}
}
struct GracefulHandler {
handler: Handler,
shutdown_rx: watch::Receiver<()>,
handler: Handler,
shutdown_rx: watch::Receiver<()>,
}
impl GracefulHandler {
async fn run(mut self) -> io::Result<()> {
tokio::select!(
result = self.handler.run() => {
if let Err(ref error) = result {
warn!("GracefulHandler: handler returned error {:?}", error);
}
result
},
// Ignore receive errors - if shutdown_rx's sender is dropped, we take
// that as a signal to shut down too.
_ = self.shutdown_rx.changed() => {
info!("GracefulHandler: shutting down.");
Ok(())
},
)
}
async fn run(mut self) -> io::Result<()> {
tokio::select!(
result = self.handler.run() => {
if let Err(ref error) = result {
warn!("GracefulHandler: handler returned error {:?}", error);
}
result
},
// Ignore receive errors - if shutdown_rx's sender is dropped, we take
// that as a signal to shut down too.
_ = self.shutdown_rx.changed() => {
info!("GracefulHandler: shutting down.");
Ok(())
},
)
}
}
struct SenderHandler {
handler: GracefulHandler,
result_tx: mpsc::Sender<io::Result<()>>,
handler: GracefulHandler,
result_tx: mpsc::Sender<io::Result<()>>,
}
impl SenderHandler {
async fn run(self) {
let result = self.handler.run().await;
let _ = self.result_tx.send(result).await;
}
async fn run(self) {
let result = self.handler.run().await;
let _ = self.result_tx.send(result).await;
}
}
/// A builder for Server instances.
#[derive(Default)]
pub struct ServerBuilder {
user_status_map: Option<Arc<Mutex<UserStatusMap>>>,
user_status_map: Option<Arc<Mutex<UserStatusMap>>>,
}
impl ServerBuilder {
/// Sets the UserStatusMap which the server will use to respond to
/// UserStatusRequest messages.
pub fn with_user_status_map(mut self, map: UserStatusMap) -> Self {
self.user_status_map = Some(Arc::new(Mutex::new(map)));
self
}
/// Binds to a localhost port, then returns a server and its handle.
pub async fn bind(self) -> io::Result<(Server, ServerHandle)> {
let listener = TcpListener::bind("localhost:0").await?;
let address = listener.local_addr()?;
let user_status_map = match self.user_status_map {
Some(user_status_map) => user_status_map,
None => Arc::new(Mutex::new(UserStatusMap::default())),
};
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let (handler_shutdown_tx, handler_shutdown_rx) = watch::channel(());
let (result_tx, result_rx) = mpsc::channel(1);
Ok((
Server {
listener,
shutdown_rx,
handler_shutdown_tx,
handler_shutdown_rx,
result_tx,
result_rx,
user_status_map,
},
ServerHandle {
shutdown_tx,
address,
},
))
}
/// Sets the UserStatusMap which the server will use to respond to
/// UserStatusRequest messages.
pub fn with_user_status_map(mut self, map: UserStatusMap) -> Self {
self.user_status_map = Some(Arc::new(Mutex::new(map)));
self
}
/// Binds to a localhost port, then returns a server and its handle.
pub async fn bind(self) -> io::Result<(Server, ServerHandle)> {
let listener = TcpListener::bind("localhost:0").await?;
let address = listener.local_addr()?;
let user_status_map = match self.user_status_map {
Some(user_status_map) => user_status_map,
None => Arc::new(Mutex::new(UserStatusMap::default())),
};
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let (handler_shutdown_tx, handler_shutdown_rx) = watch::channel(());
let (result_tx, result_rx) = mpsc::channel(1);
Ok((
Server {
listener,
shutdown_rx,
handler_shutdown_tx,
handler_shutdown_rx,
result_tx,
result_rx,
user_status_map,
},
ServerHandle {
shutdown_tx,
address,
},
))
}
}
/// Specifies how to shut down a server.
pub enum ShutdownType {
/// Shut down immediately, sever open connections.
Immediate,
/// Shut down immediately, sever open connections.
Immediate,
/// Stop accepting new connections, wait for open connections to close.
LameDuck,
/// Stop accepting new connections, wait for open connections to close.
LameDuck,
}
/// A simple server for connecting to in tests.
pub struct Server {
// Listener for new connections.
listener: TcpListener,
// Listener for new connections.
listener: TcpListener,
// Receiver for ServerHandle shutdown() notification.
shutdown_rx: oneshot::Receiver<ShutdownType>,
// Receiver for ServerHandle shutdown() notification.
shutdown_rx: oneshot::Receiver<ShutdownType>,
// Watch channel for signalling immediate termination to handlers.
handler_shutdown_tx: watch::Sender<()>,
handler_shutdown_rx: watch::Receiver<()>,
// Watch channel for signalling immediate termination to handlers.
handler_shutdown_tx: watch::Sender<()>,
handler_shutdown_rx: watch::Receiver<()>,
// Channel for receiving results back from handlers.
result_tx: mpsc::Sender<io::Result<()>>,
result_rx: mpsc::Receiver<io::Result<()>>,
// Channel for receiving results back from handlers.
result_tx: mpsc::Sender<io::Result<()>>,
result_rx: mpsc::Receiver<io::Result<()>>,
// Shared state for handlers to use when serving responses.
user_status_map: Arc<Mutex<UserStatusMap>>,
// Shared state for handlers to use when serving responses.
user_status_map: Arc<Mutex<UserStatusMap>>,
}
/// Allows interacting with a running `Server`.
pub struct ServerHandle {
shutdown_tx: oneshot::Sender<ShutdownType>,
address: SocketAddr,
shutdown_tx: oneshot::Sender<ShutdownType>,
address: SocketAddr,
}
impl ServerHandle {
/// Returns the address on which the server is accepting connections.
pub fn address(&self) -> SocketAddr {
self.address
}
/// Starts shutting down the server.
/// Does nothing if the server is already shutting down or even dropped.
pub fn shutdown(self, how: ShutdownType) {
// Ignore send errors, which mean that the server has been dropped.
let _ = self.shutdown_tx.send(how);
}
/// Returns the address on which the server is accepting connections.
pub fn address(&self) -> SocketAddr {
self.address
}
/// Starts shutting down the server.
/// Does nothing if the server is already shutting down or even dropped.
pub fn shutdown(self, how: ShutdownType) {
// Ignore send errors, which mean that the server has been dropped.
let _ = self.shutdown_tx.send(how);
}
}
impl Server {
/// Returns the address to which this server is bound.
/// This is always localhost and a random port chosen by the OS.
pub fn address(&self) -> io::Result<SocketAddr> {
self.listener.local_addr()
}
/// Spawns a handler for the given new stream, initiated by a remote peer.
fn spawn_handler(&mut self, stream: TcpStream, peer_address: SocketAddr) {
let handler = SenderHandler {
handler: GracefulHandler {
handler: Handler {
frame_stream: FrameStream::new(stream),
peer_address,
user_status_map: self.user_status_map.clone(),
},
shutdown_rx: self.handler_shutdown_rx.clone(),
},
result_tx: self.result_tx.clone(),
};
tokio::spawn(handler.run());
}
/// Accepts a single connection, spawns a handler for it and returns.
///
/// Useful for tests who need to guarantee a handler is spawned before the
/// server shuts down.
async fn accept(&mut self) -> io::Result<()> {
let (stream, peer_address) = self.listener.accept().await?;
self.spawn_handler(stream, peer_address);
Ok(())
}
/// Runs the server: accepts incoming connections and responds to requests.
///
/// Returns an error if:
///
/// - an error was encountered while listening
/// - an error was encountered while serving a request
///
pub async fn serve(mut self) -> io::Result<()> {
loop {
tokio::select!(
result = self.listener.accept() => {
let (stream, peer_address) = result?;
self.spawn_handler(stream, peer_address);
},
result = &mut self.shutdown_rx => {
// If shutdown_rx's sender is dropped and we receive an error, we take
// that as a signal to shut down immediately too.
match result.unwrap_or(ShutdownType::Immediate) {
ShutdownType::LameDuck => break,
ShutdownType::Immediate => {
// Send errors cannot happen, since we hold onto a receiver in
// `self.handler_shutdown_rx`.
self.handler_shutdown_tx.send(()).unwrap();
break
}
}
},
optional_result = self.result_rx.recv() => {
// We can never exhaust the result channel because we hold onto a
// sender in `self.result_tx`.
let result = optional_result.unwrap();
// Return an error if a handler returns an error.
result?;
}
);
/// Returns the address to which this server is bound.
/// This is always localhost and a random port chosen by the OS.
pub fn address(&self) -> io::Result<SocketAddr> {
self.listener.local_addr()
}
/// Spawns a handler for the given new stream, initiated by a remote peer.
fn spawn_handler(&mut self, stream: TcpStream, peer_address: SocketAddr) {
let handler = SenderHandler {
handler: GracefulHandler {
handler: Handler {
frame_stream: FrameStream::new(stream),
peer_address,
user_status_map: self.user_status_map.clone(),
},
shutdown_rx: self.handler_shutdown_rx.clone(),
},
result_tx: self.result_tx.clone(),
};
tokio::spawn(handler.run());
}
info!("Server: shutting down");
drop(self.result_tx);
while let Some(result) = self.result_rx.recv().await {
result?;
/// Accepts a single connection, spawns a handler for it and returns.
///
/// Useful for tests who need to guarantee a handler is spawned before the
/// server shuts down.
async fn accept(&mut self) -> io::Result<()> {
let (stream, peer_address) = self.listener.accept().await?;
self.spawn_handler(stream, peer_address);
Ok(())
}
Ok(())
}
/// Runs the server: accepts incoming connections and responds to requests.
///
/// Returns an error if:
///
/// - an error was encountered while listening
/// - an error was encountered while serving a request
///
pub async fn serve(mut self) -> io::Result<()> {
loop {
tokio::select!(
result = self.listener.accept() => {
let (stream, peer_address) = result?;
self.spawn_handler(stream, peer_address);
},
result = &mut self.shutdown_rx => {
// If shutdown_rx's sender is dropped and we receive an error, we take
// that as a signal to shut down immediately too.
match result.unwrap_or(ShutdownType::Immediate) {
ShutdownType::LameDuck => break,
ShutdownType::Immediate => {
// Send errors cannot happen, since we hold onto a receiver in
// `self.handler_shutdown_rx`.
self.handler_shutdown_tx.send(()).unwrap();
break
}
}
},
optional_result = self.result_rx.recv() => {
// We can never exhaust the result channel because we hold onto a
// sender in `self.result_tx`.
let result = optional_result.unwrap();
// Return an error if a handler returns an error.
result?;
}
);
}
info!("Server: shutting down");
drop(self.result_tx);
while let Some(result) = self.result_rx.recv().await {
result?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::io;
use std::io;
use tokio::net::TcpStream;
use tokio::net::TcpStream;
use super::{ServerBuilder, ShutdownType};
use super::{ServerBuilder, ShutdownType};
// Enable capturing logs in tests.
fn init() {
let _ = env_logger::builder().is_test(true).try_init();
}
// Enable capturing logs in tests.
fn init() {
let _ = env_logger::builder().is_test(true).try_init();
}
#[tokio::test]
async fn new_binds_to_localhost() {
init();
#[tokio::test]
async fn new_binds_to_localhost() {
init();
let (server, handle) = ServerBuilder::default().bind().await.unwrap();
assert!(server.address().unwrap().ip().is_loopback());
assert_eq!(server.address().unwrap(), handle.address());
}
let (server, handle) = ServerBuilder::default().bind().await.unwrap();
assert!(server.address().unwrap().ip().is_loopback());
assert_eq!(server.address().unwrap(), handle.address());
}
#[tokio::test]
async fn accepts_incoming_connections() {
init();
#[tokio::test]
async fn accepts_incoming_connections() {
init();
let (server, handle) = ServerBuilder::default().bind().await.unwrap();
let server_task = tokio::spawn(server.serve());
let (server, handle) = ServerBuilder::default().bind().await.unwrap();
let server_task = tokio::spawn(server.serve());
// The connection succeeds.
let _ = TcpStream::connect(handle.address()).await.unwrap();
// The connection succeeds.
let _ = TcpStream::connect(handle.address()).await.unwrap();
handle.shutdown(ShutdownType::Immediate);
handle.shutdown(ShutdownType::Immediate);
// Ignore errors, which can happen when the handler task is spawned right
// before we call `handle.shutdown()`. See `serve_yields_handler_error`.
let _ = server_task.await.unwrap();
}
// Ignore errors, which can happen when the handler task is spawned right
// before we call `handle.shutdown()`. See `serve_yields_handler_error`.
let _ = server_task.await.unwrap();
}
// This test verifies that when a handler encounters an error, it is
// reflected in `Server::serve()`'s return value.
#[tokio::test]
async fn serve_yields_handler_error() {
init();
// This test verifies that when a handler encounters an error, it is
// reflected in `Server::serve()`'s return value.
#[tokio::test]
async fn serve_yields_handler_error() {
init();
let (mut server, handle) = ServerBuilder::default().bind().await.unwrap();
let (mut server, handle) = ServerBuilder::default().bind().await.unwrap();
// The connection is accepted, then immediately closed.
let address = handle.address();
tokio::spawn(async move {
let _ = TcpStream::connect(address).await.unwrap();
});
// The connection is accepted, then immediately closed.
let address = handle.address();
tokio::spawn(async move {
let _ = TcpStream::connect(address).await.unwrap();
});
// Accept the connection on the server and spawn a handler for it.
server.accept().await.unwrap();
// Accept the connection on the server and spawn a handler for it.
server.accept().await.unwrap();
// Signal that the server should stop accepting incoming connections.
handle.shutdown(ShutdownType::LameDuck);
// Signal that the server should stop accepting incoming connections.
handle.shutdown(ShutdownType::LameDuck);
// Drain outstanding requests, encountering the error.
let error = server.serve().await.unwrap_err();
assert_eq!(error.kind(), io::ErrorKind::UnexpectedEof);
}
// Drain outstanding requests, encountering the error.
let error = server.serve().await.unwrap_err();
assert_eq!(error.kind(), io::ErrorKind::UnexpectedEof);
}
}

+ 9
- 9
proto/src/server/version.rs View File

@ -2,18 +2,18 @@
/// Specifies a protocol version.
pub struct Version {
/// The major version number.
pub major: u32,
/// The major version number.
pub major: u32,
/// The minor version number.
pub minor: u32,
/// The minor version number.
pub minor: u32,
}
impl Default for Version {
fn default() -> Self {
Self {
major: 181,
minor: 100,
fn default() -> Self {
Self {
major: 181,
minor: 100,
}
}
}
}

+ 165
- 166
proto/src/stream.rs View File

@ -4,7 +4,7 @@ use std::fmt;
use std::io;
use std::net::ToSocketAddrs;
use log::{error, info};
use log::error;
use mio;
use super::packet::{MutPacket, Parser, ReadFromPacket, WriteToPacket};
@ -16,41 +16,41 @@ use super::packet::{MutPacket, Parser, ReadFromPacket, WriteToPacket};
/// A struct used for writing bytes to a TryWrite sink.
#[derive(Debug)]
struct OutBuf {
cursor: usize,
bytes: Vec<u8>,
cursor: usize,
bytes: Vec<u8>,
}
impl From<Vec<u8>> for OutBuf {
fn from(bytes: Vec<u8>) -> Self {
OutBuf {
cursor: 0,
bytes: bytes,
fn from(bytes: Vec<u8>) -> Self {
OutBuf {
cursor: 0,
bytes: bytes,
}
}
}
}
impl OutBuf {
#[inline]
fn remaining(&self) -> usize {
self.bytes.len() - self.cursor
}
#[inline]
fn has_remaining(&self) -> bool {
self.remaining() > 0
}
#[allow(deprecated)]
fn try_write_to<T>(&mut self, mut writer: T) -> io::Result<Option<usize>>
where
T: mio::deprecated::TryWrite,
{
let result = writer.try_write(&self.bytes[self.cursor..]);
if let Ok(Some(bytes_written)) = result {
self.cursor += bytes_written;
#[inline]
fn remaining(&self) -> usize {
self.bytes.len() - self.cursor
}
#[inline]
fn has_remaining(&self) -> bool {
self.remaining() > 0
}
#[allow(deprecated)]
fn try_write_to<T>(&mut self, mut writer: T) -> io::Result<Option<usize>>
where
T: mio::deprecated::TryWrite,
{
let result = writer.try_write(&self.bytes[self.cursor..]);
if let Ok(Some(bytes_written)) = result {
self.cursor += bytes_written;
}
result
}
result
}
}
/*========*
@ -60,171 +60,170 @@ impl OutBuf {
/// This trait is implemented by packet sinks to which a stream can forward
/// the packets it reads.
pub trait SendPacket {
type Value: ReadFromPacket;
type Error: error::Error;
type Value: ReadFromPacket;
type Error: error::Error;
fn send_packet(&mut self, _: Self::Value) -> Result<(), Self::Error>;
fn send_packet(&mut self, _: Self::Value) -> Result<(), Self::Error>;
fn notify_open(&mut self) -> Result<(), Self::Error>;
fn notify_open(&mut self) -> Result<(), Self::Error>;
}
/// This enum defines the possible actions the stream wants to take after
/// processing an event.
#[derive(Debug, Clone, Copy)]
pub enum Intent {
/// The stream is done, the event loop handler can drop it.
Done,
/// The stream wants to wait for the next event matching the given
/// `EventSet`.
Continue(mio::Ready),
/// The stream is done, the event loop handler can drop it.
Done,
/// The stream wants to wait for the next event matching the given
/// `EventSet`.
Continue(mio::Ready),
}
/// This struct wraps around an mio tcp stream and handles packet reads and
/// writes.
#[derive(Debug)]
pub struct Stream<T: SendPacket> {
parser: Parser,
queue: VecDeque<OutBuf>,
sender: T,
stream: mio::tcp::TcpStream,
parser: Parser,
queue: VecDeque<OutBuf>,
sender: T,
stream: mio::tcp::TcpStream,
is_connected: bool,
is_connected: bool,
}
impl<T: SendPacket> Stream<T> {
/// Returns a new stream, asynchronously connected to the given address,
/// which forwards incoming packets to the given sender.
/// If an error occurs when connecting, returns an error.
pub fn new<U>(addr_spec: U, sender: T) -> io::Result<Self>
where
U: ToSocketAddrs + fmt::Debug,
{
for sock_addr in addr_spec.to_socket_addrs()? {
if let Ok(stream) = mio::tcp::TcpStream::connect(&sock_addr) {
return Ok(Stream {
parser: Parser::new(),
queue: VecDeque::new(),
sender: sender,
stream: stream,
is_connected: false,
});
}
}
Err(io::Error::new(
io::ErrorKind::Other,
format!("Cannot connect to {:?}", addr_spec),
))
}
/// Returns a reference to the underlying byte stream, to allow it to be
/// registered with an event loop.
pub fn evented(&self) -> &mio::tcp::TcpStream {
&self.stream
}
/// The stream is ready to be read from.
fn on_readable(&mut self) -> Result<(), String> {
loop {
let mut packet = match self.parser.try_read(&mut self.stream) {
Ok(Some(packet)) => packet,
Ok(None) => break,
Err(e) => return Err(format!("Error reading stream: {}", e)),
};
let value = match packet.read_value() {
Ok(value) => value,
Err(e) => return Err(format!("Error parsing packet: {}", e)),
};
if let Err(e) = self.sender.send_packet(value) {
return Err(format!("Error sending parsed packet: {}", e));
}
}
Ok(())
}
/// The stream is ready to be written to.
fn on_writable(&mut self) -> io::Result<()> {
loop {
let mut outbuf = match self.queue.pop_front() {
Some(outbuf) => outbuf,
None => break,
};
let option = outbuf.try_write_to(&mut self.stream)?;
match option {
Some(_) => {
if outbuf.has_remaining() {
self.queue.push_front(outbuf)
}
// Continue looping
}
None => {
self.queue.push_front(outbuf);
break;
/// Returns a new stream, asynchronously connected to the given address,
/// which forwards incoming packets to the given sender.
/// If an error occurs when connecting, returns an error.
pub fn new<U>(addr_spec: U, sender: T) -> io::Result<Self>
where
U: ToSocketAddrs + fmt::Debug,
{
for sock_addr in addr_spec.to_socket_addrs()? {
if let Ok(stream) = mio::tcp::TcpStream::connect(&sock_addr) {
return Ok(Stream {
parser: Parser::new(),
queue: VecDeque::new(),
sender: sender,
stream: stream,
is_connected: false,
});
}
}
}
Err(io::Error::new(
io::ErrorKind::Other,
format!("Cannot connect to {:?}", addr_spec),
))
}
Ok(())
}
/// The stream is ready to read, write, or both.
pub fn on_ready(&mut self, event_set: mio::Ready) -> Intent {
#[allow(deprecated)]
if event_set.is_hup() || event_set.is_error() {
return Intent::Done;
}
if event_set.is_readable() {
let result = self.on_readable();
if let Err(e) = result {
error!("Stream input error: {}", e);
return Intent::Done;
}
/// Returns a reference to the underlying byte stream, to allow it to be
/// registered with an event loop.
pub fn evented(&self) -> &mio::tcp::TcpStream {
&self.stream
}
if event_set.is_writable() {
let result = self.on_writable();
if let Err(e) = result {
error!("Stream output error: {}", e);
return Intent::Done;
}
/// The stream is ready to be read from.
fn on_readable(&mut self) -> Result<(), String> {
loop {
let mut packet = match self.parser.try_read(&mut self.stream) {
Ok(Some(packet)) => packet,
Ok(None) => break,
Err(e) => return Err(format!("Error reading stream: {}", e)),
};
let value = match packet.read_value() {
Ok(value) => value,
Err(e) => return Err(format!("Error parsing packet: {}", e)),
};
if let Err(e) = self.sender.send_packet(value) {
return Err(format!("Error sending parsed packet: {}", e));
}
}
Ok(())
}
// We must have read or written something succesfully if we're here,
// so the stream must be connected.
if !self.is_connected {
// If we weren't already connected, notify the sink.
if let Err(err) = self.sender.notify_open() {
error!("Cannot notify client that stream is open: {}", err);
return Intent::Done;
}
// And record the fact that we are now connected.
self.is_connected = true;
/// The stream is ready to be written to.
fn on_writable(&mut self) -> io::Result<()> {
loop {
let mut outbuf = match self.queue.pop_front() {
Some(outbuf) => outbuf,
None => break,
};
let option = outbuf.try_write_to(&mut self.stream)?;
match option {
Some(_) => {
if outbuf.has_remaining() {
self.queue.push_front(outbuf)
}
// Continue looping
}
None => {
self.queue.push_front(outbuf);
break;
}
}
}
Ok(())
}
// We're always interested in reading more.
#[allow(deprecated)]
let mut event_set =
mio::Ready::readable() | mio::Ready::hup() | mio::Ready::error();
// If there is still stuff to write in the queue, we're interested in
// the socket becoming writable too.
if self.queue.len() > 0 {
event_set = event_set | mio::Ready::writable();
/// The stream is ready to read, write, or both.
pub fn on_ready(&mut self, event_set: mio::Ready) -> Intent {
#[allow(deprecated)]
if event_set.is_hup() || event_set.is_error() {
return Intent::Done;
}
if event_set.is_readable() {
let result = self.on_readable();
if let Err(e) = result {
error!("Stream input error: {}", e);
return Intent::Done;
}
}
if event_set.is_writable() {
let result = self.on_writable();
if let Err(e) = result {
error!("Stream output error: {}", e);
return Intent::Done;
}
}
// We must have read or written something succesfully if we're here,
// so the stream must be connected.
if !self.is_connected {
// If we weren't already connected, notify the sink.
if let Err(err) = self.sender.notify_open() {
error!("Cannot notify client that stream is open: {}", err);
return Intent::Done;
}
// And record the fact that we are now connected.
self.is_connected = true;
}
// We're always interested in reading more.
#[allow(deprecated)]
let mut event_set = mio::Ready::readable() | mio::Ready::hup() | mio::Ready::error();
// If there is still stuff to write in the queue, we're interested in
// the socket becoming writable too.
if self.queue.len() > 0 {
event_set = event_set | mio::Ready::writable();
}
Intent::Continue(event_set)
}
Intent::Continue(event_set)
}
/// The stream has been notified.
pub fn on_notify<V>(&mut self, payload: &V) -> Intent
where
V: WriteToPacket,
{
let mut packet = MutPacket::new();
let result = packet.write_value(payload);
if let Err(e) = result {
error!("Error writing payload to packet: {}", e);
return Intent::Done;
/// The stream has been notified.
pub fn on_notify<V>(&mut self, payload: &V) -> Intent
where
V: WriteToPacket,
{
let mut packet = MutPacket::new();
let result = packet.write_value(payload);
if let Err(e) = result {
error!("Error writing payload to packet: {}", e);
return Intent::Done;
}
self.queue.push_back(OutBuf::from(packet.into_bytes()));
Intent::Continue(mio::Ready::readable() | mio::Ready::writable())
}
self.queue.push_back(OutBuf::from(packet.into_bytes()));
Intent::Continue(mio::Ready::readable() | mio::Ready::writable())
}
}

+ 46
- 47
proto/tests/connect.rs View File

@ -4,84 +4,83 @@ use tokio::net;
use tokio::sync::mpsc;
use solstice_proto::server::{
Client, ClientOptions, Credentials, ServerRequest, ServerResponse,
UserStatusRequest, Version,
Client, ClientOptions, Credentials, ServerResponse, UserStatusRequest, Version,
};
// Enable capturing logs in tests.
fn init() {
let _ = env_logger::builder().is_test(true).try_init();
let _ = env_logger::builder().is_test(true).try_init();
}
async fn connect() -> io::Result<net::TcpStream> {
net::TcpStream::connect("server.slsknet.org:2242").await
net::TcpStream::connect("server.slsknet.org:2242").await
}
fn make_user_name(test_name: &str) -> String {
format!("st_{}", test_name)
format!("st_{}", test_name)
}
fn client_options(user_name: String) -> ClientOptions {
let password = "abcdefgh".to_string();
let credentials = Credentials::new(user_name, password).unwrap();
let password = "abcdefgh".to_string();
let credentials = Credentials::new(user_name, password).unwrap();
ClientOptions {
credentials,
version: Version::default(),
}
ClientOptions {
credentials,
version: Version::default(),
}
}
#[tokio::test]
async fn integration_connect() {
init();
init();
let stream = connect().await.unwrap();
let stream = connect().await.unwrap();
let options = client_options(make_user_name("connect"));
let client = Client::login(stream, options).await.unwrap();
let options = client_options(make_user_name("connect"));
let client = Client::login(stream, options).await.unwrap();
let mut inbound = client.run(stream::pending());
let mut inbound = client.run(stream::pending());
assert!(inbound.next().await.is_some());
assert!(inbound.next().await.is_some());
}
#[tokio::test]
async fn integration_check_user_status() {
init();
let stream = connect().await.unwrap();
init();
let user_name = make_user_name("check_user_status");
let options = client_options(user_name.clone());
let client = Client::login(stream, options).await.unwrap();
let stream = connect().await.unwrap();
let (request_tx, mut request_rx) = mpsc::channel(1);
let outbound = Box::pin(async_stream::stream! {
while let Some(request) = request_rx.recv().await {
yield request;
}
});
let user_name = make_user_name("check_user_status");
let options = client_options(user_name.clone());
let client = Client::login(stream, options).await.unwrap();
let mut inbound = client.run(outbound);
let (request_tx, mut request_rx) = mpsc::channel(1);
request_tx
.send(
UserStatusRequest {
user_name: user_name.clone(),
let outbound = Box::pin(async_stream::stream! {
while let Some(request) = request_rx.recv().await {
yield request;
}
.into(),
)
.await
.unwrap();
while let Some(result) = inbound.next().await {
let response = result.unwrap();
if let ServerResponse::UserStatusResponse(response) = response {
assert_eq!(response.user_name, user_name);
return;
});
let mut inbound = client.run(outbound);
request_tx
.send(
UserStatusRequest {
user_name: user_name.clone(),
}
.into(),
)
.await
.unwrap();
while let Some(result) = inbound.next().await {
let response = result.unwrap();
if let ServerResponse::UserStatusResponse(response) = response {
assert_eq!(response.user_name, user_name);
return;
}
}
}
unreachable!();
unreachable!();
}

Loading…
Cancel
Save