diff --git a/consensus/Cargo.toml b/consensus/Cargo.toml index 846c1ac..dce5ce8 100644 --- a/consensus/Cargo.toml +++ b/consensus/Cargo.toml @@ -36,14 +36,15 @@ crypto-bigint = "0.5" curve25519-dalek = "4" randomx-rs = "1" -monero-serai = {git="https://github.com/Cuprate/serai.git", rev = "39eafae"} -multiexp = {git="https://github.com/Cuprate/serai.git", rev = "39eafae"} -dalek-ff-group = {git="https://github.com/Cuprate/serai.git", rev = "39eafae"} +monero-serai = {git="https://github.com/serai-dex/serai.git", rev = "c328e5e"} +multiexp = {git="https://github.com/serai-dex/serai.git", rev = "c328e5e"} +dalek-ff-group = {git="https://github.com/serai-dex/serai.git", rev = "c328e5e"} cuprate-common = {path = "../common"} cryptonight-cuprate = {path = "../cryptonight"} rayon = "1" +thread_local = "1.1.7" tokio = "1" tokio-util = "0.7" diff --git a/consensus/src/batch_verifier.rs b/consensus/src/batch_verifier.rs new file mode 100644 index 0000000..0d4067b --- /dev/null +++ b/consensus/src/batch_verifier.rs @@ -0,0 +1,47 @@ +use std::cell::UnsafeCell; + +use multiexp::BatchVerifier as InternalBatchVerifier; +use rayon::prelude::*; +use thread_local::ThreadLocal; + +use crate::ConsensusError; + +/// A multi threaded batch verifier. +pub struct MultiThreadedBatchVerifier { + internal: ThreadLocal>>, +} + +impl MultiThreadedBatchVerifier { + /// Create a new multithreaded batch verifier, + pub fn new(numb_threads: usize) -> MultiThreadedBatchVerifier { + MultiThreadedBatchVerifier { + internal: ThreadLocal::with_capacity(numb_threads), + } + } + + pub fn queue_statement( + &self, + stmt: impl FnOnce( + &mut InternalBatchVerifier, + ) -> Result<(), ConsensusError>, + ) -> Result<(), ConsensusError> { + let verifier_cell = self + .internal + .get_or(|| UnsafeCell::new(InternalBatchVerifier::new(0))); + // SAFETY: This is safe for 2 reasons: + // 1. each thread gets a different batch verifier. + // 2. only this function `queue_statement` will get the inner batch verifier, it's private. + // + // TODO: it's probably ok to just use RefCell + stmt(unsafe { &mut *verifier_cell.get() }) + } + + pub fn verify(self) -> bool { + self.internal + .into_iter() + .map(UnsafeCell::into_inner) + .par_bridge() + .find_any(|batch_verifer| !batch_verifer.verify_vartime()) + .is_none() + } +} diff --git a/consensus/src/bin/scan_chain.rs b/consensus/src/bin/scan_chain.rs index 2a62f6b..6e05c05 100644 --- a/consensus/src/bin/scan_chain.rs +++ b/consensus/src/bin/scan_chain.rs @@ -24,7 +24,7 @@ use monero_consensus::{ mod tx_pool; -const MAX_BLOCKS_IN_RANGE: u64 = 1000; +const MAX_BLOCKS_IN_RANGE: u64 = 500; const MAX_BLOCKS_HEADERS_IN_RANGE: u64 = 500; /// Calls for a batch of blocks, returning the response and the time it took. @@ -82,19 +82,19 @@ where D::Future: Send + 'static, { let mut next_fut = tokio::spawn(call_batch( - start_height..(start_height + (MAX_BLOCKS_IN_RANGE * 2)).min(chain_height), + start_height..(start_height + (MAX_BLOCKS_IN_RANGE * 3)).min(chain_height), database.clone(), )); for next_batch_start in (start_height..chain_height) - .step_by((MAX_BLOCKS_IN_RANGE * 2) as usize) + .step_by((MAX_BLOCKS_IN_RANGE * 3) as usize) .skip(1) { // Call the next batch while we handle this batch. let current_fut = std::mem::replace( &mut next_fut, tokio::spawn(call_batch( - next_batch_start..(next_batch_start + (MAX_BLOCKS_IN_RANGE * 2)).min(chain_height), + next_batch_start..(next_batch_start + (MAX_BLOCKS_IN_RANGE * 3)).min(chain_height), database.clone(), )), ); @@ -105,7 +105,7 @@ where tracing::info!( "Retrived batch: {:?}, chain height: {}", - (next_batch_start - (MAX_BLOCKS_IN_RANGE * 2))..(next_batch_start), + (next_batch_start - (MAX_BLOCKS_IN_RANGE * 3))..(next_batch_start), chain_height ); @@ -162,7 +162,7 @@ where call_blocks(new_tx_chan, block_tx, start_height, chain_height, database).await }); - let (mut prepared_blocks_tx, mut prepared_blocks_rx) = mpsc::channel(2); + let (mut prepared_blocks_tx, mut prepared_blocks_rx) = mpsc::channel(3); let mut cloned_block_verifier = block_verifier.clone(); tokio::spawn(async move { @@ -170,14 +170,14 @@ where while !next_blocks.is_empty() { tracing::info!( "preparing next batch, number of blocks: {}", - next_blocks.len().min(100) + next_blocks.len().min(150) ); let res = cloned_block_verifier .ready() .await? .call(VerifyBlockRequest::BatchSetup( - next_blocks.drain(0..next_blocks.len().min(100)).collect(), + next_blocks.drain(0..next_blocks.len().min(150)).collect(), )) .await; @@ -242,7 +242,7 @@ async fn main() { let urls = vec![ "http://xmr-node.cakewallet.com:18081".to_string(), - "http://node.sethforprivacy.com".to_string(), + "https://node.sethforprivacy.com".to_string(), "http://nodex.monerujo.io:18081".to_string(), "http://nodes.hashvault.pro:18081".to_string(), "http://node.c3pool.com:18081".to_string(), @@ -254,7 +254,7 @@ async fn main() { "http://145.239.97.211:18089".to_string(), // "http://xmr-node.cakewallet.com:18081".to_string(), - "http://node.sethforprivacy.com".to_string(), + "https://node.sethforprivacy.com".to_string(), "http://nodex.monerujo.io:18081".to_string(), "http://nodes.hashvault.pro:18081".to_string(), "http://node.c3pool.com:18081".to_string(), diff --git a/consensus/src/block.rs b/consensus/src/block.rs index 689aea7..3e8331f 100644 --- a/consensus/src/block.rs +++ b/consensus/src/block.rs @@ -175,32 +175,17 @@ fn prepare_block(block: Block) -> Result { } }; - let block_hashing_blob = block.serialize_hashable(); - let (pow_hash, mut prepared_block) = rayon::join( - || { - // we calculate the POW hash on a different task because this takes a massive amount of time. - calculate_pow_hash(&block_hashing_blob, height, &hf_version) - }, - || { - PrePreparedBlock { - block_blob: block.serialize(), - block_hash: block.hash(), - // set a dummy pow hash for now. We use u8::MAX so if something odd happens and this value isn't changed it will fail for - // difficulties > 1. - pow_hash: [u8::MAX; 32], - miner_tx_weight: block.miner_tx.weight(), - block, - hf_vote, - hf_version, - } - }, - ); + tracing::debug!("preparing block: {}", height); - prepared_block.pow_hash = pow_hash?; - - tracing::debug!("prepared block: {}", height); - - Ok(prepared_block) + Ok(PrePreparedBlock { + block_blob: block.serialize(), + block_hash: block.hash(), + pow_hash: calculate_pow_hash(&block.serialize_hashable(), height, &hf_version)?, + miner_tx_weight: block.miner_tx.weight(), + block, + hf_vote, + hf_version, + }) } async fn verify_prepared_main_chain_block( @@ -231,22 +216,29 @@ where tracing::debug!("got blockchain context: {:?}", context); - let TxPoolResponse::Transactions(txs) = tx_pool - .oneshot(TxPoolRequest::Transactions(block.block.txs.clone())) - .await?; + let txs = if !block.block.txs.is_empty() { + let TxPoolResponse::Transactions(txs) = tx_pool + .oneshot(TxPoolRequest::Transactions(block.block.txs.clone())) + .await?; + txs + } else { + vec![] + }; let block_weight = block.miner_tx_weight + txs.iter().map(|tx| tx.tx_weight).sum::(); let total_fees = txs.iter().map(|tx| tx.fee).sum::(); - tx_verifier_svc - .oneshot(VerifyTxRequest::Block { - txs: txs.clone(), - current_chain_height: context.chain_height, - time_for_time_lock: context.current_adjusted_timestamp_for_time_lock(), - hf: context.current_hard_fork, - re_org_token: context.re_org_token.clone(), - }) - .await?; + if !txs.is_empty() { + tx_verifier_svc + .oneshot(VerifyTxRequest::Block { + txs: txs.clone(), + current_chain_height: context.chain_height, + time_for_time_lock: context.current_adjusted_timestamp_for_time_lock(), + hf: context.current_hard_fork, + re_org_token: context.re_org_token.clone(), + }) + .await?; + } let generated_coins = miner_tx::check_miner_tx( &block.block.miner_tx, diff --git a/consensus/src/lib.rs b/consensus/src/lib.rs index f1cd38a..190b8a6 100644 --- a/consensus/src/lib.rs +++ b/consensus/src/lib.rs @@ -4,6 +4,7 @@ use std::{ sync::Arc, }; +mod batch_verifier; pub mod block; pub mod context; pub mod genesis; diff --git a/consensus/src/rpc.rs b/consensus/src/rpc.rs index 0f6ecde..507b9d4 100644 --- a/consensus/src/rpc.rs +++ b/consensus/src/rpc.rs @@ -16,21 +16,13 @@ use futures::{ FutureExt, StreamExt, TryFutureExt, TryStreamExt, }; use monero_serai::rpc::{HttpRpc, RpcConnection, RpcError}; -use rayon::prelude::*; -use serde::{Deserialize, Serialize}; -use serde_json::json; use tower::{balance::p2c::Balance, util::BoxService, ServiceExt}; -use tracing::{instrument, Instrument}; +use tracing_subscriber::filter::FilterExt; -use cuprate_common::BlockID; -use monero_wire::common::{BlockCompleteEntry, TransactionBlobs}; - -use crate::{ - helper::rayon_spawn_async, DatabaseRequest, DatabaseResponse, ExtendedBlockHeader, HardFork, - OutputOnChain, -}; +use crate::{DatabaseRequest, DatabaseResponse}; pub mod cache; +mod connection; mod discover; use cache::ScanningCache; @@ -90,29 +82,35 @@ pub fn init_rpc_load_balancer( Box> + Send + 'static>, >, > + Clone { - let (rpc_discoverer_tx, rpc_discoverer_rx) = futures::channel::mpsc::channel(30); + let (rpc_discoverer_tx, rpc_discoverer_rx) = futures::channel::mpsc::channel(0); - let rpc_balance = Balance::new(rpc_discoverer_rx.map(Result::<_, tower::BoxError>::Ok)); - let timeout = tower::timeout::Timeout::new(rpc_balance, Duration::from_secs(300)); - let rpc_buffer = tower::buffer::Buffer::new(BoxService::new(timeout), 50); + let rpc_balance = Balance::new(Box::pin( + rpc_discoverer_rx.map(Result::<_, tower::BoxError>::Ok), + )); + let rpc_buffer = tower::buffer::Buffer::new(rpc_balance, 500); let rpcs = tower::retry::Retry::new(Attempts(10), rpc_buffer); let discover = discover::RPCDiscover { initial_list: addresses, ok_channel: rpc_discoverer_tx, already_connected: Default::default(), - cache, + cache: cache.clone(), }; tokio::spawn(discover.run()); - RpcBalancer { rpcs, config } + RpcBalancer { + rpcs, + config, + cache, + } } #[derive(Clone)] pub struct RpcBalancer { rpcs: T, config: Arc>, + cache: Arc>, } impl tower::Service for RpcBalancer @@ -138,7 +136,27 @@ where let config_mutex = self.config.clone(); let config = config_mutex.read().unwrap(); + let cache = self.cache.clone(); + match req { + DatabaseRequest::CheckKIsNotSpent(kis) => async move { + Ok(DatabaseResponse::CheckKIsNotSpent( + cache.read().unwrap().are_kis_spent(kis), + )) + } + .boxed(), + DatabaseRequest::GeneratedCoins => async move { + Ok(DatabaseResponse::GeneratedCoins( + cache.read().unwrap().already_generated_coins, + )) + } + .boxed(), + DatabaseRequest::NumberOutputsWithAmount(amt) => async move { + Ok(DatabaseResponse::NumberOutputsWithAmount( + cache.read().unwrap().numb_outs(amt), + )) + } + .boxed(), DatabaseRequest::BlockBatchInRange(range) => { let resp_to_ret = |resp: DatabaseResponse| { let DatabaseResponse::BlockBatchInRange(pow_info) = resp else { @@ -265,373 +283,3 @@ where } .boxed() } - -enum RpcState { - Locked, - Acquiring(OwnedMutexLockFuture>), - Acquired(OwnedMutexGuard>), -} -pub struct Rpc { - rpc: Arc>>, - addr: String, - rpc_state: RpcState, - cache: Arc>, - error_slot: Arc>>, -} - -impl Rpc { - pub fn new_http(addr: String, cache: Arc>) -> Rpc { - let http_rpc = HttpRpc::new(addr.clone()).unwrap(); - Rpc { - rpc: Arc::new(futures::lock::Mutex::new(http_rpc)), - addr, - rpc_state: RpcState::Locked, - cache, - error_slot: Arc::new(Mutex::new(None)), - } - } -} - -impl tower::Service for Rpc { - type Response = DatabaseResponse; - type Error = tower::BoxError; - type Future = - Pin> + Send + 'static>>; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - if let Some(rpc_error) = self.error_slot.lock().unwrap().clone() { - return Poll::Ready(Err(rpc_error.into())); - } - loop { - match &mut self.rpc_state { - RpcState::Locked => { - self.rpc_state = RpcState::Acquiring(Arc::clone(&self.rpc).lock_owned()) - } - RpcState::Acquiring(rpc) => { - self.rpc_state = RpcState::Acquired(futures::ready!(rpc.poll_unpin(cx))) - } - RpcState::Acquired(_) => return Poll::Ready(Ok(())), - } - } - } - - fn call(&mut self, req: DatabaseRequest) -> Self::Future { - let RpcState::Acquired(rpc) = std::mem::replace(&mut self.rpc_state, RpcState::Locked) - else { - panic!("poll_ready was not called first!"); - }; - - let cache = self.cache.clone(); - - let span = tracing::info_span!("rpc_request", addr = &self.addr); - - let err_slot = self.error_slot.clone(); - - match req { - DatabaseRequest::BlockHash(height) => async move { - let res: Result<_, RpcError> = rpc - .get_block_hash(height as usize) - .map_ok(DatabaseResponse::BlockHash) - .await; - if let Err(e) = &res { - *err_slot.lock().unwrap() = Some(e.clone()); - } - res.map_err(Into::into) - } - .instrument(span) - .boxed(), - DatabaseRequest::ChainHeight => async move { - let height = cache.read().unwrap().height; - - let hash = rpc - .get_block_hash((height - 1) as usize) - .await - .map_err(Into::::into)?; - - Ok(DatabaseResponse::ChainHeight(height, hash)) - } - .instrument(span) - .boxed(), - DatabaseRequest::CheckKIsNotSpent(kis) => async move { - Ok(DatabaseResponse::CheckKIsNotSpent( - cache.read().unwrap().are_kis_spent(kis), - )) - } - .instrument(span) - .boxed(), - DatabaseRequest::GeneratedCoins => async move { - Ok(DatabaseResponse::GeneratedCoins( - cache.read().unwrap().already_generated_coins, - )) - } - .instrument(span) - .boxed(), - - DatabaseRequest::BlockExtendedHeader(id) => { - get_block_info(id, rpc).instrument(span).boxed() - } - DatabaseRequest::BlockExtendedHeaderInRange(range) => { - get_block_info_in_range(range, rpc).instrument(span).boxed() - } - DatabaseRequest::BlockBatchInRange(range) => { - get_blocks_in_range(range, rpc).instrument(span).boxed() - } - DatabaseRequest::Outputs(out_ids) => { - get_outputs(out_ids, cache, rpc).instrument(span).boxed() - } - DatabaseRequest::NumberOutputsWithAmount(amt) => async move { - Ok(DatabaseResponse::NumberOutputsWithAmount( - cache.read().unwrap().numb_outs(amt) as usize, - )) - } - .boxed(), - } - } -} - -#[instrument(skip_all)] -async fn get_outputs( - out_ids: HashMap>, - cache: Arc>, - rpc: OwnedMutexGuard>, -) -> Result { - tracing::info!( - "Getting outputs len: {}", - out_ids.values().map(|amt_map| amt_map.len()).sum::() - ); - - #[derive(Serialize, Copy, Clone)] - struct OutputID { - amount: u64, - index: u64, - } - - #[derive(Serialize, Clone)] - struct Request { - outputs: Vec, - } - - #[derive(Deserialize)] - struct OutputRes { - height: u64, - key: [u8; 32], - mask: [u8; 32], - txid: [u8; 32], - } - - #[derive(Deserialize)] - struct Response { - outs: Vec, - } - - let outputs = rayon_spawn_async(|| { - out_ids - .into_par_iter() - .flat_map(|(amt, amt_map)| { - amt_map - .into_iter() - .map(|amt_idx| OutputID { - amount: amt, - index: amt_idx, - }) - .collect::>() - }) - .collect::>() - }) - .await; - - let res = rpc - .bin_call( - "get_outs.bin", - monero_epee_bin_serde::to_bytes(&Request { - outputs: outputs.clone(), - })?, - ) - .await?; - - rayon_spawn_async(move || { - let outs: Response = monero_epee_bin_serde::from_bytes(&res)?; - - tracing::info!("Got outputs len: {}", outs.outs.len()); - - let mut ret = HashMap::new(); - let cache = cache.read().unwrap(); - - for (out, idx) in outs.outs.iter().zip(outputs) { - ret.entry(idx.amount).or_insert_with(HashMap::new).insert( - idx.index, - OutputOnChain { - height: out.height, - time_lock: cache.outputs_time_lock(&out.txid), - // we unwrap these as we are checking already approved rings so if these points are bad - // then a bad proof has been approved. - key: CompressedEdwardsY::from_slice(&out.key) - .unwrap() - .decompress() - .unwrap(), - mask: CompressedEdwardsY::from_slice(&out.mask) - .unwrap() - .decompress() - .unwrap(), - }, - ); - } - Ok(DatabaseResponse::Outputs(ret)) - }) - .await -} - -async fn get_blocks_in_range( - range: Range, - rpc: OwnedMutexGuard>, -) -> Result { - tracing::info!("Getting blocks in range: {:?}", range); - - #[derive(Serialize)] - pub struct Request { - pub heights: Vec, - } - - #[derive(Deserialize)] - pub struct Response { - pub blocks: Vec, - } - - let res = rpc - .bin_call( - "get_blocks_by_height.bin", - monero_epee_bin_serde::to_bytes(&Request { - heights: range.collect(), - })?, - ) - .await?; - - let blocks: Response = monero_epee_bin_serde::from_bytes(res)?; - - Ok(DatabaseResponse::BlockBatchInRange( - rayon_spawn_async(|| { - blocks - .blocks - .into_par_iter() - .map(|b| { - Ok(( - monero_serai::block::Block::read(&mut b.block.as_slice())?, - match b.txs { - TransactionBlobs::Pruned(_) => { - return Err("node sent pruned txs!".into()) - } - TransactionBlobs::Normal(txs) => txs - .into_par_iter() - .map(|tx| { - monero_serai::transaction::Transaction::read(&mut tx.as_slice()) - }) - .collect::>()?, - TransactionBlobs::None => vec![], - }, - )) - }) - .collect::>() - }) - .await?, - )) -} - -#[derive(Deserialize, Debug)] -struct BlockInfo { - cumulative_difficulty: u64, - cumulative_difficulty_top64: u64, - timestamp: u64, - block_weight: usize, - long_term_weight: usize, - - major_version: u8, - minor_version: u8, -} - -async fn get_block_info_in_range( - range: Range, - rpc: OwnedMutexGuard>, -) -> Result { - #[derive(Deserialize, Debug)] - struct Response { - headers: Vec, - } - - let res = rpc - .json_rpc_call::( - "get_block_headers_range", - Some(json!({"start_height": range.start, "end_height": range.end - 1})), - ) - .await?; - - tracing::info!("Retrieved block headers in range: {:?}", range); - - Ok(DatabaseResponse::BlockExtendedHeaderInRange( - rayon_spawn_async(|| { - res.headers - .into_par_iter() - .map(|info| ExtendedBlockHeader { - version: HardFork::from_version(&info.major_version) - .expect("previously checked block has incorrect version"), - vote: HardFork::from_vote(&info.minor_version), - timestamp: info.timestamp, - cumulative_difficulty: u128_from_low_high( - info.cumulative_difficulty, - info.cumulative_difficulty_top64, - ), - block_weight: info.block_weight, - long_term_weight: info.long_term_weight, - }) - .collect() - }) - .await, - )) -} - -async fn get_block_info( - id: BlockID, - rpc: OwnedMutexGuard>, -) -> Result { - tracing::info!("Retrieving block info with id: {}", id); - - #[derive(Deserialize, Debug)] - struct Response { - block_header: BlockInfo, - } - - let info = match id { - BlockID::Height(height) => { - let res = rpc - .json_rpc_call::( - "get_block_header_by_height", - Some(json!({"height": height})), - ) - .await?; - res.block_header - } - BlockID::Hash(hash) => { - let res = rpc - .json_rpc_call::("get_block_header_by_hash", Some(json!({"hash": hash}))) - .await?; - res.block_header - } - }; - - Ok(DatabaseResponse::BlockExtendedHeader(ExtendedBlockHeader { - version: HardFork::from_version(&info.major_version) - .expect("previously checked block has incorrect version"), - vote: HardFork::from_vote(&info.minor_version), - timestamp: info.timestamp, - cumulative_difficulty: u128_from_low_high( - info.cumulative_difficulty, - info.cumulative_difficulty_top64, - ), - block_weight: info.block_weight, - long_term_weight: info.long_term_weight, - })) -} - -fn u128_from_low_high(low: u64, high: u64) -> u128 { - let res: u128 = high as u128; - res << 64 | low as u128 -} diff --git a/consensus/src/rpc/cache.rs b/consensus/src/rpc/cache.rs index 2c3366c..3ad0531 100644 --- a/consensus/src/rpc/cache.rs +++ b/consensus/src/rpc/cache.rs @@ -22,7 +22,7 @@ use crate::transactions::TransactionVerificationData; #[derive(Debug, Default, Clone, Encode, Decode)] pub struct ScanningCache { // network: u8, - numb_outs: HashMap, + numb_outs: HashMap, time_locked_out: HashMap<[u8; 32], u64>, kis: HashSet<[u8; 32]>, pub already_generated_coins: u64, @@ -112,15 +112,15 @@ impl ScanningCache { } } - pub fn total_outs(&self) -> u64 { + pub fn total_outs(&self) -> usize { self.numb_outs.values().sum() } - pub fn numb_outs(&self, amount: u64) -> u64 { + pub fn numb_outs(&self, amount: u64) -> usize { *self.numb_outs.get(&amount).unwrap_or(&0) } - pub fn add_outs(&mut self, amount: u64, count: u64) { + pub fn add_outs(&mut self, amount: u64, count: usize) { if let Some(numb_outs) = self.numb_outs.get_mut(&amount) { *numb_outs += count; } else { diff --git a/consensus/src/rpc/connection.rs b/consensus/src/rpc/connection.rs new file mode 100644 index 0000000..7873cff --- /dev/null +++ b/consensus/src/rpc/connection.rs @@ -0,0 +1,441 @@ +use std::{ + collections::{HashMap, HashSet}, + future::Future, + ops::Range, + pin::Pin, + sync::{Arc, RwLock}, + task::{Context, Poll}, +}; + +use curve25519_dalek::edwards::CompressedEdwardsY; +use futures::{ + channel::{mpsc, oneshot}, + ready, FutureExt, SinkExt, StreamExt, TryStreamExt, +}; +use monero_serai::{ + block::Block, + rpc::{HttpRpc, Rpc, RpcError}, + transaction::Transaction, +}; +use monero_wire::common::{BlockCompleteEntry, TransactionBlobs}; +use rayon::prelude::*; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use tokio::{ + task::JoinHandle, + time::{timeout, Duration}, +}; +use tower::Service; +use tracing::{instrument, Instrument}; + +use cuprate_common::BlockID; + +use super::ScanningCache; +use crate::{ + helper::rayon_spawn_async, DatabaseRequest, DatabaseResponse, ExtendedBlockHeader, HardFork, + OutputOnChain, +}; +const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300); + +pub struct RpcConnectionSvc { + pub(crate) address: String, + + pub(crate) rpc_task_handle: JoinHandle<()>, + pub(crate) rpc_task_chan: mpsc::Sender, +} + +impl Service for RpcConnectionSvc { + type Response = DatabaseResponse; + type Error = tower::BoxError; + type Future = + Pin> + Send + 'static>>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + if self.rpc_task_handle.is_finished() { + return Poll::Ready(Err("RPC task has exited!".into())); + } + self.rpc_task_chan.poll_ready(cx).map_err(Into::into) + } + + fn call(&mut self, req: DatabaseRequest) -> Self::Future { + let (tx, rx) = oneshot::channel(); + + let req = RpcReq { + req, + res_chan: tx, + span: tracing::info_span!(parent: &tracing::Span::current(), "rpc", addr = &self.address), + }; + + self.rpc_task_chan + .try_send(req) + .expect("poll_ready should be called first!"); + + async move { + rx.await + .expect("sender will not be dropped without response") + } + .boxed() + } +} + +pub(crate) struct RpcReq { + req: DatabaseRequest, + res_chan: oneshot::Sender>, + span: tracing::Span, +} + +pub struct RpcConnection { + pub(crate) address: String, + + pub(crate) con: Rpc, + pub(crate) cache: Arc>, + + pub(crate) req_chan: mpsc::Receiver, +} + +impl RpcConnection { + async fn get_block_hash(&self, height: u64) -> Result<[u8; 32], tower::BoxError> { + self.con + .get_block_hash(height.try_into().unwrap()) + .await + .map_err(Into::into) + } + + async fn get_extended_block_header( + &self, + id: BlockID, + ) -> Result { + tracing::info!("Retrieving block info with id: {}", id); + + #[derive(Deserialize, Debug)] + struct Response { + block_header: BlockInfo, + } + + let info = match id { + BlockID::Height(height) => { + let res = self + .con + .json_rpc_call::( + "get_block_header_by_height", + Some(json!({"height": height})), + ) + .await?; + res.block_header + } + BlockID::Hash(hash) => { + let res = self + .con + .json_rpc_call::( + "get_block_header_by_hash", + Some(json!({"hash": hash})), + ) + .await?; + res.block_header + } + }; + + Ok(ExtendedBlockHeader { + version: HardFork::from_version(&info.major_version) + .expect("previously checked block has incorrect version"), + vote: HardFork::from_vote(&info.minor_version), + timestamp: info.timestamp, + cumulative_difficulty: u128_from_low_high( + info.cumulative_difficulty, + info.cumulative_difficulty_top64, + ), + block_weight: info.block_weight, + long_term_weight: info.long_term_weight, + }) + } + + async fn get_extended_block_header_in_range( + &self, + range: Range, + ) -> Result, tower::BoxError> { + #[derive(Deserialize, Debug)] + struct Response { + headers: Vec, + } + + let res = self + .con + .json_rpc_call::( + "get_block_headers_range", + Some(json!({"start_height": range.start, "end_height": range.end - 1})), + ) + .await?; + + tracing::info!("Retrieved block headers in range: {:?}", range); + + Ok(rayon_spawn_async(|| { + res.headers + .into_iter() + .map(|info| ExtendedBlockHeader { + version: HardFork::from_version(&info.major_version) + .expect("previously checked block has incorrect version"), + vote: HardFork::from_vote(&info.minor_version), + timestamp: info.timestamp, + cumulative_difficulty: u128_from_low_high( + info.cumulative_difficulty, + info.cumulative_difficulty_top64, + ), + block_weight: info.block_weight, + long_term_weight: info.long_term_weight, + }) + .collect() + }) + .await) + } + + async fn get_blocks_in_range( + &self, + range: Range, + ) -> Result)>, tower::BoxError> { + tracing::info!("Getting blocks in range: {:?}", range); + + #[derive(Serialize)] + pub struct Request { + pub heights: Vec, + } + + #[derive(Deserialize)] + pub struct Response { + pub blocks: Vec, + } + + let res = self + .con + .bin_call( + "get_blocks_by_height.bin", + monero_epee_bin_serde::to_bytes(&Request { + heights: range.collect(), + })?, + ) + .await?; + + let blocks: Response = monero_epee_bin_serde::from_bytes(res)?; + + Ok(rayon_spawn_async(|| { + blocks + .blocks + .into_par_iter() + .map(|b| { + Ok(( + Block::read(&mut b.block.as_slice())?, + match b.txs { + TransactionBlobs::Pruned(_) => { + return Err("node sent pruned txs!".into()) + } + TransactionBlobs::Normal(txs) => txs + .into_par_iter() + .map(|tx| Transaction::read(&mut tx.as_slice())) + .collect::>()?, + TransactionBlobs::None => vec![], + }, + )) + }) + .collect::>() + }) + .await?) + } + + async fn get_outputs( + &self, + out_ids: HashMap>, + ) -> Result>, tower::BoxError> { + tracing::info!( + "Getting outputs len: {}", + out_ids.values().map(|amt_map| amt_map.len()).sum::() + ); + + #[derive(Serialize, Copy, Clone)] + struct OutputID { + amount: u64, + index: u64, + } + + #[derive(Serialize, Clone)] + struct Request { + outputs: Vec, + } + + #[derive(Deserialize)] + struct OutputRes { + height: u64, + key: [u8; 32], + mask: [u8; 32], + txid: [u8; 32], + } + + #[derive(Deserialize)] + struct Response { + outs: Vec, + } + + let outputs = out_ids + .into_iter() + .flat_map(|(amt, amt_map)| { + amt_map + .into_iter() + .map(|amt_idx| OutputID { + amount: amt, + index: amt_idx, + }) + .collect::>() + }) + .collect::>(); + + let res = self + .con + .bin_call( + "get_outs.bin", + monero_epee_bin_serde::to_bytes(&Request { + outputs: outputs.clone(), + })?, + ) + .await?; + + let cache = self.cache.clone(); + let span = tracing::Span::current(); + rayon_spawn_async(move || { + let outs: Response = monero_epee_bin_serde::from_bytes(&res)?; + + tracing::info!(parent: &span, "Got outputs len: {}", outs.outs.len()); + + let mut ret = HashMap::new(); + let cache = cache.read().unwrap(); + + for (out, idx) in outs.outs.iter().zip(outputs) { + ret.entry(idx.amount).or_insert_with(HashMap::new).insert( + idx.index, + OutputOnChain { + height: out.height, + time_lock: cache.outputs_time_lock(&out.txid), + // we unwrap these as we are checking already approved rings so if these points are bad + // then a bad proof has been approved. + key: CompressedEdwardsY::from_slice(&out.key) + .unwrap() + .decompress() + .unwrap(), + mask: CompressedEdwardsY::from_slice(&out.mask) + .unwrap() + .decompress() + .unwrap(), + }, + ); + } + Ok(ret) + }) + .await + } + + async fn handle_request( + &mut self, + req: DatabaseRequest, + ) -> Result { + match req { + DatabaseRequest::BlockHash(height) => { + timeout(DEFAULT_TIMEOUT, self.get_block_hash(height)) + .await? + .map(DatabaseResponse::BlockHash) + } + DatabaseRequest::ChainHeight => { + let height = self.cache.read().unwrap().height; + + let hash = timeout(DEFAULT_TIMEOUT, self.get_block_hash(height - 1)).await??; + + Ok(DatabaseResponse::ChainHeight(height, hash)) + } + DatabaseRequest::BlockExtendedHeader(id) => { + timeout(DEFAULT_TIMEOUT, self.get_extended_block_header(id)) + .await? + .map(DatabaseResponse::BlockExtendedHeader) + } + DatabaseRequest::BlockExtendedHeaderInRange(range) => timeout( + DEFAULT_TIMEOUT, + self.get_extended_block_header_in_range(range), + ) + .await? + .map(DatabaseResponse::BlockExtendedHeaderInRange), + DatabaseRequest::BlockBatchInRange(range) => { + timeout(DEFAULT_TIMEOUT, self.get_blocks_in_range(range)) + .await? + .map(DatabaseResponse::BlockBatchInRange) + } + DatabaseRequest::Outputs(out_ids) => { + timeout(DEFAULT_TIMEOUT, self.get_outputs(out_ids)) + .await? + .map(DatabaseResponse::Outputs) + } + DatabaseRequest::NumberOutputsWithAmount(_) + | DatabaseRequest::GeneratedCoins + | DatabaseRequest::CheckKIsNotSpent(_) => { + panic!("Request does not need RPC connection!") + } + } + } + + #[instrument(level = "info", skip(self), fields(addr = self.address))] + pub async fn check_rpc_alive(&self) -> Result<(), tower::BoxError> { + tracing::debug!("Checking RPC connection"); + + let res = timeout(Duration::from_secs(10), self.con.get_height()).await; + let ok = matches!(res, Ok(Ok(_))); + + if !ok { + tracing::warn!("RPC connection test failed"); + return Err("RPC connection test failed".into()); + } + tracing::info!("RPC connection Ok"); + + Ok(()) + } + + pub async fn run(mut self) { + while let Some(req) = self.req_chan.next().await { + let RpcReq { + req, + span, + res_chan, + } = req; + + let res = self.handle_request(req).instrument(span.clone()).await; + + let is_err = res.is_err(); + if is_err { + tracing::warn!(parent: &span, "Error from RPC: {:?}", res) + } + + let _ = res_chan.send(res); + + if is_err && self.check_rpc_alive().await.is_err() { + break; + } + } + + tracing::warn!("Shutting down RPC connection: {}", self.address); + + self.req_chan.close(); + while let Some(req) = self.req_chan.try_next().unwrap() { + let _ = req.res_chan.send(Err("RPC connection closed!".into())); + } + } +} + +#[derive(Deserialize, Debug)] +struct BlockInfo { + cumulative_difficulty: u64, + cumulative_difficulty_top64: u64, + timestamp: u64, + block_weight: usize, + long_term_weight: usize, + + major_version: u8, + minor_version: u8, +} + +fn u128_from_low_high(low: u64, high: u64) -> u128 { + let res: u128 = high as u128; + res << 64 | low as u128 +} diff --git a/consensus/src/rpc/discover.rs b/consensus/src/rpc/discover.rs index 66f098f..ca01072 100644 --- a/consensus/src/rpc/discover.rs +++ b/consensus/src/rpc/discover.rs @@ -10,46 +10,51 @@ use futures::{ SinkExt, StreamExt, }; use monero_serai::rpc::HttpRpc; -use tokio::time::timeout; use tower::{discover::Change, load::PeakEwma}; use tracing::instrument; -use super::{cache::ScanningCache, Rpc}; +use super::{ + cache::ScanningCache, + connection::{RpcConnection, RpcConnectionSvc}, +}; #[instrument(skip(cache))] -async fn check_rpc(addr: String, cache: Arc>) -> Option> { +async fn check_rpc(addr: String, cache: Arc>) -> Option { tracing::debug!("Sending request to node."); - let rpc = HttpRpc::new(addr.clone()).ok()?; - // make sure the RPC is actually reachable - timeout(Duration::from_secs(2), rpc.get_height()) - .await - .ok()? - .ok()?; - tracing::debug!("Node sent ok response."); + let con = HttpRpc::new(addr.clone()).await.ok()?; + let (tx, rx) = mpsc::channel(1); + let rpc = RpcConnection { + address: addr.clone(), + con, + cache, + req_chan: rx, + }; - Some(Rpc::new_http(addr, cache)) + rpc.check_rpc_alive().await.ok()?; + let handle = tokio::spawn(rpc.run()); + + Some(RpcConnectionSvc { + address: addr, + rpc_task_chan: tx, + rpc_task_handle: handle, + }) } pub(crate) struct RPCDiscover { pub initial_list: Vec, - pub ok_channel: mpsc::Sender>>>, - pub already_connected: HashSet, + pub ok_channel: mpsc::Sender>>, + pub already_connected: usize, pub cache: Arc>, } impl RPCDiscover { - async fn found_rpc(&mut self, rpc: Rpc) -> Result<(), SendError> { - //if self.already_connected.contains(&rpc.addr) { - // return Ok(()); - //} + async fn found_rpc(&mut self, rpc: RpcConnectionSvc) -> Result<(), SendError> { + self.already_connected += 1; - tracing::info!("Connecting to node: {}", &rpc.addr); - - let addr = rpc.addr.clone(); self.ok_channel .send(Change::Insert( - self.already_connected.len(), + self.already_connected, PeakEwma::new( rpc, Duration::from_secs(5000), @@ -58,7 +63,6 @@ impl RPCDiscover { ), )) .await?; - self.already_connected.insert(addr); Ok(()) } diff --git a/consensus/src/transactions.rs b/consensus/src/transactions.rs index 9c9cfb0..73e12c5 100644 --- a/consensus/src/transactions.rs +++ b/consensus/src/transactions.rs @@ -1,7 +1,7 @@ -use std::ops::Deref; use std::{ collections::HashSet, future::Future, + ops::Deref, pin::Pin, sync::Arc, task::{Context, Poll}, diff --git a/consensus/src/transactions/contextual_data.rs b/consensus/src/transactions/contextual_data.rs index 7d3b9d3..a55b6ae 100644 --- a/consensus/src/transactions/contextual_data.rs +++ b/consensus/src/transactions/contextual_data.rs @@ -38,13 +38,15 @@ pub async fn batch_refresh_ring_member_info false, - Input::ToKey { amount, .. } => amount.is_some(), - }) { + // Or if a hf has happened as this will change the default minimum decoys. + if &tx_ring_member_info + .as_ref() + .expect("We just checked if this was None") + .hf + != hf + || tx.tx.prefix.inputs.iter().any(|inp| match inp { + Input::Gen(_) => false, + Input::ToKey { amount, .. } => amount.is_some(), + }) + { txs_needing_partial_refresh.push(tx.clone()); } } diff --git a/consensus/src/transactions/inputs.rs b/consensus/src/transactions/inputs.rs index 253eec1..1edf0c8 100644 --- a/consensus/src/transactions/inputs.rs +++ b/consensus/src/transactions/inputs.rs @@ -56,6 +56,29 @@ fn check_decoy_info(decoy_info: &DecoyInfo, hf: &HardFork) -> Result<(), Consens Ok(()) } +/// Checks that the key image is torsion free. +/// +/// https://cuprate.github.io/monero-book/consensus_rules/transactions.html#torsion-free-key-image +pub(crate) fn check_key_images_torsion(input: &Input) -> Result<(), ConsensusError> { + match input { + Input::ToKey { key_image, .. } => { + // this happens in monero-serai but we may as well duplicate the check. + if !key_image.is_torsion_free() { + return Err(ConsensusError::TransactionHasInvalidInput( + "key image has torsion", + )); + } + } + _ => { + return Err(ConsensusError::TransactionHasInvalidInput( + "Input not ToKey", + )) + } + } + + Ok(()) +} + /// Checks the inputs key images for torsion and for duplicates in the transaction. /// /// The `spent_kis` parameter is not meant to be a complete list of key images, just a list of related transactions @@ -211,6 +234,37 @@ fn sum_inputs_v1(inputs: &[Input]) -> Result { Ok(sum) } +/// Checks the inputs semantics are valid. +/// +/// This does all the checks that don't need blockchain context. +/// +/// Although technically hard-fork is contextual data we class it as not because +/// blocks keep their hf in the header. +pub fn check_inputs_semantics( + inputs: &[Input], + hf: &HardFork, + tx_version: &TxVersion, +) -> Result { + if inputs.is_empty() { + return Err(ConsensusError::TransactionHasInvalidInput("no inputs")); + } + + for input in inputs { + check_input_type(input)?; + check_input_has_decoys(input)?; + + check_ring_members_unique(input, hf)?; + check_key_images_torsion(input)?; + } + + check_inputs_sorted(inputs, hf)?; + + match tx_version { + TxVersion::RingSignatures => sum_inputs_v1(inputs), + _ => panic!("TODO: RCT"), + } +} + /// Checks all input consensus rules. /// /// TODO: list rules. diff --git a/consensus/src/transactions/sigs/ring_sigs.rs b/consensus/src/transactions/sigs/ring_sigs.rs index b73d264..4195897 100644 --- a/consensus/src/transactions/sigs/ring_sigs.rs +++ b/consensus/src/transactions/sigs/ring_sigs.rs @@ -40,15 +40,15 @@ pub fn verify_inputs_signatures( panic!("How did we build a ring with no decoys?"); }; - if !sig.verify_ring_signature(tx_sig_hash, ring, key_image) { + if !sig.verify(tx_sig_hash, ring, key_image) { return Err(ConsensusError::TransactionSignatureInvalid( "Invalid ring signature", )); } Ok(()) })?; - }, - _ => panic!("tried to verify v1 tx with a non v1 ring"), + } + _ => panic!("tried to verify v1 tx with a non v1 ring"), } Ok(()) }