heed: use Mutex for HeedTableRo's read tx

This commit is contained in:
hinto.janai 2024-04-28 21:00:14 -04:00
parent 8b8bb6342f
commit 7e8aae016c
No known key found for this signature in database
GPG key ID: D47CE05FA175A499
2 changed files with 39 additions and 18 deletions

View file

@ -6,7 +6,7 @@ use std::{
cell::RefCell,
fmt::Debug,
ops::RangeBounds,
sync::RwLockReadGuard,
sync::{Mutex, RwLockReadGuard},
};
use crate::{
@ -36,7 +36,7 @@ pub(super) struct HeedTableRo<'tx, T: Table> {
/// An already opened database table.
pub(super) db: HeedDb<T::Key, T::Value>,
/// The associated read-only transaction that opened this table.
pub(super) tx_ro: &'tx heed::RoTxn<'tx>,
pub(super) tx_ro: Mutex<&'tx heed::RoTxn<'tx>>,
}
/// An opened read/write database associated with a transaction.
@ -49,16 +49,25 @@ pub(super) struct HeedTableRw<'env, 'tx, T: Table> {
pub(super) tx_rw: &'tx RefCell<heed::RwTxn<'env>>,
}
/// SAFETY: `cuprate_database`'s Cargo.toml enables a feature
/// for `heed` that turns on the `MDB_NOTLS` flag for LMDB.
#[allow(clippy::non_send_fields_in_send_ty)]
/// SAFETY: 2 invariants for safety:
///
/// 1. `cuprate_database`'s Cargo.toml enables a feature
/// for `heed` that turns on the `MDB_NOTLS` flag for LMDB.
/// This makes read transactions `Send`, but only if that flag is enabled.
///
/// 2. Our `tx_ro` is wrapped in Mutex, as `&T: Send` only if `T: Sync`.
/// This is what is happening as we have `&TxRw`, not `TxRw`.
/// <https://github.com/Cuprate/cuprate/pull/113#discussion_r1582189108>
///
/// This is required as in `crate::service` we must put our transactions and
/// tables inside `ThreadLocal`'s to use across multiple threads.
///
/// `ThreadLocal<T>` requires that `T: Send`.
unsafe impl<T: Table> Send for HeedTableRo<'_, T> {}
unsafe impl<T: Table> Send for HeedTableRo<'_, T>
where
T::Key: Send,
T::Value: Send,
{
}
//---------------------------------------------------------------------------------------------------- Shared functions
// FIXME: we cannot just deref `HeedTableRw -> HeedTableRo` and
@ -121,7 +130,10 @@ impl<T: Table> DatabaseIter<T> for HeedTableRo<'_, T> {
where
Range: RangeBounds<T::Key> + 'a,
{
Ok(self.db.range(self.tx_ro, &range)?.map(|res| Ok(res?.1)))
Ok(self
.db
.range(&self.tx_ro.lock().unwrap(), &range)?
.map(|res| Ok(res?.1)))
}
#[inline]
@ -129,21 +141,30 @@ impl<T: Table> DatabaseIter<T> for HeedTableRo<'_, T> {
&self,
) -> Result<impl Iterator<Item = Result<(T::Key, T::Value), RuntimeError>> + '_, RuntimeError>
{
Ok(self.db.iter(self.tx_ro)?.map(|res| Ok(res?)))
Ok(self
.db
.iter(&self.tx_ro.lock().unwrap())?
.map(|res| Ok(res?)))
}
#[inline]
fn keys(
&self,
) -> Result<impl Iterator<Item = Result<T::Key, RuntimeError>> + '_, RuntimeError> {
Ok(self.db.iter(self.tx_ro)?.map(|res| Ok(res?.0)))
Ok(self
.db
.iter(&self.tx_ro.lock().unwrap())?
.map(|res| Ok(res?.0)))
}
#[inline]
fn values(
&self,
) -> Result<impl Iterator<Item = Result<T::Value, RuntimeError>> + '_, RuntimeError> {
Ok(self.db.iter(self.tx_ro)?.map(|res| Ok(res?.1)))
Ok(self
.db
.iter(&self.tx_ro.lock().unwrap())?
.map(|res| Ok(res?.1)))
}
}
@ -151,27 +172,27 @@ impl<T: Table> DatabaseIter<T> for HeedTableRo<'_, T> {
impl<T: Table> DatabaseRo<T> for HeedTableRo<'_, T> {
#[inline]
fn get(&self, key: &T::Key) -> Result<T::Value, RuntimeError> {
get::<T>(&self.db, self.tx_ro, key)
get::<T>(&self.db, &self.tx_ro.lock().unwrap(), key)
}
#[inline]
fn len(&self) -> Result<u64, RuntimeError> {
len::<T>(&self.db, self.tx_ro)
len::<T>(&self.db, &self.tx_ro.lock().unwrap())
}
#[inline]
fn first(&self) -> Result<(T::Key, T::Value), RuntimeError> {
first::<T>(&self.db, self.tx_ro)
first::<T>(&self.db, &self.tx_ro.lock().unwrap())
}
#[inline]
fn last(&self) -> Result<(T::Key, T::Value), RuntimeError> {
last::<T>(&self.db, self.tx_ro)
last::<T>(&self.db, &self.tx_ro.lock().unwrap())
}
#[inline]
fn is_empty(&self) -> Result<bool, RuntimeError> {
is_empty::<T>(&self.db, self.tx_ro)
is_empty::<T>(&self.db, &self.tx_ro.lock().unwrap())
}
}

View file

@ -6,7 +6,7 @@ use std::{
fmt::Debug,
num::NonZeroUsize,
ops::Deref,
sync::{RwLock, RwLockReadGuard, RwLockWriteGuard},
sync::{Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard},
};
use heed::{DatabaseOpenOptions, EnvFlags, EnvOpenOptions};
@ -312,7 +312,7 @@ where
db: self
.open_database(tx_ro, Some(T::NAME))?
.expect(PANIC_MSG_MISSING_TABLE),
tx_ro,
tx_ro: Mutex::new(tx_ro),
})
}