add an iterator for ByteArrayVec

This commit is contained in:
Boog900 2024-06-29 00:12:38 +01:00
parent 7e9891de5b
commit cf46d58d17
No known key found for this signature in database
GPG key ID: 42AB1287CB0041C2
4 changed files with 193 additions and 112 deletions

View file

@ -0,0 +1,119 @@
mod into_iter;
use core::ops::Index;
use bytes::{BufMut, Bytes, BytesMut};
use crate::FixedByteError;
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct ByteArrayVec<const N: usize>(Bytes);
impl<const N: usize> ByteArrayVec<N> {
pub fn len(&self) -> usize {
self.0.len() / N
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn take_bytes(self) -> Bytes {
self.0
}
/// Splits the byte array vec into two at the given index.
///
/// Afterwards self contains elements [0, at), and the returned [`ByteArrayVec`] contains elements [at, len).
///
/// This is an O(1) operation that just increases the reference count and sets a few indices.
///
/// # Panics
/// Panics if at > len.
pub fn split_off(&mut self, at: usize) -> Self {
Self(self.0.split_off(at * N))
}
}
impl<const N: usize> IntoIterator for ByteArrayVec<N> {
type Item = [u8; N];
type IntoIter = into_iter::ByteArrayVecIterator<N>;
fn into_iter(self) -> Self::IntoIter {
into_iter::ByteArrayVecIterator(self.0)
}
}
impl<const N: usize> From<ByteArrayVec<N>> for Vec<[u8; N]> {
fn from(value: ByteArrayVec<N>) -> Self {
value.into_iter().collect()
}
}
impl<const N: usize> From<Vec<[u8; N]>> for ByteArrayVec<N> {
fn from(value: Vec<[u8; N]>) -> Self {
let mut bytes = BytesMut::with_capacity(N * value.len());
for i in value.into_iter() {
bytes.extend_from_slice(&i)
}
ByteArrayVec(bytes.freeze())
}
}
impl<const N: usize> TryFrom<Bytes> for ByteArrayVec<N> {
type Error = FixedByteError;
fn try_from(value: Bytes) -> Result<Self, Self::Error> {
if value.len() % N != 0 {
return Err(FixedByteError::InvalidLength);
}
Ok(ByteArrayVec(value))
}
}
impl<const N: usize> From<[u8; N]> for ByteArrayVec<N> {
fn from(value: [u8; N]) -> Self {
ByteArrayVec(Bytes::copy_from_slice(value.as_slice()))
}
}
impl<const N: usize, const LEN: usize> From<[[u8; N]; LEN]> for ByteArrayVec<N> {
fn from(value: [[u8; N]; LEN]) -> Self {
let mut bytes = BytesMut::with_capacity(N * LEN);
for val in value.into_iter() {
bytes.put_slice(val.as_slice());
}
ByteArrayVec(bytes.freeze())
}
}
impl<const N: usize> TryFrom<Vec<u8>> for ByteArrayVec<N> {
type Error = FixedByteError;
fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
if value.len() % N != 0 {
return Err(FixedByteError::InvalidLength);
}
Ok(ByteArrayVec(Bytes::from(value)))
}
}
impl<const N: usize> Index<usize> for ByteArrayVec<N> {
type Output = [u8; N];
fn index(&self, index: usize) -> &Self::Output {
if (index + 1) * N > self.0.len() {
panic!("Index out of range, idx: {}, length: {}", index, self.len());
}
self.0[index * N..(index + 1) * N]
.as_ref()
.try_into()
.unwrap()
}
}

View file

@ -0,0 +1,66 @@
use bytes::{Buf, Bytes};
pub struct ByteArrayVecIterator<const N: usize>(pub(crate) Bytes);
impl<const N: usize> Iterator for ByteArrayVecIterator<N> {
type Item = [u8; N];
fn next(&mut self) -> Option<Self::Item> {
if self.0.is_empty() {
return None;
}
let next = self.0[..N].try_into().unwrap();
self.0.advance(N);
Some(next)
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.0.len() / N, Some(self.0.len() / N))
}
fn last(self) -> Option<Self::Item>
where
Self: Sized,
{
if self.0.is_empty() {
return None;
}
Some(self.0[self.0.len() - N..].try_into().unwrap())
}
fn nth(&mut self, n: usize) -> Option<Self::Item> {
let iters_left = self.0.len() / N;
if iters_left.checked_sub(n).is_none() {
return None;
}
self.0.advance(n * N - N);
self.next()
}
}
impl<const N: usize> DoubleEndedIterator for ByteArrayVecIterator<N> {
fn next_back(&mut self) -> Option<Self::Item> {
if self.0.is_empty() {
return None;
}
Some(self.0[self.0.len() - N..].try_into().unwrap())
}
fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
let iters_left = self.0.len() / N;
if iters_left.checked_sub(n).is_none() {
return None;
}
self.0.truncate(self.0.len() - n * N - N);
self.next_back()
}
}

View file

@ -1,9 +1,12 @@
use core::{
fmt::{Debug, Formatter},
ops::{Deref, Index},
ops::Deref,
};
use bytes::{BufMut, Bytes, BytesMut};
use bytes::Bytes;
mod byte_array_vec;
pub use byte_array_vec::ByteArrayVec;
#[cfg_attr(feature = "std", derive(thiserror::Error))]
pub enum FixedByteError {
@ -87,114 +90,6 @@ impl<const N: usize> TryFrom<Vec<u8>> for ByteArray<N> {
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct ByteArrayVec<const N: usize>(Bytes);
impl<const N: usize> ByteArrayVec<N> {
pub fn len(&self) -> usize {
self.0.len() / N
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn take_bytes(self) -> Bytes {
self.0
}
/// Splits the byte array vec into two at the given index.
///
/// Afterwards self contains elements [0, at), and the returned [`ByteArrayVec`] contains elements [at, len).
///
/// This is an O(1) operation that just increases the reference count and sets a few indices.
///
/// # Panics
/// Panics if at > len.
pub fn split_off(&mut self, at: usize) -> Self {
Self(self.0.split_off(at * N))
}
}
impl<const N: usize> From<&ByteArrayVec<N>> for Vec<[u8; N]> {
fn from(value: &ByteArrayVec<N>) -> Self {
let mut out = Vec::with_capacity(value.len());
for i in 0..value.len() {
out.push(value[i])
}
out
}
}
impl<const N: usize> From<Vec<[u8; N]>> for ByteArrayVec<N> {
fn from(value: Vec<[u8; N]>) -> Self {
let mut bytes = BytesMut::with_capacity(N * value.len());
for i in value.into_iter() {
bytes.extend_from_slice(&i)
}
ByteArrayVec(bytes.freeze())
}
}
impl<const N: usize> TryFrom<Bytes> for ByteArrayVec<N> {
type Error = FixedByteError;
fn try_from(value: Bytes) -> Result<Self, Self::Error> {
if value.len() % N != 0 {
return Err(FixedByteError::InvalidLength);
}
Ok(ByteArrayVec(value))
}
}
impl<const N: usize> From<[u8; N]> for ByteArrayVec<N> {
fn from(value: [u8; N]) -> Self {
ByteArrayVec(Bytes::copy_from_slice(value.as_slice()))
}
}
impl<const N: usize, const LEN: usize> From<[[u8; N]; LEN]> for ByteArrayVec<N> {
fn from(value: [[u8; N]; LEN]) -> Self {
let mut bytes = BytesMut::with_capacity(N * LEN);
for val in value.into_iter() {
bytes.put_slice(val.as_slice());
}
ByteArrayVec(bytes.freeze())
}
}
impl<const N: usize> TryFrom<Vec<u8>> for ByteArrayVec<N> {
type Error = FixedByteError;
fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
if value.len() % N != 0 {
return Err(FixedByteError::InvalidLength);
}
Ok(ByteArrayVec(Bytes::from(value)))
}
}
impl<const N: usize> Index<usize> for ByteArrayVec<N> {
type Output = [u8; N];
fn index(&self, index: usize) -> &Self::Output {
if (index + 1) * N > self.0.len() {
panic!("Index out of range, idx: {}, length: {}", index, self.len());
}
self.0[index * N..(index + 1) * N]
.as_ref()
.try_into()
.unwrap()
}
}
#[cfg(test)]
mod tests {
use super::*;

View file

@ -68,7 +68,7 @@ pub async fn request_chain_entry_from_peer<N: NetworkZone>(
}
let entry = ChainEntry {
ids: (&chain_res.m_block_ids).into(),
ids: chain_res.m_block_ids.into(),
peer: client.info.id,
handle: client.info.handle.clone(),
};
@ -191,7 +191,8 @@ where
return Err(BlockDownloadError::FailedToFindAChainToFollow);
};
let hashes: Vec<[u8; 32]> = (&chain_res.m_block_ids).into();
// .clone here is not a full clone as the underlying data is [`Bytes`].
let hashes: Vec<[u8; 32]> = chain_res.m_block_ids.clone().into();
// drop this to deallocate the [`Bytes`].
drop(chain_res);