Parcourir la source

refactor: ConfigChange (#191)

Yujia Qiao il y a 3 ans
Parent
commit
ea01c42da7
4 fichiers modifiés avec 146 ajouts et 137 suppressions
  1. 29 23
      src/client.rs
  2. 87 85
      src/config_watcher.rs
  3. 5 6
      src/lib.rs
  4. 25 23
      src/server.rs

+ 29 - 23
src/client.rs

@@ -1,5 +1,5 @@
 use crate::config::{ClientConfig, ClientServiceConfig, Config, ServiceType, TransportType};
-use crate::config_watcher::ServiceChange;
+use crate::config_watcher::{ClientServiceChange, ConfigChange};
 use crate::helper::udp_connect;
 use crate::protocol::Hello::{self, *};
 use crate::protocol::{
@@ -31,7 +31,7 @@ use crate::constants::{run_control_chan_backoff, UDP_BUFFER_SIZE, UDP_SENDQ_SIZE
 pub async fn run_client(
     config: Config,
     shutdown_rx: broadcast::Receiver<bool>,
-    service_rx: mpsc::Receiver<ServiceChange>,
+    update_rx: mpsc::Receiver<ConfigChange>,
 ) -> Result<()> {
     let config = config.client.ok_or_else(|| {
         anyhow!(
@@ -42,13 +42,13 @@ pub async fn run_client(
     match config.transport.transport_type {
         TransportType::Tcp => {
             let mut client = Client::<TcpTransport>::from(config).await?;
-            client.run(shutdown_rx, service_rx).await
+            client.run(shutdown_rx, update_rx).await
         }
         TransportType::Tls => {
             #[cfg(feature = "tls")]
             {
                 let mut client = Client::<TlsTransport>::from(config).await?;
-                client.run(shutdown_rx, service_rx).await
+                client.run(shutdown_rx, update_rx).await
             }
             #[cfg(not(feature = "tls"))]
             crate::helper::feature_not_compile("tls")
@@ -57,7 +57,7 @@ pub async fn run_client(
             #[cfg(feature = "noise")]
             {
                 let mut client = Client::<NoiseTransport>::from(config).await?;
-                client.run(shutdown_rx, service_rx).await
+                client.run(shutdown_rx, update_rx).await
             }
             #[cfg(not(feature = "noise"))]
             crate::helper::feature_not_compile("noise")
@@ -91,7 +91,7 @@ impl<T: 'static + Transport> Client<T> {
     async fn run(
         &mut self,
         mut shutdown_rx: broadcast::Receiver<bool>,
-        mut service_rx: mpsc::Receiver<ServiceChange>,
+        mut update_rx: mpsc::Receiver<ConfigChange>,
     ) -> Result<()> {
         for (name, config) in &self.config.services {
             // Create a control channel for each service defined
@@ -116,24 +116,9 @@ impl<T: 'static + Transport> Client<T> {
                     }
                     break;
                 },
-                e = service_rx.recv() => {
+                e = update_rx.recv() => {
                     if let Some(e) = e {
-                        match e {
-                            ServiceChange::ClientAdd(s)=> {
-                                let name = s.name.clone();
-                                let handle = ControlChannelHandle::new(
-                                    s,
-                                    self.config.remote_addr.clone(),
-                                    self.transport.clone(),
-                                    self.config.heartbeat_timeout
-                                );
-                                let _ = self.service_handles.insert(name, handle);
-                            },
-                            ServiceChange::ClientDelete(s)=> {
-                                let _ = self.service_handles.remove(&s);
-                            },
-                            _ => ()
-                        }
+                        self.handle_hot_reload(e).await;
                     }
                 }
             }
@@ -146,6 +131,27 @@ impl<T: 'static + Transport> Client<T> {
 
         Ok(())
     }
+
+    async fn handle_hot_reload(&mut self, e: ConfigChange) {
+        match e {
+            ConfigChange::ClientChange(client_change) => match client_change {
+                ClientServiceChange::Add(cfg) => {
+                    let name = cfg.name.clone();
+                    let handle = ControlChannelHandle::new(
+                        cfg,
+                        self.config.remote_addr.clone(),
+                        self.transport.clone(),
+                        self.config.heartbeat_timeout,
+                    );
+                    let _ = self.service_handles.insert(name, handle);
+                }
+                ClientServiceChange::Delete(s) => {
+                    let _ = self.service_handles.remove(&s);
+                }
+            },
+            ignored => warn!("Ignored {:?} since running as a client", ignored),
+        }
+    }
 }
 
 struct RunDataChannelArgs<T: Transport> {

+ 87 - 85
src/config_watcher.rs

@@ -14,36 +14,30 @@ use tracing::{error, info, instrument};
 #[cfg(feature = "notify")]
 use notify::{EventKind, RecursiveMode, Watcher};
 
-#[derive(Debug, PartialEq, Eq)]
+#[derive(Debug, PartialEq, Eq, Clone)]
 pub enum ConfigChange {
     General(Box<Config>), // Trigger a full restart
-    ServiceChange(ServiceChange),
+    ServerChange(ServerServiceChange),
+    ClientChange(ClientServiceChange),
 }
 
-#[derive(Debug, PartialEq, Eq)]
-pub enum ServiceChange {
-    ClientAdd(ClientServiceConfig),
-    ClientDelete(String),
-    ServerAdd(ServerServiceConfig),
-    ServerDelete(String),
+#[derive(Debug, PartialEq, Eq, Clone)]
+pub enum ClientServiceChange {
+    Add(ClientServiceConfig),
+    Delete(String),
 }
 
-impl From<ClientServiceConfig> for ServiceChange {
-    fn from(c: ClientServiceConfig) -> Self {
-        ServiceChange::ClientAdd(c)
-    }
-}
-
-impl From<ServerServiceConfig> for ServiceChange {
-    fn from(c: ServerServiceConfig) -> Self {
-        ServiceChange::ServerAdd(c)
-    }
+#[derive(Debug, PartialEq, Eq, Clone)]
+pub enum ServerServiceChange {
+    Add(ServerServiceConfig),
+    Delete(String),
 }
 
 trait InstanceConfig: Clone {
-    type ServiceConfig: Into<ServiceChange> + PartialEq + Clone;
+    type ServiceConfig: PartialEq + Eq + Clone;
     fn equal_without_service(&self, rhs: &Self) -> bool;
-    fn to_service_change_delete(s: String) -> ServiceChange;
+    fn service_delete_change(s: String) -> ConfigChange;
+    fn service_add_change(cfg: Self::ServiceConfig) -> ConfigChange;
     fn get_services(&self) -> &HashMap<String, Self::ServiceConfig>;
 }
 
@@ -62,8 +56,11 @@ impl InstanceConfig for ServerConfig {
 
         left == right
     }
-    fn to_service_change_delete(s: String) -> ServiceChange {
-        ServiceChange::ServerDelete(s)
+    fn service_delete_change(s: String) -> ConfigChange {
+        ConfigChange::ServerChange(ServerServiceChange::Delete(s))
+    }
+    fn service_add_change(cfg: Self::ServiceConfig) -> ConfigChange {
+        ConfigChange::ServerChange(ServerServiceChange::Add(cfg))
     }
     fn get_services(&self) -> &HashMap<String, Self::ServiceConfig> {
         &self.services
@@ -85,8 +82,11 @@ impl InstanceConfig for ClientConfig {
 
         left == right
     }
-    fn to_service_change_delete(s: String) -> ServiceChange {
-        ServiceChange::ClientDelete(s)
+    fn service_delete_change(s: String) -> ConfigChange {
+        ConfigChange::ClientChange(ClientServiceChange::Delete(s))
+    }
+    fn service_add_change(cfg: Self::ServiceConfig) -> ConfigChange {
+        ConfigChange::ClientChange(ClientServiceChange::Add(cfg))
     }
     fn get_services(&self) -> &HashMap<String, Self::ServiceConfig> {
         &self.services
@@ -180,8 +180,9 @@ async fn config_watcher(
                       }
                     };
 
-                    for event in calculate_events(&old, &new) {
-                      event_tx.send(event)?;
+                    let events = calculate_events(&old, &new).into_iter().flatten();
+                    for event in events {
+                        event_tx.send(event)?;
                     }
 
                     old = new;
@@ -198,42 +199,40 @@ async fn config_watcher(
     Ok(())
 }
 
-fn calculate_events(old: &Config, new: &Config) -> Vec<ConfigChange> {
+fn calculate_events(old: &Config, new: &Config) -> Option<Vec<ConfigChange>> {
     if old == new {
-        return vec![];
+        return None;
+    }
+
+    if (old.server.is_some() != new.server.is_some())
+        || (old.client.is_some() != new.client.is_some())
+    {
+        return Some(vec![ConfigChange::General(Box::new(new.clone()))]);
     }
 
     let mut ret = vec![];
 
     if old.server != new.server {
-        if old.server.is_some() != new.server.is_some() {
-            return vec![ConfigChange::General(Box::new(new.clone()))];
-        } else {
-            match calculate_instance_config_events(
-                old.server.as_ref().unwrap(),
-                new.server.as_ref().unwrap(),
-            ) {
-                Some(mut v) => ret.append(&mut v),
-                None => return vec![ConfigChange::General(Box::new(new.clone()))],
-            }
+        match calculate_instance_config_events(
+            old.server.as_ref().unwrap(),
+            new.server.as_ref().unwrap(),
+        ) {
+            Some(mut v) => ret.append(&mut v),
+            None => return Some(vec![ConfigChange::General(Box::new(new.clone()))]),
         }
     }
 
     if old.client != new.client {
-        if old.client.is_some() != new.client.is_some() {
-            return vec![ConfigChange::General(Box::new(new.clone()))];
-        } else {
-            match calculate_instance_config_events(
-                old.client.as_ref().unwrap(),
-                new.client.as_ref().unwrap(),
-            ) {
-                Some(mut v) => ret.append(&mut v),
-                None => return vec![ConfigChange::General(Box::new(new.clone()))],
-            }
+        match calculate_instance_config_events(
+            old.client.as_ref().unwrap(),
+            new.client.as_ref().unwrap(),
+        ) {
+            Some(mut v) => ret.append(&mut v),
+            None => return Some(vec![ConfigChange::General(Box::new(new.clone()))]),
         }
     }
 
-    ret
+    Some(ret)
 }
 
 // None indicates a General change needed
@@ -248,31 +247,17 @@ fn calculate_instance_config_events<T: InstanceConfig>(
     let old = old.get_services();
     let new = new.get_services();
 
-    let mut v = vec![];
-    v.append(&mut calculate_service_delete_events::<T>(old, new));
-    v.append(&mut calculate_service_add_events(old, new));
-
-    Some(v.into_iter().map(ConfigChange::ServiceChange).collect())
-}
-
-fn calculate_service_delete_events<T: InstanceConfig>(
-    old: &HashMap<String, T::ServiceConfig>,
-    new: &HashMap<String, T::ServiceConfig>,
-) -> Vec<ServiceChange> {
-    old.keys()
+    let deletions = old
+        .keys()
         .filter(|&name| new.get(name).is_none())
-        .map(|x| T::to_service_change_delete(x.to_owned()))
-        .collect()
-}
+        .map(|x| T::service_delete_change(x.to_owned()));
 
-fn calculate_service_add_events<T: PartialEq + Clone + Into<ServiceChange>>(
-    old: &HashMap<String, T>,
-    new: &HashMap<String, T>,
-) -> Vec<ServiceChange> {
-    new.iter()
+    let addition = new
+        .iter()
         .filter(|(name, c)| old.get(*name) != Some(*c))
-        .map(|(_, c)| c.clone().into())
-        .collect()
+        .map(|(_, c)| T::service_add_change(c.clone()));
+
+    Some(deletions.chain(addition).collect())
 }
 
 #[cfg(test)]
@@ -378,23 +363,23 @@ mod test {
         let mut expected = [
             vec![ConfigChange::General(Box::new(tests[0].new.clone()))],
             vec![ConfigChange::General(Box::new(tests[1].new.clone()))],
-            vec![ConfigChange::ServiceChange(ServiceChange::ServerAdd(
+            vec![ConfigChange::ServerChange(ServerServiceChange::Add(
                 Default::default(),
             ))],
-            vec![ConfigChange::ServiceChange(ServiceChange::ServerDelete(
+            vec![ConfigChange::ServerChange(ServerServiceChange::Delete(
                 String::from("foo"),
             ))],
             vec![
-                ConfigChange::ServiceChange(ServiceChange::ServerDelete(String::from("foo1"))),
-                ConfigChange::ServiceChange(ServiceChange::ServerAdd(
+                ConfigChange::ServerChange(ServerServiceChange::Delete(String::from("foo1"))),
+                ConfigChange::ServerChange(ServerServiceChange::Add(
                     tests[4].new.server.as_ref().unwrap().services["bar1"].clone(),
                 )),
-                ConfigChange::ServiceChange(ServiceChange::ClientDelete(String::from("foo1"))),
-                ConfigChange::ServiceChange(ServiceChange::ClientDelete(String::from("foo2"))),
-                ConfigChange::ServiceChange(ServiceChange::ClientAdd(
+                ConfigChange::ClientChange(ClientServiceChange::Delete(String::from("foo1"))),
+                ConfigChange::ClientChange(ClientServiceChange::Delete(String::from("foo2"))),
+                ConfigChange::ClientChange(ClientServiceChange::Add(
                     tests[4].new.client.as_ref().unwrap().services["bar1"].clone(),
                 )),
-                ConfigChange::ServiceChange(ServiceChange::ClientAdd(
+                ConfigChange::ClientChange(ClientServiceChange::Add(
                     tests[4].new.client.as_ref().unwrap().services["bar2"].clone(),
                 )),
             ],
@@ -403,16 +388,18 @@ mod test {
         assert_eq!(tests.len(), expected.len());
 
         for i in 0..tests.len() {
-            let mut actual = calculate_events(&tests[i].old, &tests[i].new);
+            let mut actual = calculate_events(&tests[i].old, &tests[i].new).unwrap();
 
             let get_key = |x: &ConfigChange| -> String {
                 match x {
                     ConfigChange::General(_) => String::from("g"),
-                    ConfigChange::ServiceChange(sc) => match sc {
-                        ServiceChange::ClientAdd(c) => "c_add_".to_owned() + &c.name,
-                        ServiceChange::ClientDelete(s) => "c_del_".to_owned() + s,
-                        ServiceChange::ServerAdd(c) => "s_add_".to_owned() + &c.name,
-                        ServiceChange::ServerDelete(s) => "s_del_".to_owned() + s,
+                    ConfigChange::ServerChange(sc) => match sc {
+                        ServerServiceChange::Add(c) => "s_add_".to_owned() + &c.name,
+                        ServerServiceChange::Delete(s) => "s_del_".to_owned() + s,
+                    },
+                    ConfigChange::ClientChange(sc) => match sc {
+                        ClientServiceChange::Add(c) => "c_add_".to_owned() + &c.name,
+                        ClientServiceChange::Delete(s) => "c_del_".to_owned() + s,
                     },
                 }
             };
@@ -422,5 +409,20 @@ mod test {
 
             assert_eq!(actual, expected[i]);
         }
+
+        // No changes
+        assert_eq!(
+            calculate_events(
+                &Config {
+                    server: Default::default(),
+                    client: None,
+                },
+                &Config {
+                    server: Default::default(),
+                    client: None,
+                },
+            ),
+            None
+        );
     }
 }

+ 5 - 6
src/lib.rs

@@ -10,7 +10,6 @@ mod transport;
 pub use cli::Cli;
 use cli::KeypairType;
 pub use config::Config;
-use config_watcher::ServiceChange;
 pub use constants::UDP_BUFFER_SIZE;
 
 use anyhow::Result;
@@ -76,7 +75,7 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()
     let (shutdown_tx, _) = broadcast::channel(1);
 
     // (The join handle of the last instance, The service update channel sender)
-    let mut last_instance: Option<(tokio::task::JoinHandle<_>, mpsc::Sender<ServiceChange>)> = None;
+    let mut last_instance: Option<(tokio::task::JoinHandle<_>, mpsc::Sender<ConfigChange>)> = None;
 
     while let Some(e) = cfg_watcher.event_rx.recv().await {
         match e {
@@ -101,10 +100,10 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()
                     service_update_tx,
                 ));
             }
-            ConfigChange::ServiceChange(service_event) => {
-                info!("Service change detcted. {:?}", service_event);
+            ev => {
+                info!("Service change detected. {:?}", ev);
                 if let Some((_, service_update_tx)) = &last_instance {
-                    let _ = service_update_tx.send(service_event).await;
+                    let _ = service_update_tx.send(ev).await;
                 }
             }
         }
@@ -119,7 +118,7 @@ async fn run_instance(
     config: Config,
     args: Cli,
     shutdown_rx: broadcast::Receiver<bool>,
-    service_update: mpsc::Receiver<ServiceChange>,
+    service_update: mpsc::Receiver<ConfigChange>,
 ) {
     let ret: Result<()> = match determine_run_mode(&config, &args) {
         RunMode::Undetermine => panic!("Cannot determine running as a server or a client"),

+ 25 - 23
src/server.rs

@@ -1,5 +1,5 @@
 use crate::config::{Config, ServerConfig, ServerServiceConfig, ServiceType, TransportType};
-use crate::config_watcher::ServiceChange;
+use crate::config_watcher::{ConfigChange, ServerServiceChange};
 use crate::constants::{listen_backoff, UDP_BUFFER_SIZE};
 use crate::helper::retry_notify_with_deadline;
 use crate::multi_map::MultiMap;
@@ -40,7 +40,7 @@ const HANDSHAKE_TIMEOUT: u64 = 5; // Timeout for transport handshake
 pub async fn run_server(
     config: Config,
     shutdown_rx: broadcast::Receiver<bool>,
-    service_rx: mpsc::Receiver<ServiceChange>,
+    update_rx: mpsc::Receiver<ConfigChange>,
 ) -> Result<()> {
     let config = match config.server {
             Some(config) => config,
@@ -52,13 +52,13 @@ pub async fn run_server(
     match config.transport.transport_type {
         TransportType::Tcp => {
             let mut server = Server::<TcpTransport>::from(config).await?;
-            server.run(shutdown_rx, service_rx).await?;
+            server.run(shutdown_rx, update_rx).await?;
         }
         TransportType::Tls => {
             #[cfg(feature = "tls")]
             {
                 let mut server = Server::<TlsTransport>::from(config).await?;
-                server.run(shutdown_rx, service_rx).await?;
+                server.run(shutdown_rx, update_rx).await?;
             }
             #[cfg(not(feature = "tls"))]
             crate::helper::feature_not_compile("tls")
@@ -67,7 +67,7 @@ pub async fn run_server(
             #[cfg(feature = "noise")]
             {
                 let mut server = Server::<NoiseTransport>::from(config).await?;
-                server.run(shutdown_rx, service_rx).await?;
+                server.run(shutdown_rx, update_rx).await?;
             }
             #[cfg(not(feature = "noise"))]
             crate::helper::feature_not_compile("noise")
@@ -124,7 +124,7 @@ impl<T: 'static + Transport> Server<T> {
     pub async fn run(
         &mut self,
         mut shutdown_rx: broadcast::Receiver<bool>,
-        mut service_rx: mpsc::Receiver<ServiceChange>,
+        mut update_rx: mpsc::Receiver<ConfigChange>,
     ) -> Result<()> {
         // Listen at `server.bind_addr`
         let l = self
@@ -198,7 +198,7 @@ impl<T: 'static + Transport> Server<T> {
                     info!("Shuting down gracefully...");
                     break;
                 },
-                e = service_rx.recv() => {
+                e = update_rx.recv() => {
                     if let Some(e) = e {
                         self.handle_hot_reload(e).await;
                     }
@@ -211,24 +211,26 @@ impl<T: 'static + Transport> Server<T> {
         Ok(())
     }
 
-    async fn handle_hot_reload(&mut self, e: ServiceChange) {
+    async fn handle_hot_reload(&mut self, e: ConfigChange) {
         match e {
-            ServiceChange::ServerAdd(s) => {
-                let hash = protocol::digest(s.name.as_bytes());
-                let mut wg = self.services.write().await;
-                let _ = wg.insert(hash, s);
-
-                let mut wg = self.control_channels.write().await;
-                let _ = wg.remove1(&hash);
-            }
-            ServiceChange::ServerDelete(s) => {
-                let hash = protocol::digest(s.as_bytes());
-                let _ = self.services.write().await.remove(&hash);
+            ConfigChange::ServerChange(server_change) => match server_change {
+                ServerServiceChange::Add(cfg) => {
+                    let hash = protocol::digest(cfg.name.as_bytes());
+                    let mut wg = self.services.write().await;
+                    let _ = wg.insert(hash, cfg);
+
+                    let mut wg = self.control_channels.write().await;
+                    let _ = wg.remove1(&hash);
+                }
+                ServerServiceChange::Delete(s) => {
+                    let hash = protocol::digest(s.as_bytes());
+                    let _ = self.services.write().await.remove(&hash);
 
-                let mut wg = self.control_channels.write().await;
-                let _ = wg.remove1(&hash);
-            }
-            _ => (),
+                    let mut wg = self.control_channels.write().await;
+                    let _ = wg.remove1(&hash);
+                }
+            },
+            ignored => warn!("Ignored {:?} since running as a server", ignored),
         }
     }
 }