diff --git a/p2p/async-buffer/Cargo.toml b/p2p/async-buffer/Cargo.toml new file mode 100644 index 00000000..59f04301 --- /dev/null +++ b/p2p/async-buffer/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "async-buffer" +version = "0.1.0" +edition = "2021" +license = "MIT" +authors = ["Boog900"] + +[dependencies] +thiserror = { workspace = true } +futures = { workspace = true, features = ["std"] } +pin-project = { workspace = true } + +[dev-dependencies] +tokio = { workspace = true, features = ["full"] } \ No newline at end of file diff --git a/p2p/async-buffer/src/lib.rs b/p2p/async-buffer/src/lib.rs new file mode 100644 index 00000000..ded8c6a9 --- /dev/null +++ b/p2p/async-buffer/src/lib.rs @@ -0,0 +1,205 @@ +//! Async Buffer +//! +//! A bounded SPSC, FIFO, async buffer that supports arbitrary weights for values. +//! +//! Weight is used to bound the channel, on creation you specify a max weight and for each value you +//! specify a weight. +use std::{ + cmp::min, + future::Future, + pin::Pin, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + task::{Context, Poll}, +}; + +use futures::{ + channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}, + ready, + task::AtomicWaker, + Stream, StreamExt, +}; + +#[derive(thiserror::Error, Debug, Copy, Clone, Eq, PartialEq)] +pub enum BufferError { + #[error("The buffer did not have enough capacity.")] + NotEnoughCapacity, + #[error("The other end of the buffer disconnected.")] + Disconnected, +} + +/// Initializes a new buffer with the provided capacity. +/// +/// The capacity inputted is not the max number of items, it is the max combined weight of all items +/// in the buffer. +/// +/// It should be noted that if there are no items in the buffer then a single item of any capacity is accepted. +/// i.e. if the capacity is 5 and there are no items in the buffer then any item even if it's weight is >5 will be +/// accepted. +pub fn new_buffer(max_item_weight: usize) -> (BufferAppender, BufferStream) { + let (tx, rx) = unbounded(); + let sink_waker = Arc::new(AtomicWaker::new()); + let capacity_atomic = Arc::new(AtomicUsize::new(max_item_weight)); + + ( + BufferAppender { + queue: tx, + sink_waker: sink_waker.clone(), + capacity: capacity_atomic.clone(), + max_item_weight: capacity, + }, + BufferStream { + queue: rx, + sink_waker, + capacity: capacity_atomic, + }, + ) +} + +/// The stream side of the buffer. +pub struct BufferStream { + /// The internal queue of items. + queue: UnboundedReceiver<(T, usize)>, + /// The waker for the [`BufferAppender`] + sink_waker: Arc, + /// The current capacity of the buffer. + capacity: Arc, +} + +impl Stream for BufferStream { + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let Some((item, size)) = ready!(self.queue.poll_next_unpin(cx)) else { + return Poll::Ready(None); + }; + + // add the capacity back to the buffer. + self.capacity.fetch_add(size, Ordering::AcqRel); + // wake the sink. + self.sink_waker.wake(); + + Poll::Ready(Some(item)) + } +} + +/// The appender/sink side of the buffer. +pub struct BufferAppender { + /// The internal queue of items. + queue: UnboundedSender<(T, usize)>, + /// Our waker. + sink_waker: Arc, + /// The current capacity of the buffer. + capacity: Arc, + /// The max weight of an item, equal to the total allowed weight of the buffer. + max_item_weight: usize, +} + +impl BufferAppender { + /// Returns a future that resolves when the channel has enough capacity for + /// a single message of `size_needed`. + /// + /// It should be noted that if there are no items in the buffer then a single item of any capacity is accepted. + /// i.e. if the capacity is 5 and there are no items in the buffer then any item even if it's weight is >5 will be + /// accepted. + pub fn ready(&mut self, size_needed: usize) -> BufferSinkReady<'_, T> { + let size_needed = min(self.max_item_weight, size_needed); + + BufferSinkReady { + sink: self, + size_needed, + } + } + + /// Attempts to add an item to the buffer. + /// + /// # Errors + /// Returns an error if there is not enough capacity or the [`BufferStream`] was dropped. + pub fn try_send(&mut self, item: T, size_needed: usize) -> Result<(), BufferError> { + let size_needed = min(self.max_item_weight, size_needed); + + if self.capacity.load(Ordering::Acquire) < size_needed { + return Err(BufferError::NotEnoughCapacity); + } + + let prev_size = self.capacity.fetch_sub(size_needed, Ordering::AcqRel); + + // make sure we haven't wrapped the capacity around. + assert!(prev_size >= size_needed); + + self.queue + .unbounded_send((item, size_needed)) + .map_err(|_| BufferError::Disconnected)?; + + Ok(()) + } + + /// Waits for capacity in the buffer and then sends the item. + pub fn send(&mut self, item: T, size_needed: usize) -> BufferSinkSend<'_, T> { + BufferSinkSend { + ready: self.ready(size_needed), + item: Some(item), + } + } +} + +/// A [`Future`] for adding an item to the buffer. +#[pin_project::pin_project] +pub struct BufferSinkSend<'a, T> { + /// A future that resolves when the channel has capacity. + #[pin] + ready: BufferSinkReady<'a, T>, + /// The item to send. + /// + /// This is [`take`](Option::take)n and added to the buffer when there is enough capacity. + item: Option, +} + +impl<'a, T> Future for BufferSinkSend<'a, T> { + type Output = Result<(), BufferError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + + let size_needed = this.ready.size_needed; + + this.ready.as_mut().poll(cx).map(|_| { + this.ready + .sink + .try_send(this.item.take().unwrap(), size_needed) + }) + } +} + +/// A [`Future`] for waiting for capacity in the buffer. +pub struct BufferSinkReady<'a, T> { + /// The sink side of the buffer. + sink: &'a mut BufferAppender, + /// The capacity needed. + /// + /// This future will wait forever if this is higher than the total availability of the buffer. + size_needed: usize, +} + +impl<'a, T> Future for BufferSinkReady<'a, T> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // Check before setting the waker just in case it has capacity now, + if self.sink.capacity.load(Ordering::Acquire) >= self.size_needed { + return Poll::Ready(()); + } + + // set the waker + self.sink.sink_waker.register(cx.waker()); + + // check the capacity again to avoid a race condition that would result in lost notifications. + if self.sink.capacity.load(Ordering::Acquire) >= self.size_needed { + Poll::Ready(()) + } else { + Poll::Pending + } + } +} diff --git a/p2p/async-buffer/tests/basic.rs b/p2p/async-buffer/tests/basic.rs new file mode 100644 index 00000000..93717300 --- /dev/null +++ b/p2p/async-buffer/tests/basic.rs @@ -0,0 +1,37 @@ +use futures::{FutureExt, StreamExt}; + +use async_buffer::new_buffer; + +#[tokio::test] +async fn async_buffer_send_rec() { + let (mut tx, mut rx) = new_buffer(1000); + + tx.send(4, 5).await.unwrap(); + tx.send(8, 5).await.unwrap(); + + assert_eq!(rx.next().await.unwrap(), 4); + assert_eq!(rx.next().await.unwrap(), 8); +} + +#[tokio::test] +async fn capacity_reached() { + let (mut tx, mut rx) = new_buffer(1000); + + tx.send(4, 1000).await.unwrap(); + + assert!(tx.ready(1).now_or_never().is_none()); + + let fut = tx.ready(1); + + rx.next().await; + + assert!(fut.now_or_never().is_some()); +} + +#[tokio::test] +async fn single_item_over_capacity() { + let (mut tx, mut rx) = new_buffer(1000); + tx.send(4, 1_000_000).await.unwrap(); + + assert_eq!(rx.next().await.unwrap(), 4); +}