config.rs 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. use anyhow::{anyhow, bail, Context, Result};
  2. use serde::{Deserialize, Serialize};
  3. use std::collections::HashMap;
  4. use std::fmt::{Debug, Formatter};
  5. use std::ops::Deref;
  6. use std::path::Path;
  7. use tokio::fs;
  8. use url::Url;
  9. use crate::transport::{DEFAULT_KEEPALIVE_INTERVAL, DEFAULT_KEEPALIVE_SECS, DEFAULT_NODELAY};
  10. /// Application-layer heartbeat interval in secs
  11. const DEFAULT_HEARTBEAT_INTERVAL_SECS: u64 = 30;
  12. const DEFAULT_HEARTBEAT_TIMEOUT_SECS: u64 = 40;
  13. /// String with Debug implementation that emits "MASKED"
  14. /// Used to mask sensitive strings when logging
  15. #[derive(Serialize, Deserialize, Default, PartialEq, Eq, Clone)]
  16. pub struct MaskedString(String);
  17. impl Debug for MaskedString {
  18. fn fmt(&self, f: &mut Formatter<'_>) -> std::result::Result<(), std::fmt::Error> {
  19. f.write_str("MASKED")
  20. }
  21. }
  22. impl Deref for MaskedString {
  23. type Target = str;
  24. fn deref(&self) -> &Self::Target {
  25. &self.0
  26. }
  27. }
  28. impl From<&str> for MaskedString {
  29. fn from(s: &str) -> MaskedString {
  30. MaskedString(String::from(s))
  31. }
  32. }
  33. #[derive(Debug, Serialize, Deserialize, Copy, Clone, PartialEq, Eq)]
  34. pub enum TransportType {
  35. #[serde(rename = "tcp")]
  36. Tcp,
  37. #[serde(rename = "tls")]
  38. Tls,
  39. #[serde(rename = "noise")]
  40. Noise,
  41. }
  42. impl Default for TransportType {
  43. fn default() -> TransportType {
  44. TransportType::Tcp
  45. }
  46. }
  47. #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, Default)]
  48. #[serde(deny_unknown_fields)]
  49. pub struct ClientServiceConfig {
  50. #[serde(rename = "type", default = "default_service_type")]
  51. pub service_type: ServiceType,
  52. #[serde(skip)]
  53. pub name: String,
  54. pub local_addr: String,
  55. pub token: Option<MaskedString>,
  56. pub nodelay: Option<bool>,
  57. }
  58. impl ClientServiceConfig {
  59. pub fn with_name(name: &str) -> ClientServiceConfig {
  60. ClientServiceConfig {
  61. name: name.to_string(),
  62. ..Default::default()
  63. }
  64. }
  65. }
  66. #[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq)]
  67. pub enum ServiceType {
  68. #[serde(rename = "tcp")]
  69. Tcp,
  70. #[serde(rename = "udp")]
  71. Udp,
  72. }
  73. impl Default for ServiceType {
  74. fn default() -> Self {
  75. ServiceType::Tcp
  76. }
  77. }
  78. fn default_service_type() -> ServiceType {
  79. Default::default()
  80. }
  81. #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, Default)]
  82. #[serde(deny_unknown_fields)]
  83. pub struct ServerServiceConfig {
  84. #[serde(rename = "type", default = "default_service_type")]
  85. pub service_type: ServiceType,
  86. #[serde(skip)]
  87. pub name: String,
  88. pub bind_addr: String,
  89. pub token: Option<MaskedString>,
  90. pub nodelay: Option<bool>,
  91. }
  92. impl ServerServiceConfig {
  93. pub fn with_name(name: &str) -> ServerServiceConfig {
  94. ServerServiceConfig {
  95. name: name.to_string(),
  96. ..Default::default()
  97. }
  98. }
  99. }
  100. #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
  101. #[serde(deny_unknown_fields)]
  102. pub struct TlsConfig {
  103. pub hostname: Option<String>,
  104. pub trusted_root: Option<String>,
  105. pub pkcs12: Option<String>,
  106. pub pkcs12_password: Option<MaskedString>,
  107. }
  108. fn default_noise_pattern() -> String {
  109. String::from("Noise_NK_25519_ChaChaPoly_BLAKE2s")
  110. }
  111. #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
  112. #[serde(deny_unknown_fields)]
  113. pub struct NoiseConfig {
  114. #[serde(default = "default_noise_pattern")]
  115. pub pattern: String,
  116. pub local_private_key: Option<MaskedString>,
  117. pub remote_public_key: Option<String>,
  118. // TODO: Maybe psk can be added
  119. }
  120. fn default_nodelay() -> bool {
  121. DEFAULT_NODELAY
  122. }
  123. fn default_keepalive_secs() -> u64 {
  124. DEFAULT_KEEPALIVE_SECS
  125. }
  126. fn default_keepalive_interval() -> u64 {
  127. DEFAULT_KEEPALIVE_INTERVAL
  128. }
  129. #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
  130. #[serde(deny_unknown_fields)]
  131. pub struct TcpConfig {
  132. #[serde(default = "default_nodelay")]
  133. pub nodelay: bool,
  134. #[serde(default = "default_keepalive_secs")]
  135. pub keepalive_secs: u64,
  136. #[serde(default = "default_keepalive_interval")]
  137. pub keepalive_interval: u64,
  138. pub proxy: Option<Url>,
  139. }
  140. impl Default for TcpConfig {
  141. fn default() -> Self {
  142. Self {
  143. nodelay: default_nodelay(),
  144. keepalive_secs: default_keepalive_secs(),
  145. keepalive_interval: default_keepalive_interval(),
  146. proxy: None,
  147. }
  148. }
  149. }
  150. #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Default)]
  151. #[serde(deny_unknown_fields)]
  152. pub struct TransportConfig {
  153. #[serde(rename = "type")]
  154. pub transport_type: TransportType,
  155. #[serde(default)]
  156. pub tcp: TcpConfig,
  157. pub tls: Option<TlsConfig>,
  158. pub noise: Option<NoiseConfig>,
  159. }
  160. fn default_heartbeat_timeout() -> u64 {
  161. DEFAULT_HEARTBEAT_TIMEOUT_SECS
  162. }
  163. #[derive(Debug, Serialize, Deserialize, Default, PartialEq, Eq, Clone)]
  164. #[serde(deny_unknown_fields)]
  165. pub struct ClientConfig {
  166. pub remote_addr: String,
  167. pub default_token: Option<MaskedString>,
  168. pub services: HashMap<String, ClientServiceConfig>,
  169. #[serde(default)]
  170. pub transport: TransportConfig,
  171. #[serde(default = "default_heartbeat_timeout")]
  172. pub heartbeat_timeout: u64,
  173. }
  174. fn default_heartbeat_interval() -> u64 {
  175. DEFAULT_HEARTBEAT_INTERVAL_SECS
  176. }
  177. #[derive(Debug, Serialize, Deserialize, Default, PartialEq, Eq, Clone)]
  178. #[serde(deny_unknown_fields)]
  179. pub struct ServerConfig {
  180. pub bind_addr: String,
  181. pub default_token: Option<MaskedString>,
  182. pub services: HashMap<String, ServerServiceConfig>,
  183. #[serde(default)]
  184. pub transport: TransportConfig,
  185. #[serde(default = "default_heartbeat_interval")]
  186. pub heartbeat_interval: u64,
  187. }
  188. #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
  189. #[serde(deny_unknown_fields)]
  190. pub struct Config {
  191. pub server: Option<ServerConfig>,
  192. pub client: Option<ClientConfig>,
  193. }
  194. impl Config {
  195. fn from_str(s: &str) -> Result<Config> {
  196. let mut config: Config = toml::from_str(s).with_context(|| "Failed to parse the config")?;
  197. if let Some(server) = config.server.as_mut() {
  198. Config::validate_server_config(server)?;
  199. }
  200. if let Some(client) = config.client.as_mut() {
  201. Config::validate_client_config(client)?;
  202. }
  203. if config.server.is_none() && config.client.is_none() {
  204. Err(anyhow!("Neither of `[server]` or `[client]` is defined"))
  205. } else {
  206. Ok(config)
  207. }
  208. }
  209. fn validate_server_config(server: &mut ServerConfig) -> Result<()> {
  210. // Validate services
  211. for (name, s) in &mut server.services {
  212. s.name = name.clone();
  213. if s.token.is_none() {
  214. s.token = server.default_token.clone();
  215. if s.token.is_none() {
  216. bail!("The token of service {} is not set", name);
  217. }
  218. }
  219. }
  220. Config::validate_transport_config(&server.transport, true)?;
  221. Ok(())
  222. }
  223. fn validate_client_config(client: &mut ClientConfig) -> Result<()> {
  224. // Validate services
  225. for (name, s) in &mut client.services {
  226. s.name = name.clone();
  227. if s.token.is_none() {
  228. s.token = client.default_token.clone();
  229. if s.token.is_none() {
  230. bail!("The token of service {} is not set", name);
  231. }
  232. }
  233. }
  234. Config::validate_transport_config(&client.transport, false)?;
  235. Ok(())
  236. }
  237. fn validate_transport_config(config: &TransportConfig, is_server: bool) -> Result<()> {
  238. config
  239. .tcp
  240. .proxy
  241. .as_ref()
  242. .map_or(Ok(()), |u| match u.scheme() {
  243. "socks5" => Ok(()),
  244. "http" => Ok(()),
  245. _ => Err(anyhow!(format!("Unknown proxy scheme: {}", u.scheme()))),
  246. })?;
  247. match config.transport_type {
  248. TransportType::Tcp => Ok(()),
  249. TransportType::Tls => {
  250. let tls_config = config
  251. .tls
  252. .as_ref()
  253. .ok_or_else(|| anyhow!("Missing TLS configuration"))?;
  254. if is_server {
  255. tls_config
  256. .pkcs12
  257. .as_ref()
  258. .and(tls_config.pkcs12_password.as_ref())
  259. .ok_or_else(|| anyhow!("Missing `pkcs12` or `pkcs12_password`"))?;
  260. }
  261. Ok(())
  262. }
  263. TransportType::Noise => {
  264. // The check is done in transport
  265. Ok(())
  266. }
  267. }
  268. }
  269. pub async fn from_file(path: &Path) -> Result<Config> {
  270. let s: String = fs::read_to_string(path)
  271. .await
  272. .with_context(|| format!("Failed to read the config {:?}", path))?;
  273. Config::from_str(&s).with_context(|| {
  274. "Configuration is invalid. Please refer to the configuration specification."
  275. })
  276. }
  277. }
  278. #[cfg(test)]
  279. mod tests {
  280. use super::*;
  281. use std::{fs, path::PathBuf};
  282. use anyhow::Result;
  283. fn list_config_files<T: AsRef<Path>>(root: T) -> Result<Vec<PathBuf>> {
  284. let mut files = Vec::new();
  285. for entry in fs::read_dir(root)? {
  286. let entry = entry?;
  287. let path = entry.path();
  288. if path.is_file() {
  289. files.push(path);
  290. } else if path.is_dir() {
  291. files.append(&mut list_config_files(path)?);
  292. }
  293. }
  294. Ok(files)
  295. }
  296. fn get_all_example_config() -> Result<Vec<PathBuf>> {
  297. Ok(list_config_files("./examples")?
  298. .into_iter()
  299. .filter(|x| x.ends_with(".toml"))
  300. .collect())
  301. }
  302. #[test]
  303. fn test_example_config() -> Result<()> {
  304. let paths = get_all_example_config()?;
  305. for p in paths {
  306. let s = fs::read_to_string(p)?;
  307. Config::from_str(&s)?;
  308. }
  309. Ok(())
  310. }
  311. #[test]
  312. fn test_valid_config() -> Result<()> {
  313. let paths = list_config_files("tests/config_test/valid_config")?;
  314. for p in paths {
  315. let s = fs::read_to_string(p)?;
  316. Config::from_str(&s)?;
  317. }
  318. Ok(())
  319. }
  320. #[test]
  321. fn test_invalid_config() -> Result<()> {
  322. let paths = list_config_files("tests/config_test/invalid_config")?;
  323. for p in paths {
  324. let s = fs::read_to_string(p)?;
  325. assert!(Config::from_str(&s).is_err());
  326. }
  327. Ok(())
  328. }
  329. #[test]
  330. fn test_validate_server_config() -> Result<()> {
  331. let mut cfg = ServerConfig::default();
  332. cfg.services.insert(
  333. "foo1".into(),
  334. ServerServiceConfig {
  335. service_type: ServiceType::Tcp,
  336. name: "foo1".into(),
  337. bind_addr: "127.0.0.1:80".into(),
  338. token: None,
  339. ..Default::default()
  340. },
  341. );
  342. // Missing the token
  343. assert!(Config::validate_server_config(&mut cfg).is_err());
  344. // Use the default token
  345. cfg.default_token = Some("123".into());
  346. assert!(Config::validate_server_config(&mut cfg).is_ok());
  347. assert_eq!(
  348. cfg.services
  349. .get("foo1")
  350. .as_ref()
  351. .unwrap()
  352. .token
  353. .as_ref()
  354. .unwrap()
  355. .0,
  356. "123"
  357. );
  358. // The default token won't override the service token
  359. cfg.services.get_mut("foo1").unwrap().token = Some("4".into());
  360. assert!(Config::validate_server_config(&mut cfg).is_ok());
  361. assert_eq!(
  362. cfg.services
  363. .get("foo1")
  364. .as_ref()
  365. .unwrap()
  366. .token
  367. .as_ref()
  368. .unwrap()
  369. .0,
  370. "4"
  371. );
  372. Ok(())
  373. }
  374. #[test]
  375. fn test_validate_client_config() -> Result<()> {
  376. let mut cfg = ClientConfig::default();
  377. cfg.services.insert(
  378. "foo1".into(),
  379. ClientServiceConfig {
  380. service_type: ServiceType::Tcp,
  381. name: "foo1".into(),
  382. local_addr: "127.0.0.1:80".into(),
  383. token: None,
  384. ..Default::default()
  385. },
  386. );
  387. // Missing the token
  388. assert!(Config::validate_client_config(&mut cfg).is_err());
  389. // Use the default token
  390. cfg.default_token = Some("123".into());
  391. assert!(Config::validate_client_config(&mut cfg).is_ok());
  392. assert_eq!(
  393. cfg.services
  394. .get("foo1")
  395. .as_ref()
  396. .unwrap()
  397. .token
  398. .as_ref()
  399. .unwrap()
  400. .0,
  401. "123"
  402. );
  403. // The default token won't override the service token
  404. cfg.services.get_mut("foo1").unwrap().token = Some("4".into());
  405. assert!(Config::validate_client_config(&mut cfg).is_ok());
  406. assert_eq!(
  407. cfg.services
  408. .get("foo1")
  409. .as_ref()
  410. .unwrap()
  411. .token
  412. .as_ref()
  413. .unwrap()
  414. .0,
  415. "4"
  416. );
  417. Ok(())
  418. }
  419. }