split rpc calls into 3 Services

This commit is contained in:
hinto.janai 2024-09-02 20:15:12 -04:00
parent c57020d805
commit 68ba5f4781
No known key found for this signature in database
GPG key ID: D47CE05FA175A499
6 changed files with 196 additions and 122 deletions

View file

@ -3,6 +3,11 @@
//---------------------------------------------------------------------------------------------------- Use
use std::task::Poll;
use cuprate_rpc_types::{
bin::{BinRequest, BinResponse},
json::{JsonRpcRequest, JsonRpcResponse},
other::{OtherRequest, OtherResponse},
};
use futures::channel::oneshot::channel;
use serde::{Deserialize, Serialize};
use tower::Service;
@ -36,32 +41,57 @@ impl RpcHandler for CupratedRpcHandler {
}
}
impl Service<RpcRequest> for CupratedRpcHandler {
type Response = RpcResponse;
// INVARIANT:
//
// We don't need to check for `self.is_restricted()`
// here because `cuprate-rpc-interface` handles that.
//
// We can assume the request coming has the required permissions.
impl Service<JsonRpcRequest> for CupratedRpcHandler {
type Response = JsonRpcResponse;
type Error = RpcError;
type Future = InfallibleOneshotReceiver<Result<RpcResponse, RpcError>>;
type Future = InfallibleOneshotReceiver<Result<JsonRpcResponse, RpcError>>;
fn poll_ready(&mut self, _: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
/// INVARIANT:
///
/// We don't need to check for `self.is_restricted()`
/// here because `cuprate-rpc-interface` handles that.
///
/// We can assume the request coming has the required permissions.
fn call(&mut self, req: RpcRequest) -> Self::Future {
fn call(&mut self, request: JsonRpcRequest) -> Self::Future {
let state = Self::clone(self);
let resp = match req {
RpcRequest::JsonRpc(r) => {
RpcResponse::JsonRpc(json::map_request(state, r).expect("TODO"))
} // JSON-RPC 2.0 requests.
RpcRequest::Binary(r) => RpcResponse::Binary(bin::map_request(state, r).expect("TODO")), // Binary requests.
RpcRequest::Other(r) => RpcResponse::Other(other::map_request(state, r).expect("TODO")), // JSON (but not JSON-RPC) requests.
};
let response = json::map_request(state, request).expect("TODO");
todo!()
}
}
impl Service<BinRequest> for CupratedRpcHandler {
type Response = BinResponse;
type Error = RpcError;
type Future = InfallibleOneshotReceiver<Result<BinResponse, RpcError>>;
fn poll_ready(&mut self, _: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, request: BinRequest) -> Self::Future {
let state = Self::clone(self);
let response = bin::map_request(state, request).expect("TODO");
todo!()
}
}
impl Service<OtherRequest> for CupratedRpcHandler {
type Response = OtherResponse;
type Error = RpcError;
type Future = InfallibleOneshotReceiver<Result<OtherResponse, RpcError>>;
fn poll_ready(&mut self, _: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, request: OtherRequest) -> Self::Future {
let state = Self::clone(self);
let response = other::map_request(state, request).expect("TODO");
todo!()
}
}

View file

@ -67,13 +67,8 @@ macro_rules! generate_endpoints_inner {
paste::paste! {
{
// Send request.
let request = RpcRequest::Binary($request);
let channel = $handler.oneshot(request).await?;
let response = $handler.oneshot($request).await?;
// Assert the response from the inner handler is correct.
let RpcResponse::Binary(response) = channel else {
panic!("RPC handler did not return a binary response");
};
let BinResponse::$variant(response) = response else {
panic!("RPC handler returned incorrect response");
};

View file

@ -50,13 +50,7 @@ pub(crate) async fn json_rpc<H: RpcHandler>(
}
// Send request.
let request = RpcRequest::JsonRpc(request.body);
let channel = handler.oneshot(request).await?;
// Assert the response from the inner handler is correct.
let RpcResponse::JsonRpc(response) = channel else {
panic!("RPC handler returned incorrect response");
};
let response = handler.oneshot(request.body).await?;
Ok(Json(Response::ok(id, response)))
}

View file

@ -81,13 +81,9 @@ macro_rules! generate_endpoints_inner {
}
// Send request.
let request = RpcRequest::Other(OtherRequest::$variant($request));
let channel = $handler.oneshot(request).await?;
let request = OtherRequest::$variant($request);
let response = $handler.oneshot(request).await?;
// Assert the response from the inner handler is correct.
let RpcResponse::Other(response) = channel else {
panic!("RPC handler did not return a binary response");
};
let OtherResponse::$variant(response) = response else {
panic!("RPC handler returned incorrect response")
};

View file

@ -3,6 +3,11 @@
//---------------------------------------------------------------------------------------------------- Use
use std::future::Future;
use cuprate_rpc_types::{
bin::{BinRequest, BinResponse},
json::{JsonRpcRequest, JsonRpcResponse},
other::{OtherRequest, OtherResponse},
};
use tower::Service;
use crate::{rpc_error::RpcError, rpc_request::RpcRequest, rpc_response::RpcResponse};
@ -33,10 +38,22 @@ pub trait RpcHandler:
+ Sync
+ 'static
+ Service<
RpcRequest,
Response = RpcResponse,
JsonRpcRequest,
Response = JsonRpcResponse,
Error = RpcError,
Future: Future<Output = Result<RpcResponse, RpcError>> + Send + Sync + 'static,
Future: Future<Output = Result<JsonRpcResponse, RpcError>> + Send + Sync + 'static,
>
+ Service<
OtherRequest,
Response = OtherResponse,
Error = RpcError,
Future: Future<Output = Result<OtherResponse, RpcError>> + Send + Sync + 'static,
>
+ Service<
BinRequest,
Response = BinResponse,
Error = RpcError,
Future: Future<Output = Result<BinResponse, RpcError>> + Send + Sync + 'static,
>
{
/// Is this [`RpcHandler`] restricted?

View file

@ -3,6 +3,11 @@
//---------------------------------------------------------------------------------------------------- Use
use std::task::Poll;
use cuprate_rpc_types::{
bin::{BinRequest, BinResponse},
json::{JsonRpcRequest, JsonRpcResponse},
other::{OtherRequest, OtherResponse},
};
use futures::channel::oneshot::channel;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
@ -41,96 +46,133 @@ impl RpcHandler for RpcHandlerDummy {
}
}
impl Service<RpcRequest> for RpcHandlerDummy {
type Response = RpcResponse;
impl Service<JsonRpcRequest> for RpcHandlerDummy {
type Response = JsonRpcResponse;
type Error = RpcError;
type Future = InfallibleOneshotReceiver<Result<RpcResponse, RpcError>>;
type Future = InfallibleOneshotReceiver<Result<JsonRpcResponse, RpcError>>;
fn poll_ready(&mut self, _: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: RpcRequest) -> Self::Future {
use cuprate_rpc_types::bin::BinRequest as BReq;
use cuprate_rpc_types::bin::BinResponse as BResp;
use cuprate_rpc_types::json::JsonRpcRequest as JReq;
use cuprate_rpc_types::json::JsonRpcResponse as JResp;
use cuprate_rpc_types::other::OtherRequest as OReq;
use cuprate_rpc_types::other::OtherResponse as OResp;
fn call(&mut self, req: JsonRpcRequest) -> Self::Future {
use cuprate_rpc_types::json::JsonRpcRequest as Req;
use cuprate_rpc_types::json::JsonRpcResponse as Resp;
#[rustfmt::skip]
#[allow(clippy::default_trait_access)]
let resp = match req {
RpcRequest::JsonRpc(j) => RpcResponse::JsonRpc(match j {
JReq::GetBlockCount(_) => JResp::GetBlockCount(Default::default()),
JReq::OnGetBlockHash(_) => JResp::OnGetBlockHash(Default::default()),
JReq::SubmitBlock(_) => JResp::SubmitBlock(Default::default()),
JReq::GenerateBlocks(_) => JResp::GenerateBlocks(Default::default()),
JReq::GetLastBlockHeader(_) => JResp::GetLastBlockHeader(Default::default()),
JReq::GetBlockHeaderByHash(_) => JResp::GetBlockHeaderByHash(Default::default()),
JReq::GetBlockHeaderByHeight(_) => JResp::GetBlockHeaderByHeight(Default::default()),
JReq::GetBlockHeadersRange(_) => JResp::GetBlockHeadersRange(Default::default()),
JReq::GetBlock(_) => JResp::GetBlock(Default::default()),
JReq::GetConnections(_) => JResp::GetConnections(Default::default()),
JReq::GetInfo(_) => JResp::GetInfo(Default::default()),
JReq::HardForkInfo(_) => JResp::HardForkInfo(Default::default()),
JReq::SetBans(_) => JResp::SetBans(Default::default()),
JReq::GetBans(_) => JResp::GetBans(Default::default()),
JReq::Banned(_) => JResp::Banned(Default::default()),
JReq::FlushTransactionPool(_) => JResp::FlushTransactionPool(Default::default()),
JReq::GetOutputHistogram(_) => JResp::GetOutputHistogram(Default::default()),
JReq::GetCoinbaseTxSum(_) => JResp::GetCoinbaseTxSum(Default::default()),
JReq::GetVersion(_) => JResp::GetVersion(Default::default()),
JReq::GetFeeEstimate(_) => JResp::GetFeeEstimate(Default::default()),
JReq::GetAlternateChains(_) => JResp::GetAlternateChains(Default::default()),
JReq::RelayTx(_) => JResp::RelayTx(Default::default()),
JReq::SyncInfo(_) => JResp::SyncInfo(Default::default()),
JReq::GetTransactionPoolBacklog(_) => JResp::GetTransactionPoolBacklog(Default::default()),
JReq::GetMinerData(_) => JResp::GetMinerData(Default::default()),
JReq::PruneBlockchain(_) => JResp::PruneBlockchain(Default::default()),
JReq::CalcPow(_) => JResp::CalcPow(Default::default()),
JReq::FlushCache(_) => JResp::FlushCache(Default::default()),
JReq::AddAuxPow(_) => JResp::AddAuxPow(Default::default()),
JReq::GetTxIdsLoose(_) => JResp::GetTxIdsLoose(Default::default()),
}),
RpcRequest::Binary(b) => RpcResponse::Binary(match b {
BReq::GetBlocks(_) => BResp::GetBlocks(Default::default()),
BReq::GetBlocksByHeight(_) => BResp::GetBlocksByHeight(Default::default()),
BReq::GetHashes(_) => BResp::GetHashes(Default::default()),
BReq::GetOutputIndexes(_) => BResp::GetOutputIndexes(Default::default()),
BReq::GetOuts(_) => BResp::GetOuts(Default::default()),
BReq::GetTransactionPoolHashes(_) => BResp::GetTransactionPoolHashes(Default::default()),
BReq::GetOutputDistribution(_) => BResp::GetOutputDistribution(Default::default()),
}),
RpcRequest::Other(o) => RpcResponse::Other(match o {
OReq::GetHeight(_) => OResp::GetHeight(Default::default()),
OReq::GetTransactions(_) => OResp::GetTransactions(Default::default()),
OReq::GetAltBlocksHashes(_) => OResp::GetAltBlocksHashes(Default::default()),
OReq::IsKeyImageSpent(_) => OResp::IsKeyImageSpent(Default::default()),
OReq::SendRawTransaction(_) => OResp::SendRawTransaction(Default::default()),
OReq::StartMining(_) => OResp::StartMining(Default::default()),
OReq::StopMining(_) => OResp::StopMining(Default::default()),
OReq::MiningStatus(_) => OResp::MiningStatus(Default::default()),
OReq::SaveBc(_) => OResp::SaveBc(Default::default()),
OReq::GetPeerList(_) => OResp::GetPeerList(Default::default()),
OReq::SetLogHashRate(_) => OResp::SetLogHashRate(Default::default()),
OReq::SetLogLevel(_) => OResp::SetLogLevel(Default::default()),
OReq::SetLogCategories(_) => OResp::SetLogCategories(Default::default()),
OReq::SetBootstrapDaemon(_) => OResp::SetBootstrapDaemon(Default::default()),
OReq::GetTransactionPool(_) => OResp::GetTransactionPool(Default::default()),
OReq::GetTransactionPoolStats(_) => OResp::GetTransactionPoolStats(Default::default()),
OReq::StopDaemon(_) => OResp::StopDaemon(Default::default()),
OReq::GetLimit(_) => OResp::GetLimit(Default::default()),
OReq::SetLimit(_) => OResp::SetLimit(Default::default()),
OReq::OutPeers(_) => OResp::OutPeers(Default::default()),
OReq::InPeers(_) => OResp::InPeers(Default::default()),
OReq::GetNetStats(_) => OResp::GetNetStats(Default::default()),
OReq::GetOuts(_) => OResp::GetOuts(Default::default()),
OReq::Update(_) => OResp::Update(Default::default()),
OReq::PopBlocks(_) => OResp::PopBlocks(Default::default()),
OReq::GetTransactionPoolHashes(_) => OResp::GetTransactionPoolHashes(Default::default()),
OReq::GetPublicNodes(_) => OResp::GetPublicNodes(Default::default()),
})
Req::GetBlockCount(_) => Resp::GetBlockCount(Default::default()),
Req::OnGetBlockHash(_) => Resp::OnGetBlockHash(Default::default()),
Req::SubmitBlock(_) => Resp::SubmitBlock(Default::default()),
Req::GenerateBlocks(_) => Resp::GenerateBlocks(Default::default()),
Req::GetLastBlockHeader(_) => Resp::GetLastBlockHeader(Default::default()),
Req::GetBlockHeaderByHash(_) => Resp::GetBlockHeaderByHash(Default::default()),
Req::GetBlockHeaderByHeight(_) => Resp::GetBlockHeaderByHeight(Default::default()),
Req::GetBlockHeadersRange(_) => Resp::GetBlockHeadersRange(Default::default()),
Req::GetBlock(_) => Resp::GetBlock(Default::default()),
Req::GetConnections(_) => Resp::GetConnections(Default::default()),
Req::GetInfo(_) => Resp::GetInfo(Default::default()),
Req::HardForkInfo(_) => Resp::HardForkInfo(Default::default()),
Req::SetBans(_) => Resp::SetBans(Default::default()),
Req::GetBans(_) => Resp::GetBans(Default::default()),
Req::Banned(_) => Resp::Banned(Default::default()),
Req::FlushTransactionPool(_) => Resp::FlushTransactionPool(Default::default()),
Req::GetOutputHistogram(_) => Resp::GetOutputHistogram(Default::default()),
Req::GetCoinbaseTxSum(_) => Resp::GetCoinbaseTxSum(Default::default()),
Req::GetVersion(_) => Resp::GetVersion(Default::default()),
Req::GetFeeEstimate(_) => Resp::GetFeeEstimate(Default::default()),
Req::GetAlternateChains(_) => Resp::GetAlternateChains(Default::default()),
Req::RelayTx(_) => Resp::RelayTx(Default::default()),
Req::SyncInfo(_) => Resp::SyncInfo(Default::default()),
Req::GetTransactionPoolBacklog(_) => {
Resp::GetTransactionPoolBacklog(Default::default())
}
Req::GetMinerData(_) => Resp::GetMinerData(Default::default()),
Req::PruneBlockchain(_) => Resp::PruneBlockchain(Default::default()),
Req::CalcPow(_) => Resp::CalcPow(Default::default()),
Req::FlushCache(_) => Resp::FlushCache(Default::default()),
Req::AddAuxPow(_) => Resp::AddAuxPow(Default::default()),
Req::GetTxIdsLoose(_) => Resp::GetTxIdsLoose(Default::default()),
};
let (tx, rx) = channel();
drop(tx.send(Ok(resp)));
InfallibleOneshotReceiver::from(rx)
}
}
impl Service<BinRequest> for RpcHandlerDummy {
type Response = BinResponse;
type Error = RpcError;
type Future = InfallibleOneshotReceiver<Result<BinResponse, RpcError>>;
fn poll_ready(&mut self, _: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: BinRequest) -> Self::Future {
use cuprate_rpc_types::bin::BinRequest as Req;
use cuprate_rpc_types::bin::BinResponse as Resp;
#[allow(clippy::default_trait_access)]
let resp = match req {
Req::GetBlocks(_) => Resp::GetBlocks(Default::default()),
Req::GetBlocksByHeight(_) => Resp::GetBlocksByHeight(Default::default()),
Req::GetHashes(_) => Resp::GetHashes(Default::default()),
Req::GetOutputIndexes(_) => Resp::GetOutputIndexes(Default::default()),
Req::GetOuts(_) => Resp::GetOuts(Default::default()),
Req::GetTransactionPoolHashes(_) => Resp::GetTransactionPoolHashes(Default::default()),
Req::GetOutputDistribution(_) => Resp::GetOutputDistribution(Default::default()),
};
let (tx, rx) = channel();
drop(tx.send(Ok(resp)));
InfallibleOneshotReceiver::from(rx)
}
}
impl Service<OtherRequest> for RpcHandlerDummy {
type Response = OtherResponse;
type Error = RpcError;
type Future = InfallibleOneshotReceiver<Result<OtherResponse, RpcError>>;
fn poll_ready(&mut self, _: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: OtherRequest) -> Self::Future {
use cuprate_rpc_types::other::OtherRequest as Req;
use cuprate_rpc_types::other::OtherResponse as Resp;
#[allow(clippy::default_trait_access)]
let resp = match req {
Req::GetHeight(_) => Resp::GetHeight(Default::default()),
Req::GetTransactions(_) => Resp::GetTransactions(Default::default()),
Req::GetAltBlocksHashes(_) => Resp::GetAltBlocksHashes(Default::default()),
Req::IsKeyImageSpent(_) => Resp::IsKeyImageSpent(Default::default()),
Req::SendRawTransaction(_) => Resp::SendRawTransaction(Default::default()),
Req::StartMining(_) => Resp::StartMining(Default::default()),
Req::StopMining(_) => Resp::StopMining(Default::default()),
Req::MiningStatus(_) => Resp::MiningStatus(Default::default()),
Req::SaveBc(_) => Resp::SaveBc(Default::default()),
Req::GetPeerList(_) => Resp::GetPeerList(Default::default()),
Req::SetLogHashRate(_) => Resp::SetLogHashRate(Default::default()),
Req::SetLogLevel(_) => Resp::SetLogLevel(Default::default()),
Req::SetLogCategories(_) => Resp::SetLogCategories(Default::default()),
Req::SetBootstrapDaemon(_) => Resp::SetBootstrapDaemon(Default::default()),
Req::GetTransactionPool(_) => Resp::GetTransactionPool(Default::default()),
Req::GetTransactionPoolStats(_) => Resp::GetTransactionPoolStats(Default::default()),
Req::StopDaemon(_) => Resp::StopDaemon(Default::default()),
Req::GetLimit(_) => Resp::GetLimit(Default::default()),
Req::SetLimit(_) => Resp::SetLimit(Default::default()),
Req::OutPeers(_) => Resp::OutPeers(Default::default()),
Req::InPeers(_) => Resp::InPeers(Default::default()),
Req::GetNetStats(_) => Resp::GetNetStats(Default::default()),
Req::GetOuts(_) => Resp::GetOuts(Default::default()),
Req::Update(_) => Resp::Update(Default::default()),
Req::PopBlocks(_) => Resp::PopBlocks(Default::default()),
Req::GetTransactionPoolHashes(_) => Resp::GetTransactionPoolHashes(Default::default()),
Req::GetPublicNodes(_) => Resp::GetPublicNodes(Default::default()),
};
let (tx, rx) = channel();