diff --git a/Cargo.lock b/Cargo.lock index ed8dd39e..96bd38dc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4877,8 +4877,6 @@ dependencies = [ "group", "hex", "hex-literal", - "hyper", - "hyper-rustls", "modular-frost", "monero-generators", "multiexp", @@ -4890,6 +4888,7 @@ dependencies = [ "serde", "serde_json", "sha3", + "simple-request", "std-shims", "subtle", "thiserror", diff --git a/coins/monero/Cargo.toml b/coins/monero/Cargo.toml index 5c317ed0..cf0549fe 100644 --- a/coins/monero/Cargo.toml +++ b/coins/monero/Cargo.toml @@ -55,9 +55,7 @@ base58-monero = { version = "2", default-features = false, features = ["check"] # Used for the provided HTTP RPC digest_auth = { version = "0.3", default-features = false, optional = true } -# Deprecated here means to enable deprecated warnings, not to restore deprecated APIs -hyper = { version = "0.14", default-features = false, features = ["http1", "tcp", "client", "backports", "deprecated"], optional = true } -hyper-rustls = { version = "0.24", default-features = false, features = ["http1", "native-tokio"], optional = true } +simple-request = { path = "../../common/request", version = "0.1", default-features = false, optional = true } tokio = { version = "1", default-features = false, optional = true } [build-dependencies] @@ -102,7 +100,7 @@ std = [ "base58-monero/std", ] -http-rpc = ["digest_auth", "hyper", "hyper-rustls", "tokio/time", "tokio/rt"] +http-rpc = ["digest_auth", "simple-request", "tokio"] multisig = ["transcript", "frost", "dleq", "std"] binaries = ["tokio/rt-multi-thread", "tokio/macros", "http-rpc"] experimental = [] diff --git a/coins/monero/src/rpc/http.rs b/coins/monero/src/rpc/http.rs index 6489d052..b7462cd0 100644 --- a/coins/monero/src/rpc/http.rs +++ b/coins/monero/src/rpc/http.rs @@ -1,25 +1,25 @@ -use core::str::FromStr; +use std::io::Read; use async_trait::async_trait; use digest_auth::AuthContext; -use hyper::{ - Uri, header::HeaderValue, Request, service::Service, client::connect::HttpConnector, Client, +use simple_request::{ + hyper::{header::HeaderValue, Request}, + Client, }; -use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder}; use crate::rpc::{RpcError, RpcConnection, Rpc}; #[derive(Clone, Debug)] enum Authentication { // If unauthenticated, reuse a single client - Unauthenticated(Client>), + Unauthenticated(Client), // If authenticated, don't reuse clients so that each connection makes its own connection // This ensures that if a nonce is requested, another caller doesn't make a request invalidating // it // We could acquire a mutex over the client, yet creating a new client is preferred for the // possibility of parallelism - Authenticated(HttpsConnector, String, String), + Authenticated { username: String, password: String }, } /// An HTTP(S) transport for the RPC. @@ -37,9 +37,6 @@ impl HttpRpc { /// A daemon requiring authentication can be used via including the username and password in the /// URL. pub fn new(mut url: String) -> Result, RpcError> { - let https_builder = - HttpsConnectorBuilder::new().with_native_roots().https_or_http().enable_http1().build(); - let authentication = if url.contains('@') { // Parse out the username and password let url_clone = url; @@ -64,13 +61,12 @@ impl HttpRpc { if split_userpass.len() > 2 { Err(RpcError::ConnectionError("invalid amount of passwords".to_string()))?; } - Authentication::Authenticated( - https_builder, - split_userpass[0].to_string(), - split_userpass.get(1).unwrap_or(&"").to_string(), - ) + Authentication::Authenticated { + username: split_userpass[0].to_string(), + password: split_userpass.get(1).unwrap_or(&"").to_string(), + } } else { - Authentication::Unauthenticated(Client::builder().build(https_builder)) + Authentication::Unauthenticated(Client::with_connection_pool()) }; Ok(Rpc(HttpRpc { authentication, url })) @@ -79,45 +75,26 @@ impl HttpRpc { impl HttpRpc { async fn inner_post(&self, route: &str, body: Vec) -> Result, RpcError> { - let request = |uri| { - Request::post(uri) - .header(hyper::header::HOST, { - let mut host = self.url.clone(); - if let Some(protocol_pos) = host.find("://") { - host.drain(0 .. (protocol_pos + 3)); - } - host - }) - .body(body.clone().into()) - .unwrap() - }; + let request = |uri| Request::post(uri).body(body.clone().into()).unwrap(); - let mut connection_task_handle = None; + let mut connection = None; let response = match &self.authentication { Authentication::Unauthenticated(client) => client .request(request(self.url.clone() + "/" + route)) .await - .map_err(|e| RpcError::ConnectionError(e.to_string()))?, - Authentication::Authenticated(https_builder, user, pass) => { - let connection = https_builder - .clone() - .call( - self - .url - .parse() - .map_err(|e: ::Err| RpcError::ConnectionError(e.to_string()))?, - ) + .map_err(|e| RpcError::ConnectionError(format!("{e:?}")))?, + Authentication::Authenticated { username, password } => { + // This Client will drop and replace its connection on error, when monero-serai requires + // a single socket for the lifetime of this function + // Since dropping the connection will raise an error, and this function aborts on any + // error, this is fine + let client = Client::without_connection_pool(self.url.clone()) + .map_err(|_| RpcError::ConnectionError("invalid URL".to_string()))?; + let mut response = client + .request(request("/".to_string() + route)) .await - .map_err(|e| RpcError::ConnectionError(e.to_string()))?; - let (mut requester, connection) = hyper::client::conn::http1::handshake(connection) - .await - .map_err(|e| RpcError::ConnectionError(e.to_string()))?; - let connection_task = tokio::spawn(connection); + .map_err(|e| RpcError::ConnectionError(format!("{e:?}")))?; - let mut response = requester - .send_request(request("/".to_string() + route)) - .await - .map_err(|e| RpcError::ConnectionError(e.to_string()))?; // Only provide authentication if this daemon actually expects it if let Some(header) = response.headers().get("www-authenticate") { let mut request = request("/".to_string() + route); @@ -131,8 +108,8 @@ impl HttpRpc { ) .map_err(|_| RpcError::InvalidNode("invalid digest-auth response"))? .respond(&AuthContext::new_post::<_, _, _, &[u8]>( - user, - pass, + username, + password, "/".to_string() + route, None, )) @@ -142,19 +119,16 @@ impl HttpRpc { .unwrap(), ); - // Wait for the connection to be ready again - requester.ready().await.map_err(|e| RpcError::ConnectionError(e.to_string()))?; - // Make the request with the response challenge - response = requester - .send_request(request) + response = client + .request(request) .await - .map_err(|e| RpcError::ConnectionError(e.to_string()))?; - - // Also embed the requester so it's not dropped, causing the connection to close - connection_task_handle = Some((requester, connection_task.abort_handle())); + .map_err(|e| RpcError::ConnectionError(format!("{e:?}")))?; } + // Store the client so it's not dropped yet + connection = Some(client); + response } }; @@ -177,19 +151,19 @@ impl HttpRpc { let mut body = response.into_body(); while res.len() < length { let Some(data) = body.data().await else { break }; - res.extend(data.map_err(|e| RpcError::ConnectionError(e.to_string()))?.as_ref()); + res.extend(data.map_err(|e| RpcError::ConnectionError(format!("{e:?}")))?.as_ref()); } */ - let res = hyper::body::to_bytes(response.into_body()) + let mut res = Vec::with_capacity(128); + response + .body() .await - .map_err(|e| RpcError::ConnectionError(e.to_string()))? - .to_vec(); + .map_err(|e| RpcError::ConnectionError(format!("{e:?}")))? + .read_to_end(&mut res) + .unwrap(); - if let Some((_, connection_task)) = connection_task_handle { - // Clean up the connection task - connection_task.abort(); - } + drop(connection); Ok(res) } @@ -201,6 +175,6 @@ impl RpcConnection for HttpRpc { // TODO: Make this timeout configurable tokio::time::timeout(core::time::Duration::from_secs(30), self.inner_post(route, body)) .await - .map_err(|e| RpcError::ConnectionError(e.to_string()))? + .map_err(|e| RpcError::ConnectionError(format!("{e:?}")))? } } diff --git a/common/request/src/lib.rs b/common/request/src/lib.rs index 99bf3b03..86eb1aac 100644 --- a/common/request/src/lib.rs +++ b/common/request/src/lib.rs @@ -1,8 +1,18 @@ #![cfg_attr(docsrs, feature(doc_auto_cfg))] #![doc = include_str!("../README.md")] +use std::sync::Arc; + +use tokio::sync::Mutex; + use hyper_rustls::{HttpsConnectorBuilder, HttpsConnector}; -use hyper::{header::HeaderValue, client::HttpConnector}; +use hyper::{ + Uri, + header::HeaderValue, + body::Body, + service::Service, + client::{HttpConnector, conn::http1::SendRequest}, +}; pub use hyper; mod request; @@ -14,12 +24,20 @@ pub use response::*; #[derive(Debug)] pub enum Error { InvalidUri, + MissingHost, + InconsistentHost, + SslError, Hyper(hyper::Error), } #[derive(Clone, Debug)] enum Connection { ConnectionPool(hyper::Client>), + Connection { + https_builder: HttpsConnector, + host: Uri, + connection: Arc>>>, + }, } #[derive(Clone, Debug)] @@ -38,15 +56,53 @@ impl Client { } } - /* - fn without_connection_pool() -> Client {} - */ + pub fn without_connection_pool(host: String) -> Result { + Ok(Client { + connection: Connection::Connection { + https_builder: HttpsConnectorBuilder::new() + .with_native_roots() + .https_or_http() + .enable_http1() + .build(), + host: { + let uri: Uri = host.parse().map_err(|_| Error::InvalidUri)?; + if uri.host().is_none() { + Err(Error::MissingHost)?; + }; + uri + }, + connection: Arc::new(Mutex::new(None)), + }, + }) + } pub async fn request>(&self, request: R) -> Result { let request: Request = request.into(); let mut request = request.0; - if request.headers().get(hyper::header::HOST).is_none() { - let host = request.uri().host().ok_or(Error::InvalidUri)?.to_string(); + if let Some(header_host) = request.headers().get(hyper::header::HOST) { + match &self.connection { + Connection::ConnectionPool(_) => {} + Connection::Connection { host, .. } => { + if header_host.to_str().map_err(|_| Error::InvalidUri)? != host.host().unwrap() { + Err(Error::InconsistentHost)?; + } + } + } + } else { + let host = match &self.connection { + Connection::ConnectionPool(_) => { + request.uri().host().ok_or(Error::MissingHost)?.to_string() + } + Connection::Connection { host, .. } => { + let host_str = host.host().unwrap(); + if let Some(uri_host) = request.uri().host() { + if host_str != uri_host { + Err(Error::InconsistentHost)?; + } + } + host_str.to_string() + } + }; request .headers_mut() .insert(hyper::header::HOST, HeaderValue::from_str(&host).map_err(|_| Error::InvalidUri)?); @@ -54,6 +110,36 @@ impl Client { Ok(Response(match &self.connection { Connection::ConnectionPool(client) => client.request(request).await.map_err(Error::Hyper)?, + Connection::Connection { https_builder, host, connection } => { + let mut connection_lock = connection.lock().await; + + // If there's not a connection... + if connection_lock.is_none() { + let (requester, connection) = hyper::client::conn::http1::handshake( + https_builder.clone().call(host.clone()).await.map_err(|_| Error::SslError)?, + ) + .await + .map_err(Error::Hyper)?; + // This will die when we drop the requester, so we don't need to track an AbortHandle for + // it + tokio::spawn(connection); + *connection_lock = Some(requester); + } + + let connection = connection_lock.as_mut().unwrap(); + let mut err = connection.ready().await.err(); + if err.is_none() { + // Send the request + let res = connection.send_request(request).await; + if let Ok(res) = res { + return Ok(Response(res)); + } + err = res.err(); + } + // Since this connection has been put into an error state, drop it + *connection_lock = None; + Err(Error::Hyper(err.unwrap()))? + } })) } }