Added support to send custom headers to upstream servers.

This commit is contained in:
Ara Sadoyan
2025-11-22 23:18:06 +01:00
parent 78c83b802f
commit 74821654f3
13 changed files with 321 additions and 115 deletions

16
Cargo.lock generated
View File

@@ -2478,6 +2478,7 @@ dependencies = [
"bytes", "bytes",
"encoding_rs", "encoding_rs",
"futures-core", "futures-core",
"futures-util",
"h2", "h2",
"http", "http",
"http-body", "http-body",
@@ -2499,12 +2500,14 @@ dependencies = [
"sync_wrapper", "sync_wrapper",
"tokio", "tokio",
"tokio-native-tls", "tokio-native-tls",
"tokio-util",
"tower", "tower",
"tower-http", "tower-http",
"tower-service", "tower-service",
"url", "url",
"wasm-bindgen", "wasm-bindgen",
"wasm-bindgen-futures", "wasm-bindgen-futures",
"wasm-streams",
"web-sys", "web-sys",
] ]
@@ -3499,6 +3502,19 @@ dependencies = [
"unicode-ident", "unicode-ident",
] ]
[[package]]
name = "wasm-streams"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65"
dependencies = [
"futures-util",
"js-sys",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
]
[[package]] [[package]]
name = "web-sys" name = "web-sys"
version = "0.3.77" version = "0.3.77"

View File

@@ -26,7 +26,7 @@ futures = "0.3.31"
notify = "8.2.0" notify = "8.2.0"
axum = { version = "0.8.4" } axum = { version = "0.8.4" }
axum-server = { version = "0.7.2", features = ["tls-openssl"] } axum-server = { version = "0.7.2", features = ["tls-openssl"] }
reqwest = { version = "0.12.23", features = ["json", "native-tls-alpn"] } reqwest = { version = "0.12.23", features = ["json", "native-tls-alpn", "stream"] }
#reqwest = { version = "0.12.15", features = ["json", "rustls-tls"] } #reqwest = { version = "0.12.15", features = ["json", "rustls-tls"] }
#reqwest = { version = "0.12.15", default-features = false, features = ["rustls-tls", "json"] } #reqwest = { version = "0.12.15", default-features = false, features = ["rustls-tls", "json"] }

View File

@@ -191,7 +191,10 @@ provider: "file"
sticky_sessions: false sticky_sessions: false
to_https: false to_https: false
rate_limit: 10 rate_limit: 10
headers: server_headers:
- "X-Forwarded-Proto:https"
- "X-Forwarded-Port:443"
client_headers:
- "Access-Control-Allow-Origin:*" - "Access-Control-Allow-Origin:*"
- "Access-Control-Allow-Methods:POST, GET, OPTIONS" - "Access-Control-Allow-Methods:POST, GET, OPTIONS"
- "Access-Control-Max-Age:86400" - "Access-Control-Max-Age:86400"
@@ -203,7 +206,10 @@ myhost.mydomain.com:
"/": "/":
rate_limit: 20 rate_limit: 20
to_https: false to_https: false
headers: server_headers:
- "X-Something-Else:Foobar"
- "X-Another-Header:Hohohohoho"
client_headers:
- "X-Some-Thing:Yaaaaaaaaaaaaaaa" - "X-Some-Thing:Yaaaaaaaaaaaaaaa"
- "X-Proxy-From:Hopaaaaaaaaaaaar" - "X-Proxy-From:Hopaaaaaaaaaaaar"
servers: servers:
@@ -211,7 +217,7 @@ myhost.mydomain.com:
- "127.0.0.2:8000" - "127.0.0.2:8000"
"/foo": "/foo":
to_https: true to_https: true
headers: client_headers:
- "X-Another-Header:Hohohohoho" - "X-Another-Header:Hohohohoho"
servers: servers:
- "127.0.0.4:8443" - "127.0.0.4:8443"
@@ -226,6 +232,8 @@ myhost.mydomain.com:
- Sticky sessions are disabled globally. This setting applies to all upstreams. If enabled all requests will be 301 redirected to HTTPS. - Sticky sessions are disabled globally. This setting applies to all upstreams. If enabled all requests will be 301 redirected to HTTPS.
- HTTP to HTTPS redirect disabled globally, but can be overridden by `to_https` setting per upstream. - HTTP to HTTPS redirect disabled globally, but can be overridden by `to_https` setting per upstream.
- All upstreams will receive custom headers : `X-Forwarded-Proto:https` and `X-Forwarded-Port:443`
- Additionally, myhost.mydomain.com with path `/` will receive custom headers : `X-Another-Header:Hohohohoho` and `X-Something-Else:Foobar`
- Requests to each hosted domains will be limited to 10 requests per second per virtualhost. - Requests to each hosted domains will be limited to 10 requests per second per virtualhost.
- Requests limits are calculated per requester ip plus requested virtualhost. - Requests limits are calculated per requester ip plus requested virtualhost.
- If the requester exceeds the limit it will receive `429 Too Many Requests` error. - If the requester exceeds the limit it will receive `429 Too Many Requests` error.

View File

@@ -3,11 +3,13 @@ provider: "file" # "file" "consul" "kubernetes"
sticky_sessions: false sticky_sessions: false
to_https: false to_https: false
rate_limit: 100 rate_limit: 100
headers: server_headers:
- "X-Forwarded-Proto:https"
- "X-Forwarded-Port:443"
client_headers:
- "Access-Control-Allow-Origin:*" - "Access-Control-Allow-Origin:*"
- "Access-Control-Allow-Methods:POST, GET, OPTIONS" - "Access-Control-Allow-Methods:POST, GET, OPTIONS"
- "Access-Control-Max-Age:86400" - "Access-Control-Max-Age:86400"
- "Strict-Transport-Security:max-age=31536000; includeSubDomains; preload"
#authorization: #authorization:
# type: "jwt" # type: "jwt"
# creds: "910517d9-f9a1-48de-8826-dbadacbd84af-cb6f830e-ab16-47ec-9d8f-0090de732774" # creds: "910517d9-f9a1-48de-8826-dbadacbd84af-cb6f830e-ab16-47ec-9d8f-0090de732774"
@@ -21,38 +23,38 @@ consul:
- "http://192.168.1.200:8500" - "http://192.168.1.200:8500"
- "http://192.168.1.201:8500" - "http://192.168.1.201:8500"
services: # hostname: The hostname to access the proxy server, upstream : The real service name in Consul database. services: # hostname: The hostname to access the proxy server, upstream : The real service name in Consul database.
- hostname: "vt-webapi-service" - hostname: "webapi-service"
upstream: "vt-webapi-service-health" upstream: "webapi-service-health"
path: "/one" path: "/one"
headers: client_headers:
- "X-Some-Thing:Yaaaaaaaaaaaaaaa" - "X-Some-Thing:Yaaaaaaaaaaaaaaa"
- "X-Proxy-From:Aralez" - "X-Proxy-From:Aralez"
rate_limit: 1 rate_limit: 1
to_https: false to_https: false
- hostname: "vt-webapi-service" - hostname: "webapi-service"
upstream: "vt-webapi-service-health" upstream: "webapi-service-health"
path: "/" path: "/"
token: "8e2db809-845b-45e1-8b47-2c8356a09da0-a4370955-18c2-4d6e-a8f8-ffcc0b47be81" # Consul server access token, If Consul auth is enabled token: "8e2db809-845b-45e1-8b47-2c8356a09da0-a4370955-18c2-4d6e-a8f8-ffcc0b47be81" # Consul server access token, If Consul auth is enabled
kubernetes: kubernetes:
servers: servers:
- "192.168.1.55:443" #For testing only, overrides with KUBERNETES_SERVICE_HOST : KUBERNETES_SERVICE_PORT_HTTPS env variables. - "192.168.1.55:443" #For testing only, overrides with KUBERNETES_SERVICE_HOST : KUBERNETES_SERVICE_PORT_HTTPS env variables.
services: services:
- hostname: "vt-webapi-service" - hostname: "webapi-service"
path: "/" path: "/"
upstream: "vt-webapi-service" upstream: "webapi-service"
- hostname: "vt-webapi-service" - hostname: "webapi-service"
upstream: "vt-console-service" upstream: "vt-console-service"
path: "/one" path: "/one"
headers: client_headers:
- "X-Some-Thing:Yaaaaaaaaaaaaaaa" - "X-Some-Thing:Yaaaaaaaaaaaaaaa"
- "X-Proxy-From:Aralez" - "X-Proxy-From:Aralez"
rate_limit: 100 rate_limit: 100
to_https: false to_https: false
- hostname: "vt-webapi-service" - hostname: "webapi-service"
upstream: "vt-rambulik-service" upstream: "vt-rambulik-service"
path: "/two" path: "/two"
- hostname: "vt-websocket-service" - hostname: "websocket-service"
upstream: "vt-websocket-service" upstream: "websocket-service"
path: "/" path: "/"
tokenpath: "/path/to/kubetoken.txt" #If not set, will default to /var/run/secrets/kubernetes.io/serviceaccount/token tokenpath: "/path/to/kubetoken.txt" #If not set, will default to /var/run/secrets/kubernetes.io/serviceaccount/token
upstreams: upstreams:
@@ -61,7 +63,7 @@ upstreams:
"/": "/":
rate_limit: 200 rate_limit: 200
to_https: false to_https: false
headers: client_headers:
- "X-Proxy-From:Aralez" - "X-Proxy-From:Aralez"
servers: servers:
- "127.0.0.1:8000" - "127.0.0.1:8000"
@@ -71,7 +73,10 @@ upstreams:
- "127.0.0.5:8000" - "127.0.0.5:8000"
"/ping": "/ping":
to_https: false to_https: false
headers: server_headers:
- "X-Forwarded-Proto:https"
- "X-Forwarded-Port:443"
client_headers:
- "X-Some-Thing:Yaaaaaaaaaaaaaaa" - "X-Some-Thing:Yaaaaaaaaaaaaaaa"
- "X-Proxy-From:Aralez" - "X-Proxy-From:Aralez"
servers: servers:
@@ -84,7 +89,7 @@ upstreams:
paths: paths:
"/": "/":
to_https: false to_https: false
headers: client_headers:
- "X-Some-Thing:Yaaaaaaaaaaaaaaa" - "X-Some-Thing:Yaaaaaaaaaaaaaaa"
servers: servers:
- "192.168.1.1:8000" - "192.168.1.1:8000"

View File

@@ -12,3 +12,4 @@ pub mod state;
pub mod structs; pub mod structs;
pub mod tls; pub mod tls;
pub mod tools; pub mod tools;
// pub mod watchksecret;

75
src/utils/httpclient.rs Normal file
View File

@@ -0,0 +1,75 @@
use crate::utils::kuberconsul::{match_path, ConsulService, KubeEndpoints};
use crate::utils::structs::{InnerMap, ServiceMapping};
use axum::http::{HeaderMap, HeaderValue};
use dashmap::DashMap;
use reqwest::Client;
use std::sync::atomic::AtomicUsize;
use std::time::Duration;
pub async fn for_consul(url: String, token: Option<String>, conf: &ServiceMapping) -> Option<DashMap<String, (Vec<InnerMap>, AtomicUsize)>> {
let client = Client::builder().timeout(Duration::from_secs(2)).danger_accept_invalid_certs(true).build().ok()?;
let mut headers = HeaderMap::new();
if let Some(token) = token {
headers.insert("X-Consul-Token", HeaderValue::from_str(&token).unwrap());
}
let to = Duration::from_secs(1);
let resp = client.get(url).timeout(to).send().await.ok()?;
if !resp.status().is_success() {
eprintln!("Consul API returned status: {}", resp.status());
return None;
}
let mut inner_vec = Vec::new();
let upstreams: DashMap<String, (Vec<InnerMap>, AtomicUsize)> = DashMap::new();
let endpoints: Vec<ConsulService> = resp.json().await.ok()?;
for subsets in endpoints {
let addr = subsets.tagged_addresses.get("lan_ipv4").unwrap().address.clone();
let prt = subsets.tagged_addresses.get("lan_ipv4").unwrap().port.clone();
let to_add = InnerMap {
address: addr,
port: prt,
is_ssl: false,
is_http2: false,
to_https: conf.to_https.unwrap_or(false),
rate_limit: conf.rate_limit,
healthcheck: None,
};
inner_vec.push(to_add);
}
match_path(&conf, &upstreams, inner_vec.clone());
Some(upstreams)
}
pub async fn for_kuber(url: &str, token: &str, conf: &ServiceMapping) -> Option<DashMap<String, (Vec<InnerMap>, AtomicUsize)>> {
let to = Duration::from_secs(10);
let client = Client::builder().timeout(Duration::from_secs(10)).danger_accept_invalid_certs(true).build().ok()?;
let resp = client.get(url).timeout(to).bearer_auth(token).send().await.ok()?;
if !resp.status().is_success() {
eprintln!("Kubernetes API returned status: {}", resp.status());
return None;
}
let endpoints: KubeEndpoints = resp.json().await.ok()?;
let upstreams: DashMap<String, (Vec<InnerMap>, AtomicUsize)> = DashMap::new();
if let Some(subsets) = endpoints.subsets {
for subset in subsets {
if let (Some(addresses), Some(ports)) = (subset.addresses, subset.ports) {
let mut inner_vec = Vec::new();
for addr in addresses {
for port in &ports {
let to_add = InnerMap {
address: addr.ip.clone(),
port: port.port.clone(),
is_ssl: false,
is_http2: false,
to_https: conf.to_https.unwrap_or(false),
rate_limit: conf.rate_limit,
healthcheck: None,
};
inner_vec.push(to_add);
}
}
match_path(&conf, &upstreams, inner_vec.clone());
}
}
}
Some(upstreams)
}

View File

@@ -94,43 +94,45 @@ pub struct ConsulDiscovery;
impl ServiceDiscovery for KubernetesDiscovery { impl ServiceDiscovery for KubernetesDiscovery {
async fn fetch_upstreams(&self, config: Arc<Configuration>, mut toreturn: Sender<Configuration>) { async fn fetch_upstreams(&self, config: Arc<Configuration>, mut toreturn: Sender<Configuration>) {
let prev_upstreams = UpstreamsDashMap::new(); let prev_upstreams = UpstreamsDashMap::new();
loop {
let upstreams = UpstreamsDashMap::new();
if let Some(kuber) = config.kubernetes.clone() { if let Some(kuber) = config.kubernetes.clone() {
let path = kuber.tokenpath.unwrap_or("/var/run/secrets/kubernetes.io/serviceaccount/token".to_string()); let servers = kuber.servers.unwrap_or(vec![format!(
let token = read_token(path.as_str()).await; "{}:{}",
env::var("KUBERNETES_SERVICE_HOST").unwrap_or("0.0.0.0".to_string()),
env::var("KUBERNETES_SERVICE_PORT_HTTPS").unwrap_or("0".to_string())
)]);
let servers = kuber.servers.unwrap_or(vec![format!( let end = servers.len().saturating_sub(1);
"{}:{}", let num = if end > 0 { rand::rng().random_range(0..end) } else { 0 };
env::var("KUBERNETES_SERVICE_HOST").unwrap_or("0.0.0.0".to_string()), let server = servers.get(num).unwrap().to_string();
env::var("KUBERNETES_SERVICE_PORT_HTTPS").unwrap_or("0".to_string()) let path = kuber.tokenpath.unwrap_or("/var/run/secrets/kubernetes.io/serviceaccount/token".to_string());
)]); let token = read_token(path.as_str()).await;
// let mut oldcrt: HashMap<String, String> = HashMap::new();
let end = servers.len().saturating_sub(1); loop {
let num = if end > 0 { rand::rng().random_range(0..end) } else { 0 }; // crate::utils::watchksecret::watch_secret("ar-tls", "staging", server.clone(), token.clone(), &mut oldcrt).await;
let server = servers.get(num).unwrap().to_string(); let upstreams = UpstreamsDashMap::new();
if let Some(kuber) = config.kubernetes.clone() {
if let Some(svc) = kuber.services { if let Some(svc) = kuber.services {
for i in svc { for i in svc {
let header_list = DashMap::new(); let header_list = DashMap::new();
let mut hl = Vec::new(); let mut hl = Vec::new();
build_headers(&i.headers, config.as_ref(), &mut hl); build_headers(&i.client_headers, config.as_ref(), &mut hl);
if !hl.is_empty() { if !hl.is_empty() {
header_list.insert(i.path.clone().unwrap_or("/".to_string()), hl); header_list.insert(i.path.clone().unwrap_or("/".to_string()), hl);
config.headers.insert(i.hostname.clone(), header_list); config.client_headers.insert(i.hostname.clone(), header_list);
}
let url = format!("https://{}/api/v1/namespaces/staging/endpoints/{}", server, i.hostname);
let list = httpclient::for_kuber(&*url, &*token, &i).await;
list_to_upstreams(list, &upstreams, &i);
} }
}
let url = format!("https://{}/api/v1/namespaces/staging/endpoints/{}", server, i.hostname); if let Some(lt) = clone_compare(&upstreams, &prev_upstreams, &config).await {
let list = httpclient::for_kuber(&*url, &*token, &i).await; toreturn.send(lt).await.unwrap();
list_to_upstreams(list, &upstreams, &i);
} }
} }
if let Some(lt) = clone_compare(&upstreams, &prev_upstreams, &config).await { sleep(Duration::from_secs(5)).await;
toreturn.send(lt).await.unwrap();
}
} }
sleep(Duration::from_secs(5)).await;
} }
} }
} }
@@ -157,10 +159,10 @@ impl ServiceDiscovery for ConsulDiscovery {
for i in svc { for i in svc {
let header_list = DashMap::new(); let header_list = DashMap::new();
let mut hl = Vec::new(); let mut hl = Vec::new();
build_headers(&i.headers, config.as_ref(), &mut hl); build_headers(&i.client_headers, config.as_ref(), &mut hl);
if !hl.is_empty() { if !hl.is_empty() {
header_list.insert(i.path.clone().unwrap_or("/".to_string()), hl); header_list.insert(i.path.clone().unwrap_or("/".to_string()), hl);
config.headers.insert(i.hostname.clone(), header_list); config.client_headers.insert(i.hostname.clone(), header_list);
} }
let pref = ss.clone() + &i.upstream; let pref = ss.clone() + &i.upstream;
@@ -180,7 +182,8 @@ async fn clone_compare(upstreams: &UpstreamsDashMap, prev_upstreams: &UpstreamsD
if !compare_dashmaps(&upstreams, &prev_upstreams) { if !compare_dashmaps(&upstreams, &prev_upstreams) {
let tosend: Configuration = Configuration { let tosend: Configuration = Configuration {
upstreams: Default::default(), upstreams: Default::default(),
headers: config.headers.clone(), client_headers: config.client_headers.clone(),
server_headers: config.server_headers.clone(),
consul: config.consul.clone(), consul: config.consul.clone(),
kubernetes: config.kubernetes.clone(), kubernetes: config.kubernetes.clone(),
typecfg: config.typecfg.clone(), typecfg: config.typecfg.clone(),

View File

@@ -67,18 +67,33 @@ pub async fn load_configuration(d: &str, kind: &str) -> Option<Configuration> {
} }
async fn populate_headers_and_auth(config: &mut Configuration, parsed: &Config) { async fn populate_headers_and_auth(config: &mut Configuration, parsed: &Config) {
if let Some(headers) = &parsed.headers { let mut ch = Vec::new();
let mut hl = Vec::new(); ch.push(("Server".to_string(), "Aralez".to_string()));
// println!("{:?}", &parsed.client_headers);
if let Some(headers) = &parsed.client_headers {
for header in headers { for header in headers {
if let Some((key, val)) = header.split_once(':') { if let Some((key, val)) = header.split_once(':') {
hl.push((key.trim().to_string(), val.trim().to_string())); println!("{}:{}", key.trim().to_string(), val.trim().to_string());
ch.push((key.trim().to_string(), val.trim().to_string()));
} }
} }
let global_headers = DashMap::new();
global_headers.insert("/".to_string(), hl);
config.headers.insert("GLOBAL_HEADERS".to_string(), global_headers);
} }
let global_headers = DashMap::new();
global_headers.insert("/".to_string(), ch);
config.client_headers.insert("GLOBAL_CLIENT_HEADERS".to_string(), global_headers);
let mut sh = Vec::new();
sh.push(("X-Proxy-Server".to_string(), "Aralez".to_string()));
if let Some(headers) = &parsed.server_headers {
for header in headers {
if let Some((key, val)) = header.split_once(':') {
sh.push((key.trim().to_string(), val.trim().to_string()));
}
}
}
let server_global_headers = DashMap::new();
server_global_headers.insert("/".to_string(), sh);
config.server_headers.insert("GLOBAL_SERVER_HEADERS".to_string(), server_global_headers);
config.extraparams.sticky_sessions = parsed.sticky_sessions; config.extraparams.sticky_sessions = parsed.sticky_sessions;
config.extraparams.to_https = parsed.to_https; config.extraparams.to_https = parsed.to_https;
@@ -102,15 +117,19 @@ async fn populate_file_upstreams(config: &mut Configuration, parsed: &Config) {
if let Some(upstreams) = &parsed.upstreams { if let Some(upstreams) = &parsed.upstreams {
for (hostname, host_config) in upstreams { for (hostname, host_config) in upstreams {
let path_map = DashMap::new(); let path_map = DashMap::new();
let header_list = DashMap::new(); let client_header_list = DashMap::new();
let server_header_list = DashMap::new();
for (path, path_config) in &host_config.paths { for (path, path_config) in &host_config.paths {
if let Some(rate) = &path_config.rate_limit { if let Some(rate) = &path_config.rate_limit {
info!("Applied Rate Limit for {} : {} request per second", hostname, rate); info!("Applied Rate Limit for {} : {} request per second", hostname, rate);
} }
let mut hl: Vec<(String, String)> = Vec::new(); let mut hl: Vec<(String, String)> = Vec::new();
build_headers(&path_config.headers, config, &mut hl); let mut sl: Vec<(String, String)> = Vec::new();
header_list.insert(path.clone(), hl); build_headers(&path_config.client_headers, config, &mut hl);
build_headers(&path_config.server_headers, config, &mut sl);
client_header_list.insert(path.clone(), hl);
server_header_list.insert(path.clone(), sl);
let mut server_list = Vec::new(); let mut server_list = Vec::new();
for server in &path_config.servers { for server in &path_config.servers {
@@ -130,7 +149,8 @@ async fn populate_file_upstreams(config: &mut Configuration, parsed: &Config) {
} }
path_map.insert(path.clone(), (server_list, AtomicUsize::new(0))); path_map.insert(path.clone(), (server_list, AtomicUsize::new(0)));
} }
config.headers.insert(hostname.clone(), header_list); config.client_headers.insert(hostname.clone(), client_header_list);
config.server_headers.insert(hostname.clone(), server_header_list);
imtdashmap.insert(hostname.clone(), path_map); imtdashmap.insert(hostname.clone(), path_map);
} }
@@ -218,19 +238,19 @@ fn log_builder(conf: &AppConfig) {
env_logger::builder().init(); env_logger::builder().init();
} }
pub fn build_headers(path_config: &Option<Vec<String>>, config: &Configuration, hl: &mut Vec<(String, String)>) { pub fn build_headers(path_config: &Option<Vec<String>>, _config: &Configuration, hl: &mut Vec<(String, String)>) {
if let Some(headers) = &path_config { if let Some(headers) = &path_config {
for header in headers { for header in headers {
if let Some((key, val)) = header.split_once(':') { if let Some((key, val)) = header.split_once(':') {
hl.push((key.trim().to_string(), val.trim().to_string())); hl.push((key.trim().to_string(), val.trim().to_string()));
} }
} }
if let Some(push) = config.headers.get("GLOBAL_HEADERS") { // if let Some(push) = config.client_headers.get("GLOBAL_HEADERS") {
for k in push.iter() { // for k in push.iter() {
for x in k.value() { // for x in k.value() {
hl.push(x.to_owned()); // hl.push(x.to_owned());
} // }
} // }
} // }
} }
} }

View File

@@ -15,7 +15,8 @@ pub struct ServiceMapping {
pub path: Option<String>, pub path: Option<String>,
pub to_https: Option<bool>, pub to_https: Option<bool>,
pub rate_limit: Option<isize>, pub rate_limit: Option<isize>,
pub headers: Option<Vec<String>>, pub client_headers: Option<Vec<String>>,
pub server_headers: Option<Vec<String>>,
} }
// pub type Services = DashMap<String, Vec<(String, Option<String>)>>; // pub type Services = DashMap<String, Vec<(String, Option<String>)>>;
@@ -50,7 +51,9 @@ pub struct Config {
#[serde(default)] #[serde(default)]
pub globals: Option<HashMap<String, Vec<String>>>, pub globals: Option<HashMap<String, Vec<String>>>,
#[serde(default)] #[serde(default)]
pub headers: Option<Vec<String>>, pub client_headers: Option<Vec<String>>,
#[serde(default)]
pub server_headers: Option<Vec<String>>,
#[serde(default)] #[serde(default)]
pub authorization: Option<HashMap<String, String>>, pub authorization: Option<HashMap<String, String>>,
#[serde(default)] #[serde(default)]
@@ -71,14 +74,16 @@ pub struct HostConfig {
pub struct PathConfig { pub struct PathConfig {
pub servers: Vec<String>, pub servers: Vec<String>,
pub to_https: Option<bool>, pub to_https: Option<bool>,
pub headers: Option<Vec<String>>, pub client_headers: Option<Vec<String>>,
pub server_headers: Option<Vec<String>>,
pub rate_limit: Option<isize>, pub rate_limit: Option<isize>,
pub healthcheck: Option<bool>, pub healthcheck: Option<bool>,
} }
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct Configuration { pub struct Configuration {
pub upstreams: UpstreamsDashMap, pub upstreams: UpstreamsDashMap,
pub headers: Headers, pub client_headers: Headers,
pub server_headers: Headers,
pub consul: Option<Consul>, pub consul: Option<Consul>,
pub kubernetes: Option<Kubernetes>, pub kubernetes: Option<Kubernetes>,
pub typecfg: String, pub typecfg: String,

View File

@@ -85,22 +85,38 @@ impl BackgroundService for LB {
new.authentication = ss.extraparams.authentication.clone(); new.authentication = ss.extraparams.authentication.clone();
new.rate_limit = ss.extraparams.rate_limit; new.rate_limit = ss.extraparams.rate_limit;
self.extraparams.store(Arc::new(new)); self.extraparams.store(Arc::new(new));
self.headers.clear(); self.client_headers.clear();
self.server_headers.clear();
for entry in ss.upstreams.iter() { for entry in ss.upstreams.iter() {
let global_key = entry.key().clone(); let global_key = entry.key().clone();
let global_values = DashMap::new(); let client_global_values = DashMap::new();
let mut target_entry = ss.headers.entry(global_key).or_insert_with(DashMap::new); let server_global_values = DashMap::new();
target_entry.extend(global_values);
self.headers.insert(target_entry.key().to_owned(), target_entry.value().to_owned()); let mut client_target_entry = ss.client_headers.entry(global_key.clone()).or_insert_with(DashMap::new);
client_target_entry.extend(client_global_values);
let mut server_target_entry = ss.server_headers.entry(global_key).or_insert_with(DashMap::new);
server_target_entry.extend(server_global_values);
self.server_headers.insert(server_target_entry.key().to_owned(), server_target_entry.value().to_owned());
} }
for path in ss.headers.iter() { for path in ss.client_headers.iter() {
let path_key = path.key().clone(); let path_key = path.key().clone();
let path_headers = path.value().clone(); let path_headers = path.value().clone();
self.headers.insert(path_key.clone(), path_headers); self.client_headers.insert(path_key.clone(), path_headers);
if let Some(global_headers) = ss.headers.get("GLOBAL_HEADERS") { if let Some(global_headers) = ss.client_headers.get("GLOBAL_CLIENT_HEADERS") {
if let Some(existing_headers) = self.headers.get_mut(&path_key) { if let Some(existing_headers) = self.client_headers.get_mut(&path_key) {
merge_headers(&existing_headers, &global_headers);
}
}
}
for path in ss.server_headers.iter() {
let path_key = path.key().clone();
let path_headers = path.value().clone();
self.server_headers.insert(path_key.clone(), path_headers);
if let Some(global_headers) = ss.server_headers.get("GLOBAL_SERVER_HEADERS") {
if let Some(existing_headers) = self.server_headers.get_mut(&path_key) {
merge_headers(&existing_headers, &global_headers); merge_headers(&existing_headers, &global_headers);
} }
} }

View File

@@ -3,17 +3,22 @@ use crate::web::proxyhttp::LB;
use async_trait::async_trait; use async_trait::async_trait;
use std::sync::atomic::Ordering; use std::sync::atomic::Ordering;
#[derive(Debug, Clone)]
pub struct GetHostsReturHeaders {
pub client_headers: Option<Vec<(String, String)>>,
pub server_headers: Option<Vec<(String, String)>>,
}
#[async_trait] #[async_trait]
pub trait GetHost { pub trait GetHost {
fn get_host(&self, peer: &str, path: &str, backend_id: Option<&str>) -> Option<InnerMap>; fn get_host(&self, peer: &str, path: &str, backend_id: Option<&str>) -> Option<InnerMap>;
fn get_header(&self, peer: &str, path: &str) -> Option<Vec<(String, String)>>; fn get_header(&self, peer: &str, path: &str) -> Option<GetHostsReturHeaders>;
} }
#[async_trait] #[async_trait]
impl GetHost for LB { impl GetHost for LB {
fn get_host(&self, peer: &str, path: &str, backend_id: Option<&str>) -> Option<InnerMap> { fn get_host(&self, peer: &str, path: &str, backend_id: Option<&str>) -> Option<InnerMap> {
if let Some(b) = backend_id { if let Some(b) = backend_id {
if let Some(bb) = self.ump_byid.get(b) { if let Some(bb) = self.ump_byid.get(b) {
// println!("BIB :===> {:?}", Some(bb.value()));
return Some(bb.value().clone()); return Some(bb.value().clone());
} }
} }
@@ -45,33 +50,54 @@ impl GetHost for LB {
} }
} }
} }
// println!("Best Match :===> {:?}", best_match);
best_match best_match
} }
fn get_header(&self, peer: &str, path: &str) -> Option<Vec<(String, String)>> {
let host_entry = self.headers.get(peer)?; fn get_header(&self, peer: &str, path: &str) -> Option<GetHostsReturHeaders> {
let mut current_path = path.to_string(); let client_entry = self.client_headers.get(peer)?;
let mut best_match: Option<Vec<(String, String)>> = None; let server_entry = self.server_headers.get(peer)?;
let mut current_path = path;
let mut best_match = None;
loop { loop {
if let Some(entry) = host_entry.get(&current_path) { if let Some(entry) = client_entry.get(current_path) {
if !entry.value().is_empty() { if !entry.value().is_empty() {
best_match = Some(entry.value().clone()); best_match = Some(entry.value().clone());
break; break;
} }
} }
if let Some(pos) = current_path.rfind('/') { if let Some(pos) = current_path.rfind('/') {
current_path.truncate(pos); current_path = if pos == 0 { "/" } else { &current_path[..pos] };
} else { } else {
break; break;
} }
} }
if best_match.is_none() { current_path = path;
if let Some(entry) = host_entry.get("/") { let mut serv_match = None;
loop {
if let Some(entry) = server_entry.get(current_path) {
if !entry.value().is_empty() { if !entry.value().is_empty() {
best_match = Some(entry.value().clone()); serv_match = Some(entry.value().clone());
break;
}
}
if let Some(pos) = current_path.rfind('/') {
current_path = if pos == 0 { "/" } else { &current_path[..pos] };
} else {
break;
}
if best_match.is_none() {
if let Some(entry) = server_entry.get("/") {
if !entry.value().is_empty() {
best_match = Some(entry.value().clone());
break;
}
} }
} }
} }
best_match let result = GetHostsReturHeaders {
client_headers: best_match,
server_headers: serv_match,
};
Some(result)
} }
} }

View File

@@ -25,7 +25,8 @@ pub struct LB {
pub ump_upst: Arc<UpstreamsDashMap>, pub ump_upst: Arc<UpstreamsDashMap>,
pub ump_full: Arc<UpstreamsDashMap>, pub ump_full: Arc<UpstreamsDashMap>,
pub ump_byid: Arc<UpstreamsIdMap>, pub ump_byid: Arc<UpstreamsIdMap>,
pub headers: Arc<Headers>, pub client_headers: Arc<Headers>,
pub server_headers: Arc<Headers>,
pub config: Arc<AppConfig>, pub config: Arc<AppConfig>,
pub extraparams: Arc<ArcSwap<Extraparams>>, pub extraparams: Arc<ArcSwap<Extraparams>>,
} }
@@ -180,13 +181,22 @@ impl ProxyHttp for LB {
} }
} }
async fn upstream_request_filter(&self, _session: &mut Session, upstream_request: &mut RequestHeader, ctx: &mut Self::CTX) -> Result<()> { async fn upstream_request_filter(&self, session: &mut Session, upstream_request: &mut RequestHeader, ctx: &mut Self::CTX) -> Result<()> {
if let Some(hostname) = ctx.hostname.as_ref() { if let Some(hostname) = ctx.hostname.as_ref() {
upstream_request.insert_header("Host", hostname)?; upstream_request.insert_header("Host", hostname)?;
} }
if let Some(peer) = ctx.upstream_peer.as_ref() { if let Some(peer) = ctx.upstream_peer.as_ref() {
upstream_request.insert_header("X-Forwarded-For", peer.address.as_str())?; upstream_request.insert_header("X-Forwarded-For", peer.address.as_str())?;
} }
if let Some(headers) = self.get_header(ctx.hostname.as_ref().unwrap_or(&"localhost".to_string()), session.req_header().uri.path()) {
if let Some(client_headers) = headers.server_headers {
for k in client_headers {
upstream_request.insert_header(k.0, k.1)?;
}
}
}
Ok(()) Ok(())
} }
@@ -213,27 +223,46 @@ impl ProxyHttp for LB {
match ctx.hostname.as_ref() { match ctx.hostname.as_ref() {
Some(host) => { Some(host) => {
let path = session.req_header().uri.path(); let path = session.req_header().uri.path();
let host_header = host; let split_header = host.split_once(':');
let split_header = host_header.split_once(':');
match split_header { match split_header {
Some(sh) => { Some((host, _port)) => {
let yoyo = self.get_header(sh.0, path); if let Some(headers) = self.get_header(host, path) {
for k in yoyo.iter() { if let Some(server_headers) = headers.client_headers {
for t in k.iter() { for k in server_headers {
_upstream_response.insert_header(t.0.clone(), t.1.clone()).unwrap(); _upstream_response.insert_header(k.0, k.1).unwrap();
}
} }
} }
} }
None => { None => {
let yoyo = self.get_header(host_header, path); if let Some(headers) = self.get_header(host, path) {
for k in yoyo.iter() { if let Some(server_headers) = headers.client_headers {
for t in k.iter() { for k in server_headers {
_upstream_response.insert_header(t.0.clone(), t.1.clone()).unwrap(); _upstream_response.insert_header(k.0, k.1).unwrap();
}
} }
} }
} }
} }
// match split_header {
// Some(sh) => {
// let client_header = self.get_header(sh.0, path);
// for k in client_header.iter() {
// for t in k.iter() {
// _upstream_response.insert_header(t.0.clone(), t.1.clone()).unwrap();
// }
// }
// }
// None => {
// let client_header = self.get_header(host_header, path);
// for k in client_header.iter() {
// for t in k.iter() {
// _upstream_response.insert_header(t.0.clone(), t.1.clone()).unwrap();
// }
// }
// }
// }
} }
None => {} None => {}
} }

View File

@@ -27,7 +27,8 @@ pub fn run() {
let uf_config = Arc::new(DashMap::new()); let uf_config = Arc::new(DashMap::new());
let ff_config = Arc::new(DashMap::new()); let ff_config = Arc::new(DashMap::new());
let im_config = Arc::new(DashMap::new()); let im_config = Arc::new(DashMap::new());
let hh_config = Arc::new(DashMap::new()); let ch_config = Arc::new(DashMap::new());
let sh_config = Arc::new(DashMap::new());
let ec_config = Arc::new(ArcSwap::from_pointee(Extraparams { let ec_config = Arc::new(ArcSwap::from_pointee(Extraparams {
sticky_sessions: false, sticky_sessions: false,
@@ -43,7 +44,8 @@ pub fn run() {
ump_full: ff_config, ump_full: ff_config,
ump_byid: im_config, ump_byid: im_config,
config: cfg.clone(), config: cfg.clone(),
headers: hh_config, client_headers: ch_config,
server_headers: sh_config,
extraparams: ec_config, extraparams: ec_config,
}; };