use crate::utils::jwt::{check_jwt, JWT_TOKEN}; use crate::utils::structs::InnerAuth; use axum::http::StatusCode; use base64::engine::general_purpose::STANDARD; use base64::Engine; use pingora::http::RequestHeader; use pingora_core::connectors::http::Connector; use pingora_core::upstreams::peer::HttpPeer; use pingora_http::ResponseHeader; use pingora_proxy::Session; use std::collections::HashMap; use std::sync::LazyLock; use subtle::ConstantTimeEq; use urlencoding::decode; #[async_trait::async_trait] trait AuthValidator { async fn validate(&self, session: &mut Session) -> bool; } struct BasicAuth<'a>(&'a str); struct ApiKeyAuth<'a>(&'a str); struct JwtAuth(); struct ForwardAuth<'a>(&'a str); pub static AUTH_CONNECTOR: LazyLock = LazyLock::new(|| Connector::new(None)); #[async_trait::async_trait] impl AuthValidator for ForwardAuth<'_> { async fn validate(&self, session: &mut Session) -> bool { let method = match session.req_header().method.as_str() { "HEAD" => "HEAD", _ => "GET", }; let auth_url = self.0; let (plain, tls) = if let Some(p) = auth_url.strip_prefix("http://") { (p, false) } else if let Some(p) = auth_url.strip_prefix("https://") { (p, true) } else { return false; }; let (addr, uri) = if let Some(pos) = plain.find('/') { (&plain[..pos], &plain[pos..]) } else { (plain, "/") }; let hp = match split_host_port(addr, tls) { Some(hp) => hp, None => return false, }; let peer = HttpPeer::new((hp.0, hp.1), tls, hp.0.to_string()); let (mut http_session, _) = match AUTH_CONNECTOR.get_http_session(&peer).await { Ok(s) => s, Err(e) => { log::warn!("ForwardAuth: connect failed: {}", e); return false; } }; let mut auth_req = match RequestHeader::build(method, uri.as_bytes(), None) { Ok(r) => r, Err(e) => { log::warn!("ForwardAuth: failed to build request: {}", e); return false; } }; // auth_req.headers = session.req_header().headers.clone(); auth_req.insert_header("Host", addr).ok(); auth_req.insert_header("X-Forwarded-Uri", uri).ok(); auth_req.insert_header("X-Forwarded-Method", session.req_header().method.as_str()).ok(); if let Some(auth) = session.req_header().headers.get("authorization") { auth_req.insert_header("Authorization", auth.clone()).ok(); } if let Some(cookie) = session.req_header().headers.get("cookie") { auth_req.insert_header("Cookie", cookie.clone()).ok(); } if tls { auth_req.insert_header("X-Forwarded-Proto", "https").ok(); } else { auth_req.insert_header("X-Forwarded-Proto", "http").ok(); } if let Err(e) = http_session.write_request_header(Box::new(auth_req)).await { log::warn!("ForwardAuth: write failed: {}", e); return false; } let status = match http_session.read_response_header().await { Ok(_) => http_session.response_header().map(|r| r.status.as_u16()).unwrap_or(500), Err(e) => { log::warn!("ForwardAuth: read failed: {}", e); return false; } }; let auth_headers_to_forward: Vec<(String, String)> = if let Some(resp_header) = http_session.response_header() { resp_header .headers .iter() .filter_map(|(name, value)| { let name_str = name.as_str(); if name_str.starts_with("x-") || name_str.starts_with("remote-") || name_str.starts_with("locat") { value.to_str().ok().map(|v| (name_str.to_string(), v.to_string())) } else { None } }) .collect() } else { Vec::new() }; AUTH_CONNECTOR.release_http_session(http_session, &peer, None).await; if (200..300).contains(&status) { for (name, value) in auth_headers_to_forward { session.req_header_mut().insert_header(name, value).ok(); } true } else if status == 302 || status == 301 { let resp = ResponseHeader::build(StatusCode::MOVED_PERMANENTLY, None); match resp { Ok(mut r) => { for (name, value) in auth_headers_to_forward { r.insert_header(name, value).ok(); } let _ = r.insert_header("Content-Length", "0"); let _ = session.write_response_header(Box::new(r), true).await; true } Err(_) => return false, } } else { false } } } #[async_trait::async_trait] impl AuthValidator for BasicAuth<'_> { async fn validate(&self, session: &mut Session) -> bool { if let Some(header) = session.get_header("authorization") { if let Ok(h) = header.to_str() { if let Some((_, val)) = h.split_once(' ') { if let Ok(decoded) = STANDARD.decode(val) { if decoded.as_slice().ct_eq(self.0.as_bytes()).into() { return true; } } } } } false } } #[async_trait::async_trait] impl AuthValidator for ApiKeyAuth<'_> { async fn validate(&self, session: &mut Session) -> bool { if let Some(header) = session.get_header("x-api-key") { if let Ok(h) = header.to_str() { return h.as_bytes().ct_eq(self.0.as_bytes()).into(); } } false } } #[async_trait::async_trait] impl AuthValidator for JwtAuth { async fn validate(&self, session: &mut Session) -> bool { if let Some(jwtsecret) = JWT_TOKEN.clone() { if let Some(tok) = get_query_param(session, "araleztoken") { return check_jwt(tok.as_str(), jwtsecret.as_ref()); } if let Some(auth_header) = session.get_header("authorization") { if let Ok(header_str) = auth_header.to_str() { if let Some((scheme, token)) = header_str.split_once(' ') { if scheme.eq_ignore_ascii_case("bearer") { return check_jwt(token, jwtsecret.as_ref()); } } } } } false } } pub async fn authenticate(auth: &InnerAuth, session: &mut Session) -> bool { match &*auth.auth_type { "basic" => BasicAuth(&*auth.auth_cred).validate(session).await, "apikey" => ApiKeyAuth(&*auth.auth_cred).validate(session).await, "jwt" => JwtAuth().validate(session).await, "forward" => ForwardAuth(&*auth.auth_cred).validate(session).await, _ => { log::warn!("Unsupported authentication mechanism : {}", &*auth.auth_type); false } } } pub fn get_query_param(session: &mut Session, key: &str) -> Option { let query = session.req_header().uri.query()?; let params: HashMap<_, _> = query .split('&') .filter_map(|pair| { let mut parts = pair.splitn(2, '='); let k = parts.next()?; let v = parts.next().unwrap_or(""); // Some params might have no value Some((k, v)) }) .collect(); params.get(key).and_then(|v| decode(v).ok()).map(|s| s.to_string()) } #[allow(clippy::needless_return)] fn split_host_port(addr: &str, tls: bool) -> Option<(&str, u16, bool, &str)> { match addr.split_once(':') { Some((h, p)) => match p.parse::() { Ok(port) => return Some((h, port, tls, h)), Err(_) => { log::warn!("ForwardAuth: invalid port in {}", addr); return None; } }, None => { if tls { return Some((addr, 443u16, tls, addr)); } else { return Some((addr, 80u16, tls, addr)); } } }; }