Ver Fonte

feat: hot-reload by restarting

Yujia Qiao há 4 anos atrás
pai
commit
8097b6916f
9 ficheiros alterados com 319 adições e 27 exclusões
  1. 120 0
      Cargo.lock
  2. 1 0
      Cargo.toml
  3. 1 1
      src/cli.rs
  4. 10 10
      src/config.rs
  5. 119 0
      src/config_watcher.rs
  6. 61 9
      src/lib.rs
  7. 1 1
      src/main.rs
  8. 2 2
      tests/common/mod.rs
  9. 4 4
      tests/integration_test.rs

+ 120 - 0
Cargo.lock

@@ -272,6 +272,26 @@ dependencies = [
  "libc",
 ]
 
+[[package]]
+name = "crossbeam-channel"
+version = "0.5.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "06ed27e177f16d65f0f0c22a213e17c696ace5dd64b14258b52f9417ccb52db4"
+dependencies = [
+ "cfg-if",
+ "crossbeam-utils",
+]
+
+[[package]]
+name = "crossbeam-utils"
+version = "0.8.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d82cfc11ce7f2c3faef78d8a684447b40d503d9681acebed6cb728d45940c4db"
+dependencies = [
+ "cfg-if",
+ "lazy_static",
+]
+
 [[package]]
 name = "crypto-common"
 version = "0.1.1"
@@ -342,6 +362,18 @@ dependencies = [
  "libc",
 ]
 
+[[package]]
+name = "filetime"
+version = "0.2.15"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "975ccf83d8d9d0d84682850a38c8169027be83368805971cc4f238c2b245bc98"
+dependencies = [
+ "cfg-if",
+ "libc",
+ "redox_syscall",
+ "winapi",
+]
+
 [[package]]
 name = "foreign-types"
 version = "0.3.2"
@@ -357,6 +389,15 @@ version = "0.1.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b"
 
+[[package]]
+name = "fsevent-sys"
+version = "4.0.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5c0e564d24da983c053beff1bb7178e237501206840a3e6bf4e267b9e8ae734a"
+dependencies = [
+ "libc",
+]
+
 [[package]]
 name = "futures-core"
 version = "0.3.19"
@@ -476,6 +517,26 @@ dependencies = [
  "hashbrown",
 ]
 
+[[package]]
+name = "inotify"
+version = "0.9.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f8069d3ec154eb856955c1c0fbffefbf5f3c40a104ec912d4797314c1801abff"
+dependencies = [
+ "bitflags",
+ "inotify-sys",
+ "libc",
+]
+
+[[package]]
+name = "inotify-sys"
+version = "0.1.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e05c02b5e89bff3b946cedeca278abc628fe811e604f027c45a8aa3cf793d0eb"
+dependencies = [
+ "libc",
+]
+
 [[package]]
 name = "instant"
 version = "0.1.12"
@@ -491,6 +552,26 @@ version = "1.0.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "1aab8fc367588b89dcee83ab0fd66b72b50b72fa1904d7095045ace2b0c81c35"
 
+[[package]]
+name = "kqueue"
+version = "1.0.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "058a107a784f8be94c7d35c1300f4facced2e93d2fbe5b1452b44e905ddca4a9"
+dependencies = [
+ "kqueue-sys",
+ "libc",
+]
+
+[[package]]
+name = "kqueue-sys"
+version = "1.0.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8367585489f01bc55dd27404dcf56b95e6da061a256a666ab23be9ba96a2e587"
+dependencies = [
+ "bitflags",
+ "libc",
+]
+
 [[package]]
 name = "lazy_static"
 version = "1.4.0"
@@ -576,6 +657,24 @@ dependencies = [
  "tempfile",
 ]
 
+[[package]]
+name = "notify"
+version = "5.0.0-pre.13"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "245d358380e2352c2d020e8ee62baac09b3420f1f6c012a31326cfced4ad487d"
+dependencies = [
+ "bitflags",
+ "crossbeam-channel",
+ "filetime",
+ "fsevent-sys",
+ "inotify",
+ "kqueue",
+ "libc",
+ "mio",
+ "walkdir",
+ "winapi",
+]
+
 [[package]]
 name = "ntapi"
 version = "0.3.6"
@@ -874,6 +973,7 @@ dependencies = [
  "fdlimit",
  "hex",
  "lazy_static",
+ "notify",
  "rand",
  "serde",
  "sha2 0.10.0",
@@ -943,6 +1043,15 @@ version = "1.0.9"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "73b4b750c782965c211b42f022f59af1fbceabdd026623714f104152f1ec149f"
 
+[[package]]
+name = "same-file"
+version = "1.0.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
+dependencies = [
+ "winapi-util",
+]
+
 [[package]]
 name = "schannel"
 version = "0.1.19"
@@ -1389,6 +1498,17 @@ version = "0.9.3"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "5fecdca9a5291cc2b8dcf7dc02453fee791a280f3743cb0905f8822ae463b3fe"
 
+[[package]]
+name = "walkdir"
+version = "2.3.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "808cf2735cd4b6866113f648b791c6adc5714537bc222d9347bb203386ffda56"
+dependencies = [
+ "same-file",
+ "winapi",
+ "winapi-util",
+]
+
 [[package]]
 name = "wasi"
 version = "0.9.0+wasi-snapshot-preview1"

+ 1 - 0
Cargo.toml

@@ -50,3 +50,4 @@ tokio-native-tls = { version = "0.3.0", optional = true }
 async-trait = "0.1.52"
 snowstorm = { git = "https://github.com/black-binary/snowstorm", rev = "1887755", optional = true }
 base64 = { version = "0.13.0", optional = true }
+notify = "5.0.0-pre.13"

+ 1 - 1
src/cli.rs

@@ -6,7 +6,7 @@ pub enum KeypairType {
     X448,
 }
 
-#[derive(Parser, Debug, Default)]
+#[derive(Parser, Debug, Default, Clone)]
 #[clap(about, version, setting(AppSettings::DeriveDisplayOrder))]
 #[clap(group(
             ArgGroup::new("cmds")

+ 10 - 10
src/config.rs

@@ -4,7 +4,7 @@ use std::collections::HashMap;
 use std::path::Path;
 use tokio::fs;
 
-#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
+#[derive(Debug, Serialize, Deserialize, Copy, Clone, PartialEq)]
 pub enum TransportType {
     #[serde(rename = "tcp")]
     Tcp,
@@ -20,7 +20,7 @@ impl Default for TransportType {
     }
 }
 
-#[derive(Debug, Serialize, Deserialize, Clone)]
+#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
 pub struct ClientServiceConfig {
     #[serde(rename = "type", default = "default_service_type")]
     pub service_type: ServiceType,
@@ -30,7 +30,7 @@ pub struct ClientServiceConfig {
     pub token: Option<String>,
 }
 
-#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
+#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq)]
 pub enum ServiceType {
     #[serde(rename = "tcp")]
     Tcp,
@@ -42,7 +42,7 @@ fn default_service_type() -> ServiceType {
     ServiceType::Tcp
 }
 
-#[derive(Debug, Serialize, Deserialize, Clone)]
+#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
 pub struct ServerServiceConfig {
     #[serde(rename = "type", default = "default_service_type")]
     pub service_type: ServiceType,
@@ -52,7 +52,7 @@ pub struct ServerServiceConfig {
     pub token: Option<String>,
 }
 
-#[derive(Clone, Debug, Serialize, Deserialize)]
+#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
 pub struct TlsConfig {
     pub hostname: Option<String>,
     pub trusted_root: Option<String>,
@@ -64,7 +64,7 @@ fn default_noise_pattern() -> String {
     String::from("Noise_NK_25519_ChaChaPoly_BLAKE2s")
 }
 
-#[derive(Debug, Serialize, Deserialize, Clone)]
+#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
 pub struct NoiseConfig {
     #[serde(default = "default_noise_pattern")]
     pub pattern: String,
@@ -73,7 +73,7 @@ pub struct NoiseConfig {
     // TODO: Maybe psk can be added
 }
 
-#[derive(Debug, Serialize, Deserialize, Default)]
+#[derive(Debug, Serialize, Deserialize, Default, PartialEq, Clone)]
 pub struct TransportConfig {
     #[serde(rename = "type")]
     pub transport_type: TransportType,
@@ -85,7 +85,7 @@ fn default_transport() -> TransportConfig {
     Default::default()
 }
 
-#[derive(Debug, Serialize, Deserialize, Default)]
+#[derive(Debug, Serialize, Deserialize, Default, PartialEq, Clone)]
 pub struct ClientConfig {
     pub remote_addr: String,
     pub default_token: Option<String>,
@@ -94,7 +94,7 @@ pub struct ClientConfig {
     pub transport: TransportConfig,
 }
 
-#[derive(Debug, Serialize, Deserialize, Default)]
+#[derive(Debug, Serialize, Deserialize, Default, PartialEq, Clone)]
 pub struct ServerConfig {
     pub bind_addr: String,
     pub default_token: Option<String>,
@@ -103,7 +103,7 @@ pub struct ServerConfig {
     pub transport: TransportConfig,
 }
 
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
 #[serde(deny_unknown_fields)]
 pub struct Config {
     pub server: Option<ServerConfig>,

+ 119 - 0
src/config_watcher.rs

@@ -0,0 +1,119 @@
+use crate::{
+    config::{ClientServiceConfig, ServerServiceConfig},
+    Config,
+};
+use anyhow::{Context, Result};
+use notify::{EventKind, RecursiveMode, Watcher};
+use std::path::PathBuf;
+use tokio::sync::{broadcast, mpsc};
+use tracing::{error, info, instrument};
+
+#[derive(Debug)]
+pub enum ConfigChangeEvent {
+    General(Config), // Trigger a full restart
+    ServiceChange(ServiceChangeEvent),
+}
+
+#[derive(Debug)]
+pub enum ServiceChangeEvent {
+    AddClientService(ClientServiceConfig),
+    DeleteClientService(ClientServiceConfig),
+    AddServerService(ServerServiceConfig),
+    DeleteServerService(ServerServiceConfig),
+}
+
+pub struct ConfigWatcherHandle {
+    pub event_rx: mpsc::Receiver<ConfigChangeEvent>,
+}
+
+impl ConfigWatcherHandle {
+    pub async fn new(path: &PathBuf, shutdown_rx: broadcast::Receiver<bool>) -> Result<Self> {
+        let (event_tx, event_rx) = mpsc::channel(16);
+
+        let origin_cfg = Config::from_file(path).await?;
+
+        tokio::spawn(config_watcher(
+            path.to_owned(),
+            shutdown_rx,
+            event_tx,
+            origin_cfg,
+        ));
+
+        Ok(ConfigWatcherHandle { event_rx })
+    }
+}
+
+#[instrument(skip(shutdown_rx, cfg_event_tx))]
+async fn config_watcher(
+    path: PathBuf,
+    mut shutdown_rx: broadcast::Receiver<bool>,
+    cfg_event_tx: mpsc::Sender<ConfigChangeEvent>,
+    mut old: Config,
+) -> Result<()> {
+    let (fevent_tx, mut fevent_rx) = mpsc::channel(16);
+
+    let mut watcher = notify::recommended_watcher(move |res| match res {
+        Ok(event) => {
+            let _ = fevent_tx.blocking_send(event);
+        }
+        Err(e) => error!("watch error: {:?}", e),
+    })?;
+
+    // Initial start
+    cfg_event_tx
+        .send(ConfigChangeEvent::General(old.clone()))
+        .await
+        .unwrap();
+
+    watcher.watch(&path, RecursiveMode::NonRecursive)?;
+    info!("Start watching the config");
+
+    loop {
+        tokio::select! {
+          e = fevent_rx.recv() => {
+            match e {
+              Some(e) => {
+                match e.kind {
+                  EventKind::Modify(_) => {
+                    info!("Configuration modify event is detected");
+                    let new = match Config::from_file(&path).await.with_context(|| "The changed configuration is invalid. Ignored") {
+                      Ok(v) => v,
+                      Err(e) => {
+                        error!("{:?}", e);
+                        // If the config is invalid, just ignore it
+                        continue;
+                      }
+                    };
+
+                    for event in calculate_event(&old, &new) {
+                      cfg_event_tx.send(event).await?;
+                    }
+
+                    old = new;
+                  },
+                  _ => (), // Just ignore other events
+                }
+              },
+              None => break
+            }
+          },
+          _ = shutdown_rx.recv() => break
+        }
+    }
+
+    info!("Config watcher exiting");
+
+    Ok(())
+}
+
+fn calculate_event(old: &Config, new: &Config) -> Vec<ConfigChangeEvent> {
+    let mut ret = Vec::new();
+
+    if old == new {
+        return ret;
+    }
+
+    ret.push(ConfigChangeEvent::General(new.to_owned()));
+
+    ret
+}

+ 61 - 9
src/lib.rs

@@ -1,5 +1,6 @@
 mod cli;
 mod config;
+mod config_watcher;
 mod constants;
 mod helper;
 mod multi_map;
@@ -9,10 +10,11 @@ mod transport;
 pub use cli::Cli;
 use cli::KeypairType;
 pub use config::Config;
+use config_watcher::ServiceChangeEvent;
 pub use constants::UDP_BUFFER_SIZE;
 
-use anyhow::{anyhow, Result};
-use tokio::sync::broadcast;
+use anyhow::Result;
+use tokio::sync::{broadcast, mpsc};
 use tracing::debug;
 
 #[cfg(feature = "client")]
@@ -25,6 +27,8 @@ mod server;
 #[cfg(feature = "server")]
 use server::run_server;
 
+use crate::config_watcher::{ConfigChangeEvent, ConfigWatcherHandle};
+
 const DEFAULT_CURVE: KeypairType = KeypairType::X25519;
 
 fn get_str_from_keypair_type(curve: KeypairType) -> &'static str {
@@ -56,20 +60,68 @@ fn genkey(curve: Option<KeypairType>) -> Result<()> {
     crate::helper::feature_not_compile("nosie")
 }
 
-pub async fn run(args: &Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
+pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
     if args.genkey.is_some() {
         return genkey(args.genkey.unwrap());
     }
 
-    let config = Config::from_file(args.config_path.as_ref().unwrap()).await?;
-
-    debug!("{:?}", config);
-
     // Raise `nofile` limit on linux and mac
     fdlimit::raise_fd_limit();
 
-    match determine_run_mode(&config, args) {
-        RunMode::Undetermine => Err(anyhow!("Cannot determine running as a server or a client")),
+    // Spawn a config watcher. The watcher will send a initial signal to start the instance with a config
+    let config_path = args.config_path.as_ref().unwrap();
+    let mut cfg_watcher = ConfigWatcherHandle::new(config_path, shutdown_rx).await?;
+
+    // shutdown_tx owns the instance
+    let (shutdown_tx, _) = broadcast::channel(1);
+
+    // (The join handle of the last instance, The service update channel sender)
+    let mut last_instance: Option<(tokio::task::JoinHandle<_>, mpsc::Sender<ServiceChangeEvent>)> =
+        None;
+
+    while let Some(e) = cfg_watcher.event_rx.recv().await {
+        match e {
+            ConfigChangeEvent::General(config) => {
+                match last_instance {
+                    Some((i, _)) => {
+                        shutdown_tx.send(true)?;
+                        i.await??;
+                    }
+                    None => (),
+                }
+
+                debug!("{:?}", config);
+
+                let (service_update_tx, service_update_rx) = mpsc::channel(1024);
+
+                last_instance = Some((
+                    tokio::spawn(run_instance(
+                        config.clone(),
+                        args.clone(),
+                        shutdown_tx.subscribe(),
+                        service_update_rx,
+                    )),
+                    service_update_tx,
+                ));
+            }
+            ConfigChangeEvent::ServiceChange(service_event) => {
+                if let Some((_, service_update_tx)) = &last_instance {
+                    let _ = service_update_tx.send(service_event).await;
+                }
+            }
+        }
+    }
+    Ok(())
+}
+
+async fn run_instance(
+    config: Config,
+    args: Cli,
+    shutdown_rx: broadcast::Receiver<bool>,
+    _service_update: mpsc::Receiver<ServiceChangeEvent>,
+) -> Result<()> {
+    match determine_run_mode(&config, &args) {
+        RunMode::Undetermine => panic!("Cannot determine running as a server or a client"),
         RunMode::Client => {
             #[cfg(not(feature = "client"))]
             crate::helper::feature_not_compile("client");

+ 1 - 1
src/main.rs

@@ -29,5 +29,5 @@ async fn main() -> Result<()> {
         )
         .init();
 
-    run(&args, shutdown_rx).await
+    run(args, shutdown_rx).await
 }

+ 2 - 2
tests/common/mod.rs

@@ -20,7 +20,7 @@ pub async fn run_rathole_server(
         client: false,
         ..Default::default()
     };
-    rathole::run(&cli, shutdown_rx).await
+    rathole::run(cli, shutdown_rx).await
 }
 
 pub async fn run_rathole_client(
@@ -33,7 +33,7 @@ pub async fn run_rathole_client(
         client: true,
         ..Default::default()
     };
-    rathole::run(&cli, shutdown_rx).await
+    rathole::run(cli, shutdown_rx).await
 }
 
 pub mod tcp {

+ 4 - 4
tests/integration_test.rs

@@ -94,7 +94,7 @@ async fn test(config_path: &'static str, t: Type) -> Result<()> {
     // Start the client
     info!("start the client");
     let client = tokio::spawn(async move {
-        run_rathole_client(&config_path, client_shutdown_rx)
+        run_rathole_client(config_path, client_shutdown_rx)
             .await
             .unwrap();
     });
@@ -105,7 +105,7 @@ async fn test(config_path: &'static str, t: Type) -> Result<()> {
     // Start the server
     info!("start the server");
     let server = tokio::spawn(async move {
-        run_rathole_server(&config_path, server_shutdown_rx)
+        run_rathole_server(config_path, server_shutdown_rx)
             .await
             .unwrap();
     });
@@ -126,7 +126,7 @@ async fn test(config_path: &'static str, t: Type) -> Result<()> {
     info!("restart the client");
     let client_shutdown_rx = client_shutdown_tx.subscribe();
     let client = tokio::spawn(async move {
-        run_rathole_client(&config_path, client_shutdown_rx)
+        run_rathole_client(config_path, client_shutdown_rx)
             .await
             .unwrap();
     });
@@ -147,7 +147,7 @@ async fn test(config_path: &'static str, t: Type) -> Result<()> {
     info!("restart the server");
     let server_shutdown_rx = server_shutdown_tx.subscribe();
     let server = tokio::spawn(async move {
-        run_rathole_server(&config_path, server_shutdown_rx)
+        run_rathole_server(config_path, server_shutdown_rx)
             .await
             .unwrap();
     });