re-write p2p handshaker

This commit is contained in:
Boog900 2024-01-12 00:02:25 +00:00
parent d6495cdb01
commit 5e8221183e
No known key found for this signature in database
GPG key ID: 5401367FB7302004

View file

@ -15,15 +15,18 @@ use monero_wire::{
PING_OK_RESPONSE_STATUS_TEXT, PING_OK_RESPONSE_STATUS_TEXT,
}, },
common::PeerSupportFlags, common::PeerSupportFlags,
BasicNodeData, BucketError, CoreSyncData, Message, RequestMessage, ResponseMessage, BasicNodeData, BucketError, CoreSyncData, LevinCommand, Message, RequestMessage,
ResponseMessage,
}; };
use crate::{ use crate::{
AddressBook, AddressBookRequest, AddressBookResponse, ConnectionDirection, CoreSyncDataRequest, AddressBook, AddressBookRequest, AddressBookResponse, ConnectionDirection, CoreSyncDataRequest,
CoreSyncDataResponse, CoreSyncSvc, NetworkZone, PeerRequestHandler, CoreSyncDataResponse, CoreSyncSvc, MessageID, NetworkZone, PeerRequestHandler,
MAX_PEERS_IN_PEER_LIST_MESSAGE, MAX_PEERS_IN_PEER_LIST_MESSAGE,
}; };
const MAX_EAGER_PROTOCOL_MESSAGES: usize = 2;
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum HandshakeError { pub enum HandshakeError {
#[error("peer has the same node ID as us")] #[error("peer has the same node ID as us")]
@ -109,7 +112,9 @@ where
let span = tracing::info_span!(parent: &tracing::Span::current(), "handshaker", %addr); let span = tracing::info_span!(parent: &tracing::Span::current(), "handshaker", %addr);
let state_machine = HandshakeStateMachine::<Z, _, _, _> { async move {
// TODO: timeouts
handshake(
addr, addr,
peer_stream, peer_stream,
peer_sink, peer_sink,
@ -118,85 +123,173 @@ where
core_sync_svc, core_sync_svc,
peer_request_svc, peer_request_svc,
our_basic_node_data, our_basic_node_data,
state: HandshakeState::Start, )
eager_protocol_messages: vec![], .await
};
async move {
// TODO: timeouts
state_machine.do_handshake().await
} }
.instrument(span) .instrument(span)
.boxed() .boxed()
} }
} }
/// The states a handshake can be in. #[allow(clippy::too_many_arguments)]
#[derive(Debug, Clone, Eq, PartialEq)] async fn handshake<Z: NetworkZone, AdrBook, CSync, ReqHdlr>(
enum HandshakeState {
/// The initial state.
///
/// If this is an inbound handshake then this state means we
/// are waiting for a [`HandshakeRequest`].
Start,
/// Waiting for a [`HandshakeResponse`].
WaitingForHandshakeResponse,
/// Waiting for a [`SupportFlagsResponse`]
/// This contains the peers node data.
WaitingForSupportFlagResponse(BasicNodeData, CoreSyncData),
/// The handshake is complete.
/// This contains the peers node data.
Complete(BasicNodeData, CoreSyncData),
/// An invalid state, the handshake SM should not be in this state.
Invalid,
}
impl HandshakeState {
/// Returns true if the handshake is completed.
pub fn is_complete(&self) -> bool {
matches!(self, Self::Complete(..))
}
/// returns the peers [`BasicNodeData`] and [`CoreSyncData`] if the peer
/// is in state [`HandshakeState::Complete`].
pub fn peer_data(self) -> Option<(BasicNodeData, CoreSyncData)> {
match self {
HandshakeState::Complete(bnd, coresync) => Some((bnd, coresync)),
_ => None,
}
}
}
struct HandshakeStateMachine<Z: NetworkZone, AdrBook, CSync, ReqHdlr> {
addr: Z::Addr, addr: Z::Addr,
mut peer_stream: Z::Stream,
peer_stream: Z::Stream, mut peer_sink: Z::Sink,
peer_sink: Z::Sink,
direction: ConnectionDirection, direction: ConnectionDirection,
mut address_book: AdrBook,
address_book: AdrBook, mut core_sync_svc: CSync,
core_sync_svc: CSync,
peer_request_svc: ReqHdlr, peer_request_svc: ReqHdlr,
our_basic_node_data: BasicNodeData, our_basic_node_data: BasicNodeData,
) -> Result<(), HandshakeError>
state: HandshakeState,
/// Monero allows protocol messages to be sent before a handshake response, so we have to
/// keep track of them here. For saftey we only keep a Max of 2 messages.
eager_protocol_messages: Vec<monero_wire::ProtocolMessage>,
}
impl<Z: NetworkZone, AdrBook, CSync, ReqHdlr> HandshakeStateMachine<Z, AdrBook, CSync, ReqHdlr>
where where
AdrBook: AddressBook<Z>, AdrBook: AddressBook<Z>,
CSync: CoreSyncSvc, CSync: CoreSyncSvc,
ReqHdlr: PeerRequestHandler, ReqHdlr: PeerRequestHandler,
{ {
async fn send_handshake_request(&mut self) -> Result<(), HandshakeError> { let mut eager_protocol_messages = Vec::new();
let CoreSyncDataResponse::Ours(our_core_sync_data) = self let mut allow_support_flag_req = true;
.core_sync_svc
let (peer_core_sync, mut peer_node_data) = match direction {
ConnectionDirection::InBound => {
tracing::debug!("waiting for handshake request.");
let Message::Request(RequestMessage::Handshake(handshake_req)) = wait_for_message::<Z>(
LevinCommand::Handshake,
true,
&mut peer_sink,
&mut peer_stream,
&mut eager_protocol_messages,
&mut allow_support_flag_req,
our_basic_node_data.support_flags,
)
.await?
else {
panic!("wait_for_message returned ok with wrong message.");
};
tracing::debug!("Received handshake request.");
(handshake_req.payload_data, handshake_req.node_data)
}
ConnectionDirection::OutBound => {
send_hs_request::<Z, _>(
&mut peer_sink,
&mut core_sync_svc,
our_basic_node_data.clone(),
)
.await?;
let Message::Response(ResponseMessage::Handshake(handshake_res)) =
wait_for_message::<Z>(
LevinCommand::Handshake,
false,
&mut peer_sink,
&mut peer_stream,
&mut eager_protocol_messages,
&mut allow_support_flag_req,
our_basic_node_data.support_flags,
)
.await?
else {
panic!("wait_for_message returned ok with wrong message.");
};
if handshake_res.local_peerlist_new.len() > MAX_PEERS_IN_PEER_LIST_MESSAGE {
tracing::debug!("peer sent too many peers in response, cancelling handshake");
return Err(HandshakeError::PeerSentInvalidMessage(
"Too many peers in peer list message (>250)",
));
}
tracing::debug!(
"Telling address book about new peers, len: {}",
handshake_res.local_peerlist_new.len()
);
address_book
.ready()
.await?
.call(AddressBookRequest::IncomingPeerList(
handshake_res
.local_peerlist_new
.into_iter()
.map(TryInto::try_into)
.collect::<Result<_, _>>()?,
))
.await?;
(handshake_res.payload_data, handshake_res.node_data)
}
};
if peer_node_data.network_id != our_basic_node_data.network_id {
return Err(HandshakeError::IncorrectNetwork);
}
if Z::CHECK_NODE_ID && peer_node_data.peer_id == our_basic_node_data.peer_id {
return Err(HandshakeError::PeerHasSameNodeID);
}
if peer_node_data.support_flags.is_empty() {
tracing::debug!(
"Peer didn't send support flags or has no features, sending request to make sure."
);
peer_sink
.send(Message::Request(RequestMessage::SupportFlags))
.await?;
let Message::Response(ResponseMessage::SupportFlags(support_flags_res)) =
wait_for_message::<Z>(
LevinCommand::SupportFlags,
false,
&mut peer_sink,
&mut peer_stream,
&mut eager_protocol_messages,
&mut allow_support_flag_req,
our_basic_node_data.support_flags,
)
.await?
else {
panic!("wait_for_message returned ok with wrong message.");
};
tracing::debug!("Received support flag response.");
peer_node_data.support_flags = support_flags_res.support_flags;
}
if direction == ConnectionDirection::InBound {
send_hs_response::<Z, _, _>(
&mut peer_sink,
&mut core_sync_svc,
&mut address_book,
our_basic_node_data,
)
.await?;
}
core_sync_svc
.ready()
.await?
.call(CoreSyncDataRequest::HandleIncoming(peer_core_sync))
.await?;
tracing::debug!("Handshake complete.");
Ok(())
}
/// Sends a [`HandshakeRequest`] to the peer.
async fn send_hs_request<Z: NetworkZone, CSync>(
peer_sink: &mut Z::Sink,
core_sync_svc: &mut CSync,
our_basic_node_data: BasicNodeData,
) -> Result<(), HandshakeError>
where
CSync: CoreSyncSvc,
{
let CoreSyncDataResponse::Ours(our_core_sync_data) = core_sync_svc
.ready() .ready()
.await? .await?
.call(CoreSyncDataRequest::Ours) .call(CoreSyncDataRequest::Ours)
@ -206,22 +299,30 @@ where
}; };
let req = HandshakeRequest { let req = HandshakeRequest {
node_data: self.our_basic_node_data.clone(), node_data: our_basic_node_data,
payload_data: our_core_sync_data, payload_data: our_core_sync_data,
}; };
tracing::debug!("Sending handshake request."); tracing::debug!("Sending handshake request.");
self.peer_sink peer_sink
.send(Message::Request(RequestMessage::Handshake(req))) .send(Message::Request(RequestMessage::Handshake(req)))
.await?; .await?;
Ok(()) Ok(())
} }
async fn send_handshake_response(&mut self) -> Result<(), HandshakeError> { async fn send_hs_response<Z: NetworkZone, CSync, AdrBook>(
let CoreSyncDataResponse::Ours(our_core_sync_data) = self peer_sink: &mut Z::Sink,
.core_sync_svc core_sync_svc: &mut CSync,
address_book: &mut AdrBook,
our_basic_node_data: BasicNodeData,
) -> Result<(), HandshakeError>
where
AdrBook: AddressBook<Z>,
CSync: CoreSyncSvc,
{
let CoreSyncDataResponse::Ours(our_core_sync_data) = core_sync_svc
.ready() .ready()
.await? .await?
.call(CoreSyncDataRequest::Ours) .call(CoreSyncDataRequest::Ours)
@ -230,8 +331,7 @@ where
panic!("core sync service returned wrong response!"); panic!("core sync service returned wrong response!");
}; };
let AddressBookResponse::Peers(our_peer_list) = self let AddressBookResponse::Peers(our_peer_list) = address_book
.address_book
.ready() .ready()
.await? .await?
.call(AddressBookRequest::GetWhitePeers( .call(AddressBookRequest::GetWhitePeers(
@ -243,254 +343,101 @@ where
}; };
let res = HandshakeResponse { let res = HandshakeResponse {
node_data: self.our_basic_node_data.clone(), node_data: our_basic_node_data,
payload_data: our_core_sync_data, payload_data: our_core_sync_data,
local_peerlist_new: our_peer_list.into_iter().map(Into::into).collect(), local_peerlist_new: our_peer_list.into_iter().map(Into::into).collect(),
}; };
tracing::debug!("Sending handshake response."); tracing::debug!("Sending handshake response.");
self.peer_sink peer_sink
.send(Message::Response(ResponseMessage::Handshake(res))) .send(Message::Response(ResponseMessage::Handshake(res)))
.await?; .await?;
Ok(()) Ok(())
} }
async fn send_support_flags(&mut self) -> Result<(), HandshakeError> { async fn wait_for_message<Z: NetworkZone>(
let res = SupportFlagsResponse { levin_command: LevinCommand,
support_flags: self.our_basic_node_data.support_flags, request: bool,
}; peer_sink: &mut Z::Sink,
peer_stream: &mut Z::Stream,
eager_protocol_messages: &mut Vec<monero_wire::ProtocolMessage>,
allow_support_flag_req: &mut bool,
support_flags: PeerSupportFlags,
) -> Result<Message, HandshakeError> {
while let Some(message) = peer_stream.next().await {
let message = message?;
tracing::debug!("Sending support flag response."); match message {
Message::Protocol(protocol_message) => {
self.peer_sink
.send(Message::Response(ResponseMessage::SupportFlags(res)))
.await?;
Ok(())
}
async fn check_request_support_flags(
&mut self,
support_flags: &PeerSupportFlags,
) -> Result<bool, HandshakeError> {
Ok(if support_flags.is_empty() {
tracing::debug!( tracing::debug!(
"Peer didn't send support flags or has no features, sending request to make sure." "Received eager protocol message with ID: {}, adding to queue",
protocol_message.command()
); );
self.peer_sink eager_protocol_messages.push(protocol_message);
.send(Message::Request(RequestMessage::SupportFlags)) if eager_protocol_messages.len() > MAX_EAGER_PROTOCOL_MESSAGES {
.await?;
true
} else {
false
})
}
async fn handle_handshake_response(
&mut self,
response: HandshakeResponse,
) -> Result<(), HandshakeError> {
if response.local_peerlist_new.len() > MAX_PEERS_IN_PEER_LIST_MESSAGE {
tracing::debug!("peer sent too many peers in response, cancelling handshake");
return Err(HandshakeError::PeerSentInvalidMessage(
"Too many peers in peer list message (>250)",
));
}
if response.node_data.network_id != self.our_basic_node_data.network_id {
return Err(HandshakeError::IncorrectNetwork);
}
if Z::CHECK_NODE_ID && response.node_data.peer_id == self.our_basic_node_data.peer_id {
return Err(HandshakeError::PeerHasSameNodeID);
}
tracing::debug!( tracing::debug!(
"Telling address book about new peers, len: {}", "Peer sent too many protocl messages before a handshake response."
response.local_peerlist_new.len()
); );
self.address_book
.ready()
.await?
.call(AddressBookRequest::IncomingPeerList(
response
.local_peerlist_new
.into_iter()
.map(TryInto::try_into)
.collect::<Result<_, _>>()?,
))
.await?;
if self
.check_request_support_flags(&response.node_data.support_flags)
.await?
{
self.state = HandshakeState::WaitingForSupportFlagResponse(
response.node_data,
response.payload_data,
);
} else {
self.state = HandshakeState::Complete(response.node_data, response.payload_data);
}
Ok(())
}
async fn handle_handshake_request(
&mut self,
request: HandshakeRequest,
) -> Result<(), HandshakeError> {
// We don't respond here as if we did the other peer could accept the handshake before responding to a
// support flag request which then means we could recive other requests while waiting for the support
// flags.
if request.node_data.network_id != self.our_basic_node_data.network_id {
return Err(HandshakeError::IncorrectNetwork);
}
if Z::CHECK_NODE_ID && request.node_data.peer_id == self.our_basic_node_data.peer_id {
return Err(HandshakeError::PeerHasSameNodeID);
}
if self
.check_request_support_flags(&request.node_data.support_flags)
.await?
{
self.state = HandshakeState::WaitingForSupportFlagResponse(
request.node_data,
request.payload_data,
);
} else {
self.state = HandshakeState::Complete(request.node_data, request.payload_data);
}
Ok(())
}
async fn handle_incoming_message(&mut self, message: Message) -> Result<(), HandshakeError> {
tracing::debug!("Received message from peer: {}", message.command());
if let Message::Protocol(protocol_message) = message {
if self.eager_protocol_messages.len() == 2 {
tracing::debug!("Peer sent too many protocl messages before a handshake response.");
return Err(HandshakeError::PeerSentInvalidMessage( return Err(HandshakeError::PeerSentInvalidMessage(
"Peer sent too many protocol messages", "Peer sent too many protocol messages",
)); ));
} }
tracing::debug!( continue;
"Protocol message getting added to queue for when handshake is complete." }
); Message::Request(req_message) => {
self.eager_protocol_messages.push(protocol_message); if req_message.command() == levin_command && request {
return Ok(()); return Ok(Message::Request(req_message));
} }
match std::mem::replace(&mut self.state, HandshakeState::Invalid) { if matches!(req_message, RequestMessage::SupportFlags) {
HandshakeState::Start => match message { if !*allow_support_flag_req {
Message::Request(RequestMessage::Ping) => {
// Set the state back to what it was before.
self.state = HandshakeState::Start;
Ok(self
.peer_sink
.send(Message::Response(ResponseMessage::Ping(PingResponse {
status: PING_OK_RESPONSE_STATUS_TEXT.to_string(),
peer_id: self.our_basic_node_data.peer_id,
})))
.await?)
}
Message::Request(RequestMessage::Handshake(handshake_req)) => {
self.handle_handshake_request(handshake_req).await
}
_ => Err(HandshakeError::PeerSentInvalidMessage(
"Peer didn't send handshake request.",
)),
},
HandshakeState::WaitingForHandshakeResponse => match message {
// TODO: only allow 1 support flag request.
Message::Request(RequestMessage::SupportFlags) => {
// Set the state back to what it was before.
self.state = HandshakeState::WaitingForHandshakeResponse;
self.send_support_flags().await
}
Message::Response(ResponseMessage::Handshake(res)) => {
self.handle_handshake_response(res).await
}
_ => Err(HandshakeError::PeerSentInvalidMessage(
"Peer didn't send handshake response.",
)),
},
HandshakeState::WaitingForSupportFlagResponse(mut peer_node_data, peer_core_sync) => {
let Message::Response(ResponseMessage::SupportFlags(support_flags)) = message
else {
return Err(HandshakeError::PeerSentInvalidMessage( return Err(HandshakeError::PeerSentInvalidMessage(
"Peer didn't send support flags response.", "Peer sent 2 support flag requests",
)); ));
};
peer_node_data.support_flags = support_flags.support_flags;
self.state = HandshakeState::Complete(peer_node_data, peer_core_sync);
Ok(())
}
HandshakeState::Complete(..) => {
panic!("Handshake is complete messages should no longer be handled here!")
}
HandshakeState::Invalid => panic!("Handshake state machine stayed in invalid state!"),
} }
send_support_flags::<Z>(peer_sink, support_flags).await?;
// don't let the peer send more after the first request.
*allow_support_flag_req = false;
continue;
} }
async fn advance_machine(&mut self) -> Result<(), HandshakeError> { return Err(HandshakeError::PeerSentInvalidMessage(
while !self.state.is_complete() { "Peer sent a admin request before responding to the handshake",
tracing::debug!("Waiting for message from peer."); ));
}
Message::Response(res_message) if !request => {
if res_message.command() == levin_command {
return Ok(Message::Response(res_message));
}
match self.peer_stream.next().await { tracing::debug!("Received unexpected response: {}", res_message.command());
Some(message) => self.handle_incoming_message(message?).await?, return Err(HandshakeError::PeerSentInvalidMessage(
None => Err(BucketError::IO(std::io::Error::new( "Peer sent an incorrect response",
));
}
_ => Err(HandshakeError::PeerSentInvalidMessage(
"Peer sent an incorrect message",
)),
}?
}
Err(BucketError::IO(std::io::Error::new(
std::io::ErrorKind::ConnectionAborted, std::io::ErrorKind::ConnectionAborted,
"The peer stream returned None", "The peer stream returned None",
)))?, )))?
} }
}
async fn send_support_flags<Z: NetworkZone>(
Ok(()) peer_sink: &mut Z::Sink,
} support_flags: PeerSupportFlags,
) -> Result<(), HandshakeError> {
async fn do_outbound_handshake(&mut self) -> Result<(), HandshakeError> { tracing::debug!("Sending support flag response.");
self.send_handshake_request().await?; Ok(peer_sink
self.state = HandshakeState::WaitingForHandshakeResponse; .send(Message::Response(ResponseMessage::SupportFlags(
SupportFlagsResponse { support_flags },
self.advance_machine().await )))
} .await?)
async fn do_inbound_handshake(&mut self) -> Result<(), HandshakeError> {
self.advance_machine().await?;
debug_assert!(self.state.is_complete());
self.send_handshake_response().await
}
async fn do_handshake(mut self) -> Result<(), HandshakeError> {
tracing::debug!("Beginning handshake.");
match self.direction {
ConnectionDirection::OutBound => self.do_outbound_handshake().await?,
ConnectionDirection::InBound => self.do_inbound_handshake().await?,
}
let HandshakeState::Complete(peer_node_data, peer_core_sync) = self.state else {
panic!("Hanshake completed not in complete state!");
};
self.core_sync_svc
.ready()
.await?
.call(CoreSyncDataRequest::HandleIncoming(peer_core_sync))
.await?;
tracing::debug!("Handshake complete.");
Ok(())
}
} }