diff --git a/src/utils/auth.rs b/src/utils/auth.rs index afdfe7a..c2f6cad 100644 --- a/src/utils/auth.rs +++ b/src/utils/auth.rs @@ -57,22 +57,23 @@ fn validate(auth: &dyn AuthValidator, session: &Session) -> bool { auth.validate(session) } -pub fn authenticate(c: &[Arc], session: &Session) -> bool { - match &*c[0] { +// pub fn authenticate(c: &[Arc], session: &Session) -> bool { +pub fn authenticate(auth_type: &Arc, credentials: &Arc, session: &Session) -> bool { + match &*auth_type.clone() { "basic" => { - let auth = BasicAuth(&*c[1]); + let auth = BasicAuth(&*credentials.clone()); validate(&auth, session) } "apikey" => { - let auth = ApiKeyAuth(&*c[1]); + let auth = ApiKeyAuth(&*credentials.clone()); validate(&auth, session) } "jwt" => { - let auth = JwtAuth(&*c[1]); + let auth = JwtAuth(&*credentials.clone()); validate(&auth, session) } _ => { - println!("Unsupported authentication mechanism : {}", c[0]); + println!("Unsupported authentication mechanism : {}", auth_type); false } } diff --git a/src/utils/healthcheck.rs b/src/utils/healthcheck.rs index 15ddb94..2761730 100644 --- a/src/utils/healthcheck.rs +++ b/src/utils/healthcheck.rs @@ -70,6 +70,7 @@ async fn build_upstreams(fullist: &UpstreamsDashMap, method: &str, client: &Clie to_https: upstream.to_https, rate_limit: upstream.rate_limit, healthcheck: upstream.healthcheck, + authorization: upstream.authorization.clone(), }; if scheme.healthcheck.unwrap_or(true) { diff --git a/src/utils/httpclient.rs b/src/utils/httpclient.rs index f681c76..519d998 100644 --- a/src/utils/httpclient.rs +++ b/src/utils/httpclient.rs @@ -35,6 +35,7 @@ pub async fn for_consul(url: String, token: Option, conf: &ServiceMappin to_https: conf.to_https.unwrap_or(false), rate_limit: conf.rate_limit, healthcheck: None, + authorization: None, }); inner_vec.push(to_add); } @@ -68,6 +69,7 @@ pub async fn for_kuber(url: &str, token: &str, conf: &ServiceMapping) -> Option< to_https: conf.to_https.unwrap_or(false), rate_limit: conf.rate_limit, healthcheck: None, + authorization: None, }); inner_vec.push(to_add); } diff --git a/src/utils/parceyaml.rs b/src/utils/parceyaml.rs index da962e3..ae9359d 100644 --- a/src/utils/parceyaml.rs +++ b/src/utils/parceyaml.rs @@ -34,7 +34,10 @@ pub async fn load_configuration(d: &str, kind: &str) -> (Option, }; let parsed: Config = match serde_yaml::from_str(&yaml_data) { - Ok(cfg) => cfg, + Ok(cfg) => { + // println!("{:#?}", cfg); + cfg + } Err(e) => { error!("Failed to parse upstreams file: {}", e); return (None, e.to_string()); @@ -97,6 +100,7 @@ async fn populate_headers_and_auth(config: &mut Configuration, parsed: &Config) info!("Applied Global Rate Limit : {} request per second", rate); } + // ======================================================================================== // if let Some(auth) = &parsed.authorization { let name = auth.get("type").unwrap_or(&"".to_string()).to_string(); let creds = auth.get("creds").unwrap_or(&"".to_string()).to_string(); @@ -107,6 +111,7 @@ async fn populate_headers_and_auth(config: &mut Configuration, parsed: &Config) } else { config.extraparams.authentication = DashMap::new(); } + // ======================================================================================== // } async fn populate_file_upstreams(config: &mut Configuration, parsed: &Config) { @@ -126,9 +131,17 @@ async fn populate_file_upstreams(config: &mut Configuration, parsed: &Config) { build_headers(&path_config.server_headers, config, &mut sl); client_header_list.insert(Arc::from(path.as_str()), hl); server_header_list.insert(Arc::from(path.as_str()), sl); - let mut server_list = Vec::new(); for server in &path_config.servers { + let mut path_auth: Option> = None; + if let Some(pa) = &path_config.authorization { + let y: InnerAuth = InnerAuth { + auth_type: Arc::from(pa.auth_type.clone()), + auth_cred: Arc::from(pa.auth_cred.clone()), + }; + path_auth = Some(Arc::from(y)); + } + if let Some((ip, port_str)) = server.split_once(':') { if let Ok(port) = port_str.parse::() { server_list.push(Arc::from(InnerMap { @@ -139,6 +152,7 @@ async fn populate_file_upstreams(config: &mut Configuration, parsed: &Config) { to_https: path_config.to_https.unwrap_or(false), rate_limit: path_config.rate_limit, healthcheck: path_config.healthcheck, + authorization: path_auth, })); } } diff --git a/src/utils/structs.rs b/src/utils/structs.rs index 4403fad..6d83c07 100644 --- a/src/utils/structs.rs +++ b/src/utils/structs.rs @@ -14,6 +14,7 @@ pub struct Extraparams { pub to_https: Option, pub sticky_sessions: bool, pub authentication: DashMap, Vec>>, + // pub authentication: InnerAuth, pub rate_limit: Option, } @@ -70,7 +71,13 @@ pub struct HostConfig { pub paths: HashMap, pub rate_limit: Option, } - +#[derive(Debug, Default, Serialize, Deserialize, Clone)] +pub struct Auth { + #[serde(rename = "type")] + pub auth_type: String, + #[serde(rename = "creds")] + pub auth_cred: String, +} #[derive(Debug, Default, Serialize, Deserialize)] pub struct PathConfig { pub servers: Vec, @@ -80,6 +87,8 @@ pub struct PathConfig { pub server_headers: Option>, pub rate_limit: Option, pub healthcheck: Option, + // pub authorization: Option>, + pub authorization: Option, } #[derive(Debug, Default)] pub struct Configuration { @@ -116,7 +125,13 @@ pub struct AppConfig { pub rungroup: Option, } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Default, Clone, PartialEq, Eq)] +pub struct InnerAuth { + pub auth_type: Arc, + pub auth_cred: Arc, +} + +#[derive(Debug, Clone, PartialEq, Eq)] pub struct InnerMap { pub address: Arc, pub port: u16, @@ -125,6 +140,8 @@ pub struct InnerMap { pub to_https: bool, pub rate_limit: Option, pub healthcheck: Option, + // pub authorization: Option, Arc>>, + pub authorization: Option>, } #[allow(dead_code)] @@ -139,6 +156,7 @@ impl InnerMap { to_https: Default::default(), rate_limit: Default::default(), healthcheck: Default::default(), + authorization: Default::default(), } } } diff --git a/src/utils/tools.rs b/src/utils/tools.rs index 7f2212c..5faec8f 100644 --- a/src/utils/tools.rs +++ b/src/utils/tools.rs @@ -124,11 +124,23 @@ pub fn compare_dashmaps(map1: &UpstreamsDashMap, map2: &UpstreamsDashMap) -> boo return false; // Path exists in map1 but not in map2 }; let (vec2, _counter2) = entry2.value(); - let set1: HashSet<_> = vec1.iter().collect(); - let set2: HashSet<_> = vec2.iter().collect(); - if set1 != set2 { + + if vec1.len() != vec2.len() { return false; } + for item in vec1.iter() { + let count1 = vec1.iter().filter(|&x| x == item).count(); + let count2 = vec2.iter().filter(|&x| x == item).count(); + if count1 != count2 { + return false; + } + } + + // let set1: HashSet<_> = vec1.iter().collect(); + // let set2: HashSet<_> = vec2.iter().collect(); + // if set1 != set2 { + // return false; + // } } } true @@ -168,6 +180,7 @@ pub fn clone_idmap_into(original: &UpstreamsDashMap, cloned: &UpstreamsIdMap) { to_https: false, rate_limit: None, healthcheck: None, + authorization: None, }; cloned.insert(id, Arc::from(to_add)); diff --git a/src/web/proxyhttp.rs b/src/web/proxyhttp.rs index 756c474..93e4098 100644 --- a/src/web/proxyhttp.rs +++ b/src/web/proxyhttp.rs @@ -70,17 +70,20 @@ impl ProxyHttp for LB { } async fn request_filter(&self, session: &mut Session, _ctx: &mut Self::CTX) -> Result { let ep = _ctx.extraparams.as_ref(); + + // ======================================================================================== // + println!("{:?}", ep); if let Some(auth) = ep.authentication.get("authorization") { - let authenticated = authenticate(auth.value(), &session); + let authenticated = authenticate(&auth.value()[0], &auth.value()[1], &session); if !authenticated { let _ = session.respond_error(401).await; warn!("Forbidden: {:?}, {}", session.client_addr(), session.req_header().uri.path()); return Ok(true); } }; + // ======================================================================================== // let hostname = return_header_host_from_upstream(session, &self.ump_upst); - _ctx.hostname = hostname; let mut backend_id = None; @@ -101,9 +104,19 @@ impl ProxyHttp for LB { None => return Ok(false), Some(host) => { let optioninnermap = self.get_host(host, session.req_header().uri.path(), backend_id); + match optioninnermap { None => return Ok(false), Some(ref innermap) => { + if let Some(auth) = &innermap.authorization { + let authenticated = authenticate(&auth.auth_type, &auth.auth_cred, &session); + if !authenticated { + let _ = session.respond_error(401).await; + warn!("Forbidden: {:?}, {}", session.client_addr(), session.req_header().uri.path()); + return Ok(true); + } + } + if let Some(rate) = innermap.rate_limit.or(ep.rate_limit) { let rate_key = session.client_addr().and_then(|addr| addr.as_inet()).map(|inet| inet.ip()); let curr_window_requests = RATE_LIMITER.observe(&rate_key, 1);