From 0000ee96b3f27e82be86b7a8fb8a1b5e9dc77b66 Mon Sep 17 00:00:00 2001 From: Boog900 <54e72d8a-345f-4599-bd90-c6b9bc7d0ec5@aleeas.com> Date: Thu, 7 Sep 2023 21:38:56 +0100 Subject: [PATCH] Clean up some of the rpc code --- consensus/src/rpc.rs | 73 +++++++++++++++++++++++--------------------- 1 file changed, 38 insertions(+), 35 deletions(-) diff --git a/consensus/src/rpc.rs b/consensus/src/rpc.rs index d2a542f..30ffb10 100644 --- a/consensus/src/rpc.rs +++ b/consensus/src/rpc.rs @@ -5,10 +5,9 @@ use std::task::{Context, Poll}; use futures::lock::{OwnedMutexGuard, OwnedMutexLockFuture}; use futures::{FutureExt, TryFutureExt}; -use monero_serai::rpc::{HttpRpc, RpcConnection}; +use monero_serai::rpc::{HttpRpc, RpcConnection, RpcError}; use serde::Deserialize; use serde_json::json; -use tower::BoxError; use cuprate_common::BlockID; @@ -20,26 +19,30 @@ enum RpcState { Acquiring(OwnedMutexLockFuture>), Acquired(OwnedMutexGuard>), } -pub struct Rpc( - Arc>>, - RpcState, - Arc>, -); +pub struct Rpc { + rpc: Arc>>, + rpc_state: RpcState, + error_slot: Arc>>, +} impl Rpc { pub fn new_http(addr: String) -> Rpc { let http_rpc = HttpRpc::new(addr).unwrap(); - Rpc( - Arc::new(futures::lock::Mutex::new(http_rpc)), - RpcState::Locked, - Arc::new(Mutex::new(false)), - ) + Rpc { + rpc: Arc::new(futures::lock::Mutex::new(http_rpc)), + rpc_state: RpcState::Locked, + error_slot: Arc::new(Mutex::new(None)), + } } } impl Clone for Rpc { fn clone(&self) -> Self { - Rpc(Arc::clone(&self.0), RpcState::Locked, Arc::clone(&self.2)) + Rpc { + rpc: Arc::clone(&self.rpc), + rpc_state: RpcState::Locked, + error_slot: Arc::clone(&self.error_slot), + } } } @@ -50,14 +53,16 @@ impl tower::Service f Pin> + Send + 'static>>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - if *self.2.lock().unwrap() { - return Poll::Ready(Err("Rpc has errored".into())); + if let Some(rpc_error) = self.error_slot.lock().unwrap().clone() { + return Poll::Ready(Err(rpc_error.into())); } loop { - match &mut self.1 { - RpcState::Locked => self.1 = RpcState::Acquiring(self.0.clone().lock_owned()), + match &mut self.rpc_state { + RpcState::Locked => { + self.rpc_state = RpcState::Acquiring(Arc::clone(&self.rpc).lock_owned()) + } RpcState::Acquiring(rpc) => { - self.1 = RpcState::Acquired(futures::ready!(rpc.poll_unpin(cx))) + self.rpc_state = RpcState::Acquired(futures::ready!(rpc.poll_unpin(cx))) } RpcState::Acquired(_) => return Poll::Ready(Ok(())), } @@ -65,49 +70,47 @@ impl tower::Service f } fn call(&mut self, req: DatabaseRequest) -> Self::Future { - let RpcState::Acquired(rpc) = std::mem::replace(&mut self.1, RpcState::Locked) else { + let RpcState::Acquired(rpc) = std::mem::replace(&mut self.rpc_state, RpcState::Locked) + else { panic!("poll_ready was not called first!"); }; - let err = self.2.clone(); + let err_slot = self.error_slot.clone(); match req { DatabaseRequest::ChainHeight => async move { - let res = rpc + let res: Result<_, RpcError> = rpc .get_height() .map_ok(|height| DatabaseResponse::ChainHeight(height.try_into().unwrap())) - .map_err(Into::into) .await; - if res.is_err() { - *err.lock().unwrap() = true; + if let Err(e) = &res { + *err_slot.lock().unwrap() = Some(e.clone()); } - res + res.map_err(Into::into) } .boxed(), DatabaseRequest::BlockHeader(id) => match id { BlockID::Hash(hash) => async move { - let res = rpc + let res: Result<_, RpcError> = rpc .get_block(hash) .map_ok(|block| DatabaseResponse::BlockHeader(block.header)) - .map_err(Into::::into) .await; - if res.is_err() { - *err.lock().unwrap() = true; + if let Err(e) = &res { + *err_slot.lock().unwrap() = Some(e.clone()); } - res + res.map_err(Into::into) } .boxed(), BlockID::Height(height) => async move { - let res = rpc + let res: Result<_, RpcError> = rpc .get_block_by_number(height.try_into().unwrap()) .map_ok(|block| DatabaseResponse::BlockHeader(block.header)) - .map_err(Into::into) .await; - if res.is_err() { - *err.lock().unwrap() = true; + if let Err(e) = &res { + *err_slot.lock().unwrap() = Some(e.clone()); } - res + res.map_err(Into::into) } .boxed(), },