mirror of
https://github.com/sadoyan/aralez.git
synced 2026-04-30 23:08:40 +08:00
Dynamic load of SSL certificates from disk.
This commit is contained in:
@@ -1,9 +1,12 @@
|
||||
use openssl::ssl::{select_next_proto, AlpnError, NameType, SniError, SslAlert, SslContext, SslFiletype, SslMethod, SslRef};
|
||||
use dashmap::DashMap;
|
||||
use log::error;
|
||||
use pingora::tls::ssl::{select_next_proto, AlpnError, NameType, SniError, SslAlert, SslContext, SslFiletype, SslMethod, SslRef};
|
||||
use rustls_pemfile::{read_one, Item};
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashSet;
|
||||
use std::fs::File;
|
||||
use std::io::BufReader;
|
||||
// use tokio::time::Instant;
|
||||
use x509_parser::extensions::GeneralName;
|
||||
use x509_parser::nom::Err as NomErr;
|
||||
use x509_parser::prelude::*;
|
||||
@@ -28,43 +31,56 @@ struct CertificateInfo {
|
||||
#[derive(Debug)]
|
||||
pub struct Certificates {
|
||||
configs: Vec<CertificateInfo>,
|
||||
name_map: DashMap<String, SslContext>,
|
||||
pub default_cert_path: String,
|
||||
pub default_key_path: String,
|
||||
}
|
||||
|
||||
impl Certificates {
|
||||
pub fn new(configs: &Vec<CertificateConfig>) -> Self {
|
||||
let default_cert = configs.first().expect("atleast one TLS certificate required");
|
||||
pub fn new(configs: &Vec<CertificateConfig>) -> Option<Self> {
|
||||
let default_cert = configs.first().expect("At least one TLS certificate required");
|
||||
let mut cert_infos = Vec::new();
|
||||
let name_map: DashMap<String, SslContext> = DashMap::new();
|
||||
for config in configs {
|
||||
cert_infos.push(
|
||||
load_cert_info(&config.cert_path, &config.key_path)
|
||||
.unwrap_or_else(|| panic!("unable to load certificate info | public: {}, private: {}", &config.cert_path, &config.key_path)),
|
||||
);
|
||||
let cert_info = load_cert_info(&config.cert_path, &config.key_path);
|
||||
match cert_info {
|
||||
Some(cert) => {
|
||||
for name in &cert.common_names {
|
||||
name_map.insert(name.clone(), cert.ssl_context.clone());
|
||||
}
|
||||
for name in &cert.alt_names {
|
||||
name_map.insert(name.clone(), cert.ssl_context.clone());
|
||||
}
|
||||
|
||||
cert_infos.push(cert)
|
||||
}
|
||||
None => {
|
||||
error!("Unable to load certificate info | public: {}, private: {}", &config.cert_path, &config.key_path);
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
Self {
|
||||
Some(Self {
|
||||
name_map: name_map,
|
||||
configs: cert_infos,
|
||||
default_cert_path: default_cert.cert_path.clone(),
|
||||
default_key_path: default_cert.key_path.clone(),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn find_ssl_context(&self, server_name: &str) -> Option<&SslContext> {
|
||||
fn find_ssl_context(&self, server_name: &str) -> Option<SslContext> {
|
||||
if let Some(ctx) = self.name_map.get(server_name) {
|
||||
return Some(ctx.clone());
|
||||
}
|
||||
for config in &self.configs {
|
||||
// Exact name match
|
||||
if config.common_names.contains(&server_name.to_string()) || config.alt_names.contains(&server_name.to_string()) {
|
||||
return Some(&config.ssl_context);
|
||||
}
|
||||
|
||||
// Wildcard match
|
||||
for name in &config.common_names {
|
||||
if name.starts_with("*.") && server_name.ends_with(&name[1..]) {
|
||||
return Some(&config.ssl_context);
|
||||
return Some(config.ssl_context.clone());
|
||||
}
|
||||
}
|
||||
for name in &config.alt_names {
|
||||
if name.starts_with("*.") && server_name.ends_with(&name[1..]) {
|
||||
return Some(&config.ssl_context);
|
||||
return Some(config.ssl_context.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -74,16 +90,18 @@ impl Certificates {
|
||||
pub fn server_name_callback(&self, ssl_ref: &mut SslRef, ssl_alert: &mut SslAlert) -> Result<(), SniError> {
|
||||
let server_name = ssl_ref.servername(NameType::HOST_NAME);
|
||||
log::debug!("TLS connect: server_name = {:?}, ssl_ref = {:?}, ssl_alert = {:?}", server_name, ssl_ref, ssl_alert);
|
||||
// let start_time = Instant::now();
|
||||
if let Some(name) = server_name {
|
||||
match self.find_ssl_context(name) {
|
||||
Some(ctx) => {
|
||||
ssl_ref.set_ssl_context(ctx).map_err(|_| SniError::ALERT_FATAL)?;
|
||||
ssl_ref.set_ssl_context(&*ctx).map_err(|_| SniError::ALERT_FATAL)?;
|
||||
}
|
||||
None => {
|
||||
log::debug!("No matching server name found");
|
||||
}
|
||||
}
|
||||
}
|
||||
// println!("Context ==> {:?} <==", start_time.elapsed());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -160,7 +178,6 @@ fn load_cert_info(cert_path: &str, key_path: &str) -> Option<CertificateInfo> {
|
||||
|
||||
fn create_ssl_context(cert_path: &str, key_path: &str) -> Result<SslContext, Box<dyn std::error::Error>> {
|
||||
let mut ctx = SslContext::builder(SslMethod::tls())?;
|
||||
|
||||
ctx.set_certificate_chain_file(cert_path)?;
|
||||
ctx.set_private_key_file(key_path, SslFiletype::PEM)?;
|
||||
ctx.set_alpn_select_callback(prefer_h2);
|
||||
|
||||
Reference in New Issue
Block a user