Clean up some of the rpc code

This commit is contained in:
Boog900 2023-09-07 21:38:56 +01:00
parent b243ff0021
commit 0000ee96b3
No known key found for this signature in database
GPG key ID: 5401367FB7302004

View file

@ -5,10 +5,9 @@ use std::task::{Context, Poll};
use futures::lock::{OwnedMutexGuard, OwnedMutexLockFuture}; use futures::lock::{OwnedMutexGuard, OwnedMutexLockFuture};
use futures::{FutureExt, TryFutureExt}; use futures::{FutureExt, TryFutureExt};
use monero_serai::rpc::{HttpRpc, RpcConnection}; use monero_serai::rpc::{HttpRpc, RpcConnection, RpcError};
use serde::Deserialize; use serde::Deserialize;
use serde_json::json; use serde_json::json;
use tower::BoxError;
use cuprate_common::BlockID; use cuprate_common::BlockID;
@ -20,26 +19,30 @@ enum RpcState<R: RpcConnection> {
Acquiring(OwnedMutexLockFuture<monero_serai::rpc::Rpc<R>>), Acquiring(OwnedMutexLockFuture<monero_serai::rpc::Rpc<R>>),
Acquired(OwnedMutexGuard<monero_serai::rpc::Rpc<R>>), Acquired(OwnedMutexGuard<monero_serai::rpc::Rpc<R>>),
} }
pub struct Rpc<R: RpcConnection>( pub struct Rpc<R: RpcConnection> {
Arc<futures::lock::Mutex<monero_serai::rpc::Rpc<R>>>, rpc: Arc<futures::lock::Mutex<monero_serai::rpc::Rpc<R>>>,
RpcState<R>, rpc_state: RpcState<R>,
Arc<Mutex<bool>>, error_slot: Arc<Mutex<Option<RpcError>>>,
); }
impl Rpc<HttpRpc> { impl Rpc<HttpRpc> {
pub fn new_http(addr: String) -> Rpc<HttpRpc> { pub fn new_http(addr: String) -> Rpc<HttpRpc> {
let http_rpc = HttpRpc::new(addr).unwrap(); let http_rpc = HttpRpc::new(addr).unwrap();
Rpc( Rpc {
Arc::new(futures::lock::Mutex::new(http_rpc)), rpc: Arc::new(futures::lock::Mutex::new(http_rpc)),
RpcState::Locked, rpc_state: RpcState::Locked,
Arc::new(Mutex::new(false)), error_slot: Arc::new(Mutex::new(None)),
) }
} }
} }
impl<R: RpcConnection> Clone for Rpc<R> { impl<R: RpcConnection> Clone for Rpc<R> {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Rpc(Arc::clone(&self.0), RpcState::Locked, Arc::clone(&self.2)) Rpc {
rpc: Arc::clone(&self.rpc),
rpc_state: RpcState::Locked,
error_slot: Arc::clone(&self.error_slot),
}
} }
} }
@ -50,14 +53,16 @@ impl<R: RpcConnection + Send + Sync + 'static> tower::Service<DatabaseRequest> f
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>; Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if *self.2.lock().unwrap() { if let Some(rpc_error) = self.error_slot.lock().unwrap().clone() {
return Poll::Ready(Err("Rpc has errored".into())); return Poll::Ready(Err(rpc_error.into()));
} }
loop { loop {
match &mut self.1 { match &mut self.rpc_state {
RpcState::Locked => self.1 = RpcState::Acquiring(self.0.clone().lock_owned()), RpcState::Locked => {
self.rpc_state = RpcState::Acquiring(Arc::clone(&self.rpc).lock_owned())
}
RpcState::Acquiring(rpc) => { RpcState::Acquiring(rpc) => {
self.1 = RpcState::Acquired(futures::ready!(rpc.poll_unpin(cx))) self.rpc_state = RpcState::Acquired(futures::ready!(rpc.poll_unpin(cx)))
} }
RpcState::Acquired(_) => return Poll::Ready(Ok(())), RpcState::Acquired(_) => return Poll::Ready(Ok(())),
} }
@ -65,49 +70,47 @@ impl<R: RpcConnection + Send + Sync + 'static> tower::Service<DatabaseRequest> f
} }
fn call(&mut self, req: DatabaseRequest) -> Self::Future { fn call(&mut self, req: DatabaseRequest) -> Self::Future {
let RpcState::Acquired(rpc) = std::mem::replace(&mut self.1, RpcState::Locked) else { let RpcState::Acquired(rpc) = std::mem::replace(&mut self.rpc_state, RpcState::Locked)
else {
panic!("poll_ready was not called first!"); panic!("poll_ready was not called first!");
}; };
let err = self.2.clone(); let err_slot = self.error_slot.clone();
match req { match req {
DatabaseRequest::ChainHeight => async move { DatabaseRequest::ChainHeight => async move {
let res = rpc let res: Result<_, RpcError> = rpc
.get_height() .get_height()
.map_ok(|height| DatabaseResponse::ChainHeight(height.try_into().unwrap())) .map_ok(|height| DatabaseResponse::ChainHeight(height.try_into().unwrap()))
.map_err(Into::into)
.await; .await;
if res.is_err() { if let Err(e) = &res {
*err.lock().unwrap() = true; *err_slot.lock().unwrap() = Some(e.clone());
} }
res res.map_err(Into::into)
} }
.boxed(), .boxed(),
DatabaseRequest::BlockHeader(id) => match id { DatabaseRequest::BlockHeader(id) => match id {
BlockID::Hash(hash) => async move { BlockID::Hash(hash) => async move {
let res = rpc let res: Result<_, RpcError> = rpc
.get_block(hash) .get_block(hash)
.map_ok(|block| DatabaseResponse::BlockHeader(block.header)) .map_ok(|block| DatabaseResponse::BlockHeader(block.header))
.map_err(Into::<BoxError>::into)
.await; .await;
if res.is_err() { if let Err(e) = &res {
*err.lock().unwrap() = true; *err_slot.lock().unwrap() = Some(e.clone());
} }
res res.map_err(Into::into)
} }
.boxed(), .boxed(),
BlockID::Height(height) => async move { BlockID::Height(height) => async move {
let res = rpc let res: Result<_, RpcError> = rpc
.get_block_by_number(height.try_into().unwrap()) .get_block_by_number(height.try_into().unwrap())
.map_ok(|block| DatabaseResponse::BlockHeader(block.header)) .map_ok(|block| DatabaseResponse::BlockHeader(block.header))
.map_err(Into::into)
.await; .await;
if res.is_err() { if let Err(e) = &res {
*err.lock().unwrap() = true; *err_slot.lock().unwrap() = Some(e.clone());
} }
res res.map_err(Into::into)
} }
.boxed(), .boxed(),
}, },