return the Client after a handshake

This commit is contained in:
Boog900 2024-01-13 00:07:35 +00:00
parent 5e8221183e
commit 478a8c1545
No known key found for this signature in database
GPG key ID: 5401367FB7302004
11 changed files with 390 additions and 108 deletions

13
Cargo.lock generated
View file

@ -1198,6 +1198,7 @@ dependencies = [
"monero-wire",
"thiserror",
"tokio",
"tokio-stream",
"tokio-util",
"tower",
"tracing",
@ -2041,6 +2042,18 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-stream"
version = "0.1.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "397c988d37662c7dda6d2208364a706264bf3d6138b11d436cbac0ad38832842"
dependencies = [
"futures-core",
"pin-project-lite",
"tokio",
"tokio-util",
]
[[package]]
name = "tokio-util"
version = "0.7.10"

View file

@ -13,8 +13,9 @@ borsh = ["dep:borsh"]
monero-wire = {path= "../../net/monero-wire"}
cuprate-common = {path = "../../common", features = ["borsh"]}
tokio = {version= "1.34.0", default-features = false, features = ["net"]}
tokio = {version= "1.34.0", default-features = false, features = ["net", "sync"]}
tokio-util = { version = "0.7.10", default-features = false, features = ["codec"] }
tokio-stream = {version = "0.1.14", default-features = false, features = ["sync"]}
futures = "0.3.29"
async-trait = "0.1.74"
tower = { version= "0.4.13", features = ["util"] }

View file

@ -1,3 +1,20 @@
use std::fmt::Formatter;
use std::{
fmt::{Debug, Display},
task::{Context, Poll},
};
use futures::channel::oneshot;
use tokio::{sync::mpsc, task::JoinHandle};
use tokio_util::sync::PollSender;
use tower::Service;
use cuprate_common::tower_utils::InfallibleOneshotReceiver;
use crate::{
handles::ConnectionHandle, NetworkZone, PeerError, PeerRequest, PeerResponse, SharedError,
};
mod conector;
mod connection;
pub mod handshaker;
@ -12,3 +29,85 @@ pub enum InternalPeerID<A> {
KnownAddr(A),
Unknown(u64),
}
impl<A: Display> Display for InternalPeerID<A> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
InternalPeerID::KnownAddr(addr) => addr.fmt(f),
InternalPeerID::Unknown(id) => f.write_str(&format!("Unknown addr, ID: {}", id)),
}
}
}
pub struct Client<Z: NetworkZone> {
id: InternalPeerID<Z::Addr>,
handle: ConnectionHandle,
connection_tx: PollSender<connection::ConnectionTaskRequest>,
connection_handle: JoinHandle<()>,
error: SharedError<PeerError>,
}
impl<Z: NetworkZone> Client<Z> {
pub fn new(
id: InternalPeerID<Z::Addr>,
handle: ConnectionHandle,
connection_tx: mpsc::Sender<connection::ConnectionTaskRequest>,
connection_handle: JoinHandle<()>,
error: SharedError<PeerError>,
) -> Self {
Self {
id,
handle,
connection_tx: PollSender::new(connection_tx),
connection_handle,
error,
}
}
fn set_err(&self, err: PeerError) -> tower::BoxError {
let err_str = err.to_string();
match self.error.try_insert_err(err) {
Ok(_) => err_str,
Err(e) => e.to_string(),
}
.into()
}
}
impl<Z: NetworkZone> Service<PeerRequest> for Client<Z> {
type Response = PeerResponse;
type Error = tower::BoxError;
type Future = InfallibleOneshotReceiver<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if let Some(err) = self.error.try_get_err() {
return Poll::Ready(Err(err.to_string().into()));
}
if self.connection_handle.is_finished() {
let err = self.set_err(PeerError::ClientChannelClosed);
return Poll::Ready(Err(err));
}
self.connection_tx
.poll_reserve(cx)
.map_err(|_| PeerError::ClientChannelClosed.into())
}
fn call(&mut self, request: PeerRequest) -> Self::Future {
let (tx, rx) = oneshot::channel();
let req = connection::ConnectionTaskRequest {
response_channel: tx,
request,
};
self.connection_tx
.send_item(req)
.map_err(|_| ())
.expect("poll_ready should have been called");
rx.into()
}
}

View file

@ -5,15 +5,17 @@ use std::{
};
use futures::FutureExt;
use tokio::sync::OwnedSemaphorePermit;
use tower::{Service, ServiceExt};
use crate::{
client::{DoHandshakeRequest, HandShaker, HandshakeError},
client::{Client, DoHandshakeRequest, HandShaker, HandshakeError, InternalPeerID},
AddressBook, ConnectionDirection, CoreSyncSvc, NetworkZone, PeerRequestHandler,
};
pub struct ConnectRequest<Z: NetworkZone> {
pub addr: Z::Addr,
pub permit: OwnedSemaphorePermit,
}
pub struct Connector<Z: NetworkZone, AdrBook, CSync, ReqHdlr> {
@ -33,7 +35,7 @@ where
CSync: CoreSyncSvc + Clone,
ReqHdlr: PeerRequestHandler + Clone,
{
type Response = ();
type Response = Client<Z>;
type Error = HandshakeError;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
@ -49,7 +51,8 @@ where
async move {
let (peer_stream, peer_sink) = Z::connect_to_peer(req.addr).await?;
let req = DoHandshakeRequest {
addr: req.addr,
addr: InternalPeerID::KnownAddr(req.addr),
permit: req.permit,
peer_stream,
peer_sink,
direction: ConnectionDirection::OutBound,

View file

@ -1,58 +1,65 @@
use std::sync::Arc;
use futures::{
channel::{mpsc, oneshot},
stream::FusedStream,
channel::oneshot,
stream::{Fuse, FusedStream},
SinkExt, StreamExt,
};
use tokio::sync::{broadcast, mpsc};
use tokio_stream::wrappers::{BroadcastStream, ReceiverStream};
use tower::ServiceExt;
use monero_wire::{LevinCommand, Message};
use monero_wire::{LevinCommand, Message, ProtocolMessage};
use crate::{MessageID, NetworkZone, PeerError, PeerRequest, PeerRequestHandler, PeerResponse};
use crate::{
handles::ConnectionGuard, MessageID, NetworkZone, PeerBroadcast, PeerError, PeerRequest,
PeerRequestHandler, PeerResponse, SharedError,
};
pub struct ConnectionTaskRequest {
request: PeerRequest,
response_channel: oneshot::Sender<Result<PeerResponse, PeerError>>,
pub request: PeerRequest,
pub response_channel: oneshot::Sender<Result<PeerResponse, tower::BoxError>>,
}
pub enum State {
WaitingForRequest,
WaitingForResponse {
request_id: MessageID,
tx: oneshot::Sender<Result<PeerResponse, PeerError>>,
tx: oneshot::Sender<Result<PeerResponse, tower::BoxError>>,
},
}
impl State {
/// Returns if the [`LevinCommand`] is the correct response message for our request.
///
/// e.g that we didn't get a block for a txs request.
fn levin_command_response(&self, command: LevinCommand) -> bool {
match self {
State::WaitingForResponse { request_id, .. } => matches!(
(request_id, command),
(MessageID::Handshake, LevinCommand::Handshake)
| (MessageID::TimedSync, LevinCommand::TimedSync)
| (MessageID::Ping, LevinCommand::Ping)
| (MessageID::SupportFlags, LevinCommand::SupportFlags)
| (MessageID::GetObjects, LevinCommand::GetObjectsResponse)
| (MessageID::GetChain, LevinCommand::ChainResponse)
| (MessageID::FluffyMissingTxs, LevinCommand::NewFluffyBlock)
| (
MessageID::GetTxPoolCompliment,
LevinCommand::NewTransactions
)
),
_ => panic!("We are not in a state to be checking responses!"),
}
}
/// Returns if the [`LevinCommand`] is the correct response message for our request.
///
/// e.g that we didn't get a block for a txs request.
fn levin_command_response(message_id: &MessageID, command: LevinCommand) -> bool {
matches!(
(message_id, command),
(MessageID::Handshake, LevinCommand::Handshake)
| (MessageID::TimedSync, LevinCommand::TimedSync)
| (MessageID::Ping, LevinCommand::Ping)
| (MessageID::SupportFlags, LevinCommand::SupportFlags)
| (MessageID::GetObjects, LevinCommand::GetObjectsResponse)
| (MessageID::GetChain, LevinCommand::ChainResponse)
| (MessageID::FluffyMissingTxs, LevinCommand::NewFluffyBlock)
| (
MessageID::GetTxPoolCompliment,
LevinCommand::NewTransactions
)
)
}
pub struct Connection<Z: NetworkZone, ReqHndlr> {
peer_sink: Z::Sink,
state: State,
client_rx: mpsc::Receiver<ConnectionTaskRequest>,
client_rx: Fuse<ReceiverStream<ConnectionTaskRequest>>,
broadcast_rx: Fuse<BroadcastStream<Arc<PeerBroadcast>>>,
peer_request_handler: ReqHndlr,
connection_guard: ConnectionGuard,
error: SharedError<PeerError>,
}
impl<Z: NetworkZone, ReqHndlr> Connection<Z, ReqHndlr>
@ -62,47 +69,24 @@ where
pub fn new(
peer_sink: Z::Sink,
client_rx: mpsc::Receiver<ConnectionTaskRequest>,
broadcast_rx: broadcast::Receiver<Arc<PeerBroadcast>>,
peer_request_handler: ReqHndlr,
connection_guard: ConnectionGuard,
error: SharedError<PeerError>,
) -> Connection<Z, ReqHndlr> {
Connection {
peer_sink,
state: State::WaitingForRequest,
client_rx,
client_rx: ReceiverStream::new(client_rx).fuse(),
broadcast_rx: BroadcastStream::new(broadcast_rx).fuse(),
peer_request_handler,
connection_guard,
error,
}
}
async fn handle_response(&mut self, res: PeerResponse) -> Result<(), PeerError> {
let state = std::mem::replace(&mut self.state, State::WaitingForRequest);
if let State::WaitingForResponse { request_id, tx } = state {
if request_id != res.id() {
// TODO: Fail here
return Err(PeerError::PeerSentIncorrectResponse);
}
// TODO: do more tests here
// response passed our tests we can send it to the requester
let _ = tx.send(Ok(res));
Ok(())
} else {
unreachable!("This will only be called when in state WaitingForResponse");
}
}
async fn send_message_to_peer(&mut self, mes: impl Into<Message>) -> Result<(), PeerError> {
Ok(self.peer_sink.send(mes.into()).await?)
}
async fn handle_peer_request(&mut self, _req: PeerRequest) -> Result<(), PeerError> {
// we should check contents of peer requests for obvious errors like we do with responses
todo!()
/*
let ready_svc = self.svc.ready().await?;
let res = ready_svc.call(req).await?;
self.send_message_to_peer(res).await
*/
async fn send_message_to_peer(&mut self, mes: Message) -> Result<(), PeerError> {
Ok(self.peer_sink.send(mes).await?)
}
async fn handle_client_request(&mut self, req: ConnectionTaskRequest) -> Result<(), PeerError> {
@ -111,26 +95,72 @@ where
request_id: req.request.id(),
tx: req.response_channel,
};
} else {
// TODO: we should send this after sending the message to the peer.
req.response_channel.send(Ok(PeerResponse::NA));
}
self.send_message_to_peer(req.request.into()).await
}
async fn handle_peer_request(&mut self, req: PeerRequest) -> Result<(), PeerError> {
let ready_svc = self.peer_request_handler.ready().await?;
let res = ready_svc.call(req).await?;
if matches!(res, PeerResponse::NA) {
return Ok(());
}
self.send_message_to_peer(res.try_into().unwrap()).await
}
async fn handle_potential_response(&mut self, mes: Message) -> Result<(), PeerError> {
if mes.is_request() {
return self.handle_peer_request(mes.try_into().unwrap()).await;
}
let State::WaitingForResponse { request_id, .. } = &self.state else {
panic!("Not in correct state, can't receive response!")
};
if levin_command_response(request_id, mes.command()) {
// TODO: Do more checks before returning response.
let State::WaitingForResponse { tx, .. } =
std::mem::replace(&mut self.state, State::WaitingForRequest)
else {
panic!("Not in correct state, can't receive response!")
};
let _ = tx.send(Ok(mes.try_into().unwrap()));
Ok(())
} else {
self.handle_peer_request(
mes.try_into()
.map_err(|_| PeerError::PeerSentInvalidMessage)?,
)
.await
}
// TODO: send NA response to requester
self.send_message_to_peer(req.request).await
}
async fn state_waiting_for_request<Str>(&mut self, stream: &mut Str) -> Result<(), PeerError>
where
Str: FusedStream<Item = Result<Message, monero_wire::BucketError>> + Unpin,
{
futures::select! {
peer_message = stream.next() => {
match peer_message.expect("MessageStream will never return None") {
Ok(message) => {
self.handle_peer_request(message.try_into().map_err(|_| PeerError::ResponseError(""))?).await
},
Err(e) => Err(e.into()),
}
},
tokio::select! {
biased;
bradcast_req = self.broadcast_rx.next() => {
todo!()
}
client_req = self.client_rx.next() => {
self.handle_client_request(client_req.ok_or(PeerError::ClientChannelClosed)?).await
if let Some(client_req) = client_req {
self.handle_client_request(client_req).await?
}
Err(PeerError::ClientChannelClosed)
},
peer_message = stream.next() => {
if let Some(peer_message) = peer_message {
self.handle_peer_request(peer_message?.try_into().map_err(|_| PeerError::PeerSentInvalidMessage)?).await?
}
Err(PeerError::ConnectionClosed)
},
}
}
@ -139,38 +169,69 @@ where
where
Str: FusedStream<Item = Result<Message, monero_wire::BucketError>> + Unpin,
{
// put a timeout on this
let peer_message = stream
.next()
.await
.expect("MessageStream will never return None")?;
if !peer_message.is_request() && self.state.levin_command_response(peer_message.command()) {
if let Ok(res) = peer_message.try_into() {
Ok(self.handle_response(res).await?)
} else {
// im almost certain this is impossible to hit, but im not certain enough to use unreachable!()
Err(PeerError::ResponseError("Peer sent incorrect response"))
tokio::select! {
biased;
bradcast_req = self.broadcast_rx.next() => {
todo!()
}
} else if let Ok(req) = peer_message.try_into() {
self.handle_peer_request(req).await
} else {
// this can be hit if the peer sends an incorrect response message
Err(PeerError::ResponseError("Peer sent incorrect response"))
peer_message = stream.next() => {
if let Some(peer_message) = peer_message {
self.handle_peer_request(peer_message?.try_into().map_err(|_| PeerError::PeerSentInvalidMessage)?).await?
}
Err(PeerError::ConnectionClosed)
},
}
}
pub async fn run<Str>(mut self, mut stream: Str)
pub async fn run<Str>(mut self, mut stream: Str, eager_protocol_messages: Vec<ProtocolMessage>)
where
Str: FusedStream<Item = Result<Message, monero_wire::BucketError>> + Unpin,
{
for message in eager_protocol_messages {
let message = Message::Protocol(message).try_into();
let res = match message {
Ok(mes) => self.handle_peer_request(mes).await,
Err(_) => Err(PeerError::PeerSentInvalidMessage),
};
if let Err(err) = res {
return self.shutdown(err);
}
}
loop {
let _res = match self.state {
if self.connection_guard.should_shutdown() {
return self.shutdown(PeerError::ConnectionClosed);
}
let res = match self.state {
State::WaitingForRequest => self.state_waiting_for_request(&mut stream).await,
State::WaitingForResponse { .. } => {
self.state_waiting_for_response(&mut stream).await
}
};
if let Err(err) = res {
return self.shutdown(err);
}
}
}
fn shutdown(mut self, err: PeerError) {
tracing::debug!("Connection task shutting down: {}", err);
let mut client_rx = self.client_rx.into_inner().into_inner();
client_rx.close();
let err_str = err.to_string();
if let Err(err) = self.error.try_insert_err(err) {
tracing::debug!("Shared error already contains an error: {}", err);
}
while let Ok(req) = client_rx.try_recv() {
let _ = req.response_channel.send(Err(err_str.clone().into()));
}
self.connection_guard.connection_closed();
}
}

View file

@ -2,10 +2,12 @@ use std::{
future::Future,
marker::PhantomData,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use futures::{FutureExt, SinkExt, StreamExt};
use tokio::sync::{broadcast, mpsc, OwnedSemaphorePermit};
use tower::{Service, ServiceExt};
use tracing::Instrument;
@ -20,9 +22,11 @@ use monero_wire::{
};
use crate::{
client::{connection::Connection, Client, InternalPeerID},
handles::HandleBuilder,
AddressBook, AddressBookRequest, AddressBookResponse, ConnectionDirection, CoreSyncDataRequest,
CoreSyncDataResponse, CoreSyncSvc, MessageID, NetworkZone, PeerRequestHandler,
MAX_PEERS_IN_PEER_LIST_MESSAGE,
CoreSyncDataResponse, CoreSyncSvc, MessageID, NetworkZone, PeerBroadcast, PeerRequestHandler,
SharedError, MAX_PEERS_IN_PEER_LIST_MESSAGE,
};
const MAX_EAGER_PROTOCOL_MESSAGES: usize = 2;
@ -46,10 +50,11 @@ pub enum HandshakeError {
}
pub struct DoHandshakeRequest<Z: NetworkZone> {
pub addr: Z::Addr,
pub addr: InternalPeerID<Z::Addr>,
pub peer_stream: Z::Stream,
pub peer_sink: Z::Sink,
pub direction: ConnectionDirection,
pub permit: OwnedSemaphorePermit,
}
#[derive(Debug, Clone)]
@ -60,6 +65,8 @@ pub struct HandShaker<Z: NetworkZone, AdrBook, CSync, ReqHdlr> {
our_basic_node_data: BasicNodeData,
broadcast_tx: broadcast::Sender<Arc<PeerBroadcast>>,
_zone: PhantomData<Z>,
}
@ -69,12 +76,15 @@ impl<Z: NetworkZone, AdrBook, CSync, ReqHdlr> HandShaker<Z, AdrBook, CSync, ReqH
core_sync_svc: CSync,
peer_request_svc: ReqHdlr,
broadcast_tx: broadcast::Sender<Arc<PeerBroadcast>>,
our_basic_node_data: BasicNodeData,
) -> Self {
Self {
address_book,
core_sync_svc,
peer_request_svc,
broadcast_tx,
our_basic_node_data,
_zone: PhantomData,
}
@ -88,7 +98,7 @@ where
CSync: CoreSyncSvc + Clone,
ReqHdlr: PeerRequestHandler + Clone,
{
type Response = ();
type Response = Client<Z>;
type Error = HandshakeError;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
@ -103,8 +113,11 @@ where
peer_stream,
peer_sink,
direction,
permit,
} = req;
let broadcast_rx = self.broadcast_tx.subscribe();
let address_book = self.address_book.clone();
let peer_request_svc = self.peer_request_svc.clone();
let core_sync_svc = self.core_sync_svc.clone();
@ -119,6 +132,8 @@ where
peer_stream,
peer_sink,
direction,
permit,
broadcast_rx,
address_book,
core_sync_svc,
peer_request_svc,
@ -133,15 +148,19 @@ where
#[allow(clippy::too_many_arguments)]
async fn handshake<Z: NetworkZone, AdrBook, CSync, ReqHdlr>(
addr: Z::Addr,
addr: InternalPeerID<Z::Addr>,
mut peer_stream: Z::Stream,
mut peer_sink: Z::Sink,
direction: ConnectionDirection,
permit: OwnedSemaphorePermit,
broadcast_rx: broadcast::Receiver<Arc<PeerBroadcast>>,
mut address_book: AdrBook,
mut core_sync_svc: CSync,
peer_request_svc: ReqHdlr,
our_basic_node_data: BasicNodeData,
) -> Result<(), HandshakeError>
) -> Result<Client<Z>, HandshakeError>
where
AdrBook: AddressBook<Z>,
CSync: CoreSyncSvc,
@ -277,7 +296,27 @@ where
tracing::debug!("Handshake complete.");
Ok(())
let error_slot = SharedError::new();
let (connection_guard, handle, _) = HandleBuilder::new().with_permit(permit).build();
let (connection_tx, client_rx) = mpsc::channel(3);
let connection = Connection::<Z, _>::new(
peer_sink,
client_rx,
broadcast_rx,
peer_request_svc,
connection_guard,
error_slot.clone(),
);
let connection_handle =
tokio::spawn(connection.run(peer_stream.fuse(), eager_protocol_messages));
let client = Client::<Z>::new(addr, handle, connection_tx, connection_handle, error_slot);
Ok(client)
}
/// Sends a [`HandshakeRequest`] to the peer.

View file

@ -1,12 +1,48 @@
use std::sync::{Arc, OnceLock};
pub struct SharedError<T>(Arc<OnceLock<T>>);
impl<T> Clone for SharedError<T> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<T> Default for SharedError<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> SharedError<T> {
pub fn new() -> Self {
Self(Arc::new(OnceLock::new()))
}
pub fn try_get_err(&self) -> Option<&T> {
self.0.get()
}
pub fn try_insert_err(&self, err: T) -> Result<(), &T> {
self.0.set(err).map_err(|_| self.0.get().unwrap())
}
}
#[derive(Debug, thiserror::Error)]
pub enum PeerError {
#[error("The connection was closed.")]
ConnectionClosed,
#[error("The connection tasks client channel was closed")]
ClientChannelClosed,
#[error("error with peer response: {0}")]
ResponseError(&'static str),
#[error("the peer sent an incorrect response to our request")]
PeerSentIncorrectResponse,
#[error("bucket error")]
#[error("the peer sent an invalid message")]
PeerSentInvalidMessage,
#[error("inner service error: {0}")]
ServiceError(#[from] tower::BoxError),
#[error("bucket error: {0}")]
BucketError(#[from] monero_wire::BucketError),
#[error("handshake error: {0}")]
Handshake(#[from] crate::client::HandshakeError),

View file

@ -11,6 +11,10 @@ pub struct HandleBuilder {
}
impl HandleBuilder {
pub fn new() -> Self {
Self { permit: None }
}
pub fn with_permit(mut self, permit: OwnedSemaphorePermit) -> Self {
self.permit = Some(permit);
self

View file

@ -55,6 +55,13 @@ pub enum MessageID {
NewTransactions,
}
/// This is a sub-set of [`PeerRequest`] for requests that should be sent to all nodes.
pub enum PeerBroadcast {
Transactions(NewTransactions),
NewBlock(NewBlock),
NewFluffyBlock(NewFluffyBlock),
}
pub enum PeerRequest {
Handshake(HandshakeRequest),
TimedSync(TimedSyncRequest),

View file

@ -5,6 +5,7 @@ use monero_wire::{Message, ProtocolMessage, RequestMessage, ResponseMessage};
use super::{PeerRequest, PeerResponse};
#[derive(Debug)]
pub struct MessageConversionError;
macro_rules! match_body {

View file

@ -1,6 +1,8 @@
use std::sync::Arc;
use std::{net::SocketAddr, str::FromStr};
use futures::{channel::mpsc, StreamExt};
use tokio::sync::{broadcast, Semaphore};
use tower::{Service, ServiceExt};
use cuprate_common::Network;
@ -13,6 +15,7 @@ use monero_p2p::{
};
use cuprate_test_utils::test_netzone::{TestNetZone, TestNetZoneAddr};
use monero_p2p::client::InternalPeerID;
mod utils;
use utils::*;
@ -22,6 +25,11 @@ async fn handshake_cuprate_to_cuprate() {
// Tests a Cuprate <-> Cuprate handshake by making 2 handshake services and making them talk to
// each other.
let (broadcast_tx, _) = broadcast::channel(1); // this isn't actually used in this test.
let semaphore = Arc::new(Semaphore::new(10));
let permit_1 = semaphore.clone().acquire_owned().await.unwrap();
let permit_2 = semaphore.acquire_owned().await.unwrap();
let our_basic_node_data_1 = BasicNodeData {
my_port: 0,
network_id: Network::Mainnet.network_id(),
@ -39,6 +47,7 @@ async fn handshake_cuprate_to_cuprate() {
DummyAddressBook,
DummyCoreSyncSvc,
DummyPeerRequestHandlerSvc,
broadcast_tx.clone(),
our_basic_node_data_1,
);
@ -46,6 +55,7 @@ async fn handshake_cuprate_to_cuprate() {
DummyAddressBook,
DummyCoreSyncSvc,
DummyPeerRequestHandlerSvc,
broadcast_tx.clone(),
our_basic_node_data_2,
);
@ -53,17 +63,19 @@ async fn handshake_cuprate_to_cuprate() {
let (p2_sender, p1_receiver) = mpsc::channel(5);
let p1_handshake_req = DoHandshakeRequest {
addr: TestNetZoneAddr(888),
addr: InternalPeerID::KnownAddr(TestNetZoneAddr(888)),
peer_stream: p2_receiver.map(Ok).boxed(),
peer_sink: p2_sender.into(),
direction: ConnectionDirection::OutBound,
permit: permit_1,
};
let p2_handshake_req = DoHandshakeRequest {
addr: TestNetZoneAddr(444),
addr: InternalPeerID::KnownAddr(TestNetZoneAddr(444)),
peer_stream: p1_receiver.boxed().map(Ok).boxed(),
peer_sink: p1_sender.into(),
direction: ConnectionDirection::InBound,
permit: permit_2,
};
let p1 = tokio::spawn(async move {
@ -93,6 +105,10 @@ async fn handshake_cuprate_to_cuprate() {
#[tokio::test]
async fn handshake() {
let (broadcast_tx, _) = broadcast::channel(1); // this isn't actually used in this test.
let semaphore = Arc::new(Semaphore::new(10));
let permit = semaphore.acquire_owned().await.unwrap();
let addr = "127.0.0.1:18080";
let our_basic_node_data = BasicNodeData {
@ -108,6 +124,7 @@ async fn handshake() {
DummyAddressBook,
DummyCoreSyncSvc,
DummyPeerRequestHandlerSvc,
broadcast_tx,
our_basic_node_data,
);
@ -119,6 +136,7 @@ async fn handshake() {
.unwrap()
.call(ConnectRequest {
addr: SocketAddr::from_str(addr).unwrap(),
permit,
})
.await
.unwrap();