change the peer module to use the new API + fix a couple bugs

This commit is contained in:
Boog900 2023-07-21 15:24:14 +01:00
parent 8981260750
commit e5d5730ad6
No known key found for this signature in database
GPG key ID: 5401367FB7302004
15 changed files with 447 additions and 422 deletions

View file

@ -61,7 +61,7 @@ pub enum BucketError {
InvalidFragmentedMessage(&'static str),
/// Error decoding the body
#[error("Error decoding bucket body")]
BodyDecodingError(Box<dyn Debug + Send + Sync>),
BodyDecodingError(Box<dyn std::error::Error + Send + Sync>),
/// I/O error
#[error("I/O error: {0}")]
IO(#[from] std::io::Error),

View file

@ -155,7 +155,8 @@ pub struct PeerListEntryBase {
/// The Peer Address
pub adr: NetworkAddress,
/// The Peer ID
pub id: u64,
#[epee_try_from_into(u64)]
pub id: PeerID,
/// The last Time The Peer Was Seen
#[epee_default(0)]
pub last_seen: i64,

View file

@ -18,5 +18,6 @@ tokio-util = {version = "0.7.8", features=["codec"]}
tokio-stream = {version="0.1.14", features=["time"]}
async-trait = "0.1.68"
tracing = "0.1.37"
tracing-error = "0.2.0"
rand = "0.8.5"
pin-project = "1.0.12"

View file

@ -0,0 +1,130 @@
//! Counting active connections used by Cuprate.
//!
//! These types can be used to count any kind of active resource.
//! But they are currently used to track the number of open connections.
use std::{fmt, sync::Arc};
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
/// A counter for active connections.
///
/// Creates a [`ConnectionTracker`] to track each active connection.
/// When these trackers are dropped, the counter gets notified.
pub struct ActiveConnectionCounter {
/// The limit for this type of connection, for diagnostics only.
/// The caller must enforce the limit by ignoring, delaying, or dropping connections.
limit: usize,
/// The label for this connection counter, typically its type.
label: Arc<str>,
semaphore: Arc<Semaphore>,
}
impl fmt::Debug for ActiveConnectionCounter {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ActiveConnectionCounter")
.field("label", &self.label)
.field("count", &self.count())
.field("limit", &self.limit)
.finish()
}
}
impl ActiveConnectionCounter {
/// Create and return a new active connection counter.
pub fn new_counter() -> Self {
Self::new_counter_with(Semaphore::MAX_PERMITS, "Active Connections")
}
/// Create and return a new active connection counter with `limit` and `label`.
/// The caller must check and enforce limits using [`update_count()`](Self::update_count).
pub fn new_counter_with<S: ToString>(limit: usize, label: S) -> Self {
let label = label.to_string();
Self {
limit,
label: label.into(),
semaphore: Arc::new(Semaphore::new(limit)),
}
}
/// Create and return a new [`ConnectionTracker`], using a permit from the semaphore,
/// SAFETY:
/// This function will panic if the semaphore doesn't have anymore permits.
pub fn track_connection(&mut self) -> ConnectionTracker {
ConnectionTracker::new(self)
}
pub fn count(&self) -> usize {
let count = self
.limit
.checked_sub(self.semaphore.available_permits())
.expect("Limit is less than available connection permits");
tracing::trace!(
open_connections = ?count,
limit = ?self.limit,
label = ?self.label,
);
count
}
pub fn available_permits(&self) -> usize {
self.semaphore.available_permits()
}
}
/// A per-connection tracker.
///
/// [`ActiveConnectionCounter`] creates a tracker instance for each active connection.
pub struct ConnectionTracker {
/// The permit for this connection, updates the semaphore when dropped.
permit: OwnedSemaphorePermit,
/// The label for this connection counter, typically its type.
label: Arc<str>,
}
impl fmt::Debug for ConnectionTracker {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("ConnectionTracker")
.field(&self.label)
.finish()
}
}
impl ConnectionTracker {
/// Create and return a new active connection tracker, and add 1 to `counter`.
/// All connection trackers share a label with their connection counter.
///
/// When the returned tracker is dropped, `counter` will be notified.
///
/// SAFETY:
/// This function will panic if the [`ActiveConnectionCounter`] doesn't have anymore permits.
fn new(counter: &mut ActiveConnectionCounter) -> Self {
tracing::debug!(
open_connections = ?counter.count(),
limit = ?counter.limit,
label = ?counter.label,
"opening a new peer connection",
);
Self {
permit: counter.semaphore.clone().try_acquire_owned().unwrap(),
label: counter.label.clone(),
}
}
}
impl Drop for ConnectionTracker {
fn drop(&mut self) {
tracing::debug!(
label = ?self.label,
"A peer connection has closed",
);
// the permit is automatically dropped
}
}

View file

@ -1,132 +1,98 @@
//! This module contains the address book [`Connection`](crate::peer::connection::Connection) handle
//!
//! # Why do we need a handle between the address book and connection task
//!
//! When banning a peer we need to tell the connection task to close and
//! when we close a connection we need to remove it from our connection
//! and anchor list.
//! when we close a connection we need to tell the address book.
//!
//!
use futures::channel::{mpsc, oneshot};
use futures::{FutureExt, SinkExt, StreamExt};
use std::time::Duration;
/// A message sent to tell the address book that a peer has disconnected.
pub struct PeerConnectionClosed(Option<std::time::Duration>);
use futures::channel::mpsc;
use futures::SinkExt;
use tokio_util::sync::CancellationToken;
pub enum ConnectionClosed {
Closed(Option<std::time::Duration>),
Running,
use crate::connection_counter::ConnectionTracker;
#[derive(Default, Debug)]
pub struct HandleBuilder {
tracker: Option<ConnectionTracker>,
}
/// The connection side of the address book to connection
/// communication.
#[derive(Debug)]
pub struct ConnectionHandleAddressBookSide {
connection_closed_rx: oneshot::Receiver<PeerConnectionClosed>,
ban_tx: mpsc::Sender<std::time::Duration>,
}
impl ConnectionHandleAddressBookSide {
pub fn ban_peer(&mut self, duration: std::time::Duration) {
let _ = self.ban_tx.send(duration);
impl HandleBuilder {
pub fn set_tracker(&mut self, tracker: ConnectionTracker) {
self.tracker = Some(tracker)
}
pub fn check_connection_closed(&mut self) -> ConnectionClosed {
let connection_closed = self
.connection_closed_rx
.try_recv()
.expect("Will not be cancelled");
match connection_closed {
Some(closed) => return ConnectionClosed::Closed(closed.0),
None => ConnectionClosed::Running,
pub fn build(self) -> (DisconnectSignal, ConnectionHandle, PeerHandle) {
let token = CancellationToken::new();
let (tx, rx) = mpsc::channel(0);
(
DisconnectSignal {
token: token.clone(),
tracker: self.tracker.expect("Tracker was not set!"),
},
ConnectionHandle {
token: token.clone(),
ban: rx,
},
PeerHandle { ban: tx },
)
}
}
pub struct BanPeer(pub Duration);
/// A struct given to the connection task.
pub struct DisconnectSignal {
token: CancellationToken,
tracker: ConnectionTracker,
}
impl DisconnectSignal {
pub fn should_shutdown(&self) -> bool {
self.token.is_cancelled()
}
pub fn connection_closed(&self) {
self.token.cancel()
}
}
impl Drop for DisconnectSignal {
fn drop(&mut self) {
self.token.cancel()
}
}
/// A handle given to a task that needs to cancel this connection.
pub struct ConnectionHandle {
token: CancellationToken,
ban: mpsc::Receiver<BanPeer>,
}
impl ConnectionHandle {
pub fn is_closed(&self) -> bool {
self.token.is_cancelled()
}
pub fn check_should_ban(&mut self) -> Option<BanPeer> {
match self.ban.try_next() {
Ok(res) => res,
Err(_) => None,
}
}
}
/// The address book side of the address book to connection
/// communication.
#[derive(Debug)]
pub struct ConnectionHandleConnectionSide {
connection_closed_tx: Option<oneshot::Sender<PeerConnectionClosed>>,
ban_rx: mpsc::Receiver<std::time::Duration>,
ban_holder: Option<std::time::Duration>,
}
impl ConnectionHandleConnectionSide {
pub fn been_banned(&mut self) -> bool {
let ban_time =
self.ban_rx.next().now_or_never().and_then(|inner| {
Some(inner.expect("Handles to the connection task wont be dropped"))
});
let ret = ban_time.is_some();
self.ban_holder = ban_time;
ret
pub fn send_close_signal(&self) {
self.token.cancel()
}
}
impl Drop for ConnectionHandleConnectionSide {
fn drop(&mut self) {
let tx = std::mem::replace(&mut self.connection_closed_tx, None)
.expect("Will only be dropped once");
let _ = tx.send(PeerConnectionClosed(self.ban_holder));
}
}
pub struct ConnectionHandleClientSide {
ban_tx: mpsc::Sender<std::time::Duration>,
}
impl ConnectionHandleClientSide {
pub fn ban_peer(&mut self, duration: std::time::Duration) {
let _ = self.ban_tx.send(duration);
}
}
/// Creates a new handle pair that can be given to the connection task and
/// address book respectively.
pub fn new_address_book_connection_handle() -> (
ConnectionHandleConnectionSide,
ConnectionHandleAddressBookSide,
ConnectionHandleClientSide,
) {
let (tx_closed, rx_closed) = oneshot::channel();
let (tx_ban, rx_ban) = mpsc::channel(0);
let c_h_c_s = ConnectionHandleConnectionSide {
connection_closed_tx: Some(tx_closed),
ban_rx: rx_ban,
ban_holder: None,
};
let c_h_a_s = ConnectionHandleAddressBookSide {
connection_closed_rx: rx_closed,
ban_tx: tx_ban.clone(),
};
let c_h_cl_s = ConnectionHandleClientSide { ban_tx: tx_ban };
(c_h_c_s, c_h_a_s, c_h_cl_s)
}
#[cfg(test)]
mod tests {
use super::new_address_book_connection_handle;
#[test]
fn close_connection_from_address_book() {
let (conn_side, mut addr_side) = new_address_book_connection_handle();
assert!(!conn_side.is_canceled());
assert!(!addr_side.connection_closed());
addr_side.kill_connection();
assert!(conn_side.is_canceled());
}
#[test]
fn close_connection_from_connection() {
let (conn_side, mut addr_side) = new_address_book_connection_handle();
assert!(!conn_side.is_canceled());
assert!(!addr_side.connection_closed());
drop(conn_side);
assert!(addr_side.connection_closed());
/// A handle given to a task that needs to be able to ban a connection.
#[derive(Clone)]
pub struct PeerHandle {
ban: mpsc::Sender<BanPeer>,
}
impl PeerHandle {
pub fn ban_peer(&mut self, duration: Duration) {
let _ = self.ban.send(BanPeer(duration));
}
}

View file

@ -1,174 +0,0 @@
//! Counting active connections used by Cuprate.
//!
//! These types can be used to count any kind of active resource.
//! But they are currently used to track the number of open connections.
//!
//! This file was originally from Zebra
use std::{fmt, sync::Arc};
use std::sync::mpsc;
/// A signal sent by a [`Connection`][1] when it closes.
///
/// Used to count the number of open connections.
///
/// [1]: crate::peer::Connection
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub struct ConnectionClosed;
/// A counter for active connections.
///
/// Creates a [`ConnectionTracker`] to track each active connection.
/// When these trackers are dropped, the counter gets notified.
pub struct ActiveConnectionCounter {
/// The number of active peers tracked using this counter.
count: usize,
/// The limit for this type of connection, for diagnostics only.
/// The caller must enforce the limit by ignoring, delaying, or dropping connections.
limit: usize,
/// The label for this connection counter, typically its type.
label: Arc<str>,
/// The channel used to send closed connection notifications.
close_notification_tx: mpsc::Sender<ConnectionClosed>,
/// The channel used to receive closed connection notifications.
close_notification_rx: mpsc::Receiver<ConnectionClosed>,
}
impl fmt::Debug for ActiveConnectionCounter {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ActiveConnectionCounter")
.field("label", &self.label)
.field("count", &self.count)
.field("limit", &self.limit)
.finish()
}
}
impl ActiveConnectionCounter {
/// Create and return a new active connection counter.
pub fn new_counter() -> Self {
Self::new_counter_with(usize::MAX, "Active Connections")
}
/// Create and return a new active connection counter with `limit` and `label`.
/// The caller must check and enforce limits using [`update_count()`](Self::update_count).
pub fn new_counter_with<S: ToString>(limit: usize, label: S) -> Self {
// The number of items in this channel is bounded by the connection limit.
let (close_notification_tx, close_notification_rx) = mpsc::channel();
let label = label.to_string();
Self {
count: 0,
limit,
label: label.into(),
close_notification_rx,
close_notification_tx,
}
}
/// Create and return a new [`ConnectionTracker`], and add 1 to this counter.
///
/// When the returned tracker is dropped, this counter will be notified, and decreased by 1.
pub fn track_connection(&mut self) -> ConnectionTracker {
ConnectionTracker::new(self)
}
/// Check for closed connection notifications, and return the current connection count.
pub fn update_count(&mut self) -> usize {
let previous_connections = self.count;
// We ignore errors here:
// - TryRecvError::Empty means that there are no pending close notifications
// - TryRecvError::Closed is unreachable, because we hold a sender
while let Ok(ConnectionClosed) = self.close_notification_rx.try_recv() {
self.count -= 1;
tracing::debug!(
open_connections = ?self.count,
?previous_connections,
limit = ?self.limit,
label = ?self.label,
"a peer connection was closed",
);
}
tracing::trace!(
open_connections = ?self.count,
?previous_connections,
limit = ?self.limit,
label = ?self.label,
"updated active connection count",
);
#[cfg(feature = "progress-bar")]
self.connection_bar
.set_pos(u64::try_from(self.count).expect("fits in u64"))
.set_len(u64::try_from(self.limit).expect("fits in u64"));
self.count
}
}
impl Drop for ActiveConnectionCounter {
fn drop(&mut self) {}
}
/// A per-connection tracker.
///
/// [`ActiveConnectionCounter`] creates a tracker instance for each active connection.
/// When these trackers are dropped, the counter gets notified.
pub struct ConnectionTracker {
/// The channel used to send closed connection notifications on drop.
close_notification_tx: mpsc::Sender<ConnectionClosed>,
/// The label for this connection counter, typically its type.
label: Arc<str>,
}
impl fmt::Debug for ConnectionTracker {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("ConnectionTracker")
.field(&self.label)
.finish()
}
}
impl ConnectionTracker {
/// Create and return a new active connection tracker, and add 1 to `counter`.
/// All connection trackers share a label with their connection counter.
///
/// When the returned tracker is dropped, `counter` will be notified, and decreased by 1.
fn new(counter: &mut ActiveConnectionCounter) -> Self {
counter.count += 1;
tracing::debug!(
open_connections = ?counter.count,
limit = ?counter.limit,
label = ?counter.label,
"opening a new peer connection",
);
Self {
close_notification_tx: counter.close_notification_tx.clone(),
label: counter.label.clone(),
}
}
}
impl Drop for ConnectionTracker {
/// Notifies the corresponding connection counter that the connection has closed.
fn drop(&mut self) {
tracing::debug!(label = ?self.label, "closing a peer connection");
// We ignore disconnected errors, because the receiver can be dropped
// before some connections are dropped.
//
let _ = self.close_notification_tx.send(ConnectionClosed);
}
}

View file

@ -1,6 +1,7 @@
pub mod address_book;
pub mod config;
pub mod connection_tracker;
mod connection_handle;
pub mod connection_counter;
mod constants;
pub mod peer;
pub mod peer_set;

View file

@ -4,21 +4,10 @@ pub mod connector;
pub mod handshaker;
pub mod load_tracked_client;
mod error;
#[cfg(test)]
mod tests;
use monero_wire::levin::BucketError;
use thiserror::Error;
#[derive(Debug, Error, Clone, Copy)]
pub enum RequestServiceError {}
#[derive(Debug, Error, Clone, Copy)]
pub enum PeerError {
#[error("The connection task has closed.")]
ConnectionTaskClosed,
}
pub use client::Client;
pub use client::ConnectionInfo;
pub use connection::Connection;

View file

@ -10,21 +10,21 @@ use futures::{
use tokio::task::JoinHandle;
use tower::BoxError;
use crate::peer::handshaker::ConnectionAddr;
use cuprate_common::shutdown::set_shutting_down;
use cuprate_common::PruningSeed;
use monero_wire::messages::PeerID;
use monero_wire::{messages::common::PeerSupportFlags, NetworkAddress};
use super::{connection::ClientRequest, PeerError};
use super::{
connection::ClientRequest,
error::{ErrorSlot, PeerError, SharedPeerError},
PeerError,
};
use crate::connection_handle::PeerHandle;
use crate::protocol::{InternalMessageRequest, InternalMessageResponse};
pub struct ConnectionInfo {
pub addr: ConnectionAddr,
pub support_flags: PeerSupportFlags,
pub pruning_seed: PruningSeed,
/// Peer ID
pub peer_id: PeerID,
pub handle: PeerHandle,
pub rpc_port: u16,
pub rpc_credits_per_hash: u32,
}
@ -37,6 +37,8 @@ pub struct Client {
server_tx: mpsc::Sender<ClientRequest>,
connection_task: JoinHandle<()>,
heartbeat_task: JoinHandle<()>,
error_slot: ErrorSlot,
}
impl Client {
@ -46,6 +48,7 @@ impl Client {
server_tx: mpsc::Sender<ClientRequest>,
connection_task: JoinHandle<()>,
heartbeat_task: JoinHandle<()>,
error_slot: ErrorSlot,
) -> Self {
Client {
connection_info,
@ -53,6 +56,7 @@ impl Client {
server_tx,
connection_task,
heartbeat_task,
error_slot,
}
}
@ -140,7 +144,7 @@ impl Client {
impl tower::Service<InternalMessageRequest> for Client {
type Response = InternalMessageResponse;
type Error = PeerError;
type Error = SharedPeerError;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
@ -163,7 +167,7 @@ impl tower::Service<InternalMessageRequest> for Client {
.map_err(|e| e.into())
})
.boxed(),
Err(_e) => {
Err(_) => {
// TODO: better error handling
futures::future::ready(Err(PeerError::ClientChannelClosed.into())).boxed()
}

View file

@ -1,123 +1,77 @@
use std::collections::HashSet;
use std::sync::atomic::AtomicU64;
use futures::channel::{mpsc, oneshot};
use futures::stream::Fuse;
use futures::{AsyncRead, AsyncWrite, SinkExt, StreamExt};
use futures::{Sink, SinkExt, Stream};
use crate::address_book::connection_handle::AddressBookConnectionHandle;
use levin::{MessageSink, MessageStream};
use monero_wire::messages::CoreSyncData;
use monero_wire::{levin, Message, NetworkAddress};
use monero_wire::{BucketError, Message};
use tower::{BoxError, Service, ServiceExt};
use crate::connection_tracker::{self, ConnectionTracker};
use crate::connection_handle::DisconnectSignal;
use crate::peer::error::{ErrorSlot, PeerError, SharedPeerError};
use crate::peer::handshaker::ConnectionAddr;
use crate::protocol::{InternalMessageRequest, InternalMessageResponse};
use super::PeerError;
use crate::protocol::internal_network::{MessageID, Request, Response};
pub struct ClientRequest {
pub req: InternalMessageRequest,
pub tx: oneshot::Sender<Result<InternalMessageResponse, PeerError>>,
pub req: Request,
pub tx: oneshot::Sender<Result<Response, SharedPeerError>>,
}
pub enum State {
WaitingForRequest,
WaitingForResponse {
request: InternalMessageRequest,
tx: oneshot::Sender<Result<InternalMessageResponse, PeerError>>,
request_id: MessageID,
tx: oneshot::Sender<Result<Response, SharedPeerError>>,
},
}
impl State {
pub fn expected_response_id(&self) -> Option<u32> {
match self {
Self::WaitingForRequest => None,
Self::WaitingForResponse { request, tx: _ } => request.expected_id(),
}
}
}
pub struct Connection<Svc, Aw> {
pub struct Connection<Svc, Snk> {
address: ConnectionAddr,
state: State,
sink: MessageSink<Aw, Message>,
sink: Snk,
client_rx: mpsc::Receiver<ClientRequest>,
/// A connection tracker that reduces the open connection count when dropped.
/// Used to limit the number of open connections in Zebra.
///
/// This field does nothing until it is dropped.
///
error_slot: ErrorSlot,
/// # Security
///
/// If this connection tracker or `Connection`s are leaked,
/// the number of active connections will appear higher than it actually is.
/// If enough connections leak, Cuprate will stop making new connections.
#[allow(dead_code)]
connection_tracker: ConnectionTracker,
/// A handle to our slot in the address book so we can tell the address
/// book when we disconnect and the address book can tell us to disconnect.
address_book_handle: AddressBookConnectionHandle,
connection_tracker: DisconnectSignal,
svc: Svc,
}
impl<Svc, Aw> Connection<Svc, Aw>
impl<Svc, Snk> Connection<Svc, Snk>
where
Svc: Service<InternalMessageRequest, Response = InternalMessageResponse, Error = BoxError>,
Aw: AsyncWrite + std::marker::Unpin,
Svc: Service<Request, Response = Response, Error = BoxError>,
Snk: Sink<Message, Error = BucketError> + Unpin,
{
pub fn new(
address: ConnectionAddr,
sink: MessageSink<Aw, Message>,
sink: Snk,
client_rx: mpsc::Receiver<ClientRequest>,
connection_tracker: ConnectionTracker,
address_book_handle: AddressBookConnectionHandle,
error_slot: ErrorSlot,
connection_tracker: DisconnectSignal,
svc: Svc,
) -> Connection<Svc, Aw> {
) -> Connection<Svc, Snk> {
Connection {
address,
state: State::WaitingForRequest,
sink,
client_rx,
error_slot,
connection_tracker,
address_book_handle,
svc,
}
}
async fn handle_response(&mut self, res: InternalMessageResponse) -> Result<(), PeerError> {
async fn handle_response(&mut self, res: Response) -> Result<(), PeerError> {
let state = std::mem::replace(&mut self.state, State::WaitingForRequest);
if let State::WaitingForResponse { request, tx } = state {
match (request, &res) {
(InternalMessageRequest::Handshake(_), InternalMessageResponse::Handshake(_)) => {}
(
InternalMessageRequest::SupportFlags(_),
InternalMessageResponse::SupportFlags(_),
) => {}
(InternalMessageRequest::TimedSync(_), InternalMessageResponse::TimedSync(res)) => {
}
(
InternalMessageRequest::GetObjectsRequest(req),
InternalMessageResponse::GetObjectsResponse(res),
) => {}
(
InternalMessageRequest::ChainRequest(_),
InternalMessageResponse::ChainResponse(res),
) => {}
(
InternalMessageRequest::FluffyMissingTransactionsRequest(req),
InternalMessageResponse::NewFluffyBlock(blk),
) => {}
(
InternalMessageRequest::GetTxPoolCompliment(_),
InternalMessageResponse::NewTransactions(_),
) => {
// we could check we received no transactions that we said we knew about but thats going to happen later anyway when they get added to our
// mempool
}
_ => return Err(PeerError::ResponseError("Peer sent incorrect response")),
if let State::WaitingForResponse { request_id, tx } = state {
if request_id != res.id() {
// TODO: Fail here
return Err(PeerError::PeerSentIncorrectResponse);
}
// response passed our tests we can send it to the requestor
// response passed our tests we can send it to the requester
let _ = tx.send(Ok(res));
Ok(())
} else {
@ -129,7 +83,7 @@ where
Ok(self.sink.send(mes.into()).await?)
}
async fn handle_peer_request(&mut self, req: InternalMessageRequest) -> Result<(), PeerError> {
async fn handle_peer_request(&mut self, req: Request) -> Result<(), PeerError> {
// we should check contents of peer requests for obvious errors like we do with responses
todo!()
/*
@ -140,13 +94,13 @@ where
}
async fn handle_client_request(&mut self, req: ClientRequest) -> Result<(), PeerError> {
// check we need a response
if let Some(_) = req.req.expected_id() {
if req.req.needs_response() {
self.state = State::WaitingForResponse {
request: req.req.clone(),
request_id: req.req.id(),
tx: req.tx,
};
}
// TODO: send NA response to requester
self.send_message_to_peer(req.req).await
}
@ -197,9 +151,7 @@ where
loop {
let _res = match self.state {
State::WaitingForRequest => self.state_waiting_for_request().await,
State::WaitingForResponse { request: _, tx: _ } => {
self.state_waiting_for_response().await
}
State::WaitingForResponse { .. } => self.state_waiting_for_response().await,
};
}
}

View file

@ -16,7 +16,7 @@ use tracing::Instrument;
use crate::peer::handshaker::ConnectionAddr;
use crate::{
address_book::{AddressBookRequest, AddressBookResponse},
connection_tracker::ConnectionTracker,
connection_counter::ConnectionTracker,
protocol::{
CoreSyncDataRequest, CoreSyncDataResponse, InternalMessageRequest, InternalMessageResponse,
},

112
p2p/src/peer/error.rs Normal file
View file

@ -0,0 +1,112 @@
use std::sync::{Arc, Mutex};
use thiserror::Error;
use tracing_error::TracedError;
use monero_wire::BucketError;
/// A wrapper around `Arc<PeerError>` that implements `Error`.
#[derive(Error, Debug, Clone)]
#[error(transparent)]
pub struct SharedPeerError(Arc<TracedError<PeerError>>);
impl<E> From<E> for SharedPeerError
where
PeerError: From<E>,
{
fn from(source: E) -> Self {
Self(Arc::new(TracedError::from(PeerError::from(source))))
}
}
impl SharedPeerError {
/// Returns a debug-formatted string describing the inner [`PeerError`].
///
/// Unfortunately, [`TracedError`] makes it impossible to get a reference to the original error.
pub fn inner_debug(&self) -> String {
format!("{:?}", self.0.as_ref())
}
}
#[derive(Debug, Error, Clone)]
pub enum PeerError {
#[error("The connection task has closed.")]
ConnectionTaskClosed,
#[error("The connected peer sent an incorrect response.")]
PeerSentIncorrectResponse,
#[error("The connected peer sent an incorrect response.")]
BucketError(#[from] BucketError)
}
/// A shared error slot for peer errors.
///
/// # Correctness
///
/// Error slots are shared between sync and async code. In async code, the error
/// mutex should be held for as short a time as possible. This avoids blocking
/// the async task thread on acquiring the mutex.
///
/// > If the value behind the mutex is just data, its usually appropriate to use a blocking mutex
/// > ...
/// > wrap the `Arc<Mutex<...>>` in a struct
/// > that provides non-async methods for performing operations on the data within,
/// > and only lock the mutex inside these methods
///
/// <https://docs.rs/tokio/1.15.0/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use>
#[derive(Default, Clone)]
pub struct ErrorSlot(Arc<std::sync::Mutex<Option<SharedPeerError>>>);
impl std::fmt::Debug for ErrorSlot {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// don't hang if the mutex is locked
// show the panic if the mutex was poisoned
f.debug_struct("ErrorSlot")
.field("error", &self.0.try_lock())
.finish()
}
}
impl ErrorSlot {
/// Read the current error in the slot.
///
/// Returns `None` if there is no error in the slot.
///
/// # Correctness
///
/// Briefly locks the error slot's threaded `std::sync::Mutex`, to get a
/// reference to the error in the slot.
#[allow(clippy::unwrap_in_result)]
pub fn try_get_error(&self) -> Option<SharedPeerError> {
self.0
.lock()
.expect("error mutex should be unpoisoned")
.as_ref()
.cloned()
}
/// Update the current error in the slot.
///
/// Returns `Err(AlreadyErrored)` if there was already an error in the slot.
///
/// # Correctness
///
/// Briefly locks the error slot's threaded `std::sync::Mutex`, to check for
/// a previous error, then update the error in the slot.
#[allow(clippy::unwrap_in_result)]
pub fn try_update_error(&self, e: SharedPeerError) -> Result<(), AlreadyErrored> {
let mut guard = self.0.lock().expect("error mutex should be unpoisoned");
if let Some(original_error) = guard.clone() {
Err(AlreadyErrored { original_error })
} else {
*guard = Some(e);
Ok(())
}
}
}
/// Error returned when the [`ErrorSlot`] already contains an error.
#[derive(Clone, Debug)]
pub struct AlreadyErrored {
/// The original error in the error slot.
pub original_error: SharedPeerError,
}

View file

@ -40,7 +40,7 @@ use super::{
};
use crate::address_book::connection_handle::new_address_book_connection_handle;
use crate::address_book::{AddressBookRequest, AddressBookResponse};
use crate::connection_tracker::ConnectionTracker;
use crate::connection_counter::ConnectionTracker;
use crate::constants::{
CUPRATE_MINIMUM_SUPPORT_FLAGS, HANDSHAKE_TIMEOUT, P2P_MAX_PEERS_IN_HANDSHAKE,
};

View file

@ -29,28 +29,6 @@ pub struct LoadTrackedClient {
connection_info: Arc<ConnectionInfo>,
}
impl LoadTrackedClient {
pub fn supports_fluffy_blocks(&self) -> bool {
self.connection_info.support_flags.supports_fluffy_blocks()
}
pub fn current_height(&self) -> u64 {
self.connection_info.peer_height.load(Ordering::SeqCst)
}
pub fn pruning_seed(&self) -> PruningSeed {
self.connection_info.pruning_seed
}
pub fn has_full_block(&self, block_height: u64) -> bool {
let seed = self.pruning_seed();
let Some(log_stripes) = seed.get_log_stripes() else {
return true;
};
seed.will_have_complete_block(block_height, self.current_height(), log_stripes)
}
}
/// Create a new [`LoadTrackedClient`] wrapping the provided `client` service.
impl From<Client> for LoadTrackedClient {
fn from(client: Client) -> Self {

View file

@ -24,11 +24,30 @@
///
use monero_wire::{
ChainRequest, ChainResponse, FluffyMissingTransactionsRequest, GetObjectsRequest,
GetObjectsResponse, GetTxPoolCompliment, HandshakeRequest, HandshakeResponse, NewBlock,
NewFluffyBlock, NewTransactions, PingResponse, SupportFlagsResponse, TimedSyncRequest,
TimedSyncResponse,
GetObjectsResponse, GetTxPoolCompliment, HandshakeRequest, HandshakeResponse, Message,
NewBlock, NewFluffyBlock, NewTransactions, PingResponse, SupportFlagsResponse,
TimedSyncRequest, TimedSyncResponse,
};
/// An enum representing a request/ response combination, so a handshake request
/// and response would have the same [`MessageID`]. This allows associating the
/// correct response to a request.
#[derive(Debug, Eq, PartialEq, Copy, Clone)]
pub enum MessageID {
Handshake,
TimedSync,
Ping,
SupportFlags,
GetObjects,
GetChain,
FluffyMissingTxs,
GetTxPollCompliment,
NewBlock,
NewFluffyBlock,
NewTransactions,
}
pub enum Request {
Handshake(HandshakeRequest),
TimedSync(TimedSyncRequest),
@ -44,6 +63,34 @@ pub enum Request {
NewTransactions(NewTransactions),
}
impl Request {
pub fn id(&self) -> MessageID {
match self {
Request::Handshake(_) => MessageID::Handshake,
Request::TimedSync(_) => MessageID::TimedSync,
Request::Ping => MessageID::Ping,
Request::SupportFlags => MessageID::SupportFlags,
Request::GetObjects(_) => MessageID::GetObjects,
Request::GetChain(_) => MessageID::GetChain,
Request::FluffyMissingTxs(_) => MessageID::FluffyMissingTxs,
Request::GetTxPollCompliment(_) => MessageID::GetTxPollCompliment,
Request::NewBlock(_) => MessageID::NewBlock,
Request::NewFluffyBlock(_) => MessageID::NewFluffyBlock,
Request::NewTransactions(_) => MessageID::NewTransactions,
}
}
pub fn needs_response(&self) -> bool {
match self {
Request::NewBlock(_) | Request::NewFluffyBlock(_) | Request::NewTransactions(_) => {
false
}
_ => true,
}
}
}
pub enum Response {
Handshake(HandshakeResponse),
TimedSync(TimedSyncResponse),
@ -56,3 +103,21 @@ pub enum Response {
NewTransactions(NewTransactions),
NA,
}
impl Response {
pub fn id(&self) -> MessageID {
match self {
Response::Handshake(_) => MessageID::Handshake,
Response::TimedSync(_) => MessageID::TimedSync,
Response::Ping(_) => MessageID::Ping,
Response::SupportFlags(_) => MessageID::SupportFlags,
Response::GetObjects(_) => MessageID::GetObjects,
Response::GetChain(_) => MessageID::GetChain,
Response::NewFluffyBlock(_) => MessageID::NewBlock,
Response::NewTransactions(_) => MessageID::NewFluffyBlock,
Response::NA => panic!("Can't get message ID for a non existent response"),
}
}
}