config_watcher.rs 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. use crate::{
  2. config::{ClientConfig, ClientServiceConfig, ServerConfig, ServerServiceConfig},
  3. Config,
  4. };
  5. use anyhow::{Context, Result};
  6. use std::{
  7. collections::HashMap,
  8. env,
  9. path::{Path, PathBuf},
  10. };
  11. use tokio::sync::{broadcast, mpsc};
  12. use tracing::{error, info, instrument};
  13. #[cfg(feature = "notify")]
  14. use notify::{EventKind, RecursiveMode, Watcher};
  15. #[derive(Debug, PartialEq, Eq, Clone)]
  16. pub enum ConfigChange {
  17. General(Box<Config>), // Trigger a full restart
  18. ServerChange(ServerServiceChange),
  19. ClientChange(ClientServiceChange),
  20. }
  21. #[derive(Debug, PartialEq, Eq, Clone)]
  22. pub enum ClientServiceChange {
  23. Add(ClientServiceConfig),
  24. Delete(String),
  25. }
  26. #[derive(Debug, PartialEq, Eq, Clone)]
  27. pub enum ServerServiceChange {
  28. Add(ServerServiceConfig),
  29. Delete(String),
  30. }
  31. trait InstanceConfig: Clone {
  32. type ServiceConfig: PartialEq + Eq + Clone;
  33. fn equal_without_service(&self, rhs: &Self) -> bool;
  34. fn service_delete_change(s: String) -> ConfigChange;
  35. fn service_add_change(cfg: Self::ServiceConfig) -> ConfigChange;
  36. fn get_services(&self) -> &HashMap<String, Self::ServiceConfig>;
  37. }
  38. impl InstanceConfig for ServerConfig {
  39. type ServiceConfig = ServerServiceConfig;
  40. fn equal_without_service(&self, rhs: &Self) -> bool {
  41. let left = ServerConfig {
  42. services: Default::default(),
  43. ..self.clone()
  44. };
  45. let right = ServerConfig {
  46. services: Default::default(),
  47. ..rhs.clone()
  48. };
  49. left == right
  50. }
  51. fn service_delete_change(s: String) -> ConfigChange {
  52. ConfigChange::ServerChange(ServerServiceChange::Delete(s))
  53. }
  54. fn service_add_change(cfg: Self::ServiceConfig) -> ConfigChange {
  55. ConfigChange::ServerChange(ServerServiceChange::Add(cfg))
  56. }
  57. fn get_services(&self) -> &HashMap<String, Self::ServiceConfig> {
  58. &self.services
  59. }
  60. }
  61. impl InstanceConfig for ClientConfig {
  62. type ServiceConfig = ClientServiceConfig;
  63. fn equal_without_service(&self, rhs: &Self) -> bool {
  64. let left = ClientConfig {
  65. services: Default::default(),
  66. ..self.clone()
  67. };
  68. let right = ClientConfig {
  69. services: Default::default(),
  70. ..rhs.clone()
  71. };
  72. left == right
  73. }
  74. fn service_delete_change(s: String) -> ConfigChange {
  75. ConfigChange::ClientChange(ClientServiceChange::Delete(s))
  76. }
  77. fn service_add_change(cfg: Self::ServiceConfig) -> ConfigChange {
  78. ConfigChange::ClientChange(ClientServiceChange::Add(cfg))
  79. }
  80. fn get_services(&self) -> &HashMap<String, Self::ServiceConfig> {
  81. &self.services
  82. }
  83. }
  84. pub struct ConfigWatcherHandle {
  85. pub event_rx: mpsc::UnboundedReceiver<ConfigChange>,
  86. }
  87. impl ConfigWatcherHandle {
  88. pub async fn new(path: &Path, shutdown_rx: broadcast::Receiver<bool>) -> Result<Self> {
  89. let (event_tx, event_rx) = mpsc::unbounded_channel();
  90. let origin_cfg = Config::from_file(path).await?;
  91. // Initial start
  92. event_tx
  93. .send(ConfigChange::General(Box::new(origin_cfg.clone())))
  94. .unwrap();
  95. tokio::spawn(config_watcher(
  96. path.to_owned(),
  97. shutdown_rx,
  98. event_tx,
  99. origin_cfg,
  100. ));
  101. Ok(ConfigWatcherHandle { event_rx })
  102. }
  103. }
  104. // Fake config watcher when compiling without `notify`
  105. #[cfg(not(feature = "notify"))]
  106. async fn config_watcher(
  107. _path: PathBuf,
  108. mut shutdown_rx: broadcast::Receiver<bool>,
  109. _event_tx: mpsc::UnboundedSender<ConfigChange>,
  110. _old: Config,
  111. ) -> Result<()> {
  112. // Do nothing except waiting for ctrl-c
  113. let _ = shutdown_rx.recv().await;
  114. Ok(())
  115. }
  116. #[cfg(feature = "notify")]
  117. #[instrument(skip(shutdown_rx, event_tx, old))]
  118. async fn config_watcher(
  119. path: PathBuf,
  120. mut shutdown_rx: broadcast::Receiver<bool>,
  121. event_tx: mpsc::UnboundedSender<ConfigChange>,
  122. mut old: Config,
  123. ) -> Result<()> {
  124. let (fevent_tx, mut fevent_rx) = mpsc::unbounded_channel();
  125. let path = if path.is_absolute() {
  126. path
  127. } else {
  128. env::current_dir()?.join(path)
  129. };
  130. let parent_path = path.parent().expect("config file should have a parent dir");
  131. let path_clone = path.clone();
  132. let mut watcher =
  133. notify::recommended_watcher(move |res: Result<notify::Event, _>| match res {
  134. Ok(e) => {
  135. if matches!(e.kind, EventKind::Modify(_))
  136. && e.paths
  137. .iter()
  138. .map(|x| x.file_name())
  139. .any(|x| x == path_clone.file_name())
  140. {
  141. let _ = fevent_tx.send(true);
  142. }
  143. }
  144. Err(e) => error!("watch error: {:#}", e),
  145. })?;
  146. watcher.watch(parent_path, RecursiveMode::NonRecursive)?;
  147. info!("Start watching the config");
  148. loop {
  149. tokio::select! {
  150. e = fevent_rx.recv() => {
  151. match e {
  152. Some(_) => {
  153. info!("Rescan the configuration");
  154. let new = match Config::from_file(&path).await.with_context(|| "The changed configuration is invalid. Ignored") {
  155. Ok(v) => v,
  156. Err(e) => {
  157. error!("{:#}", e);
  158. // If the config is invalid, just ignore it
  159. continue;
  160. }
  161. };
  162. let events = calculate_events(&old, &new).into_iter().flatten();
  163. for event in events {
  164. event_tx.send(event)?;
  165. }
  166. old = new;
  167. },
  168. None => break
  169. }
  170. },
  171. _ = shutdown_rx.recv() => break
  172. }
  173. }
  174. info!("Config watcher exiting");
  175. Ok(())
  176. }
  177. fn calculate_events(old: &Config, new: &Config) -> Option<Vec<ConfigChange>> {
  178. if old == new {
  179. return None;
  180. }
  181. if (old.server.is_some() != new.server.is_some())
  182. || (old.client.is_some() != new.client.is_some())
  183. {
  184. return Some(vec![ConfigChange::General(Box::new(new.clone()))]);
  185. }
  186. let mut ret = vec![];
  187. if old.server != new.server {
  188. match calculate_instance_config_events(
  189. old.server.as_ref().unwrap(),
  190. new.server.as_ref().unwrap(),
  191. ) {
  192. Some(mut v) => ret.append(&mut v),
  193. None => return Some(vec![ConfigChange::General(Box::new(new.clone()))]),
  194. }
  195. }
  196. if old.client != new.client {
  197. match calculate_instance_config_events(
  198. old.client.as_ref().unwrap(),
  199. new.client.as_ref().unwrap(),
  200. ) {
  201. Some(mut v) => ret.append(&mut v),
  202. None => return Some(vec![ConfigChange::General(Box::new(new.clone()))]),
  203. }
  204. }
  205. Some(ret)
  206. }
  207. // None indicates a General change needed
  208. fn calculate_instance_config_events<T: InstanceConfig>(
  209. old: &T,
  210. new: &T,
  211. ) -> Option<Vec<ConfigChange>> {
  212. if !old.equal_without_service(new) {
  213. return None;
  214. }
  215. let old = old.get_services();
  216. let new = new.get_services();
  217. let deletions = old
  218. .keys()
  219. .filter(|&name| new.get(name).is_none())
  220. .map(|x| T::service_delete_change(x.to_owned()));
  221. let addition = new
  222. .iter()
  223. .filter(|(name, c)| old.get(*name) != Some(*c))
  224. .map(|(_, c)| T::service_add_change(c.clone()));
  225. Some(deletions.chain(addition).collect())
  226. }
  227. #[cfg(test)]
  228. mod test {
  229. use crate::config::ServerConfig;
  230. use super::*;
  231. // macro to create map or set literal
  232. macro_rules! collection {
  233. // map-like
  234. ($($k:expr => $v:expr),* $(,)?) => {{
  235. use std::iter::{Iterator, IntoIterator};
  236. Iterator::collect(IntoIterator::into_iter([$(($k, $v),)*]))
  237. }};
  238. }
  239. #[test]
  240. fn test_calculate_events() {
  241. struct Test {
  242. old: Config,
  243. new: Config,
  244. }
  245. let tests = [
  246. Test {
  247. old: Config {
  248. server: Some(Default::default()),
  249. client: None,
  250. },
  251. new: Config {
  252. server: Some(Default::default()),
  253. client: Some(Default::default()),
  254. },
  255. },
  256. Test {
  257. old: Config {
  258. server: Some(ServerConfig {
  259. bind_addr: String::from("127.0.0.1:2334"),
  260. ..Default::default()
  261. }),
  262. client: None,
  263. },
  264. new: Config {
  265. server: Some(ServerConfig {
  266. bind_addr: String::from("127.0.0.1:2333"),
  267. services: collection!(String::from("foo") => Default::default()),
  268. ..Default::default()
  269. }),
  270. client: None,
  271. },
  272. },
  273. Test {
  274. old: Config {
  275. server: Some(Default::default()),
  276. client: None,
  277. },
  278. new: Config {
  279. server: Some(ServerConfig {
  280. services: collection!(String::from("foo") => Default::default()),
  281. ..Default::default()
  282. }),
  283. client: None,
  284. },
  285. },
  286. Test {
  287. old: Config {
  288. server: Some(ServerConfig {
  289. services: collection!(String::from("foo") => Default::default()),
  290. ..Default::default()
  291. }),
  292. client: None,
  293. },
  294. new: Config {
  295. server: Some(Default::default()),
  296. client: None,
  297. },
  298. },
  299. Test {
  300. old: Config {
  301. server: Some(ServerConfig {
  302. services: collection!(String::from("foo1") => ServerServiceConfig::with_name("foo1"), String::from("foo2") => ServerServiceConfig::with_name("foo2")),
  303. ..Default::default()
  304. }),
  305. client: Some(ClientConfig {
  306. services: collection!(String::from("foo1") => ClientServiceConfig::with_name("foo1"), String::from("foo2") => ClientServiceConfig::with_name("foo2")),
  307. ..Default::default()
  308. }),
  309. },
  310. new: Config {
  311. server: Some(ServerConfig {
  312. services: collection!(String::from("bar1") => ServerServiceConfig::with_name("bar1"), String::from("foo2") => ServerServiceConfig::with_name("foo2")),
  313. ..Default::default()
  314. }),
  315. client: Some(ClientConfig {
  316. services: collection!(String::from("bar1") => ClientServiceConfig::with_name("bar1"), String::from("bar2") => ClientServiceConfig::with_name("bar2")),
  317. ..Default::default()
  318. }),
  319. },
  320. },
  321. ];
  322. let mut expected = [
  323. vec![ConfigChange::General(Box::new(tests[0].new.clone()))],
  324. vec![ConfigChange::General(Box::new(tests[1].new.clone()))],
  325. vec![ConfigChange::ServerChange(ServerServiceChange::Add(
  326. Default::default(),
  327. ))],
  328. vec![ConfigChange::ServerChange(ServerServiceChange::Delete(
  329. String::from("foo"),
  330. ))],
  331. vec![
  332. ConfigChange::ServerChange(ServerServiceChange::Delete(String::from("foo1"))),
  333. ConfigChange::ServerChange(ServerServiceChange::Add(
  334. tests[4].new.server.as_ref().unwrap().services["bar1"].clone(),
  335. )),
  336. ConfigChange::ClientChange(ClientServiceChange::Delete(String::from("foo1"))),
  337. ConfigChange::ClientChange(ClientServiceChange::Delete(String::from("foo2"))),
  338. ConfigChange::ClientChange(ClientServiceChange::Add(
  339. tests[4].new.client.as_ref().unwrap().services["bar1"].clone(),
  340. )),
  341. ConfigChange::ClientChange(ClientServiceChange::Add(
  342. tests[4].new.client.as_ref().unwrap().services["bar2"].clone(),
  343. )),
  344. ],
  345. ];
  346. assert_eq!(tests.len(), expected.len());
  347. for i in 0..tests.len() {
  348. let mut actual = calculate_events(&tests[i].old, &tests[i].new).unwrap();
  349. let get_key = |x: &ConfigChange| -> String {
  350. match x {
  351. ConfigChange::General(_) => String::from("g"),
  352. ConfigChange::ServerChange(sc) => match sc {
  353. ServerServiceChange::Add(c) => "s_add_".to_owned() + &c.name,
  354. ServerServiceChange::Delete(s) => "s_del_".to_owned() + s,
  355. },
  356. ConfigChange::ClientChange(sc) => match sc {
  357. ClientServiceChange::Add(c) => "c_add_".to_owned() + &c.name,
  358. ClientServiceChange::Delete(s) => "c_del_".to_owned() + s,
  359. },
  360. }
  361. };
  362. actual.sort_by_cached_key(get_key);
  363. expected[i].sort_by_cached_key(get_key);
  364. assert_eq!(actual, expected[i]);
  365. }
  366. // No changes
  367. assert_eq!(
  368. calculate_events(
  369. &Config {
  370. server: Default::default(),
  371. client: None,
  372. },
  373. &Config {
  374. server: Default::default(),
  375. client: None,
  376. },
  377. ),
  378. None
  379. );
  380. }
  381. }