From f135106a4428f1b401846b6adacc6cb28b25a14b Mon Sep 17 00:00:00 2001 From: Ara Sadoyan Date: Wed, 8 Apr 2026 19:05:19 +0200 Subject: [PATCH] Changes in authentication --- Cargo.lock | 1 + Cargo.toml | 1 + src/utils/auth.rs | 36 +++++++++++------------------------- src/web/webserver.rs | 18 +++++++----------- 4 files changed, 20 insertions(+), 36 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f2e2030..7c388f6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -147,6 +147,7 @@ dependencies = [ "serde_json", "serde_yml", "sha2 0.11.0", + "subtle", "tokio", "tonic", "tower-http", diff --git a/Cargo.toml b/Cargo.toml index 7812737..efbf91d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,3 +43,4 @@ tower-http = { version = "0.6.8", features = ["fs"] } privdrop = "0.5.6" ctrlc = "3.5.2" serde_json = "1.0.149" +subtle = "2.6.1" diff --git a/src/utils/auth.rs b/src/utils/auth.rs index 2a636c9..3a68066 100644 --- a/src/utils/auth.rs +++ b/src/utils/auth.rs @@ -4,6 +4,7 @@ use base64::Engine; use pingora_proxy::Session; use std::collections::HashMap; use std::sync::Arc; +use subtle::ConstantTimeEq; use urlencoding::decode; trait AuthValidator { @@ -19,8 +20,8 @@ impl AuthValidator for BasicAuth<'_> { if let Some(h) = header.to_str().ok() { if let Some((_, val)) = h.split_once(' ') { if let Some(decoded) = STANDARD.decode(val).ok() { - if let Some(decoded_str) = String::from_utf8(decoded).ok() { - return decoded_str == self.0; + if decoded.as_slice().ct_eq(self.0.as_bytes()).into() { + return true; } } } @@ -33,10 +34,9 @@ impl AuthValidator for BasicAuth<'_> { impl AuthValidator for ApiKeyAuth<'_> { fn validate(&self, session: &Session) -> bool { if let Some(header) = session.get_header("x-api-key") { - if let Some(header) = header.to_str().ok() { - return header == self.0; + if let Some(h) = header.to_str().ok() { + return h.as_bytes().ct_eq(self.0.as_bytes()).into(); } - // return header.to_str().ok().unwrap() == self.0; } false } @@ -60,27 +60,14 @@ impl AuthValidator for JwtAuth<'_> { false } } -fn validate(auth: &dyn AuthValidator, session: &Session) -> bool { - auth.validate(session) -} -// pub fn authenticate(c: &[Arc], session: &Session) -> bool { pub fn authenticate(auth_type: &Arc, credentials: &Arc, session: &Session) -> bool { - match &*auth_type.clone() { - "basic" => { - let auth = BasicAuth(&*credentials.clone()); - validate(&auth, session) - } - "apikey" => { - let auth = ApiKeyAuth(&*credentials.clone()); - validate(&auth, session) - } - "jwt" => { - let auth = JwtAuth(&*credentials.clone()); - validate(&auth, session) - } + match &**auth_type { + "basic" => BasicAuth(credentials).validate(session), + "apikey" => ApiKeyAuth(credentials).validate(session), + "jwt" => JwtAuth(credentials).validate(session), _ => { - println!("Unsupported authentication mechanism : {}", auth_type); + log::warn!("Unsupported authentication mechanism : {}", auth_type); false } } @@ -98,6 +85,5 @@ pub fn get_query_param(session: &Session, key: &str) -> Option { Some((k, v)) }) .collect(); - - params.get(key).map(|v| decode(v).ok()).flatten().map(|s| s.to_string()) + params.get(key).and_then(|v| decode(v).ok()).map(|s| s.to_string()) } diff --git a/src/web/webserver.rs b/src/web/webserver.rs index 1d9f4c3..c7bcaf4 100644 --- a/src/web/webserver.rs +++ b/src/web/webserver.rs @@ -3,7 +3,7 @@ use crate::utils::structs::{Config, Configuration, UpstreamsDashMap}; use crate::utils::tools::{upstreams_liveness_json, upstreams_to_json}; use axum::body::Body; use axum::extract::{Query, State}; -use axum::http::{Response, StatusCode}; +use axum::http::{header::HeaderMap, Response, StatusCode}; use axum::response::IntoResponse; use axum::routing::{get, post}; use axum::{Json, Router}; @@ -18,6 +18,7 @@ use std::collections::HashMap; // use std::net::SocketAddr; use std::sync::Arc; use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use subtle::ConstantTimeEq; use tokio::net::TcpListener; use tower_http::services::ServeDir; @@ -88,23 +89,18 @@ pub async fn run_server(config: &APIUpstreamProvider, mut to_return: Sender, Query(params): Query>, content: String) -> impl IntoResponse { +async fn conf(State(st): State, Query(params): Query>, headers: HeaderMap, content: String) -> impl IntoResponse { if !st.config_api_enabled { return Response::builder().status(StatusCode::FORBIDDEN).body(Body::from("Config API is disabled !\n")).unwrap(); } - if let Some(s) = params.get("key") { - if s.to_owned() == st.master_key { + if let Some(s) = headers.get("x-api-key").and_then(|v| v.to_str().ok()).or(params.get("key").map(|s| s.as_str())) { + if s.as_bytes().ct_eq(st.master_key.as_bytes()).into() { let strcontent = content.as_str(); let parsed = serde_yml::from_str::(strcontent); match parsed { Ok(_) => { - if let Some(s) = params.get("key") { - if s.to_owned() == st.master_key { - let _ = tokio::spawn(async move { apply_config(content.as_str(), st).await }); - return Response::builder().status(StatusCode::OK).body(Body::from("Accepted! Applying in background\n")).unwrap(); - } - } - return Response::builder().status(StatusCode::FORBIDDEN).body(Body::from("Access Denied !\n")).unwrap(); + let _ = tokio::spawn(async move { apply_config(content.as_str(), st).await }); + return Response::builder().status(StatusCode::OK).body(Body::from("Accepted! Applying in background\n")).unwrap(); } Err(err) => { error!("Failed to parse upstreams file: {}", err);