#![cfg_attr(docsrs, feature(doc_auto_cfg))] #![doc = include_str!("../README.md")] use std::sync::Arc; use tokio::sync::Mutex; use tower_service::Service as TowerService; #[cfg(feature = "tls")] use hyper_rustls::{HttpsConnectorBuilder, HttpsConnector}; use hyper::{Uri, header::HeaderValue, body::Bytes, client::conn::http1::SendRequest}; use hyper_util::{ rt::tokio::TokioExecutor, client::legacy::{Client as HyperClient, connect::HttpConnector}, }; pub use hyper; mod request; pub use request::*; mod response; pub use response::*; #[derive(Debug)] pub enum Error { InvalidUri, MissingHost, InconsistentHost, ConnectionError(Box<dyn Send + Sync + std::error::Error>), Hyper(hyper::Error), HyperUtil(hyper_util::client::legacy::Error), } #[cfg(not(feature = "tls"))] type Connector = HttpConnector; #[cfg(feature = "tls")] type Connector = HttpsConnector<HttpConnector>; #[derive(Clone, Debug)] enum Connection { ConnectionPool(HyperClient<Connector, Full<Bytes>>), Connection { connector: Connector, host: Uri, connection: Arc<Mutex<Option<SendRequest<Full<Bytes>>>>>, }, } #[derive(Clone, Debug)] pub struct Client { connection: Connection, } impl Client { fn connector() -> Connector { let mut res = HttpConnector::new(); res.set_keepalive(Some(core::time::Duration::from_secs(60))); res.set_nodelay(true); res.set_reuse_address(true); #[cfg(feature = "tls")] res.enforce_http(false); #[cfg(feature = "tls")] let res = HttpsConnectorBuilder::new() .with_native_roots() .expect("couldn't fetch system's SSL roots") .https_or_http() .enable_http1() .wrap_connector(res); res } pub fn with_connection_pool() -> Client { Client { connection: Connection::ConnectionPool( HyperClient::builder(TokioExecutor::new()) .pool_idle_timeout(core::time::Duration::from_secs(60)) .build(Self::connector()), ), } } pub fn without_connection_pool(host: &str) -> Result<Client, Error> { Ok(Client { connection: Connection::Connection { connector: Self::connector(), host: { let uri: Uri = host.parse().map_err(|_| Error::InvalidUri)?; if uri.host().is_none() { Err(Error::MissingHost)?; }; uri }, connection: Arc::new(Mutex::new(None)), }, }) } pub async fn request<R: Into<Request>>(&self, request: R) -> Result<Response<'_>, Error> { let request: Request = request.into(); let mut request = request.0; if let Some(header_host) = request.headers().get(hyper::header::HOST) { match &self.connection { Connection::ConnectionPool(_) => {} Connection::Connection { host, .. } => { if header_host.to_str().map_err(|_| Error::InvalidUri)? != host.host().unwrap() { Err(Error::InconsistentHost)?; } } } } else { let host = match &self.connection { Connection::ConnectionPool(_) => { request.uri().host().ok_or(Error::MissingHost)?.to_string() } Connection::Connection { host, .. } => { let host_str = host.host().unwrap(); if let Some(uri_host) = request.uri().host() { if host_str != uri_host { Err(Error::InconsistentHost)?; } } host_str.to_string() } }; request .headers_mut() .insert(hyper::header::HOST, HeaderValue::from_str(&host).map_err(|_| Error::InvalidUri)?); } let response = match &self.connection { Connection::ConnectionPool(client) => { client.request(request).await.map_err(Error::HyperUtil)? } Connection::Connection { connector, host, connection } => { let mut connection_lock = connection.lock().await; // If there's not a connection... if connection_lock.is_none() { let call_res = connector.clone().call(host.clone()).await; #[cfg(not(feature = "tls"))] let call_res = call_res.map_err(|e| Error::ConnectionError(format!("{e:?}").into())); #[cfg(feature = "tls")] let call_res = call_res.map_err(Error::ConnectionError); let (requester, connection) = hyper::client::conn::http1::handshake(call_res?).await.map_err(Error::Hyper)?; // This will die when we drop the requester, so we don't need to track an AbortHandle // for it tokio::spawn(connection); *connection_lock = Some(requester); } let connection = connection_lock.as_mut().unwrap(); let mut err = connection.ready().await.err(); if err.is_none() { // Send the request let res = connection.send_request(request).await; if let Ok(res) = res { return Ok(Response(res, self)); } err = res.err(); } // Since this connection has been put into an error state, drop it *connection_lock = None; Err(Error::Hyper(err.unwrap()))? } }; Ok(Response(response, self)) } }