From 9b4ee26a2b588b2b9c05540690f8fa4b8cb0adf9 Mon Sep 17 00:00:00 2001 From: Ara Sadoyan Date: Mon, 13 Apr 2026 20:06:57 +0200 Subject: [PATCH] Working on #17 --- Cargo.lock | 1 + Cargo.toml | 2 +- src/utils/auth.rs | 126 +++++++++++++++++++++++++++++++++++++++---- src/web/proxyhttp.rs | 2 +- 4 files changed, 120 insertions(+), 11 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7c388f6..8e90eaa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3060,6 +3060,7 @@ dependencies = [ "base64", "bytes", "encoding_rs", + "futures-channel", "futures-core", "futures-util", "h2", diff --git a/Cargo.toml b/Cargo.toml index efbf91d..bb9796d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,7 +25,7 @@ log = "0.4.29" futures = "0.3.32" notify = "9.0.0-rc.2" axum = { version = "0.8.8" } -reqwest = { version = "0.13.2", features = ["json", "stream"] } +reqwest = { version = "0.13.2", features = ["json", "stream", "blocking"] } serde_yml = "0.0.12" rand = "0.10.0" base64 = "0.22.1" diff --git a/src/utils/auth.rs b/src/utils/auth.rs index 3a68066..1e054be 100644 --- a/src/utils/auth.rs +++ b/src/utils/auth.rs @@ -2,20 +2,106 @@ use crate::utils::jwt::check_jwt; use base64::engine::general_purpose::STANDARD; use base64::Engine; use pingora_proxy::Session; +// use reqwest::Client; use std::collections::HashMap; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; use subtle::ConstantTimeEq; use urlencoding::decode; +// --------------------------------- // +use pingora::http::RequestHeader; +use pingora_core::connectors::http::Connector; +use pingora_core::upstreams::peer::HttpPeer; +// --------------------------------- // + +#[async_trait::async_trait] trait AuthValidator { - fn validate(&self, session: &Session) -> bool; + async fn validate(&self, session: &Session) -> bool; } struct BasicAuth<'a>(&'a str); struct ApiKeyAuth<'a>(&'a str); struct JwtAuth<'a>(&'a str); +struct ForwardAuth<'a>(&'a str); +pub static AUTH_CONNECTOR: LazyLock = LazyLock::new(|| Connector::new(None)); + +#[async_trait::async_trait] +impl AuthValidator for ForwardAuth<'_> { + async fn validate(&self, session: &Session) -> bool { + let method = match session.req_header().method.as_str() { + "POST" => "POST", + _ => "GET", + }; + + let auth_url = self.0; + + let (plain, tls) = if let Some(p) = auth_url.strip_prefix("http://") { + (p, false) + } else if let Some(p) = auth_url.strip_prefix("https://") { + (p, true) + } else { + return false; + }; + + let (addr, uri) = if let Some(pos) = plain.find('/') { + (&plain[..pos], &plain[pos..]) + } else { + (plain, "/") + }; + + let hp = match split_host_port(addr, tls) { + Some(hp) => hp, + None => return false, + }; + + let peer = HttpPeer::new((hp.0, hp.1), tls, hp.0.to_string()); + + let (mut http_session, _) = match AUTH_CONNECTOR.get_http_session(&peer).await { + Ok(s) => s, + Err(e) => { + log::warn!("ForwardAuth: connect failed: {}", e); + return false; + } + }; + + let mut auth_req = match RequestHeader::build(method, uri.as_bytes(), None) { + Ok(r) => r, + Err(e) => { + log::warn!("ForwardAuth: failed to build request: {}", e); + return false; + } + }; + + // Filter headers ???? + // auth_req.headers = session.req_header().headers.clone(); + auth_req.insert_header("Host", addr).ok(); + auth_req.insert_header("X-Forwarded-Uri", uri).ok(); + if let Some(auth) = session.req_header().headers.get("authorization") { + auth_req.insert_header("Authorization", auth.clone()).ok(); + } + + if let Err(e) = http_session.write_request_header(Box::new(auth_req)).await { + log::warn!("ForwardAuth: write failed: {}", e); + return false; + } + + let status = match http_session.read_response_header().await { + Ok(_) => http_session.response_header().map(|r| r.status.as_u16()).unwrap_or(500), + Err(e) => { + log::warn!("ForwardAuth: read failed: {}", e); + return false; + } + }; + + AUTH_CONNECTOR.release_http_session(http_session, &peer, None).await; + + (200..300).contains(&status) + } +} + +#[async_trait::async_trait] impl AuthValidator for BasicAuth<'_> { - fn validate(&self, session: &Session) -> bool { + async fn validate(&self, session: &Session) -> bool { if let Some(header) = session.get_header("authorization") { if let Some(h) = header.to_str().ok() { if let Some((_, val)) = h.split_once(' ') { @@ -31,8 +117,9 @@ impl AuthValidator for BasicAuth<'_> { } } +#[async_trait::async_trait] impl AuthValidator for ApiKeyAuth<'_> { - fn validate(&self, session: &Session) -> bool { + async fn validate(&self, session: &Session) -> bool { if let Some(header) = session.get_header("x-api-key") { if let Some(h) = header.to_str().ok() { return h.as_bytes().ct_eq(self.0.as_bytes()).into(); @@ -42,8 +129,9 @@ impl AuthValidator for ApiKeyAuth<'_> { } } +#[async_trait::async_trait] impl AuthValidator for JwtAuth<'_> { - fn validate(&self, session: &Session) -> bool { + async fn validate(&self, session: &Session) -> bool { let jwtsecret = self.0; if let Some(tok) = get_query_param(session, "araleztoken") { return check_jwt(tok.as_str(), jwtsecret); @@ -61,11 +149,12 @@ impl AuthValidator for JwtAuth<'_> { } } -pub fn authenticate(auth_type: &Arc, credentials: &Arc, session: &Session) -> bool { +pub async fn authenticate(auth_type: &Arc, credentials: &Arc, session: &Session) -> bool { match &**auth_type { - "basic" => BasicAuth(credentials).validate(session), - "apikey" => ApiKeyAuth(credentials).validate(session), - "jwt" => JwtAuth(credentials).validate(session), + "basic" => BasicAuth(credentials).validate(session).await, + "apikey" => ApiKeyAuth(credentials).validate(session).await, + "jwt" => JwtAuth(credentials).validate(session).await, + "forward" => ForwardAuth(credentials).validate(session).await, _ => { log::warn!("Unsupported authentication mechanism : {}", auth_type); false @@ -87,3 +176,22 @@ pub fn get_query_param(session: &Session, key: &str) -> Option { .collect(); params.get(key).and_then(|v| decode(v).ok()).map(|s| s.to_string()) } + +fn split_host_port(addr: &str, tls: bool) -> Option<(&str, u16, bool, &str)> { + match addr.split_once(':') { + Some((h, p)) => match p.parse::() { + Ok(port) => return Some((h, port, tls, h)), + Err(_) => { + log::warn!("ForwardAuth: invalid port in {}", addr); + return None; + } + }, + None => { + if tls { + return Some((addr, 443u16, tls, addr)); + } else { + return Some((addr, 80u16, tls, addr)); + } + } + }; +} diff --git a/src/web/proxyhttp.rs b/src/web/proxyhttp.rs index 1a4ad20..e0bb2b4 100644 --- a/src/web/proxyhttp.rs +++ b/src/web/proxyhttp.rs @@ -89,7 +89,7 @@ impl ProxyHttp for LB { None => return Ok(false), Some(ref innermap) => { if let Some(auth) = _ctx.extraparams.authentication.as_ref().or(innermap.authorization.as_ref()) { - if !authenticate(&auth.auth_type, &auth.auth_cred, &session) { + if !authenticate(&auth.auth_type, &auth.auth_cred, &session).await { let _ = session.respond_error(401).await; warn!("Forbidden: {:?}, {}", session.client_addr(), session.req_header().uri.path()); return Ok(true);