| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379 |
- use anyhow::{anyhow, bail, Context, Result};
- use serde::{Deserialize, Serialize};
- use std::collections::HashMap;
- use std::path::Path;
- use tokio::fs;
- #[derive(Debug, Serialize, Deserialize, Copy, Clone, PartialEq)]
- pub enum TransportType {
- #[serde(rename = "tcp")]
- Tcp,
- #[serde(rename = "tls")]
- Tls,
- #[serde(rename = "noise")]
- Noise,
- }
- impl Default for TransportType {
- fn default() -> TransportType {
- TransportType::Tcp
- }
- }
- #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
- pub struct ClientServiceConfig {
- #[serde(rename = "type", default = "default_service_type")]
- pub service_type: ServiceType,
- #[serde(skip)]
- pub name: String,
- pub local_addr: String,
- pub token: Option<String>,
- }
- impl ClientServiceConfig {
- pub fn with_name(name: &str) -> ClientServiceConfig {
- ClientServiceConfig {
- name: name.to_string(),
- ..Default::default()
- }
- }
- }
- #[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq)]
- pub enum ServiceType {
- #[serde(rename = "tcp")]
- Tcp,
- #[serde(rename = "udp")]
- Udp,
- }
- impl Default for ServiceType {
- fn default() -> Self {
- ServiceType::Tcp
- }
- }
- fn default_service_type() -> ServiceType {
- Default::default()
- }
- #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
- pub struct ServerServiceConfig {
- #[serde(rename = "type", default = "default_service_type")]
- pub service_type: ServiceType,
- #[serde(skip)]
- pub name: String,
- pub bind_addr: String,
- pub token: Option<String>,
- }
- impl ServerServiceConfig {
- pub fn with_name(name: &str) -> ServerServiceConfig {
- ServerServiceConfig {
- name: name.to_string(),
- ..Default::default()
- }
- }
- }
- #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
- pub struct TlsConfig {
- pub hostname: Option<String>,
- pub trusted_root: Option<String>,
- pub pkcs12: Option<String>,
- pub pkcs12_password: Option<String>,
- }
- fn default_noise_pattern() -> String {
- String::from("Noise_NK_25519_ChaChaPoly_BLAKE2s")
- }
- #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
- pub struct NoiseConfig {
- #[serde(default = "default_noise_pattern")]
- pub pattern: String,
- pub local_private_key: Option<String>,
- pub remote_public_key: Option<String>,
- // TODO: Maybe psk can be added
- }
- #[derive(Debug, Serialize, Deserialize, Default, PartialEq, Clone)]
- pub struct TransportConfig {
- #[serde(rename = "type")]
- pub transport_type: TransportType,
- pub tls: Option<TlsConfig>,
- pub noise: Option<NoiseConfig>,
- }
- fn default_transport() -> TransportConfig {
- Default::default()
- }
- #[derive(Debug, Serialize, Deserialize, Default, PartialEq, Clone)]
- pub struct ClientConfig {
- pub remote_addr: String,
- pub default_token: Option<String>,
- pub services: HashMap<String, ClientServiceConfig>,
- #[serde(default = "default_transport")]
- pub transport: TransportConfig,
- }
- #[derive(Debug, Serialize, Deserialize, Default, PartialEq, Clone)]
- pub struct ServerConfig {
- pub bind_addr: String,
- pub default_token: Option<String>,
- pub services: HashMap<String, ServerServiceConfig>,
- #[serde(default = "default_transport")]
- pub transport: TransportConfig,
- }
- #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
- #[serde(deny_unknown_fields)]
- pub struct Config {
- pub server: Option<ServerConfig>,
- pub client: Option<ClientConfig>,
- }
- impl Config {
- fn from_str(s: &str) -> Result<Config> {
- let mut config: Config = toml::from_str(s).with_context(|| "Failed to parse the config")?;
- if let Some(server) = config.server.as_mut() {
- Config::validate_server_config(server)?;
- }
- if let Some(client) = config.client.as_mut() {
- Config::validate_client_config(client)?;
- }
- if config.server.is_none() && config.client.is_none() {
- Err(anyhow!("Neither of `[server]` or `[client]` is defined"))
- } else {
- Ok(config)
- }
- }
- fn validate_server_config(server: &mut ServerConfig) -> Result<()> {
- // Validate services
- for (name, s) in &mut server.services {
- s.name = name.clone();
- if s.token.is_none() {
- s.token = server.default_token.clone();
- if s.token.is_none() {
- bail!("The token of service {} is not set", name);
- }
- }
- }
- Config::validate_transport_config(&server.transport, true)?;
- Ok(())
- }
- fn validate_client_config(client: &mut ClientConfig) -> Result<()> {
- // Validate services
- for (name, s) in &mut client.services {
- s.name = name.clone();
- if s.token.is_none() {
- s.token = client.default_token.clone();
- if s.token.is_none() {
- bail!("The token of service {} is not set", name);
- }
- }
- }
- Config::validate_transport_config(&client.transport, false)?;
- Ok(())
- }
- fn validate_transport_config(config: &TransportConfig, is_server: bool) -> Result<()> {
- match config.transport_type {
- TransportType::Tcp => Ok(()),
- TransportType::Tls => {
- let tls_config = config
- .tls
- .as_ref()
- .ok_or(anyhow!("Missing TLS configuration"))?;
- if is_server {
- tls_config
- .pkcs12
- .as_ref()
- .and(tls_config.pkcs12_password.as_ref())
- .ok_or(anyhow!("Missing `pkcs12` or `pkcs12_password`"))?;
- } else {
- tls_config
- .trusted_root
- .as_ref()
- .ok_or(anyhow!("Missing `trusted_root`"))?;
- }
- Ok(())
- }
- TransportType::Noise => {
- // The check is done in transport
- Ok(())
- }
- }
- }
- pub async fn from_file(path: &Path) -> Result<Config> {
- let s: String = fs::read_to_string(path)
- .await
- .with_context(|| format!("Failed to read the config {:?}", path))?;
- Config::from_str(&s).with_context(|| {
- "Configuration is invalid. Please refer to the configuration specification."
- })
- }
- }
- #[cfg(test)]
- mod tests {
- use super::*;
- use std::{fs, path::PathBuf};
- use anyhow::Result;
- fn list_config_files<T: AsRef<Path>>(root: T) -> Result<Vec<PathBuf>> {
- let mut files = Vec::new();
- for entry in fs::read_dir(root)? {
- let entry = entry?;
- let path = entry.path();
- if path.is_file() {
- files.push(path);
- } else if path.is_dir() {
- files.append(&mut list_config_files(path)?);
- }
- }
- Ok(files)
- }
- fn get_all_example_config() -> Result<Vec<PathBuf>> {
- Ok(list_config_files("./examples")?
- .into_iter()
- .filter(|x| x.ends_with(".toml"))
- .collect())
- }
- #[test]
- fn test_example_config() -> Result<()> {
- let paths = get_all_example_config()?;
- for p in paths {
- let s = fs::read_to_string(p)?;
- Config::from_str(&s)?;
- }
- Ok(())
- }
- #[test]
- fn test_valid_config() -> Result<()> {
- let paths = list_config_files("tests/config_test/valid_config")?;
- for p in paths {
- let s = fs::read_to_string(p)?;
- Config::from_str(&s)?;
- }
- Ok(())
- }
- #[test]
- fn test_invalid_config() -> Result<()> {
- let paths = list_config_files("tests/config_test/invalid_config")?;
- for p in paths {
- let s = fs::read_to_string(p)?;
- assert!(Config::from_str(&s).is_err());
- }
- Ok(())
- }
- #[test]
- fn test_validate_server_config() -> Result<()> {
- let mut cfg = ServerConfig::default();
- cfg.services.insert(
- "foo1".into(),
- ServerServiceConfig {
- service_type: ServiceType::Tcp,
- name: "foo1".into(),
- bind_addr: "127.0.0.1:80".into(),
- token: None,
- },
- );
- // Missing the token
- assert!(Config::validate_server_config(&mut cfg).is_err());
- // Use the default token
- cfg.default_token = Some("123".into());
- assert!(Config::validate_server_config(&mut cfg).is_ok());
- assert_eq!(
- cfg.services
- .get("foo1")
- .as_ref()
- .unwrap()
- .token
- .as_ref()
- .unwrap(),
- "123"
- );
- // The default token won't override the service token
- cfg.services.get_mut("foo1").unwrap().token = Some("4".into());
- assert!(Config::validate_server_config(&mut cfg).is_ok());
- assert_eq!(
- cfg.services
- .get("foo1")
- .as_ref()
- .unwrap()
- .token
- .as_ref()
- .unwrap(),
- "4"
- );
- Ok(())
- }
- #[test]
- fn test_validate_client_config() -> Result<()> {
- let mut cfg = ClientConfig::default();
- cfg.services.insert(
- "foo1".into(),
- ClientServiceConfig {
- service_type: ServiceType::Tcp,
- name: "foo1".into(),
- local_addr: "127.0.0.1:80".into(),
- token: None,
- },
- );
- // Missing the token
- assert!(Config::validate_client_config(&mut cfg).is_err());
- // Use the default token
- cfg.default_token = Some("123".into());
- assert!(Config::validate_client_config(&mut cfg).is_ok());
- assert_eq!(
- cfg.services
- .get("foo1")
- .as_ref()
- .unwrap()
- .token
- .as_ref()
- .unwrap(),
- "123"
- );
- // The default token won't override the service token
- cfg.services.get_mut("foo1").unwrap().token = Some("4".into());
- assert!(Config::validate_client_config(&mut cfg).is_ok());
- assert_eq!(
- cfg.services
- .get("foo1")
- .as_ref()
- .unwrap()
- .token
- .as_ref()
- .unwrap(),
- "4"
- );
- Ok(())
- }
- }
|