Code cleanup

This commit is contained in:
Ara Sadoyan
2025-07-22 17:40:58 +02:00
parent 51c88c8f7c
commit 6f012cee69
6 changed files with 156 additions and 200 deletions

View File

@@ -50,5 +50,6 @@ x509-parser = "0.17.0"
rustls-pemfile = "2.2.0" rustls-pemfile = "2.2.0"
tower-http = { version = "0.6.6", features = ["fs"] } tower-http = { version = "0.6.6", features = ["fs"] }
once_cell = "1.20.2" once_cell = "1.20.2"
#moka = { version = "0.12.10", features = ["sync"] }

View File

@@ -13,6 +13,7 @@ use tonic::transport::Endpoint;
pub async fn hc2(upslist: Arc<UpstreamsDashMap>, fullist: Arc<UpstreamsDashMap>, idlist: Arc<UpstreamsIdMap>, params: (&str, u64)) { pub async fn hc2(upslist: Arc<UpstreamsDashMap>, fullist: Arc<UpstreamsDashMap>, idlist: Arc<UpstreamsIdMap>, params: (&str, u64)) {
let mut period = interval(Duration::from_secs(params.1)); let mut period = interval(Duration::from_secs(params.1));
let mut first_run = 0; let mut first_run = 0;
let client = Client::builder().timeout(Duration::from_secs(2)).danger_accept_invalid_certs(true).build().unwrap();
loop { loop {
tokio::select! { tokio::select! {
_ = period.tick() => { _ = period.tick() => {
@@ -27,8 +28,9 @@ pub async fn hc2(upslist: Arc<UpstreamsDashMap>, fullist: Arc<UpstreamsDashMap>,
let mut innervec= Vec::new(); let mut innervec= Vec::new();
for k in path_entry.value().0 .iter().enumerate() { for k in path_entry.value().0 .iter().enumerate() {
let mut _link = String::new(); let mut _link = String::new();
let tls = detect_tls(k.1.address.as_str(), &k.1.port).await; let tls = detect_tls(k.1.address.as_str(), &k.1.port, &client).await;
let mut is_h2 = false; let mut is_h2 = false;
if tls.1 == Some(Version::HTTP_2) { if tls.1 == Some(Version::HTTP_2) {
is_h2 = true; is_h2 = true;
} }
@@ -43,7 +45,7 @@ pub async fn hc2(upslist: Arc<UpstreamsDashMap>, fullist: Arc<UpstreamsDashMap>,
is_http2: is_h2, is_http2: is_h2,
to_https: k.1.to_https, to_https: k.1.to_https,
}; };
let resp = http_request(_link.as_str(), params.0, "").await; let resp = http_request(_link.as_str(), params.0, "", &client).await;
match resp.0 { match resp.0 {
true => { true => {
if resp.1 { if resp.1 {
@@ -86,33 +88,26 @@ pub async fn hc2(upslist: Arc<UpstreamsDashMap>, fullist: Arc<UpstreamsDashMap>,
} }
} }
#[allow(dead_code)] async fn http_request(url: &str, method: &str, payload: &str, client: &Client) -> (bool, bool) {
async fn http_request(url: &str, method: &str, payload: &str) -> (bool, bool) {
let client = Client::builder().danger_accept_invalid_certs(true).build().unwrap();
let timeout = Duration::from_secs(1);
if !["POST", "GET", "HEAD"].contains(&method) { if !["POST", "GET", "HEAD"].contains(&method) {
error!("Method {} not supported. Only GET|POST|HEAD are supported ", method); error!("Method {} not supported. Only GET|POST|HEAD are supported ", method);
return (false, false); return (false, false);
} }
async fn send_request(client: &Client, method: &str, url: &str, payload: &str, timeout: Duration) -> Option<reqwest::Response> { async fn send_request(client: &Client, method: &str, url: &str, payload: &str) -> Option<reqwest::Response> {
match method { match method {
"POST" => client.post(url).body(payload.to_owned()).timeout(timeout).send().await.ok(), "POST" => client.post(url).body(payload.to_owned()).send().await.ok(),
"GET" => client.get(url).timeout(timeout).send().await.ok(), "GET" => client.get(url).send().await.ok(),
"HEAD" => client.head(url).timeout(timeout).send().await.ok(), "HEAD" => client.head(url).send().await.ok(),
_ => None, _ => None,
} }
} }
match send_request(&client, method, url, payload, timeout).await { match send_request(&client, method, url, payload).await {
Some(response) => { Some(response) => {
let status = response.status().as_u16(); let status = response.status().as_u16();
((99..499).contains(&status), false) ((99..499).contains(&status), false)
} }
None => { None => (ping_grpc(&url).await, true),
// let fallback_url = url.replace("https", "http");
// ping_grpc(&fallback_url).await
(ping_grpc(&url).await, true)
}
} }
} }
@@ -123,10 +118,7 @@ pub async fn ping_grpc(addr: &str) -> bool {
let endpoint = endpoint.timeout(Duration::from_secs(2)); let endpoint = endpoint.timeout(Duration::from_secs(2));
match tokio::time::timeout(Duration::from_secs(3), endpoint.connect()).await { match tokio::time::timeout(Duration::from_secs(3), endpoint.connect()).await {
Ok(Ok(_channel)) => { Ok(Ok(_channel)) => true,
// println!("{:?} ==> {:?} ==> {}", endpoint, _channel, addr);
true
}
_ => false, _ => false,
} }
} else { } else {
@@ -134,15 +126,24 @@ pub async fn ping_grpc(addr: &str) -> bool {
} }
} }
async fn detect_tls(ip: &str, port: &u16) -> (bool, Option<Version>) { async fn detect_tls(ip: &str, port: &u16, client: &Client) -> (bool, Option<Version>) {
let url = format!("https://{}:{}", ip, port); let https_url = format!("https://{}:{}", ip, port);
// let url = format!("{}:{}", ip, port); match client.get(&https_url).send().await {
let client = Client::builder().timeout(Duration::from_secs(2)).danger_accept_invalid_certs(true).build().unwrap(); Ok(response) => {
match client.get(&url).send().await { // println!("{} => {:?} (HTTPS)", https_url, response.version());
Ok(response) => (true, Some(response.version())), return (true, Some(response.version()));
Err(e) => { }
if e.is_builder() || e.is_connect() || e.to_string().contains("tls") { _ => {}
(false, None) }
let http_url = format!("http://{}:{}", ip, port);
match client.get(&http_url).send().await {
Ok(response) => {
// println!("{} => {:?} (HTTP)", http_url, response.version());
(false, Some(response.version()))
}
Err(_) => {
if ping_grpc(&http_url).await {
(false, Some(Version::HTTP_2))
} else { } else {
(false, None) (false, None)
} }

View File

@@ -1,139 +1,136 @@
use crate::utils::structs::*; use crate::utils::structs::*;
use dashmap::DashMap; use dashmap::DashMap;
use log::{error, info, warn}; use log::{error, info, warn};
use serde_yaml::Error;
use std::collections::HashMap; use std::collections::HashMap;
use std::fs; use std::fs;
use std::sync::atomic::AtomicUsize; use std::sync::atomic::AtomicUsize;
pub fn load_configuration(d: &str, kind: &str) -> Option<Configuration> { pub fn load_configuration(d: &str, kind: &str) -> Option<Configuration> {
let mut toreturn: Configuration = Configuration { let yaml_data = match kind {
upstreams: Default::default(), "filepath" => match fs::read_to_string(d) {
headers: Default::default(), Ok(data) => {
consul: None, info!("Reading upstreams from {}", d);
typecfg: "".to_string(), data
extraparams: Extraparams { }
sticky_sessions: false, Err(e) => {
to_https: None, error!("Reading: {}: {:?}", d, e);
authentication: DashMap::new(), warn!("Running with empty upstreams list, update it via API");
rate_limit: None, return None;
}
}, },
};
toreturn.upstreams = UpstreamsDashMap::new();
toreturn.headers = Headers::new();
let mut yaml_data = d.to_string();
match kind {
"filepath" => {
let _ = match fs::read_to_string(d) {
Ok(data) => {
info!("Reading upstreams from {}", d);
yaml_data = data
}
Err(e) => {
error!("Reading: {}: {:?}", d, e.to_string());
warn!("Running with empty upstreams list, update it via API");
return None;
}
};
}
"content" => { "content" => {
info!("Reading upstreams from API post body"); info!("Reading upstreams from API post body");
d.to_string()
} }
_ => error!("Mismatched parameter, only filepath|content is allowed "), _ => {
} error!("Mismatched parameter, only filepath|content is allowed");
return None;
}
};
let p: Result<Config, Error> = serde_yaml::from_str(&yaml_data); let parsed: Config = match serde_yaml::from_str(&yaml_data) {
match p { Ok(cfg) => cfg,
Ok(parsed) => {
let global_headers = DashMap::new();
let mut hl = Vec::new();
if let Some(headers) = &parsed.headers {
for header in headers.iter() {
if let Some((key, val)) = header.split_once(':') {
hl.push((key.to_string(), val.to_string()));
}
}
global_headers.insert("/".to_string(), hl);
toreturn.headers.insert("GLOBAL_HEADERS".to_string(), global_headers);
toreturn.extraparams.sticky_sessions = parsed.sticky_sessions;
toreturn.extraparams.to_https = parsed.to_https;
toreturn.extraparams.rate_limit = parsed.rate_limit;
}
if let Some(auth) = &parsed.authorization {
let name = auth.get("type").unwrap().to_string();
let creds = auth.get("creds").unwrap().to_string();
let val: Vec<String> = vec![name, creds];
toreturn.extraparams.authentication.insert("authorization".to_string(), val);
} else {
toreturn.extraparams.authentication = DashMap::new();
}
match parsed.provider.as_str() {
"file" => {
toreturn.typecfg = "file".to_string();
if let Some(upstream) = parsed.upstreams {
for (hostname, host_config) in upstream {
let path_map = DashMap::new();
let header_list = DashMap::new();
for (path, path_config) in host_config.paths {
let mut server_list = Vec::new();
let mut hl = Vec::new();
if let Some(headers) = &path_config.headers {
for header in headers.iter().by_ref() {
if let Some((key, val)) = header.split_once(':') {
hl.push((key.to_string(), val.to_string()));
}
}
}
header_list.insert(path.clone(), hl);
for server in path_config.servers {
if let Some((ip, port_str)) = server.split_once(':') {
if let Ok(port) = port_str.parse::<u16>() {
let to_https = path_config.to_https.unwrap_or(false);
let sl = InnerMap {
address: ip.to_string(),
port: port,
is_ssl: true,
is_http2: false,
to_https: to_https,
};
server_list.push(sl);
}
}
}
path_map.insert(path, (server_list, AtomicUsize::new(0)));
}
toreturn.headers.insert(hostname.clone(), header_list);
toreturn.upstreams.insert(hostname, path_map);
}
}
Some(toreturn)
}
"consul" => {
toreturn.typecfg = "consul".to_string();
let consul = parsed.consul;
match consul {
Some(consul) => {
toreturn.consul = Some(consul);
Some(toreturn)
}
None => None,
}
}
"kubernetes" => None,
_ => {
warn!("Unknown provider {}", parsed.provider);
None
}
}
}
Err(e) => { Err(e) => {
error!("Failed to parse upstreams file: {}", e); error!("Failed to parse upstreams file: {}", e);
return None;
}
};
let mut toreturn = Configuration::default();
populate_headers_and_auth(&mut toreturn, &parsed);
toreturn.typecfg = parsed.provider.clone();
match parsed.provider.as_str() {
"file" => {
populate_file_upstreams(&mut toreturn, &parsed);
Some(toreturn)
}
"consul" => {
toreturn.consul = parsed.consul;
if toreturn.consul.is_some() {
Some(toreturn)
} else {
None
}
}
"kubernetes" => None,
_ => {
warn!("Unknown provider {}", parsed.provider);
None None
} }
} }
} }
fn populate_headers_and_auth(config: &mut Configuration, parsed: &Config) {
if let Some(headers) = &parsed.headers {
let mut hl = Vec::new();
for header in headers {
if let Some((key, val)) = header.split_once(':') {
hl.push((key.trim().to_string(), val.trim().to_string()));
}
}
let global_headers = DashMap::new();
global_headers.insert("/".to_string(), hl);
config.headers.insert("GLOBAL_HEADERS".to_string(), global_headers);
}
config.extraparams.sticky_sessions = parsed.sticky_sessions;
config.extraparams.to_https = parsed.to_https;
config.extraparams.rate_limit = parsed.rate_limit;
if let Some(auth) = &parsed.authorization {
let name = auth.get("type").unwrap_or(&"".to_string()).to_string();
let creds = auth.get("creds").unwrap_or(&"".to_string()).to_string();
config.extraparams.authentication.insert("authorization".to_string(), vec![name, creds]);
} else {
config.extraparams.authentication = DashMap::new();
}
}
fn populate_file_upstreams(config: &mut Configuration, parsed: &Config) {
if let Some(upstreams) = &parsed.upstreams {
for (hostname, host_config) in upstreams {
let path_map = DashMap::new();
let header_list = DashMap::new();
for (path, path_config) in &host_config.paths {
let mut server_list = Vec::new();
let mut hl = Vec::new();
if let Some(headers) = &path_config.headers {
for header in headers {
if let Some((key, val)) = header.split_once(':') {
hl.push((key.trim().to_string(), val.trim().to_string()));
}
}
}
header_list.insert(path.clone(), hl);
for server in &path_config.servers {
if let Some((ip, port_str)) = server.split_once(':') {
if let Ok(port) = port_str.parse::<u16>() {
server_list.push(InnerMap {
address: ip.trim().to_string(),
port,
is_ssl: true,
is_http2: false,
to_https: path_config.to_https.unwrap_or(false),
});
}
}
}
path_map.insert(path.clone(), (server_list, AtomicUsize::new(0)));
}
config.headers.insert(hostname.clone(), header_list);
config.upstreams.insert(hostname.clone(), path_map);
}
}
}
pub fn parce_main_config(path: &str) -> AppConfig { pub fn parce_main_config(path: &str) -> AppConfig {
info!("Parsing configuration"); info!("Parsing configuration");
let data = fs::read_to_string(path).unwrap(); let data = fs::read_to_string(path).unwrap();

View File

@@ -14,7 +14,7 @@ pub struct ServiceMapping {
pub real: String, pub real: String,
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug, Default)]
pub struct Extraparams { pub struct Extraparams {
pub sticky_sessions: bool, pub sticky_sessions: bool,
pub to_https: Option<bool>, pub to_https: Option<bool>,
@@ -60,7 +60,7 @@ pub struct PathConfig {
pub headers: Option<Vec<String>>, pub headers: Option<Vec<String>>,
pub rate_limit: Option<isize>, pub rate_limit: Option<isize>,
} }
#[derive(Debug)] #[derive(Debug, Default)]
pub struct Configuration { pub struct Configuration {
pub upstreams: UpstreamsDashMap, pub upstreams: UpstreamsDashMap,
pub headers: Headers, pub headers: Headers,
@@ -102,49 +102,11 @@ pub struct InnerMap {
impl InnerMap { impl InnerMap {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
address: String::new(), address: Default::default(),
port: 0, port: Default::default(),
is_ssl: false, is_ssl: Default::default(),
is_http2: false, is_http2: Default::default(),
to_https: false, to_https: Default::default(),
} }
} }
} }
/*
impl InnerMap {
pub fn new(address: String, port: u16) -> Self {
Self {
address,
port,
is_ssl: false, // Default values
is_http2: false,
to_https: false,
}
}
pub fn address(&self) -> &str {
&self.address
}
pub fn port(&self) -> u16 {
self.port
}
// Setters with validation
pub fn with_ssl(mut self, ssl: bool) -> Result<Self, String> {
self.is_ssl = ssl;
Ok(self)
}
pub fn with_http2(mut self, http2: bool) -> Result<Self, String> {
self.is_http2 = http2;
Ok(self)
}
pub fn with_to_https(mut self, to_https: bool) -> Result<Self, String> {
self.to_https = to_https;
Ok(self)
}
}
*/

View File

@@ -21,17 +21,13 @@ pub fn print_upstreams(upstreams: &UpstreamsDashMap) {
for path_entry in host_entry.value().iter() { for path_entry in host_entry.value().iter() {
let path = path_entry.key(); let path = path_entry.key();
println!(" Path: {}", path); println!(" Path: {}", path);
for f in path_entry.value().0.clone() { for f in path_entry.value().0.clone() {
println!( println!(
" ===> IP: {}, Port: {}, SSL: {}, H2: {}, To HTTPS: {}", " IP: {}, Port: {}, SSL: {}, H2: {}, To HTTPS: {}",
f.address, f.port, f.is_ssl, f.is_http2, f.to_https f.address, f.port, f.is_ssl, f.is_http2, f.to_https
); );
} }
// { address: "127.0.0.4", port: 8000, is_ssl: false, is_http2: false, to_https: false }
// for (ip, port, ssl, vers, to_https) in path_entry.value().0.clone() {
// println!(" ===> IP: {}, Port: {}, SSL: {}, H2: {}, To HTTPS: {}", ip, port, ssl, vers, to_https);
// }
} }
} }
} }

View File

@@ -18,6 +18,8 @@ use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::time::Instant; use tokio::time::Instant;
static RATE_LIMITER: Lazy<Rate> = Lazy::new(|| Rate::new(Duration::from_secs(1)));
#[derive(Clone)] #[derive(Clone)]
pub struct LB { pub struct LB {
pub ump_upst: Arc<UpstreamsDashMap>, pub ump_upst: Arc<UpstreamsDashMap>,
@@ -35,10 +37,6 @@ pub struct Context {
start_time: Instant, start_time: Instant,
hostname: Option<String>, hostname: Option<String>,
} }
// Rate limiter
static RATE_LIMITER: Lazy<Rate> = Lazy::new(|| Rate::new(Duration::from_secs(1)));
// max request per second per client
// static MAX_REQ_PER_SEC: isize = 1;
#[async_trait] #[async_trait]
impl ProxyHttp for LB { impl ProxyHttp for LB {
@@ -64,6 +62,7 @@ impl ProxyHttp for LB {
let hostname = return_header_host(&session); let hostname = return_header_host(&session);
_ctx.hostname = hostname.clone(); _ctx.hostname = hostname.clone();
if let Some(rate) = self.extraparams.load().rate_limit { if let Some(rate) = self.extraparams.load().rate_limit {
match hostname { match hostname {
None => return Ok(false), None => return Ok(false),