mirror of
https://github.com/serai-dex/serai.git
synced 2024-12-22 19:49:22 +00:00
Have simple-request Response's borrow the Client to ensure it's not prematurely dropped
This commit is contained in:
parent
51bb434239
commit
8040fedddf
3 changed files with 84 additions and 69 deletions
|
@ -117,12 +117,50 @@ impl HttpRpc {
|
||||||
.map_err(|e| RpcError::ConnectionError(format!("couldn't make request: {e:?}")))
|
.map_err(|e| RpcError::ConnectionError(format!("couldn't make request: {e:?}")))
|
||||||
};
|
};
|
||||||
|
|
||||||
|
async fn body_from_response(response: Response<'_>) -> Result<Vec<u8>, RpcError> {
|
||||||
|
/*
|
||||||
|
let length = usize::try_from(
|
||||||
|
response
|
||||||
|
.headers()
|
||||||
|
.get("content-length")
|
||||||
|
.ok_or(RpcError::InvalidNode("no content-length header"))?
|
||||||
|
.to_str()
|
||||||
|
.map_err(|_| RpcError::InvalidNode("non-ascii content-length value"))?
|
||||||
|
.parse::<u32>()
|
||||||
|
.map_err(|_| RpcError::InvalidNode("non-u32 content-length value"))?,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
// Only pre-allocate 1 MB so a malicious node which claims a content-length of 1 GB actually
|
||||||
|
// has to send 1 GB of data to cause a 1 GB allocation
|
||||||
|
let mut res = Vec::with_capacity(length.max(1024 * 1024));
|
||||||
|
let mut body = response.into_body();
|
||||||
|
while res.len() < length {
|
||||||
|
let Some(data) = body.data().await else { break };
|
||||||
|
res.extend(data.map_err(|e| RpcError::ConnectionError(format!("{e:?}")))?.as_ref());
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
let mut res = Vec::with_capacity(128);
|
||||||
|
response
|
||||||
|
.body()
|
||||||
|
.await
|
||||||
|
.map_err(|e| RpcError::ConnectionError(format!("{e:?}")))?
|
||||||
|
.read_to_end(&mut res)
|
||||||
|
.unwrap();
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
|
||||||
for attempt in 0 .. 2 {
|
for attempt in 0 .. 2 {
|
||||||
let response = match &self.authentication {
|
return Ok(match &self.authentication {
|
||||||
Authentication::Unauthenticated(client) => client
|
Authentication::Unauthenticated(client) => {
|
||||||
.request(request_fn(self.url.clone() + "/" + route)?)
|
body_from_response(
|
||||||
.await
|
client
|
||||||
.map_err(|e| RpcError::ConnectionError(format!("{e:?}")))?,
|
.request(request_fn(self.url.clone() + "/" + route)?)
|
||||||
|
.await
|
||||||
|
.map_err(|e| RpcError::ConnectionError(format!("{e:?}")))?,
|
||||||
|
)
|
||||||
|
.await?
|
||||||
|
}
|
||||||
Authentication::Authenticated { username, password, connection } => {
|
Authentication::Authenticated { username, password, connection } => {
|
||||||
let mut connection_lock = connection.lock().await;
|
let mut connection_lock = connection.lock().await;
|
||||||
|
|
||||||
|
@ -168,26 +206,16 @@ impl HttpRpc {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let response_result = connection_lock
|
let response = connection_lock
|
||||||
.1
|
.1
|
||||||
.request(request)
|
.request(request)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| RpcError::ConnectionError(format!("{e:?}")));
|
.map_err(|e| RpcError::ConnectionError(format!("{e:?}")));
|
||||||
|
|
||||||
// If the connection entered an error state, drop the cached challenge as challenges are
|
let (error, is_stale) = match &response {
|
||||||
// per-connection
|
Err(e) => (Some(e.clone()), false),
|
||||||
// We don't need to create a new connection as simple-request will for us
|
Ok(response) => (
|
||||||
if response_result.is_err() {
|
None,
|
||||||
connection_lock.0 = None;
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we're not already on our second attempt and:
|
|
||||||
// A) We had a connection error
|
|
||||||
// B) We need to re-auth due to this token being stale
|
|
||||||
// Move to the next loop iteration (retrying all of this)
|
|
||||||
if (attempt == 0) &&
|
|
||||||
(response_result.is_err() || {
|
|
||||||
let response = response_result.as_ref().unwrap();
|
|
||||||
if response.status() == StatusCode::UNAUTHORIZED {
|
if response.status() == StatusCode::UNAUTHORIZED {
|
||||||
if let Some(header) = response.headers().get("www-authenticate") {
|
if let Some(header) = response.headers().get("www-authenticate") {
|
||||||
header
|
header
|
||||||
|
@ -201,49 +229,33 @@ impl HttpRpc {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
false
|
false
|
||||||
}
|
},
|
||||||
})
|
),
|
||||||
{
|
};
|
||||||
// Drop the cached authentication before we do
|
|
||||||
|
// If the connection entered an error state, drop the cached challenge as challenges are
|
||||||
|
// per-connection
|
||||||
|
// We don't need to create a new connection as simple-request will for us
|
||||||
|
if error.is_some() || is_stale {
|
||||||
connection_lock.0 = None;
|
connection_lock.0 = None;
|
||||||
continue;
|
// If we're not already on our second attempt, move to the next loop iteration
|
||||||
|
// (retrying all of this once)
|
||||||
|
if attempt == 0 {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if let Some(e) = error {
|
||||||
|
Err(e)?
|
||||||
|
} else {
|
||||||
|
debug_assert!(is_stale);
|
||||||
|
Err(RpcError::InvalidNode(
|
||||||
|
"node claimed fresh connection had stale authentication".to_string(),
|
||||||
|
))?
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
body_from_response(response.unwrap()).await?
|
||||||
}
|
}
|
||||||
|
|
||||||
response_result?
|
|
||||||
}
|
}
|
||||||
};
|
});
|
||||||
|
|
||||||
/*
|
|
||||||
let length = usize::try_from(
|
|
||||||
response
|
|
||||||
.headers()
|
|
||||||
.get("content-length")
|
|
||||||
.ok_or(RpcError::InvalidNode("no content-length header"))?
|
|
||||||
.to_str()
|
|
||||||
.map_err(|_| RpcError::InvalidNode("non-ascii content-length value"))?
|
|
||||||
.parse::<u32>()
|
|
||||||
.map_err(|_| RpcError::InvalidNode("non-u32 content-length value"))?,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
// Only pre-allocate 1 MB so a malicious node which claims a content-length of 1 GB actually
|
|
||||||
// has to send 1 GB of data to cause a 1 GB allocation
|
|
||||||
let mut res = Vec::with_capacity(length.max(1024 * 1024));
|
|
||||||
let mut body = response.into_body();
|
|
||||||
while res.len() < length {
|
|
||||||
let Some(data) = body.data().await else { break };
|
|
||||||
res.extend(data.map_err(|e| RpcError::ConnectionError(format!("{e:?}")))?.as_ref());
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
let mut res = Vec::with_capacity(128);
|
|
||||||
response
|
|
||||||
.body()
|
|
||||||
.await
|
|
||||||
.map_err(|e| RpcError::ConnectionError(format!("{e:?}")))?
|
|
||||||
.read_to_end(&mut res)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
return Ok(res);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
unreachable!()
|
unreachable!()
|
||||||
|
|
|
@ -79,7 +79,7 @@ impl Client {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn request<R: Into<Request>>(&self, request: R) -> Result<Response, Error> {
|
pub async fn request<R: Into<Request>>(&self, request: R) -> Result<Response<'_>, Error> {
|
||||||
let request: Request = request.into();
|
let request: Request = request.into();
|
||||||
let mut request = request.0;
|
let mut request = request.0;
|
||||||
if let Some(header_host) = request.headers().get(hyper::header::HOST) {
|
if let Some(header_host) = request.headers().get(hyper::header::HOST) {
|
||||||
|
@ -111,7 +111,7 @@ impl Client {
|
||||||
.insert(hyper::header::HOST, HeaderValue::from_str(&host).map_err(|_| Error::InvalidUri)?);
|
.insert(hyper::header::HOST, HeaderValue::from_str(&host).map_err(|_| Error::InvalidUri)?);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Response(match &self.connection {
|
let response = match &self.connection {
|
||||||
Connection::ConnectionPool(client) => client.request(request).await.map_err(Error::Hyper)?,
|
Connection::ConnectionPool(client) => client.request(request).await.map_err(Error::Hyper)?,
|
||||||
Connection::Connection { connector, host, connection } => {
|
Connection::Connection { connector, host, connection } => {
|
||||||
let mut connection_lock = connection.lock().await;
|
let mut connection_lock = connection.lock().await;
|
||||||
|
@ -125,8 +125,8 @@ impl Client {
|
||||||
let call_res = call_res.map_err(Error::ConnectionError);
|
let call_res = call_res.map_err(Error::ConnectionError);
|
||||||
let (requester, connection) =
|
let (requester, connection) =
|
||||||
hyper::client::conn::http1::handshake(call_res?).await.map_err(Error::Hyper)?;
|
hyper::client::conn::http1::handshake(call_res?).await.map_err(Error::Hyper)?;
|
||||||
// This will die when we drop the requester, so we don't need to track an AbortHandle for
|
// This will die when we drop the requester, so we don't need to track an AbortHandle
|
||||||
// it
|
// for it
|
||||||
tokio::spawn(connection);
|
tokio::spawn(connection);
|
||||||
*connection_lock = Some(requester);
|
*connection_lock = Some(requester);
|
||||||
}
|
}
|
||||||
|
@ -137,7 +137,7 @@ impl Client {
|
||||||
// Send the request
|
// Send the request
|
||||||
let res = connection.send_request(request).await;
|
let res = connection.send_request(request).await;
|
||||||
if let Ok(res) = res {
|
if let Ok(res) = res {
|
||||||
return Ok(Response(res));
|
return Ok(Response(res, self));
|
||||||
}
|
}
|
||||||
err = res.err();
|
err = res.err();
|
||||||
}
|
}
|
||||||
|
@ -145,6 +145,8 @@ impl Client {
|
||||||
*connection_lock = None;
|
*connection_lock = None;
|
||||||
Err(Error::Hyper(err.unwrap()))?
|
Err(Error::Hyper(err.unwrap()))?
|
||||||
}
|
}
|
||||||
}))
|
};
|
||||||
|
|
||||||
|
Ok(Response(response, self))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,11 +4,12 @@ use hyper::{
|
||||||
body::{Buf, Body},
|
body::{Buf, Body},
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::Error;
|
use crate::{Client, Error};
|
||||||
|
|
||||||
|
// Borrows the client so its async task lives as long as this response exists.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Response(pub(crate) hyper::Response<Body>);
|
pub struct Response<'a>(pub(crate) hyper::Response<Body>, pub(crate) &'a Client);
|
||||||
impl Response {
|
impl<'a> Response<'a> {
|
||||||
pub fn status(&self) -> StatusCode {
|
pub fn status(&self) -> StatusCode {
|
||||||
self.0.status()
|
self.0.status()
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue