fix UB + use weaker ordering

transmuting the Vec is not safe as Rust could make layout optimisations for `Vec<T>`` that it can't for `Vec<UnsafeCell<MaybeUninit<V>>>``
This commit is contained in:
Boog900 2024-04-03 03:17:42 +01:00
parent 0dbe783a45
commit 11bf3900ac
No known key found for this signature in database
GPG key ID: 42AB1287CB0041C2
2 changed files with 30 additions and 16 deletions

View file

@ -2,7 +2,7 @@ use std::{
cell::UnsafeCell, cell::UnsafeCell,
cmp::min, cmp::min,
hash::Hash, hash::Hash,
mem::{needs_drop, MaybeUninit}, mem::{needs_drop, ManuallyDrop, MaybeUninit},
ops::Range, ops::Range,
sync::{ sync::{
atomic::{AtomicUsize, Ordering}, atomic::{AtomicUsize, Ordering},
@ -25,7 +25,7 @@ pub(crate) struct SharedConcurrentMapBuilder<K, V> {
/// Values that we are initialising, will be the length of `index_set`. /// Values that we are initialising, will be the length of `index_set`.
/// ///
/// The index for a keys value is given by the keys index in `index_set`. /// The index for a keys value is given by the keys index in `index_set`.
values: Option<Vec<UnsafeCell<MaybeUninit<V>>>>, values: Option<ManuallyDrop<Vec<UnsafeCell<MaybeUninit<V>>>>>,
/// A marker for if a value in `values` is initialised. /// A marker for if a value in `values` is initialised.
initialised_values: Vec<UnsafeCell<bool>>, initialised_values: Vec<UnsafeCell<bool>>,
@ -40,11 +40,11 @@ unsafe impl<K: Sync, V: Sync> Sync for SharedConcurrentMapBuilder<K, V> {}
impl<K, V> SharedConcurrentMapBuilder<K, V> { impl<K, V> SharedConcurrentMapBuilder<K, V> {
/// Returns a new [`SharedConcurrentMapBuilder`], with the keys needed in an [`IndexSet`]. /// Returns a new [`SharedConcurrentMapBuilder`], with the keys needed in an [`IndexSet`].
pub fn new(keys_needed: IndexSet<K>) -> SharedConcurrentMapBuilder<K, V> { pub fn new(keys_needed: IndexSet<K>) -> SharedConcurrentMapBuilder<K, V> {
let values = Some( let values = Some(ManuallyDrop::new(
(0..keys_needed.len()) (0..keys_needed.len())
.map(|_| UnsafeCell::new(MaybeUninit::uninit())) .map(|_| UnsafeCell::new(MaybeUninit::uninit()))
.collect(), .collect(),
); ));
let initialised_values = (0..keys_needed.len()) let initialised_values = (0..keys_needed.len())
.map(|_| UnsafeCell::new(false)) .map(|_| UnsafeCell::new(false))
.collect(); .collect();
@ -64,8 +64,9 @@ impl<K, V> Drop for SharedConcurrentMapBuilder<K, V> {
// Values in a MaybeUninit will not be dropped so we need to drop them manually. // Values in a MaybeUninit will not be dropped so we need to drop them manually.
// This will only be ran when all workers have dropped their handles. // This will only be ran when all workers have dropped their handles.
if let Some(values) = &mut self.values {
if needs_drop::<V>() { if needs_drop::<V>() {
if let Some(values) = &self.values {
for init_value in self for init_value in self
.initialised_values .initialised_values
.iter() .iter()
@ -86,6 +87,12 @@ impl<K, V> Drop for SharedConcurrentMapBuilder<K, V> {
unsafe { value.assume_init_drop() } unsafe { value.assume_init_drop() }
} }
} }
// SAFETY:
// This is drop code this is the only reference and this will not be used after being dropped.
unsafe {
ManuallyDrop::drop(values);
}
} }
} }
} }
@ -111,8 +118,8 @@ impl<K, V> ConcurrentMapBuilder<K, V> {
return Err(*err); return Err(*err);
} }
// TODO: can we use a weaker Ordering? // TODO: We can use Relaxed, explain why.
let start = self.0.current_index.fetch_add(amt, Ordering::SeqCst); let start = self.0.current_index.fetch_add(amt, Ordering::Relaxed);
if start >= values.len() { if start >= values.len() {
// No work to do, all given out. // No work to do, all given out.
@ -142,7 +149,7 @@ impl<K, V> ConcurrentMapBuilder<K, V> {
return Err(*err); return Err(*err);
} }
let values = inner.values.take().unwrap(); let mut values = inner.values.take().unwrap();
if inner.current_index.load(Ordering::Relaxed) < values.len() { if inner.current_index.load(Ordering::Relaxed) < values.len() {
return Err(ConcurrentMapBuilderError::WorkWasNotFinishedBeforeInit); return Err(ConcurrentMapBuilderError::WorkWasNotFinishedBeforeInit);
@ -152,8 +159,15 @@ impl<K, V> ConcurrentMapBuilder<K, V> {
// - UnsafeCell<MaybeUninit<T>> has the same bit pattern as T. // - UnsafeCell<MaybeUninit<T>> has the same bit pattern as T.
// - If any value is unitised that means work wasn't handed out which we just // - If any value is unitised that means work wasn't handed out which we just
// checked for, or work handed out was not completed which is checked for in // checked for, or work handed out was not completed which is checked for in
// the Drop impl of MapBuilderWork. // the Drop impl of MapBuilderWork and if the Drop impl is not ran then there
let values: Vec<V> = unsafe { std::mem::transmute(values) }; // will be a reference on the Arc so this code will not be reached.
// - The `values` Vec is wrapped in ManuallyDrop so the inner vec will not be dropped.
let values: Vec<V> = unsafe {
let capacity = values.capacity();
let len = values.len();
Vec::from_raw_parts(values.as_mut_ptr() as *mut V, len, capacity)
};
Ok(Some(BuiltMap { Ok(Some(BuiltMap {
index_set: inner.index_set.take().unwrap(), index_set: inner.index_set.take().unwrap(),

View file

@ -47,9 +47,9 @@ fn build() {
use std::time::Duration; use std::time::Duration;
let mut keys = IndexSet::new(); let mut keys = IndexSet::new();
keys.extend(0..100_u8); keys.extend(0..1000_u16);
let map_builder = BuiltMap::<u8, u8>::builder(keys); let map_builder = BuiltMap::<u16, u16>::builder(keys);
let map_builder2 = map_builder.clone(); let map_builder2 = map_builder.clone();
@ -63,7 +63,7 @@ fn build() {
for key in keys_needed { for key in keys_needed {
println!("Thread1: {}", key); println!("Thread1: {}", key);
work.insert_next_value(*key).unwrap(); work.insert_next_value(*key).unwrap();
std::thread::sleep(Duration::from_millis(100)); std::thread::sleep(Duration::from_millis(10));
} }
}); });
@ -79,7 +79,7 @@ fn build() {
for key in keys_needed { for key in keys_needed {
println!("Thread2: {}", key); println!("Thread2: {}", key);
work.insert_next_value(*key).unwrap(); work.insert_next_value(*key).unwrap();
std::thread::sleep(Duration::from_millis(100)); std::thread::sleep(Duration::from_millis(10));
} }
}); });
@ -95,7 +95,7 @@ fn build() {
for key in keys_needed { for key in keys_needed {
println!("Thread3: {}", key); println!("Thread3: {}", key);
work.insert_next_value(*key).unwrap(); work.insert_next_value(*key).unwrap();
std::thread::sleep(Duration::from_millis(100)); std::thread::sleep(Duration::from_millis(10));
} }
}); });