Re-work levin to remove a lot of the complexities (#24)

This commit is contained in:
Boog900 2023-07-13 21:10:52 +00:00 committed by GitHub
parent f9a735b51f
commit e6e8bdaf6c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 421 additions and 665 deletions

View file

@ -1,15 +1,14 @@
[package]
name = "levin"
name = "levin-cuprate"
version = "0.1.0"
edition = "2021"
description = "A crate for working with the Levin protocol in Rust."
license = "MIT"
authors = ["Boog900"]
repository = "https://github.com/SyntheticBird45/cuprate/tree/main/net/levin"
repository = "https://github.com/Cuprate/cuprate/tree/main/net/levin"
[dependencies]
thiserror = "1.0.24"
byteorder = "1.4.3"
futures = "0.3"
thiserror = "1"
bytes = "1"
pin-project = "1"
tokio-util = {version = "0.7", features = ["codec"]}

View file

@ -1,107 +0,0 @@
// Rust Levin Library
// Written in 2023 by
// Cuprate Contributors
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
//! This module provides a `BucketSink` struct, which writes buckets to the
//! provided `AsyncWrite`. If you are a user of this library you should
//! probably use `MessageSink` instead.
use std::collections::VecDeque;
use std::pin::Pin;
use std::task::Poll;
use bytes::{Buf, BytesMut};
use futures::ready;
use futures::sink::Sink;
use futures::AsyncWrite;
use pin_project::pin_project;
use crate::{Bucket, BucketError};
/// A BucketSink writes Bucket instances to the provided AsyncWrite target.
#[pin_project]
pub struct BucketSink<W> {
#[pin]
writer: W,
buffer: VecDeque<BytesMut>,
}
impl<W: AsyncWrite + std::marker::Unpin> BucketSink<W> {
/// Creates a new [`BucketSink`] from the given [`AsyncWrite`] writer.
pub fn new(writer: W) -> Self {
BucketSink {
writer,
buffer: VecDeque::with_capacity(2),
}
}
}
impl<W: AsyncWrite + std::marker::Unpin> Sink<Bucket> for BucketSink<W> {
type Error = BucketError;
fn poll_ready(
self: std::pin::Pin<&mut Self>,
_: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn start_send(mut self: Pin<&mut Self>, item: Bucket) -> Result<(), Self::Error> {
let buf = item.to_bytes();
self.buffer.push_back(BytesMut::from(&buf[..]));
Ok(())
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
let this = self.project();
let mut w = this.writer;
let buffer = this.buffer;
loop {
match ready!(w.as_mut().poll_flush(cx)) {
Err(err) => return Poll::Ready(Err(err.into())),
Ok(()) => {
if let Some(buf) = buffer.front() {
match ready!(w.as_mut().poll_write(cx, buf)) {
Err(e) => match e.kind() {
std::io::ErrorKind::WouldBlock => return std::task::Poll::Pending,
_ => return Poll::Ready(Err(e.into())),
},
Ok(len) => {
if len == buffer[0].len() {
buffer.pop_front();
} else {
buffer[0].advance(len);
}
}
}
} else {
return Poll::Ready(Ok(()));
}
}
}
}
}
fn poll_close(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
ready!(self.project().writer.poll_close(cx))?;
Poll::Ready(Ok(()))
}
}

View file

@ -1,152 +0,0 @@
// Rust Levin Library
// Written in 2023 by
// Cuprate Contributors
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
//! This module provides a `BucketStream` struct, which is a stream of `Bucket`s,
//! where only the header is decoded. If you are a user of this library you should
//! probably use `MessageStream` instead.
use std::task::Poll;
use bytes::{Buf, BytesMut};
use futures::stream::Stream;
use futures::{ready, AsyncRead};
use pin_project::pin_project;
use super::{Bucket, BucketError, BucketHead};
/// An enum representing the decoding state of a `BucketStream`.
#[derive(Debug, Clone)]
enum BucketDecoder {
/// Waiting for the header of a `Bucket`.
WaitingForHeader,
/// Waiting for the body of a `Bucket` with the given header.
WaitingForBody(BucketHead),
}
impl BucketDecoder {
/// Returns the number of bytes needed to complete the current decoding state.
pub fn bytes_needed(&self) -> usize {
match self {
Self::WaitingForHeader => BucketHead::SIZE,
Self::WaitingForBody(bucket_head) => bucket_head.size as usize,
}
}
/// Tries to decode a `Bucket` from the given buffer, returning the decoded `Bucket` and the
/// number of bytes consumed from the buffer.
pub fn try_decode_bucket(
&mut self,
mut buf: &[u8],
) -> Result<(Option<Bucket>, usize), BucketError> {
let mut len = 0;
// first we decode header
if let BucketDecoder::WaitingForHeader = self {
if buf.len() < BucketHead::SIZE {
return Ok((None, 0));
}
let header = BucketHead::from_bytes(&mut buf)?;
len += BucketHead::SIZE;
*self = BucketDecoder::WaitingForBody(header);
};
// next we check we have enough bytes to fill the body
if let &mut Self::WaitingForBody(head) = self {
if buf.len() < head.size as usize {
return Ok((None, len));
}
*self = BucketDecoder::WaitingForHeader;
Ok((
Some(Bucket {
header: head,
body: buf.copy_to_bytes(buf.len()),
}),
len + head.size as usize,
))
} else {
unreachable!()
}
}
}
/// A stream of `Bucket`s, with only the header decoded.
#[pin_project]
#[derive(Debug, Clone)]
pub struct BucketStream<S> {
#[pin]
stream: S,
decoder: BucketDecoder,
buffer: BytesMut,
}
impl<S: AsyncRead> BucketStream<S> {
/// Creates a new `BucketStream` from the given `AsyncRead` stream.
pub fn new(stream: S) -> Self {
BucketStream {
stream,
decoder: BucketDecoder::WaitingForHeader,
buffer: BytesMut::with_capacity(1024),
}
}
}
impl<S: AsyncRead + std::marker::Unpin> Stream for BucketStream<S> {
type Item = Result<Bucket, BucketError>;
/// Attempt to read from the underlying stream into the buffer until enough bytes are received to construct a `Bucket`.
///
/// If enough bytes are received, return the decoded `Bucket`, if not enough bytes are received to construct a `Bucket`,
/// return `Poll::Pending`. This will never return `Poll::Ready(None)`.
///
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let this = self.project();
let mut stream = this.stream;
let decoder = this.decoder;
let buffer = this.buffer;
loop {
// this is a bit ugly but all we are doing is calculating the amount of bytes we
// need to build the rest of a bucket if this is zero it means we need to start
// reading a new bucket
let mut bytes_needed = buffer.len().saturating_sub(decoder.bytes_needed());
if bytes_needed == 0 {
bytes_needed = 1024
}
let mut buf = vec![0; bytes_needed];
match ready!(stream.as_mut().poll_read(cx, &mut buf)) {
Err(e) => match e.kind() {
std::io::ErrorKind::WouldBlock => return std::task::Poll::Pending,
std::io::ErrorKind::Interrupted => continue,
_ => return Poll::Ready(Some(Err(BucketError::IO(e)))),
},
Ok(len) => {
buffer.extend(&buf[..len]);
let (bucket, len) = decoder.try_decode_bucket(buffer)?;
buffer.advance(len);
if let Some(bucket) = bucket {
return Poll::Ready(Some(Ok(bucket)));
} else {
continue;
}
}
}
}
}
}

235
net/levin/src/codec.rs Normal file
View file

@ -0,0 +1,235 @@
// Rust Levin Library
// Written in 2023 by
// Cuprate Contributors
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
//! A tokio-codec for levin buckets
use std::io::ErrorKind;
use std::marker::PhantomData;
use bytes::{Buf, BufMut, BytesMut};
use tokio_util::codec::{Decoder, Encoder};
use crate::{
Bucket, BucketBuilder, BucketError, BucketHead, LevinBody, MessageType,
LEVIN_DEFAULT_MAX_PACKET_SIZE,
};
/// The levin tokio-codec for decoding and encoding levin buckets
#[derive(Default)]
pub enum LevinCodec {
/// Waiting for the peer to send a header.
#[default]
WaitingForHeader,
/// Waiting for a peer to send a body.
WaitingForBody(BucketHead),
}
impl Decoder for LevinCodec {
type Item = Bucket;
type Error = BucketError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
loop {
match self {
LevinCodec::WaitingForHeader => {
if src.len() < BucketHead::SIZE {
return Ok(None);
};
let head = BucketHead::from_bytes(src)?;
let _ = std::mem::replace(self, LevinCodec::WaitingForBody(head));
}
LevinCodec::WaitingForBody(head) => {
// We size check header while decoding it.
let body_len = head
.size
.try_into()
.map_err(|_| BucketError::BucketExceededMaxSize)?;
if src.len() < body_len {
src.reserve(body_len - src.len());
return Ok(None);
}
let LevinCodec::WaitingForBody(header) = std::mem::replace(self, LevinCodec::WaitingForHeader) else {
unreachable!()
};
return Ok(Some(Bucket {
header,
body: src.copy_to_bytes(body_len).into(),
}));
}
}
}
}
}
impl Encoder<Bucket> for LevinCodec {
type Error = BucketError;
fn encode(&mut self, item: Bucket, dst: &mut BytesMut) -> Result<(), Self::Error> {
if dst.capacity() < BucketHead::SIZE + item.body.len() {
return Err(BucketError::IO(std::io::Error::new(
ErrorKind::OutOfMemory,
"Not enough capacity to write the bucket",
)));
}
item.header.write_bytes(dst);
dst.put_slice(&item.body);
Ok(())
}
}
#[derive(Default)]
enum MessageState {
#[default]
WaitingForBucket,
WaitingForRestOfFragment(Vec<u8>, MessageType, u32),
}
/// A tokio-codec for levin messages or in other words the decoded body
/// of a levin bucket.
pub struct LevinMessageCodec<T> {
message_ty: PhantomData<T>,
bucket_codec: LevinCodec,
state: MessageState,
}
impl<T: LevinBody> Decoder for LevinMessageCodec<T> {
type Item = T;
type Error = BucketError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
loop {
match &mut self.state {
MessageState::WaitingForBucket => {
let Some(bucket) = self.bucket_codec.decode(src)? else {
return Ok(None);
};
let end_fragment = bucket.header.flags.end_fragment;
let start_fragment = bucket.header.flags.start_fragment;
let request = bucket.header.flags.request;
let response = bucket.header.flags.response;
if start_fragment && end_fragment {
// Dummy message
return Ok(None);
};
if end_fragment {
return Err(BucketError::InvalidHeaderFlags(
"Flag end fragment received before a start fragment",
));
};
if !request && !response {
return Err(BucketError::InvalidHeaderFlags(
"Request and response flags both not set",
));
};
let message_type = MessageType::from_flags_and_have_to_return(
bucket.header.flags,
bucket.header.have_to_return_data,
)?;
if start_fragment {
let _ = std::mem::replace(
&mut self.state,
MessageState::WaitingForRestOfFragment(
bucket.body.to_vec(),
message_type,
bucket.header.protocol_version,
),
);
continue;
}
return Ok(Some(T::decode_message(
&bucket.body,
message_type,
bucket.header.command,
)?));
}
MessageState::WaitingForRestOfFragment(bytes, ty, command) => {
let Some(bucket) = self.bucket_codec.decode(src)? else {
return Ok(None);
};
let end_fragment = bucket.header.flags.end_fragment;
let start_fragment = bucket.header.flags.start_fragment;
let request = bucket.header.flags.request;
let response = bucket.header.flags.response;
if start_fragment && end_fragment {
// Dummy message
return Ok(None);
};
if !request && !response {
return Err(BucketError::InvalidHeaderFlags(
"Request and response flags both not set",
));
};
let message_type = MessageType::from_flags_and_have_to_return(
bucket.header.flags,
bucket.header.have_to_return_data,
)?;
if message_type != *ty {
return Err(BucketError::InvalidFragmentedMessage(
"Message type was inconsistent across fragments",
));
}
if bucket.header.command != *command {
return Err(BucketError::InvalidFragmentedMessage(
"Command not consistent across message",
));
}
if bytes.len() + bucket.body.len()
> LEVIN_DEFAULT_MAX_PACKET_SIZE.try_into().unwrap()
{
return Err(BucketError::InvalidFragmentedMessage(
"Fragmented message exceeded maximum size",
));
}
bytes.append(&mut bucket.body.to_vec());
if end_fragment {
let MessageState::WaitingForRestOfFragment(bytes, ty, command) =
std::mem::replace(&mut self.state, MessageState::WaitingForBucket) else {
unreachable!();
};
return Ok(Some(T::decode_message(&bytes, ty, command)?));
}
}
}
}
}
}
impl<T: LevinBody> Encoder<T> for LevinMessageCodec<T> {
type Error = BucketError;
fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<(), Self::Error> {
let mut bucket_builder = BucketBuilder::default();
item.encode(&mut bucket_builder)?;
let bucket = bucket_builder.finish();
self.bucket_codec.encode(bucket, dst)
}
}

View file

@ -16,117 +16,63 @@
//! This module provides a struct BucketHead for the header of a levin protocol
//! message.
use std::io::Read;
use crate::LEVIN_DEFAULT_MAX_PACKET_SIZE;
use bytes::{Buf, BufMut, BytesMut};
use super::{BucketError, LEVIN_SIGNATURE, PROTOCOL_VERSION};
use byteorder::{LittleEndian, ReadBytesExt};
const REQUEST: u32 = 0b0000_0001;
const RESPONSE: u32 = 0b0000_0010;
const START_FRAGMENT: u32 = 0b0000_0100;
const END_FRAGMENT: u32 = 0b0000_1000;
/// The Flags for the levin header
/// Levin header flags
#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
pub struct Flags(u32);
pub struct Flags {
/// Q bit
pub request: bool,
/// S bit
pub response: bool,
/// B bit
pub start_fragment: bool,
/// E bit
pub end_fragment: bool,
}
pub(crate) const REQUEST: Flags = Flags(0b0000_0001);
pub(crate) const RESPONSE: Flags = Flags(0b0000_0010);
const START_FRAGMENT: Flags = Flags(0b0000_0100);
const END_FRAGMENT: Flags = Flags(0b0000_1000);
const DUMMY: Flags = Flags(0b0000_1100); // both start and end fragment set
impl Flags {
fn contains_flag(&self, rhs: Self) -> bool {
self & &rhs == rhs
}
/// Converts the inner flags to little endian bytes
pub fn to_le_bytes(&self) -> [u8; 4] {
self.0.to_le_bytes()
}
/// Checks if the flags have the `REQUEST` flag set and
/// does not have the `RESPONSE` flag set, this does
/// not check for other flags
pub fn is_request(&self) -> bool {
self.contains_flag(REQUEST) && !self.contains_flag(RESPONSE)
}
/// Checks if the flags have the `RESPONSE` flag set and
/// does not have the `REQUEST` flag set, this does
/// not check for other flags
pub fn is_response(&self) -> bool {
self.contains_flag(RESPONSE) && !self.contains_flag(REQUEST)
}
/// Checks if the flags have the `START_FRAGMENT`and the
/// `END_FRAGMENT` flags set, this does
/// not check for other flags
pub fn is_dummy(&self) -> bool {
self.contains_flag(DUMMY)
}
/// Checks if the flags have the `START_FRAGMENT` flag
/// set and does not have the `END_FRAGMENT` flag set, this
/// does not check for other flags
pub fn is_start_fragment(&self) -> bool {
self.contains_flag(START_FRAGMENT) && !self.is_dummy()
}
/// Checks if the flags have the `END_FRAGMENT` flag
/// set and does not have the `START_FRAGMENT` flag set, this
/// does not check for other flags
pub fn is_end_fragment(&self) -> bool {
self.contains_flag(END_FRAGMENT) && !self.is_dummy()
}
/// Sets the `REQUEST` flag
pub fn set_flag_request(&mut self) {
*self |= REQUEST
}
/// Sets the `RESPONSE` flag
pub fn set_flag_response(&mut self) {
*self |= RESPONSE
}
/// Sets the `START_FRAGMENT` flag
pub fn set_flag_start_fragment(&mut self) {
*self |= START_FRAGMENT
}
/// Sets the `END_FRAGMENT` flag
pub fn set_flag_end_fragment(&mut self) {
*self |= END_FRAGMENT
}
/// Sets the `START_FRAGMENT` and `END_FRAGMENT` flag
pub fn set_flag_dummy(&mut self) {
self.set_flag_start_fragment();
self.set_flag_end_fragment();
impl TryFrom<u32> for Flags {
type Error = BucketError;
fn try_from(value: u32) -> Result<Self, Self::Error> {
let flags = Flags {
request: value & REQUEST > 0,
response: value & RESPONSE > 0,
start_fragment: value & START_FRAGMENT > 0,
end_fragment: value & END_FRAGMENT > 0,
};
if flags.request && flags.response {
return Err(BucketError::InvalidHeaderFlags(
"Request and Response bits set",
));
};
Ok(flags)
}
}
impl From<u32> for Flags {
fn from(value: u32) -> Self {
Flags(value)
}
}
impl core::ops::BitAnd for &Flags {
type Output = Flags;
fn bitand(self, rhs: Self) -> Self::Output {
Flags(self.0 & rhs.0)
}
}
impl core::ops::BitOr for &Flags {
type Output = Flags;
fn bitor(self, rhs: Self) -> Self::Output {
Flags(self.0 | rhs.0)
}
}
impl core::ops::BitOrAssign for Flags {
fn bitor_assign(&mut self, rhs: Self) {
self.0 |= rhs.0
impl From<Flags> for u32 {
fn from(value: Flags) -> Self {
let mut ret = 0;
if value.request {
ret |= REQUEST;
};
if value.response {
ret |= RESPONSE;
};
if value.start_fragment {
ret |= START_FRAGMENT;
};
if value.end_fragment {
ret |= END_FRAGMENT;
};
ret
}
}
@ -156,7 +102,7 @@ impl BucketHead {
pub const SIZE: usize = 33;
/// Builds the header in a Monero specific way
pub fn build(
pub fn build_monero(
payload_size: u64,
have_to_return_data: bool,
command: u32,
@ -176,52 +122,38 @@ impl BucketHead {
/// Builds the header from bytes, this function does not check any fields should
/// match the expected ones (signature, protocol_version)
pub fn from_bytes<R: Read + ?Sized>(r: &mut R) -> Result<BucketHead, BucketError> {
///
/// # Panics
/// This function will panic if there aren't enough bytes to fill the header.
/// Currently ['SIZE'](BucketHead::SIZE)
pub fn from_bytes(buf: &mut BytesMut) -> Result<BucketHead, BucketError> {
let header = BucketHead {
signature: r.read_u64::<LittleEndian>()?,
size: r.read_u64::<LittleEndian>()?,
have_to_return_data: r.read_u8()? != 0,
command: r.read_u32::<LittleEndian>()?,
return_code: r.read_i32::<LittleEndian>()?,
// this is incorrect an will not work for fragmented messages
flags: Flags::from(r.read_u32::<LittleEndian>()?),
protocol_version: r.read_u32::<LittleEndian>()?,
signature: buf.get_u64_le(),
size: buf.get_u64_le(),
have_to_return_data: buf.get_u8() != 0,
command: buf.get_u32_le(),
return_code: buf.get_i32_le(),
flags: Flags::try_from(buf.get_u32_le())?,
protocol_version: buf.get_u32_le(),
};
if header.size > LEVIN_DEFAULT_MAX_PACKET_SIZE {
return Err(BucketError::BucketExceededMaxSize);
}
Ok(header)
}
/// Serializes the header
pub fn to_bytes(&self) -> Vec<u8> {
let mut out = Vec::with_capacity(BucketHead::SIZE);
out.extend_from_slice(&self.signature.to_le_bytes());
out.extend_from_slice(&self.size.to_le_bytes());
out.push(if self.have_to_return_data { 1 } else { 0 });
out.extend_from_slice(&self.command.to_le_bytes());
out.extend_from_slice(&self.return_code.to_le_bytes());
out.extend_from_slice(&self.flags.to_le_bytes());
out.extend_from_slice(&self.protocol_version.to_le_bytes());
out
}
}
#[cfg(test)]
mod tests {
use super::Flags;
#[test]
fn set_flags() {
macro_rules! set_and_check {
($set:ident, $check:ident) => {
let mut flag = Flags::default();
flag.$set();
assert!(flag.$check());
};
}
set_and_check!(set_flag_request, is_request);
set_and_check!(set_flag_response, is_response);
set_and_check!(set_flag_start_fragment, is_start_fragment);
set_and_check!(set_flag_end_fragment, is_end_fragment);
set_and_check!(set_flag_dummy, is_dummy);
pub fn write_bytes(&self, dst: &mut BytesMut) {
dst.reserve(BucketHead::SIZE);
dst.put_u64_le(self.signature);
dst.put_u64_le(self.size);
dst.put_u8(if self.have_to_return_data { 1 } else { 0 });
dst.put_u32_le(self.command);
dst.put_i32_le(self.return_code);
dst.put_u32_le(self.flags.into());
dst.put_u32_le(self.protocol_version);
}
}

View file

@ -19,8 +19,8 @@
//!
//! The Levin protocol is a network protocol used in the Monero cryptocurrency. It is used for
//! peer-to-peer communication between nodes. This crate provides a Rust implementation of the Levin
//! header serialization and allows developers to define their own bucket bodies so this is not a
//! complete Monero networking crate.
//! header serialization and allows developers to define their own bucket bodies, for a complete
//! monero protocol crate see: monero-wire.
//!
//! ## License
//!
@ -31,75 +31,50 @@
#![deny(non_upper_case_globals)]
#![deny(non_camel_case_types)]
#![deny(unused_mut)]
#![deny(missing_docs)]
//#![deny(missing_docs)]
pub mod bucket_sink;
pub mod bucket_stream;
pub mod codec;
pub mod header;
pub mod message_sink;
pub mod message_stream;
pub use codec::LevinCodec;
pub use header::BucketHead;
pub use message_sink::MessageSink;
pub use message_stream::MessageStream;
use std::fmt::Debug;
use bytes::Bytes;
use thiserror::Error;
const PROTOCOL_VERSION: u32 = 1;
const LEVIN_SIGNATURE: u64 = 0x0101010101012101;
const LEVIN_DEFAULT_MAX_PACKET_SIZE: u64 = 100_000_000; // 100MB
/// Possible Errors when working with levin buckets
#[derive(Error, Debug)]
pub enum BucketError {
/// Unsupported p2p command.
#[error("Unsupported p2p command: {0}")]
UnsupportedP2pCommand(u32),
/// Revived header with incorrect signature.
#[error("Revived header with incorrect signature: {0}")]
IncorrectSignature(u64),
/// Header contains unknown flags.
#[error("Header contains unknown flags")]
UnknownFlags,
/// Revived header with unknown protocol version.
#[error("Revived header with unknown protocol version: {0}")]
UnknownProtocolVersion(u32),
/// More bytes needed to parse data.
#[error("More bytes needed to parse data")]
NotEnoughBytes,
/// Failed to decode bucket body.
#[error("Failed to decode bucket body: {0}")]
FailedToDecodeBucketBody(String),
/// Failed to encode bucket body.
#[error("Failed to encode bucket body: {0}")]
FailedToEncodeBucketBody(String),
/// IO Error.
#[error("IO Error: {0}")]
/// Invalid header flags
#[error("Invalid header flags: {0}")]
InvalidHeaderFlags(&'static str),
/// Levin bucket exceeded max size
#[error("Levin bucket exceeded max size")]
BucketExceededMaxSize,
/// Invalid Fragmented Message
#[error("Levin fragmented message was invalid: {0}")]
InvalidFragmentedMessage(&'static str),
/// I/O error
#[error("I/O error: {0}")]
IO(#[from] std::io::Error),
/// Peer sent an error response code.
#[error("Peer sent an error response code: {0}")]
Error(i32),
}
const PROTOCOL_VERSION: u32 = 1;
const LEVIN_SIGNATURE: u64 = 0x0101010101012101;
/// A levin Bucket
#[derive(Debug)]
pub struct Bucket {
header: BucketHead,
body: Bytes,
}
impl Bucket {
fn to_bytes(&self) -> Bytes {
let mut buf = self.header.to_bytes();
buf.extend(self.body.iter());
buf.into()
}
/// The bucket header
pub header: BucketHead,
/// The bucket body
pub body: Vec<u8>,
}
/// An enum representing if the message is a request or response
#[derive(Debug)]
#[derive(Debug, Eq, PartialEq)]
pub enum MessageType {
/// Request
Request,
@ -123,23 +98,94 @@ impl MessageType {
flags: header::Flags,
have_to_return: bool,
) -> Result<Self, BucketError> {
if flags.is_request() && have_to_return {
if flags.request && have_to_return {
Ok(MessageType::Request)
} else if flags.is_request() {
} else if flags.request {
Ok(MessageType::Notification)
} else if flags.is_response() && !have_to_return {
} else if flags.response && !have_to_return {
Ok(MessageType::Response)
} else {
Err(BucketError::UnknownFlags)
Err(BucketError::InvalidHeaderFlags(
"Unable to assign a message type to this bucket",
))
}
}
pub fn as_flags(&self) -> header::Flags {
match self {
MessageType::Request | MessageType::Notification => header::Flags {
request: true,
..Default::default()
},
MessageType::Response => header::Flags {
response: true,
..Default::default()
},
}
}
}
impl From<MessageType> for header::Flags {
fn from(val: MessageType) -> Self {
match val {
MessageType::Request | MessageType::Notification => header::REQUEST,
MessageType::Response => header::RESPONSE,
pub struct BucketBuilder {
signature: Option<u64>,
ty: Option<MessageType>,
command: Option<u32>,
return_code: Option<i32>,
protocol_version: Option<u32>,
body: Option<Vec<u8>>,
}
impl Default for BucketBuilder {
fn default() -> Self {
Self {
signature: Some(LEVIN_SIGNATURE),
ty: None,
command: None,
return_code: None,
protocol_version: Some(PROTOCOL_VERSION),
body: None,
}
}
}
impl BucketBuilder {
pub fn set_signature(&mut self, sig: u64) {
self.signature = Some(sig)
}
pub fn set_message_type(&mut self, ty: MessageType) {
self.ty = Some(ty)
}
pub fn set_command(&mut self, command: u32) {
self.command = Some(command)
}
pub fn set_return_code(&mut self, code: i32) {
self.return_code = Some(code)
}
pub fn set_protocol_version(&mut self, version: u32) {
self.protocol_version = Some(version)
}
pub fn set_body(&mut self, body: Vec<u8>) {
self.body = Some(body)
}
pub fn finish(self) -> Bucket {
let body = self.body.unwrap();
let ty = self.ty.unwrap();
Bucket {
header: BucketHead {
signature: self.signature.unwrap(),
size: body.len().try_into().unwrap(),
have_to_return_data: ty.have_to_return_data(),
command: self.command.unwrap(),
return_code: self.return_code.unwrap(),
flags: ty.as_flags(),
protocol_version: self.protocol_version.unwrap(),
},
body,
}
}
}
@ -150,11 +196,5 @@ pub trait LevinBody: Sized {
fn decode_message(buf: &[u8], typ: MessageType, command: u32) -> Result<Self, BucketError>;
/// Encodes the message
///
/// returns:
/// return_code: i32,
/// command: u32,
/// message_type: MessageType
/// bytes: Vec<u8>
fn encode(&self) -> Result<(i32, u32, MessageType, Vec<u8>), BucketError>;
fn encode(&self, builder: &mut BucketBuilder) -> Result<(), BucketError>;
}

View file

@ -1,92 +0,0 @@
// Rust Levin Library
// Written in 2023 by
// Cuprate Contributors
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
//! This module provides a `MessageSink` struct, which serializes user defined messages
//! into levin `Bucket`s and passes them onto the `BucketSink`
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::Poll;
use futures::AsyncWrite;
use futures::Sink;
use pin_project::pin_project;
use crate::bucket_sink::BucketSink;
use crate::Bucket;
use crate::BucketError;
use crate::BucketHead;
use crate::LevinBody;
/// A Sink that converts levin messages to buckets and passes them onto the `BucketSink`
#[pin_project]
pub struct MessageSink<W, E> {
#[pin]
bucket_sink: BucketSink<W>,
phantom: PhantomData<E>,
}
impl<W: AsyncWrite + std::marker::Unpin, E: LevinBody> MessageSink<W, E> {
/// Creates a new sink from the provided [`AsyncWrite`]
pub fn new(writer: W) -> Self {
MessageSink {
bucket_sink: BucketSink::new(writer),
phantom: PhantomData,
}
}
}
impl<W: AsyncWrite + std::marker::Unpin, E: LevinBody> Sink<E> for MessageSink<W, E> {
type Error = BucketError;
fn poll_ready(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.project().bucket_sink.poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: E) -> Result<(), Self::Error> {
let (return_code, command, message_type, body) = item.encode()?;
let header = BucketHead::build(
body.len() as u64,
message_type.have_to_return_data(),
command,
message_type.into(),
return_code,
);
let bucket = Bucket {
header,
body: body.into(),
};
self.project().bucket_sink.start_send(bucket)
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.project().bucket_sink.poll_flush(cx)
}
fn poll_close(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.project().bucket_sink.poll_close(cx)
}
}

View file

@ -1,99 +0,0 @@
// Rust Levin Library
// Written in 2023 by
// Cuprate Contributors
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
//! This modual provides a `MessageStream` which deserializes partially decoded `Bucket`s
//! into full Buckets using a user provided `LevinBody`
use std::marker::PhantomData;
use std::task::Poll;
use futures::ready;
use futures::AsyncRead;
use futures::Stream;
use pin_project::pin_project;
use crate::bucket_stream::BucketStream;
use crate::BucketError;
use crate::LevinBody;
use crate::MessageType;
use crate::LEVIN_SIGNATURE;
use crate::PROTOCOL_VERSION;
/// A stream that reads from the underlying `BucketStream` and uses the the
/// methods on the `LevinBody` trait to decode the inner messages(bodies)
#[pin_project]
pub struct MessageStream<S, D> {
#[pin]
bucket_stream: BucketStream<S>,
phantom: PhantomData<D>,
}
impl<D: LevinBody, S: AsyncRead + std::marker::Unpin> MessageStream<S, D> {
/// Creates a new stream from the provided `AsyncRead`
pub fn new(stream: S) -> Self {
MessageStream {
bucket_stream: BucketStream::new(stream),
phantom: PhantomData,
}
}
}
impl<D: LevinBody, S: AsyncRead + std::marker::Unpin> Stream for MessageStream<S, D> {
type Item = Result<D, BucketError>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
let this = self.project();
match ready!(this.bucket_stream.poll_next(cx)).expect("BucketStream will never return None")
{
Err(e) => Poll::Ready(Some(Err(e))),
Ok(bucket) => {
if bucket.header.signature != LEVIN_SIGNATURE {
return Err(BucketError::IncorrectSignature(bucket.header.signature))?;
}
if bucket.header.protocol_version != PROTOCOL_VERSION {
return Err(BucketError::UnknownProtocolVersion(
bucket.header.protocol_version,
))?;
}
// TODO: we shouldn't return an error if the peer sends an error response we should define a new network
// message: Error.
if bucket.header.return_code < 0
|| (bucket.header.return_code == 0 && bucket.header.flags.is_response())
{
return Err(BucketError::Error(bucket.header.return_code))?;
}
if bucket.header.flags.is_dummy() {
cx.waker().wake_by_ref();
return Poll::Pending;
}
Poll::Ready(Some(D::decode_message(
&bucket.body,
MessageType::from_flags_and_have_to_return(
bucket.header.flags,
bucket.header.have_to_return_data,
)?,
bucket.header.command,
)))
}
}
}
}