diff --git a/Cargo.lock b/Cargo.lock index 000a9885..0dc645a4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1198,6 +1198,7 @@ dependencies = [ "monero-wire", "thiserror", "tokio", + "tokio-stream", "tokio-util", "tower", "tracing", @@ -2041,6 +2042,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-stream" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "397c988d37662c7dda6d2208364a706264bf3d6138b11d436cbac0ad38832842" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", + "tokio-util", +] + [[package]] name = "tokio-util" version = "0.7.10" diff --git a/p2p/monero-p2p/Cargo.toml b/p2p/monero-p2p/Cargo.toml index b9a80900..1566c537 100644 --- a/p2p/monero-p2p/Cargo.toml +++ b/p2p/monero-p2p/Cargo.toml @@ -13,8 +13,9 @@ borsh = ["dep:borsh"] monero-wire = {path= "../../net/monero-wire"} cuprate-common = {path = "../../common", features = ["borsh"]} -tokio = {version= "1.34.0", default-features = false, features = ["net"]} +tokio = {version= "1.34.0", default-features = false, features = ["net", "sync"]} tokio-util = { version = "0.7.10", default-features = false, features = ["codec"] } +tokio-stream = {version = "0.1.14", default-features = false, features = ["sync"]} futures = "0.3.29" async-trait = "0.1.74" tower = { version= "0.4.13", features = ["util"] } diff --git a/p2p/monero-p2p/src/client.rs b/p2p/monero-p2p/src/client.rs index 44167020..dca66e2f 100644 --- a/p2p/monero-p2p/src/client.rs +++ b/p2p/monero-p2p/src/client.rs @@ -1,3 +1,20 @@ +use std::fmt::Formatter; +use std::{ + fmt::{Debug, Display}, + task::{Context, Poll}, +}; + +use futures::channel::oneshot; +use tokio::{sync::mpsc, task::JoinHandle}; +use tokio_util::sync::PollSender; +use tower::Service; + +use cuprate_common::tower_utils::InfallibleOneshotReceiver; + +use crate::{ + handles::ConnectionHandle, NetworkZone, PeerError, PeerRequest, PeerResponse, SharedError, +}; + mod conector; mod connection; pub mod handshaker; @@ -12,3 +29,85 @@ pub enum InternalPeerID { KnownAddr(A), Unknown(u64), } + +impl Display for InternalPeerID { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + InternalPeerID::KnownAddr(addr) => addr.fmt(f), + InternalPeerID::Unknown(id) => f.write_str(&format!("Unknown addr, ID: {}", id)), + } + } +} + +pub struct Client { + id: InternalPeerID, + handle: ConnectionHandle, + + connection_tx: PollSender, + connection_handle: JoinHandle<()>, + + error: SharedError, +} + +impl Client { + pub fn new( + id: InternalPeerID, + handle: ConnectionHandle, + connection_tx: mpsc::Sender, + connection_handle: JoinHandle<()>, + error: SharedError, + ) -> Self { + Self { + id, + handle, + connection_tx: PollSender::new(connection_tx), + connection_handle, + error, + } + } + + fn set_err(&self, err: PeerError) -> tower::BoxError { + let err_str = err.to_string(); + match self.error.try_insert_err(err) { + Ok(_) => err_str, + Err(e) => e.to_string(), + } + .into() + } +} + +impl Service for Client { + type Response = PeerResponse; + type Error = tower::BoxError; + type Future = InfallibleOneshotReceiver>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + if let Some(err) = self.error.try_get_err() { + return Poll::Ready(Err(err.to_string().into())); + } + + if self.connection_handle.is_finished() { + let err = self.set_err(PeerError::ClientChannelClosed); + return Poll::Ready(Err(err)); + } + + self.connection_tx + .poll_reserve(cx) + .map_err(|_| PeerError::ClientChannelClosed.into()) + } + + fn call(&mut self, request: PeerRequest) -> Self::Future { + let (tx, rx) = oneshot::channel(); + let req = connection::ConnectionTaskRequest { + response_channel: tx, + request, + }; + + self.connection_tx + .send_item(req) + .map_err(|_| ()) + .expect("poll_ready should have been called"); + + rx.into() + } +} diff --git a/p2p/monero-p2p/src/client/conector.rs b/p2p/monero-p2p/src/client/conector.rs index 5ee69f04..3f9f6047 100644 --- a/p2p/monero-p2p/src/client/conector.rs +++ b/p2p/monero-p2p/src/client/conector.rs @@ -5,15 +5,17 @@ use std::{ }; use futures::FutureExt; +use tokio::sync::OwnedSemaphorePermit; use tower::{Service, ServiceExt}; use crate::{ - client::{DoHandshakeRequest, HandShaker, HandshakeError}, + client::{Client, DoHandshakeRequest, HandShaker, HandshakeError, InternalPeerID}, AddressBook, ConnectionDirection, CoreSyncSvc, NetworkZone, PeerRequestHandler, }; pub struct ConnectRequest { pub addr: Z::Addr, + pub permit: OwnedSemaphorePermit, } pub struct Connector { @@ -33,7 +35,7 @@ where CSync: CoreSyncSvc + Clone, ReqHdlr: PeerRequestHandler + Clone, { - type Response = (); + type Response = Client; type Error = HandshakeError; type Future = Pin> + Send + 'static>>; @@ -49,7 +51,8 @@ where async move { let (peer_stream, peer_sink) = Z::connect_to_peer(req.addr).await?; let req = DoHandshakeRequest { - addr: req.addr, + addr: InternalPeerID::KnownAddr(req.addr), + permit: req.permit, peer_stream, peer_sink, direction: ConnectionDirection::OutBound, diff --git a/p2p/monero-p2p/src/client/connection.rs b/p2p/monero-p2p/src/client/connection.rs index bb445ba3..a05b0220 100644 --- a/p2p/monero-p2p/src/client/connection.rs +++ b/p2p/monero-p2p/src/client/connection.rs @@ -1,58 +1,65 @@ +use std::sync::Arc; + use futures::{ - channel::{mpsc, oneshot}, - stream::FusedStream, + channel::oneshot, + stream::{Fuse, FusedStream}, SinkExt, StreamExt, }; +use tokio::sync::{broadcast, mpsc}; +use tokio_stream::wrappers::{BroadcastStream, ReceiverStream}; +use tower::ServiceExt; -use monero_wire::{LevinCommand, Message}; +use monero_wire::{LevinCommand, Message, ProtocolMessage}; -use crate::{MessageID, NetworkZone, PeerError, PeerRequest, PeerRequestHandler, PeerResponse}; +use crate::{ + handles::ConnectionGuard, MessageID, NetworkZone, PeerBroadcast, PeerError, PeerRequest, + PeerRequestHandler, PeerResponse, SharedError, +}; pub struct ConnectionTaskRequest { - request: PeerRequest, - response_channel: oneshot::Sender>, + pub request: PeerRequest, + pub response_channel: oneshot::Sender>, } pub enum State { WaitingForRequest, WaitingForResponse { request_id: MessageID, - tx: oneshot::Sender>, + tx: oneshot::Sender>, }, } -impl State { - /// Returns if the [`LevinCommand`] is the correct response message for our request. - /// - /// e.g that we didn't get a block for a txs request. - fn levin_command_response(&self, command: LevinCommand) -> bool { - match self { - State::WaitingForResponse { request_id, .. } => matches!( - (request_id, command), - (MessageID::Handshake, LevinCommand::Handshake) - | (MessageID::TimedSync, LevinCommand::TimedSync) - | (MessageID::Ping, LevinCommand::Ping) - | (MessageID::SupportFlags, LevinCommand::SupportFlags) - | (MessageID::GetObjects, LevinCommand::GetObjectsResponse) - | (MessageID::GetChain, LevinCommand::ChainResponse) - | (MessageID::FluffyMissingTxs, LevinCommand::NewFluffyBlock) - | ( - MessageID::GetTxPoolCompliment, - LevinCommand::NewTransactions - ) - ), - _ => panic!("We are not in a state to be checking responses!"), - } - } +/// Returns if the [`LevinCommand`] is the correct response message for our request. +/// +/// e.g that we didn't get a block for a txs request. +fn levin_command_response(message_id: &MessageID, command: LevinCommand) -> bool { + matches!( + (message_id, command), + (MessageID::Handshake, LevinCommand::Handshake) + | (MessageID::TimedSync, LevinCommand::TimedSync) + | (MessageID::Ping, LevinCommand::Ping) + | (MessageID::SupportFlags, LevinCommand::SupportFlags) + | (MessageID::GetObjects, LevinCommand::GetObjectsResponse) + | (MessageID::GetChain, LevinCommand::ChainResponse) + | (MessageID::FluffyMissingTxs, LevinCommand::NewFluffyBlock) + | ( + MessageID::GetTxPoolCompliment, + LevinCommand::NewTransactions + ) + ) } pub struct Connection { peer_sink: Z::Sink, state: State, - client_rx: mpsc::Receiver, + client_rx: Fuse>, + broadcast_rx: Fuse>>, peer_request_handler: ReqHndlr, + + connection_guard: ConnectionGuard, + error: SharedError, } impl Connection @@ -62,47 +69,24 @@ where pub fn new( peer_sink: Z::Sink, client_rx: mpsc::Receiver, - + broadcast_rx: broadcast::Receiver>, peer_request_handler: ReqHndlr, + connection_guard: ConnectionGuard, + error: SharedError, ) -> Connection { Connection { peer_sink, state: State::WaitingForRequest, - client_rx, + client_rx: ReceiverStream::new(client_rx).fuse(), + broadcast_rx: BroadcastStream::new(broadcast_rx).fuse(), peer_request_handler, + connection_guard, + error, } } - async fn handle_response(&mut self, res: PeerResponse) -> Result<(), PeerError> { - let state = std::mem::replace(&mut self.state, State::WaitingForRequest); - if let State::WaitingForResponse { request_id, tx } = state { - if request_id != res.id() { - // TODO: Fail here - return Err(PeerError::PeerSentIncorrectResponse); - } - - // TODO: do more tests here - - // response passed our tests we can send it to the requester - let _ = tx.send(Ok(res)); - Ok(()) - } else { - unreachable!("This will only be called when in state WaitingForResponse"); - } - } - - async fn send_message_to_peer(&mut self, mes: impl Into) -> Result<(), PeerError> { - Ok(self.peer_sink.send(mes.into()).await?) - } - - async fn handle_peer_request(&mut self, _req: PeerRequest) -> Result<(), PeerError> { - // we should check contents of peer requests for obvious errors like we do with responses - todo!() - /* - let ready_svc = self.svc.ready().await?; - let res = ready_svc.call(req).await?; - self.send_message_to_peer(res).await - */ + async fn send_message_to_peer(&mut self, mes: Message) -> Result<(), PeerError> { + Ok(self.peer_sink.send(mes).await?) } async fn handle_client_request(&mut self, req: ConnectionTaskRequest) -> Result<(), PeerError> { @@ -111,26 +95,72 @@ where request_id: req.request.id(), tx: req.response_channel, }; + } else { + // TODO: we should send this after sending the message to the peer. + req.response_channel.send(Ok(PeerResponse::NA)); + } + self.send_message_to_peer(req.request.into()).await + } + + async fn handle_peer_request(&mut self, req: PeerRequest) -> Result<(), PeerError> { + let ready_svc = self.peer_request_handler.ready().await?; + let res = ready_svc.call(req).await?; + if matches!(res, PeerResponse::NA) { + return Ok(()); + } + + self.send_message_to_peer(res.try_into().unwrap()).await + } + + async fn handle_potential_response(&mut self, mes: Message) -> Result<(), PeerError> { + if mes.is_request() { + return self.handle_peer_request(mes.try_into().unwrap()).await; + } + + let State::WaitingForResponse { request_id, .. } = &self.state else { + panic!("Not in correct state, can't receive response!") + }; + + if levin_command_response(request_id, mes.command()) { + // TODO: Do more checks before returning response. + + let State::WaitingForResponse { tx, .. } = + std::mem::replace(&mut self.state, State::WaitingForRequest) + else { + panic!("Not in correct state, can't receive response!") + }; + + let _ = tx.send(Ok(mes.try_into().unwrap())); + Ok(()) + } else { + self.handle_peer_request( + mes.try_into() + .map_err(|_| PeerError::PeerSentInvalidMessage)?, + ) + .await } - // TODO: send NA response to requester - self.send_message_to_peer(req.request).await } async fn state_waiting_for_request(&mut self, stream: &mut Str) -> Result<(), PeerError> where Str: FusedStream> + Unpin, { - futures::select! { - peer_message = stream.next() => { - match peer_message.expect("MessageStream will never return None") { - Ok(message) => { - self.handle_peer_request(message.try_into().map_err(|_| PeerError::ResponseError(""))?).await - }, - Err(e) => Err(e.into()), - } - }, + tokio::select! { + biased; + bradcast_req = self.broadcast_rx.next() => { + todo!() + } client_req = self.client_rx.next() => { - self.handle_client_request(client_req.ok_or(PeerError::ClientChannelClosed)?).await + if let Some(client_req) = client_req { + self.handle_client_request(client_req).await? + } + Err(PeerError::ClientChannelClosed) + }, + peer_message = stream.next() => { + if let Some(peer_message) = peer_message { + self.handle_peer_request(peer_message?.try_into().map_err(|_| PeerError::PeerSentInvalidMessage)?).await? + } + Err(PeerError::ConnectionClosed) }, } } @@ -139,38 +169,69 @@ where where Str: FusedStream> + Unpin, { - // put a timeout on this - let peer_message = stream - .next() - .await - .expect("MessageStream will never return None")?; - - if !peer_message.is_request() && self.state.levin_command_response(peer_message.command()) { - if let Ok(res) = peer_message.try_into() { - Ok(self.handle_response(res).await?) - } else { - // im almost certain this is impossible to hit, but im not certain enough to use unreachable!() - Err(PeerError::ResponseError("Peer sent incorrect response")) + tokio::select! { + biased; + bradcast_req = self.broadcast_rx.next() => { + todo!() } - } else if let Ok(req) = peer_message.try_into() { - self.handle_peer_request(req).await - } else { - // this can be hit if the peer sends an incorrect response message - Err(PeerError::ResponseError("Peer sent incorrect response")) + peer_message = stream.next() => { + if let Some(peer_message) = peer_message { + self.handle_peer_request(peer_message?.try_into().map_err(|_| PeerError::PeerSentInvalidMessage)?).await? + } + Err(PeerError::ConnectionClosed) + }, } } - pub async fn run(mut self, mut stream: Str) + pub async fn run(mut self, mut stream: Str, eager_protocol_messages: Vec) where Str: FusedStream> + Unpin, { + for message in eager_protocol_messages { + let message = Message::Protocol(message).try_into(); + + let res = match message { + Ok(mes) => self.handle_peer_request(mes).await, + Err(_) => Err(PeerError::PeerSentInvalidMessage), + }; + + if let Err(err) = res { + return self.shutdown(err); + } + } + loop { - let _res = match self.state { + if self.connection_guard.should_shutdown() { + return self.shutdown(PeerError::ConnectionClosed); + } + + let res = match self.state { State::WaitingForRequest => self.state_waiting_for_request(&mut stream).await, State::WaitingForResponse { .. } => { self.state_waiting_for_response(&mut stream).await } }; + + if let Err(err) = res { + return self.shutdown(err); + } } } + + fn shutdown(mut self, err: PeerError) { + tracing::debug!("Connection task shutting down: {}", err); + let mut client_rx = self.client_rx.into_inner().into_inner(); + client_rx.close(); + + let err_str = err.to_string(); + if let Err(err) = self.error.try_insert_err(err) { + tracing::debug!("Shared error already contains an error: {}", err); + } + + while let Ok(req) = client_rx.try_recv() { + let _ = req.response_channel.send(Err(err_str.clone().into())); + } + + self.connection_guard.connection_closed(); + } } diff --git a/p2p/monero-p2p/src/client/handshaker.rs b/p2p/monero-p2p/src/client/handshaker.rs index 38858b49..86536f7b 100644 --- a/p2p/monero-p2p/src/client/handshaker.rs +++ b/p2p/monero-p2p/src/client/handshaker.rs @@ -2,10 +2,12 @@ use std::{ future::Future, marker::PhantomData, pin::Pin, + sync::Arc, task::{Context, Poll}, }; use futures::{FutureExt, SinkExt, StreamExt}; +use tokio::sync::{broadcast, mpsc, OwnedSemaphorePermit}; use tower::{Service, ServiceExt}; use tracing::Instrument; @@ -20,9 +22,11 @@ use monero_wire::{ }; use crate::{ + client::{connection::Connection, Client, InternalPeerID}, + handles::HandleBuilder, AddressBook, AddressBookRequest, AddressBookResponse, ConnectionDirection, CoreSyncDataRequest, - CoreSyncDataResponse, CoreSyncSvc, MessageID, NetworkZone, PeerRequestHandler, - MAX_PEERS_IN_PEER_LIST_MESSAGE, + CoreSyncDataResponse, CoreSyncSvc, MessageID, NetworkZone, PeerBroadcast, PeerRequestHandler, + SharedError, MAX_PEERS_IN_PEER_LIST_MESSAGE, }; const MAX_EAGER_PROTOCOL_MESSAGES: usize = 2; @@ -46,10 +50,11 @@ pub enum HandshakeError { } pub struct DoHandshakeRequest { - pub addr: Z::Addr, + pub addr: InternalPeerID, pub peer_stream: Z::Stream, pub peer_sink: Z::Sink, pub direction: ConnectionDirection, + pub permit: OwnedSemaphorePermit, } #[derive(Debug, Clone)] @@ -60,6 +65,8 @@ pub struct HandShaker { our_basic_node_data: BasicNodeData, + broadcast_tx: broadcast::Sender>, + _zone: PhantomData, } @@ -69,12 +76,15 @@ impl HandShaker>, + our_basic_node_data: BasicNodeData, ) -> Self { Self { address_book, core_sync_svc, peer_request_svc, + broadcast_tx, our_basic_node_data, _zone: PhantomData, } @@ -88,7 +98,7 @@ where CSync: CoreSyncSvc + Clone, ReqHdlr: PeerRequestHandler + Clone, { - type Response = (); + type Response = Client; type Error = HandshakeError; type Future = Pin> + Send + 'static>>; @@ -103,8 +113,11 @@ where peer_stream, peer_sink, direction, + permit, } = req; + let broadcast_rx = self.broadcast_tx.subscribe(); + let address_book = self.address_book.clone(); let peer_request_svc = self.peer_request_svc.clone(); let core_sync_svc = self.core_sync_svc.clone(); @@ -119,6 +132,8 @@ where peer_stream, peer_sink, direction, + permit, + broadcast_rx, address_book, core_sync_svc, peer_request_svc, @@ -133,15 +148,19 @@ where #[allow(clippy::too_many_arguments)] async fn handshake( - addr: Z::Addr, + addr: InternalPeerID, mut peer_stream: Z::Stream, mut peer_sink: Z::Sink, direction: ConnectionDirection, + + permit: OwnedSemaphorePermit, + broadcast_rx: broadcast::Receiver>, + mut address_book: AdrBook, mut core_sync_svc: CSync, peer_request_svc: ReqHdlr, our_basic_node_data: BasicNodeData, -) -> Result<(), HandshakeError> +) -> Result, HandshakeError> where AdrBook: AddressBook, CSync: CoreSyncSvc, @@ -277,7 +296,27 @@ where tracing::debug!("Handshake complete."); - Ok(()) + let error_slot = SharedError::new(); + + let (connection_guard, handle, _) = HandleBuilder::new().with_permit(permit).build(); + + let (connection_tx, client_rx) = mpsc::channel(3); + + let connection = Connection::::new( + peer_sink, + client_rx, + broadcast_rx, + peer_request_svc, + connection_guard, + error_slot.clone(), + ); + + let connection_handle = + tokio::spawn(connection.run(peer_stream.fuse(), eager_protocol_messages)); + + let client = Client::::new(addr, handle, connection_tx, connection_handle, error_slot); + + Ok(client) } /// Sends a [`HandshakeRequest`] to the peer. diff --git a/p2p/monero-p2p/src/error.rs b/p2p/monero-p2p/src/error.rs index 046a6599..2b8ace84 100644 --- a/p2p/monero-p2p/src/error.rs +++ b/p2p/monero-p2p/src/error.rs @@ -1,12 +1,48 @@ +use std::sync::{Arc, OnceLock}; + +pub struct SharedError(Arc>); + +impl Clone for SharedError { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +impl Default for SharedError { + fn default() -> Self { + Self::new() + } +} + +impl SharedError { + pub fn new() -> Self { + Self(Arc::new(OnceLock::new())) + } + + pub fn try_get_err(&self) -> Option<&T> { + self.0.get() + } + + pub fn try_insert_err(&self, err: T) -> Result<(), &T> { + self.0.set(err).map_err(|_| self.0.get().unwrap()) + } +} + #[derive(Debug, thiserror::Error)] pub enum PeerError { + #[error("The connection was closed.")] + ConnectionClosed, #[error("The connection tasks client channel was closed")] ClientChannelClosed, #[error("error with peer response: {0}")] ResponseError(&'static str), #[error("the peer sent an incorrect response to our request")] PeerSentIncorrectResponse, - #[error("bucket error")] + #[error("the peer sent an invalid message")] + PeerSentInvalidMessage, + #[error("inner service error: {0}")] + ServiceError(#[from] tower::BoxError), + #[error("bucket error: {0}")] BucketError(#[from] monero_wire::BucketError), #[error("handshake error: {0}")] Handshake(#[from] crate::client::HandshakeError), diff --git a/p2p/monero-p2p/src/handles.rs b/p2p/monero-p2p/src/handles.rs index 912726e3..c07a76dc 100644 --- a/p2p/monero-p2p/src/handles.rs +++ b/p2p/monero-p2p/src/handles.rs @@ -11,6 +11,10 @@ pub struct HandleBuilder { } impl HandleBuilder { + pub fn new() -> Self { + Self { permit: None } + } + pub fn with_permit(mut self, permit: OwnedSemaphorePermit) -> Self { self.permit = Some(permit); self diff --git a/p2p/monero-p2p/src/protocol.rs b/p2p/monero-p2p/src/protocol.rs index 56edd810..a4fb6e9e 100644 --- a/p2p/monero-p2p/src/protocol.rs +++ b/p2p/monero-p2p/src/protocol.rs @@ -55,6 +55,13 @@ pub enum MessageID { NewTransactions, } +/// This is a sub-set of [`PeerRequest`] for requests that should be sent to all nodes. +pub enum PeerBroadcast { + Transactions(NewTransactions), + NewBlock(NewBlock), + NewFluffyBlock(NewFluffyBlock), +} + pub enum PeerRequest { Handshake(HandshakeRequest), TimedSync(TimedSyncRequest), diff --git a/p2p/monero-p2p/src/protocol/try_from.rs b/p2p/monero-p2p/src/protocol/try_from.rs index 4e4ebdb9..02a5233e 100644 --- a/p2p/monero-p2p/src/protocol/try_from.rs +++ b/p2p/monero-p2p/src/protocol/try_from.rs @@ -5,6 +5,7 @@ use monero_wire::{Message, ProtocolMessage, RequestMessage, ResponseMessage}; use super::{PeerRequest, PeerResponse}; +#[derive(Debug)] pub struct MessageConversionError; macro_rules! match_body { diff --git a/p2p/monero-p2p/tests/handshake.rs b/p2p/monero-p2p/tests/handshake.rs index 8385c55d..51f32558 100644 --- a/p2p/monero-p2p/tests/handshake.rs +++ b/p2p/monero-p2p/tests/handshake.rs @@ -1,6 +1,8 @@ +use std::sync::Arc; use std::{net::SocketAddr, str::FromStr}; use futures::{channel::mpsc, StreamExt}; +use tokio::sync::{broadcast, Semaphore}; use tower::{Service, ServiceExt}; use cuprate_common::Network; @@ -13,6 +15,7 @@ use monero_p2p::{ }; use cuprate_test_utils::test_netzone::{TestNetZone, TestNetZoneAddr}; +use monero_p2p::client::InternalPeerID; mod utils; use utils::*; @@ -22,6 +25,11 @@ async fn handshake_cuprate_to_cuprate() { // Tests a Cuprate <-> Cuprate handshake by making 2 handshake services and making them talk to // each other. + let (broadcast_tx, _) = broadcast::channel(1); // this isn't actually used in this test. + let semaphore = Arc::new(Semaphore::new(10)); + let permit_1 = semaphore.clone().acquire_owned().await.unwrap(); + let permit_2 = semaphore.acquire_owned().await.unwrap(); + let our_basic_node_data_1 = BasicNodeData { my_port: 0, network_id: Network::Mainnet.network_id(), @@ -39,6 +47,7 @@ async fn handshake_cuprate_to_cuprate() { DummyAddressBook, DummyCoreSyncSvc, DummyPeerRequestHandlerSvc, + broadcast_tx.clone(), our_basic_node_data_1, ); @@ -46,6 +55,7 @@ async fn handshake_cuprate_to_cuprate() { DummyAddressBook, DummyCoreSyncSvc, DummyPeerRequestHandlerSvc, + broadcast_tx.clone(), our_basic_node_data_2, ); @@ -53,17 +63,19 @@ async fn handshake_cuprate_to_cuprate() { let (p2_sender, p1_receiver) = mpsc::channel(5); let p1_handshake_req = DoHandshakeRequest { - addr: TestNetZoneAddr(888), + addr: InternalPeerID::KnownAddr(TestNetZoneAddr(888)), peer_stream: p2_receiver.map(Ok).boxed(), peer_sink: p2_sender.into(), direction: ConnectionDirection::OutBound, + permit: permit_1, }; let p2_handshake_req = DoHandshakeRequest { - addr: TestNetZoneAddr(444), + addr: InternalPeerID::KnownAddr(TestNetZoneAddr(444)), peer_stream: p1_receiver.boxed().map(Ok).boxed(), peer_sink: p1_sender.into(), direction: ConnectionDirection::InBound, + permit: permit_2, }; let p1 = tokio::spawn(async move { @@ -93,6 +105,10 @@ async fn handshake_cuprate_to_cuprate() { #[tokio::test] async fn handshake() { + let (broadcast_tx, _) = broadcast::channel(1); // this isn't actually used in this test. + let semaphore = Arc::new(Semaphore::new(10)); + let permit = semaphore.acquire_owned().await.unwrap(); + let addr = "127.0.0.1:18080"; let our_basic_node_data = BasicNodeData { @@ -108,6 +124,7 @@ async fn handshake() { DummyAddressBook, DummyCoreSyncSvc, DummyPeerRequestHandlerSvc, + broadcast_tx, our_basic_node_data, ); @@ -119,6 +136,7 @@ async fn handshake() { .unwrap() .call(ConnectRequest { addr: SocketAddr::from_str(addr).unwrap(), + permit, }) .await .unwrap();