diff --git a/Cargo.lock b/Cargo.lock index 83138ea..d5d41b0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -840,7 +840,6 @@ dependencies = [ "cuprate-test-utils", "cuprate-types", "cuprate-wire", - "dashmap", "futures", "indexmap", "monero-serai", diff --git a/binaries/cuprated/src/blockchain/syncer.rs b/binaries/cuprated/src/blockchain/syncer.rs index 913c983..69ad330 100644 --- a/binaries/cuprated/src/blockchain/syncer.rs +++ b/binaries/cuprated/src/blockchain/syncer.rs @@ -12,7 +12,7 @@ use tracing::instrument; use cuprate_consensus::{BlockChainContext, BlockChainContextRequest, BlockChainContextResponse}; use cuprate_p2p::{ block_downloader::{BlockBatch, BlockDownloaderConfig, ChainSvcRequest, ChainSvcResponse}, - NetworkInterface, + NetworkInterface, PeerSetRequest, PeerSetResponse, }; use cuprate_p2p_core::ClearNet; @@ -28,15 +28,11 @@ pub enum SyncerError { } /// The syncer tasks that makes sure we are fully synchronised with our connected peers. -#[expect( - clippy::significant_drop_tightening, - reason = "Client pool which will be removed" -)] #[instrument(level = "debug", skip_all)] pub async fn syncer( mut context_svc: C, our_chain: CN, - clearnet_interface: NetworkInterface, + mut clearnet_interface: NetworkInterface, incoming_block_batch_tx: mpsc::Sender, stop_current_block_downloader: Arc, block_downloader_config: BlockDownloaderConfig, @@ -67,8 +63,6 @@ where unreachable!(); }; - let client_pool = clearnet_interface.client_pool(); - tracing::debug!("Waiting for new sync info in top sync channel"); loop { @@ -79,9 +73,20 @@ where check_update_blockchain_context(&mut context_svc, &mut blockchain_ctx).await?; let raw_blockchain_context = blockchain_ctx.unchecked_blockchain_context(); - if !client_pool.contains_client_with_more_cumulative_difficulty( - raw_blockchain_context.cumulative_difficulty, - ) { + let PeerSetResponse::MostPoWSeen { + cumulative_difficulty, + .. + } = clearnet_interface + .peer_set() + .ready() + .await? + .call(PeerSetRequest::MostPoWSeen) + .await? + else { + unreachable!(); + }; + + if cumulative_difficulty <= raw_blockchain_context.cumulative_difficulty { continue; } diff --git a/binaries/cuprated/src/txpool/dandelion.rs b/binaries/cuprated/src/txpool/dandelion.rs index d791b62..00d9f5a 100644 --- a/binaries/cuprated/src/txpool/dandelion.rs +++ b/binaries/cuprated/src/txpool/dandelion.rs @@ -59,7 +59,7 @@ pub fn dandelion_router(clear_net: NetworkInterface) -> ConcreteDandel diffuse_service::DiffuseService { clear_net_broadcast_service: clear_net.broadcast_svc(), }, - stem_service::OutboundPeerStream { clear_net }, + stem_service::OutboundPeerStream::new(clear_net), DANDELION_CONFIG, ) } diff --git a/binaries/cuprated/src/txpool/dandelion/stem_service.rs b/binaries/cuprated/src/txpool/dandelion/stem_service.rs index 5c0ba65..2debfd4 100644 --- a/binaries/cuprated/src/txpool/dandelion/stem_service.rs +++ b/binaries/cuprated/src/txpool/dandelion/stem_service.rs @@ -1,14 +1,15 @@ use std::{ + future::Future, pin::Pin, - task::{Context, Poll}, + task::{ready, Context, Poll}, }; use bytes::Bytes; -use futures::Stream; +use futures::{future::BoxFuture, FutureExt, Stream}; use tower::Service; use cuprate_dandelion_tower::{traits::StemRequest, OutboundPeer}; -use cuprate_p2p::{ClientPoolDropGuard, NetworkInterface}; +use cuprate_p2p::{ClientDropGuard, NetworkInterface, PeerSetRequest, PeerSetResponse}; use cuprate_p2p_core::{ client::{Client, InternalPeerID}, ClearNet, NetworkZone, PeerRequest, ProtocolRequest, @@ -19,7 +20,17 @@ use crate::{p2p::CrossNetworkInternalPeerId, txpool::dandelion::DandelionTx}; /// The dandelion outbound peer stream. pub struct OutboundPeerStream { - pub clear_net: NetworkInterface, + clear_net: NetworkInterface, + state: OutboundPeerStreamState, +} + +impl OutboundPeerStream { + pub const fn new(clear_net: NetworkInterface) -> Self { + Self { + clear_net, + state: OutboundPeerStreamState::Standby, + } + } } impl Stream for OutboundPeerStream { @@ -28,23 +39,49 @@ impl Stream for OutboundPeerStream { tower::BoxError, >; - fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - // TODO: make the outbound peer choice random. - Poll::Ready(Some(Ok(self - .clear_net - .client_pool() - .outbound_client() - .map_or(OutboundPeer::Exhausted, |client| { - OutboundPeer::Peer( - CrossNetworkInternalPeerId::ClearNet(client.info.id), - StemPeerService(client), - ) - })))) + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + match &mut self.state { + OutboundPeerStreamState::Standby => { + let peer_set = self.clear_net.peer_set(); + let res = ready!(peer_set.poll_ready(cx)); + + self.state = OutboundPeerStreamState::AwaitingPeer( + peer_set.call(PeerSetRequest::StemPeer).boxed(), + ); + } + OutboundPeerStreamState::AwaitingPeer(fut) => { + let res = ready!(fut.poll_unpin(cx)); + + return Poll::Ready(Some(res.map(|res| { + let PeerSetResponse::StemPeer(stem_peer) = res else { + unreachable!() + }; + + match stem_peer { + Some(peer) => OutboundPeer::Peer( + CrossNetworkInternalPeerId::ClearNet(peer.info.id), + StemPeerService(peer), + ), + None => OutboundPeer::Exhausted, + } + }))); + } + } + } } } +/// The state of the [`OutboundPeerStream`]. +enum OutboundPeerStreamState { + /// Standby state. + Standby, + /// Awaiting a response from the peer-set. + AwaitingPeer(BoxFuture<'static, Result, tower::BoxError>>), +} + /// The stem service, used to send stem txs. -pub struct StemPeerService(ClientPoolDropGuard); +pub struct StemPeerService(ClientDropGuard); impl Service> for StemPeerService { type Response = as Service>::Response; diff --git a/p2p/p2p-core/src/client.rs b/p2p/p2p-core/src/client.rs index 73b33ba..f2fde67 100644 --- a/p2p/p2p-core/src/client.rs +++ b/p2p/p2p-core/src/client.rs @@ -27,9 +27,11 @@ mod connector; pub mod handshaker; mod request_handler; mod timeout_monitor; +mod weak; pub use connector::{ConnectRequest, Connector}; pub use handshaker::{DoHandshakeRequest, HandshakeError, HandshakerBuilder}; +pub use weak::WeakClient; /// An internal identifier for a given peer, will be their address if known /// or a random u128 if not. @@ -128,6 +130,17 @@ impl Client { } .into() } + + /// Create a [`WeakClient`] for this [`Client`]. + pub fn downgrade(&self) -> WeakClient { + WeakClient { + info: self.info.clone(), + connection_tx: self.connection_tx.downgrade(), + semaphore: self.semaphore.clone(), + permit: None, + error: self.error.clone(), + } + } } impl Service for Client { diff --git a/p2p/p2p-core/src/client/weak.rs b/p2p/p2p-core/src/client/weak.rs new file mode 100644 index 0000000..90f25dd --- /dev/null +++ b/p2p/p2p-core/src/client/weak.rs @@ -0,0 +1,114 @@ +use std::task::{ready, Context, Poll}; + +use futures::channel::oneshot; +use tokio::sync::{mpsc, OwnedSemaphorePermit}; +use tokio_util::sync::PollSemaphore; +use tower::Service; + +use cuprate_helper::asynch::InfallibleOneshotReceiver; + +use crate::{ + client::{connection, PeerInformation}, + NetworkZone, PeerError, PeerRequest, PeerResponse, SharedError, +}; + +/// A weak handle to a [`Client`](super::Client). +/// +/// When this is dropped the peer will not be disconnected. +pub struct WeakClient { + /// Information on the connected peer. + pub info: PeerInformation, + + /// The channel to the [`Connection`](connection::Connection) task. + pub(super) connection_tx: mpsc::WeakSender, + + /// The semaphore that limits the requests sent to the peer. + pub(super) semaphore: PollSemaphore, + /// A permit for the semaphore, will be [`Some`] after `poll_ready` returns ready. + pub(super) permit: Option, + + /// The error slot shared between the [`Client`] and [`Connection`](connection::Connection). + pub(super) error: SharedError, +} + +impl WeakClient { + /// Internal function to set an error on the [`SharedError`]. + fn set_err(&self, err: PeerError) -> tower::BoxError { + let err_str = err.to_string(); + match self.error.try_insert_err(err) { + Ok(()) => err_str, + Err(e) => e.to_string(), + } + .into() + } +} + +impl Service for WeakClient { + type Response = PeerResponse; + type Error = tower::BoxError; + type Future = InfallibleOneshotReceiver>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + if let Some(err) = self.error.try_get_err() { + return Poll::Ready(Err(err.to_string().into())); + } + + if self.connection_tx.strong_count() == 0 { + let err = self.set_err(PeerError::ClientChannelClosed); + return Poll::Ready(Err(err)); + } + + if self.permit.is_some() { + return Poll::Ready(Ok(())); + } + + let permit = ready!(self.semaphore.poll_acquire(cx)) + .expect("Client semaphore should not be closed!"); + + self.permit = Some(permit); + + Poll::Ready(Ok(())) + } + + #[expect(clippy::significant_drop_tightening)] + fn call(&mut self, request: PeerRequest) -> Self::Future { + let permit = self + .permit + .take() + .expect("poll_ready did not return ready before call to call"); + + let (tx, rx) = oneshot::channel(); + let req = connection::ConnectionTaskRequest { + response_channel: tx, + request, + permit: Some(permit), + }; + + match self.connection_tx.upgrade() { + None => { + self.set_err(PeerError::ClientChannelClosed); + + let resp = Err(PeerError::ClientChannelClosed.into()); + drop(req.response_channel.send(resp)); + } + Some(sender) => { + if let Err(e) = sender.try_send(req) { + // The connection task could have closed between a call to `poll_ready` and the call to + // `call`, which means if we don't handle the error here the receiver would panic. + use mpsc::error::TrySendError; + + match e { + TrySendError::Closed(req) | TrySendError::Full(req) => { + self.set_err(PeerError::ClientChannelClosed); + + let resp = Err(PeerError::ClientChannelClosed.into()); + drop(req.response_channel.send(resp)); + } + } + } + } + } + + rx.into() + } +} diff --git a/p2p/p2p/Cargo.toml b/p2p/p2p/Cargo.toml index 866fb91..e6ebccb 100644 --- a/p2p/p2p/Cargo.toml +++ b/p2p/p2p/Cargo.toml @@ -20,12 +20,12 @@ monero-serai = { workspace = true, features = ["std"] } tower = { workspace = true, features = ["buffer"] } tokio = { workspace = true, features = ["rt", "rt-multi-thread"] } -rayon = { workspace = true } tokio-util = { workspace = true } +rayon = { workspace = true } tokio-stream = { workspace = true, features = ["sync", "time"] } futures = { workspace = true, features = ["std"] } pin-project = { workspace = true } -dashmap = { workspace = true } +indexmap = { workspace = true, features = ["std"] } thiserror = { workspace = true } bytes = { workspace = true, features = ["std"] } diff --git a/p2p/p2p/src/block_downloader.rs b/p2p/p2p/src/block_downloader.rs index 5ebcedc..ee335c9 100644 --- a/p2p/p2p/src/block_downloader.rs +++ b/p2p/p2p/src/block_downloader.rs @@ -8,7 +8,6 @@ use std::{ cmp::{max, min, Reverse}, collections::{BTreeMap, BinaryHeap}, - sync::Arc, time::Duration, }; @@ -18,7 +17,7 @@ use tokio::{ task::JoinSet, time::{interval, timeout, MissedTickBehavior}, }; -use tower::{Service, ServiceExt}; +use tower::{util::BoxCloneService, Service, ServiceExt}; use tracing::{instrument, Instrument, Span}; use cuprate_async_buffer::{BufferAppender, BufferStream}; @@ -27,11 +26,11 @@ use cuprate_p2p_core::{handles::ConnectionHandle, NetworkZone}; use cuprate_pruning::PruningSeed; use crate::{ - client_pool::{ClientPool, ClientPoolDropGuard}, constants::{ BLOCK_DOWNLOADER_REQUEST_TIMEOUT, EMPTY_CHAIN_ENTRIES_BEFORE_TOP_ASSUMED, LONG_BAN, MAX_BLOCK_BATCH_LEN, MAX_DOWNLOAD_FAILURES, }, + peer_set::ClientDropGuard, }; mod block_queue; @@ -41,6 +40,7 @@ mod request_chain; #[cfg(test)] mod tests; +use crate::peer_set::{PeerSetRequest, PeerSetResponse}; use block_queue::{BlockQueue, ReadyQueueBatch}; use chain_tracker::{BlocksToRetrieve, ChainEntry, ChainTracker}; use download_batch::download_batch_task; @@ -135,7 +135,7 @@ pub enum ChainSvcResponse { /// call this function again, so it can start the search again. #[instrument(level = "error", skip_all, name = "block_downloader")] pub fn download_blocks( - client_pool: Arc>, + peer_set: BoxCloneService, tower::BoxError>, our_chain_svc: C, config: BlockDownloaderConfig, ) -> BufferStream @@ -147,8 +147,7 @@ where { let (buffer_appender, buffer_stream) = cuprate_async_buffer::new_buffer(config.buffer_size); - let block_downloader = - BlockDownloader::new(client_pool, our_chain_svc, buffer_appender, config); + let block_downloader = BlockDownloader::new(peer_set, our_chain_svc, buffer_appender, config); tokio::spawn( block_downloader @@ -186,8 +185,8 @@ where /// - download an already requested batch of blocks (this might happen due to an error in the previous request /// or because the queue of ready blocks is too large, so we need the oldest block to clear it). struct BlockDownloader { - /// The client pool. - client_pool: Arc>, + /// The peer set. + peer_set: BoxCloneService, tower::BoxError>, /// The service that holds our current chain state. our_chain_svc: C, @@ -208,7 +207,7 @@ struct BlockDownloader { /// /// Returns a result of the chain entry or an error. #[expect(clippy::type_complexity)] - chain_entry_task: JoinSet, ChainEntry), BlockDownloadError>>, + chain_entry_task: JoinSet, ChainEntry), BlockDownloadError>>, /// The current inflight requests. /// @@ -235,13 +234,13 @@ where { /// Creates a new [`BlockDownloader`] fn new( - client_pool: Arc>, + peer_set: BoxCloneService, tower::BoxError>, our_chain_svc: C, buffer_appender: BufferAppender, config: BlockDownloaderConfig, ) -> Self { Self { - client_pool, + peer_set, our_chain_svc, amount_of_blocks_to_request: config.initial_batch_len, amount_of_blocks_to_request_updated_at: 0, @@ -259,7 +258,7 @@ where fn check_pending_peers( &mut self, chain_tracker: &mut ChainTracker, - pending_peers: &mut BTreeMap>>, + pending_peers: &mut BTreeMap>>, ) { tracing::debug!("Checking if we can give any work to pending peers."); @@ -286,11 +285,11 @@ where /// This function will find the batch(es) that we are waiting on to clear our ready queue and sends another request /// for them. /// - /// Returns the [`ClientPoolDropGuard`] back if it doesn't have the batch according to its pruning seed. + /// Returns the [`ClientDropGuard`] back if it doesn't have the batch according to its pruning seed. fn request_inflight_batch_again( &mut self, - client: ClientPoolDropGuard, - ) -> Option> { + client: ClientDropGuard, + ) -> Option> { tracing::debug!( "Requesting an inflight batch, current ready queue size: {}", self.block_queue.size() @@ -336,13 +335,13 @@ where /// /// The batch requested will depend on our current state, failed batches will be prioritised. /// - /// Returns the [`ClientPoolDropGuard`] back if it doesn't have the data we currently need according + /// Returns the [`ClientDropGuard`] back if it doesn't have the data we currently need according /// to its pruning seed. fn request_block_batch( &mut self, chain_tracker: &mut ChainTracker, - client: ClientPoolDropGuard, - ) -> Option> { + client: ClientDropGuard, + ) -> Option> { tracing::trace!("Using peer to request a batch of blocks."); // First look to see if we have any failed requests. while let Some(failed_request) = self.failed_batches.peek() { @@ -416,13 +415,13 @@ where /// This function will use our current state to decide if we should send a request for a chain entry /// or if we should request a batch of blocks. /// - /// Returns the [`ClientPoolDropGuard`] back if it doesn't have the data we currently need according + /// Returns the [`ClientDropGuard`] back if it doesn't have the data we currently need according /// to its pruning seed. fn try_handle_free_client( &mut self, chain_tracker: &mut ChainTracker, - client: ClientPoolDropGuard, - ) -> Option> { + client: ClientDropGuard, + ) -> Option> { // We send 2 requests, so if one of them is slow or doesn't have the next chain, we still have a backup. if self.chain_entry_task.len() < 2 // If we have had too many failures then assume the tip has been found so no more chain entries. @@ -463,7 +462,7 @@ where async fn check_for_free_clients( &mut self, chain_tracker: &mut ChainTracker, - pending_peers: &mut BTreeMap>>, + pending_peers: &mut BTreeMap>>, ) -> Result<(), BlockDownloadError> { tracing::debug!("Checking for free peers"); @@ -478,10 +477,19 @@ where panic!("Chain service returned wrong response."); }; - for client in self - .client_pool - .clients_with_more_cumulative_difficulty(current_cumulative_difficulty) - { + let PeerSetResponse::PeersWithMorePoW(clients) = self + .peer_set + .ready() + .await? + .call(PeerSetRequest::PeersWithMorePoW( + current_cumulative_difficulty, + )) + .await? + else { + unreachable!(); + }; + + for client in clients { pending_peers .entry(client.info.pruning_seed) .or_default() @@ -497,9 +505,9 @@ where async fn handle_download_batch_res( &mut self, start_height: usize, - res: Result<(ClientPoolDropGuard, BlockBatch), BlockDownloadError>, + res: Result<(ClientDropGuard, BlockBatch), BlockDownloadError>, chain_tracker: &mut ChainTracker, - pending_peers: &mut BTreeMap>>, + pending_peers: &mut BTreeMap>>, ) -> Result<(), BlockDownloadError> { tracing::debug!("Handling block download response"); @@ -593,7 +601,7 @@ where /// Starts the main loop of the block downloader. async fn run(mut self) -> Result<(), BlockDownloadError> { let mut chain_tracker = - initial_chain_search(&self.client_pool, &mut self.our_chain_svc).await?; + initial_chain_search(&mut self.peer_set, &mut self.our_chain_svc).await?; let mut pending_peers = BTreeMap::new(); @@ -662,7 +670,7 @@ struct BlockDownloadTaskResponse { /// The start height of the batch. start_height: usize, /// A result containing the batch or an error. - result: Result<(ClientPoolDropGuard, BlockBatch), BlockDownloadError>, + result: Result<(ClientDropGuard, BlockBatch), BlockDownloadError>, } /// Returns if a peer has all the blocks in a range, according to its [`PruningSeed`]. diff --git a/p2p/p2p/src/block_downloader/download_batch.rs b/p2p/p2p/src/block_downloader/download_batch.rs index bbb14b3..ef621ce 100644 --- a/p2p/p2p/src/block_downloader/download_batch.rs +++ b/p2p/p2p/src/block_downloader/download_batch.rs @@ -16,8 +16,8 @@ use cuprate_wire::protocol::{GetObjectsRequest, GetObjectsResponse}; use crate::{ block_downloader::{BlockBatch, BlockDownloadError, BlockDownloadTaskResponse}, - client_pool::ClientPoolDropGuard, constants::{BLOCK_DOWNLOADER_REQUEST_TIMEOUT, MAX_TRANSACTION_BLOB_SIZE, MEDIUM_BAN}, + peer_set::ClientDropGuard, }; /// Attempts to request a batch of blocks from a peer, returning [`BlockDownloadTaskResponse`]. @@ -32,7 +32,7 @@ use crate::{ )] #[expect(clippy::used_underscore_binding)] pub async fn download_batch_task( - client: ClientPoolDropGuard, + client: ClientDropGuard, ids: ByteArrayVec<32>, previous_id: [u8; 32], expected_start_height: usize, @@ -49,11 +49,11 @@ pub async fn download_batch_task( /// This function will validate the blocks that were downloaded were the ones asked for and that they match /// the expected height. async fn request_batch_from_peer( - mut client: ClientPoolDropGuard, + mut client: ClientDropGuard, ids: ByteArrayVec<32>, previous_id: [u8; 32], expected_start_height: usize, -) -> Result<(ClientPoolDropGuard, BlockBatch), BlockDownloadError> { +) -> Result<(ClientDropGuard, BlockBatch), BlockDownloadError> { let request = PeerRequest::Protocol(ProtocolRequest::GetObjects(GetObjectsRequest { blocks: ids.clone(), pruned: false, diff --git a/p2p/p2p/src/block_downloader/request_chain.rs b/p2p/p2p/src/block_downloader/request_chain.rs index d6a2a0a..4e0f855 100644 --- a/p2p/p2p/src/block_downloader/request_chain.rs +++ b/p2p/p2p/src/block_downloader/request_chain.rs @@ -1,7 +1,7 @@ -use std::{mem, sync::Arc}; +use std::mem; use tokio::{task::JoinSet, time::timeout}; -use tower::{Service, ServiceExt}; +use tower::{util::BoxCloneService, Service, ServiceExt}; use tracing::{instrument, Instrument, Span}; use cuprate_p2p_core::{ @@ -15,11 +15,11 @@ use crate::{ chain_tracker::{ChainEntry, ChainTracker}, BlockDownloadError, ChainSvcRequest, ChainSvcResponse, }, - client_pool::{ClientPool, ClientPoolDropGuard}, constants::{ BLOCK_DOWNLOADER_REQUEST_TIMEOUT, INITIAL_CHAIN_REQUESTS_TO_SEND, MAX_BLOCKS_IDS_IN_CHAIN_ENTRY, MEDIUM_BAN, }, + peer_set::{ClientDropGuard, PeerSetRequest, PeerSetResponse}, }; /// Request a chain entry from a peer. @@ -27,9 +27,9 @@ use crate::{ /// Because the block downloader only follows and downloads one chain we only have to send the block hash of /// top block we have found and the genesis block, this is then called `short_history`. pub(crate) async fn request_chain_entry_from_peer( - mut client: ClientPoolDropGuard, + mut client: ClientDropGuard, short_history: [[u8; 32]; 2], -) -> Result<(ClientPoolDropGuard, ChainEntry), BlockDownloadError> { +) -> Result<(ClientDropGuard, ChainEntry), BlockDownloadError> { let PeerResponse::Protocol(ProtocolResponse::GetChain(chain_res)) = client .ready() .await? @@ -80,7 +80,7 @@ pub(crate) async fn request_chain_entry_from_peer( /// We then wait for their response and choose the peer who claims the highest cumulative difficulty. #[instrument(level = "error", skip_all)] pub async fn initial_chain_search( - client_pool: &Arc>, + peer_set: &mut BoxCloneService, tower::BoxError>, mut our_chain_svc: C, ) -> Result, BlockDownloadError> where @@ -102,9 +102,15 @@ where let our_genesis = *block_ids.last().expect("Blockchain had no genesis block."); - let mut peers = client_pool - .clients_with_more_cumulative_difficulty(cumulative_difficulty) - .into_iter(); + let PeerSetResponse::PeersWithMorePoW(clients) = peer_set + .ready() + .await? + .call(PeerSetRequest::PeersWithMorePoW(cumulative_difficulty)) + .await? + else { + unreachable!(); + }; + let mut peers = clients.into_iter(); let mut futs = JoinSet::new(); diff --git a/p2p/p2p/src/block_downloader/tests.rs b/p2p/p2p/src/block_downloader/tests.rs index dd07cce..2d00358 100644 --- a/p2p/p2p/src/block_downloader/tests.rs +++ b/p2p/p2p/src/block_downloader/tests.rs @@ -14,8 +14,8 @@ use monero_serai::{ transaction::{Input, Timelock, Transaction, TransactionPrefix}, }; use proptest::{collection::vec, prelude::*}; -use tokio::time::timeout; -use tower::{service_fn, Service}; +use tokio::{sync::mpsc, time::timeout}; +use tower::{buffer::Buffer, service_fn, Service, ServiceExt}; use cuprate_fixed_bytes::ByteArrayVec; use cuprate_p2p_core::{ @@ -31,7 +31,7 @@ use cuprate_wire::{ use crate::{ block_downloader::{download_blocks, BlockDownloaderConfig, ChainSvcRequest, ChainSvcResponse}, - client_pool::ClientPool, + peer_set::PeerSet, }; proptest! { @@ -48,19 +48,20 @@ proptest! { let tokio_pool = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap(); - #[expect(clippy::significant_drop_tightening)] tokio_pool.block_on(async move { timeout(Duration::from_secs(600), async move { - let client_pool = ClientPool::new(); + let (new_connection_tx, new_connection_rx) = mpsc::channel(peers); + + let peer_set = PeerSet::new(new_connection_rx); for _ in 0..peers { let client = mock_block_downloader_client(Arc::clone(&blockchain)); - client_pool.add_new_client(client); + new_connection_tx.try_send(client).unwrap(); } let stream = download_blocks( - client_pool, + Buffer::new(peer_set, 10).boxed_clone(), OurChainSvc { genesis: *blockchain.blocks.first().unwrap().0 }, diff --git a/p2p/p2p/src/client_pool.rs b/p2p/p2p/src/client_pool.rs deleted file mode 100644 index 67c8f11..0000000 --- a/p2p/p2p/src/client_pool.rs +++ /dev/null @@ -1,188 +0,0 @@ -//! # Client Pool. -//! -//! The [`ClientPool`], is a pool of currently connected peers that can be pulled from. -//! It does _not_ necessarily contain every connected peer as another place could have -//! taken a peer from the pool. -//! -//! When taking peers from the pool they are wrapped in [`ClientPoolDropGuard`], which -//! returns the peer to the pool when it is dropped. -//! -//! Internally the pool is a [`DashMap`] which means care should be taken in `async` code -//! as internally this uses blocking `RwLock`s. -use std::sync::Arc; - -use dashmap::DashMap; -use tokio::sync::mpsc; -use tracing::{Instrument, Span}; - -use cuprate_p2p_core::{ - client::{Client, InternalPeerID}, - handles::ConnectionHandle, - ConnectionDirection, NetworkZone, -}; - -pub(crate) mod disconnect_monitor; -mod drop_guard_client; - -pub use drop_guard_client::ClientPoolDropGuard; - -/// The client pool, which holds currently connected free peers. -/// -/// See the [module docs](self) for more. -pub struct ClientPool { - /// The connected [`Client`]s. - clients: DashMap, Client>, - /// A channel to send new peer ids down to monitor for disconnect. - new_connection_tx: mpsc::UnboundedSender<(ConnectionHandle, InternalPeerID)>, -} - -impl ClientPool { - /// Returns a new [`ClientPool`] wrapped in an [`Arc`]. - pub fn new() -> Arc { - let (tx, rx) = mpsc::unbounded_channel(); - - let pool = Arc::new(Self { - clients: DashMap::new(), - new_connection_tx: tx, - }); - - tokio::spawn( - disconnect_monitor::disconnect_monitor(rx, Arc::clone(&pool)) - .instrument(Span::current()), - ); - - pool - } - - /// Adds a [`Client`] to the pool, the client must have previously been taken from the - /// pool. - /// - /// See [`ClientPool::add_new_client`] to add a [`Client`] which was not taken from the pool before. - /// - /// # Panics - /// This function panics if `client` already exists in the pool. - fn add_client(&self, client: Client) { - let handle = client.info.handle.clone(); - let id = client.info.id; - - // Fast path: if the client is disconnected don't add it to the peer set. - if handle.is_closed() { - return; - } - - assert!(self.clients.insert(id, client).is_none()); - - // We have to check this again otherwise we could have a race condition where a - // peer is disconnected after the first check, the disconnect monitor tries to remove it, - // and then it is added to the pool. - if handle.is_closed() { - self.remove_client(&id); - } - } - - /// Adds a _new_ [`Client`] to the pool, this client should be a new connection, and not already - /// from the pool. - /// - /// # Panics - /// This function panics if `client` already exists in the pool. - pub fn add_new_client(&self, client: Client) { - self.new_connection_tx - .send((client.info.handle.clone(), client.info.id)) - .unwrap(); - - self.add_client(client); - } - - /// Remove a [`Client`] from the pool. - /// - /// [`None`] is returned if the client did not exist in the pool. - fn remove_client(&self, peer: &InternalPeerID) -> Option> { - self.clients.remove(peer).map(|(_, client)| client) - } - - /// Borrows a [`Client`] from the pool. - /// - /// The [`Client`] is wrapped in [`ClientPoolDropGuard`] which - /// will return the client to the pool when it's dropped. - /// - /// See [`Self::borrow_clients`] for borrowing multiple clients. - pub fn borrow_client( - self: &Arc, - peer: &InternalPeerID, - ) -> Option> { - self.remove_client(peer).map(|client| ClientPoolDropGuard { - pool: Arc::clone(self), - client: Some(client), - }) - } - - /// Borrows multiple [`Client`]s from the pool. - /// - /// Note that the returned iterator is not guaranteed to contain every peer asked for. - /// - /// See [`Self::borrow_client`] for borrowing a single client. - pub fn borrow_clients<'a, 'b>( - self: &'a Arc, - peers: &'b [InternalPeerID], - ) -> impl Iterator> + sealed::Captures<(&'a (), &'b ())> { - peers.iter().filter_map(|peer| self.borrow_client(peer)) - } - - /// Borrows all [`Client`]s from the pool that have claimed a higher cumulative difficulty than - /// the amount passed in. - /// - /// The [`Client`]s are wrapped in [`ClientPoolDropGuard`] which - /// will return the clients to the pool when they are dropped. - pub fn clients_with_more_cumulative_difficulty( - self: &Arc, - cumulative_difficulty: u128, - ) -> Vec> { - let peers = self - .clients - .iter() - .filter_map(|element| { - let peer_sync_info = element.value().info.core_sync_data.lock().unwrap(); - - if peer_sync_info.cumulative_difficulty() > cumulative_difficulty { - Some(*element.key()) - } else { - None - } - }) - .collect::>(); - - self.borrow_clients(&peers).collect() - } - - /// Checks all clients in the pool checking if any claim a higher cumulative difficulty than the - /// amount specified. - pub fn contains_client_with_more_cumulative_difficulty( - &self, - cumulative_difficulty: u128, - ) -> bool { - self.clients.iter().any(|element| { - let sync_data = element.value().info.core_sync_data.lock().unwrap(); - sync_data.cumulative_difficulty() > cumulative_difficulty - }) - } - - /// Returns the first outbound peer when iterating over the peers. - pub fn outbound_client(self: &Arc) -> Option> { - let client = self - .clients - .iter() - .find(|element| element.value().info.direction == ConnectionDirection::Outbound)?; - let id = *client.key(); - - Some(self.borrow_client(&id).unwrap()) - } -} - -mod sealed { - /// TODO: Remove me when 2024 Rust - /// - /// - pub trait Captures {} - - impl Captures for T {} -} diff --git a/p2p/p2p/src/client_pool/disconnect_monitor.rs b/p2p/p2p/src/client_pool/disconnect_monitor.rs deleted file mode 100644 index f54b560..0000000 --- a/p2p/p2p/src/client_pool/disconnect_monitor.rs +++ /dev/null @@ -1,83 +0,0 @@ -//! # Disconnect Monitor -//! -//! This module contains the [`disconnect_monitor`] task, which monitors connected peers for disconnection -//! and then removes them from the [`ClientPool`] if they do. -use std::{ - future::Future, - pin::Pin, - sync::Arc, - task::{Context, Poll}, -}; - -use futures::{stream::FuturesUnordered, StreamExt}; -use tokio::sync::mpsc; -use tokio_util::sync::WaitForCancellationFutureOwned; -use tracing::instrument; - -use cuprate_p2p_core::{client::InternalPeerID, handles::ConnectionHandle, NetworkZone}; - -use super::ClientPool; - -/// The disconnect monitor task. -#[instrument(level = "info", skip_all)] -pub async fn disconnect_monitor( - mut new_connection_rx: mpsc::UnboundedReceiver<(ConnectionHandle, InternalPeerID)>, - client_pool: Arc>, -) { - // We need to hold a weak reference otherwise the client pool and this would hold a reference to - // each other causing the pool to be leaked. - let weak_client_pool = Arc::downgrade(&client_pool); - drop(client_pool); - - tracing::info!("Starting peer disconnect monitor."); - - let mut futs: FuturesUnordered> = FuturesUnordered::new(); - - loop { - tokio::select! { - Some((con_handle, peer_id)) = new_connection_rx.recv() => { - tracing::debug!("Monitoring {peer_id} for disconnect"); - futs.push(PeerDisconnectFut { - closed_fut: con_handle.closed(), - peer_id: Some(peer_id), - }); - } - Some(peer_id) = futs.next() => { - tracing::debug!("{peer_id} has disconnected, removing from client pool."); - let Some(pool) = weak_client_pool.upgrade() else { - tracing::info!("Peer disconnect monitor shutting down."); - return; - }; - - pool.remove_client(&peer_id); - drop(pool); - } - else => { - tracing::info!("Peer disconnect monitor shutting down."); - return; - } - } - } -} - -/// A [`Future`] that resolves when a peer disconnects. -#[pin_project::pin_project] -pub(crate) struct PeerDisconnectFut { - /// The inner [`Future`] that resolves when a peer disconnects. - #[pin] - pub(crate) closed_fut: WaitForCancellationFutureOwned, - /// The peers ID. - pub(crate) peer_id: Option>, -} - -impl Future for PeerDisconnectFut { - type Output = InternalPeerID; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - - this.closed_fut - .poll(cx) - .map(|()| this.peer_id.take().unwrap()) - } -} diff --git a/p2p/p2p/src/client_pool/drop_guard_client.rs b/p2p/p2p/src/client_pool/drop_guard_client.rs deleted file mode 100644 index b10c4e9..0000000 --- a/p2p/p2p/src/client_pool/drop_guard_client.rs +++ /dev/null @@ -1,41 +0,0 @@ -use std::{ - ops::{Deref, DerefMut}, - sync::Arc, -}; - -use cuprate_p2p_core::{client::Client, NetworkZone}; - -use crate::client_pool::ClientPool; - -/// A wrapper around [`Client`] which returns the client to the [`ClientPool`] when dropped. -pub struct ClientPoolDropGuard { - /// The [`ClientPool`] to return the peer to. - pub(super) pool: Arc>, - /// The [`Client`]. - /// - /// This is set to [`Some`] when this guard is created, then - /// [`take`](Option::take)n and returned to the pool when dropped. - pub(super) client: Option>, -} - -impl Deref for ClientPoolDropGuard { - type Target = Client; - - fn deref(&self) -> &Self::Target { - self.client.as_ref().unwrap() - } -} - -impl DerefMut for ClientPoolDropGuard { - fn deref_mut(&mut self) -> &mut Self::Target { - self.client.as_mut().unwrap() - } -} - -impl Drop for ClientPoolDropGuard { - fn drop(&mut self) { - let client = self.client.take().unwrap(); - - self.pool.add_client(client); - } -} diff --git a/p2p/p2p/src/connection_maintainer.rs b/p2p/p2p/src/connection_maintainer.rs index cd9d931..245fbf1 100644 --- a/p2p/p2p/src/connection_maintainer.rs +++ b/p2p/p2p/src/connection_maintainer.rs @@ -21,7 +21,6 @@ use cuprate_p2p_core::{ }; use crate::{ - client_pool::ClientPool, config::P2PConfig, constants::{HANDSHAKE_TIMEOUT, MAX_SEED_CONNECTIONS, OUTBOUND_CONNECTION_ATTEMPT_TIMEOUT}, }; @@ -46,7 +45,7 @@ pub struct MakeConnectionRequest { /// This handles maintaining a minimum number of connections and making extra connections when needed, upto a maximum. pub struct OutboundConnectionKeeper { /// The pool of currently connected peers. - pub client_pool: Arc>, + pub new_peers_tx: mpsc::Sender>, /// The channel that tells us to make new _extra_ outbound connections. pub make_connection_rx: mpsc::Receiver, /// The address book service @@ -77,7 +76,7 @@ where { pub fn new( config: P2PConfig, - client_pool: Arc>, + new_peers_tx: mpsc::Sender>, make_connection_rx: mpsc::Receiver, address_book_svc: A, connector_svc: C, @@ -86,7 +85,7 @@ where .expect("Gray peer percent is incorrect should be 0..=1"); Self { - client_pool, + new_peers_tx, make_connection_rx, address_book_svc, connector_svc, @@ -149,7 +148,7 @@ where /// Connects to a given outbound peer. #[instrument(level = "info", skip_all)] async fn connect_to_outbound_peer(&mut self, permit: OwnedSemaphorePermit, addr: N::Addr) { - let client_pool = Arc::clone(&self.client_pool); + let new_peers_tx = self.new_peers_tx.clone(); let connection_fut = self .connector_svc .ready() @@ -164,7 +163,7 @@ where async move { #[expect(clippy::significant_drop_in_scrutinee)] if let Ok(Ok(peer)) = timeout(HANDSHAKE_TIMEOUT, connection_fut).await { - client_pool.add_new_client(peer); + drop(new_peers_tx.send(peer).await); } } .instrument(Span::current()), diff --git a/p2p/p2p/src/inbound_server.rs b/p2p/p2p/src/inbound_server.rs index 6e793bd..0479560 100644 --- a/p2p/p2p/src/inbound_server.rs +++ b/p2p/p2p/src/inbound_server.rs @@ -6,7 +6,7 @@ use std::{pin::pin, sync::Arc}; use futures::{SinkExt, StreamExt}; use tokio::{ - sync::Semaphore, + sync::{mpsc, Semaphore}, task::JoinSet, time::{sleep, timeout}, }; @@ -24,7 +24,6 @@ use cuprate_wire::{ }; use crate::{ - client_pool::ClientPool, constants::{ HANDSHAKE_TIMEOUT, INBOUND_CONNECTION_COOL_DOWN, PING_REQUEST_CONCURRENCY, PING_REQUEST_TIMEOUT, @@ -36,7 +35,7 @@ use crate::{ /// and initiate handshake if needed, after verifying the address isn't banned. #[instrument(level = "warn", skip_all)] pub async fn inbound_server( - client_pool: Arc>, + new_connection_tx: mpsc::Sender>, mut handshaker: HS, mut address_book: A, config: P2PConfig, @@ -111,13 +110,13 @@ where permit: Some(permit), }); - let cloned_pool = Arc::clone(&client_pool); + let new_connection_tx = new_connection_tx.clone(); tokio::spawn( async move { let client = timeout(HANDSHAKE_TIMEOUT, fut).await; if let Ok(Ok(peer)) = client { - cloned_pool.add_new_client(peer); + drop(new_connection_tx.send(peer).await); } } .instrument(Span::current()), diff --git a/p2p/p2p/src/lib.rs b/p2p/p2p/src/lib.rs index 541784c..fb50658 100644 --- a/p2p/p2p/src/lib.rs +++ b/p2p/p2p/src/lib.rs @@ -18,17 +18,18 @@ use cuprate_p2p_core::{ pub mod block_downloader; mod broadcast; -pub mod client_pool; pub mod config; pub mod connection_maintainer; pub mod constants; mod inbound_server; +mod peer_set; use block_downloader::{BlockBatch, BlockDownloaderConfig, ChainSvcRequest, ChainSvcResponse}; pub use broadcast::{BroadcastRequest, BroadcastSvc}; -pub use client_pool::{ClientPool, ClientPoolDropGuard}; pub use config::{AddressBookConfig, P2PConfig}; use connection_maintainer::MakeConnectionRequest; +use peer_set::PeerSet; +pub use peer_set::{ClientDropGuard, PeerSetRequest, PeerSetResponse}; /// Initializes the P2P [`NetworkInterface`] for a specific [`NetworkZone`]. /// @@ -54,7 +55,10 @@ where cuprate_address_book::init_address_book(config.address_book_config.clone()).await?; let address_book = Buffer::new( address_book, - config.max_inbound_connections + config.outbound_connections, + config + .max_inbound_connections + .checked_add(config.outbound_connections) + .unwrap(), ); // Use the default config. Changing the defaults affects tx fluff times, which could affect D++ so for now don't allow changing @@ -83,19 +87,25 @@ where let outbound_handshaker = outbound_handshaker_builder.build(); - let client_pool = ClientPool::new(); - + let (new_connection_tx, new_connection_rx) = mpsc::channel( + config + .outbound_connections + .checked_add(config.max_inbound_connections) + .unwrap(), + ); let (make_connection_tx, make_connection_rx) = mpsc::channel(3); let outbound_connector = Connector::new(outbound_handshaker); let outbound_connection_maintainer = connection_maintainer::OutboundConnectionKeeper::new( config.clone(), - Arc::clone(&client_pool), + new_connection_tx.clone(), make_connection_rx, address_book.clone(), outbound_connector, ); + let peer_set = PeerSet::new(new_connection_rx); + let mut background_tasks = JoinSet::new(); background_tasks.spawn( @@ -105,7 +115,7 @@ where ); background_tasks.spawn( inbound_server::inbound_server( - Arc::clone(&client_pool), + new_connection_tx, inbound_handshaker, address_book.clone(), config, @@ -121,7 +131,7 @@ where ); Ok(NetworkInterface { - pool: client_pool, + peer_set: Buffer::new(peer_set, 10).boxed_clone(), broadcast_svc, make_connection_tx, address_book: address_book.boxed_clone(), @@ -133,7 +143,7 @@ where #[derive(Clone)] pub struct NetworkInterface { /// A pool of free connected peers. - pool: Arc>, + peer_set: BoxCloneService, tower::BoxError>, /// A [`Service`] that allows broadcasting to all connected peers. broadcast_svc: BroadcastSvc, /// A channel to request extra connections. @@ -163,7 +173,7 @@ impl NetworkInterface { + 'static, C::Future: Send + 'static, { - block_downloader::download_blocks(Arc::clone(&self.pool), our_chain_service, config) + block_downloader::download_blocks(self.peer_set.clone(), our_chain_service, config) } /// Returns the address book service. @@ -173,8 +183,10 @@ impl NetworkInterface { self.address_book.clone() } - /// Borrows the `ClientPool`, for access to connected peers. - pub const fn client_pool(&self) -> &Arc> { - &self.pool + /// Borrows the `PeerSet`, for access to connected peers. + pub fn peer_set( + &mut self, + ) -> &mut BoxCloneService, tower::BoxError> { + &mut self.peer_set } } diff --git a/p2p/p2p/src/peer_set.rs b/p2p/p2p/src/peer_set.rs new file mode 100644 index 0000000..498eaaf --- /dev/null +++ b/p2p/p2p/src/peer_set.rs @@ -0,0 +1,217 @@ +use std::{ + future::{ready, Future, Ready}, + pin::{pin, Pin}, + task::{Context, Poll}, +}; + +use futures::{stream::FuturesUnordered, StreamExt}; +use indexmap::{IndexMap, IndexSet}; +use rand::{seq::index::sample, thread_rng}; +use tokio::sync::mpsc::Receiver; +use tokio_util::sync::WaitForCancellationFutureOwned; +use tower::Service; + +use cuprate_helper::cast::u64_to_usize; +use cuprate_p2p_core::{ + client::{Client, InternalPeerID}, + ConnectionDirection, NetworkZone, +}; + +mod client_wrappers; + +pub use client_wrappers::ClientDropGuard; +use client_wrappers::StoredClient; + +/// A request to the peer-set. +pub enum PeerSetRequest { + /// The most claimed proof-of-work from a peer in the peer-set. + MostPoWSeen, + /// Peers with more cumulative difficulty than the given cumulative difficulty. + /// + /// Returned peers will be remembered and won't be returned from subsequent calls until the guard is dropped. + PeersWithMorePoW(u128), + /// A random outbound peer. + /// + /// The returned peer will be remembered and won't be returned from subsequent calls until the guard is dropped. + StemPeer, +} + +/// A response from the peer-set. +pub enum PeerSetResponse { + /// [`PeerSetRequest::MostPoWSeen`] + MostPoWSeen { + /// The cumulative difficulty claimed. + cumulative_difficulty: u128, + /// The height claimed. + height: usize, + /// The claimed hash of the top block. + top_hash: [u8; 32], + }, + /// [`PeerSetRequest::PeersWithMorePoW`] + /// + /// Returned peers will be remembered and won't be returned from subsequent calls until the guard is dropped. + PeersWithMorePoW(Vec>), + /// [`PeerSetRequest::StemPeer`] + /// + /// The returned peer will be remembered and won't be returned from subsequent calls until the guard is dropped. + StemPeer(Option>), +} + +/// A [`Future`] that completes when a peer disconnects. +#[pin_project::pin_project] +struct ClosedConnectionFuture { + #[pin] + fut: WaitForCancellationFutureOwned, + id: Option>, +} + +impl Future for ClosedConnectionFuture { + type Output = InternalPeerID; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + this.fut.poll(cx).map(|()| this.id.take().unwrap()) + } +} + +/// A collection of all connected peers on a [`NetworkZone`]. +pub(crate) struct PeerSet { + /// The connected peers. + peers: IndexMap, StoredClient>, + /// A [`FuturesUnordered`] that resolves when a peer disconnects. + closed_connections: FuturesUnordered>, + /// The [`InternalPeerID`]s of all outbound peers. + outbound_peers: IndexSet>, + /// A channel of new peers from the inbound server or outbound connector. + new_peers: Receiver>, +} + +impl PeerSet { + pub(crate) fn new(new_peers: Receiver>) -> Self { + Self { + peers: IndexMap::new(), + closed_connections: FuturesUnordered::new(), + outbound_peers: IndexSet::new(), + new_peers, + } + } + + /// Polls the new peers channel for newly connected peers. + fn poll_new_peers(&mut self, cx: &mut Context<'_>) { + while let Poll::Ready(Some(new_peer)) = self.new_peers.poll_recv(cx) { + if new_peer.info.direction == ConnectionDirection::Outbound { + self.outbound_peers.insert(new_peer.info.id); + } + + self.closed_connections.push(ClosedConnectionFuture { + fut: new_peer.info.handle.closed(), + id: Some(new_peer.info.id), + }); + + self.peers + .insert(new_peer.info.id, StoredClient::new(new_peer)); + } + } + + /// Remove disconnected peers from the peer set. + fn remove_dead_peers(&mut self, cx: &mut Context<'_>) { + while let Poll::Ready(Some(dead_peer)) = self.closed_connections.poll_next_unpin(cx) { + let Some(peer) = self.peers.swap_remove(&dead_peer) else { + continue; + }; + + if peer.client.info.direction == ConnectionDirection::Outbound { + self.outbound_peers.swap_remove(&peer.client.info.id); + } + + self.peers.swap_remove(&dead_peer); + } + } + + /// [`PeerSetRequest::MostPoWSeen`] + fn most_pow_seen(&self) -> PeerSetResponse { + let most_pow_chain = self + .peers + .values() + .map(|peer| { + let core_sync_data = peer.client.info.core_sync_data.lock().unwrap(); + + ( + core_sync_data.cumulative_difficulty(), + u64_to_usize(core_sync_data.current_height), + core_sync_data.top_id, + ) + }) + .max_by_key(|(cumulative_difficulty, ..)| *cumulative_difficulty) + .unwrap_or_default(); + + PeerSetResponse::MostPoWSeen { + cumulative_difficulty: most_pow_chain.0, + height: most_pow_chain.1, + top_hash: most_pow_chain.2, + } + } + + /// [`PeerSetRequest::PeersWithMorePoW`] + fn peers_with_more_pow(&self, cumulative_difficulty: u128) -> PeerSetResponse { + PeerSetResponse::PeersWithMorePoW( + self.peers + .values() + .filter(|&client| { + !client.is_downloading_blocks() + && client + .client + .info + .core_sync_data + .lock() + .unwrap() + .cumulative_difficulty() + > cumulative_difficulty + }) + .map(StoredClient::downloading_blocks_guard) + .collect(), + ) + } + + /// [`PeerSetRequest::StemPeer`] + fn random_peer_for_stem(&self) -> PeerSetResponse { + PeerSetResponse::StemPeer( + sample( + &mut thread_rng(), + self.outbound_peers.len(), + self.outbound_peers.len(), + ) + .into_iter() + .find_map(|i| { + let peer = self.outbound_peers.get_index(i).unwrap(); + let client = self.peers.get(peer).unwrap(); + (!client.is_a_stem_peer()).then(|| client.stem_peer_guard()) + }), + ) + } +} + +impl Service for PeerSet { + type Response = PeerSetResponse; + type Error = tower::BoxError; + type Future = Ready>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.poll_new_peers(cx); + self.remove_dead_peers(cx); + + // TODO: should we return `Pending` if we don't have any peers? + + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: PeerSetRequest) -> Self::Future { + ready(match req { + PeerSetRequest::MostPoWSeen => Ok(self.most_pow_seen()), + PeerSetRequest::PeersWithMorePoW(cumulative_difficulty) => { + Ok(self.peers_with_more_pow(cumulative_difficulty)) + } + PeerSetRequest::StemPeer => Ok(self.random_peer_for_stem()), + }) + } +} diff --git a/p2p/p2p/src/peer_set/client_wrappers.rs b/p2p/p2p/src/peer_set/client_wrappers.rs new file mode 100644 index 0000000..97d7493 --- /dev/null +++ b/p2p/p2p/src/peer_set/client_wrappers.rs @@ -0,0 +1,86 @@ +use std::{ + ops::{Deref, DerefMut}, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; + +use cuprate_p2p_core::{ + client::{Client, WeakClient}, + NetworkZone, +}; + +/// A client stored in the peer-set. +pub(super) struct StoredClient { + pub client: Client, + /// An [`AtomicBool`] for if the peer is currently downloading blocks. + downloading_blocks: Arc, + /// An [`AtomicBool`] for if the peer is currently being used to stem txs. + stem_peer: Arc, +} + +impl StoredClient { + pub(super) fn new(client: Client) -> Self { + Self { + client, + downloading_blocks: Arc::new(AtomicBool::new(false)), + stem_peer: Arc::new(AtomicBool::new(false)), + } + } + + /// Returns [`true`] if the [`StoredClient`] is currently downloading blocks. + pub(super) fn is_downloading_blocks(&self) -> bool { + self.downloading_blocks.load(Ordering::Relaxed) + } + + /// Returns [`true`] if the [`StoredClient`] is currently being used to stem txs. + pub(super) fn is_a_stem_peer(&self) -> bool { + self.stem_peer.load(Ordering::Relaxed) + } + + /// Returns a [`ClientDropGuard`] that while it is alive keeps the [`StoredClient`] in the downloading blocks state. + pub(super) fn downloading_blocks_guard(&self) -> ClientDropGuard { + self.downloading_blocks.store(true, Ordering::Relaxed); + + ClientDropGuard { + client: self.client.downgrade(), + bool: Arc::clone(&self.downloading_blocks), + } + } + + /// Returns a [`ClientDropGuard`] that while it is alive keeps the [`StoredClient`] in the stemming peers state. + pub(super) fn stem_peer_guard(&self) -> ClientDropGuard { + self.stem_peer.store(true, Ordering::Relaxed); + + ClientDropGuard { + client: self.client.downgrade(), + bool: Arc::clone(&self.stem_peer), + } + } +} + +/// A [`Drop`] guard for a client returned from the peer-set. +pub struct ClientDropGuard { + client: WeakClient, + bool: Arc, +} + +impl Deref for ClientDropGuard { + type Target = WeakClient; + fn deref(&self) -> &Self::Target { + &self.client + } +} + +impl DerefMut for ClientDropGuard { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.client + } +} + +impl Drop for ClientDropGuard { + fn drop(&mut self) { + self.bool.store(false, Ordering::Relaxed); + } +}