Code cleanup

This commit is contained in:
Ara Sadoyan
2025-07-22 17:40:58 +02:00
parent 51c88c8f7c
commit 6f012cee69
6 changed files with 156 additions and 200 deletions

View File

@@ -50,5 +50,6 @@ x509-parser = "0.17.0"
rustls-pemfile = "2.2.0"
tower-http = { version = "0.6.6", features = ["fs"] }
once_cell = "1.20.2"
#moka = { version = "0.12.10", features = ["sync"] }

View File

@@ -13,6 +13,7 @@ use tonic::transport::Endpoint;
pub async fn hc2(upslist: Arc<UpstreamsDashMap>, fullist: Arc<UpstreamsDashMap>, idlist: Arc<UpstreamsIdMap>, params: (&str, u64)) {
let mut period = interval(Duration::from_secs(params.1));
let mut first_run = 0;
let client = Client::builder().timeout(Duration::from_secs(2)).danger_accept_invalid_certs(true).build().unwrap();
loop {
tokio::select! {
_ = period.tick() => {
@@ -27,8 +28,9 @@ pub async fn hc2(upslist: Arc<UpstreamsDashMap>, fullist: Arc<UpstreamsDashMap>,
let mut innervec= Vec::new();
for k in path_entry.value().0 .iter().enumerate() {
let mut _link = String::new();
let tls = detect_tls(k.1.address.as_str(), &k.1.port).await;
let tls = detect_tls(k.1.address.as_str(), &k.1.port, &client).await;
let mut is_h2 = false;
if tls.1 == Some(Version::HTTP_2) {
is_h2 = true;
}
@@ -43,7 +45,7 @@ pub async fn hc2(upslist: Arc<UpstreamsDashMap>, fullist: Arc<UpstreamsDashMap>,
is_http2: is_h2,
to_https: k.1.to_https,
};
let resp = http_request(_link.as_str(), params.0, "").await;
let resp = http_request(_link.as_str(), params.0, "", &client).await;
match resp.0 {
true => {
if resp.1 {
@@ -86,33 +88,26 @@ pub async fn hc2(upslist: Arc<UpstreamsDashMap>, fullist: Arc<UpstreamsDashMap>,
}
}
#[allow(dead_code)]
async fn http_request(url: &str, method: &str, payload: &str) -> (bool, bool) {
let client = Client::builder().danger_accept_invalid_certs(true).build().unwrap();
let timeout = Duration::from_secs(1);
async fn http_request(url: &str, method: &str, payload: &str, client: &Client) -> (bool, bool) {
if !["POST", "GET", "HEAD"].contains(&method) {
error!("Method {} not supported. Only GET|POST|HEAD are supported ", method);
return (false, false);
}
async fn send_request(client: &Client, method: &str, url: &str, payload: &str, timeout: Duration) -> Option<reqwest::Response> {
async fn send_request(client: &Client, method: &str, url: &str, payload: &str) -> Option<reqwest::Response> {
match method {
"POST" => client.post(url).body(payload.to_owned()).timeout(timeout).send().await.ok(),
"GET" => client.get(url).timeout(timeout).send().await.ok(),
"HEAD" => client.head(url).timeout(timeout).send().await.ok(),
"POST" => client.post(url).body(payload.to_owned()).send().await.ok(),
"GET" => client.get(url).send().await.ok(),
"HEAD" => client.head(url).send().await.ok(),
_ => None,
}
}
match send_request(&client, method, url, payload, timeout).await {
match send_request(&client, method, url, payload).await {
Some(response) => {
let status = response.status().as_u16();
((99..499).contains(&status), false)
}
None => {
// let fallback_url = url.replace("https", "http");
// ping_grpc(&fallback_url).await
(ping_grpc(&url).await, true)
}
None => (ping_grpc(&url).await, true),
}
}
@@ -123,10 +118,7 @@ pub async fn ping_grpc(addr: &str) -> bool {
let endpoint = endpoint.timeout(Duration::from_secs(2));
match tokio::time::timeout(Duration::from_secs(3), endpoint.connect()).await {
Ok(Ok(_channel)) => {
// println!("{:?} ==> {:?} ==> {}", endpoint, _channel, addr);
true
}
Ok(Ok(_channel)) => true,
_ => false,
}
} else {
@@ -134,15 +126,24 @@ pub async fn ping_grpc(addr: &str) -> bool {
}
}
async fn detect_tls(ip: &str, port: &u16) -> (bool, Option<Version>) {
let url = format!("https://{}:{}", ip, port);
// let url = format!("{}:{}", ip, port);
let client = Client::builder().timeout(Duration::from_secs(2)).danger_accept_invalid_certs(true).build().unwrap();
match client.get(&url).send().await {
Ok(response) => (true, Some(response.version())),
Err(e) => {
if e.is_builder() || e.is_connect() || e.to_string().contains("tls") {
(false, None)
async fn detect_tls(ip: &str, port: &u16, client: &Client) -> (bool, Option<Version>) {
let https_url = format!("https://{}:{}", ip, port);
match client.get(&https_url).send().await {
Ok(response) => {
// println!("{} => {:?} (HTTPS)", https_url, response.version());
return (true, Some(response.version()));
}
_ => {}
}
let http_url = format!("http://{}:{}", ip, port);
match client.get(&http_url).send().await {
Ok(response) => {
// println!("{} => {:?} (HTTP)", http_url, response.version());
(false, Some(response.version()))
}
Err(_) => {
if ping_grpc(&http_url).await {
(false, Some(Version::HTTP_2))
} else {
(false, None)
}

View File

@@ -1,139 +1,136 @@
use crate::utils::structs::*;
use dashmap::DashMap;
use log::{error, info, warn};
use serde_yaml::Error;
use std::collections::HashMap;
use std::fs;
use std::sync::atomic::AtomicUsize;
pub fn load_configuration(d: &str, kind: &str) -> Option<Configuration> {
let mut toreturn: Configuration = Configuration {
upstreams: Default::default(),
headers: Default::default(),
consul: None,
typecfg: "".to_string(),
extraparams: Extraparams {
sticky_sessions: false,
to_https: None,
authentication: DashMap::new(),
rate_limit: None,
let yaml_data = match kind {
"filepath" => match fs::read_to_string(d) {
Ok(data) => {
info!("Reading upstreams from {}", d);
data
}
Err(e) => {
error!("Reading: {}: {:?}", d, e);
warn!("Running with empty upstreams list, update it via API");
return None;
}
},
};
toreturn.upstreams = UpstreamsDashMap::new();
toreturn.headers = Headers::new();
let mut yaml_data = d.to_string();
match kind {
"filepath" => {
let _ = match fs::read_to_string(d) {
Ok(data) => {
info!("Reading upstreams from {}", d);
yaml_data = data
}
Err(e) => {
error!("Reading: {}: {:?}", d, e.to_string());
warn!("Running with empty upstreams list, update it via API");
return None;
}
};
}
"content" => {
info!("Reading upstreams from API post body");
d.to_string()
}
_ => error!("Mismatched parameter, only filepath|content is allowed "),
}
_ => {
error!("Mismatched parameter, only filepath|content is allowed");
return None;
}
};
let p: Result<Config, Error> = serde_yaml::from_str(&yaml_data);
match p {
Ok(parsed) => {
let global_headers = DashMap::new();
let mut hl = Vec::new();
if let Some(headers) = &parsed.headers {
for header in headers.iter() {
if let Some((key, val)) = header.split_once(':') {
hl.push((key.to_string(), val.to_string()));
}
}
global_headers.insert("/".to_string(), hl);
toreturn.headers.insert("GLOBAL_HEADERS".to_string(), global_headers);
toreturn.extraparams.sticky_sessions = parsed.sticky_sessions;
toreturn.extraparams.to_https = parsed.to_https;
toreturn.extraparams.rate_limit = parsed.rate_limit;
}
if let Some(auth) = &parsed.authorization {
let name = auth.get("type").unwrap().to_string();
let creds = auth.get("creds").unwrap().to_string();
let val: Vec<String> = vec![name, creds];
toreturn.extraparams.authentication.insert("authorization".to_string(), val);
} else {
toreturn.extraparams.authentication = DashMap::new();
}
match parsed.provider.as_str() {
"file" => {
toreturn.typecfg = "file".to_string();
if let Some(upstream) = parsed.upstreams {
for (hostname, host_config) in upstream {
let path_map = DashMap::new();
let header_list = DashMap::new();
for (path, path_config) in host_config.paths {
let mut server_list = Vec::new();
let mut hl = Vec::new();
if let Some(headers) = &path_config.headers {
for header in headers.iter().by_ref() {
if let Some((key, val)) = header.split_once(':') {
hl.push((key.to_string(), val.to_string()));
}
}
}
header_list.insert(path.clone(), hl);
for server in path_config.servers {
if let Some((ip, port_str)) = server.split_once(':') {
if let Ok(port) = port_str.parse::<u16>() {
let to_https = path_config.to_https.unwrap_or(false);
let sl = InnerMap {
address: ip.to_string(),
port: port,
is_ssl: true,
is_http2: false,
to_https: to_https,
};
server_list.push(sl);
}
}
}
path_map.insert(path, (server_list, AtomicUsize::new(0)));
}
toreturn.headers.insert(hostname.clone(), header_list);
toreturn.upstreams.insert(hostname, path_map);
}
}
Some(toreturn)
}
"consul" => {
toreturn.typecfg = "consul".to_string();
let consul = parsed.consul;
match consul {
Some(consul) => {
toreturn.consul = Some(consul);
Some(toreturn)
}
None => None,
}
}
"kubernetes" => None,
_ => {
warn!("Unknown provider {}", parsed.provider);
None
}
}
}
let parsed: Config = match serde_yaml::from_str(&yaml_data) {
Ok(cfg) => cfg,
Err(e) => {
error!("Failed to parse upstreams file: {}", e);
return None;
}
};
let mut toreturn = Configuration::default();
populate_headers_and_auth(&mut toreturn, &parsed);
toreturn.typecfg = parsed.provider.clone();
match parsed.provider.as_str() {
"file" => {
populate_file_upstreams(&mut toreturn, &parsed);
Some(toreturn)
}
"consul" => {
toreturn.consul = parsed.consul;
if toreturn.consul.is_some() {
Some(toreturn)
} else {
None
}
}
"kubernetes" => None,
_ => {
warn!("Unknown provider {}", parsed.provider);
None
}
}
}
fn populate_headers_and_auth(config: &mut Configuration, parsed: &Config) {
if let Some(headers) = &parsed.headers {
let mut hl = Vec::new();
for header in headers {
if let Some((key, val)) = header.split_once(':') {
hl.push((key.trim().to_string(), val.trim().to_string()));
}
}
let global_headers = DashMap::new();
global_headers.insert("/".to_string(), hl);
config.headers.insert("GLOBAL_HEADERS".to_string(), global_headers);
}
config.extraparams.sticky_sessions = parsed.sticky_sessions;
config.extraparams.to_https = parsed.to_https;
config.extraparams.rate_limit = parsed.rate_limit;
if let Some(auth) = &parsed.authorization {
let name = auth.get("type").unwrap_or(&"".to_string()).to_string();
let creds = auth.get("creds").unwrap_or(&"".to_string()).to_string();
config.extraparams.authentication.insert("authorization".to_string(), vec![name, creds]);
} else {
config.extraparams.authentication = DashMap::new();
}
}
fn populate_file_upstreams(config: &mut Configuration, parsed: &Config) {
if let Some(upstreams) = &parsed.upstreams {
for (hostname, host_config) in upstreams {
let path_map = DashMap::new();
let header_list = DashMap::new();
for (path, path_config) in &host_config.paths {
let mut server_list = Vec::new();
let mut hl = Vec::new();
if let Some(headers) = &path_config.headers {
for header in headers {
if let Some((key, val)) = header.split_once(':') {
hl.push((key.trim().to_string(), val.trim().to_string()));
}
}
}
header_list.insert(path.clone(), hl);
for server in &path_config.servers {
if let Some((ip, port_str)) = server.split_once(':') {
if let Ok(port) = port_str.parse::<u16>() {
server_list.push(InnerMap {
address: ip.trim().to_string(),
port,
is_ssl: true,
is_http2: false,
to_https: path_config.to_https.unwrap_or(false),
});
}
}
}
path_map.insert(path.clone(), (server_list, AtomicUsize::new(0)));
}
config.headers.insert(hostname.clone(), header_list);
config.upstreams.insert(hostname.clone(), path_map);
}
}
}
pub fn parce_main_config(path: &str) -> AppConfig {
info!("Parsing configuration");
let data = fs::read_to_string(path).unwrap();

View File

@@ -14,7 +14,7 @@ pub struct ServiceMapping {
pub real: String,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Default)]
pub struct Extraparams {
pub sticky_sessions: bool,
pub to_https: Option<bool>,
@@ -60,7 +60,7 @@ pub struct PathConfig {
pub headers: Option<Vec<String>>,
pub rate_limit: Option<isize>,
}
#[derive(Debug)]
#[derive(Debug, Default)]
pub struct Configuration {
pub upstreams: UpstreamsDashMap,
pub headers: Headers,
@@ -102,49 +102,11 @@ pub struct InnerMap {
impl InnerMap {
pub fn new() -> Self {
Self {
address: String::new(),
port: 0,
is_ssl: false,
is_http2: false,
to_https: false,
address: Default::default(),
port: Default::default(),
is_ssl: Default::default(),
is_http2: Default::default(),
to_https: Default::default(),
}
}
}
/*
impl InnerMap {
pub fn new(address: String, port: u16) -> Self {
Self {
address,
port,
is_ssl: false, // Default values
is_http2: false,
to_https: false,
}
}
pub fn address(&self) -> &str {
&self.address
}
pub fn port(&self) -> u16 {
self.port
}
// Setters with validation
pub fn with_ssl(mut self, ssl: bool) -> Result<Self, String> {
self.is_ssl = ssl;
Ok(self)
}
pub fn with_http2(mut self, http2: bool) -> Result<Self, String> {
self.is_http2 = http2;
Ok(self)
}
pub fn with_to_https(mut self, to_https: bool) -> Result<Self, String> {
self.to_https = to_https;
Ok(self)
}
}
*/

View File

@@ -21,17 +21,13 @@ pub fn print_upstreams(upstreams: &UpstreamsDashMap) {
for path_entry in host_entry.value().iter() {
let path = path_entry.key();
println!(" Path: {}", path);
println!(" Path: {}", path);
for f in path_entry.value().0.clone() {
println!(
" ===> IP: {}, Port: {}, SSL: {}, H2: {}, To HTTPS: {}",
" IP: {}, Port: {}, SSL: {}, H2: {}, To HTTPS: {}",
f.address, f.port, f.is_ssl, f.is_http2, f.to_https
);
}
// { address: "127.0.0.4", port: 8000, is_ssl: false, is_http2: false, to_https: false }
// for (ip, port, ssl, vers, to_https) in path_entry.value().0.clone() {
// println!(" ===> IP: {}, Port: {}, SSL: {}, H2: {}, To HTTPS: {}", ip, port, ssl, vers, to_https);
// }
}
}
}

View File

@@ -18,6 +18,8 @@ use std::sync::Arc;
use std::time::Duration;
use tokio::time::Instant;
static RATE_LIMITER: Lazy<Rate> = Lazy::new(|| Rate::new(Duration::from_secs(1)));
#[derive(Clone)]
pub struct LB {
pub ump_upst: Arc<UpstreamsDashMap>,
@@ -35,10 +37,6 @@ pub struct Context {
start_time: Instant,
hostname: Option<String>,
}
// Rate limiter
static RATE_LIMITER: Lazy<Rate> = Lazy::new(|| Rate::new(Duration::from_secs(1)));
// max request per second per client
// static MAX_REQ_PER_SEC: isize = 1;
#[async_trait]
impl ProxyHttp for LB {
@@ -64,6 +62,7 @@ impl ProxyHttp for LB {
let hostname = return_header_host(&session);
_ctx.hostname = hostname.clone();
if let Some(rate) = self.extraparams.load().rate_limit {
match hostname {
None => return Ok(false),