//! Async Buffer //! //! A bounded SPSC, FIFO, async buffer that supports arbitrary weights for values. //! //! Weight is used to bound the channel, on creation you specify a max weight and for each value you //! specify a weight. use std::{ cmp::min, future::Future, pin::Pin, sync::{ atomic::{AtomicUsize, Ordering}, Arc, }, task::{Context, Poll}, }; use futures::{ channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}, ready, task::AtomicWaker, Stream, StreamExt, }; #[derive(thiserror::Error, Debug, Copy, Clone, Eq, PartialEq)] pub enum BufferError { #[error("The buffer did not have enough capacity.")] NotEnoughCapacity, #[error("The other end of the buffer disconnected.")] Disconnected, } /// Initializes a new buffer with the provided capacity. /// /// The capacity inputted is not the max number of items, it is the max combined weight of all items /// in the buffer. /// /// It should be noted that if there are no items in the buffer then a single item of any capacity is accepted. /// i.e. if the capacity is 5 and there are no items in the buffer then any item even if it's weight is >5 will be /// accepted. pub fn new_buffer(max_item_weight: usize) -> (BufferAppender, BufferStream) { let (tx, rx) = unbounded(); let sink_waker = Arc::new(AtomicWaker::new()); let capacity_atomic = Arc::new(AtomicUsize::new(max_item_weight)); ( BufferAppender { queue: tx, sink_waker: sink_waker.clone(), capacity: capacity_atomic.clone(), max_item_weight, }, BufferStream { queue: rx, sink_waker, capacity: capacity_atomic, }, ) } /// The stream side of the buffer. pub struct BufferStream { /// The internal queue of items. queue: UnboundedReceiver<(T, usize)>, /// The waker for the [`BufferAppender`] sink_waker: Arc, /// The current capacity of the buffer. capacity: Arc, } impl Stream for BufferStream { type Item = T; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let Some((item, size)) = ready!(self.queue.poll_next_unpin(cx)) else { return Poll::Ready(None); }; // add the capacity back to the buffer. self.capacity.fetch_add(size, Ordering::AcqRel); // wake the sink. self.sink_waker.wake(); Poll::Ready(Some(item)) } } /// The appender/sink side of the buffer. pub struct BufferAppender { /// The internal queue of items. queue: UnboundedSender<(T, usize)>, /// Our waker. sink_waker: Arc, /// The current capacity of the buffer. capacity: Arc, /// The max weight of an item, equal to the total allowed weight of the buffer. max_item_weight: usize, } impl BufferAppender { /// Returns a future that resolves when the channel has enough capacity for /// a single message of `size_needed`. /// /// It should be noted that if there are no items in the buffer then a single item of any capacity is accepted. /// i.e. if the capacity is 5 and there are no items in the buffer then any item even if it's weight is >5 will be /// accepted. pub fn ready(&mut self, size_needed: usize) -> BufferSinkReady<'_, T> { let size_needed = min(self.max_item_weight, size_needed); BufferSinkReady { sink: self, size_needed, } } /// Attempts to add an item to the buffer. /// /// # Errors /// Returns an error if there is not enough capacity or the [`BufferStream`] was dropped. pub fn try_send(&mut self, item: T, size_needed: usize) -> Result<(), BufferError> { let size_needed = min(self.max_item_weight, size_needed); if self.capacity.load(Ordering::Acquire) < size_needed { return Err(BufferError::NotEnoughCapacity); } let prev_size = self.capacity.fetch_sub(size_needed, Ordering::AcqRel); // make sure we haven't wrapped the capacity around. assert!(prev_size >= size_needed); self.queue .unbounded_send((item, size_needed)) .map_err(|_| BufferError::Disconnected)?; Ok(()) } /// Waits for capacity in the buffer and then sends the item. pub fn send(&mut self, item: T, size_needed: usize) -> BufferSinkSend<'_, T> { BufferSinkSend { ready: self.ready(size_needed), item: Some(item), } } } /// A [`Future`] for adding an item to the buffer. #[pin_project::pin_project] pub struct BufferSinkSend<'a, T> { /// A future that resolves when the channel has capacity. #[pin] ready: BufferSinkReady<'a, T>, /// The item to send. /// /// This is [`take`](Option::take)n and added to the buffer when there is enough capacity. item: Option, } impl Future for BufferSinkSend<'_, T> { type Output = Result<(), BufferError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut this = self.project(); let size_needed = this.ready.size_needed; this.ready.as_mut().poll(cx).map(|_| { this.ready .sink .try_send(this.item.take().unwrap(), size_needed) }) } } /// A [`Future`] for waiting for capacity in the buffer. pub struct BufferSinkReady<'a, T> { /// The sink side of the buffer. sink: &'a mut BufferAppender, /// The capacity needed. /// /// This future will wait forever if this is higher than the total availability of the buffer. size_needed: usize, } impl Future for BufferSinkReady<'_, T> { type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { // Check before setting the waker just in case it has capacity now, if self.sink.capacity.load(Ordering::Acquire) >= self.size_needed { return Poll::Ready(()); } // set the waker self.sink.sink_waker.register(cx.waker()); // check the capacity again to avoid a race condition that would result in lost notifications. if self.sink.capacity.load(Ordering::Acquire) >= self.size_needed { Poll::Ready(()) } else { Poll::Pending } } }