Selaa lähdekoodia

test: refactor and add tests for hot-reload

Yujia Qiao 4 vuotta sitten
vanhempi
sitoutus
c8cb60708d
5 muutettua tiedostoa jossa 277 lisäystä ja 119 poistoa
  1. 5 5
      src/client.rs
  2. 8 2
      src/config.rs
  3. 245 87
      src/config_watcher.rs
  4. 9 7
      src/lib.rs
  5. 10 18
      src/server.rs

+ 5 - 5
src/client.rs

@@ -1,5 +1,5 @@
 use crate::config::{ClientConfig, ClientServiceConfig, Config, TransportType};
-use crate::config_watcher::ServiceChangeEvent;
+use crate::config_watcher::ServiceChange;
 use crate::helper::udp_connect;
 use crate::protocol::Hello::{self, *};
 use crate::protocol::{
@@ -30,7 +30,7 @@ use crate::constants::{UDP_BUFFER_SIZE, UDP_SENDQ_SIZE, UDP_TIMEOUT};
 pub async fn run_client(
     config: &Config,
     shutdown_rx: broadcast::Receiver<bool>,
-    service_rx: mpsc::Receiver<ServiceChangeEvent>,
+    service_rx: mpsc::Receiver<ServiceChange>,
 ) -> Result<()> {
     let config = match &config.client {
         Some(v) => v,
@@ -93,7 +93,7 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
     async fn run(
         &mut self,
         mut shutdown_rx: broadcast::Receiver<bool>,
-        mut service_rx: mpsc::Receiver<ServiceChangeEvent>,
+        mut service_rx: mpsc::Receiver<ServiceChange>,
     ) -> Result<()> {
         for (name, config) in &self.config.services {
             // Create a control channel for each service defined
@@ -120,7 +120,7 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
                 e = service_rx.recv() => {
                     if let Some(e) = e {
                         match e {
-                            ServiceChangeEvent::ClientAdd(s)=> {
+                            ServiceChange::ClientAdd(s)=> {
                                 let name = s.name.clone();
                                 let handle = ControlChannelHandle::new(
                                     s,
@@ -129,7 +129,7 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
                                 );
                                 let _ = self.service_handles.insert(name, handle);
                             },
-                            ServiceChangeEvent::ClientDelete(s)=> {
+                            ServiceChange::ClientDelete(s)=> {
                                 let _ = self.service_handles.remove(&s);
                             },
                             _ => ()

+ 8 - 2
src/config.rs

@@ -38,11 +38,17 @@ pub enum ServiceType {
     Udp,
 }
 
+impl Default for ServiceType {
+    fn default() -> Self {
+        ServiceType::Tcp
+    }
+}
+
 fn default_service_type() -> ServiceType {
-    ServiceType::Tcp
+    Default::default()
 }
 
-#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
+#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
 pub struct ServerServiceConfig {
     #[serde(rename = "type", default = "default_service_type")]
     pub service_type: ServiceType,

+ 245 - 87
src/config_watcher.rs

@@ -1,5 +1,5 @@
 use crate::{
-    config::{ClientServiceConfig, ServerServiceConfig},
+    config::{ClientConfig, ClientServiceConfig, ServerConfig, ServerServiceConfig},
     Config,
 };
 use anyhow::{Context, Result};
@@ -13,22 +13,87 @@ use tracing::{error, info, instrument};
 #[cfg(feature = "notify")]
 use notify::{event::ModifyKind, EventKind, RecursiveMode, Watcher};
 
-#[derive(Debug)]
-pub enum ConfigChangeEvent {
+#[derive(Debug, PartialEq)]
+pub enum ConfigChange {
     General(Box<Config>), // Trigger a full restart
-    ServiceChange(ServiceChangeEvent),
+    ServiceChange(ServiceChange),
 }
 
-#[derive(Debug)]
-pub enum ServiceChangeEvent {
+#[derive(Debug, PartialEq)]
+pub enum ServiceChange {
     ClientAdd(ClientServiceConfig),
     ClientDelete(String),
     ServerAdd(ServerServiceConfig),
     ServerDelete(String),
 }
 
+impl From<ClientServiceConfig> for ServiceChange {
+    fn from(c: ClientServiceConfig) -> Self {
+        ServiceChange::ClientAdd(c)
+    }
+}
+
+impl From<ServerServiceConfig> for ServiceChange {
+    fn from(c: ServerServiceConfig) -> Self {
+        ServiceChange::ServerAdd(c)
+    }
+}
+
+trait InstanceConfig: Clone {
+    type ServiceConfig: Into<ServiceChange> + PartialEq + Clone;
+    fn equal_without_service(&self, rhs: &Self) -> bool;
+    fn to_service_change_delete(s: String) -> ServiceChange;
+    fn get_services(&self) -> &HashMap<String, Self::ServiceConfig>;
+}
+
+impl InstanceConfig for ServerConfig {
+    type ServiceConfig = ServerServiceConfig;
+    fn equal_without_service(&self, rhs: &Self) -> bool {
+        let left = ServerConfig {
+            services: Default::default(),
+            ..self.clone()
+        };
+
+        let right = ServerConfig {
+            services: Default::default(),
+            ..rhs.clone()
+        };
+
+        left == right
+    }
+    fn to_service_change_delete(s: String) -> ServiceChange {
+        ServiceChange::ServerDelete(s)
+    }
+    fn get_services(&self) -> &HashMap<String, Self::ServiceConfig> {
+        &self.services
+    }
+}
+
+impl InstanceConfig for ClientConfig {
+    type ServiceConfig = ClientServiceConfig;
+    fn equal_without_service(&self, rhs: &Self) -> bool {
+        let left = ClientConfig {
+            services: Default::default(),
+            ..self.clone()
+        };
+
+        let right = ClientConfig {
+            services: Default::default(),
+            ..rhs.clone()
+        };
+
+        left == right
+    }
+    fn to_service_change_delete(s: String) -> ServiceChange {
+        ServiceChange::ClientDelete(s)
+    }
+    fn get_services(&self) -> &HashMap<String, Self::ServiceConfig> {
+        &self.services
+    }
+}
+
 pub struct ConfigWatcherHandle {
-    pub event_rx: mpsc::Receiver<ConfigChangeEvent>,
+    pub event_rx: mpsc::Receiver<ConfigChange>,
 }
 
 impl ConfigWatcherHandle {
@@ -39,7 +104,7 @@ impl ConfigWatcherHandle {
 
         // Initial start
         event_tx
-            .send(ConfigChangeEvent::General(Box::new(origin_cfg.clone())))
+            .send(ConfigChange::General(Box::new(origin_cfg.clone())))
             .await
             .unwrap();
 
@@ -59,30 +124,33 @@ impl ConfigWatcherHandle {
 async fn config_watcher(
     _path: PathBuf,
     mut shutdown_rx: broadcast::Receiver<bool>,
-    _cfg_event_tx: mpsc::Sender<ConfigChangeEvent>,
+    _event_tx: mpsc::Sender<ConfigChange>,
     _old: Config,
 ) -> Result<()> {
-    // Do nothing except wating for ctrl-c
+    // Do nothing except waiting for ctrl-c
     let _ = shutdown_rx.recv().await;
     Ok(())
 }
 
 #[cfg(feature = "notify")]
-#[instrument(skip(shutdown_rx, cfg_event_tx, old))]
+#[instrument(skip(shutdown_rx, event_tx, old))]
 async fn config_watcher(
     path: PathBuf,
     mut shutdown_rx: broadcast::Receiver<bool>,
-    cfg_event_tx: mpsc::Sender<ConfigChangeEvent>,
+    event_tx: mpsc::Sender<ConfigChange>,
     mut old: Config,
 ) -> Result<()> {
     let (fevent_tx, mut fevent_rx) = mpsc::channel(16);
 
-    let mut watcher = notify::recommended_watcher(move |res| match res {
-        Ok(event) => {
-            let _ = fevent_tx.blocking_send(event);
-        }
-        Err(e) => error!("watch error: {:?}", e),
-    })?;
+    let mut watcher =
+        notify::recommended_watcher(move |res: Result<notify::Event, _>| match res {
+            Ok(e) => {
+                if let EventKind::Modify(ModifyKind::Data(_)) = e.kind {
+                    let _ = fevent_tx.blocking_send(true);
+                }
+            }
+            Err(e) => error!("watch error: {:?}", e),
+        })?;
 
     watcher.watch(&path, RecursiveMode::NonRecursive)?;
     info!("Start watching the config");
@@ -91,12 +159,7 @@ async fn config_watcher(
         tokio::select! {
           e = fevent_rx.recv() => {
             match e {
-              Some(e) => {
-                if let EventKind::Modify(kind) = e.kind {
-                    match kind {
-                        ModifyKind::Data(_) => (),
-                        _ => continue
-                    }
+              Some(_) => {
                     info!("Rescan the configuration");
                     let new = match Config::from_file(&path).await.with_context(|| "The changed configuration is invalid. Ignored") {
                       Ok(v) => v,
@@ -107,12 +170,11 @@ async fn config_watcher(
                       }
                     };
 
-                    for event in calculate_event(&old, &new) {
-                      cfg_event_tx.send(event).await?;
+                    for event in calculate_events(&old, &new) {
+                      event_tx.send(event).await?;
                     }
 
                     old = new;
-                  }
               },
               None => break
             }
@@ -126,74 +188,170 @@ async fn config_watcher(
     Ok(())
 }
 
-fn calculate_event(old: &Config, new: &Config) -> Vec<ConfigChangeEvent> {
-    let mut ret = Vec::new();
-
-    if old != new {
-        if old.server.is_some() && new.server.is_some() {
-            let mut e: Vec<ConfigChangeEvent> = calculate_service_delete_event(
-                &old.server.as_ref().unwrap().services,
-                &new.server.as_ref().unwrap().services,
-            )
-            .into_iter()
-            .map(|x| ConfigChangeEvent::ServiceChange(ServiceChangeEvent::ServerDelete(x)))
-            .collect();
-            ret.append(&mut e);
-
-            let mut e: Vec<ConfigChangeEvent> = calculate_service_add_event(
-                &old.server.as_ref().unwrap().services,
-                &new.server.as_ref().unwrap().services,
-            )
-            .into_iter()
-            .map(|x| ConfigChangeEvent::ServiceChange(ServiceChangeEvent::ServerAdd(x)))
-            .collect();
-
-            ret.append(&mut e);
-        } else if old.client.is_some() && new.client.is_some() {
-            let mut e: Vec<ConfigChangeEvent> = calculate_service_delete_event(
-                &old.client.as_ref().unwrap().services,
-                &new.client.as_ref().unwrap().services,
-            )
-            .into_iter()
-            .map(|x| ConfigChangeEvent::ServiceChange(ServiceChangeEvent::ClientDelete(x)))
-            .collect();
-            ret.append(&mut e);
-
-            let mut e: Vec<ConfigChangeEvent> = calculate_service_add_event(
-                &old.client.as_ref().unwrap().services,
-                &new.client.as_ref().unwrap().services,
-            )
-            .into_iter()
-            .map(|x| ConfigChangeEvent::ServiceChange(ServiceChangeEvent::ClientAdd(x)))
-            .collect();
-
-            ret.append(&mut e);
+fn calculate_events(old: &Config, new: &Config) -> Vec<ConfigChange> {
+    if old == new {
+        vec![]
+    } else if old.server != new.server {
+        if old.server.is_some() != new.server.is_some() {
+            vec![ConfigChange::General(Box::new(new.clone()))]
+        } else {
+            match calculate_instance_config_events(
+                old.server.as_ref().unwrap(),
+                new.server.as_ref().unwrap(),
+            ) {
+                Some(v) => v,
+                None => vec![ConfigChange::General(Box::new(new.clone()))],
+            }
+        }
+    } else if old.client != new.client {
+        if old.client.is_some() != new.client.is_some() {
+            vec![ConfigChange::General(Box::new(new.clone()))]
         } else {
-            ret.push(ConfigChangeEvent::General(Box::new(new.clone())));
+            match calculate_instance_config_events(
+                old.client.as_ref().unwrap(),
+                new.client.as_ref().unwrap(),
+            ) {
+                Some(v) => v,
+                None => vec![ConfigChange::General(Box::new(new.clone()))],
+            }
         }
+    } else {
+        vec![]
+    }
+}
+
+// None indicates a General change needed
+fn calculate_instance_config_events<T: InstanceConfig>(
+    old: &T,
+    new: &T,
+) -> Option<Vec<ConfigChange>> {
+    if !old.equal_without_service(new) {
+        return None;
     }
 
-    ret
+    let old = old.get_services();
+    let new = new.get_services();
+
+    let mut v = vec![];
+    v.append(&mut calculate_service_delete_events::<T>(old, new));
+    v.append(&mut calculate_service_add_events(old, new));
+
+    Some(v.into_iter().map(ConfigChange::ServiceChange).collect())
 }
 
-fn calculate_service_delete_event<T: PartialEq>(
-    old_services: &HashMap<String, T>,
-    new_services: &HashMap<String, T>,
-) -> Vec<String> {
-    old_services
-        .keys()
-        .filter(|&name| old_services.get(name) != new_services.get(name))
-        .map(|x| x.to_owned())
+fn calculate_service_delete_events<T: InstanceConfig>(
+    old: &HashMap<String, T::ServiceConfig>,
+    new: &HashMap<String, T::ServiceConfig>,
+) -> Vec<ServiceChange> {
+    old.keys()
+        .filter(|&name| new.get(name).is_none())
+        .map(|x| T::to_service_change_delete(x.to_owned()))
         .collect()
 }
 
-fn calculate_service_add_event<T: PartialEq + Clone>(
-    old_services: &HashMap<String, T>,
-    new_services: &HashMap<String, T>,
-) -> Vec<T> {
-    new_services
-        .iter()
-        .filter(|(name, _)| old_services.get(*name) != new_services.get(*name))
-        .map(|(_, c)| c.clone())
+fn calculate_service_add_events<T: PartialEq + Clone + Into<ServiceChange>>(
+    old: &HashMap<String, T>,
+    new: &HashMap<String, T>,
+) -> Vec<ServiceChange> {
+    new.iter()
+        .filter(|(name, c)| old.get(*name) != Some(*c))
+        .map(|(_, c)| c.clone().into())
         .collect()
 }
+
+#[cfg(test)]
+mod test {
+    use crate::config::ServerConfig;
+
+    use super::*;
+
+    // macro to create map or set literal
+    macro_rules! collection {
+        // map-like
+        ($($k:expr => $v:expr),* $(,)?) => {{
+            use std::iter::{Iterator, IntoIterator};
+            Iterator::collect(IntoIterator::into_iter([$(($k, $v),)*]))
+        }};
+    }
+
+    #[test]
+    fn test_calculate_events() {
+        struct Test {
+            old: Config,
+            new: Config,
+        }
+
+        let tests = [
+            Test {
+                old: Config {
+                    server: Some(Default::default()),
+                    client: None,
+                },
+                new: Config {
+                    server: Some(Default::default()),
+                    client: Some(Default::default()),
+                },
+            },
+            Test {
+                old: Config {
+                    server: Some(ServerConfig {
+                        bind_addr: String::from("127.0.0.1:2334"),
+                        ..Default::default()
+                    }),
+                    client: None,
+                },
+                new: Config {
+                    server: Some(ServerConfig {
+                        bind_addr: String::from("127.0.0.1:2333"),
+                        services: collection!(String::from("foo") => Default::default()),
+                        ..Default::default()
+                    }),
+                    client: None,
+                },
+            },
+            Test {
+                old: Config {
+                    server: Some(Default::default()),
+                    client: None,
+                },
+                new: Config {
+                    server: Some(ServerConfig {
+                        services: collection!(String::from("foo") => Default::default()),
+                        ..Default::default()
+                    }),
+                    client: None,
+                },
+            },
+            Test {
+                old: Config {
+                    server: Some(ServerConfig {
+                        services: collection!(String::from("foo") => Default::default()),
+                        ..Default::default()
+                    }),
+                    client: None,
+                },
+                new: Config {
+                    server: Some(Default::default()),
+                    client: None,
+                },
+            },
+        ];
+        let expected = [
+            vec![ConfigChange::General(Box::new(tests[0].new.clone()))],
+            vec![ConfigChange::General(Box::new(tests[1].new.clone()))],
+            vec![ConfigChange::ServiceChange(ServiceChange::ServerAdd(
+                Default::default(),
+            ))],
+            vec![ConfigChange::ServiceChange(ServiceChange::ServerDelete(
+                String::from("foo"),
+            ))],
+        ];
+
+        assert_eq!(tests.len(), expected.len());
+
+        for i in 0..tests.len() {
+            let actual = calculate_events(&tests[i].old, &tests[i].new);
+            assert_eq!(actual, expected[i]);
+        }
+    }
+}

+ 9 - 7
src/lib.rs

@@ -10,7 +10,7 @@ mod transport;
 pub use cli::Cli;
 use cli::KeypairType;
 pub use config::Config;
-use config_watcher::ServiceChangeEvent;
+use config_watcher::ServiceChange;
 pub use constants::UDP_BUFFER_SIZE;
 
 use anyhow::Result;
@@ -27,7 +27,7 @@ mod server;
 #[cfg(feature = "server")]
 use server::run_server;
 
-use crate::config_watcher::{ConfigChangeEvent, ConfigWatcherHandle};
+use crate::config_watcher::{ConfigChange, ConfigWatcherHandle};
 
 const DEFAULT_CURVE: KeypairType = KeypairType::X25519;
 
@@ -76,12 +76,11 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()
     let (shutdown_tx, _) = broadcast::channel(1);
 
     // (The join handle of the last instance, The service update channel sender)
-    let mut last_instance: Option<(tokio::task::JoinHandle<_>, mpsc::Sender<ServiceChangeEvent>)> =
-        None;
+    let mut last_instance: Option<(tokio::task::JoinHandle<_>, mpsc::Sender<ServiceChange>)> = None;
 
     while let Some(e) = cfg_watcher.event_rx.recv().await {
         match e {
-            ConfigChangeEvent::General(config) => {
+            ConfigChange::General(config) => {
                 if let Some((i, _)) = last_instance {
                     info!("General configuration change detected. Restarting...");
                     shutdown_tx.send(true)?;
@@ -102,7 +101,7 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()
                     service_update_tx,
                 ));
             }
-            ConfigChangeEvent::ServiceChange(service_event) => {
+            ConfigChange::ServiceChange(service_event) => {
                 info!("Service change detcted. {:?}", service_event);
                 if let Some((_, service_update_tx)) = &last_instance {
                     let _ = service_update_tx.send(service_event).await;
@@ -110,6 +109,9 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()
             }
         }
     }
+
+    let _ = shutdown_tx.send(true);
+
     Ok(())
 }
 
@@ -117,7 +119,7 @@ async fn run_instance(
     config: Config,
     args: Cli,
     shutdown_rx: broadcast::Receiver<bool>,
-    service_update: mpsc::Receiver<ServiceChangeEvent>,
+    service_update: mpsc::Receiver<ServiceChange>,
 ) -> Result<()> {
     match determine_run_mode(&config, &args) {
         RunMode::Undetermine => panic!("Cannot determine running as a server or a client"),

+ 10 - 18
src/server.rs

@@ -1,5 +1,5 @@
 use crate::config::{Config, ServerConfig, ServerServiceConfig, ServiceType, TransportType};
-use crate::config_watcher::ServiceChangeEvent;
+use crate::config_watcher::ServiceChange;
 use crate::constants::{listen_backoff, UDP_BUFFER_SIZE};
 use crate::multi_map::MultiMap;
 use crate::protocol::Hello::{ControlChannelHello, DataChannelHello};
@@ -39,7 +39,7 @@ const CHAN_SIZE: usize = 2048; // The capacity of various chans
 pub async fn run_server(
     config: &Config,
     shutdown_rx: broadcast::Receiver<bool>,
-    service_rx: mpsc::Receiver<ServiceChangeEvent>,
+    service_rx: mpsc::Receiver<ServiceChange>,
 ) -> Result<()> {
     let config = match &config.server {
             Some(config) => config,
@@ -122,7 +122,7 @@ impl<'a, T: 'static + Transport> Server<'a, T> {
     pub async fn run(
         &mut self,
         mut shutdown_rx: broadcast::Receiver<bool>,
-        mut service_rx: mpsc::Receiver<ServiceChangeEvent>,
+        mut service_rx: mpsc::Receiver<ServiceChange>,
     ) -> Result<()> {
         // Listen at `server.bind_addr`
         let l = self
@@ -193,9 +193,9 @@ impl<'a, T: 'static + Transport> Server<'a, T> {
         Ok(())
     }
 
-    async fn handle_hot_reload(&mut self, e: ServiceChangeEvent) {
+    async fn handle_hot_reload(&mut self, e: ServiceChange) {
         match e {
-            ServiceChangeEvent::ServerAdd(s) => {
+            ServiceChange::ServerAdd(s) => {
                 let hash = protocol::digest(s.name.as_bytes());
                 let mut wg = self.services.write().await;
                 let _ = wg.insert(hash, s);
@@ -203,7 +203,7 @@ impl<'a, T: 'static + Transport> Server<'a, T> {
                 let mut wg = self.control_channels.write().await;
                 let _ = wg.remove1(&hash);
             }
-            ServiceChangeEvent::ServerDelete(s) => {
+            ServiceChange::ServerDelete(s) => {
                 let hash = protocol::digest(s.as_bytes());
                 let _ = self.services.write().await.remove(&hash);
 
@@ -340,11 +340,8 @@ async fn do_data_channel_handshake<T: 'static + Transport>(
 }
 
 pub struct ControlChannelHandle<T: Transport> {
-    // Shutdown the control channel.
-    // Not used for now, but can be used for hot reloading
-    #[allow(dead_code)]
-    shutdown_tx: broadcast::Sender<bool>,
-    //data_ch_req_tx: mpsc::Sender<bool>,
+    // Shutdown the control channel by dropping it
+    _shutdown_tx: broadcast::Sender<bool>,
     data_ch_tx: mpsc::Sender<T::Stream>,
 }
 
@@ -359,7 +356,7 @@ where
         // Save the name string for logging
         let name = service.name.clone();
 
-        // Create a shutdown channel. The sender is not used for now, but for future use
+        // Create a shutdown channel
         let (shutdown_tx, shutdown_rx) = broadcast::channel::<bool>(1);
 
         // Store data channels
@@ -417,15 +414,10 @@ where
         });
 
         ControlChannelHandle {
-            shutdown_tx,
+            _shutdown_tx: shutdown_tx,
             data_ch_tx,
         }
     }
-
-    #[allow(dead_code)]
-    fn shutdown(self) {
-        let _ = self.shutdown_tx.send(true);
-    }
 }
 
 // Control channel, using T as the transport layer. P is TcpStream or UdpTraffic