mirror of
https://github.com/Cuprate/cuprate.git
synced 2024-12-22 11:39:26 +00:00
fix UB + use weaker ordering
transmuting the Vec is not safe as Rust could make layout optimisations for `Vec<T>`` that it can't for `Vec<UnsafeCell<MaybeUninit<V>>>``
This commit is contained in:
parent
0dbe783a45
commit
11bf3900ac
2 changed files with 30 additions and 16 deletions
|
@ -2,7 +2,7 @@ use std::{
|
||||||
cell::UnsafeCell,
|
cell::UnsafeCell,
|
||||||
cmp::min,
|
cmp::min,
|
||||||
hash::Hash,
|
hash::Hash,
|
||||||
mem::{needs_drop, MaybeUninit},
|
mem::{needs_drop, ManuallyDrop, MaybeUninit},
|
||||||
ops::Range,
|
ops::Range,
|
||||||
sync::{
|
sync::{
|
||||||
atomic::{AtomicUsize, Ordering},
|
atomic::{AtomicUsize, Ordering},
|
||||||
|
@ -25,7 +25,7 @@ pub(crate) struct SharedConcurrentMapBuilder<K, V> {
|
||||||
/// Values that we are initialising, will be the length of `index_set`.
|
/// Values that we are initialising, will be the length of `index_set`.
|
||||||
///
|
///
|
||||||
/// The index for a keys value is given by the keys index in `index_set`.
|
/// The index for a keys value is given by the keys index in `index_set`.
|
||||||
values: Option<Vec<UnsafeCell<MaybeUninit<V>>>>,
|
values: Option<ManuallyDrop<Vec<UnsafeCell<MaybeUninit<V>>>>>,
|
||||||
/// A marker for if a value in `values` is initialised.
|
/// A marker for if a value in `values` is initialised.
|
||||||
initialised_values: Vec<UnsafeCell<bool>>,
|
initialised_values: Vec<UnsafeCell<bool>>,
|
||||||
|
|
||||||
|
@ -40,11 +40,11 @@ unsafe impl<K: Sync, V: Sync> Sync for SharedConcurrentMapBuilder<K, V> {}
|
||||||
impl<K, V> SharedConcurrentMapBuilder<K, V> {
|
impl<K, V> SharedConcurrentMapBuilder<K, V> {
|
||||||
/// Returns a new [`SharedConcurrentMapBuilder`], with the keys needed in an [`IndexSet`].
|
/// Returns a new [`SharedConcurrentMapBuilder`], with the keys needed in an [`IndexSet`].
|
||||||
pub fn new(keys_needed: IndexSet<K>) -> SharedConcurrentMapBuilder<K, V> {
|
pub fn new(keys_needed: IndexSet<K>) -> SharedConcurrentMapBuilder<K, V> {
|
||||||
let values = Some(
|
let values = Some(ManuallyDrop::new(
|
||||||
(0..keys_needed.len())
|
(0..keys_needed.len())
|
||||||
.map(|_| UnsafeCell::new(MaybeUninit::uninit()))
|
.map(|_| UnsafeCell::new(MaybeUninit::uninit()))
|
||||||
.collect(),
|
.collect(),
|
||||||
);
|
));
|
||||||
let initialised_values = (0..keys_needed.len())
|
let initialised_values = (0..keys_needed.len())
|
||||||
.map(|_| UnsafeCell::new(false))
|
.map(|_| UnsafeCell::new(false))
|
||||||
.collect();
|
.collect();
|
||||||
|
@ -64,8 +64,9 @@ impl<K, V> Drop for SharedConcurrentMapBuilder<K, V> {
|
||||||
// Values in a MaybeUninit will not be dropped so we need to drop them manually.
|
// Values in a MaybeUninit will not be dropped so we need to drop them manually.
|
||||||
|
|
||||||
// This will only be ran when all workers have dropped their handles.
|
// This will only be ran when all workers have dropped their handles.
|
||||||
if needs_drop::<V>() {
|
|
||||||
if let Some(values) = &self.values {
|
if let Some(values) = &mut self.values {
|
||||||
|
if needs_drop::<V>() {
|
||||||
for init_value in self
|
for init_value in self
|
||||||
.initialised_values
|
.initialised_values
|
||||||
.iter()
|
.iter()
|
||||||
|
@ -86,6 +87,12 @@ impl<K, V> Drop for SharedConcurrentMapBuilder<K, V> {
|
||||||
unsafe { value.assume_init_drop() }
|
unsafe { value.assume_init_drop() }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SAFETY:
|
||||||
|
// This is drop code this is the only reference and this will not be used after being dropped.
|
||||||
|
unsafe {
|
||||||
|
ManuallyDrop::drop(values);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -111,8 +118,8 @@ impl<K, V> ConcurrentMapBuilder<K, V> {
|
||||||
return Err(*err);
|
return Err(*err);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: can we use a weaker Ordering?
|
// TODO: We can use Relaxed, explain why.
|
||||||
let start = self.0.current_index.fetch_add(amt, Ordering::SeqCst);
|
let start = self.0.current_index.fetch_add(amt, Ordering::Relaxed);
|
||||||
|
|
||||||
if start >= values.len() {
|
if start >= values.len() {
|
||||||
// No work to do, all given out.
|
// No work to do, all given out.
|
||||||
|
@ -142,7 +149,7 @@ impl<K, V> ConcurrentMapBuilder<K, V> {
|
||||||
return Err(*err);
|
return Err(*err);
|
||||||
}
|
}
|
||||||
|
|
||||||
let values = inner.values.take().unwrap();
|
let mut values = inner.values.take().unwrap();
|
||||||
|
|
||||||
if inner.current_index.load(Ordering::Relaxed) < values.len() {
|
if inner.current_index.load(Ordering::Relaxed) < values.len() {
|
||||||
return Err(ConcurrentMapBuilderError::WorkWasNotFinishedBeforeInit);
|
return Err(ConcurrentMapBuilderError::WorkWasNotFinishedBeforeInit);
|
||||||
|
@ -152,8 +159,15 @@ impl<K, V> ConcurrentMapBuilder<K, V> {
|
||||||
// - UnsafeCell<MaybeUninit<T>> has the same bit pattern as T.
|
// - UnsafeCell<MaybeUninit<T>> has the same bit pattern as T.
|
||||||
// - If any value is unitised that means work wasn't handed out which we just
|
// - If any value is unitised that means work wasn't handed out which we just
|
||||||
// checked for, or work handed out was not completed which is checked for in
|
// checked for, or work handed out was not completed which is checked for in
|
||||||
// the Drop impl of MapBuilderWork.
|
// the Drop impl of MapBuilderWork and if the Drop impl is not ran then there
|
||||||
let values: Vec<V> = unsafe { std::mem::transmute(values) };
|
// will be a reference on the Arc so this code will not be reached.
|
||||||
|
// - The `values` Vec is wrapped in ManuallyDrop so the inner vec will not be dropped.
|
||||||
|
let values: Vec<V> = unsafe {
|
||||||
|
let capacity = values.capacity();
|
||||||
|
let len = values.len();
|
||||||
|
|
||||||
|
Vec::from_raw_parts(values.as_mut_ptr() as *mut V, len, capacity)
|
||||||
|
};
|
||||||
|
|
||||||
Ok(Some(BuiltMap {
|
Ok(Some(BuiltMap {
|
||||||
index_set: inner.index_set.take().unwrap(),
|
index_set: inner.index_set.take().unwrap(),
|
||||||
|
|
|
@ -47,9 +47,9 @@ fn build() {
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
let mut keys = IndexSet::new();
|
let mut keys = IndexSet::new();
|
||||||
keys.extend(0..100_u8);
|
keys.extend(0..1000_u16);
|
||||||
|
|
||||||
let map_builder = BuiltMap::<u8, u8>::builder(keys);
|
let map_builder = BuiltMap::<u16, u16>::builder(keys);
|
||||||
|
|
||||||
let map_builder2 = map_builder.clone();
|
let map_builder2 = map_builder.clone();
|
||||||
|
|
||||||
|
@ -63,7 +63,7 @@ fn build() {
|
||||||
for key in keys_needed {
|
for key in keys_needed {
|
||||||
println!("Thread1: {}", key);
|
println!("Thread1: {}", key);
|
||||||
work.insert_next_value(*key).unwrap();
|
work.insert_next_value(*key).unwrap();
|
||||||
std::thread::sleep(Duration::from_millis(100));
|
std::thread::sleep(Duration::from_millis(10));
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -79,7 +79,7 @@ fn build() {
|
||||||
for key in keys_needed {
|
for key in keys_needed {
|
||||||
println!("Thread2: {}", key);
|
println!("Thread2: {}", key);
|
||||||
work.insert_next_value(*key).unwrap();
|
work.insert_next_value(*key).unwrap();
|
||||||
std::thread::sleep(Duration::from_millis(100));
|
std::thread::sleep(Duration::from_millis(10));
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -95,7 +95,7 @@ fn build() {
|
||||||
for key in keys_needed {
|
for key in keys_needed {
|
||||||
println!("Thread3: {}", key);
|
println!("Thread3: {}", key);
|
||||||
work.insert_next_value(*key).unwrap();
|
work.insert_next_value(*key).unwrap();
|
||||||
std::thread::sleep(Duration::from_millis(100));
|
std::thread::sleep(Duration::from_millis(10));
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue