diff --git a/net/levin/Cargo.toml b/net/levin/Cargo.toml index 407cef14..0101e082 100644 --- a/net/levin/Cargo.toml +++ b/net/levin/Cargo.toml @@ -1,15 +1,14 @@ [package] -name = "levin" +name = "levin-cuprate" version = "0.1.0" edition = "2021" description = "A crate for working with the Levin protocol in Rust." license = "MIT" authors = ["Boog900"] -repository = "https://github.com/SyntheticBird45/cuprate/tree/main/net/levin" +repository = "https://github.com/Cuprate/cuprate/tree/main/net/levin" [dependencies] -thiserror = "1.0.24" -byteorder = "1.4.3" -futures = "0.3" +thiserror = "1" bytes = "1" -pin-project = "1" +tokio-util = {version = "0.7", features = ["codec"]} + diff --git a/net/levin/src/bucket_sink.rs b/net/levin/src/bucket_sink.rs deleted file mode 100644 index 50f82251..00000000 --- a/net/levin/src/bucket_sink.rs +++ /dev/null @@ -1,107 +0,0 @@ -// Rust Levin Library -// Written in 2023 by -// Cuprate Contributors -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all -// copies or substantial portions of the Software. -// - -//! This module provides a `BucketSink` struct, which writes buckets to the -//! provided `AsyncWrite`. If you are a user of this library you should -//! probably use `MessageSink` instead. - -use std::collections::VecDeque; -use std::pin::Pin; -use std::task::Poll; - -use bytes::{Buf, BytesMut}; -use futures::ready; -use futures::sink::Sink; -use futures::AsyncWrite; -use pin_project::pin_project; - -use crate::{Bucket, BucketError}; - -/// A BucketSink writes Bucket instances to the provided AsyncWrite target. -#[pin_project] -pub struct BucketSink { - #[pin] - writer: W, - buffer: VecDeque, -} - -impl BucketSink { - /// Creates a new [`BucketSink`] from the given [`AsyncWrite`] writer. - pub fn new(writer: W) -> Self { - BucketSink { - writer, - buffer: VecDeque::with_capacity(2), - } - } -} - -impl Sink for BucketSink { - type Error = BucketError; - - fn poll_ready( - self: std::pin::Pin<&mut Self>, - _: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - Poll::Ready(Ok(())) - } - - fn start_send(mut self: Pin<&mut Self>, item: Bucket) -> Result<(), Self::Error> { - let buf = item.to_bytes(); - self.buffer.push_back(BytesMut::from(&buf[..])); - Ok(()) - } - - fn poll_flush( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - let this = self.project(); - let mut w = this.writer; - let buffer = this.buffer; - - loop { - match ready!(w.as_mut().poll_flush(cx)) { - Err(err) => return Poll::Ready(Err(err.into())), - Ok(()) => { - if let Some(buf) = buffer.front() { - match ready!(w.as_mut().poll_write(cx, buf)) { - Err(e) => match e.kind() { - std::io::ErrorKind::WouldBlock => return std::task::Poll::Pending, - _ => return Poll::Ready(Err(e.into())), - }, - Ok(len) => { - if len == buffer[0].len() { - buffer.pop_front(); - } else { - buffer[0].advance(len); - } - } - } - } else { - return Poll::Ready(Ok(())); - } - } - } - } - } - - fn poll_close( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - ready!(self.project().writer.poll_close(cx))?; - Poll::Ready(Ok(())) - } -} diff --git a/net/levin/src/bucket_stream.rs b/net/levin/src/bucket_stream.rs deleted file mode 100644 index f8ad8f12..00000000 --- a/net/levin/src/bucket_stream.rs +++ /dev/null @@ -1,152 +0,0 @@ -// Rust Levin Library -// Written in 2023 by -// Cuprate Contributors -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all -// copies or substantial portions of the Software. -// - -//! This module provides a `BucketStream` struct, which is a stream of `Bucket`s, -//! where only the header is decoded. If you are a user of this library you should -//! probably use `MessageStream` instead. - -use std::task::Poll; - -use bytes::{Buf, BytesMut}; -use futures::stream::Stream; -use futures::{ready, AsyncRead}; -use pin_project::pin_project; - -use super::{Bucket, BucketError, BucketHead}; - -/// An enum representing the decoding state of a `BucketStream`. -#[derive(Debug, Clone)] -enum BucketDecoder { - /// Waiting for the header of a `Bucket`. - WaitingForHeader, - /// Waiting for the body of a `Bucket` with the given header. - WaitingForBody(BucketHead), -} - -impl BucketDecoder { - /// Returns the number of bytes needed to complete the current decoding state. - pub fn bytes_needed(&self) -> usize { - match self { - Self::WaitingForHeader => BucketHead::SIZE, - Self::WaitingForBody(bucket_head) => bucket_head.size as usize, - } - } - - /// Tries to decode a `Bucket` from the given buffer, returning the decoded `Bucket` and the - /// number of bytes consumed from the buffer. - pub fn try_decode_bucket( - &mut self, - mut buf: &[u8], - ) -> Result<(Option, usize), BucketError> { - let mut len = 0; - - // first we decode header - if let BucketDecoder::WaitingForHeader = self { - if buf.len() < BucketHead::SIZE { - return Ok((None, 0)); - } - let header = BucketHead::from_bytes(&mut buf)?; - len += BucketHead::SIZE; - *self = BucketDecoder::WaitingForBody(header); - }; - - // next we check we have enough bytes to fill the body - if let &mut Self::WaitingForBody(head) = self { - if buf.len() < head.size as usize { - return Ok((None, len)); - } - *self = BucketDecoder::WaitingForHeader; - Ok(( - Some(Bucket { - header: head, - body: buf.copy_to_bytes(buf.len()), - }), - len + head.size as usize, - )) - } else { - unreachable!() - } - } -} - -/// A stream of `Bucket`s, with only the header decoded. -#[pin_project] -#[derive(Debug, Clone)] -pub struct BucketStream { - #[pin] - stream: S, - decoder: BucketDecoder, - buffer: BytesMut, -} - -impl BucketStream { - /// Creates a new `BucketStream` from the given `AsyncRead` stream. - pub fn new(stream: S) -> Self { - BucketStream { - stream, - decoder: BucketDecoder::WaitingForHeader, - buffer: BytesMut::with_capacity(1024), - } - } -} - -impl Stream for BucketStream { - type Item = Result; - - /// Attempt to read from the underlying stream into the buffer until enough bytes are received to construct a `Bucket`. - /// - /// If enough bytes are received, return the decoded `Bucket`, if not enough bytes are received to construct a `Bucket`, - /// return `Poll::Pending`. This will never return `Poll::Ready(None)`. - /// - fn poll_next( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let this = self.project(); - let mut stream = this.stream; - let decoder = this.decoder; - let buffer = this.buffer; - - loop { - // this is a bit ugly but all we are doing is calculating the amount of bytes we - // need to build the rest of a bucket if this is zero it means we need to start - // reading a new bucket - let mut bytes_needed = buffer.len().saturating_sub(decoder.bytes_needed()); - if bytes_needed == 0 { - bytes_needed = 1024 - } - - let mut buf = vec![0; bytes_needed]; - match ready!(stream.as_mut().poll_read(cx, &mut buf)) { - Err(e) => match e.kind() { - std::io::ErrorKind::WouldBlock => return std::task::Poll::Pending, - std::io::ErrorKind::Interrupted => continue, - _ => return Poll::Ready(Some(Err(BucketError::IO(e)))), - }, - Ok(len) => { - buffer.extend(&buf[..len]); - - let (bucket, len) = decoder.try_decode_bucket(buffer)?; - buffer.advance(len); - if let Some(bucket) = bucket { - return Poll::Ready(Some(Ok(bucket))); - } else { - continue; - } - } - } - } - } -} diff --git a/net/levin/src/codec.rs b/net/levin/src/codec.rs new file mode 100644 index 00000000..24f69ee1 --- /dev/null +++ b/net/levin/src/codec.rs @@ -0,0 +1,235 @@ +// Rust Levin Library +// Written in 2023 by +// Cuprate Contributors +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// + +//! A tokio-codec for levin buckets + +use std::io::ErrorKind; +use std::marker::PhantomData; + +use bytes::{Buf, BufMut, BytesMut}; +use tokio_util::codec::{Decoder, Encoder}; + +use crate::{ + Bucket, BucketBuilder, BucketError, BucketHead, LevinBody, MessageType, + LEVIN_DEFAULT_MAX_PACKET_SIZE, +}; + +/// The levin tokio-codec for decoding and encoding levin buckets +#[derive(Default)] +pub enum LevinCodec { + /// Waiting for the peer to send a header. + #[default] + WaitingForHeader, + /// Waiting for a peer to send a body. + WaitingForBody(BucketHead), +} + +impl Decoder for LevinCodec { + type Item = Bucket; + type Error = BucketError; + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + loop { + match self { + LevinCodec::WaitingForHeader => { + if src.len() < BucketHead::SIZE { + return Ok(None); + }; + + let head = BucketHead::from_bytes(src)?; + let _ = std::mem::replace(self, LevinCodec::WaitingForBody(head)); + } + LevinCodec::WaitingForBody(head) => { + // We size check header while decoding it. + let body_len = head + .size + .try_into() + .map_err(|_| BucketError::BucketExceededMaxSize)?; + if src.len() < body_len { + src.reserve(body_len - src.len()); + return Ok(None); + } + + let LevinCodec::WaitingForBody(header) = std::mem::replace(self, LevinCodec::WaitingForHeader) else { + unreachable!() + }; + + return Ok(Some(Bucket { + header, + body: src.copy_to_bytes(body_len).into(), + })); + } + } + } + } +} + +impl Encoder for LevinCodec { + type Error = BucketError; + fn encode(&mut self, item: Bucket, dst: &mut BytesMut) -> Result<(), Self::Error> { + if dst.capacity() < BucketHead::SIZE + item.body.len() { + return Err(BucketError::IO(std::io::Error::new( + ErrorKind::OutOfMemory, + "Not enough capacity to write the bucket", + ))); + } + item.header.write_bytes(dst); + dst.put_slice(&item.body); + Ok(()) + } +} + +#[derive(Default)] +enum MessageState { + #[default] + WaitingForBucket, + WaitingForRestOfFragment(Vec, MessageType, u32), +} + +/// A tokio-codec for levin messages or in other words the decoded body +/// of a levin bucket. +pub struct LevinMessageCodec { + message_ty: PhantomData, + bucket_codec: LevinCodec, + state: MessageState, +} + +impl Decoder for LevinMessageCodec { + type Item = T; + type Error = BucketError; + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + loop { + match &mut self.state { + MessageState::WaitingForBucket => { + let Some(bucket) = self.bucket_codec.decode(src)? else { + return Ok(None); + }; + + let end_fragment = bucket.header.flags.end_fragment; + let start_fragment = bucket.header.flags.start_fragment; + let request = bucket.header.flags.request; + let response = bucket.header.flags.response; + + if start_fragment && end_fragment { + // Dummy message + return Ok(None); + }; + + if end_fragment { + return Err(BucketError::InvalidHeaderFlags( + "Flag end fragment received before a start fragment", + )); + }; + + if !request && !response { + return Err(BucketError::InvalidHeaderFlags( + "Request and response flags both not set", + )); + }; + + let message_type = MessageType::from_flags_and_have_to_return( + bucket.header.flags, + bucket.header.have_to_return_data, + )?; + + if start_fragment { + let _ = std::mem::replace( + &mut self.state, + MessageState::WaitingForRestOfFragment( + bucket.body.to_vec(), + message_type, + bucket.header.protocol_version, + ), + ); + + continue; + } + + return Ok(Some(T::decode_message( + &bucket.body, + message_type, + bucket.header.command, + )?)); + } + MessageState::WaitingForRestOfFragment(bytes, ty, command) => { + let Some(bucket) = self.bucket_codec.decode(src)? else { + return Ok(None); + }; + + let end_fragment = bucket.header.flags.end_fragment; + let start_fragment = bucket.header.flags.start_fragment; + let request = bucket.header.flags.request; + let response = bucket.header.flags.response; + + if start_fragment && end_fragment { + // Dummy message + return Ok(None); + }; + + if !request && !response { + return Err(BucketError::InvalidHeaderFlags( + "Request and response flags both not set", + )); + }; + + let message_type = MessageType::from_flags_and_have_to_return( + bucket.header.flags, + bucket.header.have_to_return_data, + )?; + + if message_type != *ty { + return Err(BucketError::InvalidFragmentedMessage( + "Message type was inconsistent across fragments", + )); + } + + if bucket.header.command != *command { + return Err(BucketError::InvalidFragmentedMessage( + "Command not consistent across message", + )); + } + + if bytes.len() + bucket.body.len() + > LEVIN_DEFAULT_MAX_PACKET_SIZE.try_into().unwrap() + { + return Err(BucketError::InvalidFragmentedMessage( + "Fragmented message exceeded maximum size", + )); + } + + bytes.append(&mut bucket.body.to_vec()); + + if end_fragment { + let MessageState::WaitingForRestOfFragment(bytes, ty, command) = + std::mem::replace(&mut self.state, MessageState::WaitingForBucket) else { + unreachable!(); + }; + + return Ok(Some(T::decode_message(&bytes, ty, command)?)); + } + } + } + } + } +} + +impl Encoder for LevinMessageCodec { + type Error = BucketError; + fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<(), Self::Error> { + let mut bucket_builder = BucketBuilder::default(); + item.encode(&mut bucket_builder)?; + let bucket = bucket_builder.finish(); + self.bucket_codec.encode(bucket, dst) + } +} diff --git a/net/levin/src/header.rs b/net/levin/src/header.rs index bd50a05b..d6b930f7 100644 --- a/net/levin/src/header.rs +++ b/net/levin/src/header.rs @@ -16,117 +16,63 @@ //! This module provides a struct BucketHead for the header of a levin protocol //! message. -use std::io::Read; +use crate::LEVIN_DEFAULT_MAX_PACKET_SIZE; +use bytes::{Buf, BufMut, BytesMut}; use super::{BucketError, LEVIN_SIGNATURE, PROTOCOL_VERSION}; -use byteorder::{LittleEndian, ReadBytesExt}; +const REQUEST: u32 = 0b0000_0001; +const RESPONSE: u32 = 0b0000_0010; +const START_FRAGMENT: u32 = 0b0000_0100; +const END_FRAGMENT: u32 = 0b0000_1000; -/// The Flags for the levin header +/// Levin header flags #[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] -pub struct Flags(u32); +pub struct Flags { + /// Q bit + pub request: bool, + /// S bit + pub response: bool, + /// B bit + pub start_fragment: bool, + /// E bit + pub end_fragment: bool, +} -pub(crate) const REQUEST: Flags = Flags(0b0000_0001); -pub(crate) const RESPONSE: Flags = Flags(0b0000_0010); -const START_FRAGMENT: Flags = Flags(0b0000_0100); -const END_FRAGMENT: Flags = Flags(0b0000_1000); -const DUMMY: Flags = Flags(0b0000_1100); // both start and end fragment set - -impl Flags { - fn contains_flag(&self, rhs: Self) -> bool { - self & &rhs == rhs - } - - /// Converts the inner flags to little endian bytes - pub fn to_le_bytes(&self) -> [u8; 4] { - self.0.to_le_bytes() - } - - /// Checks if the flags have the `REQUEST` flag set and - /// does not have the `RESPONSE` flag set, this does - /// not check for other flags - pub fn is_request(&self) -> bool { - self.contains_flag(REQUEST) && !self.contains_flag(RESPONSE) - } - - /// Checks if the flags have the `RESPONSE` flag set and - /// does not have the `REQUEST` flag set, this does - /// not check for other flags - pub fn is_response(&self) -> bool { - self.contains_flag(RESPONSE) && !self.contains_flag(REQUEST) - } - - /// Checks if the flags have the `START_FRAGMENT`and the - /// `END_FRAGMENT` flags set, this does - /// not check for other flags - pub fn is_dummy(&self) -> bool { - self.contains_flag(DUMMY) - } - - /// Checks if the flags have the `START_FRAGMENT` flag - /// set and does not have the `END_FRAGMENT` flag set, this - /// does not check for other flags - pub fn is_start_fragment(&self) -> bool { - self.contains_flag(START_FRAGMENT) && !self.is_dummy() - } - - /// Checks if the flags have the `END_FRAGMENT` flag - /// set and does not have the `START_FRAGMENT` flag set, this - /// does not check for other flags - pub fn is_end_fragment(&self) -> bool { - self.contains_flag(END_FRAGMENT) && !self.is_dummy() - } - - /// Sets the `REQUEST` flag - pub fn set_flag_request(&mut self) { - *self |= REQUEST - } - - /// Sets the `RESPONSE` flag - pub fn set_flag_response(&mut self) { - *self |= RESPONSE - } - - /// Sets the `START_FRAGMENT` flag - pub fn set_flag_start_fragment(&mut self) { - *self |= START_FRAGMENT - } - - /// Sets the `END_FRAGMENT` flag - pub fn set_flag_end_fragment(&mut self) { - *self |= END_FRAGMENT - } - - /// Sets the `START_FRAGMENT` and `END_FRAGMENT` flag - pub fn set_flag_dummy(&mut self) { - self.set_flag_start_fragment(); - self.set_flag_end_fragment(); +impl TryFrom for Flags { + type Error = BucketError; + fn try_from(value: u32) -> Result { + let flags = Flags { + request: value & REQUEST > 0, + response: value & RESPONSE > 0, + start_fragment: value & START_FRAGMENT > 0, + end_fragment: value & END_FRAGMENT > 0, + }; + if flags.request && flags.response { + return Err(BucketError::InvalidHeaderFlags( + "Request and Response bits set", + )); + }; + Ok(flags) } } -impl From for Flags { - fn from(value: u32) -> Self { - Flags(value) - } -} - -impl core::ops::BitAnd for &Flags { - type Output = Flags; - fn bitand(self, rhs: Self) -> Self::Output { - Flags(self.0 & rhs.0) - } -} - -impl core::ops::BitOr for &Flags { - type Output = Flags; - fn bitor(self, rhs: Self) -> Self::Output { - Flags(self.0 | rhs.0) - } -} - -impl core::ops::BitOrAssign for Flags { - fn bitor_assign(&mut self, rhs: Self) { - self.0 |= rhs.0 +impl From for u32 { + fn from(value: Flags) -> Self { + let mut ret = 0; + if value.request { + ret |= REQUEST; + }; + if value.response { + ret |= RESPONSE; + }; + if value.start_fragment { + ret |= START_FRAGMENT; + }; + if value.end_fragment { + ret |= END_FRAGMENT; + }; + ret } } @@ -156,7 +102,7 @@ impl BucketHead { pub const SIZE: usize = 33; /// Builds the header in a Monero specific way - pub fn build( + pub fn build_monero( payload_size: u64, have_to_return_data: bool, command: u32, @@ -176,52 +122,38 @@ impl BucketHead { /// Builds the header from bytes, this function does not check any fields should /// match the expected ones (signature, protocol_version) - pub fn from_bytes(r: &mut R) -> Result { + /// + /// # Panics + /// This function will panic if there aren't enough bytes to fill the header. + /// Currently ['SIZE'](BucketHead::SIZE) + pub fn from_bytes(buf: &mut BytesMut) -> Result { let header = BucketHead { - signature: r.read_u64::()?, - size: r.read_u64::()?, - have_to_return_data: r.read_u8()? != 0, - command: r.read_u32::()?, - return_code: r.read_i32::()?, - // this is incorrect an will not work for fragmented messages - flags: Flags::from(r.read_u32::()?), - protocol_version: r.read_u32::()?, + signature: buf.get_u64_le(), + size: buf.get_u64_le(), + have_to_return_data: buf.get_u8() != 0, + command: buf.get_u32_le(), + return_code: buf.get_i32_le(), + flags: Flags::try_from(buf.get_u32_le())?, + protocol_version: buf.get_u32_le(), }; + if header.size > LEVIN_DEFAULT_MAX_PACKET_SIZE { + return Err(BucketError::BucketExceededMaxSize); + } + Ok(header) } /// Serializes the header - pub fn to_bytes(&self) -> Vec { - let mut out = Vec::with_capacity(BucketHead::SIZE); - out.extend_from_slice(&self.signature.to_le_bytes()); - out.extend_from_slice(&self.size.to_le_bytes()); - out.push(if self.have_to_return_data { 1 } else { 0 }); - out.extend_from_slice(&self.command.to_le_bytes()); - out.extend_from_slice(&self.return_code.to_le_bytes()); - out.extend_from_slice(&self.flags.to_le_bytes()); - out.extend_from_slice(&self.protocol_version.to_le_bytes()); - out - } -} - -#[cfg(test)] -mod tests { - use super::Flags; - - #[test] - fn set_flags() { - macro_rules! set_and_check { - ($set:ident, $check:ident) => { - let mut flag = Flags::default(); - flag.$set(); - assert!(flag.$check()); - }; - } - set_and_check!(set_flag_request, is_request); - set_and_check!(set_flag_response, is_response); - set_and_check!(set_flag_start_fragment, is_start_fragment); - set_and_check!(set_flag_end_fragment, is_end_fragment); - set_and_check!(set_flag_dummy, is_dummy); + pub fn write_bytes(&self, dst: &mut BytesMut) { + dst.reserve(BucketHead::SIZE); + + dst.put_u64_le(self.signature); + dst.put_u64_le(self.size); + dst.put_u8(if self.have_to_return_data { 1 } else { 0 }); + dst.put_u32_le(self.command); + dst.put_i32_le(self.return_code); + dst.put_u32_le(self.flags.into()); + dst.put_u32_le(self.protocol_version); } } diff --git a/net/levin/src/lib.rs b/net/levin/src/lib.rs index c9f08fdd..c1f85109 100644 --- a/net/levin/src/lib.rs +++ b/net/levin/src/lib.rs @@ -19,8 +19,8 @@ //! //! The Levin protocol is a network protocol used in the Monero cryptocurrency. It is used for //! peer-to-peer communication between nodes. This crate provides a Rust implementation of the Levin -//! header serialization and allows developers to define their own bucket bodies so this is not a -//! complete Monero networking crate. +//! header serialization and allows developers to define their own bucket bodies, for a complete +//! monero protocol crate see: monero-wire. //! //! ## License //! @@ -31,75 +31,50 @@ #![deny(non_upper_case_globals)] #![deny(non_camel_case_types)] #![deny(unused_mut)] -#![deny(missing_docs)] +//#![deny(missing_docs)] -pub mod bucket_sink; -pub mod bucket_stream; +pub mod codec; pub mod header; -pub mod message_sink; -pub mod message_stream; +pub use codec::LevinCodec; pub use header::BucketHead; -pub use message_sink::MessageSink; -pub use message_stream::MessageStream; use std::fmt::Debug; -use bytes::Bytes; use thiserror::Error; +const PROTOCOL_VERSION: u32 = 1; +const LEVIN_SIGNATURE: u64 = 0x0101010101012101; +const LEVIN_DEFAULT_MAX_PACKET_SIZE: u64 = 100_000_000; // 100MB + /// Possible Errors when working with levin buckets #[derive(Error, Debug)] pub enum BucketError { - /// Unsupported p2p command. - #[error("Unsupported p2p command: {0}")] - UnsupportedP2pCommand(u32), - /// Revived header with incorrect signature. - #[error("Revived header with incorrect signature: {0}")] - IncorrectSignature(u64), - /// Header contains unknown flags. - #[error("Header contains unknown flags")] - UnknownFlags, - /// Revived header with unknown protocol version. - #[error("Revived header with unknown protocol version: {0}")] - UnknownProtocolVersion(u32), - /// More bytes needed to parse data. - #[error("More bytes needed to parse data")] - NotEnoughBytes, - /// Failed to decode bucket body. - #[error("Failed to decode bucket body: {0}")] - FailedToDecodeBucketBody(String), - /// Failed to encode bucket body. - #[error("Failed to encode bucket body: {0}")] - FailedToEncodeBucketBody(String), - /// IO Error. - #[error("IO Error: {0}")] + /// Invalid header flags + #[error("Invalid header flags: {0}")] + InvalidHeaderFlags(&'static str), + /// Levin bucket exceeded max size + #[error("Levin bucket exceeded max size")] + BucketExceededMaxSize, + /// Invalid Fragmented Message + #[error("Levin fragmented message was invalid: {0}")] + InvalidFragmentedMessage(&'static str), + /// I/O error + #[error("I/O error: {0}")] IO(#[from] std::io::Error), - /// Peer sent an error response code. - #[error("Peer sent an error response code: {0}")] - Error(i32), } -const PROTOCOL_VERSION: u32 = 1; -const LEVIN_SIGNATURE: u64 = 0x0101010101012101; - /// A levin Bucket #[derive(Debug)] pub struct Bucket { - header: BucketHead, - body: Bytes, -} - -impl Bucket { - fn to_bytes(&self) -> Bytes { - let mut buf = self.header.to_bytes(); - buf.extend(self.body.iter()); - buf.into() - } + /// The bucket header + pub header: BucketHead, + /// The bucket body + pub body: Vec, } /// An enum representing if the message is a request or response -#[derive(Debug)] +#[derive(Debug, Eq, PartialEq)] pub enum MessageType { /// Request Request, @@ -123,23 +98,94 @@ impl MessageType { flags: header::Flags, have_to_return: bool, ) -> Result { - if flags.is_request() && have_to_return { + if flags.request && have_to_return { Ok(MessageType::Request) - } else if flags.is_request() { + } else if flags.request { Ok(MessageType::Notification) - } else if flags.is_response() && !have_to_return { + } else if flags.response && !have_to_return { Ok(MessageType::Response) } else { - Err(BucketError::UnknownFlags) + Err(BucketError::InvalidHeaderFlags( + "Unable to assign a message type to this bucket", + )) + } + } + + pub fn as_flags(&self) -> header::Flags { + match self { + MessageType::Request | MessageType::Notification => header::Flags { + request: true, + ..Default::default() + }, + MessageType::Response => header::Flags { + response: true, + ..Default::default() + }, } } } -impl From for header::Flags { - fn from(val: MessageType) -> Self { - match val { - MessageType::Request | MessageType::Notification => header::REQUEST, - MessageType::Response => header::RESPONSE, +pub struct BucketBuilder { + signature: Option, + ty: Option, + command: Option, + return_code: Option, + protocol_version: Option, + body: Option>, +} + +impl Default for BucketBuilder { + fn default() -> Self { + Self { + signature: Some(LEVIN_SIGNATURE), + ty: None, + command: None, + return_code: None, + protocol_version: Some(PROTOCOL_VERSION), + body: None, + } + } +} + +impl BucketBuilder { + pub fn set_signature(&mut self, sig: u64) { + self.signature = Some(sig) + } + + pub fn set_message_type(&mut self, ty: MessageType) { + self.ty = Some(ty) + } + + pub fn set_command(&mut self, command: u32) { + self.command = Some(command) + } + + pub fn set_return_code(&mut self, code: i32) { + self.return_code = Some(code) + } + + pub fn set_protocol_version(&mut self, version: u32) { + self.protocol_version = Some(version) + } + + pub fn set_body(&mut self, body: Vec) { + self.body = Some(body) + } + + pub fn finish(self) -> Bucket { + let body = self.body.unwrap(); + let ty = self.ty.unwrap(); + Bucket { + header: BucketHead { + signature: self.signature.unwrap(), + size: body.len().try_into().unwrap(), + have_to_return_data: ty.have_to_return_data(), + command: self.command.unwrap(), + return_code: self.return_code.unwrap(), + flags: ty.as_flags(), + protocol_version: self.protocol_version.unwrap(), + }, + body, } } } @@ -150,11 +196,5 @@ pub trait LevinBody: Sized { fn decode_message(buf: &[u8], typ: MessageType, command: u32) -> Result; /// Encodes the message - /// - /// returns: - /// return_code: i32, - /// command: u32, - /// message_type: MessageType - /// bytes: Vec - fn encode(&self) -> Result<(i32, u32, MessageType, Vec), BucketError>; + fn encode(&self, builder: &mut BucketBuilder) -> Result<(), BucketError>; } diff --git a/net/levin/src/message_sink.rs b/net/levin/src/message_sink.rs deleted file mode 100644 index dfcf2765..00000000 --- a/net/levin/src/message_sink.rs +++ /dev/null @@ -1,92 +0,0 @@ -// Rust Levin Library -// Written in 2023 by -// Cuprate Contributors -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all -// copies or substantial portions of the Software. -// - -//! This module provides a `MessageSink` struct, which serializes user defined messages -//! into levin `Bucket`s and passes them onto the `BucketSink` - -use std::marker::PhantomData; -use std::pin::Pin; -use std::task::Poll; - -use futures::AsyncWrite; -use futures::Sink; -use pin_project::pin_project; - -use crate::bucket_sink::BucketSink; -use crate::Bucket; -use crate::BucketError; -use crate::BucketHead; -use crate::LevinBody; - -/// A Sink that converts levin messages to buckets and passes them onto the `BucketSink` -#[pin_project] -pub struct MessageSink { - #[pin] - bucket_sink: BucketSink, - phantom: PhantomData, -} - -impl MessageSink { - /// Creates a new sink from the provided [`AsyncWrite`] - pub fn new(writer: W) -> Self { - MessageSink { - bucket_sink: BucketSink::new(writer), - phantom: PhantomData, - } - } -} - -impl Sink for MessageSink { - type Error = BucketError; - - fn poll_ready( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - self.project().bucket_sink.poll_ready(cx) - } - - fn start_send(self: Pin<&mut Self>, item: E) -> Result<(), Self::Error> { - let (return_code, command, message_type, body) = item.encode()?; - let header = BucketHead::build( - body.len() as u64, - message_type.have_to_return_data(), - command, - message_type.into(), - return_code, - ); - - let bucket = Bucket { - header, - body: body.into(), - }; - - self.project().bucket_sink.start_send(bucket) - } - - fn poll_flush( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - self.project().bucket_sink.poll_flush(cx) - } - - fn poll_close( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - self.project().bucket_sink.poll_close(cx) - } -} diff --git a/net/levin/src/message_stream.rs b/net/levin/src/message_stream.rs deleted file mode 100644 index ce680747..00000000 --- a/net/levin/src/message_stream.rs +++ /dev/null @@ -1,99 +0,0 @@ -// Rust Levin Library -// Written in 2023 by -// Cuprate Contributors -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all -// copies or substantial portions of the Software. -// - -//! This modual provides a `MessageStream` which deserializes partially decoded `Bucket`s -//! into full Buckets using a user provided `LevinBody` - -use std::marker::PhantomData; -use std::task::Poll; - -use futures::ready; -use futures::AsyncRead; -use futures::Stream; -use pin_project::pin_project; - -use crate::bucket_stream::BucketStream; -use crate::BucketError; -use crate::LevinBody; -use crate::MessageType; -use crate::LEVIN_SIGNATURE; -use crate::PROTOCOL_VERSION; - -/// A stream that reads from the underlying `BucketStream` and uses the the -/// methods on the `LevinBody` trait to decode the inner messages(bodies) -#[pin_project] -pub struct MessageStream { - #[pin] - bucket_stream: BucketStream, - phantom: PhantomData, -} - -impl MessageStream { - /// Creates a new stream from the provided `AsyncRead` - pub fn new(stream: S) -> Self { - MessageStream { - bucket_stream: BucketStream::new(stream), - phantom: PhantomData, - } - } -} - -impl Stream for MessageStream { - type Item = Result; - - fn poll_next( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - let this = self.project(); - match ready!(this.bucket_stream.poll_next(cx)).expect("BucketStream will never return None") - { - Err(e) => Poll::Ready(Some(Err(e))), - Ok(bucket) => { - if bucket.header.signature != LEVIN_SIGNATURE { - return Err(BucketError::IncorrectSignature(bucket.header.signature))?; - } - - if bucket.header.protocol_version != PROTOCOL_VERSION { - return Err(BucketError::UnknownProtocolVersion( - bucket.header.protocol_version, - ))?; - } - - // TODO: we shouldn't return an error if the peer sends an error response we should define a new network - // message: Error. - if bucket.header.return_code < 0 - || (bucket.header.return_code == 0 && bucket.header.flags.is_response()) - { - return Err(BucketError::Error(bucket.header.return_code))?; - } - - if bucket.header.flags.is_dummy() { - cx.waker().wake_by_ref(); - return Poll::Pending; - } - - Poll::Ready(Some(D::decode_message( - &bucket.body, - MessageType::from_flags_and_have_to_return( - bucket.header.flags, - bucket.header.have_to_return_data, - )?, - bucket.header.command, - ))) - } - } - } -}