diff --git a/Cargo.toml b/Cargo.toml index 888d5ad..1d9606b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,5 +50,6 @@ x509-parser = "0.17.0" rustls-pemfile = "2.2.0" tower-http = { version = "0.6.6", features = ["fs"] } once_cell = "1.20.2" +#moka = { version = "0.12.10", features = ["sync"] } diff --git a/src/utils/healthcheck.rs b/src/utils/healthcheck.rs index c717b60..27f7dad 100644 --- a/src/utils/healthcheck.rs +++ b/src/utils/healthcheck.rs @@ -13,6 +13,7 @@ use tonic::transport::Endpoint; pub async fn hc2(upslist: Arc, fullist: Arc, idlist: Arc, params: (&str, u64)) { let mut period = interval(Duration::from_secs(params.1)); let mut first_run = 0; + let client = Client::builder().timeout(Duration::from_secs(2)).danger_accept_invalid_certs(true).build().unwrap(); loop { tokio::select! { _ = period.tick() => { @@ -27,8 +28,9 @@ pub async fn hc2(upslist: Arc, fullist: Arc, let mut innervec= Vec::new(); for k in path_entry.value().0 .iter().enumerate() { let mut _link = String::new(); - let tls = detect_tls(k.1.address.as_str(), &k.1.port).await; + let tls = detect_tls(k.1.address.as_str(), &k.1.port, &client).await; let mut is_h2 = false; + if tls.1 == Some(Version::HTTP_2) { is_h2 = true; } @@ -43,7 +45,7 @@ pub async fn hc2(upslist: Arc, fullist: Arc, is_http2: is_h2, to_https: k.1.to_https, }; - let resp = http_request(_link.as_str(), params.0, "").await; + let resp = http_request(_link.as_str(), params.0, "", &client).await; match resp.0 { true => { if resp.1 { @@ -86,33 +88,26 @@ pub async fn hc2(upslist: Arc, fullist: Arc, } } -#[allow(dead_code)] -async fn http_request(url: &str, method: &str, payload: &str) -> (bool, bool) { - let client = Client::builder().danger_accept_invalid_certs(true).build().unwrap(); - let timeout = Duration::from_secs(1); +async fn http_request(url: &str, method: &str, payload: &str, client: &Client) -> (bool, bool) { if !["POST", "GET", "HEAD"].contains(&method) { error!("Method {} not supported. Only GET|POST|HEAD are supported ", method); return (false, false); } - async fn send_request(client: &Client, method: &str, url: &str, payload: &str, timeout: Duration) -> Option { + async fn send_request(client: &Client, method: &str, url: &str, payload: &str) -> Option { match method { - "POST" => client.post(url).body(payload.to_owned()).timeout(timeout).send().await.ok(), - "GET" => client.get(url).timeout(timeout).send().await.ok(), - "HEAD" => client.head(url).timeout(timeout).send().await.ok(), + "POST" => client.post(url).body(payload.to_owned()).send().await.ok(), + "GET" => client.get(url).send().await.ok(), + "HEAD" => client.head(url).send().await.ok(), _ => None, } } - match send_request(&client, method, url, payload, timeout).await { + match send_request(&client, method, url, payload).await { Some(response) => { let status = response.status().as_u16(); ((99..499).contains(&status), false) } - None => { - // let fallback_url = url.replace("https", "http"); - // ping_grpc(&fallback_url).await - (ping_grpc(&url).await, true) - } + None => (ping_grpc(&url).await, true), } } @@ -123,10 +118,7 @@ pub async fn ping_grpc(addr: &str) -> bool { let endpoint = endpoint.timeout(Duration::from_secs(2)); match tokio::time::timeout(Duration::from_secs(3), endpoint.connect()).await { - Ok(Ok(_channel)) => { - // println!("{:?} ==> {:?} ==> {}", endpoint, _channel, addr); - true - } + Ok(Ok(_channel)) => true, _ => false, } } else { @@ -134,15 +126,24 @@ pub async fn ping_grpc(addr: &str) -> bool { } } -async fn detect_tls(ip: &str, port: &u16) -> (bool, Option) { - let url = format!("https://{}:{}", ip, port); - // let url = format!("{}:{}", ip, port); - let client = Client::builder().timeout(Duration::from_secs(2)).danger_accept_invalid_certs(true).build().unwrap(); - match client.get(&url).send().await { - Ok(response) => (true, Some(response.version())), - Err(e) => { - if e.is_builder() || e.is_connect() || e.to_string().contains("tls") { - (false, None) +async fn detect_tls(ip: &str, port: &u16, client: &Client) -> (bool, Option) { + let https_url = format!("https://{}:{}", ip, port); + match client.get(&https_url).send().await { + Ok(response) => { + // println!("{} => {:?} (HTTPS)", https_url, response.version()); + return (true, Some(response.version())); + } + _ => {} + } + let http_url = format!("http://{}:{}", ip, port); + match client.get(&http_url).send().await { + Ok(response) => { + // println!("{} => {:?} (HTTP)", http_url, response.version()); + (false, Some(response.version())) + } + Err(_) => { + if ping_grpc(&http_url).await { + (false, Some(Version::HTTP_2)) } else { (false, None) } diff --git a/src/utils/parceyaml.rs b/src/utils/parceyaml.rs index cfb6f64..4b0b300 100644 --- a/src/utils/parceyaml.rs +++ b/src/utils/parceyaml.rs @@ -1,139 +1,136 @@ use crate::utils::structs::*; use dashmap::DashMap; use log::{error, info, warn}; -use serde_yaml::Error; use std::collections::HashMap; use std::fs; use std::sync::atomic::AtomicUsize; pub fn load_configuration(d: &str, kind: &str) -> Option { - let mut toreturn: Configuration = Configuration { - upstreams: Default::default(), - headers: Default::default(), - consul: None, - typecfg: "".to_string(), - extraparams: Extraparams { - sticky_sessions: false, - to_https: None, - authentication: DashMap::new(), - rate_limit: None, + let yaml_data = match kind { + "filepath" => match fs::read_to_string(d) { + Ok(data) => { + info!("Reading upstreams from {}", d); + data + } + Err(e) => { + error!("Reading: {}: {:?}", d, e); + warn!("Running with empty upstreams list, update it via API"); + return None; + } }, - }; - toreturn.upstreams = UpstreamsDashMap::new(); - toreturn.headers = Headers::new(); - - let mut yaml_data = d.to_string(); - match kind { - "filepath" => { - let _ = match fs::read_to_string(d) { - Ok(data) => { - info!("Reading upstreams from {}", d); - yaml_data = data - } - Err(e) => { - error!("Reading: {}: {:?}", d, e.to_string()); - warn!("Running with empty upstreams list, update it via API"); - return None; - } - }; - } "content" => { info!("Reading upstreams from API post body"); + d.to_string() } - _ => error!("Mismatched parameter, only filepath|content is allowed "), - } + _ => { + error!("Mismatched parameter, only filepath|content is allowed"); + return None; + } + }; - let p: Result = serde_yaml::from_str(&yaml_data); - match p { - Ok(parsed) => { - let global_headers = DashMap::new(); - let mut hl = Vec::new(); - if let Some(headers) = &parsed.headers { - for header in headers.iter() { - if let Some((key, val)) = header.split_once(':') { - hl.push((key.to_string(), val.to_string())); - } - } - global_headers.insert("/".to_string(), hl); - toreturn.headers.insert("GLOBAL_HEADERS".to_string(), global_headers); - toreturn.extraparams.sticky_sessions = parsed.sticky_sessions; - toreturn.extraparams.to_https = parsed.to_https; - toreturn.extraparams.rate_limit = parsed.rate_limit; - } - if let Some(auth) = &parsed.authorization { - let name = auth.get("type").unwrap().to_string(); - let creds = auth.get("creds").unwrap().to_string(); - let val: Vec = vec![name, creds]; - toreturn.extraparams.authentication.insert("authorization".to_string(), val); - } else { - toreturn.extraparams.authentication = DashMap::new(); - } - match parsed.provider.as_str() { - "file" => { - toreturn.typecfg = "file".to_string(); - if let Some(upstream) = parsed.upstreams { - for (hostname, host_config) in upstream { - let path_map = DashMap::new(); - let header_list = DashMap::new(); - for (path, path_config) in host_config.paths { - let mut server_list = Vec::new(); - let mut hl = Vec::new(); - if let Some(headers) = &path_config.headers { - for header in headers.iter().by_ref() { - if let Some((key, val)) = header.split_once(':') { - hl.push((key.to_string(), val.to_string())); - } - } - } - header_list.insert(path.clone(), hl); - for server in path_config.servers { - if let Some((ip, port_str)) = server.split_once(':') { - if let Ok(port) = port_str.parse::() { - let to_https = path_config.to_https.unwrap_or(false); - let sl = InnerMap { - address: ip.to_string(), - port: port, - is_ssl: true, - is_http2: false, - to_https: to_https, - }; - server_list.push(sl); - } - } - } - path_map.insert(path, (server_list, AtomicUsize::new(0))); - } - toreturn.headers.insert(hostname.clone(), header_list); - toreturn.upstreams.insert(hostname, path_map); - } - } - Some(toreturn) - } - "consul" => { - toreturn.typecfg = "consul".to_string(); - let consul = parsed.consul; - match consul { - Some(consul) => { - toreturn.consul = Some(consul); - Some(toreturn) - } - None => None, - } - } - "kubernetes" => None, - _ => { - warn!("Unknown provider {}", parsed.provider); - None - } - } - } + let parsed: Config = match serde_yaml::from_str(&yaml_data) { + Ok(cfg) => cfg, Err(e) => { error!("Failed to parse upstreams file: {}", e); + return None; + } + }; + + let mut toreturn = Configuration::default(); + + populate_headers_and_auth(&mut toreturn, &parsed); + toreturn.typecfg = parsed.provider.clone(); + + match parsed.provider.as_str() { + "file" => { + populate_file_upstreams(&mut toreturn, &parsed); + Some(toreturn) + } + "consul" => { + toreturn.consul = parsed.consul; + if toreturn.consul.is_some() { + Some(toreturn) + } else { + None + } + } + "kubernetes" => None, + _ => { + warn!("Unknown provider {}", parsed.provider); None } } } +fn populate_headers_and_auth(config: &mut Configuration, parsed: &Config) { + if let Some(headers) = &parsed.headers { + let mut hl = Vec::new(); + for header in headers { + if let Some((key, val)) = header.split_once(':') { + hl.push((key.trim().to_string(), val.trim().to_string())); + } + } + + let global_headers = DashMap::new(); + global_headers.insert("/".to_string(), hl); + config.headers.insert("GLOBAL_HEADERS".to_string(), global_headers); + } + + config.extraparams.sticky_sessions = parsed.sticky_sessions; + config.extraparams.to_https = parsed.to_https; + config.extraparams.rate_limit = parsed.rate_limit; + + if let Some(auth) = &parsed.authorization { + let name = auth.get("type").unwrap_or(&"".to_string()).to_string(); + let creds = auth.get("creds").unwrap_or(&"".to_string()).to_string(); + config.extraparams.authentication.insert("authorization".to_string(), vec![name, creds]); + } else { + config.extraparams.authentication = DashMap::new(); + } +} + +fn populate_file_upstreams(config: &mut Configuration, parsed: &Config) { + if let Some(upstreams) = &parsed.upstreams { + for (hostname, host_config) in upstreams { + let path_map = DashMap::new(); + let header_list = DashMap::new(); + + for (path, path_config) in &host_config.paths { + let mut server_list = Vec::new(); + let mut hl = Vec::new(); + + if let Some(headers) = &path_config.headers { + for header in headers { + if let Some((key, val)) = header.split_once(':') { + hl.push((key.trim().to_string(), val.trim().to_string())); + } + } + } + header_list.insert(path.clone(), hl); + + for server in &path_config.servers { + if let Some((ip, port_str)) = server.split_once(':') { + if let Ok(port) = port_str.parse::() { + server_list.push(InnerMap { + address: ip.trim().to_string(), + port, + is_ssl: true, + is_http2: false, + to_https: path_config.to_https.unwrap_or(false), + }); + } + } + } + + path_map.insert(path.clone(), (server_list, AtomicUsize::new(0))); + } + + config.headers.insert(hostname.clone(), header_list); + config.upstreams.insert(hostname.clone(), path_map); + } + } +} + pub fn parce_main_config(path: &str) -> AppConfig { info!("Parsing configuration"); let data = fs::read_to_string(path).unwrap(); diff --git a/src/utils/structs.rs b/src/utils/structs.rs index 7dfd286..6a251bc 100644 --- a/src/utils/structs.rs +++ b/src/utils/structs.rs @@ -14,7 +14,7 @@ pub struct ServiceMapping { pub real: String, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Default)] pub struct Extraparams { pub sticky_sessions: bool, pub to_https: Option, @@ -60,7 +60,7 @@ pub struct PathConfig { pub headers: Option>, pub rate_limit: Option, } -#[derive(Debug)] +#[derive(Debug, Default)] pub struct Configuration { pub upstreams: UpstreamsDashMap, pub headers: Headers, @@ -102,49 +102,11 @@ pub struct InnerMap { impl InnerMap { pub fn new() -> Self { Self { - address: String::new(), - port: 0, - is_ssl: false, - is_http2: false, - to_https: false, + address: Default::default(), + port: Default::default(), + is_ssl: Default::default(), + is_http2: Default::default(), + to_https: Default::default(), } } } - -/* -impl InnerMap { - pub fn new(address: String, port: u16) -> Self { - Self { - address, - port, - is_ssl: false, // Default values - is_http2: false, - to_https: false, - } - } - pub fn address(&self) -> &str { - &self.address - } - - pub fn port(&self) -> u16 { - self.port - } - - // Setters with validation - pub fn with_ssl(mut self, ssl: bool) -> Result { - self.is_ssl = ssl; - Ok(self) - } - - pub fn with_http2(mut self, http2: bool) -> Result { - self.is_http2 = http2; - Ok(self) - } - - pub fn with_to_https(mut self, to_https: bool) -> Result { - self.to_https = to_https; - Ok(self) - } -} - -*/ diff --git a/src/utils/tools.rs b/src/utils/tools.rs index f4c5f77..d7a2e1d 100644 --- a/src/utils/tools.rs +++ b/src/utils/tools.rs @@ -21,17 +21,13 @@ pub fn print_upstreams(upstreams: &UpstreamsDashMap) { for path_entry in host_entry.value().iter() { let path = path_entry.key(); - println!(" Path: {}", path); + println!(" Path: {}", path); for f in path_entry.value().0.clone() { println!( - " ===> IP: {}, Port: {}, SSL: {}, H2: {}, To HTTPS: {}", + " IP: {}, Port: {}, SSL: {}, H2: {}, To HTTPS: {}", f.address, f.port, f.is_ssl, f.is_http2, f.to_https ); } - // { address: "127.0.0.4", port: 8000, is_ssl: false, is_http2: false, to_https: false } - // for (ip, port, ssl, vers, to_https) in path_entry.value().0.clone() { - // println!(" ===> IP: {}, Port: {}, SSL: {}, H2: {}, To HTTPS: {}", ip, port, ssl, vers, to_https); - // } } } } diff --git a/src/web/proxyhttp.rs b/src/web/proxyhttp.rs index 4a70824..4210d39 100644 --- a/src/web/proxyhttp.rs +++ b/src/web/proxyhttp.rs @@ -18,6 +18,8 @@ use std::sync::Arc; use std::time::Duration; use tokio::time::Instant; +static RATE_LIMITER: Lazy = Lazy::new(|| Rate::new(Duration::from_secs(1))); + #[derive(Clone)] pub struct LB { pub ump_upst: Arc, @@ -35,10 +37,6 @@ pub struct Context { start_time: Instant, hostname: Option, } -// Rate limiter -static RATE_LIMITER: Lazy = Lazy::new(|| Rate::new(Duration::from_secs(1))); -// max request per second per client -// static MAX_REQ_PER_SEC: isize = 1; #[async_trait] impl ProxyHttp for LB { @@ -64,6 +62,7 @@ impl ProxyHttp for LB { let hostname = return_header_host(&session); _ctx.hostname = hostname.clone(); + if let Some(rate) = self.extraparams.load().rate_limit { match hostname { None => return Ok(false),