mirror of
https://github.com/Cuprate/cuprate.git
synced 2024-11-16 15:58:17 +00:00
return the Client
after a handshake
This commit is contained in:
parent
5e8221183e
commit
478a8c1545
11 changed files with 390 additions and 108 deletions
13
Cargo.lock
generated
13
Cargo.lock
generated
|
@ -1198,6 +1198,7 @@ dependencies = [
|
|||
"monero-wire",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tokio-util",
|
||||
"tower",
|
||||
"tracing",
|
||||
|
@ -2041,6 +2042,18 @@ dependencies = [
|
|||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-stream"
|
||||
version = "0.1.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "397c988d37662c7dda6d2208364a706264bf3d6138b11d436cbac0ad38832842"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-util"
|
||||
version = "0.7.10"
|
||||
|
|
|
@ -13,8 +13,9 @@ borsh = ["dep:borsh"]
|
|||
monero-wire = {path= "../../net/monero-wire"}
|
||||
cuprate-common = {path = "../../common", features = ["borsh"]}
|
||||
|
||||
tokio = {version= "1.34.0", default-features = false, features = ["net"]}
|
||||
tokio = {version= "1.34.0", default-features = false, features = ["net", "sync"]}
|
||||
tokio-util = { version = "0.7.10", default-features = false, features = ["codec"] }
|
||||
tokio-stream = {version = "0.1.14", default-features = false, features = ["sync"]}
|
||||
futures = "0.3.29"
|
||||
async-trait = "0.1.74"
|
||||
tower = { version= "0.4.13", features = ["util"] }
|
||||
|
|
|
@ -1,3 +1,20 @@
|
|||
use std::fmt::Formatter;
|
||||
use std::{
|
||||
fmt::{Debug, Display},
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use futures::channel::oneshot;
|
||||
use tokio::{sync::mpsc, task::JoinHandle};
|
||||
use tokio_util::sync::PollSender;
|
||||
use tower::Service;
|
||||
|
||||
use cuprate_common::tower_utils::InfallibleOneshotReceiver;
|
||||
|
||||
use crate::{
|
||||
handles::ConnectionHandle, NetworkZone, PeerError, PeerRequest, PeerResponse, SharedError,
|
||||
};
|
||||
|
||||
mod conector;
|
||||
mod connection;
|
||||
pub mod handshaker;
|
||||
|
@ -12,3 +29,85 @@ pub enum InternalPeerID<A> {
|
|||
KnownAddr(A),
|
||||
Unknown(u64),
|
||||
}
|
||||
|
||||
impl<A: Display> Display for InternalPeerID<A> {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
InternalPeerID::KnownAddr(addr) => addr.fmt(f),
|
||||
InternalPeerID::Unknown(id) => f.write_str(&format!("Unknown addr, ID: {}", id)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Client<Z: NetworkZone> {
|
||||
id: InternalPeerID<Z::Addr>,
|
||||
handle: ConnectionHandle,
|
||||
|
||||
connection_tx: PollSender<connection::ConnectionTaskRequest>,
|
||||
connection_handle: JoinHandle<()>,
|
||||
|
||||
error: SharedError<PeerError>,
|
||||
}
|
||||
|
||||
impl<Z: NetworkZone> Client<Z> {
|
||||
pub fn new(
|
||||
id: InternalPeerID<Z::Addr>,
|
||||
handle: ConnectionHandle,
|
||||
connection_tx: mpsc::Sender<connection::ConnectionTaskRequest>,
|
||||
connection_handle: JoinHandle<()>,
|
||||
error: SharedError<PeerError>,
|
||||
) -> Self {
|
||||
Self {
|
||||
id,
|
||||
handle,
|
||||
connection_tx: PollSender::new(connection_tx),
|
||||
connection_handle,
|
||||
error,
|
||||
}
|
||||
}
|
||||
|
||||
fn set_err(&self, err: PeerError) -> tower::BoxError {
|
||||
let err_str = err.to_string();
|
||||
match self.error.try_insert_err(err) {
|
||||
Ok(_) => err_str,
|
||||
Err(e) => e.to_string(),
|
||||
}
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl<Z: NetworkZone> Service<PeerRequest> for Client<Z> {
|
||||
type Response = PeerResponse;
|
||||
type Error = tower::BoxError;
|
||||
type Future = InfallibleOneshotReceiver<Result<Self::Response, Self::Error>>;
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
if let Some(err) = self.error.try_get_err() {
|
||||
return Poll::Ready(Err(err.to_string().into()));
|
||||
}
|
||||
|
||||
if self.connection_handle.is_finished() {
|
||||
let err = self.set_err(PeerError::ClientChannelClosed);
|
||||
return Poll::Ready(Err(err));
|
||||
}
|
||||
|
||||
self.connection_tx
|
||||
.poll_reserve(cx)
|
||||
.map_err(|_| PeerError::ClientChannelClosed.into())
|
||||
}
|
||||
|
||||
fn call(&mut self, request: PeerRequest) -> Self::Future {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let req = connection::ConnectionTaskRequest {
|
||||
response_channel: tx,
|
||||
request,
|
||||
};
|
||||
|
||||
self.connection_tx
|
||||
.send_item(req)
|
||||
.map_err(|_| ())
|
||||
.expect("poll_ready should have been called");
|
||||
|
||||
rx.into()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,15 +5,17 @@ use std::{
|
|||
};
|
||||
|
||||
use futures::FutureExt;
|
||||
use tokio::sync::OwnedSemaphorePermit;
|
||||
use tower::{Service, ServiceExt};
|
||||
|
||||
use crate::{
|
||||
client::{DoHandshakeRequest, HandShaker, HandshakeError},
|
||||
client::{Client, DoHandshakeRequest, HandShaker, HandshakeError, InternalPeerID},
|
||||
AddressBook, ConnectionDirection, CoreSyncSvc, NetworkZone, PeerRequestHandler,
|
||||
};
|
||||
|
||||
pub struct ConnectRequest<Z: NetworkZone> {
|
||||
pub addr: Z::Addr,
|
||||
pub permit: OwnedSemaphorePermit,
|
||||
}
|
||||
|
||||
pub struct Connector<Z: NetworkZone, AdrBook, CSync, ReqHdlr> {
|
||||
|
@ -33,7 +35,7 @@ where
|
|||
CSync: CoreSyncSvc + Clone,
|
||||
ReqHdlr: PeerRequestHandler + Clone,
|
||||
{
|
||||
type Response = ();
|
||||
type Response = Client<Z>;
|
||||
type Error = HandshakeError;
|
||||
type Future =
|
||||
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
|
||||
|
@ -49,7 +51,8 @@ where
|
|||
async move {
|
||||
let (peer_stream, peer_sink) = Z::connect_to_peer(req.addr).await?;
|
||||
let req = DoHandshakeRequest {
|
||||
addr: req.addr,
|
||||
addr: InternalPeerID::KnownAddr(req.addr),
|
||||
permit: req.permit,
|
||||
peer_stream,
|
||||
peer_sink,
|
||||
direction: ConnectionDirection::OutBound,
|
||||
|
|
|
@ -1,58 +1,65 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use futures::{
|
||||
channel::{mpsc, oneshot},
|
||||
stream::FusedStream,
|
||||
channel::oneshot,
|
||||
stream::{Fuse, FusedStream},
|
||||
SinkExt, StreamExt,
|
||||
};
|
||||
use tokio::sync::{broadcast, mpsc};
|
||||
use tokio_stream::wrappers::{BroadcastStream, ReceiverStream};
|
||||
use tower::ServiceExt;
|
||||
|
||||
use monero_wire::{LevinCommand, Message};
|
||||
use monero_wire::{LevinCommand, Message, ProtocolMessage};
|
||||
|
||||
use crate::{MessageID, NetworkZone, PeerError, PeerRequest, PeerRequestHandler, PeerResponse};
|
||||
use crate::{
|
||||
handles::ConnectionGuard, MessageID, NetworkZone, PeerBroadcast, PeerError, PeerRequest,
|
||||
PeerRequestHandler, PeerResponse, SharedError,
|
||||
};
|
||||
|
||||
pub struct ConnectionTaskRequest {
|
||||
request: PeerRequest,
|
||||
response_channel: oneshot::Sender<Result<PeerResponse, PeerError>>,
|
||||
pub request: PeerRequest,
|
||||
pub response_channel: oneshot::Sender<Result<PeerResponse, tower::BoxError>>,
|
||||
}
|
||||
|
||||
pub enum State {
|
||||
WaitingForRequest,
|
||||
WaitingForResponse {
|
||||
request_id: MessageID,
|
||||
tx: oneshot::Sender<Result<PeerResponse, PeerError>>,
|
||||
tx: oneshot::Sender<Result<PeerResponse, tower::BoxError>>,
|
||||
},
|
||||
}
|
||||
|
||||
impl State {
|
||||
/// Returns if the [`LevinCommand`] is the correct response message for our request.
|
||||
///
|
||||
/// e.g that we didn't get a block for a txs request.
|
||||
fn levin_command_response(&self, command: LevinCommand) -> bool {
|
||||
match self {
|
||||
State::WaitingForResponse { request_id, .. } => matches!(
|
||||
(request_id, command),
|
||||
(MessageID::Handshake, LevinCommand::Handshake)
|
||||
| (MessageID::TimedSync, LevinCommand::TimedSync)
|
||||
| (MessageID::Ping, LevinCommand::Ping)
|
||||
| (MessageID::SupportFlags, LevinCommand::SupportFlags)
|
||||
| (MessageID::GetObjects, LevinCommand::GetObjectsResponse)
|
||||
| (MessageID::GetChain, LevinCommand::ChainResponse)
|
||||
| (MessageID::FluffyMissingTxs, LevinCommand::NewFluffyBlock)
|
||||
| (
|
||||
MessageID::GetTxPoolCompliment,
|
||||
LevinCommand::NewTransactions
|
||||
)
|
||||
),
|
||||
_ => panic!("We are not in a state to be checking responses!"),
|
||||
}
|
||||
}
|
||||
/// Returns if the [`LevinCommand`] is the correct response message for our request.
|
||||
///
|
||||
/// e.g that we didn't get a block for a txs request.
|
||||
fn levin_command_response(message_id: &MessageID, command: LevinCommand) -> bool {
|
||||
matches!(
|
||||
(message_id, command),
|
||||
(MessageID::Handshake, LevinCommand::Handshake)
|
||||
| (MessageID::TimedSync, LevinCommand::TimedSync)
|
||||
| (MessageID::Ping, LevinCommand::Ping)
|
||||
| (MessageID::SupportFlags, LevinCommand::SupportFlags)
|
||||
| (MessageID::GetObjects, LevinCommand::GetObjectsResponse)
|
||||
| (MessageID::GetChain, LevinCommand::ChainResponse)
|
||||
| (MessageID::FluffyMissingTxs, LevinCommand::NewFluffyBlock)
|
||||
| (
|
||||
MessageID::GetTxPoolCompliment,
|
||||
LevinCommand::NewTransactions
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
pub struct Connection<Z: NetworkZone, ReqHndlr> {
|
||||
peer_sink: Z::Sink,
|
||||
|
||||
state: State,
|
||||
client_rx: mpsc::Receiver<ConnectionTaskRequest>,
|
||||
client_rx: Fuse<ReceiverStream<ConnectionTaskRequest>>,
|
||||
broadcast_rx: Fuse<BroadcastStream<Arc<PeerBroadcast>>>,
|
||||
|
||||
peer_request_handler: ReqHndlr,
|
||||
|
||||
connection_guard: ConnectionGuard,
|
||||
error: SharedError<PeerError>,
|
||||
}
|
||||
|
||||
impl<Z: NetworkZone, ReqHndlr> Connection<Z, ReqHndlr>
|
||||
|
@ -62,47 +69,24 @@ where
|
|||
pub fn new(
|
||||
peer_sink: Z::Sink,
|
||||
client_rx: mpsc::Receiver<ConnectionTaskRequest>,
|
||||
|
||||
broadcast_rx: broadcast::Receiver<Arc<PeerBroadcast>>,
|
||||
peer_request_handler: ReqHndlr,
|
||||
connection_guard: ConnectionGuard,
|
||||
error: SharedError<PeerError>,
|
||||
) -> Connection<Z, ReqHndlr> {
|
||||
Connection {
|
||||
peer_sink,
|
||||
state: State::WaitingForRequest,
|
||||
client_rx,
|
||||
client_rx: ReceiverStream::new(client_rx).fuse(),
|
||||
broadcast_rx: BroadcastStream::new(broadcast_rx).fuse(),
|
||||
peer_request_handler,
|
||||
connection_guard,
|
||||
error,
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_response(&mut self, res: PeerResponse) -> Result<(), PeerError> {
|
||||
let state = std::mem::replace(&mut self.state, State::WaitingForRequest);
|
||||
if let State::WaitingForResponse { request_id, tx } = state {
|
||||
if request_id != res.id() {
|
||||
// TODO: Fail here
|
||||
return Err(PeerError::PeerSentIncorrectResponse);
|
||||
}
|
||||
|
||||
// TODO: do more tests here
|
||||
|
||||
// response passed our tests we can send it to the requester
|
||||
let _ = tx.send(Ok(res));
|
||||
Ok(())
|
||||
} else {
|
||||
unreachable!("This will only be called when in state WaitingForResponse");
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_message_to_peer(&mut self, mes: impl Into<Message>) -> Result<(), PeerError> {
|
||||
Ok(self.peer_sink.send(mes.into()).await?)
|
||||
}
|
||||
|
||||
async fn handle_peer_request(&mut self, _req: PeerRequest) -> Result<(), PeerError> {
|
||||
// we should check contents of peer requests for obvious errors like we do with responses
|
||||
todo!()
|
||||
/*
|
||||
let ready_svc = self.svc.ready().await?;
|
||||
let res = ready_svc.call(req).await?;
|
||||
self.send_message_to_peer(res).await
|
||||
*/
|
||||
async fn send_message_to_peer(&mut self, mes: Message) -> Result<(), PeerError> {
|
||||
Ok(self.peer_sink.send(mes).await?)
|
||||
}
|
||||
|
||||
async fn handle_client_request(&mut self, req: ConnectionTaskRequest) -> Result<(), PeerError> {
|
||||
|
@ -111,26 +95,72 @@ where
|
|||
request_id: req.request.id(),
|
||||
tx: req.response_channel,
|
||||
};
|
||||
} else {
|
||||
// TODO: we should send this after sending the message to the peer.
|
||||
req.response_channel.send(Ok(PeerResponse::NA));
|
||||
}
|
||||
self.send_message_to_peer(req.request.into()).await
|
||||
}
|
||||
|
||||
async fn handle_peer_request(&mut self, req: PeerRequest) -> Result<(), PeerError> {
|
||||
let ready_svc = self.peer_request_handler.ready().await?;
|
||||
let res = ready_svc.call(req).await?;
|
||||
if matches!(res, PeerResponse::NA) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.send_message_to_peer(res.try_into().unwrap()).await
|
||||
}
|
||||
|
||||
async fn handle_potential_response(&mut self, mes: Message) -> Result<(), PeerError> {
|
||||
if mes.is_request() {
|
||||
return self.handle_peer_request(mes.try_into().unwrap()).await;
|
||||
}
|
||||
|
||||
let State::WaitingForResponse { request_id, .. } = &self.state else {
|
||||
panic!("Not in correct state, can't receive response!")
|
||||
};
|
||||
|
||||
if levin_command_response(request_id, mes.command()) {
|
||||
// TODO: Do more checks before returning response.
|
||||
|
||||
let State::WaitingForResponse { tx, .. } =
|
||||
std::mem::replace(&mut self.state, State::WaitingForRequest)
|
||||
else {
|
||||
panic!("Not in correct state, can't receive response!")
|
||||
};
|
||||
|
||||
let _ = tx.send(Ok(mes.try_into().unwrap()));
|
||||
Ok(())
|
||||
} else {
|
||||
self.handle_peer_request(
|
||||
mes.try_into()
|
||||
.map_err(|_| PeerError::PeerSentInvalidMessage)?,
|
||||
)
|
||||
.await
|
||||
}
|
||||
// TODO: send NA response to requester
|
||||
self.send_message_to_peer(req.request).await
|
||||
}
|
||||
|
||||
async fn state_waiting_for_request<Str>(&mut self, stream: &mut Str) -> Result<(), PeerError>
|
||||
where
|
||||
Str: FusedStream<Item = Result<Message, monero_wire::BucketError>> + Unpin,
|
||||
{
|
||||
futures::select! {
|
||||
peer_message = stream.next() => {
|
||||
match peer_message.expect("MessageStream will never return None") {
|
||||
Ok(message) => {
|
||||
self.handle_peer_request(message.try_into().map_err(|_| PeerError::ResponseError(""))?).await
|
||||
},
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
},
|
||||
tokio::select! {
|
||||
biased;
|
||||
bradcast_req = self.broadcast_rx.next() => {
|
||||
todo!()
|
||||
}
|
||||
client_req = self.client_rx.next() => {
|
||||
self.handle_client_request(client_req.ok_or(PeerError::ClientChannelClosed)?).await
|
||||
if let Some(client_req) = client_req {
|
||||
self.handle_client_request(client_req).await?
|
||||
}
|
||||
Err(PeerError::ClientChannelClosed)
|
||||
},
|
||||
peer_message = stream.next() => {
|
||||
if let Some(peer_message) = peer_message {
|
||||
self.handle_peer_request(peer_message?.try_into().map_err(|_| PeerError::PeerSentInvalidMessage)?).await?
|
||||
}
|
||||
Err(PeerError::ConnectionClosed)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -139,38 +169,69 @@ where
|
|||
where
|
||||
Str: FusedStream<Item = Result<Message, monero_wire::BucketError>> + Unpin,
|
||||
{
|
||||
// put a timeout on this
|
||||
let peer_message = stream
|
||||
.next()
|
||||
.await
|
||||
.expect("MessageStream will never return None")?;
|
||||
|
||||
if !peer_message.is_request() && self.state.levin_command_response(peer_message.command()) {
|
||||
if let Ok(res) = peer_message.try_into() {
|
||||
Ok(self.handle_response(res).await?)
|
||||
} else {
|
||||
// im almost certain this is impossible to hit, but im not certain enough to use unreachable!()
|
||||
Err(PeerError::ResponseError("Peer sent incorrect response"))
|
||||
tokio::select! {
|
||||
biased;
|
||||
bradcast_req = self.broadcast_rx.next() => {
|
||||
todo!()
|
||||
}
|
||||
} else if let Ok(req) = peer_message.try_into() {
|
||||
self.handle_peer_request(req).await
|
||||
} else {
|
||||
// this can be hit if the peer sends an incorrect response message
|
||||
Err(PeerError::ResponseError("Peer sent incorrect response"))
|
||||
peer_message = stream.next() => {
|
||||
if let Some(peer_message) = peer_message {
|
||||
self.handle_peer_request(peer_message?.try_into().map_err(|_| PeerError::PeerSentInvalidMessage)?).await?
|
||||
}
|
||||
Err(PeerError::ConnectionClosed)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run<Str>(mut self, mut stream: Str)
|
||||
pub async fn run<Str>(mut self, mut stream: Str, eager_protocol_messages: Vec<ProtocolMessage>)
|
||||
where
|
||||
Str: FusedStream<Item = Result<Message, monero_wire::BucketError>> + Unpin,
|
||||
{
|
||||
for message in eager_protocol_messages {
|
||||
let message = Message::Protocol(message).try_into();
|
||||
|
||||
let res = match message {
|
||||
Ok(mes) => self.handle_peer_request(mes).await,
|
||||
Err(_) => Err(PeerError::PeerSentInvalidMessage),
|
||||
};
|
||||
|
||||
if let Err(err) = res {
|
||||
return self.shutdown(err);
|
||||
}
|
||||
}
|
||||
|
||||
loop {
|
||||
let _res = match self.state {
|
||||
if self.connection_guard.should_shutdown() {
|
||||
return self.shutdown(PeerError::ConnectionClosed);
|
||||
}
|
||||
|
||||
let res = match self.state {
|
||||
State::WaitingForRequest => self.state_waiting_for_request(&mut stream).await,
|
||||
State::WaitingForResponse { .. } => {
|
||||
self.state_waiting_for_response(&mut stream).await
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(err) = res {
|
||||
return self.shutdown(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn shutdown(mut self, err: PeerError) {
|
||||
tracing::debug!("Connection task shutting down: {}", err);
|
||||
let mut client_rx = self.client_rx.into_inner().into_inner();
|
||||
client_rx.close();
|
||||
|
||||
let err_str = err.to_string();
|
||||
if let Err(err) = self.error.try_insert_err(err) {
|
||||
tracing::debug!("Shared error already contains an error: {}", err);
|
||||
}
|
||||
|
||||
while let Ok(req) = client_rx.try_recv() {
|
||||
let _ = req.response_channel.send(Err(err_str.clone().into()));
|
||||
}
|
||||
|
||||
self.connection_guard.connection_closed();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,10 +2,12 @@ use std::{
|
|||
future::Future,
|
||||
marker::PhantomData,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use futures::{FutureExt, SinkExt, StreamExt};
|
||||
use tokio::sync::{broadcast, mpsc, OwnedSemaphorePermit};
|
||||
use tower::{Service, ServiceExt};
|
||||
use tracing::Instrument;
|
||||
|
||||
|
@ -20,9 +22,11 @@ use monero_wire::{
|
|||
};
|
||||
|
||||
use crate::{
|
||||
client::{connection::Connection, Client, InternalPeerID},
|
||||
handles::HandleBuilder,
|
||||
AddressBook, AddressBookRequest, AddressBookResponse, ConnectionDirection, CoreSyncDataRequest,
|
||||
CoreSyncDataResponse, CoreSyncSvc, MessageID, NetworkZone, PeerRequestHandler,
|
||||
MAX_PEERS_IN_PEER_LIST_MESSAGE,
|
||||
CoreSyncDataResponse, CoreSyncSvc, MessageID, NetworkZone, PeerBroadcast, PeerRequestHandler,
|
||||
SharedError, MAX_PEERS_IN_PEER_LIST_MESSAGE,
|
||||
};
|
||||
|
||||
const MAX_EAGER_PROTOCOL_MESSAGES: usize = 2;
|
||||
|
@ -46,10 +50,11 @@ pub enum HandshakeError {
|
|||
}
|
||||
|
||||
pub struct DoHandshakeRequest<Z: NetworkZone> {
|
||||
pub addr: Z::Addr,
|
||||
pub addr: InternalPeerID<Z::Addr>,
|
||||
pub peer_stream: Z::Stream,
|
||||
pub peer_sink: Z::Sink,
|
||||
pub direction: ConnectionDirection,
|
||||
pub permit: OwnedSemaphorePermit,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
@ -60,6 +65,8 @@ pub struct HandShaker<Z: NetworkZone, AdrBook, CSync, ReqHdlr> {
|
|||
|
||||
our_basic_node_data: BasicNodeData,
|
||||
|
||||
broadcast_tx: broadcast::Sender<Arc<PeerBroadcast>>,
|
||||
|
||||
_zone: PhantomData<Z>,
|
||||
}
|
||||
|
||||
|
@ -69,12 +76,15 @@ impl<Z: NetworkZone, AdrBook, CSync, ReqHdlr> HandShaker<Z, AdrBook, CSync, ReqH
|
|||
core_sync_svc: CSync,
|
||||
peer_request_svc: ReqHdlr,
|
||||
|
||||
broadcast_tx: broadcast::Sender<Arc<PeerBroadcast>>,
|
||||
|
||||
our_basic_node_data: BasicNodeData,
|
||||
) -> Self {
|
||||
Self {
|
||||
address_book,
|
||||
core_sync_svc,
|
||||
peer_request_svc,
|
||||
broadcast_tx,
|
||||
our_basic_node_data,
|
||||
_zone: PhantomData,
|
||||
}
|
||||
|
@ -88,7 +98,7 @@ where
|
|||
CSync: CoreSyncSvc + Clone,
|
||||
ReqHdlr: PeerRequestHandler + Clone,
|
||||
{
|
||||
type Response = ();
|
||||
type Response = Client<Z>;
|
||||
type Error = HandshakeError;
|
||||
type Future =
|
||||
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
|
||||
|
@ -103,8 +113,11 @@ where
|
|||
peer_stream,
|
||||
peer_sink,
|
||||
direction,
|
||||
permit,
|
||||
} = req;
|
||||
|
||||
let broadcast_rx = self.broadcast_tx.subscribe();
|
||||
|
||||
let address_book = self.address_book.clone();
|
||||
let peer_request_svc = self.peer_request_svc.clone();
|
||||
let core_sync_svc = self.core_sync_svc.clone();
|
||||
|
@ -119,6 +132,8 @@ where
|
|||
peer_stream,
|
||||
peer_sink,
|
||||
direction,
|
||||
permit,
|
||||
broadcast_rx,
|
||||
address_book,
|
||||
core_sync_svc,
|
||||
peer_request_svc,
|
||||
|
@ -133,15 +148,19 @@ where
|
|||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn handshake<Z: NetworkZone, AdrBook, CSync, ReqHdlr>(
|
||||
addr: Z::Addr,
|
||||
addr: InternalPeerID<Z::Addr>,
|
||||
mut peer_stream: Z::Stream,
|
||||
mut peer_sink: Z::Sink,
|
||||
direction: ConnectionDirection,
|
||||
|
||||
permit: OwnedSemaphorePermit,
|
||||
broadcast_rx: broadcast::Receiver<Arc<PeerBroadcast>>,
|
||||
|
||||
mut address_book: AdrBook,
|
||||
mut core_sync_svc: CSync,
|
||||
peer_request_svc: ReqHdlr,
|
||||
our_basic_node_data: BasicNodeData,
|
||||
) -> Result<(), HandshakeError>
|
||||
) -> Result<Client<Z>, HandshakeError>
|
||||
where
|
||||
AdrBook: AddressBook<Z>,
|
||||
CSync: CoreSyncSvc,
|
||||
|
@ -277,7 +296,27 @@ where
|
|||
|
||||
tracing::debug!("Handshake complete.");
|
||||
|
||||
Ok(())
|
||||
let error_slot = SharedError::new();
|
||||
|
||||
let (connection_guard, handle, _) = HandleBuilder::new().with_permit(permit).build();
|
||||
|
||||
let (connection_tx, client_rx) = mpsc::channel(3);
|
||||
|
||||
let connection = Connection::<Z, _>::new(
|
||||
peer_sink,
|
||||
client_rx,
|
||||
broadcast_rx,
|
||||
peer_request_svc,
|
||||
connection_guard,
|
||||
error_slot.clone(),
|
||||
);
|
||||
|
||||
let connection_handle =
|
||||
tokio::spawn(connection.run(peer_stream.fuse(), eager_protocol_messages));
|
||||
|
||||
let client = Client::<Z>::new(addr, handle, connection_tx, connection_handle, error_slot);
|
||||
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
/// Sends a [`HandshakeRequest`] to the peer.
|
||||
|
|
|
@ -1,12 +1,48 @@
|
|||
use std::sync::{Arc, OnceLock};
|
||||
|
||||
pub struct SharedError<T>(Arc<OnceLock<T>>);
|
||||
|
||||
impl<T> Clone for SharedError<T> {
|
||||
fn clone(&self) -> Self {
|
||||
Self(self.0.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Default for SharedError<T> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> SharedError<T> {
|
||||
pub fn new() -> Self {
|
||||
Self(Arc::new(OnceLock::new()))
|
||||
}
|
||||
|
||||
pub fn try_get_err(&self) -> Option<&T> {
|
||||
self.0.get()
|
||||
}
|
||||
|
||||
pub fn try_insert_err(&self, err: T) -> Result<(), &T> {
|
||||
self.0.set(err).map_err(|_| self.0.get().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum PeerError {
|
||||
#[error("The connection was closed.")]
|
||||
ConnectionClosed,
|
||||
#[error("The connection tasks client channel was closed")]
|
||||
ClientChannelClosed,
|
||||
#[error("error with peer response: {0}")]
|
||||
ResponseError(&'static str),
|
||||
#[error("the peer sent an incorrect response to our request")]
|
||||
PeerSentIncorrectResponse,
|
||||
#[error("bucket error")]
|
||||
#[error("the peer sent an invalid message")]
|
||||
PeerSentInvalidMessage,
|
||||
#[error("inner service error: {0}")]
|
||||
ServiceError(#[from] tower::BoxError),
|
||||
#[error("bucket error: {0}")]
|
||||
BucketError(#[from] monero_wire::BucketError),
|
||||
#[error("handshake error: {0}")]
|
||||
Handshake(#[from] crate::client::HandshakeError),
|
||||
|
|
|
@ -11,6 +11,10 @@ pub struct HandleBuilder {
|
|||
}
|
||||
|
||||
impl HandleBuilder {
|
||||
pub fn new() -> Self {
|
||||
Self { permit: None }
|
||||
}
|
||||
|
||||
pub fn with_permit(mut self, permit: OwnedSemaphorePermit) -> Self {
|
||||
self.permit = Some(permit);
|
||||
self
|
||||
|
|
|
@ -55,6 +55,13 @@ pub enum MessageID {
|
|||
NewTransactions,
|
||||
}
|
||||
|
||||
/// This is a sub-set of [`PeerRequest`] for requests that should be sent to all nodes.
|
||||
pub enum PeerBroadcast {
|
||||
Transactions(NewTransactions),
|
||||
NewBlock(NewBlock),
|
||||
NewFluffyBlock(NewFluffyBlock),
|
||||
}
|
||||
|
||||
pub enum PeerRequest {
|
||||
Handshake(HandshakeRequest),
|
||||
TimedSync(TimedSyncRequest),
|
||||
|
|
|
@ -5,6 +5,7 @@ use monero_wire::{Message, ProtocolMessage, RequestMessage, ResponseMessage};
|
|||
|
||||
use super::{PeerRequest, PeerResponse};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MessageConversionError;
|
||||
|
||||
macro_rules! match_body {
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
use std::sync::Arc;
|
||||
use std::{net::SocketAddr, str::FromStr};
|
||||
|
||||
use futures::{channel::mpsc, StreamExt};
|
||||
use tokio::sync::{broadcast, Semaphore};
|
||||
use tower::{Service, ServiceExt};
|
||||
|
||||
use cuprate_common::Network;
|
||||
|
@ -13,6 +15,7 @@ use monero_p2p::{
|
|||
};
|
||||
|
||||
use cuprate_test_utils::test_netzone::{TestNetZone, TestNetZoneAddr};
|
||||
use monero_p2p::client::InternalPeerID;
|
||||
|
||||
mod utils;
|
||||
use utils::*;
|
||||
|
@ -22,6 +25,11 @@ async fn handshake_cuprate_to_cuprate() {
|
|||
// Tests a Cuprate <-> Cuprate handshake by making 2 handshake services and making them talk to
|
||||
// each other.
|
||||
|
||||
let (broadcast_tx, _) = broadcast::channel(1); // this isn't actually used in this test.
|
||||
let semaphore = Arc::new(Semaphore::new(10));
|
||||
let permit_1 = semaphore.clone().acquire_owned().await.unwrap();
|
||||
let permit_2 = semaphore.acquire_owned().await.unwrap();
|
||||
|
||||
let our_basic_node_data_1 = BasicNodeData {
|
||||
my_port: 0,
|
||||
network_id: Network::Mainnet.network_id(),
|
||||
|
@ -39,6 +47,7 @@ async fn handshake_cuprate_to_cuprate() {
|
|||
DummyAddressBook,
|
||||
DummyCoreSyncSvc,
|
||||
DummyPeerRequestHandlerSvc,
|
||||
broadcast_tx.clone(),
|
||||
our_basic_node_data_1,
|
||||
);
|
||||
|
||||
|
@ -46,6 +55,7 @@ async fn handshake_cuprate_to_cuprate() {
|
|||
DummyAddressBook,
|
||||
DummyCoreSyncSvc,
|
||||
DummyPeerRequestHandlerSvc,
|
||||
broadcast_tx.clone(),
|
||||
our_basic_node_data_2,
|
||||
);
|
||||
|
||||
|
@ -53,17 +63,19 @@ async fn handshake_cuprate_to_cuprate() {
|
|||
let (p2_sender, p1_receiver) = mpsc::channel(5);
|
||||
|
||||
let p1_handshake_req = DoHandshakeRequest {
|
||||
addr: TestNetZoneAddr(888),
|
||||
addr: InternalPeerID::KnownAddr(TestNetZoneAddr(888)),
|
||||
peer_stream: p2_receiver.map(Ok).boxed(),
|
||||
peer_sink: p2_sender.into(),
|
||||
direction: ConnectionDirection::OutBound,
|
||||
permit: permit_1,
|
||||
};
|
||||
|
||||
let p2_handshake_req = DoHandshakeRequest {
|
||||
addr: TestNetZoneAddr(444),
|
||||
addr: InternalPeerID::KnownAddr(TestNetZoneAddr(444)),
|
||||
peer_stream: p1_receiver.boxed().map(Ok).boxed(),
|
||||
peer_sink: p1_sender.into(),
|
||||
direction: ConnectionDirection::InBound,
|
||||
permit: permit_2,
|
||||
};
|
||||
|
||||
let p1 = tokio::spawn(async move {
|
||||
|
@ -93,6 +105,10 @@ async fn handshake_cuprate_to_cuprate() {
|
|||
|
||||
#[tokio::test]
|
||||
async fn handshake() {
|
||||
let (broadcast_tx, _) = broadcast::channel(1); // this isn't actually used in this test.
|
||||
let semaphore = Arc::new(Semaphore::new(10));
|
||||
let permit = semaphore.acquire_owned().await.unwrap();
|
||||
|
||||
let addr = "127.0.0.1:18080";
|
||||
|
||||
let our_basic_node_data = BasicNodeData {
|
||||
|
@ -108,6 +124,7 @@ async fn handshake() {
|
|||
DummyAddressBook,
|
||||
DummyCoreSyncSvc,
|
||||
DummyPeerRequestHandlerSvc,
|
||||
broadcast_tx,
|
||||
our_basic_node_data,
|
||||
);
|
||||
|
||||
|
@ -119,6 +136,7 @@ async fn handshake() {
|
|||
.unwrap()
|
||||
.call(ConnectRequest {
|
||||
addr: SocketAddr::from_str(addr).unwrap(),
|
||||
permit,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
|
Loading…
Reference in a new issue