Browse Source

feat: service-wise hot reload

Yujia Qiao 4 năm trước cách đây
mục cha
commit
9c0f28caee
4 tập tin đã thay đổi với 173 bổ sung44 xóa
  1. 39 8
      src/client.rs
  2. 84 20
      src/config_watcher.rs
  3. 10 11
      src/lib.rs
  4. 40 5
      src/server.rs

+ 39 - 8
src/client.rs

@@ -1,4 +1,5 @@
 use crate::config::{ClientConfig, ClientServiceConfig, Config, TransportType};
+use crate::config_watcher::ServiceChangeEvent;
 use crate::helper::udp_connect;
 use crate::protocol::Hello::{self, *};
 use crate::protocol::{
@@ -16,7 +17,7 @@ use tokio::io::{self, copy_bidirectional, AsyncWriteExt};
 use tokio::net::{TcpStream, UdpSocket};
 use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
 use tokio::time::{self, Duration};
-use tracing::{debug, error, info, instrument, Instrument, Span};
+use tracing::{debug, error, info, instrument, warn, Instrument, Span};
 
 #[cfg(feature = "noise")]
 use crate::transport::NoiseTransport;
@@ -26,7 +27,11 @@ use crate::transport::TlsTransport;
 use crate::constants::{UDP_BUFFER_SIZE, UDP_SENDQ_SIZE, UDP_TIMEOUT};
 
 // The entrypoint of running a client
-pub async fn run_client(config: &Config, shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
+pub async fn run_client(
+    config: &Config,
+    shutdown_rx: broadcast::Receiver<bool>,
+    service_rx: mpsc::Receiver<ServiceChangeEvent>,
+) -> Result<()> {
     let config = match &config.client {
         Some(v) => v,
         None => {
@@ -37,13 +42,13 @@ pub async fn run_client(config: &Config, shutdown_rx: broadcast::Receiver<bool>)
     match config.transport.transport_type {
         TransportType::Tcp => {
             let mut client = Client::<TcpTransport>::from(config).await?;
-            client.run(shutdown_rx).await
+            client.run(shutdown_rx, service_rx).await
         }
         TransportType::Tls => {
             #[cfg(feature = "tls")]
             {
                 let mut client = Client::<TlsTransport>::from(config).await?;
-                client.run(shutdown_rx).await
+                client.run(shutdown_rx, service_rx).await
             }
             #[cfg(not(feature = "tls"))]
             crate::helper::feature_not_compile("tls")
@@ -52,7 +57,7 @@ pub async fn run_client(config: &Config, shutdown_rx: broadcast::Receiver<bool>)
             #[cfg(feature = "noise")]
             {
                 let mut client = Client::<NoiseTransport>::from(config).await?;
-                client.run(shutdown_rx).await
+                client.run(shutdown_rx, service_rx).await
             }
             #[cfg(not(feature = "noise"))]
             crate::helper::feature_not_compile("noise")
@@ -85,7 +90,11 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
     }
 
     // The entrypoint of Client
-    async fn run(&mut self, mut shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
+    async fn run(
+        &mut self,
+        mut shutdown_rx: broadcast::Receiver<bool>,
+        mut service_rx: mpsc::Receiver<ServiceChangeEvent>,
+    ) -> Result<()> {
         for (name, config) in &self.config.services {
             // Create a control channel for each service defined
             let handle = ControlChannelHandle::new(
@@ -96,7 +105,6 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
             self.service_handles.insert(name.clone(), handle);
         }
 
-        // TODO: Maybe wait for a config change signal for hot reloading
         // Wait for the shutdown signal
         loop {
             tokio::select! {
@@ -109,6 +117,25 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
                     }
                     break;
                 },
+                e = service_rx.recv() => {
+                    if let Some(e) = e {
+                        match e {
+                            ServiceChangeEvent::ClientAdd(s)=> {
+                                let name = s.name.clone();
+                                let handle = ControlChannelHandle::new(
+                                    s,
+                                    self.config.remote_addr.clone(),
+                                    self.transport.clone(),
+                                );
+                                let _ = self.service_handles.insert(name, handle);
+                            },
+                            ServiceChangeEvent::ClientDelete(s)=> {
+                                let _ = self.service_handles.remove(&s);
+                            },
+                            _ => ()
+                        }
+                    }
+                }
             }
         }
 
@@ -399,7 +426,7 @@ impl<T: 'static + Transport> ControlChannel<T> {
                     }
                 },
                 _ = &mut self.shutdown_rx => {
-                    info!( "Shutting down gracefully...");
+                    info!( "Control channel shutting down...");
                     break;
                 }
             }
@@ -433,6 +460,10 @@ impl ControlChannelHandle {
                     .await
                     .with_context(|| "Failed to run the control channel")
                 {
+                    if s.shutdown_rx.try_recv() != Err(oneshot::error::TryRecvError::Empty) {
+                        break;
+                    }
+
                     let duration = Duration::from_secs(1);
                     error!("{:?}\n\nRetry in {:?}...", err, duration);
                     time::sleep(duration).await;

+ 84 - 20
src/config_watcher.rs

@@ -3,23 +3,26 @@ use crate::{
     Config,
 };
 use anyhow::{Context, Result};
-use notify::{EventKind, RecursiveMode, Watcher};
-use std::path::PathBuf;
+use notify::{event::ModifyKind, EventKind, RecursiveMode, Watcher};
+use std::{
+    collections::HashMap,
+    path::{Path, PathBuf},
+};
 use tokio::sync::{broadcast, mpsc};
 use tracing::{error, info, instrument};
 
 #[derive(Debug)]
 pub enum ConfigChangeEvent {
-    General(Config), // Trigger a full restart
+    General(Box<Config>), // Trigger a full restart
     ServiceChange(ServiceChangeEvent),
 }
 
 #[derive(Debug)]
 pub enum ServiceChangeEvent {
-    AddClientService(ClientServiceConfig),
-    DeleteClientService(ClientServiceConfig),
-    AddServerService(ServerServiceConfig),
-    DeleteServerService(ServerServiceConfig),
+    ClientAdd(ClientServiceConfig),
+    ClientDelete(String),
+    ServerAdd(ServerServiceConfig),
+    ServerDelete(String),
 }
 
 pub struct ConfigWatcherHandle {
@@ -27,7 +30,7 @@ pub struct ConfigWatcherHandle {
 }
 
 impl ConfigWatcherHandle {
-    pub async fn new(path: &PathBuf, shutdown_rx: broadcast::Receiver<bool>) -> Result<Self> {
+    pub async fn new(path: &Path, shutdown_rx: broadcast::Receiver<bool>) -> Result<Self> {
         let (event_tx, event_rx) = mpsc::channel(16);
 
         let origin_cfg = Config::from_file(path).await?;
@@ -43,7 +46,7 @@ impl ConfigWatcherHandle {
     }
 }
 
-#[instrument(skip(shutdown_rx, cfg_event_tx))]
+#[instrument(skip(shutdown_rx, cfg_event_tx, old))]
 async fn config_watcher(
     path: PathBuf,
     mut shutdown_rx: broadcast::Receiver<bool>,
@@ -61,7 +64,7 @@ async fn config_watcher(
 
     // Initial start
     cfg_event_tx
-        .send(ConfigChangeEvent::General(old.clone()))
+        .send(ConfigChangeEvent::General(Box::new(old.clone())))
         .await
         .unwrap();
 
@@ -73,9 +76,12 @@ async fn config_watcher(
           e = fevent_rx.recv() => {
             match e {
               Some(e) => {
-                match e.kind {
-                  EventKind::Modify(_) => {
-                    info!("Configuration modify event is detected");
+                if let EventKind::Modify(kind) = e.kind {
+                    match kind {
+                        ModifyKind::Data(_) => (),
+                        _ => continue
+                    }
+                    info!("Rescan the configuration");
                     let new = match Config::from_file(&path).await.with_context(|| "The changed configuration is invalid. Ignored") {
                       Ok(v) => v,
                       Err(e) => {
@@ -90,9 +96,7 @@ async fn config_watcher(
                     }
 
                     old = new;
-                  },
-                  _ => (), // Just ignore other events
-                }
+                  }
               },
               None => break
             }
@@ -109,11 +113,71 @@ async fn config_watcher(
 fn calculate_event(old: &Config, new: &Config) -> Vec<ConfigChangeEvent> {
     let mut ret = Vec::new();
 
-    if old == new {
-        return ret;
+    if old != new {
+        if old.server.is_some() && new.server.is_some() {
+            let mut e: Vec<ConfigChangeEvent> = calculate_service_delete_event(
+                &old.server.as_ref().unwrap().services,
+                &new.server.as_ref().unwrap().services,
+            )
+            .into_iter()
+            .map(|x| ConfigChangeEvent::ServiceChange(ServiceChangeEvent::ServerDelete(x)))
+            .collect();
+            ret.append(&mut e);
+
+            let mut e: Vec<ConfigChangeEvent> = calculate_service_add_event(
+                &old.server.as_ref().unwrap().services,
+                &new.server.as_ref().unwrap().services,
+            )
+            .into_iter()
+            .map(|x| ConfigChangeEvent::ServiceChange(ServiceChangeEvent::ServerAdd(x)))
+            .collect();
+
+            ret.append(&mut e);
+        } else if old.client.is_some() && new.client.is_some() {
+            let mut e: Vec<ConfigChangeEvent> = calculate_service_delete_event(
+                &old.client.as_ref().unwrap().services,
+                &new.client.as_ref().unwrap().services,
+            )
+            .into_iter()
+            .map(|x| ConfigChangeEvent::ServiceChange(ServiceChangeEvent::ClientDelete(x)))
+            .collect();
+            ret.append(&mut e);
+
+            let mut e: Vec<ConfigChangeEvent> = calculate_service_add_event(
+                &old.client.as_ref().unwrap().services,
+                &new.client.as_ref().unwrap().services,
+            )
+            .into_iter()
+            .map(|x| ConfigChangeEvent::ServiceChange(ServiceChangeEvent::ClientAdd(x)))
+            .collect();
+
+            ret.append(&mut e);
+        } else {
+            ret.push(ConfigChangeEvent::General(Box::new(new.clone())));
+        }
     }
 
-    ret.push(ConfigChangeEvent::General(new.to_owned()));
-
     ret
 }
+
+fn calculate_service_delete_event<T: PartialEq>(
+    old_services: &HashMap<String, T>,
+    new_services: &HashMap<String, T>,
+) -> Vec<String> {
+    old_services
+        .keys()
+        .filter(|&name| old_services.get(name) != new_services.get(name))
+        .map(|x| x.to_owned())
+        .collect()
+}
+
+fn calculate_service_add_event<T: PartialEq + Clone>(
+    old_services: &HashMap<String, T>,
+    new_services: &HashMap<String, T>,
+) -> Vec<T> {
+    new_services
+        .iter()
+        .filter(|(name, _)| old_services.get(*name) != new_services.get(*name))
+        .map(|(_, c)| c.clone())
+        .collect()
+}

+ 10 - 11
src/lib.rs

@@ -15,7 +15,7 @@ pub use constants::UDP_BUFFER_SIZE;
 
 use anyhow::Result;
 use tokio::sync::{broadcast, mpsc};
-use tracing::debug;
+use tracing::{debug, info};
 
 #[cfg(feature = "client")]
 mod client;
@@ -82,12 +82,10 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()
     while let Some(e) = cfg_watcher.event_rx.recv().await {
         match e {
             ConfigChangeEvent::General(config) => {
-                match last_instance {
-                    Some((i, _)) => {
-                        shutdown_tx.send(true)?;
-                        i.await??;
-                    }
-                    None => (),
+                if let Some((i, _)) = last_instance {
+                    info!("General configuration change detected. Restarting...");
+                    shutdown_tx.send(true)?;
+                    i.await??;
                 }
 
                 debug!("{:?}", config);
@@ -96,7 +94,7 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()
 
                 last_instance = Some((
                     tokio::spawn(run_instance(
-                        config.clone(),
+                        *(config.clone()),
                         args.clone(),
                         shutdown_tx.subscribe(),
                         service_update_rx,
@@ -105,6 +103,7 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()
                 ));
             }
             ConfigChangeEvent::ServiceChange(service_event) => {
+                info!("Service change detcted. {:?}", service_event);
                 if let Some((_, service_update_tx)) = &last_instance {
                     let _ = service_update_tx.send(service_event).await;
                 }
@@ -118,7 +117,7 @@ async fn run_instance(
     config: Config,
     args: Cli,
     shutdown_rx: broadcast::Receiver<bool>,
-    _service_update: mpsc::Receiver<ServiceChangeEvent>,
+    service_update: mpsc::Receiver<ServiceChangeEvent>,
 ) -> Result<()> {
     match determine_run_mode(&config, &args) {
         RunMode::Undetermine => panic!("Cannot determine running as a server or a client"),
@@ -126,13 +125,13 @@ async fn run_instance(
             #[cfg(not(feature = "client"))]
             crate::helper::feature_not_compile("client");
             #[cfg(feature = "client")]
-            run_client(&config, shutdown_rx).await
+            run_client(&config, shutdown_rx, service_update).await
         }
         RunMode::Server => {
             #[cfg(not(feature = "server"))]
             crate::helper::feature_not_compile("server");
             #[cfg(feature = "server")]
-            run_server(&config, shutdown_rx).await
+            run_server(&config, shutdown_rx, service_update).await
         }
     }
 }

+ 40 - 5
src/server.rs

@@ -1,4 +1,5 @@
 use crate::config::{Config, ServerConfig, ServerServiceConfig, ServiceType, TransportType};
+use crate::config_watcher::ServiceChangeEvent;
 use crate::constants::{listen_backoff, UDP_BUFFER_SIZE};
 use crate::multi_map::MultiMap;
 use crate::protocol::Hello::{ControlChannelHello, DataChannelHello};
@@ -35,7 +36,11 @@ const UDP_POOL_SIZE: usize = 2; // The number of cached connections for UDP serv
 const CHAN_SIZE: usize = 2048; // The capacity of various chans
 
 // The entrypoint of running a server
-pub async fn run_server(config: &Config, shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
+pub async fn run_server(
+    config: &Config,
+    shutdown_rx: broadcast::Receiver<bool>,
+    service_rx: mpsc::Receiver<ServiceChangeEvent>,
+) -> Result<()> {
     let config = match &config.server {
             Some(config) => config,
             None => {
@@ -47,13 +52,13 @@ pub async fn run_server(config: &Config, shutdown_rx: broadcast::Receiver<bool>)
     match config.transport.transport_type {
         TransportType::Tcp => {
             let mut server = Server::<TcpTransport>::from(config).await?;
-            server.run(shutdown_rx).await?;
+            server.run(shutdown_rx, service_rx).await?;
         }
         TransportType::Tls => {
             #[cfg(feature = "tls")]
             {
                 let mut server = Server::<TlsTransport>::from(config).await?;
-                server.run(shutdown_rx).await?;
+                server.run(shutdown_rx, service_rx).await?;
             }
             #[cfg(not(feature = "tls"))]
             crate::helper::feature_not_compile("tls")
@@ -62,7 +67,7 @@ pub async fn run_server(config: &Config, shutdown_rx: broadcast::Receiver<bool>)
             #[cfg(feature = "noise")]
             {
                 let mut server = Server::<NoiseTransport>::from(config).await?;
-                server.run(shutdown_rx).await?;
+                server.run(shutdown_rx, service_rx).await?;
             }
             #[cfg(not(feature = "noise"))]
             crate::helper::feature_not_compile("noise")
@@ -114,7 +119,11 @@ impl<'a, T: 'static + Transport> Server<'a, T> {
     }
 
     // The entry point of Server
-    pub async fn run(&mut self, mut shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
+    pub async fn run(
+        &mut self,
+        mut shutdown_rx: broadcast::Receiver<bool>,
+        mut service_rx: mpsc::Receiver<ServiceChangeEvent>,
+    ) -> Result<()> {
         // Listen at `server.bind_addr`
         let l = self
             .transport
@@ -172,12 +181,38 @@ impl<'a, T: 'static + Transport> Server<'a, T> {
                 _ = shutdown_rx.recv() => {
                     info!("Shuting down gracefully...");
                     break;
+                },
+                e = service_rx.recv() => {
+                    if let Some(e) = e {
+                        self.handle_hot_reload(e).await;
+                    }
                 }
             }
         }
 
         Ok(())
     }
+
+    async fn handle_hot_reload(&mut self, e: ServiceChangeEvent) {
+        match e {
+            ServiceChangeEvent::ServerAdd(s) => {
+                let hash = protocol::digest(s.name.as_bytes());
+                let mut wg = self.services.write().await;
+                let _ = wg.insert(hash, s);
+
+                let mut wg = self.control_channels.write().await;
+                let _ = wg.remove1(&hash);
+            }
+            ServiceChangeEvent::ServerDelete(s) => {
+                let hash = protocol::digest(s.as_bytes());
+                let _ = self.services.write().await.remove(&hash);
+
+                let mut wg = self.control_channels.write().await;
+                let _ = wg.remove1(&hash);
+            }
+            _ => (),
+        }
+    }
 }
 
 // Handle connections to `server.bind_addr`