From 8557073c15d8226b876869bea8403513857edb99 Mon Sep 17 00:00:00 2001 From: Boog900 <boog900@tutanota.com> Date: Thu, 30 Nov 2023 18:09:05 +0000 Subject: [PATCH] p2p changes (#38) * start re-working p2p to work with change monero-wire * start re-working p2p to work with change monero-wire adds back some changes from #22 * change the peer module to use the new API + fix a couple bugs * remove peer set for now * add try_from/from conversion between `Message` and `Request`/`Response` * Allow specifying other parameters in levin-cuprate * add new `LevinCommand` enum and clean up monero-wire message de/encoding * fix issues with merge * start splitting up p2p crate into smaller crates. * add monerod action from serai to test network code * remove tracing in tests --- .github/actions/monerod-regtest/action.yml | 62 ++ .github/workflows/ci.yml | 3 + Cargo.toml | 4 +- common/src/pruning.rs | 1 + consensus/src/rpc/connection.rs | 8 +- net/levin/src/codec.rs | 140 +++-- net/levin/src/header.rs | 123 ++-- net/levin/src/lib.rs | 103 ++- net/monero-wire/src/lib.rs | 10 +- net/monero-wire/src/network_address.rs | 39 +- .../src/network_address/serde_helper.rs | 34 +- net/monero-wire/src/p2p.rs | 306 +++++++-- net/monero-wire/src/p2p/common.rs | 17 +- p2p/Cargo.toml | 12 +- p2p/monero-peer/Cargo.toml | 28 + p2p/monero-peer/src/client.rs | 6 + p2p/monero-peer/src/client/conector.rs | 61 ++ p2p/monero-peer/src/client/connection.rs | 176 ++++++ p2p/monero-peer/src/client/handshaker.rs | 494 +++++++++++++++ p2p/monero-peer/src/error.rs | 15 + p2p/monero-peer/src/lib.rs | 157 +++++ p2p/monero-peer/src/network_zones.rs | 3 + p2p/monero-peer/src/network_zones/clear.rs | 43 ++ p2p/monero-peer/src/protocol.rs | 130 ++++ p2p/monero-peer/src/protocol/try_from.rs | 179 ++++++ p2p/monero-peer/src/services.rs | 61 ++ p2p/monero-peer/tests/handshake.rs | 125 ++++ p2p/monero-peer/tests/utils.rs | 95 +++ p2p/src/address_book.rs | 187 +++--- p2p/src/address_book/addr_book_client.rs | 125 ++-- p2p/src/address_book/address_book.rs | 586 ++++++++++++++---- .../address_book/address_book/peer_list.rs | 319 +++++----- .../address_book/peer_list/tests.rs | 176 ++++++ p2p/src/address_book/address_book/tests.rs | 81 +++ p2p/src/address_book/connection_handle.rs | 110 ++++ p2p/src/config.rs | 78 +++ p2p/src/connection_counter.rs | 130 ++++ p2p/src/connection_handle.rs | 98 +++ p2p/src/constants.rs | 58 ++ p2p/src/lib.rs | 78 +++ p2p/src/peer.rs | 44 +- p2p/src/peer/client.rs | 126 +++- p2p/src/peer/connection.rs | 161 ++--- p2p/src/peer/connector.rs | 159 +++++ p2p/src/peer/error.rs | 116 ++++ p2p/src/peer/handshaker.rs | 565 +++++++++++------ p2p/src/peer/load_tracked_client.rs | 74 +++ p2p/src/peer/tests/handshake.rs | 2 +- p2p/src/protocol.rs | 24 +- p2p/src/protocol/internal_network.rs | 248 +++----- p2p/src/protocol/internal_network/try_from.rs | 163 +++++ p2p/src/protocol/lib.rs | 13 - p2p/src/protocol/temp_database.rs | 36 -- p2p/sync-states/Cargo.toml | 21 - p2p/sync-states/src/lib.rs | 538 ---------------- p2p/sync-states/tests/mod.rs | 109 ---- test-utils/Cargo.toml | 11 + test-utils/src/lib.rs | 1 + test-utils/src/test_netzone.rs | 109 ++++ 59 files changed, 5079 insertions(+), 1902 deletions(-) create mode 100644 .github/actions/monerod-regtest/action.yml create mode 100644 p2p/monero-peer/Cargo.toml create mode 100644 p2p/monero-peer/src/client.rs create mode 100644 p2p/monero-peer/src/client/conector.rs create mode 100644 p2p/monero-peer/src/client/connection.rs create mode 100644 p2p/monero-peer/src/client/handshaker.rs create mode 100644 p2p/monero-peer/src/error.rs create mode 100644 p2p/monero-peer/src/lib.rs create mode 100644 p2p/monero-peer/src/network_zones.rs create mode 100644 p2p/monero-peer/src/network_zones/clear.rs create mode 100644 p2p/monero-peer/src/protocol.rs create mode 100644 p2p/monero-peer/src/protocol/try_from.rs create mode 100644 p2p/monero-peer/src/services.rs create mode 100644 p2p/monero-peer/tests/handshake.rs create mode 100644 p2p/monero-peer/tests/utils.rs create mode 100644 p2p/src/address_book/address_book/peer_list/tests.rs create mode 100644 p2p/src/address_book/address_book/tests.rs create mode 100644 p2p/src/address_book/connection_handle.rs create mode 100644 p2p/src/config.rs create mode 100644 p2p/src/connection_counter.rs create mode 100644 p2p/src/connection_handle.rs create mode 100644 p2p/src/constants.rs create mode 100644 p2p/src/peer/connector.rs create mode 100644 p2p/src/peer/error.rs create mode 100644 p2p/src/peer/load_tracked_client.rs create mode 100644 p2p/src/protocol/internal_network/try_from.rs delete mode 100644 p2p/src/protocol/lib.rs delete mode 100644 p2p/src/protocol/temp_database.rs delete mode 100644 p2p/sync-states/Cargo.toml delete mode 100644 p2p/sync-states/src/lib.rs delete mode 100644 p2p/sync-states/tests/mod.rs create mode 100644 test-utils/Cargo.toml create mode 100644 test-utils/src/lib.rs create mode 100644 test-utils/src/test_netzone.rs diff --git a/.github/actions/monerod-regtest/action.yml b/.github/actions/monerod-regtest/action.yml new file mode 100644 index 0000000..73551b2 --- /dev/null +++ b/.github/actions/monerod-regtest/action.yml @@ -0,0 +1,62 @@ +# MIT License +# +# Copyright (c) 2022-2023 Luke Parker +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# Initially taken from Serai Dex: https://github.com/serai-dex/serai/blob/b823413c9b7ae6747b9af99e18379cfc49f4271a/.github/actions/monero/action.yml. + + + +name: monero-regtest +description: Spawns a regtest Monero daemon + +inputs: + version: + description: "Version to download and run" + required: false + default: v0.18.2.0 + +runs: + using: "composite" + steps: + - name: Monero Daemon Cache + id: cache-monerod + uses: actions/cache@704facf57e6136b1bc63b828d79edcd491f0ee84 + with: + path: monerod + key: monerod-${{ runner.os }}-${{ runner.arch }}-${{ inputs.version }} + + - name: Download the Monero Daemon + if: steps.cache-monerod.outputs.cache-hit != 'true' + # Calculates OS/ARCH to demonstrate it, yet then locks to linux-x64 due + # to the contained folder not following the same naming scheme and + # requiring further expansion not worth doing right now + shell: bash + run: | + RUNNER_OS=${{ runner.os }} + RUNNER_ARCH=${{ runner.arch }} + + RUNNER_OS=${RUNNER_OS,,} + RUNNER_ARCH=${RUNNER_ARCH,,} + + RUNNER_OS=linux + RUNNER_ARCH=x64 + + FILE=monero-$RUNNER_OS-$RUNNER_ARCH-${{ inputs.version }}.tar.bz2 + wget https://downloads.getmonero.org/cli/$FILE + tar -xvf $FILE + + mv monero-x86_64-linux-gnu-${{ inputs.version }}/monerod monerod + + - name: Monero Regtest Daemon + shell: bash + run: ./monerod --regtest --fixed-difficulty=1 --detach --out-peers 0 \ No newline at end of file diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 687c556..997ce3d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -35,6 +35,9 @@ jobs: path: target key: ${{ matrix.os }} + - name: Spawn monerod + uses: ./.github/actions/monerod-regtest + - name: Install dependencies run: sudo apt install -y libboost-dev diff --git a/Cargo.toml b/Cargo.toml index b37f88c..52aa2a4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,8 +9,8 @@ members = [ # "database", "net/levin", "net/monero-wire", - # "p2p", - # "p2p/sync-states" + "p2p/monero-peer", + "test-utils" ] [profile.release] diff --git a/common/src/pruning.rs b/common/src/pruning.rs index ccd1d55..b78cde8 100644 --- a/common/src/pruning.rs +++ b/common/src/pruning.rs @@ -50,6 +50,7 @@ pub enum PruningError { /// // Internally we use an Option<u32> to represent if a pruning seed is 0 (None)which means // no pruning will take place. +#[derive(Debug, Clone, Copy)] pub struct PruningSeed(Option<u32>); impl PruningSeed { diff --git a/consensus/src/rpc/connection.rs b/consensus/src/rpc/connection.rs index 7873cff..c54e74e 100644 --- a/consensus/src/rpc/connection.rs +++ b/consensus/src/rpc/connection.rs @@ -10,11 +10,11 @@ use std::{ use curve25519_dalek::edwards::CompressedEdwardsY; use futures::{ channel::{mpsc, oneshot}, - ready, FutureExt, SinkExt, StreamExt, TryStreamExt, + FutureExt, StreamExt, }; use monero_serai::{ block::Block, - rpc::{HttpRpc, Rpc, RpcError}, + rpc::{HttpRpc, Rpc}, transaction::Transaction, }; use monero_wire::common::{BlockCompleteEntry, TransactionBlobs}; @@ -216,7 +216,7 @@ impl RpcConnection { let blocks: Response = monero_epee_bin_serde::from_bytes(res)?; - Ok(rayon_spawn_async(|| { + rayon_spawn_async(|| { blocks .blocks .into_par_iter() @@ -237,7 +237,7 @@ impl RpcConnection { }) .collect::<Result<_, tower::BoxError>>() }) - .await?) + .await } async fn get_outputs( diff --git a/net/levin/src/codec.rs b/net/levin/src/codec.rs index 8333ffa..37f6579 100644 --- a/net/levin/src/codec.rs +++ b/net/levin/src/codec.rs @@ -22,36 +22,79 @@ use bytes::{Buf, BufMut, BytesMut}; use tokio_util::codec::{Decoder, Encoder}; use crate::{ - Bucket, BucketBuilder, BucketError, BucketHead, LevinBody, MessageType, - LEVIN_DEFAULT_MAX_PACKET_SIZE, + Bucket, BucketBuilder, BucketError, BucketHead, LevinBody, LevinCommand, MessageType, Protocol, }; -/// The levin tokio-codec for decoding and encoding levin buckets -#[derive(Default)] -pub enum LevinCodec { +#[derive(Debug, Clone)] +pub enum LevinBucketState<C> { /// Waiting for the peer to send a header. - #[default] WaitingForHeader, /// Waiting for a peer to send a body. - WaitingForBody(BucketHead), + WaitingForBody(BucketHead<C>), } -impl Decoder for LevinCodec { - type Item = Bucket; +/// The levin tokio-codec for decoding and encoding raw levin buckets +/// +#[derive(Debug, Clone)] +pub struct LevinBucketCodec<C> { + state: LevinBucketState<C>, + protocol: Protocol, + handshake_message_seen: bool, +} + +impl<C> Default for LevinBucketCodec<C> { + fn default() -> Self { + LevinBucketCodec { + state: LevinBucketState::WaitingForHeader, + protocol: Protocol::default(), + handshake_message_seen: false, + } + } +} + +impl<C> LevinBucketCodec<C> { + pub fn new(protocol: Protocol) -> Self { + LevinBucketCodec { + state: LevinBucketState::WaitingForHeader, + protocol, + handshake_message_seen: false, + } + } +} + +impl<C: LevinCommand> Decoder for LevinBucketCodec<C> { + type Item = Bucket<C>; type Error = BucketError; fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> { loop { - match self { - LevinCodec::WaitingForHeader => { - if src.len() < BucketHead::SIZE { + match &self.state { + LevinBucketState::WaitingForHeader => { + if src.len() < BucketHead::<C>::SIZE { return Ok(None); }; - let head = BucketHead::from_bytes(src)?; - let _ = std::mem::replace(self, LevinCodec::WaitingForBody(head)); + let head = BucketHead::<C>::from_bytes(src); + + if head.size > self.protocol.max_packet_size + || head.size > head.command.bucket_size_limit() + { + return Err(BucketError::BucketExceededMaxSize); + } + + if !self.handshake_message_seen { + if head.size > self.protocol.max_packet_size_before_handshake { + return Err(BucketError::BucketExceededMaxSize); + } + + if head.command.is_handshake() { + self.handshake_message_seen = true; + } + } + + let _ = + std::mem::replace(&mut self.state, LevinBucketState::WaitingForBody(head)); } - LevinCodec::WaitingForBody(head) => { - // We size check header while decoding it. + LevinBucketState::WaitingForBody(head) => { let body_len = head .size .try_into() @@ -61,8 +104,8 @@ impl Decoder for LevinCodec { return Ok(None); } - let LevinCodec::WaitingForBody(header) = - std::mem::replace(self, LevinCodec::WaitingForHeader) + let LevinBucketState::WaitingForBody(header) = + std::mem::replace(&mut self.state, LevinBucketState::WaitingForHeader) else { unreachable!() }; @@ -77,10 +120,10 @@ impl Decoder for LevinCodec { } } -impl Encoder<Bucket> for LevinCodec { +impl<C: LevinCommand> Encoder<Bucket<C>> for LevinBucketCodec<C> { type Error = BucketError; - fn encode(&mut self, item: Bucket, dst: &mut BytesMut) -> Result<(), Self::Error> { - if dst.capacity() < BucketHead::SIZE + item.body.len() { + fn encode(&mut self, item: Bucket<C>, dst: &mut BytesMut) -> Result<(), Self::Error> { + if dst.capacity() < BucketHead::<C>::SIZE + item.body.len() { return Err(BucketError::IO(std::io::Error::new( ErrorKind::OutOfMemory, "Not enough capacity to write the bucket", @@ -92,19 +135,30 @@ impl Encoder<Bucket> for LevinCodec { } } -#[derive(Default)] -enum MessageState { +#[derive(Default, Debug, Clone)] +enum MessageState<C> { #[default] WaitingForBucket, - WaitingForRestOfFragment(Vec<u8>, MessageType, u32), + WaitingForRestOfFragment(Vec<u8>, MessageType, C), } /// A tokio-codec for levin messages or in other words the decoded body /// of a levin bucket. -pub struct LevinMessageCodec<T> { +#[derive(Debug, Clone)] +pub struct LevinMessageCodec<T: LevinBody> { message_ty: PhantomData<T>, - bucket_codec: LevinCodec, - state: MessageState, + bucket_codec: LevinBucketCodec<T::Command>, + state: MessageState<T::Command>, +} + +impl<T: LevinBody> Default for LevinMessageCodec<T> { + fn default() -> Self { + Self { + message_ty: Default::default(), + bucket_codec: Default::default(), + state: Default::default(), + } + } } impl<T: LevinBody> Decoder for LevinMessageCodec<T> { @@ -118,23 +172,20 @@ impl<T: LevinBody> Decoder for LevinMessageCodec<T> { return Ok(None); }; - let end_fragment = bucket.header.flags.end_fragment; - let start_fragment = bucket.header.flags.start_fragment; - let request = bucket.header.flags.request; - let response = bucket.header.flags.response; + let flags = &bucket.header.flags; - if start_fragment && end_fragment { + if flags.is_start_fragment() && flags.is_end_fragment() { // Dummy message return Ok(None); }; - if end_fragment { + if flags.is_end_fragment() { return Err(BucketError::InvalidHeaderFlags( "Flag end fragment received before a start fragment", )); }; - if !request && !response { + if !flags.is_request() && !flags.is_response() { return Err(BucketError::InvalidHeaderFlags( "Request and response flags both not set", )); @@ -145,13 +196,13 @@ impl<T: LevinBody> Decoder for LevinMessageCodec<T> { bucket.header.have_to_return_data, )?; - if start_fragment { + if flags.is_start_fragment() { let _ = std::mem::replace( &mut self.state, MessageState::WaitingForRestOfFragment( bucket.body.to_vec(), message_type, - bucket.header.protocol_version, + bucket.header.command, ), ); @@ -169,17 +220,14 @@ impl<T: LevinBody> Decoder for LevinMessageCodec<T> { return Ok(None); }; - let end_fragment = bucket.header.flags.end_fragment; - let start_fragment = bucket.header.flags.start_fragment; - let request = bucket.header.flags.request; - let response = bucket.header.flags.response; + let flags = &bucket.header.flags; - if start_fragment && end_fragment { + if flags.is_start_fragment() && flags.is_end_fragment() { // Dummy message return Ok(None); }; - if !request && !response { + if !flags.is_request() && !flags.is_response() { return Err(BucketError::InvalidHeaderFlags( "Request and response flags both not set", )); @@ -198,12 +246,12 @@ impl<T: LevinBody> Decoder for LevinMessageCodec<T> { if bucket.header.command != *command { return Err(BucketError::InvalidFragmentedMessage( - "Command not consistent across message", + "Command not consistent across fragments", )); } - if bytes.len() + bucket.body.len() - > LEVIN_DEFAULT_MAX_PACKET_SIZE.try_into().unwrap() + if bytes.len().saturating_add(bucket.body.len()) + > command.bucket_size_limit().try_into().unwrap() { return Err(BucketError::InvalidFragmentedMessage( "Fragmented message exceeded maximum size", @@ -212,7 +260,7 @@ impl<T: LevinBody> Decoder for LevinMessageCodec<T> { bytes.append(&mut bucket.body.to_vec()); - if end_fragment { + if flags.is_end_fragment() { let MessageState::WaitingForRestOfFragment(bytes, ty, command) = std::mem::replace(&mut self.state, MessageState::WaitingForBucket) else { diff --git a/net/levin/src/header.rs b/net/levin/src/header.rs index d6b930f..3435293 100644 --- a/net/levin/src/header.rs +++ b/net/levin/src/header.rs @@ -13,13 +13,27 @@ // copies or substantial portions of the Software. // +// Rust Levin Library +// Written in 2023 by +// Cuprate Contributors +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// + //! This module provides a struct BucketHead for the header of a levin protocol //! message. -use crate::LEVIN_DEFAULT_MAX_PACKET_SIZE; use bytes::{Buf, BufMut, BytesMut}; -use super::{BucketError, LEVIN_SIGNATURE, PROTOCOL_VERSION}; +use crate::LevinCommand; const REQUEST: u32 = 0b0000_0001; const RESPONSE: u32 = 0b0000_0010; @@ -28,57 +42,41 @@ const END_FRAGMENT: u32 = 0b0000_1000; /// Levin header flags #[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] -pub struct Flags { - /// Q bit - pub request: bool, - /// S bit - pub response: bool, - /// B bit - pub start_fragment: bool, - /// E bit - pub end_fragment: bool, +pub struct Flags(u32); + +impl Flags { + pub const REQUEST: Flags = Flags(REQUEST); + pub const RESPONSE: Flags = Flags(RESPONSE); + + pub fn is_request(&self) -> bool { + self.0 & REQUEST != 0 + } + pub fn is_response(&self) -> bool { + self.0 & RESPONSE != 0 + } + pub fn is_start_fragment(&self) -> bool { + self.0 & START_FRAGMENT != 0 + } + pub fn is_end_fragment(&self) -> bool { + self.0 & END_FRAGMENT != 0 + } } -impl TryFrom<u32> for Flags { - type Error = BucketError; - fn try_from(value: u32) -> Result<Self, Self::Error> { - let flags = Flags { - request: value & REQUEST > 0, - response: value & RESPONSE > 0, - start_fragment: value & START_FRAGMENT > 0, - end_fragment: value & END_FRAGMENT > 0, - }; - if flags.request && flags.response { - return Err(BucketError::InvalidHeaderFlags( - "Request and Response bits set", - )); - }; - Ok(flags) +impl From<u32> for Flags { + fn from(value: u32) -> Self { + Flags(value) } } impl From<Flags> for u32 { fn from(value: Flags) -> Self { - let mut ret = 0; - if value.request { - ret |= REQUEST; - }; - if value.response { - ret |= RESPONSE; - }; - if value.start_fragment { - ret |= START_FRAGMENT; - }; - if value.end_fragment { - ret |= END_FRAGMENT; - }; - ret + value.0 } } /// The Header of a Bucket. This contains #[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub struct BucketHead { +pub struct BucketHead<C> { /// The network signature, should be `LEVIN_SIGNATURE` for Monero pub signature: u64, /// The size of the body @@ -87,7 +85,7 @@ pub struct BucketHead { /// messages require responses but don't have this set (some notifications) pub have_to_return_data: bool, /// Command - pub command: u32, + pub command: C, /// Return Code - will be 0 for requests and >0 for ok responses otherwise will be /// a negative number corresponding to the error pub return_code: i32, @@ -97,61 +95,36 @@ pub struct BucketHead { pub protocol_version: u32, } -impl BucketHead { +impl<C: LevinCommand> BucketHead<C> { /// The size of the header (in bytes) pub const SIZE: usize = 33; - /// Builds the header in a Monero specific way - pub fn build_monero( - payload_size: u64, - have_to_return_data: bool, - command: u32, - flags: Flags, - return_code: i32, - ) -> BucketHead { - BucketHead { - signature: LEVIN_SIGNATURE, - size: payload_size, - have_to_return_data, - command, - return_code, - flags, - protocol_version: PROTOCOL_VERSION, - } - } - /// Builds the header from bytes, this function does not check any fields should - /// match the expected ones (signature, protocol_version) + /// match the expected ones. /// /// # Panics /// This function will panic if there aren't enough bytes to fill the header. /// Currently ['SIZE'](BucketHead::SIZE) - pub fn from_bytes(buf: &mut BytesMut) -> Result<BucketHead, BucketError> { - let header = BucketHead { + pub fn from_bytes(buf: &mut BytesMut) -> BucketHead<C> { + BucketHead { signature: buf.get_u64_le(), size: buf.get_u64_le(), have_to_return_data: buf.get_u8() != 0, - command: buf.get_u32_le(), + command: buf.get_u32_le().into(), return_code: buf.get_i32_le(), - flags: Flags::try_from(buf.get_u32_le())?, + flags: Flags::from(buf.get_u32_le()), protocol_version: buf.get_u32_le(), - }; - - if header.size > LEVIN_DEFAULT_MAX_PACKET_SIZE { - return Err(BucketError::BucketExceededMaxSize); } - - Ok(header) } /// Serializes the header pub fn write_bytes(&self, dst: &mut BytesMut) { - dst.reserve(BucketHead::SIZE); + dst.reserve(Self::SIZE); dst.put_u64_le(self.signature); dst.put_u64_le(self.size); dst.put_u8(if self.have_to_return_data { 1 } else { 0 }); - dst.put_u32_le(self.command); + dst.put_u32_le(self.command.clone().into()); dst.put_i32_le(self.return_code); dst.put_u32_le(self.flags.into()); dst.put_u32_le(self.protocol_version); diff --git a/net/levin/src/lib.rs b/net/levin/src/lib.rs index 7442f8b..22ffd12 100644 --- a/net/levin/src/lib.rs +++ b/net/levin/src/lib.rs @@ -36,16 +36,17 @@ pub mod codec; pub mod header; -pub use codec::LevinCodec; +pub use codec::*; pub use header::BucketHead; use std::fmt::Debug; use thiserror::Error; -const PROTOCOL_VERSION: u32 = 1; -const LEVIN_SIGNATURE: u64 = 0x0101010101012101; -const LEVIN_DEFAULT_MAX_PACKET_SIZE: u64 = 100_000_000; // 100MB +const MONERO_PROTOCOL_VERSION: u32 = 1; +const MONERO_LEVIN_SIGNATURE: u64 = 0x0101010101012101; +const MONERO_MAX_PACKET_SIZE_BEFORE_HANDSHAKE: u64 = 256 * 1000; // 256 KiB +const MONERO_MAX_PACKET_SIZE: u64 = 100_000_000; // 100MB /// Possible Errors when working with levin buckets #[derive(Error, Debug)] @@ -59,28 +60,53 @@ pub enum BucketError { /// Invalid Fragmented Message #[error("Levin fragmented message was invalid: {0}")] InvalidFragmentedMessage(&'static str), + /// The Header did not have the correct signature + #[error("Levin header had incorrect signature")] + InvalidHeaderSignature, /// Error decoding the body - #[error("Error decoding bucket body: {0}")] - BodyDecodingError(Box<dyn std::error::Error>), - /// The levin command is unknown - #[error("The levin command is unknown")] + #[error("Error decoding bucket body")] + BodyDecodingError(Box<dyn std::error::Error + Send + Sync>), + /// Unknown command ID + #[error("Unknown command ID")] UnknownCommand, /// I/O error #[error("I/O error: {0}")] IO(#[from] std::io::Error), } +/// Levin protocol settings, allows setting custom parameters. +/// +/// For Monero use [`Protocol::default()`] +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub struct Protocol { + pub version: u32, + pub signature: u64, + pub max_packet_size_before_handshake: u64, + pub max_packet_size: u64, +} + +impl Default for Protocol { + fn default() -> Self { + Protocol { + version: MONERO_PROTOCOL_VERSION, + signature: MONERO_LEVIN_SIGNATURE, + max_packet_size_before_handshake: MONERO_MAX_PACKET_SIZE_BEFORE_HANDSHAKE, + max_packet_size: MONERO_MAX_PACKET_SIZE, + } + } +} + /// A levin Bucket #[derive(Debug)] -pub struct Bucket { +pub struct Bucket<C> { /// The bucket header - pub header: BucketHead, + pub header: BucketHead<C>, /// The bucket body pub body: Vec<u8>, } /// An enum representing if the message is a request, response or notification. -#[derive(Debug, Eq, PartialEq)] +#[derive(Debug, Eq, PartialEq, Clone, Copy)] pub enum MessageType { /// Request Request, @@ -104,11 +130,11 @@ impl MessageType { flags: header::Flags, have_to_return: bool, ) -> Result<Self, BucketError> { - if flags.request && have_to_return { + if flags.is_request() && have_to_return { Ok(MessageType::Request) - } else if flags.request { + } else if flags.is_request() { Ok(MessageType::Notification) - } else if flags.response && !have_to_return { + } else if flags.is_response() && !have_to_return { Ok(MessageType::Response) } else { Err(BucketError::InvalidHeaderFlags( @@ -119,42 +145,36 @@ impl MessageType { pub fn as_flags(&self) -> header::Flags { match self { - MessageType::Request | MessageType::Notification => header::Flags { - request: true, - ..Default::default() - }, - MessageType::Response => header::Flags { - response: true, - ..Default::default() - }, + MessageType::Request | MessageType::Notification => header::Flags::REQUEST, + MessageType::Response => header::Flags::RESPONSE, } } } #[derive(Debug)] -pub struct BucketBuilder { +pub struct BucketBuilder<C> { signature: Option<u64>, ty: Option<MessageType>, - command: Option<u32>, + command: Option<C>, return_code: Option<i32>, protocol_version: Option<u32>, body: Option<Vec<u8>>, } -impl Default for BucketBuilder { +impl<C> Default for BucketBuilder<C> { fn default() -> Self { Self { - signature: Some(LEVIN_SIGNATURE), + signature: Some(MONERO_LEVIN_SIGNATURE), ty: None, command: None, return_code: None, - protocol_version: Some(PROTOCOL_VERSION), + protocol_version: Some(MONERO_PROTOCOL_VERSION), body: None, } } } -impl BucketBuilder { +impl<C: LevinCommand> BucketBuilder<C> { pub fn set_signature(&mut self, sig: u64) { self.signature = Some(sig) } @@ -163,7 +183,7 @@ impl BucketBuilder { self.ty = Some(ty) } - pub fn set_command(&mut self, command: u32) { + pub fn set_command(&mut self, command: C) { self.command = Some(command) } @@ -179,7 +199,7 @@ impl BucketBuilder { self.body = Some(body) } - pub fn finish(self) -> Bucket { + pub fn finish(self) -> Bucket<C> { let body = self.body.unwrap(); let ty = self.ty.unwrap(); Bucket { @@ -199,9 +219,28 @@ impl BucketBuilder { /// A levin body pub trait LevinBody: Sized { + type Command: LevinCommand; + /// Decodes the message from the data in the header - fn decode_message(body: &[u8], typ: MessageType, command: u32) -> Result<Self, BucketError>; + fn decode_message( + body: &[u8], + typ: MessageType, + command: Self::Command, + ) -> Result<Self, BucketError>; /// Encodes the message - fn encode(&self, builder: &mut BucketBuilder) -> Result<(), BucketError>; + fn encode(&self, builder: &mut BucketBuilder<Self::Command>) -> Result<(), BucketError>; +} + +/// The levin commands. +/// +/// Implementers should account for all possible u32 values, this means +/// you will probably need some sort of `Unknown` variant. +pub trait LevinCommand: From<u32> + Into<u32> + PartialEq + Clone { + /// Returns the size limit for this command. + /// + /// must be less than [`usize::MAX`] + fn bucket_size_limit(&self) -> u64; + /// Returns if this is a handshake + fn is_handshake(&self) -> bool; } diff --git a/net/monero-wire/src/lib.rs b/net/monero-wire/src/lib.rs index 232b86c..019756a 100644 --- a/net/monero-wire/src/lib.rs +++ b/net/monero-wire/src/lib.rs @@ -22,18 +22,12 @@ //! //! This project is licensed under the MIT License. -// Coding conventions -#![forbid(unsafe_code)] -#![deny(non_upper_case_globals)] -#![deny(non_camel_case_types)] -#![deny(unused_mut)] -//#![deny(missing_docs)] - pub mod network_address; pub mod p2p; mod serde_helpers; -pub use network_address::NetworkAddress; +pub use levin_cuprate::BucketError; +pub use network_address::{NetZone, NetworkAddress}; pub use p2p::*; pub type MoneroWireCodec = levin_cuprate::codec::LevinMessageCodec<Message>; diff --git a/net/monero-wire/src/network_address.rs b/net/monero-wire/src/network_address.rs index d1aab44..34ed812 100644 --- a/net/monero-wire/src/network_address.rs +++ b/net/monero-wire/src/network_address.rs @@ -17,8 +17,7 @@ //! Monero network. Core Monero has 4 main addresses: IPv4, IPv6, Tor, //! I2p. Currently this module only has IPv(4/6). //! -use std::net::{SocketAddrV4, SocketAddrV6}; -use std::{hash::Hash, net}; +use std::{hash::Hash, net, net::SocketAddr}; use serde::{Deserialize, Serialize}; @@ -38,16 +37,13 @@ pub enum NetZone { #[serde(try_from = "TaggedNetworkAddress")] #[serde(into = "TaggedNetworkAddress")] pub enum NetworkAddress { - /// IPv4 - IPv4(SocketAddrV4), - /// IPv6 - IPv6(SocketAddrV6), + Clear(SocketAddr), } impl NetworkAddress { pub fn get_zone(&self) -> NetZone { match self { - NetworkAddress::IPv4(_) | NetworkAddress::IPv6(_) => NetZone::Public, + NetworkAddress::Clear(_) => NetZone::Public, } } @@ -63,29 +59,42 @@ impl NetworkAddress { pub fn port(&self) -> u16 { match self { - NetworkAddress::IPv4(ip) => ip.port(), - NetworkAddress::IPv6(ip) => ip.port(), + NetworkAddress::Clear(ip) => ip.port(), } } } impl From<net::SocketAddrV4> for NetworkAddress { fn from(value: net::SocketAddrV4) -> Self { - NetworkAddress::IPv4(value) + NetworkAddress::Clear(value.into()) } } impl From<net::SocketAddrV6> for NetworkAddress { fn from(value: net::SocketAddrV6) -> Self { - NetworkAddress::IPv6(value) + NetworkAddress::Clear(value.into()) } } -impl From<net::SocketAddr> for NetworkAddress { - fn from(value: net::SocketAddr) -> Self { +impl From<SocketAddr> for NetworkAddress { + fn from(value: SocketAddr) -> Self { match value { - net::SocketAddr::V4(v4) => v4.into(), - net::SocketAddr::V6(v6) => v6.into(), + SocketAddr::V4(v4) => v4.into(), + SocketAddr::V6(v6) => v6.into(), + } + } +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq, thiserror::Error)] +#[error("Network address is not in the correct zone")] +pub struct NetworkAddressIncorrectZone; + +impl TryFrom<NetworkAddress> for SocketAddr { + type Error = NetworkAddressIncorrectZone; + fn try_from(value: NetworkAddress) -> Result<Self, Self::Error> { + match value { + NetworkAddress::Clear(addr) => Ok(addr), + //_ => Err(NetworkAddressIncorrectZone) } } } diff --git a/net/monero-wire/src/network_address/serde_helper.rs b/net/monero-wire/src/network_address/serde_helper.rs index d08533f..2349161 100644 --- a/net/monero-wire/src/network_address/serde_helper.rs +++ b/net/monero-wire/src/network_address/serde_helper.rs @@ -1,4 +1,4 @@ -use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6}; +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use serde::{Deserialize, Serialize}; use thiserror::Error; @@ -30,20 +30,22 @@ impl TryFrom<TaggedNetworkAddress> for NetworkAddress { impl From<NetworkAddress> for TaggedNetworkAddress { fn from(value: NetworkAddress) -> Self { match value { - NetworkAddress::IPv4(addr) => TaggedNetworkAddress { - ty: 1, - addr: AllFieldsNetworkAddress { - m_ip: Some(u32::from_be_bytes(addr.ip().octets())), - m_port: Some(addr.port()), - ..Default::default() + NetworkAddress::Clear(addr) => match addr { + SocketAddr::V4(addr) => TaggedNetworkAddress { + ty: 1, + addr: AllFieldsNetworkAddress { + m_ip: Some(u32::from_be_bytes(addr.ip().octets())), + m_port: Some(addr.port()), + ..Default::default() + }, }, - }, - NetworkAddress::IPv6(addr) => TaggedNetworkAddress { - ty: 2, - addr: AllFieldsNetworkAddress { - addr: Some(addr.ip().octets()), - m_port: Some(addr.port()), - ..Default::default() + SocketAddr::V6(addr) => TaggedNetworkAddress { + ty: 2, + addr: AllFieldsNetworkAddress { + addr: Some(addr.ip().octets()), + m_port: Some(addr.port()), + ..Default::default() + }, }, }, } @@ -63,8 +65,8 @@ struct AllFieldsNetworkAddress { impl AllFieldsNetworkAddress { fn try_into_network_address(self, ty: u8) -> Option<NetworkAddress> { Some(match ty { - 1 => NetworkAddress::IPv4(SocketAddrV4::new(Ipv4Addr::from(self.m_ip?), self.m_port?)), - 2 => NetworkAddress::IPv6(SocketAddrV6::new( + 1 => NetworkAddress::from(SocketAddrV4::new(Ipv4Addr::from(self.m_ip?), self.m_port?)), + 2 => NetworkAddress::from(SocketAddrV6::new( Ipv6Addr::from(self.addr?), self.m_port?, 0, diff --git a/net/monero-wire/src/p2p.rs b/net/monero-wire/src/p2p.rs index 48b316a..86c258e 100644 --- a/net/monero-wire/src/p2p.rs +++ b/net/monero-wire/src/p2p.rs @@ -16,7 +16,10 @@ //! This module defines a Monero `Message` enum which contains //! every possible Monero network message (levin body) -use levin_cuprate::{BucketBuilder, BucketError, LevinBody, MessageType}; +use levin_cuprate::{ + BucketBuilder, BucketError, LevinBody, LevinCommand as LevinCommandTrait, MessageType, +}; +use std::fmt::Formatter; pub mod admin; pub mod common; @@ -26,6 +29,127 @@ use admin::*; pub use common::{BasicNodeData, CoreSyncData, PeerListEntryBase}; use protocol::*; +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +pub enum LevinCommand { + Handshake, + TimedSync, + Ping, + SupportFlags, + + NewBlock, + NewTransactions, + GetObjectsRequest, + GetObjectsResponse, + ChainRequest, + ChainResponse, + NewFluffyBlock, + FluffyMissingTxsRequest, + GetTxPoolCompliment, + + Unknown(u32), +} + +impl std::fmt::Display for LevinCommand { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + if let LevinCommand::Unknown(id) = self { + return f.write_str(&format!("unknown id: {}", id)); + } + + f.write_str(match self { + LevinCommand::Handshake => "handshake", + LevinCommand::TimedSync => "timed sync", + LevinCommand::Ping => "ping", + LevinCommand::SupportFlags => "support flags", + + LevinCommand::NewBlock => "new block", + LevinCommand::NewTransactions => "new transactions", + LevinCommand::GetObjectsRequest => "get objects request", + LevinCommand::GetObjectsResponse => "get objects response", + LevinCommand::ChainRequest => "chain request", + LevinCommand::ChainResponse => "chain response", + LevinCommand::NewFluffyBlock => "new fluffy block", + LevinCommand::FluffyMissingTxsRequest => "fluffy missing transaction request", + LevinCommand::GetTxPoolCompliment => "get transaction pool compliment", + + LevinCommand::Unknown(_) => unreachable!(), + }) + } +} + +impl LevinCommandTrait for LevinCommand { + fn bucket_size_limit(&self) -> u64 { + // https://github.com/monero-project/monero/blob/00fd416a99686f0956361d1cd0337fe56e58d4a7/src/cryptonote_basic/connection_context.cpp#L37 + match self { + LevinCommand::Handshake => 65536, + LevinCommand::TimedSync => 65536, + LevinCommand::Ping => 4096, + LevinCommand::SupportFlags => 4096, + + LevinCommand::NewBlock => 1024 * 1024 * 128, // 128 MB (max packet is a bit less than 100 MB though) + LevinCommand::NewTransactions => 1024 * 1024 * 128, // 128 MB (max packet is a bit less than 100 MB though) + LevinCommand::GetObjectsRequest => 1024 * 1024 * 2, // 2 MB + LevinCommand::GetObjectsResponse => 1024 * 1024 * 128, // 128 MB (max packet is a bit less than 100 MB though) + LevinCommand::ChainRequest => 512 * 1024, // 512 kB + LevinCommand::ChainResponse => 1024 * 1024 * 4, // 4 MB + LevinCommand::NewFluffyBlock => 1024 * 1024 * 4, // 4 MB + LevinCommand::FluffyMissingTxsRequest => 1024 * 1024, // 1 MB + LevinCommand::GetTxPoolCompliment => 1024 * 1024 * 4, // 4 MB + + LevinCommand::Unknown(_) => usize::MAX.try_into().unwrap_or(u64::MAX), + } + } + + fn is_handshake(&self) -> bool { + matches!(self, LevinCommand::Handshake) + } +} + +impl From<u32> for LevinCommand { + fn from(value: u32) -> Self { + match value { + 1001 => LevinCommand::Handshake, + 1002 => LevinCommand::TimedSync, + 1003 => LevinCommand::Ping, + 1007 => LevinCommand::SupportFlags, + + 2001 => LevinCommand::NewBlock, + 2002 => LevinCommand::NewTransactions, + 2003 => LevinCommand::GetObjectsRequest, + 2004 => LevinCommand::GetObjectsResponse, + 2006 => LevinCommand::ChainRequest, + 2007 => LevinCommand::ChainResponse, + 2008 => LevinCommand::NewFluffyBlock, + 2009 => LevinCommand::FluffyMissingTxsRequest, + 2010 => LevinCommand::GetTxPoolCompliment, + + x => LevinCommand::Unknown(x), + } + } +} + +impl From<LevinCommand> for u32 { + fn from(value: LevinCommand) -> Self { + match value { + LevinCommand::Handshake => 1001, + LevinCommand::TimedSync => 1002, + LevinCommand::Ping => 1003, + LevinCommand::SupportFlags => 1007, + + LevinCommand::NewBlock => 2001, + LevinCommand::NewTransactions => 2002, + LevinCommand::GetObjectsRequest => 2003, + LevinCommand::GetObjectsResponse => 2004, + LevinCommand::ChainRequest => 2006, + LevinCommand::ChainResponse => 2007, + LevinCommand::NewFluffyBlock => 2008, + LevinCommand::FluffyMissingTxsRequest => 2009, + LevinCommand::GetTxPoolCompliment => 2010, + + LevinCommand::Unknown(x) => x, + } + } +} + fn decode_message<T: serde::de::DeserializeOwned, Ret>( ret: impl FnOnce(T) -> Ret, buf: &[u8], @@ -36,9 +160,9 @@ fn decode_message<T: serde::de::DeserializeOwned, Ret>( } fn build_message<T: serde::Serialize>( - id: u32, + id: LevinCommand, val: &T, - builder: &mut BucketBuilder, + builder: &mut BucketBuilder<LevinCommand>, ) -> Result<(), BucketError> { builder.set_command(id); builder.set_body( @@ -61,34 +185,66 @@ pub enum ProtocolMessage { } impl ProtocolMessage { - fn decode(buf: &[u8], command: u32) -> Result<Self, BucketError> { + pub fn command(&self) -> LevinCommand { + use LevinCommand as C; + + match self { + ProtocolMessage::NewBlock(_) => C::NewBlock, + ProtocolMessage::NewFluffyBlock(_) => C::NewFluffyBlock, + ProtocolMessage::GetObjectsRequest(_) => C::GetObjectsRequest, + ProtocolMessage::GetObjectsResponse(_) => C::GetObjectsResponse, + ProtocolMessage::ChainRequest(_) => C::ChainRequest, + ProtocolMessage::ChainEntryResponse(_) => C::ChainResponse, + ProtocolMessage::NewTransactions(_) => C::NewTransactions, + ProtocolMessage::FluffyMissingTransactionsRequest(_) => C::FluffyMissingTxsRequest, + ProtocolMessage::GetTxPoolCompliment(_) => C::GetTxPoolCompliment, + } + } + + fn decode(buf: &[u8], command: LevinCommand) -> Result<Self, BucketError> { + use LevinCommand as C; + Ok(match command { - 2001 => decode_message(ProtocolMessage::NewBlock, buf)?, - 2002 => decode_message(ProtocolMessage::NewTransactions, buf)?, - 2003 => decode_message(ProtocolMessage::GetObjectsRequest, buf)?, - 2004 => decode_message(ProtocolMessage::GetObjectsResponse, buf)?, - 2006 => decode_message(ProtocolMessage::ChainRequest, buf)?, - 2007 => decode_message(ProtocolMessage::ChainEntryResponse, buf)?, - 2008 => decode_message(ProtocolMessage::NewFluffyBlock, buf)?, - 2009 => decode_message(ProtocolMessage::FluffyMissingTransactionsRequest, buf)?, - 2010 => decode_message(ProtocolMessage::GetTxPoolCompliment, buf)?, + C::NewBlock => decode_message(ProtocolMessage::NewBlock, buf)?, + C::NewTransactions => decode_message(ProtocolMessage::NewTransactions, buf)?, + C::GetObjectsRequest => decode_message(ProtocolMessage::GetObjectsRequest, buf)?, + C::GetObjectsResponse => decode_message(ProtocolMessage::GetObjectsResponse, buf)?, + C::ChainRequest => decode_message(ProtocolMessage::ChainRequest, buf)?, + C::ChainResponse => decode_message(ProtocolMessage::ChainEntryResponse, buf)?, + C::NewFluffyBlock => decode_message(ProtocolMessage::NewFluffyBlock, buf)?, + C::FluffyMissingTxsRequest => { + decode_message(ProtocolMessage::FluffyMissingTransactionsRequest, buf)? + } + C::GetTxPoolCompliment => decode_message(ProtocolMessage::GetTxPoolCompliment, buf)?, _ => return Err(BucketError::UnknownCommand), }) } - fn build(&self, builder: &mut BucketBuilder) -> Result<(), BucketError> { + fn build(&self, builder: &mut BucketBuilder<LevinCommand>) -> Result<(), BucketError> { + use LevinCommand as C; + match self { - ProtocolMessage::NewBlock(val) => build_message(2001, val, builder)?, - ProtocolMessage::NewTransactions(val) => build_message(2002, val, builder)?, - ProtocolMessage::GetObjectsRequest(val) => build_message(2003, val, builder)?, - ProtocolMessage::GetObjectsResponse(val) => build_message(2004, val, builder)?, - ProtocolMessage::ChainRequest(val) => build_message(2006, val, builder)?, - ProtocolMessage::ChainEntryResponse(val) => build_message(2007, &val, builder)?, - ProtocolMessage::NewFluffyBlock(val) => build_message(2008, val, builder)?, - ProtocolMessage::FluffyMissingTransactionsRequest(val) => { - build_message(2009, val, builder)? + ProtocolMessage::NewBlock(val) => build_message(C::NewBlock, val, builder)?, + ProtocolMessage::NewTransactions(val) => { + build_message(C::NewTransactions, val, builder)? + } + ProtocolMessage::GetObjectsRequest(val) => { + build_message(C::GetObjectsRequest, val, builder)? + } + ProtocolMessage::GetObjectsResponse(val) => { + build_message(C::GetObjectsResponse, val, builder)? + } + ProtocolMessage::ChainRequest(val) => build_message(C::ChainRequest, val, builder)?, + ProtocolMessage::ChainEntryResponse(val) => { + build_message(C::ChainResponse, &val, builder)? + } + ProtocolMessage::NewFluffyBlock(val) => build_message(C::NewFluffyBlock, val, builder)?, + ProtocolMessage::FluffyMissingTransactionsRequest(val) => { + build_message(C::FluffyMissingTxsRequest, val, builder)? + } + ProtocolMessage::GetTxPoolCompliment(val) => { + build_message(C::GetTxPoolCompliment, val, builder)? } - ProtocolMessage::GetTxPoolCompliment(val) => build_message(2010, val, builder)?, } Ok(()) } @@ -102,26 +258,41 @@ pub enum RequestMessage { } impl RequestMessage { - fn decode(buf: &[u8], command: u32) -> Result<Self, BucketError> { + pub fn command(&self) -> LevinCommand { + use LevinCommand as C; + + match self { + RequestMessage::Handshake(_) => C::Handshake, + RequestMessage::Ping => C::Ping, + RequestMessage::SupportFlags => C::SupportFlags, + RequestMessage::TimedSync(_) => C::TimedSync, + } + } + + fn decode(buf: &[u8], command: LevinCommand) -> Result<Self, BucketError> { + use LevinCommand as C; + Ok(match command { - 1001 => decode_message(RequestMessage::Handshake, buf)?, - 1002 => decode_message(RequestMessage::TimedSync, buf)?, - 1003 => RequestMessage::Ping, - 1007 => RequestMessage::SupportFlags, + C::Handshake => decode_message(RequestMessage::Handshake, buf)?, + C::TimedSync => decode_message(RequestMessage::TimedSync, buf)?, + C::Ping => RequestMessage::Ping, + C::SupportFlags => RequestMessage::SupportFlags, _ => return Err(BucketError::UnknownCommand), }) } - fn build(&self, builder: &mut BucketBuilder) -> Result<(), BucketError> { + fn build(&self, builder: &mut BucketBuilder<LevinCommand>) -> Result<(), BucketError> { + use LevinCommand as C; + match self { - RequestMessage::Handshake(val) => build_message(1001, val, builder)?, - RequestMessage::TimedSync(val) => build_message(1002, val, builder)?, + RequestMessage::Handshake(val) => build_message(C::Handshake, val, builder)?, + RequestMessage::TimedSync(val) => build_message(C::TimedSync, val, builder)?, RequestMessage::Ping => { - builder.set_command(1003); + builder.set_command(C::Ping); builder.set_body(Vec::new()); } RequestMessage::SupportFlags => { - builder.set_command(1007); + builder.set_command(C::SupportFlags); builder.set_body(Vec::new()); } } @@ -137,22 +308,37 @@ pub enum ResponseMessage { } impl ResponseMessage { - fn decode(buf: &[u8], command: u32) -> Result<Self, BucketError> { + pub fn command(&self) -> LevinCommand { + use LevinCommand as C; + + match self { + ResponseMessage::Handshake(_) => C::Handshake, + ResponseMessage::Ping(_) => C::Ping, + ResponseMessage::SupportFlags(_) => C::SupportFlags, + ResponseMessage::TimedSync(_) => C::TimedSync, + } + } + + fn decode(buf: &[u8], command: LevinCommand) -> Result<Self, BucketError> { + use LevinCommand as C; + Ok(match command { - 1001 => decode_message(ResponseMessage::Handshake, buf)?, - 1002 => decode_message(ResponseMessage::TimedSync, buf)?, - 1003 => decode_message(ResponseMessage::Ping, buf)?, - 1007 => decode_message(ResponseMessage::SupportFlags, buf)?, + C::Handshake => decode_message(ResponseMessage::Handshake, buf)?, + C::TimedSync => decode_message(ResponseMessage::TimedSync, buf)?, + C::Ping => decode_message(ResponseMessage::Ping, buf)?, + C::SupportFlags => decode_message(ResponseMessage::SupportFlags, buf)?, _ => return Err(BucketError::UnknownCommand), }) } - fn build(&self, builder: &mut BucketBuilder) -> Result<(), BucketError> { + fn build(&self, builder: &mut BucketBuilder<LevinCommand>) -> Result<(), BucketError> { + use LevinCommand as C; + match self { - ResponseMessage::Handshake(val) => build_message(1001, val, builder)?, - ResponseMessage::TimedSync(val) => build_message(1002, val, builder)?, - ResponseMessage::Ping(val) => build_message(1003, val, builder)?, - ResponseMessage::SupportFlags(val) => build_message(1007, val, builder)?, + ResponseMessage::Handshake(val) => build_message(C::Handshake, val, builder)?, + ResponseMessage::TimedSync(val) => build_message(C::TimedSync, val, builder)?, + ResponseMessage::Ping(val) => build_message(C::Ping, val, builder)?, + ResponseMessage::SupportFlags(val) => build_message(C::SupportFlags, val, builder)?, } Ok(()) } @@ -164,8 +350,36 @@ pub enum Message { Protocol(ProtocolMessage), } +impl Message { + pub fn is_request(&self) -> bool { + matches!(self, Message::Request(_)) + } + + pub fn is_response(&self) -> bool { + matches!(self, Message::Response(_)) + } + + pub fn is_protocol(&self) -> bool { + matches!(self, Message::Protocol(_)) + } + + pub fn command(&self) -> LevinCommand { + match self { + Message::Request(mes) => mes.command(), + Message::Response(mes) => mes.command(), + Message::Protocol(mes) => mes.command(), + } + } +} + impl LevinBody for Message { - fn decode_message(body: &[u8], typ: MessageType, command: u32) -> Result<Self, BucketError> { + type Command = LevinCommand; + + fn decode_message( + body: &[u8], + typ: MessageType, + command: LevinCommand, + ) -> Result<Self, BucketError> { Ok(match typ { MessageType::Request => Message::Request(RequestMessage::decode(body, command)?), MessageType::Response => Message::Response(ResponseMessage::decode(body, command)?), @@ -173,7 +387,7 @@ impl LevinBody for Message { }) } - fn encode(&self, builder: &mut BucketBuilder) -> Result<(), BucketError> { + fn encode(&self, builder: &mut BucketBuilder<LevinCommand>) -> Result<(), BucketError> { match self { Message::Protocol(pro) => { builder.set_message_type(MessageType::Notification); diff --git a/net/monero-wire/src/p2p/common.rs b/net/monero-wire/src/p2p/common.rs index 8037694..e08e5ff 100644 --- a/net/monero-wire/src/p2p/common.rs +++ b/net/monero-wire/src/p2p/common.rs @@ -39,27 +39,14 @@ impl From<PeerSupportFlags> for u32 { } } -/* impl PeerSupportFlags { - const FLUFFY_BLOCKS: u32 = 0b0000_0001; - /// checks if `self` has all the flags that `other` has - pub fn contains(&self, other: &PeerSupportFlags) -> bool { - self.0. & other.0 == other.0 - } - pub fn supports_fluffy_blocks(&self) -> bool { - self.0 & Self::FLUFFY_BLOCKS == Self::FLUFFY_BLOCKS - } - pub fn get_support_flag_fluffy_blocks() -> Self { - PeerSupportFlags { - support_flags: Self::FLUFFY_BLOCKS, - } - } + //const FLUFFY_BLOCKS: u32 = 0b0000_0001; pub fn is_empty(&self) -> bool { self.0 == 0 } } -*/ + impl From<u8> for PeerSupportFlags { fn from(value: u8) -> Self { PeerSupportFlags(value.into()) diff --git a/p2p/Cargo.toml b/p2p/Cargo.toml index 594c9d7..b90ea4b 100644 --- a/p2p/Cargo.toml +++ b/p2p/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "cuprate-peer" +name = "cuprate-p2p" version = "0.1.0" edition = "2021" license = "AGPL-3.0-only" @@ -12,8 +12,12 @@ thiserror = "1.0.39" cuprate-common = {path = "../common"} monero-wire = {path= "../net/monero-wire"} futures = "0.3.26" -tower = {version = "0.4.13", features = ["util", "steer"]} -tokio = {version= "1.27", features=["rt", "time"]} +tower = {version = "0.4.13", features = ["util", "steer", "load", "discover", "load-shed", "buffer", "timeout"]} +tokio = {version= "1.27", features=["rt", "time", "net"]} +tokio-util = {version = "0.7.8", features=["codec"]} +tokio-stream = {version="0.1.14", features=["time"]} async-trait = "0.1.68" tracing = "0.1.37" -rand = "0.8.5" \ No newline at end of file +tracing-error = "0.2.0" +rand = "0.8.5" +pin-project = "1.0.12" diff --git a/p2p/monero-peer/Cargo.toml b/p2p/monero-peer/Cargo.toml new file mode 100644 index 0000000..03ef392 --- /dev/null +++ b/p2p/monero-peer/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "monero-peer" +version = "0.1.0" +edition = "2021" + +[features] +default = [] + +[dependencies] +monero-wire = {path= "../../net/monero-wire"} +cuprate-common = {path = "../../common"} + +tokio = {version= "1.34.0", default-features = false, features = ["net"]} +tokio-util = { version = "0.7.10", default-features = false, features = ["codec"] } +futures = "0.3.29" +async-trait = "0.1.74" + +tower = { version= "0.4.13", features = ["util"] } +thiserror = "1.0.50" + +tracing = "0.1.40" + +[dev-dependencies] +cuprate-test-utils = {path = "../../test-utils"} + +hex = "0.4.3" +tokio = {version= "1.34.0", default-features = false, features = ["net", "rt-multi-thread", "rt", "macros"]} +tracing-subscriber = "0.3" diff --git a/p2p/monero-peer/src/client.rs b/p2p/monero-peer/src/client.rs new file mode 100644 index 0000000..87d63a3 --- /dev/null +++ b/p2p/monero-peer/src/client.rs @@ -0,0 +1,6 @@ +mod conector; +mod connection; +pub mod handshaker; + +pub use conector::{ConnectRequest, Connector}; +pub use handshaker::{DoHandshakeRequest, HandShaker, HandshakeError}; diff --git a/p2p/monero-peer/src/client/conector.rs b/p2p/monero-peer/src/client/conector.rs new file mode 100644 index 0000000..2a0e151 --- /dev/null +++ b/p2p/monero-peer/src/client/conector.rs @@ -0,0 +1,61 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use futures::FutureExt; +use tower::{Service, ServiceExt}; + +use crate::{ + client::{DoHandshakeRequest, HandShaker, HandshakeError}, + AddressBook, ConnectionDirection, CoreSyncSvc, NetworkZone, PeerRequestHandler, +}; + +pub struct ConnectRequest<Z: NetworkZone> { + pub addr: Z::Addr, +} + +pub struct Connector<Z: NetworkZone, AdrBook, CSync, ReqHdlr> { + handshaker: HandShaker<Z, AdrBook, CSync, ReqHdlr>, +} + +impl<Z: NetworkZone, AdrBook, CSync, ReqHdlr> Connector<Z, AdrBook, CSync, ReqHdlr> { + pub fn new(handshaker: HandShaker<Z, AdrBook, CSync, ReqHdlr>) -> Self { + Self { handshaker } + } +} + +impl<Z: NetworkZone, AdrBook, CSync, ReqHdlr> Service<ConnectRequest<Z>> + for Connector<Z, AdrBook, CSync, ReqHdlr> +where + AdrBook: AddressBook<Z> + Clone, + CSync: CoreSyncSvc + Clone, + ReqHdlr: PeerRequestHandler + Clone, +{ + type Response = (); + type Error = HandshakeError; + type Future = + Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: ConnectRequest<Z>) -> Self::Future { + tracing::debug!("Connecting to peer: {}", req.addr); + let mut handshaker = self.handshaker.clone(); + + async move { + let (peer_stream, peer_sink) = Z::connect_to_peer(req.addr.clone()).await?; + let req = DoHandshakeRequest { + addr: req.addr, + peer_stream, + peer_sink, + direction: ConnectionDirection::OutBound, + }; + handshaker.ready().await?.call(req).await + } + .boxed() + } +} diff --git a/p2p/monero-peer/src/client/connection.rs b/p2p/monero-peer/src/client/connection.rs new file mode 100644 index 0000000..e8de44d --- /dev/null +++ b/p2p/monero-peer/src/client/connection.rs @@ -0,0 +1,176 @@ +use futures::{ + channel::{mpsc, oneshot}, + stream::FusedStream, + SinkExt, StreamExt, +}; + +use monero_wire::{LevinCommand, Message}; + +use crate::{MessageID, NetworkZone, PeerError, PeerRequest, PeerRequestHandler, PeerResponse}; + +pub struct ConnectionTaskRequest { + request: PeerRequest, + response_channel: oneshot::Sender<Result<PeerResponse, PeerError>>, +} + +pub enum State { + WaitingForRequest, + WaitingForResponse { + request_id: MessageID, + tx: oneshot::Sender<Result<PeerResponse, PeerError>>, + }, +} + +impl State { + /// Returns if the [`LevinCommand`] is the correct response message for our request. + /// + /// e.g that we didn't get a block for a txs request. + fn levin_command_response(&self, command: LevinCommand) -> bool { + match self { + State::WaitingForResponse { request_id, .. } => matches!( + (request_id, command), + (MessageID::Handshake, LevinCommand::Handshake) + | (MessageID::TimedSync, LevinCommand::TimedSync) + | (MessageID::Ping, LevinCommand::Ping) + | (MessageID::SupportFlags, LevinCommand::SupportFlags) + | (MessageID::GetObjects, LevinCommand::GetObjectsResponse) + | (MessageID::GetChain, LevinCommand::ChainResponse) + | (MessageID::FluffyMissingTxs, LevinCommand::NewFluffyBlock) + | ( + MessageID::GetTxPoolCompliment, + LevinCommand::NewTransactions + ) + ), + _ => false, + } + } +} + +pub struct Connection<Z: NetworkZone, ReqHndlr> { + peer_sink: Z::Sink, + + state: State, + client_rx: mpsc::Receiver<ConnectionTaskRequest>, + + peer_request_handler: ReqHndlr, +} + +impl<Z: NetworkZone, ReqHndlr> Connection<Z, ReqHndlr> +where + ReqHndlr: PeerRequestHandler, +{ + pub fn new( + peer_sink: Z::Sink, + client_rx: mpsc::Receiver<ConnectionTaskRequest>, + + peer_request_handler: ReqHndlr, + ) -> Connection<Z, ReqHndlr> { + Connection { + peer_sink, + state: State::WaitingForRequest, + client_rx, + peer_request_handler, + } + } + + async fn handle_response(&mut self, res: PeerResponse) -> Result<(), PeerError> { + let state = std::mem::replace(&mut self.state, State::WaitingForRequest); + if let State::WaitingForResponse { request_id, tx } = state { + if request_id != res.id() { + // TODO: Fail here + return Err(PeerError::PeerSentIncorrectResponse); + } + + // TODO: do more tests here + + // response passed our tests we can send it to the requester + let _ = tx.send(Ok(res)); + Ok(()) + } else { + unreachable!("This will only be called when in state WaitingForResponse"); + } + } + + async fn send_message_to_peer(&mut self, mes: impl Into<Message>) -> Result<(), PeerError> { + Ok(self.peer_sink.send(mes.into()).await?) + } + + async fn handle_peer_request(&mut self, _req: PeerRequest) -> Result<(), PeerError> { + // we should check contents of peer requests for obvious errors like we do with responses + todo!() + /* + let ready_svc = self.svc.ready().await?; + let res = ready_svc.call(req).await?; + self.send_message_to_peer(res).await + */ + } + + async fn handle_client_request(&mut self, req: ConnectionTaskRequest) -> Result<(), PeerError> { + if req.request.needs_response() { + self.state = State::WaitingForResponse { + request_id: req.request.id(), + tx: req.response_channel, + }; + } + // TODO: send NA response to requester + self.send_message_to_peer(req.request).await + } + + async fn state_waiting_for_request<Str>(&mut self, stream: &mut Str) -> Result<(), PeerError> + where + Str: FusedStream<Item = Result<Message, monero_wire::BucketError>> + Unpin, + { + futures::select! { + peer_message = stream.next() => { + match peer_message.expect("MessageStream will never return None") { + Ok(message) => { + self.handle_peer_request(message.try_into().map_err(|_| PeerError::ResponseError(""))?).await + }, + Err(e) => Err(e.into()), + } + }, + client_req = self.client_rx.next() => { + self.handle_client_request(client_req.ok_or(PeerError::ClientChannelClosed)?).await + }, + } + } + + async fn state_waiting_for_response<Str>(&mut self, stream: &mut Str) -> Result<(), PeerError> + where + Str: FusedStream<Item = Result<Message, monero_wire::BucketError>> + Unpin, + { + // put a timeout on this + let peer_message = stream + .next() + .await + .expect("MessageStream will never return None")?; + + if !peer_message.is_request() && self.state.levin_command_response(peer_message.command()) { + if let Ok(res) = peer_message.try_into() { + Ok(self.handle_response(res).await?) + } else { + // im almost certain this is impossible to hit, but im not certain enough to use unreachable!() + Err(PeerError::ResponseError("Peer sent incorrect response")) + } + } else if let Ok(req) = peer_message.try_into() { + self.handle_peer_request(req).await + } else { + // this can be hit if the peer sends an incorrect response message + Err(PeerError::ResponseError("Peer sent incorrect response")) + } + } + + pub async fn run<Str>(mut self, mut stream: Str) + where + Str: FusedStream<Item = Result<Message, monero_wire::BucketError>> + Unpin, + { + loop { + let _res = match self.state { + State::WaitingForRequest => self.state_waiting_for_request(&mut stream).await, + State::WaitingForResponse { .. } => { + self.state_waiting_for_response(&mut stream).await + } + }; + } + } +} diff --git a/p2p/monero-peer/src/client/handshaker.rs b/p2p/monero-peer/src/client/handshaker.rs new file mode 100644 index 0000000..4482cd3 --- /dev/null +++ b/p2p/monero-peer/src/client/handshaker.rs @@ -0,0 +1,494 @@ +use std::{ + future::Future, + marker::PhantomData, + pin::Pin, + task::{Context, Poll}, +}; + +use futures::{FutureExt, SinkExt, StreamExt}; +use tower::{Service, ServiceExt}; +use tracing::Instrument; + +use monero_wire::{ + admin::{ + HandshakeRequest, HandshakeResponse, PingResponse, SupportFlagsResponse, + PING_OK_RESPONSE_STATUS_TEXT, + }, + common::PeerSupportFlags, + BasicNodeData, BucketError, CoreSyncData, Message, RequestMessage, ResponseMessage, +}; + +use crate::{ + AddressBook, AddressBookRequest, AddressBookResponse, ConnectionDirection, CoreSyncDataRequest, + CoreSyncDataResponse, CoreSyncSvc, NetworkZone, PeerRequestHandler, + MAX_PEERS_IN_PEER_LIST_MESSAGE, +}; + +#[derive(Debug, thiserror::Error)] +pub enum HandshakeError { + #[error("peer has the same node ID as us")] + PeerHasSameNodeID, + #[error("peer is on a different network")] + IncorrectNetwork, + #[error("peer sent a peer list with peers from different zones")] + PeerSentIncorrectZonePeerList(#[from] crate::NetworkAddressIncorrectZone), + #[error("peer sent invalid message: {0}")] + PeerSentInvalidMessage(&'static str), + #[error("Levin bucket error: {0}")] + LevinBucketError(#[from] BucketError), + #[error("Internal service error: {0}")] + InternalSvcErr(#[from] tower::BoxError), + #[error("i/o error: {0}")] + IO(#[from] std::io::Error), +} + +pub struct DoHandshakeRequest<Z: NetworkZone> { + pub addr: Z::Addr, + pub peer_stream: Z::Stream, + pub peer_sink: Z::Sink, + pub direction: ConnectionDirection, +} + +#[derive(Debug, Clone)] +pub struct HandShaker<Z: NetworkZone, AdrBook, CSync, ReqHdlr> { + address_book: AdrBook, + core_sync_svc: CSync, + peer_request_svc: ReqHdlr, + + our_basic_node_data: BasicNodeData, + + _zone: PhantomData<Z>, +} + +impl<Z: NetworkZone, AdrBook, CSync, ReqHdlr> HandShaker<Z, AdrBook, CSync, ReqHdlr> { + pub fn new( + address_book: AdrBook, + core_sync_svc: CSync, + peer_request_svc: ReqHdlr, + + our_basic_node_data: BasicNodeData, + ) -> Self { + Self { + address_book, + core_sync_svc, + peer_request_svc, + our_basic_node_data, + _zone: PhantomData, + } + } +} + +impl<Z: NetworkZone, AdrBook, CSync, ReqHdlr> Service<DoHandshakeRequest<Z>> + for HandShaker<Z, AdrBook, CSync, ReqHdlr> +where + AdrBook: AddressBook<Z> + Clone, + CSync: CoreSyncSvc + Clone, + ReqHdlr: PeerRequestHandler + Clone, +{ + type Response = (); + type Error = HandshakeError; + type Future = + Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: DoHandshakeRequest<Z>) -> Self::Future { + let DoHandshakeRequest { + addr, + peer_stream, + peer_sink, + direction, + } = req; + + let address_book = self.address_book.clone(); + let peer_request_svc = self.peer_request_svc.clone(); + let core_sync_svc = self.core_sync_svc.clone(); + let our_basic_node_data = self.our_basic_node_data.clone(); + + let span = tracing::info_span!(parent: &tracing::Span::current(), "handshaker", %addr); + + let state_machine = HandshakeStateMachine::<Z, _, _, _> { + addr, + peer_stream, + peer_sink, + direction, + address_book, + core_sync_svc, + peer_request_svc, + our_basic_node_data, + state: HandshakeState::Start, + eager_protocol_messages: vec![], + }; + + async move { + // TODO: timeouts + state_machine.do_handshake().await + } + .instrument(span) + .boxed() + } +} + +/// The states a handshake can be in. +#[derive(Debug, Clone, Eq, PartialEq)] +enum HandshakeState { + /// The initial state. + /// + /// If this is an inbound handshake then this state means we + /// are waiting for a [`HandshakeRequest`]. + Start, + /// Waiting for a [`HandshakeResponse`]. + WaitingForHandshakeResponse, + /// Waiting for a [`SupportFlagsResponse`] + /// This contains the peers node data. + WaitingForSupportFlagResponse(BasicNodeData, CoreSyncData), + /// The handshake is complete. + /// This contains the peers node data. + Complete(BasicNodeData, CoreSyncData), + /// An invalid state, the handshake SM should not be in this state. + Invalid, +} + +impl HandshakeState { + /// Returns true if the handshake is completed. + pub fn is_complete(&self) -> bool { + matches!(self, Self::Complete(..)) + } + + /// returns the peers [`BasicNodeData`] and [`CoreSyncData`] if the peer + /// is in state [`HandshakeState::Complete`]. + pub fn peer_data(self) -> Option<(BasicNodeData, CoreSyncData)> { + match self { + HandshakeState::Complete(bnd, coresync) => Some((bnd, coresync)), + _ => None, + } + } +} + +struct HandshakeStateMachine<Z: NetworkZone, AdrBook, CSync, ReqHdlr> { + addr: Z::Addr, + + peer_stream: Z::Stream, + peer_sink: Z::Sink, + + direction: ConnectionDirection, + + address_book: AdrBook, + core_sync_svc: CSync, + peer_request_svc: ReqHdlr, + + our_basic_node_data: BasicNodeData, + + state: HandshakeState, + + /// Monero allows protocol messages to be sent before a handshake response, so we have to + /// keep track of them here. For saftey we only keep a Max of 2 messages. + eager_protocol_messages: Vec<monero_wire::ProtocolMessage>, +} + +impl<Z: NetworkZone, AdrBook, CSync, ReqHdlr> HandshakeStateMachine<Z, AdrBook, CSync, ReqHdlr> +where + AdrBook: AddressBook<Z>, + CSync: CoreSyncSvc, + ReqHdlr: PeerRequestHandler, +{ + async fn send_handshake_request(&mut self) -> Result<(), HandshakeError> { + let CoreSyncDataResponse::Ours(our_core_sync_data) = self + .core_sync_svc + .ready() + .await? + .call(CoreSyncDataRequest::Ours) + .await? + else { + panic!("core sync service returned wrong response!"); + }; + + let req = HandshakeRequest { + node_data: self.our_basic_node_data.clone(), + payload_data: our_core_sync_data, + }; + + tracing::debug!("Sending handshake request."); + + self.peer_sink + .send(Message::Request(RequestMessage::Handshake(req))) + .await?; + + Ok(()) + } + + async fn send_handshake_response(&mut self) -> Result<(), HandshakeError> { + let CoreSyncDataResponse::Ours(our_core_sync_data) = self + .core_sync_svc + .ready() + .await? + .call(CoreSyncDataRequest::Ours) + .await? + else { + panic!("core sync service returned wrong response!"); + }; + + let AddressBookResponse::Peers(our_peer_list) = self + .address_book + .ready() + .await? + .call(AddressBookRequest::GetPeers(MAX_PEERS_IN_PEER_LIST_MESSAGE)) + .await? + else { + panic!("Address book sent incorrect response"); + }; + + let res = HandshakeResponse { + node_data: self.our_basic_node_data.clone(), + payload_data: our_core_sync_data, + local_peerlist_new: our_peer_list.into_iter().map(Into::into).collect(), + }; + + tracing::debug!("Sending handshake response."); + + self.peer_sink + .send(Message::Response(ResponseMessage::Handshake(res))) + .await?; + + Ok(()) + } + + async fn send_support_flags(&mut self) -> Result<(), HandshakeError> { + let res = SupportFlagsResponse { + support_flags: self.our_basic_node_data.support_flags, + }; + + tracing::debug!("Sending support flag response."); + + self.peer_sink + .send(Message::Response(ResponseMessage::SupportFlags(res))) + .await?; + + Ok(()) + } + + async fn check_request_support_flags( + &mut self, + support_flags: &PeerSupportFlags, + ) -> Result<bool, HandshakeError> { + Ok(if support_flags.is_empty() { + tracing::debug!( + "Peer didn't send support flags or has no features, sending request to make sure." + ); + self.peer_sink + .send(Message::Request(RequestMessage::SupportFlags)) + .await?; + true + } else { + false + }) + } + + async fn handle_handshake_response( + &mut self, + response: HandshakeResponse, + ) -> Result<(), HandshakeError> { + if response.local_peerlist_new.len() > MAX_PEERS_IN_PEER_LIST_MESSAGE { + tracing::debug!("peer sent too many peers in response, cancelling handshake"); + + return Err(HandshakeError::PeerSentInvalidMessage( + "Too many peers in peer list message (>250)", + )); + } + + if response.node_data.network_id != self.our_basic_node_data.network_id { + return Err(HandshakeError::IncorrectNetwork); + } + + if Z::CHECK_NODE_ID && response.node_data.peer_id == self.our_basic_node_data.peer_id { + return Err(HandshakeError::PeerHasSameNodeID); + } + + tracing::debug!( + "Telling address book about new peers, len: {}", + response.local_peerlist_new.len() + ); + + self.address_book + .ready() + .await? + .call(AddressBookRequest::IncomingPeerList( + response + .local_peerlist_new + .into_iter() + .map(TryInto::try_into) + .collect::<Result<_, _>>()?, + )) + .await?; + + if self + .check_request_support_flags(&response.node_data.support_flags) + .await? + { + self.state = HandshakeState::WaitingForSupportFlagResponse( + response.node_data, + response.payload_data, + ); + } else { + self.state = HandshakeState::Complete(response.node_data, response.payload_data); + } + + Ok(()) + } + + async fn handle_handshake_request( + &mut self, + request: HandshakeRequest, + ) -> Result<(), HandshakeError> { + // We don't respond here as if we did the other peer could accept the handshake before responding to a + // support flag request which then means we could recive other requests while waiting for the support + // flags. + + if request.node_data.network_id != self.our_basic_node_data.network_id { + return Err(HandshakeError::IncorrectNetwork); + } + + if Z::CHECK_NODE_ID && request.node_data.peer_id == self.our_basic_node_data.peer_id { + return Err(HandshakeError::PeerHasSameNodeID); + } + + if self + .check_request_support_flags(&request.node_data.support_flags) + .await? + { + self.state = HandshakeState::WaitingForSupportFlagResponse( + request.node_data, + request.payload_data, + ); + } else { + self.state = HandshakeState::Complete(request.node_data, request.payload_data); + } + + Ok(()) + } + + async fn handle_incoming_message(&mut self, message: Message) -> Result<(), HandshakeError> { + tracing::debug!("Received message from peer: {}", message.command()); + + if let Message::Protocol(protocol_message) = message { + if self.eager_protocol_messages.len() == 2 { + tracing::debug!("Peer sent too many protocl messages before a handshake response."); + return Err(HandshakeError::PeerSentInvalidMessage( + "Peer sent too many protocol messages", + )); + } + tracing::debug!( + "Protocol message getting added to queue for when handshake is complete." + ); + self.eager_protocol_messages.push(protocol_message); + return Ok(()); + } + + match std::mem::replace(&mut self.state, HandshakeState::Invalid) { + HandshakeState::Start => match message { + Message::Request(RequestMessage::Ping) => { + // Set the state back to what it was before. + self.state = HandshakeState::Start; + Ok(self + .peer_sink + .send(Message::Response(ResponseMessage::Ping(PingResponse { + status: PING_OK_RESPONSE_STATUS_TEXT.to_string(), + peer_id: self.our_basic_node_data.peer_id, + }))) + .await?) + } + Message::Request(RequestMessage::Handshake(handshake_req)) => { + self.handle_handshake_request(handshake_req).await + } + _ => Err(HandshakeError::PeerSentInvalidMessage( + "Peer didn't send handshake request.", + )), + }, + HandshakeState::WaitingForHandshakeResponse => match message { + // TODO: only allow 1 support flag request. + Message::Request(RequestMessage::SupportFlags) => { + // Set the state back to what it was before. + self.state = HandshakeState::WaitingForHandshakeResponse; + self.send_support_flags().await + } + Message::Response(ResponseMessage::Handshake(res)) => { + self.handle_handshake_response(res).await + } + _ => Err(HandshakeError::PeerSentInvalidMessage( + "Peer didn't send handshake response.", + )), + }, + HandshakeState::WaitingForSupportFlagResponse(mut peer_node_data, peer_core_sync) => { + let Message::Response(ResponseMessage::SupportFlags(support_flags)) = message + else { + return Err(HandshakeError::PeerSentInvalidMessage( + "Peer didn't send support flags response.", + )); + }; + peer_node_data.support_flags = support_flags.support_flags; + self.state = HandshakeState::Complete(peer_node_data, peer_core_sync); + Ok(()) + } + HandshakeState::Complete(..) => { + panic!("Handshake is complete messages should no longer be handled here!") + } + HandshakeState::Invalid => panic!("Handshake state machine stayed in invalid state!"), + } + } + + async fn advance_machine(&mut self) -> Result<(), HandshakeError> { + while !self.state.is_complete() { + tracing::debug!("Waiting for message from peer."); + + match self.peer_stream.next().await { + Some(message) => self.handle_incoming_message(message?).await?, + None => Err(BucketError::IO(std::io::Error::new( + std::io::ErrorKind::ConnectionAborted, + "The peer stream returned None", + )))?, + } + } + + Ok(()) + } + + async fn do_outbound_handshake(&mut self) -> Result<(), HandshakeError> { + self.send_handshake_request().await?; + self.state = HandshakeState::WaitingForHandshakeResponse; + + self.advance_machine().await + } + + async fn do_inbound_handshake(&mut self) -> Result<(), HandshakeError> { + self.advance_machine().await?; + + debug_assert!(self.state.is_complete()); + + self.send_handshake_response().await + } + + async fn do_handshake(mut self) -> Result<(), HandshakeError> { + tracing::debug!("Beginning handshake."); + + match self.direction { + ConnectionDirection::OutBound => self.do_outbound_handshake().await?, + ConnectionDirection::InBound => self.do_inbound_handshake().await?, + } + + let HandshakeState::Complete(peer_node_data, peer_core_sync) = self.state else { + panic!("Hanshake completed not in complete state!"); + }; + + self.core_sync_svc + .ready() + .await? + .call(CoreSyncDataRequest::HandleIncoming(peer_core_sync)) + .await?; + + tracing::debug!("Handshake complete."); + + Ok(()) + } +} diff --git a/p2p/monero-peer/src/error.rs b/p2p/monero-peer/src/error.rs new file mode 100644 index 0000000..046a659 --- /dev/null +++ b/p2p/monero-peer/src/error.rs @@ -0,0 +1,15 @@ +#[derive(Debug, thiserror::Error)] +pub enum PeerError { + #[error("The connection tasks client channel was closed")] + ClientChannelClosed, + #[error("error with peer response: {0}")] + ResponseError(&'static str), + #[error("the peer sent an incorrect response to our request")] + PeerSentIncorrectResponse, + #[error("bucket error")] + BucketError(#[from] monero_wire::BucketError), + #[error("handshake error: {0}")] + Handshake(#[from] crate::client::HandshakeError), + #[error("i/o error: {0}")] + IO(#[from] std::io::Error), +} diff --git a/p2p/monero-peer/src/lib.rs b/p2p/monero-peer/src/lib.rs new file mode 100644 index 0000000..25fe966 --- /dev/null +++ b/p2p/monero-peer/src/lib.rs @@ -0,0 +1,157 @@ +#![allow(unused)] + +use std::{future::Future, pin::Pin}; + +use futures::{Sink, Stream}; + +use monero_wire::{ + network_address::NetworkAddressIncorrectZone, BucketError, Message, NetworkAddress, +}; + +pub mod client; +pub mod error; +pub mod network_zones; +pub mod protocol; +pub mod services; + +pub use error::*; +pub use protocol::*; +use services::*; + +const MAX_PEERS_IN_PEER_LIST_MESSAGE: usize = 250; + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum ConnectionDirection { + InBound, + OutBound, +} + +/// An abstraction over a network zone (tor/i2p/clear) +#[async_trait::async_trait] +pub trait NetworkZone: Clone + Send + 'static { + /// Allow syncing over this network. + /// + /// Not recommended for anonymity networks. + const ALLOW_SYNC: bool; + /// Enable dandelion++ for this network. + /// + /// This is unneeded on anonymity networks. + const DANDELION_PP: bool; + /// Check if our node ID matches the incoming peers node ID for this network. + /// + /// This has privacy implications on an anonymity network if true so should be set + /// to false. + const CHECK_NODE_ID: bool; + + /// The address type of this network. + type Addr: TryFrom<NetworkAddress, Error = NetworkAddressIncorrectZone> + + Into<NetworkAddress> + + std::fmt::Display + + Clone + + Send + + 'static; + /// The stream (incoming data) type for this network. + type Stream: Stream<Item = Result<Message, BucketError>> + Unpin + Send + 'static; + /// The sink (outgoing data) type for this network. + type Sink: Sink<Message, Error = BucketError> + Unpin + Send + 'static; + /// Config used to start a server which listens for incoming connections. + type ServerCfg; + + async fn connect_to_peer( + addr: Self::Addr, + ) -> Result<(Self::Stream, Self::Sink), std::io::Error>; + + async fn incoming_connection_listener(config: Self::ServerCfg) -> (); +} + +pub(crate) trait AddressBook<Z: NetworkZone>: + tower::Service< + AddressBookRequest<Z>, + Response = AddressBookResponse<Z>, + Error = tower::BoxError, + Future = Pin< + Box< + dyn Future<Output = Result<AddressBookResponse<Z>, tower::BoxError>> + + Send + + 'static, + >, + >, + > + Send + + 'static +{ +} + +impl<T, Z: NetworkZone> AddressBook<Z> for T where + T: tower::Service< + AddressBookRequest<Z>, + Response = AddressBookResponse<Z>, + Error = tower::BoxError, + Future = Pin< + Box< + dyn Future<Output = Result<AddressBookResponse<Z>, tower::BoxError>> + + Send + + 'static, + >, + >, + > + Send + + 'static +{ +} + +pub(crate) trait CoreSyncSvc: + tower::Service< + CoreSyncDataRequest, + Response = CoreSyncDataResponse, + Error = tower::BoxError, + Future = Pin< + Box< + dyn Future<Output = Result<CoreSyncDataResponse, tower::BoxError>> + Send + 'static, + >, + >, + > + Send + + 'static +{ +} + +impl<T> CoreSyncSvc for T where + T: tower::Service< + CoreSyncDataRequest, + Response = CoreSyncDataResponse, + Error = tower::BoxError, + Future = Pin< + Box< + dyn Future<Output = Result<CoreSyncDataResponse, tower::BoxError>> + + Send + + 'static, + >, + >, + > + Send + + 'static +{ +} + +pub(crate) trait PeerRequestHandler: + tower::Service< + PeerRequest, + Response = PeerResponse, + Error = tower::BoxError, + Future = Pin< + Box<dyn Future<Output = Result<PeerResponse, tower::BoxError>> + Send + 'static>, + >, + > + Send + + 'static +{ +} + +impl<T> PeerRequestHandler for T where + T: tower::Service< + PeerRequest, + Response = PeerResponse, + Error = tower::BoxError, + Future = Pin< + Box<dyn Future<Output = Result<PeerResponse, tower::BoxError>> + Send + 'static>, + >, + > + Send + + 'static +{ +} diff --git a/p2p/monero-peer/src/network_zones.rs b/p2p/monero-peer/src/network_zones.rs new file mode 100644 index 0000000..cc20402 --- /dev/null +++ b/p2p/monero-peer/src/network_zones.rs @@ -0,0 +1,3 @@ +mod clear; + +pub use clear::{ClearNet, ClearNetServerCfg}; diff --git a/p2p/monero-peer/src/network_zones/clear.rs b/p2p/monero-peer/src/network_zones/clear.rs new file mode 100644 index 0000000..cc35285 --- /dev/null +++ b/p2p/monero-peer/src/network_zones/clear.rs @@ -0,0 +1,43 @@ +use std::net::SocketAddr; + +use monero_wire::MoneroWireCodec; + +use tokio::net::{ + tcp::{OwnedReadHalf, OwnedWriteHalf}, + TcpStream, +}; +use tokio_util::codec::{FramedRead, FramedWrite}; + +use crate::NetworkZone; + +#[derive(Clone)] +pub struct ClearNet; + +pub struct ClearNetServerCfg {} + +#[async_trait::async_trait] +impl NetworkZone for ClearNet { + const ALLOW_SYNC: bool = true; + const DANDELION_PP: bool = true; + const CHECK_NODE_ID: bool = true; + + type Addr = SocketAddr; + type Stream = FramedRead<OwnedReadHalf, MoneroWireCodec>; + type Sink = FramedWrite<OwnedWriteHalf, MoneroWireCodec>; + + type ServerCfg = (); + + async fn connect_to_peer( + addr: Self::Addr, + ) -> Result<(Self::Stream, Self::Sink), std::io::Error> { + let (read, write) = TcpStream::connect(addr).await?.into_split(); + Ok(( + FramedRead::new(read, MoneroWireCodec::default()), + FramedWrite::new(write, MoneroWireCodec::default()), + )) + } + + async fn incoming_connection_listener(config: Self::ServerCfg) -> () { + todo!() + } +} diff --git a/p2p/monero-peer/src/protocol.rs b/p2p/monero-peer/src/protocol.rs new file mode 100644 index 0000000..56edd81 --- /dev/null +++ b/p2p/monero-peer/src/protocol.rs @@ -0,0 +1,130 @@ +/// This module defines InternalRequests and InternalResponses. Cuprate's P2P works by translating network messages into an internal +/// request/ response, this is easy for levin "requests" and "responses" (admin messages) but takes a bit more work with "notifications" +/// (protocol messages). +/// +/// Some notifications are easy to translate, like `GetObjectsRequest` is obviously a request but others like `NewFluffyBlock` are a +/// bit tricker. To translate a `NewFluffyBlock` into a request/ response we will have to look to see if we asked for `FluffyMissingTransactionsRequest` +/// if we have we interpret `NewFluffyBlock` as a response if not its a request that doesn't require a response. +/// +/// Here is every P2P request/ response. *note admin messages are already request/ response so "Handshake" is actually made of a HandshakeRequest & HandshakeResponse +/// +/// Admin: +/// Handshake, +/// TimedSync, +/// Ping, +/// SupportFlags +/// Protocol: +/// Request: GetObjectsRequest, Response: GetObjectsResponse, +/// Request: ChainRequest, Response: ChainResponse, +/// Request: FluffyMissingTransactionsRequest, Response: NewFluffyBlock, <- these 2 could be requests or responses +/// Request: GetTxPoolCompliment, Response: NewTransactions, <- +/// Request: NewBlock, Response: None, +/// Request: NewFluffyBlock, Response: None, +/// Request: NewTransactions, Response: None +/// +/// +use monero_wire::{ + admin::{ + HandshakeRequest, HandshakeResponse, PingResponse, SupportFlagsResponse, TimedSyncRequest, + TimedSyncResponse, + }, + protocol::{ + ChainRequest, ChainResponse, FluffyMissingTransactionsRequest, GetObjectsRequest, + GetObjectsResponse, GetTxPoolCompliment, NewBlock, NewFluffyBlock, NewTransactions, + }, +}; + +mod try_from; + +/// An enum representing a request/ response combination, so a handshake request +/// and response would have the same [`MessageID`]. This allows associating the +/// correct response to a request. +#[derive(Debug, Eq, PartialEq, Copy, Clone)] +pub enum MessageID { + Handshake, + TimedSync, + Ping, + SupportFlags, + + GetObjects, + GetChain, + FluffyMissingTxs, + GetTxPoolCompliment, + NewBlock, + NewFluffyBlock, + NewTransactions, +} + +pub enum PeerRequest { + Handshake(HandshakeRequest), + TimedSync(TimedSyncRequest), + Ping, + SupportFlags, + + GetObjects(GetObjectsRequest), + GetChain(ChainRequest), + FluffyMissingTxs(FluffyMissingTransactionsRequest), + GetTxPoolCompliment(GetTxPoolCompliment), + NewBlock(NewBlock), + NewFluffyBlock(NewFluffyBlock), + NewTransactions(NewTransactions), +} + +impl PeerRequest { + pub fn id(&self) -> MessageID { + match self { + PeerRequest::Handshake(_) => MessageID::Handshake, + PeerRequest::TimedSync(_) => MessageID::TimedSync, + PeerRequest::Ping => MessageID::Ping, + PeerRequest::SupportFlags => MessageID::SupportFlags, + + PeerRequest::GetObjects(_) => MessageID::GetObjects, + PeerRequest::GetChain(_) => MessageID::GetChain, + PeerRequest::FluffyMissingTxs(_) => MessageID::FluffyMissingTxs, + PeerRequest::GetTxPoolCompliment(_) => MessageID::GetTxPoolCompliment, + PeerRequest::NewBlock(_) => MessageID::NewBlock, + PeerRequest::NewFluffyBlock(_) => MessageID::NewFluffyBlock, + PeerRequest::NewTransactions(_) => MessageID::NewTransactions, + } + } + + pub fn needs_response(&self) -> bool { + !matches!( + self, + PeerRequest::NewBlock(_) + | PeerRequest::NewFluffyBlock(_) + | PeerRequest::NewTransactions(_) + ) + } +} + +pub enum PeerResponse { + Handshake(HandshakeResponse), + TimedSync(TimedSyncResponse), + Ping(PingResponse), + SupportFlags(SupportFlagsResponse), + + GetObjects(GetObjectsResponse), + GetChain(ChainResponse), + NewFluffyBlock(NewFluffyBlock), + NewTransactions(NewTransactions), + NA, +} + +impl PeerResponse { + pub fn id(&self) -> MessageID { + match self { + PeerResponse::Handshake(_) => MessageID::Handshake, + PeerResponse::TimedSync(_) => MessageID::TimedSync, + PeerResponse::Ping(_) => MessageID::Ping, + PeerResponse::SupportFlags(_) => MessageID::SupportFlags, + + PeerResponse::GetObjects(_) => MessageID::GetObjects, + PeerResponse::GetChain(_) => MessageID::GetChain, + PeerResponse::NewFluffyBlock(_) => MessageID::NewBlock, + PeerResponse::NewTransactions(_) => MessageID::NewFluffyBlock, + + PeerResponse::NA => panic!("Can't get message ID for a non existent response"), + } + } +} diff --git a/p2p/monero-peer/src/protocol/try_from.rs b/p2p/monero-peer/src/protocol/try_from.rs new file mode 100644 index 0000000..4e4ebdb --- /dev/null +++ b/p2p/monero-peer/src/protocol/try_from.rs @@ -0,0 +1,179 @@ +//! This module contains the implementations of [`TryFrom`] and [`From`] to convert between +//! [`Message`], [`PeerRequest`] and [`PeerResponse`]. + +use monero_wire::{Message, ProtocolMessage, RequestMessage, ResponseMessage}; + +use super::{PeerRequest, PeerResponse}; + +pub struct MessageConversionError; + +macro_rules! match_body { + (match $value: ident {$($body:tt)*} ($left:pat => $right_ty:expr) $($todo:tt)*) => { + match_body!( match $value { + $left => $right_ty, + $($body)* + } $($todo)* ) + }; + (match $value: ident {$($body:tt)*}) => { + match $value { + $($body)* + } + }; +} + +macro_rules! from { + ($left_ty:ident, $right_ty:ident, {$($left:ident $(($val: ident))? = $right:ident $(($vall: ident))?,)+}) => { + impl From<$left_ty> for $right_ty { + fn from(value: $left_ty) -> Self { + match_body!( match value {} + $(($left_ty::$left$(($val))? => $right_ty::$right$(($vall))?))+ + ) + } + } + }; +} + +macro_rules! try_from { + ($left_ty:ident, $right_ty:ident, {$($left:ident $(($val: ident))? = $right:ident $(($vall: ident))?,)+}) => { + impl TryFrom<$left_ty> for $right_ty { + type Error = MessageConversionError; + + fn try_from(value: $left_ty) -> Result<Self, Self::Error> { + Ok(match_body!( match value { + _ => return Err(MessageConversionError) + } + $(($left_ty::$left$(($val))? => $right_ty::$right$(($vall))?))+ + )) + } + } + }; +} + +macro_rules! from_try_from { + ($left_ty:ident, $right_ty:ident, {$($left:ident $(($val: ident))? = $right:ident $(($vall: ident))?,)+}) => { + try_from!($left_ty, $right_ty, {$($left $(($val))? = $right $(($vall))?,)+}); + from!($right_ty, $left_ty, {$($right $(($val))? = $left $(($vall))?,)+}); + }; +} + +macro_rules! try_from_try_from { + ($left_ty:ident, $right_ty:ident, {$($left:ident $(($val: ident))? = $right:ident $(($vall: ident))?,)+}) => { + try_from!($left_ty, $right_ty, {$($left $(($val))? = $right $(($vall))?,)+}); + try_from!($right_ty, $left_ty, {$($right $(($val))? = $left $(($val))?,)+}); + }; +} + +from_try_from!(PeerRequest, RequestMessage,{ + Handshake(val) = Handshake(val), + Ping = Ping, + SupportFlags = SupportFlags, + TimedSync(val) = TimedSync(val), +}); + +try_from_try_from!(PeerRequest, ProtocolMessage,{ + NewBlock(val) = NewBlock(val), + NewFluffyBlock(val) = NewFluffyBlock(val), + GetObjects(val) = GetObjectsRequest(val), + GetChain(val) = ChainRequest(val), + NewTransactions(val) = NewTransactions(val), + FluffyMissingTxs(val) = FluffyMissingTransactionsRequest(val), + GetTxPoolCompliment(val) = GetTxPoolCompliment(val), +}); + +impl TryFrom<Message> for PeerRequest { + type Error = MessageConversionError; + + fn try_from(value: Message) -> Result<Self, Self::Error> { + match value { + Message::Request(req) => Ok(req.into()), + Message::Protocol(pro) => pro.try_into(), + _ => Err(MessageConversionError), + } + } +} + +impl From<PeerRequest> for Message { + fn from(value: PeerRequest) -> Self { + match value { + PeerRequest::Handshake(val) => Message::Request(RequestMessage::Handshake(val)), + PeerRequest::Ping => Message::Request(RequestMessage::Ping), + PeerRequest::SupportFlags => Message::Request(RequestMessage::SupportFlags), + PeerRequest::TimedSync(val) => Message::Request(RequestMessage::TimedSync(val)), + + PeerRequest::NewBlock(val) => Message::Protocol(ProtocolMessage::NewBlock(val)), + PeerRequest::NewFluffyBlock(val) => { + Message::Protocol(ProtocolMessage::NewFluffyBlock(val)) + } + PeerRequest::GetObjects(val) => { + Message::Protocol(ProtocolMessage::GetObjectsRequest(val)) + } + PeerRequest::GetChain(val) => Message::Protocol(ProtocolMessage::ChainRequest(val)), + PeerRequest::NewTransactions(val) => { + Message::Protocol(ProtocolMessage::NewTransactions(val)) + } + PeerRequest::FluffyMissingTxs(val) => { + Message::Protocol(ProtocolMessage::FluffyMissingTransactionsRequest(val)) + } + PeerRequest::GetTxPoolCompliment(val) => { + Message::Protocol(ProtocolMessage::GetTxPoolCompliment(val)) + } + } + } +} + +from_try_from!(PeerResponse, ResponseMessage,{ + Handshake(val) = Handshake(val), + Ping(val) = Ping(val), + SupportFlags(val) = SupportFlags(val), + TimedSync(val) = TimedSync(val), +}); + +try_from_try_from!(PeerResponse, ProtocolMessage,{ + NewFluffyBlock(val) = NewFluffyBlock(val), + GetObjects(val) = GetObjectsResponse(val), + GetChain(val) = ChainEntryResponse(val), + NewTransactions(val) = NewTransactions(val), + +}); + +impl TryFrom<Message> for PeerResponse { + type Error = MessageConversionError; + + fn try_from(value: Message) -> Result<Self, Self::Error> { + match value { + Message::Response(res) => Ok(res.into()), + Message::Protocol(pro) => pro.try_into(), + _ => Err(MessageConversionError), + } + } +} + +impl TryFrom<PeerResponse> for Message { + type Error = MessageConversionError; + + fn try_from(value: PeerResponse) -> Result<Self, Self::Error> { + Ok(match value { + PeerResponse::Handshake(val) => Message::Response(ResponseMessage::Handshake(val)), + PeerResponse::Ping(val) => Message::Response(ResponseMessage::Ping(val)), + PeerResponse::SupportFlags(val) => { + Message::Response(ResponseMessage::SupportFlags(val)) + } + PeerResponse::TimedSync(val) => Message::Response(ResponseMessage::TimedSync(val)), + + PeerResponse::NewFluffyBlock(val) => { + Message::Protocol(ProtocolMessage::NewFluffyBlock(val)) + } + PeerResponse::GetObjects(val) => { + Message::Protocol(ProtocolMessage::GetObjectsResponse(val)) + } + PeerResponse::GetChain(val) => { + Message::Protocol(ProtocolMessage::ChainEntryResponse(val)) + } + PeerResponse::NewTransactions(val) => { + Message::Protocol(ProtocolMessage::NewTransactions(val)) + } + + PeerResponse::NA => return Err(MessageConversionError), + }) + } +} diff --git a/p2p/monero-peer/src/services.rs b/p2p/monero-peer/src/services.rs new file mode 100644 index 0000000..db1187f --- /dev/null +++ b/p2p/monero-peer/src/services.rs @@ -0,0 +1,61 @@ +use monero_wire::PeerListEntryBase; + +use crate::{NetworkAddressIncorrectZone, NetworkZone}; + +pub enum CoreSyncDataRequest { + Ours, + HandleIncoming(monero_wire::CoreSyncData), +} + +pub enum CoreSyncDataResponse { + Ours(monero_wire::CoreSyncData), + Ok, +} + +pub struct ZoneSpecificPeerListEntryBase<Z: NetworkZone> { + pub adr: Z::Addr, + pub id: u64, + pub last_seen: i64, + pub pruning_seed: u32, + pub rpc_port: u16, + pub rpc_credits_per_hash: u32, +} + +impl<Z: NetworkZone> From<ZoneSpecificPeerListEntryBase<Z>> for monero_wire::PeerListEntryBase { + fn from(value: ZoneSpecificPeerListEntryBase<Z>) -> Self { + Self { + adr: value.adr.into(), + id: value.id, + last_seen: value.last_seen, + pruning_seed: value.pruning_seed, + rpc_port: value.rpc_port, + rpc_credits_per_hash: value.rpc_credits_per_hash, + } + } +} + +impl<Z: NetworkZone> TryFrom<monero_wire::PeerListEntryBase> for ZoneSpecificPeerListEntryBase<Z> { + type Error = NetworkAddressIncorrectZone; + + fn try_from(value: PeerListEntryBase) -> Result<Self, Self::Error> { + Ok(Self { + adr: value.adr.try_into()?, + id: value.id, + last_seen: value.last_seen, + pruning_seed: value.pruning_seed, + rpc_port: value.rpc_port, + rpc_credits_per_hash: value.rpc_credits_per_hash, + }) + } +} + +pub enum AddressBookRequest<Z: NetworkZone> { + NewConnection(Z::Addr, ZoneSpecificPeerListEntryBase<Z>), + IncomingPeerList(Vec<ZoneSpecificPeerListEntryBase<Z>>), + GetPeers(usize), +} + +pub enum AddressBookResponse<Z: NetworkZone> { + Ok, + Peers(Vec<ZoneSpecificPeerListEntryBase<Z>>), +} diff --git a/p2p/monero-peer/tests/handshake.rs b/p2p/monero-peer/tests/handshake.rs new file mode 100644 index 0000000..846ae86 --- /dev/null +++ b/p2p/monero-peer/tests/handshake.rs @@ -0,0 +1,125 @@ +use std::{net::SocketAddr, str::FromStr}; + +use futures::{channel::mpsc, StreamExt}; +use tower::{Service, ServiceExt}; + +use cuprate_common::Network; +use monero_wire::{common::PeerSupportFlags, BasicNodeData}; + +use monero_peer::{ + client::{ConnectRequest, Connector, DoHandshakeRequest, HandShaker}, + network_zones::ClearNet, + ConnectionDirection, +}; + +use cuprate_test_utils::test_netzone::{TestNetZone, TestNetZoneAddr}; + +mod utils; +use utils::*; + +#[tokio::test] +async fn handshake_cuprate_to_cuprate() { + // Tests a Cuprate <-> Cuprate handshake by making 2 handshake services and making them talk to + // each other. + + let our_basic_node_data_1 = BasicNodeData { + my_port: 0, + network_id: Network::Mainnet.network_id(), + peer_id: 87980, + // TODO: This fails if the support flags are empty (0) + support_flags: PeerSupportFlags::from(1_u32), + rpc_port: 0, + rpc_credits_per_hash: 0, + }; + // make sure both node IDs are different + let mut our_basic_node_data_2 = our_basic_node_data_1.clone(); + our_basic_node_data_2.peer_id = 2344; + + let mut handshaker_1 = HandShaker::<TestNetZone<true, true, true>, _, _, _>::new( + DummyAddressBook, + DummyCoreSyncSvc, + DummyPeerRequestHandlerSvc, + our_basic_node_data_1, + ); + + let mut handshaker_2 = HandShaker::<TestNetZone<true, true, true>, _, _, _>::new( + DummyAddressBook, + DummyCoreSyncSvc, + DummyPeerRequestHandlerSvc, + our_basic_node_data_2, + ); + + let (p1_sender, p2_receiver) = mpsc::channel(5); + let (p2_sender, p1_receiver) = mpsc::channel(5); + + let p1_handshake_req = DoHandshakeRequest { + addr: TestNetZoneAddr(888), + peer_stream: p2_receiver.map(Ok).boxed(), + peer_sink: p2_sender.into(), + direction: ConnectionDirection::OutBound, + }; + + let p2_handshake_req = DoHandshakeRequest { + addr: TestNetZoneAddr(444), + peer_stream: p1_receiver.boxed().map(Ok).boxed(), + peer_sink: p1_sender.into(), + direction: ConnectionDirection::InBound, + }; + + let p1 = tokio::spawn(async move { + handshaker_1 + .ready() + .await + .unwrap() + .call(p1_handshake_req) + .await + .unwrap() + }); + + let p2 = tokio::spawn(async move { + handshaker_2 + .ready() + .await + .unwrap() + .call(p2_handshake_req) + .await + .unwrap() + }); + + let (res1, res2) = futures::join!(p1, p2); + res1.unwrap(); + res2.unwrap(); +} + +#[tokio::test] +async fn handshake() { + let addr = "127.0.0.1:18080"; + + let our_basic_node_data = BasicNodeData { + my_port: 0, + network_id: Network::Mainnet.network_id(), + peer_id: 87980, + support_flags: PeerSupportFlags::from(1_u32), + rpc_port: 0, + rpc_credits_per_hash: 0, + }; + + let handshaker = HandShaker::<ClearNet, _, _, _>::new( + DummyAddressBook, + DummyCoreSyncSvc, + DummyPeerRequestHandlerSvc, + our_basic_node_data, + ); + + let mut connector = Connector::new(handshaker); + + connector + .ready() + .await + .unwrap() + .call(ConnectRequest { + addr: SocketAddr::from_str(addr).unwrap(), + }) + .await + .unwrap(); +} diff --git a/p2p/monero-peer/tests/utils.rs b/p2p/monero-peer/tests/utils.rs new file mode 100644 index 0000000..6ff1cd4 --- /dev/null +++ b/p2p/monero-peer/tests/utils.rs @@ -0,0 +1,95 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use futures::FutureExt; +use tower::Service; + +use monero_peer::{ + services::{ + AddressBookRequest, AddressBookResponse, CoreSyncDataRequest, CoreSyncDataResponse, + }, + NetworkZone, PeerRequest, PeerResponse, +}; + +#[derive(Clone)] +pub struct DummyAddressBook; + +impl<Z: NetworkZone> Service<AddressBookRequest<Z>> for DummyAddressBook { + type Response = AddressBookResponse<Z>; + type Error = tower::BoxError; + type Future = + Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: AddressBookRequest<Z>) -> Self::Future { + async move { + Ok(match req { + AddressBookRequest::GetPeers(_) => AddressBookResponse::Peers(vec![]), + _ => AddressBookResponse::Ok, + }) + } + .boxed() + } +} + +#[derive(Clone)] +pub struct DummyCoreSyncSvc; + +impl Service<CoreSyncDataRequest> for DummyCoreSyncSvc { + type Response = CoreSyncDataResponse; + type Error = tower::BoxError; + type Future = + Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: CoreSyncDataRequest) -> Self::Future { + async move { + match req { + CoreSyncDataRequest::Ours => { + Ok(CoreSyncDataResponse::Ours(monero_wire::CoreSyncData { + cumulative_difficulty: 1, + cumulative_difficulty_top64: 0, + current_height: 1, + pruning_seed: 0, + top_id: hex::decode( + "418015bb9ae982a1975da7d79277c2705727a56894ba0fb246adaabb1f4632e3", + ) + .unwrap() + .try_into() + .unwrap(), + top_version: 1, + })) + } + CoreSyncDataRequest::HandleIncoming(_) => Ok(CoreSyncDataResponse::Ok), + } + } + .boxed() + } +} + +#[derive(Clone)] +pub struct DummyPeerRequestHandlerSvc; + +impl Service<PeerRequest> for DummyPeerRequestHandlerSvc { + type Response = PeerResponse; + type Error = tower::BoxError; + type Future = + Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + todo!() + } + + fn call(&mut self, req: PeerRequest) -> Self::Future { + todo!() + } +} diff --git a/p2p/src/address_book.rs b/p2p/src/address_book.rs index c7b465b..6865137 100644 --- a/p2p/src/address_book.rs +++ b/p2p/src/address_book.rs @@ -1,120 +1,157 @@ +//! Cuprate Address Book +//! +//! This module holds the logic for persistent peer storage. +//! Cuprates address book is modeled as a [`tower::Service`] +//! The request is [`AddressBookRequest`] and the response is +//! [`AddressBookResponse`]. +//! +//! Cuprate, like monerod, actually has 3 address books, one +//! for each [`NetZone`]. This is to reduce the possibility of +//! clear net peers getting linked to their dark counterparts +//! and so peers will only get told about peers they can +//! connect to. +//! + mod addr_book_client; -pub(crate) mod address_book; +mod address_book; +pub mod connection_handle; + +use cuprate_common::PruningSeed; +use monero_wire::{messages::PeerListEntryBase, network_address::NetZone, NetworkAddress, PeerID}; + +use connection_handle::ConnectionAddressBookHandle; pub use addr_book_client::start_address_book; -use monero_wire::{messages::PeerListEntryBase, network_address::NetZone, NetworkAddress}; - -const MAX_WHITE_LIST_PEERS: usize = 1000; -const MAX_GRAY_LIST_PEERS: usize = 5000; - +/// Possible errors when dealing with the address book. +/// This is boxed when returning an error in the [`tower::Service`]. #[derive(Debug, thiserror::Error)] pub enum AddressBookError { + /// The peer is not in the address book for this zone. #[error("Peer was not found in book")] PeerNotFound, + /// The peer list is empty. #[error("The peer list is empty")] PeerListEmpty, + /// The peers pruning seed has changed. + #[error("The peers pruning seed has changed")] + PeersPruningSeedChanged, + /// The peer is banned. + #[error("The peer is banned")] + PeerIsBanned, + /// When handling a received peer list, the list contains + /// a peer in a different [`NetZone`] #[error("Peer sent an address out of it's net-zone")] PeerSentAnAddressOutOfZone, + /// The channel to the address book has closed unexpectedly. #[error("The address books channel has closed.")] AddressBooksChannelClosed, + /// The address book task has exited. + #[error("The address book task has exited.")] + AddressBookTaskExited, + /// The peer file store has failed. #[error("Peer Store Error: {0}")] PeerStoreError(&'static str), } +/// A message sent to tell the address book that a peer has disconnected. +pub struct PeerConnectionClosed; + +/// A request to the address book. #[derive(Debug)] pub enum AddressBookRequest { + /// A request to handle an incoming peer list. HandleNewPeerList(Vec<PeerListEntryBase>, NetZone), - SetPeerSeen(NetworkAddress, i64), - BanPeer(NetworkAddress, chrono::NaiveDateTime), - AddPeerToAnchor(NetworkAddress), - RemovePeerFromAnchor(NetworkAddress), - UpdatePeerInfo(PeerListEntryBase), + /// Updates the `last_seen` timestamp of this peer. + SetPeerSeen(PeerID, chrono::NaiveDateTime, NetZone), + /// Bans a peer for the specified duration. This request + /// will send disconnect signals to all peers with the same + /// [`ban_identifier`](NetworkAddress::ban_identifier). + BanPeer(PeerID, std::time::Duration, NetZone), + /// Adds a peer to the connected list + ConnectedToPeer { + /// The net zone of this connection. + zone: NetZone, + /// A handle between the connection and address book. + connection_handle: ConnectionAddressBookHandle, + /// The connection addr, None if the peer is using a + /// hidden network. + addr: Option<NetworkAddress>, + /// The peers id. + id: PeerID, + /// If the peer is reachable by our node. + reachable: bool, + /// The last seen timestamp, note: Cuprate may skip updating this + /// field on some inbound messages + last_seen: chrono::NaiveDateTime, + /// The peers pruning seed + pruning_seed: PruningSeed, + /// The peers port. + rpc_port: u16, + /// The peers rpc credits per hash + rpc_credits_per_hash: u32, + }, - GetRandomGrayPeer(NetZone), - GetRandomWhitePeer(NetZone), + /// A request to get and eempty the anchor list, + /// used when starting the node. + GetAndEmptyAnchorList(NetZone), + /// Get a random Gray peer from the peer list + /// If a pruning seed is given we will select from + /// peers with that seed and peers that dont prune. + GetRandomGrayPeer(NetZone, Option<PruningSeed>), + /// Get a random White peer from the peer list + /// If a pruning seed is given we will select from + /// peers with that seed and peers that dont prune. + GetRandomWhitePeer(NetZone, Option<PruningSeed>), + /// Get a list of random peers from the white list, + /// The list will be less than or equal to the provided + /// len. + GetRandomWhitePeers(NetZone, usize), } impl std::fmt::Display for AddressBookRequest { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Self::HandleNewPeerList(_, _) => f.write_str("HandleNewPeerList"), - Self::SetPeerSeen(_, _) => f.write_str("SetPeerSeen"), - Self::BanPeer(_, _) => f.write_str("BanPeer"), - Self::AddPeerToAnchor(_) => f.write_str("AddPeerToAnchor"), - Self::RemovePeerFromAnchor(_) => f.write_str("RemovePeerFromAnchor"), - Self::UpdatePeerInfo(_) => f.write_str("UpdatePeerInfo"), - Self::GetRandomGrayPeer(_) => f.write_str("GetRandomGrayPeer"), - Self::GetRandomWhitePeer(_) => f.write_str("GetRandomWhitePeer"), + Self::HandleNewPeerList(..) => f.write_str("HandleNewPeerList"), + Self::SetPeerSeen(..) => f.write_str("SetPeerSeen"), + Self::BanPeer(..) => f.write_str("BanPeer"), + Self::ConnectedToPeer { .. } => f.write_str("ConnectedToPeer"), + + Self::GetAndEmptyAnchorList(_) => f.write_str("GetAndEmptyAnchorList"), + Self::GetRandomGrayPeer(..) => f.write_str("GetRandomGrayPeer"), + Self::GetRandomWhitePeer(..) => f.write_str("GetRandomWhitePeer"), + Self::GetRandomWhitePeers(_, len) => { + f.write_str(&format!("GetRandomWhitePeers, len: {len}")) + } } } } impl AddressBookRequest { + /// Gets the [`NetZone`] for this request so we can + /// route it to the required address book. pub fn get_zone(&self) -> NetZone { match self { Self::HandleNewPeerList(_, zone) => *zone, - Self::SetPeerSeen(peer, _) => peer.get_zone(), - Self::BanPeer(peer, _) => peer.get_zone(), - Self::AddPeerToAnchor(peer) => peer.get_zone(), - Self::RemovePeerFromAnchor(peer) => peer.get_zone(), - Self::UpdatePeerInfo(peer) => peer.adr.get_zone(), + Self::SetPeerSeen(.., zone) => *zone, + Self::BanPeer(.., zone) => *zone, + Self::ConnectedToPeer { zone, .. } => *zone, - Self::GetRandomGrayPeer(zone) => *zone, - Self::GetRandomWhitePeer(zone) => *zone, + Self::GetAndEmptyAnchorList(zone) => *zone, + Self::GetRandomGrayPeer(zone, _) => *zone, + Self::GetRandomWhitePeer(zone, _) => *zone, + Self::GetRandomWhitePeers(zone, _) => *zone, } } } +/// A response from the AddressBook. #[derive(Debug)] pub enum AddressBookResponse { + /// The request was handled ok. Ok, + /// A peer. Peer(PeerListEntryBase), -} - -#[derive(Debug, Clone)] -pub struct AddressBookConfig { - max_white_peers: usize, - max_gray_peers: usize, -} - -impl Default for AddressBookConfig { - fn default() -> Self { - AddressBookConfig { - max_white_peers: MAX_WHITE_LIST_PEERS, - max_gray_peers: MAX_GRAY_LIST_PEERS, - } - } -} - -#[async_trait::async_trait] -pub trait AddressBookStore: Clone { - type Error: Into<AddressBookError>; - /// Loads the peers from the peer store. - /// returns (in order): - /// the white list, - /// the gray list, - /// the anchor list, - /// the ban list - async fn load_peers( - &mut self, - zone: NetZone, - ) -> Result< - ( - Vec<PeerListEntryBase>, // white list - Vec<PeerListEntryBase>, // gray list - Vec<NetworkAddress>, // anchor list - Vec<(NetworkAddress, chrono::NaiveDateTime)>, // ban list - ), - Self::Error, - >; - - async fn save_peers( - &mut self, - zone: NetZone, - white: Vec<PeerListEntryBase>, - gray: Vec<PeerListEntryBase>, - anchor: Vec<NetworkAddress>, - bans: Vec<(NetworkAddress, chrono::NaiveDateTime)>, // ban lists - ) -> Result<(), Self::Error>; + /// A list of peers. + Peers(Vec<PeerListEntryBase>), } diff --git a/p2p/src/address_book/addr_book_client.rs b/p2p/src/address_book/addr_book_client.rs index 5101cd2..f35d265 100644 --- a/p2p/src/address_book/addr_book_client.rs +++ b/p2p/src/address_book/addr_book_client.rs @@ -1,38 +1,44 @@ +//! This module holds the address books client and [`tower::Service`]. +//! +//! To start the address book use [`start_address_book`]. +// TODO: Store banned peers persistently. use std::future::Future; use std::pin::Pin; +use std::task::Poll; use futures::channel::{mpsc, oneshot}; use futures::FutureExt; -use tokio::task::spawn; +use tokio::task::{spawn, JoinHandle}; use tower::steer::Steer; +use tower::BoxError; +use tracing::Instrument; use monero_wire::network_address::NetZone; -use super::address_book::{AddressBook, AddressBookClientRequest}; -use super::{ - AddressBookConfig, AddressBookError, AddressBookRequest, AddressBookResponse, AddressBookStore, -}; +use crate::{Config, P2PStore}; +use super::address_book::{AddressBook, AddressBookClientRequest}; +use super::{AddressBookError, AddressBookRequest, AddressBookResponse}; + +/// Start the address book. +/// Under the hood this function spawns 3 address books +/// for the 3 [`NetZone`] and combines them into a [`tower::Steer`](Steer). pub async fn start_address_book<S>( peer_store: S, - config: AddressBookConfig, + config: Config, ) -> Result< impl tower::Service< - AddressBookRequest, - Response = AddressBookResponse, - Error = AddressBookError, - Future = Pin< - Box< - dyn Future<Output = Result<AddressBookResponse, AddressBookError>> - + Send - + 'static, - >, - >, - > + Clone, - AddressBookError, + AddressBookRequest, + Response = AddressBookResponse, + Error = BoxError, + Future = Pin< + Box<dyn Future<Output = Result<AddressBookResponse, BoxError>> + Send + 'static>, + >, + >, + BoxError, > where - S: AddressBookStore, + S: P2PStore, { let mut builder = AddressBookBuilder::new(peer_store, config); @@ -40,11 +46,13 @@ where let tor = builder.build(NetZone::Tor).await?; let i2p = builder.build(NetZone::I2p).await?; + // This list MUST be in the same order as closuer in the `Steer` func let books = vec![public, tor, i2p]; Ok(Steer::new( books, |req: &AddressBookRequest, _: &[_]| match req.get_zone() { + // This: NetZone::Public => 0, NetZone::Tor => 1, NetZone::I2p => 2, @@ -52,68 +60,105 @@ where )) } -pub struct AddressBookBuilder<S> { +/// An address book builder. +/// This: +/// - starts the address book +/// - creates and returns the `AddressBookClient` +struct AddressBookBuilder<S> { peer_store: S, - config: AddressBookConfig, + config: Config, } impl<S> AddressBookBuilder<S> where - S: AddressBookStore, + S: P2PStore, { - fn new(peer_store: S, config: AddressBookConfig) -> Self { + fn new(peer_store: S, config: Config) -> Self { AddressBookBuilder { peer_store, config } } + /// Builds the address book for a specific [`NetZone`] async fn build(&mut self, zone: NetZone) -> Result<AddressBookClient, AddressBookError> { - let (white, gray, anchor, bans) = - self.peer_store.load_peers(zone).await.map_err(Into::into)?; + let (white, gray, anchor) = self + .peer_store + .load_peers(zone) + .await + .map_err(|e| AddressBookError::PeerStoreError(e))?; - let book = AddressBook::new(self.config.clone(), zone, white, gray, anchor, bans); + let book = AddressBook::new( + self.config.clone(), + zone, + white, + gray, + anchor, + vec![], + self.peer_store.clone(), + ); - let (tx, rx) = mpsc::channel(5); + let (tx, rx) = mpsc::channel(0); - spawn(book.run(rx)); + let book_span = tracing::info_span!("AddressBook", book = book.book_name()); - Ok(AddressBookClient { book: tx }) + let book_handle = spawn(book.run(rx).instrument(book_span)); + + Ok(AddressBookClient { + book: tx, + book_handle, + }) } } -#[derive(Debug, Clone)] +/// The Client for an individual address book. +#[derive(Debug)] struct AddressBookClient { + /// The channel to pass requests to the address book. book: mpsc::Sender<AddressBookClientRequest>, + /// The address book task handle. + book_handle: JoinHandle<()>, } impl tower::Service<AddressBookRequest> for AddressBookClient { - type Error = AddressBookError; type Response = AddressBookResponse; + type Error = BoxError; type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>; - fn poll_ready( - &mut self, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll<Result<(), Self::Error>> { - self.book - .poll_ready(cx) - .map_err(|_| AddressBookError::AddressBooksChannelClosed) + fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> { + // Check the channel + match self.book.poll_ready(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Ok(())) => (), + Poll::Ready(Err(_)) => { + return Poll::Ready(Err(AddressBookError::AddressBooksChannelClosed.into())) + } + } + + // Check the address book task is still running + match self.book_handle.poll_unpin(cx) { + // The address book is still running + Poll::Pending => Poll::Ready(Ok(())), + // The address book task has exited + Poll::Ready(_) => Err(AddressBookError::AddressBookTaskExited)?, + } } fn call(&mut self, req: AddressBookRequest) -> Self::Future { let (tx, rx) = oneshot::channel(); // get the callers span - let span = tracing::span::Span::current(); + let span = tracing::debug_span!(parent: &tracing::span::Span::current(), "AddressBook"); let req = AddressBookClientRequest { req, tx, span }; match self.book.try_send(req) { Err(_e) => { // I'm assuming all callers will call `poll_ready` first (which they are supposed to) - futures::future::ready(Err(AddressBookError::AddressBooksChannelClosed)).boxed() + futures::future::ready(Err(AddressBookError::AddressBooksChannelClosed.into())) + .boxed() } Ok(()) => async move { rx.await .expect("Address Book will not drop requests until completed") + .map_err(Into::into) } .boxed(), } diff --git a/p2p/src/address_book/address_book.rs b/p2p/src/address_book/address_book.rs index 680bf10..84d149b 100644 --- a/p2p/src/address_book/address_book.rs +++ b/p2p/src/address_book/address_book.rs @@ -1,70 +1,145 @@ +//! This module contains the actual address book logic. +//! +//! The address book is split into multiple [`PeerList`]: +//! +//! - A White list: For peers we have connected to ourselves. +//! +//! - A Gray list: For Peers we have been told about but +//! haven't connected to ourselves. +//! +//! - An Anchor list: This holds peers we are currently +//! connected to that are reachable if we were to +//! connect to them again. For example an inbound proxy +//! connection would not get added to this list as we cant +//! connect to this peer ourselves. Behind the scenes we +//! are just storing the key to a peer in the white list. +//! use std::collections::{HashMap, HashSet}; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use futures::stream::FuturesUnordered; use futures::{ channel::{mpsc, oneshot}, - StreamExt, + FutureExt, Stream, StreamExt, }; -use rand::{Rng, SeedableRng}; -use std::time::Duration; +use pin_project::pin_project; +use rand::prelude::SliceRandom; +use cuprate_common::shutdown::is_shutting_down; use cuprate_common::PruningSeed; -use monero_wire::{messages::PeerListEntryBase, network_address::NetZone, NetworkAddress}; +use monero_wire::{messages::PeerListEntryBase, network_address::NetZone, NetworkAddress, PeerID}; -use super::{AddressBookConfig, AddressBookError, AddressBookRequest, AddressBookResponse}; +use super::{AddressBookError, AddressBookRequest, AddressBookResponse}; +use crate::address_book::connection_handle::ConnectionAddressBookHandle; +use crate::{constants::ADDRESS_BOOK_SAVE_INTERVAL, Config, P2PStore}; mod peer_list; use peer_list::PeerList; -pub(crate) struct AddressBookClientRequest { - pub req: AddressBookRequest, - pub tx: oneshot::Sender<Result<AddressBookResponse, AddressBookError>>, +#[cfg(test)] +mod tests; +/// A request sent to the address book task. +pub(crate) struct AddressBookClientRequest { + /// The request + pub req: AddressBookRequest, + /// A oneshot to send the result down + pub tx: oneshot::Sender<Result<AddressBookResponse, AddressBookError>>, + /// The tracing span to keep the context of the request pub span: tracing::Span, } -pub struct AddressBook { - zone: NetZone, - config: AddressBookConfig, - white_list: PeerList, - gray_list: PeerList, - anchor_list: HashSet<NetworkAddress>, - - baned_peers: HashMap<NetworkAddress, chrono::NaiveDateTime>, - - rng: rand::rngs::StdRng, - //banned_subnets:, +/// An entry in the connected list. +pub struct ConnectionPeerEntry { + /// A oneshot sent from the Connection when it has finished. + connection_handle: ConnectionAddressBookHandle, + /// The connection addr, None if the peer is connected through + /// a hidden network. + addr: Option<NetworkAddress>, + /// If the peer is reachable by our node. + reachable: bool, + /// The last seen timestamp, note: Cuprate may skip updating this + /// field on some inbound messages + last_seen: chrono::NaiveDateTime, + /// The peers pruning seed + pruning_seed: PruningSeed, + /// The peers port. + rpc_port: u16, + /// The peers rpc credits per hash + rpc_credits_per_hash: u32, } -impl AddressBook { +/// A future that resolves when a peer is unbanned. +#[pin_project(project = EnumProj)] +pub struct BanedPeerFut(Vec<u8>, #[pin] tokio::time::Sleep); + +impl Future for BanedPeerFut { + type Output = Vec<u8>; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let mut this = self.project(); + match this.1.poll_unpin(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(_) => Poll::Ready(this.0.clone()), + } + } +} + +/// The address book for a specific [`NetZone`] +pub struct AddressBook<PeerStore> { + /// The [`NetZone`] of this address book. + zone: NetZone, + /// A copy of the nodes configuration. + config: Config, + /// The Address books white list. + white_list: PeerList, + /// The Address books gray list. + gray_list: PeerList, + /// The Address books anchor list. + anchor_list: HashSet<NetworkAddress>, + /// The Currently connected peers. + connected_peers: HashMap<PeerID, ConnectionPeerEntry>, + /// A tuple of: + /// - A hashset of [`ban_identifier`](NetworkAddress::ban_identifier) + /// - A [`FuturesUnordered`] which contains futures for every ban_id + /// that will resolve when the ban_id should be un banned. + baned_peers: (HashSet<Vec<u8>>, FuturesUnordered<BanedPeerFut>), + /// The peer store to save the peers to persistent storage + p2p_store: PeerStore, +} + +impl<PeerStore: P2PStore> AddressBook<PeerStore> { + /// Creates a new address book for a given [`NetZone`] pub fn new( - config: AddressBookConfig, + config: Config, zone: NetZone, white_peers: Vec<PeerListEntryBase>, gray_peers: Vec<PeerListEntryBase>, anchor_peers: Vec<NetworkAddress>, baned_peers: Vec<(NetworkAddress, chrono::NaiveDateTime)>, - ) -> AddressBook { - let rng = rand::prelude::StdRng::from_entropy(); + p2p_store: PeerStore, + ) -> Self { let white_list = PeerList::new(white_peers); let gray_list = PeerList::new(gray_peers); let anchor_list = HashSet::from_iter(anchor_peers); - let baned_peers = HashMap::from_iter(baned_peers); + let baned_peers = (HashSet::new(), FuturesUnordered::new()); - let mut book = AddressBook { + let connected_peers = HashMap::new(); + + AddressBook { zone, config, white_list, gray_list, anchor_list, + connected_peers, baned_peers, - rng, - }; - - book.check_unban_peers(); - - book + p2p_store, + } } + /// Returns the books name (Based on the [`NetZone`]) pub const fn book_name(&self) -> &'static str { match self.zone { NetZone::Public => "PublicAddressBook", @@ -73,80 +148,137 @@ impl AddressBook { } } + /// Returns the length of the white list fn len_white_list(&self) -> usize { self.white_list.len() } + /// Returns the length of the gray list fn len_gray_list(&self) -> usize { self.gray_list.len() } + /// Returns the length of the anchor list + fn len_anchor_list(&self) -> usize { + self.anchor_list.len() + } + + /// Returns the length of the banned list + fn len_banned_list(&self) -> usize { + self.baned_peers.0.len() + } + + /// Returns the maximum length of the white list + /// *note this list can grow bigger if we are connected to more + /// than this amount. fn max_white_peers(&self) -> usize { - self.config.max_white_peers + self.config.max_white_peers() } + /// Returns the maximum length of the gray list fn max_gray_peers(&self) -> usize { - self.config.max_gray_peers + self.config.max_gray_peers() } + /// Checks if a peer is banned. fn is_peer_banned(&self, peer: &NetworkAddress) -> bool { - self.baned_peers.contains_key(peer) + self.baned_peers.0.contains(&peer.ban_identifier()) } + /// Checks if banned peers should be unbanned as the duration has elapsed fn check_unban_peers(&mut self) { - let mut now = chrono::Utc::now().naive_utc(); - self.baned_peers.retain(|_, time| time > &mut now) - } - - fn ban_peer(&mut self, peer: NetworkAddress, till: chrono::NaiveDateTime) { - let now = chrono::Utc::now().naive_utc(); - if now > till { - return; + while let Some(Some(addr)) = Pin::new(&mut self.baned_peers.1).next().now_or_never() { + tracing::debug!("Unbanning peer: {addr:?}"); + self.baned_peers.0.remove(&addr); } - - tracing::debug!("Banning peer: {peer:?} until: {till}"); - - self.baned_peers.insert(peer, till); } - fn add_peer_to_anchor(&mut self, peer: NetworkAddress) -> Result<(), AddressBookError> { - tracing::debug!("Adding peer: {peer:?} to anchor list"); - // is peer in gray list - if let Some(peer_eb) = self.gray_list.remove_peer(&peer) { - self.white_list.add_new_peer(peer_eb); - self.anchor_list.insert(peer); - Ok(()) - } else { - if !self.white_list.contains_peer(&peer) { - return Err(AddressBookError::PeerNotFound); + /// Checks if peers have disconnected, if they have removing them from the + /// connected and anchor list. + fn check_connected_peers(&mut self) { + let mut remove_from_anchor = vec![]; + // We dont have to worry about updating our white list with the information + // before we remove the peers as that happens on every save. + self.connected_peers.retain(|_, peer| { + if !peer.connection_handle.connection_closed() { + // add the peer to the list to get removed from the anchor + if let Some(addr) = peer.addr { + remove_from_anchor.push(addr) + } + false + } else { + true + } + }); + // If we are shutting down we want to keep our anchor peers for + // the next time we boot up so we dont remove disconnecting peers + // from the anchor list if we are shutting down. + if !is_shutting_down() { + for peer in remove_from_anchor { + self.anchor_list.remove(&peer); } - self.anchor_list.insert(peer); - Ok(()) } } - fn remove_peer_from_anchor(&mut self, peer: NetworkAddress) { - let _ = self.anchor_list.remove(&peer); - } - - fn set_peer_seen( + // Bans the peer and tells the connection tasks of peers with the same ban id to shutdown. + fn ban_peer( &mut self, - peer: NetworkAddress, - last_seen: i64, + peer: PeerID, + time: std::time::Duration, ) -> Result<(), AddressBookError> { - if let Some(mut peer) = self.gray_list.remove_peer(&peer) { - peer.last_seen = last_seen; - self.white_list.add_new_peer(peer); - } else { - let peer = self - .white_list - .get_peer_mut(&peer) - .ok_or(AddressBookError::PeerNotFound)?; - peer.last_seen = last_seen; + tracing::debug!("Banning peer: {peer:?} for: {time:?}"); + + let Some(conn_entry) = self.connected_peers.get(&peer) else { + tracing::debug!("Peer is not in connected list"); + return Err(AddressBookError::PeerNotFound); + }; + // tell the connection task to finish. + conn_entry.connection_handle.kill_connection(); + // try find the NetworkAddress of the peer + let Some(addr) = conn_entry.addr else { + tracing::debug!("Peer does not have an address we can ban"); + return Ok(()); + }; + + let ban_id = addr.ban_identifier(); + + self.white_list.remove_peers_with_ban_id(&ban_id); + self.gray_list.remove_peers_with_ban_id(&ban_id); + // Dont remove from anchor list or connection list as this will happen when + // the connection is closed. + + // tell the connection task of peers with the same ban id to shutdown. + for conn in self.connected_peers.values() { + if let Some(addr) = conn.addr { + if addr.ban_identifier() == ban_id { + conn.connection_handle.kill_connection() + } + } } + + // add the ban identifier to the ban list + self.baned_peers.0.insert(ban_id.clone()); + self.baned_peers + .1 + .push(BanedPeerFut(ban_id, tokio::time::sleep(time))); Ok(()) } + /// Update the last seen timestamp of a connected peer. + fn update_last_seen( + &mut self, + peer: PeerID, + last_seen: chrono::NaiveDateTime, + ) -> Result<(), AddressBookError> { + if let Some(mut peer) = self.connected_peers.get_mut(&peer) { + peer.last_seen = last_seen; + Ok(()) + } else { + Err(AddressBookError::PeerNotFound) + } + } + + /// adds a peer to the gray list. fn add_peer_to_gray_list(&mut self, mut peer: PeerListEntryBase) { if self.white_list.contains_peer(&peer.adr) { return; @@ -157,6 +289,9 @@ impl AddressBook { } } + /// handles an incoming peer list, + /// dose some basic validation on the addresses + /// appends the good peers to our book. fn handle_new_peerlist( &mut self, mut peers: Vec<PeerListEntryBase>, @@ -198,77 +333,262 @@ impl AddressBook { } } - fn get_random_gray_peer(&mut self) -> Option<PeerListEntryBase> { - self.gray_list.get_random_peer(&mut self.rng).map(|p| *p) + /// Gets a random peer from our gray list. + /// If pruning seed is set we will get a peer with that pruning seed. + fn get_random_gray_peer( + &mut self, + pruning_seed: Option<PruningSeed>, + ) -> Option<PeerListEntryBase> { + self.gray_list + .get_random_peer(&mut rand::thread_rng(), pruning_seed.map(Into::into)) + .map(|p| *p) } - fn get_random_white_peer(&mut self) -> Option<PeerListEntryBase> { - self.white_list.get_random_peer(&mut self.rng).map(|p| *p) + /// Gets a random peer from our white list. + /// If pruning seed is set we will get a peer with that pruning seed. + fn get_random_white_peer( + &mut self, + pruning_seed: Option<PruningSeed>, + ) -> Option<PeerListEntryBase> { + self.white_list + .get_random_peer(&mut rand::thread_rng(), pruning_seed.map(Into::into)) + .map(|p| *p) } - fn update_peer_info(&mut self, peer: PeerListEntryBase) -> Result<(), AddressBookError> { - if let Some(peer_stored) = self.gray_list.get_peer_mut(&peer.adr) { - *peer_stored = peer; - Ok(()) - } else if let Some(peer_stored) = self.white_list.get_peer_mut(&peer.adr) { - *peer_stored = peer; - Ok(()) - } else { - return Err(AddressBookError::PeerNotFound); + /// Gets random peers from our white list. + /// will be less than or equal to `len`. + fn get_random_white_peers(&mut self, len: usize) -> Vec<PeerListEntryBase> { + let white_len = self.white_list.len(); + let len = if len < white_len { len } else { white_len }; + let mut white_peers: Vec<&PeerListEntryBase> = self.white_list.iter_all_peers().collect(); + white_peers.shuffle(&mut rand::thread_rng()); + white_peers[0..len].iter().map(|peb| **peb).collect() + } + + /// Updates an entry in the white list, if the peer is not found and `reachable` is true then + /// the peer will be added to the white list. + fn update_white_list_peer_entry( + &mut self, + addr: &NetworkAddress, + id: PeerID, + conn_entry: &ConnectionPeerEntry, + ) -> Result<(), AddressBookError> { + if let Some(peb) = self.white_list.get_peer_mut(addr) { + if peb.pruning_seed == conn_entry.pruning_seed.into() { + return Err(AddressBookError::PeersPruningSeedChanged); + } + peb.id = id; + peb.last_seen = conn_entry.last_seen.timestamp(); + peb.rpc_port = conn_entry.rpc_port; + peb.rpc_credits_per_hash = conn_entry.rpc_credits_per_hash; + peb.pruning_seed = conn_entry.pruning_seed.into(); + } else if conn_entry.reachable { + // if the peer is reachable add it to our white list + let peb = PeerListEntryBase { + id, + adr: *addr, + last_seen: conn_entry.last_seen.timestamp(), + rpc_port: conn_entry.rpc_port, + rpc_credits_per_hash: conn_entry.rpc_credits_per_hash, + pruning_seed: conn_entry.pruning_seed.into(), + }; + self.white_list.add_new_peer(peb); + } + Ok(()) + } + + /// Handles a new connection, adding it to the white list if the + /// peer is reachable by our node. + fn handle_new_connection( + &mut self, + connection_handle: ConnectionAddressBookHandle, + addr: Option<NetworkAddress>, + id: PeerID, + reachable: bool, + last_seen: chrono::NaiveDateTime, + pruning_seed: PruningSeed, + rpc_port: u16, + rpc_credits_per_hash: u32, + ) -> Result<(), AddressBookError> { + let connection_entry = ConnectionPeerEntry { + connection_handle, + addr, + reachable, + last_seen, + pruning_seed, + rpc_port, + rpc_credits_per_hash, + }; + if let Some(addr) = addr { + if self.baned_peers.0.contains(&addr.ban_identifier()) { + return Err(AddressBookError::PeerIsBanned); + } + // remove the peer from the gray list as we know it's active. + let _ = self.gray_list.remove_peer(&addr); + if !reachable { + // If we can't reach the peer remove it from the white list as well + let _ = self.white_list.remove_peer(&addr); + } else { + // The peer is reachable, update our white list and add it to the anchor connections. + self.update_white_list_peer_entry(&addr, id, &connection_entry)?; + self.anchor_list.insert(addr); + } + } + + self.connected_peers.insert(id, connection_entry); + self.white_list + .reduce_list(&self.anchor_list, self.max_white_peers()); + Ok(()) + } + + /// Get and empties the anchor list, used at startup to + /// connect to some peers we were previously connected to. + fn get_and_empty_anchor_list(&mut self) -> Vec<PeerListEntryBase> { + self.anchor_list + .drain() + .map(|addr| { + self.white_list + .get_peer(&addr) + .expect("If peer is in anchor it must be in white list") + .clone() + }) + .collect() + } + + /// Handles an [`AddressBookClientRequest`] to the address book. + async fn handle_request(&mut self, req: AddressBookClientRequest) { + let _guard = req.span.enter(); + + tracing::trace!("received request: {}", req.req); + + let res = match req.req { + AddressBookRequest::HandleNewPeerList(new_peers, _) => self + .handle_new_peerlist(new_peers) + .map(|_| AddressBookResponse::Ok), + AddressBookRequest::SetPeerSeen(peer, last_seen, _) => self + .update_last_seen(peer, last_seen) + .map(|_| AddressBookResponse::Ok), + AddressBookRequest::BanPeer(peer, time, _) => { + self.ban_peer(peer, time).map(|_| AddressBookResponse::Ok) + } + AddressBookRequest::ConnectedToPeer { + zone: _, + connection_handle, + addr, + id, + reachable, + last_seen, + pruning_seed, + rpc_port, + rpc_credits_per_hash, + } => self + .handle_new_connection( + connection_handle, + addr, + id, + reachable, + last_seen, + pruning_seed, + rpc_port, + rpc_credits_per_hash, + ) + .map(|_| AddressBookResponse::Ok), + + AddressBookRequest::GetAndEmptyAnchorList(_) => { + Ok(AddressBookResponse::Peers(self.get_and_empty_anchor_list())) + } + + AddressBookRequest::GetRandomGrayPeer(_, pruning_seed) => { + match self.get_random_gray_peer(pruning_seed) { + Some(peer) => Ok(AddressBookResponse::Peer(peer)), + None => Err(AddressBookError::PeerListEmpty), + } + } + AddressBookRequest::GetRandomWhitePeer(_, pruning_seed) => { + match self.get_random_white_peer(pruning_seed) { + Some(peer) => Ok(AddressBookResponse::Peer(peer)), + None => Err(AddressBookError::PeerListEmpty), + } + } + AddressBookRequest::GetRandomWhitePeers(_, len) => { + Ok(AddressBookResponse::Peers(self.get_random_white_peers(len))) + } + }; + + if let Err(e) = &res { + tracing::debug!("Error when handling request, err: {e}") + } + + let _ = req.tx.send(res); + } + + /// Updates the white list with the information in the `connected_peers` list. + /// This only updates the `last_seen` timestamp as that's the only thing that should + /// change during connections. + fn update_white_list_with_conn_list(&mut self) { + for (_, peer) in self.connected_peers.iter() { + if peer.reachable { + if let Some(peer_eb) = self.white_list.get_peer_mut(&peer.addr.unwrap()) { + peer_eb.last_seen = peer.last_seen.timestamp(); + } + } } } + /// Saves the address book to persistent storage. + /// TODO: save the banned peer list. + #[tracing::instrument(level="trace", skip(self), fields(name = self.book_name()) )] + async fn save(&mut self) { + self.update_white_list_with_conn_list(); + tracing::trace!( + "white_len: {}, gray_len: {}, anchor_len: {}, banned_len: {}", + self.len_white_list(), + self.len_gray_list(), + self.len_anchor_list(), + self.len_banned_list() + ); + let res = self + .p2p_store + .save_peers( + self.zone, + (&self.white_list).into(), + (&self.gray_list).into(), + self.anchor_list.iter().collect(), + ) + .await; + match res { + Ok(()) => tracing::trace!("Complete"), + Err(e) => tracing::error!("Error saving address book: {e}"), + } + } + + /// Runs the address book task + /// Should be spawned in a task. pub(crate) async fn run(mut self, mut rx: mpsc::Receiver<AddressBookClientRequest>) { + let mut save_interval = { + let mut interval = tokio::time::interval(ADDRESS_BOOK_SAVE_INTERVAL); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + // Interval ticks at 0, interval, 2 interval, ... + // this is just to ignore the first tick + interval.tick().await; + tokio_stream::wrappers::IntervalStream::new(interval).fuse() + }; + loop { - let Some(req) = rx.next().await else { - // the client has been dropped the node has *possibly* shut down - return; - }; - self.check_unban_peers(); - - let span = tracing::debug_span!(parent: &req.span, "AddressBook"); - let _guard = span.enter(); - - tracing::debug!("{} received request: {}", self.book_name(), req.req); - - let res = match req.req { - AddressBookRequest::HandleNewPeerList(new_peers, _) => self - .handle_new_peerlist(new_peers) - .map(|_| AddressBookResponse::Ok), - AddressBookRequest::SetPeerSeen(peer, last_seen) => self - .set_peer_seen(peer, last_seen) - .map(|_| AddressBookResponse::Ok), - AddressBookRequest::BanPeer(peer, till) => { - self.ban_peer(peer, till); - Ok(AddressBookResponse::Ok) + self.check_connected_peers(); + futures::select! { + req = rx.next() => { + if let Some(req) = req { + self.handle_request(req).await + } else { + tracing::debug!("{} req channel closed, saving and shutting down book", self.book_name()); + self.save().await; + return; + } } - AddressBookRequest::AddPeerToAnchor(peer) => self - .add_peer_to_anchor(peer) - .map(|_| AddressBookResponse::Ok), - AddressBookRequest::RemovePeerFromAnchor(peer) => { - self.remove_peer_from_anchor(peer); - Ok(AddressBookResponse::Ok) - } - AddressBookRequest::UpdatePeerInfo(peer) => { - self.update_peer_info(peer).map(|_| AddressBookResponse::Ok) - } - - AddressBookRequest::GetRandomGrayPeer(_) => match self.get_random_gray_peer() { - Some(peer) => Ok(AddressBookResponse::Peer(peer)), - None => Err(AddressBookError::PeerListEmpty), - }, - AddressBookRequest::GetRandomWhitePeer(_) => match self.get_random_white_peer() { - Some(peer) => Ok(AddressBookResponse::Peer(peer)), - None => Err(AddressBookError::PeerListEmpty), - }, - }; - - if let Err(e) = &res { - tracing::debug!("Error when handling request, err: {e}") + _ = save_interval.next() => self.save().await } - - let _ = req.tx.send(res); } } } diff --git a/p2p/src/address_book/address_book/peer_list.rs b/p2p/src/address_book/address_book/peer_list.rs index 4afa55d..de8a17e 100644 --- a/p2p/src/address_book/address_book/peer_list.rs +++ b/p2p/src/address_book/address_book/peer_list.rs @@ -1,17 +1,42 @@ +//! This module contains the individual address books peer lists. +//! use std::collections::{HashMap, HashSet}; +use std::hash::Hash; +use cuprate_common::CRYPTONOTE_PRUNING_LOG_STRIPES; use monero_wire::{messages::PeerListEntryBase, NetworkAddress}; use rand::Rng; +#[cfg(test)] +mod tests; + +/// A Peer list in the address book. +/// +/// This could either be the white list or gray list. pub struct PeerList { + /// The peers with their peer data. peers: HashMap<NetworkAddress, PeerListEntryBase>, + /// An index of Pruning seed to address, so + /// can quickly grab peers with the pruning seed + /// we want. pruning_idxs: HashMap<u32, Vec<NetworkAddress>>, + /// An index of [`ban_identifier`](NetworkAddress::ban_identifier) to Address + /// to allow us to quickly remove baned peers. + ban_id_idxs: HashMap<Vec<u8>, Vec<NetworkAddress>>, +} + +impl<'a> Into<Vec<&'a PeerListEntryBase>> for &'a PeerList { + fn into(self) -> Vec<&'a PeerListEntryBase> { + self.peers.iter().map(|(_, peb)| peb).collect() + } } impl PeerList { + /// Creates a new peer list. pub fn new(list: Vec<PeerListEntryBase>) -> PeerList { let mut peers = HashMap::with_capacity(list.len()); - let mut pruning_idxs = HashMap::with_capacity(8); + let mut pruning_idxs = HashMap::with_capacity(2 << CRYPTONOTE_PRUNING_LOG_STRIPES); + let mut ban_id_idxs = HashMap::with_capacity(list.len()); // worse case, every peer has a different NetworkAddress and ban id for peer in list { peers.insert(peer.adr, peer); @@ -20,79 +45,157 @@ impl PeerList { .entry(peer.pruning_seed) .or_insert_with(Vec::new) .push(peer.adr); + + ban_id_idxs + .entry(peer.adr.ban_identifier()) + .or_insert_with(Vec::new) + .push(peer.adr); } PeerList { peers, pruning_idxs, + ban_id_idxs, } } + /// Gets the length of the peer list pub fn len(&self) -> usize { self.peers.len() } + /// Gets the amount of peers with a specific seed. + pub fn len_by_seed(&self, pruning_seed: &u32) -> usize { + self.pruning_idxs + .get(pruning_seed) + .map(|indexes| indexes.len()) + .unwrap_or(0) + } + + /// Adds a new peer to the peer list pub fn add_new_peer(&mut self, peer: PeerListEntryBase) { - if self.peers.insert(peer.adr, peer.clone()).is_none() { + if let None = self.peers.insert(peer.adr, peer) { self.pruning_idxs .entry(peer.pruning_seed) .or_insert_with(Vec::new) .push(peer.adr); + + self.ban_id_idxs + .entry(peer.adr.ban_identifier()) + .or_insert_with(Vec::new) + .push(peer.adr); } } + /// Gets a reference to a peer pub fn get_peer(&self, peer: &NetworkAddress) -> Option<&PeerListEntryBase> { self.peers.get(peer) } - pub fn get_peer_by_idx(&self, n: usize) -> Option<&PeerListEntryBase> { - self.peers.iter().nth(n).map(|(_, ret)| ret) + /// Returns an iterator over every peer in this peer list + pub fn iter_all_peers(&self) -> impl Iterator<Item = &PeerListEntryBase> { + self.peers.values() } - pub fn get_random_peer<R: Rng>(&self, r: &mut R) -> Option<&PeerListEntryBase> { - let len = self.len(); - if len == 0 { - None - } else { - let n = r.gen_range(0..len); + /// Returns a random peer. + /// If the pruning seed is specified then we will get a random peer with + /// that pruning seed otherwise we will just get a random peer in the whole + /// list. + pub fn get_random_peer<R: Rng>( + &self, + r: &mut R, + pruning_seed: Option<u32>, + ) -> Option<&PeerListEntryBase> { + if let Some(seed) = pruning_seed { + let mut peers = self.get_peers_with_pruning(&seed)?; + let len = self.len_by_seed(&seed); + if len == 0 { + None + } else { + let n = r.gen_range(0..len); - self.get_peer_by_idx(n) + peers.nth(n) + } + } else { + let mut peers = self.iter_all_peers(); + let len = self.len(); + if len == 0 { + None + } else { + let n = r.gen_range(0..len); + + peers.nth(n) + } } } + /// Returns a mutable reference to a peer. pub fn get_peer_mut(&mut self, peer: &NetworkAddress) -> Option<&mut PeerListEntryBase> { self.peers.get_mut(peer) } + /// Returns true if the list contains this peer. pub fn contains_peer(&self, peer: &NetworkAddress) -> bool { self.peers.contains_key(peer) } - pub fn get_peers_by_pruning_seed( + /// Returns an iterator of peer info of peers with a specific pruning seed. + fn get_peers_with_pruning( &self, seed: &u32, ) -> Option<impl Iterator<Item = &PeerListEntryBase>> { let addrs = self.pruning_idxs.get(seed)?; - Some(addrs.iter().filter_map(move |addr| self.peers.get(addr))) + + Some(addrs.iter().map(move |addr| { + self.peers + .get(addr) + .expect("Address must be in peer list if we have an idx for it") + })) } + /// Removes a peer from the pruning idx + /// + /// MUST NOT BE USED ALONE fn remove_peer_pruning_idx(&mut self, peer: &PeerListEntryBase) { - if let Some(peer_list) = self.pruning_idxs.get_mut(&peer.pruning_seed) { - if let Some(idx) = peer_list.iter().position(|peer_adr| peer_adr == &peer.adr) { - peer_list.remove(idx); - } else { - unreachable!("This function will only be called when the peer exists."); - } - } else { - unreachable!("Pruning seed must exist if a peer has that seed."); - } + remove_peer_idx(&mut self.pruning_idxs, &peer.pruning_seed, &peer.adr) } + /// Removes a peer from the ban idx + /// + /// MUST NOT BE USED ALONE + fn remove_peer_ban_idx(&mut self, peer: &PeerListEntryBase) { + remove_peer_idx(&mut self.ban_id_idxs, &peer.adr.ban_identifier(), &peer.adr) + } + + /// Removes a peer from all the indexes + /// + /// MUST NOT BE USED ALONE + fn remove_peer_from_all_idxs(&mut self, peer: &PeerListEntryBase) { + self.remove_peer_ban_idx(peer); + self.remove_peer_pruning_idx(peer); + } + + /// Removes a peer from the peer list pub fn remove_peer(&mut self, peer: &NetworkAddress) -> Option<PeerListEntryBase> { let peer_eb = self.peers.remove(peer)?; - self.remove_peer_pruning_idx(&peer_eb); + self.remove_peer_from_all_idxs(&peer_eb); Some(peer_eb) } + /// Removes all peers with a specific ban id. + pub fn remove_peers_with_ban_id(&mut self, ban_id: &Vec<u8>) { + let Some(addresses) = self.ban_id_idxs.get(ban_id) else { + // No peers to ban + return; + }; + for addr in addresses.clone() { + self.remove_peer(&addr); + } + } + + /// Tries to reduce the peer list to `new_len`. + /// + /// This function could keep the list bigger than `new_len` if `must_keep_peers`s length + /// is larger than new_len, in that case we will remove as much as we can. pub fn reduce_list(&mut self, must_keep_peers: &HashSet<NetworkAddress>, new_len: usize) { if new_len >= self.len() { return; @@ -118,165 +221,19 @@ impl PeerList { } } -#[cfg(test)] -mod tests { - use std::{collections::HashSet, vec}; - - use monero_wire::{messages::PeerListEntryBase, NetworkAddress}; - use rand::Rng; - - use super::PeerList; - - fn make_fake_peer_list(numb_o_peers: usize) -> PeerList { - let mut peer_list = vec![PeerListEntryBase::default(); numb_o_peers]; - for (idx, peer) in peer_list.iter_mut().enumerate() { - let NetworkAddress::IPv4(ip) = &mut peer.adr else {panic!("this test requires default to be ipv4")}; - ip.m_ip += idx as u32; - } - - PeerList::new(peer_list) - } - - fn make_fake_peer_list_with_random_pruning_seeds(numb_o_peers: usize) -> PeerList { - let mut r = rand::thread_rng(); - - let mut peer_list = vec![PeerListEntryBase::default(); numb_o_peers]; - for (idx, peer) in peer_list.iter_mut().enumerate() { - let NetworkAddress::IPv4(ip) = &mut peer.adr else {panic!("this test requires default to be ipv4")}; - ip.m_ip += idx as u32; - - peer.pruning_seed = if r.gen_bool(0.4) { - 0 - } else { - r.gen_range(384..=391) - }; - } - - PeerList::new(peer_list) - } - - #[test] - fn peer_list_reduce_length() { - let mut peer_list = make_fake_peer_list(2090); - let must_keep_peers = HashSet::new(); - - let target_len = 2000; - - peer_list.reduce_list(&must_keep_peers, target_len); - - assert_eq!(peer_list.len(), target_len); - } - - #[test] - fn peer_list_reduce_length_with_peers_we_need() { - let mut peer_list = make_fake_peer_list(500); - let must_keep_peers = HashSet::from_iter(peer_list.peers.iter().map(|(adr, _)| *adr)); - - let target_len = 49; - - peer_list.reduce_list(&must_keep_peers, target_len); - - // we can't remove any of the peers we said we need them all - assert_eq!(peer_list.len(), 500); - } - - #[test] - fn peer_list_get_peers_by_pruning_seed() { - let mut r = rand::thread_rng(); - - let peer_list = make_fake_peer_list_with_random_pruning_seeds(1000); - let seed = if r.gen_bool(0.4) { - 0 +/// Remove a peer from an index. +fn remove_peer_idx<T: Hash + Eq + PartialEq>( + idx_map: &mut HashMap<T, Vec<NetworkAddress>>, + idx: &T, + addr: &NetworkAddress, +) { + if let Some(peer_list) = idx_map.get_mut(idx) { + if let Some(idx) = peer_list.iter().position(|peer_adr| peer_adr == addr) { + peer_list.swap_remove(idx); } else { - r.gen_range(384..=391) - }; - - let peers_with_seed = peer_list - .get_peers_by_pruning_seed(&seed) - .expect("If you hit this buy a lottery ticket"); - - for peer in peers_with_seed { - assert_eq!(peer.pruning_seed, seed); + unreachable!("This function will only be called when the peer exists."); } - - assert_eq!(peer_list.len(), 1000); - } - - #[test] - fn peer_list_remove_specific_peer() { - let mut peer_list = make_fake_peer_list_with_random_pruning_seeds(100); - - // generate peer at a random point in the list - let mut peer = NetworkAddress::default(); - let NetworkAddress::IPv4(ip) = &mut peer else {panic!("this test requires default to be ipv4")}; - ip.m_ip += 50; - - assert!(peer_list.remove_peer(&peer).is_some()); - - let pruning_idxs = peer_list.pruning_idxs; - let peers = peer_list.peers; - - for (_, addrs) in pruning_idxs { - addrs.iter().for_each(|adr| assert!(adr != &peer)) - } - - assert!(!peers.contains_key(&peer)); - } - - #[test] - fn peer_list_pruning_idxs_are_correct() { - let peer_list = make_fake_peer_list_with_random_pruning_seeds(100); - let mut total_len = 0; - - for (seed, list) in peer_list.pruning_idxs { - for peer in list.iter() { - assert_eq!(peer_list.peers.get(peer).unwrap().pruning_seed, seed); - total_len += 1; - } - } - - assert_eq!(total_len, peer_list.peers.len()) - } - - #[test] - fn peer_list_add_new_peer() { - let mut peer_list = make_fake_peer_list(10); - let mut new_peer = PeerListEntryBase::default(); - let NetworkAddress::IPv4(ip) = &mut new_peer.adr else {panic!("this test requires default to be ipv4")}; - ip.m_ip += 50; - - peer_list.add_new_peer(new_peer.clone()); - - assert_eq!(peer_list.len(), 11); - assert_eq!(peer_list.get_peer(&new_peer.adr), Some(&new_peer)); - assert!(peer_list - .pruning_idxs - .get(&new_peer.pruning_seed) - .unwrap() - .contains(&new_peer.adr)); - } - - #[test] - fn peer_list_add_existing_peer() { - let mut peer_list = make_fake_peer_list(10); - let existing_peer = peer_list - .get_peer(&NetworkAddress::default()) - .unwrap() - .clone(); - - peer_list.add_new_peer(existing_peer.clone()); - - assert_eq!(peer_list.len(), 10); - assert_eq!(peer_list.get_peer(&existing_peer.adr), Some(&existing_peer)); - } - - #[test] - fn peer_list_get_non_existent_peer() { - let peer_list = make_fake_peer_list(10); - let mut non_existent_peer = NetworkAddress::default(); - let NetworkAddress::IPv4(ip) = &mut non_existent_peer else {panic!("this test requires default to be ipv4")}; - ip.m_ip += 50; - - assert_eq!(peer_list.get_peer(&non_existent_peer), None); + } else { + unreachable!("Index must exist if a peer has that index"); } } diff --git a/p2p/src/address_book/address_book/peer_list/tests.rs b/p2p/src/address_book/address_book/peer_list/tests.rs new file mode 100644 index 0000000..00ca37c --- /dev/null +++ b/p2p/src/address_book/address_book/peer_list/tests.rs @@ -0,0 +1,176 @@ +use std::{collections::HashSet, vec}; + +use monero_wire::{messages::PeerListEntryBase, NetworkAddress}; +use rand::Rng; + +use super::PeerList; + +fn make_fake_peer_list(numb_o_peers: usize) -> PeerList { + let mut peer_list = vec![PeerListEntryBase::default(); numb_o_peers]; + for (idx, peer) in peer_list.iter_mut().enumerate() { + let NetworkAddress::IPv4(ip) = &mut peer.adr else {panic!("this test requires default to be ipv4")}; + ip.m_ip += idx as u32; + } + + PeerList::new(peer_list) +} + +fn make_fake_peer_list_with_random_pruning_seeds(numb_o_peers: usize) -> PeerList { + let mut r = rand::thread_rng(); + + let mut peer_list = vec![PeerListEntryBase::default(); numb_o_peers]; + for (idx, peer) in peer_list.iter_mut().enumerate() { + let NetworkAddress::IPv4(ip) = &mut peer.adr else {panic!("this test requires default to be ipv4")}; + ip.m_ip += idx as u32; + ip.m_port += r.gen_range(0..15); + + peer.pruning_seed = if r.gen_bool(0.4) { + 0 + } else { + r.gen_range(384..=391) + }; + } + + PeerList::new(peer_list) +} + +#[test] +fn peer_list_reduce_length() { + let mut peer_list = make_fake_peer_list(2090); + let must_keep_peers = HashSet::new(); + + let target_len = 2000; + + peer_list.reduce_list(&must_keep_peers, target_len); + + assert_eq!(peer_list.len(), target_len); +} + +#[test] +fn peer_list_reduce_length_with_peers_we_need() { + let mut peer_list = make_fake_peer_list(500); + let must_keep_peers = HashSet::from_iter(peer_list.peers.iter().map(|(adr, _)| *adr)); + + let target_len = 49; + + peer_list.reduce_list(&must_keep_peers, target_len); + + // we can't remove any of the peers we said we need them all + assert_eq!(peer_list.len(), 500); +} + +#[test] +fn peer_list_get_peers_by_pruning_seed() { + let mut r = rand::thread_rng(); + + let peer_list = make_fake_peer_list_with_random_pruning_seeds(1000); + let seed = if r.gen_bool(0.4) { + 0 + } else { + r.gen_range(384..=391) + }; + + let peers_with_seed = peer_list + .get_peers_with_pruning(&seed) + .expect("If you hit this buy a lottery ticket"); + + for peer in peers_with_seed { + assert_eq!(peer.pruning_seed, seed); + } + + assert_eq!(peer_list.len(), 1000); +} + +#[test] +fn peer_list_remove_specific_peer() { + let mut peer_list = make_fake_peer_list_with_random_pruning_seeds(100); + + let peer = peer_list + .get_random_peer(&mut rand::thread_rng(), None) + .unwrap() + .clone(); + + assert!(peer_list.remove_peer(&peer.adr).is_some()); + + let pruning_idxs = peer_list.pruning_idxs; + let peers = peer_list.peers; + + for (_, addrs) in pruning_idxs { + addrs.iter().for_each(|adr| assert_ne!(adr, &peer.adr)) + } + + assert!(!peers.contains_key(&peer.adr)); +} + +#[test] +fn peer_list_pruning_idxs_are_correct() { + let peer_list = make_fake_peer_list_with_random_pruning_seeds(100); + let mut total_len = 0; + + for (seed, list) in peer_list.pruning_idxs { + for peer in list.iter() { + assert_eq!(peer_list.peers.get(peer).unwrap().pruning_seed, seed); + total_len += 1; + } + } + + assert_eq!(total_len, peer_list.peers.len()) +} + +#[test] +fn peer_list_add_new_peer() { + let mut peer_list = make_fake_peer_list(10); + let mut new_peer = PeerListEntryBase::default(); + let NetworkAddress::IPv4(ip) = &mut new_peer.adr else {panic!("this test requires default to be ipv4")}; + ip.m_ip += 50; + + peer_list.add_new_peer(new_peer.clone()); + + assert_eq!(peer_list.len(), 11); + assert_eq!(peer_list.get_peer(&new_peer.adr), Some(&new_peer)); + assert!(peer_list + .pruning_idxs + .get(&new_peer.pruning_seed) + .unwrap() + .contains(&new_peer.adr)); +} + +#[test] +fn peer_list_add_existing_peer() { + let mut peer_list = make_fake_peer_list(10); + let existing_peer = peer_list + .get_peer(&NetworkAddress::default()) + .unwrap() + .clone(); + + peer_list.add_new_peer(existing_peer.clone()); + + assert_eq!(peer_list.len(), 10); + assert_eq!(peer_list.get_peer(&existing_peer.adr), Some(&existing_peer)); +} + +#[test] +fn peer_list_get_non_existent_peer() { + let peer_list = make_fake_peer_list(10); + let mut non_existent_peer = NetworkAddress::default(); + let NetworkAddress::IPv4(ip) = &mut non_existent_peer else {panic!("this test requires default to be ipv4")}; + ip.m_ip += 50; + + assert_eq!(peer_list.get_peer(&non_existent_peer), None); +} + +#[test] +fn peer_list_ban_peers() { + let mut peer_list = make_fake_peer_list_with_random_pruning_seeds(100); + let peer = peer_list + .get_random_peer(&mut rand::thread_rng(), None) + .unwrap(); + let ban_id = peer.adr.ban_identifier(); + assert!(peer_list.contains_peer(&peer.adr)); + assert_ne!(peer_list.ban_id_idxs.get(&ban_id).unwrap().len(), 0); + peer_list.remove_peers_with_ban_id(&ban_id); + assert_eq!(peer_list.ban_id_idxs.get(&ban_id).unwrap().len(), 0); + for (addr, _) in peer_list.peers { + assert_ne!(addr.ban_identifier(), ban_id); + } +} diff --git a/p2p/src/address_book/address_book/tests.rs b/p2p/src/address_book/address_book/tests.rs new file mode 100644 index 0000000..acf7460 --- /dev/null +++ b/p2p/src/address_book/address_book/tests.rs @@ -0,0 +1,81 @@ +use super::*; +use crate::NetZoneBasicNodeData; +use monero_wire::network_address::IPv4Address; +use rand::Rng; + +fn create_random_net_address<R: Rng>(r: &mut R) -> NetworkAddress { + NetworkAddress::IPv4(IPv4Address { + m_ip: r.gen(), + m_port: r.gen(), + }) +} + +fn create_random_net_addr_vec<R: Rng>(r: &mut R, len: usize) -> Vec<NetworkAddress> { + let mut ret = Vec::with_capacity(len); + for i in 0..len { + ret.push(create_random_net_address(r)); + } + ret +} + +fn create_random_peer<R: Rng>(r: &mut R) -> PeerListEntryBase { + PeerListEntryBase { + adr: create_random_net_address(r), + pruning_seed: r.gen_range(384..=391), + id: PeerID(r.gen()), + last_seen: r.gen(), + rpc_port: r.gen(), + rpc_credits_per_hash: r.gen(), + } +} + +fn create_random_peer_vec<R: Rng>(r: &mut R, len: usize) -> Vec<PeerListEntryBase> { + let mut ret = Vec::with_capacity(len); + for i in 0..len { + ret.push(create_random_peer(r)); + } + ret +} + +#[derive(Clone)] +pub struct MockPeerStore; + +#[async_trait::async_trait] +impl P2PStore for MockPeerStore { + async fn basic_node_data(&mut self) -> Result<Option<NetZoneBasicNodeData>, &'static str> { + unimplemented!() + } + async fn save_basic_node_data( + &mut self, + node_id: &NetZoneBasicNodeData, + ) -> Result<(), &'static str> { + unimplemented!() + } + async fn load_peers( + &mut self, + zone: NetZone, + ) -> Result< + ( + Vec<PeerListEntryBase>, + Vec<PeerListEntryBase>, + Vec<NetworkAddress>, + ), + &'static str, + > { + let mut r = rand::thread_rng(); + Ok(( + create_random_peer_vec(&mut r, 300), + create_random_peer_vec(&mut r, 1500), + create_random_net_addr_vec(&mut r, 50), + )) + } + async fn save_peers( + &mut self, + zone: NetZone, + white: Vec<&PeerListEntryBase>, + gray: Vec<&PeerListEntryBase>, + anchor: Vec<&NetworkAddress>, + ) -> Result<(), &'static str> { + todo!() + } +} diff --git a/p2p/src/address_book/connection_handle.rs b/p2p/src/address_book/connection_handle.rs new file mode 100644 index 0000000..1f36155 --- /dev/null +++ b/p2p/src/address_book/connection_handle.rs @@ -0,0 +1,110 @@ +//! This module contains the address book [`Connection`](crate::peer::connection::Connection) handle +//! +//! # Why do we need a handle between the address book and connection task +//! +//! When banning a peer we need to tell the connection task to close and +//! when we close a connection we need to remove it from our connection +//! and anchor list. +//! +//! +use futures::channel::oneshot; +use tokio_util::sync::CancellationToken; + +/// A message sent to tell the address book that a peer has disconnected. +pub struct PeerConnectionClosed; + +/// The connection side of the address book to connection +/// communication. +#[derive(Debug)] +pub struct AddressBookConnectionHandle { + connection_closed: Option<oneshot::Sender<PeerConnectionClosed>>, + close: CancellationToken, +} + +impl AddressBookConnectionHandle { + /// Returns true if the address book has told us to kill the + /// connection. + pub fn is_canceled(&self) -> bool { + self.close.is_cancelled() + } +} + +impl Drop for AddressBookConnectionHandle { + fn drop(&mut self) { + let connection_closed = std::mem::replace(&mut self.connection_closed, None).unwrap(); + let _ = connection_closed.send(PeerConnectionClosed); + } +} + +/// The address book side of the address book to connection +/// communication. +#[derive(Debug)] +pub struct ConnectionAddressBookHandle { + connection_closed: oneshot::Receiver<PeerConnectionClosed>, + killer: CancellationToken, +} + +impl ConnectionAddressBookHandle { + /// Checks if the connection task has closed, returns + /// true if the task has closed + pub fn connection_closed(&mut self) -> bool { + let Ok(mes) = self.connection_closed.try_recv() else { + panic!("This must not be called again after returning true and the connection task must tell us if a connection is closed") + }; + match mes { + None => false, + Some(_) => true, + } + } + + /// Ends the connection task, the caller of this function should + /// wait to be told the connection has closed by [`check_if_connection_closed`](Self::check_if_connection_closed) + /// before acting on the closed connection. + pub fn kill_connection(&self) { + self.killer.cancel() + } +} + +/// Creates a new handle pair that can be given to the connection task and +/// address book respectively. +pub fn new_address_book_connection_handle( +) -> (AddressBookConnectionHandle, ConnectionAddressBookHandle) { + let (tx, rx) = oneshot::channel(); + let token = CancellationToken::new(); + + let ab_c_h = AddressBookConnectionHandle { + connection_closed: Some(tx), + close: token.clone(), + }; + let c_ab_h = ConnectionAddressBookHandle { + connection_closed: rx, + killer: token, + }; + + (ab_c_h, c_ab_h) +} + +#[cfg(test)] +mod tests { + use crate::address_book::connection_handle::new_address_book_connection_handle; + + #[test] + fn close_connection_from_address_book() { + let (conn_side, mut addr_side) = new_address_book_connection_handle(); + + assert!(!conn_side.is_canceled()); + assert!(!addr_side.connection_closed()); + addr_side.kill_connection(); + assert!(conn_side.is_canceled()); + } + + #[test] + fn close_connection_from_connection() { + let (conn_side, mut addr_side) = new_address_book_connection_handle(); + + assert!(!conn_side.is_canceled()); + assert!(!addr_side.connection_closed()); + drop(conn_side); + assert!(addr_side.connection_closed()); + } +} diff --git a/p2p/src/config.rs b/p2p/src/config.rs new file mode 100644 index 0000000..9d8db6b --- /dev/null +++ b/p2p/src/config.rs @@ -0,0 +1,78 @@ +use cuprate_common::Network; +use monero_wire::messages::{common::PeerSupportFlags, BasicNodeData, PeerID}; + +use crate::{ + constants::{ + CUPRATE_SUPPORT_FLAGS, DEFAULT_IN_PEERS, DEFAULT_LOAD_OUT_PEERS_MULTIPLIER, + DEFAULT_TARGET_OUT_PEERS, MAX_GRAY_LIST_PEERS, MAX_WHITE_LIST_PEERS, + }, + NodeID, +}; + +#[derive(Debug, Clone, Copy)] +pub struct Config { + /// Port + my_port: u32, + /// The Network + network: Network, + /// RPC Port + rpc_port: u16, + + target_out_peers: usize, + out_peers_load_multiplier: usize, + max_in_peers: usize, + max_white_peers: usize, + max_gray_peers: usize, +} + +impl Default for Config { + fn default() -> Self { + Config { + my_port: 18080, + network: Network::MainNet, + rpc_port: 18081, + target_out_peers: DEFAULT_TARGET_OUT_PEERS, + out_peers_load_multiplier: DEFAULT_LOAD_OUT_PEERS_MULTIPLIER, + max_in_peers: DEFAULT_IN_PEERS, + max_white_peers: MAX_WHITE_LIST_PEERS, + max_gray_peers: MAX_GRAY_LIST_PEERS, + } + } +} + +impl Config { + pub fn basic_node_data(&self, peer_id: PeerID) -> BasicNodeData { + BasicNodeData { + my_port: self.my_port, + network_id: self.network.network_id(), + peer_id, + support_flags: CUPRATE_SUPPORT_FLAGS, + rpc_port: self.rpc_port, + rpc_credits_per_hash: 0, + } + } + + pub fn peerset_total_connection_limit(&self) -> usize { + self.target_out_peers * self.out_peers_load_multiplier + self.max_in_peers + } + + pub fn network(&self) -> Network { + self.network + } + + pub fn max_white_peers(&self) -> usize { + self.max_white_peers + } + + pub fn max_gray_peers(&self) -> usize { + self.max_gray_peers + } + + pub fn public_port(&self) -> u32 { + self.my_port + } + + pub fn public_rpc_port(&self) -> u16 { + self.rpc_port + } +} diff --git a/p2p/src/connection_counter.rs b/p2p/src/connection_counter.rs new file mode 100644 index 0000000..922b47f --- /dev/null +++ b/p2p/src/connection_counter.rs @@ -0,0 +1,130 @@ +//! Counting active connections used by Cuprate. +//! +//! These types can be used to count any kind of active resource. +//! But they are currently used to track the number of open connections. + +use std::{fmt, sync::Arc}; + +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; + +/// A counter for active connections. +/// +/// Creates a [`ConnectionTracker`] to track each active connection. +/// When these trackers are dropped, the counter gets notified. +pub struct ActiveConnectionCounter { + /// The limit for this type of connection, for diagnostics only. + /// The caller must enforce the limit by ignoring, delaying, or dropping connections. + limit: usize, + + /// The label for this connection counter, typically its type. + label: Arc<str>, + + semaphore: Arc<Semaphore>, +} + +impl fmt::Debug for ActiveConnectionCounter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ActiveConnectionCounter") + .field("label", &self.label) + .field("count", &self.count()) + .field("limit", &self.limit) + .finish() + } +} + +impl ActiveConnectionCounter { + /// Create and return a new active connection counter. + pub fn new_counter() -> Self { + Self::new_counter_with(Semaphore::MAX_PERMITS, "Active Connections") + } + + /// Create and return a new active connection counter with `limit` and `label`. + /// The caller must check and enforce limits using [`update_count()`](Self::update_count). + pub fn new_counter_with<S: ToString>(limit: usize, label: S) -> Self { + let label = label.to_string(); + + Self { + limit, + label: label.into(), + semaphore: Arc::new(Semaphore::new(limit)), + } + } + + /// Create and return a new [`ConnectionTracker`], using a permit from the semaphore, + /// SAFETY: + /// This function will panic if the semaphore doesn't have anymore permits. + pub fn track_connection(&mut self) -> ConnectionTracker { + ConnectionTracker::new(self) + } + + pub fn count(&self) -> usize { + let count = self + .limit + .checked_sub(self.semaphore.available_permits()) + .expect("Limit is less than available connection permits"); + + tracing::trace!( + open_connections = ?count, + limit = ?self.limit, + label = ?self.label, + ); + + count + } + + pub fn available_permits(&self) -> usize { + self.semaphore.available_permits() + } +} + +/// A per-connection tracker. +/// +/// [`ActiveConnectionCounter`] creates a tracker instance for each active connection. +pub struct ConnectionTracker { + /// The permit for this connection, updates the semaphore when dropped. + permit: OwnedSemaphorePermit, + + /// The label for this connection counter, typically its type. + label: Arc<str>, +} + +impl fmt::Debug for ConnectionTracker { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("ConnectionTracker") + .field(&self.label) + .finish() + } +} + +impl ConnectionTracker { + /// Create and return a new active connection tracker, and add 1 to `counter`. + /// All connection trackers share a label with their connection counter. + /// + /// When the returned tracker is dropped, `counter` will be notified. + /// + /// SAFETY: + /// This function will panic if the [`ActiveConnectionCounter`] doesn't have anymore permits. + fn new(counter: &mut ActiveConnectionCounter) -> Self { + tracing::debug!( + open_connections = ?counter.count(), + limit = ?counter.limit, + label = ?counter.label, + "opening a new peer connection", + ); + + Self { + permit: counter.semaphore.clone().try_acquire_owned().unwrap(), + label: counter.label.clone(), + } + } +} + +impl Drop for ConnectionTracker { + fn drop(&mut self) { + tracing::debug!( + label = ?self.label, + "A peer connection has closed", + ); + // the permit is automatically dropped + } +} \ No newline at end of file diff --git a/p2p/src/connection_handle.rs b/p2p/src/connection_handle.rs new file mode 100644 index 0000000..f3d3601 --- /dev/null +++ b/p2p/src/connection_handle.rs @@ -0,0 +1,98 @@ +//! +//! # Why do we need a handle between the address book and connection task +//! +//! When banning a peer we need to tell the connection task to close and +//! when we close a connection we need to tell the address book. +//! +//! +use std::time::Duration; + +use futures::channel::mpsc; +use futures::SinkExt; +use tokio_util::sync::CancellationToken; + +use crate::connection_counter::ConnectionTracker; + +#[derive(Default, Debug)] +pub struct HandleBuilder { + tracker: Option<ConnectionTracker>, +} + +impl HandleBuilder { + pub fn set_tracker(&mut self, tracker: ConnectionTracker) { + self.tracker = Some(tracker) + } + + pub fn build(self) -> (DisconnectSignal, ConnectionHandle, PeerHandle) { + let token = CancellationToken::new(); + let (tx, rx) = mpsc::channel(0); + + ( + DisconnectSignal { + token: token.clone(), + tracker: self.tracker.expect("Tracker was not set!"), + }, + ConnectionHandle { + token: token.clone(), + ban: rx, + }, + PeerHandle { ban: tx }, + ) + } +} + +pub struct BanPeer(pub Duration); + +/// A struct given to the connection task. +pub struct DisconnectSignal { + token: CancellationToken, + tracker: ConnectionTracker, +} + +impl DisconnectSignal { + pub fn should_shutdown(&self) -> bool { + self.token.is_cancelled() + } + pub fn connection_closed(&self) { + self.token.cancel() + } +} + +impl Drop for DisconnectSignal { + fn drop(&mut self) { + self.token.cancel() + } +} + +/// A handle given to a task that needs to cancel this connection. +pub struct ConnectionHandle { + token: CancellationToken, + ban: mpsc::Receiver<BanPeer>, +} + +impl ConnectionHandle { + pub fn is_closed(&self) -> bool { + self.token.is_cancelled() + } + pub fn check_should_ban(&mut self) -> Option<BanPeer> { + match self.ban.try_next() { + Ok(res) => res, + Err(_) => None, + } + } + pub fn send_close_signal(&self) { + self.token.cancel() + } +} + +/// A handle given to a task that needs to be able to ban a connection. +#[derive(Clone)] +pub struct PeerHandle { + ban: mpsc::Sender<BanPeer>, +} + +impl PeerHandle { + pub fn ban_peer(&mut self, duration: Duration) { + let _ = self.ban.send(BanPeer(duration)); + } +} diff --git a/p2p/src/constants.rs b/p2p/src/constants.rs new file mode 100644 index 0000000..4d3c900 --- /dev/null +++ b/p2p/src/constants.rs @@ -0,0 +1,58 @@ +use core::time::Duration; + +use monero_wire::messages::common::PeerSupportFlags; + +pub const CUPRATE_SUPPORT_FLAGS: PeerSupportFlags = + PeerSupportFlags::get_support_flag_fluffy_blocks(); + +pub const CUPRATE_MINIMUM_SUPPORT_FLAGS: PeerSupportFlags = + PeerSupportFlags::get_support_flag_fluffy_blocks(); + +pub const DEFAULT_TARGET_OUT_PEERS: usize = 20; + +pub const DEFAULT_LOAD_OUT_PEERS_MULTIPLIER: usize = 3; + +pub const DEFAULT_IN_PEERS: usize = 20; + +pub const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(5); + +pub const ADDRESS_BOOK_SAVE_INTERVAL: Duration = Duration::from_secs(60); + +pub const ADDRESS_BOOK_BUFFER_SIZE: usize = 3; + +pub const PEERSET_BUFFER_SIZE: usize = 3; + +/// The maximum size of the address books white list. +/// This number is copied from monerod. +pub const MAX_WHITE_LIST_PEERS: usize = 1000; + +/// The maximum size of the address books gray list. +/// This number is copied from monerod. +pub const MAX_GRAY_LIST_PEERS: usize = 5000; + +/// The max amount of peers that can be sent in one +/// message. +pub const P2P_MAX_PEERS_IN_HANDSHAKE: usize = 250; + +/// The timeout for sending a message to a remote peer, +/// and receiving a response from a remote peer. +pub const REQUEST_TIMEOUT: Duration = Duration::from_secs(20); + +/// The default RTT estimate for peer responses. +/// +/// We choose a high value for the default RTT, so that new peers must prove they +/// are fast, before we prefer them to other peers. This is particularly +/// important on testnet, which has a small number of peers, which are often +/// slow. +/// +/// Make the default RTT slightly higher than the request timeout. +pub const EWMA_DEFAULT_RTT: Duration = Duration::from_secs(REQUEST_TIMEOUT.as_secs() + 1); + +/// The decay time for the EWMA response time metric used for load balancing. +/// +/// This should be much larger than the `SYNC_RESTART_TIMEOUT`, so we choose +/// better peers when we restart the sync. +pub const EWMA_DECAY_TIME_NANOS: f64 = 200.0 * NANOS_PER_SECOND; + +/// The number of nanoseconds in one second. +const NANOS_PER_SECOND: f64 = 1_000_000_000.0; diff --git a/p2p/src/lib.rs b/p2p/src/lib.rs index 7f61797..cf1fc44 100644 --- a/p2p/src/lib.rs +++ b/p2p/src/lib.rs @@ -1,3 +1,81 @@ pub mod address_book; +pub mod config; +pub mod connection_counter; +mod connection_handle; +mod constants; pub mod peer; mod protocol; + +pub use config::Config; +use rand::Rng; + +#[derive(Debug, Clone)] +pub struct NetZoneBasicNodeData { + public: monero_wire::BasicNodeData, + tor: monero_wire::BasicNodeData, + i2p: monero_wire::BasicNodeData, +} + +impl NetZoneBasicNodeData { + pub fn basic_node_data(&self, net_zone: &monero_wire::NetZone) -> monero_wire::BasicNodeData { + match net_zone { + monero_wire::NetZone::Public => self.public.clone(), + _ => todo!(), + } + } + pub fn new(config: &Config, node_id: &NodeID) -> Self { + let bnd = monero_wire::BasicNodeData { + my_port: config.public_port(), + network_id: config.network().network_id(), + peer_id: node_id.public, + support_flags: constants::CUPRATE_SUPPORT_FLAGS, + rpc_port: config.public_rpc_port(), + rpc_credits_per_hash: 0, + }; + + // obviously this is wrong, i will change when i add tor support + NetZoneBasicNodeData { + public: bnd.clone(), + tor: bnd.clone(), + i2p: bnd, + } + } +} + +#[async_trait::async_trait] +pub trait P2PStore: Clone + Send + 'static { + /// Loads the peers from the peer store. + /// returns (in order): + /// the white list, + /// the gray list, + /// the anchor list, + /// the ban list + async fn load_peers( + &mut self, + zone: monero_wire::NetZone, + ) -> Result< + ( + Vec<monero_wire::PeerListEntryBase>, // white list + Vec<monero_wire::PeerListEntryBase>, // gray list + Vec<monero_wire::NetworkAddress>, // anchor list + // Vec<(monero_wire::NetworkAddress, chrono::NaiveDateTime)>, // ban list + ), + &'static str, + >; + + async fn save_peers( + &mut self, + zone: monero_wire::NetZone, + white: Vec<&monero_wire::PeerListEntryBase>, + gray: Vec<&monero_wire::PeerListEntryBase>, + anchor: Vec<&monero_wire::NetworkAddress>, + // bans: Vec<(&monero_wire::NetworkAddress, &chrono::NaiveDateTime)>, // ban lists + ) -> Result<(), &'static str>; + + async fn basic_node_data(&mut self) -> Result<Option<NetZoneBasicNodeData>, &'static str>; + + async fn save_basic_node_data( + &mut self, + node_id: &NetZoneBasicNodeData, + ) -> Result<(), &'static str>; +} diff --git a/p2p/src/peer.rs b/p2p/src/peer.rs index 5bb16aa..5d1f2ae 100644 --- a/p2p/src/peer.rs +++ b/p2p/src/peer.rs @@ -1,42 +1,16 @@ pub mod client; pub mod connection; +pub mod connector; pub mod handshaker; +pub mod load_tracked_client; +mod error; #[cfg(test)] mod tests; -use monero_wire::levin::BucketError; -use thiserror::Error; - -#[derive(Debug, Error, Clone, Copy)] -pub enum RequestServiceError {} - -#[derive(Debug, Error, Clone, Copy)] -pub enum PeerError { - #[error("Peer is on a different network")] - PeerIsOnAnotherNetwork, - #[error("Peer sent an unexpected response")] - PeerSentUnSolicitedResponse, - #[error("Internal service did not respond when required")] - InternalServiceDidNotRespond, - #[error("Connection to peer has been terminated")] - PeerConnectionClosed, - #[error("The Client `internal` channel was closed")] - ClientChannelClosed, - #[error("The Peer sent an unexpected response")] - PeerSentUnexpectedResponse, - #[error("The peer sent a bad response: {0}")] - ResponseError(&'static str), - #[error("Internal service error: {0}")] - InternalService(#[from] RequestServiceError), - #[error("Internal peer sync channel closed")] - InternalPeerSyncChannelClosed, - #[error("Levin Error")] - LevinError, // remove me, this is just temporary -} - -impl From<BucketError> for PeerError { - fn from(_: BucketError) -> Self { - PeerError::LevinError - } -} +pub use client::Client; +pub use client::ConnectionInfo; +pub use connection::Connection; +pub use connector::{Connector, OutboundConnectorRequest}; +pub use handshaker::Handshaker; +pub use load_tracked_client::LoadTrackedClient; diff --git a/p2p/src/peer/client.rs b/p2p/src/peer/client.rs index 163ebd4..b79a80c 100644 --- a/p2p/src/peer/client.rs +++ b/p2p/src/peer/client.rs @@ -1,45 +1,150 @@ use std::pin::Pin; +use std::sync::atomic::AtomicU64; +use std::task::{Context, Poll}; use std::{future::Future, sync::Arc}; -use crate::protocol::{InternalMessageRequest, InternalMessageResponse}; use futures::{ channel::{mpsc, oneshot}, FutureExt, }; -use monero_wire::messages::PeerID; +use tokio::task::JoinHandle; +use tower::BoxError; + +use cuprate_common::PruningSeed; use monero_wire::{messages::common::PeerSupportFlags, NetworkAddress}; -use super::{connection::ClientRequest, PeerError}; +use super::{ + connection::ClientRequest, + error::{ErrorSlot, PeerError, SharedPeerError}, + PeerError, +}; +use crate::connection_handle::PeerHandle; +use crate::protocol::{InternalMessageRequest, InternalMessageResponse}; pub struct ConnectionInfo { - pub addr: NetworkAddress, pub support_flags: PeerSupportFlags, - /// Peer ID - pub peer_id: PeerID, + pub pruning_seed: PruningSeed, + pub handle: PeerHandle, pub rpc_port: u16, pub rpc_credits_per_hash: u32, } pub struct Client { pub connection_info: Arc<ConnectionInfo>, + /// Used to shut down the corresponding heartbeat. + /// This is always Some except when we take it on drop. + heartbeat_shutdown_tx: Option<oneshot::Sender<()>>, server_tx: mpsc::Sender<ClientRequest>, + connection_task: JoinHandle<()>, + heartbeat_task: JoinHandle<()>, + + error_slot: ErrorSlot, } impl Client { pub fn new( connection_info: Arc<ConnectionInfo>, + heartbeat_shutdown_tx: oneshot::Sender<()>, server_tx: mpsc::Sender<ClientRequest>, + connection_task: JoinHandle<()>, + heartbeat_task: JoinHandle<()>, + error_slot: ErrorSlot, ) -> Self { Client { connection_info, + heartbeat_shutdown_tx: Some(heartbeat_shutdown_tx), server_tx, + connection_task, + heartbeat_task, + error_slot, + } + } + + /// Check if this connection's heartbeat task has exited. + #[allow(clippy::unwrap_in_result)] + fn check_heartbeat(&mut self, cx: &mut Context<'_>) -> Result<(), SharedPeerError> { + let is_canceled = self + .heartbeat_shutdown_tx + .as_mut() + .expect("only taken on drop") + .poll_canceled(cx) + .is_ready(); + + if is_canceled { + return self.set_task_exited_error( + "heartbeat", + PeerError::HeartbeatTaskExited("Task was cancelled".to_string()), + ); + } + + match self.heartbeat_task.poll_unpin(cx) { + Poll::Pending => { + // Heartbeat task is still running. + Ok(()) + } + Poll::Ready(Ok(Ok(_))) => { + // Heartbeat task stopped unexpectedly, without panic or error. + self.set_task_exited_error( + "heartbeat", + PeerError::HeartbeatTaskExited( + "Heartbeat task stopped unexpectedly".to_string(), + ), + ) + } + Poll::Ready(Ok(Err(error))) => { + // Heartbeat task stopped unexpectedly, with error. + self.set_task_exited_error( + "heartbeat", + PeerError::HeartbeatTaskExited(error.to_string()), + ) + } + Poll::Ready(Err(error)) => { + // Heartbeat task was cancelled. + if error.is_cancelled() { + self.set_task_exited_error( + "heartbeat", + PeerError::HeartbeatTaskExited("Task was cancelled".to_string()), + ) + } + // Heartbeat task stopped with panic. + else if error.is_panic() { + panic!("heartbeat task has panicked: {error}"); + } + // Heartbeat task stopped with error. + else { + self.set_task_exited_error( + "heartbeat", + PeerError::HeartbeatTaskExited(error.to_string()), + ) + } + } + } + } + + /// Check if the connection's task has exited. + fn check_connection(&mut self, context: &mut Context<'_>) -> Result<(), PeerError> { + match self.connection_task.poll_unpin(context) { + Poll::Pending => { + // Connection task is still running. + Ok(()) + } + Poll::Ready(Ok(())) => { + // Connection task stopped unexpectedly, without panicking. + return Err(PeerError::ConnectionTaskClosed); + } + Poll::Ready(Err(error)) => { + // Connection task stopped unexpectedly with a panic. shut the node down. + tracing::error!("Peer Connection task panicked: {error}, shutting the node down!"); + set_shutting_down(); + return Err(PeerError::ConnectionTaskClosed); + } } } } impl tower::Service<InternalMessageRequest> for Client { - type Error = PeerError; type Response = InternalMessageResponse; + type Error = SharedPeerError; type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>; @@ -49,7 +154,7 @@ impl tower::Service<InternalMessageRequest> for Client { ) -> std::task::Poll<Result<(), Self::Error>> { self.server_tx .poll_ready(cx) - .map_err(|e| PeerError::ClientChannelClosed) + .map_err(|e| PeerError::ClientChannelClosed.into()) } fn call(&mut self, req: InternalMessageRequest) -> Self::Future { let (tx, rx) = oneshot::channel(); @@ -59,11 +164,12 @@ impl tower::Service<InternalMessageRequest> for Client { .map(|recv_result| { recv_result .expect("ClientRequest oneshot sender must not be dropped before send") + .map_err(|e| e.into()) }) .boxed(), - Err(_e) => { + Err(_) => { // TODO: better error handling - futures::future::ready(Err(PeerError::ClientChannelClosed)).boxed() + futures::future::ready(Err(PeerError::ClientChannelClosed.into())).boxed() } } } diff --git a/p2p/src/peer/connection.rs b/p2p/src/peer/connection.rs index 28c1d1e..d518d2e 100644 --- a/p2p/src/peer/connection.rs +++ b/p2p/src/peer/connection.rs @@ -1,116 +1,78 @@ -use std::collections::HashSet; - use futures::channel::{mpsc, oneshot}; -use futures::stream::Fuse; -use futures::{AsyncRead, AsyncWrite, SinkExt, StreamExt}; +use futures::stream::FusedStream; +use futures::{Sink, SinkExt, Stream, StreamExt}; -use levin::{MessageSink, MessageStream}; -use monero_wire::messages::CoreSyncData; -use monero_wire::{levin, Message, NetworkAddress}; -use tower::{Service, ServiceExt}; +use monero_wire::{Message, BucketError}; +use tower::{BoxError, Service}; -use crate::protocol::{ - InternalMessageRequest, InternalMessageResponse, BLOCKS_IDS_SYNCHRONIZING_MAX_COUNT, - P2P_MAX_PEERS_IN_HANDSHAKE, -}; - -use super::PeerError; - -pub enum PeerSyncChange { - CoreSyncData(NetworkAddress, CoreSyncData), - ObjectsResponse(NetworkAddress, Vec<[u8; 32]>, u64), - PeerDisconnected(NetworkAddress), -} +use crate::connection_handle::DisconnectSignal; +use crate::peer::error::{ErrorSlot, PeerError, SharedPeerError}; +use crate::peer::handshaker::ConnectionAddr; +use crate::protocol::internal_network::{MessageID, Request, Response}; pub struct ClientRequest { - pub req: InternalMessageRequest, - pub tx: oneshot::Sender<Result<InternalMessageResponse, PeerError>>, + pub req: Request, + pub tx: oneshot::Sender<Result<Response, SharedPeerError>>, } pub enum State { WaitingForRequest, WaitingForResponse { - request: InternalMessageRequest, - tx: oneshot::Sender<Result<InternalMessageResponse, PeerError>>, + request_id: MessageID, + tx: oneshot::Sender<Result<Response, SharedPeerError>>, }, } -impl State { - pub fn expected_response_id(&self) -> Option<u32> { - match self { - Self::WaitingForRequest => None, - Self::WaitingForResponse { request, tx: _ } => request.expected_id(), - } - } -} - -pub struct Connection<Svc, Aw, Ar> { - address: NetworkAddress, +pub struct Connection<Svc, Snk> { + address: ConnectionAddr, state: State, - sink: MessageSink<Aw, Message>, - stream: Fuse<MessageStream<Ar, Message>>, + sink: Snk, client_rx: mpsc::Receiver<ClientRequest>, - sync_state_tx: mpsc::Sender<PeerSyncChange>, + + error_slot: ErrorSlot, + + /// # Security + /// + /// If this connection tracker or `Connection`s are leaked, + /// the number of active connections will appear higher than it actually is. + /// If enough connections leak, Cuprate will stop making new connections. + connection_tracker: DisconnectSignal, + svc: Svc, } -impl<Svc, Aw, Ar> Connection<Svc, Aw, Ar> +impl<Svc, Snk> Connection<Svc, Snk> where - Svc: Service<InternalMessageRequest, Response = InternalMessageResponse, Error = PeerError>, - Aw: AsyncWrite + std::marker::Unpin, - Ar: AsyncRead + std::marker::Unpin, + Svc: Service<Request, Response = Response, Error = BoxError>, + Snk: Sink<Message, Error = BucketError> + Unpin, { pub fn new( - address: NetworkAddress, - sink: MessageSink<Aw, Message>, - stream: MessageStream<Ar, Message>, + address: ConnectionAddr, + sink: Snk, client_rx: mpsc::Receiver<ClientRequest>, - sync_state_tx: mpsc::Sender<PeerSyncChange>, + error_slot: ErrorSlot, + connection_tracker: DisconnectSignal, svc: Svc, - ) -> Connection<Svc, Aw, Ar> { + ) -> Connection<Svc, Snk> { Connection { address, state: State::WaitingForRequest, sink, - stream: stream.fuse(), client_rx, - sync_state_tx, + error_slot, + connection_tracker, svc, } } - async fn handle_response(&mut self, res: InternalMessageResponse) -> Result<(), PeerError> { + async fn handle_response(&mut self, res: Response) -> Result<(), PeerError> { let state = std::mem::replace(&mut self.state, State::WaitingForRequest); - if let State::WaitingForResponse { request, tx } = state { - match (request, &res) { - (InternalMessageRequest::Handshake(_), InternalMessageResponse::Handshake(_)) => {} - ( - InternalMessageRequest::SupportFlags(_), - InternalMessageResponse::SupportFlags(_), - ) => {} - (InternalMessageRequest::TimedSync(_), InternalMessageResponse::TimedSync(res)) => { - } - ( - InternalMessageRequest::GetObjectsRequest(req), - InternalMessageResponse::GetObjectsResponse(res), - ) => {} - ( - InternalMessageRequest::ChainRequest(_), - InternalMessageResponse::ChainResponse(res), - ) => {} - ( - InternalMessageRequest::FluffyMissingTransactionsRequest(req), - InternalMessageResponse::NewFluffyBlock(blk), - ) => {} - ( - InternalMessageRequest::GetTxPoolCompliment(_), - InternalMessageResponse::NewTransactions(_), - ) => { - // we could check we received no transactions that we said we knew about but thats going to happen later anyway when they get added to our - // mempool - } - _ => return Err(PeerError::ResponseError("Peer sent incorrect response")), + if let State::WaitingForResponse { request_id, tx } = state { + if request_id != res.id() { + // TODO: Fail here + return Err(PeerError::PeerSentIncorrectResponse); } - // response passed our tests we can send it to the requestor + + // response passed our tests we can send it to the requester let _ = tx.send(Ok(res)); Ok(()) } else { @@ -122,30 +84,36 @@ where Ok(self.sink.send(mes.into()).await?) } - async fn handle_peer_request(&mut self, req: InternalMessageRequest) -> Result<(), PeerError> { + async fn handle_peer_request(&mut self, req: Request) -> Result<(), PeerError> { // we should check contents of peer requests for obvious errors like we do with responses + todo!() + /* let ready_svc = self.svc.ready().await?; let res = ready_svc.call(req).await?; self.send_message_to_peer(res).await + */ } async fn handle_client_request(&mut self, req: ClientRequest) -> Result<(), PeerError> { - // check we need a response - if let Some(_) = req.req.expected_id() { + if req.req.needs_response() { self.state = State::WaitingForResponse { - request: req.req.clone(), + request_id: req.req.id(), tx: req.tx, }; } + // TODO: send NA response to requester self.send_message_to_peer(req.req).await } - async fn state_waiting_for_request(&mut self) -> Result<(), PeerError> { + async fn state_waiting_for_request<Str>(&mut self, stream: &mut Str) -> Result<(), PeerError> + where + Str: FusedStream<Item = Result<Message, BucketError>> + Unpin, + { futures::select! { - peer_message = self.stream.next() => { + peer_message = stream.next() => { match peer_message.expect("MessageStream will never return None") { Ok(message) => { - self.handle_peer_request(message.try_into().map_err(|_| PeerError::PeerSentUnexpectedResponse)?).await + self.handle_peer_request(message.try_into().map_err(|_| PeerError::ResponseError(""))?).await }, Err(e) => Err(e.into()), } @@ -156,10 +124,12 @@ where } } - async fn state_waiting_for_response(&mut self) -> Result<(), PeerError> { + async fn state_waiting_for_response<Str>(&mut self, stream: &mut Str) -> Result<(), PeerError> + where + Str: FusedStream<Item = Result<Message, BucketError>> + Unpin, + { // put a timeout on this - let peer_message = self - .stream + let peer_message = stream .next() .await .expect("MessageStream will never return None")?; @@ -183,12 +153,15 @@ where } } - pub async fn run(mut self) { + pub async fn run<Str>(mut self, mut stream: Str) + where + Str: FusedStream<Item = Result<Message, BucketError>> + Unpin, + { loop { let _res = match self.state { - State::WaitingForRequest => self.state_waiting_for_request().await, - State::WaitingForResponse { request: _, tx: _ } => { - self.state_waiting_for_response().await + State::WaitingForRequest => self.state_waiting_for_request(&mut stream).await, + State::WaitingForResponse { .. } => { + self.state_waiting_for_response(&mut stream).await } }; } diff --git a/p2p/src/peer/connector.rs b/p2p/src/peer/connector.rs new file mode 100644 index 0000000..28f09f9 --- /dev/null +++ b/p2p/src/peer/connector.rs @@ -0,0 +1,159 @@ +//! Wrapper around handshake logic that also opens a TCP connection. + +use std::{ + future::Future, + net::SocketAddr, + pin::Pin, + task::{Context, Poll}, +}; + +use futures::{AsyncRead, AsyncWrite, FutureExt}; +use monero_wire::{network_address::NetZone, NetworkAddress}; +use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; +use tower::{BoxError, Service, ServiceExt}; +use tracing::Instrument; + +use crate::peer::handshaker::ConnectionAddr; +use crate::{ + address_book::{AddressBookRequest, AddressBookResponse}, + connection_counter::ConnectionTracker, + protocol::{ + CoreSyncDataRequest, CoreSyncDataResponse, InternalMessageRequest, InternalMessageResponse, + }, +}; + +use super::{ + handshaker::{DoHandshakeRequest, Handshaker}, + Client, +}; + +async fn connect(addr: &NetworkAddress) -> Result<(impl AsyncRead, impl AsyncWrite), BoxError> { + match addr.get_zone() { + NetZone::Public => { + let stream = + tokio::net::TcpStream::connect(SocketAddr::try_from(*addr).unwrap()).await?; + let (read, write) = stream.into_split(); + Ok((read.compat(), write.compat_write())) + } + _ => unimplemented!(), + } +} + +/// A wrapper around [`Handshake`] that opens a connection before +/// forwarding to the inner handshake service. Writing this as its own +/// [`tower::Service`] lets us apply unified timeout policies, etc. +#[derive(Debug, Clone)] +pub struct Connector<Svc, CoreSync, AdrBook> +where + CoreSync: Service<CoreSyncDataRequest, Response = CoreSyncDataResponse, Error = BoxError> + + Clone + + Send + + 'static, + CoreSync::Future: Send, + + Svc: Service<InternalMessageRequest, Response = InternalMessageResponse, Error = BoxError> + + Clone + + Send + + 'static, + Svc::Future: Send, + + AdrBook: Service<AddressBookRequest, Response = AddressBookResponse, Error = BoxError> + + Clone + + Send + + 'static, + AdrBook::Future: Send, +{ + handshaker: Handshaker<Svc, CoreSync, AdrBook>, +} + +impl<Svc, CoreSync, AdrBook> Connector<Svc, CoreSync, AdrBook> +where + CoreSync: Service<CoreSyncDataRequest, Response = CoreSyncDataResponse, Error = BoxError> + + Clone + + Send + + 'static, + CoreSync::Future: Send, + + Svc: Service<InternalMessageRequest, Response = InternalMessageResponse, Error = BoxError> + + Clone + + Send + + 'static, + Svc::Future: Send, + + AdrBook: Service<AddressBookRequest, Response = AddressBookResponse, Error = BoxError> + + Clone + + Send + + 'static, + AdrBook::Future: Send, +{ + pub fn new(handshaker: Handshaker<Svc, CoreSync, AdrBook>) -> Self { + Connector { handshaker } + } +} + +/// A connector request. +/// Contains the information needed to make an outbound connection to the peer. +pub struct OutboundConnectorRequest { + /// The Monero listener address of the peer. + pub addr: NetworkAddress, + + /// A connection tracker that reduces the open connection count when dropped. + /// + /// Used to limit the number of open connections in Cuprate. + pub connection_tracker: ConnectionTracker, +} + +impl<Svc, CoreSync, AdrBook> Service<OutboundConnectorRequest> for Connector<Svc, CoreSync, AdrBook> +where + CoreSync: Service<CoreSyncDataRequest, Response = CoreSyncDataResponse, Error = BoxError> + + Clone + + Send + + 'static, + CoreSync::Future: Send, + + Svc: Service<InternalMessageRequest, Response = InternalMessageResponse, Error = BoxError> + + Clone + + Send + + 'static, + Svc::Future: Send, + + AdrBook: Service<AddressBookRequest, Response = AddressBookResponse, Error = BoxError> + + Clone + + Send + + 'static, + AdrBook::Future: Send, +{ + type Response = (NetworkAddress, Client); + type Error = BoxError; + type Future = + Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: OutboundConnectorRequest) -> Self::Future { + let OutboundConnectorRequest { + addr: address, + connection_tracker, + }: OutboundConnectorRequest = req; + + let hs = self.handshaker.clone(); + let connector_span = tracing::info_span!("connector", peer = ?address); + + async move { + let (read, write) = connect(&address).await?; + let client = hs + .oneshot(DoHandshakeRequest { + read, + write, + addr: ConnectionAddr::OutBound { address }, + connection_tracker, + }) + .await?; + Ok((address, client)) + } + .instrument(connector_span) + .boxed() + } +} diff --git a/p2p/src/peer/error.rs b/p2p/src/peer/error.rs new file mode 100644 index 0000000..bbf3650 --- /dev/null +++ b/p2p/src/peer/error.rs @@ -0,0 +1,116 @@ +use std::sync::{Arc, Mutex}; + +use monero_wire::BucketError; +use thiserror::Error; +use tracing_error::TracedError; + +/// A wrapper around `Arc<PeerError>` that implements `Error`. +#[derive(Error, Debug, Clone)] +#[error(transparent)] +pub struct SharedPeerError(Arc<TracedError<PeerError>>); + +impl<E> From<E> for SharedPeerError +where + PeerError: From<E>, +{ + fn from(source: E) -> Self { + Self(Arc::new(TracedError::from(PeerError::from(source)))) + } +} + +impl SharedPeerError { + /// Returns a debug-formatted string describing the inner [`PeerError`]. + /// + /// Unfortunately, [`TracedError`] makes it impossible to get a reference to the original error. + pub fn inner_debug(&self) -> String { + format!("{:?}", self.0.as_ref()) + } +} + +#[derive(Debug, Error)] +pub enum PeerError { + #[error("The connection task has closed.")] + ConnectionTaskClosed, + #[error("Error with peers response: {0}.")] + ResponseError(&'static str), + #[error("The connected peer sent an an unexpected response message.")] + PeerSentUnexpectedResponse, + #[error("The connected peer sent an incorrect response.")] + BucketError(#[from] BucketError), + #[error("The channel was closed.")] + ClientChannelClosed, +} + +/// A shared error slot for peer errors. +/// +/// # Correctness +/// +/// Error slots are shared between sync and async code. In async code, the error +/// mutex should be held for as short a time as possible. This avoids blocking +/// the async task thread on acquiring the mutex. +/// +/// > If the value behind the mutex is just data, it’s usually appropriate to use a blocking mutex +/// > ... +/// > wrap the `Arc<Mutex<...>>` in a struct +/// > that provides non-async methods for performing operations on the data within, +/// > and only lock the mutex inside these methods +/// +/// <https://docs.rs/tokio/1.15.0/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use> +#[derive(Default, Clone)] +pub struct ErrorSlot(Arc<std::sync::Mutex<Option<SharedPeerError>>>); + +impl std::fmt::Debug for ErrorSlot { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // don't hang if the mutex is locked + // show the panic if the mutex was poisoned + f.debug_struct("ErrorSlot") + .field("error", &self.0.try_lock()) + .finish() + } +} + +impl ErrorSlot { + /// Read the current error in the slot. + /// + /// Returns `None` if there is no error in the slot. + /// + /// # Correctness + /// + /// Briefly locks the error slot's threaded `std::sync::Mutex`, to get a + /// reference to the error in the slot. + #[allow(clippy::unwrap_in_result)] + pub fn try_get_error(&self) -> Option<SharedPeerError> { + self.0 + .lock() + .expect("error mutex should be unpoisoned") + .as_ref() + .cloned() + } + + /// Update the current error in the slot. + /// + /// Returns `Err(AlreadyErrored)` if there was already an error in the slot. + /// + /// # Correctness + /// + /// Briefly locks the error slot's threaded `std::sync::Mutex`, to check for + /// a previous error, then update the error in the slot. + #[allow(clippy::unwrap_in_result)] + pub fn try_update_error(&self, e: SharedPeerError) -> Result<(), AlreadyErrored> { + let mut guard = self.0.lock().expect("error mutex should be unpoisoned"); + + if let Some(original_error) = guard.clone() { + Err(AlreadyErrored { original_error }) + } else { + *guard = Some(e); + Ok(()) + } + } +} + +/// Error returned when the [`ErrorSlot`] already contains an error. +#[derive(Clone, Debug)] +pub struct AlreadyErrored { + /// The original error in the error slot. + pub original_error: SharedPeerError, +} diff --git a/p2p/src/peer/handshaker.rs b/p2p/src/peer/handshaker.rs index fdf8411..e1b4641 100644 --- a/p2p/src/peer/handshaker.rs +++ b/p2p/src/peer/handshaker.rs @@ -1,274 +1,360 @@ +/// This module contains the logic for turning [`AsyncRead`] and [`AsyncWrite`] +/// into [`Client`] and [`Connection`]. +/// +/// The main entry point is modeled as a [`tower::Service`] the struct being +/// [`Handshaker`]. The [`Handshaker`] accepts handshake requests: [`DoHandshakeRequest`] +/// and creates a state machine that's drives the handshake forward: [`HandshakeSM`] and +/// eventually outputs a [`Client`] and [`Connection`]. +/// use std::future::Future; +use std::net::SocketAddr; use std::pin::Pin; -use std::sync::Arc; -use futures::FutureExt; -use futures::{channel::mpsc, AsyncRead, AsyncWrite, SinkExt, StreamExt}; -use monero_wire::messages::admin::{SupportFlagsRequest, SupportFlagsResponse}; -use monero_wire::messages::MessageRequest; +use futures::{channel::mpsc, sink::Sink, SinkExt, Stream}; +use futures::{FutureExt, StreamExt}; use thiserror::Error; -use tokio::time; -use tower::{Service, ServiceExt}; - -use crate::address_book::{AddressBookError, AddressBookRequest, AddressBookResponse}; -use crate::protocol::temp_database::{DataBaseRequest, DataBaseResponse, DatabaseError}; -use crate::protocol::{ - Direction, InternalMessageRequest, InternalMessageResponse, P2P_MAX_PEERS_IN_HANDSHAKE, +use tokio::{ + io::{AsyncRead, AsyncWrite}, + time, }; -use cuprate_common::{HardForks, Network, PruningSeed}; +use tokio_util::codec::{FramedRead, FramedWrite}; +use tower::{BoxError, Service, ServiceExt}; +use tracing::Instrument; + +use cuprate_common::{Network, PruningSeed}; +use monero_wire::messages::admin::SupportFlagsResponse; use monero_wire::{ - levin::{BucketError, MessageSink, MessageStream}, messages::{ admin::{HandshakeRequest, HandshakeResponse}, common::PeerSupportFlags, - BasicNodeData, CoreSyncData, MessageResponse, PeerID, PeerListEntryBase, + BasicNodeData, CoreSyncData, PeerID, PeerListEntryBase, }, - Message, NetworkAddress, + BucketError, Message, MoneroWireCodec, NetZone, NetworkAddress, RequestMessage, + ResponseMessage, }; -use tracing::Instrument; -use super::client::Client; use super::{ - client::ConnectionInfo, - connection::{ClientRequest, Connection, PeerSyncChange}, + client::{Client, ConnectionInfo}, + connection::Connection, PeerError, }; +use crate::address_book::connection_handle::new_address_book_connection_handle; +use crate::address_book::{AddressBookRequest, AddressBookResponse}; +use crate::connection_counter::ConnectionTracker; +use crate::constants::{ + CUPRATE_MINIMUM_SUPPORT_FLAGS, HANDSHAKE_TIMEOUT, P2P_MAX_PEERS_IN_HANDSHAKE, +}; +use crate::protocol::{ + CoreSyncDataRequest, CoreSyncDataResponse, Direction, InternalMessageRequest, + InternalMessageResponse, +}; +use crate::NetZoneBasicNodeData; +/// Possible handshake errors #[derive(Debug, Error)] pub enum HandShakeError { + /// The peer did not complete the handshake fast enough. #[error("The peer did not complete the handshake fast enough")] PeerTimedOut, + /// The Peer has non-standard pruning. #[error("The peer has a weird pruning scheme")] PeerClaimedWeirdPruning, - #[error("The peer has an unexpected top version")] - PeerHasUnexpectedTopVersion, + /// The peer does not have the minimum support flags #[error("The peer does not have the minimum support flags")] PeerDoesNotHaveTheMinimumSupportFlags, + /// The peer is not on the network we are on (MAINNET|TESTNET|STAGENET) #[error("The peer is on a different network")] PeerIsOnADifferentNetwork, - #[error("Address book err: {0}")] - AddressBookError(#[from] AddressBookError), + /// The peer sent us too many peers, more than [`P2P_MAX_PEERS_IN_HANDSHAKE`] #[error("The peer sent too many peers, considered spamming")] PeerSentTooManyPeers, + /// The peer sent an incorrect response #[error("The peer sent a wrong response to our handshake")] PeerSentWrongResponse, - #[error("The syncer returned an error")] - DataBaseError(#[from] DatabaseError), + /// Error communicating with peer #[error("Bucket error while communicating with peer: {0}")] BucketError(#[from] BucketError), } -pub struct NetworkConfig { - /// Port - my_port: u32, - /// The Network +/// An address used to connect to a peer. +#[derive(Debug, Copy, Clone)] +pub enum ConnectionAddr { + /// Outbound connection to another peer. + OutBound { address: NetworkAddress }, + /// An inbound direct connection to our node. + InBoundDirect { transient_address: SocketAddr }, + /// An inbound connection through a hidden network + /// like Tor/ I2p + InBoundProxy { net_zone: NetZone }, +} + +impl ConnectionAddr { + /// Gets the [`NetworkAddress`] of this connection. + pub fn get_network_address(&self, port: u16) -> Option<NetworkAddress> { + match self { + ConnectionAddr::OutBound { address } => Some(*address), + _ => None, + } + } + /// Gets the [`NetZone`] of this connection. + pub fn get_zone(&self) -> NetZone { + match self { + ConnectionAddr::OutBound { address } => address.get_zone(), + ConnectionAddr::InBoundDirect { .. } => NetZone::Public, + ConnectionAddr::InBoundProxy { net_zone } => *net_zone, + } + } + + /// Gets the [`Direction`] of this connection. + pub fn direction(&self) -> Direction { + match self { + ConnectionAddr::OutBound { .. } => Direction::Outbound, + ConnectionAddr::InBoundDirect { .. } | ConnectionAddr::InBoundProxy { .. } => { + Direction::Inbound + } + } + } +} + +/// A request to handshake with a peer. +pub struct DoHandshakeRequest<W, R> { + /// The read-half of the connection. + pub read: R, + /// The write-half of the connection. + pub write: W, + /// The [`ConnectionAddr`] of this connection. + pub addr: ConnectionAddr, + /// The [`ConnectionTracker`] of this connection. + pub connection_tracker: ConnectionTracker, +} + +/// A [`Service`] that accepts [`DoHandshakeRequest`] and +/// produces a [`Client`] and [`Connection`]. +#[derive(Debug, Clone)] +pub struct Handshaker<Svc, CoreSync, AdrBook> { + /// A collection of our [`BasicNodeData`] for each [`NetZone`] + /// for more info see: [`NetZoneBasicNodeData`] + basic_node_data: NetZoneBasicNodeData, + /// The [`Network`] our node is using network: Network, - /// Peer ID - peer_id: PeerID, - /// RPC Port - rpc_port: u16, - /// RPC Credits Per Hash - rpc_credits_per_hash: u32, - our_support_flags: PeerSupportFlags, - minimum_peer_support_flags: PeerSupportFlags, - handshake_timeout: time::Duration, - max_in_peers: u32, - target_out_peers: u32, -} - -impl Default for NetworkConfig { - fn default() -> Self { - NetworkConfig { - my_port: 18080, - network: Network::MainNet, - peer_id: PeerID(21), - rpc_port: 0, - rpc_credits_per_hash: 0, - our_support_flags: PeerSupportFlags::get_support_flag_fluffy_blocks(), - minimum_peer_support_flags: PeerSupportFlags::from(0_u32), - handshake_timeout: time::Duration::from_secs(5), - max_in_peers: 13, - target_out_peers: 21, - } - } -} - -impl NetworkConfig { - pub fn basic_node_data(&self) -> BasicNodeData { - BasicNodeData { - my_port: self.my_port, - network_id: self.network.network_id(), - peer_id: self.peer_id, - support_flags: self.our_support_flags, - rpc_port: self.rpc_port, - rpc_credits_per_hash: self.rpc_credits_per_hash, - } - } -} - -pub struct Handshake<W, R> { - sink: MessageSink<W, Message>, - stream: MessageStream<R, Message>, - direction: Direction, - addr: NetworkAddress, -} - -pub struct Handshaker<Bc, Svc, AdrBook> { - config: NetworkConfig, + /// The span [`Connection`] tasks will be [`tracing::instrument`]ed with parent_span: tracing::Span, + /// The address book [`Service`] address_book: AdrBook, - blockchain: Bc, - peer_sync_states: mpsc::Sender<PeerSyncChange>, + /// A [`Service`] to handle incoming [`CoreSyncData`] and to get + /// our [`CoreSyncData`]. + core_sync_svc: CoreSync, + /// A service given to the [`Connection`] task to answer incoming + /// requests to our node. peer_request_service: Svc, } -impl<Bc, Svc, AdrBook, W, R> tower::Service<Handshake<W, R>> for Handshaker<Bc, Svc, AdrBook> +impl<Svc, CoreSync, AdrBook> Handshaker<Svc, CoreSync, AdrBook> { + pub fn new( + basic_node_data: NetZoneBasicNodeData, + network: Network, + address_book: AdrBook, + core_sync_svc: CoreSync, + peer_request_service: Svc, + ) -> Self { + Handshaker { + basic_node_data, + network, + parent_span: tracing::Span::current(), + address_book, + core_sync_svc, + peer_request_service, + } + } +} + +impl<Svc, CoreSync, AdrBook, W, R> Service<DoHandshakeRequest<W, R>> + for Handshaker<Svc, CoreSync, AdrBook> where - Bc: Service<DataBaseRequest, Response = DataBaseResponse, Error = DatabaseError> + CoreSync: Service<CoreSyncDataRequest, Response = CoreSyncDataResponse, Error = BoxError> + Clone + Send + 'static, - Bc::Future: Send, + CoreSync::Future: Send, - Svc: Service<InternalMessageRequest, Response = InternalMessageResponse, Error = PeerError> + Svc: Service<InternalMessageRequest, Response = InternalMessageResponse, Error = BoxError> + Clone + Send + 'static, Svc::Future: Send, - AdrBook: Service<AddressBookRequest, Response = AddressBookResponse, Error = AddressBookError> + AdrBook: Service<AddressBookRequest, Response = AddressBookResponse, Error = BoxError> + Clone + Send + 'static, AdrBook::Future: Send, - W: AsyncWrite + std::marker::Unpin + Send + 'static, - R: AsyncRead + std::marker::Unpin + Send + 'static, + W: AsyncWrite + Unpin + Send + 'static, + R: AsyncRead + Unpin + Send + 'static, { - type Error = HandShakeError; type Response = Client; + type Error = BoxError; type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>; fn poll_ready( &mut self, - cx: &mut std::task::Context<'_>, + _cx: &mut std::task::Context<'_>, ) -> std::task::Poll<Result<(), Self::Error>> { + // We are always ready. std::task::Poll::Ready(Ok(())) } - fn call(&mut self, req: Handshake<W, R>) -> Self::Future { - let Handshake { - sink: mut peer_sink, - stream: mut peer_stream, - direction, + fn call(&mut self, req: DoHandshakeRequest<W, R>) -> Self::Future { + let DoHandshakeRequest { + read, + write, addr, + connection_tracker, } = req; + // create the levin message stream/ sink. + let peer_stream = FramedRead::new(read, MoneroWireCodec::default()); + let peer_sink = FramedWrite::new(write, MoneroWireCodec::default()); + + // The span the handshake state machine will use let span = tracing::debug_span!("Handshaker"); + // The span the connection task will use. let connection_span = tracing::debug_span!(parent: &self.parent_span, "Connection"); - let blockchain = self.blockchain.clone(); + // clone the services that the handshake state machine will need. + let core_sync_svc = self.core_sync_svc.clone(); let address_book = self.address_book.clone(); - let syncer_tx = self.peer_sync_states.clone(); let peer_request_service = self.peer_request_service.clone(); let state_machine = HandshakeSM { peer_sink, peer_stream, - direction, addr, - network: self.config.network, - basic_node_data: self.config.basic_node_data(), - minimum_support_flags: self.config.minimum_peer_support_flags, + network: self.network, + basic_node_data: self.basic_node_data.basic_node_data(&addr.get_zone()), address_book, - blockchain, + core_sync_svc, peer_request_service, connection_span, + connection_tracker, state: HandshakeState::Start, }; - - let ret = time::timeout(self.config.handshake_timeout, state_machine.do_handshake()); + // although callers should use a timeout do one here as well just to be safe. + let ret = time::timeout(HANDSHAKE_TIMEOUT, state_machine.do_handshake()); async move { match ret.await { Ok(handshake) => handshake, - Err(_) => Err(HandShakeError::PeerTimedOut), + Err(_) => Err(HandShakeError::PeerTimedOut.into()), } } + .instrument(span) .boxed() } } +/// The states a handshake can be in. enum HandshakeState { + /// The initial state. + /// if this is an inbound handshake then this state means we + /// are waiting for a [`HandshakeRequest`]. Start, + /// Waiting for a [`HandshakeResponse`]. WaitingForHandshakeResponse, - WaitingForSupportFlagResponse(BasicNodeData), - Complete(BasicNodeData), + /// Waiting for a [`SupportFlagsResponse`] + /// This contains the peers node data. + WaitingForSupportFlagResponse(BasicNodeData, CoreSyncData), + /// The handshake is complete. + /// This contains the peers node data. + Complete(BasicNodeData, CoreSyncData), } impl HandshakeState { + /// Returns true if the handshake is completed. pub fn is_complete(&self) -> bool { - matches!(self, HandshakeState::Complete(_)) + matches!(self, Self::Complete(..)) } - pub fn peer_basic_node_data(self) -> Option<BasicNodeData> { + /// returns the peers [`BasicNodeData`] and [`CoreSyncData`] if the peer + /// is in state [`HandshakeState::Complete`]. + pub fn peer_data(self) -> Option<(BasicNodeData, CoreSyncData)> { match self { - HandshakeState::Complete(sup) => Some(sup), + HandshakeState::Complete(bnd, coresync) => Some((bnd, coresync)), _ => None, } } } -struct HandshakeSM<Bc, Svc, AdrBook, W, R> { - peer_sink: MessageSink<W, Message>, - peer_stream: MessageStream<R, Message>, - direction: Direction, - addr: NetworkAddress, +/// The state machine that drives a handshake forward and +/// accepts requests (that can happen during a handshake) +/// from a peer. +struct HandshakeSM<Svc, CoreSync, AdrBook, W, R> { + /// The levin [`FramedWrite`] for the peer. + peer_sink: W, + /// The levin [`FramedRead`] for the peer. + peer_stream: R, + /// The [`ConnectionAddr`] for the peer. + addr: ConnectionAddr, + /// The [`Network`] we are on. network: Network, + /// Our [`BasicNodeData`]. basic_node_data: BasicNodeData, - minimum_support_flags: PeerSupportFlags, + /// The address book [`Service`] address_book: AdrBook, - blockchain: Bc, + /// The core sync [`Service`] to handle incoming + /// [`CoreSyncData`] and to retrieve ours. + core_sync_svc: CoreSync, + /// The [`Service`] passed to the [`Connection`] + /// task to handle incoming peer requests. peer_request_service: Svc, + + /// The [`tracing::Span`] the [`Connection`] task + /// will be [`tracing::instrument`]ed with. connection_span: tracing::Span, + /// A connection tracker to keep track of the + /// number of connections Cuprate is making. + connection_tracker: ConnectionTracker, state: HandshakeState, } -impl<Bc, Svc, AdrBook, W, R> HandshakeSM<Bc, Svc, AdrBook, W, R> +impl<Svc, CoreSync, AdrBook, W, R> HandshakeSM<Svc, CoreSync, AdrBook, W, R> where - Bc: Service<DataBaseRequest, Response = DataBaseResponse, Error = DatabaseError> + CoreSync: Service<CoreSyncDataRequest, Response = CoreSyncDataResponse, Error = BoxError> + Clone + Send + 'static, - Bc::Future: Send, + CoreSync::Future: Send, - Svc: Service<InternalMessageRequest, Response = InternalMessageResponse, Error = PeerError> + Svc: Service<InternalMessageRequest, Response = InternalMessageResponse, Error = BoxError> + Clone + Send + 'static, Svc::Future: Send, - AdrBook: Service<AddressBookRequest, Response = AddressBookResponse, Error = AddressBookError> + AdrBook: Service<AddressBookRequest, Response = AddressBookResponse, Error = BoxError> + Clone + Send + 'static, AdrBook::Future: Send, - W: AsyncWrite + std::marker::Unpin + Send + 'static, - R: AsyncRead + std::marker::Unpin + Send + 'static, + W: Sink<Message, Error = BucketError> + Unpin, + R: Stream<Item = Result<Message, BucketError>> + Unpin, { - async fn get_our_core_sync(&mut self) -> Result<CoreSyncData, DatabaseError> { - let blockchain = self.blockchain.ready().await?; - let DataBaseResponse::CoreSyncData(core_sync) = blockchain.call(DataBaseRequest::CoreSyncData).await? else { - unreachable!("Database will always return the requested item") + /// Gets our [`CoreSyncData`] from the `core_sync_svc`. + async fn get_our_core_sync(&mut self) -> Result<CoreSyncData, BoxError> { + let core_sync_svc = self.core_sync_svc.ready().await?; + let CoreSyncDataResponse::Ours(core_sync) = core_sync_svc.call(CoreSyncDataRequest::GetOurs).await? else { + unreachable!("The Service must give correct responses"); }; + tracing::trace!("Got core sync data: {core_sync:?}"); Ok(core_sync) } + /// Sends a [`HandshakeRequest`] to the peer. async fn send_handshake_req( &mut self, node_data: BasicNodeData, @@ -281,59 +367,62 @@ where tracing::trace!("Sending handshake request: {handshake_req:?}"); - let message: Message = Message::Request(handshake_req.into()); + let message: Message = Message::Request(RequestMessage::Handshake(handshake_req)); self.peer_sink.send(message).await?; Ok(()) } - async fn get_handshake_res(&mut self) -> Result<HandshakeResponse, HandShakeError> { - // put a timeout on this - let Message::Response(MessageResponse::Handshake(handshake_res)) = self.peer_stream.next().await.expect("MessageSink will not return None")? else { - return Err(HandShakeError::PeerSentWrongResponse); - }; - - tracing::trace!("Received handshake response: {handshake_res:?}"); - - Ok(handshake_res) - } - + /// Sends a [`SupportFlagsRequest`] to the peer. + /// This is done when a peer sends no support flags in their + /// [`HandshakeRequest`] or [`HandshakeResponse`]. + /// + /// *note because Cuprate has minimum required support flags this won't + /// happeen but is included here just in case this changes. async fn send_support_flag_req(&mut self) -> Result<(), HandShakeError> { tracing::trace!("Peer sent no support flags, sending request"); - let message: Message = Message::Request(SupportFlagsRequest.into()); + let message: Message = Message::Request(RequestMessage::SupportFlags); self.peer_sink.send(message).await?; Ok(()) } - async fn handle_handshake_response( - &mut self, - res: HandshakeResponse, - ) -> Result<(), HandShakeError> { + /// Handles an incoming [`HandshakeResponse`]. + async fn handle_handshake_response(&mut self, res: HandshakeResponse) -> Result<(), BoxError> { let HandshakeResponse { node_data: peer_node_data, payload_data: peer_core_sync, local_peerlist_new, } = res; - if !peer_node_data - .support_flags - .contains(&self.minimum_support_flags) - { - tracing::debug!("Handshake failed: peer does not have minimum support flags"); - return Err(HandShakeError::PeerDoesNotHaveTheMinimumSupportFlags); - } - + // Check the peer is on the correct network. if peer_node_data.network_id != self.network.network_id() { tracing::debug!("Handshake failed: peer is on a different network"); - return Err(HandShakeError::PeerIsOnADifferentNetwork); + return Err(HandShakeError::PeerIsOnADifferentNetwork.into()); } + // Check the peer meets the minimum support flags. + if !peer_node_data + .support_flags + .contains(&CUPRATE_MINIMUM_SUPPORT_FLAGS) + { + tracing::debug!("Handshake failed: peer does not have minimum required support flags"); + return Err(HandShakeError::PeerDoesNotHaveTheMinimumSupportFlags.into()); + } + + // Check the peer didn't send too many peers. if local_peerlist_new.len() > P2P_MAX_PEERS_IN_HANDSHAKE { tracing::debug!("Handshake failed: peer sent too many peers in response"); - return Err(HandShakeError::PeerSentTooManyPeers); + return Err(HandShakeError::PeerSentTooManyPeers.into()); } + // Tell the sync mgr about the new incoming core sync data. + self.core_sync_svc + .ready() + .await? + .call(CoreSyncDataRequest::NewIncoming(peer_core_sync.clone())) + .await?; + // Tell the address book about the new peers self.address_book .ready() @@ -344,52 +433,65 @@ where )) .await?; - // coresync, pruning seed - + // This won't actually happen (as long as we have a none 0 minimum support flags) + // it's just included here for completeness. if peer_node_data.support_flags.is_empty() { self.send_support_flag_req().await?; - self.state = HandshakeState::WaitingForSupportFlagResponse(peer_node_data); + self.state = + HandshakeState::WaitingForSupportFlagResponse(peer_node_data, peer_core_sync); } else { - self.state = HandshakeState::Complete(peer_node_data); + // this will always happen. + self.state = HandshakeState::Complete(peer_node_data, peer_core_sync); } Ok(()) } - async fn handle_message_response( - &mut self, - response: MessageResponse, - ) -> Result<(), HandShakeError> { - match (&mut self.state, response) { + /// Handles a [`MessageResponse`]. + async fn handle_message_response(&mut self, response: ResponseMessage) -> Result<(), BoxError> { + // The functions called here will change the state of the HandshakeSM so `HandshakeState::Start` + // is just used as a place holder. + // + // doing this allows us to not clone the BasicNodeData and CoreSyncData for WaitingForSupportFlagResponse. + let prv_state = std::mem::replace(&mut self.state, HandshakeState::Start); + + match (prv_state, response) { ( HandshakeState::WaitingForHandshakeResponse, - MessageResponse::Handshake(handshake), + ResponseMessage::Handshake(handshake), ) => self.handle_handshake_response(handshake).await, ( - HandshakeState::WaitingForSupportFlagResponse(bnd), - MessageResponse::SupportFlags(support_flags), + HandshakeState::WaitingForSupportFlagResponse(mut bnd, coresync), + ResponseMessage::SupportFlags(support_flags), ) => { bnd.support_flags = support_flags.support_flags; - self.state = HandshakeState::Complete(bnd.clone()); + self.state = HandshakeState::Complete(bnd, coresync); Ok(()) } - _ => Err(HandShakeError::PeerSentWrongResponse), + _ => Err(HandShakeError::PeerSentWrongResponse.into()), } } + /// Sends our [`PeerSupportFlags`] to the peer. async fn send_support_flags( &mut self, support_flags: PeerSupportFlags, ) -> Result<(), HandShakeError> { - let message = Message::Response(SupportFlagsResponse { support_flags }.into()); + let message = Message::Response(ResponseMessage::SupportFlags(SupportFlagsResponse { + support_flags, + })); self.peer_sink.send(message).await?; Ok(()) } - async fn do_outbound_handshake(&mut self) -> Result<(), HandShakeError> { + /// Attempts an outbound handshake with the peer. + async fn do_outbound_handshake(&mut self) -> Result<(), BoxError> { + // Get the data needed for the handshake request. let core_sync = self.get_our_core_sync().await?; + // send the handshake request. self.send_handshake_req(self.basic_node_data.clone(), core_sync) .await?; + // set the state to waiting for a response. self.state = HandshakeState::WaitingForHandshakeResponse; while !self.state.is_complete() { @@ -397,14 +499,17 @@ where Some(mes) => { let mes = mes?; match mes { - Message::Request(MessageRequest::SupportFlags(_)) => { + Message::Request(RequestMessage::SupportFlags) => { + // The only request we should be getting during an outbound handshake + // is a support flag request. self.send_support_flags(self.basic_node_data.support_flags) .await? } Message::Response(response) => { + // This could be a handshake response or a support flags response. self.handle_message_response(response).await? } - _ => return Err(HandShakeError::PeerSentWrongResponse), + _ => return Err(HandShakeError::PeerSentWrongResponse.into()), } } None => unreachable!("peer_stream wont return None"), @@ -414,40 +519,108 @@ where Ok(()) } - async fn do_handshake(mut self) -> Result<Client, HandShakeError> { - match self.direction { - Direction::Outbound => self.do_outbound_handshake().await?, + /// Completes a handshake with a peer. + async fn do_handshake(mut self) -> Result<Client, BoxError> { + let mut peer_reachable = false; + match self.addr.direction() { + Direction::Outbound => { + self.do_outbound_handshake().await?; + // If this is an outbound handshake then obviously the peer + // is reachable. + peer_reachable = true + } Direction::Inbound => todo!(), } - let (server_tx, server_rx) = mpsc::channel(3); + let (server_tx, server_rx) = mpsc::channel(0); - let (replace_me, replace_me_rx) = mpsc::channel(3); - - let peer_node_data = self + let (peer_node_data, coresync) = self .state - .peer_basic_node_data() + .peer_data() .expect("We must be in state complete to be here"); + + let pruning_seed = PruningSeed::try_from(coresync.pruning_seed).map_err(|e| Box::new(e))?; + + // create the handle between the Address book and the connection task to + // allow the address book to shutdown the connection task and to update + // the address book when the connection is closed. + let (book_connection_side_handle, connection_book_side_handle) = + new_address_book_connection_handle(); + + // tell the address book about the new connection. + self.address_book + .ready() + .await? + .call(AddressBookRequest::ConnectedToPeer { + zone: self.addr.get_zone(), + connection_handle: connection_book_side_handle, + addr: self.addr.get_network_address( + peer_node_data + .my_port + .try_into() + .map_err(|_| "Peer sent a port that does not fit into a u16")?, + ), + id: peer_node_data.peer_id, + reachable: peer_reachable, + last_seen: chrono::Utc::now().naive_utc(), + pruning_seed: pruning_seed.clone(), + rpc_port: peer_node_data.rpc_port, + rpc_credits_per_hash: peer_node_data.rpc_credits_per_hash, + }) + .await?; + + // This block below is for keeping the last seen times in the address book + // upto date. We only update the last seen times on timed syncs to reduce + // the load on the address book. + // + // first clone the items needed + let mut address_book = self.address_book.clone(); + let peer_id = peer_node_data.peer_id; + let net_zone = self.addr.get_zone(); + + /* + let peer_stream = self.peer_stream.then(|mes| async move { + if let Ok(mes) = &mes { + if mes.id() == TimedSync::ID { + if let Ok(ready_book) = address_book.ready().await { + // we dont care about address book errors here, If there is a problem + // with the address book the node will get shutdown. + let _ = ready_book + .call(AddressBookRequest::SetPeerSeen( + peer_id, + chrono::Utc::now().naive_utc(), + net_zone, + )) + .await; + } + } + } + // return the message + mes + }); + + */ + let connection = Connection::new( + self.addr, + self.peer_sink, + server_rx, + self.connection_tracker, + book_connection_side_handle, + self.peer_request_service, + ); + + let connection_task = tokio::task::spawn(connection.run().instrument(self.connection_span)); + let connection_info = ConnectionInfo { addr: self.addr, support_flags: peer_node_data.support_flags, + pruning_seed, peer_id: peer_node_data.peer_id, rpc_port: peer_node_data.rpc_port, rpc_credits_per_hash: peer_node_data.rpc_credits_per_hash, }; - let connection = Connection::new( - self.addr, - self.peer_sink, - self.peer_stream, - server_rx, - replace_me, - self.peer_request_service, - ); - - let client = Client::new(connection_info.into(), server_tx); - - tokio::task::spawn(connection.run().instrument(self.connection_span)); + let client = Client::new(connection_info.into(), /* futures::futures_channel::oneshot::Sender<()> */, server_tx, connection_task, /* tokio::task::JoinHandle<()> */); Ok(client) } diff --git a/p2p/src/peer/load_tracked_client.rs b/p2p/src/peer/load_tracked_client.rs new file mode 100644 index 0000000..8ac5e04 --- /dev/null +++ b/p2p/src/peer/load_tracked_client.rs @@ -0,0 +1,74 @@ +//! A peer connection service wrapper type to handle load tracking and provide access to the +//! reported protocol version. + +use std::sync::atomic::Ordering; +use std::{ + sync::Arc, + task::{Context, Poll}, +}; + +use cuprate_common::PruningSeed; +use tower::{ + load::{Load, PeakEwma}, + Service, +}; + +use crate::{ + constants::{EWMA_DECAY_TIME_NANOS, EWMA_DEFAULT_RTT}, + peer::{Client, ConnectionInfo}, +}; + +/// A client service wrapper that keeps track of its load. +/// +/// It also keeps track of the peer's reported protocol version. +pub struct LoadTrackedClient { + /// A service representing a connected peer, wrapped in a load tracker. + service: PeakEwma<Client>, + + /// The metadata for the connected peer `service`. + connection_info: Arc<ConnectionInfo>, +} + +/// Create a new [`LoadTrackedClient`] wrapping the provided `client` service. +impl From<Client> for LoadTrackedClient { + fn from(client: Client) -> Self { + let connection_info = client.connection_info.clone(); + + let service = PeakEwma::new( + client, + EWMA_DEFAULT_RTT, + EWMA_DECAY_TIME_NANOS, + tower::load::CompleteOnResponse::default(), + ); + + LoadTrackedClient { + service, + connection_info, + } + } +} + +impl<Request> Service<Request> for LoadTrackedClient +where + Client: Service<Request>, +{ + type Response = <Client as Service<Request>>::Response; + type Error = <Client as Service<Request>>::Error; + type Future = <PeakEwma<Client> as Service<Request>>::Future; + + fn poll_ready(&mut self, context: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.service.poll_ready(context) + } + + fn call(&mut self, request: Request) -> Self::Future { + self.service.call(request) + } +} + +impl Load for LoadTrackedClient { + type Metric = <PeakEwma<Client> as Load>::Metric; + + fn load(&self) -> Self::Metric { + self.service.load() + } +} diff --git a/p2p/src/peer/tests/handshake.rs b/p2p/src/peer/tests/handshake.rs index 0421162..e7ed008 100644 --- a/p2p/src/peer/tests/handshake.rs +++ b/p2p/src/peer/tests/handshake.rs @@ -1 +1 @@ -pub use crate::peer::handshaker::{Handshake, Handshaker}; +pub use crate::peer::handshaker::Handshaker; diff --git a/p2p/src/protocol.rs b/p2p/src/protocol.rs index d271a9a..235a2b4 100644 --- a/p2p/src/protocol.rs +++ b/p2p/src/protocol.rs @@ -1,13 +1,29 @@ pub mod internal_network; -pub mod temp_database; pub use internal_network::{InternalMessageRequest, InternalMessageResponse}; -pub const BLOCKS_IDS_SYNCHRONIZING_DEFAULT_COUNT: usize = 10000; -pub const BLOCKS_IDS_SYNCHRONIZING_MAX_COUNT: usize = 25000; -pub const P2P_MAX_PEERS_IN_HANDSHAKE: usize = 250; +use monero_wire::messages::CoreSyncData; +/// A request to a [`tower::Service`] that handles sync states. +pub enum CoreSyncDataRequest { + /// Get our [`CoreSyncData`]. + GetOurs, + /// Handle an incoming [`CoreSyncData`]. + NewIncoming(CoreSyncData), +} + +/// A response from a [`tower::Service`] that handles sync states. +pub enum CoreSyncDataResponse { + /// Our [`CoreSyncData`] + Ours(CoreSyncData), + /// The incoming [`CoreSyncData`] is ok. + Ok, +} + +/// The direction of a connection. pub enum Direction { + /// An inbound connection. Inbound, + /// An outbound connection. Outbound, } diff --git a/p2p/src/protocol/internal_network.rs b/p2p/src/protocol/internal_network.rs index daa17df..42a419e 100644 --- a/p2p/src/protocol/internal_network.rs +++ b/p2p/src/protocol/internal_network.rs @@ -22,162 +22,104 @@ /// Request: NewFluffyBlock, Response: None, /// Request: NewTransactions, Response: None /// -use monero_wire::messages::{ - AdminMessage, ChainRequest, ChainResponse, FluffyMissingTransactionsRequest, GetObjectsRequest, - GetObjectsResponse, GetTxPoolCompliment, Handshake, Message, MessageNotification, - MessageRequest, MessageResponse, NewBlock, NewFluffyBlock, NewTransactions, Ping, - ProtocolMessage, SupportFlags, TimedSync, +use monero_wire::{ + ChainRequest, ChainResponse, FluffyMissingTransactionsRequest, GetObjectsRequest, + GetObjectsResponse, GetTxPoolCompliment, HandshakeRequest, HandshakeResponse, Message, + NewBlock, NewFluffyBlock, NewTransactions, PingResponse, RequestMessage, SupportFlagsResponse, + TimedSyncRequest, TimedSyncResponse, }; -macro_rules! client_request_peer_response { - ( - Admin: - $($admin_mes:ident),+ - Protocol: - $(Request: $protocol_req:ident, Response: $(SOME: $protocol_res:ident)? $(NULL: $none:expr)? ),+ - ) => { +mod try_from; - #[derive(Debug, Clone)] - pub enum InternalMessageRequest { - $($admin_mes(<$admin_mes as AdminMessage>::Request),)+ - $($protocol_req(<$protocol_req as ProtocolMessage>::Notification),)+ - } +/// An enum representing a request/ response combination, so a handshake request +/// and response would have the same [`MessageID`]. This allows associating the +/// correct response to a request. +#[derive(Debug, Eq, PartialEq, Copy, Clone)] +pub enum MessageID { + Handshake, + TimedSync, + Ping, + SupportFlags, - impl InternalMessageRequest { - pub fn get_str_name(&self) -> &'static str { - match self { - $(InternalMessageRequest::$admin_mes(_) => $admin_mes::NAME,)+ - $(InternalMessageRequest::$protocol_req(_) => $protocol_req::NAME,)+ - } - } - pub fn id(&self) -> u32 { - match self { - $(InternalMessageRequest::$admin_mes(_) => $admin_mes::ID,)+ - $(InternalMessageRequest::$protocol_req(_) => $protocol_req::ID,)+ - } - } - pub fn expected_id(&self) -> Option<u32> { - match self { - $(InternalMessageRequest::$admin_mes(_) => Some($admin_mes::ID),)+ - $(InternalMessageRequest::$protocol_req(_) => $(Some($protocol_res::ID))? $($none)?,)+ - } - } - pub fn is_levin_request(&self) -> bool { - match self { - $(InternalMessageRequest::$admin_mes(_) => true,)+ - $(InternalMessageRequest::$protocol_req(_) => false,)+ - } - } - } - - impl From<MessageRequest> for InternalMessageRequest { - fn from(value: MessageRequest) -> Self { - match value { - $(MessageRequest::$admin_mes(mes) => InternalMessageRequest::$admin_mes(mes),)+ - } - } - } - - impl Into<Message> for InternalMessageRequest { - fn into(self) -> Message { - match self { - $(InternalMessageRequest::$admin_mes(mes) => Message::Request(MessageRequest::$admin_mes(mes)),)+ - $(InternalMessageRequest::$protocol_req(mes) => Message::Notification(MessageNotification::$protocol_req(mes)),)+ - } - } - } - - #[derive(Debug)] - pub struct NotAnInternalRequest; - - impl TryFrom<Message> for InternalMessageRequest { - type Error = NotAnInternalRequest; - fn try_from(value: Message) -> Result<Self, Self::Error> { - match value { - Message::Response(_) => Err(NotAnInternalRequest), - Message::Request(req) => Ok(req.into()), - Message::Notification(noti) => { - match noti { - $(MessageNotification::$protocol_req(noti) => Ok(InternalMessageRequest::$protocol_req(noti)),)+ - _ => Err(NotAnInternalRequest), - } - } - } - } - } - - #[derive(Debug, Clone)] - pub enum InternalMessageResponse { - $($admin_mes(<$admin_mes as AdminMessage>::Response),)+ - $($($protocol_res(<$protocol_res as ProtocolMessage>::Notification),)?)+ - } - - impl InternalMessageResponse { - pub fn get_str_name(&self) -> &'static str { - match self { - $(InternalMessageResponse::$admin_mes(_) => $admin_mes::NAME,)+ - $($(InternalMessageResponse::$protocol_res(_) => $protocol_res::NAME,)?)+ - } - } - pub fn id(&self) -> u32 { - match self{ - $(InternalMessageResponse::$admin_mes(_) => $admin_mes::ID,)+ - $($(InternalMessageResponse::$protocol_res(_) => $protocol_res::ID,)?)+ - } - } - } - - impl From<MessageResponse> for InternalMessageResponse { - fn from(value: MessageResponse) -> Self { - match value { - $(MessageResponse::$admin_mes(mes) => InternalMessageResponse::$admin_mes(mes),)+ - } - } - } - - impl Into<Message> for InternalMessageResponse { - fn into(self) -> Message { - match self { - $(InternalMessageResponse::$admin_mes(mes) => Message::Response(MessageResponse::$admin_mes(mes)),)+ - $($(InternalMessageResponse::$protocol_res(mes) => Message::Notification(MessageNotification::$protocol_res(mes)),)?)+ - } - } - } - - #[derive(Debug)] - pub struct NotAnInternalResponse; - - impl TryFrom<Message> for InternalMessageResponse { - type Error = NotAnInternalResponse; - fn try_from(value: Message) -> Result<Self, Self::Error> { - match value { - Message::Response(res) => Ok(res.into()), - Message::Request(_) => Err(NotAnInternalResponse), - Message::Notification(noti) => { - match noti { - $($(MessageNotification::$protocol_res(noti) => Ok(InternalMessageResponse::$protocol_res(noti)),)?)+ - _ => Err(NotAnInternalResponse), - } - } - } - } - } - }; + GetObjects, + GetChain, + FluffyMissingTxs, + GetTxPoolCompliment, + NewBlock, + NewFluffyBlock, + NewTransactions, } -client_request_peer_response!( - Admin: - Handshake, - TimedSync, - Ping, - SupportFlags - Protocol: - Request: GetObjectsRequest, Response: SOME: GetObjectsResponse, - Request: ChainRequest, Response: SOME: ChainResponse, - Request: FluffyMissingTransactionsRequest, Response: SOME: NewFluffyBlock, // these 2 could be requests or responses - Request: GetTxPoolCompliment, Response: SOME: NewTransactions, // - // these don't need to be responded to - Request: NewBlock, Response: NULL: None, - Request: NewFluffyBlock, Response: NULL: None, - Request: NewTransactions, Response: NULL: None -); +pub enum Request { + Handshake(HandshakeRequest), + TimedSync(TimedSyncRequest), + Ping, + SupportFlags, + + GetObjects(GetObjectsRequest), + GetChain(ChainRequest), + FluffyMissingTxs(FluffyMissingTransactionsRequest), + GetTxPoolCompliment(GetTxPoolCompliment), + NewBlock(NewBlock), + NewFluffyBlock(NewFluffyBlock), + NewTransactions(NewTransactions), +} + +impl Request { + pub fn id(&self) -> MessageID { + match self { + Request::Handshake(_) => MessageID::Handshake, + Request::TimedSync(_) => MessageID::TimedSync, + Request::Ping => MessageID::Ping, + Request::SupportFlags => MessageID::SupportFlags, + + Request::GetObjects(_) => MessageID::GetObjects, + Request::GetChain(_) => MessageID::GetChain, + Request::FluffyMissingTxs(_) => MessageID::FluffyMissingTxs, + Request::GetTxPoolCompliment(_) => MessageID::GetTxPoolCompliment, + Request::NewBlock(_) => MessageID::NewBlock, + Request::NewFluffyBlock(_) => MessageID::NewFluffyBlock, + Request::NewTransactions(_) => MessageID::NewTransactions, + } + } + + pub fn needs_response(&self) -> bool { + match self { + Request::NewBlock(_) | Request::NewFluffyBlock(_) | Request::NewTransactions(_) => { + false + } + _ => true, + } + } +} + +pub enum Response { + Handshake(HandshakeResponse), + TimedSync(TimedSyncResponse), + Ping(PingResponse), + SupportFlags(SupportFlagsResponse), + + GetObjects(GetObjectsResponse), + GetChain(ChainResponse), + NewFluffyBlock(NewFluffyBlock), + NewTransactions(NewTransactions), + NA, +} + +impl Response { + pub fn id(&self) -> MessageID { + match self { + Response::Handshake(_) => MessageID::Handshake, + Response::TimedSync(_) => MessageID::TimedSync, + Response::Ping(_) => MessageID::Ping, + Response::SupportFlags(_) => MessageID::SupportFlags, + + Response::GetObjects(_) => MessageID::GetObjects, + Response::GetChain(_) => MessageID::GetChain, + Response::NewFluffyBlock(_) => MessageID::NewBlock, + Response::NewTransactions(_) => MessageID::NewFluffyBlock, + + Response::NA => panic!("Can't get message ID for a non existent response"), + } + } +} diff --git a/p2p/src/protocol/internal_network/try_from.rs b/p2p/src/protocol/internal_network/try_from.rs new file mode 100644 index 0000000..c8c9ec5 --- /dev/null +++ b/p2p/src/protocol/internal_network/try_from.rs @@ -0,0 +1,163 @@ +//! This module contains the implementations of [`TryFrom`] and [`From`] to convert between +//! [`Message`], [`Request`] and [`Response`]. + +use monero_wire::messages::{Message, ProtocolMessage, RequestMessage, ResponseMessage}; + +use super::{Request, Response}; + +pub struct MessageConversionError; + + +macro_rules! match_body { + (match $value: ident {$($body:tt)*} ($left:pat => $right_ty:expr) $($todo:tt)*) => { + match_body!( match $value { + $left => $right_ty, + $($body)* + } $($todo)* ) + }; + (match $value: ident {$($body:tt)*}) => { + match $value { + $($body)* + } + }; +} + + +macro_rules! from { + ($left_ty:ident, $right_ty:ident, {$($left:ident $(($val: ident))? = $right:ident $(($vall: ident))?,)+}) => { + impl From<$left_ty> for $right_ty { + fn from(value: $left_ty) -> Self { + match_body!( match value {} + $(($left_ty::$left$(($val))? => $right_ty::$right$(($vall))?))+ + ) + } + } + }; +} + +macro_rules! try_from { + ($left_ty:ident, $right_ty:ident, {$($left:ident $(($val: ident))? = $right:ident $(($vall: ident))?,)+}) => { + impl TryFrom<$left_ty> for $right_ty { + type Error = MessageConversionError; + + fn try_from(value: $left_ty) -> Result<Self, Self::Error> { + Ok(match_body!( match value { + _ => return Err(MessageConversionError) + } + $(($left_ty::$left$(($val))? => $right_ty::$right$(($vall))?))+ + )) + } + } + }; +} + +macro_rules! from_try_from { + ($left_ty:ident, $right_ty:ident, {$($left:ident $(($val: ident))? = $right:ident $(($vall: ident))?,)+}) => { + try_from!($left_ty, $right_ty, {$($left $(($val))? = $right $(($vall))?,)+}); + from!($right_ty, $left_ty, {$($right $(($val))? = $left $(($vall))?,)+}); + }; +} + +macro_rules! try_from_try_from { + ($left_ty:ident, $right_ty:ident, {$($left:ident $(($val: ident))? = $right:ident $(($vall: ident))?,)+}) => { + try_from!($left_ty, $right_ty, {$($left $(($val))? = $right $(($vall))?,)+}); + try_from!($right_ty, $left_ty, {$($right $(($val))? = $left $(($val))?,)+}); + }; +} + +from_try_from!(Request, RequestMessage,{ + Handshake(val) = Handshake(val), + Ping = Ping, + SupportFlags = SupportFlags, + TimedSync(val) = TimedSync(val), +}); + +try_from_try_from!(Request, ProtocolMessage,{ + NewBlock(val) = NewBlock(val), + NewFluffyBlock(val) = NewFluffyBlock(val), + GetObjects(val) = GetObjectsRequest(val), + GetChain(val) = ChainRequest(val), + NewTransactions(val) = NewTransactions(val), + FluffyMissingTxs(val) = FluffyMissingTransactionsRequest(val), + GetTxPoolCompliment(val) = GetTxPoolCompliment(val), +}); + + + +impl TryFrom<Message> for Request { + type Error = MessageConversionError; + + fn try_from(value: Message) -> Result<Self, Self::Error> { + match value { + Message::Request(req) => Ok(req.into()), + Message::Protocol(pro) => pro.try_into(), + _ => Err(MessageConversionError), + } + } +} + +impl From<Request> for Message { + fn from(value: Request) -> Self { + match value { + Request::Handshake(val) => Message::Request(RequestMessage::Handshake(val)), + Request::Ping => Message::Request(RequestMessage::Ping), + Request::SupportFlags => Message::Request(RequestMessage::SupportFlags), + Request::TimedSync(val) => Message::Request(RequestMessage::TimedSync(val)), + + Request::NewBlock(val) => Message::Protocol(ProtocolMessage::NewBlock(val)), + Request::NewFluffyBlock(val) => Message::Protocol(ProtocolMessage::NewFluffyBlock(val)), + Request::GetObjects(val) => Message::Protocol(ProtocolMessage::GetObjectsRequest(val)), + Request::GetChain(val) => Message::Protocol(ProtocolMessage::ChainRequest(val)), + Request::NewTransactions(val) => Message::Protocol(ProtocolMessage::NewTransactions(val)), + Request::FluffyMissingTxs(val) => Message::Protocol(ProtocolMessage::FluffyMissingTransactionsRequest(val)), + Request::GetTxPoolCompliment(val) => Message::Protocol(ProtocolMessage::GetTxPoolCompliment(val)), + } + } +} + +from_try_from!(Response, ResponseMessage,{ + Handshake(val) = Handshake(val), + Ping(val) = Ping(val), + SupportFlags(val) = SupportFlags(val), + TimedSync(val) = TimedSync(val), +}); + +try_from_try_from!(Response, ProtocolMessage,{ + NewFluffyBlock(val) = NewFluffyBlock(val), + GetObjects(val) = GetObjectsResponse(val), + GetChain(val) = ChainEntryResponse(val), + NewTransactions(val) = NewTransactions(val), + +}); + +impl TryFrom<Message> for Response { + type Error = MessageConversionError; + + fn try_from(value: Message) -> Result<Self, Self::Error> { + match value { + Message::Response(res) => Ok(res.into()), + Message::Protocol(pro) => pro.try_into(), + _ => Err(MessageConversionError), + } + } +} + +impl TryFrom<Response> for Message { + type Error = MessageConversionError; + + fn try_from(value: Response) -> Result<Self, Self::Error> { + Ok(match value { + Response::Handshake(val) => Message::Response(ResponseMessage::Handshake(val)), + Response::Ping(val) => Message::Response(ResponseMessage::Ping(val)), + Response::SupportFlags(val) => Message::Response(ResponseMessage::SupportFlags(val)), + Response::TimedSync(val) => Message::Response(ResponseMessage::TimedSync(val)), + + Response::NewFluffyBlock(val) => Message::Protocol(ProtocolMessage::NewFluffyBlock(val)), + Response::GetObjects(val) => Message::Protocol(ProtocolMessage::GetObjectsResponse(val)), + Response::GetChain(val) => Message::Protocol(ProtocolMessage::ChainEntryResponse(val)), + Response::NewTransactions(val) => Message::Protocol(ProtocolMessage::NewTransactions(val)), + + Response::NA => return Err(MessageConversionError), + }) + } +} diff --git a/p2p/src/protocol/lib.rs b/p2p/src/protocol/lib.rs deleted file mode 100644 index d271a9a..0000000 --- a/p2p/src/protocol/lib.rs +++ /dev/null @@ -1,13 +0,0 @@ -pub mod internal_network; -pub mod temp_database; - -pub use internal_network::{InternalMessageRequest, InternalMessageResponse}; - -pub const BLOCKS_IDS_SYNCHRONIZING_DEFAULT_COUNT: usize = 10000; -pub const BLOCKS_IDS_SYNCHRONIZING_MAX_COUNT: usize = 25000; -pub const P2P_MAX_PEERS_IN_HANDSHAKE: usize = 250; - -pub enum Direction { - Inbound, - Outbound, -} diff --git a/p2p/src/protocol/temp_database.rs b/p2p/src/protocol/temp_database.rs deleted file mode 100644 index 82016bf..0000000 --- a/p2p/src/protocol/temp_database.rs +++ /dev/null @@ -1,36 +0,0 @@ -use monero_wire::messages::CoreSyncData; -use thiserror::Error; - -pub enum BlockKnown { - No, - OnMainChain, - OnSideChain, - KnownBad, -} - -impl BlockKnown { - pub fn is_known(&self) -> bool { - !matches!(self, BlockKnown::No) - } -} - -pub enum DataBaseRequest { - CurrentHeight, - CumulativeDifficulty, - CoreSyncData, - Chain, - BlockHeight([u8; 32]), - BlockKnown([u8; 32]), -} - -pub enum DataBaseResponse { - CurrentHeight(u64), - CumulativeDifficulty(u128), - CoreSyncData(CoreSyncData), - Chain(Vec<[u8; 32]>), - BlockHeight(Option<u64>), - BlockKnown(BlockKnown), -} - -#[derive(Debug, Error, PartialEq, Eq)] -pub enum DatabaseError {} diff --git a/p2p/sync-states/Cargo.toml b/p2p/sync-states/Cargo.toml deleted file mode 100644 index 65e275a..0000000 --- a/p2p/sync-states/Cargo.toml +++ /dev/null @@ -1,21 +0,0 @@ -[package] -name = "cuprate-sync-states" -version = "0.1.0" -edition = "2021" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -cuprate-common = {path = "../../common"} -cuprate-peer = {path = "../peer"} -cuprate-protocol = {path = "../protocol"} -monero = {git="https://github.com/Boog900/monero-rs.git", branch="db", features=["database"]} -monero-wire = {path= "../../net/monero-wire"} -futures = "0.3.26" -tower = {version = "0.4.13", features = ["util"]} -thiserror = "1.0.39" - - -tokio = {version="1.1", features=["full"]} -tokio-util = {version ="0.7", features=["compat"]} - diff --git a/p2p/sync-states/src/lib.rs b/p2p/sync-states/src/lib.rs deleted file mode 100644 index 79bbe20..0000000 --- a/p2p/sync-states/src/lib.rs +++ /dev/null @@ -1,538 +0,0 @@ -use std::collections::{HashMap, HashSet}; -use std::sync::{Arc, Mutex}; - -use futures::channel::mpsc; -use futures::StreamExt; -use monero::Hash; -use thiserror::Error; -use tower::{Service, ServiceExt}; - -use cuprate_common::{hardforks, HardForks}; -use cuprate_peer::connection::PeerSyncChange; -use cuprate_protocol::temp_database::{ - BlockKnown, DataBaseRequest, DataBaseResponse, DatabaseError, -}; -use cuprate_protocol::{InternalMessageRequest, InternalMessageResponse}; -use monero_wire::messages::protocol::ChainResponse; -use monero_wire::messages::{ChainRequest, CoreSyncData}; -use monero_wire::{Message, NetworkAddress}; - -// TODO: Move this!!!!!!! -// ******************************** - -pub enum PeerSetRequest { - DisconnectPeer(NetworkAddress), - BanPeer(NetworkAddress), - SendRequest(InternalMessageRequest, Option<NetworkAddress>), -} - -pub struct PeerSetResponse { - peer: NetworkAddress, - response: Option<InternalMessageResponse>, -} - -// ******************************* -#[derive(Debug, Default)] -pub struct IndividualPeerSync { - height: u64, - // no grantee this is the same block as height - top_id: Hash, - top_version: u8, - cumulative_difficulty: u128, - /// the height the list of needed blocks starts at - start_height: u64, - /// list of block hashes our node does not have. - needed_blocks: Vec<(Hash, Option<u64>)>, -} - -#[derive(Debug, Default)] -pub struct PeersSyncData { - peers: HashMap<NetworkAddress, IndividualPeerSync>, -} - -impl PeersSyncData { - pub fn new_core_sync_data( - &mut self, - id: &NetworkAddress, - core_sync: CoreSyncData, - ) -> Result<(), SyncStatesError> { - let peer_data = self.peers.get_mut(&id); - if peer_data.is_none() { - let ips = IndividualPeerSync { - height: core_sync.current_height, - top_id: core_sync.top_id, - top_version: core_sync.top_version, - cumulative_difficulty: core_sync.cumulative_difficulty(), - start_height: 0, - needed_blocks: vec![], - }; - self.peers.insert(*id, ips); - } else { - let peer_data = peer_data.unwrap(); - if peer_data.height > core_sync.current_height { - return Err(SyncStatesError::PeersHeightHasDropped); - } - if peer_data.cumulative_difficulty > core_sync.cumulative_difficulty() { - return Err(SyncStatesError::PeersCumulativeDifficultyDropped); - } - peer_data.height = core_sync.current_height; - peer_data.cumulative_difficulty = core_sync.cumulative_difficulty(); - peer_data.top_id = core_sync.top_id; - peer_data.top_version = core_sync.top_version; - } - Ok(()) - } - - pub fn new_chain_response( - &mut self, - id: &NetworkAddress, - chain_response: ChainResponse, - needed_blocks: Vec<(Hash, Option<u64>)>, - ) -> Result<(), SyncStatesError> { - let peer_data = self - .peers - .get_mut(&id) - .expect("Peers must give use their core sync before chain response"); - - // it's sad we have to do this so late in the response validation process - if peer_data.height > chain_response.total_height { - return Err(SyncStatesError::PeersHeightHasDropped); - } - if peer_data.cumulative_difficulty > chain_response.cumulative_difficulty() { - return Err(SyncStatesError::PeersCumulativeDifficultyDropped); - } - - peer_data.cumulative_difficulty = chain_response.cumulative_difficulty(); - peer_data.height = chain_response.total_height; - peer_data.start_height = chain_response.start_height - + chain_response.m_block_ids.len() as u64 - - needed_blocks.len() as u64; - peer_data.needed_blocks = needed_blocks; - Ok(()) - } - // returns true if we have ran out of known blocks for that peer - pub fn new_objects_response( - &mut self, - id: &NetworkAddress, - mut block_ids: HashSet<Hash>, - ) -> Result<bool, SyncStatesError> { - let peer_data = self - .peers - .get_mut(id) - .expect("Peers must give use their core sync before objects response"); - let mut i = 0; - if peer_data.needed_blocks.is_empty() { - return Ok(true); - } - while !block_ids.contains(&peer_data.needed_blocks[i].0) { - i += 1; - if i == peer_data.needed_blocks.len() { - peer_data.needed_blocks = vec![]; - peer_data.start_height = 0; - return Ok(true); - } - } - for _ in 0..block_ids.len() { - if !block_ids.remove(&peer_data.needed_blocks[i].0) { - return Err(SyncStatesError::PeerSentAnUnexpectedBlockId); - } - i += 1; - if i == peer_data.needed_blocks.len() { - peer_data.needed_blocks = vec![]; - peer_data.start_height = 0; - return Ok(true); - } - } - peer_data.needed_blocks = peer_data.needed_blocks[i..].to_vec(); - peer_data.start_height = peer_data.start_height + i as u64; - return Ok(false); - } - - pub fn peer_disconnected(&mut self, id: &NetworkAddress) { - let _ = self.peers.remove(id); - } -} - -#[derive(Debug, Error, PartialEq, Eq)] -pub enum SyncStatesError { - #[error("Peer sent a block id we know is bad")] - PeerSentKnownBadBlock, - #[error("Peer sent a block id we weren't expecting")] - PeerSentAnUnexpectedBlockId, - #[error("Peer sent a chain entry where we don't know the start")] - PeerSentNoneOverlappingFirstBlock, - #[error("We have the peers block just at a different height")] - WeHaveBlockAtDifferentHeight, - #[error("The peer sent a top version we weren't expecting")] - PeerSentBadTopVersion, - #[error("The peer sent a weird pruning seed")] - PeerSentBadPruningSeed, - #[error("The peer height has dropped")] - PeersHeightHasDropped, - #[error("The peers cumulative difficulty has dropped")] - PeersCumulativeDifficultyDropped, - #[error("Our database returned an error: {0}")] - DataBaseError(#[from] DatabaseError), -} - -pub struct SyncStates<Db> { - peer_sync_rx: mpsc::Receiver<PeerSyncChange>, - hardforks: HardForks, - peer_sync_states: Arc<Mutex<PeersSyncData>>, - blockchain: Db, -} - -impl<Db> SyncStates<Db> -where - Db: Service<DataBaseRequest, Response = DataBaseResponse, Error = DatabaseError>, -{ - pub fn new( - peer_sync_rx: mpsc::Receiver<PeerSyncChange>, - hardforks: HardForks, - peer_sync_states: Arc<Mutex<PeersSyncData>>, - blockchain: Db, - ) -> Self { - SyncStates { - peer_sync_rx, - hardforks, - peer_sync_states, - blockchain, - } - } - async fn send_database_request( - &mut self, - req: DataBaseRequest, - ) -> Result<DataBaseResponse, DatabaseError> { - let ready_blockchain = self.blockchain.ready().await?; - ready_blockchain.call(req).await - } - - async fn handle_core_sync_change( - &mut self, - id: &NetworkAddress, - core_sync: CoreSyncData, - ) -> Result<bool, SyncStatesError> { - if core_sync.current_height > 0 { - let version = self - .hardforks - .get_ideal_version_from_height(core_sync.current_height - 1); - if version >= 6 && version != core_sync.top_version { - return Err(SyncStatesError::PeerSentBadTopVersion); - } - } - if core_sync.pruning_seed != 0 { - let log_stripes = - monero::database::pruning::get_pruning_log_stripes(core_sync.pruning_seed); - let stripe = - monero::database::pruning::get_pruning_stripe_for_seed(core_sync.pruning_seed); - if stripe != monero::database::pruning::CRYPTONOTE_PRUNING_LOG_STRIPES - || stripe > (1 << log_stripes) - { - return Err(SyncStatesError::PeerSentBadPruningSeed); - } - } - //if core_sync.current_height > max block numb - let DataBaseResponse::BlockHeight(height) = self.send_database_request(DataBaseRequest::BlockHeight(core_sync.top_id)).await? else { - unreachable!("the blockchain won't send the wrong response"); - }; - - let behind: bool; - - if let Some(height) = height { - if height != core_sync.current_height { - return Err(SyncStatesError::WeHaveBlockAtDifferentHeight); - } - behind = false; - } else { - let DataBaseResponse::CumulativeDifficulty(cumulative_diff) = self.send_database_request(DataBaseRequest::CumulativeDifficulty).await? else { - unreachable!("the blockchain won't send the wrong response"); - }; - // if their chain has more POW we want it - if cumulative_diff < core_sync.cumulative_difficulty() { - behind = true; - } else { - behind = false; - } - } - - let mut sync_states = self.peer_sync_states.lock().unwrap(); - sync_states.new_core_sync_data(id, core_sync)?; - - Ok(behind) - } - - async fn handle_chain_entry_response( - &mut self, - id: &NetworkAddress, - chain_response: ChainResponse, - ) -> Result<(), SyncStatesError> { - let mut expect_unknown = false; - let mut needed_blocks = Vec::with_capacity(chain_response.m_block_ids.len()); - - for (index, block_id) in chain_response.m_block_ids.iter().enumerate() { - let DataBaseResponse::BlockKnown(known) = self.send_database_request(DataBaseRequest::BlockKnown(*block_id)).await? else { - unreachable!("the blockchain won't send the wrong response"); - }; - if index == 0 { - if !known.is_known() { - return Err(SyncStatesError::PeerSentNoneOverlappingFirstBlock); - } - } else { - match known { - BlockKnown::No => expect_unknown = true, - BlockKnown::OnMainChain => { - if expect_unknown { - return Err(SyncStatesError::PeerSentAnUnexpectedBlockId); - } else { - let DataBaseResponse::BlockHeight(height) = self.send_database_request(DataBaseRequest::BlockHeight(*block_id)).await? else { - unreachable!("the blockchain won't send the wrong response"); - }; - if chain_response.start_height + index as u64 - != height.expect("We already know this block is in our main chain.") - { - return Err(SyncStatesError::WeHaveBlockAtDifferentHeight); - } - } - } - BlockKnown::OnSideChain => { - if expect_unknown { - return Err(SyncStatesError::PeerSentAnUnexpectedBlockId); - } - } - BlockKnown::KnownBad => return Err(SyncStatesError::PeerSentKnownBadBlock), - } - } - let block_weight = chain_response.m_block_weights.get(index).map(|f| f.clone()); - needed_blocks.push((*block_id, block_weight)); - } - let mut sync_states = self.peer_sync_states.lock().unwrap(); - sync_states.new_chain_response(id, chain_response, needed_blocks)?; - Ok(()) - } - - async fn build_chain_request(&mut self) -> Result<ChainRequest, DatabaseError> { - let DataBaseResponse::Chain(ids) = self.send_database_request(DataBaseRequest::Chain).await? else { - unreachable!("the blockchain won't send the wrong response"); - }; - - Ok(ChainRequest { - block_ids: ids, - prune: false, - }) - } - - async fn get_peers_chain_entry<Svc>( - &mut self, - peer_set: &mut Svc, - id: &NetworkAddress, - ) -> Result<ChainResponse, DatabaseError> - where - Svc: Service<PeerSetRequest, Response = PeerSetResponse, Error = DatabaseError>, - { - let chain_req = self.build_chain_request().await?; - let ready_set = peer_set.ready().await.unwrap(); - let response: PeerSetResponse = ready_set - .call(PeerSetRequest::SendRequest( - Message::Notification(chain_req.into()) - .try_into() - .expect("Chain request can always be converted to IMR"), - Some(*id), - )) - .await?; - let InternalMessageResponse::ChainResponse(response) = response.response.expect("peer set will return a result for a chain request") else { - unreachable!("peer set will return correct response"); - }; - - Ok(response) - } - - async fn get_and_handle_chain_entry<Svc>( - &mut self, - peer_set: &mut Svc, - id: NetworkAddress, - ) -> Result<(), SyncStatesError> - where - Svc: Service<PeerSetRequest, Response = PeerSetResponse, Error = DatabaseError>, - { - let chain_response = self.get_peers_chain_entry(peer_set, &id).await?; - self.handle_chain_entry_response(&id, chain_response).await - } - - async fn handle_objects_response( - &mut self, - id: NetworkAddress, - block_ids: Vec<Hash>, - peers_height: u64, - ) -> Result<bool, SyncStatesError> { - let mut sync_states = self.peer_sync_states.lock().unwrap(); - let ran_out_of_blocks = - sync_states.new_objects_response(&id, HashSet::from_iter(block_ids))?; - drop(sync_states); - if ran_out_of_blocks { - let DataBaseResponse::CurrentHeight(our_height) = self.send_database_request(DataBaseRequest::CurrentHeight).await? else { - unreachable!("the blockchain won't send the wrong response"); - }; - if our_height < peers_height { - return Ok(true); - } - } - Ok(false) - } - - fn handle_peer_disconnect(&mut self, id: NetworkAddress) { - let mut sync_states = self.peer_sync_states.lock().unwrap(); - sync_states.peer_disconnected(&id); - } - - pub async fn run<Svc>(mut self, mut peer_set: Svc) - where - Svc: Service<PeerSetRequest, Response = PeerSetResponse, Error = DatabaseError>, - { - loop { - let Some(change) = self.peer_sync_rx.next().await else { - // is this best? - return; - }; - - match change { - PeerSyncChange::CoreSyncData(id, csd) => { - match self.handle_core_sync_change(&id, csd).await { - Err(_) => { - // TODO: check if error needs ban or forget - let ready_set = peer_set.ready().await.unwrap(); - let res = ready_set.call(PeerSetRequest::BanPeer(id)).await; - } - Ok(request_chain) => { - if request_chain { - self.get_and_handle_chain_entry(&mut peer_set, id).await; - } - } - } - } - PeerSyncChange::ObjectsResponse(id, block_ids, height) => { - match self.handle_objects_response(id, block_ids, height).await { - Err(_) => { - // TODO: check if error needs ban or forget - let ready_set = peer_set.ready().await.unwrap(); - let res = ready_set.call(PeerSetRequest::BanPeer(id)).await; - } - Ok(res) => { - if res { - self.get_and_handle_chain_entry(&mut peer_set, id).await; - } - } - } - } - PeerSyncChange::PeerDisconnected(id) => { - self.handle_peer_disconnect(id); - } - } - } - } -} - -#[cfg(test)] -mod tests { - use monero::Hash; - use monero_wire::messages::{ChainResponse, CoreSyncData}; - - use crate::{PeersSyncData, SyncStatesError}; - - #[test] - fn peer_sync_data_good_core_sync() { - let mut peer_sync_states = PeersSyncData::default(); - let core_sync = CoreSyncData::new(65346753, 1232, 389, Hash::null(), 1); - - peer_sync_states - .new_core_sync_data(&monero_wire::NetworkAddress::default(), core_sync) - .unwrap(); - - let new_core_sync = CoreSyncData::new(65346754, 1233, 389, Hash::null(), 1); - - peer_sync_states - .new_core_sync_data(&monero_wire::NetworkAddress::default(), new_core_sync) - .unwrap(); - - let peer = peer_sync_states - .peers - .get(&monero_wire::NetworkAddress::default()) - .unwrap(); - assert_eq!(peer.height, 1233); - assert_eq!(peer.cumulative_difficulty, 65346754); - } - - #[test] - fn peer_sync_data_peer_height_dropped() { - let mut peer_sync_states = PeersSyncData::default(); - let core_sync = CoreSyncData::new(65346753, 1232, 389, Hash::null(), 1); - - peer_sync_states - .new_core_sync_data(&monero_wire::NetworkAddress::default(), core_sync) - .unwrap(); - - let new_core_sync = CoreSyncData::new(65346754, 1231, 389, Hash::null(), 1); - - let res = peer_sync_states - .new_core_sync_data(&monero_wire::NetworkAddress::default(), new_core_sync) - .unwrap_err(); - - assert_eq!(res, SyncStatesError::PeersHeightHasDropped); - } - - #[test] - fn peer_sync_data_peer_cumulative_difficulty_dropped() { - let mut peer_sync_states = PeersSyncData::default(); - let core_sync = CoreSyncData::new(65346753, 1232, 389, Hash::null(), 1); - - peer_sync_states - .new_core_sync_data(&monero_wire::NetworkAddress::default(), core_sync) - .unwrap(); - - let new_core_sync = CoreSyncData::new(65346752, 1233, 389, Hash::null(), 1); - - let res = peer_sync_states - .new_core_sync_data(&monero_wire::NetworkAddress::default(), new_core_sync) - .unwrap_err(); - - assert_eq!(res, SyncStatesError::PeersCumulativeDifficultyDropped); - } - - #[test] - fn peer_sync_new_chain_response() { - let mut peer_sync_states = PeersSyncData::default(); - let core_sync = CoreSyncData::new(65346753, 1232, 389, Hash::null(), 1); - - peer_sync_states - .new_core_sync_data(&monero_wire::NetworkAddress::default(), core_sync) - .unwrap(); - - let chain_response = ChainResponse::new( - 10, - 1233, - 65346754, - vec![Hash::new(&[1]), Hash::new(&[2])], - vec![], - vec![], - ); - - let needed_blocks = vec![(Hash::new(&[2]), None)]; - - peer_sync_states - .new_chain_response( - &monero_wire::NetworkAddress::default(), - chain_response, - needed_blocks, - ) - .unwrap(); - - let peer = peer_sync_states - .peers - .get(&monero_wire::NetworkAddress::default()) - .unwrap(); - - assert_eq!(peer.start_height, 11); - assert_eq!(peer.height, 1233); - assert_eq!(peer.cumulative_difficulty, 65346754); - assert_eq!(peer.needed_blocks, vec![(Hash::new(&[2]), None)]); - } -} diff --git a/p2p/sync-states/tests/mod.rs b/p2p/sync-states/tests/mod.rs deleted file mode 100644 index 89896b9..0000000 --- a/p2p/sync-states/tests/mod.rs +++ /dev/null @@ -1,109 +0,0 @@ -use std::{ - pin::Pin, - str::FromStr, - sync::{Arc, Mutex}, -}; - -use cuprate_common::{HardForks, Network}; -use cuprate_peer::PeerError; -use cuprate_protocol::{ - temp_database::{BlockKnown, DataBaseRequest, DataBaseResponse, DatabaseError}, - Direction, InternalMessageRequest, InternalMessageResponse, -}; -use cuprate_sync_states::SyncStates; -use futures::{channel::mpsc, Future, FutureExt}; -use monero::Hash; -use monero_wire::messages::{admin::HandshakeResponse, CoreSyncData}; -use tower::ServiceExt; - -use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; - -struct TestBlockchain; - -impl tower::Service<DataBaseRequest> for TestBlockchain { - type Error = DatabaseError; - type Response = DataBaseResponse; - type Future = - Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>; - fn poll_ready( - &mut self, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll<Result<(), Self::Error>> { - std::task::Poll::Ready(Ok(())) - } - fn call(&mut self, req: DataBaseRequest) -> Self::Future { - let res = match req { - DataBaseRequest::BlockHeight(h) => DataBaseResponse::BlockHeight(Some(221)), - DataBaseRequest::BlockKnown(_) => DataBaseResponse::BlockKnown(BlockKnown::OnMainChain), - DataBaseRequest::Chain => todo!(), - DataBaseRequest::CoreSyncData => { - DataBaseResponse::CoreSyncData(CoreSyncData::new(0, 0, 0, Hash::null(), 0)) - } - DataBaseRequest::CumulativeDifficulty => DataBaseResponse::CumulativeDifficulty(0), - DataBaseRequest::CurrentHeight => DataBaseResponse::CurrentHeight(0), - }; - - async { Ok(res) }.boxed() - } -} - -#[derive(Debug, Clone)] -struct TestPeerRequest; - -impl tower::Service<InternalMessageRequest> for TestPeerRequest { - type Error = PeerError; - type Response = InternalMessageResponse; - type Future = - Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>; - fn poll_ready( - &mut self, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll<Result<(), Self::Error>> { - todo!() - } - fn call(&mut self, req: InternalMessageRequest) -> Self::Future { - todo!() - } -} - -#[tokio::test] -async fn test_p2p_conn() { - let conf = cuprate_peer::handshaker::NetworkConfig::default(); - let (addr_tx, addr_rx) = mpsc::channel(21); - let (sync_tx, sync_rx) = mpsc::channel(21); - let peer_sync_states = Arc::new(Mutex::default()); - - let peer_sync_states = SyncStates::new( - sync_rx, - HardForks::new(Network::MainNet), - peer_sync_states, - TestBlockchain, - ); - - let mut handshaker = cuprate_peer::handshaker::Handshaker::new( - conf, - addr_tx, - TestBlockchain, - sync_tx, - TestPeerRequest.boxed_clone(), - ); - - let soc = tokio::net::TcpSocket::new_v4().unwrap(); - let addr = std::net::SocketAddr::from_str("127.0.0.1:18080").unwrap(); - - let mut con = soc.connect(addr).await.unwrap(); - - let (r_h, w_h) = con.split(); - - let (client, conn) = handshaker - .complete_handshake( - r_h.compat(), - w_h.compat_write(), - Direction::Outbound, - monero_wire::NetworkAddress::default(), - ) - .await - .unwrap(); - - //conn.run().await; -} diff --git a/test-utils/Cargo.toml b/test-utils/Cargo.toml new file mode 100644 index 0000000..8096b09 --- /dev/null +++ b/test-utils/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "cuprate-test-utils" +version = "0.1.0" +edition = "2021" + +[dependencies] +monero-wire = {path = "../net/monero-wire"} +monero-peer = {path = "../p2p/monero-peer"} + +futures = "0.3.29" +async-trait = "0.1.74" \ No newline at end of file diff --git a/test-utils/src/lib.rs b/test-utils/src/lib.rs new file mode 100644 index 0000000..e3870a6 --- /dev/null +++ b/test-utils/src/lib.rs @@ -0,0 +1 @@ +pub mod test_netzone; diff --git a/test-utils/src/test_netzone.rs b/test-utils/src/test_netzone.rs new file mode 100644 index 0000000..34e3b56 --- /dev/null +++ b/test-utils/src/test_netzone.rs @@ -0,0 +1,109 @@ +use std::{ + fmt::Formatter, + io::Error, + net::{Ipv4Addr, SocketAddr}, + pin::Pin, + task::{Context, Poll}, +}; + +use futures::{channel::mpsc::Sender as InnerSender, stream::BoxStream, Sink}; + +use monero_wire::{ + network_address::{NetworkAddress, NetworkAddressIncorrectZone}, + BucketError, Message, +}; + +use monero_peer::NetworkZone; + +#[derive(Clone)] +pub struct TestNetZoneAddr(pub u32); + +impl std::fmt::Display for TestNetZoneAddr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_str(format!("test client, id: {}", self.0).as_str()) + } +} + +impl From<TestNetZoneAddr> for NetworkAddress { + fn from(value: TestNetZoneAddr) -> Self { + NetworkAddress::Clear(SocketAddr::new(Ipv4Addr::from(value.0).into(), 18080)) + } +} + +impl TryFrom<NetworkAddress> for TestNetZoneAddr { + type Error = NetworkAddressIncorrectZone; + + fn try_from(value: NetworkAddress) -> Result<Self, Self::Error> { + match value { + NetworkAddress::Clear(soc) => match soc { + SocketAddr::V4(v4) => Ok(TestNetZoneAddr(u32::from_be_bytes(v4.ip().octets()))), + _ => panic!("None v4 address in test code"), + }, + } + } +} + +pub struct Sender { + inner: InnerSender<Message>, +} + +impl From<InnerSender<Message>> for Sender { + fn from(inner: InnerSender<Message>) -> Self { + Sender { inner } + } +} + +impl Sink<Message> for Sender { + type Error = BucketError; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.get_mut() + .inner + .poll_ready(cx) + .map_err(|_| BucketError::IO(std::io::Error::other("mock connection channel closed"))) + } + + fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { + self.get_mut() + .inner + .start_send(item) + .map_err(|_| BucketError::IO(std::io::Error::other("mock connection channel closed"))) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + Pin::new(&mut self.get_mut().inner) + .poll_flush(cx) + .map_err(|_| BucketError::IO(std::io::Error::other("mock connection channel closed"))) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + Pin::new(&mut self.get_mut().inner) + .poll_close(cx) + .map_err(|_| BucketError::IO(std::io::Error::other("mock connection channel closed"))) + } +} + +#[derive(Clone)] +pub struct TestNetZone<const ALLOW_SYNC: bool, const DANDELION_PP: bool, const CHECK_NODE_ID: bool>; + +#[async_trait::async_trait] +impl<const ALLOW_SYNC: bool, const DANDELION_PP: bool, const CHECK_NODE_ID: bool> NetworkZone + for TestNetZone<ALLOW_SYNC, DANDELION_PP, CHECK_NODE_ID> +{ + const ALLOW_SYNC: bool = ALLOW_SYNC; + const DANDELION_PP: bool = DANDELION_PP; + const CHECK_NODE_ID: bool = CHECK_NODE_ID; + + type Addr = TestNetZoneAddr; + type Stream = BoxStream<'static, Result<Message, BucketError>>; + type Sink = Sender; + type ServerCfg = (); + + async fn connect_to_peer(_: Self::Addr) -> Result<(Self::Stream, Self::Sink), Error> { + unimplemented!() + } + + async fn incoming_connection_listener(_: Self::ServerCfg) -> () { + unimplemented!() + } +}