diff --git a/Cargo.lock b/Cargo.lock index 1363cd05..fab36ef4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -598,6 +598,7 @@ dependencies = [ name = "cuprate-p2p" version = "0.1.0" dependencies = [ + "async-buffer", "bytes", "cuprate-helper", "cuprate-test-utils", @@ -612,16 +613,17 @@ dependencies = [ "monero-serai", "monero-wire", "pin-project", + "proptest", "rand", "rand_distr", "rayon", "thiserror", "tokio", "tokio-stream", + "tokio-test", "tokio-util", "tower", "tracing", - "tracing-subscriber", ] [[package]] @@ -1571,16 +1573,6 @@ dependencies = [ "zeroize", ] -[[package]] -name = "nu-ansi-term" -version = "0.46.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" -dependencies = [ - "overload", - "winapi", -] - [[package]] name = "num-traits" version = "0.2.19" @@ -1628,12 +1620,6 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" -[[package]] -name = "overload" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" - [[package]] name = "page_size" version = "0.6.0" @@ -2254,15 +2240,6 @@ dependencies = [ "keccak", ] -[[package]] -name = "sharded-slab" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" -dependencies = [ - "lazy_static", -] - [[package]] name = "signal-hook-registry" version = "1.4.2" @@ -2615,18 +2592,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" dependencies = [ "once_cell", - "valuable", -] - -[[package]] -name = "tracing-log" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" -dependencies = [ - "log", - "once_cell", - "tracing-core", ] [[package]] @@ -2635,12 +2600,7 @@ version = "0.3.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" dependencies = [ - "nu-ansi-term", - "sharded-slab", - "smallvec", - "thread_local", "tracing-core", - "tracing-log", ] [[package]] @@ -2699,12 +2659,6 @@ dependencies = [ "percent-encoding", ] -[[package]] -name = "valuable" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" - [[package]] name = "version_check" version = "0.9.4" diff --git a/Cargo.toml b/Cargo.toml index 8100af72..7be28732 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -89,6 +89,7 @@ tempfile = { version = "3" } pretty_assertions = { version = "1.4.0" } proptest = { version = "1" } proptest-derive = { version = "0.4.0" } +tokio-test = { version = "0.4.4" } ## TODO: ## Potential dependencies. diff --git a/p2p/cuprate-p2p/Cargo.toml b/p2p/cuprate-p2p/Cargo.toml index accd9841..ab477a83 100644 --- a/p2p/cuprate-p2p/Cargo.toml +++ b/p2p/cuprate-p2p/Cargo.toml @@ -36,3 +36,4 @@ tracing = { workspace = true, features = ["std", "attributes"] } cuprate-test-utils = { path = "../../test-utils" } indexmap = { workspace = true } proptest = { workspace = true } +tokio-test = { workspace = true } diff --git a/p2p/cuprate-p2p/src/block_downloader.rs b/p2p/cuprate-p2p/src/block_downloader.rs index e047c152..e88ab13c 100644 --- a/p2p/cuprate-p2p/src/block_downloader.rs +++ b/p2p/cuprate-p2p/src/block_downloader.rs @@ -6,7 +6,7 @@ //! //! The block downloader is started by [`download_blocks`]. use std::{ - cmp::{max, min, Ordering, Reverse}, + cmp::{max, min, Reverse}, collections::{BTreeMap, BinaryHeap, HashSet}, mem, sync::Arc, @@ -34,25 +34,32 @@ use monero_p2p::{ NetworkZone, PeerRequest, PeerResponse, PeerSyncSvc, }; use monero_pruning::{PruningSeed, CRYPTONOTE_MAX_BLOCK_HEIGHT}; -use monero_wire::protocol::{ChainRequest, ChainResponse, GetObjectsRequest}; +use monero_wire::protocol::{ChainRequest, ChainResponse}; use crate::{ client_pool::{ClientPool, ClientPoolDropGuard}, constants::{ BLOCK_DOWNLOADER_REQUEST_TIMEOUT, EMPTY_CHAIN_ENTRIES_BEFORE_TOP_ASSUMED, INITIAL_CHAIN_REQUESTS_TO_SEND, LONG_BAN, MAX_BLOCKS_IDS_IN_CHAIN_ENTRY, - MAX_BLOCK_BATCH_LEN, MAX_DOWNLOAD_FAILURES, MAX_TRANSACTION_BLOB_SIZE, MEDIUM_BAN, + MAX_BLOCK_BATCH_LEN, MAX_DOWNLOAD_FAILURES, MEDIUM_BAN, }, }; +mod block_queue; mod chain_tracker; -use chain_tracker::{BlocksToRetrieve, ChainEntry, ChainTracker}; +use block_queue::{BlockQueue, ReadyQueueBatch}; +use chain_tracker::{BlocksToRetrieve, ChainEntry, ChainTracker}; +use download_batch::download_batch_task; + +// TODO: check first block in batch prev_id + +mod download_batch; #[cfg(test)] mod tests; /// A downloaded batch of blocks. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct BlockBatch { /// The blocks. pub blocks: Vec<(Block, Vec)>, @@ -172,43 +179,6 @@ where buffer_stream } -/// A batch of blocks in the ready queue, waiting for previous blocks to come in, so they can -/// be passed into the buffer. -/// -/// The [`Eq`] and [`Ord`] impl on this type will only take into account the `start_height`, this -/// is because the block downloader will only download one chain at once so no 2 batches can have -/// the same `start_height`. -/// -/// Also, the [`Ord`] impl is reversed so older blocks (lower height) come first in a [`BinaryHeap`]. -#[derive(Debug)] -struct ReadyQueueBatch { - /// The start height of the batch. - start_height: u64, - /// The batch of blocks. - block_batch: BlockBatch, -} - -impl Eq for ReadyQueueBatch {} - -impl PartialEq for ReadyQueueBatch { - fn eq(&self, other: &Self) -> bool { - self.start_height.eq(&other.start_height) - } -} - -impl PartialOrd for ReadyQueueBatch { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for ReadyQueueBatch { - fn cmp(&self, other: &Self) -> Ordering { - // reverse the ordering so older blocks (lower height) come first in a [`BinaryHeap`] - self.start_height.cmp(&other.start_height).reverse() - } -} - /// # Block Downloader /// /// This is the block downloader, which finds a chain to follow and attempts to follow it, adding the @@ -223,7 +193,7 @@ impl Ord for ReadyQueueBatch { /// to download blocks from. /// /// For each peer we will then allocate a batch of blocks for them to retrieve, as these blocks come in -/// we add them to queue for pushing into the [`async_buffer`], once we have the oldest block downloaded +/// we add them to the [`BlockQueue`] for pushing into the [`async_buffer`], once we have the oldest block downloaded /// we send it into the buffer, repeating this until the oldest current block is still being downloaded. /// /// When a peer has finished downloading blocks we add it to our list of ready peers, so it can be used to @@ -270,18 +240,12 @@ struct BlockDownloader { /// This is a map of batch start heights to block IDs and related information of the batch. inflight_requests: BTreeMap>, - /// A queue of ready batches. - ready_batches: BinaryHeap, - /// The size, in bytes, of all the batches in [`Self::ready_batches`]. - ready_batches_size: usize, - /// A queue of start heights from failed batches that should be retried. /// /// Wrapped in [`Reverse`] so we prioritize early batches. failed_batches: BinaryHeap>, - /// The [`BufferAppender`] that gives blocks to Cuprate. - buffer_appender: BufferAppender, + block_queue: BlockQueue, /// The [`BlockDownloaderConfig`]. config: BlockDownloaderConfig, @@ -315,10 +279,8 @@ where block_download_tasks: JoinSet::new(), chain_entry_task: JoinSet::new(), inflight_requests: BTreeMap::new(), - ready_batches: BinaryHeap::new(), - ready_batches_size: 0, + block_queue: BlockQueue::new(buffer_appender), failed_batches: BinaryHeap::new(), - buffer_appender, config, } } @@ -360,7 +322,7 @@ where ) -> Option> { tracing::debug!( "Requesting an inflight batch, current ready queue size: {}", - self.ready_batches_size + self.block_queue.size() ); assert!( @@ -368,7 +330,7 @@ where "We need requests inflight to be able to send the request again", ); - let oldest_ready_batch = self.ready_batches.peek().unwrap().start_height; + let oldest_ready_batch = self.block_queue.oldest_ready_batch().unwrap(); for (_, in_flight_batch) in self.inflight_requests.range_mut(0..oldest_ready_batch) { if in_flight_batch.requests_sent >= 2 { @@ -383,29 +345,12 @@ where return Some(client); } - in_flight_batch.requests_sent += 1; - - tracing::debug!( - "Sending request for batch, total requests sent for batch: {}", - in_flight_batch.requests_sent - ); - - let ids = in_flight_batch.ids.clone(); - let start_height = in_flight_batch.start_height; - - self.block_download_tasks.spawn( - async move { - BlockDownloadTaskResponse { - start_height, - result: request_batch_from_peer(client, ids, start_height).await, - } - } - .instrument(tracing::debug_span!( - "download_batch", - start_height, - attempt = in_flight_batch.requests_sent - )), - ); + self.block_download_tasks.spawn(download_batch_task( + client, + in_flight_batch.ids.clone(), + in_flight_batch.start_height, + in_flight_batch.requests_sent, + )); return None; } @@ -449,19 +394,12 @@ where request.requests_sent += 1; - self.block_download_tasks.spawn( - async move { - BlockDownloadTaskResponse { - start_height, - result: request_batch_from_peer(client, ids, start_height).await, - } - } - .instrument(tracing::debug_span!( - "download_batch", - start_height, - attempt = request.requests_sent - )), - ); + self.block_download_tasks.spawn(download_batch_task( + client, + ids, + start_height, + request.requests_sent, + )); // Remove the failure, we have just handled it. self.failed_batches.pop(); @@ -473,7 +411,7 @@ where } // If our ready queue is too large send duplicate requests for the blocks we are waiting on. - if self.ready_batches_size >= self.config.in_progress_queue_size { + if self.block_queue.size() >= self.config.in_progress_queue_size { return self.request_inflight_batch_again(client).await; } @@ -491,24 +429,12 @@ where self.inflight_requests .insert(block_entry_to_get.start_height, block_entry_to_get.clone()); - self.block_download_tasks.spawn( - async move { - BlockDownloadTaskResponse { - start_height: block_entry_to_get.start_height, - result: request_batch_from_peer( - client, - block_entry_to_get.ids, - block_entry_to_get.start_height, - ) - .await, - } - } - .instrument(tracing::debug_span!( - "download_batch", - block_entry_to_get.start_height, - attempt = block_entry_to_get.requests_sent - )), - ); + self.block_download_tasks.spawn(download_batch_task( + client, + block_entry_to_get.ids.clone(), + block_entry_to_get.start_height, + block_entry_to_get.requests_sent, + )); None } @@ -607,47 +533,6 @@ where Ok(()) } - /// Checks if we have batches ready to send down the [`BufferAppender`]. - /// - /// We guarantee that blocks sent down the buffer are sent in the correct order. - async fn push_new_blocks(&mut self) -> Result<(), BlockDownloadError> { - while let Some(ready_batch) = self.ready_batches.peek() { - // Check if this ready batch's start height is higher than the lowest in flight request. - // If there is a lower start height in the inflight requests then this is _not_ the next batch - // to send down the buffer. - if self - .inflight_requests - .first_key_value() - .is_some_and(|(&lowest_start_height, _)| { - ready_batch.start_height > lowest_start_height - }) - { - break; - } - - // Our next ready batch is older (lower height) than the oldest in flight, push it down the - // buffer. - let ready_batch = self.ready_batches.pop().unwrap(); - - let size = ready_batch.block_batch.size; - self.ready_batches_size -= size; - - tracing::debug!( - "Pushing batch to buffer, new ready batches size: {}", - self.ready_batches_size - ); - - self.buffer_appender - .send(ready_batch.block_batch, size) - .await - .map_err(|_| BlockDownloadError::BufferWasClosed)?; - - // Loops back to check the next oldest ready batch. - } - - Ok(()) - } - /// Handles a response to a request to get blocks from a peer. async fn handle_download_batch_res( &mut self, @@ -723,15 +608,15 @@ where self.amount_of_blocks_to_request_updated_at = start_height; } - // Add the batch to the queue of ready batches. - self.ready_batches_size += block_batch.size; - self.ready_batches.push(ReadyQueueBatch { - start_height, - block_batch, - }); - - // Attempt to push new batches to the buffer. - self.push_new_blocks().await?; + self.block_queue + .add_incoming_batch( + ReadyQueueBatch { + start_height, + block_batch, + }, + self.inflight_requests.first_key_value().map(|(k, _)| *k), + ) + .await?; pending_peers .entry(client.info.pruning_seed) @@ -859,136 +744,6 @@ fn calculate_next_block_batch_size( min(next_batch_len, MAX_BLOCK_BATCH_LEN) } -/// Requests a sequential batch of blocks from a peer. -/// -/// This function will validate the blocks that were downloaded were the ones asked for and that they match -/// the expected height. -async fn request_batch_from_peer( - mut client: ClientPoolDropGuard, - ids: ByteArrayVec<32>, - expected_start_height: u64, -) -> Result<(ClientPoolDropGuard, BlockBatch), BlockDownloadError> { - // Request the blocks. - let blocks_response = timeout(BLOCK_DOWNLOADER_REQUEST_TIMEOUT, async { - let PeerResponse::GetObjects(blocks_response) = client - .ready() - .await? - .call(PeerRequest::GetObjects(GetObjectsRequest { - blocks: ids.clone(), - pruned: false, - })) - .await? - else { - panic!("Connection task returned wrong response."); - }; - - Ok::<_, BlockDownloadError>(blocks_response) - }) - .await - .map_err(|_| BlockDownloadError::TimedOut)??; - - // Initial sanity checks - if blocks_response.blocks.len() > ids.len() { - client.info.handle.ban_peer(MEDIUM_BAN); - return Err(BlockDownloadError::PeersResponseWasInvalid); - } - - if blocks_response.blocks.len() != ids.len() { - return Err(BlockDownloadError::PeerDidNotHaveRequestedData); - } - - let blocks = rayon_spawn_async(move || { - let blocks = blocks_response - .blocks - .into_par_iter() - .enumerate() - .map(|(i, block_entry)| { - let expected_height = u64::try_from(i).unwrap() + expected_start_height; - - let mut size = block_entry.block.len(); - - let block = Block::read(&mut block_entry.block.as_ref()) - .map_err(|_| BlockDownloadError::PeersResponseWasInvalid)?; - - // Check the block matches the one requested and the peer sent enough transactions. - if ids[i] != block.hash() || block.txs.len() != block_entry.txs.len() { - return Err(BlockDownloadError::PeersResponseWasInvalid); - } - - // Check the height lines up as expected. - // This must happen after the hash check. - if !block - .number() - .is_some_and(|height| height == expected_height) - { - tracing::warn!( - "Invalid chain, expected height: {expected_height}, got height: {:?}", - block.number() - ); - - // This peer probably did nothing wrong, it was the peer who told us this blockID which - // is misbehaving. - return Err(BlockDownloadError::ChainInvalid); - } - - // Deserialize the transactions. - let txs = block_entry - .txs - .take_normal() - .ok_or(BlockDownloadError::PeersResponseWasInvalid)? - .into_iter() - .map(|tx_blob| { - size += tx_blob.len(); - - if tx_blob.len() > MAX_TRANSACTION_BLOB_SIZE { - return Err(BlockDownloadError::PeersResponseWasInvalid); - } - - Transaction::read(&mut tx_blob.as_ref()) - .map_err(|_| BlockDownloadError::PeersResponseWasInvalid) - }) - .collect::, _>>()?; - - // Make sure the transactions in the block were the ones the peer sent. - let mut expected_txs = block.txs.iter().collect::>(); - - for tx in &txs { - if !expected_txs.remove(&tx.hash()) { - return Err(BlockDownloadError::PeersResponseWasInvalid); - } - } - - if !expected_txs.is_empty() { - return Err(BlockDownloadError::PeersResponseWasInvalid); - } - - Ok(((block, txs), size)) - }) - .collect::, Vec<_>), _>>(); - - blocks - }) - .await; - - let (blocks, sizes) = blocks.inspect_err(|e| { - // If the peers response was invalid, ban it. - if matches!(e, BlockDownloadError::PeersResponseWasInvalid) { - client.info.handle.ban_peer(MEDIUM_BAN); - } - })?; - - let peer_handle = client.info.handle.clone(); - - Ok(( - client, - BlockBatch { - blocks, - size: sizes.iter().sum(), - peer_handle, - }, - )) -} - /// Request a chain entry from a peer. /// /// Because the block downloader only follows and downloads one chain we only have to send the block hash of diff --git a/p2p/cuprate-p2p/src/block_downloader/block_queue.rs b/p2p/cuprate-p2p/src/block_downloader/block_queue.rs new file mode 100644 index 00000000..addee4b3 --- /dev/null +++ b/p2p/cuprate-p2p/src/block_downloader/block_queue.rs @@ -0,0 +1,172 @@ +use std::{cmp::Ordering, collections::BinaryHeap}; + +use async_buffer::BufferAppender; + +use super::{BlockBatch, BlockDownloadError}; + +/// A batch of blocks in the ready queue, waiting for previous blocks to come in, so they can +/// be passed into the buffer. +/// +/// The [`Eq`] and [`Ord`] impl on this type will only take into account the `start_height`, this +/// is because the block downloader will only download one chain at once so no 2 batches can have +/// the same `start_height`. +/// +/// Also, the [`Ord`] impl is reversed so older blocks (lower height) come first in a [`BinaryHeap`]. +#[derive(Debug, Clone)] +pub struct ReadyQueueBatch { + /// The start height of the batch. + pub start_height: u64, + /// The batch of blocks. + pub block_batch: BlockBatch, +} + +impl Eq for ReadyQueueBatch {} + +impl PartialEq for ReadyQueueBatch { + fn eq(&self, other: &Self) -> bool { + self.start_height.eq(&other.start_height) + } +} + +impl PartialOrd for ReadyQueueBatch { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for ReadyQueueBatch { + fn cmp(&self, other: &Self) -> Ordering { + // reverse the ordering so older blocks (lower height) come first in a [`BinaryHeap`] + self.start_height.cmp(&other.start_height).reverse() + } +} + +/// The block queue that holds downloaded block batches, adding them to the [`async_buffer`] when the +/// oldest batch has been downloaded. +pub struct BlockQueue { + /// A queue of ready batches. + ready_batches: BinaryHeap, + /// The size, in bytes, of all the batches in [`Self::ready_batches`]. + ready_batches_size: usize, + + /// The [`BufferAppender`] that gives blocks to Cuprate. + buffer_appender: BufferAppender, +} + +impl BlockQueue { + /// Creates a new [`BlockQueue`]. + pub fn new(buffer_appender: BufferAppender) -> BlockQueue { + BlockQueue { + ready_batches: BinaryHeap::new(), + ready_batches_size: 0, + buffer_appender, + } + } + + /// Returns the oldest batch that has not been put in the [`async_buffer`] yet. + pub fn oldest_ready_batch(&self) -> Option { + self.ready_batches.peek().map(|batch| batch.start_height) + } + + /// Returns the size of all the batches that hav not been put into the [`async_buffer`] yet. + pub fn size(&self) -> usize { + self.ready_batches_size + } + + /// Adds an incoming batch to the queue and checks if we can push any batches into the [`async_buffer`]. + /// + /// `oldest_in_flight_start_height` Should be the start height of the oldest batch that is still inflight, if + /// there are no batches inflight then this should be [`None`]. + pub async fn add_incoming_batch( + &mut self, + new_batch: ReadyQueueBatch, + oldest_in_flight_start_height: Option, + ) -> Result<(), BlockDownloadError> { + self.ready_batches_size += new_batch.block_batch.size; + self.ready_batches.push(new_batch); + + // The height to stop pushing batches into the buffer. + let height_to_stop_at = oldest_in_flight_start_height.unwrap_or(u64::MAX); + + while self + .ready_batches + .peek() + .is_some_and(|batch| batch.start_height <= height_to_stop_at) + { + let batch = self + .ready_batches + .pop() + .expect("We just checked we have a batch in the buffer"); + + let batch_size = batch.block_batch.size; + + self.ready_batches_size -= batch_size; + self.buffer_appender + .send(batch.block_batch, batch_size) + .await + .map_err(|_| BlockDownloadError::BufferWasClosed)?; + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use futures::StreamExt; + use std::{collections::BTreeSet, sync::Arc}; + + use proptest::{collection::vec, prelude::*}; + use tokio::sync::Semaphore; + use tokio_test::block_on; + + use monero_p2p::handles::HandleBuilder; + + use super::*; + + prop_compose! { + fn read_batch_stratergy()(start_height in 0_u64..500_000_000) -> ReadyQueueBatch { + // TODO: The permit will not be needed here when + let (_, peer_handle) = HandleBuilder::new().with_permit(Arc::new(Semaphore::new(1)).try_acquire_owned().unwrap()).build(); + + ReadyQueueBatch { + start_height, + block_batch: BlockBatch { + blocks: vec![], + size: start_height as usize, + peer_handle, + }, + } + } + } + + proptest! { + #[test] + fn block_queue_returns_items_in_order(batches in vec(read_batch_stratergy(), 0..10_000)) { + block_on(async move { + let (buffer_tx, mut buffer_rx) = async_buffer::new_buffer(usize::MAX); + + let mut queue = BlockQueue::new(buffer_tx); + + let mut sorted_batches = BTreeSet::from_iter(batches.clone()); + let mut soreted_batch_2 = sorted_batches.clone(); + + for batch in batches { + if sorted_batches.remove(&batch) { + queue.add_incoming_batch(batch, sorted_batches.last().map(|batch| batch.start_height)).await.unwrap(); + } + } + + assert_eq!(queue.size(), 0); + assert!(queue.oldest_ready_batch().is_none()); + drop(queue); + + while let Some(batch) = buffer_rx.next().await { + let last_batch = soreted_batch_2.pop_last().unwrap(); + + assert_eq!(batch.size, last_batch.block_batch.size); + } + }); + } + } +} diff --git a/p2p/cuprate-p2p/src/block_downloader/download_batch.rs b/p2p/cuprate-p2p/src/block_downloader/download_batch.rs new file mode 100644 index 00000000..1ac01b08 --- /dev/null +++ b/p2p/cuprate-p2p/src/block_downloader/download_batch.rs @@ -0,0 +1,175 @@ +use cuprate_helper::asynch::rayon_spawn_async; +use fixed_bytes::ByteArrayVec; +use monero_p2p::handles::ConnectionHandle; +use monero_p2p::{NetworkZone, PeerRequest, PeerResponse}; +use monero_serai::block::Block; +use monero_serai::transaction::Transaction; +use monero_wire::protocol::{GetObjectsRequest, GetObjectsResponse}; +use rayon::prelude::*; +use std::collections::HashSet; +use tokio::time::timeout; +use tower::{Service, ServiceExt}; +use tracing::instrument; + +use crate::block_downloader::BlockDownloadTaskResponse; +use crate::{ + block_downloader::{BlockBatch, BlockDownloadError}, + client_pool::ClientPoolDropGuard, + constants::{BLOCK_DOWNLOADER_REQUEST_TIMEOUT, MAX_TRANSACTION_BLOB_SIZE, MEDIUM_BAN}, +}; + +#[instrument( + level = "debug", + name = "download_batch", + skip_all, + fields( + start_height = expected_start_height, + attempt + ) +)] +pub async fn download_batch_task( + client: ClientPoolDropGuard, + ids: ByteArrayVec<32>, + expected_start_height: u64, + attempt: usize, +) -> BlockDownloadTaskResponse { + BlockDownloadTaskResponse { + start_height: expected_start_height, + result: request_batch_from_peer(client, ids, expected_start_height).await, + } +} + +/// Requests a sequential batch of blocks from a peer. +/// +/// This function will validate the blocks that were downloaded were the ones asked for and that they match +/// the expected height. +async fn request_batch_from_peer( + mut client: ClientPoolDropGuard, + ids: ByteArrayVec<32>, + expected_start_height: u64, +) -> Result<(ClientPoolDropGuard, BlockBatch), BlockDownloadError> { + // Request the blocks. + let blocks_response = timeout(BLOCK_DOWNLOADER_REQUEST_TIMEOUT, async { + let PeerResponse::GetObjects(blocks_response) = client + .ready() + .await? + .call(PeerRequest::GetObjects(GetObjectsRequest { + blocks: ids.clone(), + pruned: false, + })) + .await? + else { + panic!("Connection task returned wrong response."); + }; + + Ok::<_, BlockDownloadError>(blocks_response) + }) + .await + .map_err(|_| BlockDownloadError::TimedOut)??; + + // Initial sanity checks + if blocks_response.blocks.len() > ids.len() { + client.info.handle.ban_peer(MEDIUM_BAN); + return Err(BlockDownloadError::PeersResponseWasInvalid); + } + + if blocks_response.blocks.len() != ids.len() { + return Err(BlockDownloadError::PeerDidNotHaveRequestedData); + } + let peer_handle = client.info.handle.clone(); + + let blocks = rayon_spawn_async(move || { + deserialize_batch(blocks_response, expected_start_height, ids, peer_handle) + }) + .await; + + let batch = blocks.inspect_err(|e| { + // If the peers response was invalid, ban it. + if matches!(e, BlockDownloadError::PeersResponseWasInvalid) { + client.info.handle.ban_peer(MEDIUM_BAN); + } + })?; + + Ok((client, batch)) +} + +fn deserialize_batch( + blocks_response: GetObjectsResponse, + expected_start_height: u64, + requested_ids: ByteArrayVec<32>, + peer_handle: ConnectionHandle, +) -> Result { + let blocks = blocks_response + .blocks + .into_par_iter() + .enumerate() + .map(|(i, block_entry)| { + let expected_height = u64::try_from(i).unwrap() + expected_start_height; + + let mut size = block_entry.block.len(); + + let block = Block::read(&mut block_entry.block.as_ref()) + .map_err(|_| BlockDownloadError::PeersResponseWasInvalid)?; + + // Check the block matches the one requested and the peer sent enough transactions. + if requested_ids[i] != block.hash() || block.txs.len() != block_entry.txs.len() { + return Err(BlockDownloadError::PeersResponseWasInvalid); + } + + // Check the height lines up as expected. + // This must happen after the hash check. + if !block + .number() + .is_some_and(|height| height == expected_height) + { + tracing::warn!( + "Invalid chain, expected height: {expected_height}, got height: {:?}", + block.number() + ); + + // This peer probably did nothing wrong, it was the peer who told us this blockID which + // is misbehaving. + return Err(BlockDownloadError::ChainInvalid); + } + + // Deserialize the transactions. + let txs = block_entry + .txs + .take_normal() + .ok_or(BlockDownloadError::PeersResponseWasInvalid)? + .into_iter() + .map(|tx_blob| { + size += tx_blob.len(); + + if tx_blob.len() > MAX_TRANSACTION_BLOB_SIZE { + return Err(BlockDownloadError::PeersResponseWasInvalid); + } + + Transaction::read(&mut tx_blob.as_ref()) + .map_err(|_| BlockDownloadError::PeersResponseWasInvalid) + }) + .collect::, _>>()?; + + // Make sure the transactions in the block were the ones the peer sent. + let mut expected_txs = block.txs.iter().collect::>(); + + for tx in &txs { + if !expected_txs.remove(&tx.hash()) { + return Err(BlockDownloadError::PeersResponseWasInvalid); + } + } + + if !expected_txs.is_empty() { + return Err(BlockDownloadError::PeersResponseWasInvalid); + } + + Ok(((block, txs), size)) + }) + .collect::, Vec<_>), _>>()?; + + Ok(BlockBatch { + blocks: blocks.0, + size: blocks.1.into_iter().sum(), + peer_handle, + }) +}