From 8aff2fa87571caf1e9a6476073ead495e9464596 Mon Sep 17 00:00:00 2001 From: Ara Sadoyan Date: Tue, 14 Apr 2026 16:11:24 +0200 Subject: [PATCH] Standardizing implementation of #17 --- src/utils/auth.rs | 82 ++++++++++++++++++++++++++++++++++++-------- src/web/proxyhttp.rs | 2 +- 2 files changed, 68 insertions(+), 16 deletions(-) diff --git a/src/utils/auth.rs b/src/utils/auth.rs index 1e054be..7f2d450 100644 --- a/src/utils/auth.rs +++ b/src/utils/auth.rs @@ -1,22 +1,25 @@ use crate::utils::jwt::check_jwt; +// use reqwest::Client; +use axum::http::StatusCode; use base64::engine::general_purpose::STANDARD; use base64::Engine; use pingora_proxy::Session; -// use reqwest::Client; use std::collections::HashMap; use std::sync::{Arc, LazyLock}; use subtle::ConstantTimeEq; use urlencoding::decode; -// --------------------------------- // +// use pingora::http::{RequestHeader, ResponseHeader, StatusCode}; use pingora::http::RequestHeader; +// --------------------------------- // use pingora_core::connectors::http::Connector; use pingora_core::upstreams::peer::HttpPeer; +use pingora_http::ResponseHeader; // --------------------------------- // #[async_trait::async_trait] trait AuthValidator { - async fn validate(&self, session: &Session) -> bool; + async fn validate(&self, session: &mut Session) -> bool; } struct BasicAuth<'a>(&'a str); struct ApiKeyAuth<'a>(&'a str); @@ -27,11 +30,11 @@ pub static AUTH_CONNECTOR: LazyLock = LazyLock::new(|| Connector::new #[async_trait::async_trait] impl AuthValidator for ForwardAuth<'_> { - async fn validate(&self, session: &Session) -> bool { - let method = match session.req_header().method.as_str() { - "POST" => "POST", - _ => "GET", - }; + async fn validate(&self, session: &mut Session) -> bool { + // let method = match session.req_header().method.as_str() { + // "HEAD" => "HEAD", + // _ => "GET", + // }; let auth_url = self.0; @@ -64,7 +67,7 @@ impl AuthValidator for ForwardAuth<'_> { } }; - let mut auth_req = match RequestHeader::build(method, uri.as_bytes(), None) { + let mut auth_req = match RequestHeader::build("GET", uri.as_bytes(), None) { Ok(r) => r, Err(e) => { log::warn!("ForwardAuth: failed to build request: {}", e); @@ -76,10 +79,21 @@ impl AuthValidator for ForwardAuth<'_> { // auth_req.headers = session.req_header().headers.clone(); auth_req.insert_header("Host", addr).ok(); auth_req.insert_header("X-Forwarded-Uri", uri).ok(); + auth_req.insert_header("X-Forwarded-Method", session.req_header().method.as_str()).ok(); if let Some(auth) = session.req_header().headers.get("authorization") { auth_req.insert_header("Authorization", auth.clone()).ok(); } + if let Some(cookie) = session.req_header().headers.get("cookie") { + auth_req.insert_header("Cookie", cookie.clone()).ok(); + } + + if tls { + auth_req.insert_header("X-Forwarded-Proto", "https").ok(); + } else { + auth_req.insert_header("X-Forwarded-Proto", "http").ok(); + } + if let Err(e) = http_session.write_request_header(Box::new(auth_req)).await { log::warn!("ForwardAuth: write failed: {}", e); return false; @@ -93,15 +107,53 @@ impl AuthValidator for ForwardAuth<'_> { } }; + let auth_headers_to_forward: Vec<(String, String)> = if let Some(resp_header) = http_session.response_header() { + resp_header + .headers + .iter() + .filter_map(|(name, value)| { + let name_str = name.as_str(); + if name_str.starts_with("x-") || name_str.starts_with("remote-") || name_str.starts_with("locat") { + value.to_str().ok().map(|v| (name_str.to_string(), v.to_string())) + } else { + None + } + }) + .collect() + } else { + Vec::new() + }; + AUTH_CONNECTOR.release_http_session(http_session, &peer, None).await; - (200..300).contains(&status) + if (200..300).contains(&status) { + for (name, value) in auth_headers_to_forward { + session.req_header_mut().insert_header(name, value).ok(); + } + true + } else if status == 302 || status == 301 { + let resp = ResponseHeader::build(StatusCode::MOVED_PERMANENTLY, None); + match resp { + Ok(mut r) => { + for (name, value) in auth_headers_to_forward { + r.insert_header(name, value).ok(); + } + let _ = r.insert_header("Content-Length", "0"); + let _ = session.write_response_header(Box::new(r), true).await; + true + } + Err(_) => return false, + } + } else { + false + } + // (200..300).contains(&status) } } #[async_trait::async_trait] impl AuthValidator for BasicAuth<'_> { - async fn validate(&self, session: &Session) -> bool { + async fn validate(&self, session: &mut Session) -> bool { if let Some(header) = session.get_header("authorization") { if let Some(h) = header.to_str().ok() { if let Some((_, val)) = h.split_once(' ') { @@ -119,7 +171,7 @@ impl AuthValidator for BasicAuth<'_> { #[async_trait::async_trait] impl AuthValidator for ApiKeyAuth<'_> { - async fn validate(&self, session: &Session) -> bool { + async fn validate(&self, session: &mut Session) -> bool { if let Some(header) = session.get_header("x-api-key") { if let Some(h) = header.to_str().ok() { return h.as_bytes().ct_eq(self.0.as_bytes()).into(); @@ -131,7 +183,7 @@ impl AuthValidator for ApiKeyAuth<'_> { #[async_trait::async_trait] impl AuthValidator for JwtAuth<'_> { - async fn validate(&self, session: &Session) -> bool { + async fn validate(&self, session: &mut Session) -> bool { let jwtsecret = self.0; if let Some(tok) = get_query_param(session, "araleztoken") { return check_jwt(tok.as_str(), jwtsecret); @@ -149,7 +201,7 @@ impl AuthValidator for JwtAuth<'_> { } } -pub async fn authenticate(auth_type: &Arc, credentials: &Arc, session: &Session) -> bool { +pub async fn authenticate(auth_type: &Arc, credentials: &Arc, session: &mut Session) -> bool { match &**auth_type { "basic" => BasicAuth(credentials).validate(session).await, "apikey" => ApiKeyAuth(credentials).validate(session).await, @@ -162,7 +214,7 @@ pub async fn authenticate(auth_type: &Arc, credentials: &Arc, session: } } -pub fn get_query_param(session: &Session, key: &str) -> Option { +pub fn get_query_param(session: &mut Session, key: &str) -> Option { let query = session.req_header().uri.query()?; let params: HashMap<_, _> = query diff --git a/src/web/proxyhttp.rs b/src/web/proxyhttp.rs index e0bb2b4..35d9027 100644 --- a/src/web/proxyhttp.rs +++ b/src/web/proxyhttp.rs @@ -89,7 +89,7 @@ impl ProxyHttp for LB { None => return Ok(false), Some(ref innermap) => { if let Some(auth) = _ctx.extraparams.authentication.as_ref().or(innermap.authorization.as_ref()) { - if !authenticate(&auth.auth_type, &auth.auth_cred, &session).await { + if !authenticate(&auth.auth_type, &auth.auth_cred, session).await { let _ = session.respond_error(401).await; warn!("Forbidden: {:?}, {}", session.client_addr(), session.req_header().uri.path()); return Ok(true);