Pārlūkot izejas kodu

fix: reimplement `retry_notify` with signals (#123)

Yujia Qiao 3 gadi atpakaļ
vecāks
revīzija
9d143dab6a
5 mainītis faili ar 51 papildinājumiem un 96 dzēšanām
  1. 5 2
      Cargo.lock
  2. 1 1
      Cargo.toml
  3. 9 7
      src/client.rs
  4. 25 58
      src/helper.rs
  5. 11 28
      src/server.rs

+ 5 - 2
Cargo.lock

@@ -118,13 +118,16 @@ checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a"
 
 [[package]]
 name = "backoff"
-version = "0.3.0"
+version = "0.4.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9fe17f59a06fe8b87a6fc8bf53bb70b3aba76d7685f432487a68cd5552853625"
+checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1"
 dependencies = [
+ "futures-core",
  "getrandom 0.2.4",
  "instant",
+ "pin-project-lite",
  "rand",
+ "tokio",
 ]
 
 [[package]]

+ 1 - 1
Cargo.toml

@@ -57,7 +57,7 @@ bincode = "1"
 lazy_static = "1.4"
 hex = "0.4"
 rand = "0.8"
-backoff = "0.3"
+backoff = { version = "0.4", features = ["tokio"] }
 tracing = "0.1"
 tracing-subscriber = "0.2"
 socket2 = { version = "0.4", features = ["all"] }

+ 9 - 7
src/client.rs

@@ -1,6 +1,6 @@
 use crate::config::{ClientConfig, ClientServiceConfig, Config, ServiceType, TransportType};
 use crate::config_watcher::ServiceChange;
-use crate::helper::{retry_notify, udp_connect};
+use crate::helper::udp_connect;
 use crate::protocol::Hello::{self, *};
 use crate::protocol::{
     self, read_ack, read_control_cmd, read_data_cmd, read_hello, Ack, Auth, ControlChannelCmd,
@@ -8,8 +8,8 @@ use crate::protocol::{
 };
 use crate::transport::{SocketOpts, TcpTransport, Transport};
 use anyhow::{anyhow, bail, Context, Result};
-use backoff::backoff::Backoff;
 use backoff::ExponentialBackoff;
+use backoff::{backoff::Backoff, future::retry_notify};
 use bytes::{Bytes, BytesMut};
 use std::collections::HashMap;
 use std::net::SocketAddr;
@@ -159,21 +159,22 @@ async fn do_data_channel_handshake<T: Transport>(
     args: Arc<RunDataChannelArgs<T>>,
 ) -> Result<T::Stream> {
     // Retry at least every 100ms, at most for 10 seconds
-    let mut backoff = ExponentialBackoff {
+    let backoff = ExponentialBackoff {
         max_interval: Duration::from_millis(100),
         max_elapsed_time: Some(Duration::from_secs(10)),
         ..Default::default()
     };
 
     // Connect to remote_addr
-    let mut conn: T::Stream = retry_notify!(
+    let mut conn: T::Stream = retry_notify(
         backoff,
-        {
+        || async {
             match args
                 .connector
                 .connect(&args.remote_addr)
                 .await
                 .with_context(|| format!("Failed to connect to {}", &args.remote_addr))
+                .map_err(backoff::Error::transient)
             {
                 Ok(conn) => {
                     T::hint(&conn, args.socket_opts);
@@ -184,8 +185,9 @@ async fn do_data_channel_handshake<T: Transport>(
         },
         |e, duration| {
             warn!("{:#}. Retry in {:?}", e, duration);
-        }
-    )?;
+        },
+    )
+    .await?;
 
     // Send nonce
     let v: &[u8; HASH_WIDTH_IN_BYTES] = args.session_key[..].try_into().unwrap();

+ 25 - 58
src/helper.rs

@@ -1,8 +1,11 @@
-use std::{net::SocketAddr, time::Duration};
-
 use anyhow::{anyhow, Result};
+use backoff::{backoff::Backoff, Notify};
 use socket2::{SockRef, TcpKeepalive};
-use tokio::net::{lookup_host, TcpStream, ToSocketAddrs, UdpSocket};
+use std::{future::Future, net::SocketAddr, time::Duration};
+use tokio::{
+    net::{lookup_host, TcpStream, ToSocketAddrs, UdpSocket},
+    sync::broadcast,
+};
 use tracing::trace;
 
 // Tokio hesitates to expose this option...So we have to do it on our own :(
@@ -52,62 +55,26 @@ pub async fn udp_connect<A: ToSocketAddrs>(addr: A) -> Result<UdpSocket> {
     Ok(s)
 }
 
-/// Almost same as backoff::future::retry_notify
-/// But directly expands to a loop
-macro_rules! retry_notify {
-    ($b: expr, $func: expr, $notify: expr) => {
-        loop {
-            match $func {
-                Ok(v) => break Ok(v),
-                Err(e) => match $b.next_backoff() {
-                    Some(duration) => {
-                        $notify(e, duration);
-                        tokio::time::sleep(duration).await;
-                    }
-                    None => break Err(e),
-                },
-            }
+// Wrapper of retry_notify
+pub async fn retry_notify_with_deadline<I, E, Fn, Fut, B, N>(
+    backoff: B,
+    operation: Fn,
+    notify: N,
+    deadline: &mut broadcast::Receiver<bool>,
+) -> Result<I>
+where
+    E: std::error::Error + Send + Sync + 'static,
+    B: Backoff,
+    Fn: FnMut() -> Fut,
+    Fut: Future<Output = std::result::Result<I, backoff::Error<E>>>,
+    N: Notify<E>,
+{
+    tokio::select! {
+        v = backoff::future::retry_notify(backoff, operation, notify) => {
+            v.map_err(anyhow::Error::new)
         }
-    };
-}
-
-pub(crate) use retry_notify;
-
-#[cfg(test)]
-mod test {
-    use super::*;
-    use backoff::{backoff::Backoff, ExponentialBackoff};
-    #[tokio::test]
-    async fn test_retry_notify() {
-        let tests = [(3, Ok(())), (5, Err("try again"))];
-        for (try_succ, expected) in tests {
-            let mut b = ExponentialBackoff {
-                current_interval: Duration::from_millis(100),
-                initial_interval: Duration::from_millis(100),
-                max_elapsed_time: Some(Duration::from_millis(210)),
-                randomization_factor: 0.0,
-                multiplier: 1.0,
-                ..Default::default()
-            };
-
-            let mut notify_count = 0;
-            let mut try_count = 0;
-            let ret: Result<(), &str> = retry_notify!(
-                b,
-                {
-                    try_count += 1;
-                    if try_count == try_succ {
-                        Ok(())
-                    } else {
-                        Err("try again")
-                    }
-                },
-                |e, duration| {
-                    notify_count += 1;
-                    println!("{}: {}, {:?}", notify_count, e, duration);
-                }
-            );
-            assert_eq!(ret, expected);
+        _ = deadline.recv() => {
+            Err(anyhow!("shutdown"))
         }
     }
 }

+ 11 - 28
src/server.rs

@@ -1,7 +1,7 @@
 use crate::config::{Config, ServerConfig, ServerServiceConfig, ServiceType, TransportType};
 use crate::config_watcher::ServiceChange;
 use crate::constants::{listen_backoff, UDP_BUFFER_SIZE};
-use crate::helper::retry_notify;
+use crate::helper::retry_notify_with_deadline;
 use crate::multi_map::MultiMap;
 use crate::protocol::Hello::{ControlChannelHello, DataChannelHello};
 use crate::protocol::{
@@ -509,21 +509,15 @@ fn tcp_listen_and_send(
     let (tx, rx) = mpsc::channel(CHAN_SIZE);
 
     tokio::spawn(async move {
-        let l = retry_notify!(listen_backoff(),  {
-            match shutdown_rx.try_recv() {
-                Err(broadcast::error::TryRecvError::Closed) => Ok(None),
-                _ => TcpListener::bind(&addr).await.map(Some)
-            }
+        let l = retry_notify_with_deadline(listen_backoff(),  || async {
+            Ok(TcpListener::bind(&addr).await?)
         }, |e, duration| {
             error!("{:#}. Retry in {:?}", e, duration);
-        })
+        }, &mut shutdown_rx).await
         .with_context(|| "Failed to listen for the service");
 
         let l: TcpListener = match l {
-            Ok(v) => match v {
-                Some(v) => v,
-                None => return
-            },
+            Ok(v) => v,
             Err(e) => {
                 error!("{:#}", e);
                 return;
@@ -628,27 +622,16 @@ async fn run_udp_connection_pool<T: Transport>(
 ) -> Result<()> {
     // TODO: Load balance
 
-    let l = retry_notify!(
+    let l = retry_notify_with_deadline(
         listen_backoff(),
-        {
-            match shutdown_rx.try_recv() {
-                Err(broadcast::error::TryRecvError::Closed) => Ok(None),
-                _ => UdpSocket::bind(&bind_addr).await.map(Some),
-            }
-        },
+        || async { Ok(UdpSocket::bind(&bind_addr).await?) },
         |e, duration| {
             warn!("{:#}. Retry in {:?}", e, duration);
-        }
-    )
-    .with_context(|| "Failed to listen for the service");
-
-    let l = match l {
-        Ok(v) => match v {
-            Some(l) => l,
-            None => return Ok(()),
         },
-        Err(e) => return Err(e),
-    };
+        &mut shutdown_rx,
+    )
+    .await
+    .with_context(|| "Failed to listen for the service")?;
 
     info!("Listening at {}", &bind_addr);