use crate::utils::healthcheck; use crate::utils::state::{is_first_run, mark_not_first_run}; use crate::utils::structs::*; use crate::utils::tools::{clone_dashmap, clone_dashmap_into, print_upstreams}; use dashmap::DashMap; use log::{error, info, warn}; use std::collections::HashMap; use std::path::Path; use std::sync::atomic::AtomicUsize; use std::sync::Arc; use std::{env, fs}; pub async fn load_configuration(d: &str, kind: &str) -> (Option, String) { let mut conf_files = Vec::new(); let yaml_data = match kind { "filepath" => match fs::read_to_string(d) { Ok(data) => { let mut confdir = Path::new(d).parent().unwrap().to_path_buf(); confdir.push("conf.d"); if let Ok(entries) = fs::read_dir(&confdir) { let mut paths: Vec<_> = entries .flatten() .map(|e| e.path()) .filter(|p| p.extension().and_then(|e| e.to_str()) == Some("yaml")) .collect(); paths.sort(); for path in paths { let content = fs::read_to_string(&path); match content { Ok(content) => { conf_files.push(content); } Err(e) => { error!("Reading: {}: {:?}", path.display(), e) } }; } } info!("Reading upstreams from {}", d); data } Err(e) => { error!("Reading: {}: {:?}", d, e); warn!("Running with empty upstreams list, update it via API"); return (None, e.to_string()); } }, "content" => { info!("Reading upstreams from API post body"); d.to_string() } _ => { error!("Mismatched parameter, only filepath|content is allowed"); return (None, "Mismatched parameter, only filepath|content is allowed".to_string()); } }; let mut parsed: Config = match serde_yml::from_str(&yaml_data) { Ok(cfg) => cfg, Err(e) => { error!("Failed to parse upstreams file: {}", e); return (None, e.to_string()); } }; if let Some(ref mut upstreams) = parsed.upstreams { for uconf in conf_files { let p: HashMap = match serde_yml::from_str(&uconf) { Ok(ucfg) => ucfg, Err(e) => { error!("Failed to parse upstreams file: {}", e); return (None, e.to_string()); } }; upstreams.extend(p); } } let mut toreturn = Configuration::default(); populate_headers_and_auth(&mut toreturn, &parsed).await; toreturn.typecfg = parsed.provider.clone(); match parsed.provider.as_str() { "file" => { populate_file_upstreams(&mut toreturn, &parsed).await; (Some(toreturn), "Ok".to_string()) } "consul" => { toreturn.consul = parsed.consul; (toreturn.consul.is_some().then_some(toreturn), "Ok".to_string()) } "kubernetes" => { toreturn.kubernetes = parsed.kubernetes; (toreturn.kubernetes.is_some().then_some(toreturn), "Ok".to_string()) } _ => { warn!("Unknown provider {}", parsed.provider); (None, "Unknown provider".to_string()) } } } async fn populate_headers_and_auth(config: &mut Configuration, parsed: &Config) { let mut ch: Vec<(String, Arc)> = Vec::new(); if let Some(headers) = &parsed.client_headers { for header in headers { if let Some((key, val)) = header.split_once(':') { ch.push((key.to_string(), Arc::from(val))); } } } let global_headers: DashMap, Vec<(String, Arc)>> = DashMap::new(); global_headers.insert(Arc::from("/"), ch); config.client_headers.insert(Arc::from("GLOBAL_CLIENT_HEADERS"), global_headers); let mut sh: Vec<(String, Arc)> = Vec::new(); if let Some(headers) = &parsed.server_headers { for header in headers { if let Some((key, val)) = header.split_once(':') { sh.push((key.to_string(), Arc::from(val.trim()))); } } } let server_global_headers: DashMap, Vec<(String, Arc)>> = DashMap::new(); server_global_headers.insert(Arc::from("/"), sh); config.server_headers.insert(Arc::from("GLOBAL_SERVER_HEADERS"), server_global_headers); config.extraparams.to_https = parsed.to_https; config.extraparams.sticky_sessions = parsed.sticky_sessions; config.extraparams.rate_limit = parsed.rate_limit; if let Some(rate) = &parsed.rate_limit { info!("Applied Global Rate Limit : {} request per second", rate); } if let Some(pa) = &parsed.authorization { let y: InnerAuth = InnerAuth { auth_type: Arc::from(pa.auth_type.clone()), auth_cred: Arc::from(pa.auth_cred.clone()), }; config.extraparams.authentication = Some(Arc::from(y)); } } async fn populate_file_upstreams(config: &mut Configuration, parsed: &Config) { let imtdashmap = UpstreamsDashMap::new(); if let Some(upstreams) = &parsed.upstreams { for (hostname, host_config) in upstreams { let path_map = DashMap::new(); let client_header_list = DashMap::new(); let server_header_list = DashMap::new(); for (path, path_config) in &host_config.paths { if let Some(rate) = &path_config.rate_limit { info!("Applied Rate Limit for {} : {} request per second", hostname, rate); } let mut hl: Vec<(String, Arc)> = Vec::new(); let mut sl: Vec<(String, Arc)> = Vec::new(); build_headers(&path_config.client_headers, config, &mut hl); build_headers(&path_config.server_headers, config, &mut sl); client_header_list.insert(Arc::from(path.as_str()), hl); server_header_list.insert(Arc::from(path.as_str()), sl); let mut server_list = Vec::new(); for server in &path_config.servers { let mut path_auth: Option> = None; if let Some(pa) = &path_config.authorization { let y: InnerAuth = InnerAuth { auth_type: Arc::from(pa.auth_type.clone()), auth_cred: Arc::from(pa.auth_cred.clone()), }; path_auth = Some(Arc::from(y)); } let redirect_link = path_config.redirect_to.as_ref().map(|www| Arc::from(www.as_str())); if let Some((ip, port_str)) = server.split_once(':') { if let Ok(port) = port_str.parse::() { server_list.push(Arc::from(InnerMap { address: Arc::from(ip), port, is_ssl: true, is_http2: false, to_https: path_config.to_https.unwrap_or(false), rate_limit: path_config.rate_limit, healthcheck: path_config.healthcheck, redirect_to: redirect_link, authorization: path_auth, })); } } } path_map.insert(Arc::from(path.clone()), (server_list, AtomicUsize::new(0))); } config.client_headers.insert(Arc::from(hostname.clone()), client_header_list); config.server_headers.insert(Arc::from(hostname.clone()), server_header_list); imtdashmap.insert(Arc::from(hostname.clone()), path_map); } if is_first_run() { clone_dashmap_into(&imtdashmap, &config.upstreams); mark_not_first_run(); } else { let y = clone_dashmap(&imtdashmap); let r = healthcheck::initiate_upstreams(y).await; clone_dashmap_into(&r, &config.upstreams); } info!("Upstream Config:"); print_upstreams(&config.upstreams); } } pub fn parce_main_config(path: &str) -> AppConfig { let data = fs::read_to_string(path).unwrap(); let reply = DashMap::new(); let cfg: HashMap = serde_yml::from_str(&*data).expect("Failed to parse main config file"); let mut cfo: AppConfig = serde_yml::from_str(&*data).expect("Failed to parse main config file"); log_builder(&cfo); cfo.hc_method = cfo.hc_method.to_uppercase(); for (k, v) in cfg { reply.insert(k.to_string(), v.to_string()); } if let Some((ip, port_str)) = cfo.config_address.split_once(':') { if let Ok(port) = port_str.parse::() { cfo.local_server = Option::from((ip.to_string(), port)); } } // if let Some(tlsport_cfg) = cfo.proxy_address_tls.clone() { // if let Some((_, port_str)) = tlsport_cfg.split_once(':') { // if let Ok(port) = port_str.parse::() { // cfo.proxy_port_tls = Some(port); // } // } // }; if let Some(tlsport_cfg) = cfo.proxy_address_tls.clone() { if let Some((_, port_str)) = tlsport_cfg.split_once(':') { cfo.proxy_port_tls = Some(port_str.to_string()); } }; if let Some((_, port_str)) = cfo.proxy_address_http.split_once(':') { cfo.proxy_port = Some(port_str.to_string()); } cfo.proxy_tls_grade = parce_tls_grades(cfo.proxy_tls_grade.clone()); cfo } fn parce_tls_grades(what: Option) -> Option { match what { Some(g) => match g.to_ascii_lowercase().as_str() { "high" => { // info!("TLS grade set to: [ HIGH ]"); Some("high".to_string()) } "medium" => { // info!("TLS grade set to: [ MEDIUM ]"); Some("medium".to_string()) } "unsafe" => { // info!("TLS grade set to: [ UNSAFE ]"); Some("unsafe".to_string()) } _ => { warn!("Error parsing TLS grade, defaulting to: `medium`"); Some("medium".to_string()) } }, None => { warn!("TLS grade not set, defaulting to: medium"); Some("b".to_string()) } } } fn log_builder(conf: &AppConfig) { let log_level = conf.log_level.clone(); unsafe { match log_level.as_str() { "info" => env::set_var("RUST_LOG", "info"), "error" => env::set_var("RUST_LOG", "error"), "warn" => env::set_var("RUST_LOG", "warn"), "debug" => env::set_var("RUST_LOG", "debug"), "trace" => env::set_var("RUST_LOG", "trace"), "off" => env::set_var("RUST_LOG", "off"), _ => { println!("Error reading log level, defaulting to: INFO"); env::set_var("RUST_LOG", "info") } } } env_logger::builder().init(); } pub fn build_headers(path_config: &Option>, _config: &Configuration, hl: &mut Vec<(String, Arc)>) { if let Some(headers) = &path_config { for header in headers { if let Some((key, val)) = header.split_once(':') { hl.push((key.trim().to_string(), Arc::from(val.trim()))); } } } }