diff --git a/p2p/monero-p2p/src/client/handshaker.rs b/p2p/monero-p2p/src/client/handshaker.rs index d72e7edc..38858b49 100644 --- a/p2p/monero-p2p/src/client/handshaker.rs +++ b/p2p/monero-p2p/src/client/handshaker.rs @@ -15,15 +15,18 @@ use monero_wire::{ PING_OK_RESPONSE_STATUS_TEXT, }, common::PeerSupportFlags, - BasicNodeData, BucketError, CoreSyncData, Message, RequestMessage, ResponseMessage, + BasicNodeData, BucketError, CoreSyncData, LevinCommand, Message, RequestMessage, + ResponseMessage, }; use crate::{ AddressBook, AddressBookRequest, AddressBookResponse, ConnectionDirection, CoreSyncDataRequest, - CoreSyncDataResponse, CoreSyncSvc, NetworkZone, PeerRequestHandler, + CoreSyncDataResponse, CoreSyncSvc, MessageID, NetworkZone, PeerRequestHandler, MAX_PEERS_IN_PEER_LIST_MESSAGE, }; +const MAX_EAGER_PROTOCOL_MESSAGES: usize = 2; + #[derive(Debug, thiserror::Error)] pub enum HandshakeError { #[error("peer has the same node ID as us")] @@ -109,388 +112,332 @@ where let span = tracing::info_span!(parent: &tracing::Span::current(), "handshaker", %addr); - let state_machine = HandshakeStateMachine:: { - addr, - peer_stream, - peer_sink, - direction, - address_book, - core_sync_svc, - peer_request_svc, - our_basic_node_data, - state: HandshakeState::Start, - eager_protocol_messages: vec![], - }; - async move { // TODO: timeouts - state_machine.do_handshake().await + handshake( + addr, + peer_stream, + peer_sink, + direction, + address_book, + core_sync_svc, + peer_request_svc, + our_basic_node_data, + ) + .await } .instrument(span) .boxed() } } -/// The states a handshake can be in. -#[derive(Debug, Clone, Eq, PartialEq)] -enum HandshakeState { - /// The initial state. - /// - /// If this is an inbound handshake then this state means we - /// are waiting for a [`HandshakeRequest`]. - Start, - /// Waiting for a [`HandshakeResponse`]. - WaitingForHandshakeResponse, - /// Waiting for a [`SupportFlagsResponse`] - /// This contains the peers node data. - WaitingForSupportFlagResponse(BasicNodeData, CoreSyncData), - /// The handshake is complete. - /// This contains the peers node data. - Complete(BasicNodeData, CoreSyncData), - /// An invalid state, the handshake SM should not be in this state. - Invalid, -} - -impl HandshakeState { - /// Returns true if the handshake is completed. - pub fn is_complete(&self) -> bool { - matches!(self, Self::Complete(..)) - } - - /// returns the peers [`BasicNodeData`] and [`CoreSyncData`] if the peer - /// is in state [`HandshakeState::Complete`]. - pub fn peer_data(self) -> Option<(BasicNodeData, CoreSyncData)> { - match self { - HandshakeState::Complete(bnd, coresync) => Some((bnd, coresync)), - _ => None, - } - } -} - -struct HandshakeStateMachine { +#[allow(clippy::too_many_arguments)] +async fn handshake( addr: Z::Addr, - - peer_stream: Z::Stream, - peer_sink: Z::Sink, - + mut peer_stream: Z::Stream, + mut peer_sink: Z::Sink, direction: ConnectionDirection, - - address_book: AdrBook, - core_sync_svc: CSync, + mut address_book: AdrBook, + mut core_sync_svc: CSync, peer_request_svc: ReqHdlr, - our_basic_node_data: BasicNodeData, - - state: HandshakeState, - - /// Monero allows protocol messages to be sent before a handshake response, so we have to - /// keep track of them here. For saftey we only keep a Max of 2 messages. - eager_protocol_messages: Vec, -} - -impl HandshakeStateMachine +) -> Result<(), HandshakeError> where AdrBook: AddressBook, CSync: CoreSyncSvc, ReqHdlr: PeerRequestHandler, { - async fn send_handshake_request(&mut self) -> Result<(), HandshakeError> { - let CoreSyncDataResponse::Ours(our_core_sync_data) = self - .core_sync_svc - .ready() + let mut eager_protocol_messages = Vec::new(); + let mut allow_support_flag_req = true; + + let (peer_core_sync, mut peer_node_data) = match direction { + ConnectionDirection::InBound => { + tracing::debug!("waiting for handshake request."); + + let Message::Request(RequestMessage::Handshake(handshake_req)) = wait_for_message::( + LevinCommand::Handshake, + true, + &mut peer_sink, + &mut peer_stream, + &mut eager_protocol_messages, + &mut allow_support_flag_req, + our_basic_node_data.support_flags, + ) .await? - .call(CoreSyncDataRequest::Ours) - .await? - else { - panic!("core sync service returned wrong response!"); - }; + else { + panic!("wait_for_message returned ok with wrong message."); + }; - let req = HandshakeRequest { - node_data: self.our_basic_node_data.clone(), - payload_data: our_core_sync_data, - }; + tracing::debug!("Received handshake request."); - tracing::debug!("Sending handshake request."); - - self.peer_sink - .send(Message::Request(RequestMessage::Handshake(req))) + (handshake_req.payload_data, handshake_req.node_data) + } + ConnectionDirection::OutBound => { + send_hs_request::( + &mut peer_sink, + &mut core_sync_svc, + our_basic_node_data.clone(), + ) .await?; - Ok(()) - } + let Message::Response(ResponseMessage::Handshake(handshake_res)) = + wait_for_message::( + LevinCommand::Handshake, + false, + &mut peer_sink, + &mut peer_stream, + &mut eager_protocol_messages, + &mut allow_support_flag_req, + our_basic_node_data.support_flags, + ) + .await? + else { + panic!("wait_for_message returned ok with wrong message."); + }; - async fn send_handshake_response(&mut self) -> Result<(), HandshakeError> { - let CoreSyncDataResponse::Ours(our_core_sync_data) = self - .core_sync_svc - .ready() - .await? - .call(CoreSyncDataRequest::Ours) - .await? - else { - panic!("core sync service returned wrong response!"); - }; + if handshake_res.local_peerlist_new.len() > MAX_PEERS_IN_PEER_LIST_MESSAGE { + tracing::debug!("peer sent too many peers in response, cancelling handshake"); - let AddressBookResponse::Peers(our_peer_list) = self - .address_book - .ready() - .await? - .call(AddressBookRequest::GetWhitePeers( - MAX_PEERS_IN_PEER_LIST_MESSAGE, - )) - .await? - else { - panic!("Address book sent incorrect response"); - }; - - let res = HandshakeResponse { - node_data: self.our_basic_node_data.clone(), - payload_data: our_core_sync_data, - local_peerlist_new: our_peer_list.into_iter().map(Into::into).collect(), - }; - - tracing::debug!("Sending handshake response."); - - self.peer_sink - .send(Message::Response(ResponseMessage::Handshake(res))) - .await?; - - Ok(()) - } - - async fn send_support_flags(&mut self) -> Result<(), HandshakeError> { - let res = SupportFlagsResponse { - support_flags: self.our_basic_node_data.support_flags, - }; - - tracing::debug!("Sending support flag response."); - - self.peer_sink - .send(Message::Response(ResponseMessage::SupportFlags(res))) - .await?; - - Ok(()) - } - - async fn check_request_support_flags( - &mut self, - support_flags: &PeerSupportFlags, - ) -> Result { - Ok(if support_flags.is_empty() { - tracing::debug!( - "Peer didn't send support flags or has no features, sending request to make sure." - ); - self.peer_sink - .send(Message::Request(RequestMessage::SupportFlags)) - .await?; - true - } else { - false - }) - } - - async fn handle_handshake_response( - &mut self, - response: HandshakeResponse, - ) -> Result<(), HandshakeError> { - if response.local_peerlist_new.len() > MAX_PEERS_IN_PEER_LIST_MESSAGE { - tracing::debug!("peer sent too many peers in response, cancelling handshake"); - - return Err(HandshakeError::PeerSentInvalidMessage( - "Too many peers in peer list message (>250)", - )); - } - - if response.node_data.network_id != self.our_basic_node_data.network_id { - return Err(HandshakeError::IncorrectNetwork); - } - - if Z::CHECK_NODE_ID && response.node_data.peer_id == self.our_basic_node_data.peer_id { - return Err(HandshakeError::PeerHasSameNodeID); - } - - tracing::debug!( - "Telling address book about new peers, len: {}", - response.local_peerlist_new.len() - ); - - self.address_book - .ready() - .await? - .call(AddressBookRequest::IncomingPeerList( - response - .local_peerlist_new - .into_iter() - .map(TryInto::try_into) - .collect::>()?, - )) - .await?; - - if self - .check_request_support_flags(&response.node_data.support_flags) - .await? - { - self.state = HandshakeState::WaitingForSupportFlagResponse( - response.node_data, - response.payload_data, - ); - } else { - self.state = HandshakeState::Complete(response.node_data, response.payload_data); - } - - Ok(()) - } - - async fn handle_handshake_request( - &mut self, - request: HandshakeRequest, - ) -> Result<(), HandshakeError> { - // We don't respond here as if we did the other peer could accept the handshake before responding to a - // support flag request which then means we could recive other requests while waiting for the support - // flags. - - if request.node_data.network_id != self.our_basic_node_data.network_id { - return Err(HandshakeError::IncorrectNetwork); - } - - if Z::CHECK_NODE_ID && request.node_data.peer_id == self.our_basic_node_data.peer_id { - return Err(HandshakeError::PeerHasSameNodeID); - } - - if self - .check_request_support_flags(&request.node_data.support_flags) - .await? - { - self.state = HandshakeState::WaitingForSupportFlagResponse( - request.node_data, - request.payload_data, - ); - } else { - self.state = HandshakeState::Complete(request.node_data, request.payload_data); - } - - Ok(()) - } - - async fn handle_incoming_message(&mut self, message: Message) -> Result<(), HandshakeError> { - tracing::debug!("Received message from peer: {}", message.command()); - - if let Message::Protocol(protocol_message) = message { - if self.eager_protocol_messages.len() == 2 { - tracing::debug!("Peer sent too many protocl messages before a handshake response."); return Err(HandshakeError::PeerSentInvalidMessage( - "Peer sent too many protocol messages", + "Too many peers in peer list message (>250)", )); } + tracing::debug!( - "Protocol message getting added to queue for when handshake is complete." + "Telling address book about new peers, len: {}", + handshake_res.local_peerlist_new.len() ); - self.eager_protocol_messages.push(protocol_message); - return Ok(()); - } - match std::mem::replace(&mut self.state, HandshakeState::Invalid) { - HandshakeState::Start => match message { - Message::Request(RequestMessage::Ping) => { - // Set the state back to what it was before. - self.state = HandshakeState::Start; - Ok(self - .peer_sink - .send(Message::Response(ResponseMessage::Ping(PingResponse { - status: PING_OK_RESPONSE_STATUS_TEXT.to_string(), - peer_id: self.our_basic_node_data.peer_id, - }))) - .await?) - } - Message::Request(RequestMessage::Handshake(handshake_req)) => { - self.handle_handshake_request(handshake_req).await - } - _ => Err(HandshakeError::PeerSentInvalidMessage( - "Peer didn't send handshake request.", - )), - }, - HandshakeState::WaitingForHandshakeResponse => match message { - // TODO: only allow 1 support flag request. - Message::Request(RequestMessage::SupportFlags) => { - // Set the state back to what it was before. - self.state = HandshakeState::WaitingForHandshakeResponse; - self.send_support_flags().await - } - Message::Response(ResponseMessage::Handshake(res)) => { - self.handle_handshake_response(res).await - } - _ => Err(HandshakeError::PeerSentInvalidMessage( - "Peer didn't send handshake response.", - )), - }, - HandshakeState::WaitingForSupportFlagResponse(mut peer_node_data, peer_core_sync) => { - let Message::Response(ResponseMessage::SupportFlags(support_flags)) = message - else { - return Err(HandshakeError::PeerSentInvalidMessage( - "Peer didn't send support flags response.", - )); - }; - peer_node_data.support_flags = support_flags.support_flags; - self.state = HandshakeState::Complete(peer_node_data, peer_core_sync); - Ok(()) - } - HandshakeState::Complete(..) => { - panic!("Handshake is complete messages should no longer be handled here!") - } - HandshakeState::Invalid => panic!("Handshake state machine stayed in invalid state!"), + address_book + .ready() + .await? + .call(AddressBookRequest::IncomingPeerList( + handshake_res + .local_peerlist_new + .into_iter() + .map(TryInto::try_into) + .collect::>()?, + )) + .await?; + + (handshake_res.payload_data, handshake_res.node_data) } + }; + + if peer_node_data.network_id != our_basic_node_data.network_id { + return Err(HandshakeError::IncorrectNetwork); } - async fn advance_machine(&mut self) -> Result<(), HandshakeError> { - while !self.state.is_complete() { - tracing::debug!("Waiting for message from peer."); - - match self.peer_stream.next().await { - Some(message) => self.handle_incoming_message(message?).await?, - None => Err(BucketError::IO(std::io::Error::new( - std::io::ErrorKind::ConnectionAborted, - "The peer stream returned None", - )))?, - } - } - - Ok(()) + if Z::CHECK_NODE_ID && peer_node_data.peer_id == our_basic_node_data.peer_id { + return Err(HandshakeError::PeerHasSameNodeID); } - async fn do_outbound_handshake(&mut self) -> Result<(), HandshakeError> { - self.send_handshake_request().await?; - self.state = HandshakeState::WaitingForHandshakeResponse; - - self.advance_machine().await - } - - async fn do_inbound_handshake(&mut self) -> Result<(), HandshakeError> { - self.advance_machine().await?; - - debug_assert!(self.state.is_complete()); - - self.send_handshake_response().await - } - - async fn do_handshake(mut self) -> Result<(), HandshakeError> { - tracing::debug!("Beginning handshake."); - - match self.direction { - ConnectionDirection::OutBound => self.do_outbound_handshake().await?, - ConnectionDirection::InBound => self.do_inbound_handshake().await?, - } - - let HandshakeState::Complete(peer_node_data, peer_core_sync) = self.state else { - panic!("Hanshake completed not in complete state!"); - }; - - self.core_sync_svc - .ready() - .await? - .call(CoreSyncDataRequest::HandleIncoming(peer_core_sync)) + if peer_node_data.support_flags.is_empty() { + tracing::debug!( + "Peer didn't send support flags or has no features, sending request to make sure." + ); + peer_sink + .send(Message::Request(RequestMessage::SupportFlags)) .await?; - tracing::debug!("Handshake complete."); + let Message::Response(ResponseMessage::SupportFlags(support_flags_res)) = + wait_for_message::( + LevinCommand::SupportFlags, + false, + &mut peer_sink, + &mut peer_stream, + &mut eager_protocol_messages, + &mut allow_support_flag_req, + our_basic_node_data.support_flags, + ) + .await? + else { + panic!("wait_for_message returned ok with wrong message."); + }; - Ok(()) + tracing::debug!("Received support flag response."); + peer_node_data.support_flags = support_flags_res.support_flags; } + + if direction == ConnectionDirection::InBound { + send_hs_response::( + &mut peer_sink, + &mut core_sync_svc, + &mut address_book, + our_basic_node_data, + ) + .await?; + } + + core_sync_svc + .ready() + .await? + .call(CoreSyncDataRequest::HandleIncoming(peer_core_sync)) + .await?; + + tracing::debug!("Handshake complete."); + + Ok(()) +} + +/// Sends a [`HandshakeRequest`] to the peer. +async fn send_hs_request( + peer_sink: &mut Z::Sink, + core_sync_svc: &mut CSync, + our_basic_node_data: BasicNodeData, +) -> Result<(), HandshakeError> +where + CSync: CoreSyncSvc, +{ + let CoreSyncDataResponse::Ours(our_core_sync_data) = core_sync_svc + .ready() + .await? + .call(CoreSyncDataRequest::Ours) + .await? + else { + panic!("core sync service returned wrong response!"); + }; + + let req = HandshakeRequest { + node_data: our_basic_node_data, + payload_data: our_core_sync_data, + }; + + tracing::debug!("Sending handshake request."); + + peer_sink + .send(Message::Request(RequestMessage::Handshake(req))) + .await?; + + Ok(()) +} + +async fn send_hs_response( + peer_sink: &mut Z::Sink, + core_sync_svc: &mut CSync, + address_book: &mut AdrBook, + our_basic_node_data: BasicNodeData, +) -> Result<(), HandshakeError> +where + AdrBook: AddressBook, + CSync: CoreSyncSvc, +{ + let CoreSyncDataResponse::Ours(our_core_sync_data) = core_sync_svc + .ready() + .await? + .call(CoreSyncDataRequest::Ours) + .await? + else { + panic!("core sync service returned wrong response!"); + }; + + let AddressBookResponse::Peers(our_peer_list) = address_book + .ready() + .await? + .call(AddressBookRequest::GetWhitePeers( + MAX_PEERS_IN_PEER_LIST_MESSAGE, + )) + .await? + else { + panic!("Address book sent incorrect response"); + }; + + let res = HandshakeResponse { + node_data: our_basic_node_data, + payload_data: our_core_sync_data, + local_peerlist_new: our_peer_list.into_iter().map(Into::into).collect(), + }; + + tracing::debug!("Sending handshake response."); + + peer_sink + .send(Message::Response(ResponseMessage::Handshake(res))) + .await?; + + Ok(()) +} + +async fn wait_for_message( + levin_command: LevinCommand, + request: bool, + peer_sink: &mut Z::Sink, + peer_stream: &mut Z::Stream, + eager_protocol_messages: &mut Vec, + allow_support_flag_req: &mut bool, + support_flags: PeerSupportFlags, +) -> Result { + while let Some(message) = peer_stream.next().await { + let message = message?; + + match message { + Message::Protocol(protocol_message) => { + tracing::debug!( + "Received eager protocol message with ID: {}, adding to queue", + protocol_message.command() + ); + eager_protocol_messages.push(protocol_message); + if eager_protocol_messages.len() > MAX_EAGER_PROTOCOL_MESSAGES { + tracing::debug!( + "Peer sent too many protocl messages before a handshake response." + ); + return Err(HandshakeError::PeerSentInvalidMessage( + "Peer sent too many protocol messages", + )); + } + continue; + } + Message::Request(req_message) => { + if req_message.command() == levin_command && request { + return Ok(Message::Request(req_message)); + } + + if matches!(req_message, RequestMessage::SupportFlags) { + if !*allow_support_flag_req { + return Err(HandshakeError::PeerSentInvalidMessage( + "Peer sent 2 support flag requests", + )); + } + send_support_flags::(peer_sink, support_flags).await?; + // don't let the peer send more after the first request. + *allow_support_flag_req = false; + continue; + } + + return Err(HandshakeError::PeerSentInvalidMessage( + "Peer sent a admin request before responding to the handshake", + )); + } + Message::Response(res_message) if !request => { + if res_message.command() == levin_command { + return Ok(Message::Response(res_message)); + } + + tracing::debug!("Received unexpected response: {}", res_message.command()); + return Err(HandshakeError::PeerSentInvalidMessage( + "Peer sent an incorrect response", + )); + } + + _ => Err(HandshakeError::PeerSentInvalidMessage( + "Peer sent an incorrect message", + )), + }? + } + + Err(BucketError::IO(std::io::Error::new( + std::io::ErrorKind::ConnectionAborted, + "The peer stream returned None", + )))? +} + +async fn send_support_flags( + peer_sink: &mut Z::Sink, + support_flags: PeerSupportFlags, +) -> Result<(), HandshakeError> { + tracing::debug!("Sending support flag response."); + Ok(peer_sink + .send(Message::Response(ResponseMessage::SupportFlags( + SupportFlagsResponse { support_flags }, + ))) + .await?) }