Ver Fonte

feat: make `accept` cancel safe and add a timeout

Yujia Qiao há 4 anos atrás
pai
commit
edbb5ce5c9
5 ficheiros alterados com 62 adições e 23 exclusões
  1. 21 7
      src/server.rs
  2. 3 1
      src/transport/mod.rs
  3. 6 2
      src/transport/noise.rs
  4. 6 1
      src/transport/tcp.rs
  5. 26 12
      src/transport/tls.rs

+ 21 - 7
src/server.rs

@@ -33,6 +33,7 @@ type Nonce = protocol::Digest; // Also called `session_key`
 const TCP_POOL_SIZE: usize = 8; // The number of cached connections for TCP servies
 const UDP_POOL_SIZE: usize = 2; // The number of cached connections for UDP services
 const CHAN_SIZE: usize = 2048; // The capacity of various chans
+const HANDSHAKE_TIMEOUT: u64 = 5; // Timeout for transport handshake
 
 // The entrypoint of running a server
 pub async fn run_server(
@@ -138,7 +139,6 @@ impl<'a, T: 'static + Transport> Server<'a, T> {
         // Wait for connections and shutdown signals
         loop {
             tokio::select! {
-                // FIXME: This should be cancel safe.
                 // Wait for incoming control and data channels
                 ret = self.transport.accept(&l) => {
                     match ret {
@@ -163,13 +163,27 @@ impl<'a, T: 'static + Transport> Server<'a, T> {
                         Ok((conn, addr)) => {
                             backoff.reset();
 
-                            let services = self.services.clone();
-                            let control_channels = self.control_channels.clone();
-                            tokio::spawn(async move {
-                                if let Err(err) = handle_connection(conn, services, control_channels).await {
-                                    error!("{:?}", err);
+                            // Do transport handshake with a timeout
+                            match time::timeout(Duration::from_secs(HANDSHAKE_TIMEOUT), self.transport.handshake(conn)).await {
+                                Ok(conn) => {
+                                    match conn.with_context(|| "Failed to do transport handshake") {
+                                        Ok(conn) => {
+                                            let services = self.services.clone();
+                                            let control_channels = self.control_channels.clone();
+                                            tokio::spawn(async move {
+                                                if let Err(err) = handle_connection(conn, services, control_channels).await {
+                                                    error!("{:?}", err);
+                                                }
+                                            }.instrument(info_span!("handle_connection", %addr)));
+                                        }, Err(e) => {
+                                            error!("{:?}", e);
+                                        }
+                                    }
+                                },
+                                Err(e) => {
+                                    error!("Transport handshake timeout: {}", e);
                                 }
-                            }.instrument(info_span!("handle_connection", %addr)));
+                            }
                         }
                     }
                 },

+ 3 - 1
src/transport/mod.rs

@@ -10,13 +10,15 @@ use tokio::net::ToSocketAddrs;
 #[async_trait]
 pub trait Transport: Debug + Send + Sync {
     type Acceptor: Send + Sync;
+    type RawStream: Send + Sync;
     type Stream: 'static + AsyncRead + AsyncWrite + Unpin + Send + Sync + Debug;
 
     async fn new(config: &TransportConfig) -> Result<Self>
     where
         Self: Sized;
     async fn bind<T: ToSocketAddrs + Send + Sync>(&self, addr: T) -> Result<Self::Acceptor>;
-    async fn accept(&self, a: &Self::Acceptor) -> Result<(Self::Stream, SocketAddr)>;
+    async fn accept(&self, a: &Self::Acceptor) -> Result<(Self::RawStream, SocketAddr)>;
+    async fn handshake(&self, conn: Self::RawStream) -> Result<Self::Stream>;
     async fn connect(&self, addr: &str) -> Result<Self::Stream>;
 }
 

+ 6 - 2
src/transport/noise.rs

@@ -36,6 +36,7 @@ impl NoiseTransport {
 #[async_trait]
 impl Transport for NoiseTransport {
     type Acceptor = TcpListener;
+    type RawStream = TcpStream;
     type Stream = snowstorm::stream::NoiseStream<TcpStream>;
 
     async fn new(config: &TransportConfig) -> Result<Self> {
@@ -71,17 +72,20 @@ impl Transport for NoiseTransport {
         Ok(TcpListener::bind(addr).await?)
     }
 
-    async fn accept(&self, a: &Self::Acceptor) -> Result<(Self::Stream, SocketAddr)> {
+    async fn accept(&self, a: &Self::Acceptor) -> Result<(Self::RawStream, SocketAddr)> {
         let (conn, addr) = a
             .accept()
             .await
             .with_context(|| "Failed to accept TCP connection")?;
         set_tcp_keepalive(&conn);
+        Ok((conn, addr))
+    }
 
+    async fn handshake(&self, conn: Self::RawStream) -> Result<Self::Stream> {
         let conn = NoiseStream::handshake(conn, self.builder().build_responder()?)
             .await
             .with_context(|| "Failed to do noise handshake")?;
-        Ok((conn, addr))
+        Ok(conn)
     }
 
     async fn connect(&self, addr: &str) -> Result<Self::Stream> {

+ 6 - 1
src/transport/tcp.rs

@@ -14,6 +14,7 @@ pub struct TcpTransport {}
 impl Transport for TcpTransport {
     type Acceptor = TcpListener;
     type Stream = TcpStream;
+    type RawStream = TcpStream;
 
     async fn new(_config: &TransportConfig) -> Result<Self> {
         Ok(TcpTransport {})
@@ -23,12 +24,16 @@ impl Transport for TcpTransport {
         Ok(TcpListener::bind(addr).await?)
     }
 
-    async fn accept(&self, a: &Self::Acceptor) -> Result<(Self::Stream, SocketAddr)> {
+    async fn accept(&self, a: &Self::Acceptor) -> Result<(Self::RawStream, SocketAddr)> {
         let (s, addr) = a.accept().await?;
         set_tcp_keepalive(&s);
         Ok((s, addr))
     }
 
+    async fn handshake(&self, conn: Self::RawStream) -> Result<Self::Stream> {
+        Ok(conn)
+    }
+
     async fn connect(&self, addr: &str) -> Result<Self::Stream> {
         let s = TcpStream::connect(addr).await?;
         set_tcp_keepalive(&s);

+ 26 - 12
src/transport/tls.rs

@@ -14,11 +14,13 @@ use tokio_native_tls::{TlsAcceptor, TlsConnector, TlsStream};
 pub struct TlsTransport {
     config: TlsConfig,
     connector: Option<TlsConnector>,
+    tls_acceptor: Option<TlsAcceptor>,
 }
 
 #[async_trait]
 impl Transport for TlsTransport {
-    type Acceptor = (TcpListener, TlsAcceptor);
+    type Acceptor = TcpListener;
+    type RawStream = TcpStream;
     type Stream = TlsStream<TcpStream>;
 
     async fn new(config: &TransportConfig) -> Result<Self> {
@@ -44,34 +46,46 @@ impl Transport for TlsTransport {
             None => None,
         };
 
+        let tls_acceptor = match config.pkcs12.as_ref() {
+            Some(path) => {
+                let ident = Identity::from_pkcs12(
+                    &fs::read(path).await?,
+                    config.pkcs12_password.as_ref().unwrap(),
+                )
+                .with_context(|| "Failed to create identitiy")?;
+                Some(TlsAcceptor::from(
+                    native_tls::TlsAcceptor::new(ident).unwrap(),
+                ))
+            }
+            None => None,
+        };
+
         Ok(TlsTransport {
             config: config.clone(),
             connector,
+            tls_acceptor,
         })
     }
 
     async fn bind<A: ToSocketAddrs + Send + Sync>(&self, addr: A) -> Result<Self::Acceptor> {
-        let ident = Identity::from_pkcs12(
-            &fs::read(self.config.pkcs12.as_ref().unwrap()).await?,
-            self.config.pkcs12_password.as_ref().unwrap(),
-        )
-        .with_context(|| "Failed to create identitiy")?;
         let l = TcpListener::bind(addr)
             .await
             .with_context(|| "Failed to create tcp listener")?;
-        let t = TlsAcceptor::from(native_tls::TlsAcceptor::new(ident).unwrap());
-        Ok((l, t))
+        Ok(l)
     }
 
-    async fn accept(&self, a: &Self::Acceptor) -> Result<(Self::Stream, SocketAddr)> {
-        let (conn, addr) = a.0.accept().await?;
+    async fn accept(&self, a: &Self::Acceptor) -> Result<(Self::RawStream, SocketAddr)> {
+        let (conn, addr) = a.accept().await?;
         set_tcp_keepalive(&conn);
 
-        let conn = a.1.accept(conn).await?;
-
         Ok((conn, addr))
     }
 
+    async fn handshake(&self, conn: Self::RawStream) -> Result<Self::Stream> {
+        let conn = self.tls_acceptor.as_ref().unwrap().accept(conn).await?;
+        Ok(conn)
+    }
+
     async fn connect(&self, addr: &str) -> Result<Self::Stream> {
         let conn = TcpStream::connect(&addr).await?;
         set_tcp_keepalive(&conn);