Jelajahi Sumber

feat: application layer heartbeat (#136)

* feat: application layer heartbeat

* feat: make heartbeat configurable

* fix: update keepalive params

* docs: update about heartbeat
Yujia Qiao 3 tahun lalu
induk
melakukan
2746a0ea88
7 mengubah file dengan 106 tambahan dan 41 penghapusan
  1. 6 4
      README.md
  2. 20 10
      src/client.rs
  3. 16 0
      src/config.rs
  4. 3 3
      src/lib.rs
  5. 4 2
      src/protocol.rs
  6. 54 19
      src/server.rs
  7. 3 3
      src/transport/mod.rs

+ 6 - 4
README.md

@@ -105,6 +105,7 @@ Here is the full configuration specification:
 [client]
 remote_addr = "example.com:2333" # Necessary. The address of the server
 default_token = "default_token_if_not_specify" # Optional. The default token of services, if they don't define their own ones
+heartbeat_timeout = 40 # Optional. Set to 0 to disable the application-layer heartbeat test. The value must be greater than `server.heartbeat_interval`. Default: 40 secs
 
 [client.transport] # The whole block is optional. Specify which transport to use
 type = "tcp" # Optional. Possible values: ["tcp", "tls", "noise"]. Default: "tcp"
@@ -112,8 +113,8 @@ type = "tcp" # Optional. Possible values: ["tcp", "tls", "noise"]. Default: "tcp
 [client.transport.tcp] # Optional
 proxy = "socks5://user:passwd@127.0.0.1:1080" # Optional. Use the proxy to connect to the server
 nodelay = false # Optional. Determine whether to enable TCP_NODELAY, if applicable, to improve the latency but decrease the bandwidth. Default: false
-keepalive_secs = 10 # Optional. Specify `tcp_keepalive_time` in `tcp(7)`, if applicable. Default: 10 seconds
-keepalive_interval = 5 # Optional. Specify `tcp_keepalive_intvl` in `tcp(7)`, if applicable. Default: 5 seconds
+keepalive_secs = 20 # Optional. Specify `tcp_keepalive_time` in `tcp(7)`, if applicable. Default: 20 seconds
+keepalive_interval = 8 # Optional. Specify `tcp_keepalive_intvl` in `tcp(7)`, if applicable. Default: 8 seconds
 
 [client.transport.tls] # Necessary if `type` is "tls"
 trusted_root = "ca.pem" # Necessary. The certificate of CA that signed the server's certificate
@@ -136,12 +137,13 @@ local_addr = "127.0.0.1:1082"
 [server]
 bind_addr = "0.0.0.0:2333" # Necessary. The address that the server listens for clients. Generally only the port needs to be change.
 default_token = "default_token_if_not_specify" # Optional
+heartbeat_interval = 30 # Optional. The interval between two application-layer heartbeat. Set to 0 to disable sending heartbeat. Default: 30 secs
 
 [server.transport] # Same as `[client.transport]`
 type = "tcp"
 nodelay = false
-keepalive_secs = 10
-keepalive_interval = 5
+keepalive_secs = 20
+keepalive_interval = 8
 
 [server.transport.tls] # Necessary if `type` is "tls"
 pkcs12 = "identify.pfx" # Necessary. pkcs12 file of server's certificate and private key

+ 20 - 10
src/client.rs

@@ -29,11 +29,11 @@ use crate::constants::{run_control_chan_backoff, UDP_BUFFER_SIZE, UDP_SENDQ_SIZE
 
 // The entrypoint of running a client
 pub async fn run_client(
-    config: &Config,
+    config: Config,
     shutdown_rx: broadcast::Receiver<bool>,
     service_rx: mpsc::Receiver<ServiceChange>,
 ) -> Result<()> {
-    let config = config.client.as_ref().ok_or(anyhow!(
+    let config = config.client.ok_or(anyhow!(
         "Try to run as a client, but the configuration is missing. Please add the `[client]` block"
     ))?;
 
@@ -67,21 +67,21 @@ type ServiceDigest = protocol::Digest;
 type Nonce = protocol::Digest;
 
 // Holds the state of a client
-struct Client<'a, T: Transport> {
-    config: &'a ClientConfig,
+struct Client<T: Transport> {
+    config: ClientConfig,
     service_handles: HashMap<String, ControlChannelHandle>,
     transport: Arc<T>,
 }
 
-impl<'a, T: 'static + Transport> Client<'a, T> {
+impl<T: 'static + Transport> Client<T> {
     // Create a Client from `[client]` config block
-    async fn from(config: &'a ClientConfig) -> Result<Client<'a, T>> {
+    async fn from(config: ClientConfig) -> Result<Client<T>> {
+        let transport =
+            Arc::new(T::new(&config.transport).with_context(|| "Failed to create the transport")?);
         Ok(Client {
             config,
             service_handles: HashMap::new(),
-            transport: Arc::new(
-                T::new(&config.transport).with_context(|| "Failed to create the transport")?,
-            ),
+            transport,
         })
     }
 
@@ -97,6 +97,7 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
                 (*config).clone(),
                 self.config.remote_addr.clone(),
                 self.transport.clone(),
+                self.config.heartbeat_timeout,
             );
             self.service_handles.insert(name.clone(), handle);
         }
@@ -122,6 +123,7 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
                                     s,
                                     self.config.remote_addr.clone(),
                                     self.transport.clone(),
+                                    self.config.heartbeat_timeout
                                 );
                                 let _ = self.service_handles.insert(name, handle);
                             },
@@ -369,6 +371,7 @@ struct ControlChannel<T: Transport> {
     shutdown_rx: oneshot::Receiver<u8>, // Receives the shutdown signal
     remote_addr: String,                // `client.remote_addr`
     transport: Arc<T>,                  // Wrapper around the transport layer
+    heartbeat_timeout: u64,             // Application layer heartbeat timeout in secs
 }
 
 // Handle of a control channel
@@ -451,9 +454,14 @@ impl<T: 'static + Transport> ControlChannel<T> {
                                     warn!("{:#}", e);
                                 }
                             }.instrument(Span::current()));
-                        }
+                        },
+                        ControlChannelCmd::HeartBeat => ()
                     }
                 },
+                _ = time::sleep(Duration::from_secs(self.heartbeat_timeout)), if self.heartbeat_timeout != 0 => {
+                    warn!("Heartbeat timed out");
+                    break;
+                }
                 _ = &mut self.shutdown_rx => {
                     break;
                 }
@@ -471,6 +479,7 @@ impl ControlChannelHandle {
         service: ClientServiceConfig,
         remote_addr: String,
         transport: Arc<T>,
+        heartbeat_timeout: u64,
     ) -> ControlChannelHandle {
         let digest = protocol::digest(service.name.as_bytes());
 
@@ -482,6 +491,7 @@ impl ControlChannelHandle {
             shutdown_rx,
             remote_addr,
             transport,
+            heartbeat_timeout,
         };
 
         tokio::spawn(

+ 16 - 0
src/config.rs

@@ -9,6 +9,10 @@ use url::Url;
 
 use crate::transport::{DEFAULT_KEEPALIVE_INTERVAL, DEFAULT_KEEPALIVE_SECS, DEFAULT_NODELAY};
 
+/// Application-layer heartbeat interval in secs
+const DEFAULT_HEARTBEAT_INTERVAL_SECS: u64 = 30;
+const DEFAULT_HEARTBEAT_TIMEOUT_SECS: u64 = 40;
+
 /// String with Debug implementation that emits "MASKED"
 /// Used to mask sensitive strings when logging
 #[derive(Serialize, Deserialize, Default, PartialEq, Clone)]
@@ -177,6 +181,10 @@ pub struct TransportConfig {
     pub noise: Option<NoiseConfig>,
 }
 
+fn default_heartbeat_timeout() -> u64 {
+    DEFAULT_HEARTBEAT_TIMEOUT_SECS
+}
+
 #[derive(Debug, Serialize, Deserialize, Default, PartialEq, Clone)]
 #[serde(deny_unknown_fields)]
 pub struct ClientConfig {
@@ -185,6 +193,12 @@ pub struct ClientConfig {
     pub services: HashMap<String, ClientServiceConfig>,
     #[serde(default)]
     pub transport: TransportConfig,
+    #[serde(default = "default_heartbeat_timeout")]
+    pub heartbeat_timeout: u64,
+}
+
+fn default_heartbeat_interval() -> u64 {
+    DEFAULT_HEARTBEAT_INTERVAL_SECS
 }
 
 #[derive(Debug, Serialize, Deserialize, Default, PartialEq, Clone)]
@@ -195,6 +209,8 @@ pub struct ServerConfig {
     pub services: HashMap<String, ServerServiceConfig>,
     #[serde(default)]
     pub transport: TransportConfig,
+    #[serde(default = "default_heartbeat_interval")]
+    pub heartbeat_interval: u64,
 }
 
 #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]

+ 3 - 3
src/lib.rs

@@ -93,7 +93,7 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()
 
                 last_instance = Some((
                     tokio::spawn(run_instance(
-                        *(config.clone()),
+                        *config,
                         args.clone(),
                         shutdown_tx.subscribe(),
                         service_update_rx,
@@ -127,13 +127,13 @@ async fn run_instance(
             #[cfg(not(feature = "client"))]
             crate::helper::feature_not_compile("client");
             #[cfg(feature = "client")]
-            run_client(&config, shutdown_rx, service_update).await
+            run_client(config, shutdown_rx, service_update).await
         }
         RunMode::Server => {
             #[cfg(not(feature = "server"))]
             crate::helper::feature_not_compile("server");
             #[cfg(feature = "server")]
-            run_server(&config, shutdown_rx, service_update).await
+            run_server(config, shutdown_rx, service_update).await
         }
     };
     ret.unwrap();

+ 4 - 2
src/protocol.rs

@@ -9,9 +9,10 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
 use tracing::trace;
 
 type ProtocolVersion = u8;
-const PROTO_V0: u8 = 0u8;
+const _PROTO_V0: u8 = 0u8;
+const PROTO_V1: u8 = 1u8;
 
-pub const CURRENT_PROTO_VERSION: ProtocolVersion = PROTO_V0;
+pub const CURRENT_PROTO_VERSION: ProtocolVersion = PROTO_V1;
 
 pub type Digest = [u8; HASH_WIDTH_IN_BYTES];
 
@@ -48,6 +49,7 @@ impl std::fmt::Display for Ack {
 #[derive(Deserialize, Serialize, Debug)]
 pub enum ControlChannelCmd {
     CreateDataChannel,
+    HeartBeat,
 }
 
 #[derive(Deserialize, Serialize, Debug)]

+ 54 - 19
src/server.rs

@@ -38,11 +38,11 @@ const HANDSHAKE_TIMEOUT: u64 = 5; // Timeout for transport handshake
 
 // The entrypoint of running a server
 pub async fn run_server(
-    config: &Config,
+    config: Config,
     shutdown_rx: broadcast::Receiver<bool>,
     service_rx: mpsc::Receiver<ServiceChange>,
 ) -> Result<()> {
-    let config = match &config.server {
+    let config = match config.server {
             Some(config) => config,
             None => {
                 return Err(anyhow!("Try to run as a server, but the configuration is missing. Please add the `[server]` block"))
@@ -82,9 +82,9 @@ pub async fn run_server(
 type ControlChannelMap<T> = MultiMap<ServiceDigest, Nonce, ControlChannelHandle<T>>;
 
 // Server holds all states of running a server
-struct Server<'a, T: Transport> {
+struct Server<T: Transport> {
     // `[server]` config
-    config: &'a ServerConfig,
+    config: Arc<ServerConfig>,
 
     // `[server.services]` config, indexed by ServiceDigest
     services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
@@ -105,14 +105,18 @@ fn generate_service_hashmap(
     ret
 }
 
-impl<'a, T: 'static + Transport> Server<'a, T> {
+impl<T: 'static + Transport> Server<T> {
     // Create a server from `[server]`
-    pub async fn from(config: &'a ServerConfig) -> Result<Server<'a, T>> {
+    pub async fn from(config: ServerConfig) -> Result<Server<T>> {
+        let config = Arc::new(config);
+        let services = Arc::new(RwLock::new(generate_service_hashmap(&config)));
+        let control_channels = Arc::new(RwLock::new(ControlChannelMap::new()));
+        let transport = Arc::new(T::new(&config.transport)?);
         Ok(Server {
             config,
-            services: Arc::new(RwLock::new(generate_service_hashmap(config))),
-            control_channels: Arc::new(RwLock::new(ControlChannelMap::new())),
-            transport: Arc::new(T::new(&config.transport)?),
+            services,
+            control_channels,
+            transport,
         })
     }
 
@@ -171,8 +175,9 @@ impl<'a, T: 'static + Transport> Server<'a, T> {
                                         Ok(conn) => {
                                             let services = self.services.clone();
                                             let control_channels = self.control_channels.clone();
+                                            let server_config = self.config.clone();
                                             tokio::spawn(async move {
-                                                if let Err(err) = handle_connection(conn, services, control_channels).await {
+                                                if let Err(err) = handle_connection(conn, services, control_channels, server_config).await {
                                                     error!("{:#}", err);
                                                 }
                                             }.instrument(info_span!("connection", %addr)));
@@ -233,12 +238,20 @@ async fn handle_connection<T: 'static + Transport>(
     mut conn: T::Stream,
     services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
     control_channels: Arc<RwLock<ControlChannelMap<T>>>,
+    server_config: Arc<ServerConfig>,
 ) -> Result<()> {
     // Read hello
     let hello = read_hello(&mut conn).await?;
     match hello {
         ControlChannelHello(_, service_digest) => {
-            do_control_channel_handshake(conn, services, control_channels, service_digest).await?;
+            do_control_channel_handshake(
+                conn,
+                services,
+                control_channels,
+                service_digest,
+                server_config,
+            )
+            .await?;
         }
         DataChannelHello(_, nonce) => {
             do_data_channel_handshake(conn, control_channels, nonce).await?;
@@ -252,6 +265,7 @@ async fn do_control_channel_handshake<T: 'static + Transport>(
     services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
     control_channels: Arc<RwLock<ControlChannelMap<T>>>,
     service_digest: ServiceDigest,
+    server_config: Arc<ServerConfig>,
 ) -> Result<()> {
     info!("Try to handshake a control channel");
 
@@ -321,7 +335,8 @@ async fn do_control_channel_handshake<T: 'static + Transport>(
         conn.flush().await?;
 
         info!(service = %service_config.name, "Control channel established");
-        let handle = ControlChannelHandle::new(conn, service_config);
+        let handle =
+            ControlChannelHandle::new(conn, service_config, server_config.heartbeat_interval);
 
         // Insert the new handle
         let _ = h.insert(service_digest, session_key, handle);
@@ -371,7 +386,11 @@ where
     // Create a control channel handle, where the control channel handling task
     // and the connection pool task are created.
     #[instrument(name = "handle", skip_all, fields(service = %service.name))]
-    fn new(conn: T::Stream, service: ServerServiceConfig) -> ControlChannelHandle<T> {
+    fn new(
+        conn: T::Stream,
+        service: ServerServiceConfig,
+        heartbeat_interval: u64,
+    ) -> ControlChannelHandle<T> {
         // Create a shutdown channel
         let (shutdown_tx, shutdown_rx) = broadcast::channel::<bool>(1);
 
@@ -435,6 +454,7 @@ where
             conn,
             shutdown_rx,
             data_ch_req_rx,
+            heartbeat_interval,
         };
 
         // Run the control channel
@@ -460,13 +480,26 @@ struct ControlChannel<T: Transport> {
     conn: T::Stream,                               // The connection of control channel
     shutdown_rx: broadcast::Receiver<bool>,        // Receives the shutdown signal
     data_ch_req_rx: mpsc::UnboundedReceiver<bool>, // Receives visitor connections
+    heartbeat_interval: u64,                       // Application-layer heartbeat interval in secs
 }
 
 impl<T: Transport> ControlChannel<T> {
+    async fn write_and_flush(&mut self, data: &[u8]) -> Result<()> {
+        self.conn
+            .write_all(data)
+            .await
+            .with_context(|| "Failed to write control cmds")?;
+        self.conn
+            .flush()
+            .await
+            .with_context(|| "Failed to flush control cmds")?;
+        Ok(())
+    }
     // Run a control channel
     #[instrument(skip_all)]
     async fn run(mut self) -> Result<()> {
-        let cmd = bincode::serialize(&ControlChannelCmd::CreateDataChannel).unwrap();
+        let create_ch_cmd = bincode::serialize(&ControlChannelCmd::CreateDataChannel).unwrap();
+        let heartbeat = bincode::serialize(&ControlChannelCmd::HeartBeat).unwrap();
 
         // Wait for data channel requests and the shutdown signal
         loop {
@@ -474,11 +507,7 @@ impl<T: Transport> ControlChannel<T> {
                 val = self.data_ch_req_rx.recv() => {
                     match val {
                         Some(_) => {
-                            if let Err(e) = self.conn.write_all(&cmd).await.with_context(||"Failed to write control cmds") {
-                                error!("{:#}", e);
-                                break;
-                            }
-                            if let Err(e) = self.conn.flush().await.with_context(|| "Failed to flush control cmds") {
+                            if let Err(e) = self.write_and_flush(&create_ch_cmd).await {
                                 error!("{:#}", e);
                                 break;
                             }
@@ -488,6 +517,12 @@ impl<T: Transport> ControlChannel<T> {
                         }
                     }
                 },
+                _ = time::sleep(Duration::from_secs(self.heartbeat_interval)), if self.heartbeat_interval != 0 => {
+                            if let Err(e) = self.write_and_flush(&heartbeat).await {
+                                error!("{:#}", e);
+                                break;
+                            }
+                }
                 // Wait for the shutdown signal
                 _ = self.shutdown_rx.recv() => {
                     break;

+ 3 - 3
src/transport/mod.rs

@@ -9,10 +9,10 @@ use tokio::io::{AsyncRead, AsyncWrite};
 use tokio::net::{TcpStream, ToSocketAddrs};
 use tracing::{error, trace};
 
-pub static DEFAULT_NODELAY: bool = false;
+pub const DEFAULT_NODELAY: bool = false;
 
-pub static DEFAULT_KEEPALIVE_SECS: u64 = 10;
-pub static DEFAULT_KEEPALIVE_INTERVAL: u64 = 3;
+pub const DEFAULT_KEEPALIVE_SECS: u64 = 20;
+pub const DEFAULT_KEEPALIVE_INTERVAL: u64 = 8;
 
 /// Specify a transport layer, like TCP, TLS
 #[async_trait]