change the peer module to use the new API + fix a couple bugs

This commit is contained in:
Boog900 2023-07-21 15:24:14 +01:00
parent 8981260750
commit e5d5730ad6
No known key found for this signature in database
GPG key ID: 5401367FB7302004
15 changed files with 447 additions and 422 deletions

View file

@ -61,7 +61,7 @@ pub enum BucketError {
InvalidFragmentedMessage(&'static str), InvalidFragmentedMessage(&'static str),
/// Error decoding the body /// Error decoding the body
#[error("Error decoding bucket body")] #[error("Error decoding bucket body")]
BodyDecodingError(Box<dyn Debug + Send + Sync>), BodyDecodingError(Box<dyn std::error::Error + Send + Sync>),
/// I/O error /// I/O error
#[error("I/O error: {0}")] #[error("I/O error: {0}")]
IO(#[from] std::io::Error), IO(#[from] std::io::Error),

View file

@ -155,7 +155,8 @@ pub struct PeerListEntryBase {
/// The Peer Address /// The Peer Address
pub adr: NetworkAddress, pub adr: NetworkAddress,
/// The Peer ID /// The Peer ID
pub id: u64, #[epee_try_from_into(u64)]
pub id: PeerID,
/// The last Time The Peer Was Seen /// The last Time The Peer Was Seen
#[epee_default(0)] #[epee_default(0)]
pub last_seen: i64, pub last_seen: i64,

View file

@ -18,5 +18,6 @@ tokio-util = {version = "0.7.8", features=["codec"]}
tokio-stream = {version="0.1.14", features=["time"]} tokio-stream = {version="0.1.14", features=["time"]}
async-trait = "0.1.68" async-trait = "0.1.68"
tracing = "0.1.37" tracing = "0.1.37"
tracing-error = "0.2.0"
rand = "0.8.5" rand = "0.8.5"
pin-project = "1.0.12" pin-project = "1.0.12"

View file

@ -0,0 +1,130 @@
//! Counting active connections used by Cuprate.
//!
//! These types can be used to count any kind of active resource.
//! But they are currently used to track the number of open connections.
use std::{fmt, sync::Arc};
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
/// A counter for active connections.
///
/// Creates a [`ConnectionTracker`] to track each active connection.
/// When these trackers are dropped, the counter gets notified.
pub struct ActiveConnectionCounter {
/// The limit for this type of connection, for diagnostics only.
/// The caller must enforce the limit by ignoring, delaying, or dropping connections.
limit: usize,
/// The label for this connection counter, typically its type.
label: Arc<str>,
semaphore: Arc<Semaphore>,
}
impl fmt::Debug for ActiveConnectionCounter {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ActiveConnectionCounter")
.field("label", &self.label)
.field("count", &self.count())
.field("limit", &self.limit)
.finish()
}
}
impl ActiveConnectionCounter {
/// Create and return a new active connection counter.
pub fn new_counter() -> Self {
Self::new_counter_with(Semaphore::MAX_PERMITS, "Active Connections")
}
/// Create and return a new active connection counter with `limit` and `label`.
/// The caller must check and enforce limits using [`update_count()`](Self::update_count).
pub fn new_counter_with<S: ToString>(limit: usize, label: S) -> Self {
let label = label.to_string();
Self {
limit,
label: label.into(),
semaphore: Arc::new(Semaphore::new(limit)),
}
}
/// Create and return a new [`ConnectionTracker`], using a permit from the semaphore,
/// SAFETY:
/// This function will panic if the semaphore doesn't have anymore permits.
pub fn track_connection(&mut self) -> ConnectionTracker {
ConnectionTracker::new(self)
}
pub fn count(&self) -> usize {
let count = self
.limit
.checked_sub(self.semaphore.available_permits())
.expect("Limit is less than available connection permits");
tracing::trace!(
open_connections = ?count,
limit = ?self.limit,
label = ?self.label,
);
count
}
pub fn available_permits(&self) -> usize {
self.semaphore.available_permits()
}
}
/// A per-connection tracker.
///
/// [`ActiveConnectionCounter`] creates a tracker instance for each active connection.
pub struct ConnectionTracker {
/// The permit for this connection, updates the semaphore when dropped.
permit: OwnedSemaphorePermit,
/// The label for this connection counter, typically its type.
label: Arc<str>,
}
impl fmt::Debug for ConnectionTracker {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("ConnectionTracker")
.field(&self.label)
.finish()
}
}
impl ConnectionTracker {
/// Create and return a new active connection tracker, and add 1 to `counter`.
/// All connection trackers share a label with their connection counter.
///
/// When the returned tracker is dropped, `counter` will be notified.
///
/// SAFETY:
/// This function will panic if the [`ActiveConnectionCounter`] doesn't have anymore permits.
fn new(counter: &mut ActiveConnectionCounter) -> Self {
tracing::debug!(
open_connections = ?counter.count(),
limit = ?counter.limit,
label = ?counter.label,
"opening a new peer connection",
);
Self {
permit: counter.semaphore.clone().try_acquire_owned().unwrap(),
label: counter.label.clone(),
}
}
}
impl Drop for ConnectionTracker {
fn drop(&mut self) {
tracing::debug!(
label = ?self.label,
"A peer connection has closed",
);
// the permit is automatically dropped
}
}

View file

@ -1,132 +1,98 @@
//! This module contains the address book [`Connection`](crate::peer::connection::Connection) handle
//! //!
//! # Why do we need a handle between the address book and connection task //! # Why do we need a handle between the address book and connection task
//! //!
//! When banning a peer we need to tell the connection task to close and //! When banning a peer we need to tell the connection task to close and
//! when we close a connection we need to remove it from our connection //! when we close a connection we need to tell the address book.
//! and anchor list.
//! //!
//! //!
use futures::channel::{mpsc, oneshot}; use std::time::Duration;
use futures::{FutureExt, SinkExt, StreamExt};
/// A message sent to tell the address book that a peer has disconnected. use futures::channel::mpsc;
pub struct PeerConnectionClosed(Option<std::time::Duration>); use futures::SinkExt;
use tokio_util::sync::CancellationToken;
pub enum ConnectionClosed { use crate::connection_counter::ConnectionTracker;
Closed(Option<std::time::Duration>),
Running, #[derive(Default, Debug)]
pub struct HandleBuilder {
tracker: Option<ConnectionTracker>,
} }
/// The connection side of the address book to connection impl HandleBuilder {
/// communication. pub fn set_tracker(&mut self, tracker: ConnectionTracker) {
#[derive(Debug)] self.tracker = Some(tracker)
pub struct ConnectionHandleAddressBookSide {
connection_closed_rx: oneshot::Receiver<PeerConnectionClosed>,
ban_tx: mpsc::Sender<std::time::Duration>,
}
impl ConnectionHandleAddressBookSide {
pub fn ban_peer(&mut self, duration: std::time::Duration) {
let _ = self.ban_tx.send(duration);
} }
pub fn check_connection_closed(&mut self) -> ConnectionClosed { pub fn build(self) -> (DisconnectSignal, ConnectionHandle, PeerHandle) {
let connection_closed = self let token = CancellationToken::new();
.connection_closed_rx let (tx, rx) = mpsc::channel(0);
.try_recv()
.expect("Will not be cancelled"); (
match connection_closed { DisconnectSignal {
Some(closed) => return ConnectionClosed::Closed(closed.0), token: token.clone(),
None => ConnectionClosed::Running, tracker: self.tracker.expect("Tracker was not set!"),
},
ConnectionHandle {
token: token.clone(),
ban: rx,
},
PeerHandle { ban: tx },
)
}
}
pub struct BanPeer(pub Duration);
/// A struct given to the connection task.
pub struct DisconnectSignal {
token: CancellationToken,
tracker: ConnectionTracker,
}
impl DisconnectSignal {
pub fn should_shutdown(&self) -> bool {
self.token.is_cancelled()
}
pub fn connection_closed(&self) {
self.token.cancel()
}
}
impl Drop for DisconnectSignal {
fn drop(&mut self) {
self.token.cancel()
}
}
/// A handle given to a task that needs to cancel this connection.
pub struct ConnectionHandle {
token: CancellationToken,
ban: mpsc::Receiver<BanPeer>,
}
impl ConnectionHandle {
pub fn is_closed(&self) -> bool {
self.token.is_cancelled()
}
pub fn check_should_ban(&mut self) -> Option<BanPeer> {
match self.ban.try_next() {
Ok(res) => res,
Err(_) => None,
} }
} }
} pub fn send_close_signal(&self) {
self.token.cancel()
/// The address book side of the address book to connection
/// communication.
#[derive(Debug)]
pub struct ConnectionHandleConnectionSide {
connection_closed_tx: Option<oneshot::Sender<PeerConnectionClosed>>,
ban_rx: mpsc::Receiver<std::time::Duration>,
ban_holder: Option<std::time::Duration>,
}
impl ConnectionHandleConnectionSide {
pub fn been_banned(&mut self) -> bool {
let ban_time =
self.ban_rx.next().now_or_never().and_then(|inner| {
Some(inner.expect("Handles to the connection task wont be dropped"))
});
let ret = ban_time.is_some();
self.ban_holder = ban_time;
ret
} }
} }
impl Drop for ConnectionHandleConnectionSide { /// A handle given to a task that needs to be able to ban a connection.
fn drop(&mut self) { #[derive(Clone)]
let tx = std::mem::replace(&mut self.connection_closed_tx, None) pub struct PeerHandle {
.expect("Will only be dropped once"); ban: mpsc::Sender<BanPeer>,
let _ = tx.send(PeerConnectionClosed(self.ban_holder)); }
}
} impl PeerHandle {
pub fn ban_peer(&mut self, duration: Duration) {
pub struct ConnectionHandleClientSide { let _ = self.ban.send(BanPeer(duration));
ban_tx: mpsc::Sender<std::time::Duration>,
}
impl ConnectionHandleClientSide {
pub fn ban_peer(&mut self, duration: std::time::Duration) {
let _ = self.ban_tx.send(duration);
}
}
/// Creates a new handle pair that can be given to the connection task and
/// address book respectively.
pub fn new_address_book_connection_handle() -> (
ConnectionHandleConnectionSide,
ConnectionHandleAddressBookSide,
ConnectionHandleClientSide,
) {
let (tx_closed, rx_closed) = oneshot::channel();
let (tx_ban, rx_ban) = mpsc::channel(0);
let c_h_c_s = ConnectionHandleConnectionSide {
connection_closed_tx: Some(tx_closed),
ban_rx: rx_ban,
ban_holder: None,
};
let c_h_a_s = ConnectionHandleAddressBookSide {
connection_closed_rx: rx_closed,
ban_tx: tx_ban.clone(),
};
let c_h_cl_s = ConnectionHandleClientSide { ban_tx: tx_ban };
(c_h_c_s, c_h_a_s, c_h_cl_s)
}
#[cfg(test)]
mod tests {
use super::new_address_book_connection_handle;
#[test]
fn close_connection_from_address_book() {
let (conn_side, mut addr_side) = new_address_book_connection_handle();
assert!(!conn_side.is_canceled());
assert!(!addr_side.connection_closed());
addr_side.kill_connection();
assert!(conn_side.is_canceled());
}
#[test]
fn close_connection_from_connection() {
let (conn_side, mut addr_side) = new_address_book_connection_handle();
assert!(!conn_side.is_canceled());
assert!(!addr_side.connection_closed());
drop(conn_side);
assert!(addr_side.connection_closed());
} }
} }

View file

@ -1,174 +0,0 @@
//! Counting active connections used by Cuprate.
//!
//! These types can be used to count any kind of active resource.
//! But they are currently used to track the number of open connections.
//!
//! This file was originally from Zebra
use std::{fmt, sync::Arc};
use std::sync::mpsc;
/// A signal sent by a [`Connection`][1] when it closes.
///
/// Used to count the number of open connections.
///
/// [1]: crate::peer::Connection
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub struct ConnectionClosed;
/// A counter for active connections.
///
/// Creates a [`ConnectionTracker`] to track each active connection.
/// When these trackers are dropped, the counter gets notified.
pub struct ActiveConnectionCounter {
/// The number of active peers tracked using this counter.
count: usize,
/// The limit for this type of connection, for diagnostics only.
/// The caller must enforce the limit by ignoring, delaying, or dropping connections.
limit: usize,
/// The label for this connection counter, typically its type.
label: Arc<str>,
/// The channel used to send closed connection notifications.
close_notification_tx: mpsc::Sender<ConnectionClosed>,
/// The channel used to receive closed connection notifications.
close_notification_rx: mpsc::Receiver<ConnectionClosed>,
}
impl fmt::Debug for ActiveConnectionCounter {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ActiveConnectionCounter")
.field("label", &self.label)
.field("count", &self.count)
.field("limit", &self.limit)
.finish()
}
}
impl ActiveConnectionCounter {
/// Create and return a new active connection counter.
pub fn new_counter() -> Self {
Self::new_counter_with(usize::MAX, "Active Connections")
}
/// Create and return a new active connection counter with `limit` and `label`.
/// The caller must check and enforce limits using [`update_count()`](Self::update_count).
pub fn new_counter_with<S: ToString>(limit: usize, label: S) -> Self {
// The number of items in this channel is bounded by the connection limit.
let (close_notification_tx, close_notification_rx) = mpsc::channel();
let label = label.to_string();
Self {
count: 0,
limit,
label: label.into(),
close_notification_rx,
close_notification_tx,
}
}
/// Create and return a new [`ConnectionTracker`], and add 1 to this counter.
///
/// When the returned tracker is dropped, this counter will be notified, and decreased by 1.
pub fn track_connection(&mut self) -> ConnectionTracker {
ConnectionTracker::new(self)
}
/// Check for closed connection notifications, and return the current connection count.
pub fn update_count(&mut self) -> usize {
let previous_connections = self.count;
// We ignore errors here:
// - TryRecvError::Empty means that there are no pending close notifications
// - TryRecvError::Closed is unreachable, because we hold a sender
while let Ok(ConnectionClosed) = self.close_notification_rx.try_recv() {
self.count -= 1;
tracing::debug!(
open_connections = ?self.count,
?previous_connections,
limit = ?self.limit,
label = ?self.label,
"a peer connection was closed",
);
}
tracing::trace!(
open_connections = ?self.count,
?previous_connections,
limit = ?self.limit,
label = ?self.label,
"updated active connection count",
);
#[cfg(feature = "progress-bar")]
self.connection_bar
.set_pos(u64::try_from(self.count).expect("fits in u64"))
.set_len(u64::try_from(self.limit).expect("fits in u64"));
self.count
}
}
impl Drop for ActiveConnectionCounter {
fn drop(&mut self) {}
}
/// A per-connection tracker.
///
/// [`ActiveConnectionCounter`] creates a tracker instance for each active connection.
/// When these trackers are dropped, the counter gets notified.
pub struct ConnectionTracker {
/// The channel used to send closed connection notifications on drop.
close_notification_tx: mpsc::Sender<ConnectionClosed>,
/// The label for this connection counter, typically its type.
label: Arc<str>,
}
impl fmt::Debug for ConnectionTracker {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("ConnectionTracker")
.field(&self.label)
.finish()
}
}
impl ConnectionTracker {
/// Create and return a new active connection tracker, and add 1 to `counter`.
/// All connection trackers share a label with their connection counter.
///
/// When the returned tracker is dropped, `counter` will be notified, and decreased by 1.
fn new(counter: &mut ActiveConnectionCounter) -> Self {
counter.count += 1;
tracing::debug!(
open_connections = ?counter.count,
limit = ?counter.limit,
label = ?counter.label,
"opening a new peer connection",
);
Self {
close_notification_tx: counter.close_notification_tx.clone(),
label: counter.label.clone(),
}
}
}
impl Drop for ConnectionTracker {
/// Notifies the corresponding connection counter that the connection has closed.
fn drop(&mut self) {
tracing::debug!(label = ?self.label, "closing a peer connection");
// We ignore disconnected errors, because the receiver can be dropped
// before some connections are dropped.
//
let _ = self.close_notification_tx.send(ConnectionClosed);
}
}

View file

@ -1,6 +1,7 @@
pub mod address_book; pub mod address_book;
pub mod config; pub mod config;
pub mod connection_tracker; mod connection_handle;
pub mod connection_counter;
mod constants; mod constants;
pub mod peer; pub mod peer;
pub mod peer_set; pub mod peer_set;

View file

@ -4,21 +4,10 @@ pub mod connector;
pub mod handshaker; pub mod handshaker;
pub mod load_tracked_client; pub mod load_tracked_client;
mod error;
#[cfg(test)] #[cfg(test)]
mod tests; mod tests;
use monero_wire::levin::BucketError;
use thiserror::Error;
#[derive(Debug, Error, Clone, Copy)]
pub enum RequestServiceError {}
#[derive(Debug, Error, Clone, Copy)]
pub enum PeerError {
#[error("The connection task has closed.")]
ConnectionTaskClosed,
}
pub use client::Client; pub use client::Client;
pub use client::ConnectionInfo; pub use client::ConnectionInfo;
pub use connection::Connection; pub use connection::Connection;

View file

@ -10,21 +10,21 @@ use futures::{
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use tower::BoxError; use tower::BoxError;
use crate::peer::handshaker::ConnectionAddr;
use cuprate_common::shutdown::set_shutting_down;
use cuprate_common::PruningSeed; use cuprate_common::PruningSeed;
use monero_wire::messages::PeerID;
use monero_wire::{messages::common::PeerSupportFlags, NetworkAddress}; use monero_wire::{messages::common::PeerSupportFlags, NetworkAddress};
use super::{connection::ClientRequest, PeerError}; use super::{
connection::ClientRequest,
error::{ErrorSlot, PeerError, SharedPeerError},
PeerError,
};
use crate::connection_handle::PeerHandle;
use crate::protocol::{InternalMessageRequest, InternalMessageResponse}; use crate::protocol::{InternalMessageRequest, InternalMessageResponse};
pub struct ConnectionInfo { pub struct ConnectionInfo {
pub addr: ConnectionAddr,
pub support_flags: PeerSupportFlags, pub support_flags: PeerSupportFlags,
pub pruning_seed: PruningSeed, pub pruning_seed: PruningSeed,
/// Peer ID pub handle: PeerHandle,
pub peer_id: PeerID,
pub rpc_port: u16, pub rpc_port: u16,
pub rpc_credits_per_hash: u32, pub rpc_credits_per_hash: u32,
} }
@ -37,6 +37,8 @@ pub struct Client {
server_tx: mpsc::Sender<ClientRequest>, server_tx: mpsc::Sender<ClientRequest>,
connection_task: JoinHandle<()>, connection_task: JoinHandle<()>,
heartbeat_task: JoinHandle<()>, heartbeat_task: JoinHandle<()>,
error_slot: ErrorSlot,
} }
impl Client { impl Client {
@ -46,6 +48,7 @@ impl Client {
server_tx: mpsc::Sender<ClientRequest>, server_tx: mpsc::Sender<ClientRequest>,
connection_task: JoinHandle<()>, connection_task: JoinHandle<()>,
heartbeat_task: JoinHandle<()>, heartbeat_task: JoinHandle<()>,
error_slot: ErrorSlot,
) -> Self { ) -> Self {
Client { Client {
connection_info, connection_info,
@ -53,6 +56,7 @@ impl Client {
server_tx, server_tx,
connection_task, connection_task,
heartbeat_task, heartbeat_task,
error_slot,
} }
} }
@ -140,7 +144,7 @@ impl Client {
impl tower::Service<InternalMessageRequest> for Client { impl tower::Service<InternalMessageRequest> for Client {
type Response = InternalMessageResponse; type Response = InternalMessageResponse;
type Error = PeerError; type Error = SharedPeerError;
type Future = type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>; Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
@ -163,7 +167,7 @@ impl tower::Service<InternalMessageRequest> for Client {
.map_err(|e| e.into()) .map_err(|e| e.into())
}) })
.boxed(), .boxed(),
Err(_e) => { Err(_) => {
// TODO: better error handling // TODO: better error handling
futures::future::ready(Err(PeerError::ClientChannelClosed.into())).boxed() futures::future::ready(Err(PeerError::ClientChannelClosed.into())).boxed()
} }

View file

@ -1,123 +1,77 @@
use std::collections::HashSet;
use std::sync::atomic::AtomicU64;
use futures::channel::{mpsc, oneshot}; use futures::channel::{mpsc, oneshot};
use futures::stream::Fuse; use futures::{Sink, SinkExt, Stream};
use futures::{AsyncRead, AsyncWrite, SinkExt, StreamExt};
use crate::address_book::connection_handle::AddressBookConnectionHandle; use monero_wire::{BucketError, Message};
use levin::{MessageSink, MessageStream};
use monero_wire::messages::CoreSyncData;
use monero_wire::{levin, Message, NetworkAddress};
use tower::{BoxError, Service, ServiceExt}; use tower::{BoxError, Service, ServiceExt};
use crate::connection_tracker::{self, ConnectionTracker}; use crate::connection_handle::DisconnectSignal;
use crate::peer::error::{ErrorSlot, PeerError, SharedPeerError};
use crate::peer::handshaker::ConnectionAddr; use crate::peer::handshaker::ConnectionAddr;
use crate::protocol::{InternalMessageRequest, InternalMessageResponse}; use crate::protocol::internal_network::{MessageID, Request, Response};
use super::PeerError;
pub struct ClientRequest { pub struct ClientRequest {
pub req: InternalMessageRequest, pub req: Request,
pub tx: oneshot::Sender<Result<InternalMessageResponse, PeerError>>, pub tx: oneshot::Sender<Result<Response, SharedPeerError>>,
} }
pub enum State { pub enum State {
WaitingForRequest, WaitingForRequest,
WaitingForResponse { WaitingForResponse {
request: InternalMessageRequest, request_id: MessageID,
tx: oneshot::Sender<Result<InternalMessageResponse, PeerError>>, tx: oneshot::Sender<Result<Response, SharedPeerError>>,
}, },
} }
impl State { pub struct Connection<Svc, Snk> {
pub fn expected_response_id(&self) -> Option<u32> {
match self {
Self::WaitingForRequest => None,
Self::WaitingForResponse { request, tx: _ } => request.expected_id(),
}
}
}
pub struct Connection<Svc, Aw> {
address: ConnectionAddr, address: ConnectionAddr,
state: State, state: State,
sink: MessageSink<Aw, Message>, sink: Snk,
client_rx: mpsc::Receiver<ClientRequest>, client_rx: mpsc::Receiver<ClientRequest>,
/// A connection tracker that reduces the open connection count when dropped.
/// Used to limit the number of open connections in Zebra. error_slot: ErrorSlot,
///
/// This field does nothing until it is dropped.
///
/// # Security /// # Security
/// ///
/// If this connection tracker or `Connection`s are leaked, /// If this connection tracker or `Connection`s are leaked,
/// the number of active connections will appear higher than it actually is. /// the number of active connections will appear higher than it actually is.
/// If enough connections leak, Cuprate will stop making new connections. /// If enough connections leak, Cuprate will stop making new connections.
#[allow(dead_code)] connection_tracker: DisconnectSignal,
connection_tracker: ConnectionTracker,
/// A handle to our slot in the address book so we can tell the address
/// book when we disconnect and the address book can tell us to disconnect.
address_book_handle: AddressBookConnectionHandle,
svc: Svc, svc: Svc,
} }
impl<Svc, Aw> Connection<Svc, Aw> impl<Svc, Snk> Connection<Svc, Snk>
where where
Svc: Service<InternalMessageRequest, Response = InternalMessageResponse, Error = BoxError>, Svc: Service<Request, Response = Response, Error = BoxError>,
Aw: AsyncWrite + std::marker::Unpin, Snk: Sink<Message, Error = BucketError> + Unpin,
{ {
pub fn new( pub fn new(
address: ConnectionAddr, address: ConnectionAddr,
sink: MessageSink<Aw, Message>, sink: Snk,
client_rx: mpsc::Receiver<ClientRequest>, client_rx: mpsc::Receiver<ClientRequest>,
connection_tracker: ConnectionTracker, error_slot: ErrorSlot,
address_book_handle: AddressBookConnectionHandle, connection_tracker: DisconnectSignal,
svc: Svc, svc: Svc,
) -> Connection<Svc, Aw> { ) -> Connection<Svc, Snk> {
Connection { Connection {
address, address,
state: State::WaitingForRequest, state: State::WaitingForRequest,
sink, sink,
client_rx, client_rx,
error_slot,
connection_tracker, connection_tracker,
address_book_handle,
svc, svc,
} }
} }
async fn handle_response(&mut self, res: InternalMessageResponse) -> Result<(), PeerError> { async fn handle_response(&mut self, res: Response) -> Result<(), PeerError> {
let state = std::mem::replace(&mut self.state, State::WaitingForRequest); let state = std::mem::replace(&mut self.state, State::WaitingForRequest);
if let State::WaitingForResponse { request, tx } = state { if let State::WaitingForResponse { request_id, tx } = state {
match (request, &res) { if request_id != res.id() {
(InternalMessageRequest::Handshake(_), InternalMessageResponse::Handshake(_)) => {} // TODO: Fail here
( return Err(PeerError::PeerSentIncorrectResponse);
InternalMessageRequest::SupportFlags(_),
InternalMessageResponse::SupportFlags(_),
) => {}
(InternalMessageRequest::TimedSync(_), InternalMessageResponse::TimedSync(res)) => {
}
(
InternalMessageRequest::GetObjectsRequest(req),
InternalMessageResponse::GetObjectsResponse(res),
) => {}
(
InternalMessageRequest::ChainRequest(_),
InternalMessageResponse::ChainResponse(res),
) => {}
(
InternalMessageRequest::FluffyMissingTransactionsRequest(req),
InternalMessageResponse::NewFluffyBlock(blk),
) => {}
(
InternalMessageRequest::GetTxPoolCompliment(_),
InternalMessageResponse::NewTransactions(_),
) => {
// we could check we received no transactions that we said we knew about but thats going to happen later anyway when they get added to our
// mempool
}
_ => return Err(PeerError::ResponseError("Peer sent incorrect response")),
} }
// response passed our tests we can send it to the requestor
// response passed our tests we can send it to the requester
let _ = tx.send(Ok(res)); let _ = tx.send(Ok(res));
Ok(()) Ok(())
} else { } else {
@ -129,7 +83,7 @@ where
Ok(self.sink.send(mes.into()).await?) Ok(self.sink.send(mes.into()).await?)
} }
async fn handle_peer_request(&mut self, req: InternalMessageRequest) -> Result<(), PeerError> { async fn handle_peer_request(&mut self, req: Request) -> Result<(), PeerError> {
// we should check contents of peer requests for obvious errors like we do with responses // we should check contents of peer requests for obvious errors like we do with responses
todo!() todo!()
/* /*
@ -140,13 +94,13 @@ where
} }
async fn handle_client_request(&mut self, req: ClientRequest) -> Result<(), PeerError> { async fn handle_client_request(&mut self, req: ClientRequest) -> Result<(), PeerError> {
// check we need a response if req.req.needs_response() {
if let Some(_) = req.req.expected_id() {
self.state = State::WaitingForResponse { self.state = State::WaitingForResponse {
request: req.req.clone(), request_id: req.req.id(),
tx: req.tx, tx: req.tx,
}; };
} }
// TODO: send NA response to requester
self.send_message_to_peer(req.req).await self.send_message_to_peer(req.req).await
} }
@ -197,9 +151,7 @@ where
loop { loop {
let _res = match self.state { let _res = match self.state {
State::WaitingForRequest => self.state_waiting_for_request().await, State::WaitingForRequest => self.state_waiting_for_request().await,
State::WaitingForResponse { request: _, tx: _ } => { State::WaitingForResponse { .. } => self.state_waiting_for_response().await,
self.state_waiting_for_response().await
}
}; };
} }
} }

View file

@ -16,7 +16,7 @@ use tracing::Instrument;
use crate::peer::handshaker::ConnectionAddr; use crate::peer::handshaker::ConnectionAddr;
use crate::{ use crate::{
address_book::{AddressBookRequest, AddressBookResponse}, address_book::{AddressBookRequest, AddressBookResponse},
connection_tracker::ConnectionTracker, connection_counter::ConnectionTracker,
protocol::{ protocol::{
CoreSyncDataRequest, CoreSyncDataResponse, InternalMessageRequest, InternalMessageResponse, CoreSyncDataRequest, CoreSyncDataResponse, InternalMessageRequest, InternalMessageResponse,
}, },

112
p2p/src/peer/error.rs Normal file
View file

@ -0,0 +1,112 @@
use std::sync::{Arc, Mutex};
use thiserror::Error;
use tracing_error::TracedError;
use monero_wire::BucketError;
/// A wrapper around `Arc<PeerError>` that implements `Error`.
#[derive(Error, Debug, Clone)]
#[error(transparent)]
pub struct SharedPeerError(Arc<TracedError<PeerError>>);
impl<E> From<E> for SharedPeerError
where
PeerError: From<E>,
{
fn from(source: E) -> Self {
Self(Arc::new(TracedError::from(PeerError::from(source))))
}
}
impl SharedPeerError {
/// Returns a debug-formatted string describing the inner [`PeerError`].
///
/// Unfortunately, [`TracedError`] makes it impossible to get a reference to the original error.
pub fn inner_debug(&self) -> String {
format!("{:?}", self.0.as_ref())
}
}
#[derive(Debug, Error, Clone)]
pub enum PeerError {
#[error("The connection task has closed.")]
ConnectionTaskClosed,
#[error("The connected peer sent an incorrect response.")]
PeerSentIncorrectResponse,
#[error("The connected peer sent an incorrect response.")]
BucketError(#[from] BucketError)
}
/// A shared error slot for peer errors.
///
/// # Correctness
///
/// Error slots are shared between sync and async code. In async code, the error
/// mutex should be held for as short a time as possible. This avoids blocking
/// the async task thread on acquiring the mutex.
///
/// > If the value behind the mutex is just data, its usually appropriate to use a blocking mutex
/// > ...
/// > wrap the `Arc<Mutex<...>>` in a struct
/// > that provides non-async methods for performing operations on the data within,
/// > and only lock the mutex inside these methods
///
/// <https://docs.rs/tokio/1.15.0/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use>
#[derive(Default, Clone)]
pub struct ErrorSlot(Arc<std::sync::Mutex<Option<SharedPeerError>>>);
impl std::fmt::Debug for ErrorSlot {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// don't hang if the mutex is locked
// show the panic if the mutex was poisoned
f.debug_struct("ErrorSlot")
.field("error", &self.0.try_lock())
.finish()
}
}
impl ErrorSlot {
/// Read the current error in the slot.
///
/// Returns `None` if there is no error in the slot.
///
/// # Correctness
///
/// Briefly locks the error slot's threaded `std::sync::Mutex`, to get a
/// reference to the error in the slot.
#[allow(clippy::unwrap_in_result)]
pub fn try_get_error(&self) -> Option<SharedPeerError> {
self.0
.lock()
.expect("error mutex should be unpoisoned")
.as_ref()
.cloned()
}
/// Update the current error in the slot.
///
/// Returns `Err(AlreadyErrored)` if there was already an error in the slot.
///
/// # Correctness
///
/// Briefly locks the error slot's threaded `std::sync::Mutex`, to check for
/// a previous error, then update the error in the slot.
#[allow(clippy::unwrap_in_result)]
pub fn try_update_error(&self, e: SharedPeerError) -> Result<(), AlreadyErrored> {
let mut guard = self.0.lock().expect("error mutex should be unpoisoned");
if let Some(original_error) = guard.clone() {
Err(AlreadyErrored { original_error })
} else {
*guard = Some(e);
Ok(())
}
}
}
/// Error returned when the [`ErrorSlot`] already contains an error.
#[derive(Clone, Debug)]
pub struct AlreadyErrored {
/// The original error in the error slot.
pub original_error: SharedPeerError,
}

View file

@ -40,7 +40,7 @@ use super::{
}; };
use crate::address_book::connection_handle::new_address_book_connection_handle; use crate::address_book::connection_handle::new_address_book_connection_handle;
use crate::address_book::{AddressBookRequest, AddressBookResponse}; use crate::address_book::{AddressBookRequest, AddressBookResponse};
use crate::connection_tracker::ConnectionTracker; use crate::connection_counter::ConnectionTracker;
use crate::constants::{ use crate::constants::{
CUPRATE_MINIMUM_SUPPORT_FLAGS, HANDSHAKE_TIMEOUT, P2P_MAX_PEERS_IN_HANDSHAKE, CUPRATE_MINIMUM_SUPPORT_FLAGS, HANDSHAKE_TIMEOUT, P2P_MAX_PEERS_IN_HANDSHAKE,
}; };

View file

@ -29,28 +29,6 @@ pub struct LoadTrackedClient {
connection_info: Arc<ConnectionInfo>, connection_info: Arc<ConnectionInfo>,
} }
impl LoadTrackedClient {
pub fn supports_fluffy_blocks(&self) -> bool {
self.connection_info.support_flags.supports_fluffy_blocks()
}
pub fn current_height(&self) -> u64 {
self.connection_info.peer_height.load(Ordering::SeqCst)
}
pub fn pruning_seed(&self) -> PruningSeed {
self.connection_info.pruning_seed
}
pub fn has_full_block(&self, block_height: u64) -> bool {
let seed = self.pruning_seed();
let Some(log_stripes) = seed.get_log_stripes() else {
return true;
};
seed.will_have_complete_block(block_height, self.current_height(), log_stripes)
}
}
/// Create a new [`LoadTrackedClient`] wrapping the provided `client` service. /// Create a new [`LoadTrackedClient`] wrapping the provided `client` service.
impl From<Client> for LoadTrackedClient { impl From<Client> for LoadTrackedClient {
fn from(client: Client) -> Self { fn from(client: Client) -> Self {

View file

@ -24,11 +24,30 @@
/// ///
use monero_wire::{ use monero_wire::{
ChainRequest, ChainResponse, FluffyMissingTransactionsRequest, GetObjectsRequest, ChainRequest, ChainResponse, FluffyMissingTransactionsRequest, GetObjectsRequest,
GetObjectsResponse, GetTxPoolCompliment, HandshakeRequest, HandshakeResponse, NewBlock, GetObjectsResponse, GetTxPoolCompliment, HandshakeRequest, HandshakeResponse, Message,
NewFluffyBlock, NewTransactions, PingResponse, SupportFlagsResponse, TimedSyncRequest, NewBlock, NewFluffyBlock, NewTransactions, PingResponse, SupportFlagsResponse,
TimedSyncResponse, TimedSyncRequest, TimedSyncResponse,
}; };
/// An enum representing a request/ response combination, so a handshake request
/// and response would have the same [`MessageID`]. This allows associating the
/// correct response to a request.
#[derive(Debug, Eq, PartialEq, Copy, Clone)]
pub enum MessageID {
Handshake,
TimedSync,
Ping,
SupportFlags,
GetObjects,
GetChain,
FluffyMissingTxs,
GetTxPollCompliment,
NewBlock,
NewFluffyBlock,
NewTransactions,
}
pub enum Request { pub enum Request {
Handshake(HandshakeRequest), Handshake(HandshakeRequest),
TimedSync(TimedSyncRequest), TimedSync(TimedSyncRequest),
@ -44,6 +63,34 @@ pub enum Request {
NewTransactions(NewTransactions), NewTransactions(NewTransactions),
} }
impl Request {
pub fn id(&self) -> MessageID {
match self {
Request::Handshake(_) => MessageID::Handshake,
Request::TimedSync(_) => MessageID::TimedSync,
Request::Ping => MessageID::Ping,
Request::SupportFlags => MessageID::SupportFlags,
Request::GetObjects(_) => MessageID::GetObjects,
Request::GetChain(_) => MessageID::GetChain,
Request::FluffyMissingTxs(_) => MessageID::FluffyMissingTxs,
Request::GetTxPollCompliment(_) => MessageID::GetTxPollCompliment,
Request::NewBlock(_) => MessageID::NewBlock,
Request::NewFluffyBlock(_) => MessageID::NewFluffyBlock,
Request::NewTransactions(_) => MessageID::NewTransactions,
}
}
pub fn needs_response(&self) -> bool {
match self {
Request::NewBlock(_) | Request::NewFluffyBlock(_) | Request::NewTransactions(_) => {
false
}
_ => true,
}
}
}
pub enum Response { pub enum Response {
Handshake(HandshakeResponse), Handshake(HandshakeResponse),
TimedSync(TimedSyncResponse), TimedSync(TimedSyncResponse),
@ -56,3 +103,21 @@ pub enum Response {
NewTransactions(NewTransactions), NewTransactions(NewTransactions),
NA, NA,
} }
impl Response {
pub fn id(&self) -> MessageID {
match self {
Response::Handshake(_) => MessageID::Handshake,
Response::TimedSync(_) => MessageID::TimedSync,
Response::Ping(_) => MessageID::Ping,
Response::SupportFlags(_) => MessageID::SupportFlags,
Response::GetObjects(_) => MessageID::GetObjects,
Response::GetChain(_) => MessageID::GetChain,
Response::NewFluffyBlock(_) => MessageID::NewBlock,
Response::NewTransactions(_) => MessageID::NewFluffyBlock,
Response::NA => panic!("Can't get message ID for a non existent response"),
}
}
}