read: use macros for set/getting ThreadLocal's based on backend

This commit is contained in:
hinto.janai 2024-04-25 17:43:08 -04:00
parent bb99739c59
commit 9a507e7053
No known key found for this signature in database
GPG key ID: D47CE05FA175A499

View file

@ -18,7 +18,7 @@ use cfg_if::cfg_if;
use crossbeam::channel::Receiver;
use curve25519_dalek::{constants::ED25519_BASEPOINT_POINT, edwards::CompressedEdwardsY, Scalar};
use futures::{channel::oneshot, ready};
use monero_serai::transaction::Timelock;
use monero_serai::{transaction::Timelock, H};
use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
use thread_local::ThreadLocal;
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
@ -34,7 +34,12 @@ use crate::{
config::ReaderThreads,
constants::DATABASE_CORRUPT_MSG,
error::RuntimeError,
ops::block::{get_block_extended_header_from_height, get_block_info},
ops::{
block::{get_block_extended_header_from_height, get_block_info},
blockchain::{cumulative_generated_coins, top_block_height},
key_image::key_image_exists,
output::get_output,
},
service::types::{ResponseReceiver, ResponseResult, ResponseSender},
tables::{BlockHeights, BlockInfos, KeyImages, NumOutputs, Outputs, Tables},
types::{Amount, AmountIndex, BlockHeight, KeyImage, OutputFlags, PreRctOutputId},
@ -171,8 +176,10 @@ impl tower::Service<ReadRequest> for DatabaseReadHandle {
// INVARIANT:
// The below `DatabaseReader` function impl block relies on this behavior.
let env = Arc::clone(&self.env);
self.pool
.spawn(move || map_request(permit, env, request, response_sender));
self.pool.spawn(move || {
let _permit: OwnedSemaphorePermit = permit;
map_request(&env, request, response_sender);
}); // drop(permit/env);
InfallibleOneshotReceiver::from(receiver)
}
@ -182,7 +189,6 @@ impl tower::Service<ReadRequest> for DatabaseReadHandle {
// This function maps [`Request`]s to function calls
// executed by the rayon DB reader threadpool.
#[allow(clippy::needless_pass_by_value)] // TODO: fix me
/// Map [`Request`]'s to specific database handler functions.
///
/// This is the main entrance into all `Request` handler functions.
@ -191,9 +197,8 @@ impl tower::Service<ReadRequest> for DatabaseReadHandle {
/// 2. Handler function is called
/// 3. [`Response`] is sent
fn map_request(
_permit: OwnedSemaphorePermit, // Permit for this request, dropped at end of function
env: Arc<ConcreteEnv>, // Access to the database
request: ReadRequest, // The request we must fulfill
env: &ConcreteEnv, // Access to the database
request: ReadRequest, // The request we must fulfill
response_sender: ResponseSender, // The channel we must send the response back to
) {
use ReadRequest as R;
@ -201,14 +206,14 @@ fn map_request(
/* TODO: pre-request handling, run some code for each request? */
let response = match request {
R::BlockExtendedHeader(block) => block_extended_header(&env, block),
R::BlockHash(block) => block_hash(&env, block),
R::BlockExtendedHeaderInRange(range) => block_extended_header_in_range(&env, range),
R::ChainHeight => chain_height(&env),
R::GeneratedCoins => generated_coins(&env),
R::Outputs(map) => outputs(&env, map),
R::NumberOutputsWithAmount(vec) => number_outputs_with_amount(&env, vec),
R::CheckKIsNotSpent(set) => check_k_is_not_spent(&env, set),
R::BlockExtendedHeader(block) => block_extended_header(env, block),
R::BlockHash(block) => block_hash(env, block),
R::BlockExtendedHeaderInRange(range) => block_extended_header_in_range(env, range),
R::ChainHeight => chain_height(env),
R::GeneratedCoins => generated_coins(env),
R::Outputs(map) => outputs(env, map),
R::NumberOutputsWithAmount(vec) => number_outputs_with_amount(env, vec),
R::CheckKIsNotSpent(set) => check_k_is_not_spent(env, set),
};
if let Err(e) = response_sender.send(response) {
@ -220,12 +225,86 @@ fn map_request(
}
//---------------------------------------------------------------------------------------------------- Thread Local
/// TODO: explain this.
/// `heed`'s transactions and tables are not `Sync`, so we cannot use
/// them with rayon, however, we set a feature such that they are `Send`.
///
/// Thus, before using rayon, we put the tx/table inside a
/// `ThreadLocal` which gives access to those threads.
///
/// <https://github.com/Cuprate/cuprate/pull/113#discussion_r1576762346>
#[inline]
fn thread_local<T: Send>(env: &impl Env) -> ThreadLocal<T> {
ThreadLocal::with_capacity(env.config().reader_threads.as_threads().get())
}
/// Only `heed` requires the above [`thread_local()`] function,
/// as `redb`'s transactions and tables are `Send + Sync`.
///
/// Thus, wrapping them in `ThreadLocal` is wasteful.
///
/// This macro branches depending on what backend we're using
/// and either returns `ThreadLocal<T>` or the T directly.
///
/// An imaginary signature would look something like:
/// ```ignore
/// fn set_tx_ro_and_tables() -> if heed {
/// (ThreadLocal<TxRo>, ThreadLocal<Tables>) }
/// } else {
/// (TxRo, Tables)
/// };
/// ```
///
/// See [`get_tx_ro_and_tables`] for retrieving the output.
///
/// # Early return
/// Note that this early returns with `?` from whatever scope
/// it was called from if `tx_ro()` or `open_tables()` errors.
///
/// # Example
/// ```ignore
/// // Outside scope, still single threaded.
/// // Set the transaction and tables.
/// let (tx_ro, tables) = set_tx_ro_and_tables!(env, env_inner);
///
/// iter
/// .into_par_iter() // <- we've entered `rayon` scope
/// .map(|_| {
/// // Access the outside scope's `tx_ro` and `tables`.
/// // If needed, this will initialize some `ThreadLocal`'s.
/// let (tx_ro, tables) = get_tx_ro_and_tables!(env_inner, tx_ro, tables);
///
/// /* do rayon stuff */
/// });
/// ```
macro_rules! set_tx_ro_and_tables {
($env:ident, $env_inner:ident) => {{
cfg_if::cfg_if! {
if #[cfg(all(feature = "redb", not(feature = "heed")))] {
let tx_ro = $env_inner.tx_ro()?;
let tables = $env_inner.open_tables(tx_ro)?;
(tx_ro, tables)
} else {
(thread_local($env), thread_local($env))
}
}
}};
}
/// Access the values set with [`set_tx_ro_and_tables`].
macro_rules! get_tx_ro_and_tables {
($env_inner:ident, $tx_ro:ident, $tables:ident) => {{
cfg_if::cfg_if! {
if #[cfg(all(feature = "redb", not(feature = "heed")))] {
($tx_ro, $tables)
} else {
let tx_ro = $tx_ro.get_or_try(|| $env_inner.tx_ro())?;
let tables = $tables.get_or_try(|| $env_inner.open_tables(tx_ro))?;
(tx_ro, tables)
}
}
}};
}
//---------------------------------------------------------------------------------------------------- Handler functions
// These are the actual functions that do stuff according to the incoming [`Request`].
//
@ -277,16 +356,12 @@ fn block_extended_header_in_range(
range: std::ops::Range<BlockHeight>,
) -> ResponseResult {
let env_inner = env.env_inner();
let (tx_ro, tables) = set_tx_ro_and_tables!(env, env_inner);
let tx_ro = thread_local(env);
let tables = thread_local(env);
// This iterator will early return as `Err` if there's even 1 error.
let vec = range
.into_par_iter()
.map(|block_height| {
let tx_ro = tx_ro.get_or_try(|| env_inner.tx_ro())?;
let tables = tables.get_or_try(|| env_inner.open_tables(tx_ro))?;
let (tx_ro, tables) = get_tx_ro_and_tables!(env_inner, tx_ro, tables);
get_block_extended_header_from_height(&block_height, tables)
})
.collect::<Result<Vec<ExtendedBlockHeader>, RuntimeError>>()?;
@ -302,8 +377,8 @@ fn chain_height(env: &ConcreteEnv) -> ResponseResult {
let table_block_heights = env_inner.open_db_ro::<BlockHeights>(&tx_ro)?;
let table_block_infos = env_inner.open_db_ro::<BlockInfos>(&tx_ro)?;
let top_height = crate::ops::blockchain::top_block_height(&table_block_heights)?;
let block_hash = crate::ops::block::get_block_info(&top_height, &table_block_infos)?.block_hash;
let top_height = top_block_height(&table_block_heights)?;
let block_hash = get_block_info(&top_height, &table_block_infos)?.block_hash;
Ok(Response::ChainHeight(top_height, block_hash))
}
@ -316,38 +391,35 @@ fn generated_coins(env: &ConcreteEnv) -> ResponseResult {
let table_block_heights = env_inner.open_db_ro::<BlockHeights>(&tx_ro)?;
let table_block_infos = env_inner.open_db_ro::<BlockInfos>(&tx_ro)?;
let top_height = crate::ops::blockchain::top_block_height(&table_block_heights)?;
let top_height = top_block_height(&table_block_heights)?;
Ok(Response::GeneratedCoins(
crate::ops::blockchain::cumulative_generated_coins(&top_height, &table_block_infos)?,
))
Ok(Response::GeneratedCoins(cumulative_generated_coins(
&top_height,
&table_block_infos,
)?))
}
/// [`ReadRequest::Outputs`].
#[inline]
fn outputs(env: &ConcreteEnv, map: HashMap<Amount, HashSet<AmountIndex>>) -> ResponseResult {
let env_inner = env.env_inner();
let tx_ro = thread_local(env);
let table_outputs = thread_local(env);
let (tx_ro, tables) = set_tx_ro_and_tables!(env, env_inner);
// -> Result<(AmountIndex, OutputOnChain), RuntimeError>
let inner_map = |amount, amount_index| {
let tx_ro = tx_ro.get_or_try(|| env_inner.tx_ro())?;
let table_outputs = table_outputs.get_or_try(|| env_inner.open_db_ro::<Outputs>(tx_ro))?;
let (tx_ro, tables) = get_tx_ro_and_tables!(env_inner, tx_ro, tables);
let pre_rct_output_id = PreRctOutputId {
amount,
amount_index,
};
let output = crate::ops::output::get_output(&pre_rct_output_id, table_outputs)?;
let output = get_output(&pre_rct_output_id, tables.outputs())?;
// Map `Output` -> `OutputOnChain`
// FIXME: This should be in a function somewhere.
//--- Map `Output` -> `OutputOnChain`
// FIXME: implement lookup table for common values:
// <https://github.com/monero-project/monero/blob/c8214782fb2a769c57382a999eaf099691c836e7/src/ringct/rctOps.cpp#L322>
let commitment = ED25519_BASEPOINT_POINT + monero_serai::H() * Scalar::from(amount);
let commitment = ED25519_BASEPOINT_POINT + H() * Scalar::from(amount);
let time_lock = if output
.output_flags
@ -393,18 +465,14 @@ fn outputs(env: &ConcreteEnv, map: HashMap<Amount, HashSet<AmountIndex>>) -> Res
#[inline]
fn number_outputs_with_amount(env: &ConcreteEnv, amounts: Vec<Amount>) -> ResponseResult {
let env_inner = env.env_inner();
let tx_ro = thread_local(env);
let table_num_outputs = thread_local(env);
let (tx_ro, tables) = set_tx_ro_and_tables!(env, env_inner);
let map = amounts
.into_par_iter()
.map(|amount| {
let tx_ro = tx_ro.get_or_try(|| env_inner.tx_ro())?;
let table_num_outputs =
table_num_outputs.get_or_try(|| env_inner.open_db_ro::<NumOutputs>(tx_ro))?;
let (tx_ro, tables) = get_tx_ro_and_tables!(env_inner, tx_ro, tables);
match table_num_outputs.get(&amount) {
match tables.num_outputs().get(&amount) {
// INVARIANT: #[cfg] @ lib.rs asserts `usize == u64`
#[allow(clippy::cast_possible_truncation)]
Ok(count) => Ok((amount, count as usize)),
@ -421,15 +489,11 @@ fn number_outputs_with_amount(env: &ConcreteEnv, amounts: Vec<Amount>) -> Respon
#[inline]
fn check_k_is_not_spent(env: &ConcreteEnv, key_images: HashSet<KeyImage>) -> ResponseResult {
let env_inner = env.env_inner();
let tx_ro = thread_local(env);
let table_key_images = thread_local(env);
let (tx_ro, tables) = set_tx_ro_and_tables!(env, env_inner);
let key_image_exists = |key_image| {
let tx_ro = tx_ro.get_or_try(|| env_inner.tx_ro())?;
let table_key_images =
table_key_images.get_or_try(|| env_inner.open_db_ro::<KeyImages>(tx_ro))?;
crate::ops::key_image::key_image_exists(&key_image, table_key_images)
let (tx_ro, tables) = get_tx_ro_and_tables!(env_inner, tx_ro, tables);
key_image_exists(&key_image, tables.key_images())
};
match key_images