config.rs 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497
  1. use anyhow::{anyhow, bail, Context, Result};
  2. use serde::{Deserialize, Serialize};
  3. use std::collections::HashMap;
  4. use std::fmt::{Debug, Formatter};
  5. use std::ops::Deref;
  6. use std::path::Path;
  7. use tokio::fs;
  8. use url::Url;
  9. use crate::transport::{DEFAULT_KEEPALIVE_INTERVAL, DEFAULT_KEEPALIVE_SECS, DEFAULT_NODELAY};
  10. /// Application-layer heartbeat interval in secs
  11. const DEFAULT_HEARTBEAT_INTERVAL_SECS: u64 = 30;
  12. const DEFAULT_HEARTBEAT_TIMEOUT_SECS: u64 = 40;
  13. /// Client
  14. const DEFAULT_CLIENT_RETRY_INTERVAL_SECS: u64 = 1;
  15. /// String with Debug implementation that emits "MASKED"
  16. /// Used to mask sensitive strings when logging
  17. #[derive(Serialize, Deserialize, Default, PartialEq, Eq, Clone)]
  18. pub struct MaskedString(String);
  19. impl Debug for MaskedString {
  20. fn fmt(&self, f: &mut Formatter<'_>) -> std::result::Result<(), std::fmt::Error> {
  21. f.write_str("MASKED")
  22. }
  23. }
  24. impl Deref for MaskedString {
  25. type Target = str;
  26. fn deref(&self) -> &Self::Target {
  27. &self.0
  28. }
  29. }
  30. impl From<&str> for MaskedString {
  31. fn from(s: &str) -> MaskedString {
  32. MaskedString(String::from(s))
  33. }
  34. }
  35. #[derive(Debug, Serialize, Deserialize, Copy, Clone, PartialEq, Eq, Default)]
  36. pub enum TransportType {
  37. #[default]
  38. #[serde(rename = "tcp")]
  39. Tcp,
  40. #[serde(rename = "tls")]
  41. Tls,
  42. #[serde(rename = "noise")]
  43. Noise,
  44. #[serde(rename = "websocket")]
  45. Websocket,
  46. }
  47. /// Per service config
  48. /// All Option are optional in configuration but must be Some value in runtime
  49. #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, Default)]
  50. #[serde(deny_unknown_fields)]
  51. pub struct ClientServiceConfig {
  52. #[serde(rename = "type", default = "default_service_type")]
  53. pub service_type: ServiceType,
  54. #[serde(skip)]
  55. pub name: String,
  56. pub local_addr: String,
  57. #[serde(default)] // Default to false
  58. pub prefer_ipv6: bool,
  59. pub token: Option<MaskedString>,
  60. pub nodelay: Option<bool>,
  61. pub retry_interval: Option<u64>,
  62. }
  63. impl ClientServiceConfig {
  64. pub fn with_name(name: &str) -> ClientServiceConfig {
  65. ClientServiceConfig {
  66. name: name.to_string(),
  67. ..Default::default()
  68. }
  69. }
  70. }
  71. #[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, Default)]
  72. pub enum ServiceType {
  73. #[serde(rename = "tcp")]
  74. #[default]
  75. Tcp,
  76. #[serde(rename = "udp")]
  77. Udp,
  78. }
  79. fn default_service_type() -> ServiceType {
  80. Default::default()
  81. }
  82. /// Per service config
  83. /// All Option are optional in configuration but must be Some value in runtime
  84. #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, Default)]
  85. #[serde(deny_unknown_fields)]
  86. pub struct ServerServiceConfig {
  87. #[serde(rename = "type", default = "default_service_type")]
  88. pub service_type: ServiceType,
  89. #[serde(skip)]
  90. pub name: String,
  91. pub bind_addr: String,
  92. pub token: Option<MaskedString>,
  93. pub nodelay: Option<bool>,
  94. pub enable_proxy_protocol: Option<bool>,
  95. }
  96. impl ServerServiceConfig {
  97. pub fn with_name(name: &str) -> ServerServiceConfig {
  98. ServerServiceConfig {
  99. name: name.to_string(),
  100. ..Default::default()
  101. }
  102. }
  103. }
  104. #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
  105. #[serde(deny_unknown_fields)]
  106. pub struct TlsConfig {
  107. pub hostname: Option<String>,
  108. pub trusted_root: Option<String>,
  109. pub pkcs12: Option<String>,
  110. pub pkcs12_password: Option<MaskedString>,
  111. }
  112. fn default_noise_pattern() -> String {
  113. String::from("Noise_NK_25519_ChaChaPoly_BLAKE2s")
  114. }
  115. #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
  116. #[serde(deny_unknown_fields)]
  117. pub struct NoiseConfig {
  118. #[serde(default = "default_noise_pattern")]
  119. pub pattern: String,
  120. pub local_private_key: Option<MaskedString>,
  121. pub remote_public_key: Option<String>,
  122. // TODO: Maybe psk can be added
  123. }
  124. #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
  125. #[serde(deny_unknown_fields)]
  126. pub struct WebsocketConfig {
  127. pub tls: bool,
  128. }
  129. fn default_nodelay() -> bool {
  130. DEFAULT_NODELAY
  131. }
  132. fn default_keepalive_secs() -> u64 {
  133. DEFAULT_KEEPALIVE_SECS
  134. }
  135. fn default_keepalive_interval() -> u64 {
  136. DEFAULT_KEEPALIVE_INTERVAL
  137. }
  138. #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
  139. #[serde(deny_unknown_fields)]
  140. pub struct TcpConfig {
  141. #[serde(default = "default_nodelay")]
  142. pub nodelay: bool,
  143. #[serde(default = "default_keepalive_secs")]
  144. pub keepalive_secs: u64,
  145. #[serde(default = "default_keepalive_interval")]
  146. pub keepalive_interval: u64,
  147. pub proxy: Option<Url>,
  148. }
  149. impl Default for TcpConfig {
  150. fn default() -> Self {
  151. Self {
  152. nodelay: default_nodelay(),
  153. keepalive_secs: default_keepalive_secs(),
  154. keepalive_interval: default_keepalive_interval(),
  155. proxy: None,
  156. }
  157. }
  158. }
  159. #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Default)]
  160. #[serde(deny_unknown_fields)]
  161. pub struct TransportConfig {
  162. #[serde(rename = "type")]
  163. pub transport_type: TransportType,
  164. #[serde(default)]
  165. pub tcp: TcpConfig,
  166. pub tls: Option<TlsConfig>,
  167. pub noise: Option<NoiseConfig>,
  168. pub websocket: Option<WebsocketConfig>,
  169. }
  170. fn default_heartbeat_timeout() -> u64 {
  171. DEFAULT_HEARTBEAT_TIMEOUT_SECS
  172. }
  173. fn default_client_retry_interval() -> u64 {
  174. DEFAULT_CLIENT_RETRY_INTERVAL_SECS
  175. }
  176. #[derive(Debug, Serialize, Deserialize, Default, PartialEq, Eq, Clone)]
  177. #[serde(deny_unknown_fields)]
  178. pub struct ClientConfig {
  179. pub remote_addr: String,
  180. pub default_token: Option<MaskedString>,
  181. pub prefer_ipv6: Option<bool>,
  182. pub services: HashMap<String, ClientServiceConfig>,
  183. #[serde(default)]
  184. pub transport: TransportConfig,
  185. #[serde(default = "default_heartbeat_timeout")]
  186. pub heartbeat_timeout: u64,
  187. #[serde(default = "default_client_retry_interval")]
  188. pub retry_interval: u64,
  189. }
  190. fn default_heartbeat_interval() -> u64 {
  191. DEFAULT_HEARTBEAT_INTERVAL_SECS
  192. }
  193. #[derive(Debug, Serialize, Deserialize, Default, PartialEq, Eq, Clone)]
  194. #[serde(deny_unknown_fields)]
  195. pub struct ServerConfig {
  196. pub bind_addr: String,
  197. pub default_token: Option<MaskedString>,
  198. pub services: HashMap<String, ServerServiceConfig>,
  199. #[serde(default)]
  200. pub transport: TransportConfig,
  201. #[serde(default = "default_heartbeat_interval")]
  202. pub heartbeat_interval: u64,
  203. }
  204. #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
  205. #[serde(deny_unknown_fields)]
  206. pub struct Config {
  207. pub server: Option<ServerConfig>,
  208. pub client: Option<ClientConfig>,
  209. }
  210. impl Config {
  211. fn from_str(s: &str) -> Result<Config> {
  212. let mut config: Config = toml::from_str(s).with_context(|| "Failed to parse the config")?;
  213. if let Some(server) = config.server.as_mut() {
  214. Config::validate_server_config(server)?;
  215. }
  216. if let Some(client) = config.client.as_mut() {
  217. Config::validate_client_config(client)?;
  218. }
  219. if config.server.is_none() && config.client.is_none() {
  220. Err(anyhow!("Neither of `[server]` or `[client]` is defined"))
  221. } else {
  222. Ok(config)
  223. }
  224. }
  225. fn validate_server_config(server: &mut ServerConfig) -> Result<()> {
  226. // Validate services
  227. for (name, s) in &mut server.services {
  228. s.name = name.clone();
  229. if s.token.is_none() {
  230. s.token = server.default_token.clone();
  231. if s.token.is_none() {
  232. bail!("The token of service {} is not set", name);
  233. }
  234. }
  235. }
  236. Config::validate_transport_config(&server.transport, true)?;
  237. Ok(())
  238. }
  239. fn validate_client_config(client: &mut ClientConfig) -> Result<()> {
  240. // Validate services
  241. for (name, s) in &mut client.services {
  242. s.name = name.clone();
  243. if s.token.is_none() {
  244. s.token = client.default_token.clone();
  245. if s.token.is_none() {
  246. bail!("The token of service {} is not set", name);
  247. }
  248. }
  249. if s.retry_interval.is_none() {
  250. s.retry_interval = Some(client.retry_interval);
  251. }
  252. }
  253. Config::validate_transport_config(&client.transport, false)?;
  254. Ok(())
  255. }
  256. fn validate_transport_config(config: &TransportConfig, is_server: bool) -> Result<()> {
  257. config
  258. .tcp
  259. .proxy
  260. .as_ref()
  261. .map_or(Ok(()), |u| match u.scheme() {
  262. "socks5" => Ok(()),
  263. "http" => Ok(()),
  264. _ => Err(anyhow!(format!("Unknown proxy scheme: {}", u.scheme()))),
  265. })?;
  266. match config.transport_type {
  267. TransportType::Tcp => Ok(()),
  268. TransportType::Tls => {
  269. let tls_config = config
  270. .tls
  271. .as_ref()
  272. .ok_or_else(|| anyhow!("Missing TLS configuration"))?;
  273. if is_server {
  274. tls_config
  275. .pkcs12
  276. .as_ref()
  277. .and(tls_config.pkcs12_password.as_ref())
  278. .ok_or_else(|| anyhow!("Missing `pkcs12` or `pkcs12_password`"))?;
  279. }
  280. Ok(())
  281. }
  282. TransportType::Noise => {
  283. // The check is done in transport
  284. Ok(())
  285. }
  286. TransportType::Websocket => Ok(()),
  287. }
  288. }
  289. pub async fn from_file(path: &Path) -> Result<Config> {
  290. let s: String = fs::read_to_string(path)
  291. .await
  292. .with_context(|| format!("Failed to read the config {:?}", path))?;
  293. Config::from_str(&s).with_context(|| {
  294. "Configuration is invalid. Please refer to the configuration specification."
  295. })
  296. }
  297. }
  298. #[cfg(test)]
  299. mod tests {
  300. use super::*;
  301. use std::{fs, path::PathBuf};
  302. use anyhow::Result;
  303. fn list_config_files<T: AsRef<Path>>(root: T) -> Result<Vec<PathBuf>> {
  304. let mut files = Vec::new();
  305. for entry in fs::read_dir(root)? {
  306. let entry = entry?;
  307. let path = entry.path();
  308. if path.is_file() {
  309. files.push(path);
  310. } else if path.is_dir() {
  311. files.append(&mut list_config_files(path)?);
  312. }
  313. }
  314. Ok(files)
  315. }
  316. fn get_all_example_config() -> Result<Vec<PathBuf>> {
  317. Ok(list_config_files("./examples")?
  318. .into_iter()
  319. .filter(|x| x.ends_with(".toml"))
  320. .collect())
  321. }
  322. #[test]
  323. fn test_example_config() -> Result<()> {
  324. let paths = get_all_example_config()?;
  325. for p in paths {
  326. let s = fs::read_to_string(p)?;
  327. Config::from_str(&s)?;
  328. }
  329. Ok(())
  330. }
  331. #[test]
  332. fn test_valid_config() -> Result<()> {
  333. let paths = list_config_files("tests/config_test/valid_config")?;
  334. for p in paths {
  335. let s = fs::read_to_string(p)?;
  336. Config::from_str(&s)?;
  337. }
  338. Ok(())
  339. }
  340. #[test]
  341. fn test_invalid_config() -> Result<()> {
  342. let paths = list_config_files("tests/config_test/invalid_config")?;
  343. for p in paths {
  344. let s = fs::read_to_string(p)?;
  345. assert!(Config::from_str(&s).is_err());
  346. }
  347. Ok(())
  348. }
  349. #[test]
  350. fn test_validate_server_config() -> Result<()> {
  351. let mut cfg = ServerConfig::default();
  352. cfg.services.insert(
  353. "foo1".into(),
  354. ServerServiceConfig {
  355. service_type: ServiceType::Tcp,
  356. name: "foo1".into(),
  357. bind_addr: "127.0.0.1:80".into(),
  358. token: None,
  359. ..Default::default()
  360. },
  361. );
  362. // Missing the token
  363. assert!(Config::validate_server_config(&mut cfg).is_err());
  364. // Use the default token
  365. cfg.default_token = Some("123".into());
  366. assert!(Config::validate_server_config(&mut cfg).is_ok());
  367. assert_eq!(
  368. cfg.services
  369. .get("foo1")
  370. .as_ref()
  371. .unwrap()
  372. .token
  373. .as_ref()
  374. .unwrap()
  375. .0,
  376. "123"
  377. );
  378. // The default token won't override the service token
  379. cfg.services.get_mut("foo1").unwrap().token = Some("4".into());
  380. assert!(Config::validate_server_config(&mut cfg).is_ok());
  381. assert_eq!(
  382. cfg.services
  383. .get("foo1")
  384. .as_ref()
  385. .unwrap()
  386. .token
  387. .as_ref()
  388. .unwrap()
  389. .0,
  390. "4"
  391. );
  392. Ok(())
  393. }
  394. #[test]
  395. fn test_validate_client_config() -> Result<()> {
  396. let mut cfg = ClientConfig::default();
  397. cfg.services.insert(
  398. "foo1".into(),
  399. ClientServiceConfig {
  400. service_type: ServiceType::Tcp,
  401. name: "foo1".into(),
  402. local_addr: "127.0.0.1:80".into(),
  403. token: None,
  404. ..Default::default()
  405. },
  406. );
  407. // Missing the token
  408. assert!(Config::validate_client_config(&mut cfg).is_err());
  409. // Use the default token
  410. cfg.default_token = Some("123".into());
  411. assert!(Config::validate_client_config(&mut cfg).is_ok());
  412. assert_eq!(
  413. cfg.services
  414. .get("foo1")
  415. .as_ref()
  416. .unwrap()
  417. .token
  418. .as_ref()
  419. .unwrap()
  420. .0,
  421. "123"
  422. );
  423. // The default token won't override the service token
  424. cfg.services.get_mut("foo1").unwrap().token = Some("4".into());
  425. assert!(Config::validate_client_config(&mut cfg).is_ok());
  426. assert_eq!(
  427. cfg.services
  428. .get("foo1")
  429. .as_ref()
  430. .unwrap()
  431. .token
  432. .as_ref()
  433. .unwrap()
  434. .0,
  435. "4"
  436. );
  437. Ok(())
  438. }
  439. }