From 6875a9a11d8dbccbc314fc408865807f768587ca Mon Sep 17 00:00:00 2001 From: Boog900 <54e72d8a-345f-4599-bd90-c6b9bc7d0ec5@aleeas.com> Date: Thu, 27 Jul 2023 00:45:28 +0100 Subject: [PATCH] add try_from/from conversion between `Message` and `Request`/`Response` --- p2p/src/peer/connection.rs | 33 ++-- p2p/src/peer/error.rs | 14 +- p2p/src/protocol/internal_network.rs | 10 +- p2p/src/protocol/internal_network/try_from.rs | 163 ++++++++++++++++++ 4 files changed, 200 insertions(+), 20 deletions(-) create mode 100644 p2p/src/protocol/internal_network/try_from.rs diff --git a/p2p/src/peer/connection.rs b/p2p/src/peer/connection.rs index 5502c6b..5a60a4b 100644 --- a/p2p/src/peer/connection.rs +++ b/p2p/src/peer/connection.rs @@ -1,8 +1,9 @@ use futures::channel::{mpsc, oneshot}; -use futures::{Sink, SinkExt, Stream}; +use futures::stream::FusedStream; +use futures::{Sink, SinkExt, Stream, StreamExt}; use monero_wire::{BucketError, Message}; -use tower::{BoxError, Service, ServiceExt}; +use tower::{BoxError, Service}; use crate::connection_handle::DisconnectSignal; use crate::peer::error::{ErrorSlot, PeerError, SharedPeerError}; @@ -104,12 +105,15 @@ where self.send_message_to_peer(req.req).await } - async fn state_waiting_for_request(&mut self) -> Result<(), PeerError> { + async fn state_waiting_for_request(&mut self, stream: &mut Str) -> Result<(), PeerError> + where + Str: FusedStream> + Unpin, + { futures::select! { - peer_message = self.stream.next() => { + peer_message = stream.next() => { match peer_message.expect("MessageStream will never return None") { Ok(message) => { - self.handle_peer_request(message.try_into().map_err(|_| PeerError::PeerSentUnexpectedResponse)?).await + self.handle_peer_request(message.try_into().map_err(|_| PeerError::ResponseError(""))?).await }, Err(e) => Err(e.into()), } @@ -120,10 +124,12 @@ where } } - async fn state_waiting_for_response(&mut self) -> Result<(), PeerError> { + async fn state_waiting_for_response(&mut self, stream: &mut Str) -> Result<(), PeerError> + where + Str: FusedStream> + Unpin, + { // put a timeout on this - let peer_message = self - .stream + let peer_message = stream .next() .await .expect("MessageStream will never return None")?; @@ -147,11 +153,16 @@ where } } - pub async fn run(mut self) { + pub async fn run(mut self, mut stream: Str) + where + Str: FusedStream> + Unpin, + { loop { let _res = match self.state { - State::WaitingForRequest => self.state_waiting_for_request().await, - State::WaitingForResponse { .. } => self.state_waiting_for_response().await, + State::WaitingForRequest => self.state_waiting_for_request(&mut stream).await, + State::WaitingForResponse { .. } => { + self.state_waiting_for_response(&mut stream).await + } }; } } diff --git a/p2p/src/peer/error.rs b/p2p/src/peer/error.rs index ee5ddd8..bbf3650 100644 --- a/p2p/src/peer/error.rs +++ b/p2p/src/peer/error.rs @@ -1,8 +1,8 @@ use std::sync::{Arc, Mutex}; +use monero_wire::BucketError; use thiserror::Error; use tracing_error::TracedError; -use monero_wire::BucketError; /// A wrapper around `Arc` that implements `Error`. #[derive(Error, Debug, Clone)] @@ -27,14 +27,18 @@ impl SharedPeerError { } } -#[derive(Debug, Error, Clone)] +#[derive(Debug, Error)] pub enum PeerError { #[error("The connection task has closed.")] ConnectionTaskClosed, + #[error("Error with peers response: {0}.")] + ResponseError(&'static str), + #[error("The connected peer sent an an unexpected response message.")] + PeerSentUnexpectedResponse, #[error("The connected peer sent an incorrect response.")] - PeerSentIncorrectResponse, - #[error("The connected peer sent an incorrect response.")] - BucketError(#[from] BucketError) + BucketError(#[from] BucketError), + #[error("The channel was closed.")] + ClientChannelClosed, } /// A shared error slot for peer errors. diff --git a/p2p/src/protocol/internal_network.rs b/p2p/src/protocol/internal_network.rs index 73351b8..42a419e 100644 --- a/p2p/src/protocol/internal_network.rs +++ b/p2p/src/protocol/internal_network.rs @@ -25,10 +25,12 @@ use monero_wire::{ ChainRequest, ChainResponse, FluffyMissingTransactionsRequest, GetObjectsRequest, GetObjectsResponse, GetTxPoolCompliment, HandshakeRequest, HandshakeResponse, Message, - NewBlock, NewFluffyBlock, NewTransactions, PingResponse, SupportFlagsResponse, + NewBlock, NewFluffyBlock, NewTransactions, PingResponse, RequestMessage, SupportFlagsResponse, TimedSyncRequest, TimedSyncResponse, }; +mod try_from; + /// An enum representing a request/ response combination, so a handshake request /// and response would have the same [`MessageID`]. This allows associating the /// correct response to a request. @@ -42,7 +44,7 @@ pub enum MessageID { GetObjects, GetChain, FluffyMissingTxs, - GetTxPollCompliment, + GetTxPoolCompliment, NewBlock, NewFluffyBlock, NewTransactions, @@ -57,7 +59,7 @@ pub enum Request { GetObjects(GetObjectsRequest), GetChain(ChainRequest), FluffyMissingTxs(FluffyMissingTransactionsRequest), - GetTxPollCompliment(GetTxPoolCompliment), + GetTxPoolCompliment(GetTxPoolCompliment), NewBlock(NewBlock), NewFluffyBlock(NewFluffyBlock), NewTransactions(NewTransactions), @@ -74,7 +76,7 @@ impl Request { Request::GetObjects(_) => MessageID::GetObjects, Request::GetChain(_) => MessageID::GetChain, Request::FluffyMissingTxs(_) => MessageID::FluffyMissingTxs, - Request::GetTxPollCompliment(_) => MessageID::GetTxPollCompliment, + Request::GetTxPoolCompliment(_) => MessageID::GetTxPoolCompliment, Request::NewBlock(_) => MessageID::NewBlock, Request::NewFluffyBlock(_) => MessageID::NewFluffyBlock, Request::NewTransactions(_) => MessageID::NewTransactions, diff --git a/p2p/src/protocol/internal_network/try_from.rs b/p2p/src/protocol/internal_network/try_from.rs new file mode 100644 index 0000000..c8c9ec5 --- /dev/null +++ b/p2p/src/protocol/internal_network/try_from.rs @@ -0,0 +1,163 @@ +//! This module contains the implementations of [`TryFrom`] and [`From`] to convert between +//! [`Message`], [`Request`] and [`Response`]. + +use monero_wire::messages::{Message, ProtocolMessage, RequestMessage, ResponseMessage}; + +use super::{Request, Response}; + +pub struct MessageConversionError; + + +macro_rules! match_body { + (match $value: ident {$($body:tt)*} ($left:pat => $right_ty:expr) $($todo:tt)*) => { + match_body!( match $value { + $left => $right_ty, + $($body)* + } $($todo)* ) + }; + (match $value: ident {$($body:tt)*}) => { + match $value { + $($body)* + } + }; +} + + +macro_rules! from { + ($left_ty:ident, $right_ty:ident, {$($left:ident $(($val: ident))? = $right:ident $(($vall: ident))?,)+}) => { + impl From<$left_ty> for $right_ty { + fn from(value: $left_ty) -> Self { + match_body!( match value {} + $(($left_ty::$left$(($val))? => $right_ty::$right$(($vall))?))+ + ) + } + } + }; +} + +macro_rules! try_from { + ($left_ty:ident, $right_ty:ident, {$($left:ident $(($val: ident))? = $right:ident $(($vall: ident))?,)+}) => { + impl TryFrom<$left_ty> for $right_ty { + type Error = MessageConversionError; + + fn try_from(value: $left_ty) -> Result { + Ok(match_body!( match value { + _ => return Err(MessageConversionError) + } + $(($left_ty::$left$(($val))? => $right_ty::$right$(($vall))?))+ + )) + } + } + }; +} + +macro_rules! from_try_from { + ($left_ty:ident, $right_ty:ident, {$($left:ident $(($val: ident))? = $right:ident $(($vall: ident))?,)+}) => { + try_from!($left_ty, $right_ty, {$($left $(($val))? = $right $(($vall))?,)+}); + from!($right_ty, $left_ty, {$($right $(($val))? = $left $(($vall))?,)+}); + }; +} + +macro_rules! try_from_try_from { + ($left_ty:ident, $right_ty:ident, {$($left:ident $(($val: ident))? = $right:ident $(($vall: ident))?,)+}) => { + try_from!($left_ty, $right_ty, {$($left $(($val))? = $right $(($vall))?,)+}); + try_from!($right_ty, $left_ty, {$($right $(($val))? = $left $(($val))?,)+}); + }; +} + +from_try_from!(Request, RequestMessage,{ + Handshake(val) = Handshake(val), + Ping = Ping, + SupportFlags = SupportFlags, + TimedSync(val) = TimedSync(val), +}); + +try_from_try_from!(Request, ProtocolMessage,{ + NewBlock(val) = NewBlock(val), + NewFluffyBlock(val) = NewFluffyBlock(val), + GetObjects(val) = GetObjectsRequest(val), + GetChain(val) = ChainRequest(val), + NewTransactions(val) = NewTransactions(val), + FluffyMissingTxs(val) = FluffyMissingTransactionsRequest(val), + GetTxPoolCompliment(val) = GetTxPoolCompliment(val), +}); + + + +impl TryFrom for Request { + type Error = MessageConversionError; + + fn try_from(value: Message) -> Result { + match value { + Message::Request(req) => Ok(req.into()), + Message::Protocol(pro) => pro.try_into(), + _ => Err(MessageConversionError), + } + } +} + +impl From for Message { + fn from(value: Request) -> Self { + match value { + Request::Handshake(val) => Message::Request(RequestMessage::Handshake(val)), + Request::Ping => Message::Request(RequestMessage::Ping), + Request::SupportFlags => Message::Request(RequestMessage::SupportFlags), + Request::TimedSync(val) => Message::Request(RequestMessage::TimedSync(val)), + + Request::NewBlock(val) => Message::Protocol(ProtocolMessage::NewBlock(val)), + Request::NewFluffyBlock(val) => Message::Protocol(ProtocolMessage::NewFluffyBlock(val)), + Request::GetObjects(val) => Message::Protocol(ProtocolMessage::GetObjectsRequest(val)), + Request::GetChain(val) => Message::Protocol(ProtocolMessage::ChainRequest(val)), + Request::NewTransactions(val) => Message::Protocol(ProtocolMessage::NewTransactions(val)), + Request::FluffyMissingTxs(val) => Message::Protocol(ProtocolMessage::FluffyMissingTransactionsRequest(val)), + Request::GetTxPoolCompliment(val) => Message::Protocol(ProtocolMessage::GetTxPoolCompliment(val)), + } + } +} + +from_try_from!(Response, ResponseMessage,{ + Handshake(val) = Handshake(val), + Ping(val) = Ping(val), + SupportFlags(val) = SupportFlags(val), + TimedSync(val) = TimedSync(val), +}); + +try_from_try_from!(Response, ProtocolMessage,{ + NewFluffyBlock(val) = NewFluffyBlock(val), + GetObjects(val) = GetObjectsResponse(val), + GetChain(val) = ChainEntryResponse(val), + NewTransactions(val) = NewTransactions(val), + +}); + +impl TryFrom for Response { + type Error = MessageConversionError; + + fn try_from(value: Message) -> Result { + match value { + Message::Response(res) => Ok(res.into()), + Message::Protocol(pro) => pro.try_into(), + _ => Err(MessageConversionError), + } + } +} + +impl TryFrom for Message { + type Error = MessageConversionError; + + fn try_from(value: Response) -> Result { + Ok(match value { + Response::Handshake(val) => Message::Response(ResponseMessage::Handshake(val)), + Response::Ping(val) => Message::Response(ResponseMessage::Ping(val)), + Response::SupportFlags(val) => Message::Response(ResponseMessage::SupportFlags(val)), + Response::TimedSync(val) => Message::Response(ResponseMessage::TimedSync(val)), + + Response::NewFluffyBlock(val) => Message::Protocol(ProtocolMessage::NewFluffyBlock(val)), + Response::GetObjects(val) => Message::Protocol(ProtocolMessage::GetObjectsResponse(val)), + Response::GetChain(val) => Message::Protocol(ProtocolMessage::ChainEntryResponse(val)), + Response::NewTransactions(val) => Message::Protocol(ProtocolMessage::NewTransactions(val)), + + Response::NA => return Err(MessageConversionError), + }) + } +}