diff --git a/src/utils/parceyaml.rs b/src/utils/parceyaml.rs index c02e9d8..5a6aa61 100644 --- a/src/utils/parceyaml.rs +++ b/src/utils/parceyaml.rs @@ -67,7 +67,6 @@ pub async fn load_configuration(d: &str, kind: &str) -> (Option, async fn populate_headers_and_auth(config: &mut Configuration, parsed: &Config) { let mut ch: Vec<(Arc, Arc)> = Vec::new(); - ch.push((Arc::from("Server"), Arc::from("Aralez"))); if let Some(headers) = &parsed.client_headers { for header in headers { if let Some((key, val)) = header.split_once(':') { @@ -80,7 +79,6 @@ async fn populate_headers_and_auth(config: &mut Configuration, parsed: &Config) config.client_headers.insert(Arc::from("GLOBAL_CLIENT_HEADERS"), global_headers); let mut sh: Vec<(Arc, Arc)> = Vec::new(); - sh.push((Arc::from("X-Proxy-Server"), Arc::from("Aralez"))); if let Some(headers) = &parsed.server_headers { for header in headers { if let Some((key, val)) = header.split_once(':') { diff --git a/src/web/gethosts.rs b/src/web/gethosts.rs index bbc4ce8..efef9ea 100644 --- a/src/web/gethosts.rs +++ b/src/web/gethosts.rs @@ -56,49 +56,66 @@ impl GetHost for LB { } fn get_header(&self, peer: &str, path: &str) -> Option { - let client_entry = self.client_headers.get(peer)?; - let server_entry = self.server_headers.get(peer)?; + let client_entry = self.client_headers.get(peer); + let server_entry = self.server_headers.get(peer); + if client_entry.is_none() && server_entry.is_none() { + return None; + } let mut current_path = path; let mut clnt_match = None; - loop { - if let Some(entry) = client_entry.get(current_path) { - if !entry.value().is_empty() { - clnt_match = Some(entry.value().clone()); + if let Some(client_entry) = client_entry { + loop { + if let Some(entry) = client_entry.get(current_path) { + if !entry.value().is_empty() { + clnt_match = Some(entry.value().clone()); + break; + } + } + if current_path == "/" { + break; + } + if let Some(pos) = current_path.rfind('/') { + current_path = if pos == 0 { "/" } else { ¤t_path[..pos] }; + } else { break; } - } - if let Some(pos) = current_path.rfind('/') { - current_path = if pos == 0 { "/" } else { ¤t_path[..pos] }; - } else { - break; } } current_path = path; let mut serv_match = None; - loop { - if let Some(entry) = server_entry.get(current_path) { - if !entry.value().is_empty() { - serv_match = Some(entry.value().clone()); - break; - } - } - if let Some(pos) = current_path.rfind('/') { - current_path = if pos == 0 { "/" } else { ¤t_path[..pos] }; - } else { - break; - } - if serv_match.is_none() { - if let Some(entry) = server_entry.get("/") { + if let Some(server_entry) = server_entry { + loop { + if let Some(entry) = server_entry.get(current_path) { if !entry.value().is_empty() { serv_match = Some(entry.value().clone()); break; } } + if current_path == "/" { + if let Some(entry) = server_entry.get("/") { + if !entry.value().is_empty() { + serv_match = Some(entry.value().clone()); + break; + } + } + break; + } + if let Some(pos) = current_path.rfind('/') { + current_path = if pos == 0 { "/" } else { ¤t_path[..pos] }; + } else { + break; + } } } - Some(GetHostsReturHeaders { + let result = GetHostsReturHeaders { client_headers: clnt_match, server_headers: serv_match, - }) + }; + + if result.client_headers.is_some() || result.server_headers.is_some() { + Some(result) + } else { + None + } } } diff --git a/src/web/proxyhttp.rs b/src/web/proxyhttp.rs index 6e39ac0..218021e 100644 --- a/src/web/proxyhttp.rs +++ b/src/web/proxyhttp.rs @@ -1,7 +1,7 @@ use crate::utils::auth::authenticate; use crate::utils::metrics::*; use crate::utils::structs::{AppConfig, Extraparams, Headers, InnerMap, UpstreamsDashMap, UpstreamsIdMap}; -use crate::web::gethosts::GetHost; +use crate::web::gethosts::{GetHost, GetHostsReturHeaders}; use arc_swap::ArcSwap; use async_trait::async_trait; use axum::body::Bytes; @@ -41,7 +41,7 @@ pub struct Context { hostname: Option>, upstream_peer: Option>, extraparams: arc_swap::Guard>, - client_headers: Arc, Arc)>>, + client_headers: Option, Arc)>>>, } #[async_trait] @@ -56,7 +56,7 @@ impl ProxyHttp for LB { hostname: None, upstream_peer: None, extraparams: self.extraparams.load(), - client_headers: Arc::new(Vec::new()), + client_headers: None, } } async fn request_filter(&self, session: &mut Session, _ctx: &mut Self::CTX) -> Result { @@ -178,24 +178,28 @@ impl ProxyHttp for LB { } async fn upstream_request_filter(&self, session: &mut Session, upstream_request: &mut RequestHeader, ctx: &mut Self::CTX) -> Result<()> { - if let Some(hostname) = ctx.hostname.as_ref() { - upstream_request.insert_header("Host", hostname.as_ref())?; + if let Some(hostname) = ctx.hostname.as_deref() { + upstream_request.insert_header("Host", hostname)?; } - if let Some(peer) = ctx.upstream_peer.as_ref() { + if let Some(peer) = &ctx.upstream_peer { upstream_request.insert_header("X-Forwarded-For", &*peer.address)?; } + let hostname = ctx.hostname.as_deref().unwrap_or("localhost"); + let path = session.req_header().uri.path(); - if let Some(headers) = self.get_header(ctx.hostname.as_ref().unwrap_or(&Arc::from("localhost")), session.req_header().uri.path()) { - if let Some(server_headers) = headers.server_headers { - for (k, v) in server_headers.iter() { - upstream_request.insert_header(k.to_string(), v.as_ref())?; - } - } - if let Some(client_headers) = headers.client_headers { - ctx.client_headers = Arc::new(client_headers); + let GetHostsReturHeaders { server_headers, client_headers } = match self.get_header(hostname, path) { + Some(h) => h, + None => return Ok(()), + }; + + if let Some(sh) = server_headers { + for (k, v) in sh { + upstream_request.insert_header(k.to_string(), v.as_ref())?; } } - + if let Some(ch) = client_headers { + ctx.client_headers = Some(Arc::new(ch)); + } Ok(()) } async fn response_filter(&self, session: &mut Session, _upstream_response: &mut ResponseHeader, ctx: &mut Self::CTX) -> Result<()> { @@ -211,8 +215,10 @@ impl ProxyHttp for LB { session.write_response_header(Box::new(redirect_response), false).await?; } - for (k, v) in ctx.client_headers.iter() { - _upstream_response.insert_header(k.to_string(), v.as_ref())?; + if let Some(client_headers) = &ctx.client_headers { + for (k, v) in client_headers.iter() { + _upstream_response.insert_header(k.to_string(), v.as_ref())?; + } } session.set_keepalive(Some(300));