add chain_tracker

This commit is contained in:
Boog900 2024-05-25 00:53:13 +01:00
parent 596fed775a
commit 3270560711
No known key found for this signature in database
GPG key ID: 42AB1287CB0041C2
5 changed files with 319 additions and 33 deletions

View file

@ -48,7 +48,7 @@ pub fn new_buffer<T>(max_item_weight: usize) -> (BufferAppender<T>, BufferStream
queue: tx,
sink_waker: sink_waker.clone(),
capacity: capacity_atomic.clone(),
max_item_weight: capacity,
max_item_weight,
},
BufferStream {
queue: rx,

View file

@ -1,29 +1,38 @@
//! # Block Downloader
//!
use std::collections::VecDeque;
mod chain_tracker;
use std::collections::{BTreeMap, BinaryHeap, VecDeque};
use std::sync::Arc;
use std::time::Duration;
use monero_serai::{block::Block, transaction::Transaction};
use rand::prelude::*;
use tokio::task::JoinSet;
use tokio::time::{interval, MissedTickBehavior};
use tower::{Service, ServiceExt};
use crate::block_downloader::chain_tracker::{ChainEntry, ChainTracker};
use async_buffer::{BufferAppender, BufferStream};
use fixed_bytes::ByteArrayVec;
use monero_p2p::client::InternalPeerID;
use monero_p2p::{
handles::ConnectionHandle,
services::{PeerSyncRequest, PeerSyncResponse},
NetworkZone, PeerRequest, PeerResponse, PeerSyncSvc,
};
use monero_wire::protocol::ChainRequest;
use monero_wire::protocol::{ChainRequest, ChainResponse};
use crate::client_pool::ClientPool;
use crate::constants::INITIAL_CHAIN_REQUESTS_TO_SEND;
use crate::client_pool::{ClientPool, ClientPoolDropGuard};
use crate::constants::{INITIAL_CHAIN_REQUESTS_TO_SEND, MEDIUM_BAN};
/// A downloaded batch of blocks.
pub struct BlockBatch {
/// The blocks.
blocks: (Block, Vec<Transaction>),
blocks: Vec<(Block, Vec<Transaction>)>,
/// The size of this batch in bytes.
size: usize,
/// The peer that gave us this block.
peer_handle: ConnectionHandle,
}
@ -36,12 +45,20 @@ pub struct BlockDownloaderConfig {
buffer_size: usize,
/// The size of the in progress queue at which we stop requesting more blocks.
in_progress_queue_size: usize,
/// The [`Duration`] between checking the client pool for free peers.
check_client_pool_interval: Duration,
/// The target size of a single batch of blocks (in bytes).
target_batch_size: usize,
/// The initial amount of blocks to request (in number of blocks)
initial_batch_size: usize,
}
#[derive(Debug, Copy, Clone, PartialOrd, PartialEq, Ord, Eq, thiserror::Error)]
#[derive(Debug, thiserror::Error)]
pub enum BlockDownloadError {
#[error("Failed to find a more advanced chain to follow")]
FailedToFindAChainToFollow,
#[error("The peer did not send any overlapping blocks, unknown start height.")]
PeerSentNoOverlappingBlocks,
#[error("Service error: {0}")]
ServiceError(#[from] tower::BoxError),
}
@ -52,6 +69,8 @@ pub enum ChainSvcRequest {
CompactHistory,
/// A request to find the first unknown
FindFirstUnknown(Vec<[u8; 32]>),
CumulativeDifficulty,
}
/// The response type for the chain service.
@ -64,6 +83,8 @@ pub enum ChainSvcResponse {
/// The response for [`ChainSvcRequest::FindFirstUnknown`], contains the index of the first unknown
/// block.
FindFirstUnknown(usize),
CumulativeDifficulty(u128),
}
/// # Block Downloader
@ -83,6 +104,7 @@ pub fn download_blocks<N: NetworkZone, S>(
) -> BufferStream<BlockBatch> {
let (buffer_appender, buffer_stream) = async_buffer::new_buffer(config.buffer_size);
/*
tokio::spawn(block_downloader(
client_pool,
peer_sync_svc,
@ -90,33 +112,89 @@ pub fn download_blocks<N: NetworkZone, S>(
buffer_appender,
));
*/
buffer_stream
}
async fn block_downloader<N: NetworkZone, S>(
struct BlockDownloader<N: NetworkZone, S, C> {
client_pool: Arc<ClientPool<N>>,
peer_sync_svc: S,
config: BlockDownloaderConfig,
our_chain_svc: C,
block_download_tasks: JoinSet<()>,
chain_entry_task: JoinSet<()>,
buffer_appender: BufferAppender<BlockBatch>,
) -> Result<(), tower::BoxError> {
todo!()
config: BlockDownloaderConfig,
}
struct BestChainFound {
common_ancestor: [u8; 32],
next_hashes: VecDeque<[u8; 32]>,
from_peer: ConnectionHandle,
async fn block_downloader<N: NetworkZone, S, C>(
client_pool: Arc<ClientPool<N>>,
mut peer_sync_svc: S,
mut our_chain_svc: C,
config: BlockDownloaderConfig,
buffer_appender: BufferAppender<BlockBatch>,
) -> Result<(), BlockDownloadError>
where
S: PeerSyncSvc<N> + Clone,
C: Service<ChainSvcRequest, Response = ChainSvcResponse, Error = tower::BoxError>
+ Send
+ 'static,
C::Future: Send + 'static,
{
let mut best_chain_found =
initial_chain_search(&client_pool, peer_sync_svc.clone(), &mut our_chain_svc).await?;
let tasks = JoinSet::new();
let mut ready_queue = BinaryHeap::new();
let mut inflight_queue = BTreeMap::new();
let mut next_request_id = 0;
// The request ID for which we updated `amount_of_blocks_to_request`
// `amount_of_blocks_to_request` will update for every new batch of blocks that come in.
let mut amount_of_blocks_to_request_updated_at = next_request_id;
// The amount of blocks to request in 1 batch, will dynamically update based on block size.
let mut amount_of_blocks_to_request = config.initial_batch_size;
let mut check_client_pool_interval = interval(config.check_client_pool_interval);
check_client_pool_interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
loop {
tokio::select! {
_ = check_client_pool_interval.tick() => {
todo!()
}
}
}
}
async fn handle_free_peer<N: NetworkZone>(
peer: ClientPoolDropGuard<N>,
chain_tracker: &mut ChainTracker<N>,
next_batch_size: usize,
) {
if chain_tracker.block_requests_queued(next_batch_size) < 15
&& chain_tracker.should_ask_for_next_chain_entry(&peer.info.pruning_seed)
{}
}
async fn initial_chain_search<N: NetworkZone, S, C>(
client_pool: &ClientPool<N>,
client_pool: &Arc<ClientPool<N>>,
mut peer_sync_svc: S,
mut our_chain_svc: C,
) -> Result<BestChainFound, BlockDownloadError>
) -> Result<ChainTracker<N>, BlockDownloadError>
where
S: PeerSyncSvc<N>,
C: Service<ChainSvcRequest, Response = ChainSvcResponse> + Send + 'static,
C::Future: Send + 'static,
C: Service<ChainSvcRequest, Response = ChainSvcResponse, Error = tower::BoxError>,
{
let ChainSvcResponse::CompactHistory {
block_ids,
@ -130,6 +208,8 @@ where
panic!("chain service sent wrong response.");
};
let our_genesis = *block_ids.last().expect("Blockchain had no genesis block.");
let PeerSyncResponse::PeersToSyncFrom(mut peers) = peer_sync_svc
.ready()
.await?
@ -165,14 +245,17 @@ where
panic!("connection task returned wrong response!");
};
Ok((chain_res, next_peer.info.handle.clone()))
Ok((chain_res, next_peer.info.id, next_peer.info.handle.clone()))
});
}
let mut res = None;
let mut res: Option<(ChainResponse, InternalPeerID<_>, ConnectionHandle)> = None;
while let Some(task_res) = futs.join_next().await {
let Ok(task_res) = task_res.unwrap() else {
let Ok(task_res): Result<
(ChainResponse, InternalPeerID<_>, ConnectionHandle),
tower::BoxError,
> = task_res.unwrap() else {
continue;
};
@ -188,11 +271,14 @@ where
}
}
let Some((chain_res, peer_handle)) = res else {
let Some((chain_res, peer_id, peer_handle)) = res else {
return Err(BlockDownloadError::FailedToFindAChainToFollow);
};
let hashes: Vec<[u8; 32]> = chain_res.m_block_ids.into();
let hashes: Vec<[u8; 32]> = (&chain_res.m_block_ids).into();
let start_height = chain_res.start_height;
// drop this to deallocate the [`Bytes`].
drop(chain_res);
let ChainSvcResponse::FindFirstUnknown(first_unknown) = our_chain_svc
.ready()
@ -202,7 +288,19 @@ where
else {
panic!("chain service sent wrong response.");
};
todo!()
if first_unknown == 0 {
peer_handle.ban_peer(MEDIUM_BAN);
return Err(BlockDownloadError::PeerSentNoOverlappingBlocks);
}
let first_entry = ChainEntry {
ids: hashes[first_unknown..].to_vec(),
peer: peer_id,
handle: peer_handle,
};
let tracker = ChainTracker::new(first_entry, start_height, our_genesis);
Ok(tracker)
}

View file

@ -0,0 +1,182 @@
use fixed_bytes::ByteArrayVec;
use std::{cmp::min, collections::VecDeque};
use monero_p2p::{client::InternalPeerID, handles::ConnectionHandle, NetworkZone};
use monero_pruning::{PruningSeed, CRYPTONOTE_MAX_BLOCK_HEIGHT};
use monero_wire::protocol::ChainResponse;
use crate::constants::MEDIUM_BAN;
/// A new chain entry to add to our chain tracker.
#[derive(Debug)]
pub(crate) struct ChainEntry<N: NetworkZone> {
/// A list of block IDs.
pub ids: Vec<[u8; 32]>,
/// The peer who told us about this chain entry.
pub peer: InternalPeerID<N::Addr>,
/// The peer who told us about this chain entry's handle
pub handle: ConnectionHandle,
}
/// A batch of blocks to retrieve.
pub struct BlocksToRetrieve<N: NetworkZone> {
/// The block IDs to get.
pub ids: ByteArrayVec<32>,
/// The expected height of the first block in `ids`.
pub start_height: u64,
/// The peer who told us about this batch.
pub peer_who_told_us: InternalPeerID<N::Addr>,
/// The peer who told us about this batch's handle.
pub peer_who_told_us_handle: ConnectionHandle,
}
pub enum ChainTrackerError {
NewEntryIsInvalid,
NewEntryDoesNotFollowChain,
}
/// # Chain Tracker
///
/// This struct allows following a single chain. It takes in [`ChainEntry`]s and
/// allows getting [`BlocksToRetrieve`].
pub struct ChainTracker<N: NetworkZone> {
/// A list of [`ChainEntry`]s, in order.
entries: VecDeque<ChainEntry<N>>,
/// The height of the first block, in the first entry in entries.
first_height: u64,
/// The hash of the last block in the last entry.
top_seen_hash: [u8; 32],
/// The hash of the genesis block.
our_genesis: [u8; 32],
}
impl<N: NetworkZone> ChainTracker<N> {
pub fn new(new_entry: ChainEntry<N>, first_height: u64, our_genesis: [u8; 32]) -> Self {
let top_seen_hash = *new_entry.ids.last().unwrap();
let mut entries = VecDeque::with_capacity(1);
entries.push_back(new_entry);
Self {
top_seen_hash,
entries,
first_height,
our_genesis,
}
}
/// Returns `true` if the peer is expected to have the next block after our highest seen block
/// according to their pruning seed.
pub fn should_ask_for_next_chain_entry(&self, seed: &PruningSeed) -> bool {
let top_block_idx = self
.entries
.iter()
.map(|entry| entry.ids.len())
.sum::<usize>();
seed.has_full_block(
self.first_height + u64::try_from(top_block_idx).unwrap(),
CRYPTONOTE_MAX_BLOCK_HEIGHT,
)
}
/// Returns the simple history, the highest seen block and the genesis block.
pub fn get_simple_history(&self) -> [[u8; 32]; 2] {
[self.top_seen_hash, self.our_genesis]
}
/// Returns the total number of queued batches for a certain `batch_size`.
pub fn block_requests_queued(&self, batch_size: usize) -> usize {
self.entries
.iter()
.map(|entry| entry.ids.len().div_ceil(batch_size))
.sum()
}
pub fn add_entry(
&mut self,
mut chain_entry: ChainResponse,
peer: InternalPeerID<N::Addr>,
handle: ConnectionHandle,
) -> Result<(), ChainTrackerError> {
// TODO: check chain entries length.
if chain_entry.m_block_ids.is_empty() {
// The peer must send at lest one overlapping block.
handle.ban_peer(MEDIUM_BAN);
return Err(ChainTrackerError::NewEntryIsInvalid);
}
if self
.entries
.back()
.is_some_and(|last_entry| last_entry.ids.last().unwrap() != &chain_entry.m_block_ids[0])
{
return Err(ChainTrackerError::NewEntryDoesNotFollowChain);
}
tracing::warn!("len: {}", chain_entry.m_block_ids.len());
let new_entry = ChainEntry {
// ignore the first block - we already know it.
ids: (&chain_entry.m_block_ids.split_off(1)).into(),
peer,
handle,
};
self.top_seen_hash = *new_entry.ids.last().unwrap();
self.entries.push_back(new_entry);
Ok(())
}
pub fn blocks_to_get(
&mut self,
pruning_seed: &PruningSeed,
max_blocks: usize,
) -> Option<BlocksToRetrieve<N>> {
if !pruning_seed.has_full_block(self.first_height, CRYPTONOTE_MAX_BLOCK_HEIGHT) {
return None;
}
// TODO: make sure max block height is enforced.
let entry = self.entries.front_mut()?;
// Calculate the ending index for us to get in this batch, will be the smallest out of `max_blocks`, the length of the batch or
// the index of the next pruned block for this seed.
let end_idx = min(
min(entry.ids.len(), max_blocks),
usize::try_from(
pruning_seed
.get_next_pruned_block(self.first_height, CRYPTONOTE_MAX_BLOCK_HEIGHT)
// We check the first height is less than CRYPTONOTE_MAX_BLOCK_HEIGHT in response task.
.unwrap()
// Use a big value as a fallback if the seed does no pruning.
.unwrap_or(CRYPTONOTE_MAX_BLOCK_HEIGHT)
- self.first_height,
)
.unwrap(),
);
if end_idx == 0 {
return None;
}
let ids_to_get = entry.ids.drain(0..end_idx).collect::<Vec<_>>();
let blocks = BlocksToRetrieve {
ids: ids_to_get.into(),
start_height: self.first_height,
peer_who_told_us: entry.peer,
peer_who_told_us_handle: entry.handle.clone(),
};
self.first_height += u64::try_from(end_idx).unwrap();
if entry.ids.is_empty() {
self.entries.pop_front();
}
Some(blocks)
}
}

View file

@ -138,13 +138,16 @@ impl<N: NetworkZone> ClientPool<N> {
pub fn borrow_clients<'a, 'b>(
self: &'a Arc<Self>,
peers: &'b [InternalPeerID<N::Addr>],
) -> impl Iterator<Item = ClientPoolDropGuard<N>> + Captures<(&'a (), &'b ())> {
) -> impl Iterator<Item = ClientPoolDropGuard<N>> + sealed::Captures<(&'a (), &'b ())> {
peers.iter().filter_map(|peer| self.borrow_client(peer))
}
}
/// TODO: Remove me when 2024 Rust
///
/// https://rust-lang.github.io/rfcs/3498-lifetime-capture-rules-2024.html#the-captures-trick
trait Captures<U> {}
impl<T: ?Sized, U> Captures<U> for T {}
mod sealed {
/// TODO: Remove me when 2024 Rust
///
/// https://rust-lang.github.io/rfcs/3498-lifetime-capture-rules-2024.html#the-captures-trick
pub trait Captures<U> {}
impl<T: ?Sized, U> Captures<U> for T {}
}

View file

@ -12,6 +12,9 @@ pub(crate) const OUTBOUND_CONNECTION_ATTEMPT_TIMEOUT: Duration = Duration::from_
/// The durations of a short ban.
pub(crate) const SHORT_BAN: Duration = Duration::from_secs(60 * 10);
/// The durations of a medium ban.
pub(crate) const MEDIUM_BAN: Duration = Duration::from_secs(60 * 10 * 24);
/// The default amount of time between inbound diffusion flushes.
pub(crate) const DIFFUSION_FLUSH_AVERAGE_SECONDS_INBOUND: Duration = Duration::from_secs(5);