diff --git a/coins/bitcoin/src/rpc.rs b/coins/bitcoin/src/rpc.rs index fc25da94..73a215e0 100644 --- a/coins/bitcoin/src/rpc.rs +++ b/coins/bitcoin/src/rpc.rs @@ -6,7 +6,7 @@ use thiserror::Error; use serde::{Deserialize, de::DeserializeOwned}; use serde_json::json; -use simple_request::{Request, Client}; +use simple_request::{hyper, Request, Client}; use bitcoin::{ hashes::{Hash, hex::FromHex}, @@ -107,18 +107,20 @@ impl Rpc { method: &str, params: serde_json::Value, ) -> Result { + let mut request = Request::from( + hyper::Request::post(&self.url) + .header("Content-Type", "application/json") + .body( + serde_json::to_vec(&json!({ "jsonrpc": "2.0", "method": method, "params": params })) + .unwrap() + .into(), + ) + .unwrap(), + ); + request.with_basic_auth(); let mut res = self .client - .request( - Request::post(&self.url) - .header("Content-Type", "application/json") - .body( - serde_json::to_vec(&json!({ "jsonrpc": "2.0", "method": method, "params": params })) - .unwrap() - .into(), - ) - .unwrap(), - ) + .request(request) .await .map_err(|_| RpcError::ConnectionError)? .body() diff --git a/common/request/src/lib.rs b/common/request/src/lib.rs index edb879fb..99bf3b03 100644 --- a/common/request/src/lib.rs +++ b/common/request/src/lib.rs @@ -2,27 +2,19 @@ #![doc = include_str!("../README.md")] use hyper_rustls::{HttpsConnectorBuilder, HttpsConnector}; -use hyper::{ - StatusCode, - header::{HeaderValue, HeaderMap}, - body::{Buf, Body}, - Response as HyperResponse, - client::HttpConnector, -}; -pub use hyper::{self, Request}; +use hyper::{header::HeaderValue, client::HttpConnector}; +pub use hyper; + +mod request; +pub use request::*; + +mod response; +pub use response::*; #[derive(Debug)] -pub struct Response(HyperResponse); -impl Response { - pub fn status(&self) -> StatusCode { - self.0.status() - } - pub fn headers(&self) -> &HeaderMap { - self.0.headers() - } - pub async fn body(self) -> Result { - Ok(hyper::body::aggregate(self.0.into_body()).await?.reader()) - } +pub enum Error { + InvalidUri, + Hyper(hyper::Error), } #[derive(Clone, Debug)] @@ -35,12 +27,6 @@ pub struct Client { connection: Connection, } -#[derive(Debug)] -pub enum Error { - InvalidHost, - Hyper(hyper::Error), -} - impl Client { fn https_builder() -> HttpsConnector { HttpsConnectorBuilder::new().with_native_roots().https_or_http().enable_http1().build() @@ -56,38 +42,14 @@ impl Client { fn without_connection_pool() -> Client {} */ - pub async fn request(&self, mut request: Request) -> Result { + pub async fn request>(&self, request: R) -> Result { + let request: Request = request.into(); + let mut request = request.0; if request.headers().get(hyper::header::HOST).is_none() { - let host = request.uri().host().ok_or(Error::InvalidHost)?.to_string(); + let host = request.uri().host().ok_or(Error::InvalidUri)?.to_string(); request .headers_mut() - .insert(hyper::header::HOST, HeaderValue::from_str(&host).map_err(|_| Error::InvalidHost)?); - } - - #[cfg(feature = "basic-auth")] - if request.headers().get(hyper::header::AUTHORIZATION).is_none() { - if let Some(authority) = request.uri().authority() { - let authority = authority.as_str(); - if authority.contains('@') { - // Decode the username and password from the URI - let mut userpass = authority.split('@').next().unwrap().to_string(); - // If the password is "", the URI may omit :, yet the authentication will still expect it - if !userpass.contains(':') { - userpass.push(':'); - } - - use zeroize::Zeroize; - use base64ct::{Encoding, Base64}; - - let mut encoded = Base64::encode_string(userpass.as_bytes()); - userpass.zeroize(); - request.headers_mut().insert( - hyper::header::AUTHORIZATION, - HeaderValue::from_str(&format!("Basic {encoded}")).unwrap(), - ); - encoded.zeroize(); - } - } + .insert(hyper::header::HOST, HeaderValue::from_str(&host).map_err(|_| Error::InvalidUri)?); } Ok(Response(match &self.connection { diff --git a/common/request/src/request.rs b/common/request/src/request.rs new file mode 100644 index 00000000..1117e9fd --- /dev/null +++ b/common/request/src/request.rs @@ -0,0 +1,66 @@ +use hyper::body::Body; +#[cfg(feature = "basic-auth")] +use hyper::header::HeaderValue; + +#[cfg(feature = "basic-auth")] +use crate::Error; + +#[derive(Debug)] +pub struct Request(pub(crate) hyper::Request); +impl Request { + #[cfg(feature = "basic-auth")] + fn username_password_from_uri(&self) -> Result<(String, String), Error> { + if let Some(authority) = self.0.uri().authority() { + let authority = authority.as_str(); + if authority.contains('@') { + // Decode the username and password from the URI + let mut userpass = authority.split('@').next().unwrap().to_string(); + + let mut userpass_iter = userpass.split(':'); + let username = userpass_iter.next().unwrap().to_string(); + let password = userpass_iter.next().map(str::to_string).unwrap_or_else(String::new); + zeroize::Zeroize::zeroize(&mut userpass); + + return Ok((username, password)); + } + } + Err(Error::InvalidUri) + } + + #[cfg(feature = "basic-auth")] + pub fn basic_auth(&mut self, username: &str, password: &str) { + use zeroize::Zeroize; + use base64ct::{Encoding, Base64}; + + let mut formatted = format!("{username}:{password}"); + let mut encoded = Base64::encode_string(formatted.as_bytes()); + formatted.zeroize(); + self.0.headers_mut().insert( + hyper::header::AUTHORIZATION, + HeaderValue::from_str(&format!("Basic {encoded}")).unwrap(), + ); + encoded.zeroize(); + } + + #[cfg(feature = "basic-auth")] + pub fn basic_auth_from_uri(&mut self) -> Result<(), Error> { + let (mut username, mut password) = self.username_password_from_uri()?; + self.basic_auth(&username, &password); + + use zeroize::Zeroize; + username.zeroize(); + password.zeroize(); + + Ok(()) + } + + #[cfg(feature = "basic-auth")] + pub fn with_basic_auth(&mut self) { + let _ = self.basic_auth_from_uri(); + } +} +impl From> for Request { + fn from(request: hyper::Request) -> Request { + Request(request) + } +} diff --git a/common/request/src/response.rs b/common/request/src/response.rs new file mode 100644 index 00000000..4611324a --- /dev/null +++ b/common/request/src/response.rs @@ -0,0 +1,21 @@ +use hyper::{ + StatusCode, + header::{HeaderValue, HeaderMap}, + body::{Buf, Body}, +}; + +use crate::Error; + +#[derive(Debug)] +pub struct Response(pub(crate) hyper::Response); +impl Response { + pub fn status(&self) -> StatusCode { + self.0.status() + } + pub fn headers(&self) -> &HeaderMap { + self.0.headers() + } + pub async fn body(self) -> Result { + hyper::body::aggregate(self.0.into_body()).await.map(Buf::reader).map_err(Error::Hyper) + } +} diff --git a/message-queue/src/client.rs b/message-queue/src/client.rs index 9adcc2cd..ed103102 100644 --- a/message-queue/src/client.rs +++ b/message-queue/src/client.rs @@ -11,7 +11,7 @@ use schnorr_signatures::SchnorrSignature; use serde::{Serialize, Deserialize}; -use simple_request::{Request, Client}; +use simple_request::{hyper::Request, Client}; use serai_env as env;