From 74821654f3b9bc2c8825db796fa4143d493ecf91 Mon Sep 17 00:00:00 2001 From: Ara Sadoyan Date: Sat, 22 Nov 2025 23:18:06 +0100 Subject: [PATCH] Added support to send custom headers to upstream servers. --- Cargo.lock | 16 +++++++++ Cargo.toml | 2 +- README.md | 14 ++++++-- etc/upstreams.yaml | 39 ++++++++++++--------- src/utils.rs | 1 + src/utils/httpclient.rs | 75 ++++++++++++++++++++++++++++++++++++++++ src/utils/kuberconsul.rs | 69 ++++++++++++++++++------------------ src/utils/parceyaml.rs | 58 +++++++++++++++++++++---------- src/utils/structs.rs | 13 ++++--- src/web/bgservice.rs | 34 +++++++++++++----- src/web/gethosts.rs | 52 +++++++++++++++++++++------- src/web/proxyhttp.rs | 57 ++++++++++++++++++++++-------- src/web/start.rs | 6 ++-- 13 files changed, 321 insertions(+), 115 deletions(-) create mode 100644 src/utils/httpclient.rs diff --git a/Cargo.lock b/Cargo.lock index db94e1a..01a20c3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2478,6 +2478,7 @@ dependencies = [ "bytes", "encoding_rs", "futures-core", + "futures-util", "h2", "http", "http-body", @@ -2499,12 +2500,14 @@ dependencies = [ "sync_wrapper", "tokio", "tokio-native-tls", + "tokio-util", "tower", "tower-http", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", ] @@ -3499,6 +3502,19 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "wasm-streams" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.77" diff --git a/Cargo.toml b/Cargo.toml index 0eda853..3557a85 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,7 @@ futures = "0.3.31" notify = "8.2.0" axum = { version = "0.8.4" } axum-server = { version = "0.7.2", features = ["tls-openssl"] } -reqwest = { version = "0.12.23", features = ["json", "native-tls-alpn"] } +reqwest = { version = "0.12.23", features = ["json", "native-tls-alpn", "stream"] } #reqwest = { version = "0.12.15", features = ["json", "rustls-tls"] } #reqwest = { version = "0.12.15", default-features = false, features = ["rustls-tls", "json"] } diff --git a/README.md b/README.md index 60d7961..0fba6f9 100644 --- a/README.md +++ b/README.md @@ -191,7 +191,10 @@ provider: "file" sticky_sessions: false to_https: false rate_limit: 10 -headers: +server_headers: + - "X-Forwarded-Proto:https" + - "X-Forwarded-Port:443" +client_headers: - "Access-Control-Allow-Origin:*" - "Access-Control-Allow-Methods:POST, GET, OPTIONS" - "Access-Control-Max-Age:86400" @@ -203,7 +206,10 @@ myhost.mydomain.com: "/": rate_limit: 20 to_https: false - headers: + server_headers: + - "X-Something-Else:Foobar" + - "X-Another-Header:Hohohohoho" + client_headers: - "X-Some-Thing:Yaaaaaaaaaaaaaaa" - "X-Proxy-From:Hopaaaaaaaaaaaar" servers: @@ -211,7 +217,7 @@ myhost.mydomain.com: - "127.0.0.2:8000" "/foo": to_https: true - headers: + client_headers: - "X-Another-Header:Hohohohoho" servers: - "127.0.0.4:8443" @@ -226,6 +232,8 @@ myhost.mydomain.com: - Sticky sessions are disabled globally. This setting applies to all upstreams. If enabled all requests will be 301 redirected to HTTPS. - HTTP to HTTPS redirect disabled globally, but can be overridden by `to_https` setting per upstream. +- All upstreams will receive custom headers : `X-Forwarded-Proto:https` and `X-Forwarded-Port:443` +- Additionally, myhost.mydomain.com with path `/` will receive custom headers : `X-Another-Header:Hohohohoho` and `X-Something-Else:Foobar` - Requests to each hosted domains will be limited to 10 requests per second per virtualhost. - Requests limits are calculated per requester ip plus requested virtualhost. - If the requester exceeds the limit it will receive `429 Too Many Requests` error. diff --git a/etc/upstreams.yaml b/etc/upstreams.yaml index 7ffaf0d..fb08590 100644 --- a/etc/upstreams.yaml +++ b/etc/upstreams.yaml @@ -3,11 +3,13 @@ provider: "file" # "file" "consul" "kubernetes" sticky_sessions: false to_https: false rate_limit: 100 -headers: +server_headers: + - "X-Forwarded-Proto:https" + - "X-Forwarded-Port:443" +client_headers: - "Access-Control-Allow-Origin:*" - "Access-Control-Allow-Methods:POST, GET, OPTIONS" - "Access-Control-Max-Age:86400" - - "Strict-Transport-Security:max-age=31536000; includeSubDomains; preload" #authorization: # type: "jwt" # creds: "910517d9-f9a1-48de-8826-dbadacbd84af-cb6f830e-ab16-47ec-9d8f-0090de732774" @@ -21,38 +23,38 @@ consul: - "http://192.168.1.200:8500" - "http://192.168.1.201:8500" services: # hostname: The hostname to access the proxy server, upstream : The real service name in Consul database. - - hostname: "vt-webapi-service" - upstream: "vt-webapi-service-health" + - hostname: "webapi-service" + upstream: "webapi-service-health" path: "/one" - headers: + client_headers: - "X-Some-Thing:Yaaaaaaaaaaaaaaa" - "X-Proxy-From:Aralez" rate_limit: 1 to_https: false - - hostname: "vt-webapi-service" - upstream: "vt-webapi-service-health" + - hostname: "webapi-service" + upstream: "webapi-service-health" path: "/" token: "8e2db809-845b-45e1-8b47-2c8356a09da0-a4370955-18c2-4d6e-a8f8-ffcc0b47be81" # Consul server access token, If Consul auth is enabled kubernetes: servers: - "192.168.1.55:443" #For testing only, overrides with KUBERNETES_SERVICE_HOST : KUBERNETES_SERVICE_PORT_HTTPS env variables. services: - - hostname: "vt-webapi-service" + - hostname: "webapi-service" path: "/" - upstream: "vt-webapi-service" - - hostname: "vt-webapi-service" + upstream: "webapi-service" + - hostname: "webapi-service" upstream: "vt-console-service" path: "/one" - headers: + client_headers: - "X-Some-Thing:Yaaaaaaaaaaaaaaa" - "X-Proxy-From:Aralez" rate_limit: 100 to_https: false - - hostname: "vt-webapi-service" + - hostname: "webapi-service" upstream: "vt-rambulik-service" path: "/two" - - hostname: "vt-websocket-service" - upstream: "vt-websocket-service" + - hostname: "websocket-service" + upstream: "websocket-service" path: "/" tokenpath: "/path/to/kubetoken.txt" #If not set, will default to /var/run/secrets/kubernetes.io/serviceaccount/token upstreams: @@ -61,7 +63,7 @@ upstreams: "/": rate_limit: 200 to_https: false - headers: + client_headers: - "X-Proxy-From:Aralez" servers: - "127.0.0.1:8000" @@ -71,7 +73,10 @@ upstreams: - "127.0.0.5:8000" "/ping": to_https: false - headers: + server_headers: + - "X-Forwarded-Proto:https" + - "X-Forwarded-Port:443" + client_headers: - "X-Some-Thing:Yaaaaaaaaaaaaaaa" - "X-Proxy-From:Aralez" servers: @@ -84,7 +89,7 @@ upstreams: paths: "/": to_https: false - headers: + client_headers: - "X-Some-Thing:Yaaaaaaaaaaaaaaa" servers: - "192.168.1.1:8000" diff --git a/src/utils.rs b/src/utils.rs index 3a2ee68..9e3eff6 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -12,3 +12,4 @@ pub mod state; pub mod structs; pub mod tls; pub mod tools; +// pub mod watchksecret; diff --git a/src/utils/httpclient.rs b/src/utils/httpclient.rs new file mode 100644 index 0000000..2997ceb --- /dev/null +++ b/src/utils/httpclient.rs @@ -0,0 +1,75 @@ +use crate::utils::kuberconsul::{match_path, ConsulService, KubeEndpoints}; +use crate::utils::structs::{InnerMap, ServiceMapping}; +use axum::http::{HeaderMap, HeaderValue}; +use dashmap::DashMap; +use reqwest::Client; +use std::sync::atomic::AtomicUsize; +use std::time::Duration; + +pub async fn for_consul(url: String, token: Option, conf: &ServiceMapping) -> Option, AtomicUsize)>> { + let client = Client::builder().timeout(Duration::from_secs(2)).danger_accept_invalid_certs(true).build().ok()?; + let mut headers = HeaderMap::new(); + if let Some(token) = token { + headers.insert("X-Consul-Token", HeaderValue::from_str(&token).unwrap()); + } + let to = Duration::from_secs(1); + let resp = client.get(url).timeout(to).send().await.ok()?; + if !resp.status().is_success() { + eprintln!("Consul API returned status: {}", resp.status()); + return None; + } + let mut inner_vec = Vec::new(); + let upstreams: DashMap, AtomicUsize)> = DashMap::new(); + let endpoints: Vec = resp.json().await.ok()?; + for subsets in endpoints { + let addr = subsets.tagged_addresses.get("lan_ipv4").unwrap().address.clone(); + let prt = subsets.tagged_addresses.get("lan_ipv4").unwrap().port.clone(); + let to_add = InnerMap { + address: addr, + port: prt, + is_ssl: false, + is_http2: false, + to_https: conf.to_https.unwrap_or(false), + rate_limit: conf.rate_limit, + healthcheck: None, + }; + inner_vec.push(to_add); + } + match_path(&conf, &upstreams, inner_vec.clone()); + Some(upstreams) +} + +pub async fn for_kuber(url: &str, token: &str, conf: &ServiceMapping) -> Option, AtomicUsize)>> { + let to = Duration::from_secs(10); + let client = Client::builder().timeout(Duration::from_secs(10)).danger_accept_invalid_certs(true).build().ok()?; + let resp = client.get(url).timeout(to).bearer_auth(token).send().await.ok()?; + if !resp.status().is_success() { + eprintln!("Kubernetes API returned status: {}", resp.status()); + return None; + } + let endpoints: KubeEndpoints = resp.json().await.ok()?; + let upstreams: DashMap, AtomicUsize)> = DashMap::new(); + if let Some(subsets) = endpoints.subsets { + for subset in subsets { + if let (Some(addresses), Some(ports)) = (subset.addresses, subset.ports) { + let mut inner_vec = Vec::new(); + for addr in addresses { + for port in &ports { + let to_add = InnerMap { + address: addr.ip.clone(), + port: port.port.clone(), + is_ssl: false, + is_http2: false, + to_https: conf.to_https.unwrap_or(false), + rate_limit: conf.rate_limit, + healthcheck: None, + }; + inner_vec.push(to_add); + } + } + match_path(&conf, &upstreams, inner_vec.clone()); + } + } + } + Some(upstreams) +} diff --git a/src/utils/kuberconsul.rs b/src/utils/kuberconsul.rs index 0d8b4a7..36fc284 100644 --- a/src/utils/kuberconsul.rs +++ b/src/utils/kuberconsul.rs @@ -94,43 +94,45 @@ pub struct ConsulDiscovery; impl ServiceDiscovery for KubernetesDiscovery { async fn fetch_upstreams(&self, config: Arc, mut toreturn: Sender) { let prev_upstreams = UpstreamsDashMap::new(); - loop { - let upstreams = UpstreamsDashMap::new(); - if let Some(kuber) = config.kubernetes.clone() { - let path = kuber.tokenpath.unwrap_or("/var/run/secrets/kubernetes.io/serviceaccount/token".to_string()); - let token = read_token(path.as_str()).await; + if let Some(kuber) = config.kubernetes.clone() { + let servers = kuber.servers.unwrap_or(vec![format!( + "{}:{}", + env::var("KUBERNETES_SERVICE_HOST").unwrap_or("0.0.0.0".to_string()), + env::var("KUBERNETES_SERVICE_PORT_HTTPS").unwrap_or("0".to_string()) + )]); - let servers = kuber.servers.unwrap_or(vec![format!( - "{}:{}", - env::var("KUBERNETES_SERVICE_HOST").unwrap_or("0.0.0.0".to_string()), - env::var("KUBERNETES_SERVICE_PORT_HTTPS").unwrap_or("0".to_string()) - )]); + let end = servers.len().saturating_sub(1); + let num = if end > 0 { rand::rng().random_range(0..end) } else { 0 }; + let server = servers.get(num).unwrap().to_string(); + let path = kuber.tokenpath.unwrap_or("/var/run/secrets/kubernetes.io/serviceaccount/token".to_string()); + let token = read_token(path.as_str()).await; + // let mut oldcrt: HashMap = HashMap::new(); - let end = servers.len().saturating_sub(1); - let num = if end > 0 { rand::rng().random_range(0..end) } else { 0 }; - let server = servers.get(num).unwrap().to_string(); - - if let Some(svc) = kuber.services { - for i in svc { - let header_list = DashMap::new(); - let mut hl = Vec::new(); - build_headers(&i.headers, config.as_ref(), &mut hl); - if !hl.is_empty() { - header_list.insert(i.path.clone().unwrap_or("/".to_string()), hl); - config.headers.insert(i.hostname.clone(), header_list); + loop { + // crate::utils::watchksecret::watch_secret("ar-tls", "staging", server.clone(), token.clone(), &mut oldcrt).await; + let upstreams = UpstreamsDashMap::new(); + if let Some(kuber) = config.kubernetes.clone() { + if let Some(svc) = kuber.services { + for i in svc { + let header_list = DashMap::new(); + let mut hl = Vec::new(); + build_headers(&i.client_headers, config.as_ref(), &mut hl); + if !hl.is_empty() { + header_list.insert(i.path.clone().unwrap_or("/".to_string()), hl); + config.client_headers.insert(i.hostname.clone(), header_list); + } + let url = format!("https://{}/api/v1/namespaces/staging/endpoints/{}", server, i.hostname); + let list = httpclient::for_kuber(&*url, &*token, &i).await; + list_to_upstreams(list, &upstreams, &i); } - - let url = format!("https://{}/api/v1/namespaces/staging/endpoints/{}", server, i.hostname); - let list = httpclient::for_kuber(&*url, &*token, &i).await; - list_to_upstreams(list, &upstreams, &i); + } + if let Some(lt) = clone_compare(&upstreams, &prev_upstreams, &config).await { + toreturn.send(lt).await.unwrap(); } } - if let Some(lt) = clone_compare(&upstreams, &prev_upstreams, &config).await { - toreturn.send(lt).await.unwrap(); - } + sleep(Duration::from_secs(5)).await; } - sleep(Duration::from_secs(5)).await; } } } @@ -157,10 +159,10 @@ impl ServiceDiscovery for ConsulDiscovery { for i in svc { let header_list = DashMap::new(); let mut hl = Vec::new(); - build_headers(&i.headers, config.as_ref(), &mut hl); + build_headers(&i.client_headers, config.as_ref(), &mut hl); if !hl.is_empty() { header_list.insert(i.path.clone().unwrap_or("/".to_string()), hl); - config.headers.insert(i.hostname.clone(), header_list); + config.client_headers.insert(i.hostname.clone(), header_list); } let pref = ss.clone() + &i.upstream; @@ -180,7 +182,8 @@ async fn clone_compare(upstreams: &UpstreamsDashMap, prev_upstreams: &UpstreamsD if !compare_dashmaps(&upstreams, &prev_upstreams) { let tosend: Configuration = Configuration { upstreams: Default::default(), - headers: config.headers.clone(), + client_headers: config.client_headers.clone(), + server_headers: config.server_headers.clone(), consul: config.consul.clone(), kubernetes: config.kubernetes.clone(), typecfg: config.typecfg.clone(), diff --git a/src/utils/parceyaml.rs b/src/utils/parceyaml.rs index 35262d0..b8fd1b9 100644 --- a/src/utils/parceyaml.rs +++ b/src/utils/parceyaml.rs @@ -67,18 +67,33 @@ pub async fn load_configuration(d: &str, kind: &str) -> Option { } async fn populate_headers_and_auth(config: &mut Configuration, parsed: &Config) { - if let Some(headers) = &parsed.headers { - let mut hl = Vec::new(); + let mut ch = Vec::new(); + ch.push(("Server".to_string(), "Aralez".to_string())); + // println!("{:?}", &parsed.client_headers); + if let Some(headers) = &parsed.client_headers { for header in headers { if let Some((key, val)) = header.split_once(':') { - hl.push((key.trim().to_string(), val.trim().to_string())); + println!("{}:{}", key.trim().to_string(), val.trim().to_string()); + ch.push((key.trim().to_string(), val.trim().to_string())); } } - - let global_headers = DashMap::new(); - global_headers.insert("/".to_string(), hl); - config.headers.insert("GLOBAL_HEADERS".to_string(), global_headers); } + let global_headers = DashMap::new(); + global_headers.insert("/".to_string(), ch); + config.client_headers.insert("GLOBAL_CLIENT_HEADERS".to_string(), global_headers); + + let mut sh = Vec::new(); + sh.push(("X-Proxy-Server".to_string(), "Aralez".to_string())); + if let Some(headers) = &parsed.server_headers { + for header in headers { + if let Some((key, val)) = header.split_once(':') { + sh.push((key.trim().to_string(), val.trim().to_string())); + } + } + } + let server_global_headers = DashMap::new(); + server_global_headers.insert("/".to_string(), sh); + config.server_headers.insert("GLOBAL_SERVER_HEADERS".to_string(), server_global_headers); config.extraparams.sticky_sessions = parsed.sticky_sessions; config.extraparams.to_https = parsed.to_https; @@ -102,15 +117,19 @@ async fn populate_file_upstreams(config: &mut Configuration, parsed: &Config) { if let Some(upstreams) = &parsed.upstreams { for (hostname, host_config) in upstreams { let path_map = DashMap::new(); - let header_list = DashMap::new(); + let client_header_list = DashMap::new(); + let server_header_list = DashMap::new(); for (path, path_config) in &host_config.paths { if let Some(rate) = &path_config.rate_limit { info!("Applied Rate Limit for {} : {} request per second", hostname, rate); } let mut hl: Vec<(String, String)> = Vec::new(); - build_headers(&path_config.headers, config, &mut hl); - header_list.insert(path.clone(), hl); + let mut sl: Vec<(String, String)> = Vec::new(); + build_headers(&path_config.client_headers, config, &mut hl); + build_headers(&path_config.server_headers, config, &mut sl); + client_header_list.insert(path.clone(), hl); + server_header_list.insert(path.clone(), sl); let mut server_list = Vec::new(); for server in &path_config.servers { @@ -130,7 +149,8 @@ async fn populate_file_upstreams(config: &mut Configuration, parsed: &Config) { } path_map.insert(path.clone(), (server_list, AtomicUsize::new(0))); } - config.headers.insert(hostname.clone(), header_list); + config.client_headers.insert(hostname.clone(), client_header_list); + config.server_headers.insert(hostname.clone(), server_header_list); imtdashmap.insert(hostname.clone(), path_map); } @@ -218,19 +238,19 @@ fn log_builder(conf: &AppConfig) { env_logger::builder().init(); } -pub fn build_headers(path_config: &Option>, config: &Configuration, hl: &mut Vec<(String, String)>) { +pub fn build_headers(path_config: &Option>, _config: &Configuration, hl: &mut Vec<(String, String)>) { if let Some(headers) = &path_config { for header in headers { if let Some((key, val)) = header.split_once(':') { hl.push((key.trim().to_string(), val.trim().to_string())); } } - if let Some(push) = config.headers.get("GLOBAL_HEADERS") { - for k in push.iter() { - for x in k.value() { - hl.push(x.to_owned()); - } - } - } + // if let Some(push) = config.client_headers.get("GLOBAL_HEADERS") { + // for k in push.iter() { + // for x in k.value() { + // hl.push(x.to_owned()); + // } + // } + // } } } diff --git a/src/utils/structs.rs b/src/utils/structs.rs index 8dceb84..ab9d598 100644 --- a/src/utils/structs.rs +++ b/src/utils/structs.rs @@ -15,7 +15,8 @@ pub struct ServiceMapping { pub path: Option, pub to_https: Option, pub rate_limit: Option, - pub headers: Option>, + pub client_headers: Option>, + pub server_headers: Option>, } // pub type Services = DashMap)>>; @@ -50,7 +51,9 @@ pub struct Config { #[serde(default)] pub globals: Option>>, #[serde(default)] - pub headers: Option>, + pub client_headers: Option>, + #[serde(default)] + pub server_headers: Option>, #[serde(default)] pub authorization: Option>, #[serde(default)] @@ -71,14 +74,16 @@ pub struct HostConfig { pub struct PathConfig { pub servers: Vec, pub to_https: Option, - pub headers: Option>, + pub client_headers: Option>, + pub server_headers: Option>, pub rate_limit: Option, pub healthcheck: Option, } #[derive(Debug, Default)] pub struct Configuration { pub upstreams: UpstreamsDashMap, - pub headers: Headers, + pub client_headers: Headers, + pub server_headers: Headers, pub consul: Option, pub kubernetes: Option, pub typecfg: String, diff --git a/src/web/bgservice.rs b/src/web/bgservice.rs index a0a6444..7bb5b9c 100644 --- a/src/web/bgservice.rs +++ b/src/web/bgservice.rs @@ -85,22 +85,38 @@ impl BackgroundService for LB { new.authentication = ss.extraparams.authentication.clone(); new.rate_limit = ss.extraparams.rate_limit; self.extraparams.store(Arc::new(new)); - self.headers.clear(); + self.client_headers.clear(); + self.server_headers.clear(); for entry in ss.upstreams.iter() { let global_key = entry.key().clone(); - let global_values = DashMap::new(); - let mut target_entry = ss.headers.entry(global_key).or_insert_with(DashMap::new); - target_entry.extend(global_values); - self.headers.insert(target_entry.key().to_owned(), target_entry.value().to_owned()); + let client_global_values = DashMap::new(); + let server_global_values = DashMap::new(); + + let mut client_target_entry = ss.client_headers.entry(global_key.clone()).or_insert_with(DashMap::new); + client_target_entry.extend(client_global_values); + let mut server_target_entry = ss.server_headers.entry(global_key).or_insert_with(DashMap::new); + server_target_entry.extend(server_global_values); + self.server_headers.insert(server_target_entry.key().to_owned(), server_target_entry.value().to_owned()); } - for path in ss.headers.iter() { + for path in ss.client_headers.iter() { let path_key = path.key().clone(); let path_headers = path.value().clone(); - self.headers.insert(path_key.clone(), path_headers); - if let Some(global_headers) = ss.headers.get("GLOBAL_HEADERS") { - if let Some(existing_headers) = self.headers.get_mut(&path_key) { + self.client_headers.insert(path_key.clone(), path_headers); + if let Some(global_headers) = ss.client_headers.get("GLOBAL_CLIENT_HEADERS") { + if let Some(existing_headers) = self.client_headers.get_mut(&path_key) { + merge_headers(&existing_headers, &global_headers); + } + } + } + + for path in ss.server_headers.iter() { + let path_key = path.key().clone(); + let path_headers = path.value().clone(); + self.server_headers.insert(path_key.clone(), path_headers); + if let Some(global_headers) = ss.server_headers.get("GLOBAL_SERVER_HEADERS") { + if let Some(existing_headers) = self.server_headers.get_mut(&path_key) { merge_headers(&existing_headers, &global_headers); } } diff --git a/src/web/gethosts.rs b/src/web/gethosts.rs index 25449f0..45ef7fe 100644 --- a/src/web/gethosts.rs +++ b/src/web/gethosts.rs @@ -3,17 +3,22 @@ use crate::web::proxyhttp::LB; use async_trait::async_trait; use std::sync::atomic::Ordering; +#[derive(Debug, Clone)] +pub struct GetHostsReturHeaders { + pub client_headers: Option>, + pub server_headers: Option>, +} + #[async_trait] pub trait GetHost { fn get_host(&self, peer: &str, path: &str, backend_id: Option<&str>) -> Option; - fn get_header(&self, peer: &str, path: &str) -> Option>; + fn get_header(&self, peer: &str, path: &str) -> Option; } #[async_trait] impl GetHost for LB { fn get_host(&self, peer: &str, path: &str, backend_id: Option<&str>) -> Option { if let Some(b) = backend_id { if let Some(bb) = self.ump_byid.get(b) { - // println!("BIB :===> {:?}", Some(bb.value())); return Some(bb.value().clone()); } } @@ -45,33 +50,54 @@ impl GetHost for LB { } } } - // println!("Best Match :===> {:?}", best_match); best_match } - fn get_header(&self, peer: &str, path: &str) -> Option> { - let host_entry = self.headers.get(peer)?; - let mut current_path = path.to_string(); - let mut best_match: Option> = None; + + fn get_header(&self, peer: &str, path: &str) -> Option { + let client_entry = self.client_headers.get(peer)?; + let server_entry = self.server_headers.get(peer)?; + let mut current_path = path; + let mut best_match = None; loop { - if let Some(entry) = host_entry.get(¤t_path) { + if let Some(entry) = client_entry.get(current_path) { if !entry.value().is_empty() { best_match = Some(entry.value().clone()); break; } } if let Some(pos) = current_path.rfind('/') { - current_path.truncate(pos); + current_path = if pos == 0 { "/" } else { ¤t_path[..pos] }; } else { break; } } - if best_match.is_none() { - if let Some(entry) = host_entry.get("/") { + current_path = path; + let mut serv_match = None; + loop { + if let Some(entry) = server_entry.get(current_path) { if !entry.value().is_empty() { - best_match = Some(entry.value().clone()); + serv_match = Some(entry.value().clone()); + break; + } + } + if let Some(pos) = current_path.rfind('/') { + current_path = if pos == 0 { "/" } else { ¤t_path[..pos] }; + } else { + break; + } + if best_match.is_none() { + if let Some(entry) = server_entry.get("/") { + if !entry.value().is_empty() { + best_match = Some(entry.value().clone()); + break; + } } } } - best_match + let result = GetHostsReturHeaders { + client_headers: best_match, + server_headers: serv_match, + }; + Some(result) } } diff --git a/src/web/proxyhttp.rs b/src/web/proxyhttp.rs index 0cd9165..f33caee 100644 --- a/src/web/proxyhttp.rs +++ b/src/web/proxyhttp.rs @@ -25,7 +25,8 @@ pub struct LB { pub ump_upst: Arc, pub ump_full: Arc, pub ump_byid: Arc, - pub headers: Arc, + pub client_headers: Arc, + pub server_headers: Arc, pub config: Arc, pub extraparams: Arc>, } @@ -180,13 +181,22 @@ impl ProxyHttp for LB { } } - async fn upstream_request_filter(&self, _session: &mut Session, upstream_request: &mut RequestHeader, ctx: &mut Self::CTX) -> Result<()> { + async fn upstream_request_filter(&self, session: &mut Session, upstream_request: &mut RequestHeader, ctx: &mut Self::CTX) -> Result<()> { if let Some(hostname) = ctx.hostname.as_ref() { upstream_request.insert_header("Host", hostname)?; } if let Some(peer) = ctx.upstream_peer.as_ref() { upstream_request.insert_header("X-Forwarded-For", peer.address.as_str())?; } + + if let Some(headers) = self.get_header(ctx.hostname.as_ref().unwrap_or(&"localhost".to_string()), session.req_header().uri.path()) { + if let Some(client_headers) = headers.server_headers { + for k in client_headers { + upstream_request.insert_header(k.0, k.1)?; + } + } + } + Ok(()) } @@ -213,27 +223,46 @@ impl ProxyHttp for LB { match ctx.hostname.as_ref() { Some(host) => { let path = session.req_header().uri.path(); - let host_header = host; - let split_header = host_header.split_once(':'); - + let split_header = host.split_once(':'); match split_header { - Some(sh) => { - let yoyo = self.get_header(sh.0, path); - for k in yoyo.iter() { - for t in k.iter() { - _upstream_response.insert_header(t.0.clone(), t.1.clone()).unwrap(); + Some((host, _port)) => { + if let Some(headers) = self.get_header(host, path) { + if let Some(server_headers) = headers.client_headers { + for k in server_headers { + _upstream_response.insert_header(k.0, k.1).unwrap(); + } } } } None => { - let yoyo = self.get_header(host_header, path); - for k in yoyo.iter() { - for t in k.iter() { - _upstream_response.insert_header(t.0.clone(), t.1.clone()).unwrap(); + if let Some(headers) = self.get_header(host, path) { + if let Some(server_headers) = headers.client_headers { + for k in server_headers { + _upstream_response.insert_header(k.0, k.1).unwrap(); + } } } } } + + // match split_header { + // Some(sh) => { + // let client_header = self.get_header(sh.0, path); + // for k in client_header.iter() { + // for t in k.iter() { + // _upstream_response.insert_header(t.0.clone(), t.1.clone()).unwrap(); + // } + // } + // } + // None => { + // let client_header = self.get_header(host_header, path); + // for k in client_header.iter() { + // for t in k.iter() { + // _upstream_response.insert_header(t.0.clone(), t.1.clone()).unwrap(); + // } + // } + // } + // } } None => {} } diff --git a/src/web/start.rs b/src/web/start.rs index 71e9ada..9e89fd7 100644 --- a/src/web/start.rs +++ b/src/web/start.rs @@ -27,7 +27,8 @@ pub fn run() { let uf_config = Arc::new(DashMap::new()); let ff_config = Arc::new(DashMap::new()); let im_config = Arc::new(DashMap::new()); - let hh_config = Arc::new(DashMap::new()); + let ch_config = Arc::new(DashMap::new()); + let sh_config = Arc::new(DashMap::new()); let ec_config = Arc::new(ArcSwap::from_pointee(Extraparams { sticky_sessions: false, @@ -43,7 +44,8 @@ pub fn run() { ump_full: ff_config, ump_byid: im_config, config: cfg.clone(), - headers: hh_config, + client_headers: ch_config, + server_headers: sh_config, extraparams: ec_config, };