diff --git a/Cargo.lock b/Cargo.lock index 32e86d36..4bb47a9b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -629,6 +629,7 @@ dependencies = [ "tar", "tempfile", "tokio", + "tokio-util", "zip", ] @@ -1418,9 +1419,15 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" name = "levin-cuprate" version = "0.1.0" dependencies = [ + "bitflags 2.4.2", "bytes", + "futures", + "proptest", + "rand", "thiserror", + "tokio", "tokio-util", + "tracing", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 3bf700a4..73a10687 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,6 +37,7 @@ opt-level = 3 [workspace.dependencies] async-trait = { version = "0.1.74", default-features = false } +bitflags = { version = "2.4.2", default-features = false } borsh = { version = "1.2.1", default-features = false } bytes = { version = "1.5.0", default-features = false } cfg-if = { version = "1.0.0", default-features = false } diff --git a/net/levin/Cargo.toml b/net/levin/Cargo.toml index 790d4369..d7ca94b5 100644 --- a/net/levin/Cargo.toml +++ b/net/levin/Cargo.toml @@ -7,8 +7,21 @@ license = "MIT" authors = ["Boog900"] repository = "https://github.com/Cuprate/cuprate/tree/main/net/levin" +[features] +default = [] +tracing = ["dep:tracing", "tokio-util/tracing"] + [dependencies] thiserror = { workspace = true } -bytes = { workspace = true } +bytes = { workspace = true, features = ["std"] } +bitflags = { workspace = true } tokio-util = { workspace = true, features = ["codec"]} +tracing = { workspace = true, features = ["std"], optional = true } + +[dev-dependencies] +proptest = { workspace = true } +rand = { workspace = true, features = ["std", "std_rng"] } +tokio-util = { workspace = true, features = ["io-util"]} +tokio = { workspace = true, features = ["full"] } +futures = { workspace = true, features = ["std"] } \ No newline at end of file diff --git a/net/levin/src/codec.rs b/net/levin/src/codec.rs index 0c6e8b8f..3718d8c3 100644 --- a/net/levin/src/codec.rs +++ b/net/levin/src/codec.rs @@ -15,12 +15,14 @@ //! A tokio-codec for levin buckets -use std::marker::PhantomData; +use std::{fmt::Debug, marker::PhantomData}; -use bytes::{Buf, BufMut, Bytes, BytesMut}; +use bytes::{Buf, BufMut, BytesMut}; use tokio_util::codec::{Decoder, Encoder}; use crate::{ + header::{Flags, HEADER_SIZE}, + message::{make_dummy_message, LevinMessage}, Bucket, BucketBuilder, BucketError, BucketHead, LevinBody, LevinCommand, MessageType, Protocol, }; @@ -61,31 +63,49 @@ impl LevinBucketCodec { } } -impl Decoder for LevinBucketCodec { +impl Decoder for LevinBucketCodec { type Item = Bucket; type Error = BucketError; fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { loop { match &self.state { LevinBucketState::WaitingForHeader => { - if src.len() < BucketHead::::SIZE { + if src.len() < HEADER_SIZE { return Ok(None); }; let head = BucketHead::::from_bytes(src); + #[cfg(feature = "tracing")] + tracing::trace!( + "Received new bucket header, command: {:?}, waiting for body, body len: {}", + head.command, + head.size + ); + if head.size > self.protocol.max_packet_size || head.size > head.command.bucket_size_limit() { + #[cfg(feature = "tracing")] + tracing::debug!("Peer sent message which is too large."); + return Err(BucketError::BucketExceededMaxSize); } if !self.handshake_message_seen { if head.size > self.protocol.max_packet_size_before_handshake { + #[cfg(feature = "tracing")] + tracing::debug!("Peer sent message which is too large."); + return Err(BucketError::BucketExceededMaxSize); } if head.command.is_handshake() { + #[cfg(feature = "tracing")] + tracing::debug!( + "Peer handshake message seen, increasing bucket size limit." + ); + self.handshake_message_seen = true; } } @@ -109,6 +129,9 @@ impl Decoder for LevinBucketCodec { unreachable!() }; + #[cfg(feature = "tracing")] + tracing::trace!("Received full bucket for command: {:?}", header.command); + return Ok(Some(Bucket { header, body: src.copy_to_bytes(body_len), @@ -122,23 +145,26 @@ impl Decoder for LevinBucketCodec { impl Encoder> for LevinBucketCodec { type Error = BucketError; fn encode(&mut self, item: Bucket, dst: &mut BytesMut) -> Result<(), Self::Error> { - if let Some(additional) = - (BucketHead::::SIZE + item.body.len()).checked_sub(dst.capacity()) - { + if let Some(additional) = (HEADER_SIZE + item.body.len()).checked_sub(dst.capacity()) { dst.reserve(additional) } - item.header.write_bytes(dst); + item.header.write_bytes_into(dst); dst.put_slice(&item.body); Ok(()) } } #[derive(Default, Debug, Clone)] -enum MessageState { +enum MessageState { #[default] WaitingForBucket, - WaitingForRestOfFragment(Vec, MessageType, C), + /// Waiting for the rest of a fragmented message. + /// + /// We keep the fragmented message as a Vec instead of [`Bytes`](bytes::Bytes) as [`Bytes`](bytes::Bytes) could point to a + /// large allocation even if the [`Bytes`](bytes::Bytes) itself is small, so is not safe to keep around for long. + /// To prevent this attack vector completely we just use Vec for fragmented messages. + WaitingForRestOfFragment(Vec), } /// A tokio-codec for levin messages or in other words the decoded body @@ -147,7 +173,7 @@ enum MessageState { pub struct LevinMessageCodec { message_ty: PhantomData, bucket_codec: LevinBucketCodec, - state: MessageState, + state: MessageState, } impl Default for LevinMessageCodec { @@ -173,107 +199,143 @@ impl Decoder for LevinMessageCodec { let flags = &bucket.header.flags; - if flags.is_start_fragment() && flags.is_end_fragment() { + if flags.contains(Flags::DUMMY) { // Dummy message - return Ok(None); + + #[cfg(feature = "tracing")] + tracing::trace!("Received DUMMY bucket from peer, ignoring."); + // We may have another bucket in `src`. + continue; }; - if flags.is_end_fragment() { + if flags.contains(Flags::END_FRAGMENT) { return Err(BucketError::InvalidHeaderFlags( "Flag end fragment received before a start fragment", )); }; - if !flags.is_request() && !flags.is_response() { - return Err(BucketError::InvalidHeaderFlags( - "Request and response flags both not set", - )); - }; + if flags.contains(Flags::START_FRAGMENT) { + // monerod does not require a start flag before starting a fragmented message, + // but will always produce one, so it is ok for us to require one. + + #[cfg(feature = "tracing")] + tracing::debug!("Bucket is a fragment, waiting for rest of message."); + + self.state = MessageState::WaitingForRestOfFragment(bucket.body.to_vec()); + + continue; + } + + // Normal, non fragmented bucket let message_type = MessageType::from_flags_and_have_to_return( bucket.header.flags, bucket.header.have_to_return_data, )?; - if flags.is_start_fragment() { - let _ = std::mem::replace( - &mut self.state, - MessageState::WaitingForRestOfFragment( - vec![bucket.body], - message_type, - bucket.header.command, - ), - ); - - continue; - } - return Ok(Some(T::decode_message( &mut bucket.body, message_type, bucket.header.command, )?)); } - MessageState::WaitingForRestOfFragment(bytes, ty, command) => { + MessageState::WaitingForRestOfFragment(bytes) => { let Some(bucket) = self.bucket_codec.decode(src)? else { return Ok(None); }; let flags = &bucket.header.flags; - if flags.is_start_fragment() && flags.is_end_fragment() { + if flags.contains(Flags::DUMMY) { // Dummy message - return Ok(None); + + #[cfg(feature = "tracing")] + tracing::trace!("Received DUMMY bucket from peer, ignoring."); + // We may have another bucket in `src`. + continue; }; - if !flags.is_request() && !flags.is_response() { - return Err(BucketError::InvalidHeaderFlags( - "Request and response flags both not set", - )); - }; - - let message_type = MessageType::from_flags_and_have_to_return( - bucket.header.flags, - bucket.header.have_to_return_data, - )?; - - if message_type != *ty { - return Err(BucketError::InvalidFragmentedMessage( - "Message type was inconsistent across fragments", - )); + let max_size = if self.bucket_codec.handshake_message_seen { + self.bucket_codec.protocol.max_packet_size + } else { + self.bucket_codec.protocol.max_packet_size_before_handshake } + .try_into() + .expect("Levin max message size is too large, does not fit into a usize."); - if bucket.header.command != *command { - return Err(BucketError::InvalidFragmentedMessage( - "Command not consistent across fragments", - )); - } - - if bytes.len().saturating_add(bucket.body.len()) - > command.bucket_size_limit().try_into().unwrap() - { + if bytes.len().saturating_add(bucket.body.len()) > max_size { return Err(BucketError::InvalidFragmentedMessage( "Fragmented message exceeded maximum size", )); } - bytes.push(bucket.body); + #[cfg(feature = "tracing")] + tracing::trace!("Received another bucket fragment."); - if flags.is_end_fragment() { - let MessageState::WaitingForRestOfFragment(mut bytes, ty, command) = + bytes.extend_from_slice(bucket.body.as_ref()); + + if flags.contains(Flags::END_FRAGMENT) { + // make sure we only look at the internal bucket and don't use this. + drop(bucket); + + let MessageState::WaitingForRestOfFragment(bytes) = std::mem::replace(&mut self.state, MessageState::WaitingForBucket) else { unreachable!(); }; - // TODO: this doesn't seem very efficient but I can't think of a better way. - bytes.reverse(); - let mut byte_vec: Box = Box::new(bytes.pop().unwrap()); - for bytes in bytes { - byte_vec = Box::new(byte_vec.chain(bytes)); + // Check there are enough bytes in the fragment to build a header. + if bytes.len() < HEADER_SIZE { + return Err(BucketError::InvalidFragmentedMessage( + "Fragmented message is not large enough to build a bucket.", + )); } - return Ok(Some(T::decode_message(&mut byte_vec, ty, command)?)); + let mut header_bytes = BytesMut::from(&bytes[0..HEADER_SIZE]); + + let header = BucketHead::::from_bytes(&mut header_bytes); + + if header.size > header.command.bucket_size_limit() { + return Err(BucketError::BucketExceededMaxSize); + } + + // Check the fragmented message contains enough bytes to build the message. + if bytes.len().saturating_sub(HEADER_SIZE) + < header + .size + .try_into() + .map_err(|_| BucketError::BucketExceededMaxSize)? + { + return Err(BucketError::InvalidFragmentedMessage( + "Fragmented message does not have enough bytes to fill bucket body", + )); + } + + #[cfg(feature = "tracing")] + tracing::debug!( + "Received final fragment, combined message command: {:?}.", + header.command + ); + + let message_type = MessageType::from_flags_and_have_to_return( + header.flags, + header.have_to_return_data, + )?; + + if header.command.is_handshake() { + #[cfg(feature = "tracing")] + tracing::debug!( + "Peer handshake message seen, increasing bucket size limit." + ); + + self.bucket_codec.handshake_message_seen = true; + } + + return Ok(Some(T::decode_message( + &mut &bytes[HEADER_SIZE..], + message_type, + header.command, + )?)); } } } @@ -281,12 +343,21 @@ impl Decoder for LevinMessageCodec { } } -impl Encoder for LevinMessageCodec { +impl Encoder> for LevinMessageCodec { type Error = BucketError; - fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<(), Self::Error> { - let mut bucket_builder = BucketBuilder::default(); - item.encode(&mut bucket_builder)?; - let bucket = bucket_builder.finish(); - self.bucket_codec.encode(bucket, dst) + fn encode(&mut self, item: LevinMessage, dst: &mut BytesMut) -> Result<(), Self::Error> { + match item { + LevinMessage::Body(body) => { + let mut bucket_builder = BucketBuilder::new(&self.bucket_codec.protocol); + body.encode(&mut bucket_builder)?; + let bucket = bucket_builder.finish(); + self.bucket_codec.encode(bucket, dst) + } + LevinMessage::Bucket(bucket) => self.bucket_codec.encode(bucket, dst), + LevinMessage::Dummy(size) => { + let bucket = make_dummy_message(&self.bucket_codec.protocol, size); + self.bucket_codec.encode(bucket, dst) + } + } } } diff --git a/net/levin/src/header.rs b/net/levin/src/header.rs index 42418ded..7acd0858 100644 --- a/net/levin/src/header.rs +++ b/net/levin/src/header.rs @@ -16,34 +16,47 @@ //! This module provides a struct BucketHead for the header of a levin protocol //! message. +use bitflags::bitflags; use bytes::{Buf, BufMut, BytesMut}; use crate::LevinCommand; -const REQUEST: u32 = 0b0000_0001; -const RESPONSE: u32 = 0b0000_0010; -const START_FRAGMENT: u32 = 0b0000_0100; -const END_FRAGMENT: u32 = 0b0000_1000; +/// The size of the header (in bytes) +pub const HEADER_SIZE: usize = 33; /// Levin header flags #[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] pub struct Flags(u32); -impl Flags { - pub const REQUEST: Flags = Flags(REQUEST); - pub const RESPONSE: Flags = Flags(RESPONSE); +bitflags! { + impl Flags: u32 { + /// The request flag. + /// + /// Depending on the `have_to_return_data` field in [`BucketHead`], this message is either + /// a request or notification. + const REQUEST = 0b0000_0001; + /// The response flags. + /// + /// Messages with this set are responses to requests. + const RESPONSE = 0b0000_0010; - pub fn is_request(&self) -> bool { - self.0 & REQUEST != 0 - } - pub fn is_response(&self) -> bool { - self.0 & RESPONSE != 0 - } - pub fn is_start_fragment(&self) -> bool { - self.0 & START_FRAGMENT != 0 - } - pub fn is_end_fragment(&self) -> bool { - self.0 & END_FRAGMENT != 0 + /// The start fragment flag. + /// + /// Messages with this flag set tell the parser that the next messages until a message + /// with [`Flags::END_FRAGMENT`] should be combined into a single bucket. + const START_FRAGMENT = 0b0000_0100; + /// The end fragment flag. + /// + /// Messages with this flag set tell the parser that all fragments of a fragmented message + /// have been sent. + const END_FRAGMENT = 0b0000_1000; + + /// A dummy message. + /// + /// Messages with this flag will be completely ignored by the parser. + const DUMMY = Self::START_FRAGMENT.bits() | Self::END_FRAGMENT.bits(); + + const _ = !0; } } @@ -81,15 +94,12 @@ pub struct BucketHead { } impl BucketHead { - /// The size of the header (in bytes) - pub const SIZE: usize = 33; - /// Builds the header from bytes, this function does not check any fields should /// match the expected ones. /// /// # Panics /// This function will panic if there aren't enough bytes to fill the header. - /// Currently ['SIZE'](BucketHead::SIZE) + /// Currently [HEADER_SIZE] pub fn from_bytes(buf: &mut BytesMut) -> BucketHead { BucketHead { signature: buf.get_u64_le(), @@ -103,8 +113,8 @@ impl BucketHead { } /// Serializes the header - pub fn write_bytes(&self, dst: &mut BytesMut) { - dst.reserve(Self::SIZE); + pub fn write_bytes_into(&self, dst: &mut BytesMut) { + dst.reserve(HEADER_SIZE); dst.put_u64_le(self.signature); dst.put_u64_le(self.size); diff --git a/net/levin/src/lib.rs b/net/levin/src/lib.rs index 6be0edb3..0a247f72 100644 --- a/net/levin/src/lib.rs +++ b/net/levin/src/lib.rs @@ -33,20 +33,29 @@ #![deny(unused_mut)] //#![deny(missing_docs)] -pub mod codec; -pub mod header; - -pub use codec::*; -pub use header::BucketHead; - use std::fmt::Debug; use bytes::{Buf, Bytes}; use thiserror::Error; +pub mod codec; +pub mod header; +pub mod message; + +pub use codec::*; +pub use header::BucketHead; +pub use message::LevinMessage; + +use header::Flags; + +/// The version field for bucket headers. const MONERO_PROTOCOL_VERSION: u32 = 1; +/// The signature field for bucket headers, will be constant for all peers using the Monero levin +/// protocol. const MONERO_LEVIN_SIGNATURE: u64 = 0x0101010101012101; +/// Maximum size a bucket can be before a handshake. const MONERO_MAX_PACKET_SIZE_BEFORE_HANDSHAKE: u64 = 256 * 1000; // 256 KiB +/// Maximum size a bucket can be after a handshake. const MONERO_MAX_PACKET_SIZE: u64 = 100_000_000; // 100MB /// Possible Errors when working with levin buckets @@ -98,7 +107,7 @@ impl Default for Protocol { } /// A levin Bucket -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Bucket { /// The bucket header pub header: BucketHead, @@ -128,20 +137,19 @@ impl MessageType { /// Returns the `MessageType` given the flags and have_to_return_data fields pub fn from_flags_and_have_to_return( - flags: header::Flags, + flags: Flags, have_to_return: bool, ) -> Result { - if flags.is_request() && have_to_return { - Ok(MessageType::Request) - } else if flags.is_request() { - Ok(MessageType::Notification) - } else if flags.is_response() && !have_to_return { - Ok(MessageType::Response) - } else { - Err(BucketError::InvalidHeaderFlags( - "Unable to assign a message type to this bucket", - )) - } + Ok(match (flags, have_to_return) { + (Flags::REQUEST, true) => MessageType::Request, + (Flags::REQUEST, false) => MessageType::Notification, + (Flags::RESPONSE, false) => MessageType::Response, + _ => { + return Err(BucketError::InvalidHeaderFlags( + "Unable to assign a message type to this bucket", + )) + } + }) } pub fn as_flags(&self) -> header::Flags { @@ -162,20 +170,18 @@ pub struct BucketBuilder { body: Option, } -impl Default for BucketBuilder { - fn default() -> Self { +impl BucketBuilder { + pub fn new(protocol: &Protocol) -> Self { Self { - signature: Some(MONERO_LEVIN_SIGNATURE), + signature: Some(protocol.signature), ty: None, command: None, return_code: None, - protocol_version: Some(MONERO_PROTOCOL_VERSION), + protocol_version: Some(protocol.version), body: None, } } -} -impl BucketBuilder { pub fn set_signature(&mut self, sig: u64) { self.signature = Some(sig) } @@ -220,7 +226,7 @@ impl BucketBuilder { /// A levin body pub trait LevinBody: Sized { - type Command: LevinCommand; + type Command: LevinCommand + Debug; /// Decodes the message from the data in the header fn decode_message( diff --git a/net/levin/src/message.rs b/net/levin/src/message.rs new file mode 100644 index 00000000..dd60fdd0 --- /dev/null +++ b/net/levin/src/message.rs @@ -0,0 +1,206 @@ +//! Levin Messages +//! +//! This module contains the [`LevinMessage`], which allows sending bucket body's, full buckets or dummy messages. +//! The codec will not return [`LevinMessage`] instead it will only return bucket body's. [`LevinMessage`] allows +//! for more control over what is actually sent over the wire at certain times. +use bytes::{Bytes, BytesMut}; + +use crate::{ + header::{Flags, HEADER_SIZE}, + Bucket, BucketBuilder, BucketError, BucketHead, LevinBody, LevinCommand, Protocol, +}; + +/// A levin message that can be sent to a peer. +pub enum LevinMessage { + /// A message body. + /// + /// A levin header will be added to this message before it is sent to the peer. + Body(T), + /// A full levin bucket. + /// + /// This bucket will be sent to the peer directly with no extra information. + /// + /// This should only be used to send fragmented messages: [`make_fragmented_messages`] + Bucket(Bucket), + /// A dummy message. + /// + /// A dummy message which the peer will ignore. The dummy message will be the exact size + /// (in bytes) of the given `usize` on the wire. + Dummy(usize), +} + +impl From for LevinMessage { + fn from(value: T) -> Self { + LevinMessage::Body(value) + } +} + +impl From> for LevinMessage { + fn from(value: Bucket) -> Self { + LevinMessage::Bucket(value) + } +} + +/// This represents a dummy message to send to a peer. +/// +/// The message, including the header, will be the exact size of the given `usize`. +/// This exists because it seems weird to do this: +/// ```rust,ignore +/// peer.send(1_000); +/// ``` +/// This is a lot clearer: +/// ```rust,ignore +/// peer.send(Dummy(1_000)); +/// ``` +pub struct Dummy(pub usize); + +impl From for LevinMessage { + fn from(value: Dummy) -> Self { + LevinMessage::Dummy(value.0) + } +} + +/// Fragments the provided message into buckets which, when serialised, will all be the size of `fragment_size`. +/// +/// This function will produce many buckets that have to be sent in order. When the peer receives these buckets +/// they will combine them to produce the original message. +/// +/// The last bucket may be padded with zeros to make it the correct size, the format used to encode the body must +/// allow for extra data at the end of the message this to work. +/// +/// `fragment_size` must be more than 2 * [`HEADER_SIZE`] otherwise this will panic. +pub fn make_fragmented_messages( + protocol: &Protocol, + fragment_size: usize, + message: T, +) -> Result>, BucketError> { + if fragment_size * 2 < HEADER_SIZE { + panic!( + "Fragment size: {fragment_size}, is too small, must be at least {}", + 2 * HEADER_SIZE + ); + } + + let mut builder = BucketBuilder::new(protocol); + message.encode(&mut builder)?; + let mut bucket = builder.finish(); + + // Make sure we are not trying to fragment a fragment. + if !bucket + .header + .flags + .intersects(Flags::REQUEST | Flags::RESPONSE) + { + // If a bucket does not have the request or response bits set it is a fragment. + return Err(BucketError::InvalidFragmentedMessage( + "Can't make a fragmented message out of a message which is already fragmented", + )); + } + + // Check if the bucket can fit in one fragment. + if bucket.body.len() + HEADER_SIZE <= fragment_size { + // If it can pad the bucket upto the fragment size and just return this bucket. + if bucket.body.len() + HEADER_SIZE < fragment_size { + let mut new_body = BytesMut::from(bucket.body.as_ref()); + // Epee's binary format will ignore extra data at the end so just pad with 0. + new_body.resize(fragment_size - HEADER_SIZE, 0); + + bucket.body = new_body.freeze(); + bucket.header.size = fragment_size + .try_into() + .expect("Bucket size does not fit into u64"); + } + + return Ok(vec![bucket]); + } + + // A header put on all fragments. + // The first fragment will set the START flag, the last will set the END flag. + let fragment_head = BucketHead { + signature: protocol.signature, + size: (fragment_size - HEADER_SIZE) + .try_into() + .expect("Bucket size does not fit into u64"), + have_to_return_data: false, + // Just use a default command. + command: T::Command::from(0), + return_code: 0, + flags: Flags::empty(), + protocol_version: protocol.version, + }; + + // data_space - the amount of actual data we can fit in each fragment. + let data_space = fragment_size - HEADER_SIZE; + + let amount_of_fragments = (bucket.body.len() + HEADER_SIZE).div_ceil(data_space); + + let mut first_bucket_body = BytesMut::with_capacity(fragment_size); + // Fragmented messages store the whole fragmented bucket in the combined payloads not just the body + // so the first bucket contains 2 headers, a fragment header and the actual bucket header we are sending. + bucket.header.write_bytes_into(&mut first_bucket_body); + first_bucket_body.extend_from_slice( + bucket + .body + .split_to(fragment_size - (HEADER_SIZE * 2)) + .as_ref(), + ); + + let mut buckets = Vec::with_capacity(amount_of_fragments); + buckets.push(Bucket { + header: fragment_head.clone(), + body: first_bucket_body.freeze(), + }); + + for mut bytes in (1..amount_of_fragments).map(|_| { + bucket + .body + .split_to((fragment_size - HEADER_SIZE).min(bucket.body.len())) + }) { + // make sure this fragment has the correct size - the last one might not, so pad it. + if bytes.len() + HEADER_SIZE < fragment_size { + let mut new_bytes = BytesMut::from(bytes.as_ref()); + // Epee's binary format will ignore extra data at the end so just pad with 0. + new_bytes.resize(fragment_size - HEADER_SIZE, 0); + bytes = new_bytes.freeze(); + } + + buckets.push(Bucket { + header: fragment_head.clone(), + body: bytes, + }); + } + + buckets + .first_mut() + .unwrap() + .header + .flags + .toggle(Flags::START_FRAGMENT); + buckets + .last_mut() + .unwrap() + .header + .flags + .toggle(Flags::END_FRAGMENT); + + Ok(buckets) +} + +/// Makes a dummy message, which will be the size of `size` when sent over the wire. +pub(crate) fn make_dummy_message(protocol: &Protocol, size: usize) -> Bucket { + // A header to put on the dummy message. + let header = BucketHead { + signature: protocol.signature, + size: size.try_into().expect("Bucket size does not fit into u64"), + have_to_return_data: false, + // Just use a default command. + command: T::from(0), + return_code: 0, + flags: Flags::DUMMY, + protocol_version: protocol.version, + }; + + let body = Bytes::from(vec![0; size - HEADER_SIZE]); + + Bucket { header, body } +} diff --git a/net/levin/tests/fragmented_message.rs b/net/levin/tests/fragmented_message.rs new file mode 100644 index 00000000..7598e2ca --- /dev/null +++ b/net/levin/tests/fragmented_message.rs @@ -0,0 +1,151 @@ +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use futures::{SinkExt, StreamExt}; +use proptest::{prelude::any_with, prop_assert_eq, proptest, sample::size_range}; +use rand::Fill; +use tokio::{ + io::duplex, + time::{timeout, Duration}, +}; +use tokio_util::codec::{FramedRead, FramedWrite}; + +use levin_cuprate::{ + message::make_fragmented_messages, BucketBuilder, BucketError, LevinBody, LevinCommand, + LevinMessageCodec, MessageType, Protocol, +}; + +/// A timeout put on streams so tests don't stall. +const TEST_TIMEOUT: Duration = Duration::from_secs(30); + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct TestCommands(u32); + +impl From for TestCommands { + fn from(value: u32) -> Self { + Self(value) + } +} + +impl From for u32 { + fn from(value: TestCommands) -> Self { + value.0 + } +} + +impl LevinCommand for TestCommands { + fn bucket_size_limit(&self) -> u64 { + u64::MAX + } + + fn is_handshake(&self) -> bool { + self.0 == 1 + } +} + +#[derive(Clone)] +enum TestBody { + Bytes(usize, Bytes), +} + +impl LevinBody for TestBody { + type Command = TestCommands; + + fn decode_message( + body: &mut B, + _: MessageType, + _: Self::Command, + ) -> Result { + let size = body.get_u64_le().try_into().unwrap(); + // bucket + Ok(TestBody::Bytes(size, body.copy_to_bytes(size))) + } + + fn encode(self, builder: &mut BucketBuilder) -> Result<(), BucketError> { + match self { + TestBody::Bytes(len, bytes) => { + let mut buf = BytesMut::new(); + buf.put_u64_le(len as u64); + buf.extend_from_slice(bytes.as_ref()); + + builder.set_command(TestCommands(1)); + builder.set_message_type(MessageType::Notification); + builder.set_return_code(0); + builder.set_body(buf.freeze()); + } + } + + Ok(()) + } +} + +#[tokio::test] +async fn codec_fragmented_messages() { + // Set up the fake connection + let (write, read) = duplex(100_000); + + let mut read = FramedRead::new(read, LevinMessageCodec::::default()); + let mut write = FramedWrite::new(write, LevinMessageCodec::::default()); + + // Create the message to fragment + let mut buf = BytesMut::from(vec![0; 10_000].as_slice()); + let mut rng = rand::thread_rng(); + buf.try_fill(&mut rng).unwrap(); + + let message = TestBody::Bytes(buf.len(), buf.freeze()); + + let fragments = make_fragmented_messages(&Protocol::default(), 3_000, message.clone()).unwrap(); + + for frag in fragments { + // Send each fragment + timeout(TEST_TIMEOUT, write.send(frag.into())) + .await + .unwrap() + .unwrap(); + } + + // only one message should be received. + let message2 = timeout(TEST_TIMEOUT, read.next()) + .await + .unwrap() + .unwrap() + .unwrap(); + + match (message, message2) { + (TestBody::Bytes(_, buf), TestBody::Bytes(_, buf2)) => assert_eq!(buf, buf2), + } +} + +proptest! { + #[test] + fn make_fragmented_messages_correct_size(fragment_size in 100_usize..5000, message_size in 0_usize..100_000) { + let mut bytes = BytesMut::new(); + bytes.resize(message_size, 10); + + let fragments = make_fragmented_messages(&Protocol::default(), fragment_size, TestBody::Bytes(bytes.len(), bytes.freeze())).unwrap(); + let len = fragments.len(); + + for (i, fragment) in fragments.into_iter().enumerate() { + prop_assert_eq!(fragment.body.len() + 33, fragment_size, "numb_fragments:{}, index: {}", len, i) + } + } + + #[test] + fn make_fragmented_messages_consistent(fragment_size in 100_usize..5_000, message in any_with::>(size_range(50_000).lift())) { + let fragments = make_fragmented_messages(&Protocol::default(), fragment_size, TestBody::Bytes(message.len(), Bytes::copy_from_slice(message.as_slice()))).unwrap(); + + let mut message2 = Vec::with_capacity(message.len()); + + // remove the header and the bytes length. + message2.extend_from_slice(&fragments[0].body[(33 + 8)..]); + + for frag in fragments.iter().skip(1) { + message2.extend_from_slice(frag.body.as_ref()) + } + + prop_assert_eq!(message.as_slice(), &message2[0..message.len()], "numb_fragments: {}", fragments.len()); + + for byte in message2[message.len()..].iter(){ + prop_assert_eq!(*byte, 0); + } + } + +} diff --git a/net/monero-wire/Cargo.toml b/net/monero-wire/Cargo.toml index 5a8f4b86..611fb080 100644 --- a/net/monero-wire/Cargo.toml +++ b/net/monero-wire/Cargo.toml @@ -6,6 +6,9 @@ license = "MIT" authors = ["Boog900"] repository = "https://github.com/SyntheticBird45/cuprate/tree/main/net/monero-wire" +[features] +default = [] +tracing = ["levin-cuprate/tracing"] [dependencies] levin-cuprate = {path="../levin"} diff --git a/net/monero-wire/src/lib.rs b/net/monero-wire/src/lib.rs index f69fd2dc..27e6481d 100644 --- a/net/monero-wire/src/lib.rs +++ b/net/monero-wire/src/lib.rs @@ -29,4 +29,7 @@ pub use levin_cuprate::BucketError; pub use network_address::{NetZone, NetworkAddress}; pub use p2p::*; +// re-export. +pub use levin_cuprate as levin; + pub type MoneroWireCodec = levin_cuprate::codec::LevinMessageCodec; diff --git a/p2p/monero-p2p/Cargo.toml b/p2p/monero-p2p/Cargo.toml index 5c39035d..7abd9c86 100644 --- a/p2p/monero-p2p/Cargo.toml +++ b/p2p/monero-p2p/Cargo.toml @@ -11,7 +11,7 @@ borsh = ["dep:borsh", "monero-pruning/borsh"] [dependencies] cuprate-helper = { path = "../../helper" } -monero-wire = { path = "../../net/monero-wire" } +monero-wire = { path = "../../net/monero-wire", features = ["tracing"] } monero-pruning = { path = "../../pruning" } tokio = { workspace = true, features = ["net", "sync", "macros", "time"]} diff --git a/p2p/monero-p2p/src/client/connection.rs b/p2p/monero-p2p/src/client/connection.rs index 1b49454c..b458c3da 100644 --- a/p2p/monero-p2p/src/client/connection.rs +++ b/p2p/monero-p2p/src/client/connection.rs @@ -86,7 +86,7 @@ where } async fn send_message_to_peer(&mut self, mes: Message) -> Result<(), PeerError> { - Ok(self.peer_sink.send(mes).await?) + Ok(self.peer_sink.send(mes.into()).await?) } async fn handle_client_request(&mut self, req: ConnectionTaskRequest) -> Result<(), PeerError> { diff --git a/p2p/monero-p2p/src/client/handshaker.rs b/p2p/monero-p2p/src/client/handshaker.rs index 6ea828d9..52f8e2e7 100644 --- a/p2p/monero-p2p/src/client/handshaker.rs +++ b/p2p/monero-p2p/src/client/handshaker.rs @@ -257,7 +257,7 @@ where "Peer didn't send support flags or has no features, sending request to make sure." ); peer_sink - .send(Message::Request(RequestMessage::SupportFlags)) + .send(Message::Request(RequestMessage::SupportFlags).into()) .await?; let Message::Response(ResponseMessage::SupportFlags(support_flags_res)) = @@ -346,7 +346,7 @@ where tracing::debug!("Sending handshake request."); peer_sink - .send(Message::Request(RequestMessage::Handshake(req))) + .send(Message::Request(RequestMessage::Handshake(req)).into()) .await?; Ok(()) @@ -391,7 +391,7 @@ where tracing::debug!("Sending handshake response."); peer_sink - .send(Message::Response(ResponseMessage::Handshake(res))) + .send(Message::Response(ResponseMessage::Handshake(res)).into()) .await?; Ok(()) @@ -476,8 +476,11 @@ async fn send_support_flags( ) -> Result<(), HandshakeError> { tracing::debug!("Sending support flag response."); Ok(peer_sink - .send(Message::Response(ResponseMessage::SupportFlags( - SupportFlagsResponse { support_flags }, - ))) + .send( + Message::Response(ResponseMessage::SupportFlags(SupportFlagsResponse { + support_flags, + })) + .into(), + ) .await?) } diff --git a/p2p/monero-p2p/src/lib.rs b/p2p/monero-p2p/src/lib.rs index 74b4a3ee..8eb309be 100644 --- a/p2p/monero-p2p/src/lib.rs +++ b/p2p/monero-p2p/src/lib.rs @@ -5,7 +5,8 @@ use std::{fmt::Debug, future::Future, hash::Hash, pin::Pin}; use futures::{Sink, Stream}; use monero_wire::{ - network_address::NetworkAddressIncorrectZone, BucketError, Message, NetworkAddress, + levin::LevinMessage, network_address::NetworkAddressIncorrectZone, BucketError, Message, + NetworkAddress, }; pub mod client; @@ -103,7 +104,7 @@ pub trait NetworkZone: Clone + Copy + Send + 'static { /// The stream (incoming data) type for this network. type Stream: Stream> + Unpin + Send + 'static; /// The sink (outgoing data) type for this network. - type Sink: Sink + Unpin + Send + 'static; + type Sink: Sink, Error = BucketError> + Unpin + Send + 'static; /// The inbound connection listener for this network. type Listener: Stream< Item = Result<(Option, Self::Stream, Self::Sink), std::io::Error>, diff --git a/p2p/monero-p2p/src/network_zones/clear.rs b/p2p/monero-p2p/src/network_zones/clear.rs index 508b0ab2..4086b48a 100644 --- a/p2p/monero-p2p/src/network_zones/clear.rs +++ b/p2p/monero-p2p/src/network_zones/clear.rs @@ -30,7 +30,7 @@ pub struct ClearNetServerCfg { } #[derive(Clone, Copy)] -pub struct ClearNet; +pub enum ClearNet {} #[async_trait::async_trait] impl NetworkZone for ClearNet { diff --git a/p2p/monero-p2p/tests/fragmented_handshake.rs b/p2p/monero-p2p/tests/fragmented_handshake.rs new file mode 100644 index 00000000..fdc25193 --- /dev/null +++ b/p2p/monero-p2p/tests/fragmented_handshake.rs @@ -0,0 +1,224 @@ +//! This file contains a test for a handshake with monerod but uses fragmented messages. +use std::{ + net::SocketAddr, + pin::Pin, + sync::Arc, + task::{Context, Poll}, + time::Duration, +}; + +use futures::{Stream, StreamExt}; +use tokio::{ + net::{ + tcp::{OwnedReadHalf, OwnedWriteHalf}, + TcpListener, TcpStream, + }, + sync::{broadcast, Semaphore}, + time::timeout, +}; +use tokio_util::{ + bytes::BytesMut, + codec::{Encoder, FramedRead, FramedWrite}, +}; +use tower::{Service, ServiceExt}; + +use cuprate_helper::network::Network; +use monero_p2p::{ + client::{ConnectRequest, Connector, DoHandshakeRequest, HandShaker, InternalPeerID}, + network_zones::ClearNetServerCfg, + ConnectionDirection, NetworkZone, +}; +use monero_wire::{ + common::PeerSupportFlags, + levin::{message::make_fragmented_messages, LevinMessage, Protocol}, + BasicNodeData, Message, MoneroWireCodec, +}; + +use cuprate_test_utils::monerod::monerod; + +mod utils; +use utils::*; + +/// A network zone equal to clear net where every message sent is turned into a fragmented message. +/// Does not support sending fragmented or dummy messages manually. +#[derive(Clone, Copy)] +pub enum FragNet {} + +#[async_trait::async_trait] +impl NetworkZone for FragNet { + const NAME: &'static str = "FragNet"; + const ALLOW_SYNC: bool = true; + const DANDELION_PP: bool = true; + const CHECK_NODE_ID: bool = true; + + type Addr = SocketAddr; + type Stream = FramedRead; + type Sink = FramedWrite; + type Listener = InBoundStream; + + type ServerCfg = ClearNetServerCfg; + + async fn connect_to_peer( + addr: Self::Addr, + ) -> Result<(Self::Stream, Self::Sink), std::io::Error> { + let (read, write) = TcpStream::connect(addr).await?.into_split(); + Ok(( + FramedRead::new(read, MoneroWireCodec::default()), + FramedWrite::new(write, FragmentCodec::default()), + )) + } + + async fn incoming_connection_listener( + config: Self::ServerCfg, + ) -> Result { + let listener = TcpListener::bind(config.addr).await?; + Ok(InBoundStream { listener }) + } +} + +pub struct InBoundStream { + listener: TcpListener, +} + +impl Stream for InBoundStream { + type Item = Result< + ( + Option, + FramedRead, + FramedWrite, + ), + std::io::Error, + >; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.listener + .poll_accept(cx) + .map_ok(|(stream, addr)| { + let (read, write) = stream.into_split(); + ( + Some(addr), + FramedRead::new(read, MoneroWireCodec::default()), + FramedWrite::new(write, FragmentCodec::default()), + ) + }) + .map(Some) + } +} + +#[derive(Default)] +pub struct FragmentCodec(MoneroWireCodec); + +impl Encoder> for FragmentCodec { + type Error = >>::Error; + + fn encode( + &mut self, + item: LevinMessage, + dst: &mut BytesMut, + ) -> Result<(), Self::Error> { + match item { + LevinMessage::Body(body) => { + // 66 is the minimum fragment size. + let fragments = make_fragmented_messages(&Protocol::default(), 66, body).unwrap(); + + for frag in fragments { + self.0.encode(frag.into(), dst)?; + } + } + _ => unreachable!("Handshakes should only send bucket bodys"), + } + Ok(()) + } +} + +#[tokio::test] +async fn fragmented_handshake_cuprate_to_monerod() { + let (broadcast_tx, _) = broadcast::channel(1); // this isn't actually used in this test. + let semaphore = Arc::new(Semaphore::new(10)); + let permit = semaphore.acquire_owned().await.unwrap(); + + let monerod = monerod(["--fixed-difficulty=1", "--out-peers=0"]).await; + + let our_basic_node_data = BasicNodeData { + my_port: 0, + network_id: Network::Mainnet.network_id().into(), + peer_id: 87980, + support_flags: PeerSupportFlags::from(1_u32), + rpc_port: 0, + rpc_credits_per_hash: 0, + }; + + let handshaker = HandShaker::::new( + DummyAddressBook, + DummyCoreSyncSvc, + DummyPeerRequestHandlerSvc, + broadcast_tx, + our_basic_node_data, + ); + + let mut connector = Connector::new(handshaker); + + connector + .ready() + .await + .unwrap() + .call(ConnectRequest { + addr: monerod.p2p_addr(), + permit, + }) + .await + .unwrap(); +} + +#[tokio::test] +async fn fragmented_handshake_monerod_to_cuprate() { + let (broadcast_tx, _) = broadcast::channel(1); // this isn't actually used in this test. + let semaphore = Arc::new(Semaphore::new(10)); + let permit = semaphore.acquire_owned().await.unwrap(); + + let our_basic_node_data = BasicNodeData { + my_port: 18081, + network_id: Network::Mainnet.network_id().into(), + peer_id: 87980, + support_flags: PeerSupportFlags::from(1_u32), + rpc_port: 0, + rpc_credits_per_hash: 0, + }; + + let mut handshaker = HandShaker::::new( + DummyAddressBook, + DummyCoreSyncSvc, + DummyPeerRequestHandlerSvc, + broadcast_tx, + our_basic_node_data, + ); + + let addr = "127.0.0.1:18081".parse().unwrap(); + + let mut listener = FragNet::incoming_connection_listener(ClearNetServerCfg { addr }) + .await + .unwrap(); + + let _monerod = monerod(["--add-exclusive-node=127.0.0.1:18081"]).await; + + // Put a timeout on this just in case monerod doesn't make the connection to us. + let next_connection_fut = timeout(Duration::from_secs(30), listener.next()); + + if let Some(Ok((addr, stream, sink))) = next_connection_fut.await.unwrap() { + let _ = handshaker + .ready() + .await + .unwrap() + .call(DoHandshakeRequest { + addr: InternalPeerID::KnownAddr(addr.unwrap()), // This is clear net all addresses are known. + peer_stream: stream, + peer_sink: sink, + direction: ConnectionDirection::InBound, + permit, + }) + .await + .unwrap(); + } else { + panic!("Failed to receive connection from monerod."); + }; +} diff --git a/p2p/monero-p2p/tests/handshake.rs b/p2p/monero-p2p/tests/handshake.rs index dacc150c..2634263d 100644 --- a/p2p/monero-p2p/tests/handshake.rs +++ b/p2p/monero-p2p/tests/handshake.rs @@ -1,14 +1,16 @@ use std::{sync::Arc, time::Duration}; -use futures::{channel::mpsc, StreamExt}; +use futures::StreamExt; use tokio::{ + io::{duplex, split}, sync::{broadcast, Semaphore}, time::timeout, }; +use tokio_util::codec::{FramedRead, FramedWrite}; use tower::{Service, ServiceExt}; use cuprate_helper::network::Network; -use monero_wire::{common::PeerSupportFlags, BasicNodeData}; +use monero_wire::{common::PeerSupportFlags, BasicNodeData, MoneroWireCodec}; use monero_p2p::{ client::{ConnectRequest, Connector, DoHandshakeRequest, HandShaker, InternalPeerID}, @@ -63,21 +65,23 @@ async fn handshake_cuprate_to_cuprate() { our_basic_node_data_2, ); - let (p1_sender, p2_receiver) = mpsc::channel(5); - let (p2_sender, p1_receiver) = mpsc::channel(5); + let (p1, p2) = duplex(50_000); + + let (p1_receiver, p1_sender) = split(p1); + let (p2_receiver, p2_sender) = split(p2); let p1_handshake_req = DoHandshakeRequest { addr: InternalPeerID::KnownAddr(TestNetZoneAddr(888)), - peer_stream: p2_receiver.map(Ok).boxed(), - peer_sink: p2_sender.into(), + peer_stream: FramedRead::new(p2_receiver, MoneroWireCodec::default()), + peer_sink: FramedWrite::new(p2_sender, MoneroWireCodec::default()), direction: ConnectionDirection::OutBound, permit: permit_1, }; let p2_handshake_req = DoHandshakeRequest { addr: InternalPeerID::KnownAddr(TestNetZoneAddr(444)), - peer_stream: p1_receiver.boxed().map(Ok).boxed(), - peer_sink: p1_sender.into(), + peer_stream: FramedRead::new(p1_receiver, MoneroWireCodec::default()), + peer_sink: FramedWrite::new(p1_sender, MoneroWireCodec::default()), direction: ConnectionDirection::InBound, permit: permit_2, }; diff --git a/test-utils/Cargo.toml b/test-utils/Cargo.toml index 6d653a12..b25dfde3 100644 --- a/test-utils/Cargo.toml +++ b/test-utils/Cargo.toml @@ -12,6 +12,7 @@ monero-p2p = {path = "../p2p/monero-p2p", features = ["borsh"] } futures = { workspace = true, features = ["std"] } async-trait = { workspace = true } tokio = { workspace = true, features = ["full"] } +tokio-util = { workspace = true } reqwest = { workspace = true } bytes = { workspace = true, features = ["std"] } tempfile = { workspace = true } diff --git a/test-utils/src/test_netzone.rs b/test-utils/src/test_netzone.rs index d5e2ad54..709d5567 100644 --- a/test-utils/src/test_netzone.rs +++ b/test-utils/src/test_netzone.rs @@ -8,15 +8,16 @@ use std::{ io::Error, net::{Ipv4Addr, SocketAddr}, pin::Pin, - task::{Context, Poll}, }; use borsh::{BorshDeserialize, BorshSerialize}; -use futures::{channel::mpsc::Sender as InnerSender, stream::BoxStream, Sink, Stream}; +use futures::Stream; +use tokio::io::{DuplexStream, ReadHalf, WriteHalf}; +use tokio_util::codec::{FramedRead, FramedWrite}; use monero_wire::{ network_address::{NetworkAddress, NetworkAddressIncorrectZone}, - BucketError, Message, + MoneroWireCodec, }; use monero_p2p::{NetZoneAddress, NetworkZone}; @@ -62,47 +63,6 @@ impl TryFrom for TestNetZoneAddr { } } -/// A wrapper around [`futures::channel::mpsc::Sender`] that changes the error to [`BucketError`]. -pub struct Sender { - inner: InnerSender, -} - -impl From> for Sender { - fn from(inner: InnerSender) -> Self { - Sender { inner } - } -} - -impl Sink for Sender { - type Error = BucketError; - - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.get_mut() - .inner - .poll_ready(cx) - .map_err(|_| BucketError::IO(std::io::Error::other("mock connection channel closed"))) - } - - fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { - self.get_mut() - .inner - .start_send(item) - .map_err(|_| BucketError::IO(std::io::Error::other("mock connection channel closed"))) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.get_mut().inner) - .poll_flush(cx) - .map_err(|_| BucketError::IO(std::io::Error::other("mock connection channel closed"))) - } - - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.get_mut().inner) - .poll_close(cx) - .map_err(|_| BucketError::IO(std::io::Error::other("mock connection channel closed"))) - } -} - #[derive(Debug, Clone, Copy, Eq, PartialEq)] pub struct TestNetZone; @@ -116,8 +76,8 @@ impl>; - type Sink = Sender; + type Stream = FramedRead, MoneroWireCodec>; + type Sink = FramedWrite, MoneroWireCodec>; type Listener = Pin< Box< dyn Stream<