p2p: add handshake timeouts

This commit is contained in:
Boog900 2024-01-22 18:18:15 +00:00
parent 81eec5cbbb
commit f894ff6f1b
No known key found for this signature in database
GPG key ID: 5401367FB7302004
6 changed files with 98 additions and 47 deletions

View file

@ -14,7 +14,7 @@ cuprate-helper = { path = "../../helper" }
monero-wire = { path = "../../net/monero-wire" } monero-wire = { path = "../../net/monero-wire" }
monero-pruning = { path = "../../pruning" } monero-pruning = { path = "../../pruning" }
tokio = { workspace = true, features = ["net", "sync", "macros"]} tokio = { workspace = true, features = ["net", "sync", "macros", "time"]}
tokio-util = { workspace = true, features = ["codec"] } tokio-util = { workspace = true, features = ["codec"] }
tokio-stream = { workspace = true, features = ["sync"]} tokio-stream = { workspace = true, features = ["sync"]}
futures = { workspace = true, features = ["std", "async-await"] } futures = { workspace = true, features = ["std", "async-await"] }

View file

@ -4,10 +4,14 @@ use std::{
pin::Pin, pin::Pin,
sync::Arc, sync::Arc,
task::{Context, Poll}, task::{Context, Poll},
time::Duration,
}; };
use futures::{FutureExt, SinkExt, StreamExt}; use futures::{FutureExt, SinkExt, StreamExt};
use tokio::sync::{broadcast, mpsc, OwnedSemaphorePermit}; use tokio::{
sync::{broadcast, mpsc, OwnedSemaphorePermit},
time::{error::Elapsed, timeout},
};
use tower::{Service, ServiceExt}; use tower::{Service, ServiceExt};
use tracing::Instrument; use tracing::Instrument;
@ -30,22 +34,25 @@ use crate::{
}; };
const MAX_EAGER_PROTOCOL_MESSAGES: usize = 2; const MAX_EAGER_PROTOCOL_MESSAGES: usize = 2;
const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(120);
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum HandshakeError { pub enum HandshakeError {
#[error("peer has the same node ID as us")] #[error("The handshake timed out")]
TimedOut(#[from] Elapsed),
#[error("Peer has the same node ID as us")]
PeerHasSameNodeID, PeerHasSameNodeID,
#[error("peer is on a different network")] #[error("Peer is on a different network")]
IncorrectNetwork, IncorrectNetwork,
#[error("peer sent a peer list with peers from different zones")] #[error("Peer sent a peer list with peers from different zones")]
PeerSentIncorrectPeerList(#[from] crate::services::PeerListConversionError), PeerSentIncorrectPeerList(#[from] crate::services::PeerListConversionError),
#[error("peer sent invalid message: {0}")] #[error("Peer sent invalid message: {0}")]
PeerSentInvalidMessage(&'static str), PeerSentInvalidMessage(&'static str),
#[error("Levin bucket error: {0}")] #[error("Levin bucket error: {0}")]
LevinBucketError(#[from] BucketError), LevinBucketError(#[from] BucketError),
#[error("Internal service error: {0}")] #[error("Internal service error: {0}")]
InternalSvcErr(#[from] tower::BoxError), InternalSvcErr(#[from] tower::BoxError),
#[error("i/o error: {0}")] #[error("I/O error: {0}")]
IO(#[from] std::io::Error), IO(#[from] std::io::Error),
} }
@ -108,14 +115,6 @@ where
} }
fn call(&mut self, req: DoHandshakeRequest<Z>) -> Self::Future { fn call(&mut self, req: DoHandshakeRequest<Z>) -> Self::Future {
let DoHandshakeRequest {
addr,
peer_stream,
peer_sink,
direction,
permit,
} = req;
let broadcast_rx = self.broadcast_tx.subscribe(); let broadcast_rx = self.broadcast_tx.subscribe();
let address_book = self.address_book.clone(); let address_book = self.address_book.clone();
@ -123,37 +122,31 @@ where
let core_sync_svc = self.core_sync_svc.clone(); let core_sync_svc = self.core_sync_svc.clone();
let our_basic_node_data = self.our_basic_node_data.clone(); let our_basic_node_data = self.our_basic_node_data.clone();
let span = tracing::info_span!(parent: &tracing::Span::current(), "handshaker", %addr); let span = tracing::info_span!(parent: &tracing::Span::current(), "handshaker", %req.addr);
async move { async move {
// TODO: timeouts timeout(
HANDSHAKE_TIMEOUT,
handshake( handshake(
addr, req,
peer_stream,
peer_sink,
direction,
permit,
broadcast_rx, broadcast_rx,
address_book, address_book,
core_sync_svc, core_sync_svc,
peer_request_svc, peer_request_svc,
our_basic_node_data, our_basic_node_data,
),
) )
.await .await?
} }
.instrument(span) .instrument(span)
.boxed() .boxed()
} }
} }
#[allow(clippy::too_many_arguments)] /// This function completes a handshake with the requested peer.
async fn handshake<Z: NetworkZone, AdrBook, CSync, ReqHdlr>( async fn handshake<Z: NetworkZone, AdrBook, CSync, ReqHdlr>(
addr: InternalPeerID<Z::Addr>, req: DoHandshakeRequest<Z>,
mut peer_stream: Z::Stream,
mut peer_sink: Z::Sink,
direction: ConnectionDirection,
permit: OwnedSemaphorePermit,
broadcast_rx: broadcast::Receiver<Arc<PeerBroadcast>>, broadcast_rx: broadcast::Receiver<Arc<PeerBroadcast>>,
mut address_book: AdrBook, mut address_book: AdrBook,
@ -166,6 +159,14 @@ where
CSync: CoreSyncSvc, CSync: CoreSyncSvc,
ReqHdlr: PeerRequestHandler, ReqHdlr: PeerRequestHandler,
{ {
let DoHandshakeRequest {
addr,
mut peer_stream,
mut peer_sink,
direction,
permit,
} = req;
let mut eager_protocol_messages = Vec::new(); let mut eager_protocol_messages = Vec::new();
let mut allow_support_flag_req = true; let mut allow_support_flag_req = true;
@ -443,7 +444,7 @@ async fn wait_for_message<Z: NetworkZone>(
} }
return Err(HandshakeError::PeerSentInvalidMessage( return Err(HandshakeError::PeerSentInvalidMessage(
"Peer sent a admin request before responding to the handshake", "Peer sent an admin request before responding to the handshake",
)); ));
} }
Message::Response(res_message) if !request => { Message::Response(res_message) if !request => {

View file

@ -61,7 +61,6 @@ pub trait NetZoneAddress:
+ borsh::BorshDeserialize + borsh::BorshDeserialize
+ Hash + Hash
+ Eq + Eq
+ Clone
+ Copy + Copy
+ Send + Send
+ Unpin + Unpin
@ -105,6 +104,10 @@ pub trait NetworkZone: Clone + Copy + Send + 'static {
type Stream: Stream<Item = Result<Message, BucketError>> + Unpin + Send + 'static; type Stream: Stream<Item = Result<Message, BucketError>> + Unpin + Send + 'static;
/// The sink (outgoing data) type for this network. /// The sink (outgoing data) type for this network.
type Sink: Sink<Message, Error = BucketError> + Unpin + Send + 'static; type Sink: Sink<Message, Error = BucketError> + Unpin + Send + 'static;
/// The inbound connection listener for this network.
type Listener: Stream<
Item = Result<(Option<Self::Addr>, Self::Stream, Self::Sink), std::io::Error>,
>;
/// Config used to start a server which listens for incoming connections. /// Config used to start a server which listens for incoming connections.
type ServerCfg; type ServerCfg;
@ -112,7 +115,9 @@ pub trait NetworkZone: Clone + Copy + Send + 'static {
addr: Self::Addr, addr: Self::Addr,
) -> Result<(Self::Stream, Self::Sink), std::io::Error>; ) -> Result<(Self::Stream, Self::Sink), std::io::Error>;
async fn incoming_connection_listener(config: Self::ServerCfg) -> (); async fn incoming_connection_listener(
config: Self::ServerCfg,
) -> Result<Self::Listener, std::io::Error>;
} }
pub(crate) trait AddressBook<Z: NetworkZone>: pub(crate) trait AddressBook<Z: NetworkZone>:

View file

@ -1,10 +1,13 @@
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::pin::Pin;
use std::task::{Context, Poll};
use monero_wire::MoneroWireCodec; use monero_wire::MoneroWireCodec;
use futures::Stream;
use tokio::net::{ use tokio::net::{
tcp::{OwnedReadHalf, OwnedWriteHalf}, tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpStream, TcpListener, TcpStream,
}; };
use tokio_util::codec::{FramedRead, FramedWrite}; use tokio_util::codec::{FramedRead, FramedWrite};
@ -22,11 +25,13 @@ impl NetZoneAddress for SocketAddr {
} }
} }
pub struct ClearNetServerCfg {
addr: SocketAddr,
}
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
pub struct ClearNet; pub struct ClearNet;
pub struct ClearNetServerCfg {}
#[async_trait::async_trait] #[async_trait::async_trait]
impl NetworkZone for ClearNet { impl NetworkZone for ClearNet {
const NAME: &'static str = "ClearNet"; const NAME: &'static str = "ClearNet";
@ -37,8 +42,9 @@ impl NetworkZone for ClearNet {
type Addr = SocketAddr; type Addr = SocketAddr;
type Stream = FramedRead<OwnedReadHalf, MoneroWireCodec>; type Stream = FramedRead<OwnedReadHalf, MoneroWireCodec>;
type Sink = FramedWrite<OwnedWriteHalf, MoneroWireCodec>; type Sink = FramedWrite<OwnedWriteHalf, MoneroWireCodec>;
type Listener = InBoundStream;
type ServerCfg = (); type ServerCfg = ClearNetServerCfg;
async fn connect_to_peer( async fn connect_to_peer(
addr: Self::Addr, addr: Self::Addr,
@ -50,7 +56,39 @@ impl NetworkZone for ClearNet {
)) ))
} }
async fn incoming_connection_listener(config: Self::ServerCfg) -> () { async fn incoming_connection_listener(
todo!() config: Self::ServerCfg,
) -> Result<Self::Listener, std::io::Error> {
let listener = TcpListener::bind(config.addr).await?;
Ok(InBoundStream { listener })
}
}
pub struct InBoundStream {
listener: TcpListener,
}
impl Stream for InBoundStream {
type Item = Result<
(
Option<SocketAddr>,
FramedRead<OwnedReadHalf, MoneroWireCodec>,
FramedWrite<OwnedWriteHalf, MoneroWireCodec>,
),
std::io::Error,
>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.listener
.poll_accept(cx)
.map_ok(|(stream, addr)| {
let (read, write) = stream.into_split();
(
Some(addr),
FramedRead::new(read, MoneroWireCodec::default()),
FramedWrite::new(write, MoneroWireCodec::default()),
)
})
.map(Some)
} }
} }

View file

@ -5,7 +5,7 @@ edition = "2021"
[dependencies] [dependencies]
monero-wire = {path = "../net/monero-wire"} monero-wire = {path = "../net/monero-wire"}
monero-p2p = {path = "../p2p/monero-p2p" } monero-p2p = {path = "../p2p/monero-p2p", features = ["borsh"] }
futures = { workspace = true, features = ["std"] } futures = { workspace = true, features = ["std"] }
async-trait = { workspace = true } async-trait = { workspace = true }

View file

@ -7,7 +7,7 @@ use std::{
}; };
use borsh::{BorshDeserialize, BorshSerialize}; use borsh::{BorshDeserialize, BorshSerialize};
use futures::{channel::mpsc::Sender as InnerSender, stream::BoxStream, Sink}; use futures::{channel::mpsc::Sender as InnerSender, stream::BoxStream, Sink, Stream};
use monero_wire::{ use monero_wire::{
network_address::{NetworkAddress, NetworkAddressIncorrectZone}, network_address::{NetworkAddress, NetworkAddressIncorrectZone},
@ -111,13 +111,20 @@ impl<const ALLOW_SYNC: bool, const DANDELION_PP: bool, const CHECK_NODE_ID: bool
type Addr = TestNetZoneAddr; type Addr = TestNetZoneAddr;
type Stream = BoxStream<'static, Result<Message, BucketError>>; type Stream = BoxStream<'static, Result<Message, BucketError>>;
type Sink = Sender; type Sink = Sender;
type Listener = Pin<
Box<
dyn Stream<
Item = Result<(Option<Self::Addr>, Self::Stream, Self::Sink), std::io::Error>,
>,
>,
>;
type ServerCfg = (); type ServerCfg = ();
async fn connect_to_peer(_: Self::Addr) -> Result<(Self::Stream, Self::Sink), Error> { async fn connect_to_peer(_: Self::Addr) -> Result<(Self::Stream, Self::Sink), Error> {
unimplemented!() unimplemented!()
} }
async fn incoming_connection_listener(_: Self::ServerCfg) -> () { async fn incoming_connection_listener(_: Self::ServerCfg) -> Result<Self::Listener, Error> {
unimplemented!() unimplemented!()
} }
} }