levin: fragmented messages (#84)

* levin: fix fragmented messages & use bitflags

* levin: add a method to fragment a message

* levin: add tests for fragmented messages and fix issues

* fix docs

* tests: don't include bytes length

* levin: add support for sending fragmented
/ dummy messages

* fmt

* add fragmented handshake tests.

* fix handshake detection when fragmented
and alt (non-monero) protocol info

* add tracing logs

* remove `already_built`, this was an old way I was thinking of sending raw buckets

* clippy

* clippy 2

* Update net/levin/src/message.rs

Co-authored-by: hinto-janai <hinto.janai@protonmail.com>

* review comments

* add timeout to tests

* Update net/levin/src/header.rs

Co-authored-by: hinto-janai <hinto.janai@protonmail.com>

---------

Co-authored-by: hinto-janai <hinto.janai@protonmail.com>
This commit is contained in:
Boog900 2024-03-05 01:29:57 +00:00 committed by GitHub
parent 8ef70bf0cd
commit 159c8a3b48
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 854 additions and 190 deletions

7
Cargo.lock generated
View file

@ -629,6 +629,7 @@ dependencies = [
"tar", "tar",
"tempfile", "tempfile",
"tokio", "tokio",
"tokio-util",
"zip", "zip",
] ]
@ -1418,9 +1419,15 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
name = "levin-cuprate" name = "levin-cuprate"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"bitflags 2.4.2",
"bytes", "bytes",
"futures",
"proptest",
"rand",
"thiserror", "thiserror",
"tokio",
"tokio-util", "tokio-util",
"tracing",
] ]
[[package]] [[package]]

View file

@ -37,6 +37,7 @@ opt-level = 3
[workspace.dependencies] [workspace.dependencies]
async-trait = { version = "0.1.74", default-features = false } async-trait = { version = "0.1.74", default-features = false }
bitflags = { version = "2.4.2", default-features = false }
borsh = { version = "1.2.1", default-features = false } borsh = { version = "1.2.1", default-features = false }
bytes = { version = "1.5.0", default-features = false } bytes = { version = "1.5.0", default-features = false }
cfg-if = { version = "1.0.0", default-features = false } cfg-if = { version = "1.0.0", default-features = false }

View file

@ -7,8 +7,21 @@ license = "MIT"
authors = ["Boog900"] authors = ["Boog900"]
repository = "https://github.com/Cuprate/cuprate/tree/main/net/levin" repository = "https://github.com/Cuprate/cuprate/tree/main/net/levin"
[features]
default = []
tracing = ["dep:tracing", "tokio-util/tracing"]
[dependencies] [dependencies]
thiserror = { workspace = true } thiserror = { workspace = true }
bytes = { workspace = true } bytes = { workspace = true, features = ["std"] }
bitflags = { workspace = true }
tokio-util = { workspace = true, features = ["codec"]} tokio-util = { workspace = true, features = ["codec"]}
tracing = { workspace = true, features = ["std"], optional = true }
[dev-dependencies]
proptest = { workspace = true }
rand = { workspace = true, features = ["std", "std_rng"] }
tokio-util = { workspace = true, features = ["io-util"]}
tokio = { workspace = true, features = ["full"] }
futures = { workspace = true, features = ["std"] }

View file

@ -15,12 +15,14 @@
//! A tokio-codec for levin buckets //! A tokio-codec for levin buckets
use std::marker::PhantomData; use std::{fmt::Debug, marker::PhantomData};
use bytes::{Buf, BufMut, Bytes, BytesMut}; use bytes::{Buf, BufMut, BytesMut};
use tokio_util::codec::{Decoder, Encoder}; use tokio_util::codec::{Decoder, Encoder};
use crate::{ use crate::{
header::{Flags, HEADER_SIZE},
message::{make_dummy_message, LevinMessage},
Bucket, BucketBuilder, BucketError, BucketHead, LevinBody, LevinCommand, MessageType, Protocol, Bucket, BucketBuilder, BucketError, BucketHead, LevinBody, LevinCommand, MessageType, Protocol,
}; };
@ -61,31 +63,49 @@ impl<C> LevinBucketCodec<C> {
} }
} }
impl<C: LevinCommand> Decoder for LevinBucketCodec<C> { impl<C: LevinCommand + Debug> Decoder for LevinBucketCodec<C> {
type Item = Bucket<C>; type Item = Bucket<C>;
type Error = BucketError; type Error = BucketError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> { fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
loop { loop {
match &self.state { match &self.state {
LevinBucketState::WaitingForHeader => { LevinBucketState::WaitingForHeader => {
if src.len() < BucketHead::<C>::SIZE { if src.len() < HEADER_SIZE {
return Ok(None); return Ok(None);
}; };
let head = BucketHead::<C>::from_bytes(src); let head = BucketHead::<C>::from_bytes(src);
#[cfg(feature = "tracing")]
tracing::trace!(
"Received new bucket header, command: {:?}, waiting for body, body len: {}",
head.command,
head.size
);
if head.size > self.protocol.max_packet_size if head.size > self.protocol.max_packet_size
|| head.size > head.command.bucket_size_limit() || head.size > head.command.bucket_size_limit()
{ {
#[cfg(feature = "tracing")]
tracing::debug!("Peer sent message which is too large.");
return Err(BucketError::BucketExceededMaxSize); return Err(BucketError::BucketExceededMaxSize);
} }
if !self.handshake_message_seen { if !self.handshake_message_seen {
if head.size > self.protocol.max_packet_size_before_handshake { if head.size > self.protocol.max_packet_size_before_handshake {
#[cfg(feature = "tracing")]
tracing::debug!("Peer sent message which is too large.");
return Err(BucketError::BucketExceededMaxSize); return Err(BucketError::BucketExceededMaxSize);
} }
if head.command.is_handshake() { if head.command.is_handshake() {
#[cfg(feature = "tracing")]
tracing::debug!(
"Peer handshake message seen, increasing bucket size limit."
);
self.handshake_message_seen = true; self.handshake_message_seen = true;
} }
} }
@ -109,6 +129,9 @@ impl<C: LevinCommand> Decoder for LevinBucketCodec<C> {
unreachable!() unreachable!()
}; };
#[cfg(feature = "tracing")]
tracing::trace!("Received full bucket for command: {:?}", header.command);
return Ok(Some(Bucket { return Ok(Some(Bucket {
header, header,
body: src.copy_to_bytes(body_len), body: src.copy_to_bytes(body_len),
@ -122,23 +145,26 @@ impl<C: LevinCommand> Decoder for LevinBucketCodec<C> {
impl<C: LevinCommand> Encoder<Bucket<C>> for LevinBucketCodec<C> { impl<C: LevinCommand> Encoder<Bucket<C>> for LevinBucketCodec<C> {
type Error = BucketError; type Error = BucketError;
fn encode(&mut self, item: Bucket<C>, dst: &mut BytesMut) -> Result<(), Self::Error> { fn encode(&mut self, item: Bucket<C>, dst: &mut BytesMut) -> Result<(), Self::Error> {
if let Some(additional) = if let Some(additional) = (HEADER_SIZE + item.body.len()).checked_sub(dst.capacity()) {
(BucketHead::<C>::SIZE + item.body.len()).checked_sub(dst.capacity())
{
dst.reserve(additional) dst.reserve(additional)
} }
item.header.write_bytes(dst); item.header.write_bytes_into(dst);
dst.put_slice(&item.body); dst.put_slice(&item.body);
Ok(()) Ok(())
} }
} }
#[derive(Default, Debug, Clone)] #[derive(Default, Debug, Clone)]
enum MessageState<C> { enum MessageState {
#[default] #[default]
WaitingForBucket, WaitingForBucket,
WaitingForRestOfFragment(Vec<Bytes>, MessageType, C), /// Waiting for the rest of a fragmented message.
///
/// We keep the fragmented message as a Vec<u8> instead of [`Bytes`](bytes::Bytes) as [`Bytes`](bytes::Bytes) could point to a
/// large allocation even if the [`Bytes`](bytes::Bytes) itself is small, so is not safe to keep around for long.
/// To prevent this attack vector completely we just use Vec<u8> for fragmented messages.
WaitingForRestOfFragment(Vec<u8>),
} }
/// A tokio-codec for levin messages or in other words the decoded body /// A tokio-codec for levin messages or in other words the decoded body
@ -147,7 +173,7 @@ enum MessageState<C> {
pub struct LevinMessageCodec<T: LevinBody> { pub struct LevinMessageCodec<T: LevinBody> {
message_ty: PhantomData<T>, message_ty: PhantomData<T>,
bucket_codec: LevinBucketCodec<T::Command>, bucket_codec: LevinBucketCodec<T::Command>,
state: MessageState<T::Command>, state: MessageState,
} }
impl<T: LevinBody> Default for LevinMessageCodec<T> { impl<T: LevinBody> Default for LevinMessageCodec<T> {
@ -173,107 +199,143 @@ impl<T: LevinBody> Decoder for LevinMessageCodec<T> {
let flags = &bucket.header.flags; let flags = &bucket.header.flags;
if flags.is_start_fragment() && flags.is_end_fragment() { if flags.contains(Flags::DUMMY) {
// Dummy message // Dummy message
return Ok(None);
#[cfg(feature = "tracing")]
tracing::trace!("Received DUMMY bucket from peer, ignoring.");
// We may have another bucket in `src`.
continue;
}; };
if flags.is_end_fragment() { if flags.contains(Flags::END_FRAGMENT) {
return Err(BucketError::InvalidHeaderFlags( return Err(BucketError::InvalidHeaderFlags(
"Flag end fragment received before a start fragment", "Flag end fragment received before a start fragment",
)); ));
}; };
if !flags.is_request() && !flags.is_response() { if flags.contains(Flags::START_FRAGMENT) {
return Err(BucketError::InvalidHeaderFlags( // monerod does not require a start flag before starting a fragmented message,
"Request and response flags both not set", // but will always produce one, so it is ok for us to require one.
));
}; #[cfg(feature = "tracing")]
tracing::debug!("Bucket is a fragment, waiting for rest of message.");
self.state = MessageState::WaitingForRestOfFragment(bucket.body.to_vec());
continue;
}
// Normal, non fragmented bucket
let message_type = MessageType::from_flags_and_have_to_return( let message_type = MessageType::from_flags_and_have_to_return(
bucket.header.flags, bucket.header.flags,
bucket.header.have_to_return_data, bucket.header.have_to_return_data,
)?; )?;
if flags.is_start_fragment() {
let _ = std::mem::replace(
&mut self.state,
MessageState::WaitingForRestOfFragment(
vec![bucket.body],
message_type,
bucket.header.command,
),
);
continue;
}
return Ok(Some(T::decode_message( return Ok(Some(T::decode_message(
&mut bucket.body, &mut bucket.body,
message_type, message_type,
bucket.header.command, bucket.header.command,
)?)); )?));
} }
MessageState::WaitingForRestOfFragment(bytes, ty, command) => { MessageState::WaitingForRestOfFragment(bytes) => {
let Some(bucket) = self.bucket_codec.decode(src)? else { let Some(bucket) = self.bucket_codec.decode(src)? else {
return Ok(None); return Ok(None);
}; };
let flags = &bucket.header.flags; let flags = &bucket.header.flags;
if flags.is_start_fragment() && flags.is_end_fragment() { if flags.contains(Flags::DUMMY) {
// Dummy message // Dummy message
return Ok(None);
#[cfg(feature = "tracing")]
tracing::trace!("Received DUMMY bucket from peer, ignoring.");
// We may have another bucket in `src`.
continue;
}; };
if !flags.is_request() && !flags.is_response() { let max_size = if self.bucket_codec.handshake_message_seen {
return Err(BucketError::InvalidHeaderFlags( self.bucket_codec.protocol.max_packet_size
"Request and response flags both not set", } else {
)); self.bucket_codec.protocol.max_packet_size_before_handshake
};
let message_type = MessageType::from_flags_and_have_to_return(
bucket.header.flags,
bucket.header.have_to_return_data,
)?;
if message_type != *ty {
return Err(BucketError::InvalidFragmentedMessage(
"Message type was inconsistent across fragments",
));
} }
.try_into()
.expect("Levin max message size is too large, does not fit into a usize.");
if bucket.header.command != *command { if bytes.len().saturating_add(bucket.body.len()) > max_size {
return Err(BucketError::InvalidFragmentedMessage(
"Command not consistent across fragments",
));
}
if bytes.len().saturating_add(bucket.body.len())
> command.bucket_size_limit().try_into().unwrap()
{
return Err(BucketError::InvalidFragmentedMessage( return Err(BucketError::InvalidFragmentedMessage(
"Fragmented message exceeded maximum size", "Fragmented message exceeded maximum size",
)); ));
} }
bytes.push(bucket.body); #[cfg(feature = "tracing")]
tracing::trace!("Received another bucket fragment.");
if flags.is_end_fragment() { bytes.extend_from_slice(bucket.body.as_ref());
let MessageState::WaitingForRestOfFragment(mut bytes, ty, command) =
if flags.contains(Flags::END_FRAGMENT) {
// make sure we only look at the internal bucket and don't use this.
drop(bucket);
let MessageState::WaitingForRestOfFragment(bytes) =
std::mem::replace(&mut self.state, MessageState::WaitingForBucket) std::mem::replace(&mut self.state, MessageState::WaitingForBucket)
else { else {
unreachable!(); unreachable!();
}; };
// TODO: this doesn't seem very efficient but I can't think of a better way. // Check there are enough bytes in the fragment to build a header.
bytes.reverse(); if bytes.len() < HEADER_SIZE {
let mut byte_vec: Box<dyn Buf> = Box::new(bytes.pop().unwrap()); return Err(BucketError::InvalidFragmentedMessage(
for bytes in bytes { "Fragmented message is not large enough to build a bucket.",
byte_vec = Box::new(byte_vec.chain(bytes)); ));
} }
return Ok(Some(T::decode_message(&mut byte_vec, ty, command)?)); let mut header_bytes = BytesMut::from(&bytes[0..HEADER_SIZE]);
let header = BucketHead::<T::Command>::from_bytes(&mut header_bytes);
if header.size > header.command.bucket_size_limit() {
return Err(BucketError::BucketExceededMaxSize);
}
// Check the fragmented message contains enough bytes to build the message.
if bytes.len().saturating_sub(HEADER_SIZE)
< header
.size
.try_into()
.map_err(|_| BucketError::BucketExceededMaxSize)?
{
return Err(BucketError::InvalidFragmentedMessage(
"Fragmented message does not have enough bytes to fill bucket body",
));
}
#[cfg(feature = "tracing")]
tracing::debug!(
"Received final fragment, combined message command: {:?}.",
header.command
);
let message_type = MessageType::from_flags_and_have_to_return(
header.flags,
header.have_to_return_data,
)?;
if header.command.is_handshake() {
#[cfg(feature = "tracing")]
tracing::debug!(
"Peer handshake message seen, increasing bucket size limit."
);
self.bucket_codec.handshake_message_seen = true;
}
return Ok(Some(T::decode_message(
&mut &bytes[HEADER_SIZE..],
message_type,
header.command,
)?));
} }
} }
} }
@ -281,12 +343,21 @@ impl<T: LevinBody> Decoder for LevinMessageCodec<T> {
} }
} }
impl<T: LevinBody> Encoder<T> for LevinMessageCodec<T> { impl<T: LevinBody> Encoder<LevinMessage<T>> for LevinMessageCodec<T> {
type Error = BucketError; type Error = BucketError;
fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<(), Self::Error> { fn encode(&mut self, item: LevinMessage<T>, dst: &mut BytesMut) -> Result<(), Self::Error> {
let mut bucket_builder = BucketBuilder::default(); match item {
item.encode(&mut bucket_builder)?; LevinMessage::Body(body) => {
let mut bucket_builder = BucketBuilder::new(&self.bucket_codec.protocol);
body.encode(&mut bucket_builder)?;
let bucket = bucket_builder.finish(); let bucket = bucket_builder.finish();
self.bucket_codec.encode(bucket, dst) self.bucket_codec.encode(bucket, dst)
} }
LevinMessage::Bucket(bucket) => self.bucket_codec.encode(bucket, dst),
LevinMessage::Dummy(size) => {
let bucket = make_dummy_message(&self.bucket_codec.protocol, size);
self.bucket_codec.encode(bucket, dst)
}
}
}
} }

View file

@ -16,34 +16,47 @@
//! This module provides a struct BucketHead for the header of a levin protocol //! This module provides a struct BucketHead for the header of a levin protocol
//! message. //! message.
use bitflags::bitflags;
use bytes::{Buf, BufMut, BytesMut}; use bytes::{Buf, BufMut, BytesMut};
use crate::LevinCommand; use crate::LevinCommand;
const REQUEST: u32 = 0b0000_0001; /// The size of the header (in bytes)
const RESPONSE: u32 = 0b0000_0010; pub const HEADER_SIZE: usize = 33;
const START_FRAGMENT: u32 = 0b0000_0100;
const END_FRAGMENT: u32 = 0b0000_1000;
/// Levin header flags /// Levin header flags
#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] #[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
pub struct Flags(u32); pub struct Flags(u32);
impl Flags { bitflags! {
pub const REQUEST: Flags = Flags(REQUEST); impl Flags: u32 {
pub const RESPONSE: Flags = Flags(RESPONSE); /// The request flag.
///
/// Depending on the `have_to_return_data` field in [`BucketHead`], this message is either
/// a request or notification.
const REQUEST = 0b0000_0001;
/// The response flags.
///
/// Messages with this set are responses to requests.
const RESPONSE = 0b0000_0010;
pub fn is_request(&self) -> bool { /// The start fragment flag.
self.0 & REQUEST != 0 ///
} /// Messages with this flag set tell the parser that the next messages until a message
pub fn is_response(&self) -> bool { /// with [`Flags::END_FRAGMENT`] should be combined into a single bucket.
self.0 & RESPONSE != 0 const START_FRAGMENT = 0b0000_0100;
} /// The end fragment flag.
pub fn is_start_fragment(&self) -> bool { ///
self.0 & START_FRAGMENT != 0 /// Messages with this flag set tell the parser that all fragments of a fragmented message
} /// have been sent.
pub fn is_end_fragment(&self) -> bool { const END_FRAGMENT = 0b0000_1000;
self.0 & END_FRAGMENT != 0
/// A dummy message.
///
/// Messages with this flag will be completely ignored by the parser.
const DUMMY = Self::START_FRAGMENT.bits() | Self::END_FRAGMENT.bits();
const _ = !0;
} }
} }
@ -81,15 +94,12 @@ pub struct BucketHead<C> {
} }
impl<C: LevinCommand> BucketHead<C> { impl<C: LevinCommand> BucketHead<C> {
/// The size of the header (in bytes)
pub const SIZE: usize = 33;
/// Builds the header from bytes, this function does not check any fields should /// Builds the header from bytes, this function does not check any fields should
/// match the expected ones. /// match the expected ones.
/// ///
/// # Panics /// # Panics
/// This function will panic if there aren't enough bytes to fill the header. /// This function will panic if there aren't enough bytes to fill the header.
/// Currently ['SIZE'](BucketHead::SIZE) /// Currently [HEADER_SIZE]
pub fn from_bytes(buf: &mut BytesMut) -> BucketHead<C> { pub fn from_bytes(buf: &mut BytesMut) -> BucketHead<C> {
BucketHead { BucketHead {
signature: buf.get_u64_le(), signature: buf.get_u64_le(),
@ -103,8 +113,8 @@ impl<C: LevinCommand> BucketHead<C> {
} }
/// Serializes the header /// Serializes the header
pub fn write_bytes(&self, dst: &mut BytesMut) { pub fn write_bytes_into(&self, dst: &mut BytesMut) {
dst.reserve(Self::SIZE); dst.reserve(HEADER_SIZE);
dst.put_u64_le(self.signature); dst.put_u64_le(self.signature);
dst.put_u64_le(self.size); dst.put_u64_le(self.size);

View file

@ -33,20 +33,29 @@
#![deny(unused_mut)] #![deny(unused_mut)]
//#![deny(missing_docs)] //#![deny(missing_docs)]
pub mod codec;
pub mod header;
pub use codec::*;
pub use header::BucketHead;
use std::fmt::Debug; use std::fmt::Debug;
use bytes::{Buf, Bytes}; use bytes::{Buf, Bytes};
use thiserror::Error; use thiserror::Error;
pub mod codec;
pub mod header;
pub mod message;
pub use codec::*;
pub use header::BucketHead;
pub use message::LevinMessage;
use header::Flags;
/// The version field for bucket headers.
const MONERO_PROTOCOL_VERSION: u32 = 1; const MONERO_PROTOCOL_VERSION: u32 = 1;
/// The signature field for bucket headers, will be constant for all peers using the Monero levin
/// protocol.
const MONERO_LEVIN_SIGNATURE: u64 = 0x0101010101012101; const MONERO_LEVIN_SIGNATURE: u64 = 0x0101010101012101;
/// Maximum size a bucket can be before a handshake.
const MONERO_MAX_PACKET_SIZE_BEFORE_HANDSHAKE: u64 = 256 * 1000; // 256 KiB const MONERO_MAX_PACKET_SIZE_BEFORE_HANDSHAKE: u64 = 256 * 1000; // 256 KiB
/// Maximum size a bucket can be after a handshake.
const MONERO_MAX_PACKET_SIZE: u64 = 100_000_000; // 100MB const MONERO_MAX_PACKET_SIZE: u64 = 100_000_000; // 100MB
/// Possible Errors when working with levin buckets /// Possible Errors when working with levin buckets
@ -98,7 +107,7 @@ impl Default for Protocol {
} }
/// A levin Bucket /// A levin Bucket
#[derive(Debug)] #[derive(Debug, Clone)]
pub struct Bucket<C> { pub struct Bucket<C> {
/// The bucket header /// The bucket header
pub header: BucketHead<C>, pub header: BucketHead<C>,
@ -128,20 +137,19 @@ impl MessageType {
/// Returns the `MessageType` given the flags and have_to_return_data fields /// Returns the `MessageType` given the flags and have_to_return_data fields
pub fn from_flags_and_have_to_return( pub fn from_flags_and_have_to_return(
flags: header::Flags, flags: Flags,
have_to_return: bool, have_to_return: bool,
) -> Result<Self, BucketError> { ) -> Result<Self, BucketError> {
if flags.is_request() && have_to_return { Ok(match (flags, have_to_return) {
Ok(MessageType::Request) (Flags::REQUEST, true) => MessageType::Request,
} else if flags.is_request() { (Flags::REQUEST, false) => MessageType::Notification,
Ok(MessageType::Notification) (Flags::RESPONSE, false) => MessageType::Response,
} else if flags.is_response() && !have_to_return { _ => {
Ok(MessageType::Response) return Err(BucketError::InvalidHeaderFlags(
} else {
Err(BucketError::InvalidHeaderFlags(
"Unable to assign a message type to this bucket", "Unable to assign a message type to this bucket",
)) ))
} }
})
} }
pub fn as_flags(&self) -> header::Flags { pub fn as_flags(&self) -> header::Flags {
@ -162,20 +170,18 @@ pub struct BucketBuilder<C> {
body: Option<Bytes>, body: Option<Bytes>,
} }
impl<C> Default for BucketBuilder<C> { impl<C: LevinCommand> BucketBuilder<C> {
fn default() -> Self { pub fn new(protocol: &Protocol) -> Self {
Self { Self {
signature: Some(MONERO_LEVIN_SIGNATURE), signature: Some(protocol.signature),
ty: None, ty: None,
command: None, command: None,
return_code: None, return_code: None,
protocol_version: Some(MONERO_PROTOCOL_VERSION), protocol_version: Some(protocol.version),
body: None, body: None,
} }
} }
}
impl<C: LevinCommand> BucketBuilder<C> {
pub fn set_signature(&mut self, sig: u64) { pub fn set_signature(&mut self, sig: u64) {
self.signature = Some(sig) self.signature = Some(sig)
} }
@ -220,7 +226,7 @@ impl<C: LevinCommand> BucketBuilder<C> {
/// A levin body /// A levin body
pub trait LevinBody: Sized { pub trait LevinBody: Sized {
type Command: LevinCommand; type Command: LevinCommand + Debug;
/// Decodes the message from the data in the header /// Decodes the message from the data in the header
fn decode_message<B: Buf>( fn decode_message<B: Buf>(

206
net/levin/src/message.rs Normal file
View file

@ -0,0 +1,206 @@
//! Levin Messages
//!
//! This module contains the [`LevinMessage`], which allows sending bucket body's, full buckets or dummy messages.
//! The codec will not return [`LevinMessage`] instead it will only return bucket body's. [`LevinMessage`] allows
//! for more control over what is actually sent over the wire at certain times.
use bytes::{Bytes, BytesMut};
use crate::{
header::{Flags, HEADER_SIZE},
Bucket, BucketBuilder, BucketError, BucketHead, LevinBody, LevinCommand, Protocol,
};
/// A levin message that can be sent to a peer.
pub enum LevinMessage<T: LevinBody> {
/// A message body.
///
/// A levin header will be added to this message before it is sent to the peer.
Body(T),
/// A full levin bucket.
///
/// This bucket will be sent to the peer directly with no extra information.
///
/// This should only be used to send fragmented messages: [`make_fragmented_messages`]
Bucket(Bucket<T::Command>),
/// A dummy message.
///
/// A dummy message which the peer will ignore. The dummy message will be the exact size
/// (in bytes) of the given `usize` on the wire.
Dummy(usize),
}
impl<T: LevinBody> From<T> for LevinMessage<T> {
fn from(value: T) -> Self {
LevinMessage::Body(value)
}
}
impl<T: LevinBody> From<Bucket<T::Command>> for LevinMessage<T> {
fn from(value: Bucket<T::Command>) -> Self {
LevinMessage::Bucket(value)
}
}
/// This represents a dummy message to send to a peer.
///
/// The message, including the header, will be the exact size of the given `usize`.
/// This exists because it seems weird to do this:
/// ```rust,ignore
/// peer.send(1_000);
/// ```
/// This is a lot clearer:
/// ```rust,ignore
/// peer.send(Dummy(1_000));
/// ```
pub struct Dummy(pub usize);
impl<T: LevinBody> From<Dummy> for LevinMessage<T> {
fn from(value: Dummy) -> Self {
LevinMessage::Dummy(value.0)
}
}
/// Fragments the provided message into buckets which, when serialised, will all be the size of `fragment_size`.
///
/// This function will produce many buckets that have to be sent in order. When the peer receives these buckets
/// they will combine them to produce the original message.
///
/// The last bucket may be padded with zeros to make it the correct size, the format used to encode the body must
/// allow for extra data at the end of the message this to work.
///
/// `fragment_size` must be more than 2 * [`HEADER_SIZE`] otherwise this will panic.
pub fn make_fragmented_messages<T: LevinBody>(
protocol: &Protocol,
fragment_size: usize,
message: T,
) -> Result<Vec<Bucket<T::Command>>, BucketError> {
if fragment_size * 2 < HEADER_SIZE {
panic!(
"Fragment size: {fragment_size}, is too small, must be at least {}",
2 * HEADER_SIZE
);
}
let mut builder = BucketBuilder::new(protocol);
message.encode(&mut builder)?;
let mut bucket = builder.finish();
// Make sure we are not trying to fragment a fragment.
if !bucket
.header
.flags
.intersects(Flags::REQUEST | Flags::RESPONSE)
{
// If a bucket does not have the request or response bits set it is a fragment.
return Err(BucketError::InvalidFragmentedMessage(
"Can't make a fragmented message out of a message which is already fragmented",
));
}
// Check if the bucket can fit in one fragment.
if bucket.body.len() + HEADER_SIZE <= fragment_size {
// If it can pad the bucket upto the fragment size and just return this bucket.
if bucket.body.len() + HEADER_SIZE < fragment_size {
let mut new_body = BytesMut::from(bucket.body.as_ref());
// Epee's binary format will ignore extra data at the end so just pad with 0.
new_body.resize(fragment_size - HEADER_SIZE, 0);
bucket.body = new_body.freeze();
bucket.header.size = fragment_size
.try_into()
.expect("Bucket size does not fit into u64");
}
return Ok(vec![bucket]);
}
// A header put on all fragments.
// The first fragment will set the START flag, the last will set the END flag.
let fragment_head = BucketHead {
signature: protocol.signature,
size: (fragment_size - HEADER_SIZE)
.try_into()
.expect("Bucket size does not fit into u64"),
have_to_return_data: false,
// Just use a default command.
command: T::Command::from(0),
return_code: 0,
flags: Flags::empty(),
protocol_version: protocol.version,
};
// data_space - the amount of actual data we can fit in each fragment.
let data_space = fragment_size - HEADER_SIZE;
let amount_of_fragments = (bucket.body.len() + HEADER_SIZE).div_ceil(data_space);
let mut first_bucket_body = BytesMut::with_capacity(fragment_size);
// Fragmented messages store the whole fragmented bucket in the combined payloads not just the body
// so the first bucket contains 2 headers, a fragment header and the actual bucket header we are sending.
bucket.header.write_bytes_into(&mut first_bucket_body);
first_bucket_body.extend_from_slice(
bucket
.body
.split_to(fragment_size - (HEADER_SIZE * 2))
.as_ref(),
);
let mut buckets = Vec::with_capacity(amount_of_fragments);
buckets.push(Bucket {
header: fragment_head.clone(),
body: first_bucket_body.freeze(),
});
for mut bytes in (1..amount_of_fragments).map(|_| {
bucket
.body
.split_to((fragment_size - HEADER_SIZE).min(bucket.body.len()))
}) {
// make sure this fragment has the correct size - the last one might not, so pad it.
if bytes.len() + HEADER_SIZE < fragment_size {
let mut new_bytes = BytesMut::from(bytes.as_ref());
// Epee's binary format will ignore extra data at the end so just pad with 0.
new_bytes.resize(fragment_size - HEADER_SIZE, 0);
bytes = new_bytes.freeze();
}
buckets.push(Bucket {
header: fragment_head.clone(),
body: bytes,
});
}
buckets
.first_mut()
.unwrap()
.header
.flags
.toggle(Flags::START_FRAGMENT);
buckets
.last_mut()
.unwrap()
.header
.flags
.toggle(Flags::END_FRAGMENT);
Ok(buckets)
}
/// Makes a dummy message, which will be the size of `size` when sent over the wire.
pub(crate) fn make_dummy_message<T: LevinCommand>(protocol: &Protocol, size: usize) -> Bucket<T> {
// A header to put on the dummy message.
let header = BucketHead {
signature: protocol.signature,
size: size.try_into().expect("Bucket size does not fit into u64"),
have_to_return_data: false,
// Just use a default command.
command: T::from(0),
return_code: 0,
flags: Flags::DUMMY,
protocol_version: protocol.version,
};
let body = Bytes::from(vec![0; size - HEADER_SIZE]);
Bucket { header, body }
}

View file

@ -0,0 +1,151 @@
use bytes::{Buf, BufMut, Bytes, BytesMut};
use futures::{SinkExt, StreamExt};
use proptest::{prelude::any_with, prop_assert_eq, proptest, sample::size_range};
use rand::Fill;
use tokio::{
io::duplex,
time::{timeout, Duration},
};
use tokio_util::codec::{FramedRead, FramedWrite};
use levin_cuprate::{
message::make_fragmented_messages, BucketBuilder, BucketError, LevinBody, LevinCommand,
LevinMessageCodec, MessageType, Protocol,
};
/// A timeout put on streams so tests don't stall.
const TEST_TIMEOUT: Duration = Duration::from_secs(30);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct TestCommands(u32);
impl From<u32> for TestCommands {
fn from(value: u32) -> Self {
Self(value)
}
}
impl From<TestCommands> for u32 {
fn from(value: TestCommands) -> Self {
value.0
}
}
impl LevinCommand for TestCommands {
fn bucket_size_limit(&self) -> u64 {
u64::MAX
}
fn is_handshake(&self) -> bool {
self.0 == 1
}
}
#[derive(Clone)]
enum TestBody {
Bytes(usize, Bytes),
}
impl LevinBody for TestBody {
type Command = TestCommands;
fn decode_message<B: Buf>(
body: &mut B,
_: MessageType,
_: Self::Command,
) -> Result<Self, BucketError> {
let size = body.get_u64_le().try_into().unwrap();
// bucket
Ok(TestBody::Bytes(size, body.copy_to_bytes(size)))
}
fn encode(self, builder: &mut BucketBuilder<Self::Command>) -> Result<(), BucketError> {
match self {
TestBody::Bytes(len, bytes) => {
let mut buf = BytesMut::new();
buf.put_u64_le(len as u64);
buf.extend_from_slice(bytes.as_ref());
builder.set_command(TestCommands(1));
builder.set_message_type(MessageType::Notification);
builder.set_return_code(0);
builder.set_body(buf.freeze());
}
}
Ok(())
}
}
#[tokio::test]
async fn codec_fragmented_messages() {
// Set up the fake connection
let (write, read) = duplex(100_000);
let mut read = FramedRead::new(read, LevinMessageCodec::<TestBody>::default());
let mut write = FramedWrite::new(write, LevinMessageCodec::<TestBody>::default());
// Create the message to fragment
let mut buf = BytesMut::from(vec![0; 10_000].as_slice());
let mut rng = rand::thread_rng();
buf.try_fill(&mut rng).unwrap();
let message = TestBody::Bytes(buf.len(), buf.freeze());
let fragments = make_fragmented_messages(&Protocol::default(), 3_000, message.clone()).unwrap();
for frag in fragments {
// Send each fragment
timeout(TEST_TIMEOUT, write.send(frag.into()))
.await
.unwrap()
.unwrap();
}
// only one message should be received.
let message2 = timeout(TEST_TIMEOUT, read.next())
.await
.unwrap()
.unwrap()
.unwrap();
match (message, message2) {
(TestBody::Bytes(_, buf), TestBody::Bytes(_, buf2)) => assert_eq!(buf, buf2),
}
}
proptest! {
#[test]
fn make_fragmented_messages_correct_size(fragment_size in 100_usize..5000, message_size in 0_usize..100_000) {
let mut bytes = BytesMut::new();
bytes.resize(message_size, 10);
let fragments = make_fragmented_messages(&Protocol::default(), fragment_size, TestBody::Bytes(bytes.len(), bytes.freeze())).unwrap();
let len = fragments.len();
for (i, fragment) in fragments.into_iter().enumerate() {
prop_assert_eq!(fragment.body.len() + 33, fragment_size, "numb_fragments:{}, index: {}", len, i)
}
}
#[test]
fn make_fragmented_messages_consistent(fragment_size in 100_usize..5_000, message in any_with::<Vec<u8>>(size_range(50_000).lift())) {
let fragments = make_fragmented_messages(&Protocol::default(), fragment_size, TestBody::Bytes(message.len(), Bytes::copy_from_slice(message.as_slice()))).unwrap();
let mut message2 = Vec::with_capacity(message.len());
// remove the header and the bytes length.
message2.extend_from_slice(&fragments[0].body[(33 + 8)..]);
for frag in fragments.iter().skip(1) {
message2.extend_from_slice(frag.body.as_ref())
}
prop_assert_eq!(message.as_slice(), &message2[0..message.len()], "numb_fragments: {}", fragments.len());
for byte in message2[message.len()..].iter(){
prop_assert_eq!(*byte, 0);
}
}
}

View file

@ -6,6 +6,9 @@ license = "MIT"
authors = ["Boog900"] authors = ["Boog900"]
repository = "https://github.com/SyntheticBird45/cuprate/tree/main/net/monero-wire" repository = "https://github.com/SyntheticBird45/cuprate/tree/main/net/monero-wire"
[features]
default = []
tracing = ["levin-cuprate/tracing"]
[dependencies] [dependencies]
levin-cuprate = {path="../levin"} levin-cuprate = {path="../levin"}

View file

@ -29,4 +29,7 @@ pub use levin_cuprate::BucketError;
pub use network_address::{NetZone, NetworkAddress}; pub use network_address::{NetZone, NetworkAddress};
pub use p2p::*; pub use p2p::*;
// re-export.
pub use levin_cuprate as levin;
pub type MoneroWireCodec = levin_cuprate::codec::LevinMessageCodec<Message>; pub type MoneroWireCodec = levin_cuprate::codec::LevinMessageCodec<Message>;

View file

@ -11,7 +11,7 @@ borsh = ["dep:borsh", "monero-pruning/borsh"]
[dependencies] [dependencies]
cuprate-helper = { path = "../../helper" } cuprate-helper = { path = "../../helper" }
monero-wire = { path = "../../net/monero-wire" } monero-wire = { path = "../../net/monero-wire", features = ["tracing"] }
monero-pruning = { path = "../../pruning" } monero-pruning = { path = "../../pruning" }
tokio = { workspace = true, features = ["net", "sync", "macros", "time"]} tokio = { workspace = true, features = ["net", "sync", "macros", "time"]}

View file

@ -86,7 +86,7 @@ where
} }
async fn send_message_to_peer(&mut self, mes: Message) -> Result<(), PeerError> { async fn send_message_to_peer(&mut self, mes: Message) -> Result<(), PeerError> {
Ok(self.peer_sink.send(mes).await?) Ok(self.peer_sink.send(mes.into()).await?)
} }
async fn handle_client_request(&mut self, req: ConnectionTaskRequest) -> Result<(), PeerError> { async fn handle_client_request(&mut self, req: ConnectionTaskRequest) -> Result<(), PeerError> {

View file

@ -257,7 +257,7 @@ where
"Peer didn't send support flags or has no features, sending request to make sure." "Peer didn't send support flags or has no features, sending request to make sure."
); );
peer_sink peer_sink
.send(Message::Request(RequestMessage::SupportFlags)) .send(Message::Request(RequestMessage::SupportFlags).into())
.await?; .await?;
let Message::Response(ResponseMessage::SupportFlags(support_flags_res)) = let Message::Response(ResponseMessage::SupportFlags(support_flags_res)) =
@ -346,7 +346,7 @@ where
tracing::debug!("Sending handshake request."); tracing::debug!("Sending handshake request.");
peer_sink peer_sink
.send(Message::Request(RequestMessage::Handshake(req))) .send(Message::Request(RequestMessage::Handshake(req)).into())
.await?; .await?;
Ok(()) Ok(())
@ -391,7 +391,7 @@ where
tracing::debug!("Sending handshake response."); tracing::debug!("Sending handshake response.");
peer_sink peer_sink
.send(Message::Response(ResponseMessage::Handshake(res))) .send(Message::Response(ResponseMessage::Handshake(res)).into())
.await?; .await?;
Ok(()) Ok(())
@ -476,8 +476,11 @@ async fn send_support_flags<Z: NetworkZone>(
) -> Result<(), HandshakeError> { ) -> Result<(), HandshakeError> {
tracing::debug!("Sending support flag response."); tracing::debug!("Sending support flag response.");
Ok(peer_sink Ok(peer_sink
.send(Message::Response(ResponseMessage::SupportFlags( .send(
SupportFlagsResponse { support_flags }, Message::Response(ResponseMessage::SupportFlags(SupportFlagsResponse {
))) support_flags,
}))
.into(),
)
.await?) .await?)
} }

View file

@ -5,7 +5,8 @@ use std::{fmt::Debug, future::Future, hash::Hash, pin::Pin};
use futures::{Sink, Stream}; use futures::{Sink, Stream};
use monero_wire::{ use monero_wire::{
network_address::NetworkAddressIncorrectZone, BucketError, Message, NetworkAddress, levin::LevinMessage, network_address::NetworkAddressIncorrectZone, BucketError, Message,
NetworkAddress,
}; };
pub mod client; pub mod client;
@ -103,7 +104,7 @@ pub trait NetworkZone: Clone + Copy + Send + 'static {
/// The stream (incoming data) type for this network. /// The stream (incoming data) type for this network.
type Stream: Stream<Item = Result<Message, BucketError>> + Unpin + Send + 'static; type Stream: Stream<Item = Result<Message, BucketError>> + Unpin + Send + 'static;
/// The sink (outgoing data) type for this network. /// The sink (outgoing data) type for this network.
type Sink: Sink<Message, Error = BucketError> + Unpin + Send + 'static; type Sink: Sink<LevinMessage<Message>, Error = BucketError> + Unpin + Send + 'static;
/// The inbound connection listener for this network. /// The inbound connection listener for this network.
type Listener: Stream< type Listener: Stream<
Item = Result<(Option<Self::Addr>, Self::Stream, Self::Sink), std::io::Error>, Item = Result<(Option<Self::Addr>, Self::Stream, Self::Sink), std::io::Error>,

View file

@ -30,7 +30,7 @@ pub struct ClearNetServerCfg {
} }
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
pub struct ClearNet; pub enum ClearNet {}
#[async_trait::async_trait] #[async_trait::async_trait]
impl NetworkZone for ClearNet { impl NetworkZone for ClearNet {

View file

@ -0,0 +1,224 @@
//! This file contains a test for a handshake with monerod but uses fragmented messages.
use std::{
net::SocketAddr,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Duration,
};
use futures::{Stream, StreamExt};
use tokio::{
net::{
tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpListener, TcpStream,
},
sync::{broadcast, Semaphore},
time::timeout,
};
use tokio_util::{
bytes::BytesMut,
codec::{Encoder, FramedRead, FramedWrite},
};
use tower::{Service, ServiceExt};
use cuprate_helper::network::Network;
use monero_p2p::{
client::{ConnectRequest, Connector, DoHandshakeRequest, HandShaker, InternalPeerID},
network_zones::ClearNetServerCfg,
ConnectionDirection, NetworkZone,
};
use monero_wire::{
common::PeerSupportFlags,
levin::{message::make_fragmented_messages, LevinMessage, Protocol},
BasicNodeData, Message, MoneroWireCodec,
};
use cuprate_test_utils::monerod::monerod;
mod utils;
use utils::*;
/// A network zone equal to clear net where every message sent is turned into a fragmented message.
/// Does not support sending fragmented or dummy messages manually.
#[derive(Clone, Copy)]
pub enum FragNet {}
#[async_trait::async_trait]
impl NetworkZone for FragNet {
const NAME: &'static str = "FragNet";
const ALLOW_SYNC: bool = true;
const DANDELION_PP: bool = true;
const CHECK_NODE_ID: bool = true;
type Addr = SocketAddr;
type Stream = FramedRead<OwnedReadHalf, MoneroWireCodec>;
type Sink = FramedWrite<OwnedWriteHalf, FragmentCodec>;
type Listener = InBoundStream;
type ServerCfg = ClearNetServerCfg;
async fn connect_to_peer(
addr: Self::Addr,
) -> Result<(Self::Stream, Self::Sink), std::io::Error> {
let (read, write) = TcpStream::connect(addr).await?.into_split();
Ok((
FramedRead::new(read, MoneroWireCodec::default()),
FramedWrite::new(write, FragmentCodec::default()),
))
}
async fn incoming_connection_listener(
config: Self::ServerCfg,
) -> Result<Self::Listener, std::io::Error> {
let listener = TcpListener::bind(config.addr).await?;
Ok(InBoundStream { listener })
}
}
pub struct InBoundStream {
listener: TcpListener,
}
impl Stream for InBoundStream {
type Item = Result<
(
Option<SocketAddr>,
FramedRead<OwnedReadHalf, MoneroWireCodec>,
FramedWrite<OwnedWriteHalf, FragmentCodec>,
),
std::io::Error,
>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.listener
.poll_accept(cx)
.map_ok(|(stream, addr)| {
let (read, write) = stream.into_split();
(
Some(addr),
FramedRead::new(read, MoneroWireCodec::default()),
FramedWrite::new(write, FragmentCodec::default()),
)
})
.map(Some)
}
}
#[derive(Default)]
pub struct FragmentCodec(MoneroWireCodec);
impl Encoder<LevinMessage<Message>> for FragmentCodec {
type Error = <MoneroWireCodec as Encoder<LevinMessage<Message>>>::Error;
fn encode(
&mut self,
item: LevinMessage<Message>,
dst: &mut BytesMut,
) -> Result<(), Self::Error> {
match item {
LevinMessage::Body(body) => {
// 66 is the minimum fragment size.
let fragments = make_fragmented_messages(&Protocol::default(), 66, body).unwrap();
for frag in fragments {
self.0.encode(frag.into(), dst)?;
}
}
_ => unreachable!("Handshakes should only send bucket bodys"),
}
Ok(())
}
}
#[tokio::test]
async fn fragmented_handshake_cuprate_to_monerod() {
let (broadcast_tx, _) = broadcast::channel(1); // this isn't actually used in this test.
let semaphore = Arc::new(Semaphore::new(10));
let permit = semaphore.acquire_owned().await.unwrap();
let monerod = monerod(["--fixed-difficulty=1", "--out-peers=0"]).await;
let our_basic_node_data = BasicNodeData {
my_port: 0,
network_id: Network::Mainnet.network_id().into(),
peer_id: 87980,
support_flags: PeerSupportFlags::from(1_u32),
rpc_port: 0,
rpc_credits_per_hash: 0,
};
let handshaker = HandShaker::<FragNet, _, _, _>::new(
DummyAddressBook,
DummyCoreSyncSvc,
DummyPeerRequestHandlerSvc,
broadcast_tx,
our_basic_node_data,
);
let mut connector = Connector::new(handshaker);
connector
.ready()
.await
.unwrap()
.call(ConnectRequest {
addr: monerod.p2p_addr(),
permit,
})
.await
.unwrap();
}
#[tokio::test]
async fn fragmented_handshake_monerod_to_cuprate() {
let (broadcast_tx, _) = broadcast::channel(1); // this isn't actually used in this test.
let semaphore = Arc::new(Semaphore::new(10));
let permit = semaphore.acquire_owned().await.unwrap();
let our_basic_node_data = BasicNodeData {
my_port: 18081,
network_id: Network::Mainnet.network_id().into(),
peer_id: 87980,
support_flags: PeerSupportFlags::from(1_u32),
rpc_port: 0,
rpc_credits_per_hash: 0,
};
let mut handshaker = HandShaker::<FragNet, _, _, _>::new(
DummyAddressBook,
DummyCoreSyncSvc,
DummyPeerRequestHandlerSvc,
broadcast_tx,
our_basic_node_data,
);
let addr = "127.0.0.1:18081".parse().unwrap();
let mut listener = FragNet::incoming_connection_listener(ClearNetServerCfg { addr })
.await
.unwrap();
let _monerod = monerod(["--add-exclusive-node=127.0.0.1:18081"]).await;
// Put a timeout on this just in case monerod doesn't make the connection to us.
let next_connection_fut = timeout(Duration::from_secs(30), listener.next());
if let Some(Ok((addr, stream, sink))) = next_connection_fut.await.unwrap() {
let _ = handshaker
.ready()
.await
.unwrap()
.call(DoHandshakeRequest {
addr: InternalPeerID::KnownAddr(addr.unwrap()), // This is clear net all addresses are known.
peer_stream: stream,
peer_sink: sink,
direction: ConnectionDirection::InBound,
permit,
})
.await
.unwrap();
} else {
panic!("Failed to receive connection from monerod.");
};
}

View file

@ -1,14 +1,16 @@
use std::{sync::Arc, time::Duration}; use std::{sync::Arc, time::Duration};
use futures::{channel::mpsc, StreamExt}; use futures::StreamExt;
use tokio::{ use tokio::{
io::{duplex, split},
sync::{broadcast, Semaphore}, sync::{broadcast, Semaphore},
time::timeout, time::timeout,
}; };
use tokio_util::codec::{FramedRead, FramedWrite};
use tower::{Service, ServiceExt}; use tower::{Service, ServiceExt};
use cuprate_helper::network::Network; use cuprate_helper::network::Network;
use monero_wire::{common::PeerSupportFlags, BasicNodeData}; use monero_wire::{common::PeerSupportFlags, BasicNodeData, MoneroWireCodec};
use monero_p2p::{ use monero_p2p::{
client::{ConnectRequest, Connector, DoHandshakeRequest, HandShaker, InternalPeerID}, client::{ConnectRequest, Connector, DoHandshakeRequest, HandShaker, InternalPeerID},
@ -63,21 +65,23 @@ async fn handshake_cuprate_to_cuprate() {
our_basic_node_data_2, our_basic_node_data_2,
); );
let (p1_sender, p2_receiver) = mpsc::channel(5); let (p1, p2) = duplex(50_000);
let (p2_sender, p1_receiver) = mpsc::channel(5);
let (p1_receiver, p1_sender) = split(p1);
let (p2_receiver, p2_sender) = split(p2);
let p1_handshake_req = DoHandshakeRequest { let p1_handshake_req = DoHandshakeRequest {
addr: InternalPeerID::KnownAddr(TestNetZoneAddr(888)), addr: InternalPeerID::KnownAddr(TestNetZoneAddr(888)),
peer_stream: p2_receiver.map(Ok).boxed(), peer_stream: FramedRead::new(p2_receiver, MoneroWireCodec::default()),
peer_sink: p2_sender.into(), peer_sink: FramedWrite::new(p2_sender, MoneroWireCodec::default()),
direction: ConnectionDirection::OutBound, direction: ConnectionDirection::OutBound,
permit: permit_1, permit: permit_1,
}; };
let p2_handshake_req = DoHandshakeRequest { let p2_handshake_req = DoHandshakeRequest {
addr: InternalPeerID::KnownAddr(TestNetZoneAddr(444)), addr: InternalPeerID::KnownAddr(TestNetZoneAddr(444)),
peer_stream: p1_receiver.boxed().map(Ok).boxed(), peer_stream: FramedRead::new(p1_receiver, MoneroWireCodec::default()),
peer_sink: p1_sender.into(), peer_sink: FramedWrite::new(p1_sender, MoneroWireCodec::default()),
direction: ConnectionDirection::InBound, direction: ConnectionDirection::InBound,
permit: permit_2, permit: permit_2,
}; };

View file

@ -12,6 +12,7 @@ monero-p2p = {path = "../p2p/monero-p2p", features = ["borsh"] }
futures = { workspace = true, features = ["std"] } futures = { workspace = true, features = ["std"] }
async-trait = { workspace = true } async-trait = { workspace = true }
tokio = { workspace = true, features = ["full"] } tokio = { workspace = true, features = ["full"] }
tokio-util = { workspace = true }
reqwest = { workspace = true } reqwest = { workspace = true }
bytes = { workspace = true, features = ["std"] } bytes = { workspace = true, features = ["std"] }
tempfile = { workspace = true } tempfile = { workspace = true }

View file

@ -8,15 +8,16 @@ use std::{
io::Error, io::Error,
net::{Ipv4Addr, SocketAddr}, net::{Ipv4Addr, SocketAddr},
pin::Pin, pin::Pin,
task::{Context, Poll},
}; };
use borsh::{BorshDeserialize, BorshSerialize}; use borsh::{BorshDeserialize, BorshSerialize};
use futures::{channel::mpsc::Sender as InnerSender, stream::BoxStream, Sink, Stream}; use futures::Stream;
use tokio::io::{DuplexStream, ReadHalf, WriteHalf};
use tokio_util::codec::{FramedRead, FramedWrite};
use monero_wire::{ use monero_wire::{
network_address::{NetworkAddress, NetworkAddressIncorrectZone}, network_address::{NetworkAddress, NetworkAddressIncorrectZone},
BucketError, Message, MoneroWireCodec,
}; };
use monero_p2p::{NetZoneAddress, NetworkZone}; use monero_p2p::{NetZoneAddress, NetworkZone};
@ -62,47 +63,6 @@ impl TryFrom<NetworkAddress> for TestNetZoneAddr {
} }
} }
/// A wrapper around [`futures::channel::mpsc::Sender`] that changes the error to [`BucketError`].
pub struct Sender {
inner: InnerSender<Message>,
}
impl From<InnerSender<Message>> for Sender {
fn from(inner: InnerSender<Message>) -> Self {
Sender { inner }
}
}
impl Sink<Message> for Sender {
type Error = BucketError;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.get_mut()
.inner
.poll_ready(cx)
.map_err(|_| BucketError::IO(std::io::Error::other("mock connection channel closed")))
}
fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
self.get_mut()
.inner
.start_send(item)
.map_err(|_| BucketError::IO(std::io::Error::other("mock connection channel closed")))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.get_mut().inner)
.poll_flush(cx)
.map_err(|_| BucketError::IO(std::io::Error::other("mock connection channel closed")))
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.get_mut().inner)
.poll_close(cx)
.map_err(|_| BucketError::IO(std::io::Error::other("mock connection channel closed")))
}
}
#[derive(Debug, Clone, Copy, Eq, PartialEq)] #[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub struct TestNetZone<const ALLOW_SYNC: bool, const DANDELION_PP: bool, const CHECK_NODE_ID: bool>; pub struct TestNetZone<const ALLOW_SYNC: bool, const DANDELION_PP: bool, const CHECK_NODE_ID: bool>;
@ -116,8 +76,8 @@ impl<const ALLOW_SYNC: bool, const DANDELION_PP: bool, const CHECK_NODE_ID: bool
const CHECK_NODE_ID: bool = CHECK_NODE_ID; const CHECK_NODE_ID: bool = CHECK_NODE_ID;
type Addr = TestNetZoneAddr; type Addr = TestNetZoneAddr;
type Stream = BoxStream<'static, Result<Message, BucketError>>; type Stream = FramedRead<ReadHalf<DuplexStream>, MoneroWireCodec>;
type Sink = Sender; type Sink = FramedWrite<WriteHalf<DuplexStream>, MoneroWireCodec>;
type Listener = Pin< type Listener = Pin<
Box< Box<
dyn Stream< dyn Stream<