P2P: Block downloader (#132)

* impl async buffer

* clippy

* p2p changes

* clippy

* a few more docs

* init cuprate-p2p

* remove some unrelated code and add some docs

* start documenting client_pool.rs

* add more docs

* typo

* fix docs

* use JoinSet in connection maintainer

* small changes

* add peer sync state svc

* add broadcast svc

* add more docs

* add some tests

* add a test

* fix merge

* add another test

* unify PeerDisconnectFut and add more docs

* start network init

* add an inbound connection server

* remove crate doc for now

* fix address book docs

* fix leak in client pool

* correct comment

* fix merge + add some docs

* review comments

* init block downloader

* fix doc

* initial chain search

* add chain_tracker

* move block downloader to struct

* spawn task whe getting blocks

* check for free peers and handle batch response

* add test bin

* working block downloader

* dynamic batch sizes

* dandelion_tower -> dandelion-tower

* fix async-buffer builds

* check if incoming peers are banned

* add interface methods

* update docs

* use a JoinSet for background network tasks

* dynamic batch size changes

* Keep a longer of queue of blocks to get

* more checks on incoming data

* fix merge

* fix imports

* add more docs

* add some limits on messages

* keep peers that dont have the current need data

* fix clippy

* fix .lock

* fix stopping the block downloader

* clean up API and add more docs

* tracing + bug fixes

* fix panic

* doc changes

* remove test_init

* remove spammy log

* fix previous merge

* add a test

* fix test

* remove test unwrap

* order imports correctly

* clean up test

* add a timeout

* fix tests

* review fixes

* make `BlockDownloader` pub

* make `initial_chain_search` pub

* make `block_downloader` private

* Apply suggestions from code review

Co-authored-by: hinto-janai <hinto.janai@protonmail.com>

* split some sections into separate modules

* split chain requests

* sort imports

* check previous ID is correct

* fix typos

* Apply suggestions from code review

Co-authored-by: hinto-janai <hinto.janai@protonmail.com>

---------

Co-authored-by: hinto-janai <hinto.janai@protonmail.com>
This commit is contained in:
Boog900 2024-06-22 00:29:40 +00:00 committed by GitHub
parent ff1172f2ab
commit 10aac8cbb2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 2033 additions and 71 deletions

52
Cargo.lock generated
View file

@ -598,6 +598,7 @@ dependencies = [
name = "cuprate-p2p"
version = "0.1.0"
dependencies = [
"async-buffer",
"bytes",
"cuprate-helper",
"cuprate-test-utils",
@ -612,16 +613,17 @@ dependencies = [
"monero-serai",
"monero-wire",
"pin-project",
"proptest",
"rand",
"rand_distr",
"rayon",
"thiserror",
"tokio",
"tokio-stream",
"tokio-test",
"tokio-util",
"tower",
"tracing",
"tracing-subscriber",
]
[[package]]
@ -1564,16 +1566,6 @@ dependencies = [
"zeroize",
]
[[package]]
name = "nu-ansi-term"
version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84"
dependencies = [
"overload",
"winapi",
]
[[package]]
name = "num-traits"
version = "0.2.19"
@ -1621,12 +1613,6 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d"
[[package]]
name = "overload"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
[[package]]
name = "page_size"
version = "0.6.0"
@ -2235,15 +2221,6 @@ dependencies = [
"keccak",
]
[[package]]
name = "sharded-slab"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6"
dependencies = [
"lazy_static",
]
[[package]]
name = "signal-hook-registry"
version = "1.4.2"
@ -2596,18 +2573,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54"
dependencies = [
"once_cell",
"valuable",
]
[[package]]
name = "tracing-log"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3"
dependencies = [
"log",
"once_cell",
"tracing-core",
]
[[package]]
@ -2616,12 +2581,7 @@ version = "0.3.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b"
dependencies = [
"nu-ansi-term",
"sharded-slab",
"smallvec",
"thread_local",
"tracing-core",
"tracing-log",
]
[[package]]
@ -2680,12 +2640,6 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "valuable"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d"
[[package]]
name = "version_check"
version = "0.9.4"

View file

@ -89,6 +89,7 @@ tempfile = { version = "3" }
pretty_assertions = { version = "1.4.0" }
proptest = { version = "1" }
proptest-derive = { version = "0.4.0" }
tokio-test = { version = "0.4.4" }
## TODO:
## Potential dependencies.

View file

@ -11,7 +11,8 @@ monero-wire = { path = "../../net/monero-wire" }
monero-p2p = { path = "../monero-p2p", features = ["borsh"] }
monero-address-book = { path = "../address-book" }
monero-pruning = { path = "../../pruning" }
cuprate-helper = { path = "../../helper", features = ["asynch"] }
cuprate-helper = { path = "../../helper", features = ["asynch"], default-features = false }
async-buffer = { path = "../async-buffer" }
monero-serai = { workspace = true, features = ["std"] }
@ -26,13 +27,13 @@ dashmap = { workspace = true }
thiserror = { workspace = true }
bytes = { workspace = true, features = ["std"] }
indexmap = { workspace = true, features = ["std"] }
rand = { workspace = true, features = ["std", "std_rng"] }
rand_distr = { workspace = true, features = ["std"] }
hex = { workspace = true, features = ["std"] }
tracing = { workspace = true, features = ["std", "attributes"] }
tracing-subscriber = "0.3.18"
[dev-dependencies]
cuprate-test-utils = { path = "../../test-utils" }
indexmap = { workspace = true }
proptest = { workspace = true }
tokio-test = { workspace = true }

View file

@ -0,0 +1,733 @@
//! # Block Downloader
//!
//! This module contains the [`BlockDownloader`], which finds a chain to
//! download from our connected peers and downloads it. See the actual
//! `struct` documentation for implementation details.
//!
//! The block downloader is started by [`download_blocks`].
use std::{
cmp::{max, min, Reverse},
collections::{BTreeMap, BinaryHeap},
sync::Arc,
time::Duration,
};
use futures::TryFutureExt;
use monero_serai::{block::Block, transaction::Transaction};
use tokio::{
task::JoinSet,
time::{interval, timeout, MissedTickBehavior},
};
use tower::{Service, ServiceExt};
use tracing::{instrument, Instrument, Span};
use async_buffer::{BufferAppender, BufferStream};
use monero_p2p::{
handles::ConnectionHandle,
services::{PeerSyncRequest, PeerSyncResponse},
NetworkZone, PeerSyncSvc,
};
use monero_pruning::{PruningSeed, CRYPTONOTE_MAX_BLOCK_HEIGHT};
use crate::{
client_pool::{ClientPool, ClientPoolDropGuard},
constants::{
BLOCK_DOWNLOADER_REQUEST_TIMEOUT, EMPTY_CHAIN_ENTRIES_BEFORE_TOP_ASSUMED, LONG_BAN,
MAX_BLOCK_BATCH_LEN, MAX_DOWNLOAD_FAILURES,
},
};
mod block_queue;
mod chain_tracker;
mod download_batch;
mod request_chain;
#[cfg(test)]
mod tests;
use block_queue::{BlockQueue, ReadyQueueBatch};
use chain_tracker::{BlocksToRetrieve, ChainEntry, ChainTracker};
use download_batch::download_batch_task;
use request_chain::{initial_chain_search, request_chain_entry_from_peer};
/// A downloaded batch of blocks.
#[derive(Debug, Clone)]
pub struct BlockBatch {
/// The blocks.
pub blocks: Vec<(Block, Vec<Transaction>)>,
/// The size in bytes of this batch.
pub size: usize,
/// The peer that gave us this batch.
pub peer_handle: ConnectionHandle,
}
/// The block downloader config.
#[derive(Debug, Copy, Clone, PartialOrd, PartialEq, Ord, Eq)]
pub struct BlockDownloaderConfig {
/// The size in bytes of the buffer between the block downloader and the place which
/// is consuming the downloaded blocks.
pub buffer_size: usize,
/// The size of the in progress queue (in bytes) at which we stop requesting more blocks.
pub in_progress_queue_size: usize,
/// The [`Duration`] between checking the client pool for free peers.
pub check_client_pool_interval: Duration,
/// The target size of a single batch of blocks (in bytes).
pub target_batch_size: usize,
/// The initial amount of blocks to request (in number of blocks)
pub initial_batch_size: usize,
}
/// An error that occurred in the [`BlockDownloader`].
#[derive(Debug, thiserror::Error)]
pub enum BlockDownloadError {
#[error("A request to a peer timed out.")]
TimedOut,
#[error("The block buffer was closed.")]
BufferWasClosed,
#[error("The peers we requested data from did not have all the data.")]
PeerDidNotHaveRequestedData,
#[error("The peers response to a request was invalid.")]
PeersResponseWasInvalid,
#[error("The chain we are following is invalid.")]
ChainInvalid,
#[error("Failed to find a more advanced chain to follow")]
FailedToFindAChainToFollow,
#[error("The peer did not send any overlapping blocks, unknown start height.")]
PeerSentNoOverlappingBlocks,
#[error("Service error: {0}")]
ServiceError(#[from] tower::BoxError),
}
/// The request type for the chain service.
pub enum ChainSvcRequest {
/// A request for the current chain history.
CompactHistory,
/// A request to find the first unknown block ID in a list of block IDs.
FindFirstUnknown(Vec<[u8; 32]>),
/// A request for our current cumulative difficulty.
CumulativeDifficulty,
}
/// The response type for the chain service.
pub enum ChainSvcResponse {
/// The response for [`ChainSvcRequest::CompactHistory`].
CompactHistory {
/// A list of blocks IDs in our chain, starting with the most recent block, all the way to the genesis block.
///
/// These blocks should be in reverse chronological order, not every block is needed.
block_ids: Vec<[u8; 32]>,
/// The current cumulative difficulty of the chain.
cumulative_difficulty: u128,
},
/// The response for [`ChainSvcRequest::FindFirstUnknown`].
///
/// Contains the index of the first unknown block and its expected height.
FindFirstUnknown(usize, u64),
/// The response for [`ChainSvcRequest::CumulativeDifficulty`].
///
/// The current cumulative difficulty of our chain.
CumulativeDifficulty(u128),
}
/// This function starts the block downloader and returns a [`BufferStream`] that will produce
/// a sequential stream of blocks.
///
/// The block downloader will pick the longest chain and will follow it for as long as possible,
/// the blocks given from the [`BufferStream`] will be in order.
///
/// The block downloader may fail before the whole chain is downloaded. If this is the case you can
/// call this function again, so it can start the search again.
#[instrument(level = "error", skip_all, name = "block_downloader")]
pub fn download_blocks<N: NetworkZone, S, C>(
client_pool: Arc<ClientPool<N>>,
peer_sync_svc: S,
our_chain_svc: C,
config: BlockDownloaderConfig,
) -> BufferStream<BlockBatch>
where
S: PeerSyncSvc<N> + Clone,
C: Service<ChainSvcRequest, Response = ChainSvcResponse, Error = tower::BoxError>
+ Send
+ 'static,
C::Future: Send + 'static,
{
let (buffer_appender, buffer_stream) = async_buffer::new_buffer(config.buffer_size);
let block_downloader = BlockDownloader::new(
client_pool,
peer_sync_svc,
our_chain_svc,
buffer_appender,
config,
);
tokio::spawn(
block_downloader
.run()
.inspect_err(|e| tracing::debug!("Error downloading blocks: {e}"))
.instrument(Span::current()),
);
buffer_stream
}
/// # Block Downloader
///
/// This is the block downloader, which finds a chain to follow and attempts to follow it, adding the
/// downloaded blocks to an [`async_buffer`].
///
/// ## Implementation Details
///
/// The first step to downloading blocks is to find a chain to follow, this is done by [`initial_chain_search`],
/// docs can be found on that function for details on how this is done.
///
/// With an initial list of block IDs to follow the block downloader will then look for available peers
/// to download blocks from.
///
/// For each peer we will then allocate a batch of blocks for them to retrieve, as these blocks come in
/// we add them to the [`BlockQueue`] for pushing into the [`async_buffer`], once we have the oldest block downloaded
/// we send it into the buffer, repeating this until the oldest current block is still being downloaded.
///
/// When a peer has finished downloading blocks we add it to our list of ready peers, so it can be used to
/// request more data from.
///
/// Ready peers will either:
/// - download the next batch of blocks
/// - request the next chain entry
/// - download an already requested batch of blocks (this might happen due to an error in the previous request
/// or because the queue of ready blocks is too large, so we need the oldest block to clear it).
struct BlockDownloader<N: NetworkZone, S, C> {
/// The client pool.
client_pool: Arc<ClientPool<N>>,
/// The service that holds the peer's sync states.
peer_sync_svc: S,
/// The service that holds our current chain state.
our_chain_svc: C,
/// The amount of blocks to request in the next batch.
amount_of_blocks_to_request: usize,
/// The height at which [`Self::amount_of_blocks_to_request`] was updated.
amount_of_blocks_to_request_updated_at: u64,
/// The amount of consecutive empty chain entries we received.
///
/// An empty chain entry means we reached the peer's chain tip.
amount_of_empty_chain_entries: usize,
/// The running block download tasks.
block_download_tasks: JoinSet<BlockDownloadTaskResponse<N>>,
/// The running chain entry tasks.
///
/// Returns a result of the chain entry or an error.
#[allow(clippy::type_complexity)]
chain_entry_task: JoinSet<Result<(ClientPoolDropGuard<N>, ChainEntry<N>), BlockDownloadError>>,
/// The current inflight requests.
///
/// This is a map of batch start heights to block IDs and related information of the batch.
inflight_requests: BTreeMap<u64, BlocksToRetrieve<N>>,
/// A queue of start heights from failed batches that should be retried.
///
/// Wrapped in [`Reverse`] so we prioritize early batches.
failed_batches: BinaryHeap<Reverse<u64>>,
block_queue: BlockQueue,
/// The [`BlockDownloaderConfig`].
config: BlockDownloaderConfig,
}
impl<N: NetworkZone, S, C> BlockDownloader<N, S, C>
where
S: PeerSyncSvc<N> + Clone,
C: Service<ChainSvcRequest, Response = ChainSvcResponse, Error = tower::BoxError>
+ Send
+ 'static,
C::Future: Send + 'static,
{
/// Creates a new [`BlockDownloader`]
fn new(
client_pool: Arc<ClientPool<N>>,
peer_sync_svc: S,
our_chain_svc: C,
buffer_appender: BufferAppender<BlockBatch>,
config: BlockDownloaderConfig,
) -> Self {
Self {
client_pool,
peer_sync_svc,
our_chain_svc,
amount_of_blocks_to_request: config.initial_batch_size,
amount_of_blocks_to_request_updated_at: 0,
amount_of_empty_chain_entries: 0,
block_download_tasks: JoinSet::new(),
chain_entry_task: JoinSet::new(),
inflight_requests: BTreeMap::new(),
block_queue: BlockQueue::new(buffer_appender),
failed_batches: BinaryHeap::new(),
config,
}
}
/// Checks if we can make use of any peers that are currently pending requests.
async fn check_pending_peers(
&mut self,
chain_tracker: &mut ChainTracker<N>,
pending_peers: &mut BTreeMap<PruningSeed, Vec<ClientPoolDropGuard<N>>>,
) {
tracing::debug!("Checking if we can give any work to pending peers.");
for (_, peers) in pending_peers.iter_mut() {
while let Some(peer) = peers.pop() {
if peer.info.handle.is_closed() {
// Peer has disconnected, drop it.
continue;
}
if let Some(peer) = self.try_handle_free_client(chain_tracker, peer).await {
// This peer is ok however it does not have the data we currently need, this will only happen
// because of its pruning seed so just skip over all peers with this pruning seed.
peers.push(peer);
break;
}
}
}
}
/// Attempts to send another request for an inflight batch
///
/// This function will find the batch(es) that we are waiting on to clear our ready queue and sends another request
/// for them.
///
/// Returns the [`ClientPoolDropGuard`] back if it doesn't have the batch according to its pruning seed.
async fn request_inflight_batch_again(
&mut self,
client: ClientPoolDropGuard<N>,
) -> Option<ClientPoolDropGuard<N>> {
tracing::debug!(
"Requesting an inflight batch, current ready queue size: {}",
self.block_queue.size()
);
assert!(
!self.inflight_requests.is_empty(),
"We need requests inflight to be able to send the request again",
);
let oldest_ready_batch = self.block_queue.oldest_ready_batch().unwrap();
for (_, in_flight_batch) in self.inflight_requests.range_mut(0..oldest_ready_batch) {
if in_flight_batch.requests_sent >= 2 {
continue;
}
if !client_has_block_in_range(
&client.info.pruning_seed,
in_flight_batch.start_height,
in_flight_batch.ids.len(),
) {
return Some(client);
}
self.block_download_tasks.spawn(download_batch_task(
client,
in_flight_batch.ids.clone(),
in_flight_batch.prev_id,
in_flight_batch.start_height,
in_flight_batch.requests_sent,
));
return None;
}
tracing::debug!("Could not find an inflight request applicable for this peer.");
Some(client)
}
/// Spawns a task to request blocks from the given peer.
///
/// The batch requested will depend on our current state, failed batches will be prioritised.
///
/// Returns the [`ClientPoolDropGuard`] back if it doesn't have the data we currently need according
/// to its pruning seed.
async fn request_block_batch(
&mut self,
chain_tracker: &mut ChainTracker<N>,
client: ClientPoolDropGuard<N>,
) -> Option<ClientPoolDropGuard<N>> {
tracing::trace!("Using peer to request a batch of blocks.");
// First look to see if we have any failed requests.
while let Some(failed_request) = self.failed_batches.peek() {
// Check if we still have the request that failed - another peer could have completed it after
// failure.
let Some(request) = self.inflight_requests.get_mut(&failed_request.0) else {
// We don't have the request in flight so remove the failure.
self.failed_batches.pop();
continue;
};
// Check if this peer has the blocks according to their pruning seed.
if client_has_block_in_range(
&client.info.pruning_seed,
request.start_height,
request.ids.len(),
) {
tracing::debug!("Using peer to request a failed batch");
// They should have the blocks so send the re-request to this peer.
request.requests_sent += 1;
self.block_download_tasks.spawn(download_batch_task(
client,
request.ids.clone(),
request.prev_id,
request.start_height,
request.requests_sent,
));
// Remove the failure, we have just handled it.
self.failed_batches.pop();
return None;
}
// The peer doesn't have the batch according to its pruning seed.
break;
}
// If our ready queue is too large send duplicate requests for the blocks we are waiting on.
if self.block_queue.size() >= self.config.in_progress_queue_size {
return self.request_inflight_batch_again(client).await;
}
// No failed requests that we can handle, request some new blocks.
let Some(mut block_entry_to_get) = chain_tracker
.blocks_to_get(&client.info.pruning_seed, self.amount_of_blocks_to_request)
else {
return Some(client);
};
tracing::debug!("Requesting a new batch of blocks");
block_entry_to_get.requests_sent = 1;
self.inflight_requests
.insert(block_entry_to_get.start_height, block_entry_to_get.clone());
self.block_download_tasks.spawn(download_batch_task(
client,
block_entry_to_get.ids.clone(),
block_entry_to_get.prev_id,
block_entry_to_get.start_height,
block_entry_to_get.requests_sent,
));
None
}
/// Attempts to give work to a free client.
///
/// This function will use our current state to decide if we should send a request for a chain entry
/// or if we should request a batch of blocks.
///
/// Returns the [`ClientPoolDropGuard`] back if it doesn't have the data we currently need according
/// to its pruning seed.
async fn try_handle_free_client(
&mut self,
chain_tracker: &mut ChainTracker<N>,
client: ClientPoolDropGuard<N>,
) -> Option<ClientPoolDropGuard<N>> {
// We send 2 requests, so if one of them is slow or doesn't have the next chain, we still have a backup.
if self.chain_entry_task.len() < 2
// If we have had too many failures then assume the tip has been found so no more chain entries.
&& self.amount_of_empty_chain_entries <= EMPTY_CHAIN_ENTRIES_BEFORE_TOP_ASSUMED
// Check we have a big buffer of pending block IDs to retrieve, we don't want to be waiting around
// for a chain entry.
&& chain_tracker.block_requests_queued(self.amount_of_blocks_to_request) < 500
// Make sure this peer actually has the chain.
&& chain_tracker.should_ask_for_next_chain_entry(&client.info.pruning_seed)
{
tracing::debug!("Requesting next chain entry");
let history = chain_tracker.get_simple_history();
self.chain_entry_task.spawn(
async move {
timeout(
BLOCK_DOWNLOADER_REQUEST_TIMEOUT,
request_chain_entry_from_peer(client, history),
)
.await
.map_err(|_| BlockDownloadError::TimedOut)?
}
.instrument(tracing::debug_span!(
"request_chain_entry",
current_height = chain_tracker.top_height()
)),
);
return None;
}
// Request a batch of blocks instead.
self.request_block_batch(chain_tracker, client).await
}
/// Checks the [`ClientPool`] for free peers.
async fn check_for_free_clients(
&mut self,
chain_tracker: &mut ChainTracker<N>,
pending_peers: &mut BTreeMap<PruningSeed, Vec<ClientPoolDropGuard<N>>>,
) -> Result<(), BlockDownloadError> {
tracing::debug!("Checking for free peers");
// This value might be slightly behind but that's ok.
let ChainSvcResponse::CumulativeDifficulty(current_cumulative_difficulty) = self
.our_chain_svc
.ready()
.await?
.call(ChainSvcRequest::CumulativeDifficulty)
.await?
else {
panic!("Chain service returned wrong response.");
};
let PeerSyncResponse::PeersToSyncFrom(peers) = self
.peer_sync_svc
.ready()
.await?
.call(PeerSyncRequest::PeersToSyncFrom {
current_cumulative_difficulty,
block_needed: None,
})
.await?
else {
panic!("Peer sync service returned wrong response.");
};
tracing::debug!("Response received from peer sync service");
for client in self.client_pool.borrow_clients(&peers) {
pending_peers
.entry(client.info.pruning_seed)
.or_default()
.push(client);
}
self.check_pending_peers(chain_tracker, pending_peers).await;
Ok(())
}
/// Handles a response to a request to get blocks from a peer.
async fn handle_download_batch_res(
&mut self,
start_height: u64,
res: Result<(ClientPoolDropGuard<N>, BlockBatch), BlockDownloadError>,
chain_tracker: &mut ChainTracker<N>,
pending_peers: &mut BTreeMap<PruningSeed, Vec<ClientPoolDropGuard<N>>>,
) -> Result<(), BlockDownloadError> {
tracing::debug!("Handling block download response");
match res {
Err(e) => {
if matches!(e, BlockDownloadError::ChainInvalid) {
// If the chain was invalid ban the peer who told us about it and error here to stop the
// block downloader.
self.inflight_requests.get(&start_height).inspect(|entry| {
tracing::warn!(
"Received an invalid chain from peer: {}, exiting block downloader (it should be restarted).",
entry.peer_who_told_us
);
entry.peer_who_told_us_handle.ban_peer(LONG_BAN);
});
return Err(e);
}
// Add the request to the failed list.
if let Some(batch) = self.inflight_requests.get_mut(&start_height) {
tracing::debug!("Error downloading batch: {e}");
batch.failures += 1;
if batch.failures > MAX_DOWNLOAD_FAILURES {
tracing::debug!(
"Too many errors downloading blocks, stopping the block downloader."
);
return Err(BlockDownloadError::TimedOut);
}
self.failed_batches.push(Reverse(start_height));
}
Ok(())
}
Ok((client, block_batch)) => {
// Remove the batch from the inflight batches.
if self.inflight_requests.remove(&start_height).is_none() {
tracing::debug!("Already retrieved batch");
// If it was already retrieved then there is nothing else to do.
pending_peers
.entry(client.info.pruning_seed)
.or_default()
.push(client);
self.check_pending_peers(chain_tracker, pending_peers).await;
return Ok(());
};
// If the batch is higher than the last time we updated `amount_of_blocks_to_request`, update it
// again.
if start_height > self.amount_of_blocks_to_request_updated_at {
self.amount_of_blocks_to_request = calculate_next_block_batch_size(
block_batch.size,
block_batch.blocks.len(),
self.config.target_batch_size,
);
tracing::debug!(
"Updating batch size of new batches, new size: {}",
self.amount_of_blocks_to_request
);
self.amount_of_blocks_to_request_updated_at = start_height;
}
self.block_queue
.add_incoming_batch(
ReadyQueueBatch {
start_height,
block_batch,
},
self.inflight_requests.first_key_value().map(|(k, _)| *k),
)
.await?;
pending_peers
.entry(client.info.pruning_seed)
.or_default()
.push(client);
self.check_pending_peers(chain_tracker, pending_peers).await;
Ok(())
}
}
}
/// Starts the main loop of the block downloader.
async fn run(mut self) -> Result<(), BlockDownloadError> {
let mut chain_tracker = initial_chain_search(
&self.client_pool,
self.peer_sync_svc.clone(),
&mut self.our_chain_svc,
)
.await?;
let mut pending_peers = BTreeMap::new();
tracing::info!("Attempting to download blocks from peers, this may take a while.");
let mut check_client_pool_interval = interval(self.config.check_client_pool_interval);
check_client_pool_interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
self.check_for_free_clients(&mut chain_tracker, &mut pending_peers)
.await?;
loop {
tokio::select! {
_ = check_client_pool_interval.tick() => {
tracing::debug!("Checking client pool for free peers, timer fired.");
self.check_for_free_clients(&mut chain_tracker, &mut pending_peers).await?;
// If we have no inflight requests, and we have had too many empty chain entries in a row assume the top has been found.
if self.inflight_requests.is_empty() && self.amount_of_empty_chain_entries >= EMPTY_CHAIN_ENTRIES_BEFORE_TOP_ASSUMED {
tracing::debug!("Failed to find any more chain entries, probably fround the top");
return Ok(());
}
}
Some(res) = self.block_download_tasks.join_next() => {
let BlockDownloadTaskResponse {
start_height,
result
} = res.expect("Download batch future panicked");
self.handle_download_batch_res(start_height, result, &mut chain_tracker, &mut pending_peers).await?;
// If we have no inflight requests, and we have had too many empty chain entries in a row assume the top has been found.
if self.inflight_requests.is_empty() && self.amount_of_empty_chain_entries >= EMPTY_CHAIN_ENTRIES_BEFORE_TOP_ASSUMED {
tracing::debug!("Failed to find any more chain entries, probably fround the top");
return Ok(());
}
}
Some(Ok(res)) = self.chain_entry_task.join_next() => {
match res {
Ok((client, entry)) => {
if chain_tracker.add_entry(entry).is_ok() {
tracing::debug!("Successfully added chain entry to chain tracker.");
self.amount_of_empty_chain_entries = 0;
} else {
tracing::debug!("Failed to add incoming chain entry to chain tracker.");
self.amount_of_empty_chain_entries += 1;
}
pending_peers
.entry(client.info.pruning_seed)
.or_default()
.push(client);
self.check_pending_peers(&mut chain_tracker, &mut pending_peers).await;
}
Err(_) => self.amount_of_empty_chain_entries += 1
}
}
}
}
}
}
/// The return value from the block download tasks.
struct BlockDownloadTaskResponse<N: NetworkZone> {
/// The start height of the batch.
start_height: u64,
/// A result containing the batch or an error.
result: Result<(ClientPoolDropGuard<N>, BlockBatch), BlockDownloadError>,
}
/// Returns if a peer has all the blocks in a range, according to its [`PruningSeed`].
fn client_has_block_in_range(pruning_seed: &PruningSeed, start_height: u64, length: usize) -> bool {
pruning_seed.has_full_block(start_height, CRYPTONOTE_MAX_BLOCK_HEIGHT)
&& pruning_seed.has_full_block(
start_height + u64::try_from(length).unwrap(),
CRYPTONOTE_MAX_BLOCK_HEIGHT,
)
}
/// Calculates the next amount of blocks to request in a batch.
///
/// Parameters:
/// - `previous_batch_size` is the size, in bytes, of the last batch
/// - `previous_batch_len` is the amount of blocks in the last batch
/// - `target_batch_size` is the target size, in bytes, of a batch
fn calculate_next_block_batch_size(
previous_batch_size: usize,
previous_batch_len: usize,
target_batch_size: usize,
) -> usize {
// The average block size of the last batch of blocks, multiplied by 2 as a safety margin for
// future blocks.
let adjusted_average_block_size = max((previous_batch_size * 2) / previous_batch_len, 1);
// Set the amount of blocks to request equal to our target batch size divided by the adjusted_average_block_size.
let next_batch_len = max(target_batch_size / adjusted_average_block_size, 1);
// Cap the amount of growth to 1.5x the previous batch len, to prevent a small block causing us to request
// a huge amount of blocks.
let next_batch_len = min(next_batch_len, (previous_batch_len * 3).div_ceil(2));
// Cap the length to the maximum allowed.
min(next_batch_len, MAX_BLOCK_BATCH_LEN)
}

View file

@ -0,0 +1,172 @@
use std::{cmp::Ordering, collections::BinaryHeap};
use async_buffer::BufferAppender;
use super::{BlockBatch, BlockDownloadError};
/// A batch of blocks in the ready queue, waiting for previous blocks to come in, so they can
/// be passed into the buffer.
///
/// The [`Eq`] and [`Ord`] impl on this type will only take into account the `start_height`, this
/// is because the block downloader will only download one chain at once so no 2 batches can have
/// the same `start_height`.
///
/// Also, the [`Ord`] impl is reversed so older blocks (lower height) come first in a [`BinaryHeap`].
#[derive(Debug, Clone)]
pub struct ReadyQueueBatch {
/// The start height of the batch.
pub start_height: u64,
/// The batch of blocks.
pub block_batch: BlockBatch,
}
impl Eq for ReadyQueueBatch {}
impl PartialEq<Self> for ReadyQueueBatch {
fn eq(&self, other: &Self) -> bool {
self.start_height.eq(&other.start_height)
}
}
impl PartialOrd<Self> for ReadyQueueBatch {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for ReadyQueueBatch {
fn cmp(&self, other: &Self) -> Ordering {
// reverse the ordering so older blocks (lower height) come first in a [`BinaryHeap`]
self.start_height.cmp(&other.start_height).reverse()
}
}
/// The block queue that holds downloaded block batches, adding them to the [`async_buffer`] when the
/// oldest batch has been downloaded.
pub struct BlockQueue {
/// A queue of ready batches.
ready_batches: BinaryHeap<ReadyQueueBatch>,
/// The size, in bytes, of all the batches in [`Self::ready_batches`].
ready_batches_size: usize,
/// The [`BufferAppender`] that gives blocks to Cuprate.
buffer_appender: BufferAppender<BlockBatch>,
}
impl BlockQueue {
/// Creates a new [`BlockQueue`].
pub fn new(buffer_appender: BufferAppender<BlockBatch>) -> BlockQueue {
BlockQueue {
ready_batches: BinaryHeap::new(),
ready_batches_size: 0,
buffer_appender,
}
}
/// Returns the oldest batch that has not been put in the [`async_buffer`] yet.
pub fn oldest_ready_batch(&self) -> Option<u64> {
self.ready_batches.peek().map(|batch| batch.start_height)
}
/// Returns the size of all the batches that have not been put into the [`async_buffer`] yet.
pub fn size(&self) -> usize {
self.ready_batches_size
}
/// Adds an incoming batch to the queue and checks if we can push any batches into the [`async_buffer`].
///
/// `oldest_in_flight_start_height` should be the start height of the oldest batch that is still inflight, if
/// there are no batches inflight then this should be [`None`].
pub async fn add_incoming_batch(
&mut self,
new_batch: ReadyQueueBatch,
oldest_in_flight_start_height: Option<u64>,
) -> Result<(), BlockDownloadError> {
self.ready_batches_size += new_batch.block_batch.size;
self.ready_batches.push(new_batch);
// The height to stop pushing batches into the buffer.
let height_to_stop_at = oldest_in_flight_start_height.unwrap_or(u64::MAX);
while self
.ready_batches
.peek()
.is_some_and(|batch| batch.start_height <= height_to_stop_at)
{
let batch = self
.ready_batches
.pop()
.expect("We just checked we have a batch in the buffer");
let batch_size = batch.block_batch.size;
self.ready_batches_size -= batch_size;
self.buffer_appender
.send(batch.block_batch, batch_size)
.await
.map_err(|_| BlockDownloadError::BufferWasClosed)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use futures::StreamExt;
use std::{collections::BTreeSet, sync::Arc};
use proptest::{collection::vec, prelude::*};
use tokio::sync::Semaphore;
use tokio_test::block_on;
use monero_p2p::handles::HandleBuilder;
use super::*;
prop_compose! {
fn ready_batch_strategy()(start_height in 0_u64..500_000_000) -> ReadyQueueBatch {
// TODO: The permit will not be needed here when
let (_, peer_handle) = HandleBuilder::new().with_permit(Arc::new(Semaphore::new(1)).try_acquire_owned().unwrap()).build();
ReadyQueueBatch {
start_height,
block_batch: BlockBatch {
blocks: vec![],
size: start_height as usize,
peer_handle,
},
}
}
}
proptest! {
#[test]
fn block_queue_returns_items_in_order(batches in vec(ready_batch_strategy(), 0..10_000)) {
block_on(async move {
let (buffer_tx, mut buffer_rx) = async_buffer::new_buffer(usize::MAX);
let mut queue = BlockQueue::new(buffer_tx);
let mut sorted_batches = BTreeSet::from_iter(batches.clone());
let mut soreted_batch_2 = sorted_batches.clone();
for batch in batches {
if sorted_batches.remove(&batch) {
queue.add_incoming_batch(batch, sorted_batches.last().map(|batch| batch.start_height)).await.unwrap();
}
}
assert_eq!(queue.size(), 0);
assert!(queue.oldest_ready_batch().is_none());
drop(queue);
while let Some(batch) = buffer_rx.next().await {
let last_batch = soreted_batch_2.pop_last().unwrap();
assert_eq!(batch.size, last_batch.block_batch.size);
}
});
}
}
}

View file

@ -0,0 +1,211 @@
use std::{cmp::min, collections::VecDeque};
use fixed_bytes::ByteArrayVec;
use monero_p2p::{client::InternalPeerID, handles::ConnectionHandle, NetworkZone};
use monero_pruning::{PruningSeed, CRYPTONOTE_MAX_BLOCK_HEIGHT};
use crate::constants::MEDIUM_BAN;
/// A new chain entry to add to our chain tracker.
#[derive(Debug)]
pub(crate) struct ChainEntry<N: NetworkZone> {
/// A list of block IDs.
pub ids: Vec<[u8; 32]>,
/// The peer who told us about this chain entry.
pub peer: InternalPeerID<N::Addr>,
/// The peer who told us about this chain entry's handle
pub handle: ConnectionHandle,
}
/// A batch of blocks to retrieve.
#[derive(Clone)]
pub struct BlocksToRetrieve<N: NetworkZone> {
/// The block IDs to get.
pub ids: ByteArrayVec<32>,
/// The hash of the last block before this batch.
pub prev_id: [u8; 32],
/// The expected height of the first block in [`BlocksToRetrieve::ids`].
pub start_height: u64,
/// The peer who told us about this batch.
pub peer_who_told_us: InternalPeerID<N::Addr>,
/// The peer who told us about this batch's handle.
pub peer_who_told_us_handle: ConnectionHandle,
/// The number of requests sent for this batch.
pub requests_sent: usize,
/// The number of times this batch has been requested from a peer and failed.
pub failures: usize,
}
/// An error returned from the [`ChainTracker`].
#[derive(Debug, Clone)]
pub enum ChainTrackerError {
/// The new chain entry is invalid.
NewEntryIsInvalid,
/// The new chain entry does not follow from the top of our chain tracker.
NewEntryDoesNotFollowChain,
}
/// # Chain Tracker
///
/// This struct allows following a single chain. It takes in [`ChainEntry`]s and
/// allows getting [`BlocksToRetrieve`].
pub struct ChainTracker<N: NetworkZone> {
/// A list of [`ChainEntry`]s, in order.
entries: VecDeque<ChainEntry<N>>,
/// The height of the first block, in the first entry in [`Self::entries`].
first_height: u64,
/// The hash of the last block in the last entry.
top_seen_hash: [u8; 32],
/// The hash of the block one below [`Self::first_height`].
previous_hash: [u8; 32],
/// The hash of the genesis block.
our_genesis: [u8; 32],
}
impl<N: NetworkZone> ChainTracker<N> {
/// Creates a new chain tracker.
pub fn new(
new_entry: ChainEntry<N>,
first_height: u64,
our_genesis: [u8; 32],
previous_hash: [u8; 32],
) -> Self {
let top_seen_hash = *new_entry.ids.last().unwrap();
let mut entries = VecDeque::with_capacity(1);
entries.push_back(new_entry);
Self {
top_seen_hash,
entries,
first_height,
previous_hash,
our_genesis,
}
}
/// Returns `true` if the peer is expected to have the next block after our highest seen block
/// according to their pruning seed.
pub fn should_ask_for_next_chain_entry(&self, seed: &PruningSeed) -> bool {
seed.has_full_block(self.top_height(), CRYPTONOTE_MAX_BLOCK_HEIGHT)
}
/// Returns the simple history, the highest seen block and the genesis block.
pub fn get_simple_history(&self) -> [[u8; 32]; 2] {
[self.top_seen_hash, self.our_genesis]
}
/// Returns the height of the highest block we are tracking.
pub fn top_height(&self) -> u64 {
let top_block_idx = self
.entries
.iter()
.map(|entry| entry.ids.len())
.sum::<usize>();
self.first_height + u64::try_from(top_block_idx).unwrap()
}
/// Returns the total number of queued batches for a certain `batch_size`.
///
/// # Panics
/// This function panics if `batch_size` is `0`.
pub fn block_requests_queued(&self, batch_size: usize) -> usize {
self.entries
.iter()
.map(|entry| entry.ids.len().div_ceil(batch_size))
.sum()
}
/// Attempts to add an incoming [`ChainEntry`] to the chain tracker.
pub fn add_entry(&mut self, mut chain_entry: ChainEntry<N>) -> Result<(), ChainTrackerError> {
if chain_entry.ids.is_empty() {
// The peer must send at lest one overlapping block.
chain_entry.handle.ban_peer(MEDIUM_BAN);
return Err(ChainTrackerError::NewEntryIsInvalid);
}
if chain_entry.ids.len() == 1 {
return Err(ChainTrackerError::NewEntryDoesNotFollowChain);
}
if self
.entries
.back()
.is_some_and(|last_entry| last_entry.ids.last().unwrap() != &chain_entry.ids[0])
{
return Err(ChainTrackerError::NewEntryDoesNotFollowChain);
}
let new_entry = ChainEntry {
// ignore the first block - we already know it.
ids: chain_entry.ids.split_off(1),
peer: chain_entry.peer,
handle: chain_entry.handle,
};
self.top_seen_hash = *new_entry.ids.last().unwrap();
self.entries.push_back(new_entry);
Ok(())
}
/// Returns a batch of blocks to request.
///
/// The returned batches length will be less than or equal to `max_blocks`
pub fn blocks_to_get(
&mut self,
pruning_seed: &PruningSeed,
max_blocks: usize,
) -> Option<BlocksToRetrieve<N>> {
if !pruning_seed.has_full_block(self.first_height, CRYPTONOTE_MAX_BLOCK_HEIGHT) {
return None;
}
let entry = self.entries.front_mut()?;
// Calculate the ending index for us to get in this batch, it will be one of these:
// - smallest out of `max_blocks`
// - length of the batch
// - index of the next pruned block for this seed
let end_idx = min(
min(entry.ids.len(), max_blocks),
usize::try_from(
pruning_seed
.get_next_pruned_block(self.first_height, CRYPTONOTE_MAX_BLOCK_HEIGHT)
.expect("We use local values to calculate height which should be below the sanity limit")
// Use a big value as a fallback if the seed does no pruning.
.unwrap_or(CRYPTONOTE_MAX_BLOCK_HEIGHT)
- self.first_height,
)
.unwrap(),
);
if end_idx == 0 {
return None;
}
let ids_to_get = entry.ids.drain(0..end_idx).collect::<Vec<_>>();
let blocks = BlocksToRetrieve {
ids: ids_to_get.into(),
prev_id: self.previous_hash,
start_height: self.first_height,
peer_who_told_us: entry.peer,
peer_who_told_us_handle: entry.handle.clone(),
requests_sent: 0,
failures: 0,
};
self.first_height += u64::try_from(end_idx).unwrap();
// TODO: improve ByteArrayVec API.
self.previous_hash = blocks.ids[blocks.ids.len() - 1];
if entry.ids.is_empty() {
self.entries.pop_front();
}
Some(blocks)
}
}

View file

@ -0,0 +1,199 @@
use std::collections::HashSet;
use monero_serai::{block::Block, transaction::Transaction};
use rayon::prelude::*;
use tokio::time::timeout;
use tower::{Service, ServiceExt};
use tracing::instrument;
use cuprate_helper::asynch::rayon_spawn_async;
use fixed_bytes::ByteArrayVec;
use monero_p2p::{handles::ConnectionHandle, NetworkZone, PeerRequest, PeerResponse};
use monero_wire::protocol::{GetObjectsRequest, GetObjectsResponse};
use crate::{
block_downloader::{BlockBatch, BlockDownloadError, BlockDownloadTaskResponse},
client_pool::ClientPoolDropGuard,
constants::{BLOCK_DOWNLOADER_REQUEST_TIMEOUT, MAX_TRANSACTION_BLOB_SIZE, MEDIUM_BAN},
};
/// Attempts to request a batch of blocks from a peer, returning [`BlockDownloadTaskResponse`].
#[instrument(
level = "debug",
name = "download_batch",
skip_all,
fields(
start_height = expected_start_height,
attempt = _attempt
)
)]
pub async fn download_batch_task<N: NetworkZone>(
client: ClientPoolDropGuard<N>,
ids: ByteArrayVec<32>,
previous_id: [u8; 32],
expected_start_height: u64,
_attempt: usize,
) -> BlockDownloadTaskResponse<N> {
BlockDownloadTaskResponse {
start_height: expected_start_height,
result: request_batch_from_peer(client, ids, previous_id, expected_start_height).await,
}
}
/// Requests a sequential batch of blocks from a peer.
///
/// This function will validate the blocks that were downloaded were the ones asked for and that they match
/// the expected height.
async fn request_batch_from_peer<N: NetworkZone>(
mut client: ClientPoolDropGuard<N>,
ids: ByteArrayVec<32>,
previous_id: [u8; 32],
expected_start_height: u64,
) -> Result<(ClientPoolDropGuard<N>, BlockBatch), BlockDownloadError> {
// Request the blocks.
let blocks_response = timeout(BLOCK_DOWNLOADER_REQUEST_TIMEOUT, async {
let PeerResponse::GetObjects(blocks_response) = client
.ready()
.await?
.call(PeerRequest::GetObjects(GetObjectsRequest {
blocks: ids.clone(),
pruned: false,
}))
.await?
else {
panic!("Connection task returned wrong response.");
};
Ok::<_, BlockDownloadError>(blocks_response)
})
.await
.map_err(|_| BlockDownloadError::TimedOut)??;
// Initial sanity checks
if blocks_response.blocks.len() > ids.len() {
client.info.handle.ban_peer(MEDIUM_BAN);
return Err(BlockDownloadError::PeersResponseWasInvalid);
}
if blocks_response.blocks.len() != ids.len() {
return Err(BlockDownloadError::PeerDidNotHaveRequestedData);
}
let peer_handle = client.info.handle.clone();
let blocks = rayon_spawn_async(move || {
deserialize_batch(
blocks_response,
expected_start_height,
ids,
previous_id,
peer_handle,
)
})
.await;
let batch = blocks.inspect_err(|e| {
// If the peers response was invalid, ban it.
if matches!(e, BlockDownloadError::PeersResponseWasInvalid) {
client.info.handle.ban_peer(MEDIUM_BAN);
}
})?;
Ok((client, batch))
}
fn deserialize_batch(
blocks_response: GetObjectsResponse,
expected_start_height: u64,
requested_ids: ByteArrayVec<32>,
previous_id: [u8; 32],
peer_handle: ConnectionHandle,
) -> Result<BlockBatch, BlockDownloadError> {
let blocks = blocks_response
.blocks
.into_par_iter()
.enumerate()
.map(|(i, block_entry)| {
let expected_height = u64::try_from(i).unwrap() + expected_start_height;
let mut size = block_entry.block.len();
let block = Block::read(&mut block_entry.block.as_ref())
.map_err(|_| BlockDownloadError::PeersResponseWasInvalid)?;
let block_hash = block.hash();
// Check the block matches the one requested and the peer sent enough transactions.
if requested_ids[i] != block_hash || block.txs.len() != block_entry.txs.len() {
return Err(BlockDownloadError::PeersResponseWasInvalid);
}
// Check that the previous ID is correct for the first block.
// This is to protect use against banning the wrong peer.
// This must happen after the hash check.
if i == 0 && block.header.previous != previous_id {
tracing::warn!(
"Invalid chain, peer told us a block follows the chain when it doesn't."
);
// This peer probably did nothing wrong, it was the peer who told us this blockID which
// is misbehaving.
return Err(BlockDownloadError::ChainInvalid);
}
// Check the height lines up as expected.
// This must happen after the hash check.
if !block
.number()
.is_some_and(|height| height == expected_height)
{
tracing::warn!(
"Invalid chain, expected height: {expected_height}, got height: {:?}",
block.number()
);
// This peer probably did nothing wrong, it was the peer who told us this blockID which
// is misbehaving.
return Err(BlockDownloadError::ChainInvalid);
}
// Deserialize the transactions.
let txs = block_entry
.txs
.take_normal()
.ok_or(BlockDownloadError::PeersResponseWasInvalid)?
.into_iter()
.map(|tx_blob| {
size += tx_blob.len();
if tx_blob.len() > MAX_TRANSACTION_BLOB_SIZE {
return Err(BlockDownloadError::PeersResponseWasInvalid);
}
Transaction::read(&mut tx_blob.as_ref())
.map_err(|_| BlockDownloadError::PeersResponseWasInvalid)
})
.collect::<Result<Vec<_>, _>>()?;
// Make sure the transactions in the block were the ones the peer sent.
let mut expected_txs = block.txs.iter().collect::<HashSet<_>>();
for tx in &txs {
if !expected_txs.remove(&tx.hash()) {
return Err(BlockDownloadError::PeersResponseWasInvalid);
}
}
if !expected_txs.is_empty() {
return Err(BlockDownloadError::PeersResponseWasInvalid);
}
Ok(((block, txs), size))
})
.collect::<Result<(Vec<_>, Vec<_>), _>>()?;
Ok(BlockBatch {
blocks: blocks.0,
size: blocks.1.into_iter().sum(),
peer_handle,
})
}

View file

@ -0,0 +1,238 @@
use std::{mem, sync::Arc};
use rand::prelude::SliceRandom;
use rand::thread_rng;
use tokio::{task::JoinSet, time::timeout};
use tower::{Service, ServiceExt};
use tracing::{instrument, Instrument, Span};
use monero_p2p::{
client::InternalPeerID,
handles::ConnectionHandle,
services::{PeerSyncRequest, PeerSyncResponse},
NetworkZone, PeerRequest, PeerResponse, PeerSyncSvc,
};
use monero_wire::protocol::{ChainRequest, ChainResponse};
use crate::{
block_downloader::{
chain_tracker::{ChainEntry, ChainTracker},
BlockDownloadError, ChainSvcRequest, ChainSvcResponse,
},
client_pool::{ClientPool, ClientPoolDropGuard},
constants::{
BLOCK_DOWNLOADER_REQUEST_TIMEOUT, INITIAL_CHAIN_REQUESTS_TO_SEND,
MAX_BLOCKS_IDS_IN_CHAIN_ENTRY, MEDIUM_BAN,
},
};
/// Request a chain entry from a peer.
///
/// Because the block downloader only follows and downloads one chain we only have to send the block hash of
/// top block we have found and the genesis block, this is then called `short_history`.
pub async fn request_chain_entry_from_peer<N: NetworkZone>(
mut client: ClientPoolDropGuard<N>,
short_history: [[u8; 32]; 2],
) -> Result<(ClientPoolDropGuard<N>, ChainEntry<N>), BlockDownloadError> {
let PeerResponse::GetChain(chain_res) = client
.ready()
.await?
.call(PeerRequest::GetChain(ChainRequest {
block_ids: short_history.into(),
prune: true,
}))
.await?
else {
panic!("Connection task returned wrong response!");
};
if chain_res.m_block_ids.is_empty()
|| chain_res.m_block_ids.len() > MAX_BLOCKS_IDS_IN_CHAIN_ENTRY
{
client.info.handle.ban_peer(MEDIUM_BAN);
return Err(BlockDownloadError::PeersResponseWasInvalid);
}
// We must have at least one overlapping block.
if !(chain_res.m_block_ids[0] == short_history[0]
|| chain_res.m_block_ids[0] == short_history[1])
{
client.info.handle.ban_peer(MEDIUM_BAN);
return Err(BlockDownloadError::PeersResponseWasInvalid);
}
// If the genesis is the overlapping block then this peer does not have our top tracked block in
// its chain.
if chain_res.m_block_ids[0] == short_history[1] {
return Err(BlockDownloadError::PeerDidNotHaveRequestedData);
}
let entry = ChainEntry {
ids: (&chain_res.m_block_ids).into(),
peer: client.info.id,
handle: client.info.handle.clone(),
};
Ok((client, entry))
}
/// Initial chain search, this function pulls [`INITIAL_CHAIN_REQUESTS_TO_SEND`] peers from the [`ClientPool`]
/// and sends chain requests to all of them.
///
/// We then wait for their response and choose the peer who claims the highest cumulative difficulty.
#[instrument(level = "error", skip_all)]
pub async fn initial_chain_search<N: NetworkZone, S, C>(
client_pool: &Arc<ClientPool<N>>,
mut peer_sync_svc: S,
mut our_chain_svc: C,
) -> Result<ChainTracker<N>, BlockDownloadError>
where
S: PeerSyncSvc<N>,
C: Service<ChainSvcRequest, Response = ChainSvcResponse, Error = tower::BoxError>,
{
tracing::debug!("Getting our chain history");
// Get our history.
let ChainSvcResponse::CompactHistory {
block_ids,
cumulative_difficulty,
} = our_chain_svc
.ready()
.await?
.call(ChainSvcRequest::CompactHistory)
.await?
else {
panic!("chain service sent wrong response.");
};
let our_genesis = *block_ids.last().expect("Blockchain had no genesis block.");
tracing::debug!("Getting a list of peers with higher cumulative difficulty");
let PeerSyncResponse::PeersToSyncFrom(mut peers) = peer_sync_svc
.ready()
.await?
.call(PeerSyncRequest::PeersToSyncFrom {
block_needed: None,
current_cumulative_difficulty: cumulative_difficulty,
})
.await?
else {
panic!("peer sync service sent wrong response.");
};
tracing::debug!(
"{} peers claim they have a higher cumulative difficulty",
peers.len()
);
// Shuffle the list to remove any possibility of peers being able to prioritize getting picked.
peers.shuffle(&mut thread_rng());
let mut peers = client_pool.borrow_clients(&peers);
let mut futs = JoinSet::new();
let req = PeerRequest::GetChain(ChainRequest {
block_ids: block_ids.into(),
prune: false,
});
tracing::debug!("Sending requests for chain entries.");
// Send the requests.
while futs.len() < INITIAL_CHAIN_REQUESTS_TO_SEND {
let Some(mut next_peer) = peers.next() else {
break;
};
let cloned_req = req.clone();
futs.spawn(timeout(
BLOCK_DOWNLOADER_REQUEST_TIMEOUT,
async move {
let PeerResponse::GetChain(chain_res) =
next_peer.ready().await?.call(cloned_req).await?
else {
panic!("connection task returned wrong response!");
};
Ok::<_, tower::BoxError>((
chain_res,
next_peer.info.id,
next_peer.info.handle.clone(),
))
}
.instrument(Span::current()),
));
}
let mut res: Option<(ChainResponse, InternalPeerID<_>, ConnectionHandle)> = None;
// Wait for the peers responses.
while let Some(task_res) = futs.join_next().await {
let Ok(Ok(task_res)) = task_res.unwrap() else {
continue;
};
match &mut res {
Some(res) => {
// res has already been set, replace it if this peer claims higher cumulative difficulty
if res.0.cumulative_difficulty() < task_res.0.cumulative_difficulty() {
let _ = mem::replace(res, task_res);
}
}
None => {
// res has not been set, set it now;
res = Some(task_res);
}
}
}
let Some((chain_res, peer_id, peer_handle)) = res else {
return Err(BlockDownloadError::FailedToFindAChainToFollow);
};
let hashes: Vec<[u8; 32]> = (&chain_res.m_block_ids).into();
// drop this to deallocate the [`Bytes`].
drop(chain_res);
tracing::debug!("Highest chin entry contained {} block Ids", hashes.len());
// Find the first unknown block in the batch.
let ChainSvcResponse::FindFirstUnknown(first_unknown, expected_height) = our_chain_svc
.ready()
.await?
.call(ChainSvcRequest::FindFirstUnknown(hashes.clone()))
.await?
else {
panic!("chain service sent wrong response.");
};
// The peer must send at least one block we already know.
if first_unknown == 0 {
peer_handle.ban_peer(MEDIUM_BAN);
return Err(BlockDownloadError::PeerSentNoOverlappingBlocks);
}
// We know all the blocks already
// TODO: The peer could still be on a different chain, however the chain might just be too far split.
if first_unknown == hashes.len() {
return Err(BlockDownloadError::FailedToFindAChainToFollow);
}
let previous_id = hashes[first_unknown - 1];
let first_entry = ChainEntry {
ids: hashes[first_unknown..].to_vec(),
peer: peer_id,
handle: peer_handle,
};
tracing::debug!(
"Creating chain tracker with {} new block Ids",
first_entry.ids.len()
);
let tracker = ChainTracker::new(first_entry, expected_height, our_genesis, previous_id);
Ok(tracker)
}

View file

@ -0,0 +1,323 @@
use std::{
fmt::{Debug, Formatter},
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Duration,
};
use futures::{FutureExt, StreamExt};
use indexmap::IndexMap;
use monero_serai::{
block::{Block, BlockHeader},
ringct::{RctBase, RctPrunable, RctSignatures},
transaction::{Input, Timelock, Transaction, TransactionPrefix},
};
use proptest::{collection::vec, prelude::*};
use tokio::{sync::Semaphore, time::timeout};
use tower::{service_fn, Service};
use fixed_bytes::ByteArrayVec;
use monero_p2p::{
client::{mock_client, Client, InternalPeerID, PeerInformation},
network_zones::ClearNet,
services::{PeerSyncRequest, PeerSyncResponse},
ConnectionDirection, NetworkZone, PeerRequest, PeerResponse,
};
use monero_pruning::PruningSeed;
use monero_wire::{
common::{BlockCompleteEntry, TransactionBlobs},
protocol::{ChainResponse, GetObjectsResponse},
};
use crate::{
block_downloader::{download_blocks, BlockDownloaderConfig, ChainSvcRequest, ChainSvcResponse},
client_pool::ClientPool,
};
proptest! {
#![proptest_config(ProptestConfig {
cases: 4,
max_shrink_iters: 10,
timeout: 60 * 1000,
.. ProptestConfig::default()
})]
#[test]
fn test_block_downloader(blockchain in dummy_blockchain_stragtegy(), peers in 1_usize..128) {
let blockchain = Arc::new(blockchain);
let tokio_pool = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
tokio_pool.block_on(async move {
timeout(Duration::from_secs(600), async move {
let client_pool = ClientPool::new();
let mut peer_ids = Vec::with_capacity(peers);
for _ in 0..peers {
let client = mock_block_downloader_client(blockchain.clone());
peer_ids.push(client.info.id);
client_pool.add_new_client(client);
}
let stream = download_blocks(
client_pool,
SyncStateSvc(peer_ids) ,
OurChainSvc {
genesis: *blockchain.blocks.first().unwrap().0
},
BlockDownloaderConfig {
buffer_size: 1_000,
in_progress_queue_size: 10_000,
check_client_pool_interval: Duration::from_secs(5),
target_batch_size: 5_000,
initial_batch_size: 1,
});
let blocks = stream.map(|blocks| blocks.blocks).concat().await;
assert_eq!(blocks.len() + 1, blockchain.blocks.len());
for (i, block) in blocks.into_iter().enumerate() {
assert_eq!(&block, blockchain.blocks.get_index(i + 1).unwrap().1);
}
}).await
}).unwrap();
}
}
prop_compose! {
/// Returns a strategy to generate a [`Transaction`] that is valid for the block downloader.
fn dummy_transaction_stragtegy(height: u64)
(
extra in vec(any::<u8>(), 0..1_000),
timelock in 0_usize..50_000_000,
)
-> Transaction {
Transaction {
prefix: TransactionPrefix {
version: 1,
timelock: Timelock::Block(timelock),
inputs: vec![Input::Gen(height)],
outputs: vec![],
extra,
},
signatures: vec![],
rct_signatures: RctSignatures {
base: RctBase {
fee: 0,
pseudo_outs: vec![],
encrypted_amounts: vec![],
commitments: vec![],
},
prunable: RctPrunable::Null
},
}
}
}
prop_compose! {
/// Returns a strategy to generate a [`Block`] that is valid for the block downloader.
fn dummy_block_stragtegy(
height: u64,
previous: [u8; 32],
)
(
miner_tx in dummy_transaction_stragtegy(height),
txs in vec(dummy_transaction_stragtegy(height), 0..25)
)
-> (Block, Vec<Transaction>) {
(
Block {
header: BlockHeader {
major_version: 0,
minor_version: 0,
timestamp: 0,
previous,
nonce: 0,
},
miner_tx,
txs: txs.iter().map(Transaction::hash).collect(),
},
txs
)
}
}
/// A mock blockchain.
struct MockBlockchain {
blocks: IndexMap<[u8; 32], (Block, Vec<Transaction>)>,
}
impl Debug for MockBlockchain {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_str("MockBlockchain")
}
}
prop_compose! {
/// Returns a strategy to generate a [`MockBlockchain`].
fn dummy_blockchain_stragtegy()(
blocks in vec(dummy_block_stragtegy(0, [0; 32]), 1..50_000),
) -> MockBlockchain {
let mut blockchain = IndexMap::new();
for (height, mut block) in blocks.into_iter().enumerate() {
if let Some(last) = blockchain.last() {
block.0.header.previous = *last.0;
block.0.miner_tx.prefix.inputs = vec![Input::Gen(height as u64)]
}
blockchain.insert(block.0.hash(), block);
}
MockBlockchain {
blocks: blockchain
}
}
}
fn mock_block_downloader_client(blockchain: Arc<MockBlockchain>) -> Client<ClearNet> {
let semaphore = Arc::new(Semaphore::new(1));
let (connection_guard, connection_handle) = monero_p2p::handles::HandleBuilder::new()
.with_permit(semaphore.try_acquire_owned().unwrap())
.build();
let request_handler = service_fn(move |req: PeerRequest| {
let bc = blockchain.clone();
async move {
match req {
PeerRequest::GetChain(chain_req) => {
let mut i = 0;
while !bc.blocks.contains_key(&chain_req.block_ids[i]) {
i += 1;
if i == chain_req.block_ids.len() {
i -= 1;
break;
}
}
let block_index = bc.blocks.get_index_of(&chain_req.block_ids[i]).unwrap();
let block_ids = bc
.blocks
.get_range(block_index..)
.unwrap()
.iter()
.map(|(id, _)| *id)
.take(200)
.collect::<Vec<_>>();
Ok(PeerResponse::GetChain(ChainResponse {
start_height: 0,
total_height: 0,
cumulative_difficulty_low64: 1,
cumulative_difficulty_top64: 0,
m_block_ids: block_ids.into(),
m_block_weights: vec![],
first_block: Default::default(),
}))
}
PeerRequest::GetObjects(obj) => {
let mut res = Vec::with_capacity(obj.blocks.len());
for i in 0..obj.blocks.len() {
let block = bc.blocks.get(&obj.blocks[i]).unwrap();
let block_entry = BlockCompleteEntry {
pruned: false,
block: block.0.serialize().into(),
txs: TransactionBlobs::Normal(
block
.1
.iter()
.map(Transaction::serialize)
.map(Into::into)
.collect(),
),
block_weight: 0,
};
res.push(block_entry);
}
Ok(PeerResponse::GetObjects(GetObjectsResponse {
blocks: res,
missed_ids: ByteArrayVec::from([]),
current_blockchain_height: 0,
}))
}
_ => panic!(),
}
}
.boxed()
});
let info = PeerInformation {
id: InternalPeerID::Unknown(rand::random()),
handle: connection_handle,
direction: ConnectionDirection::InBound,
pruning_seed: PruningSeed::NotPruned,
};
mock_client(info, connection_guard, request_handler)
}
#[derive(Clone)]
struct SyncStateSvc<Z: NetworkZone>(Vec<InternalPeerID<Z::Addr>>);
impl Service<PeerSyncRequest<ClearNet>> for SyncStateSvc<ClearNet> {
type Response = PeerSyncResponse<ClearNet>;
type Error = tower::BoxError;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _: PeerSyncRequest<ClearNet>) -> Self::Future {
let peers = self.0.clone();
async move { Ok(PeerSyncResponse::PeersToSyncFrom(peers)) }.boxed()
}
}
struct OurChainSvc {
genesis: [u8; 32],
}
impl Service<ChainSvcRequest> for OurChainSvc {
type Response = ChainSvcResponse;
type Error = tower::BoxError;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: ChainSvcRequest) -> Self::Future {
let genesis = self.genesis;
async move {
Ok(match req {
ChainSvcRequest::CompactHistory => ChainSvcResponse::CompactHistory {
block_ids: vec![genesis],
cumulative_difficulty: 1,
},
ChainSvcRequest::FindFirstUnknown(_) => ChainSvcResponse::FindFirstUnknown(1, 1),
ChainSvcRequest::CumulativeDifficulty => ChainSvcResponse::CumulativeDifficulty(1),
})
}
.boxed()
}
}

View file

@ -126,13 +126,16 @@ impl<N: NetworkZone> ClientPool<N> {
pub fn borrow_clients<'a, 'b>(
self: &'a Arc<Self>,
peers: &'b [InternalPeerID<N::Addr>],
) -> impl Iterator<Item = ClientPoolDropGuard<N>> + Captures<(&'a (), &'b ())> {
) -> impl Iterator<Item = ClientPoolDropGuard<N>> + sealed::Captures<(&'a (), &'b ())> {
peers.iter().filter_map(|peer| self.borrow_client(peer))
}
}
/// TODO: Remove me when 2024 Rust
///
/// https://rust-lang.github.io/rfcs/3498-lifetime-capture-rules-2024.html#the-captures-trick
trait Captures<U> {}
impl<T: ?Sized, U> Captures<U> for T {}
mod sealed {
/// TODO: Remove me when 2024 Rust
///
/// https://rust-lang.github.io/rfcs/3498-lifetime-capture-rules-2024.html#the-captures-trick
pub trait Captures<U> {}
impl<T: ?Sized, U> Captures<U> for T {}
}

View file

@ -12,6 +12,12 @@ pub(crate) const OUTBOUND_CONNECTION_ATTEMPT_TIMEOUT: Duration = Duration::from_
/// The durations of a short ban.
pub(crate) const SHORT_BAN: Duration = Duration::from_secs(60 * 10);
/// The durations of a medium ban.
pub(crate) const MEDIUM_BAN: Duration = Duration::from_secs(60 * 60 * 24);
/// The durations of a long ban.
pub(crate) const LONG_BAN: Duration = Duration::from_secs(60 * 60 * 24 * 7);
/// The default amount of time between inbound diffusion flushes.
pub(crate) const DIFFUSION_FLUSH_AVERAGE_SECONDS_INBOUND: Duration = Duration::from_secs(5);
@ -34,6 +40,35 @@ pub(crate) const MAX_TXS_IN_BROADCAST_CHANNEL: usize = 50;
/// TODO: it might be a good idea to make this configurable.
pub(crate) const INBOUND_CONNECTION_COOL_DOWN: Duration = Duration::from_millis(500);
/// The initial amount of chain requests to send to find the best chain to sync from.
pub(crate) const INITIAL_CHAIN_REQUESTS_TO_SEND: usize = 3;
/// The enforced maximum amount of blocks to request in a batch.
///
/// Requesting more than this will cause the peer to disconnect and potentially lead to bans.
pub(crate) const MAX_BLOCK_BATCH_LEN: usize = 100;
/// The timeout that the block downloader will use for requests.
pub(crate) const BLOCK_DOWNLOADER_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
/// The maximum size of a transaction, a sanity limit that all transactions across all hard-forks must
/// be less than.
///
/// ref: <https://monero-book.cuprate.org/consensus_rules/transactions.html#transaction-size>
pub(crate) const MAX_TRANSACTION_BLOB_SIZE: usize = 1_000_000;
/// The maximum amount of block IDs allowed in a chain entry response.
///
/// ref: <https://github.com/monero-project/monero/blob/cc73fe71162d564ffda8e549b79a350bca53c454/src/cryptonote_config.h#L97>
// TODO: link to the protocol book when this section is added.
pub(crate) const MAX_BLOCKS_IDS_IN_CHAIN_ENTRY: usize = 25_000;
/// The amount of failures downloading a specific batch before we stop attempting to download it.
pub(crate) const MAX_DOWNLOAD_FAILURES: usize = 5;
/// The amount of empty chain entries to receive before we assume we have found the top of the chain.
pub(crate) const EMPTY_CHAIN_ENTRIES_BEFORE_TOP_ASSUMED: usize = 5;
#[cfg(test)]
mod tests {
use super::*;
@ -44,4 +79,10 @@ mod tests {
fn outbound_diffusion_flush_shorter_than_inbound() {
assert!(DIFFUSION_FLUSH_AVERAGE_SECONDS_OUTBOUND < DIFFUSION_FLUSH_AVERAGE_SECONDS_INBOUND);
}
/// Checks that the ban time increases from short to long.
#[test]
fn ban_times_sanity_check() {
assert!(SHORT_BAN < MEDIUM_BAN && MEDIUM_BAN < LONG_BAN);
}
}

View file

@ -4,22 +4,24 @@
//! a certain [`NetworkZone`]
use std::sync::Arc;
use async_buffer::BufferStream;
use futures::FutureExt;
use tokio::{
sync::{mpsc, watch},
task::JoinSet,
};
use tokio_stream::wrappers::WatchStream;
use tower::{buffer::Buffer, util::BoxCloneService, ServiceExt};
use tower::{buffer::Buffer, util::BoxCloneService, Service, ServiceExt};
use tracing::{instrument, Instrument, Span};
use monero_p2p::{
client::Connector,
client::InternalPeerID,
services::{AddressBookRequest, AddressBookResponse},
services::{AddressBookRequest, AddressBookResponse, PeerSyncRequest},
CoreSyncSvc, NetworkZone, PeerRequestHandler,
};
mod block_downloader;
mod broadcast;
mod client_pool;
pub mod config;
@ -28,6 +30,7 @@ mod constants;
mod inbound_server;
mod sync_states;
use block_downloader::{BlockBatch, BlockDownloaderConfig, ChainSvcRequest, ChainSvcResponse};
pub use broadcast::{BroadcastRequest, BroadcastSvc};
use client_pool::ClientPoolDropGuard;
pub use config::P2PConfig;
@ -87,7 +90,7 @@ where
let inbound_handshaker = monero_p2p::client::HandShaker::new(
address_book.clone(),
sync_states_svc,
sync_states_svc.clone(),
core_sync_svc.clone(),
peer_req_handler,
inbound_mkr,
@ -136,6 +139,7 @@ where
broadcast_svc,
top_block_watch,
make_connection_tx,
sync_states_svc,
address_book: address_book.boxed_clone(),
_background_tasks: Arc::new(background_tasks),
})
@ -156,6 +160,8 @@ pub struct NetworkInterface<N: NetworkZone> {
make_connection_tx: mpsc::Sender<MakeConnectionRequest>,
/// The address book service.
address_book: BoxCloneService<AddressBookRequest<N>, AddressBookResponse<N>, tower::BoxError>,
/// The peer's sync states service.
sync_states_svc: Buffer<sync_states::PeerSyncSvc<N>, PeerSyncRequest<N>>,
/// Background tasks that will be aborted when this interface is dropped.
_background_tasks: Arc<JoinSet<()>>,
}
@ -166,6 +172,26 @@ impl<N: NetworkZone> NetworkInterface<N> {
self.broadcast_svc.clone()
}
/// Starts the block downloader and returns a stream that will yield sequentially downloaded blocks.
pub fn block_downloader<C>(
&self,
our_chain_service: C,
config: BlockDownloaderConfig,
) -> BufferStream<BlockBatch>
where
C: Service<ChainSvcRequest, Response = ChainSvcResponse, Error = tower::BoxError>
+ Send
+ 'static,
C::Future: Send + 'static,
{
block_downloader::download_blocks(
self.pool.clone(),
self.sync_states_svc.clone(),
our_chain_service,
config,
)
}
/// Returns a stream which yields the highest seen sync state from a connected peer.
pub fn top_sync_stream(&self) -> WatchStream<sync_states::NewSyncInfo> {
WatchStream::from_changes(self.top_block_watch.clone())

View file

@ -10,7 +10,7 @@ default = ["borsh"]
borsh = ["dep:borsh", "monero-pruning/borsh"]
[dependencies]
cuprate-helper = { path = "../../helper" }
cuprate-helper = { path = "../../helper", features = ["asynch"], default-features = false }
monero-wire = { path = "../../net/monero-wire", features = ["tracing"] }
monero-pruning = { path = "../../pruning" }

View file

@ -10,13 +10,15 @@ use tokio::{
task::JoinHandle,
};
use tokio_util::sync::PollSemaphore;
use tower::Service;
use tower::{Service, ServiceExt};
use tracing::Instrument;
use cuprate_helper::asynch::InfallibleOneshotReceiver;
use monero_pruning::PruningSeed;
use crate::{
handles::ConnectionHandle, ConnectionDirection, NetworkZone, PeerError, PeerRequest,
PeerResponse, SharedError,
handles::{ConnectionGuard, ConnectionHandle},
ConnectionDirection, NetworkZone, PeerError, PeerRequest, PeerResponse, SharedError,
};
mod connection;
@ -26,7 +28,6 @@ mod timeout_monitor;
pub use connector::{ConnectRequest, Connector};
pub use handshaker::{DoHandshakeRequest, HandShaker, HandshakeError};
use monero_pruning::PruningSeed;
/// An internal identifier for a given peer, will be their address if known
/// or a random u128 if not.
@ -158,11 +159,70 @@ impl<Z: NetworkZone> Service<PeerRequest> for Client<Z> {
permit: Some(permit),
};
self.connection_tx
.try_send(req)
.map_err(|_| ())
.expect("poll_ready should have been called");
if let Err(e) = self.connection_tx.try_send(req) {
// The connection task could have closed between a call to `poll_ready` and the call to
// `call`, which means if we don't handle the error here the receiver would panic.
use mpsc::error::TrySendError;
match e {
TrySendError::Closed(req) | TrySendError::Full(req) => {
self.set_err(PeerError::ClientChannelClosed);
let _ = req
.response_channel
.send(Err(PeerError::ClientChannelClosed.into()));
}
}
}
rx.into()
}
}
/// Creates a mock [`Client`] for testing purposes.
///
/// `request_handler` will be used to handle requests sent to the [`Client`]
pub fn mock_client<Z: NetworkZone, S>(
info: PeerInformation<Z::Addr>,
connection_guard: ConnectionGuard,
mut request_handler: S,
) -> Client<Z>
where
S: crate::PeerRequestHandler,
{
let (tx, mut rx) = mpsc::channel(1);
let task_span = tracing::error_span!("mock_connection", addr = %info.id);
let task_handle = tokio::spawn(
async move {
let _guard = connection_guard;
loop {
let Some(req): Option<connection::ConnectionTaskRequest> = rx.recv().await else {
tracing::debug!("Channel closed, closing mock connection");
return;
};
tracing::debug!("Received new request: {:?}", req.request.id());
let res = request_handler
.ready()
.await
.unwrap()
.call(req.request)
.await
.unwrap();
tracing::debug!("Sending back response");
let _ = req.response_channel.send(Ok(res));
}
}
.instrument(task_span),
);
let timeout_task = tokio::spawn(futures::future::pending());
let semaphore = Arc::new(Semaphore::new(1));
let error_slot = SharedError::new();
Client::new(info, tx, task_handle, timeout_task, semaphore, error_slot)
}