fix UB + use weaker ordering

transmuting the Vec is not safe as Rust could make layout optimisations for `Vec<T>`` that it can't for `Vec<UnsafeCell<MaybeUninit<V>>>``
This commit is contained in:
Boog900 2024-04-03 03:17:42 +01:00
parent 0dbe783a45
commit 11bf3900ac
No known key found for this signature in database
GPG key ID: 42AB1287CB0041C2
2 changed files with 30 additions and 16 deletions

View file

@ -2,7 +2,7 @@ use std::{
cell::UnsafeCell,
cmp::min,
hash::Hash,
mem::{needs_drop, MaybeUninit},
mem::{needs_drop, ManuallyDrop, MaybeUninit},
ops::Range,
sync::{
atomic::{AtomicUsize, Ordering},
@ -25,7 +25,7 @@ pub(crate) struct SharedConcurrentMapBuilder<K, V> {
/// Values that we are initialising, will be the length of `index_set`.
///
/// The index for a keys value is given by the keys index in `index_set`.
values: Option<Vec<UnsafeCell<MaybeUninit<V>>>>,
values: Option<ManuallyDrop<Vec<UnsafeCell<MaybeUninit<V>>>>>,
/// A marker for if a value in `values` is initialised.
initialised_values: Vec<UnsafeCell<bool>>,
@ -40,11 +40,11 @@ unsafe impl<K: Sync, V: Sync> Sync for SharedConcurrentMapBuilder<K, V> {}
impl<K, V> SharedConcurrentMapBuilder<K, V> {
/// Returns a new [`SharedConcurrentMapBuilder`], with the keys needed in an [`IndexSet`].
pub fn new(keys_needed: IndexSet<K>) -> SharedConcurrentMapBuilder<K, V> {
let values = Some(
let values = Some(ManuallyDrop::new(
(0..keys_needed.len())
.map(|_| UnsafeCell::new(MaybeUninit::uninit()))
.collect(),
);
));
let initialised_values = (0..keys_needed.len())
.map(|_| UnsafeCell::new(false))
.collect();
@ -64,8 +64,9 @@ impl<K, V> Drop for SharedConcurrentMapBuilder<K, V> {
// Values in a MaybeUninit will not be dropped so we need to drop them manually.
// This will only be ran when all workers have dropped their handles.
if let Some(values) = &mut self.values {
if needs_drop::<V>() {
if let Some(values) = &self.values {
for init_value in self
.initialised_values
.iter()
@ -86,6 +87,12 @@ impl<K, V> Drop for SharedConcurrentMapBuilder<K, V> {
unsafe { value.assume_init_drop() }
}
}
// SAFETY:
// This is drop code this is the only reference and this will not be used after being dropped.
unsafe {
ManuallyDrop::drop(values);
}
}
}
}
@ -111,8 +118,8 @@ impl<K, V> ConcurrentMapBuilder<K, V> {
return Err(*err);
}
// TODO: can we use a weaker Ordering?
let start = self.0.current_index.fetch_add(amt, Ordering::SeqCst);
// TODO: We can use Relaxed, explain why.
let start = self.0.current_index.fetch_add(amt, Ordering::Relaxed);
if start >= values.len() {
// No work to do, all given out.
@ -142,7 +149,7 @@ impl<K, V> ConcurrentMapBuilder<K, V> {
return Err(*err);
}
let values = inner.values.take().unwrap();
let mut values = inner.values.take().unwrap();
if inner.current_index.load(Ordering::Relaxed) < values.len() {
return Err(ConcurrentMapBuilderError::WorkWasNotFinishedBeforeInit);
@ -152,8 +159,15 @@ impl<K, V> ConcurrentMapBuilder<K, V> {
// - UnsafeCell<MaybeUninit<T>> has the same bit pattern as T.
// - If any value is unitised that means work wasn't handed out which we just
// checked for, or work handed out was not completed which is checked for in
// the Drop impl of MapBuilderWork.
let values: Vec<V> = unsafe { std::mem::transmute(values) };
// the Drop impl of MapBuilderWork and if the Drop impl is not ran then there
// will be a reference on the Arc so this code will not be reached.
// - The `values` Vec is wrapped in ManuallyDrop so the inner vec will not be dropped.
let values: Vec<V> = unsafe {
let capacity = values.capacity();
let len = values.len();
Vec::from_raw_parts(values.as_mut_ptr() as *mut V, len, capacity)
};
Ok(Some(BuiltMap {
index_set: inner.index_set.take().unwrap(),

View file

@ -47,9 +47,9 @@ fn build() {
use std::time::Duration;
let mut keys = IndexSet::new();
keys.extend(0..100_u8);
keys.extend(0..1000_u16);
let map_builder = BuiltMap::<u8, u8>::builder(keys);
let map_builder = BuiltMap::<u16, u16>::builder(keys);
let map_builder2 = map_builder.clone();
@ -63,7 +63,7 @@ fn build() {
for key in keys_needed {
println!("Thread1: {}", key);
work.insert_next_value(*key).unwrap();
std::thread::sleep(Duration::from_millis(100));
std::thread::sleep(Duration::from_millis(10));
}
});
@ -79,7 +79,7 @@ fn build() {
for key in keys_needed {
println!("Thread2: {}", key);
work.insert_next_value(*key).unwrap();
std::thread::sleep(Duration::from_millis(100));
std::thread::sleep(Duration::from_millis(10));
}
});
@ -95,7 +95,7 @@ fn build() {
for key in keys_needed {
println!("Thread3: {}", key);
work.insert_next_value(*key).unwrap();
std::thread::sleep(Duration::from_millis(100));
std::thread::sleep(Duration::from_millis(10));
}
});