Have simple-request Response's borrow the Client to ensure it's not prematurely dropped

This commit is contained in:
Luke Parker 2023-11-29 01:16:18 -05:00
parent 51bb434239
commit 8040fedddf
No known key found for this signature in database
3 changed files with 84 additions and 69 deletions

View file

@ -117,12 +117,50 @@ impl HttpRpc {
.map_err(|e| RpcError::ConnectionError(format!("couldn't make request: {e:?}"))) .map_err(|e| RpcError::ConnectionError(format!("couldn't make request: {e:?}")))
}; };
async fn body_from_response(response: Response<'_>) -> Result<Vec<u8>, RpcError> {
/*
let length = usize::try_from(
response
.headers()
.get("content-length")
.ok_or(RpcError::InvalidNode("no content-length header"))?
.to_str()
.map_err(|_| RpcError::InvalidNode("non-ascii content-length value"))?
.parse::<u32>()
.map_err(|_| RpcError::InvalidNode("non-u32 content-length value"))?,
)
.unwrap();
// Only pre-allocate 1 MB so a malicious node which claims a content-length of 1 GB actually
// has to send 1 GB of data to cause a 1 GB allocation
let mut res = Vec::with_capacity(length.max(1024 * 1024));
let mut body = response.into_body();
while res.len() < length {
let Some(data) = body.data().await else { break };
res.extend(data.map_err(|e| RpcError::ConnectionError(format!("{e:?}")))?.as_ref());
}
*/
let mut res = Vec::with_capacity(128);
response
.body()
.await
.map_err(|e| RpcError::ConnectionError(format!("{e:?}")))?
.read_to_end(&mut res)
.unwrap();
Ok(res)
}
for attempt in 0 .. 2 { for attempt in 0 .. 2 {
let response = match &self.authentication { return Ok(match &self.authentication {
Authentication::Unauthenticated(client) => client Authentication::Unauthenticated(client) => {
.request(request_fn(self.url.clone() + "/" + route)?) body_from_response(
.await client
.map_err(|e| RpcError::ConnectionError(format!("{e:?}")))?, .request(request_fn(self.url.clone() + "/" + route)?)
.await
.map_err(|e| RpcError::ConnectionError(format!("{e:?}")))?,
)
.await?
}
Authentication::Authenticated { username, password, connection } => { Authentication::Authenticated { username, password, connection } => {
let mut connection_lock = connection.lock().await; let mut connection_lock = connection.lock().await;
@ -168,26 +206,16 @@ impl HttpRpc {
); );
} }
let response_result = connection_lock let response = connection_lock
.1 .1
.request(request) .request(request)
.await .await
.map_err(|e| RpcError::ConnectionError(format!("{e:?}"))); .map_err(|e| RpcError::ConnectionError(format!("{e:?}")));
// If the connection entered an error state, drop the cached challenge as challenges are let (error, is_stale) = match &response {
// per-connection Err(e) => (Some(e.clone()), false),
// We don't need to create a new connection as simple-request will for us Ok(response) => (
if response_result.is_err() { None,
connection_lock.0 = None;
}
// If we're not already on our second attempt and:
// A) We had a connection error
// B) We need to re-auth due to this token being stale
// Move to the next loop iteration (retrying all of this)
if (attempt == 0) &&
(response_result.is_err() || {
let response = response_result.as_ref().unwrap();
if response.status() == StatusCode::UNAUTHORIZED { if response.status() == StatusCode::UNAUTHORIZED {
if let Some(header) = response.headers().get("www-authenticate") { if let Some(header) = response.headers().get("www-authenticate") {
header header
@ -201,49 +229,33 @@ impl HttpRpc {
} }
} else { } else {
false false
} },
}) ),
{ };
// Drop the cached authentication before we do
// If the connection entered an error state, drop the cached challenge as challenges are
// per-connection
// We don't need to create a new connection as simple-request will for us
if error.is_some() || is_stale {
connection_lock.0 = None; connection_lock.0 = None;
continue; // If we're not already on our second attempt, move to the next loop iteration
// (retrying all of this once)
if attempt == 0 {
continue;
}
if let Some(e) = error {
Err(e)?
} else {
debug_assert!(is_stale);
Err(RpcError::InvalidNode(
"node claimed fresh connection had stale authentication".to_string(),
))?
}
} else {
body_from_response(response.unwrap()).await?
} }
response_result?
} }
}; });
/*
let length = usize::try_from(
response
.headers()
.get("content-length")
.ok_or(RpcError::InvalidNode("no content-length header"))?
.to_str()
.map_err(|_| RpcError::InvalidNode("non-ascii content-length value"))?
.parse::<u32>()
.map_err(|_| RpcError::InvalidNode("non-u32 content-length value"))?,
)
.unwrap();
// Only pre-allocate 1 MB so a malicious node which claims a content-length of 1 GB actually
// has to send 1 GB of data to cause a 1 GB allocation
let mut res = Vec::with_capacity(length.max(1024 * 1024));
let mut body = response.into_body();
while res.len() < length {
let Some(data) = body.data().await else { break };
res.extend(data.map_err(|e| RpcError::ConnectionError(format!("{e:?}")))?.as_ref());
}
*/
let mut res = Vec::with_capacity(128);
response
.body()
.await
.map_err(|e| RpcError::ConnectionError(format!("{e:?}")))?
.read_to_end(&mut res)
.unwrap();
return Ok(res);
} }
unreachable!() unreachable!()

View file

@ -79,7 +79,7 @@ impl Client {
}) })
} }
pub async fn request<R: Into<Request>>(&self, request: R) -> Result<Response, Error> { pub async fn request<R: Into<Request>>(&self, request: R) -> Result<Response<'_>, Error> {
let request: Request = request.into(); let request: Request = request.into();
let mut request = request.0; let mut request = request.0;
if let Some(header_host) = request.headers().get(hyper::header::HOST) { if let Some(header_host) = request.headers().get(hyper::header::HOST) {
@ -111,7 +111,7 @@ impl Client {
.insert(hyper::header::HOST, HeaderValue::from_str(&host).map_err(|_| Error::InvalidUri)?); .insert(hyper::header::HOST, HeaderValue::from_str(&host).map_err(|_| Error::InvalidUri)?);
} }
Ok(Response(match &self.connection { let response = match &self.connection {
Connection::ConnectionPool(client) => client.request(request).await.map_err(Error::Hyper)?, Connection::ConnectionPool(client) => client.request(request).await.map_err(Error::Hyper)?,
Connection::Connection { connector, host, connection } => { Connection::Connection { connector, host, connection } => {
let mut connection_lock = connection.lock().await; let mut connection_lock = connection.lock().await;
@ -125,8 +125,8 @@ impl Client {
let call_res = call_res.map_err(Error::ConnectionError); let call_res = call_res.map_err(Error::ConnectionError);
let (requester, connection) = let (requester, connection) =
hyper::client::conn::http1::handshake(call_res?).await.map_err(Error::Hyper)?; hyper::client::conn::http1::handshake(call_res?).await.map_err(Error::Hyper)?;
// This will die when we drop the requester, so we don't need to track an AbortHandle for // This will die when we drop the requester, so we don't need to track an AbortHandle
// it // for it
tokio::spawn(connection); tokio::spawn(connection);
*connection_lock = Some(requester); *connection_lock = Some(requester);
} }
@ -137,7 +137,7 @@ impl Client {
// Send the request // Send the request
let res = connection.send_request(request).await; let res = connection.send_request(request).await;
if let Ok(res) = res { if let Ok(res) = res {
return Ok(Response(res)); return Ok(Response(res, self));
} }
err = res.err(); err = res.err();
} }
@ -145,6 +145,8 @@ impl Client {
*connection_lock = None; *connection_lock = None;
Err(Error::Hyper(err.unwrap()))? Err(Error::Hyper(err.unwrap()))?
} }
})) };
Ok(Response(response, self))
} }
} }

View file

@ -4,11 +4,12 @@ use hyper::{
body::{Buf, Body}, body::{Buf, Body},
}; };
use crate::Error; use crate::{Client, Error};
// Borrows the client so its async task lives as long as this response exists.
#[derive(Debug)] #[derive(Debug)]
pub struct Response(pub(crate) hyper::Response<Body>); pub struct Response<'a>(pub(crate) hyper::Response<Body>, pub(crate) &'a Client);
impl Response { impl<'a> Response<'a> {
pub fn status(&self) -> StatusCode { pub fn status(&self) -> StatusCode {
self.0.status() self.0.status()
} }