Dynamic load of SSL certificates from disk.

This commit is contained in:
Ara Sadoyan
2025-06-19 18:32:44 +02:00
parent 60b7b3aa7a
commit 8d4e434d6a
9 changed files with 205 additions and 166 deletions

View File

@@ -78,3 +78,9 @@ pub struct AppConfig {
pub local_server: Option<(String, u16)>,
pub proxy_certificates: Option<String>,
}
// #[derive(Debug)]
// pub struct CertificateMove {
// pub cert_tx: Sender<CertificateConfig>,
// pub cert_rx: Receiver<CertificateConfig>,
// }

View File

@@ -1,9 +1,12 @@
use openssl::ssl::{select_next_proto, AlpnError, NameType, SniError, SslAlert, SslContext, SslFiletype, SslMethod, SslRef};
use dashmap::DashMap;
use log::error;
use pingora::tls::ssl::{select_next_proto, AlpnError, NameType, SniError, SslAlert, SslContext, SslFiletype, SslMethod, SslRef};
use rustls_pemfile::{read_one, Item};
use serde::Deserialize;
use std::collections::HashSet;
use std::fs::File;
use std::io::BufReader;
// use tokio::time::Instant;
use x509_parser::extensions::GeneralName;
use x509_parser::nom::Err as NomErr;
use x509_parser::prelude::*;
@@ -28,43 +31,56 @@ struct CertificateInfo {
#[derive(Debug)]
pub struct Certificates {
configs: Vec<CertificateInfo>,
name_map: DashMap<String, SslContext>,
pub default_cert_path: String,
pub default_key_path: String,
}
impl Certificates {
pub fn new(configs: &Vec<CertificateConfig>) -> Self {
let default_cert = configs.first().expect("atleast one TLS certificate required");
pub fn new(configs: &Vec<CertificateConfig>) -> Option<Self> {
let default_cert = configs.first().expect("At least one TLS certificate required");
let mut cert_infos = Vec::new();
let name_map: DashMap<String, SslContext> = DashMap::new();
for config in configs {
cert_infos.push(
load_cert_info(&config.cert_path, &config.key_path)
.unwrap_or_else(|| panic!("unable to load certificate info | public: {}, private: {}", &config.cert_path, &config.key_path)),
);
let cert_info = load_cert_info(&config.cert_path, &config.key_path);
match cert_info {
Some(cert) => {
for name in &cert.common_names {
name_map.insert(name.clone(), cert.ssl_context.clone());
}
for name in &cert.alt_names {
name_map.insert(name.clone(), cert.ssl_context.clone());
}
cert_infos.push(cert)
}
None => {
error!("Unable to load certificate info | public: {}, private: {}", &config.cert_path, &config.key_path);
return None;
}
}
}
Self {
Some(Self {
name_map: name_map,
configs: cert_infos,
default_cert_path: default_cert.cert_path.clone(),
default_key_path: default_cert.key_path.clone(),
}
})
}
fn find_ssl_context(&self, server_name: &str) -> Option<&SslContext> {
fn find_ssl_context(&self, server_name: &str) -> Option<SslContext> {
if let Some(ctx) = self.name_map.get(server_name) {
return Some(ctx.clone());
}
for config in &self.configs {
// Exact name match
if config.common_names.contains(&server_name.to_string()) || config.alt_names.contains(&server_name.to_string()) {
return Some(&config.ssl_context);
}
// Wildcard match
for name in &config.common_names {
if name.starts_with("*.") && server_name.ends_with(&name[1..]) {
return Some(&config.ssl_context);
return Some(config.ssl_context.clone());
}
}
for name in &config.alt_names {
if name.starts_with("*.") && server_name.ends_with(&name[1..]) {
return Some(&config.ssl_context);
return Some(config.ssl_context.clone());
}
}
}
@@ -74,16 +90,18 @@ impl Certificates {
pub fn server_name_callback(&self, ssl_ref: &mut SslRef, ssl_alert: &mut SslAlert) -> Result<(), SniError> {
let server_name = ssl_ref.servername(NameType::HOST_NAME);
log::debug!("TLS connect: server_name = {:?}, ssl_ref = {:?}, ssl_alert = {:?}", server_name, ssl_ref, ssl_alert);
// let start_time = Instant::now();
if let Some(name) = server_name {
match self.find_ssl_context(name) {
Some(ctx) => {
ssl_ref.set_ssl_context(ctx).map_err(|_| SniError::ALERT_FATAL)?;
ssl_ref.set_ssl_context(&*ctx).map_err(|_| SniError::ALERT_FATAL)?;
}
None => {
log::debug!("No matching server name found");
}
}
}
// println!("Context ==> {:?} <==", start_time.elapsed());
Ok(())
}
}
@@ -160,7 +178,6 @@ fn load_cert_info(cert_path: &str, key_path: &str) -> Option<CertificateInfo> {
fn create_ssl_context(cert_path: &str, key_path: &str) -> Result<SslContext, Box<dyn std::error::Error>> {
let mut ctx = SslContext::builder(SslMethod::tls())?;
ctx.set_certificate_chain_file(cert_path)?;
ctx.set_private_key_file(key_path, SslFiletype::PEM)?;
ctx.set_alpn_select_callback(prefer_h2);

View File

@@ -1,12 +1,17 @@
use crate::utils::structs::{UpstreamsDashMap, UpstreamsIdMap};
use crate::utils::tls;
use crate::utils::tls::CertificateConfig;
use dashmap::DashMap;
use log::{error, info};
use notify::{event::ModifyKind, Config, EventKind, RecommendedWatcher, RecursiveMode, Watcher};
use sha2::{Digest, Sha256};
use std::any::type_name;
use std::collections::{HashMap, HashSet};
use std::fmt::Write;
use std::fs;
use std::sync::atomic::AtomicUsize;
use std::sync::mpsc::{channel, Sender};
use std::time::{Duration, Instant};
#[allow(dead_code)]
pub fn print_upstreams(upstreams: &UpstreamsDashMap) {
@@ -162,7 +167,7 @@ pub fn listdir(dir: String) -> Vec<tls::CertificateConfig> {
inner.push(name.clone() + ".crt");
inner.push(name.clone() + ".key");
f.insert(domain[domain.len() - 1].to_owned(), inner);
let y = tls::CertificateConfig {
let y = CertificateConfig {
cert_path: name.clone() + ".crt",
key_path: name.clone() + ".key",
};
@@ -170,7 +175,7 @@ pub fn listdir(dir: String) -> Vec<tls::CertificateConfig> {
}
}
for (_, v) in f.iter() {
let y = tls::CertificateConfig {
let y = CertificateConfig {
cert_path: v[0].clone(),
key_path: v[1].clone(),
};
@@ -178,3 +183,30 @@ pub fn listdir(dir: String) -> Vec<tls::CertificateConfig> {
}
certificate_configs
}
pub fn watch_folder(path: String, sender: Sender<Vec<CertificateConfig>>) -> notify::Result<()> {
let (tx, rx) = channel();
let mut watcher = RecommendedWatcher::new(tx, Config::default())?;
watcher.watch(path.as_ref(), RecursiveMode::Recursive)?;
info!("Watching for certificates in : {}", path);
let certificate_configs = listdir(path.clone());
sender.send(certificate_configs)?;
let mut start = Instant::now();
loop {
match rx.recv_timeout(Duration::from_secs(1)) {
Ok(Ok(event)) => match &event.kind {
EventKind::Modify(ModifyKind::Data(_)) | EventKind::Create(_) | EventKind::Remove(_) => {
if start.elapsed() > Duration::from_secs(1) {
start = Instant::now();
let certificate_configs = listdir(path.clone());
sender.send(certificate_configs)?;
info!("Certificate changed: {:?}, {:?}", event.kind, event.paths);
}
}
_ => {}
},
Ok(Err(e)) => error!("Watch error: {:?}", e),
Err(_) => {}
}
}
}