check previous ID is correct

This commit is contained in:
Boog900 2024-06-21 02:32:16 +01:00
parent b8e50eb96d
commit f4f1dc29e5
No known key found for this signature in database
GPG key ID: 42AB1287CB0041C2
5 changed files with 62 additions and 28 deletions

View file

@ -39,19 +39,16 @@ use crate::{
mod block_queue; mod block_queue;
mod chain_tracker; mod chain_tracker;
use crate::block_downloader::request_chain::{initial_chain_search, request_chain_entry_from_peer};
use block_queue::{BlockQueue, ReadyQueueBatch};
use chain_tracker::{BlocksToRetrieve, ChainEntry, ChainTracker};
use download_batch::download_batch_task;
// TODO: check first block in batch prev_id
mod download_batch; mod download_batch;
mod request_chain; mod request_chain;
#[cfg(test)] #[cfg(test)]
mod tests; mod tests;
use block_queue::{BlockQueue, ReadyQueueBatch};
use chain_tracker::{BlocksToRetrieve, ChainEntry, ChainTracker};
use download_batch::download_batch_task;
use request_chain::{initial_chain_search, request_chain_entry_from_peer};
/// A downloaded batch of blocks. /// A downloaded batch of blocks.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct BlockBatch { pub struct BlockBatch {
@ -218,10 +215,6 @@ struct BlockDownloader<N: NetworkZone, S, C> {
amount_of_empty_chain_entries: usize, amount_of_empty_chain_entries: usize,
/// The running block download tasks. /// The running block download tasks.
///
/// Returns:
/// - The start height of the batch
/// - A result containing the batch or an error.
block_download_tasks: JoinSet<BlockDownloadTaskResponse<N>>, block_download_tasks: JoinSet<BlockDownloadTaskResponse<N>>,
/// The running chain entry tasks. /// The running chain entry tasks.
/// ///
@ -342,6 +335,7 @@ where
self.block_download_tasks.spawn(download_batch_task( self.block_download_tasks.spawn(download_batch_task(
client, client,
in_flight_batch.ids.clone(), in_flight_batch.ids.clone(),
in_flight_batch.prev_id,
in_flight_batch.start_height, in_flight_batch.start_height,
in_flight_batch.requests_sent, in_flight_batch.requests_sent,
)); ));
@ -383,15 +377,14 @@ where
) { ) {
tracing::debug!("Using peer to request a failed batch"); tracing::debug!("Using peer to request a failed batch");
// They should have the blocks so send the re-request to this peer. // They should have the blocks so send the re-request to this peer.
let ids = request.ids.clone();
let start_height = request.start_height;
request.requests_sent += 1; request.requests_sent += 1;
self.block_download_tasks.spawn(download_batch_task( self.block_download_tasks.spawn(download_batch_task(
client, client,
ids, request.ids.clone(),
start_height, request.prev_id,
request.start_height,
request.requests_sent, request.requests_sent,
)); ));
@ -426,6 +419,7 @@ where
self.block_download_tasks.spawn(download_batch_task( self.block_download_tasks.spawn(download_batch_task(
client, client,
block_entry_to_get.ids.clone(), block_entry_to_get.ids.clone(),
block_entry_to_get.prev_id,
block_entry_to_get.start_height, block_entry_to_get.start_height,
block_entry_to_get.requests_sent, block_entry_to_get.requests_sent,
)); ));

View file

@ -23,6 +23,8 @@ pub(crate) struct ChainEntry<N: NetworkZone> {
pub struct BlocksToRetrieve<N: NetworkZone> { pub struct BlocksToRetrieve<N: NetworkZone> {
/// The block IDs to get. /// The block IDs to get.
pub ids: ByteArrayVec<32>, pub ids: ByteArrayVec<32>,
/// The hash of the last block before this batch.
pub prev_id: [u8; 32],
/// The expected height of the first block in [`BlocksToRetrieve::ids`]. /// The expected height of the first block in [`BlocksToRetrieve::ids`].
pub start_height: u64, pub start_height: u64,
/// The peer who told us about this batch. /// The peer who told us about this batch.
@ -51,17 +53,24 @@ pub enum ChainTrackerError {
pub struct ChainTracker<N: NetworkZone> { pub struct ChainTracker<N: NetworkZone> {
/// A list of [`ChainEntry`]s, in order. /// A list of [`ChainEntry`]s, in order.
entries: VecDeque<ChainEntry<N>>, entries: VecDeque<ChainEntry<N>>,
/// The height of the first block, in the first entry in entries. /// The height of the first block, in the first entry in [`Self::entries`].
first_height: u64, first_height: u64,
/// The hash of the last block in the last entry. /// The hash of the last block in the last entry.
top_seen_hash: [u8; 32], top_seen_hash: [u8; 32],
/// The hash of the block one below [`Self::first_height`].
previous_hash: [u8; 32],
/// The hash of the genesis block. /// The hash of the genesis block.
our_genesis: [u8; 32], our_genesis: [u8; 32],
} }
impl<N: NetworkZone> ChainTracker<N> { impl<N: NetworkZone> ChainTracker<N> {
/// Creates a new chain tracker. /// Creates a new chain tracker.
pub fn new(new_entry: ChainEntry<N>, first_height: u64, our_genesis: [u8; 32]) -> Self { pub fn new(
new_entry: ChainEntry<N>,
first_height: u64,
our_genesis: [u8; 32],
previous_hash: [u8; 32],
) -> Self {
let top_seen_hash = *new_entry.ids.last().unwrap(); let top_seen_hash = *new_entry.ids.last().unwrap();
let mut entries = VecDeque::with_capacity(1); let mut entries = VecDeque::with_capacity(1);
entries.push_back(new_entry); entries.push_back(new_entry);
@ -70,6 +79,7 @@ impl<N: NetworkZone> ChainTracker<N> {
top_seen_hash, top_seen_hash,
entries, entries,
first_height, first_height,
previous_hash,
our_genesis, our_genesis,
} }
} }
@ -180,6 +190,7 @@ impl<N: NetworkZone> ChainTracker<N> {
let blocks = BlocksToRetrieve { let blocks = BlocksToRetrieve {
ids: ids_to_get.into(), ids: ids_to_get.into(),
prev_id: self.previous_hash,
start_height: self.first_height, start_height: self.first_height,
peer_who_told_us: entry.peer, peer_who_told_us: entry.peer,
peer_who_told_us_handle: entry.handle.clone(), peer_who_told_us_handle: entry.handle.clone(),
@ -188,6 +199,8 @@ impl<N: NetworkZone> ChainTracker<N> {
}; };
self.first_height += u64::try_from(end_idx).unwrap(); self.first_height += u64::try_from(end_idx).unwrap();
// TODO: improve ByteArrayVec API.
self.previous_hash = blocks.ids[blocks.ids.len() - 1];
if entry.ids.is_empty() { if entry.ids.is_empty() {
self.entries.pop_front(); self.entries.pop_front();

View file

@ -6,10 +6,10 @@ use tokio::time::timeout;
use tower::{Service, ServiceExt}; use tower::{Service, ServiceExt};
use tracing::instrument; use tracing::instrument;
use monero_p2p::{NetworkZone, PeerRequest, PeerResponse, handles::ConnectionHandle};
use monero_wire::protocol::{GetObjectsRequest, GetObjectsResponse};
use cuprate_helper::asynch::rayon_spawn_async; use cuprate_helper::asynch::rayon_spawn_async;
use fixed_bytes::ByteArrayVec; use fixed_bytes::ByteArrayVec;
use monero_p2p::{handles::ConnectionHandle, NetworkZone, PeerRequest, PeerResponse};
use monero_wire::protocol::{GetObjectsRequest, GetObjectsResponse};
use crate::{ use crate::{
block_downloader::{BlockBatch, BlockDownloadError, BlockDownloadTaskResponse}, block_downloader::{BlockBatch, BlockDownloadError, BlockDownloadTaskResponse},
@ -17,26 +17,26 @@ use crate::{
constants::{BLOCK_DOWNLOADER_REQUEST_TIMEOUT, MAX_TRANSACTION_BLOB_SIZE, MEDIUM_BAN}, constants::{BLOCK_DOWNLOADER_REQUEST_TIMEOUT, MAX_TRANSACTION_BLOB_SIZE, MEDIUM_BAN},
}; };
/// Attempts to request a batch of blocks from a peer, returning [`BlockDownloadTaskResponse`]. /// Attempts to request a batch of blocks from a peer, returning [`BlockDownloadTaskResponse`].
#[instrument( #[instrument(
level = "debug", level = "debug",
name = "download_batch", name = "download_batch",
skip_all, skip_all,
fields( fields(
start_height = expected_start_height, start_height = expected_start_height,
attempt = _attempt attempt = _attempt
) )
)] )]
pub async fn download_batch_task<N: NetworkZone>( pub async fn download_batch_task<N: NetworkZone>(
client: ClientPoolDropGuard<N>, client: ClientPoolDropGuard<N>,
ids: ByteArrayVec<32>, ids: ByteArrayVec<32>,
previous_id: [u8; 32],
expected_start_height: u64, expected_start_height: u64,
_attempt: usize, _attempt: usize,
) -> BlockDownloadTaskResponse<N> { ) -> BlockDownloadTaskResponse<N> {
BlockDownloadTaskResponse { BlockDownloadTaskResponse {
start_height: expected_start_height, start_height: expected_start_height,
result: request_batch_from_peer(client, ids, expected_start_height).await, result: request_batch_from_peer(client, ids, previous_id, expected_start_height).await,
} }
} }
@ -47,6 +47,7 @@ pub async fn download_batch_task<N: NetworkZone>(
async fn request_batch_from_peer<N: NetworkZone>( async fn request_batch_from_peer<N: NetworkZone>(
mut client: ClientPoolDropGuard<N>, mut client: ClientPoolDropGuard<N>,
ids: ByteArrayVec<32>, ids: ByteArrayVec<32>,
previous_id: [u8; 32],
expected_start_height: u64, expected_start_height: u64,
) -> Result<(ClientPoolDropGuard<N>, BlockBatch), BlockDownloadError> { ) -> Result<(ClientPoolDropGuard<N>, BlockBatch), BlockDownloadError> {
// Request the blocks. // Request the blocks.
@ -80,7 +81,13 @@ async fn request_batch_from_peer<N: NetworkZone>(
let peer_handle = client.info.handle.clone(); let peer_handle = client.info.handle.clone();
let blocks = rayon_spawn_async(move || { let blocks = rayon_spawn_async(move || {
deserialize_batch(blocks_response, expected_start_height, ids, peer_handle) deserialize_batch(
blocks_response,
expected_start_height,
ids,
previous_id,
peer_handle,
)
}) })
.await; .await;
@ -98,6 +105,7 @@ fn deserialize_batch(
blocks_response: GetObjectsResponse, blocks_response: GetObjectsResponse,
expected_start_height: u64, expected_start_height: u64,
requested_ids: ByteArrayVec<32>, requested_ids: ByteArrayVec<32>,
previous_id: [u8; 32],
peer_handle: ConnectionHandle, peer_handle: ConnectionHandle,
) -> Result<BlockBatch, BlockDownloadError> { ) -> Result<BlockBatch, BlockDownloadError> {
let blocks = blocks_response let blocks = blocks_response
@ -112,11 +120,26 @@ fn deserialize_batch(
let block = Block::read(&mut block_entry.block.as_ref()) let block = Block::read(&mut block_entry.block.as_ref())
.map_err(|_| BlockDownloadError::PeersResponseWasInvalid)?; .map_err(|_| BlockDownloadError::PeersResponseWasInvalid)?;
let block_hash = block.hash();
// Check the block matches the one requested and the peer sent enough transactions. // Check the block matches the one requested and the peer sent enough transactions.
if requested_ids[i] != block.hash() || block.txs.len() != block_entry.txs.len() { if requested_ids[i] != block_hash || block.txs.len() != block_entry.txs.len() {
return Err(BlockDownloadError::PeersResponseWasInvalid); return Err(BlockDownloadError::PeersResponseWasInvalid);
} }
// Check that the previous ID is correct for the first block.
// This is to protect use against banning the wrong peer.
// This must happen after the hash check.
if i == 0 && block.header.previous != previous_id {
tracing::warn!(
"Invalid chain, peer told us a block follows the chain when it doesn't."
);
// This peer probably did nothing wrong, it was the peer who told us this blockID which
// is misbehaving.
return Err(BlockDownloadError::ChainInvalid);
}
// Check the height lines up as expected. // Check the height lines up as expected.
// This must happen after the hash check. // This must happen after the hash check.
if !block if !block

View file

@ -219,6 +219,8 @@ where
return Err(BlockDownloadError::FailedToFindAChainToFollow); return Err(BlockDownloadError::FailedToFindAChainToFollow);
} }
let previous_id = hashes[first_unknown - 1];
let first_entry = ChainEntry { let first_entry = ChainEntry {
ids: hashes[first_unknown..].to_vec(), ids: hashes[first_unknown..].to_vec(),
peer: peer_id, peer: peer_id,
@ -230,7 +232,7 @@ where
first_entry.ids.len() first_entry.ids.len()
); );
let tracker = ChainTracker::new(first_entry, expected_height, our_genesis); let tracker = ChainTracker::new(first_entry, expected_height, our_genesis, previous_id);
Ok(tracker) Ok(tracker)
} }

View file

@ -160,6 +160,8 @@ impl<Z: NetworkZone> Service<PeerRequest> for Client<Z> {
}; };
if let Err(e) = self.connection_tx.try_send(req) { if let Err(e) = self.connection_tx.try_send(req) {
// The connection task could have closed between a call to `poll_ready` and the call to
// `call`, which means if we don't handle the error here the receiver would panic.
use mpsc::error::TrySendError; use mpsc::error::TrySendError;
match e { match e {