heed: use Mutex for HeedTableRo's read tx

This commit is contained in:
hinto.janai 2024-04-28 21:00:14 -04:00
parent 8b8bb6342f
commit 7e8aae016c
No known key found for this signature in database
GPG key ID: D47CE05FA175A499
2 changed files with 39 additions and 18 deletions

View file

@ -6,7 +6,7 @@ use std::{
cell::RefCell, cell::RefCell,
fmt::Debug, fmt::Debug,
ops::RangeBounds, ops::RangeBounds,
sync::RwLockReadGuard, sync::{Mutex, RwLockReadGuard},
}; };
use crate::{ use crate::{
@ -36,7 +36,7 @@ pub(super) struct HeedTableRo<'tx, T: Table> {
/// An already opened database table. /// An already opened database table.
pub(super) db: HeedDb<T::Key, T::Value>, pub(super) db: HeedDb<T::Key, T::Value>,
/// The associated read-only transaction that opened this table. /// The associated read-only transaction that opened this table.
pub(super) tx_ro: &'tx heed::RoTxn<'tx>, pub(super) tx_ro: Mutex<&'tx heed::RoTxn<'tx>>,
} }
/// An opened read/write database associated with a transaction. /// An opened read/write database associated with a transaction.
@ -49,16 +49,25 @@ pub(super) struct HeedTableRw<'env, 'tx, T: Table> {
pub(super) tx_rw: &'tx RefCell<heed::RwTxn<'env>>, pub(super) tx_rw: &'tx RefCell<heed::RwTxn<'env>>,
} }
/// SAFETY: `cuprate_database`'s Cargo.toml enables a feature #[allow(clippy::non_send_fields_in_send_ty)]
/// for `heed` that turns on the `MDB_NOTLS` flag for LMDB. /// SAFETY: 2 invariants for safety:
/// ///
/// 1. `cuprate_database`'s Cargo.toml enables a feature
/// for `heed` that turns on the `MDB_NOTLS` flag for LMDB.
/// This makes read transactions `Send`, but only if that flag is enabled. /// This makes read transactions `Send`, but only if that flag is enabled.
/// ///
/// 2. Our `tx_ro` is wrapped in Mutex, as `&T: Send` only if `T: Sync`.
/// This is what is happening as we have `&TxRw`, not `TxRw`.
/// <https://github.com/Cuprate/cuprate/pull/113#discussion_r1582189108>
///
/// This is required as in `crate::service` we must put our transactions and /// This is required as in `crate::service` we must put our transactions and
/// tables inside `ThreadLocal`'s to use across multiple threads. /// tables inside `ThreadLocal`'s to use across multiple threads.
/// unsafe impl<T: Table> Send for HeedTableRo<'_, T>
/// `ThreadLocal<T>` requires that `T: Send`. where
unsafe impl<T: Table> Send for HeedTableRo<'_, T> {} T::Key: Send,
T::Value: Send,
{
}
//---------------------------------------------------------------------------------------------------- Shared functions //---------------------------------------------------------------------------------------------------- Shared functions
// FIXME: we cannot just deref `HeedTableRw -> HeedTableRo` and // FIXME: we cannot just deref `HeedTableRw -> HeedTableRo` and
@ -121,7 +130,10 @@ impl<T: Table> DatabaseIter<T> for HeedTableRo<'_, T> {
where where
Range: RangeBounds<T::Key> + 'a, Range: RangeBounds<T::Key> + 'a,
{ {
Ok(self.db.range(self.tx_ro, &range)?.map(|res| Ok(res?.1))) Ok(self
.db
.range(&self.tx_ro.lock().unwrap(), &range)?
.map(|res| Ok(res?.1)))
} }
#[inline] #[inline]
@ -129,21 +141,30 @@ impl<T: Table> DatabaseIter<T> for HeedTableRo<'_, T> {
&self, &self,
) -> Result<impl Iterator<Item = Result<(T::Key, T::Value), RuntimeError>> + '_, RuntimeError> ) -> Result<impl Iterator<Item = Result<(T::Key, T::Value), RuntimeError>> + '_, RuntimeError>
{ {
Ok(self.db.iter(self.tx_ro)?.map(|res| Ok(res?))) Ok(self
.db
.iter(&self.tx_ro.lock().unwrap())?
.map(|res| Ok(res?)))
} }
#[inline] #[inline]
fn keys( fn keys(
&self, &self,
) -> Result<impl Iterator<Item = Result<T::Key, RuntimeError>> + '_, RuntimeError> { ) -> Result<impl Iterator<Item = Result<T::Key, RuntimeError>> + '_, RuntimeError> {
Ok(self.db.iter(self.tx_ro)?.map(|res| Ok(res?.0))) Ok(self
.db
.iter(&self.tx_ro.lock().unwrap())?
.map(|res| Ok(res?.0)))
} }
#[inline] #[inline]
fn values( fn values(
&self, &self,
) -> Result<impl Iterator<Item = Result<T::Value, RuntimeError>> + '_, RuntimeError> { ) -> Result<impl Iterator<Item = Result<T::Value, RuntimeError>> + '_, RuntimeError> {
Ok(self.db.iter(self.tx_ro)?.map(|res| Ok(res?.1))) Ok(self
.db
.iter(&self.tx_ro.lock().unwrap())?
.map(|res| Ok(res?.1)))
} }
} }
@ -151,27 +172,27 @@ impl<T: Table> DatabaseIter<T> for HeedTableRo<'_, T> {
impl<T: Table> DatabaseRo<T> for HeedTableRo<'_, T> { impl<T: Table> DatabaseRo<T> for HeedTableRo<'_, T> {
#[inline] #[inline]
fn get(&self, key: &T::Key) -> Result<T::Value, RuntimeError> { fn get(&self, key: &T::Key) -> Result<T::Value, RuntimeError> {
get::<T>(&self.db, self.tx_ro, key) get::<T>(&self.db, &self.tx_ro.lock().unwrap(), key)
} }
#[inline] #[inline]
fn len(&self) -> Result<u64, RuntimeError> { fn len(&self) -> Result<u64, RuntimeError> {
len::<T>(&self.db, self.tx_ro) len::<T>(&self.db, &self.tx_ro.lock().unwrap())
} }
#[inline] #[inline]
fn first(&self) -> Result<(T::Key, T::Value), RuntimeError> { fn first(&self) -> Result<(T::Key, T::Value), RuntimeError> {
first::<T>(&self.db, self.tx_ro) first::<T>(&self.db, &self.tx_ro.lock().unwrap())
} }
#[inline] #[inline]
fn last(&self) -> Result<(T::Key, T::Value), RuntimeError> { fn last(&self) -> Result<(T::Key, T::Value), RuntimeError> {
last::<T>(&self.db, self.tx_ro) last::<T>(&self.db, &self.tx_ro.lock().unwrap())
} }
#[inline] #[inline]
fn is_empty(&self) -> Result<bool, RuntimeError> { fn is_empty(&self) -> Result<bool, RuntimeError> {
is_empty::<T>(&self.db, self.tx_ro) is_empty::<T>(&self.db, &self.tx_ro.lock().unwrap())
} }
} }

View file

@ -6,7 +6,7 @@ use std::{
fmt::Debug, fmt::Debug,
num::NonZeroUsize, num::NonZeroUsize,
ops::Deref, ops::Deref,
sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}, sync::{Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard},
}; };
use heed::{DatabaseOpenOptions, EnvFlags, EnvOpenOptions}; use heed::{DatabaseOpenOptions, EnvFlags, EnvOpenOptions};
@ -312,7 +312,7 @@ where
db: self db: self
.open_database(tx_ro, Some(T::NAME))? .open_database(tx_ro, Some(T::NAME))?
.expect(PANIC_MSG_MISSING_TABLE), .expect(PANIC_MSG_MISSING_TABLE),
tx_ro, tx_ro: Mutex::new(tx_ro),
}) })
} }