p2p: add handshake timeouts

This commit is contained in:
Boog900 2024-01-22 18:18:15 +00:00
parent 81eec5cbbb
commit f894ff6f1b
No known key found for this signature in database
GPG key ID: 5401367FB7302004
6 changed files with 98 additions and 47 deletions

View file

@ -14,7 +14,7 @@ cuprate-helper = { path = "../../helper" }
monero-wire = { path = "../../net/monero-wire" }
monero-pruning = { path = "../../pruning" }
tokio = { workspace = true, features = ["net", "sync", "macros"]}
tokio = { workspace = true, features = ["net", "sync", "macros", "time"]}
tokio-util = { workspace = true, features = ["codec"] }
tokio-stream = { workspace = true, features = ["sync"]}
futures = { workspace = true, features = ["std", "async-await"] }

View file

@ -4,10 +4,14 @@ use std::{
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Duration,
};
use futures::{FutureExt, SinkExt, StreamExt};
use tokio::sync::{broadcast, mpsc, OwnedSemaphorePermit};
use tokio::{
sync::{broadcast, mpsc, OwnedSemaphorePermit},
time::{error::Elapsed, timeout},
};
use tower::{Service, ServiceExt};
use tracing::Instrument;
@ -30,22 +34,25 @@ use crate::{
};
const MAX_EAGER_PROTOCOL_MESSAGES: usize = 2;
const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(120);
#[derive(Debug, thiserror::Error)]
pub enum HandshakeError {
#[error("peer has the same node ID as us")]
#[error("The handshake timed out")]
TimedOut(#[from] Elapsed),
#[error("Peer has the same node ID as us")]
PeerHasSameNodeID,
#[error("peer is on a different network")]
#[error("Peer is on a different network")]
IncorrectNetwork,
#[error("peer sent a peer list with peers from different zones")]
#[error("Peer sent a peer list with peers from different zones")]
PeerSentIncorrectPeerList(#[from] crate::services::PeerListConversionError),
#[error("peer sent invalid message: {0}")]
#[error("Peer sent invalid message: {0}")]
PeerSentInvalidMessage(&'static str),
#[error("Levin bucket error: {0}")]
LevinBucketError(#[from] BucketError),
#[error("Internal service error: {0}")]
InternalSvcErr(#[from] tower::BoxError),
#[error("i/o error: {0}")]
#[error("I/O error: {0}")]
IO(#[from] std::io::Error),
}
@ -108,14 +115,6 @@ where
}
fn call(&mut self, req: DoHandshakeRequest<Z>) -> Self::Future {
let DoHandshakeRequest {
addr,
peer_stream,
peer_sink,
direction,
permit,
} = req;
let broadcast_rx = self.broadcast_tx.subscribe();
let address_book = self.address_book.clone();
@ -123,37 +122,31 @@ where
let core_sync_svc = self.core_sync_svc.clone();
let our_basic_node_data = self.our_basic_node_data.clone();
let span = tracing::info_span!(parent: &tracing::Span::current(), "handshaker", %addr);
let span = tracing::info_span!(parent: &tracing::Span::current(), "handshaker", %req.addr);
async move {
// TODO: timeouts
timeout(
HANDSHAKE_TIMEOUT,
handshake(
addr,
peer_stream,
peer_sink,
direction,
permit,
req,
broadcast_rx,
address_book,
core_sync_svc,
peer_request_svc,
our_basic_node_data,
),
)
.await
.await?
}
.instrument(span)
.boxed()
}
}
#[allow(clippy::too_many_arguments)]
/// This function completes a handshake with the requested peer.
async fn handshake<Z: NetworkZone, AdrBook, CSync, ReqHdlr>(
addr: InternalPeerID<Z::Addr>,
mut peer_stream: Z::Stream,
mut peer_sink: Z::Sink,
direction: ConnectionDirection,
req: DoHandshakeRequest<Z>,
permit: OwnedSemaphorePermit,
broadcast_rx: broadcast::Receiver<Arc<PeerBroadcast>>,
mut address_book: AdrBook,
@ -166,6 +159,14 @@ where
CSync: CoreSyncSvc,
ReqHdlr: PeerRequestHandler,
{
let DoHandshakeRequest {
addr,
mut peer_stream,
mut peer_sink,
direction,
permit,
} = req;
let mut eager_protocol_messages = Vec::new();
let mut allow_support_flag_req = true;
@ -443,7 +444,7 @@ async fn wait_for_message<Z: NetworkZone>(
}
return Err(HandshakeError::PeerSentInvalidMessage(
"Peer sent a admin request before responding to the handshake",
"Peer sent an admin request before responding to the handshake",
));
}
Message::Response(res_message) if !request => {

View file

@ -61,7 +61,6 @@ pub trait NetZoneAddress:
+ borsh::BorshDeserialize
+ Hash
+ Eq
+ Clone
+ Copy
+ Send
+ Unpin
@ -105,6 +104,10 @@ pub trait NetworkZone: Clone + Copy + Send + 'static {
type Stream: Stream<Item = Result<Message, BucketError>> + Unpin + Send + 'static;
/// The sink (outgoing data) type for this network.
type Sink: Sink<Message, Error = BucketError> + Unpin + Send + 'static;
/// The inbound connection listener for this network.
type Listener: Stream<
Item = Result<(Option<Self::Addr>, Self::Stream, Self::Sink), std::io::Error>,
>;
/// Config used to start a server which listens for incoming connections.
type ServerCfg;
@ -112,7 +115,9 @@ pub trait NetworkZone: Clone + Copy + Send + 'static {
addr: Self::Addr,
) -> Result<(Self::Stream, Self::Sink), std::io::Error>;
async fn incoming_connection_listener(config: Self::ServerCfg) -> ();
async fn incoming_connection_listener(
config: Self::ServerCfg,
) -> Result<Self::Listener, std::io::Error>;
}
pub(crate) trait AddressBook<Z: NetworkZone>:

View file

@ -1,10 +1,13 @@
use std::net::{IpAddr, SocketAddr};
use std::pin::Pin;
use std::task::{Context, Poll};
use monero_wire::MoneroWireCodec;
use futures::Stream;
use tokio::net::{
tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpStream,
TcpListener, TcpStream,
};
use tokio_util::codec::{FramedRead, FramedWrite};
@ -22,11 +25,13 @@ impl NetZoneAddress for SocketAddr {
}
}
pub struct ClearNetServerCfg {
addr: SocketAddr,
}
#[derive(Clone, Copy)]
pub struct ClearNet;
pub struct ClearNetServerCfg {}
#[async_trait::async_trait]
impl NetworkZone for ClearNet {
const NAME: &'static str = "ClearNet";
@ -37,8 +42,9 @@ impl NetworkZone for ClearNet {
type Addr = SocketAddr;
type Stream = FramedRead<OwnedReadHalf, MoneroWireCodec>;
type Sink = FramedWrite<OwnedWriteHalf, MoneroWireCodec>;
type Listener = InBoundStream;
type ServerCfg = ();
type ServerCfg = ClearNetServerCfg;
async fn connect_to_peer(
addr: Self::Addr,
@ -50,7 +56,39 @@ impl NetworkZone for ClearNet {
))
}
async fn incoming_connection_listener(config: Self::ServerCfg) -> () {
todo!()
async fn incoming_connection_listener(
config: Self::ServerCfg,
) -> Result<Self::Listener, std::io::Error> {
let listener = TcpListener::bind(config.addr).await?;
Ok(InBoundStream { listener })
}
}
pub struct InBoundStream {
listener: TcpListener,
}
impl Stream for InBoundStream {
type Item = Result<
(
Option<SocketAddr>,
FramedRead<OwnedReadHalf, MoneroWireCodec>,
FramedWrite<OwnedWriteHalf, MoneroWireCodec>,
),
std::io::Error,
>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.listener
.poll_accept(cx)
.map_ok(|(stream, addr)| {
let (read, write) = stream.into_split();
(
Some(addr),
FramedRead::new(read, MoneroWireCodec::default()),
FramedWrite::new(write, MoneroWireCodec::default()),
)
})
.map(Some)
}
}

View file

@ -5,7 +5,7 @@ edition = "2021"
[dependencies]
monero-wire = {path = "../net/monero-wire"}
monero-p2p = {path = "../p2p/monero-p2p" }
monero-p2p = {path = "../p2p/monero-p2p", features = ["borsh"] }
futures = { workspace = true, features = ["std"] }
async-trait = { workspace = true }

View file

@ -7,7 +7,7 @@ use std::{
};
use borsh::{BorshDeserialize, BorshSerialize};
use futures::{channel::mpsc::Sender as InnerSender, stream::BoxStream, Sink};
use futures::{channel::mpsc::Sender as InnerSender, stream::BoxStream, Sink, Stream};
use monero_wire::{
network_address::{NetworkAddress, NetworkAddressIncorrectZone},
@ -111,13 +111,20 @@ impl<const ALLOW_SYNC: bool, const DANDELION_PP: bool, const CHECK_NODE_ID: bool
type Addr = TestNetZoneAddr;
type Stream = BoxStream<'static, Result<Message, BucketError>>;
type Sink = Sender;
type Listener = Pin<
Box<
dyn Stream<
Item = Result<(Option<Self::Addr>, Self::Stream, Self::Sink), std::io::Error>,
>,
>,
>;
type ServerCfg = ();
async fn connect_to_peer(_: Self::Addr) -> Result<(Self::Stream, Self::Sink), Error> {
unimplemented!()
}
async fn incoming_connection_listener(_: Self::ServerCfg) -> () {
async fn incoming_connection_listener(_: Self::ServerCfg) -> Result<Self::Listener, Error> {
unimplemented!()
}
}