Merge branch 'main' into cuprated

This commit is contained in:
Boog900 2024-07-03 01:57:41 +01:00
commit 415425a576
No known key found for this signature in database
GPG key ID: 42AB1287CB0041C2
25 changed files with 1258 additions and 177 deletions

1
Cargo.lock generated
View file

@ -506,6 +506,7 @@ dependencies = [
"monero-serai", "monero-serai",
"paste", "paste",
"pretty_assertions", "pretty_assertions",
"proptest",
"rayon", "rayon",
"tempfile", "tempfile",
"thread_local", "thread_local",

View file

@ -10,7 +10,7 @@ default = ["txpool"]
txpool = ["dep:rand_distr", "dep:tokio-util", "dep:tokio"] txpool = ["dep:rand_distr", "dep:tokio-util", "dep:tokio"]
[dependencies] [dependencies]
tower = { workspace = true, features = ["discover", "util"] } tower = { workspace = true, features = ["util"] }
tracing = { workspace = true, features = ["std"] } tracing = { workspace = true, features = ["std"] }
futures = { workspace = true, features = ["std"] } futures = { workspace = true, features = ["std"] }

View file

@ -26,9 +26,9 @@
//! The diffuse service should have a request of [`DiffuseRequest`](traits::DiffuseRequest) and it's error //! The diffuse service should have a request of [`DiffuseRequest`](traits::DiffuseRequest) and it's error
//! should be [`tower::BoxError`]. //! should be [`tower::BoxError`].
//! //!
//! ## Outbound Peer Discoverer //! ## Outbound Peer TryStream
//! //!
//! The outbound peer [`Discover`](tower::discover::Discover) should provide a stream of randomly selected outbound //! The outbound peer [`TryStream`](futures::TryStream) should provide a stream of randomly selected outbound
//! peers, these peers will then be used to route stem txs to. //! peers, these peers will then be used to route stem txs to.
//! //!
//! The peers will not be returned anywhere, so it is recommended to wrap them in some sort of drop guard that returns //! The peers will not be returned anywhere, so it is recommended to wrap them in some sort of drop guard that returns
@ -37,10 +37,10 @@
//! ## Peer Service //! ## Peer Service
//! //!
//! This service represents a connection to an individual peer, this should be returned from the Outbound Peer //! This service represents a connection to an individual peer, this should be returned from the Outbound Peer
//! Discover. This should immediately send the transaction to the peer when requested, i.e. it should _not_ set //! TryStream. This should immediately send the transaction to the peer when requested, it should _not_ set
//! a timer. //! a timer.
//! //!
//! The diffuse service should have a request of [`StemRequest`](traits::StemRequest) and it's error //! The peer service should have a request of [`StemRequest`](traits::StemRequest) and its error
//! should be [`tower::BoxError`]. //! should be [`tower::BoxError`].
//! //!
//! ## Backing Pool //! ## Backing Pool

View file

@ -6,11 +6,10 @@
//! ### What The Router Does Not Do //! ### What The Router Does Not Do
//! //!
//! It does not handle anything to do with keeping transactions long term, i.e. embargo timers and handling //! It does not handle anything to do with keeping transactions long term, i.e. embargo timers and handling
//! loops in the stem. It is up to implementers to do this if they decide not top use [`DandelionPool`](crate::pool::DandelionPool) //! loops in the stem. It is up to implementers to do this if they decide not to use [`DandelionPool`](crate::pool::DandelionPool)
//! //!
use std::{ use std::{
collections::HashMap, collections::HashMap,
future::Future,
hash::Hash, hash::Hash,
marker::PhantomData, marker::PhantomData,
pin::Pin, pin::Pin,
@ -18,12 +17,9 @@ use std::{
time::Instant, time::Instant,
}; };
use futures::TryFutureExt; use futures::{future::BoxFuture, FutureExt, TryFutureExt, TryStream};
use rand::{distributions::Bernoulli, prelude::*, thread_rng}; use rand::{distributions::Bernoulli, prelude::*, thread_rng};
use tower::{ use tower::Service;
discover::{Change, Discover},
Service,
};
use crate::{ use crate::{
traits::{DiffuseRequest, StemRequest}, traits::{DiffuseRequest, StemRequest},
@ -39,14 +35,22 @@ pub enum DandelionRouterError {
/// The broadcast service returned an error. /// The broadcast service returned an error.
#[error("Broadcast service returned an err: {0}.")] #[error("Broadcast service returned an err: {0}.")]
BroadcastError(tower::BoxError), BroadcastError(tower::BoxError),
/// The outbound peer discoverer returned an error, this is critical. /// The outbound peer stream returned an error, this is critical.
#[error("The outbound peer discoverer returned an err: {0}.")] #[error("The outbound peer stream returned an err: {0}.")]
OutboundPeerDiscoverError(tower::BoxError), OutboundPeerStreamError(tower::BoxError),
/// The outbound peer discoverer returned [`None`]. /// The outbound peer discoverer returned [`None`].
#[error("The outbound peer discoverer exited.")] #[error("The outbound peer discoverer exited.")]
OutboundPeerDiscoverExited, OutboundPeerDiscoverExited,
} }
/// A response from an attempt to retrieve an outbound peer.
pub enum OutboundPeer<ID, T> {
/// A peer.
Peer(ID, T),
/// The peer store is exhausted and has no more to return.
Exhausted,
}
/// The dandelion++ state. /// The dandelion++ state.
#[derive(Debug, Copy, Clone, Eq, PartialEq)] #[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum State { pub enum State {
@ -116,9 +120,11 @@ pub struct DandelionRouter<P, B, ID, S, Tx> {
impl<Tx, ID, P, B, S> DandelionRouter<P, B, ID, S, Tx> impl<Tx, ID, P, B, S> DandelionRouter<P, B, ID, S, Tx>
where where
ID: Hash + Eq + Clone, ID: Hash + Eq + Clone,
P: Discover<Key = ID, Service = S, Error = tower::BoxError>, P: TryStream<Ok = OutboundPeer<ID, S>, Error = tower::BoxError>,
B: Service<DiffuseRequest<Tx>, Error = tower::BoxError>, B: Service<DiffuseRequest<Tx>, Error = tower::BoxError>,
B::Future: Send + 'static,
S: Service<StemRequest<Tx>, Error = tower::BoxError>, S: Service<StemRequest<Tx>, Error = tower::BoxError>,
S::Future: Send + 'static,
{ {
/// Creates a new [`DandelionRouter`], with the provided services and config. /// Creates a new [`DandelionRouter`], with the provided services and config.
/// ///
@ -165,15 +171,16 @@ where
match ready!(self match ready!(self
.outbound_peer_discover .outbound_peer_discover
.as_mut() .as_mut()
.poll_discover(cx) .try_poll_next(cx)
.map_err(DandelionRouterError::OutboundPeerDiscoverError)) .map_err(DandelionRouterError::OutboundPeerStreamError))
.ok_or(DandelionRouterError::OutboundPeerDiscoverExited)?? .ok_or(DandelionRouterError::OutboundPeerDiscoverExited)??
{ {
Change::Insert(key, svc) => { OutboundPeer::Peer(key, svc) => {
self.stem_peers.insert(key, svc); self.stem_peers.insert(key, svc);
} }
Change::Remove(key) => { OutboundPeer::Exhausted => {
self.stem_peers.remove(&key); tracing::warn!("Failed to retrieve enough outbound peers for optimal dandelion++, privacy may be degraded.");
return Poll::Ready(Ok(()));
} }
} }
} }
@ -181,11 +188,24 @@ where
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
fn fluff_tx(&mut self, tx: Tx) -> B::Future { fn fluff_tx(&mut self, tx: Tx) -> BoxFuture<'static, Result<State, DandelionRouterError>> {
self.broadcast_svc.call(DiffuseRequest(tx)) self.broadcast_svc
.call(DiffuseRequest(tx))
.map_ok(|_| State::Fluff)
.map_err(DandelionRouterError::BroadcastError)
.boxed()
}
fn stem_tx(
&mut self,
tx: Tx,
from: ID,
) -> BoxFuture<'static, Result<State, DandelionRouterError>> {
if self.stem_peers.is_empty() {
tracing::debug!("Stem peers are empty, fluffing stem transaction.");
return self.fluff_tx(tx);
} }
fn stem_tx(&mut self, tx: Tx, from: ID) -> S::Future {
loop { loop {
let stem_route = self.stem_routes.entry(from.clone()).or_insert_with(|| { let stem_route = self.stem_routes.entry(from.clone()).or_insert_with(|| {
self.stem_peers self.stem_peers
@ -201,11 +221,20 @@ where
continue; continue;
}; };
return peer.call(StemRequest(tx)); return peer
.call(StemRequest(tx))
.map_ok(|_| State::Stem)
.map_err(DandelionRouterError::PeerError)
.boxed();
} }
} }
fn stem_local_tx(&mut self, tx: Tx) -> S::Future { fn stem_local_tx(&mut self, tx: Tx) -> BoxFuture<'static, Result<State, DandelionRouterError>> {
if self.stem_peers.is_empty() {
tracing::warn!("Stem peers are empty, no outbound connections to stem local tx to, fluffing instead, privacy will be degraded.");
return self.fluff_tx(tx);
}
loop { loop {
let stem_route = self.local_route.get_or_insert_with(|| { let stem_route = self.local_route.get_or_insert_with(|| {
self.stem_peers self.stem_peers
@ -221,7 +250,11 @@ where
continue; continue;
}; };
return peer.call(StemRequest(tx)); return peer
.call(StemRequest(tx))
.map_ok(|_| State::Stem)
.map_err(DandelionRouterError::PeerError)
.boxed();
} }
} }
} }
@ -238,7 +271,7 @@ S: The Peer service - handles routing messages to a single node.
impl<Tx, ID, P, B, S> Service<DandelionRouteReq<Tx, ID>> for DandelionRouter<P, B, ID, S, Tx> impl<Tx, ID, P, B, S> Service<DandelionRouteReq<Tx, ID>> for DandelionRouter<P, B, ID, S, Tx>
where where
ID: Hash + Eq + Clone, ID: Hash + Eq + Clone,
P: Discover<Key = ID, Service = S, Error = tower::BoxError>, P: TryStream<Ok = OutboundPeer<ID, S>, Error = tower::BoxError>,
B: Service<DiffuseRequest<Tx>, Error = tower::BoxError>, B: Service<DiffuseRequest<Tx>, Error = tower::BoxError>,
B::Future: Send + 'static, B::Future: Send + 'static,
S: Service<StemRequest<Tx>, Error = tower::BoxError>, S: Service<StemRequest<Tx>, Error = tower::BoxError>,
@ -246,8 +279,7 @@ where
{ {
type Response = State; type Response = State;
type Error = DandelionRouterError; type Error = DandelionRouterError;
type Future = type Future = BoxFuture<'static, Result<State, DandelionRouterError>>;
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if self.epoch_start.elapsed() > self.config.epoch_duration { if self.epoch_start.elapsed() > self.config.epoch_duration {
@ -309,39 +341,23 @@ where
tracing::trace!(parent: &self.span, "Handling route request."); tracing::trace!(parent: &self.span, "Handling route request.");
match req.state { match req.state {
TxState::Fluff => Box::pin( TxState::Fluff => self.fluff_tx(req.tx),
self.fluff_tx(req.tx)
.map_ok(|_| State::Fluff)
.map_err(DandelionRouterError::BroadcastError),
),
TxState::Stem { from } => match self.current_state { TxState::Stem { from } => match self.current_state {
State::Fluff => { State::Fluff => {
tracing::debug!(parent: &self.span, "Fluffing stem tx."); tracing::debug!(parent: &self.span, "Fluffing stem tx.");
Box::pin(
self.fluff_tx(req.tx) self.fluff_tx(req.tx)
.map_ok(|_| State::Fluff)
.map_err(DandelionRouterError::BroadcastError),
)
} }
State::Stem => { State::Stem => {
tracing::trace!(parent: &self.span, "Steming transaction"); tracing::trace!(parent: &self.span, "Steming transaction");
Box::pin(
self.stem_tx(req.tx, from) self.stem_tx(req.tx, from)
.map_ok(|_| State::Stem)
.map_err(DandelionRouterError::PeerError),
)
} }
}, },
TxState::Local => { TxState::Local => {
tracing::debug!(parent: &self.span, "Steming local tx."); tracing::debug!(parent: &self.span, "Steming local tx.");
Box::pin(
self.stem_local_tx(req.tx) self.stem_local_tx(req.tx)
.map_ok(|_| State::Stem)
.map_err(DandelionRouterError::PeerError),
)
} }
} }
} }

View file

@ -3,43 +3,47 @@ mod router;
use std::{collections::HashMap, future::Future, hash::Hash, sync::Arc}; use std::{collections::HashMap, future::Future, hash::Hash, sync::Arc};
use futures::TryStreamExt; use futures::{Stream, StreamExt, TryStreamExt};
use tokio::sync::mpsc::{self, UnboundedReceiver}; use tokio::sync::mpsc::{self, UnboundedReceiver};
use tower::{ use tower::{util::service_fn, Service, ServiceExt};
discover::{Discover, ServiceList},
util::service_fn,
Service, ServiceExt,
};
use crate::{ use crate::{
traits::{TxStoreRequest, TxStoreResponse}, traits::{TxStoreRequest, TxStoreResponse},
State, OutboundPeer, State,
}; };
pub fn mock_discover_svc<Req: Send + 'static>() -> ( pub fn mock_discover_svc<Req: Send + 'static>() -> (
impl Discover< impl Stream<
Key = usize, Item = Result<
Service = impl Service< OutboundPeer<
usize,
impl Service<
Req, Req,
Future = impl Future<Output = Result<(), tower::BoxError>> + Send + 'static, Future = impl Future<Output = Result<(), tower::BoxError>> + Send + 'static,
Error = tower::BoxError, Error = tower::BoxError,
> + Send > + Send
+ 'static, + 'static,
Error = tower::BoxError,
>, >,
UnboundedReceiver<(u64, Req)>, tower::BoxError,
>,
>,
UnboundedReceiver<(usize, Req)>,
) { ) {
let (tx, rx) = mpsc::unbounded_channel(); let (tx, rx) = mpsc::unbounded_channel();
let discover = ServiceList::new((0..).map(move |i| { let discover = futures::stream::iter(0_usize..1_000_000)
.map(move |i| {
let tx_2 = tx.clone(); let tx_2 = tx.clone();
Ok::<_, tower::BoxError>(OutboundPeer::Peer(
i,
service_fn(move |req| { service_fn(move |req| {
tx_2.send((i, req)).unwrap(); tx_2.send((i, req)).unwrap();
async move { Ok::<(), tower::BoxError>(()) } async move { Ok::<(), tower::BoxError>(()) }
}),
))
}) })
}))
.map_err(Into::into); .map_err(Into::into);
(discover, rx) (discover, rx)

View file

@ -121,7 +121,7 @@ pub enum ChainSvcResponse {
/// The response for [`ChainSvcRequest::FindFirstUnknown`]. /// The response for [`ChainSvcRequest::FindFirstUnknown`].
/// ///
/// Contains the index of the first unknown block and its expected height. /// Contains the index of the first unknown block and its expected height.
FindFirstUnknown(usize, u64), FindFirstUnknown(Option<(usize, u64)>),
/// The response for [`ChainSvcRequest::CumulativeDifficulty`]. /// The response for [`ChainSvcRequest::CumulativeDifficulty`].
/// ///
/// The current cumulative difficulty of our chain. /// The current cumulative difficulty of our chain.

View file

@ -0,0 +1,238 @@
use std::{mem, sync::Arc};
use rand::prelude::SliceRandom;
use rand::thread_rng;
use tokio::{task::JoinSet, time::timeout};
use tower::{Service, ServiceExt};
use tracing::{instrument, Instrument, Span};
use cuprate_p2p_core::{
client::InternalPeerID,
handles::ConnectionHandle,
services::{PeerSyncRequest, PeerSyncResponse},
NetworkZone, PeerRequest, PeerResponse, PeerSyncSvc,
};
use cuprate_wire::protocol::{ChainRequest, ChainResponse};
use crate::{
block_downloader::{
chain_tracker::{ChainEntry, ChainTracker},
BlockDownloadError, ChainSvcRequest, ChainSvcResponse,
},
client_pool::{ClientPool, ClientPoolDropGuard},
constants::{
BLOCK_DOWNLOADER_REQUEST_TIMEOUT, INITIAL_CHAIN_REQUESTS_TO_SEND,
MAX_BLOCKS_IDS_IN_CHAIN_ENTRY, MEDIUM_BAN,
},
};
/// Request a chain entry from a peer.
///
/// Because the block downloader only follows and downloads one chain we only have to send the block hash of
/// top block we have found and the genesis block, this is then called `short_history`.
pub async fn request_chain_entry_from_peer<N: NetworkZone>(
mut client: ClientPoolDropGuard<N>,
short_history: [[u8; 32]; 2],
) -> Result<(ClientPoolDropGuard<N>, ChainEntry<N>), BlockDownloadError> {
let PeerResponse::GetChain(chain_res) = client
.ready()
.await?
.call(PeerRequest::GetChain(ChainRequest {
block_ids: short_history.into(),
prune: true,
}))
.await?
else {
panic!("Connection task returned wrong response!");
};
if chain_res.m_block_ids.is_empty()
|| chain_res.m_block_ids.len() > MAX_BLOCKS_IDS_IN_CHAIN_ENTRY
{
client.info.handle.ban_peer(MEDIUM_BAN);
return Err(BlockDownloadError::PeersResponseWasInvalid);
}
// We must have at least one overlapping block.
if !(chain_res.m_block_ids[0] == short_history[0]
|| chain_res.m_block_ids[0] == short_history[1])
{
client.info.handle.ban_peer(MEDIUM_BAN);
return Err(BlockDownloadError::PeersResponseWasInvalid);
}
// If the genesis is the overlapping block then this peer does not have our top tracked block in
// its chain.
if chain_res.m_block_ids[0] == short_history[1] {
return Err(BlockDownloadError::PeerDidNotHaveRequestedData);
}
let entry = ChainEntry {
ids: (&chain_res.m_block_ids).into(),
peer: client.info.id,
handle: client.info.handle.clone(),
};
Ok((client, entry))
}
/// Initial chain search, this function pulls [`INITIAL_CHAIN_REQUESTS_TO_SEND`] peers from the [`ClientPool`]
/// and sends chain requests to all of them.
///
/// We then wait for their response and choose the peer who claims the highest cumulative difficulty.
#[instrument(level = "error", skip_all)]
pub async fn initial_chain_search<N: NetworkZone, S, C>(
client_pool: &Arc<ClientPool<N>>,
mut peer_sync_svc: S,
mut our_chain_svc: C,
) -> Result<ChainTracker<N>, BlockDownloadError>
where
S: PeerSyncSvc<N>,
C: Service<ChainSvcRequest, Response = ChainSvcResponse, Error = tower::BoxError>,
{
tracing::debug!("Getting our chain history");
// Get our history.
let ChainSvcResponse::CompactHistory {
block_ids,
cumulative_difficulty,
} = our_chain_svc
.ready()
.await?
.call(ChainSvcRequest::CompactHistory)
.await?
else {
panic!("chain service sent wrong response.");
};
let our_genesis = *block_ids.last().expect("Blockchain had no genesis block.");
tracing::debug!("Getting a list of peers with higher cumulative difficulty");
let PeerSyncResponse::PeersToSyncFrom(mut peers) = peer_sync_svc
.ready()
.await?
.call(PeerSyncRequest::PeersToSyncFrom {
block_needed: None,
current_cumulative_difficulty: cumulative_difficulty,
})
.await?
else {
panic!("peer sync service sent wrong response.");
};
tracing::debug!(
"{} peers claim they have a higher cumulative difficulty",
peers.len()
);
// Shuffle the list to remove any possibility of peers being able to prioritize getting picked.
peers.shuffle(&mut thread_rng());
let mut peers = client_pool.borrow_clients(&peers);
let mut futs = JoinSet::new();
let req = PeerRequest::GetChain(ChainRequest {
block_ids: block_ids.into(),
prune: false,
});
tracing::debug!("Sending requests for chain entries.");
// Send the requests.
while futs.len() < INITIAL_CHAIN_REQUESTS_TO_SEND {
let Some(mut next_peer) = peers.next() else {
break;
};
let cloned_req = req.clone();
futs.spawn(timeout(
BLOCK_DOWNLOADER_REQUEST_TIMEOUT,
async move {
let PeerResponse::GetChain(chain_res) =
next_peer.ready().await?.call(cloned_req).await?
else {
panic!("connection task returned wrong response!");
};
Ok::<_, tower::BoxError>((
chain_res,
next_peer.info.id,
next_peer.info.handle.clone(),
))
}
.instrument(Span::current()),
));
}
let mut res: Option<(ChainResponse, InternalPeerID<_>, ConnectionHandle)> = None;
// Wait for the peers responses.
while let Some(task_res) = futs.join_next().await {
let Ok(Ok(task_res)) = task_res.unwrap() else {
continue;
};
match &mut res {
Some(res) => {
// res has already been set, replace it if this peer claims higher cumulative difficulty
if res.0.cumulative_difficulty() < task_res.0.cumulative_difficulty() {
let _ = mem::replace(res, task_res);
}
}
None => {
// res has not been set, set it now;
res = Some(task_res);
}
}
}
let Some((chain_res, peer_id, peer_handle)) = res else {
return Err(BlockDownloadError::FailedToFindAChainToFollow);
};
let hashes: Vec<[u8; 32]> = (&chain_res.m_block_ids).into();
// drop this to deallocate the [`Bytes`].
drop(chain_res);
tracing::debug!("Highest chin entry contained {} block Ids", hashes.len());
// Find the first unknown block in the batch.
let ChainSvcResponse::FindFirstUnknown(first_unknown_ret) = our_chain_svc
.ready()
.await?
.call(ChainSvcRequest::FindFirstUnknown(hashes.clone()))
.await?
else {
panic!("chain service sent wrong response.");
};
// We know all the blocks already
// TODO: The peer could still be on a different chain, however the chain might just be too far split.
let Some((first_unknown, expected_height)) = first_unknown_ret else {
return Err(BlockDownloadError::FailedToFindAChainToFollow);
};
// The peer must send at least one block we already know.
if first_unknown == 0 {
peer_handle.ban_peer(MEDIUM_BAN);
return Err(BlockDownloadError::PeerSentNoOverlappingBlocks);
}
let previous_id = hashes[first_unknown - 1];
let first_entry = ChainEntry {
ids: hashes[first_unknown..].to_vec(),
peer: peer_id,
handle: peer_handle,
};
tracing::debug!(
"Creating chain tracker with {} new block Ids",
first_entry.ids.len()
);
let tracker = ChainTracker::new(first_entry, expected_height, our_genesis, previous_id);
Ok(tracker)
}

View file

@ -0,0 +1,325 @@
use std::{
fmt::{Debug, Formatter},
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Duration,
};
use futures::{FutureExt, StreamExt};
use indexmap::IndexMap;
use monero_serai::{
block::{Block, BlockHeader},
ringct::{RctBase, RctPrunable, RctSignatures},
transaction::{Input, Timelock, Transaction, TransactionPrefix},
};
use proptest::{collection::vec, prelude::*};
use tokio::{sync::Semaphore, time::timeout};
use tower::{service_fn, Service};
use cuprate_fixed_bytes::ByteArrayVec;
use cuprate_p2p_core::{
client::{mock_client, Client, InternalPeerID, PeerInformation},
network_zones::ClearNet,
services::{PeerSyncRequest, PeerSyncResponse},
ConnectionDirection, NetworkZone, PeerRequest, PeerResponse,
};
use cuprate_pruning::PruningSeed;
use cuprate_wire::{
common::{BlockCompleteEntry, TransactionBlobs},
protocol::{ChainResponse, GetObjectsResponse},
};
use crate::{
block_downloader::{download_blocks, BlockDownloaderConfig, ChainSvcRequest, ChainSvcResponse},
client_pool::ClientPool,
};
proptest! {
#![proptest_config(ProptestConfig {
cases: 4,
max_shrink_iters: 10,
timeout: 60 * 1000,
.. ProptestConfig::default()
})]
#[test]
fn test_block_downloader(blockchain in dummy_blockchain_stragtegy(), peers in 1_usize..128) {
let blockchain = Arc::new(blockchain);
let tokio_pool = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
tokio_pool.block_on(async move {
timeout(Duration::from_secs(600), async move {
let client_pool = ClientPool::new();
let mut peer_ids = Vec::with_capacity(peers);
for _ in 0..peers {
let client = mock_block_downloader_client(blockchain.clone());
peer_ids.push(client.info.id);
client_pool.add_new_client(client);
}
let stream = download_blocks(
client_pool,
SyncStateSvc(peer_ids) ,
OurChainSvc {
genesis: *blockchain.blocks.first().unwrap().0
},
BlockDownloaderConfig {
buffer_size: 1_000,
in_progress_queue_size: 10_000,
check_client_pool_interval: Duration::from_secs(5),
target_batch_size: 5_000,
initial_batch_size: 1,
});
let blocks = stream.map(|blocks| blocks.blocks).concat().await;
assert_eq!(blocks.len() + 1, blockchain.blocks.len());
for (i, block) in blocks.into_iter().enumerate() {
assert_eq!(&block, blockchain.blocks.get_index(i + 1).unwrap().1);
}
}).await
}).unwrap();
}
}
prop_compose! {
/// Returns a strategy to generate a [`Transaction`] that is valid for the block downloader.
fn dummy_transaction_stragtegy(height: u64)
(
extra in vec(any::<u8>(), 0..1_000),
timelock in 1_usize..50_000_000,
)
-> Transaction {
Transaction {
prefix: TransactionPrefix {
version: 1,
timelock: Timelock::Block(timelock),
inputs: vec![Input::Gen(height)],
outputs: vec![],
extra,
},
signatures: vec![],
rct_signatures: RctSignatures {
base: RctBase {
fee: 0,
pseudo_outs: vec![],
encrypted_amounts: vec![],
commitments: vec![],
},
prunable: RctPrunable::Null
},
}
}
}
prop_compose! {
/// Returns a strategy to generate a [`Block`] that is valid for the block downloader.
fn dummy_block_stragtegy(
height: u64,
previous: [u8; 32],
)
(
miner_tx in dummy_transaction_stragtegy(height),
txs in vec(dummy_transaction_stragtegy(height), 0..25)
)
-> (Block, Vec<Transaction>) {
(
Block {
header: BlockHeader {
major_version: 0,
minor_version: 0,
timestamp: 0,
previous,
nonce: 0,
},
miner_tx,
txs: txs.iter().map(Transaction::hash).collect(),
},
txs
)
}
}
/// A mock blockchain.
struct MockBlockchain {
blocks: IndexMap<[u8; 32], (Block, Vec<Transaction>)>,
}
impl Debug for MockBlockchain {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_str("MockBlockchain")
}
}
prop_compose! {
/// Returns a strategy to generate a [`MockBlockchain`].
fn dummy_blockchain_stragtegy()(
blocks in vec(dummy_block_stragtegy(0, [0; 32]), 1..50_000),
) -> MockBlockchain {
let mut blockchain = IndexMap::new();
for (height, mut block) in blocks.into_iter().enumerate() {
if let Some(last) = blockchain.last() {
block.0.header.previous = *last.0;
block.0.miner_tx.prefix.inputs = vec![Input::Gen(height as u64)]
}
blockchain.insert(block.0.hash(), block);
}
MockBlockchain {
blocks: blockchain
}
}
}
fn mock_block_downloader_client(blockchain: Arc<MockBlockchain>) -> Client<ClearNet> {
let semaphore = Arc::new(Semaphore::new(1));
let (connection_guard, connection_handle) = cuprate_p2p_core::handles::HandleBuilder::new()
.with_permit(semaphore.try_acquire_owned().unwrap())
.build();
let request_handler = service_fn(move |req: PeerRequest| {
let bc = blockchain.clone();
async move {
match req {
PeerRequest::GetChain(chain_req) => {
let mut i = 0;
while !bc.blocks.contains_key(&chain_req.block_ids[i]) {
i += 1;
if i == chain_req.block_ids.len() {
i -= 1;
break;
}
}
let block_index = bc.blocks.get_index_of(&chain_req.block_ids[i]).unwrap();
let block_ids = bc
.blocks
.get_range(block_index..)
.unwrap()
.iter()
.map(|(id, _)| *id)
.take(200)
.collect::<Vec<_>>();
Ok(PeerResponse::GetChain(ChainResponse {
start_height: 0,
total_height: 0,
cumulative_difficulty_low64: 1,
cumulative_difficulty_top64: 0,
m_block_ids: block_ids.into(),
m_block_weights: vec![],
first_block: Default::default(),
}))
}
PeerRequest::GetObjects(obj) => {
let mut res = Vec::with_capacity(obj.blocks.len());
for i in 0..obj.blocks.len() {
let block = bc.blocks.get(&obj.blocks[i]).unwrap();
let block_entry = BlockCompleteEntry {
pruned: false,
block: block.0.serialize().into(),
txs: TransactionBlobs::Normal(
block
.1
.iter()
.map(Transaction::serialize)
.map(Into::into)
.collect(),
),
block_weight: 0,
};
res.push(block_entry);
}
Ok(PeerResponse::GetObjects(GetObjectsResponse {
blocks: res,
missed_ids: ByteArrayVec::from([]),
current_blockchain_height: 0,
}))
}
_ => panic!(),
}
}
.boxed()
});
let info = PeerInformation {
id: InternalPeerID::Unknown(rand::random()),
handle: connection_handle,
direction: ConnectionDirection::InBound,
pruning_seed: PruningSeed::NotPruned,
};
mock_client(info, connection_guard, request_handler)
}
#[derive(Clone)]
struct SyncStateSvc<Z: NetworkZone>(Vec<InternalPeerID<Z::Addr>>);
impl Service<PeerSyncRequest<ClearNet>> for SyncStateSvc<ClearNet> {
type Response = PeerSyncResponse<ClearNet>;
type Error = tower::BoxError;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _: PeerSyncRequest<ClearNet>) -> Self::Future {
let peers = self.0.clone();
async move { Ok(PeerSyncResponse::PeersToSyncFrom(peers)) }.boxed()
}
}
struct OurChainSvc {
genesis: [u8; 32],
}
impl Service<ChainSvcRequest> for OurChainSvc {
type Response = ChainSvcResponse;
type Error = tower::BoxError;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: ChainSvcRequest) -> Self::Future {
let genesis = self.genesis;
async move {
Ok(match req {
ChainSvcRequest::CompactHistory => ChainSvcResponse::CompactHistory {
block_ids: vec![genesis],
cumulative_difficulty: 1,
},
ChainSvcRequest::FindFirstUnknown(_) => {
ChainSvcResponse::FindFirstUnknown(Some((1, 1)))
}
ChainSvcRequest::CumulativeDifficulty => ChainSvcResponse::CumulativeDifficulty(1),
})
}
.boxed()
}
}

View file

@ -45,8 +45,8 @@ rayon = { workspace = true, optional = true }
cuprate-helper = { path = "../../helper", features = ["thread"] } cuprate-helper = { path = "../../helper", features = ["thread"] }
cuprate-test-utils = { path = "../../test-utils" } cuprate-test-utils = { path = "../../test-utils" }
bytemuck = { version = "1.14.3", features = ["must_cast", "derive", "min_const_generics", "extern_crate_alloc"] }
tempfile = { version = "3.10.0" } tempfile = { version = "3.10.0" }
pretty_assertions = { workspace = true } pretty_assertions = { workspace = true }
proptest = { workspace = true }
hex = { workspace = true } hex = { workspace = true }
hex-literal = { workspace = true } hex-literal = { workspace = true }

View file

@ -169,8 +169,9 @@ mod test {
env_inner.open_tables(&tx_ro).unwrap(); env_inner.open_tables(&tx_ro).unwrap();
} }
/// Tests that directory [`cuprate_database::ConcreteEnv`] /// Tests that direct usage of
/// usage does NOT create all tables. /// [`cuprate_database::ConcreteEnv`]
/// does NOT create all tables.
#[test] #[test]
#[should_panic(expected = "`Result::unwrap()` on an `Err` value: TableNotFound")] #[should_panic(expected = "`Result::unwrap()` on an `Err` value: TableNotFound")]
fn test_no_tables_are_created() { fn test_no_tables_are_created() {

View file

@ -33,8 +33,69 @@ pub fn init(config: Config) -> Result<(DatabaseReadHandle, DatabaseWriteHandle),
Ok((readers, writer)) Ok((readers, writer))
} }
//---------------------------------------------------------------------------------------------------- Tests //---------------------------------------------------------------------------------------------------- Compact history
#[cfg(test)] /// Given a position in the compact history, returns the height offset that should be in that position.
mod test { ///
// use super::*; /// The height offset is the difference between the top block's height and the block height that should be in that position.
#[inline]
pub(super) const fn compact_history_index_to_height_offset<const INITIAL_BLOCKS: u64>(
i: u64,
) -> u64 {
// If the position is below the initial blocks just return the position back
if i <= INITIAL_BLOCKS {
i
} else {
// Otherwise we go with power of 2 offsets, the same as monerod.
// So (INITIAL_BLOCKS + 2), (INITIAL_BLOCKS + 2 + 4), (INITIAL_BLOCKS + 2 + 4 + 8)
// ref: <https://github.com/monero-project/monero/blob/cc73fe71162d564ffda8e549b79a350bca53c454/src/cryptonote_core/blockchain.cpp#L727>
INITIAL_BLOCKS + (2 << (i - INITIAL_BLOCKS)) - 2
}
}
/// Returns if the genesis block was _NOT_ included when calculating the height offsets.
///
/// The genesis must always be included in the compact history.
#[inline]
pub(super) const fn compact_history_genesis_not_included<const INITIAL_BLOCKS: u64>(
top_block_height: u64,
) -> bool {
// If the top block height is less than the initial blocks then it will always be included.
// Otherwise, we use the fact that to reach the genesis block this statement must be true (for a
// single `i`):
//
// `top_block_height - INITIAL_BLOCKS - 2^i + 2 == 0`
// which then means:
// `top_block_height - INITIAL_BLOCKS + 2 == 2^i`
// So if `top_block_height - INITIAL_BLOCKS + 2` is a power of 2 then the genesis block is in
// the compact history already.
top_block_height > INITIAL_BLOCKS && !(top_block_height - INITIAL_BLOCKS + 2).is_power_of_two()
}
//---------------------------------------------------------------------------------------------------- Tests
#[cfg(test)]
mod tests {
use proptest::prelude::*;
use super::*;
proptest! {
#[test]
fn compact_history(top_height in 0_u64..500_000_000) {
let mut heights = (0..)
.map(compact_history_index_to_height_offset::<11>)
.map_while(|i| top_height.checked_sub(i))
.collect::<Vec<_>>();
if compact_history_genesis_not_included::<11>(top_height) {
heights.push(0);
}
// Make sure the genesis and top block are always included.
assert_eq!(*heights.last().unwrap(), 0);
assert_eq!(*heights.first().unwrap(), top_height);
heights.windows(2).for_each(|window| assert_ne!(window[0], window[1]));
}
}
} }

View file

@ -14,7 +14,7 @@ use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use tokio_util::sync::PollSemaphore; use tokio_util::sync::PollSemaphore;
use cuprate_database::{ConcreteEnv, DatabaseRo, Env, EnvInner, RuntimeError}; use cuprate_database::{ConcreteEnv, DatabaseRo, Env, EnvInner, RuntimeError};
use cuprate_helper::asynch::InfallibleOneshotReceiver; use cuprate_helper::{asynch::InfallibleOneshotReceiver, map::combine_low_high_bits_to_u128};
use cuprate_types::{ use cuprate_types::{
blockchain::{BCReadRequest, BCResponse}, blockchain::{BCReadRequest, BCResponse},
ExtendedBlockHeader, OutputOnChain, ExtendedBlockHeader, OutputOnChain,
@ -23,17 +23,20 @@ use cuprate_types::{
use crate::{ use crate::{
config::ReaderThreads, config::ReaderThreads,
open_tables::OpenTables, open_tables::OpenTables,
ops::block::block_exists,
ops::{ ops::{
block::{get_block_extended_header_from_height, get_block_info}, block::{
block_exists, get_block_extended_header_from_height, get_block_height, get_block_info,
},
blockchain::{cumulative_generated_coins, top_block_height}, blockchain::{cumulative_generated_coins, top_block_height},
key_image::key_image_exists, key_image::key_image_exists,
output::id_to_output_on_chain, output::id_to_output_on_chain,
}, },
service::types::{ResponseReceiver, ResponseResult, ResponseSender}, service::{
free::{compact_history_genesis_not_included, compact_history_index_to_height_offset},
types::{ResponseReceiver, ResponseResult, ResponseSender},
},
tables::{BlockHeights, BlockInfos, Tables}, tables::{BlockHeights, BlockInfos, Tables},
types::BlockHash, types::{Amount, AmountIndex, BlockHash, BlockHeight, KeyImage, PreRctOutputId},
types::{Amount, AmountIndex, BlockHeight, KeyImage, PreRctOutputId},
}; };
//---------------------------------------------------------------------------------------------------- DatabaseReadHandle //---------------------------------------------------------------------------------------------------- DatabaseReadHandle
@ -204,13 +207,15 @@ fn map_request(
let response = match request { let response = match request {
R::BlockExtendedHeader(block) => block_extended_header(env, block), R::BlockExtendedHeader(block) => block_extended_header(env, block),
R::BlockHash(block) => block_hash(env, block), R::BlockHash(block) => block_hash(env, block),
R::FilterUnknownHashes(hashes) => filter_unknown_hahses(env, hashes), R::FilterUnknownHashes(hashes) => filter_unknown_hashes(env, hashes),
R::BlockExtendedHeaderInRange(range) => block_extended_header_in_range(env, range), R::BlockExtendedHeaderInRange(range) => block_extended_header_in_range(env, range),
R::ChainHeight => chain_height(env), R::ChainHeight => chain_height(env),
R::GeneratedCoins => generated_coins(env), R::GeneratedCoins => generated_coins(env),
R::Outputs(map) => outputs(env, map), R::Outputs(map) => outputs(env, map),
R::NumberOutputsWithAmount(vec) => number_outputs_with_amount(env, vec), R::NumberOutputsWithAmount(vec) => number_outputs_with_amount(env, vec),
R::KeyImagesSpent(set) => key_images_spent(env, set), R::KeyImagesSpent(set) => key_images_spent(env, set),
R::CompactChainHistory => compact_chain_history(env),
R::FindFirstUnknown(block_ids) => find_first_unknown(env, &block_ids),
}; };
if let Err(e) = response_sender.send(response) { if let Err(e) = response_sender.send(response) {
@ -320,7 +325,7 @@ fn block_hash(env: &ConcreteEnv, block_height: BlockHeight) -> ResponseResult {
/// [`BCReadRequest::FilterUnknownHashes`]. /// [`BCReadRequest::FilterUnknownHashes`].
#[inline] #[inline]
fn filter_unknown_hahses(env: &ConcreteEnv, mut hashes: HashSet<BlockHash>) -> ResponseResult { fn filter_unknown_hashes(env: &ConcreteEnv, mut hashes: HashSet<BlockHash>) -> ResponseResult {
// Single-threaded, no `ThreadLocal` required. // Single-threaded, no `ThreadLocal` required.
let env_inner = env.env_inner(); let env_inner = env.env_inner();
let tx_ro = env_inner.tx_ro()?; let tx_ro = env_inner.tx_ro()?;
@ -525,3 +530,81 @@ fn key_images_spent(env: &ConcreteEnv, key_images: HashSet<KeyImage>) -> Respons
Some(Err(e)) => Err(e), // A database error occurred. Some(Err(e)) => Err(e), // A database error occurred.
} }
} }
/// [`BCReadRequest::CompactChainHistory`]
fn compact_chain_history(env: &ConcreteEnv) -> ResponseResult {
let env_inner = env.env_inner();
let tx_ro = env_inner.tx_ro()?;
let table_block_heights = env_inner.open_db_ro::<BlockHeights>(&tx_ro)?;
let table_block_infos = env_inner.open_db_ro::<BlockInfos>(&tx_ro)?;
let top_block_height = top_block_height(&table_block_heights)?;
let top_block_info = get_block_info(&top_block_height, &table_block_infos)?;
let cumulative_difficulty = combine_low_high_bits_to_u128(
top_block_info.cumulative_difficulty_low,
top_block_info.cumulative_difficulty_high,
);
/// The amount of top block IDs in the compact chain.
const INITIAL_BLOCKS: u64 = 11;
// rayon is not used here because the amount of block IDs is expected to be small.
let mut block_ids = (0..)
.map(compact_history_index_to_height_offset::<INITIAL_BLOCKS>)
.map_while(|i| top_block_height.checked_sub(i))
.map(|height| Ok(get_block_info(&height, &table_block_infos)?.block_hash))
.collect::<Result<Vec<_>, RuntimeError>>()?;
if compact_history_genesis_not_included::<INITIAL_BLOCKS>(top_block_height) {
block_ids.push(get_block_info(&0, &table_block_infos)?.block_hash);
}
Ok(BCResponse::CompactChainHistory {
cumulative_difficulty,
block_ids,
})
}
/// [`BCReadRequest::FindFirstUnknown`]
///
/// # Invariant
/// `block_ids` must be sorted in chronological block order, or else
/// the returned result is unspecified and meaningless, as this function
/// performs a binary search.
fn find_first_unknown(env: &ConcreteEnv, block_ids: &[BlockHash]) -> ResponseResult {
let env_inner = env.env_inner();
let tx_ro = env_inner.tx_ro()?;
let table_block_heights = env_inner.open_db_ro::<BlockHeights>(&tx_ro)?;
let mut err = None;
// Do a binary search to find the first unknown block in the batch.
let idx =
block_ids.partition_point(
|block_id| match block_exists(block_id, &table_block_heights) {
Ok(exists) => exists,
Err(e) => {
err.get_or_insert(e);
// if this happens the search is scrapped, just return `false` back.
false
}
},
);
if let Some(e) = err {
return Err(e);
}
Ok(if idx == block_ids.len() {
BCResponse::FindFirstUnknown(None)
} else if idx == 0 {
BCResponse::FindFirstUnknown(Some((0, 0)))
} else {
let last_known_height = get_block_height(&block_ids[idx - 1], &table_block_heights)?;
BCResponse::FindFirstUnknown(Some((idx, last_known_height + 1)))
})
}

View file

@ -46,7 +46,7 @@ use bytemuck::{Pod, Zeroable};
#[cfg(feature = "serde")] #[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use cuprate_database::StorableVec; use cuprate_database::{Key, StorableVec};
//---------------------------------------------------------------------------------------------------- Aliases //---------------------------------------------------------------------------------------------------- Aliases
// These type aliases exist as many Monero-related types are the exact same. // These type aliases exist as many Monero-related types are the exact same.
@ -143,6 +143,8 @@ pub struct PreRctOutputId {
pub amount_index: AmountIndex, pub amount_index: AmountIndex,
} }
impl Key for PreRctOutputId {}
//---------------------------------------------------------------------------------------------------- BlockInfoV3 //---------------------------------------------------------------------------------------------------- BlockInfoV3
/// Block information. /// Block information.
/// ///

View file

@ -7,26 +7,23 @@ use std::{
sync::{RwLock, RwLockReadGuard}, sync::{RwLock, RwLockReadGuard},
}; };
use heed::{EnvFlags, EnvOpenOptions}; use heed::{DatabaseFlags, EnvFlags, EnvOpenOptions};
use crate::{ use crate::{
backend::heed::{ backend::heed::{
database::{HeedTableRo, HeedTableRw}, database::{HeedTableRo, HeedTableRw},
storable::StorableHeed,
types::HeedDb, types::HeedDb,
}, },
config::{Config, SyncMode}, config::{Config, SyncMode},
database::{DatabaseIter, DatabaseRo, DatabaseRw}, database::{DatabaseIter, DatabaseRo, DatabaseRw},
env::{Env, EnvInner}, env::{Env, EnvInner},
error::{InitError, RuntimeError}, error::{InitError, RuntimeError},
key::{Key, KeyCompare},
resize::ResizeAlgorithm, resize::ResizeAlgorithm,
table::Table, table::Table,
}; };
//---------------------------------------------------------------------------------------------------- Consts
/// Panic message when there's a table missing.
const PANIC_MSG_MISSING_TABLE: &str =
"cuprate_database::Env should uphold the invariant that all tables are already created";
//---------------------------------------------------------------------------------------------------- ConcreteEnv //---------------------------------------------------------------------------------------------------- ConcreteEnv
/// A strongly typed, concrete database environment, backed by `heed`. /// A strongly typed, concrete database environment, backed by `heed`.
pub struct ConcreteEnv { pub struct ConcreteEnv {
@ -268,6 +265,10 @@ where
tx_ro: &heed::RoTxn<'env>, tx_ro: &heed::RoTxn<'env>,
) -> Result<impl DatabaseRo<T> + DatabaseIter<T>, RuntimeError> { ) -> Result<impl DatabaseRo<T> + DatabaseIter<T>, RuntimeError> {
// Open up a read-only database using our table's const metadata. // Open up a read-only database using our table's const metadata.
//
// INVARIANT: LMDB caches the ordering / comparison function from [`EnvInner::create_db`],
// and we're relying on that since we aren't setting that here.
// <https://github.com/Cuprate/cuprate/pull/198#discussion_r1659422277>
Ok(HeedTableRo { Ok(HeedTableRo {
db: self db: self
.open_database(tx_ro, Some(T::NAME))? .open_database(tx_ro, Some(T::NAME))?
@ -282,6 +283,10 @@ where
tx_rw: &RefCell<heed::RwTxn<'env>>, tx_rw: &RefCell<heed::RwTxn<'env>>,
) -> Result<impl DatabaseRw<T>, RuntimeError> { ) -> Result<impl DatabaseRw<T>, RuntimeError> {
// Open up a read/write database using our table's const metadata. // Open up a read/write database using our table's const metadata.
//
// INVARIANT: LMDB caches the ordering / comparison function from [`EnvInner::create_db`],
// and we're relying on that since we aren't setting that here.
// <https://github.com/Cuprate/cuprate/pull/198#discussion_r1659422277>
Ok(HeedTableRw { Ok(HeedTableRw {
db: self.create_database(&mut tx_rw.borrow_mut(), Some(T::NAME))?, db: self.create_database(&mut tx_rw.borrow_mut(), Some(T::NAME))?,
tx_rw, tx_rw,
@ -289,8 +294,33 @@ where
} }
fn create_db<T: Table>(&self, tx_rw: &RefCell<heed::RwTxn<'env>>) -> Result<(), RuntimeError> { fn create_db<T: Table>(&self, tx_rw: &RefCell<heed::RwTxn<'env>>) -> Result<(), RuntimeError> {
// INVARIANT: `heed` creates tables with `open_database` if they don't exist. // Create a database using our:
self.open_db_rw::<T>(tx_rw)?; // - [`Table`]'s const metadata.
// - (potentially) our [`Key`] comparison function
let mut tx_rw = tx_rw.borrow_mut();
let mut db = self.database_options();
db.name(T::NAME);
// Set the key comparison behavior.
match <T::Key>::KEY_COMPARE {
// Use LMDB's default comparison function.
KeyCompare::Default => {
db.create(&mut tx_rw)?;
}
// Instead of setting a custom [`heed::Comparator`],
// use this LMDB flag; it is ~10% faster.
KeyCompare::Number => {
db.flags(DatabaseFlags::INTEGER_KEY).create(&mut tx_rw)?;
}
// Use a custom comparison function if specified.
KeyCompare::Custom(_) => {
db.key_comparator::<StorableHeed<T::Key>>()
.create(&mut tx_rw)?;
}
}
Ok(()) Ok(())
} }
@ -301,18 +331,18 @@ where
) -> Result<(), RuntimeError> { ) -> Result<(), RuntimeError> {
let tx_rw = tx_rw.get_mut(); let tx_rw = tx_rw.get_mut();
// Open the table first... // Open the table. We don't care about flags or key
// comparison behavior since we're clearing it anyway.
let db: HeedDb<T::Key, T::Value> = self let db: HeedDb<T::Key, T::Value> = self
.open_database(tx_rw, Some(T::NAME))? .open_database(tx_rw, Some(T::NAME))?
.expect(PANIC_MSG_MISSING_TABLE); .ok_or(RuntimeError::TableNotFound)?;
// ...then clear it. db.clear(tx_rw)?;
Ok(db.clear(tx_rw)?)
Ok(())
} }
} }
//---------------------------------------------------------------------------------------------------- Tests //---------------------------------------------------------------------------------------------------- Tests
#[cfg(test)] #[cfg(test)]
mod test { mod tests {}
// use super::*;
}

View file

@ -1,11 +1,11 @@
//! `cuprate_database::Storable` <-> `heed` serde trait compatibility layer. //! `cuprate_database::Storable` <-> `heed` serde trait compatibility layer.
//---------------------------------------------------------------------------------------------------- Use //---------------------------------------------------------------------------------------------------- Use
use std::{borrow::Cow, marker::PhantomData}; use std::{borrow::Cow, cmp::Ordering, marker::PhantomData};
use heed::{BoxedError, BytesDecode, BytesEncode}; use heed::{BoxedError, BytesDecode, BytesEncode};
use crate::storable::Storable; use crate::{storable::Storable, Key};
//---------------------------------------------------------------------------------------------------- StorableHeed //---------------------------------------------------------------------------------------------------- StorableHeed
/// The glue struct that implements `heed`'s (de)serialization /// The glue struct that implements `heed`'s (de)serialization
@ -16,7 +16,19 @@ pub(super) struct StorableHeed<T>(PhantomData<T>)
where where
T: Storable + ?Sized; T: Storable + ?Sized;
//---------------------------------------------------------------------------------------------------- BytesDecode //---------------------------------------------------------------------------------------------------- Key
// If `Key` is also implemented, this can act as the comparison function.
impl<T> heed::Comparator for StorableHeed<T>
where
T: Key,
{
#[inline]
fn compare(a: &[u8], b: &[u8]) -> Ordering {
<T as Key>::KEY_COMPARE.as_compare_fn::<T>()(a, b)
}
}
//---------------------------------------------------------------------------------------------------- BytesDecode/Encode
impl<'a, T> BytesDecode<'a> for StorableHeed<T> impl<'a, T> BytesDecode<'a> for StorableHeed<T>
where where
T: Storable + 'static, T: Storable + 'static,
@ -30,7 +42,6 @@ where
} }
} }
//---------------------------------------------------------------------------------------------------- BytesEncode
impl<'a, T> BytesEncode<'a> for StorableHeed<T> impl<'a, T> BytesEncode<'a> for StorableHeed<T>
where where
T: Storable + ?Sized + 'a, T: Storable + ?Sized + 'a,
@ -57,6 +68,42 @@ mod test {
// - simplify trait bounds // - simplify trait bounds
// - make sure the right function is being called // - make sure the right function is being called
#[test]
/// Assert key comparison behavior is correct.
fn compare() {
fn test<T>(left: T, right: T, expected: Ordering)
where
T: Key + Ord + 'static,
{
println!("left: {left:?}, right: {right:?}, expected: {expected:?}");
assert_eq!(
<StorableHeed::<T> as heed::Comparator>::compare(
&<StorableHeed::<T> as heed::BytesEncode>::bytes_encode(&left).unwrap(),
&<StorableHeed::<T> as heed::BytesEncode>::bytes_encode(&right).unwrap()
),
expected
);
}
// Value comparison
test::<u8>(0, 255, Ordering::Less);
test::<u16>(0, 256, Ordering::Less);
test::<u32>(0, 256, Ordering::Less);
test::<u64>(0, 256, Ordering::Less);
test::<u128>(0, 256, Ordering::Less);
test::<usize>(0, 256, Ordering::Less);
test::<i8>(-1, 2, Ordering::Less);
test::<i16>(-1, 2, Ordering::Less);
test::<i32>(-1, 2, Ordering::Less);
test::<i64>(-1, 2, Ordering::Less);
test::<i128>(-1, 2, Ordering::Less);
test::<isize>(-1, 2, Ordering::Less);
// Byte comparison
test::<[u8; 2]>([1, 1], [1, 0], Ordering::Greater);
test::<[u8; 3]>([1, 2, 3], [1, 2, 3], Ordering::Equal);
}
#[test] #[test]
/// Assert `BytesEncode::bytes_encode` is accurate. /// Assert `BytesEncode::bytes_encode` is accurate.
fn bytes_encode() { fn bytes_encode() {

View file

@ -5,4 +5,7 @@ use crate::backend::heed::storable::StorableHeed;
//---------------------------------------------------------------------------------------------------- Types //---------------------------------------------------------------------------------------------------- Types
/// The concrete database type for `heed`, usable for reads and writes. /// The concrete database type for `heed`, usable for reads and writes.
//
// Key type Value type
// v v
pub(super) type HeedDb<K, V> = heed::Database<StorableHeed<K>, StorableHeed<V>>; pub(super) type HeedDb<K, V> = heed::Database<StorableHeed<K>, StorableHeed<V>>;

View file

@ -189,7 +189,10 @@ where
// 3. So it's not being used to open a table since that needs `&tx_rw` // 3. So it's not being used to open a table since that needs `&tx_rw`
// //
// Reader-open tables do not affect this, if they're open the below is still OK. // Reader-open tables do not affect this, if they're open the below is still OK.
redb::WriteTransaction::delete_table(tx_rw, table)?; if !redb::WriteTransaction::delete_table(tx_rw, table)? {
return Err(RuntimeError::TableNotFound);
}
// Re-create the table. // Re-create the table.
// `redb` creates tables if they don't exist, so this should never panic. // `redb` creates tables if they don't exist, so this should never panic.
redb::WriteTransaction::open_table(tx_rw, table)?; redb::WriteTransaction::open_table(tx_rw, table)?;
@ -200,6 +203,4 @@ where
//---------------------------------------------------------------------------------------------------- Tests //---------------------------------------------------------------------------------------------------- Tests
#[cfg(test)] #[cfg(test)]
mod test { mod tests {}
// use super::*;
}

View file

@ -25,7 +25,7 @@ where
{ {
#[inline] #[inline]
fn compare(left: &[u8], right: &[u8]) -> Ordering { fn compare(left: &[u8], right: &[u8]) -> Ordering {
<T as Key>::compare(left, right) <T as Key>::KEY_COMPARE.as_compare_fn::<T>()(left, right)
} }
} }
@ -93,8 +93,21 @@ mod test {
); );
} }
test::<i64>(-1, 2, Ordering::Greater); // bytes are greater, not the value // Value comparison
test::<u64>(0, 1, Ordering::Less); test::<u8>(0, 255, Ordering::Less);
test::<u16>(0, 256, Ordering::Less);
test::<u32>(0, 256, Ordering::Less);
test::<u64>(0, 256, Ordering::Less);
test::<u128>(0, 256, Ordering::Less);
test::<usize>(0, 256, Ordering::Less);
test::<i8>(-1, 2, Ordering::Less);
test::<i16>(-1, 2, Ordering::Less);
test::<i32>(-1, 2, Ordering::Less);
test::<i64>(-1, 2, Ordering::Less);
test::<i128>(-1, 2, Ordering::Less);
test::<isize>(-1, 2, Ordering::Less);
// Byte comparison
test::<[u8; 2]>([1, 1], [1, 0], Ordering::Greater); test::<[u8; 2]>([1, 1], [1, 0], Ordering::Greater);
test::<[u8; 3]>([1, 2, 3], [1, 2, 3], Ordering::Equal); test::<[u8; 3]>([1, 2, 3], [1, 2, 3], Ordering::Equal);
} }

View file

@ -156,6 +156,20 @@ fn non_manual_resize_2() {
env.current_map_size(); env.current_map_size();
} }
/// Tests that [`EnvInner::clear_db`] will return
/// [`RuntimeError::TableNotFound`] if the table doesn't exist.
#[test]
fn clear_db_table_not_found() {
let (env, _tmpdir) = tmp_concrete_env();
let env_inner = env.env_inner();
let mut tx_rw = env_inner.tx_rw().unwrap();
let err = env_inner.clear_db::<TestTable>(&mut tx_rw).unwrap_err();
assert!(matches!(err, RuntimeError::TableNotFound));
env_inner.create_db::<TestTable>(&tx_rw).unwrap();
env_inner.clear_db::<TestTable>(&mut tx_rw).unwrap();
}
/// Test all `DatabaseR{o,w}` operations. /// Test all `DatabaseR{o,w}` operations.
#[test] #[test]
fn db_read_write() { fn db_read_write() {
@ -165,11 +179,11 @@ fn db_read_write() {
let mut table = env_inner.open_db_rw::<TestTable>(&tx_rw).unwrap(); let mut table = env_inner.open_db_rw::<TestTable>(&tx_rw).unwrap();
/// The (1st) key. /// The (1st) key.
const KEY: u8 = 0; const KEY: u32 = 0;
/// The expected value. /// The expected value.
const VALUE: u64 = 0; const VALUE: u64 = 0;
/// How many `(key, value)` pairs will be inserted. /// How many `(key, value)` pairs will be inserted.
const N: u8 = 100; const N: u32 = 100;
/// Assert a u64 is the same as `VALUE`. /// Assert a u64 is the same as `VALUE`.
fn assert_value(value: u64) { fn assert_value(value: u64) {
@ -323,19 +337,35 @@ fn db_read_write() {
/// Assert that `key`'s in database tables are sorted in /// Assert that `key`'s in database tables are sorted in
/// an ordered B-Tree fashion, i.e. `min_value -> max_value`. /// an ordered B-Tree fashion, i.e. `min_value -> max_value`.
///
/// And that it is true for integers, e.g. `0` -> `10`.
#[test] #[test]
fn tables_are_sorted() { fn tables_are_sorted() {
let (env, _tmp) = tmp_concrete_env(); let (env, _tmp) = tmp_concrete_env();
let env_inner = env.env_inner(); let env_inner = env.env_inner();
/// Range of keys to insert, `{0, 1, 2 ... 256}`.
const RANGE: std::ops::Range<u32> = 0..257;
// Create tables and set flags / comparison flags.
{
let tx_rw = env_inner.tx_rw().unwrap();
env_inner.create_db::<TestTable>(&tx_rw).unwrap();
TxRw::commit(tx_rw).unwrap();
}
let tx_rw = env_inner.tx_rw().unwrap(); let tx_rw = env_inner.tx_rw().unwrap();
let mut table = env_inner.open_db_rw::<TestTable>(&tx_rw).unwrap(); let mut table = env_inner.open_db_rw::<TestTable>(&tx_rw).unwrap();
// Insert `{5, 4, 3, 2, 1, 0}`, assert each new // Insert range, assert each new
// number inserted is the minimum `first()` value. // number inserted is the minimum `last()` value.
for key in (0..6).rev() { for key in RANGE {
table.put(&key, &123).unwrap(); table.put(&key, &0).unwrap();
table.contains(&key).unwrap();
let (first, _) = table.first().unwrap(); let (first, _) = table.first().unwrap();
assert_eq!(first, key); let (last, _) = table.last().unwrap();
println!("first: {first}, last: {last}, key: {key}");
assert_eq!(last, key);
} }
drop(table); drop(table);
@ -348,7 +378,7 @@ fn tables_are_sorted() {
let table = env_inner.open_db_ro::<TestTable>(&tx_ro).unwrap(); let table = env_inner.open_db_ro::<TestTable>(&tx_ro).unwrap();
let iter = table.iter().unwrap(); let iter = table.iter().unwrap();
let keys = table.keys().unwrap(); let keys = table.keys().unwrap();
for ((i, iter), key) in (0..6).zip(iter).zip(keys) { for ((i, iter), key) in RANGE.zip(iter).zip(keys) {
let (iter, _) = iter.unwrap(); let (iter, _) = iter.unwrap();
let key = key.unwrap(); let key = key.unwrap();
assert_eq!(i, iter); assert_eq!(i, iter);
@ -359,14 +389,14 @@ fn tables_are_sorted() {
let mut table = env_inner.open_db_rw::<TestTable>(&tx_rw).unwrap(); let mut table = env_inner.open_db_rw::<TestTable>(&tx_rw).unwrap();
// Assert the `first()` values are the minimum, i.e. `{0, 1, 2}` // Assert the `first()` values are the minimum, i.e. `{0, 1, 2}`
for key in 0..3 { for key in [0, 1, 2] {
let (first, _) = table.first().unwrap(); let (first, _) = table.first().unwrap();
assert_eq!(first, key); assert_eq!(first, key);
table.delete(&key).unwrap(); table.delete(&key).unwrap();
} }
// Assert the `last()` values are the maximum, i.e. `{5, 4, 3}` // Assert the `last()` values are the maximum, i.e. `{256, 255, 254}`
for key in (3..6).rev() { for key in [256, 255, 254] {
let (last, _) = table.last().unwrap(); let (last, _) = table.last().unwrap();
assert_eq!(last, key); assert_eq!(last, key);
table.delete(&key).unwrap(); table.delete(&key).unwrap();

View file

@ -175,18 +175,16 @@ pub trait Env: Sized {
} }
//---------------------------------------------------------------------------------------------------- DatabaseRo //---------------------------------------------------------------------------------------------------- DatabaseRo
/// Document errors when opening tables in [`EnvInner`]. /// Document the INVARIANT that the `heed` backend
macro_rules! doc_table_error { /// must use [`EnvInner::create_db`] when initially
/// opening/creating tables.
macro_rules! doc_heed_create_db_invariant {
() => { () => {
r"# Errors r#"The first time you open/create tables, you _must_ use [`EnvInner::create_db`]
This will only return [`RuntimeError::Io`] on normal errors. to set the proper flags / [`Key`](crate::Key) comparison for the `heed` backend.
If the specified table is not created upon before this function is called, Subsequent table opens will follow the flags/ordering, but only if
this will return an error. [`EnvInner::create_db`] was the _first_ function to open/create it."#
Implementation detail you should NOT rely on:
- This only panics on `heed`
- `redb` will create the table if it does not exist"
}; };
} }
@ -204,7 +202,13 @@ Implementation detail you should NOT rely on:
/// Note that when opening tables with [`EnvInner::open_db_ro`], /// Note that when opening tables with [`EnvInner::open_db_ro`],
/// they must be created first or else it will return error. /// they must be created first or else it will return error.
/// ///
/// See [`EnvInner::open_db_rw`] and [`EnvInner::create_db`] for creating tables. /// Note that when opening tables with [`EnvInner::open_db_ro`],
/// they must be created first or else it will return error.
///
/// See [`EnvInner::create_db`] for creating tables.
///
/// # Invariant
#[doc = doc_heed_create_db_invariant!()]
pub trait EnvInner<'env, Ro, Rw> pub trait EnvInner<'env, Ro, Rw>
where where
Self: 'env, Self: 'env,
@ -243,6 +247,9 @@ where
/// ///
/// If the specified table is not created upon before this function is called, /// If the specified table is not created upon before this function is called,
/// this will return [`RuntimeError::TableNotFound`]. /// this will return [`RuntimeError::TableNotFound`].
///
/// # Invariant
#[doc = doc_heed_create_db_invariant!()]
fn open_db_ro<T: Table>( fn open_db_ro<T: Table>(
&self, &self,
tx_ro: &Ro, tx_ro: &Ro,
@ -262,18 +269,19 @@ where
/// # Errors /// # Errors
/// This will only return [`RuntimeError::Io`] on errors. /// This will only return [`RuntimeError::Io`] on errors.
/// ///
/// Implementation details: Both `heed` & `redb` backends create /// # Invariant
/// the table with this function if it does not already exist. For safety and #[doc = doc_heed_create_db_invariant!()]
/// clear intent, you should still consider using [`EnvInner::create_db`] instead.
fn open_db_rw<T: Table>(&self, tx_rw: &Rw) -> Result<impl DatabaseRw<T>, RuntimeError>; fn open_db_rw<T: Table>(&self, tx_rw: &Rw) -> Result<impl DatabaseRw<T>, RuntimeError>;
/// Create a database table. /// Create a database table.
/// ///
/// This will create the database [`Table`] /// This will create the database [`Table`] passed as a generic to this function.
/// passed as a generic to this function.
/// ///
/// # Errors /// # Errors
/// This will only return [`RuntimeError::Io`] on errors. /// This will only return [`RuntimeError::Io`] on errors.
///
/// # Invariant
#[doc = doc_heed_create_db_invariant!()]
fn create_db<T: Table>(&self, tx_rw: &Rw) -> Result<(), RuntimeError>; fn create_db<T: Table>(&self, tx_rw: &Rw) -> Result<(), RuntimeError>;
/// Clear all `(key, value)`'s from a database table. /// Clear all `(key, value)`'s from a database table.
@ -284,6 +292,10 @@ where
/// Note that this operation is tied to `tx_rw`, as such this /// Note that this operation is tied to `tx_rw`, as such this
/// function's effects can be aborted using [`TxRw::abort`]. /// function's effects can be aborted using [`TxRw::abort`].
/// ///
#[doc = doc_table_error!()] /// # Errors
/// This will return [`RuntimeError::Io`] on normal errors.
///
/// If the specified table is not created upon before this function is called,
/// this will return [`RuntimeError::TableNotFound`].
fn clear_db<T: Table>(&self, tx_rw: &mut Rw) -> Result<(), RuntimeError>; fn clear_db<T: Table>(&self, tx_rw: &mut Rw) -> Result<(), RuntimeError>;
} }

View file

@ -1,54 +1,177 @@
//! Database key abstraction; `trait Key`. //! Database key abstraction; `trait Key`.
//---------------------------------------------------------------------------------------------------- Import //---------------------------------------------------------------------------------------------------- Import
use std::cmp::Ordering; use std::{cmp::Ordering, fmt::Debug};
use crate::storable::Storable; use crate::{storable::Storable, StorableBytes, StorableStr, StorableVec};
//---------------------------------------------------------------------------------------------------- Table //---------------------------------------------------------------------------------------------------- Table
/// Database [`Table`](crate::table::Table) key metadata. /// Database [`Table`](crate::table::Table) key metadata.
/// ///
/// Purely compile time information for database table keys. /// Purely compile time information for database table keys.
// ///
// FIXME: this doesn't need to exist right now but /// ## Comparison
// may be used if we implement getting values using ranges. /// There are 2 differences between [`Key`] and [`Storable`]:
// <https://github.com/Cuprate/cuprate/pull/117#discussion_r1589378104> /// 1. [`Key`] must be [`Sized`]
pub trait Key: Storable + Sized { /// 2. [`Key`] represents a [`Storable`] type that defines a comparison function
/// The primary key type. ///
type Primary: Storable; /// The database backends will use [`Key::KEY_COMPARE`]
/// to sort the keys within database tables.
///
/// [`Key::KEY_COMPARE`] is pre-implemented as a straight byte comparison.
///
/// This default is overridden for numbers, which use a number comparison.
/// For example, [`u64`] keys are sorted as `{0, 1, 2 ... 999_998, 999_999, 1_000_000}`.
///
/// If you would like to re-define this for number types, consider;
/// 1. Creating a wrapper type around primitives like a `struct SortU8(pub u8)`
/// 2. Implement [`Key`] on that wrapper
/// 3. Define a custom [`Key::KEY_COMPARE`]
pub trait Key: Storable + Sized + Ord {
/// Compare 2 [`Key`]'s against each other. /// Compare 2 [`Key`]'s against each other.
/// ///
/// By default, this does a straight _byte_ comparison, /// # Defaults for types
/// not a comparison of the key's value. /// For arrays and vectors that contain a `T: Storable`,
/// this does a straight _byte_ comparison, not a comparison of the key's value.
/// ///
/// For [`StorableStr`], this will use [`str::cmp`], i.e. it is the same as the default behavior; it is a
/// [lexicographical comparison](https://doc.rust-lang.org/std/cmp/trait.Ord.html#lexicographical-comparison)
///
/// For all primitive number types ([`u8`], [`i128`], etc), this will
/// convert the bytes to the number using [`Storable::from_bytes`],
/// then do a number comparison.
///
/// # Example
/// ```rust /// ```rust
/// # use cuprate_database::*; /// # use cuprate_database::*;
/// // Normal byte comparison.
/// let vec1 = StorableVec(vec![0, 1]);
/// let vec2 = StorableVec(vec![255, 0]);
/// assert_eq!( /// assert_eq!(
/// <u64 as Key>::compare([0].as_slice(), [1].as_slice()), /// <StorableVec<u8> as Key>::KEY_COMPARE
/// .as_compare_fn::<StorableVec<u8>>()(&vec1, &vec2),
/// std::cmp::Ordering::Less, /// std::cmp::Ordering::Less,
/// ); /// );
///
/// // Integer comparison.
/// let byte1 = [0, 1]; // 256
/// let byte2 = [255, 0]; // 255
/// let num1 = u16::from_le_bytes(byte1);
/// let num2 = u16::from_le_bytes(byte2);
/// assert_eq!(num1, 256);
/// assert_eq!(num2, 255);
/// assert_eq!( /// assert_eq!(
/// <u64 as Key>::compare([1].as_slice(), [1].as_slice()), /// // 256 > 255
/// std::cmp::Ordering::Equal, /// <u16 as Key>::KEY_COMPARE.as_compare_fn::<u16>()(&byte1, &byte2),
/// );
/// assert_eq!(
/// <u64 as Key>::compare([2].as_slice(), [1].as_slice()),
/// std::cmp::Ordering::Greater, /// std::cmp::Ordering::Greater,
/// ); /// );
/// ``` /// ```
#[inline] const KEY_COMPARE: KeyCompare = KeyCompare::Default;
fn compare(left: &[u8], right: &[u8]) -> Ordering {
left.cmp(right)
}
} }
//---------------------------------------------------------------------------------------------------- Impl //---------------------------------------------------------------------------------------------------- Impl
impl<T> Key for T /// [`Ord`] comparison for arrays/vectors.
where impl<const N: usize, T> Key for [T; N] where T: Key + Storable + Sized + bytemuck::Pod {}
T: Storable + Sized, impl<T: bytemuck::Pod + Debug + Ord> Key for StorableVec<T> {}
{
type Primary = Self; /// [`Ord`] comparison for misc types.
///
/// This is not a blanket implementation because
/// it allows outer crates to define their own
/// comparison functions for their `T: Storable` types.
impl Key for () {}
impl Key for StorableBytes {}
impl Key for StorableStr {}
/// Number comparison.
///
/// # Invariant
/// This must _only_ be implemented for [`u32`], [`u64`] (and maybe [`usize`]).
///
/// This is because:
/// 1. We use LMDB's `INTEGER_KEY` flag when this enum variant is used
/// 2. LMDB only supports these types when using that flag
///
/// See: <https://docs.rs/heed/0.20.0-alpha.9/heed/struct.DatabaseFlags.html#associatedconstant.INTEGER_KEY>
///
/// Other numbers will still have the same behavior, but they use
/// [`impl_custom_numbers_key`] and essentially pass LMDB a "custom"
/// number compare function.
macro_rules! impl_number_key {
($($t:ident),* $(,)?) => {
$(
impl Key for $t {
const KEY_COMPARE: KeyCompare = KeyCompare::Number;
}
)*
};
}
impl_number_key!(u32, u64, usize);
#[cfg(not(any(target_pointer_width = "32", target_pointer_width = "64")))]
compile_error!("`cuprate_database`: `usize` must be equal to `u32` or `u64` for LMDB's `usize` key sorting to function correctly");
/// Custom number comparison for other numbers.
macro_rules! impl_custom_numbers_key {
($($t:ident),* $(,)?) => {
$(
impl Key for $t {
// Just forward the the number comparison function.
const KEY_COMPARE: KeyCompare = KeyCompare::Custom(|left, right| {
KeyCompare::Number.as_compare_fn::<$t>()(left, right)
});
}
)*
};
}
impl_custom_numbers_key!(u8, u16, u128, i8, i16, i32, i64, i128, isize);
//---------------------------------------------------------------------------------------------------- KeyCompare
/// Comparison behavior for [`Key`]s.
///
/// This determines how the database sorts [`Key`]s inside a database [`Table`](crate::Table).
///
/// See [`Key`] for more info.
#[derive(Default, Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum KeyCompare {
/// Use the default comparison behavior of the backend.
///
/// Currently, both `heed` and `redb` use
/// [lexicographical comparison](https://doc.rust-lang.org/1.79.0/std/cmp/trait.Ord.html#lexicographical-comparison)
/// by default, i.e. a straight byte comparison.
#[default]
Default,
/// A by-value number comparison, i.e. `255 < 256`.
///
/// This _behavior_ is implemented as the default for all number primitives,
/// although some implementations on numbers use [`KeyCompare::Custom`] due
/// to internal implementation details of LMDB.
Number,
/// A custom sorting function.
///
/// The input of the function is 2 [`Key`]s in byte form.
Custom(fn(&[u8], &[u8]) -> Ordering),
}
impl KeyCompare {
/// Return [`Self`] as a pure comparison function.
///
/// The returned function expects 2 [`Key`]s in byte form as input.
#[inline]
pub const fn as_compare_fn<K: Key>(self) -> fn(&[u8], &[u8]) -> Ordering {
match self {
Self::Default => std::cmp::Ord::cmp,
Self::Number => |left, right| {
let left = <K as Storable>::from_bytes(left);
let right = <K as Storable>::from_bytes(right);
std::cmp::Ord::cmp(&left, &right)
},
Self::Custom(f) => f,
}
}
} }
//---------------------------------------------------------------------------------------------------- Tests //---------------------------------------------------------------------------------------------------- Tests

View file

@ -126,10 +126,10 @@ pub use error::{InitError, RuntimeError};
pub mod resize; pub mod resize;
mod key; mod key;
pub use key::Key; pub use key::{Key, KeyCompare};
mod storable; mod storable;
pub use storable::{Storable, StorableBytes, StorableVec}; pub use storable::{Storable, StorableBytes, StorableStr, StorableVec};
mod table; mod table;
pub use table::Table; pub use table::Table;

View file

@ -1,7 +1,10 @@
//! (De)serialization for table keys & values. //! (De)serialization for table keys & values.
//---------------------------------------------------------------------------------------------------- Import //---------------------------------------------------------------------------------------------------- Import
use std::{borrow::Borrow, fmt::Debug}; use std::{
borrow::{Borrow, Cow},
fmt::Debug,
};
use bytemuck::Pod; use bytemuck::Pod;
use bytes::Bytes; use bytes::Bytes;
@ -194,6 +197,66 @@ impl<T> Borrow<[T]> for StorableVec<T> {
} }
} }
//---------------------------------------------------------------------------------------------------- StorableVec
/// A [`Storable`] string.
///
/// This is a wrapper around a `Cow<'static, str>`
/// that can be stored in the database.
///
/// # Invariant
/// [`StorableStr::from_bytes`] will panic
/// if the bytes are not UTF-8. This should normally
/// not be possible in database operations, although technically
/// you can call this function yourself and input bad data.
///
/// # Example
/// ```rust
/// # use cuprate_database::*;
/// # use std::borrow::Cow;
/// let string: StorableStr = StorableStr(Cow::Borrowed("a"));
///
/// // Into bytes.
/// let into = Storable::as_bytes(&string);
/// assert_eq!(into, &[97]);
///
/// // From bytes.
/// let from: StorableStr = Storable::from_bytes(&into);
/// assert_eq!(from, string);
/// ```
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, bytemuck::TransparentWrapper)]
#[repr(transparent)]
pub struct StorableStr(pub Cow<'static, str>);
impl Storable for StorableStr {
const BYTE_LENGTH: Option<usize> = None;
/// [`String::as_bytes`].
#[inline]
fn as_bytes(&self) -> &[u8] {
self.0.as_bytes()
}
#[inline]
fn from_bytes(bytes: &[u8]) -> Self {
Self(Cow::Owned(std::str::from_utf8(bytes).unwrap().to_string()))
}
}
impl std::ops::Deref for StorableStr {
type Target = Cow<'static, str>;
#[inline]
fn deref(&self) -> &Cow<'static, str> {
&self.0
}
}
impl Borrow<Cow<'static, str>> for StorableStr {
#[inline]
fn borrow(&self) -> &Cow<'static, str> {
&self.0
}
}
//---------------------------------------------------------------------------------------------------- StorableBytes //---------------------------------------------------------------------------------------------------- StorableBytes
/// A [`Storable`] version of [`Bytes`]. /// A [`Storable`] version of [`Bytes`].
/// ///

View file

@ -15,7 +15,7 @@ pub(crate) struct TestTable;
impl Table for TestTable { impl Table for TestTable {
const NAME: &'static str = "test_table"; const NAME: &'static str = "test_table";
type Key = u8; type Key = u32;
type Value = u64; type Value = u64;
} }

View file

@ -83,10 +83,21 @@ pub enum BCReadRequest {
/// The input is a list of output amounts. /// The input is a list of output amounts.
NumberOutputsWithAmount(Vec<u64>), NumberOutputsWithAmount(Vec<u64>),
/// Check that all key images within a set arer not spent. /// Check that all key images within a set are not spent.
/// ///
/// Input is a set of key images. /// Input is a set of key images.
KeyImagesSpent(HashSet<[u8; 32]>), KeyImagesSpent(HashSet<[u8; 32]>),
/// A request for the compact chain history.
CompactChainHistory,
/// A request to find the first unknown block ID in a list of block IDs.
////
/// # Invariant
/// The [`Vec`] containing the block IDs must be sorted in chronological block
/// order, or else the returned response is unspecified and meaningless,
/// as this request performs a binary search.
FindFirstUnknown(Vec<[u8; 32]>),
} }
//---------------------------------------------------------------------------------------------------- WriteRequest //---------------------------------------------------------------------------------------------------- WriteRequest
@ -164,6 +175,23 @@ pub enum BCResponse {
/// The inner value is `false` if _none_ of the key images were spent. /// The inner value is `false` if _none_ of the key images were spent.
KeyImagesSpent(bool), KeyImagesSpent(bool),
/// Response to [`BCReadRequest::CompactChainHistory`].
CompactChainHistory {
/// A list of blocks IDs in our chain, starting with the most recent block, all the way to the genesis block.
///
/// These blocks should be in reverse chronological order, not every block is needed.
block_ids: Vec<[u8; 32]>,
/// The current cumulative difficulty of the chain.
cumulative_difficulty: u128,
},
/// The response for [`BCReadRequest::FindFirstUnknown`].
///
/// Contains the index of the first unknown block and its expected height.
///
/// This will be [`None`] if all blocks were known.
FindFirstUnknown(Option<(usize, u64)>),
//------------------------------------------------------ Writes //------------------------------------------------------ Writes
/// Response to [`BCWriteRequest::WriteBlock`]. /// Response to [`BCWriteRequest::WriteBlock`].
/// ///