Kaynağa Gözat

fix: respect shutdown signal when retry (#121)

* fix: respect shutdown signal when retry

* test: add tests for retry_notify

* chore: drop backoff/tokio
Yujia Qiao 4 yıl önce
ebeveyn
işleme
0278c529dd
5 değiştirilmiş dosya ile 104 ekleme ve 31 silme
  1. 0 3
      Cargo.lock
  2. 1 1
      Cargo.toml
  3. 15 13
      src/client.rs
  4. 62 0
      src/helper.rs
  5. 26 14
      src/server.rs

+ 0 - 3
Cargo.lock

@@ -122,12 +122,9 @@ version = "0.3.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "9fe17f59a06fe8b87a6fc8bf53bb70b3aba76d7685f432487a68cd5552853625"
 dependencies = [
- "futures-core",
  "getrandom 0.2.4",
  "instant",
- "pin-project",
  "rand",
- "tokio",
 ]
 
 [[package]]

+ 1 - 1
Cargo.toml

@@ -57,7 +57,7 @@ bincode = "1"
 lazy_static = "1.4"
 hex = "0.4"
 rand = "0.8"
-backoff = { version="0.3", features=["tokio"] }
+backoff = "0.3"
 tracing = "0.1"
 tracing-subscriber = "0.2"
 socket2 = { version = "0.4", features = ["all"] }

+ 15 - 13
src/client.rs

@@ -1,6 +1,6 @@
 use crate::config::{ClientConfig, ClientServiceConfig, Config, ServiceType, TransportType};
 use crate::config_watcher::ServiceChange;
-use crate::helper::udp_connect;
+use crate::helper::{retry_notify, udp_connect};
 use crate::protocol::Hello::{self, *};
 use crate::protocol::{
     self, read_ack, read_control_cmd, read_data_cmd, read_hello, Ack, Auth, ControlChannelCmd,
@@ -159,31 +159,33 @@ async fn do_data_channel_handshake<T: Transport>(
     args: Arc<RunDataChannelArgs<T>>,
 ) -> Result<T::Stream> {
     // Retry at least every 100ms, at most for 10 seconds
-    let backoff = ExponentialBackoff {
+    let mut backoff = ExponentialBackoff {
         max_interval: Duration::from_millis(100),
         max_elapsed_time: Some(Duration::from_secs(10)),
         ..Default::default()
     };
 
-    // FIXME: Respect control channel shutdown here
     // Connect to remote_addr
-    let mut conn: T::Stream = backoff::future::retry_notify(
+    let mut conn: T::Stream = retry_notify!(
         backoff,
-        || async {
-            let conn = args
+        {
+            match args
                 .connector
                 .connect(&args.remote_addr)
                 .await
-                .with_context(|| format!("Failed to connect to {}", &args.remote_addr))?;
-            T::hint(&conn, args.socket_opts);
-
-            Ok(conn)
+                .with_context(|| format!("Failed to connect to {}", &args.remote_addr))
+            {
+                Ok(conn) => {
+                    T::hint(&conn, args.socket_opts);
+                    Ok(conn)
+                }
+                Err(e) => Err(e),
+            }
         },
         |e, duration| {
             warn!("{:#}. Retry in {:?}", e, duration);
-        },
-    )
-    .await?;
+        }
+    )?;
 
     // Send nonce
     let v: &[u8; HASH_WIDTH_IN_BYTES] = args.session_key[..].try_into().unwrap();

+ 62 - 0
src/helper.rs

@@ -51,3 +51,65 @@ pub async fn udp_connect<A: ToSocketAddrs>(addr: A) -> Result<UdpSocket> {
     s.connect(addr).await?;
     Ok(s)
 }
+
+/// Almost same as backoff::future::retry_notify
+/// But directly expands to a loop
+macro_rules! retry_notify {
+    ($b: expr, $func: expr, $notify: expr) => {
+        loop {
+            match $func {
+                Ok(v) => break Ok(v),
+                Err(e) => match $b.next_backoff() {
+                    Some(duration) => {
+                        $notify(e, duration);
+                        tokio::time::sleep(duration).await;
+                    }
+                    None => break Err(e),
+                },
+            }
+        }
+    };
+}
+
+pub(crate) use retry_notify;
+
+#[cfg(test)]
+mod test {
+    use super::*;
+    use backoff::{backoff::Backoff, ExponentialBackoff};
+    #[tokio::test]
+    async fn test_retry_notify() {
+        let tests = [(3, 3, 2, Ok(())), (4, 3, 2, Err("try again"))];
+        for (try_succ, try_expected, notify_expected, expected) in tests {
+            let mut b = ExponentialBackoff {
+                current_interval: Duration::from_millis(100),
+                initial_interval: Duration::from_millis(100),
+                max_elapsed_time: Some(Duration::from_millis(200)),
+                randomization_factor: 0.0,
+                multiplier: 1.0,
+                ..Default::default()
+            };
+
+            let mut notify_count = 0;
+            let mut try_count = 0;
+            let ret: Result<(), &str> = retry_notify!(
+                b,
+                {
+                    try_count += 1;
+                    if try_count == try_succ {
+                        Ok(())
+                    } else {
+                        Err("try again")
+                    }
+                },
+                |e, duration| {
+                    notify_count += 1;
+                    println!("{}: {}, {:?}", notify_count, e, duration);
+                }
+            );
+            assert_eq!(ret, expected);
+            assert_eq!(try_count, try_expected);
+            assert_eq!(notify_count, notify_expected);
+        }
+    }
+}

+ 26 - 14
src/server.rs

@@ -1,6 +1,7 @@
 use crate::config::{Config, ServerConfig, ServerServiceConfig, ServiceType, TransportType};
 use crate::config_watcher::ServiceChange;
 use crate::constants::{listen_backoff, UDP_BUFFER_SIZE};
+use crate::helper::retry_notify;
 use crate::multi_map::MultiMap;
 use crate::protocol::Hello::{ControlChannelHello, DataChannelHello};
 use crate::protocol::{
@@ -508,17 +509,21 @@ fn tcp_listen_and_send(
     let (tx, rx) = mpsc::channel(CHAN_SIZE);
 
     tokio::spawn(async move {
-        // FIXME: Respect shutdown signal
-        let l = backoff::future::retry_notify(listen_backoff(), || async {
-            Ok(TcpListener::bind(&addr).await?)
+        let l = retry_notify!(listen_backoff(),  {
+            match shutdown_rx.try_recv() {
+                Err(broadcast::error::TryRecvError::Closed) => Ok(None),
+                _ => TcpListener::bind(&addr).await.map(Some)
+            }
         }, |e, duration| {
             error!("{:#}. Retry in {:?}", e, duration);
         })
-        .await
         .with_context(|| "Failed to listen for the service");
 
         let l: TcpListener = match l {
-            Ok(v) => v,
+            Ok(v) => match v {
+                Some(v) => v,
+                None => return
+            },
             Err(e) => {
                 error!("{:#}", e);
                 return;
@@ -623,20 +628,27 @@ async fn run_udp_connection_pool<T: Transport>(
 ) -> Result<()> {
     // TODO: Load balance
 
-    // FIXME: Respect shutdown signal
-    let l: UdpSocket = backoff::future::retry_notify(
+    let l = retry_notify!(
         listen_backoff(),
-        || async {
-            Ok(UdpSocket::bind(&bind_addr)
-                .await
-                .with_context(|| "Failed to listen for the service")?)
+        {
+            match shutdown_rx.try_recv() {
+                Err(broadcast::error::TryRecvError::Closed) => Ok(None),
+                _ => UdpSocket::bind(&bind_addr).await.map(Some),
+            }
         },
         |e, duration| {
             warn!("{:#}. Retry in {:?}", e, duration);
-        },
+        }
     )
-    .await
-    .with_context(|| "Failed to listen for the service")?;
+    .with_context(|| "Failed to listen for the service");
+
+    let l = match l {
+        Ok(v) => match v {
+            Some(l) => l,
+            None => return Ok(()),
+        },
+        Err(e) => return Err(e),
+    };
 
     info!("Listening at {}", &bind_addr);