Use a connection task for RPC connections.

This commit is contained in:
Boog900 2023-11-18 14:00:33 +00:00
parent 10b7400b17
commit 343e979e82
No known key found for this signature in database
GPG key ID: 5401367FB7302004
13 changed files with 677 additions and 481 deletions

View file

@ -36,14 +36,15 @@ crypto-bigint = "0.5"
curve25519-dalek = "4" curve25519-dalek = "4"
randomx-rs = "1" randomx-rs = "1"
monero-serai = {git="https://github.com/Cuprate/serai.git", rev = "39eafae"} monero-serai = {git="https://github.com/serai-dex/serai.git", rev = "c328e5e"}
multiexp = {git="https://github.com/Cuprate/serai.git", rev = "39eafae"} multiexp = {git="https://github.com/serai-dex/serai.git", rev = "c328e5e"}
dalek-ff-group = {git="https://github.com/Cuprate/serai.git", rev = "39eafae"} dalek-ff-group = {git="https://github.com/serai-dex/serai.git", rev = "c328e5e"}
cuprate-common = {path = "../common"} cuprate-common = {path = "../common"}
cryptonight-cuprate = {path = "../cryptonight"} cryptonight-cuprate = {path = "../cryptonight"}
rayon = "1" rayon = "1"
thread_local = "1.1.7"
tokio = "1" tokio = "1"
tokio-util = "0.7" tokio-util = "0.7"

View file

@ -0,0 +1,47 @@
use std::cell::UnsafeCell;
use multiexp::BatchVerifier as InternalBatchVerifier;
use rayon::prelude::*;
use thread_local::ThreadLocal;
use crate::ConsensusError;
/// A multi threaded batch verifier.
pub struct MultiThreadedBatchVerifier {
internal: ThreadLocal<UnsafeCell<InternalBatchVerifier<usize, dalek_ff_group::EdwardsPoint>>>,
}
impl MultiThreadedBatchVerifier {
/// Create a new multithreaded batch verifier,
pub fn new(numb_threads: usize) -> MultiThreadedBatchVerifier {
MultiThreadedBatchVerifier {
internal: ThreadLocal::with_capacity(numb_threads),
}
}
pub fn queue_statement(
&self,
stmt: impl FnOnce(
&mut InternalBatchVerifier<usize, dalek_ff_group::EdwardsPoint>,
) -> Result<(), ConsensusError>,
) -> Result<(), ConsensusError> {
let verifier_cell = self
.internal
.get_or(|| UnsafeCell::new(InternalBatchVerifier::new(0)));
// SAFETY: This is safe for 2 reasons:
// 1. each thread gets a different batch verifier.
// 2. only this function `queue_statement` will get the inner batch verifier, it's private.
//
// TODO: it's probably ok to just use RefCell
stmt(unsafe { &mut *verifier_cell.get() })
}
pub fn verify(self) -> bool {
self.internal
.into_iter()
.map(UnsafeCell::into_inner)
.par_bridge()
.find_any(|batch_verifer| !batch_verifer.verify_vartime())
.is_none()
}
}

View file

@ -24,7 +24,7 @@ use monero_consensus::{
mod tx_pool; mod tx_pool;
const MAX_BLOCKS_IN_RANGE: u64 = 1000; const MAX_BLOCKS_IN_RANGE: u64 = 500;
const MAX_BLOCKS_HEADERS_IN_RANGE: u64 = 500; const MAX_BLOCKS_HEADERS_IN_RANGE: u64 = 500;
/// Calls for a batch of blocks, returning the response and the time it took. /// Calls for a batch of blocks, returning the response and the time it took.
@ -82,19 +82,19 @@ where
D::Future: Send + 'static, D::Future: Send + 'static,
{ {
let mut next_fut = tokio::spawn(call_batch( let mut next_fut = tokio::spawn(call_batch(
start_height..(start_height + (MAX_BLOCKS_IN_RANGE * 2)).min(chain_height), start_height..(start_height + (MAX_BLOCKS_IN_RANGE * 3)).min(chain_height),
database.clone(), database.clone(),
)); ));
for next_batch_start in (start_height..chain_height) for next_batch_start in (start_height..chain_height)
.step_by((MAX_BLOCKS_IN_RANGE * 2) as usize) .step_by((MAX_BLOCKS_IN_RANGE * 3) as usize)
.skip(1) .skip(1)
{ {
// Call the next batch while we handle this batch. // Call the next batch while we handle this batch.
let current_fut = std::mem::replace( let current_fut = std::mem::replace(
&mut next_fut, &mut next_fut,
tokio::spawn(call_batch( tokio::spawn(call_batch(
next_batch_start..(next_batch_start + (MAX_BLOCKS_IN_RANGE * 2)).min(chain_height), next_batch_start..(next_batch_start + (MAX_BLOCKS_IN_RANGE * 3)).min(chain_height),
database.clone(), database.clone(),
)), )),
); );
@ -105,7 +105,7 @@ where
tracing::info!( tracing::info!(
"Retrived batch: {:?}, chain height: {}", "Retrived batch: {:?}, chain height: {}",
(next_batch_start - (MAX_BLOCKS_IN_RANGE * 2))..(next_batch_start), (next_batch_start - (MAX_BLOCKS_IN_RANGE * 3))..(next_batch_start),
chain_height chain_height
); );
@ -162,7 +162,7 @@ where
call_blocks(new_tx_chan, block_tx, start_height, chain_height, database).await call_blocks(new_tx_chan, block_tx, start_height, chain_height, database).await
}); });
let (mut prepared_blocks_tx, mut prepared_blocks_rx) = mpsc::channel(2); let (mut prepared_blocks_tx, mut prepared_blocks_rx) = mpsc::channel(3);
let mut cloned_block_verifier = block_verifier.clone(); let mut cloned_block_verifier = block_verifier.clone();
tokio::spawn(async move { tokio::spawn(async move {
@ -170,14 +170,14 @@ where
while !next_blocks.is_empty() { while !next_blocks.is_empty() {
tracing::info!( tracing::info!(
"preparing next batch, number of blocks: {}", "preparing next batch, number of blocks: {}",
next_blocks.len().min(100) next_blocks.len().min(150)
); );
let res = cloned_block_verifier let res = cloned_block_verifier
.ready() .ready()
.await? .await?
.call(VerifyBlockRequest::BatchSetup( .call(VerifyBlockRequest::BatchSetup(
next_blocks.drain(0..next_blocks.len().min(100)).collect(), next_blocks.drain(0..next_blocks.len().min(150)).collect(),
)) ))
.await; .await;
@ -242,7 +242,7 @@ async fn main() {
let urls = vec![ let urls = vec![
"http://xmr-node.cakewallet.com:18081".to_string(), "http://xmr-node.cakewallet.com:18081".to_string(),
"http://node.sethforprivacy.com".to_string(), "https://node.sethforprivacy.com".to_string(),
"http://nodex.monerujo.io:18081".to_string(), "http://nodex.monerujo.io:18081".to_string(),
"http://nodes.hashvault.pro:18081".to_string(), "http://nodes.hashvault.pro:18081".to_string(),
"http://node.c3pool.com:18081".to_string(), "http://node.c3pool.com:18081".to_string(),
@ -254,7 +254,7 @@ async fn main() {
"http://145.239.97.211:18089".to_string(), "http://145.239.97.211:18089".to_string(),
// //
"http://xmr-node.cakewallet.com:18081".to_string(), "http://xmr-node.cakewallet.com:18081".to_string(),
"http://node.sethforprivacy.com".to_string(), "https://node.sethforprivacy.com".to_string(),
"http://nodex.monerujo.io:18081".to_string(), "http://nodex.monerujo.io:18081".to_string(),
"http://nodes.hashvault.pro:18081".to_string(), "http://nodes.hashvault.pro:18081".to_string(),
"http://node.c3pool.com:18081".to_string(), "http://node.c3pool.com:18081".to_string(),

View file

@ -175,32 +175,17 @@ fn prepare_block(block: Block) -> Result<PrePreparedBlock, ConsensusError> {
} }
}; };
let block_hashing_blob = block.serialize_hashable(); tracing::debug!("preparing block: {}", height);
let (pow_hash, mut prepared_block) = rayon::join(
|| { Ok(PrePreparedBlock {
// we calculate the POW hash on a different task because this takes a massive amount of time.
calculate_pow_hash(&block_hashing_blob, height, &hf_version)
},
|| {
PrePreparedBlock {
block_blob: block.serialize(), block_blob: block.serialize(),
block_hash: block.hash(), block_hash: block.hash(),
// set a dummy pow hash for now. We use u8::MAX so if something odd happens and this value isn't changed it will fail for pow_hash: calculate_pow_hash(&block.serialize_hashable(), height, &hf_version)?,
// difficulties > 1.
pow_hash: [u8::MAX; 32],
miner_tx_weight: block.miner_tx.weight(), miner_tx_weight: block.miner_tx.weight(),
block, block,
hf_vote, hf_vote,
hf_version, hf_version,
} })
},
);
prepared_block.pow_hash = pow_hash?;
tracing::debug!("prepared block: {}", height);
Ok(prepared_block)
} }
async fn verify_prepared_main_chain_block<C, TxV, TxP>( async fn verify_prepared_main_chain_block<C, TxV, TxP>(
@ -231,13 +216,19 @@ where
tracing::debug!("got blockchain context: {:?}", context); tracing::debug!("got blockchain context: {:?}", context);
let txs = if !block.block.txs.is_empty() {
let TxPoolResponse::Transactions(txs) = tx_pool let TxPoolResponse::Transactions(txs) = tx_pool
.oneshot(TxPoolRequest::Transactions(block.block.txs.clone())) .oneshot(TxPoolRequest::Transactions(block.block.txs.clone()))
.await?; .await?;
txs
} else {
vec![]
};
let block_weight = block.miner_tx_weight + txs.iter().map(|tx| tx.tx_weight).sum::<usize>(); let block_weight = block.miner_tx_weight + txs.iter().map(|tx| tx.tx_weight).sum::<usize>();
let total_fees = txs.iter().map(|tx| tx.fee).sum::<u64>(); let total_fees = txs.iter().map(|tx| tx.fee).sum::<u64>();
if !txs.is_empty() {
tx_verifier_svc tx_verifier_svc
.oneshot(VerifyTxRequest::Block { .oneshot(VerifyTxRequest::Block {
txs: txs.clone(), txs: txs.clone(),
@ -247,6 +238,7 @@ where
re_org_token: context.re_org_token.clone(), re_org_token: context.re_org_token.clone(),
}) })
.await?; .await?;
}
let generated_coins = miner_tx::check_miner_tx( let generated_coins = miner_tx::check_miner_tx(
&block.block.miner_tx, &block.block.miner_tx,

View file

@ -4,6 +4,7 @@ use std::{
sync::Arc, sync::Arc,
}; };
mod batch_verifier;
pub mod block; pub mod block;
pub mod context; pub mod context;
pub mod genesis; pub mod genesis;

View file

@ -16,21 +16,13 @@ use futures::{
FutureExt, StreamExt, TryFutureExt, TryStreamExt, FutureExt, StreamExt, TryFutureExt, TryStreamExt,
}; };
use monero_serai::rpc::{HttpRpc, RpcConnection, RpcError}; use monero_serai::rpc::{HttpRpc, RpcConnection, RpcError};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use serde_json::json;
use tower::{balance::p2c::Balance, util::BoxService, ServiceExt}; use tower::{balance::p2c::Balance, util::BoxService, ServiceExt};
use tracing::{instrument, Instrument}; use tracing_subscriber::filter::FilterExt;
use cuprate_common::BlockID; use crate::{DatabaseRequest, DatabaseResponse};
use monero_wire::common::{BlockCompleteEntry, TransactionBlobs};
use crate::{
helper::rayon_spawn_async, DatabaseRequest, DatabaseResponse, ExtendedBlockHeader, HardFork,
OutputOnChain,
};
pub mod cache; pub mod cache;
mod connection;
mod discover; mod discover;
use cache::ScanningCache; use cache::ScanningCache;
@ -90,29 +82,35 @@ pub fn init_rpc_load_balancer(
Box<dyn Future<Output = Result<DatabaseResponse, tower::BoxError>> + Send + 'static>, Box<dyn Future<Output = Result<DatabaseResponse, tower::BoxError>> + Send + 'static>,
>, >,
> + Clone { > + Clone {
let (rpc_discoverer_tx, rpc_discoverer_rx) = futures::channel::mpsc::channel(30); let (rpc_discoverer_tx, rpc_discoverer_rx) = futures::channel::mpsc::channel(0);
let rpc_balance = Balance::new(rpc_discoverer_rx.map(Result::<_, tower::BoxError>::Ok)); let rpc_balance = Balance::new(Box::pin(
let timeout = tower::timeout::Timeout::new(rpc_balance, Duration::from_secs(300)); rpc_discoverer_rx.map(Result::<_, tower::BoxError>::Ok),
let rpc_buffer = tower::buffer::Buffer::new(BoxService::new(timeout), 50); ));
let rpc_buffer = tower::buffer::Buffer::new(rpc_balance, 500);
let rpcs = tower::retry::Retry::new(Attempts(10), rpc_buffer); let rpcs = tower::retry::Retry::new(Attempts(10), rpc_buffer);
let discover = discover::RPCDiscover { let discover = discover::RPCDiscover {
initial_list: addresses, initial_list: addresses,
ok_channel: rpc_discoverer_tx, ok_channel: rpc_discoverer_tx,
already_connected: Default::default(), already_connected: Default::default(),
cache, cache: cache.clone(),
}; };
tokio::spawn(discover.run()); tokio::spawn(discover.run());
RpcBalancer { rpcs, config } RpcBalancer {
rpcs,
config,
cache,
}
} }
#[derive(Clone)] #[derive(Clone)]
pub struct RpcBalancer<T: Clone> { pub struct RpcBalancer<T: Clone> {
rpcs: T, rpcs: T,
config: Arc<RwLock<RpcConfig>>, config: Arc<RwLock<RpcConfig>>,
cache: Arc<RwLock<ScanningCache>>,
} }
impl<T> tower::Service<DatabaseRequest> for RpcBalancer<T> impl<T> tower::Service<DatabaseRequest> for RpcBalancer<T>
@ -138,7 +136,27 @@ where
let config_mutex = self.config.clone(); let config_mutex = self.config.clone();
let config = config_mutex.read().unwrap(); let config = config_mutex.read().unwrap();
let cache = self.cache.clone();
match req { match req {
DatabaseRequest::CheckKIsNotSpent(kis) => async move {
Ok(DatabaseResponse::CheckKIsNotSpent(
cache.read().unwrap().are_kis_spent(kis),
))
}
.boxed(),
DatabaseRequest::GeneratedCoins => async move {
Ok(DatabaseResponse::GeneratedCoins(
cache.read().unwrap().already_generated_coins,
))
}
.boxed(),
DatabaseRequest::NumberOutputsWithAmount(amt) => async move {
Ok(DatabaseResponse::NumberOutputsWithAmount(
cache.read().unwrap().numb_outs(amt),
))
}
.boxed(),
DatabaseRequest::BlockBatchInRange(range) => { DatabaseRequest::BlockBatchInRange(range) => {
let resp_to_ret = |resp: DatabaseResponse| { let resp_to_ret = |resp: DatabaseResponse| {
let DatabaseResponse::BlockBatchInRange(pow_info) = resp else { let DatabaseResponse::BlockBatchInRange(pow_info) = resp else {
@ -265,373 +283,3 @@ where
} }
.boxed() .boxed()
} }
enum RpcState<R: RpcConnection> {
Locked,
Acquiring(OwnedMutexLockFuture<monero_serai::rpc::Rpc<R>>),
Acquired(OwnedMutexGuard<monero_serai::rpc::Rpc<R>>),
}
pub struct Rpc<R: RpcConnection> {
rpc: Arc<futures::lock::Mutex<monero_serai::rpc::Rpc<R>>>,
addr: String,
rpc_state: RpcState<R>,
cache: Arc<RwLock<ScanningCache>>,
error_slot: Arc<Mutex<Option<RpcError>>>,
}
impl Rpc<HttpRpc> {
pub fn new_http(addr: String, cache: Arc<RwLock<ScanningCache>>) -> Rpc<HttpRpc> {
let http_rpc = HttpRpc::new(addr.clone()).unwrap();
Rpc {
rpc: Arc::new(futures::lock::Mutex::new(http_rpc)),
addr,
rpc_state: RpcState::Locked,
cache,
error_slot: Arc::new(Mutex::new(None)),
}
}
}
impl<R: RpcConnection + Send + Sync + 'static> tower::Service<DatabaseRequest> for Rpc<R> {
type Response = DatabaseResponse;
type Error = tower::BoxError;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if let Some(rpc_error) = self.error_slot.lock().unwrap().clone() {
return Poll::Ready(Err(rpc_error.into()));
}
loop {
match &mut self.rpc_state {
RpcState::Locked => {
self.rpc_state = RpcState::Acquiring(Arc::clone(&self.rpc).lock_owned())
}
RpcState::Acquiring(rpc) => {
self.rpc_state = RpcState::Acquired(futures::ready!(rpc.poll_unpin(cx)))
}
RpcState::Acquired(_) => return Poll::Ready(Ok(())),
}
}
}
fn call(&mut self, req: DatabaseRequest) -> Self::Future {
let RpcState::Acquired(rpc) = std::mem::replace(&mut self.rpc_state, RpcState::Locked)
else {
panic!("poll_ready was not called first!");
};
let cache = self.cache.clone();
let span = tracing::info_span!("rpc_request", addr = &self.addr);
let err_slot = self.error_slot.clone();
match req {
DatabaseRequest::BlockHash(height) => async move {
let res: Result<_, RpcError> = rpc
.get_block_hash(height as usize)
.map_ok(DatabaseResponse::BlockHash)
.await;
if let Err(e) = &res {
*err_slot.lock().unwrap() = Some(e.clone());
}
res.map_err(Into::into)
}
.instrument(span)
.boxed(),
DatabaseRequest::ChainHeight => async move {
let height = cache.read().unwrap().height;
let hash = rpc
.get_block_hash((height - 1) as usize)
.await
.map_err(Into::<tower::BoxError>::into)?;
Ok(DatabaseResponse::ChainHeight(height, hash))
}
.instrument(span)
.boxed(),
DatabaseRequest::CheckKIsNotSpent(kis) => async move {
Ok(DatabaseResponse::CheckKIsNotSpent(
cache.read().unwrap().are_kis_spent(kis),
))
}
.instrument(span)
.boxed(),
DatabaseRequest::GeneratedCoins => async move {
Ok(DatabaseResponse::GeneratedCoins(
cache.read().unwrap().already_generated_coins,
))
}
.instrument(span)
.boxed(),
DatabaseRequest::BlockExtendedHeader(id) => {
get_block_info(id, rpc).instrument(span).boxed()
}
DatabaseRequest::BlockExtendedHeaderInRange(range) => {
get_block_info_in_range(range, rpc).instrument(span).boxed()
}
DatabaseRequest::BlockBatchInRange(range) => {
get_blocks_in_range(range, rpc).instrument(span).boxed()
}
DatabaseRequest::Outputs(out_ids) => {
get_outputs(out_ids, cache, rpc).instrument(span).boxed()
}
DatabaseRequest::NumberOutputsWithAmount(amt) => async move {
Ok(DatabaseResponse::NumberOutputsWithAmount(
cache.read().unwrap().numb_outs(amt) as usize,
))
}
.boxed(),
}
}
}
#[instrument(skip_all)]
async fn get_outputs<R: RpcConnection>(
out_ids: HashMap<u64, HashSet<u64>>,
cache: Arc<RwLock<ScanningCache>>,
rpc: OwnedMutexGuard<monero_serai::rpc::Rpc<R>>,
) -> Result<DatabaseResponse, tower::BoxError> {
tracing::info!(
"Getting outputs len: {}",
out_ids.values().map(|amt_map| amt_map.len()).sum::<usize>()
);
#[derive(Serialize, Copy, Clone)]
struct OutputID {
amount: u64,
index: u64,
}
#[derive(Serialize, Clone)]
struct Request {
outputs: Vec<OutputID>,
}
#[derive(Deserialize)]
struct OutputRes {
height: u64,
key: [u8; 32],
mask: [u8; 32],
txid: [u8; 32],
}
#[derive(Deserialize)]
struct Response {
outs: Vec<OutputRes>,
}
let outputs = rayon_spawn_async(|| {
out_ids
.into_par_iter()
.flat_map(|(amt, amt_map)| {
amt_map
.into_iter()
.map(|amt_idx| OutputID {
amount: amt,
index: amt_idx,
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
})
.await;
let res = rpc
.bin_call(
"get_outs.bin",
monero_epee_bin_serde::to_bytes(&Request {
outputs: outputs.clone(),
})?,
)
.await?;
rayon_spawn_async(move || {
let outs: Response = monero_epee_bin_serde::from_bytes(&res)?;
tracing::info!("Got outputs len: {}", outs.outs.len());
let mut ret = HashMap::new();
let cache = cache.read().unwrap();
for (out, idx) in outs.outs.iter().zip(outputs) {
ret.entry(idx.amount).or_insert_with(HashMap::new).insert(
idx.index,
OutputOnChain {
height: out.height,
time_lock: cache.outputs_time_lock(&out.txid),
// we unwrap these as we are checking already approved rings so if these points are bad
// then a bad proof has been approved.
key: CompressedEdwardsY::from_slice(&out.key)
.unwrap()
.decompress()
.unwrap(),
mask: CompressedEdwardsY::from_slice(&out.mask)
.unwrap()
.decompress()
.unwrap(),
},
);
}
Ok(DatabaseResponse::Outputs(ret))
})
.await
}
async fn get_blocks_in_range<R: RpcConnection>(
range: Range<u64>,
rpc: OwnedMutexGuard<monero_serai::rpc::Rpc<R>>,
) -> Result<DatabaseResponse, tower::BoxError> {
tracing::info!("Getting blocks in range: {:?}", range);
#[derive(Serialize)]
pub struct Request {
pub heights: Vec<u64>,
}
#[derive(Deserialize)]
pub struct Response {
pub blocks: Vec<BlockCompleteEntry>,
}
let res = rpc
.bin_call(
"get_blocks_by_height.bin",
monero_epee_bin_serde::to_bytes(&Request {
heights: range.collect(),
})?,
)
.await?;
let blocks: Response = monero_epee_bin_serde::from_bytes(res)?;
Ok(DatabaseResponse::BlockBatchInRange(
rayon_spawn_async(|| {
blocks
.blocks
.into_par_iter()
.map(|b| {
Ok((
monero_serai::block::Block::read(&mut b.block.as_slice())?,
match b.txs {
TransactionBlobs::Pruned(_) => {
return Err("node sent pruned txs!".into())
}
TransactionBlobs::Normal(txs) => txs
.into_par_iter()
.map(|tx| {
monero_serai::transaction::Transaction::read(&mut tx.as_slice())
})
.collect::<Result<_, _>>()?,
TransactionBlobs::None => vec![],
},
))
})
.collect::<Result<_, tower::BoxError>>()
})
.await?,
))
}
#[derive(Deserialize, Debug)]
struct BlockInfo {
cumulative_difficulty: u64,
cumulative_difficulty_top64: u64,
timestamp: u64,
block_weight: usize,
long_term_weight: usize,
major_version: u8,
minor_version: u8,
}
async fn get_block_info_in_range<R: RpcConnection>(
range: Range<u64>,
rpc: OwnedMutexGuard<monero_serai::rpc::Rpc<R>>,
) -> Result<DatabaseResponse, tower::BoxError> {
#[derive(Deserialize, Debug)]
struct Response {
headers: Vec<BlockInfo>,
}
let res = rpc
.json_rpc_call::<Response>(
"get_block_headers_range",
Some(json!({"start_height": range.start, "end_height": range.end - 1})),
)
.await?;
tracing::info!("Retrieved block headers in range: {:?}", range);
Ok(DatabaseResponse::BlockExtendedHeaderInRange(
rayon_spawn_async(|| {
res.headers
.into_par_iter()
.map(|info| ExtendedBlockHeader {
version: HardFork::from_version(&info.major_version)
.expect("previously checked block has incorrect version"),
vote: HardFork::from_vote(&info.minor_version),
timestamp: info.timestamp,
cumulative_difficulty: u128_from_low_high(
info.cumulative_difficulty,
info.cumulative_difficulty_top64,
),
block_weight: info.block_weight,
long_term_weight: info.long_term_weight,
})
.collect()
})
.await,
))
}
async fn get_block_info<R: RpcConnection>(
id: BlockID,
rpc: OwnedMutexGuard<monero_serai::rpc::Rpc<R>>,
) -> Result<DatabaseResponse, tower::BoxError> {
tracing::info!("Retrieving block info with id: {}", id);
#[derive(Deserialize, Debug)]
struct Response {
block_header: BlockInfo,
}
let info = match id {
BlockID::Height(height) => {
let res = rpc
.json_rpc_call::<Response>(
"get_block_header_by_height",
Some(json!({"height": height})),
)
.await?;
res.block_header
}
BlockID::Hash(hash) => {
let res = rpc
.json_rpc_call::<Response>("get_block_header_by_hash", Some(json!({"hash": hash})))
.await?;
res.block_header
}
};
Ok(DatabaseResponse::BlockExtendedHeader(ExtendedBlockHeader {
version: HardFork::from_version(&info.major_version)
.expect("previously checked block has incorrect version"),
vote: HardFork::from_vote(&info.minor_version),
timestamp: info.timestamp,
cumulative_difficulty: u128_from_low_high(
info.cumulative_difficulty,
info.cumulative_difficulty_top64,
),
block_weight: info.block_weight,
long_term_weight: info.long_term_weight,
}))
}
fn u128_from_low_high(low: u64, high: u64) -> u128 {
let res: u128 = high as u128;
res << 64 | low as u128
}

View file

@ -22,7 +22,7 @@ use crate::transactions::TransactionVerificationData;
#[derive(Debug, Default, Clone, Encode, Decode)] #[derive(Debug, Default, Clone, Encode, Decode)]
pub struct ScanningCache { pub struct ScanningCache {
// network: u8, // network: u8,
numb_outs: HashMap<u64, u64>, numb_outs: HashMap<u64, usize>,
time_locked_out: HashMap<[u8; 32], u64>, time_locked_out: HashMap<[u8; 32], u64>,
kis: HashSet<[u8; 32]>, kis: HashSet<[u8; 32]>,
pub already_generated_coins: u64, pub already_generated_coins: u64,
@ -112,15 +112,15 @@ impl ScanningCache {
} }
} }
pub fn total_outs(&self) -> u64 { pub fn total_outs(&self) -> usize {
self.numb_outs.values().sum() self.numb_outs.values().sum()
} }
pub fn numb_outs(&self, amount: u64) -> u64 { pub fn numb_outs(&self, amount: u64) -> usize {
*self.numb_outs.get(&amount).unwrap_or(&0) *self.numb_outs.get(&amount).unwrap_or(&0)
} }
pub fn add_outs(&mut self, amount: u64, count: u64) { pub fn add_outs(&mut self, amount: u64, count: usize) {
if let Some(numb_outs) = self.numb_outs.get_mut(&amount) { if let Some(numb_outs) = self.numb_outs.get_mut(&amount) {
*numb_outs += count; *numb_outs += count;
} else { } else {

View file

@ -0,0 +1,441 @@
use std::{
collections::{HashMap, HashSet},
future::Future,
ops::Range,
pin::Pin,
sync::{Arc, RwLock},
task::{Context, Poll},
};
use curve25519_dalek::edwards::CompressedEdwardsY;
use futures::{
channel::{mpsc, oneshot},
ready, FutureExt, SinkExt, StreamExt, TryStreamExt,
};
use monero_serai::{
block::Block,
rpc::{HttpRpc, Rpc, RpcError},
transaction::Transaction,
};
use monero_wire::common::{BlockCompleteEntry, TransactionBlobs};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use serde_json::json;
use tokio::{
task::JoinHandle,
time::{timeout, Duration},
};
use tower::Service;
use tracing::{instrument, Instrument};
use cuprate_common::BlockID;
use super::ScanningCache;
use crate::{
helper::rayon_spawn_async, DatabaseRequest, DatabaseResponse, ExtendedBlockHeader, HardFork,
OutputOnChain,
};
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
pub struct RpcConnectionSvc {
pub(crate) address: String,
pub(crate) rpc_task_handle: JoinHandle<()>,
pub(crate) rpc_task_chan: mpsc::Sender<RpcReq>,
}
impl Service<DatabaseRequest> for RpcConnectionSvc {
type Response = DatabaseResponse;
type Error = tower::BoxError;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if self.rpc_task_handle.is_finished() {
return Poll::Ready(Err("RPC task has exited!".into()));
}
self.rpc_task_chan.poll_ready(cx).map_err(Into::into)
}
fn call(&mut self, req: DatabaseRequest) -> Self::Future {
let (tx, rx) = oneshot::channel();
let req = RpcReq {
req,
res_chan: tx,
span: tracing::info_span!(parent: &tracing::Span::current(), "rpc", addr = &self.address),
};
self.rpc_task_chan
.try_send(req)
.expect("poll_ready should be called first!");
async move {
rx.await
.expect("sender will not be dropped without response")
}
.boxed()
}
}
pub(crate) struct RpcReq {
req: DatabaseRequest,
res_chan: oneshot::Sender<Result<DatabaseResponse, tower::BoxError>>,
span: tracing::Span,
}
pub struct RpcConnection {
pub(crate) address: String,
pub(crate) con: Rpc<HttpRpc>,
pub(crate) cache: Arc<RwLock<ScanningCache>>,
pub(crate) req_chan: mpsc::Receiver<RpcReq>,
}
impl RpcConnection {
async fn get_block_hash(&self, height: u64) -> Result<[u8; 32], tower::BoxError> {
self.con
.get_block_hash(height.try_into().unwrap())
.await
.map_err(Into::into)
}
async fn get_extended_block_header(
&self,
id: BlockID,
) -> Result<ExtendedBlockHeader, tower::BoxError> {
tracing::info!("Retrieving block info with id: {}", id);
#[derive(Deserialize, Debug)]
struct Response {
block_header: BlockInfo,
}
let info = match id {
BlockID::Height(height) => {
let res = self
.con
.json_rpc_call::<Response>(
"get_block_header_by_height",
Some(json!({"height": height})),
)
.await?;
res.block_header
}
BlockID::Hash(hash) => {
let res = self
.con
.json_rpc_call::<Response>(
"get_block_header_by_hash",
Some(json!({"hash": hash})),
)
.await?;
res.block_header
}
};
Ok(ExtendedBlockHeader {
version: HardFork::from_version(&info.major_version)
.expect("previously checked block has incorrect version"),
vote: HardFork::from_vote(&info.minor_version),
timestamp: info.timestamp,
cumulative_difficulty: u128_from_low_high(
info.cumulative_difficulty,
info.cumulative_difficulty_top64,
),
block_weight: info.block_weight,
long_term_weight: info.long_term_weight,
})
}
async fn get_extended_block_header_in_range(
&self,
range: Range<u64>,
) -> Result<Vec<ExtendedBlockHeader>, tower::BoxError> {
#[derive(Deserialize, Debug)]
struct Response {
headers: Vec<BlockInfo>,
}
let res = self
.con
.json_rpc_call::<Response>(
"get_block_headers_range",
Some(json!({"start_height": range.start, "end_height": range.end - 1})),
)
.await?;
tracing::info!("Retrieved block headers in range: {:?}", range);
Ok(rayon_spawn_async(|| {
res.headers
.into_iter()
.map(|info| ExtendedBlockHeader {
version: HardFork::from_version(&info.major_version)
.expect("previously checked block has incorrect version"),
vote: HardFork::from_vote(&info.minor_version),
timestamp: info.timestamp,
cumulative_difficulty: u128_from_low_high(
info.cumulative_difficulty,
info.cumulative_difficulty_top64,
),
block_weight: info.block_weight,
long_term_weight: info.long_term_weight,
})
.collect()
})
.await)
}
async fn get_blocks_in_range(
&self,
range: Range<u64>,
) -> Result<Vec<(Block, Vec<Transaction>)>, tower::BoxError> {
tracing::info!("Getting blocks in range: {:?}", range);
#[derive(Serialize)]
pub struct Request {
pub heights: Vec<u64>,
}
#[derive(Deserialize)]
pub struct Response {
pub blocks: Vec<BlockCompleteEntry>,
}
let res = self
.con
.bin_call(
"get_blocks_by_height.bin",
monero_epee_bin_serde::to_bytes(&Request {
heights: range.collect(),
})?,
)
.await?;
let blocks: Response = monero_epee_bin_serde::from_bytes(res)?;
Ok(rayon_spawn_async(|| {
blocks
.blocks
.into_par_iter()
.map(|b| {
Ok((
Block::read(&mut b.block.as_slice())?,
match b.txs {
TransactionBlobs::Pruned(_) => {
return Err("node sent pruned txs!".into())
}
TransactionBlobs::Normal(txs) => txs
.into_par_iter()
.map(|tx| Transaction::read(&mut tx.as_slice()))
.collect::<Result<_, _>>()?,
TransactionBlobs::None => vec![],
},
))
})
.collect::<Result<_, tower::BoxError>>()
})
.await?)
}
async fn get_outputs(
&self,
out_ids: HashMap<u64, HashSet<u64>>,
) -> Result<HashMap<u64, HashMap<u64, OutputOnChain>>, tower::BoxError> {
tracing::info!(
"Getting outputs len: {}",
out_ids.values().map(|amt_map| amt_map.len()).sum::<usize>()
);
#[derive(Serialize, Copy, Clone)]
struct OutputID {
amount: u64,
index: u64,
}
#[derive(Serialize, Clone)]
struct Request {
outputs: Vec<OutputID>,
}
#[derive(Deserialize)]
struct OutputRes {
height: u64,
key: [u8; 32],
mask: [u8; 32],
txid: [u8; 32],
}
#[derive(Deserialize)]
struct Response {
outs: Vec<OutputRes>,
}
let outputs = out_ids
.into_iter()
.flat_map(|(amt, amt_map)| {
amt_map
.into_iter()
.map(|amt_idx| OutputID {
amount: amt,
index: amt_idx,
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
let res = self
.con
.bin_call(
"get_outs.bin",
monero_epee_bin_serde::to_bytes(&Request {
outputs: outputs.clone(),
})?,
)
.await?;
let cache = self.cache.clone();
let span = tracing::Span::current();
rayon_spawn_async(move || {
let outs: Response = monero_epee_bin_serde::from_bytes(&res)?;
tracing::info!(parent: &span, "Got outputs len: {}", outs.outs.len());
let mut ret = HashMap::new();
let cache = cache.read().unwrap();
for (out, idx) in outs.outs.iter().zip(outputs) {
ret.entry(idx.amount).or_insert_with(HashMap::new).insert(
idx.index,
OutputOnChain {
height: out.height,
time_lock: cache.outputs_time_lock(&out.txid),
// we unwrap these as we are checking already approved rings so if these points are bad
// then a bad proof has been approved.
key: CompressedEdwardsY::from_slice(&out.key)
.unwrap()
.decompress()
.unwrap(),
mask: CompressedEdwardsY::from_slice(&out.mask)
.unwrap()
.decompress()
.unwrap(),
},
);
}
Ok(ret)
})
.await
}
async fn handle_request(
&mut self,
req: DatabaseRequest,
) -> Result<DatabaseResponse, tower::BoxError> {
match req {
DatabaseRequest::BlockHash(height) => {
timeout(DEFAULT_TIMEOUT, self.get_block_hash(height))
.await?
.map(DatabaseResponse::BlockHash)
}
DatabaseRequest::ChainHeight => {
let height = self.cache.read().unwrap().height;
let hash = timeout(DEFAULT_TIMEOUT, self.get_block_hash(height - 1)).await??;
Ok(DatabaseResponse::ChainHeight(height, hash))
}
DatabaseRequest::BlockExtendedHeader(id) => {
timeout(DEFAULT_TIMEOUT, self.get_extended_block_header(id))
.await?
.map(DatabaseResponse::BlockExtendedHeader)
}
DatabaseRequest::BlockExtendedHeaderInRange(range) => timeout(
DEFAULT_TIMEOUT,
self.get_extended_block_header_in_range(range),
)
.await?
.map(DatabaseResponse::BlockExtendedHeaderInRange),
DatabaseRequest::BlockBatchInRange(range) => {
timeout(DEFAULT_TIMEOUT, self.get_blocks_in_range(range))
.await?
.map(DatabaseResponse::BlockBatchInRange)
}
DatabaseRequest::Outputs(out_ids) => {
timeout(DEFAULT_TIMEOUT, self.get_outputs(out_ids))
.await?
.map(DatabaseResponse::Outputs)
}
DatabaseRequest::NumberOutputsWithAmount(_)
| DatabaseRequest::GeneratedCoins
| DatabaseRequest::CheckKIsNotSpent(_) => {
panic!("Request does not need RPC connection!")
}
}
}
#[instrument(level = "info", skip(self), fields(addr = self.address))]
pub async fn check_rpc_alive(&self) -> Result<(), tower::BoxError> {
tracing::debug!("Checking RPC connection");
let res = timeout(Duration::from_secs(10), self.con.get_height()).await;
let ok = matches!(res, Ok(Ok(_)));
if !ok {
tracing::warn!("RPC connection test failed");
return Err("RPC connection test failed".into());
}
tracing::info!("RPC connection Ok");
Ok(())
}
pub async fn run(mut self) {
while let Some(req) = self.req_chan.next().await {
let RpcReq {
req,
span,
res_chan,
} = req;
let res = self.handle_request(req).instrument(span.clone()).await;
let is_err = res.is_err();
if is_err {
tracing::warn!(parent: &span, "Error from RPC: {:?}", res)
}
let _ = res_chan.send(res);
if is_err && self.check_rpc_alive().await.is_err() {
break;
}
}
tracing::warn!("Shutting down RPC connection: {}", self.address);
self.req_chan.close();
while let Some(req) = self.req_chan.try_next().unwrap() {
let _ = req.res_chan.send(Err("RPC connection closed!".into()));
}
}
}
#[derive(Deserialize, Debug)]
struct BlockInfo {
cumulative_difficulty: u64,
cumulative_difficulty_top64: u64,
timestamp: u64,
block_weight: usize,
long_term_weight: usize,
major_version: u8,
minor_version: u8,
}
fn u128_from_low_high(low: u64, high: u64) -> u128 {
let res: u128 = high as u128;
res << 64 | low as u128
}

View file

@ -10,46 +10,51 @@ use futures::{
SinkExt, StreamExt, SinkExt, StreamExt,
}; };
use monero_serai::rpc::HttpRpc; use monero_serai::rpc::HttpRpc;
use tokio::time::timeout;
use tower::{discover::Change, load::PeakEwma}; use tower::{discover::Change, load::PeakEwma};
use tracing::instrument; use tracing::instrument;
use super::{cache::ScanningCache, Rpc}; use super::{
cache::ScanningCache,
connection::{RpcConnection, RpcConnectionSvc},
};
#[instrument(skip(cache))] #[instrument(skip(cache))]
async fn check_rpc(addr: String, cache: Arc<RwLock<ScanningCache>>) -> Option<Rpc<HttpRpc>> { async fn check_rpc(addr: String, cache: Arc<RwLock<ScanningCache>>) -> Option<RpcConnectionSvc> {
tracing::debug!("Sending request to node."); tracing::debug!("Sending request to node.");
let rpc = HttpRpc::new(addr.clone()).ok()?;
// make sure the RPC is actually reachable
timeout(Duration::from_secs(2), rpc.get_height())
.await
.ok()?
.ok()?;
tracing::debug!("Node sent ok response."); let con = HttpRpc::new(addr.clone()).await.ok()?;
let (tx, rx) = mpsc::channel(1);
let rpc = RpcConnection {
address: addr.clone(),
con,
cache,
req_chan: rx,
};
Some(Rpc::new_http(addr, cache)) rpc.check_rpc_alive().await.ok()?;
let handle = tokio::spawn(rpc.run());
Some(RpcConnectionSvc {
address: addr,
rpc_task_chan: tx,
rpc_task_handle: handle,
})
} }
pub(crate) struct RPCDiscover { pub(crate) struct RPCDiscover {
pub initial_list: Vec<String>, pub initial_list: Vec<String>,
pub ok_channel: mpsc::Sender<Change<usize, PeakEwma<Rpc<HttpRpc>>>>, pub ok_channel: mpsc::Sender<Change<usize, PeakEwma<RpcConnectionSvc>>>,
pub already_connected: HashSet<String>, pub already_connected: usize,
pub cache: Arc<RwLock<ScanningCache>>, pub cache: Arc<RwLock<ScanningCache>>,
} }
impl RPCDiscover { impl RPCDiscover {
async fn found_rpc(&mut self, rpc: Rpc<HttpRpc>) -> Result<(), SendError> { async fn found_rpc(&mut self, rpc: RpcConnectionSvc) -> Result<(), SendError> {
//if self.already_connected.contains(&rpc.addr) { self.already_connected += 1;
// return Ok(());
//}
tracing::info!("Connecting to node: {}", &rpc.addr);
let addr = rpc.addr.clone();
self.ok_channel self.ok_channel
.send(Change::Insert( .send(Change::Insert(
self.already_connected.len(), self.already_connected,
PeakEwma::new( PeakEwma::new(
rpc, rpc,
Duration::from_secs(5000), Duration::from_secs(5000),
@ -58,7 +63,6 @@ impl RPCDiscover {
), ),
)) ))
.await?; .await?;
self.already_connected.insert(addr);
Ok(()) Ok(())
} }

View file

@ -1,7 +1,7 @@
use std::ops::Deref;
use std::{ use std::{
collections::HashSet, collections::HashSet,
future::Future, future::Future,
ops::Deref,
pin::Pin, pin::Pin,
sync::Arc, sync::Arc,
task::{Context, Poll}, task::{Context, Poll},

View file

@ -38,6 +38,7 @@ pub async fn batch_refresh_ring_member_info<D: Database + Clone + Send + Sync +
let (txs_needing_full_refresh, txs_needing_partial_refresh) = let (txs_needing_full_refresh, txs_needing_partial_refresh) =
ring_member_info_needing_refresh(txs_verification_data, hf); ring_member_info_needing_refresh(txs_verification_data, hf);
if !txs_needing_full_refresh.is_empty() {
batch_fill_ring_member_info( batch_fill_ring_member_info(
&txs_needing_full_refresh, &txs_needing_full_refresh,
hf, hf,
@ -45,6 +46,7 @@ pub async fn batch_refresh_ring_member_info<D: Database + Clone + Send + Sync +
database.clone(), database.clone(),
) )
.await?; .await?;
}
for tx_v_data in txs_needing_partial_refresh { for tx_v_data in txs_needing_partial_refresh {
let decoy_info = if hf != &HardFork::V1 { let decoy_info = if hf != &HardFork::V1 {
@ -88,10 +90,9 @@ fn ring_member_info_needing_refresh(
for tx in txs_verification_data { for tx in txs_verification_data {
let tx_ring_member_info = tx.rings_member_info.lock().unwrap(); let tx_ring_member_info = tx.rings_member_info.lock().unwrap();
// if we don't have ring members or if a re-org has happened or if we changed hf do a full refresh. // if we don't have ring members or if a re-org has happened do a full refresh.
// doing a full refresh each hf isn't needed now but its so rare it makes sense to just do a full one.
if let Some(tx_ring_member_info) = tx_ring_member_info.deref() { if let Some(tx_ring_member_info) = tx_ring_member_info.deref() {
if tx_ring_member_info.re_org_token.reorg_happened() || &tx_ring_member_info.hf != hf { if tx_ring_member_info.re_org_token.reorg_happened() {
txs_needing_full_refresh.push(tx.clone()); txs_needing_full_refresh.push(tx.clone());
continue; continue;
} }
@ -102,10 +103,17 @@ fn ring_member_info_needing_refresh(
// if any input does not have a 0 amount do a partial refresh, this is because some decoy info // if any input does not have a 0 amount do a partial refresh, this is because some decoy info
// data is based on the amount of non-ringCT outputs at a certain point. // data is based on the amount of non-ringCT outputs at a certain point.
if tx.tx.prefix.inputs.iter().any(|inp| match inp { // Or if a hf has happened as this will change the default minimum decoys.
if &tx_ring_member_info
.as_ref()
.expect("We just checked if this was None")
.hf
!= hf
|| tx.tx.prefix.inputs.iter().any(|inp| match inp {
Input::Gen(_) => false, Input::Gen(_) => false,
Input::ToKey { amount, .. } => amount.is_some(), Input::ToKey { amount, .. } => amount.is_some(),
}) { })
{
txs_needing_partial_refresh.push(tx.clone()); txs_needing_partial_refresh.push(tx.clone());
} }
} }

View file

@ -56,6 +56,29 @@ fn check_decoy_info(decoy_info: &DecoyInfo, hf: &HardFork) -> Result<(), Consens
Ok(()) Ok(())
} }
/// Checks that the key image is torsion free.
///
/// https://cuprate.github.io/monero-book/consensus_rules/transactions.html#torsion-free-key-image
pub(crate) fn check_key_images_torsion(input: &Input) -> Result<(), ConsensusError> {
match input {
Input::ToKey { key_image, .. } => {
// this happens in monero-serai but we may as well duplicate the check.
if !key_image.is_torsion_free() {
return Err(ConsensusError::TransactionHasInvalidInput(
"key image has torsion",
));
}
}
_ => {
return Err(ConsensusError::TransactionHasInvalidInput(
"Input not ToKey",
))
}
}
Ok(())
}
/// Checks the inputs key images for torsion and for duplicates in the transaction. /// Checks the inputs key images for torsion and for duplicates in the transaction.
/// ///
/// The `spent_kis` parameter is not meant to be a complete list of key images, just a list of related transactions /// The `spent_kis` parameter is not meant to be a complete list of key images, just a list of related transactions
@ -211,6 +234,37 @@ fn sum_inputs_v1(inputs: &[Input]) -> Result<u64, ConsensusError> {
Ok(sum) Ok(sum)
} }
/// Checks the inputs semantics are valid.
///
/// This does all the checks that don't need blockchain context.
///
/// Although technically hard-fork is contextual data we class it as not because
/// blocks keep their hf in the header.
pub fn check_inputs_semantics(
inputs: &[Input],
hf: &HardFork,
tx_version: &TxVersion,
) -> Result<u64, ConsensusError> {
if inputs.is_empty() {
return Err(ConsensusError::TransactionHasInvalidInput("no inputs"));
}
for input in inputs {
check_input_type(input)?;
check_input_has_decoys(input)?;
check_ring_members_unique(input, hf)?;
check_key_images_torsion(input)?;
}
check_inputs_sorted(inputs, hf)?;
match tx_version {
TxVersion::RingSignatures => sum_inputs_v1(inputs),
_ => panic!("TODO: RCT"),
}
}
/// Checks all input consensus rules. /// Checks all input consensus rules.
/// ///
/// TODO: list rules. /// TODO: list rules.

View file

@ -40,14 +40,14 @@ pub fn verify_inputs_signatures(
panic!("How did we build a ring with no decoys?"); panic!("How did we build a ring with no decoys?");
}; };
if !sig.verify_ring_signature(tx_sig_hash, ring, key_image) { if !sig.verify(tx_sig_hash, ring, key_image) {
return Err(ConsensusError::TransactionSignatureInvalid( return Err(ConsensusError::TransactionSignatureInvalid(
"Invalid ring signature", "Invalid ring signature",
)); ));
} }
Ok(()) Ok(())
})?; })?;
}, }
_ => panic!("tried to verify v1 tx with a non v1 ring"), _ => panic!("tried to verify v1 tx with a non v1 ring"),
} }
Ok(()) Ok(())