Clean up some of the rpc code

This commit is contained in:
Boog900 2023-09-07 21:38:56 +01:00
parent b243ff0021
commit 0000ee96b3
No known key found for this signature in database
GPG key ID: 5401367FB7302004

View file

@ -5,10 +5,9 @@ use std::task::{Context, Poll};
use futures::lock::{OwnedMutexGuard, OwnedMutexLockFuture};
use futures::{FutureExt, TryFutureExt};
use monero_serai::rpc::{HttpRpc, RpcConnection};
use monero_serai::rpc::{HttpRpc, RpcConnection, RpcError};
use serde::Deserialize;
use serde_json::json;
use tower::BoxError;
use cuprate_common::BlockID;
@ -20,26 +19,30 @@ enum RpcState<R: RpcConnection> {
Acquiring(OwnedMutexLockFuture<monero_serai::rpc::Rpc<R>>),
Acquired(OwnedMutexGuard<monero_serai::rpc::Rpc<R>>),
}
pub struct Rpc<R: RpcConnection>(
Arc<futures::lock::Mutex<monero_serai::rpc::Rpc<R>>>,
RpcState<R>,
Arc<Mutex<bool>>,
);
pub struct Rpc<R: RpcConnection> {
rpc: Arc<futures::lock::Mutex<monero_serai::rpc::Rpc<R>>>,
rpc_state: RpcState<R>,
error_slot: Arc<Mutex<Option<RpcError>>>,
}
impl Rpc<HttpRpc> {
pub fn new_http(addr: String) -> Rpc<HttpRpc> {
let http_rpc = HttpRpc::new(addr).unwrap();
Rpc(
Arc::new(futures::lock::Mutex::new(http_rpc)),
RpcState::Locked,
Arc::new(Mutex::new(false)),
)
Rpc {
rpc: Arc::new(futures::lock::Mutex::new(http_rpc)),
rpc_state: RpcState::Locked,
error_slot: Arc::new(Mutex::new(None)),
}
}
}
impl<R: RpcConnection> Clone for Rpc<R> {
fn clone(&self) -> Self {
Rpc(Arc::clone(&self.0), RpcState::Locked, Arc::clone(&self.2))
Rpc {
rpc: Arc::clone(&self.rpc),
rpc_state: RpcState::Locked,
error_slot: Arc::clone(&self.error_slot),
}
}
}
@ -50,14 +53,16 @@ impl<R: RpcConnection + Send + Sync + 'static> tower::Service<DatabaseRequest> f
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if *self.2.lock().unwrap() {
return Poll::Ready(Err("Rpc has errored".into()));
if let Some(rpc_error) = self.error_slot.lock().unwrap().clone() {
return Poll::Ready(Err(rpc_error.into()));
}
loop {
match &mut self.1 {
RpcState::Locked => self.1 = RpcState::Acquiring(self.0.clone().lock_owned()),
match &mut self.rpc_state {
RpcState::Locked => {
self.rpc_state = RpcState::Acquiring(Arc::clone(&self.rpc).lock_owned())
}
RpcState::Acquiring(rpc) => {
self.1 = RpcState::Acquired(futures::ready!(rpc.poll_unpin(cx)))
self.rpc_state = RpcState::Acquired(futures::ready!(rpc.poll_unpin(cx)))
}
RpcState::Acquired(_) => return Poll::Ready(Ok(())),
}
@ -65,49 +70,47 @@ impl<R: RpcConnection + Send + Sync + 'static> tower::Service<DatabaseRequest> f
}
fn call(&mut self, req: DatabaseRequest) -> Self::Future {
let RpcState::Acquired(rpc) = std::mem::replace(&mut self.1, RpcState::Locked) else {
let RpcState::Acquired(rpc) = std::mem::replace(&mut self.rpc_state, RpcState::Locked)
else {
panic!("poll_ready was not called first!");
};
let err = self.2.clone();
let err_slot = self.error_slot.clone();
match req {
DatabaseRequest::ChainHeight => async move {
let res = rpc
let res: Result<_, RpcError> = rpc
.get_height()
.map_ok(|height| DatabaseResponse::ChainHeight(height.try_into().unwrap()))
.map_err(Into::into)
.await;
if res.is_err() {
*err.lock().unwrap() = true;
if let Err(e) = &res {
*err_slot.lock().unwrap() = Some(e.clone());
}
res
res.map_err(Into::into)
}
.boxed(),
DatabaseRequest::BlockHeader(id) => match id {
BlockID::Hash(hash) => async move {
let res = rpc
let res: Result<_, RpcError> = rpc
.get_block(hash)
.map_ok(|block| DatabaseResponse::BlockHeader(block.header))
.map_err(Into::<BoxError>::into)
.await;
if res.is_err() {
*err.lock().unwrap() = true;
if let Err(e) = &res {
*err_slot.lock().unwrap() = Some(e.clone());
}
res
res.map_err(Into::into)
}
.boxed(),
BlockID::Height(height) => async move {
let res = rpc
let res: Result<_, RpcError> = rpc
.get_block_by_number(height.try_into().unwrap())
.map_ok(|block| DatabaseResponse::BlockHeader(block.header))
.map_err(Into::into)
.await;
if res.is_err() {
*err.lock().unwrap() = true;
if let Err(e) = &res {
*err_slot.lock().unwrap() = Some(e.clone());
}
res
res.map_err(Into::into)
}
.boxed(),
},