diff --git a/concurrent-map-builder/src/builder.rs b/concurrent-map-builder/src/builder.rs index a9151039..f5ae91fb 100644 --- a/concurrent-map-builder/src/builder.rs +++ b/concurrent-map-builder/src/builder.rs @@ -2,7 +2,7 @@ use std::{ cell::UnsafeCell, cmp::min, hash::Hash, - mem::{needs_drop, MaybeUninit}, + mem::{needs_drop, ManuallyDrop, MaybeUninit}, ops::Range, sync::{ atomic::{AtomicUsize, Ordering}, @@ -25,7 +25,7 @@ pub(crate) struct SharedConcurrentMapBuilder { /// Values that we are initialising, will be the length of `index_set`. /// /// The index for a keys value is given by the keys index in `index_set`. - values: Option>>>, + values: Option>>>>, /// A marker for if a value in `values` is initialised. initialised_values: Vec>, @@ -40,11 +40,11 @@ unsafe impl Sync for SharedConcurrentMapBuilder {} impl SharedConcurrentMapBuilder { /// Returns a new [`SharedConcurrentMapBuilder`], with the keys needed in an [`IndexSet`]. pub fn new(keys_needed: IndexSet) -> SharedConcurrentMapBuilder { - let values = Some( + let values = Some(ManuallyDrop::new( (0..keys_needed.len()) .map(|_| UnsafeCell::new(MaybeUninit::uninit())) .collect(), - ); + )); let initialised_values = (0..keys_needed.len()) .map(|_| UnsafeCell::new(false)) .collect(); @@ -64,8 +64,9 @@ impl Drop for SharedConcurrentMapBuilder { // Values in a MaybeUninit will not be dropped so we need to drop them manually. // This will only be ran when all workers have dropped their handles. - if needs_drop::() { - if let Some(values) = &self.values { + + if let Some(values) = &mut self.values { + if needs_drop::() { for init_value in self .initialised_values .iter() @@ -86,6 +87,12 @@ impl Drop for SharedConcurrentMapBuilder { unsafe { value.assume_init_drop() } } } + + // SAFETY: + // This is drop code this is the only reference and this will not be used after being dropped. + unsafe { + ManuallyDrop::drop(values); + } } } } @@ -111,8 +118,8 @@ impl ConcurrentMapBuilder { return Err(*err); } - // TODO: can we use a weaker Ordering? - let start = self.0.current_index.fetch_add(amt, Ordering::SeqCst); + // TODO: We can use Relaxed, explain why. + let start = self.0.current_index.fetch_add(amt, Ordering::Relaxed); if start >= values.len() { // No work to do, all given out. @@ -142,7 +149,7 @@ impl ConcurrentMapBuilder { return Err(*err); } - let values = inner.values.take().unwrap(); + let mut values = inner.values.take().unwrap(); if inner.current_index.load(Ordering::Relaxed) < values.len() { return Err(ConcurrentMapBuilderError::WorkWasNotFinishedBeforeInit); @@ -152,8 +159,15 @@ impl ConcurrentMapBuilder { // - UnsafeCell> has the same bit pattern as T. // - If any value is unitised that means work wasn't handed out which we just // checked for, or work handed out was not completed which is checked for in - // the Drop impl of MapBuilderWork. - let values: Vec = unsafe { std::mem::transmute(values) }; + // the Drop impl of MapBuilderWork and if the Drop impl is not ran then there + // will be a reference on the Arc so this code will not be reached. + // - The `values` Vec is wrapped in ManuallyDrop so the inner vec will not be dropped. + let values: Vec = unsafe { + let capacity = values.capacity(); + let len = values.len(); + + Vec::from_raw_parts(values.as_mut_ptr() as *mut V, len, capacity) + }; Ok(Some(BuiltMap { index_set: inner.index_set.take().unwrap(), diff --git a/concurrent-map-builder/src/lib.rs b/concurrent-map-builder/src/lib.rs index 6f597868..002a052e 100644 --- a/concurrent-map-builder/src/lib.rs +++ b/concurrent-map-builder/src/lib.rs @@ -47,9 +47,9 @@ fn build() { use std::time::Duration; let mut keys = IndexSet::new(); - keys.extend(0..100_u8); + keys.extend(0..1000_u16); - let map_builder = BuiltMap::::builder(keys); + let map_builder = BuiltMap::::builder(keys); let map_builder2 = map_builder.clone(); @@ -63,7 +63,7 @@ fn build() { for key in keys_needed { println!("Thread1: {}", key); work.insert_next_value(*key).unwrap(); - std::thread::sleep(Duration::from_millis(100)); + std::thread::sleep(Duration::from_millis(10)); } }); @@ -79,7 +79,7 @@ fn build() { for key in keys_needed { println!("Thread2: {}", key); work.insert_next_value(*key).unwrap(); - std::thread::sleep(Duration::from_millis(100)); + std::thread::sleep(Duration::from_millis(10)); } }); @@ -95,7 +95,7 @@ fn build() { for key in keys_needed { println!("Thread3: {}", key); work.insert_next_value(*key).unwrap(); - std::thread::sleep(Duration::from_millis(100)); + std::thread::sleep(Duration::from_millis(10)); } });