Don't default to basic-auth if it's enabled, yet require it to be specified

This commit is contained in:
Luke Parker 2023-11-06 10:31:26 -05:00
parent b9983bf133
commit b680bb532b
No known key found for this signature in database
5 changed files with 117 additions and 66 deletions

View file

@ -6,7 +6,7 @@ use thiserror::Error;
use serde::{Deserialize, de::DeserializeOwned};
use serde_json::json;
use simple_request::{Request, Client};
use simple_request::{hyper, Request, Client};
use bitcoin::{
hashes::{Hash, hex::FromHex},
@ -107,18 +107,20 @@ impl Rpc {
method: &str,
params: serde_json::Value,
) -> Result<Response, RpcError> {
let mut request = Request::from(
hyper::Request::post(&self.url)
.header("Content-Type", "application/json")
.body(
serde_json::to_vec(&json!({ "jsonrpc": "2.0", "method": method, "params": params }))
.unwrap()
.into(),
)
.unwrap(),
);
request.with_basic_auth();
let mut res = self
.client
.request(
Request::post(&self.url)
.header("Content-Type", "application/json")
.body(
serde_json::to_vec(&json!({ "jsonrpc": "2.0", "method": method, "params": params }))
.unwrap()
.into(),
)
.unwrap(),
)
.request(request)
.await
.map_err(|_| RpcError::ConnectionError)?
.body()

View file

@ -2,27 +2,19 @@
#![doc = include_str!("../README.md")]
use hyper_rustls::{HttpsConnectorBuilder, HttpsConnector};
use hyper::{
StatusCode,
header::{HeaderValue, HeaderMap},
body::{Buf, Body},
Response as HyperResponse,
client::HttpConnector,
};
pub use hyper::{self, Request};
use hyper::{header::HeaderValue, client::HttpConnector};
pub use hyper;
mod request;
pub use request::*;
mod response;
pub use response::*;
#[derive(Debug)]
pub struct Response(HyperResponse<Body>);
impl Response {
pub fn status(&self) -> StatusCode {
self.0.status()
}
pub fn headers(&self) -> &HeaderMap<HeaderValue> {
self.0.headers()
}
pub async fn body(self) -> Result<impl std::io::Read, hyper::Error> {
Ok(hyper::body::aggregate(self.0.into_body()).await?.reader())
}
pub enum Error {
InvalidUri,
Hyper(hyper::Error),
}
#[derive(Clone, Debug)]
@ -35,12 +27,6 @@ pub struct Client {
connection: Connection,
}
#[derive(Debug)]
pub enum Error {
InvalidHost,
Hyper(hyper::Error),
}
impl Client {
fn https_builder() -> HttpsConnector<HttpConnector> {
HttpsConnectorBuilder::new().with_native_roots().https_or_http().enable_http1().build()
@ -56,38 +42,14 @@ impl Client {
fn without_connection_pool() -> Client {}
*/
pub async fn request(&self, mut request: Request<Body>) -> Result<Response, Error> {
pub async fn request<R: Into<Request>>(&self, request: R) -> Result<Response, Error> {
let request: Request = request.into();
let mut request = request.0;
if request.headers().get(hyper::header::HOST).is_none() {
let host = request.uri().host().ok_or(Error::InvalidHost)?.to_string();
let host = request.uri().host().ok_or(Error::InvalidUri)?.to_string();
request
.headers_mut()
.insert(hyper::header::HOST, HeaderValue::from_str(&host).map_err(|_| Error::InvalidHost)?);
}
#[cfg(feature = "basic-auth")]
if request.headers().get(hyper::header::AUTHORIZATION).is_none() {
if let Some(authority) = request.uri().authority() {
let authority = authority.as_str();
if authority.contains('@') {
// Decode the username and password from the URI
let mut userpass = authority.split('@').next().unwrap().to_string();
// If the password is "", the URI may omit :, yet the authentication will still expect it
if !userpass.contains(':') {
userpass.push(':');
}
use zeroize::Zeroize;
use base64ct::{Encoding, Base64};
let mut encoded = Base64::encode_string(userpass.as_bytes());
userpass.zeroize();
request.headers_mut().insert(
hyper::header::AUTHORIZATION,
HeaderValue::from_str(&format!("Basic {encoded}")).unwrap(),
);
encoded.zeroize();
}
}
.insert(hyper::header::HOST, HeaderValue::from_str(&host).map_err(|_| Error::InvalidUri)?);
}
Ok(Response(match &self.connection {

View file

@ -0,0 +1,66 @@
use hyper::body::Body;
#[cfg(feature = "basic-auth")]
use hyper::header::HeaderValue;
#[cfg(feature = "basic-auth")]
use crate::Error;
#[derive(Debug)]
pub struct Request(pub(crate) hyper::Request<Body>);
impl Request {
#[cfg(feature = "basic-auth")]
fn username_password_from_uri(&self) -> Result<(String, String), Error> {
if let Some(authority) = self.0.uri().authority() {
let authority = authority.as_str();
if authority.contains('@') {
// Decode the username and password from the URI
let mut userpass = authority.split('@').next().unwrap().to_string();
let mut userpass_iter = userpass.split(':');
let username = userpass_iter.next().unwrap().to_string();
let password = userpass_iter.next().map(str::to_string).unwrap_or_else(String::new);
zeroize::Zeroize::zeroize(&mut userpass);
return Ok((username, password));
}
}
Err(Error::InvalidUri)
}
#[cfg(feature = "basic-auth")]
pub fn basic_auth(&mut self, username: &str, password: &str) {
use zeroize::Zeroize;
use base64ct::{Encoding, Base64};
let mut formatted = format!("{username}:{password}");
let mut encoded = Base64::encode_string(formatted.as_bytes());
formatted.zeroize();
self.0.headers_mut().insert(
hyper::header::AUTHORIZATION,
HeaderValue::from_str(&format!("Basic {encoded}")).unwrap(),
);
encoded.zeroize();
}
#[cfg(feature = "basic-auth")]
pub fn basic_auth_from_uri(&mut self) -> Result<(), Error> {
let (mut username, mut password) = self.username_password_from_uri()?;
self.basic_auth(&username, &password);
use zeroize::Zeroize;
username.zeroize();
password.zeroize();
Ok(())
}
#[cfg(feature = "basic-auth")]
pub fn with_basic_auth(&mut self) {
let _ = self.basic_auth_from_uri();
}
}
impl From<hyper::Request<Body>> for Request {
fn from(request: hyper::Request<Body>) -> Request {
Request(request)
}
}

View file

@ -0,0 +1,21 @@
use hyper::{
StatusCode,
header::{HeaderValue, HeaderMap},
body::{Buf, Body},
};
use crate::Error;
#[derive(Debug)]
pub struct Response(pub(crate) hyper::Response<Body>);
impl Response {
pub fn status(&self) -> StatusCode {
self.0.status()
}
pub fn headers(&self) -> &HeaderMap<HeaderValue> {
self.0.headers()
}
pub async fn body(self) -> Result<impl std::io::Read, Error> {
hyper::body::aggregate(self.0.into_body()).await.map(Buf::reader).map_err(Error::Hyper)
}
}

View file

@ -11,7 +11,7 @@ use schnorr_signatures::SchnorrSignature;
use serde::{Serialize, Deserialize};
use simple_request::{Request, Client};
use simple_request::{hyper::Request, Client};
use serai_env as env;