diff --git a/src/utils/auth.rs b/src/utils/auth.rs index 5e89c97..cb1ab7b 100644 --- a/src/utils/auth.rs +++ b/src/utils/auth.rs @@ -1,4 +1,5 @@ -use crate::utils::jwt::check_jwt; +use crate::utils::jwt::{check_jwt, JWT_TOKEN}; +use crate::utils::structs::InnerAuth; use axum::http::StatusCode; use base64::engine::general_purpose::STANDARD; use base64::Engine; @@ -8,7 +9,7 @@ use pingora_core::upstreams::peer::HttpPeer; use pingora_http::ResponseHeader; use pingora_proxy::Session; use std::collections::HashMap; -use std::sync::{Arc, LazyLock}; +use std::sync::LazyLock; use subtle::ConstantTimeEq; use urlencoding::decode; @@ -18,7 +19,7 @@ trait AuthValidator { } struct BasicAuth<'a>(&'a str); struct ApiKeyAuth<'a>(&'a str); -struct JwtAuth<'a>(&'a str); +struct JwtAuth(); struct ForwardAuth<'a>(&'a str); pub static AUTH_CONNECTOR: LazyLock = LazyLock::new(|| Connector::new(None)); @@ -175,18 +176,19 @@ impl AuthValidator for ApiKeyAuth<'_> { } #[async_trait::async_trait] -impl AuthValidator for JwtAuth<'_> { +impl AuthValidator for JwtAuth { async fn validate(&self, session: &mut Session) -> bool { - println!("{:?}", self.0); - let jwtsecret = self.0; - if let Some(tok) = get_query_param(session, "araleztoken") { - return check_jwt(tok.as_str(), jwtsecret); - } - if let Some(auth_header) = session.get_header("authorization") { - if let Ok(header_str) = auth_header.to_str() { - if let Some((scheme, token)) = header_str.split_once(' ') { - if scheme.eq_ignore_ascii_case("bearer") { - return check_jwt(token, jwtsecret); + if let Some(jwtsecret) = JWT_TOKEN.clone() { + // println!(" ===> {:?}", jwtsecret); + if let Some(tok) = get_query_param(session, "araleztoken") { + return check_jwt(tok.as_str(), jwtsecret.as_ref()); + } + if let Some(auth_header) = session.get_header("authorization") { + if let Ok(header_str) = auth_header.to_str() { + if let Some((scheme, token)) = header_str.split_once(' ') { + if scheme.eq_ignore_ascii_case("bearer") { + return check_jwt(token, jwtsecret.as_ref()); + } } } } @@ -195,14 +197,14 @@ impl AuthValidator for JwtAuth<'_> { } } -pub async fn authenticate(auth_type: &Arc, credentials: &Arc, session: &mut Session) -> bool { - match &**auth_type { - "basic" => BasicAuth(credentials).validate(session).await, - "apikey" => ApiKeyAuth(credentials).validate(session).await, - "jwt" => JwtAuth(credentials).validate(session).await, - "forward" => ForwardAuth(credentials).validate(session).await, +pub async fn authenticate(auth: &InnerAuth, session: &mut Session) -> bool { + match &*auth.auth_type { + "basic" => BasicAuth(&*auth.auth_cred).validate(session).await, + "apikey" => ApiKeyAuth(&*auth.auth_cred).validate(session).await, + "jwt" => JwtAuth().validate(session).await, + "forward" => ForwardAuth(&*auth.auth_cred).validate(session).await, _ => { - log::warn!("Unsupported authentication mechanism : {}", auth_type); + log::warn!("Unsupported authentication mechanism : {}", &*auth.auth_type); false } } diff --git a/src/utils/jwt.rs b/src/utils/jwt.rs index 55d7ea3..e1b003b 100644 --- a/src/utils/jwt.rs +++ b/src/utils/jwt.rs @@ -4,8 +4,9 @@ use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation}; use moka::sync::Cache; use moka::Expiry; use serde::{Deserialize, Serialize}; +use std::env; use std::hash::{Hash, Hasher}; -use std::sync::LazyLock; +use std::sync::{Arc, LazyLock}; use std::time::{Duration, Instant, SystemTime}; #[derive(Debug, Serialize, Deserialize)] @@ -23,6 +24,11 @@ struct Expired { static JWT_VALIDATION: LazyLock = LazyLock::new(|| Validation::new(Algorithm::HS256)); +pub static JWT_TOKEN: LazyLock>> = LazyLock::new(|| match env::var("JWT_KEY") { + Ok(key) if !key.is_empty() => Some(Arc::from(key.as_str())), + _ => None, +}); + static JWT_CACHE: LazyLock> = LazyLock::new(|| Cache::builder().max_capacity(100_000).expire_after(JwtExpiry).build()); struct JwtExpiry; impl Expiry for JwtExpiry { diff --git a/src/utils/parceyaml.rs b/src/utils/parceyaml.rs index 294732b..280e9e1 100644 --- a/src/utils/parceyaml.rs +++ b/src/utils/parceyaml.rs @@ -85,6 +85,7 @@ pub async fn load_configuration(d: &str, kind: &str) -> (Option, let mut parsed: Config = match serde_yml::from_str(&yaml_data) { Ok(cfg) => cfg, Err(e) => { + println!("================================================"); error!("Failed to parse upstreams file: {}", e); return (None, e.to_string()); } @@ -136,6 +137,7 @@ async fn populate_headers_and_auth(config: &mut Configuration, parsed: &Config) } } } + let global_headers: DashMap, Vec<(String, Arc)>> = DashMap::new(); global_headers.insert(Arc::from("/"), ch); config.client_headers.insert(Arc::from("GLOBAL_CLIENT_HEADERS"), global_headers); @@ -162,7 +164,7 @@ async fn populate_headers_and_auth(config: &mut Configuration, parsed: &Config) if let Some(pa) = &parsed.authorization { let y: InnerAuth = InnerAuth { auth_type: Arc::from(pa.auth_type.clone()), - auth_cred: Arc::from(pa.auth_cred.clone()), + auth_cred: Arc::from(pa.auth_cred.clone().unwrap_or_default()), }; config.extraparams.authentication = Some(Arc::from(y)); } @@ -191,7 +193,7 @@ async fn populate_file_upstreams(config: &mut Configuration, parsed: &Config) { if let Some(pa) = &path_config.authorization { let y: InnerAuth = InnerAuth { auth_type: Arc::from(pa.auth_type.clone()), - auth_cred: Arc::from(pa.auth_cred.clone()), + auth_cred: Arc::from(pa.auth_cred.clone().unwrap_or_default()), }; path_auth = Some(Arc::from(y)); } diff --git a/src/utils/structs.rs b/src/utils/structs.rs index 6e64a83..ff2c506 100644 --- a/src/utils/structs.rs +++ b/src/utils/structs.rs @@ -77,7 +77,7 @@ pub struct Auth { #[serde(rename = "type")] pub auth_type: String, #[serde(rename = "data")] - pub auth_cred: String, + pub auth_cred: Option, } #[derive(Debug, Default, Serialize, Deserialize)] pub struct PathConfig { diff --git a/src/web/proxyhttp.rs b/src/web/proxyhttp.rs index ee9b40d..535c14b 100644 --- a/src/web/proxyhttp.rs +++ b/src/web/proxyhttp.rs @@ -85,7 +85,7 @@ impl ProxyHttp for LB { None => return Ok(false), Some(ref innermap) => { if let Some(auth) = _ctx.extraparams.authentication.as_ref().or(innermap.authorization.as_ref()) { - if !authenticate(&auth.auth_type, &auth.auth_cred, session).await { + if !authenticate(&auth, session).await { let _ = session.respond_error(401).await; warn!("Forbidden: {:?}, {}", session.client_addr(), session.req_header().uri.path()); return Ok(true);