diff --git a/rustfmt.toml b/rustfmt.toml index df99c69..205c72c 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1 +1,2 @@ +tab_spaces = 2 max_width = 80 diff --git a/src/client.rs b/src/client.rs index b2af730..8ecf14a 100644 --- a/src/client.rs +++ b/src/client.rs @@ -15,798 +15,780 @@ use crate::user; #[derive(Debug)] enum IncomingMessage { - Proto(proto::Response), + Proto(proto::Response), - #[allow(dead_code)] - ControlNotification(control::Notification), + #[allow(dead_code)] + ControlNotification(control::Notification), } #[derive(Debug)] enum PeerState { - /// We are trying to establish a direct connection. - #[allow(dead_code)] - Opening, + /// We are trying to establish a direct connection. + #[allow(dead_code)] + Opening, - /// We are trying to establish a reverse connection. - OpeningFirewalled, + /// We are trying to establish a reverse connection. + OpeningFirewalled, - /// We are waiting for a reverse connection to be established to us. - WaitingFirewalled, + /// We are waiting for a reverse connection to be established to us. + WaitingFirewalled, - /// The connection is open. - Open, + /// The connection is open. + Open, } #[derive(Debug)] struct Peer { - user_name: String, - ip: net::Ipv4Addr, - port: u16, - connection_type: String, - token: u32, - state: PeerState, + user_name: String, + ip: net::Ipv4Addr, + port: u16, + connection_type: String, + token: u32, + state: PeerState, } pub struct Client { - #[allow(deprecated)] - proto_tx: mio::deprecated::Sender, - proto_rx: crossbeam_channel::Receiver, + #[allow(deprecated)] + proto_tx: mio::deprecated::Sender, + proto_rx: crossbeam_channel::Receiver, - control_tx: Option, + control_tx: Option, - #[allow(dead_code)] - control_rx: crossbeam_channel::Receiver, + #[allow(dead_code)] + control_rx: crossbeam_channel::Receiver, - login_status: LoginStatus, + login_status: LoginStatus, - rooms: room::RoomMap, - users: user::UserMap, + rooms: room::RoomMap, + users: user::UserMap, - peers: slab::Slab, + peers: slab::Slab, } impl Client { - /// Returns a new client that will communicate with the protocol agent - /// through `proto_tx` and `proto_rx`, and with the controller agent - /// through `control_rx`. - #[allow(deprecated)] - pub fn new( - proto_tx: mio::deprecated::Sender, - proto_rx: crossbeam_channel::Receiver, - control_rx: crossbeam_channel::Receiver, - ) -> Self { - Client { - proto_tx: proto_tx, - proto_rx: proto_rx, + /// Returns a new client that will communicate with the protocol agent + /// through `proto_tx` and `proto_rx`, and with the controller agent + /// through `control_rx`. + #[allow(deprecated)] + pub fn new( + proto_tx: mio::deprecated::Sender, + proto_rx: crossbeam_channel::Receiver, + control_rx: crossbeam_channel::Receiver, + ) -> Self { + Client { + proto_tx: proto_tx, + proto_rx: proto_rx, - control_tx: None, - control_rx: control_rx, + control_tx: None, + control_rx: control_rx, - login_status: LoginStatus::Todo, + login_status: LoginStatus::Todo, - rooms: room::RoomMap::new(), - users: user::UserMap::new(), + rooms: room::RoomMap::new(), + users: user::UserMap::new(), - peers: slab::Slab::new(config::MAX_PEERS), - } + peers: slab::Slab::new(config::MAX_PEERS), } - - /// Runs the client, potentially forever. - pub fn run(&mut self) { - info!("Logging in..."); - self.send_to_server(server::ServerRequest::LoginRequest( - server::LoginRequest::new( - config::USERNAME, - config::PASSWORD, - config::VER_MAJOR, - config::VER_MINOR, - ) - .unwrap(), - )); - - self.login_status = LoginStatus::AwaitingResponse; - - self.send_to_server(server::ServerRequest::SetListenPortRequest( - server::SetListenPortRequest { - port: config::LISTEN_PORT, - }, - )); - - loop { - match self.recv() { - IncomingMessage::Proto(response) => { - self.handle_proto_response(response) - } - - IncomingMessage::ControlNotification(notif) => { - self.handle_control_notification(notif) - } - } + } + + /// Runs the client, potentially forever. + pub fn run(&mut self) { + info!("Logging in..."); + self.send_to_server(server::ServerRequest::LoginRequest( + server::LoginRequest::new( + config::USERNAME, + config::PASSWORD, + config::VER_MAJOR, + config::VER_MINOR, + ) + .unwrap(), + )); + + self.login_status = LoginStatus::AwaitingResponse; + + self.send_to_server(server::ServerRequest::SetListenPortRequest( + server::SetListenPortRequest { + port: config::LISTEN_PORT, + }, + )); + + loop { + match self.recv() { + IncomingMessage::Proto(response) => { + self.handle_proto_response(response) } - } - - // Necessary to break out in different function because self cannot be - // borrowed in the select arms due to *macro things*. - fn recv(&mut self) -> IncomingMessage { - IncomingMessage::Proto(self.proto_rx.recv().unwrap()) - } - /// Send a request to the server. - fn send_to_server(&self, request: server::ServerRequest) { - #[allow(deprecated)] - self.proto_tx - .send(proto::Request::ServerRequest(request)) - .unwrap(); - } - - /// Send a message to a peer. - fn send_to_peer(&self, peer_id: usize, message: peer::Message) { - #[allow(deprecated)] - self.proto_tx - .send(proto::Request::PeerMessage(peer_id, message)) - .unwrap(); - } - - /// Send a response to the controller client. - fn send_to_controller(&mut self, response: control::Response) { - #[allow(deprecated)] - let result = match self.control_tx { - None => { - // Silently drop control requests when controller is - // disconnected. - return; - } - Some(ref mut control_tx) => control_tx.send(response), - }; - // If we failed to send, we assume it means that the other end of the - // channel has been dropped, i.e. the controller has disconnected. - // It may be that mio has died on us, in which case we will never see - // a controller again. If that happens, there would have probably been - // a panic anyway, so we might never hit this corner case. - if let Err(_) = result { - info!("Controller has disconnected."); - self.control_tx = None; + IncomingMessage::ControlNotification(notif) => { + self.handle_control_notification(notif) } + } } + } - /*===============================* - * CONTROL NOTIFICATION HANDLING * - *===============================*/ + // Necessary to break out in different function because self cannot be + // borrowed in the select arms due to *macro things*. + fn recv(&mut self) -> IncomingMessage { + IncomingMessage::Proto(self.proto_rx.recv().unwrap()) + } - fn handle_control_notification(&mut self, notif: control::Notification) { - match notif { - control::Notification::Connected(tx) => { - self.control_tx = Some(tx); - } - - control::Notification::Disconnected => { - self.control_tx = None; - } - - control::Notification::Error(e) => { - debug!("Control loop error: {}", e); - self.control_tx = None; - } - - control::Notification::Request(req) => { - self.handle_control_request(req) - } - } + /// Send a request to the server. + fn send_to_server(&self, request: server::ServerRequest) { + #[allow(deprecated)] + self + .proto_tx + .send(proto::Request::ServerRequest(request)) + .unwrap(); + } + + /// Send a message to a peer. + fn send_to_peer(&self, peer_id: usize, message: peer::Message) { + #[allow(deprecated)] + self + .proto_tx + .send(proto::Request::PeerMessage(peer_id, message)) + .unwrap(); + } + + /// Send a response to the controller client. + fn send_to_controller(&mut self, response: control::Response) { + #[allow(deprecated)] + let result = match self.control_tx { + None => { + // Silently drop control requests when controller is + // disconnected. + return; + } + Some(ref mut control_tx) => control_tx.send(response), + }; + // If we failed to send, we assume it means that the other end of the + // channel has been dropped, i.e. the controller has disconnected. + // It may be that mio has died on us, in which case we will never see + // a controller again. If that happens, there would have probably been + // a panic anyway, so we might never hit this corner case. + if let Err(_) = result { + info!("Controller has disconnected."); + self.control_tx = None; } + } - /*==========================* - * CONTROL REQUEST HANDLING * - *==========================*/ + /*===============================* + * CONTROL NOTIFICATION HANDLING * + *===============================*/ - fn handle_control_request(&mut self, request: control::Request) { - match request { - control::Request::LoginStatusRequest => { - self.handle_login_status_request() - } + fn handle_control_notification(&mut self, notif: control::Notification) { + match notif { + control::Notification::Connected(tx) => { + self.control_tx = Some(tx); + } - control::Request::RoomJoinRequest(room_name) => { - self.handle_room_join_request(room_name) - } + control::Notification::Disconnected => { + self.control_tx = None; + } - control::Request::RoomLeaveRequest(room_name) => { - self.handle_room_leave_request(room_name) - } + control::Notification::Error(e) => { + debug!("Control loop error: {}", e); + self.control_tx = None; + } - control::Request::RoomListRequest => { - self.handle_room_list_request() - } - - control::Request::RoomMessageRequest(request) => { - self.handle_room_message_request(request) - } - - control::Request::UserListRequest => { - self.handle_user_list_request() - } /* - _ =>{ - error!("Unhandled control request: {:?}", request); - }, - */ - } + control::Notification::Request(req) => self.handle_control_request(req), } + } - fn handle_login_status_request(&mut self) { - let username = config::USERNAME.to_string(); + /*==========================* + * CONTROL REQUEST HANDLING * + *==========================*/ - let response = match self.login_status { - LoginStatus::Todo => { - control::LoginStatusResponse::Pending { username: username } - } - LoginStatus::AwaitingResponse => { - control::LoginStatusResponse::Pending { username: username } - } + fn handle_control_request(&mut self, request: control::Request) { + match request { + control::Request::LoginStatusRequest => { + self.handle_login_status_request() + } - LoginStatus::Success(ref motd) => { - control::LoginStatusResponse::Success { - username: username, - motd: motd.clone(), - } - } + control::Request::RoomJoinRequest(room_name) => { + self.handle_room_join_request(room_name) + } - LoginStatus::Failure(ref reason) => { - control::LoginStatusResponse::Failure { - username: username, - reason: reason.clone(), - } - } - }; - self.send_to_controller(control::Response::LoginStatusResponse( - response, - )); - } + control::Request::RoomLeaveRequest(room_name) => { + self.handle_room_leave_request(room_name) + } - fn handle_room_join_request(&mut self, room_name: String) { - match self.rooms.start_joining(&room_name) { - Ok(()) => { - info!("Requesting to join room {:?}", room_name); - self.send_to_server(server::ServerRequest::RoomJoinRequest( - server::RoomJoinRequest { - room_name: room_name, - }, - )); - } + control::Request::RoomListRequest => self.handle_room_list_request(), - Err(err) => error!("RoomLeaveRequest: {}", err), - } - } - - fn handle_room_leave_request(&mut self, room_name: String) { - match self.rooms.start_leaving(&room_name) { - Ok(()) => { - info!("Requesting to leave room {:?}", room_name); - self.send_to_server(server::ServerRequest::RoomLeaveRequest( - server::RoomLeaveRequest { - room_name: room_name, - }, - )); - } + control::Request::RoomMessageRequest(request) => { + self.handle_room_message_request(request) + } - Err(err) => error!("RoomLeaveRequest: {}", err), - } + control::Request::UserListRequest => self.handle_user_list_request(), /* + _ =>{ + error!("Unhandled control request: {:?}", request); + }, + */ } - - fn handle_room_list_request(&mut self) { - // First send the controller client what we have in memory. - let rooms = self.rooms.get_room_list(); - self.send_to_controller(control::Response::RoomListResponse( - control::RoomListResponse { rooms: rooms }, + } + + fn handle_login_status_request(&mut self) { + let username = config::USERNAME.to_string(); + + let response = match self.login_status { + LoginStatus::Todo => { + control::LoginStatusResponse::Pending { username: username } + } + LoginStatus::AwaitingResponse => { + control::LoginStatusResponse::Pending { username: username } + } + + LoginStatus::Success(ref motd) => control::LoginStatusResponse::Success { + username: username, + motd: motd.clone(), + }, + + LoginStatus::Failure(ref reason) => { + control::LoginStatusResponse::Failure { + username: username, + reason: reason.clone(), + } + } + }; + self.send_to_controller(control::Response::LoginStatusResponse(response)); + } + + fn handle_room_join_request(&mut self, room_name: String) { + match self.rooms.start_joining(&room_name) { + Ok(()) => { + info!("Requesting to join room {:?}", room_name); + self.send_to_server(server::ServerRequest::RoomJoinRequest( + server::RoomJoinRequest { + room_name: room_name, + }, )); - // Then ask the server for an updated version, which will be forwarded - // to the controller client once received. - self.send_to_server(server::ServerRequest::RoomListRequest); - } + } - fn handle_room_message_request( - &mut self, - request: control::RoomMessageRequest, - ) { - self.send_to_server(server::ServerRequest::RoomMessageRequest( - server::RoomMessageRequest { - room_name: request.room_name, - message: request.message, - }, - )); + Err(err) => error!("RoomLeaveRequest: {}", err), } - - fn handle_user_list_request(&mut self) { - // Send the controller client what we have in memory. - let user_list = self.users.get_list(); - self.send_to_controller(control::Response::UserListResponse( - control::UserListResponse { - user_list: user_list, - }, + } + + fn handle_room_leave_request(&mut self, room_name: String) { + match self.rooms.start_leaving(&room_name) { + Ok(()) => { + info!("Requesting to leave room {:?}", room_name); + self.send_to_server(server::ServerRequest::RoomLeaveRequest( + server::RoomLeaveRequest { + room_name: room_name, + }, )); - } - - /*=========================* - * PROTO RESPONSE HANDLING * - *=========================*/ - - fn handle_proto_response(&mut self, response: proto::Response) { - match response { - proto::Response::ServerResponse(server_response) => { - self.handle_server_response(server_response) - } - - proto::Response::PeerConnectionOpen(peer_id) => { - self.handle_peer_connection_open(peer_id) - } - - proto::Response::PeerConnectionClosed(peer_id) => { - self.handle_peer_connection_closed(peer_id) - } - - _ => { - warn!("Unhandled proto response: {:?}", response); - } - } - } - - fn handle_peer_connection_closed(&mut self, peer_id: usize) { - let mut occupied_entry = match self.peers.entry(peer_id) { - None | Some(slab::Entry::Vacant(_)) => { - error!("Unknown peer connection {} has closed", peer_id); - return; - } - - Some(slab::Entry::Occupied(occupied_entry)) => occupied_entry, - }; - - match occupied_entry.get_mut().state { - PeerState::Open => { - info!("Peer connection {} has closed", peer_id); - occupied_entry.remove(); - } - - PeerState::WaitingFirewalled => { - error!( - "Peer connection {} has closed, was waiting: inconsistent", - peer_id - ); - occupied_entry.remove(); - } - - PeerState::Opening => { - info!( - "Peer connection {} has been refused, trying reverse", - peer_id - ); - - let peer = occupied_entry.get_mut(); - peer.state = PeerState::WaitingFirewalled; - - #[allow(deprecated)] - self.proto_tx - .send(proto::Request::ServerRequest( - server::ServerRequest::ConnectToPeerRequest( - server::ConnectToPeerRequest { - token: peer.token, - user_name: peer.user_name.clone(), - connection_type: peer.connection_type.clone(), - }, - ), - )) - .unwrap(); - } + } - PeerState::OpeningFirewalled => { - info!( - "Peer connection {} has been refused, cannot connect", - peer_id - ); - - let (peer, _) = occupied_entry.remove(); - #[allow(deprecated)] - self.proto_tx - .send(proto::Request::ServerRequest( - server::ServerRequest::CannotConnectRequest( - server::CannotConnectRequest { - token: peer.token, - user_name: peer.user_name, - }, - ), - )) - .unwrap(); - } - } + Err(err) => error!("RoomLeaveRequest: {}", err), } - - fn handle_peer_connection_open(&mut self, peer_id: usize) { - let message = match self.peers.get_mut(peer_id) { - None => { - error!("Unknown peer connection {} is open", peer_id); - return; - } - - Some( - peer - @ - &mut Peer { - state: PeerState::Open, - .. - }, - ) => { - error!( - "Peer connection {} was already open: {:?}", - peer_id, peer - ); - return; - } - - Some( - peer - @ - &mut Peer { - state: PeerState::WaitingFirewalled, - .. - }, - ) => { - error!("Peer connection {} was waiting: {:?}", peer_id, peer); - return; - } - - Some( - peer - @ - &mut Peer { - state: PeerState::Opening, - .. - }, - ) => { - info!("Peer connection {} is now open: {:?}", peer_id, peer); - // Mark it as open. - peer.state = PeerState::Open; - // Send a PeerInit. - peer::Message::PeerInit(peer::PeerInit { - user_name: config::USERNAME.to_string(), - connection_type: peer.connection_type.clone(), - token: peer.token, - }) - } - - Some( - peer - @ - &mut Peer { - state: PeerState::OpeningFirewalled, - .. - }, - ) => { - info!("Peer connection {} is now open: {:?}", peer_id, peer); - // Mark it as open. - peer.state = PeerState::Open; - // Send a PierceFirewall. - peer::Message::PierceFirewall(peer.token) - } - }; - - self.send_to_peer(peer_id, message); + } + + fn handle_room_list_request(&mut self) { + // First send the controller client what we have in memory. + let rooms = self.rooms.get_room_list(); + self.send_to_controller(control::Response::RoomListResponse( + control::RoomListResponse { rooms: rooms }, + )); + // Then ask the server for an updated version, which will be forwarded + // to the controller client once received. + self.send_to_server(server::ServerRequest::RoomListRequest); + } + + fn handle_room_message_request( + &mut self, + request: control::RoomMessageRequest, + ) { + self.send_to_server(server::ServerRequest::RoomMessageRequest( + server::RoomMessageRequest { + room_name: request.room_name, + message: request.message, + }, + )); + } + + fn handle_user_list_request(&mut self) { + // Send the controller client what we have in memory. + let user_list = self.users.get_list(); + self.send_to_controller(control::Response::UserListResponse( + control::UserListResponse { + user_list: user_list, + }, + )); + } + + /*=========================* + * PROTO RESPONSE HANDLING * + *=========================*/ + + fn handle_proto_response(&mut self, response: proto::Response) { + match response { + proto::Response::ServerResponse(server_response) => { + self.handle_server_response(server_response) + } + + proto::Response::PeerConnectionOpen(peer_id) => { + self.handle_peer_connection_open(peer_id) + } + + proto::Response::PeerConnectionClosed(peer_id) => { + self.handle_peer_connection_closed(peer_id) + } + + _ => { + warn!("Unhandled proto response: {:?}", response); + } } + } + + fn handle_peer_connection_closed(&mut self, peer_id: usize) { + let mut occupied_entry = match self.peers.entry(peer_id) { + None | Some(slab::Entry::Vacant(_)) => { + error!("Unknown peer connection {} has closed", peer_id); + return; + } + + Some(slab::Entry::Occupied(occupied_entry)) => occupied_entry, + }; + + match occupied_entry.get_mut().state { + PeerState::Open => { + info!("Peer connection {} has closed", peer_id); + occupied_entry.remove(); + } + + PeerState::WaitingFirewalled => { + error!( + "Peer connection {} has closed, was waiting: inconsistent", + peer_id + ); + occupied_entry.remove(); + } - /*==========================* - * SERVER RESPONSE HANDLING * - *==========================*/ - - fn handle_server_response(&mut self, response: server::ServerResponse) { - match response { - server::ServerResponse::ConnectToPeerResponse(response) => { - self.handle_connect_to_peer_response(response) - } - - server::ServerResponse::LoginResponse(response) => { - self.handle_login_response(response) - } - - server::ServerResponse::PrivilegedUsersResponse(response) => { - self.handle_privileged_users_response(response) - } - - server::ServerResponse::RoomJoinResponse(response) => { - self.handle_room_join_response(response) - } - - server::ServerResponse::RoomLeaveResponse(response) => { - self.handle_room_leave_response(response) - } - - server::ServerResponse::RoomListResponse(response) => { - self.handle_room_list_response(response) - } - - server::ServerResponse::RoomMessageResponse(response) => { - self.handle_room_message_response(response) - } - - server::ServerResponse::RoomTickersResponse(response) => { - self.handle_room_tickers_response(response) - } - - server::ServerResponse::RoomUserJoinedResponse(response) => { - self.handle_room_user_joined_response(response) - } - - server::ServerResponse::RoomUserLeftResponse(response) => { - self.handle_room_user_left_response(response) - } - - server::ServerResponse::UserInfoResponse(response) => { - self.handle_user_info_response(response) - } - - server::ServerResponse::UserStatusResponse(response) => { - self.handle_user_status_response(response) - } - - server::ServerResponse::UnknownResponse(code) => { - warn!("Unknown response: code {}", code) - } - - response => warn!("Unhandled response: {:?}", response), - } - } + PeerState::Opening => { + info!( + "Peer connection {} has been refused, trying reverse", + peer_id + ); - fn handle_connect_to_peer_response( - &mut self, - response: server::ConnectToPeerResponse, - ) { - let peer = Peer { - user_name: response.user_name, - ip: response.ip, - port: response.port, - connection_type: response.connection_type, - token: response.token, - state: PeerState::OpeningFirewalled, - }; - - match self.peers.insert(peer) { - Ok(peer_id) => { - info!( - "Opening peer connection {} to {}:{} to pierce firewall", - peer_id, response.ip, response.port - ); - #[allow(deprecated)] - self.proto_tx - .send(proto::Request::PeerConnect( - peer_id, - response.ip, - response.port, - )) - .unwrap(); - } + let peer = occupied_entry.get_mut(); + peer.state = PeerState::WaitingFirewalled; - Err(peer) => { - warn!( - "Cannot open peer connection {:?}: too many already open", - peer - ); - } - } - } + #[allow(deprecated)] + self + .proto_tx + .send(proto::Request::ServerRequest( + server::ServerRequest::ConnectToPeerRequest( + server::ConnectToPeerRequest { + token: peer.token, + user_name: peer.user_name.clone(), + connection_type: peer.connection_type.clone(), + }, + ), + )) + .unwrap(); + } + + PeerState::OpeningFirewalled => { + info!( + "Peer connection {} has been refused, cannot connect", + peer_id + ); - fn handle_login_response(&mut self, login: server::LoginResponse) { - if let LoginStatus::AwaitingResponse = self.login_status { - match login { - server::LoginResponse::LoginOk { - motd, - ip, - password_md5_opt, - } => { - info!("Login successful!"); - info!("MOTD: \"{}\"", motd); - info!("External IP address: {}", ip); - - match password_md5_opt { - Some(_) => { - info!(concat!( - "Connected to official server ", - "as official client" - )); - } - None => info!(concat!( - "Connected to official server ", - "as unofficial client" - )), - } - self.login_status = LoginStatus::Success(motd); - } - - server::LoginResponse::LoginFail { reason } => { - error!("Login failed: \"{}\"", reason); - self.login_status = LoginStatus::Failure(reason); - } - } - } else { - error!( - "Received unexpected login response, status = {:?}", - self.login_status - ); - } + let (peer, _) = occupied_entry.remove(); + #[allow(deprecated)] + self + .proto_tx + .send(proto::Request::ServerRequest( + server::ServerRequest::CannotConnectRequest( + server::CannotConnectRequest { + token: peer.token, + user_name: peer.user_name, + }, + ), + )) + .unwrap(); + } } - - fn handle_privileged_users_response( - &mut self, - response: server::PrivilegedUsersResponse, - ) { - self.users.set_all_privileged(response.users); + } + + fn handle_peer_connection_open(&mut self, peer_id: usize) { + let message = match self.peers.get_mut(peer_id) { + None => { + error!("Unknown peer connection {} is open", peer_id); + return; + } + + Some( + peer + @ &mut Peer { + state: PeerState::Open, + .. + }, + ) => { + error!("Peer connection {} was already open: {:?}", peer_id, peer); + return; + } + + Some( + peer + @ + &mut Peer { + state: PeerState::WaitingFirewalled, + .. + }, + ) => { + error!("Peer connection {} was waiting: {:?}", peer_id, peer); + return; + } + + Some( + peer + @ + &mut Peer { + state: PeerState::Opening, + .. + }, + ) => { + info!("Peer connection {} is now open: {:?}", peer_id, peer); + // Mark it as open. + peer.state = PeerState::Open; + // Send a PeerInit. + peer::Message::PeerInit(peer::PeerInit { + user_name: config::USERNAME.to_string(), + connection_type: peer.connection_type.clone(), + token: peer.token, + }) + } + + Some( + peer + @ + &mut Peer { + state: PeerState::OpeningFirewalled, + .. + }, + ) => { + info!("Peer connection {} is now open: {:?}", peer_id, peer); + // Mark it as open. + peer.state = PeerState::Open; + // Send a PierceFirewall. + peer::Message::PierceFirewall(peer.token) + } + }; + + self.send_to_peer(peer_id, message); + } + + /*==========================* + * SERVER RESPONSE HANDLING * + *==========================*/ + + fn handle_server_response(&mut self, response: server::ServerResponse) { + match response { + server::ServerResponse::ConnectToPeerResponse(response) => { + self.handle_connect_to_peer_response(response) + } + + server::ServerResponse::LoginResponse(response) => { + self.handle_login_response(response) + } + + server::ServerResponse::PrivilegedUsersResponse(response) => { + self.handle_privileged_users_response(response) + } + + server::ServerResponse::RoomJoinResponse(response) => { + self.handle_room_join_response(response) + } + + server::ServerResponse::RoomLeaveResponse(response) => { + self.handle_room_leave_response(response) + } + + server::ServerResponse::RoomListResponse(response) => { + self.handle_room_list_response(response) + } + + server::ServerResponse::RoomMessageResponse(response) => { + self.handle_room_message_response(response) + } + + server::ServerResponse::RoomTickersResponse(response) => { + self.handle_room_tickers_response(response) + } + + server::ServerResponse::RoomUserJoinedResponse(response) => { + self.handle_room_user_joined_response(response) + } + + server::ServerResponse::RoomUserLeftResponse(response) => { + self.handle_room_user_left_response(response) + } + + server::ServerResponse::UserInfoResponse(response) => { + self.handle_user_info_response(response) + } + + server::ServerResponse::UserStatusResponse(response) => { + self.handle_user_status_response(response) + } + + server::ServerResponse::UnknownResponse(code) => { + warn!("Unknown response: code {}", code) + } + + response => warn!("Unhandled response: {:?}", response), } - - fn handle_room_join_response( - &mut self, - mut response: server::RoomJoinResponse, - ) { - // Join the room and store the received information. - let result = self.rooms.join( - &response.room_name, - response.owner, - response.operators, - &response.users, + } + + fn handle_connect_to_peer_response( + &mut self, + response: server::ConnectToPeerResponse, + ) { + let peer = Peer { + user_name: response.user_name, + ip: response.ip, + port: response.port, + connection_type: response.connection_type, + token: response.token, + state: PeerState::OpeningFirewalled, + }; + + match self.peers.insert(peer) { + Ok(peer_id) => { + info!( + "Opening peer connection {} to {}:{} to pierce firewall", + peer_id, response.ip, response.port + ); + #[allow(deprecated)] + self + .proto_tx + .send(proto::Request::PeerConnect( + peer_id, + response.ip, + response.port, + )) + .unwrap(); + } + + Err(peer) => { + warn!( + "Cannot open peer connection {:?}: too many already open", + peer ); - if let Err(err) = result { - error!("RoomJoinResponse: {}", err); - return; + } + } + } + + fn handle_login_response(&mut self, login: server::LoginResponse) { + if let LoginStatus::AwaitingResponse = self.login_status { + match login { + server::LoginResponse::LoginOk { + motd, + ip, + password_md5_opt, + } => { + info!("Login successful!"); + info!("MOTD: \"{}\"", motd); + info!("External IP address: {}", ip); + + match password_md5_opt { + Some(_) => { + info!(concat!( + "Connected to official server ", + "as official client" + )); + } + None => info!(concat!( + "Connected to official server ", + "as unofficial client" + )), + } + self.login_status = LoginStatus::Success(motd); } - // Then update the user structs based on the info we just got. - for user in response.users.drain(..) { - self.users.insert(user); + server::LoginResponse::LoginFail { reason } => { + error!("Login failed: \"{}\"", reason); + self.login_status = LoginStatus::Failure(reason); } - - let control_response = control::RoomJoinResponse { - room_name: response.room_name, - }; - self.send_to_controller(control::Response::RoomJoinResponse( - control_response, - )); + } + } else { + error!( + "Received unexpected login response, status = {:?}", + self.login_status + ); } - - fn handle_room_leave_response( - &mut self, - response: server::RoomLeaveResponse, - ) { - if let Err(err) = self.rooms.leave(&response.room_name) { - error!("RoomLeaveResponse: {}", err); - } - - self.send_to_controller(control::Response::RoomLeaveResponse( - control::RoomLeaveResponse { - room_name: response.room_name, - }, - )); + } + + fn handle_privileged_users_response( + &mut self, + response: server::PrivilegedUsersResponse, + ) { + self.users.set_all_privileged(response.users); + } + + fn handle_room_join_response( + &mut self, + mut response: server::RoomJoinResponse, + ) { + // Join the room and store the received information. + let result = self.rooms.join( + &response.room_name, + response.owner, + response.operators, + &response.users, + ); + if let Err(err) = result { + error!("RoomJoinResponse: {}", err); + return; } - fn handle_room_list_response( - &mut self, - response: server::RoomListResponse, - ) { - // Update the room map in memory. - self.rooms.set_room_list(response); - // Send the updated version to the controller. - let rooms = self.rooms.get_room_list(); - self.send_to_controller(control::Response::RoomListResponse( - control::RoomListResponse { rooms: rooms }, - )); + // Then update the user structs based on the info we just got. + for user in response.users.drain(..) { + self.users.insert(user); } - fn handle_room_message_response( - &mut self, - response: server::RoomMessageResponse, - ) { - let result = self.rooms.add_message( - &response.room_name, - room::Message { - user_name: response.user_name.clone(), - message: response.message.clone(), - }, - ); - if let Err(err) = result { - error!("RoomMessageResponse: {}", err); - return; - } - - self.send_to_controller(control::Response::RoomMessageResponse( - control::RoomMessageResponse { - room_name: response.room_name, - user_name: response.user_name, - message: response.message, - }, - )); + let control_response = control::RoomJoinResponse { + room_name: response.room_name, + }; + self.send_to_controller(control::Response::RoomJoinResponse( + control_response, + )); + } + + fn handle_room_leave_response( + &mut self, + response: server::RoomLeaveResponse, + ) { + if let Err(err) = self.rooms.leave(&response.room_name) { + error!("RoomLeaveResponse: {}", err); } - fn handle_room_tickers_response( - &mut self, - response: server::RoomTickersResponse, - ) { - let result = self - .rooms - .set_tickers(&response.room_name, response.tickers); - if let Err(e) = result { - error!("RoomTickersResponse: {}", e); - } + self.send_to_controller(control::Response::RoomLeaveResponse( + control::RoomLeaveResponse { + room_name: response.room_name, + }, + )); + } + + fn handle_room_list_response(&mut self, response: server::RoomListResponse) { + // Update the room map in memory. + self.rooms.set_room_list(response); + // Send the updated version to the controller. + let rooms = self.rooms.get_room_list(); + self.send_to_controller(control::Response::RoomListResponse( + control::RoomListResponse { rooms: rooms }, + )); + } + + fn handle_room_message_response( + &mut self, + response: server::RoomMessageResponse, + ) { + let result = self.rooms.add_message( + &response.room_name, + room::Message { + user_name: response.user_name.clone(), + message: response.message.clone(), + }, + ); + if let Err(err) = result { + error!("RoomMessageResponse: {}", err); + return; } - fn handle_room_user_joined_response( - &mut self, - response: server::RoomUserJoinedResponse, - ) { - let result = self - .rooms - .insert_member(&response.room_name, response.user.name.clone()); - if let Err(err) = result { - error!("RoomUserJoinedResponse: {}", err); - return; - } - self.send_to_controller(control::Response::RoomUserJoinedResponse( - control::RoomUserJoinedResponse { - room_name: response.room_name, - user_name: response.user.name, - }, - )); + self.send_to_controller(control::Response::RoomMessageResponse( + control::RoomMessageResponse { + room_name: response.room_name, + user_name: response.user_name, + message: response.message, + }, + )); + } + + fn handle_room_tickers_response( + &mut self, + response: server::RoomTickersResponse, + ) { + let result = self + .rooms + .set_tickers(&response.room_name, response.tickers); + if let Err(e) = result { + error!("RoomTickersResponse: {}", e); } - - fn handle_room_user_left_response( - &mut self, - response: server::RoomUserLeftResponse, - ) { - let result = self - .rooms - .remove_member(&response.room_name, &response.user_name); - if let Err(err) = result { - error!("RoomUserLeftResponse: {}", err); - return; - } - self.send_to_controller(control::Response::RoomUserLeftResponse( - control::RoomUserLeftResponse { - room_name: response.room_name, - user_name: response.user_name, - }, - )); + } + + fn handle_room_user_joined_response( + &mut self, + response: server::RoomUserJoinedResponse, + ) { + let result = self + .rooms + .insert_member(&response.room_name, response.user.name.clone()); + if let Err(err) = result { + error!("RoomUserJoinedResponse: {}", err); + return; } - - fn handle_user_info_response( - &mut self, - response: server::UserInfoResponse, - ) { - let c_response = match self.users.get_mut_strict(&response.user_name) { - Ok(user) => { - user.average_speed = response.average_speed; - user.num_downloads = response.num_downloads; - user.num_files = response.num_files; - user.num_folders = response.num_folders; - control::UserInfoResponse { - user_name: response.user_name, - user_info: user.clone(), - } - } - Err(err) => { - error!("UserInfoResponse: {}", err); - return; - } - }; - self.send_to_controller(control::Response::UserInfoResponse( - c_response, - )); + self.send_to_controller(control::Response::RoomUserJoinedResponse( + control::RoomUserJoinedResponse { + room_name: response.room_name, + user_name: response.user.name, + }, + )); + } + + fn handle_room_user_left_response( + &mut self, + response: server::RoomUserLeftResponse, + ) { + let result = self + .rooms + .remove_member(&response.room_name, &response.user_name); + if let Err(err) = result { + error!("RoomUserLeftResponse: {}", err); + return; } - - fn handle_user_status_response( - &mut self, - response: server::UserStatusResponse, - ) { - let result = - self.users.set_status(&response.user_name, response.status); - if let Err(err) = result { - error!("UserStatusResponse: {}", err); - return; + self.send_to_controller(control::Response::RoomUserLeftResponse( + control::RoomUserLeftResponse { + room_name: response.room_name, + user_name: response.user_name, + }, + )); + } + + fn handle_user_info_response(&mut self, response: server::UserInfoResponse) { + let c_response = match self.users.get_mut_strict(&response.user_name) { + Ok(user) => { + user.average_speed = response.average_speed; + user.num_downloads = response.num_downloads; + user.num_files = response.num_files; + user.num_folders = response.num_folders; + control::UserInfoResponse { + user_name: response.user_name, + user_info: user.clone(), } + } + Err(err) => { + error!("UserInfoResponse: {}", err); + return; + } + }; + self.send_to_controller(control::Response::UserInfoResponse(c_response)); + } + + fn handle_user_status_response( + &mut self, + response: server::UserStatusResponse, + ) { + let result = self.users.set_status(&response.user_name, response.status); + if let Err(err) = result { + error!("UserStatusResponse: {}", err); + return; + } - if response.is_privileged { - self.users.insert_privileged(response.user_name); - } else { - self.users.remove_privileged(&response.user_name); - } + if response.is_privileged { + self.users.insert_privileged(response.user_name); + } else { + self.users.remove_privileged(&response.user_name); } + } } diff --git a/src/context.rs b/src/context.rs index 1bb4599..68ce50a 100644 --- a/src/context.rs +++ b/src/context.rs @@ -12,35 +12,35 @@ use crate::user::UserMap; /// Implements `Sync`. #[derive(Debug)] pub struct Context { - pub login: Mutex, - pub rooms: Mutex, - pub users: Mutex, + pub login: Mutex, + pub rooms: Mutex, + pub users: Mutex, } impl Context { - /// Creates a new empty context. - pub fn new() -> Self { - Self { - login: Mutex::new(LoginStatus::Todo), - rooms: Mutex::new(RoomMap::new()), - users: Mutex::new(UserMap::new()), - } + /// Creates a new empty context. + pub fn new() -> Self { + Self { + login: Mutex::new(LoginStatus::Todo), + rooms: Mutex::new(RoomMap::new()), + users: Mutex::new(UserMap::new()), } + } } #[cfg(test)] mod tests { - use super::Context; + use super::Context; - #[test] - fn new_context_is_empty() { - let context = Context::new(); - assert_eq!(context.rooms.lock().get_room_list(), vec![]); - assert_eq!(context.users.lock().get_list(), vec![]); - } + #[test] + fn new_context_is_empty() { + let context = Context::new(); + assert_eq!(context.rooms.lock().get_room_list(), vec![]); + assert_eq!(context.users.lock().get_list(), vec![]); + } - #[test] - fn context_is_sync() { - let _sync: &dyn Sync = &Context::new(); - } + #[test] + fn context_is_sync() { + let _sync: &dyn Sync = &Context::new(); + } } diff --git a/src/control/request.rs b/src/control/request.rs index 5518cba..67b1aa2 100644 --- a/src/control/request.rs +++ b/src/control/request.rs @@ -2,25 +2,25 @@ /// controller client to the client. #[derive(Debug, RustcDecodable, RustcEncodable)] pub enum Request { - /// The controller wants to join a room. Contains the room name. - RoomJoinRequest(String), - /// The controller wants to leave a rom. Contains the room name. - RoomLeaveRequest(String), - /// The controller wants to know what the login status is. - LoginStatusRequest, - /// The controller wants to know the list of visible chat rooms. - RoomListRequest, - /// The controller wants to send a message to a chat room. - RoomMessageRequest(RoomMessageRequest), - /// The controller wants to know the list of known users. - UserListRequest, + /// The controller wants to join a room. Contains the room name. + RoomJoinRequest(String), + /// The controller wants to leave a rom. Contains the room name. + RoomLeaveRequest(String), + /// The controller wants to know what the login status is. + LoginStatusRequest, + /// The controller wants to know the list of visible chat rooms. + RoomListRequest, + /// The controller wants to send a message to a chat room. + RoomMessageRequest(RoomMessageRequest), + /// The controller wants to know the list of known users. + UserListRequest, } /// This structure contains the chat room message request from the controller. #[derive(Debug, RustcDecodable, RustcEncodable)] pub struct RoomMessageRequest { - /// The name of the chat room in which to send the message. - pub room_name: String, - /// The message to be said. - pub message: String, + /// The name of the chat room in which to send the message. + pub room_name: String, + /// The message to be said. + pub message: String, } diff --git a/src/control/response.rs b/src/control/response.rs index a0588e0..770ebec 100644 --- a/src/control/response.rs +++ b/src/control/response.rs @@ -5,98 +5,98 @@ use crate::room; /// to the controller. #[derive(Debug, RustcDecodable, RustcEncodable)] pub enum Response { - LoginStatusResponse(LoginStatusResponse), - RoomJoinResponse(RoomJoinResponse), - RoomLeaveResponse(RoomLeaveResponse), - RoomListResponse(RoomListResponse), - RoomMessageResponse(RoomMessageResponse), - RoomUserJoinedResponse(RoomUserJoinedResponse), - RoomUserLeftResponse(RoomUserLeftResponse), - UserInfoResponse(UserInfoResponse), - UserListResponse(UserListResponse), + LoginStatusResponse(LoginStatusResponse), + RoomJoinResponse(RoomJoinResponse), + RoomLeaveResponse(RoomLeaveResponse), + RoomListResponse(RoomListResponse), + RoomMessageResponse(RoomMessageResponse), + RoomUserJoinedResponse(RoomUserJoinedResponse), + RoomUserLeftResponse(RoomUserLeftResponse), + UserInfoResponse(UserInfoResponse), + UserListResponse(UserListResponse), } #[derive(Debug, RustcEncodable, RustcDecodable)] pub struct RoomJoinResponse { - pub room_name: String, + pub room_name: String, } #[derive(Debug, RustcEncodable, RustcDecodable)] pub struct RoomLeaveResponse { - pub room_name: String, + pub room_name: String, } /// This enumeration is the list of possible login states, and the associated /// information. #[derive(Debug, RustcDecodable, RustcEncodable)] pub enum LoginStatusResponse { - /// The login request has been sent to the server, but the response hasn't - /// been received yet. - Pending { - /// The username used to log in. - username: String, - }, + /// The login request has been sent to the server, but the response hasn't + /// been received yet. + Pending { + /// The username used to log in. + username: String, + }, - /// Login was successful. - Success { - /// The username used to log in. - username: String, - /// The message of the day sent by the server. - motd: String, - }, + /// Login was successful. + Success { + /// The username used to log in. + username: String, + /// The message of the day sent by the server. + motd: String, + }, - /// Login failed. - Failure { - /// The username used to log in. - username: String, - /// The reason the server gave for refusing the login request. - reason: String, - }, + /// Login failed. + Failure { + /// The username used to log in. + username: String, + /// The reason the server gave for refusing the login request. + reason: String, + }, } /// This structure contains the list of all visible rooms, and their associated /// data. #[derive(Debug, RustcDecodable, RustcEncodable)] pub struct RoomListResponse { - /// The list of (room name, room data) pairs. - pub rooms: Vec<(String, room::Room)>, + /// The list of (room name, room data) pairs. + pub rooms: Vec<(String, room::Room)>, } /// This structure contains a message said in a chat room the user is a member /// of. #[derive(Debug, RustcDecodable, RustcEncodable)] pub struct RoomMessageResponse { - /// The name of the room in which the message was said. - pub room_name: String, - /// The name of the user who said the message. - pub user_name: String, - /// The message itself. - pub message: String, + /// The name of the room in which the message was said. + pub room_name: String, + /// The name of the user who said the message. + pub user_name: String, + /// The message itself. + pub message: String, } /// This struct describes the fact that the given user joined the given room. #[derive(Debug, RustcDecodable, RustcEncodable)] pub struct RoomUserJoinedResponse { - pub room_name: String, - pub user_name: String, + pub room_name: String, + pub user_name: String, } /// This struct describes the fact that the given user left the given room. #[derive(Debug, RustcDecodable, RustcEncodable)] pub struct RoomUserLeftResponse { - pub room_name: String, - pub user_name: String, + pub room_name: String, + pub user_name: String, } /// This struct contains the last known information about a given user. #[derive(Debug, RustcDecodable, RustcEncodable)] pub struct UserInfoResponse { - pub user_name: String, - pub user_info: User, + pub user_name: String, + pub user_info: User, } /// This stuct contains the last known information about every user. #[derive(Debug, RustcDecodable, RustcEncodable)] pub struct UserListResponse { - pub user_list: Vec<(String, User)>, + pub user_list: Vec<(String, User)>, } diff --git a/src/control/ws.rs b/src/control/ws.rs index 483b4f4..00881c6 100644 --- a/src/control/ws.rs +++ b/src/control/ws.rs @@ -14,65 +14,65 @@ use super::response::*; /// send to the client. #[derive(Debug)] pub enum Notification { - /// A new controller has connected: control messages can now be sent on the - /// given channel. - Connected(Sender), - /// The controller has disconnected. - Disconnected, - /// An irretrievable error has arisen. - Error(String), - /// The controller has sent a request. - Request(Request), + /// A new controller has connected: control messages can now be sent on the + /// given channel. + Connected(Sender), + /// The controller has disconnected. + Disconnected, + /// An irretrievable error has arisen. + Error(String), + /// The controller has sent a request. + Request(Request), } /// This error is returned when a `Sender` fails to send a control request. #[derive(Debug)] pub enum SendError { - /// Error encoding the control request. - JSONEncoderError(json::EncoderError), - /// Error sending the encoded control request to the websocket. - WebSocketError(ws::Error), + /// Error encoding the control request. + JSONEncoderError(json::EncoderError), + /// Error sending the encoded control request to the websocket. + WebSocketError(ws::Error), } impl fmt::Display for SendError { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - match *self { - SendError::JSONEncoderError(ref err) => { - write!(fmt, "JSONEncoderError: {}", err) - } - SendError::WebSocketError(ref err) => { - write!(fmt, "WebSocketError: {}", err) - } - } + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + match *self { + SendError::JSONEncoderError(ref err) => { + write!(fmt, "JSONEncoderError: {}", err) + } + SendError::WebSocketError(ref err) => { + write!(fmt, "WebSocketError: {}", err) + } } + } } impl error::Error for SendError { - fn description(&self) -> &str { - match *self { - SendError::JSONEncoderError(_) => "JSONEncoderError", - SendError::WebSocketError(_) => "WebSocketError", - } + fn description(&self) -> &str { + match *self { + SendError::JSONEncoderError(_) => "JSONEncoderError", + SendError::WebSocketError(_) => "WebSocketError", } + } - fn cause(&self) -> Option<&dyn error::Error> { - match *self { - SendError::JSONEncoderError(ref err) => Some(err), - SendError::WebSocketError(ref err) => Some(err), - } + fn cause(&self) -> Option<&dyn error::Error> { + match *self { + SendError::JSONEncoderError(ref err) => Some(err), + SendError::WebSocketError(ref err) => Some(err), } + } } impl From for SendError { - fn from(err: json::EncoderError) -> Self { - SendError::JSONEncoderError(err) - } + fn from(err: json::EncoderError) -> Self { + SendError::JSONEncoderError(err) + } } impl From for SendError { - fn from(err: ws::Error) -> Self { - SendError::WebSocketError(err) - } + fn from(err: ws::Error) -> Self { + SendError::WebSocketError(err) + } } /// This struct is used to send control responses to the controller. @@ -80,125 +80,123 @@ impl From for SendError { /// the underlying implementation. #[derive(Clone, Debug)] pub struct Sender { - sender: ws::Sender, + sender: ws::Sender, } impl Sender { - /// Queues up a control response to be sent to the controller. - pub fn send(&mut self, response: Response) -> Result<(), SendError> { - let encoded = json::encode(&response)?; - self.sender.send(encoded)?; - Ok(()) - } + /// Queues up a control response to be sent to the controller. + pub fn send(&mut self, response: Response) -> Result<(), SendError> { + let encoded = json::encode(&response)?; + self.sender.send(encoded)?; + Ok(()) + } } /// This struct handles a single websocket connection. #[derive(Debug)] struct Handler { - /// The channel on which to send notifications to the client. - client_tx: crossbeam_channel::Sender, - /// The channel on which to send messages to the controller. - socket_tx: ws::Sender, + /// The channel on which to send notifications to the client. + client_tx: crossbeam_channel::Sender, + /// The channel on which to send messages to the controller. + socket_tx: ws::Sender, } impl Handler { - fn send_to_client(&self, notification: Notification) -> ws::Result<()> { - match self.client_tx.send(notification) { - Ok(()) => Ok(()), - Err(e) => { - error!("Error sending notification to client: {}", e); - Err(ws::Error::new(ws::ErrorKind::Internal, "")) - } - } + fn send_to_client(&self, notification: Notification) -> ws::Result<()> { + match self.client_tx.send(notification) { + Ok(()) => Ok(()), + Err(e) => { + error!("Error sending notification to client: {}", e); + Err(ws::Error::new(ws::ErrorKind::Internal, "")) + } } + } } impl ws::Handler for Handler { - fn on_open(&mut self, _: ws::Handshake) -> ws::Result<()> { - info!("Websocket open"); - self.send_to_client(Notification::Connected(Sender { - sender: self.socket_tx.clone(), - })) - } + fn on_open(&mut self, _: ws::Handshake) -> ws::Result<()> { + info!("Websocket open"); + self.send_to_client(Notification::Connected(Sender { + sender: self.socket_tx.clone(), + })) + } + + fn on_close(&mut self, code: ws::CloseCode, reason: &str) { + info!("Websocket closed: code: {:?}, reason: {:?}", code, reason); + self + .send_to_client(Notification::Disconnected) + .unwrap_or(()) + } + + fn on_message(&mut self, msg: ws::Message) -> ws::Result<()> { + // Get the payload string. + let payload = match msg { + ws::Message::Text(payload) => payload, + ws::Message::Binary(_) => { + error!("Received binary websocket message from controller"); + return Err(ws::Error::new( + ws::ErrorKind::Protocol, + "Binary message not supported", + )); + } + }; - fn on_close(&mut self, code: ws::CloseCode, reason: &str) { - info!("Websocket closed: code: {:?}, reason: {:?}", code, reason); - self.send_to_client(Notification::Disconnected) - .unwrap_or(()) - } + // Decode the json control request. + let control_request = match json::decode(&payload) { + Ok(control_request) => control_request, + Err(e) => { + error!("Received invalid JSON message from controller: {}", e); + return Err(ws::Error::new(ws::ErrorKind::Protocol, "Invalid JSON")); + } + }; - fn on_message(&mut self, msg: ws::Message) -> ws::Result<()> { - // Get the payload string. - let payload = match msg { - ws::Message::Text(payload) => payload, - ws::Message::Binary(_) => { - error!("Received binary websocket message from controller"); - return Err(ws::Error::new( - ws::ErrorKind::Protocol, - "Binary message not supported", - )); - } - }; - - // Decode the json control request. - let control_request = match json::decode(&payload) { - Ok(control_request) => control_request, - Err(e) => { - error!("Received invalid JSON message from controller: {}", e); - return Err(ws::Error::new( - ws::ErrorKind::Protocol, - "Invalid JSON", - )); - } - }; - - debug!("Received control request: {:?}", control_request); - - // Send the control request to the client. - self.send_to_client(Notification::Request(control_request)) - } + debug!("Received control request: {:?}", control_request); + + // Send the control request to the client. + self.send_to_client(Notification::Request(control_request)) + } } /// Start listening on the socket address stored in configuration, and send /// control notifications to the client through the given channel. pub fn listen(client_tx: crossbeam_channel::Sender) { - let websocket_result = ws::Builder::new() - .with_settings(ws::Settings { - max_connections: 1, - ..ws::Settings::default() - }) - .build(|socket_tx| Handler { - client_tx: client_tx.clone(), - socket_tx: socket_tx, - }); - - let websocket = match websocket_result { - Ok(websocket) => websocket, - Err(e) => { - error!("Unable to build websocket: {}", e); - client_tx - .send(Notification::Error(format!( - "Unable to build websocket: {}", - e - ))) - .unwrap(); - return; - } - }; - - let listen_result = - websocket.listen((config::CONTROL_HOST, config::CONTROL_PORT)); - - match listen_result { - Ok(_) => (), - Err(e) => { - error!("Unable to listen on websocket: {}", e); - client_tx - .send(Notification::Error(format!( - "Unable to listen on websocket: {}", - e - ))) - .unwrap(); - } + let websocket_result = ws::Builder::new() + .with_settings(ws::Settings { + max_connections: 1, + ..ws::Settings::default() + }) + .build(|socket_tx| Handler { + client_tx: client_tx.clone(), + socket_tx: socket_tx, + }); + + let websocket = match websocket_result { + Ok(websocket) => websocket, + Err(e) => { + error!("Unable to build websocket: {}", e); + client_tx + .send(Notification::Error(format!( + "Unable to build websocket: {}", + e + ))) + .unwrap(); + return; + } + }; + + let listen_result = + websocket.listen((config::CONTROL_HOST, config::CONTROL_PORT)); + + match listen_result { + Ok(_) => (), + Err(e) => { + error!("Unable to listen on websocket: {}", e); + client_tx + .send(Notification::Error(format!( + "Unable to listen on websocket: {}", + e + ))) + .unwrap(); } + } } diff --git a/src/dispatcher.rs b/src/dispatcher.rs index 025ef1c..ebae154 100644 --- a/src/dispatcher.rs +++ b/src/dispatcher.rs @@ -11,97 +11,88 @@ use crate::proto::server::ServerResponse; /// The type of messages dispatched by a dispatcher. #[derive(Debug)] pub enum Message { - ServerResponse(ServerResponse), + ServerResponse(ServerResponse), } /// Pairs together a message and its handler as chosen by the dispatcher. /// Implements Job so as to be scheduled on an executor. struct DispatchedMessage { - message: M, - handler: H, + message: M, + handler: H, } impl DispatchedMessage { - fn new(message: M, handler: H) -> Self { - Self { message, handler } - } + fn new(message: M, handler: H) -> Self { + Self { message, handler } + } } impl Job for DispatchedMessage where - M: Debug + Send, - H: MessageHandler + Send, + M: Debug + Send, + H: MessageHandler + Send, { - fn execute(self: Box, context: &Context) { - if let Err(error) = self.handler.run(context, &self.message) { - error!( - "Error in handler {}: {:?}\nMessage: {:?}", - H::name(), - error, - &self.message - ); - } + fn execute(self: Box, context: &Context) { + if let Err(error) = self.handler.run(context, &self.message) { + error!( + "Error in handler {}: {:?}\nMessage: {:?}", + H::name(), + error, + &self.message + ); } + } } /// The Dispatcher is in charge of mapping messages to their handlers. pub struct Dispatcher; impl Dispatcher { - /// Returns a new dispatcher. - pub fn new() -> Self { - Self {} - } + /// Returns a new dispatcher. + pub fn new() -> Self { + Self {} + } - /// Dispatches the given message by wrapping it with a handler. - pub fn dispatch(&self, message: Message) -> Box { - match message { - Message::ServerResponse(ServerResponse::LoginResponse( - response, - )) => Box::new(DispatchedMessage::new( - response, - LoginHandler::default(), - )), - Message::ServerResponse( - ServerResponse::PrivilegedUsersResponse(response), - ) => Box::new(DispatchedMessage::new( - response, - SetPrivilegedUsersHandler::default(), - )), - _ => panic!("Unimplemented"), - } + /// Dispatches the given message by wrapping it with a handler. + pub fn dispatch(&self, message: Message) -> Box { + match message { + Message::ServerResponse(ServerResponse::LoginResponse(response)) => { + Box::new(DispatchedMessage::new(response, LoginHandler::default())) + } + Message::ServerResponse(ServerResponse::PrivilegedUsersResponse( + response, + )) => Box::new(DispatchedMessage::new( + response, + SetPrivilegedUsersHandler::default(), + )), + _ => panic!("Unimplemented"), } + } } #[cfg(test)] mod tests { - use crate::proto::server; + use crate::proto::server; - use super::*; + use super::*; - #[test] - fn dispatcher_privileged_users_response() { - Dispatcher::new().dispatch(Message::ServerResponse( - server::ServerResponse::PrivilegedUsersResponse( - server::PrivilegedUsersResponse { - users: vec![ - "foo".to_string(), - "bar".to_string(), - "baz".to_string(), - ], - }, - ), - )); - } + #[test] + fn dispatcher_privileged_users_response() { + Dispatcher::new().dispatch(Message::ServerResponse( + server::ServerResponse::PrivilegedUsersResponse( + server::PrivilegedUsersResponse { + users: vec!["foo".to_string(), "bar".to_string(), "baz".to_string()], + }, + ), + )); + } - #[test] - fn dispatcher_login_response() { - Dispatcher::new().dispatch(Message::ServerResponse( - server::ServerResponse::LoginResponse( - server::LoginResponse::LoginFail { - reason: "bleep bloop".to_string(), - }, - ), - )); - } + #[test] + fn dispatcher_login_response() { + Dispatcher::new().dispatch(Message::ServerResponse( + server::ServerResponse::LoginResponse(server::LoginResponse::LoginFail { + reason: "bleep bloop".to_string(), + }), + )); + } } diff --git a/src/executor.rs b/src/executor.rs index a551d98..b45ca69 100644 --- a/src/executor.rs +++ b/src/executor.rs @@ -15,133 +15,133 @@ const NUM_THREADS: usize = 8; /// The trait of objects that can be run by an Executor. pub trait Job: Send { - /// Runs this job against the given context. - fn execute(self: Box, context: &Context); + /// Runs this job against the given context. + fn execute(self: Box, context: &Context); } /// A concurrent job execution engine. pub struct Executor { - /// The context against which jobs are executed. - context: Arc, + /// The context against which jobs are executed. + context: Arc, - /// Executes the jobs. - pool: threadpool::ThreadPool, + /// Executes the jobs. + pool: threadpool::ThreadPool, } impl Executor { - /// Builds a new executor against the given context. - pub fn new(context: Context) -> Self { - Self { - context: Arc::new(context), - pool: threadpool::Builder::new() - .num_threads(NUM_THREADS) - .thread_name("Executor".to_string()) - .build(), - } - } - - /// Schedules execution of the given job on this executor. - pub fn schedule(&self, job: Box) { - let context = self.context.clone(); - self.pool.execute(move || job.execute(&*context)); - } - - /// Blocks until all scheduled jobs are executed, then returns the context. - pub fn join(self) -> Context { - self.pool.join(); - - // The only copies of the Arc are passed to the closures executed on - // the threadpool. Once the pool is join()ed, there cannot exist any - // other copies than ours, so we are safe to unwrap() the Arc. - Arc::try_unwrap(self.context).unwrap() + /// Builds a new executor against the given context. + pub fn new(context: Context) -> Self { + Self { + context: Arc::new(context), + pool: threadpool::Builder::new() + .num_threads(NUM_THREADS) + .thread_name("Executor".to_string()) + .build(), } + } + + /// Schedules execution of the given job on this executor. + pub fn schedule(&self, job: Box) { + let context = self.context.clone(); + self.pool.execute(move || job.execute(&*context)); + } + + /// Blocks until all scheduled jobs are executed, then returns the context. + pub fn join(self) -> Context { + self.pool.join(); + + // The only copies of the Arc are passed to the closures executed on + // the threadpool. Once the pool is join()ed, there cannot exist any + // other copies than ours, so we are safe to unwrap() the Arc. + Arc::try_unwrap(self.context).unwrap() + } } #[cfg(test)] mod tests { - use std::sync::{Arc, Barrier}; + use std::sync::{Arc, Barrier}; - use crate::proto::{User, UserStatus}; + use crate::proto::{User, UserStatus}; - use super::{Context, Executor, Job}; + use super::{Context, Executor, Job}; - #[test] - fn immediate_join_returns_empty_context() { - let context = Executor::new(Context::new()).join(); - assert_eq!(context.users.lock().get_list(), vec![]); - assert_eq!(context.rooms.lock().get_room_list(), vec![]); - } + #[test] + fn immediate_join_returns_empty_context() { + let context = Executor::new(Context::new()).join(); + assert_eq!(context.users.lock().get_list(), vec![]); + assert_eq!(context.rooms.lock().get_room_list(), vec![]); + } - struct Waiter { - barrier: Arc, - } + struct Waiter { + barrier: Arc, + } - impl Job for Waiter { - fn execute(self: Box, _context: &Context) { - self.barrier.wait(); - } + impl Job for Waiter { + fn execute(self: Box, _context: &Context) { + self.barrier.wait(); } + } - #[test] - fn join_waits_for_all_jobs() { - let executor = Executor::new(Context::new()); + #[test] + fn join_waits_for_all_jobs() { + let executor = Executor::new(Context::new()); - let barrier = Arc::new(Barrier::new(2)); + let barrier = Arc::new(Barrier::new(2)); - executor.schedule(Box::new(Waiter { - barrier: barrier.clone(), - })); - executor.schedule(Box::new(Waiter { - barrier: barrier.clone(), - })); + executor.schedule(Box::new(Waiter { + barrier: barrier.clone(), + })); + executor.schedule(Box::new(Waiter { + barrier: barrier.clone(), + })); - executor.join(); - } + executor.join(); + } - struct UserAdder { - pub user: User, - } - - impl Job for UserAdder { - fn execute(self: Box, context: &Context) { - context.users.lock().insert(self.user); - } - } + struct UserAdder { + pub user: User, + } - #[test] - fn jobs_access_context() { - let executor = Executor::new(Context::new()); - - let user1 = User { - name: "potato".to_string(), - status: UserStatus::Offline, - average_speed: 0, - num_downloads: 0, - unknown: 0, - num_files: 0, - num_folders: 0, - num_free_slots: 0, - country: "YO".to_string(), - }; - - let mut user2 = user1.clone(); - user2.name = "rutabaga".to_string(); - - executor.schedule(Box::new(UserAdder { - user: user1.clone(), - })); - executor.schedule(Box::new(UserAdder { - user: user2.clone(), - })); - - let context = executor.join(); - - let expected_users = - vec![(user1.name.clone(), user1), (user2.name.clone(), user2)]; - - let mut users = context.users.lock().get_list(); - users.sort(); - - assert_eq!(users, expected_users); + impl Job for UserAdder { + fn execute(self: Box, context: &Context) { + context.users.lock().insert(self.user); } + } + + #[test] + fn jobs_access_context() { + let executor = Executor::new(Context::new()); + + let user1 = User { + name: "potato".to_string(), + status: UserStatus::Offline, + average_speed: 0, + num_downloads: 0, + unknown: 0, + num_files: 0, + num_folders: 0, + num_free_slots: 0, + country: "YO".to_string(), + }; + + let mut user2 = user1.clone(); + user2.name = "rutabaga".to_string(); + + executor.schedule(Box::new(UserAdder { + user: user1.clone(), + })); + executor.schedule(Box::new(UserAdder { + user: user2.clone(), + })); + + let context = executor.join(); + + let expected_users = + vec![(user1.name.clone(), user1), (user2.name.clone(), user2)]; + + let mut users = context.users.lock().get_list(); + users.sort(); + + assert_eq!(users, expected_users); + } } diff --git a/src/handlers/login_handler.rs b/src/handlers/login_handler.rs index 2cb4b9d..848ba3a 100644 --- a/src/handlers/login_handler.rs +++ b/src/handlers/login_handler.rs @@ -9,44 +9,40 @@ use crate::proto::server::LoginResponse; pub struct LoginHandler; impl MessageHandler for LoginHandler { - fn run( - self, - context: &Context, - _message: &LoginResponse, - ) -> io::Result<()> { - let lock = context.login.lock(); - - match *lock { - LoginStatus::AwaitingResponse => (), - _ => { - return Err(io::Error::new( - io::ErrorKind::Other, - format!("unexpected login response, status = {:?}", *lock), - )); - } - }; - - unimplemented!(); - } - - fn name() -> String { - "LoginHandler".to_string() - } + fn run(self, context: &Context, _message: &LoginResponse) -> io::Result<()> { + let lock = context.login.lock(); + + match *lock { + LoginStatus::AwaitingResponse => (), + _ => { + return Err(io::Error::new( + io::ErrorKind::Other, + format!("unexpected login response, status = {:?}", *lock), + )); + } + }; + + unimplemented!(); + } + + fn name() -> String { + "LoginHandler".to_string() + } } #[cfg(test)] mod tests { - use super::*; + use super::*; - #[test] - #[should_panic] - fn run_fails_on_wrong_status() { - let context = Context::new(); + #[test] + #[should_panic] + fn run_fails_on_wrong_status() { + let context = Context::new(); - let response = LoginResponse::LoginFail { - reason: "bleep bloop".to_string(), - }; + let response = LoginResponse::LoginFail { + reason: "bleep bloop".to_string(), + }; - LoginHandler::default().run(&context, &response).unwrap(); - } + LoginHandler::default().run(&context, &response).unwrap(); + } } diff --git a/src/handlers/set_privileged_users_handler.rs b/src/handlers/set_privileged_users_handler.rs index 260559f..404d73b 100644 --- a/src/handlers/set_privileged_users_handler.rs +++ b/src/handlers/set_privileged_users_handler.rs @@ -8,48 +8,48 @@ use crate::proto::server::PrivilegedUsersResponse; pub struct SetPrivilegedUsersHandler; impl MessageHandler for SetPrivilegedUsersHandler { - fn run( - self, - context: &Context, - message: &PrivilegedUsersResponse, - ) -> io::Result<()> { - let users = message.users.clone(); - context.users.lock().set_all_privileged(users); - Ok(()) - } - - fn name() -> String { - "SetPrivilegedUsersHandler".to_string() - } + fn run( + self, + context: &Context, + message: &PrivilegedUsersResponse, + ) -> io::Result<()> { + let users = message.users.clone(); + context.users.lock().set_all_privileged(users); + Ok(()) + } + + fn name() -> String { + "SetPrivilegedUsersHandler".to_string() + } } #[cfg(test)] mod tests { - use crate::context::Context; - use crate::message_handler::MessageHandler; - use crate::proto::server::PrivilegedUsersResponse; + use crate::context::Context; + use crate::message_handler::MessageHandler; + use crate::proto::server::PrivilegedUsersResponse; - use super::SetPrivilegedUsersHandler; + use super::SetPrivilegedUsersHandler; - #[test] - fn run_sets_privileged_users() { - let context = Context::new(); + #[test] + fn run_sets_privileged_users() { + let context = Context::new(); - let response = PrivilegedUsersResponse { - users: vec![ - "aomame".to_string(), - "billybob".to_string(), - "carlos".to_string(), - ], - }; + let response = PrivilegedUsersResponse { + users: vec![ + "aomame".to_string(), + "billybob".to_string(), + "carlos".to_string(), + ], + }; - SetPrivilegedUsersHandler::default() - .run(&context, &response) - .unwrap(); + SetPrivilegedUsersHandler::default() + .run(&context, &response) + .unwrap(); - let mut privileged = context.users.lock().get_all_privileged(); - privileged.sort(); + let mut privileged = context.users.lock().get_all_privileged(); + privileged.sort(); - assert_eq!(privileged, response.users); - } + assert_eq!(privileged, response.users); + } } diff --git a/src/login.rs b/src/login.rs index 65fe6b4..8333318 100644 --- a/src/login.rs +++ b/src/login.rs @@ -6,17 +6,17 @@ /// successfully logged in, the client can interact with the server. #[derive(Clone, Debug)] pub enum LoginStatus { - /// Request not sent yet. - Todo, + /// Request not sent yet. + Todo, - /// Sent request, awaiting response. - AwaitingResponse, + /// Sent request, awaiting response. + AwaitingResponse, - /// Logged in. - /// Stores the MOTD as received from the server. - Success(String), + /// Logged in. + /// Stores the MOTD as received from the server. + Success(String), - /// Failed to log in. - /// Stores the error message as received from the server. - Failure(String), + /// Failed to log in. + /// Stores the error message as received from the server. + Failure(String), } diff --git a/src/main.rs b/src/main.rs index f5095a0..6ecf2d9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -21,36 +21,35 @@ mod room; mod user; fn main() { - match env_logger::init() { - Ok(()) => (), - Err(err) => { - error!("Error initializing logger: {}", err); - return; - } - }; - - let (proto_to_client_tx, proto_to_client_rx) = - crossbeam_channel::unbounded(); - - let mut proto_agent = match proto::Agent::new(proto_to_client_tx) { - Ok(agent) => agent, - Err(err) => { - error!("Error initializing protocol agent: {}", err); - return; - } - }; - - let client_to_proto_tx = proto_agent.channel(); - let (control_to_client_tx, control_to_client_rx) = - crossbeam_channel::unbounded(); - - let mut client = client::Client::new( - client_to_proto_tx, - proto_to_client_rx, - control_to_client_rx, - ); - - thread::spawn(move || control::listen(control_to_client_tx)); - thread::spawn(move || proto_agent.run().unwrap()); - client.run(); + match env_logger::init() { + Ok(()) => (), + Err(err) => { + error!("Error initializing logger: {}", err); + return; + } + }; + + let (proto_to_client_tx, proto_to_client_rx) = crossbeam_channel::unbounded(); + + let mut proto_agent = match proto::Agent::new(proto_to_client_tx) { + Ok(agent) => agent, + Err(err) => { + error!("Error initializing protocol agent: {}", err); + return; + } + }; + + let client_to_proto_tx = proto_agent.channel(); + let (control_to_client_tx, control_to_client_rx) = + crossbeam_channel::unbounded(); + + let mut client = client::Client::new( + client_to_proto_tx, + proto_to_client_rx, + control_to_client_rx, + ); + + thread::spawn(move || control::listen(control_to_client_tx)); + thread::spawn(move || proto_agent.run().unwrap()); + client.run(); } diff --git a/src/message_handler.rs b/src/message_handler.rs index be2543c..470ee2b 100644 --- a/src/message_handler.rs +++ b/src/message_handler.rs @@ -7,9 +7,9 @@ use crate::context::Context; /// Message types are mapped to handler types by Dispatcher. /// This trait is intended to allow composing handler logic. pub trait MessageHandler { - /// Attempts to handle the given message against the given context. - fn run(self, context: &Context, message: &Message) -> io::Result<()>; + /// Attempts to handle the given message against the given context. + fn run(self, context: &Context, message: &Message) -> io::Result<()>; - /// Returns the name of this handler type. - fn name() -> String; + /// Returns the name of this handler type. + fn name() -> String; } diff --git a/src/proto/frame.rs b/src/proto/frame.rs index 0f23271..7aee1b1 100644 --- a/src/proto/frame.rs +++ b/src/proto/frame.rs @@ -14,374 +14,372 @@ use tokio::net::TcpStream; use super::prefix::Prefixer; use super::u32::{decode_u32, U32_BYTE_LEN}; use super::value_codec::{ - ValueDecode, ValueDecodeError, ValueDecoder, ValueEncode, ValueEncodeError, - ValueEncoder, + ValueDecode, ValueDecodeError, ValueDecoder, ValueEncode, ValueEncodeError, + ValueEncoder, }; #[derive(Debug, Error, PartialEq)] pub enum FrameEncodeError { - #[error("encoded value length {length} is too large")] - ValueTooLarge { - /// The length of the encoded value. - length: usize, - }, - - #[error("failed to encode value: {0}")] - ValueEncodeError(#[from] ValueEncodeError), + #[error("encoded value length {length} is too large")] + ValueTooLarge { + /// The length of the encoded value. + length: usize, + }, + + #[error("failed to encode value: {0}")] + ValueEncodeError(#[from] ValueEncodeError), } impl From for io::Error { - fn from(error: FrameEncodeError) -> Self { - io::Error::new(io::ErrorKind::InvalidData, format!("{}", error)) - } + fn from(error: FrameEncodeError) -> Self { + io::Error::new(io::ErrorKind::InvalidData, format!("{}", error)) + } } /// Encodes entire protocol frames containing values of type `T`. #[derive(Debug)] pub struct FrameEncoder { - phantom: PhantomData, + phantom: PhantomData, } impl FrameEncoder { - pub fn new() -> Self { - Self { - phantom: PhantomData, - } + pub fn new() -> Self { + Self { + phantom: PhantomData, } + } - pub fn encode_to( - &mut self, - value: &T, - buffer: &mut BytesMut, - ) -> Result<(), FrameEncodeError> { - let mut prefixer = Prefixer::new(buffer); - - ValueEncoder::new(prefixer.suffix_mut()).encode(value)?; + pub fn encode_to( + &mut self, + value: &T, + buffer: &mut BytesMut, + ) -> Result<(), FrameEncodeError> { + let mut prefixer = Prefixer::new(buffer); - if let Err(prefixer) = prefixer.finalize() { - return Err(FrameEncodeError::ValueTooLarge { - length: prefixer.suffix().len(), - }); - } + ValueEncoder::new(prefixer.suffix_mut()).encode(value)?; - Ok(()) + if let Err(prefixer) = prefixer.finalize() { + return Err(FrameEncodeError::ValueTooLarge { + length: prefixer.suffix().len(), + }); } + + Ok(()) + } } /// Decodes entire protocol frames containing values of type `T`. #[derive(Debug)] pub struct FrameDecoder { - // Only here to enable parameterizing `Decoder` by `T`. - phantom: PhantomData, + // Only here to enable parameterizing `Decoder` by `T`. + phantom: PhantomData, } impl FrameDecoder { - pub fn new() -> Self { - Self { - phantom: PhantomData, - } + pub fn new() -> Self { + Self { + phantom: PhantomData, + } + } + + /// Attempts to decode an entire frame from the given buffer. + /// + /// Returns `Ok(Some(frame))` if successful, in which case the frame's bytes + /// have been split off from the left of `bytes`. + /// + /// Returns `Ok(None)` if not enough bytes are available to decode an entire + /// frame yet, in which case `bytes` is untouched. + /// + /// Returns an error if the length prefix or the framed value are malformed, + /// in which case `bytes` is untouched. + pub fn decode_from( + &mut self, + bytes: &mut BytesMut, + ) -> Result, ValueDecodeError> { + if bytes.len() < U32_BYTE_LEN { + return Ok(None); // Not enough bytes yet. } - /// Attempts to decode an entire frame from the given buffer. - /// - /// Returns `Ok(Some(frame))` if successful, in which case the frame's bytes - /// have been split off from the left of `bytes`. - /// - /// Returns `Ok(None)` if not enough bytes are available to decode an entire - /// frame yet, in which case `bytes` is untouched. - /// - /// Returns an error if the length prefix or the framed value are malformed, - /// in which case `bytes` is untouched. - pub fn decode_from( - &mut self, - bytes: &mut BytesMut, - ) -> Result, ValueDecodeError> { - if bytes.len() < U32_BYTE_LEN { - return Ok(None); // Not enough bytes yet. - } - - // Split the prefix off. After this: - // - // | bytes (len 4) | suffix | - // - // NOTE: This method would be simpler if we could use split_to() instead - // here such that `bytes` contained the suffix. At the end, we would not - // have to replace `bytes` with `suffix`. However, that would require - // calling `prefix.unsplit(*bytes)`, and that does not work since - // `bytes` is only borrowed, and unsplit() takes its argument by value. - let mut suffix = bytes.split_off(U32_BYTE_LEN); - - // unwrap() cannot panic because `bytes` is of the exact right length. - let array: [u8; U32_BYTE_LEN] = bytes.as_ref().try_into().unwrap(); - let length = decode_u32(array) as usize; - - if suffix.len() < length { - // Re-assemble `bytes` as it first was. - bytes.unsplit(suffix); - return Ok(None); // Not enough bytes yet. - } - - // Split off the right amount of bytes from the buffer. After this: - // - // | bytes (len 4) | contents | suffix | - // - let mut contents = suffix.split_to(length); - - // Attempt to decode the value. - let item = match ValueDecoder::new(&contents).decode() { - Ok(item) => item, - Err(error) => { - // Re-assemble `bytes` as it first was. - contents.unsplit(suffix); - bytes.unsplit(contents); - return Err(error); - } - }; - - // Remove the decoded bytes from the left of `bytes`. - *bytes = suffix; - Ok(Some(item)) + // Split the prefix off. After this: + // + // | bytes (len 4) | suffix | + // + // NOTE: This method would be simpler if we could use split_to() instead + // here such that `bytes` contained the suffix. At the end, we would not + // have to replace `bytes` with `suffix`. However, that would require + // calling `prefix.unsplit(*bytes)`, and that does not work since + // `bytes` is only borrowed, and unsplit() takes its argument by value. + let mut suffix = bytes.split_off(U32_BYTE_LEN); + + // unwrap() cannot panic because `bytes` is of the exact right length. + let array: [u8; U32_BYTE_LEN] = bytes.as_ref().try_into().unwrap(); + let length = decode_u32(array) as usize; + + if suffix.len() < length { + // Re-assemble `bytes` as it first was. + bytes.unsplit(suffix); + return Ok(None); // Not enough bytes yet. } + + // Split off the right amount of bytes from the buffer. After this: + // + // | bytes (len 4) | contents | suffix | + // + let mut contents = suffix.split_to(length); + + // Attempt to decode the value. + let item = match ValueDecoder::new(&contents).decode() { + Ok(item) => item, + Err(error) => { + // Re-assemble `bytes` as it first was. + contents.unsplit(suffix); + bytes.unsplit(contents); + return Err(error); + } + }; + + // Remove the decoded bytes from the left of `bytes`. + *bytes = suffix; + Ok(Some(item)) + } } #[derive(Debug)] pub struct FrameStream { - stream: TcpStream, + stream: TcpStream, - read_buffer: BytesMut, + read_buffer: BytesMut, - decoder: FrameDecoder, - encoder: FrameEncoder, + decoder: FrameDecoder, + encoder: FrameEncoder, } impl FrameStream where - ReadFrame: ValueDecode, - WriteFrame: ValueEncode + ?Sized, + ReadFrame: ValueDecode, + WriteFrame: ValueEncode + ?Sized, { - pub fn new(stream: TcpStream) -> Self { - FrameStream { - stream, - read_buffer: BytesMut::new(), - decoder: FrameDecoder::new(), - encoder: FrameEncoder::new(), - } + pub fn new(stream: TcpStream) -> Self { + FrameStream { + stream, + read_buffer: BytesMut::new(), + decoder: FrameDecoder::new(), + encoder: FrameEncoder::new(), } - - pub async fn read(&mut self) -> io::Result { - loop { - if let Some(frame) = - self.decoder.decode_from(&mut self.read_buffer)? - { - return Ok(frame); - } - self.stream.read_buf(&mut self.read_buffer).await?; - } + } + + pub async fn read(&mut self) -> io::Result { + loop { + if let Some(frame) = self.decoder.decode_from(&mut self.read_buffer)? { + return Ok(frame); + } + self.stream.read_buf(&mut self.read_buffer).await?; } + } - pub async fn write(&mut self, frame: &WriteFrame) -> io::Result<()> { - let mut bytes = BytesMut::new(); - self.encoder.encode_to(frame, &mut bytes)?; - self.stream.write_all(bytes.as_ref()).await - } + pub async fn write(&mut self, frame: &WriteFrame) -> io::Result<()> { + let mut bytes = BytesMut::new(); + self.encoder.encode_to(frame, &mut bytes)?; + self.stream.write_all(bytes.as_ref()).await + } } mod tests { - use bytes::BytesMut; - use tokio::net::{TcpListener, TcpStream}; - - use super::{FrameDecoder, FrameEncoder, FrameStream}; - - // Test value: [1, 3, 3, 7] in little-endian. - const U32_1337: u32 = 1 + (3 << 8) + (3 << 16) + (7 << 24); - - #[test] - fn encode_u32() { - let mut bytes = BytesMut::new(); - - FrameEncoder::new() - .encode_to(&U32_1337, &mut bytes) - .unwrap(); - - assert_eq!( - bytes, - vec![ - 4, 0, 0, 0, // 1 32-bit integer = 4 bytes. - 1, 3, 3, 7, // Little-endian integer. - ] - ); - } - - #[test] - fn encode_appends() { - let mut bytes = BytesMut::new(); - - let mut encoder = FrameEncoder::new(); - encoder.encode_to(&U32_1337, &mut bytes).unwrap(); - encoder.encode_to(&U32_1337, &mut bytes).unwrap(); - - assert_eq!( - bytes, - vec![ - 4, 0, 0, 0, // 1 32-bit integer = 4 bytes. - 1, 3, 3, 7, // Little-endian integer. - 4, 0, 0, 0, // Repeated. - 1, 3, 3, 7, - ] - ); - } - - #[test] - fn encode_vec() { - let v: Vec = vec![1, 3, 3, 7]; - - let mut bytes = BytesMut::new(); - FrameEncoder::new().encode_to(&v, &mut bytes).unwrap(); - - assert_eq!( - bytes, - vec![ - 20, 0, 0, 0, // 5 32-bit integers = 20 bytes. - 4, 0, 0, 0, // 4 elements in the vector. - 1, 0, 0, 0, // Little-endian vector elements. - 3, 0, 0, 0, // - 3, 0, 0, 0, // - 7, 0, 0, 0, // - ] - ); - } - - #[test] - fn decode_not_enough_data_for_prefix() { - let initial_bytes = vec![ - 4, 0, 0, // Incomplete 32-bit length prefix. - ]; - - let mut bytes = BytesMut::new(); - bytes.extend_from_slice(&initial_bytes); - - let value: Option = - FrameDecoder::new().decode_from(&mut bytes).unwrap(); - - assert_eq!(value, None); - assert_eq!(bytes, initial_bytes); // Untouched. - } - - #[test] - fn decode_not_enough_data_for_contents() { - let initial_bytes = vec![ - 4, 0, 0, 0, // Length 4. - 1, 2, 3, // But there are only 3 bytes! - ]; - - let mut bytes = BytesMut::new(); - bytes.extend_from_slice(&initial_bytes); - - let value: Option = - FrameDecoder::new().decode_from(&mut bytes).unwrap(); - - assert_eq!(value, None); - assert_eq!(bytes, initial_bytes); // Untouched. - } - - #[test] - fn decode_u32() { - let mut bytes = BytesMut::new(); - bytes.extend_from_slice(&[ - 4, 0, 0, 0, // 1 32-bit integer = 4 bytes. - 1, 3, 3, 7, // Little-endian integer. - 4, 2, // Trailing bytes. - ]); - - let value = FrameDecoder::new().decode_from(&mut bytes).unwrap(); - - assert_eq!(value, Some(U32_1337)); - assert_eq!(bytes, vec![4, 2]); // Decoded bytes were split off. - } - - #[test] - fn decode_vec() { - let mut bytes = BytesMut::new(); - bytes.extend_from_slice(&[ - 20, 0, 0, 0, // 5 32-bit integers = 20 bytes. - 4, 0, 0, 0, // 4 elements in the vector. - 1, 0, 0, 0, // Little-endian vector elements. - 3, 0, 0, 0, // - 3, 0, 0, 0, // - 7, 0, 0, 0, // - 4, 2, // Trailing bytes. - ]); - - let value = FrameDecoder::new().decode_from(&mut bytes).unwrap(); - - let expected_value: Vec = vec![1, 3, 3, 7]; - assert_eq!(value, Some(expected_value)); - assert_eq!(bytes, vec![4, 2]); // Decoded bytes were split off. - } - - #[test] - fn roundtrip() { - let value: Vec = vec![ - "apples".to_string(), // - "bananas".to_string(), // - "oranges".to_string(), // - "and cheese!".to_string(), // - ]; - - let mut buffer = BytesMut::new(); - - FrameEncoder::new().encode_to(&value, &mut buffer).unwrap(); - let decoded = FrameDecoder::new().decode_from(&mut buffer).unwrap(); - - assert_eq!(decoded, Some(value)); - assert_eq!(buffer, vec![]); - } - - #[tokio::test] - async fn ping_pong() { - let listener = TcpListener::bind("localhost:0").await.unwrap(); - let address = listener.local_addr().unwrap(); - - let server_task = tokio::spawn(async move { - let (stream, _peer_address) = listener.accept().await.unwrap(); - let mut frame_stream = FrameStream::::new(stream); - - assert_eq!(frame_stream.read().await.unwrap(), "ping"); - frame_stream.write("pong").await.unwrap(); - assert_eq!(frame_stream.read().await.unwrap(), "ping"); - frame_stream.write("pong").await.unwrap(); - }); - - let stream = TcpStream::connect(address).await.unwrap(); - let mut frame_stream = FrameStream::::new(stream); - - frame_stream.write("ping").await.unwrap(); - assert_eq!(frame_stream.read().await.unwrap(), "pong"); - frame_stream.write("ping").await.unwrap(); - assert_eq!(frame_stream.read().await.unwrap(), "pong"); - - server_task.await.unwrap(); - } - - #[tokio::test] - async fn very_large_message() { - let listener = TcpListener::bind("localhost:0").await.unwrap(); - let address = listener.local_addr().unwrap(); - - let server_task = tokio::spawn(async move { - let (stream, _peer_address) = listener.accept().await.unwrap(); - let mut frame_stream = FrameStream::>::new(stream); - - assert_eq!(frame_stream.read().await.unwrap(), "ping"); - frame_stream.write(&vec![0; 10 * 4096]).await.unwrap(); - }); - - let stream = TcpStream::connect(address).await.unwrap(); - let mut frame_stream = FrameStream::, str>::new(stream); - - frame_stream.write("ping").await.unwrap(); - assert_eq!(frame_stream.read().await.unwrap(), vec![0; 10 * 4096]); - - server_task.await.unwrap(); - } + use bytes::BytesMut; + use tokio::net::{TcpListener, TcpStream}; + + use super::{FrameDecoder, FrameEncoder, FrameStream}; + + // Test value: [1, 3, 3, 7] in little-endian. + const U32_1337: u32 = 1 + (3 << 8) + (3 << 16) + (7 << 24); + + #[test] + fn encode_u32() { + let mut bytes = BytesMut::new(); + + FrameEncoder::new() + .encode_to(&U32_1337, &mut bytes) + .unwrap(); + + assert_eq!( + bytes, + vec![ + 4, 0, 0, 0, // 1 32-bit integer = 4 bytes. + 1, 3, 3, 7, // Little-endian integer. + ] + ); + } + + #[test] + fn encode_appends() { + let mut bytes = BytesMut::new(); + + let mut encoder = FrameEncoder::new(); + encoder.encode_to(&U32_1337, &mut bytes).unwrap(); + encoder.encode_to(&U32_1337, &mut bytes).unwrap(); + + assert_eq!( + bytes, + vec![ + 4, 0, 0, 0, // 1 32-bit integer = 4 bytes. + 1, 3, 3, 7, // Little-endian integer. + 4, 0, 0, 0, // Repeated. + 1, 3, 3, 7, + ] + ); + } + + #[test] + fn encode_vec() { + let v: Vec = vec![1, 3, 3, 7]; + + let mut bytes = BytesMut::new(); + FrameEncoder::new().encode_to(&v, &mut bytes).unwrap(); + + assert_eq!( + bytes, + vec![ + 20, 0, 0, 0, // 5 32-bit integers = 20 bytes. + 4, 0, 0, 0, // 4 elements in the vector. + 1, 0, 0, 0, // Little-endian vector elements. + 3, 0, 0, 0, // + 3, 0, 0, 0, // + 7, 0, 0, 0, // + ] + ); + } + + #[test] + fn decode_not_enough_data_for_prefix() { + let initial_bytes = vec![ + 4, 0, 0, // Incomplete 32-bit length prefix. + ]; + + let mut bytes = BytesMut::new(); + bytes.extend_from_slice(&initial_bytes); + + let value: Option = + FrameDecoder::new().decode_from(&mut bytes).unwrap(); + + assert_eq!(value, None); + assert_eq!(bytes, initial_bytes); // Untouched. + } + + #[test] + fn decode_not_enough_data_for_contents() { + let initial_bytes = vec![ + 4, 0, 0, 0, // Length 4. + 1, 2, 3, // But there are only 3 bytes! + ]; + + let mut bytes = BytesMut::new(); + bytes.extend_from_slice(&initial_bytes); + + let value: Option = + FrameDecoder::new().decode_from(&mut bytes).unwrap(); + + assert_eq!(value, None); + assert_eq!(bytes, initial_bytes); // Untouched. + } + + #[test] + fn decode_u32() { + let mut bytes = BytesMut::new(); + bytes.extend_from_slice(&[ + 4, 0, 0, 0, // 1 32-bit integer = 4 bytes. + 1, 3, 3, 7, // Little-endian integer. + 4, 2, // Trailing bytes. + ]); + + let value = FrameDecoder::new().decode_from(&mut bytes).unwrap(); + + assert_eq!(value, Some(U32_1337)); + assert_eq!(bytes, vec![4, 2]); // Decoded bytes were split off. + } + + #[test] + fn decode_vec() { + let mut bytes = BytesMut::new(); + bytes.extend_from_slice(&[ + 20, 0, 0, 0, // 5 32-bit integers = 20 bytes. + 4, 0, 0, 0, // 4 elements in the vector. + 1, 0, 0, 0, // Little-endian vector elements. + 3, 0, 0, 0, // + 3, 0, 0, 0, // + 7, 0, 0, 0, // + 4, 2, // Trailing bytes. + ]); + + let value = FrameDecoder::new().decode_from(&mut bytes).unwrap(); + + let expected_value: Vec = vec![1, 3, 3, 7]; + assert_eq!(value, Some(expected_value)); + assert_eq!(bytes, vec![4, 2]); // Decoded bytes were split off. + } + + #[test] + fn roundtrip() { + let value: Vec = vec![ + "apples".to_string(), // + "bananas".to_string(), // + "oranges".to_string(), // + "and cheese!".to_string(), // + ]; + + let mut buffer = BytesMut::new(); + + FrameEncoder::new().encode_to(&value, &mut buffer).unwrap(); + let decoded = FrameDecoder::new().decode_from(&mut buffer).unwrap(); + + assert_eq!(decoded, Some(value)); + assert_eq!(buffer, vec![]); + } + + #[tokio::test] + async fn ping_pong() { + let listener = TcpListener::bind("localhost:0").await.unwrap(); + let address = listener.local_addr().unwrap(); + + let server_task = tokio::spawn(async move { + let (stream, _peer_address) = listener.accept().await.unwrap(); + let mut frame_stream = FrameStream::::new(stream); + + assert_eq!(frame_stream.read().await.unwrap(), "ping"); + frame_stream.write("pong").await.unwrap(); + assert_eq!(frame_stream.read().await.unwrap(), "ping"); + frame_stream.write("pong").await.unwrap(); + }); + + let stream = TcpStream::connect(address).await.unwrap(); + let mut frame_stream = FrameStream::::new(stream); + + frame_stream.write("ping").await.unwrap(); + assert_eq!(frame_stream.read().await.unwrap(), "pong"); + frame_stream.write("ping").await.unwrap(); + assert_eq!(frame_stream.read().await.unwrap(), "pong"); + + server_task.await.unwrap(); + } + + #[tokio::test] + async fn very_large_message() { + let listener = TcpListener::bind("localhost:0").await.unwrap(); + let address = listener.local_addr().unwrap(); + + let server_task = tokio::spawn(async move { + let (stream, _peer_address) = listener.accept().await.unwrap(); + let mut frame_stream = FrameStream::>::new(stream); + + assert_eq!(frame_stream.read().await.unwrap(), "ping"); + frame_stream.write(&vec![0; 10 * 4096]).await.unwrap(); + }); + + let stream = TcpStream::connect(address).await.unwrap(); + let mut frame_stream = FrameStream::, str>::new(stream); + + frame_stream.write("ping").await.unwrap(); + assert_eq!(frame_stream.read().await.unwrap(), vec![0; 10 * 4096]); + + server_task.await.unwrap(); + } } diff --git a/src/proto/handler.rs b/src/proto/handler.rs index db48c49..75986a4 100644 --- a/src/proto/handler.rs +++ b/src/proto/handler.rs @@ -30,17 +30,17 @@ const LISTEN_TOKEN: usize = config::MAX_PEERS + 1; #[derive(Debug)] pub enum Request { - PeerConnect(usize, net::Ipv4Addr, u16), - PeerMessage(usize, peer::Message), - ServerRequest(ServerRequest), + PeerConnect(usize, net::Ipv4Addr, u16), + PeerMessage(usize, peer::Message), + ServerRequest(ServerRequest), } #[derive(Debug)] pub enum Response { - PeerConnectionClosed(usize), - PeerConnectionOpen(usize), - PeerMessage(usize, peer::Message), - ServerResponse(ServerResponse), + PeerConnectionClosed(usize), + PeerConnectionOpen(usize), + PeerMessage(usize, peer::Message), + ServerResponse(ServerResponse), } /*========================* @@ -50,16 +50,16 @@ pub enum Response { pub struct ServerResponseSender(crossbeam_channel::Sender); impl SendPacket for ServerResponseSender { - type Value = ServerResponse; - type Error = crossbeam_channel::SendError; + type Value = ServerResponse; + type Error = crossbeam_channel::SendError; - fn send_packet(&mut self, value: Self::Value) -> Result<(), Self::Error> { - self.0.send(Response::ServerResponse(value)) - } + fn send_packet(&mut self, value: Self::Value) -> Result<(), Self::Error> { + self.0.send(Response::ServerResponse(value)) + } - fn notify_open(&mut self) -> Result<(), Self::Error> { - Ok(()) - } + fn notify_open(&mut self) -> Result<(), Self::Error> { + Ok(()) + } } /*======================* @@ -67,21 +67,21 @@ impl SendPacket for ServerResponseSender { *======================*/ pub struct PeerResponseSender { - sender: crossbeam_channel::Sender, - peer_id: usize, + sender: crossbeam_channel::Sender, + peer_id: usize, } impl SendPacket for PeerResponseSender { - type Value = peer::Message; - type Error = crossbeam_channel::SendError; + type Value = peer::Message; + type Error = crossbeam_channel::SendError; - fn send_packet(&mut self, value: Self::Value) -> Result<(), Self::Error> { - self.sender.send(Response::PeerMessage(self.peer_id, value)) - } + fn send_packet(&mut self, value: Self::Value) -> Result<(), Self::Error> { + self.sender.send(Response::PeerMessage(self.peer_id, value)) + } - fn notify_open(&mut self) -> Result<(), Self::Error> { - self.sender.send(Response::PeerConnectionOpen(self.peer_id)) - } + fn notify_open(&mut self) -> Result<(), Self::Error> { + self.sender.send(Response::PeerConnectionOpen(self.peer_id)) + } } /*=========* @@ -91,307 +91,302 @@ impl SendPacket for PeerResponseSender { /// This struct handles all the soulseek connections, to the server and to /// peers. struct Handler { - server_stream: Stream, + server_stream: Stream, - peer_streams: slab::Slab, usize>, + peer_streams: slab::Slab, usize>, - listener: mio::tcp::TcpListener, + listener: mio::tcp::TcpListener, - client_tx: crossbeam_channel::Sender, + client_tx: crossbeam_channel::Sender, } fn listener_bind(addr_spec: U) -> io::Result where - U: ToSocketAddrs + fmt::Debug, + U: ToSocketAddrs + fmt::Debug, { - for socket_addr in addr_spec.to_socket_addrs()? { - if let Ok(listener) = mio::tcp::TcpListener::bind(&socket_addr) { - return Ok(listener); - } + for socket_addr in addr_spec.to_socket_addrs()? { + if let Ok(listener) = mio::tcp::TcpListener::bind(&socket_addr) { + return Ok(listener); } - Err(io::Error::new( - io::ErrorKind::Other, - format!("Cannot bind to {:?}", addr_spec), - )) + } + Err(io::Error::new( + io::ErrorKind::Other, + format!("Cannot bind to {:?}", addr_spec), + )) } impl Handler { - #[allow(deprecated)] - fn new( - client_tx: crossbeam_channel::Sender, - event_loop: &mut mio::deprecated::EventLoop, - ) -> io::Result { - let host = config::SERVER_HOST; - let port = config::SERVER_PORT; - let server_stream = - Stream::new((host, port), ServerResponseSender(client_tx.clone()))?; - - info!("Connected to server at {}:{}", host, port); - - let listener = - listener_bind((config::LISTEN_HOST, config::LISTEN_PORT))?; - info!( - "Listening for connections on {}:{}", - config::LISTEN_HOST, - config::LISTEN_PORT - ); - - event_loop.register( - server_stream.evented(), + #[allow(deprecated)] + fn new( + client_tx: crossbeam_channel::Sender, + event_loop: &mut mio::deprecated::EventLoop, + ) -> io::Result { + let host = config::SERVER_HOST; + let port = config::SERVER_PORT; + let server_stream = + Stream::new((host, port), ServerResponseSender(client_tx.clone()))?; + + info!("Connected to server at {}:{}", host, port); + + let listener = listener_bind((config::LISTEN_HOST, config::LISTEN_PORT))?; + info!( + "Listening for connections on {}:{}", + config::LISTEN_HOST, + config::LISTEN_PORT + ); + + event_loop.register( + server_stream.evented(), + mio::Token(SERVER_TOKEN), + mio::Ready::all(), + mio::PollOpt::edge() | mio::PollOpt::oneshot(), + )?; + + event_loop.register( + &listener, + mio::Token(LISTEN_TOKEN), + mio::Ready::all(), + mio::PollOpt::edge() | mio::PollOpt::oneshot(), + )?; + + Ok(Handler { + server_stream: server_stream, + + peer_streams: slab::Slab::new(config::MAX_PEERS), + + listener: listener, + + client_tx: client_tx, + }) + } + + #[allow(deprecated)] + fn connect_to_peer( + &mut self, + peer_id: usize, + ip: net::Ipv4Addr, + port: u16, + event_loop: &mut mio::deprecated::EventLoop, + ) -> Result<(), String> { + let vacant_entry = match self.peer_streams.entry(peer_id) { + None => return Err("id out of range".to_string()), + + Some(slab::Entry::Occupied(_occupied_entry)) => { + return Err("id already taken".to_string()); + } + + Some(slab::Entry::Vacant(vacant_entry)) => vacant_entry, + }; + + info!("Opening peer connection {} to {}:{}", peer_id, ip, port); + + let sender = PeerResponseSender { + sender: self.client_tx.clone(), + peer_id: peer_id, + }; + + let peer_stream = match Stream::new((ip, port), sender) { + Ok(peer_stream) => peer_stream, + + Err(err) => return Err(format!("i/o error: {}", err)), + }; + + event_loop + .register( + peer_stream.evented(), + mio::Token(peer_id), + mio::Ready::all(), + mio::PollOpt::edge() | mio::PollOpt::oneshot(), + ) + .unwrap(); + + vacant_entry.insert(peer_stream); + + Ok(()) + } + + #[allow(deprecated)] + fn process_server_intent( + &mut self, + intent: Intent, + event_loop: &mut mio::deprecated::EventLoop, + ) { + match intent { + Intent::Done => { + error!("Server connection closed"); + // TODO notify client and shut down + } + Intent::Continue(event_set) => { + event_loop + .reregister( + self.server_stream.evented(), mio::Token(SERVER_TOKEN), - mio::Ready::all(), + event_set, mio::PollOpt::edge() | mio::PollOpt::oneshot(), - )?; - - event_loop.register( - &listener, - mio::Token(LISTEN_TOKEN), - mio::Ready::all(), - mio::PollOpt::edge() | mio::PollOpt::oneshot(), - )?; - - Ok(Handler { - server_stream: server_stream, - - peer_streams: slab::Slab::new(config::MAX_PEERS), - - listener: listener, - - client_tx: client_tx, - }) + ) + .unwrap(); + } } - - #[allow(deprecated)] - fn connect_to_peer( - &mut self, - peer_id: usize, - ip: net::Ipv4Addr, - port: u16, - event_loop: &mut mio::deprecated::EventLoop, - ) -> Result<(), String> { - let vacant_entry = match self.peer_streams.entry(peer_id) { - None => return Err("id out of range".to_string()), - - Some(slab::Entry::Occupied(_occupied_entry)) => { - return Err("id already taken".to_string()); - } - - Some(slab::Entry::Vacant(vacant_entry)) => vacant_entry, - }; - - info!("Opening peer connection {} to {}:{}", peer_id, ip, port); - - let sender = PeerResponseSender { - sender: self.client_tx.clone(), - peer_id: peer_id, - }; - - let peer_stream = match Stream::new((ip, port), sender) { - Ok(peer_stream) => peer_stream, - - Err(err) => return Err(format!("i/o error: {}", err)), - }; - - event_loop - .register( - peer_stream.evented(), - mio::Token(peer_id), - mio::Ready::all(), - mio::PollOpt::edge() | mio::PollOpt::oneshot(), + } + + #[allow(deprecated)] + fn process_peer_intent( + &mut self, + intent: Intent, + token: mio::Token, + event_loop: &mut mio::deprecated::EventLoop, + ) { + match intent { + Intent::Done => { + self.peer_streams.remove(token.0); + self + .client_tx + .send(Response::PeerConnectionClosed(token.0)) + .unwrap(); + } + + Intent::Continue(event_set) => { + if let Some(peer_stream) = self.peer_streams.get_mut(token.0) { + event_loop + .reregister( + peer_stream.evented(), + token, + event_set, + mio::PollOpt::edge() | mio::PollOpt::oneshot(), ) .unwrap(); - - vacant_entry.insert(peer_stream); - - Ok(()) - } - - #[allow(deprecated)] - fn process_server_intent( - &mut self, - intent: Intent, - event_loop: &mut mio::deprecated::EventLoop, - ) { - match intent { - Intent::Done => { - error!("Server connection closed"); - // TODO notify client and shut down - } - Intent::Continue(event_set) => { - event_loop - .reregister( - self.server_stream.evented(), - mio::Token(SERVER_TOKEN), - event_set, - mio::PollOpt::edge() | mio::PollOpt::oneshot(), - ) - .unwrap(); - } - } - } - - #[allow(deprecated)] - fn process_peer_intent( - &mut self, - intent: Intent, - token: mio::Token, - event_loop: &mut mio::deprecated::EventLoop, - ) { - match intent { - Intent::Done => { - self.peer_streams.remove(token.0); - self.client_tx - .send(Response::PeerConnectionClosed(token.0)) - .unwrap(); - } - - Intent::Continue(event_set) => { - if let Some(peer_stream) = self.peer_streams.get_mut(token.0) { - event_loop - .reregister( - peer_stream.evented(), - token, - event_set, - mio::PollOpt::edge() | mio::PollOpt::oneshot(), - ) - .unwrap(); - } - } } + } } + } } #[allow(deprecated)] impl mio::deprecated::Handler for Handler { - type Timeout = (); - type Message = Request; - - fn ready( - &mut self, - event_loop: &mut mio::deprecated::EventLoop, - token: mio::Token, - event_set: mio::Ready, - ) { - match token { - mio::Token(LISTEN_TOKEN) => { - if event_set.is_readable() { - // A peer wants to connect to us. - match self.listener.accept() { - Ok((_sock, addr)) => { - // TODO add it to peer streams - info!("Peer connection accepted from {}", addr); - } - - Err(err) => { - error!("Cannot accept peer connection: {}", err); - } - } - } - event_loop - .reregister( - &self.listener, - token, - mio::Ready::all(), - mio::PollOpt::edge() | mio::PollOpt::oneshot(), - ) - .unwrap(); + type Timeout = (); + type Message = Request; + + fn ready( + &mut self, + event_loop: &mut mio::deprecated::EventLoop, + token: mio::Token, + event_set: mio::Ready, + ) { + match token { + mio::Token(LISTEN_TOKEN) => { + if event_set.is_readable() { + // A peer wants to connect to us. + match self.listener.accept() { + Ok((_sock, addr)) => { + // TODO add it to peer streams + info!("Peer connection accepted from {}", addr); } - mio::Token(SERVER_TOKEN) => { - let intent = self.server_stream.on_ready(event_set); - self.process_server_intent(intent, event_loop); - } - - mio::Token(peer_id) => { - let intent = match self.peer_streams.get_mut(peer_id) { - Some(peer_stream) => peer_stream.on_ready(event_set), - - None => unreachable!("Unknown peer {} is ready", peer_id), - }; - self.process_peer_intent(intent, token, event_loop); + Err(err) => { + error!("Cannot accept peer connection: {}", err); } + } } - } + event_loop + .reregister( + &self.listener, + token, + mio::Ready::all(), + mio::PollOpt::edge() | mio::PollOpt::oneshot(), + ) + .unwrap(); + } - fn notify( - &mut self, - event_loop: &mut mio::deprecated::EventLoop, - request: Request, - ) { - match request { - Request::PeerConnect(peer_id, ip, port) => { - if let Err(err) = - self.connect_to_peer(peer_id, ip, port, event_loop) - { - error!( - "Cannot open peer connection {} to {}:{}: {}", - peer_id, ip, port, err - ); - self.client_tx - .send(Response::PeerConnectionClosed(peer_id)) - .unwrap(); - } - } + mio::Token(SERVER_TOKEN) => { + let intent = self.server_stream.on_ready(event_set); + self.process_server_intent(intent, event_loop); + } - Request::PeerMessage(peer_id, message) => { - let intent = match self.peer_streams.get_mut(peer_id) { - Some(peer_stream) => peer_stream.on_notify(&message), - None => { - error!( - "Cannot send peer message {:?}: unknown id {}", - message, peer_id - ); - return; - } - }; - self.process_peer_intent( - intent, - mio::Token(peer_id), - event_loop, - ); - } + mio::Token(peer_id) => { + let intent = match self.peer_streams.get_mut(peer_id) { + Some(peer_stream) => peer_stream.on_ready(event_set), - Request::ServerRequest(server_request) => { - let intent = self.server_stream.on_notify(&server_request); - self.process_server_intent(intent, event_loop); - } + None => unreachable!("Unknown peer {} is ready", peer_id), + }; + self.process_peer_intent(intent, token, event_loop); + } + } + } + + fn notify( + &mut self, + event_loop: &mut mio::deprecated::EventLoop, + request: Request, + ) { + match request { + Request::PeerConnect(peer_id, ip, port) => { + if let Err(err) = self.connect_to_peer(peer_id, ip, port, event_loop) { + error!( + "Cannot open peer connection {} to {}:{}: {}", + peer_id, ip, port, err + ); + self + .client_tx + .send(Response::PeerConnectionClosed(peer_id)) + .unwrap(); } + } + + Request::PeerMessage(peer_id, message) => { + let intent = match self.peer_streams.get_mut(peer_id) { + Some(peer_stream) => peer_stream.on_notify(&message), + None => { + error!( + "Cannot send peer message {:?}: unknown id {}", + message, peer_id + ); + return; + } + }; + self.process_peer_intent(intent, mio::Token(peer_id), event_loop); + } + + Request::ServerRequest(server_request) => { + let intent = self.server_stream.on_notify(&server_request); + self.process_server_intent(intent, event_loop); + } } + } } #[allow(deprecated)] pub type Sender = mio::deprecated::Sender; pub struct Agent { - #[allow(deprecated)] - event_loop: mio::deprecated::EventLoop, - handler: Handler, + #[allow(deprecated)] + event_loop: mio::deprecated::EventLoop, + handler: Handler, } impl Agent { - pub fn new( - client_tx: crossbeam_channel::Sender, - ) -> io::Result { - // Create the event loop. - #[allow(deprecated)] - let mut event_loop = mio::deprecated::EventLoop::new()?; - // Create the handler for the event loop and register the handler's - // sockets with the event loop. - let handler = Handler::new(client_tx, &mut event_loop)?; - - Ok(Agent { - event_loop: event_loop, - handler: handler, - }) - } - - pub fn channel(&self) -> Sender { - #[allow(deprecated)] - self.event_loop.channel() - } + pub fn new( + client_tx: crossbeam_channel::Sender, + ) -> io::Result { + // Create the event loop. + #[allow(deprecated)] + let mut event_loop = mio::deprecated::EventLoop::new()?; + // Create the handler for the event loop and register the handler's + // sockets with the event loop. + let handler = Handler::new(client_tx, &mut event_loop)?; + + Ok(Agent { + event_loop: event_loop, + handler: handler, + }) + } + + pub fn channel(&self) -> Sender { + #[allow(deprecated)] + self.event_loop.channel() + } - pub fn run(&mut self) -> io::Result<()> { - #[allow(deprecated)] - self.event_loop.run(&mut self.handler) - } + pub fn run(&mut self) -> io::Result<()> { + #[allow(deprecated)] + self.event_loop.run(&mut self.handler) + } } diff --git a/src/proto/mod.rs b/src/proto/mod.rs index 24dcfb7..843b448 100644 --- a/src/proto/mod.rs +++ b/src/proto/mod.rs @@ -19,6 +19,6 @@ pub use self::server::{ServerRequest, ServerResponse}; pub use self::stream::*; pub use self::user::{User, UserStatus}; pub use self::value_codec::{ - Decode, ValueDecode, ValueDecodeError, ValueDecoder, ValueEncode, - ValueEncodeError, ValueEncoder, + Decode, ValueDecode, ValueDecodeError, ValueDecoder, ValueEncode, + ValueEncodeError, ValueEncoder, }; diff --git a/src/proto/packet.rs b/src/proto/packet.rs index c57d12c..ddb1044 100644 --- a/src/proto/packet.rs +++ b/src/proto/packet.rs @@ -19,46 +19,46 @@ use super::constants::*; #[derive(Debug)] pub struct Packet { - /// The current read position in the byte buffer. - cursor: usize, - /// The underlying bytes. - bytes: Vec, + /// The current read position in the byte buffer. + cursor: usize, + /// The underlying bytes. + bytes: Vec, } impl io::Read for Packet { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - let bytes_read = { - let mut slice = &self.bytes[self.cursor..]; - slice.read(buf)? - }; - self.cursor += bytes_read; - Ok(bytes_read) - } + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let bytes_read = { + let mut slice = &self.bytes[self.cursor..]; + slice.read(buf)? + }; + self.cursor += bytes_read; + Ok(bytes_read) + } } impl Packet { - /// Returns a readable packet struct from the wire representation of a - /// packet. - /// Assumes that the given vector is a valid length-prefixed packet. - fn from_wire(bytes: Vec) -> Self { - Packet { - cursor: U32_SIZE, - bytes: bytes, - } - } - - /// Provides the main way to read data out of a binary packet. - pub fn read_value(&mut self) -> Result - where - T: ReadFromPacket, - { - T::read_from_packet(self) - } - - /// Returns the number of unread bytes remaining in the packet. - pub fn bytes_remaining(&self) -> usize { - self.bytes.len() - self.cursor + /// Returns a readable packet struct from the wire representation of a + /// packet. + /// Assumes that the given vector is a valid length-prefixed packet. + fn from_wire(bytes: Vec) -> Self { + Packet { + cursor: U32_SIZE, + bytes: bytes, } + } + + /// Provides the main way to read data out of a binary packet. + pub fn read_value(&mut self) -> Result + where + T: ReadFromPacket, + { + T::read_from_packet(self) + } + + /// Returns the number of unread bytes remaining in the packet. + pub fn bytes_remaining(&self) -> usize { + self.bytes.len() - self.cursor + } } /*===================* @@ -67,45 +67,45 @@ impl Packet { #[derive(Debug)] pub struct MutPacket { - bytes: Vec, + bytes: Vec, } impl MutPacket { - /// Returns an empty packet with the given packet code. - pub fn new() -> Self { - // Leave space for the eventual size of the packet. - MutPacket { - bytes: vec![0; U32_SIZE], - } + /// Returns an empty packet with the given packet code. + pub fn new() -> Self { + // Leave space for the eventual size of the packet. + MutPacket { + bytes: vec![0; U32_SIZE], } - - /// Provides the main way to write data into a binary packet. - pub fn write_value(&mut self, val: &T) -> io::Result<()> - where - T: WriteToPacket, + } + + /// Provides the main way to write data into a binary packet. + pub fn write_value(&mut self, val: &T) -> io::Result<()> + where + T: WriteToPacket, + { + val.write_to_packet(self) + } + + /// Consumes the mutable packet and returns its wire representation. + pub fn into_bytes(mut self) -> Vec { + let length = (self.bytes.len() - U32_SIZE) as u32; { - val.write_to_packet(self) - } - - /// Consumes the mutable packet and returns its wire representation. - pub fn into_bytes(mut self) -> Vec { - let length = (self.bytes.len() - U32_SIZE) as u32; - { - let mut first_word = &mut self.bytes[..U32_SIZE]; - first_word.write_u32::(length).unwrap(); - } - self.bytes + let mut first_word = &mut self.bytes[..U32_SIZE]; + first_word.write_u32::(length).unwrap(); } + self.bytes + } } impl io::Write for MutPacket { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.bytes.write(buf) - } + fn write(&mut self, buf: &[u8]) -> io::Result { + self.bytes.write(buf) + } - fn flush(&mut self) -> io::Result<()> { - self.bytes.flush() - } + fn flush(&mut self) -> io::Result<()> { + self.bytes.flush() + } } /*===================* @@ -115,70 +115,68 @@ impl io::Write for MutPacket { /// This enum contains an error that arose when reading data out of a Packet. #[derive(Debug)] pub enum PacketReadError { - /// Attempted to read a boolean, but the value was not 0 nor 1. - InvalidBoolError(u8), - /// Attempted to read an unsigned 16-bit integer, but the value was too - /// large. - InvalidU16Error(u32), - /// Attempted to read a string, but a character was invalid. - InvalidStringError(Vec), - /// Attempted to read a user::Status, but the value was not a valid - /// representation of an enum variant. - InvalidUserStatusError(u32), - /// Encountered an I/O error while reading. - IOError(io::Error), + /// Attempted to read a boolean, but the value was not 0 nor 1. + InvalidBoolError(u8), + /// Attempted to read an unsigned 16-bit integer, but the value was too + /// large. + InvalidU16Error(u32), + /// Attempted to read a string, but a character was invalid. + InvalidStringError(Vec), + /// Attempted to read a user::Status, but the value was not a valid + /// representation of an enum variant. + InvalidUserStatusError(u32), + /// Encountered an I/O error while reading. + IOError(io::Error), } impl fmt::Display for PacketReadError { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - match *self { - PacketReadError::InvalidBoolError(n) => { - write!(fmt, "InvalidBoolError: {}", n) - } - PacketReadError::InvalidU16Error(n) => { - write!(fmt, "InvalidU16Error: {}", n) - } - PacketReadError::InvalidStringError(ref bytes) => { - write!(fmt, "InvalidStringError: {:?}", bytes) - } - PacketReadError::InvalidUserStatusError(n) => { - write!(fmt, "InvalidUserStatusError: {}", n) - } - PacketReadError::IOError(ref err) => { - write!(fmt, "IOError: {}", err) - } - } + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + match *self { + PacketReadError::InvalidBoolError(n) => { + write!(fmt, "InvalidBoolError: {}", n) + } + PacketReadError::InvalidU16Error(n) => { + write!(fmt, "InvalidU16Error: {}", n) + } + PacketReadError::InvalidStringError(ref bytes) => { + write!(fmt, "InvalidStringError: {:?}", bytes) + } + PacketReadError::InvalidUserStatusError(n) => { + write!(fmt, "InvalidUserStatusError: {}", n) + } + PacketReadError::IOError(ref err) => { + write!(fmt, "IOError: {}", err) + } } + } } impl error::Error for PacketReadError { - fn description(&self) -> &str { - match *self { - PacketReadError::InvalidBoolError(_) => "InvalidBoolError", - PacketReadError::InvalidU16Error(_) => "InvalidU16Error", - PacketReadError::InvalidStringError(_) => "InvalidStringError", - PacketReadError::InvalidUserStatusError(_) => { - "InvalidUserStatusError" - } - PacketReadError::IOError(_) => "IOError", - } + fn description(&self) -> &str { + match *self { + PacketReadError::InvalidBoolError(_) => "InvalidBoolError", + PacketReadError::InvalidU16Error(_) => "InvalidU16Error", + PacketReadError::InvalidStringError(_) => "InvalidStringError", + PacketReadError::InvalidUserStatusError(_) => "InvalidUserStatusError", + PacketReadError::IOError(_) => "IOError", } - - fn cause(&self) -> Option<&dyn error::Error> { - match *self { - PacketReadError::InvalidBoolError(_) => None, - PacketReadError::InvalidU16Error(_) => None, - PacketReadError::InvalidStringError(_) => None, - PacketReadError::InvalidUserStatusError(_) => None, - PacketReadError::IOError(ref err) => Some(err), - } + } + + fn cause(&self) -> Option<&dyn error::Error> { + match *self { + PacketReadError::InvalidBoolError(_) => None, + PacketReadError::InvalidU16Error(_) => None, + PacketReadError::InvalidStringError(_) => None, + PacketReadError::InvalidUserStatusError(_) => None, + PacketReadError::IOError(ref err) => Some(err), } + } } impl From for PacketReadError { - fn from(err: io::Error) -> Self { - PacketReadError::IOError(err) - } + fn from(err: io::Error) -> Self { + PacketReadError::IOError(err) + } } /*==================* @@ -188,81 +186,81 @@ impl From for PacketReadError { /// This trait is implemented by types that can be deserialized from binary /// Packets. pub trait ReadFromPacket: Sized { - fn read_from_packet(_: &mut Packet) -> Result; + fn read_from_packet(_: &mut Packet) -> Result; } /// 32-bit integers are serialized in 4 bytes, little-endian. impl ReadFromPacket for u32 { - fn read_from_packet(packet: &mut Packet) -> Result { - Ok(packet.read_u32::()?) - } + fn read_from_packet(packet: &mut Packet) -> Result { + Ok(packet.read_u32::()?) + } } /// For convenience, usize's are deserialized as u32's then casted. impl ReadFromPacket for usize { - fn read_from_packet(packet: &mut Packet) -> Result { - Ok(u32::read_from_packet(packet)? as usize) - } + fn read_from_packet(packet: &mut Packet) -> Result { + Ok(u32::read_from_packet(packet)? as usize) + } } /// Booleans are serialized as single bytes, containing either 0 or 1. impl ReadFromPacket for bool { - fn read_from_packet(packet: &mut Packet) -> Result { - match packet.read_u8()? { - 0 => Ok(false), - 1 => Ok(true), - n => Err(PacketReadError::InvalidBoolError(n)), - } + fn read_from_packet(packet: &mut Packet) -> Result { + match packet.read_u8()? { + 0 => Ok(false), + 1 => Ok(true), + n => Err(PacketReadError::InvalidBoolError(n)), } + } } /// 16-bit integers are serialized as 32-bit integers. impl ReadFromPacket for u16 { - fn read_from_packet(packet: &mut Packet) -> Result { - let n = u32::read_from_packet(packet)?; - if n > MAX_PORT { - return Err(PacketReadError::InvalidU16Error(n)); - } - Ok(n as u16) + fn read_from_packet(packet: &mut Packet) -> Result { + let n = u32::read_from_packet(packet)?; + if n > MAX_PORT { + return Err(PacketReadError::InvalidU16Error(n)); } + Ok(n as u16) + } } /// IPv4 addresses are serialized directly as 32-bit integers. impl ReadFromPacket for net::Ipv4Addr { - fn read_from_packet(packet: &mut Packet) -> Result { - let ip = u32::read_from_packet(packet)?; - Ok(net::Ipv4Addr::from(ip)) - } + fn read_from_packet(packet: &mut Packet) -> Result { + let ip = u32::read_from_packet(packet)?; + Ok(net::Ipv4Addr::from(ip)) + } } /// Strings are serialized as length-prefixed arrays of ISO-8859-1 encoded /// characters. impl ReadFromPacket for String { - fn read_from_packet(packet: &mut Packet) -> Result { - let len = usize::read_from_packet(packet)?; + fn read_from_packet(packet: &mut Packet) -> Result { + let len = usize::read_from_packet(packet)?; - let mut buffer = vec![0; len]; - packet.read_exact(&mut buffer)?; + let mut buffer = vec![0; len]; + packet.read_exact(&mut buffer)?; - match ISO_8859_1.decode(&buffer, DecoderTrap::Strict) { - Ok(string) => Ok(string), - Err(_) => Err(PacketReadError::InvalidStringError(buffer)), - } + match ISO_8859_1.decode(&buffer, DecoderTrap::Strict) { + Ok(string) => Ok(string), + Err(_) => Err(PacketReadError::InvalidStringError(buffer)), } + } } /// Vectors are serialized as length-prefixed arrays of values. impl ReadFromPacket for Vec { - fn read_from_packet(packet: &mut Packet) -> Result { - let len = usize::read_from_packet(packet)?; + fn read_from_packet(packet: &mut Packet) -> Result { + let len = usize::read_from_packet(packet)?; - let mut vec = Vec::new(); - for _ in 0..len { - vec.push(T::read_from_packet(packet)?); - } - - Ok(vec) + let mut vec = Vec::new(); + for _ in 0..len { + vec.push(T::read_from_packet(packet)?); } + + Ok(vec) + } } /*=================* @@ -272,55 +270,55 @@ impl ReadFromPacket for Vec { /// This trait is implemented by types that can be serialized to a binary /// MutPacket. pub trait WriteToPacket { - fn write_to_packet(&self, _: &mut MutPacket) -> io::Result<()>; + fn write_to_packet(&self, _: &mut MutPacket) -> io::Result<()>; } /// 32-bit integers are serialized in 4 bytes, little-endian. impl WriteToPacket for u32 { - fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { - packet.write_u32::(*self) - } + fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { + packet.write_u32::(*self) + } } /// Booleans are serialized as single bytes, containing either 0 or 1. impl WriteToPacket for bool { - fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { - packet.write_u8(*self as u8)?; - Ok(()) - } + fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { + packet.write_u8(*self as u8)?; + Ok(()) + } } /// 16-bit integers are serialized as 32-bit integers. impl WriteToPacket for u16 { - fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { - (*self as u32).write_to_packet(packet) - } + fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { + (*self as u32).write_to_packet(packet) + } } /// Strings are serialized as a length-prefixed array of ISO-8859-1 encoded /// characters. impl WriteToPacket for str { - fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { - // Encode the string. - let bytes = match ISO_8859_1.encode(self, EncoderTrap::Strict) { - Ok(bytes) => bytes, - Err(_) => { - let copy = self.to_string(); - return Err(io::Error::new(io::ErrorKind::Other, copy)); - } - }; - // Then write the bytes to the packet. - (bytes.len() as u32).write_to_packet(packet)?; - packet.write(&bytes)?; - Ok(()) - } + fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { + // Encode the string. + let bytes = match ISO_8859_1.encode(self, EncoderTrap::Strict) { + Ok(bytes) => bytes, + Err(_) => { + let copy = self.to_string(); + return Err(io::Error::new(io::ErrorKind::Other, copy)); + } + }; + // Then write the bytes to the packet. + (bytes.len() as u32).write_to_packet(packet)?; + packet.write(&bytes)?; + Ok(()) + } } /// Deref coercion does not happen for trait methods apparently. impl WriteToPacket for String { - fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { - (self as &str).write_to_packet(packet) - } + fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { + (self as &str).write_to_packet(packet) + } } /*========* @@ -330,87 +328,86 @@ impl WriteToPacket for String { /// This enum defines the possible states of a packet parser state machine. #[derive(Debug, Clone, Copy)] enum State { - /// The parser is waiting to read enough bytes to determine the - /// length of the following packet. - ReadingLength, - /// The parser is waiting to read enough bytes to form the entire - /// packet. - ReadingPacket, + /// The parser is waiting to read enough bytes to determine the + /// length of the following packet. + ReadingLength, + /// The parser is waiting to read enough bytes to form the entire + /// packet. + ReadingPacket, } #[derive(Debug)] pub struct Parser { - state: State, - num_bytes_left: usize, - buffer: Vec, + state: State, + num_bytes_left: usize, + buffer: Vec, } impl Parser { - pub fn new() -> Self { - Parser { - state: State::ReadingLength, - num_bytes_left: U32_SIZE, - buffer: vec![0; U32_SIZE], - } + pub fn new() -> Self { + Parser { + state: State::ReadingLength, + num_bytes_left: U32_SIZE, + buffer: vec![0; U32_SIZE], + } + } + + /// Attemps to read a packet in a non-blocking fashion. + /// If enough bytes can be read from the given byte stream to form a + /// complete packet `p`, returns `Ok(Some(p))`. + /// If not enough bytes are available, returns `Ok(None)`. + /// If an I/O error `e` arises when trying to read the underlying stream, + /// returns `Err(e)`. + /// Note: as long as this function returns `Ok(Some(p))`, the caller is + /// responsible for calling it once more to ensure that all packets are + /// read as soon as possible. + pub fn try_read(&mut self, stream: &mut U) -> io::Result> + where + U: io::Read, + { + // Try to read as many bytes as we currently need from the underlying + // byte stream. + let offset = self.buffer.len() - self.num_bytes_left; + + #[allow(deprecated)] + match stream.try_read(&mut self.buffer[offset..])? { + None => (), + + Some(num_bytes_read) => { + self.num_bytes_left -= num_bytes_read; + } } - /// Attemps to read a packet in a non-blocking fashion. - /// If enough bytes can be read from the given byte stream to form a - /// complete packet `p`, returns `Ok(Some(p))`. - /// If not enough bytes are available, returns `Ok(None)`. - /// If an I/O error `e` arises when trying to read the underlying stream, - /// returns `Err(e)`. - /// Note: as long as this function returns `Ok(Some(p))`, the caller is - /// responsible for calling it once more to ensure that all packets are - /// read as soon as possible. - pub fn try_read(&mut self, stream: &mut U) -> io::Result> - where - U: io::Read, - { - // Try to read as many bytes as we currently need from the underlying - // byte stream. - let offset = self.buffer.len() - self.num_bytes_left; - - #[allow(deprecated)] - match stream.try_read(&mut self.buffer[offset..])? { - None => (), - - Some(num_bytes_read) => { - self.num_bytes_left -= num_bytes_read; - } - } - - // If we haven't read enough bytes, return. - if self.num_bytes_left > 0 { - return Ok(None); - } - - // Otherwise, the behavior depends on what state we were in. - match self.state { - State::ReadingLength => { - // If we have finished reading the length prefix, then - // deserialize it, switch states and try to read the packet - // bytes. - let message_len = - LittleEndian::read_u32(&mut self.buffer) as usize; - if message_len > MAX_MESSAGE_SIZE { - unimplemented!(); - }; - self.state = State::ReadingPacket; - self.num_bytes_left = message_len; - self.buffer.resize(message_len + U32_SIZE, 0); - self.try_read(stream) - } - - State::ReadingPacket => { - // If we have finished reading the packet, swap the full buffer - // out and return the packet made from the full buffer. - self.state = State::ReadingLength; - self.num_bytes_left = U32_SIZE; - let new_buffer = vec![0; U32_SIZE]; - let old_buffer = mem::replace(&mut self.buffer, new_buffer); - Ok(Some(Packet::from_wire(old_buffer))) - } - } + // If we haven't read enough bytes, return. + if self.num_bytes_left > 0 { + return Ok(None); + } + + // Otherwise, the behavior depends on what state we were in. + match self.state { + State::ReadingLength => { + // If we have finished reading the length prefix, then + // deserialize it, switch states and try to read the packet + // bytes. + let message_len = LittleEndian::read_u32(&mut self.buffer) as usize; + if message_len > MAX_MESSAGE_SIZE { + unimplemented!(); + }; + self.state = State::ReadingPacket; + self.num_bytes_left = message_len; + self.buffer.resize(message_len + U32_SIZE, 0); + self.try_read(stream) + } + + State::ReadingPacket => { + // If we have finished reading the packet, swap the full buffer + // out and return the packet made from the full buffer. + self.state = State::ReadingLength; + self.num_bytes_left = U32_SIZE; + let new_buffer = vec![0; U32_SIZE]; + let old_buffer = mem::replace(&mut self.buffer, new_buffer); + Ok(Some(Packet::from_wire(old_buffer))) + } } + } } diff --git a/src/proto/peer/message.rs b/src/proto/peer/message.rs index cf269c6..51520d5 100644 --- a/src/proto/peer/message.rs +++ b/src/proto/peer/message.rs @@ -2,9 +2,9 @@ use std::io; use crate::proto::peer::constants::*; use crate::proto::{ - MutPacket, Packet, PacketReadError, ReadFromPacket, ValueDecode, - ValueDecodeError, ValueDecoder, ValueEncode, ValueEncodeError, - ValueEncoder, WriteToPacket, + MutPacket, Packet, PacketReadError, ReadFromPacket, ValueDecode, + ValueDecodeError, ValueDecoder, ValueEncode, ValueEncodeError, ValueEncoder, + WriteToPacket, }; /*=========* @@ -14,195 +14,183 @@ use crate::proto::{ /// This enum contains all the possible messages peers can exchange. #[derive(Clone, Debug, Eq, PartialEq)] pub enum Message { - PierceFirewall(u32), - PeerInit(PeerInit), - Unknown(u32), + PierceFirewall(u32), + PeerInit(PeerInit), + Unknown(u32), } impl ReadFromPacket for Message { - fn read_from_packet(packet: &mut Packet) -> Result { - let code: u32 = packet.read_value()?; - let message = match code { - CODE_PIERCE_FIREWALL => { - Message::PierceFirewall(packet.read_value()?) - } - - CODE_PEER_INIT => Message::PeerInit(packet.read_value()?), - - code => Message::Unknown(code), - }; - - let bytes_remaining = packet.bytes_remaining(); - if bytes_remaining > 0 { - warn!( - "Peer message with code {} contains {} extra bytes", - code, bytes_remaining - ) - } - - Ok(message) + fn read_from_packet(packet: &mut Packet) -> Result { + let code: u32 = packet.read_value()?; + let message = match code { + CODE_PIERCE_FIREWALL => Message::PierceFirewall(packet.read_value()?), + + CODE_PEER_INIT => Message::PeerInit(packet.read_value()?), + + code => Message::Unknown(code), + }; + + let bytes_remaining = packet.bytes_remaining(); + if bytes_remaining > 0 { + warn!( + "Peer message with code {} contains {} extra bytes", + code, bytes_remaining + ) } + + Ok(message) + } } impl ValueDecode for Message { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let position = decoder.position(); - let code: u32 = decoder.decode()?; - let message = match code { - CODE_PIERCE_FIREWALL => { - let val = decoder.decode()?; - Message::PierceFirewall(val) - } - CODE_PEER_INIT => { - let peer_init = decoder.decode()?; - Message::PeerInit(peer_init) - } - _ => { - return Err(ValueDecodeError::InvalidData { - value_name: "peer message code".to_string(), - cause: format!("unknown value {}", code), - position: position, - }) - } - }; - Ok(message) - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let position = decoder.position(); + let code: u32 = decoder.decode()?; + let message = match code { + CODE_PIERCE_FIREWALL => { + let val = decoder.decode()?; + Message::PierceFirewall(val) + } + CODE_PEER_INIT => { + let peer_init = decoder.decode()?; + Message::PeerInit(peer_init) + } + _ => { + return Err(ValueDecodeError::InvalidData { + value_name: "peer message code".to_string(), + cause: format!("unknown value {}", code), + position: position, + }) + } + }; + Ok(message) + } } impl ValueEncode for Message { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - match *self { - Message::PierceFirewall(token) => { - encoder.encode_u32(CODE_PIERCE_FIREWALL)?; - encoder.encode_u32(token)?; - } - Message::PeerInit(ref request) => { - encoder.encode_u32(CODE_PEER_INIT)?; - request.encode(encoder)?; - } - Message::Unknown(_) => unreachable!(), - } - Ok(()) + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + match *self { + Message::PierceFirewall(token) => { + encoder.encode_u32(CODE_PIERCE_FIREWALL)?; + encoder.encode_u32(token)?; + } + Message::PeerInit(ref request) => { + encoder.encode_u32(CODE_PEER_INIT)?; + request.encode(encoder)?; + } + Message::Unknown(_) => unreachable!(), } + Ok(()) + } } impl WriteToPacket for Message { - fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { - match *self { - Message::PierceFirewall(ref token) => { - packet.write_value(&CODE_PIERCE_FIREWALL)?; - packet.write_value(token)?; - } - - Message::PeerInit(ref request) => { - packet.write_value(&CODE_PEER_INIT)?; - packet.write_value(request)?; - } - - Message::Unknown(_) => unreachable!(), - } - Ok(()) + fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { + match *self { + Message::PierceFirewall(ref token) => { + packet.write_value(&CODE_PIERCE_FIREWALL)?; + packet.write_value(token)?; + } + + Message::PeerInit(ref request) => { + packet.write_value(&CODE_PEER_INIT)?; + packet.write_value(request)?; + } + + Message::Unknown(_) => unreachable!(), } + Ok(()) + } } #[derive(Clone, Debug, Eq, PartialEq)] pub struct PeerInit { - pub user_name: String, - pub connection_type: String, - pub token: u32, + pub user_name: String, + pub connection_type: String, + pub token: u32, } impl ReadFromPacket for PeerInit { - fn read_from_packet(packet: &mut Packet) -> Result { - let user_name = packet.read_value()?; - let connection_type = packet.read_value()?; - let token = packet.read_value()?; - Ok(PeerInit { - user_name, - connection_type, - token, - }) - } + fn read_from_packet(packet: &mut Packet) -> Result { + let user_name = packet.read_value()?; + let connection_type = packet.read_value()?; + let token = packet.read_value()?; + Ok(PeerInit { + user_name, + connection_type, + token, + }) + } } impl WriteToPacket for PeerInit { - fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { - packet.write_value(&self.user_name)?; - packet.write_value(&self.connection_type)?; - packet.write_value(&self.token)?; - Ok(()) - } + fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { + packet.write_value(&self.user_name)?; + packet.write_value(&self.connection_type)?; + packet.write_value(&self.token)?; + Ok(()) + } } impl ValueEncode for PeerInit { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - encoder.encode_string(&self.user_name)?; - encoder.encode_string(&self.connection_type)?; - encoder.encode_u32(self.token)?; - Ok(()) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + encoder.encode_string(&self.user_name)?; + encoder.encode_string(&self.connection_type)?; + encoder.encode_u32(self.token)?; + Ok(()) + } } impl ValueDecode for PeerInit { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let user_name = decoder.decode()?; - let connection_type = decoder.decode()?; - let token = decoder.decode()?; - Ok(PeerInit { - user_name, - connection_type, - token, - }) - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let user_name = decoder.decode()?; + let connection_type = decoder.decode()?; + let token = decoder.decode()?; + Ok(PeerInit { + user_name, + connection_type, + token, + }) + } } #[cfg(test)] mod tests { - use bytes::BytesMut; - - use crate::proto::value_codec::tests::roundtrip; - use crate::proto::{ValueDecodeError, ValueDecoder}; - - use super::*; - - #[test] - fn invalid_code() { - let mut bytes = BytesMut::new(); - bytes.extend_from_slice(&[57, 5, 0, 0]); - - let result = ValueDecoder::new(&bytes).decode::(); - - assert_eq!( - result, - Err(ValueDecodeError::InvalidData { - value_name: "peer message code".to_string(), - cause: "unknown value 1337".to_string(), - position: 0, - }) - ); - } - - #[test] - fn roundtrip_pierce_firewall() { - roundtrip(Message::PierceFirewall(1337)) - } - - #[test] - fn roundtrip_peer_init() { - roundtrip(Message::PeerInit(PeerInit { - user_name: "alice".to_string(), - connection_type: "P".to_string(), - token: 1337, - })); - } + use bytes::BytesMut; + + use crate::proto::value_codec::tests::roundtrip; + use crate::proto::{ValueDecodeError, ValueDecoder}; + + use super::*; + + #[test] + fn invalid_code() { + let mut bytes = BytesMut::new(); + bytes.extend_from_slice(&[57, 5, 0, 0]); + + let result = ValueDecoder::new(&bytes).decode::(); + + assert_eq!( + result, + Err(ValueDecodeError::InvalidData { + value_name: "peer message code".to_string(), + cause: "unknown value 1337".to_string(), + position: 0, + }) + ); + } + + #[test] + fn roundtrip_pierce_firewall() { + roundtrip(Message::PierceFirewall(1337)) + } + + #[test] + fn roundtrip_peer_init() { + roundtrip(Message::PeerInit(PeerInit { + user_name: "alice".to_string(), + connection_type: "P".to_string(), + token: 1337, + })); + } } diff --git a/src/proto/prefix.rs b/src/proto/prefix.rs index c2a756b..00b40ca 100644 --- a/src/proto/prefix.rs +++ b/src/proto/prefix.rs @@ -11,111 +11,111 @@ use crate::proto::u32::{encode_u32, U32_BYTE_LEN}; /// know the length ahead of encoding time. #[derive(Debug)] pub struct Prefixer<'a> { - /// The prefix buffer. - /// - /// The length of the suffix buffer is written to the end of this buffer - /// when the prefixer is finalized. - /// - /// Contains any bytes with which this prefixer was constructed. - prefix: &'a mut BytesMut, - - /// The suffix buffer. - /// - /// This is the buffer into which data is written before finalization. - suffix: BytesMut, + /// The prefix buffer. + /// + /// The length of the suffix buffer is written to the end of this buffer + /// when the prefixer is finalized. + /// + /// Contains any bytes with which this prefixer was constructed. + prefix: &'a mut BytesMut, + + /// The suffix buffer. + /// + /// This is the buffer into which data is written before finalization. + suffix: BytesMut, } impl Prefixer<'_> { - /// Constructs a prefixer for easily appending a length prefixed value to - /// the given buffer. - pub fn new<'a>(buffer: &'a mut BytesMut) -> Prefixer<'a> { - // Reserve some space fot the prefix, but don't write it yet. - buffer.reserve(U32_BYTE_LEN); - - // Split off the suffix, into which bytes will be written. - let suffix = buffer.split_off(buffer.len() + U32_BYTE_LEN); - - Prefixer { - prefix: buffer, - suffix: suffix, - } - } - - /// Returns a reference to the buffer into which data is written. - pub fn suffix(&self) -> &BytesMut { - &self.suffix - } - - /// Returns a mutable reference to a buffer into which data can be written. - pub fn suffix_mut(&mut self) -> &mut BytesMut { - &mut self.suffix - } - - /// Returns a buffer containing the original data passed at construction - /// time, to which a length-prefixed value is appended. The value itself is - /// the data written into the buffer returned by `get_mut()`. - /// - /// Returns `Ok(length)` if successful, in which case the length of the - /// suffix is `length`. - /// - /// Returns `Err(self)` if the length of the suffix is too large to store as - /// a prefix. - pub fn finalize(self) -> Result { - // Check that the suffix's length is not too large. - let length = self.suffix.len(); - let length_u32 = match u32::try_from(length) { - Ok(value) => value, - Err(_) => return Err(self), - }; - - // Write the prefix. - self.prefix.extend_from_slice(&encode_u32(length_u32)); - - // Join the prefix and suffix back again. Because `self.prefix` is - // private, we are sure that this is O(1). - self.prefix.unsplit(self.suffix); - - Ok(length_u32) + /// Constructs a prefixer for easily appending a length prefixed value to + /// the given buffer. + pub fn new<'a>(buffer: &'a mut BytesMut) -> Prefixer<'a> { + // Reserve some space fot the prefix, but don't write it yet. + buffer.reserve(U32_BYTE_LEN); + + // Split off the suffix, into which bytes will be written. + let suffix = buffer.split_off(buffer.len() + U32_BYTE_LEN); + + Prefixer { + prefix: buffer, + suffix: suffix, } + } + + /// Returns a reference to the buffer into which data is written. + pub fn suffix(&self) -> &BytesMut { + &self.suffix + } + + /// Returns a mutable reference to a buffer into which data can be written. + pub fn suffix_mut(&mut self) -> &mut BytesMut { + &mut self.suffix + } + + /// Returns a buffer containing the original data passed at construction + /// time, to which a length-prefixed value is appended. The value itself is + /// the data written into the buffer returned by `get_mut()`. + /// + /// Returns `Ok(length)` if successful, in which case the length of the + /// suffix is `length`. + /// + /// Returns `Err(self)` if the length of the suffix is too large to store as + /// a prefix. + pub fn finalize(self) -> Result { + // Check that the suffix's length is not too large. + let length = self.suffix.len(); + let length_u32 = match u32::try_from(length) { + Ok(value) => value, + Err(_) => return Err(self), + }; + + // Write the prefix. + self.prefix.extend_from_slice(&encode_u32(length_u32)); + + // Join the prefix and suffix back again. Because `self.prefix` is + // private, we are sure that this is O(1). + self.prefix.unsplit(self.suffix); + + Ok(length_u32) + } } #[cfg(test)] mod tests { - use super::Prefixer; + use super::Prefixer; - use std::convert::TryInto; + use std::convert::TryInto; - use bytes::{BufMut, BytesMut}; + use bytes::{BufMut, BytesMut}; - use crate::proto::u32::{decode_u32, U32_BYTE_LEN}; + use crate::proto::u32::{decode_u32, U32_BYTE_LEN}; - #[test] - fn finalize_empty() { - let mut buffer = BytesMut::new(); - buffer.put_u8(13); + #[test] + fn finalize_empty() { + let mut buffer = BytesMut::new(); + buffer.put_u8(13); - Prefixer::new(&mut buffer).finalize().unwrap(); + Prefixer::new(&mut buffer).finalize().unwrap(); - assert_eq!(buffer.len(), U32_BYTE_LEN + 1); - let array: [u8; U32_BYTE_LEN] = buffer[1..].try_into().unwrap(); - assert_eq!(decode_u32(array), 0); - } + assert_eq!(buffer.len(), U32_BYTE_LEN + 1); + let array: [u8; U32_BYTE_LEN] = buffer[1..].try_into().unwrap(); + assert_eq!(decode_u32(array), 0); + } - #[test] - fn finalize_ok() { - let mut buffer = BytesMut::new(); - buffer.put_u8(13); + #[test] + fn finalize_ok() { + let mut buffer = BytesMut::new(); + buffer.put_u8(13); - let mut prefixer = Prefixer::new(&mut buffer); + let mut prefixer = Prefixer::new(&mut buffer); - prefixer.suffix_mut().extend_from_slice(&[0; 42]); + prefixer.suffix_mut().extend_from_slice(&[0; 42]); - prefixer.finalize().unwrap(); + prefixer.finalize().unwrap(); - // 1 junk prefix byte, length prefix, 42 bytes of value. - assert_eq!(buffer.len(), U32_BYTE_LEN + 43); - let prefix = &buffer[1..U32_BYTE_LEN + 1]; - let array: [u8; U32_BYTE_LEN] = prefix.try_into().unwrap(); - assert_eq!(decode_u32(array), 42); - } + // 1 junk prefix byte, length prefix, 42 bytes of value. + assert_eq!(buffer.len(), U32_BYTE_LEN + 43); + let prefix = &buffer[1..U32_BYTE_LEN + 1]; + let array: [u8; U32_BYTE_LEN] = prefix.try_into().unwrap(); + assert_eq!(decode_u32(array), 42); + } } diff --git a/src/proto/server/request.rs b/src/proto/server/request.rs index f43145e..38fba4d 100644 --- a/src/proto/server/request.rs +++ b/src/proto/server/request.rs @@ -6,8 +6,8 @@ use crypto::md5::Md5; use crate::proto::packet::{MutPacket, WriteToPacket}; use crate::proto::server::constants::*; use crate::proto::{ - ValueDecode, ValueDecodeError, ValueDecoder, ValueEncode, ValueEncodeError, - ValueEncoder, + ValueDecode, ValueDecodeError, ValueDecoder, ValueEncode, ValueEncodeError, + ValueEncoder, }; /* ------- * @@ -15,9 +15,9 @@ use crate::proto::{ * ------- */ fn md5_str(string: &str) -> String { - let mut hasher = Md5::new(); - hasher.input_str(string); - hasher.result_str() + let mut hasher = Md5::new(); + hasher.input_str(string); + hasher.result_str() } /*================* @@ -26,192 +26,187 @@ fn md5_str(string: &str) -> String { #[derive(Debug, Eq, PartialEq)] pub enum ServerRequest { - CannotConnectRequest(CannotConnectRequest), - ConnectToPeerRequest(ConnectToPeerRequest), - FileSearchRequest(FileSearchRequest), - LoginRequest(LoginRequest), - PeerAddressRequest(PeerAddressRequest), - RoomJoinRequest(RoomJoinRequest), - RoomLeaveRequest(RoomLeaveRequest), - RoomListRequest, - RoomMessageRequest(RoomMessageRequest), - SetListenPortRequest(SetListenPortRequest), - UserStatusRequest(UserStatusRequest), + CannotConnectRequest(CannotConnectRequest), + ConnectToPeerRequest(ConnectToPeerRequest), + FileSearchRequest(FileSearchRequest), + LoginRequest(LoginRequest), + PeerAddressRequest(PeerAddressRequest), + RoomJoinRequest(RoomJoinRequest), + RoomLeaveRequest(RoomLeaveRequest), + RoomListRequest, + RoomMessageRequest(RoomMessageRequest), + SetListenPortRequest(SetListenPortRequest), + UserStatusRequest(UserStatusRequest), } impl WriteToPacket for ServerRequest { - fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { - match *self { - ServerRequest::CannotConnectRequest(ref request) => { - packet.write_value(&CODE_CANNOT_CONNECT)?; - packet.write_value(request)?; - } - - ServerRequest::ConnectToPeerRequest(ref request) => { - packet.write_value(&CODE_CONNECT_TO_PEER)?; - packet.write_value(request)?; - } - - ServerRequest::FileSearchRequest(ref request) => { - packet.write_value(&CODE_FILE_SEARCH)?; - packet.write_value(request)?; - } - - ServerRequest::LoginRequest(ref request) => { - packet.write_value(&CODE_LOGIN)?; - packet.write_value(request)?; - } - - ServerRequest::PeerAddressRequest(ref request) => { - packet.write_value(&CODE_PEER_ADDRESS)?; - packet.write_value(request)?; - } - - ServerRequest::RoomJoinRequest(ref request) => { - packet.write_value(&CODE_ROOM_JOIN)?; - packet.write_value(request)?; - } - - ServerRequest::RoomLeaveRequest(ref request) => { - packet.write_value(&CODE_ROOM_LEAVE)?; - packet.write_value(request)?; - } - - ServerRequest::RoomListRequest => { - packet.write_value(&CODE_ROOM_LIST)?; - } - - ServerRequest::RoomMessageRequest(ref request) => { - packet.write_value(&CODE_ROOM_MESSAGE)?; - packet.write_value(request)?; - } - - ServerRequest::SetListenPortRequest(ref request) => { - packet.write_value(&CODE_SET_LISTEN_PORT)?; - packet.write_value(request)?; - } - - ServerRequest::UserStatusRequest(ref request) => { - packet.write_value(&CODE_USER_STATUS)?; - packet.write_value(request)?; - } - } - Ok(()) - } + fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { + match *self { + ServerRequest::CannotConnectRequest(ref request) => { + packet.write_value(&CODE_CANNOT_CONNECT)?; + packet.write_value(request)?; + } + + ServerRequest::ConnectToPeerRequest(ref request) => { + packet.write_value(&CODE_CONNECT_TO_PEER)?; + packet.write_value(request)?; + } + + ServerRequest::FileSearchRequest(ref request) => { + packet.write_value(&CODE_FILE_SEARCH)?; + packet.write_value(request)?; + } + + ServerRequest::LoginRequest(ref request) => { + packet.write_value(&CODE_LOGIN)?; + packet.write_value(request)?; + } + + ServerRequest::PeerAddressRequest(ref request) => { + packet.write_value(&CODE_PEER_ADDRESS)?; + packet.write_value(request)?; + } + + ServerRequest::RoomJoinRequest(ref request) => { + packet.write_value(&CODE_ROOM_JOIN)?; + packet.write_value(request)?; + } + + ServerRequest::RoomLeaveRequest(ref request) => { + packet.write_value(&CODE_ROOM_LEAVE)?; + packet.write_value(request)?; + } + + ServerRequest::RoomListRequest => { + packet.write_value(&CODE_ROOM_LIST)?; + } + + ServerRequest::RoomMessageRequest(ref request) => { + packet.write_value(&CODE_ROOM_MESSAGE)?; + packet.write_value(request)?; + } + + ServerRequest::SetListenPortRequest(ref request) => { + packet.write_value(&CODE_SET_LISTEN_PORT)?; + packet.write_value(request)?; + } + + ServerRequest::UserStatusRequest(ref request) => { + packet.write_value(&CODE_USER_STATUS)?; + packet.write_value(request)?; + } + } + Ok(()) + } } impl ValueEncode for ServerRequest { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - match *self { - ServerRequest::CannotConnectRequest(ref request) => { - encoder.encode_u32(CODE_CANNOT_CONNECT)?; - request.encode(encoder)?; - } - ServerRequest::ConnectToPeerRequest(ref request) => { - encoder.encode_u32(CODE_CONNECT_TO_PEER)?; - request.encode(encoder)?; - } - ServerRequest::FileSearchRequest(ref request) => { - encoder.encode_u32(CODE_FILE_SEARCH)?; - request.encode(encoder)?; - } - ServerRequest::LoginRequest(ref request) => { - encoder.encode_u32(CODE_LOGIN)?; - request.encode(encoder)?; - } - ServerRequest::PeerAddressRequest(ref request) => { - encoder.encode_u32(CODE_PEER_ADDRESS)?; - request.encode(encoder)?; - } - ServerRequest::RoomJoinRequest(ref request) => { - encoder.encode_u32(CODE_ROOM_JOIN)?; - request.encode(encoder)?; - } - ServerRequest::RoomLeaveRequest(ref request) => { - encoder.encode_u32(CODE_ROOM_LEAVE)?; - request.encode(encoder)?; - } - ServerRequest::RoomListRequest => { - encoder.encode_u32(CODE_ROOM_LIST)?; - } - ServerRequest::RoomMessageRequest(ref request) => { - encoder.encode_u32(CODE_ROOM_MESSAGE)?; - request.encode(encoder)?; - } - ServerRequest::SetListenPortRequest(ref request) => { - encoder.encode_u32(CODE_SET_LISTEN_PORT)?; - request.encode(encoder)?; - } - ServerRequest::UserStatusRequest(ref request) => { - encoder.encode_u32(CODE_USER_STATUS)?; - request.encode(encoder)?; - } - } - Ok(()) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + match *self { + ServerRequest::CannotConnectRequest(ref request) => { + encoder.encode_u32(CODE_CANNOT_CONNECT)?; + request.encode(encoder)?; + } + ServerRequest::ConnectToPeerRequest(ref request) => { + encoder.encode_u32(CODE_CONNECT_TO_PEER)?; + request.encode(encoder)?; + } + ServerRequest::FileSearchRequest(ref request) => { + encoder.encode_u32(CODE_FILE_SEARCH)?; + request.encode(encoder)?; + } + ServerRequest::LoginRequest(ref request) => { + encoder.encode_u32(CODE_LOGIN)?; + request.encode(encoder)?; + } + ServerRequest::PeerAddressRequest(ref request) => { + encoder.encode_u32(CODE_PEER_ADDRESS)?; + request.encode(encoder)?; + } + ServerRequest::RoomJoinRequest(ref request) => { + encoder.encode_u32(CODE_ROOM_JOIN)?; + request.encode(encoder)?; + } + ServerRequest::RoomLeaveRequest(ref request) => { + encoder.encode_u32(CODE_ROOM_LEAVE)?; + request.encode(encoder)?; + } + ServerRequest::RoomListRequest => { + encoder.encode_u32(CODE_ROOM_LIST)?; + } + ServerRequest::RoomMessageRequest(ref request) => { + encoder.encode_u32(CODE_ROOM_MESSAGE)?; + request.encode(encoder)?; + } + ServerRequest::SetListenPortRequest(ref request) => { + encoder.encode_u32(CODE_SET_LISTEN_PORT)?; + request.encode(encoder)?; + } + ServerRequest::UserStatusRequest(ref request) => { + encoder.encode_u32(CODE_USER_STATUS)?; + request.encode(encoder)?; + } + } + Ok(()) + } } impl ValueDecode for ServerRequest { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let position = decoder.position(); - let code: u32 = decoder.decode()?; - let request = match code { - CODE_CANNOT_CONNECT => { - let request = decoder.decode()?; - ServerRequest::CannotConnectRequest(request) - } - CODE_CONNECT_TO_PEER => { - let request = decoder.decode()?; - ServerRequest::ConnectToPeerRequest(request) - } - CODE_FILE_SEARCH => { - let request = decoder.decode()?; - ServerRequest::FileSearchRequest(request) - } - CODE_LOGIN => { - let request = decoder.decode()?; - ServerRequest::LoginRequest(request) - } - CODE_PEER_ADDRESS => { - let request = decoder.decode()?; - ServerRequest::PeerAddressRequest(request) - } - CODE_ROOM_JOIN => { - let request = decoder.decode()?; - ServerRequest::RoomJoinRequest(request) - } - CODE_ROOM_LEAVE => { - let request = decoder.decode()?; - ServerRequest::RoomLeaveRequest(request) - } - CODE_ROOM_LIST => ServerRequest::RoomListRequest, - CODE_ROOM_MESSAGE => { - let request = decoder.decode()?; - ServerRequest::RoomMessageRequest(request) - } - CODE_SET_LISTEN_PORT => { - let request = decoder.decode()?; - ServerRequest::SetListenPortRequest(request) - } - CODE_USER_STATUS => { - let request = decoder.decode()?; - ServerRequest::UserStatusRequest(request) - } - _ => { - return Err(ValueDecodeError::InvalidData { - value_name: "server request code".to_string(), - cause: format!("unknown value {}", code), - position: position, - }) - } - }; - Ok(request) - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let position = decoder.position(); + let code: u32 = decoder.decode()?; + let request = match code { + CODE_CANNOT_CONNECT => { + let request = decoder.decode()?; + ServerRequest::CannotConnectRequest(request) + } + CODE_CONNECT_TO_PEER => { + let request = decoder.decode()?; + ServerRequest::ConnectToPeerRequest(request) + } + CODE_FILE_SEARCH => { + let request = decoder.decode()?; + ServerRequest::FileSearchRequest(request) + } + CODE_LOGIN => { + let request = decoder.decode()?; + ServerRequest::LoginRequest(request) + } + CODE_PEER_ADDRESS => { + let request = decoder.decode()?; + ServerRequest::PeerAddressRequest(request) + } + CODE_ROOM_JOIN => { + let request = decoder.decode()?; + ServerRequest::RoomJoinRequest(request) + } + CODE_ROOM_LEAVE => { + let request = decoder.decode()?; + ServerRequest::RoomLeaveRequest(request) + } + CODE_ROOM_LIST => ServerRequest::RoomListRequest, + CODE_ROOM_MESSAGE => { + let request = decoder.decode()?; + ServerRequest::RoomMessageRequest(request) + } + CODE_SET_LISTEN_PORT => { + let request = decoder.decode()?; + ServerRequest::SetListenPortRequest(request) + } + CODE_USER_STATUS => { + let request = decoder.decode()?; + ServerRequest::UserStatusRequest(request) + } + _ => { + return Err(ValueDecodeError::InvalidData { + value_name: "server request code".to_string(), + cause: format!("unknown value {}", code), + position: position, + }) + } + }; + Ok(request) + } } /*================* @@ -220,36 +215,31 @@ impl ValueDecode for ServerRequest { #[derive(Debug, Eq, PartialEq)] pub struct CannotConnectRequest { - pub token: u32, - pub user_name: String, + pub token: u32, + pub user_name: String, } impl WriteToPacket for CannotConnectRequest { - fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { - packet.write_value(&self.token)?; - packet.write_value(&self.user_name)?; - Ok(()) - } + fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { + packet.write_value(&self.token)?; + packet.write_value(&self.user_name)?; + Ok(()) + } } impl ValueEncode for CannotConnectRequest { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - encoder.encode_u32(self.token)?; - encoder.encode_string(&self.user_name) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + encoder.encode_u32(self.token)?; + encoder.encode_string(&self.user_name) + } } impl ValueDecode for CannotConnectRequest { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let token = decoder.decode()?; - let user_name = decoder.decode()?; - Ok(CannotConnectRequest { token, user_name }) - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let token = decoder.decode()?; + let user_name = decoder.decode()?; + Ok(CannotConnectRequest { token, user_name }) + } } /*=================* @@ -258,44 +248,39 @@ impl ValueDecode for CannotConnectRequest { #[derive(Debug, Eq, PartialEq)] pub struct ConnectToPeerRequest { - pub token: u32, - pub user_name: String, - pub connection_type: String, + pub token: u32, + pub user_name: String, + pub connection_type: String, } impl WriteToPacket for ConnectToPeerRequest { - fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { - packet.write_value(&self.token)?; - packet.write_value(&self.user_name)?; - packet.write_value(&self.connection_type)?; - Ok(()) - } + fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { + packet.write_value(&self.token)?; + packet.write_value(&self.user_name)?; + packet.write_value(&self.connection_type)?; + Ok(()) + } } impl ValueEncode for ConnectToPeerRequest { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - encoder.encode_u32(self.token)?; - encoder.encode_string(&self.user_name)?; - encoder.encode_string(&self.connection_type) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + encoder.encode_u32(self.token)?; + encoder.encode_string(&self.user_name)?; + encoder.encode_string(&self.connection_type) + } } impl ValueDecode for ConnectToPeerRequest { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let token = decoder.decode()?; - let user_name = decoder.decode()?; - let connection_type = decoder.decode()?; - Ok(ConnectToPeerRequest { - token, - user_name, - connection_type, - }) - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let token = decoder.decode()?; + let user_name = decoder.decode()?; + let connection_type = decoder.decode()?; + Ok(ConnectToPeerRequest { + token, + user_name, + connection_type, + }) + } } /*=============* @@ -304,36 +289,31 @@ impl ValueDecode for ConnectToPeerRequest { #[derive(Debug, Eq, PartialEq)] pub struct FileSearchRequest { - pub ticket: u32, - pub query: String, + pub ticket: u32, + pub query: String, } impl WriteToPacket for FileSearchRequest { - fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { - packet.write_value(&self.ticket)?; - packet.write_value(&self.query)?; - Ok(()) - } + fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { + packet.write_value(&self.ticket)?; + packet.write_value(&self.query)?; + Ok(()) + } } impl ValueEncode for FileSearchRequest { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - encoder.encode_u32(self.ticket)?; - encoder.encode_string(&self.query) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + encoder.encode_u32(self.ticket)?; + encoder.encode_string(&self.query) + } } impl ValueDecode for FileSearchRequest { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let ticket = decoder.decode()?; - let query = decoder.decode()?; - Ok(FileSearchRequest { ticket, query }) - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let ticket = decoder.decode()?; + let query = decoder.decode()?; + Ok(FileSearchRequest { ticket, query }) + } } /*=======* @@ -342,84 +322,79 @@ impl ValueDecode for FileSearchRequest { #[derive(Debug, Eq, PartialEq)] pub struct LoginRequest { - username: String, - password: String, - digest: String, - major: u32, - minor: u32, + username: String, + password: String, + digest: String, + major: u32, + minor: u32, } fn userpass_md5(username: &str, password: &str) -> String { - let userpass = String::new() + username + password; - md5_str(&userpass) + let userpass = String::new() + username + password; + md5_str(&userpass) } impl LoginRequest { - pub fn new( - username: &str, - password: &str, - major: u32, - minor: u32, - ) -> Result { - if password.len() > 0 { - Ok(LoginRequest { - username: username.to_string(), - password: password.to_string(), - digest: userpass_md5(username, password), - major, - minor, - }) - } else { - Err("Empty password") - } + pub fn new( + username: &str, + password: &str, + major: u32, + minor: u32, + ) -> Result { + if password.len() > 0 { + Ok(LoginRequest { + username: username.to_string(), + password: password.to_string(), + digest: userpass_md5(username, password), + major, + minor, + }) + } else { + Err("Empty password") } + } - fn has_correct_digest(&self) -> bool { - self.digest == userpass_md5(&self.username, &self.password) - } + fn has_correct_digest(&self) -> bool { + self.digest == userpass_md5(&self.username, &self.password) + } } impl WriteToPacket for LoginRequest { - fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { - packet.write_value(&self.username)?; - packet.write_value(&self.password)?; - packet.write_value(&self.major)?; - packet.write_value(&self.digest)?; - packet.write_value(&self.minor)?; - Ok(()) - } + fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { + packet.write_value(&self.username)?; + packet.write_value(&self.password)?; + packet.write_value(&self.major)?; + packet.write_value(&self.digest)?; + packet.write_value(&self.minor)?; + Ok(()) + } } impl ValueEncode for LoginRequest { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - encoder.encode_string(&self.username)?; - encoder.encode_string(&self.password)?; - encoder.encode_u32(self.major)?; - encoder.encode_string(&self.digest)?; - encoder.encode_u32(self.minor) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + encoder.encode_string(&self.username)?; + encoder.encode_string(&self.password)?; + encoder.encode_u32(self.major)?; + encoder.encode_string(&self.digest)?; + encoder.encode_u32(self.minor) + } } impl ValueDecode for LoginRequest { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let username = decoder.decode()?; - let password = decoder.decode()?; - let major = decoder.decode()?; - let digest = decoder.decode()?; - let minor = decoder.decode()?; - Ok(LoginRequest { - username, - password, - digest, - major, - minor, - }) - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let username = decoder.decode()?; + let password = decoder.decode()?; + let major = decoder.decode()?; + let digest = decoder.decode()?; + let minor = decoder.decode()?; + Ok(LoginRequest { + username, + password, + digest, + major, + minor, + }) + } } /*==============* @@ -428,32 +403,27 @@ impl ValueDecode for LoginRequest { #[derive(Debug, Eq, PartialEq)] pub struct PeerAddressRequest { - pub username: String, + pub username: String, } impl WriteToPacket for PeerAddressRequest { - fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { - packet.write_value(&self.username)?; - Ok(()) - } + fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { + packet.write_value(&self.username)?; + Ok(()) + } } impl ValueEncode for PeerAddressRequest { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - encoder.encode_string(&self.username) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + encoder.encode_string(&self.username) + } } impl ValueDecode for PeerAddressRequest { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let username = decoder.decode()?; - Ok(PeerAddressRequest { username: username }) - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let username = decoder.decode()?; + Ok(PeerAddressRequest { username: username }) + } } /*===========* @@ -462,34 +432,29 @@ impl ValueDecode for PeerAddressRequest { #[derive(Debug, Eq, PartialEq)] pub struct RoomJoinRequest { - pub room_name: String, + pub room_name: String, } impl WriteToPacket for RoomJoinRequest { - fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { - packet.write_value(&self.room_name)?; - Ok(()) - } + fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { + packet.write_value(&self.room_name)?; + Ok(()) + } } impl ValueEncode for RoomJoinRequest { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - encoder.encode_string(&self.room_name) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + encoder.encode_string(&self.room_name) + } } impl ValueDecode for RoomJoinRequest { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let room_name = decoder.decode()?; - Ok(RoomJoinRequest { - room_name: room_name, - }) - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let room_name = decoder.decode()?; + Ok(RoomJoinRequest { + room_name: room_name, + }) + } } /*============* @@ -498,34 +463,29 @@ impl ValueDecode for RoomJoinRequest { #[derive(Debug, Eq, PartialEq)] pub struct RoomLeaveRequest { - pub room_name: String, + pub room_name: String, } impl WriteToPacket for RoomLeaveRequest { - fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { - packet.write_value(&self.room_name)?; - Ok(()) - } + fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { + packet.write_value(&self.room_name)?; + Ok(()) + } } impl ValueEncode for RoomLeaveRequest { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - encoder.encode_string(&self.room_name) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + encoder.encode_string(&self.room_name) + } } impl ValueDecode for RoomLeaveRequest { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let room_name = decoder.decode()?; - Ok(RoomLeaveRequest { - room_name: room_name, - }) - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let room_name = decoder.decode()?; + Ok(RoomLeaveRequest { + room_name: room_name, + }) + } } /*==============* @@ -534,36 +494,31 @@ impl ValueDecode for RoomLeaveRequest { #[derive(Debug, Eq, PartialEq)] pub struct RoomMessageRequest { - pub room_name: String, - pub message: String, + pub room_name: String, + pub message: String, } impl WriteToPacket for RoomMessageRequest { - fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { - packet.write_value(&self.room_name)?; - packet.write_value(&self.message)?; - Ok(()) - } + fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { + packet.write_value(&self.room_name)?; + packet.write_value(&self.message)?; + Ok(()) + } } impl ValueEncode for RoomMessageRequest { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - encoder.encode_string(&self.room_name)?; - encoder.encode_string(&self.message) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + encoder.encode_string(&self.room_name)?; + encoder.encode_string(&self.message) + } } impl ValueDecode for RoomMessageRequest { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let room_name = decoder.decode()?; - let message = decoder.decode()?; - Ok(RoomMessageRequest { room_name, message }) - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let room_name = decoder.decode()?; + let message = decoder.decode()?; + Ok(RoomMessageRequest { room_name, message }) + } } /*=================* @@ -572,32 +527,27 @@ impl ValueDecode for RoomMessageRequest { #[derive(Debug, Eq, PartialEq)] pub struct SetListenPortRequest { - pub port: u16, + pub port: u16, } impl WriteToPacket for SetListenPortRequest { - fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { - packet.write_value(&self.port)?; - Ok(()) - } + fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { + packet.write_value(&self.port)?; + Ok(()) + } } impl ValueEncode for SetListenPortRequest { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - encoder.encode(&self.port) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + encoder.encode(&self.port) + } } impl ValueDecode for SetListenPortRequest { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let port = decoder.decode()?; - Ok(SetListenPortRequest { port: port }) - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let port = decoder.decode()?; + Ok(SetListenPortRequest { port: port }) + } } /*=============* @@ -606,34 +556,29 @@ impl ValueDecode for SetListenPortRequest { #[derive(Debug, Eq, PartialEq)] pub struct UserStatusRequest { - pub user_name: String, + pub user_name: String, } impl WriteToPacket for UserStatusRequest { - fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { - packet.write_value(&self.user_name)?; - Ok(()) - } + fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { + packet.write_value(&self.user_name)?; + Ok(()) + } } impl ValueEncode for UserStatusRequest { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - encoder.encode_string(&self.user_name) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + encoder.encode_string(&self.user_name) + } } impl ValueDecode for UserStatusRequest { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let user_name = decoder.decode()?; - Ok(UserStatusRequest { - user_name: user_name, - }) - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let user_name = decoder.decode()?; + Ok(UserStatusRequest { + user_name: user_name, + }) + } } /*=======* @@ -642,120 +587,119 @@ impl ValueDecode for UserStatusRequest { #[cfg(test)] mod tests { - use bytes::BytesMut; - - use crate::proto::value_codec::tests::roundtrip; - use crate::proto::{ValueDecodeError, ValueDecoder}; - - use super::*; - - #[test] - fn invalid_code() { - let mut bytes = BytesMut::new(); - bytes.extend_from_slice(&[57, 5, 0, 0]); - - let result = ValueDecoder::new(&bytes).decode::(); - - assert_eq!( - result, - Err(ValueDecodeError::InvalidData { - value_name: "server request code".to_string(), - cause: "unknown value 1337".to_string(), - position: 0, - }) - ); - } - - #[test] - fn roundtrip_cannot_connect_request() { - roundtrip(ServerRequest::CannotConnectRequest(CannotConnectRequest { - token: 1337, - user_name: "alice".to_string(), - })) - } - - #[test] - fn roundtrip_connect_to_peer_request() { - roundtrip(ServerRequest::ConnectToPeerRequest(ConnectToPeerRequest { - token: 1337, - user_name: "alice".to_string(), - connection_type: "P".to_string(), - })) - } - - #[test] - fn roundtrip_file_search_request() { - roundtrip(ServerRequest::FileSearchRequest(FileSearchRequest { - ticket: 1337, - query: "foo.txt".to_string(), - })) - } - - #[test] - #[should_panic] - fn new_login_request_with_empty_password() { - LoginRequest::new("alice", "", 1337, 42).unwrap(); - } - - #[test] - fn new_login_request_has_correct_digest() { - let request = - LoginRequest::new("alice", "password1234", 1337, 42).unwrap(); - assert!(request.has_correct_digest()); - } - - #[test] - fn roundtrip_login_request() { - roundtrip(ServerRequest::LoginRequest( - LoginRequest::new("alice", "password1234", 1337, 42).unwrap(), - )) - } - - #[test] - fn roundtrip_peer_address_request() { - roundtrip(ServerRequest::PeerAddressRequest(PeerAddressRequest { - username: "alice".to_string(), - })) - } - - #[test] - fn roundtrip_room_join_request() { - roundtrip(ServerRequest::RoomJoinRequest(RoomJoinRequest { - room_name: "best room ever".to_string(), - })) - } - - #[test] - fn roundtrip_room_leave_request() { - roundtrip(ServerRequest::RoomLeaveRequest(RoomLeaveRequest { - room_name: "best room ever".to_string(), - })) - } - - #[test] - fn roundtrip_room_list_request() { - roundtrip(ServerRequest::RoomListRequest) - } - - #[test] - fn roundtrip_room_message_request() { - roundtrip(ServerRequest::RoomMessageRequest(RoomMessageRequest { - room_name: "best room ever".to_string(), - message: "hello world!".to_string(), - })) - } - - #[test] - fn roundtrip_set_listen_port_request() { - roundtrip(ServerRequest::SetListenPortRequest(SetListenPortRequest { - port: 1337, - })) - } - - #[test] - fn roundtrip_user_status_request() { - roundtrip(ServerRequest::UserStatusRequest(UserStatusRequest { - user_name: "alice".to_string(), - })) - } + use bytes::BytesMut; + + use crate::proto::value_codec::tests::roundtrip; + use crate::proto::{ValueDecodeError, ValueDecoder}; + + use super::*; + + #[test] + fn invalid_code() { + let mut bytes = BytesMut::new(); + bytes.extend_from_slice(&[57, 5, 0, 0]); + + let result = ValueDecoder::new(&bytes).decode::(); + + assert_eq!( + result, + Err(ValueDecodeError::InvalidData { + value_name: "server request code".to_string(), + cause: "unknown value 1337".to_string(), + position: 0, + }) + ); + } + + #[test] + fn roundtrip_cannot_connect_request() { + roundtrip(ServerRequest::CannotConnectRequest(CannotConnectRequest { + token: 1337, + user_name: "alice".to_string(), + })) + } + + #[test] + fn roundtrip_connect_to_peer_request() { + roundtrip(ServerRequest::ConnectToPeerRequest(ConnectToPeerRequest { + token: 1337, + user_name: "alice".to_string(), + connection_type: "P".to_string(), + })) + } + + #[test] + fn roundtrip_file_search_request() { + roundtrip(ServerRequest::FileSearchRequest(FileSearchRequest { + ticket: 1337, + query: "foo.txt".to_string(), + })) + } + + #[test] + #[should_panic] + fn new_login_request_with_empty_password() { + LoginRequest::new("alice", "", 1337, 42).unwrap(); + } + + #[test] + fn new_login_request_has_correct_digest() { + let request = LoginRequest::new("alice", "password1234", 1337, 42).unwrap(); + assert!(request.has_correct_digest()); + } + + #[test] + fn roundtrip_login_request() { + roundtrip(ServerRequest::LoginRequest( + LoginRequest::new("alice", "password1234", 1337, 42).unwrap(), + )) + } + + #[test] + fn roundtrip_peer_address_request() { + roundtrip(ServerRequest::PeerAddressRequest(PeerAddressRequest { + username: "alice".to_string(), + })) + } + + #[test] + fn roundtrip_room_join_request() { + roundtrip(ServerRequest::RoomJoinRequest(RoomJoinRequest { + room_name: "best room ever".to_string(), + })) + } + + #[test] + fn roundtrip_room_leave_request() { + roundtrip(ServerRequest::RoomLeaveRequest(RoomLeaveRequest { + room_name: "best room ever".to_string(), + })) + } + + #[test] + fn roundtrip_room_list_request() { + roundtrip(ServerRequest::RoomListRequest) + } + + #[test] + fn roundtrip_room_message_request() { + roundtrip(ServerRequest::RoomMessageRequest(RoomMessageRequest { + room_name: "best room ever".to_string(), + message: "hello world!".to_string(), + })) + } + + #[test] + fn roundtrip_set_listen_port_request() { + roundtrip(ServerRequest::SetListenPortRequest(SetListenPortRequest { + port: 1337, + })) + } + + #[test] + fn roundtrip_user_status_request() { + roundtrip(ServerRequest::UserStatusRequest(UserStatusRequest { + user_name: "alice".to_string(), + })) + } } diff --git a/src/proto/server/response.rs b/src/proto/server/response.rs index c5ae637..68eab33 100644 --- a/src/proto/server/response.rs +++ b/src/proto/server/response.rs @@ -3,8 +3,8 @@ use std::net; use crate::proto::packet::{Packet, PacketReadError, ReadFromPacket}; use crate::proto::server::constants::*; use crate::proto::{ - User, UserStatus, ValueDecode, ValueDecodeError, ValueDecoder, ValueEncode, - ValueEncodeError, ValueEncoder, + User, UserStatus, ValueDecode, ValueDecodeError, ValueDecoder, ValueEncode, + ValueEncodeError, ValueEncoder, }; /*=================* @@ -13,279 +13,268 @@ use crate::proto::{ #[derive(Debug, Eq, PartialEq)] pub enum ServerResponse { - ConnectToPeerResponse(ConnectToPeerResponse), - FileSearchResponse(FileSearchResponse), - LoginResponse(LoginResponse), - PeerAddressResponse(PeerAddressResponse), - PrivilegedUsersResponse(PrivilegedUsersResponse), - RoomJoinResponse(RoomJoinResponse), - RoomLeaveResponse(RoomLeaveResponse), - RoomListResponse(RoomListResponse), - RoomMessageResponse(RoomMessageResponse), - RoomTickersResponse(RoomTickersResponse), - RoomUserJoinedResponse(RoomUserJoinedResponse), - RoomUserLeftResponse(RoomUserLeftResponse), - UserInfoResponse(UserInfoResponse), - UserStatusResponse(UserStatusResponse), - WishlistIntervalResponse(WishlistIntervalResponse), - - // Unknown purpose - ParentMinSpeedResponse(ParentMinSpeedResponse), - ParentSpeedRatioResponse(ParentSpeedRatioResponse), - - UnknownResponse(u32), + ConnectToPeerResponse(ConnectToPeerResponse), + FileSearchResponse(FileSearchResponse), + LoginResponse(LoginResponse), + PeerAddressResponse(PeerAddressResponse), + PrivilegedUsersResponse(PrivilegedUsersResponse), + RoomJoinResponse(RoomJoinResponse), + RoomLeaveResponse(RoomLeaveResponse), + RoomListResponse(RoomListResponse), + RoomMessageResponse(RoomMessageResponse), + RoomTickersResponse(RoomTickersResponse), + RoomUserJoinedResponse(RoomUserJoinedResponse), + RoomUserLeftResponse(RoomUserLeftResponse), + UserInfoResponse(UserInfoResponse), + UserStatusResponse(UserStatusResponse), + WishlistIntervalResponse(WishlistIntervalResponse), + + // Unknown purpose + ParentMinSpeedResponse(ParentMinSpeedResponse), + ParentSpeedRatioResponse(ParentSpeedRatioResponse), + + UnknownResponse(u32), } impl ReadFromPacket for ServerResponse { - fn read_from_packet(packet: &mut Packet) -> Result { - let code: u32 = packet.read_value()?; - let resp = match code { - CODE_CONNECT_TO_PEER => { - ServerResponse::ConnectToPeerResponse(packet.read_value()?) - } - - CODE_FILE_SEARCH => { - ServerResponse::FileSearchResponse(packet.read_value()?) - } - - CODE_LOGIN => ServerResponse::LoginResponse(packet.read_value()?), - - CODE_PEER_ADDRESS => { - ServerResponse::PeerAddressResponse(packet.read_value()?) - } - - CODE_PRIVILEGED_USERS => { - ServerResponse::PrivilegedUsersResponse(packet.read_value()?) - } - - CODE_ROOM_JOIN => { - ServerResponse::RoomJoinResponse(packet.read_value()?) - } - - CODE_ROOM_LEAVE => { - ServerResponse::RoomLeaveResponse(packet.read_value()?) - } - - CODE_ROOM_LIST => { - ServerResponse::RoomListResponse(packet.read_value()?) - } - - CODE_ROOM_MESSAGE => { - ServerResponse::RoomMessageResponse(packet.read_value()?) - } - - CODE_ROOM_TICKERS => { - ServerResponse::RoomTickersResponse(packet.read_value()?) - } - - CODE_ROOM_USER_JOINED => { - ServerResponse::RoomUserJoinedResponse(packet.read_value()?) - } - - CODE_ROOM_USER_LEFT => { - ServerResponse::RoomUserLeftResponse(packet.read_value()?) - } - - CODE_USER_INFO => { - ServerResponse::UserInfoResponse(packet.read_value()?) - } - - CODE_USER_STATUS => { - ServerResponse::UserStatusResponse(packet.read_value()?) - } - - CODE_WISHLIST_INTERVAL => { - ServerResponse::WishlistIntervalResponse(packet.read_value()?) - } - - CODE_PARENT_MIN_SPEED => { - ServerResponse::ParentMinSpeedResponse(packet.read_value()?) - } - - CODE_PARENT_SPEED_RATIO => { - ServerResponse::ParentSpeedRatioResponse(packet.read_value()?) - } - - code => ServerResponse::UnknownResponse(code), - }; - let bytes_remaining = packet.bytes_remaining(); - if bytes_remaining > 0 { - warn!( - "Packet with code {} contains {} extra bytes", - code, bytes_remaining - ) - } - Ok(resp) + fn read_from_packet(packet: &mut Packet) -> Result { + let code: u32 = packet.read_value()?; + let resp = match code { + CODE_CONNECT_TO_PEER => { + ServerResponse::ConnectToPeerResponse(packet.read_value()?) + } + + CODE_FILE_SEARCH => { + ServerResponse::FileSearchResponse(packet.read_value()?) + } + + CODE_LOGIN => ServerResponse::LoginResponse(packet.read_value()?), + + CODE_PEER_ADDRESS => { + ServerResponse::PeerAddressResponse(packet.read_value()?) + } + + CODE_PRIVILEGED_USERS => { + ServerResponse::PrivilegedUsersResponse(packet.read_value()?) + } + + CODE_ROOM_JOIN => ServerResponse::RoomJoinResponse(packet.read_value()?), + + CODE_ROOM_LEAVE => { + ServerResponse::RoomLeaveResponse(packet.read_value()?) + } + + CODE_ROOM_LIST => ServerResponse::RoomListResponse(packet.read_value()?), + + CODE_ROOM_MESSAGE => { + ServerResponse::RoomMessageResponse(packet.read_value()?) + } + + CODE_ROOM_TICKERS => { + ServerResponse::RoomTickersResponse(packet.read_value()?) + } + + CODE_ROOM_USER_JOINED => { + ServerResponse::RoomUserJoinedResponse(packet.read_value()?) + } + + CODE_ROOM_USER_LEFT => { + ServerResponse::RoomUserLeftResponse(packet.read_value()?) + } + + CODE_USER_INFO => ServerResponse::UserInfoResponse(packet.read_value()?), + + CODE_USER_STATUS => { + ServerResponse::UserStatusResponse(packet.read_value()?) + } + + CODE_WISHLIST_INTERVAL => { + ServerResponse::WishlistIntervalResponse(packet.read_value()?) + } + + CODE_PARENT_MIN_SPEED => { + ServerResponse::ParentMinSpeedResponse(packet.read_value()?) + } + + CODE_PARENT_SPEED_RATIO => { + ServerResponse::ParentSpeedRatioResponse(packet.read_value()?) + } + + code => ServerResponse::UnknownResponse(code), + }; + let bytes_remaining = packet.bytes_remaining(); + if bytes_remaining > 0 { + warn!( + "Packet with code {} contains {} extra bytes", + code, bytes_remaining + ) } + Ok(resp) + } } impl ValueEncode for ServerResponse { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - match *self { - ServerResponse::ConnectToPeerResponse(ref response) => { - encoder.encode_u32(CODE_CONNECT_TO_PEER)?; - response.encode(encoder)?; - } - ServerResponse::FileSearchResponse(ref response) => { - encoder.encode_u32(CODE_FILE_SEARCH)?; - response.encode(encoder)?; - } - ServerResponse::LoginResponse(ref response) => { - encoder.encode_u32(CODE_LOGIN)?; - response.encode(encoder)?; - } - ServerResponse::ParentMinSpeedResponse(ref response) => { - encoder.encode_u32(CODE_PARENT_MIN_SPEED)?; - response.encode(encoder)?; - } - ServerResponse::ParentSpeedRatioResponse(ref response) => { - encoder.encode_u32(CODE_PARENT_SPEED_RATIO)?; - response.encode(encoder)?; - } - ServerResponse::PeerAddressResponse(ref response) => { - encoder.encode_u32(CODE_PEER_ADDRESS)?; - response.encode(encoder)?; - } - ServerResponse::PrivilegedUsersResponse(ref response) => { - encoder.encode_u32(CODE_PRIVILEGED_USERS)?; - response.encode(encoder)?; - } - ServerResponse::RoomJoinResponse(ref response) => { - encoder.encode_u32(CODE_ROOM_JOIN)?; - response.encode(encoder)?; - } - ServerResponse::RoomLeaveResponse(ref response) => { - encoder.encode_u32(CODE_ROOM_LEAVE)?; - response.encode(encoder)?; - } - ServerResponse::RoomListResponse(ref response) => { - encoder.encode_u32(CODE_ROOM_LIST)?; - response.encode(encoder)?; - } - ServerResponse::RoomMessageResponse(ref response) => { - encoder.encode_u32(CODE_ROOM_MESSAGE)?; - response.encode(encoder)?; - } - ServerResponse::RoomTickersResponse(ref response) => { - encoder.encode_u32(CODE_ROOM_TICKERS)?; - response.encode(encoder)?; - } - ServerResponse::RoomUserJoinedResponse(ref response) => { - encoder.encode_u32(CODE_ROOM_USER_JOINED)?; - response.encode(encoder)?; - } - ServerResponse::RoomUserLeftResponse(ref response) => { - encoder.encode_u32(CODE_ROOM_USER_LEFT)?; - response.encode(encoder)?; - } - ServerResponse::UserInfoResponse(ref response) => { - encoder.encode_u32(CODE_USER_INFO)?; - response.encode(encoder)?; - } - ServerResponse::UserStatusResponse(ref response) => { - encoder.encode_u32(CODE_USER_STATUS)?; - response.encode(encoder)?; - } - ServerResponse::WishlistIntervalResponse(ref response) => { - encoder.encode_u32(CODE_WISHLIST_INTERVAL)?; - response.encode(encoder)?; - } - _ => { - unimplemented!(); - } - }; - Ok(()) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + match *self { + ServerResponse::ConnectToPeerResponse(ref response) => { + encoder.encode_u32(CODE_CONNECT_TO_PEER)?; + response.encode(encoder)?; + } + ServerResponse::FileSearchResponse(ref response) => { + encoder.encode_u32(CODE_FILE_SEARCH)?; + response.encode(encoder)?; + } + ServerResponse::LoginResponse(ref response) => { + encoder.encode_u32(CODE_LOGIN)?; + response.encode(encoder)?; + } + ServerResponse::ParentMinSpeedResponse(ref response) => { + encoder.encode_u32(CODE_PARENT_MIN_SPEED)?; + response.encode(encoder)?; + } + ServerResponse::ParentSpeedRatioResponse(ref response) => { + encoder.encode_u32(CODE_PARENT_SPEED_RATIO)?; + response.encode(encoder)?; + } + ServerResponse::PeerAddressResponse(ref response) => { + encoder.encode_u32(CODE_PEER_ADDRESS)?; + response.encode(encoder)?; + } + ServerResponse::PrivilegedUsersResponse(ref response) => { + encoder.encode_u32(CODE_PRIVILEGED_USERS)?; + response.encode(encoder)?; + } + ServerResponse::RoomJoinResponse(ref response) => { + encoder.encode_u32(CODE_ROOM_JOIN)?; + response.encode(encoder)?; + } + ServerResponse::RoomLeaveResponse(ref response) => { + encoder.encode_u32(CODE_ROOM_LEAVE)?; + response.encode(encoder)?; + } + ServerResponse::RoomListResponse(ref response) => { + encoder.encode_u32(CODE_ROOM_LIST)?; + response.encode(encoder)?; + } + ServerResponse::RoomMessageResponse(ref response) => { + encoder.encode_u32(CODE_ROOM_MESSAGE)?; + response.encode(encoder)?; + } + ServerResponse::RoomTickersResponse(ref response) => { + encoder.encode_u32(CODE_ROOM_TICKERS)?; + response.encode(encoder)?; + } + ServerResponse::RoomUserJoinedResponse(ref response) => { + encoder.encode_u32(CODE_ROOM_USER_JOINED)?; + response.encode(encoder)?; + } + ServerResponse::RoomUserLeftResponse(ref response) => { + encoder.encode_u32(CODE_ROOM_USER_LEFT)?; + response.encode(encoder)?; + } + ServerResponse::UserInfoResponse(ref response) => { + encoder.encode_u32(CODE_USER_INFO)?; + response.encode(encoder)?; + } + ServerResponse::UserStatusResponse(ref response) => { + encoder.encode_u32(CODE_USER_STATUS)?; + response.encode(encoder)?; + } + ServerResponse::WishlistIntervalResponse(ref response) => { + encoder.encode_u32(CODE_WISHLIST_INTERVAL)?; + response.encode(encoder)?; + } + _ => { + unimplemented!(); + } + }; + Ok(()) + } } impl ValueDecode for ServerResponse { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let position = decoder.position(); - let code: u32 = decoder.decode()?; - let response = match code { - CODE_CONNECT_TO_PEER => { - let response = decoder.decode()?; - ServerResponse::ConnectToPeerResponse(response) - } - CODE_FILE_SEARCH => { - let response = decoder.decode()?; - ServerResponse::FileSearchResponse(response) - } - CODE_LOGIN => { - let response = decoder.decode()?; - ServerResponse::LoginResponse(response) - } - CODE_PARENT_MIN_SPEED => { - let response = decoder.decode()?; - ServerResponse::ParentMinSpeedResponse(response) - } - CODE_PARENT_SPEED_RATIO => { - let response = decoder.decode()?; - ServerResponse::ParentSpeedRatioResponse(response) - } - CODE_PEER_ADDRESS => { - let response = decoder.decode()?; - ServerResponse::PeerAddressResponse(response) - } - CODE_PRIVILEGED_USERS => { - let response = decoder.decode()?; - ServerResponse::PrivilegedUsersResponse(response) - } - CODE_ROOM_JOIN => { - let response = decoder.decode()?; - ServerResponse::RoomJoinResponse(response) - } - CODE_ROOM_LEAVE => { - let response = decoder.decode()?; - ServerResponse::RoomLeaveResponse(response) - } - CODE_ROOM_LIST => { - let response = decoder.decode()?; - ServerResponse::RoomListResponse(response) - } - CODE_ROOM_MESSAGE => { - let response = decoder.decode()?; - ServerResponse::RoomMessageResponse(response) - } - CODE_ROOM_TICKERS => { - let response = decoder.decode()?; - ServerResponse::RoomTickersResponse(response) - } - CODE_ROOM_USER_JOINED => { - let response = decoder.decode()?; - ServerResponse::RoomUserJoinedResponse(response) - } - CODE_ROOM_USER_LEFT => { - let response = decoder.decode()?; - ServerResponse::RoomUserLeftResponse(response) - } - CODE_USER_INFO => { - let response = decoder.decode()?; - ServerResponse::UserInfoResponse(response) - } - CODE_USER_STATUS => { - let response = decoder.decode()?; - ServerResponse::UserStatusResponse(response) - } - CODE_WISHLIST_INTERVAL => { - let response = decoder.decode()?; - ServerResponse::WishlistIntervalResponse(response) - } - _ => { - return Err(ValueDecodeError::InvalidData { - value_name: "server response code".to_string(), - cause: format!("unknown value {}", code), - position: position, - }); - } - }; - Ok(response) - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let position = decoder.position(); + let code: u32 = decoder.decode()?; + let response = match code { + CODE_CONNECT_TO_PEER => { + let response = decoder.decode()?; + ServerResponse::ConnectToPeerResponse(response) + } + CODE_FILE_SEARCH => { + let response = decoder.decode()?; + ServerResponse::FileSearchResponse(response) + } + CODE_LOGIN => { + let response = decoder.decode()?; + ServerResponse::LoginResponse(response) + } + CODE_PARENT_MIN_SPEED => { + let response = decoder.decode()?; + ServerResponse::ParentMinSpeedResponse(response) + } + CODE_PARENT_SPEED_RATIO => { + let response = decoder.decode()?; + ServerResponse::ParentSpeedRatioResponse(response) + } + CODE_PEER_ADDRESS => { + let response = decoder.decode()?; + ServerResponse::PeerAddressResponse(response) + } + CODE_PRIVILEGED_USERS => { + let response = decoder.decode()?; + ServerResponse::PrivilegedUsersResponse(response) + } + CODE_ROOM_JOIN => { + let response = decoder.decode()?; + ServerResponse::RoomJoinResponse(response) + } + CODE_ROOM_LEAVE => { + let response = decoder.decode()?; + ServerResponse::RoomLeaveResponse(response) + } + CODE_ROOM_LIST => { + let response = decoder.decode()?; + ServerResponse::RoomListResponse(response) + } + CODE_ROOM_MESSAGE => { + let response = decoder.decode()?; + ServerResponse::RoomMessageResponse(response) + } + CODE_ROOM_TICKERS => { + let response = decoder.decode()?; + ServerResponse::RoomTickersResponse(response) + } + CODE_ROOM_USER_JOINED => { + let response = decoder.decode()?; + ServerResponse::RoomUserJoinedResponse(response) + } + CODE_ROOM_USER_LEFT => { + let response = decoder.decode()?; + ServerResponse::RoomUserLeftResponse(response) + } + CODE_USER_INFO => { + let response = decoder.decode()?; + ServerResponse::UserInfoResponse(response) + } + CODE_USER_STATUS => { + let response = decoder.decode()?; + ServerResponse::UserStatusResponse(response) + } + CODE_WISHLIST_INTERVAL => { + let response = decoder.decode()?; + ServerResponse::WishlistIntervalResponse(response) + } + _ => { + return Err(ValueDecodeError::InvalidData { + value_name: "server response code".to_string(), + cause: format!("unknown value {}", code), + position: position, + }); + } + }; + Ok(response) + } } /*=================* @@ -294,68 +283,63 @@ impl ValueDecode for ServerResponse { #[derive(Debug, Eq, PartialEq)] pub struct ConnectToPeerResponse { - pub user_name: String, - pub connection_type: String, - pub ip: net::Ipv4Addr, - pub port: u16, - pub token: u32, - pub is_privileged: bool, + pub user_name: String, + pub connection_type: String, + pub ip: net::Ipv4Addr, + pub port: u16, + pub token: u32, + pub is_privileged: bool, } impl ReadFromPacket for ConnectToPeerResponse { - fn read_from_packet(packet: &mut Packet) -> Result { - let user_name = packet.read_value()?; - let connection_type = packet.read_value()?; - let ip = packet.read_value()?; - let port = packet.read_value()?; - let token = packet.read_value()?; - let is_privileged = packet.read_value()?; - - Ok(ConnectToPeerResponse { - user_name, - connection_type, - ip, - port, - token, - is_privileged, - }) - } + fn read_from_packet(packet: &mut Packet) -> Result { + let user_name = packet.read_value()?; + let connection_type = packet.read_value()?; + let ip = packet.read_value()?; + let port = packet.read_value()?; + let token = packet.read_value()?; + let is_privileged = packet.read_value()?; + + Ok(ConnectToPeerResponse { + user_name, + connection_type, + ip, + port, + token, + is_privileged, + }) + } } impl ValueEncode for ConnectToPeerResponse { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - encoder.encode(&self.user_name)?; - encoder.encode(&self.connection_type)?; - encoder.encode(&self.ip)?; - encoder.encode_u16(self.port)?; - encoder.encode_u32(self.token)?; - encoder.encode_bool(self.is_privileged) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + encoder.encode(&self.user_name)?; + encoder.encode(&self.connection_type)?; + encoder.encode(&self.ip)?; + encoder.encode_u16(self.port)?; + encoder.encode_u32(self.token)?; + encoder.encode_bool(self.is_privileged) + } } impl ValueDecode for ConnectToPeerResponse { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let user_name = decoder.decode()?; - let connection_type = decoder.decode()?; - let ip = decoder.decode()?; - let port = decoder.decode()?; - let token = decoder.decode()?; - let is_privileged = decoder.decode()?; - - Ok(ConnectToPeerResponse { - user_name, - connection_type, - ip, - port, - token, - is_privileged, - }) - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let user_name = decoder.decode()?; + let connection_type = decoder.decode()?; + let ip = decoder.decode()?; + let port = decoder.decode()?; + let token = decoder.decode()?; + let is_privileged = decoder.decode()?; + + Ok(ConnectToPeerResponse { + user_name, + connection_type, + ip, + port, + token, + is_privileged, + }) + } } /*=============* @@ -364,50 +348,45 @@ impl ValueDecode for ConnectToPeerResponse { #[derive(Debug, Eq, PartialEq)] pub struct FileSearchResponse { - pub user_name: String, - pub ticket: u32, - pub query: String, + pub user_name: String, + pub ticket: u32, + pub query: String, } impl ReadFromPacket for FileSearchResponse { - fn read_from_packet(packet: &mut Packet) -> Result { - let user_name = packet.read_value()?; - let ticket = packet.read_value()?; - let query = packet.read_value()?; - - Ok(FileSearchResponse { - user_name, - ticket, - query, - }) - } + fn read_from_packet(packet: &mut Packet) -> Result { + let user_name = packet.read_value()?; + let ticket = packet.read_value()?; + let query = packet.read_value()?; + + Ok(FileSearchResponse { + user_name, + ticket, + query, + }) + } } impl ValueEncode for FileSearchResponse { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - encoder.encode_string(&self.user_name)?; - encoder.encode_u32(self.ticket)?; - encoder.encode_string(&self.query) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + encoder.encode_string(&self.user_name)?; + encoder.encode_u32(self.ticket)?; + encoder.encode_string(&self.query) + } } impl ValueDecode for FileSearchResponse { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let user_name = decoder.decode()?; - let ticket = decoder.decode()?; - let query = decoder.decode()?; - - Ok(FileSearchResponse { - user_name, - ticket, - query, - }) - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let user_name = decoder.decode()?; + let ticket = decoder.decode()?; + let query = decoder.decode()?; + + Ok(FileSearchResponse { + user_name, + ticket, + query, + }) + } } /*=======* @@ -416,90 +395,85 @@ impl ValueDecode for FileSearchResponse { #[derive(Debug, Eq, PartialEq)] pub enum LoginResponse { - LoginOk { - motd: String, - ip: net::Ipv4Addr, - password_md5_opt: Option, - }, - LoginFail { - reason: String, - }, + LoginOk { + motd: String, + ip: net::Ipv4Addr, + password_md5_opt: Option, + }, + LoginFail { + reason: String, + }, } impl ReadFromPacket for LoginResponse { - fn read_from_packet(packet: &mut Packet) -> Result { - let ok = packet.read_value()?; - if ok { - let motd = packet.read_value()?; - let ip = packet.read_value()?; - - match packet.read_value::() { - Ok(value) => debug!("LoginResponse last field: {}", value), - Err(e) => debug!("Error reading LoginResponse field: {:?}", e), - } - - Ok(LoginResponse::LoginOk { - motd, - ip, - password_md5_opt: None, - }) - } else { - Ok(LoginResponse::LoginFail { - reason: packet.read_value()?, - }) - } - } + fn read_from_packet(packet: &mut Packet) -> Result { + let ok = packet.read_value()?; + if ok { + let motd = packet.read_value()?; + let ip = packet.read_value()?; + + match packet.read_value::() { + Ok(value) => debug!("LoginResponse last field: {}", value), + Err(e) => debug!("Error reading LoginResponse field: {:?}", e), + } + + Ok(LoginResponse::LoginOk { + motd, + ip, + password_md5_opt: None, + }) + } else { + Ok(LoginResponse::LoginFail { + reason: packet.read_value()?, + }) + } + } } impl ValueEncode for LoginResponse { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - match *self { - LoginResponse::LoginOk { - ref motd, - ip, - password_md5_opt: _, - } => { - encoder.encode_bool(true)?; - encoder.encode(motd)?; - encoder.encode(&ip)?; - } - LoginResponse::LoginFail { ref reason } => { - encoder.encode_bool(false)?; - encoder.encode(reason)?; - } - }; - Ok(()) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + match *self { + LoginResponse::LoginOk { + ref motd, + ip, + password_md5_opt: _, + } => { + encoder.encode_bool(true)?; + encoder.encode(motd)?; + encoder.encode(&ip)?; + } + LoginResponse::LoginFail { ref reason } => { + encoder.encode_bool(false)?; + encoder.encode(reason)?; + } + }; + Ok(()) + } } impl ValueDecode for LoginResponse { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let ok: bool = decoder.decode()?; - if !ok { - let reason = decoder.decode()?; - return Ok(LoginResponse::LoginFail { reason }); - } - - let motd = decoder.decode()?; - let ip = decoder.decode()?; + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let ok: bool = decoder.decode()?; + if !ok { + let reason = decoder.decode()?; + return Ok(LoginResponse::LoginFail { reason }); + } - let result = decoder.decode::(); - match result { - Ok(value) => debug!("LoginResponse last field: {}", value), - Err(e) => debug!("Error reading LoginResponse field: {:?}", e), - } + let motd = decoder.decode()?; + let ip = decoder.decode()?; - Ok(LoginResponse::LoginOk { - motd, - ip, - password_md5_opt: None, - }) + let result = decoder.decode::(); + match result { + Ok(value) => debug!("LoginResponse last field: {}", value), + Err(e) => debug!("Error reading LoginResponse field: {:?}", e), } + + Ok(LoginResponse::LoginOk { + motd, + ip, + password_md5_opt: None, + }) + } } /*==================* @@ -508,32 +482,27 @@ impl ValueDecode for LoginResponse { #[derive(Debug, Eq, PartialEq)] pub struct ParentMinSpeedResponse { - pub value: u32, + pub value: u32, } impl ReadFromPacket for ParentMinSpeedResponse { - fn read_from_packet(packet: &mut Packet) -> Result { - let value = packet.read_value()?; - Ok(ParentMinSpeedResponse { value }) - } + fn read_from_packet(packet: &mut Packet) -> Result { + let value = packet.read_value()?; + Ok(ParentMinSpeedResponse { value }) + } } impl ValueEncode for ParentMinSpeedResponse { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - encoder.encode_u32(self.value) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + encoder.encode_u32(self.value) + } } impl ValueDecode for ParentMinSpeedResponse { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let value = decoder.decode()?; - Ok(ParentMinSpeedResponse { value }) - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let value = decoder.decode()?; + Ok(ParentMinSpeedResponse { value }) + } } /*====================* @@ -542,32 +511,27 @@ impl ValueDecode for ParentMinSpeedResponse { #[derive(Debug, Eq, PartialEq)] pub struct ParentSpeedRatioResponse { - pub value: u32, + pub value: u32, } impl ReadFromPacket for ParentSpeedRatioResponse { - fn read_from_packet(packet: &mut Packet) -> Result { - let value = packet.read_value()?; - Ok(ParentSpeedRatioResponse { value }) - } + fn read_from_packet(packet: &mut Packet) -> Result { + let value = packet.read_value()?; + Ok(ParentSpeedRatioResponse { value }) + } } impl ValueEncode for ParentSpeedRatioResponse { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - encoder.encode_u32(self.value) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + encoder.encode_u32(self.value) + } } impl ValueDecode for ParentSpeedRatioResponse { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let value = decoder.decode()?; - Ok(ParentSpeedRatioResponse { value }) - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let value = decoder.decode()?; + Ok(ParentSpeedRatioResponse { value }) + } } /*==============* @@ -576,41 +540,36 @@ impl ValueDecode for ParentSpeedRatioResponse { #[derive(Debug, Eq, PartialEq)] pub struct PeerAddressResponse { - username: String, - ip: net::Ipv4Addr, - port: u16, + username: String, + ip: net::Ipv4Addr, + port: u16, } impl ReadFromPacket for PeerAddressResponse { - fn read_from_packet(packet: &mut Packet) -> Result { - let username = packet.read_value()?; - let ip = packet.read_value()?; - let port = packet.read_value()?; + fn read_from_packet(packet: &mut Packet) -> Result { + let username = packet.read_value()?; + let ip = packet.read_value()?; + let port = packet.read_value()?; - Ok(PeerAddressResponse { username, ip, port }) - } + Ok(PeerAddressResponse { username, ip, port }) + } } impl ValueEncode for PeerAddressResponse { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - encoder.encode(&self.username)?; - encoder.encode(&self.ip)?; - encoder.encode_u16(self.port) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + encoder.encode(&self.username)?; + encoder.encode(&self.ip)?; + encoder.encode_u16(self.port) + } } impl ValueDecode for PeerAddressResponse { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let username = decoder.decode()?; - let ip = decoder.decode()?; - let port = decoder.decode()?; - Ok(PeerAddressResponse { username, ip, port }) - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let username = decoder.decode()?; + let ip = decoder.decode()?; + let port = decoder.decode()?; + Ok(PeerAddressResponse { username, ip, port }) + } } /*==================* @@ -619,32 +578,27 @@ impl ValueDecode for PeerAddressResponse { #[derive(Debug, Eq, PartialEq)] pub struct PrivilegedUsersResponse { - pub users: Vec, + pub users: Vec, } impl ReadFromPacket for PrivilegedUsersResponse { - fn read_from_packet(packet: &mut Packet) -> Result { - let users = packet.read_value()?; - Ok(PrivilegedUsersResponse { users }) - } + fn read_from_packet(packet: &mut Packet) -> Result { + let users = packet.read_value()?; + Ok(PrivilegedUsersResponse { users }) + } } impl ValueEncode for PrivilegedUsersResponse { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - encoder.encode(&self.users) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + encoder.encode(&self.users) + } } impl ValueDecode for PrivilegedUsersResponse { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let users = decoder.decode()?; - Ok(PrivilegedUsersResponse { users }) - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let users = decoder.decode()?; + Ok(PrivilegedUsersResponse { users }) + } } /*===========* @@ -653,108 +607,104 @@ impl ValueDecode for PrivilegedUsersResponse { #[derive(Debug, Eq, PartialEq)] pub struct RoomJoinResponse { - pub room_name: String, - pub users: Vec, - pub owner: Option, - pub operators: Vec, + pub room_name: String, + pub users: Vec, + pub owner: Option, + pub operators: Vec, } impl ReadFromPacket for RoomJoinResponse { - fn read_from_packet(packet: &mut Packet) -> Result { - let mut response = RoomJoinResponse { - room_name: packet.read_value()?, - users: Vec::new(), - owner: None, - operators: Vec::new(), - }; - - let num_users: usize = packet.read_value()?; - for _ in 0..num_users { - let name: String = packet.read_value()?; - let user = User { - name, - status: UserStatus::Offline, - average_speed: 0, - num_downloads: 0, - unknown: 0, - num_files: 0, - num_folders: 0, - num_free_slots: 0, - country: String::new(), - }; - response.users.push(user); - } - - response.read_user_infos(packet)?; + fn read_from_packet(packet: &mut Packet) -> Result { + let mut response = RoomJoinResponse { + room_name: packet.read_value()?, + users: Vec::new(), + owner: None, + operators: Vec::new(), + }; + + let num_users: usize = packet.read_value()?; + for _ in 0..num_users { + let name: String = packet.read_value()?; + let user = User { + name, + status: UserStatus::Offline, + average_speed: 0, + num_downloads: 0, + unknown: 0, + num_files: 0, + num_folders: 0, + num_free_slots: 0, + country: String::new(), + }; + response.users.push(user); + } - if packet.bytes_remaining() > 0 { - response.owner = Some(packet.read_value()?); + response.read_user_infos(packet)?; - let num_operators: usize = packet.read_value()?; - for _ in 0..num_operators { - response.operators.push(packet.read_value()?); - } - } + if packet.bytes_remaining() > 0 { + response.owner = Some(packet.read_value()?); - Ok(response) + let num_operators: usize = packet.read_value()?; + for _ in 0..num_operators { + response.operators.push(packet.read_value()?); + } } + + Ok(response) + } } impl RoomJoinResponse { - fn read_user_infos( - &mut self, - packet: &mut Packet, - ) -> Result<(), PacketReadError> { - let num_statuses: usize = packet.read_value()?; - for i in 0..num_statuses { - if let Some(user) = self.users.get_mut(i) { - user.status = packet.read_value()?; - } - } - - let num_infos: usize = packet.read_value()?; - for i in 0..num_infos { - if let Some(user) = self.users.get_mut(i) { - user.average_speed = packet.read_value()?; - user.num_downloads = packet.read_value()?; - user.unknown = packet.read_value()?; - user.num_files = packet.read_value()?; - user.num_folders = packet.read_value()?; - } - } - - let num_free_slots: usize = packet.read_value()?; - for i in 0..num_free_slots { - if let Some(user) = self.users.get_mut(i) { - user.num_free_slots = packet.read_value()?; - } - } - - let num_countries: usize = packet.read_value()?; - for i in 0..num_countries { - if let Some(user) = self.users.get_mut(i) { - user.country = packet.read_value()?; - } - } - - let num_users = self.users.len(); - if num_users != num_statuses - || num_users != num_infos - || num_users != num_free_slots - || num_users != num_countries - { - warn!( - "RoomJoinResponse: mismatched vector sizes {}, {}, {}, {}, {}", - num_users, - num_statuses, - num_infos, - num_free_slots, - num_countries - ); - } - - Ok(()) - } + fn read_user_infos( + &mut self, + packet: &mut Packet, + ) -> Result<(), PacketReadError> { + let num_statuses: usize = packet.read_value()?; + for i in 0..num_statuses { + if let Some(user) = self.users.get_mut(i) { + user.status = packet.read_value()?; + } + } + + let num_infos: usize = packet.read_value()?; + for i in 0..num_infos { + if let Some(user) = self.users.get_mut(i) { + user.average_speed = packet.read_value()?; + user.num_downloads = packet.read_value()?; + user.unknown = packet.read_value()?; + user.num_files = packet.read_value()?; + user.num_folders = packet.read_value()?; + } + } + + let num_free_slots: usize = packet.read_value()?; + for i in 0..num_free_slots { + if let Some(user) = self.users.get_mut(i) { + user.num_free_slots = packet.read_value()?; + } + } + + let num_countries: usize = packet.read_value()?; + for i in 0..num_countries { + if let Some(user) = self.users.get_mut(i) { + user.country = packet.read_value()?; + } + } + + let num_users = self.users.len(); + if num_users != num_statuses + || num_users != num_infos + || num_users != num_free_slots + || num_users != num_countries + { + warn!( + "RoomJoinResponse: mismatched vector sizes {}, {}, {}, {}, {}", + num_users, num_statuses, num_infos, num_free_slots, num_countries + ); + } + + Ok(()) + } } // This struct is defined to enable decoding a vector of such values for @@ -762,176 +712,162 @@ impl RoomJoinResponse { // For details about individual fields, see said `User` struct. #[derive(Debug, Eq, PartialEq)] struct UserInfo { - average_speed: u32, - num_downloads: u32, - unknown: u32, - num_files: u32, - num_folders: u32, + average_speed: u32, + num_downloads: u32, + unknown: u32, + num_files: u32, + num_folders: u32, } impl UserInfo { - fn from_user(user: &User) -> Self { - Self { - average_speed: user.average_speed as u32, - num_downloads: user.num_downloads as u32, - unknown: user.unknown as u32, - num_files: user.num_files as u32, - num_folders: user.num_folders as u32, - } + fn from_user(user: &User) -> Self { + Self { + average_speed: user.average_speed as u32, + num_downloads: user.num_downloads as u32, + unknown: user.unknown as u32, + num_files: user.num_files as u32, + num_folders: user.num_folders as u32, } + } } fn build_user( - name: String, - status: UserStatus, - info: UserInfo, - num_free_slots: u32, - country: String, + name: String, + status: UserStatus, + info: UserInfo, + num_free_slots: u32, + country: String, ) -> User { - User { - name, - status, - average_speed: info.average_speed as usize, - num_downloads: info.num_downloads as usize, - unknown: info.unknown as usize, - num_files: info.num_files as usize, - num_folders: info.num_folders as usize, - num_free_slots: num_free_slots as usize, - country, - } + User { + name, + status, + average_speed: info.average_speed as usize, + num_downloads: info.num_downloads as usize, + unknown: info.unknown as usize, + num_files: info.num_files as usize, + num_folders: info.num_folders as usize, + num_free_slots: num_free_slots as usize, + country, + } } impl ValueEncode for UserInfo { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - encoder.encode_u32(self.average_speed)?; - encoder.encode_u32(self.num_downloads)?; - encoder.encode_u32(self.unknown)?; - encoder.encode_u32(self.num_files)?; - encoder.encode_u32(self.num_folders) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + encoder.encode_u32(self.average_speed)?; + encoder.encode_u32(self.num_downloads)?; + encoder.encode_u32(self.unknown)?; + encoder.encode_u32(self.num_files)?; + encoder.encode_u32(self.num_folders) + } } impl ValueDecode for UserInfo { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let average_speed = decoder.decode()?; - let num_downloads = decoder.decode()?; - let unknown = decoder.decode()?; - let num_files = decoder.decode()?; - let num_folders = decoder.decode()?; - Ok(UserInfo { - average_speed, - num_downloads, - unknown, - num_files, - num_folders, - }) - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let average_speed = decoder.decode()?; + let num_downloads = decoder.decode()?; + let unknown = decoder.decode()?; + let num_files = decoder.decode()?; + let num_folders = decoder.decode()?; + Ok(UserInfo { + average_speed, + num_downloads, + unknown, + num_files, + num_folders, + }) + } } impl ValueEncode for RoomJoinResponse { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - let mut user_names = vec![]; - let mut user_statuses = vec![]; - let mut user_infos = vec![]; - let mut user_free_slots = vec![]; - let mut user_countries = vec![]; - for user in &self.users { - user_names.push(&user.name); - user_statuses.push(user.status); - user_infos.push(UserInfo::from_user(user)); - user_free_slots.push(user.num_free_slots as u32); - user_countries.push(&user.country); - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + let mut user_names = vec![]; + let mut user_statuses = vec![]; + let mut user_infos = vec![]; + let mut user_free_slots = vec![]; + let mut user_countries = vec![]; + for user in &self.users { + user_names.push(&user.name); + user_statuses.push(user.status); + user_infos.push(UserInfo::from_user(user)); + user_free_slots.push(user.num_free_slots as u32); + user_countries.push(&user.country); + } - encoder.encode_string(&self.room_name)?; - encoder.encode(&user_names)?; - encoder.encode(&user_statuses)?; - encoder.encode(&user_infos)?; - encoder.encode(&user_free_slots)?; - encoder.encode(&user_countries)?; + encoder.encode_string(&self.room_name)?; + encoder.encode(&user_names)?; + encoder.encode(&user_statuses)?; + encoder.encode(&user_infos)?; + encoder.encode(&user_free_slots)?; + encoder.encode(&user_countries)?; - if let Some(ref owner) = self.owner { - encoder.encode_string(owner)?; - encoder.encode(&self.operators)?; - } - - Ok(()) + if let Some(ref owner) = self.owner { + encoder.encode_string(owner)?; + encoder.encode(&self.operators)?; } + + Ok(()) + } } fn build_users( - mut names: Vec, - mut statuses: Vec, - mut infos: Vec, - mut free_slots: Vec, - mut countries: Vec, + mut names: Vec, + mut statuses: Vec, + mut infos: Vec, + mut free_slots: Vec, + mut countries: Vec, ) -> Vec { - let mut users = vec![]; + let mut users = vec![]; - loop { - let name_opt = names.pop(); - let status_opt = statuses.pop(); - let info_opt = infos.pop(); - let slots_opt = free_slots.pop(); - let country_opt = countries.pop(); - - match (name_opt, status_opt, info_opt, slots_opt, country_opt) { - ( - Some(name), - Some(status), - Some(info), - Some(slots), - Some(country), - ) => users.push(build_user(name, status, info, slots, country)), - _ => break, - } + loop { + let name_opt = names.pop(); + let status_opt = statuses.pop(); + let info_opt = infos.pop(); + let slots_opt = free_slots.pop(); + let country_opt = countries.pop(); + + match (name_opt, status_opt, info_opt, slots_opt, country_opt) { + (Some(name), Some(status), Some(info), Some(slots), Some(country)) => { + users.push(build_user(name, status, info, slots, country)) + } + _ => break, } + } - users.reverse(); - users + users.reverse(); + users } impl ValueDecode for RoomJoinResponse { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let room_name = decoder.decode()?; - let user_names = decoder.decode()?; - let user_statuses = decoder.decode()?; - let user_infos = decoder.decode()?; - let user_free_slots = decoder.decode()?; - let user_countries = decoder.decode()?; - - let mut owner = None; - let mut operators = vec![]; - if decoder.has_remaining() { - owner = Some(decoder.decode()?); - operators = decoder.decode()?; - } - - let users = build_users( - user_names, - user_statuses, - user_infos, - user_free_slots, - user_countries, - ); - - Ok(RoomJoinResponse { - room_name, - users, - owner, - operators, - }) - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let room_name = decoder.decode()?; + let user_names = decoder.decode()?; + let user_statuses = decoder.decode()?; + let user_infos = decoder.decode()?; + let user_free_slots = decoder.decode()?; + let user_countries = decoder.decode()?; + + let mut owner = None; + let mut operators = vec![]; + if decoder.has_remaining() { + owner = Some(decoder.decode()?); + operators = decoder.decode()?; + } + + let users = build_users( + user_names, + user_statuses, + user_infos, + user_free_slots, + user_countries, + ); + + Ok(RoomJoinResponse { + room_name, + users, + owner, + operators, + }) + } } /*============* @@ -940,33 +876,28 @@ impl ValueDecode for RoomJoinResponse { #[derive(Debug, Eq, PartialEq)] pub struct RoomLeaveResponse { - pub room_name: String, + pub room_name: String, } impl ReadFromPacket for RoomLeaveResponse { - fn read_from_packet(packet: &mut Packet) -> Result { - Ok(RoomLeaveResponse { - room_name: packet.read_value()?, - }) - } + fn read_from_packet(packet: &mut Packet) -> Result { + Ok(RoomLeaveResponse { + room_name: packet.read_value()?, + }) + } } impl ValueEncode for RoomLeaveResponse { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - encoder.encode_string(&self.room_name) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + encoder.encode_string(&self.room_name) + } } impl ValueDecode for RoomLeaveResponse { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let room_name = decoder.decode()?; - Ok(RoomLeaveResponse { room_name }) - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let room_name = decoder.decode()?; + Ok(RoomLeaveResponse { room_name }) + } } /*===========* @@ -975,142 +906,137 @@ impl ValueDecode for RoomLeaveResponse { #[derive(Debug, Eq, PartialEq)] pub struct RoomListResponse { - pub rooms: Vec<(String, u32)>, - pub owned_private_rooms: Vec<(String, u32)>, - pub other_private_rooms: Vec<(String, u32)>, - pub operated_private_room_names: Vec, + pub rooms: Vec<(String, u32)>, + pub owned_private_rooms: Vec<(String, u32)>, + pub other_private_rooms: Vec<(String, u32)>, + pub operated_private_room_names: Vec, } impl ReadFromPacket for RoomListResponse { - fn read_from_packet(packet: &mut Packet) -> Result { - let rooms = Self::read_rooms(packet)?; - let owned_private_rooms = Self::read_rooms(packet)?; - let other_private_rooms = Self::read_rooms(packet)?; - let operated_private_room_names = packet.read_value()?; - Ok(RoomListResponse { - rooms, - owned_private_rooms, - other_private_rooms, - operated_private_room_names, - }) - } + fn read_from_packet(packet: &mut Packet) -> Result { + let rooms = Self::read_rooms(packet)?; + let owned_private_rooms = Self::read_rooms(packet)?; + let other_private_rooms = Self::read_rooms(packet)?; + let operated_private_room_names = packet.read_value()?; + Ok(RoomListResponse { + rooms, + owned_private_rooms, + other_private_rooms, + operated_private_room_names, + }) + } } impl RoomListResponse { - fn read_rooms( - packet: &mut Packet, - ) -> Result, PacketReadError> { - let num_rooms: usize = packet.read_value()?; - let mut rooms = Vec::new(); - for _ in 0..num_rooms { - let room_name = packet.read_value()?; - rooms.push((room_name, 0)); - } - - let num_user_counts: usize = packet.read_value()?; - for i in 0..num_user_counts { - if let Some(&mut (_, ref mut count)) = rooms.get_mut(i) { - *count = packet.read_value()?; - } - } + fn read_rooms( + packet: &mut Packet, + ) -> Result, PacketReadError> { + let num_rooms: usize = packet.read_value()?; + let mut rooms = Vec::new(); + for _ in 0..num_rooms { + let room_name = packet.read_value()?; + rooms.push((room_name, 0)); + } - if num_rooms != num_user_counts { - warn!( - "Numbers of rooms and user counts do not match: {} != {}", - num_rooms, num_user_counts - ); - } + let num_user_counts: usize = packet.read_value()?; + for i in 0..num_user_counts { + if let Some(&mut (_, ref mut count)) = rooms.get_mut(i) { + *count = packet.read_value()?; + } + } - Ok(rooms) + if num_rooms != num_user_counts { + warn!( + "Numbers of rooms and user counts do not match: {} != {}", + num_rooms, num_user_counts + ); } - fn build_rooms( - mut room_names: Vec, - mut user_counts: Vec, - ) -> Vec<(String, u32)> { - let mut rooms = vec![]; - - loop { - let room_name_opt = room_names.pop(); - let user_count_opt = user_counts.pop(); - - match (room_name_opt, user_count_opt) { - (Some(room_name), Some(user_count)) => { - rooms.push((room_name, user_count)) - } - _ => break, - } - } + Ok(rooms) + } - if !room_names.is_empty() { - warn!( - "Unmatched room names in room list response: {:?}", - room_names - ) - } - if !user_counts.is_empty() { - warn!( - "Unmatched user counts in room list response: {:?}", - user_counts - ) - } + fn build_rooms( + mut room_names: Vec, + mut user_counts: Vec, + ) -> Vec<(String, u32)> { + let mut rooms = vec![]; + + loop { + let room_name_opt = room_names.pop(); + let user_count_opt = user_counts.pop(); - rooms.reverse(); - rooms + match (room_name_opt, user_count_opt) { + (Some(room_name), Some(user_count)) => { + rooms.push((room_name, user_count)) + } + _ => break, + } } - fn decode_rooms( - decoder: &mut ValueDecoder, - ) -> Result, ValueDecodeError> { - let room_names = decoder.decode()?; - let user_counts = decoder.decode()?; - Ok(Self::build_rooms(room_names, user_counts)) + if !room_names.is_empty() { + warn!( + "Unmatched room names in room list response: {:?}", + room_names + ) + } + if !user_counts.is_empty() { + warn!( + "Unmatched user counts in room list response: {:?}", + user_counts + ) } - fn encode_rooms( - rooms: &[(String, u32)], - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - let mut room_names = vec![]; - let mut user_counts = vec![]; + rooms.reverse(); + rooms + } - for &(ref room_name, user_count) in rooms { - room_names.push(room_name); - user_counts.push(user_count); - } + fn decode_rooms( + decoder: &mut ValueDecoder, + ) -> Result, ValueDecodeError> { + let room_names = decoder.decode()?; + let user_counts = decoder.decode()?; + Ok(Self::build_rooms(room_names, user_counts)) + } + + fn encode_rooms( + rooms: &[(String, u32)], + encoder: &mut ValueEncoder, + ) -> Result<(), ValueEncodeError> { + let mut room_names = vec![]; + let mut user_counts = vec![]; - encoder.encode(&room_names)?; - encoder.encode(&user_counts) + for &(ref room_name, user_count) in rooms { + room_names.push(room_name); + user_counts.push(user_count); } + + encoder.encode(&room_names)?; + encoder.encode(&user_counts) + } } impl ValueEncode for RoomListResponse { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - Self::encode_rooms(&self.rooms, encoder)?; - Self::encode_rooms(&self.owned_private_rooms, encoder)?; - Self::encode_rooms(&self.other_private_rooms, encoder)?; - encoder.encode(&self.operated_private_room_names) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + Self::encode_rooms(&self.rooms, encoder)?; + Self::encode_rooms(&self.owned_private_rooms, encoder)?; + Self::encode_rooms(&self.other_private_rooms, encoder)?; + encoder.encode(&self.operated_private_room_names) + } } impl ValueDecode for RoomListResponse { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let rooms = RoomListResponse::decode_rooms(decoder)?; - let owned_private_rooms = RoomListResponse::decode_rooms(decoder)?; - let other_private_rooms = RoomListResponse::decode_rooms(decoder)?; - let operated_private_room_names = decoder.decode()?; - Ok(RoomListResponse { - rooms, - owned_private_rooms, - other_private_rooms, - operated_private_room_names, - }) - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let rooms = RoomListResponse::decode_rooms(decoder)?; + let owned_private_rooms = RoomListResponse::decode_rooms(decoder)?; + let other_private_rooms = RoomListResponse::decode_rooms(decoder)?; + let operated_private_room_names = decoder.decode()?; + Ok(RoomListResponse { + rooms, + owned_private_rooms, + other_private_rooms, + operated_private_room_names, + }) + } } /*==============* @@ -1119,48 +1045,43 @@ impl ValueDecode for RoomListResponse { #[derive(Debug, Eq, PartialEq)] pub struct RoomMessageResponse { - pub room_name: String, - pub user_name: String, - pub message: String, + pub room_name: String, + pub user_name: String, + pub message: String, } impl ReadFromPacket for RoomMessageResponse { - fn read_from_packet(packet: &mut Packet) -> Result { - let room_name = packet.read_value()?; - let user_name = packet.read_value()?; - let message = packet.read_value()?; - Ok(RoomMessageResponse { - room_name, - user_name, - message, - }) - } + fn read_from_packet(packet: &mut Packet) -> Result { + let room_name = packet.read_value()?; + let user_name = packet.read_value()?; + let message = packet.read_value()?; + Ok(RoomMessageResponse { + room_name, + user_name, + message, + }) + } } impl ValueEncode for RoomMessageResponse { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - encoder.encode_string(&self.room_name)?; - encoder.encode_string(&self.user_name)?; - encoder.encode_string(&self.message) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + encoder.encode_string(&self.room_name)?; + encoder.encode_string(&self.user_name)?; + encoder.encode_string(&self.message) + } } impl ValueDecode for RoomMessageResponse { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let room_name = decoder.decode()?; - let user_name = decoder.decode()?; - let message = decoder.decode()?; - Ok(RoomMessageResponse { - room_name, - user_name, - message, - }) - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let room_name = decoder.decode()?; + let user_name = decoder.decode()?; + let message = decoder.decode()?; + Ok(RoomMessageResponse { + room_name, + user_name, + message, + }) + } } /*==============* @@ -1169,44 +1090,39 @@ impl ValueDecode for RoomMessageResponse { #[derive(Debug, Eq, PartialEq)] pub struct RoomTickersResponse { - pub room_name: String, - pub tickers: Vec<(String, String)>, + pub room_name: String, + pub tickers: Vec<(String, String)>, } impl ReadFromPacket for RoomTickersResponse { - fn read_from_packet(packet: &mut Packet) -> Result { - let room_name = packet.read_value()?; - - let num_tickers: usize = packet.read_value()?; - let mut tickers = Vec::new(); - for _ in 0..num_tickers { - let user_name = packet.read_value()?; - let message = packet.read_value()?; - tickers.push((user_name, message)) - } + fn read_from_packet(packet: &mut Packet) -> Result { + let room_name = packet.read_value()?; - Ok(RoomTickersResponse { room_name, tickers }) + let num_tickers: usize = packet.read_value()?; + let mut tickers = Vec::new(); + for _ in 0..num_tickers { + let user_name = packet.read_value()?; + let message = packet.read_value()?; + tickers.push((user_name, message)) } + + Ok(RoomTickersResponse { room_name, tickers }) + } } impl ValueEncode for RoomTickersResponse { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - encoder.encode_string(&self.room_name)?; - encoder.encode(&self.tickers) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + encoder.encode_string(&self.room_name)?; + encoder.encode(&self.tickers) + } } impl ValueDecode for RoomTickersResponse { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let room_name = decoder.decode()?; - let tickers = decoder.decode()?; - Ok(RoomTickersResponse { room_name, tickers }) - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let room_name = decoder.decode()?; + let tickers = decoder.decode()?; + Ok(RoomTickersResponse { room_name, tickers }) + } } /*==================* @@ -1215,72 +1131,67 @@ impl ValueDecode for RoomTickersResponse { #[derive(Debug, Eq, PartialEq)] pub struct RoomUserJoinedResponse { - pub room_name: String, - pub user: User, + pub room_name: String, + pub user: User, } impl ReadFromPacket for RoomUserJoinedResponse { - fn read_from_packet(packet: &mut Packet) -> Result { - let room_name = packet.read_value()?; - let user_name = packet.read_value()?; - - let status = packet.read_value()?; - - let average_speed = packet.read_value()?; - let num_downloads = packet.read_value()?; - let unknown = packet.read_value()?; - let num_files = packet.read_value()?; - let num_folders = packet.read_value()?; - let num_free_slots = packet.read_value()?; - - let country = packet.read_value()?; - - Ok(RoomUserJoinedResponse { - room_name, - user: User { - name: user_name, - status, - average_speed, - num_downloads, - unknown, - num_files, - num_folders, - num_free_slots, - country, - }, - }) - } + fn read_from_packet(packet: &mut Packet) -> Result { + let room_name = packet.read_value()?; + let user_name = packet.read_value()?; + + let status = packet.read_value()?; + + let average_speed = packet.read_value()?; + let num_downloads = packet.read_value()?; + let unknown = packet.read_value()?; + let num_files = packet.read_value()?; + let num_folders = packet.read_value()?; + let num_free_slots = packet.read_value()?; + + let country = packet.read_value()?; + + Ok(RoomUserJoinedResponse { + room_name, + user: User { + name: user_name, + status, + average_speed, + num_downloads, + unknown, + num_files, + num_folders, + num_free_slots, + country, + }, + }) + } } impl ValueEncode for RoomUserJoinedResponse { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - encoder.encode_string(&self.room_name)?; - encoder.encode_string(&self.user.name)?; - self.user.status.encode(encoder)?; - UserInfo::from_user(&self.user).encode(encoder)?; - encoder.encode_u32(self.user.num_free_slots as u32)?; - encoder.encode_string(&self.user.country) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + encoder.encode_string(&self.room_name)?; + encoder.encode_string(&self.user.name)?; + self.user.status.encode(encoder)?; + UserInfo::from_user(&self.user).encode(encoder)?; + encoder.encode_u32(self.user.num_free_slots as u32)?; + encoder.encode_string(&self.user.country) + } } impl ValueDecode for RoomUserJoinedResponse { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let room_name = decoder.decode()?; - let user_name = decoder.decode()?; - let status = decoder.decode()?; - let info = decoder.decode()?; - let num_free_slots = decoder.decode()?; - let country = decoder.decode()?; - Ok(RoomUserJoinedResponse { - room_name, - user: build_user(user_name, status, info, num_free_slots, country), - }) - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let room_name = decoder.decode()?; + let user_name = decoder.decode()?; + let status = decoder.decode()?; + let info = decoder.decode()?; + let num_free_slots = decoder.decode()?; + let country = decoder.decode()?; + Ok(RoomUserJoinedResponse { + room_name, + user: build_user(user_name, status, info, num_free_slots, country), + }) + } } /*================* @@ -1289,42 +1200,37 @@ impl ValueDecode for RoomUserJoinedResponse { #[derive(Debug, Eq, PartialEq)] pub struct RoomUserLeftResponse { - pub room_name: String, - pub user_name: String, + pub room_name: String, + pub user_name: String, } impl ReadFromPacket for RoomUserLeftResponse { - fn read_from_packet(packet: &mut Packet) -> Result { - let room_name = packet.read_value()?; - let user_name = packet.read_value()?; - Ok(RoomUserLeftResponse { - room_name, - user_name, - }) - } + fn read_from_packet(packet: &mut Packet) -> Result { + let room_name = packet.read_value()?; + let user_name = packet.read_value()?; + Ok(RoomUserLeftResponse { + room_name, + user_name, + }) + } } impl ValueEncode for RoomUserLeftResponse { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - encoder.encode_string(&self.room_name)?; - encoder.encode_string(&self.user_name) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + encoder.encode_string(&self.room_name)?; + encoder.encode_string(&self.user_name) + } } impl ValueDecode for RoomUserLeftResponse { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let room_name = decoder.decode()?; - let user_name = decoder.decode()?; - Ok(RoomUserLeftResponse { - room_name, - user_name, - }) - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let room_name = decoder.decode()?; + let user_name = decoder.decode()?; + Ok(RoomUserLeftResponse { + room_name, + user_name, + }) + } } /*===========* @@ -1333,60 +1239,55 @@ impl ValueDecode for RoomUserLeftResponse { #[derive(Debug, Eq, PartialEq)] pub struct UserInfoResponse { - pub user_name: String, - pub average_speed: usize, - pub num_downloads: usize, - pub num_files: usize, - pub num_folders: usize, + pub user_name: String, + pub average_speed: usize, + pub num_downloads: usize, + pub num_files: usize, + pub num_folders: usize, } impl ReadFromPacket for UserInfoResponse { - fn read_from_packet(packet: &mut Packet) -> Result { - let user_name = packet.read_value()?; - let average_speed = packet.read_value()?; - let num_downloads = packet.read_value()?; - let num_files = packet.read_value()?; - let num_folders = packet.read_value()?; - Ok(UserInfoResponse { - user_name, - average_speed, - num_downloads, - num_files, - num_folders, - }) - } + fn read_from_packet(packet: &mut Packet) -> Result { + let user_name = packet.read_value()?; + let average_speed = packet.read_value()?; + let num_downloads = packet.read_value()?; + let num_files = packet.read_value()?; + let num_folders = packet.read_value()?; + Ok(UserInfoResponse { + user_name, + average_speed, + num_downloads, + num_files, + num_folders, + }) + } } impl ValueEncode for UserInfoResponse { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - encoder.encode_string(&self.user_name)?; - encoder.encode_u32(self.average_speed as u32)?; - encoder.encode_u32(self.num_downloads as u32)?; - encoder.encode_u32(self.num_files as u32)?; - encoder.encode_u32(self.num_folders as u32) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + encoder.encode_string(&self.user_name)?; + encoder.encode_u32(self.average_speed as u32)?; + encoder.encode_u32(self.num_downloads as u32)?; + encoder.encode_u32(self.num_files as u32)?; + encoder.encode_u32(self.num_folders as u32) + } } impl ValueDecode for UserInfoResponse { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let user_name = decoder.decode()?; - let average_speed: u32 = decoder.decode()?; - let num_downloads: u32 = decoder.decode()?; - let num_files: u32 = decoder.decode()?; - let num_folders: u32 = decoder.decode()?; - Ok(UserInfoResponse { - user_name, - average_speed: average_speed as usize, - num_downloads: num_downloads as usize, - num_files: num_files as usize, - num_folders: num_folders as usize, - }) - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let user_name = decoder.decode()?; + let average_speed: u32 = decoder.decode()?; + let num_downloads: u32 = decoder.decode()?; + let num_files: u32 = decoder.decode()?; + let num_folders: u32 = decoder.decode()?; + Ok(UserInfoResponse { + user_name, + average_speed: average_speed as usize, + num_downloads: num_downloads as usize, + num_files: num_files as usize, + num_folders: num_folders as usize, + }) + } } /*=============* @@ -1395,48 +1296,43 @@ impl ValueDecode for UserInfoResponse { #[derive(Debug, Eq, PartialEq)] pub struct UserStatusResponse { - pub user_name: String, - pub status: UserStatus, - pub is_privileged: bool, + pub user_name: String, + pub status: UserStatus, + pub is_privileged: bool, } impl ReadFromPacket for UserStatusResponse { - fn read_from_packet(packet: &mut Packet) -> Result { - let user_name = packet.read_value()?; - let status = packet.read_value()?; - let is_privileged = packet.read_value()?; - Ok(UserStatusResponse { - user_name, - status, - is_privileged, - }) - } + fn read_from_packet(packet: &mut Packet) -> Result { + let user_name = packet.read_value()?; + let status = packet.read_value()?; + let is_privileged = packet.read_value()?; + Ok(UserStatusResponse { + user_name, + status, + is_privileged, + }) + } } impl ValueEncode for UserStatusResponse { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - encoder.encode_string(&self.user_name)?; - self.status.encode(encoder)?; - encoder.encode_bool(self.is_privileged) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + encoder.encode_string(&self.user_name)?; + self.status.encode(encoder)?; + encoder.encode_bool(self.is_privileged) + } } impl ValueDecode for UserStatusResponse { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let user_name = decoder.decode()?; - let status = decoder.decode()?; - let is_privileged = decoder.decode()?; - Ok(UserStatusResponse { - user_name, - status, - is_privileged, - }) - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let user_name = decoder.decode()?; + let status = decoder.decode()?; + let is_privileged = decoder.decode()?; + Ok(UserStatusResponse { + user_name, + status, + is_privileged, + }) + } } /*===================* @@ -1445,32 +1341,27 @@ impl ValueDecode for UserStatusResponse { #[derive(Debug, Eq, PartialEq)] pub struct WishlistIntervalResponse { - pub seconds: u32, + pub seconds: u32, } impl ReadFromPacket for WishlistIntervalResponse { - fn read_from_packet(packet: &mut Packet) -> Result { - let seconds = packet.read_value()?; - Ok(WishlistIntervalResponse { seconds }) - } + fn read_from_packet(packet: &mut Packet) -> Result { + let seconds = packet.read_value()?; + Ok(WishlistIntervalResponse { seconds }) + } } impl ValueEncode for WishlistIntervalResponse { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - encoder.encode_u32(self.seconds) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + encoder.encode_u32(self.seconds) + } } impl ValueDecode for WishlistIntervalResponse { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let seconds = decoder.decode()?; - Ok(WishlistIntervalResponse { seconds }) - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let seconds = decoder.decode()?; + Ok(WishlistIntervalResponse { seconds }) + } } /*=======* @@ -1479,249 +1370,249 @@ impl ValueDecode for WishlistIntervalResponse { #[cfg(test)] mod tests { - use std::net; - - use bytes::BytesMut; - - use crate::proto::value_codec::tests::roundtrip; - use crate::proto::{ValueDecodeError, ValueDecoder}; - - use super::*; - - #[test] - fn invalid_code() { - let mut bytes = BytesMut::new(); - bytes.extend_from_slice(&[57, 5, 0, 0]); - - let result = ValueDecoder::new(&bytes).decode::(); - - assert_eq!( - result, - Err(ValueDecodeError::InvalidData { - value_name: "server response code".to_string(), - cause: "unknown value 1337".to_string(), - position: 0, - }) - ); - } - - #[test] - fn roundtrip_connect_to_peer() { - roundtrip(ServerResponse::ConnectToPeerResponse( - ConnectToPeerResponse { - user_name: "alice".to_string(), - connection_type: "P".to_string(), - ip: net::Ipv4Addr::new(192, 168, 254, 1), - port: 1337, - token: 42, - is_privileged: true, - }, - )) - } - - #[test] - fn roundtrip_file_search() { - roundtrip(ServerResponse::FileSearchResponse(FileSearchResponse { - user_name: "alice".to_string(), - ticket: 1337, - query: "foo.txt".to_string(), - })) - } - - #[test] - fn roundtrip_login_ok() { - roundtrip(ServerResponse::LoginResponse(LoginResponse::LoginOk { - motd: "welcome one welcome all!".to_string(), - ip: net::Ipv4Addr::new(127, 0, 0, 1), - password_md5_opt: None, - })) - } - - #[test] - fn roundtrip_login_fail() { - roundtrip(ServerResponse::LoginResponse(LoginResponse::LoginFail { - reason: "I just don't like you".to_string(), - })) - } - - #[test] - fn roundtrip_parent_min_speed() { - roundtrip(ServerResponse::ParentMinSpeedResponse( - ParentMinSpeedResponse { value: 1337 }, - )) - } - - #[test] - fn roundtrip_parent_speed_ratio() { - roundtrip(ServerResponse::ParentSpeedRatioResponse( - ParentSpeedRatioResponse { value: 1337 }, - )) - } - - #[test] - fn roundtrip_peer_address() { - roundtrip(ServerResponse::PeerAddressResponse(PeerAddressResponse { - username: "alice".to_string(), - ip: net::Ipv4Addr::new(127, 0, 0, 1), - port: 1337, - })) - } - - #[test] - fn roundtrip_privileged_users() { - roundtrip(ServerResponse::PrivilegedUsersResponse( - PrivilegedUsersResponse { - users: vec![ - "alice".to_string(), - "bob".to_string(), - "chris".to_string(), - "dory".to_string(), - ], - }, - )) - } - - #[test] - fn roundtrip_room_join() { - roundtrip(ServerResponse::RoomJoinResponse(RoomJoinResponse { - room_name: "red".to_string(), - users: vec![ - User { - name: "alice".to_string(), - status: UserStatus::Online, - average_speed: 1000, - num_downloads: 1001, - unknown: 1002, - num_files: 1003, - num_folders: 1004, - num_free_slots: 1005, - country: "US".to_string(), - }, - User { - name: "barbara".to_string(), - status: UserStatus::Away, - average_speed: 2000, - num_downloads: 2001, - unknown: 2002, - num_files: 2003, - num_folders: 2004, - num_free_slots: 2005, - country: "DE".to_string(), - }, - ], - owner: Some("carol".to_string()), - operators: vec!["deirdre".to_string(), "erica".to_string()], - })) - } - - #[test] - fn roundtrip_room_join_no_owner() { - roundtrip(ServerResponse::RoomJoinResponse(RoomJoinResponse { - room_name: "red".to_string(), - users: vec![], - owner: None, - operators: vec![], - })) - } - - #[test] - fn roundtrip_room_leave() { - roundtrip(ServerResponse::RoomLeaveResponse(RoomLeaveResponse { - room_name: "red".to_string(), - })) - } - - #[test] - fn roundtrip_room_list() { - roundtrip(ServerResponse::RoomListResponse(RoomListResponse { - rooms: vec![("red".to_string(), 12), ("blue".to_string(), 13)], - owned_private_rooms: vec![ - ("green".to_string(), 14), - ("purple".to_string(), 15), - ], - other_private_rooms: vec![ - ("yellow".to_string(), 16), - ("orange".to_string(), 17), - ], - operated_private_room_names: vec![ - "brown".to_string(), - "pink".to_string(), - ], - })) - } - - #[test] - fn roundtrip_room_message() { - roundtrip(ServerResponse::RoomMessageResponse(RoomMessageResponse { - room_name: "red".to_string(), - user_name: "alice".to_string(), - message: "hello world!".to_string(), - })) - } - - #[test] - fn roundtrip_room_tickers() { - roundtrip(ServerResponse::RoomTickersResponse(RoomTickersResponse { - room_name: "red".to_string(), - tickers: vec![ - ("alice".to_string(), "hello world!".to_string()), - ("bob".to_string(), "hi alice :)".to_string()), - ], - })) - } - - #[test] - fn roundtrip_room_user_joined() { - roundtrip(ServerResponse::RoomUserJoinedResponse( - RoomUserJoinedResponse { - room_name: "red".to_string(), - user: User { - name: "alice".to_string(), - status: UserStatus::Online, - average_speed: 1000, - num_downloads: 1001, - unknown: 1002, - num_files: 1003, - num_folders: 1004, - num_free_slots: 1005, - country: "AR".to_string(), - }, - }, - )) - } - - #[test] - fn roundtrip_room_user_left() { - roundtrip(ServerResponse::RoomUserLeftResponse(RoomUserLeftResponse { - room_name: "red".to_string(), - user_name: "alice".to_string(), - })) - } - - #[test] - fn roundtrip_user_info() { - roundtrip(ServerResponse::UserInfoResponse(UserInfoResponse { - user_name: "alice".to_string(), - average_speed: 1000, - num_downloads: 1001, - num_files: 1002, - num_folders: 1003, - })) - } - - #[test] - fn roundtrip_user_status() { - roundtrip(ServerResponse::UserStatusResponse(UserStatusResponse { - user_name: "alice".to_string(), - status: UserStatus::Offline, - is_privileged: true, - })) - } - - #[test] - fn roundtrip_wishlist_interval() { - roundtrip(ServerResponse::WishlistIntervalResponse( - WishlistIntervalResponse { seconds: 1337 }, - )) - } + use std::net; + + use bytes::BytesMut; + + use crate::proto::value_codec::tests::roundtrip; + use crate::proto::{ValueDecodeError, ValueDecoder}; + + use super::*; + + #[test] + fn invalid_code() { + let mut bytes = BytesMut::new(); + bytes.extend_from_slice(&[57, 5, 0, 0]); + + let result = ValueDecoder::new(&bytes).decode::(); + + assert_eq!( + result, + Err(ValueDecodeError::InvalidData { + value_name: "server response code".to_string(), + cause: "unknown value 1337".to_string(), + position: 0, + }) + ); + } + + #[test] + fn roundtrip_connect_to_peer() { + roundtrip(ServerResponse::ConnectToPeerResponse( + ConnectToPeerResponse { + user_name: "alice".to_string(), + connection_type: "P".to_string(), + ip: net::Ipv4Addr::new(192, 168, 254, 1), + port: 1337, + token: 42, + is_privileged: true, + }, + )) + } + + #[test] + fn roundtrip_file_search() { + roundtrip(ServerResponse::FileSearchResponse(FileSearchResponse { + user_name: "alice".to_string(), + ticket: 1337, + query: "foo.txt".to_string(), + })) + } + + #[test] + fn roundtrip_login_ok() { + roundtrip(ServerResponse::LoginResponse(LoginResponse::LoginOk { + motd: "welcome one welcome all!".to_string(), + ip: net::Ipv4Addr::new(127, 0, 0, 1), + password_md5_opt: None, + })) + } + + #[test] + fn roundtrip_login_fail() { + roundtrip(ServerResponse::LoginResponse(LoginResponse::LoginFail { + reason: "I just don't like you".to_string(), + })) + } + + #[test] + fn roundtrip_parent_min_speed() { + roundtrip(ServerResponse::ParentMinSpeedResponse( + ParentMinSpeedResponse { value: 1337 }, + )) + } + + #[test] + fn roundtrip_parent_speed_ratio() { + roundtrip(ServerResponse::ParentSpeedRatioResponse( + ParentSpeedRatioResponse { value: 1337 }, + )) + } + + #[test] + fn roundtrip_peer_address() { + roundtrip(ServerResponse::PeerAddressResponse(PeerAddressResponse { + username: "alice".to_string(), + ip: net::Ipv4Addr::new(127, 0, 0, 1), + port: 1337, + })) + } + + #[test] + fn roundtrip_privileged_users() { + roundtrip(ServerResponse::PrivilegedUsersResponse( + PrivilegedUsersResponse { + users: vec![ + "alice".to_string(), + "bob".to_string(), + "chris".to_string(), + "dory".to_string(), + ], + }, + )) + } + + #[test] + fn roundtrip_room_join() { + roundtrip(ServerResponse::RoomJoinResponse(RoomJoinResponse { + room_name: "red".to_string(), + users: vec![ + User { + name: "alice".to_string(), + status: UserStatus::Online, + average_speed: 1000, + num_downloads: 1001, + unknown: 1002, + num_files: 1003, + num_folders: 1004, + num_free_slots: 1005, + country: "US".to_string(), + }, + User { + name: "barbara".to_string(), + status: UserStatus::Away, + average_speed: 2000, + num_downloads: 2001, + unknown: 2002, + num_files: 2003, + num_folders: 2004, + num_free_slots: 2005, + country: "DE".to_string(), + }, + ], + owner: Some("carol".to_string()), + operators: vec!["deirdre".to_string(), "erica".to_string()], + })) + } + + #[test] + fn roundtrip_room_join_no_owner() { + roundtrip(ServerResponse::RoomJoinResponse(RoomJoinResponse { + room_name: "red".to_string(), + users: vec![], + owner: None, + operators: vec![], + })) + } + + #[test] + fn roundtrip_room_leave() { + roundtrip(ServerResponse::RoomLeaveResponse(RoomLeaveResponse { + room_name: "red".to_string(), + })) + } + + #[test] + fn roundtrip_room_list() { + roundtrip(ServerResponse::RoomListResponse(RoomListResponse { + rooms: vec![("red".to_string(), 12), ("blue".to_string(), 13)], + owned_private_rooms: vec![ + ("green".to_string(), 14), + ("purple".to_string(), 15), + ], + other_private_rooms: vec![ + ("yellow".to_string(), 16), + ("orange".to_string(), 17), + ], + operated_private_room_names: vec![ + "brown".to_string(), + "pink".to_string(), + ], + })) + } + + #[test] + fn roundtrip_room_message() { + roundtrip(ServerResponse::RoomMessageResponse(RoomMessageResponse { + room_name: "red".to_string(), + user_name: "alice".to_string(), + message: "hello world!".to_string(), + })) + } + + #[test] + fn roundtrip_room_tickers() { + roundtrip(ServerResponse::RoomTickersResponse(RoomTickersResponse { + room_name: "red".to_string(), + tickers: vec![ + ("alice".to_string(), "hello world!".to_string()), + ("bob".to_string(), "hi alice :)".to_string()), + ], + })) + } + + #[test] + fn roundtrip_room_user_joined() { + roundtrip(ServerResponse::RoomUserJoinedResponse( + RoomUserJoinedResponse { + room_name: "red".to_string(), + user: User { + name: "alice".to_string(), + status: UserStatus::Online, + average_speed: 1000, + num_downloads: 1001, + unknown: 1002, + num_files: 1003, + num_folders: 1004, + num_free_slots: 1005, + country: "AR".to_string(), + }, + }, + )) + } + + #[test] + fn roundtrip_room_user_left() { + roundtrip(ServerResponse::RoomUserLeftResponse(RoomUserLeftResponse { + room_name: "red".to_string(), + user_name: "alice".to_string(), + })) + } + + #[test] + fn roundtrip_user_info() { + roundtrip(ServerResponse::UserInfoResponse(UserInfoResponse { + user_name: "alice".to_string(), + average_speed: 1000, + num_downloads: 1001, + num_files: 1002, + num_folders: 1003, + })) + } + + #[test] + fn roundtrip_user_status() { + roundtrip(ServerResponse::UserStatusResponse(UserStatusResponse { + user_name: "alice".to_string(), + status: UserStatus::Offline, + is_privileged: true, + })) + } + + #[test] + fn roundtrip_wishlist_interval() { + roundtrip(ServerResponse::WishlistIntervalResponse( + WishlistIntervalResponse { seconds: 1337 }, + )) + } } diff --git a/src/proto/stream.rs b/src/proto/stream.rs index 1e2926b..a829690 100644 --- a/src/proto/stream.rs +++ b/src/proto/stream.rs @@ -15,41 +15,41 @@ use super::packet::{MutPacket, Parser, ReadFromPacket, WriteToPacket}; /// A struct used for writing bytes to a TryWrite sink. #[derive(Debug)] struct OutBuf { - cursor: usize, - bytes: Vec, + cursor: usize, + bytes: Vec, } impl From> for OutBuf { - fn from(bytes: Vec) -> Self { - OutBuf { - cursor: 0, - bytes: bytes, - } + fn from(bytes: Vec) -> Self { + OutBuf { + cursor: 0, + bytes: bytes, } + } } impl OutBuf { - #[inline] - fn remaining(&self) -> usize { - self.bytes.len() - self.cursor - } - - #[inline] - fn has_remaining(&self) -> bool { - self.remaining() > 0 - } - - #[allow(deprecated)] - fn try_write_to(&mut self, mut writer: T) -> io::Result> - where - T: mio::deprecated::TryWrite, - { - let result = writer.try_write(&self.bytes[self.cursor..]); - if let Ok(Some(bytes_written)) = result { - self.cursor += bytes_written; - } - result + #[inline] + fn remaining(&self) -> usize { + self.bytes.len() - self.cursor + } + + #[inline] + fn has_remaining(&self) -> bool { + self.remaining() > 0 + } + + #[allow(deprecated)] + fn try_write_to(&mut self, mut writer: T) -> io::Result> + where + T: mio::deprecated::TryWrite, + { + let result = writer.try_write(&self.bytes[self.cursor..]); + if let Ok(Some(bytes_written)) = result { + self.cursor += bytes_written; } + result + } } /*========* @@ -59,171 +59,171 @@ impl OutBuf { /// This trait is implemented by packet sinks to which a stream can forward /// the packets it reads. pub trait SendPacket { - type Value: ReadFromPacket; - type Error: error::Error; + type Value: ReadFromPacket; + type Error: error::Error; - fn send_packet(&mut self, _: Self::Value) -> Result<(), Self::Error>; + fn send_packet(&mut self, _: Self::Value) -> Result<(), Self::Error>; - fn notify_open(&mut self) -> Result<(), Self::Error>; + fn notify_open(&mut self) -> Result<(), Self::Error>; } /// This enum defines the possible actions the stream wants to take after /// processing an event. #[derive(Debug, Clone, Copy)] pub enum Intent { - /// The stream is done, the event loop handler can drop it. - Done, - /// The stream wants to wait for the next event matching the given - /// `EventSet`. - Continue(mio::Ready), + /// The stream is done, the event loop handler can drop it. + Done, + /// The stream wants to wait for the next event matching the given + /// `EventSet`. + Continue(mio::Ready), } /// This struct wraps around an mio tcp stream and handles packet reads and /// writes. #[derive(Debug)] pub struct Stream { - parser: Parser, - queue: VecDeque, - sender: T, - stream: mio::tcp::TcpStream, + parser: Parser, + queue: VecDeque, + sender: T, + stream: mio::tcp::TcpStream, - is_connected: bool, + is_connected: bool, } impl Stream { - /// Returns a new stream, asynchronously connected to the given address, - /// which forwards incoming packets to the given sender. - /// If an error occurs when connecting, returns an error. - pub fn new(addr_spec: U, sender: T) -> io::Result - where - U: ToSocketAddrs + fmt::Debug, - { - for sock_addr in addr_spec.to_socket_addrs()? { - if let Ok(stream) = mio::tcp::TcpStream::connect(&sock_addr) { - return Ok(Stream { - parser: Parser::new(), - queue: VecDeque::new(), - sender: sender, - stream: stream, - - is_connected: false, - }); - } - } - Err(io::Error::new( - io::ErrorKind::Other, - format!("Cannot connect to {:?}", addr_spec), - )) + /// Returns a new stream, asynchronously connected to the given address, + /// which forwards incoming packets to the given sender. + /// If an error occurs when connecting, returns an error. + pub fn new(addr_spec: U, sender: T) -> io::Result + where + U: ToSocketAddrs + fmt::Debug, + { + for sock_addr in addr_spec.to_socket_addrs()? { + if let Ok(stream) = mio::tcp::TcpStream::connect(&sock_addr) { + return Ok(Stream { + parser: Parser::new(), + queue: VecDeque::new(), + sender: sender, + stream: stream, + + is_connected: false, + }); + } } - - /// Returns a reference to the underlying byte stream, to allow it to be - /// registered with an event loop. - pub fn evented(&self) -> &mio::tcp::TcpStream { - &self.stream + Err(io::Error::new( + io::ErrorKind::Other, + format!("Cannot connect to {:?}", addr_spec), + )) + } + + /// Returns a reference to the underlying byte stream, to allow it to be + /// registered with an event loop. + pub fn evented(&self) -> &mio::tcp::TcpStream { + &self.stream + } + + /// The stream is ready to be read from. + fn on_readable(&mut self) -> Result<(), String> { + loop { + let mut packet = match self.parser.try_read(&mut self.stream) { + Ok(Some(packet)) => packet, + Ok(None) => break, + Err(e) => return Err(format!("Error reading stream: {}", e)), + }; + let value = match packet.read_value() { + Ok(value) => value, + Err(e) => return Err(format!("Error parsing packet: {}", e)), + }; + if let Err(e) = self.sender.send_packet(value) { + return Err(format!("Error sending parsed packet: {}", e)); + } } - - /// The stream is ready to be read from. - fn on_readable(&mut self) -> Result<(), String> { - loop { - let mut packet = match self.parser.try_read(&mut self.stream) { - Ok(Some(packet)) => packet, - Ok(None) => break, - Err(e) => return Err(format!("Error reading stream: {}", e)), - }; - let value = match packet.read_value() { - Ok(value) => value, - Err(e) => return Err(format!("Error parsing packet: {}", e)), - }; - if let Err(e) = self.sender.send_packet(value) { - return Err(format!("Error sending parsed packet: {}", e)); - } + Ok(()) + } + + /// The stream is ready to be written to. + fn on_writable(&mut self) -> io::Result<()> { + loop { + let mut outbuf = match self.queue.pop_front() { + Some(outbuf) => outbuf, + None => break, + }; + + let option = outbuf.try_write_to(&mut self.stream)?; + match option { + Some(_) => { + if outbuf.has_remaining() { + self.queue.push_front(outbuf) + } + // Continue looping } - Ok(()) - } - - /// The stream is ready to be written to. - fn on_writable(&mut self) -> io::Result<()> { - loop { - let mut outbuf = match self.queue.pop_front() { - Some(outbuf) => outbuf, - None => break, - }; - - let option = outbuf.try_write_to(&mut self.stream)?; - match option { - Some(_) => { - if outbuf.has_remaining() { - self.queue.push_front(outbuf) - } - // Continue looping - } - None => { - self.queue.push_front(outbuf); - break; - } - } + None => { + self.queue.push_front(outbuf); + break; } - Ok(()) + } } + Ok(()) + } - /// The stream is ready to read, write, or both. - pub fn on_ready(&mut self, event_set: mio::Ready) -> Intent { - #[allow(deprecated)] - if event_set.is_hup() || event_set.is_error() { - return Intent::Done; - } - if event_set.is_readable() { - let result = self.on_readable(); - if let Err(e) = result { - error!("Stream input error: {}", e); - return Intent::Done; - } - } - if event_set.is_writable() { - let result = self.on_writable(); - if let Err(e) = result { - error!("Stream output error: {}", e); - return Intent::Done; - } - } - - // We must have read or written something succesfully if we're here, - // so the stream must be connected. - if !self.is_connected { - // If we weren't already connected, notify the sink. - if let Err(err) = self.sender.notify_open() { - error!("Cannot notify client that stream is open: {}", err); - return Intent::Done; - } - // And record the fact that we are now connected. - self.is_connected = true; - } + /// The stream is ready to read, write, or both. + pub fn on_ready(&mut self, event_set: mio::Ready) -> Intent { + #[allow(deprecated)] + if event_set.is_hup() || event_set.is_error() { + return Intent::Done; + } + if event_set.is_readable() { + let result = self.on_readable(); + if let Err(e) = result { + error!("Stream input error: {}", e); + return Intent::Done; + } + } + if event_set.is_writable() { + let result = self.on_writable(); + if let Err(e) = result { + error!("Stream output error: {}", e); + return Intent::Done; + } + } - // We're always interested in reading more. - #[allow(deprecated)] - let mut event_set = - mio::Ready::readable() | mio::Ready::hup() | mio::Ready::error(); - // If there is still stuff to write in the queue, we're interested in - // the socket becoming writable too. - if self.queue.len() > 0 { - event_set = event_set | mio::Ready::writable(); - } + // We must have read or written something succesfully if we're here, + // so the stream must be connected. + if !self.is_connected { + // If we weren't already connected, notify the sink. + if let Err(err) = self.sender.notify_open() { + error!("Cannot notify client that stream is open: {}", err); + return Intent::Done; + } + // And record the fact that we are now connected. + self.is_connected = true; + } - Intent::Continue(event_set) + // We're always interested in reading more. + #[allow(deprecated)] + let mut event_set = + mio::Ready::readable() | mio::Ready::hup() | mio::Ready::error(); + // If there is still stuff to write in the queue, we're interested in + // the socket becoming writable too. + if self.queue.len() > 0 { + event_set = event_set | mio::Ready::writable(); } - /// The stream has been notified. - pub fn on_notify(&mut self, payload: &V) -> Intent - where - V: WriteToPacket, - { - let mut packet = MutPacket::new(); - let result = packet.write_value(payload); - if let Err(e) = result { - error!("Error writing payload to packet: {}", e); - return Intent::Done; - } - self.queue.push_back(OutBuf::from(packet.into_bytes())); - Intent::Continue(mio::Ready::readable() | mio::Ready::writable()) + Intent::Continue(event_set) + } + + /// The stream has been notified. + pub fn on_notify(&mut self, payload: &V) -> Intent + where + V: WriteToPacket, + { + let mut packet = MutPacket::new(); + let result = packet.write_value(payload); + if let Err(e) = result { + error!("Error writing payload to packet: {}", e); + return Intent::Done; } + self.queue.push_back(OutBuf::from(packet.into_bytes())); + Intent::Continue(mio::Ready::readable() | mio::Ready::writable()) + } } diff --git a/src/proto/testing.rs b/src/proto/testing.rs index 618afaa..69e0183 100644 --- a/src/proto/testing.rs +++ b/src/proto/testing.rs @@ -8,68 +8,68 @@ use tokio::net::{TcpListener, TcpStream}; use crate::proto::{FrameStream, ServerRequest, ServerResponse}; async fn process(stream: TcpStream) -> io::Result<()> { - let mut connection = - FrameStream::::new(stream); + let mut connection = + FrameStream::::new(stream); - let _request = match connection.read().await? { - ServerRequest::LoginRequest(request) => request, - request => { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("expected login request, got: {:?}", request), - )); - } - }; + let _request = match connection.read().await? { + ServerRequest::LoginRequest(request) => request, + request => { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("expected login request, got: {:?}", request), + )); + } + }; - Ok(()) + Ok(()) } /// A fake server for connecting to in tests. pub struct FakeServer { - listener: TcpListener, + listener: TcpListener, } impl FakeServer { - /// Creates a new fake server and binds it to a port on localhost. - pub async fn new() -> io::Result { - let listener = TcpListener::bind("localhost:0").await?; - Ok(FakeServer { listener }) - } + /// Creates a new fake server and binds it to a port on localhost. + pub async fn new() -> io::Result { + let listener = TcpListener::bind("localhost:0").await?; + Ok(FakeServer { listener }) + } - /// Returns the address to which this server is bound. - /// This is always localhost and a random port chosen by the OS. - pub fn address(&self) -> io::Result { - self.listener.local_addr() - } + /// Returns the address to which this server is bound. + /// This is always localhost and a random port chosen by the OS. + pub fn address(&self) -> io::Result { + self.listener.local_addr() + } - /// Runs the server: accepts incoming connections and responds to requests. - pub async fn run(&mut self) -> io::Result<()> { - loop { - let (socket, _peer_address) = self.listener.accept().await?; - tokio::spawn(async move { process(socket).await }); - } + /// Runs the server: accepts incoming connections and responds to requests. + pub async fn run(&mut self) -> io::Result<()> { + loop { + let (socket, _peer_address) = self.listener.accept().await?; + tokio::spawn(async move { process(socket).await }); } + } } #[cfg(test)] mod tests { - use tokio::net::TcpStream; + use tokio::net::TcpStream; - use super::FakeServer; + use super::FakeServer; - #[tokio::test] - async fn new_binds_to_localhost() { - let server = FakeServer::new().await.unwrap(); - assert!(server.address().unwrap().ip().is_loopback()); - } + #[tokio::test] + async fn new_binds_to_localhost() { + let server = FakeServer::new().await.unwrap(); + assert!(server.address().unwrap().ip().is_loopback()); + } - #[tokio::test] - async fn accepts_incoming_connections() { - let mut server = FakeServer::new().await.unwrap(); - let address = server.address().unwrap(); - tokio::spawn(async move { server.run().await.unwrap() }); + #[tokio::test] + async fn accepts_incoming_connections() { + let mut server = FakeServer::new().await.unwrap(); + let address = server.address().unwrap(); + tokio::spawn(async move { server.run().await.unwrap() }); - // The connection succeeds. - let _ = TcpStream::connect(address).await.unwrap(); - } + // The connection succeeds. + let _ = TcpStream::connect(address).await.unwrap(); + } } diff --git a/src/proto/u32.rs b/src/proto/u32.rs index bc6999c..cc51f71 100644 --- a/src/proto/u32.rs +++ b/src/proto/u32.rs @@ -8,10 +8,10 @@ pub const U32_BYTE_LEN: usize = 4; /// Returns the byte representatio of the given integer value. pub fn encode_u32(value: u32) -> [u8; U32_BYTE_LEN] { - value.to_le_bytes() + value.to_le_bytes() } /// Returns the integer value corresponding to the given bytes. pub fn decode_u32(bytes: [u8; U32_BYTE_LEN]) -> u32 { - u32::from_le_bytes(bytes) + u32::from_le_bytes(bytes) } diff --git a/src/proto/user.rs b/src/proto/user.rs index 26dabde..1a752ec 100644 --- a/src/proto/user.rs +++ b/src/proto/user.rs @@ -1,9 +1,9 @@ use std::io; use crate::proto::{ - MutPacket, Packet, PacketReadError, ReadFromPacket, ValueDecode, - ValueDecodeError, ValueDecoder, ValueEncode, ValueEncodeError, - ValueEncoder, WriteToPacket, + MutPacket, Packet, PacketReadError, ReadFromPacket, ValueDecode, + ValueDecodeError, ValueDecoder, ValueEncode, ValueEncodeError, ValueEncoder, + WriteToPacket, }; const STATUS_OFFLINE: u32 = 1; @@ -12,103 +12,98 @@ const STATUS_ONLINE: u32 = 3; /// This enumeration is the list of possible user statuses. #[derive( - Clone, - Copy, - Debug, - Eq, - Ord, - PartialEq, - PartialOrd, - RustcDecodable, - RustcEncodable, + Clone, + Copy, + Debug, + Eq, + Ord, + PartialEq, + PartialOrd, + RustcDecodable, + RustcEncodable, )] pub enum UserStatus { - /// The user if offline. - Offline, - /// The user is connected, but AFK. - Away, - /// The user is present. - Online, + /// The user if offline. + Offline, + /// The user is connected, but AFK. + Away, + /// The user is present. + Online, } impl ReadFromPacket for UserStatus { - fn read_from_packet(packet: &mut Packet) -> Result { - let n: u32 = packet.read_value()?; - match n { - STATUS_OFFLINE => Ok(UserStatus::Offline), - STATUS_AWAY => Ok(UserStatus::Away), - STATUS_ONLINE => Ok(UserStatus::Online), - _ => Err(PacketReadError::InvalidUserStatusError(n)), - } + fn read_from_packet(packet: &mut Packet) -> Result { + let n: u32 = packet.read_value()?; + match n { + STATUS_OFFLINE => Ok(UserStatus::Offline), + STATUS_AWAY => Ok(UserStatus::Away), + STATUS_ONLINE => Ok(UserStatus::Online), + _ => Err(PacketReadError::InvalidUserStatusError(n)), } + } } impl WriteToPacket for UserStatus { - fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { - let n = match *self { - UserStatus::Offline => STATUS_OFFLINE, - UserStatus::Away => STATUS_AWAY, - UserStatus::Online => STATUS_ONLINE, - }; - packet.write_value(&n)?; - Ok(()) - } + fn write_to_packet(&self, packet: &mut MutPacket) -> io::Result<()> { + let n = match *self { + UserStatus::Offline => STATUS_OFFLINE, + UserStatus::Away => STATUS_AWAY, + UserStatus::Online => STATUS_ONLINE, + }; + packet.write_value(&n)?; + Ok(()) + } } impl ValueEncode for UserStatus { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - let value = match *self { - UserStatus::Offline => STATUS_OFFLINE, - UserStatus::Away => STATUS_AWAY, - UserStatus::Online => STATUS_ONLINE, - }; - encoder.encode_u32(value) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + let value = match *self { + UserStatus::Offline => STATUS_OFFLINE, + UserStatus::Away => STATUS_AWAY, + UserStatus::Online => STATUS_ONLINE, + }; + encoder.encode_u32(value) + } } impl ValueDecode for UserStatus { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let position = decoder.position(); - let value: u32 = decoder.decode()?; - match value { - STATUS_OFFLINE => Ok(UserStatus::Offline), - STATUS_AWAY => Ok(UserStatus::Away), - STATUS_ONLINE => Ok(UserStatus::Online), - _ => Err(ValueDecodeError::InvalidData { - value_name: "user status".to_string(), - cause: format!("unknown value {}", value), - position: position, - }), - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let position = decoder.position(); + let value: u32 = decoder.decode()?; + match value { + STATUS_OFFLINE => Ok(UserStatus::Offline), + STATUS_AWAY => Ok(UserStatus::Away), + STATUS_ONLINE => Ok(UserStatus::Online), + _ => Err(ValueDecodeError::InvalidData { + value_name: "user status".to_string(), + cause: format!("unknown value {}", value), + position: position, + }), } + } } /// This structure contains the last known information about a fellow user. #[derive( - Clone, Debug, Eq, Ord, PartialEq, PartialOrd, RustcDecodable, RustcEncodable, + Clone, Debug, Eq, Ord, PartialEq, PartialOrd, RustcDecodable, RustcEncodable, )] pub struct User { - /// The name of the user. - pub name: String, - /// The last known status of the user. - pub status: UserStatus, - /// The average upload speed of the user. - pub average_speed: usize, - /// ??? Nicotine calls it downloadnum. - pub num_downloads: usize, - /// ??? Unknown field. - pub unknown: usize, - /// The number of files this user shares. - pub num_files: usize, - /// The number of folders this user shares. - pub num_folders: usize, - /// The number of free download slots of this user. - pub num_free_slots: usize, - /// The user's country code. - pub country: String, + /// The name of the user. + pub name: String, + /// The last known status of the user. + pub status: UserStatus, + /// The average upload speed of the user. + pub average_speed: usize, + /// ??? Nicotine calls it downloadnum. + pub num_downloads: usize, + /// ??? Unknown field. + pub unknown: usize, + /// The number of files this user shares. + pub num_files: usize, + /// The number of folders this user shares. + pub num_folders: usize, + /// The number of free download slots of this user. + pub num_free_slots: usize, + /// The user's country code. + pub country: String, } diff --git a/src/proto/value_codec.rs b/src/proto/value_codec.rs index 942314c..41ffdd7 100644 --- a/src/proto/value_codec.rs +++ b/src/proto/value_codec.rs @@ -27,432 +27,398 @@ use super::prefix::Prefixer; use super::u32::{decode_u32, encode_u32, U32_BYTE_LEN}; pub trait Decode { - /// Attempts to decode an instance of `T` from `self`. - fn decode(&mut self) -> io::Result; + /// Attempts to decode an instance of `T` from `self`. + fn decode(&mut self) -> io::Result; } pub trait Encode { - /// Attempts to encode `value` into `self`. - fn encode(&mut self, value: T) -> io::Result<()>; + /// Attempts to encode `value` into `self`. + fn encode(&mut self, value: T) -> io::Result<()>; } // TODO: Add backtrace fields to each enum variant once std::backtrace is // stabilized. #[derive(PartialEq, Error, Debug)] pub enum ValueDecodeError { - #[error("at position {position}: not enough bytes to decode: expected {expected}, found {remaining}")] - NotEnoughData { - /// The number of bytes the decoder expected to read. - /// - /// Invariant: `remaining < expected`. - expected: usize, - - /// The number of bytes remaining in the input buffer. - /// - /// Invariant: `remaining < expected`. - remaining: usize, - - /// The decoder's position in the input buffer. - position: usize, - }, - #[error("at position {position}: invalid boolean value: {value}")] - InvalidBool { - /// The invalid value. Never equal to 0 nor 1. - value: u8, - - /// The decoder's position in the input buffer. - position: usize, - }, - #[error("at position {position}: invalid u16 value: {value}")] - InvalidU16 { - /// The invalid value. Always greater than u16::max_value(). - value: u32, - - /// The decoder's position in the input buffer. - position: usize, - }, - #[error("at position {position}: failed to decode string: {cause}")] - InvalidString { - /// The cause of the error, as reported by the encoding library. - cause: String, - - /// The decoder's position in the input buffer. - position: usize, - }, - #[error("at position {position}: invalid {value_name}: {cause}")] - InvalidData { - /// The name of the value the decoder failed to decode. - value_name: String, - - /// The cause of the error. - cause: String, - - /// The decoder's position in the input buffer. - position: usize, - }, + #[error("at position {position}: not enough bytes to decode: expected {expected}, found {remaining}")] + NotEnoughData { + /// The number of bytes the decoder expected to read. + /// + /// Invariant: `remaining < expected`. + expected: usize, + + /// The number of bytes remaining in the input buffer. + /// + /// Invariant: `remaining < expected`. + remaining: usize, + + /// The decoder's position in the input buffer. + position: usize, + }, + #[error("at position {position}: invalid boolean value: {value}")] + InvalidBool { + /// The invalid value. Never equal to 0 nor 1. + value: u8, + + /// The decoder's position in the input buffer. + position: usize, + }, + #[error("at position {position}: invalid u16 value: {value}")] + InvalidU16 { + /// The invalid value. Always greater than u16::max_value(). + value: u32, + + /// The decoder's position in the input buffer. + position: usize, + }, + #[error("at position {position}: failed to decode string: {cause}")] + InvalidString { + /// The cause of the error, as reported by the encoding library. + cause: String, + + /// The decoder's position in the input buffer. + position: usize, + }, + #[error("at position {position}: invalid {value_name}: {cause}")] + InvalidData { + /// The name of the value the decoder failed to decode. + value_name: String, + + /// The cause of the error. + cause: String, + + /// The decoder's position in the input buffer. + position: usize, + }, } impl From for io::Error { - fn from(error: ValueDecodeError) -> Self { - let kind = match &error { - &ValueDecodeError::NotEnoughData { .. } => { - io::ErrorKind::UnexpectedEof - } - _ => io::ErrorKind::InvalidData, - }; - let message = format!("{}", &error); - io::Error::new(kind, message) - } + fn from(error: ValueDecodeError) -> Self { + let kind = match &error { + &ValueDecodeError::NotEnoughData { .. } => io::ErrorKind::UnexpectedEof, + _ => io::ErrorKind::InvalidData, + }; + let message = format!("{}", &error); + io::Error::new(kind, message) + } } /// A type for decoding various types of values from protocol messages. pub struct ValueDecoder<'a> { - // The buffer we are decoding from. - // - // Invariant: `position <= buffer.len()`. - buffer: &'a [u8], - - // Our current position within `buffer`. - // - // We could instead maintain this implicitly in `buffer` by splitting off - // decoded bytes from the start of the buffer, but we would then be unable - // to remember how many bytes we had decoded. This information is useful to - // have in error messages when encountering decoding errors. - // - // Invariant: `position <= buffer.len()`. - position: usize, + // The buffer we are decoding from. + // + // Invariant: `position <= buffer.len()`. + buffer: &'a [u8], + + // Our current position within `buffer`. + // + // We could instead maintain this implicitly in `buffer` by splitting off + // decoded bytes from the start of the buffer, but we would then be unable + // to remember how many bytes we had decoded. This information is useful to + // have in error messages when encountering decoding errors. + // + // Invariant: `position <= buffer.len()`. + position: usize, } /// This trait is implemented by types that can be decoded from messages using /// a `ValueDecoder`. pub trait ValueDecode: Sized { - /// Attempts to decode a value of this type with the given decoder. - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result; + /// Attempts to decode a value of this type with the given decoder. + fn decode_from(decoder: &mut ValueDecoder) -> Result; } impl<'a> ValueDecoder<'a> { - /// Wraps the given byte buffer. - pub fn new(buffer: &'a [u8]) -> Self { - Self { - buffer: buffer, - position: 0, - } - } - - /// The current position of this decoder in the input buffer. - pub fn position(&self) -> usize { - self.position - } - - /// Returns the number of bytes remaining to decode. - pub fn remaining(&self) -> usize { - self.buffer.len() - self.position - } - - /// Returns whether the underlying buffer has remaining bytes to decode. - /// - /// Shorthand for `remaining() > 0`. - pub fn has_remaining(&self) -> bool { - self.remaining() > 0 - } - - /// Returns a read-only view of the remaining bytes to decode. - /// - /// The returned slice is of size `remaining()`. - pub fn bytes(&self) -> &[u8] { - &self.buffer[self.position..] - } - - /// Attempts to consume the next `n` bytes from this buffer. - /// - /// Returns a slice of size `n` if successful, in which case this decoder - /// advances its internal position by `n`. - fn consume(&mut self, n: usize) -> Result<&[u8], ValueDecodeError> { - if self.remaining() < n { - return Err(ValueDecodeError::NotEnoughData { - expected: n, - remaining: self.remaining(), - position: self.position, - }); - } - - // Cannot use bytes() here as it borrows self immutably, which - // prevents us from mutating self.position afterwards. - let end = self.position + n; - let bytes = &self.buffer[self.position..end]; - self.position = end; - Ok(bytes) - } - - /// Attempts to decode a u32 value. - fn decode_u32(&mut self) -> Result { - let bytes = self.consume(U32_BYTE_LEN)?; - // The conversion from slice to fixed-size array cannot fail, because - // consume() guarantees that its return value is of size n. - let array: [u8; U32_BYTE_LEN] = bytes.try_into().unwrap(); - Ok(decode_u32(array)) - } - - /// Attempts to decode a u16 value. - fn decode_u16(&mut self) -> Result { - let position = self.position; - let n = self.decode_u32()?; - match u16::try_from(n) { - Ok(value) => Ok(value), - Err(_) => Err(ValueDecodeError::InvalidU16 { - value: n, - position: position, - }), - } - } - - /// Attempts to decode a boolean value. - fn decode_bool(&mut self) -> Result { - let position = self.position; - let bytes = self.consume(1)?; - match bytes[0] { - 0 => Ok(false), - 1 => Ok(true), - n => Err(ValueDecodeError::InvalidBool { - value: n, - position: position, - }), - } - } - - /// Attempts to decode a string value. - fn decode_string(&mut self) -> Result { - let length = self.decode_u32()? as usize; - - let position = self.position; - let bytes = self.consume(length)?; - - let result = WINDOWS_1252 - .decode_without_bom_handling_and_without_replacement(bytes); - match result { - Some(string) => Ok(string.into_owned()), - None => Err(ValueDecodeError::InvalidString { - cause: "malformed sequence in Windows-1252-encoded string" - .to_string(), - position: position, - }), - } - } - - /// Attempts to decode a value of the given type. - /// - /// Allows easy decoding of complex values using type inference: - /// - /// ``` - /// let val: Foo = decoder.decode()?; - /// ``` - pub fn decode(&mut self) -> Result { - T::decode_from(self) - } + /// Wraps the given byte buffer. + pub fn new(buffer: &'a [u8]) -> Self { + Self { + buffer: buffer, + position: 0, + } + } + + /// The current position of this decoder in the input buffer. + pub fn position(&self) -> usize { + self.position + } + + /// Returns the number of bytes remaining to decode. + pub fn remaining(&self) -> usize { + self.buffer.len() - self.position + } + + /// Returns whether the underlying buffer has remaining bytes to decode. + /// + /// Shorthand for `remaining() > 0`. + pub fn has_remaining(&self) -> bool { + self.remaining() > 0 + } + + /// Returns a read-only view of the remaining bytes to decode. + /// + /// The returned slice is of size `remaining()`. + pub fn bytes(&self) -> &[u8] { + &self.buffer[self.position..] + } + + /// Attempts to consume the next `n` bytes from this buffer. + /// + /// Returns a slice of size `n` if successful, in which case this decoder + /// advances its internal position by `n`. + fn consume(&mut self, n: usize) -> Result<&[u8], ValueDecodeError> { + if self.remaining() < n { + return Err(ValueDecodeError::NotEnoughData { + expected: n, + remaining: self.remaining(), + position: self.position, + }); + } + + // Cannot use bytes() here as it borrows self immutably, which + // prevents us from mutating self.position afterwards. + let end = self.position + n; + let bytes = &self.buffer[self.position..end]; + self.position = end; + Ok(bytes) + } + + /// Attempts to decode a u32 value. + fn decode_u32(&mut self) -> Result { + let bytes = self.consume(U32_BYTE_LEN)?; + // The conversion from slice to fixed-size array cannot fail, because + // consume() guarantees that its return value is of size n. + let array: [u8; U32_BYTE_LEN] = bytes.try_into().unwrap(); + Ok(decode_u32(array)) + } + + /// Attempts to decode a u16 value. + fn decode_u16(&mut self) -> Result { + let position = self.position; + let n = self.decode_u32()?; + match u16::try_from(n) { + Ok(value) => Ok(value), + Err(_) => Err(ValueDecodeError::InvalidU16 { + value: n, + position: position, + }), + } + } + + /// Attempts to decode a boolean value. + fn decode_bool(&mut self) -> Result { + let position = self.position; + let bytes = self.consume(1)?; + match bytes[0] { + 0 => Ok(false), + 1 => Ok(true), + n => Err(ValueDecodeError::InvalidBool { + value: n, + position: position, + }), + } + } + + /// Attempts to decode a string value. + fn decode_string(&mut self) -> Result { + let length = self.decode_u32()? as usize; + + let position = self.position; + let bytes = self.consume(length)?; + + let result = + WINDOWS_1252.decode_without_bom_handling_and_without_replacement(bytes); + match result { + Some(string) => Ok(string.into_owned()), + None => Err(ValueDecodeError::InvalidString { + cause: "malformed sequence in Windows-1252-encoded string".to_string(), + position: position, + }), + } + } + + /// Attempts to decode a value of the given type. + /// + /// Allows easy decoding of complex values using type inference: + /// + /// ``` + /// let val: Foo = decoder.decode()?; + /// ``` + pub fn decode(&mut self) -> Result { + T::decode_from(self) + } } impl ValueDecode for u32 { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - decoder.decode_u32() - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + decoder.decode_u32() + } } impl ValueDecode for u16 { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - decoder.decode_u16() - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + decoder.decode_u16() + } } impl ValueDecode for bool { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - decoder.decode_bool() - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + decoder.decode_bool() + } } impl ValueDecode for net::Ipv4Addr { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let ip = decoder.decode_u32()?; - Ok(net::Ipv4Addr::from(ip)) - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let ip = decoder.decode_u32()?; + Ok(net::Ipv4Addr::from(ip)) + } } impl ValueDecode for String { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - decoder.decode_string() - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + decoder.decode_string() + } } impl ValueDecode for (T, U) { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let first = decoder.decode()?; - let second = decoder.decode()?; - Ok((first, second)) - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let first = decoder.decode()?; + let second = decoder.decode()?; + Ok((first, second)) + } } impl ValueDecode for Vec { - fn decode_from( - decoder: &mut ValueDecoder, - ) -> Result { - let len = decoder.decode_u32()? as usize; - let mut vec = Vec::with_capacity(len); - for _ in 0..len { - let val = decoder.decode()?; - vec.push(val); - } - Ok(vec) - } + fn decode_from(decoder: &mut ValueDecoder) -> Result { + let len = decoder.decode_u32()? as usize; + let mut vec = Vec::with_capacity(len); + for _ in 0..len { + let val = decoder.decode()?; + vec.push(val); + } + Ok(vec) + } } #[derive(Debug, Error, PartialEq)] pub enum ValueEncodeError { - #[error("encoded string length {length} is too large: {string:?}")] - StringTooLong { - /// The string that is too long to encode. - string: String, - - /// The length of `string` in the Windows-1252 encoding. - /// Always larger than `u32::max_value()`. - length: usize, - }, + #[error("encoded string length {length} is too large: {string:?}")] + StringTooLong { + /// The string that is too long to encode. + string: String, + + /// The length of `string` in the Windows-1252 encoding. + /// Always larger than `u32::max_value()`. + length: usize, + }, } impl From for io::Error { - fn from(error: ValueEncodeError) -> Self { - io::Error::new(io::ErrorKind::InvalidData, format!("{}", error)) - } + fn from(error: ValueEncodeError) -> Self { + io::Error::new(io::ErrorKind::InvalidData, format!("{}", error)) + } } /// A type for encoding various types of values into protocol messages. pub struct ValueEncoder<'a> { - /// The buffer to which the encoder appends encoded bytes. - buffer: &'a mut BytesMut, + /// The buffer to which the encoder appends encoded bytes. + buffer: &'a mut BytesMut, } /// This trait is implemented by types that can be encoded into messages using /// a `ValueEncoder`. pub trait ValueEncode { - // TODO: Rename to encode_to(). - /// Attempts to encode `self` with the given encoder. - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError>; + // TODO: Rename to encode_to(). + /// Attempts to encode `self` with the given encoder. + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError>; } impl<'a> ValueEncoder<'a> { - /// Wraps the given buffer for encoding values into. - /// - /// Encoded bytes are appended. The buffer is not pre-cleared. - pub fn new(buffer: &'a mut BytesMut) -> Self { - ValueEncoder { buffer: buffer } - } - - /// Encodes the given u32 value into the underlying buffer. - pub fn encode_u32(&mut self, val: u32) -> Result<(), ValueEncodeError> { - self.buffer.extend_from_slice(&encode_u32(val)); - Ok(()) - } - - /// Encodes the given u16 value into the underlying buffer. - pub fn encode_u16(&mut self, val: u16) -> Result<(), ValueEncodeError> { - self.encode_u32(val as u32) - } - - /// Encodes the given boolean value into the underlying buffer. - pub fn encode_bool(&mut self, val: bool) -> Result<(), ValueEncodeError> { - self.buffer.put_u8(val as u8); - Ok(()) - } - - /// Encodes the given string into the underlying buffer. - pub fn encode_string(&mut self, val: &str) -> Result<(), ValueEncodeError> { - // Reserve space for the length prefix. - let mut prefixer = Prefixer::new(self.buffer); - - // Encode the string. This cannot fail because any unmappable characters - // are replaced. - let (bytes, encoding, _did_replace) = WINDOWS_1252.encode(val); - - // Encodings in full generality can have a different "output encoding" - // but that is not the case of Windows-1252. - assert_eq!(encoding, WINDOWS_1252); - - prefixer.suffix_mut().extend_from_slice(&bytes); - - // Write the length prefix in the space we initially reserved for it. - if let Err(prefixer) = prefixer.finalize() { - return Err(ValueEncodeError::StringTooLong { - string: val.to_string(), - length: prefixer.suffix().len(), - }); - } - - Ok(()) - } - - /// Encodes the given value into the underlying buffer. - /// - /// Allows for easy encoding with type inference: - /// ``` - /// encoder.encode(&Foo::new(bar))?; - /// ``` - pub fn encode( - &mut self, - val: &T, - ) -> Result<(), ValueEncodeError> { - val.encode(self) - } + /// Wraps the given buffer for encoding values into. + /// + /// Encoded bytes are appended. The buffer is not pre-cleared. + pub fn new(buffer: &'a mut BytesMut) -> Self { + ValueEncoder { buffer: buffer } + } + + /// Encodes the given u32 value into the underlying buffer. + pub fn encode_u32(&mut self, val: u32) -> Result<(), ValueEncodeError> { + self.buffer.extend_from_slice(&encode_u32(val)); + Ok(()) + } + + /// Encodes the given u16 value into the underlying buffer. + pub fn encode_u16(&mut self, val: u16) -> Result<(), ValueEncodeError> { + self.encode_u32(val as u32) + } + + /// Encodes the given boolean value into the underlying buffer. + pub fn encode_bool(&mut self, val: bool) -> Result<(), ValueEncodeError> { + self.buffer.put_u8(val as u8); + Ok(()) + } + + /// Encodes the given string into the underlying buffer. + pub fn encode_string(&mut self, val: &str) -> Result<(), ValueEncodeError> { + // Reserve space for the length prefix. + let mut prefixer = Prefixer::new(self.buffer); + + // Encode the string. This cannot fail because any unmappable characters + // are replaced. + let (bytes, encoding, _did_replace) = WINDOWS_1252.encode(val); + + // Encodings in full generality can have a different "output encoding" + // but that is not the case of Windows-1252. + assert_eq!(encoding, WINDOWS_1252); + + prefixer.suffix_mut().extend_from_slice(&bytes); + + // Write the length prefix in the space we initially reserved for it. + if let Err(prefixer) = prefixer.finalize() { + return Err(ValueEncodeError::StringTooLong { + string: val.to_string(), + length: prefixer.suffix().len(), + }); + } + + Ok(()) + } + + /// Encodes the given value into the underlying buffer. + /// + /// Allows for easy encoding with type inference: + /// ``` + /// encoder.encode(&Foo::new(bar))?; + /// ``` + pub fn encode( + &mut self, + val: &T, + ) -> Result<(), ValueEncodeError> { + val.encode(self) + } } impl ValueEncode for u32 { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - encoder.encode_u32(*self) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + encoder.encode_u32(*self) + } } impl ValueEncode for u16 { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - encoder.encode_u16(*self) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + encoder.encode_u16(*self) + } } impl ValueEncode for bool { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - encoder.encode_bool(*self) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + encoder.encode_bool(*self) + } } impl ValueEncode for net::Ipv4Addr { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - encoder.encode_u32(u32::from(*self)) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + encoder.encode_u32(u32::from(*self)) + } } // It would be nice to use AsRef, or Deref for the following @@ -464,63 +430,45 @@ impl ValueEncode for net::Ipv4Addr { // Value{De,En}code) but it is not really worth the hassle. impl ValueEncode for str { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - encoder.encode_string(self) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + encoder.encode_string(self) + } } impl ValueEncode for String { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - encoder.encode_string(self) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + encoder.encode_string(self) + } } impl<'a> ValueEncode for &'a String { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - encoder.encode_string(*self) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + encoder.encode_string(*self) + } } impl ValueEncode for (T, U) { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - self.0.encode(encoder)?; - self.1.encode(encoder) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + self.0.encode(encoder)?; + self.1.encode(encoder) + } } impl ValueEncode for [T] { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - encoder.encode_u32(self.len() as u32)?; - for ref item in self { - item.encode(encoder)?; - } - Ok(()) + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + encoder.encode_u32(self.len() as u32)?; + for ref item in self { + item.encode(encoder)?; } + Ok(()) + } } impl ValueEncode for Vec { - fn encode( - &self, - encoder: &mut ValueEncoder, - ) -> Result<(), ValueEncodeError> { - let slice: &[T] = &*self; - slice.encode(encoder) - } + fn encode(&self, encoder: &mut ValueEncoder) -> Result<(), ValueEncodeError> { + let slice: &[T] = &*self; + slice.encode(encoder) + } } /*=======* @@ -529,439 +477,433 @@ impl ValueEncode for Vec { #[cfg(test)] pub mod tests { - use std::fmt; - use std::net; - use std::u16; - use std::u32; + use std::fmt; + use std::net; + use std::u16; + use std::u32; + + use bytes::{BufMut, BytesMut}; + + use super::{ + ValueDecode, ValueDecodeError, ValueDecoder, ValueEncode, ValueEncoder, + }; - use bytes::{BufMut, BytesMut}; + // Declared here because assert_eq!(bytes, &[]) fails to infer types. + const EMPTY_BYTES: &'static [u8] = &[]; - use super::{ - ValueDecode, ValueDecodeError, ValueDecoder, ValueEncode, ValueEncoder, - }; + pub fn roundtrip(input: T) + where + T: fmt::Debug + Eq + PartialEq + ValueEncode + ValueDecode, + { + let mut bytes = BytesMut::new(); - // Declared here because assert_eq!(bytes, &[]) fails to infer types. - const EMPTY_BYTES: &'static [u8] = &[]; + ValueEncoder::new(&mut bytes).encode(&input).unwrap(); + let output = ValueDecoder::new(&bytes).decode::().unwrap(); - pub fn roundtrip(input: T) - where - T: fmt::Debug + Eq + PartialEq + ValueEncode + ValueDecode, - { - let mut bytes = BytesMut::new(); + assert_eq!(output, input); + } - ValueEncoder::new(&mut bytes).encode(&input).unwrap(); - let output = ValueDecoder::new(&bytes).decode::().unwrap(); + // A few integers and their corresponding byte encodings. + const U32_ENCODINGS: [(u32, [u8; 4]); 8] = [ + (0, [0, 0, 0, 0]), + (255, [255, 0, 0, 0]), + (256, [0, 1, 0, 0]), + (65535, [255, 255, 0, 0]), + (65536, [0, 0, 1, 0]), + (16777215, [255, 255, 255, 0]), + (16777216, [0, 0, 0, 1]), + (u32::MAX, [255, 255, 255, 255]), + ]; - assert_eq!(output, input); - } + #[test] + fn encode_u32() { + for &(val, ref encoded_bytes) in &U32_ENCODINGS { + let mut bytes = BytesMut::new(); + bytes.put_u8(13); - // A few integers and their corresponding byte encodings. - const U32_ENCODINGS: [(u32, [u8; 4]); 8] = [ - (0, [0, 0, 0, 0]), - (255, [255, 0, 0, 0]), - (256, [0, 1, 0, 0]), - (65535, [255, 255, 0, 0]), - (65536, [0, 0, 1, 0]), - (16777215, [255, 255, 255, 0]), - (16777216, [0, 0, 0, 1]), - (u32::MAX, [255, 255, 255, 255]), - ]; - - #[test] - fn encode_u32() { - for &(val, ref encoded_bytes) in &U32_ENCODINGS { - let mut bytes = BytesMut::new(); - bytes.put_u8(13); - - let mut expected_bytes = BytesMut::new(); - expected_bytes.put_u8(13); - expected_bytes.extend_from_slice(encoded_bytes); - - ValueEncoder::new(&mut bytes).encode_u32(val).unwrap(); - assert_eq!(bytes, expected_bytes); - } + let mut expected_bytes = BytesMut::new(); + expected_bytes.put_u8(13); + expected_bytes.extend_from_slice(encoded_bytes); + + ValueEncoder::new(&mut bytes).encode_u32(val).unwrap(); + assert_eq!(bytes, expected_bytes); } + } + + #[test] + fn decode_u32() { + for &(expected_val, ref bytes) in &U32_ENCODINGS { + let buffer = bytes.to_vec(); + let mut decoder = ValueDecoder::new(&buffer); - #[test] - fn decode_u32() { - for &(expected_val, ref bytes) in &U32_ENCODINGS { - let buffer = bytes.to_vec(); - let mut decoder = ValueDecoder::new(&buffer); + let val = decoder.decode::().unwrap(); + + assert_eq!(val, expected_val); + assert_eq!(decoder.bytes(), EMPTY_BYTES); + } + } + + #[test] + fn roundtrip_u32() { + for &(val, _) in &U32_ENCODINGS { + roundtrip(val) + } + } + + #[test] + fn decode_u32_unexpected_eof() { + let mut buffer = BytesMut::new(); + buffer.put_u8(13); + + let mut decoder = ValueDecoder::new(&buffer); + + let result = decoder.decode::(); + + assert_eq!( + result, + Err(ValueDecodeError::NotEnoughData { + expected: 4, + remaining: 1, + position: 0, + }) + ); + assert_eq!(decoder.bytes(), &[13]); + } + + #[test] + fn encode_bool_false() { + let mut bytes = BytesMut::new(); + bytes.put_u8(13); + + ValueEncoder::new(&mut bytes).encode_bool(false).unwrap(); + + assert_eq!(bytes, vec![13, 0]); + } + + #[test] + fn encode_bool_true() { + let mut bytes = BytesMut::new(); + bytes.put_u8(13); + + ValueEncoder::new(&mut bytes).encode_bool(true).unwrap(); + + assert_eq!(bytes, vec![13, 1]); + } + + #[test] + fn decode_bool_false() { + let buffer = vec![0]; + let mut decoder = ValueDecoder::new(&buffer); + + let val = decoder.decode::().unwrap(); + + assert!(!val); + assert_eq!(decoder.bytes(), EMPTY_BYTES); + } + + #[test] + fn decode_bool_true() { + let buffer = vec![1]; + let mut decoder = ValueDecoder::new(&buffer); + + let val = decoder.decode::().unwrap(); + + assert!(val); + assert_eq!(decoder.bytes(), EMPTY_BYTES); + } - let val = decoder.decode::().unwrap(); + #[test] + fn decode_bool_invalid() { + let buffer = vec![42]; - assert_eq!(val, expected_val); - assert_eq!(decoder.bytes(), EMPTY_BYTES); - } - } + let result = ValueDecoder::new(&buffer).decode::(); - #[test] - fn roundtrip_u32() { - for &(val, _) in &U32_ENCODINGS { - roundtrip(val) - } - } - - #[test] - fn decode_u32_unexpected_eof() { - let mut buffer = BytesMut::new(); - buffer.put_u8(13); - - let mut decoder = ValueDecoder::new(&buffer); - - let result = decoder.decode::(); - - assert_eq!( - result, - Err(ValueDecodeError::NotEnoughData { - expected: 4, - remaining: 1, - position: 0, - }) - ); - assert_eq!(decoder.bytes(), &[13]); - } + assert_eq!( + result, + Err(ValueDecodeError::InvalidBool { + value: 42, + position: 0, + }) + ); + } - #[test] - fn encode_bool_false() { - let mut bytes = BytesMut::new(); - bytes.put_u8(13); + #[test] + fn decode_bool_unexpected_eof() { + let buffer = vec![]; - ValueEncoder::new(&mut bytes).encode_bool(false).unwrap(); + let result = ValueDecoder::new(&buffer).decode::(); - assert_eq!(bytes, vec![13, 0]); - } - - #[test] - fn encode_bool_true() { - let mut bytes = BytesMut::new(); - bytes.put_u8(13); + assert_eq!( + result, + Err(ValueDecodeError::NotEnoughData { + expected: 1, + remaining: 0, + position: 0, + }) + ); + } - ValueEncoder::new(&mut bytes).encode_bool(true).unwrap(); - - assert_eq!(bytes, vec![13, 1]); + #[test] + fn roundtrip_bool() { + roundtrip(false); + roundtrip(true); + } + + #[test] + fn encode_u16() { + for &(val, ref encoded_bytes) in &U32_ENCODINGS { + if val > u16::MAX as u32 { + continue; + } + + let mut bytes = BytesMut::new(); + bytes.put_u8(13); + + let mut expected_bytes = BytesMut::new(); + expected_bytes.put_u8(13); + expected_bytes.extend(encoded_bytes); + + ValueEncoder::new(&mut bytes).encode(&(val as u16)).unwrap(); + + assert_eq!(bytes, expected_bytes); } + } - #[test] - fn decode_bool_false() { - let buffer = vec![0]; - let mut decoder = ValueDecoder::new(&buffer); - - let val = decoder.decode::().unwrap(); + #[test] + fn decode_u16() { + for &(expected_val, ref buffer) in &U32_ENCODINGS { + let mut decoder = ValueDecoder::new(buffer); - assert!(!val); + if expected_val <= u16::MAX as u32 { + let val = decoder.decode::().unwrap(); + assert_eq!(val, expected_val as u16); assert_eq!(decoder.bytes(), EMPTY_BYTES); - } - - #[test] - fn decode_bool_true() { - let buffer = vec![1]; - let mut decoder = ValueDecoder::new(&buffer); - - let val = decoder.decode::().unwrap(); - - assert!(val); - assert_eq!(decoder.bytes(), EMPTY_BYTES); - } - - #[test] - fn decode_bool_invalid() { - let buffer = vec![42]; - - let result = ValueDecoder::new(&buffer).decode::(); - + } else { assert_eq!( - result, - Err(ValueDecodeError::InvalidBool { - value: 42, - position: 0, - }) + decoder.decode::(), + Err(ValueDecodeError::InvalidU16 { + value: expected_val, + position: 0, + }) ); + } } + } - #[test] - fn decode_bool_unexpected_eof() { - let buffer = vec![]; + #[test] + fn decode_u16_unexpected_eof() { + let buffer = vec![]; + let mut decoder = ValueDecoder::new(&buffer); - let result = ValueDecoder::new(&buffer).decode::(); + let result = decoder.decode::(); - assert_eq!( - result, - Err(ValueDecodeError::NotEnoughData { - expected: 1, - remaining: 0, - position: 0, - }) - ); - } + assert_eq!( + result, + Err(ValueDecodeError::NotEnoughData { + expected: 4, + remaining: 0, + position: 0, + }) + ); + } - #[test] - fn roundtrip_bool() { - roundtrip(false); - roundtrip(true); + #[test] + fn roundtrip_u16() { + for &(val, _) in &U32_ENCODINGS { + if val <= u16::MAX as u32 { + roundtrip(val) + } } + } - #[test] - fn encode_u16() { - for &(val, ref encoded_bytes) in &U32_ENCODINGS { - if val > u16::MAX as u32 { - continue; - } - - let mut bytes = BytesMut::new(); - bytes.put_u8(13); + #[test] + fn encode_ipv4() { + for &(val, ref encoded_bytes) in &U32_ENCODINGS { + let mut bytes = BytesMut::new(); + bytes.put_u8(13); - let mut expected_bytes = BytesMut::new(); - expected_bytes.put_u8(13); - expected_bytes.extend(encoded_bytes); + let mut expected_bytes = BytesMut::new(); + expected_bytes.put_u8(13); + expected_bytes.extend(encoded_bytes); - ValueEncoder::new(&mut bytes).encode(&(val as u16)).unwrap(); + let addr = net::Ipv4Addr::from(val); - assert_eq!(bytes, expected_bytes); - } - } + ValueEncoder::new(&mut bytes).encode(&addr).unwrap(); - #[test] - fn decode_u16() { - for &(expected_val, ref buffer) in &U32_ENCODINGS { - let mut decoder = ValueDecoder::new(buffer); - - if expected_val <= u16::MAX as u32 { - let val = decoder.decode::().unwrap(); - assert_eq!(val, expected_val as u16); - assert_eq!(decoder.bytes(), EMPTY_BYTES); - } else { - assert_eq!( - decoder.decode::(), - Err(ValueDecodeError::InvalidU16 { - value: expected_val, - position: 0, - }) - ); - } - } + assert_eq!(bytes, expected_bytes); } + } - #[test] - fn decode_u16_unexpected_eof() { - let buffer = vec![]; - let mut decoder = ValueDecoder::new(&buffer); + #[test] + fn decode_ipv4() { + for &(expected_val, ref buffer) in &U32_ENCODINGS { + let mut decoder = ValueDecoder::new(buffer); - let result = decoder.decode::(); + let val = decoder.decode::().unwrap(); - assert_eq!( - result, - Err(ValueDecodeError::NotEnoughData { - expected: 4, - remaining: 0, - position: 0, - }) - ); + assert_eq!(val, net::Ipv4Addr::from(expected_val)); + assert_eq!(decoder.bytes(), EMPTY_BYTES); } + } - #[test] - fn roundtrip_u16() { - for &(val, _) in &U32_ENCODINGS { - if val <= u16::MAX as u32 { - roundtrip(val) - } - } + #[test] + fn roundtrip_ipv4() { + for &(val, _) in &U32_ENCODINGS { + roundtrip(net::Ipv4Addr::from(val)) } + } - #[test] - fn encode_ipv4() { - for &(val, ref encoded_bytes) in &U32_ENCODINGS { - let mut bytes = BytesMut::new(); - bytes.put_u8(13); + // A few strings and their corresponding encodings. + const STRING_ENCODINGS: [(&'static str, &'static [u8]); 4] = [ + ("", &[0, 0, 0, 0]), + ("hey!", &[4, 0, 0, 0, 104, 101, 121, 33]), + // Windows 1252 specific codepoints. + ("‘’“”€", &[5, 0, 0, 0, 145, 146, 147, 148, 128]), + // Undefined codepoints. They are not decoded to representable + // characters, but they do not generate errors either. In particular, + // they survive round-trips through the codec. + ( + "\u{81}\u{8D}\u{8F}\u{90}\u{9D}", + &[5, 0, 0, 0, 0x81, 0x8D, 0x8F, 0x90, 0x9D], + ), + ]; - let mut expected_bytes = BytesMut::new(); - expected_bytes.put_u8(13); - expected_bytes.extend(encoded_bytes); + #[test] + fn encode_string() { + for &(string, encoded_bytes) in &STRING_ENCODINGS { + let mut bytes = BytesMut::new(); + bytes.put_u8(13); - let addr = net::Ipv4Addr::from(val); + let mut expected_bytes = BytesMut::new(); + expected_bytes.put_u8(13); + expected_bytes.extend(encoded_bytes); - ValueEncoder::new(&mut bytes).encode(&addr).unwrap(); + ValueEncoder::new(&mut bytes).encode_string(string).unwrap(); - assert_eq!(bytes, expected_bytes); - } + assert_eq!(bytes, expected_bytes); } + } - #[test] - fn decode_ipv4() { - for &(expected_val, ref buffer) in &U32_ENCODINGS { - let mut decoder = ValueDecoder::new(buffer); - - let val = decoder.decode::().unwrap(); - - assert_eq!(val, net::Ipv4Addr::from(expected_val)); - assert_eq!(decoder.bytes(), EMPTY_BYTES); - } - } + #[test] + fn encode_string_with_unencodable_characters() { + let mut bytes = BytesMut::new(); - #[test] - fn roundtrip_ipv4() { - for &(val, _) in &U32_ENCODINGS { - roundtrip(net::Ipv4Addr::from(val)) - } - } + ValueEncoder::new(&mut bytes).encode_string("你好").unwrap(); - // A few strings and their corresponding encodings. - const STRING_ENCODINGS: [(&'static str, &'static [u8]); 4] = [ - ("", &[0, 0, 0, 0]), - ("hey!", &[4, 0, 0, 0, 104, 101, 121, 33]), - // Windows 1252 specific codepoints. - ("‘’“”€", &[5, 0, 0, 0, 145, 146, 147, 148, 128]), - // Undefined codepoints. They are not decoded to representable - // characters, but they do not generate errors either. In particular, - // they survive round-trips through the codec. - ( - "\u{81}\u{8D}\u{8F}\u{90}\u{9D}", - &[5, 0, 0, 0, 0x81, 0x8D, 0x8F, 0x90, 0x9D], - ), - ]; - - #[test] - fn encode_string() { - for &(string, encoded_bytes) in &STRING_ENCODINGS { - let mut bytes = BytesMut::new(); - bytes.put_u8(13); - - let mut expected_bytes = BytesMut::new(); - expected_bytes.put_u8(13); - expected_bytes.extend(encoded_bytes); - - ValueEncoder::new(&mut bytes).encode_string(string).unwrap(); - - assert_eq!(bytes, expected_bytes); - } - } + // Characters not in the Windows 1252 codepage are replaced with their + // decimal representations. Thus the output is longer than the input. + assert_eq!(bytes[0..4], [16, 0, 0, 0]); + assert_eq!(&bytes[4..], b"你好"); - #[test] - fn encode_string_with_unencodable_characters() { - let mut bytes = BytesMut::new(); + // The replaced characters are not decoded back to their original values. + assert_eq!( + ValueDecoder::new(&bytes).decode_string().unwrap(), + "你好" + ); + } - ValueEncoder::new(&mut bytes).encode_string("你好").unwrap(); + #[test] + fn decode_string() { + for &(expected_string, buffer) in &STRING_ENCODINGS { + let mut decoder = ValueDecoder::new(&buffer); - // Characters not in the Windows 1252 codepage are replaced with their - // decimal representations. Thus the output is longer than the input. - assert_eq!(bytes[0..4], [16, 0, 0, 0]); - assert_eq!(&bytes[4..], b"你好"); + let string = decoder.decode::().unwrap(); - // The replaced characters are not decoded back to their original values. - assert_eq!( - ValueDecoder::new(&bytes).decode_string().unwrap(), - "你好" - ); + assert_eq!(string, expected_string); + assert_eq!(decoder.bytes(), EMPTY_BYTES); } + } - #[test] - fn decode_string() { - for &(expected_string, buffer) in &STRING_ENCODINGS { - let mut decoder = ValueDecoder::new(&buffer); - - let string = decoder.decode::().unwrap(); - - assert_eq!(string, expected_string); - assert_eq!(decoder.bytes(), EMPTY_BYTES); - } + #[test] + fn roundtrip_string() { + for &(string, _) in &STRING_ENCODINGS { + roundtrip(string.to_string()) } + } - #[test] - fn roundtrip_string() { - for &(string, _) in &STRING_ENCODINGS { - roundtrip(string.to_string()) - } - } + #[test] + fn encode_pair_u32_string() { + let mut bytes = BytesMut::new(); + bytes.put_u8(13); - #[test] - fn encode_pair_u32_string() { - let mut bytes = BytesMut::new(); - bytes.put_u8(13); + let mut expected_bytes = BytesMut::new(); + expected_bytes.put_u8(13); - let mut expected_bytes = BytesMut::new(); - expected_bytes.put_u8(13); + let (integer, ref expected_integer_bytes) = U32_ENCODINGS[0]; + let (string, expected_string_bytes) = STRING_ENCODINGS[0]; - let (integer, ref expected_integer_bytes) = U32_ENCODINGS[0]; - let (string, expected_string_bytes) = STRING_ENCODINGS[0]; + expected_bytes.extend(expected_integer_bytes); + expected_bytes.extend(expected_string_bytes); - expected_bytes.extend(expected_integer_bytes); - expected_bytes.extend(expected_string_bytes); + ValueEncoder::new(&mut bytes) + .encode(&(integer, string.to_string())) + .unwrap(); - ValueEncoder::new(&mut bytes) - .encode(&(integer, string.to_string())) - .unwrap(); + assert_eq!(bytes, expected_bytes); + } - assert_eq!(bytes, expected_bytes); - } + #[test] + fn decode_pair_u32_string() { + let mut buffer = vec![]; - #[test] - fn decode_pair_u32_string() { - let mut buffer = vec![]; + let (expected_integer, ref integer_bytes) = U32_ENCODINGS[0]; + let (expected_string, string_bytes) = STRING_ENCODINGS[0]; - let (expected_integer, ref integer_bytes) = U32_ENCODINGS[0]; - let (expected_string, string_bytes) = STRING_ENCODINGS[0]; + buffer.extend(integer_bytes); + buffer.extend(string_bytes); - buffer.extend(integer_bytes); - buffer.extend(string_bytes); + let mut decoder = ValueDecoder::new(&buffer); - let mut decoder = ValueDecoder::new(&buffer); + let pair = decoder.decode::<(u32, String)>().unwrap(); - let pair = decoder.decode::<(u32, String)>().unwrap(); + assert_eq!(pair, (expected_integer, expected_string.to_string())); + assert_eq!(decoder.bytes(), EMPTY_BYTES); + } - assert_eq!(pair, (expected_integer, expected_string.to_string())); - assert_eq!(decoder.bytes(), EMPTY_BYTES); - } + #[test] + fn roundtrip_pair_u32_string() { + roundtrip((42u32, "hello world!".to_string())) + } - #[test] - fn roundtrip_pair_u32_string() { - roundtrip((42u32, "hello world!".to_string())) - } + #[test] + fn encode_u32_vector() { + let mut vec = vec![]; - #[test] - fn encode_u32_vector() { - let mut vec = vec![]; + let mut expected_bytes = BytesMut::new(); + expected_bytes.extend_from_slice(&[13, U32_ENCODINGS.len() as u8, 0, 0, 0]); - let mut expected_bytes = BytesMut::new(); - expected_bytes.extend_from_slice(&[ - 13, - U32_ENCODINGS.len() as u8, - 0, - 0, - 0, - ]); + for &(val, ref encoded_bytes) in &U32_ENCODINGS { + vec.push(val); + expected_bytes.extend(encoded_bytes); + } - for &(val, ref encoded_bytes) in &U32_ENCODINGS { - vec.push(val); - expected_bytes.extend(encoded_bytes); - } + let mut bytes = BytesMut::new(); + bytes.put_u8(13); - let mut bytes = BytesMut::new(); - bytes.put_u8(13); + ValueEncoder::new(&mut bytes).encode(&vec).unwrap(); - ValueEncoder::new(&mut bytes).encode(&vec).unwrap(); + assert_eq!(bytes, expected_bytes); + } - assert_eq!(bytes, expected_bytes); + #[test] + fn decode_u32_vector() { + let mut expected_vec = vec![]; + let mut buffer = vec![U32_ENCODINGS.len() as u8, 0, 0, 0]; + for &(expected_val, ref encoded_bytes) in &U32_ENCODINGS { + expected_vec.push(expected_val); + buffer.extend(encoded_bytes); } - #[test] - fn decode_u32_vector() { - let mut expected_vec = vec![]; - let mut buffer = vec![U32_ENCODINGS.len() as u8, 0, 0, 0]; - for &(expected_val, ref encoded_bytes) in &U32_ENCODINGS { - expected_vec.push(expected_val); - buffer.extend(encoded_bytes); - } + let mut decoder = ValueDecoder::new(&buffer); - let mut decoder = ValueDecoder::new(&buffer); + let vec = decoder.decode::>().unwrap(); - let vec = decoder.decode::>().unwrap(); + assert_eq!(vec, expected_vec); + assert_eq!(decoder.bytes(), EMPTY_BYTES); + } - assert_eq!(vec, expected_vec); - assert_eq!(decoder.bytes(), EMPTY_BYTES); - } - - #[test] - fn roundtrip_u32_vector() { - roundtrip(vec![0u32, 1, 2, 3, 4, 5, 6, 7, 8, 9]) - } + #[test] + fn roundtrip_u32_vector() { + roundtrip(vec![0u32, 1, 2, 3, 4, 5, 6, 7, 8, 9]) + } } diff --git a/src/room.rs b/src/room.rs index 4205b89..0e954b4 100644 --- a/src/room.rs +++ b/src/room.rs @@ -8,35 +8,35 @@ use crate::proto::{server, User}; /// This enumeration is the list of possible membership states for a chat room. #[derive(Clone, Copy, Debug, Eq, PartialEq, RustcDecodable, RustcEncodable)] pub enum Membership { - /// The user is not a member of this room. - NonMember, - /// The user has requested to join the room, but hasn't heard back from the - /// server yet. - Joining, - /// The user is a member of the room. - Member, - /// The user has request to leave the room, but hasn't heard back from the - /// server yet. - Leaving, + /// The user is not a member of this room. + NonMember, + /// The user has requested to join the room, but hasn't heard back from the + /// server yet. + Joining, + /// The user is a member of the room. + Member, + /// The user has request to leave the room, but hasn't heard back from the + /// server yet. + Leaving, } /// This enumeration is the list of visibility types for rooms that the user is /// a member of. #[derive(Clone, Copy, Debug, Eq, PartialEq, RustcDecodable, RustcEncodable)] pub enum Visibility { - /// This room is visible to any user. - Public, - /// This room is visible only to members, and the user owns it. - PrivateOwned, - /// This room is visible only to members, and someone else owns it. - PrivateOther, + /// This room is visible to any user. + Public, + /// This room is visible only to members, and the user owns it. + PrivateOwned, + /// This room is visible only to members, and someone else owns it. + PrivateOther, } /// This structure contains a chat room message. #[derive(Clone, Debug, Eq, PartialEq, RustcDecodable, RustcEncodable)] pub struct Message { - pub user_name: String, - pub message: String, + pub user_name: String, + pub message: String, } /// This structure contains the last known information about a chat room. @@ -44,356 +44,342 @@ pub struct Message { /// room hash table. #[derive(Clone, Debug, Eq, PartialEq, RustcDecodable, RustcEncodable)] pub struct Room { - /// The membership state of the user for the room. - pub membership: Membership, - /// The visibility of the room. - pub visibility: Visibility, - /// True if the user is one of the room's operators, False if the user is a - /// regular member. - pub operated: bool, - /// The number of users that are members of the room. - pub user_count: usize, - /// The name of the room's owner, if any. - pub owner: Option, - /// The names of the room's operators. - pub operators: collections::HashSet, - /// The names of the room's members. - pub members: collections::HashSet, - /// The messages sent to this chat room, in chronological order. - pub messages: Vec, - /// The tickers displayed in this room. - pub tickers: Vec<(String, String)>, + /// The membership state of the user for the room. + pub membership: Membership, + /// The visibility of the room. + pub visibility: Visibility, + /// True if the user is one of the room's operators, False if the user is a + /// regular member. + pub operated: bool, + /// The number of users that are members of the room. + pub user_count: usize, + /// The name of the room's owner, if any. + pub owner: Option, + /// The names of the room's operators. + pub operators: collections::HashSet, + /// The names of the room's members. + pub members: collections::HashSet, + /// The messages sent to this chat room, in chronological order. + pub messages: Vec, + /// The tickers displayed in this room. + pub tickers: Vec<(String, String)>, } impl Room { - /// Creates a new room with the given visibility and user count. - fn new(visibility: Visibility, user_count: usize) -> Self { - Room { - membership: Membership::NonMember, - visibility: visibility, - operated: false, - user_count: user_count, - owner: None, - operators: collections::HashSet::new(), - members: collections::HashSet::new(), - messages: Vec::new(), - tickers: Vec::new(), - } + /// Creates a new room with the given visibility and user count. + fn new(visibility: Visibility, user_count: usize) -> Self { + Room { + membership: Membership::NonMember, + visibility: visibility, + operated: false, + user_count: user_count, + owner: None, + operators: collections::HashSet::new(), + members: collections::HashSet::new(), + messages: Vec::new(), + tickers: Vec::new(), } + } } /// The error returned by RoomMap functions. #[derive(Debug)] pub enum Error { - RoomNotFound(String), - MembershipChangeInvalid(Membership, Membership), + RoomNotFound(String), + MembershipChangeInvalid(Membership, Membership), } impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - Error::RoomNotFound(ref room_name) => { - write!(f, "room {:?} not found", room_name) - } - - Error::MembershipChangeInvalid(old_membership, new_membership) => { - write!( - f, - "cannot change membership from {:?} to {:?}", - old_membership, new_membership - ) - } - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Error::RoomNotFound(ref room_name) => { + write!(f, "room {:?} not found", room_name) + } + + Error::MembershipChangeInvalid(old_membership, new_membership) => { + write!( + f, + "cannot change membership from {:?} to {:?}", + old_membership, new_membership + ) + } } + } } impl error::Error for Error { - fn description(&self) -> &str { - match *self { - Error::RoomNotFound(_) => "room not found", - Error::MembershipChangeInvalid(_, _) => "cannot change membership", - } + fn description(&self) -> &str { + match *self { + Error::RoomNotFound(_) => "room not found", + Error::MembershipChangeInvalid(_, _) => "cannot change membership", } + } } /// Contains the mapping from room names to room data and provides a clean /// interface to interact with it. #[derive(Debug)] pub struct RoomMap { - /// The actual map from room names to room data. - map: collections::HashMap, + /// The actual map from room names to room data. + map: collections::HashMap, } impl RoomMap { - /// Creates an empty mapping. - pub fn new() -> Self { - RoomMap { - map: collections::HashMap::new(), - } + /// Creates an empty mapping. + pub fn new() -> Self { + RoomMap { + map: collections::HashMap::new(), } + } + + /// Looks up the given room name in the map, returning an immutable + /// reference to the associated data if found, or an error if not found. + fn get_strict(&self, room_name: &str) -> Result<&Room, Error> { + match self.map.get(room_name) { + Some(room) => Ok(room), + None => Err(Error::RoomNotFound(room_name.to_string())), + } + } + + /// Looks up the given room name in the map, returning a mutable + /// reference to the associated data if found, or an error if not found. + fn get_mut_strict(&mut self, room_name: &str) -> Result<&mut Room, Error> { + match self.map.get_mut(room_name) { + Some(room) => Ok(room), + None => Err(Error::RoomNotFound(room_name.to_string())), + } + } + + /// Updates one room in the map based on the information received in + /// a RoomListResponse and the potential previously stored information. + fn update_one( + &mut self, + name: String, + visibility: Visibility, + user_count: u32, + old_map: &mut collections::HashMap, + ) { + let room = match old_map.remove(&name) { + None => Room::new(Visibility::Public, user_count as usize), + Some(mut room) => { + room.visibility = visibility; + room.user_count = user_count as usize; + room + } + }; + if let Some(_) = self.map.insert(name, room) { + error!("Room present twice in room list response"); + } + } - /// Looks up the given room name in the map, returning an immutable - /// reference to the associated data if found, or an error if not found. - fn get_strict(&self, room_name: &str) -> Result<&Room, Error> { - match self.map.get(room_name) { - Some(room) => Ok(room), - None => Err(Error::RoomNotFound(room_name.to_string())), - } + /// Updates the map to reflect the information contained in the given + /// server response. + pub fn set_room_list(&mut self, mut response: server::RoomListResponse) { + // Replace the old mapping with an empty one. + let mut old_map = mem::replace(&mut self.map, collections::HashMap::new()); + + // Add all public rooms. + for (name, user_count) in response.rooms.drain(..) { + self.update_one(name, Visibility::Public, user_count, &mut old_map); } - /// Looks up the given room name in the map, returning a mutable - /// reference to the associated data if found, or an error if not found. - fn get_mut_strict(&mut self, room_name: &str) -> Result<&mut Room, Error> { - match self.map.get_mut(room_name) { - Some(room) => Ok(room), - None => Err(Error::RoomNotFound(room_name.to_string())), - } + // Add all private, owned, rooms. + for (name, user_count) in response.owned_private_rooms.drain(..) { + self.update_one(name, Visibility::PrivateOwned, user_count, &mut old_map); } - /// Updates one room in the map based on the information received in - /// a RoomListResponse and the potential previously stored information. - fn update_one( - &mut self, - name: String, - visibility: Visibility, - user_count: u32, - old_map: &mut collections::HashMap, - ) { - let room = match old_map.remove(&name) { - None => Room::new(Visibility::Public, user_count as usize), - Some(mut room) => { - room.visibility = visibility; - room.user_count = user_count as usize; - room - } - }; - if let Some(_) = self.map.insert(name, room) { - error!("Room present twice in room list response"); - } + // Add all private, unowned, rooms. + for (name, user_count) in response.other_private_rooms.drain(..) { + self.update_one(name, Visibility::PrivateOther, user_count, &mut old_map); } - /// Updates the map to reflect the information contained in the given - /// server response. - pub fn set_room_list(&mut self, mut response: server::RoomListResponse) { - // Replace the old mapping with an empty one. - let mut old_map = - mem::replace(&mut self.map, collections::HashMap::new()); - - // Add all public rooms. - for (name, user_count) in response.rooms.drain(..) { - self.update_one(name, Visibility::Public, user_count, &mut old_map); - } - - // Add all private, owned, rooms. - for (name, user_count) in response.owned_private_rooms.drain(..) { - self.update_one( - name, - Visibility::PrivateOwned, - user_count, - &mut old_map, - ); - } - - // Add all private, unowned, rooms. - for (name, user_count) in response.other_private_rooms.drain(..) { - self.update_one( - name, - Visibility::PrivateOther, - user_count, - &mut old_map, - ); - } - - // Mark all operated rooms as necessary. - for name in response.operated_private_room_names.iter() { - match self.map.get_mut(name) { - Some(room) => room.operated = true, - None => error!("Room {} is operated but does not exist", name), - } - } + // Mark all operated rooms as necessary. + for name in response.operated_private_room_names.iter() { + match self.map.get_mut(name) { + Some(room) => room.operated = true, + None => error!("Room {} is operated but does not exist", name), + } } + } - /// Returns the list of (room name, room data) representing all known rooms. - pub fn get_room_list(&self) -> Vec<(String, Room)> { - let mut rooms = Vec::new(); - for (room_name, room) in self.map.iter() { - rooms.push((room_name.clone(), room.clone())); - } - rooms + /// Returns the list of (room name, room data) representing all known rooms. + pub fn get_room_list(&self) -> Vec<(String, Room)> { + let mut rooms = Vec::new(); + for (room_name, room) in self.map.iter() { + rooms.push((room_name.clone(), room.clone())); } + rooms + } + + /// Records that we are now trying to join the given room. + /// If the room is not found, or if its membership is not `NonMember`, + /// returns an error. + pub fn start_joining(&mut self, room_name: &str) -> Result<(), Error> { + let room = self.get_mut_strict(room_name)?; + + match room.membership { + Membership::NonMember => { + room.membership = Membership::Joining; + Ok(()) + } - /// Records that we are now trying to join the given room. - /// If the room is not found, or if its membership is not `NonMember`, - /// returns an error. - pub fn start_joining(&mut self, room_name: &str) -> Result<(), Error> { - let room = self.get_mut_strict(room_name)?; - - match room.membership { - Membership::NonMember => { - room.membership = Membership::Joining; - Ok(()) - } - - membership => Err(Error::MembershipChangeInvalid( - membership, - Membership::Joining, - )), - } + membership => Err(Error::MembershipChangeInvalid( + membership, + Membership::Joining, + )), + } + } + + /// Records that we are now a member of the given room and updates the room + /// information. + pub fn join( + &mut self, + room_name: &str, + owner: Option, + mut operators: Vec, + members: &[User], + ) -> Result<(), Error> { + // First look up the room struct. + let room = self.get_mut_strict(room_name)?; + + // Log what's happening. + if let Membership::Joining = room.membership { + info!("Joined room {:?}", room_name); + } else { + warn!( + "Joined room {:?} but membership was already {:?}", + room_name, room.membership + ); } - /// Records that we are now a member of the given room and updates the room - /// information. - pub fn join( - &mut self, - room_name: &str, - owner: Option, - mut operators: Vec, - members: &[User], - ) -> Result<(), Error> { - // First look up the room struct. - let room = self.get_mut_strict(room_name)?; - - // Log what's happening. - if let Membership::Joining = room.membership { - info!("Joined room {:?}", room_name); - } else { - warn!( - "Joined room {:?} but membership was already {:?}", - room_name, room.membership - ); - } - - // Update the room struct. - room.membership = Membership::Member; - room.user_count = members.len(); - room.owner = owner; - - room.operators.clear(); - for user_name in operators.drain(..) { - room.operators.insert(user_name); - } - - room.members.clear(); - for user in members { - room.members.insert(user.name.clone()); - } + // Update the room struct. + room.membership = Membership::Member; + room.user_count = members.len(); + room.owner = owner; - Ok(()) + room.operators.clear(); + for user_name in operators.drain(..) { + room.operators.insert(user_name); } - /// Records that we are now trying to leave the given room. - /// If the room is not found, or if its membership status is not `Member`, - /// returns an error. - pub fn start_leaving(&mut self, room_name: &str) -> Result<(), Error> { - let room = self.get_mut_strict(room_name)?; - - match room.membership { - Membership::Member => { - room.membership = Membership::Leaving; - Ok(()) - } - - membership => Err(Error::MembershipChangeInvalid( - membership, - Membership::Leaving, - )), - } + room.members.clear(); + for user in members { + room.members.insert(user.name.clone()); } - /// Records that we have now left the given room. - pub fn leave(&mut self, room_name: &str) -> Result<(), Error> { - let room = self.get_mut_strict(room_name)?; - - match room.membership { - Membership::Leaving => info!("Left room {:?}", room_name), + Ok(()) + } - membership => warn!( - "Left room {:?} with wrong membership: {:?}", - room_name, membership - ), - } + /// Records that we are now trying to leave the given room. + /// If the room is not found, or if its membership status is not `Member`, + /// returns an error. + pub fn start_leaving(&mut self, room_name: &str) -> Result<(), Error> { + let room = self.get_mut_strict(room_name)?; - room.membership = Membership::NonMember; + match room.membership { + Membership::Member => { + room.membership = Membership::Leaving; Ok(()) - } + } - /// Saves the given message as the last one in the given room. - pub fn add_message( - &mut self, - room_name: &str, - message: Message, - ) -> Result<(), Error> { - let room = self.get_mut_strict(room_name)?; - room.messages.push(message); - Ok(()) + membership => Err(Error::MembershipChangeInvalid( + membership, + Membership::Leaving, + )), } + } - /// Inserts the given user in the given room's set of members. - /// Returns an error if the room is not found. - pub fn insert_member( - &mut self, - room_name: &str, - user_name: String, - ) -> Result<(), Error> { - let room = self.get_mut_strict(room_name)?; - room.members.insert(user_name); - Ok(()) - } + /// Records that we have now left the given room. + pub fn leave(&mut self, room_name: &str) -> Result<(), Error> { + let room = self.get_mut_strict(room_name)?; - /// Removes the given user from the given room's set of members. - /// Returns an error if the room is not found. - pub fn remove_member( - &mut self, - room_name: &str, - user_name: &str, - ) -> Result<(), Error> { - let room = self.get_mut_strict(room_name)?; - room.members.remove(user_name); - Ok(()) - } + match room.membership { + Membership::Leaving => info!("Left room {:?}", room_name), - /*---------* - * Tickers * - *---------*/ - - pub fn set_tickers( - &mut self, - room_name: &str, - tickers: Vec<(String, String)>, - ) -> Result<(), Error> { - let room = self.get_mut_strict(room_name)?; - room.tickers = tickers; - Ok(()) + membership => warn!( + "Left room {:?} with wrong membership: {:?}", + room_name, membership + ), } + + room.membership = Membership::NonMember; + Ok(()) + } + + /// Saves the given message as the last one in the given room. + pub fn add_message( + &mut self, + room_name: &str, + message: Message, + ) -> Result<(), Error> { + let room = self.get_mut_strict(room_name)?; + room.messages.push(message); + Ok(()) + } + + /// Inserts the given user in the given room's set of members. + /// Returns an error if the room is not found. + pub fn insert_member( + &mut self, + room_name: &str, + user_name: String, + ) -> Result<(), Error> { + let room = self.get_mut_strict(room_name)?; + room.members.insert(user_name); + Ok(()) + } + + /// Removes the given user from the given room's set of members. + /// Returns an error if the room is not found. + pub fn remove_member( + &mut self, + room_name: &str, + user_name: &str, + ) -> Result<(), Error> { + let room = self.get_mut_strict(room_name)?; + room.members.remove(user_name); + Ok(()) + } + + /*---------* + * Tickers * + *---------*/ + + pub fn set_tickers( + &mut self, + room_name: &str, + tickers: Vec<(String, String)>, + ) -> Result<(), Error> { + let room = self.get_mut_strict(room_name)?; + room.tickers = tickers; + Ok(()) + } } #[cfg(test)] mod tests { - use crate::proto::server::RoomListResponse; - - use super::{Room, RoomMap, Visibility}; - - #[test] - fn room_map_new_is_empty() { - assert_eq!(RoomMap::new().get_room_list(), vec![]); - } - - #[test] - fn room_map_get_strict() { - let mut rooms = RoomMap::new(); - rooms.set_room_list(RoomListResponse { - rooms: vec![ - ("room a".to_string(), 42), - ("room b".to_string(), 1337), - ], - owned_private_rooms: vec![], - other_private_rooms: vec![], - operated_private_room_names: vec![], - }); - - assert_eq!( - rooms.get_strict("room a").unwrap(), - &Room::new(Visibility::Public, 42) - ); - } + use crate::proto::server::RoomListResponse; + + use super::{Room, RoomMap, Visibility}; + + #[test] + fn room_map_new_is_empty() { + assert_eq!(RoomMap::new().get_room_list(), vec![]); + } + + #[test] + fn room_map_get_strict() { + let mut rooms = RoomMap::new(); + rooms.set_room_list(RoomListResponse { + rooms: vec![("room a".to_string(), 42), ("room b".to_string(), 1337)], + owned_private_rooms: vec![], + other_private_rooms: vec![], + operated_private_room_names: vec![], + }); + + assert_eq!( + rooms.get_strict("room a").unwrap(), + &Room::new(Visibility::Public, 42) + ); + } } diff --git a/src/user.rs b/src/user.rs index 3cd7ef5..34855fb 100644 --- a/src/user.rs +++ b/src/user.rs @@ -7,171 +7,170 @@ use crate::proto::{User, UserStatus}; /// The error returned when a user name was not found in the user map. #[derive(Debug)] pub struct UserNotFoundError { - /// The name of the user that wasn't found. - user_name: String, + /// The name of the user that wasn't found. + user_name: String, } impl fmt::Display for UserNotFoundError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "user \"{}\" not found", self.user_name) - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "user \"{}\" not found", self.user_name) + } } impl error::Error for UserNotFoundError { - fn description(&self) -> &str { - "user not found" - } + fn description(&self) -> &str { + "user not found" + } } /// Contains the mapping from user names to user data and provides a clean /// interface to interact with it. #[derive(Debug)] pub struct UserMap { - /// The actual map from user names to user data and privileged status. - map: collections::HashMap, - /// The set of privileged users. - privileged: collections::HashSet, + /// The actual map from user names to user data and privileged status. + map: collections::HashMap, + /// The set of privileged users. + privileged: collections::HashSet, } impl UserMap { - /// Creates an empty mapping. - pub fn new() -> Self { - UserMap { - map: collections::HashMap::new(), - privileged: collections::HashSet::new(), - } - } - - /// Looks up the given user name in the map, returning an immutable - /// reference to the associated data if found. - pub fn get(&self, user_name: &str) -> Option<&User> { - self.map.get(user_name) - } - - /// Looks up the given user name in the map, returning a mutable reference - /// to the associated data if found, or an error if not found. - pub fn get_mut_strict( - &mut self, - user_name: &str, - ) -> Result<&mut User, UserNotFoundError> { - match self.map.get_mut(user_name) { - Some(user) => Ok(user), - None => Err(UserNotFoundError { - user_name: user_name.to_string(), - }), - } - } - - /// Inserts the given user info for the given user name in the mapping. - /// If there is already data under that name, it is replaced. - pub fn insert(&mut self, user: User) { - self.map.insert(user.name.clone(), user); - } - - /// Sets the given user's status to the given value, if such a user exists. - pub fn set_status( - &mut self, - user_name: &str, - status: UserStatus, - ) -> Result<(), UserNotFoundError> { - let user = self.get_mut_strict(user_name)?; - user.status = status; - Ok(()) - } - - /// Returns the list of (user name, user data) representing all known users. - pub fn get_list(&self) -> Vec<(String, User)> { - let mut users = Vec::new(); - for (user_name, user_data) in self.map.iter() { - users.push((user_name.clone(), user_data.clone())); - } - users - } - - /// Sets the set of privileged users to the given list. - pub fn set_all_privileged(&mut self, mut users: Vec) { - self.privileged.clear(); - for user_name in users.drain(..) { - self.privileged.insert(user_name); - } - } - - /// Returns a copy of the set of privileged users. - pub fn get_all_privileged(&self) -> Vec { - self.privileged.iter().map(|s| s.to_string()).collect() - } - - /// Marks the given user as privileged. - pub fn insert_privileged(&mut self, user_name: String) { - self.privileged.insert(user_name); - } - - /// Marks the given user as not privileged. - pub fn remove_privileged(&mut self, user_name: &str) { - self.privileged.remove(user_name); - } - - /// Checks if the given user is privileged. - pub fn is_privileged(&self, user_name: &str) -> bool { - self.privileged.contains(user_name) - } + /// Creates an empty mapping. + pub fn new() -> Self { + UserMap { + map: collections::HashMap::new(), + privileged: collections::HashSet::new(), + } + } + + /// Looks up the given user name in the map, returning an immutable + /// reference to the associated data if found. + pub fn get(&self, user_name: &str) -> Option<&User> { + self.map.get(user_name) + } + + /// Looks up the given user name in the map, returning a mutable reference + /// to the associated data if found, or an error if not found. + pub fn get_mut_strict( + &mut self, + user_name: &str, + ) -> Result<&mut User, UserNotFoundError> { + match self.map.get_mut(user_name) { + Some(user) => Ok(user), + None => Err(UserNotFoundError { + user_name: user_name.to_string(), + }), + } + } + + /// Inserts the given user info for the given user name in the mapping. + /// If there is already data under that name, it is replaced. + pub fn insert(&mut self, user: User) { + self.map.insert(user.name.clone(), user); + } + + /// Sets the given user's status to the given value, if such a user exists. + pub fn set_status( + &mut self, + user_name: &str, + status: UserStatus, + ) -> Result<(), UserNotFoundError> { + let user = self.get_mut_strict(user_name)?; + user.status = status; + Ok(()) + } + + /// Returns the list of (user name, user data) representing all known users. + pub fn get_list(&self) -> Vec<(String, User)> { + let mut users = Vec::new(); + for (user_name, user_data) in self.map.iter() { + users.push((user_name.clone(), user_data.clone())); + } + users + } + + /// Sets the set of privileged users to the given list. + pub fn set_all_privileged(&mut self, mut users: Vec) { + self.privileged.clear(); + for user_name in users.drain(..) { + self.privileged.insert(user_name); + } + } + + /// Returns a copy of the set of privileged users. + pub fn get_all_privileged(&self) -> Vec { + self.privileged.iter().map(|s| s.to_string()).collect() + } + + /// Marks the given user as privileged. + pub fn insert_privileged(&mut self, user_name: String) { + self.privileged.insert(user_name); + } + + /// Marks the given user as not privileged. + pub fn remove_privileged(&mut self, user_name: &str) { + self.privileged.remove(user_name); + } + + /// Checks if the given user is privileged. + pub fn is_privileged(&self, user_name: &str) -> bool { + self.privileged.contains(user_name) + } } #[cfg(test)] mod tests { - use super::UserMap; + use super::UserMap; - #[test] - fn new_is_empty() { - let users = UserMap::new(); + #[test] + fn new_is_empty() { + let users = UserMap::new(); - assert_eq!(users.get_list(), vec![]); - assert_eq!(users.get("bleep"), None); - } + assert_eq!(users.get_list(), vec![]); + assert_eq!(users.get("bleep"), None); + } - #[test] - fn set_get_all_privileged() { - let mut users = UserMap::new(); + #[test] + fn set_get_all_privileged() { + let mut users = UserMap::new(); - users - .set_all_privileged(vec!["bleep".to_string(), "bloop".to_string()]); + users.set_all_privileged(vec!["bleep".to_string(), "bloop".to_string()]); - let mut privileged = users.get_all_privileged(); - privileged.sort(); - assert_eq!(privileged, vec!["bleep".to_string(), "bloop".to_string()]); - } + let mut privileged = users.get_all_privileged(); + privileged.sort(); + assert_eq!(privileged, vec!["bleep".to_string(), "bloop".to_string()]); + } - #[test] - fn insert_privileged() { - let mut users = UserMap::new(); + #[test] + fn insert_privileged() { + let mut users = UserMap::new(); - users.insert_privileged("bleep".to_string()); - users.insert_privileged("bleep".to_string()); - users.insert_privileged("bloop".to_string()); + users.insert_privileged("bleep".to_string()); + users.insert_privileged("bleep".to_string()); + users.insert_privileged("bloop".to_string()); - let mut privileged = users.get_all_privileged(); - privileged.sort(); - assert_eq!(privileged, vec!["bleep".to_string(), "bloop".to_string()]); - } + let mut privileged = users.get_all_privileged(); + privileged.sort(); + assert_eq!(privileged, vec!["bleep".to_string(), "bloop".to_string()]); + } - #[test] - fn remove_privileged() { - let mut users = UserMap::new(); - users.insert_privileged("bleep".to_string()); - users.insert_privileged("bloop".to_string()); + #[test] + fn remove_privileged() { + let mut users = UserMap::new(); + users.insert_privileged("bleep".to_string()); + users.insert_privileged("bloop".to_string()); - users.remove_privileged("bleep"); + users.remove_privileged("bleep"); - let privileged = users.get_all_privileged(); - assert_eq!(privileged, vec!["bloop".to_string()]); - } + let privileged = users.get_all_privileged(); + assert_eq!(privileged, vec!["bloop".to_string()]); + } - #[test] - fn is_privileged() { - let mut users = UserMap::new(); - users.insert_privileged("bleep".to_string()); + #[test] + fn is_privileged() { + let mut users = UserMap::new(); + users.insert_privileged("bleep".to_string()); - assert!(users.is_privileged("bleep")); - assert!(!users.is_privileged("bloop")); - } + assert!(users.is_privileged("bleep")); + assert!(!users.is_privileged("bloop")); + } }