Browse Source

Set indent width to 2.

wip
Titouan Rigoudy 4 years ago
parent
commit
381ec05247
28 changed files with 5396 additions and 5694 deletions
  1. +1
    -0
      rustfmt.toml
  2. +695
    -713
      src/client.rs
  3. +21
    -21
      src/context.rs
  4. +16
    -16
      src/control/request.rs
  5. +46
    -46
      src/control/response.rs
  6. +134
    -136
      src/control/ws.rs
  7. +55
    -64
      src/dispatcher.rs
  8. +103
    -103
      src/executor.rs
  9. +29
    -33
      src/handlers/login_handler.rs
  10. +34
    -34
      src/handlers/set_privileged_users_handler.rs
  11. +10
    -10
      src/login.rs
  12. +31
    -32
      src/main.rs
  13. +4
    -4
      src/message_handler.rs
  14. +320
    -322
      src/proto/frame.rs
  15. +281
    -286
      src/proto/handler.rs
  16. +2
    -2
      src/proto/mod.rs
  17. +251
    -254
      src/proto/packet.rs
  18. +147
    -159
      src/proto/peer/message.rs
  19. +88
    -88
      src/proto/prefix.rs
  20. +486
    -542
      src/proto/server/request.rs
  21. +1237
    -1346
      src/proto/server/response.rs
  22. +165
    -165
      src/proto/stream.rs
  23. +44
    -44
      src/proto/testing.rs
  24. +2
    -2
      src/proto/u32.rs
  25. +75
    -80
      src/proto/user.rs
  26. +687
    -745
      src/proto/value_codec.rs
  27. +298
    -312
      src/room.rs
  28. +134
    -135
      src/user.rs

+ 1
- 0
rustfmt.toml View File

@ -1 +1,2 @@
tab_spaces = 2
max_width = 80

+ 695
- 713
src/client.rs
File diff suppressed because it is too large
View File


+ 21
- 21
src/context.rs View File

@ -12,35 +12,35 @@ use crate::user::UserMap;
/// Implements `Sync`.
#[derive(Debug)]
pub struct Context {
pub login: Mutex<LoginStatus>,
pub rooms: Mutex<RoomMap>,
pub users: Mutex<UserMap>,
pub login: Mutex<LoginStatus>,
pub rooms: Mutex<RoomMap>,
pub users: Mutex<UserMap>,
}
impl Context {
/// Creates a new empty context.
pub fn new() -> Self {
Self {
login: Mutex::new(LoginStatus::Todo),
rooms: Mutex::new(RoomMap::new()),
users: Mutex::new(UserMap::new()),
}
/// Creates a new empty context.
pub fn new() -> Self {
Self {
login: Mutex::new(LoginStatus::Todo),
rooms: Mutex::new(RoomMap::new()),
users: Mutex::new(UserMap::new()),
}
}
}
#[cfg(test)]
mod tests {
use super::Context;
use super::Context;
#[test]
fn new_context_is_empty() {
let context = Context::new();
assert_eq!(context.rooms.lock().get_room_list(), vec![]);
assert_eq!(context.users.lock().get_list(), vec![]);
}
#[test]
fn new_context_is_empty() {
let context = Context::new();
assert_eq!(context.rooms.lock().get_room_list(), vec![]);
assert_eq!(context.users.lock().get_list(), vec![]);
}
#[test]
fn context_is_sync() {
let _sync: &dyn Sync = &Context::new();
}
#[test]
fn context_is_sync() {
let _sync: &dyn Sync = &Context::new();
}
}

+ 16
- 16
src/control/request.rs View File

@ -2,25 +2,25 @@
/// controller client to the client.
#[derive(Debug, RustcDecodable, RustcEncodable)]
pub enum Request {
/// The controller wants to join a room. Contains the room name.
RoomJoinRequest(String),
/// The controller wants to leave a rom. Contains the room name.
RoomLeaveRequest(String),
/// The controller wants to know what the login status is.
LoginStatusRequest,
/// The controller wants to know the list of visible chat rooms.
RoomListRequest,
/// The controller wants to send a message to a chat room.
RoomMessageRequest(RoomMessageRequest),
/// The controller wants to know the list of known users.
UserListRequest,
/// The controller wants to join a room. Contains the room name.
RoomJoinRequest(String),
/// The controller wants to leave a rom. Contains the room name.
RoomLeaveRequest(String),
/// The controller wants to know what the login status is.
LoginStatusRequest,
/// The controller wants to know the list of visible chat rooms.
RoomListRequest,
/// The controller wants to send a message to a chat room.
RoomMessageRequest(RoomMessageRequest),
/// The controller wants to know the list of known users.
UserListRequest,
}
/// This structure contains the chat room message request from the controller.
#[derive(Debug, RustcDecodable, RustcEncodable)]
pub struct RoomMessageRequest {
/// The name of the chat room in which to send the message.
pub room_name: String,
/// The message to be said.
pub message: String,
/// The name of the chat room in which to send the message.
pub room_name: String,
/// The message to be said.
pub message: String,
}

+ 46
- 46
src/control/response.rs View File

@ -5,98 +5,98 @@ use crate::room;
/// to the controller.
#[derive(Debug, RustcDecodable, RustcEncodable)]
pub enum Response {
LoginStatusResponse(LoginStatusResponse),
RoomJoinResponse(RoomJoinResponse),
RoomLeaveResponse(RoomLeaveResponse),
RoomListResponse(RoomListResponse),
RoomMessageResponse(RoomMessageResponse),
RoomUserJoinedResponse(RoomUserJoinedResponse),
RoomUserLeftResponse(RoomUserLeftResponse),
UserInfoResponse(UserInfoResponse),
UserListResponse(UserListResponse),
LoginStatusResponse(LoginStatusResponse),
RoomJoinResponse(RoomJoinResponse),
RoomLeaveResponse(RoomLeaveResponse),
RoomListResponse(RoomListResponse),
RoomMessageResponse(RoomMessageResponse),
RoomUserJoinedResponse(RoomUserJoinedResponse),
RoomUserLeftResponse(RoomUserLeftResponse),
UserInfoResponse(UserInfoResponse),
UserListResponse(UserListResponse),
}
#[derive(Debug, RustcEncodable, RustcDecodable)]
pub struct RoomJoinResponse {
pub room_name: String,
pub room_name: String,
}
#[derive(Debug, RustcEncodable, RustcDecodable)]
pub struct RoomLeaveResponse {
pub room_name: String,
pub room_name: String,
}
/// This enumeration is the list of possible login states, and the associated
/// information.
#[derive(Debug, RustcDecodable, RustcEncodable)]
pub enum LoginStatusResponse {
/// The login request has been sent to the server, but the response hasn't
/// been received yet.
Pending {
/// The username used to log in.
username: String,
},
/// The login request has been sent to the server, but the response hasn't
/// been received yet.
Pending {
/// The username used to log in.
username: String,
},
/// Login was successful.
Success {
/// The username used to log in.
username: String,
/// The message of the day sent by the server.
motd: String,
},
/// Login was successful.
Success {
/// The username used to log in.
username: String,
/// The message of the day sent by the server.
motd: String,
},
/// Login failed.
Failure {
/// The username used to log in.
username: String,
/// The reason the server gave for refusing the login request.
reason: String,
},
/// Login failed.
Failure {
/// The username used to log in.
username: String,
/// The reason the server gave for refusing the login request.
reason: String,
},
}
/// This structure contains the list of all visible rooms, and their associated
/// data.
#[derive(Debug, RustcDecodable, RustcEncodable)]
pub struct RoomListResponse {
/// The list of (room name, room data) pairs.
pub rooms: Vec<(String, room::Room)>,
/// The list of (room name, room data) pairs.
pub rooms: Vec<(String, room::Room)>,
}
/// This structure contains a message said in a chat room the user is a member
/// of.
#[derive(Debug, RustcDecodable, RustcEncodable)]
pub struct RoomMessageResponse {
/// The name of the room in which the message was said.
pub room_name: String,
/// The name of the user who said the message.
pub user_name: String,
/// The message itself.
pub message: String,
/// The name of the room in which the message was said.
pub room_name: String,
/// The name of the user who said the message.
pub user_name: String,
/// The message itself.
pub message: String,
}
/// This struct describes the fact that the given user joined the given room.
#[derive(Debug, RustcDecodable, RustcEncodable)]
pub struct RoomUserJoinedResponse {
pub room_name: String,
pub user_name: String,
pub room_name: String,
pub user_name: String,
}
/// This struct describes the fact that the given user left the given room.
#[derive(Debug, RustcDecodable, RustcEncodable)]
pub struct RoomUserLeftResponse {
pub room_name: String,
pub user_name: String,
pub room_name: String,
pub user_name: String,
}
/// This struct contains the last known information about a given user.
#[derive(Debug, RustcDecodable, RustcEncodable)]
pub struct UserInfoResponse {
pub user_name: String,
pub user_info: User,
pub user_name: String,
pub user_info: User,
}
/// This stuct contains the last known information about every user.
#[derive(Debug, RustcDecodable, RustcEncodable)]
pub struct UserListResponse {
pub user_list: Vec<(String, User)>,
pub user_list: Vec<(String, User)>,
}

+ 134
- 136
src/control/ws.rs View File

@ -14,65 +14,65 @@ use super::response::*;
/// send to the client.
#[derive(Debug)]
pub enum Notification {
/// A new controller has connected: control messages can now be sent on the
/// given channel.
Connected(Sender),
/// The controller has disconnected.
Disconnected,
/// An irretrievable error has arisen.
Error(String),
/// The controller has sent a request.
Request(Request),
/// A new controller has connected: control messages can now be sent on the
/// given channel.
Connected(Sender),
/// The controller has disconnected.
Disconnected,
/// An irretrievable error has arisen.
Error(String),
/// The controller has sent a request.
Request(Request),
}
/// This error is returned when a `Sender` fails to send a control request.
#[derive(Debug)]
pub enum SendError {
/// Error encoding the control request.
JSONEncoderError(json::EncoderError),
/// Error sending the encoded control request to the websocket.
WebSocketError(ws::Error),
/// Error encoding the control request.
JSONEncoderError(json::EncoderError),
/// Error sending the encoded control request to the websocket.
WebSocketError(ws::Error),
}
impl fmt::Display for SendError {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
match *self {
SendError::JSONEncoderError(ref err) => {
write!(fmt, "JSONEncoderError: {}", err)
}
SendError::WebSocketError(ref err) => {
write!(fmt, "WebSocketError: {}", err)
}
}
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
match *self {
SendError::JSONEncoderError(ref err) => {
write!(fmt, "JSONEncoderError: {}", err)
}
SendError::WebSocketError(ref err) => {
write!(fmt, "WebSocketError: {}", err)
}
}
}
}
impl error::Error for SendError {
fn description(&self) -> &str {
match *self {
SendError::JSONEncoderError(_) => "JSONEncoderError",
SendError::WebSocketError(_) => "WebSocketError",
}
fn description(&self) -> &str {
match *self {
SendError::JSONEncoderError(_) => "JSONEncoderError",
SendError::WebSocketError(_) => "WebSocketError",
}
}
fn cause(&self) -> Option<&dyn error::Error> {
match *self {
SendError::JSONEncoderError(ref err) => Some(err),
SendError::WebSocketError(ref err) => Some(err),
}
fn cause(&self) -> Option<&dyn error::Error> {
match *self {
SendError::JSONEncoderError(ref err) => Some(err),
SendError::WebSocketError(ref err) => Some(err),
}
}
}
impl From<json::EncoderError> for SendError {
fn from(err: json::EncoderError) -> Self {
SendError::JSONEncoderError(err)
}
fn from(err: json::EncoderError) -> Self {
SendError::JSONEncoderError(err)
}
}
impl From<ws::Error> for SendError {
fn from(err: ws::Error) -> Self {
SendError::WebSocketError(err)
}
fn from(err: ws::Error) -> Self {
SendError::WebSocketError(err)
}
}
/// This struct is used to send control responses to the controller.
@ -80,125 +80,123 @@ impl From<ws::Error> for SendError {
/// the underlying implementation.
#[derive(Clone, Debug)]
pub struct Sender {
sender: ws::Sender,
sender: ws::Sender,
}
impl Sender {
/// Queues up a control response to be sent to the controller.
pub fn send(&mut self, response: Response) -> Result<(), SendError> {
let encoded = json::encode(&response)?;
self.sender.send(encoded)?;
Ok(())
}
/// Queues up a control response to be sent to the controller.
pub fn send(&mut self, response: Response) -> Result<(), SendError> {
let encoded = json::encode(&response)?;
self.sender.send(encoded)?;
Ok(())
}
}
/// This struct handles a single websocket connection.
#[derive(Debug)]
struct Handler {
/// The channel on which to send notifications to the client.
client_tx: crossbeam_channel::Sender<Notification>,
/// The channel on which to send messages to the controller.
socket_tx: ws::Sender,
/// The channel on which to send notifications to the client.
client_tx: crossbeam_channel::Sender<Notification>,
/// The channel on which to send messages to the controller.
socket_tx: ws::Sender,
}
impl Handler {
fn send_to_client(&self, notification: Notification) -> ws::Result<()> {
match self.client_tx.send(notification) {
Ok(()) => Ok(()),
Err(e) => {
error!("Error sending notification to client: {}", e);
Err(ws::Error::new(ws::ErrorKind::Internal, ""))
}
}
fn send_to_client(&self, notification: Notification) -> ws::Result<()> {
match self.client_tx.send(notification) {
Ok(()) => Ok(()),
Err(e) => {
error!("Error sending notification to client: {}", e);
Err(ws::Error::new(ws::ErrorKind::Internal, ""))
}
}
}
}
impl ws::Handler for Handler {
fn on_open(&mut self, _: ws::Handshake) -> ws::Result<()> {
info!("Websocket open");
self.send_to_client(Notification::Connected(Sender {
sender: self.socket_tx.clone(),
}))
}
fn on_open(&mut self, _: ws::Handshake) -> ws::Result<()> {
info!("Websocket open");
self.send_to_client(Notification::Connected(Sender {
sender: self.socket_tx.clone(),
}))
}
fn on_close(&mut self, code: ws::CloseCode, reason: &str) {
info!("Websocket closed: code: {:?}, reason: {:?}", code, reason);
self
.send_to_client(Notification::Disconnected)
.unwrap_or(())
}
fn on_message(&mut self, msg: ws::Message) -> ws::Result<()> {
// Get the payload string.
let payload = match msg {
ws::Message::Text(payload) => payload,
ws::Message::Binary(_) => {
error!("Received binary websocket message from controller");
return Err(ws::Error::new(
ws::ErrorKind::Protocol,
"Binary message not supported",
));
}
};
fn on_close(&mut self, code: ws::CloseCode, reason: &str) {
info!("Websocket closed: code: {:?}, reason: {:?}", code, reason);
self.send_to_client(Notification::Disconnected)
.unwrap_or(())
}
// Decode the json control request.
let control_request = match json::decode(&payload) {
Ok(control_request) => control_request,
Err(e) => {
error!("Received invalid JSON message from controller: {}", e);
return Err(ws::Error::new(ws::ErrorKind::Protocol, "Invalid JSON"));
}
};
fn on_message(&mut self, msg: ws::Message) -> ws::Result<()> {
// Get the payload string.
let payload = match msg {
ws::Message::Text(payload) => payload,
ws::Message::Binary(_) => {
error!("Received binary websocket message from controller");
return Err(ws::Error::new(
ws::ErrorKind::Protocol,
"Binary message not supported",
));
}
};
// Decode the json control request.
let control_request = match json::decode(&payload) {
Ok(control_request) => control_request,
Err(e) => {
error!("Received invalid JSON message from controller: {}", e);
return Err(ws::Error::new(
ws::ErrorKind::Protocol,
"Invalid JSON",
));
}
};
debug!("Received control request: {:?}", control_request);
// Send the control request to the client.
self.send_to_client(Notification::Request(control_request))
}
debug!("Received control request: {:?}", control_request);
// Send the control request to the client.
self.send_to_client(Notification::Request(control_request))
}
}
/// Start listening on the socket address stored in configuration, and send
/// control notifications to the client through the given channel.
pub fn listen(client_tx: crossbeam_channel::Sender<Notification>) {
let websocket_result = ws::Builder::new()
.with_settings(ws::Settings {
max_connections: 1,
..ws::Settings::default()
})
.build(|socket_tx| Handler {
client_tx: client_tx.clone(),
socket_tx: socket_tx,
});
let websocket = match websocket_result {
Ok(websocket) => websocket,
Err(e) => {
error!("Unable to build websocket: {}", e);
client_tx
.send(Notification::Error(format!(
"Unable to build websocket: {}",
e
)))
.unwrap();
return;
}
};
let listen_result =
websocket.listen((config::CONTROL_HOST, config::CONTROL_PORT));
match listen_result {
Ok(_) => (),
Err(e) => {
error!("Unable to listen on websocket: {}", e);
client_tx
.send(Notification::Error(format!(
"Unable to listen on websocket: {}",
e
)))
.unwrap();
}
let websocket_result = ws::Builder::new()
.with_settings(ws::Settings {
max_connections: 1,
..ws::Settings::default()
})
.build(|socket_tx| Handler {
client_tx: client_tx.clone(),
socket_tx: socket_tx,
});
let websocket = match websocket_result {
Ok(websocket) => websocket,
Err(e) => {
error!("Unable to build websocket: {}", e);
client_tx
.send(Notification::Error(format!(
"Unable to build websocket: {}",
e
)))
.unwrap();
return;
}
};
let listen_result =
websocket.listen((config::CONTROL_HOST, config::CONTROL_PORT));
match listen_result {
Ok(_) => (),
Err(e) => {
error!("Unable to listen on websocket: {}", e);
client_tx
.send(Notification::Error(format!(
"Unable to listen on websocket: {}",
e
)))
.unwrap();
}
}
}

+ 55
- 64
src/dispatcher.rs View File

@ -11,97 +11,88 @@ use crate::proto::server::ServerResponse;
/// The type of messages dispatched by a dispatcher.
#[derive(Debug)]
pub enum Message {
ServerResponse(ServerResponse),
ServerResponse(ServerResponse),
}
/// Pairs together a message and its handler as chosen by the dispatcher.
/// Implements Job so as to be scheduled on an executor.
struct DispatchedMessage<M, H> {
message: M,
handler: H,
message: M,
handler: H,
}
impl<M, H> DispatchedMessage<M, H> {
fn new(message: M, handler: H) -> Self {
Self { message, handler }
}
fn new(message: M, handler: H) -> Self {
Self { message, handler }
}
}
impl<M, H> Job for DispatchedMessage<M, H>
where
M: Debug + Send,
H: MessageHandler<M> + Send,
M: Debug + Send,
H: MessageHandler<M> + Send,
{
fn execute(self: Box<Self>, context: &Context) {
if let Err(error) = self.handler.run(context, &self.message) {
error!(
"Error in handler {}: {:?}\nMessage: {:?}",
H::name(),
error,
&self.message
);
}
fn execute(self: Box<Self>, context: &Context) {
if let Err(error) = self.handler.run(context, &self.message) {
error!(
"Error in handler {}: {:?}\nMessage: {:?}",
H::name(),
error,
&self.message
);
}
}
}
/// The Dispatcher is in charge of mapping messages to their handlers.
pub struct Dispatcher;
impl Dispatcher {
/// Returns a new dispatcher.
pub fn new() -> Self {
Self {}
}
/// Returns a new dispatcher.
pub fn new() -> Self {
Self {}
}
/// Dispatches the given message by wrapping it with a handler.
pub fn dispatch(&self, message: Message) -> Box<dyn Job> {
match message {
Message::ServerResponse(ServerResponse::LoginResponse(
response,
)) => Box::new(DispatchedMessage::new(
response,
LoginHandler::default(),
)),
Message::ServerResponse(
ServerResponse::PrivilegedUsersResponse(response),
) => Box::new(DispatchedMessage::new(
response,
SetPrivilegedUsersHandler::default(),
)),
_ => panic!("Unimplemented"),
}
/// Dispatches the given message by wrapping it with a handler.
pub fn dispatch(&self, message: Message) -> Box<dyn Job> {
match message {
Message::ServerResponse(ServerResponse::LoginResponse(response)) => {
Box::new(DispatchedMessage::new(response, LoginHandler::default()))
}
Message::ServerResponse(ServerResponse::PrivilegedUsersResponse(
response,
)) => Box::new(DispatchedMessage::new(
response,
SetPrivilegedUsersHandler::default(),
)),
_ => panic!("Unimplemented"),
}
}
}
#[cfg(test)]
mod tests {
use crate::proto::server;
use crate::proto::server;
use super::*;
use super::*;
#[test]
fn dispatcher_privileged_users_response() {
Dispatcher::new().dispatch(Message::ServerResponse(
server::ServerResponse::PrivilegedUsersResponse(
server::PrivilegedUsersResponse {
users: vec![
"foo".to_string(),
"bar".to_string(),
"baz".to_string(),
],
},
),
));
}
#[test]
fn dispatcher_privileged_users_response() {
Dispatcher::new().dispatch(Message::ServerResponse(
server::ServerResponse::PrivilegedUsersResponse(
server::PrivilegedUsersResponse {
users: vec!["foo".to_string(), "bar".to_string(), "baz".to_string()],
},
),
));
}
#[test]
fn dispatcher_login_response() {
Dispatcher::new().dispatch(Message::ServerResponse(
server::ServerResponse::LoginResponse(
server::LoginResponse::LoginFail {
reason: "bleep bloop".to_string(),
},
),
));
}
#[test]
fn dispatcher_login_response() {
Dispatcher::new().dispatch(Message::ServerResponse(
server::ServerResponse::LoginResponse(server::LoginResponse::LoginFail {
reason: "bleep bloop".to_string(),
}),
));
}
}

+ 103
- 103
src/executor.rs View File

@ -15,133 +15,133 @@ const NUM_THREADS: usize = 8;
/// The trait of objects that can be run by an Executor.
pub trait Job: Send {
/// Runs this job against the given context.
fn execute(self: Box<Self>, context: &Context);
/// Runs this job against the given context.
fn execute(self: Box<Self>, context: &Context);
}
/// A concurrent job execution engine.
pub struct Executor {
/// The context against which jobs are executed.
context: Arc<Context>,
/// The context against which jobs are executed.
context: Arc<Context>,
/// Executes the jobs.
pool: threadpool::ThreadPool,
/// Executes the jobs.
pool: threadpool::ThreadPool,
}
impl Executor {
/// Builds a new executor against the given context.
pub fn new(context: Context) -> Self {
Self {
context: Arc::new(context),
pool: threadpool::Builder::new()
.num_threads(NUM_THREADS)
.thread_name("Executor".to_string())
.build(),
}
}
/// Schedules execution of the given job on this executor.
pub fn schedule(&self, job: Box<dyn Job>) {
let context = self.context.clone();
self.pool.execute(move || job.execute(&*context));
}
/// Blocks until all scheduled jobs are executed, then returns the context.
pub fn join(self) -> Context {
self.pool.join();
// The only copies of the Arc are passed to the closures executed on
// the threadpool. Once the pool is join()ed, there cannot exist any
// other copies than ours, so we are safe to unwrap() the Arc.
Arc::try_unwrap(self.context).unwrap()
/// Builds a new executor against the given context.
pub fn new(context: Context) -> Self {
Self {
context: Arc::new(context),
pool: threadpool::Builder::new()
.num_threads(NUM_THREADS)
.thread_name("Executor".to_string())
.build(),
}
}
/// Schedules execution of the given job on this executor.
pub fn schedule(&self, job: Box<dyn Job>) {
let context = self.context.clone();
self.pool.execute(move || job.execute(&*context));
}
/// Blocks until all scheduled jobs are executed, then returns the context.
pub fn join(self) -> Context {
self.pool.join();
// The only copies of the Arc are passed to the closures executed on
// the threadpool. Once the pool is join()ed, there cannot exist any
// other copies than ours, so we are safe to unwrap() the Arc.
Arc::try_unwrap(self.context).unwrap()
}
}
#[cfg(test)]
mod tests {
use std::sync::{Arc, Barrier};
use std::sync::{Arc, Barrier};
use crate::proto::{User, UserStatus};
use crate::proto::{User, UserStatus};
use super::{Context, Executor, Job};
use super::{Context, Executor, Job};
#[test]
fn immediate_join_returns_empty_context() {
let context = Executor::new(Context::new()).join();
assert_eq!(context.users.lock().get_list(), vec![]);
assert_eq!(context.rooms.lock().get_room_list(), vec![]);
}
#[test]
fn immediate_join_returns_empty_context() {
let context = Executor::new(Context::new()).join();
assert_eq!(context.users.lock().get_list(), vec![]);
assert_eq!(context.rooms.lock().get_room_list(), vec![]);
}
struct Waiter {
barrier: Arc<Barrier>,
}
struct Waiter {
barrier: Arc<Barrier>,
}
impl Job for Waiter {
fn execute(self: Box<Self>, _context: &Context) {
self.barrier.wait();
}
impl Job for Waiter {
fn execute(self: Box<Self>, _context: &Context) {
self.barrier.wait();
}
}
#[test]
fn join_waits_for_all_jobs() {
let executor = Executor::new(Context::new());
#[test]
fn join_waits_for_all_jobs() {
let executor = Executor::new(Context::new());
let barrier = Arc::new(Barrier::new(2));
let barrier = Arc::new(Barrier::new(2));
executor.schedule(Box::new(Waiter {
barrier: barrier.clone(),
}));
executor.schedule(Box::new(Waiter {
barrier: barrier.clone(),
}));
executor.schedule(Box::new(Waiter {
barrier: barrier.clone(),
}));
executor.schedule(Box::new(Waiter {
barrier: barrier.clone(),
}));
executor.join();
}
executor.join();
}
struct UserAdder {
pub user: User,
}
impl Job for UserAdder {
fn execute(self: Box<Self>, context: &Context) {
context.users.lock().insert(self.user);
}
}
struct UserAdder {
pub user: User,
}
#[test]
fn jobs_access_context() {
let executor = Executor::new(Context::new());
let user1 = User {
name: "potato".to_string(),
status: UserStatus::Offline,
average_speed: 0,
num_downloads: 0,
unknown: 0,
num_files: 0,
num_folders: 0,
num_free_slots: 0,
country: "YO".to_string(),
};
let mut user2 = user1.clone();
user2.name = "rutabaga".to_string();
executor.schedule(Box::new(UserAdder {
user: user1.clone(),
}));
executor.schedule(Box::new(UserAdder {
user: user2.clone(),
}));
let context = executor.join();
let expected_users =
vec![(user1.name.clone(), user1), (user2.name.clone(), user2)];
let mut users = context.users.lock().get_list();
users.sort();
assert_eq!(users, expected_users);
impl Job for UserAdder {
fn execute(self: Box<Self>, context: &Context) {
context.users.lock().insert(self.user);
}
}
#[test]
fn jobs_access_context() {
let executor = Executor::new(Context::new());
let user1 = User {
name: "potato".to_string(),
status: UserStatus::Offline,
average_speed: 0,
num_downloads: 0,
unknown: 0,
num_files: 0,
num_folders: 0,
num_free_slots: 0,
country: "YO".to_string(),
};
let mut user2 = user1.clone();
user2.name = "rutabaga".to_string();
executor.schedule(Box::new(UserAdder {
user: user1.clone(),
}));
executor.schedule(Box::new(UserAdder {
user: user2.clone(),
}));
let context = executor.join();
let expected_users =
vec![(user1.name.clone(), user1), (user2.name.clone(), user2)];
let mut users = context.users.lock().get_list();
users.sort();
assert_eq!(users, expected_users);
}
}

+ 29
- 33
src/handlers/login_handler.rs View File

@ -9,44 +9,40 @@ use crate::proto::server::LoginResponse;
pub struct LoginHandler;
impl MessageHandler<LoginResponse> for LoginHandler {
fn run(
self,
context: &Context,
_message: &LoginResponse,
) -> io::Result<()> {
let lock = context.login.lock();
match *lock {
LoginStatus::AwaitingResponse => (),
_ => {
return Err(io::Error::new(
io::ErrorKind::Other,
format!("unexpected login response, status = {:?}", *lock),
));
}
};
unimplemented!();
}
fn name() -> String {
"LoginHandler".to_string()
}
fn run(self, context: &Context, _message: &LoginResponse) -> io::Result<()> {
let lock = context.login.lock();
match *lock {
LoginStatus::AwaitingResponse => (),
_ => {
return Err(io::Error::new(
io::ErrorKind::Other,
format!("unexpected login response, status = {:?}", *lock),
));
}
};
unimplemented!();
}
fn name() -> String {
"LoginHandler".to_string()
}
}
#[cfg(test)]
mod tests {
use super::*;
use super::*;
#[test]
#[should_panic]
fn run_fails_on_wrong_status() {
let context = Context::new();
#[test]
#[should_panic]
fn run_fails_on_wrong_status() {
let context = Context::new();
let response = LoginResponse::LoginFail {
reason: "bleep bloop".to_string(),
};
let response = LoginResponse::LoginFail {
reason: "bleep bloop".to_string(),
};
LoginHandler::default().run(&context, &response).unwrap();
}
LoginHandler::default().run(&context, &response).unwrap();
}
}

+ 34
- 34
src/handlers/set_privileged_users_handler.rs View File

@ -8,48 +8,48 @@ use crate::proto::server::PrivilegedUsersResponse;
pub struct SetPrivilegedUsersHandler;
impl MessageHandler<PrivilegedUsersResponse> for SetPrivilegedUsersHandler {
fn run(
self,
context: &Context,
message: &PrivilegedUsersResponse,
) -> io::Result<()> {
let users = message.users.clone();
context.users.lock().set_all_privileged(users);
Ok(())
}
fn name() -> String {
"SetPrivilegedUsersHandler".to_string()
}
fn run(
self,
context: &Context,
message: &PrivilegedUsersResponse,
) -> io::Result<()> {
let users = message.users.clone();
context.users.lock().set_all_privileged(users);
Ok(())
}
fn name() -> String {
"SetPrivilegedUsersHandler".to_string()
}
}
#[cfg(test)]
mod tests {
use crate::context::Context;
use crate::message_handler::MessageHandler;
use crate::proto::server::PrivilegedUsersResponse;
use crate::context::Context;
use crate::message_handler::MessageHandler;
use crate::proto::server::PrivilegedUsersResponse;
use super::SetPrivilegedUsersHandler;
use super::SetPrivilegedUsersHandler;
#[test]
fn run_sets_privileged_users() {
let context = Context::new();
#[test]
fn run_sets_privileged_users() {
let context = Context::new();
let response = PrivilegedUsersResponse {
users: vec![
"aomame".to_string(),
"billybob".to_string(),
"carlos".to_string(),
],
};
let response = PrivilegedUsersResponse {
users: vec![
"aomame".to_string(),
"billybob".to_string(),
"carlos".to_string(),
],
};
SetPrivilegedUsersHandler::default()
.run(&context, &response)
.unwrap();
SetPrivilegedUsersHandler::default()
.run(&context, &response)
.unwrap();
let mut privileged = context.users.lock().get_all_privileged();
privileged.sort();
let mut privileged = context.users.lock().get_all_privileged();
privileged.sort();
assert_eq!(privileged, response.users);
}
assert_eq!(privileged, response.users);
}
}

+ 10
- 10
src/login.rs View File

@ -6,17 +6,17 @@
/// successfully logged in, the client can interact with the server.
#[derive(Clone, Debug)]
pub enum LoginStatus {
/// Request not sent yet.
Todo,
/// Request not sent yet.
Todo,
/// Sent request, awaiting response.
AwaitingResponse,
/// Sent request, awaiting response.
AwaitingResponse,
/// Logged in.
/// Stores the MOTD as received from the server.
Success(String),
/// Logged in.
/// Stores the MOTD as received from the server.
Success(String),
/// Failed to log in.
/// Stores the error message as received from the server.
Failure(String),
/// Failed to log in.
/// Stores the error message as received from the server.
Failure(String),
}

+ 31
- 32
src/main.rs View File

@ -21,36 +21,35 @@ mod room;
mod user;
fn main() {
match env_logger::init() {
Ok(()) => (),
Err(err) => {
error!("Error initializing logger: {}", err);
return;
}
};
let (proto_to_client_tx, proto_to_client_rx) =
crossbeam_channel::unbounded();
let mut proto_agent = match proto::Agent::new(proto_to_client_tx) {
Ok(agent) => agent,
Err(err) => {
error!("Error initializing protocol agent: {}", err);
return;
}
};
let client_to_proto_tx = proto_agent.channel();
let (control_to_client_tx, control_to_client_rx) =
crossbeam_channel::unbounded();
let mut client = client::Client::new(
client_to_proto_tx,
proto_to_client_rx,
control_to_client_rx,
);
thread::spawn(move || control::listen(control_to_client_tx));
thread::spawn(move || proto_agent.run().unwrap());
client.run();
match env_logger::init() {
Ok(()) => (),
Err(err) => {
error!("Error initializing logger: {}", err);
return;
}
};
let (proto_to_client_tx, proto_to_client_rx) = crossbeam_channel::unbounded();
let mut proto_agent = match proto::Agent::new(proto_to_client_tx) {
Ok(agent) => agent,
Err(err) => {
error!("Error initializing protocol agent: {}", err);
return;
}
};
let client_to_proto_tx = proto_agent.channel();
let (control_to_client_tx, control_to_client_rx) =
crossbeam_channel::unbounded();
let mut client = client::Client::new(
client_to_proto_tx,
proto_to_client_rx,
control_to_client_rx,
);
thread::spawn(move || control::listen(control_to_client_tx));
thread::spawn(move || proto_agent.run().unwrap());
client.run();
}

+ 4
- 4
src/message_handler.rs View File

@ -7,9 +7,9 @@ use crate::context::Context;
/// Message types are mapped to handler types by Dispatcher.
/// This trait is intended to allow composing handler logic.
pub trait MessageHandler<Message> {
/// Attempts to handle the given message against the given context.
fn run(self, context: &Context, message: &Message) -> io::Result<()>;
/// Attempts to handle the given message against the given context.
fn run(self, context: &Context, message: &Message) -> io::Result<()>;
/// Returns the name of this handler type.
fn name() -> String;
/// Returns the name of this handler type.
fn name() -> String;
}

+ 320
- 322
src/proto/frame.rs View File

@ -14,374 +14,372 @@ use tokio::net::TcpStream;
use super::prefix::Prefixer;
use super::u32::{decode_u32, U32_BYTE_LEN};
use super::value_codec::{
ValueDecode, ValueDecodeError, ValueDecoder, ValueEncode, ValueEncodeError,
ValueEncoder,
ValueDecode, ValueDecodeError, ValueDecoder, ValueEncode, ValueEncodeError,
ValueEncoder,
};
#[derive(Debug, Error, PartialEq)]
pub enum FrameEncodeError {
#[error("encoded value length {length} is too large")]
ValueTooLarge {
/// The length of the encoded value.
length: usize,
},
#[error("failed to encode value: {0}")]
ValueEncodeError(#[from] ValueEncodeError),
#[error("encoded value length {length} is too large")]
ValueTooLarge {
/// The length of the encoded value.
length: usize,
},
#[error("failed to encode value: {0}")]
ValueEncodeError(#[from] ValueEncodeError),
}
impl From<FrameEncodeError> for io::Error {
fn from(error: FrameEncodeError) -> Self {
io::Error::new(io::ErrorKind::InvalidData, format!("{}", error))
}
fn from(error: FrameEncodeError) -> Self {
io::Error::new(io::ErrorKind::InvalidData, format!("{}", error))
}
}
/// Encodes entire protocol frames containing values of type `T`.
#[derive(Debug)]
pub struct FrameEncoder<T: ?Sized> {
phantom: PhantomData<T>,
phantom: PhantomData<T>,
}
impl<T: ValueEncode + ?Sized> FrameEncoder<T> {
pub fn new() -> Self {
Self {
phantom: PhantomData,
}
pub fn new() -> Self {
Self {
phantom: PhantomData,
}
}
pub fn encode_to(
&mut self,
value: &T,
buffer: &mut BytesMut,
) -> Result<(), FrameEncodeError> {
let mut prefixer = Prefixer::new(buffer);
ValueEncoder::new(prefixer.suffix_mut()).encode(value)?;
pub fn encode_to(
&mut self,
value: &T,
buffer: &mut BytesMut,
) -> Result<(), FrameEncodeError> {
let mut prefixer = Prefixer::new(buffer);
if let Err(prefixer) = prefixer.finalize() {
return Err(FrameEncodeError::ValueTooLarge {
length: prefixer.suffix().len(),
});
}
ValueEncoder::new(prefixer.suffix_mut()).encode(value)?;
Ok(())
if let Err(prefixer) = prefixer.finalize() {
return Err(FrameEncodeError::ValueTooLarge {
length: prefixer.suffix().len(),
});
}
Ok(())
}
}
/// Decodes entire protocol frames containing values of type `T`.
#[derive(Debug)]
pub struct FrameDecoder<T> {
// Only here to enable parameterizing `Decoder` by `T`.
phantom: PhantomData<T>,
// Only here to enable parameterizing `Decoder` by `T`.
phantom: PhantomData<T>,
}
impl<T: ValueDecode> FrameDecoder<T> {
pub fn new() -> Self {
Self {
phantom: PhantomData,
}
pub fn new() -> Self {
Self {
phantom: PhantomData,
}
}
/// Attempts to decode an entire frame from the given buffer.
///
/// Returns `Ok(Some(frame))` if successful, in which case the frame's bytes
/// have been split off from the left of `bytes`.
///
/// Returns `Ok(None)` if not enough bytes are available to decode an entire
/// frame yet, in which case `bytes` is untouched.
///
/// Returns an error if the length prefix or the framed value are malformed,
/// in which case `bytes` is untouched.
pub fn decode_from(
&mut self,
bytes: &mut BytesMut,
) -> Result<Option<T>, ValueDecodeError> {
if bytes.len() < U32_BYTE_LEN {
return Ok(None); // Not enough bytes yet.
}
/// Attempts to decode an entire frame from the given buffer.
///
/// Returns `Ok(Some(frame))` if successful, in which case the frame's bytes
/// have been split off from the left of `bytes`.
///
/// Returns `Ok(None)` if not enough bytes are available to decode an entire
/// frame yet, in which case `bytes` is untouched.
///
/// Returns an error if the length prefix or the framed value are malformed,
/// in which case `bytes` is untouched.
pub fn decode_from(
&mut self,
bytes: &mut BytesMut,
) -> Result<Option<T>, ValueDecodeError> {
if bytes.len() < U32_BYTE_LEN {
return Ok(None); // Not enough bytes yet.
}
// Split the prefix off. After this:
//
// | bytes (len 4) | suffix |
//
// NOTE: This method would be simpler if we could use split_to() instead
// here such that `bytes` contained the suffix. At the end, we would not
// have to replace `bytes` with `suffix`. However, that would require
// calling `prefix.unsplit(*bytes)`, and that does not work since
// `bytes` is only borrowed, and unsplit() takes its argument by value.
let mut suffix = bytes.split_off(U32_BYTE_LEN);
// unwrap() cannot panic because `bytes` is of the exact right length.
let array: [u8; U32_BYTE_LEN] = bytes.as_ref().try_into().unwrap();
let length = decode_u32(array) as usize;
if suffix.len() < length {
// Re-assemble `bytes` as it first was.
bytes.unsplit(suffix);
return Ok(None); // Not enough bytes yet.
}
// Split off the right amount of bytes from the buffer. After this:
//
// | bytes (len 4) | contents | suffix |
//
let mut contents = suffix.split_to(length);
// Attempt to decode the value.
let item = match ValueDecoder::new(&contents).decode() {
Ok(item) => item,
Err(error) => {
// Re-assemble `bytes` as it first was.
contents.unsplit(suffix);
bytes.unsplit(contents);
return Err(error);
}
};
// Remove the decoded bytes from the left of `bytes`.
*bytes = suffix;
Ok(Some(item))
// Split the prefix off. After this:
//
// | bytes (len 4) | suffix |
//
// NOTE: This method would be simpler if we could use split_to() instead
// here such that `bytes` contained the suffix. At the end, we would not
// have to replace `bytes` with `suffix`. However, that would require
// calling `prefix.unsplit(*bytes)`, and that does not work since
// `bytes` is only borrowed, and unsplit() takes its argument by value.
let mut suffix = bytes.split_off(U32_BYTE_LEN);
// unwrap() cannot panic because `bytes` is of the exact right length.
let array: [u8; U32_BYTE_LEN] = bytes.as_ref().try_into().unwrap();
let length = decode_u32(array) as usize;
if suffix.len() < length {
// Re-assemble `bytes` as it first was.
bytes.unsplit(suffix);
return Ok(None); // Not enough bytes yet.
}
// Split off the right amount of bytes from the buffer. After this:
//
// | bytes (len 4) | contents | suffix |
//
let mut contents = suffix.split_to(length);
// Attempt to decode the value.
let item = match ValueDecoder::new(&contents).decode() {
Ok(item) => item,
Err(error) => {
// Re-assemble `bytes` as it first was.
contents.unsplit(suffix);
bytes.unsplit(contents);
return Err(error);
}
};
// Remove the decoded bytes from the left of `bytes`.
*bytes = suffix;
Ok(Some(item))
}
}
#[derive(Debug)]
pub struct FrameStream<ReadFrame, WriteFrame: ?Sized> {
stream: TcpStream,
stream: TcpStream,
read_buffer: BytesMut,
read_buffer: BytesMut,
decoder: FrameDecoder<ReadFrame>,
encoder: FrameEncoder<WriteFrame>,
decoder: FrameDecoder<ReadFrame>,
encoder: FrameEncoder<WriteFrame>,
}
impl<ReadFrame, WriteFrame> FrameStream<ReadFrame, WriteFrame>
where
ReadFrame: ValueDecode,
WriteFrame: ValueEncode + ?Sized,
ReadFrame: ValueDecode,
WriteFrame: ValueEncode + ?Sized,
{
pub fn new(stream: TcpStream) -> Self {
FrameStream {
stream,
read_buffer: BytesMut::new(),
decoder: FrameDecoder::new(),
encoder: FrameEncoder::new(),
}
pub fn new(stream: TcpStream) -> Self {
FrameStream {
stream,
read_buffer: BytesMut::new(),
decoder: FrameDecoder::new(),
encoder: FrameEncoder::new(),
}
pub async fn read(&mut self) -> io::Result<ReadFrame> {
loop {
if let Some(frame) =
self.decoder.decode_from(&mut self.read_buffer)?
{
return Ok(frame);
}
self.stream.read_buf(&mut self.read_buffer).await?;
}
}
pub async fn read(&mut self) -> io::Result<ReadFrame> {
loop {
if let Some(frame) = self.decoder.decode_from(&mut self.read_buffer)? {
return Ok(frame);
}
self.stream.read_buf(&mut self.read_buffer).await?;
}
}
pub async fn write(&mut self, frame: &WriteFrame) -> io::Result<()> {
let mut bytes = BytesMut::new();
self.encoder.encode_to(frame, &mut bytes)?;
self.stream.write_all(bytes.as_ref()).await
}
pub async fn write(&mut self, frame: &WriteFrame) -> io::Result<()> {
let mut bytes = BytesMut::new();
self.encoder.encode_to(frame, &mut bytes)?;
self.stream.write_all(bytes.as_ref()).await
}
}
mod tests {
use bytes::BytesMut;
use tokio::net::{TcpListener, TcpStream};
use super::{FrameDecoder, FrameEncoder, FrameStream};
// Test value: [1, 3, 3, 7] in little-endian.
const U32_1337: u32 = 1 + (3 << 8) + (3 << 16) + (7 << 24);
#[test]
fn encode_u32() {
let mut bytes = BytesMut::new();
FrameEncoder::new()
.encode_to(&U32_1337, &mut bytes)
.unwrap();
assert_eq!(
bytes,
vec![
4, 0, 0, 0, // 1 32-bit integer = 4 bytes.
1, 3, 3, 7, // Little-endian integer.
]
);
}
#[test]
fn encode_appends() {
let mut bytes = BytesMut::new();
let mut encoder = FrameEncoder::new();
encoder.encode_to(&U32_1337, &mut bytes).unwrap();
encoder.encode_to(&U32_1337, &mut bytes).unwrap();
assert_eq!(
bytes,
vec![
4, 0, 0, 0, // 1 32-bit integer = 4 bytes.
1, 3, 3, 7, // Little-endian integer.
4, 0, 0, 0, // Repeated.
1, 3, 3, 7,
]
);
}
#[test]
fn encode_vec() {
let v: Vec<u32> = vec![1, 3, 3, 7];
let mut bytes = BytesMut::new();
FrameEncoder::new().encode_to(&v, &mut bytes).unwrap();
assert_eq!(
bytes,
vec![
20, 0, 0, 0, // 5 32-bit integers = 20 bytes.
4, 0, 0, 0, // 4 elements in the vector.
1, 0, 0, 0, // Little-endian vector elements.
3, 0, 0, 0, //
3, 0, 0, 0, //
7, 0, 0, 0, //
]
);
}
#[test]
fn decode_not_enough_data_for_prefix() {
let initial_bytes = vec![
4, 0, 0, // Incomplete 32-bit length prefix.
];
let mut bytes = BytesMut::new();
bytes.extend_from_slice(&initial_bytes);
let value: Option<u32> =
FrameDecoder::new().decode_from(&mut bytes).unwrap();
assert_eq!(value, None);
assert_eq!(bytes, initial_bytes); // Untouched.
}
#[test]
fn decode_not_enough_data_for_contents() {
let initial_bytes = vec![
4, 0, 0, 0, // Length 4.
1, 2, 3, // But there are only 3 bytes!
];
let mut bytes = BytesMut::new();
bytes.extend_from_slice(&initial_bytes);
let value: Option<u32> =
FrameDecoder::new().decode_from(&mut bytes).unwrap();
assert_eq!(value, None);
assert_eq!(bytes, initial_bytes); // Untouched.
}
#[test]
fn decode_u32() {
let mut bytes = BytesMut::new();
bytes.extend_from_slice(&[
4, 0, 0, 0, // 1 32-bit integer = 4 bytes.
1, 3, 3, 7, // Little-endian integer.
4, 2, // Trailing bytes.
]);
let value = FrameDecoder::new().decode_from(&mut bytes).unwrap();
assert_eq!(value, Some(U32_1337));
assert_eq!(bytes, vec![4, 2]); // Decoded bytes were split off.
}
#[test]
fn decode_vec() {
let mut bytes = BytesMut::new();
bytes.extend_from_slice(&[
20, 0, 0, 0, // 5 32-bit integers = 20 bytes.
4, 0, 0, 0, // 4 elements in the vector.
1, 0, 0, 0, // Little-endian vector elements.
3, 0, 0, 0, //
3, 0, 0, 0, //
7, 0, 0, 0, //
4, 2, // Trailing bytes.
]);
let value = FrameDecoder::new().decode_from(&mut bytes).unwrap();
let expected_value: Vec<u32> = vec![1, 3, 3, 7];
assert_eq!(value, Some(expected_value));
assert_eq!(bytes, vec![4, 2]); // Decoded bytes were split off.
}
#[test]
fn roundtrip() {
let value: Vec<String> = vec![
"apples".to_string(), //
"bananas".to_string(), //
"oranges".to_string(), //
"and cheese!".to_string(), //
];
let mut buffer = BytesMut::new();
FrameEncoder::new().encode_to(&value, &mut buffer).unwrap();
let decoded = FrameDecoder::new().decode_from(&mut buffer).unwrap();
assert_eq!(decoded, Some(value));
assert_eq!(buffer, vec![]);
}
#[tokio::test]
async fn ping_pong() {
let listener = TcpListener::bind("localhost:0").await.unwrap();
let address = listener.local_addr().unwrap();
let server_task = tokio::spawn(async move {
let (stream, _peer_address) = listener.accept().await.unwrap();
let mut frame_stream = FrameStream::<String, str>::new(stream);
assert_eq!(frame_stream.read().await.unwrap(), "ping");
frame_stream.write("pong").await.unwrap();
assert_eq!(frame_stream.read().await.unwrap(), "ping");
frame_stream.write("pong").await.unwrap();
});
let stream = TcpStream::connect(address).await.unwrap();
let mut frame_stream = FrameStream::<String, str>::new(stream);
frame_stream.write("ping").await.unwrap();
assert_eq!(frame_stream.read().await.unwrap(), "pong");
frame_stream.write("ping").await.unwrap();
assert_eq!(frame_stream.read().await.unwrap(), "pong");
server_task.await.unwrap();
}
#[tokio::test]
async fn very_large_message() {
let listener = TcpListener::bind("localhost:0").await.unwrap();
let address = listener.local_addr().unwrap();
let server_task = tokio::spawn(async move {
let (stream, _peer_address) = listener.accept().await.unwrap();
let mut frame_stream = FrameStream::<String, Vec<u32>>::new(stream);
assert_eq!(frame_stream.read().await.unwrap(), "ping");
frame_stream.write(&vec![0; 10 * 4096]).await.unwrap();
});
let stream = TcpStream::connect(address).await.unwrap();
let mut frame_stream = FrameStream::<Vec<u32>, str>::new(stream);
frame_stream.write("ping").await.unwrap();
assert_eq!(frame_stream.read().await.unwrap(), vec![0; 10 * 4096]);
server_task.await.unwrap();
}
use bytes::BytesMut;
use tokio::net::{TcpListener, TcpStream};
use super::{FrameDecoder, FrameEncoder, FrameStream};
// Test value: [1, 3, 3, 7] in little-endian.
const U32_1337: u32 = 1 + (3 << 8) + (3 << 16) + (7 << 24);
#[test]
fn encode_u32() {
let mut bytes = BytesMut::new();
FrameEncoder::new()
.encode_to(&U32_1337, &mut bytes)
.unwrap();
assert_eq!(
bytes,
vec![
4, 0, 0, 0, // 1 32-bit integer = 4 bytes.
1, 3, 3, 7, // Little-endian integer.
]
);
}
#[test]
fn encode_appends() {
let mut bytes = BytesMut::new();
let mut encoder = FrameEncoder::new();
encoder.encode_to(&U32_1337, &mut bytes).unwrap();
encoder.encode_to(&U32_1337, &mut bytes).unwrap();
assert_eq!(
bytes,
vec![
4, 0, 0, 0, // 1 32-bit integer = 4 bytes.
1, 3, 3, 7, // Little-endian integer.
4, 0, 0, 0, // Repeated.
1, 3, 3, 7,
]
);
}
#[test]
fn encode_vec() {
let v: Vec<u32> = vec![1, 3, 3, 7];
let mut bytes = BytesMut::new();
FrameEncoder::new().encode_to(&v, &mut bytes).unwrap();
assert_eq!(
bytes,
vec![
20, 0, 0, 0, // 5 32-bit integers = 20 bytes.
4, 0, 0, 0, // 4 elements in the vector.
1, 0, 0, 0, // Little-endian vector elements.
3, 0, 0, 0, //
3, 0, 0, 0, //
7, 0, 0, 0, //
]
);
}
#[test]
fn decode_not_enough_data_for_prefix() {
let initial_bytes = vec![
4, 0, 0, // Incomplete 32-bit length prefix.
];
let mut bytes = BytesMut::new();
bytes.extend_from_slice(&initial_bytes);
let value: Option<u32> =
FrameDecoder::new().decode_from(&mut bytes).unwrap();
assert_eq!(value, None);
assert_eq!(bytes, initial_bytes); // Untouched.
}
#[test]
fn decode_not_enough_data_for_contents() {
let initial_bytes = vec![
4, 0, 0, 0, // Length 4.
1, 2, 3, // But there are only 3 bytes!
];
let mut bytes = BytesMut::new();
bytes.extend_from_slice(&initial_bytes);
let value: Option<u32> =
FrameDecoder::new().decode_from(&mut bytes).unwrap();
assert_eq!(value, None);
assert_eq!(bytes, initial_bytes); // Untouched.
}
#[test]
fn decode_u32() {
let mut bytes = BytesMut::new();
bytes.extend_from_slice(&[
4, 0, 0, 0, // 1 32-bit integer = 4 bytes.
1, 3, 3, 7, // Little-endian integer.
4, 2, // Trailing bytes.
]);
let value = FrameDecoder::new().decode_from(&mut bytes).unwrap();
assert_eq!(value, Some(U32_1337));
assert_eq!(bytes, vec![4, 2]); // Decoded bytes were split off.
}
#[test]
fn decode_vec() {
let mut bytes = BytesMut::new();
bytes.extend_from_slice(&[
20, 0, 0, 0, // 5 32-bit integers = 20 bytes.
4, 0, 0, 0, // 4 elements in the vector.
1, 0, 0, 0, // Little-endian vector elements.
3, 0, 0, 0, //
3, 0, 0, 0, //
7, 0, 0, 0, //
4, 2, // Trailing bytes.
]);
let value = FrameDecoder::new().decode_from(&mut bytes).unwrap();
let expected_value: Vec<u32> = vec![1, 3, 3, 7];
assert_eq!(value, Some(expected_value));
assert_eq!(bytes, vec![4, 2]); // Decoded bytes were split off.
}
#[test]
fn roundtrip() {
let value: Vec<String> = vec![
"apples".to_string(), //
"bananas".to_string(), //
"oranges".to_string(), //
"and cheese!".to_string(), //
];
let mut buffer = BytesMut::new();
FrameEncoder::new().encode_to(&value, &mut buffer).unwrap();
let decoded = FrameDecoder::new().decode_from(&mut buffer).unwrap();
assert_eq!(decoded, Some(value));
assert_eq!(buffer, vec![]);
}
#[tokio::test]
async fn ping_pong() {
let listener = TcpListener::bind("localhost:0").await.unwrap();
let address = listener.local_addr().unwrap();
let server_task = tokio::spawn(async move {
let (stream, _peer_address) = listener.accept().await.unwrap();
let mut frame_stream = FrameStream::<String, str>::new(stream);
assert_eq!(frame_stream.read().await.unwrap(), "ping");
frame_stream.write("pong").await.unwrap();
assert_eq!(frame_stream.read().await.unwrap(), "ping");
frame_stream.write("pong").await.unwrap();
});
let stream = TcpStream::connect(address).await.unwrap();
let mut frame_stream = FrameStream::<String, str>::new(stream);
frame_stream.write("ping").await.unwrap();
assert_eq!(frame_stream.read().await.unwrap(), "pong");
frame_stream.write("ping").await.unwrap();
assert_eq!(frame_stream.read().await.unwrap(), "pong");
server_task.await.unwrap();
}
#[tokio::test]
async fn very_large_message() {
let listener = TcpListener::bind("localhost:0").await.unwrap();
let address = listener.local_addr().unwrap();
let server_task = tokio::spawn(async move {
let (stream, _peer_address) = listener.accept().await.unwrap();
let mut frame_stream = FrameStream::<String, Vec<u32>>::new(stream);
assert_eq!(frame_stream.read().await.unwrap(), "ping");
frame_stream.write(&vec![0; 10 * 4096]).await.unwrap();
});
let stream = TcpStream::connect(address).await.unwrap();
let mut frame_stream = FrameStream::<Vec<u32>, str>::new(stream);
frame_stream.write("ping").await.unwrap();
assert_eq!(frame_stream.read().await.unwrap(), vec![0; 10 * 4096]);
server_task.await.unwrap();
}
}

+ 281
- 286
src/proto/handler.rs View File

@ -30,17 +30,17 @@ const LISTEN_TOKEN: usize = config::MAX_PEERS + 1;
#[derive(Debug)]
pub enum Request {
PeerConnect(usize, net::Ipv4Addr, u16),
PeerMessage(usize, peer::Message),
ServerRequest(ServerRequest),
PeerConnect(usize, net::Ipv4Addr, u16),
PeerMessage(usize, peer::Message),
ServerRequest(ServerRequest),
}
#[derive(Debug)]
pub enum Response {
PeerConnectionClosed(usize),
PeerConnectionOpen(usize),
PeerMessage(usize, peer::Message),
ServerResponse(ServerResponse),
PeerConnectionClosed(usize),
PeerConnectionOpen(usize),
PeerMessage(usize, peer::Message),
ServerResponse(ServerResponse),
}
/*========================*
@ -50,16 +50,16 @@ pub enum Response {
pub struct ServerResponseSender(crossbeam_channel::Sender<Response>);
impl SendPacket for ServerResponseSender {
type Value = ServerResponse;
type Error = crossbeam_channel::SendError<Response>;
type Value = ServerResponse;
type Error = crossbeam_channel::SendError<Response>;
fn send_packet(&mut self, value: Self::Value) -> Result<(), Self::Error> {
self.0.send(Response::ServerResponse(value))
}
fn send_packet(&mut self, value: Self::Value) -> Result<(), Self::Error> {
self.0.send(Response::ServerResponse(value))
}
fn notify_open(&mut self) -> Result<(), Self::Error> {
Ok(())
}
fn notify_open(&mut self) -> Result<(), Self::Error> {
Ok(())
}
}
/*======================*
@ -67,21 +67,21 @@ impl SendPacket for ServerResponseSender {
*======================*/
pub struct PeerResponseSender {
sender: crossbeam_channel::Sender<Response>,
peer_id: usize,
sender: crossbeam_channel::Sender<Response>,
peer_id: usize,
}
impl SendPacket for PeerResponseSender {
type Value = peer::Message;
type Error = crossbeam_channel::SendError<Response>;
type Value = peer::Message;
type Error = crossbeam_channel::SendError<Response>;
fn send_packet(&mut self, value: Self::Value) -> Result<(), Self::Error> {
self.sender.send(Response::PeerMessage(self.peer_id, value))
}
fn send_packet(&mut self, value: Self::Value) -> Result<(), Self::Error> {
self.sender.send(Response::PeerMessage(self.peer_id, value))
}
fn notify_open(&mut self) -> Result<(), Self::Error> {
self.sender.send(Response::PeerConnectionOpen(self.peer_id))
}
fn notify_open(&mut self) -> Result<(), Self::Error> {
self.sender.send(Response::PeerConnectionOpen(self.peer_id))
}
}
/*=========*
@ -91,307 +91,302 @@ impl SendPacket for PeerResponseSender {
/// This struct handles all the soulseek connections, to the server and to
/// peers.
struct Handler {
server_stream: Stream<ServerResponseSender>,
server_stream: Stream<ServerResponseSender>,
peer_streams: slab::Slab<Stream<PeerResponseSender>, usize>,
peer_streams: slab::Slab<Stream<PeerResponseSender>, usize>,
listener: mio::tcp::TcpListener,
listener: mio::tcp::TcpListener,
client_tx: crossbeam_channel::Sender<Response>,
client_tx: crossbeam_channel::Sender<Response>,
}
fn listener_bind<U>(addr_spec: U) -> io::Result<mio::tcp::TcpListener>
where
U: ToSocketAddrs + fmt::Debug,
U: ToSocketAddrs + fmt::Debug,
{
for socket_addr in addr_spec.to_socket_addrs()? {
if let Ok(listener) = mio::tcp::TcpListener::bind(&socket_addr) {
return Ok(listener);
}
for socket_addr in addr_spec.to_socket_addrs()? {
if let Ok(listener) = mio::tcp::TcpListener::bind(&socket_addr) {
return Ok(listener);
}
Err(io::Error::new(
io::ErrorKind::Other,
format!("Cannot bind to {:?}", addr_spec),
))
}
Err(io::Error::new(
io::ErrorKind::Other,
format!("Cannot bind to {:?}", addr_spec),
))
}
impl Handler {
#[allow(deprecated)]
fn new(
client_tx: crossbeam_channel::Sender<Response>,
event_loop: &mut mio::deprecated::EventLoop<Self>,
) -> io::Result<Self> {
let host = config::SERVER_HOST;
let port = config::SERVER_PORT;
let server_stream =
Stream::new((host, port), ServerResponseSender(client_tx.clone()))?;
info!("Connected to server at {}:{}", host, port);
let listener =
listener_bind((config::LISTEN_HOST, config::LISTEN_PORT))?;
info!(
"Listening for connections on {}:{}",
config::LISTEN_HOST,
config::LISTEN_PORT
);
event_loop.register(
server_stream.evented(),
#[allow(deprecated)]
fn new(
client_tx: crossbeam_channel::Sender<Response>,
event_loop: &mut mio::deprecated::EventLoop<Self>,
) -> io::Result<Self> {
let host = config::SERVER_HOST;
let port = config::SERVER_PORT;
let server_stream =
Stream::new((host, port), ServerResponseSender(client_tx.clone()))?;
info!("Connected to server at {}:{}", host, port);
let listener = listener_bind((config::LISTEN_HOST, config::LISTEN_PORT))?;
info!(
"Listening for connections on {}:{}",
config::LISTEN_HOST,
config::LISTEN_PORT
);
event_loop.register(
server_stream.evented(),
mio::Token(SERVER_TOKEN),
mio::Ready::all(),
mio::PollOpt::edge() | mio::PollOpt::oneshot(),
)?;
event_loop.register(
&listener,
mio::Token(LISTEN_TOKEN),
mio::Ready::all(),
mio::PollOpt::edge() | mio::PollOpt::oneshot(),
)?;
Ok(Handler {
server_stream: server_stream,
peer_streams: slab::Slab::new(config::MAX_PEERS),
listener: listener,
client_tx: client_tx,
})
}
#[allow(deprecated)]
fn connect_to_peer(
&mut self,
peer_id: usize,
ip: net::Ipv4Addr,
port: u16,
event_loop: &mut mio::deprecated::EventLoop<Self>,
) -> Result<(), String> {
let vacant_entry = match self.peer_streams.entry(peer_id) {
None => return Err("id out of range".to_string()),
Some(slab::Entry::Occupied(_occupied_entry)) => {
return Err("id already taken".to_string());
}
Some(slab::Entry::Vacant(vacant_entry)) => vacant_entry,
};
info!("Opening peer connection {} to {}:{}", peer_id, ip, port);
let sender = PeerResponseSender {
sender: self.client_tx.clone(),
peer_id: peer_id,
};
let peer_stream = match Stream::new((ip, port), sender) {
Ok(peer_stream) => peer_stream,
Err(err) => return Err(format!("i/o error: {}", err)),
};
event_loop
.register(
peer_stream.evented(),
mio::Token(peer_id),
mio::Ready::all(),
mio::PollOpt::edge() | mio::PollOpt::oneshot(),
)
.unwrap();
vacant_entry.insert(peer_stream);
Ok(())
}
#[allow(deprecated)]
fn process_server_intent(
&mut self,
intent: Intent,
event_loop: &mut mio::deprecated::EventLoop<Self>,
) {
match intent {
Intent::Done => {
error!("Server connection closed");
// TODO notify client and shut down
}
Intent::Continue(event_set) => {
event_loop
.reregister(
self.server_stream.evented(),
mio::Token(SERVER_TOKEN),
mio::Ready::all(),
event_set,
mio::PollOpt::edge() | mio::PollOpt::oneshot(),
)?;
event_loop.register(
&listener,
mio::Token(LISTEN_TOKEN),
mio::Ready::all(),
mio::PollOpt::edge() | mio::PollOpt::oneshot(),
)?;
Ok(Handler {
server_stream: server_stream,
peer_streams: slab::Slab::new(config::MAX_PEERS),
listener: listener,
client_tx: client_tx,
})
)
.unwrap();
}
}
#[allow(deprecated)]
fn connect_to_peer(
&mut self,
peer_id: usize,
ip: net::Ipv4Addr,
port: u16,
event_loop: &mut mio::deprecated::EventLoop<Self>,
) -> Result<(), String> {
let vacant_entry = match self.peer_streams.entry(peer_id) {
None => return Err("id out of range".to_string()),
Some(slab::Entry::Occupied(_occupied_entry)) => {
return Err("id already taken".to_string());
}
Some(slab::Entry::Vacant(vacant_entry)) => vacant_entry,
};
info!("Opening peer connection {} to {}:{}", peer_id, ip, port);
let sender = PeerResponseSender {
sender: self.client_tx.clone(),
peer_id: peer_id,
};
let peer_stream = match Stream::new((ip, port), sender) {
Ok(peer_stream) => peer_stream,
Err(err) => return Err(format!("i/o error: {}", err)),
};
event_loop
.register(
peer_stream.evented(),
mio::Token(peer_id),
mio::Ready::all(),
mio::PollOpt::edge() | mio::PollOpt::oneshot(),
}
#[allow(deprecated)]
fn process_peer_intent(
&mut self,
intent: Intent,
token: mio::Token,
event_loop: &mut mio::deprecated::EventLoop<Self>,
) {
match intent {
Intent::Done => {
self.peer_streams.remove(token.0);
self
.client_tx
.send(Response::PeerConnectionClosed(token.0))
.unwrap();
}
Intent::Continue(event_set) => {
if let Some(peer_stream) = self.peer_streams.get_mut(token.0) {
event_loop
.reregister(
peer_stream.evented(),
token,
event_set,
mio::PollOpt::edge() | mio::PollOpt::oneshot(),
)
.unwrap();
vacant_entry.insert(peer_stream);
Ok(())
}
#[allow(deprecated)]
fn process_server_intent(
&mut self,
intent: Intent,
event_loop: &mut mio::deprecated::EventLoop<Self>,
) {
match intent {
Intent::Done => {
error!("Server connection closed");
// TODO notify client and shut down
}
Intent::Continue(event_set) => {
event_loop
.reregister(
self.server_stream.evented(),
mio::Token(SERVER_TOKEN),
event_set,
mio::PollOpt::edge() | mio::PollOpt::oneshot(),
)
.unwrap();
}
}
}
#[allow(deprecated)]
fn process_peer_intent(
&mut self,
intent: Intent,
token: mio::Token,
event_loop: &mut mio::deprecated::EventLoop<Self>,
) {
match intent {
Intent::Done => {
self.peer_streams.remove(token.0);
self.client_tx
.send(Response::PeerConnectionClosed(token.0))
.unwrap();
}
Intent::Continue(event_set) => {
if let Some(peer_stream) = self.peer_streams.get_mut(token.0) {
event_loop
.reregister(
peer_stream.evented(),
token,
event_set,
mio::PollOpt::edge() | mio::PollOpt::oneshot(),
)
.unwrap();
}
}
}
}
}
}
}
#[allow(deprecated)]
impl mio::deprecated::Handler for Handler {
type Timeout = ();
type Message = Request;
fn ready(
&mut self,
event_loop: &mut mio::deprecated::EventLoop<Self>,
token: mio::Token,
event_set: mio::Ready,
) {
match token {
mio::Token(LISTEN_TOKEN) => {
if event_set.is_readable() {
// A peer wants to connect to us.
match self.listener.accept() {
Ok((_sock, addr)) => {
// TODO add it to peer streams
info!("Peer connection accepted from {}", addr);
}
Err(err) => {
error!("Cannot accept peer connection: {}", err);
}
}
}
event_loop
.reregister(
&self.listener,
token,
mio::Ready::all(),
mio::PollOpt::edge() | mio::PollOpt::oneshot(),
)
.unwrap();
type Timeout = ();
type Message = Request;
fn ready(
&mut self,
event_loop: &mut mio::deprecated::EventLoop<Self>,
token: mio::Token,
event_set: mio::Ready,
) {
match token {
mio::Token(LISTEN_TOKEN) => {
if event_set.is_readable() {
// A peer wants to connect to us.
match self.listener.accept() {
Ok((_sock, addr)) => {
// TODO add it to peer streams
info!("Peer connection accepted from {}", addr);
}
mio::Token(SERVER_TOKEN) => {
let intent = self.server_stream.on_ready(event_set);
self.process_server_intent(intent, event_loop);
}
mio::Token(peer_id) => {
let intent = match self.peer_streams.get_mut(peer_id) {
Some(peer_stream) => peer_stream.on_ready(event_set),
None => unreachable!("Unknown peer {} is ready", peer_id),
};
self.process_peer_intent(intent, token, event_loop);
Err(err) => {
error!("Cannot accept peer connection: {}", err);
}
}
}
}
event_loop
.reregister(
&self.listener,
token,
mio::Ready::all(),
mio::PollOpt::edge() | mio::PollOpt::oneshot(),
)
.unwrap();
}
fn notify(
&mut self,
event_loop: &mut mio::deprecated::EventLoop<Self>,
request: Request,
) {
match request {
Request::PeerConnect(peer_id, ip, port) => {
if let Err(err) =
self.connect_to_peer(peer_id, ip, port, event_loop)
{
error!(
"Cannot open peer connection {} to {}:{}: {}",
peer_id, ip, port, err
);
self.client_tx
.send(Response::PeerConnectionClosed(peer_id))
.unwrap();
}
}
mio::Token(SERVER_TOKEN) => {
let intent = self.server_stream.on_ready(event_set);
self.process_server_intent(intent, event_loop);
}
Request::PeerMessage(peer_id, message) => {
let intent = match self.peer_streams.get_mut(peer_id) {
Some(peer_stream) => peer_stream.on_notify(&message),
None => {
error!(
"Cannot send peer message {:?}: unknown id {}",
message, peer_id
);
return;
}
};
self.process_peer_intent(
intent,
mio::Token(peer_id),
event_loop,
);
}
mio::Token(peer_id) => {
let intent = match self.peer_streams.get_mut(peer_id) {
Some(peer_stream) => peer_stream.on_ready(event_set),
Request::ServerRequest(server_request) => {
let intent = self.server_stream.on_notify(&server_request);
self.process_server_intent(intent, event_loop);
}
None => unreachable!("Unknown peer {} is ready", peer_id),
};
self.process_peer_intent(intent, token, event_loop);
}
}
}
fn notify(
&mut self,
event_loop: &mut mio::deprecated::EventLoop<Self>,
request: Request,
) {
match request {
Request::PeerConnect(peer_id, ip, port) => {
if let Err(err) = self.connect_to_peer(peer_id, ip, port, event_loop) {
error!(
"Cannot open peer connection {} to {}:{}: {}",
peer_id, ip, port, err
);
self
.client_tx
.send(Response::PeerConnectionClosed(peer_id))
.unwrap();
}
}
Request::PeerMessage(peer_id, message) => {
let intent = match self.peer_streams.get_mut(peer_id) {
Some(peer_stream) => peer_stream.on_notify(&message),
None => {
error!(
"Cannot send peer message {:?}: unknown id {}",
message, peer_id
);
return;
}
};
self.process_peer_intent(intent, mio::Token(peer_id), event_loop);
}
Request::ServerRequest(server_request) => {
let intent = self.server_stream.on_notify(&server_request);
self.process_server_intent(intent, event_loop);
}
}
}
}
#[allow(deprecated)]
pub type Sender = mio::deprecated::Sender<Request>;
pub struct Agent {
#[allow(deprecated)]
event_loop: mio::deprecated::EventLoop<Handler>,
handler: Handler,
#[allow(deprecated)]
event_loop: mio::deprecated::EventLoop<Handler>,
handler: Handler,
}
impl Agent {
pub fn new(
client_tx: crossbeam_channel::Sender<Response>,
) -> io::Result<Self> {
// Create the event loop.
#[allow(deprecated)]
let mut event_loop = mio::deprecated::EventLoop::new()?;
// Create the handler for the event loop and register the handler's
// sockets with the event loop.
let handler = Handler::new(client_tx, &mut event_loop)?;
Ok(Agent {
event_loop: event_loop,
handler: handler,
})
}
pub fn channel(&self) -> Sender {
#[allow(deprecated)]
self.event_loop.channel()
}
pub fn new(
client_tx: crossbeam_channel::Sender<Response>,
) -> io::Result<Self> {
// Create the event loop.
#[allow(deprecated)]
let mut event_loop = mio::deprecated::EventLoop::new()?;
// Create the handler for the event loop and register the handler's
// sockets with the event loop.
let handler = Handler::new(client_tx, &mut event_loop)?;
Ok(Agent {
event_loop: event_loop,
handler: handler,
})
}
pub fn channel(&self) -> Sender {
#[allow(deprecated)]
self.event_loop.channel()
}
pub fn run(&mut self) -> io::Result<()> {
#[allow(deprecated)]
self.event_loop.run(&mut self.handler)
}
pub fn run(&mut self) -> io::Result<()> {
#[allow(deprecated)]
self.event_loop.run(&mut self.handler)
}
}

+ 2
- 2
src/proto/mod.rs View File

@ -19,6 +19,6 @@ pub use self::server::{ServerRequest, ServerResponse};
pub use self::stream::*;
pub use self::user::{User, UserStatus};
pub use self::value_codec::{
Decode, ValueDecode, ValueDecodeError, ValueDecoder, ValueEncode,
ValueEncodeError, ValueEncoder,
Decode, ValueDecode, ValueDecodeError, ValueDecoder, ValueEncode,
ValueEncodeError, ValueEncoder,
};

+ 251
- 254
src/proto/packet.rs View File

@ -19,46 +19,46 @@ use super::constants::*;
#[derive(Debug)]
pub struct Packet {
/// The current read position in the byte buffer.
cursor: usize,
/// The underlying bytes.
bytes: Vec<u8>,
/// The current read position in the byte buffer.
cursor: usize,
/// The underlying bytes.
bytes: Vec<u8>,
}
impl io::Read for Packet {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let bytes_read = {
let mut slice = &self.bytes[self.cursor..];
slice.read(buf)?
};
self.cursor += bytes_read;
Ok(bytes_read)
}
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let bytes_read = {
let mut slice = &self.bytes[self.cursor..];
slice.read(buf)?
};
self.cursor += bytes_read;
Ok(bytes_read)
}
}
impl Packet {
/// Returns a readable packet struct from the wire representation of a
/// packet.
/// Assumes that the given vector is a valid length-prefixed packet.
fn from_wire(bytes: Vec<u8>) -> Self {
Packet {
cursor: U32_SIZE,
bytes: bytes,
}
}
/// Provides the main way to read data out of a binary packet.
pub fn read_value<T>(&mut self) -> Result<T, PacketReadError>
where
T: ReadFromPacket,
{
T::read_from_packet(self)
}
/// Returns the number of unread bytes remaining in the packet.
pub fn bytes_remaining(&self) -> usize {
self.bytes.len() - self.cursor
/// Returns a readable packet struct from the wire representation of a
/// packet.
/// Assumes that the given vector is a valid length-prefixed packet.
fn from_wire(bytes: Vec<u8>) -> Self {
Packet {
cursor: U32_SIZE,
bytes: bytes,
}
}
/// Provides the main way to read data out of a binary packet.
pub fn read_value<T>(&mut self) -> Result<T, PacketReadError>
where
T: ReadFromPacket,
{
T::read_from_packet(self)
}
/// Returns the number of unread bytes remaining in the packet.
pub fn bytes_remaining(&self) -> usize {
self.bytes.len() - self.cursor
}
}
/*===================*
@ -67,45 +67,45 @@ impl Packet {
#[derive(Debug)]
pub struct MutPacket {
bytes: Vec<u8>,
bytes: Vec<u8>,
}
impl MutPacket {
/// Returns an empty packet with the given packet code.
pub fn new() -> Self {
// Leave space for the eventual size of the packet.
MutPacket {
bytes: vec![0; U32_SIZE],
}
/// Returns an empty packet with the given packet code.
pub fn new() -> Self {
// Leave space for the eventual size of the packet.
MutPacket {
bytes: vec![0; U32_SIZE],
}
/// Provides the main way to write data into a binary packet.
pub fn write_value<T>(&mut self, val: &T) -> io::Result<()>
where
T: WriteToPacket,
}
/// Provides the main way to write data into a binary packet.
pub fn write_value<T>(&mut self, val: &T) -> io::Result<()>
where
T: WriteToPacket,
{
val.write_to_packet(self)
}
/// Consumes the mutable packet and returns its wire representation.
pub fn into_bytes(mut self) -> Vec<u8> {
let length = (self.bytes.len() - U32_SIZE) as u32;
{
val.write_to_packet(self)
}
/// Consumes the mutable packet and returns its wire representation.
pub fn into_bytes(mut self) -> Vec<u8> {
let length = (self.bytes.len() - U32_SIZE) as u32;
{
let mut first_word = &mut self.bytes[..U32_SIZE];
first_word.write_u32::<LittleEndian>(length).unwrap();
}
self.bytes
let mut first_word = &mut self.bytes[..U32_SIZE];
first_word.write_u32::<LittleEndian>(length).unwrap();
}
self.bytes
}
}
impl io::Write for MutPacket {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.bytes.write(buf)
}
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.bytes.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.bytes.flush()
}
fn flush(&mut self) -> io::Result<()> {
self.bytes.flush()
}
}
/*===================*
@ -115,70 +115,68 @@ impl io::Write for MutPacket {
/// This enum contains an error that arose when reading data out of a Packet.
#[derive(Debug)]
pub enum PacketReadError {
/// Attempted to read a boolean, but the value was not 0 nor 1.
InvalidBoolError(u8),
/// Attempted to read an unsigned 16-bit integer, but the value was too
/// large.
InvalidU16Error(u32),
/// Attempted to read a string, but a character was invalid.
InvalidStringError(Vec<u8>),
/// Attempted to read a user::Status, but the value was not a valid
/// representation of an enum variant.
InvalidUserStatusError(u32),
/// Encountered an I/O error while reading.
IOError(io::Error),
/// Attempted to read a boolean, but the value was not 0 nor 1.
InvalidBoolError(u8),
/// Attempted to read an unsigned 16-bit integer, but the value was too
/// large.
InvalidU16Error(u32),
/// Attempted to read a string, but a character was invalid.
InvalidStringError(Vec<u8>),
/// Attempted to read a user::Status, but the value was not a valid
/// representation of an enum variant.
InvalidUserStatusError(u32),
/// Encountered an I/O error while reading.
IOError(io::Error),
}
impl fmt::Display for PacketReadError {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
match *self {
PacketReadError::InvalidBoolError(n) => {
write!(fmt, "InvalidBoolError: {}", n)
}
PacketReadError::InvalidU16Error(n) => {
write!(fmt, "InvalidU16Error: {}", n)
}
PacketReadError::InvalidStringError(ref bytes) => {
write!(fmt, "InvalidStringError: {:?}", bytes)
}
PacketReadError::InvalidUserStatusError(n) => {
write!(fmt, "InvalidUserStatusError: {}", n)
}
PacketReadError::IOError(ref err) => {
write!(fmt, "IOError: {}", err)
}
}
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
match *self {
PacketReadError::InvalidBoolError(n) => {
write!(fmt, "InvalidBoolError: {}", n)
}
PacketReadError::InvalidU16Error(n) => {
write!(fmt, "InvalidU16Error: {}", n)
}
PacketReadError::InvalidStringError(ref bytes) => {
write!(fmt, "InvalidStringError: {:?}", bytes)
}
PacketReadError::InvalidUserStatusError(n) => {
write!(fmt, "InvalidUserStatusError: {}", n)
}
PacketReadError::IOError(ref err) => {
write!(fmt, "IOError: {}", err)
}
}
}
}
impl error::Error for PacketReadError {
fn description(&self) -> &str {
match *self {
PacketReadError::InvalidBoolError(_) => "InvalidBoolError",
PacketReadError::InvalidU16Error(_) => "InvalidU16Error",
PacketReadError::InvalidStringError(_) => "InvalidStringError",
PacketReadError::InvalidUserStatusError(_) => {
"InvalidUserStatusError"
}
PacketReadError::IOError(_) => "IOError",
}
fn description(&self) -> &str {
match *self {
PacketReadError::InvalidBoolError(_) => "InvalidBoolError",
PacketReadError::InvalidU16Error(_) => "InvalidU16Error",
PacketReadError::InvalidStringError(_) => "InvalidStringError",
PacketReadError::InvalidUserStatusError(_) => "InvalidUserStatusError",
PacketReadError::IOError(_) => "IOError",
}
fn cause(&self) -> Option<&dyn error::Error> {
match *self {
PacketReadError::InvalidBoolError(_) => None,
PacketReadError::InvalidU16Error(_) => None,
PacketReadError::InvalidStringError(_) => None,
PacketReadError::InvalidUserStatusError(_) => None,
PacketReadError::IOError(ref err) => Some(err),
}
}
fn cause(&self) -> Option<&dyn error::Error> {
match *self {
PacketReadError::InvalidBoolError(_) => None,
PacketReadError::InvalidU16Error(_) => None,
PacketReadError::InvalidStringError(_) => None,
PacketReadError::InvalidUserStatusError(_) => None,
PacketReadError::IOError(ref err) => Some(err),
}
}
}
impl From<io::Error> for PacketReadError {
fn from(err: io::Error) -> Self {
PacketReadError::IOError(err)
}
fn from(err: io::Error) -> Self {
PacketReadError::IOError(err)
}
}
/*==================*
@ -188,81 +186,81 @@ impl From<io::Error> for PacketReadError {
/// This trait is implemented by types that can be deserialized from binary
/// Packets.
pub trait ReadFromPacket: Sized {
fn read_from_packet(_: &mut Packet) -> Result<Self, PacketReadError>;
fn read_from_packet(_: &mut Packet) -> Result<Self, PacketReadError>;
}
/// 32-bit integers are serialized in 4 bytes, little-endian.
impl ReadFromPacket for u32 {
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
Ok(packet.read_u32::<LittleEndian>()?)
}
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
Ok(packet.read_u32::<LittleEndian>()?)
}
}
/// For convenience, usize's are deserialized as u32's then casted.
impl ReadFromPacket for usize {
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
Ok(u32::read_from_packet(packet)? as usize)
}
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
Ok(u32::read_from_packet(packet)? as usize)
}
}
/// Booleans are serialized as single bytes, containing either 0 or 1.
impl ReadFromPacket for bool {
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
match packet.read_u8()? {
0 => Ok(false),
1 => Ok(true),
n => Err(PacketReadError::InvalidBoolError(n)),
}
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
match packet.read_u8()? {
0 => Ok(false),
1 => Ok(true),
n => Err(PacketReadError::InvalidBoolError(n)),
}
}
}
/// 16-bit integers are serialized as 32-bit integers.
impl ReadFromPacket for u16 {
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let n = u32::read_from_packet(packet)?;
if n > MAX_PORT {
return Err(PacketReadError::InvalidU16Error(n));
}
Ok(n as u16)
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let n = u32::read_from_packet(packet)?;
if n > MAX_PORT {
return Err(PacketReadError::InvalidU16Error(n));
}
Ok(n as u16)
}
}
/// IPv4 addresses are serialized directly as 32-bit integers.
impl ReadFromPacket for net::Ipv4Addr {
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let ip = u32::read_from_packet(packet)?;
Ok(net::Ipv4Addr::from(ip))
}
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let ip = u32::read_from_packet(packet)?;
Ok(net::Ipv4Addr::from(ip))
}
}
/// Strings are serialized as length-prefixed arrays of ISO-8859-1 encoded
/// characters.
impl ReadFromPacket for String {
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let len = usize::read_from_packet(packet)?;
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let len = usize::read_from_packet(packet)?;
let mut buffer = vec![0; len];
packet.read_exact(&mut buffer)?;
let mut buffer = vec![0; len];
packet.read_exact(&mut buffer)?;
match ISO_8859_1.decode(&buffer, DecoderTrap::Strict) {
Ok(string) => Ok(string),
Err(_) => Err(PacketReadError::InvalidStringError(buffer)),
}
match ISO_8859_1.decode(&buffer, DecoderTrap::Strict) {
Ok(string) => Ok(string),
Err(_) => Err(PacketReadError::InvalidStringError(buffer)),
}
}
}
/// Vectors are serialized as length-prefixed arrays of values.
impl<T: ReadFromPacket> ReadFromPacket for Vec<T> {
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let len = usize::read_from_packet(packet)?;
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let len = usize::read_from_packet(packet)?;
let mut vec = Vec::new();
for _ in 0..len {
vec.push(T::read_from_packet(packet)?);
}
Ok(vec)
let mut vec = Vec::new();
for _ in 0..len {
vec.push(T::read_from_packet(packet)?);
}
Ok(vec)
}
}
/*=================*
@ -272,55 +270,55 @@ impl<T: ReadFromPacket> ReadFromPacket for Vec<T> {
/// This trait is implemented by types that can be serialized to a binary
/// MutPacket.
pub trait WriteToPacket {
fn write_to_packet(&self, _: &mut MutPacket) -> io::Result<()>;
fn write_to_packet(&self, _: &mut MutPacket) -> io::Result<()>;
}
/// 32-bit integers are serialized in 4 bytes, little-endian.
impl WriteToPacket for u32 {
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
packet.write_u32::<LittleEndian>(*self)
}
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
packet.write_u32::<LittleEndian>(*self)
}
}
/// Booleans are serialized as single bytes, containing either 0 or 1.
impl WriteToPacket for bool {
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
packet.write_u8(*self as u8)?;
Ok(())
}
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
packet.write_u8(*self as u8)?;
Ok(())
}
}
/// 16-bit integers are serialized as 32-bit integers.
impl WriteToPacket for u16 {
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
(*self as u32).write_to_packet(packet)
}
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
(*self as u32).write_to_packet(packet)
}
}
/// Strings are serialized as a length-prefixed array of ISO-8859-1 encoded
/// characters.
impl WriteToPacket for str {
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
// Encode the string.
let bytes = match ISO_8859_1.encode(self, EncoderTrap::Strict) {
Ok(bytes) => bytes,
Err(_) => {
let copy = self.to_string();
return Err(io::Error::new(io::ErrorKind::Other, copy));
}
};
// Then write the bytes to the packet.
(bytes.len() as u32).write_to_packet(packet)?;
packet.write(&bytes)?;
Ok(())
}
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
// Encode the string.
let bytes = match ISO_8859_1.encode(self, EncoderTrap::Strict) {
Ok(bytes) => bytes,
Err(_) => {
let copy = self.to_string();
return Err(io::Error::new(io::ErrorKind::Other, copy));
}
};
// Then write the bytes to the packet.
(bytes.len() as u32).write_to_packet(packet)?;
packet.write(&bytes)?;
Ok(())
}
}
/// Deref coercion does not happen for trait methods apparently.
impl WriteToPacket for String {
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
(self as &str).write_to_packet(packet)
}
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
(self as &str).write_to_packet(packet)
}
}
/*========*
@ -330,87 +328,86 @@ impl WriteToPacket for String {
/// This enum defines the possible states of a packet parser state machine.
#[derive(Debug, Clone, Copy)]
enum State {
/// The parser is waiting to read enough bytes to determine the
/// length of the following packet.
ReadingLength,
/// The parser is waiting to read enough bytes to form the entire
/// packet.
ReadingPacket,
/// The parser is waiting to read enough bytes to determine the
/// length of the following packet.
ReadingLength,
/// The parser is waiting to read enough bytes to form the entire
/// packet.
ReadingPacket,
}
#[derive(Debug)]
pub struct Parser {
state: State,
num_bytes_left: usize,
buffer: Vec<u8>,
state: State,
num_bytes_left: usize,
buffer: Vec<u8>,
}
impl Parser {
pub fn new() -> Self {
Parser {
state: State::ReadingLength,
num_bytes_left: U32_SIZE,
buffer: vec![0; U32_SIZE],
}
pub fn new() -> Self {
Parser {
state: State::ReadingLength,
num_bytes_left: U32_SIZE,
buffer: vec![0; U32_SIZE],
}
}
/// Attemps to read a packet in a non-blocking fashion.
/// If enough bytes can be read from the given byte stream to form a
/// complete packet `p`, returns `Ok(Some(p))`.
/// If not enough bytes are available, returns `Ok(None)`.
/// If an I/O error `e` arises when trying to read the underlying stream,
/// returns `Err(e)`.
/// Note: as long as this function returns `Ok(Some(p))`, the caller is
/// responsible for calling it once more to ensure that all packets are
/// read as soon as possible.
pub fn try_read<U>(&mut self, stream: &mut U) -> io::Result<Option<Packet>>
where
U: io::Read,
{
// Try to read as many bytes as we currently need from the underlying
// byte stream.
let offset = self.buffer.len() - self.num_bytes_left;
#[allow(deprecated)]
match stream.try_read(&mut self.buffer[offset..])? {
None => (),
Some(num_bytes_read) => {
self.num_bytes_left -= num_bytes_read;
}
}
/// Attemps to read a packet in a non-blocking fashion.
/// If enough bytes can be read from the given byte stream to form a
/// complete packet `p`, returns `Ok(Some(p))`.
/// If not enough bytes are available, returns `Ok(None)`.
/// If an I/O error `e` arises when trying to read the underlying stream,
/// returns `Err(e)`.
/// Note: as long as this function returns `Ok(Some(p))`, the caller is
/// responsible for calling it once more to ensure that all packets are
/// read as soon as possible.
pub fn try_read<U>(&mut self, stream: &mut U) -> io::Result<Option<Packet>>
where
U: io::Read,
{
// Try to read as many bytes as we currently need from the underlying
// byte stream.
let offset = self.buffer.len() - self.num_bytes_left;
#[allow(deprecated)]
match stream.try_read(&mut self.buffer[offset..])? {
None => (),
Some(num_bytes_read) => {
self.num_bytes_left -= num_bytes_read;
}
}
// If we haven't read enough bytes, return.
if self.num_bytes_left > 0 {
return Ok(None);
}
// Otherwise, the behavior depends on what state we were in.
match self.state {
State::ReadingLength => {
// If we have finished reading the length prefix, then
// deserialize it, switch states and try to read the packet
// bytes.
let message_len =
LittleEndian::read_u32(&mut self.buffer) as usize;
if message_len > MAX_MESSAGE_SIZE {
unimplemented!();
};
self.state = State::ReadingPacket;
self.num_bytes_left = message_len;
self.buffer.resize(message_len + U32_SIZE, 0);
self.try_read(stream)
}
State::ReadingPacket => {
// If we have finished reading the packet, swap the full buffer
// out and return the packet made from the full buffer.
self.state = State::ReadingLength;
self.num_bytes_left = U32_SIZE;
let new_buffer = vec![0; U32_SIZE];
let old_buffer = mem::replace(&mut self.buffer, new_buffer);
Ok(Some(Packet::from_wire(old_buffer)))
}
}
// If we haven't read enough bytes, return.
if self.num_bytes_left > 0 {
return Ok(None);
}
// Otherwise, the behavior depends on what state we were in.
match self.state {
State::ReadingLength => {
// If we have finished reading the length prefix, then
// deserialize it, switch states and try to read the packet
// bytes.
let message_len = LittleEndian::read_u32(&mut self.buffer) as usize;
if message_len > MAX_MESSAGE_SIZE {
unimplemented!();
};
self.state = State::ReadingPacket;
self.num_bytes_left = message_len;
self.buffer.resize(message_len + U32_SIZE, 0);
self.try_read(stream)
}
State::ReadingPacket => {
// If we have finished reading the packet, swap the full buffer
// out and return the packet made from the full buffer.
self.state = State::ReadingLength;
self.num_bytes_left = U32_SIZE;
let new_buffer = vec![0; U32_SIZE];
let old_buffer = mem::replace(&mut self.buffer, new_buffer);
Ok(Some(Packet::from_wire(old_buffer)))
}
}
}
}

+ 147
- 159
src/proto/peer/message.rs View File

@ -2,9 +2,9 @@ use std::io;
use crate::proto::peer::constants::*;
use crate::proto::{
MutPacket, Packet, PacketReadError, ReadFromPacket, ValueDecode,
ValueDecodeError, ValueDecoder, ValueEncode, ValueEncodeError,
ValueEncoder, WriteToPacket,
MutPacket, Packet, PacketReadError, ReadFromPacket, ValueDecode,
ValueDecodeError, ValueDecoder, ValueEncode, ValueEncodeError, ValueEncoder,
WriteToPacket,
};
/*=========*
@ -14,195 +14,183 @@ use crate::proto::{
/// This enum contains all the possible messages peers can exchange.
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum Message {
PierceFirewall(u32),
PeerInit(PeerInit),
Unknown(u32),
PierceFirewall(u32),
PeerInit(PeerInit),
Unknown(u32),
}
impl ReadFromPacket for Message {
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let code: u32 = packet.read_value()?;
let message = match code {
CODE_PIERCE_FIREWALL => {
Message::PierceFirewall(packet.read_value()?)
}
CODE_PEER_INIT => Message::PeerInit(packet.read_value()?),
code => Message::Unknown(code),
};
let bytes_remaining = packet.bytes_remaining();
if bytes_remaining > 0 {
warn!(
"Peer message with code {} contains {} extra bytes",
code, bytes_remaining
)
}
Ok(message)
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let code: u32 = packet.read_value()?;
let message = match code {
CODE_PIERCE_FIREWALL => Message::PierceFirewall(packet.read_value()?),
CODE_PEER_INIT => Message::PeerInit(packet.read_value()?),
code => Message::Unknown(code),
};
let bytes_remaining = packet.bytes_remaining();
if bytes_remaining > 0 {
warn!(
"Peer message with code {} contains {} extra bytes",
code, bytes_remaining
)
}
Ok(message)
}
}
impl ValueDecode for Message {
fn decode_from(
decoder: &mut ValueDecoder,
) -> Result<Self, ValueDecodeError> {
let position = decoder.position();
let code: u32 = decoder.decode()?;
let message = match code {
CODE_PIERCE_FIREWALL => {
let val = decoder.decode()?;
Message::PierceFirewall(val)
}
CODE_PEER_INIT => {
let peer_init = decoder.decode()?;
Message::PeerInit(peer_init)
}
_ => {
return Err(ValueDecodeError::InvalidData {
value_name: "peer message code".to_string(),
cause: format!("unknown value {}", code),
position: position,
})
}
};
Ok(message)
}
fn decode_from(decoder: &mut ValueDecoder) -> Result<Self, ValueDecodeError> {
let position = decoder.position();
let code: u32 = decoder.decode()?;
let message = match code {
CODE_PIERCE_FIREWALL => {
let val = decoder.decode()?;
Message::PierceFirewall(val)
}
CODE_PEER_INIT => {
let peer_init = decoder.decode()?;
Message::PeerInit(peer_init)
}
_ => {
return Err(ValueDecodeError::InvalidData {
value_name: "peer message code".to_string(),
cause: format!("unknown value {}", code),
position: position,
})
}
};
Ok(message)
}
}
impl ValueEncode for Message {
fn encode(
&self,
encoder: &mut ValueEncoder,
) -> Result<(), ValueEncodeError> {
match *self {
Message::PierceFirewall(token) => {
encoder.encode_u32(CODE_PIERCE_FIREWALL)?;
encoder.encode_u32(token)?;
}
Message::PeerInit(ref request) => {
encoder.encode_u32(CODE_PEER_INIT)?;
request.encode(encoder)?;
}
Message::Unknown(_) => unreachable!(),
}
Ok(())
fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> {
match *self {
Message::PierceFirewall(token) => {
encoder.encode_u32(CODE_PIERCE_FIREWALL)?;
encoder.encode_u32(token)?;
}
Message::PeerInit(ref request) => {
encoder.encode_u32(CODE_PEER_INIT)?;
request.encode(encoder)?;
}
Message::Unknown(_) => unreachable!(),
}
Ok(())
}
}
impl WriteToPacket for Message {
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
match *self {
Message::PierceFirewall(ref token) => {
packet.write_value(&CODE_PIERCE_FIREWALL)?;
packet.write_value(token)?;
}
Message::PeerInit(ref request) => {
packet.write_value(&CODE_PEER_INIT)?;
packet.write_value(request)?;
}
Message::Unknown(_) => unreachable!(),
}
Ok(())
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
match *self {
Message::PierceFirewall(ref token) => {
packet.write_value(&CODE_PIERCE_FIREWALL)?;
packet.write_value(token)?;
}
Message::PeerInit(ref request) => {
packet.write_value(&CODE_PEER_INIT)?;
packet.write_value(request)?;
}
Message::Unknown(_) => unreachable!(),
}
Ok(())
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct PeerInit {
pub user_name: String,
pub connection_type: String,
pub token: u32,
pub user_name: String,
pub connection_type: String,
pub token: u32,
}
impl ReadFromPacket for PeerInit {
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let user_name = packet.read_value()?;
let connection_type = packet.read_value()?;
let token = packet.read_value()?;
Ok(PeerInit {
user_name,
connection_type,
token,
})
}
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let user_name = packet.read_value()?;
let connection_type = packet.read_value()?;
let token = packet.read_value()?;
Ok(PeerInit {
user_name,
connection_type,
token,
})
}
}
impl WriteToPacket for PeerInit {
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
packet.write_value(&self.user_name)?;
packet.write_value(&self.connection_type)?;
packet.write_value(&self.token)?;
Ok(())
}
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
packet.write_value(&self.user_name)?;
packet.write_value(&self.connection_type)?;
packet.write_value(&self.token)?;
Ok(())
}
}
impl ValueEncode for PeerInit {
fn encode(
&self,
encoder: &mut ValueEncoder,
) -> Result<(), ValueEncodeError> {
encoder.encode_string(&self.user_name)?;
encoder.encode_string(&self.connection_type)?;
encoder.encode_u32(self.token)?;
Ok(())
}
fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> {
encoder.encode_string(&self.user_name)?;
encoder.encode_string(&self.connection_type)?;
encoder.encode_u32(self.token)?;
Ok(())
}
}
impl ValueDecode for PeerInit {
fn decode_from(
decoder: &mut ValueDecoder,
) -> Result<Self, ValueDecodeError> {
let user_name = decoder.decode()?;
let connection_type = decoder.decode()?;
let token = decoder.decode()?;
Ok(PeerInit {
user_name,
connection_type,
token,
})
}
fn decode_from(decoder: &mut ValueDecoder) -> Result<Self, ValueDecodeError> {
let user_name = decoder.decode()?;
let connection_type = decoder.decode()?;
let token = decoder.decode()?;
Ok(PeerInit {
user_name,
connection_type,
token,
})
}
}
#[cfg(test)]
mod tests {
use bytes::BytesMut;
use crate::proto::value_codec::tests::roundtrip;
use crate::proto::{ValueDecodeError, ValueDecoder};
use super::*;
#[test]
fn invalid_code() {
let mut bytes = BytesMut::new();
bytes.extend_from_slice(&[57, 5, 0, 0]);
let result = ValueDecoder::new(&bytes).decode::<Message>();
assert_eq!(
result,
Err(ValueDecodeError::InvalidData {
value_name: "peer message code".to_string(),
cause: "unknown value 1337".to_string(),
position: 0,
})
);
}
#[test]
fn roundtrip_pierce_firewall() {
roundtrip(Message::PierceFirewall(1337))
}
#[test]
fn roundtrip_peer_init() {
roundtrip(Message::PeerInit(PeerInit {
user_name: "alice".to_string(),
connection_type: "P".to_string(),
token: 1337,
}));
}
use bytes::BytesMut;
use crate::proto::value_codec::tests::roundtrip;
use crate::proto::{ValueDecodeError, ValueDecoder};
use super::*;
#[test]
fn invalid_code() {
let mut bytes = BytesMut::new();
bytes.extend_from_slice(&[57, 5, 0, 0]);
let result = ValueDecoder::new(&bytes).decode::<Message>();
assert_eq!(
result,
Err(ValueDecodeError::InvalidData {
value_name: "peer message code".to_string(),
cause: "unknown value 1337".to_string(),
position: 0,
})
);
}
#[test]
fn roundtrip_pierce_firewall() {
roundtrip(Message::PierceFirewall(1337))
}
#[test]
fn roundtrip_peer_init() {
roundtrip(Message::PeerInit(PeerInit {
user_name: "alice".to_string(),
connection_type: "P".to_string(),
token: 1337,
}));
}
}

+ 88
- 88
src/proto/prefix.rs View File

@ -11,111 +11,111 @@ use crate::proto::u32::{encode_u32, U32_BYTE_LEN};
/// know the length ahead of encoding time.
#[derive(Debug)]
pub struct Prefixer<'a> {
/// The prefix buffer.
///
/// The length of the suffix buffer is written to the end of this buffer
/// when the prefixer is finalized.
///
/// Contains any bytes with which this prefixer was constructed.
prefix: &'a mut BytesMut,
/// The suffix buffer.
///
/// This is the buffer into which data is written before finalization.
suffix: BytesMut,
/// The prefix buffer.
///
/// The length of the suffix buffer is written to the end of this buffer
/// when the prefixer is finalized.
///
/// Contains any bytes with which this prefixer was constructed.
prefix: &'a mut BytesMut,
/// The suffix buffer.
///
/// This is the buffer into which data is written before finalization.
suffix: BytesMut,
}
impl Prefixer<'_> {
/// Constructs a prefixer for easily appending a length prefixed value to
/// the given buffer.
pub fn new<'a>(buffer: &'a mut BytesMut) -> Prefixer<'a> {
// Reserve some space fot the prefix, but don't write it yet.
buffer.reserve(U32_BYTE_LEN);
// Split off the suffix, into which bytes will be written.
let suffix = buffer.split_off(buffer.len() + U32_BYTE_LEN);
Prefixer {
prefix: buffer,
suffix: suffix,
}
}
/// Returns a reference to the buffer into which data is written.
pub fn suffix(&self) -> &BytesMut {
&self.suffix
}
/// Returns a mutable reference to a buffer into which data can be written.
pub fn suffix_mut(&mut self) -> &mut BytesMut {
&mut self.suffix
}
/// Returns a buffer containing the original data passed at construction
/// time, to which a length-prefixed value is appended. The value itself is
/// the data written into the buffer returned by `get_mut()`.
///
/// Returns `Ok(length)` if successful, in which case the length of the
/// suffix is `length`.
///
/// Returns `Err(self)` if the length of the suffix is too large to store as
/// a prefix.
pub fn finalize(self) -> Result<u32, Self> {
// Check that the suffix's length is not too large.
let length = self.suffix.len();
let length_u32 = match u32::try_from(length) {
Ok(value) => value,
Err(_) => return Err(self),
};
// Write the prefix.
self.prefix.extend_from_slice(&encode_u32(length_u32));
// Join the prefix and suffix back again. Because `self.prefix` is
// private, we are sure that this is O(1).
self.prefix.unsplit(self.suffix);
Ok(length_u32)
/// Constructs a prefixer for easily appending a length prefixed value to
/// the given buffer.
pub fn new<'a>(buffer: &'a mut BytesMut) -> Prefixer<'a> {
// Reserve some space fot the prefix, but don't write it yet.
buffer.reserve(U32_BYTE_LEN);
// Split off the suffix, into which bytes will be written.
let suffix = buffer.split_off(buffer.len() + U32_BYTE_LEN);
Prefixer {
prefix: buffer,
suffix: suffix,
}
}
/// Returns a reference to the buffer into which data is written.
pub fn suffix(&self) -> &BytesMut {
&self.suffix
}
/// Returns a mutable reference to a buffer into which data can be written.
pub fn suffix_mut(&mut self) -> &mut BytesMut {
&mut self.suffix
}
/// Returns a buffer containing the original data passed at construction
/// time, to which a length-prefixed value is appended. The value itself is
/// the data written into the buffer returned by `get_mut()`.
///
/// Returns `Ok(length)` if successful, in which case the length of the
/// suffix is `length`.
///
/// Returns `Err(self)` if the length of the suffix is too large to store as
/// a prefix.
pub fn finalize(self) -> Result<u32, Self> {
// Check that the suffix's length is not too large.
let length = self.suffix.len();
let length_u32 = match u32::try_from(length) {
Ok(value) => value,
Err(_) => return Err(self),
};
// Write the prefix.
self.prefix.extend_from_slice(&encode_u32(length_u32));
// Join the prefix and suffix back again. Because `self.prefix` is
// private, we are sure that this is O(1).
self.prefix.unsplit(self.suffix);
Ok(length_u32)
}
}
#[cfg(test)]
mod tests {
use super::Prefixer;
use super::Prefixer;
use std::convert::TryInto;
use std::convert::TryInto;
use bytes::{BufMut, BytesMut};
use bytes::{BufMut, BytesMut};
use crate::proto::u32::{decode_u32, U32_BYTE_LEN};
use crate::proto::u32::{decode_u32, U32_BYTE_LEN};
#[test]
fn finalize_empty() {
let mut buffer = BytesMut::new();
buffer.put_u8(13);
#[test]
fn finalize_empty() {
let mut buffer = BytesMut::new();
buffer.put_u8(13);
Prefixer::new(&mut buffer).finalize().unwrap();
Prefixer::new(&mut buffer).finalize().unwrap();
assert_eq!(buffer.len(), U32_BYTE_LEN + 1);
let array: [u8; U32_BYTE_LEN] = buffer[1..].try_into().unwrap();
assert_eq!(decode_u32(array), 0);
}
assert_eq!(buffer.len(), U32_BYTE_LEN + 1);
let array: [u8; U32_BYTE_LEN] = buffer[1..].try_into().unwrap();
assert_eq!(decode_u32(array), 0);
}
#[test]
fn finalize_ok() {
let mut buffer = BytesMut::new();
buffer.put_u8(13);
#[test]
fn finalize_ok() {
let mut buffer = BytesMut::new();
buffer.put_u8(13);
let mut prefixer = Prefixer::new(&mut buffer);
let mut prefixer = Prefixer::new(&mut buffer);
prefixer.suffix_mut().extend_from_slice(&[0; 42]);
prefixer.suffix_mut().extend_from_slice(&[0; 42]);
prefixer.finalize().unwrap();
prefixer.finalize().unwrap();
// 1 junk prefix byte, length prefix, 42 bytes of value.
assert_eq!(buffer.len(), U32_BYTE_LEN + 43);
let prefix = &buffer[1..U32_BYTE_LEN + 1];
let array: [u8; U32_BYTE_LEN] = prefix.try_into().unwrap();
assert_eq!(decode_u32(array), 42);
}
// 1 junk prefix byte, length prefix, 42 bytes of value.
assert_eq!(buffer.len(), U32_BYTE_LEN + 43);
let prefix = &buffer[1..U32_BYTE_LEN + 1];
let array: [u8; U32_BYTE_LEN] = prefix.try_into().unwrap();
assert_eq!(decode_u32(array), 42);
}
}

+ 486
- 542
src/proto/server/request.rs
File diff suppressed because it is too large
View File


+ 1237
- 1346
src/proto/server/response.rs
File diff suppressed because it is too large
View File


+ 165
- 165
src/proto/stream.rs View File

@ -15,41 +15,41 @@ use super::packet::{MutPacket, Parser, ReadFromPacket, WriteToPacket};
/// A struct used for writing bytes to a TryWrite sink.
#[derive(Debug)]
struct OutBuf {
cursor: usize,
bytes: Vec<u8>,
cursor: usize,
bytes: Vec<u8>,
}
impl From<Vec<u8>> for OutBuf {
fn from(bytes: Vec<u8>) -> Self {
OutBuf {
cursor: 0,
bytes: bytes,
}
fn from(bytes: Vec<u8>) -> Self {
OutBuf {
cursor: 0,
bytes: bytes,
}
}
}
impl OutBuf {
#[inline]
fn remaining(&self) -> usize {
self.bytes.len() - self.cursor
}
#[inline]
fn has_remaining(&self) -> bool {
self.remaining() > 0
}
#[allow(deprecated)]
fn try_write_to<T>(&mut self, mut writer: T) -> io::Result<Option<usize>>
where
T: mio::deprecated::TryWrite,
{
let result = writer.try_write(&self.bytes[self.cursor..]);
if let Ok(Some(bytes_written)) = result {
self.cursor += bytes_written;
}
result
#[inline]
fn remaining(&self) -> usize {
self.bytes.len() - self.cursor
}
#[inline]
fn has_remaining(&self) -> bool {
self.remaining() > 0
}
#[allow(deprecated)]
fn try_write_to<T>(&mut self, mut writer: T) -> io::Result<Option<usize>>
where
T: mio::deprecated::TryWrite,
{
let result = writer.try_write(&self.bytes[self.cursor..]);
if let Ok(Some(bytes_written)) = result {
self.cursor += bytes_written;
}
result
}
}
/*========*
@ -59,171 +59,171 @@ impl OutBuf {
/// This trait is implemented by packet sinks to which a stream can forward
/// the packets it reads.
pub trait SendPacket {
type Value: ReadFromPacket;
type Error: error::Error;
type Value: ReadFromPacket;
type Error: error::Error;
fn send_packet(&mut self, _: Self::Value) -> Result<(), Self::Error>;
fn send_packet(&mut self, _: Self::Value) -> Result<(), Self::Error>;
fn notify_open(&mut self) -> Result<(), Self::Error>;
fn notify_open(&mut self) -> Result<(), Self::Error>;
}
/// This enum defines the possible actions the stream wants to take after
/// processing an event.
#[derive(Debug, Clone, Copy)]
pub enum Intent {
/// The stream is done, the event loop handler can drop it.
Done,
/// The stream wants to wait for the next event matching the given
/// `EventSet`.
Continue(mio::Ready),
/// The stream is done, the event loop handler can drop it.
Done,
/// The stream wants to wait for the next event matching the given
/// `EventSet`.
Continue(mio::Ready),
}
/// This struct wraps around an mio tcp stream and handles packet reads and
/// writes.
#[derive(Debug)]
pub struct Stream<T: SendPacket> {
parser: Parser,
queue: VecDeque<OutBuf>,
sender: T,
stream: mio::tcp::TcpStream,
parser: Parser,
queue: VecDeque<OutBuf>,
sender: T,
stream: mio::tcp::TcpStream,
is_connected: bool,
is_connected: bool,
}
impl<T: SendPacket> Stream<T> {
/// Returns a new stream, asynchronously connected to the given address,
/// which forwards incoming packets to the given sender.
/// If an error occurs when connecting, returns an error.
pub fn new<U>(addr_spec: U, sender: T) -> io::Result<Self>
where
U: ToSocketAddrs + fmt::Debug,
{
for sock_addr in addr_spec.to_socket_addrs()? {
if let Ok(stream) = mio::tcp::TcpStream::connect(&sock_addr) {
return Ok(Stream {
parser: Parser::new(),
queue: VecDeque::new(),
sender: sender,
stream: stream,
is_connected: false,
});
}
}
Err(io::Error::new(
io::ErrorKind::Other,
format!("Cannot connect to {:?}", addr_spec),
))
/// Returns a new stream, asynchronously connected to the given address,
/// which forwards incoming packets to the given sender.
/// If an error occurs when connecting, returns an error.
pub fn new<U>(addr_spec: U, sender: T) -> io::Result<Self>
where
U: ToSocketAddrs + fmt::Debug,
{
for sock_addr in addr_spec.to_socket_addrs()? {
if let Ok(stream) = mio::tcp::TcpStream::connect(&sock_addr) {
return Ok(Stream {
parser: Parser::new(),
queue: VecDeque::new(),
sender: sender,
stream: stream,
is_connected: false,
});
}
}
/// Returns a reference to the underlying byte stream, to allow it to be
/// registered with an event loop.
pub fn evented(&self) -> &mio::tcp::TcpStream {
&self.stream
Err(io::Error::new(
io::ErrorKind::Other,
format!("Cannot connect to {:?}", addr_spec),
))
}
/// Returns a reference to the underlying byte stream, to allow it to be
/// registered with an event loop.
pub fn evented(&self) -> &mio::tcp::TcpStream {
&self.stream
}
/// The stream is ready to be read from.
fn on_readable(&mut self) -> Result<(), String> {
loop {
let mut packet = match self.parser.try_read(&mut self.stream) {
Ok(Some(packet)) => packet,
Ok(None) => break,
Err(e) => return Err(format!("Error reading stream: {}", e)),
};
let value = match packet.read_value() {
Ok(value) => value,
Err(e) => return Err(format!("Error parsing packet: {}", e)),
};
if let Err(e) = self.sender.send_packet(value) {
return Err(format!("Error sending parsed packet: {}", e));
}
}
/// The stream is ready to be read from.
fn on_readable(&mut self) -> Result<(), String> {
loop {
let mut packet = match self.parser.try_read(&mut self.stream) {
Ok(Some(packet)) => packet,
Ok(None) => break,
Err(e) => return Err(format!("Error reading stream: {}", e)),
};
let value = match packet.read_value() {
Ok(value) => value,
Err(e) => return Err(format!("Error parsing packet: {}", e)),
};
if let Err(e) = self.sender.send_packet(value) {
return Err(format!("Error sending parsed packet: {}", e));
}
Ok(())
}
/// The stream is ready to be written to.
fn on_writable(&mut self) -> io::Result<()> {
loop {
let mut outbuf = match self.queue.pop_front() {
Some(outbuf) => outbuf,
None => break,
};
let option = outbuf.try_write_to(&mut self.stream)?;
match option {
Some(_) => {
if outbuf.has_remaining() {
self.queue.push_front(outbuf)
}
// Continue looping
}
Ok(())
}
/// The stream is ready to be written to.
fn on_writable(&mut self) -> io::Result<()> {
loop {
let mut outbuf = match self.queue.pop_front() {
Some(outbuf) => outbuf,
None => break,
};
let option = outbuf.try_write_to(&mut self.stream)?;
match option {
Some(_) => {
if outbuf.has_remaining() {
self.queue.push_front(outbuf)
}
// Continue looping
}
None => {
self.queue.push_front(outbuf);
break;
}
}
None => {
self.queue.push_front(outbuf);
break;
}
Ok(())
}
}
Ok(())
}
/// The stream is ready to read, write, or both.
pub fn on_ready(&mut self, event_set: mio::Ready) -> Intent {
#[allow(deprecated)]
if event_set.is_hup() || event_set.is_error() {
return Intent::Done;
}
if event_set.is_readable() {
let result = self.on_readable();
if let Err(e) = result {
error!("Stream input error: {}", e);
return Intent::Done;
}
}
if event_set.is_writable() {
let result = self.on_writable();
if let Err(e) = result {
error!("Stream output error: {}", e);
return Intent::Done;
}
}
// We must have read or written something succesfully if we're here,
// so the stream must be connected.
if !self.is_connected {
// If we weren't already connected, notify the sink.
if let Err(err) = self.sender.notify_open() {
error!("Cannot notify client that stream is open: {}", err);
return Intent::Done;
}
// And record the fact that we are now connected.
self.is_connected = true;
}
/// The stream is ready to read, write, or both.
pub fn on_ready(&mut self, event_set: mio::Ready) -> Intent {
#[allow(deprecated)]
if event_set.is_hup() || event_set.is_error() {
return Intent::Done;
}
if event_set.is_readable() {
let result = self.on_readable();
if let Err(e) = result {
error!("Stream input error: {}", e);
return Intent::Done;
}
}
if event_set.is_writable() {
let result = self.on_writable();
if let Err(e) = result {
error!("Stream output error: {}", e);
return Intent::Done;
}
}
// We're always interested in reading more.
#[allow(deprecated)]
let mut event_set =
mio::Ready::readable() | mio::Ready::hup() | mio::Ready::error();
// If there is still stuff to write in the queue, we're interested in
// the socket becoming writable too.
if self.queue.len() > 0 {
event_set = event_set | mio::Ready::writable();
}
// We must have read or written something succesfully if we're here,
// so the stream must be connected.
if !self.is_connected {
// If we weren't already connected, notify the sink.
if let Err(err) = self.sender.notify_open() {
error!("Cannot notify client that stream is open: {}", err);
return Intent::Done;
}
// And record the fact that we are now connected.
self.is_connected = true;
}
Intent::Continue(event_set)
// We're always interested in reading more.
#[allow(deprecated)]
let mut event_set =
mio::Ready::readable() | mio::Ready::hup() | mio::Ready::error();
// If there is still stuff to write in the queue, we're interested in
// the socket becoming writable too.
if self.queue.len() > 0 {
event_set = event_set | mio::Ready::writable();
}
/// The stream has been notified.
pub fn on_notify<V>(&mut self, payload: &V) -> Intent
where
V: WriteToPacket,
{
let mut packet = MutPacket::new();
let result = packet.write_value(payload);
if let Err(e) = result {
error!("Error writing payload to packet: {}", e);
return Intent::Done;
}
self.queue.push_back(OutBuf::from(packet.into_bytes()));
Intent::Continue(mio::Ready::readable() | mio::Ready::writable())
Intent::Continue(event_set)
}
/// The stream has been notified.
pub fn on_notify<V>(&mut self, payload: &V) -> Intent
where
V: WriteToPacket,
{
let mut packet = MutPacket::new();
let result = packet.write_value(payload);
if let Err(e) = result {
error!("Error writing payload to packet: {}", e);
return Intent::Done;
}
self.queue.push_back(OutBuf::from(packet.into_bytes()));
Intent::Continue(mio::Ready::readable() | mio::Ready::writable())
}
}

+ 44
- 44
src/proto/testing.rs View File

@ -8,68 +8,68 @@ use tokio::net::{TcpListener, TcpStream};
use crate::proto::{FrameStream, ServerRequest, ServerResponse};
async fn process(stream: TcpStream) -> io::Result<()> {
let mut connection =
FrameStream::<ServerRequest, ServerResponse>::new(stream);
let mut connection =
FrameStream::<ServerRequest, ServerResponse>::new(stream);
let _request = match connection.read().await? {
ServerRequest::LoginRequest(request) => request,
request => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("expected login request, got: {:?}", request),
));
}
};
let _request = match connection.read().await? {
ServerRequest::LoginRequest(request) => request,
request => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("expected login request, got: {:?}", request),
));
}
};
Ok(())
Ok(())
}
/// A fake server for connecting to in tests.
pub struct FakeServer {
listener: TcpListener,
listener: TcpListener,
}
impl FakeServer {
/// Creates a new fake server and binds it to a port on localhost.
pub async fn new() -> io::Result<Self> {
let listener = TcpListener::bind("localhost:0").await?;
Ok(FakeServer { listener })
}
/// Creates a new fake server and binds it to a port on localhost.
pub async fn new() -> io::Result<Self> {
let listener = TcpListener::bind("localhost:0").await?;
Ok(FakeServer { listener })
}
/// Returns the address to which this server is bound.
/// This is always localhost and a random port chosen by the OS.
pub fn address(&self) -> io::Result<SocketAddr> {
self.listener.local_addr()
}
/// Returns the address to which this server is bound.
/// This is always localhost and a random port chosen by the OS.
pub fn address(&self) -> io::Result<SocketAddr> {
self.listener.local_addr()
}
/// Runs the server: accepts incoming connections and responds to requests.
pub async fn run(&mut self) -> io::Result<()> {
loop {
let (socket, _peer_address) = self.listener.accept().await?;
tokio::spawn(async move { process(socket).await });
}
/// Runs the server: accepts incoming connections and responds to requests.
pub async fn run(&mut self) -> io::Result<()> {
loop {
let (socket, _peer_address) = self.listener.accept().await?;
tokio::spawn(async move { process(socket).await });
}
}
}
#[cfg(test)]
mod tests {
use tokio::net::TcpStream;
use tokio::net::TcpStream;
use super::FakeServer;
use super::FakeServer;
#[tokio::test]
async fn new_binds_to_localhost() {
let server = FakeServer::new().await.unwrap();
assert!(server.address().unwrap().ip().is_loopback());
}
#[tokio::test]
async fn new_binds_to_localhost() {
let server = FakeServer::new().await.unwrap();
assert!(server.address().unwrap().ip().is_loopback());
}
#[tokio::test]
async fn accepts_incoming_connections() {
let mut server = FakeServer::new().await.unwrap();
let address = server.address().unwrap();
tokio::spawn(async move { server.run().await.unwrap() });
#[tokio::test]
async fn accepts_incoming_connections() {
let mut server = FakeServer::new().await.unwrap();
let address = server.address().unwrap();
tokio::spawn(async move { server.run().await.unwrap() });
// The connection succeeds.
let _ = TcpStream::connect(address).await.unwrap();
}
// The connection succeeds.
let _ = TcpStream::connect(address).await.unwrap();
}
}

+ 2
- 2
src/proto/u32.rs View File

@ -8,10 +8,10 @@ pub const U32_BYTE_LEN: usize = 4;
/// Returns the byte representatio of the given integer value.
pub fn encode_u32(value: u32) -> [u8; U32_BYTE_LEN] {
value.to_le_bytes()
value.to_le_bytes()
}
/// Returns the integer value corresponding to the given bytes.
pub fn decode_u32(bytes: [u8; U32_BYTE_LEN]) -> u32 {
u32::from_le_bytes(bytes)
u32::from_le_bytes(bytes)
}

+ 75
- 80
src/proto/user.rs View File

@ -1,9 +1,9 @@
use std::io;
use crate::proto::{
MutPacket, Packet, PacketReadError, ReadFromPacket, ValueDecode,
ValueDecodeError, ValueDecoder, ValueEncode, ValueEncodeError,
ValueEncoder, WriteToPacket,
MutPacket, Packet, PacketReadError, ReadFromPacket, ValueDecode,
ValueDecodeError, ValueDecoder, ValueEncode, ValueEncodeError, ValueEncoder,
WriteToPacket,
};
const STATUS_OFFLINE: u32 = 1;
@ -12,103 +12,98 @@ const STATUS_ONLINE: u32 = 3;
/// This enumeration is the list of possible user statuses.
#[derive(
Clone,
Copy,
Debug,
Eq,
Ord,
PartialEq,
PartialOrd,
RustcDecodable,
RustcEncodable,
Clone,
Copy,
Debug,
Eq,
Ord,
PartialEq,
PartialOrd,
RustcDecodable,
RustcEncodable,
)]
pub enum UserStatus {
/// The user if offline.
Offline,
/// The user is connected, but AFK.
Away,
/// The user is present.
Online,
/// The user if offline.
Offline,
/// The user is connected, but AFK.
Away,
/// The user is present.
Online,
}
impl ReadFromPacket for UserStatus {
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let n: u32 = packet.read_value()?;
match n {
STATUS_OFFLINE => Ok(UserStatus::Offline),
STATUS_AWAY => Ok(UserStatus::Away),
STATUS_ONLINE => Ok(UserStatus::Online),
_ => Err(PacketReadError::InvalidUserStatusError(n)),
}
fn read_from_packet(packet: &mut Packet) -> Result<Self, PacketReadError> {
let n: u32 = packet.read_value()?;
match n {
STATUS_OFFLINE => Ok(UserStatus::Offline),
STATUS_AWAY => Ok(UserStatus::Away),
STATUS_ONLINE => Ok(UserStatus::Online),
_ => Err(PacketReadError::InvalidUserStatusError(n)),
}
}
}
impl WriteToPacket for UserStatus {
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
let n = match *self {
UserStatus::Offline => STATUS_OFFLINE,
UserStatus::Away => STATUS_AWAY,
UserStatus::Online => STATUS_ONLINE,
};
packet.write_value(&n)?;
Ok(())
}
fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> {
let n = match *self {
UserStatus::Offline => STATUS_OFFLINE,
UserStatus::Away => STATUS_AWAY,
UserStatus::Online => STATUS_ONLINE,
};
packet.write_value(&n)?;
Ok(())
}
}
impl ValueEncode for UserStatus {
fn encode(
&self,
encoder: &mut ValueEncoder,
) -> Result<(), ValueEncodeError> {
let value = match *self {
UserStatus::Offline => STATUS_OFFLINE,
UserStatus::Away => STATUS_AWAY,
UserStatus::Online => STATUS_ONLINE,
};
encoder.encode_u32(value)
}
fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> {
let value = match *self {
UserStatus::Offline => STATUS_OFFLINE,
UserStatus::Away => STATUS_AWAY,
UserStatus::Online => STATUS_ONLINE,
};
encoder.encode_u32(value)
}
}
impl ValueDecode for UserStatus {
fn decode_from(
decoder: &mut ValueDecoder,
) -> Result<Self, ValueDecodeError> {
let position = decoder.position();
let value: u32 = decoder.decode()?;
match value {
STATUS_OFFLINE => Ok(UserStatus::Offline),
STATUS_AWAY => Ok(UserStatus::Away),
STATUS_ONLINE => Ok(UserStatus::Online),
_ => Err(ValueDecodeError::InvalidData {
value_name: "user status".to_string(),
cause: format!("unknown value {}", value),
position: position,
}),
}
fn decode_from(decoder: &mut ValueDecoder) -> Result<Self, ValueDecodeError> {
let position = decoder.position();
let value: u32 = decoder.decode()?;
match value {
STATUS_OFFLINE => Ok(UserStatus::Offline),
STATUS_AWAY => Ok(UserStatus::Away),
STATUS_ONLINE => Ok(UserStatus::Online),
_ => Err(ValueDecodeError::InvalidData {
value_name: "user status".to_string(),
cause: format!("unknown value {}", value),
position: position,
}),
}
}
}
/// This structure contains the last known information about a fellow user.
#[derive(
Clone, Debug, Eq, Ord, PartialEq, PartialOrd, RustcDecodable, RustcEncodable,
Clone, Debug, Eq, Ord, PartialEq, PartialOrd, RustcDecodable, RustcEncodable,
)]
pub struct User {
/// The name of the user.
pub name: String,
/// The last known status of the user.
pub status: UserStatus,
/// The average upload speed of the user.
pub average_speed: usize,
/// ??? Nicotine calls it downloadnum.
pub num_downloads: usize,
/// ??? Unknown field.
pub unknown: usize,
/// The number of files this user shares.
pub num_files: usize,
/// The number of folders this user shares.
pub num_folders: usize,
/// The number of free download slots of this user.
pub num_free_slots: usize,
/// The user's country code.
pub country: String,
/// The name of the user.
pub name: String,
/// The last known status of the user.
pub status: UserStatus,
/// The average upload speed of the user.
pub average_speed: usize,
/// ??? Nicotine calls it downloadnum.
pub num_downloads: usize,
/// ??? Unknown field.
pub unknown: usize,
/// The number of files this user shares.
pub num_files: usize,
/// The number of folders this user shares.
pub num_folders: usize,
/// The number of free download slots of this user.
pub num_free_slots: usize,
/// The user's country code.
pub country: String,
}

+ 687
- 745
src/proto/value_codec.rs
File diff suppressed because it is too large
View File


+ 298
- 312
src/room.rs View File

@ -8,35 +8,35 @@ use crate::proto::{server, User};
/// This enumeration is the list of possible membership states for a chat room.
#[derive(Clone, Copy, Debug, Eq, PartialEq, RustcDecodable, RustcEncodable)]
pub enum Membership {
/// The user is not a member of this room.
NonMember,
/// The user has requested to join the room, but hasn't heard back from the
/// server yet.
Joining,
/// The user is a member of the room.
Member,
/// The user has request to leave the room, but hasn't heard back from the
/// server yet.
Leaving,
/// The user is not a member of this room.
NonMember,
/// The user has requested to join the room, but hasn't heard back from the
/// server yet.
Joining,
/// The user is a member of the room.
Member,
/// The user has request to leave the room, but hasn't heard back from the
/// server yet.
Leaving,
}
/// This enumeration is the list of visibility types for rooms that the user is
/// a member of.
#[derive(Clone, Copy, Debug, Eq, PartialEq, RustcDecodable, RustcEncodable)]
pub enum Visibility {
/// This room is visible to any user.
Public,
/// This room is visible only to members, and the user owns it.
PrivateOwned,
/// This room is visible only to members, and someone else owns it.
PrivateOther,
/// This room is visible to any user.
Public,
/// This room is visible only to members, and the user owns it.
PrivateOwned,
/// This room is visible only to members, and someone else owns it.
PrivateOther,
}
/// This structure contains a chat room message.
#[derive(Clone, Debug, Eq, PartialEq, RustcDecodable, RustcEncodable)]
pub struct Message {
pub user_name: String,
pub message: String,
pub user_name: String,
pub message: String,
}
/// This structure contains the last known information about a chat room.
@ -44,356 +44,342 @@ pub struct Message {
/// room hash table.
#[derive(Clone, Debug, Eq, PartialEq, RustcDecodable, RustcEncodable)]
pub struct Room {
/// The membership state of the user for the room.
pub membership: Membership,
/// The visibility of the room.
pub visibility: Visibility,
/// True if the user is one of the room's operators, False if the user is a
/// regular member.
pub operated: bool,
/// The number of users that are members of the room.
pub user_count: usize,
/// The name of the room's owner, if any.
pub owner: Option<String>,
/// The names of the room's operators.
pub operators: collections::HashSet<String>,
/// The names of the room's members.
pub members: collections::HashSet<String>,
/// The messages sent to this chat room, in chronological order.
pub messages: Vec<Message>,
/// The tickers displayed in this room.
pub tickers: Vec<(String, String)>,
/// The membership state of the user for the room.
pub membership: Membership,
/// The visibility of the room.
pub visibility: Visibility,
/// True if the user is one of the room's operators, False if the user is a
/// regular member.
pub operated: bool,
/// The number of users that are members of the room.
pub user_count: usize,
/// The name of the room's owner, if any.
pub owner: Option<String>,
/// The names of the room's operators.
pub operators: collections::HashSet<String>,
/// The names of the room's members.
pub members: collections::HashSet<String>,
/// The messages sent to this chat room, in chronological order.
pub messages: Vec<Message>,
/// The tickers displayed in this room.
pub tickers: Vec<(String, String)>,
}
impl Room {
/// Creates a new room with the given visibility and user count.
fn new(visibility: Visibility, user_count: usize) -> Self {
Room {
membership: Membership::NonMember,
visibility: visibility,
operated: false,
user_count: user_count,
owner: None,
operators: collections::HashSet::new(),
members: collections::HashSet::new(),
messages: Vec::new(),
tickers: Vec::new(),
}
/// Creates a new room with the given visibility and user count.
fn new(visibility: Visibility, user_count: usize) -> Self {
Room {
membership: Membership::NonMember,
visibility: visibility,
operated: false,
user_count: user_count,
owner: None,
operators: collections::HashSet::new(),
members: collections::HashSet::new(),
messages: Vec::new(),
tickers: Vec::new(),
}
}
}
/// The error returned by RoomMap functions.
#[derive(Debug)]
pub enum Error {
RoomNotFound(String),
MembershipChangeInvalid(Membership, Membership),
RoomNotFound(String),
MembershipChangeInvalid(Membership, Membership),
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Error::RoomNotFound(ref room_name) => {
write!(f, "room {:?} not found", room_name)
}
Error::MembershipChangeInvalid(old_membership, new_membership) => {
write!(
f,
"cannot change membership from {:?} to {:?}",
old_membership, new_membership
)
}
}
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Error::RoomNotFound(ref room_name) => {
write!(f, "room {:?} not found", room_name)
}
Error::MembershipChangeInvalid(old_membership, new_membership) => {
write!(
f,
"cannot change membership from {:?} to {:?}",
old_membership, new_membership
)
}
}
}
}
impl error::Error for Error {
fn description(&self) -> &str {
match *self {
Error::RoomNotFound(_) => "room not found",
Error::MembershipChangeInvalid(_, _) => "cannot change membership",
}
fn description(&self) -> &str {
match *self {
Error::RoomNotFound(_) => "room not found",
Error::MembershipChangeInvalid(_, _) => "cannot change membership",
}
}
}
/// Contains the mapping from room names to room data and provides a clean
/// interface to interact with it.
#[derive(Debug)]
pub struct RoomMap {
/// The actual map from room names to room data.
map: collections::HashMap<String, Room>,
/// The actual map from room names to room data.
map: collections::HashMap<String, Room>,
}
impl RoomMap {
/// Creates an empty mapping.
pub fn new() -> Self {
RoomMap {
map: collections::HashMap::new(),
}
/// Creates an empty mapping.
pub fn new() -> Self {
RoomMap {
map: collections::HashMap::new(),
}
}
/// Looks up the given room name in the map, returning an immutable
/// reference to the associated data if found, or an error if not found.
fn get_strict(&self, room_name: &str) -> Result<&Room, Error> {
match self.map.get(room_name) {
Some(room) => Ok(room),
None => Err(Error::RoomNotFound(room_name.to_string())),
}
}
/// Looks up the given room name in the map, returning a mutable
/// reference to the associated data if found, or an error if not found.
fn get_mut_strict(&mut self, room_name: &str) -> Result<&mut Room, Error> {
match self.map.get_mut(room_name) {
Some(room) => Ok(room),
None => Err(Error::RoomNotFound(room_name.to_string())),
}
}
/// Updates one room in the map based on the information received in
/// a RoomListResponse and the potential previously stored information.
fn update_one(
&mut self,
name: String,
visibility: Visibility,
user_count: u32,
old_map: &mut collections::HashMap<String, Room>,
) {
let room = match old_map.remove(&name) {
None => Room::new(Visibility::Public, user_count as usize),
Some(mut room) => {
room.visibility = visibility;
room.user_count = user_count as usize;
room
}
};
if let Some(_) = self.map.insert(name, room) {
error!("Room present twice in room list response");
}
}
/// Looks up the given room name in the map, returning an immutable
/// reference to the associated data if found, or an error if not found.
fn get_strict(&self, room_name: &str) -> Result<&Room, Error> {
match self.map.get(room_name) {
Some(room) => Ok(room),
None => Err(Error::RoomNotFound(room_name.to_string())),
}
/// Updates the map to reflect the information contained in the given
/// server response.
pub fn set_room_list(&mut self, mut response: server::RoomListResponse) {
// Replace the old mapping with an empty one.
let mut old_map = mem::replace(&mut self.map, collections::HashMap::new());
// Add all public rooms.
for (name, user_count) in response.rooms.drain(..) {
self.update_one(name, Visibility::Public, user_count, &mut old_map);
}
/// Looks up the given room name in the map, returning a mutable
/// reference to the associated data if found, or an error if not found.
fn get_mut_strict(&mut self, room_name: &str) -> Result<&mut Room, Error> {
match self.map.get_mut(room_name) {
Some(room) => Ok(room),
None => Err(Error::RoomNotFound(room_name.to_string())),
}
// Add all private, owned, rooms.
for (name, user_count) in response.owned_private_rooms.drain(..) {
self.update_one(name, Visibility::PrivateOwned, user_count, &mut old_map);
}
/// Updates one room in the map based on the information received in
/// a RoomListResponse and the potential previously stored information.
fn update_one(
&mut self,
name: String,
visibility: Visibility,
user_count: u32,
old_map: &mut collections::HashMap<String, Room>,
) {
let room = match old_map.remove(&name) {
None => Room::new(Visibility::Public, user_count as usize),
Some(mut room) => {
room.visibility = visibility;
room.user_count = user_count as usize;
room
}
};
if let Some(_) = self.map.insert(name, room) {
error!("Room present twice in room list response");
}
// Add all private, unowned, rooms.
for (name, user_count) in response.other_private_rooms.drain(..) {
self.update_one(name, Visibility::PrivateOther, user_count, &mut old_map);
}
/// Updates the map to reflect the information contained in the given
/// server response.
pub fn set_room_list(&mut self, mut response: server::RoomListResponse) {
// Replace the old mapping with an empty one.
let mut old_map =
mem::replace(&mut self.map, collections::HashMap::new());
// Add all public rooms.
for (name, user_count) in response.rooms.drain(..) {
self.update_one(name, Visibility::Public, user_count, &mut old_map);
}
// Add all private, owned, rooms.
for (name, user_count) in response.owned_private_rooms.drain(..) {
self.update_one(
name,
Visibility::PrivateOwned,
user_count,
&mut old_map,
);
}
// Add all private, unowned, rooms.
for (name, user_count) in response.other_private_rooms.drain(..) {
self.update_one(
name,
Visibility::PrivateOther,
user_count,
&mut old_map,
);
}
// Mark all operated rooms as necessary.
for name in response.operated_private_room_names.iter() {
match self.map.get_mut(name) {
Some(room) => room.operated = true,
None => error!("Room {} is operated but does not exist", name),
}
}
// Mark all operated rooms as necessary.
for name in response.operated_private_room_names.iter() {
match self.map.get_mut(name) {
Some(room) => room.operated = true,
None => error!("Room {} is operated but does not exist", name),
}
}
}
/// Returns the list of (room name, room data) representing all known rooms.
pub fn get_room_list(&self) -> Vec<(String, Room)> {
let mut rooms = Vec::new();
for (room_name, room) in self.map.iter() {
rooms.push((room_name.clone(), room.clone()));
}
rooms
/// Returns the list of (room name, room data) representing all known rooms.
pub fn get_room_list(&self) -> Vec<(String, Room)> {
let mut rooms = Vec::new();
for (room_name, room) in self.map.iter() {
rooms.push((room_name.clone(), room.clone()));
}
rooms
}
/// Records that we are now trying to join the given room.
/// If the room is not found, or if its membership is not `NonMember`,
/// returns an error.
pub fn start_joining(&mut self, room_name: &str) -> Result<(), Error> {
let room = self.get_mut_strict(room_name)?;
match room.membership {
Membership::NonMember => {
room.membership = Membership::Joining;
Ok(())
}
/// Records that we are now trying to join the given room.
/// If the room is not found, or if its membership is not `NonMember`,
/// returns an error.
pub fn start_joining(&mut self, room_name: &str) -> Result<(), Error> {
let room = self.get_mut_strict(room_name)?;
match room.membership {
Membership::NonMember => {
room.membership = Membership::Joining;
Ok(())
}
membership => Err(Error::MembershipChangeInvalid(
membership,
Membership::Joining,
)),
}
membership => Err(Error::MembershipChangeInvalid(
membership,
Membership::Joining,
)),
}
}
/// Records that we are now a member of the given room and updates the room
/// information.
pub fn join(
&mut self,
room_name: &str,
owner: Option<String>,
mut operators: Vec<String>,
members: &[User],
) -> Result<(), Error> {
// First look up the room struct.
let room = self.get_mut_strict(room_name)?;
// Log what's happening.
if let Membership::Joining = room.membership {
info!("Joined room {:?}", room_name);
} else {
warn!(
"Joined room {:?} but membership was already {:?}",
room_name, room.membership
);
}
/// Records that we are now a member of the given room and updates the room
/// information.
pub fn join(
&mut self,
room_name: &str,
owner: Option<String>,
mut operators: Vec<String>,
members: &[User],
) -> Result<(), Error> {
// First look up the room struct.
let room = self.get_mut_strict(room_name)?;
// Log what's happening.
if let Membership::Joining = room.membership {
info!("Joined room {:?}", room_name);
} else {
warn!(
"Joined room {:?} but membership was already {:?}",
room_name, room.membership
);
}
// Update the room struct.
room.membership = Membership::Member;
room.user_count = members.len();
room.owner = owner;
room.operators.clear();
for user_name in operators.drain(..) {
room.operators.insert(user_name);
}
room.members.clear();
for user in members {
room.members.insert(user.name.clone());
}
// Update the room struct.
room.membership = Membership::Member;
room.user_count = members.len();
room.owner = owner;
Ok(())
room.operators.clear();
for user_name in operators.drain(..) {
room.operators.insert(user_name);
}
/// Records that we are now trying to leave the given room.
/// If the room is not found, or if its membership status is not `Member`,
/// returns an error.
pub fn start_leaving(&mut self, room_name: &str) -> Result<(), Error> {
let room = self.get_mut_strict(room_name)?;
match room.membership {
Membership::Member => {
room.membership = Membership::Leaving;
Ok(())
}
membership => Err(Error::MembershipChangeInvalid(
membership,
Membership::Leaving,
)),
}
room.members.clear();
for user in members {
room.members.insert(user.name.clone());
}
/// Records that we have now left the given room.
pub fn leave(&mut self, room_name: &str) -> Result<(), Error> {
let room = self.get_mut_strict(room_name)?;
match room.membership {
Membership::Leaving => info!("Left room {:?}", room_name),
Ok(())
}
membership => warn!(
"Left room {:?} with wrong membership: {:?}",
room_name, membership
),
}
/// Records that we are now trying to leave the given room.
/// If the room is not found, or if its membership status is not `Member`,
/// returns an error.
pub fn start_leaving(&mut self, room_name: &str) -> Result<(), Error> {
let room = self.get_mut_strict(room_name)?;
room.membership = Membership::NonMember;
match room.membership {
Membership::Member => {
room.membership = Membership::Leaving;
Ok(())
}
}
/// Saves the given message as the last one in the given room.
pub fn add_message(
&mut self,
room_name: &str,
message: Message,
) -> Result<(), Error> {
let room = self.get_mut_strict(room_name)?;
room.messages.push(message);
Ok(())
membership => Err(Error::MembershipChangeInvalid(
membership,
Membership::Leaving,
)),
}
}
/// Inserts the given user in the given room's set of members.
/// Returns an error if the room is not found.
pub fn insert_member(
&mut self,
room_name: &str,
user_name: String,
) -> Result<(), Error> {
let room = self.get_mut_strict(room_name)?;
room.members.insert(user_name);
Ok(())
}
/// Records that we have now left the given room.
pub fn leave(&mut self, room_name: &str) -> Result<(), Error> {
let room = self.get_mut_strict(room_name)?;
/// Removes the given user from the given room's set of members.
/// Returns an error if the room is not found.
pub fn remove_member(
&mut self,
room_name: &str,
user_name: &str,
) -> Result<(), Error> {
let room = self.get_mut_strict(room_name)?;
room.members.remove(user_name);
Ok(())
}
match room.membership {
Membership::Leaving => info!("Left room {:?}", room_name),
/*---------*
* Tickers *
*---------*/
pub fn set_tickers(
&mut self,
room_name: &str,
tickers: Vec<(String, String)>,
) -> Result<(), Error> {
let room = self.get_mut_strict(room_name)?;
room.tickers = tickers;
Ok(())
membership => warn!(
"Left room {:?} with wrong membership: {:?}",
room_name, membership
),
}
room.membership = Membership::NonMember;
Ok(())
}
/// Saves the given message as the last one in the given room.
pub fn add_message(
&mut self,
room_name: &str,
message: Message,
) -> Result<(), Error> {
let room = self.get_mut_strict(room_name)?;
room.messages.push(message);
Ok(())
}
/// Inserts the given user in the given room's set of members.
/// Returns an error if the room is not found.
pub fn insert_member(
&mut self,
room_name: &str,
user_name: String,
) -> Result<(), Error> {
let room = self.get_mut_strict(room_name)?;
room.members.insert(user_name);
Ok(())
}
/// Removes the given user from the given room's set of members.
/// Returns an error if the room is not found.
pub fn remove_member(
&mut self,
room_name: &str,
user_name: &str,
) -> Result<(), Error> {
let room = self.get_mut_strict(room_name)?;
room.members.remove(user_name);
Ok(())
}
/*---------*
* Tickers *
*---------*/
pub fn set_tickers(
&mut self,
room_name: &str,
tickers: Vec<(String, String)>,
) -> Result<(), Error> {
let room = self.get_mut_strict(room_name)?;
room.tickers = tickers;
Ok(())
}
}
#[cfg(test)]
mod tests {
use crate::proto::server::RoomListResponse;
use super::{Room, RoomMap, Visibility};
#[test]
fn room_map_new_is_empty() {
assert_eq!(RoomMap::new().get_room_list(), vec![]);
}
#[test]
fn room_map_get_strict() {
let mut rooms = RoomMap::new();
rooms.set_room_list(RoomListResponse {
rooms: vec![
("room a".to_string(), 42),
("room b".to_string(), 1337),
],
owned_private_rooms: vec![],
other_private_rooms: vec![],
operated_private_room_names: vec![],
});
assert_eq!(
rooms.get_strict("room a").unwrap(),
&Room::new(Visibility::Public, 42)
);
}
use crate::proto::server::RoomListResponse;
use super::{Room, RoomMap, Visibility};
#[test]
fn room_map_new_is_empty() {
assert_eq!(RoomMap::new().get_room_list(), vec![]);
}
#[test]
fn room_map_get_strict() {
let mut rooms = RoomMap::new();
rooms.set_room_list(RoomListResponse {
rooms: vec![("room a".to_string(), 42), ("room b".to_string(), 1337)],
owned_private_rooms: vec![],
other_private_rooms: vec![],
operated_private_room_names: vec![],
});
assert_eq!(
rooms.get_strict("room a").unwrap(),
&Room::new(Visibility::Public, 42)
);
}
}

+ 134
- 135
src/user.rs View File

@ -7,171 +7,170 @@ use crate::proto::{User, UserStatus};
/// The error returned when a user name was not found in the user map.
#[derive(Debug)]
pub struct UserNotFoundError {
/// The name of the user that wasn't found.
user_name: String,
/// The name of the user that wasn't found.
user_name: String,
}
impl fmt::Display for UserNotFoundError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "user \"{}\" not found", self.user_name)
}
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "user \"{}\" not found", self.user_name)
}
}
impl error::Error for UserNotFoundError {
fn description(&self) -> &str {
"user not found"
}
fn description(&self) -> &str {
"user not found"
}
}
/// Contains the mapping from user names to user data and provides a clean
/// interface to interact with it.
#[derive(Debug)]
pub struct UserMap {
/// The actual map from user names to user data and privileged status.
map: collections::HashMap<String, User>,
/// The set of privileged users.
privileged: collections::HashSet<String>,
/// The actual map from user names to user data and privileged status.
map: collections::HashMap<String, User>,
/// The set of privileged users.
privileged: collections::HashSet<String>,
}
impl UserMap {
/// Creates an empty mapping.
pub fn new() -> Self {
UserMap {
map: collections::HashMap::new(),
privileged: collections::HashSet::new(),
}
}
/// Looks up the given user name in the map, returning an immutable
/// reference to the associated data if found.
pub fn get(&self, user_name: &str) -> Option<&User> {
self.map.get(user_name)
}
/// Looks up the given user name in the map, returning a mutable reference
/// to the associated data if found, or an error if not found.
pub fn get_mut_strict(
&mut self,
user_name: &str,
) -> Result<&mut User, UserNotFoundError> {
match self.map.get_mut(user_name) {
Some(user) => Ok(user),
None => Err(UserNotFoundError {
user_name: user_name.to_string(),
}),
}
}
/// Inserts the given user info for the given user name in the mapping.
/// If there is already data under that name, it is replaced.
pub fn insert(&mut self, user: User) {
self.map.insert(user.name.clone(), user);
}
/// Sets the given user's status to the given value, if such a user exists.
pub fn set_status(
&mut self,
user_name: &str,
status: UserStatus,
) -> Result<(), UserNotFoundError> {
let user = self.get_mut_strict(user_name)?;
user.status = status;
Ok(())
}
/// Returns the list of (user name, user data) representing all known users.
pub fn get_list(&self) -> Vec<(String, User)> {
let mut users = Vec::new();
for (user_name, user_data) in self.map.iter() {
users.push((user_name.clone(), user_data.clone()));
}
users
}
/// Sets the set of privileged users to the given list.
pub fn set_all_privileged(&mut self, mut users: Vec<String>) {
self.privileged.clear();
for user_name in users.drain(..) {
self.privileged.insert(user_name);
}
}
/// Returns a copy of the set of privileged users.
pub fn get_all_privileged(&self) -> Vec<String> {
self.privileged.iter().map(|s| s.to_string()).collect()
}
/// Marks the given user as privileged.
pub fn insert_privileged(&mut self, user_name: String) {
self.privileged.insert(user_name);
}
/// Marks the given user as not privileged.
pub fn remove_privileged(&mut self, user_name: &str) {
self.privileged.remove(user_name);
}
/// Checks if the given user is privileged.
pub fn is_privileged(&self, user_name: &str) -> bool {
self.privileged.contains(user_name)
}
/// Creates an empty mapping.
pub fn new() -> Self {
UserMap {
map: collections::HashMap::new(),
privileged: collections::HashSet::new(),
}
}
/// Looks up the given user name in the map, returning an immutable
/// reference to the associated data if found.
pub fn get(&self, user_name: &str) -> Option<&User> {
self.map.get(user_name)
}
/// Looks up the given user name in the map, returning a mutable reference
/// to the associated data if found, or an error if not found.
pub fn get_mut_strict(
&mut self,
user_name: &str,
) -> Result<&mut User, UserNotFoundError> {
match self.map.get_mut(user_name) {
Some(user) => Ok(user),
None => Err(UserNotFoundError {
user_name: user_name.to_string(),
}),
}
}
/// Inserts the given user info for the given user name in the mapping.
/// If there is already data under that name, it is replaced.
pub fn insert(&mut self, user: User) {
self.map.insert(user.name.clone(), user);
}
/// Sets the given user's status to the given value, if such a user exists.
pub fn set_status(
&mut self,
user_name: &str,
status: UserStatus,
) -> Result<(), UserNotFoundError> {
let user = self.get_mut_strict(user_name)?;
user.status = status;
Ok(())
}
/// Returns the list of (user name, user data) representing all known users.
pub fn get_list(&self) -> Vec<(String, User)> {
let mut users = Vec::new();
for (user_name, user_data) in self.map.iter() {
users.push((user_name.clone(), user_data.clone()));
}
users
}
/// Sets the set of privileged users to the given list.
pub fn set_all_privileged(&mut self, mut users: Vec<String>) {
self.privileged.clear();
for user_name in users.drain(..) {
self.privileged.insert(user_name);
}
}
/// Returns a copy of the set of privileged users.
pub fn get_all_privileged(&self) -> Vec<String> {
self.privileged.iter().map(|s| s.to_string()).collect()
}
/// Marks the given user as privileged.
pub fn insert_privileged(&mut self, user_name: String) {
self.privileged.insert(user_name);
}
/// Marks the given user as not privileged.
pub fn remove_privileged(&mut self, user_name: &str) {
self.privileged.remove(user_name);
}
/// Checks if the given user is privileged.
pub fn is_privileged(&self, user_name: &str) -> bool {
self.privileged.contains(user_name)
}
}
#[cfg(test)]
mod tests {
use super::UserMap;
use super::UserMap;
#[test]
fn new_is_empty() {
let users = UserMap::new();
#[test]
fn new_is_empty() {
let users = UserMap::new();
assert_eq!(users.get_list(), vec![]);
assert_eq!(users.get("bleep"), None);
}
assert_eq!(users.get_list(), vec![]);
assert_eq!(users.get("bleep"), None);
}
#[test]
fn set_get_all_privileged() {
let mut users = UserMap::new();
#[test]
fn set_get_all_privileged() {
let mut users = UserMap::new();
users
.set_all_privileged(vec!["bleep".to_string(), "bloop".to_string()]);
users.set_all_privileged(vec!["bleep".to_string(), "bloop".to_string()]);
let mut privileged = users.get_all_privileged();
privileged.sort();
assert_eq!(privileged, vec!["bleep".to_string(), "bloop".to_string()]);
}
let mut privileged = users.get_all_privileged();
privileged.sort();
assert_eq!(privileged, vec!["bleep".to_string(), "bloop".to_string()]);
}
#[test]
fn insert_privileged() {
let mut users = UserMap::new();
#[test]
fn insert_privileged() {
let mut users = UserMap::new();
users.insert_privileged("bleep".to_string());
users.insert_privileged("bleep".to_string());
users.insert_privileged("bloop".to_string());
users.insert_privileged("bleep".to_string());
users.insert_privileged("bleep".to_string());
users.insert_privileged("bloop".to_string());
let mut privileged = users.get_all_privileged();
privileged.sort();
assert_eq!(privileged, vec!["bleep".to_string(), "bloop".to_string()]);
}
let mut privileged = users.get_all_privileged();
privileged.sort();
assert_eq!(privileged, vec!["bleep".to_string(), "bloop".to_string()]);
}
#[test]
fn remove_privileged() {
let mut users = UserMap::new();
users.insert_privileged("bleep".to_string());
users.insert_privileged("bloop".to_string());
#[test]
fn remove_privileged() {
let mut users = UserMap::new();
users.insert_privileged("bleep".to_string());
users.insert_privileged("bloop".to_string());
users.remove_privileged("bleep");
users.remove_privileged("bleep");
let privileged = users.get_all_privileged();
assert_eq!(privileged, vec!["bloop".to_string()]);
}
let privileged = users.get_all_privileged();
assert_eq!(privileged, vec!["bloop".to_string()]);
}
#[test]
fn is_privileged() {
let mut users = UserMap::new();
users.insert_privileged("bleep".to_string());
#[test]
fn is_privileged() {
let mut users = UserMap::new();
users.insert_privileged("bleep".to_string());
assert!(users.is_privileged("bleep"));
assert!(!users.is_privileged("bloop"));
}
assert!(users.is_privileged("bleep"));
assert!(!users.is_privileged("bloop"));
}
}

Loading…
Cancel
Save